i didnt test this a lot but i added progress bars to GURT:// based downloads 🎉

This commit is contained in:
Gabriella Casap
2025-09-26 14:49:22 +01:00
parent 5e4c381b84
commit 04434f176b
6 changed files with 448 additions and 88 deletions

View File

@@ -527,6 +527,139 @@ impl GurtClient {
Ok((host, port, path))
}
pub async fn stream_request<HeadCb, ChunkCb>(&self,
host: &str,
port: u16,
mut request: GurtRequest,
mut on_head: HeadCb,
mut on_chunk: ChunkCb,
) -> Result<()>
where
HeadCb: FnMut(&crate::message::GurtResponseHead) + Send,
ChunkCb: FnMut(&[u8]) -> bool + Send,
{
let resolved_host = self.resolve_domain(host).await?;
request = request.with_header("Host", host);
let mut tls_stream = self.get_pooled_connection(&resolved_host, port, Some(host)).await?;
let request_data = request.to_string();
tls_stream.write_all(request_data.as_bytes()).await
.map_err(|e| GurtError::connection(format!("Failed to write request: {}", e)))?;
let mut buffer: Vec<u8> = Vec::new();
let mut temp_buffer = [0u8; 8192];
let start_time = std::time::Instant::now();
let mut headers_parsed = false;
let mut expected_body_length: Option<usize> = None;
let mut headers_end_pos: Option<usize> = None;
let mut head_emitted = false;
let mut delivered: usize = 0;
loop {
if start_time.elapsed() > self.config.request_timeout {
return Err(GurtError::timeout("Request timeout"));
}
match timeout(Duration::from_millis(400), tls_stream.read(&mut temp_buffer)).await {
Ok(Ok(0)) => {
if headers_parsed && !head_emitted {
let head = crate::message::GurtResponseHead {
version: String::new(),
status_code: 0,
status_message: String::new(),
headers: std::collections::HashMap::new(),
};
on_head(&head);
head_emitted = true;
}
break;
}
Ok(Ok(n)) => {
buffer.extend_from_slice(&temp_buffer[..n]);
if !headers_parsed {
if let Some(pos) = buffer.windows(4).position(|w| w == b"\r\n\r\n") {
headers_end_pos = Some(pos + 4);
headers_parsed = true;
let headers_section = std::str::from_utf8(&buffer[..pos])
.map_err(|e| GurtError::invalid_message(format!("Invalid UTF-8 in headers: {}", e)))?;
let mut lines = headers_section.split("\r\n");
let status_line = lines.next().unwrap_or("");
let parts: Vec<&str> = status_line.splitn(3, ' ').collect();
let mut version = String::new();
let mut status_code: u16 = 0;
let mut status_message = String::new();
if parts.len() >= 2 {
version = parts[0].to_string();
status_code = parts[1].parse().unwrap_or(0);
if parts.len() > 2 { status_message = parts[2].to_string(); }
}
let mut headers = std::collections::HashMap::new();
for line in lines {
if line.is_empty() { break; }
if let Some(colon) = line.find(':') {
let key = line[..colon].trim().to_lowercase();
let value = line[colon+1..].trim().to_string();
if key == "content-length" { expected_body_length = value.parse().ok(); }
headers.insert(key, value);
}
}
let head = crate::message::GurtResponseHead {
version,
status_code,
status_message,
headers,
};
on_head(&head);
head_emitted = true;
if let Some(end) = headers_end_pos {
if buffer.len() > end {
let body_slice = &buffer[end..];
if !on_chunk(body_slice) {
return Err(GurtError::Cancelled);
}
delivered = body_slice.len();
}
}
}
} else {
if let Some(end) = headers_end_pos {
let available = buffer.len().saturating_sub(end + delivered);
if available > 0 {
let start = end + delivered;
let end_pos = end + delivered + available;
if !on_chunk(&buffer[start..end_pos]) {
return Err(GurtError::Cancelled);
}
delivered += available;
}
if let Some(expected_len) = expected_body_length {
if delivered >= expected_len { break; }
}
}
}
}
Ok(Err(e)) => return Err(GurtError::connection(format!("Read error: {}", e))),
Err(_) => continue,
}
}
if let (Some(end), Some(expected_len)) = (headers_end_pos, expected_body_length) {
if delivered >= expected_len {
self.return_connection_to_pool(&resolved_host, port, tls_stream);
}
}
Ok(())
}
async fn resolve_domain(&self, domain: &str) -> Result<String> {
match self.dns_cache.lock() {
@@ -732,4 +865,4 @@ mod tests {
assert!(handshake_request.headers.contains_key("user-agent"));
}
}
}

View File

@@ -6,7 +6,7 @@ pub mod error;
pub mod message;
pub use error::{GurtError, Result};
pub use message::{GurtMessage, GurtRequest, GurtResponse, GurtMethod};
pub use message::{GurtMessage, GurtRequest, GurtResponse, GurtResponseHead, GurtMethod};
pub use protocol::{GurtStatusCode, GURT_VERSION, DEFAULT_PORT};
pub use crypto::{CryptoManager, TlsConfig, GURT_ALPN, TLS_VERSION};
pub use server::{GurtServer, GurtHandler, ServerContext, Route};
@@ -15,7 +15,7 @@ pub use client::{GurtClient, GurtClientConfig};
pub mod prelude {
pub use crate::{
GurtError, Result,
GurtMessage, GurtRequest, GurtResponse,
GurtMessage, GurtRequest, GurtResponse, GurtResponseHead,
GURT_VERSION, DEFAULT_PORT,
CryptoManager, TlsConfig, GURT_ALPN, TLS_VERSION,
GurtServer, GurtHandler, ServerContext, Route,

View File

@@ -225,6 +225,14 @@ pub struct GurtResponse {
pub body: Vec<u8>,
}
#[derive(Debug, Clone)]
pub struct GurtResponseHead {
pub version: String,
pub status_code: u16,
pub status_message: String,
pub headers: GurtHeaders,
}
impl GurtResponse {
pub fn new(status_code: GurtStatusCode) -> Self {
Self {