diff --git a/docs/docs/gurty-cli.md b/docs/docs/gurty-cli.md index a97f034..7ec3a2b 100644 --- a/docs/docs/gurty-cli.md +++ b/docs/docs/gurty-cli.md @@ -8,14 +8,83 @@ sidebar_position: 5 ## Installation -Build Gurty from the protocol CLI directory: +To begin, [install Gurty here](https://gurted.com/download). + +## Configuration + +Gurty supports configuration through TOML files. Use the provided template to get started: ```bash cd protocol/cli -cargo build --release +cp gurty.template.toml gurty.toml ``` -The binary will be available at `target/release/gurty` (or `gurty.exe` on Windows). +### Configuration File Structure + +The configuration file includes the following sections: + +#### Server Settings +```toml +[server] +host = "127.0.0.1" +port = 4878 +protocol_version = "1.0.0" +alpn_identifier = "GURT/1.0" +max_connections = 10 +max_message_size = "10MB" + +[server.timeouts] +handshake = 5 +request = 30 +connection = 10 +pool_idle = 300 +``` + +#### TLS Configuration +```toml +[tls] +certificate = "localhost+2.pem" +private_key = "localhost+2-key.pem" +``` + +#### Logging Options +```toml +[logging] +level = "info" +log_requests = true +log_responses = false +access_log = "/var/log/gurty/access.log" +error_log = "/var/log/gurty/error.log" +``` + +#### Security Settings +```toml +[security] +deny_files = ["*.env", "*.config", ".git/*", "*.key", "*.pem"] +allowed_methods = ["GET", "POST", "PUT", "DELETE", "HEAD", "OPTIONS", "PATCH"] +rate_limit_requests = 100 +rate_limit_connections = 1000 +``` + +#### Error Pages and Headers +```toml +# Custom error page files +[error_pages] +"404" = "/errors/404.html" +"500" = "/errors/500.html" + +# Default inline error pages +[error_pages.default] +"400" = ''' +400 Bad Request +

400 - Bad Request

The request could not be understood by the server.

''' + +# Custom HTTP headers +[headers] +server = "GURT/1.0.0" +"x-frame-options" = "SAMEORIGIN" +"x-content-type-options" = "nosniff" +``` ## Quick Start @@ -46,7 +115,18 @@ The binary will be available at `target/release/gurty` (or `gurty.exe` on Window - `localhost+2.pem` (certificate) - `localhost+2-key.pem` (private key) -4. **Start GURT server**: +4. **Set up configuration** (optional but recommended): + ```bash + cd protocol/cli + cp gurty.template.toml gurty.toml + ``` + Edit `gurty.toml` to customize settings for development. + +5. **Start GURT server**: + ```bash + cargo run --release serve --config gurty.toml + ``` + Or specify certificates explicitly: ```bash cargo run --release serve --cert localhost+2.pem --key localhost+2-key.pem ``` @@ -68,9 +148,22 @@ The binary will be available at `target/release/gurty` (or `gurty.exe` on Window openssl req -x509 -newkey rsa:4096 -keyout gurt-server.key -out gurt-server.crt -days 365 -nodes ``` -2. **Deploy with production certificates**: +2. **Set up configuration**: ```bash - cargo run --release serve --cert gurt-server.crt --key gurt-server.key --host 0.0.0.0 --port 4878 + cp gurty.template.toml gurty.toml + # Edit gurty.toml for production: + # - Update certificate paths + # - Set host to "0.0.0.0" for external access + # - Configure logging and security settings + ``` + +3. **Deploy with production certificates**: + ```bash + cargo run --release serve --config gurty.toml + ``` + Or specify certificates explicitly: + ```bash + cargo run --release serve --cert gurt-server.crt --key gurt-server.key --config gurty.toml ``` ## Commands @@ -87,19 +180,34 @@ gurty serve [OPTIONS] | Option | Description | Default | |--------|-------------|---------| -| `--cert ` | Path to TLS certificate file | Required | -| `--key ` | Path to TLS private key file | Required | +| `--cert ` | Path to TLS certificate file | Required* | +| `--key ` | Path to TLS private key file | Required* | +| `--config ` | Path to configuration file | None | | `--host ` | Host address to bind to | `127.0.0.1` | | `--port ` | Port number to listen on | `4878` | | `--dir ` | Directory to serve files from | None | | `--log-level ` | Logging level (error, warn, info, debug, trace) | `info` | +*Required unless specified in configuration file + #### Examples +**Using configuration file:** +```bash +gurty serve --config gurty.toml +``` + +**Explicit certificates with configuration:** +```bash +gurty serve --cert localhost+2.pem --key localhost+2-key.pem --config gurty.toml +``` + +**Manual setup without configuration file:** ```bash gurty serve --cert localhost+2.pem --key localhost+2-key.pem --dir ./public ``` -Debug: + +**Debug mode with configuration:** ```bash -gurty serve --cert dev.pem --key dev-key.pem --log-level debug +gurty serve --config gurty.toml --log-level debug ``` diff --git a/protocol/cli/.gitignore b/protocol/cli/.gitignore new file mode 100644 index 0000000..e69de29 diff --git a/protocol/cli/Cargo.lock b/protocol/cli/Cargo.lock index b91ca1f..00d11dd 100644 --- a/protocol/cli/Cargo.lock +++ b/protocol/cli/Cargo.lock @@ -91,6 +91,17 @@ dependencies = [ "windows-sys 0.60.2", ] +[[package]] +name = "async-trait" +version = "0.1.89" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9035ad2d096bed7955a320ee7e2230574d28fd3c3a0f186cbea1ff3c7eed5dbb" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "autocfg" version = "1.5.0" @@ -338,6 +349,12 @@ version = "1.15.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "48c757948c5ede0e46177b7add2e67155f70e33c07fea8284df6576da70b3719" +[[package]] +name = "equivalent" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "877a4ace8713b0bcf2a4e7eec82529c029f1d0619886d18145fea96c3ffe5c0f" + [[package]] name = "errno" version = "0.3.13" @@ -420,17 +437,27 @@ dependencies = [ name = "gurty" version = "0.1.0" dependencies = [ + "async-trait", "clap", "colored", "gurt", + "indexmap", "mime_guess", + "regex", "serde", "serde_json", "tokio", + "toml", "tracing", "tracing-subscriber", ] +[[package]] +name = "hashbrown" +version = "0.15.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9229cfe53dfd69f0609a49f65461bd93001ea1ef889cd5529dd176593f5338a1" + [[package]] name = "heck" version = "0.5.0" @@ -577,6 +604,16 @@ dependencies = [ "icu_properties", ] +[[package]] +name = "indexmap" +version = "2.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fe4cd85333e22411419a0bcae1297d25e58c9443848b11dc6a86fefe8c78a661" +dependencies = [ + "equivalent", + "hashbrown", +] + [[package]] name = "io-uring" version = "0.7.9" @@ -1089,6 +1126,15 @@ dependencies = [ "serde", ] +[[package]] +name = "serde_spanned" +version = "0.6.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bf41e0cfaf7226dca15e8197172c295a782857fcb97fad1808a166870dee75a3" +dependencies = [ + "serde", +] + [[package]] name = "sharded-slab" version = "0.1.7" @@ -1255,6 +1301,47 @@ dependencies = [ "tokio", ] +[[package]] +name = "toml" +version = "0.8.23" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dc1beb996b9d83529a9e75c17a1686767d148d70663143c7854d8b4a09ced362" +dependencies = [ + "serde", + "serde_spanned", + "toml_datetime", + "toml_edit", +] + +[[package]] +name = "toml_datetime" +version = "0.6.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "22cddaf88f4fbc13c51aebbf5f8eceb5c7c5a9da2ac40a13519eb5b0a0e8f11c" +dependencies = [ + "serde", +] + +[[package]] +name = "toml_edit" +version = "0.22.27" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "41fe8c660ae4257887cf66394862d21dbca4a6ddd26f04a3560410406a2f819a" +dependencies = [ + "indexmap", + "serde", + "serde_spanned", + "toml_datetime", + "toml_write", + "winnow", +] + +[[package]] +name = "toml_write" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5d99f8c9a7727884afe522e9bd5edbfc91a3312b36a77b5fb8926e4c31a41801" + [[package]] name = "tracing" version = "0.1.41" @@ -1681,6 +1768,15 @@ version = "0.53.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "271414315aff87387382ec3d271b52d7ae78726f5d44ac98b4f4030c91880486" +[[package]] +name = "winnow" +version = "0.7.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f3edebf492c8125044983378ecb5766203ad3b4c2f7a922bd7dd207f6d443e95" +dependencies = [ + "memchr", +] + [[package]] name = "wit-bindgen-rt" version = "0.39.0" diff --git a/protocol/cli/Cargo.toml b/protocol/cli/Cargo.toml index 994e9a8..1821c62 100644 --- a/protocol/cli/Cargo.toml +++ b/protocol/cli/Cargo.toml @@ -22,4 +22,8 @@ tracing-subscriber = "0.3" clap = { version = "4.0", features = ["derive"] } colored = "2.0" -mime_guess = "2.0" \ No newline at end of file +mime_guess = "2.0" +async-trait = "0.1" +toml = "0.8" +regex = "1.0" +indexmap = "2.0" \ No newline at end of file diff --git a/protocol/cli/README.md b/protocol/cli/README.md index 94f0d96..60d23b3 100644 --- a/protocol/cli/README.md +++ b/protocol/cli/README.md @@ -1,5 +1,34 @@ # Gurty - a CLI tool to setup your GURT Protocol server +Gurty is a command-line interface tool for setting up and managing GURT protocol servers. + +## Configuration + +Gurty uses a TOML configuration file to manage server settings. The `gurty.template.toml` file provides a complete configuration template with all available options: + +### Sections + +- **Server**: Basic server settings (host, port, protocol version, connection limits) +- **TLS**: Certificate and private key configuration for secure connections +- **Logging**: Logging levels, request/response logging, and log file paths +- **Security**: File access restrictions, allowed HTTP methods, and rate limiting +- **Error Pages**: Custom error page templates and default error responses +- **Headers**: Custom HTTP headers for security and server identification + +### Using Configuration Files + +1. **Copy the configuration template:** + ```bash + cp gurty.template.toml gurty.toml + ``` + +2. **Edit the configuration** to match your environment. (optional) + +3. **Use the configuration file:** + ```bash + gurty serve --config gurty.toml + ``` + ## Setup for Production For production deployments, you'll need to generate your own certificates since traditional Certificate Authorities don't support custom protocols: @@ -19,14 +48,24 @@ For production deployments, you'll need to generate your own certificates since openssl req -x509 -newkey rsa:4096 -keyout gurt-server.key -out gurt-server.crt -days 365 -nodes ``` -2. **Deploy with production certificates:** +2. **Copy the configuration template and customize:** ```bash - cargo run --release serve --cert gurt-server.crt --key gurt-server.key --host 0.0.0.0 --port 4878 + cp gurty.template.toml gurty.toml + ``` + +3. **Deploy with production certificates and configuration:** + ```bash + gurty serve --config gurty.toml + ``` + Or specify certificates explicitly: + ```bash + gurty serve --cert gurt-server.crt --key gurt-server.key --config gurty.toml ``` ## Development Environment Setup To set up a development environment for GURT, follow these steps: + 1. **Install mkcert:** ```bash # Windows (with Chocolatey) @@ -50,7 +89,16 @@ To set up a development environment for GURT, follow these steps: - `localhost+2.pem` (certificate) - `localhost+2-key.pem` (private key) -4. **Start GURT server with certificates:** +4. **Copy the configuration template and customize:** ```bash - cargo run --release serve --cert localhost+2.pem --key localhost+2-key.pem + cp gurty.template.toml gurty.toml + ``` + +5. **Start GURT server with certificates and configuration:** + ```bash + gurty serve --config gurty.toml + ``` + Or specify certificates explicitly: + ```bash + gurty serve --cert localhost+2.pem --key localhost+2-key.pem --config gurty.toml ``` \ No newline at end of file diff --git a/protocol/cli/gurty.template.toml b/protocol/cli/gurty.template.toml new file mode 100644 index 0000000..ae529d0 --- /dev/null +++ b/protocol/cli/gurty.template.toml @@ -0,0 +1,60 @@ +[server] +host = "127.0.0.1" +port = 4878 +protocol_version = "1.0.0" +alpn_identifier = "GURT/1.0" +max_connections = 10 +max_message_size = "10MB" + +[server.timeouts] +handshake = 5 +request = 30 +connection = 10 +pool_idle = 300 + +[tls] +certificate = "localhost+2.pem" +private_key = "localhost+2-key.pem" + +[logging] +level = "info" +# access_log = "/var/log/gurty/access.log" +# error_log = "/var/log/gurty/error.log" +log_requests = true +log_responses = false + +[security] +deny_files = [ + "*.env", + "*.config", + ".git/*", + "node_modules/*", + "*.key", + "*.pem" +] + +allowed_methods = ["GET", "POST", "PUT", "DELETE", "HEAD", "OPTIONS", "PATCH"] +rate_limit_requests = 100 # requests per minute +rate_limit_connections = 1000 # concurrent connections per IP + +# Error pages configuration +[error_pages] +# Specific error pages (uncomment and set paths to custom files) +# "400" = "/errors/400.html" +# "401" = "/errors/401.html" +# "403" = "/errors/403.html" +# "404" = "/errors/404.html" +# "405" = "/errors/405.html" +# "429" = "/errors/429.html" +# "500" = "/errors/500.html" +# "503" = "/errors/503.html" + +[error_pages.default] +"400" = ''' +400 Bad Request +

