GURT protocol (lib, cli, gdextension, Flumi integration)

This commit is contained in:
Face
2025-08-14 20:29:19 +03:00
parent 65f3a21890
commit c117e602fe
46 changed files with 6559 additions and 89 deletions

View File

@@ -0,0 +1,302 @@
use crate::{
GurtError, Result, GurtRequest, GurtResponse,
protocol::{DEFAULT_PORT, DEFAULT_CONNECTION_TIMEOUT, DEFAULT_REQUEST_TIMEOUT, DEFAULT_HANDSHAKE_TIMEOUT, BODY_SEPARATOR},
message::GurtMethod,
};
use tokio::net::TcpStream;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::time::{timeout, Duration};
use url::Url;
use tracing::debug;
#[derive(Debug, Clone)]
pub struct ClientConfig {
pub connect_timeout: Duration,
pub request_timeout: Duration,
pub handshake_timeout: Duration,
pub user_agent: String,
pub max_redirects: usize,
}
impl Default for ClientConfig {
fn default() -> Self {
Self {
connect_timeout: Duration::from_secs(DEFAULT_CONNECTION_TIMEOUT),
request_timeout: Duration::from_secs(DEFAULT_REQUEST_TIMEOUT),
handshake_timeout: Duration::from_secs(DEFAULT_HANDSHAKE_TIMEOUT),
user_agent: format!("GURT-Client/{}", crate::GURT_VERSION),
max_redirects: 5,
}
}
}
#[derive(Debug)]
struct PooledConnection {
stream: TcpStream,
}
impl PooledConnection {
fn new(stream: TcpStream) -> Self {
Self { stream }
}
}
pub struct GurtClient {
config: ClientConfig,
}
impl GurtClient {
pub fn new() -> Self {
Self {
config: ClientConfig::default(),
}
}
pub fn with_config(config: ClientConfig) -> Self {
Self {
config,
}
}
async fn create_connection(&self, host: &str, port: u16) -> Result<PooledConnection> {
let addr = format!("{}:{}", host, port);
let stream = timeout(
self.config.connect_timeout,
TcpStream::connect(&addr)
).await
.map_err(|_| GurtError::timeout("Connection timeout"))?
.map_err(|e| GurtError::connection(format!("Failed to connect: {}", e)))?;
let conn = PooledConnection::new(stream);
Ok(conn)
}
async fn read_response_data(&self, stream: &mut TcpStream) -> Result<Vec<u8>> {
let mut buffer = Vec::new();
let mut temp_buffer = [0u8; 8192];
let start_time = std::time::Instant::now();
loop {
if start_time.elapsed() > self.config.request_timeout {
return Err(GurtError::timeout("Response timeout"));
}
let bytes_read = stream.read(&mut temp_buffer).await?;
if bytes_read == 0 {
break; // Connection closed
}
buffer.extend_from_slice(&temp_buffer[..bytes_read]);
// Check for complete message without converting to string
let body_separator = BODY_SEPARATOR.as_bytes();
let has_complete_response = buffer.windows(body_separator.len()).any(|w| w == body_separator) ||
(buffer.starts_with(b"{") && buffer.ends_with(b"}"));
if has_complete_response {
return Ok(buffer);
}
}
if buffer.is_empty() {
Err(GurtError::connection("Connection closed unexpectedly"))
} else {
Ok(buffer)
}
}
async fn send_request_internal(&self, host: &str, port: u16, request: GurtRequest) -> Result<GurtResponse> {
debug!("Sending {} {} to {}:{}", request.method, request.path, host, port);
let mut conn = self.create_connection(host, port).await?;
let request_data = request.to_string();
conn.stream.write_all(request_data.as_bytes()).await?;
let response_bytes = timeout(
self.config.request_timeout,
self.read_response_data(&mut conn.stream)
).await
.map_err(|_| GurtError::timeout("Request timeout"))??;
let response = GurtResponse::parse_bytes(&response_bytes)?;
Ok(response)
}
pub async fn get(&self, url: &str) -> Result<GurtResponse> {
let (host, port, path) = self.parse_url(url)?;
let request = GurtRequest::new(GurtMethod::GET, path)
.with_header("Host", &host)
.with_header("User-Agent", &self.config.user_agent)
.with_header("Accept", "*/*");
self.send_request_internal(&host, port, request).await
}
pub async fn post(&self, url: &str, body: &str) -> Result<GurtResponse> {
let (host, port, path) = self.parse_url(url)?;
let request = GurtRequest::new(GurtMethod::POST, path)
.with_header("Host", &host)
.with_header("User-Agent", &self.config.user_agent)
.with_header("Content-Type", "text/plain")
.with_string_body(body);
self.send_request_internal(&host, port, request).await
}
/// POST request with JSON body
pub async fn post_json<T: serde::Serialize>(&self, url: &str, data: &T) -> Result<GurtResponse> {
let (host, port, path) = self.parse_url(url)?;
let json_body = serde_json::to_string(data)?;
let request = GurtRequest::new(GurtMethod::POST, path)
.with_header("Host", &host)
.with_header("User-Agent", &self.config.user_agent)
.with_header("Content-Type", "application/json")
.with_string_body(json_body);
self.send_request_internal(&host, port, request).await
}
/// PUT request with body
pub async fn put(&self, url: &str, body: &str) -> Result<GurtResponse> {
let (host, port, path) = self.parse_url(url)?;
let request = GurtRequest::new(GurtMethod::PUT, path)
.with_header("Host", &host)
.with_header("User-Agent", &self.config.user_agent)
.with_header("Content-Type", "text/plain")
.with_string_body(body);
self.send_request_internal(&host, port, request).await
}
/// PUT request with JSON body
pub async fn put_json<T: serde::Serialize>(&self, url: &str, data: &T) -> Result<GurtResponse> {
let (host, port, path) = self.parse_url(url)?;
let json_body = serde_json::to_string(data)?;
let request = GurtRequest::new(GurtMethod::PUT, path)
.with_header("Host", &host)
.with_header("User-Agent", &self.config.user_agent)
.with_header("Content-Type", "application/json")
.with_string_body(json_body);
self.send_request_internal(&host, port, request).await
}
pub async fn delete(&self, url: &str) -> Result<GurtResponse> {
let (host, port, path) = self.parse_url(url)?;
let request = GurtRequest::new(GurtMethod::DELETE, path)
.with_header("Host", &host)
.with_header("User-Agent", &self.config.user_agent);
self.send_request_internal(&host, port, request).await
}
pub async fn head(&self, url: &str) -> Result<GurtResponse> {
let (host, port, path) = self.parse_url(url)?;
let request = GurtRequest::new(GurtMethod::HEAD, path)
.with_header("Host", &host)
.with_header("User-Agent", &self.config.user_agent);
self.send_request_internal(&host, port, request).await
}
pub async fn options(&self, url: &str) -> Result<GurtResponse> {
let (host, port, path) = self.parse_url(url)?;
let request = GurtRequest::new(GurtMethod::OPTIONS, path)
.with_header("Host", &host)
.with_header("User-Agent", &self.config.user_agent);
self.send_request_internal(&host, port, request).await
}
/// PATCH request with body
pub async fn patch(&self, url: &str, body: &str) -> Result<GurtResponse> {
let (host, port, path) = self.parse_url(url)?;
let request = GurtRequest::new(GurtMethod::PATCH, path)
.with_header("Host", &host)
.with_header("User-Agent", &self.config.user_agent)
.with_header("Content-Type", "text/plain")
.with_string_body(body);
self.send_request_internal(&host, port, request).await
}
/// PATCH request with JSON body
pub async fn patch_json<T: serde::Serialize>(&self, url: &str, data: &T) -> Result<GurtResponse> {
let (host, port, path) = self.parse_url(url)?;
let json_body = serde_json::to_string(data)?;
let request = GurtRequest::new(GurtMethod::PATCH, path)
.with_header("Host", &host)
.with_header("User-Agent", &self.config.user_agent)
.with_header("Content-Type", "application/json")
.with_string_body(json_body);
self.send_request_internal(&host, port, request).await
}
pub async fn send_request(&self, host: &str, port: u16, request: GurtRequest) -> Result<GurtResponse> {
self.send_request_internal(host, port, request).await
}
fn parse_url(&self, url: &str) -> Result<(String, u16, String)> {
let parsed_url = Url::parse(url).map_err(|e| GurtError::invalid_message(format!("Invalid URL: {}", e)))?;
if parsed_url.scheme() != "gurt" {
return Err(GurtError::invalid_message("URL must use gurt:// scheme"));
}
let host = parsed_url.host_str()
.ok_or_else(|| GurtError::invalid_message("URL must have a host"))?
.to_string();
let port = parsed_url.port().unwrap_or(DEFAULT_PORT);
let path = if parsed_url.path().is_empty() {
"/".to_string()
} else {
parsed_url.path().to_string()
};
Ok((host, port, path))
}
}
impl Default for GurtClient {
fn default() -> Self {
Self::new()
}
}
impl Clone for GurtClient {
fn clone(&self) -> Self {
Self {
config: self.config.clone(),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_url_parsing() {
let client = GurtClient::new();
let (host, port, path) = client.parse_url("gurt://example.com/test").unwrap();
assert_eq!(host, "example.com");
assert_eq!(port, DEFAULT_PORT);
assert_eq!(path, "/test");
let (host, port, path) = client.parse_url("gurt://example.com:8080/api/v1").unwrap();
assert_eq!(host, "example.com");
assert_eq!(port, 8080);
assert_eq!(path, "/api/v1");
}
}

View File

@@ -0,0 +1,123 @@
use crate::{GurtError, Result};
use rustls::{ClientConfig, ServerConfig};
use rustls::pki_types::{CertificateDer, PrivateKeyDer};
use tokio_rustls::{TlsConnector, TlsAcceptor};
use std::sync::Arc;
pub const TLS_VERSION: &str = "TLS/1.3";
pub const GURT_ALPN: &[u8] = b"GURT/1.0";
#[derive(Debug, Clone)]
pub struct TlsConfig {
pub client_config: Option<Arc<ClientConfig>>,
pub server_config: Option<Arc<ServerConfig>>,
}
impl TlsConfig {
pub fn new_client() -> Result<Self> {
let mut config = ClientConfig::builder()
.with_root_certificates(rustls::RootCertStore::empty())
.with_no_client_auth();
config.alpn_protocols = vec![GURT_ALPN.to_vec()];
Ok(Self {
client_config: Some(Arc::new(config)),
server_config: None,
})
}
pub fn new_server(cert_chain: Vec<CertificateDer<'static>>, private_key: PrivateKeyDer<'static>) -> Result<Self> {
let mut config = ServerConfig::builder()
.with_no_client_auth()
.with_single_cert(cert_chain, private_key)
.map_err(|e| GurtError::crypto(format!("TLS server config error: {}", e)))?;
config.alpn_protocols = vec![GURT_ALPN.to_vec()];
Ok(Self {
client_config: None,
server_config: Some(Arc::new(config)),
})
}
pub fn get_connector(&self) -> Result<TlsConnector> {
let config = self.client_config.as_ref()
.ok_or_else(|| GurtError::crypto("No client config available"))?;
Ok(TlsConnector::from(config.clone()))
}
pub fn get_acceptor(&self) -> Result<TlsAcceptor> {
let config = self.server_config.as_ref()
.ok_or_else(|| GurtError::crypto("No server config available"))?;
Ok(TlsAcceptor::from(config.clone()))
}
}
#[derive(Debug)]
pub struct CryptoManager {
tls_config: Option<TlsConfig>,
}
impl CryptoManager {
pub fn new() -> Self {
Self {
tls_config: None,
}
}
pub fn with_tls_config(config: TlsConfig) -> Self {
Self {
tls_config: Some(config),
}
}
pub fn set_tls_config(&mut self, config: TlsConfig) {
self.tls_config = Some(config);
}
pub fn has_tls_config(&self) -> bool {
self.tls_config.is_some()
}
pub fn get_tls_connector(&self) -> Result<TlsConnector> {
let config = self.tls_config.as_ref()
.ok_or_else(|| GurtError::crypto("No TLS config available"))?;
config.get_connector()
}
pub fn get_tls_acceptor(&self) -> Result<TlsAcceptor> {
let config = self.tls_config.as_ref()
.ok_or_else(|| GurtError::crypto("No TLS config available"))?;
config.get_acceptor()
}
}
impl Default for CryptoManager {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_tls_config_creation() {
let client_config = TlsConfig::new_client();
assert!(client_config.is_ok());
let config = client_config.unwrap();
assert!(config.client_config.is_some());
assert!(config.server_config.is_none());
}
#[test]
fn test_crypto_manager() {
let crypto = CryptoManager::new();
assert!(!crypto.has_tls_config());
}
}

View File

@@ -0,0 +1,71 @@
use std::fmt;
use thiserror::Error;
#[derive(Error, Debug)]
pub enum GurtError {
#[error("IO error: {0}")]
Io(#[from] std::io::Error),
#[error("Serialization error: {0}")]
Serialization(#[from] serde_json::Error),
#[error("Cryptographic error: {0}")]
Crypto(String),
#[error("Protocol error: {0}")]
Protocol(String),
#[error("Invalid message format: {0}")]
InvalidMessage(String),
#[error("Connection error: {0}")]
Connection(String),
#[error("Handshake failed: {0}")]
Handshake(String),
#[error("Timeout error: {0}")]
Timeout(String),
#[error("Server error: {status} {message}")]
Server { status: u16, message: String },
#[error("Client error: {0}")]
Client(String),
}
pub type Result<T> = std::result::Result<T, GurtError>;
impl GurtError {
pub fn crypto<T: fmt::Display>(msg: T) -> Self {
GurtError::Crypto(msg.to_string())
}
pub fn protocol<T: fmt::Display>(msg: T) -> Self {
GurtError::Protocol(msg.to_string())
}
pub fn invalid_message<T: fmt::Display>(msg: T) -> Self {
GurtError::InvalidMessage(msg.to_string())
}
pub fn connection<T: fmt::Display>(msg: T) -> Self {
GurtError::Connection(msg.to_string())
}
pub fn handshake<T: fmt::Display>(msg: T) -> Self {
GurtError::Handshake(msg.to_string())
}
pub fn timeout<T: fmt::Display>(msg: T) -> Self {
GurtError::Timeout(msg.to_string())
}
pub fn server(status: u16, message: String) -> Self {
GurtError::Server { status, message }
}
pub fn client<T: fmt::Display>(msg: T) -> Self {
GurtError::Client(msg.to_string())
}
}

View File

@@ -0,0 +1,24 @@
pub mod protocol;
pub mod crypto;
pub mod server;
pub mod client;
pub mod error;
pub mod message;
pub use error::{GurtError, Result};
pub use message::{GurtMessage, GurtRequest, GurtResponse, GurtMethod};
pub use protocol::{GurtStatusCode, GURT_VERSION, DEFAULT_PORT};
pub use crypto::{CryptoManager, TlsConfig, GURT_ALPN, TLS_VERSION};
pub use server::{GurtServer, GurtHandler, ServerContext, Route};
pub use client::{GurtClient, ClientConfig};
pub mod prelude {
pub use crate::{
GurtError, Result,
GurtMessage, GurtRequest, GurtResponse,
GURT_VERSION, DEFAULT_PORT,
CryptoManager, TlsConfig, GURT_ALPN, TLS_VERSION,
GurtServer, GurtHandler, ServerContext, Route,
GurtClient, ClientConfig,
};
}

View File

@@ -0,0 +1,568 @@
use crate::{GurtError, Result, GURT_VERSION};
use crate::protocol::{GurtStatusCode, PROTOCOL_PREFIX, HEADER_SEPARATOR, BODY_SEPARATOR};
use serde::{Serialize, Deserialize};
use std::collections::HashMap;
use std::fmt;
use chrono::Utc;
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub enum GurtMethod {
GET,
POST,
PUT,
DELETE,
HEAD,
OPTIONS,
PATCH,
HANDSHAKE, // Special method for protocol handshake
}
impl GurtMethod {
pub fn parse(s: &str) -> Result<Self> {
match s.to_uppercase().as_str() {
"GET" => Ok(Self::GET),
"POST" => Ok(Self::POST),
"PUT" => Ok(Self::PUT),
"DELETE" => Ok(Self::DELETE),
"HEAD" => Ok(Self::HEAD),
"OPTIONS" => Ok(Self::OPTIONS),
"PATCH" => Ok(Self::PATCH),
"HANDSHAKE" => Ok(Self::HANDSHAKE),
_ => Err(GurtError::invalid_message(format!("Unsupported method: {}", s))),
}
}
}
impl fmt::Display for GurtMethod {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let s = match self {
Self::GET => "GET",
Self::POST => "POST",
Self::PUT => "PUT",
Self::DELETE => "DELETE",
Self::HEAD => "HEAD",
Self::OPTIONS => "OPTIONS",
Self::PATCH => "PATCH",
Self::HANDSHAKE => "HANDSHAKE",
};
write!(f, "{}", s)
}
}
pub type GurtHeaders = HashMap<String, String>;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GurtRequest {
pub method: GurtMethod,
pub path: String,
pub version: String,
pub headers: GurtHeaders,
pub body: Vec<u8>,
}
impl GurtRequest {
pub fn new(method: GurtMethod, path: String) -> Self {
Self {
method,
path,
version: GURT_VERSION.to_string(),
headers: GurtHeaders::new(),
body: Vec::new(),
}
}
pub fn with_header<K: Into<String>, V: Into<String>>(mut self, key: K, value: V) -> Self {
self.headers.insert(key.into().to_lowercase(), value.into());
self
}
pub fn with_body<B: Into<Vec<u8>>>(mut self, body: B) -> Self {
self.body = body.into();
self
}
pub fn with_string_body<S: AsRef<str>>(mut self, body: S) -> Self {
self.body = body.as_ref().as_bytes().to_vec();
self
}
pub fn header(&self, key: &str) -> Option<&String> {
self.headers.get(&key.to_lowercase())
}
pub fn body_as_string(&self) -> Result<String> {
std::str::from_utf8(&self.body)
.map(|s| s.to_string())
.map_err(|e| GurtError::invalid_message(format!("Invalid UTF-8 body: {}", e)))
}
pub fn parse(data: &str) -> Result<Self> {
Self::parse_bytes(data.as_bytes())
}
pub fn parse_bytes(data: &[u8]) -> Result<Self> {
// Find the header/body separator as bytes
let body_separator = BODY_SEPARATOR.as_bytes();
let body_separator_pos = data.windows(body_separator.len())
.position(|window| window == body_separator);
let (headers_section, body) = if let Some(pos) = body_separator_pos {
let headers_part = &data[..pos];
let body_part = &data[pos + body_separator.len()..];
(headers_part, body_part.to_vec())
} else {
(data, Vec::new())
};
// Convert headers section to string (should be valid UTF-8)
let headers_str = std::str::from_utf8(headers_section)
.map_err(|_| GurtError::invalid_message("Invalid UTF-8 in headers"))?;
let lines: Vec<&str> = headers_str.split(HEADER_SEPARATOR).collect();
if lines.is_empty() {
return Err(GurtError::invalid_message("Empty request"));
}
// Parse request line (METHOD path GURT/version)
let request_line = lines[0];
let parts: Vec<&str> = request_line.split_whitespace().collect();
if parts.len() != 3 {
return Err(GurtError::invalid_message("Invalid request line format"));
}
let method = GurtMethod::parse(parts[0])?;
let path = parts[1].to_string();
// Parse protocol version
if !parts[2].starts_with(PROTOCOL_PREFIX) {
return Err(GurtError::invalid_message("Invalid protocol identifier"));
}
let version_str = &parts[2][PROTOCOL_PREFIX.len()..];
let version = version_str.to_string();
// Parse headers
let mut headers = GurtHeaders::new();
for line in lines.iter().skip(1) {
if line.is_empty() {
break;
}
if let Some(colon_pos) = line.find(':') {
let key = line[..colon_pos].trim().to_lowercase();
let value = line[colon_pos + 1..].trim().to_string();
headers.insert(key, value);
}
}
Ok(Self {
method,
path,
version,
headers,
body,
})
}
pub fn to_string(&self) -> String {
let mut message = format!("{} {} {}{}{}",
self.method, self.path, PROTOCOL_PREFIX, self.version, HEADER_SEPARATOR);
let mut headers = self.headers.clone();
if !headers.contains_key("content-length") {
headers.insert("content-length".to_string(), self.body.len().to_string());
}
if !headers.contains_key("user-agent") {
headers.insert("user-agent".to_string(), format!("GURT-Client/{}", GURT_VERSION));
}
for (key, value) in &headers {
message.push_str(&format!("{}: {}{}", key, value, HEADER_SEPARATOR));
}
message.push_str(HEADER_SEPARATOR);
if !self.body.is_empty() {
if let Ok(body_str) = std::str::from_utf8(&self.body) {
message.push_str(body_str);
}
}
message
}
pub fn to_bytes(&self) -> Vec<u8> {
let mut message = format!("{} {} {}{}{}",
self.method, self.path, PROTOCOL_PREFIX, self.version, HEADER_SEPARATOR);
let mut headers = self.headers.clone();
if !headers.contains_key("content-length") {
headers.insert("content-length".to_string(), self.body.len().to_string());
}
if !headers.contains_key("user-agent") {
headers.insert("user-agent".to_string(), format!("GURT-Client/{}", GURT_VERSION));
}
for (key, value) in &headers {
message.push_str(&format!("{}: {}{}", key, value, HEADER_SEPARATOR));
}
message.push_str(HEADER_SEPARATOR);
let mut bytes = message.into_bytes();
bytes.extend_from_slice(&self.body);
bytes
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GurtResponse {
pub version: String,
pub status_code: u16,
pub status_message: String,
pub headers: GurtHeaders,
pub body: Vec<u8>,
}
impl GurtResponse {
pub fn new(status_code: GurtStatusCode) -> Self {
Self {
version: GURT_VERSION.to_string(),
status_code: status_code as u16,
status_message: status_code.message().to_string(),
headers: GurtHeaders::new(),
body: Vec::new(),
}
}
pub fn ok() -> Self {
Self::new(GurtStatusCode::Ok)
}
pub fn not_found() -> Self {
Self::new(GurtStatusCode::NotFound)
}
pub fn bad_request() -> Self {
Self::new(GurtStatusCode::BadRequest)
}
pub fn internal_server_error() -> Self {
Self::new(GurtStatusCode::InternalServerError)
}
pub fn with_header<K: Into<String>, V: Into<String>>(mut self, key: K, value: V) -> Self {
self.headers.insert(key.into().to_lowercase(), value.into());
self
}
pub fn with_body<B: Into<Vec<u8>>>(mut self, body: B) -> Self {
self.body = body.into();
self
}
pub fn with_string_body<S: AsRef<str>>(mut self, body: S) -> Self {
self.body = body.as_ref().as_bytes().to_vec();
self
}
pub fn with_json_body<T: Serialize>(mut self, data: &T) -> Result<Self> {
let json = serde_json::to_string(data)?;
self.body = json.into_bytes();
self.headers.insert("content-type".to_string(), "application/json".to_string());
Ok(self)
}
pub fn header(&self, key: &str) -> Option<&String> {
self.headers.get(&key.to_lowercase())
}
pub fn body_as_string(&self) -> Result<String> {
std::str::from_utf8(&self.body)
.map(|s| s.to_owned())
.map_err(|e| GurtError::invalid_message(format!("Invalid UTF-8 body: {}", e)))
}
pub fn is_success(&self) -> bool {
self.status_code >= 200 && self.status_code < 300
}
pub fn is_client_error(&self) -> bool {
self.status_code >= 400 && self.status_code < 500
}
pub fn is_server_error(&self) -> bool {
self.status_code >= 500
}
pub fn parse(data: &str) -> Result<Self> {
Self::parse_bytes(data.as_bytes())
}
pub fn parse_bytes(data: &[u8]) -> Result<Self> {
// Find the header/body separator as bytes
let body_separator = BODY_SEPARATOR.as_bytes();
let body_separator_pos = data.windows(body_separator.len())
.position(|window| window == body_separator);
let (headers_section, body) = if let Some(pos) = body_separator_pos {
let headers_part = &data[..pos];
let body_part = &data[pos + body_separator.len()..];
(headers_part, body_part.to_vec())
} else {
(data, Vec::new())
};
// Convert headers section to string (should be valid UTF-8)
let headers_str = std::str::from_utf8(headers_section)
.map_err(|_| GurtError::invalid_message("Invalid UTF-8 in headers"))?;
let lines: Vec<&str> = headers_str.split(HEADER_SEPARATOR).collect();
if lines.is_empty() {
return Err(GurtError::invalid_message("Empty response"));
}
// Parse status line (GURT/version status_code status_message)
let status_line = lines[0];
let parts: Vec<&str> = status_line.splitn(3, ' ').collect();
if parts.len() < 2 {
return Err(GurtError::invalid_message("Invalid status line format"));
}
// Parse protocol version
if !parts[0].starts_with(PROTOCOL_PREFIX) {
return Err(GurtError::invalid_message("Invalid protocol identifier"));
}
let version_str = &parts[0][PROTOCOL_PREFIX.len()..];
let version = version_str.to_string();
let status_code: u16 = parts[1].parse()
.map_err(|_| GurtError::invalid_message("Invalid status code"))?;
let status_message = if parts.len() > 2 {
parts[2].to_string()
} else {
GurtStatusCode::from_u16(status_code)
.map(|sc| sc.message().to_string())
.unwrap_or_else(|| "Unknown".to_string())
};
// Parse headers
let mut headers = GurtHeaders::new();
for line in lines.iter().skip(1) {
if line.is_empty() {
break;
}
if let Some(colon_pos) = line.find(':') {
let key = line[..colon_pos].trim().to_lowercase();
let value = line[colon_pos + 1..].trim().to_string();
headers.insert(key, value);
}
}
Ok(Self {
version,
status_code,
status_message,
headers,
body,
})
}
pub fn to_string(&self) -> String {
let mut message = format!("{}{} {} {}{}",
PROTOCOL_PREFIX, self.version, self.status_code, self.status_message, HEADER_SEPARATOR);
let mut headers = self.headers.clone();
if !headers.contains_key("content-length") {
headers.insert("content-length".to_string(), self.body.len().to_string());
}
if !headers.contains_key("server") {
headers.insert("server".to_string(), format!("GURT/{}", GURT_VERSION));
}
if !headers.contains_key("date") {
// RFC 7231 compliant
let now = Utc::now();
let date_str = now.format("%a, %d %b %Y %H:%M:%S GMT").to_string();
headers.insert("date".to_string(), date_str);
}
for (key, value) in &headers {
message.push_str(&format!("{}: {}{}", key, value, HEADER_SEPARATOR));
}
message.push_str(HEADER_SEPARATOR);
if !self.body.is_empty() {
if let Ok(body_str) = std::str::from_utf8(&self.body) {
message.push_str(body_str);
}
}
message
}
pub fn to_bytes(&self) -> Vec<u8> {
let mut message = format!("{}{} {} {}{}",
PROTOCOL_PREFIX, self.version, self.status_code, self.status_message, HEADER_SEPARATOR);
let mut headers = self.headers.clone();
if !headers.contains_key("content-length") {
headers.insert("content-length".to_string(), self.body.len().to_string());
}
if !headers.contains_key("server") {
headers.insert("server".to_string(), format!("GURT/{}", GURT_VERSION));
}
if !headers.contains_key("date") {
// RFC 7231 compliant
let now = Utc::now();
let date_str = now.format("%a, %d %b %Y %H:%M:%S GMT").to_string();
headers.insert("date".to_string(), date_str);
}
for (key, value) in &headers {
message.push_str(&format!("{}: {}{}", key, value, HEADER_SEPARATOR));
}
message.push_str(HEADER_SEPARATOR);
// Convert headers to bytes and append body as raw bytes
let mut bytes = message.into_bytes();
bytes.extend_from_slice(&self.body);
bytes
}
}
#[derive(Debug, Clone)]
pub enum GurtMessage {
Request(GurtRequest),
Response(GurtResponse),
}
impl GurtMessage {
pub fn parse(data: &str) -> Result<Self> {
Self::parse_bytes(data.as_bytes())
}
pub fn parse_bytes(data: &[u8]) -> Result<Self> {
// Convert first line to string to determine message type
let header_separator = HEADER_SEPARATOR.as_bytes();
let first_line_end = data.windows(header_separator.len())
.position(|window| window == header_separator)
.unwrap_or(data.len());
let first_line = std::str::from_utf8(&data[..first_line_end])
.map_err(|_| GurtError::invalid_message("Invalid UTF-8 in first line"))?;
// Check if it's a response (starts with GURT/version) or request (method first)
if first_line.starts_with(PROTOCOL_PREFIX) {
Ok(GurtMessage::Response(GurtResponse::parse_bytes(data)?))
} else {
Ok(GurtMessage::Request(GurtRequest::parse_bytes(data)?))
}
}
pub fn is_request(&self) -> bool {
matches!(self, GurtMessage::Request(_))
}
pub fn is_response(&self) -> bool {
matches!(self, GurtMessage::Response(_))
}
pub fn as_request(&self) -> Option<&GurtRequest> {
match self {
GurtMessage::Request(req) => Some(req),
_ => None,
}
}
pub fn as_response(&self) -> Option<&GurtResponse> {
match self {
GurtMessage::Response(res) => Some(res),
_ => None,
}
}
}
impl fmt::Display for GurtMessage {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
GurtMessage::Request(req) => write!(f, "{}", req.to_string()),
GurtMessage::Response(res) => write!(f, "{}", res.to_string()),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_request_parsing() {
let raw = "GET /test GURT/1.0.0\r\nHost: example.com\r\nAccept: text/html\r\n\r\ntest body";
let request = GurtRequest::parse(raw).expect("Failed to parse request");
assert_eq!(request.method, GurtMethod::GET);
assert_eq!(request.path, "/test");
assert_eq!(request.version, GURT_VERSION.to_string());
assert_eq!(request.header("host"), Some(&"example.com".to_string()));
assert_eq!(request.header("accept"), Some(&"text/html".to_string()));
assert_eq!(request.body_as_string().unwrap(), "test body");
}
#[test]
fn test_response_parsing() {
let raw = "GURT/1.0.0 200 OK\r\nContent-Type: text/html\r\n\r\n<html></html>";
let response = GurtResponse::parse(raw).expect("Failed to parse response");
assert_eq!(response.version, GURT_VERSION.to_string());
assert_eq!(response.status_code, 200);
assert_eq!(response.status_message, "OK");
assert_eq!(response.header("content-type"), Some(&"text/html".to_string()));
assert_eq!(response.body_as_string().unwrap(), "<html></html>");
}
#[test]
fn test_request_building() {
let request = GurtRequest::new(GurtMethod::GET, "/test".to_string())
.with_header("Host", "example.com")
.with_string_body("test body");
let raw = request.to_string();
let parsed = GurtRequest::parse(&raw).expect("Failed to parse built request");
assert_eq!(parsed.method, request.method);
assert_eq!(parsed.path, request.path);
assert_eq!(parsed.body, request.body);
}
#[test]
fn test_response_building() {
let response = GurtResponse::ok()
.with_header("Content-Type", "text/html")
.with_string_body("<html></html>");
let raw = response.to_string();
let parsed = GurtResponse::parse(&raw).expect("Failed to parse built response");
assert_eq!(parsed.status_code, response.status_code);
assert_eq!(parsed.body, response.body);
}
}

View File

@@ -0,0 +1,120 @@
use std::fmt;
pub const GURT_VERSION: &str = "1.0.0";
pub const DEFAULT_PORT: u16 = 4878;
pub const PROTOCOL_PREFIX: &str = "GURT/";
pub const HEADER_SEPARATOR: &str = "\r\n";
pub const BODY_SEPARATOR: &str = "\r\n\r\n";
pub const DEFAULT_HANDSHAKE_TIMEOUT: u64 = 5;
pub const DEFAULT_REQUEST_TIMEOUT: u64 = 30;
pub const DEFAULT_CONNECTION_TIMEOUT: u64 = 10;
pub const MAX_MESSAGE_SIZE: usize = 10 * 1024 * 1024;
pub const MAX_POOL_SIZE: usize = 10;
pub const POOL_IDLE_TIMEOUT: u64 = 300;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum GurtStatusCode {
// Success
Ok = 200,
Created = 201,
Accepted = 202,
NoContent = 204,
// Handshake
SwitchingProtocols = 101,
// Client errors
BadRequest = 400,
Unauthorized = 401,
Forbidden = 403,
NotFound = 404,
MethodNotAllowed = 405,
Timeout = 408,
TooLarge = 413,
UnsupportedMediaType = 415,
// Server errors
InternalServerError = 500,
NotImplemented = 501,
BadGateway = 502,
ServiceUnavailable = 503,
GatewayTimeout = 504,
}
impl GurtStatusCode {
pub fn from_u16(code: u16) -> Option<Self> {
match code {
200 => Some(Self::Ok),
201 => Some(Self::Created),
202 => Some(Self::Accepted),
204 => Some(Self::NoContent),
101 => Some(Self::SwitchingProtocols),
400 => Some(Self::BadRequest),
401 => Some(Self::Unauthorized),
403 => Some(Self::Forbidden),
404 => Some(Self::NotFound),
405 => Some(Self::MethodNotAllowed),
408 => Some(Self::Timeout),
413 => Some(Self::TooLarge),
415 => Some(Self::UnsupportedMediaType),
500 => Some(Self::InternalServerError),
501 => Some(Self::NotImplemented),
502 => Some(Self::BadGateway),
503 => Some(Self::ServiceUnavailable),
504 => Some(Self::GatewayTimeout),
_ => None,
}
}
pub fn message(&self) -> &'static str {
match self {
Self::Ok => "OK",
Self::Created => "CREATED",
Self::Accepted => "ACCEPTED",
Self::NoContent => "NO_CONTENT",
Self::SwitchingProtocols => "SWITCHING_PROTOCOLS",
Self::BadRequest => "BAD_REQUEST",
Self::Unauthorized => "UNAUTHORIZED",
Self::Forbidden => "FORBIDDEN",
Self::NotFound => "NOT_FOUND",
Self::MethodNotAllowed => "METHOD_NOT_ALLOWED",
Self::Timeout => "TIMEOUT",
Self::TooLarge => "TOO_LARGE",
Self::UnsupportedMediaType => "UNSUPPORTED_MEDIA_TYPE",
Self::InternalServerError => "INTERNAL_SERVER_ERROR",
Self::NotImplemented => "NOT_IMPLEMENTED",
Self::BadGateway => "BAD_GATEWAY",
Self::ServiceUnavailable => "SERVICE_UNAVAILABLE",
Self::GatewayTimeout => "GATEWAY_TIMEOUT",
}
}
pub fn is_success(&self) -> bool {
matches!(self, Self::Ok | Self::Created | Self::Accepted | Self::NoContent)
}
pub fn is_client_error(&self) -> bool {
(*self as u16) >= 400 && (*self as u16) < 500
}
pub fn is_server_error(&self) -> bool {
(*self as u16) >= 500
}
}
impl From<GurtStatusCode> for u16 {
fn from(code: GurtStatusCode) -> Self {
code as u16
}
}
impl fmt::Display for GurtStatusCode {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", *self as u16)
}
}

View File

@@ -0,0 +1,563 @@
use crate::{
GurtError, Result, GurtRequest, GurtResponse, GurtMessage,
protocol::{BODY_SEPARATOR, MAX_MESSAGE_SIZE},
message::GurtMethod,
protocol::GurtStatusCode,
crypto::{TLS_VERSION, GURT_ALPN, TlsConfig},
};
use tokio::net::{TcpListener, TcpStream};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio_rustls::{TlsAcceptor, server::TlsStream};
use rustls::pki_types::CertificateDer;
use std::collections::HashMap;
use std::net::SocketAddr;
use std::sync::Arc;
use std::fs;
use tracing::{info, warn, error, debug};
#[derive(Debug, Clone)]
pub struct ServerContext {
pub remote_addr: SocketAddr,
pub request: GurtRequest,
}
impl ServerContext {
pub fn client_ip(&self) -> std::net::IpAddr {
self.remote_addr.ip()
}
pub fn client_port(&self) -> u16 {
self.remote_addr.port()
}
pub fn method(&self) -> &GurtMethod {
&self.request.method
}
pub fn path(&self) -> &str {
&self.request.path
}
pub fn headers(&self) -> &HashMap<String, String> {
&self.request.headers
}
pub fn body(&self) -> &[u8] {
&self.request.body
}
pub fn body_as_string(&self) -> Result<String> {
self.request.body_as_string()
}
pub fn header(&self, key: &str) -> Option<&String> {
self.request.header(key)
}
}
pub trait GurtHandler: Send + Sync {
fn handle(&self, ctx: &ServerContext) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<GurtResponse>> + Send + '_>>;
}
pub struct FnHandler<F> {
handler: F,
}
impl<F, Fut> GurtHandler for FnHandler<F>
where
F: Fn(&ServerContext) -> Fut + Send + Sync,
Fut: std::future::Future<Output = Result<GurtResponse>> + Send + 'static,
{
fn handle(&self, ctx: &ServerContext) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<GurtResponse>> + Send + '_>> {
Box::pin((self.handler)(ctx))
}
}
#[derive(Debug, Clone)]
pub struct Route {
method: Option<GurtMethod>,
path_pattern: String,
}
impl Route {
pub fn new(method: Option<GurtMethod>, path_pattern: String) -> Self {
Self { method, path_pattern }
}
pub fn get(path: &str) -> Self {
Self::new(Some(GurtMethod::GET), path.to_string())
}
pub fn post(path: &str) -> Self {
Self::new(Some(GurtMethod::POST), path.to_string())
}
pub fn put(path: &str) -> Self {
Self::new(Some(GurtMethod::PUT), path.to_string())
}
pub fn delete(path: &str) -> Self {
Self::new(Some(GurtMethod::DELETE), path.to_string())
}
pub fn head(path: &str) -> Self {
Self::new(Some(GurtMethod::HEAD), path.to_string())
}
pub fn options(path: &str) -> Self {
Self::new(Some(GurtMethod::OPTIONS), path.to_string())
}
pub fn patch(path: &str) -> Self {
Self::new(Some(GurtMethod::PATCH), path.to_string())
}
pub fn any(path: &str) -> Self {
Self::new(None, path.to_string())
}
pub fn matches(&self, method: &GurtMethod, path: &str) -> bool {
if let Some(route_method) = &self.method {
if route_method != method {
return false;
}
}
self.matches_path(path)
}
pub fn matches_path(&self, path: &str) -> bool {
self.path_pattern == path ||
(self.path_pattern.ends_with('*') && path.starts_with(&self.path_pattern[..self.path_pattern.len()-1]))
}
}
pub struct GurtServer {
routes: Vec<(Route, Arc<dyn GurtHandler>)>,
tls_acceptor: Option<TlsAcceptor>,
}
impl GurtServer {
pub fn new() -> Self {
Self {
routes: Vec::new(),
tls_acceptor: None,
}
}
pub fn with_tls_certificates(cert_path: &str, key_path: &str) -> Result<Self> {
let mut server = Self::new();
server.load_tls_certificates(cert_path, key_path)?;
Ok(server)
}
pub fn load_tls_certificates(&mut self, cert_path: &str, key_path: &str) -> Result<()> {
info!("Loading TLS certificates: cert={}, key={}", cert_path, key_path);
let cert_data = fs::read(cert_path)
.map_err(|e| GurtError::crypto(format!("Failed to read certificate file '{}': {}", cert_path, e)))?;
let key_data = fs::read(key_path)
.map_err(|e| GurtError::crypto(format!("Failed to read private key file '{}': {}", key_path, e)))?;
let mut cursor = std::io::Cursor::new(cert_data);
let certs: Vec<CertificateDer<'static>> = rustls_pemfile::certs(&mut cursor)
.collect::<std::result::Result<Vec<_>, _>>()
.map_err(|e| GurtError::crypto(format!("Failed to parse certificates: {}", e)))?;
if certs.is_empty() {
return Err(GurtError::crypto("No certificates found in certificate file"));
}
let mut key_cursor = std::io::Cursor::new(key_data);
let private_key = rustls_pemfile::private_key(&mut key_cursor)
.map_err(|e| GurtError::crypto(format!("Failed to parse private key: {}", e)))?
.ok_or_else(|| GurtError::crypto("No private key found in key file"))?;
let tls_config = TlsConfig::new_server(certs, private_key)?;
self.tls_acceptor = Some(tls_config.get_acceptor()?);
info!("TLS certificates loaded successfully");
Ok(())
}
pub fn route<H>(mut self, route: Route, handler: H) -> Self
where
H: GurtHandler + 'static,
{
self.routes.push((route, Arc::new(handler)));
self
}
pub fn get<F, Fut>(self, path: &str, handler: F) -> Self
where
F: Fn(&ServerContext) -> Fut + Send + Sync + 'static,
Fut: std::future::Future<Output = Result<GurtResponse>> + Send + 'static,
{
self.route(Route::get(path), FnHandler { handler })
}
pub fn post<F, Fut>(self, path: &str, handler: F) -> Self
where
F: Fn(&ServerContext) -> Fut + Send + Sync + 'static,
Fut: std::future::Future<Output = Result<GurtResponse>> + Send + 'static,
{
self.route(Route::post(path), FnHandler { handler })
}
pub fn put<F, Fut>(self, path: &str, handler: F) -> Self
where
F: Fn(&ServerContext) -> Fut + Send + Sync + 'static,
Fut: std::future::Future<Output = Result<GurtResponse>> + Send + 'static,
{
self.route(Route::put(path), FnHandler { handler })
}
pub fn delete<F, Fut>(self, path: &str, handler: F) -> Self
where
F: Fn(&ServerContext) -> Fut + Send + Sync + 'static,
Fut: std::future::Future<Output = Result<GurtResponse>> + Send + 'static,
{
self.route(Route::delete(path), FnHandler { handler })
}
pub fn head<F, Fut>(self, path: &str, handler: F) -> Self
where
F: Fn(&ServerContext) -> Fut + Send + Sync + 'static,
Fut: std::future::Future<Output = Result<GurtResponse>> + Send + 'static,
{
self.route(Route::head(path), FnHandler { handler })
}
pub fn options<F, Fut>(self, path: &str, handler: F) -> Self
where
F: Fn(&ServerContext) -> Fut + Send + Sync + 'static,
Fut: std::future::Future<Output = Result<GurtResponse>> + Send + 'static,
{
self.route(Route::options(path), FnHandler { handler })
}
pub fn patch<F, Fut>(self, path: &str, handler: F) -> Self
where
F: Fn(&ServerContext) -> Fut + Send + Sync + 'static,
Fut: std::future::Future<Output = Result<GurtResponse>> + Send + 'static,
{
self.route(Route::patch(path), FnHandler { handler })
}
pub fn any<F, Fut>(self, path: &str, handler: F) -> Self
where
F: Fn(&ServerContext) -> Fut + Send + Sync + 'static,
Fut: std::future::Future<Output = Result<GurtResponse>> + Send + 'static,
{
self.route(Route::any(path), FnHandler { handler })
}
pub async fn listen(self, addr: &str) -> Result<()> {
let listener = TcpListener::bind(addr).await?;
info!("GURT server listening on {}", addr);
loop {
match listener.accept().await {
Ok((stream, addr)) => {
info!("Client connected: {}", addr);
let server = self.clone();
tokio::spawn(async move {
if let Err(e) = server.handle_connection(stream, addr).await {
error!("Connection error from {}: {}", addr, e);
}
info!("Client disconnected: {}", addr);
});
}
Err(e) => {
error!("Failed to accept connection: {}", e);
}
}
}
}
async fn handle_connection(&self, mut stream: TcpStream, addr: SocketAddr) -> Result<()> {
self.handle_initial_handshake(&mut stream, addr).await?;
if let Some(tls_acceptor) = &self.tls_acceptor {
info!("Upgrading connection to TLS for {}", addr);
let tls_stream = tls_acceptor.accept(stream).await
.map_err(|e| GurtError::crypto(format!("TLS upgrade failed: {}", e)))?;
info!("TLS upgrade completed for {}", addr);
self.handle_tls_connection(tls_stream, addr).await
} else {
warn!("No TLS configuration available, but handshake completed - this violates GURT protocol");
Err(GurtError::protocol("TLS is required after handshake but no TLS configuration available"))
}
}
async fn handle_initial_handshake(&self, stream: &mut TcpStream, addr: SocketAddr) -> Result<()> {
let mut buffer = Vec::new();
let mut temp_buffer = [0u8; 8192];
loop {
let bytes_read = stream.read(&mut temp_buffer).await?;
if bytes_read == 0 {
return Err(GurtError::connection("Connection closed during handshake"));
}
buffer.extend_from_slice(&temp_buffer[..bytes_read]);
let body_separator = BODY_SEPARATOR.as_bytes();
if buffer.windows(body_separator.len()).any(|w| w == body_separator) {
break;
}
if buffer.len() > MAX_MESSAGE_SIZE {
return Err(GurtError::protocol("Handshake message too large"));
}
}
let message = GurtMessage::parse_bytes(&buffer)?;
match message {
GurtMessage::Request(request) => {
if request.method == GurtMethod::HANDSHAKE {
self.send_handshake_response(stream, addr, &request).await
} else {
Err(GurtError::protocol("First message must be HANDSHAKE"))
}
}
GurtMessage::Response(_) => {
Err(GurtError::protocol("Server received response during handshake"))
}
}
}
async fn handle_tls_connection(&self, mut tls_stream: TlsStream<TcpStream>, addr: SocketAddr) -> Result<()> {
let mut buffer = Vec::new();
let mut temp_buffer = [0u8; 8192];
loop {
let bytes_read = match tls_stream.read(&mut temp_buffer).await {
Ok(n) => n,
Err(e) => {
// Handle UnexpectedEof from clients that don't send close_notify
if e.kind() == std::io::ErrorKind::UnexpectedEof {
debug!("Client {} closed connection without TLS close_notify (benign)", addr);
break;
}
return Err(e.into());
}
};
if bytes_read == 0 {
break; // Connection closed
}
buffer.extend_from_slice(&temp_buffer[..bytes_read]);
let body_separator = BODY_SEPARATOR.as_bytes();
let has_complete_message = buffer.windows(body_separator.len()).any(|w| w == body_separator) ||
(buffer.starts_with(b"{") && buffer.ends_with(b"}"));
if has_complete_message {
if let Err(e) = self.process_tls_message(&mut tls_stream, addr, &buffer).await {
error!("Encrypted message processing error from {}: {}", addr, e);
let error_response = GurtResponse::internal_server_error()
.with_string_body("Internal server error");
let _ = tls_stream.write_all(&error_response.to_bytes()).await;
}
buffer.clear();
}
// Prevent buffer overflow
if buffer.len() > MAX_MESSAGE_SIZE {
warn!("Message too large from {}, closing connection", addr);
break;
}
}
Ok(())
}
async fn send_handshake_response(&self, stream: &mut TcpStream, addr: SocketAddr, _request: &GurtRequest) -> Result<()> {
info!("Sending handshake response to {}", addr);
let response = GurtResponse::new(GurtStatusCode::SwitchingProtocols)
.with_header("GURT-Version", crate::GURT_VERSION.to_string())
.with_header("Encryption", TLS_VERSION)
.with_header("ALPN", std::str::from_utf8(GURT_ALPN).unwrap_or("gurt/1.0"));
let response_bytes = response.to_string().into_bytes();
stream.write_all(&response_bytes).await?;
info!("Handshake response sent to {}, preparing for TLS upgrade", addr);
Ok(())
}
async fn process_tls_message(&self, tls_stream: &mut TlsStream<TcpStream>, addr: SocketAddr, data: &[u8]) -> Result<()> {
let message = GurtMessage::parse_bytes(data)?;
match message {
GurtMessage::Request(request) => {
if request.method == GurtMethod::HANDSHAKE {
Err(GurtError::protocol("Received HANDSHAKE over TLS - protocol violation"))
} else {
self.handle_encrypted_request(tls_stream, addr, &request).await
}
}
GurtMessage::Response(_) => {
warn!("Received response on server, ignoring");
Ok(())
}
}
}
async fn handle_default_options(&self, tls_stream: &mut TlsStream<TcpStream>, request: &GurtRequest) -> Result<()> {
let mut allowed_methods = std::collections::HashSet::new();
for (route, _) in &self.routes {
if route.matches_path(&request.path) {
if let Some(method) = &route.method {
allowed_methods.insert(method.to_string());
} else {
// Route matches any method
allowed_methods.extend(vec![
"GET".to_string(), "POST".to_string(), "PUT".to_string(),
"DELETE".to_string(), "HEAD".to_string(), "PATCH".to_string()
]);
}
}
}
allowed_methods.insert("OPTIONS".to_string());
let mut allowed_methods_vec: Vec<String> = allowed_methods.into_iter().collect();
allowed_methods_vec.sort();
let allow_header = allowed_methods_vec.join(", ");
let response = GurtResponse::ok()
.with_header("Allow", allow_header)
.with_header("Access-Control-Allow-Origin", "*")
.with_header("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, HEAD, OPTIONS, PATCH")
.with_header("Access-Control-Allow-Headers", "Content-Type, Authorization");
tls_stream.write_all(&response.to_bytes()).await?;
Ok(())
}
async fn handle_default_head(&self, tls_stream: &mut TlsStream<TcpStream>, addr: SocketAddr, request: &GurtRequest) -> Result<()> {
for (route, handler) in &self.routes {
if route.method == Some(GurtMethod::GET) && route.matches(&GurtMethod::GET, &request.path) {
let context = ServerContext {
remote_addr: addr,
request: request.clone(),
};
match handler.handle(&context).await {
Ok(mut response) => {
let original_content_length = response.body.len();
response.body.clear();
response = response.with_header("content-length", original_content_length.to_string());
tls_stream.write_all(&response.to_bytes()).await?;
return Ok(());
}
Err(e) => {
error!("Handler error for HEAD {} (via GET): {}", request.path, e);
let error_response = GurtResponse::internal_server_error();
tls_stream.write_all(&error_response.to_bytes()).await?;
return Ok(());
}
}
}
}
let not_found_response = GurtResponse::not_found();
tls_stream.write_all(&not_found_response.to_bytes()).await?;
Ok(())
}
async fn handle_encrypted_request(&self, tls_stream: &mut TlsStream<TcpStream>, addr: SocketAddr, request: &GurtRequest) -> Result<()> {
debug!("Handling encrypted {} request to {} from {}", request.method, request.path, addr);
// Find matching route
for (route, handler) in &self.routes {
if route.matches(&request.method, &request.path) {
let context = ServerContext {
remote_addr: addr,
request: request.clone(),
};
match handler.handle(&context).await {
Ok(response) => {
// Use to_bytes() to avoid corrupting binary data
let response_bytes = response.to_bytes();
tls_stream.write_all(&response_bytes).await?;
return Ok(());
}
Err(e) => {
error!("Handler error for {} {}: {}", request.method, request.path, e);
let error_response = GurtResponse::internal_server_error()
.with_string_body("Internal server error");
tls_stream.write_all(&error_response.to_bytes()).await?;
return Ok(());
}
}
}
}
// No route found - check for default OPTIONS/HEAD handling
match request.method {
GurtMethod::OPTIONS => {
self.handle_default_options(tls_stream, request).await
}
GurtMethod::HEAD => {
self.handle_default_head(tls_stream, addr, request).await
}
_ => {
let not_found_response = GurtResponse::not_found()
.with_string_body("Not found");
tls_stream.write_all(&not_found_response.to_bytes()).await?;
Ok(())
}
}
}
}
impl Clone for GurtServer {
fn clone(&self) -> Self {
Self {
routes: self.routes.clone(),
tls_acceptor: self.tls_acceptor.clone(),
}
}
}
impl Default for GurtServer {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use tokio::test;
#[test]
async fn test_route_matching() {
let route = Route::get("/test");
assert!(route.matches(&GurtMethod::GET, "/test"));
assert!(!route.matches(&GurtMethod::POST, "/test"));
assert!(!route.matches(&GurtMethod::GET, "/other"));
let wildcard_route = Route::get("/api/*");
assert!(wildcard_route.matches(&GurtMethod::GET, "/api/users"));
assert!(wildcard_route.matches(&GurtMethod::GET, "/api/posts"));
assert!(!wildcard_route.matches(&GurtMethod::GET, "/other"));
}
}