Gurty -> usable state

This commit is contained in:
vt-d
2025-08-19 22:01:20 +05:30
parent 99f17dc42c
commit aa49fac5b8
24 changed files with 2679 additions and 394 deletions

View File

@@ -184,12 +184,11 @@ impl GurtClient {
let bytes_read = conn.connection.read(&mut temp_buffer).await?;
if bytes_read == 0 {
break; // Connection closed
break;
}
buffer.extend_from_slice(&temp_buffer[..bytes_read]);
// Check for complete message
let body_separator = BODY_SEPARATOR.as_bytes();
if !headers_parsed {
@@ -197,7 +196,6 @@ impl GurtClient {
headers_end_pos = Some(pos + body_separator.len());
headers_parsed = true;
// Parse headers to get Content-Length
let headers_section = &buffer[..pos];
if let Ok(headers_str) = std::str::from_utf8(headers_section) {
for line in headers_str.lines() {
@@ -220,7 +218,6 @@ impl GurtClient {
return Ok(buffer);
}
} else if headers_parsed && expected_body_length.is_none() {
// No Content-Length header, return what we have after headers
return Ok(buffer);
}
}
@@ -329,7 +326,7 @@ impl GurtClient {
}
match timeout(Duration::from_millis(100), tls_stream.read(&mut temp_buffer)).await {
Ok(Ok(0)) => break, // Connection closed
Ok(Ok(0)) => break,
Ok(Ok(n)) => {
buffer.extend_from_slice(&temp_buffer[..n]);
@@ -394,7 +391,6 @@ impl GurtClient {
self.send_request_internal(&host, port, request).await
}
/// POST request with JSON body
pub async fn post_json<T: serde::Serialize>(&self, url: &str, data: &T) -> Result<GurtResponse> {
let (host, port, path) = self.parse_url(url)?;
let json_body = serde_json::to_string(data)?;
@@ -408,7 +404,6 @@ impl GurtClient {
self.send_request_internal(&host, port, request).await
}
/// PUT request with body
pub async fn put(&self, url: &str, body: &str) -> Result<GurtResponse> {
let (host, port, path) = self.parse_url(url)?;
let request = GurtRequest::new(GurtMethod::PUT, path)
@@ -420,7 +415,6 @@ impl GurtClient {
self.send_request_internal(&host, port, request).await
}
/// PUT request with JSON body
pub async fn put_json<T: serde::Serialize>(&self, url: &str, data: &T) -> Result<GurtResponse> {
let (host, port, path) = self.parse_url(url)?;
let json_body = serde_json::to_string(data)?;
@@ -461,7 +455,6 @@ impl GurtClient {
self.send_request_internal(&host, port, request).await
}
/// PATCH request with body
pub async fn patch(&self, url: &str, body: &str) -> Result<GurtResponse> {
let (host, port, path) = self.parse_url(url)?;
let request = GurtRequest::new(GurtMethod::PATCH, path)
@@ -473,7 +466,6 @@ impl GurtClient {
self.send_request_internal(&host, port, request).await
}
/// PATCH request with JSON body
pub async fn patch_json<T: serde::Serialize>(&self, url: &str, data: &T) -> Result<GurtResponse> {
let (host, port, path) = self.parse_url(url)?;
let json_body = serde_json::to_string(data)?;
@@ -548,4 +540,38 @@ mod tests {
assert_eq!(port, 8080);
assert_eq!(path, "/api/v1");
}
#[test]
fn test_connection_pooling_config() {
let config = GurtClientConfig {
enable_connection_pooling: true,
max_connections_per_host: 8,
..Default::default()
};
let client = GurtClient::with_config(config);
assert!(client.config.enable_connection_pooling);
assert_eq!(client.config.max_connections_per_host, 8);
}
#[test]
fn test_connection_key() {
let key1 = ConnectionKey {
host: "example.com".to_string(),
port: 4878,
};
let key2 = ConnectionKey {
host: "example.com".to_string(),
port: 4878,
};
let key3 = ConnectionKey {
host: "other.com".to_string(),
port: 4878,
};
assert_eq!(key1, key2);
assert_ne!(key1, key3);
}
}

View File

@@ -14,7 +14,7 @@ pub enum GurtMethod {
HEAD,
OPTIONS,
PATCH,
HANDSHAKE, // Special method for protocol handshake
HANDSHAKE,
}
impl GurtMethod {
@@ -101,7 +101,6 @@ impl GurtRequest {
}
pub fn parse_bytes(data: &[u8]) -> Result<Self> {
// Find the header/body separator as bytes
let body_separator = BODY_SEPARATOR.as_bytes();
let body_separator_pos = data.windows(body_separator.len())
.position(|window| window == body_separator);
@@ -114,7 +113,6 @@ impl GurtRequest {
(data, Vec::new())
};
// Convert headers section to string (should be valid UTF-8)
let headers_str = std::str::from_utf8(headers_section)
.map_err(|_| GurtError::invalid_message("Invalid UTF-8 in headers"))?;
@@ -124,7 +122,6 @@ impl GurtRequest {
return Err(GurtError::invalid_message("Empty request"));
}
// Parse request line (METHOD path GURT/version)
let request_line = lines[0];
let parts: Vec<&str> = request_line.split_whitespace().collect();
@@ -135,7 +132,6 @@ impl GurtRequest {
let method = GurtMethod::parse(parts[0])?;
let path = parts[1].to_string();
// Parse protocol version
if !parts[2].starts_with(PROTOCOL_PREFIX) {
return Err(GurtError::invalid_message("Invalid protocol identifier"));
}
@@ -143,7 +139,6 @@ impl GurtRequest {
let version_str = &parts[2][PROTOCOL_PREFIX.len()..];
let version = version_str.to_string();
// Parse headers
let mut headers = GurtHeaders::new();
for line in lines.iter().skip(1) {
@@ -253,6 +248,10 @@ impl GurtResponse {
Self::new(GurtStatusCode::BadRequest)
}
pub fn forbidden() -> Self {
Self::new(GurtStatusCode::Forbidden)
}
pub fn internal_server_error() -> Self {
Self::new(GurtStatusCode::InternalServerError)
}
@@ -306,7 +305,6 @@ impl GurtResponse {
}
pub fn parse_bytes(data: &[u8]) -> Result<Self> {
// Find the header/body separator as bytes
let body_separator = BODY_SEPARATOR.as_bytes();
let body_separator_pos = data.windows(body_separator.len())
.position(|window| window == body_separator);
@@ -319,7 +317,6 @@ impl GurtResponse {
(data, Vec::new())
};
// Convert headers section to string (should be valid UTF-8)
let headers_str = std::str::from_utf8(headers_section)
.map_err(|_| GurtError::invalid_message("Invalid UTF-8 in headers"))?;
@@ -329,7 +326,6 @@ impl GurtResponse {
return Err(GurtError::invalid_message("Empty response"));
}
// Parse status line (GURT/version status_code status_message)
let status_line = lines[0];
let parts: Vec<&str> = status_line.splitn(3, ' ').collect();
@@ -337,7 +333,6 @@ impl GurtResponse {
return Err(GurtError::invalid_message("Invalid status line format"));
}
// Parse protocol version
if !parts[0].starts_with(PROTOCOL_PREFIX) {
return Err(GurtError::invalid_message("Invalid protocol identifier"));
}
@@ -356,7 +351,6 @@ impl GurtResponse {
.unwrap_or_else(|| "Unknown".to_string())
};
// Parse headers
let mut headers = GurtHeaders::new();
for line in lines.iter().skip(1) {
@@ -394,7 +388,6 @@ impl GurtResponse {
}
if !headers.contains_key("date") {
// RFC 7231 compliant
let now = Utc::now();
let date_str = now.format("%a, %d %b %Y %H:%M:%S GMT").to_string();
headers.insert("date".to_string(), date_str);
@@ -429,7 +422,6 @@ impl GurtResponse {
}
if !headers.contains_key("date") {
// RFC 7231 compliant
let now = Utc::now();
let date_str = now.format("%a, %d %b %Y %H:%M:%S GMT").to_string();
headers.insert("date".to_string(), date_str);
@@ -441,7 +433,6 @@ impl GurtResponse {
message.push_str(HEADER_SEPARATOR);
// Convert headers to bytes and append body as raw bytes
let mut bytes = message.into_bytes();
bytes.extend_from_slice(&self.body);
@@ -461,7 +452,6 @@ impl GurtMessage {
}
pub fn parse_bytes(data: &[u8]) -> Result<Self> {
// Convert first line to string to determine message type
let header_separator = HEADER_SEPARATOR.as_bytes();
let first_line_end = data.windows(header_separator.len())
.position(|window| window == header_separator)
@@ -470,7 +460,6 @@ impl GurtMessage {
let first_line = std::str::from_utf8(&data[..first_line_end])
.map_err(|_| GurtError::invalid_message("Invalid UTF-8 in first line"))?;
// Check if it's a response (starts with GURT/version) or request (method first)
if first_line.starts_with(PROTOCOL_PREFIX) {
Ok(GurtMessage::Response(GurtResponse::parse_bytes(data)?))
} else {

View File

@@ -37,6 +37,7 @@ pub enum GurtStatusCode {
Timeout = 408,
TooLarge = 413,
UnsupportedMediaType = 415,
TooManyRequests = 429,
// Server errors
InternalServerError = 500,
@@ -62,6 +63,7 @@ impl GurtStatusCode {
408 => Some(Self::Timeout),
413 => Some(Self::TooLarge),
415 => Some(Self::UnsupportedMediaType),
429 => Some(Self::TooManyRequests),
500 => Some(Self::InternalServerError),
501 => Some(Self::NotImplemented),
502 => Some(Self::BadGateway),
@@ -86,6 +88,7 @@ impl GurtStatusCode {
Self::Timeout => "TIMEOUT",
Self::TooLarge => "TOO_LARGE",
Self::UnsupportedMediaType => "UNSUPPORTED_MEDIA_TYPE",
Self::TooManyRequests => "TOO_MANY_REQUESTS",
Self::InternalServerError => "INTERNAL_SERVER_ERROR",
Self::NotImplemented => "NOT_IMPLEMENTED",
Self::BadGateway => "BAD_GATEWAY",

View File

@@ -7,6 +7,7 @@ use crate::{
};
use tokio::net::{TcpListener, TcpStream};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::time::{timeout, Duration};
use tokio_rustls::{TlsAcceptor, server::TlsStream};
use rustls::pki_types::CertificateDer;
use std::collections::HashMap;
@@ -136,6 +137,9 @@ impl Route {
pub struct GurtServer {
routes: Vec<(Route, Arc<dyn GurtHandler>)>,
tls_acceptor: Option<TlsAcceptor>,
handshake_timeout: Duration,
request_timeout: Duration,
connection_timeout: Duration,
}
impl GurtServer {
@@ -143,9 +147,19 @@ impl GurtServer {
Self {
routes: Vec::new(),
tls_acceptor: None,
handshake_timeout: Duration::from_secs(5),
request_timeout: Duration::from_secs(30),
connection_timeout: Duration::from_secs(10),
}
}
pub fn with_timeouts(mut self, handshake_timeout: Duration, request_timeout: Duration, connection_timeout: Duration) -> Self {
self.handshake_timeout = handshake_timeout;
self.request_timeout = request_timeout;
self.connection_timeout = connection_timeout;
self
}
pub fn with_tls_certificates(cert_path: &str, key_path: &str) -> Result<Self> {
let mut server = Self::new();
server.load_tls_certificates(cert_path, key_path)?;
@@ -279,57 +293,76 @@ impl GurtServer {
}
async fn handle_connection(&self, mut stream: TcpStream, addr: SocketAddr) -> Result<()> {
self.handle_initial_handshake(&mut stream, addr).await?;
let connection_result = timeout(self.connection_timeout, async {
self.handle_initial_handshake(&mut stream, addr).await?;
if let Some(tls_acceptor) = &self.tls_acceptor {
info!("Upgrading connection to TLS for {}", addr);
let tls_stream = tls_acceptor.accept(stream).await
.map_err(|e| GurtError::crypto(format!("TLS upgrade failed: {}", e)))?;
info!("TLS upgrade completed for {}", addr);
self.handle_tls_connection(tls_stream, addr).await
} else {
warn!("No TLS configuration available, but handshake completed - this violates GURT protocol");
Err(GurtError::protocol("TLS is required after handshake but no TLS configuration available"))
}
}).await;
if let Some(tls_acceptor) = &self.tls_acceptor {
info!("Upgrading connection to TLS for {}", addr);
let tls_stream = tls_acceptor.accept(stream).await
.map_err(|e| GurtError::crypto(format!("TLS upgrade failed: {}", e)))?;
info!("TLS upgrade completed for {}", addr);
self.handle_tls_connection(tls_stream, addr).await
} else {
warn!("No TLS configuration available, but handshake completed - this violates GURT protocol");
Err(GurtError::protocol("TLS is required after handshake but no TLS configuration available"))
match connection_result {
Ok(result) => result,
Err(_) => {
warn!("Connection timeout for {}", addr);
Err(GurtError::timeout("Connection timeout"))
}
}
}
async fn handle_initial_handshake(&self, stream: &mut TcpStream, addr: SocketAddr) -> Result<()> {
let mut buffer = Vec::new();
let mut temp_buffer = [0u8; 8192];
loop {
let bytes_read = stream.read(&mut temp_buffer).await?;
if bytes_read == 0 {
return Err(GurtError::connection("Connection closed during handshake"));
}
let handshake_result = timeout(self.handshake_timeout, async {
let mut buffer = Vec::new();
let mut temp_buffer = [0u8; 8192];
buffer.extend_from_slice(&temp_buffer[..bytes_read]);
let body_separator = BODY_SEPARATOR.as_bytes();
if buffer.windows(body_separator.len()).any(|w| w == body_separator) {
break;
}
if buffer.len() > MAX_MESSAGE_SIZE {
return Err(GurtError::protocol("Handshake message too large"));
}
}
let message = GurtMessage::parse_bytes(&buffer)?;
match message {
GurtMessage::Request(request) => {
if request.method == GurtMethod::HANDSHAKE {
self.send_handshake_response(stream, addr, &request).await
} else {
Err(GurtError::protocol("First message must be HANDSHAKE"))
loop {
let bytes_read = stream.read(&mut temp_buffer).await?;
if bytes_read == 0 {
return Err(GurtError::connection("Connection closed during handshake"));
}
buffer.extend_from_slice(&temp_buffer[..bytes_read]);
let body_separator = BODY_SEPARATOR.as_bytes();
if buffer.windows(body_separator.len()).any(|w| w == body_separator) {
break;
}
if buffer.len() > MAX_MESSAGE_SIZE {
return Err(GurtError::protocol("Handshake message too large"));
}
}
GurtMessage::Response(_) => {
Err(GurtError::protocol("Server received response during handshake"))
let message = GurtMessage::parse_bytes(&buffer)?;
match message {
GurtMessage::Request(request) => {
if request.method == GurtMethod::HANDSHAKE {
self.send_handshake_response(stream, addr, &request).await
} else {
Err(GurtError::protocol("First message must be HANDSHAKE"))
}
}
GurtMessage::Response(_) => {
Err(GurtError::protocol("Server received response during handshake"))
}
}
}).await;
match handshake_result {
Ok(result) => result,
Err(_) => {
warn!("Handshake timeout for {}", addr);
Err(GurtError::timeout("Handshake timeout"))
}
}
}
@@ -342,7 +375,6 @@ impl GurtServer {
let bytes_read = match tls_stream.read(&mut temp_buffer).await {
Ok(n) => n,
Err(e) => {
// Handle UnexpectedEof from clients that don't send close_notify
if e.kind() == std::io::ErrorKind::UnexpectedEof {
debug!("Client {} closed connection without TLS close_notify (benign)", addr);
break;
@@ -351,7 +383,7 @@ impl GurtServer {
}
};
if bytes_read == 0 {
break; // Connection closed
break;
}
buffer.extend_from_slice(&temp_buffer[..bytes_read]);
@@ -361,17 +393,31 @@ impl GurtServer {
(buffer.starts_with(b"{") && buffer.ends_with(b"}"));
if has_complete_message {
if let Err(e) = self.process_tls_message(&mut tls_stream, addr, &buffer).await {
error!("Encrypted message processing error from {}: {}", addr, e);
let error_response = GurtResponse::internal_server_error()
.with_string_body("Internal server error");
let _ = tls_stream.write_all(&error_response.to_bytes()).await;
let process_result = timeout(self.request_timeout,
self.process_tls_message(&mut tls_stream, addr, &buffer)
).await;
match process_result {
Ok(Ok(())) => {
debug!("Processed message from {} successfully", addr);
}
Ok(Err(e)) => {
error!("Encrypted message processing error from {}: {}", addr, e);
let error_response = GurtResponse::internal_server_error()
.with_string_body("Internal server error");
let _ = tls_stream.write_all(&error_response.to_bytes()).await;
}
Err(_) => {
warn!("Request timeout for {}", addr);
let timeout_response = GurtResponse::new(GurtStatusCode::Timeout)
.with_string_body("Request timeout");
let _ = tls_stream.write_all(&timeout_response.to_bytes()).await;
}
}
buffer.clear();
}
// Prevent buffer overflow
if buffer.len() > MAX_MESSAGE_SIZE {
warn!("Message too large from {}, closing connection", addr);
break;
@@ -422,7 +468,6 @@ impl GurtServer {
if let Some(method) = &route.method {
allowed_methods.insert(method.to_string());
} else {
// Route matches any method
allowed_methods.extend(vec![
"GET".to_string(), "POST".to_string(), "PUT".to_string(),
"DELETE".to_string(), "HEAD".to_string(), "PATCH".to_string()
@@ -482,7 +527,6 @@ impl GurtServer {
async fn handle_encrypted_request(&self, tls_stream: &mut TlsStream<TcpStream>, addr: SocketAddr, request: &GurtRequest) -> Result<()> {
debug!("Handling encrypted {} request to {} from {}", request.method, request.path, addr);
// Find matching route
for (route, handler) in &self.routes {
if route.matches(&request.method, &request.path) {
let context = ServerContext {
@@ -492,7 +536,6 @@ impl GurtServer {
match handler.handle(&context).await {
Ok(response) => {
// Use to_bytes() to avoid corrupting binary data
let response_bytes = response.to_bytes();
tls_stream.write_all(&response_bytes).await?;
return Ok(());
@@ -508,7 +551,6 @@ impl GurtServer {
}
}
// No route found - check for default OPTIONS/HEAD handling
match request.method {
GurtMethod::OPTIONS => {
self.handle_default_options(tls_stream, request).await
@@ -531,6 +573,9 @@ impl Clone for GurtServer {
Self {
routes: self.routes.clone(),
tls_acceptor: self.tls_acceptor.clone(),
handshake_timeout: self.handshake_timeout,
request_timeout: self.request_timeout,
connection_timeout: self.connection_timeout,
}
}
}