fix cname issue & test
This commit is contained in:
@@ -109,9 +109,9 @@ impl GurtClient {
|
||||
}
|
||||
}
|
||||
|
||||
async fn get_pooled_connection(&self, host: &str, port: u16) -> Result<tokio_rustls::client::TlsStream<TcpStream>> {
|
||||
async fn get_pooled_connection(&self, host: &str, port: u16, original_host: Option<&str>) -> Result<tokio_rustls::client::TlsStream<TcpStream>> {
|
||||
if !self.config.enable_connection_pooling {
|
||||
return self.perform_handshake(host, port).await;
|
||||
return self.perform_handshake(host, port, original_host).await;
|
||||
}
|
||||
|
||||
let key = ConnectionKey {
|
||||
@@ -131,7 +131,7 @@ impl GurtClient {
|
||||
}
|
||||
|
||||
debug!("Creating new connection for {}:{}", host, port);
|
||||
self.perform_handshake(host, port).await
|
||||
self.perform_handshake(host, port, original_host).await
|
||||
}
|
||||
|
||||
fn return_connection_to_pool(&self, host: &str, port: u16, connection: tokio_rustls::client::TlsStream<TcpStream>) {
|
||||
@@ -231,13 +231,16 @@ impl GurtClient {
|
||||
}
|
||||
}
|
||||
|
||||
async fn perform_handshake(&self, host: &str, port: u16) -> Result<tokio_rustls::client::TlsStream<TcpStream>> {
|
||||
async fn perform_handshake(&self, host: &str, port: u16, original_host: Option<&str>) -> Result<tokio_rustls::client::TlsStream<TcpStream>> {
|
||||
debug!("Starting GURT handshake with {}:{}", host, port);
|
||||
|
||||
let mut plain_conn = self.create_connection(host, port).await?;
|
||||
|
||||
// Use original_host for the Host header if available, otherwise fall back to host
|
||||
let host_header = original_host.unwrap_or(host);
|
||||
|
||||
let handshake_request = GurtRequest::new(GurtMethod::HANDSHAKE, "/".to_string())
|
||||
.with_header("Host", host)
|
||||
.with_header("Host", host_header)
|
||||
.with_header("User-Agent", &self.config.user_agent);
|
||||
|
||||
let handshake_data = handshake_request.to_string();
|
||||
@@ -261,7 +264,10 @@ impl GurtClient {
|
||||
Connection::Plain(stream) => stream,
|
||||
};
|
||||
|
||||
self.upgrade_to_tls(tcp_stream, host).await
|
||||
// Use original_host for TLS SNI if available, otherwise fall back to host
|
||||
let tls_host = original_host.unwrap_or(host);
|
||||
|
||||
self.upgrade_to_tls(tcp_stream, tls_host).await
|
||||
}
|
||||
|
||||
async fn upgrade_to_tls(&self, stream: TcpStream, host: &str) -> Result<tokio_rustls::client::TlsStream<TcpStream>> {
|
||||
@@ -323,10 +329,10 @@ impl GurtClient {
|
||||
Ok(tls_stream)
|
||||
}
|
||||
|
||||
async fn send_request_internal(&self, host: &str, port: u16, request: GurtRequest) -> Result<GurtResponse> {
|
||||
async fn send_request_internal(&self, host: &str, port: u16, request: GurtRequest, original_host: Option<&str>) -> Result<GurtResponse> {
|
||||
debug!("Sending {} {} to {}:{}", request.method, request.path, host, port);
|
||||
|
||||
let mut tls_stream = self.get_pooled_connection(host, port).await?;
|
||||
let mut tls_stream = self.get_pooled_connection(host, port, original_host).await?;
|
||||
|
||||
let request_data = request.to_string();
|
||||
tls_stream.write_all(request_data.as_bytes()).await
|
||||
@@ -501,7 +507,7 @@ impl GurtClient {
|
||||
|
||||
request = request.with_header("Host", host);
|
||||
|
||||
self.send_request_internal(&resolved_host, port, request).await
|
||||
self.send_request_internal(&resolved_host, port, request, Some(host)).await
|
||||
}
|
||||
|
||||
fn parse_gurt_url(&self, url: &str) -> Result<(String, u16, String)> {
|
||||
@@ -564,7 +570,7 @@ impl GurtClient {
|
||||
.with_header("Content-Type", "application/json")
|
||||
.with_string_body(dns_request_body);
|
||||
|
||||
let dns_response = self.send_request_internal(&dns_server_ip, self.config.dns_server_port, dns_request).await?;
|
||||
let dns_response = self.send_request_internal(&dns_server_ip, self.config.dns_server_port, dns_request, None).await?;
|
||||
|
||||
if dns_response.status_code != 200 {
|
||||
return Err(GurtError::invalid_message(format!(
|
||||
@@ -675,4 +681,57 @@ mod tests {
|
||||
assert_eq!(key1, key2);
|
||||
assert_ne!(key1, key3);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_host_header_preserved_with_dns_resolution() {
|
||||
use crate::message::{GurtMethod, GurtRequest};
|
||||
|
||||
let mut config = GurtClientConfig::default();
|
||||
config.enable_connection_pooling = false;
|
||||
let client = GurtClient::with_config(config);
|
||||
|
||||
{
|
||||
let mut dns_cache = client.dns_cache.lock().unwrap();
|
||||
dns_cache.insert("arson.dev".to_string(), "1.1.1.1".to_string());
|
||||
}
|
||||
|
||||
let request = GurtRequest::new(GurtMethod::GET, "/test".to_string());
|
||||
|
||||
let original_host = "arson.dev";
|
||||
|
||||
let mut test_request = request.clone();
|
||||
test_request = test_request.with_header("Host", original_host);
|
||||
|
||||
assert_eq!(test_request.headers.get("host").unwrap(), original_host);
|
||||
|
||||
let resolved = client.resolve_domain("arson.dev").await.unwrap();
|
||||
assert_eq!(resolved, "1.1.1.1");
|
||||
|
||||
let request_with_host = GurtRequest::new(GurtMethod::GET, "/test".to_string())
|
||||
.with_header("Host", original_host);
|
||||
|
||||
assert_eq!(request_with_host.headers.get("host").unwrap(), "arson.dev");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_handshake_request_uses_original_host() {
|
||||
use crate::message::{GurtMethod, GurtRequest};
|
||||
|
||||
let original_host = "arson.dev";
|
||||
|
||||
let host_header = original_host;
|
||||
|
||||
let handshake_request = GurtRequest::new(GurtMethod::HANDSHAKE, "/".to_string())
|
||||
.with_header("Host", host_header)
|
||||
.with_header("User-Agent", "GURT-Client/1.0.0");
|
||||
|
||||
assert_eq!(handshake_request.headers.get("host").unwrap(), "arson.dev");
|
||||
assert_ne!(handshake_request.headers.get("host").unwrap(), "1.1.1.1");
|
||||
|
||||
assert_eq!(handshake_request.method, GurtMethod::HANDSHAKE);
|
||||
assert_eq!(handshake_request.path, "/");
|
||||
|
||||
assert!(handshake_request.headers.contains_key("user-agent"));
|
||||
}
|
||||
|
||||
}
|
||||
@@ -577,8 +577,8 @@ mod tests {
|
||||
assert!(!route.matches(&GurtMethod::POST, "/test"));
|
||||
assert!(!route.matches(&GurtMethod::GET, "/other"));
|
||||
|
||||
assert!(!route.matches(&GurtMethod::GET, "/test?foo=bar"));
|
||||
assert!(!route.matches(&GurtMethod::GET, "/test?page=1&limit=100"));
|
||||
assert!(route.matches(&GurtMethod::GET, "/test?foo=bar"));
|
||||
assert!(route.matches(&GurtMethod::GET, "/test?page=1&limit=100"));
|
||||
|
||||
let wildcard_route = Route::get("/api/*");
|
||||
assert!(wildcard_route.matches(&GurtMethod::GET, "/api/users"));
|
||||
|
||||
Reference in New Issue
Block a user