608 lines
19 KiB
Rust
608 lines
19 KiB
Rust
|
|
use serde::{Deserialize, Serialize};
|
||
|
|
use std::collections::HashMap;
|
||
|
|
use std::path::PathBuf;
|
||
|
|
use std::sync::Arc;
|
||
|
|
use std::time::Duration;
|
||
|
|
|
||
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||
|
|
pub struct GurtConfig {
|
||
|
|
pub server: ServerConfig,
|
||
|
|
pub tls: Option<TlsConfig>,
|
||
|
|
pub logging: Option<LoggingConfig>,
|
||
|
|
pub security: Option<SecurityConfig>,
|
||
|
|
pub error_pages: Option<ErrorPagesConfig>,
|
||
|
|
pub headers: Option<HashMap<String, String>>,
|
||
|
|
}
|
||
|
|
|
||
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||
|
|
pub struct ServerConfig {
|
||
|
|
#[serde(default = "default_host")]
|
||
|
|
pub host: String,
|
||
|
|
|
||
|
|
#[serde(default = "default_port")]
|
||
|
|
pub port: u16,
|
||
|
|
|
||
|
|
#[serde(default = "default_protocol_version")]
|
||
|
|
pub protocol_version: String,
|
||
|
|
|
||
|
|
#[serde(default = "default_alpn_identifier")]
|
||
|
|
pub alpn_identifier: String,
|
||
|
|
|
||
|
|
pub timeouts: Option<TimeoutsConfig>,
|
||
|
|
|
||
|
|
#[serde(default = "default_max_connections")]
|
||
|
|
pub max_connections: u32,
|
||
|
|
|
||
|
|
#[serde(default = "default_max_message_size")]
|
||
|
|
pub max_message_size: String,
|
||
|
|
|
||
|
|
#[serde(skip)]
|
||
|
|
pub base_directory: Arc<PathBuf>,
|
||
|
|
|
||
|
|
#[serde(skip)]
|
||
|
|
pub verbose: bool,
|
||
|
|
}
|
||
|
|
|
||
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||
|
|
pub struct TimeoutsConfig {
|
||
|
|
#[serde(default = "default_handshake_timeout")]
|
||
|
|
pub handshake: u64,
|
||
|
|
|
||
|
|
#[serde(default = "default_request_timeout")]
|
||
|
|
pub request: u64,
|
||
|
|
|
||
|
|
#[serde(default = "default_connection_timeout")]
|
||
|
|
pub connection: u64,
|
||
|
|
|
||
|
|
#[serde(default = "default_pool_idle_timeout")]
|
||
|
|
pub pool_idle: u64,
|
||
|
|
}
|
||
|
|
|
||
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||
|
|
pub struct TlsConfig {
|
||
|
|
pub certificate: PathBuf,
|
||
|
|
pub private_key: PathBuf,
|
||
|
|
}
|
||
|
|
|
||
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||
|
|
pub struct LoggingConfig {
|
||
|
|
#[serde(default = "default_log_level")]
|
||
|
|
pub level: String,
|
||
|
|
|
||
|
|
pub access_log: Option<PathBuf>,
|
||
|
|
pub error_log: Option<PathBuf>,
|
||
|
|
|
||
|
|
#[serde(default = "default_log_requests")]
|
||
|
|
pub log_requests: bool,
|
||
|
|
|
||
|
|
#[serde(default)]
|
||
|
|
pub log_responses: bool,
|
||
|
|
}
|
||
|
|
|
||
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||
|
|
pub struct SecurityConfig {
|
||
|
|
#[serde(default)]
|
||
|
|
pub deny_files: Vec<String>,
|
||
|
|
|
||
|
|
#[serde(default = "default_allowed_methods")]
|
||
|
|
pub allowed_methods: Vec<String>,
|
||
|
|
|
||
|
|
#[serde(default = "default_rate_limit_requests")]
|
||
|
|
pub rate_limit_requests: u32,
|
||
|
|
|
||
|
|
#[serde(default = "default_rate_limit_connections")]
|
||
|
|
pub rate_limit_connections: u32,
|
||
|
|
}
|
||
|
|
|
||
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||
|
|
pub struct ErrorPagesConfig {
|
||
|
|
#[serde(flatten)]
|
||
|
|
pub pages: HashMap<String, String>,
|
||
|
|
|
||
|
|
pub default: Option<ErrorPageDefaults>,
|
||
|
|
}
|
||
|
|
|
||
|
|
impl ErrorPagesConfig {
|
||
|
|
pub fn get_page(&self, status_code: u16) -> Option<&String> {
|
||
|
|
let code_str = status_code.to_string();
|
||
|
|
self.pages.get(&code_str)
|
||
|
|
}
|
||
|
|
|
||
|
|
pub fn get_default_page(&self, status_code: u16) -> Option<&String> {
|
||
|
|
if let Some(defaults) = &self.default {
|
||
|
|
let code_str = status_code.to_string();
|
||
|
|
defaults.pages.get(&code_str)
|
||
|
|
} else {
|
||
|
|
None
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
pub fn get_any_page(&self, status_code: u16) -> Option<&String> {
|
||
|
|
self.get_page(status_code)
|
||
|
|
.or_else(|| self.get_default_page(status_code))
|
||
|
|
}
|
||
|
|
|
||
|
|
pub fn get_page_content(&self, status_code: u16, base_dir: &std::path::Path) -> Option<String> {
|
||
|
|
if let Some(page_value) = self.get_page(status_code) {
|
||
|
|
if page_value.starts_with('/') || page_value.starts_with("./") {
|
||
|
|
let file_path = if page_value.starts_with('/') {
|
||
|
|
base_dir.join(&page_value[1..])
|
||
|
|
} else {
|
||
|
|
base_dir.join(page_value)
|
||
|
|
};
|
||
|
|
|
||
|
|
if let Ok(content) = std::fs::read_to_string(&file_path) {
|
||
|
|
return Some(content);
|
||
|
|
} else {
|
||
|
|
tracing::warn!("Failed to read error page file: {}", file_path.display());
|
||
|
|
return None;
|
||
|
|
}
|
||
|
|
} else {
|
||
|
|
return Some(page_value.clone());
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
if let Some(page_value) = self.get_default_page(status_code) {
|
||
|
|
return Some(page_value.clone());
|
||
|
|
}
|
||
|
|
|
||
|
|
None
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||
|
|
pub struct ErrorPageDefaults {
|
||
|
|
#[serde(flatten)]
|
||
|
|
pub pages: HashMap<String, String>,
|
||
|
|
}
|
||
|
|
|
||
|
|
fn default_host() -> String { "127.0.0.1".to_string() }
|
||
|
|
fn default_port() -> u16 { 4878 }
|
||
|
|
fn default_protocol_version() -> String { "1.0.0".to_string() }
|
||
|
|
fn default_alpn_identifier() -> String { "GURT/1.0".to_string() }
|
||
|
|
fn default_max_connections() -> u32 { 10 }
|
||
|
|
fn default_max_message_size() -> String { "10MB".to_string() }
|
||
|
|
fn default_handshake_timeout() -> u64 { 5 }
|
||
|
|
fn default_request_timeout() -> u64 { 30 }
|
||
|
|
fn default_connection_timeout() -> u64 { 10 }
|
||
|
|
fn default_pool_idle_timeout() -> u64 { 300 }
|
||
|
|
fn default_log_level() -> String { "info".to_string() }
|
||
|
|
fn default_log_requests() -> bool { true }
|
||
|
|
fn default_allowed_methods() -> Vec<String> {
|
||
|
|
vec!["GET".to_string(), "POST".to_string(), "PUT".to_string(),
|
||
|
|
"DELETE".to_string(), "HEAD".to_string(), "OPTIONS".to_string(), "PATCH".to_string()]
|
||
|
|
}
|
||
|
|
fn default_rate_limit_requests() -> u32 { 100 }
|
||
|
|
fn default_rate_limit_connections() -> u32 { 10 }
|
||
|
|
|
||
|
|
impl Default for GurtConfig {
|
||
|
|
fn default() -> Self {
|
||
|
|
Self {
|
||
|
|
server: ServerConfig::default(),
|
||
|
|
tls: None,
|
||
|
|
logging: None,
|
||
|
|
security: None,
|
||
|
|
error_pages: None,
|
||
|
|
headers: None,
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
impl Default for ServerConfig {
|
||
|
|
fn default() -> Self {
|
||
|
|
Self {
|
||
|
|
host: default_host(),
|
||
|
|
port: default_port(),
|
||
|
|
protocol_version: default_protocol_version(),
|
||
|
|
alpn_identifier: default_alpn_identifier(),
|
||
|
|
timeouts: None,
|
||
|
|
max_connections: default_max_connections(),
|
||
|
|
max_message_size: default_max_message_size(),
|
||
|
|
base_directory: Arc::new(PathBuf::from(".")),
|
||
|
|
verbose: false,
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
impl GurtConfig {
|
||
|
|
pub fn from_file<P: AsRef<std::path::Path>>(path: P) -> crate::Result<Self> {
|
||
|
|
let content = std::fs::read_to_string(path)
|
||
|
|
.map_err(|e| crate::ServerError::InvalidConfiguration(format!("Failed to read config file: {}", e)))?;
|
||
|
|
|
||
|
|
let config: GurtConfig = toml::from_str(&content)
|
||
|
|
.map_err(|e| crate::ServerError::InvalidConfiguration(format!("Failed to parse config file: {}", e)))?;
|
||
|
|
|
||
|
|
Ok(config)
|
||
|
|
}
|
||
|
|
|
||
|
|
pub fn builder() -> GurtConfigBuilder {
|
||
|
|
GurtConfigBuilder::default()
|
||
|
|
}
|
||
|
|
|
||
|
|
pub fn address(&self) -> String {
|
||
|
|
format!("{}:{}", self.server.host, self.server.port)
|
||
|
|
}
|
||
|
|
|
||
|
|
pub fn max_message_size_bytes(&self) -> crate::Result<u64> {
|
||
|
|
parse_size(&self.server.max_message_size)
|
||
|
|
}
|
||
|
|
|
||
|
|
pub fn get_handshake_timeout(&self) -> Duration {
|
||
|
|
Duration::from_secs(
|
||
|
|
self.server.timeouts
|
||
|
|
.as_ref()
|
||
|
|
.map(|t| t.handshake)
|
||
|
|
.unwrap_or(default_handshake_timeout())
|
||
|
|
)
|
||
|
|
}
|
||
|
|
|
||
|
|
pub fn get_request_timeout(&self) -> Duration {
|
||
|
|
Duration::from_secs(
|
||
|
|
self.server.timeouts
|
||
|
|
.as_ref()
|
||
|
|
.map(|t| t.request)
|
||
|
|
.unwrap_or(default_request_timeout())
|
||
|
|
)
|
||
|
|
}
|
||
|
|
|
||
|
|
pub fn get_connection_timeout(&self) -> Duration {
|
||
|
|
Duration::from_secs(
|
||
|
|
self.server.timeouts
|
||
|
|
.as_ref()
|
||
|
|
.map(|t| t.connection)
|
||
|
|
.unwrap_or(default_connection_timeout())
|
||
|
|
)
|
||
|
|
}
|
||
|
|
|
||
|
|
pub fn should_deny_file(&self, file_path: &str) -> bool {
|
||
|
|
if let Some(security) = &self.security {
|
||
|
|
for pattern in &security.deny_files {
|
||
|
|
if matches_pattern(file_path, pattern) {
|
||
|
|
return true;
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
false
|
||
|
|
}
|
||
|
|
|
||
|
|
pub fn is_method_allowed(&self, method: &str) -> bool {
|
||
|
|
if let Some(security) = &self.security {
|
||
|
|
security.allowed_methods.contains(&method.to_uppercase())
|
||
|
|
} else {
|
||
|
|
default_allowed_methods().contains(&method.to_uppercase())
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
pub fn default_with_directory(base_dir: PathBuf) -> Self {
|
||
|
|
let mut config = Self::default();
|
||
|
|
config.server.base_directory = Arc::new(base_dir);
|
||
|
|
config
|
||
|
|
}
|
||
|
|
|
||
|
|
pub fn from_toml(toml_content: &str, base_dir: PathBuf) -> crate::Result<Self> {
|
||
|
|
let mut config: GurtConfig = toml::from_str(toml_content)
|
||
|
|
.map_err(|e| crate::ServerError::InvalidConfiguration(format!("Failed to parse config: {}", e)))?;
|
||
|
|
|
||
|
|
config.server.base_directory = Arc::new(base_dir);
|
||
|
|
Ok(config)
|
||
|
|
}
|
||
|
|
|
||
|
|
pub fn validate(&self) -> crate::Result<()> {
|
||
|
|
if !self.server.base_directory.exists() || !self.server.base_directory.is_dir() {
|
||
|
|
return Err(crate::ServerError::InvalidConfiguration(
|
||
|
|
format!("Invalid base directory: {}", self.server.base_directory.display())
|
||
|
|
));
|
||
|
|
}
|
||
|
|
|
||
|
|
if let Some(tls) = &self.tls {
|
||
|
|
if !tls.certificate.exists() {
|
||
|
|
return Err(crate::ServerError::TlsConfiguration(
|
||
|
|
format!("Certificate file does not exist: {}", tls.certificate.display())
|
||
|
|
));
|
||
|
|
}
|
||
|
|
if !tls.private_key.exists() {
|
||
|
|
return Err(crate::ServerError::TlsConfiguration(
|
||
|
|
format!("Private key file does not exist: {}", tls.private_key.display())
|
||
|
|
));
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
Ok(())
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
#[derive(Default)]
|
||
|
|
pub struct GurtConfigBuilder {
|
||
|
|
config: GurtConfig,
|
||
|
|
}
|
||
|
|
|
||
|
|
impl GurtConfigBuilder {
|
||
|
|
pub fn new() -> Self {
|
||
|
|
Self::default()
|
||
|
|
}
|
||
|
|
|
||
|
|
pub fn host<S: Into<String>>(mut self, host: S) -> Self {
|
||
|
|
self.config.server.host = host.into();
|
||
|
|
self
|
||
|
|
}
|
||
|
|
|
||
|
|
pub fn port(mut self, port: u16) -> Self {
|
||
|
|
self.config.server.port = port;
|
||
|
|
self
|
||
|
|
}
|
||
|
|
|
||
|
|
pub fn base_directory<P: Into<PathBuf>>(mut self, dir: P) -> Self {
|
||
|
|
self.config.server.base_directory = Arc::new(dir.into());
|
||
|
|
self
|
||
|
|
}
|
||
|
|
|
||
|
|
pub fn verbose(mut self, verbose: bool) -> Self {
|
||
|
|
self.config.server.verbose = verbose;
|
||
|
|
self
|
||
|
|
}
|
||
|
|
|
||
|
|
pub fn tls_config(mut self, cert_path: PathBuf, key_path: PathBuf) -> Self {
|
||
|
|
self.config.tls = Some(TlsConfig {
|
||
|
|
certificate: cert_path,
|
||
|
|
private_key: key_path,
|
||
|
|
});
|
||
|
|
self
|
||
|
|
}
|
||
|
|
|
||
|
|
pub fn logging_config(mut self, config: LoggingConfig) -> Self {
|
||
|
|
self.config.logging = Some(config);
|
||
|
|
self
|
||
|
|
}
|
||
|
|
|
||
|
|
pub fn security_config(mut self, config: SecurityConfig) -> Self {
|
||
|
|
self.config.security = Some(config);
|
||
|
|
self
|
||
|
|
}
|
||
|
|
|
||
|
|
pub fn error_pages_config(mut self, config: ErrorPagesConfig) -> Self {
|
||
|
|
self.config.error_pages = Some(config);
|
||
|
|
self
|
||
|
|
}
|
||
|
|
|
||
|
|
pub fn headers(mut self, headers: HashMap<String, String>) -> Self {
|
||
|
|
self.config.headers = Some(headers);
|
||
|
|
self
|
||
|
|
}
|
||
|
|
|
||
|
|
pub fn from_file<P: AsRef<std::path::Path>>(mut self, path: P) -> crate::Result<Self> {
|
||
|
|
let file_config = GurtConfig::from_file(path)?;
|
||
|
|
self.config = merge_configs(file_config, self.config);
|
||
|
|
Ok(self)
|
||
|
|
}
|
||
|
|
|
||
|
|
pub fn merge_cli_args(mut self, cli_args: &crate::cli::ServeCommand) -> Self {
|
||
|
|
self.config.server.host = cli_args.host.clone();
|
||
|
|
self.config.server.port = cli_args.port;
|
||
|
|
self.config.server.base_directory = Arc::new(cli_args.dir.clone());
|
||
|
|
self.config.server.verbose = cli_args.verbose;
|
||
|
|
|
||
|
|
if let (Some(cert), Some(key)) = (&cli_args.cert, &cli_args.key) {
|
||
|
|
self.config.tls = Some(TlsConfig {
|
||
|
|
certificate: cert.clone(),
|
||
|
|
private_key: key.clone(),
|
||
|
|
});
|
||
|
|
}
|
||
|
|
|
||
|
|
self
|
||
|
|
}
|
||
|
|
|
||
|
|
pub fn build(self) -> crate::Result<GurtConfig> {
|
||
|
|
let config = self.config;
|
||
|
|
|
||
|
|
if !config.server.base_directory.exists() || !config.server.base_directory.is_dir() {
|
||
|
|
return Err(crate::ServerError::InvalidConfiguration(
|
||
|
|
format!("Invalid base directory: {}", config.server.base_directory.display())
|
||
|
|
));
|
||
|
|
}
|
||
|
|
|
||
|
|
if let Some(tls) = &config.tls {
|
||
|
|
if !tls.certificate.exists() {
|
||
|
|
return Err(crate::ServerError::TlsConfiguration(
|
||
|
|
format!("Certificate file does not exist: {}", tls.certificate.display())
|
||
|
|
));
|
||
|
|
}
|
||
|
|
if !tls.private_key.exists() {
|
||
|
|
return Err(crate::ServerError::TlsConfiguration(
|
||
|
|
format!("Private key file does not exist: {}", tls.private_key.display())
|
||
|
|
));
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
Ok(config)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
|
||
|
|
fn parse_size(size_str: &str) -> crate::Result<u64> {
|
||
|
|
let size_str = size_str.trim().to_uppercase();
|
||
|
|
|
||
|
|
if let Some(captures) = regex::Regex::new(r"^(\d+(?:\.\d+)?)\s*([KMGT]?B?)$").unwrap().captures(&size_str) {
|
||
|
|
let number: f64 = captures[1].parse()
|
||
|
|
.map_err(|_| crate::ServerError::InvalidConfiguration(format!("Invalid size format: {}", size_str)))?;
|
||
|
|
|
||
|
|
let unit = captures.get(2).map_or("", |m| m.as_str());
|
||
|
|
|
||
|
|
let multiplier: u64 = match unit {
|
||
|
|
"" | "B" => 1,
|
||
|
|
"KB" => 1_000,
|
||
|
|
"MB" => 1_000_000,
|
||
|
|
"GB" => 1_000_000_000,
|
||
|
|
"TB" => 1_000_000_000_000,
|
||
|
|
_ => return Err(crate::ServerError::InvalidConfiguration(format!("Unknown size unit: {}", unit))),
|
||
|
|
};
|
||
|
|
let number = (number * multiplier as f64) as u64;
|
||
|
|
Ok(number)
|
||
|
|
} else {
|
||
|
|
Err(crate::ServerError::InvalidConfiguration(format!("Invalid size format: {}", size_str)))
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
fn matches_pattern(path: &str, pattern: &str) -> bool {
|
||
|
|
if pattern.ends_with("/*") {
|
||
|
|
let prefix = &pattern[..pattern.len() - 2];
|
||
|
|
path.starts_with(prefix)
|
||
|
|
} else if pattern.starts_with("*.") {
|
||
|
|
let suffix = &pattern[1..];
|
||
|
|
path.ends_with(suffix)
|
||
|
|
} else {
|
||
|
|
path == pattern
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
fn merge_configs(base: GurtConfig, override_config: GurtConfig) -> GurtConfig {
|
||
|
|
GurtConfig {
|
||
|
|
server: merge_server_configs(base.server, override_config.server),
|
||
|
|
tls: override_config.tls.or(base.tls),
|
||
|
|
logging: override_config.logging.or(base.logging),
|
||
|
|
security: override_config.security.or(base.security),
|
||
|
|
error_pages: override_config.error_pages.or(base.error_pages),
|
||
|
|
headers: override_config.headers.or(base.headers),
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
fn merge_server_configs(base: ServerConfig, override_config: ServerConfig) -> ServerConfig {
|
||
|
|
ServerConfig {
|
||
|
|
host: if override_config.host != default_host() { override_config.host } else { base.host },
|
||
|
|
port: if override_config.port != default_port() { override_config.port } else { base.port },
|
||
|
|
protocol_version: if override_config.protocol_version != default_protocol_version() {
|
||
|
|
override_config.protocol_version
|
||
|
|
} else {
|
||
|
|
base.protocol_version
|
||
|
|
},
|
||
|
|
alpn_identifier: if override_config.alpn_identifier != default_alpn_identifier() {
|
||
|
|
override_config.alpn_identifier
|
||
|
|
} else {
|
||
|
|
base.alpn_identifier
|
||
|
|
},
|
||
|
|
timeouts: override_config.timeouts.or(base.timeouts),
|
||
|
|
max_connections: if override_config.max_connections != default_max_connections() {
|
||
|
|
override_config.max_connections
|
||
|
|
} else {
|
||
|
|
base.max_connections
|
||
|
|
},
|
||
|
|
max_message_size: if override_config.max_message_size != default_max_message_size() {
|
||
|
|
override_config.max_message_size
|
||
|
|
} else {
|
||
|
|
base.max_message_size
|
||
|
|
},
|
||
|
|
base_directory: override_config.base_directory,
|
||
|
|
verbose: override_config.verbose,
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
#[cfg(test)]
|
||
|
|
mod tests {
|
||
|
|
use super::*;
|
||
|
|
use std::path::PathBuf;
|
||
|
|
|
||
|
|
#[test]
|
||
|
|
fn test_default_config_creation() {
|
||
|
|
let base_dir = PathBuf::from("/tmp");
|
||
|
|
let mut config = GurtConfig::default();
|
||
|
|
config.server.base_directory = Arc::new(base_dir.clone());
|
||
|
|
|
||
|
|
assert_eq!(config.server.host, "127.0.0.1");
|
||
|
|
assert_eq!(config.server.port, 4878);
|
||
|
|
assert_eq!(config.server.protocol_version, "1.0.0");
|
||
|
|
assert_eq!(config.server.alpn_identifier, "GURT/1.0");
|
||
|
|
assert_eq!(*config.server.base_directory, base_dir);
|
||
|
|
}
|
||
|
|
|
||
|
|
#[test]
|
||
|
|
fn test_config_from_valid_toml() {
|
||
|
|
let toml_content = r#"
|
||
|
|
[server]
|
||
|
|
host = "0.0.0.0"
|
||
|
|
port = 8080
|
||
|
|
protocol_version = "2.0.0"
|
||
|
|
alpn_identifier = "custom"
|
||
|
|
max_connections = 1000
|
||
|
|
max_message_size = "10MB"
|
||
|
|
|
||
|
|
[security]
|
||
|
|
rate_limit_requests = 60
|
||
|
|
rate_limit_connections = 5
|
||
|
|
"#;
|
||
|
|
|
||
|
|
let base_dir = PathBuf::from("/tmp");
|
||
|
|
let config = GurtConfig::from_toml(toml_content, base_dir).unwrap();
|
||
|
|
|
||
|
|
assert_eq!(config.server.host, "0.0.0.0");
|
||
|
|
assert_eq!(config.server.port, 8080);
|
||
|
|
assert_eq!(config.server.protocol_version, "2.0.0");
|
||
|
|
assert_eq!(config.server.alpn_identifier, "custom");
|
||
|
|
assert_eq!(config.server.max_connections, 1000);
|
||
|
|
|
||
|
|
let security = config.security.unwrap();
|
||
|
|
assert_eq!(security.rate_limit_requests, 60);
|
||
|
|
assert_eq!(security.rate_limit_connections, 5);
|
||
|
|
}
|
||
|
|
|
||
|
|
#[test]
|
||
|
|
fn test_invalid_toml_returns_error() {
|
||
|
|
let invalid_toml = r#"
|
||
|
|
[server
|
||
|
|
host = "0.0.0.0"
|
||
|
|
"#;
|
||
|
|
|
||
|
|
let base_dir = PathBuf::from("/tmp");
|
||
|
|
let result = GurtConfig::from_toml(invalid_toml, base_dir);
|
||
|
|
|
||
|
|
assert!(result.is_err());
|
||
|
|
}
|
||
|
|
|
||
|
|
#[test]
|
||
|
|
fn test_max_message_size_parsing() {
|
||
|
|
let config = GurtConfig::default();
|
||
|
|
|
||
|
|
assert_eq!(parse_size("1024").unwrap(), 1024);
|
||
|
|
assert_eq!(parse_size("1KB").unwrap(), 1000);
|
||
|
|
assert_eq!(parse_size("1MB").unwrap(), 1000 * 1000);
|
||
|
|
assert_eq!(parse_size("1GB").unwrap(), 1000 * 1000 * 1000);
|
||
|
|
|
||
|
|
assert!(parse_size("invalid").is_err());
|
||
|
|
|
||
|
|
assert!(config.max_message_size_bytes().is_ok());
|
||
|
|
}
|
||
|
|
|
||
|
|
#[test]
|
||
|
|
fn test_tls_config_validation() {
|
||
|
|
let mut config = GurtConfig::default();
|
||
|
|
|
||
|
|
config.tls = Some(TlsConfig {
|
||
|
|
certificate: PathBuf::from("/nonexistent/cert.pem"),
|
||
|
|
private_key: PathBuf::from("/nonexistent/key.pem"),
|
||
|
|
});
|
||
|
|
|
||
|
|
assert!(config.tls.is_some());
|
||
|
|
let tls = config.tls.unwrap();
|
||
|
|
assert_eq!(tls.certificate, PathBuf::from("/nonexistent/cert.pem"));
|
||
|
|
assert_eq!(tls.private_key, PathBuf::from("/nonexistent/key.pem"));
|
||
|
|
}
|
||
|
|
|
||
|
|
#[test]
|
||
|
|
fn test_address_formatting() {
|
||
|
|
let config = GurtConfig::default();
|
||
|
|
assert_eq!(config.address(), "127.0.0.1:4878");
|
||
|
|
|
||
|
|
let mut custom_config = GurtConfig::default();
|
||
|
|
custom_config.server.host = "0.0.0.0".to_string();
|
||
|
|
custom_config.server.port = 8080;
|
||
|
|
assert_eq!(custom_config.address(), "0.0.0.0:8080");
|
||
|
|
}
|
||
|
|
|
||
|
|
#[test]
|
||
|
|
fn test_timeout_getters() {
|
||
|
|
let config = GurtConfig::default();
|
||
|
|
|
||
|
|
assert_eq!(config.get_handshake_timeout(), Duration::from_secs(5));
|
||
|
|
assert_eq!(config.get_request_timeout(), Duration::from_secs(30));
|
||
|
|
assert_eq!(config.get_connection_timeout(), Duration::from_secs(10));
|
||
|
|
}
|
||
|
|
}
|