diff --git a/diskann-quantization/src/multi_vector/block_transposed.rs b/diskann-quantization/src/multi_vector/block_transposed.rs index 10339ea7b..6fe9a1315 100644 --- a/diskann-quantization/src/multi_vector/block_transposed.rs +++ b/diskann-quantization/src/multi_vector/block_transposed.rs @@ -80,7 +80,7 @@ use std::{alloc::Layout, marker::PhantomData, ptr::NonNull}; use diskann_utils::{ - ReborrowMut, + Reborrow, ReborrowMut, strided::StridedView, views::{MatrixView, MutMatrixView}, }; @@ -231,6 +231,15 @@ impl BlockTransposedRepr usize { + self.num_blocks() * GROUP + } + /// The stride (in elements) between the start of consecutive blocks. #[inline] fn block_stride(&self) -> usize { @@ -743,6 +752,15 @@ impl<'a, T: Copy, const GROUP: usize, const PACK: usize> BlockTransposedRef<'a, self.data.repr().remainder() } + /// Total number of logical rows rounded up to the next multiple of `GROUP`. + /// + /// This is the number of "available" row slots in the backing allocation, + /// including zero-padded rows in the last (possibly partial) block. + #[inline] + pub fn padded_nrows(&self) -> usize { + self.data.repr().padded_nrows() + } + /// Return a raw typed pointer to the start of the backing data. #[inline] pub fn as_ptr(&self) -> *const T { @@ -870,6 +888,7 @@ impl<'a, T: Copy, const GROUP: usize, const PACK: usize> BlockTransposedMut<'a, delegate_to_ref!(pub fn full_blocks(&self) -> usize); delegate_to_ref!(pub fn num_blocks(&self) -> usize); delegate_to_ref!(pub fn remainder(&self) -> usize); + delegate_to_ref!(pub fn padded_nrows(&self) -> usize); delegate_to_ref!(pub fn as_ptr(&self) -> *const T); delegate_to_ref!(pub fn as_slice(&self) -> &[T]); delegate_to_ref!(#[allow(clippy::missing_safety_doc)] unsafe pub fn block_ptr_unchecked(&self, block: usize) -> *const T); @@ -1017,6 +1036,7 @@ impl BlockTransposed usize); delegate_to_ref!(pub fn num_blocks(&self) -> usize); delegate_to_ref!(pub fn remainder(&self) -> usize); + delegate_to_ref!(pub fn padded_nrows(&self) -> usize); delegate_to_ref!(pub fn as_ptr(&self) -> *const T); delegate_to_ref!(pub fn as_slice(&self) -> &[T]); delegate_to_ref!(#[allow(clippy::missing_safety_doc)] unsafe pub fn block_ptr_unchecked(&self, block: usize) -> *const T); @@ -1072,6 +1092,19 @@ impl BlockTransposed Reborrow<'this> + for BlockTransposed +{ + type Target = BlockTransposedRef<'this, T, GROUP, PACK>; + + #[inline] + fn reborrow(&'this self) -> Self::Target { + self.as_view() + } +} + // ── Factory methods ────────────────────────────────────────────── impl BlockTransposed { @@ -1676,6 +1709,15 @@ mod tests { } } + // ── padded_nrows() returns padded row count ────────────── + + assert_eq!( + transpose.as_view().padded_nrows(), + padded_nrows, + "padded_nrows() mismatch -- {}", + context, + ); + // ── from_matrix_view produces identical results ────────── if nrows > 0 && ncols > 0 { diff --git a/diskann-quantization/src/multi_vector/distance/simple.rs b/diskann-quantization/src/multi_vector/distance/fallback.rs similarity index 92% rename from diskann-quantization/src/multi_vector/distance/simple.rs rename to diskann-quantization/src/multi_vector/distance/fallback.rs index b92f9fa7e..9dc46a576 100644 --- a/diskann-quantization/src/multi_vector/distance/simple.rs +++ b/diskann-quantization/src/multi_vector/distance/fallback.rs @@ -1,7 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. -//! Simple kernel implementation of multi-vector distance computation. +//! Fallback kernel implementation of multi-vector distance computation. use std::ops::Deref; @@ -49,17 +49,17 @@ impl<'a, T: Repr> Deref for QueryMatRef<'a, T> { } } -////////////////// -// SimpleKernel // -////////////////// +//////////////////// +// FallbackKernel // +//////////////////// -/// Simple double-loop kernel to compute max-sim distances over multi-vectors. +/// Fallback double-loop kernel to compute max-sim distances over multi-vectors. /// /// This kernel performs a simple double-loop over the rows of `query` /// and the `doc` and dispatches to [`InnerProduct`] to compute the similarity. -pub struct SimpleKernel; +pub struct FallbackKernel; -impl SimpleKernel { +impl FallbackKernel { /// Core kernel for computing per-query-vector max similarities (min negated inner-product). /// /// For each `query` vector, computes the maximum similarity (negated inner product) @@ -128,7 +128,7 @@ where return Err(MaxSimError::InvalidBufferLength(size, n_queries)); } - SimpleKernel::max_sim_kernel(query, doc, |i, score| { + FallbackKernel::max_sim_kernel(query, doc, |i, score| { // SAFETY: We asserted that self.size() == query.num_vectors(), // and i < query.num_vectors() due to the kernel loop bound. unsafe { *self.scores.get_unchecked_mut(i) = score }; @@ -151,7 +151,7 @@ where fn evaluate(query: QueryMatRef<'_, Standard>, doc: MatRef<'_, Standard>) -> f32 { let mut sum = 0.0f32; - SimpleKernel::max_sim_kernel(query, doc, |_i, score| { + FallbackKernel::max_sim_kernel(query, doc, |_i, score| { sum += score; }); @@ -185,7 +185,7 @@ mod tests { .fold(f32::MAX, f32::min) } - /// Generate a vector of random f32 values in [-1, 1] for testing + /// Generate deterministic test data. fn make_test_data(len: usize, ceil: usize, shift: usize) -> Vec { (0..len).map(|v| ((v + shift) % ceil) as f32).collect() } @@ -270,9 +270,9 @@ mod tests { ); } - // Check that SimpleKernel is also correct. - SimpleKernel::max_sim_kernel(query, doc, |i, score| { - assert!((scores[i] - score).abs() <= 1e-6) + // Check that FallbackKernel produces the same values as the naive reference. + FallbackKernel::max_sim_kernel(query, doc, |i, score| { + assert!((expected_scores[i] - score).abs() <= 1e-6) }); // Test Chamfer @@ -299,7 +299,7 @@ mod tests { // No query vectors means sum is 0 assert_eq!(result, 0.0); - let result = Chamfer::evaluate(doc.into(), query.deref().reborrow()); + let result = Chamfer::evaluate(QueryMatRef::from(doc), query.deref().reborrow()); assert_eq!(result, 0.0); } diff --git a/diskann-quantization/src/multi_vector/distance/kernels/f16.rs b/diskann-quantization/src/multi_vector/distance/kernels/f16.rs new file mode 100644 index 000000000..a535c68dc --- /dev/null +++ b/diskann-quantization/src/multi_vector/distance/kernels/f16.rs @@ -0,0 +1,52 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +//! f16 dispatch adapter for block-transposed multi-vector distance. +//! +//! Reuses the f32 micro-kernel family with tile-level f16→f32 conversion +//! via [`ConvertTo`](super::layouts::ConvertTo). No f16-specific micro-kernel +//! code is needed — the [`F32Kernel`](super::f32::F32Kernel) does all the +//! SIMD work after conversion. +//! +//! Conversion from f16 to f32 is performed at tile granularity via +//! [`SliceCast`](diskann_vector::conversion::SliceCast), dispatched through +//! the runtime architecture token — the same SIMD level used by the +//! micro-kernel. + +use diskann_wide::Architecture; + +use super::Kernel; +use super::TileBudget; +use super::f32::{F32Kernel, max_ip_kernel}; +use super::layouts; +use crate::multi_vector::{BlockTransposedRef, MatRef, Standard}; + +pub(crate) struct F16Entry; + +impl + diskann_wide::arch::Target3< + A, + (), + BlockTransposedRef<'_, half::f16, GROUP>, + MatRef<'_, Standard>, + &mut [f32], + > for F16Entry +where + A: Architecture, + F32Kernel: Kernel, + layouts::BlockTransposed: layouts::ConvertTo as Kernel>::Left> + + layouts::Layout, + layouts::RowMajor: layouts::ConvertTo as Kernel>::Right> + + layouts::Layout, +{ + #[inline(always)] + fn run( + self, + arch: A, + lhs: BlockTransposedRef<'_, half::f16, GROUP>, + rhs: MatRef<'_, Standard>, + scratch: &mut [f32], + ) { + max_ip_kernel(arch, lhs, rhs, scratch, TileBudget::default()); + } +} diff --git a/diskann-quantization/src/multi_vector/distance/kernels/f32/mod.rs b/diskann-quantization/src/multi_vector/distance/kernels/f32/mod.rs new file mode 100644 index 000000000..a900ea356 --- /dev/null +++ b/diskann-quantization/src/multi_vector/distance/kernels/f32/mod.rs @@ -0,0 +1,135 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +//! f32 micro-kernel family for block-transposed multi-vector distance. +//! +//! Provides: +//! +//! - `F32Kernel` — zero-sized marker type selecting the f32 micro-kernel +//! for `BlockTransposed` data. +//! - [`max_ip_kernel`] — architecture-, element-type-, and GROUP-generic entry point +//! for the reducing max-IP GEMM. Accepts any element type `T` for which +//! [`ConvertTo`](super::layouts::ConvertTo) impls exist (identity for f32, +//! SIMD-accelerated f16→f32, etc.). +//! +//! # Architecture-specific micro-kernels +//! +//! - `v3` (x86_64) — V3 (AVX2+FMA) 16×4 micro-kernel (GROUP=16). V4 delegates to V3 at dispatch. +//! - `scalar` — Emulated 8×2 micro-kernel (GROUP=8). Neon delegates to Scalar at dispatch. + +use diskann_wide::Architecture; + +use super::Kernel; +use super::TileBudget; +use super::layouts::{self, DescribeLayout}; +use super::tiled_reduce::tiled_reduce; +use crate::multi_vector::{BlockTransposedRef, MatRef, Standard}; + +mod scalar; +#[cfg(target_arch = "x86_64")] +mod v3; + +/// Zero-sized kernel type for f32 micro-kernels with block size `GROUP`. +pub(crate) struct F32Kernel; + +#[inline(never)] +#[cold] +#[allow(clippy::panic)] +fn max_ip_kernel_panic(scratch_len: usize, padded_nrows: usize, a_ncols: usize, b_dim: usize) { + panic!( + "max_ip_kernel: precondition failed: \ + scratch.len()={scratch_len} (expected {padded_nrows}), \ + a.ncols()={a_ncols}, b.vector_dim()={b_dim}" + ); +} + +/// Compute the reducing max-IP GEMM between a block-transposed A matrix and +/// a row-major B matrix, writing per-A-row max similarities into `scratch`. +/// +/// Thin wrapper over [`tiled_reduce`] using `F32Kernel` for the +/// requested architecture. The element type `T` can be any `Copy` type with +/// matching [`ConvertTo`](super::layouts::ConvertTo) impls (zero-cost for +/// `T = f32`; SIMD f16→f32 conversion once per tile for `T = half::f16`). +/// +/// `scratch` must have length [`BlockTransposedRef::padded_nrows()`] and be +/// initialized to `f32::MIN` before the first call. On return, `scratch[i]` +/// holds the maximum inner product between A row `i` and any B row. +/// +/// # Panics +/// +/// Panics if `scratch.len() != a.padded_nrows()` or `a.ncols() != b.vector_dim()`. +pub(super) fn max_ip_kernel( + arch: A, + a: BlockTransposedRef<'_, T, GROUP>, + b: MatRef<'_, Standard>, + scratch: &mut [f32], + budget: TileBudget, +) where + F32Kernel: Kernel, + layouts::BlockTransposed: + layouts::ConvertTo as Kernel>::Left> + layouts::Layout, + layouts::RowMajor: layouts::ConvertTo as Kernel>::Right> + + layouts::Layout, +{ + if scratch.len() != a.padded_nrows() || a.ncols() != b.vector_dim() { + max_ip_kernel_panic(scratch.len(), a.padded_nrows(), a.ncols(), b.vector_dim()); + } + + let k = a.ncols(); + let b_nrows = b.num_vectors(); + + // Compile-time: A_PANEL must equal GROUP for block-transposed layout correctness. + const { assert!( as Kernel>::A_PANEL == GROUP) } + + let ca = a.layout(); + let cb = b.layout(); + + // SAFETY: + // - a.as_ptr() is valid for a.padded_nrows() * k elements of T. + // - MatRef> stores nrows * ncols contiguous T elements. + // - scratch.len() == a.padded_nrows() (checked above). + // - a.padded_nrows() is always a multiple of GROUP, and the const assert above + // verifies A_PANEL == GROUP at compile time. + unsafe { + tiled_reduce::, _, _>( + arch, + &ca, + &cb, + a.as_ptr(), + a.padded_nrows(), + b.as_slice().as_ptr(), + b_nrows, + k, + scratch, + budget, + ); + } +} + +impl + diskann_wide::arch::Target3< + A, + (), + BlockTransposedRef<'_, f32, GROUP>, + MatRef<'_, Standard>, + &mut [f32], + > for F32Kernel +where + A: Architecture, + Self: Kernel, + layouts::BlockTransposed: + layouts::ConvertTo>::Left> + layouts::Layout, + layouts::RowMajor: + layouts::ConvertTo>::Right> + layouts::Layout, +{ + #[inline(always)] + fn run( + self, + arch: A, + lhs: BlockTransposedRef<'_, f32, GROUP>, + rhs: MatRef<'_, Standard>, + scratch: &mut [f32], + ) { + max_ip_kernel(arch, lhs, rhs, scratch, TileBudget::default()); + } +} diff --git a/diskann-quantization/src/multi_vector/distance/kernels/f32/scalar.rs b/diskann-quantization/src/multi_vector/distance/kernels/f32/scalar.rs new file mode 100644 index 000000000..bd8fb1c4a --- /dev/null +++ b/diskann-quantization/src/multi_vector/distance/kernels/f32/scalar.rs @@ -0,0 +1,113 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +//! Scalar (emulated) f32 micro-kernel (8×2). +//! +//! Uses [`Emulated`](diskann_wide::Emulated) (aliased as `f32x8`) for +//! arithmetic — 8 multiply-accumulate operations per inner iteration, 8 scalar +//! comparisons per `max_simd`. Geometry is A_PANEL=8 (1 × f32x8), B_PANEL=2 +//! (matching the `Strategy2x1` pattern used by scalar distance functions +//! elsewhere in the codebase). +//! +//! The inner loop uses separate multiply and add (`a * b + acc`) instead of +//! `mul_add_simd` to avoid calling into libm's software `fma()` routine on +//! x86-64 targets without hardware FMA support. + +use diskann_wide::arch::Scalar; +use diskann_wide::{SIMDMinMax, SIMDVector}; + +use super::super::Kernel; +use super::super::layouts; +use super::super::reduce::Reduce; +use super::F32Kernel; + +diskann_wide::alias!(f32s = ::f32x8); + +// SAFETY: F32Kernel's `full_panel` and `partial_panel` only access +// A_PANEL(8) * k A elements, UNROLL * k B elements, and A_PANEL(8) +// scratch elements — all within the bounds guaranteed by `tiled_reduce`. +unsafe impl Kernel for F32Kernel<8> { + type Left = layouts::BlockTransposed; + type Right = layouts::RowMajor; + const A_PANEL: usize = 8; + const B_PANEL: usize = 2; + + #[inline(always)] + unsafe fn full_panel(arch: Scalar, a: *const f32, b: *const f32, k: usize, r: *mut f32) { + // SAFETY: pointer validity per Kernel contract. + unsafe { scalar_f32_microkernel::<{ Self::B_PANEL }>(arch, a, b, k, r) } + } + + #[inline(always)] + unsafe fn partial_panel( + arch: Scalar, + remainder: usize, + a: *const f32, + b: *const f32, + k: usize, + r: *mut f32, + ) { + // SAFETY: pointer validity per Kernel contract. + unsafe { + match remainder { + 1 => scalar_f32_microkernel::<1>(arch, a, b, k, r), + _ => unreachable!( + "unexpected remainder {remainder} for B_PANEL={}", + Self::B_PANEL + ), + } + } + } +} + +// ── Scalar f32 micro-kernel ────────────────────────────────────── + +/// Emulated micro-kernel: processes 8 A rows × `UNROLL` B rows. +/// +/// Uses separate multiply and add (`a * b + acc`) rather than `mul_add_simd` +/// to avoid calling libm's software `fma()` on x86-64 without hardware FMA. +/// A single register tile covers A_PANEL = 8 = f32s::LANES. B_PANEL=2 +/// follows the `Strategy2x1` pattern from scalar distance functions. +/// +/// # Safety +/// +/// 1. `a_packed` must point to `A_PANEL(8) × k` contiguous `f32` values. +/// 2. `b` must point to `UNROLL` rows of `k` contiguous `f32` values. +/// 3. `r` must point to at least `A_PANEL(8)` writable `f32` values. +#[inline(always)] +unsafe fn scalar_f32_microkernel( + arch: Scalar, + a_packed: *const f32, + b: *const f32, + k: usize, + r: *mut f32, +) where + [f32s; UNROLL]: Reduce, +{ + let op = |x: f32s, y: f32s| x.max_simd(y); + + let mut p0 = [f32s::default(arch); UNROLL]; + let offsets: [usize; UNROLL] = core::array::from_fn(|i| k * i); + + let a_stride = f32s::LANES; + + for i in 0..k { + // SAFETY: By preconditions 1 and 2; i < k and j < UNROLL. + unsafe { + let a0 = f32s::load_simd(arch, a_packed.add(a_stride * i)); + + for j in 0..UNROLL { + let bj = f32s::splat(arch, b.add(i + offsets[j]).read_unaligned()); + p0[j] = a0 * bj + p0[j]; + } + } + } + + // SAFETY: By precondition 3. + let mut r0 = unsafe { f32s::load_simd(arch, r) }; + + r0 = op(r0, p0.reduce(&op)); + + // SAFETY: By precondition 3. + unsafe { r0.store_simd(r) }; +} diff --git a/diskann-quantization/src/multi_vector/distance/kernels/f32/v3.rs b/diskann-quantization/src/multi_vector/distance/kernels/f32/v3.rs new file mode 100644 index 000000000..b05195b1e --- /dev/null +++ b/diskann-quantization/src/multi_vector/distance/kernels/f32/v3.rs @@ -0,0 +1,112 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +//! V3 (AVX2+FMA) f32 micro-kernel (16×4). + +use diskann_wide::arch::x86_64::V3; +use diskann_wide::{SIMDMinMax, SIMDMulAdd, SIMDVector}; + +use super::super::Kernel; +use super::super::layouts; +use super::super::reduce::Reduce; +use super::F32Kernel; + +diskann_wide::alias!(f32s = ::f32x8); + +// SAFETY: F32Kernel's `full_panel` and `partial_panel` only access +// A_PANEL(16) * k A elements, UNROLL * k B elements, and A_PANEL(16) +// scratch elements — all within the bounds guaranteed by `tiled_reduce`. +unsafe impl Kernel for F32Kernel<16> { + type Left = layouts::BlockTransposed; + type Right = layouts::RowMajor; + const A_PANEL: usize = 16; + const B_PANEL: usize = 4; + + #[inline(always)] + unsafe fn full_panel(arch: V3, a: *const f32, b: *const f32, k: usize, r: *mut f32) { + // SAFETY: pointer validity per Kernel contract. + unsafe { f32_microkernel::<{ Self::B_PANEL }>(arch, a, b, k, r) } + } + + #[inline(always)] + unsafe fn partial_panel( + arch: V3, + remainder: usize, + a: *const f32, + b: *const f32, + k: usize, + r: *mut f32, + ) { + // SAFETY: pointer validity per Kernel contract. + unsafe { + match remainder { + 1 => f32_microkernel::<1>(arch, a, b, k, r), + 2 => f32_microkernel::<2>(arch, a, b, k, r), + 3 => f32_microkernel::<3>(arch, a, b, k, r), + _ => unreachable!( + "unexpected remainder {remainder} for B_PANEL={}", + Self::B_PANEL + ), + } + } + } +} + +// ── V3 f32 micro-kernel ───────────────────────────────────────── + +/// SIMD micro-kernel: processes 16 A rows × `UNROLL` B rows. +/// +/// Accumulates via FMA into two `f32x8` register tiles, reduces across the +/// `UNROLL` B lanes with `max_simd`, then merges into the scratch buffer `r`. +/// +/// # Safety +/// +/// 1. `a_packed` must point to `A_PANEL(16) × k` contiguous `f32` values. +/// 2. `b` must point to `UNROLL` rows of `k` contiguous `f32` values. +/// 3. `r` must point to at least `A_PANEL(16)` writable `f32` values. +#[inline(always)] +unsafe fn f32_microkernel( + arch: V3, + a_packed: *const f32, + b: *const f32, + k: usize, + r: *mut f32, +) where + [f32s; UNROLL]: Reduce, +{ + let op = |x: f32s, y: f32s| x.max_simd(y); + + let mut p0 = [f32s::default(arch); UNROLL]; + let mut p1 = [f32s::default(arch); UNROLL]; + let offsets: [usize; UNROLL] = core::array::from_fn(|i| k * i); + + let a_stride = 2 * f32s::LANES; + let a_stride_half = f32s::LANES; + + for i in 0..k { + // SAFETY: By preconditions 1 and 2; i < k and j < UNROLL. + unsafe { + let a0 = f32s::load_simd(arch, a_packed.add(a_stride * i)); + let a1 = f32s::load_simd(arch, a_packed.add(a_stride * i + a_stride_half)); + + for j in 0..UNROLL { + let bj = f32s::splat(arch, b.add(i + offsets[j]).read_unaligned()); + p0[j] = a0.mul_add_simd(bj, p0[j]); + p1[j] = a1.mul_add_simd(bj, p1[j]); + } + } + } + + // SAFETY: By precondition 3; LANES < A_PANEL so both halves are in-bounds. + let mut r0 = unsafe { f32s::load_simd(arch, r) }; + // SAFETY: By precondition 3; r.add(LANES) is still within the A_PANEL-sized scratch. + let mut r1 = unsafe { f32s::load_simd(arch, r.add(f32s::LANES)) }; + + r0 = op(r0, p0.reduce(&op)); + r1 = op(r1, p1.reduce(&op)); + + // SAFETY: By precondition 3. + unsafe { r0.store_simd(r) }; + // SAFETY: By precondition 3; r.add(LANES) is still within the A_PANEL-sized scratch. + unsafe { r1.store_simd(r.add(f32s::LANES)) }; +} diff --git a/diskann-quantization/src/multi_vector/distance/kernels/layouts.rs b/diskann-quantization/src/multi_vector/distance/kernels/layouts.rs new file mode 100644 index 000000000..e1ec8dd36 --- /dev/null +++ b/diskann-quantization/src/multi_vector/distance/kernels/layouts.rs @@ -0,0 +1,219 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +//! Layout markers and tile-level conversion traits. +//! +//! - [`Layout`] — marker trait: memory layout + element type. +//! - [`BlockTransposed`] / [`RowMajor`] — zero-sized layout markers. +//! - [`DescribeLayout`] — bridges matrix types to layout markers. +//! - [`ConvertTo`] — tile-level conversion (blanket identity + f16→f32). + +use core::marker::PhantomData; + +use diskann_vector::conversion::SliceCast; +use diskann_wide::Architecture; +use diskann_wide::arch::Target2; + +// ── Layout trait ───────────────────────────────────── + +/// Memory layout and element type marker for tile data. +pub(super) trait Layout { + type Element: Copy; +} + +// ── Layout markers ─────────────────────────────────── + +/// Block-transposed tile layout: `GROUP` rows per block, `PACK` columns +/// interleaved. Matches [`BlockTransposedRef`](crate::multi_vector::BlockTransposedRef). +pub(super) struct BlockTransposed(PhantomData); + +impl BlockTransposed { + pub(super) fn new() -> Self { + Self(PhantomData) + } +} + +impl Copy for BlockTransposed {} + +impl Clone for BlockTransposed { + fn clone(&self) -> Self { + *self + } +} + +impl Layout for BlockTransposed { + type Element = T; +} + +/// Dense row-major tile layout. Matches [`MatRef>`](crate::multi_vector::MatRef). +pub(super) struct RowMajor(PhantomData); + +impl RowMajor { + pub(super) fn new() -> Self { + Self(PhantomData) + } +} + +impl Copy for RowMajor {} + +impl Clone for RowMajor { + fn clone(&self) -> Self { + *self + } +} + +impl Layout for RowMajor { + type Element = T; +} + +// ── DescribeLayout ─────────────────────────────────── + +/// Bridges a concrete matrix type to its [`Layout`] marker, enabling +/// type inference of [`ConvertTo`] parameters at call sites. +pub(super) trait DescribeLayout { + type Layout: Layout; + + fn layout(&self) -> Self::Layout; +} + +impl DescribeLayout + for crate::multi_vector::BlockTransposedRef<'_, T, GROUP, PACK> +{ + type Layout = BlockTransposed; + + fn layout(&self) -> Self::Layout { + BlockTransposed::new() + } +} + +impl DescribeLayout for crate::multi_vector::MatRef<'_, crate::multi_vector::Standard> { + type Layout = RowMajor; + + fn layout(&self) -> Self::Layout { + RowMajor::new() + } +} + +// ── ConvertTo trait ────────────────────────────────── + +/// Tile-level conversion from layout `Self` to layout `To`. +/// +/// The blanket identity impl covers every layout converting to itself +/// with `Buffer = ()` and zero cost. Explicit impls handle f16→f32 via +/// [`SliceCast`]. +/// +/// # Safety +/// +/// Implementors must ensure: +/// - `convert` reads at most `rows * k` source elements. +/// - `convert` writes only within `buf`. +/// - The returned pointer is valid until the next `&mut` access to `buf`. +pub(super) unsafe trait ConvertTo: Layout { + /// Staging buffer for converted tile data (`()` for identity conversions). + type Buffer; + + /// Allocate a buffer for up to `max_tile_rows` rows of dimension `k`. + fn new_buffer(&self, max_tile_rows: usize, k: usize) -> Self::Buffer; + + /// Convert `rows` rows of source data into `buf`, returning a read pointer. + /// + /// # Safety + /// + /// * `src` must point to `rows * k` valid elements in `Self`'s layout. + /// * `buf` must come from [`new_buffer`](Self::new_buffer) with the + /// same `k` and a `max_tile_rows >= rows`. + unsafe fn convert( + &self, + buf: &mut Self::Buffer, + arch: A, + src: *const Self::Element, + rows: usize, + k: usize, + ) -> *const To::Element; +} + +// ── Blanket identity ───────────────────────────────── + +/// Identity conversion: every layout converts to itself at zero cost. +// SAFETY: Identity conversion reads nothing beyond `src` and writes +// nothing into `buf`. The returned pointer is exactly `src`, which is +// valid for the lifetime guaranteed by the caller. +unsafe impl ConvertTo for L { + type Buffer = (); + + fn new_buffer(&self, _max_tile_rows: usize, _k: usize) {} + + unsafe fn convert( + &self, + _buf: &mut (), + _arch: A, + src: *const L::Element, + _rows: usize, + _k: usize, + ) -> *const L::Element { + src + } +} + +// ── f16 → f32 conversions ──────────────────────────── + +/// Block-transposed f16 → block-transposed f32 (element-wise, layout-preserving). +// SAFETY: `SliceCast` converts exactly `rows * k` f16 values from `src` +// into `rows * k` f32 values in `buf`. The returned pointer is +// `buf.as_ptr()`, valid until the next `&mut` access to `buf`. +unsafe impl + ConvertTo> for BlockTransposed +where + A: Architecture, + SliceCast: for<'a> Target2, +{ + type Buffer = Vec; + + fn new_buffer(&self, max_tile_rows: usize, k: usize) -> Vec { + vec![0.0f32; max_tile_rows * k] + } + + unsafe fn convert( + &self, + buf: &mut Vec, + arch: A, + src: *const half::f16, + rows: usize, + k: usize, + ) -> *const f32 { + let count = rows * k; + // SAFETY: Caller guarantees `src` points to `count` contiguous f16 values. + let src_slice = unsafe { std::slice::from_raw_parts(src, count) }; + arch.run2(SliceCast::new(), &mut buf[..count], src_slice); + buf.as_ptr() + } +} + +/// Row-major f16 → row-major f32 (element-wise, layout-preserving). +// SAFETY: Same as block-transposed variant — element-wise, layout-preserving. +unsafe impl ConvertTo> for RowMajor +where + A: Architecture, + SliceCast: for<'a> Target2, +{ + type Buffer = Vec; + + fn new_buffer(&self, max_tile_rows: usize, k: usize) -> Vec { + vec![0.0f32; max_tile_rows * k] + } + + unsafe fn convert( + &self, + buf: &mut Vec, + arch: A, + src: *const half::f16, + rows: usize, + k: usize, + ) -> *const f32 { + let count = rows * k; + // SAFETY: Caller guarantees `src` points to `count` contiguous f16 values. + let src_slice = unsafe { std::slice::from_raw_parts(src, count) }; + arch.run2(SliceCast::new(), &mut buf[..count], src_slice); + buf.as_ptr() + } +} diff --git a/diskann-quantization/src/multi_vector/distance/kernels/mod.rs b/diskann-quantization/src/multi_vector/distance/kernels/mod.rs new file mode 100644 index 000000000..bd9121a24 --- /dev/null +++ b/diskann-quantization/src/multi_vector/distance/kernels/mod.rs @@ -0,0 +1,111 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +//! Block-transposed SIMD kernels for multi-vector distance computation. +//! +//! This module provides a SIMD-accelerated implementation that uses block-transposed +//! memory layout for **query** vectors (instead of documents), with documents remaining +//! in row-major format. +//! +//! # Memory Layout +//! +//! - **Query**: Block-transposed (`GROUP` vectors per block, dimensions contiguous +//! within each block). The block size is determined by the kernel's `A_PANEL`. +//! - **Document**: Row-major (standard [`MatRef`](crate::multi_vector::MatRef) format). + +pub(super) mod f16; +pub(super) mod f32; +mod layouts; +mod reduce; +mod tiled_reduce; + +// ── Tile budget ────────────────────────────────────────────────── + +/// Cache budgets fed to the tile planner. +/// +/// `Default` returns the production budgets derived from hardcoded L1/L2 +/// cache-size estimates and fixed fractions. +#[derive(Debug, Clone, Copy)] +struct TileBudget { + /// L2 budget in bytes reserved for A tiles. + l2_a: usize, + /// L1 budget in bytes reserved for B tiles (before A-panel subtraction). + l1_b: usize, +} + +impl Default for TileBudget { + // TODO: Replace hardcoded fallbacks with detected cache sizes + // (e.g. via `diskann_platform`, env-var override, or runtime query). + fn default() -> Self { + const L2_CACHE: usize = 1_250_000; // 1.25 MB fallback + const L1_CACHE: usize = 48_000; // 48 KB fallback + + Self { + // 50% of L2 for A tiles; remainder for B streaming + pollution. + l2_a: L2_CACHE / 2, + // 75% of L1 for B tiles; A micro-panel subtracted at runtime. + l1_b: L1_CACHE * 3 / 4, + } + } +} + +// ── Kernel trait ───────────────────────────────────────────────── + +/// SIMD micro-kernel for the [`tiled_reduce`](tiled_reduce::tiled_reduce) loop. +/// +/// The kernel only sees already-converted data: storage-layout to +/// kernel-layout conversion is handled at tile boundaries by +/// [`ConvertTo`](layouts::ConvertTo), so implementors can assume their input +/// pointers reference `::Element` / +/// `::Element` directly. +/// +/// # Safety +/// +/// Implementors must respect the per-method `# Safety` contracts on +/// [`full_panel`](Self::full_panel) and [`partial_panel`](Self::partial_panel). +unsafe trait Kernel { + /// Layout consumed by the A (left / query) side of the micro-kernel. + type Left: layouts::Layout; + /// Layout consumed by the B (right / document) side of the micro-kernel. + type Right: layouts::Layout; + + /// Number of A rows processed per micro-kernel invocation. + const A_PANEL: usize; + /// Number of B rows processed per micro-kernel invocation. + const B_PANEL: usize; + + /// Process one full `A_PANEL × B_PANEL` micro-panel pair. + /// + /// # Safety + /// + /// * `a` must point to `A_PANEL * k` contiguous elements of + /// `::Element`. + /// * `b` must point to `B_PANEL * k` contiguous elements of + /// `::Element`. + /// * `r` must point to at least `A_PANEL` writable `f32` values. + unsafe fn full_panel( + arch: A, + a: *const ::Element, + b: *const ::Element, + k: usize, + r: *mut f32, + ); + + /// Dispatch for `1..(B_PANEL-1)` remainder B rows. + /// + /// # Safety + /// + /// * `a` must point to `A_PANEL * k` contiguous elements of + /// `::Element`. + /// * `b` must point to `remainder * k` contiguous elements of + /// `::Element`. + /// * `r` must point to at least `A_PANEL` writable `f32` values. + unsafe fn partial_panel( + arch: A, + remainder: usize, + a: *const ::Element, + b: *const ::Element, + k: usize, + r: *mut f32, + ); +} diff --git a/diskann-quantization/src/multi_vector/distance/kernels/reduce.rs b/diskann-quantization/src/multi_vector/distance/kernels/reduce.rs new file mode 100644 index 000000000..d3dfe85d1 --- /dev/null +++ b/diskann-quantization/src/multi_vector/distance/kernels/reduce.rs @@ -0,0 +1,106 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +//! Compile-time unroll reduction over fixed-size accumulator arrays. +//! +//! Shared by every micro-kernel family (f32, f16, future u8/i8, …): each +//! kernel keeps `UNROLL` independent SIMD accumulators in the inner loop and +//! folds them down to a single value at the end with a caller-supplied binary +//! operator (e.g. `max_simd`). +//! +//! Implementations are provided for `[T; 1..=4]`, matching the unroll factors +//! currently used by the kernels. The 4-element fold is balanced (`(a⊕b)⊕(c⊕d)`) +//! to shorten the dependency chain; 2- and 3-element folds are left-associative. + +/// Compile-time unroll reduction over fixed-size arrays. +/// +/// Used by the micro-kernels to reduce `UNROLL` accumulators into a single +/// value using a caller-supplied binary operator (e.g. `max_simd`). +pub(super) trait Reduce { + type Element; + fn reduce(&self, f: &F) -> Self::Element + where + F: Fn(Self::Element, Self::Element) -> Self::Element; +} + +impl Reduce for [T; 1] { + type Element = T; + + #[inline(always)] + fn reduce(&self, _f: &F) -> T + where + F: Fn(T, T) -> T, + { + self[0] + } +} + +impl Reduce for [T; 2] { + type Element = T; + + #[inline(always)] + fn reduce(&self, f: &F) -> T + where + F: Fn(T, T) -> T, + { + f(self[0], self[1]) + } +} + +impl Reduce for [T; 3] { + type Element = T; + + #[inline(always)] + fn reduce(&self, f: &F) -> T + where + F: Fn(T, T) -> T, + { + f(f(self[0], self[1]), self[2]) + } +} + +impl Reduce for [T; 4] { + type Element = T; + + #[inline(always)] + fn reduce(&self, f: &F) -> T + where + F: Fn(T, T) -> T, + { + f(f(self[0], self[1]), f(self[2], self[3])) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn reduce_folds_correctly() { + let max = |a: f32, b: f32| a.max(b); + assert_eq!([5.0f32].reduce(&max), 5.0); + assert_eq!([1.0f32, 3.0].reduce(&max), 3.0); + assert_eq!([2.0f32, 1.0, 4.0].reduce(&max), 4.0); + assert_eq!([3.0f32, 1.0, 4.0, 2.0].reduce(&max), 4.0); + } + + /// Verify the exact fold order of each `Reduce` impl using a + /// non-commutative operator (subtraction). + /// + /// - `[a; 1]` → `a` + /// - `[a, b; 2]` → `a - b` + /// - `[a, b, c; 3]` → `(a - b) - c` (left fold) + /// - `[a, b, c, d; 4]` → `(a - b) - (c - d)` (balanced tree) + #[test] + fn reduce_fold_order() { + let sub = |a: f32, b: f32| a - b; + // [10] → 10 + assert_eq!([10.0f32].reduce(&sub), 10.0); + // [10, 3] → 10 - 3 = 7 + assert_eq!([10.0f32, 3.0].reduce(&sub), 7.0); + // [10, 3, 1] → (10 - 3) - 1 = 6 + assert_eq!([10.0f32, 3.0, 1.0].reduce(&sub), 6.0); + // [10, 3, 1, 2] → (10 - 3) - (1 - 2) = 7 - (-1) = 8 + assert_eq!([10.0f32, 3.0, 1.0, 2.0].reduce(&sub), 8.0); + } +} diff --git a/diskann-quantization/src/multi_vector/distance/kernels/tiled_reduce.rs b/diskann-quantization/src/multi_vector/distance/kernels/tiled_reduce.rs new file mode 100644 index 000000000..9f02d11da --- /dev/null +++ b/diskann-quantization/src/multi_vector/distance/kernels/tiled_reduce.rs @@ -0,0 +1,648 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +//! Generic tiling loop for reducing-GEMM micro-kernels. +//! +//! # Tiling Strategy +//! +//! This approach uses a reducing-GEMM pattern modeled after high-performance BLAS +//! implementations: +//! +//! - **L2 cache**: Tiles of A (conventionally the query) are sized to fit in L2. +//! - **L1 cache**: Tiles of B (conventionally the document) plus one micro-panel +//! of A are sized to fit in L1. +//! - **Micro-kernel**: An `A_PANEL × B_PANEL` micro-kernel (e.g. 16×4 for f32 on V3) +//! processes a panel of A rows against a panel of B rows per invocation, +//! accumulating max-IP into a scratch buffer. The panel sizes are determined +//! by the `Kernel` implementation for each element type. +//! +//! The loop itself is layout-agnostic: A and B are described by the generic +//! `LA`/`LB` parameters and converted to the kernel's expected layouts via +//! [`ConvertTo`] at tile boundaries. The current micro-kernels happen to want a +//! block-transposed A and a row-major B, but `tiled_reduce` does not require +//! either — any layout pair satisfying the `ConvertTo` bounds works. + +use diskann_wide::Architecture; + +use super::layouts::{ConvertTo, Layout}; +use super::{Kernel, TileBudget}; + +// ── Tile planner ───────────────────────────────────────────────── + +/// Tile-panel counts derived from cache budgets. +#[derive(Debug, Clone, Copy)] +struct FullReduce { + a_panels_per_tile: usize, + + b_panels_per_tile: usize, +} + +impl FullReduce { + /// Compute A-tile and B-tile panel counts from cache budgets. + /// + /// The L1 budget is reduced by one A micro-panel before splitting it into + /// B panels, since both must coexist in L1 during the inner loop. + fn new( + a_row_bytes: usize, + b_row_bytes: usize, + a_panel: usize, + b_panel: usize, + budget: TileBudget, + ) -> Self { + let a_row_bytes = a_row_bytes.max(1); + let b_row_bytes = b_row_bytes.max(1); + + let a_panels_per_tile = (budget.l2_a / (a_row_bytes * a_panel)).max(1); + + let a_panel_bytes = a_panel * a_row_bytes; + let b_tile_budget = budget.l1_b.saturating_sub(a_panel_bytes); + let b_panels_per_tile = (b_tile_budget / (b_row_bytes * b_panel)).max(1); + + Self { + a_panels_per_tile, + b_panels_per_tile, + } + } +} + +// ── Generic tiled reduce ───────────────────────────────────────── + +/// Execute the 5-level tiling loop with a pluggable SIMD micro-kernel and +/// tile-level layout converters. +/// +/// The loop nest is: +/// ```text +/// Loop 1: A tiles (sized to L2) — convert via `ca` +/// Loop 2: B tiles (sized to L1) — convert via `cb` +/// Loop 3: A panels (micro-panels within converted A tile) +/// Loop 4: B panels (micro-panels within converted B tile) +/// Loop 5: k (contraction dim, inside K::full_panel / K::partial_panel) +/// ``` +/// +/// Conversion from storage layout to kernel layout happens once per tile +/// (not per panel), amortizing cost over the entire tile. +/// +/// # Safety +/// +/// * `a_ptr` must be valid for `a_padded_nrows * k` elements of `AElem`. +/// * `a_padded_nrows` must be a multiple of `K::A_PANEL`. +/// * `b_ptr` must be valid for `b_nrows * k` elements of `BElem`. +/// * `scratch` must have length ≥ `a_padded_nrows` and be initialized by caller. +#[allow(clippy::too_many_arguments)] +pub(super) unsafe fn tiled_reduce( + arch: A, + ca: &LA, + cb: &LB, + a_ptr: *const LA::Element, + a_padded_nrows: usize, + b_ptr: *const LB::Element, + b_nrows: usize, + k: usize, + scratch: &mut [f32], + budget: TileBudget, +) where + A: Architecture, + K: Kernel, + LA: ConvertTo, + LB: ConvertTo, +{ + let a_row_bytes = k * std::mem::size_of::<::Element>(); + let b_row_bytes = k * std::mem::size_of::<::Element>(); + let plan = FullReduce::new(a_row_bytes, b_row_bytes, K::A_PANEL, K::B_PANEL, budget); + + let b_src_panel_stride = K::B_PANEL * k; + let b_src_tile_stride = b_src_panel_stride * plan.b_panels_per_tile; + + let a_kern_panel_stride = K::A_PANEL * k; + let b_kern_panel_stride = K::B_PANEL * k; + + let b_remainder = b_nrows % K::B_PANEL; + + assert_eq!( + a_padded_nrows % K::A_PANEL, + 0, + "a_padded_nrows ({a_padded_nrows}) must be a multiple of A_PANEL ({})", + K::A_PANEL, + ); + + // Zero-dimensional vectors have IP = 0 for every pair. Fill scratch and + // return to avoid zero-stride infinite loops in the tiling nest. + if k == 0 { + if b_nrows > 0 { + scratch[..a_padded_nrows].fill(0.0); + } + return; + } + + // Allocate conversion buffers once. Identity conversions use `Buffer = ()` + // and these calls are no-ops. + let a_tile_rows = K::A_PANEL * plan.a_panels_per_tile; + let b_tile_rows = K::B_PANEL * plan.b_panels_per_tile; + let mut a_buf = ca.new_buffer(a_tile_rows, k); + let mut b_buf = cb.new_buffer(b_tile_rows, k); + + // SAFETY: Caller guarantees b_ptr is valid for b_nrows * k elements. + let pb_end = unsafe { b_ptr.add(b_nrows * k) }; + // SAFETY: b_remainder < B_PANEL, so pb_end - b_remainder * k is within allocation. + let pb_full_end = unsafe { pb_end.sub(b_remainder * k) }; + + // SAFETY: All pointer arithmetic stays within the respective allocations. + unsafe { + let mut rows_done: usize = 0; + + // Loop 1: Tiles of `A`. + while rows_done < a_padded_nrows { + let tile_rows = a_tile_rows.min(a_padded_nrows - rows_done); + let pa_tile_src = a_ptr.add(rows_done * k); + // SAFETY: rows_done < a_padded_nrows (loop condition), so the + // pointer is in-bounds. + let pr_tile = scratch.as_mut_ptr().add(rows_done); + + // Convert A tile from storage layout to kernel layout. + let pa_tile = ca.convert(&mut a_buf, arch, pa_tile_src, tile_rows, k); + let pa_tile_end = pa_tile.add(tile_rows * k); + + let mut pb_tile_src = b_ptr; + + // Loop 2: Full B-tiles (every panel in the tile is complete). + // SAFETY: `pb_tile_src` is always in `[b_ptr, pb_full_end]` — both within + // the same allocation — so `offset_from` is well-defined. + while pb_full_end.offset_from(pb_tile_src) >= b_src_tile_stride as isize { + // Convert B tile from storage layout to kernel layout. + let pb_tile = cb.convert(&mut b_buf, arch, pb_tile_src, b_tile_rows, k); + let pb_tile_end = pb_tile.add(b_tile_rows * k); + + let mut pa_panel = pa_tile; + let mut pr_panel = pr_tile; + + // Loop 3: Micro-panels of `A`. + while pa_panel < pa_tile_end { + let mut pb_panel = pb_tile; + + // Loop 4: Micro-panels of `B` (all full, no remainder check). + while pb_panel < pb_tile_end { + K::full_panel(arch, pa_panel, pb_panel, k, pr_panel); + pb_panel = pb_panel.add(b_kern_panel_stride); + } + + pa_panel = pa_panel.add(a_kern_panel_stride); + pr_panel = pr_panel.add(K::A_PANEL); + } + pb_tile_src = pb_tile_src.add(b_src_tile_stride); + } + + // Peeled last B-tile: contains remaining full panels + remainder rows. + if pb_tile_src < pb_end { + let remaining_b_rows = b_nrows - ((pb_tile_src.offset_from(b_ptr) as usize) / k); + // Convert remaining B rows. + let pb_tile = cb.convert(&mut b_buf, arch, pb_tile_src, remaining_b_rows, k); + + let full_panels_in_remainder = remaining_b_rows / K::B_PANEL; + let pb_full_end_local = pb_tile.add(full_panels_in_remainder * b_kern_panel_stride); + + let mut pa_panel = pa_tile; + let mut pr_panel = pr_tile; + + // Loop 3 (peeled): Micro-panels of `A`. + while pa_panel < pa_tile_end { + let mut pb_panel = pb_tile; + + // Loop 4 (peeled): Full B-panels in the last tile. + while pb_panel < pb_full_end_local { + K::full_panel(arch, pa_panel, pb_panel, k, pr_panel); + pb_panel = pb_panel.add(b_kern_panel_stride); + } + + // Remainder dispatch: 1..(B_PANEL-1) leftover B-rows. + if b_remainder > 0 { + K::partial_panel(arch, b_remainder, pa_panel, pb_panel, k, pr_panel); + } + + pa_panel = pa_panel.add(a_kern_panel_stride); + pr_panel = pr_panel.add(K::A_PANEL); + } + } + + rows_done += tile_rows; + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + use diskann_wide::arch::Scalar; + + use super::super::f32::{F32Kernel, max_ip_kernel}; + use super::super::layouts; + use crate::multi_vector::{BlockTransposed, MatRef, Standard}; + + #[test] + fn basic_panel_counts() { + // 16 A-rows × 256 bytes/row = 4096 bytes per A-panel. + // L2 budget 40960 → 40960 / 4096 = 10 A-panels. + // One A-panel = 4096 bytes, L1 budget 36000 → 36000 - 4096 = 31904. + // 4 B-rows × 256 bytes/row = 1024 bytes per B-panel. + // 31904 / 1024 = 31 B-panels. + let plan = FullReduce::new( + 256, + 256, + 16, + 4, + TileBudget { + l2_a: 40960, + l1_b: 36000, + }, + ); + assert_eq!(plan.a_panels_per_tile, 10); + assert_eq!(plan.b_panels_per_tile, 31); + } + + #[test] + fn tiny_budget_clamps_to_one() { + // Budget too small for even one panel — clamp to 1. + let plan = FullReduce::new(1024, 1024, 16, 4, TileBudget { l2_a: 1, l1_b: 1 }); + assert_eq!(plan.a_panels_per_tile, 1); + assert_eq!(plan.b_panels_per_tile, 1); + } + + #[test] + fn zero_byte_rows_clamped() { + // Zero-byte rows (e.g. k=0) should not divide by zero. + // FullReduce clamps row bytes to max(1), so a_row_bytes=1, b_row_bytes=1. + let plan = FullReduce::new( + 0, + 0, + 16, + 4, + TileBudget { + l2_a: 100_000, + l1_b: 50_000, + }, + ); + // a_panels = 100_000 / (1 * 16) = 6250 + assert_eq!(plan.a_panels_per_tile, 6250); + // a_panel_bytes = 16 * 1 = 16. b_tile_budget = 50_000 - 16 = 49_984. + // b_panels = 49_984 / (1 * 4) = 12_496 + assert_eq!(plan.b_panels_per_tile, 12_496); + } + + #[test] + fn exact_fit_one_panel() { + // Budget exactly fits one A-panel (16 × 64 = 1024 bytes). + // No room for a second → a_panels = 1. + let plan = FullReduce::new( + 64, + 64, + 16, + 4, + TileBudget { + l2_a: 1024, + l1_b: 2048, + }, + ); + assert_eq!(plan.a_panels_per_tile, 1); + // L1: 2048 - 16*64(=1024) = 1024 for B. 4*64=256 per B-panel → 4 panels. + assert_eq!(plan.b_panels_per_tile, 4); + } + + #[test] + fn l1_saturated_by_a_panel() { + // A-panel alone exceeds L1 budget → b_tile_budget saturates to 0, + // b_panels_per_tile clamps to 1. + let plan = FullReduce::new( + 1024, + 64, + 16, + 4, + TileBudget { + l2_a: 100_000, + l1_b: 100, + }, + ); + assert_eq!(plan.b_panels_per_tile, 1); + } + + #[test] + #[should_panic(expected = "must be a multiple of A_PANEL")] + fn panics_on_unaligned_a_rows() { + let k = 4; + // 9 is not a multiple of A_PANEL (8). + let a = vec![0.0f32; 9 * k]; + let b = vec![0.0f32; 2 * k]; + let mut scratch = vec![f32::MIN; 16]; + + let ca = layouts::BlockTransposed::::new(); + let cb = layouts::RowMajor::::new(); + + // SAFETY: pointers and scratch are correctly sized; we expect a panic. + unsafe { + super::tiled_reduce::, _, _>( + Scalar::new(), + &ca, + &cb, + a.as_ptr(), + 9, + b.as_ptr(), + 2, + k, + &mut scratch, + TileBudget::default(), + ); + } + } + + #[test] + fn zero_dim_fills_scratch_and_returns() { + let a_rows = 8; + let b_rows = 3; + let k = 0; + + let a = Vec::::new(); + let b = Vec::::new(); + let mut scratch = vec![f32::MIN; a_rows]; + + let ca = layouts::BlockTransposed::::new(); + let cb = layouts::RowMajor::::new(); + + // SAFETY: k == 0 so no elements are read; pointers are never dereferenced. + unsafe { + super::tiled_reduce::, _, _>( + Scalar::new(), + &ca, + &cb, + a.as_ptr(), + a_rows, + b.as_ptr(), + b_rows, + k, + &mut scratch, + TileBudget::default(), + ); + } + + for &v in &scratch { + assert_eq!(v, 0.0, "zero-dim IP should be 0.0"); + } + } + + #[test] + fn zero_dim_zero_docs_leaves_scratch_untouched() { + let a_rows = 8; + let mut scratch = vec![f32::MIN; a_rows]; + + let ca = layouts::BlockTransposed::::new(); + let cb = layouts::RowMajor::::new(); + + // SAFETY: k == 0, b_nrows == 0; no elements read. + unsafe { + super::tiled_reduce::, _, _>( + Scalar::new(), + &ca, + &cb, + [].as_ptr(), + a_rows, + [].as_ptr(), + 0, + 0, + &mut scratch, + TileBudget::default(), + ); + } + + for &v in &scratch { + assert_eq!(v, f32::MIN, "zero docs should leave scratch untouched"); + } + } + + // Shared shape matrix for the `tiled_reduce_*_matches_naive` tests. + // Sized to exercise degenerate, prime-`k`, exact-`A_PANEL`, off-by-one + // `A_PANEL`, multi-A-tile, and every B-row remainder class. + // + // Differs from `query_computer::tests::TEST_CASES` (the end-to-end + // shape matrix) by the inclusion of `(3, 2, 0)` and `(3, 0, 4)` — + // zero-`k` and zero-`b_nrows` are kernel-internal early-exit edges + // not relevant to the public `QueryComputer` API surface (which has + // dedicated `chamfer_with_zero_docs` / `max_sim_with_zero_docs` + // tests asserting different contracts) — and by `(8, 3, 4)` / + // `(16, 5, 8)`, which are Scalar-panel arithmetic edges that are + // already covered at the public layer by other shapes that cross the + // same boundaries. + // + // (a_nrows, b_nrows, dim) + const NAIVE_CASES: &[(usize, usize, usize)] = &[ + (1, 1, 1), // Degenerate single-element + (1, 1, 2), // Minimal non-trivial + (1, 1, 4), // Single query, single doc + (1, 5, 8), // Single query, multiple docs + (5, 1, 8), // Multiple queries, single doc + (3, 2, 0), // Zero dimensions, both have rows + (3, 0, 4), // Zero docs + (3, 2, 3), // Prime k + (3, 4, 16), // General case + (5, 3, 5), // Prime k, A-remainder on aarch64 + (7, 7, 32), // Square case + (2, 3, 7), // k not divisible by SIMD lanes + (2, 3, 128), // Larger dimension + (8, 3, 4), // Single A-panel (Scalar), B remainder + (16, 5, 8), // Two A-panels (Scalar), B remainder + (16, 4, 64), // Two A-panels (Scalar), no B remainder; one (V3) + (17, 4, 64), // A-panel remainder on both Scalar and V3 + (32, 5, 16), // Multiple full A-panels, B remainder + (48, 3, 16), // 6 A-panels (Scalar) / 3 (V3) + (16, 6, 32), // V3 B remainder=2 + (16, 7, 32), // V3 B remainder=3 + (16, 8, 32), // No B remainder on either + ]; + + // Two budgets: `default` exercises the peeled tile section only; `tiny` + // forces `a_panels_per_tile=1` and `b_panels_per_tile=1`, which makes + // the main loop body (Loop 2) and multiple A-tile iterations (Loop 1) + // run for every shape. + fn naive_budgets() -> [(&'static str, TileBudget); 2] { + [ + ("default", TileBudget::default()), + ("tiny", TileBudget { l2_a: 1, l1_b: 1 }), + ] + } + + fn naive_max_ip_f32( + a: &[f32], + a_nrows: usize, + b: &[f32], + b_nrows: usize, + k: usize, + ) -> Vec { + (0..a_nrows) + .map(|i| { + (0..b_nrows) + .map(|j| (0..k).map(|d| a[i * k + d] * b[j * k + d]).sum::()) + .fold(f32::MIN, f32::max) + }) + .collect() + } + + fn naive_max_ip_f16( + a: &[half::f16], + a_nrows: usize, + b: &[half::f16], + b_nrows: usize, + k: usize, + ) -> Vec { + (0..a_nrows) + .map(|i| { + (0..b_nrows) + .map(|j| { + (0..k) + .map(|d| a[i * k + d].to_f32() * b[j * k + d].to_f32()) + .sum::() + }) + .fold(f32::MIN, f32::max) + }) + .collect() + } + + /// Run `max_ip_kernel::` against the naive reference for + /// every budget in `naive_budgets()` for one shape, asserting per-row + /// agreement within `tol`. `arch_label` is included in failure + /// messages to identify which arch branch tripped. + #[allow(clippy::too_many_arguments)] + fn check_kernel( + arch: A, + arch_label: &str, + tol: f32, + a_data: &[T], + a_nrows: usize, + b_data: &[T], + b_nrows: usize, + dim: usize, + expected: &[f32], + ) where + A: Architecture, + T: Copy + Default, + F32Kernel: Kernel, + layouts::BlockTransposed: + ConvertTo as Kernel>::Left> + Layout, + layouts::RowMajor: + ConvertTo as Kernel>::Right> + Layout, + { + for &(budget_label, budget) in &naive_budgets() { + let a_mat = MatRef::new(Standard::new(a_nrows, dim).unwrap(), a_data).unwrap(); + let a_bt = BlockTransposed::::from_matrix_view(a_mat.as_matrix_view()); + let b_mat = MatRef::new(Standard::new(b_nrows, dim).unwrap(), b_data).unwrap(); + + let mut scratch = vec![f32::MIN; a_bt.padded_nrows()]; + max_ip_kernel::(arch, a_bt.as_view(), b_mat, &mut scratch, budget); + + for i in 0..a_nrows { + let actual = scratch[i]; + let exp = expected[i]; + assert!( + (actual - exp).abs() < tol, + "[{arch_label}] row {i} mismatch for ({a_nrows},{b_nrows},{dim}) budget={budget_label}: actual={actual}, expected={exp}", + ); + } + } + } + + /// Exercise the f32 micro-kernels (`F32Kernel<8>` Scalar and, on + /// x86_64 hosts with AVX2+FMA, `F32Kernel<16>` V3) through + /// `tiled_reduce` with both `default` and `tiny` budgets. + /// + /// The `tiny` budget combined with the `NAIVE_CASES` matrix is the + /// only place that drives Loop 1 / Loop 2 of the tiling nest for + /// each registered f32 micro-kernel — the `QueryComputer`-based + /// tests always use the production cache budget and so never enter + /// those loops. + /// + /// The V3 branch compiles on all targets but only executes on + /// x86_64 hosts that expose AVX2+FMA at runtime; silently skips + /// otherwise. CI's native x86_64 runners and `sde-avx512-tests` + /// (Sapphire Rapids ⊇ V3) cover this path. + #[test] + fn tiled_reduce_f32_matches_naive() { + for &(a_nrows, b_nrows, dim) in NAIVE_CASES { + let a_data: Vec = (0..a_nrows * dim).map(|i| (i + 1) as f32).collect(); + let b_data: Vec = (0..b_nrows * dim).map(|i| ((i + 1) * 2) as f32).collect(); + let expected = naive_max_ip_f32(&a_data, a_nrows, &b_data, b_nrows, dim); + + check_kernel::<_, f32, 8>( + Scalar::new(), + "scalar", + 1e-6, + &a_data, + a_nrows, + &b_data, + b_nrows, + dim, + &expected, + ); + + #[cfg(target_arch = "x86_64")] + if let Some(arch) = diskann_wide::arch::x86_64::V3::new_checked() { + check_kernel::<_, f32, 16>( + arch, + "x86-64-v3", + 1e-6, + &a_data, + a_nrows, + &b_data, + b_nrows, + dim, + &expected, + ); + } + } + } + + /// Exercise the f16 path (`F16Entry` via `F32Kernel` + `ConvertTo`) + /// through `tiled_reduce` with both `default` and `tiny` budgets. + /// + /// Combined with `tiny`, this drives the per-tile f16→f32 conversion + /// buffer through Loop 1 / Loop 2 of the tiling nest, validating + /// buffer reuse across multiple tiles. The V3 branch additionally + /// covers the V3-width conversion on x86_64 hosts with AVX2+FMA. + #[test] + fn tiled_reduce_f16_matches_naive() { + for &(a_nrows, b_nrows, dim) in NAIVE_CASES { + // Use a small ceil so values stay exactly representable in f16 + // (bit-exact agreement with the f32 naive reference). + let ceil = dim.max(1); + let a_data: Vec = (0..a_nrows * dim) + .map(|i| diskann_wide::cast_f32_to_f16(((i + 1) % ceil) as f32)) + .collect(); + let b_data: Vec = (0..b_nrows * dim) + .map(|i| diskann_wide::cast_f32_to_f16((((i + 1) * 2) % ceil) as f32)) + .collect(); + let expected = naive_max_ip_f16(&a_data, a_nrows, &b_data, b_nrows, dim); + + check_kernel::<_, half::f16, 8>( + Scalar::new(), + "scalar", + 1e-1, + &a_data, + a_nrows, + &b_data, + b_nrows, + dim, + &expected, + ); + + #[cfg(target_arch = "x86_64")] + if let Some(arch) = diskann_wide::arch::x86_64::V3::new_checked() { + check_kernel::<_, half::f16, 16>( + arch, + "x86-64-v3", + 1e-1, + &a_data, + a_nrows, + &b_data, + b_nrows, + dim, + &expected, + ); + } + } + } +} diff --git a/diskann-quantization/src/multi_vector/distance/max_sim.rs b/diskann-quantization/src/multi_vector/distance/max_sim.rs index b9d9a3e95..9ac2b0ed1 100644 --- a/diskann-quantization/src/multi_vector/distance/max_sim.rs +++ b/diskann-quantization/src/multi_vector/distance/max_sim.rs @@ -12,7 +12,7 @@ pub enum MaxSimError { IndexOutOfBounds(usize, usize), #[error("Scores buffer length cannot be 0")] BufferLengthIsZero, - #[error("Invalid buffer length {0} for query size {0}")] + #[error("Invalid buffer length {0} for query size {1}")] InvalidBufferLength(usize, usize), } @@ -22,8 +22,8 @@ pub enum MaxSimError { /// Computes per-query-vector maximum similarities to document vectors. /// -/// For each query vector `qᵢ`, finds the maximum similarity (minimum negated -/// inner product) to any document vector: +/// For each query vector `qᵢ`, computes the negated maximum inner product +/// to any document vector: /// /// ```text /// scores[i] = minⱼ -IP(qᵢ, dⱼ) diff --git a/diskann-quantization/src/multi_vector/distance/mod.rs b/diskann-quantization/src/multi_vector/distance/mod.rs index e0b106548..853f60753 100644 --- a/diskann-quantization/src/multi_vector/distance/mod.rs +++ b/diskann-quantization/src/multi_vector/distance/mod.rs @@ -7,9 +7,13 @@ //! //! - [`MaxSim`]: Per-query-vector maximum similarities. //! - [`Chamfer`]: Sum of MaxSim scores (asymmetric Chamfer distance). +//! - [`QueryComputer`]: Architecture-dispatched query computer backed by +//! SIMD-accelerated block-transposed kernels. //! -//! Both are currently implemented using a simple double-loop kernel over -//! [`InnerProduct`](diskann_vector::distance::InnerProduct). +//! The fallback path uses a double-loop kernel over +//! [`InnerProduct`](diskann_vector::distance::InnerProduct). The optimised +//! path (via [`QueryComputer`]) uses block-transposed layout with +//! cache-tiled SIMD micro-kernels. //! //! # Example //! @@ -41,12 +45,15 @@ //! let mut scores = vec![0.0f32; 2]; //! let mut max_sim = MaxSim::new(&mut scores).unwrap(); //! max_sim.evaluate(query, doc); -//! // scores[0] = -1.0 (query[0] matches doc[0]) -//! // scores[1] = 0.0 (query[1] has no good match) +//! // scores[0] = -1.0 (query[0] matches doc[0]: negated max inner product) +//! // scores[1] = 0.0 (query[1] has no good match: max IP was 0) //! ``` +mod fallback; +mod kernels; mod max_sim; -mod simple; +mod query_computer; +pub use fallback::QueryMatRef; pub use max_sim::{Chamfer, MaxSim, MaxSimError}; -pub use simple::QueryMatRef; +pub use query_computer::QueryComputer; diff --git a/diskann-quantization/src/multi_vector/distance/query_computer/f16.rs b/diskann-quantization/src/multi_vector/distance/query_computer/f16.rs new file mode 100644 index 000000000..e020cf04c --- /dev/null +++ b/diskann-quantization/src/multi_vector/distance/query_computer/f16.rs @@ -0,0 +1,100 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +use diskann_wide::Architecture; +use diskann_wide::arch::Scalar; +#[cfg(target_arch = "aarch64")] +use diskann_wide::arch::aarch64::Neon; +#[cfg(target_arch = "x86_64")] +use diskann_wide::arch::x86_64::{V3, V4}; + +use super::{DynQueryComputer, Prepared, QueryComputer, build_prepared}; +use crate::multi_vector::distance::kernels::f16::F16Entry; +use crate::multi_vector::{BlockTransposed, BlockTransposedRef, MatRef, Standard}; +use diskann_utils::Reborrow; + +impl QueryComputer { + /// Build an f16 query computer, selecting the optimal architecture and + /// GROUP for the current CPU at runtime. + pub fn new(query: MatRef<'_, Standard>) -> Self { + diskann_wide::arch::dispatch1_no_features(BuildComputer, query) + } +} + +impl DynQueryComputer + for Prepared> +where + A: Architecture, + F16Entry: for<'a> diskann_wide::arch::Target3< + A, + (), + BlockTransposedRef<'a, half::f16, GROUP>, + MatRef<'a, Standard>, + &'a mut [f32], + >, +{ + fn compute_max_sim(&self, doc: MatRef<'_, Standard>, scores: &mut [f32]) { + let mut scratch = vec![f32::MIN; self.prepared.padded_nrows()]; + self.arch.run3( + F16Entry::, + self.prepared.reborrow(), + doc, + &mut scratch, + ); + for (dst, &src) in scores.iter_mut().zip(&scratch[..self.prepared.nrows()]) { + *dst = -src; + } + } + + fn nrows(&self) -> usize { + self.prepared.nrows() + } +} + +#[derive(Clone, Copy)] +pub(super) struct BuildComputer; + +impl diskann_wide::arch::Target1, MatRef<'_, Standard>> + for BuildComputer +{ + fn run(self, arch: Scalar, query: MatRef<'_, Standard>) -> QueryComputer { + QueryComputer { + inner: Box::new(build_prepared::(arch, query)), + } + } +} + +#[cfg(target_arch = "x86_64")] +impl diskann_wide::arch::Target1, MatRef<'_, Standard>> + for BuildComputer +{ + fn run(self, arch: V3, query: MatRef<'_, Standard>) -> QueryComputer { + QueryComputer { + inner: Box::new(build_prepared::(arch, query)), + } + } +} + +#[cfg(target_arch = "x86_64")] +impl diskann_wide::arch::Target1, MatRef<'_, Standard>> + for BuildComputer +{ + fn run(self, arch: V4, query: MatRef<'_, Standard>) -> QueryComputer { + let arch = arch.retarget(); + QueryComputer { + inner: Box::new(build_prepared::(arch, query)), + } + } +} + +#[cfg(target_arch = "aarch64")] +impl diskann_wide::arch::Target1, MatRef<'_, Standard>> + for BuildComputer +{ + fn run(self, arch: Neon, query: MatRef<'_, Standard>) -> QueryComputer { + let arch = arch.retarget(); + QueryComputer { + inner: Box::new(build_prepared::(arch, query)), + } + } +} diff --git a/diskann-quantization/src/multi_vector/distance/query_computer/f32.rs b/diskann-quantization/src/multi_vector/distance/query_computer/f32.rs new file mode 100644 index 000000000..b3dd536a6 --- /dev/null +++ b/diskann-quantization/src/multi_vector/distance/query_computer/f32.rs @@ -0,0 +1,101 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +use diskann_wide::Architecture; +use diskann_wide::arch::Scalar; +#[cfg(target_arch = "aarch64")] +use diskann_wide::arch::aarch64::Neon; +#[cfg(target_arch = "x86_64")] +use diskann_wide::arch::x86_64::{V3, V4}; + +use super::{DynQueryComputer, Prepared, QueryComputer, build_prepared}; +use crate::multi_vector::distance::kernels::f32::F32Kernel; +use crate::multi_vector::{BlockTransposed, BlockTransposedRef, MatRef, Standard}; +use diskann_utils::Reborrow; + +impl QueryComputer { + /// Build an f32 query computer, selecting the optimal architecture and + /// GROUP for the current CPU at runtime. + pub fn new(query: MatRef<'_, Standard>) -> Self { + diskann_wide::arch::dispatch1_no_features(BuildComputer, query) + } +} + +impl DynQueryComputer for Prepared> +where + A: Architecture, + F32Kernel: for<'a> diskann_wide::arch::Target3< + A, + (), + BlockTransposedRef<'a, f32, GROUP>, + MatRef<'a, Standard>, + &'a mut [f32], + >, +{ + fn compute_max_sim(&self, doc: MatRef<'_, Standard>, scores: &mut [f32]) { + let mut scratch = vec![f32::MIN; self.prepared.padded_nrows()]; + self.arch.run3( + F32Kernel::, + self.prepared.reborrow(), + doc, + &mut scratch, + ); + for (dst, &src) in scores.iter_mut().zip(&scratch[..self.prepared.nrows()]) { + *dst = -src; + } + } + + fn nrows(&self) -> usize { + self.prepared.nrows() + } +} + +#[derive(Clone, Copy)] +pub(super) struct BuildComputer; + +impl diskann_wide::arch::Target1, MatRef<'_, Standard>> + for BuildComputer +{ + fn run(self, arch: Scalar, query: MatRef<'_, Standard>) -> QueryComputer { + QueryComputer { + inner: Box::new(build_prepared::(arch, query)), + } + } +} + +#[cfg(target_arch = "x86_64")] +impl diskann_wide::arch::Target1, MatRef<'_, Standard>> + for BuildComputer +{ + fn run(self, arch: V3, query: MatRef<'_, Standard>) -> QueryComputer { + QueryComputer { + inner: Box::new(build_prepared::(arch, query)), + } + } +} + +#[cfg(target_arch = "x86_64")] +impl diskann_wide::arch::Target1, MatRef<'_, Standard>> + for BuildComputer +{ + fn run(self, arch: V4, query: MatRef<'_, Standard>) -> QueryComputer { + // V4 delegates to V3 — the V3 micro-kernel is valid on V4 hardware. + let arch = arch.retarget(); + QueryComputer { + inner: Box::new(build_prepared::(arch, query)), + } + } +} + +#[cfg(target_arch = "aarch64")] +impl diskann_wide::arch::Target1, MatRef<'_, Standard>> + for BuildComputer +{ + fn run(self, arch: Neon, query: MatRef<'_, Standard>) -> QueryComputer { + // Neon delegates to Scalar. + let arch = arch.retarget(); + QueryComputer { + inner: Box::new(build_prepared::(arch, query)), + } + } +} diff --git a/diskann-quantization/src/multi_vector/distance/query_computer/mod.rs b/diskann-quantization/src/multi_vector/distance/query_computer/mod.rs new file mode 100644 index 000000000..3632a214e --- /dev/null +++ b/diskann-quantization/src/multi_vector/distance/query_computer/mod.rs @@ -0,0 +1,323 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +//! Architecture-opaque query computer with runtime dispatch. +//! +//! [`QueryComputer`] wraps a block-transposed query and a captured +//! architecture token behind a trait-object vtable. CPU detection happens +//! once at construction; every subsequent distance call goes through +//! [`Architecture::run3`](diskann_wide::Architecture::run3) with full +//! `#[target_feature]` propagation — no re-dispatch and no enum matching +//! on the hot path. +//! +//! # Usage +//! +//! ``` +//! use diskann_quantization::multi_vector::{ +//! QueryComputer, MatRef, Standard, Chamfer, +//! }; +//! use diskann_vector::PureDistanceFunction; +//! +//! let query_data = [1.0f32, 0.0, 0.0, 1.0]; +//! let doc_data = [1.0f32, 0.0, 0.0, 1.0]; +//! +//! let query = MatRef::new(Standard::new(2, 2).unwrap(), &query_data).unwrap(); +//! let doc = MatRef::new(Standard::new(2, 2).unwrap(), &doc_data).unwrap(); +//! +//! // Build — runtime detects arch, picks optimal GROUP, captures both +//! let computer = QueryComputer::::new(query); +//! +//! // Distance — vtable → arch.run3 with target_feature propagation +//! let dist = Chamfer::evaluate(&computer, doc); +//! assert_eq!(dist, -2.0); +//! ``` + +mod f16; +mod f32; + +use diskann_vector::{DistanceFunctionMut, PureDistanceFunction}; + +use crate::multi_vector::{BlockTransposed, MatRef, Standard}; + +use super::max_sim::{Chamfer, MaxSim}; + +/// Architecture-dispatched query computer for multi-vector distance. +pub struct QueryComputer { + inner: Box>, +} + +impl QueryComputer { + /// Number of logical (non-padded) query vectors. + #[inline] + pub fn nrows(&self) -> usize { + self.inner.nrows() + } + + /// Compute Chamfer distance (sum of per-query max similarities, negated). + /// + /// Returns `0.0` if the document has zero vectors. + pub fn chamfer(&self, doc: MatRef<'_, Standard>) -> f32 { + let nq = self.nrows(); + if doc.num_vectors() == 0 { + return 0.0; + } + let mut scores = vec![0.0f32; nq]; + self.max_sim(doc, &mut scores); + scores.iter().sum() + } + + /// Compute per-query-vector max similarities into `scores`. + /// + /// `scores` must have length equal to [`nrows()`](Self::nrows). + /// Each entry is the negated max inner product for that query vector. + /// + /// # Panics + /// + /// Panics if `scores.len() != self.nrows()`. + pub fn max_sim(&self, doc: MatRef<'_, Standard>, scores: &mut [f32]) { + let nq = self.nrows(); + assert_eq!( + scores.len(), + nq, + "scores buffer not right size: {} != {}", + scores.len(), + nq + ); + + if doc.num_vectors() == 0 { + return; + } + + self.inner.compute_max_sim(doc, scores); + } +} + +trait DynQueryComputer { + fn compute_max_sim(&self, doc: MatRef<'_, Standard>, scores: &mut [f32]); + fn nrows(&self) -> usize; +} + +struct Prepared { + arch: A, + prepared: Q, +} + +fn build_prepared( + arch: A, + query: MatRef<'_, Standard>, +) -> Prepared> { + let prepared = BlockTransposed::::from_matrix_view(query.as_matrix_view()); + Prepared { arch, prepared } +} + +impl PureDistanceFunction<&QueryComputer, MatRef<'_, Standard>, f32> for Chamfer { + fn evaluate(query: &QueryComputer, doc: MatRef<'_, Standard>) -> f32 { + query.chamfer(doc) + } +} + +impl DistanceFunctionMut<&QueryComputer, MatRef<'_, Standard>> for MaxSim<'_> { + fn evaluate(&mut self, query: &QueryComputer, doc: MatRef<'_, Standard>) { + query.max_sim(doc, self.scores_mut()); + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::multi_vector::{Chamfer, MaxSim, QueryMatRef}; + use diskann_vector::distance::InnerProduct; + use diskann_vector::{DistanceFunctionMut, PureDistanceFunction}; + + trait FromF32 { + fn from_f32(v: f32) -> Self; + } + + impl FromF32 for f32 { + fn from_f32(v: f32) -> Self { + v + } + } + + impl FromF32 for half::f16 { + fn from_f32(v: f32) -> Self { + diskann_wide::cast_f32_to_f16(v) + } + } + + fn make_mat(data: &[T], nrows: usize, ncols: usize) -> MatRef<'_, Standard> { + MatRef::new(Standard::new(nrows, ncols).unwrap(), data).unwrap() + } + + fn make_test_data(len: usize, ceil: usize, shift: usize) -> Vec { + (0..len) + .map(|v| T::from_f32(((v + shift) % ceil) as f32)) + .collect() + } + + /// Test cases: (num_queries, num_docs, dim). + /// + /// Sized to exercise: + /// * degenerate single-element shapes, + /// * `k` (dim) not divisible by SIMD lane count, + /// * exact and off-by-one A_PANEL boundaries on both `GROUP=8` (Scalar/Neon) + /// and `GROUP=16` (V3/V4) configurations, + /// * every B-row remainder class for the active `B_PANEL` (1, 2, 3 on V3; + /// 1 on Scalar). + /// + /// Diverges from `kernels::tiled_reduce::tests::NAIVE_CASES`: the + /// kernel-level matrix additionally covers zero-`k` / zero-`b_nrows` + /// (kernel internal early-exit edges, with no public-API meaning — + /// the API contracts for empty docs are pinned by the dedicated + /// `chamfer_with_zero_docs` / `max_sim_with_zero_docs` tests) and a + /// pair of Scalar-panel arithmetic edges already crossed by the + /// shapes below. + const TEST_CASES: &[(usize, usize, usize)] = &[ + (1, 1, 1), // Degenerate single-element + (1, 1, 2), // Minimal non-trivial + (1, 1, 4), // Single query, single doc + (1, 5, 8), // Single query, multiple docs + (5, 1, 8), // Multiple queries, single doc + (3, 2, 3), // Prime k + (3, 4, 16), // General case + (5, 3, 5), // Prime k, A-remainder on aarch64 + (7, 7, 32), // Square case + (2, 3, 7), // k not divisible by SIMD lanes + (2, 3, 128), // Larger dimension + (16, 4, 64), // Exact A_PANEL on x86_64; two panels on aarch64 + (17, 4, 64), // One more than A_PANEL (remainder) + (32, 5, 16), // Multiple full A-panels, remainder B-rows (5 % 4 = 1) + (48, 3, 16), // 3 A-tiles on x86_64; 6 on aarch64 + (16, 6, 32), // Remainder B-rows (6 % 4 = 2) + (16, 7, 32), // Remainder B-rows (7 % 4 = 3) + (16, 8, 32), // No remainder B-rows (8 % 4 = 0) + ]; + + fn check_chamfer_matches( + build: fn(MatRef<'_, Standard>) -> QueryComputer, + tol: f32, + label: &str, + ) where + InnerProduct: for<'a, 'b> PureDistanceFunction<&'a [T], &'b [T], f32>, + { + for &(nq, nd, dim) in TEST_CASES { + let query_data = make_test_data::(nq * dim, dim, dim / 2); + let doc_data = make_test_data::(nd * dim, dim, dim); + + let query = make_mat(&query_data, nq, dim); + let doc = make_mat(&doc_data, nd, dim); + + let expected = Chamfer::evaluate(QueryMatRef::from(query), doc); + let actual = build(query).chamfer(doc); + + assert!( + (actual - expected).abs() < tol, + "{label}Chamfer mismatch for ({nq},{nd},{dim}): actual={actual}, expected={expected}", + ); + } + } + + fn check_max_sim_matches( + build: fn(MatRef<'_, Standard>) -> QueryComputer, + tol: f32, + label: &str, + ) where + InnerProduct: for<'a, 'b> PureDistanceFunction<&'a [T], &'b [T], f32>, + { + for &(nq, nd, dim) in TEST_CASES { + let query_data = make_test_data::(nq * dim, dim, dim / 2); + let doc_data = make_test_data::(nd * dim, dim, dim); + + let query = make_mat(&query_data, nq, dim); + let doc = make_mat(&doc_data, nd, dim); + + let mut expected_scores = vec![0.0f32; nq]; + let _ = MaxSim::new(&mut expected_scores) + .unwrap() + .evaluate(QueryMatRef::from(query), doc); + + let computer = build(query); + let mut actual_scores = vec![0.0f32; nq]; + computer.max_sim(doc, &mut actual_scores); + + for i in 0..nq { + assert!( + (actual_scores[i] - expected_scores[i]).abs() < tol, + "{label}MaxSim[{i}] mismatch for ({nq},{nd},{dim}): actual={}, expected={}", + actual_scores[i], + expected_scores[i], + ); + } + } + } + + #[test] + fn query_computer_dimensions() { + let data = vec![1.0f32; 5 * 8]; + let query = make_mat(&data, 5, 8); + let computer = QueryComputer::::new(query); + + assert_eq!(computer.nrows(), 5); + } + + #[test] + fn query_computer_f16_dimensions() { + let data = vec![diskann_wide::cast_f32_to_f16(1.0); 5 * 8]; + let query = make_mat(data.as_slice(), 5, 8); + let computer = QueryComputer::::new(query); + + assert_eq!(computer.nrows(), 5); + } + + #[test] + fn chamfer_with_zero_docs() { + let query = make_mat(&[1.0f32, 0.0, 0.0, 1.0], 2, 2); + let computer = QueryComputer::::new(query); + let doc = make_mat(&[], 0, 2); + assert_eq!(computer.chamfer(doc), 0.0); + } + + #[test] + fn max_sim_with_zero_docs() { + let query = make_mat(&[1.0f32, 0.0, 0.0, 1.0], 2, 2); + let computer = QueryComputer::::new(query); + let doc = make_mat::(&[], 0, 2); + let mut scores = vec![0.0f32; 2]; + computer.max_sim(doc, &mut scores); + // With zero docs the scores buffer is left untouched. + for &s in &scores { + assert_eq!(s, 0.0, "zero-doc MaxSim should leave scores untouched"); + } + } + + #[test] + #[should_panic(expected = "scores buffer not right size")] + fn max_sim_panics_on_size_mismatch() { + let query = make_mat(&[1.0f32, 2.0, 3.0, 4.0], 2, 2); + let computer = QueryComputer::::new(query); + let doc = make_mat(&[1.0, 1.0], 1, 2); + let mut scores = vec![0.0f32; 3]; // Wrong size + computer.max_sim(doc, &mut scores); + } + + macro_rules! test_matches_fallback { + ($mod_name:ident, $ty:ty, $tol:expr, $label:literal) => { + mod $mod_name { + use super::*; + + #[test] + fn chamfer_matches_fallback() { + check_chamfer_matches(QueryComputer::<$ty>::new, $tol, $label); + } + + #[test] + fn max_sim_matches_fallback() { + check_max_sim_matches(QueryComputer::<$ty>::new, $tol, $label); + } + } + }; + } + + test_matches_fallback!(f32, f32, 1e-2, "f32 "); + test_matches_fallback!(f16, half::f16, 1e-1, "f16 "); +} diff --git a/diskann-quantization/src/multi_vector/matrix.rs b/diskann-quantization/src/multi_vector/matrix.rs index 65a27b711..70629d44c 100644 --- a/diskann-quantization/src/multi_vector/matrix.rs +++ b/diskann-quantization/src/multi_vector/matrix.rs @@ -29,7 +29,7 @@ use std::{alloc::Layout, iter::FusedIterator, marker::PhantomData, ptr::NonNull}; -use diskann_utils::{Reborrow, ReborrowMut}; +use diskann_utils::{Reborrow, ReborrowMut, views::MatrixView}; use thiserror::Error; use crate::utils; @@ -717,6 +717,20 @@ impl Mat> { pub fn vector_dim(&self) -> usize { self.repr.ncols() } + + /// Return the backing data as a contiguous slice of `T`. + /// + /// The returned slice has `num_vectors() * vector_dim()` elements in row-major order. + #[inline] + pub fn as_slice(&self) -> &[T] { + self.as_view().as_slice() + } + + /// Return a [`MatrixView`] over the backing data. + #[inline] + pub fn as_matrix_view(&self) -> MatrixView<'_, T> { + self.as_view().as_matrix_view() + } } //////////// @@ -832,6 +846,27 @@ impl<'a, T: Copy> MatRef<'a, Standard> { pub fn vector_dim(&self) -> usize { self.repr.ncols() } + + /// Return the backing data as a contiguous slice of `T`. + /// + /// The returned slice has `num_vectors() * vector_dim()` elements in row-major order. + #[inline] + pub fn as_slice(&self) -> &'a [T] { + let len = self.repr.num_elements(); + // SAFETY: `Standard` guarantees `nrows * ncols` contiguous `T` elements + // starting at `self.ptr`. The lifetime `'a` is tied to the original data. + unsafe { std::slice::from_raw_parts(self.ptr.as_ptr().cast::(), len) } + } + + /// Return a [`MatrixView`] over the backing data. + #[allow(clippy::expect_used)] + #[inline] + pub fn as_matrix_view(&self) -> MatrixView<'a, T> { + // `Standard::new` validates that `nrows * ncols` does not overflow, + // so `try_from` is infallible here. + MatrixView::try_from(self.as_slice(), self.num_vectors(), self.vector_dim()) + .expect("Standard has valid dimensions") + } } // Reborrow: Mat -> MatRef @@ -1042,6 +1077,20 @@ impl<'a, T: Copy> MatMut<'a, Standard> { pub fn vector_dim(&self) -> usize { self.repr.ncols() } + + /// Return the backing data as a contiguous slice of `T`. + /// + /// The returned slice has `num_vectors() * vector_dim()` elements in row-major order. + #[inline] + pub fn as_slice(&self) -> &[T] { + self.as_view().as_slice() + } + + /// Return a [`MatrixView`] over the backing data. + #[inline] + pub fn as_matrix_view(&self) -> MatrixView<'_, T> { + self.as_view().as_matrix_view() + } } ////////// @@ -1775,4 +1824,52 @@ mod tests { }) )); } + + #[test] + fn as_matrix_view_roundtrip() { + let data = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]; + + // MatRef + let matref = MatRef::new(Standard::new(2, 3).unwrap(), &data).unwrap(); + let view = matref.as_matrix_view(); + assert_eq!(view.nrows(), 2); + assert_eq!(view.ncols(), 3); + for row in 0..2 { + for col in 0..3 { + assert_eq!(view[(row, col)], data[row * 3 + col]); + } + } + assert_eq!(matref.as_slice(), &data); + + // Mat + let mut mat = Mat::new(Standard::::new(2, 3).unwrap(), 0.0f32).unwrap(); + for i in 0..2 { + let r = mat.get_row_mut(i).unwrap(); + for j in 0..3 { + r[j] = data[i * 3 + j]; + } + } + let view = mat.as_matrix_view(); + assert_eq!(view.nrows(), 2); + assert_eq!(view.ncols(), 3); + for row in 0..2 { + for col in 0..3 { + assert_eq!(view[(row, col)], data[row * 3 + col]); + } + } + assert_eq!(mat.as_slice(), &data); + + // MatMut + let mut buf = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]; + let matmut = MatMut::new(Standard::new(2, 3).unwrap(), &mut buf).unwrap(); + let view = matmut.as_matrix_view(); + assert_eq!(view.nrows(), 2); + assert_eq!(view.ncols(), 3); + for row in 0..2 { + for col in 0..3 { + assert_eq!(view[(row, col)], data[row * 3 + col]); + } + } + assert_eq!(matmut.as_slice(), &data); + } } diff --git a/diskann-quantization/src/multi_vector/mod.rs b/diskann-quantization/src/multi_vector/mod.rs index f8598eba0..3670b1aaf 100644 --- a/diskann-quantization/src/multi_vector/mod.rs +++ b/diskann-quantization/src/multi_vector/mod.rs @@ -20,6 +20,7 @@ //! | [`BlockTransposedRef`] | Immutable view of a block-transposed matrix | //! | [`BlockTransposedMut`] | Mutable view of a block-transposed matrix | //! | [`QueryMatRef`] | Query wrapper for asymmetric distances | +//! | [`QueryComputer`] | Architecture-dispatched SIMD query computer | //! | [`MaxSim`] | Per-query-vector max similarity computation | //! | [`Chamfer`] | Asymmetric Chamfer distance (sum of MaxSim) | //! @@ -71,7 +72,7 @@ pub mod distance; pub(crate) mod matrix; pub use block_transposed::{BlockTransposed, BlockTransposedMut, BlockTransposedRef}; -pub use distance::{Chamfer, MaxSim, MaxSimError, QueryMatRef}; +pub use distance::{Chamfer, MaxSim, MaxSimError, QueryComputer, QueryMatRef}; pub use matrix::{ Defaulted, LayoutError, Mat, MatMut, MatRef, NewCloned, NewMut, NewOwned, NewRef, Overflow, Repr, ReprMut, ReprOwned, SliceError, Standard,