From 3b2060f43ba923646bc69c99fea740bd145752ef Mon Sep 17 00:00:00 2001 From: vt-d Date: Tue, 9 Sep 2025 21:34:48 +0530 Subject: [PATCH] fix cname issue & test --- protocol/library/src/client.rs | 79 +++++++++++++++++++++++++++++----- protocol/library/src/server.rs | 4 +- 2 files changed, 71 insertions(+), 12 deletions(-) diff --git a/protocol/library/src/client.rs b/protocol/library/src/client.rs index c66e354..dfb36b2 100644 --- a/protocol/library/src/client.rs +++ b/protocol/library/src/client.rs @@ -109,9 +109,9 @@ impl GurtClient { } } - async fn get_pooled_connection(&self, host: &str, port: u16) -> Result> { + async fn get_pooled_connection(&self, host: &str, port: u16, original_host: Option<&str>) -> Result> { if !self.config.enable_connection_pooling { - return self.perform_handshake(host, port).await; + return self.perform_handshake(host, port, original_host).await; } let key = ConnectionKey { @@ -131,7 +131,7 @@ impl GurtClient { } debug!("Creating new connection for {}:{}", host, port); - self.perform_handshake(host, port).await + self.perform_handshake(host, port, original_host).await } fn return_connection_to_pool(&self, host: &str, port: u16, connection: tokio_rustls::client::TlsStream) { @@ -231,13 +231,16 @@ impl GurtClient { } } - async fn perform_handshake(&self, host: &str, port: u16) -> Result> { + async fn perform_handshake(&self, host: &str, port: u16, original_host: Option<&str>) -> Result> { debug!("Starting GURT handshake with {}:{}", host, port); let mut plain_conn = self.create_connection(host, port).await?; + // Use original_host for the Host header if available, otherwise fall back to host + let host_header = original_host.unwrap_or(host); + let handshake_request = GurtRequest::new(GurtMethod::HANDSHAKE, "/".to_string()) - .with_header("Host", host) + .with_header("Host", host_header) .with_header("User-Agent", &self.config.user_agent); let handshake_data = handshake_request.to_string(); @@ -261,7 +264,10 @@ impl GurtClient { Connection::Plain(stream) => stream, }; - self.upgrade_to_tls(tcp_stream, host).await + // Use original_host for TLS SNI if available, otherwise fall back to host + let tls_host = original_host.unwrap_or(host); + + self.upgrade_to_tls(tcp_stream, tls_host).await } async fn upgrade_to_tls(&self, stream: TcpStream, host: &str) -> Result> { @@ -323,10 +329,10 @@ impl GurtClient { Ok(tls_stream) } - async fn send_request_internal(&self, host: &str, port: u16, request: GurtRequest) -> Result { + async fn send_request_internal(&self, host: &str, port: u16, request: GurtRequest, original_host: Option<&str>) -> Result { debug!("Sending {} {} to {}:{}", request.method, request.path, host, port); - let mut tls_stream = self.get_pooled_connection(host, port).await?; + let mut tls_stream = self.get_pooled_connection(host, port, original_host).await?; let request_data = request.to_string(); tls_stream.write_all(request_data.as_bytes()).await @@ -501,7 +507,7 @@ impl GurtClient { request = request.with_header("Host", host); - self.send_request_internal(&resolved_host, port, request).await + self.send_request_internal(&resolved_host, port, request, Some(host)).await } fn parse_gurt_url(&self, url: &str) -> Result<(String, u16, String)> { @@ -564,7 +570,7 @@ impl GurtClient { .with_header("Content-Type", "application/json") .with_string_body(dns_request_body); - let dns_response = self.send_request_internal(&dns_server_ip, self.config.dns_server_port, dns_request).await?; + let dns_response = self.send_request_internal(&dns_server_ip, self.config.dns_server_port, dns_request, None).await?; if dns_response.status_code != 200 { return Err(GurtError::invalid_message(format!( @@ -675,4 +681,57 @@ mod tests { assert_eq!(key1, key2); assert_ne!(key1, key3); } + + #[tokio::test] + async fn test_host_header_preserved_with_dns_resolution() { + use crate::message::{GurtMethod, GurtRequest}; + + let mut config = GurtClientConfig::default(); + config.enable_connection_pooling = false; + let client = GurtClient::with_config(config); + + { + let mut dns_cache = client.dns_cache.lock().unwrap(); + dns_cache.insert("arson.dev".to_string(), "1.1.1.1".to_string()); + } + + let request = GurtRequest::new(GurtMethod::GET, "/test".to_string()); + + let original_host = "arson.dev"; + + let mut test_request = request.clone(); + test_request = test_request.with_header("Host", original_host); + + assert_eq!(test_request.headers.get("host").unwrap(), original_host); + + let resolved = client.resolve_domain("arson.dev").await.unwrap(); + assert_eq!(resolved, "1.1.1.1"); + + let request_with_host = GurtRequest::new(GurtMethod::GET, "/test".to_string()) + .with_header("Host", original_host); + + assert_eq!(request_with_host.headers.get("host").unwrap(), "arson.dev"); + } + + #[test] + fn test_handshake_request_uses_original_host() { + use crate::message::{GurtMethod, GurtRequest}; + + let original_host = "arson.dev"; + + let host_header = original_host; + + let handshake_request = GurtRequest::new(GurtMethod::HANDSHAKE, "/".to_string()) + .with_header("Host", host_header) + .with_header("User-Agent", "GURT-Client/1.0.0"); + + assert_eq!(handshake_request.headers.get("host").unwrap(), "arson.dev"); + assert_ne!(handshake_request.headers.get("host").unwrap(), "1.1.1.1"); + + assert_eq!(handshake_request.method, GurtMethod::HANDSHAKE); + assert_eq!(handshake_request.path, "/"); + + assert!(handshake_request.headers.contains_key("user-agent")); + } + } \ No newline at end of file diff --git a/protocol/library/src/server.rs b/protocol/library/src/server.rs index c744cd2..05c040f 100644 --- a/protocol/library/src/server.rs +++ b/protocol/library/src/server.rs @@ -577,8 +577,8 @@ mod tests { assert!(!route.matches(&GurtMethod::POST, "/test")); assert!(!route.matches(&GurtMethod::GET, "/other")); - assert!(!route.matches(&GurtMethod::GET, "/test?foo=bar")); - assert!(!route.matches(&GurtMethod::GET, "/test?page=1&limit=100")); + assert!(route.matches(&GurtMethod::GET, "/test?foo=bar")); + assert!(route.matches(&GurtMethod::GET, "/test?page=1&limit=100")); let wildcard_route = Route::get("/api/*"); assert!(wildcard_route.matches(&GurtMethod::GET, "/api/users"));