Merge pull request #17 from vt-d/main
Fix CNAME records overwriting host header & test
This commit is contained in:
@@ -109,9 +109,9 @@ impl GurtClient {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn get_pooled_connection(&self, host: &str, port: u16) -> Result<tokio_rustls::client::TlsStream<TcpStream>> {
|
async fn get_pooled_connection(&self, host: &str, port: u16, original_host: Option<&str>) -> Result<tokio_rustls::client::TlsStream<TcpStream>> {
|
||||||
if !self.config.enable_connection_pooling {
|
if !self.config.enable_connection_pooling {
|
||||||
return self.perform_handshake(host, port).await;
|
return self.perform_handshake(host, port, original_host).await;
|
||||||
}
|
}
|
||||||
|
|
||||||
let key = ConnectionKey {
|
let key = ConnectionKey {
|
||||||
@@ -131,7 +131,7 @@ impl GurtClient {
|
|||||||
}
|
}
|
||||||
|
|
||||||
debug!("Creating new connection for {}:{}", host, port);
|
debug!("Creating new connection for {}:{}", host, port);
|
||||||
self.perform_handshake(host, port).await
|
self.perform_handshake(host, port, original_host).await
|
||||||
}
|
}
|
||||||
|
|
||||||
fn return_connection_to_pool(&self, host: &str, port: u16, connection: tokio_rustls::client::TlsStream<TcpStream>) {
|
fn return_connection_to_pool(&self, host: &str, port: u16, connection: tokio_rustls::client::TlsStream<TcpStream>) {
|
||||||
@@ -231,13 +231,16 @@ impl GurtClient {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn perform_handshake(&self, host: &str, port: u16) -> Result<tokio_rustls::client::TlsStream<TcpStream>> {
|
async fn perform_handshake(&self, host: &str, port: u16, original_host: Option<&str>) -> Result<tokio_rustls::client::TlsStream<TcpStream>> {
|
||||||
debug!("Starting GURT handshake with {}:{}", host, port);
|
debug!("Starting GURT handshake with {}:{}", host, port);
|
||||||
|
|
||||||
let mut plain_conn = self.create_connection(host, port).await?;
|
let mut plain_conn = self.create_connection(host, port).await?;
|
||||||
|
|
||||||
|
// Use original_host for the Host header if available, otherwise fall back to host
|
||||||
|
let host_header = original_host.unwrap_or(host);
|
||||||
|
|
||||||
let handshake_request = GurtRequest::new(GurtMethod::HANDSHAKE, "/".to_string())
|
let handshake_request = GurtRequest::new(GurtMethod::HANDSHAKE, "/".to_string())
|
||||||
.with_header("Host", host)
|
.with_header("Host", host_header)
|
||||||
.with_header("User-Agent", &self.config.user_agent);
|
.with_header("User-Agent", &self.config.user_agent);
|
||||||
|
|
||||||
let handshake_data = handshake_request.to_string();
|
let handshake_data = handshake_request.to_string();
|
||||||
@@ -261,7 +264,10 @@ impl GurtClient {
|
|||||||
Connection::Plain(stream) => stream,
|
Connection::Plain(stream) => stream,
|
||||||
};
|
};
|
||||||
|
|
||||||
self.upgrade_to_tls(tcp_stream, host).await
|
// Use original_host for TLS SNI if available, otherwise fall back to host
|
||||||
|
let tls_host = original_host.unwrap_or(host);
|
||||||
|
|
||||||
|
self.upgrade_to_tls(tcp_stream, tls_host).await
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn upgrade_to_tls(&self, stream: TcpStream, host: &str) -> Result<tokio_rustls::client::TlsStream<TcpStream>> {
|
async fn upgrade_to_tls(&self, stream: TcpStream, host: &str) -> Result<tokio_rustls::client::TlsStream<TcpStream>> {
|
||||||
@@ -323,10 +329,10 @@ impl GurtClient {
|
|||||||
Ok(tls_stream)
|
Ok(tls_stream)
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn send_request_internal(&self, host: &str, port: u16, request: GurtRequest) -> Result<GurtResponse> {
|
async fn send_request_internal(&self, host: &str, port: u16, request: GurtRequest, original_host: Option<&str>) -> Result<GurtResponse> {
|
||||||
debug!("Sending {} {} to {}:{}", request.method, request.path, host, port);
|
debug!("Sending {} {} to {}:{}", request.method, request.path, host, port);
|
||||||
|
|
||||||
let mut tls_stream = self.get_pooled_connection(host, port).await?;
|
let mut tls_stream = self.get_pooled_connection(host, port, original_host).await?;
|
||||||
|
|
||||||
let request_data = request.to_string();
|
let request_data = request.to_string();
|
||||||
tls_stream.write_all(request_data.as_bytes()).await
|
tls_stream.write_all(request_data.as_bytes()).await
|
||||||
@@ -501,7 +507,7 @@ impl GurtClient {
|
|||||||
|
|
||||||
request = request.with_header("Host", host);
|
request = request.with_header("Host", host);
|
||||||
|
|
||||||
self.send_request_internal(&resolved_host, port, request).await
|
self.send_request_internal(&resolved_host, port, request, Some(host)).await
|
||||||
}
|
}
|
||||||
|
|
||||||
fn parse_gurt_url(&self, url: &str) -> Result<(String, u16, String)> {
|
fn parse_gurt_url(&self, url: &str) -> Result<(String, u16, String)> {
|
||||||
@@ -564,7 +570,7 @@ impl GurtClient {
|
|||||||
.with_header("Content-Type", "application/json")
|
.with_header("Content-Type", "application/json")
|
||||||
.with_string_body(dns_request_body);
|
.with_string_body(dns_request_body);
|
||||||
|
|
||||||
let dns_response = self.send_request_internal(&dns_server_ip, self.config.dns_server_port, dns_request).await?;
|
let dns_response = self.send_request_internal(&dns_server_ip, self.config.dns_server_port, dns_request, None).await?;
|
||||||
|
|
||||||
if dns_response.status_code != 200 {
|
if dns_response.status_code != 200 {
|
||||||
return Err(GurtError::invalid_message(format!(
|
return Err(GurtError::invalid_message(format!(
|
||||||
@@ -675,4 +681,55 @@ mod tests {
|
|||||||
assert_eq!(key1, key2);
|
assert_eq!(key1, key2);
|
||||||
assert_ne!(key1, key3);
|
assert_ne!(key1, key3);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_host_header_preserved_with_dns_resolution() {
|
||||||
|
use crate::message::{GurtMethod, GurtRequest};
|
||||||
|
|
||||||
|
let mut config = GurtClientConfig::default();
|
||||||
|
config.enable_connection_pooling = false;
|
||||||
|
let client = GurtClient::with_config(config);
|
||||||
|
|
||||||
|
{
|
||||||
|
let mut dns_cache = client.dns_cache.lock().unwrap();
|
||||||
|
dns_cache.insert("arson.dev".to_string(), "1.1.1.1".to_string());
|
||||||
|
}
|
||||||
|
|
||||||
|
let request = GurtRequest::new(GurtMethod::GET, "/test".to_string());
|
||||||
|
|
||||||
|
let original_host = "arson.dev";
|
||||||
|
|
||||||
|
let mut test_request = request.clone();
|
||||||
|
test_request = test_request.with_header("Host", original_host);
|
||||||
|
|
||||||
|
assert_eq!(test_request.headers.get("host").unwrap(), original_host);
|
||||||
|
|
||||||
|
let resolved = client.resolve_domain("arson.dev").await.unwrap();
|
||||||
|
assert_eq!(resolved, "1.1.1.1");
|
||||||
|
|
||||||
|
let request_with_host = GurtRequest::new(GurtMethod::GET, "/test".to_string())
|
||||||
|
.with_header("Host", original_host);
|
||||||
|
|
||||||
|
assert_eq!(request_with_host.headers.get("host").unwrap(), "arson.dev");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_handshake_request_uses_original_host() {
|
||||||
|
use crate::message::{GurtMethod, GurtRequest};
|
||||||
|
|
||||||
|
let original_host = "arson.dev";
|
||||||
|
|
||||||
|
let handshake_request = GurtRequest::new(GurtMethod::HANDSHAKE, "/".to_string())
|
||||||
|
.with_header("Host", original_host)
|
||||||
|
.with_header("User-Agent", "GURT-Client/1.0.0");
|
||||||
|
|
||||||
|
assert_eq!(handshake_request.headers.get("host").unwrap(), "arson.dev");
|
||||||
|
assert_ne!(handshake_request.headers.get("host").unwrap(), "1.1.1.1");
|
||||||
|
|
||||||
|
assert_eq!(handshake_request.method, GurtMethod::HANDSHAKE);
|
||||||
|
assert_eq!(handshake_request.path, "/");
|
||||||
|
|
||||||
|
assert!(handshake_request.headers.contains_key("user-agent"));
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
@@ -577,8 +577,8 @@ mod tests {
|
|||||||
assert!(!route.matches(&GurtMethod::POST, "/test"));
|
assert!(!route.matches(&GurtMethod::POST, "/test"));
|
||||||
assert!(!route.matches(&GurtMethod::GET, "/other"));
|
assert!(!route.matches(&GurtMethod::GET, "/other"));
|
||||||
|
|
||||||
assert!(!route.matches(&GurtMethod::GET, "/test?foo=bar"));
|
assert!(route.matches(&GurtMethod::GET, "/test?foo=bar"));
|
||||||
assert!(!route.matches(&GurtMethod::GET, "/test?page=1&limit=100"));
|
assert!(route.matches(&GurtMethod::GET, "/test?page=1&limit=100"));
|
||||||
|
|
||||||
let wildcard_route = Route::get("/api/*");
|
let wildcard_route = Route::get("/api/*");
|
||||||
assert!(wildcard_route.matches(&GurtMethod::GET, "/api/users"));
|
assert!(wildcard_route.matches(&GurtMethod::GET, "/api/users"));
|
||||||
|
|||||||
Reference in New Issue
Block a user