400 - Bad Request

The request could not be understood by the server.

Back to home''' + +[headers] +server = "GURT/1.0.0" +"x-frame-options" = "SAMEORIGIN" +"x-content-type-options" = "nosniff" diff --git a/protocol/cli/src/cli.rs b/protocol/cli/src/cli.rs new file mode 100644 index 0000000..b0e5a20 --- /dev/null +++ b/protocol/cli/src/cli.rs @@ -0,0 +1,88 @@ +use clap::{Parser, Subcommand}; +use std::path::PathBuf; + +#[derive(Parser)] +#[command(name = "server")] +#[command(about = "GURT Protocol Server")] +#[command(version = "1.0.0")] +pub struct Cli { + #[command(subcommand)] + pub command: Commands, +} + +#[derive(Subcommand)] +pub enum Commands { + Serve(ServeCommand), +} + +#[derive(Parser)] +pub struct ServeCommand { + #[arg(short, long, help = "Configuration file path")] + pub config: Option, + + #[arg(short, long, default_value_t = 4878)] + pub port: u16, + + #[arg(long, default_value = "127.0.0.1")] + pub host: String, + + #[arg(short, long, default_value = ".")] + pub dir: PathBuf, + + #[arg(short, long)] + pub verbose: bool, + + #[arg(long, help = "Path to TLS certificate file")] + pub cert: Option, + + #[arg(long, help = "Path to TLS private key file")] + pub key: Option, +} + +impl ServeCommand { + pub fn validate(&self) -> crate::Result<()> { + if !self.dir.exists() { + return Err(crate::ServerError::InvalidPath( + format!("Directory does not exist: {}", self.dir.display()) + )); + } + + if !self.dir.is_dir() { + return Err(crate::ServerError::InvalidPath( + format!("Path is not a directory: {}", self.dir.display()) + )); + } + + match (&self.cert, &self.key) { + (Some(cert), Some(key)) => { + if !cert.exists() { + return Err(crate::ServerError::TlsConfiguration( + format!("Certificate file does not exist: {}", cert.display()) + )); + } + if !key.exists() { + return Err(crate::ServerError::TlsConfiguration( + format!("Key file does not exist: {}", key.display()) + )); + } + } + (Some(_), None) => { + return Err(crate::ServerError::TlsConfiguration( + "Certificate provided but no key file specified (use --key)".to_string() + )); + } + (None, Some(_)) => { + return Err(crate::ServerError::TlsConfiguration( + "Key provided but no certificate file specified (use --cert)".to_string() + )); + } + (None, None) => { + return Err(crate::ServerError::TlsConfiguration( + "GURT protocol requires TLS encryption. Please provide --cert and --key parameters.".to_string() + )); + } + } + + Ok(()) + } +} diff --git a/protocol/cli/src/command_handler.rs b/protocol/cli/src/command_handler.rs new file mode 100644 index 0000000..fc21637 --- /dev/null +++ b/protocol/cli/src/command_handler.rs @@ -0,0 +1,160 @@ +use crate::{ + cli::ServeCommand, + config::GurtConfig, + server::FileServerBuilder, + Result, +}; +use async_trait::async_trait; +use colored::Colorize; +use tracing::{error, info}; + +#[async_trait] +pub trait CommandHandler { + async fn execute(&self) -> Result<()>; +} + +pub struct CommandHandlerBuilder { + logging_initialized: bool, + verbose: bool, +} + +impl CommandHandlerBuilder { + pub fn new() -> Self { + Self { + logging_initialized: false, + verbose: false, + } + } + + pub fn with_logging(mut self, verbose: bool) -> Self { + self.verbose = verbose; + self + } + + pub fn initialize_logging(mut self) -> Self { + if !self.logging_initialized { + let level = if self.verbose { + tracing::Level::DEBUG + } else { + tracing::Level::INFO + }; + + tracing_subscriber::fmt() + .with_max_level(level) + .init(); + + self.logging_initialized = true; + } + self + } + + pub fn build_serve_handler(self, serve_cmd: ServeCommand) -> ServeCommandHandler { + ServeCommandHandler::new(serve_cmd) + } +} + +impl Default for CommandHandlerBuilder { + fn default() -> Self { + Self::new() + } +} + +pub struct ServeCommandHandler { + serve_cmd: ServeCommand, +} + +impl ServeCommandHandler { + pub fn new(serve_cmd: ServeCommand) -> Self { + Self { serve_cmd } + } + + fn validate_command(&self) -> Result<()> { + if !self.serve_cmd.dir.exists() { + return Err(crate::ServerError::InvalidPath( + format!("Directory does not exist: {}", self.serve_cmd.dir.display()) + )); + } + + if !self.serve_cmd.dir.is_dir() { + return Err(crate::ServerError::InvalidPath( + format!("Path is not a directory: {}", self.serve_cmd.dir.display()) + )); + } + + Ok(()) + } + + fn build_server_config(&self) -> Result { + let mut config_builder = GurtConfig::builder(); + + if let Some(config_file) = &self.serve_cmd.config { + config_builder = config_builder.from_file(config_file)?; + } + + let config = config_builder + .merge_cli_args(&self.serve_cmd) + .build()?; + + Ok(config) + } + + fn display_startup_info(&self, config: &GurtConfig) { + println!("{}", "GURT Protocol Server".bright_cyan().bold()); + println!("{} {}", "Version".bright_blue(), config.server.protocol_version); + println!("{} {}", "Listening on".bright_blue(), config.address()); + println!("{} {}", "Serving from".bright_blue(), config.server.base_directory.display()); + + if config.tls.is_some() { + println!("{}", "TLS encryption enabled".bright_green()); + } + + if let Some(logging) = &config.logging { + println!("{} {}", "Log level".bright_blue(), logging.level); + if logging.log_requests { + println!("{}", "Request logging enabled".bright_green()); + } + } + + if let Some(security) = &config.security { + println!("{} {} req/min", "Rate limit".bright_blue(), security.rate_limit_requests); + if !security.deny_files.is_empty() { + println!("{} {} patterns", "File restrictions".bright_blue(), security.deny_files.len()); + } + } + + if let Some(headers) = &config.headers { + if !headers.is_empty() { + println!("{} {} headers", "Custom headers".bright_blue(), headers.len()); + } + } + + println!("{} {}", "Max connections".bright_blue(), config.server.max_connections); + println!("{} {}", "Max message size".bright_blue(), config.server.max_message_size); + println!(); + } + + async fn start_server(&self, config: &GurtConfig) -> Result<()> { + let server = FileServerBuilder::new(config.clone()).build()?; + + info!("Starting GURT server on {}", config.address()); + + if let Err(e) = server.listen(&config.address()).await { + error!("Server error: {}", e); + std::process::exit(1); + } + + Ok(()) + } +} + +#[async_trait] +impl CommandHandler for ServeCommandHandler { + async fn execute(&self) -> Result<()> { + self.validate_command()?; + + let config = self.build_server_config()?; + + self.display_startup_info(&config); + self.start_server(&config).await + } +} diff --git a/protocol/cli/src/config.rs b/protocol/cli/src/config.rs new file mode 100644 index 0000000..2a09e0c --- /dev/null +++ b/protocol/cli/src/config.rs @@ -0,0 +1,607 @@ +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; +use std::path::PathBuf; +use std::sync::Arc; +use std::time::Duration; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct GurtConfig { + pub server: ServerConfig, + pub tls: Option, + pub logging: Option, + pub security: Option, + pub error_pages: Option, + pub headers: Option>, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ServerConfig { + #[serde(default = "default_host")] + pub host: String, + + #[serde(default = "default_port")] + pub port: u16, + + #[serde(default = "default_protocol_version")] + pub protocol_version: String, + + #[serde(default = "default_alpn_identifier")] + pub alpn_identifier: String, + + pub timeouts: Option, + + #[serde(default = "default_max_connections")] + pub max_connections: u32, + + #[serde(default = "default_max_message_size")] + pub max_message_size: String, + + #[serde(skip)] + pub base_directory: Arc, + + #[serde(skip)] + pub verbose: bool, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct TimeoutsConfig { + #[serde(default = "default_handshake_timeout")] + pub handshake: u64, + + #[serde(default = "default_request_timeout")] + pub request: u64, + + #[serde(default = "default_connection_timeout")] + pub connection: u64, + + #[serde(default = "default_pool_idle_timeout")] + pub pool_idle: u64, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct TlsConfig { + pub certificate: PathBuf, + pub private_key: PathBuf, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct LoggingConfig { + #[serde(default = "default_log_level")] + pub level: String, + + pub access_log: Option, + pub error_log: Option, + + #[serde(default = "default_log_requests")] + pub log_requests: bool, + + #[serde(default)] + pub log_responses: bool, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct SecurityConfig { + #[serde(default)] + pub deny_files: Vec, + + #[serde(default = "default_allowed_methods")] + pub allowed_methods: Vec, + + #[serde(default = "default_rate_limit_requests")] + pub rate_limit_requests: u32, + + #[serde(default = "default_rate_limit_connections")] + pub rate_limit_connections: u32, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ErrorPagesConfig { + #[serde(flatten)] + pub pages: HashMap, + + pub default: Option, +} + +impl ErrorPagesConfig { + pub fn get_page(&self, status_code: u16) -> Option<&String> { + let code_str = status_code.to_string(); + self.pages.get(&code_str) + } + + pub fn get_default_page(&self, status_code: u16) -> Option<&String> { + if let Some(defaults) = &self.default { + let code_str = status_code.to_string(); + defaults.pages.get(&code_str) + } else { + None + } + } + + pub fn get_any_page(&self, status_code: u16) -> Option<&String> { + self.get_page(status_code) + .or_else(|| self.get_default_page(status_code)) + } + + pub fn get_page_content(&self, status_code: u16, base_dir: &std::path::Path) -> Option { + if let Some(page_value) = self.get_page(status_code) { + if page_value.starts_with('/') || page_value.starts_with("./") { + let file_path = if page_value.starts_with('/') { + base_dir.join(&page_value[1..]) + } else { + base_dir.join(page_value) + }; + + if let Ok(content) = std::fs::read_to_string(&file_path) { + return Some(content); + } else { + tracing::warn!("Failed to read error page file: {}", file_path.display()); + return None; + } + } else { + return Some(page_value.clone()); + } + } + + if let Some(page_value) = self.get_default_page(status_code) { + return Some(page_value.clone()); + } + + None + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ErrorPageDefaults { + #[serde(flatten)] + pub pages: HashMap, +} + +fn default_host() -> String { "127.0.0.1".to_string() } +fn default_port() -> u16 { 4878 } +fn default_protocol_version() -> String { "1.0.0".to_string() } +fn default_alpn_identifier() -> String { "GURT/1.0".to_string() } +fn default_max_connections() -> u32 { 10 } +fn default_max_message_size() -> String { "10MB".to_string() } +fn default_handshake_timeout() -> u64 { 5 } +fn default_request_timeout() -> u64 { 30 } +fn default_connection_timeout() -> u64 { 10 } +fn default_pool_idle_timeout() -> u64 { 300 } +fn default_log_level() -> String { "info".to_string() } +fn default_log_requests() -> bool { true } +fn default_allowed_methods() -> Vec { + vec!["GET".to_string(), "POST".to_string(), "PUT".to_string(), + "DELETE".to_string(), "HEAD".to_string(), "OPTIONS".to_string(), "PATCH".to_string()] +} +fn default_rate_limit_requests() -> u32 { 100 } +fn default_rate_limit_connections() -> u32 { 10 } + +impl Default for GurtConfig { + fn default() -> Self { + Self { + server: ServerConfig::default(), + tls: None, + logging: None, + security: None, + error_pages: None, + headers: None, + } + } +} + +impl Default for ServerConfig { + fn default() -> Self { + Self { + host: default_host(), + port: default_port(), + protocol_version: default_protocol_version(), + alpn_identifier: default_alpn_identifier(), + timeouts: None, + max_connections: default_max_connections(), + max_message_size: default_max_message_size(), + base_directory: Arc::new(PathBuf::from(".")), + verbose: false, + } + } +} + +impl GurtConfig { + pub fn from_file>(path: P) -> crate::Result { + let content = std::fs::read_to_string(path) + .map_err(|e| crate::ServerError::InvalidConfiguration(format!("Failed to read config file: {}", e)))?; + + let config: GurtConfig = toml::from_str(&content) + .map_err(|e| crate::ServerError::InvalidConfiguration(format!("Failed to parse config file: {}", e)))?; + + Ok(config) + } + + pub fn builder() -> GurtConfigBuilder { + GurtConfigBuilder::default() + } + + pub fn address(&self) -> String { + format!("{}:{}", self.server.host, self.server.port) + } + + pub fn max_message_size_bytes(&self) -> crate::Result { + parse_size(&self.server.max_message_size) + } + + pub fn get_handshake_timeout(&self) -> Duration { + Duration::from_secs( + self.server.timeouts + .as_ref() + .map(|t| t.handshake) + .unwrap_or(default_handshake_timeout()) + ) + } + + pub fn get_request_timeout(&self) -> Duration { + Duration::from_secs( + self.server.timeouts + .as_ref() + .map(|t| t.request) + .unwrap_or(default_request_timeout()) + ) + } + + pub fn get_connection_timeout(&self) -> Duration { + Duration::from_secs( + self.server.timeouts + .as_ref() + .map(|t| t.connection) + .unwrap_or(default_connection_timeout()) + ) + } + + pub fn should_deny_file(&self, file_path: &str) -> bool { + if let Some(security) = &self.security { + for pattern in &security.deny_files { + if matches_pattern(file_path, pattern) { + return true; + } + } + } + false + } + + pub fn is_method_allowed(&self, method: &str) -> bool { + if let Some(security) = &self.security { + security.allowed_methods.contains(&method.to_uppercase()) + } else { + default_allowed_methods().contains(&method.to_uppercase()) + } + } + + pub fn default_with_directory(base_dir: PathBuf) -> Self { + let mut config = Self::default(); + config.server.base_directory = Arc::new(base_dir); + config + } + + pub fn from_toml(toml_content: &str, base_dir: PathBuf) -> crate::Result { + let mut config: GurtConfig = toml::from_str(toml_content) + .map_err(|e| crate::ServerError::InvalidConfiguration(format!("Failed to parse config: {}", e)))?; + + config.server.base_directory = Arc::new(base_dir); + Ok(config) + } + + pub fn validate(&self) -> crate::Result<()> { + if !self.server.base_directory.exists() || !self.server.base_directory.is_dir() { + return Err(crate::ServerError::InvalidConfiguration( + format!("Invalid base directory: {}", self.server.base_directory.display()) + )); + } + + if let Some(tls) = &self.tls { + if !tls.certificate.exists() { + return Err(crate::ServerError::TlsConfiguration( + format!("Certificate file does not exist: {}", tls.certificate.display()) + )); + } + if !tls.private_key.exists() { + return Err(crate::ServerError::TlsConfiguration( + format!("Private key file does not exist: {}", tls.private_key.display()) + )); + } + } + + Ok(()) + } +} + +#[derive(Default)] +pub struct GurtConfigBuilder { + config: GurtConfig, +} + +impl GurtConfigBuilder { + pub fn new() -> Self { + Self::default() + } + + pub fn host>(mut self, host: S) -> Self { + self.config.server.host = host.into(); + self + } + + pub fn port(mut self, port: u16) -> Self { + self.config.server.port = port; + self + } + + pub fn base_directory>(mut self, dir: P) -> Self { + self.config.server.base_directory = Arc::new(dir.into()); + self + } + + pub fn verbose(mut self, verbose: bool) -> Self { + self.config.server.verbose = verbose; + self + } + + pub fn tls_config(mut self, cert_path: PathBuf, key_path: PathBuf) -> Self { + self.config.tls = Some(TlsConfig { + certificate: cert_path, + private_key: key_path, + }); + self + } + + pub fn logging_config(mut self, config: LoggingConfig) -> Self { + self.config.logging = Some(config); + self + } + + pub fn security_config(mut self, config: SecurityConfig) -> Self { + self.config.security = Some(config); + self + } + + pub fn error_pages_config(mut self, config: ErrorPagesConfig) -> Self { + self.config.error_pages = Some(config); + self + } + + pub fn headers(mut self, headers: HashMap) -> Self { + self.config.headers = Some(headers); + self + } + + pub fn from_file>(mut self, path: P) -> crate::Result { + let file_config = GurtConfig::from_file(path)?; + self.config = merge_configs(file_config, self.config); + Ok(self) + } + + pub fn merge_cli_args(mut self, cli_args: &crate::cli::ServeCommand) -> Self { + self.config.server.host = cli_args.host.clone(); + self.config.server.port = cli_args.port; + self.config.server.base_directory = Arc::new(cli_args.dir.clone()); + self.config.server.verbose = cli_args.verbose; + + if let (Some(cert), Some(key)) = (&cli_args.cert, &cli_args.key) { + self.config.tls = Some(TlsConfig { + certificate: cert.clone(), + private_key: key.clone(), + }); + } + + self + } + + pub fn build(self) -> crate::Result { + let config = self.config; + + if !config.server.base_directory.exists() || !config.server.base_directory.is_dir() { + return Err(crate::ServerError::InvalidConfiguration( + format!("Invalid base directory: {}", config.server.base_directory.display()) + )); + } + + if let Some(tls) = &config.tls { + if !tls.certificate.exists() { + return Err(crate::ServerError::TlsConfiguration( + format!("Certificate file does not exist: {}", tls.certificate.display()) + )); + } + if !tls.private_key.exists() { + return Err(crate::ServerError::TlsConfiguration( + format!("Private key file does not exist: {}", tls.private_key.display()) + )); + } + } + + Ok(config) + } +} + + +fn parse_size(size_str: &str) -> crate::Result { + let size_str = size_str.trim().to_uppercase(); + + if let Some(captures) = regex::Regex::new(r"^(\d+(?:\.\d+)?)\s*([KMGT]?B?)$").unwrap().captures(&size_str) { + let number: f64 = captures[1].parse() + .map_err(|_| crate::ServerError::InvalidConfiguration(format!("Invalid size format: {}", size_str)))?; + + let unit = captures.get(2).map_or("", |m| m.as_str()); + + let multiplier: u64 = match unit { + "" | "B" => 1, + "KB" => 1_000, + "MB" => 1_000_000, + "GB" => 1_000_000_000, + "TB" => 1_000_000_000_000, + _ => return Err(crate::ServerError::InvalidConfiguration(format!("Unknown size unit: {}", unit))), + }; + let number = (number * multiplier as f64) as u64; + Ok(number) + } else { + Err(crate::ServerError::InvalidConfiguration(format!("Invalid size format: {}", size_str))) + } +} + +fn matches_pattern(path: &str, pattern: &str) -> bool { + if pattern.ends_with("/*") { + let prefix = &pattern[..pattern.len() - 2]; + path.starts_with(prefix) + } else if pattern.starts_with("*.") { + let suffix = &pattern[1..]; + path.ends_with(suffix) + } else { + path == pattern + } +} + +fn merge_configs(base: GurtConfig, override_config: GurtConfig) -> GurtConfig { + GurtConfig { + server: merge_server_configs(base.server, override_config.server), + tls: override_config.tls.or(base.tls), + logging: override_config.logging.or(base.logging), + security: override_config.security.or(base.security), + error_pages: override_config.error_pages.or(base.error_pages), + headers: override_config.headers.or(base.headers), + } +} + +fn merge_server_configs(base: ServerConfig, override_config: ServerConfig) -> ServerConfig { + ServerConfig { + host: if override_config.host != default_host() { override_config.host } else { base.host }, + port: if override_config.port != default_port() { override_config.port } else { base.port }, + protocol_version: if override_config.protocol_version != default_protocol_version() { + override_config.protocol_version + } else { + base.protocol_version + }, + alpn_identifier: if override_config.alpn_identifier != default_alpn_identifier() { + override_config.alpn_identifier + } else { + base.alpn_identifier + }, + timeouts: override_config.timeouts.or(base.timeouts), + max_connections: if override_config.max_connections != default_max_connections() { + override_config.max_connections + } else { + base.max_connections + }, + max_message_size: if override_config.max_message_size != default_max_message_size() { + override_config.max_message_size + } else { + base.max_message_size + }, + base_directory: override_config.base_directory, + verbose: override_config.verbose, + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::path::PathBuf; + + #[test] + fn test_default_config_creation() { + let base_dir = PathBuf::from("/tmp"); + let mut config = GurtConfig::default(); + config.server.base_directory = Arc::new(base_dir.clone()); + + assert_eq!(config.server.host, "127.0.0.1"); + assert_eq!(config.server.port, 4878); + assert_eq!(config.server.protocol_version, "1.0.0"); + assert_eq!(config.server.alpn_identifier, "GURT/1.0"); + assert_eq!(*config.server.base_directory, base_dir); + } + + #[test] + fn test_config_from_valid_toml() { + let toml_content = r#" +[server] +host = "0.0.0.0" +port = 8080 +protocol_version = "2.0.0" +alpn_identifier = "custom" +max_connections = 1000 +max_message_size = "10MB" + +[security] +rate_limit_requests = 60 +rate_limit_connections = 5 +"#; + + let base_dir = PathBuf::from("/tmp"); + let config = GurtConfig::from_toml(toml_content, base_dir).unwrap(); + + assert_eq!(config.server.host, "0.0.0.0"); + assert_eq!(config.server.port, 8080); + assert_eq!(config.server.protocol_version, "2.0.0"); + assert_eq!(config.server.alpn_identifier, "custom"); + assert_eq!(config.server.max_connections, 1000); + + let security = config.security.unwrap(); + assert_eq!(security.rate_limit_requests, 60); + assert_eq!(security.rate_limit_connections, 5); + } + + #[test] + fn test_invalid_toml_returns_error() { + let invalid_toml = r#" +[server +host = "0.0.0.0" +"#; + + let base_dir = PathBuf::from("/tmp"); + let result = GurtConfig::from_toml(invalid_toml, base_dir); + + assert!(result.is_err()); + } + + #[test] + fn test_max_message_size_parsing() { + let config = GurtConfig::default(); + + assert_eq!(parse_size("1024").unwrap(), 1024); + assert_eq!(parse_size("1KB").unwrap(), 1000); + assert_eq!(parse_size("1MB").unwrap(), 1000 * 1000); + assert_eq!(parse_size("1GB").unwrap(), 1000 * 1000 * 1000); + + assert!(parse_size("invalid").is_err()); + + assert!(config.max_message_size_bytes().is_ok()); + } + + #[test] + fn test_tls_config_validation() { + let mut config = GurtConfig::default(); + + config.tls = Some(TlsConfig { + certificate: PathBuf::from("/nonexistent/cert.pem"), + private_key: PathBuf::from("/nonexistent/key.pem"), + }); + + assert!(config.tls.is_some()); + let tls = config.tls.unwrap(); + assert_eq!(tls.certificate, PathBuf::from("/nonexistent/cert.pem")); + assert_eq!(tls.private_key, PathBuf::from("/nonexistent/key.pem")); + } + + #[test] + fn test_address_formatting() { + let config = GurtConfig::default(); + assert_eq!(config.address(), "127.0.0.1:4878"); + + let mut custom_config = GurtConfig::default(); + custom_config.server.host = "0.0.0.0".to_string(); + custom_config.server.port = 8080; + assert_eq!(custom_config.address(), "0.0.0.0:8080"); + } + + #[test] + fn test_timeout_getters() { + let config = GurtConfig::default(); + + assert_eq!(config.get_handshake_timeout(), Duration::from_secs(5)); + assert_eq!(config.get_request_timeout(), Duration::from_secs(30)); + assert_eq!(config.get_connection_timeout(), Duration::from_secs(10)); + } +} diff --git a/protocol/cli/src/error.rs b/protocol/cli/src/error.rs new file mode 100644 index 0000000..b99c997 --- /dev/null +++ b/protocol/cli/src/error.rs @@ -0,0 +1,38 @@ +use std::fmt; + +#[derive(Debug)] +pub enum ServerError { + Io(std::io::Error), + InvalidPath(String), + InvalidConfiguration(String), + TlsConfiguration(String), + ServerStartup(String), +} + +impl fmt::Display for ServerError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + ServerError::Io(err) => write!(f, "I/O error: {}", err), + ServerError::InvalidPath(path) => write!(f, "Invalid path: {}", path), + ServerError::InvalidConfiguration(msg) => write!(f, "Configuration error: {}", msg), + ServerError::TlsConfiguration(msg) => write!(f, "TLS configuration error: {}", msg), + ServerError::ServerStartup(msg) => write!(f, "Server startup error: {}", msg), + } + } +} + +impl std::error::Error for ServerError {} + +impl From for ServerError { + fn from(err: std::io::Error) -> Self { + ServerError::Io(err) + } +} + +impl From for ServerError { + fn from(err: gurt::GurtError) -> Self { + ServerError::ServerStartup(err.to_string()) + } +} + +pub type Result = std::result::Result; diff --git a/protocol/cli/src/handlers.rs b/protocol/cli/src/handlers.rs new file mode 100644 index 0000000..1e244b2 --- /dev/null +++ b/protocol/cli/src/handlers.rs @@ -0,0 +1,197 @@ +use std::path::Path; + +pub trait FileHandler: Send + Sync { + fn can_handle(&self, path: &Path) -> bool; + fn get_content_type(&self, path: &Path) -> String; + fn handle_file(&self, path: &Path) -> crate::Result>; +} + +pub struct DefaultFileHandler; + +impl FileHandler for DefaultFileHandler { + fn can_handle(&self, _path: &Path) -> bool { + true // Default + } + + fn get_content_type(&self, path: &Path) -> String { + match path.extension().and_then(|ext| ext.to_str()) { + Some("html") | Some("htm") => "text/html".to_string(), + Some("css") => "text/css".to_string(), + Some("js") => "application/javascript".to_string(), + Some("json") => "application/json".to_string(), + Some("png") => "image/png".to_string(), + Some("jpg") | Some("jpeg") => "image/jpeg".to_string(), + Some("gif") => "image/gif".to_string(), + Some("svg") => "image/svg+xml".to_string(), + Some("ico") => "image/x-icon".to_string(), + Some("txt") => "text/plain".to_string(), + Some("xml") => "application/xml".to_string(), + Some("pdf") => "application/pdf".to_string(), + _ => "application/octet-stream".to_string(), + } + } + + fn handle_file(&self, path: &Path) -> crate::Result> { + std::fs::read(path).map_err(crate::ServerError::from) + } +} + +pub trait DirectoryHandler: Send + Sync { + fn handle_directory(&self, path: &Path, request_path: &str) -> crate::Result; +} + +pub struct DefaultDirectoryHandler; + +impl DirectoryHandler for DefaultDirectoryHandler { + fn handle_directory(&self, path: &Path, request_path: &str) -> crate::Result { + let entries = std::fs::read_dir(path)?; + + let mut listing = String::from(include_str!("../templates/directory_listing_start.html")); + + if request_path != "/" { + listing.push_str(include_str!("../templates/directory_parent_link.html")); + } + + listing.push_str(include_str!("../templates/directory_content_start.html")); + + for entry in entries.flatten() { + let file_name = entry.file_name(); + let name = file_name.to_string_lossy(); + let is_dir = entry.path().is_dir(); + let display_name = if is_dir { + format!("{}/", name) + } else { + name.to_string() + }; + let class = if is_dir { "dir" } else { "file" }; + + listing.push_str(&format!( + r#" {}"#, + name, class, display_name + )); + listing.push('\n'); + } + + listing.push_str(include_str!("../templates/directory_listing_end.html")); + Ok(listing) + } +} + +pub fn get_404_html() -> &'static str { + include_str!("../templates/404.html") +} + +#[cfg(test)] +mod tests { + use super::*; + use std::path::Path; + + #[test] + fn test_default_file_handler_can_handle_any_file() { + let handler = DefaultFileHandler; + let path = Path::new("test.txt"); + assert!(handler.can_handle(path)); + + let path = Path::new("some/random/file"); + assert!(handler.can_handle(path)); + } + + #[test] + fn test_content_type_detection() { + let handler = DefaultFileHandler; + + assert_eq!(handler.get_content_type(Path::new("index.html")), "text/html"); + assert_eq!(handler.get_content_type(Path::new("style.css")), "text/css"); + assert_eq!(handler.get_content_type(Path::new("script.js")), "application/javascript"); + assert_eq!(handler.get_content_type(Path::new("data.json")), "application/json"); + + assert_eq!(handler.get_content_type(Path::new("image.png")), "image/png"); + assert_eq!(handler.get_content_type(Path::new("photo.jpg")), "image/jpeg"); + assert_eq!(handler.get_content_type(Path::new("photo.jpeg")), "image/jpeg"); + assert_eq!(handler.get_content_type(Path::new("icon.ico")), "image/x-icon"); + assert_eq!(handler.get_content_type(Path::new("vector.svg")), "image/svg+xml"); + + assert_eq!(handler.get_content_type(Path::new("readme.txt")), "text/plain"); + assert_eq!(handler.get_content_type(Path::new("data.xml")), "application/xml"); + assert_eq!(handler.get_content_type(Path::new("document.pdf")), "application/pdf"); + + assert_eq!(handler.get_content_type(Path::new("file.unknown")), "application/octet-stream"); + + assert_eq!(handler.get_content_type(Path::new("noextension")), "application/octet-stream"); + } + + #[test] + fn test_directory_handler_generates_valid_html() { + use std::fs; + use std::env; + + let temp_dir = env::temp_dir().join("gurty_test"); + let _ = fs::create_dir_all(&temp_dir); + + let _ = fs::write(temp_dir.join("test.txt"), "test content"); + let _ = fs::create_dir_all(temp_dir.join("subdir")); + + let handler = DefaultDirectoryHandler; + let result = handler.handle_directory(&temp_dir, "/test/"); + + assert!(result.is_ok()); + let html = result.unwrap(); + + assert!(html.contains("")); + assert!(html.contains("Directory Listing")); + assert!(html.contains("← Parent Directory")); + assert!(html.contains("test.txt")); + assert!(html.contains("subdir/")); + + let _ = fs::remove_dir_all(&temp_dir); + } + + #[test] + fn test_directory_handler_root_path() { + use std::fs; + use std::env; + + let temp_dir = env::temp_dir().join("gurty_test_root"); + let _ = fs::create_dir_all(&temp_dir); + + let handler = DefaultDirectoryHandler; + let result = handler.handle_directory(&temp_dir, "/"); + + assert!(result.is_ok()); + let html = result.unwrap(); + + assert!(!html.contains("← Parent Directory")); + + let _ = fs::remove_dir_all(&temp_dir); + } + + #[test] + fn test_get_404_html_content() { + let html = get_404_html(); + + assert!(html.contains("")); + assert!(html.contains("404 Page Not Found")); + assert!(html.contains("The requested path was not found")); + assert!(html.contains("Back to home")); + } + + #[test] + fn test_directory_handler_with_empty_directory() { + use std::fs; + use std::env; + + let temp_dir = env::temp_dir().join("gurty_test_empty"); + let _ = fs::create_dir_all(&temp_dir); + + let handler = DefaultDirectoryHandler; + let result = handler.handle_directory(&temp_dir, "/empty/"); + + assert!(result.is_ok()); + let html = result.unwrap(); + + assert!(html.contains("")); + assert!(html.contains("Directory Listing")); + + let _ = fs::remove_dir_all(&temp_dir); + } +} diff --git a/protocol/cli/src/lib.rs b/protocol/cli/src/lib.rs new file mode 100644 index 0000000..42f8cb6 --- /dev/null +++ b/protocol/cli/src/lib.rs @@ -0,0 +1,10 @@ +pub mod cli; +pub mod config; +pub mod error; +pub mod security; +pub mod server; +pub mod request_handler; +pub mod command_handler; +pub mod handlers; + +pub use error::{Result, ServerError}; diff --git a/protocol/cli/src/main.rs b/protocol/cli/src/main.rs index cfe709a..d417f55 100644 --- a/protocol/cli/src/main.rs +++ b/protocol/cli/src/main.rs @@ -1,323 +1,23 @@ -use clap::{Parser, Subcommand}; -use colored::Colorize; -use gurt::prelude::*; -use std::path::PathBuf; -use tracing::error; -use tracing_subscriber; - -#[derive(Parser)] -#[command(name = "server")] -#[command(about = "GURT Protocol Server")] -#[command(version = "1.0.0")] -struct Cli { - #[command(subcommand)] - command: Commands, -} - -#[derive(Subcommand)] -enum Commands { - Serve { - #[arg(short, long, default_value_t = 4878)] - port: u16, - - #[arg(long, default_value = "127.0.0.1")] - host: String, - - #[arg(short, long, default_value = ".")] - dir: PathBuf, - - #[arg(short, long)] - verbose: bool, - - #[arg(long, help = "Path to TLS certificate file")] - cert: Option, - - #[arg(long, help = "Path to TLS private key file")] - key: Option, - } -} +use clap::Parser; +use gurty::{ + cli::{Cli, Commands}, + command_handler::{CommandHandler, CommandHandlerBuilder}, + Result, +}; #[tokio::main] async fn main() -> Result<()> { let cli = Cli::parse(); match cli.command { - Commands::Serve { port, host, dir, verbose, cert, key } => { - if verbose { - tracing_subscriber::fmt() - .with_max_level(tracing::Level::DEBUG) - .init(); - } else { - tracing_subscriber::fmt() - .with_max_level(tracing::Level::INFO) - .init(); - } + Commands::Serve(serve_cmd) => { + let handler = CommandHandlerBuilder::new() + .with_logging(serve_cmd.verbose) + .initialize_logging() + .build_serve_handler(serve_cmd); - println!("{}", "GURT Protocol Server".bright_cyan().bold()); - println!("{} {}:{}", "Listening on".bright_blue(), host, port); - println!("{} {}", "Serving from".bright_blue(), dir.display()); - - let server = create_file_server(dir, cert, key)?; - let addr = format!("{}:{}", host, port); - - if let Err(e) = server.listen(&addr).await { - error!("Server error: {}", e); - std::process::exit(1); - } + handler.execute().await } } - - Ok(()) } -fn create_file_server(base_dir: PathBuf, cert_path: Option, key_path: Option) -> Result { - let base_dir = std::sync::Arc::new(base_dir); - - let server = match (cert_path, key_path) { - (Some(cert), Some(key)) => { - println!("TLS using certificate: {}", cert.display()); - GurtServer::with_tls_certificates( - cert.to_str().ok_or_else(|| GurtError::invalid_message("Invalid certificate path"))?, - key.to_str().ok_or_else(|| GurtError::invalid_message("Invalid key path"))? - )? - } - (Some(_), None) => { - return Err(GurtError::invalid_message("Certificate provided but no key file specified (use --key)")); - } - (None, Some(_)) => { - return Err(GurtError::invalid_message("Key provided but no certificate file specified (use --cert)")); - } - (None, None) => { - return Err(GurtError::invalid_message("GURT protocol requires TLS encryption. Please provide --cert and --key parameters.")); - } - }; - - let server = server - .get("/", { - let base_dir = base_dir.clone(); - move |_| { - let base_dir = base_dir.clone(); - async move { - // Try to serve index.html if it exists - let index_path = base_dir.join("index.html"); - - if index_path.exists() && index_path.is_file() { - match std::fs::read_to_string(&index_path) { - Ok(content) => { - return Ok(GurtResponse::ok() - .with_header("Content-Type", "text/html") - .with_string_body(content)); - } - Err(_) => { - // Fall through to directory listing - } - } - } - - // No index.html found, show directory listing - match std::fs::read_dir(base_dir.as_ref()) { - Ok(entries) => { - let mut listing = String::from(r#" - - - - Directory Listing - - - -

Directory Listing

-
-"#); - for entry in entries.flatten() { - let file_name = entry.file_name(); - let name = file_name.to_string_lossy(); - let is_dir = entry.path().is_dir(); - let display_name = if is_dir { format!("{}/", name) } else { name.to_string() }; - let class = if is_dir { "style=\"dir\"" } else { "" }; - - listing.push_str(&format!( - r#" {}"#, - class, name, display_name - )); - listing.push('\n'); - } - - listing.push_str("
\n"); - - Ok(GurtResponse::ok() - .with_header("Content-Type", "text/html") - .with_string_body(listing)) - } - Err(_) => { - Ok(GurtResponse::internal_server_error() - .with_header("Content-Type", "text/plain") - .with_string_body("Failed to read directory")) - } - } - } - } - }) - .get("/*", { - let base_dir = base_dir.clone(); - move |ctx| { - let base_dir = base_dir.clone(); - let path = ctx.path().to_string(); - async move { - let mut relative_path = path.strip_prefix('/').unwrap_or(&path).to_string(); - // Remove any leading slashes to ensure relative path - while relative_path.starts_with('/') || relative_path.starts_with('\\') { - relative_path = relative_path[1..].to_string(); - } - // If the path is now empty, use "." - let relative_path = if relative_path.is_empty() { ".".to_string() } else { relative_path }; - let file_path = base_dir.join(&relative_path); - - match file_path.canonicalize() { - Ok(canonical_path) => { - let canonical_base = match base_dir.canonicalize() { - Ok(base) => base, - Err(_) => { - return Ok(GurtResponse::internal_server_error() - .with_header("Content-Type", "text/plain") - .with_string_body("Server configuration error")); - } - }; - - if !canonical_path.starts_with(&canonical_base) { - return Ok(GurtResponse::bad_request() - .with_header("Content-Type", "text/plain") - .with_string_body("Access denied: Path outside served directory")); - } - - if canonical_path.is_file() { - match std::fs::read(&canonical_path) { - Ok(content) => { - let content_type = get_content_type(&canonical_path); - Ok(GurtResponse::ok() - .with_header("Content-Type", &content_type) - .with_body(content)) - } - Err(_) => { - Ok(GurtResponse::internal_server_error() - .with_header("Content-Type", "text/plain") - .with_string_body("Failed to read file")) - } - } - } else if canonical_path.is_dir() { - let index_path = canonical_path.join("index.html"); - if index_path.is_file() { - match std::fs::read_to_string(&index_path) { - Ok(content) => { - Ok(GurtResponse::ok() - .with_header("Content-Type", "text/html") - .with_string_body(content)) - } - Err(_) => { - Ok(GurtResponse::internal_server_error() - .with_header("Content-Type", "text/plain") - .with_string_body("Failed to read index file")) - } - } - } else { - match std::fs::read_dir(&canonical_path) { - Ok(entries) => { - let mut listing = String::from(r#" - - - - Directory Listing - - - -

Directory Listing

-

← Parent Directory

-
-"#); - for entry in entries.flatten() { - let file_name = entry.file_name(); - let name = file_name.to_string_lossy(); - let is_dir = entry.path().is_dir(); - let display_name = if is_dir { format!("{}/", name) } else { name.to_string() }; - let class = if is_dir { "style=\"dir\"" } else { "" }; - - listing.push_str(&format!( - r#" {}"#, - class, name, display_name - )); - listing.push('\n'); - } - - listing.push_str("
\n"); - - Ok(GurtResponse::ok() - .with_header("Content-Type", "text/html") - .with_string_body(listing)) - } - Err(_) => { - Ok(GurtResponse::internal_server_error() - .with_header("Content-Type", "text/plain") - .with_string_body("Failed to read directory")) - } - } - } - } else { - // File not found - Ok(GurtResponse::not_found() - .with_header("Content-Type", "text/html") - .with_string_body(get_404_html())) - } - } - Err(_e) => { - Ok(GurtResponse::not_found() - .with_header("Content-Type", "text/html") - .with_string_body(get_404_html())) - } - } - } - } - }); - - Ok(server) -} - -fn get_404_html() -> &'static str { - r#" - - - 404 Not Found - - - -

404 Page Not Found

-

The requested path was not found on this GURT server.

-

Back to home

- - -"# -} - -fn get_content_type(path: &std::path::Path) -> String { - match path.extension().and_then(|ext| ext.to_str()) { - Some("html") | Some("htm") => "text/html".to_string(), - Some("css") => "text/css".to_string(), - Some("js") => "application/javascript".to_string(), - Some("json") => "application/json".to_string(), - Some("png") => "image/png".to_string(), - Some("jpg") | Some("jpeg") => "image/jpeg".to_string(), - Some("gif") => "image/gif".to_string(), - Some("svg") => "image/svg+xml".to_string(), - Some("ico") => "image/x-icon".to_string(), - Some("txt") => "text/plain".to_string(), - Some("xml") => "application/xml".to_string(), - Some("pdf") => "application/pdf".to_string(), - _ => "application/octet-stream".to_string(), - } -} \ No newline at end of file diff --git a/protocol/cli/src/request_handler.rs b/protocol/cli/src/request_handler.rs new file mode 100644 index 0000000..aee5681 --- /dev/null +++ b/protocol/cli/src/request_handler.rs @@ -0,0 +1,560 @@ +use crate::{ + handlers::{FileHandler, DirectoryHandler, DefaultFileHandler, DefaultDirectoryHandler}, + config::GurtConfig, + security::SecurityMiddleware, +}; +use gurt::prelude::*; +use std::path::Path; +use std::sync::Arc; +use tracing; + +pub struct RequestHandlerBuilder { + file_handler: Arc, + directory_handler: Arc, + base_directory: std::path::PathBuf, + config: Option>, +} + +impl RequestHandlerBuilder { + pub fn new>(base_directory: P) -> Self { + Self { + file_handler: Arc::new(DefaultFileHandler), + directory_handler: Arc::new(DefaultDirectoryHandler), + base_directory: base_directory.as_ref().to_path_buf(), + config: None, + } + } + + pub fn with_file_handler(mut self, handler: H) -> Self { + self.file_handler = Arc::new(handler); + self + } + + pub fn with_directory_handler(mut self, handler: H) -> Self { + self.directory_handler = Arc::new(handler); + self + } + + pub fn with_config(mut self, config: Arc) -> Self { + self.config = Some(config); + self + } + + pub fn build(self) -> RequestHandler { + let security = self.config.as_ref().map(|config| SecurityMiddleware::new(config.clone())); + + RequestHandler { + file_handler: self.file_handler, + directory_handler: self.directory_handler, + base_directory: self.base_directory, + config: self.config, + security, + } + } +} + +pub struct RequestHandler { + file_handler: Arc, + directory_handler: Arc, + base_directory: std::path::PathBuf, + config: Option>, + security: Option, +} + +impl RequestHandler { + pub fn builder>(base_directory: P) -> RequestHandlerBuilder { + RequestHandlerBuilder::new(base_directory) + } + + fn apply_custom_error_page(&self, mut response: GurtResponse) -> GurtResponse { + if response.status_code >= 400 { + let custom_content = self.get_custom_error_page(response.status_code) + .unwrap_or_else(|| self.get_fallback_error_page(response.status_code)); + + response.body = custom_content.into_bytes(); + response = response.with_header("Content-Type", "text/html"); + tracing::debug!("Applied error page for status {}", response.status_code); + } + response + } + + fn get_custom_error_page(&self, status_code: u16) -> Option { + if let Some(config) = &self.config { + if let Some(error_pages) = &config.error_pages { + error_pages.get_page_content(status_code, &self.base_directory) + } else { + None + } + } else { + None + } + } + + fn get_fallback_error_page(&self, status_code: u16) -> String { + let (title, message) = match status_code { + 400 => ("Bad Request", "The request could not be understood by the server."), + 401 => ("Unauthorized", "Authentication is required to access this resource."), + 403 => ("Forbidden", "Access to this resource is denied by server policy."), + 404 => ("Not Found", "The requested resource was not found on this server."), + 405 => ("Method Not Allowed", "The request method is not allowed for this resource."), + 429 => ("Too Many Requests", "You have exceeded the rate limit. Please try again later."), + 500 => ("Internal Server Error", "The server encountered an error processing your request."), + 502 => ("Bad Gateway", "The server received an invalid response from an upstream server."), + 503 => ("Service Unavailable", "The server is temporarily unavailable. Please try again later."), + 504 => ("Gateway Timeout", "The server did not receive a timely response from an upstream server."), + _ => ("Error", "An error occurred while processing your request."), + }; + + format!(include_str!("../templates/error.html"), status_code, title, status_code, title, message) + } + + pub fn check_security(&self, ctx: &ServerContext) -> Option> { + if let Some(security) = &self.security { + let client_ip = ctx.client_ip(); + let method = ctx.method(); + + if !security.is_method_allowed(method) { + tracing::warn!("Method {} not allowed from {}", method, client_ip); + let response = security.create_method_not_allowed_response() + .map(|r| self.apply_global_headers(r)); + return Some(response); + } + + if !security.check_rate_limit(client_ip) { + let response = security.create_rate_limit_response() + .map(|r| self.apply_global_headers(r)); + return Some(response); + } + + if !security.check_connection_limit(client_ip) { + let response = security.create_rate_limit_response() + .map(|r| self.apply_global_headers(r)); + return Some(response); + } + } + + None + } + + pub fn register_connection(&self, client_ip: std::net::IpAddr) { + if let Some(security) = &self.security { + security.register_connection(client_ip); + } + } + + pub fn unregister_connection(&self, client_ip: std::net::IpAddr) { + if let Some(security) = &self.security { + security.unregister_connection(client_ip); + } + } + + fn is_file_denied(&self, file_path: &Path) -> bool { + if let Some(config) = &self.config { + let path_str = file_path.to_string_lossy(); + + let relative_path = if let Ok(canonical_file) = file_path.canonicalize() { + if let Ok(canonical_base) = self.base_directory.canonicalize() { + canonical_file.strip_prefix(&canonical_base) + .map(|p| p.to_string_lossy().to_string()) + .unwrap_or_else(|_| path_str.to_string()) + } else { + path_str.to_string() + } + } else { + path_str.to_string() + }; + + let is_denied = config.should_deny_file(&path_str) || config.should_deny_file(&relative_path); + + if is_denied { + tracing::warn!("File access denied by security policy: {}", relative_path); + } + + is_denied + } else { + false + } + } + + fn apply_global_headers(&self, mut response: GurtResponse) -> GurtResponse { + response = self.apply_custom_error_page(response); + + if let Some(config) = &self.config { + if let Some(headers) = &config.headers { + for (key, value) in headers { + response = response.with_header(key, value); + } + } + } + response + } + + fn create_forbidden_response(&self) -> std::result::Result { + let response = GurtResponse::forbidden() + .with_header("Content-Type", "text/html"); + + Ok(self.apply_global_headers(response)) + } + + pub async fn handle_root_request_with_context(&self, ctx: ServerContext) -> std::result::Result { + let client_ip = ctx.client_ip(); + + self.register_connection(client_ip); + + if let Some(security_response) = self.check_security(&ctx) { + self.unregister_connection(client_ip); + return security_response; + } + + let result = self.handle_root_request().await; + self.unregister_connection(client_ip); + result + } + + pub async fn handle_file_request_with_context(&self, request_path: &str, ctx: ServerContext) -> std::result::Result { + let client_ip = ctx.client_ip(); + + self.register_connection(client_ip); + + if let Some(security_response) = self.check_security(&ctx) { + self.unregister_connection(client_ip); + return security_response; + } + + let result = self.handle_file_request(request_path).await; + self.unregister_connection(client_ip); + result + } + + pub async fn handle_method_request_with_context(&self, ctx: ServerContext) -> std::result::Result { + let client_ip = ctx.client_ip(); + let method = ctx.method(); + + self.register_connection(client_ip); + + if let Some(security_response) = self.check_security(&ctx) { + self.unregister_connection(client_ip); + return security_response; + } + + let result = match method { + gurt::message::GurtMethod::GET => { + if ctx.path() == "/" { + self.handle_root_request().await + } else { + self.handle_file_request(ctx.path()).await + } + } + gurt::message::GurtMethod::HEAD => { + let mut response = if ctx.path() == "/" { + self.handle_root_request().await? + } else { + self.handle_file_request(ctx.path()).await? + }; + response.body = Vec::new(); + Ok(response) + } + gurt::message::GurtMethod::OPTIONS => { + let allowed_methods = if let Some(config) = &self.config { + if let Some(security) = &config.security { + security.allowed_methods.join(", ") + } else { + "GET, POST, PUT, DELETE, HEAD, OPTIONS, PATCH".to_string() + } + } else { + "GET, POST, PUT, DELETE, HEAD, OPTIONS, PATCH".to_string() + }; + + let response = GurtResponse::ok() + .with_header("Allow", &allowed_methods) + .with_header("Content-Type", "text/plain") + .with_string_body("Allowed methods"); + Ok(self.apply_global_headers(response)) + } + _ => { + let response = GurtResponse::new(gurt::protocol::GurtStatusCode::MethodNotAllowed) + .with_header("Content-Type", "text/html"); + Ok(self.apply_global_headers(response)) + } + }; + + self.unregister_connection(client_ip); + result + } + + pub async fn handle_root_request(&self) -> std::result::Result { + let index_path = self.base_directory.join("index.html"); + + if index_path.exists() && index_path.is_file() { + if self.is_file_denied(&index_path) { + return self.create_forbidden_response(); + } + + match self.file_handler.handle_file(&index_path) { + Ok(content) => { + let content_type = self.file_handler.get_content_type(&index_path); + let response = GurtResponse::ok() + .with_header("Content-Type", &content_type) + .with_body(content); + return Ok(self.apply_global_headers(response)); + } + Err(_) => { + // fall + } + } + } + + match self.directory_handler.handle_directory(&self.base_directory, "/") { + Ok(listing) => { + let response = GurtResponse::ok() + .with_header("Content-Type", "text/html") + .with_string_body(listing); + Ok(self.apply_global_headers(response)) + } + Err(_) => { + let response = GurtResponse::internal_server_error() + .with_header("Content-Type", "text/html"); + Ok(self.apply_global_headers(response)) + } + } + } + + pub async fn handle_file_request(&self, request_path: &str) -> std::result::Result { + let mut relative_path = request_path.strip_prefix('/').unwrap_or(request_path).to_string(); + + while relative_path.starts_with('/') || relative_path.starts_with('\\') { + relative_path = relative_path[1..].to_string(); + } + + let relative_path = if relative_path.is_empty() { + ".".to_string() + } else { + relative_path + }; + + let file_path = self.base_directory.join(&relative_path); + + if self.is_file_denied(&file_path) { + return self.create_forbidden_response(); + } + + match file_path.canonicalize() { + Ok(canonical_path) => { + let canonical_base = match self.base_directory.canonicalize() { + Ok(base) => base, + Err(_) => { + return Ok(GurtResponse::internal_server_error() + .with_header("Content-Type", "text/html")); + } + }; + + if !canonical_path.starts_with(&canonical_base) { + let response = GurtResponse::bad_request() + .with_header("Content-Type", "text/html"); + return Ok(self.apply_global_headers(response)); + } + + if self.is_file_denied(&canonical_path) { + return self.create_forbidden_response(); + } + + if canonical_path.is_file() { + self.handle_file_response(&canonical_path).await + } else if canonical_path.is_dir() { + self.handle_directory_response(&canonical_path, request_path).await + } else { + self.handle_not_found_response().await + } + } + Err(_) => { + self.handle_not_found_response().await + } + } + } + + async fn handle_file_response(&self, path: &Path) -> std::result::Result { + match self.file_handler.handle_file(path) { + Ok(content) => { + let content_type = self.file_handler.get_content_type(path); + let response = GurtResponse::ok() + .with_header("Content-Type", &content_type) + .with_body(content); + Ok(self.apply_global_headers(response)) + } + Err(_) => { + let response = GurtResponse::internal_server_error() + .with_header("Content-Type", "text/html"); + Ok(self.apply_global_headers(response)) + } + } + } + + async fn handle_directory_response(&self, canonical_path: &Path, request_path: &str) -> std::result::Result { + let index_path = canonical_path.join("index.html"); + if index_path.is_file() { + self.handle_file_response(&index_path).await + } else { + match self.directory_handler.handle_directory(canonical_path, request_path) { + Ok(listing) => { + let response = GurtResponse::ok() + .with_header("Content-Type", "text/html") + .with_string_body(listing); + Ok(self.apply_global_headers(response)) + } + Err(_) => { + let response = GurtResponse::internal_server_error() + .with_header("Content-Type", "text/html"); + Ok(self.apply_global_headers(response)) + } + } + } + } + + async fn handle_not_found_response(&self) -> std::result::Result { + let content = self.get_custom_error_page(404) + .unwrap_or_else(|| crate::handlers::get_404_html().to_string()); + + let response = GurtResponse::not_found() + .with_header("Content-Type", "text/html") + .with_string_body(content); + Ok(self.apply_global_headers(response)) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use gurt::GurtStatusCode; + use std::fs; + use std::env; + + fn create_test_handler() -> RequestHandler { + let temp_dir = env::temp_dir().join("gurty_request_handler_test"); + let _ = fs::create_dir_all(&temp_dir); + + RequestHandler::builder(&temp_dir).build() + } + + fn create_test_handler_with_config() -> RequestHandler { + let temp_dir = env::temp_dir().join("gurty_request_handler_test_config"); + let _ = fs::create_dir_all(&temp_dir); + + let config = Arc::new(GurtConfig::default()); + RequestHandler::builder(&temp_dir) + .with_config(config) + .build() + } + + #[test] + fn test_request_handler_builder() { + let temp_dir = env::temp_dir().join("gurty_builder_test"); + let _ = fs::create_dir_all(&temp_dir); + + let handler = RequestHandler::builder(&temp_dir).build(); + + assert_eq!(handler.base_directory, temp_dir); + assert!(handler.config.is_none()); + assert!(handler.security.is_none()); + + let _ = fs::remove_dir_all(&temp_dir); + } + + #[test] + fn test_request_handler_builder_with_config() { + let temp_dir = env::temp_dir().join("gurty_builder_config_test"); + let _ = fs::create_dir_all(&temp_dir); + + let config = Arc::new(GurtConfig::default()); + let handler = RequestHandler::builder(&temp_dir) + .with_config(config.clone()) + .build(); + + assert!(handler.config.is_some()); + assert!(handler.security.is_some()); + + let _ = fs::remove_dir_all(&temp_dir); + } + + #[test] + fn test_fallback_error_page_generation() { + let handler = create_test_handler(); + + let error_404 = handler.get_fallback_error_page(404); + assert!(error_404.contains("404 Not Found")); + assert!(error_404.contains("not found")); + + let error_500 = handler.get_fallback_error_page(500); + assert!(error_500.contains("500 Internal Server Error")); + assert!(error_500.contains("processing your request")); + + let error_429 = handler.get_fallback_error_page(429); + assert!(error_429.contains("429 Too Many Requests")); + assert!(error_429.contains("rate limit")); + } + + #[test] + fn test_custom_error_page_with_config() { + let handler = create_test_handler_with_config(); + + let result = handler.get_custom_error_page(404); + assert!(result.is_none()); + } + + #[test] + fn test_apply_global_headers_without_config() { + let handler = create_test_handler(); + let response = GurtResponse::ok(); + + let modified_response = handler.apply_global_headers(response); + + assert_eq!(modified_response.status_code, 200); + } + + #[test] + fn test_apply_global_headers_with_config() { + let temp_dir = env::temp_dir().join("gurty_headers_test"); + let _ = fs::create_dir_all(&temp_dir); + + let mut config = GurtConfig::default(); + let mut headers = std::collections::HashMap::new(); + headers.insert("X-Test-Header".to_string(), "test-value".to_string()); + config.headers = Some(headers); + + let handler = RequestHandler::builder(&temp_dir) + .with_config(Arc::new(config)) + .build(); + + let response = GurtResponse::ok(); + let modified_response = handler.apply_global_headers(response); + + assert!(modified_response.headers.contains_key("x-test-header")); + assert_eq!(modified_response.headers.get("x-test-header").unwrap(), "test-value"); + + let _ = fs::remove_dir_all(&temp_dir); + } + + #[test] + fn test_apply_custom_error_page() { + let handler = create_test_handler(); + let mut response = GurtResponse::new(GurtStatusCode::NotFound); + response.body = b"Not Found".to_vec(); + + let modified_response = handler.apply_custom_error_page(response); + + assert!(modified_response.status_code >= 400); + let body_str = String::from_utf8_lossy(&modified_response.body); + assert!(body_str.contains("html")); + } + + #[test] + fn test_apply_custom_error_page_for_success() { + let handler = create_test_handler(); + let mut response = GurtResponse::ok(); + response.body = b"Success".to_vec(); + + let modified_response = handler.apply_custom_error_page(response); + + assert_eq!(modified_response.status_code, 200); + assert_eq!(modified_response.body, b"Success".to_vec()); + } +} diff --git a/protocol/cli/src/security.rs b/protocol/cli/src/security.rs new file mode 100644 index 0000000..f738cc0 --- /dev/null +++ b/protocol/cli/src/security.rs @@ -0,0 +1,288 @@ +use crate::config::GurtConfig; +use gurt::{prelude::*, GurtMethod, GurtStatusCode}; +use std::collections::HashMap; +use std::net::IpAddr; +use std::sync::{Arc, Mutex}; +use std::time::{Duration, Instant}; +use tracing::{warn, debug}; + +#[derive(Debug)] +pub struct RateLimitData { + requests: Vec, + connections: u32, +} + +impl RateLimitData { + fn new() -> Self { + Self { + requests: Vec::new(), + connections: 0, + } + } + + fn cleanup_old_requests(&mut self, window: Duration) { + let cutoff = Instant::now() - window; + self.requests.retain(|&request_time| request_time > cutoff); + } + + fn add_request(&mut self) { + self.requests.push(Instant::now()); + } + + fn request_count(&self) -> usize { + self.requests.len() + } + + fn increment_connections(&mut self) { + self.connections += 1; + } + + fn decrement_connections(&mut self) { + if self.connections > 0 { + self.connections -= 1; + } + } + + fn connection_count(&self) -> u32 { + self.connections + } +} + +pub struct SecurityMiddleware { + config: Arc, + rate_limit_data: Arc>>, +} + +impl SecurityMiddleware { + pub fn new(config: Arc) -> Self { + Self { + config, + rate_limit_data: Arc::new(Mutex::new(HashMap::new())), + } + } + + pub fn is_method_allowed(&self, method: &GurtMethod) -> bool { + if let Some(security) = &self.config.security { + let method_str = method.to_string(); + security.allowed_methods.contains(&method_str) + } else { + true + } + } + + pub fn check_rate_limit(&self, client_ip: IpAddr) -> bool { + if let Some(security) = &self.config.security { + let mut data = self.rate_limit_data.lock().unwrap(); + let rate_data = data.entry(client_ip).or_insert_with(RateLimitData::new); + + rate_data.cleanup_old_requests(Duration::from_secs(60)); + + if rate_data.request_count() >= security.rate_limit_requests as usize { + warn!("Rate limit exceeded for IP {}: {} requests in the last minute", + client_ip, rate_data.request_count()); + return false; + } + + rate_data.add_request(); + debug!("Request from {}: {}/{} requests in the last minute", + client_ip, rate_data.request_count(), security.rate_limit_requests); + } + + true + } + + pub fn check_connection_limit(&self, client_ip: IpAddr) -> bool { + if let Some(security) = &self.config.security { + let mut data = self.rate_limit_data.lock().unwrap(); + let rate_data = data.entry(client_ip).or_insert_with(RateLimitData::new); + + if rate_data.connection_count() >= security.rate_limit_connections { + warn!("Connection limit exceeded for IP {}: {} concurrent connections", + client_ip, rate_data.connection_count()); + return false; + } + } + + true + } + + pub fn register_connection(&self, client_ip: IpAddr) { + if self.config.security.is_some() { + let mut data = self.rate_limit_data.lock().unwrap(); + let rate_data = data.entry(client_ip).or_insert_with(RateLimitData::new); + rate_data.increment_connections(); + debug!("Connection registered for {}: {} concurrent connections", + client_ip, rate_data.connection_count()); + } + } + + pub fn unregister_connection(&self, client_ip: IpAddr) { + if self.config.security.is_some() { + let mut data = self.rate_limit_data.lock().unwrap(); + if let Some(rate_data) = data.get_mut(&client_ip) { + rate_data.decrement_connections(); + debug!("Connection unregistered for {}: {} concurrent connections remaining", + client_ip, rate_data.connection_count()); + } + } + } + + pub fn create_method_not_allowed_response(&self) -> std::result::Result { + let response = GurtResponse::new(GurtStatusCode::MethodNotAllowed) + .with_header("Content-Type", "text/html"); + Ok(response) + } + + pub fn create_rate_limit_response(&self) -> std::result::Result { + let response = GurtResponse::new(GurtStatusCode::TooManyRequests) + .with_header("Content-Type", "text/html") + .with_header("Retry-After", "60"); + Ok(response) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::net::IpAddr; + use std::sync::Arc; + use std::time::Duration; + + fn create_test_config() -> Arc { + let mut config = crate::config::GurtConfig::default(); + config.security = Some(crate::config::SecurityConfig { + deny_files: vec!["*.secret".to_string(), "private/*".to_string()], + allowed_methods: vec!["GET".to_string(), "POST".to_string()], + rate_limit_requests: 5, + rate_limit_connections: 2, + }); + Arc::new(config) + } + + #[test] + fn test_rate_limit_data_initialization() { + let data = RateLimitData::new(); + + assert_eq!(data.request_count(), 0); + assert_eq!(data.connection_count(), 0); + } + + #[test] + fn test_rate_limit_data_request_tracking() { + let mut data = RateLimitData::new(); + + data.add_request(); + data.add_request(); + assert_eq!(data.request_count(), 2); + + data.cleanup_old_requests(Duration::from_secs(0)); + assert_eq!(data.request_count(), 0); + } + + #[test] + fn test_rate_limit_data_connection_tracking() { + let mut data = RateLimitData::new(); + + data.increment_connections(); + data.increment_connections(); + assert_eq!(data.connection_count(), 2); + + data.decrement_connections(); + assert_eq!(data.connection_count(), 1); + + data.decrement_connections(); + data.decrement_connections(); + assert_eq!(data.connection_count(), 0); + } + + #[test] + fn test_security_middleware_initialization() { + let config = create_test_config(); + let middleware = SecurityMiddleware::new(config.clone()); + + assert!(middleware.rate_limit_data.lock().unwrap().is_empty()); + } + + #[test] + fn test_connection_tracking() { + let config = create_test_config(); + let middleware = SecurityMiddleware::new(config.clone()); + let ip: IpAddr = "127.0.0.1".parse().unwrap(); + + middleware.register_connection(ip); + { + let data = middleware.rate_limit_data.lock().unwrap(); + assert_eq!(data.get(&ip).unwrap().connection_count(), 1); + } + + middleware.unregister_connection(ip); + { + let data = middleware.rate_limit_data.lock().unwrap(); + assert_eq!(data.get(&ip).unwrap().connection_count(), 0); + } + } + + #[test] + fn test_rate_limiting_requests() { + let config = create_test_config(); + let middleware = SecurityMiddleware::new(config.clone()); + let ip: IpAddr = "127.0.0.1".parse().unwrap(); + + for _ in 0..5 { + assert!(middleware.check_rate_limit(ip)); + } + + assert!(!middleware.check_rate_limit(ip)); + } + + #[test] + fn test_connection_limiting() { + let config = create_test_config(); + let middleware = SecurityMiddleware::new(config.clone()); + let ip: IpAddr = "127.0.0.1".parse().unwrap(); + + middleware.register_connection(ip); + middleware.register_connection(ip); + + assert!(!middleware.check_connection_limit(ip)); + } + + #[test] + fn test_method_validation() { + let config = create_test_config(); + let middleware = SecurityMiddleware::new(config.clone()); + + assert!(middleware.is_method_allowed(&GurtMethod::GET)); + assert!(middleware.is_method_allowed(&GurtMethod::POST)); + + assert!(!middleware.is_method_allowed(&GurtMethod::PUT)); + assert!(!middleware.is_method_allowed(&GurtMethod::DELETE)); + } + + #[test] + fn test_multiple_ips_isolation() { + let config = create_test_config(); + let middleware = SecurityMiddleware::new(config.clone()); + let ip1: IpAddr = "127.0.0.1".parse().unwrap(); + let ip2: IpAddr = "127.0.0.2".parse().unwrap(); + + for _ in 0..6 { + middleware.check_rate_limit(ip1); + } + + assert!(middleware.check_rate_limit(ip2)); + assert!(!middleware.check_rate_limit(ip1)); + } + + #[test] + fn test_response_creation() { + let config = create_test_config(); + let middleware = SecurityMiddleware::new(config.clone()); + + let response = middleware.create_method_not_allowed_response().unwrap(); + assert_eq!(response.status_code, 405); + + let response = middleware.create_rate_limit_response().unwrap(); + assert_eq!(response.status_code, 429); + } +} diff --git a/protocol/cli/src/server.rs b/protocol/cli/src/server.rs new file mode 100644 index 0000000..f772f4e --- /dev/null +++ b/protocol/cli/src/server.rs @@ -0,0 +1,301 @@ +use crate::{ + config::GurtConfig, + handlers::{FileHandler, DirectoryHandler, DefaultFileHandler, DefaultDirectoryHandler}, + request_handler::{RequestHandler, RequestHandlerBuilder}, +}; +use gurt::prelude::*; +use std::{path::PathBuf, sync::Arc}; + +pub struct FileServerBuilder { + config: GurtConfig, + file_handler: Arc, + directory_handler: Arc, +} + +impl FileServerBuilder { + pub fn new(config: GurtConfig) -> Self { + Self { + config, + file_handler: Arc::new(DefaultFileHandler), + directory_handler: Arc::new(DefaultDirectoryHandler), + } + } + + pub fn with_file_handler(mut self, handler: H) -> Self { + self.file_handler = Arc::new(handler); + self + } + + pub fn with_directory_handler(mut self, handler: H) -> Self { + self.directory_handler = Arc::new(handler); + self + } + + pub fn build(self) -> crate::Result { + let server = self.create_server()?; + let request_handler = self.create_request_handler(); + let server_with_routes = self.add_routes(server, request_handler); + Ok(server_with_routes) + } + + fn create_server(&self) -> crate::Result { + match &self.config.tls { + Some(tls) => { + println!("TLS using certificate: {}", tls.certificate.display()); + GurtServerBuilder::new() + .with_tls_certificates(&tls.certificate, &tls.private_key) + .with_timeouts( + self.config.get_handshake_timeout(), + self.config.get_request_timeout(), + self.config.get_connection_timeout(), + ) + .build() + } + None => { + Err(crate::ServerError::TlsConfiguration( + "GURT protocol requires TLS encryption. Please provide --cert and --key parameters.".to_string() + )) + } + } + } + + fn create_request_handler(&self) -> RequestHandler { + RequestHandlerBuilder::new(&*self.config.server.base_directory) + .with_file_handler(DefaultFileHandler) + .with_directory_handler(DefaultDirectoryHandler) + .with_config(Arc::new(self.config.clone())) + .build() + } + + fn add_routes(self, server: GurtServer, request_handler: RequestHandler) -> GurtServer { + let request_handler = Arc::new(request_handler); + + let server = server + .get("/", { + let handler = request_handler.clone(); + move |ctx| { + let handler = handler.clone(); + let ctx_clone = ctx.clone(); + async move { + handler.handle_root_request_with_context(ctx_clone).await + } + } + }) + .get("/*", { + let handler = request_handler.clone(); + move |ctx| { + let handler = handler.clone(); + let path = ctx.path().to_string(); + let ctx_clone = ctx.clone(); + async move { + handler.handle_file_request_with_context(&path, ctx_clone).await + } + } + }); + + let server = server + .post("/", { + let handler = request_handler.clone(); + move |ctx| { + let handler = handler.clone(); + let ctx_clone = ctx.clone(); + async move { + handler.handle_method_request_with_context(ctx_clone).await + } + } + }) + .post("/*", { + let handler = request_handler.clone(); + move |ctx| { + let handler = handler.clone(); + let ctx_clone = ctx.clone(); + async move { + handler.handle_method_request_with_context(ctx_clone).await + } + } + }) + .put("/", { + let handler = request_handler.clone(); + move |ctx| { + let handler = handler.clone(); + let ctx_clone = ctx.clone(); + async move { + handler.handle_method_request_with_context(ctx_clone).await + } + } + }) + .put("/*", { + let handler = request_handler.clone(); + move |ctx| { + let handler = handler.clone(); + let ctx_clone = ctx.clone(); + async move { + handler.handle_method_request_with_context(ctx_clone).await + } + } + }) + .delete("/", { + let handler = request_handler.clone(); + move |ctx| { + let handler = handler.clone(); + let ctx_clone = ctx.clone(); + async move { + handler.handle_method_request_with_context(ctx_clone).await + } + } + }) + .delete("/*", { + let handler = request_handler.clone(); + move |ctx| { + let handler = handler.clone(); + let ctx_clone = ctx.clone(); + async move { + handler.handle_method_request_with_context(ctx_clone).await + } + } + }) + .patch("/", { + let handler = request_handler.clone(); + move |ctx| { + let handler = handler.clone(); + let ctx_clone = ctx.clone(); + async move { + handler.handle_method_request_with_context(ctx_clone).await + } + } + }) + .patch("/*", { + let handler = request_handler.clone(); + move |ctx| { + let handler = handler.clone(); + let ctx_clone = ctx.clone(); + async move { + handler.handle_method_request_with_context(ctx_clone).await + } + } + }) + .options("/", { + let handler = request_handler.clone(); + move |ctx| { + let handler = handler.clone(); + let ctx_clone = ctx.clone(); + async move { + handler.handle_method_request_with_context(ctx_clone).await + } + } + }) + .options("/*", { + let handler = request_handler.clone(); + move |ctx| { + let handler = handler.clone(); + let ctx_clone = ctx.clone(); + async move { + handler.handle_method_request_with_context(ctx_clone).await + } + } + }) + .head("/", { + let handler = request_handler.clone(); + move |ctx| { + let handler = handler.clone(); + let ctx_clone = ctx.clone(); + async move { + handler.handle_method_request_with_context(ctx_clone).await + } + } + }) + .head("/*", { + let handler = request_handler.clone(); + move |ctx| { + let handler = handler.clone(); + let ctx_clone = ctx.clone(); + async move { + handler.handle_method_request_with_context(ctx_clone).await + } + } + }); + + server + } +} + + +pub struct GurtServerBuilder { + cert_path: Option, + key_path: Option, + host: Option, + port: Option, + handshake_timeout: Option, + request_timeout: Option, + connection_timeout: Option, +} + +impl GurtServerBuilder { + pub fn new() -> Self { + Self { + cert_path: None, + key_path: None, + host: None, + port: None, + handshake_timeout: None, + request_timeout: None, + connection_timeout: None, + } + } + + pub fn with_tls_certificates>(mut self, cert_path: P, key_path: P) -> Self { + self.cert_path = Some(cert_path.into()); + self.key_path = Some(key_path.into()); + self + } + + pub fn with_host>(mut self, host: S) -> Self { + self.host = Some(host.into()); + self + } + + pub fn with_port(mut self, port: u16) -> Self { + self.port = Some(port); + self + } + + pub fn with_timeouts(mut self, handshake_timeout: std::time::Duration, request_timeout: std::time::Duration, connection_timeout: std::time::Duration) -> Self { + self.handshake_timeout = Some(handshake_timeout); + self.request_timeout = Some(request_timeout); + self.connection_timeout = Some(connection_timeout); + self + } + + pub fn build(self) -> crate::Result { + match (self.cert_path, self.key_path) { + (Some(cert), Some(key)) => { + let mut server = GurtServer::with_tls_certificates( + cert.to_str().ok_or_else(|| { + crate::ServerError::TlsConfiguration("Invalid certificate path".to_string()) + })?, + key.to_str().ok_or_else(|| { + crate::ServerError::TlsConfiguration("Invalid key path".to_string()) + })? + ).map_err(crate::ServerError::from)?; + + if let (Some(handshake), Some(request), Some(connection)) = + (self.handshake_timeout, self.request_timeout, self.connection_timeout) { + server = server.with_timeouts(handshake, request, connection); + } + + Ok(server) + } + _ => { + Err(crate::ServerError::TlsConfiguration( + "TLS certificates are required. Use with_tls_certificates() to provide them.".to_string() + )) + } + } + } +} + +impl Default for GurtServerBuilder { + fn default() -> Self { + Self::new() + } +} \ No newline at end of file diff --git a/protocol/cli/templates/404.html b/protocol/cli/templates/404.html new file mode 100644 index 0000000..7caa7dc --- /dev/null +++ b/protocol/cli/templates/404.html @@ -0,0 +1,24 @@ + + + + 404 Not Found + + + +
+

404 Page Not Found

+

The requested path was not found on this GURT server.

+

Back to home

+
+ + diff --git a/protocol/cli/templates/directory_content_start.html b/protocol/cli/templates/directory_content_start.html new file mode 100644 index 0000000..7892025 --- /dev/null +++ b/protocol/cli/templates/directory_content_start.html @@ -0,0 +1 @@ +
diff --git a/protocol/cli/templates/directory_listing_end.html b/protocol/cli/templates/directory_listing_end.html new file mode 100644 index 0000000..ce29789 --- /dev/null +++ b/protocol/cli/templates/directory_listing_end.html @@ -0,0 +1,3 @@ +
+ + diff --git a/protocol/cli/templates/directory_listing_start.html b/protocol/cli/templates/directory_listing_start.html new file mode 100644 index 0000000..0ecb69e --- /dev/null +++ b/protocol/cli/templates/directory_listing_start.html @@ -0,0 +1,9 @@ + + Directory Listing + + + +

Directory Listing

diff --git a/protocol/cli/templates/directory_parent_link.html b/protocol/cli/templates/directory_parent_link.html new file mode 100644 index 0000000..77d0f81 --- /dev/null +++ b/protocol/cli/templates/directory_parent_link.html @@ -0,0 +1 @@ +

← Parent Directory

diff --git a/protocol/cli/templates/error.html b/protocol/cli/templates/error.html new file mode 100644 index 0000000..2862045 --- /dev/null +++ b/protocol/cli/templates/error.html @@ -0,0 +1,21 @@ + + {} {} + + + +
+

{} {}

+

{}

+

Back to home

+
+ \ No newline at end of file diff --git a/protocol/library/src/client.rs b/protocol/library/src/client.rs index f508442..2d5269b 100644 --- a/protocol/library/src/client.rs +++ b/protocol/library/src/client.rs @@ -184,12 +184,11 @@ impl GurtClient { let bytes_read = conn.connection.read(&mut temp_buffer).await?; if bytes_read == 0 { - break; // Connection closed + break; } buffer.extend_from_slice(&temp_buffer[..bytes_read]); - // Check for complete message let body_separator = BODY_SEPARATOR.as_bytes(); if !headers_parsed { @@ -197,7 +196,6 @@ impl GurtClient { headers_end_pos = Some(pos + body_separator.len()); headers_parsed = true; - // Parse headers to get Content-Length let headers_section = &buffer[..pos]; if let Ok(headers_str) = std::str::from_utf8(headers_section) { for line in headers_str.lines() { @@ -220,7 +218,6 @@ impl GurtClient { return Ok(buffer); } } else if headers_parsed && expected_body_length.is_none() { - // No Content-Length header, return what we have after headers return Ok(buffer); } } @@ -329,7 +326,7 @@ impl GurtClient { } match timeout(Duration::from_millis(100), tls_stream.read(&mut temp_buffer)).await { - Ok(Ok(0)) => break, // Connection closed + Ok(Ok(0)) => break, Ok(Ok(n)) => { buffer.extend_from_slice(&temp_buffer[..n]); @@ -394,7 +391,6 @@ impl GurtClient { self.send_request_internal(&host, port, request).await } - /// POST request with JSON body pub async fn post_json(&self, url: &str, data: &T) -> Result { let (host, port, path) = self.parse_url(url)?; let json_body = serde_json::to_string(data)?; @@ -408,7 +404,6 @@ impl GurtClient { self.send_request_internal(&host, port, request).await } - /// PUT request with body pub async fn put(&self, url: &str, body: &str) -> Result { let (host, port, path) = self.parse_url(url)?; let request = GurtRequest::new(GurtMethod::PUT, path) @@ -420,7 +415,6 @@ impl GurtClient { self.send_request_internal(&host, port, request).await } - /// PUT request with JSON body pub async fn put_json(&self, url: &str, data: &T) -> Result { let (host, port, path) = self.parse_url(url)?; let json_body = serde_json::to_string(data)?; @@ -461,7 +455,6 @@ impl GurtClient { self.send_request_internal(&host, port, request).await } - /// PATCH request with body pub async fn patch(&self, url: &str, body: &str) -> Result { let (host, port, path) = self.parse_url(url)?; let request = GurtRequest::new(GurtMethod::PATCH, path) @@ -473,7 +466,6 @@ impl GurtClient { self.send_request_internal(&host, port, request).await } - /// PATCH request with JSON body pub async fn patch_json(&self, url: &str, data: &T) -> Result { let (host, port, path) = self.parse_url(url)?; let json_body = serde_json::to_string(data)?; @@ -548,4 +540,38 @@ mod tests { assert_eq!(port, 8080); assert_eq!(path, "/api/v1"); } + + #[test] + fn test_connection_pooling_config() { + let config = GurtClientConfig { + enable_connection_pooling: true, + max_connections_per_host: 8, + ..Default::default() + }; + + let client = GurtClient::with_config(config); + assert!(client.config.enable_connection_pooling); + assert_eq!(client.config.max_connections_per_host, 8); + } + + #[test] + fn test_connection_key() { + let key1 = ConnectionKey { + host: "example.com".to_string(), + port: 4878, + }; + + let key2 = ConnectionKey { + host: "example.com".to_string(), + port: 4878, + }; + + let key3 = ConnectionKey { + host: "other.com".to_string(), + port: 4878, + }; + + assert_eq!(key1, key2); + assert_ne!(key1, key3); + } } \ No newline at end of file diff --git a/protocol/library/src/message.rs b/protocol/library/src/message.rs index 4ebf0e0..2b6044a 100644 --- a/protocol/library/src/message.rs +++ b/protocol/library/src/message.rs @@ -14,7 +14,7 @@ pub enum GurtMethod { HEAD, OPTIONS, PATCH, - HANDSHAKE, // Special method for protocol handshake + HANDSHAKE, } impl GurtMethod { @@ -101,7 +101,6 @@ impl GurtRequest { } pub fn parse_bytes(data: &[u8]) -> Result { - // Find the header/body separator as bytes let body_separator = BODY_SEPARATOR.as_bytes(); let body_separator_pos = data.windows(body_separator.len()) .position(|window| window == body_separator); @@ -114,7 +113,6 @@ impl GurtRequest { (data, Vec::new()) }; - // Convert headers section to string (should be valid UTF-8) let headers_str = std::str::from_utf8(headers_section) .map_err(|_| GurtError::invalid_message("Invalid UTF-8 in headers"))?; @@ -124,7 +122,6 @@ impl GurtRequest { return Err(GurtError::invalid_message("Empty request")); } - // Parse request line (METHOD path GURT/version) let request_line = lines[0]; let parts: Vec<&str> = request_line.split_whitespace().collect(); @@ -135,7 +132,6 @@ impl GurtRequest { let method = GurtMethod::parse(parts[0])?; let path = parts[1].to_string(); - // Parse protocol version if !parts[2].starts_with(PROTOCOL_PREFIX) { return Err(GurtError::invalid_message("Invalid protocol identifier")); } @@ -143,7 +139,6 @@ impl GurtRequest { let version_str = &parts[2][PROTOCOL_PREFIX.len()..]; let version = version_str.to_string(); - // Parse headers let mut headers = GurtHeaders::new(); for line in lines.iter().skip(1) { @@ -253,6 +248,10 @@ impl GurtResponse { Self::new(GurtStatusCode::BadRequest) } + pub fn forbidden() -> Self { + Self::new(GurtStatusCode::Forbidden) + } + pub fn internal_server_error() -> Self { Self::new(GurtStatusCode::InternalServerError) } @@ -306,7 +305,6 @@ impl GurtResponse { } pub fn parse_bytes(data: &[u8]) -> Result { - // Find the header/body separator as bytes let body_separator = BODY_SEPARATOR.as_bytes(); let body_separator_pos = data.windows(body_separator.len()) .position(|window| window == body_separator); @@ -319,7 +317,6 @@ impl GurtResponse { (data, Vec::new()) }; - // Convert headers section to string (should be valid UTF-8) let headers_str = std::str::from_utf8(headers_section) .map_err(|_| GurtError::invalid_message("Invalid UTF-8 in headers"))?; @@ -329,7 +326,6 @@ impl GurtResponse { return Err(GurtError::invalid_message("Empty response")); } - // Parse status line (GURT/version status_code status_message) let status_line = lines[0]; let parts: Vec<&str> = status_line.splitn(3, ' ').collect(); @@ -337,7 +333,6 @@ impl GurtResponse { return Err(GurtError::invalid_message("Invalid status line format")); } - // Parse protocol version if !parts[0].starts_with(PROTOCOL_PREFIX) { return Err(GurtError::invalid_message("Invalid protocol identifier")); } @@ -356,7 +351,6 @@ impl GurtResponse { .unwrap_or_else(|| "Unknown".to_string()) }; - // Parse headers let mut headers = GurtHeaders::new(); for line in lines.iter().skip(1) { @@ -394,7 +388,6 @@ impl GurtResponse { } if !headers.contains_key("date") { - // RFC 7231 compliant let now = Utc::now(); let date_str = now.format("%a, %d %b %Y %H:%M:%S GMT").to_string(); headers.insert("date".to_string(), date_str); @@ -429,7 +422,6 @@ impl GurtResponse { } if !headers.contains_key("date") { - // RFC 7231 compliant let now = Utc::now(); let date_str = now.format("%a, %d %b %Y %H:%M:%S GMT").to_string(); headers.insert("date".to_string(), date_str); @@ -441,7 +433,6 @@ impl GurtResponse { message.push_str(HEADER_SEPARATOR); - // Convert headers to bytes and append body as raw bytes let mut bytes = message.into_bytes(); bytes.extend_from_slice(&self.body); @@ -461,7 +452,6 @@ impl GurtMessage { } pub fn parse_bytes(data: &[u8]) -> Result { - // Convert first line to string to determine message type let header_separator = HEADER_SEPARATOR.as_bytes(); let first_line_end = data.windows(header_separator.len()) .position(|window| window == header_separator) @@ -470,7 +460,6 @@ impl GurtMessage { let first_line = std::str::from_utf8(&data[..first_line_end]) .map_err(|_| GurtError::invalid_message("Invalid UTF-8 in first line"))?; - // Check if it's a response (starts with GURT/version) or request (method first) if first_line.starts_with(PROTOCOL_PREFIX) { Ok(GurtMessage::Response(GurtResponse::parse_bytes(data)?)) } else { diff --git a/protocol/library/src/protocol.rs b/protocol/library/src/protocol.rs index 9066d4b..9415a77 100644 --- a/protocol/library/src/protocol.rs +++ b/protocol/library/src/protocol.rs @@ -37,6 +37,7 @@ pub enum GurtStatusCode { Timeout = 408, TooLarge = 413, UnsupportedMediaType = 415, + TooManyRequests = 429, // Server errors InternalServerError = 500, @@ -62,6 +63,7 @@ impl GurtStatusCode { 408 => Some(Self::Timeout), 413 => Some(Self::TooLarge), 415 => Some(Self::UnsupportedMediaType), + 429 => Some(Self::TooManyRequests), 500 => Some(Self::InternalServerError), 501 => Some(Self::NotImplemented), 502 => Some(Self::BadGateway), @@ -86,6 +88,7 @@ impl GurtStatusCode { Self::Timeout => "TIMEOUT", Self::TooLarge => "TOO_LARGE", Self::UnsupportedMediaType => "UNSUPPORTED_MEDIA_TYPE", + Self::TooManyRequests => "TOO_MANY_REQUESTS", Self::InternalServerError => "INTERNAL_SERVER_ERROR", Self::NotImplemented => "NOT_IMPLEMENTED", Self::BadGateway => "BAD_GATEWAY", diff --git a/protocol/library/src/server.rs b/protocol/library/src/server.rs index f5f7182..faabcbf 100644 --- a/protocol/library/src/server.rs +++ b/protocol/library/src/server.rs @@ -7,6 +7,7 @@ use crate::{ }; use tokio::net::{TcpListener, TcpStream}; use tokio::io::{AsyncReadExt, AsyncWriteExt}; +use tokio::time::{timeout, Duration}; use tokio_rustls::{TlsAcceptor, server::TlsStream}; use rustls::pki_types::CertificateDer; use std::collections::HashMap; @@ -136,6 +137,9 @@ impl Route { pub struct GurtServer { routes: Vec<(Route, Arc)>, tls_acceptor: Option, + handshake_timeout: Duration, + request_timeout: Duration, + connection_timeout: Duration, } impl GurtServer { @@ -143,9 +147,19 @@ impl GurtServer { Self { routes: Vec::new(), tls_acceptor: None, + handshake_timeout: Duration::from_secs(5), + request_timeout: Duration::from_secs(30), + connection_timeout: Duration::from_secs(10), } } + pub fn with_timeouts(mut self, handshake_timeout: Duration, request_timeout: Duration, connection_timeout: Duration) -> Self { + self.handshake_timeout = handshake_timeout; + self.request_timeout = request_timeout; + self.connection_timeout = connection_timeout; + self + } + pub fn with_tls_certificates(cert_path: &str, key_path: &str) -> Result { let mut server = Self::new(); server.load_tls_certificates(cert_path, key_path)?; @@ -279,57 +293,76 @@ impl GurtServer { } async fn handle_connection(&self, mut stream: TcpStream, addr: SocketAddr) -> Result<()> { - self.handle_initial_handshake(&mut stream, addr).await?; + let connection_result = timeout(self.connection_timeout, async { + self.handle_initial_handshake(&mut stream, addr).await?; + + if let Some(tls_acceptor) = &self.tls_acceptor { + info!("Upgrading connection to TLS for {}", addr); + let tls_stream = tls_acceptor.accept(stream).await + .map_err(|e| GurtError::crypto(format!("TLS upgrade failed: {}", e)))?; + + info!("TLS upgrade completed for {}", addr); + + self.handle_tls_connection(tls_stream, addr).await + } else { + warn!("No TLS configuration available, but handshake completed - this violates GURT protocol"); + Err(GurtError::protocol("TLS is required after handshake but no TLS configuration available")) + } + }).await; - if let Some(tls_acceptor) = &self.tls_acceptor { - info!("Upgrading connection to TLS for {}", addr); - let tls_stream = tls_acceptor.accept(stream).await - .map_err(|e| GurtError::crypto(format!("TLS upgrade failed: {}", e)))?; - - info!("TLS upgrade completed for {}", addr); - - self.handle_tls_connection(tls_stream, addr).await - } else { - warn!("No TLS configuration available, but handshake completed - this violates GURT protocol"); - Err(GurtError::protocol("TLS is required after handshake but no TLS configuration available")) + match connection_result { + Ok(result) => result, + Err(_) => { + warn!("Connection timeout for {}", addr); + Err(GurtError::timeout("Connection timeout")) + } } } async fn handle_initial_handshake(&self, stream: &mut TcpStream, addr: SocketAddr) -> Result<()> { - let mut buffer = Vec::new(); - let mut temp_buffer = [0u8; 8192]; - - loop { - let bytes_read = stream.read(&mut temp_buffer).await?; - if bytes_read == 0 { - return Err(GurtError::connection("Connection closed during handshake")); - } + let handshake_result = timeout(self.handshake_timeout, async { + let mut buffer = Vec::new(); + let mut temp_buffer = [0u8; 8192]; - buffer.extend_from_slice(&temp_buffer[..bytes_read]); - - let body_separator = BODY_SEPARATOR.as_bytes(); - if buffer.windows(body_separator.len()).any(|w| w == body_separator) { - break; - } - - if buffer.len() > MAX_MESSAGE_SIZE { - return Err(GurtError::protocol("Handshake message too large")); - } - } - - - let message = GurtMessage::parse_bytes(&buffer)?; - - match message { - GurtMessage::Request(request) => { - if request.method == GurtMethod::HANDSHAKE { - self.send_handshake_response(stream, addr, &request).await - } else { - Err(GurtError::protocol("First message must be HANDSHAKE")) + loop { + let bytes_read = stream.read(&mut temp_buffer).await?; + if bytes_read == 0 { + return Err(GurtError::connection("Connection closed during handshake")); + } + + buffer.extend_from_slice(&temp_buffer[..bytes_read]); + + let body_separator = BODY_SEPARATOR.as_bytes(); + if buffer.windows(body_separator.len()).any(|w| w == body_separator) { + break; + } + + if buffer.len() > MAX_MESSAGE_SIZE { + return Err(GurtError::protocol("Handshake message too large")); } } - GurtMessage::Response(_) => { - Err(GurtError::protocol("Server received response during handshake")) + + let message = GurtMessage::parse_bytes(&buffer)?; + + match message { + GurtMessage::Request(request) => { + if request.method == GurtMethod::HANDSHAKE { + self.send_handshake_response(stream, addr, &request).await + } else { + Err(GurtError::protocol("First message must be HANDSHAKE")) + } + } + GurtMessage::Response(_) => { + Err(GurtError::protocol("Server received response during handshake")) + } + } + }).await; + + match handshake_result { + Ok(result) => result, + Err(_) => { + warn!("Handshake timeout for {}", addr); + Err(GurtError::timeout("Handshake timeout")) } } } @@ -342,7 +375,6 @@ impl GurtServer { let bytes_read = match tls_stream.read(&mut temp_buffer).await { Ok(n) => n, Err(e) => { - // Handle UnexpectedEof from clients that don't send close_notify if e.kind() == std::io::ErrorKind::UnexpectedEof { debug!("Client {} closed connection without TLS close_notify (benign)", addr); break; @@ -351,7 +383,7 @@ impl GurtServer { } }; if bytes_read == 0 { - break; // Connection closed + break; } buffer.extend_from_slice(&temp_buffer[..bytes_read]); @@ -361,17 +393,31 @@ impl GurtServer { (buffer.starts_with(b"{") && buffer.ends_with(b"}")); if has_complete_message { - if let Err(e) = self.process_tls_message(&mut tls_stream, addr, &buffer).await { - error!("Encrypted message processing error from {}: {}", addr, e); - let error_response = GurtResponse::internal_server_error() - .with_string_body("Internal server error"); - let _ = tls_stream.write_all(&error_response.to_bytes()).await; + let process_result = timeout(self.request_timeout, + self.process_tls_message(&mut tls_stream, addr, &buffer) + ).await; + + match process_result { + Ok(Ok(())) => { + debug!("Processed message from {} successfully", addr); + } + Ok(Err(e)) => { + error!("Encrypted message processing error from {}: {}", addr, e); + let error_response = GurtResponse::internal_server_error() + .with_string_body("Internal server error"); + let _ = tls_stream.write_all(&error_response.to_bytes()).await; + } + Err(_) => { + warn!("Request timeout for {}", addr); + let timeout_response = GurtResponse::new(GurtStatusCode::Timeout) + .with_string_body("Request timeout"); + let _ = tls_stream.write_all(&timeout_response.to_bytes()).await; + } } buffer.clear(); } - // Prevent buffer overflow if buffer.len() > MAX_MESSAGE_SIZE { warn!("Message too large from {}, closing connection", addr); break; @@ -422,7 +468,6 @@ impl GurtServer { if let Some(method) = &route.method { allowed_methods.insert(method.to_string()); } else { - // Route matches any method allowed_methods.extend(vec![ "GET".to_string(), "POST".to_string(), "PUT".to_string(), "DELETE".to_string(), "HEAD".to_string(), "PATCH".to_string() @@ -482,7 +527,6 @@ impl GurtServer { async fn handle_encrypted_request(&self, tls_stream: &mut TlsStream, addr: SocketAddr, request: &GurtRequest) -> Result<()> { debug!("Handling encrypted {} request to {} from {}", request.method, request.path, addr); - // Find matching route for (route, handler) in &self.routes { if route.matches(&request.method, &request.path) { let context = ServerContext { @@ -492,7 +536,6 @@ impl GurtServer { match handler.handle(&context).await { Ok(response) => { - // Use to_bytes() to avoid corrupting binary data let response_bytes = response.to_bytes(); tls_stream.write_all(&response_bytes).await?; return Ok(()); @@ -508,7 +551,6 @@ impl GurtServer { } } - // No route found - check for default OPTIONS/HEAD handling match request.method { GurtMethod::OPTIONS => { self.handle_default_options(tls_stream, request).await @@ -531,6 +573,9 @@ impl Clone for GurtServer { Self { routes: self.routes.clone(), tls_acceptor: self.tls_acceptor.clone(), + handshake_timeout: self.handshake_timeout, + request_timeout: self.request_timeout, + connection_timeout: self.connection_timeout, } } }