Files
leonwww/protocol/library/src/message.rs

565 lines
18 KiB
Rust

use crate::{GurtError, Result, GURT_VERSION};
use crate::protocol::{GurtStatusCode, PROTOCOL_PREFIX, HEADER_SEPARATOR, BODY_SEPARATOR};
use serde::{Serialize, Deserialize};
use std::collections::HashMap;
use std::fmt;
use chrono::Utc;
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub enum GurtMethod {
GET,
POST,
PUT,
DELETE,
HEAD,
OPTIONS,
PATCH,
HANDSHAKE,
}
impl GurtMethod {
pub fn parse(s: &str) -> Result<Self> {
match s.to_uppercase().as_str() {
"GET" => Ok(Self::GET),
"POST" => Ok(Self::POST),
"PUT" => Ok(Self::PUT),
"DELETE" => Ok(Self::DELETE),
"HEAD" => Ok(Self::HEAD),
"OPTIONS" => Ok(Self::OPTIONS),
"PATCH" => Ok(Self::PATCH),
"HANDSHAKE" => Ok(Self::HANDSHAKE),
_ => Err(GurtError::invalid_message(format!("Unsupported method: {}", s))),
}
}
}
impl fmt::Display for GurtMethod {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let s = match self {
Self::GET => "GET",
Self::POST => "POST",
Self::PUT => "PUT",
Self::DELETE => "DELETE",
Self::HEAD => "HEAD",
Self::OPTIONS => "OPTIONS",
Self::PATCH => "PATCH",
Self::HANDSHAKE => "HANDSHAKE",
};
write!(f, "{}", s)
}
}
pub type GurtHeaders = HashMap<String, String>;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GurtRequest {
pub method: GurtMethod,
pub path: String,
pub version: String,
pub headers: GurtHeaders,
pub body: Vec<u8>,
}
impl GurtRequest {
pub fn new(method: GurtMethod, path: String) -> Self {
Self {
method,
path,
version: GURT_VERSION.to_string(),
headers: GurtHeaders::new(),
body: Vec::new(),
}
}
pub fn with_header<K: Into<String>, V: Into<String>>(mut self, key: K, value: V) -> Self {
self.headers.insert(key.into().to_lowercase(), value.into());
self
}
pub fn with_body<B: Into<Vec<u8>>>(mut self, body: B) -> Self {
self.body = body.into();
self
}
pub fn with_string_body<S: AsRef<str>>(mut self, body: S) -> Self {
self.body = body.as_ref().as_bytes().to_vec();
self
}
pub fn header(&self, key: &str) -> Option<&String> {
self.headers.get(&key.to_lowercase())
}
pub fn text(&self) -> Result<String> {
std::str::from_utf8(&self.body)
.map(|s| s.to_string())
.map_err(|e| GurtError::invalid_message(format!("Invalid UTF-8 body: {}", e)))
}
pub fn parse(data: &str) -> Result<Self> {
Self::parse_bytes(data.as_bytes())
}
pub fn parse_bytes(data: &[u8]) -> Result<Self> {
let body_separator = BODY_SEPARATOR.as_bytes();
let body_separator_pos = data.windows(body_separator.len())
.position(|window| window == body_separator);
let (headers_section, body) = if let Some(pos) = body_separator_pos {
let headers_part = &data[..pos];
let body_part = &data[pos + body_separator.len()..];
(headers_part, body_part.to_vec())
} else {
(data, Vec::new())
};
let headers_str = std::str::from_utf8(headers_section)
.map_err(|_| GurtError::invalid_message("Invalid UTF-8 in headers"))?;
let lines: Vec<&str> = headers_str.split(HEADER_SEPARATOR).collect();
if lines.is_empty() {
return Err(GurtError::invalid_message("Empty request"));
}
let request_line = lines[0];
let parts: Vec<&str> = request_line.split_whitespace().collect();
if parts.len() != 3 {
return Err(GurtError::invalid_message("Invalid request line format"));
}
let method = GurtMethod::parse(parts[0])?;
let path = parts[1].to_string();
if !parts[2].starts_with(PROTOCOL_PREFIX) {
return Err(GurtError::invalid_message("Invalid protocol identifier"));
}
let version_str = &parts[2][PROTOCOL_PREFIX.len()..];
let version = version_str.to_string();
let mut headers = GurtHeaders::new();
for line in lines.iter().skip(1) {
if line.is_empty() {
break;
}
if let Some(colon_pos) = line.find(':') {
let key = line[..colon_pos].trim().to_lowercase();
let value = line[colon_pos + 1..].trim().to_string();
headers.insert(key, value);
}
}
Ok(Self {
method,
path,
version,
headers,
body,
})
}
pub fn to_string(&self) -> String {
let mut message = format!("{} {} {}{}{}",
self.method, self.path, PROTOCOL_PREFIX, self.version, HEADER_SEPARATOR);
let mut headers = self.headers.clone();
if !headers.contains_key("content-length") {
headers.insert("content-length".to_string(), self.body.len().to_string());
}
if !headers.contains_key("user-agent") {
headers.insert("user-agent".to_string(), format!("GURT-Client/{}", GURT_VERSION));
}
for (key, value) in &headers {
message.push_str(&format!("{}: {}{}", key, value, HEADER_SEPARATOR));
}
message.push_str(HEADER_SEPARATOR);
if !self.body.is_empty() {
if let Ok(body_str) = std::str::from_utf8(&self.body) {
message.push_str(body_str);
}
}
message
}
pub fn to_bytes(&self) -> Vec<u8> {
let mut message = format!("{} {} {}{}{}",
self.method, self.path, PROTOCOL_PREFIX, self.version, HEADER_SEPARATOR);
let mut headers = self.headers.clone();
if !headers.contains_key("content-length") {
headers.insert("content-length".to_string(), self.body.len().to_string());
}
if !headers.contains_key("user-agent") {
headers.insert("user-agent".to_string(), format!("GURT-Client/{}", GURT_VERSION));
}
for (key, value) in &headers {
message.push_str(&format!("{}: {}{}", key, value, HEADER_SEPARATOR));
}
message.push_str(HEADER_SEPARATOR);
let mut bytes = message.into_bytes();
bytes.extend_from_slice(&self.body);
bytes
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GurtResponse {
pub version: String,
pub status_code: u16,
pub status_message: String,
pub headers: GurtHeaders,
pub body: Vec<u8>,
}
#[derive(Debug, Clone)]
pub struct GurtResponseHead {
pub version: String,
pub status_code: u16,
pub status_message: String,
pub headers: GurtHeaders,
}
impl GurtResponse {
pub fn new(status_code: GurtStatusCode) -> Self {
Self {
version: GURT_VERSION.to_string(),
status_code: status_code as u16,
status_message: status_code.message().to_string(),
headers: GurtHeaders::new(),
body: Vec::new(),
}
}
pub fn ok() -> Self {
Self::new(GurtStatusCode::Ok)
}
pub fn not_found() -> Self {
Self::new(GurtStatusCode::NotFound)
}
pub fn bad_request() -> Self {
Self::new(GurtStatusCode::BadRequest)
}
pub fn forbidden() -> Self {
Self::new(GurtStatusCode::Forbidden)
}
pub fn internal_server_error() -> Self {
Self::new(GurtStatusCode::InternalServerError)
}
pub fn with_header<K: Into<String>, V: Into<String>>(mut self, key: K, value: V) -> Self {
self.headers.insert(key.into().to_lowercase(), value.into());
self
}
pub fn with_body<B: Into<Vec<u8>>>(mut self, body: B) -> Self {
self.body = body.into();
self
}
pub fn with_string_body<S: AsRef<str>>(mut self, body: S) -> Self {
self.body = body.as_ref().as_bytes().to_vec();
self
}
pub fn with_json_body<T: Serialize>(mut self, data: &T) -> Result<Self> {
let json = serde_json::to_string(data)?;
self.body = json.into_bytes();
self.headers.insert("content-type".to_string(), "application/json".to_string());
Ok(self)
}
pub fn header(&self, key: &str) -> Option<&String> {
self.headers.get(&key.to_lowercase())
}
pub fn text(&self) -> Result<String> {
std::str::from_utf8(&self.body)
.map(|s| s.to_owned())
.map_err(|e| GurtError::invalid_message(format!("Invalid UTF-8 body: {}", e)))
}
pub fn is_success(&self) -> bool {
self.status_code >= 200 && self.status_code < 300
}
pub fn is_client_error(&self) -> bool {
self.status_code >= 400 && self.status_code < 500
}
pub fn is_server_error(&self) -> bool {
self.status_code >= 500
}
pub fn parse(data: &str) -> Result<Self> {
Self::parse_bytes(data.as_bytes())
}
pub fn parse_bytes(data: &[u8]) -> Result<Self> {
let body_separator = BODY_SEPARATOR.as_bytes();
let body_separator_pos = data.windows(body_separator.len())
.position(|window| window == body_separator);
let (headers_section, body) = if let Some(pos) = body_separator_pos {
let headers_part = &data[..pos];
let body_part = &data[pos + body_separator.len()..];
(headers_part, body_part.to_vec())
} else {
(data, Vec::new())
};
let headers_str = std::str::from_utf8(headers_section)
.map_err(|_| GurtError::invalid_message("Invalid UTF-8 in headers"))?;
let lines: Vec<&str> = headers_str.split(HEADER_SEPARATOR).collect();
if lines.is_empty() {
return Err(GurtError::invalid_message("Empty response"));
}
let status_line = lines[0];
let parts: Vec<&str> = status_line.splitn(3, ' ').collect();
if parts.len() < 2 {
return Err(GurtError::invalid_message("Invalid status line format"));
}
if !parts[0].starts_with(PROTOCOL_PREFIX) {
return Err(GurtError::invalid_message("Invalid protocol identifier"));
}
let version_str = &parts[0][PROTOCOL_PREFIX.len()..];
let version = version_str.to_string();
let status_code: u16 = parts[1].parse()
.map_err(|_| GurtError::invalid_message("Invalid status code"))?;
let status_message = if parts.len() > 2 {
parts[2].to_string()
} else {
GurtStatusCode::from_u16(status_code)
.map(|sc| sc.message().to_string())
.unwrap_or_else(|| "Unknown".to_string())
};
let mut headers = GurtHeaders::new();
for line in lines.iter().skip(1) {
if line.is_empty() {
break;
}
if let Some(colon_pos) = line.find(':') {
let key = line[..colon_pos].trim().to_lowercase();
let value = line[colon_pos + 1..].trim().to_string();
headers.insert(key, value);
}
}
Ok(Self {
version,
status_code,
status_message,
headers,
body,
})
}
pub fn to_string(&self) -> String {
let mut message = format!("{}{} {} {}{}",
PROTOCOL_PREFIX, self.version, self.status_code, self.status_message, HEADER_SEPARATOR);
let mut headers = self.headers.clone();
if !headers.contains_key("content-length") {
headers.insert("content-length".to_string(), self.body.len().to_string());
}
if !headers.contains_key("server") {
headers.insert("server".to_string(), format!("GURT/{}", GURT_VERSION));
}
if !headers.contains_key("date") {
let now = Utc::now();
let date_str = now.format("%a, %d %b %Y %H:%M:%S GMT").to_string();
headers.insert("date".to_string(), date_str);
}
for (key, value) in &headers {
message.push_str(&format!("{}: {}{}", key, value, HEADER_SEPARATOR));
}
message.push_str(HEADER_SEPARATOR);
if !self.body.is_empty() {
if let Ok(body_str) = std::str::from_utf8(&self.body) {
message.push_str(body_str);
}
}
message
}
pub fn to_bytes(&self) -> Vec<u8> {
let mut message = format!("{}{} {} {}{}",
PROTOCOL_PREFIX, self.version, self.status_code, self.status_message, HEADER_SEPARATOR);
let mut headers = self.headers.clone();
if !headers.contains_key("content-length") {
headers.insert("content-length".to_string(), self.body.len().to_string());
}
if !headers.contains_key("server") {
headers.insert("server".to_string(), format!("GURT/{}", GURT_VERSION));
}
if !headers.contains_key("date") {
let now = Utc::now();
let date_str = now.format("%a, %d %b %Y %H:%M:%S GMT").to_string();
headers.insert("date".to_string(), date_str);
}
for (key, value) in &headers {
message.push_str(&format!("{}: {}{}", key, value, HEADER_SEPARATOR));
}
message.push_str(HEADER_SEPARATOR);
let mut bytes = message.into_bytes();
bytes.extend_from_slice(&self.body);
bytes
}
}
#[derive(Debug, Clone)]
pub enum GurtMessage {
Request(GurtRequest),
Response(GurtResponse),
}
impl GurtMessage {
pub fn parse(data: &str) -> Result<Self> {
Self::parse_bytes(data.as_bytes())
}
pub fn parse_bytes(data: &[u8]) -> Result<Self> {
let header_separator = HEADER_SEPARATOR.as_bytes();
let first_line_end = data.windows(header_separator.len())
.position(|window| window == header_separator)
.unwrap_or(data.len());
let first_line = std::str::from_utf8(&data[..first_line_end])
.map_err(|_| GurtError::invalid_message("Invalid UTF-8 in first line"))?;
if first_line.starts_with(PROTOCOL_PREFIX) {
Ok(GurtMessage::Response(GurtResponse::parse_bytes(data)?))
} else {
Ok(GurtMessage::Request(GurtRequest::parse_bytes(data)?))
}
}
pub fn is_request(&self) -> bool {
matches!(self, GurtMessage::Request(_))
}
pub fn is_response(&self) -> bool {
matches!(self, GurtMessage::Response(_))
}
pub fn as_request(&self) -> Option<&GurtRequest> {
match self {
GurtMessage::Request(req) => Some(req),
_ => None,
}
}
pub fn as_response(&self) -> Option<&GurtResponse> {
match self {
GurtMessage::Response(res) => Some(res),
_ => None,
}
}
}
impl fmt::Display for GurtMessage {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
GurtMessage::Request(req) => write!(f, "{}", req.to_string()),
GurtMessage::Response(res) => write!(f, "{}", res.to_string()),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_request_parsing() {
let raw = "GET /test GURT/1.0.0\r\nHost: example.com\r\nAccept: text/html\r\n\r\ntest body";
let request = GurtRequest::parse(raw).expect("Failed to parse request");
assert_eq!(request.method, GurtMethod::GET);
assert_eq!(request.path, "/test");
assert_eq!(request.version, GURT_VERSION.to_string());
assert_eq!(request.header("host"), Some(&"example.com".to_string()));
assert_eq!(request.header("accept"), Some(&"text/html".to_string()));
assert_eq!(request.text().unwrap(), "test body");
}
#[test]
fn test_response_parsing() {
let raw = "GURT/1.0.0 200 OK\r\nContent-Type: text/html\r\n\r\n<html></html>";
let response = GurtResponse::parse(raw).expect("Failed to parse response");
assert_eq!(response.version, GURT_VERSION.to_string());
assert_eq!(response.status_code, 200);
assert_eq!(response.status_message, "OK");
assert_eq!(response.header("content-type"), Some(&"text/html".to_string()));
assert_eq!(response.text().unwrap(), "<html></html>");
}
#[test]
fn test_request_building() {
let request = GurtRequest::new(GurtMethod::GET, "/test".to_string())
.with_header("Host", "example.com")
.with_string_body("test body");
let raw = request.to_string();
let parsed = GurtRequest::parse(&raw).expect("Failed to parse built request");
assert_eq!(parsed.method, request.method);
assert_eq!(parsed.path, request.path);
assert_eq!(parsed.body, request.body);
}
#[test]
fn test_response_building() {
let response = GurtResponse::ok()
.with_header("Content-Type", "text/html")
.with_string_body("<html></html>");
let raw = response.to_string();
let parsed = GurtResponse::parse(&raw).expect("Failed to parse built response");
assert_eq!(parsed.status_code, response.status_code);
assert_eq!(parsed.body, response.body);
}
}