Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
222 changes: 218 additions & 4 deletions src/query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,101 @@ use std::time::Instant;
use tokio::{select, signal, task};
use tokio_util::sync::CancellationToken;

/// Reassembles arbitrarily-chunked byte streams into valid UTF-8 strings,
/// carrying incomplete multi-byte sequences across chunk boundaries.
struct ChunkDecoder {
carry: Vec<u8>,
total_decoded: usize,
}

impl ChunkDecoder {
fn new() -> Self {
Self {
carry: Vec::new(),
total_decoded: 0,
}
}

/// Feed a chunk of bytes. On success returns the decoded UTF-8 string.
/// On failure returns an error message describing the invalid bytes.
fn feed(&mut self, chunk: &[u8]) -> Result<String, String> {
let data = if self.carry.is_empty() {
chunk.to_vec()
} else {
let mut combined = std::mem::take(&mut self.carry);
combined.extend_from_slice(chunk);
combined
};
match std::str::from_utf8(&data) {
Ok(s) => {
self.total_decoded += data.len();
Ok(s.to_string())
}
Err(e) => {
let valid_end = e.valid_up_to();
let remainder = &data[valid_end..];
if e.error_len().is_none() && remainder.len() <= 3 {
let valid = if valid_end > 0 {
std::str::from_utf8(&data[..valid_end]).unwrap().to_string()
} else {
String::new()
};
self.total_decoded += valid_end;
self.carry = remainder.to_vec();
Ok(valid)
} else {
let bad_len = e.error_len().unwrap_or(1);
let window_start = valid_end.saturating_sub(32);
let window_end = (valid_end + bad_len + 32).min(data.len());
let hex_context: Vec<String> = data[window_start..window_end]
.iter()
.enumerate()
.map(|(i, b)| {
let abs = i + window_start;
if abs >= valid_end && abs < valid_end + bad_len {
format!("[{:02x}]", b)
} else {
format!("{:02x}", b)
}
})
.collect();
let valid_prefix = String::from_utf8_lossy(&data[window_start..valid_end]);
Err(format!(
"Invalid UTF-8 in response at byte {} ({} bytes already decoded). \
Hex near error: {} | Text near error: \"{}\u{fffd}...\"",
self.total_decoded + valid_end,
self.total_decoded,
hex_context.join(" "),
valid_prefix
.chars()
.rev()
.take(64)
.collect::<String>()
.chars()
.rev()
.collect::<String>(),
))
}
}
}
}

/// Call when the stream has ended. Returns an error if there are leftover
/// carry bytes (truncated multi-byte sequence).
fn finish(self) -> Result<(), String> {
if self.carry.is_empty() {
Ok(())
} else {
let hex: Vec<String> = self.carry.iter().map(|b| format!("{:02x}", b)).collect();
Err(format!(
"Invalid UTF-8 at end of response: stream ended with {} trailing bytes (hex: {})",
self.carry.len(),
hex.join(" "),
))
}
}
}

/// Extract a human-readable error message from a raw server response body.
///
/// Tries to parse as JSON and look for common error fields; falls back to the
Expand Down Expand Up @@ -540,15 +635,21 @@ pub async fn query(context: &mut Context, query_text: String) -> Result<(), Box<

let mut line_buf = String::new();
let mut stream_err: Option<String> = None;
let mut chunk_decoder = ChunkDecoder::new();

