Merge pull request #117 from someoneidoknow/main
support progress bars in `gurt.download` while using GURT://
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "gurtlib"
|
||||
version = "0.1.0"
|
||||
version = "0.1.1"
|
||||
edition = "2021"
|
||||
authors = ["FaceDev"]
|
||||
license = "MIT"
|
||||
|
||||
@@ -26,6 +26,7 @@ pub struct GurtClientConfig {
|
||||
pub custom_ca_certificates: Vec<String>,
|
||||
pub dns_server_ip: String,
|
||||
pub dns_server_port: u16,
|
||||
pub read_timeout: Duration,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
|
||||
@@ -52,6 +53,7 @@ impl Default for GurtClientConfig {
|
||||
custom_ca_certificates: Vec::new(),
|
||||
dns_server_ip: "135.125.163.131".to_string(),
|
||||
dns_server_port: 4878,
|
||||
read_timeout: Duration::from_secs(5),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -529,6 +531,132 @@ impl GurtClient {
|
||||
|
||||
Ok((host, port, path))
|
||||
}
|
||||
|
||||
pub async fn stream_request<HeadCb, ChunkCb>(&self,
|
||||
host: &str,
|
||||
port: u16,
|
||||
mut request: GurtRequest,
|
||||
mut on_head: HeadCb,
|
||||
mut on_chunk: ChunkCb,
|
||||
) -> Result<()>
|
||||
where
|
||||
HeadCb: FnMut(&crate::message::GurtResponseHead) + Send,
|
||||
ChunkCb: FnMut(&[u8]) -> bool + Send,
|
||||
{
|
||||
let resolved_host = self.resolve_domain(host).await?;
|
||||
request = request.with_header("Host", host);
|
||||
|
||||
let mut tls_stream = self.get_pooled_connection(&resolved_host, port, Some(host)).await?;
|
||||
|
||||
let request_data = request.to_string();
|
||||
tls_stream.write_all(request_data.as_bytes()).await
|
||||
.map_err(|e| GurtError::connection(format!("Failed to write request: {}", e)))?;
|
||||
|
||||
let mut buffer: Vec<u8> = Vec::new();
|
||||
let mut temp_buffer = [0u8; 8192];
|
||||
let start_time = std::time::Instant::now();
|
||||
let mut headers_parsed = false;
|
||||
let mut expected_body_length: Option<usize> = None;
|
||||
let mut headers_end_pos: Option<usize> = None;
|
||||
let mut head_emitted = false;
|
||||
let mut delivered: usize = 0;
|
||||
|
||||
loop {
|
||||
if start_time.elapsed() > self.config.request_timeout {
|
||||
return Err(GurtError::timeout("Request timeout"));
|
||||
}
|
||||
|
||||
match timeout(self.config.read_timeout, tls_stream.read(&mut temp_buffer)).await {
|
||||
Ok(Ok(0)) => {
|
||||
if headers_parsed && !head_emitted {
|
||||
return Err(GurtError::connection("Connection closed before response headers were fully received"));
|
||||
}
|
||||
break;
|
||||
}
|
||||
Ok(Ok(n)) => {
|
||||
buffer.extend_from_slice(&temp_buffer[..n]);
|
||||
|
||||
if !headers_parsed {
|
||||
if let Some(pos) = buffer.windows(4).position(|w| w == b"\r\n\r\n") {
|
||||
headers_end_pos = Some(pos + 4);
|
||||
headers_parsed = true;
|
||||
|
||||
let headers_section = std::str::from_utf8(&buffer[..pos])
|
||||
.map_err(|e| GurtError::invalid_message(format!("Invalid UTF-8 in headers: {}", e)))?;
|
||||
|
||||
let mut lines = headers_section.split("\r\n");
|
||||
let status_line = lines.next().unwrap_or("");
|
||||
let parts: Vec<&str> = status_line.splitn(3, ' ').collect();
|
||||
let mut version = String::new();
|
||||
let mut status_code: u16 = 0;
|
||||
let mut status_message = String::new();
|
||||
if parts.len() >= 2 {
|
||||
version = parts[0].to_string();
|
||||
status_code = parts[1].parse().unwrap_or(0);
|
||||
if parts.len() > 2 { status_message = parts[2].to_string(); }
|
||||
}
|
||||
|
||||
let mut headers = std::collections::HashMap::new();
|
||||
for line in lines {
|
||||
if line.is_empty() { break; }
|
||||
if let Some(colon) = line.find(':') {
|
||||
let key = line[..colon].trim().to_lowercase();
|
||||
let value = line[colon+1..].trim().to_string();
|
||||
if key == "content-length" { expected_body_length = value.parse().ok(); }
|
||||
headers.insert(key, value);
|
||||
}
|
||||
}
|
||||
|
||||
let head = crate::message::GurtResponseHead {
|
||||
version,
|
||||
status_code,
|
||||
status_message,
|
||||
headers,
|
||||
};
|
||||
on_head(&head);
|
||||
head_emitted = true;
|
||||
|
||||
if let Some(end) = headers_end_pos {
|
||||
if buffer.len() > end {
|
||||
let body_slice = &buffer[end..];
|
||||
if !on_chunk(body_slice) {
|
||||
return Err(GurtError::Cancelled);
|
||||
}
|
||||
delivered = body_slice.len();
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if let Some(end) = headers_end_pos {
|
||||
let available = buffer.len().saturating_sub(end + delivered);
|
||||
if available > 0 {
|
||||
let start = end + delivered;
|
||||
let end_pos = end + delivered + available;
|
||||
if !on_chunk(&buffer[start..end_pos]) {
|
||||
return Err(GurtError::Cancelled);
|
||||
}
|
||||
delivered += available;
|
||||
}
|
||||
|
||||
if let Some(expected_len) = expected_body_length {
|
||||
if delivered >= expected_len { break; }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(Err(e)) => return Err(GurtError::connection(format!("Read error: {}", e))),
|
||||
Err(_) => continue,
|
||||
}
|
||||
}
|
||||
|
||||
if let (Some(end), Some(expected_len)) = (headers_end_pos, expected_body_length) {
|
||||
if delivered >= expected_len {
|
||||
self.return_connection_to_pool(&resolved_host, port, tls_stream);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn resolve_domain(&self, domain: &str) -> Result<String> {
|
||||
match self.dns_cache.lock() {
|
||||
@@ -734,4 +862,4 @@ mod tests {
|
||||
assert!(handshake_request.headers.contains_key("user-agent"));
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
@@ -32,6 +32,9 @@ pub enum GurtError {
|
||||
|
||||
#[error("Client error: {0}")]
|
||||
Client(String),
|
||||
|
||||
#[error("Cancelled")]
|
||||
Cancelled,
|
||||
}
|
||||
|
||||
pub type Result<T> = std::result::Result<T, GurtError>;
|
||||
|
||||
@@ -6,7 +6,7 @@ pub mod error;
|
||||
pub mod message;
|
||||
|
||||
pub use error::{GurtError, Result};
|
||||
pub use message::{GurtMessage, GurtRequest, GurtResponse, GurtMethod};
|
||||
pub use message::{GurtMessage, GurtRequest, GurtResponse, GurtResponseHead, GurtMethod};
|
||||
pub use protocol::{GurtStatusCode, GURT_VERSION, DEFAULT_PORT};
|
||||
pub use crypto::{CryptoManager, TlsConfig, GURT_ALPN, TLS_VERSION};
|
||||
pub use server::{GurtServer, GurtHandler, ServerContext, Route};
|
||||
@@ -15,7 +15,7 @@ pub use client::{GurtClient, GurtClientConfig};
|
||||
pub mod prelude {
|
||||
pub use crate::{
|
||||
GurtError, Result,
|
||||
GurtMessage, GurtRequest, GurtResponse,
|
||||
GurtMessage, GurtRequest, GurtResponse, GurtResponseHead,
|
||||
GURT_VERSION, DEFAULT_PORT,
|
||||
CryptoManager, TlsConfig, GURT_ALPN, TLS_VERSION,
|
||||
GurtServer, GurtHandler, ServerContext, Route,
|
||||
|
||||
@@ -225,6 +225,14 @@ pub struct GurtResponse {
|
||||
pub body: Vec<u8>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct GurtResponseHead {
|
||||
pub version: String,
|
||||
pub status_code: u16,
|
||||
pub status_message: String,
|
||||
pub headers: GurtHeaders,
|
||||
}
|
||||
|
||||
impl GurtResponse {
|
||||
pub fn new(status_code: GurtStatusCode) -> Self {
|
||||
Self {
|
||||
|
||||
Reference in New Issue
Block a user