289 lines
9.1 KiB
Rust
289 lines
9.1 KiB
Rust
|
|
use crate::config::GurtConfig;
|
||
|
|
use gurt::{prelude::*, GurtMethod, GurtStatusCode};
|
||
|
|
use std::collections::HashMap;
|
||
|
|
use std::net::IpAddr;
|
||
|
|
use std::sync::{Arc, Mutex};
|
||
|
|
use std::time::{Duration, Instant};
|
||
|
|
use tracing::{warn, debug};
|
||
|
|
|
||
|
|
#[derive(Debug)]
|
||
|
|
pub struct RateLimitData {
|
||
|
|
requests: Vec<Instant>,
|
||
|
|
connections: u32,
|
||
|
|
}
|
||
|
|
|
||
|
|
impl RateLimitData {
|
||
|
|
fn new() -> Self {
|
||
|
|
Self {
|
||
|
|
requests: Vec::new(),
|
||
|
|
connections: 0,
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
fn cleanup_old_requests(&mut self, window: Duration) {
|
||
|
|
let cutoff = Instant::now() - window;
|
||
|
|
self.requests.retain(|&request_time| request_time > cutoff);
|
||
|
|
}
|
||
|
|
|
||
|
|
fn add_request(&mut self) {
|
||
|
|
self.requests.push(Instant::now());
|
||
|
|
}
|
||
|
|
|
||
|
|
fn request_count(&self) -> usize {
|
||
|
|
self.requests.len()
|
||
|
|
}
|
||
|
|
|
||
|
|
fn increment_connections(&mut self) {
|
||
|
|
self.connections += 1;
|
||
|
|
}
|
||
|
|
|
||
|
|
fn decrement_connections(&mut self) {
|
||
|
|
if self.connections > 0 {
|
||
|
|
self.connections -= 1;
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
fn connection_count(&self) -> u32 {
|
||
|
|
self.connections
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
pub struct SecurityMiddleware {
|
||
|
|
config: Arc<GurtConfig>,
|
||
|
|
rate_limit_data: Arc<Mutex<HashMap<IpAddr, RateLimitData>>>,
|
||
|
|
}
|
||
|
|
|
||
|
|
impl SecurityMiddleware {
|
||
|
|
pub fn new(config: Arc<GurtConfig>) -> Self {
|
||
|
|
Self {
|
||
|
|
config,
|
||
|
|
rate_limit_data: Arc::new(Mutex::new(HashMap::new())),
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
pub fn is_method_allowed(&self, method: &GurtMethod) -> bool {
|
||
|
|
if let Some(security) = &self.config.security {
|
||
|
|
let method_str = method.to_string();
|
||
|
|
security.allowed_methods.contains(&method_str)
|
||
|
|
} else {
|
||
|
|
true
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
pub fn check_rate_limit(&self, client_ip: IpAddr) -> bool {
|
||
|
|
if let Some(security) = &self.config.security {
|
||
|
|
let mut data = self.rate_limit_data.lock().unwrap();
|
||
|
|
let rate_data = data.entry(client_ip).or_insert_with(RateLimitData::new);
|
||
|
|
|
||
|
|
rate_data.cleanup_old_requests(Duration::from_secs(60));
|
||
|
|
|
||
|
|
if rate_data.request_count() >= security.rate_limit_requests as usize {
|
||
|
|
warn!("Rate limit exceeded for IP {}: {} requests in the last minute",
|
||
|
|
client_ip, rate_data.request_count());
|
||
|
|
return false;
|
||
|
|
}
|
||
|
|
|
||
|
|
rate_data.add_request();
|
||
|
|
debug!("Request from {}: {}/{} requests in the last minute",
|
||
|
|
client_ip, rate_data.request_count(), security.rate_limit_requests);
|
||
|
|
}
|
||
|
|
|
||
|
|
true
|
||
|
|
}
|
||
|
|
|
||
|
|
pub fn check_connection_limit(&self, client_ip: IpAddr) -> bool {
|
||
|
|
if let Some(security) = &self.config.security {
|
||
|
|
let mut data = self.rate_limit_data.lock().unwrap();
|
||
|
|
let rate_data = data.entry(client_ip).or_insert_with(RateLimitData::new);
|
||
|
|
|
||
|
|
if rate_data.connection_count() >= security.rate_limit_connections {
|
||
|
|
warn!("Connection limit exceeded for IP {}: {} concurrent connections",
|
||
|
|
client_ip, rate_data.connection_count());
|
||
|
|
return false;
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
true
|
||
|
|
}
|
||
|
|
|
||
|
|
pub fn register_connection(&self, client_ip: IpAddr) {
|
||
|
|
if self.config.security.is_some() {
|
||
|
|
let mut data = self.rate_limit_data.lock().unwrap();
|
||
|
|
let rate_data = data.entry(client_ip).or_insert_with(RateLimitData::new);
|
||
|
|
rate_data.increment_connections();
|
||
|
|
debug!("Connection registered for {}: {} concurrent connections",
|
||
|
|
client_ip, rate_data.connection_count());
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
pub fn unregister_connection(&self, client_ip: IpAddr) {
|
||
|
|
if self.config.security.is_some() {
|
||
|
|
let mut data = self.rate_limit_data.lock().unwrap();
|
||
|
|
if let Some(rate_data) = data.get_mut(&client_ip) {
|
||
|
|
rate_data.decrement_connections();
|
||
|
|
debug!("Connection unregistered for {}: {} concurrent connections remaining",
|
||
|
|
client_ip, rate_data.connection_count());
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
pub fn create_method_not_allowed_response(&self) -> std::result::Result<GurtResponse, GurtError> {
|
||
|
|
let response = GurtResponse::new(GurtStatusCode::MethodNotAllowed)
|
||
|
|
.with_header("Content-Type", "text/html");
|
||
|
|
Ok(response)
|
||
|
|
}
|
||
|
|
|
||
|
|
pub fn create_rate_limit_response(&self) -> std::result::Result<GurtResponse, GurtError> {
|
||
|
|
let response = GurtResponse::new(GurtStatusCode::TooManyRequests)
|
||
|
|
.with_header("Content-Type", "text/html")
|
||
|
|
.with_header("Retry-After", "60");
|
||
|
|
Ok(response)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
#[cfg(test)]
|
||
|
|
mod tests {
|
||
|
|
use super::*;
|
||
|
|
use std::net::IpAddr;
|
||
|
|
use std::sync::Arc;
|
||
|
|
use std::time::Duration;
|
||
|
|
|
||
|
|
fn create_test_config() -> Arc<GurtConfig> {
|
||
|
|
let mut config = crate::config::GurtConfig::default();
|
||
|
|
config.security = Some(crate::config::SecurityConfig {
|
||
|
|
deny_files: vec!["*.secret".to_string(), "private/*".to_string()],
|
||
|
|
allowed_methods: vec!["GET".to_string(), "POST".to_string()],
|
||
|
|
rate_limit_requests: 5,
|
||
|
|
rate_limit_connections: 2,
|
||
|
|
});
|
||
|
|
Arc::new(config)
|
||
|
|
}
|
||
|
|
|
||
|
|
#[test]
|
||
|
|
fn test_rate_limit_data_initialization() {
|
||
|
|
let data = RateLimitData::new();
|
||
|
|
|
||
|
|
assert_eq!(data.request_count(), 0);
|
||
|
|
assert_eq!(data.connection_count(), 0);
|
||
|
|
}
|
||
|
|
|
||
|
|
#[test]
|
||
|
|
fn test_rate_limit_data_request_tracking() {
|
||
|
|
let mut data = RateLimitData::new();
|
||
|
|
|
||
|
|
data.add_request();
|
||
|
|
data.add_request();
|
||
|
|
assert_eq!(data.request_count(), 2);
|
||
|
|
|
||
|
|
data.cleanup_old_requests(Duration::from_secs(0));
|
||
|
|
assert_eq!(data.request_count(), 0);
|
||
|
|
}
|
||
|
|
|
||
|
|
#[test]
|
||
|
|
fn test_rate_limit_data_connection_tracking() {
|
||
|
|
let mut data = RateLimitData::new();
|
||
|
|
|
||
|
|
data.increment_connections();
|
||
|
|
data.increment_connections();
|
||
|
|
assert_eq!(data.connection_count(), 2);
|
||
|
|
|
||
|
|
data.decrement_connections();
|
||
|
|
assert_eq!(data.connection_count(), 1);
|
||
|
|
|
||
|
|
data.decrement_connections();
|
||
|
|
data.decrement_connections();
|
||
|
|
assert_eq!(data.connection_count(), 0);
|
||
|
|
}
|
||
|
|
|
||
|
|
#[test]
|
||
|
|
fn test_security_middleware_initialization() {
|
||
|
|
let config = create_test_config();
|
||
|
|
let middleware = SecurityMiddleware::new(config.clone());
|
||
|
|
|
||
|
|
assert!(middleware.rate_limit_data.lock().unwrap().is_empty());
|
||
|
|
}
|
||
|
|
|
||
|
|
#[test]
|
||
|
|
fn test_connection_tracking() {
|
||
|
|
let config = create_test_config();
|
||
|
|
let middleware = SecurityMiddleware::new(config.clone());
|
||
|
|
let ip: IpAddr = "127.0.0.1".parse().unwrap();
|
||
|
|
|
||
|
|
middleware.register_connection(ip);
|
||
|
|
{
|
||
|
|
let data = middleware.rate_limit_data.lock().unwrap();
|
||
|
|
assert_eq!(data.get(&ip).unwrap().connection_count(), 1);
|
||
|
|
}
|
||
|
|
|
||
|
|
middleware.unregister_connection(ip);
|
||
|
|
{
|
||
|
|
let data = middleware.rate_limit_data.lock().unwrap();
|
||
|
|
assert_eq!(data.get(&ip).unwrap().connection_count(), 0);
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
#[test]
|
||
|
|
fn test_rate_limiting_requests() {
|
||
|
|
let config = create_test_config();
|
||
|
|
let middleware = SecurityMiddleware::new(config.clone());
|
||
|
|
let ip: IpAddr = "127.0.0.1".parse().unwrap();
|
||
|
|
|
||
|
|
for _ in 0..5 {
|
||
|
|
assert!(middleware.check_rate_limit(ip));
|
||
|
|
}
|
||
|
|
|
||
|
|
assert!(!middleware.check_rate_limit(ip));
|
||
|
|
}
|
||
|
|
|
||
|
|
#[test]
|
||
|
|
fn test_connection_limiting() {
|
||
|
|
let config = create_test_config();
|
||
|
|
let middleware = SecurityMiddleware::new(config.clone());
|
||
|
|
let ip: IpAddr = "127.0.0.1".parse().unwrap();
|
||
|
|
|
||
|
|
middleware.register_connection(ip);
|
||
|
|
middleware.register_connection(ip);
|
||
|
|
|
||
|
|
assert!(!middleware.check_connection_limit(ip));
|
||
|
|
}
|
||
|
|
|
||
|
|
#[test]
|
||
|
|
fn test_method_validation() {
|
||
|
|
let config = create_test_config();
|
||
|
|
let middleware = SecurityMiddleware::new(config.clone());
|
||
|
|
|
||
|
|
assert!(middleware.is_method_allowed(&GurtMethod::GET));
|
||
|
|
assert!(middleware.is_method_allowed(&GurtMethod::POST));
|
||
|
|
|
||
|
|
assert!(!middleware.is_method_allowed(&GurtMethod::PUT));
|
||
|
|
assert!(!middleware.is_method_allowed(&GurtMethod::DELETE));
|
||
|
|
}
|
||
|
|
|
||
|
|
#[test]
|
||
|
|
fn test_multiple_ips_isolation() {
|
||
|
|
let config = create_test_config();
|
||
|
|
let middleware = SecurityMiddleware::new(config.clone());
|
||
|
|
let ip1: IpAddr = "127.0.0.1".parse().unwrap();
|
||
|
|
let ip2: IpAddr = "127.0.0.2".parse().unwrap();
|
||
|
|
|
||
|
|
for _ in 0..6 {
|
||
|
|
middleware.check_rate_limit(ip1);
|
||
|
|
}
|
||
|
|
|
||
|
|
assert!(middleware.check_rate_limit(ip2));
|
||
|
|
assert!(!middleware.check_rate_limit(ip1));
|
||
|
|
}
|
||
|
|
|
||
|
|
#[test]
|
||
|
|
fn test_response_creation() {
|
||
|
|
let config = create_test_config();
|
||
|
|
let middleware = SecurityMiddleware::new(config.clone());
|
||
|
|
|
||
|
|
let response = middleware.create_method_not_allowed_response().unwrap();
|
||
|
|
assert_eq!(response.status_code, 405);
|
||
|
|
|
||
|
|
let response = middleware.create_rate_limit_response().unwrap();
|
||
|
|
assert_eq!(response.status_code, 429);
|
||
|
|
}
|
||
|
|
}
|