'stream: loop {
match resp.chunk().await {
Err(e) => { stream_err = Some(e.to_string()); break 'stream; }
Ok(None) => break 'stream,
Ok(None) => {
if let Err(e) = chunk_decoder.finish() {
stream_err = Some(e);
}
break 'stream;
}
Ok(Some(chunk)) => {
match std::str::from_utf8(&chunk) {
Err(_) => { stream_err = Some("Invalid UTF-8 in response".into()); break 'stream; }
Ok(s) => line_buf.push_str(s),
match chunk_decoder.feed(&chunk) {
Ok(s) => line_buf.push_str(&s),
Err(e) => { stream_err = Some(e); break 'stream; }
}
while let Some(nl) = line_buf.find('\n') {
let line = line_buf[..nl].trim().to_string();
Expand Down Expand Up @@ -1520,4 +1621,117 @@ mod tests {
apply_update_parameters(&mut ctx, "transaction_id=cafebabe,transaction_sequence_id=0").unwrap();
assert!(ctx.in_transaction(), "second transaction must register correctly");
}

#[test]
fn test_chunk_decoder_pure_ascii() {
let mut dec = ChunkDecoder::new();
assert_eq!(dec.feed(b"hello ").unwrap(), "hello ");
assert_eq!(dec.feed(b"world").unwrap(), "world");
assert!(dec.finish().is_ok());
}

#[test]
fn test_chunk_decoder_multibyte_not_split() {
let mut dec = ChunkDecoder::new();
// "─" is e2 94 80 in UTF-8
assert_eq!(dec.feed("hello ─ world".as_bytes()).unwrap(), "hello ─ world");
assert!(dec.finish().is_ok());
}

#[test]
fn test_chunk_decoder_split_two_byte_char() {
let mut dec = ChunkDecoder::new();
// "ñ" is c3 b1 — split across two chunks
let full = "señor".as_bytes();
let split = full.iter().position(|&b| b == 0xc3).unwrap();
assert_eq!(dec.feed(&full[..split]).unwrap(), "se");
assert_eq!(dec.feed(&full[split..]).unwrap(), "ñor");
assert!(dec.finish().is_ok());
}

#[test]
fn test_chunk_decoder_split_three_byte_char_after_first() {
let mut dec = ChunkDecoder::new();
// "┴" is e2 94 b4 — split after first byte
let full = "ab┴cd".as_bytes();
let split = 2; // "ab" is 2 bytes, then e2 starts
assert_eq!(dec.feed(&full[..split + 1]).unwrap(), "ab");
assert_eq!(dec.feed(&full[split + 1..]).unwrap(), "┴cd");
assert!(dec.finish().is_ok());
}

#[test]
fn test_chunk_decoder_split_three_byte_char_after_second() {
let mut dec = ChunkDecoder::new();
// "┴" is e2 94 b4 — split after second byte
let full = "ab┴cd".as_bytes();
let split = 2; // "ab" is 2 bytes, then e2 94 b4
assert_eq!(dec.feed(&full[..split + 2]).unwrap(), "ab");
assert_eq!(dec.feed(&full[split + 2..]).unwrap(), "┴cd");
assert!(dec.finish().is_ok());
}

#[test]
fn test_chunk_decoder_split_four_byte_char() {
let mut dec = ChunkDecoder::new();
// "😀" is f0 9f 98 80 — split after first byte
let full = "a😀b".as_bytes();
assert_eq!(dec.feed(&full[..2]).unwrap(), "a");
assert_eq!(dec.feed(&full[2..]).unwrap(), "😀b");
assert!(dec.finish().is_ok());
}

#[test]
fn test_chunk_decoder_many_splits_across_chunks() {
let mut dec = ChunkDecoder::new();
let input = "──┬──┴──";
let bytes = input.as_bytes();
let mut output = String::new();
for byte in bytes {
output.push_str(&dec.feed(&[*byte]).unwrap());
}
assert!(dec.finish().is_ok());
assert_eq!(output, input);
}

#[test]
fn test_chunk_decoder_truncated_stream() {
let mut dec = ChunkDecoder::new();
// Feed the first byte of a 3-byte sequence, then end stream
assert_eq!(dec.feed(&[0xe2]).unwrap(), "");
let err = dec.finish().unwrap_err();
assert!(err.contains("trailing bytes"), "error: {}", err);
assert!(err.contains("e2"), "error should contain hex: {}", err);
}

#[test]
fn test_chunk_decoder_invalid_byte() {
let mut dec = ChunkDecoder::new();
// 0xff is never valid in UTF-8
let err = dec.feed(&[b'a', b'b', 0xff, b'c']).unwrap_err();
assert!(err.contains("Invalid UTF-8"), "error: {}", err);
assert!(err.contains("[ff]"), "error should bracket the bad byte: {}", err);
}

#[test]
fn test_chunk_decoder_invalid_continuation() {
let mut dec = ChunkDecoder::new();
// e2 followed by 0xff — not a valid continuation byte
let err = dec.feed(b"hello\xe2\xff").unwrap_err();
assert!(err.contains("Invalid UTF-8"), "error: {}", err);
}

#[test]
fn test_chunk_decoder_error_reports_total_position() {
let mut dec = ChunkDecoder::new();
// First chunk: 10 valid ASCII bytes
assert_eq!(dec.feed(b"0123456789").unwrap(), "0123456789");
// Second chunk: 5 valid bytes then invalid
let err = dec.feed(&[b'a', b'b', b'c', b'd', b'e', 0xff]).unwrap_err();
assert!(
err.contains("at byte 15"),
"should report global byte position 15 (10+5): {}",
err
);
}
}
Loading