From 0e3698dbcb6d0cb8b0ca54820e399873d3f8f83d Mon Sep 17 00:00:00 2001 From: Joe Isaacs Date: Fri, 13 Mar 2026 11:00:35 +0000 Subject: [PATCH 01/19] perf[fsst]: like pushdown using a dfa Signed-off-by: Joe Isaacs --- encodings/fsst/src/compute/like.rs | 1299 ++++++++++++++++++++++++++++ encodings/fsst/src/compute/mod.rs | 1 + encodings/fsst/src/kernel.rs | 2 + 3 files changed, 1302 insertions(+) create mode 100644 encodings/fsst/src/compute/like.rs diff --git a/encodings/fsst/src/compute/like.rs b/encodings/fsst/src/compute/like.rs new file mode 100644 index 00000000000..3946c640b30 --- /dev/null +++ b/encodings/fsst/src/compute/like.rs @@ -0,0 +1,1299 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +#![allow(clippy::cast_possible_truncation)] + +use fsst::ESCAPE_CODE; +use fsst::Symbol; +use vortex_array::ArrayRef; +use vortex_array::ExecutionCtx; +use vortex_array::IntoArray; +use vortex_array::ToCanonical; +use vortex_array::arrays::BoolArray; +use vortex_array::match_each_integer_ptype; +use vortex_array::scalar_fn::fns::like::LikeKernel; +use vortex_array::scalar_fn::fns::like::LikeOptions; +use vortex_buffer::BitBuffer; +use vortex_buffer::BufferMut; +use vortex_error::VortexResult; + +use crate::FSST; +use crate::FSSTArray; + +impl LikeKernel for FSST { + #[allow(clippy::cast_possible_truncation)] + fn like( + array: &FSSTArray, + pattern: &ArrayRef, + options: LikeOptions, + _ctx: &mut ExecutionCtx, + ) -> VortexResult> { + let Some(pattern_scalar) = pattern.as_constant() else { + return Ok(None); + }; + + if options.case_insensitive { + return Ok(None); + } + + let Some(pattern_str) = pattern_scalar.as_utf8().value() else { + return Ok(None); + }; + + let Some(like_kind) = LikeKind::parse(pattern_str) else { + return Ok(None); + }; + + let symbols = array.symbols(); + let symbol_lengths = array.symbol_lengths(); + let negated = options.negated; + + // Access the underlying codes VarBinArray buffers directly to avoid + // dyn Iterator overhead from with_iterator. + let codes = array.codes(); + let offsets = codes.offsets().to_primitive(); + let all_bytes = codes.bytes(); + let all_bytes = all_bytes.as_slice(); + let n = codes.len(); + + let result = match like_kind { + LikeKind::Prefix(prefix) => { + let prefix = prefix.as_bytes(); + // FsstPrefixDfa uses 4-bit shift packing: prefix_len + 2 states must fit in 16. + if prefix.len() + 2 > (1 << FsstPrefixDfa::BITS) { + return Ok(None); + } + let dfa = FsstPrefixDfa::new(symbols.as_slice(), symbol_lengths.as_slice(), prefix); + match_each_integer_ptype!(offsets.ptype(), |T| { + let off = offsets.as_slice::(); + dfa_scan_to_bitbuf(n, off, all_bytes, negated, |codes| dfa.matches(codes)) + }) + } + LikeKind::Contains(needle) => { + let needle = needle.as_bytes(); + if needle.len() <= BranchlessShiftDfa::MAX_NEEDLE_LEN { + let dfa = BranchlessShiftDfa::new( + symbols.as_slice(), + symbol_lengths.as_slice(), + needle, + ); + match_each_integer_ptype!(offsets.ptype(), |T| { + let off = offsets.as_slice::(); + dfa_scan_to_bitbuf(n, off, all_bytes, negated, |codes| dfa.matches(codes)) + }) + } else if needle.len() <= FlatBranchlessDfa::MAX_NEEDLE_LEN { + let dfa = FlatBranchlessDfa::new( + symbols.as_slice(), + symbol_lengths.as_slice(), + needle, + ); + match_each_integer_ptype!(offsets.ptype(), |T| { + let off = offsets.as_slice::(); + dfa_scan_to_bitbuf(n, off, all_bytes, negated, |codes| dfa.matches(codes)) + }) + } else { + let dfa = + FsstContainsDfa::new(symbols.as_slice(), symbol_lengths.as_slice(), needle); + match_each_integer_ptype!(offsets.ptype(), |T| { + let off = offsets.as_slice::(); + dfa_scan_to_bitbuf(n, off, all_bytes, negated, |codes| dfa.matches(codes)) + }) + } + } + }; + + // FSST delegates validity to its codes array, so we can read it + // directly without cloning the entire FSSTArray into an ArrayRef. + let validity = array + .codes() + .validity()? + .union_nullability(pattern_scalar.dtype().nullability()); + + Ok(Some(BoolArray::new(result, validity).into_array())) + } +} + +/// Scan all strings through a DFA matcher, packing results directly into a +/// `BitBuffer` one u64 word (64 strings) at a time. This avoids the overhead +/// of `BitBufferMut::collect_bool`'s cross-crate closure indirection and +/// guarantees the compiler can see the full loop body for optimization. +// TODO: add N-way ILP overrun scan for higher throughput on short strings. +#[inline] +fn dfa_scan_to_bitbuf( + n: usize, + offsets: &[T], + all_bytes: &[u8], + negated: bool, + matcher: F, +) -> BitBuffer +where + T: vortex_array::dtype::IntegerPType, + F: Fn(&[u8]) -> bool, +{ + let n_words = n / 64; + let remainder = n % 64; + let mut words: BufferMut = BufferMut::with_capacity(n.div_ceil(64)); + + for chunk in 0..n_words { + let base = chunk * 64; + let mut word = 0u64; + let mut start: usize = offsets[base].as_(); + for bit in 0..64 { + let end: usize = offsets[base + bit + 1].as_(); + word |= ((matcher(&all_bytes[start..end]) != negated) as u64) << bit; + start = end; + } + // SAFETY: we allocated capacity for n.div_ceil(64) words. + unsafe { words.push_unchecked(word) }; + } + + if remainder != 0 { + let base = n_words * 64; + let mut word = 0u64; + let mut start: usize = offsets[base].as_(); + for bit in 0..remainder { + let end: usize = offsets[base + bit + 1].as_(); + word |= ((matcher(&all_bytes[start..end]) != negated) as u64) << bit; + start = end; + } + unsafe { words.push_unchecked(word) }; + } + + BitBuffer::new(words.into_byte_buffer().freeze(), n) +} + +/// The subset of LIKE patterns we can handle without decompression. +enum LikeKind<'a> { + /// `prefix%` + Prefix(&'a str), + /// `%needle%` + Contains(&'a str), +} + +impl<'a> LikeKind<'a> { + fn parse(pattern: &'a str) -> Option { + if pattern == "%" { + return Some(LikeKind::Prefix("")); + } + + // Find first wildcard. + let first_wild = pattern.find(['%', '_'])?; + + // `_` as first wildcard means we can't handle it. + if pattern.as_bytes()[first_wild] == b'_' { + return None; + } + + // `prefix%` — single trailing % + if first_wild > 0 && &pattern[first_wild..] == "%" { + return Some(LikeKind::Prefix(&pattern[..first_wild])); + } + + // `%needle%` — leading and trailing %, no inner wildcards + if first_wild == 0 + && pattern.len() > 2 + && pattern.as_bytes()[pattern.len() - 1] == b'%' + && !pattern[1..pattern.len() - 1].contains(['%', '_']) + { + return Some(LikeKind::Contains(&pattern[1..pattern.len() - 1])); + } + + None + } +} + +// --------------------------------------------------------------------------- +// DFA for prefix matching (LIKE 'prefix%') +// --------------------------------------------------------------------------- + +/// Precomputed shift-based DFA for prefix matching on FSST codes. +/// +/// States 0..prefix_len track match progress, plus ACCEPT and FAIL. +/// Uses the same shift-based approach as the contains DFA: all state +/// transitions packed into a `u64` per code byte. For prefixes longer +/// than 13 characters, falls back to a fused u8 table. +struct FsstPrefixDfa { + /// Packed transitions: `(table[code] >> (state * 4)) & 0xF` gives next state. + transitions: [u64; 256], + /// Packed escape transitions for literal bytes. + escape_transitions: [u64; 256], + accept_state: u8, + fail_state: u8, +} + +impl FsstPrefixDfa { + const BITS: u32 = 4; + const MASK: u64 = (1 << Self::BITS) - 1; + + fn new(symbols: &[Symbol], symbol_lengths: &[u8], prefix: &[u8]) -> Self { + // prefix.len() + 2 states (0..prefix_len, accept, fail) must fit in 4 bits. + debug_assert!(prefix.len() + 2 <= (1 << Self::BITS)); + + let n_symbols = symbols.len(); + let accept_state = prefix.len() as u8; + let fail_state = prefix.len() as u8 + 1; + let n_states = prefix.len() + 2; + + // Build per-symbol and per-escape-byte transitions into flat tables. + let mut sym_trans = vec![fail_state; n_states * n_symbols]; + let mut esc_trans = vec![fail_state; n_states * 256]; + + for state in 0..n_states { + if state as u8 == accept_state { + for code in 0..n_symbols { + sym_trans[state * n_symbols + code] = accept_state; + } + for b in 0..256 { + esc_trans[state * 256 + b] = accept_state; + } + continue; + } + if state as u8 == fail_state { + continue; + } + + for code in 0..n_symbols { + let sym = symbols[code].to_u64().to_le_bytes(); + let sym_len = symbol_lengths[code] as usize; + let remaining = prefix.len() - state; + let cmp = sym_len.min(remaining); + + if sym[..cmp] == prefix[state..state + cmp] { + let next = state + cmp; + sym_trans[state * n_symbols + code] = if next >= prefix.len() { + accept_state + } else { + next as u8 + }; + } + } + + for b in 0..256usize { + if b as u8 == prefix[state] { + let next = state + 1; + esc_trans[state * 256 + b] = if next >= prefix.len() { + accept_state + } else { + next as u8 + }; + } + } + } + + // Fuse symbol transitions into a 256-wide table. + let escape_sentinel = fail_state + 1; + let mut fused = vec![fail_state; n_states * 256]; + for state in 0..n_states { + for code in 0..n_symbols { + fused[state * 256 + code] = sym_trans[state * n_symbols + code]; + } + fused[state * 256 + ESCAPE_CODE as usize] = escape_sentinel; + } + + // Pack into u64 shift tables. + let mut transitions = [0u64; 256]; + for code_byte in 0..256usize { + let mut packed = 0u64; + for state in 0..n_states { + packed |= (fused[state * 256 + code_byte] as u64) << (state as u32 * Self::BITS); + } + transitions[code_byte] = packed; + } + + let mut escape_transitions = [0u64; 256]; + for byte_val in 0..256usize { + let mut packed = 0u64; + for state in 0..n_states { + packed |= (esc_trans[state * 256 + byte_val] as u64) << (state as u32 * Self::BITS); + } + escape_transitions[byte_val] = packed; + } + + Self { + transitions, + escape_transitions, + accept_state, + fail_state, + } + } + + #[inline] + fn matches(&self, codes: &[u8]) -> bool { + let mut state = 0u8; + let mut pos = 0; + while pos < codes.len() { + let code = codes[pos]; + pos += 1; + let packed = self.transitions[code as usize]; + let next = ((packed >> (state as u32 * Self::BITS)) & Self::MASK) as u8; + if next == self.fail_state + 1 { + // Escape sentinel: read literal byte. + if pos >= codes.len() { + return false; + } + let b = codes[pos]; + pos += 1; + let esc_packed = self.escape_transitions[b as usize]; + state = ((esc_packed >> (state as u32 * Self::BITS)) & Self::MASK) as u8; + } else { + state = next; + } + if state == self.accept_state { + return true; + } + if state == self.fail_state { + return false; + } + } + state == self.accept_state + } +} + +// --------------------------------------------------------------------------- +// DFA for contains matching (LIKE '%needle%') +// --------------------------------------------------------------------------- + +/// Contains DFA for long needles (>14 chars). Short needles (len <= 7) are +/// handled by `BranchlessShiftDfa`, medium needles (8-14) by +/// `FlatBranchlessDfa`. +enum FsstContainsDfa { + /// Shift-based DFA for medium needles (len 8-14). + Shift(Box), + /// Fused u8 table DFA for long needles (len > 14). + Fused(FusedDfa), +} + +impl FsstContainsDfa { + fn new(symbols: &[Symbol], symbol_lengths: &[u8], needle: &[u8]) -> Self { + if needle.len() <= ShiftDfa::MAX_NEEDLE_LEN { + FsstContainsDfa::Shift(Box::new(ShiftDfa::new(symbols, symbol_lengths, needle))) + } else { + FsstContainsDfa::Fused(FusedDfa::new(symbols, symbol_lengths, needle)) + } + } + + #[inline] + fn matches(&self, codes: &[u8]) -> bool { + match self { + FsstContainsDfa::Shift(dfa) => dfa.matches(codes), + FsstContainsDfa::Fused(dfa) => dfa.matches(codes), + } + } +} + +/// Branchless escape-folded DFA for short needles (len <= 7). +/// +/// Folds escape handling into the state space so that `matches()` is +/// completely branchless (except for loop control). The state layout is: +/// - States 0..N-1: normal match-progress states +/// - State N: accept (sticky for all inputs) +/// - States N+1..2N: escape states (state `s+N+1` means "was in state `s`, +/// just consumed ESCAPE_CODE") +/// +/// Total states: 2N+1. With 4-bit packing, max N=7. +/// +/// Uses a decomposed hierarchical lookup that processes 4 code bytes per +/// loop iteration with only ~3 KB of tables: +/// +/// 1. **Equivalence class table** (256 B): maps each code byte to a class +/// id. Bytes with identical transition u64s share a class -- typically +/// only ~6-10 classes exist (needle chars + escape + "miss-all"). +/// 2. **Pair-compose table** (~N^2 B): maps `(class0, class1)` to a 2-byte +/// palette index. Typically ~36 entries. +/// 3. **4-byte compose table** (~M^2 x 8 B): maps `(palette0, palette1)` to +/// the composed packed u64 for all 4 bytes. Typically ~81 entries = 648 B. +/// +/// Each loop iteration: 4 class lookups (parallel, 256 B table) -> 2 +/// pair-compose lookups (parallel, ~36 B table) -> 1 compose lookup +/// (~648 B table) -> 1 shift+mask. All tables fit in L1 cache. +struct BranchlessShiftDfa { + /// Maps each code byte to its equivalence class. Bytes with the same + /// packed transition u64 share a class. (256 bytes) + eq_class: [u8; 256], + /// Maps `(class0 * n_classes + class1)` -> 2-byte palette index. + pair_compose: Vec, + /// Number of equivalence classes (stride for pair_compose). + n_classes: usize, + /// Maps `(palette0 * n_palette + palette1)` -> composed packed u64 + /// for 4 bytes. + compose_4b: Vec, + /// Number of unique 2-byte palette entries (stride for compose_4b). + n_palette: usize, + /// 1-byte fallback transitions for trailing bytes. + transitions_1b: [u64; 256], + /// 2-byte palette for the remainder path (2-3 trailing bytes). + palette_2b: Vec, + accept_state: u8, +} + +impl BranchlessShiftDfa { + const BITS: u32 = 4; + const MASK: u64 = (1 << Self::BITS) - 1; + /// Maximum needle length: need 2N+1 states to fit in 16 slots (4 bits). + /// 2*7+1 = 15 <= 16, so max N = 7. + const MAX_NEEDLE_LEN: usize = 7; + + fn new(symbols: &[Symbol], symbol_lengths: &[u8], needle: &[u8]) -> Self { + let n = needle.len(); + debug_assert!(n <= Self::MAX_NEEDLE_LEN); + + let accept_state = n as u8; + let total_states = 2 * n + 1; + debug_assert!(total_states <= (1 << Self::BITS)); + + let transitions_1b = + Self::build_1b_transitions(symbols, symbol_lengths, needle, total_states); + + // Build equivalence classes: group bytes with identical transition u64. + let mut eq_class = [0u8; 256]; + let mut class_representatives: Vec = Vec::new(); + for byte_val in 0..256usize { + let t = transitions_1b[byte_val]; + let cls = class_representatives + .iter() + .position(|&v| v == t) + .unwrap_or_else(|| { + class_representatives.push(t); + class_representatives.len() - 1 + }); + eq_class[byte_val] = cls as u8; + } + let n_classes = class_representatives.len(); + + // Build pair-compose: for each (class0, class1), compose the two + // 1-byte transitions and deduplicate into a 2-byte palette. + let (pair_compose, palette_2b) = + Self::build_pair_compose(&class_representatives, n_classes, total_states); + + // Build 4-byte composition: compose_4b[p0 * n + p1] gives the packed + // u64 for applying palette_2b[p0] then palette_2b[p1] in sequence. + let n_palette = palette_2b.len(); + let compose_4b = Self::build_compose_4b(&palette_2b, total_states); + + Self { + eq_class, + pair_compose, + n_classes, + compose_4b, + n_palette, + transitions_1b, + palette_2b, + accept_state, + } + } + + /// Build the 1-byte packed transition table from FSST symbols and + /// a byte-level KMP table, folding escape handling into the state space. + fn build_1b_transitions( + symbols: &[Symbol], + symbol_lengths: &[u8], + needle: &[u8], + total_states: usize, + ) -> [u64; 256] { + let n = needle.len(); + let n_symbols = symbols.len(); + let accept_state = n as u8; + let n_normal_states = n + 1; + + let byte_table = kmp_byte_transitions(needle); + + // Build per-symbol transitions for normal states. + let mut sym_trans = vec![0u8; n_normal_states * n_symbols]; + for state in 0..n_normal_states { + for code in 0..n_symbols { + if state as u8 == accept_state { + sym_trans[state * n_symbols + code] = accept_state; + continue; + } + let sym = symbols[code].to_u64().to_le_bytes(); + let sym_len = symbol_lengths[code] as usize; + let mut s = state as u16; + for &b in &sym[..sym_len] { + if s == accept_state as u16 { + break; + } + s = byte_table[s as usize * 256 + b as usize]; + } + sym_trans[state * n_symbols + code] = s as u8; + } + } + + // Build fused transition table with escape folding. + let mut fused = vec![0u8; total_states * 256]; + for code_byte in 0..256usize { + for s in 0..n { + if code_byte == ESCAPE_CODE as usize { + fused[s * 256 + code_byte] = (s + n + 1) as u8; + } else if code_byte < n_symbols { + fused[s * 256 + code_byte] = sym_trans[s * n_symbols + code_byte]; + } + } + fused[n * 256 + code_byte] = accept_state; + for s in 0..n { + let esc_state = s + n + 1; + let next = byte_table[s * 256 + code_byte] as u8; + fused[esc_state * 256 + code_byte] = next; + } + } + + // Pack into u64 shift table. + let mut transitions = [0u64; 256]; + for code_byte in 0..256usize { + let mut packed = 0u64; + for state in 0..total_states { + packed |= (fused[state * 256 + code_byte] as u64) << (state as u32 * Self::BITS); + } + transitions[code_byte] = packed; + } + transitions + } + + /// Build the pair-compose table and 2-byte palette from equivalence + /// class representatives. + fn build_pair_compose( + class_reps: &[u64], + n_classes: usize, + total_states: usize, + ) -> (Vec, Vec) { + let mut pair_compose = vec![0u8; n_classes * n_classes]; + let mut palette_2b: Vec = Vec::new(); + + for c0 in 0..n_classes { + for c1 in 0..n_classes { + let t0 = class_reps[c0]; + let t1 = class_reps[c1]; + let mut packed = 0u64; + for state in 0..total_states { + let mid = ((t0 >> (state as u32 * Self::BITS)) & Self::MASK) as u8; + let final_s = ((t1 >> (mid as u32 * Self::BITS)) & Self::MASK) as u8; + packed |= (final_s as u64) << (state as u32 * Self::BITS); + } + let idx = palette_2b + .iter() + .position(|&v| v == packed) + .unwrap_or_else(|| { + palette_2b.push(packed); + palette_2b.len() - 1 + }); + pair_compose[c0 * n_classes + c1] = idx as u8; + } + } + (pair_compose, palette_2b) + } + + /// Compose pairs of 2-byte palette entries into a 4-byte lookup table. + fn build_compose_4b(palette_2b: &[u64], total_states: usize) -> Vec { + let n = palette_2b.len(); + let mut compose = vec![0u64; n * n]; + for p0 in 0..n { + for p1 in 0..n { + let mut packed = 0u64; + for state in 0..total_states { + let mid = ((palette_2b[p0] >> (state as u32 * Self::BITS)) & Self::MASK) as u8; + let final_s = + ((palette_2b[p1] >> (mid as u32 * Self::BITS)) & Self::MASK) as u8; + packed |= (final_s as u64) << (state as u32 * Self::BITS); + } + compose[p0 * n + p1] = packed; + } + } + compose + } + + /// Process remaining bytes after the interleaved common prefix. + #[inline] + fn finish_tail(&self, mut state: u8, codes: &[u8]) -> u8 { + let chunks = codes.chunks_exact(4); + let rem = chunks.remainder(); + + for chunk in chunks { + let ec0 = unsafe { *self.eq_class.get_unchecked(chunk[0] as usize) } as usize; + let ec1 = unsafe { *self.eq_class.get_unchecked(chunk[1] as usize) } as usize; + let ec2 = unsafe { *self.eq_class.get_unchecked(chunk[2] as usize) } as usize; + let ec3 = unsafe { *self.eq_class.get_unchecked(chunk[3] as usize) } as usize; + let p0 = + unsafe { *self.pair_compose.get_unchecked(ec0 * self.n_classes + ec1) } as usize; + let p1 = + unsafe { *self.pair_compose.get_unchecked(ec2 * self.n_classes + ec3) } as usize; + let packed = unsafe { *self.compose_4b.get_unchecked(p0 * self.n_palette + p1) }; + state = ((packed >> (state as u32 * Self::BITS)) & Self::MASK) as u8; + } + + if rem.len() >= 2 { + let ec0 = self.eq_class[rem[0] as usize] as usize; + let ec1 = self.eq_class[rem[1] as usize] as usize; + let p = self.pair_compose[ec0 * self.n_classes + ec1] as usize; + let packed = self.palette_2b[p]; + state = ((packed >> (state as u32 * Self::BITS)) & Self::MASK) as u8; + if rem.len() == 3 { + let packed = self.transitions_1b[rem[2] as usize]; + state = ((packed >> (state as u32 * Self::BITS)) & Self::MASK) as u8; + } + } else if rem.len() == 1 { + let packed = self.transitions_1b[rem[0] as usize]; + state = ((packed >> (state as u32 * Self::BITS)) & Self::MASK) as u8; + } + + state + } + + /// Branchless matching processing four code bytes per iteration. + #[inline(never)] + fn matches(&self, codes: &[u8]) -> bool { + self.finish_tail(0, codes) == self.accept_state + } +} + +/// Flat u8 escape-folded DFA for medium needles (8-14 chars). +/// +/// Like `BranchlessShiftDfa`, folds escape handling into the state space +/// (2N+1 states), but uses a flat `u8` transition table instead of +/// shift-packed `u64`. Supports up to 14-char needles (2*14+1 = 29 states). +/// Table size: 29 * 256 = 7,424 bytes, fits in L1. +struct FlatBranchlessDfa { + /// transitions[state * 256 + byte] -> next state + transitions: Vec, + accept_state: u8, +} + +impl FlatBranchlessDfa { + const MAX_NEEDLE_LEN: usize = 14; + + fn new(symbols: &[Symbol], symbol_lengths: &[u8], needle: &[u8]) -> Self { + let n = needle.len(); + debug_assert!(n <= Self::MAX_NEEDLE_LEN); + + let accept_state = n as u8; + let total_states = 2 * n + 1; + let n_symbols = symbols.len(); + + let byte_table = kmp_byte_transitions(needle); + + // Build per-symbol transitions for normal states. + let mut sym_trans = vec![0u8; (n + 1) * n_symbols]; + for state in 0..=n { + for code in 0..n_symbols { + if state as u8 == accept_state { + sym_trans[state * n_symbols + code] = accept_state; + continue; + } + let sym = symbols[code].to_u64().to_le_bytes(); + let sym_len = symbol_lengths[code] as usize; + let mut s = state as u16; + for &b in &sym[..sym_len] { + if s == accept_state as u16 { + break; + } + s = byte_table[s as usize * 256 + b as usize]; + } + sym_trans[state * n_symbols + code] = s as u8; + } + } + + // Build fused transition table with escape folding. + let mut transitions = vec![0u8; total_states * 256]; + for code_byte in 0..256usize { + // Normal states 0..n + for s in 0..n { + if code_byte == ESCAPE_CODE as usize { + transitions[s * 256 + code_byte] = (s + n + 1) as u8; + } else if code_byte < n_symbols { + transitions[s * 256 + code_byte] = sym_trans[s * n_symbols + code_byte]; + } + } + // Accept state (sticky) + transitions[n * 256 + code_byte] = accept_state; + // Escape states n+1..2n + for s in 0..n { + let esc_state = s + n + 1; + let next = byte_table[s * 256 + code_byte] as u8; + transitions[esc_state * 256 + code_byte] = next; + } + } + + Self { + transitions, + accept_state, + } + } + + #[inline(never)] + fn matches(&self, codes: &[u8]) -> bool { + let mut state = 0u8; + for &byte in codes { + state = self.transitions[state as usize * 256 + byte as usize]; + } + state == self.accept_state + } +} + +/// Shift-based DFA: packs all state transitions into a `u64` per input byte. +/// +/// For a DFA with S states (S <= 16, using 4 bits each), we store transitions +/// for ALL states in one `u64`. Transition: `next = (table[code] >> (state * 4)) & 0xF`. +/// +/// Supports needles up to 14 characters (needle.len() + 2 <= 16 to fit escape +/// sentinel). This covers virtually all practical LIKE patterns. +struct ShiftDfa { + /// For each code byte (0..255): a `u64` packing all state transitions. + /// Bits `[state*4 .. state*4+4)` encode the next state for that input. + transitions: [u64; 256], + /// Same layout for escape byte transitions. + escape_transitions: [u64; 256], + accept_state: u8, + escape_sentinel: u8, +} + +impl ShiftDfa { + const BITS: u32 = 4; + const MASK: u64 = (1 << Self::BITS) - 1; + /// Maximum needle length: 2^BITS - 2 (need room for accept + sentinel). + const MAX_NEEDLE_LEN: usize = (1 << Self::BITS) - 2; + + fn new(symbols: &[Symbol], symbol_lengths: &[u8], needle: &[u8]) -> Self { + debug_assert!(needle.len() <= Self::MAX_NEEDLE_LEN); + + let n_symbols = symbols.len(); + let n_states = needle.len() + 1; + let accept_state = needle.len() as u8; + let escape_sentinel = needle.len() as u8 + 1; + + let byte_table = kmp_byte_transitions(needle); + + // Build per-symbol transitions into a flat table first. + let mut sym_trans = vec![0u16; n_states * n_symbols]; + for state in 0..n_states { + for code in 0..n_symbols { + if state as u8 == accept_state { + sym_trans[state * n_symbols + code] = accept_state as u16; + continue; + } + let sym = symbols[code].to_u64().to_le_bytes(); + let sym_len = symbol_lengths[code] as usize; + let mut s = state as u16; + for &b in &sym[..sym_len] { + if s == accept_state as u16 { + break; + } + s = byte_table[s as usize * 256 + b as usize]; + } + sym_trans[state * n_symbols + code] = s; + } + } + + // Build fused 256-wide table, then pack into u64 shift tables. + let mut fused = vec![0u8; n_states * 256]; + for state in 0..n_states { + for code in 0..n_symbols { + fused[state * 256 + code] = sym_trans[state * n_symbols + code] as u8; + } + fused[state * 256 + ESCAPE_CODE as usize] = escape_sentinel; + } + + let mut transitions = [0u64; 256]; + for code_byte in 0..256usize { + let mut packed = 0u64; + for state in 0..n_states { + let next = fused[state * 256 + code_byte]; + packed |= (next as u64) << (state as u32 * Self::BITS); + } + transitions[code_byte] = packed; + } + + let mut escape_transitions = [0u64; 256]; + for byte_val in 0..256usize { + let mut packed = 0u64; + for state in 0..n_states { + let next = byte_table[state * 256 + byte_val] as u8; + packed |= (next as u64) << (state as u32 * Self::BITS); + } + escape_transitions[byte_val] = packed; + } + + Self { + transitions, + escape_transitions, + accept_state, + escape_sentinel, + } + } + + /// Match with iterator-based traversal. + /// + /// Using `iter.next()` instead of manual index + bounds check helps the + /// compiler eliminate redundant bounds checks. + #[inline] + fn matches(&self, codes: &[u8]) -> bool { + let mut state = 0u8; + let mut iter = codes.iter(); + while let Some(&code) = iter.next() { + let packed = self.transitions[code as usize]; + let next = ((packed >> (state as u32 * Self::BITS)) & Self::MASK) as u8; + if next == self.escape_sentinel { + let Some(&b) = iter.next() else { + return false; + }; + let esc_packed = self.escape_transitions[b as usize]; + state = ((esc_packed >> (state as u32 * Self::BITS)) & Self::MASK) as u8; + } else { + state = next; + } + } + state == self.accept_state + } +} + +/// Fused 256-entry u8 table DFA. Fallback for needles > 14 characters. +struct FusedDfa { + transitions: Vec, + escape_transitions: Vec, + accept_state: u8, + escape_sentinel: u8, +} + +impl FusedDfa { + fn new(symbols: &[Symbol], symbol_lengths: &[u8], needle: &[u8]) -> Self { + let n_symbols = symbols.len(); + let accept_state = needle.len() as u8; + let n_states = needle.len() + 1; + let escape_sentinel = needle.len() as u8 + 1; + + let byte_table = kmp_byte_transitions(needle); + + let mut symbol_transitions = vec![0u16; n_states * n_symbols]; + for state in 0..n_states { + for code in 0..n_symbols { + if state as u8 == accept_state { + symbol_transitions[state * n_symbols + code] = accept_state as u16; + continue; + } + let sym = symbols[code].to_u64().to_le_bytes(); + let sym_len = symbol_lengths[code] as usize; + let mut s = state as u16; + for &b in &sym[..sym_len] { + if s == accept_state as u16 { + break; + } + s = byte_table[s as usize * 256 + b as usize]; + } + symbol_transitions[state * n_symbols + code] = s; + } + } + + let mut transitions = vec![0u8; n_states * 256]; + for state in 0..n_states { + for code in 0..n_symbols { + transitions[state * 256 + code] = + symbol_transitions[state * n_symbols + code] as u8; + } + transitions[state * 256 + ESCAPE_CODE as usize] = escape_sentinel; + } + + let escape_transitions: Vec = byte_table.iter().map(|&v| v as u8).collect(); + + Self { + transitions, + escape_transitions, + accept_state, + escape_sentinel, + } + } + + #[inline] + fn matches(&self, codes: &[u8]) -> bool { + let mut state = 0u8; + let mut pos = 0; + while pos < codes.len() { + let code = codes[pos]; + pos += 1; + let next = self.transitions[state as usize * 256 + code as usize]; + if next == self.escape_sentinel { + if pos >= codes.len() { + return false; + } + let b = codes[pos]; + pos += 1; + state = self.escape_transitions[state as usize * 256 + b as usize]; + } else { + state = next; + } + if state == self.accept_state { + return true; + } + } + false + } +} + +// --------------------------------------------------------------------------- +// KMP helpers +// --------------------------------------------------------------------------- + +fn kmp_byte_transitions(needle: &[u8]) -> Vec { + let n_states = needle.len() + 1; + let accept = needle.len() as u16; + let failure = kmp_failure_table(needle); + + let mut table = vec![0u16; n_states * 256]; + for state in 0..n_states { + for byte in 0..256u16 { + if state == needle.len() { + table[state * 256 + byte as usize] = accept; + continue; + } + let mut s = state; + loop { + if byte as u8 == needle[s] { + s += 1; + break; + } + if s == 0 { + break; + } + s = failure[s - 1]; + } + table[state * 256 + byte as usize] = s as u16; + } + } + table +} + +fn kmp_failure_table(needle: &[u8]) -> Vec { + let mut failure = vec![0usize; needle.len()]; + let mut k = 0; + for i in 1..needle.len() { + while k > 0 && needle[k] != needle[i] { + k = failure[k - 1]; + } + if needle[k] == needle[i] { + k += 1; + } + failure[i] = k; + } + failure +} + +#[cfg(test)] +mod tests { + use std::sync::LazyLock; + + use vortex_array::Canonical; + use vortex_array::IntoArray; + use vortex_array::VortexSessionExecute; + use vortex_array::arrays::BoolArray; + use vortex_array::arrays::ConstantArray; + use vortex_array::arrays::VarBinArray; + use vortex_array::arrays::scalar_fn::ScalarFnArrayExt; + use vortex_array::assert_arrays_eq; + use vortex_array::dtype::DType; + use vortex_array::dtype::Nullability; + use vortex_array::scalar_fn::fns::like::Like; + use vortex_array::scalar_fn::fns::like::LikeKernel; + use vortex_array::scalar_fn::fns::like::LikeOptions; + use vortex_array::session::ArraySession; + use vortex_error::VortexResult; + use vortex_session::VortexSession; + + use crate::FSST; + use crate::FSSTArray; + use crate::fsst_compress; + use crate::fsst_train_compressor; + + static SESSION: LazyLock = + LazyLock::new(|| VortexSession::empty().with::()); + + fn make_fsst(strings: &[Option<&str>], nullability: Nullability) -> FSSTArray { + let varbin = VarBinArray::from_iter(strings.iter().copied(), DType::Utf8(nullability)); + let compressor = fsst_train_compressor(&varbin); + fsst_compress(varbin, &compressor) + } + + fn run_like(array: FSSTArray, pattern: &str, opts: LikeOptions) -> VortexResult { + let len = array.len(); + let arr = array.into_array(); + let pattern = ConstantArray::new(pattern, len).into_array(); + let result = Like + .try_new_array(len, opts, [arr, pattern])? + .into_array() + .execute::(&mut SESSION.create_execution_ctx())?; + Ok(result.into_bool()) + } + + #[test] + fn test_like_prefix() -> VortexResult<()> { + let fsst = make_fsst( + &[ + Some("http://example.com"), + Some("http://test.org"), + Some("ftp://files.net"), + Some("http://vortex.dev"), + Some("ssh://server.io"), + ], + Nullability::NonNullable, + ); + let result = run_like(fsst, "http%", LikeOptions::default())?; + assert_arrays_eq!( + &result, + &BoolArray::from_iter([true, true, false, true, false]) + ); + Ok(()) + } + + #[test] + fn test_like_prefix_with_nulls() -> VortexResult<()> { + let fsst = make_fsst( + &[Some("hello"), None, Some("help"), None, Some("goodbye")], + Nullability::Nullable, + ); + let result = run_like(fsst, "hel%", LikeOptions::default())?; + assert_arrays_eq!( + &result, + &BoolArray::from_iter([Some(true), None, Some(true), None, Some(false)]) + ); + Ok(()) + } + + #[test] + fn test_like_contains() -> VortexResult<()> { + let fsst = make_fsst( + &[ + Some("hello world"), + Some("say hello"), + Some("goodbye"), + Some("hellooo"), + ], + Nullability::NonNullable, + ); + let result = run_like(fsst, "%hello%", LikeOptions::default())?; + assert_arrays_eq!(&result, &BoolArray::from_iter([true, true, false, true])); + Ok(()) + } + + #[test] + fn test_like_contains_cross_symbol() -> VortexResult<()> { + let fsst = make_fsst( + &[ + Some("the quick brown fox jumps over the lazy dog"), + Some("a short string"), + Some("the lazy dog sleeps"), + Some("no match"), + ], + Nullability::NonNullable, + ); + let result = run_like(fsst, "%lazy dog%", LikeOptions::default())?; + assert_arrays_eq!(&result, &BoolArray::from_iter([true, false, true, false])); + Ok(()) + } + + #[test] + fn test_not_like_contains() -> VortexResult<()> { + let fsst = make_fsst( + &[Some("foobar_sdf"), Some("sdf_start"), Some("nothing")], + Nullability::NonNullable, + ); + let opts = LikeOptions { + negated: true, + case_insensitive: false, + }; + let result = run_like(fsst, "%sdf%", opts)?; + assert_arrays_eq!(&result, &BoolArray::from_iter([false, false, true])); + Ok(()) + } + + #[test] + fn test_like_match_all() -> VortexResult<()> { + let fsst = make_fsst( + &[Some("abc"), Some(""), Some("xyz")], + Nullability::NonNullable, + ); + let result = run_like(fsst, "%", LikeOptions::default())?; + assert_arrays_eq!(&result, &BoolArray::from_iter([true, true, true])); + Ok(()) + } + + /// Call `LikeKernel::like` directly on the FSSTArray and verify it + /// returns `Some(...)` (i.e. the kernel handles it, rather than + /// returning `None` which would mean "fall back to decompress"). + #[test] + fn test_like_prefix_kernel_handles() -> VortexResult<()> { + let fsst = make_fsst( + &[Some("http://a.com"), Some("ftp://b.com")], + Nullability::NonNullable, + ); + let pattern = ConstantArray::new("http%", fsst.len()).into_array(); + let mut ctx = SESSION.create_execution_ctx(); + + let result = ::like(&fsst, &pattern, LikeOptions::default(), &mut ctx)?; + assert!(result.is_some(), "FSST LikeKernel should handle prefix%"); + assert_arrays_eq!(result.unwrap(), BoolArray::from_iter([true, false])); + Ok(()) + } + + /// Same direct-call check for the contains pattern `%needle%`. + #[test] + fn test_like_contains_kernel_handles() -> VortexResult<()> { + let fsst = make_fsst( + &[Some("hello world"), Some("goodbye")], + Nullability::NonNullable, + ); + let pattern = ConstantArray::new("%world%", fsst.len()).into_array(); + let mut ctx = SESSION.create_execution_ctx(); + + let result = ::like(&fsst, &pattern, LikeOptions::default(), &mut ctx)?; + assert!(result.is_some(), "FSST LikeKernel should handle %needle%"); + assert_arrays_eq!(result.unwrap(), BoolArray::from_iter([true, false])); + Ok(()) + } + + /// Patterns we can't handle should return `None` (fall back). + #[test] + fn test_like_kernel_falls_back_for_complex_pattern() -> VortexResult<()> { + let fsst = make_fsst(&[Some("abc"), Some("def")], Nullability::NonNullable); + let mut ctx = SESSION.create_execution_ctx(); + + // Suffix pattern -- not handled. + let pattern = ConstantArray::new("%abc", fsst.len()).into_array(); + let result = ::like(&fsst, &pattern, LikeOptions::default(), &mut ctx)?; + assert!(result.is_none(), "suffix pattern should fall back"); + + // Underscore wildcard -- not handled. + let pattern = ConstantArray::new("a_c", fsst.len()).into_array(); + let result = ::like(&fsst, &pattern, LikeOptions::default(), &mut ctx)?; + assert!(result.is_none(), "underscore pattern should fall back"); + + // Case-insensitive -- not handled. + let pattern = ConstantArray::new("abc%", fsst.len()).into_array(); + let opts = LikeOptions { + negated: false, + case_insensitive: true, + }; + let result = ::like(&fsst, &pattern, opts, &mut ctx)?; + assert!(result.is_none(), "ilike should fall back"); + + Ok(()) + } + + // ----------------------------------------------------------------------- + // Fuzz tests: compare FSST kernel against naive string matching + // ----------------------------------------------------------------------- + + use rand::Rng; + use rand::SeedableRng; + use rand::rngs::StdRng; + + fn random_string(rng: &mut StdRng, max_len: usize) -> String { + let len = rng.random_range(0..=max_len); + // Use a small alphabet to increase substring hit rate. + (0..len) + .map(|_| (b'a' + rng.random_range(0..6u8)) as char) + .collect() + } + + fn fuzz_contains(seed: u64, needle_len: usize, n_strings: usize) -> VortexResult<()> { + let mut rng = StdRng::seed_from_u64(seed); + + let needle: String = (0..needle_len) + .map(|_| (b'a' + rng.random_range(0..6u8)) as char) + .collect(); + + let owned: Vec = (0..n_strings) + .map(|_| random_string(&mut rng, 80)) + .collect(); + let strings: Vec> = owned.iter().map(|s| Some(s.as_str())).collect(); + + let expected: Vec = owned.iter().map(|s| s.contains(&needle)).collect(); + + let fsst = make_fsst(&strings, Nullability::NonNullable); + let pattern = format!("%{needle}%"); + let result = run_like(fsst, &pattern, LikeOptions::default())?; + + let got: Vec = (0..n_strings) + .map(|i| result.to_bit_buffer().value(i)) + .collect(); + + for (i, (e, g)) in expected.iter().zip(got.iter()).enumerate() { + assert_eq!( + e, g, + "mismatch at index {i}: string={:?}, needle={needle:?}, expected={e}, got={g}", + &owned[i], + ); + } + Ok(()) + } + + fn fuzz_prefix(seed: u64, prefix_len: usize, n_strings: usize) -> VortexResult<()> { + let mut rng = StdRng::seed_from_u64(seed); + + let prefix: String = (0..prefix_len) + .map(|_| (b'a' + rng.random_range(0..6u8)) as char) + .collect(); + + let owned: Vec = (0..n_strings) + .map(|_| random_string(&mut rng, 80)) + .collect(); + let strings: Vec> = owned.iter().map(|s| Some(s.as_str())).collect(); + + let expected: Vec = owned.iter().map(|s| s.starts_with(&prefix)).collect(); + + let fsst = make_fsst(&strings, Nullability::NonNullable); + let pattern = format!("{prefix}%"); + let result = run_like(fsst, &pattern, LikeOptions::default())?; + + let got: Vec = (0..n_strings) + .map(|i| result.to_bit_buffer().value(i)) + .collect(); + + for (i, (e, g)) in expected.iter().zip(got.iter()).enumerate() { + assert_eq!( + e, g, + "mismatch at index {i}: string={:?}, prefix={prefix:?}, expected={e}, got={g}", + &owned[i], + ); + } + Ok(()) + } + + /// Fuzz contains with short needles (1-7 chars) -> BranchlessShiftDfa + #[test] + fn fuzz_contains_short_needle() -> VortexResult<()> { + for seed in 0..50 { + for needle_len in 1..=7 { + fuzz_contains(seed, needle_len, 200)?; + } + } + Ok(()) + } + + /// Fuzz contains with medium needles (8-14 chars) -> FlatBranchlessDfa + #[test] + fn fuzz_contains_medium_needle() -> VortexResult<()> { + for seed in 0..50 { + for needle_len in [8, 10, 14] { + fuzz_contains(seed, needle_len, 200)?; + } + } + Ok(()) + } + + /// Fuzz contains with long needles (>14 chars) -> FsstContainsDfa + #[test] + fn fuzz_contains_long_needle() -> VortexResult<()> { + for seed in 0..30 { + for needle_len in [15, 20, 30] { + fuzz_contains(seed, needle_len, 200)?; + } + } + Ok(()) + } + + /// Fuzz prefix matching + #[test] + fn fuzz_prefix_matching() -> VortexResult<()> { + for seed in 0..50 { + for prefix_len in [1, 3, 5, 10, 13] { + fuzz_prefix(seed, prefix_len, 200)?; + } + } + Ok(()) + } +} diff --git a/encodings/fsst/src/compute/mod.rs b/encodings/fsst/src/compute/mod.rs index 839deb6c588..f49c2954a04 100644 --- a/encodings/fsst/src/compute/mod.rs +++ b/encodings/fsst/src/compute/mod.rs @@ -4,6 +4,7 @@ mod cast; mod compare; mod filter; +mod like; use vortex_array::ArrayRef; use vortex_array::DynArray; diff --git a/encodings/fsst/src/kernel.rs b/encodings/fsst/src/kernel.rs index 8d2a08fba2b..3ec36dd1b32 100644 --- a/encodings/fsst/src/kernel.rs +++ b/encodings/fsst/src/kernel.rs @@ -5,6 +5,7 @@ use vortex_array::arrays::dict::TakeExecuteAdaptor; use vortex_array::arrays::filter::FilterExecuteAdaptor; use vortex_array::kernel::ParentKernelSet; use vortex_array::scalar_fn::fns::binary::CompareExecuteAdaptor; +use vortex_array::scalar_fn::fns::like::LikeExecuteAdaptor; use crate::FSST; @@ -12,6 +13,7 @@ pub(super) const PARENT_KERNELS: ParentKernelSet = ParentKernelSet::new(&[ ParentKernelSet::lift(&CompareExecuteAdaptor(FSST)), ParentKernelSet::lift(&FilterExecuteAdaptor(FSST)), ParentKernelSet::lift(&TakeExecuteAdaptor(FSST)), + ParentKernelSet::lift(&LikeExecuteAdaptor(FSST)), ]); #[cfg(test)] From 9f2ff66abab0dd82888f5f24460f32ade4d4f0da Mon Sep 17 00:00:00 2001 From: Joe Isaacs Date: Fri, 13 Mar 2026 11:30:24 +0000 Subject: [PATCH 02/19] perf[fsst]: like pushdown using a dfa Signed-off-by: Joe Isaacs --- encodings/fsst/src/compute/like.rs | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/encodings/fsst/src/compute/like.rs b/encodings/fsst/src/compute/like.rs index 3946c640b30..5d29f469dc8 100644 --- a/encodings/fsst/src/compute/like.rs +++ b/encodings/fsst/src/compute/like.rs @@ -1019,6 +1019,10 @@ mod tests { Ok(result.into_bool()) } + fn like(array: FSSTArray, pattern: &str) -> VortexResult { + run_like(array, pattern, LikeOptions::default()) + } + #[test] fn test_like_prefix() -> VortexResult<()> { let fsst = make_fsst( @@ -1031,7 +1035,7 @@ mod tests { ], Nullability::NonNullable, ); - let result = run_like(fsst, "http%", LikeOptions::default())?; + let result = like(fsst, "http%")?; assert_arrays_eq!( &result, &BoolArray::from_iter([true, true, false, true, false]) @@ -1045,7 +1049,7 @@ mod tests { &[Some("hello"), None, Some("help"), None, Some("goodbye")], Nullability::Nullable, ); - let result = run_like(fsst, "hel%", LikeOptions::default())?; + let result = like(fsst, "hel%")?; // spellchecker:disable-line assert_arrays_eq!( &result, &BoolArray::from_iter([Some(true), None, Some(true), None, Some(false)]) @@ -1064,7 +1068,7 @@ mod tests { ], Nullability::NonNullable, ); - let result = run_like(fsst, "%hello%", LikeOptions::default())?; + let result = like(fsst, "%hello%")?; assert_arrays_eq!(&result, &BoolArray::from_iter([true, true, false, true])); Ok(()) } @@ -1080,7 +1084,7 @@ mod tests { ], Nullability::NonNullable, ); - let result = run_like(fsst, "%lazy dog%", LikeOptions::default())?; + let result = like(fsst, "%lazy dog%")?; assert_arrays_eq!(&result, &BoolArray::from_iter([true, false, true, false])); Ok(()) } @@ -1106,7 +1110,7 @@ mod tests { &[Some("abc"), Some(""), Some("xyz")], Nullability::NonNullable, ); - let result = run_like(fsst, "%", LikeOptions::default())?; + let result = like(fsst, "%")?; assert_arrays_eq!(&result, &BoolArray::from_iter([true, true, true])); Ok(()) } From ebf5457fdefb0fc8ee9bc2e47fb093d1d12456f5 Mon Sep 17 00:00:00 2001 From: Joe Isaacs Date: Fri, 13 Mar 2026 11:58:39 +0000 Subject: [PATCH 03/19] perf[fsst]: like pushdown using a dfa Signed-off-by: Joe Isaacs --- encodings/fsst/public-api.lock | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/encodings/fsst/public-api.lock b/encodings/fsst/public-api.lock index c25ba6b44f2..c7f958d609c 100644 --- a/encodings/fsst/public-api.lock +++ b/encodings/fsst/public-api.lock @@ -30,6 +30,10 @@ impl vortex_array::scalar_fn::fns::cast::kernel::CastReduce for vortex_fsst::FSS pub fn vortex_fsst::FSST::cast(array: &vortex_fsst::FSSTArray, dtype: &vortex_array::dtype::DType) -> vortex_error::VortexResult> +impl vortex_array::scalar_fn::fns::like::kernel::LikeKernel for vortex_fsst::FSST + +pub fn vortex_fsst::FSST::like(array: &vortex_fsst::FSSTArray, pattern: &vortex_array::array::ArrayRef, options: vortex_array::scalar_fn::fns::like::LikeOptions, _ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult> + impl vortex_array::vtable::VTable for vortex_fsst::FSST pub type vortex_fsst::FSST::Array = vortex_fsst::FSSTArray From 53b7b3d4409e2d20d94a1d14d0b6d2d3e9258029 Mon Sep 17 00:00:00 2001 From: Joe Isaacs Date: Tue, 17 Mar 2026 10:31:32 +0000 Subject: [PATCH 04/19] clean up Signed-off-by: Joe Isaacs --- encodings/fsst/Cargo.toml | 2 +- encodings/fsst/README.md | 20 + .../{fsst_contains.rs => fsst_like.rs} | 39 +- encodings/fsst/src/compute/like.rs | 1019 ++------------ encodings/fsst/src/dfa.rs | 1192 +++++++++++++++++ encodings/fsst/src/lib.rs | 1 + 6 files changed, 1340 insertions(+), 933 deletions(-) rename encodings/fsst/benches/{fsst_contains.rs => fsst_like.rs} (79%) create mode 100644 encodings/fsst/src/dfa.rs diff --git a/encodings/fsst/Cargo.toml b/encodings/fsst/Cargo.toml index eb08bbda959..b95eeb1f444 100644 --- a/encodings/fsst/Cargo.toml +++ b/encodings/fsst/Cargo.toml @@ -41,7 +41,7 @@ harness = false required-features = ["_test-harness"] [[bench]] -name = "fsst_contains" +name = "fsst_like" harness = false required-features = ["_test-harness"] diff --git a/encodings/fsst/README.md b/encodings/fsst/README.md index 0e08c6e7fc8..83668515f26 100644 --- a/encodings/fsst/README.md +++ b/encodings/fsst/README.md @@ -2,3 +2,23 @@ A Vortex Encoding for Binary and Utf8 data that utilizes the [Fast Static Symbol Table](https://github.com/spiraldb/fsst) compression algorithm. + +## LIKE Pushdown + +The FSST encoding has a specialized LIKE fast path for a narrow subset of +patterns: + +- `prefix%` +- `%needle%` + +Unsupported shapes, including `_`, `%suffix`, or patterns with interior +wildcards, fall back to ordinary decompression-based LIKE evaluation. + +There are also two implementation limits on the pushdown path, both measured in +pattern bytes: + +- `prefix%` supports up to 13 bytes. +- `%needle%` supports up to 254 bytes. + +Patterns beyond those limits are still evaluated correctly, but they do so via +the fallback path instead of the DFA matcher. diff --git a/encodings/fsst/benches/fsst_contains.rs b/encodings/fsst/benches/fsst_like.rs similarity index 79% rename from encodings/fsst/benches/fsst_contains.rs rename to encodings/fsst/benches/fsst_like.rs index 6885ad0543e..12e78e2d7fb 100644 --- a/encodings/fsst/benches/fsst_contains.rs +++ b/encodings/fsst/benches/fsst_like.rs @@ -80,7 +80,19 @@ impl Dataset { } } - fn pattern(&self) -> &'static str { + fn prefix_pattern(&self) -> &'static str { + match self { + Self::Urls => "https%", + Self::Cb => "https://www.%", + Self::Log => "192.168%", + Self::Json => r#"{"id%"#, + Self::Path => "/home%", + Self::Email => "john%", + Self::Rare => "xyz%", + } + } + + fn contains_pattern(&self) -> &'static str { match self { Self::Urls => "%google%", Self::Cb => "%yandex%", @@ -93,15 +105,10 @@ impl Dataset { } } -#[divan::bench(args = [ - Dataset::Urls, Dataset::Cb, Dataset::Log, Dataset::Json, - Dataset::Path, Dataset::Email, Dataset::Rare, -])] -fn fsst_like(bencher: Bencher, dataset: &Dataset) { - let fsst = dataset.fsst_array(); +fn bench_like(bencher: Bencher, fsst: &FSSTArray, pattern: &str) { let len = fsst.len(); let arr = fsst.clone().into_array(); - let pattern = ConstantArray::new(dataset.pattern(), len).into_array(); + let pattern = ConstantArray::new(pattern, len).into_array(); bencher.bench_local(|| { Like.try_new_array(len, LikeOptions::default(), [arr.clone(), pattern.clone()]) .unwrap() @@ -110,3 +117,19 @@ fn fsst_like(bencher: Bencher, dataset: &Dataset) { .unwrap() }); } + +#[divan::bench(args = [ + Dataset::Urls, Dataset::Cb, Dataset::Log, Dataset::Json, + Dataset::Path, Dataset::Email, Dataset::Rare, +])] +fn fsst_prefix(bencher: Bencher, dataset: &Dataset) { + bench_like(bencher, dataset.fsst_array(), dataset.prefix_pattern()); +} + +#[divan::bench(args = [ + Dataset::Urls, Dataset::Cb, Dataset::Log, Dataset::Json, + Dataset::Path, Dataset::Email, Dataset::Rare, +])] +fn fsst_contains(bencher: Bencher, dataset: &Dataset) { + bench_like(bencher, dataset.fsst_array(), dataset.contains_pattern()); +} diff --git a/encodings/fsst/src/compute/like.rs b/encodings/fsst/src/compute/like.rs index 5d29f469dc8..c53f621f25f 100644 --- a/encodings/fsst/src/compute/like.rs +++ b/encodings/fsst/src/compute/like.rs @@ -3,8 +3,6 @@ #![allow(clippy::cast_possible_truncation)] -use fsst::ESCAPE_CODE; -use fsst::Symbol; use vortex_array::ArrayRef; use vortex_array::ExecutionCtx; use vortex_array::IntoArray; @@ -13,12 +11,12 @@ use vortex_array::arrays::BoolArray; use vortex_array::match_each_integer_ptype; use vortex_array::scalar_fn::fns::like::LikeKernel; use vortex_array::scalar_fn::fns::like::LikeOptions; -use vortex_buffer::BitBuffer; -use vortex_buffer::BufferMut; use vortex_error::VortexResult; use crate::FSST; use crate::FSSTArray; +use crate::dfa::FsstMatcher; +use crate::dfa::dfa_scan_to_bitbuf; impl LikeKernel for FSST { #[allow(clippy::cast_possible_truncation)] @@ -40,67 +38,26 @@ impl LikeKernel for FSST { return Ok(None); }; - let Some(like_kind) = LikeKind::parse(pattern_str) else { + let symbols = array.symbols(); + let symbol_lengths = array.symbol_lengths(); + + let Some(matcher) = + FsstMatcher::try_new(symbols.as_slice(), symbol_lengths.as_slice(), pattern_str)? + else { return Ok(None); }; - let symbols = array.symbols(); - let symbol_lengths = array.symbol_lengths(); let negated = options.negated; - - // Access the underlying codes VarBinArray buffers directly to avoid - // dyn Iterator overhead from with_iterator. let codes = array.codes(); let offsets = codes.offsets().to_primitive(); let all_bytes = codes.bytes(); let all_bytes = all_bytes.as_slice(); let n = codes.len(); - let result = match like_kind { - LikeKind::Prefix(prefix) => { - let prefix = prefix.as_bytes(); - // FsstPrefixDfa uses 4-bit shift packing: prefix_len + 2 states must fit in 16. - if prefix.len() + 2 > (1 << FsstPrefixDfa::BITS) { - return Ok(None); - } - let dfa = FsstPrefixDfa::new(symbols.as_slice(), symbol_lengths.as_slice(), prefix); - match_each_integer_ptype!(offsets.ptype(), |T| { - let off = offsets.as_slice::(); - dfa_scan_to_bitbuf(n, off, all_bytes, negated, |codes| dfa.matches(codes)) - }) - } - LikeKind::Contains(needle) => { - let needle = needle.as_bytes(); - if needle.len() <= BranchlessShiftDfa::MAX_NEEDLE_LEN { - let dfa = BranchlessShiftDfa::new( - symbols.as_slice(), - symbol_lengths.as_slice(), - needle, - ); - match_each_integer_ptype!(offsets.ptype(), |T| { - let off = offsets.as_slice::(); - dfa_scan_to_bitbuf(n, off, all_bytes, negated, |codes| dfa.matches(codes)) - }) - } else if needle.len() <= FlatBranchlessDfa::MAX_NEEDLE_LEN { - let dfa = FlatBranchlessDfa::new( - symbols.as_slice(), - symbol_lengths.as_slice(), - needle, - ); - match_each_integer_ptype!(offsets.ptype(), |T| { - let off = offsets.as_slice::(); - dfa_scan_to_bitbuf(n, off, all_bytes, negated, |codes| dfa.matches(codes)) - }) - } else { - let dfa = - FsstContainsDfa::new(symbols.as_slice(), symbol_lengths.as_slice(), needle); - match_each_integer_ptype!(offsets.ptype(), |T| { - let off = offsets.as_slice::(); - dfa_scan_to_bitbuf(n, off, all_bytes, negated, |codes| dfa.matches(codes)) - }) - } - } - }; + let result = match_each_integer_ptype!(offsets.ptype(), |T| { + let off = offsets.as_slice::(); + dfa_scan_to_bitbuf(n, off, all_bytes, negated, |codes| matcher.matches(codes)) + }); // FSST delegates validity to its codes array, so we can read it // directly without cloning the entire FSSTArray into an ArrayRef. @@ -113,870 +70,13 @@ impl LikeKernel for FSST { } } -/// Scan all strings through a DFA matcher, packing results directly into a -/// `BitBuffer` one u64 word (64 strings) at a time. This avoids the overhead -/// of `BitBufferMut::collect_bool`'s cross-crate closure indirection and -/// guarantees the compiler can see the full loop body for optimization. -// TODO: add N-way ILP overrun scan for higher throughput on short strings. -#[inline] -fn dfa_scan_to_bitbuf( - n: usize, - offsets: &[T], - all_bytes: &[u8], - negated: bool, - matcher: F, -) -> BitBuffer -where - T: vortex_array::dtype::IntegerPType, - F: Fn(&[u8]) -> bool, -{ - let n_words = n / 64; - let remainder = n % 64; - let mut words: BufferMut = BufferMut::with_capacity(n.div_ceil(64)); - - for chunk in 0..n_words { - let base = chunk * 64; - let mut word = 0u64; - let mut start: usize = offsets[base].as_(); - for bit in 0..64 { - let end: usize = offsets[base + bit + 1].as_(); - word |= ((matcher(&all_bytes[start..end]) != negated) as u64) << bit; - start = end; - } - // SAFETY: we allocated capacity for n.div_ceil(64) words. - unsafe { words.push_unchecked(word) }; - } - - if remainder != 0 { - let base = n_words * 64; - let mut word = 0u64; - let mut start: usize = offsets[base].as_(); - for bit in 0..remainder { - let end: usize = offsets[base + bit + 1].as_(); - word |= ((matcher(&all_bytes[start..end]) != negated) as u64) << bit; - start = end; - } - unsafe { words.push_unchecked(word) }; - } - - BitBuffer::new(words.into_byte_buffer().freeze(), n) -} - -/// The subset of LIKE patterns we can handle without decompression. -enum LikeKind<'a> { - /// `prefix%` - Prefix(&'a str), - /// `%needle%` - Contains(&'a str), -} - -impl<'a> LikeKind<'a> { - fn parse(pattern: &'a str) -> Option { - if pattern == "%" { - return Some(LikeKind::Prefix("")); - } - - // Find first wildcard. - let first_wild = pattern.find(['%', '_'])?; - - // `_` as first wildcard means we can't handle it. - if pattern.as_bytes()[first_wild] == b'_' { - return None; - } - - // `prefix%` — single trailing % - if first_wild > 0 && &pattern[first_wild..] == "%" { - return Some(LikeKind::Prefix(&pattern[..first_wild])); - } - - // `%needle%` — leading and trailing %, no inner wildcards - if first_wild == 0 - && pattern.len() > 2 - && pattern.as_bytes()[pattern.len() - 1] == b'%' - && !pattern[1..pattern.len() - 1].contains(['%', '_']) - { - return Some(LikeKind::Contains(&pattern[1..pattern.len() - 1])); - } - - None - } -} - -// --------------------------------------------------------------------------- -// DFA for prefix matching (LIKE 'prefix%') -// --------------------------------------------------------------------------- - -/// Precomputed shift-based DFA for prefix matching on FSST codes. -/// -/// States 0..prefix_len track match progress, plus ACCEPT and FAIL. -/// Uses the same shift-based approach as the contains DFA: all state -/// transitions packed into a `u64` per code byte. For prefixes longer -/// than 13 characters, falls back to a fused u8 table. -struct FsstPrefixDfa { - /// Packed transitions: `(table[code] >> (state * 4)) & 0xF` gives next state. - transitions: [u64; 256], - /// Packed escape transitions for literal bytes. - escape_transitions: [u64; 256], - accept_state: u8, - fail_state: u8, -} - -impl FsstPrefixDfa { - const BITS: u32 = 4; - const MASK: u64 = (1 << Self::BITS) - 1; - - fn new(symbols: &[Symbol], symbol_lengths: &[u8], prefix: &[u8]) -> Self { - // prefix.len() + 2 states (0..prefix_len, accept, fail) must fit in 4 bits. - debug_assert!(prefix.len() + 2 <= (1 << Self::BITS)); - - let n_symbols = symbols.len(); - let accept_state = prefix.len() as u8; - let fail_state = prefix.len() as u8 + 1; - let n_states = prefix.len() + 2; - - // Build per-symbol and per-escape-byte transitions into flat tables. - let mut sym_trans = vec![fail_state; n_states * n_symbols]; - let mut esc_trans = vec![fail_state; n_states * 256]; - - for state in 0..n_states { - if state as u8 == accept_state { - for code in 0..n_symbols { - sym_trans[state * n_symbols + code] = accept_state; - } - for b in 0..256 { - esc_trans[state * 256 + b] = accept_state; - } - continue; - } - if state as u8 == fail_state { - continue; - } - - for code in 0..n_symbols { - let sym = symbols[code].to_u64().to_le_bytes(); - let sym_len = symbol_lengths[code] as usize; - let remaining = prefix.len() - state; - let cmp = sym_len.min(remaining); - - if sym[..cmp] == prefix[state..state + cmp] { - let next = state + cmp; - sym_trans[state * n_symbols + code] = if next >= prefix.len() { - accept_state - } else { - next as u8 - }; - } - } - - for b in 0..256usize { - if b as u8 == prefix[state] { - let next = state + 1; - esc_trans[state * 256 + b] = if next >= prefix.len() { - accept_state - } else { - next as u8 - }; - } - } - } - - // Fuse symbol transitions into a 256-wide table. - let escape_sentinel = fail_state + 1; - let mut fused = vec![fail_state; n_states * 256]; - for state in 0..n_states { - for code in 0..n_symbols { - fused[state * 256 + code] = sym_trans[state * n_symbols + code]; - } - fused[state * 256 + ESCAPE_CODE as usize] = escape_sentinel; - } - - // Pack into u64 shift tables. - let mut transitions = [0u64; 256]; - for code_byte in 0..256usize { - let mut packed = 0u64; - for state in 0..n_states { - packed |= (fused[state * 256 + code_byte] as u64) << (state as u32 * Self::BITS); - } - transitions[code_byte] = packed; - } - - let mut escape_transitions = [0u64; 256]; - for byte_val in 0..256usize { - let mut packed = 0u64; - for state in 0..n_states { - packed |= (esc_trans[state * 256 + byte_val] as u64) << (state as u32 * Self::BITS); - } - escape_transitions[byte_val] = packed; - } - - Self { - transitions, - escape_transitions, - accept_state, - fail_state, - } - } - - #[inline] - fn matches(&self, codes: &[u8]) -> bool { - let mut state = 0u8; - let mut pos = 0; - while pos < codes.len() { - let code = codes[pos]; - pos += 1; - let packed = self.transitions[code as usize]; - let next = ((packed >> (state as u32 * Self::BITS)) & Self::MASK) as u8; - if next == self.fail_state + 1 { - // Escape sentinel: read literal byte. - if pos >= codes.len() { - return false; - } - let b = codes[pos]; - pos += 1; - let esc_packed = self.escape_transitions[b as usize]; - state = ((esc_packed >> (state as u32 * Self::BITS)) & Self::MASK) as u8; - } else { - state = next; - } - if state == self.accept_state { - return true; - } - if state == self.fail_state { - return false; - } - } - state == self.accept_state - } -} - -// --------------------------------------------------------------------------- -// DFA for contains matching (LIKE '%needle%') -// --------------------------------------------------------------------------- - -/// Contains DFA for long needles (>14 chars). Short needles (len <= 7) are -/// handled by `BranchlessShiftDfa`, medium needles (8-14) by -/// `FlatBranchlessDfa`. -enum FsstContainsDfa { - /// Shift-based DFA for medium needles (len 8-14). - Shift(Box), - /// Fused u8 table DFA for long needles (len > 14). - Fused(FusedDfa), -} - -impl FsstContainsDfa { - fn new(symbols: &[Symbol], symbol_lengths: &[u8], needle: &[u8]) -> Self { - if needle.len() <= ShiftDfa::MAX_NEEDLE_LEN { - FsstContainsDfa::Shift(Box::new(ShiftDfa::new(symbols, symbol_lengths, needle))) - } else { - FsstContainsDfa::Fused(FusedDfa::new(symbols, symbol_lengths, needle)) - } - } - - #[inline] - fn matches(&self, codes: &[u8]) -> bool { - match self { - FsstContainsDfa::Shift(dfa) => dfa.matches(codes), - FsstContainsDfa::Fused(dfa) => dfa.matches(codes), - } - } -} - -/// Branchless escape-folded DFA for short needles (len <= 7). -/// -/// Folds escape handling into the state space so that `matches()` is -/// completely branchless (except for loop control). The state layout is: -/// - States 0..N-1: normal match-progress states -/// - State N: accept (sticky for all inputs) -/// - States N+1..2N: escape states (state `s+N+1` means "was in state `s`, -/// just consumed ESCAPE_CODE") -/// -/// Total states: 2N+1. With 4-bit packing, max N=7. -/// -/// Uses a decomposed hierarchical lookup that processes 4 code bytes per -/// loop iteration with only ~3 KB of tables: -/// -/// 1. **Equivalence class table** (256 B): maps each code byte to a class -/// id. Bytes with identical transition u64s share a class -- typically -/// only ~6-10 classes exist (needle chars + escape + "miss-all"). -/// 2. **Pair-compose table** (~N^2 B): maps `(class0, class1)` to a 2-byte -/// palette index. Typically ~36 entries. -/// 3. **4-byte compose table** (~M^2 x 8 B): maps `(palette0, palette1)` to -/// the composed packed u64 for all 4 bytes. Typically ~81 entries = 648 B. -/// -/// Each loop iteration: 4 class lookups (parallel, 256 B table) -> 2 -/// pair-compose lookups (parallel, ~36 B table) -> 1 compose lookup -/// (~648 B table) -> 1 shift+mask. All tables fit in L1 cache. -struct BranchlessShiftDfa { - /// Maps each code byte to its equivalence class. Bytes with the same - /// packed transition u64 share a class. (256 bytes) - eq_class: [u8; 256], - /// Maps `(class0 * n_classes + class1)` -> 2-byte palette index. - pair_compose: Vec, - /// Number of equivalence classes (stride for pair_compose). - n_classes: usize, - /// Maps `(palette0 * n_palette + palette1)` -> composed packed u64 - /// for 4 bytes. - compose_4b: Vec, - /// Number of unique 2-byte palette entries (stride for compose_4b). - n_palette: usize, - /// 1-byte fallback transitions for trailing bytes. - transitions_1b: [u64; 256], - /// 2-byte palette for the remainder path (2-3 trailing bytes). - palette_2b: Vec, - accept_state: u8, -} - -impl BranchlessShiftDfa { - const BITS: u32 = 4; - const MASK: u64 = (1 << Self::BITS) - 1; - /// Maximum needle length: need 2N+1 states to fit in 16 slots (4 bits). - /// 2*7+1 = 15 <= 16, so max N = 7. - const MAX_NEEDLE_LEN: usize = 7; - - fn new(symbols: &[Symbol], symbol_lengths: &[u8], needle: &[u8]) -> Self { - let n = needle.len(); - debug_assert!(n <= Self::MAX_NEEDLE_LEN); - - let accept_state = n as u8; - let total_states = 2 * n + 1; - debug_assert!(total_states <= (1 << Self::BITS)); - - let transitions_1b = - Self::build_1b_transitions(symbols, symbol_lengths, needle, total_states); - - // Build equivalence classes: group bytes with identical transition u64. - let mut eq_class = [0u8; 256]; - let mut class_representatives: Vec = Vec::new(); - for byte_val in 0..256usize { - let t = transitions_1b[byte_val]; - let cls = class_representatives - .iter() - .position(|&v| v == t) - .unwrap_or_else(|| { - class_representatives.push(t); - class_representatives.len() - 1 - }); - eq_class[byte_val] = cls as u8; - } - let n_classes = class_representatives.len(); - - // Build pair-compose: for each (class0, class1), compose the two - // 1-byte transitions and deduplicate into a 2-byte palette. - let (pair_compose, palette_2b) = - Self::build_pair_compose(&class_representatives, n_classes, total_states); - - // Build 4-byte composition: compose_4b[p0 * n + p1] gives the packed - // u64 for applying palette_2b[p0] then palette_2b[p1] in sequence. - let n_palette = palette_2b.len(); - let compose_4b = Self::build_compose_4b(&palette_2b, total_states); - - Self { - eq_class, - pair_compose, - n_classes, - compose_4b, - n_palette, - transitions_1b, - palette_2b, - accept_state, - } - } - - /// Build the 1-byte packed transition table from FSST symbols and - /// a byte-level KMP table, folding escape handling into the state space. - fn build_1b_transitions( - symbols: &[Symbol], - symbol_lengths: &[u8], - needle: &[u8], - total_states: usize, - ) -> [u64; 256] { - let n = needle.len(); - let n_symbols = symbols.len(); - let accept_state = n as u8; - let n_normal_states = n + 1; - - let byte_table = kmp_byte_transitions(needle); - - // Build per-symbol transitions for normal states. - let mut sym_trans = vec![0u8; n_normal_states * n_symbols]; - for state in 0..n_normal_states { - for code in 0..n_symbols { - if state as u8 == accept_state { - sym_trans[state * n_symbols + code] = accept_state; - continue; - } - let sym = symbols[code].to_u64().to_le_bytes(); - let sym_len = symbol_lengths[code] as usize; - let mut s = state as u16; - for &b in &sym[..sym_len] { - if s == accept_state as u16 { - break; - } - s = byte_table[s as usize * 256 + b as usize]; - } - sym_trans[state * n_symbols + code] = s as u8; - } - } - - // Build fused transition table with escape folding. - let mut fused = vec![0u8; total_states * 256]; - for code_byte in 0..256usize { - for s in 0..n { - if code_byte == ESCAPE_CODE as usize { - fused[s * 256 + code_byte] = (s + n + 1) as u8; - } else if code_byte < n_symbols { - fused[s * 256 + code_byte] = sym_trans[s * n_symbols + code_byte]; - } - } - fused[n * 256 + code_byte] = accept_state; - for s in 0..n { - let esc_state = s + n + 1; - let next = byte_table[s * 256 + code_byte] as u8; - fused[esc_state * 256 + code_byte] = next; - } - } - - // Pack into u64 shift table. - let mut transitions = [0u64; 256]; - for code_byte in 0..256usize { - let mut packed = 0u64; - for state in 0..total_states { - packed |= (fused[state * 256 + code_byte] as u64) << (state as u32 * Self::BITS); - } - transitions[code_byte] = packed; - } - transitions - } - - /// Build the pair-compose table and 2-byte palette from equivalence - /// class representatives. - fn build_pair_compose( - class_reps: &[u64], - n_classes: usize, - total_states: usize, - ) -> (Vec, Vec) { - let mut pair_compose = vec![0u8; n_classes * n_classes]; - let mut palette_2b: Vec = Vec::new(); - - for c0 in 0..n_classes { - for c1 in 0..n_classes { - let t0 = class_reps[c0]; - let t1 = class_reps[c1]; - let mut packed = 0u64; - for state in 0..total_states { - let mid = ((t0 >> (state as u32 * Self::BITS)) & Self::MASK) as u8; - let final_s = ((t1 >> (mid as u32 * Self::BITS)) & Self::MASK) as u8; - packed |= (final_s as u64) << (state as u32 * Self::BITS); - } - let idx = palette_2b - .iter() - .position(|&v| v == packed) - .unwrap_or_else(|| { - palette_2b.push(packed); - palette_2b.len() - 1 - }); - pair_compose[c0 * n_classes + c1] = idx as u8; - } - } - (pair_compose, palette_2b) - } - - /// Compose pairs of 2-byte palette entries into a 4-byte lookup table. - fn build_compose_4b(palette_2b: &[u64], total_states: usize) -> Vec { - let n = palette_2b.len(); - let mut compose = vec![0u64; n * n]; - for p0 in 0..n { - for p1 in 0..n { - let mut packed = 0u64; - for state in 0..total_states { - let mid = ((palette_2b[p0] >> (state as u32 * Self::BITS)) & Self::MASK) as u8; - let final_s = - ((palette_2b[p1] >> (mid as u32 * Self::BITS)) & Self::MASK) as u8; - packed |= (final_s as u64) << (state as u32 * Self::BITS); - } - compose[p0 * n + p1] = packed; - } - } - compose - } - - /// Process remaining bytes after the interleaved common prefix. - #[inline] - fn finish_tail(&self, mut state: u8, codes: &[u8]) -> u8 { - let chunks = codes.chunks_exact(4); - let rem = chunks.remainder(); - - for chunk in chunks { - let ec0 = unsafe { *self.eq_class.get_unchecked(chunk[0] as usize) } as usize; - let ec1 = unsafe { *self.eq_class.get_unchecked(chunk[1] as usize) } as usize; - let ec2 = unsafe { *self.eq_class.get_unchecked(chunk[2] as usize) } as usize; - let ec3 = unsafe { *self.eq_class.get_unchecked(chunk[3] as usize) } as usize; - let p0 = - unsafe { *self.pair_compose.get_unchecked(ec0 * self.n_classes + ec1) } as usize; - let p1 = - unsafe { *self.pair_compose.get_unchecked(ec2 * self.n_classes + ec3) } as usize; - let packed = unsafe { *self.compose_4b.get_unchecked(p0 * self.n_palette + p1) }; - state = ((packed >> (state as u32 * Self::BITS)) & Self::MASK) as u8; - } - - if rem.len() >= 2 { - let ec0 = self.eq_class[rem[0] as usize] as usize; - let ec1 = self.eq_class[rem[1] as usize] as usize; - let p = self.pair_compose[ec0 * self.n_classes + ec1] as usize; - let packed = self.palette_2b[p]; - state = ((packed >> (state as u32 * Self::BITS)) & Self::MASK) as u8; - if rem.len() == 3 { - let packed = self.transitions_1b[rem[2] as usize]; - state = ((packed >> (state as u32 * Self::BITS)) & Self::MASK) as u8; - } - } else if rem.len() == 1 { - let packed = self.transitions_1b[rem[0] as usize]; - state = ((packed >> (state as u32 * Self::BITS)) & Self::MASK) as u8; - } - - state - } - - /// Branchless matching processing four code bytes per iteration. - #[inline(never)] - fn matches(&self, codes: &[u8]) -> bool { - self.finish_tail(0, codes) == self.accept_state - } -} - -/// Flat u8 escape-folded DFA for medium needles (8-14 chars). -/// -/// Like `BranchlessShiftDfa`, folds escape handling into the state space -/// (2N+1 states), but uses a flat `u8` transition table instead of -/// shift-packed `u64`. Supports up to 14-char needles (2*14+1 = 29 states). -/// Table size: 29 * 256 = 7,424 bytes, fits in L1. -struct FlatBranchlessDfa { - /// transitions[state * 256 + byte] -> next state - transitions: Vec, - accept_state: u8, -} - -impl FlatBranchlessDfa { - const MAX_NEEDLE_LEN: usize = 14; - - fn new(symbols: &[Symbol], symbol_lengths: &[u8], needle: &[u8]) -> Self { - let n = needle.len(); - debug_assert!(n <= Self::MAX_NEEDLE_LEN); - - let accept_state = n as u8; - let total_states = 2 * n + 1; - let n_symbols = symbols.len(); - - let byte_table = kmp_byte_transitions(needle); - - // Build per-symbol transitions for normal states. - let mut sym_trans = vec![0u8; (n + 1) * n_symbols]; - for state in 0..=n { - for code in 0..n_symbols { - if state as u8 == accept_state { - sym_trans[state * n_symbols + code] = accept_state; - continue; - } - let sym = symbols[code].to_u64().to_le_bytes(); - let sym_len = symbol_lengths[code] as usize; - let mut s = state as u16; - for &b in &sym[..sym_len] { - if s == accept_state as u16 { - break; - } - s = byte_table[s as usize * 256 + b as usize]; - } - sym_trans[state * n_symbols + code] = s as u8; - } - } - - // Build fused transition table with escape folding. - let mut transitions = vec![0u8; total_states * 256]; - for code_byte in 0..256usize { - // Normal states 0..n - for s in 0..n { - if code_byte == ESCAPE_CODE as usize { - transitions[s * 256 + code_byte] = (s + n + 1) as u8; - } else if code_byte < n_symbols { - transitions[s * 256 + code_byte] = sym_trans[s * n_symbols + code_byte]; - } - } - // Accept state (sticky) - transitions[n * 256 + code_byte] = accept_state; - // Escape states n+1..2n - for s in 0..n { - let esc_state = s + n + 1; - let next = byte_table[s * 256 + code_byte] as u8; - transitions[esc_state * 256 + code_byte] = next; - } - } - - Self { - transitions, - accept_state, - } - } - - #[inline(never)] - fn matches(&self, codes: &[u8]) -> bool { - let mut state = 0u8; - for &byte in codes { - state = self.transitions[state as usize * 256 + byte as usize]; - } - state == self.accept_state - } -} - -/// Shift-based DFA: packs all state transitions into a `u64` per input byte. -/// -/// For a DFA with S states (S <= 16, using 4 bits each), we store transitions -/// for ALL states in one `u64`. Transition: `next = (table[code] >> (state * 4)) & 0xF`. -/// -/// Supports needles up to 14 characters (needle.len() + 2 <= 16 to fit escape -/// sentinel). This covers virtually all practical LIKE patterns. -struct ShiftDfa { - /// For each code byte (0..255): a `u64` packing all state transitions. - /// Bits `[state*4 .. state*4+4)` encode the next state for that input. - transitions: [u64; 256], - /// Same layout for escape byte transitions. - escape_transitions: [u64; 256], - accept_state: u8, - escape_sentinel: u8, -} - -impl ShiftDfa { - const BITS: u32 = 4; - const MASK: u64 = (1 << Self::BITS) - 1; - /// Maximum needle length: 2^BITS - 2 (need room for accept + sentinel). - const MAX_NEEDLE_LEN: usize = (1 << Self::BITS) - 2; - - fn new(symbols: &[Symbol], symbol_lengths: &[u8], needle: &[u8]) -> Self { - debug_assert!(needle.len() <= Self::MAX_NEEDLE_LEN); - - let n_symbols = symbols.len(); - let n_states = needle.len() + 1; - let accept_state = needle.len() as u8; - let escape_sentinel = needle.len() as u8 + 1; - - let byte_table = kmp_byte_transitions(needle); - - // Build per-symbol transitions into a flat table first. - let mut sym_trans = vec![0u16; n_states * n_symbols]; - for state in 0..n_states { - for code in 0..n_symbols { - if state as u8 == accept_state { - sym_trans[state * n_symbols + code] = accept_state as u16; - continue; - } - let sym = symbols[code].to_u64().to_le_bytes(); - let sym_len = symbol_lengths[code] as usize; - let mut s = state as u16; - for &b in &sym[..sym_len] { - if s == accept_state as u16 { - break; - } - s = byte_table[s as usize * 256 + b as usize]; - } - sym_trans[state * n_symbols + code] = s; - } - } - - // Build fused 256-wide table, then pack into u64 shift tables. - let mut fused = vec![0u8; n_states * 256]; - for state in 0..n_states { - for code in 0..n_symbols { - fused[state * 256 + code] = sym_trans[state * n_symbols + code] as u8; - } - fused[state * 256 + ESCAPE_CODE as usize] = escape_sentinel; - } - - let mut transitions = [0u64; 256]; - for code_byte in 0..256usize { - let mut packed = 0u64; - for state in 0..n_states { - let next = fused[state * 256 + code_byte]; - packed |= (next as u64) << (state as u32 * Self::BITS); - } - transitions[code_byte] = packed; - } - - let mut escape_transitions = [0u64; 256]; - for byte_val in 0..256usize { - let mut packed = 0u64; - for state in 0..n_states { - let next = byte_table[state * 256 + byte_val] as u8; - packed |= (next as u64) << (state as u32 * Self::BITS); - } - escape_transitions[byte_val] = packed; - } - - Self { - transitions, - escape_transitions, - accept_state, - escape_sentinel, - } - } - - /// Match with iterator-based traversal. - /// - /// Using `iter.next()` instead of manual index + bounds check helps the - /// compiler eliminate redundant bounds checks. - #[inline] - fn matches(&self, codes: &[u8]) -> bool { - let mut state = 0u8; - let mut iter = codes.iter(); - while let Some(&code) = iter.next() { - let packed = self.transitions[code as usize]; - let next = ((packed >> (state as u32 * Self::BITS)) & Self::MASK) as u8; - if next == self.escape_sentinel { - let Some(&b) = iter.next() else { - return false; - }; - let esc_packed = self.escape_transitions[b as usize]; - state = ((esc_packed >> (state as u32 * Self::BITS)) & Self::MASK) as u8; - } else { - state = next; - } - } - state == self.accept_state - } -} - -/// Fused 256-entry u8 table DFA. Fallback for needles > 14 characters. -struct FusedDfa { - transitions: Vec, - escape_transitions: Vec, - accept_state: u8, - escape_sentinel: u8, -} - -impl FusedDfa { - fn new(symbols: &[Symbol], symbol_lengths: &[u8], needle: &[u8]) -> Self { - let n_symbols = symbols.len(); - let accept_state = needle.len() as u8; - let n_states = needle.len() + 1; - let escape_sentinel = needle.len() as u8 + 1; - - let byte_table = kmp_byte_transitions(needle); - - let mut symbol_transitions = vec![0u16; n_states * n_symbols]; - for state in 0..n_states { - for code in 0..n_symbols { - if state as u8 == accept_state { - symbol_transitions[state * n_symbols + code] = accept_state as u16; - continue; - } - let sym = symbols[code].to_u64().to_le_bytes(); - let sym_len = symbol_lengths[code] as usize; - let mut s = state as u16; - for &b in &sym[..sym_len] { - if s == accept_state as u16 { - break; - } - s = byte_table[s as usize * 256 + b as usize]; - } - symbol_transitions[state * n_symbols + code] = s; - } - } - - let mut transitions = vec![0u8; n_states * 256]; - for state in 0..n_states { - for code in 0..n_symbols { - transitions[state * 256 + code] = - symbol_transitions[state * n_symbols + code] as u8; - } - transitions[state * 256 + ESCAPE_CODE as usize] = escape_sentinel; - } - - let escape_transitions: Vec = byte_table.iter().map(|&v| v as u8).collect(); - - Self { - transitions, - escape_transitions, - accept_state, - escape_sentinel, - } - } - - #[inline] - fn matches(&self, codes: &[u8]) -> bool { - let mut state = 0u8; - let mut pos = 0; - while pos < codes.len() { - let code = codes[pos]; - pos += 1; - let next = self.transitions[state as usize * 256 + code as usize]; - if next == self.escape_sentinel { - if pos >= codes.len() { - return false; - } - let b = codes[pos]; - pos += 1; - state = self.escape_transitions[state as usize * 256 + b as usize]; - } else { - state = next; - } - if state == self.accept_state { - return true; - } - } - false - } -} - -// --------------------------------------------------------------------------- -// KMP helpers -// --------------------------------------------------------------------------- - -fn kmp_byte_transitions(needle: &[u8]) -> Vec { - let n_states = needle.len() + 1; - let accept = needle.len() as u16; - let failure = kmp_failure_table(needle); - - let mut table = vec![0u16; n_states * 256]; - for state in 0..n_states { - for byte in 0..256u16 { - if state == needle.len() { - table[state * 256 + byte as usize] = accept; - continue; - } - let mut s = state; - loop { - if byte as u8 == needle[s] { - s += 1; - break; - } - if s == 0 { - break; - } - s = failure[s - 1]; - } - table[state * 256 + byte as usize] = s as u16; - } - } - table -} - -fn kmp_failure_table(needle: &[u8]) -> Vec { - let mut failure = vec![0usize; needle.len()]; - let mut k = 0; - for i in 1..needle.len() { - while k > 0 && needle[k] != needle[i] { - k = failure[k - 1]; - } - if needle[k] == needle[i] { - k += 1; - } - failure[i] = k; - } - failure -} - #[cfg(test)] mod tests { use std::sync::LazyLock; + use rand::Rng; + use rand::SeedableRng; + use rand::rngs::StdRng; use vortex_array::Canonical; use vortex_array::IntoArray; use vortex_array::VortexSessionExecute; @@ -1155,11 +255,6 @@ mod tests { let fsst = make_fsst(&[Some("abc"), Some("def")], Nullability::NonNullable); let mut ctx = SESSION.create_execution_ctx(); - // Suffix pattern -- not handled. - let pattern = ConstantArray::new("%abc", fsst.len()).into_array(); - let result = ::like(&fsst, &pattern, LikeOptions::default(), &mut ctx)?; - assert!(result.is_none(), "suffix pattern should fall back"); - // Underscore wildcard -- not handled. let pattern = ConstantArray::new("a_c", fsst.len()).into_array(); let result = ::like(&fsst, &pattern, LikeOptions::default(), &mut ctx)?; @@ -1177,14 +272,90 @@ mod tests { Ok(()) } + #[test] + fn test_like_long_prefix_falls_back_but_still_matches() -> VortexResult<()> { + let fsst = make_fsst( + &[ + Some("abcdefghijklmn-tail"), + Some("abcdefghijklmx-tail"), + Some("abcdefghijklmn"), + ], + Nullability::NonNullable, + ); + let pattern = "abcdefghijklmn%"; + + let direct = ::like( + &fsst, + &ConstantArray::new(pattern, fsst.len()).into_array(), + LikeOptions::default(), + &mut SESSION.create_execution_ctx(), + )?; + assert!( + direct.is_none(), + "14-byte prefixes exceed the packed prefix DFA and should fall back" + ); + + let result = like(fsst, pattern)?; + assert_arrays_eq!(&result, &BoolArray::from_iter([true, false, true])); + Ok(()) + } + + #[test] + fn test_like_long_contains_falls_back_but_still_matches() -> VortexResult<()> { + let needle = "a".repeat(255); + let matching = format!("xx{needle}yy"); + let non_matching = format!("xx{}byy", "a".repeat(254)); + let exact = needle.clone(); + let pattern = format!("%{needle}%"); + + let fsst = make_fsst( + &[Some(&matching), Some(&non_matching), Some(&exact)], + Nullability::NonNullable, + ); + + let direct = ::like( + &fsst, + &ConstantArray::new(pattern.as_str(), fsst.len()).into_array(), + LikeOptions::default(), + &mut SESSION.create_execution_ctx(), + )?; + assert!( + direct.is_none(), + "contains needles longer than 254 bytes exceed the DFA's u8 state space" + ); + + let result = like(fsst, &pattern)?; + assert_arrays_eq!(&result, &BoolArray::from_iter([true, false, true])); + Ok(()) + } + + #[test] + fn test_like_contains_len_254_kernel_handles() -> VortexResult<()> { + let needle = "a".repeat(254); + let matching = format!("xx{needle}yy"); + let non_matching = format!("xx{}byy", "a".repeat(253)); + let pattern = format!("%{needle}%"); + + let fsst = make_fsst( + &[Some(&matching), Some(&non_matching), Some(needle.as_str())], + Nullability::NonNullable, + ); + + let direct = ::like( + &fsst, + &ConstantArray::new(pattern.as_str(), fsst.len()).into_array(), + LikeOptions::default(), + &mut SESSION.create_execution_ctx(), + )?; + assert!(direct.is_some(), "254-byte contains needle should stay on the DFA path"); + assert_arrays_eq!(direct.unwrap(), BoolArray::from_iter([true, false, true])); + Ok(()) + } + // ----------------------------------------------------------------------- // Fuzz tests: compare FSST kernel against naive string matching // ----------------------------------------------------------------------- - use rand::Rng; - use rand::SeedableRng; - use rand::rngs::StdRng; - fn random_string(rng: &mut StdRng, max_len: usize) -> String { let len = rng.random_range(0..=max_len); // Use a small alphabet to increase substring hit rate. diff --git a/encodings/fsst/src/dfa.rs b/encodings/fsst/src/dfa.rs new file mode 100644 index 00000000000..2865f5bc4fb --- /dev/null +++ b/encodings/fsst/src/dfa.rs @@ -0,0 +1,1192 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +//! # FSST LIKE Pushdown via DFA Construction +//! +//! This module implements DFA-based pattern matching directly on FSST-compressed +//! strings, without decompressing them. It handles two pattern shapes: +//! +//! - **Prefix**: `'prefix%'` — matches strings starting with a literal prefix. +//! - **Contains**: `'%needle%'` — matches strings containing a literal substring. +//! +//! Pushdown is intentionally conservative. If the pattern shape is unsupported, +//! or if the pattern exceeds the DFA's representable state space, construction +//! returns `None` and the caller must fall back to ordinary decompression-based +//! LIKE evaluation. +//! +//! TODO(joe): suffix (`'%suffix'`) pushdown. Two approaches: +//! - **Forward DFA**: use a non-sticky accept state with KMP fallback transitions, +//! check `state == accept` after processing all codes. Branchless and vectorizable. +//! - **Backward scan**: walk the compressed code stream in reverse, comparing symbol +//! bytes from the end. Simpler, no DFA construction, but requires reverse parsing +//! of the FSST escape mechanism. +//! +//! ## Background: FSST Encoding +//! +//! [FSST](https://www.vldb.org/pvldb/vol13/p2649-boncz.pdf) compresses strings by +//! replacing frequent byte sequences with single-byte **symbol codes** (0–254). Code +//! byte 255 is reserved as the **escape code**: the next byte is a literal (uncompressed) +//! byte. So a compressed string is a stream of: +//! +//! ```text +//! [symbol_code] ... [symbol_code] [ESCAPE literal_byte] [symbol_code] ... +//! ``` +//! +//! A single symbol can expand to 1–8 bytes. Matching on compressed codes requires +//! the DFA to handle multi-byte symbol expansions and the escape mechanism. +//! +//! ## The Algorithm: KMP → Byte Table → Symbol Table → Packed DFA +//! +//! Construction proceeds through four stages: +//! +//! ### Stage 1: KMP Failure Function +//! +//! We compute the standard [KMP](https://en.wikipedia.org/wiki/Knuth%E2%80%93Morris%E2%80%93Pratt_algorithm) +//! failure function for the needle bytes. This tells us, on a mismatch at +//! position `i`, the longest proper prefix of `needle[0..i]` that is also a +//! suffix — i.e., where to resume matching instead of starting over. +//! +//! ```text +//! Needle: "abcabd" +//! Failure: [0, 0, 0, 1, 2, 0] +//! ^ ^ +//! At position 3 ('a'), the prefix "a" matches suffix "a" +//! At position 4 ('b'), the prefix "ab" matches suffix "ab" +//! ``` +//! +//! ### Stage 2: Byte-Level Transition Table +//! +//! From the failure function, we build a full `(state × byte) → state` transition +//! table. State `i` means "we have matched `needle[0..i]`". State `n` (= needle +//! length) is the **accept** state. +//! +//! ```text +//! Needle: "aba" (3 states + accept) +//! +//! Input byte +//! State 'a' 'b' other +//! ───── ──── ──── ───── +//! 0 1 0 0 ← looking for first 'a' +//! 1 1 2 0 ← matched "a", want 'b' +//! 2 3✓ 0 0 ← matched "ab", want 'a' +//! 3✓ 3✓ 3✓ 3✓ ← accept (sticky) +//! ``` +//! +//! For prefix matching, a mismatch at any state goes to a **fail** state (no +//! fallback). For contains matching, mismatches follow KMP fallback transitions +//! so we can find the needle anywhere in the string. +//! +//! ### Stage 3: Symbol-Level Transition Table +//! +//! FSST symbols can be multi-byte. To compute the transition for symbol code `c` +//! in state `s`, we simulate feeding each byte of the symbol through the byte +//! table: +//! +//! ```text +//! Symbol #42 = "the" (3 bytes) +//! State 0 + 't' → 0, + 'h' → 0, + 'e' → 0 ⟹ sym_trans[0][42] = 0 +//! +//! If needle = "them": +//! State 0 + 't' → 1, + 'h' → 2, + 'e' → 3 ⟹ sym_trans[0][42] = 3 +//! ``` +//! +//! We then build a **fused 256-wide table**: for code bytes 0–254, use the +//! symbol transition; for code byte 255 (ESCAPE_CODE), transition to a +//! special sentinel that tells the scanner to read the next literal byte. +//! +//! ### Stage 4: Packing into the Final Representation +//! +//! The fused table can be stored in different layouts depending on the number +//! of states: +//! +//! - **Shift-packed `u64`** (≤16 states): Each state needs 4 bits. All state +//! transitions for one input byte fit in a single `u64`. Lookup: +//! `next = (table[byte] >> (state * 4)) & 0xF`. One cache line per lookup. +//! +//! - **Flat `u8` table** (≤255 states): `transitions[state * 256 + byte]`. +//! Larger, but still bounded by the `u8` state representation. +//! +//! ## State-Space Limits +//! +//! The public behavior is shaped by two implementation limits, both measured in +//! pattern **bytes** rather than Unicode scalar values: +//! +//! - `prefix%` pushdown is limited to **13 bytes**. The packed prefix DFA uses +//! 4-bit state ids and needs room for normal prefix-progress states, an +//! accept state, a fail state, and one escape sentinel for FSST literals. +//! - `%needle%` pushdown is limited to **254 bytes**. The long-needle DFA stores +//! states in `u8`, so it needs room for every match-progress state plus both +//! the accept state and the escape sentinel. +//! +//! Patterns beyond those limits are still valid LIKE patterns; they simply do +//! not use FSST pushdown and must be evaluated through the fallback path. +//! +//! ## DFA Variants and When Each Is Used +//! +//! ```text +//! ┌───────────────┬──────────────────────────────────────────────────────┐ +//! │ Pattern │ Needle length → DFA variant │ +//! ├───────────────┼──────────────────────────────────────────────────────┤ +//! │ prefix% │ 0–13 → FsstPrefixDfa (shift-packed, no KMP) │ +//! ├───────────────┼──────────────────────────────────────────────────────┤ +//! │ %needle% │ 1–7 → BranchlessShiftDfa (hierarchical 4-byte) │ +//! │ │ 8–14 → FlatBranchlessDfa (flat u8, escape-folded)│ +//! │ │ 15–254 → FusedDfa (escape sentinel) │ +//! └───────────────┴──────────────────────────────────────────────────────┘ +//! ``` +//! +//! ## Escape Handling Strategies +//! +//! There are two ways to handle the FSST escape code in the DFA: +//! +//! **Escape sentinel** (used by `ShiftDfa`, `FusedDfa`, `FsstPrefixDfa`): +//! The escape code maps to a sentinel state. The scanner checks for it and +//! reads the next byte from a separate escape transition table. +//! +//! ```text +//! loop: +//! state = transitions[byte] // might be sentinel +//! if state == SENTINEL: +//! state = escape_transitions[next_byte] // branch +//! ``` +//! +//! **Escape folding** (used by `BranchlessShiftDfa`, `FlatBranchlessDfa`): +//! Escape states are folded into the state space. State `s+N+1` means "was in +//! state `s`, just consumed ESCAPE_CODE". The next byte's transition from an +//! escape state uses the byte-level table. No branch needed in the scanner. +//! +//! ```text +//! States: [0..N-1: normal] [N: accept] [N+1..2N: escape shadows] +//! Total: 2N+1 states. With 4-bit packing, max N=7. +//! +//! loop: +//! state = transitions[state][byte] // branchless! +//! ``` + +#![allow(clippy::cast_possible_truncation)] + +use fsst::ESCAPE_CODE; +use fsst::Symbol; +use vortex_buffer::BitBuffer; +use vortex_buffer::BufferMut; +use vortex_error::VortexResult; + +// --------------------------------------------------------------------------- +// FsstMatcher — unified public API +// --------------------------------------------------------------------------- + +/// A compiled matcher for LIKE patterns on FSST-compressed strings. +/// +/// Encapsulates pattern parsing and DFA variant selection. Returns `None` from +/// [`try_new`](Self::try_new) for patterns that cannot be evaluated without +/// decompression (e.g., `_` wildcards, multiple `%` in non-standard positions, +/// or patterns that exceed the DFA's representable byte-length limits). +pub(crate) struct FsstMatcher { + inner: MatcherInner, +} + +enum MatcherInner { + MatchAll, + Prefix(Box), + ContainsBranchless(Box), + ContainsFlat(FlatBranchlessDfa), + Contains(FsstContainsDfa), +} + +impl FsstMatcher { + /// Try to build a matcher for the given LIKE pattern. + /// + /// Returns `Ok(None)` if the pattern shape is not supported for pushdown + /// (e.g. `_` wildcards, multiple non-bookend `%`, `prefix%` longer than + /// 13 bytes, or `%needle%` longer than 254 bytes). + pub(crate) fn try_new( + symbols: &[Symbol], + symbol_lengths: &[u8], + pattern: &str, + ) -> VortexResult> { + let Some(like_kind) = LikeKind::parse(pattern) else { + return Ok(None); + }; + + let inner = match like_kind { + LikeKind::Prefix("") => MatcherInner::MatchAll, + LikeKind::Prefix(prefix) => { + let prefix = prefix.as_bytes(); + if prefix.len() > FsstPrefixDfa::MAX_PREFIX_LEN { + return Ok(None); + } + MatcherInner::Prefix(Box::new(FsstPrefixDfa::new( + symbols, + symbol_lengths, + prefix, + ))) + } + LikeKind::Contains(needle) => { + let needle = needle.as_bytes(); + if needle.len() > FusedDfa::MAX_NEEDLE_LEN { + return Ok(None); + } + if needle.len() <= BranchlessShiftDfa::MAX_NEEDLE_LEN { + MatcherInner::ContainsBranchless(Box::new(BranchlessShiftDfa::new( + symbols, + symbol_lengths, + needle, + ))) + } else if needle.len() <= FlatBranchlessDfa::MAX_NEEDLE_LEN { + MatcherInner::ContainsFlat(FlatBranchlessDfa::new( + symbols, + symbol_lengths, + needle, + )) + } else { + MatcherInner::Contains(FsstContainsDfa::new(symbols, symbol_lengths, needle)) + } + } + }; + + Ok(Some(Self { inner })) + } + + /// Run the matcher on a single FSST-compressed code sequence. + #[inline] + pub(crate) fn matches(&self, codes: &[u8]) -> bool { + match &self.inner { + MatcherInner::MatchAll => true, + MatcherInner::Prefix(dfa) => dfa.matches(codes), + MatcherInner::ContainsBranchless(dfa) => dfa.matches(codes), + MatcherInner::ContainsFlat(dfa) => dfa.matches(codes), + MatcherInner::Contains(dfa) => dfa.matches(codes), + } + } +} + +/// The subset of LIKE patterns we can handle without decompression. +enum LikeKind<'a> { + /// `prefix%` + Prefix(&'a str), + /// `%needle%` + Contains(&'a str), +} + +impl<'a> LikeKind<'a> { + fn parse(pattern: &'a str) -> Option { + if pattern == "%" { + return Some(LikeKind::Prefix("")); + } + + // Find first wildcard. + let first_wild = pattern.find(['%', '_'])?; + + // `_` as first wildcard means we can't handle it. + if pattern.as_bytes()[first_wild] == b'_' { + return None; + } + + // `prefix%` — single trailing % + if first_wild > 0 && &pattern[first_wild..] == "%" { + return Some(LikeKind::Prefix(&pattern[..first_wild])); + } + + // `%needle%` — leading and trailing %, no inner wildcards + if first_wild == 0 + && pattern.len() > 2 + && pattern.as_bytes()[pattern.len() - 1] == b'%' + && !pattern[1..pattern.len() - 1].contains(['%', '_']) + { + return Some(LikeKind::Contains(&pattern[1..pattern.len() - 1])); + } + + None + } +} + +// --------------------------------------------------------------------------- +// Scan helper +// --------------------------------------------------------------------------- + +/// Scan all strings through a DFA matcher, packing results directly into a +/// `BitBuffer` one u64 word (64 strings) at a time. This avoids the overhead +/// of `BitBufferMut::collect_bool`'s cross-crate closure indirection and +/// guarantees the compiler can see the full loop body for optimization. +// TODO: add N-way ILP overrun scan for higher throughput on short strings. +#[inline] +pub(crate) fn dfa_scan_to_bitbuf( + n: usize, + offsets: &[T], + all_bytes: &[u8], + negated: bool, + matcher: F, +) -> BitBuffer +where + T: vortex_array::dtype::IntegerPType, + F: Fn(&[u8]) -> bool, +{ + let n_words = n / 64; + let remainder = n % 64; + let mut words: BufferMut = BufferMut::with_capacity(n.div_ceil(64)); + + for chunk in 0..n_words { + let base = chunk * 64; + let mut word = 0u64; + let mut start: usize = offsets[base].as_(); + for bit in 0..64 { + let end: usize = offsets[base + bit + 1].as_(); + word |= ((matcher(&all_bytes[start..end]) != negated) as u64) << bit; + start = end; + } + // SAFETY: we allocated capacity for n.div_ceil(64) words. + unsafe { words.push_unchecked(word) }; + } + + if remainder != 0 { + let base = n_words * 64; + let mut word = 0u64; + let mut start: usize = offsets[base].as_(); + for bit in 0..remainder { + let end: usize = offsets[base + bit + 1].as_(); + word |= ((matcher(&all_bytes[start..end]) != negated) as u64) << bit; + start = end; + } + unsafe { words.push_unchecked(word) }; + } + + BitBuffer::new(words.into_byte_buffer().freeze(), n) +} + +// --------------------------------------------------------------------------- +// Shared DFA construction helpers +// --------------------------------------------------------------------------- + +/// Builds the per-symbol transition table for FSST symbols. +/// +/// For each `(state, symbol_code)` pair, simulates feeding the symbol's bytes +/// through the byte-level transition table to compute the resulting state. +/// +/// Returns a flat `Vec` indexed as `[state * n_symbols + code]`. +fn build_symbol_transitions( + symbols: &[Symbol], + symbol_lengths: &[u8], + byte_table: &[u16], + n_states: usize, + accept_state: u8, +) -> Vec { + let n_symbols = symbols.len(); + let mut sym_trans = vec![0u8; n_states * n_symbols]; + for state in 0..n_states { + for code in 0..n_symbols { + if state as u8 == accept_state { + sym_trans[state * n_symbols + code] = accept_state; + continue; + } + let sym = symbols[code].to_u64().to_le_bytes(); + let sym_len = symbol_lengths[code] as usize; + let mut s = state as u16; + for &b in &sym[..sym_len] { + if s == accept_state as u16 { + break; + } + s = byte_table[s as usize * 256 + b as usize]; + } + sym_trans[state * n_symbols + code] = s as u8; + } + } + sym_trans +} + +/// Builds a fused 256-wide transition table from symbol transitions. +/// +/// For each `(state, code_byte)`: +/// - Code bytes `0..n_symbols`: use the symbol transition +/// - `ESCAPE_CODE`: maps to `escape_value` (either a sentinel or escape state) +/// - All others: use `default` (typically 0 for contains, fail_state for prefix) +/// +/// Returns a flat `Vec` indexed as `[state * 256 + code_byte]`. +fn build_fused_table( + sym_trans: &[u8], + n_symbols: usize, + n_states: usize, + escape_value_fn: impl Fn(usize) -> u8, + default: u8, +) -> Vec { + let mut fused = vec![default; n_states * 256]; + for state in 0..n_states { + for code in 0..n_symbols { + fused[state * 256 + code] = sym_trans[state * n_symbols + code]; + } + fused[state * 256 + ESCAPE_CODE as usize] = escape_value_fn(state); + } + fused +} + +/// Packs a fused table into shift-encoded `u64` arrays. +/// +/// Each `u64` encodes transitions for ALL states for one input byte. +/// Lookup: `next = (table[byte] >> (state * BITS)) & MASK`. +fn pack_shift_table(fused: &[u8], n_states: usize, bits: u32) -> [u64; 256] { + let mut packed = [0u64; 256]; + for code_byte in 0..256usize { + let mut val = 0u64; + for state in 0..n_states { + val |= (fused[state * 256 + code_byte] as u64) << (state as u32 * bits); + } + packed[code_byte] = val; + } + packed +} + +/// Packs a byte-level KMP table into shift-encoded `u64` arrays for escape handling. +fn pack_escape_shift_table(byte_table: &[u16], n_states: usize, bits: u32) -> [u64; 256] { + let mut packed = [0u64; 256]; + for byte_val in 0..256usize { + let mut val = 0u64; + for state in 0..n_states { + let next = byte_table[state * 256 + byte_val] as u8; + val |= (next as u64) << (state as u32 * bits); + } + packed[byte_val] = val; + } + packed +} + +// --------------------------------------------------------------------------- +// DFA for prefix matching (LIKE 'prefix%') +// --------------------------------------------------------------------------- + +/// Precomputed shift-based DFA for prefix matching on FSST codes. +/// +/// States 0..prefix_len track match progress, plus ACCEPT and FAIL. +/// Uses the same shift-based approach as the contains DFA: all state +/// transitions packed into a `u64` per code byte. For prefixes longer +/// than 13 characters, pushdown is disabled and LIKE falls back. +struct FsstPrefixDfa { + /// Packed transitions: `(table[code] >> (state * 4)) & 0xF` gives next state. + transitions: [u64; 256], + /// Packed escape transitions for literal bytes. + escape_transitions: [u64; 256], + accept_state: u8, + fail_state: u8, +} + +impl FsstPrefixDfa { + pub(crate) const BITS: u32 = 4; + const MASK: u64 = (1 << Self::BITS) - 1; + const MAX_PREFIX_LEN: usize = (1 << Self::BITS) as usize - 3; + + pub(crate) fn new(symbols: &[Symbol], symbol_lengths: &[u8], prefix: &[u8]) -> Self { + // Need room for states 0..prefix_len, accept, fail, and an escape sentinel. + debug_assert!(prefix.len() <= Self::MAX_PREFIX_LEN); + + let accept_state = prefix.len() as u8; + let fail_state = prefix.len() as u8 + 1; + let n_states = prefix.len() + 2; + + // Prefix matching uses a simpler transition rule than KMP: on mismatch + // we go to fail_state (no fallback). Build the byte table inline. + let byte_table = Self::build_prefix_byte_table(prefix, accept_state, fail_state); + + let sym_trans = + build_symbol_transitions(symbols, symbol_lengths, &byte_table, n_states, accept_state); + + // Override fail_state rows: fail is sticky. + let escape_sentinel = fail_state + 1; + let mut fused = build_fused_table( + &sym_trans, + symbols.len(), + n_states, + |_| escape_sentinel, + fail_state, + ); + + // Accept state is sticky for all inputs. + for code_byte in 0..256usize { + fused[accept_state as usize * 256 + code_byte] = accept_state; + } + // Fail state is sticky for all inputs. + for code_byte in 0..256usize { + fused[fail_state as usize * 256 + code_byte] = fail_state; + } + + let transitions = pack_shift_table(&fused, n_states, Self::BITS); + + // Build escape transitions from the byte table. + let mut esc_trans = vec![fail_state; n_states * 256]; + for state in 0..n_states { + if state as u8 == accept_state { + for b in 0..256 { + esc_trans[state * 256 + b] = accept_state; + } + } else if state as u8 != fail_state { + for b in 0..256usize { + if b as u8 == prefix[state] { + let next = state + 1; + esc_trans[state * 256 + b] = if next >= prefix.len() { + accept_state + } else { + next as u8 + }; + } + } + } + } + let escape_transitions = pack_shift_table(&esc_trans, n_states, Self::BITS); + + Self { + transitions, + escape_transitions, + accept_state, + fail_state, + } + } + + /// Build a byte-level transition table for prefix matching (no KMP fallback). + fn build_prefix_byte_table(prefix: &[u8], accept_state: u8, fail_state: u8) -> Vec { + let n_states = prefix.len() + 2; + let mut table = vec![fail_state as u16; n_states * 256]; + + for state in 0..n_states { + if state as u8 == accept_state { + for byte in 0..256 { + table[state * 256 + byte] = accept_state as u16; + } + } else if state as u8 != fail_state { + // Only the correct next byte advances; everything else fails. + let next_byte = prefix[state]; + let next_state = if state + 1 >= prefix.len() { + accept_state as u16 + } else { + (state + 1) as u16 + }; + table[state * 256 + next_byte as usize] = next_state; + } + } + table + } + + #[inline] + pub(crate) fn matches(&self, codes: &[u8]) -> bool { + let mut state = 0u8; + let mut pos = 0; + while pos < codes.len() { + let code = codes[pos]; + pos += 1; + let packed = self.transitions[code as usize]; + let next = ((packed >> (state as u32 * Self::BITS)) & Self::MASK) as u8; + if next == self.fail_state + 1 { + // Escape sentinel: read literal byte. + if pos >= codes.len() { + return false; + } + let b = codes[pos]; + pos += 1; + let esc_packed = self.escape_transitions[b as usize]; + state = ((esc_packed >> (state as u32 * Self::BITS)) & Self::MASK) as u8; + } else { + state = next; + } + if state == self.accept_state { + return true; + } + if state == self.fail_state { + return false; + } + } + state == self.accept_state + } +} + +// --------------------------------------------------------------------------- +// DFA for contains matching (LIKE '%needle%') +// --------------------------------------------------------------------------- + +/// Contains DFA dispatch for long needles (>14 bytes). Short needles (len <= 7) +/// are handled by `BranchlessShiftDfa`, medium needles (8-14) by +/// `FlatBranchlessDfa`, and longer supported needles (15-254) by `FusedDfa`. +enum FsstContainsDfa { + /// Retained internal alternative; not currently selected by `FsstMatcher`. + Shift(Box), + /// Fused u8 table DFA for long needles (15-254 bytes). + Fused(FusedDfa), +} + +impl FsstContainsDfa { + pub(crate) fn new(symbols: &[Symbol], symbol_lengths: &[u8], needle: &[u8]) -> Self { + if needle.len() <= ShiftDfa::MAX_NEEDLE_LEN { + FsstContainsDfa::Shift(Box::new(ShiftDfa::new(symbols, symbol_lengths, needle))) + } else { + FsstContainsDfa::Fused(FusedDfa::new(symbols, symbol_lengths, needle)) + } + } + + #[inline] + pub(crate) fn matches(&self, codes: &[u8]) -> bool { + match self { + FsstContainsDfa::Shift(dfa) => dfa.matches(codes), + FsstContainsDfa::Fused(dfa) => dfa.matches(codes), + } + } +} + +/// Branchless escape-folded DFA for short needles (len <= 7). +/// +/// Folds escape handling into the state space so that `matches()` is +/// completely branchless (except for loop control). The state layout is: +/// - States 0..N-1: normal match-progress states +/// - State N: accept (sticky for all inputs) +/// - States N+1..2N: escape states (state `s+N+1` means "was in state `s`, +/// just consumed ESCAPE_CODE") +/// +/// Total states: 2N+1. With 4-bit packing, max N=7. +/// +/// Uses a decomposed hierarchical lookup that processes 4 code bytes per +/// loop iteration with only ~3 KB of tables: +/// +/// 1. **Equivalence class table** (256 B): maps each code byte to a class +/// id. Bytes with identical transition u64s share a class -- typically +/// only ~6-10 classes exist (needle chars + escape + "miss-all"). +/// 2. **Pair-compose table** (~N^2 B): maps `(class0, class1)` to a 2-byte +/// palette index. Typically ~36 entries. +/// 3. **4-byte compose table** (~M^2 x 8 B): maps `(palette0, palette1)` to +/// the composed packed u64 for all 4 bytes. Typically ~81 entries = 648 B. +/// +/// Each loop iteration: 4 class lookups (parallel, 256 B table) -> 2 +/// pair-compose lookups (parallel, ~36 B table) -> 1 compose lookup +/// (~648 B table) -> 1 shift+mask. All tables fit in L1 cache. +struct BranchlessShiftDfa { + /// Maps each code byte to its equivalence class. Bytes with the same + /// packed transition u64 share a class. (256 bytes) + eq_class: [u8; 256], + /// Maps `(class0 * n_classes + class1)` -> 2-byte palette index. + pair_compose: Vec, + /// Number of equivalence classes (stride for pair_compose). + n_classes: usize, + /// Maps `(palette0 * n_palette + palette1)` -> composed packed u64 + /// for 4 bytes. + compose_4b: Vec, + /// Number of unique 2-byte palette entries (stride for compose_4b). + n_palette: usize, + /// 1-byte fallback transitions for trailing bytes. + transitions_1b: [u64; 256], + /// 2-byte palette for the remainder path (2-3 trailing bytes). + palette_2b: Vec, + accept_state: u8, +} + +impl BranchlessShiftDfa { + const BITS: u32 = 4; + const MASK: u64 = (1 << Self::BITS) - 1; + /// Maximum needle length: need 2N+1 states to fit in 16 slots (4 bits). + /// 2*7+1 = 15 <= 16, so max N = 7. + pub(crate) const MAX_NEEDLE_LEN: usize = 7; + + pub(crate) fn new(symbols: &[Symbol], symbol_lengths: &[u8], needle: &[u8]) -> Self { + let n = needle.len(); + debug_assert!(n <= Self::MAX_NEEDLE_LEN); + + let accept_state = n as u8; + let total_states = 2 * n + 1; + debug_assert!(total_states <= (1 << Self::BITS)); + + let transitions_1b = + Self::build_escape_folded_transitions(symbols, symbol_lengths, needle, total_states); + + // Build equivalence classes: group bytes with identical transition u64. + let mut eq_class = [0u8; 256]; + let mut class_representatives: Vec = Vec::new(); + for byte_val in 0..256usize { + let t = transitions_1b[byte_val]; + let cls = class_representatives + .iter() + .position(|&v| v == t) + .unwrap_or_else(|| { + class_representatives.push(t); + class_representatives.len() - 1 + }); + eq_class[byte_val] = cls as u8; + } + let n_classes = class_representatives.len(); + + // Build pair-compose: for each (class0, class1), compose the two + // 1-byte transitions and deduplicate into a 2-byte palette. + let (pair_compose, palette_2b) = + Self::build_pair_compose(&class_representatives, n_classes, total_states); + + // Build 4-byte composition: compose_4b[p0 * n + p1] gives the packed + // u64 for applying palette_2b[p0] then palette_2b[p1] in sequence. + let n_palette = palette_2b.len(); + let compose_4b = Self::build_compose_4b(&palette_2b, total_states); + + Self { + eq_class, + pair_compose, + n_classes, + compose_4b, + n_palette, + transitions_1b, + palette_2b, + accept_state, + } + } + + /// Build the 1-byte packed transition table with escape handling folded + /// into the state space (no branch needed in the scanner). + fn build_escape_folded_transitions( + symbols: &[Symbol], + symbol_lengths: &[u8], + needle: &[u8], + total_states: usize, + ) -> [u64; 256] { + let n = needle.len(); + let n_normal_states = n + 1; + let accept_state = n as u8; + + let byte_table = kmp_byte_transitions(needle); + let sym_trans = build_symbol_transitions( + symbols, + symbol_lengths, + &byte_table, + n_normal_states, + accept_state, + ); + + // Build fused transition table with escape folding. + let n_symbols = symbols.len(); + let mut fused = vec![0u8; total_states * 256]; + for code_byte in 0..256usize { + for s in 0..n { + if code_byte == ESCAPE_CODE as usize { + fused[s * 256 + code_byte] = (s + n + 1) as u8; + } else if code_byte < n_symbols { + fused[s * 256 + code_byte] = sym_trans[s * n_symbols + code_byte]; + } + } + fused[n * 256 + code_byte] = accept_state; + for s in 0..n { + let esc_state = s + n + 1; + let next = byte_table[s * 256 + code_byte] as u8; + fused[esc_state * 256 + code_byte] = next; + } + } + + // Pack into u64 shift table. + pack_shift_table(&fused, total_states, Self::BITS) + } + + /// Build the pair-compose table and 2-byte palette from equivalence + /// class representatives. + fn build_pair_compose( + class_reps: &[u64], + n_classes: usize, + total_states: usize, + ) -> (Vec, Vec) { + let mut pair_compose = vec![0u8; n_classes * n_classes]; + let mut palette_2b: Vec = Vec::new(); + + for c0 in 0..n_classes { + for c1 in 0..n_classes { + let t0 = class_reps[c0]; + let t1 = class_reps[c1]; + let mut packed = 0u64; + for state in 0..total_states { + let mid = ((t0 >> (state as u32 * Self::BITS)) & Self::MASK) as u8; + let final_s = ((t1 >> (mid as u32 * Self::BITS)) & Self::MASK) as u8; + packed |= (final_s as u64) << (state as u32 * Self::BITS); + } + let idx = palette_2b + .iter() + .position(|&v| v == packed) + .unwrap_or_else(|| { + palette_2b.push(packed); + palette_2b.len() - 1 + }); + pair_compose[c0 * n_classes + c1] = idx as u8; + } + } + (pair_compose, palette_2b) + } + + /// Compose pairs of 2-byte palette entries into a 4-byte lookup table. + fn build_compose_4b(palette_2b: &[u64], total_states: usize) -> Vec { + let n = palette_2b.len(); + let mut compose = vec![0u64; n * n]; + for p0 in 0..n { + for p1 in 0..n { + let mut packed = 0u64; + for state in 0..total_states { + let mid = ((palette_2b[p0] >> (state as u32 * Self::BITS)) & Self::MASK) as u8; + let final_s = + ((palette_2b[p1] >> (mid as u32 * Self::BITS)) & Self::MASK) as u8; + packed |= (final_s as u64) << (state as u32 * Self::BITS); + } + compose[p0 * n + p1] = packed; + } + } + compose + } + + /// Process remaining bytes after the interleaved common prefix. + #[inline] + fn finish_tail(&self, mut state: u8, codes: &[u8]) -> u8 { + let chunks = codes.chunks_exact(4); + let rem = chunks.remainder(); + + for chunk in chunks { + let ec0 = unsafe { *self.eq_class.get_unchecked(chunk[0] as usize) } as usize; + let ec1 = unsafe { *self.eq_class.get_unchecked(chunk[1] as usize) } as usize; + let ec2 = unsafe { *self.eq_class.get_unchecked(chunk[2] as usize) } as usize; + let ec3 = unsafe { *self.eq_class.get_unchecked(chunk[3] as usize) } as usize; + let p0 = + unsafe { *self.pair_compose.get_unchecked(ec0 * self.n_classes + ec1) } as usize; + let p1 = + unsafe { *self.pair_compose.get_unchecked(ec2 * self.n_classes + ec3) } as usize; + let packed = unsafe { *self.compose_4b.get_unchecked(p0 * self.n_palette + p1) }; + state = ((packed >> (state as u32 * Self::BITS)) & Self::MASK) as u8; + } + + if rem.len() >= 2 { + let ec0 = self.eq_class[rem[0] as usize] as usize; + let ec1 = self.eq_class[rem[1] as usize] as usize; + let p = self.pair_compose[ec0 * self.n_classes + ec1] as usize; + let packed = self.palette_2b[p]; + state = ((packed >> (state as u32 * Self::BITS)) & Self::MASK) as u8; + if rem.len() == 3 { + let packed = self.transitions_1b[rem[2] as usize]; + state = ((packed >> (state as u32 * Self::BITS)) & Self::MASK) as u8; + } + } else if rem.len() == 1 { + let packed = self.transitions_1b[rem[0] as usize]; + state = ((packed >> (state as u32 * Self::BITS)) & Self::MASK) as u8; + } + + state + } + + /// Branchless matching processing four code bytes per iteration. + #[inline(never)] + pub(crate) fn matches(&self, codes: &[u8]) -> bool { + self.finish_tail(0, codes) == self.accept_state + } +} + +/// Flat u8 escape-folded DFA for medium needles (8-14 chars). +/// +/// Like `BranchlessShiftDfa`, folds escape handling into the state space +/// (2N+1 states), but uses a flat `u8` transition table instead of +/// shift-packed `u64`. Supports up to 14-char needles (2*14+1 = 29 states). +/// Table size: 29 * 256 = 7,424 bytes, fits in L1. +struct FlatBranchlessDfa { + /// transitions[state * 256 + byte] -> next state + transitions: Vec, + accept_state: u8, +} + +impl FlatBranchlessDfa { + pub(crate) const MAX_NEEDLE_LEN: usize = 14; + + pub(crate) fn new(symbols: &[Symbol], symbol_lengths: &[u8], needle: &[u8]) -> Self { + let n = needle.len(); + debug_assert!(n <= Self::MAX_NEEDLE_LEN); + + let accept_state = n as u8; + let total_states = 2 * n + 1; + let n_symbols = symbols.len(); + + let byte_table = kmp_byte_transitions(needle); + let sym_trans = + build_symbol_transitions(symbols, symbol_lengths, &byte_table, n + 1, accept_state); + + // Build fused transition table with escape folding. + let mut transitions = vec![0u8; total_states * 256]; + for code_byte in 0..256usize { + // Normal states 0..n + for s in 0..n { + if code_byte == ESCAPE_CODE as usize { + transitions[s * 256 + code_byte] = (s + n + 1) as u8; + } else if code_byte < n_symbols { + transitions[s * 256 + code_byte] = sym_trans[s * n_symbols + code_byte]; + } + } + // Accept state (sticky) + transitions[n * 256 + code_byte] = accept_state; + // Escape states n+1..2n + for s in 0..n { + let esc_state = s + n + 1; + let next = byte_table[s * 256 + code_byte] as u8; + transitions[esc_state * 256 + code_byte] = next; + } + } + + Self { + transitions, + accept_state, + } + } + + #[inline(never)] + pub(crate) fn matches(&self, codes: &[u8]) -> bool { + let mut state = 0u8; + for &byte in codes { + state = self.transitions[state as usize * 256 + byte as usize]; + } + state == self.accept_state + } +} + +/// Shift-based DFA: packs all state transitions into a `u64` per input byte. +/// +/// For a DFA with S states (S <= 16, using 4 bits each), we store transitions +/// for ALL states in one `u64`. Transition: `next = (table[code] >> (state * 4)) & 0xF`. +/// +/// Supports needles up to 14 characters (needle.len() + 2 <= 16 to fit escape +/// sentinel). This covers virtually all practical LIKE patterns. +pub(crate) struct ShiftDfa { + /// For each code byte (0..255): a `u64` packing all state transitions. + /// Bits `[state*4 .. state*4+4)` encode the next state for that input. + transitions: [u64; 256], + /// Same layout for escape byte transitions. + escape_transitions: [u64; 256], + accept_state: u8, + escape_sentinel: u8, +} + +impl ShiftDfa { + const BITS: u32 = 4; + const MASK: u64 = (1 << Self::BITS) - 1; + /// Maximum needle length: 2^BITS - 2 (need room for accept + sentinel). + const MAX_NEEDLE_LEN: usize = (1 << Self::BITS) - 2; + + fn new(symbols: &[Symbol], symbol_lengths: &[u8], needle: &[u8]) -> Self { + debug_assert!(needle.len() <= Self::MAX_NEEDLE_LEN); + + let n_states = needle.len() + 1; + let accept_state = needle.len() as u8; + let escape_sentinel = needle.len() as u8 + 1; + + let byte_table = kmp_byte_transitions(needle); + let sym_trans = + build_symbol_transitions(symbols, symbol_lengths, &byte_table, n_states, accept_state); + + let fused = build_fused_table(&sym_trans, symbols.len(), n_states, |_| escape_sentinel, 0); + + let transitions = pack_shift_table(&fused, n_states, Self::BITS); + let escape_transitions = pack_escape_shift_table(&byte_table, n_states, Self::BITS); + + Self { + transitions, + escape_transitions, + accept_state, + escape_sentinel, + } + } + + /// Match with iterator-based traversal. + /// + /// Using `iter.next()` instead of manual index + bounds check helps the + /// compiler eliminate redundant bounds checks. + #[inline] + fn matches(&self, codes: &[u8]) -> bool { + let mut state = 0u8; + let mut iter = codes.iter(); + while let Some(&code) = iter.next() { + let packed = self.transitions[code as usize]; + let next = ((packed >> (state as u32 * Self::BITS)) & Self::MASK) as u8; + if next == self.escape_sentinel { + let Some(&b) = iter.next() else { + return false; + }; + let esc_packed = self.escape_transitions[b as usize]; + state = ((esc_packed >> (state as u32 * Self::BITS)) & Self::MASK) as u8; + } else { + state = next; + } + } + state == self.accept_state + } +} + +/// Fused 256-entry u8 table DFA for contains needles in the 15-254 byte range. +/// +/// This representation stores state ids in `u8`, so it cannot represent +/// needles longer than 254 bytes once the accept state and escape sentinel are +/// included. +pub(crate) struct FusedDfa { + transitions: Vec, + escape_transitions: Vec, + accept_state: u8, + escape_sentinel: u8, +} + +impl FusedDfa { + const MAX_NEEDLE_LEN: usize = u8::MAX as usize - 1; + + fn new(symbols: &[Symbol], symbol_lengths: &[u8], needle: &[u8]) -> Self { + debug_assert!(needle.len() <= Self::MAX_NEEDLE_LEN); + + let n_states = needle.len() + 1; + let accept_state = needle.len() as u8; + let escape_sentinel = needle.len() as u8 + 1; + + let byte_table = kmp_byte_transitions(needle); + let sym_trans = + build_symbol_transitions(symbols, symbol_lengths, &byte_table, n_states, accept_state); + + let transitions = + build_fused_table(&sym_trans, symbols.len(), n_states, |_| escape_sentinel, 0); + + let escape_transitions: Vec = byte_table.iter().map(|&v| v as u8).collect(); + + Self { + transitions, + escape_transitions, + accept_state, + escape_sentinel, + } + } + + #[inline] + fn matches(&self, codes: &[u8]) -> bool { + let mut state = 0u8; + let mut pos = 0; + while pos < codes.len() { + let code = codes[pos]; + pos += 1; + let next = self.transitions[state as usize * 256 + code as usize]; + if next == self.escape_sentinel { + if pos >= codes.len() { + return false; + } + let b = codes[pos]; + pos += 1; + state = self.escape_transitions[state as usize * 256 + b as usize]; + } else { + state = next; + } + if state == self.accept_state { + return true; + } + } + false + } +} + +// --------------------------------------------------------------------------- +// KMP helpers +// --------------------------------------------------------------------------- + +fn kmp_byte_transitions(needle: &[u8]) -> Vec { + let n_states = needle.len() + 1; + let accept = needle.len() as u16; + let failure = kmp_failure_table(needle); + + let mut table = vec![0u16; n_states * 256]; + for state in 0..n_states { + for byte in 0..256u16 { + if state == needle.len() { + table[state * 256 + byte as usize] = accept; + continue; + } + let mut s = state; + loop { + if byte as u8 == needle[s] { + s += 1; + break; + } + if s == 0 { + break; + } + s = failure[s - 1]; + } + table[state * 256 + byte as usize] = s as u16; + } + } + table +} + +fn kmp_failure_table(needle: &[u8]) -> Vec { + let mut failure = vec![0usize; needle.len()]; + let mut k = 0; + for i in 1..needle.len() { + while k > 0 && needle[k] != needle[i] { + k = failure[k - 1]; + } + if needle[k] == needle[i] { + k += 1; + } + failure[i] = k; + } + failure +} + +#[cfg(test)] +mod tests { + use fsst::ESCAPE_CODE; + + use super::FusedDfa; + use super::FsstMatcher; + use super::FsstPrefixDfa; + use super::LikeKind; + + fn escaped(bytes: &[u8]) -> Vec { + let mut codes = Vec::with_capacity(bytes.len() * 2); + for &b in bytes { + codes.push(ESCAPE_CODE); + codes.push(b); + } + codes + } + + #[test] + fn test_like_kind_parse() { + assert!(matches!( + LikeKind::parse("http%"), + Some(LikeKind::Prefix("http")) + )); + assert!(matches!( + LikeKind::parse("%needle%"), + Some(LikeKind::Contains("needle")) + )); + assert!(matches!(LikeKind::parse("%"), Some(LikeKind::Prefix("")))); + // Suffix and underscore patterns are not supported. + assert!(LikeKind::parse("%suffix").is_none()); + assert!(LikeKind::parse("a_c").is_none()); + } + + #[test] + fn test_prefix_pushdown_len_13_with_escapes() { + let matcher = FsstMatcher::try_new(&[], &[], "abcdefghijklm%") + .unwrap() + .unwrap(); + + assert!(matcher.matches(&escaped(b"abcdefghijklm"))); + assert!(!matcher.matches(&escaped(b"abcdefghijklx"))); + } + + #[test] + fn test_prefix_pushdown_rejects_len_14() { + debug_assert_eq!(FsstPrefixDfa::MAX_PREFIX_LEN, 13); + assert!( + FsstMatcher::try_new(&[], &[], "abcdefghijklmn%") + .unwrap() + .is_none() + ); + } + + #[test] + fn test_contains_pushdown_len_254_with_escapes() { + let needle = "a".repeat(FusedDfa::MAX_NEEDLE_LEN); + let pattern = format!("%{needle}%"); + let matcher = FsstMatcher::try_new(&[], &[], &pattern).unwrap().unwrap(); + + assert!(matcher.matches(&escaped(needle.as_bytes()))); + + let mut mismatch = needle.into_bytes(); + mismatch[FusedDfa::MAX_NEEDLE_LEN - 1] = b'b'; + assert!(!matcher.matches(&escaped(&mismatch))); + } + + #[test] + fn test_contains_pushdown_rejects_len_255() { + let needle = "a".repeat(FusedDfa::MAX_NEEDLE_LEN + 1); + let pattern = format!("%{needle}%"); + assert!(FsstMatcher::try_new(&[], &[], &pattern).unwrap().is_none()); + } +} diff --git a/encodings/fsst/src/lib.rs b/encodings/fsst/src/lib.rs index 5cc75c59b2a..3305c0e66fc 100644 --- a/encodings/fsst/src/lib.rs +++ b/encodings/fsst/src/lib.rs @@ -15,6 +15,7 @@ mod array; mod canonical; mod compress; mod compute; +mod dfa; mod kernel; mod ops; mod rules; From e08fb69ad3e806da9fa9e20175794ebe4541cfb4 Mon Sep 17 00:00:00 2001 From: Joe Isaacs Date: Tue, 17 Mar 2026 11:14:08 +0000 Subject: [PATCH 05/19] clean up Signed-off-by: Joe Isaacs --- .github/workflows/fuzz.yml | 36 +++ Cargo.lock | 1 + encodings/fsst/Cargo.toml | 4 + encodings/fsst/benches/bitpack_strategy.rs | 278 +++++++++++++++++++++ encodings/fsst/src/compute/like.rs | 6 +- encodings/fsst/src/dfa.rs | 253 +++++++++++-------- fuzz/Cargo.toml | 9 + fuzz/fuzz_targets/fsst_like.rs | 35 +++ fuzz/src/fsst_like.rs | 164 ++++++++++++ fuzz/src/lib.rs | 3 + 10 files changed, 679 insertions(+), 110 deletions(-) create mode 100644 encodings/fsst/benches/bitpack_strategy.rs create mode 100644 fuzz/fuzz_targets/fsst_like.rs create mode 100644 fuzz/src/fsst_like.rs diff --git a/.github/workflows/fuzz.yml b/.github/workflows/fuzz.yml index 49758ac260e..d7d97198023 100644 --- a/.github/workflows/fuzz.yml +++ b/.github/workflows/fuzz.yml @@ -99,6 +99,42 @@ jobs: gh_token: ${{ secrets.GITHUB_TOKEN }} incident_io_alert_token: ${{ secrets.INCIDENT_IO_ALERT_TOKEN }} + # ============================================================================ + # FSST LIKE Fuzzer + # ============================================================================ + fsst_like_fuzz: + name: "FSST LIKE Fuzz" + uses: ./.github/workflows/run-fuzzer.yml + with: + fuzz_target: fsst_like + jobs: 16 + secrets: + R2_FUZZ_ACCESS_KEY_ID: ${{ secrets.R2_FUZZ_ACCESS_KEY_ID }} + R2_FUZZ_SECRET_ACCESS_KEY: ${{ secrets.R2_FUZZ_SECRET_ACCESS_KEY }} + + report-fsst-like-fuzz-failures: + name: "Report FSST LIKE Fuzz Failures" + needs: fsst_like_fuzz + if: always() && needs.fsst_like_fuzz.outputs.crashes_found == 'true' + permissions: + issues: write + contents: read + id-token: write + pull-requests: read + uses: ./.github/workflows/report-fuzz-crash.yml + with: + fuzz_target: fsst_like + crash_file: ${{ needs.fsst_like_fuzz.outputs.first_crash_name }} + artifact_url: ${{ needs.fsst_like_fuzz.outputs.artifact_url }} + artifact_name: fsst_like-crash-artifacts + logs_artifact_name: fsst_like-logs + branch: ${{ github.ref_name }} + commit: ${{ github.sha }} + secrets: + claude_code_oauth_token: ${{ secrets.CLAUDE_CODE_OAUTH_TOKEN }} + gh_token: ${{ secrets.GITHUB_TOKEN }} + incident_io_alert_token: ${{ secrets.INCIDENT_IO_ALERT_TOKEN }} + # ============================================================================ # Compress Roundtrip Fuzzer # ============================================================================ diff --git a/Cargo.lock b/Cargo.lock index f3d1d80a2da..bc33ddb3d75 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -10219,6 +10219,7 @@ dependencies = [ "vortex-cuda", "vortex-error", "vortex-file", + "vortex-fsst", "vortex-io", "vortex-mask", "vortex-runend", diff --git a/encodings/fsst/Cargo.toml b/encodings/fsst/Cargo.toml index b95eeb1f444..c5dd318f9a9 100644 --- a/encodings/fsst/Cargo.toml +++ b/encodings/fsst/Cargo.toml @@ -50,6 +50,10 @@ name = "fsst_url_compare" harness = false required-features = ["_test-harness"] +[[bench]] +name = "bitpack_strategy" +harness = false + [[bench]] name = "chunked_dict_fsst_builder" harness = false diff --git a/encodings/fsst/benches/bitpack_strategy.rs b/encodings/fsst/benches/bitpack_strategy.rs new file mode 100644 index 00000000000..d3d6ee348aa --- /dev/null +++ b/encodings/fsst/benches/bitpack_strategy.rs @@ -0,0 +1,278 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +//! Microbenchmark comparing three bit-packing strategies for `dfa_scan_to_bitbuf`: +//! +//! 1. **manual_word** — current: pack directly into u64 words one bit at a time +//! 2. **collect_bool** — use `BitBuffer::collect_bool` with a closure +//! 3. **bool_buf_64** — write results into `[bool; 64]` stack buffer, then compress +//! +//! Uses a trivial matcher (single byte comparison) so that the packing +//! overhead dominates rather than per-string work. + +#![allow(clippy::unwrap_used, clippy::cast_possible_truncation)] + +use divan::Bencher; +use vortex_buffer::BitBuffer; +use vortex_buffer::BufferMut; + +fn main() { + divan::main(); +} + +// --------------------------------------------------------------------------- +// Test data: precomputed bool results + offsets (to keep the offset-reading +// overhead consistent across strategies while making the matcher trivial) +// --------------------------------------------------------------------------- + +const N: usize = 100_000; + +struct TestData { + /// Precomputed match results for each of the N "strings". + results: Vec, + /// Fake offsets array (N+1 entries) so the offset-reading overhead is + /// included, matching the real `dfa_scan_to_bitbuf` pattern. + offsets: Vec, + /// Single-byte "strings" — just used so the matcher reads *something*. + bytes: Vec, +} + +impl TestData { + fn new() -> Self { + let mut rng_state: u64 = 0xDEAD_BEEF; + let mut next = || -> u64 { + rng_state = rng_state.wrapping_mul(6364136223846793005).wrapping_add(1); + rng_state >> 33 + }; + + // Each "string" is exactly 1 byte so the matcher is essentially free. + let bytes: Vec = (0..N).map(|_| (next() % 256) as u8).collect(); + let offsets: Vec = (0..=N).map(|i| i as u32).collect(); + let results: Vec = bytes.iter().map(|&b| b >= 128).collect(); + + Self { + results, + offsets, + bytes, + } + } +} + +// Trivial matcher: single byte check. The real work being benchmarked is +// the bit-packing loop, not this function. +#[inline(always)] +fn matcher(data: &[u8]) -> bool { + // SAFETY: benchmark guarantees non-empty slices + unsafe { *data.get_unchecked(0) >= 128 } +} + +// --------------------------------------------------------------------------- +// Strategy 1: manual word packing (current implementation) +// --------------------------------------------------------------------------- + +#[inline(never)] +fn scan_manual_word(offsets: &[u32], all_bytes: &[u8], n: usize, negated: bool) -> BitBuffer { + let n_words = n / 64; + let remainder = n % 64; + let mut words: BufferMut = BufferMut::with_capacity(n.div_ceil(64)); + + for chunk in 0..n_words { + let base = chunk * 64; + let mut word = 0u64; + let mut start = offsets[base] as usize; + for bit in 0..64 { + let end = offsets[base + bit + 1] as usize; + word |= ((matcher(&all_bytes[start..end]) != negated) as u64) << bit; + start = end; + } + unsafe { words.push_unchecked(word) }; + } + + if remainder != 0 { + let base = n_words * 64; + let mut word = 0u64; + let mut start = offsets[base] as usize; + for bit in 0..remainder { + let end = offsets[base + bit + 1] as usize; + word |= ((matcher(&all_bytes[start..end]) != negated) as u64) << bit; + start = end; + } + unsafe { words.push_unchecked(word) }; + } + + BitBuffer::new(words.into_byte_buffer().freeze(), n) +} + +// --------------------------------------------------------------------------- +// Strategy 2: BitBuffer::collect_bool +// --------------------------------------------------------------------------- + +#[inline(never)] +fn scan_collect_bool(offsets: &[u32], all_bytes: &[u8], n: usize, negated: bool) -> BitBuffer { + let mut start = offsets[0] as usize; + BitBuffer::collect_bool(n, |i| { + let end = offsets[i + 1] as usize; + let result = matcher(&all_bytes[start..end]) != negated; + start = end; + result + }) +} + +// --------------------------------------------------------------------------- +// Strategy 3: [bool; 64] stack buffer then compress +// --------------------------------------------------------------------------- + +#[inline(never)] +fn scan_bool_buf(offsets: &[u32], all_bytes: &[u8], n: usize, negated: bool) -> BitBuffer { + let n_words = n / 64; + let remainder = n % 64; + let mut words: BufferMut = BufferMut::with_capacity(n.div_ceil(64)); + + for chunk in 0..n_words { + let base = chunk * 64; + let mut bools = [false; 64]; + let mut start = offsets[base] as usize; + for bit in 0..64 { + let end = offsets[base + bit + 1] as usize; + bools[bit] = matcher(&all_bytes[start..end]) != negated; + start = end; + } + let mut word = 0u64; + for bit in 0..64 { + word |= (bools[bit] as u64) << bit; + } + unsafe { words.push_unchecked(word) }; + } + + if remainder != 0 { + let base = n_words * 64; + let mut bools = [false; 64]; + let mut start = offsets[base] as usize; + for bit in 0..remainder { + let end = offsets[base + bit + 1] as usize; + bools[bit] = matcher(&all_bytes[start..end]) != negated; + start = end; + } + let mut word = 0u64; + for bit in 0..remainder { + word |= (bools[bit] as u64) << bit; + } + unsafe { words.push_unchecked(word) }; + } + + BitBuffer::new(words.into_byte_buffer().freeze(), n) +} + +// --------------------------------------------------------------------------- +// Strategy 4: precomputed bools (pure packing, no matcher at all) +// Isolates *just* the bool→bitbuffer packing cost. +// --------------------------------------------------------------------------- + +#[inline(never)] +fn pack_from_slice_manual(results: &[bool], n: usize) -> BitBuffer { + let n_words = n / 64; + let remainder = n % 64; + let mut words: BufferMut = BufferMut::with_capacity(n.div_ceil(64)); + + for chunk in 0..n_words { + let base = chunk * 64; + let mut word = 0u64; + for bit in 0..64 { + word |= (results[base + bit] as u64) << bit; + } + unsafe { words.push_unchecked(word) }; + } + + if remainder != 0 { + let base = n_words * 64; + let mut word = 0u64; + for bit in 0..remainder { + word |= (results[base + bit] as u64) << bit; + } + unsafe { words.push_unchecked(word) }; + } + + BitBuffer::new(words.into_byte_buffer().freeze(), n) +} + +#[inline(never)] +fn pack_from_slice_collect_bool(results: &[bool], n: usize) -> BitBuffer { + BitBuffer::collect_bool(n, |i| results[i]) +} + +#[inline(never)] +fn pack_from_slice_bool_buf(results: &[bool], n: usize) -> BitBuffer { + let n_words = n / 64; + let remainder = n % 64; + let mut words: BufferMut = BufferMut::with_capacity(n.div_ceil(64)); + + for chunk in 0..n_words { + let base = chunk * 64; + let mut bools = [false; 64]; + bools.copy_from_slice(&results[base..base + 64]); + let mut word = 0u64; + for bit in 0..64 { + word |= (bools[bit] as u64) << bit; + } + unsafe { words.push_unchecked(word) }; + } + + if remainder != 0 { + let base = n_words * 64; + let mut bools = [false; 64]; + bools[..remainder].copy_from_slice(&results[base..base + remainder]); + let mut word = 0u64; + for bit in 0..remainder { + word |= (bools[bit] as u64) << bit; + } + unsafe { words.push_unchecked(word) }; + } + + BitBuffer::new(words.into_byte_buffer().freeze(), n) +} + +// --------------------------------------------------------------------------- +// Benchmarks +// --------------------------------------------------------------------------- + +static TEST_DATA: std::sync::LazyLock = std::sync::LazyLock::new(TestData::new); + +// --- Group 1: with offset reading + trivial matcher (dfa_scan_to_bitbuf shape) --- + +#[divan::bench] +fn with_offsets_manual_word(bencher: Bencher) { + let data = &*TEST_DATA; + bencher.bench_local(|| scan_manual_word(&data.offsets, &data.bytes, N, false)); +} + +#[divan::bench] +fn with_offsets_collect_bool(bencher: Bencher) { + let data = &*TEST_DATA; + bencher.bench_local(|| scan_collect_bool(&data.offsets, &data.bytes, N, false)); +} + +#[divan::bench] +fn with_offsets_bool_buf_64(bencher: Bencher) { + let data = &*TEST_DATA; + bencher.bench_local(|| scan_bool_buf(&data.offsets, &data.bytes, N, false)); +} + +// --- Group 2: pure packing from precomputed bools (isolates packing cost) --- + +#[divan::bench] +fn pure_pack_manual_word(bencher: Bencher) { + let data = &*TEST_DATA; + bencher.bench_local(|| pack_from_slice_manual(&data.results, N)); +} + +#[divan::bench] +fn pure_pack_collect_bool(bencher: Bencher) { + let data = &*TEST_DATA; + bencher.bench_local(|| pack_from_slice_collect_bool(&data.results, N)); +} + +#[divan::bench] +fn pure_pack_bool_buf_64(bencher: Bencher) { + let data = &*TEST_DATA; + bencher.bench_local(|| pack_from_slice_bool_buf(&data.results, N)); +} diff --git a/encodings/fsst/src/compute/like.rs b/encodings/fsst/src/compute/like.rs index c53f621f25f..ad0f65ee0ac 100644 --- a/encodings/fsst/src/compute/like.rs +++ b/encodings/fsst/src/compute/like.rs @@ -19,7 +19,6 @@ use crate::dfa::FsstMatcher; use crate::dfa::dfa_scan_to_bitbuf; impl LikeKernel for FSST { - #[allow(clippy::cast_possible_truncation)] fn like( array: &FSSTArray, pattern: &ArrayRef, @@ -347,7 +346,10 @@ mod tests { LikeOptions::default(), &mut SESSION.create_execution_ctx(), )?; - assert!(direct.is_some(), "254-byte contains needle should stay on the DFA path"); + assert!( + direct.is_some(), + "254-byte contains needle should stay on the DFA path" + ); assert_arrays_eq!(direct.unwrap(), BoolArray::from_iter([true, false, true])); Ok(()) } diff --git a/encodings/fsst/src/dfa.rs b/encodings/fsst/src/dfa.rs index 2865f5bc4fb..865bfe53f28 100644 --- a/encodings/fsst/src/dfa.rs +++ b/encodings/fsst/src/dfa.rs @@ -163,8 +163,6 @@ //! state = transitions[state][byte] // branchless! //! ``` -#![allow(clippy::cast_possible_truncation)] - use fsst::ESCAPE_CODE; use fsst::Symbol; use vortex_buffer::BitBuffer; @@ -270,30 +268,17 @@ enum LikeKind<'a> { impl<'a> LikeKind<'a> { fn parse(pattern: &'a str) -> Option { - if pattern == "%" { - return Some(LikeKind::Prefix("")); - } - - // Find first wildcard. - let first_wild = pattern.find(['%', '_'])?; - - // `_` as first wildcard means we can't handle it. - if pattern.as_bytes()[first_wild] == b'_' { - return None; - } - - // `prefix%` — single trailing % - if first_wild > 0 && &pattern[first_wild..] == "%" { - return Some(LikeKind::Prefix(&pattern[..first_wild])); + // `prefix%` (including just `%` where prefix is empty) + if let Some(prefix) = pattern.strip_suffix('%') { + if !prefix.contains(['%', '_']) { + return Some(LikeKind::Prefix(prefix)); + } } - // `%needle%` — leading and trailing %, no inner wildcards - if first_wild == 0 - && pattern.len() > 2 - && pattern.as_bytes()[pattern.len() - 1] == b'%' - && !pattern[1..pattern.len() - 1].contains(['%', '_']) - { - return Some(LikeKind::Contains(&pattern[1..pattern.len() - 1])); + // `%needle%` + let inner = pattern.strip_prefix('%')?.strip_suffix('%')?; + if !inner.contains(['%', '_']) { + return Some(LikeKind::Contains(inner)); } None @@ -354,7 +339,22 @@ where } // --------------------------------------------------------------------------- -// Shared DFA construction helpers +// Shared helpers +// --------------------------------------------------------------------------- + +/// Extract a state id from a shift-packed `u64` word. +/// +/// Each state occupies `bits` bits. The mask `(1 << bits) - 1` guarantees the +/// result is at most 15 (for `bits = 4`), which always fits in `u8`. +#[inline(always)] +fn shift_extract(packed: u64, state: u8, bits: u32) -> u8 { + let mask = (1u64 << bits) - 1; + // bits ≤ 4 ⇒ mask ≤ 15 ⇒ result ≤ 15, always fits in u8. + u8::try_from((packed >> (u32::from(state) * bits)) & mask).unwrap() +} + +// --------------------------------------------------------------------------- +// DFA construction helpers // --------------------------------------------------------------------------- /// Builds the per-symbol transition table for FSST symbols. @@ -374,20 +374,22 @@ fn build_symbol_transitions( let mut sym_trans = vec![0u8; n_states * n_symbols]; for state in 0..n_states { for code in 0..n_symbols { - if state as u8 == accept_state { + if state == usize::from(accept_state) { sym_trans[state * n_symbols + code] = accept_state; continue; } let sym = symbols[code].to_u64().to_le_bytes(); - let sym_len = symbol_lengths[code] as usize; - let mut s = state as u16; + let sym_len = usize::from(symbol_lengths[code]); + // state < n_states ≤ 256, fits in u16 + let mut s = u16::try_from(state).unwrap(); for &b in &sym[..sym_len] { - if s == accept_state as u16 { + if s == u16::from(accept_state) { break; } - s = byte_table[s as usize * 256 + b as usize]; + s = byte_table[usize::from(s) * 256 + usize::from(b)]; } - sym_trans[state * n_symbols + code] = s as u8; + // s is a state id from byte_table, always < n_states ≤ 256 + sym_trans[state * n_symbols + code] = u8::try_from(s).unwrap(); } } sym_trans @@ -427,7 +429,9 @@ fn pack_shift_table(fused: &[u8], n_states: usize, bits: u32) -> [u64; 256] { for code_byte in 0..256usize { let mut val = 0u64; for state in 0..n_states { - val |= (fused[state * 256 + code_byte] as u64) << (state as u32 * bits); + // state < n_states ≤ 16 for 4-bit packing, fits in u32 + val |= u64::from(fused[state * 256 + code_byte]) + << (u32::try_from(state).unwrap() * bits); } packed[code_byte] = val; } @@ -440,8 +444,9 @@ fn pack_escape_shift_table(byte_table: &[u16], n_states: usize, bits: u32) -> [u for byte_val in 0..256usize { let mut val = 0u64; for state in 0..n_states { - let next = byte_table[state * 256 + byte_val] as u8; - val |= (next as u64) << (state as u32 * bits); + // byte_table values are state ids < n_states ≤ 256, fit in u8 + let next = u8::try_from(byte_table[state * 256 + byte_val]).unwrap(); + val |= u64::from(next) << (u32::try_from(state).unwrap() * bits); } packed[byte_val] = val; } @@ -469,15 +474,15 @@ struct FsstPrefixDfa { impl FsstPrefixDfa { pub(crate) const BITS: u32 = 4; - const MASK: u64 = (1 << Self::BITS) - 1; const MAX_PREFIX_LEN: usize = (1 << Self::BITS) as usize - 3; pub(crate) fn new(symbols: &[Symbol], symbol_lengths: &[u8], prefix: &[u8]) -> Self { // Need room for states 0..prefix_len, accept, fail, and an escape sentinel. debug_assert!(prefix.len() <= Self::MAX_PREFIX_LEN); - let accept_state = prefix.len() as u8; - let fail_state = prefix.len() as u8 + 1; + // prefix.len() ≤ MAX_PREFIX_LEN (13), fits in u8 + let accept_state = u8::try_from(prefix.len()).unwrap(); + let fail_state = u8::try_from(prefix.len() + 1).unwrap(); let n_states = prefix.len() + 2; // Prefix matching uses a simpler transition rule than KMP: on mismatch @@ -511,18 +516,19 @@ impl FsstPrefixDfa { // Build escape transitions from the byte table. let mut esc_trans = vec![fail_state; n_states * 256]; for state in 0..n_states { - if state as u8 == accept_state { + if state == usize::from(accept_state) { for b in 0..256 { esc_trans[state * 256 + b] = accept_state; } - } else if state as u8 != fail_state { + } else if state != usize::from(fail_state) { for b in 0..256usize { - if b as u8 == prefix[state] { + if b == usize::from(prefix[state]) { let next = state + 1; esc_trans[state * 256 + b] = if next >= prefix.len() { accept_state } else { - next as u8 + // next ≤ prefix.len() ≤ 13, fits in u8 + u8::try_from(next).unwrap() }; } } @@ -541,22 +547,23 @@ impl FsstPrefixDfa { /// Build a byte-level transition table for prefix matching (no KMP fallback). fn build_prefix_byte_table(prefix: &[u8], accept_state: u8, fail_state: u8) -> Vec { let n_states = prefix.len() + 2; - let mut table = vec![fail_state as u16; n_states * 256]; + let mut table = vec![u16::from(fail_state); n_states * 256]; for state in 0..n_states { - if state as u8 == accept_state { + if state == usize::from(accept_state) { for byte in 0..256 { - table[state * 256 + byte] = accept_state as u16; + table[state * 256 + byte] = u16::from(accept_state); } - } else if state as u8 != fail_state { + } else if state != usize::from(fail_state) { // Only the correct next byte advances; everything else fails. let next_byte = prefix[state]; let next_state = if state + 1 >= prefix.len() { - accept_state as u16 + u16::from(accept_state) } else { - (state + 1) as u16 + // state + 1 ≤ prefix.len() ≤ 13, fits in u16 + u16::try_from(state + 1).unwrap() }; - table[state * 256 + next_byte as usize] = next_state; + table[state * 256 + usize::from(next_byte)] = next_state; } } table @@ -569,8 +576,9 @@ impl FsstPrefixDfa { while pos < codes.len() { let code = codes[pos]; pos += 1; - let packed = self.transitions[code as usize]; - let next = ((packed >> (state as u32 * Self::BITS)) & Self::MASK) as u8; + let packed = self.transitions[usize::from(code)]; + // Masked to BITS (4) bits, result ≤ 15, fits in u8 + let next = shift_extract(packed, state, Self::BITS); if next == self.fail_state + 1 { // Escape sentinel: read literal byte. if pos >= codes.len() { @@ -578,8 +586,8 @@ impl FsstPrefixDfa { } let b = codes[pos]; pos += 1; - let esc_packed = self.escape_transitions[b as usize]; - state = ((esc_packed >> (state as u32 * Self::BITS)) & Self::MASK) as u8; + let esc_packed = self.escape_transitions[usize::from(b)]; + state = shift_extract(esc_packed, state, Self::BITS); } else { state = next; } @@ -673,7 +681,6 @@ struct BranchlessShiftDfa { impl BranchlessShiftDfa { const BITS: u32 = 4; - const MASK: u64 = (1 << Self::BITS) - 1; /// Maximum needle length: need 2N+1 states to fit in 16 slots (4 bits). /// 2*7+1 = 15 <= 16, so max N = 7. pub(crate) const MAX_NEEDLE_LEN: usize = 7; @@ -682,7 +689,8 @@ impl BranchlessShiftDfa { let n = needle.len(); debug_assert!(n <= Self::MAX_NEEDLE_LEN); - let accept_state = n as u8; + // n ≤ MAX_NEEDLE_LEN (7), fits in u8 + let accept_state = u8::try_from(n).unwrap(); let total_states = 2 * n + 1; debug_assert!(total_states <= (1 << Self::BITS)); @@ -701,7 +709,8 @@ impl BranchlessShiftDfa { class_representatives.push(t); class_representatives.len() - 1 }); - eq_class[byte_val] = cls as u8; + // At most 256 equivalence classes (one per byte value), fits in u8 + eq_class[byte_val] = u8::try_from(cls).unwrap(); } let n_classes = class_representatives.len(); @@ -737,7 +746,8 @@ impl BranchlessShiftDfa { ) -> [u64; 256] { let n = needle.len(); let n_normal_states = n + 1; - let accept_state = n as u8; + // n ≤ MAX_NEEDLE_LEN (7), fits in u8 + let accept_state = u8::try_from(n).unwrap(); let byte_table = kmp_byte_transitions(needle); let sym_trans = build_symbol_transitions( @@ -753,8 +763,9 @@ impl BranchlessShiftDfa { let mut fused = vec![0u8; total_states * 256]; for code_byte in 0..256usize { for s in 0..n { - if code_byte == ESCAPE_CODE as usize { - fused[s * 256 + code_byte] = (s + n + 1) as u8; + if code_byte == usize::from(ESCAPE_CODE) { + // s + n + 1 ≤ 2*7 = 14, fits in u8 + fused[s * 256 + code_byte] = u8::try_from(s + n + 1).unwrap(); } else if code_byte < n_symbols { fused[s * 256 + code_byte] = sym_trans[s * n_symbols + code_byte]; } @@ -762,7 +773,8 @@ impl BranchlessShiftDfa { fused[n * 256 + code_byte] = accept_state; for s in 0..n { let esc_state = s + n + 1; - let next = byte_table[s * 256 + code_byte] as u8; + // byte_table values are state ids < n_normal_states ≤ 8 + let next = u8::try_from(byte_table[s * 256 + code_byte]).unwrap(); fused[esc_state * 256 + code_byte] = next; } } @@ -787,9 +799,10 @@ impl BranchlessShiftDfa { let t1 = class_reps[c1]; let mut packed = 0u64; for state in 0..total_states { - let mid = ((t0 >> (state as u32 * Self::BITS)) & Self::MASK) as u8; - let final_s = ((t1 >> (mid as u32 * Self::BITS)) & Self::MASK) as u8; - packed |= (final_s as u64) << (state as u32 * Self::BITS); + let state_shift = u32::try_from(state).unwrap() * Self::BITS; + let mid = shift_extract(t0, u8::try_from(state).unwrap(), Self::BITS); + let final_s = shift_extract(t1, mid, Self::BITS); + packed |= u64::from(final_s) << state_shift; } let idx = palette_2b .iter() @@ -798,7 +811,8 @@ impl BranchlessShiftDfa { palette_2b.push(packed); palette_2b.len() - 1 }); - pair_compose[c0 * n_classes + c1] = idx as u8; + // Palette size bounded by n_classes^2, in practice ≤ ~36 + pair_compose[c0 * n_classes + c1] = u8::try_from(idx).unwrap(); } } (pair_compose, palette_2b) @@ -812,10 +826,10 @@ impl BranchlessShiftDfa { for p1 in 0..n { let mut packed = 0u64; for state in 0..total_states { - let mid = ((palette_2b[p0] >> (state as u32 * Self::BITS)) & Self::MASK) as u8; - let final_s = - ((palette_2b[p1] >> (mid as u32 * Self::BITS)) & Self::MASK) as u8; - packed |= (final_s as u64) << (state as u32 * Self::BITS); + let state_shift = u32::try_from(state).unwrap() * Self::BITS; + let mid = shift_extract(palette_2b[p0], u8::try_from(state).unwrap(), Self::BITS); + let final_s = shift_extract(palette_2b[p1], mid, Self::BITS); + packed |= u64::from(final_s) << state_shift; } compose[p0 * n + p1] = packed; } @@ -830,31 +844,42 @@ impl BranchlessShiftDfa { let rem = chunks.remainder(); for chunk in chunks { - let ec0 = unsafe { *self.eq_class.get_unchecked(chunk[0] as usize) } as usize; - let ec1 = unsafe { *self.eq_class.get_unchecked(chunk[1] as usize) } as usize; - let ec2 = unsafe { *self.eq_class.get_unchecked(chunk[2] as usize) } as usize; - let ec3 = unsafe { *self.eq_class.get_unchecked(chunk[3] as usize) } as usize; - let p0 = - unsafe { *self.pair_compose.get_unchecked(ec0 * self.n_classes + ec1) } as usize; - let p1 = - unsafe { *self.pair_compose.get_unchecked(ec2 * self.n_classes + ec3) } as usize; - let packed = unsafe { *self.compose_4b.get_unchecked(p0 * self.n_palette + p1) }; - state = ((packed >> (state as u32 * Self::BITS)) & Self::MASK) as u8; + // SAFETY: chunk[i] is u8, eq_class has 256 entries — index always in bounds. + let ec0 = unsafe { *self.eq_class.get_unchecked(usize::from(chunk[0])) }; + let ec1 = unsafe { *self.eq_class.get_unchecked(usize::from(chunk[1])) }; + let ec2 = unsafe { *self.eq_class.get_unchecked(usize::from(chunk[2])) }; + let ec3 = unsafe { *self.eq_class.get_unchecked(usize::from(chunk[3])) }; + let p0 = unsafe { + *self + .pair_compose + .get_unchecked(usize::from(ec0) * self.n_classes + usize::from(ec1)) + }; + let p1 = unsafe { + *self + .pair_compose + .get_unchecked(usize::from(ec2) * self.n_classes + usize::from(ec3)) + }; + let packed = unsafe { + *self + .compose_4b + .get_unchecked(usize::from(p0) * self.n_palette + usize::from(p1)) + }; + state = shift_extract(packed, state, Self::BITS); } if rem.len() >= 2 { - let ec0 = self.eq_class[rem[0] as usize] as usize; - let ec1 = self.eq_class[rem[1] as usize] as usize; - let p = self.pair_compose[ec0 * self.n_classes + ec1] as usize; - let packed = self.palette_2b[p]; - state = ((packed >> (state as u32 * Self::BITS)) & Self::MASK) as u8; + let ec0 = self.eq_class[usize::from(rem[0])]; + let ec1 = self.eq_class[usize::from(rem[1])]; + let p = self.pair_compose[usize::from(ec0) * self.n_classes + usize::from(ec1)]; + let packed = self.palette_2b[usize::from(p)]; + state = shift_extract(packed, state, Self::BITS); if rem.len() == 3 { - let packed = self.transitions_1b[rem[2] as usize]; - state = ((packed >> (state as u32 * Self::BITS)) & Self::MASK) as u8; + let packed = self.transitions_1b[usize::from(rem[2])]; + state = shift_extract(packed, state, Self::BITS); } } else if rem.len() == 1 { - let packed = self.transitions_1b[rem[0] as usize]; - state = ((packed >> (state as u32 * Self::BITS)) & Self::MASK) as u8; + let packed = self.transitions_1b[usize::from(rem[0])]; + state = shift_extract(packed, state, Self::BITS); } state @@ -886,7 +911,8 @@ impl FlatBranchlessDfa { let n = needle.len(); debug_assert!(n <= Self::MAX_NEEDLE_LEN); - let accept_state = n as u8; + // n ≤ MAX_NEEDLE_LEN (14), fits in u8 + let accept_state = u8::try_from(n).unwrap(); let total_states = 2 * n + 1; let n_symbols = symbols.len(); @@ -899,8 +925,9 @@ impl FlatBranchlessDfa { for code_byte in 0..256usize { // Normal states 0..n for s in 0..n { - if code_byte == ESCAPE_CODE as usize { - transitions[s * 256 + code_byte] = (s + n + 1) as u8; + if code_byte == usize::from(ESCAPE_CODE) { + // s + n + 1 ≤ 2*14 = 28, fits in u8 + transitions[s * 256 + code_byte] = u8::try_from(s + n + 1).unwrap(); } else if code_byte < n_symbols { transitions[s * 256 + code_byte] = sym_trans[s * n_symbols + code_byte]; } @@ -910,7 +937,8 @@ impl FlatBranchlessDfa { // Escape states n+1..2n for s in 0..n { let esc_state = s + n + 1; - let next = byte_table[s * 256 + code_byte] as u8; + // byte_table values are state ids < n+1 ≤ 15 + let next = u8::try_from(byte_table[s * 256 + code_byte]).unwrap(); transitions[esc_state * 256 + code_byte] = next; } } @@ -925,7 +953,7 @@ impl FlatBranchlessDfa { pub(crate) fn matches(&self, codes: &[u8]) -> bool { let mut state = 0u8; for &byte in codes { - state = self.transitions[state as usize * 256 + byte as usize]; + state = self.transitions[usize::from(state) * 256 + usize::from(byte)]; } state == self.accept_state } @@ -950,7 +978,6 @@ pub(crate) struct ShiftDfa { impl ShiftDfa { const BITS: u32 = 4; - const MASK: u64 = (1 << Self::BITS) - 1; /// Maximum needle length: 2^BITS - 2 (need room for accept + sentinel). const MAX_NEEDLE_LEN: usize = (1 << Self::BITS) - 2; @@ -958,8 +985,9 @@ impl ShiftDfa { debug_assert!(needle.len() <= Self::MAX_NEEDLE_LEN); let n_states = needle.len() + 1; - let accept_state = needle.len() as u8; - let escape_sentinel = needle.len() as u8 + 1; + // needle.len() ≤ MAX_NEEDLE_LEN (14), fits in u8 + let accept_state = u8::try_from(needle.len()).unwrap(); + let escape_sentinel = u8::try_from(needle.len() + 1).unwrap(); let byte_table = kmp_byte_transitions(needle); let sym_trans = @@ -987,14 +1015,14 @@ impl ShiftDfa { let mut state = 0u8; let mut iter = codes.iter(); while let Some(&code) = iter.next() { - let packed = self.transitions[code as usize]; - let next = ((packed >> (state as u32 * Self::BITS)) & Self::MASK) as u8; + let packed = self.transitions[usize::from(code)]; + let next = shift_extract(packed, state, Self::BITS); if next == self.escape_sentinel { let Some(&b) = iter.next() else { return false; }; - let esc_packed = self.escape_transitions[b as usize]; - state = ((esc_packed >> (state as u32 * Self::BITS)) & Self::MASK) as u8; + let esc_packed = self.escape_transitions[usize::from(b)]; + state = shift_extract(esc_packed, state, Self::BITS); } else { state = next; } @@ -1022,8 +1050,10 @@ impl FusedDfa { debug_assert!(needle.len() <= Self::MAX_NEEDLE_LEN); let n_states = needle.len() + 1; - let accept_state = needle.len() as u8; - let escape_sentinel = needle.len() as u8 + 1; + // needle.len() ≤ 254, fits in u8 + let accept_state = u8::try_from(needle.len()).unwrap(); + // needle.len() + 1 ≤ 255, fits in u8 + let escape_sentinel = u8::try_from(needle.len() + 1).unwrap(); let byte_table = kmp_byte_transitions(needle); let sym_trans = @@ -1032,7 +1062,11 @@ impl FusedDfa { let transitions = build_fused_table(&sym_trans, symbols.len(), n_states, |_| escape_sentinel, 0); - let escape_transitions: Vec = byte_table.iter().map(|&v| v as u8).collect(); + // byte_table values are state ids < n_states ≤ 255 + let escape_transitions: Vec = byte_table + .iter() + .map(|&v| u8::try_from(v).unwrap()) + .collect(); Self { transitions, @@ -1049,14 +1083,14 @@ impl FusedDfa { while pos < codes.len() { let code = codes[pos]; pos += 1; - let next = self.transitions[state as usize * 256 + code as usize]; + let next = self.transitions[usize::from(state) * 256 + usize::from(code)]; if next == self.escape_sentinel { if pos >= codes.len() { return false; } let b = codes[pos]; pos += 1; - state = self.escape_transitions[state as usize * 256 + b as usize]; + state = self.escape_transitions[usize::from(state) * 256 + usize::from(b)]; } else { state = next; } @@ -1074,19 +1108,21 @@ impl FusedDfa { fn kmp_byte_transitions(needle: &[u8]) -> Vec { let n_states = needle.len() + 1; - let accept = needle.len() as u16; + // needle.len() ≤ 254, fits in u16 + let accept = u16::try_from(needle.len()).unwrap(); let failure = kmp_failure_table(needle); let mut table = vec![0u16; n_states * 256]; for state in 0..n_states { for byte in 0..256u16 { if state == needle.len() { - table[state * 256 + byte as usize] = accept; + table[state * 256 + usize::from(byte)] = accept; continue; } let mut s = state; loop { - if byte as u8 == needle[s] { + // byte iterates 0..256, compare without truncation + if byte == u16::from(needle[s]) { s += 1; break; } @@ -1095,7 +1131,8 @@ fn kmp_byte_transitions(needle: &[u8]) -> Vec { } s = failure[s - 1]; } - table[state * 256 + byte as usize] = s as u16; + // s ≤ needle.len() ≤ 254, fits in u16 + table[state * 256 + usize::from(byte)] = u16::try_from(s).unwrap(); } } table @@ -1120,9 +1157,9 @@ fn kmp_failure_table(needle: &[u8]) -> Vec { mod tests { use fsst::ESCAPE_CODE; - use super::FusedDfa; use super::FsstMatcher; use super::FsstPrefixDfa; + use super::FusedDfa; use super::LikeKind; fn escaped(bytes: &[u8]) -> Vec { diff --git a/fuzz/Cargo.toml b/fuzz/Cargo.toml index b80a00d66fa..e2d05b706f9 100644 --- a/fuzz/Cargo.toml +++ b/fuzz/Cargo.toml @@ -36,6 +36,7 @@ vortex-array = { workspace = true, features = ["arbitrary", "_test-harness"] } vortex-btrblocks = { workspace = true } vortex-buffer = { workspace = true } vortex-error = { workspace = true } +vortex-fsst = { workspace = true } vortex-io = { workspace = true } vortex-mask = { workspace = true } vortex-runend = { workspace = true, features = ["arbitrary"] } @@ -88,6 +89,14 @@ path = "fuzz_targets/compress_roundtrip.rs" test = false required-features = ["native"] +[[bin]] +bench = false +doc = false +name = "fsst_like" +path = "fuzz_targets/fsst_like.rs" +test = false +required-features = ["native"] + [[bin]] bench = false doc = false diff --git a/fuzz/fuzz_targets/fsst_like.rs b/fuzz/fuzz_targets/fsst_like.rs new file mode 100644 index 00000000000..8e03badff00 --- /dev/null +++ b/fuzz/fuzz_targets/fsst_like.rs @@ -0,0 +1,35 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +#![no_main] +#![allow(clippy::unwrap_used, clippy::result_large_err)] + +use std::str::FromStr; + +use libfuzzer_sys::Corpus; +use libfuzzer_sys::fuzz_target; +use tracing::level_filters::LevelFilter; +use vortex_error::vortex_panic; +use vortex_fuzz::FuzzFsstLike; +use vortex_fuzz::run_fsst_like_fuzz; + +fuzz_target!( + init: { + let fmt = tracing_subscriber::fmt::format() + .with_ansi(false) + .without_time() + .compact(); + let level = std::env::var("RUST_LOG").map( + |v| LevelFilter::from_str(v.as_str()).unwrap()).unwrap_or(LevelFilter::INFO); + tracing_subscriber::fmt() + .event_format(fmt) + .with_max_level(level) + .init(); + }, + |fuzz_action: FuzzFsstLike| -> Corpus { + match run_fsst_like_fuzz(fuzz_action) { + Ok(true) => Corpus::Keep, + Ok(false) => Corpus::Reject, + Err(e) => vortex_panic!("{e}"), + } +}); diff --git a/fuzz/src/fsst_like.rs b/fuzz/src/fsst_like.rs new file mode 100644 index 00000000000..866b078eae0 --- /dev/null +++ b/fuzz/src/fsst_like.rs @@ -0,0 +1,164 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +//! Fuzzer for FSST LIKE pushdown: compresses arbitrary strings with FSST, then +//! runs a LIKE pattern on both the compressed and uncompressed arrays, asserting +//! that the boolean results are identical. + +use std::sync::LazyLock; + +use arbitrary::Arbitrary; +use arbitrary::Unstructured; +use vortex_array::Canonical; +use vortex_array::IntoArray; +use vortex_array::VortexSessionExecute; +use vortex_array::arrays::BoolArray; +use vortex_array::arrays::ConstantArray; +use vortex_array::arrays::VarBinArray; +use vortex_array::dtype::DType; +use vortex_array::dtype::Nullability; +use vortex_array::scalar_fn::fns::like::Like; +use vortex_array::scalar_fn::fns::like::LikeOptions; +use vortex_array::session::ArraySession; +use vortex_error::VortexResult; +use vortex_fsst::FSSTArray; +use vortex_fsst::fsst_compress; +use vortex_fsst::fsst_train_compressor; +use vortex_session::VortexSession; + +use crate::error::Backtrace; +use crate::error::VortexFuzzError; +use crate::error::VortexFuzzResult; + +static SESSION: LazyLock = + LazyLock::new(|| VortexSession::empty().with::()); + +/// Fuzz input: a set of strings and a LIKE pattern. +#[derive(Debug)] +pub struct FuzzFsstLike { + pub strings: Vec, + pub pattern: String, + pub negated: bool, +} + +impl<'a> Arbitrary<'a> for FuzzFsstLike { + fn arbitrary(u: &mut Unstructured<'a>) -> arbitrary::Result { + // Generate 1-200 strings, each 0-100 bytes from a small alphabet + // to increase FSST symbol reuse and substring hits. + let n_strings: usize = u.int_in_range(1..=200)?; + let mut strings = Vec::with_capacity(n_strings); + for _ in 0..n_strings { + let len: usize = u.int_in_range(0..=100)?; + let s: String = (0..len) + .map(|_| { + let b = u.int_in_range(b'a'..=b'h').unwrap_or(b'a'); + b as char + }) + .collect(); + strings.push(s); + } + + // Generate a pattern: pick a shape then fill in the literal part. + let needle_len: usize = u.int_in_range(0..=30)?; + let needle: String = (0..needle_len) + .map(|_| { + let b = u.int_in_range(b'a'..=b'h').unwrap_or(b'a'); + b as char + }) + .collect(); + + let pattern = match u.int_in_range(0..=2)? { + 0 => format!("{needle}%"), // prefix + 1 => format!("%{needle}%"), // contains + _ => format!("%{needle}"), // suffix (should fall back, still correct) + }; + + let negated: bool = u.arbitrary()?; + + Ok(FuzzFsstLike { + strings, + pattern, + negated, + }) + } +} + +/// Run the FSST LIKE fuzzer: compare LIKE on compressed vs uncompressed. +/// +/// Returns: +/// - `Ok(true)` — keep in corpus +/// - `Ok(false)` — reject (e.g. too few strings) +/// - `Err(_)` — mismatch found (bug) +#[allow(clippy::result_large_err)] +pub fn run_fsst_like_fuzz(fuzz: FuzzFsstLike) -> VortexFuzzResult { + let FuzzFsstLike { + strings, + pattern, + negated, + } = fuzz; + + if strings.is_empty() { + return Ok(false); + } + + let len = strings.len(); + + // Build uncompressed VarBinArray. + let varbin = VarBinArray::from_iter( + strings.iter().map(|s| Some(s.as_str())), + DType::Utf8(Nullability::NonNullable), + ); + + // Train FSST compressor and compress. + let compressor = fsst_train_compressor(&varbin); + let fsst_array: FSSTArray = fsst_compress(varbin.clone(), &compressor); + + let opts = LikeOptions { + negated, + case_insensitive: false, + }; + + // Run LIKE on the uncompressed array. + let expected = run_like_on_array(varbin.into_array().as_ref(), &pattern, len, opts) + .map_err(|err| VortexFuzzError::VortexError(err, Backtrace::capture()))?; + + // Run LIKE on the FSST-compressed array. + let actual = run_like_on_array(fsst_array.into_array().as_ref(), &pattern, len, opts) + .map_err(|err| VortexFuzzError::VortexError(err, Backtrace::capture()))?; + + // Compare bit-for-bit. + let expected_bits = expected.to_bit_buffer(); + let actual_bits = actual.to_bit_buffer(); + for idx in 0..len { + let expected_val = expected_bits.value(idx); + let actual_val = actual_bits.value(idx); + if expected_val != actual_val { + return Err(VortexFuzzError::ScalarMismatch( + expected_val.into(), + actual_val.into(), + idx, + Backtrace::capture(), + )); + } + } + + Ok(true) +} + +fn run_like_on_array( + array: &dyn vortex_array::DynArray, + pattern: &str, + len: usize, + opts: LikeOptions, +) -> VortexResult { + use vortex_array::ArrayRef; + use vortex_array::arrays::scalar_fn::ScalarFnArrayExt; + + let arr: ArrayRef = array.to_array(); + let pattern_arr = ConstantArray::new(pattern, len).into_array(); + let result = Like + .try_new_array(len, opts, [arr, pattern_arr])? + .into_array() + .execute::(&mut SESSION.create_execution_ctx())?; + Ok(result.into_bool()) +} diff --git a/fuzz/src/lib.rs b/fuzz/src/lib.rs index 1d117e6d113..910aa1bdc0c 100644 --- a/fuzz/src/lib.rs +++ b/fuzz/src/lib.rs @@ -6,6 +6,7 @@ mod array; pub mod compress; pub mod error; +pub mod fsst_like; // File module only available for native builds (requires vortex-file which uses tokio) #[cfg(not(target_arch = "wasm32"))] @@ -24,6 +25,8 @@ pub use compress::FuzzCompressRoundtrip; pub use compress::run_compress_roundtrip; #[cfg(not(target_arch = "wasm32"))] pub use file::FuzzFileAction; +pub use fsst_like::FuzzFsstLike; +pub use fsst_like::run_fsst_like_fuzz; #[cfg(feature = "cuda")] pub use gpu::FuzzCompressGpu; #[cfg(feature = "cuda")] From 0fd7acbda984e0fa77fa9f582fc3657c879c08c7 Mon Sep 17 00:00:00 2001 From: Joe Isaacs Date: Tue, 17 Mar 2026 11:42:01 +0000 Subject: [PATCH 06/19] clean up Signed-off-by: Joe Isaacs --- encodings/fsst/Cargo.toml | 4 - encodings/fsst/benches/bitpack_strategy.rs | 278 ------------- encodings/fsst/src/dfa.rs | 433 ++++++++------------- 3 files changed, 163 insertions(+), 552 deletions(-) delete mode 100644 encodings/fsst/benches/bitpack_strategy.rs diff --git a/encodings/fsst/Cargo.toml b/encodings/fsst/Cargo.toml index c5dd318f9a9..b95eeb1f444 100644 --- a/encodings/fsst/Cargo.toml +++ b/encodings/fsst/Cargo.toml @@ -50,10 +50,6 @@ name = "fsst_url_compare" harness = false required-features = ["_test-harness"] -[[bench]] -name = "bitpack_strategy" -harness = false - [[bench]] name = "chunked_dict_fsst_builder" harness = false diff --git a/encodings/fsst/benches/bitpack_strategy.rs b/encodings/fsst/benches/bitpack_strategy.rs deleted file mode 100644 index d3d6ee348aa..00000000000 --- a/encodings/fsst/benches/bitpack_strategy.rs +++ /dev/null @@ -1,278 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: Copyright the Vortex contributors - -//! Microbenchmark comparing three bit-packing strategies for `dfa_scan_to_bitbuf`: -//! -//! 1. **manual_word** — current: pack directly into u64 words one bit at a time -//! 2. **collect_bool** — use `BitBuffer::collect_bool` with a closure -//! 3. **bool_buf_64** — write results into `[bool; 64]` stack buffer, then compress -//! -//! Uses a trivial matcher (single byte comparison) so that the packing -//! overhead dominates rather than per-string work. - -#![allow(clippy::unwrap_used, clippy::cast_possible_truncation)] - -use divan::Bencher; -use vortex_buffer::BitBuffer; -use vortex_buffer::BufferMut; - -fn main() { - divan::main(); -} - -// --------------------------------------------------------------------------- -// Test data: precomputed bool results + offsets (to keep the offset-reading -// overhead consistent across strategies while making the matcher trivial) -// --------------------------------------------------------------------------- - -const N: usize = 100_000; - -struct TestData { - /// Precomputed match results for each of the N "strings". - results: Vec, - /// Fake offsets array (N+1 entries) so the offset-reading overhead is - /// included, matching the real `dfa_scan_to_bitbuf` pattern. - offsets: Vec, - /// Single-byte "strings" — just used so the matcher reads *something*. - bytes: Vec, -} - -impl TestData { - fn new() -> Self { - let mut rng_state: u64 = 0xDEAD_BEEF; - let mut next = || -> u64 { - rng_state = rng_state.wrapping_mul(6364136223846793005).wrapping_add(1); - rng_state >> 33 - }; - - // Each "string" is exactly 1 byte so the matcher is essentially free. - let bytes: Vec = (0..N).map(|_| (next() % 256) as u8).collect(); - let offsets: Vec = (0..=N).map(|i| i as u32).collect(); - let results: Vec = bytes.iter().map(|&b| b >= 128).collect(); - - Self { - results, - offsets, - bytes, - } - } -} - -// Trivial matcher: single byte check. The real work being benchmarked is -// the bit-packing loop, not this function. -#[inline(always)] -fn matcher(data: &[u8]) -> bool { - // SAFETY: benchmark guarantees non-empty slices - unsafe { *data.get_unchecked(0) >= 128 } -} - -// --------------------------------------------------------------------------- -// Strategy 1: manual word packing (current implementation) -// --------------------------------------------------------------------------- - -#[inline(never)] -fn scan_manual_word(offsets: &[u32], all_bytes: &[u8], n: usize, negated: bool) -> BitBuffer { - let n_words = n / 64; - let remainder = n % 64; - let mut words: BufferMut = BufferMut::with_capacity(n.div_ceil(64)); - - for chunk in 0..n_words { - let base = chunk * 64; - let mut word = 0u64; - let mut start = offsets[base] as usize; - for bit in 0..64 { - let end = offsets[base + bit + 1] as usize; - word |= ((matcher(&all_bytes[start..end]) != negated) as u64) << bit; - start = end; - } - unsafe { words.push_unchecked(word) }; - } - - if remainder != 0 { - let base = n_words * 64; - let mut word = 0u64; - let mut start = offsets[base] as usize; - for bit in 0..remainder { - let end = offsets[base + bit + 1] as usize; - word |= ((matcher(&all_bytes[start..end]) != negated) as u64) << bit; - start = end; - } - unsafe { words.push_unchecked(word) }; - } - - BitBuffer::new(words.into_byte_buffer().freeze(), n) -} - -// --------------------------------------------------------------------------- -// Strategy 2: BitBuffer::collect_bool -// --------------------------------------------------------------------------- - -#[inline(never)] -fn scan_collect_bool(offsets: &[u32], all_bytes: &[u8], n: usize, negated: bool) -> BitBuffer { - let mut start = offsets[0] as usize; - BitBuffer::collect_bool(n, |i| { - let end = offsets[i + 1] as usize; - let result = matcher(&all_bytes[start..end]) != negated; - start = end; - result - }) -} - -// --------------------------------------------------------------------------- -// Strategy 3: [bool; 64] stack buffer then compress -// --------------------------------------------------------------------------- - -#[inline(never)] -fn scan_bool_buf(offsets: &[u32], all_bytes: &[u8], n: usize, negated: bool) -> BitBuffer { - let n_words = n / 64; - let remainder = n % 64; - let mut words: BufferMut = BufferMut::with_capacity(n.div_ceil(64)); - - for chunk in 0..n_words { - let base = chunk * 64; - let mut bools = [false; 64]; - let mut start = offsets[base] as usize; - for bit in 0..64 { - let end = offsets[base + bit + 1] as usize; - bools[bit] = matcher(&all_bytes[start..end]) != negated; - start = end; - } - let mut word = 0u64; - for bit in 0..64 { - word |= (bools[bit] as u64) << bit; - } - unsafe { words.push_unchecked(word) }; - } - - if remainder != 0 { - let base = n_words * 64; - let mut bools = [false; 64]; - let mut start = offsets[base] as usize; - for bit in 0..remainder { - let end = offsets[base + bit + 1] as usize; - bools[bit] = matcher(&all_bytes[start..end]) != negated; - start = end; - } - let mut word = 0u64; - for bit in 0..remainder { - word |= (bools[bit] as u64) << bit; - } - unsafe { words.push_unchecked(word) }; - } - - BitBuffer::new(words.into_byte_buffer().freeze(), n) -} - -// --------------------------------------------------------------------------- -// Strategy 4: precomputed bools (pure packing, no matcher at all) -// Isolates *just* the bool→bitbuffer packing cost. -// --------------------------------------------------------------------------- - -#[inline(never)] -fn pack_from_slice_manual(results: &[bool], n: usize) -> BitBuffer { - let n_words = n / 64; - let remainder = n % 64; - let mut words: BufferMut = BufferMut::with_capacity(n.div_ceil(64)); - - for chunk in 0..n_words { - let base = chunk * 64; - let mut word = 0u64; - for bit in 0..64 { - word |= (results[base + bit] as u64) << bit; - } - unsafe { words.push_unchecked(word) }; - } - - if remainder != 0 { - let base = n_words * 64; - let mut word = 0u64; - for bit in 0..remainder { - word |= (results[base + bit] as u64) << bit; - } - unsafe { words.push_unchecked(word) }; - } - - BitBuffer::new(words.into_byte_buffer().freeze(), n) -} - -#[inline(never)] -fn pack_from_slice_collect_bool(results: &[bool], n: usize) -> BitBuffer { - BitBuffer::collect_bool(n, |i| results[i]) -} - -#[inline(never)] -fn pack_from_slice_bool_buf(results: &[bool], n: usize) -> BitBuffer { - let n_words = n / 64; - let remainder = n % 64; - let mut words: BufferMut = BufferMut::with_capacity(n.div_ceil(64)); - - for chunk in 0..n_words { - let base = chunk * 64; - let mut bools = [false; 64]; - bools.copy_from_slice(&results[base..base + 64]); - let mut word = 0u64; - for bit in 0..64 { - word |= (bools[bit] as u64) << bit; - } - unsafe { words.push_unchecked(word) }; - } - - if remainder != 0 { - let base = n_words * 64; - let mut bools = [false; 64]; - bools[..remainder].copy_from_slice(&results[base..base + remainder]); - let mut word = 0u64; - for bit in 0..remainder { - word |= (bools[bit] as u64) << bit; - } - unsafe { words.push_unchecked(word) }; - } - - BitBuffer::new(words.into_byte_buffer().freeze(), n) -} - -// --------------------------------------------------------------------------- -// Benchmarks -// --------------------------------------------------------------------------- - -static TEST_DATA: std::sync::LazyLock = std::sync::LazyLock::new(TestData::new); - -// --- Group 1: with offset reading + trivial matcher (dfa_scan_to_bitbuf shape) --- - -#[divan::bench] -fn with_offsets_manual_word(bencher: Bencher) { - let data = &*TEST_DATA; - bencher.bench_local(|| scan_manual_word(&data.offsets, &data.bytes, N, false)); -} - -#[divan::bench] -fn with_offsets_collect_bool(bencher: Bencher) { - let data = &*TEST_DATA; - bencher.bench_local(|| scan_collect_bool(&data.offsets, &data.bytes, N, false)); -} - -#[divan::bench] -fn with_offsets_bool_buf_64(bencher: Bencher) { - let data = &*TEST_DATA; - bencher.bench_local(|| scan_bool_buf(&data.offsets, &data.bytes, N, false)); -} - -// --- Group 2: pure packing from precomputed bools (isolates packing cost) --- - -#[divan::bench] -fn pure_pack_manual_word(bencher: Bencher) { - let data = &*TEST_DATA; - bencher.bench_local(|| pack_from_slice_manual(&data.results, N)); -} - -#[divan::bench] -fn pure_pack_collect_bool(bencher: Bencher) { - let data = &*TEST_DATA; - bencher.bench_local(|| pack_from_slice_collect_bool(&data.results, N)); -} - -#[divan::bench] -fn pure_pack_bool_buf_64(bencher: Bencher) { - let data = &*TEST_DATA; - bencher.bench_local(|| pack_from_slice_bool_buf(&data.results, N)); -} diff --git a/encodings/fsst/src/dfa.rs b/encodings/fsst/src/dfa.rs index 865bfe53f28..9f2746941f2 100644 --- a/encodings/fsst/src/dfa.rs +++ b/encodings/fsst/src/dfa.rs @@ -139,7 +139,7 @@ //! //! There are two ways to handle the FSST escape code in the DFA: //! -//! **Escape sentinel** (used by `ShiftDfa`, `FusedDfa`, `FsstPrefixDfa`): +//! **Escape sentinel** (used by `FusedDfa`, `FsstPrefixDfa`): //! The escape code maps to a sentinel state. The scanner checks for it and //! reads the next byte from a separate escape transition table. //! @@ -166,7 +166,6 @@ use fsst::ESCAPE_CODE; use fsst::Symbol; use vortex_buffer::BitBuffer; -use vortex_buffer::BufferMut; use vortex_error::VortexResult; // --------------------------------------------------------------------------- @@ -188,7 +187,7 @@ enum MatcherInner { Prefix(Box), ContainsBranchless(Box), ContainsFlat(FlatBranchlessDfa), - Contains(FsstContainsDfa), + ContainsFused(FusedDfa), } impl FsstMatcher { @@ -237,7 +236,7 @@ impl FsstMatcher { needle, )) } else { - MatcherInner::Contains(FsstContainsDfa::new(symbols, symbol_lengths, needle)) + MatcherInner::ContainsFused(FusedDfa::new(symbols, symbol_lengths, needle)) } } }; @@ -253,7 +252,7 @@ impl FsstMatcher { MatcherInner::Prefix(dfa) => dfa.matches(codes), MatcherInner::ContainsBranchless(dfa) => dfa.matches(codes), MatcherInner::ContainsFlat(dfa) => dfa.matches(codes), - MatcherInner::Contains(dfa) => dfa.matches(codes), + MatcherInner::ContainsFused(dfa) => dfa.matches(codes), } } } @@ -269,10 +268,10 @@ enum LikeKind<'a> { impl<'a> LikeKind<'a> { fn parse(pattern: &'a str) -> Option { // `prefix%` (including just `%` where prefix is empty) - if let Some(prefix) = pattern.strip_suffix('%') { - if !prefix.contains(['%', '_']) { - return Some(LikeKind::Prefix(prefix)); - } + if let Some(prefix) = pattern.strip_suffix('%') + && !prefix.contains(['%', '_']) + { + return Some(LikeKind::Prefix(prefix)); } // `%needle%` @@ -289,10 +288,6 @@ impl<'a> LikeKind<'a> { // Scan helper // --------------------------------------------------------------------------- -/// Scan all strings through a DFA matcher, packing results directly into a -/// `BitBuffer` one u64 word (64 strings) at a time. This avoids the overhead -/// of `BitBufferMut::collect_bool`'s cross-crate closure indirection and -/// guarantees the compiler can see the full loop body for optimization. // TODO: add N-way ILP overrun scan for higher throughput on short strings. #[inline] pub(crate) fn dfa_scan_to_bitbuf( @@ -306,36 +301,13 @@ where T: vortex_array::dtype::IntegerPType, F: Fn(&[u8]) -> bool, { - let n_words = n / 64; - let remainder = n % 64; - let mut words: BufferMut = BufferMut::with_capacity(n.div_ceil(64)); - - for chunk in 0..n_words { - let base = chunk * 64; - let mut word = 0u64; - let mut start: usize = offsets[base].as_(); - for bit in 0..64 { - let end: usize = offsets[base + bit + 1].as_(); - word |= ((matcher(&all_bytes[start..end]) != negated) as u64) << bit; - start = end; - } - // SAFETY: we allocated capacity for n.div_ceil(64) words. - unsafe { words.push_unchecked(word) }; - } - - if remainder != 0 { - let base = n_words * 64; - let mut word = 0u64; - let mut start: usize = offsets[base].as_(); - for bit in 0..remainder { - let end: usize = offsets[base + bit + 1].as_(); - word |= ((matcher(&all_bytes[start..end]) != negated) as u64) << bit; - start = end; - } - unsafe { words.push_unchecked(word) }; - } - - BitBuffer::new(words.into_byte_buffer().freeze(), n) + let mut start: usize = offsets[0].as_(); + BitBuffer::collect_bool(n, |i| { + let end: usize = offsets[i + 1].as_(); + let result = matcher(&all_bytes[start..end]) != negated; + start = end; + result + }) } // --------------------------------------------------------------------------- @@ -346,11 +318,14 @@ where /// /// Each state occupies `bits` bits. The mask `(1 << bits) - 1` guarantees the /// result is at most 15 (for `bits = 4`), which always fits in `u8`. +#[expect( + clippy::cast_possible_truncation, + reason = "masked to `bits` bits (≤4), result ≤ 15" +)] #[inline(always)] fn shift_extract(packed: u64, state: u8, bits: u32) -> u8 { let mask = (1u64 << bits) - 1; - // bits ≤ 4 ⇒ mask ≤ 15 ⇒ result ≤ 15, always fits in u8. - u8::try_from((packed >> (u32::from(state) * bits)) & mask).unwrap() + ((packed >> (u32::from(state) * bits)) & mask) as u8 } // --------------------------------------------------------------------------- @@ -380,16 +355,21 @@ fn build_symbol_transitions( } let sym = symbols[code].to_u64().to_le_bytes(); let sym_len = usize::from(symbol_lengths[code]); - // state < n_states ≤ 256, fits in u16 - let mut s = u16::try_from(state).unwrap(); + #[expect(clippy::cast_possible_truncation, reason = "state < n_states ≤ 256")] + let mut s = state as u16; for &b in &sym[..sym_len] { if s == u16::from(accept_state) { break; } s = byte_table[usize::from(s) * 256 + usize::from(b)]; } - // s is a state id from byte_table, always < n_states ≤ 256 - sym_trans[state * n_symbols + code] = u8::try_from(s).unwrap(); + #[expect( + clippy::cast_possible_truncation, + reason = "s is a state id < n_states ≤ 256" + )] + { + sym_trans[state * n_symbols + code] = s as u8; + } } } sym_trans @@ -429,28 +409,64 @@ fn pack_shift_table(fused: &[u8], n_states: usize, bits: u32) -> [u64; 256] { for code_byte in 0..256usize { let mut val = 0u64; for state in 0..n_states { - // state < n_states ≤ 16 for 4-bit packing, fits in u32 - val |= u64::from(fused[state * 256 + code_byte]) - << (u32::try_from(state).unwrap() * bits); + #[expect(clippy::cast_possible_truncation, reason = "state < n_states ≤ 16")] + let shift = state as u32 * bits; + val |= u64::from(fused[state * 256 + code_byte]) << shift; } packed[code_byte] = val; } packed } -/// Packs a byte-level KMP table into shift-encoded `u64` arrays for escape handling. -fn pack_escape_shift_table(byte_table: &[u16], n_states: usize, bits: u32) -> [u64; 256] { - let mut packed = [0u64; 256]; - for byte_val in 0..256usize { - let mut val = 0u64; - for state in 0..n_states { - // byte_table values are state ids < n_states ≤ 256, fit in u8 - let next = u8::try_from(byte_table[state * 256 + byte_val]).unwrap(); - val |= u64::from(next) << (u32::try_from(state).unwrap() * bits); +/// Builds an escape-folded fused transition table for contains matching. +/// +/// State layout: `[0..n-1]` match progress, `[n]` accept (sticky), `[n+1..2n]` escape shadows. +/// Total states: `2 * needle.len() + 1`. +/// +/// For normal states, the escape code maps to the corresponding escape shadow state. +/// Escape shadow states use byte-level KMP transitions so the next literal byte +/// resumes matching correctly — no branch needed in the scanner. +fn build_escape_folded_table(symbols: &[Symbol], symbol_lengths: &[u8], needle: &[u8]) -> Vec { + let n = needle.len(); + let total_states = 2 * n + 1; + #[expect( + clippy::cast_possible_truncation, + reason = "n ≤ FlatBranchlessDfa::MAX_NEEDLE_LEN (14)" + )] + let accept_state = n as u8; + + let byte_table = kmp_byte_transitions(needle); + let sym_trans = + build_symbol_transitions(symbols, symbol_lengths, &byte_table, n + 1, accept_state); + + let n_symbols = symbols.len(); + let mut fused = vec![0u8; total_states * 256]; + for code_byte in 0..256usize { + // Normal states 0..n + for s in 0..n { + if code_byte == usize::from(ESCAPE_CODE) { + #[expect(clippy::cast_possible_truncation, reason = "s + n + 1 ≤ 2*14 = 28")] + { + fused[s * 256 + code_byte] = (s + n + 1) as u8; + } + } else if code_byte < n_symbols { + fused[s * 256 + code_byte] = sym_trans[s * n_symbols + code_byte]; + } + } + // Accept state (sticky) + fused[n * 256 + code_byte] = accept_state; + // Escape shadow states n+1..2n + for s in 0..n { + let esc_state = s + n + 1; + #[expect( + clippy::cast_possible_truncation, + reason = "byte_table state ids < n+1 ≤ 15" + )] + let next = byte_table[s * 256 + code_byte] as u8; + fused[esc_state * 256 + code_byte] = next; } - packed[byte_val] = val; } - packed + fused } // --------------------------------------------------------------------------- @@ -480,9 +496,13 @@ impl FsstPrefixDfa { // Need room for states 0..prefix_len, accept, fail, and an escape sentinel. debug_assert!(prefix.len() <= Self::MAX_PREFIX_LEN); - // prefix.len() ≤ MAX_PREFIX_LEN (13), fits in u8 - let accept_state = u8::try_from(prefix.len()).unwrap(); - let fail_state = u8::try_from(prefix.len() + 1).unwrap(); + #[expect( + clippy::cast_possible_truncation, + reason = "prefix.len() ≤ MAX_PREFIX_LEN (13)" + )] + let accept_state = prefix.len() as u8; + #[expect(clippy::cast_possible_truncation, reason = "prefix.len() + 1 ≤ 14")] + let fail_state = (prefix.len() + 1) as u8; let n_states = prefix.len() + 2; // Prefix matching uses a simpler transition rule than KMP: on mismatch @@ -527,8 +547,13 @@ impl FsstPrefixDfa { esc_trans[state * 256 + b] = if next >= prefix.len() { accept_state } else { - // next ≤ prefix.len() ≤ 13, fits in u8 - u8::try_from(next).unwrap() + #[expect( + clippy::cast_possible_truncation, + reason = "next ≤ prefix.len() ≤ 13" + )] + { + next as u8 + } }; } } @@ -560,8 +585,13 @@ impl FsstPrefixDfa { let next_state = if state + 1 >= prefix.len() { u16::from(accept_state) } else { - // state + 1 ≤ prefix.len() ≤ 13, fits in u16 - u16::try_from(state + 1).unwrap() + #[expect( + clippy::cast_possible_truncation, + reason = "state + 1 ≤ prefix.len() ≤ 13" + )] + { + (state + 1) as u16 + } }; table[state * 256 + usize::from(next_byte)] = next_state; } @@ -606,34 +636,6 @@ impl FsstPrefixDfa { // DFA for contains matching (LIKE '%needle%') // --------------------------------------------------------------------------- -/// Contains DFA dispatch for long needles (>14 bytes). Short needles (len <= 7) -/// are handled by `BranchlessShiftDfa`, medium needles (8-14) by -/// `FlatBranchlessDfa`, and longer supported needles (15-254) by `FusedDfa`. -enum FsstContainsDfa { - /// Retained internal alternative; not currently selected by `FsstMatcher`. - Shift(Box), - /// Fused u8 table DFA for long needles (15-254 bytes). - Fused(FusedDfa), -} - -impl FsstContainsDfa { - pub(crate) fn new(symbols: &[Symbol], symbol_lengths: &[u8], needle: &[u8]) -> Self { - if needle.len() <= ShiftDfa::MAX_NEEDLE_LEN { - FsstContainsDfa::Shift(Box::new(ShiftDfa::new(symbols, symbol_lengths, needle))) - } else { - FsstContainsDfa::Fused(FusedDfa::new(symbols, symbol_lengths, needle)) - } - } - - #[inline] - pub(crate) fn matches(&self, codes: &[u8]) -> bool { - match self { - FsstContainsDfa::Shift(dfa) => dfa.matches(codes), - FsstContainsDfa::Fused(dfa) => dfa.matches(codes), - } - } -} - /// Branchless escape-folded DFA for short needles (len <= 7). /// /// Folds escape handling into the state space so that `matches()` is @@ -689,13 +691,13 @@ impl BranchlessShiftDfa { let n = needle.len(); debug_assert!(n <= Self::MAX_NEEDLE_LEN); - // n ≤ MAX_NEEDLE_LEN (7), fits in u8 - let accept_state = u8::try_from(n).unwrap(); + #[expect(clippy::cast_possible_truncation, reason = "n ≤ MAX_NEEDLE_LEN (7)")] + let accept_state = n as u8; let total_states = 2 * n + 1; debug_assert!(total_states <= (1 << Self::BITS)); - let transitions_1b = - Self::build_escape_folded_transitions(symbols, symbol_lengths, needle, total_states); + let fused = build_escape_folded_table(symbols, symbol_lengths, needle); + let transitions_1b = pack_shift_table(&fused, total_states, Self::BITS); // Build equivalence classes: group bytes with identical transition u64. let mut eq_class = [0u8; 256]; @@ -709,8 +711,10 @@ impl BranchlessShiftDfa { class_representatives.push(t); class_representatives.len() - 1 }); - // At most 256 equivalence classes (one per byte value), fits in u8 - eq_class[byte_val] = u8::try_from(cls).unwrap(); + #[expect(clippy::cast_possible_truncation, reason = "≤ 256 equivalence classes")] + { + eq_class[byte_val] = cls as u8; + } } let n_classes = class_representatives.len(); @@ -736,53 +740,6 @@ impl BranchlessShiftDfa { } } - /// Build the 1-byte packed transition table with escape handling folded - /// into the state space (no branch needed in the scanner). - fn build_escape_folded_transitions( - symbols: &[Symbol], - symbol_lengths: &[u8], - needle: &[u8], - total_states: usize, - ) -> [u64; 256] { - let n = needle.len(); - let n_normal_states = n + 1; - // n ≤ MAX_NEEDLE_LEN (7), fits in u8 - let accept_state = u8::try_from(n).unwrap(); - - let byte_table = kmp_byte_transitions(needle); - let sym_trans = build_symbol_transitions( - symbols, - symbol_lengths, - &byte_table, - n_normal_states, - accept_state, - ); - - // Build fused transition table with escape folding. - let n_symbols = symbols.len(); - let mut fused = vec![0u8; total_states * 256]; - for code_byte in 0..256usize { - for s in 0..n { - if code_byte == usize::from(ESCAPE_CODE) { - // s + n + 1 ≤ 2*7 = 14, fits in u8 - fused[s * 256 + code_byte] = u8::try_from(s + n + 1).unwrap(); - } else if code_byte < n_symbols { - fused[s * 256 + code_byte] = sym_trans[s * n_symbols + code_byte]; - } - } - fused[n * 256 + code_byte] = accept_state; - for s in 0..n { - let esc_state = s + n + 1; - // byte_table values are state ids < n_normal_states ≤ 8 - let next = u8::try_from(byte_table[s * 256 + code_byte]).unwrap(); - fused[esc_state * 256 + code_byte] = next; - } - } - - // Pack into u64 shift table. - pack_shift_table(&fused, total_states, Self::BITS) - } - /// Build the pair-compose table and 2-byte palette from equivalence /// class representatives. fn build_pair_compose( @@ -799,8 +756,17 @@ impl BranchlessShiftDfa { let t1 = class_reps[c1]; let mut packed = 0u64; for state in 0..total_states { - let state_shift = u32::try_from(state).unwrap() * Self::BITS; - let mid = shift_extract(t0, u8::try_from(state).unwrap(), Self::BITS); + #[expect( + clippy::cast_possible_truncation, + reason = "state < total_states ≤ 16" + )] + let state_u8 = state as u8; + #[expect( + clippy::cast_possible_truncation, + reason = "state < total_states ≤ 16" + )] + let state_shift = state as u32 * Self::BITS; + let mid = shift_extract(t0, state_u8, Self::BITS); let final_s = shift_extract(t1, mid, Self::BITS); packed |= u64::from(final_s) << state_shift; } @@ -811,8 +777,13 @@ impl BranchlessShiftDfa { palette_2b.push(packed); palette_2b.len() - 1 }); - // Palette size bounded by n_classes^2, in practice ≤ ~36 - pair_compose[c0 * n_classes + c1] = u8::try_from(idx).unwrap(); + #[expect( + clippy::cast_possible_truncation, + reason = "palette size ≤ n_classes² ≤ 256" + )] + { + pair_compose[c0 * n_classes + c1] = idx as u8; + } } } (pair_compose, palette_2b) @@ -826,8 +797,17 @@ impl BranchlessShiftDfa { for p1 in 0..n { let mut packed = 0u64; for state in 0..total_states { - let state_shift = u32::try_from(state).unwrap() * Self::BITS; - let mid = shift_extract(palette_2b[p0], u8::try_from(state).unwrap(), Self::BITS); + #[expect( + clippy::cast_possible_truncation, + reason = "state < total_states ≤ 16" + )] + let state_u8 = state as u8; + #[expect( + clippy::cast_possible_truncation, + reason = "state < total_states ≤ 16" + )] + let state_shift = state as u32 * Self::BITS; + let mid = shift_extract(palette_2b[p0], state_u8, Self::BITS); let final_s = shift_extract(palette_2b[p1], mid, Self::BITS); packed |= u64::from(final_s) << state_shift; } @@ -908,40 +888,15 @@ impl FlatBranchlessDfa { pub(crate) const MAX_NEEDLE_LEN: usize = 14; pub(crate) fn new(symbols: &[Symbol], symbol_lengths: &[u8], needle: &[u8]) -> Self { - let n = needle.len(); - debug_assert!(n <= Self::MAX_NEEDLE_LEN); - - // n ≤ MAX_NEEDLE_LEN (14), fits in u8 - let accept_state = u8::try_from(n).unwrap(); - let total_states = 2 * n + 1; - let n_symbols = symbols.len(); + debug_assert!(needle.len() <= Self::MAX_NEEDLE_LEN); - let byte_table = kmp_byte_transitions(needle); - let sym_trans = - build_symbol_transitions(symbols, symbol_lengths, &byte_table, n + 1, accept_state); + #[expect( + clippy::cast_possible_truncation, + reason = "needle.len() ≤ MAX_NEEDLE_LEN (14)" + )] + let accept_state = needle.len() as u8; - // Build fused transition table with escape folding. - let mut transitions = vec![0u8; total_states * 256]; - for code_byte in 0..256usize { - // Normal states 0..n - for s in 0..n { - if code_byte == usize::from(ESCAPE_CODE) { - // s + n + 1 ≤ 2*14 = 28, fits in u8 - transitions[s * 256 + code_byte] = u8::try_from(s + n + 1).unwrap(); - } else if code_byte < n_symbols { - transitions[s * 256 + code_byte] = sym_trans[s * n_symbols + code_byte]; - } - } - // Accept state (sticky) - transitions[n * 256 + code_byte] = accept_state; - // Escape states n+1..2n - for s in 0..n { - let esc_state = s + n + 1; - // byte_table values are state ids < n+1 ≤ 15 - let next = u8::try_from(byte_table[s * 256 + code_byte]).unwrap(); - transitions[esc_state * 256 + code_byte] = next; - } - } + let transitions = build_escape_folded_table(symbols, symbol_lengths, needle); Self { transitions, @@ -959,78 +914,6 @@ impl FlatBranchlessDfa { } } -/// Shift-based DFA: packs all state transitions into a `u64` per input byte. -/// -/// For a DFA with S states (S <= 16, using 4 bits each), we store transitions -/// for ALL states in one `u64`. Transition: `next = (table[code] >> (state * 4)) & 0xF`. -/// -/// Supports needles up to 14 characters (needle.len() + 2 <= 16 to fit escape -/// sentinel). This covers virtually all practical LIKE patterns. -pub(crate) struct ShiftDfa { - /// For each code byte (0..255): a `u64` packing all state transitions. - /// Bits `[state*4 .. state*4+4)` encode the next state for that input. - transitions: [u64; 256], - /// Same layout for escape byte transitions. - escape_transitions: [u64; 256], - accept_state: u8, - escape_sentinel: u8, -} - -impl ShiftDfa { - const BITS: u32 = 4; - /// Maximum needle length: 2^BITS - 2 (need room for accept + sentinel). - const MAX_NEEDLE_LEN: usize = (1 << Self::BITS) - 2; - - fn new(symbols: &[Symbol], symbol_lengths: &[u8], needle: &[u8]) -> Self { - debug_assert!(needle.len() <= Self::MAX_NEEDLE_LEN); - - let n_states = needle.len() + 1; - // needle.len() ≤ MAX_NEEDLE_LEN (14), fits in u8 - let accept_state = u8::try_from(needle.len()).unwrap(); - let escape_sentinel = u8::try_from(needle.len() + 1).unwrap(); - - let byte_table = kmp_byte_transitions(needle); - let sym_trans = - build_symbol_transitions(symbols, symbol_lengths, &byte_table, n_states, accept_state); - - let fused = build_fused_table(&sym_trans, symbols.len(), n_states, |_| escape_sentinel, 0); - - let transitions = pack_shift_table(&fused, n_states, Self::BITS); - let escape_transitions = pack_escape_shift_table(&byte_table, n_states, Self::BITS); - - Self { - transitions, - escape_transitions, - accept_state, - escape_sentinel, - } - } - - /// Match with iterator-based traversal. - /// - /// Using `iter.next()` instead of manual index + bounds check helps the - /// compiler eliminate redundant bounds checks. - #[inline] - fn matches(&self, codes: &[u8]) -> bool { - let mut state = 0u8; - let mut iter = codes.iter(); - while let Some(&code) = iter.next() { - let packed = self.transitions[usize::from(code)]; - let next = shift_extract(packed, state, Self::BITS); - if next == self.escape_sentinel { - let Some(&b) = iter.next() else { - return false; - }; - let esc_packed = self.escape_transitions[usize::from(b)]; - state = shift_extract(esc_packed, state, Self::BITS); - } else { - state = next; - } - } - state == self.accept_state - } -} - /// Fused 256-entry u8 table DFA for contains needles in the 15-254 byte range. /// /// This representation stores state ids in `u8`, so it cannot represent @@ -1050,10 +933,10 @@ impl FusedDfa { debug_assert!(needle.len() <= Self::MAX_NEEDLE_LEN); let n_states = needle.len() + 1; - // needle.len() ≤ 254, fits in u8 - let accept_state = u8::try_from(needle.len()).unwrap(); - // needle.len() + 1 ≤ 255, fits in u8 - let escape_sentinel = u8::try_from(needle.len() + 1).unwrap(); + #[expect(clippy::cast_possible_truncation, reason = "needle.len() ≤ 254")] + let accept_state = needle.len() as u8; + #[expect(clippy::cast_possible_truncation, reason = "needle.len() + 1 ≤ 255")] + let escape_sentinel = (needle.len() + 1) as u8; let byte_table = kmp_byte_transitions(needle); let sym_trans = @@ -1065,7 +948,15 @@ impl FusedDfa { // byte_table values are state ids < n_states ≤ 255 let escape_transitions: Vec = byte_table .iter() - .map(|&v| u8::try_from(v).unwrap()) + .map(|&v| { + #[expect( + clippy::cast_possible_truncation, + reason = "state ids < n_states ≤ 255" + )] + { + v as u8 + } + }) .collect(); Self { @@ -1108,8 +999,8 @@ impl FusedDfa { fn kmp_byte_transitions(needle: &[u8]) -> Vec { let n_states = needle.len() + 1; - // needle.len() ≤ 254, fits in u16 - let accept = u16::try_from(needle.len()).unwrap(); + #[expect(clippy::cast_possible_truncation, reason = "needle.len() ≤ 254")] + let accept = needle.len() as u16; let failure = kmp_failure_table(needle); let mut table = vec![0u16; n_states * 256]; @@ -1131,8 +1022,10 @@ fn kmp_byte_transitions(needle: &[u8]) -> Vec { } s = failure[s - 1]; } - // s ≤ needle.len() ≤ 254, fits in u16 - table[state * 256 + usize::from(byte)] = u16::try_from(s).unwrap(); + #[expect(clippy::cast_possible_truncation, reason = "s ≤ needle.len() ≤ 254")] + { + table[state * 256 + usize::from(byte)] = s as u16; + } } } table From b0e94b8f444745e99bc97bc3c8e8d7c9afc111c8 Mon Sep 17 00:00:00 2001 From: Joe Isaacs Date: Wed, 18 Mar 2026 14:22:26 +0000 Subject: [PATCH 07/19] clean up Signed-off-by: Joe Isaacs --- encodings/fsst/src/dfa.rs | 1122 -------------------- encodings/fsst/src/dfa/DFA_NOTES.md | 124 +++ encodings/fsst/src/dfa/branchless_shift.rs | 227 ++++ encodings/fsst/src/dfa/flat_contains.rs | 157 +++ encodings/fsst/src/dfa/mod.rs | 539 ++++++++++ encodings/fsst/src/dfa/prefix.rs | 165 +++ encodings/fsst/src/dfa/tests.rs | 74 ++ 7 files changed, 1286 insertions(+), 1122 deletions(-) delete mode 100644 encodings/fsst/src/dfa.rs create mode 100644 encodings/fsst/src/dfa/DFA_NOTES.md create mode 100644 encodings/fsst/src/dfa/branchless_shift.rs create mode 100644 encodings/fsst/src/dfa/flat_contains.rs create mode 100644 encodings/fsst/src/dfa/mod.rs create mode 100644 encodings/fsst/src/dfa/prefix.rs create mode 100644 encodings/fsst/src/dfa/tests.rs diff --git a/encodings/fsst/src/dfa.rs b/encodings/fsst/src/dfa.rs deleted file mode 100644 index 9f2746941f2..00000000000 --- a/encodings/fsst/src/dfa.rs +++ /dev/null @@ -1,1122 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: Copyright the Vortex contributors - -//! # FSST LIKE Pushdown via DFA Construction -//! -//! This module implements DFA-based pattern matching directly on FSST-compressed -//! strings, without decompressing them. It handles two pattern shapes: -//! -//! - **Prefix**: `'prefix%'` — matches strings starting with a literal prefix. -//! - **Contains**: `'%needle%'` — matches strings containing a literal substring. -//! -//! Pushdown is intentionally conservative. If the pattern shape is unsupported, -//! or if the pattern exceeds the DFA's representable state space, construction -//! returns `None` and the caller must fall back to ordinary decompression-based -//! LIKE evaluation. -//! -//! TODO(joe): suffix (`'%suffix'`) pushdown. Two approaches: -//! - **Forward DFA**: use a non-sticky accept state with KMP fallback transitions, -//! check `state == accept` after processing all codes. Branchless and vectorizable. -//! - **Backward scan**: walk the compressed code stream in reverse, comparing symbol -//! bytes from the end. Simpler, no DFA construction, but requires reverse parsing -//! of the FSST escape mechanism. -//! -//! ## Background: FSST Encoding -//! -//! [FSST](https://www.vldb.org/pvldb/vol13/p2649-boncz.pdf) compresses strings by -//! replacing frequent byte sequences with single-byte **symbol codes** (0–254). Code -//! byte 255 is reserved as the **escape code**: the next byte is a literal (uncompressed) -//! byte. So a compressed string is a stream of: -//! -//! ```text -//! [symbol_code] ... [symbol_code] [ESCAPE literal_byte] [symbol_code] ... -//! ``` -//! -//! A single symbol can expand to 1–8 bytes. Matching on compressed codes requires -//! the DFA to handle multi-byte symbol expansions and the escape mechanism. -//! -//! ## The Algorithm: KMP → Byte Table → Symbol Table → Packed DFA -//! -//! Construction proceeds through four stages: -//! -//! ### Stage 1: KMP Failure Function -//! -//! We compute the standard [KMP](https://en.wikipedia.org/wiki/Knuth%E2%80%93Morris%E2%80%93Pratt_algorithm) -//! failure function for the needle bytes. This tells us, on a mismatch at -//! position `i`, the longest proper prefix of `needle[0..i]` that is also a -//! suffix — i.e., where to resume matching instead of starting over. -//! -//! ```text -//! Needle: "abcabd" -//! Failure: [0, 0, 0, 1, 2, 0] -//! ^ ^ -//! At position 3 ('a'), the prefix "a" matches suffix "a" -//! At position 4 ('b'), the prefix "ab" matches suffix "ab" -//! ``` -//! -//! ### Stage 2: Byte-Level Transition Table -//! -//! From the failure function, we build a full `(state × byte) → state` transition -//! table. State `i` means "we have matched `needle[0..i]`". State `n` (= needle -//! length) is the **accept** state. -//! -//! ```text -//! Needle: "aba" (3 states + accept) -//! -//! Input byte -//! State 'a' 'b' other -//! ───── ──── ──── ───── -//! 0 1 0 0 ← looking for first 'a' -//! 1 1 2 0 ← matched "a", want 'b' -//! 2 3✓ 0 0 ← matched "ab", want 'a' -//! 3✓ 3✓ 3✓ 3✓ ← accept (sticky) -//! ``` -//! -//! For prefix matching, a mismatch at any state goes to a **fail** state (no -//! fallback). For contains matching, mismatches follow KMP fallback transitions -//! so we can find the needle anywhere in the string. -//! -//! ### Stage 3: Symbol-Level Transition Table -//! -//! FSST symbols can be multi-byte. To compute the transition for symbol code `c` -//! in state `s`, we simulate feeding each byte of the symbol through the byte -//! table: -//! -//! ```text -//! Symbol #42 = "the" (3 bytes) -//! State 0 + 't' → 0, + 'h' → 0, + 'e' → 0 ⟹ sym_trans[0][42] = 0 -//! -//! If needle = "them": -//! State 0 + 't' → 1, + 'h' → 2, + 'e' → 3 ⟹ sym_trans[0][42] = 3 -//! ``` -//! -//! We then build a **fused 256-wide table**: for code bytes 0–254, use the -//! symbol transition; for code byte 255 (ESCAPE_CODE), transition to a -//! special sentinel that tells the scanner to read the next literal byte. -//! -//! ### Stage 4: Packing into the Final Representation -//! -//! The fused table can be stored in different layouts depending on the number -//! of states: -//! -//! - **Shift-packed `u64`** (≤16 states): Each state needs 4 bits. All state -//! transitions for one input byte fit in a single `u64`. Lookup: -//! `next = (table[byte] >> (state * 4)) & 0xF`. One cache line per lookup. -//! -//! - **Flat `u8` table** (≤255 states): `transitions[state * 256 + byte]`. -//! Larger, but still bounded by the `u8` state representation. -//! -//! ## State-Space Limits -//! -//! The public behavior is shaped by two implementation limits, both measured in -//! pattern **bytes** rather than Unicode scalar values: -//! -//! - `prefix%` pushdown is limited to **13 bytes**. The packed prefix DFA uses -//! 4-bit state ids and needs room for normal prefix-progress states, an -//! accept state, a fail state, and one escape sentinel for FSST literals. -//! - `%needle%` pushdown is limited to **254 bytes**. The long-needle DFA stores -//! states in `u8`, so it needs room for every match-progress state plus both -//! the accept state and the escape sentinel. -//! -//! Patterns beyond those limits are still valid LIKE patterns; they simply do -//! not use FSST pushdown and must be evaluated through the fallback path. -//! -//! ## DFA Variants and When Each Is Used -//! -//! ```text -//! ┌───────────────┬──────────────────────────────────────────────────────┐ -//! │ Pattern │ Needle length → DFA variant │ -//! ├───────────────┼──────────────────────────────────────────────────────┤ -//! │ prefix% │ 0–13 → FsstPrefixDfa (shift-packed, no KMP) │ -//! ├───────────────┼──────────────────────────────────────────────────────┤ -//! │ %needle% │ 1–7 → BranchlessShiftDfa (hierarchical 4-byte) │ -//! │ │ 8–14 → FlatBranchlessDfa (flat u8, escape-folded)│ -//! │ │ 15–254 → FusedDfa (escape sentinel) │ -//! └───────────────┴──────────────────────────────────────────────────────┘ -//! ``` -//! -//! ## Escape Handling Strategies -//! -//! There are two ways to handle the FSST escape code in the DFA: -//! -//! **Escape sentinel** (used by `FusedDfa`, `FsstPrefixDfa`): -//! The escape code maps to a sentinel state. The scanner checks for it and -//! reads the next byte from a separate escape transition table. -//! -//! ```text -//! loop: -//! state = transitions[byte] // might be sentinel -//! if state == SENTINEL: -//! state = escape_transitions[next_byte] // branch -//! ``` -//! -//! **Escape folding** (used by `BranchlessShiftDfa`, `FlatBranchlessDfa`): -//! Escape states are folded into the state space. State `s+N+1` means "was in -//! state `s`, just consumed ESCAPE_CODE". The next byte's transition from an -//! escape state uses the byte-level table. No branch needed in the scanner. -//! -//! ```text -//! States: [0..N-1: normal] [N: accept] [N+1..2N: escape shadows] -//! Total: 2N+1 states. With 4-bit packing, max N=7. -//! -//! loop: -//! state = transitions[state][byte] // branchless! -//! ``` - -use fsst::ESCAPE_CODE; -use fsst::Symbol; -use vortex_buffer::BitBuffer; -use vortex_error::VortexResult; - -// --------------------------------------------------------------------------- -// FsstMatcher — unified public API -// --------------------------------------------------------------------------- - -/// A compiled matcher for LIKE patterns on FSST-compressed strings. -/// -/// Encapsulates pattern parsing and DFA variant selection. Returns `None` from -/// [`try_new`](Self::try_new) for patterns that cannot be evaluated without -/// decompression (e.g., `_` wildcards, multiple `%` in non-standard positions, -/// or patterns that exceed the DFA's representable byte-length limits). -pub(crate) struct FsstMatcher { - inner: MatcherInner, -} - -enum MatcherInner { - MatchAll, - Prefix(Box), - ContainsBranchless(Box), - ContainsFlat(FlatBranchlessDfa), - ContainsFused(FusedDfa), -} - -impl FsstMatcher { - /// Try to build a matcher for the given LIKE pattern. - /// - /// Returns `Ok(None)` if the pattern shape is not supported for pushdown - /// (e.g. `_` wildcards, multiple non-bookend `%`, `prefix%` longer than - /// 13 bytes, or `%needle%` longer than 254 bytes). - pub(crate) fn try_new( - symbols: &[Symbol], - symbol_lengths: &[u8], - pattern: &str, - ) -> VortexResult> { - let Some(like_kind) = LikeKind::parse(pattern) else { - return Ok(None); - }; - - let inner = match like_kind { - LikeKind::Prefix("") => MatcherInner::MatchAll, - LikeKind::Prefix(prefix) => { - let prefix = prefix.as_bytes(); - if prefix.len() > FsstPrefixDfa::MAX_PREFIX_LEN { - return Ok(None); - } - MatcherInner::Prefix(Box::new(FsstPrefixDfa::new( - symbols, - symbol_lengths, - prefix, - ))) - } - LikeKind::Contains(needle) => { - let needle = needle.as_bytes(); - if needle.len() > FusedDfa::MAX_NEEDLE_LEN { - return Ok(None); - } - if needle.len() <= BranchlessShiftDfa::MAX_NEEDLE_LEN { - MatcherInner::ContainsBranchless(Box::new(BranchlessShiftDfa::new( - symbols, - symbol_lengths, - needle, - ))) - } else if needle.len() <= FlatBranchlessDfa::MAX_NEEDLE_LEN { - MatcherInner::ContainsFlat(FlatBranchlessDfa::new( - symbols, - symbol_lengths, - needle, - )) - } else { - MatcherInner::ContainsFused(FusedDfa::new(symbols, symbol_lengths, needle)) - } - } - }; - - Ok(Some(Self { inner })) - } - - /// Run the matcher on a single FSST-compressed code sequence. - #[inline] - pub(crate) fn matches(&self, codes: &[u8]) -> bool { - match &self.inner { - MatcherInner::MatchAll => true, - MatcherInner::Prefix(dfa) => dfa.matches(codes), - MatcherInner::ContainsBranchless(dfa) => dfa.matches(codes), - MatcherInner::ContainsFlat(dfa) => dfa.matches(codes), - MatcherInner::ContainsFused(dfa) => dfa.matches(codes), - } - } -} - -/// The subset of LIKE patterns we can handle without decompression. -enum LikeKind<'a> { - /// `prefix%` - Prefix(&'a str), - /// `%needle%` - Contains(&'a str), -} - -impl<'a> LikeKind<'a> { - fn parse(pattern: &'a str) -> Option { - // `prefix%` (including just `%` where prefix is empty) - if let Some(prefix) = pattern.strip_suffix('%') - && !prefix.contains(['%', '_']) - { - return Some(LikeKind::Prefix(prefix)); - } - - // `%needle%` - let inner = pattern.strip_prefix('%')?.strip_suffix('%')?; - if !inner.contains(['%', '_']) { - return Some(LikeKind::Contains(inner)); - } - - None - } -} - -// --------------------------------------------------------------------------- -// Scan helper -// --------------------------------------------------------------------------- - -// TODO: add N-way ILP overrun scan for higher throughput on short strings. -#[inline] -pub(crate) fn dfa_scan_to_bitbuf( - n: usize, - offsets: &[T], - all_bytes: &[u8], - negated: bool, - matcher: F, -) -> BitBuffer -where - T: vortex_array::dtype::IntegerPType, - F: Fn(&[u8]) -> bool, -{ - let mut start: usize = offsets[0].as_(); - BitBuffer::collect_bool(n, |i| { - let end: usize = offsets[i + 1].as_(); - let result = matcher(&all_bytes[start..end]) != negated; - start = end; - result - }) -} - -// --------------------------------------------------------------------------- -// Shared helpers -// --------------------------------------------------------------------------- - -/// Extract a state id from a shift-packed `u64` word. -/// -/// Each state occupies `bits` bits. The mask `(1 << bits) - 1` guarantees the -/// result is at most 15 (for `bits = 4`), which always fits in `u8`. -#[expect( - clippy::cast_possible_truncation, - reason = "masked to `bits` bits (≤4), result ≤ 15" -)] -#[inline(always)] -fn shift_extract(packed: u64, state: u8, bits: u32) -> u8 { - let mask = (1u64 << bits) - 1; - ((packed >> (u32::from(state) * bits)) & mask) as u8 -} - -// --------------------------------------------------------------------------- -// DFA construction helpers -// --------------------------------------------------------------------------- - -/// Builds the per-symbol transition table for FSST symbols. -/// -/// For each `(state, symbol_code)` pair, simulates feeding the symbol's bytes -/// through the byte-level transition table to compute the resulting state. -/// -/// Returns a flat `Vec` indexed as `[state * n_symbols + code]`. -fn build_symbol_transitions( - symbols: &[Symbol], - symbol_lengths: &[u8], - byte_table: &[u16], - n_states: usize, - accept_state: u8, -) -> Vec { - let n_symbols = symbols.len(); - let mut sym_trans = vec![0u8; n_states * n_symbols]; - for state in 0..n_states { - for code in 0..n_symbols { - if state == usize::from(accept_state) { - sym_trans[state * n_symbols + code] = accept_state; - continue; - } - let sym = symbols[code].to_u64().to_le_bytes(); - let sym_len = usize::from(symbol_lengths[code]); - #[expect(clippy::cast_possible_truncation, reason = "state < n_states ≤ 256")] - let mut s = state as u16; - for &b in &sym[..sym_len] { - if s == u16::from(accept_state) { - break; - } - s = byte_table[usize::from(s) * 256 + usize::from(b)]; - } - #[expect( - clippy::cast_possible_truncation, - reason = "s is a state id < n_states ≤ 256" - )] - { - sym_trans[state * n_symbols + code] = s as u8; - } - } - } - sym_trans -} - -/// Builds a fused 256-wide transition table from symbol transitions. -/// -/// For each `(state, code_byte)`: -/// - Code bytes `0..n_symbols`: use the symbol transition -/// - `ESCAPE_CODE`: maps to `escape_value` (either a sentinel or escape state) -/// - All others: use `default` (typically 0 for contains, fail_state for prefix) -/// -/// Returns a flat `Vec` indexed as `[state * 256 + code_byte]`. -fn build_fused_table( - sym_trans: &[u8], - n_symbols: usize, - n_states: usize, - escape_value_fn: impl Fn(usize) -> u8, - default: u8, -) -> Vec { - let mut fused = vec![default; n_states * 256]; - for state in 0..n_states { - for code in 0..n_symbols { - fused[state * 256 + code] = sym_trans[state * n_symbols + code]; - } - fused[state * 256 + ESCAPE_CODE as usize] = escape_value_fn(state); - } - fused -} - -/// Packs a fused table into shift-encoded `u64` arrays. -/// -/// Each `u64` encodes transitions for ALL states for one input byte. -/// Lookup: `next = (table[byte] >> (state * BITS)) & MASK`. -fn pack_shift_table(fused: &[u8], n_states: usize, bits: u32) -> [u64; 256] { - let mut packed = [0u64; 256]; - for code_byte in 0..256usize { - let mut val = 0u64; - for state in 0..n_states { - #[expect(clippy::cast_possible_truncation, reason = "state < n_states ≤ 16")] - let shift = state as u32 * bits; - val |= u64::from(fused[state * 256 + code_byte]) << shift; - } - packed[code_byte] = val; - } - packed -} - -/// Builds an escape-folded fused transition table for contains matching. -/// -/// State layout: `[0..n-1]` match progress, `[n]` accept (sticky), `[n+1..2n]` escape shadows. -/// Total states: `2 * needle.len() + 1`. -/// -/// For normal states, the escape code maps to the corresponding escape shadow state. -/// Escape shadow states use byte-level KMP transitions so the next literal byte -/// resumes matching correctly — no branch needed in the scanner. -fn build_escape_folded_table(symbols: &[Symbol], symbol_lengths: &[u8], needle: &[u8]) -> Vec { - let n = needle.len(); - let total_states = 2 * n + 1; - #[expect( - clippy::cast_possible_truncation, - reason = "n ≤ FlatBranchlessDfa::MAX_NEEDLE_LEN (14)" - )] - let accept_state = n as u8; - - let byte_table = kmp_byte_transitions(needle); - let sym_trans = - build_symbol_transitions(symbols, symbol_lengths, &byte_table, n + 1, accept_state); - - let n_symbols = symbols.len(); - let mut fused = vec![0u8; total_states * 256]; - for code_byte in 0..256usize { - // Normal states 0..n - for s in 0..n { - if code_byte == usize::from(ESCAPE_CODE) { - #[expect(clippy::cast_possible_truncation, reason = "s + n + 1 ≤ 2*14 = 28")] - { - fused[s * 256 + code_byte] = (s + n + 1) as u8; - } - } else if code_byte < n_symbols { - fused[s * 256 + code_byte] = sym_trans[s * n_symbols + code_byte]; - } - } - // Accept state (sticky) - fused[n * 256 + code_byte] = accept_state; - // Escape shadow states n+1..2n - for s in 0..n { - let esc_state = s + n + 1; - #[expect( - clippy::cast_possible_truncation, - reason = "byte_table state ids < n+1 ≤ 15" - )] - let next = byte_table[s * 256 + code_byte] as u8; - fused[esc_state * 256 + code_byte] = next; - } - } - fused -} - -// --------------------------------------------------------------------------- -// DFA for prefix matching (LIKE 'prefix%') -// --------------------------------------------------------------------------- - -/// Precomputed shift-based DFA for prefix matching on FSST codes. -/// -/// States 0..prefix_len track match progress, plus ACCEPT and FAIL. -/// Uses the same shift-based approach as the contains DFA: all state -/// transitions packed into a `u64` per code byte. For prefixes longer -/// than 13 characters, pushdown is disabled and LIKE falls back. -struct FsstPrefixDfa { - /// Packed transitions: `(table[code] >> (state * 4)) & 0xF` gives next state. - transitions: [u64; 256], - /// Packed escape transitions for literal bytes. - escape_transitions: [u64; 256], - accept_state: u8, - fail_state: u8, -} - -impl FsstPrefixDfa { - pub(crate) const BITS: u32 = 4; - const MAX_PREFIX_LEN: usize = (1 << Self::BITS) as usize - 3; - - pub(crate) fn new(symbols: &[Symbol], symbol_lengths: &[u8], prefix: &[u8]) -> Self { - // Need room for states 0..prefix_len, accept, fail, and an escape sentinel. - debug_assert!(prefix.len() <= Self::MAX_PREFIX_LEN); - - #[expect( - clippy::cast_possible_truncation, - reason = "prefix.len() ≤ MAX_PREFIX_LEN (13)" - )] - let accept_state = prefix.len() as u8; - #[expect(clippy::cast_possible_truncation, reason = "prefix.len() + 1 ≤ 14")] - let fail_state = (prefix.len() + 1) as u8; - let n_states = prefix.len() + 2; - - // Prefix matching uses a simpler transition rule than KMP: on mismatch - // we go to fail_state (no fallback). Build the byte table inline. - let byte_table = Self::build_prefix_byte_table(prefix, accept_state, fail_state); - - let sym_trans = - build_symbol_transitions(symbols, symbol_lengths, &byte_table, n_states, accept_state); - - // Override fail_state rows: fail is sticky. - let escape_sentinel = fail_state + 1; - let mut fused = build_fused_table( - &sym_trans, - symbols.len(), - n_states, - |_| escape_sentinel, - fail_state, - ); - - // Accept state is sticky for all inputs. - for code_byte in 0..256usize { - fused[accept_state as usize * 256 + code_byte] = accept_state; - } - // Fail state is sticky for all inputs. - for code_byte in 0..256usize { - fused[fail_state as usize * 256 + code_byte] = fail_state; - } - - let transitions = pack_shift_table(&fused, n_states, Self::BITS); - - // Build escape transitions from the byte table. - let mut esc_trans = vec![fail_state; n_states * 256]; - for state in 0..n_states { - if state == usize::from(accept_state) { - for b in 0..256 { - esc_trans[state * 256 + b] = accept_state; - } - } else if state != usize::from(fail_state) { - for b in 0..256usize { - if b == usize::from(prefix[state]) { - let next = state + 1; - esc_trans[state * 256 + b] = if next >= prefix.len() { - accept_state - } else { - #[expect( - clippy::cast_possible_truncation, - reason = "next ≤ prefix.len() ≤ 13" - )] - { - next as u8 - } - }; - } - } - } - } - let escape_transitions = pack_shift_table(&esc_trans, n_states, Self::BITS); - - Self { - transitions, - escape_transitions, - accept_state, - fail_state, - } - } - - /// Build a byte-level transition table for prefix matching (no KMP fallback). - fn build_prefix_byte_table(prefix: &[u8], accept_state: u8, fail_state: u8) -> Vec { - let n_states = prefix.len() + 2; - let mut table = vec![u16::from(fail_state); n_states * 256]; - - for state in 0..n_states { - if state == usize::from(accept_state) { - for byte in 0..256 { - table[state * 256 + byte] = u16::from(accept_state); - } - } else if state != usize::from(fail_state) { - // Only the correct next byte advances; everything else fails. - let next_byte = prefix[state]; - let next_state = if state + 1 >= prefix.len() { - u16::from(accept_state) - } else { - #[expect( - clippy::cast_possible_truncation, - reason = "state + 1 ≤ prefix.len() ≤ 13" - )] - { - (state + 1) as u16 - } - }; - table[state * 256 + usize::from(next_byte)] = next_state; - } - } - table - } - - #[inline] - pub(crate) fn matches(&self, codes: &[u8]) -> bool { - let mut state = 0u8; - let mut pos = 0; - while pos < codes.len() { - let code = codes[pos]; - pos += 1; - let packed = self.transitions[usize::from(code)]; - // Masked to BITS (4) bits, result ≤ 15, fits in u8 - let next = shift_extract(packed, state, Self::BITS); - if next == self.fail_state + 1 { - // Escape sentinel: read literal byte. - if pos >= codes.len() { - return false; - } - let b = codes[pos]; - pos += 1; - let esc_packed = self.escape_transitions[usize::from(b)]; - state = shift_extract(esc_packed, state, Self::BITS); - } else { - state = next; - } - if state == self.accept_state { - return true; - } - if state == self.fail_state { - return false; - } - } - state == self.accept_state - } -} - -// --------------------------------------------------------------------------- -// DFA for contains matching (LIKE '%needle%') -// --------------------------------------------------------------------------- - -/// Branchless escape-folded DFA for short needles (len <= 7). -/// -/// Folds escape handling into the state space so that `matches()` is -/// completely branchless (except for loop control). The state layout is: -/// - States 0..N-1: normal match-progress states -/// - State N: accept (sticky for all inputs) -/// - States N+1..2N: escape states (state `s+N+1` means "was in state `s`, -/// just consumed ESCAPE_CODE") -/// -/// Total states: 2N+1. With 4-bit packing, max N=7. -/// -/// Uses a decomposed hierarchical lookup that processes 4 code bytes per -/// loop iteration with only ~3 KB of tables: -/// -/// 1. **Equivalence class table** (256 B): maps each code byte to a class -/// id. Bytes with identical transition u64s share a class -- typically -/// only ~6-10 classes exist (needle chars + escape + "miss-all"). -/// 2. **Pair-compose table** (~N^2 B): maps `(class0, class1)` to a 2-byte -/// palette index. Typically ~36 entries. -/// 3. **4-byte compose table** (~M^2 x 8 B): maps `(palette0, palette1)` to -/// the composed packed u64 for all 4 bytes. Typically ~81 entries = 648 B. -/// -/// Each loop iteration: 4 class lookups (parallel, 256 B table) -> 2 -/// pair-compose lookups (parallel, ~36 B table) -> 1 compose lookup -/// (~648 B table) -> 1 shift+mask. All tables fit in L1 cache. -struct BranchlessShiftDfa { - /// Maps each code byte to its equivalence class. Bytes with the same - /// packed transition u64 share a class. (256 bytes) - eq_class: [u8; 256], - /// Maps `(class0 * n_classes + class1)` -> 2-byte palette index. - pair_compose: Vec, - /// Number of equivalence classes (stride for pair_compose). - n_classes: usize, - /// Maps `(palette0 * n_palette + palette1)` -> composed packed u64 - /// for 4 bytes. - compose_4b: Vec, - /// Number of unique 2-byte palette entries (stride for compose_4b). - n_palette: usize, - /// 1-byte fallback transitions for trailing bytes. - transitions_1b: [u64; 256], - /// 2-byte palette for the remainder path (2-3 trailing bytes). - palette_2b: Vec, - accept_state: u8, -} - -impl BranchlessShiftDfa { - const BITS: u32 = 4; - /// Maximum needle length: need 2N+1 states to fit in 16 slots (4 bits). - /// 2*7+1 = 15 <= 16, so max N = 7. - pub(crate) const MAX_NEEDLE_LEN: usize = 7; - - pub(crate) fn new(symbols: &[Symbol], symbol_lengths: &[u8], needle: &[u8]) -> Self { - let n = needle.len(); - debug_assert!(n <= Self::MAX_NEEDLE_LEN); - - #[expect(clippy::cast_possible_truncation, reason = "n ≤ MAX_NEEDLE_LEN (7)")] - let accept_state = n as u8; - let total_states = 2 * n + 1; - debug_assert!(total_states <= (1 << Self::BITS)); - - let fused = build_escape_folded_table(symbols, symbol_lengths, needle); - let transitions_1b = pack_shift_table(&fused, total_states, Self::BITS); - - // Build equivalence classes: group bytes with identical transition u64. - let mut eq_class = [0u8; 256]; - let mut class_representatives: Vec = Vec::new(); - for byte_val in 0..256usize { - let t = transitions_1b[byte_val]; - let cls = class_representatives - .iter() - .position(|&v| v == t) - .unwrap_or_else(|| { - class_representatives.push(t); - class_representatives.len() - 1 - }); - #[expect(clippy::cast_possible_truncation, reason = "≤ 256 equivalence classes")] - { - eq_class[byte_val] = cls as u8; - } - } - let n_classes = class_representatives.len(); - - // Build pair-compose: for each (class0, class1), compose the two - // 1-byte transitions and deduplicate into a 2-byte palette. - let (pair_compose, palette_2b) = - Self::build_pair_compose(&class_representatives, n_classes, total_states); - - // Build 4-byte composition: compose_4b[p0 * n + p1] gives the packed - // u64 for applying palette_2b[p0] then palette_2b[p1] in sequence. - let n_palette = palette_2b.len(); - let compose_4b = Self::build_compose_4b(&palette_2b, total_states); - - Self { - eq_class, - pair_compose, - n_classes, - compose_4b, - n_palette, - transitions_1b, - palette_2b, - accept_state, - } - } - - /// Build the pair-compose table and 2-byte palette from equivalence - /// class representatives. - fn build_pair_compose( - class_reps: &[u64], - n_classes: usize, - total_states: usize, - ) -> (Vec, Vec) { - let mut pair_compose = vec![0u8; n_classes * n_classes]; - let mut palette_2b: Vec = Vec::new(); - - for c0 in 0..n_classes { - for c1 in 0..n_classes { - let t0 = class_reps[c0]; - let t1 = class_reps[c1]; - let mut packed = 0u64; - for state in 0..total_states { - #[expect( - clippy::cast_possible_truncation, - reason = "state < total_states ≤ 16" - )] - let state_u8 = state as u8; - #[expect( - clippy::cast_possible_truncation, - reason = "state < total_states ≤ 16" - )] - let state_shift = state as u32 * Self::BITS; - let mid = shift_extract(t0, state_u8, Self::BITS); - let final_s = shift_extract(t1, mid, Self::BITS); - packed |= u64::from(final_s) << state_shift; - } - let idx = palette_2b - .iter() - .position(|&v| v == packed) - .unwrap_or_else(|| { - palette_2b.push(packed); - palette_2b.len() - 1 - }); - #[expect( - clippy::cast_possible_truncation, - reason = "palette size ≤ n_classes² ≤ 256" - )] - { - pair_compose[c0 * n_classes + c1] = idx as u8; - } - } - } - (pair_compose, palette_2b) - } - - /// Compose pairs of 2-byte palette entries into a 4-byte lookup table. - fn build_compose_4b(palette_2b: &[u64], total_states: usize) -> Vec { - let n = palette_2b.len(); - let mut compose = vec![0u64; n * n]; - for p0 in 0..n { - for p1 in 0..n { - let mut packed = 0u64; - for state in 0..total_states { - #[expect( - clippy::cast_possible_truncation, - reason = "state < total_states ≤ 16" - )] - let state_u8 = state as u8; - #[expect( - clippy::cast_possible_truncation, - reason = "state < total_states ≤ 16" - )] - let state_shift = state as u32 * Self::BITS; - let mid = shift_extract(palette_2b[p0], state_u8, Self::BITS); - let final_s = shift_extract(palette_2b[p1], mid, Self::BITS); - packed |= u64::from(final_s) << state_shift; - } - compose[p0 * n + p1] = packed; - } - } - compose - } - - /// Process remaining bytes after the interleaved common prefix. - #[inline] - fn finish_tail(&self, mut state: u8, codes: &[u8]) -> u8 { - let chunks = codes.chunks_exact(4); - let rem = chunks.remainder(); - - for chunk in chunks { - // SAFETY: chunk[i] is u8, eq_class has 256 entries — index always in bounds. - let ec0 = unsafe { *self.eq_class.get_unchecked(usize::from(chunk[0])) }; - let ec1 = unsafe { *self.eq_class.get_unchecked(usize::from(chunk[1])) }; - let ec2 = unsafe { *self.eq_class.get_unchecked(usize::from(chunk[2])) }; - let ec3 = unsafe { *self.eq_class.get_unchecked(usize::from(chunk[3])) }; - let p0 = unsafe { - *self - .pair_compose - .get_unchecked(usize::from(ec0) * self.n_classes + usize::from(ec1)) - }; - let p1 = unsafe { - *self - .pair_compose - .get_unchecked(usize::from(ec2) * self.n_classes + usize::from(ec3)) - }; - let packed = unsafe { - *self - .compose_4b - .get_unchecked(usize::from(p0) * self.n_palette + usize::from(p1)) - }; - state = shift_extract(packed, state, Self::BITS); - } - - if rem.len() >= 2 { - let ec0 = self.eq_class[usize::from(rem[0])]; - let ec1 = self.eq_class[usize::from(rem[1])]; - let p = self.pair_compose[usize::from(ec0) * self.n_classes + usize::from(ec1)]; - let packed = self.palette_2b[usize::from(p)]; - state = shift_extract(packed, state, Self::BITS); - if rem.len() == 3 { - let packed = self.transitions_1b[usize::from(rem[2])]; - state = shift_extract(packed, state, Self::BITS); - } - } else if rem.len() == 1 { - let packed = self.transitions_1b[usize::from(rem[0])]; - state = shift_extract(packed, state, Self::BITS); - } - - state - } - - /// Branchless matching processing four code bytes per iteration. - #[inline(never)] - pub(crate) fn matches(&self, codes: &[u8]) -> bool { - self.finish_tail(0, codes) == self.accept_state - } -} - -/// Flat u8 escape-folded DFA for medium needles (8-14 chars). -/// -/// Like `BranchlessShiftDfa`, folds escape handling into the state space -/// (2N+1 states), but uses a flat `u8` transition table instead of -/// shift-packed `u64`. Supports up to 14-char needles (2*14+1 = 29 states). -/// Table size: 29 * 256 = 7,424 bytes, fits in L1. -struct FlatBranchlessDfa { - /// transitions[state * 256 + byte] -> next state - transitions: Vec, - accept_state: u8, -} - -impl FlatBranchlessDfa { - pub(crate) const MAX_NEEDLE_LEN: usize = 14; - - pub(crate) fn new(symbols: &[Symbol], symbol_lengths: &[u8], needle: &[u8]) -> Self { - debug_assert!(needle.len() <= Self::MAX_NEEDLE_LEN); - - #[expect( - clippy::cast_possible_truncation, - reason = "needle.len() ≤ MAX_NEEDLE_LEN (14)" - )] - let accept_state = needle.len() as u8; - - let transitions = build_escape_folded_table(symbols, symbol_lengths, needle); - - Self { - transitions, - accept_state, - } - } - - #[inline(never)] - pub(crate) fn matches(&self, codes: &[u8]) -> bool { - let mut state = 0u8; - for &byte in codes { - state = self.transitions[usize::from(state) * 256 + usize::from(byte)]; - } - state == self.accept_state - } -} - -/// Fused 256-entry u8 table DFA for contains needles in the 15-254 byte range. -/// -/// This representation stores state ids in `u8`, so it cannot represent -/// needles longer than 254 bytes once the accept state and escape sentinel are -/// included. -pub(crate) struct FusedDfa { - transitions: Vec, - escape_transitions: Vec, - accept_state: u8, - escape_sentinel: u8, -} - -impl FusedDfa { - const MAX_NEEDLE_LEN: usize = u8::MAX as usize - 1; - - fn new(symbols: &[Symbol], symbol_lengths: &[u8], needle: &[u8]) -> Self { - debug_assert!(needle.len() <= Self::MAX_NEEDLE_LEN); - - let n_states = needle.len() + 1; - #[expect(clippy::cast_possible_truncation, reason = "needle.len() ≤ 254")] - let accept_state = needle.len() as u8; - #[expect(clippy::cast_possible_truncation, reason = "needle.len() + 1 ≤ 255")] - let escape_sentinel = (needle.len() + 1) as u8; - - let byte_table = kmp_byte_transitions(needle); - let sym_trans = - build_symbol_transitions(symbols, symbol_lengths, &byte_table, n_states, accept_state); - - let transitions = - build_fused_table(&sym_trans, symbols.len(), n_states, |_| escape_sentinel, 0); - - // byte_table values are state ids < n_states ≤ 255 - let escape_transitions: Vec = byte_table - .iter() - .map(|&v| { - #[expect( - clippy::cast_possible_truncation, - reason = "state ids < n_states ≤ 255" - )] - { - v as u8 - } - }) - .collect(); - - Self { - transitions, - escape_transitions, - accept_state, - escape_sentinel, - } - } - - #[inline] - fn matches(&self, codes: &[u8]) -> bool { - let mut state = 0u8; - let mut pos = 0; - while pos < codes.len() { - let code = codes[pos]; - pos += 1; - let next = self.transitions[usize::from(state) * 256 + usize::from(code)]; - if next == self.escape_sentinel { - if pos >= codes.len() { - return false; - } - let b = codes[pos]; - pos += 1; - state = self.escape_transitions[usize::from(state) * 256 + usize::from(b)]; - } else { - state = next; - } - if state == self.accept_state { - return true; - } - } - false - } -} - -// --------------------------------------------------------------------------- -// KMP helpers -// --------------------------------------------------------------------------- - -fn kmp_byte_transitions(needle: &[u8]) -> Vec { - let n_states = needle.len() + 1; - #[expect(clippy::cast_possible_truncation, reason = "needle.len() ≤ 254")] - let accept = needle.len() as u16; - let failure = kmp_failure_table(needle); - - let mut table = vec![0u16; n_states * 256]; - for state in 0..n_states { - for byte in 0..256u16 { - if state == needle.len() { - table[state * 256 + usize::from(byte)] = accept; - continue; - } - let mut s = state; - loop { - // byte iterates 0..256, compare without truncation - if byte == u16::from(needle[s]) { - s += 1; - break; - } - if s == 0 { - break; - } - s = failure[s - 1]; - } - #[expect(clippy::cast_possible_truncation, reason = "s ≤ needle.len() ≤ 254")] - { - table[state * 256 + usize::from(byte)] = s as u16; - } - } - } - table -} - -fn kmp_failure_table(needle: &[u8]) -> Vec { - let mut failure = vec![0usize; needle.len()]; - let mut k = 0; - for i in 1..needle.len() { - while k > 0 && needle[k] != needle[i] { - k = failure[k - 1]; - } - if needle[k] == needle[i] { - k += 1; - } - failure[i] = k; - } - failure -} - -#[cfg(test)] -mod tests { - use fsst::ESCAPE_CODE; - - use super::FsstMatcher; - use super::FsstPrefixDfa; - use super::FusedDfa; - use super::LikeKind; - - fn escaped(bytes: &[u8]) -> Vec { - let mut codes = Vec::with_capacity(bytes.len() * 2); - for &b in bytes { - codes.push(ESCAPE_CODE); - codes.push(b); - } - codes - } - - #[test] - fn test_like_kind_parse() { - assert!(matches!( - LikeKind::parse("http%"), - Some(LikeKind::Prefix("http")) - )); - assert!(matches!( - LikeKind::parse("%needle%"), - Some(LikeKind::Contains("needle")) - )); - assert!(matches!(LikeKind::parse("%"), Some(LikeKind::Prefix("")))); - // Suffix and underscore patterns are not supported. - assert!(LikeKind::parse("%suffix").is_none()); - assert!(LikeKind::parse("a_c").is_none()); - } - - #[test] - fn test_prefix_pushdown_len_13_with_escapes() { - let matcher = FsstMatcher::try_new(&[], &[], "abcdefghijklm%") - .unwrap() - .unwrap(); - - assert!(matcher.matches(&escaped(b"abcdefghijklm"))); - assert!(!matcher.matches(&escaped(b"abcdefghijklx"))); - } - - #[test] - fn test_prefix_pushdown_rejects_len_14() { - debug_assert_eq!(FsstPrefixDfa::MAX_PREFIX_LEN, 13); - assert!( - FsstMatcher::try_new(&[], &[], "abcdefghijklmn%") - .unwrap() - .is_none() - ); - } - - #[test] - fn test_contains_pushdown_len_254_with_escapes() { - let needle = "a".repeat(FusedDfa::MAX_NEEDLE_LEN); - let pattern = format!("%{needle}%"); - let matcher = FsstMatcher::try_new(&[], &[], &pattern).unwrap().unwrap(); - - assert!(matcher.matches(&escaped(needle.as_bytes()))); - - let mut mismatch = needle.into_bytes(); - mismatch[FusedDfa::MAX_NEEDLE_LEN - 1] = b'b'; - assert!(!matcher.matches(&escaped(&mismatch))); - } - - #[test] - fn test_contains_pushdown_rejects_len_255() { - let needle = "a".repeat(FusedDfa::MAX_NEEDLE_LEN + 1); - let pattern = format!("%{needle}%"); - assert!(FsstMatcher::try_new(&[], &[], &pattern).unwrap().is_none()); - } -} diff --git a/encodings/fsst/src/dfa/DFA_NOTES.md b/encodings/fsst/src/dfa/DFA_NOTES.md new file mode 100644 index 00000000000..2fe7708f55d --- /dev/null +++ b/encodings/fsst/src/dfa/DFA_NOTES.md @@ -0,0 +1,124 @@ +# DFA Refactoring Notes + +## Summary of changes (from 1229 → 1110 lines) + +Unified 5 DFA structs down to 3: + +| Before | After | What happened | +|--------|-------|---------------| +| `ShiftDfa` | (deleted) | Dead code — `FsstContainsDfa` only routed needles >14 to it, but `ShiftDfa::MAX_NEEDLE_LEN` was 14, so the arm was unreachable | +| `FsstContainsDfa` | (deleted) | Dispatch enum wrapping dead `ShiftDfa` arm; only the `FusedDfa` path was reachable | +| `FlatBranchlessDfa` | `FlatContainsDfa` | Merged with `FusedDfa` into single struct with `EscapeStrategy` enum | +| `FusedDfa` | `FlatContainsDfa` | Merged (see above) | +| `BranchlessShiftDfa` | `BranchlessShiftDfa` | Unchanged | +| `FsstPrefixDfa` | `FsstPrefixDfa` | Simplified escape transition building | + +Other changes: +- Extracted `build_escape_folded_table()` (shared by `BranchlessShiftDfa` and `FlatContainsDfa`) +- Extracted `compose_packed()` (shared by `build_pair_compose` and `build_compose_4b`) +- Extended `FlatContainsDfa` folded range from 14 → 127 (2*127+1=255 fits in u8) +- Simplified `FsstPrefixDfa` escape transitions (reuse byte_table directly) +- Deleted `pack_escape_shift_table` (only caller was `ShiftDfa`) + +## Removed code (recoverable from git) + +All removed code is in commit `e08fb69ad` (the starting point). Key pieces: + +### `ShiftDfa` (~70 lines) +Shift-packed `[u64; 256]` DFA using escape sentinel. Was identical in scan loop to +`BranchlessShiftDfa` but without the hierarchical 4-byte compose optimization. +Recovery: `git show e08fb69ad:encodings/fsst/src/dfa.rs` lines ~956-1027. + +### `pack_escape_shift_table` (~15 lines) +Built a separate shift-packed escape transition table. Only used by `ShiftDfa`. +Recovery: same commit, lines ~418-433. + +### `FsstContainsDfa` enum (~25 lines) +Dispatch enum: `ShiftDfa` for len ≤ 14, `FusedDfa` for len > 14. +Since caller guaranteed len > 14, the `ShiftDfa` arm was dead. +Recovery: same commit, lines ~592-615. + +## Benchmark results: escape strategy comparison + +Sentinel-only is 28-45% slower than folded for needles 8-14. +Both strategies must be kept in `FlatContainsDfa`. + +| Benchmark | Needle len | Folded (ms) | Sentinel (ms) | Regression | +|-----------|-----------|-------------|---------------|------------| +| contains/log | 9 | 5.449 | 7.480 | +37% | +| contains/json | 10 | 2.390 | 3.466 | +45% | +| contains/path | 14 | 0.937 | 1.199 | +28% | + +## Current benchmark baseline (post-refactor) + +``` +fsst_like fastest │ slowest │ median │ mean │ samples │ iters +├─ fsst_contains │ │ │ │ │ +│ ├─ cb 1.593 ms │ 2.122 ms │ 1.725 ms │ 1.745 ms │ 100 │ 100 +│ ├─ email 492.9 µs │ 697.7 µs │ 526.3 µs │ 544.2 µs │ 100 │ 100 +│ ├─ json 2.282 ms │ 2.731 ms │ 2.401 ms │ 2.406 ms │ 100 │ 100 +│ ├─ log 5.191 ms │ 5.919 ms │ 5.426 ms │ 5.439 ms │ 100 │ 100 +│ ├─ path 894.3 µs │ 1.076 ms │ 941.1 µs │ 952.8 µs │ 100 │ 100 +│ ├─ rare 1.674 ms │ 4.55 ms │ 1.814 ms │ 1.992 ms │ 100 │ 100 +│ ╰─ urls 736.8 µs │ 959.6 µs │ 837.1 µs │ 844.6 µs │ 100 │ 100 +╰─ fsst_prefix │ │ │ │ │ + ├─ cb 541.7 µs │ 761 µs │ 585.2 µs │ 598.1 µs │ 100 │ 100 + ├─ email 197.9 µs │ 305.8 µs │ 208.2 µs │ 214.6 µs │ 100 │ 100 + ├─ json 141.9 µs │ 352.6 µs │ 145.5 µs │ 151.8 µs │ 100 │ 100 + ├─ log 259.6 µs │ 378.1 µs │ 278.5 µs │ 285.3 µs │ 100 │ 100 + ├─ path 214.2 µs │ 281.1 µs │ 227.1 µs │ 230.9 µs │ 100 │ 100 + ├─ rare 153.7 µs │ 191.9 µs │ 157.1 µs │ 160.8 µs │ 100 │ 100 + ╰─ urls 260.7 µs │ 445.4 µs │ 294.2 µs │ 297.7 µs │ 100 │ 100 +``` + +DFA routing per benchmark: +- cb, email, rare, urls (needle ≤ 7) → `BranchlessShiftDfa` +- log (9), json (10), path (14) → `FlatContainsDfa` (folded) +- No benchmark exercises sentinel path (would need needle > 127) + +## Post integer-type cleanup benchmarks + +After eliminating `u16`, tightening `usize` → `u8` in `compose_packed`, `pack_shift_table`, +`kmp_failure_table`, and `kmp_byte_transitions`. All within noise of baseline. + +| Benchmark | Baseline (ms) | Current (ms) | Delta | +|-----------|--------------|-------------|-------| +| contains/cb | 1.725 | 1.695 | -1.7% | +| contains/email | 0.526 | 0.542 | +2.9% | +| contains/json | 2.401 | 2.452 | +2.1% | +| contains/log | 5.426 | 5.447 | +0.4% | +| contains/path | 0.941 | 0.949 | +0.8% | +| contains/rare | 1.814 | 1.762 | -2.9% | +| contains/urls | 0.837 | 0.812 | -3.0% | +| prefix/cb | 0.585 | 0.568 | -3.0% | +| prefix/email | 0.208 | 0.215 | +3.0% | +| prefix/json | 0.146 | 0.145 | -0.2% | +| prefix/log | 0.279 | 0.270 | -3.1% | +| prefix/path | 0.227 | 0.224 | -1.2% | +| prefix/rare | 0.157 | 0.159 | +1.1% | +| prefix/urls | 0.294 | 0.288 | -2.1% | + +## Optimization ideas for later + +### 1. 8-byte-per-iter BranchlessShiftDfa +Extend `BranchlessShiftDfa` to process 8 bytes/iteration via two 4-byte composes. +Would reduce loop overhead for long compressed strings. Tables stay the same size, +just add a `compose_8b` level on top of `compose_4b`. + +### 2. Branchless prefix DFA +`FsstPrefixDfa` currently uses escape sentinel + branch. Could use escape-folding +(like the contains DFAs) to make the prefix scan branchless. Needs 2*prefix_len+1 +states to fit in 4-bit packing, so max prefix drops from 13 to 7. Worth it if +prefix matching is a bottleneck. + +### 3. Further struct merging +`BranchlessShiftDfa` and `FlatContainsDfa` (folded) share the same escape-folded +state layout. They differ only in table representation (shift-packed u64 vs flat u8). +Could theoretically be merged, but the hierarchical 4-byte compose in +`BranchlessShiftDfa` is fundamentally different from the flat scan, so sharing code +wouldn't simplify much. + +### 4. Suffix pushdown (`%suffix`) +Two approaches noted in the module doc: +- Forward DFA with non-sticky accept (check state == accept after all codes) +- Backward scan of compressed stream diff --git a/encodings/fsst/src/dfa/branchless_shift.rs b/encodings/fsst/src/dfa/branchless_shift.rs new file mode 100644 index 00000000000..2facc49d8fd --- /dev/null +++ b/encodings/fsst/src/dfa/branchless_shift.rs @@ -0,0 +1,227 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +//! Branchless shift-packed DFA for short contains matching (`LIKE '%needle%'`, needle ≤ 7). + +use fsst::Symbol; +use vortex_error::VortexResult; +use vortex_error::vortex_bail; + +use super::build_escape_folded_table; +use super::compose_packed; +use super::pack_shift_table; +use super::shift_extract; + +/// Branchless escape-folded DFA for short needles (len <= 7). +/// +/// Folds escape handling into the state space so that `matches()` is +/// completely branchless (except for loop control). The state layout is: +/// - States 0..N-1: normal match-progress states +/// - State N: accept (sticky for all inputs) +/// - States N+1..2N: escape states (state `s+N+1` means "was in state `s`, +/// just consumed ESCAPE_CODE") +/// +/// Total states: 2N+1. With 4-bit packing, max N=7. +/// +/// Uses a decomposed hierarchical lookup that processes 4 code bytes per +/// loop iteration with only ~3 KB of tables: +/// +/// 1. **Equivalence class table** (256 B): maps each code byte to a class +/// id. Bytes with identical transition u64s share a class -- typically +/// only ~6-10 classes exist (needle chars + escape + "miss-all"). +/// 2. **Pair-compose table** (~N^2 B): maps `(class0, class1)` to a 2-byte +/// palette index. Typically ~36 entries. +/// 3. **4-byte compose table** (~M^2 x 8 B): maps `(palette0, palette1)` to +/// the composed packed u64 for all 4 bytes. Typically ~81 entries = 648 B. +/// +/// Each loop iteration: 4 class lookups (parallel, 256 B table) -> 2 +/// pair-compose lookups (parallel, ~36 B table) -> 1 compose lookup +/// (~648 B table) -> 1 shift+mask. All tables fit in L1 cache. +pub(crate) struct BranchlessShiftDfa { + /// Maps each code byte to its equivalence class. Bytes with the same + /// packed transition u64 share a class. (256 bytes) + eq_class: [u8; 256], + /// Maps `(class0 * n_classes + class1)` -> 2-byte palette index. + pair_compose: Vec, + /// Number of equivalence classes (stride for pair_compose). + n_classes: usize, + /// Maps `(palette0 * n_palette + palette1)` -> composed packed u64 + /// for 4 bytes. + compose_4b: Vec, + /// Number of unique 2-byte palette entries (stride for compose_4b). + n_palette: usize, + /// 1-byte fallback transitions for trailing bytes. + transitions_1b: [u64; 256], + /// 2-byte palette for the remainder path (2-3 trailing bytes). + palette_2b: Vec, + accept_state: u8, +} + +impl BranchlessShiftDfa { + const BITS: u32 = 4; + /// Maximum needle length: need 2N+1 states to fit in 16 slots (4 bits). + /// 2*7+1 = 15 <= 16, so max N = 7. + pub(crate) const MAX_NEEDLE_LEN: usize = 7; + + pub(crate) fn new( + symbols: &[Symbol], + symbol_lengths: &[u8], + needle: &[u8], + ) -> VortexResult { + let n = needle.len(); + if n > Self::MAX_NEEDLE_LEN { + vortex_bail!( + "needle length {} exceeds maximum {} for branchless shift DFA", + n, + Self::MAX_NEEDLE_LEN + ); + } + + #[expect(clippy::cast_possible_truncation, reason = "n ≤ MAX_NEEDLE_LEN (7)")] + let accept_state = n as u8; + let total_states = 2 * accept_state + 1; + + let fused = build_escape_folded_table(symbols, symbol_lengths, needle); + let transitions_1b = pack_shift_table(&fused, total_states, Self::BITS); + + // Build equivalence classes: group bytes with identical transition u64. + let mut eq_class = [0u8; 256]; + let mut class_representatives: Vec = Vec::new(); + for byte_val in 0..256usize { + let t = transitions_1b[byte_val]; + let cls = class_representatives + .iter() + .position(|&v| v == t) + .unwrap_or_else(|| { + class_representatives.push(t); + class_representatives.len() - 1 + }); + #[expect(clippy::cast_possible_truncation, reason = "≤ 256 equivalence classes")] + { + eq_class[byte_val] = cls as u8; + } + } + let n_classes = class_representatives.len(); + + // Build pair-compose: for each (class0, class1), compose the two + // 1-byte transitions and deduplicate into a 2-byte palette. + let (pair_compose, palette_2b) = + Self::build_pair_compose(&class_representatives, n_classes, total_states); + + // Build 4-byte composition: compose_4b[p0 * n + p1] gives the packed + // u64 for applying palette_2b[p0] then palette_2b[p1] in sequence. + let n_palette = palette_2b.len(); + let compose_4b = Self::build_compose_4b(&palette_2b, total_states); + + Ok(Self { + eq_class, + pair_compose, + n_classes, + compose_4b, + n_palette, + transitions_1b, + palette_2b, + accept_state, + }) + } + + /// Build the pair-compose table and 2-byte palette from equivalence + /// class representatives. + fn build_pair_compose( + class_reps: &[u64], + n_classes: usize, + total_states: u8, + ) -> (Vec, Vec) { + let mut pair_compose = vec![0u8; n_classes * n_classes]; + let mut palette_2b: Vec = Vec::new(); + + for c0 in 0..n_classes { + for c1 in 0..n_classes { + let packed = + compose_packed(class_reps[c0], class_reps[c1], total_states, Self::BITS); + let idx = palette_2b + .iter() + .position(|&v| v == packed) + .unwrap_or_else(|| { + palette_2b.push(packed); + palette_2b.len() - 1 + }); + #[expect( + clippy::cast_possible_truncation, + reason = "palette size ≤ n_classes² ≤ 256" + )] + { + pair_compose[c0 * n_classes + c1] = idx as u8; + } + } + } + (pair_compose, palette_2b) + } + + /// Compose pairs of 2-byte palette entries into a 4-byte lookup table. + fn build_compose_4b(palette_2b: &[u64], total_states: u8) -> Vec { + let n = palette_2b.len(); + let mut compose = vec![0u64; n * n]; + for p0 in 0..n { + for p1 in 0..n { + compose[p0 * n + p1] = + compose_packed(palette_2b[p0], palette_2b[p1], total_states, Self::BITS); + } + } + compose + } + + /// Process remaining bytes after the interleaved common prefix. + #[inline] + fn finish_tail(&self, mut state: u8, codes: &[u8]) -> u8 { + let chunks = codes.chunks_exact(4); + let rem = chunks.remainder(); + + for chunk in chunks { + // SAFETY: chunk[i] is u8, eq_class has 256 entries — index always in bounds. + let ec0 = unsafe { *self.eq_class.get_unchecked(usize::from(chunk[0])) }; + let ec1 = unsafe { *self.eq_class.get_unchecked(usize::from(chunk[1])) }; + let ec2 = unsafe { *self.eq_class.get_unchecked(usize::from(chunk[2])) }; + let ec3 = unsafe { *self.eq_class.get_unchecked(usize::from(chunk[3])) }; + let p0 = unsafe { + *self + .pair_compose + .get_unchecked(usize::from(ec0) * self.n_classes + usize::from(ec1)) + }; + let p1 = unsafe { + *self + .pair_compose + .get_unchecked(usize::from(ec2) * self.n_classes + usize::from(ec3)) + }; + let packed = unsafe { + *self + .compose_4b + .get_unchecked(usize::from(p0) * self.n_palette + usize::from(p1)) + }; + state = shift_extract(packed, state, Self::BITS); + } + + if rem.len() >= 2 { + let ec0 = self.eq_class[usize::from(rem[0])]; + let ec1 = self.eq_class[usize::from(rem[1])]; + let p = self.pair_compose[usize::from(ec0) * self.n_classes + usize::from(ec1)]; + let packed = self.palette_2b[usize::from(p)]; + state = shift_extract(packed, state, Self::BITS); + if rem.len() == 3 { + let packed = self.transitions_1b[usize::from(rem[2])]; + state = shift_extract(packed, state, Self::BITS); + } + } else if rem.len() == 1 { + let packed = self.transitions_1b[usize::from(rem[0])]; + state = shift_extract(packed, state, Self::BITS); + } + + state + } + + /// Branchless matching processing four code bytes per iteration. + #[inline(never)] + pub(crate) fn matches(&self, codes: &[u8]) -> bool { + self.finish_tail(0, codes) == self.accept_state + } +} diff --git a/encodings/fsst/src/dfa/flat_contains.rs b/encodings/fsst/src/dfa/flat_contains.rs new file mode 100644 index 00000000000..f6227c8f31f --- /dev/null +++ b/encodings/fsst/src/dfa/flat_contains.rs @@ -0,0 +1,157 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +//! Flat `u8` transition table DFA for contains matching (`LIKE '%needle%'`, needle 8-254). + +use fsst::Symbol; +use vortex_error::VortexExpect; +use vortex_error::VortexResult; +use vortex_error::vortex_bail; + +use super::build_escape_folded_table; +use super::build_fused_table; +use super::build_symbol_transitions; +use super::kmp_byte_transitions; + +/// Flat `u8` transition table DFA for contains matching (needles 8-254 bytes). +/// +/// Uses two escape strategies depending on needle length: +/// - **Escape-folded** (needle ≤ 127): escape handling is folded into the state +/// space (2N+1 states), making the scan loop branchless. +/// - **Escape sentinel** (needle 128-254): escape code maps to a sentinel state +/// with a separate byte-level escape table. Required because 2N+1 > 255 won't +/// fit in `u8`. +pub(crate) struct FlatContainsDfa { + /// `transitions[state * 256 + byte]` -> next state. + transitions: Vec, + accept_state: u8, + escape: EscapeStrategy, +} + +/// How the flat DFA handles the FSST escape code. +enum EscapeStrategy { + /// Escape states folded into the transition table (branchless scan). + Folded, + /// Escape code maps to a sentinel; next byte uses a separate table. + Sentinel { + escape_transitions: Vec, + sentinel: u8, + }, +} + +impl FlatContainsDfa { + /// Maximum needle for escape-folded mode: 2N+1 ≤ 255, so N ≤ 127. + const MAX_FOLDED_LEN: usize = 127; + /// Maximum needle overall: need accept + sentinel to fit in u8. + pub(crate) const MAX_NEEDLE_LEN: usize = u8::MAX as usize - 1; + + pub(crate) fn new( + symbols: &[Symbol], + symbol_lengths: &[u8], + needle: &[u8], + ) -> VortexResult { + if needle.len() > Self::MAX_NEEDLE_LEN { + vortex_bail!( + "needle length {} exceeds maximum {} for flat contains DFA", + needle.len(), + Self::MAX_NEEDLE_LEN + ); + } + + let accept_state = u8::try_from(needle.len()) + .vortex_expect("FlatContainsDfa: accept state must fit into u8"); + + if needle.len() <= Self::MAX_FOLDED_LEN { + let transitions = build_escape_folded_table(symbols, symbol_lengths, needle); + Ok(Self { + transitions, + accept_state, + escape: EscapeStrategy::Folded, + }) + } else { + let n_states = accept_state + 1; + let sentinel = n_states; + + let byte_table = kmp_byte_transitions(needle); + let sym_trans = build_symbol_transitions( + symbols, + symbol_lengths, + &byte_table, + n_states, + accept_state, + ); + let transitions = + build_fused_table(&sym_trans, symbols.len(), n_states, |_| sentinel, 0); + + let escape_transitions = byte_table; + + Ok(Self { + transitions, + accept_state, + escape: EscapeStrategy::Sentinel { + escape_transitions, + sentinel, + }, + }) + } + } + + #[inline(never)] + pub(crate) fn matches(&self, codes: &[u8]) -> bool { + match &self.escape { + EscapeStrategy::Folded => self.matches_folded(codes), + EscapeStrategy::Sentinel { + escape_transitions, + sentinel, + } => Self::matches_sentinel( + codes, + &self.transitions, + escape_transitions, + self.accept_state, + *sentinel, + ), + } + } + + /// Branchless scan: escape handling is folded into the state space. + #[inline(always)] + fn matches_folded(&self, codes: &[u8]) -> bool { + let mut state = 0u8; + for &byte in codes { + state = self.transitions[usize::from(state) * 256 + usize::from(byte)]; + } + state == self.accept_state + } + + /// Sentinel scan: escape code triggers a separate table lookup. + #[inline(always)] + fn matches_sentinel( + codes: &[u8], + transitions: &[u8], + escape_transitions: &[u8], + accept_state: u8, + sentinel: u8, + ) -> bool { + let mut state = 0u8; + let mut pos = 0; + while pos < codes.len() { + let code = codes[pos]; + pos += 1; + let next = transitions[usize::from(state) * 256 + usize::from(code)]; + if next == sentinel { + if pos >= codes.len() { + return false; + } + let b = codes[pos]; + pos += 1; + state = escape_transitions[usize::from(state) * 256 + usize::from(b)]; + } else { + state = next; + } + if state == accept_state { + return true; + } + } + false + } +} diff --git a/encodings/fsst/src/dfa/mod.rs b/encodings/fsst/src/dfa/mod.rs new file mode 100644 index 00000000000..fd1da377583 --- /dev/null +++ b/encodings/fsst/src/dfa/mod.rs @@ -0,0 +1,539 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +//! # FSST LIKE Pushdown via DFA Construction +//! +//! This module implements DFA-based pattern matching directly on FSST-compressed +//! strings, without decompressing them. It handles two pattern shapes: +//! +//! - **Prefix**: `'prefix%'` — matches strings starting with a literal prefix. +//! - **Contains**: `'%needle%'` — matches strings containing a literal substring. +//! +//! Pushdown is intentionally conservative. If the pattern shape is unsupported, +//! or if the pattern exceeds the DFA's representable state space, construction +//! returns `None` and the caller must fall back to ordinary decompression-based +//! LIKE evaluation. +//! +//! TODO(joe): suffix (`'%suffix'`) pushdown. Two approaches: +//! - **Forward DFA**: use a non-sticky accept state with KMP fallback transitions, +//! check `state == accept` after processing all codes. Branchless and vectorizable. +//! - **Backward scan**: walk the compressed code stream in reverse, comparing symbol +//! bytes from the end. Simpler, no DFA construction, but requires reverse parsing +//! of the FSST escape mechanism. +//! +//! ## Background: FSST Encoding +//! +//! [FSST](https://www.vldb.org/pvldb/vol13/p2649-boncz.pdf) compresses strings by +//! replacing frequent byte sequences with single-byte **symbol codes** (0–254). Code +//! byte 255 is reserved as the **escape code**: the next byte is a literal (uncompressed) +//! byte. So a compressed string is a stream of: +//! +//! ```text +//! [symbol_code] ... [symbol_code] [ESCAPE literal_byte] [symbol_code] ... +//! ``` +//! +//! A single symbol can expand to 1–8 bytes. Matching on compressed codes requires +//! the DFA to handle multi-byte symbol expansions and the escape mechanism. +//! +//! ## The Algorithm: KMP → Byte Table → Symbol Table → Packed DFA +//! +//! Construction proceeds through four stages: +//! +//! ### Stage 1: KMP Failure Function +//! +//! We compute the standard [KMP](https://en.wikipedia.org/wiki/Knuth%E2%80%93Morris%E2%80%93Pratt_algorithm) +//! failure function for the needle bytes. This tells us, on a mismatch at +//! position `i`, the longest proper prefix of `needle[0..i]` that is also a +//! suffix — i.e., where to resume matching instead of starting over. +//! +//! ```text +//! Needle: "abcabd" +//! Failure: [0, 0, 0, 1, 2, 0] +//! ^ ^ +//! At position 3 ('a'), the prefix "a" matches suffix "a" +//! At position 4 ('b'), the prefix "ab" matches suffix "ab" +//! ``` +//! +//! ### Stage 2: Byte-Level Transition Table +//! +//! From the failure function, we build a full `(state × byte) → state` transition +//! table. State `i` means "we have matched `needle[0..i]`". State `n` (= needle +//! length) is the **accept** state. +//! +//! ```text +//! Needle: "aba" (3 states + accept) +//! +//! Input byte +//! State 'a' 'b' other +//! ───── ──── ──── ───── +//! 0 1 0 0 ← looking for first 'a' +//! 1 1 2 0 ← matched "a", want 'b' +//! 2 3✓ 0 0 ← matched "ab", want 'a' +//! 3✓ 3✓ 3✓ 3✓ ← accept (sticky) +//! ``` +//! +//! For prefix matching, a mismatch at any state goes to a **fail** state (no +//! fallback). For contains matching, mismatches follow KMP fallback transitions +//! so we can find the needle anywhere in the string. +//! +//! ### Stage 3: Symbol-Level Transition Table +//! +//! FSST symbols can be multi-byte. To compute the transition for symbol code `c` +//! in state `s`, we simulate feeding each byte of the symbol through the byte +//! table: +//! +//! ```text +//! Symbol #42 = "the" (3 bytes) +//! State 0 + 't' → 0, + 'h' → 0, + 'e' → 0 ⟹ sym_trans[0][42] = 0 +//! +//! If needle = "them": +//! State 0 + 't' → 1, + 'h' → 2, + 'e' → 3 ⟹ sym_trans[0][42] = 3 +//! ``` +//! +//! We then build a **fused 256-wide table**: for code bytes 0–254, use the +//! symbol transition; for code byte 255 (ESCAPE_CODE), transition to a +//! special sentinel that tells the scanner to read the next literal byte. +//! +//! ### Stage 4: Packing into the Final Representation +//! +//! The fused table can be stored in different layouts depending on the number +//! of states: +//! +//! - **Shift-packed `u64`** (≤16 states): Each state needs 4 bits. All state +//! transitions for one input byte fit in a single `u64`. Lookup: +//! `next = (table[byte] >> (state * 4)) & 0xF`. One cache line per lookup. +//! +//! - **Flat `u8` table** (≤255 states): `transitions[state * 256 + byte]`. +//! Larger, but still bounded by the `u8` state representation. +//! +//! ## State-Space Limits +//! +//! The public behavior is shaped by two implementation limits, both measured in +//! pattern **bytes** rather than Unicode scalar values: +//! +//! - `prefix%` pushdown is limited to **13 bytes**. The packed prefix DFA uses +//! 4-bit state ids and needs room for normal prefix-progress states, an +//! accept state, a fail state, and one escape sentinel for FSST literals. +//! - `%needle%` pushdown is limited to **254 bytes**. The long-needle DFA stores +//! states in `u8`, so it needs room for every match-progress state plus both +//! the accept state and the escape sentinel. +//! +//! Patterns beyond those limits are still valid LIKE patterns; they simply do +//! not use FSST pushdown and must be evaluated through the fallback path. +//! +//! ## DFA Variants and When Each Is Used +//! +//! ```text +//! ┌───────────────┬──────────────────────────────────────────────────────┐ +//! │ Pattern │ Needle length → DFA variant │ +//! ├───────────────┼──────────────────────────────────────────────────────┤ +//! │ prefix% │ 0–13 → FsstPrefixDfa (shift-packed, no KMP) │ +//! ├───────────────┼──────────────────────────────────────────────────────┤ +//! │ %needle% │ 1–7 → BranchlessShiftDfa (hierarchical 4-byte) │ +//! │ │ 8–127 → FlatContainsDfa (flat u8, esc-folded) │ +//! │ │ 128–254 → FlatContainsDfa (flat u8, esc-sentinel) │ +//! └───────────────┴──────────────────────────────────────────────────────┘ +//! ``` +//! +//! ## Escape Handling Strategies +//! +//! There are two ways to handle the FSST escape code in the DFA: +//! +//! **Escape sentinel** (used by `FlatContainsDfa` for long needles, `FsstPrefixDfa`): +//! The escape code maps to a sentinel state. The scanner checks for it and +//! reads the next byte from a separate escape transition table. +//! +//! ```text +//! loop: +//! state = transitions[byte] // might be sentinel +//! if state == SENTINEL: +//! state = escape_transitions[next_byte] // branch +//! ``` +//! +//! **Escape folding** (used by `BranchlessShiftDfa`, `FlatContainsDfa` for short needles): +//! Escape states are folded into the state space. State `s+N+1` means "was in +//! state `s`, just consumed ESCAPE_CODE". The next byte's transition from an +//! escape state uses the byte-level table. No branch needed in the scanner. +//! +//! ```text +//! States: [0..N-1: normal] [N: accept] [N+1..2N: escape shadows] +//! Total: 2N+1 states. With 4-bit packing, max N=7. +//! +//! loop: +//! state = transitions[state][byte] // branchless! +//! ``` + +mod branchless_shift; +mod flat_contains; +mod prefix; +#[cfg(test)] +mod tests; + +use branchless_shift::BranchlessShiftDfa; +use flat_contains::FlatContainsDfa; +use fsst::ESCAPE_CODE; +use fsst::Symbol; +use prefix::FsstPrefixDfa; +use vortex_buffer::BitBuffer; +use vortex_error::VortexResult; + +// --------------------------------------------------------------------------- +// FsstMatcher — unified public API +// --------------------------------------------------------------------------- + +/// A compiled matcher for LIKE patterns on FSST-compressed strings. +/// +/// Encapsulates pattern parsing and DFA variant selection. Returns `None` from +/// [`try_new`](Self::try_new) for patterns that cannot be evaluated without +/// decompression (e.g., `_` wildcards, multiple `%` in non-standard positions, +/// or patterns that exceed the DFA's representable byte-length limits). +pub(crate) struct FsstMatcher { + inner: MatcherInner, +} + +enum MatcherInner { + MatchAll, + Prefix(Box), + ContainsBranchless(Box), + ContainsFlat(FlatContainsDfa), +} + +impl FsstMatcher { + /// Try to build a matcher for the given LIKE pattern. + /// + /// Returns `Ok(None)` if the pattern shape is not supported for pushdown + /// (e.g. `_` wildcards, multiple non-bookend `%`, `prefix%` longer than + /// 13 bytes, or `%needle%` longer than 254 bytes). + pub(crate) fn try_new( + symbols: &[Symbol], + symbol_lengths: &[u8], + pattern: &str, + ) -> VortexResult> { + let Some(like_kind) = LikeKind::parse(pattern) else { + return Ok(None); + }; + + let inner = match like_kind { + LikeKind::Prefix("") => MatcherInner::MatchAll, + LikeKind::Prefix(prefix) => { + let prefix = prefix.as_bytes(); + if prefix.len() > FsstPrefixDfa::MAX_PREFIX_LEN { + return Ok(None); + } + MatcherInner::Prefix(Box::new(FsstPrefixDfa::new( + symbols, + symbol_lengths, + prefix, + )?)) + } + LikeKind::Contains(needle) => { + let needle = needle.as_bytes(); + if needle.len() > FlatContainsDfa::MAX_NEEDLE_LEN { + return Ok(None); + } + if needle.len() <= BranchlessShiftDfa::MAX_NEEDLE_LEN { + MatcherInner::ContainsBranchless(Box::new(BranchlessShiftDfa::new( + symbols, + symbol_lengths, + needle, + )?)) + } else { + MatcherInner::ContainsFlat(FlatContainsDfa::new( + symbols, + symbol_lengths, + needle, + )?) + } + } + }; + + Ok(Some(Self { inner })) + } + + /// Run the matcher on a single FSST-compressed code sequence. + #[inline] + pub(crate) fn matches(&self, codes: &[u8]) -> bool { + match &self.inner { + MatcherInner::MatchAll => true, + MatcherInner::Prefix(dfa) => dfa.matches(codes), + MatcherInner::ContainsBranchless(dfa) => dfa.matches(codes), + MatcherInner::ContainsFlat(dfa) => dfa.matches(codes), + } + } +} + +/// The subset of LIKE patterns we can handle without decompression. +enum LikeKind<'a> { + /// `prefix%` + Prefix(&'a str), + /// `%needle%` + Contains(&'a str), +} + +impl<'a> LikeKind<'a> { + fn parse(pattern: &'a str) -> Option { + // `prefix%` (including just `%` where prefix is empty) + if let Some(prefix) = pattern.strip_suffix('%') + && !prefix.contains(['%', '_']) + { + return Some(LikeKind::Prefix(prefix)); + } + + // `%needle%` + let inner = pattern.strip_prefix('%')?.strip_suffix('%')?; + if !inner.contains(['%', '_']) { + return Some(LikeKind::Contains(inner)); + } + + None + } +} + +// --------------------------------------------------------------------------- +// Scan helper +// --------------------------------------------------------------------------- + +// TODO: add N-way ILP overrun scan for higher throughput on short strings. +#[inline] +pub(crate) fn dfa_scan_to_bitbuf( + n: usize, + offsets: &[T], + all_bytes: &[u8], + negated: bool, + matcher: F, +) -> BitBuffer +where + T: vortex_array::dtype::IntegerPType, + F: Fn(&[u8]) -> bool, +{ + let mut start: usize = offsets[0].as_(); + BitBuffer::collect_bool(n, |i| { + let end: usize = offsets[i + 1].as_(); + let result = matcher(&all_bytes[start..end]) != negated; + start = end; + result + }) +} + +// --------------------------------------------------------------------------- +// Shared helpers — used by multiple DFA implementations +// --------------------------------------------------------------------------- + +/// Extract a state id from a shift-packed `u64` word. +/// +/// Each state occupies `bits` bits. The mask `(1 << bits) - 1` guarantees the +/// result is at most 15 (for `bits = 4`), which always fits in `u8`. +#[expect( + clippy::cast_possible_truncation, + reason = "masked to `bits` bits (≤4), result ≤ 15" +)] +#[inline(always)] +fn shift_extract(packed: u64, state: u8, bits: u32) -> u8 { + let mask = (1u64 << bits) - 1; + ((packed >> (u32::from(state) * bits)) & mask) as u8 +} + +/// Compose two shift-packed transition `u64`s: for each state, apply `first` +/// then `second`, packing the result back into a single `u64`. +fn compose_packed(first: u64, second: u64, total_states: u8, bits: u32) -> u64 { + let mut packed = 0u64; + for state in 0..total_states { + let mid = shift_extract(first, state, bits); + let final_s = shift_extract(second, mid, bits); + packed |= u64::from(final_s) << (u32::from(state) * bits); + } + packed +} + +// --------------------------------------------------------------------------- +// DFA construction helpers +// --------------------------------------------------------------------------- + +/// Builds the per-symbol transition table for FSST symbols. +/// +/// For each `(state, symbol_code)` pair, simulates feeding the symbol's bytes +/// through the byte-level transition table to compute the resulting state. +/// +/// Returns a flat `Vec` indexed as `[state * n_symbols + code]`. +fn build_symbol_transitions( + symbols: &[Symbol], + symbol_lengths: &[u8], + byte_table: &[u8], + n_states: u8, + accept_state: u8, +) -> Vec { + let n_states = usize::from(n_states); + let n_symbols = symbols.len(); + let mut sym_trans = vec![0u8; n_states * n_symbols]; + for state in 0..n_states { + for code in 0..n_symbols { + if state == usize::from(accept_state) { + sym_trans[state * n_symbols + code] = accept_state; + continue; + } + let sym = symbols[code].to_u64().to_le_bytes(); + let sym_len = usize::from(symbol_lengths[code]); + let mut s = state; + for &b in &sym[..sym_len] { + if s == usize::from(accept_state) { + break; + } + s = usize::from(byte_table[s * 256 + usize::from(b)]); + } + #[expect( + clippy::cast_possible_truncation, + reason = "s is a state id < n_states ≤ 256" + )] + { + sym_trans[state * n_symbols + code] = s as u8; + } + } + } + sym_trans +} + +/// Builds a fused 256-wide transition table from symbol transitions. +/// +/// For each `(state, code_byte)`: +/// - Code bytes `0..n_symbols`: use the symbol transition +/// - `ESCAPE_CODE`: maps to `escape_value` (either a sentinel or escape state) +/// - All others: use `default` (typically 0 for contains, fail_state for prefix) +/// +/// Returns a flat `Vec` indexed as `[state * 256 + code_byte]`. +fn build_fused_table( + sym_trans: &[u8], + n_symbols: usize, + n_states: u8, + escape_value_fn: impl Fn(u8) -> u8, + default: u8, +) -> Vec { + let mut fused = vec![default; usize::from(n_states) * 256]; + for state in 0..n_states { + let s = usize::from(state); + for code in 0..n_symbols { + fused[s * 256 + code] = sym_trans[s * n_symbols + code]; + } + fused[s * 256 + usize::from(ESCAPE_CODE)] = escape_value_fn(state); + } + fused +} + +/// Packs a fused table into shift-encoded `u64` arrays. +/// +/// Each `u64` encodes transitions for ALL states for one input byte. +/// Lookup: `next = (table[byte] >> (state * BITS)) & MASK`. +fn pack_shift_table(fused: &[u8], n_states: u8, bits: u32) -> [u64; 256] { + let mut packed = [0u64; 256]; + for code_byte in 0..256usize { + let mut val = 0u64; + for state in 0..n_states { + val |= + u64::from(fused[usize::from(state) * 256 + code_byte]) << (u32::from(state) * bits); + } + packed[code_byte] = val; + } + packed +} + +/// Builds an escape-folded fused transition table for contains matching. +/// +/// State layout: `[0..n-1]` match progress, `[n]` accept (sticky), `[n+1..2n]` escape shadows. +/// Total states: `2 * needle.len() + 1`. +/// +/// For normal states, the escape code maps to the corresponding escape shadow state. +/// Escape shadow states use byte-level KMP transitions so the next literal byte +/// resumes matching correctly — no branch needed in the scanner. +fn build_escape_folded_table(symbols: &[Symbol], symbol_lengths: &[u8], needle: &[u8]) -> Vec { + #[expect( + clippy::cast_possible_truncation, + reason = "needle.len() ≤ FlatContainsDfa::MAX_FOLDED_LEN (127)" + )] + let n = needle.len() as u8; + let accept_state = n; + let total_states = usize::from(2 * n + 1); + + let byte_table = kmp_byte_transitions(needle); + let sym_trans = + build_symbol_transitions(symbols, symbol_lengths, &byte_table, n + 1, accept_state); + + let n_symbols = symbols.len(); + let n_usize = usize::from(n); + let mut fused = vec![0u8; total_states * 256]; + for code_byte in 0..256usize { + // Normal states 0..n + for s in 0..n_usize { + if code_byte == usize::from(ESCAPE_CODE) { + #[expect(clippy::cast_possible_truncation, reason = "s + n + 1 ≤ 2*127+1 = 255")] + { + fused[s * 256 + code_byte] = (s + n_usize + 1) as u8; + } + } else if code_byte < n_symbols { + fused[s * 256 + code_byte] = sym_trans[s * n_symbols + code_byte]; + } + } + // Accept state (sticky) + fused[n_usize * 256 + code_byte] = accept_state; + // Escape shadow states n+1..2n + for s in 0..n_usize { + let esc_state = s + n_usize + 1; + fused[esc_state * 256 + code_byte] = byte_table[s * 256 + code_byte]; + } + } + fused +} + +// --------------------------------------------------------------------------- +// KMP helpers +// --------------------------------------------------------------------------- + +fn kmp_byte_transitions(needle: &[u8]) -> Vec { + let n_states = needle.len() + 1; + #[expect( + clippy::cast_possible_truncation, + reason = "needle.len() ≤ 254, accept state fits in u8" + )] + let accept = needle.len() as u8; + let failure = kmp_failure_table(needle); + + let mut table = vec![0u8; n_states * 256]; + for state in 0..n_states { + for byte in 0..256usize { + if state == needle.len() { + table[state * 256 + byte] = accept; + continue; + } + #[expect( + clippy::cast_possible_truncation, + reason = "state < needle.len() ≤ 254" + )] + let mut s = state as u8; + loop { + if byte == usize::from(needle[usize::from(s)]) { + s += 1; + break; + } + if s == 0 { + break; + } + s = failure[usize::from(s) - 1]; + } + table[state * 256 + byte] = s; + } + } + table +} + +fn kmp_failure_table(needle: &[u8]) -> Vec { + let mut failure = vec![0u8; needle.len()]; + let mut k = 0u8; + for i in 1..needle.len() { + while k > 0 && needle[usize::from(k)] != needle[i] { + k = failure[usize::from(k) - 1]; + } + if needle[usize::from(k)] == needle[i] { + k += 1; + } + failure[i] = k; + } + failure +} diff --git a/encodings/fsst/src/dfa/prefix.rs b/encodings/fsst/src/dfa/prefix.rs new file mode 100644 index 00000000000..b00cd34bda2 --- /dev/null +++ b/encodings/fsst/src/dfa/prefix.rs @@ -0,0 +1,165 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +//! DFA for prefix matching (`LIKE 'prefix%'`). +//! +//! TODO(joe): support longer prefixes (14–253 bytes) via a flat `Vec` table +//! with escape sentinel, similar to `FlatContainsDfa`. The construction is simpler +//! than contains (no KMP — mismatches go to a sticky fail state). Would need states +//! 0..N (progress) + accept + fail + sentinel, so N+3 ≤ 256 → max prefix = 253. + +use fsst::Symbol; +use vortex_error::VortexExpect; +use vortex_error::VortexResult; +use vortex_error::vortex_bail; + +use super::build_fused_table; +use super::build_symbol_transitions; +use super::pack_shift_table; +use super::shift_extract; + +/// Precomputed shift-based DFA for prefix matching on FSST codes. +/// +/// States 0..prefix_len track match progress, plus ACCEPT and FAIL. +/// Uses the same shift-based approach as the contains DFA: all state +/// transitions packed into a `u64` per code byte. For prefixes longer +/// than 13 characters, pushdown is disabled and LIKE falls back. +/// +/// ```text +/// Prefix: "http" (4 progress states + accept + fail) +/// +/// Input byte +/// State 'h' 't' 'p' other +/// ───── ──── ──── ──── ───── +/// 0 1 F F F ← want 'h' +/// 1 F 2 F F ← want 't' +/// 2 F 3 F F ← want 't' +/// 3 F F 4✓ F ← want 'p' +/// 4✓ 4✓ 4✓ 4✓ 4✓ ← accept (sticky) +/// F F F F F ← fail (sticky) +/// +/// Escape handling: code 255 → sentinel → read next literal byte → byte table +/// ``` +pub(crate) struct FsstPrefixDfa { + /// Packed transitions: `(table[code] >> (state * 4)) & 0xF` gives next state. + transitions: [u64; 256], + /// Packed escape transitions for literal bytes. + escape_transitions: [u64; 256], + accept_state: u8, + fail_state: u8, +} + +impl FsstPrefixDfa { + pub(crate) const BITS: u32 = 4; + pub(crate) const MAX_PREFIX_LEN: usize = (1 << Self::BITS) as usize - 3; + + pub(crate) fn new( + symbols: &[Symbol], + symbol_lengths: &[u8], + prefix: &[u8], + ) -> VortexResult { + if prefix.len() > Self::MAX_PREFIX_LEN { + vortex_bail!( + "prefix length {} exceeds maximum {} for shift-packed prefix DFA", + prefix.len(), + Self::MAX_PREFIX_LEN + ); + } + + let accept_state = u8::try_from(prefix.len()).vortex_expect("prefix fits in u8"); + let fail_state = accept_state + 1; + let n_states = fail_state + 1; + + // Prefix matching uses a simpler transition rule than KMP: on mismatch + // we go to fail_state (no fallback). Build the byte table inline. + let byte_table = Self::build_prefix_byte_table(prefix, accept_state, fail_state); + + let sym_trans = + build_symbol_transitions(symbols, symbol_lengths, &byte_table, n_states, accept_state); + + // Override fail_state rows: fail is sticky. + let escape_sentinel = fail_state + 1; + let mut fused = build_fused_table( + &sym_trans, + symbols.len(), + n_states, + |_| escape_sentinel, + fail_state, + ); + + // Accept and fail states are sticky for all inputs. + let accept_row = usize::from(accept_state) * 256; + fused[accept_row..accept_row + 256].fill(accept_state); + let fail_row = usize::from(fail_state) * 256; + fused[fail_row..fail_row + 256].fill(fail_state); + + let transitions = pack_shift_table(&fused, n_states, Self::BITS); + + // Escape transitions: for an escaped literal byte, use the byte-level transition. + let escape_transitions = pack_shift_table(&byte_table, n_states, Self::BITS); + + Ok(Self { + transitions, + escape_transitions, + accept_state, + fail_state, + }) + } + + /// Build a byte-level transition table for prefix matching (no KMP fallback). + fn build_prefix_byte_table(prefix: &[u8], accept_state: u8, fail_state: u8) -> Vec { + let n_states = fail_state + 1; + let mut table = vec![fail_state; usize::from(n_states) * 256]; + + for state in 0..n_states { + let s = usize::from(state); + if state == accept_state { + for byte in 0..256 { + table[s * 256 + byte] = accept_state; + } + } else if state != fail_state { + // Only the correct next byte advances; everything else fails. + let next_byte = prefix[s]; + let next_state = if s + 1 >= prefix.len() { + accept_state + } else { + state + 1 + }; + table[s * 256 + usize::from(next_byte)] = next_state; + } + } + table + } + + #[inline] + pub(crate) fn matches(&self, codes: &[u8]) -> bool { + let mut state = 0u8; + let mut pos = 0; + while pos < codes.len() { + let code = codes[pos]; + pos += 1; + let packed = self.transitions[usize::from(code)]; + // Masked to BITS (4) bits, result ≤ 15, fits in u8 + let next = shift_extract(packed, state, Self::BITS); + if next == self.fail_state + 1 { + // Escape sentinel: read literal byte. + if pos >= codes.len() { + return false; + } + let b = codes[pos]; + pos += 1; + let esc_packed = self.escape_transitions[usize::from(b)]; + state = shift_extract(esc_packed, state, Self::BITS); + } else { + state = next; + } + if state == self.accept_state { + return true; + } + if state == self.fail_state { + return false; + } + } + state == self.accept_state + } +} diff --git a/encodings/fsst/src/dfa/tests.rs b/encodings/fsst/src/dfa/tests.rs new file mode 100644 index 00000000000..b7fadbcf928 --- /dev/null +++ b/encodings/fsst/src/dfa/tests.rs @@ -0,0 +1,74 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use fsst::ESCAPE_CODE; + +use super::FsstMatcher; +use super::LikeKind; +use super::flat_contains::FlatContainsDfa; +use super::prefix::FsstPrefixDfa; + +fn escaped(bytes: &[u8]) -> Vec { + let mut codes = Vec::with_capacity(bytes.len() * 2); + for &b in bytes { + codes.push(ESCAPE_CODE); + codes.push(b); + } + codes +} + +#[test] +fn test_like_kind_parse() { + assert!(matches!( + LikeKind::parse("http%"), + Some(LikeKind::Prefix("http")) + )); + assert!(matches!( + LikeKind::parse("%needle%"), + Some(LikeKind::Contains("needle")) + )); + assert!(matches!(LikeKind::parse("%"), Some(LikeKind::Prefix("")))); + // Suffix and underscore patterns are not supported. + assert!(LikeKind::parse("%suffix").is_none()); + assert!(LikeKind::parse("a_c").is_none()); +} + +#[test] +fn test_prefix_pushdown_len_13_with_escapes() { + let matcher = FsstMatcher::try_new(&[], &[], "abcdefghijklm%") + .unwrap() + .unwrap(); + + assert!(matcher.matches(&escaped(b"abcdefghijklm"))); + assert!(!matcher.matches(&escaped(b"abcdefghijklx"))); +} + +#[test] +fn test_prefix_pushdown_rejects_len_14() { + debug_assert_eq!(FsstPrefixDfa::MAX_PREFIX_LEN, 13); + assert!( + FsstMatcher::try_new(&[], &[], "abcdefghijklmn%") + .unwrap() + .is_none() + ); +} + +#[test] +fn test_contains_pushdown_len_254_with_escapes() { + let needle = "a".repeat(FlatContainsDfa::MAX_NEEDLE_LEN); + let pattern = format!("%{needle}%"); + let matcher = FsstMatcher::try_new(&[], &[], &pattern).unwrap().unwrap(); + + assert!(matcher.matches(&escaped(needle.as_bytes()))); + + let mut mismatch = needle.into_bytes(); + mismatch[FlatContainsDfa::MAX_NEEDLE_LEN - 1] = b'b'; + assert!(!matcher.matches(&escaped(&mismatch))); +} + +#[test] +fn test_contains_pushdown_rejects_len_255() { + let needle = "a".repeat(FlatContainsDfa::MAX_NEEDLE_LEN + 1); + let pattern = format!("%{needle}%"); + assert!(FsstMatcher::try_new(&[], &[], &pattern).unwrap().is_none()); +} From 7faf9f36f75bad311ea9d3a2bfe21898fc6f1dfc Mon Sep 17 00:00:00 2001 From: Joe Isaacs Date: Wed, 18 Mar 2026 18:48:44 +0000 Subject: [PATCH 08/19] clean up Signed-off-by: Joe Isaacs --- encodings/fsst/README.md | 2 +- encodings/fsst/src/compute/like.rs | 12 ++- encodings/fsst/src/dfa/DFA_NOTES.md | 9 +- encodings/fsst/src/dfa/mod.rs | 42 ++++----- encodings/fsst/src/dfa/prefix.rs | 134 +++++++++++++--------------- encodings/fsst/src/dfa/tests.rs | 118 +++++++++++++++++++++++- 6 files changed, 204 insertions(+), 113 deletions(-) diff --git a/encodings/fsst/README.md b/encodings/fsst/README.md index 83668515f26..7cc53ba07f9 100644 --- a/encodings/fsst/README.md +++ b/encodings/fsst/README.md @@ -17,7 +17,7 @@ wildcards, fall back to ordinary decompression-based LIKE evaluation. There are also two implementation limits on the pushdown path, both measured in pattern bytes: -- `prefix%` supports up to 13 bytes. +- `prefix%` supports up to 253 bytes. - `%needle%` supports up to 254 bytes. Patterns beyond those limits are still evaluated correctly, but they do so via diff --git a/encodings/fsst/src/compute/like.rs b/encodings/fsst/src/compute/like.rs index ad0f65ee0ac..732708a64c1 100644 --- a/encodings/fsst/src/compute/like.rs +++ b/encodings/fsst/src/compute/like.rs @@ -272,7 +272,7 @@ mod tests { } #[test] - fn test_like_long_prefix_falls_back_but_still_matches() -> VortexResult<()> { + fn test_like_long_prefix_handled_by_flat_dfa() -> VortexResult<()> { let fsst = make_fsst( &[ Some("abcdefghijklmn-tail"), @@ -290,12 +290,10 @@ mod tests { &mut SESSION.create_execution_ctx(), )?; assert!( - direct.is_none(), - "14-byte prefixes exceed the packed prefix DFA and should fall back" + direct.is_some(), + "14-byte prefixes are now handled by the flat prefix DFA" ); - - let result = like(fsst, pattern)?; - assert_arrays_eq!(&result, &BoolArray::from_iter([true, false, true])); + assert_arrays_eq!(direct.unwrap(), BoolArray::from_iter([true, false, true])); Ok(()) } @@ -467,7 +465,7 @@ mod tests { #[test] fn fuzz_prefix_matching() -> VortexResult<()> { for seed in 0..50 { - for prefix_len in [1, 3, 5, 10, 13] { + for prefix_len in [1, 3, 5, 10, 13, 20, 40] { fuzz_prefix(seed, prefix_len, 200)?; } } diff --git a/encodings/fsst/src/dfa/DFA_NOTES.md b/encodings/fsst/src/dfa/DFA_NOTES.md index 2fe7708f55d..f22483ca28f 100644 --- a/encodings/fsst/src/dfa/DFA_NOTES.md +++ b/encodings/fsst/src/dfa/DFA_NOTES.md @@ -11,13 +11,13 @@ Unified 5 DFA structs down to 3: | `FlatBranchlessDfa` | `FlatContainsDfa` | Merged with `FusedDfa` into single struct with `EscapeStrategy` enum | | `FusedDfa` | `FlatContainsDfa` | Merged (see above) | | `BranchlessShiftDfa` | `BranchlessShiftDfa` | Unchanged | -| `FsstPrefixDfa` | `FsstPrefixDfa` | Simplified escape transition building | +| `FlatPrefixDfa` | `FlatPrefixDfa` | Simplified escape transition building | Other changes: - Extracted `build_escape_folded_table()` (shared by `BranchlessShiftDfa` and `FlatContainsDfa`) - Extracted `compose_packed()` (shared by `build_pair_compose` and `build_compose_4b`) - Extended `FlatContainsDfa` folded range from 14 → 127 (2*127+1=255 fits in u8) -- Simplified `FsstPrefixDfa` escape transitions (reuse byte_table directly) +- Simplified `FlatPrefixDfa` escape transitions (reuse byte_table directly) - Deleted `pack_escape_shift_table` (only caller was `ShiftDfa`) ## Removed code (recoverable from git) @@ -106,10 +106,9 @@ Would reduce loop overhead for long compressed strings. Tables stay the same siz just add a `compose_8b` level on top of `compose_4b`. ### 2. Branchless prefix DFA -`FsstPrefixDfa` currently uses escape sentinel + branch. Could use escape-folding +`FlatPrefixDfa` currently uses escape sentinel + branch. Could use escape-folding (like the contains DFAs) to make the prefix scan branchless. Needs 2*prefix_len+1 -states to fit in 4-bit packing, so max prefix drops from 13 to 7. Worth it if -prefix matching is a bottleneck. +states, so max prefix would drop. Worth it if prefix matching is a bottleneck. ### 3. Further struct merging `BranchlessShiftDfa` and `FlatContainsDfa` (folded) share the same escape-folded diff --git a/encodings/fsst/src/dfa/mod.rs b/encodings/fsst/src/dfa/mod.rs index fd1da377583..49c710d6430 100644 --- a/encodings/fsst/src/dfa/mod.rs +++ b/encodings/fsst/src/dfa/mod.rs @@ -111,9 +111,9 @@ //! The public behavior is shaped by two implementation limits, both measured in //! pattern **bytes** rather than Unicode scalar values: //! -//! - `prefix%` pushdown is limited to **13 bytes**. The packed prefix DFA uses -//! 4-bit state ids and needs room for normal prefix-progress states, an -//! accept state, a fail state, and one escape sentinel for FSST literals. +//! - `prefix%` pushdown is limited to **253 bytes**. The flat prefix DFA uses +//! `u8` state ids and needs room for progress states, an accept state, a +//! fail state, and one escape sentinel (N+3 ≤ 256). //! - `%needle%` pushdown is limited to **254 bytes**. The long-needle DFA stores //! states in `u8`, so it needs room for every match-progress state plus both //! the accept state and the escape sentinel. @@ -127,7 +127,7 @@ //! ┌───────────────┬──────────────────────────────────────────────────────┐ //! │ Pattern │ Needle length → DFA variant │ //! ├───────────────┼──────────────────────────────────────────────────────┤ -//! │ prefix% │ 0–13 → FsstPrefixDfa (shift-packed, no KMP) │ +//! │ prefix% │ 0–253 → FlatPrefixDfa (flat u8, esc-sentinel) │ //! ├───────────────┼──────────────────────────────────────────────────────┤ //! │ %needle% │ 1–7 → BranchlessShiftDfa (hierarchical 4-byte) │ //! │ │ 8–127 → FlatContainsDfa (flat u8, esc-folded) │ @@ -139,7 +139,7 @@ //! //! There are two ways to handle the FSST escape code in the DFA: //! -//! **Escape sentinel** (used by `FlatContainsDfa` for long needles, `FsstPrefixDfa`): +//! **Escape sentinel** (used by `FlatContainsDfa` for long needles, `FlatPrefixDfa`): //! The escape code maps to a sentinel state. The scanner checks for it and //! reads the next byte from a separate escape transition table. //! @@ -173,7 +173,7 @@ use branchless_shift::BranchlessShiftDfa; use flat_contains::FlatContainsDfa; use fsst::ESCAPE_CODE; use fsst::Symbol; -use prefix::FsstPrefixDfa; +use prefix::FlatPrefixDfa; use vortex_buffer::BitBuffer; use vortex_error::VortexResult; @@ -193,7 +193,7 @@ pub(crate) struct FsstMatcher { enum MatcherInner { MatchAll, - Prefix(Box), + Prefix(FlatPrefixDfa), ContainsBranchless(Box), ContainsFlat(FlatContainsDfa), } @@ -203,7 +203,7 @@ impl FsstMatcher { /// /// Returns `Ok(None)` if the pattern shape is not supported for pushdown /// (e.g. `_` wildcards, multiple non-bookend `%`, `prefix%` longer than - /// 13 bytes, or `%needle%` longer than 254 bytes). + /// 253 bytes, or `%needle%` longer than 254 bytes). pub(crate) fn try_new( symbols: &[Symbol], symbol_lengths: &[u8], @@ -217,14 +217,10 @@ impl FsstMatcher { LikeKind::Prefix("") => MatcherInner::MatchAll, LikeKind::Prefix(prefix) => { let prefix = prefix.as_bytes(); - if prefix.len() > FsstPrefixDfa::MAX_PREFIX_LEN { + if prefix.len() > FlatPrefixDfa::MAX_PREFIX_LEN { return Ok(None); } - MatcherInner::Prefix(Box::new(FsstPrefixDfa::new( - symbols, - symbol_lengths, - prefix, - )?)) + MatcherInner::Prefix(FlatPrefixDfa::new(symbols, symbol_lengths, prefix)?) } LikeKind::Contains(needle) => { let needle = needle.as_bytes(); @@ -373,20 +369,18 @@ fn build_symbol_transitions( } let sym = symbols[code].to_u64().to_le_bytes(); let sym_len = usize::from(symbol_lengths[code]); - let mut s = state; - for &b in &sym[..sym_len] { - if s == usize::from(accept_state) { - break; - } - s = usize::from(byte_table[s * 256 + usize::from(b)]); - } #[expect( clippy::cast_possible_truncation, - reason = "s is a state id < n_states ≤ 256" + reason = "state < n_states ≤ 256" )] - { - sym_trans[state * n_symbols + code] = s as u8; + let mut s = state as u8; + for &b in &sym[..sym_len] { + if s == accept_state { + break; + } + s = byte_table[usize::from(s) * 256 + usize::from(b)]; } + sym_trans[state * n_symbols + code] = s; } } sym_trans diff --git a/encodings/fsst/src/dfa/prefix.rs b/encodings/fsst/src/dfa/prefix.rs index b00cd34bda2..c30da87f3ba 100644 --- a/encodings/fsst/src/dfa/prefix.rs +++ b/encodings/fsst/src/dfa/prefix.rs @@ -1,12 +1,16 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright the Vortex contributors -//! DFA for prefix matching (`LIKE 'prefix%'`). +//! Flat `u8` transition table DFA for prefix matching (`LIKE 'prefix%'`). //! -//! TODO(joe): support longer prefixes (14–253 bytes) via a flat `Vec` table -//! with escape sentinel, similar to `FlatContainsDfa`. The construction is simpler -//! than contains (no KMP — mismatches go to a sticky fail state). Would need states -//! 0..N (progress) + accept + fail + sentinel, so N+3 ≤ 256 → max prefix = 253. +//! Supports prefixes up to 253 bytes (states: 0..N progress + accept + fail + +//! sentinel ≤ 256). +//! +//! TODO(joe): for short prefixes (≤13 bytes), a shift-packed `[u64; 256]` +//! representation would be simpler and easier to read — all state transitions +//! for one input byte fit in a single `u64`. Benchmarks showed no meaningful +//! perf difference (see `benches/BENCH_RESULTS.md`), so we use flat-only for +//! now to keep the code simple and support long prefixes. use fsst::Symbol; use vortex_error::VortexExpect; @@ -15,15 +19,40 @@ use vortex_error::vortex_bail; use super::build_fused_table; use super::build_symbol_transitions; -use super::pack_shift_table; -use super::shift_extract; -/// Precomputed shift-based DFA for prefix matching on FSST codes. +/// Build a byte-level transition table for prefix matching (no KMP fallback). /// -/// States 0..prefix_len track match progress, plus ACCEPT and FAIL. -/// Uses the same shift-based approach as the contains DFA: all state -/// transitions packed into a `u64` per code byte. For prefixes longer -/// than 13 characters, pushdown is disabled and LIKE falls back. +/// For each state, only the correct next byte advances; everything else goes +/// to the fail state. +fn build_prefix_byte_table(prefix: &[u8], accept_state: u8, fail_state: u8) -> Vec { + let n_states = fail_state + 1; + let mut table = vec![fail_state; usize::from(n_states) * 256]; + + for state in 0..n_states { + let s = usize::from(state); + if state == accept_state { + for byte in 0..256 { + table[s * 256 + byte] = accept_state; + } + } else if state != fail_state { + // Only the correct next byte advances; everything else fails. + let next_byte = prefix[s]; + let next_state = if s + 1 >= prefix.len() { + accept_state + } else { + state + 1 + }; + table[s * 256 + usize::from(next_byte)] = next_state; + } + } + table +} + +/// Flat `u8` transition table DFA for prefix matching on FSST codes. +/// +/// States 0..prefix_len track match progress, plus ACCEPT, FAIL, and an +/// escape SENTINEL. Transitions are stored in a flat `Vec` indexed as +/// `[state * 256 + byte]`. /// /// ```text /// Prefix: "http" (4 progress states + accept + fail) @@ -40,18 +69,18 @@ use super::shift_extract; /// /// Escape handling: code 255 → sentinel → read next literal byte → byte table /// ``` -pub(crate) struct FsstPrefixDfa { - /// Packed transitions: `(table[code] >> (state * 4)) & 0xF` gives next state. - transitions: [u64; 256], - /// Packed escape transitions for literal bytes. - escape_transitions: [u64; 256], +pub(crate) struct FlatPrefixDfa { + /// `transitions[state * 256 + byte]` -> next state. + transitions: Vec, + /// `escape_transitions[state * 256 + byte]` -> next state for escaped bytes. + escape_transitions: Vec, accept_state: u8, fail_state: u8, + sentinel: u8, } -impl FsstPrefixDfa { - pub(crate) const BITS: u32 = 4; - pub(crate) const MAX_PREFIX_LEN: usize = (1 << Self::BITS) as usize - 3; +impl FlatPrefixDfa { + pub(crate) const MAX_PREFIX_LEN: usize = 253; pub(crate) fn new( symbols: &[Symbol], @@ -60,7 +89,7 @@ impl FsstPrefixDfa { ) -> VortexResult { if prefix.len() > Self::MAX_PREFIX_LEN { vortex_bail!( - "prefix length {} exceeds maximum {} for shift-packed prefix DFA", + "prefix length {} exceeds maximum {} for flat prefix DFA", prefix.len(), Self::MAX_PREFIX_LEN ); @@ -69,68 +98,33 @@ impl FsstPrefixDfa { let accept_state = u8::try_from(prefix.len()).vortex_expect("prefix fits in u8"); let fail_state = accept_state + 1; let n_states = fail_state + 1; + let sentinel = fail_state + 1; - // Prefix matching uses a simpler transition rule than KMP: on mismatch - // we go to fail_state (no fallback). Build the byte table inline. - let byte_table = Self::build_prefix_byte_table(prefix, accept_state, fail_state); + // Step 1: byte-level transitions + let byte_table = build_prefix_byte_table(prefix, accept_state, fail_state); + // Step 2: symbol-level transitions let sym_trans = build_symbol_transitions(symbols, symbol_lengths, &byte_table, n_states, accept_state); - // Override fail_state rows: fail is sticky. - let escape_sentinel = fail_state + 1; - let mut fused = build_fused_table( + // Step 3: fused table with escape sentinel + let transitions = build_fused_table( &sym_trans, symbols.len(), n_states, - |_| escape_sentinel, + |_| sentinel, fail_state, ); - // Accept and fail states are sticky for all inputs. - let accept_row = usize::from(accept_state) * 256; - fused[accept_row..accept_row + 256].fill(accept_state); - let fail_row = usize::from(fail_state) * 256; - fused[fail_row..fail_row + 256].fill(fail_state); - - let transitions = pack_shift_table(&fused, n_states, Self::BITS); - - // Escape transitions: for an escaped literal byte, use the byte-level transition. - let escape_transitions = pack_shift_table(&byte_table, n_states, Self::BITS); - Ok(Self { transitions, - escape_transitions, + escape_transitions: byte_table, accept_state, fail_state, + sentinel, }) } - /// Build a byte-level transition table for prefix matching (no KMP fallback). - fn build_prefix_byte_table(prefix: &[u8], accept_state: u8, fail_state: u8) -> Vec { - let n_states = fail_state + 1; - let mut table = vec![fail_state; usize::from(n_states) * 256]; - - for state in 0..n_states { - let s = usize::from(state); - if state == accept_state { - for byte in 0..256 { - table[s * 256 + byte] = accept_state; - } - } else if state != fail_state { - // Only the correct next byte advances; everything else fails. - let next_byte = prefix[s]; - let next_state = if s + 1 >= prefix.len() { - accept_state - } else { - state + 1 - }; - table[s * 256 + usize::from(next_byte)] = next_state; - } - } - table - } - #[inline] pub(crate) fn matches(&self, codes: &[u8]) -> bool { let mut state = 0u8; @@ -138,18 +132,14 @@ impl FsstPrefixDfa { while pos < codes.len() { let code = codes[pos]; pos += 1; - let packed = self.transitions[usize::from(code)]; - // Masked to BITS (4) bits, result ≤ 15, fits in u8 - let next = shift_extract(packed, state, Self::BITS); - if next == self.fail_state + 1 { - // Escape sentinel: read literal byte. + let next = self.transitions[usize::from(state) * 256 + usize::from(code)]; + if next == self.sentinel { if pos >= codes.len() { return false; } let b = codes[pos]; pos += 1; - let esc_packed = self.escape_transitions[usize::from(b)]; - state = shift_extract(esc_packed, state, Self::BITS); + state = self.escape_transitions[usize::from(state) * 256 + usize::from(b)]; } else { state = next; } diff --git a/encodings/fsst/src/dfa/tests.rs b/encodings/fsst/src/dfa/tests.rs index b7fadbcf928..8b2a99aa4c9 100644 --- a/encodings/fsst/src/dfa/tests.rs +++ b/encodings/fsst/src/dfa/tests.rs @@ -2,11 +2,20 @@ // SPDX-FileCopyrightText: Copyright the Vortex contributors use fsst::ESCAPE_CODE; +use fsst::Symbol; +use vortex_error::VortexResult; use super::FsstMatcher; use super::LikeKind; use super::flat_contains::FlatContainsDfa; -use super::prefix::FsstPrefixDfa; +use super::prefix::FlatPrefixDfa; + +/// Helper: make a Symbol from a byte string (up to 8 bytes, zero-padded). +fn sym(bytes: &[u8]) -> Symbol { + let mut buf = [0u8; 8]; + buf[..bytes.len()].copy_from_slice(bytes); + Symbol::from_slice(&buf) +} fn escaped(bytes: &[u8]) -> Vec { let mut codes = Vec::with_capacity(bytes.len() * 2); @@ -33,6 +42,84 @@ fn test_like_kind_parse() { assert!(LikeKind::parse("a_c").is_none()); } +/// No symbols — all bytes escaped. Simplest case to see the two tables. +#[test] +fn test_prefix_dfa_no_symbols() -> VortexResult<()> { + let dfa = FlatPrefixDfa::new(&[], &[], b"ab")?; + + assert!(dfa.matches(&escaped(b"abx"))); + assert!(dfa.matches(&escaped(b"ab"))); + assert!(!dfa.matches(&escaped(b"a"))); + assert!(!dfa.matches(&escaped(b"ax"))); + assert!(!dfa.matches(&escaped(b"ba"))); + assert!(!dfa.matches(&[])); + + Ok(()) +} + +/// With symbols — shows how multi-byte symbols interact with prefix matching. +/// +/// Symbol table: code 0 = "ht", code 1 = "tp" +/// Prefix: "http" +/// +/// The string "http" can be encoded as: +/// [0, 1] — two symbols: "ht" + "tp" +/// [ESC,h, ESC,t, ESC,t, ESC,p] — all escaped +/// [0, ESC,t, ESC,p] — symbol "ht" + escaped "t" + escaped "p" +#[test] +fn test_prefix_dfa_with_symbols() -> VortexResult<()> { + let symbols = [sym(b"ht"), sym(b"tp")]; + let lengths = [2u8, 2]; + let dfa = FlatPrefixDfa::new(&symbols, &lengths, b"http")?; + + // "http" via two symbols: code 0 ("ht") + code 1 ("tp") → accept + assert!(dfa.matches(&[0, 1])); + + // "http" all escaped + assert!(dfa.matches(&escaped(b"http"))); + + // "http" mixed: symbol "ht" + escaped "tp" + assert!(dfa.matches(&[0, ESCAPE_CODE, b't', ESCAPE_CODE, b'p'])); + + // "htxx" via symbol "ht" + escaped "xx" → fail after "ht" advances to state 2, + // then 'x' doesn't match 't' + assert!(!dfa.matches(&[0, ESCAPE_CODE, b'x', ESCAPE_CODE, b'x'])); + + // "tp" alone → symbol "tp" from state 0 feeds 't','p' through byte table: + // state 0 wants 'h', sees 't' → fail + assert!(!dfa.matches(&[1])); + + Ok(()) +} + +/// Longer prefix showing more progress states. +#[test] +fn test_prefix_dfa_longer() -> VortexResult<()> { + // code 0 = "tp" (2 bytes), code 1 = "htt" (3 bytes), code 2 = "p:/" (3 bytes) + let symbols = [sym(b"tp"), sym(b"htt"), sym(b"p:/")]; + let lengths = [2u8, 3, 3]; + let dfa = FlatPrefixDfa::new(&symbols, &lengths, b"http://")?; + + // "http://e" via symbols: "htt"(1) + "p:/"(2) + escaped "/" + escaped "e" + // "htt" = states 0→1→2→3, "p:/" = states 3→4→5→6, "/" = state 6→accept + assert!(dfa.matches(&[1, 2, ESCAPE_CODE, b'/', ESCAPE_CODE, b'e'])); + + // "http:/" — 6 chars, missing the 7th '/' + assert!(!dfa.matches(&[1, ESCAPE_CODE, b'p', ESCAPE_CODE, b':', ESCAPE_CODE, b'/',])); + + // "http://" all escaped — 7 chars, exact match + assert!(dfa.matches(&escaped(b"http://"))); + + // "tp" alone (code 0) from state 0: feeds 't','p' → state 0 wants 'h', sees 't' → fail + assert!(!dfa.matches(&[0])); + + // "htt" + "tp" = "httpp"? No — "htt" → states 0→1→2→3, then "tp": + // state 3 wants 'p', sees 't' → fail immediately + assert!(!dfa.matches(&[1, 0])); + + Ok(()) +} + #[test] fn test_prefix_pushdown_len_13_with_escapes() { let matcher = FsstMatcher::try_new(&[], &[], "abcdefghijklm%") @@ -44,15 +131,38 @@ fn test_prefix_pushdown_len_13_with_escapes() { } #[test] -fn test_prefix_pushdown_rejects_len_14() { - debug_assert_eq!(FsstPrefixDfa::MAX_PREFIX_LEN, 13); +fn test_prefix_pushdown_len_14_now_handled() { + // 14-byte prefix is now handled by FlatPrefixDfa (was rejected by shift-packed). assert!( FsstMatcher::try_new(&[], &[], "abcdefghijklmn%") .unwrap() - .is_none() + .is_some() ); } +#[test] +fn test_prefix_pushdown_long_prefix() -> VortexResult<()> { + let prefix = "a".repeat(FlatPrefixDfa::MAX_PREFIX_LEN); + let pattern = format!("{prefix}%"); + let matcher = FsstMatcher::try_new(&[], &[], &pattern)?.unwrap(); + + assert!(matcher.matches(&escaped(prefix.as_bytes()))); + + let mut mismatch = prefix.into_bytes(); + mismatch[FlatPrefixDfa::MAX_PREFIX_LEN - 1] = b'b'; + assert!(!matcher.matches(&escaped(&mismatch))); + + Ok(()) +} + +#[test] +fn test_prefix_pushdown_rejects_len_254() { + debug_assert_eq!(FlatPrefixDfa::MAX_PREFIX_LEN, 253); + let prefix = "a".repeat(254); + let pattern = format!("{prefix}%"); + assert!(FsstMatcher::try_new(&[], &[], &pattern).unwrap().is_none()); +} + #[test] fn test_contains_pushdown_len_254_with_escapes() { let needle = "a".repeat(FlatContainsDfa::MAX_NEEDLE_LEN); From f1c6a021d002c976c884d04df1682d4ef2b808b1 Mon Sep 17 00:00:00 2001 From: Joe Isaacs Date: Wed, 18 Mar 2026 19:16:36 +0000 Subject: [PATCH 09/19] clean up Signed-off-by: Joe Isaacs --- encodings/fsst/src/dfa/DFA_NOTES.md | 123 ----------- encodings/fsst/src/dfa/branchless_shift.rs | 227 --------------------- encodings/fsst/src/dfa/flat_contains.rs | 180 ++++++++-------- encodings/fsst/src/dfa/mod.rs | 221 +++----------------- encodings/fsst/src/dfa/prefix.rs | 58 +++--- 5 files changed, 144 insertions(+), 665 deletions(-) delete mode 100644 encodings/fsst/src/dfa/DFA_NOTES.md delete mode 100644 encodings/fsst/src/dfa/branchless_shift.rs diff --git a/encodings/fsst/src/dfa/DFA_NOTES.md b/encodings/fsst/src/dfa/DFA_NOTES.md deleted file mode 100644 index f22483ca28f..00000000000 --- a/encodings/fsst/src/dfa/DFA_NOTES.md +++ /dev/null @@ -1,123 +0,0 @@ -# DFA Refactoring Notes - -## Summary of changes (from 1229 → 1110 lines) - -Unified 5 DFA structs down to 3: - -| Before | After | What happened | -|--------|-------|---------------| -| `ShiftDfa` | (deleted) | Dead code — `FsstContainsDfa` only routed needles >14 to it, but `ShiftDfa::MAX_NEEDLE_LEN` was 14, so the arm was unreachable | -| `FsstContainsDfa` | (deleted) | Dispatch enum wrapping dead `ShiftDfa` arm; only the `FusedDfa` path was reachable | -| `FlatBranchlessDfa` | `FlatContainsDfa` | Merged with `FusedDfa` into single struct with `EscapeStrategy` enum | -| `FusedDfa` | `FlatContainsDfa` | Merged (see above) | -| `BranchlessShiftDfa` | `BranchlessShiftDfa` | Unchanged | -| `FlatPrefixDfa` | `FlatPrefixDfa` | Simplified escape transition building | - -Other changes: -- Extracted `build_escape_folded_table()` (shared by `BranchlessShiftDfa` and `FlatContainsDfa`) -- Extracted `compose_packed()` (shared by `build_pair_compose` and `build_compose_4b`) -- Extended `FlatContainsDfa` folded range from 14 → 127 (2*127+1=255 fits in u8) -- Simplified `FlatPrefixDfa` escape transitions (reuse byte_table directly) -- Deleted `pack_escape_shift_table` (only caller was `ShiftDfa`) - -## Removed code (recoverable from git) - -All removed code is in commit `e08fb69ad` (the starting point). Key pieces: - -### `ShiftDfa` (~70 lines) -Shift-packed `[u64; 256]` DFA using escape sentinel. Was identical in scan loop to -`BranchlessShiftDfa` but without the hierarchical 4-byte compose optimization. -Recovery: `git show e08fb69ad:encodings/fsst/src/dfa.rs` lines ~956-1027. - -### `pack_escape_shift_table` (~15 lines) -Built a separate shift-packed escape transition table. Only used by `ShiftDfa`. -Recovery: same commit, lines ~418-433. - -### `FsstContainsDfa` enum (~25 lines) -Dispatch enum: `ShiftDfa` for len ≤ 14, `FusedDfa` for len > 14. -Since caller guaranteed len > 14, the `ShiftDfa` arm was dead. -Recovery: same commit, lines ~592-615. - -## Benchmark results: escape strategy comparison - -Sentinel-only is 28-45% slower than folded for needles 8-14. -Both strategies must be kept in `FlatContainsDfa`. - -| Benchmark | Needle len | Folded (ms) | Sentinel (ms) | Regression | -|-----------|-----------|-------------|---------------|------------| -| contains/log | 9 | 5.449 | 7.480 | +37% | -| contains/json | 10 | 2.390 | 3.466 | +45% | -| contains/path | 14 | 0.937 | 1.199 | +28% | - -## Current benchmark baseline (post-refactor) - -``` -fsst_like fastest │ slowest │ median │ mean │ samples │ iters -├─ fsst_contains │ │ │ │ │ -│ ├─ cb 1.593 ms │ 2.122 ms │ 1.725 ms │ 1.745 ms │ 100 │ 100 -│ ├─ email 492.9 µs │ 697.7 µs │ 526.3 µs │ 544.2 µs │ 100 │ 100 -│ ├─ json 2.282 ms │ 2.731 ms │ 2.401 ms │ 2.406 ms │ 100 │ 100 -│ ├─ log 5.191 ms │ 5.919 ms │ 5.426 ms │ 5.439 ms │ 100 │ 100 -│ ├─ path 894.3 µs │ 1.076 ms │ 941.1 µs │ 952.8 µs │ 100 │ 100 -│ ├─ rare 1.674 ms │ 4.55 ms │ 1.814 ms │ 1.992 ms │ 100 │ 100 -│ ╰─ urls 736.8 µs │ 959.6 µs │ 837.1 µs │ 844.6 µs │ 100 │ 100 -╰─ fsst_prefix │ │ │ │ │ - ├─ cb 541.7 µs │ 761 µs │ 585.2 µs │ 598.1 µs │ 100 │ 100 - ├─ email 197.9 µs │ 305.8 µs │ 208.2 µs │ 214.6 µs │ 100 │ 100 - ├─ json 141.9 µs │ 352.6 µs │ 145.5 µs │ 151.8 µs │ 100 │ 100 - ├─ log 259.6 µs │ 378.1 µs │ 278.5 µs │ 285.3 µs │ 100 │ 100 - ├─ path 214.2 µs │ 281.1 µs │ 227.1 µs │ 230.9 µs │ 100 │ 100 - ├─ rare 153.7 µs │ 191.9 µs │ 157.1 µs │ 160.8 µs │ 100 │ 100 - ╰─ urls 260.7 µs │ 445.4 µs │ 294.2 µs │ 297.7 µs │ 100 │ 100 -``` - -DFA routing per benchmark: -- cb, email, rare, urls (needle ≤ 7) → `BranchlessShiftDfa` -- log (9), json (10), path (14) → `FlatContainsDfa` (folded) -- No benchmark exercises sentinel path (would need needle > 127) - -## Post integer-type cleanup benchmarks - -After eliminating `u16`, tightening `usize` → `u8` in `compose_packed`, `pack_shift_table`, -`kmp_failure_table`, and `kmp_byte_transitions`. All within noise of baseline. - -| Benchmark | Baseline (ms) | Current (ms) | Delta | -|-----------|--------------|-------------|-------| -| contains/cb | 1.725 | 1.695 | -1.7% | -| contains/email | 0.526 | 0.542 | +2.9% | -| contains/json | 2.401 | 2.452 | +2.1% | -| contains/log | 5.426 | 5.447 | +0.4% | -| contains/path | 0.941 | 0.949 | +0.8% | -| contains/rare | 1.814 | 1.762 | -2.9% | -| contains/urls | 0.837 | 0.812 | -3.0% | -| prefix/cb | 0.585 | 0.568 | -3.0% | -| prefix/email | 0.208 | 0.215 | +3.0% | -| prefix/json | 0.146 | 0.145 | -0.2% | -| prefix/log | 0.279 | 0.270 | -3.1% | -| prefix/path | 0.227 | 0.224 | -1.2% | -| prefix/rare | 0.157 | 0.159 | +1.1% | -| prefix/urls | 0.294 | 0.288 | -2.1% | - -## Optimization ideas for later - -### 1. 8-byte-per-iter BranchlessShiftDfa -Extend `BranchlessShiftDfa` to process 8 bytes/iteration via two 4-byte composes. -Would reduce loop overhead for long compressed strings. Tables stay the same size, -just add a `compose_8b` level on top of `compose_4b`. - -### 2. Branchless prefix DFA -`FlatPrefixDfa` currently uses escape sentinel + branch. Could use escape-folding -(like the contains DFAs) to make the prefix scan branchless. Needs 2*prefix_len+1 -states, so max prefix would drop. Worth it if prefix matching is a bottleneck. - -### 3. Further struct merging -`BranchlessShiftDfa` and `FlatContainsDfa` (folded) share the same escape-folded -state layout. They differ only in table representation (shift-packed u64 vs flat u8). -Could theoretically be merged, but the hierarchical 4-byte compose in -`BranchlessShiftDfa` is fundamentally different from the flat scan, so sharing code -wouldn't simplify much. - -### 4. Suffix pushdown (`%suffix`) -Two approaches noted in the module doc: -- Forward DFA with non-sticky accept (check state == accept after all codes) -- Backward scan of compressed stream diff --git a/encodings/fsst/src/dfa/branchless_shift.rs b/encodings/fsst/src/dfa/branchless_shift.rs deleted file mode 100644 index 2facc49d8fd..00000000000 --- a/encodings/fsst/src/dfa/branchless_shift.rs +++ /dev/null @@ -1,227 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: Copyright the Vortex contributors - -//! Branchless shift-packed DFA for short contains matching (`LIKE '%needle%'`, needle ≤ 7). - -use fsst::Symbol; -use vortex_error::VortexResult; -use vortex_error::vortex_bail; - -use super::build_escape_folded_table; -use super::compose_packed; -use super::pack_shift_table; -use super::shift_extract; - -/// Branchless escape-folded DFA for short needles (len <= 7). -/// -/// Folds escape handling into the state space so that `matches()` is -/// completely branchless (except for loop control). The state layout is: -/// - States 0..N-1: normal match-progress states -/// - State N: accept (sticky for all inputs) -/// - States N+1..2N: escape states (state `s+N+1` means "was in state `s`, -/// just consumed ESCAPE_CODE") -/// -/// Total states: 2N+1. With 4-bit packing, max N=7. -/// -/// Uses a decomposed hierarchical lookup that processes 4 code bytes per -/// loop iteration with only ~3 KB of tables: -/// -/// 1. **Equivalence class table** (256 B): maps each code byte to a class -/// id. Bytes with identical transition u64s share a class -- typically -/// only ~6-10 classes exist (needle chars + escape + "miss-all"). -/// 2. **Pair-compose table** (~N^2 B): maps `(class0, class1)` to a 2-byte -/// palette index. Typically ~36 entries. -/// 3. **4-byte compose table** (~M^2 x 8 B): maps `(palette0, palette1)` to -/// the composed packed u64 for all 4 bytes. Typically ~81 entries = 648 B. -/// -/// Each loop iteration: 4 class lookups (parallel, 256 B table) -> 2 -/// pair-compose lookups (parallel, ~36 B table) -> 1 compose lookup -/// (~648 B table) -> 1 shift+mask. All tables fit in L1 cache. -pub(crate) struct BranchlessShiftDfa { - /// Maps each code byte to its equivalence class. Bytes with the same - /// packed transition u64 share a class. (256 bytes) - eq_class: [u8; 256], - /// Maps `(class0 * n_classes + class1)` -> 2-byte palette index. - pair_compose: Vec, - /// Number of equivalence classes (stride for pair_compose). - n_classes: usize, - /// Maps `(palette0 * n_palette + palette1)` -> composed packed u64 - /// for 4 bytes. - compose_4b: Vec, - /// Number of unique 2-byte palette entries (stride for compose_4b). - n_palette: usize, - /// 1-byte fallback transitions for trailing bytes. - transitions_1b: [u64; 256], - /// 2-byte palette for the remainder path (2-3 trailing bytes). - palette_2b: Vec, - accept_state: u8, -} - -impl BranchlessShiftDfa { - const BITS: u32 = 4; - /// Maximum needle length: need 2N+1 states to fit in 16 slots (4 bits). - /// 2*7+1 = 15 <= 16, so max N = 7. - pub(crate) const MAX_NEEDLE_LEN: usize = 7; - - pub(crate) fn new( - symbols: &[Symbol], - symbol_lengths: &[u8], - needle: &[u8], - ) -> VortexResult { - let n = needle.len(); - if n > Self::MAX_NEEDLE_LEN { - vortex_bail!( - "needle length {} exceeds maximum {} for branchless shift DFA", - n, - Self::MAX_NEEDLE_LEN - ); - } - - #[expect(clippy::cast_possible_truncation, reason = "n ≤ MAX_NEEDLE_LEN (7)")] - let accept_state = n as u8; - let total_states = 2 * accept_state + 1; - - let fused = build_escape_folded_table(symbols, symbol_lengths, needle); - let transitions_1b = pack_shift_table(&fused, total_states, Self::BITS); - - // Build equivalence classes: group bytes with identical transition u64. - let mut eq_class = [0u8; 256]; - let mut class_representatives: Vec = Vec::new(); - for byte_val in 0..256usize { - let t = transitions_1b[byte_val]; - let cls = class_representatives - .iter() - .position(|&v| v == t) - .unwrap_or_else(|| { - class_representatives.push(t); - class_representatives.len() - 1 - }); - #[expect(clippy::cast_possible_truncation, reason = "≤ 256 equivalence classes")] - { - eq_class[byte_val] = cls as u8; - } - } - let n_classes = class_representatives.len(); - - // Build pair-compose: for each (class0, class1), compose the two - // 1-byte transitions and deduplicate into a 2-byte palette. - let (pair_compose, palette_2b) = - Self::build_pair_compose(&class_representatives, n_classes, total_states); - - // Build 4-byte composition: compose_4b[p0 * n + p1] gives the packed - // u64 for applying palette_2b[p0] then palette_2b[p1] in sequence. - let n_palette = palette_2b.len(); - let compose_4b = Self::build_compose_4b(&palette_2b, total_states); - - Ok(Self { - eq_class, - pair_compose, - n_classes, - compose_4b, - n_palette, - transitions_1b, - palette_2b, - accept_state, - }) - } - - /// Build the pair-compose table and 2-byte palette from equivalence - /// class representatives. - fn build_pair_compose( - class_reps: &[u64], - n_classes: usize, - total_states: u8, - ) -> (Vec, Vec) { - let mut pair_compose = vec![0u8; n_classes * n_classes]; - let mut palette_2b: Vec = Vec::new(); - - for c0 in 0..n_classes { - for c1 in 0..n_classes { - let packed = - compose_packed(class_reps[c0], class_reps[c1], total_states, Self::BITS); - let idx = palette_2b - .iter() - .position(|&v| v == packed) - .unwrap_or_else(|| { - palette_2b.push(packed); - palette_2b.len() - 1 - }); - #[expect( - clippy::cast_possible_truncation, - reason = "palette size ≤ n_classes² ≤ 256" - )] - { - pair_compose[c0 * n_classes + c1] = idx as u8; - } - } - } - (pair_compose, palette_2b) - } - - /// Compose pairs of 2-byte palette entries into a 4-byte lookup table. - fn build_compose_4b(palette_2b: &[u64], total_states: u8) -> Vec { - let n = palette_2b.len(); - let mut compose = vec![0u64; n * n]; - for p0 in 0..n { - for p1 in 0..n { - compose[p0 * n + p1] = - compose_packed(palette_2b[p0], palette_2b[p1], total_states, Self::BITS); - } - } - compose - } - - /// Process remaining bytes after the interleaved common prefix. - #[inline] - fn finish_tail(&self, mut state: u8, codes: &[u8]) -> u8 { - let chunks = codes.chunks_exact(4); - let rem = chunks.remainder(); - - for chunk in chunks { - // SAFETY: chunk[i] is u8, eq_class has 256 entries — index always in bounds. - let ec0 = unsafe { *self.eq_class.get_unchecked(usize::from(chunk[0])) }; - let ec1 = unsafe { *self.eq_class.get_unchecked(usize::from(chunk[1])) }; - let ec2 = unsafe { *self.eq_class.get_unchecked(usize::from(chunk[2])) }; - let ec3 = unsafe { *self.eq_class.get_unchecked(usize::from(chunk[3])) }; - let p0 = unsafe { - *self - .pair_compose - .get_unchecked(usize::from(ec0) * self.n_classes + usize::from(ec1)) - }; - let p1 = unsafe { - *self - .pair_compose - .get_unchecked(usize::from(ec2) * self.n_classes + usize::from(ec3)) - }; - let packed = unsafe { - *self - .compose_4b - .get_unchecked(usize::from(p0) * self.n_palette + usize::from(p1)) - }; - state = shift_extract(packed, state, Self::BITS); - } - - if rem.len() >= 2 { - let ec0 = self.eq_class[usize::from(rem[0])]; - let ec1 = self.eq_class[usize::from(rem[1])]; - let p = self.pair_compose[usize::from(ec0) * self.n_classes + usize::from(ec1)]; - let packed = self.palette_2b[usize::from(p)]; - state = shift_extract(packed, state, Self::BITS); - if rem.len() == 3 { - let packed = self.transitions_1b[usize::from(rem[2])]; - state = shift_extract(packed, state, Self::BITS); - } - } else if rem.len() == 1 { - let packed = self.transitions_1b[usize::from(rem[0])]; - state = shift_extract(packed, state, Self::BITS); - } - - state - } - - /// Branchless matching processing four code bytes per iteration. - #[inline(never)] - pub(crate) fn matches(&self, codes: &[u8]) -> bool { - self.finish_tail(0, codes) == self.accept_state - } -} diff --git a/encodings/fsst/src/dfa/flat_contains.rs b/encodings/fsst/src/dfa/flat_contains.rs index f6227c8f31f..64ce71117fa 100644 --- a/encodings/fsst/src/dfa/flat_contains.rs +++ b/encodings/fsst/src/dfa/flat_contains.rs @@ -1,48 +1,94 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright the Vortex contributors -//! Flat `u8` transition table DFA for contains matching (`LIKE '%needle%'`, needle 8-254). +//! Flat `u8` transition table DFA for contains matching (`LIKE '%needle%'`). +//! +//! Uses an escape-sentinel strategy: the FSST escape code maps to a sentinel +//! state, and the next literal byte is looked up in a separate byte-level +//! transition table. +//! +//! ## Construction (needle = `"aba"`, symbols = `[0:"ab", 1:"ba"]`) +//! +//! ### Step 1: KMP byte-level transition table +//! +//! Build a `(state × byte) → state` table using the KMP failure function. +//! States 0..2 track match progress, state 3 is accept (sticky). +//! +//! ```text +//! Input byte +//! State 'a' 'b' other +//! ───── ──── ──── ───── +//! 0 1 0 0 ← want 'a' +//! 1 1 2 0 ← matched "a", want 'b' (KMP: 'a'→stay at 1) +//! 2 3✓ 0 0 ← matched "ab", want 'a' +//! 3✓ 3✓ 3✓ 3✓ ← accept (sticky) +//! ``` +//! +//! ### Step 2: Symbol-level transitions +//! +//! For each `(state, symbol)` pair, simulate feeding the symbol's bytes +//! through the byte table: +//! +//! ```text +//! Symbol 0 = "ab" (2 bytes): +//! state 0 + 'a' → 1, + 'b' → 2 ⟹ sym_trans[0][0] = 2 +//! state 1 + 'a' → 1, + 'b' → 2 ⟹ sym_trans[1][0] = 2 +//! state 2 + 'a' → 3✓ ⟹ sym_trans[2][0] = 3✓ (accept) +//! +//! Symbol 1 = "ba" (2 bytes): +//! state 0 + 'b' → 0, + 'a' → 1 ⟹ sym_trans[0][1] = 1 +//! state 1 + 'b' → 2, + 'a' → 3✓ ⟹ sym_trans[1][1] = 3✓ (accept) +//! state 2 + 'b' → 0, + 'a' → 1 ⟹ sym_trans[2][1] = 1 +//! ``` +//! +//! ### Step 3: Fused 256-wide table with escape sentinel +//! +//! Merge symbol transitions into a 256-wide table. Code bytes 0–1 use symbol +//! transitions, code 255 (ESCAPE_CODE) maps to the sentinel (4), and +//! unused code bytes default to 0: +//! +//! ```text +//! Code byte +//! State 0("ab") 1("ba") 2..254 255(ESC) +//! ───── ─────── ─────── ────── ──────── +//! 0 2 1 0 4(S) +//! 1 2 3✓ 0 4(S) +//! 2 3✓ 1 0 4(S) +//! 3✓ 3✓ 3✓ 3✓ 3✓ +//! ``` +//! +//! When the scanner sees sentinel (4), it reads the next byte and looks it +//! up in the byte-level escape table (from step 1). +//! +//! TODO(joe): for short needles (≤7 bytes), a branchless escape-folded DFA +//! with hierarchical 4-byte composition is ~2x faster. For needles ≤127 bytes, +//! an escape-folded flat DFA (2N+1 states) avoids the sentinel branch. +//! See commit 7faf9f36f for those implementations. use fsst::Symbol; use vortex_error::VortexExpect; use vortex_error::VortexResult; use vortex_error::vortex_bail; -use super::build_escape_folded_table; use super::build_fused_table; use super::build_symbol_transitions; use super::kmp_byte_transitions; -/// Flat `u8` transition table DFA for contains matching (needles 8-254 bytes). +/// Flat `u8` transition table DFA for contains matching. /// -/// Uses two escape strategies depending on needle length: -/// - **Escape-folded** (needle ≤ 127): escape handling is folded into the state -/// space (2N+1 states), making the scan loop branchless. -/// - **Escape sentinel** (needle 128-254): escape code maps to a sentinel state -/// with a separate byte-level escape table. Required because 2N+1 > 255 won't -/// fit in `u8`. +/// The escape code maps to a sentinel state; the next literal byte is looked +/// up in a separate byte-level escape table. pub(crate) struct FlatContainsDfa { /// `transitions[state * 256 + byte]` -> next state. transitions: Vec, + /// `escape_transitions[state * 256 + byte]` -> next state for escaped bytes. + escape_transitions: Vec, accept_state: u8, - escape: EscapeStrategy, -} - -/// How the flat DFA handles the FSST escape code. -enum EscapeStrategy { - /// Escape states folded into the transition table (branchless scan). - Folded, - /// Escape code maps to a sentinel; next byte uses a separate table. - Sentinel { - escape_transitions: Vec, - sentinel: u8, - }, + sentinel: u8, } impl FlatContainsDfa { - /// Maximum needle for escape-folded mode: 2N+1 ≤ 255, so N ≤ 127. - const MAX_FOLDED_LEN: usize = 127; - /// Maximum needle overall: need accept + sentinel to fit in u8. + /// Maximum needle length: need accept + sentinel to fit in u8. pub(crate) const MAX_NEEDLE_LEN: usize = u8::MAX as usize - 1; pub(crate) fn new( @@ -60,95 +106,41 @@ impl FlatContainsDfa { let accept_state = u8::try_from(needle.len()) .vortex_expect("FlatContainsDfa: accept state must fit into u8"); + let n_states = accept_state + 1; + let sentinel = n_states; - if needle.len() <= Self::MAX_FOLDED_LEN { - let transitions = build_escape_folded_table(symbols, symbol_lengths, needle); - Ok(Self { - transitions, - accept_state, - escape: EscapeStrategy::Folded, - }) - } else { - let n_states = accept_state + 1; - let sentinel = n_states; - - let byte_table = kmp_byte_transitions(needle); - let sym_trans = build_symbol_transitions( - symbols, - symbol_lengths, - &byte_table, - n_states, - accept_state, - ); - let transitions = - build_fused_table(&sym_trans, symbols.len(), n_states, |_| sentinel, 0); + let byte_table = kmp_byte_transitions(needle); + let sym_trans = + build_symbol_transitions(symbols, symbol_lengths, &byte_table, n_states, accept_state); + let transitions = build_fused_table(&sym_trans, symbols.len(), n_states, |_| sentinel, 0); - let escape_transitions = byte_table; - - Ok(Self { - transitions, - accept_state, - escape: EscapeStrategy::Sentinel { - escape_transitions, - sentinel, - }, - }) - } + Ok(Self { + transitions, + escape_transitions: byte_table, + accept_state, + sentinel, + }) } #[inline(never)] pub(crate) fn matches(&self, codes: &[u8]) -> bool { - match &self.escape { - EscapeStrategy::Folded => self.matches_folded(codes), - EscapeStrategy::Sentinel { - escape_transitions, - sentinel, - } => Self::matches_sentinel( - codes, - &self.transitions, - escape_transitions, - self.accept_state, - *sentinel, - ), - } - } - - /// Branchless scan: escape handling is folded into the state space. - #[inline(always)] - fn matches_folded(&self, codes: &[u8]) -> bool { - let mut state = 0u8; - for &byte in codes { - state = self.transitions[usize::from(state) * 256 + usize::from(byte)]; - } - state == self.accept_state - } - - /// Sentinel scan: escape code triggers a separate table lookup. - #[inline(always)] - fn matches_sentinel( - codes: &[u8], - transitions: &[u8], - escape_transitions: &[u8], - accept_state: u8, - sentinel: u8, - ) -> bool { let mut state = 0u8; let mut pos = 0; while pos < codes.len() { let code = codes[pos]; pos += 1; - let next = transitions[usize::from(state) * 256 + usize::from(code)]; - if next == sentinel { + let next = self.transitions[usize::from(state) * 256 + usize::from(code)]; + if next == self.sentinel { if pos >= codes.len() { return false; } let b = codes[pos]; pos += 1; - state = escape_transitions[usize::from(state) * 256 + usize::from(b)]; + state = self.escape_transitions[usize::from(state) * 256 + usize::from(b)]; } else { state = next; } - if state == accept_state { + if state == self.accept_state { return true; } } diff --git a/encodings/fsst/src/dfa/mod.rs b/encodings/fsst/src/dfa/mod.rs index 49c710d6430..f4b73959d16 100644 --- a/encodings/fsst/src/dfa/mod.rs +++ b/encodings/fsst/src/dfa/mod.rs @@ -35,7 +35,7 @@ //! A single symbol can expand to 1–8 bytes. Matching on compressed codes requires //! the DFA to handle multi-byte symbol expansions and the escape mechanism. //! -//! ## The Algorithm: KMP → Byte Table → Symbol Table → Packed DFA +//! ## The Algorithm: KMP → Byte Table → Symbol Table → Flat DFA //! //! Construction proceeds through four stages: //! @@ -94,17 +94,17 @@ //! symbol transition; for code byte 255 (ESCAPE_CODE), transition to a //! special sentinel that tells the scanner to read the next literal byte. //! -//! ### Stage 4: Packing into the Final Representation +//! ### Stage 4: Flat `u8` Table //! -//! The fused table can be stored in different layouts depending on the number -//! of states: +//! The fused table is stored as a flat `Vec` indexed as +//! `transitions[state * 256 + byte]`. Both the prefix and contains DFAs use +//! escape-sentinel handling: when the scanner sees the sentinel value, it reads +//! the next byte from a separate byte-level escape table. //! -//! - **Shift-packed `u64`** (≤16 states): Each state needs 4 bits. All state -//! transitions for one input byte fit in a single `u64`. Lookup: -//! `next = (table[byte] >> (state * 4)) & 0xF`. One cache line per lookup. -//! -//! - **Flat `u8` table** (≤255 states): `transitions[state * 256 + byte]`. -//! Larger, but still bounded by the `u8` state representation. +//! TODO(joe): for short contains needles (≤7 bytes), a branchless escape-folded +//! DFA with hierarchical 4-byte composition is ~2x faster. For needles ≤127 +//! bytes, an escape-folded flat DFA (2N+1 states) avoids the sentinel branch. +//! See commit 7faf9f36f for those implementations. //! //! ## State-Space Limits //! @@ -114,67 +114,24 @@ //! - `prefix%` pushdown is limited to **253 bytes**. The flat prefix DFA uses //! `u8` state ids and needs room for progress states, an accept state, a //! fail state, and one escape sentinel (N+3 ≤ 256). -//! - `%needle%` pushdown is limited to **254 bytes**. The long-needle DFA stores +//! - `%needle%` pushdown is limited to **254 bytes**. The contains DFA stores //! states in `u8`, so it needs room for every match-progress state plus both //! the accept state and the escape sentinel. //! //! Patterns beyond those limits are still valid LIKE patterns; they simply do //! not use FSST pushdown and must be evaluated through the fallback path. -//! -//! ## DFA Variants and When Each Is Used -//! -//! ```text -//! ┌───────────────┬──────────────────────────────────────────────────────┐ -//! │ Pattern │ Needle length → DFA variant │ -//! ├───────────────┼──────────────────────────────────────────────────────┤ -//! │ prefix% │ 0–253 → FlatPrefixDfa (flat u8, esc-sentinel) │ -//! ├───────────────┼──────────────────────────────────────────────────────┤ -//! │ %needle% │ 1–7 → BranchlessShiftDfa (hierarchical 4-byte) │ -//! │ │ 8–127 → FlatContainsDfa (flat u8, esc-folded) │ -//! │ │ 128–254 → FlatContainsDfa (flat u8, esc-sentinel) │ -//! └───────────────┴──────────────────────────────────────────────────────┘ -//! ``` -//! -//! ## Escape Handling Strategies -//! -//! There are two ways to handle the FSST escape code in the DFA: -//! -//! **Escape sentinel** (used by `FlatContainsDfa` for long needles, `FlatPrefixDfa`): -//! The escape code maps to a sentinel state. The scanner checks for it and -//! reads the next byte from a separate escape transition table. -//! -//! ```text -//! loop: -//! state = transitions[byte] // might be sentinel -//! if state == SENTINEL: -//! state = escape_transitions[next_byte] // branch -//! ``` -//! -//! **Escape folding** (used by `BranchlessShiftDfa`, `FlatContainsDfa` for short needles): -//! Escape states are folded into the state space. State `s+N+1` means "was in -//! state `s`, just consumed ESCAPE_CODE". The next byte's transition from an -//! escape state uses the byte-level table. No branch needed in the scanner. -//! -//! ```text -//! States: [0..N-1: normal] [N: accept] [N+1..2N: escape shadows] -//! Total: 2N+1 states. With 4-bit packing, max N=7. -//! -//! loop: -//! state = transitions[state][byte] // branchless! -//! ``` -mod branchless_shift; mod flat_contains; mod prefix; #[cfg(test)] mod tests; -use branchless_shift::BranchlessShiftDfa; use flat_contains::FlatContainsDfa; use fsst::ESCAPE_CODE; use fsst::Symbol; use prefix::FlatPrefixDfa; use vortex_buffer::BitBuffer; +use vortex_error::VortexExpect; use vortex_error::VortexResult; // --------------------------------------------------------------------------- @@ -194,8 +151,7 @@ pub(crate) struct FsstMatcher { enum MatcherInner { MatchAll, Prefix(FlatPrefixDfa), - ContainsBranchless(Box), - ContainsFlat(FlatContainsDfa), + Contains(FlatContainsDfa), } impl FsstMatcher { @@ -227,19 +183,7 @@ impl FsstMatcher { if needle.len() > FlatContainsDfa::MAX_NEEDLE_LEN { return Ok(None); } - if needle.len() <= BranchlessShiftDfa::MAX_NEEDLE_LEN { - MatcherInner::ContainsBranchless(Box::new(BranchlessShiftDfa::new( - symbols, - symbol_lengths, - needle, - )?)) - } else { - MatcherInner::ContainsFlat(FlatContainsDfa::new( - symbols, - symbol_lengths, - needle, - )?) - } + MatcherInner::Contains(FlatContainsDfa::new(symbols, symbol_lengths, needle)?) } }; @@ -252,8 +196,7 @@ impl FsstMatcher { match &self.inner { MatcherInner::MatchAll => true, MatcherInner::Prefix(dfa) => dfa.matches(codes), - MatcherInner::ContainsBranchless(dfa) => dfa.matches(codes), - MatcherInner::ContainsFlat(dfa) => dfa.matches(codes), + MatcherInner::Contains(dfa) => dfa.matches(codes), } } } @@ -311,36 +254,6 @@ where }) } -// --------------------------------------------------------------------------- -// Shared helpers — used by multiple DFA implementations -// --------------------------------------------------------------------------- - -/// Extract a state id from a shift-packed `u64` word. -/// -/// Each state occupies `bits` bits. The mask `(1 << bits) - 1` guarantees the -/// result is at most 15 (for `bits = 4`), which always fits in `u8`. -#[expect( - clippy::cast_possible_truncation, - reason = "masked to `bits` bits (≤4), result ≤ 15" -)] -#[inline(always)] -fn shift_extract(packed: u64, state: u8, bits: u32) -> u8 { - let mask = (1u64 << bits) - 1; - ((packed >> (u32::from(state) * bits)) & mask) as u8 -} - -/// Compose two shift-packed transition `u64`s: for each state, apply `first` -/// then `second`, packing the result back into a single `u64`. -fn compose_packed(first: u64, second: u64, total_states: u8, bits: u32) -> u64 { - let mut packed = 0u64; - for state in 0..total_states { - let mid = shift_extract(first, state, bits); - let final_s = shift_extract(second, mid, bits); - packed |= u64::from(final_s) << (u32::from(state) * bits); - } - packed -} - // --------------------------------------------------------------------------- // DFA construction helpers // --------------------------------------------------------------------------- @@ -358,29 +271,24 @@ fn build_symbol_transitions( n_states: u8, accept_state: u8, ) -> Vec { - let n_states = usize::from(n_states); let n_symbols = symbols.len(); - let mut sym_trans = vec![0u8; n_states * n_symbols]; + let mut sym_trans = vec![0u8; n_states as usize * n_symbols]; for state in 0..n_states { for code in 0..n_symbols { - if state == usize::from(accept_state) { - sym_trans[state * n_symbols + code] = accept_state; + if state == accept_state { + sym_trans[state as usize * n_symbols + code] = accept_state; continue; } let sym = symbols[code].to_u64().to_le_bytes(); let sym_len = usize::from(symbol_lengths[code]); - #[expect( - clippy::cast_possible_truncation, - reason = "state < n_states ≤ 256" - )] - let mut s = state as u8; + let mut s = state; for &b in &sym[..sym_len] { if s == accept_state { break; } - s = byte_table[usize::from(s) * 256 + usize::from(b)]; + s = byte_table[s as usize * 256 + b as usize]; } - sym_trans[state * n_symbols + code] = s; + sym_trans[state as usize * n_symbols + code] = s; } } sym_trans @@ -412,95 +320,24 @@ fn build_fused_table( fused } -/// Packs a fused table into shift-encoded `u64` arrays. -/// -/// Each `u64` encodes transitions for ALL states for one input byte. -/// Lookup: `next = (table[byte] >> (state * BITS)) & MASK`. -fn pack_shift_table(fused: &[u8], n_states: u8, bits: u32) -> [u64; 256] { - let mut packed = [0u64; 256]; - for code_byte in 0..256usize { - let mut val = 0u64; - for state in 0..n_states { - val |= - u64::from(fused[usize::from(state) * 256 + code_byte]) << (u32::from(state) * bits); - } - packed[code_byte] = val; - } - packed -} - -/// Builds an escape-folded fused transition table for contains matching. -/// -/// State layout: `[0..n-1]` match progress, `[n]` accept (sticky), `[n+1..2n]` escape shadows. -/// Total states: `2 * needle.len() + 1`. -/// -/// For normal states, the escape code maps to the corresponding escape shadow state. -/// Escape shadow states use byte-level KMP transitions so the next literal byte -/// resumes matching correctly — no branch needed in the scanner. -fn build_escape_folded_table(symbols: &[Symbol], symbol_lengths: &[u8], needle: &[u8]) -> Vec { - #[expect( - clippy::cast_possible_truncation, - reason = "needle.len() ≤ FlatContainsDfa::MAX_FOLDED_LEN (127)" - )] - let n = needle.len() as u8; - let accept_state = n; - let total_states = usize::from(2 * n + 1); - - let byte_table = kmp_byte_transitions(needle); - let sym_trans = - build_symbol_transitions(symbols, symbol_lengths, &byte_table, n + 1, accept_state); - - let n_symbols = symbols.len(); - let n_usize = usize::from(n); - let mut fused = vec![0u8; total_states * 256]; - for code_byte in 0..256usize { - // Normal states 0..n - for s in 0..n_usize { - if code_byte == usize::from(ESCAPE_CODE) { - #[expect(clippy::cast_possible_truncation, reason = "s + n + 1 ≤ 2*127+1 = 255")] - { - fused[s * 256 + code_byte] = (s + n_usize + 1) as u8; - } - } else if code_byte < n_symbols { - fused[s * 256 + code_byte] = sym_trans[s * n_symbols + code_byte]; - } - } - // Accept state (sticky) - fused[n_usize * 256 + code_byte] = accept_state; - // Escape shadow states n+1..2n - for s in 0..n_usize { - let esc_state = s + n_usize + 1; - fused[esc_state * 256 + code_byte] = byte_table[s * 256 + code_byte]; - } - } - fused -} - // --------------------------------------------------------------------------- // KMP helpers // --------------------------------------------------------------------------- fn kmp_byte_transitions(needle: &[u8]) -> Vec { - let n_states = needle.len() + 1; - #[expect( - clippy::cast_possible_truncation, - reason = "needle.len() ≤ 254, accept state fits in u8" - )] - let accept = needle.len() as u8; + let n_states = u8::try_from(needle.len() + 1) + .vortex_expect("kmp_byte_transitions: must have needle.len() ≤ 255"); + let accept = n_states - 1; let failure = kmp_failure_table(needle); - let mut table = vec![0u8; n_states * 256]; + let mut table = vec![0u8; n_states as usize * 256]; for state in 0..n_states { for byte in 0..256usize { - if state == needle.len() { - table[state * 256 + byte] = accept; + if state == accept { + table[state as usize * 256 + byte] = accept; continue; } - #[expect( - clippy::cast_possible_truncation, - reason = "state < needle.len() ≤ 254" - )] - let mut s = state as u8; + let mut s = state; loop { if byte == usize::from(needle[usize::from(s)]) { s += 1; @@ -511,7 +348,7 @@ fn kmp_byte_transitions(needle: &[u8]) -> Vec { } s = failure[usize::from(s) - 1]; } - table[state * 256 + byte] = s; + table[state as usize * 256 + byte] = s; } } table diff --git a/encodings/fsst/src/dfa/prefix.rs b/encodings/fsst/src/dfa/prefix.rs index c30da87f3ba..20a07c2aaa3 100644 --- a/encodings/fsst/src/dfa/prefix.rs +++ b/encodings/fsst/src/dfa/prefix.rs @@ -20,34 +20,6 @@ use vortex_error::vortex_bail; use super::build_fused_table; use super::build_symbol_transitions; -/// Build a byte-level transition table for prefix matching (no KMP fallback). -/// -/// For each state, only the correct next byte advances; everything else goes -/// to the fail state. -fn build_prefix_byte_table(prefix: &[u8], accept_state: u8, fail_state: u8) -> Vec { - let n_states = fail_state + 1; - let mut table = vec![fail_state; usize::from(n_states) * 256]; - - for state in 0..n_states { - let s = usize::from(state); - if state == accept_state { - for byte in 0..256 { - table[s * 256 + byte] = accept_state; - } - } else if state != fail_state { - // Only the correct next byte advances; everything else fails. - let next_byte = prefix[s]; - let next_state = if s + 1 >= prefix.len() { - accept_state - } else { - state + 1 - }; - table[s * 256 + usize::from(next_byte)] = next_state; - } - } - table -} - /// Flat `u8` transition table DFA for prefix matching on FSST codes. /// /// States 0..prefix_len track match progress, plus ACCEPT, FAIL, and an @@ -80,7 +52,7 @@ pub(crate) struct FlatPrefixDfa { } impl FlatPrefixDfa { - pub(crate) const MAX_PREFIX_LEN: usize = 253; + pub(crate) const MAX_PREFIX_LEN: usize = (u8::MAX - 2) as usize; pub(crate) fn new( symbols: &[Symbol], @@ -153,3 +125,31 @@ impl FlatPrefixDfa { state == self.accept_state } } + +/// Build a byte-level transition table for prefix matching (no KMP fallback). +/// +/// For each state, only the correct next byte advances; everything else goes +/// to the fail state. +fn build_prefix_byte_table(prefix: &[u8], accept_state: u8, fail_state: u8) -> Vec { + let n_states = fail_state + 1; + let mut table = vec![fail_state; usize::from(n_states) * 256]; + + for state in 0..n_states { + let s = usize::from(state); + if state == accept_state { + for byte in 0..256 { + table[s * 256 + byte] = accept_state; + } + } else if state != fail_state { + // Only the correct next byte advances; everything else fails. + let next_byte = prefix[s]; + let next_state = if s + 1 >= prefix.len() { + accept_state + } else { + state + 1 + }; + table[s * 256 + usize::from(next_byte)] = next_state; + } + } + table +} From 9d278dc5e5d04a120e74969de81d2a9652f5d44f Mon Sep 17 00:00:00 2001 From: Joe Isaacs Date: Thu, 19 Mar 2026 14:09:10 +0000 Subject: [PATCH 10/19] fixup Signed-off-by: Joe Isaacs --- encodings/fsst/src/dfa/flat_contains.rs | 2 +- encodings/fsst/src/dfa/prefix.rs | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/encodings/fsst/src/dfa/flat_contains.rs b/encodings/fsst/src/dfa/flat_contains.rs index 64ce71117fa..b2b62d1d0a5 100644 --- a/encodings/fsst/src/dfa/flat_contains.rs +++ b/encodings/fsst/src/dfa/flat_contains.rs @@ -6,6 +6,7 @@ //! Uses an escape-sentinel strategy: the FSST escape code maps to a sentinel //! state, and the next literal byte is looked up in a separate byte-level //! transition table. +//! This is to support needles up to u8::MAX long. //! //! ## Construction (needle = `"aba"`, symbols = `[0:"ab", 1:"ba"]`) //! @@ -122,7 +123,6 @@ impl FlatContainsDfa { }) } - #[inline(never)] pub(crate) fn matches(&self, codes: &[u8]) -> bool { let mut state = 0u8; let mut pos = 0; diff --git a/encodings/fsst/src/dfa/prefix.rs b/encodings/fsst/src/dfa/prefix.rs index 20a07c2aaa3..92667aba928 100644 --- a/encodings/fsst/src/dfa/prefix.rs +++ b/encodings/fsst/src/dfa/prefix.rs @@ -9,7 +9,7 @@ //! TODO(joe): for short prefixes (≤13 bytes), a shift-packed `[u64; 256]` //! representation would be simpler and easier to read — all state transitions //! for one input byte fit in a single `u64`. Benchmarks showed no meaningful -//! perf difference (see `benches/BENCH_RESULTS.md`), so we use flat-only for +//! perf difference, so we use flat-only for //! now to keep the code simple and support long prefixes. use fsst::Symbol; @@ -97,7 +97,6 @@ impl FlatPrefixDfa { }) } - #[inline] pub(crate) fn matches(&self, codes: &[u8]) -> bool { let mut state = 0u8; let mut pos = 0; From 6c8e616c2d66173c3a5ca7dfeaedeebe0c868037 Mon Sep 17 00:00:00 2001 From: Joe Isaacs Date: Thu, 19 Mar 2026 14:19:09 +0000 Subject: [PATCH 11/19] fixup Signed-off-by: Joe Isaacs --- fuzz/src/fsst_like.rs | 44 +++++++++++++++++++++---------------------- 1 file changed, 22 insertions(+), 22 deletions(-) diff --git a/fuzz/src/fsst_like.rs b/fuzz/src/fsst_like.rs index 866b078eae0..8ee332e6c9f 100644 --- a/fuzz/src/fsst_like.rs +++ b/fuzz/src/fsst_like.rs @@ -33,6 +33,21 @@ use crate::error::VortexFuzzResult; static SESSION: LazyLock = LazyLock::new(|| VortexSession::empty().with::()); +/// A random string from a small alphabet (`a..=h`) with bounded length. +#[derive(Debug)] +struct SmallAlphabetString { + max_len: usize, +} + +impl SmallAlphabetString { + fn generate(&self, u: &mut Unstructured<'_>) -> arbitrary::Result { + let len: usize = u.int_in_range(0..=self.max_len)?; + Ok((0..len) + .map(|_| u.int_in_range(b'a'..=b'h').expect("cannot make char") as char) + .collect()) + } +} + /// Fuzz input: a set of strings and a LIKE pattern. #[derive(Debug)] pub struct FuzzFsstLike { @@ -43,34 +58,19 @@ pub struct FuzzFsstLike { impl<'a> Arbitrary<'a> for FuzzFsstLike { fn arbitrary(u: &mut Unstructured<'a>) -> arbitrary::Result { - // Generate 1-200 strings, each 0-100 bytes from a small alphabet - // to increase FSST symbol reuse and substring hits. let n_strings: usize = u.int_in_range(1..=200)?; - let mut strings = Vec::with_capacity(n_strings); - for _ in 0..n_strings { - let len: usize = u.int_in_range(0..=100)?; - let s: String = (0..len) - .map(|_| { - let b = u.int_in_range(b'a'..=b'h').unwrap_or(b'a'); - b as char - }) - .collect(); - strings.push(s); - } + let str_gen = SmallAlphabetString { max_len: 512 }; + let strings: Vec = (0..n_strings) + .map(|_| str_gen.generate(u)) + .collect::>()?; - // Generate a pattern: pick a shape then fill in the literal part. - let needle_len: usize = u.int_in_range(0..=30)?; - let needle: String = (0..needle_len) - .map(|_| { - let b = u.int_in_range(b'a'..=b'h').unwrap_or(b'a'); - b as char - }) - .collect(); + let needle = SmallAlphabetString { max_len: 254 }.generate(u)?; let pattern = match u.int_in_range(0..=2)? { 0 => format!("{needle}%"), // prefix 1 => format!("%{needle}%"), // contains - _ => format!("%{needle}"), // suffix (should fall back, still correct) + 2 => format!("%{needle}"), // suffix (should fall back, still correct) + _ => unreachable!(""), }; let negated: bool = u.arbitrary()?; From 2113488110bc60075b618871ae44851f8aa0976d Mon Sep 17 00:00:00 2001 From: Joe Isaacs Date: Thu, 19 Mar 2026 14:22:57 +0000 Subject: [PATCH 12/19] fixup Signed-off-by: Joe Isaacs --- encodings/fsst/src/dfa/flat_contains.rs | 4 +++- encodings/fsst/src/dfa/prefix.rs | 5 +---- fuzz/src/fsst_like.rs | 6 +++--- 3 files changed, 7 insertions(+), 8 deletions(-) diff --git a/encodings/fsst/src/dfa/flat_contains.rs b/encodings/fsst/src/dfa/flat_contains.rs index b2b62d1d0a5..a6cdac24cf6 100644 --- a/encodings/fsst/src/dfa/flat_contains.rs +++ b/encodings/fsst/src/dfa/flat_contains.rs @@ -10,7 +10,9 @@ //! //! ## Construction (needle = `"aba"`, symbols = `[0:"ab", 1:"ba"]`) //! -//! ### Step 1: KMP byte-level transition table +//! ### Step 1: KMP (Knuth–Morris–Pratt) byte-level transition table +//! +//! See: https://en.wikipedia.org/wiki/Knuth%E2%80%93Morris%E2%80%93Pratt_algorithm //! //! Build a `(state × byte) → state` table using the KMP failure function. //! States 0..2 track match progress, state 3 is accept (sticky). diff --git a/encodings/fsst/src/dfa/prefix.rs b/encodings/fsst/src/dfa/prefix.rs index 92667aba928..5c40affd7a1 100644 --- a/encodings/fsst/src/dfa/prefix.rs +++ b/encodings/fsst/src/dfa/prefix.rs @@ -72,14 +72,11 @@ impl FlatPrefixDfa { let n_states = fail_state + 1; let sentinel = fail_state + 1; - // Step 1: byte-level transitions let byte_table = build_prefix_byte_table(prefix, accept_state, fail_state); - // Step 2: symbol-level transitions let sym_trans = build_symbol_transitions(symbols, symbol_lengths, &byte_table, n_states, accept_state); - // Step 3: fused table with escape sentinel let transitions = build_fused_table( &sym_trans, symbols.len(), @@ -125,7 +122,7 @@ impl FlatPrefixDfa { } } -/// Build a byte-level transition table for prefix matching (no KMP fallback). +/// Build a byte-level transition table for prefix matching. /// /// For each state, only the correct next byte advances; everything else goes /// to the fail state. diff --git a/fuzz/src/fsst_like.rs b/fuzz/src/fsst_like.rs index 8ee332e6c9f..646dfb20a96 100644 --- a/fuzz/src/fsst_like.rs +++ b/fuzz/src/fsst_like.rs @@ -42,9 +42,9 @@ struct SmallAlphabetString { impl SmallAlphabetString { fn generate(&self, u: &mut Unstructured<'_>) -> arbitrary::Result { let len: usize = u.int_in_range(0..=self.max_len)?; - Ok((0..len) - .map(|_| u.int_in_range(b'a'..=b'h').expect("cannot make char") as char) - .collect()) + (0..len) + .map(|_| Ok(u.int_in_range(b'a'..=b'h')? as char)) + .collect() } } From 173cbda9827a629b7b1181b02a43f9b9103c7642 Mon Sep 17 00:00:00 2001 From: Joe Isaacs Date: Thu, 19 Mar 2026 14:27:11 +0000 Subject: [PATCH 13/19] fixup Signed-off-by: Joe Isaacs --- vortex-layout/src/layouts/dict/reader.rs | 21 ++++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/vortex-layout/src/layouts/dict/reader.rs b/vortex-layout/src/layouts/dict/reader.rs index 5054fcd27f3..3649596919e 100644 --- a/vortex-layout/src/layouts/dict/reader.rs +++ b/vortex-layout/src/layouts/dict/reader.rs @@ -106,6 +106,25 @@ impl DictReader { .clone() } + // This is the dict values array without canonicalization, if not already canonical + fn values_array_uncanonical(&self) -> SharedArrayFuture { + // We capture the name, so it may be wrong if we re-use the same reader within multiple + // different parent readers. But that's rare... + let values_len = self.values_len; + self.values_array.get().cloned().unwrap_or_else(|| { + self.values + .projection_evaluation( + &(0..values_len as u64), + &root(), + MaskFuture::new_true(values_len), + ) + .vortex_expect("must construct dict values array evaluation") + .map_err(Arc::new) + .boxed() + .shared() + }) + } + fn values_eval(&self, expr: Expression) -> SharedArrayFuture { // This is unsound since we cannot be sure that all the values are referenced in the query // after applying the filter, so if the expression is fallible this might fail when it @@ -120,7 +139,7 @@ impl DictReader { self.values_evals .entry(expr.clone()) .or_insert_with(|| { - self.values_array() + self.values_array_uncanonical() .map(move |array| { let array = array?.apply(&expr)?; Ok(SharedArray::new(array).into_array()) From 5881376fcbe2ad98e8cd4bf4579ab3816a0da147 Mon Sep 17 00:00:00 2001 From: Joe Isaacs Date: Thu, 19 Mar 2026 14:43:07 +0000 Subject: [PATCH 14/19] fixup Signed-off-by: Joe Isaacs --- encodings/fsst/src/dfa/mod.rs | 2 +- encodings/fsst/src/dfa/tests.rs | 95 ++++++++++++++++++++++++ fuzz/src/fsst_like.rs | 13 +++- vortex-layout/src/layouts/dict/reader.rs | 2 +- 4 files changed, 106 insertions(+), 6 deletions(-) diff --git a/encodings/fsst/src/dfa/mod.rs b/encodings/fsst/src/dfa/mod.rs index f4b73959d16..f810da5649f 100644 --- a/encodings/fsst/src/dfa/mod.rs +++ b/encodings/fsst/src/dfa/mod.rs @@ -170,7 +170,7 @@ impl FsstMatcher { }; let inner = match like_kind { - LikeKind::Prefix("") => MatcherInner::MatchAll, + LikeKind::Prefix("") | LikeKind::Contains("") => MatcherInner::MatchAll, LikeKind::Prefix(prefix) => { let prefix = prefix.as_bytes(); if prefix.len() > FlatPrefixDfa::MAX_PREFIX_LEN { diff --git a/encodings/fsst/src/dfa/tests.rs b/encodings/fsst/src/dfa/tests.rs index 8b2a99aa4c9..05362eac154 100644 --- a/encodings/fsst/src/dfa/tests.rs +++ b/encodings/fsst/src/dfa/tests.rs @@ -1,14 +1,36 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright the Vortex contributors +use std::sync::LazyLock; + use fsst::ESCAPE_CODE; use fsst::Symbol; +use rstest::rstest; +use vortex_array::Canonical; +use vortex_array::IntoArray; +use vortex_array::VortexSessionExecute; +use vortex_array::arrays::BoolArray; +use vortex_array::arrays::ConstantArray; +use vortex_array::arrays::VarBinArray; +use vortex_array::assert_arrays_eq; +use vortex_array::dtype::DType; +use vortex_array::dtype::Nullability; +use vortex_array::scalar_fn::fns::like::Like; +use vortex_array::scalar_fn::fns::like::LikeOptions; +use vortex_array::session::ArraySession; use vortex_error::VortexResult; +use vortex_session::VortexSession; use super::FsstMatcher; use super::LikeKind; use super::flat_contains::FlatContainsDfa; use super::prefix::FlatPrefixDfa; +use crate::FSSTArray; +use crate::fsst_compress; +use crate::fsst_train_compressor; + +static SESSION: LazyLock = + LazyLock::new(|| VortexSession::empty().with::()); /// Helper: make a Symbol from a byte string (up to 8 bytes, zero-padded). fn sym(bytes: &[u8]) -> Symbol { @@ -182,3 +204,76 @@ fn test_contains_pushdown_rejects_len_255() { let pattern = format!("%{needle}%"); assert!(FsstMatcher::try_new(&[], &[], &pattern).unwrap().is_none()); } + +// --------------------------------------------------------------------------- +// End-to-end edge cases: FSST compress → LIKE → compare booleans +// --------------------------------------------------------------------------- + +fn make_fsst(strings: &[Option<&str>]) -> FSSTArray { + let varbin = VarBinArray::from_iter( + strings.iter().copied(), + DType::Utf8(Nullability::NonNullable), + ); + let compressor = fsst_train_compressor(&varbin); + fsst_compress(varbin, &compressor) +} + +fn run_like(array: FSSTArray, pattern: &str) -> VortexResult { + use vortex_array::ArrayRef; + use vortex_array::arrays::scalar_fn::ScalarFnArrayExt; + + let len = array.len(); + let arr: ArrayRef = array.into_array(); + let pattern_arr = ConstantArray::new(pattern, len).into_array(); + let result = Like + .try_new_array(len, LikeOptions::default(), [arr, pattern_arr])? + .into_array() + .execute::(&mut SESSION.create_execution_ctx())?; + Ok(result.into_bool()) +} + +#[rstest] +// Empty strings +#[case(&[""], "aaaa%", &[false])] +#[case(&[""], "%aaaa%", &[false])] +#[case(&[""], "%", &[true])] +#[case(&["", "", ""], "%", &[true, true, true])] +// Single-char patterns +#[case(&["a", "b", ""], "a%", &[true, false, false])] +#[case(&["a", "b", ""], "%a%", &[true, false, false])] +// Needle longer than every input string +#[case(&["ab", "abc", ""], "%abcd%", &[false, false, false])] +#[case(&["ab", "abc", ""], "abcd%", &[false, false, false])] +// Exact match (prefix pattern = entire string + %) +#[case(&["abc", "abcd", "ab"], "abc%", &[true, true, false])] +#[case(&["abc", "abcd", "ab"], "%abc%", &[true, true, false])] +// Repeated characters — KMP overlap +#[case(&["aa", "aaa", "aaaa", "aba"], "%aaa%", &[false, true, true, false])] +#[case(&["aab", "aaab", "a"], "aaa%", &[false, true, false])] +// Needle at different positions +#[case(&["xxabcyy", "abcyy", "xxabc", "abc", "xabx"], "%abc%", &[true, true, true, true, false])] +// All identical strings +#[case(&["aaa", "aaa", "aaa"], "%aaa%", &[true, true, true])] +#[case(&["aaa", "aaa", "aaa"], "bbb%", &[false, false, false])] +// Single element arrays +#[case(&["hello"], "hello%", &[true])] +#[case(&["hello"], "hellx%", &[false])] +#[case(&["hello"], "%ello%", &[true])] +#[case(&["hello"], "%ellx%", &[false])] +// Overlapping KMP pattern "abab" +#[case(&["ababab", "abab", "aba", "xababx"], "%abab%", &[true, true, false, true])] +// Prefix that shares chars with rest of string +#[case(&["abab", "abba", "abcd"], "ab%", &[true, true, true])] +#[case(&["abab", "abba", "abcd", "ba"], "ab%", &[true, true, true, false])] +fn test_like_edge_cases( + #[case] strings: &[&str], + #[case] pattern: &str, + #[case] expected: &[bool], +) -> VortexResult<()> { + let opts: Vec> = strings.iter().map(|s| Some(*s)).collect(); + let fsst = make_fsst(&opts); + let result = run_like(fsst, pattern)?; + let expected_arr = BoolArray::from_iter(expected.iter().copied()); + assert_arrays_eq!(&result, &expected_arr); + Ok(()) +} diff --git a/fuzz/src/fsst_like.rs b/fuzz/src/fsst_like.rs index 646dfb20a96..f2f0f3bac48 100644 --- a/fuzz/src/fsst_like.rs +++ b/fuzz/src/fsst_like.rs @@ -133,10 +133,15 @@ pub fn run_fsst_like_fuzz(fuzz: FuzzFsstLike) -> VortexFuzzResult { let expected_val = expected_bits.value(idx); let actual_val = actual_bits.value(idx); if expected_val != actual_val { - return Err(VortexFuzzError::ScalarMismatch( - expected_val.into(), - actual_val.into(), - idx, + return Err(VortexFuzzError::VortexError( + vortex_error::vortex_err!( + "FSST LIKE mismatch at index {idx}:\n \ + pattern: {pattern:?}\n \ + string: {:?}\n \ + expected: {expected_val}\n \ + actual: {actual_val}", + &strings[idx], + ), Backtrace::capture(), )); } diff --git a/vortex-layout/src/layouts/dict/reader.rs b/vortex-layout/src/layouts/dict/reader.rs index 3649596919e..ded15f6ace0 100644 --- a/vortex-layout/src/layouts/dict/reader.rs +++ b/vortex-layout/src/layouts/dict/reader.rs @@ -225,7 +225,7 @@ impl LayoutReader for DictReader { mask: MaskFuture, ) -> VortexResult>> { // TODO: fix up expr partitioning with fallible & null sensitive annotations - let values_eval = self.values_eval(root()); + let values_eval = self.values_array(); let codes_eval = self .codes .projection_evaluation(row_range, &root(), mask) From ef16508f0d2545f13b502cc1fca550b51049c1a1 Mon Sep 17 00:00:00 2001 From: Joe Isaacs Date: Thu, 19 Mar 2026 15:04:59 +0000 Subject: [PATCH 15/19] fixup Signed-off-by: Joe Isaacs --- encodings/fsst/src/compute/like.rs | 120 ----------------------------- encodings/fsst/src/dfa/tests.rs | 2 + 2 files changed, 2 insertions(+), 120 deletions(-) diff --git a/encodings/fsst/src/compute/like.rs b/encodings/fsst/src/compute/like.rs index 732708a64c1..be663b16c9f 100644 --- a/encodings/fsst/src/compute/like.rs +++ b/encodings/fsst/src/compute/like.rs @@ -351,124 +351,4 @@ mod tests { assert_arrays_eq!(direct.unwrap(), BoolArray::from_iter([true, false, true])); Ok(()) } - - // ----------------------------------------------------------------------- - // Fuzz tests: compare FSST kernel against naive string matching - // ----------------------------------------------------------------------- - - fn random_string(rng: &mut StdRng, max_len: usize) -> String { - let len = rng.random_range(0..=max_len); - // Use a small alphabet to increase substring hit rate. - (0..len) - .map(|_| (b'a' + rng.random_range(0..6u8)) as char) - .collect() - } - - fn fuzz_contains(seed: u64, needle_len: usize, n_strings: usize) -> VortexResult<()> { - let mut rng = StdRng::seed_from_u64(seed); - - let needle: String = (0..needle_len) - .map(|_| (b'a' + rng.random_range(0..6u8)) as char) - .collect(); - - let owned: Vec = (0..n_strings) - .map(|_| random_string(&mut rng, 80)) - .collect(); - let strings: Vec> = owned.iter().map(|s| Some(s.as_str())).collect(); - - let expected: Vec = owned.iter().map(|s| s.contains(&needle)).collect(); - - let fsst = make_fsst(&strings, Nullability::NonNullable); - let pattern = format!("%{needle}%"); - let result = run_like(fsst, &pattern, LikeOptions::default())?; - - let got: Vec = (0..n_strings) - .map(|i| result.to_bit_buffer().value(i)) - .collect(); - - for (i, (e, g)) in expected.iter().zip(got.iter()).enumerate() { - assert_eq!( - e, g, - "mismatch at index {i}: string={:?}, needle={needle:?}, expected={e}, got={g}", - &owned[i], - ); - } - Ok(()) - } - - fn fuzz_prefix(seed: u64, prefix_len: usize, n_strings: usize) -> VortexResult<()> { - let mut rng = StdRng::seed_from_u64(seed); - - let prefix: String = (0..prefix_len) - .map(|_| (b'a' + rng.random_range(0..6u8)) as char) - .collect(); - - let owned: Vec = (0..n_strings) - .map(|_| random_string(&mut rng, 80)) - .collect(); - let strings: Vec> = owned.iter().map(|s| Some(s.as_str())).collect(); - - let expected: Vec = owned.iter().map(|s| s.starts_with(&prefix)).collect(); - - let fsst = make_fsst(&strings, Nullability::NonNullable); - let pattern = format!("{prefix}%"); - let result = run_like(fsst, &pattern, LikeOptions::default())?; - - let got: Vec = (0..n_strings) - .map(|i| result.to_bit_buffer().value(i)) - .collect(); - - for (i, (e, g)) in expected.iter().zip(got.iter()).enumerate() { - assert_eq!( - e, g, - "mismatch at index {i}: string={:?}, prefix={prefix:?}, expected={e}, got={g}", - &owned[i], - ); - } - Ok(()) - } - - /// Fuzz contains with short needles (1-7 chars) -> BranchlessShiftDfa - #[test] - fn fuzz_contains_short_needle() -> VortexResult<()> { - for seed in 0..50 { - for needle_len in 1..=7 { - fuzz_contains(seed, needle_len, 200)?; - } - } - Ok(()) - } - - /// Fuzz contains with medium needles (8-14 chars) -> FlatBranchlessDfa - #[test] - fn fuzz_contains_medium_needle() -> VortexResult<()> { - for seed in 0..50 { - for needle_len in [8, 10, 14] { - fuzz_contains(seed, needle_len, 200)?; - } - } - Ok(()) - } - - /// Fuzz contains with long needles (>14 chars) -> FsstContainsDfa - #[test] - fn fuzz_contains_long_needle() -> VortexResult<()> { - for seed in 0..30 { - for needle_len in [15, 20, 30] { - fuzz_contains(seed, needle_len, 200)?; - } - } - Ok(()) - } - - /// Fuzz prefix matching - #[test] - fn fuzz_prefix_matching() -> VortexResult<()> { - for seed in 0..50 { - for prefix_len in [1, 3, 5, 10, 13, 20, 40] { - fuzz_prefix(seed, prefix_len, 200)?; - } - } - Ok(()) - } } diff --git a/encodings/fsst/src/dfa/tests.rs b/encodings/fsst/src/dfa/tests.rs index 05362eac154..663ed90b6d9 100644 --- a/encodings/fsst/src/dfa/tests.rs +++ b/encodings/fsst/src/dfa/tests.rs @@ -237,7 +237,9 @@ fn run_like(array: FSSTArray, pattern: &str) -> VortexResult { #[case(&[""], "aaaa%", &[false])] #[case(&[""], "%aaaa%", &[false])] #[case(&[""], "%", &[true])] +#[case(&[""], "%%", &[true])] #[case(&["", "", ""], "%", &[true, true, true])] +#[case(&["", "abc", ""], "%%", &[true, true, true])] // Single-char patterns #[case(&["a", "b", ""], "a%", &[true, false, false])] #[case(&["a", "b", ""], "%a%", &[true, false, false])] From ef2445b8c1b0aca0907ab7084608b3c3b2d9fb46 Mon Sep 17 00:00:00 2001 From: Joe Isaacs Date: Thu, 19 Mar 2026 16:20:36 +0000 Subject: [PATCH 16/19] fixup Signed-off-by: Joe Isaacs --- encodings/fsst/src/dfa/mod.rs | 2 -- 1 file changed, 2 deletions(-) diff --git a/encodings/fsst/src/dfa/mod.rs b/encodings/fsst/src/dfa/mod.rs index f810da5649f..35fe84509d1 100644 --- a/encodings/fsst/src/dfa/mod.rs +++ b/encodings/fsst/src/dfa/mod.rs @@ -191,7 +191,6 @@ impl FsstMatcher { } /// Run the matcher on a single FSST-compressed code sequence. - #[inline] pub(crate) fn matches(&self, codes: &[u8]) -> bool { match &self.inner { MatcherInner::MatchAll => true, @@ -233,7 +232,6 @@ impl<'a> LikeKind<'a> { // --------------------------------------------------------------------------- // TODO: add N-way ILP overrun scan for higher throughput on short strings. -#[inline] pub(crate) fn dfa_scan_to_bitbuf( n: usize, offsets: &[T], From b013dad2605212ee90338945831c167019f06ae0 Mon Sep 17 00:00:00 2001 From: Joe Isaacs Date: Thu, 19 Mar 2026 16:41:00 +0000 Subject: [PATCH 17/19] fixup Signed-off-by: Joe Isaacs --- _typos.toml | 2 +- encodings/fsst/src/compute/like.rs | 17 ++++-- encodings/fsst/src/dfa/flat_contains.rs | 2 +- encodings/fsst/src/dfa/mod.rs | 21 ++++--- encodings/fsst/src/dfa/tests.rs | 78 ++++++++++++++++++------- 5 files changed, 81 insertions(+), 39 deletions(-) diff --git a/_typos.toml b/_typos.toml index 4e482a52d3d..8bfc27178c7 100644 --- a/_typos.toml +++ b/_typos.toml @@ -8,7 +8,7 @@ extend-ignore-re = [ ] [files] -extend-exclude = ["/vortex-bench/**", "/docs/references.bib", "benchmarks/**", "vortex-sqllogictest/slt/**"] +extend-exclude = ["/vortex-bench/**", "/docs/references.bib", "benchmarks/**", "vortex-sqllogictest/slt/**", "encodings/fsst/src/dfa/tests.rs"] [type.py] extend-ignore-identifiers-re = [ diff --git a/encodings/fsst/src/compute/like.rs b/encodings/fsst/src/compute/like.rs index be663b16c9f..c438a9ab01e 100644 --- a/encodings/fsst/src/compute/like.rs +++ b/encodings/fsst/src/compute/like.rs @@ -33,7 +33,17 @@ impl LikeKernel for FSST { return Ok(None); } - let Some(pattern_str) = pattern_scalar.as_utf8().value() else { + let pattern_bytes: &[u8] = if let Some(s) = pattern_scalar.as_utf8_opt() { + let Some(v) = s.value() else { + return Ok(None); + }; + v.as_ref() + } else if let Some(b) = pattern_scalar.as_binary_opt() { + let Some(v) = b.value() else { + return Ok(None); + }; + v + } else { return Ok(None); }; @@ -41,7 +51,7 @@ impl LikeKernel for FSST { let symbol_lengths = array.symbol_lengths(); let Some(matcher) = - FsstMatcher::try_new(symbols.as_slice(), symbol_lengths.as_slice(), pattern_str)? + FsstMatcher::try_new(symbols.as_slice(), symbol_lengths.as_slice(), pattern_bytes)? else { return Ok(None); }; @@ -73,9 +83,6 @@ impl LikeKernel for FSST { mod tests { use std::sync::LazyLock; - use rand::Rng; - use rand::SeedableRng; - use rand::rngs::StdRng; use vortex_array::Canonical; use vortex_array::IntoArray; use vortex_array::VortexSessionExecute; diff --git a/encodings/fsst/src/dfa/flat_contains.rs b/encodings/fsst/src/dfa/flat_contains.rs index a6cdac24cf6..a33b82b4a0a 100644 --- a/encodings/fsst/src/dfa/flat_contains.rs +++ b/encodings/fsst/src/dfa/flat_contains.rs @@ -12,7 +12,7 @@ //! //! ### Step 1: KMP (Knuth–Morris–Pratt) byte-level transition table //! -//! See: https://en.wikipedia.org/wiki/Knuth%E2%80%93Morris%E2%80%93Pratt_algorithm +//! See: //! //! Build a `(state × byte) → state` table using the KMP failure function. //! States 0..2 track match progress, state 3 is accept (sticky). diff --git a/encodings/fsst/src/dfa/mod.rs b/encodings/fsst/src/dfa/mod.rs index 35fe84509d1..ce615268346 100644 --- a/encodings/fsst/src/dfa/mod.rs +++ b/encodings/fsst/src/dfa/mod.rs @@ -163,23 +163,21 @@ impl FsstMatcher { pub(crate) fn try_new( symbols: &[Symbol], symbol_lengths: &[u8], - pattern: &str, + pattern: &[u8], ) -> VortexResult> { let Some(like_kind) = LikeKind::parse(pattern) else { return Ok(None); }; let inner = match like_kind { - LikeKind::Prefix("") | LikeKind::Contains("") => MatcherInner::MatchAll, + LikeKind::Prefix(b"") | LikeKind::Contains(b"") => MatcherInner::MatchAll, LikeKind::Prefix(prefix) => { - let prefix = prefix.as_bytes(); if prefix.len() > FlatPrefixDfa::MAX_PREFIX_LEN { return Ok(None); } MatcherInner::Prefix(FlatPrefixDfa::new(symbols, symbol_lengths, prefix)?) } LikeKind::Contains(needle) => { - let needle = needle.as_bytes(); if needle.len() > FlatContainsDfa::MAX_NEEDLE_LEN { return Ok(None); } @@ -203,23 +201,24 @@ impl FsstMatcher { /// The subset of LIKE patterns we can handle without decompression. enum LikeKind<'a> { /// `prefix%` - Prefix(&'a str), + Prefix(&'a [u8]), /// `%needle%` - Contains(&'a str), + Contains(&'a [u8]), } impl<'a> LikeKind<'a> { - fn parse(pattern: &'a str) -> Option { + fn parse(pattern: &'a [u8]) -> Option { // `prefix%` (including just `%` where prefix is empty) - if let Some(prefix) = pattern.strip_suffix('%') - && !prefix.contains(['%', '_']) + if let Some(prefix) = pattern.strip_suffix(&[b'%']) + && !prefix.contains(&b'%') + && !prefix.contains(&b'_') { return Some(LikeKind::Prefix(prefix)); } // `%needle%` - let inner = pattern.strip_prefix('%')?.strip_suffix('%')?; - if !inner.contains(['%', '_']) { + let inner = pattern.strip_prefix(&[b'%'])?.strip_suffix(&[b'%'])?; + if !inner.contains(&b'%') && !inner.contains(&b'_') { return Some(LikeKind::Contains(inner)); } diff --git a/encodings/fsst/src/dfa/tests.rs b/encodings/fsst/src/dfa/tests.rs index 663ed90b6d9..d135014c42e 100644 --- a/encodings/fsst/src/dfa/tests.rs +++ b/encodings/fsst/src/dfa/tests.rs @@ -6,12 +6,14 @@ use std::sync::LazyLock; use fsst::ESCAPE_CODE; use fsst::Symbol; use rstest::rstest; +use vortex_array::ArrayRef; use vortex_array::Canonical; use vortex_array::IntoArray; use vortex_array::VortexSessionExecute; use vortex_array::arrays::BoolArray; use vortex_array::arrays::ConstantArray; use vortex_array::arrays::VarBinArray; +use vortex_array::arrays::scalar_fn::ScalarFnArrayExt; use vortex_array::assert_arrays_eq; use vortex_array::dtype::DType; use vortex_array::dtype::Nullability; @@ -51,17 +53,17 @@ fn escaped(bytes: &[u8]) -> Vec { #[test] fn test_like_kind_parse() { assert!(matches!( - LikeKind::parse("http%"), - Some(LikeKind::Prefix("http")) + LikeKind::parse(b"http%"), + Some(LikeKind::Prefix(b"http")) )); assert!(matches!( - LikeKind::parse("%needle%"), - Some(LikeKind::Contains("needle")) + LikeKind::parse(b"%needle%"), + Some(LikeKind::Contains(b"needle")) )); - assert!(matches!(LikeKind::parse("%"), Some(LikeKind::Prefix("")))); + assert!(matches!(LikeKind::parse(b"%"), Some(LikeKind::Prefix(b"")))); // Suffix and underscore patterns are not supported. - assert!(LikeKind::parse("%suffix").is_none()); - assert!(LikeKind::parse("a_c").is_none()); + assert!(LikeKind::parse(b"%suffix").is_none()); + assert!(LikeKind::parse(b"a_c").is_none()); } /// No symbols — all bytes escaped. Simplest case to see the two tables. @@ -144,7 +146,7 @@ fn test_prefix_dfa_longer() -> VortexResult<()> { #[test] fn test_prefix_pushdown_len_13_with_escapes() { - let matcher = FsstMatcher::try_new(&[], &[], "abcdefghijklm%") + let matcher = FsstMatcher::try_new(&[], &[], b"abcdefghijklm%") .unwrap() .unwrap(); @@ -156,7 +158,7 @@ fn test_prefix_pushdown_len_13_with_escapes() { fn test_prefix_pushdown_len_14_now_handled() { // 14-byte prefix is now handled by FlatPrefixDfa (was rejected by shift-packed). assert!( - FsstMatcher::try_new(&[], &[], "abcdefghijklmn%") + FsstMatcher::try_new(&[], &[], b"abcdefghijklmn%") .unwrap() .is_some() ); @@ -166,7 +168,7 @@ fn test_prefix_pushdown_len_14_now_handled() { fn test_prefix_pushdown_long_prefix() -> VortexResult<()> { let prefix = "a".repeat(FlatPrefixDfa::MAX_PREFIX_LEN); let pattern = format!("{prefix}%"); - let matcher = FsstMatcher::try_new(&[], &[], &pattern)?.unwrap(); + let matcher = FsstMatcher::try_new(&[], &[], pattern.as_bytes())?.unwrap(); assert!(matcher.matches(&escaped(prefix.as_bytes()))); @@ -182,14 +184,20 @@ fn test_prefix_pushdown_rejects_len_254() { debug_assert_eq!(FlatPrefixDfa::MAX_PREFIX_LEN, 253); let prefix = "a".repeat(254); let pattern = format!("{prefix}%"); - assert!(FsstMatcher::try_new(&[], &[], &pattern).unwrap().is_none()); + assert!( + FsstMatcher::try_new(&[], &[], pattern.as_bytes()) + .unwrap() + .is_none() + ); } #[test] fn test_contains_pushdown_len_254_with_escapes() { let needle = "a".repeat(FlatContainsDfa::MAX_NEEDLE_LEN); let pattern = format!("%{needle}%"); - let matcher = FsstMatcher::try_new(&[], &[], &pattern).unwrap().unwrap(); + let matcher = FsstMatcher::try_new(&[], &[], pattern.as_bytes()) + .unwrap() + .unwrap(); assert!(matcher.matches(&escaped(needle.as_bytes()))); @@ -202,14 +210,18 @@ fn test_contains_pushdown_len_254_with_escapes() { fn test_contains_pushdown_rejects_len_255() { let needle = "a".repeat(FlatContainsDfa::MAX_NEEDLE_LEN + 1); let pattern = format!("%{needle}%"); - assert!(FsstMatcher::try_new(&[], &[], &pattern).unwrap().is_none()); + assert!( + FsstMatcher::try_new(&[], &[], pattern.as_bytes()) + .unwrap() + .is_none() + ); } // --------------------------------------------------------------------------- // End-to-end edge cases: FSST compress → LIKE → compare booleans // --------------------------------------------------------------------------- -fn make_fsst(strings: &[Option<&str>]) -> FSSTArray { +fn make_fsst_str(strings: &[Option<&str>]) -> FSSTArray { let varbin = VarBinArray::from_iter( strings.iter().copied(), DType::Utf8(Nullability::NonNullable), @@ -218,13 +230,9 @@ fn make_fsst(strings: &[Option<&str>]) -> FSSTArray { fsst_compress(varbin, &compressor) } -fn run_like(array: FSSTArray, pattern: &str) -> VortexResult { - use vortex_array::ArrayRef; - use vortex_array::arrays::scalar_fn::ScalarFnArrayExt; - +fn run_like(array: FSSTArray, pattern_arr: ArrayRef) -> VortexResult { let len = array.len(); let arr: ArrayRef = array.into_array(); - let pattern_arr = ConstantArray::new(pattern, len).into_array(); let result = Like .try_new_array(len, LikeOptions::default(), [arr, pattern_arr])? .into_array() @@ -267,14 +275,42 @@ fn run_like(array: FSSTArray, pattern: &str) -> VortexResult { // Prefix that shares chars with rest of string #[case(&["abab", "abba", "abcd"], "ab%", &[true, true, true])] #[case(&["abab", "abba", "abcd", "ba"], "ab%", &[true, true, true, false])] +// The string "aabaabaabaab" requires multi-level KMP fallback at the 'a' after "aabaabaab" +#[case(&["aabaabaabaab", "aabaabaax", "xaabaabaab"], "%aabaabaab%", &[true, false, true])] +#[case(&["café latte", "naïve approach", "café noir"], "café%", &[true, false, true])] +#[case(&["日本語テスト", "日本語データ", "英語テスト"], "%日本語%", &[true, true, false])] +// 10-byte needle, contains: match at start, middle, end, exact, and near-miss +#[case( + &["abcdefghijxxx", "xxxabcdefghij", "xxabcdefghijxx", "abcdefghij", "abcdefghxx"], + "%abcdefghij%", + &[true, true, true, true, false] +)] +// 10-byte prefix: same needle but anchored at the start of the string +#[case( + &["abcdefghijxxx", "abcdefghij", "xabcdefghij", "abcdefghxx"], + "abcdefghij%", + &[true, true, false, false] +)] +// 9-byte needle with KMP-relevant overlap ("abcabcabc"): +// failure table = [0,0,0,1,2,3,4,5,6], so a partial match of "abcabcab" +// followed by a mismatch must fall back to state 5 ("abcab"), not restart. +// This exercises multi-level KMP backtracking across symbol boundaries. +#[case( + &["xxabcabcabcxx", "abcabcabc", "abcabcabx", "abcabcxx"], + "%abcabcabc%", + &[true, true, false, false] +)] fn test_like_edge_cases( #[case] strings: &[&str], #[case] pattern: &str, #[case] expected: &[bool], ) -> VortexResult<()> { let opts: Vec> = strings.iter().map(|s| Some(*s)).collect(); - let fsst = make_fsst(&opts); - let result = run_like(fsst, pattern)?; + let fsst_arr = make_fsst_str(&opts); + let result = run_like( + fsst_arr, + ConstantArray::new(pattern, opts.len()).into_array(), + )?; let expected_arr = BoolArray::from_iter(expected.iter().copied()); assert_arrays_eq!(&result, &expected_arr); Ok(()) From 39c193b863e3e5f7819c40c6e44206de5726a680 Mon Sep 17 00:00:00 2001 From: Joe Isaacs Date: Thu, 19 Mar 2026 17:09:10 +0000 Subject: [PATCH 18/19] update Signed-off-by: Joe Isaacs --- encodings/fsst/src/dfa/mod.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/encodings/fsst/src/dfa/mod.rs b/encodings/fsst/src/dfa/mod.rs index ce615268346..358fd3a7ab5 100644 --- a/encodings/fsst/src/dfa/mod.rs +++ b/encodings/fsst/src/dfa/mod.rs @@ -209,7 +209,7 @@ enum LikeKind<'a> { impl<'a> LikeKind<'a> { fn parse(pattern: &'a [u8]) -> Option { // `prefix%` (including just `%` where prefix is empty) - if let Some(prefix) = pattern.strip_suffix(&[b'%']) + if let Some(prefix) = pattern.strip_suffix(b"%") && !prefix.contains(&b'%') && !prefix.contains(&b'_') { @@ -217,7 +217,7 @@ impl<'a> LikeKind<'a> { } // `%needle%` - let inner = pattern.strip_prefix(&[b'%'])?.strip_suffix(&[b'%'])?; + let inner = pattern.strip_prefix(b"%")?.strip_suffix(b"%")?; if !inner.contains(&b'%') && !inner.contains(&b'_') { return Some(LikeKind::Contains(inner)); } From 0a60e0186ca48a9558ba7231af8d104bea69f29c Mon Sep 17 00:00:00 2001 From: Joe Isaacs Date: Thu, 19 Mar 2026 17:11:04 +0000 Subject: [PATCH 19/19] update Signed-off-by: Joe Isaacs --- _typos.toml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/_typos.toml b/_typos.toml index 8bfc27178c7..e9cf23d68b7 100644 --- a/_typos.toml +++ b/_typos.toml @@ -2,13 +2,13 @@ extend-ignore-identifiers-re = ["ffor", "FFOR", "FoR", "typ", "ratatui"] # We support a few common special comments to tell the checker to ignore sections of code extend-ignore-re = [ - "(#|//)\\s*spellchecker:ignore-next-line\\n.*", # Ignore the next line - "(?Rm)^.*(#|//)\\s*spellchecker:disable-line$", # Ignore line that ends with this hint + "(#|//)\\s*spellchecker:ignore-next-line\\n.*", # Ignore the next line + "(?Rm)^.*(#|//)\\s*spellchecker:disable-line$", # Ignore line that ends with this hint "(?s)(#|//)\\s*spellchecker:off.*?\\n\\s*(#|//)\\s*spellchecker:on", # Ignore block between hints ] [files] -extend-exclude = ["/vortex-bench/**", "/docs/references.bib", "benchmarks/**", "vortex-sqllogictest/slt/**", "encodings/fsst/src/dfa/tests.rs"] +extend-exclude = ["/vortex-bench/**", "/docs/references.bib", "benchmarks/**", "vortex-sqllogictest/slt/**", "encodings/fsst/src/dfa/tests.rs", "encodings/fsst/src/dfa/flat_contains.rs"] [type.py] extend-ignore-identifiers-re = [