Gurty -> usable state
This commit is contained in:
0
protocol/cli/.gitignore
vendored
Normal file
0
protocol/cli/.gitignore
vendored
Normal file
96
protocol/cli/Cargo.lock
generated
96
protocol/cli/Cargo.lock
generated
@@ -91,6 +91,17 @@ dependencies = [
|
||||
"windows-sys 0.60.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "async-trait"
|
||||
version = "0.1.89"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9035ad2d096bed7955a320ee7e2230574d28fd3c3a0f186cbea1ff3c7eed5dbb"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "autocfg"
|
||||
version = "1.5.0"
|
||||
@@ -338,6 +349,12 @@ version = "1.15.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "48c757948c5ede0e46177b7add2e67155f70e33c07fea8284df6576da70b3719"
|
||||
|
||||
[[package]]
|
||||
name = "equivalent"
|
||||
version = "1.0.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "877a4ace8713b0bcf2a4e7eec82529c029f1d0619886d18145fea96c3ffe5c0f"
|
||||
|
||||
[[package]]
|
||||
name = "errno"
|
||||
version = "0.3.13"
|
||||
@@ -420,17 +437,27 @@ dependencies = [
|
||||
name = "gurty"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"async-trait",
|
||||
"clap",
|
||||
"colored",
|
||||
"gurt",
|
||||
"indexmap",
|
||||
"mime_guess",
|
||||
"regex",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"tokio",
|
||||
"toml",
|
||||
"tracing",
|
||||
"tracing-subscriber",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "hashbrown"
|
||||
version = "0.15.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9229cfe53dfd69f0609a49f65461bd93001ea1ef889cd5529dd176593f5338a1"
|
||||
|
||||
[[package]]
|
||||
name = "heck"
|
||||
version = "0.5.0"
|
||||
@@ -577,6 +604,16 @@ dependencies = [
|
||||
"icu_properties",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "indexmap"
|
||||
version = "2.10.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "fe4cd85333e22411419a0bcae1297d25e58c9443848b11dc6a86fefe8c78a661"
|
||||
dependencies = [
|
||||
"equivalent",
|
||||
"hashbrown",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "io-uring"
|
||||
version = "0.7.9"
|
||||
@@ -1089,6 +1126,15 @@ dependencies = [
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "serde_spanned"
|
||||
version = "0.6.9"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "bf41e0cfaf7226dca15e8197172c295a782857fcb97fad1808a166870dee75a3"
|
||||
dependencies = [
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "sharded-slab"
|
||||
version = "0.1.7"
|
||||
@@ -1255,6 +1301,47 @@ dependencies = [
|
||||
"tokio",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "toml"
|
||||
version = "0.8.23"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "dc1beb996b9d83529a9e75c17a1686767d148d70663143c7854d8b4a09ced362"
|
||||
dependencies = [
|
||||
"serde",
|
||||
"serde_spanned",
|
||||
"toml_datetime",
|
||||
"toml_edit",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "toml_datetime"
|
||||
version = "0.6.11"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "22cddaf88f4fbc13c51aebbf5f8eceb5c7c5a9da2ac40a13519eb5b0a0e8f11c"
|
||||
dependencies = [
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "toml_edit"
|
||||
version = "0.22.27"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "41fe8c660ae4257887cf66394862d21dbca4a6ddd26f04a3560410406a2f819a"
|
||||
dependencies = [
|
||||
"indexmap",
|
||||
"serde",
|
||||
"serde_spanned",
|
||||
"toml_datetime",
|
||||
"toml_write",
|
||||
"winnow",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "toml_write"
|
||||
version = "0.1.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5d99f8c9a7727884afe522e9bd5edbfc91a3312b36a77b5fb8926e4c31a41801"
|
||||
|
||||
[[package]]
|
||||
name = "tracing"
|
||||
version = "0.1.41"
|
||||
@@ -1681,6 +1768,15 @@ version = "0.53.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "271414315aff87387382ec3d271b52d7ae78726f5d44ac98b4f4030c91880486"
|
||||
|
||||
[[package]]
|
||||
name = "winnow"
|
||||
version = "0.7.12"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f3edebf492c8125044983378ecb5766203ad3b4c2f7a922bd7dd207f6d443e95"
|
||||
dependencies = [
|
||||
"memchr",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "wit-bindgen-rt"
|
||||
version = "0.39.0"
|
||||
|
||||
@@ -23,3 +23,7 @@ tracing-subscriber = "0.3"
|
||||
clap = { version = "4.0", features = ["derive"] }
|
||||
colored = "2.0"
|
||||
mime_guess = "2.0"
|
||||
async-trait = "0.1"
|
||||
toml = "0.8"
|
||||
regex = "1.0"
|
||||
indexmap = "2.0"
|
||||
60
protocol/cli/gurty.toml
Normal file
60
protocol/cli/gurty.toml
Normal file
@@ -0,0 +1,60 @@
|
||||
[server]
|
||||
host = "127.0.0.1"
|
||||
port = 4878
|
||||
protocol_version = "1.0.0"
|
||||
alpn_identifier = "GURT/1.0"
|
||||
max_connections = 10
|
||||
max_message_size = "10MB"
|
||||
|
||||
[server.timeouts]
|
||||
handshake = 5
|
||||
request = 30
|
||||
connection = 10
|
||||
pool_idle = 300
|
||||
|
||||
[tls]
|
||||
certificate = "/path/to/certificate.pem"
|
||||
private_key = "/path/to/private_key.pem"
|
||||
|
||||
[logging]
|
||||
level = "info"
|
||||
# access_log = "/var/log/gurty/access.log"
|
||||
# error_log = "/var/log/gurty/error.log"
|
||||
log_requests = true
|
||||
log_responses = false
|
||||
|
||||
[security]
|
||||
deny_files = [
|
||||
"*.env",
|
||||
"*.config",
|
||||
".git/*",
|
||||
"node_modules/*",
|
||||
"*.key",
|
||||
"*.pem"
|
||||
]
|
||||
|
||||
allowed_methods = ["GET", "POST", "PUT", "DELETE", "HEAD", "OPTIONS", "PATCH"]
|
||||
rate_limit_requests = 100 # requests per minute
|
||||
rate_limit_connections = 1000 # concurrent connections per IP
|
||||
|
||||
# Error pages configuration
|
||||
[error_pages]
|
||||
# Specific error pages (uncomment and set paths to custom files)
|
||||
# "400" = "/errors/400.html"
|
||||
# "401" = "/errors/401.html"
|
||||
# "403" = "/errors/403.html"
|
||||
# "404" = "/errors/404.html"
|
||||
# "405" = "/errors/405.html"
|
||||
# "429" = "/errors/429.html"
|
||||
# "500" = "/errors/500.html"
|
||||
# "503" = "/errors/503.html"
|
||||
|
||||
[error_pages.default]
|
||||
"400" = '''<!DOCTYPE html>
|
||||
<html><head><title>400 Bad Request</title></head>
|
||||
<body><h1>400 - Bad Request</h1><p>The request could not be understood by the server.</p><a href="/">Back to home</a></body></html>'''
|
||||
|
||||
[headers]
|
||||
server = "GURT/1.0.0"
|
||||
"x-frame-options" = "SAMEORIGIN"
|
||||
"x-content-type-options" = "nosniff"
|
||||
88
protocol/cli/src/cli.rs
Normal file
88
protocol/cli/src/cli.rs
Normal file
@@ -0,0 +1,88 @@
|
||||
use clap::{Parser, Subcommand};
|
||||
use std::path::PathBuf;
|
||||
|
||||
#[derive(Parser)]
|
||||
#[command(name = "server")]
|
||||
#[command(about = "GURT Protocol Server")]
|
||||
#[command(version = "1.0.0")]
|
||||
pub struct Cli {
|
||||
#[command(subcommand)]
|
||||
pub command: Commands,
|
||||
}
|
||||
|
||||
#[derive(Subcommand)]
|
||||
pub enum Commands {
|
||||
Serve(ServeCommand),
|
||||
}
|
||||
|
||||
#[derive(Parser)]
|
||||
pub struct ServeCommand {
|
||||
#[arg(short, long, help = "Configuration file path")]
|
||||
pub config: Option<PathBuf>,
|
||||
|
||||
#[arg(short, long, default_value_t = 4878)]
|
||||
pub port: u16,
|
||||
|
||||
#[arg(long, default_value = "127.0.0.1")]
|
||||
pub host: String,
|
||||
|
||||
#[arg(short, long, default_value = ".")]
|
||||
pub dir: PathBuf,
|
||||
|
||||
#[arg(short, long)]
|
||||
pub verbose: bool,
|
||||
|
||||
#[arg(long, help = "Path to TLS certificate file")]
|
||||
pub cert: Option<PathBuf>,
|
||||
|
||||
#[arg(long, help = "Path to TLS private key file")]
|
||||
pub key: Option<PathBuf>,
|
||||
}
|
||||
|
||||
impl ServeCommand {
|
||||
pub fn validate(&self) -> crate::Result<()> {
|
||||
if !self.dir.exists() {
|
||||
return Err(crate::ServerError::InvalidPath(
|
||||
format!("Directory does not exist: {}", self.dir.display())
|
||||
));
|
||||
}
|
||||
|
||||
if !self.dir.is_dir() {
|
||||
return Err(crate::ServerError::InvalidPath(
|
||||
format!("Path is not a directory: {}", self.dir.display())
|
||||
));
|
||||
}
|
||||
|
||||
match (&self.cert, &self.key) {
|
||||
(Some(cert), Some(key)) => {
|
||||
if !cert.exists() {
|
||||
return Err(crate::ServerError::TlsConfiguration(
|
||||
format!("Certificate file does not exist: {}", cert.display())
|
||||
));
|
||||
}
|
||||
if !key.exists() {
|
||||
return Err(crate::ServerError::TlsConfiguration(
|
||||
format!("Key file does not exist: {}", key.display())
|
||||
));
|
||||
}
|
||||
}
|
||||
(Some(_), None) => {
|
||||
return Err(crate::ServerError::TlsConfiguration(
|
||||
"Certificate provided but no key file specified (use --key)".to_string()
|
||||
));
|
||||
}
|
||||
(None, Some(_)) => {
|
||||
return Err(crate::ServerError::TlsConfiguration(
|
||||
"Key provided but no certificate file specified (use --cert)".to_string()
|
||||
));
|
||||
}
|
||||
(None, None) => {
|
||||
return Err(crate::ServerError::TlsConfiguration(
|
||||
"GURT protocol requires TLS encryption. Please provide --cert and --key parameters.".to_string()
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
160
protocol/cli/src/command_handler.rs
Normal file
160
protocol/cli/src/command_handler.rs
Normal file
@@ -0,0 +1,160 @@
|
||||
use crate::{
|
||||
cli::ServeCommand,
|
||||
config::GurtConfig,
|
||||
server::FileServerBuilder,
|
||||
Result,
|
||||
};
|
||||
use async_trait::async_trait;
|
||||
use colored::Colorize;
|
||||
use tracing::{error, info};
|
||||
|
||||
#[async_trait]
|
||||
pub trait CommandHandler {
|
||||
async fn execute(&self) -> Result<()>;
|
||||
}
|
||||
|
||||
pub struct CommandHandlerBuilder {
|
||||
logging_initialized: bool,
|
||||
verbose: bool,
|
||||
}
|
||||
|
||||
impl CommandHandlerBuilder {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
logging_initialized: false,
|
||||
verbose: false,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_logging(mut self, verbose: bool) -> Self {
|
||||
self.verbose = verbose;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn initialize_logging(mut self) -> Self {
|
||||
if !self.logging_initialized {
|
||||
let level = if self.verbose {
|
||||
tracing::Level::DEBUG
|
||||
} else {
|
||||
tracing::Level::INFO
|
||||
};
|
||||
|
||||
tracing_subscriber::fmt()
|
||||
.with_max_level(level)
|
||||
.init();
|
||||
|
||||
self.logging_initialized = true;
|
||||
}
|
||||
self
|
||||
}
|
||||
|
||||
pub fn build_serve_handler(self, serve_cmd: ServeCommand) -> ServeCommandHandler {
|
||||
ServeCommandHandler::new(serve_cmd)
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for CommandHandlerBuilder {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ServeCommandHandler {
|
||||
serve_cmd: ServeCommand,
|
||||
}
|
||||
|
||||
impl ServeCommandHandler {
|
||||
pub fn new(serve_cmd: ServeCommand) -> Self {
|
||||
Self { serve_cmd }
|
||||
}
|
||||
|
||||
fn validate_command(&self) -> Result<()> {
|
||||
if !self.serve_cmd.dir.exists() {
|
||||
return Err(crate::ServerError::InvalidPath(
|
||||
format!("Directory does not exist: {}", self.serve_cmd.dir.display())
|
||||
));
|
||||
}
|
||||
|
||||
if !self.serve_cmd.dir.is_dir() {
|
||||
return Err(crate::ServerError::InvalidPath(
|
||||
format!("Path is not a directory: {}", self.serve_cmd.dir.display())
|
||||
));
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn build_server_config(&self) -> Result<GurtConfig> {
|
||||
let mut config_builder = GurtConfig::builder();
|
||||
|
||||
if let Some(config_file) = &self.serve_cmd.config {
|
||||
config_builder = config_builder.from_file(config_file)?;
|
||||
}
|
||||
|
||||
let config = config_builder
|
||||
.merge_cli_args(&self.serve_cmd)
|
||||
.build()?;
|
||||
|
||||
Ok(config)
|
||||
}
|
||||
|
||||
fn display_startup_info(&self, config: &GurtConfig) {
|
||||
println!("{}", "GURT Protocol Server".bright_cyan().bold());
|
||||
println!("{} {}", "Version".bright_blue(), config.server.protocol_version);
|
||||
println!("{} {}", "Listening on".bright_blue(), config.address());
|
||||
println!("{} {}", "Serving from".bright_blue(), config.server.base_directory.display());
|
||||
|
||||
if config.tls.is_some() {
|
||||
println!("{}", "TLS encryption enabled".bright_green());
|
||||
}
|
||||
|
||||
if let Some(logging) = &config.logging {
|
||||
println!("{} {}", "Log level".bright_blue(), logging.level);
|
||||
if logging.log_requests {
|
||||
println!("{}", "Request logging enabled".bright_green());
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(security) = &config.security {
|
||||
println!("{} {} req/min", "Rate limit".bright_blue(), security.rate_limit_requests);
|
||||
if !security.deny_files.is_empty() {
|
||||
println!("{} {} patterns", "File restrictions".bright_blue(), security.deny_files.len());
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(headers) = &config.headers {
|
||||
if !headers.is_empty() {
|
||||
println!("{} {} headers", "Custom headers".bright_blue(), headers.len());
|
||||
}
|
||||
}
|
||||
|
||||
println!("{} {}", "Max connections".bright_blue(), config.server.max_connections);
|
||||
println!("{} {}", "Max message size".bright_blue(), config.server.max_message_size);
|
||||
println!();
|
||||
}
|
||||
|
||||
async fn start_server(&self, config: &GurtConfig) -> Result<()> {
|
||||
let server = FileServerBuilder::new(config.clone()).build()?;
|
||||
|
||||
info!("Starting GURT server on {}", config.address());
|
||||
|
||||
if let Err(e) = server.listen(&config.address()).await {
|
||||
error!("Server error: {}", e);
|
||||
std::process::exit(1);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl CommandHandler for ServeCommandHandler {
|
||||
async fn execute(&self) -> Result<()> {
|
||||
self.validate_command()?;
|
||||
|
||||
let config = self.build_server_config()?;
|
||||
|
||||
self.display_startup_info(&config);
|
||||
self.start_server(&config).await
|
||||
}
|
||||
}
|
||||
607
protocol/cli/src/config.rs
Normal file
607
protocol/cli/src/config.rs
Normal file
@@ -0,0 +1,607 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct GurtConfig {
|
||||
pub server: ServerConfig,
|
||||
pub tls: Option<TlsConfig>,
|
||||
pub logging: Option<LoggingConfig>,
|
||||
pub security: Option<SecurityConfig>,
|
||||
pub error_pages: Option<ErrorPagesConfig>,
|
||||
pub headers: Option<HashMap<String, String>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ServerConfig {
|
||||
#[serde(default = "default_host")]
|
||||
pub host: String,
|
||||
|
||||
#[serde(default = "default_port")]
|
||||
pub port: u16,
|
||||
|
||||
#[serde(default = "default_protocol_version")]
|
||||
pub protocol_version: String,
|
||||
|
||||
#[serde(default = "default_alpn_identifier")]
|
||||
pub alpn_identifier: String,
|
||||
|
||||
pub timeouts: Option<TimeoutsConfig>,
|
||||
|
||||
#[serde(default = "default_max_connections")]
|
||||
pub max_connections: u32,
|
||||
|
||||
#[serde(default = "default_max_message_size")]
|
||||
pub max_message_size: String,
|
||||
|
||||
#[serde(skip)]
|
||||
pub base_directory: Arc<PathBuf>,
|
||||
|
||||
#[serde(skip)]
|
||||
pub verbose: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct TimeoutsConfig {
|
||||
#[serde(default = "default_handshake_timeout")]
|
||||
pub handshake: u64,
|
||||
|
||||
#[serde(default = "default_request_timeout")]
|
||||
pub request: u64,
|
||||
|
||||
#[serde(default = "default_connection_timeout")]
|
||||
pub connection: u64,
|
||||
|
||||
#[serde(default = "default_pool_idle_timeout")]
|
||||
pub pool_idle: u64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct TlsConfig {
|
||||
pub certificate: PathBuf,
|
||||
pub private_key: PathBuf,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct LoggingConfig {
|
||||
#[serde(default = "default_log_level")]
|
||||
pub level: String,
|
||||
|
||||
pub access_log: Option<PathBuf>,
|
||||
pub error_log: Option<PathBuf>,
|
||||
|
||||
#[serde(default = "default_log_requests")]
|
||||
pub log_requests: bool,
|
||||
|
||||
#[serde(default)]
|
||||
pub log_responses: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct SecurityConfig {
|
||||
#[serde(default)]
|
||||
pub deny_files: Vec<String>,
|
||||
|
||||
#[serde(default = "default_allowed_methods")]
|
||||
pub allowed_methods: Vec<String>,
|
||||
|
||||
#[serde(default = "default_rate_limit_requests")]
|
||||
pub rate_limit_requests: u32,
|
||||
|
||||
#[serde(default = "default_rate_limit_connections")]
|
||||
pub rate_limit_connections: u32,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ErrorPagesConfig {
|
||||
#[serde(flatten)]
|
||||
pub pages: HashMap<String, String>,
|
||||
|
||||
pub default: Option<ErrorPageDefaults>,
|
||||
}
|
||||
|
||||
impl ErrorPagesConfig {
|
||||
pub fn get_page(&self, status_code: u16) -> Option<&String> {
|
||||
let code_str = status_code.to_string();
|
||||
self.pages.get(&code_str)
|
||||
}
|
||||
|
||||
pub fn get_default_page(&self, status_code: u16) -> Option<&String> {
|
||||
if let Some(defaults) = &self.default {
|
||||
let code_str = status_code.to_string();
|
||||
defaults.pages.get(&code_str)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_any_page(&self, status_code: u16) -> Option<&String> {
|
||||
self.get_page(status_code)
|
||||
.or_else(|| self.get_default_page(status_code))
|
||||
}
|
||||
|
||||
pub fn get_page_content(&self, status_code: u16, base_dir: &std::path::Path) -> Option<String> {
|
||||
if let Some(page_value) = self.get_page(status_code) {
|
||||
if page_value.starts_with('/') || page_value.starts_with("./") {
|
||||
let file_path = if page_value.starts_with('/') {
|
||||
base_dir.join(&page_value[1..])
|
||||
} else {
|
||||
base_dir.join(page_value)
|
||||
};
|
||||
|
||||
if let Ok(content) = std::fs::read_to_string(&file_path) {
|
||||
return Some(content);
|
||||
} else {
|
||||
tracing::warn!("Failed to read error page file: {}", file_path.display());
|
||||
return None;
|
||||
}
|
||||
} else {
|
||||
return Some(page_value.clone());
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(page_value) = self.get_default_page(status_code) {
|
||||
return Some(page_value.clone());
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ErrorPageDefaults {
|
||||
#[serde(flatten)]
|
||||
pub pages: HashMap<String, String>,
|
||||
}
|
||||
|
||||
fn default_host() -> String { "127.0.0.1".to_string() }
|
||||
fn default_port() -> u16 { 4878 }
|
||||
fn default_protocol_version() -> String { "1.0.0".to_string() }
|
||||
fn default_alpn_identifier() -> String { "GURT/1.0".to_string() }
|
||||
fn default_max_connections() -> u32 { 10 }
|
||||
fn default_max_message_size() -> String { "10MB".to_string() }
|
||||
fn default_handshake_timeout() -> u64 { 5 }
|
||||
fn default_request_timeout() -> u64 { 30 }
|
||||
fn default_connection_timeout() -> u64 { 10 }
|
||||
fn default_pool_idle_timeout() -> u64 { 300 }
|
||||
fn default_log_level() -> String { "info".to_string() }
|
||||
fn default_log_requests() -> bool { true }
|
||||
fn default_allowed_methods() -> Vec<String> {
|
||||
vec!["GET".to_string(), "POST".to_string(), "PUT".to_string(),
|
||||
"DELETE".to_string(), "HEAD".to_string(), "OPTIONS".to_string(), "PATCH".to_string()]
|
||||
}
|
||||
fn default_rate_limit_requests() -> u32 { 100 }
|
||||
fn default_rate_limit_connections() -> u32 { 10 }
|
||||
|
||||
impl Default for GurtConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
server: ServerConfig::default(),
|
||||
tls: None,
|
||||
logging: None,
|
||||
security: None,
|
||||
error_pages: None,
|
||||
headers: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for ServerConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
host: default_host(),
|
||||
port: default_port(),
|
||||
protocol_version: default_protocol_version(),
|
||||
alpn_identifier: default_alpn_identifier(),
|
||||
timeouts: None,
|
||||
max_connections: default_max_connections(),
|
||||
max_message_size: default_max_message_size(),
|
||||
base_directory: Arc::new(PathBuf::from(".")),
|
||||
verbose: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl GurtConfig {
|
||||
pub fn from_file<P: AsRef<std::path::Path>>(path: P) -> crate::Result<Self> {
|
||||
let content = std::fs::read_to_string(path)
|
||||
.map_err(|e| crate::ServerError::InvalidConfiguration(format!("Failed to read config file: {}", e)))?;
|
||||
|
||||
let config: GurtConfig = toml::from_str(&content)
|
||||
.map_err(|e| crate::ServerError::InvalidConfiguration(format!("Failed to parse config file: {}", e)))?;
|
||||
|
||||
Ok(config)
|
||||
}
|
||||
|
||||
pub fn builder() -> GurtConfigBuilder {
|
||||
GurtConfigBuilder::default()
|
||||
}
|
||||
|
||||
pub fn address(&self) -> String {
|
||||
format!("{}:{}", self.server.host, self.server.port)
|
||||
}
|
||||
|
||||
pub fn max_message_size_bytes(&self) -> crate::Result<u64> {
|
||||
parse_size(&self.server.max_message_size)
|
||||
}
|
||||
|
||||
pub fn get_handshake_timeout(&self) -> Duration {
|
||||
Duration::from_secs(
|
||||
self.server.timeouts
|
||||
.as_ref()
|
||||
.map(|t| t.handshake)
|
||||
.unwrap_or(default_handshake_timeout())
|
||||
)
|
||||
}
|
||||
|
||||
pub fn get_request_timeout(&self) -> Duration {
|
||||
Duration::from_secs(
|
||||
self.server.timeouts
|
||||
.as_ref()
|
||||
.map(|t| t.request)
|
||||
.unwrap_or(default_request_timeout())
|
||||
)
|
||||
}
|
||||
|
||||
pub fn get_connection_timeout(&self) -> Duration {
|
||||
Duration::from_secs(
|
||||
self.server.timeouts
|
||||
.as_ref()
|
||||
.map(|t| t.connection)
|
||||
.unwrap_or(default_connection_timeout())
|
||||
)
|
||||
}
|
||||
|
||||
pub fn should_deny_file(&self, file_path: &str) -> bool {
|
||||
if let Some(security) = &self.security {
|
||||
for pattern in &security.deny_files {
|
||||
if matches_pattern(file_path, pattern) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
false
|
||||
}
|
||||
|
||||
pub fn is_method_allowed(&self, method: &str) -> bool {
|
||||
if let Some(security) = &self.security {
|
||||
security.allowed_methods.contains(&method.to_uppercase())
|
||||
} else {
|
||||
default_allowed_methods().contains(&method.to_uppercase())
|
||||
}
|
||||
}
|
||||
|
||||
pub fn default_with_directory(base_dir: PathBuf) -> Self {
|
||||
let mut config = Self::default();
|
||||
config.server.base_directory = Arc::new(base_dir);
|
||||
config
|
||||
}
|
||||
|
||||
pub fn from_toml(toml_content: &str, base_dir: PathBuf) -> crate::Result<Self> {
|
||||
let mut config: GurtConfig = toml::from_str(toml_content)
|
||||
.map_err(|e| crate::ServerError::InvalidConfiguration(format!("Failed to parse config: {}", e)))?;
|
||||
|
||||
config.server.base_directory = Arc::new(base_dir);
|
||||
Ok(config)
|
||||
}
|
||||
|
||||
pub fn validate(&self) -> crate::Result<()> {
|
||||
if !self.server.base_directory.exists() || !self.server.base_directory.is_dir() {
|
||||
return Err(crate::ServerError::InvalidConfiguration(
|
||||
format!("Invalid base directory: {}", self.server.base_directory.display())
|
||||
));
|
||||
}
|
||||
|
||||
if let Some(tls) = &self.tls {
|
||||
if !tls.certificate.exists() {
|
||||
return Err(crate::ServerError::TlsConfiguration(
|
||||
format!("Certificate file does not exist: {}", tls.certificate.display())
|
||||
));
|
||||
}
|
||||
if !tls.private_key.exists() {
|
||||
return Err(crate::ServerError::TlsConfiguration(
|
||||
format!("Private key file does not exist: {}", tls.private_key.display())
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct GurtConfigBuilder {
|
||||
config: GurtConfig,
|
||||
}
|
||||
|
||||
impl GurtConfigBuilder {
|
||||
pub fn new() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
|
||||
pub fn host<S: Into<String>>(mut self, host: S) -> Self {
|
||||
self.config.server.host = host.into();
|
||||
self
|
||||
}
|
||||
|
||||
pub fn port(mut self, port: u16) -> Self {
|
||||
self.config.server.port = port;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn base_directory<P: Into<PathBuf>>(mut self, dir: P) -> Self {
|
||||
self.config.server.base_directory = Arc::new(dir.into());
|
||||
self
|
||||
}
|
||||
|
||||
pub fn verbose(mut self, verbose: bool) -> Self {
|
||||
self.config.server.verbose = verbose;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn tls_config(mut self, cert_path: PathBuf, key_path: PathBuf) -> Self {
|
||||
self.config.tls = Some(TlsConfig {
|
||||
certificate: cert_path,
|
||||
private_key: key_path,
|
||||
});
|
||||
self
|
||||
}
|
||||
|
||||
pub fn logging_config(mut self, config: LoggingConfig) -> Self {
|
||||
self.config.logging = Some(config);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn security_config(mut self, config: SecurityConfig) -> Self {
|
||||
self.config.security = Some(config);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn error_pages_config(mut self, config: ErrorPagesConfig) -> Self {
|
||||
self.config.error_pages = Some(config);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn headers(mut self, headers: HashMap<String, String>) -> Self {
|
||||
self.config.headers = Some(headers);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn from_file<P: AsRef<std::path::Path>>(mut self, path: P) -> crate::Result<Self> {
|
||||
let file_config = GurtConfig::from_file(path)?;
|
||||
self.config = merge_configs(file_config, self.config);
|
||||
Ok(self)
|
||||
}
|
||||
|
||||
pub fn merge_cli_args(mut self, cli_args: &crate::cli::ServeCommand) -> Self {
|
||||
self.config.server.host = cli_args.host.clone();
|
||||
self.config.server.port = cli_args.port;
|
||||
self.config.server.base_directory = Arc::new(cli_args.dir.clone());
|
||||
self.config.server.verbose = cli_args.verbose;
|
||||
|
||||
if let (Some(cert), Some(key)) = (&cli_args.cert, &cli_args.key) {
|
||||
self.config.tls = Some(TlsConfig {
|
||||
certificate: cert.clone(),
|
||||
private_key: key.clone(),
|
||||
});
|
||||
}
|
||||
|
||||
self
|
||||
}
|
||||
|
||||
pub fn build(self) -> crate::Result<GurtConfig> {
|
||||
let config = self.config;
|
||||
|
||||
if !config.server.base_directory.exists() || !config.server.base_directory.is_dir() {
|
||||
return Err(crate::ServerError::InvalidConfiguration(
|
||||
format!("Invalid base directory: {}", config.server.base_directory.display())
|
||||
));
|
||||
}
|
||||
|
||||
if let Some(tls) = &config.tls {
|
||||
if !tls.certificate.exists() {
|
||||
return Err(crate::ServerError::TlsConfiguration(
|
||||
format!("Certificate file does not exist: {}", tls.certificate.display())
|
||||
));
|
||||
}
|
||||
if !tls.private_key.exists() {
|
||||
return Err(crate::ServerError::TlsConfiguration(
|
||||
format!("Private key file does not exist: {}", tls.private_key.display())
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(config)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
fn parse_size(size_str: &str) -> crate::Result<u64> {
|
||||
let size_str = size_str.trim().to_uppercase();
|
||||
|
||||
if let Some(captures) = regex::Regex::new(r"^(\d+(?:\.\d+)?)\s*([KMGT]?B?)$").unwrap().captures(&size_str) {
|
||||
let number: f64 = captures[1].parse()
|
||||
.map_err(|_| crate::ServerError::InvalidConfiguration(format!("Invalid size format: {}", size_str)))?;
|
||||
|
||||
let unit = captures.get(2).map_or("", |m| m.as_str());
|
||||
|
||||
let multiplier: u64 = match unit {
|
||||
"" | "B" => 1,
|
||||
"KB" => 1_000,
|
||||
"MB" => 1_000_000,
|
||||
"GB" => 1_000_000_000,
|
||||
"TB" => 1_000_000_000_000,
|
||||
_ => return Err(crate::ServerError::InvalidConfiguration(format!("Unknown size unit: {}", unit))),
|
||||
};
|
||||
let number = (number * multiplier as f64) as u64;
|
||||
Ok(number)
|
||||
} else {
|
||||
Err(crate::ServerError::InvalidConfiguration(format!("Invalid size format: {}", size_str)))
|
||||
}
|
||||
}
|
||||
|
||||
fn matches_pattern(path: &str, pattern: &str) -> bool {
|
||||
if pattern.ends_with("/*") {
|
||||
let prefix = &pattern[..pattern.len() - 2];
|
||||
path.starts_with(prefix)
|
||||
} else if pattern.starts_with("*.") {
|
||||
let suffix = &pattern[1..];
|
||||
path.ends_with(suffix)
|
||||
} else {
|
||||
path == pattern
|
||||
}
|
||||
}
|
||||
|
||||
fn merge_configs(base: GurtConfig, override_config: GurtConfig) -> GurtConfig {
|
||||
GurtConfig {
|
||||
server: merge_server_configs(base.server, override_config.server),
|
||||
tls: override_config.tls.or(base.tls),
|
||||
logging: override_config.logging.or(base.logging),
|
||||
security: override_config.security.or(base.security),
|
||||
error_pages: override_config.error_pages.or(base.error_pages),
|
||||
headers: override_config.headers.or(base.headers),
|
||||
}
|
||||
}
|
||||
|
||||
fn merge_server_configs(base: ServerConfig, override_config: ServerConfig) -> ServerConfig {
|
||||
ServerConfig {
|
||||
host: if override_config.host != default_host() { override_config.host } else { base.host },
|
||||
port: if override_config.port != default_port() { override_config.port } else { base.port },
|
||||
protocol_version: if override_config.protocol_version != default_protocol_version() {
|
||||
override_config.protocol_version
|
||||
} else {
|
||||
base.protocol_version
|
||||
},
|
||||
alpn_identifier: if override_config.alpn_identifier != default_alpn_identifier() {
|
||||
override_config.alpn_identifier
|
||||
} else {
|
||||
base.alpn_identifier
|
||||
},
|
||||
timeouts: override_config.timeouts.or(base.timeouts),
|
||||
max_connections: if override_config.max_connections != default_max_connections() {
|
||||
override_config.max_connections
|
||||
} else {
|
||||
base.max_connections
|
||||
},
|
||||
max_message_size: if override_config.max_message_size != default_max_message_size() {
|
||||
override_config.max_message_size
|
||||
} else {
|
||||
base.max_message_size
|
||||
},
|
||||
base_directory: override_config.base_directory,
|
||||
verbose: override_config.verbose,
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::path::PathBuf;
|
||||
|
||||
#[test]
|
||||
fn test_default_config_creation() {
|
||||
let base_dir = PathBuf::from("/tmp");
|
||||
let mut config = GurtConfig::default();
|
||||
config.server.base_directory = Arc::new(base_dir.clone());
|
||||
|
||||
assert_eq!(config.server.host, "127.0.0.1");
|
||||
assert_eq!(config.server.port, 4878);
|
||||
assert_eq!(config.server.protocol_version, "1.0.0");
|
||||
assert_eq!(config.server.alpn_identifier, "GURT/1.0");
|
||||
assert_eq!(*config.server.base_directory, base_dir);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_config_from_valid_toml() {
|
||||
let toml_content = r#"
|
||||
[server]
|
||||
host = "0.0.0.0"
|
||||
port = 8080
|
||||
protocol_version = "2.0.0"
|
||||
alpn_identifier = "custom"
|
||||
max_connections = 1000
|
||||
max_message_size = "10MB"
|
||||
|
||||
[security]
|
||||
rate_limit_requests = 60
|
||||
rate_limit_connections = 5
|
||||
"#;
|
||||
|
||||
let base_dir = PathBuf::from("/tmp");
|
||||
let config = GurtConfig::from_toml(toml_content, base_dir).unwrap();
|
||||
|
||||
assert_eq!(config.server.host, "0.0.0.0");
|
||||
assert_eq!(config.server.port, 8080);
|
||||
assert_eq!(config.server.protocol_version, "2.0.0");
|
||||
assert_eq!(config.server.alpn_identifier, "custom");
|
||||
assert_eq!(config.server.max_connections, 1000);
|
||||
|
||||
let security = config.security.unwrap();
|
||||
assert_eq!(security.rate_limit_requests, 60);
|
||||
assert_eq!(security.rate_limit_connections, 5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_invalid_toml_returns_error() {
|
||||
let invalid_toml = r#"
|
||||
[server
|
||||
host = "0.0.0.0"
|
||||
"#;
|
||||
|
||||
let base_dir = PathBuf::from("/tmp");
|
||||
let result = GurtConfig::from_toml(invalid_toml, base_dir);
|
||||
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_max_message_size_parsing() {
|
||||
let config = GurtConfig::default();
|
||||
|
||||
assert_eq!(parse_size("1024").unwrap(), 1024);
|
||||
assert_eq!(parse_size("1KB").unwrap(), 1000);
|
||||
assert_eq!(parse_size("1MB").unwrap(), 1000 * 1000);
|
||||
assert_eq!(parse_size("1GB").unwrap(), 1000 * 1000 * 1000);
|
||||
|
||||
assert!(parse_size("invalid").is_err());
|
||||
|
||||
assert!(config.max_message_size_bytes().is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tls_config_validation() {
|
||||
let mut config = GurtConfig::default();
|
||||
|
||||
config.tls = Some(TlsConfig {
|
||||
certificate: PathBuf::from("/nonexistent/cert.pem"),
|
||||
private_key: PathBuf::from("/nonexistent/key.pem"),
|
||||
});
|
||||
|
||||
assert!(config.tls.is_some());
|
||||
let tls = config.tls.unwrap();
|
||||
assert_eq!(tls.certificate, PathBuf::from("/nonexistent/cert.pem"));
|
||||
assert_eq!(tls.private_key, PathBuf::from("/nonexistent/key.pem"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_address_formatting() {
|
||||
let config = GurtConfig::default();
|
||||
assert_eq!(config.address(), "127.0.0.1:4878");
|
||||
|
||||
let mut custom_config = GurtConfig::default();
|
||||
custom_config.server.host = "0.0.0.0".to_string();
|
||||
custom_config.server.port = 8080;
|
||||
assert_eq!(custom_config.address(), "0.0.0.0:8080");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_timeout_getters() {
|
||||
let config = GurtConfig::default();
|
||||
|
||||
assert_eq!(config.get_handshake_timeout(), Duration::from_secs(5));
|
||||
assert_eq!(config.get_request_timeout(), Duration::from_secs(30));
|
||||
assert_eq!(config.get_connection_timeout(), Duration::from_secs(10));
|
||||
}
|
||||
}
|
||||
38
protocol/cli/src/error.rs
Normal file
38
protocol/cli/src/error.rs
Normal file
@@ -0,0 +1,38 @@
|
||||
use std::fmt;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum ServerError {
|
||||
Io(std::io::Error),
|
||||
InvalidPath(String),
|
||||
InvalidConfiguration(String),
|
||||
TlsConfiguration(String),
|
||||
ServerStartup(String),
|
||||
}
|
||||
|
||||
impl fmt::Display for ServerError {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
match self {
|
||||
ServerError::Io(err) => write!(f, "I/O error: {}", err),
|
||||
ServerError::InvalidPath(path) => write!(f, "Invalid path: {}", path),
|
||||
ServerError::InvalidConfiguration(msg) => write!(f, "Configuration error: {}", msg),
|
||||
ServerError::TlsConfiguration(msg) => write!(f, "TLS configuration error: {}", msg),
|
||||
ServerError::ServerStartup(msg) => write!(f, "Server startup error: {}", msg),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::error::Error for ServerError {}
|
||||
|
||||
impl From<std::io::Error> for ServerError {
|
||||
fn from(err: std::io::Error) -> Self {
|
||||
ServerError::Io(err)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<gurt::GurtError> for ServerError {
|
||||
fn from(err: gurt::GurtError) -> Self {
|
||||
ServerError::ServerStartup(err.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
pub type Result<T> = std::result::Result<T, ServerError>;
|
||||
197
protocol/cli/src/handlers.rs
Normal file
197
protocol/cli/src/handlers.rs
Normal file
@@ -0,0 +1,197 @@
|
||||
use std::path::Path;
|
||||
|
||||
pub trait FileHandler: Send + Sync {
|
||||
fn can_handle(&self, path: &Path) -> bool;
|
||||
fn get_content_type(&self, path: &Path) -> String;
|
||||
fn handle_file(&self, path: &Path) -> crate::Result<Vec<u8>>;
|
||||
}
|
||||
|
||||
pub struct DefaultFileHandler;
|
||||
|
||||
impl FileHandler for DefaultFileHandler {
|
||||
fn can_handle(&self, _path: &Path) -> bool {
|
||||
true // Default
|
||||
}
|
||||
|
||||
fn get_content_type(&self, path: &Path) -> String {
|
||||
match path.extension().and_then(|ext| ext.to_str()) {
|
||||
Some("html") | Some("htm") => "text/html".to_string(),
|
||||
Some("css") => "text/css".to_string(),
|
||||
Some("js") => "application/javascript".to_string(),
|
||||
Some("json") => "application/json".to_string(),
|
||||
Some("png") => "image/png".to_string(),
|
||||
Some("jpg") | Some("jpeg") => "image/jpeg".to_string(),
|
||||
Some("gif") => "image/gif".to_string(),
|
||||
Some("svg") => "image/svg+xml".to_string(),
|
||||
Some("ico") => "image/x-icon".to_string(),
|
||||
Some("txt") => "text/plain".to_string(),
|
||||
Some("xml") => "application/xml".to_string(),
|
||||
Some("pdf") => "application/pdf".to_string(),
|
||||
_ => "application/octet-stream".to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
fn handle_file(&self, path: &Path) -> crate::Result<Vec<u8>> {
|
||||
std::fs::read(path).map_err(crate::ServerError::from)
|
||||
}
|
||||
}
|
||||
|
||||
pub trait DirectoryHandler: Send + Sync {
|
||||
fn handle_directory(&self, path: &Path, request_path: &str) -> crate::Result<String>;
|
||||
}
|
||||
|
||||
pub struct DefaultDirectoryHandler;
|
||||
|
||||
impl DirectoryHandler for DefaultDirectoryHandler {
|
||||
fn handle_directory(&self, path: &Path, request_path: &str) -> crate::Result<String> {
|
||||
let entries = std::fs::read_dir(path)?;
|
||||
|
||||
let mut listing = String::from(include_str!("../templates/directory_listing_start.html"));
|
||||
|
||||
if request_path != "/" {
|
||||
listing.push_str(include_str!("../templates/directory_parent_link.html"));
|
||||
}
|
||||
|
||||
listing.push_str(include_str!("../templates/directory_content_start.html"));
|
||||
|
||||
for entry in entries.flatten() {
|
||||
let file_name = entry.file_name();
|
||||
let name = file_name.to_string_lossy();
|
||||
let is_dir = entry.path().is_dir();
|
||||
let display_name = if is_dir {
|
||||
format!("{}/", name)
|
||||
} else {
|
||||
name.to_string()
|
||||
};
|
||||
let class = if is_dir { "dir" } else { "file" };
|
||||
|
||||
listing.push_str(&format!(
|
||||
r#" <a href="{}" class="{}">{}</a>"#,
|
||||
name, class, display_name
|
||||
));
|
||||
listing.push('\n');
|
||||
}
|
||||
|
||||
listing.push_str(include_str!("../templates/directory_listing_end.html"));
|
||||
Ok(listing)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_404_html() -> &'static str {
|
||||
include_str!("../templates/404.html")
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::path::Path;
|
||||
|
||||
#[test]
|
||||
fn test_default_file_handler_can_handle_any_file() {
|
||||
let handler = DefaultFileHandler;
|
||||
let path = Path::new("test.txt");
|
||||
assert!(handler.can_handle(path));
|
||||
|
||||
let path = Path::new("some/random/file");
|
||||
assert!(handler.can_handle(path));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_content_type_detection() {
|
||||
let handler = DefaultFileHandler;
|
||||
|
||||
assert_eq!(handler.get_content_type(Path::new("index.html")), "text/html");
|
||||
assert_eq!(handler.get_content_type(Path::new("style.css")), "text/css");
|
||||
assert_eq!(handler.get_content_type(Path::new("script.js")), "application/javascript");
|
||||
assert_eq!(handler.get_content_type(Path::new("data.json")), "application/json");
|
||||
|
||||
assert_eq!(handler.get_content_type(Path::new("image.png")), "image/png");
|
||||
assert_eq!(handler.get_content_type(Path::new("photo.jpg")), "image/jpeg");
|
||||
assert_eq!(handler.get_content_type(Path::new("photo.jpeg")), "image/jpeg");
|
||||
assert_eq!(handler.get_content_type(Path::new("icon.ico")), "image/x-icon");
|
||||
assert_eq!(handler.get_content_type(Path::new("vector.svg")), "image/svg+xml");
|
||||
|
||||
assert_eq!(handler.get_content_type(Path::new("readme.txt")), "text/plain");
|
||||
assert_eq!(handler.get_content_type(Path::new("data.xml")), "application/xml");
|
||||
assert_eq!(handler.get_content_type(Path::new("document.pdf")), "application/pdf");
|
||||
|
||||
assert_eq!(handler.get_content_type(Path::new("file.unknown")), "application/octet-stream");
|
||||
|
||||
assert_eq!(handler.get_content_type(Path::new("noextension")), "application/octet-stream");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_directory_handler_generates_valid_html() {
|
||||
use std::fs;
|
||||
use std::env;
|
||||
|
||||
let temp_dir = env::temp_dir().join("gurty_test");
|
||||
let _ = fs::create_dir_all(&temp_dir);
|
||||
|
||||
let _ = fs::write(temp_dir.join("test.txt"), "test content");
|
||||
let _ = fs::create_dir_all(temp_dir.join("subdir"));
|
||||
|
||||
let handler = DefaultDirectoryHandler;
|
||||
let result = handler.handle_directory(&temp_dir, "/test/");
|
||||
|
||||
assert!(result.is_ok());
|
||||
let html = result.unwrap();
|
||||
|
||||
assert!(html.contains("<!DOCTYPE html>"));
|
||||
assert!(html.contains("<title>Directory Listing</title>"));
|
||||
assert!(html.contains("← Parent Directory"));
|
||||
assert!(html.contains("test.txt"));
|
||||
assert!(html.contains("subdir/"));
|
||||
|
||||
let _ = fs::remove_dir_all(&temp_dir);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_directory_handler_root_path() {
|
||||
use std::fs;
|
||||
use std::env;
|
||||
|
||||
let temp_dir = env::temp_dir().join("gurty_test_root");
|
||||
let _ = fs::create_dir_all(&temp_dir);
|
||||
|
||||
let handler = DefaultDirectoryHandler;
|
||||
let result = handler.handle_directory(&temp_dir, "/");
|
||||
|
||||
assert!(result.is_ok());
|
||||
let html = result.unwrap();
|
||||
|
||||
assert!(!html.contains("← Parent Directory"));
|
||||
|
||||
let _ = fs::remove_dir_all(&temp_dir);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_get_404_html_content() {
|
||||
let html = get_404_html();
|
||||
|
||||
assert!(html.contains("<!DOCTYPE html>"));
|
||||
assert!(html.contains("404 Page Not Found"));
|
||||
assert!(html.contains("The requested path was not found"));
|
||||
assert!(html.contains("Back to home"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_directory_handler_with_empty_directory() {
|
||||
use std::fs;
|
||||
use std::env;
|
||||
|
||||
let temp_dir = env::temp_dir().join("gurty_test_empty");
|
||||
let _ = fs::create_dir_all(&temp_dir);
|
||||
|
||||
let handler = DefaultDirectoryHandler;
|
||||
let result = handler.handle_directory(&temp_dir, "/empty/");
|
||||
|
||||
assert!(result.is_ok());
|
||||
let html = result.unwrap();
|
||||
|
||||
assert!(html.contains("<!DOCTYPE html>"));
|
||||
assert!(html.contains("Directory Listing"));
|
||||
|
||||
let _ = fs::remove_dir_all(&temp_dir);
|
||||
}
|
||||
}
|
||||
10
protocol/cli/src/lib.rs
Normal file
10
protocol/cli/src/lib.rs
Normal file
@@ -0,0 +1,10 @@
|
||||
pub mod cli;
|
||||
pub mod config;
|
||||
pub mod error;
|
||||
pub mod security;
|
||||
pub mod server;
|
||||
pub mod request_handler;
|
||||
pub mod command_handler;
|
||||
pub mod handlers;
|
||||
|
||||
pub use error::{Result, ServerError};
|
||||
@@ -1,323 +1,23 @@
|
||||
use clap::{Parser, Subcommand};
|
||||
use colored::Colorize;
|
||||
use gurt::prelude::*;
|
||||
use std::path::PathBuf;
|
||||
use tracing::error;
|
||||
use tracing_subscriber;
|
||||
|
||||
#[derive(Parser)]
|
||||
#[command(name = "server")]
|
||||
#[command(about = "GURT Protocol Server")]
|
||||
#[command(version = "1.0.0")]
|
||||
struct Cli {
|
||||
#[command(subcommand)]
|
||||
command: Commands,
|
||||
}
|
||||
|
||||
#[derive(Subcommand)]
|
||||
enum Commands {
|
||||
Serve {
|
||||
#[arg(short, long, default_value_t = 4878)]
|
||||
port: u16,
|
||||
|
||||
#[arg(long, default_value = "127.0.0.1")]
|
||||
host: String,
|
||||
|
||||
#[arg(short, long, default_value = ".")]
|
||||
dir: PathBuf,
|
||||
|
||||
#[arg(short, long)]
|
||||
verbose: bool,
|
||||
|
||||
#[arg(long, help = "Path to TLS certificate file")]
|
||||
cert: Option<PathBuf>,
|
||||
|
||||
#[arg(long, help = "Path to TLS private key file")]
|
||||
key: Option<PathBuf>,
|
||||
}
|
||||
}
|
||||
use clap::Parser;
|
||||
use gurty::{
|
||||
cli::{Cli, Commands},
|
||||
command_handler::{CommandHandler, CommandHandlerBuilder},
|
||||
Result,
|
||||
};
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<()> {
|
||||
let cli = Cli::parse();
|
||||
|
||||
match cli.command {
|
||||
Commands::Serve { port, host, dir, verbose, cert, key } => {
|
||||
if verbose {
|
||||
tracing_subscriber::fmt()
|
||||
.with_max_level(tracing::Level::DEBUG)
|
||||
.init();
|
||||
} else {
|
||||
tracing_subscriber::fmt()
|
||||
.with_max_level(tracing::Level::INFO)
|
||||
.init();
|
||||
}
|
||||
Commands::Serve(serve_cmd) => {
|
||||
let handler = CommandHandlerBuilder::new()
|
||||
.with_logging(serve_cmd.verbose)
|
||||
.initialize_logging()
|
||||
.build_serve_handler(serve_cmd);
|
||||
|
||||
println!("{}", "GURT Protocol Server".bright_cyan().bold());
|
||||
println!("{} {}:{}", "Listening on".bright_blue(), host, port);
|
||||
println!("{} {}", "Serving from".bright_blue(), dir.display());
|
||||
|
||||
let server = create_file_server(dir, cert, key)?;
|
||||
let addr = format!("{}:{}", host, port);
|
||||
|
||||
if let Err(e) = server.listen(&addr).await {
|
||||
error!("Server error: {}", e);
|
||||
std::process::exit(1);
|
||||
handler.execute().await
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn create_file_server(base_dir: PathBuf, cert_path: Option<PathBuf>, key_path: Option<PathBuf>) -> Result<GurtServer> {
|
||||
let base_dir = std::sync::Arc::new(base_dir);
|
||||
|
||||
let server = match (cert_path, key_path) {
|
||||
(Some(cert), Some(key)) => {
|
||||
println!("TLS using certificate: {}", cert.display());
|
||||
GurtServer::with_tls_certificates(
|
||||
cert.to_str().ok_or_else(|| GurtError::invalid_message("Invalid certificate path"))?,
|
||||
key.to_str().ok_or_else(|| GurtError::invalid_message("Invalid key path"))?
|
||||
)?
|
||||
}
|
||||
(Some(_), None) => {
|
||||
return Err(GurtError::invalid_message("Certificate provided but no key file specified (use --key)"));
|
||||
}
|
||||
(None, Some(_)) => {
|
||||
return Err(GurtError::invalid_message("Key provided but no certificate file specified (use --cert)"));
|
||||
}
|
||||
(None, None) => {
|
||||
return Err(GurtError::invalid_message("GURT protocol requires TLS encryption. Please provide --cert and --key parameters."));
|
||||
}
|
||||
};
|
||||
|
||||
let server = server
|
||||
.get("/", {
|
||||
let base_dir = base_dir.clone();
|
||||
move |_| {
|
||||
let base_dir = base_dir.clone();
|
||||
async move {
|
||||
// Try to serve index.html if it exists
|
||||
let index_path = base_dir.join("index.html");
|
||||
|
||||
if index_path.exists() && index_path.is_file() {
|
||||
match std::fs::read_to_string(&index_path) {
|
||||
Ok(content) => {
|
||||
return Ok(GurtResponse::ok()
|
||||
.with_header("Content-Type", "text/html")
|
||||
.with_string_body(content));
|
||||
}
|
||||
Err(_) => {
|
||||
// Fall through to directory listing
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// No index.html found, show directory listing
|
||||
match std::fs::read_dir(base_dir.as_ref()) {
|
||||
Ok(entries) => {
|
||||
let mut listing = String::from(r#"
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<title>Directory Listing</title>
|
||||
<style>
|
||||
body { font-sans m-[40px] }
|
||||
.dir { font-bold text-[#0066cc] }
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<h1>Directory Listing</h1>
|
||||
<div style="flex flex-col gap-2">
|
||||
"#);
|
||||
for entry in entries.flatten() {
|
||||
let file_name = entry.file_name();
|
||||
let name = file_name.to_string_lossy();
|
||||
let is_dir = entry.path().is_dir();
|
||||
let display_name = if is_dir { format!("{}/", name) } else { name.to_string() };
|
||||
let class = if is_dir { "style=\"dir\"" } else { "" };
|
||||
|
||||
listing.push_str(&format!(
|
||||
r#" <a {} href="/{}">{}</a>"#,
|
||||
class, name, display_name
|
||||
));
|
||||
listing.push('\n');
|
||||
}
|
||||
|
||||
listing.push_str("</div></body>\n</html>");
|
||||
|
||||
Ok(GurtResponse::ok()
|
||||
.with_header("Content-Type", "text/html")
|
||||
.with_string_body(listing))
|
||||
}
|
||||
Err(_) => {
|
||||
Ok(GurtResponse::internal_server_error()
|
||||
.with_header("Content-Type", "text/plain")
|
||||
.with_string_body("Failed to read directory"))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
.get("/*", {
|
||||
let base_dir = base_dir.clone();
|
||||
move |ctx| {
|
||||
let base_dir = base_dir.clone();
|
||||
let path = ctx.path().to_string();
|
||||
async move {
|
||||
let mut relative_path = path.strip_prefix('/').unwrap_or(&path).to_string();
|
||||
// Remove any leading slashes to ensure relative path
|
||||
while relative_path.starts_with('/') || relative_path.starts_with('\\') {
|
||||
relative_path = relative_path[1..].to_string();
|
||||
}
|
||||
// If the path is now empty, use "."
|
||||
let relative_path = if relative_path.is_empty() { ".".to_string() } else { relative_path };
|
||||
let file_path = base_dir.join(&relative_path);
|
||||
|
||||
match file_path.canonicalize() {
|
||||
Ok(canonical_path) => {
|
||||
let canonical_base = match base_dir.canonicalize() {
|
||||
Ok(base) => base,
|
||||
Err(_) => {
|
||||
return Ok(GurtResponse::internal_server_error()
|
||||
.with_header("Content-Type", "text/plain")
|
||||
.with_string_body("Server configuration error"));
|
||||
}
|
||||
};
|
||||
|
||||
if !canonical_path.starts_with(&canonical_base) {
|
||||
return Ok(GurtResponse::bad_request()
|
||||
.with_header("Content-Type", "text/plain")
|
||||
.with_string_body("Access denied: Path outside served directory"));
|
||||
}
|
||||
|
||||
if canonical_path.is_file() {
|
||||
match std::fs::read(&canonical_path) {
|
||||
Ok(content) => {
|
||||
let content_type = get_content_type(&canonical_path);
|
||||
Ok(GurtResponse::ok()
|
||||
.with_header("Content-Type", &content_type)
|
||||
.with_body(content))
|
||||
}
|
||||
Err(_) => {
|
||||
Ok(GurtResponse::internal_server_error()
|
||||
.with_header("Content-Type", "text/plain")
|
||||
.with_string_body("Failed to read file"))
|
||||
}
|
||||
}
|
||||
} else if canonical_path.is_dir() {
|
||||
let index_path = canonical_path.join("index.html");
|
||||
if index_path.is_file() {
|
||||
match std::fs::read_to_string(&index_path) {
|
||||
Ok(content) => {
|
||||
Ok(GurtResponse::ok()
|
||||
.with_header("Content-Type", "text/html")
|
||||
.with_string_body(content))
|
||||
}
|
||||
Err(_) => {
|
||||
Ok(GurtResponse::internal_server_error()
|
||||
.with_header("Content-Type", "text/plain")
|
||||
.with_string_body("Failed to read index file"))
|
||||
}
|
||||
}
|
||||
} else {
|
||||
match std::fs::read_dir(&canonical_path) {
|
||||
Ok(entries) => {
|
||||
let mut listing = String::from(r#"
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<title>Directory Listing</title>
|
||||
<style>
|
||||
body { font-sans m-[40px] }
|
||||
.dir { font-bold text-[#0066cc] }
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<h1>Directory Listing</h1>
|
||||
<p><a href="../">← Parent Directory</a></p>
|
||||
<div style="flex flex-col gap-2">
|
||||
"#);
|
||||
for entry in entries.flatten() {
|
||||
let file_name = entry.file_name();
|
||||
let name = file_name.to_string_lossy();
|
||||
let is_dir = entry.path().is_dir();
|
||||
let display_name = if is_dir { format!("{}/", name) } else { name.to_string() };
|
||||
let class = if is_dir { "style=\"dir\"" } else { "" };
|
||||
|
||||
listing.push_str(&format!(
|
||||
r#" <a {} href="{}">{}</a>"#,
|
||||
class, name, display_name
|
||||
));
|
||||
listing.push('\n');
|
||||
}
|
||||
|
||||
listing.push_str("</div></body>\n</html>");
|
||||
|
||||
Ok(GurtResponse::ok()
|
||||
.with_header("Content-Type", "text/html")
|
||||
.with_string_body(listing))
|
||||
}
|
||||
Err(_) => {
|
||||
Ok(GurtResponse::internal_server_error()
|
||||
.with_header("Content-Type", "text/plain")
|
||||
.with_string_body("Failed to read directory"))
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// File not found
|
||||
Ok(GurtResponse::not_found()
|
||||
.with_header("Content-Type", "text/html")
|
||||
.with_string_body(get_404_html()))
|
||||
}
|
||||
}
|
||||
Err(_e) => {
|
||||
Ok(GurtResponse::not_found()
|
||||
.with_header("Content-Type", "text/html")
|
||||
.with_string_body(get_404_html()))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
Ok(server)
|
||||
}
|
||||
|
||||
fn get_404_html() -> &'static str {
|
||||
r#"<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<title>404 Not Found</title>
|
||||
<style>
|
||||
body { font-sans m-[40px] text-center }
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<h1>404 Page Not Found</h1>
|
||||
<p>The requested path was not found on this GURT server.</p>
|
||||
<p><a href="/">Back to home</a></p>
|
||||
</body>
|
||||
</html>
|
||||
"#
|
||||
}
|
||||
|
||||
fn get_content_type(path: &std::path::Path) -> String {
|
||||
match path.extension().and_then(|ext| ext.to_str()) {
|
||||
Some("html") | Some("htm") => "text/html".to_string(),
|
||||
Some("css") => "text/css".to_string(),
|
||||
Some("js") => "application/javascript".to_string(),
|
||||
Some("json") => "application/json".to_string(),
|
||||
Some("png") => "image/png".to_string(),
|
||||
Some("jpg") | Some("jpeg") => "image/jpeg".to_string(),
|
||||
Some("gif") => "image/gif".to_string(),
|
||||
Some("svg") => "image/svg+xml".to_string(),
|
||||
Some("ico") => "image/x-icon".to_string(),
|
||||
Some("txt") => "text/plain".to_string(),
|
||||
Some("xml") => "application/xml".to_string(),
|
||||
Some("pdf") => "application/pdf".to_string(),
|
||||
_ => "application/octet-stream".to_string(),
|
||||
}
|
||||
}
|
||||
560
protocol/cli/src/request_handler.rs
Normal file
560
protocol/cli/src/request_handler.rs
Normal file
@@ -0,0 +1,560 @@
|
||||
use crate::{
|
||||
handlers::{FileHandler, DirectoryHandler, DefaultFileHandler, DefaultDirectoryHandler},
|
||||
config::GurtConfig,
|
||||
security::SecurityMiddleware,
|
||||
};
|
||||
use gurt::prelude::*;
|
||||
use std::path::Path;
|
||||
use std::sync::Arc;
|
||||
use tracing;
|
||||
|
||||
pub struct RequestHandlerBuilder {
|
||||
file_handler: Arc<dyn FileHandler>,
|
||||
directory_handler: Arc<dyn DirectoryHandler>,
|
||||
base_directory: std::path::PathBuf,
|
||||
config: Option<Arc<GurtConfig>>,
|
||||
}
|
||||
|
||||
impl RequestHandlerBuilder {
|
||||
pub fn new<P: AsRef<Path>>(base_directory: P) -> Self {
|
||||
Self {
|
||||
file_handler: Arc::new(DefaultFileHandler),
|
||||
directory_handler: Arc::new(DefaultDirectoryHandler),
|
||||
base_directory: base_directory.as_ref().to_path_buf(),
|
||||
config: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_file_handler<H: FileHandler + 'static>(mut self, handler: H) -> Self {
|
||||
self.file_handler = Arc::new(handler);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_directory_handler<H: DirectoryHandler + 'static>(mut self, handler: H) -> Self {
|
||||
self.directory_handler = Arc::new(handler);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_config(mut self, config: Arc<GurtConfig>) -> Self {
|
||||
self.config = Some(config);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn build(self) -> RequestHandler {
|
||||
let security = self.config.as_ref().map(|config| SecurityMiddleware::new(config.clone()));
|
||||
|
||||
RequestHandler {
|
||||
file_handler: self.file_handler,
|
||||
directory_handler: self.directory_handler,
|
||||
base_directory: self.base_directory,
|
||||
config: self.config,
|
||||
security,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct RequestHandler {
|
||||
file_handler: Arc<dyn FileHandler>,
|
||||
directory_handler: Arc<dyn DirectoryHandler>,
|
||||
base_directory: std::path::PathBuf,
|
||||
config: Option<Arc<GurtConfig>>,
|
||||
security: Option<SecurityMiddleware>,
|
||||
}
|
||||
|
||||
impl RequestHandler {
|
||||
pub fn builder<P: AsRef<Path>>(base_directory: P) -> RequestHandlerBuilder {
|
||||
RequestHandlerBuilder::new(base_directory)
|
||||
}
|
||||
|
||||
fn apply_custom_error_page(&self, mut response: GurtResponse) -> GurtResponse {
|
||||
if response.status_code >= 400 {
|
||||
let custom_content = self.get_custom_error_page(response.status_code)
|
||||
.unwrap_or_else(|| self.get_fallback_error_page(response.status_code));
|
||||
|
||||
response.body = custom_content.into_bytes();
|
||||
response = response.with_header("Content-Type", "text/html");
|
||||
tracing::debug!("Applied error page for status {}", response.status_code);
|
||||
}
|
||||
response
|
||||
}
|
||||
|
||||
fn get_custom_error_page(&self, status_code: u16) -> Option<String> {
|
||||
if let Some(config) = &self.config {
|
||||
if let Some(error_pages) = &config.error_pages {
|
||||
error_pages.get_page_content(status_code, &self.base_directory)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
fn get_fallback_error_page(&self, status_code: u16) -> String {
|
||||
let (title, message) = match status_code {
|
||||
400 => ("Bad Request", "The request could not be understood by the server."),
|
||||
401 => ("Unauthorized", "Authentication is required to access this resource."),
|
||||
403 => ("Forbidden", "Access to this resource is denied by server policy."),
|
||||
404 => ("Not Found", "The requested resource was not found on this server."),
|
||||
405 => ("Method Not Allowed", "The request method is not allowed for this resource."),
|
||||
429 => ("Too Many Requests", "You have exceeded the rate limit. Please try again later."),
|
||||
500 => ("Internal Server Error", "The server encountered an error processing your request."),
|
||||
502 => ("Bad Gateway", "The server received an invalid response from an upstream server."),
|
||||
503 => ("Service Unavailable", "The server is temporarily unavailable. Please try again later."),
|
||||
504 => ("Gateway Timeout", "The server did not receive a timely response from an upstream server."),
|
||||
_ => ("Error", "An error occurred while processing your request."),
|
||||
};
|
||||
|
||||
format!(include_str!("../templates/error.html"), status_code, title, status_code, title, message)
|
||||
}
|
||||
|
||||
pub fn check_security(&self, ctx: &ServerContext) -> Option<std::result::Result<GurtResponse, GurtError>> {
|
||||
if let Some(security) = &self.security {
|
||||
let client_ip = ctx.client_ip();
|
||||
let method = ctx.method();
|
||||
|
||||
if !security.is_method_allowed(method) {
|
||||
tracing::warn!("Method {} not allowed from {}", method, client_ip);
|
||||
let response = security.create_method_not_allowed_response()
|
||||
.map(|r| self.apply_global_headers(r));
|
||||
return Some(response);
|
||||
}
|
||||
|
||||
if !security.check_rate_limit(client_ip) {
|
||||
let response = security.create_rate_limit_response()
|
||||
.map(|r| self.apply_global_headers(r));
|
||||
return Some(response);
|
||||
}
|
||||
|
||||
if !security.check_connection_limit(client_ip) {
|
||||
let response = security.create_rate_limit_response()
|
||||
.map(|r| self.apply_global_headers(r));
|
||||
return Some(response);
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
pub fn register_connection(&self, client_ip: std::net::IpAddr) {
|
||||
if let Some(security) = &self.security {
|
||||
security.register_connection(client_ip);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn unregister_connection(&self, client_ip: std::net::IpAddr) {
|
||||
if let Some(security) = &self.security {
|
||||
security.unregister_connection(client_ip);
|
||||
}
|
||||
}
|
||||
|
||||
fn is_file_denied(&self, file_path: &Path) -> bool {
|
||||
if let Some(config) = &self.config {
|
||||
let path_str = file_path.to_string_lossy();
|
||||
|
||||
let relative_path = if let Ok(canonical_file) = file_path.canonicalize() {
|
||||
if let Ok(canonical_base) = self.base_directory.canonicalize() {
|
||||
canonical_file.strip_prefix(&canonical_base)
|
||||
.map(|p| p.to_string_lossy().to_string())
|
||||
.unwrap_or_else(|_| path_str.to_string())
|
||||
} else {
|
||||
path_str.to_string()
|
||||
}
|
||||
} else {
|
||||
path_str.to_string()
|
||||
};
|
||||
|
||||
let is_denied = config.should_deny_file(&path_str) || config.should_deny_file(&relative_path);
|
||||
|
||||
if is_denied {
|
||||
tracing::warn!("File access denied by security policy: {}", relative_path);
|
||||
}
|
||||
|
||||
is_denied
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
fn apply_global_headers(&self, mut response: GurtResponse) -> GurtResponse {
|
||||
response = self.apply_custom_error_page(response);
|
||||
|
||||
if let Some(config) = &self.config {
|
||||
if let Some(headers) = &config.headers {
|
||||
for (key, value) in headers {
|
||||
response = response.with_header(key, value);
|
||||
}
|
||||
}
|
||||
}
|
||||
response
|
||||
}
|
||||
|
||||
fn create_forbidden_response(&self) -> std::result::Result<GurtResponse, GurtError> {
|
||||
let response = GurtResponse::forbidden()
|
||||
.with_header("Content-Type", "text/html");
|
||||
|
||||
Ok(self.apply_global_headers(response))
|
||||
}
|
||||
|
||||
pub async fn handle_root_request_with_context(&self, ctx: ServerContext) -> std::result::Result<GurtResponse, GurtError> {
|
||||
let client_ip = ctx.client_ip();
|
||||
|
||||
self.register_connection(client_ip);
|
||||
|
||||
if let Some(security_response) = self.check_security(&ctx) {
|
||||
self.unregister_connection(client_ip);
|
||||
return security_response;
|
||||
}
|
||||
|
||||
let result = self.handle_root_request().await;
|
||||
self.unregister_connection(client_ip);
|
||||
result
|
||||
}
|
||||
|
||||
pub async fn handle_file_request_with_context(&self, request_path: &str, ctx: ServerContext) -> std::result::Result<GurtResponse, GurtError> {
|
||||
let client_ip = ctx.client_ip();
|
||||
|
||||
self.register_connection(client_ip);
|
||||
|
||||
if let Some(security_response) = self.check_security(&ctx) {
|
||||
self.unregister_connection(client_ip);
|
||||
return security_response;
|
||||
}
|
||||
|
||||
let result = self.handle_file_request(request_path).await;
|
||||
self.unregister_connection(client_ip);
|
||||
result
|
||||
}
|
||||
|
||||
pub async fn handle_method_request_with_context(&self, ctx: ServerContext) -> std::result::Result<GurtResponse, GurtError> {
|
||||
let client_ip = ctx.client_ip();
|
||||
let method = ctx.method();
|
||||
|
||||
self.register_connection(client_ip);
|
||||
|
||||
if let Some(security_response) = self.check_security(&ctx) {
|
||||
self.unregister_connection(client_ip);
|
||||
return security_response;
|
||||
}
|
||||
|
||||
let result = match method {
|
||||
gurt::message::GurtMethod::GET => {
|
||||
if ctx.path() == "/" {
|
||||
self.handle_root_request().await
|
||||
} else {
|
||||
self.handle_file_request(ctx.path()).await
|
||||
}
|
||||
}
|
||||
gurt::message::GurtMethod::HEAD => {
|
||||
let mut response = if ctx.path() == "/" {
|
||||
self.handle_root_request().await?
|
||||
} else {
|
||||
self.handle_file_request(ctx.path()).await?
|
||||
};
|
||||
response.body = Vec::new();
|
||||
Ok(response)
|
||||
}
|
||||
gurt::message::GurtMethod::OPTIONS => {
|
||||
let allowed_methods = if let Some(config) = &self.config {
|
||||
if let Some(security) = &config.security {
|
||||
security.allowed_methods.join(", ")
|
||||
} else {
|
||||
"GET, POST, PUT, DELETE, HEAD, OPTIONS, PATCH".to_string()
|
||||
}
|
||||
} else {
|
||||
"GET, POST, PUT, DELETE, HEAD, OPTIONS, PATCH".to_string()
|
||||
};
|
||||
|
||||
let response = GurtResponse::ok()
|
||||
.with_header("Allow", &allowed_methods)
|
||||
.with_header("Content-Type", "text/plain")
|
||||
.with_string_body("Allowed methods");
|
||||
Ok(self.apply_global_headers(response))
|
||||
}
|
||||
_ => {
|
||||
let response = GurtResponse::new(gurt::protocol::GurtStatusCode::MethodNotAllowed)
|
||||
.with_header("Content-Type", "text/html");
|
||||
Ok(self.apply_global_headers(response))
|
||||
}
|
||||
};
|
||||
|
||||
self.unregister_connection(client_ip);
|
||||
result
|
||||
}
|
||||
|
||||
pub async fn handle_root_request(&self) -> std::result::Result<GurtResponse, GurtError> {
|
||||
let index_path = self.base_directory.join("index.html");
|
||||
|
||||
if index_path.exists() && index_path.is_file() {
|
||||
if self.is_file_denied(&index_path) {
|
||||
return self.create_forbidden_response();
|
||||
}
|
||||
|
||||
match self.file_handler.handle_file(&index_path) {
|
||||
Ok(content) => {
|
||||
let content_type = self.file_handler.get_content_type(&index_path);
|
||||
let response = GurtResponse::ok()
|
||||
.with_header("Content-Type", &content_type)
|
||||
.with_body(content);
|
||||
return Ok(self.apply_global_headers(response));
|
||||
}
|
||||
Err(_) => {
|
||||
// fall
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
match self.directory_handler.handle_directory(&self.base_directory, "/") {
|
||||
Ok(listing) => {
|
||||
let response = GurtResponse::ok()
|
||||
.with_header("Content-Type", "text/html")
|
||||
.with_string_body(listing);
|
||||
Ok(self.apply_global_headers(response))
|
||||
}
|
||||
Err(_) => {
|
||||
let response = GurtResponse::internal_server_error()
|
||||
.with_header("Content-Type", "text/html");
|
||||
Ok(self.apply_global_headers(response))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn handle_file_request(&self, request_path: &str) -> std::result::Result<GurtResponse, GurtError> {
|
||||
let mut relative_path = request_path.strip_prefix('/').unwrap_or(request_path).to_string();
|
||||
|
||||
while relative_path.starts_with('/') || relative_path.starts_with('\\') {
|
||||
relative_path = relative_path[1..].to_string();
|
||||
}
|
||||
|
||||
let relative_path = if relative_path.is_empty() {
|
||||
".".to_string()
|
||||
} else {
|
||||
relative_path
|
||||
};
|
||||
|
||||
let file_path = self.base_directory.join(&relative_path);
|
||||
|
||||
if self.is_file_denied(&file_path) {
|
||||
return self.create_forbidden_response();
|
||||
}
|
||||
|
||||
match file_path.canonicalize() {
|
||||
Ok(canonical_path) => {
|
||||
let canonical_base = match self.base_directory.canonicalize() {
|
||||
Ok(base) => base,
|
||||
Err(_) => {
|
||||
return Ok(GurtResponse::internal_server_error()
|
||||
.with_header("Content-Type", "text/html"));
|
||||
}
|
||||
};
|
||||
|
||||
if !canonical_path.starts_with(&canonical_base) {
|
||||
let response = GurtResponse::bad_request()
|
||||
.with_header("Content-Type", "text/html");
|
||||
return Ok(self.apply_global_headers(response));
|
||||
}
|
||||
|
||||
if self.is_file_denied(&canonical_path) {
|
||||
return self.create_forbidden_response();
|
||||
}
|
||||
|
||||
if canonical_path.is_file() {
|
||||
self.handle_file_response(&canonical_path).await
|
||||
} else if canonical_path.is_dir() {
|
||||
self.handle_directory_response(&canonical_path, request_path).await
|
||||
} else {
|
||||
self.handle_not_found_response().await
|
||||
}
|
||||
}
|
||||
Err(_) => {
|
||||
self.handle_not_found_response().await
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_file_response(&self, path: &Path) -> std::result::Result<GurtResponse, GurtError> {
|
||||
match self.file_handler.handle_file(path) {
|
||||
Ok(content) => {
|
||||
let content_type = self.file_handler.get_content_type(path);
|
||||
let response = GurtResponse::ok()
|
||||
.with_header("Content-Type", &content_type)
|
||||
.with_body(content);
|
||||
Ok(self.apply_global_headers(response))
|
||||
}
|
||||
Err(_) => {
|
||||
let response = GurtResponse::internal_server_error()
|
||||
.with_header("Content-Type", "text/html");
|
||||
Ok(self.apply_global_headers(response))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_directory_response(&self, canonical_path: &Path, request_path: &str) -> std::result::Result<GurtResponse, GurtError> {
|
||||
let index_path = canonical_path.join("index.html");
|
||||
if index_path.is_file() {
|
||||
self.handle_file_response(&index_path).await
|
||||
} else {
|
||||
match self.directory_handler.handle_directory(canonical_path, request_path) {
|
||||
Ok(listing) => {
|
||||
let response = GurtResponse::ok()
|
||||
.with_header("Content-Type", "text/html")
|
||||
.with_string_body(listing);
|
||||
Ok(self.apply_global_headers(response))
|
||||
}
|
||||
Err(_) => {
|
||||
let response = GurtResponse::internal_server_error()
|
||||
.with_header("Content-Type", "text/html");
|
||||
Ok(self.apply_global_headers(response))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_not_found_response(&self) -> std::result::Result<GurtResponse, GurtError> {
|
||||
let content = self.get_custom_error_page(404)
|
||||
.unwrap_or_else(|| crate::handlers::get_404_html().to_string());
|
||||
|
||||
let response = GurtResponse::not_found()
|
||||
.with_header("Content-Type", "text/html")
|
||||
.with_string_body(content);
|
||||
Ok(self.apply_global_headers(response))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use gurt::GurtStatusCode;
|
||||
use std::fs;
|
||||
use std::env;
|
||||
|
||||
fn create_test_handler() -> RequestHandler {
|
||||
let temp_dir = env::temp_dir().join("gurty_request_handler_test");
|
||||
let _ = fs::create_dir_all(&temp_dir);
|
||||
|
||||
RequestHandler::builder(&temp_dir).build()
|
||||
}
|
||||
|
||||
fn create_test_handler_with_config() -> RequestHandler {
|
||||
let temp_dir = env::temp_dir().join("gurty_request_handler_test_config");
|
||||
let _ = fs::create_dir_all(&temp_dir);
|
||||
|
||||
let config = Arc::new(GurtConfig::default());
|
||||
RequestHandler::builder(&temp_dir)
|
||||
.with_config(config)
|
||||
.build()
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_request_handler_builder() {
|
||||
let temp_dir = env::temp_dir().join("gurty_builder_test");
|
||||
let _ = fs::create_dir_all(&temp_dir);
|
||||
|
||||
let handler = RequestHandler::builder(&temp_dir).build();
|
||||
|
||||
assert_eq!(handler.base_directory, temp_dir);
|
||||
assert!(handler.config.is_none());
|
||||
assert!(handler.security.is_none());
|
||||
|
||||
let _ = fs::remove_dir_all(&temp_dir);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_request_handler_builder_with_config() {
|
||||
let temp_dir = env::temp_dir().join("gurty_builder_config_test");
|
||||
let _ = fs::create_dir_all(&temp_dir);
|
||||
|
||||
let config = Arc::new(GurtConfig::default());
|
||||
let handler = RequestHandler::builder(&temp_dir)
|
||||
.with_config(config.clone())
|
||||
.build();
|
||||
|
||||
assert!(handler.config.is_some());
|
||||
assert!(handler.security.is_some());
|
||||
|
||||
let _ = fs::remove_dir_all(&temp_dir);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_fallback_error_page_generation() {
|
||||
let handler = create_test_handler();
|
||||
|
||||
let error_404 = handler.get_fallback_error_page(404);
|
||||
assert!(error_404.contains("404 Not Found"));
|
||||
assert!(error_404.contains("not found"));
|
||||
|
||||
let error_500 = handler.get_fallback_error_page(500);
|
||||
assert!(error_500.contains("500 Internal Server Error"));
|
||||
assert!(error_500.contains("processing your request"));
|
||||
|
||||
let error_429 = handler.get_fallback_error_page(429);
|
||||
assert!(error_429.contains("429 Too Many Requests"));
|
||||
assert!(error_429.contains("rate limit"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_custom_error_page_with_config() {
|
||||
let handler = create_test_handler_with_config();
|
||||
|
||||
let result = handler.get_custom_error_page(404);
|
||||
assert!(result.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_apply_global_headers_without_config() {
|
||||
let handler = create_test_handler();
|
||||
let response = GurtResponse::ok();
|
||||
|
||||
let modified_response = handler.apply_global_headers(response);
|
||||
|
||||
assert_eq!(modified_response.status_code, 200);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_apply_global_headers_with_config() {
|
||||
let temp_dir = env::temp_dir().join("gurty_headers_test");
|
||||
let _ = fs::create_dir_all(&temp_dir);
|
||||
|
||||
let mut config = GurtConfig::default();
|
||||
let mut headers = std::collections::HashMap::new();
|
||||
headers.insert("X-Test-Header".to_string(), "test-value".to_string());
|
||||
config.headers = Some(headers);
|
||||
|
||||
let handler = RequestHandler::builder(&temp_dir)
|
||||
.with_config(Arc::new(config))
|
||||
.build();
|
||||
|
||||
let response = GurtResponse::ok();
|
||||
let modified_response = handler.apply_global_headers(response);
|
||||
|
||||
assert!(modified_response.headers.contains_key("x-test-header"));
|
||||
assert_eq!(modified_response.headers.get("x-test-header").unwrap(), "test-value");
|
||||
|
||||
let _ = fs::remove_dir_all(&temp_dir);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_apply_custom_error_page() {
|
||||
let handler = create_test_handler();
|
||||
let mut response = GurtResponse::new(GurtStatusCode::NotFound);
|
||||
response.body = b"Not Found".to_vec();
|
||||
|
||||
let modified_response = handler.apply_custom_error_page(response);
|
||||
|
||||
assert!(modified_response.status_code >= 400);
|
||||
let body_str = String::from_utf8_lossy(&modified_response.body);
|
||||
assert!(body_str.contains("html"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_apply_custom_error_page_for_success() {
|
||||
let handler = create_test_handler();
|
||||
let mut response = GurtResponse::ok();
|
||||
response.body = b"Success".to_vec();
|
||||
|
||||
let modified_response = handler.apply_custom_error_page(response);
|
||||
|
||||
assert_eq!(modified_response.status_code, 200);
|
||||
assert_eq!(modified_response.body, b"Success".to_vec());
|
||||
}
|
||||
}
|
||||
288
protocol/cli/src/security.rs
Normal file
288
protocol/cli/src/security.rs
Normal file
@@ -0,0 +1,288 @@
|
||||
use crate::config::GurtConfig;
|
||||
use gurt::{prelude::*, GurtMethod, GurtStatusCode};
|
||||
use std::collections::HashMap;
|
||||
use std::net::IpAddr;
|
||||
use std::sync::{Arc, Mutex};
|
||||
use std::time::{Duration, Instant};
|
||||
use tracing::{warn, debug};
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct RateLimitData {
|
||||
requests: Vec<Instant>,
|
||||
connections: u32,
|
||||
}
|
||||
|
||||
impl RateLimitData {
|
||||
fn new() -> Self {
|
||||
Self {
|
||||
requests: Vec::new(),
|
||||
connections: 0,
|
||||
}
|
||||
}
|
||||
|
||||
fn cleanup_old_requests(&mut self, window: Duration) {
|
||||
let cutoff = Instant::now() - window;
|
||||
self.requests.retain(|&request_time| request_time > cutoff);
|
||||
}
|
||||
|
||||
fn add_request(&mut self) {
|
||||
self.requests.push(Instant::now());
|
||||
}
|
||||
|
||||
fn request_count(&self) -> usize {
|
||||
self.requests.len()
|
||||
}
|
||||
|
||||
fn increment_connections(&mut self) {
|
||||
self.connections += 1;
|
||||
}
|
||||
|
||||
fn decrement_connections(&mut self) {
|
||||
if self.connections > 0 {
|
||||
self.connections -= 1;
|
||||
}
|
||||
}
|
||||
|
||||
fn connection_count(&self) -> u32 {
|
||||
self.connections
|
||||
}
|
||||
}
|
||||
|
||||
pub struct SecurityMiddleware {
|
||||
config: Arc<GurtConfig>,
|
||||
rate_limit_data: Arc<Mutex<HashMap<IpAddr, RateLimitData>>>,
|
||||
}
|
||||
|
||||
impl SecurityMiddleware {
|
||||
pub fn new(config: Arc<GurtConfig>) -> Self {
|
||||
Self {
|
||||
config,
|
||||
rate_limit_data: Arc::new(Mutex::new(HashMap::new())),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn is_method_allowed(&self, method: &GurtMethod) -> bool {
|
||||
if let Some(security) = &self.config.security {
|
||||
let method_str = method.to_string();
|
||||
security.allowed_methods.contains(&method_str)
|
||||
} else {
|
||||
true
|
||||
}
|
||||
}
|
||||
|
||||
pub fn check_rate_limit(&self, client_ip: IpAddr) -> bool {
|
||||
if let Some(security) = &self.config.security {
|
||||
let mut data = self.rate_limit_data.lock().unwrap();
|
||||
let rate_data = data.entry(client_ip).or_insert_with(RateLimitData::new);
|
||||
|
||||
rate_data.cleanup_old_requests(Duration::from_secs(60));
|
||||
|
||||
if rate_data.request_count() >= security.rate_limit_requests as usize {
|
||||
warn!("Rate limit exceeded for IP {}: {} requests in the last minute",
|
||||
client_ip, rate_data.request_count());
|
||||
return false;
|
||||
}
|
||||
|
||||
rate_data.add_request();
|
||||
debug!("Request from {}: {}/{} requests in the last minute",
|
||||
client_ip, rate_data.request_count(), security.rate_limit_requests);
|
||||
}
|
||||
|
||||
true
|
||||
}
|
||||
|
||||
pub fn check_connection_limit(&self, client_ip: IpAddr) -> bool {
|
||||
if let Some(security) = &self.config.security {
|
||||
let mut data = self.rate_limit_data.lock().unwrap();
|
||||
let rate_data = data.entry(client_ip).or_insert_with(RateLimitData::new);
|
||||
|
||||
if rate_data.connection_count() >= security.rate_limit_connections {
|
||||
warn!("Connection limit exceeded for IP {}: {} concurrent connections",
|
||||
client_ip, rate_data.connection_count());
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
true
|
||||
}
|
||||
|
||||
pub fn register_connection(&self, client_ip: IpAddr) {
|
||||
if self.config.security.is_some() {
|
||||
let mut data = self.rate_limit_data.lock().unwrap();
|
||||
let rate_data = data.entry(client_ip).or_insert_with(RateLimitData::new);
|
||||
rate_data.increment_connections();
|
||||
debug!("Connection registered for {}: {} concurrent connections",
|
||||
client_ip, rate_data.connection_count());
|
||||
}
|
||||
}
|
||||
|
||||
pub fn unregister_connection(&self, client_ip: IpAddr) {
|
||||
if self.config.security.is_some() {
|
||||
let mut data = self.rate_limit_data.lock().unwrap();
|
||||
if let Some(rate_data) = data.get_mut(&client_ip) {
|
||||
rate_data.decrement_connections();
|
||||
debug!("Connection unregistered for {}: {} concurrent connections remaining",
|
||||
client_ip, rate_data.connection_count());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn create_method_not_allowed_response(&self) -> std::result::Result<GurtResponse, GurtError> {
|
||||
let response = GurtResponse::new(GurtStatusCode::MethodNotAllowed)
|
||||
.with_header("Content-Type", "text/html");
|
||||
Ok(response)
|
||||
}
|
||||
|
||||
pub fn create_rate_limit_response(&self) -> std::result::Result<GurtResponse, GurtError> {
|
||||
let response = GurtResponse::new(GurtStatusCode::TooManyRequests)
|
||||
.with_header("Content-Type", "text/html")
|
||||
.with_header("Retry-After", "60");
|
||||
Ok(response)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::net::IpAddr;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
fn create_test_config() -> Arc<GurtConfig> {
|
||||
let mut config = crate::config::GurtConfig::default();
|
||||
config.security = Some(crate::config::SecurityConfig {
|
||||
deny_files: vec!["*.secret".to_string(), "private/*".to_string()],
|
||||
allowed_methods: vec!["GET".to_string(), "POST".to_string()],
|
||||
rate_limit_requests: 5,
|
||||
rate_limit_connections: 2,
|
||||
});
|
||||
Arc::new(config)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_rate_limit_data_initialization() {
|
||||
let data = RateLimitData::new();
|
||||
|
||||
assert_eq!(data.request_count(), 0);
|
||||
assert_eq!(data.connection_count(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_rate_limit_data_request_tracking() {
|
||||
let mut data = RateLimitData::new();
|
||||
|
||||
data.add_request();
|
||||
data.add_request();
|
||||
assert_eq!(data.request_count(), 2);
|
||||
|
||||
data.cleanup_old_requests(Duration::from_secs(0));
|
||||
assert_eq!(data.request_count(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_rate_limit_data_connection_tracking() {
|
||||
let mut data = RateLimitData::new();
|
||||
|
||||
data.increment_connections();
|
||||
data.increment_connections();
|
||||
assert_eq!(data.connection_count(), 2);
|
||||
|
||||
data.decrement_connections();
|
||||
assert_eq!(data.connection_count(), 1);
|
||||
|
||||
data.decrement_connections();
|
||||
data.decrement_connections();
|
||||
assert_eq!(data.connection_count(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_security_middleware_initialization() {
|
||||
let config = create_test_config();
|
||||
let middleware = SecurityMiddleware::new(config.clone());
|
||||
|
||||
assert!(middleware.rate_limit_data.lock().unwrap().is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_connection_tracking() {
|
||||
let config = create_test_config();
|
||||
let middleware = SecurityMiddleware::new(config.clone());
|
||||
let ip: IpAddr = "127.0.0.1".parse().unwrap();
|
||||
|
||||
middleware.register_connection(ip);
|
||||
{
|
||||
let data = middleware.rate_limit_data.lock().unwrap();
|
||||
assert_eq!(data.get(&ip).unwrap().connection_count(), 1);
|
||||
}
|
||||
|
||||
middleware.unregister_connection(ip);
|
||||
{
|
||||
let data = middleware.rate_limit_data.lock().unwrap();
|
||||
assert_eq!(data.get(&ip).unwrap().connection_count(), 0);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_rate_limiting_requests() {
|
||||
let config = create_test_config();
|
||||
let middleware = SecurityMiddleware::new(config.clone());
|
||||
let ip: IpAddr = "127.0.0.1".parse().unwrap();
|
||||
|
||||
for _ in 0..5 {
|
||||
assert!(middleware.check_rate_limit(ip));
|
||||
}
|
||||
|
||||
assert!(!middleware.check_rate_limit(ip));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_connection_limiting() {
|
||||
let config = create_test_config();
|
||||
let middleware = SecurityMiddleware::new(config.clone());
|
||||
let ip: IpAddr = "127.0.0.1".parse().unwrap();
|
||||
|
||||
middleware.register_connection(ip);
|
||||
middleware.register_connection(ip);
|
||||
|
||||
assert!(!middleware.check_connection_limit(ip));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_method_validation() {
|
||||
let config = create_test_config();
|
||||
let middleware = SecurityMiddleware::new(config.clone());
|
||||
|
||||
assert!(middleware.is_method_allowed(&GurtMethod::GET));
|
||||
assert!(middleware.is_method_allowed(&GurtMethod::POST));
|
||||
|
||||
assert!(!middleware.is_method_allowed(&GurtMethod::PUT));
|
||||
assert!(!middleware.is_method_allowed(&GurtMethod::DELETE));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_multiple_ips_isolation() {
|
||||
let config = create_test_config();
|
||||
let middleware = SecurityMiddleware::new(config.clone());
|
||||
let ip1: IpAddr = "127.0.0.1".parse().unwrap();
|
||||
let ip2: IpAddr = "127.0.0.2".parse().unwrap();
|
||||
|
||||
for _ in 0..6 {
|
||||
middleware.check_rate_limit(ip1);
|
||||
}
|
||||
|
||||
assert!(middleware.check_rate_limit(ip2));
|
||||
assert!(!middleware.check_rate_limit(ip1));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_response_creation() {
|
||||
let config = create_test_config();
|
||||
let middleware = SecurityMiddleware::new(config.clone());
|
||||
|
||||
let response = middleware.create_method_not_allowed_response().unwrap();
|
||||
assert_eq!(response.status_code, 405);
|
||||
|
||||
let response = middleware.create_rate_limit_response().unwrap();
|
||||
assert_eq!(response.status_code, 429);
|
||||
}
|
||||
}
|
||||
301
protocol/cli/src/server.rs
Normal file
301
protocol/cli/src/server.rs
Normal file
@@ -0,0 +1,301 @@
|
||||
use crate::{
|
||||
config::GurtConfig,
|
||||
handlers::{FileHandler, DirectoryHandler, DefaultFileHandler, DefaultDirectoryHandler},
|
||||
request_handler::{RequestHandler, RequestHandlerBuilder},
|
||||
};
|
||||
use gurt::prelude::*;
|
||||
use std::{path::PathBuf, sync::Arc};
|
||||
|
||||
pub struct FileServerBuilder {
|
||||
config: GurtConfig,
|
||||
file_handler: Arc<dyn FileHandler>,
|
||||
directory_handler: Arc<dyn DirectoryHandler>,
|
||||
}
|
||||
|
||||
impl FileServerBuilder {
|
||||
pub fn new(config: GurtConfig) -> Self {
|
||||
Self {
|
||||
config,
|
||||
file_handler: Arc::new(DefaultFileHandler),
|
||||
directory_handler: Arc::new(DefaultDirectoryHandler),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_file_handler<H: FileHandler + 'static>(mut self, handler: H) -> Self {
|
||||
self.file_handler = Arc::new(handler);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_directory_handler<H: DirectoryHandler + 'static>(mut self, handler: H) -> Self {
|
||||
self.directory_handler = Arc::new(handler);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn build(self) -> crate::Result<GurtServer> {
|
||||
let server = self.create_server()?;
|
||||
let request_handler = self.create_request_handler();
|
||||
let server_with_routes = self.add_routes(server, request_handler);
|
||||
Ok(server_with_routes)
|
||||
}
|
||||
|
||||
fn create_server(&self) -> crate::Result<GurtServer> {
|
||||
match &self.config.tls {
|
||||
Some(tls) => {
|
||||
println!("TLS using certificate: {}", tls.certificate.display());
|
||||
GurtServerBuilder::new()
|
||||
.with_tls_certificates(&tls.certificate, &tls.private_key)
|
||||
.with_timeouts(
|
||||
self.config.get_handshake_timeout(),
|
||||
self.config.get_request_timeout(),
|
||||
self.config.get_connection_timeout(),
|
||||
)
|
||||
.build()
|
||||
}
|
||||
None => {
|
||||
Err(crate::ServerError::TlsConfiguration(
|
||||
"GURT protocol requires TLS encryption. Please provide --cert and --key parameters.".to_string()
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn create_request_handler(&self) -> RequestHandler {
|
||||
RequestHandlerBuilder::new(&*self.config.server.base_directory)
|
||||
.with_file_handler(DefaultFileHandler)
|
||||
.with_directory_handler(DefaultDirectoryHandler)
|
||||
.with_config(Arc::new(self.config.clone()))
|
||||
.build()
|
||||
}
|
||||
|
||||
fn add_routes(self, server: GurtServer, request_handler: RequestHandler) -> GurtServer {
|
||||
let request_handler = Arc::new(request_handler);
|
||||
|
||||
let server = server
|
||||
.get("/", {
|
||||
let handler = request_handler.clone();
|
||||
move |ctx| {
|
||||
let handler = handler.clone();
|
||||
let ctx_clone = ctx.clone();
|
||||
async move {
|
||||
handler.handle_root_request_with_context(ctx_clone).await
|
||||
}
|
||||
}
|
||||
})
|
||||
.get("/*", {
|
||||
let handler = request_handler.clone();
|
||||
move |ctx| {
|
||||
let handler = handler.clone();
|
||||
let path = ctx.path().to_string();
|
||||
let ctx_clone = ctx.clone();
|
||||
async move {
|
||||
handler.handle_file_request_with_context(&path, ctx_clone).await
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let server = server
|
||||
.post("/", {
|
||||
let handler = request_handler.clone();
|
||||
move |ctx| {
|
||||
let handler = handler.clone();
|
||||
let ctx_clone = ctx.clone();
|
||||
async move {
|
||||
handler.handle_method_request_with_context(ctx_clone).await
|
||||
}
|
||||
}
|
||||
})
|
||||
.post("/*", {
|
||||
let handler = request_handler.clone();
|
||||
move |ctx| {
|
||||
let handler = handler.clone();
|
||||
let ctx_clone = ctx.clone();
|
||||
async move {
|
||||
handler.handle_method_request_with_context(ctx_clone).await
|
||||
}
|
||||
}
|
||||
})
|
||||
.put("/", {
|
||||
let handler = request_handler.clone();
|
||||
move |ctx| {
|
||||
let handler = handler.clone();
|
||||
let ctx_clone = ctx.clone();
|
||||
async move {
|
||||
handler.handle_method_request_with_context(ctx_clone).await
|
||||
}
|
||||
}
|
||||
})
|
||||
.put("/*", {
|
||||
let handler = request_handler.clone();
|
||||
move |ctx| {
|
||||
let handler = handler.clone();
|
||||
let ctx_clone = ctx.clone();
|
||||
async move {
|
||||
handler.handle_method_request_with_context(ctx_clone).await
|
||||
}
|
||||
}
|
||||
})
|
||||
.delete("/", {
|
||||
let handler = request_handler.clone();
|
||||
move |ctx| {
|
||||
let handler = handler.clone();
|
||||
let ctx_clone = ctx.clone();
|
||||
async move {
|
||||
handler.handle_method_request_with_context(ctx_clone).await
|
||||
}
|
||||
}
|
||||
})
|
||||
.delete("/*", {
|
||||
let handler = request_handler.clone();
|
||||
move |ctx| {
|
||||
let handler = handler.clone();
|
||||
let ctx_clone = ctx.clone();
|
||||
async move {
|
||||
handler.handle_method_request_with_context(ctx_clone).await
|
||||
}
|
||||
}
|
||||
})
|
||||
.patch("/", {
|
||||
let handler = request_handler.clone();
|
||||
move |ctx| {
|
||||
let handler = handler.clone();
|
||||
let ctx_clone = ctx.clone();
|
||||
async move {
|
||||
handler.handle_method_request_with_context(ctx_clone).await
|
||||
}
|
||||
}
|
||||
})
|
||||
.patch("/*", {
|
||||
let handler = request_handler.clone();
|
||||
move |ctx| {
|
||||
let handler = handler.clone();
|
||||
let ctx_clone = ctx.clone();
|
||||
async move {
|
||||
handler.handle_method_request_with_context(ctx_clone).await
|
||||
}
|
||||
}
|
||||
})
|
||||
.options("/", {
|
||||
let handler = request_handler.clone();
|
||||
move |ctx| {
|
||||
let handler = handler.clone();
|
||||
let ctx_clone = ctx.clone();
|
||||
async move {
|
||||
handler.handle_method_request_with_context(ctx_clone).await
|
||||
}
|
||||
}
|
||||
})
|
||||
.options("/*", {
|
||||
let handler = request_handler.clone();
|
||||
move |ctx| {
|
||||
let handler = handler.clone();
|
||||
let ctx_clone = ctx.clone();
|
||||
async move {
|
||||
handler.handle_method_request_with_context(ctx_clone).await
|
||||
}
|
||||
}
|
||||
})
|
||||
.head("/", {
|
||||
let handler = request_handler.clone();
|
||||
move |ctx| {
|
||||
let handler = handler.clone();
|
||||
let ctx_clone = ctx.clone();
|
||||
async move {
|
||||
handler.handle_method_request_with_context(ctx_clone).await
|
||||
}
|
||||
}
|
||||
})
|
||||
.head("/*", {
|
||||
let handler = request_handler.clone();
|
||||
move |ctx| {
|
||||
let handler = handler.clone();
|
||||
let ctx_clone = ctx.clone();
|
||||
async move {
|
||||
handler.handle_method_request_with_context(ctx_clone).await
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
server
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
pub struct GurtServerBuilder {
|
||||
cert_path: Option<PathBuf>,
|
||||
key_path: Option<PathBuf>,
|
||||
host: Option<String>,
|
||||
port: Option<u16>,
|
||||
handshake_timeout: Option<std::time::Duration>,
|
||||
request_timeout: Option<std::time::Duration>,
|
||||
connection_timeout: Option<std::time::Duration>,
|
||||
}
|
||||
|
||||
impl GurtServerBuilder {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
cert_path: None,
|
||||
key_path: None,
|
||||
host: None,
|
||||
port: None,
|
||||
handshake_timeout: None,
|
||||
request_timeout: None,
|
||||
connection_timeout: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_tls_certificates<P: Into<PathBuf>>(mut self, cert_path: P, key_path: P) -> Self {
|
||||
self.cert_path = Some(cert_path.into());
|
||||
self.key_path = Some(key_path.into());
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_host<S: Into<String>>(mut self, host: S) -> Self {
|
||||
self.host = Some(host.into());
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_port(mut self, port: u16) -> Self {
|
||||
self.port = Some(port);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_timeouts(mut self, handshake_timeout: std::time::Duration, request_timeout: std::time::Duration, connection_timeout: std::time::Duration) -> Self {
|
||||
self.handshake_timeout = Some(handshake_timeout);
|
||||
self.request_timeout = Some(request_timeout);
|
||||
self.connection_timeout = Some(connection_timeout);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn build(self) -> crate::Result<GurtServer> {
|
||||
match (self.cert_path, self.key_path) {
|
||||
(Some(cert), Some(key)) => {
|
||||
let mut server = GurtServer::with_tls_certificates(
|
||||
cert.to_str().ok_or_else(|| {
|
||||
crate::ServerError::TlsConfiguration("Invalid certificate path".to_string())
|
||||
})?,
|
||||
key.to_str().ok_or_else(|| {
|
||||
crate::ServerError::TlsConfiguration("Invalid key path".to_string())
|
||||
})?
|
||||
).map_err(crate::ServerError::from)?;
|
||||
|
||||
if let (Some(handshake), Some(request), Some(connection)) =
|
||||
(self.handshake_timeout, self.request_timeout, self.connection_timeout) {
|
||||
server = server.with_timeouts(handshake, request, connection);
|
||||
}
|
||||
|
||||
Ok(server)
|
||||
}
|
||||
_ => {
|
||||
Err(crate::ServerError::TlsConfiguration(
|
||||
"TLS certificates are required. Use with_tls_certificates() to provide them.".to_string()
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for GurtServerBuilder {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
31
protocol/cli/templates/404.html
Normal file
31
protocol/cli/templates/404.html
Normal file
@@ -0,0 +1,31 @@
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<title>404 Not Found</title>
|
||||
<style>
|
||||
body {
|
||||
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif;
|
||||
margin: 40px;
|
||||
text-align: center;
|
||||
background: #f5f5f5;
|
||||
}
|
||||
.container {
|
||||
background: white;
|
||||
padding: 60px 40px;
|
||||
border-radius: 8px;
|
||||
box-shadow: 0 2px 10px rgba(0,0,0,0.1);
|
||||
display: inline-block;
|
||||
}
|
||||
h1 { color: #d32f2f; }
|
||||
a { color: #0066cc; text-decoration: none; }
|
||||
a:hover { text-decoration: underline; }
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="container">
|
||||
<h1>404 Page Not Found</h1>
|
||||
<p>The requested path was not found on this GURT server.</p>
|
||||
<p><a href="/">Back to home</a></p>
|
||||
</div>
|
||||
</body>
|
||||
</html>
|
||||
1
protocol/cli/templates/directory_content_start.html
Normal file
1
protocol/cli/templates/directory_content_start.html
Normal file
@@ -0,0 +1 @@
|
||||
<div>
|
||||
4
protocol/cli/templates/directory_listing_end.html
Normal file
4
protocol/cli/templates/directory_listing_end.html
Normal file
@@ -0,0 +1,4 @@
|
||||
</div>
|
||||
</div>
|
||||
</body>
|
||||
</html>
|
||||
43
protocol/cli/templates/directory_listing_start.html
Normal file
43
protocol/cli/templates/directory_listing_start.html
Normal file
@@ -0,0 +1,43 @@
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<title>Directory Listing</title>
|
||||
<style>
|
||||
body {
|
||||
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif;
|
||||
margin: 40px;
|
||||
background: #f5f5f5;
|
||||
}
|
||||
.container {
|
||||
background: white;
|
||||
padding: 30px;
|
||||
border-radius: 8px;
|
||||
box-shadow: 0 2px 10px rgba(0,0,0,0.1);
|
||||
}
|
||||
.dir {
|
||||
font-weight: bold;
|
||||
color: #0066cc;
|
||||
}
|
||||
.file {
|
||||
color: #333;
|
||||
}
|
||||
a {
|
||||
text-decoration: none;
|
||||
display: block;
|
||||
padding: 8px 12px;
|
||||
margin: 2px 0;
|
||||
border-radius: 4px;
|
||||
transition: background-color 0.2s;
|
||||
}
|
||||
a:hover {
|
||||
background-color: #f0f0f0;
|
||||
}
|
||||
.parent {
|
||||
color: #666;
|
||||
font-style: italic;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="container">
|
||||
<h1>Directory Listing</h1>
|
||||
1
protocol/cli/templates/directory_parent_link.html
Normal file
1
protocol/cli/templates/directory_parent_link.html
Normal file
@@ -0,0 +1 @@
|
||||
<a href="../" class="parent">← Parent Directory</a>
|
||||
33
protocol/cli/templates/error.html
Normal file
33
protocol/cli/templates/error.html
Normal file
@@ -0,0 +1,33 @@
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<title>{} {}</title>
|
||||
<style>
|
||||
body {{
|
||||
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif;
|
||||
margin: 40px;
|
||||
text-align: center;
|
||||
background: #f5f5f5;
|
||||
}}
|
||||
.container {{
|
||||
background: white;
|
||||
padding: 60px 40px;
|
||||
border-radius: 8px;
|
||||
box-shadow: 0 2px 10px rgba(0,0,0,0.1);
|
||||
display: inline-block;
|
||||
max-width: 500px;
|
||||
}}
|
||||
h1 {{ color: #d32f2f; margin-bottom: 20px; }}
|
||||
p {{ color: #555; margin-bottom: 30px; }}
|
||||
a {{ color: #0066cc; text-decoration: none; }}
|
||||
a:hover {{ text-decoration: underline; }}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="container">
|
||||
<h1>{} {}</h1>
|
||||
<p>{}</p>
|
||||
<p><a href="/">Back to home</a></p>
|
||||
</div>
|
||||
</body>
|
||||
</html>
|
||||
@@ -184,12 +184,11 @@ impl GurtClient {
|
||||
|
||||
let bytes_read = conn.connection.read(&mut temp_buffer).await?;
|
||||
if bytes_read == 0 {
|
||||
break; // Connection closed
|
||||
break;
|
||||
}
|
||||
|
||||
buffer.extend_from_slice(&temp_buffer[..bytes_read]);
|
||||
|
||||
// Check for complete message
|
||||
let body_separator = BODY_SEPARATOR.as_bytes();
|
||||
|
||||
if !headers_parsed {
|
||||
@@ -197,7 +196,6 @@ impl GurtClient {
|
||||
headers_end_pos = Some(pos + body_separator.len());
|
||||
headers_parsed = true;
|
||||
|
||||
// Parse headers to get Content-Length
|
||||
let headers_section = &buffer[..pos];
|
||||
if let Ok(headers_str) = std::str::from_utf8(headers_section) {
|
||||
for line in headers_str.lines() {
|
||||
@@ -220,7 +218,6 @@ impl GurtClient {
|
||||
return Ok(buffer);
|
||||
}
|
||||
} else if headers_parsed && expected_body_length.is_none() {
|
||||
// No Content-Length header, return what we have after headers
|
||||
return Ok(buffer);
|
||||
}
|
||||
}
|
||||
@@ -329,7 +326,7 @@ impl GurtClient {
|
||||
}
|
||||
|
||||
match timeout(Duration::from_millis(100), tls_stream.read(&mut temp_buffer)).await {
|
||||
Ok(Ok(0)) => break, // Connection closed
|
||||
Ok(Ok(0)) => break,
|
||||
Ok(Ok(n)) => {
|
||||
buffer.extend_from_slice(&temp_buffer[..n]);
|
||||
|
||||
@@ -394,7 +391,6 @@ impl GurtClient {
|
||||
self.send_request_internal(&host, port, request).await
|
||||
}
|
||||
|
||||
/// POST request with JSON body
|
||||
pub async fn post_json<T: serde::Serialize>(&self, url: &str, data: &T) -> Result<GurtResponse> {
|
||||
let (host, port, path) = self.parse_url(url)?;
|
||||
let json_body = serde_json::to_string(data)?;
|
||||
@@ -408,7 +404,6 @@ impl GurtClient {
|
||||
self.send_request_internal(&host, port, request).await
|
||||
}
|
||||
|
||||
/// PUT request with body
|
||||
pub async fn put(&self, url: &str, body: &str) -> Result<GurtResponse> {
|
||||
let (host, port, path) = self.parse_url(url)?;
|
||||
let request = GurtRequest::new(GurtMethod::PUT, path)
|
||||
@@ -420,7 +415,6 @@ impl GurtClient {
|
||||
self.send_request_internal(&host, port, request).await
|
||||
}
|
||||
|
||||
/// PUT request with JSON body
|
||||
pub async fn put_json<T: serde::Serialize>(&self, url: &str, data: &T) -> Result<GurtResponse> {
|
||||
let (host, port, path) = self.parse_url(url)?;
|
||||
let json_body = serde_json::to_string(data)?;
|
||||
@@ -461,7 +455,6 @@ impl GurtClient {
|
||||
self.send_request_internal(&host, port, request).await
|
||||
}
|
||||
|
||||
/// PATCH request with body
|
||||
pub async fn patch(&self, url: &str, body: &str) -> Result<GurtResponse> {
|
||||
let (host, port, path) = self.parse_url(url)?;
|
||||
let request = GurtRequest::new(GurtMethod::PATCH, path)
|
||||
@@ -473,7 +466,6 @@ impl GurtClient {
|
||||
self.send_request_internal(&host, port, request).await
|
||||
}
|
||||
|
||||
/// PATCH request with JSON body
|
||||
pub async fn patch_json<T: serde::Serialize>(&self, url: &str, data: &T) -> Result<GurtResponse> {
|
||||
let (host, port, path) = self.parse_url(url)?;
|
||||
let json_body = serde_json::to_string(data)?;
|
||||
@@ -548,4 +540,38 @@ mod tests {
|
||||
assert_eq!(port, 8080);
|
||||
assert_eq!(path, "/api/v1");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_connection_pooling_config() {
|
||||
let config = GurtClientConfig {
|
||||
enable_connection_pooling: true,
|
||||
max_connections_per_host: 8,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let client = GurtClient::with_config(config);
|
||||
assert!(client.config.enable_connection_pooling);
|
||||
assert_eq!(client.config.max_connections_per_host, 8);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_connection_key() {
|
||||
let key1 = ConnectionKey {
|
||||
host: "example.com".to_string(),
|
||||
port: 4878,
|
||||
};
|
||||
|
||||
let key2 = ConnectionKey {
|
||||
host: "example.com".to_string(),
|
||||
port: 4878,
|
||||
};
|
||||
|
||||
let key3 = ConnectionKey {
|
||||
host: "other.com".to_string(),
|
||||
port: 4878,
|
||||
};
|
||||
|
||||
assert_eq!(key1, key2);
|
||||
assert_ne!(key1, key3);
|
||||
}
|
||||
}
|
||||
@@ -14,7 +14,7 @@ pub enum GurtMethod {
|
||||
HEAD,
|
||||
OPTIONS,
|
||||
PATCH,
|
||||
HANDSHAKE, // Special method for protocol handshake
|
||||
HANDSHAKE,
|
||||
}
|
||||
|
||||
impl GurtMethod {
|
||||
@@ -101,7 +101,6 @@ impl GurtRequest {
|
||||
}
|
||||
|
||||
pub fn parse_bytes(data: &[u8]) -> Result<Self> {
|
||||
// Find the header/body separator as bytes
|
||||
let body_separator = BODY_SEPARATOR.as_bytes();
|
||||
let body_separator_pos = data.windows(body_separator.len())
|
||||
.position(|window| window == body_separator);
|
||||
@@ -114,7 +113,6 @@ impl GurtRequest {
|
||||
(data, Vec::new())
|
||||
};
|
||||
|
||||
// Convert headers section to string (should be valid UTF-8)
|
||||
let headers_str = std::str::from_utf8(headers_section)
|
||||
.map_err(|_| GurtError::invalid_message("Invalid UTF-8 in headers"))?;
|
||||
|
||||
@@ -124,7 +122,6 @@ impl GurtRequest {
|
||||
return Err(GurtError::invalid_message("Empty request"));
|
||||
}
|
||||
|
||||
// Parse request line (METHOD path GURT/version)
|
||||
let request_line = lines[0];
|
||||
let parts: Vec<&str> = request_line.split_whitespace().collect();
|
||||
|
||||
@@ -135,7 +132,6 @@ impl GurtRequest {
|
||||
let method = GurtMethod::parse(parts[0])?;
|
||||
let path = parts[1].to_string();
|
||||
|
||||
// Parse protocol version
|
||||
if !parts[2].starts_with(PROTOCOL_PREFIX) {
|
||||
return Err(GurtError::invalid_message("Invalid protocol identifier"));
|
||||
}
|
||||
@@ -143,7 +139,6 @@ impl GurtRequest {
|
||||
let version_str = &parts[2][PROTOCOL_PREFIX.len()..];
|
||||
let version = version_str.to_string();
|
||||
|
||||
// Parse headers
|
||||
let mut headers = GurtHeaders::new();
|
||||
|
||||
for line in lines.iter().skip(1) {
|
||||
@@ -253,6 +248,10 @@ impl GurtResponse {
|
||||
Self::new(GurtStatusCode::BadRequest)
|
||||
}
|
||||
|
||||
pub fn forbidden() -> Self {
|
||||
Self::new(GurtStatusCode::Forbidden)
|
||||
}
|
||||
|
||||
pub fn internal_server_error() -> Self {
|
||||
Self::new(GurtStatusCode::InternalServerError)
|
||||
}
|
||||
@@ -306,7 +305,6 @@ impl GurtResponse {
|
||||
}
|
||||
|
||||
pub fn parse_bytes(data: &[u8]) -> Result<Self> {
|
||||
// Find the header/body separator as bytes
|
||||
let body_separator = BODY_SEPARATOR.as_bytes();
|
||||
let body_separator_pos = data.windows(body_separator.len())
|
||||
.position(|window| window == body_separator);
|
||||
@@ -319,7 +317,6 @@ impl GurtResponse {
|
||||
(data, Vec::new())
|
||||
};
|
||||
|
||||
// Convert headers section to string (should be valid UTF-8)
|
||||
let headers_str = std::str::from_utf8(headers_section)
|
||||
.map_err(|_| GurtError::invalid_message("Invalid UTF-8 in headers"))?;
|
||||
|
||||
@@ -329,7 +326,6 @@ impl GurtResponse {
|
||||
return Err(GurtError::invalid_message("Empty response"));
|
||||
}
|
||||
|
||||
// Parse status line (GURT/version status_code status_message)
|
||||
let status_line = lines[0];
|
||||
let parts: Vec<&str> = status_line.splitn(3, ' ').collect();
|
||||
|
||||
@@ -337,7 +333,6 @@ impl GurtResponse {
|
||||
return Err(GurtError::invalid_message("Invalid status line format"));
|
||||
}
|
||||
|
||||
// Parse protocol version
|
||||
if !parts[0].starts_with(PROTOCOL_PREFIX) {
|
||||
return Err(GurtError::invalid_message("Invalid protocol identifier"));
|
||||
}
|
||||
@@ -356,7 +351,6 @@ impl GurtResponse {
|
||||
.unwrap_or_else(|| "Unknown".to_string())
|
||||
};
|
||||
|
||||
// Parse headers
|
||||
let mut headers = GurtHeaders::new();
|
||||
|
||||
for line in lines.iter().skip(1) {
|
||||
@@ -394,7 +388,6 @@ impl GurtResponse {
|
||||
}
|
||||
|
||||
if !headers.contains_key("date") {
|
||||
// RFC 7231 compliant
|
||||
let now = Utc::now();
|
||||
let date_str = now.format("%a, %d %b %Y %H:%M:%S GMT").to_string();
|
||||
headers.insert("date".to_string(), date_str);
|
||||
@@ -429,7 +422,6 @@ impl GurtResponse {
|
||||
}
|
||||
|
||||
if !headers.contains_key("date") {
|
||||
// RFC 7231 compliant
|
||||
let now = Utc::now();
|
||||
let date_str = now.format("%a, %d %b %Y %H:%M:%S GMT").to_string();
|
||||
headers.insert("date".to_string(), date_str);
|
||||
@@ -441,7 +433,6 @@ impl GurtResponse {
|
||||
|
||||
message.push_str(HEADER_SEPARATOR);
|
||||
|
||||
// Convert headers to bytes and append body as raw bytes
|
||||
let mut bytes = message.into_bytes();
|
||||
bytes.extend_from_slice(&self.body);
|
||||
|
||||
@@ -461,7 +452,6 @@ impl GurtMessage {
|
||||
}
|
||||
|
||||
pub fn parse_bytes(data: &[u8]) -> Result<Self> {
|
||||
// Convert first line to string to determine message type
|
||||
let header_separator = HEADER_SEPARATOR.as_bytes();
|
||||
let first_line_end = data.windows(header_separator.len())
|
||||
.position(|window| window == header_separator)
|
||||
@@ -470,7 +460,6 @@ impl GurtMessage {
|
||||
let first_line = std::str::from_utf8(&data[..first_line_end])
|
||||
.map_err(|_| GurtError::invalid_message("Invalid UTF-8 in first line"))?;
|
||||
|
||||
// Check if it's a response (starts with GURT/version) or request (method first)
|
||||
if first_line.starts_with(PROTOCOL_PREFIX) {
|
||||
Ok(GurtMessage::Response(GurtResponse::parse_bytes(data)?))
|
||||
} else {
|
||||
|
||||
@@ -37,6 +37,7 @@ pub enum GurtStatusCode {
|
||||
Timeout = 408,
|
||||
TooLarge = 413,
|
||||
UnsupportedMediaType = 415,
|
||||
TooManyRequests = 429,
|
||||
|
||||
// Server errors
|
||||
InternalServerError = 500,
|
||||
@@ -62,6 +63,7 @@ impl GurtStatusCode {
|
||||
408 => Some(Self::Timeout),
|
||||
413 => Some(Self::TooLarge),
|
||||
415 => Some(Self::UnsupportedMediaType),
|
||||
429 => Some(Self::TooManyRequests),
|
||||
500 => Some(Self::InternalServerError),
|
||||
501 => Some(Self::NotImplemented),
|
||||
502 => Some(Self::BadGateway),
|
||||
@@ -86,6 +88,7 @@ impl GurtStatusCode {
|
||||
Self::Timeout => "TIMEOUT",
|
||||
Self::TooLarge => "TOO_LARGE",
|
||||
Self::UnsupportedMediaType => "UNSUPPORTED_MEDIA_TYPE",
|
||||
Self::TooManyRequests => "TOO_MANY_REQUESTS",
|
||||
Self::InternalServerError => "INTERNAL_SERVER_ERROR",
|
||||
Self::NotImplemented => "NOT_IMPLEMENTED",
|
||||
Self::BadGateway => "BAD_GATEWAY",
|
||||
|
||||
@@ -7,6 +7,7 @@ use crate::{
|
||||
};
|
||||
use tokio::net::{TcpListener, TcpStream};
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||||
use tokio::time::{timeout, Duration};
|
||||
use tokio_rustls::{TlsAcceptor, server::TlsStream};
|
||||
use rustls::pki_types::CertificateDer;
|
||||
use std::collections::HashMap;
|
||||
@@ -136,6 +137,9 @@ impl Route {
|
||||
pub struct GurtServer {
|
||||
routes: Vec<(Route, Arc<dyn GurtHandler>)>,
|
||||
tls_acceptor: Option<TlsAcceptor>,
|
||||
handshake_timeout: Duration,
|
||||
request_timeout: Duration,
|
||||
connection_timeout: Duration,
|
||||
}
|
||||
|
||||
impl GurtServer {
|
||||
@@ -143,9 +147,19 @@ impl GurtServer {
|
||||
Self {
|
||||
routes: Vec::new(),
|
||||
tls_acceptor: None,
|
||||
handshake_timeout: Duration::from_secs(5),
|
||||
request_timeout: Duration::from_secs(30),
|
||||
connection_timeout: Duration::from_secs(10),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_timeouts(mut self, handshake_timeout: Duration, request_timeout: Duration, connection_timeout: Duration) -> Self {
|
||||
self.handshake_timeout = handshake_timeout;
|
||||
self.request_timeout = request_timeout;
|
||||
self.connection_timeout = connection_timeout;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_tls_certificates(cert_path: &str, key_path: &str) -> Result<Self> {
|
||||
let mut server = Self::new();
|
||||
server.load_tls_certificates(cert_path, key_path)?;
|
||||
@@ -279,6 +293,7 @@ impl GurtServer {
|
||||
}
|
||||
|
||||
async fn handle_connection(&self, mut stream: TcpStream, addr: SocketAddr) -> Result<()> {
|
||||
let connection_result = timeout(self.connection_timeout, async {
|
||||
self.handle_initial_handshake(&mut stream, addr).await?;
|
||||
|
||||
if let Some(tls_acceptor) = &self.tls_acceptor {
|
||||
@@ -293,9 +308,19 @@ impl GurtServer {
|
||||
warn!("No TLS configuration available, but handshake completed - this violates GURT protocol");
|
||||
Err(GurtError::protocol("TLS is required after handshake but no TLS configuration available"))
|
||||
}
|
||||
}).await;
|
||||
|
||||
match connection_result {
|
||||
Ok(result) => result,
|
||||
Err(_) => {
|
||||
warn!("Connection timeout for {}", addr);
|
||||
Err(GurtError::timeout("Connection timeout"))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_initial_handshake(&self, stream: &mut TcpStream, addr: SocketAddr) -> Result<()> {
|
||||
let handshake_result = timeout(self.handshake_timeout, async {
|
||||
let mut buffer = Vec::new();
|
||||
let mut temp_buffer = [0u8; 8192];
|
||||
|
||||
@@ -317,7 +342,6 @@ impl GurtServer {
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
let message = GurtMessage::parse_bytes(&buffer)?;
|
||||
|
||||
match message {
|
||||
@@ -332,6 +356,15 @@ impl GurtServer {
|
||||
Err(GurtError::protocol("Server received response during handshake"))
|
||||
}
|
||||
}
|
||||
}).await;
|
||||
|
||||
match handshake_result {
|
||||
Ok(result) => result,
|
||||
Err(_) => {
|
||||
warn!("Handshake timeout for {}", addr);
|
||||
Err(GurtError::timeout("Handshake timeout"))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_tls_connection(&self, mut tls_stream: TlsStream<TcpStream>, addr: SocketAddr) -> Result<()> {
|
||||
@@ -342,7 +375,6 @@ impl GurtServer {
|
||||
let bytes_read = match tls_stream.read(&mut temp_buffer).await {
|
||||
Ok(n) => n,
|
||||
Err(e) => {
|
||||
// Handle UnexpectedEof from clients that don't send close_notify
|
||||
if e.kind() == std::io::ErrorKind::UnexpectedEof {
|
||||
debug!("Client {} closed connection without TLS close_notify (benign)", addr);
|
||||
break;
|
||||
@@ -351,7 +383,7 @@ impl GurtServer {
|
||||
}
|
||||
};
|
||||
if bytes_read == 0 {
|
||||
break; // Connection closed
|
||||
break;
|
||||
}
|
||||
|
||||
buffer.extend_from_slice(&temp_buffer[..bytes_read]);
|
||||
@@ -361,17 +393,31 @@ impl GurtServer {
|
||||
(buffer.starts_with(b"{") && buffer.ends_with(b"}"));
|
||||
|
||||
if has_complete_message {
|
||||
if let Err(e) = self.process_tls_message(&mut tls_stream, addr, &buffer).await {
|
||||
let process_result = timeout(self.request_timeout,
|
||||
self.process_tls_message(&mut tls_stream, addr, &buffer)
|
||||
).await;
|
||||
|
||||
match process_result {
|
||||
Ok(Ok(())) => {
|
||||
debug!("Processed message from {} successfully", addr);
|
||||
}
|
||||
Ok(Err(e)) => {
|
||||
error!("Encrypted message processing error from {}: {}", addr, e);
|
||||
let error_response = GurtResponse::internal_server_error()
|
||||
.with_string_body("Internal server error");
|
||||
let _ = tls_stream.write_all(&error_response.to_bytes()).await;
|
||||
}
|
||||
Err(_) => {
|
||||
warn!("Request timeout for {}", addr);
|
||||
let timeout_response = GurtResponse::new(GurtStatusCode::Timeout)
|
||||
.with_string_body("Request timeout");
|
||||
let _ = tls_stream.write_all(&timeout_response.to_bytes()).await;
|
||||
}
|
||||
}
|
||||
|
||||
buffer.clear();
|
||||
}
|
||||
|
||||
// Prevent buffer overflow
|
||||
if buffer.len() > MAX_MESSAGE_SIZE {
|
||||
warn!("Message too large from {}, closing connection", addr);
|
||||
break;
|
||||
@@ -422,7 +468,6 @@ impl GurtServer {
|
||||
if let Some(method) = &route.method {
|
||||
allowed_methods.insert(method.to_string());
|
||||
} else {
|
||||
// Route matches any method
|
||||
allowed_methods.extend(vec![
|
||||
"GET".to_string(), "POST".to_string(), "PUT".to_string(),
|
||||
"DELETE".to_string(), "HEAD".to_string(), "PATCH".to_string()
|
||||
@@ -482,7 +527,6 @@ impl GurtServer {
|
||||
async fn handle_encrypted_request(&self, tls_stream: &mut TlsStream<TcpStream>, addr: SocketAddr, request: &GurtRequest) -> Result<()> {
|
||||
debug!("Handling encrypted {} request to {} from {}", request.method, request.path, addr);
|
||||
|
||||
// Find matching route
|
||||
for (route, handler) in &self.routes {
|
||||
if route.matches(&request.method, &request.path) {
|
||||
let context = ServerContext {
|
||||
@@ -492,7 +536,6 @@ impl GurtServer {
|
||||
|
||||
match handler.handle(&context).await {
|
||||
Ok(response) => {
|
||||
// Use to_bytes() to avoid corrupting binary data
|
||||
let response_bytes = response.to_bytes();
|
||||
tls_stream.write_all(&response_bytes).await?;
|
||||
return Ok(());
|
||||
@@ -508,7 +551,6 @@ impl GurtServer {
|
||||
}
|
||||
}
|
||||
|
||||
// No route found - check for default OPTIONS/HEAD handling
|
||||
match request.method {
|
||||
GurtMethod::OPTIONS => {
|
||||
self.handle_default_options(tls_stream, request).await
|
||||
@@ -531,6 +573,9 @@ impl Clone for GurtServer {
|
||||
Self {
|
||||
routes: self.routes.clone(),
|
||||
tls_acceptor: self.tls_acceptor.clone(),
|
||||
handshake_timeout: self.handshake_timeout,
|
||||
request_timeout: self.request_timeout,
|
||||
connection_timeout: self.connection_timeout,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user