diff --git a/.changeset/stream-multipart-uploads.md b/.changeset/stream-multipart-uploads.md new file mode 100644 index 00000000..98150928 --- /dev/null +++ b/.changeset/stream-multipart-uploads.md @@ -0,0 +1,5 @@ +--- +"@googleworkspace/cli": patch +--- + +Stream multipart uploads to avoid OOM on large files. File content is now streamed in chunks via `ReaderStream` instead of being read entirely into memory, reducing memory usage from O(file_size) to O(64 KB). diff --git a/Cargo.lock b/Cargo.lock index 5c3a6a02..198e3212 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -867,6 +867,7 @@ dependencies = [ "anyhow", "async-trait", "base64", + "bytes", "chrono", "clap", "crossterm", @@ -888,6 +889,7 @@ dependencies = [ "tempfile", "thiserror 2.0.18", "tokio", + "tokio-util", "yup-oauth2", "zeroize", ] diff --git a/Cargo.toml b/Cargo.toml index 1d1fe4d4..8aa0db22 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -47,6 +47,8 @@ thiserror = "2" tokio = { version = "1", features = ["full"] } yup-oauth2 = "12" futures-util = "0.3" +tokio-util = { version = "0.7", features = ["io"] } +bytes = "1" base64 = "0.22.1" derive_builder = "0.20.2" ratatui = "0.30.0" diff --git a/src/executor.rs b/src/executor.rs index 0a78c625..ab08ea0d 100644 --- a/src/executor.rs +++ b/src/executor.rs @@ -22,6 +22,7 @@ use std::collections::{HashMap, HashSet}; use std::path::PathBuf; use anyhow::Context; +use futures_util::stream::TryStreamExt; use futures_util::StreamExt; use serde_json::{json, Map, Value}; use tokio::io::AsyncWriteExt; @@ -183,20 +184,22 @@ async fn build_http_request( if input.is_upload { let upload_path = upload_path.expect("upload_path must be Some when is_upload is true"); - let file_bytes = tokio::fs::read(upload_path).await.map_err(|e| { + let file_meta = tokio::fs::metadata(upload_path).await.map_err(|e| { GwsError::Validation(format!( - "Failed to read upload file '{}': {}", + "Failed to get metadata for upload file '{}': {}", upload_path, e )) })?; + let file_size = file_meta.len(); request = request.query(&[("uploadType", "multipart")]); let media_mime = resolve_upload_mime(upload_content_type, Some(upload_path), &input.body); - let (multipart_body, content_type) = - build_multipart_body(&input.body, &file_bytes, &media_mime)?; + let (body, content_type, content_length) = + build_multipart_stream(&input.body, upload_path, file_size, &media_mime)?; request = request.header("Content-Type", content_type); - request = request.body(multipart_body); + request = request.header("Content-Length", content_length); + request = request.body(body); } else if let Some(ref body_val) = input.body { request = request.header("Content-Type", "application/json"); request = request.json(body_val); @@ -827,9 +830,75 @@ fn mime_from_extension(path: &str) -> Option<&'static str> { } } -/// Builds a multipart/related body for media upload requests. +/// Builds a streaming multipart/related body for media upload requests. +/// +/// Instead of reading the entire file into memory, this streams the file in +/// chunks via `ReaderStream`, keeping memory usage at O(64 KB) regardless of +/// file size. The `Content-Length` is pre-computed from file metadata so Google +/// APIs still receive the correct header without buffering. +/// +/// Returns `(body, content_type, content_length)`. +fn build_multipart_stream( + metadata: &Option, + file_path: &str, + file_size: u64, + media_mime: &str, +) -> Result<(reqwest::Body, String, u64), GwsError> { + let boundary = format!("gws_boundary_{:016x}", rand::random::()); + + let media_mime = media_mime.to_string(); + + let metadata_json = match metadata { + Some(m) => serde_json::to_string(m).map_err(|e| { + GwsError::Validation(format!("Failed to serialize upload metadata: {e}")) + })?, + None => "{}".to_string(), + }; + + let preamble = format!( + "--{boundary}\r\nContent-Type: application/json; charset=UTF-8\r\n\r\n{metadata_json}\r\n\ + --{boundary}\r\nContent-Type: {media_mime}\r\n\r\n" + ); + let postamble = format!("\r\n--{boundary}--\r\n"); + + let content_length = preamble.len() as u64 + file_size + postamble.len() as u64; + let content_type = format!("multipart/related; boundary={boundary}"); + + let preamble_bytes: bytes::Bytes = preamble.into_bytes().into(); + let postamble_bytes: bytes::Bytes = postamble.into_bytes().into(); + + let file_path_owned = file_path.to_owned(); + let file_stream = futures_util::stream::once(async move { + tokio::fs::File::open(&file_path_owned).await.map_err(|e| { + std::io::Error::new( + e.kind(), + format!("failed to open upload file '{}': {}", file_path_owned, e), + ) + }) + }) + .map_ok(tokio_util::io::ReaderStream::new) + .try_flatten(); + + let stream = futures_util::stream::once(async { Ok::<_, std::io::Error>(preamble_bytes) }) + .chain(file_stream) + .chain(futures_util::stream::once(async { + Ok::<_, std::io::Error>(postamble_bytes) + })); + + Ok(( + reqwest::Body::wrap_stream(stream), + content_type, + content_length, + )) +} + +/// Builds a buffered multipart/related body for media upload requests. +/// +/// This is the legacy implementation retained for unit tests that need +/// a fully materialized body to assert against. /// /// Returns the body bytes and the Content-Type header value (with boundary). +#[cfg(test)] fn build_multipart_body( metadata: &Option, file_bytes: &[u8], @@ -1369,6 +1438,82 @@ mod tests { ); } + #[tokio::test] + async fn test_build_multipart_stream_content_length() { + let dir = tempfile::tempdir().unwrap(); + let file_path = dir.path().join("small.txt"); + let file_content = b"Hello stream"; + std::fs::write(&file_path, file_content).unwrap(); + + let metadata = Some(json!({ "name": "small.txt" })); + let file_size = file_content.len() as u64; + + let (_body, content_type, declared_len) = build_multipart_stream( + &metadata, + file_path.to_str().unwrap(), + file_size, + "text/plain", + ) + .unwrap(); + + assert!(content_type.starts_with("multipart/related; boundary=")); + let boundary = content_type.split("boundary=").nth(1).unwrap(); + + // Manually compute expected content length: + // preamble = "--{boundary}\r\nContent-Type: application/json; charset=UTF-8\r\n\r\n{json}\r\n--{boundary}\r\nContent-Type: text/plain\r\n\r\n" + // postamble = "\r\n--{boundary}--\r\n" + let metadata_json = serde_json::to_string(&metadata.unwrap()).unwrap(); + let preamble = format!( + "--{boundary}\r\nContent-Type: application/json; charset=UTF-8\r\n\r\n{metadata_json}\r\n\ + --{boundary}\r\nContent-Type: text/plain\r\n\r\n" + ); + let postamble = format!("\r\n--{boundary}--\r\n"); + let expected = preamble.len() as u64 + file_size + postamble.len() as u64; + assert_eq!( + declared_len, expected, + "declared Content-Length must match expected preamble + file + postamble" + ); + } + + #[tokio::test] + async fn test_build_multipart_stream_large_file() { + let dir = tempfile::tempdir().unwrap(); + let file_path = dir.path().join("large.bin"); + // 256 KB — larger than the default 64 KB ReaderStream chunk size + let data = vec![0xABu8; 256 * 1024]; + std::fs::write(&file_path, &data).unwrap(); + + let metadata = None; + let file_size = data.len() as u64; + + let (_body, _content_type, declared_len) = build_multipart_stream( + &metadata, + file_path.to_str().unwrap(), + file_size, + "application/octet-stream", + ) + .unwrap(); + + // Content-Length must account for the empty-metadata preamble + large file + postamble + assert!( + declared_len > file_size, + "Content-Length ({declared_len}) must be larger than file size ({file_size}) due to multipart framing" + ); + + // Verify exact arithmetic: preamble overhead + file_size + postamble + let boundary = _content_type.split("boundary=").nth(1).unwrap(); + let preamble = format!( + "--{boundary}\r\nContent-Type: application/json; charset=UTF-8\r\n\r\n{{}}\r\n\ + --{boundary}\r\nContent-Type: application/octet-stream\r\n\r\n" + ); + let postamble = format!("\r\n--{boundary}--\r\n"); + let expected = preamble.len() as u64 + file_size + postamble.len() as u64; + assert_eq!( + declared_len, expected, + "Content-Length must match for multi-chunk files" + ); + } + #[test] fn test_build_url_basic() { let doc = RestDescription {