fix cname issue & test

This commit is contained in:
vt-d
2025-09-09 21:34:48 +05:30
parent b6379f97d2
commit 3b2060f43b
2 changed files with 71 additions and 12 deletions

View File

@@ -109,9 +109,9 @@ impl GurtClient {
}
}
async fn get_pooled_connection(&self, host: &str, port: u16) -> Result<tokio_rustls::client::TlsStream<TcpStream>> {
async fn get_pooled_connection(&self, host: &str, port: u16, original_host: Option<&str>) -> Result<tokio_rustls::client::TlsStream<TcpStream>> {
if !self.config.enable_connection_pooling {
return self.perform_handshake(host, port).await;
return self.perform_handshake(host, port, original_host).await;
}
let key = ConnectionKey {
@@ -131,7 +131,7 @@ impl GurtClient {
}
debug!("Creating new connection for {}:{}", host, port);
self.perform_handshake(host, port).await
self.perform_handshake(host, port, original_host).await
}
fn return_connection_to_pool(&self, host: &str, port: u16, connection: tokio_rustls::client::TlsStream<TcpStream>) {
@@ -231,13 +231,16 @@ impl GurtClient {
}
}
async fn perform_handshake(&self, host: &str, port: u16) -> Result<tokio_rustls::client::TlsStream<TcpStream>> {
async fn perform_handshake(&self, host: &str, port: u16, original_host: Option<&str>) -> Result<tokio_rustls::client::TlsStream<TcpStream>> {
debug!("Starting GURT handshake with {}:{}", host, port);
let mut plain_conn = self.create_connection(host, port).await?;
// Use original_host for the Host header if available, otherwise fall back to host
let host_header = original_host.unwrap_or(host);
let handshake_request = GurtRequest::new(GurtMethod::HANDSHAKE, "/".to_string())
.with_header("Host", host)
.with_header("Host", host_header)
.with_header("User-Agent", &self.config.user_agent);
let handshake_data = handshake_request.to_string();
@@ -261,7 +264,10 @@ impl GurtClient {
Connection::Plain(stream) => stream,
};
self.upgrade_to_tls(tcp_stream, host).await
// Use original_host for TLS SNI if available, otherwise fall back to host
let tls_host = original_host.unwrap_or(host);
self.upgrade_to_tls(tcp_stream, tls_host).await
}
async fn upgrade_to_tls(&self, stream: TcpStream, host: &str) -> Result<tokio_rustls::client::TlsStream<TcpStream>> {
@@ -323,10 +329,10 @@ impl GurtClient {
Ok(tls_stream)
}
async fn send_request_internal(&self, host: &str, port: u16, request: GurtRequest) -> Result<GurtResponse> {
async fn send_request_internal(&self, host: &str, port: u16, request: GurtRequest, original_host: Option<&str>) -> Result<GurtResponse> {
debug!("Sending {} {} to {}:{}", request.method, request.path, host, port);
let mut tls_stream = self.get_pooled_connection(host, port).await?;
let mut tls_stream = self.get_pooled_connection(host, port, original_host).await?;
let request_data = request.to_string();
tls_stream.write_all(request_data.as_bytes()).await
@@ -501,7 +507,7 @@ impl GurtClient {
request = request.with_header("Host", host);
self.send_request_internal(&resolved_host, port, request).await
self.send_request_internal(&resolved_host, port, request, Some(host)).await
}
fn parse_gurt_url(&self, url: &str) -> Result<(String, u16, String)> {
@@ -564,7 +570,7 @@ impl GurtClient {
.with_header("Content-Type", "application/json")
.with_string_body(dns_request_body);
let dns_response = self.send_request_internal(&dns_server_ip, self.config.dns_server_port, dns_request).await?;
let dns_response = self.send_request_internal(&dns_server_ip, self.config.dns_server_port, dns_request, None).await?;
if dns_response.status_code != 200 {
return Err(GurtError::invalid_message(format!(
@@ -675,4 +681,57 @@ mod tests {
assert_eq!(key1, key2);
assert_ne!(key1, key3);
}
#[tokio::test]
async fn test_host_header_preserved_with_dns_resolution() {
use crate::message::{GurtMethod, GurtRequest};
let mut config = GurtClientConfig::default();
config.enable_connection_pooling = false;
let client = GurtClient::with_config(config);
{
let mut dns_cache = client.dns_cache.lock().unwrap();
dns_cache.insert("arson.dev".to_string(), "1.1.1.1".to_string());
}
let request = GurtRequest::new(GurtMethod::GET, "/test".to_string());
let original_host = "arson.dev";
let mut test_request = request.clone();
test_request = test_request.with_header("Host", original_host);
assert_eq!(test_request.headers.get("host").unwrap(), original_host);
let resolved = client.resolve_domain("arson.dev").await.unwrap();
assert_eq!(resolved, "1.1.1.1");
let request_with_host = GurtRequest::new(GurtMethod::GET, "/test".to_string())
.with_header("Host", original_host);
assert_eq!(request_with_host.headers.get("host").unwrap(), "arson.dev");
}
#[test]
fn test_handshake_request_uses_original_host() {
use crate::message::{GurtMethod, GurtRequest};
let original_host = "arson.dev";
let host_header = original_host;
let handshake_request = GurtRequest::new(GurtMethod::HANDSHAKE, "/".to_string())
.with_header("Host", host_header)
.with_header("User-Agent", "GURT-Client/1.0.0");
assert_eq!(handshake_request.headers.get("host").unwrap(), "arson.dev");
assert_ne!(handshake_request.headers.get("host").unwrap(), "1.1.1.1");
assert_eq!(handshake_request.method, GurtMethod::HANDSHAKE);
assert_eq!(handshake_request.path, "/");
assert!(handshake_request.headers.contains_key("user-agent"));
}
}

View File

@@ -577,8 +577,8 @@ mod tests {
assert!(!route.matches(&GurtMethod::POST, "/test"));
assert!(!route.matches(&GurtMethod::GET, "/other"));
assert!(!route.matches(&GurtMethod::GET, "/test?foo=bar"));
assert!(!route.matches(&GurtMethod::GET, "/test?page=1&limit=100"));
assert!(route.matches(&GurtMethod::GET, "/test?foo=bar"));
assert!(route.matches(&GurtMethod::GET, "/test?page=1&limit=100"));
let wildcard_route = Route::get("/api/*");
assert!(wildcard_route.matches(&GurtMethod::GET, "/api/users"));