Gurty -> usable state
This commit is contained in:
@@ -7,6 +7,7 @@ use crate::{
|
||||
};
|
||||
use tokio::net::{TcpListener, TcpStream};
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||||
use tokio::time::{timeout, Duration};
|
||||
use tokio_rustls::{TlsAcceptor, server::TlsStream};
|
||||
use rustls::pki_types::CertificateDer;
|
||||
use std::collections::HashMap;
|
||||
@@ -136,6 +137,9 @@ impl Route {
|
||||
pub struct GurtServer {
|
||||
routes: Vec<(Route, Arc<dyn GurtHandler>)>,
|
||||
tls_acceptor: Option<TlsAcceptor>,
|
||||
handshake_timeout: Duration,
|
||||
request_timeout: Duration,
|
||||
connection_timeout: Duration,
|
||||
}
|
||||
|
||||
impl GurtServer {
|
||||
@@ -143,9 +147,19 @@ impl GurtServer {
|
||||
Self {
|
||||
routes: Vec::new(),
|
||||
tls_acceptor: None,
|
||||
handshake_timeout: Duration::from_secs(5),
|
||||
request_timeout: Duration::from_secs(30),
|
||||
connection_timeout: Duration::from_secs(10),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_timeouts(mut self, handshake_timeout: Duration, request_timeout: Duration, connection_timeout: Duration) -> Self {
|
||||
self.handshake_timeout = handshake_timeout;
|
||||
self.request_timeout = request_timeout;
|
||||
self.connection_timeout = connection_timeout;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_tls_certificates(cert_path: &str, key_path: &str) -> Result<Self> {
|
||||
let mut server = Self::new();
|
||||
server.load_tls_certificates(cert_path, key_path)?;
|
||||
@@ -279,57 +293,76 @@ impl GurtServer {
|
||||
}
|
||||
|
||||
async fn handle_connection(&self, mut stream: TcpStream, addr: SocketAddr) -> Result<()> {
|
||||
self.handle_initial_handshake(&mut stream, addr).await?;
|
||||
let connection_result = timeout(self.connection_timeout, async {
|
||||
self.handle_initial_handshake(&mut stream, addr).await?;
|
||||
|
||||
if let Some(tls_acceptor) = &self.tls_acceptor {
|
||||
info!("Upgrading connection to TLS for {}", addr);
|
||||
let tls_stream = tls_acceptor.accept(stream).await
|
||||
.map_err(|e| GurtError::crypto(format!("TLS upgrade failed: {}", e)))?;
|
||||
|
||||
info!("TLS upgrade completed for {}", addr);
|
||||
|
||||
self.handle_tls_connection(tls_stream, addr).await
|
||||
} else {
|
||||
warn!("No TLS configuration available, but handshake completed - this violates GURT protocol");
|
||||
Err(GurtError::protocol("TLS is required after handshake but no TLS configuration available"))
|
||||
}
|
||||
}).await;
|
||||
|
||||
if let Some(tls_acceptor) = &self.tls_acceptor {
|
||||
info!("Upgrading connection to TLS for {}", addr);
|
||||
let tls_stream = tls_acceptor.accept(stream).await
|
||||
.map_err(|e| GurtError::crypto(format!("TLS upgrade failed: {}", e)))?;
|
||||
|
||||
info!("TLS upgrade completed for {}", addr);
|
||||
|
||||
self.handle_tls_connection(tls_stream, addr).await
|
||||
} else {
|
||||
warn!("No TLS configuration available, but handshake completed - this violates GURT protocol");
|
||||
Err(GurtError::protocol("TLS is required after handshake but no TLS configuration available"))
|
||||
match connection_result {
|
||||
Ok(result) => result,
|
||||
Err(_) => {
|
||||
warn!("Connection timeout for {}", addr);
|
||||
Err(GurtError::timeout("Connection timeout"))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_initial_handshake(&self, stream: &mut TcpStream, addr: SocketAddr) -> Result<()> {
|
||||
let mut buffer = Vec::new();
|
||||
let mut temp_buffer = [0u8; 8192];
|
||||
|
||||
loop {
|
||||
let bytes_read = stream.read(&mut temp_buffer).await?;
|
||||
if bytes_read == 0 {
|
||||
return Err(GurtError::connection("Connection closed during handshake"));
|
||||
}
|
||||
let handshake_result = timeout(self.handshake_timeout, async {
|
||||
let mut buffer = Vec::new();
|
||||
let mut temp_buffer = [0u8; 8192];
|
||||
|
||||
buffer.extend_from_slice(&temp_buffer[..bytes_read]);
|
||||
|
||||
let body_separator = BODY_SEPARATOR.as_bytes();
|
||||
if buffer.windows(body_separator.len()).any(|w| w == body_separator) {
|
||||
break;
|
||||
}
|
||||
|
||||
if buffer.len() > MAX_MESSAGE_SIZE {
|
||||
return Err(GurtError::protocol("Handshake message too large"));
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
let message = GurtMessage::parse_bytes(&buffer)?;
|
||||
|
||||
match message {
|
||||
GurtMessage::Request(request) => {
|
||||
if request.method == GurtMethod::HANDSHAKE {
|
||||
self.send_handshake_response(stream, addr, &request).await
|
||||
} else {
|
||||
Err(GurtError::protocol("First message must be HANDSHAKE"))
|
||||
loop {
|
||||
let bytes_read = stream.read(&mut temp_buffer).await?;
|
||||
if bytes_read == 0 {
|
||||
return Err(GurtError::connection("Connection closed during handshake"));
|
||||
}
|
||||
|
||||
buffer.extend_from_slice(&temp_buffer[..bytes_read]);
|
||||
|
||||
let body_separator = BODY_SEPARATOR.as_bytes();
|
||||
if buffer.windows(body_separator.len()).any(|w| w == body_separator) {
|
||||
break;
|
||||
}
|
||||
|
||||
if buffer.len() > MAX_MESSAGE_SIZE {
|
||||
return Err(GurtError::protocol("Handshake message too large"));
|
||||
}
|
||||
}
|
||||
GurtMessage::Response(_) => {
|
||||
Err(GurtError::protocol("Server received response during handshake"))
|
||||
|
||||
let message = GurtMessage::parse_bytes(&buffer)?;
|
||||
|
||||
match message {
|
||||
GurtMessage::Request(request) => {
|
||||
if request.method == GurtMethod::HANDSHAKE {
|
||||
self.send_handshake_response(stream, addr, &request).await
|
||||
} else {
|
||||
Err(GurtError::protocol("First message must be HANDSHAKE"))
|
||||
}
|
||||
}
|
||||
GurtMessage::Response(_) => {
|
||||
Err(GurtError::protocol("Server received response during handshake"))
|
||||
}
|
||||
}
|
||||
}).await;
|
||||
|
||||
match handshake_result {
|
||||
Ok(result) => result,
|
||||
Err(_) => {
|
||||
warn!("Handshake timeout for {}", addr);
|
||||
Err(GurtError::timeout("Handshake timeout"))
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -342,7 +375,6 @@ impl GurtServer {
|
||||
let bytes_read = match tls_stream.read(&mut temp_buffer).await {
|
||||
Ok(n) => n,
|
||||
Err(e) => {
|
||||
// Handle UnexpectedEof from clients that don't send close_notify
|
||||
if e.kind() == std::io::ErrorKind::UnexpectedEof {
|
||||
debug!("Client {} closed connection without TLS close_notify (benign)", addr);
|
||||
break;
|
||||
@@ -351,7 +383,7 @@ impl GurtServer {
|
||||
}
|
||||
};
|
||||
if bytes_read == 0 {
|
||||
break; // Connection closed
|
||||
break;
|
||||
}
|
||||
|
||||
buffer.extend_from_slice(&temp_buffer[..bytes_read]);
|
||||
@@ -361,17 +393,31 @@ impl GurtServer {
|
||||
(buffer.starts_with(b"{") && buffer.ends_with(b"}"));
|
||||
|
||||
if has_complete_message {
|
||||
if let Err(e) = self.process_tls_message(&mut tls_stream, addr, &buffer).await {
|
||||
error!("Encrypted message processing error from {}: {}", addr, e);
|
||||
let error_response = GurtResponse::internal_server_error()
|
||||
.with_string_body("Internal server error");
|
||||
let _ = tls_stream.write_all(&error_response.to_bytes()).await;
|
||||
let process_result = timeout(self.request_timeout,
|
||||
self.process_tls_message(&mut tls_stream, addr, &buffer)
|
||||
).await;
|
||||
|
||||
match process_result {
|
||||
Ok(Ok(())) => {
|
||||
debug!("Processed message from {} successfully", addr);
|
||||
}
|
||||
Ok(Err(e)) => {
|
||||
error!("Encrypted message processing error from {}: {}", addr, e);
|
||||
let error_response = GurtResponse::internal_server_error()
|
||||
.with_string_body("Internal server error");
|
||||
let _ = tls_stream.write_all(&error_response.to_bytes()).await;
|
||||
}
|
||||
Err(_) => {
|
||||
warn!("Request timeout for {}", addr);
|
||||
let timeout_response = GurtResponse::new(GurtStatusCode::Timeout)
|
||||
.with_string_body("Request timeout");
|
||||
let _ = tls_stream.write_all(&timeout_response.to_bytes()).await;
|
||||
}
|
||||
}
|
||||
|
||||
buffer.clear();
|
||||
}
|
||||
|
||||
// Prevent buffer overflow
|
||||
if buffer.len() > MAX_MESSAGE_SIZE {
|
||||
warn!("Message too large from {}, closing connection", addr);
|
||||
break;
|
||||
@@ -422,7 +468,6 @@ impl GurtServer {
|
||||
if let Some(method) = &route.method {
|
||||
allowed_methods.insert(method.to_string());
|
||||
} else {
|
||||
// Route matches any method
|
||||
allowed_methods.extend(vec![
|
||||
"GET".to_string(), "POST".to_string(), "PUT".to_string(),
|
||||
"DELETE".to_string(), "HEAD".to_string(), "PATCH".to_string()
|
||||
@@ -482,7 +527,6 @@ impl GurtServer {
|
||||
async fn handle_encrypted_request(&self, tls_stream: &mut TlsStream<TcpStream>, addr: SocketAddr, request: &GurtRequest) -> Result<()> {
|
||||
debug!("Handling encrypted {} request to {} from {}", request.method, request.path, addr);
|
||||
|
||||
// Find matching route
|
||||
for (route, handler) in &self.routes {
|
||||
if route.matches(&request.method, &request.path) {
|
||||
let context = ServerContext {
|
||||
@@ -492,7 +536,6 @@ impl GurtServer {
|
||||
|
||||
match handler.handle(&context).await {
|
||||
Ok(response) => {
|
||||
// Use to_bytes() to avoid corrupting binary data
|
||||
let response_bytes = response.to_bytes();
|
||||
tls_stream.write_all(&response_bytes).await?;
|
||||
return Ok(());
|
||||
@@ -508,7 +551,6 @@ impl GurtServer {
|
||||
}
|
||||
}
|
||||
|
||||
// No route found - check for default OPTIONS/HEAD handling
|
||||
match request.method {
|
||||
GurtMethod::OPTIONS => {
|
||||
self.handle_default_options(tls_stream, request).await
|
||||
@@ -531,6 +573,9 @@ impl Clone for GurtServer {
|
||||
Self {
|
||||
routes: self.routes.clone(),
|
||||
tls_acceptor: self.tls_acceptor.clone(),
|
||||
handshake_timeout: self.handshake_timeout,
|
||||
request_timeout: self.request_timeout,
|
||||
connection_timeout: self.connection_timeout,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user