Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 43 additions & 1 deletion diskann-quantization/src/multi_vector/block_transposed.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@
use std::{alloc::Layout, marker::PhantomData, ptr::NonNull};

use diskann_utils::{
ReborrowMut,
Reborrow, ReborrowMut,
strided::StridedView,
views::{MatrixView, MutMatrixView},
};
Expand Down Expand Up @@ -231,6 +231,15 @@ impl<T: Copy, const GROUP: usize, const PACK: usize> BlockTransposedRepr<T, GROU
self.nrows % GROUP
}

/// Total number of logical rows rounded up to the next multiple of `GROUP`.
///
/// This is the number of "available" row slots in the backing allocation,
/// including zero-padded rows in the last (possibly partial) block.
#[inline]
pub fn padded_nrows(&self) -> usize {
self.num_blocks() * GROUP
}

/// The stride (in elements) between the start of consecutive blocks.
#[inline]
fn block_stride(&self) -> usize {
Expand Down Expand Up @@ -743,6 +752,15 @@ impl<'a, T: Copy, const GROUP: usize, const PACK: usize> BlockTransposedRef<'a,
self.data.repr().remainder()
}

/// Total number of logical rows rounded up to the next multiple of `GROUP`.
///
/// This is the number of "available" row slots in the backing allocation,
/// including zero-padded rows in the last (possibly partial) block.
#[inline]
pub fn padded_nrows(&self) -> usize {
self.data.repr().padded_nrows()
}
Comment thread
suri-kumkaran marked this conversation as resolved.

/// Return a raw typed pointer to the start of the backing data.
#[inline]
pub fn as_ptr(&self) -> *const T {
Expand Down Expand Up @@ -870,6 +888,7 @@ impl<'a, T: Copy, const GROUP: usize, const PACK: usize> BlockTransposedMut<'a,
delegate_to_ref!(pub fn full_blocks(&self) -> usize);
delegate_to_ref!(pub fn num_blocks(&self) -> usize);
delegate_to_ref!(pub fn remainder(&self) -> usize);
delegate_to_ref!(pub fn padded_nrows(&self) -> usize);
delegate_to_ref!(pub fn as_ptr(&self) -> *const T);
delegate_to_ref!(pub fn as_slice(&self) -> &[T]);
delegate_to_ref!(#[allow(clippy::missing_safety_doc)] unsafe pub fn block_ptr_unchecked(&self, block: usize) -> *const T);
Expand Down Expand Up @@ -1017,6 +1036,7 @@ impl<T: Copy, const GROUP: usize, const PACK: usize> BlockTransposed<T, GROUP, P
delegate_to_ref!(pub fn full_blocks(&self) -> usize);
delegate_to_ref!(pub fn num_blocks(&self) -> usize);
delegate_to_ref!(pub fn remainder(&self) -> usize);
delegate_to_ref!(pub fn padded_nrows(&self) -> usize);
delegate_to_ref!(pub fn as_ptr(&self) -> *const T);
delegate_to_ref!(pub fn as_slice(&self) -> &[T]);
delegate_to_ref!(#[allow(clippy::missing_safety_doc)] unsafe pub fn block_ptr_unchecked(&self, block: usize) -> *const T);
Expand Down Expand Up @@ -1072,6 +1092,19 @@ impl<T: Copy, const GROUP: usize, const PACK: usize> BlockTransposed<T, GROUP, P
}
}

// ── Reborrow ─────────────────────────────────────────────────────

impl<'this, T: Copy, const GROUP: usize, const PACK: usize> Reborrow<'this>
for BlockTransposed<T, GROUP, PACK>
{
type Target = BlockTransposedRef<'this, T, GROUP, PACK>;

#[inline]
fn reborrow(&'this self) -> Self::Target {
self.as_view()
}
}

// ── Factory methods ──────────────────────────────────────────────

impl<T: Copy + Default, const GROUP: usize, const PACK: usize> BlockTransposed<T, GROUP, PACK> {
Expand Down Expand Up @@ -1676,6 +1709,15 @@ mod tests {
}
}

// ── padded_nrows() returns padded row count ──────────────

assert_eq!(
transpose.as_view().padded_nrows(),
padded_nrows,
"padded_nrows() mismatch -- {}",
context,
);

// ── from_matrix_view produces identical results ──────────

if nrows > 0 && ncols > 0 {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license.

//! Simple kernel implementation of multi-vector distance computation.
//! Fallback kernel implementation of multi-vector distance computation.

use std::ops::Deref;

Expand Down Expand Up @@ -49,17 +49,17 @@ impl<'a, T: Repr> Deref for QueryMatRef<'a, T> {
}
}

//////////////////
// SimpleKernel //
//////////////////
////////////////////
// FallbackKernel //
////////////////////

/// Simple double-loop kernel to compute max-sim distances over multi-vectors.
/// Fallback double-loop kernel to compute max-sim distances over multi-vectors.
///
/// This kernel performs a simple double-loop over the rows of `query`
/// and the `doc` and dispatches to [`InnerProduct`] to compute the similarity.
pub struct SimpleKernel;
pub struct FallbackKernel;

impl SimpleKernel {
impl FallbackKernel {
/// Core kernel for computing per-query-vector max similarities (min negated inner-product).
///
/// For each `query` vector, computes the maximum similarity (negated inner product)
Expand Down Expand Up @@ -128,7 +128,7 @@ where
return Err(MaxSimError::InvalidBufferLength(size, n_queries));
}

SimpleKernel::max_sim_kernel(query, doc, |i, score| {
FallbackKernel::max_sim_kernel(query, doc, |i, score| {
// SAFETY: We asserted that self.size() == query.num_vectors(),
// and i < query.num_vectors() due to the kernel loop bound.
unsafe { *self.scores.get_unchecked_mut(i) = score };
Expand All @@ -151,7 +151,7 @@ where
fn evaluate(query: QueryMatRef<'_, Standard<T>>, doc: MatRef<'_, Standard<T>>) -> f32 {
let mut sum = 0.0f32;

SimpleKernel::max_sim_kernel(query, doc, |_i, score| {
FallbackKernel::max_sim_kernel(query, doc, |_i, score| {
sum += score;
});

Expand Down Expand Up @@ -185,7 +185,7 @@ mod tests {
.fold(f32::MAX, f32::min)
}

/// Generate a vector of random f32 values in [-1, 1] for testing
/// Generate deterministic test data.
fn make_test_data(len: usize, ceil: usize, shift: usize) -> Vec<f32> {
(0..len).map(|v| ((v + shift) % ceil) as f32).collect()
}
Expand Down Expand Up @@ -270,9 +270,9 @@ mod tests {
);
}

// Check that SimpleKernel is also correct.
SimpleKernel::max_sim_kernel(query, doc, |i, score| {
assert!((scores[i] - score).abs() <= 1e-6)
// Check that FallbackKernel produces the same values as the naive reference.
FallbackKernel::max_sim_kernel(query, doc, |i, score| {
assert!((expected_scores[i] - score).abs() <= 1e-6)
});

// Test Chamfer
Expand All @@ -299,7 +299,7 @@ mod tests {
// No query vectors means sum is 0
assert_eq!(result, 0.0);

let result = Chamfer::evaluate(doc.into(), query.deref().reborrow());
let result = Chamfer::evaluate(QueryMatRef::from(doc), query.deref().reborrow());

assert_eq!(result, 0.0);
}
Expand Down
52 changes: 52 additions & 0 deletions diskann-quantization/src/multi_vector/distance/kernels/f16.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license.

//! f16 dispatch adapter for block-transposed multi-vector distance.
//!
//! Reuses the f32 micro-kernel family with tile-level f16→f32 conversion
//! via [`ConvertTo`](super::layouts::ConvertTo). No f16-specific micro-kernel
//! code is needed — the [`F32Kernel`](super::f32::F32Kernel) does all the
//! SIMD work after conversion.
//!
//! Conversion from f16 to f32 is performed at tile granularity via
//! [`SliceCast`](diskann_vector::conversion::SliceCast), dispatched through
//! the runtime architecture token — the same SIMD level used by the
//! micro-kernel.

use diskann_wide::Architecture;

use super::Kernel;
use super::TileBudget;
use super::f32::{F32Kernel, max_ip_kernel};
use super::layouts;
use crate::multi_vector::{BlockTransposedRef, MatRef, Standard};

pub(crate) struct F16Entry<const GROUP: usize>;

impl<A, const GROUP: usize>
diskann_wide::arch::Target3<
A,
(),
BlockTransposedRef<'_, half::f16, GROUP>,
MatRef<'_, Standard<half::f16>>,
&mut [f32],
> for F16Entry<GROUP>
where
A: Architecture,
F32Kernel<GROUP>: Kernel<A>,
layouts::BlockTransposed<half::f16, GROUP>: layouts::ConvertTo<A, <F32Kernel<GROUP> as Kernel<A>>::Left>
+ layouts::Layout<Element = half::f16>,
layouts::RowMajor<half::f16>: layouts::ConvertTo<A, <F32Kernel<GROUP> as Kernel<A>>::Right>
+ layouts::Layout<Element = half::f16>,
{
#[inline(always)]
fn run(
self,
arch: A,
lhs: BlockTransposedRef<'_, half::f16, GROUP>,
rhs: MatRef<'_, Standard<half::f16>>,
scratch: &mut [f32],
) {
max_ip_kernel(arch, lhs, rhs, scratch, TileBudget::default());
}
}
135 changes: 135 additions & 0 deletions diskann-quantization/src/multi_vector/distance/kernels/f32/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license.

//! f32 micro-kernel family for block-transposed multi-vector distance.
//!
//! Provides:
//!
//! - `F32Kernel<GROUP>` — zero-sized marker type selecting the f32 micro-kernel
//! for `BlockTransposed<f32, GROUP>` data.
//! - [`max_ip_kernel`] — architecture-, element-type-, and GROUP-generic entry point
//! for the reducing max-IP GEMM. Accepts any element type `T` for which
//! [`ConvertTo`](super::layouts::ConvertTo) impls exist (identity for f32,
//! SIMD-accelerated f16→f32, etc.).
//!
//! # Architecture-specific micro-kernels
//!
//! - `v3` (x86_64) — V3 (AVX2+FMA) 16×4 micro-kernel (GROUP=16). V4 delegates to V3 at dispatch.
//! - `scalar` — Emulated 8×2 micro-kernel (GROUP=8). Neon delegates to Scalar at dispatch.

use diskann_wide::Architecture;

use super::Kernel;
use super::TileBudget;
use super::layouts::{self, DescribeLayout};
use super::tiled_reduce::tiled_reduce;
use crate::multi_vector::{BlockTransposedRef, MatRef, Standard};

mod scalar;
#[cfg(target_arch = "x86_64")]
mod v3;

/// Zero-sized kernel type for f32 micro-kernels with block size `GROUP`.
pub(crate) struct F32Kernel<const GROUP: usize>;

#[inline(never)]
#[cold]
#[allow(clippy::panic)]
fn max_ip_kernel_panic(scratch_len: usize, padded_nrows: usize, a_ncols: usize, b_dim: usize) {
panic!(
"max_ip_kernel: precondition failed: \
scratch.len()={scratch_len} (expected {padded_nrows}), \
a.ncols()={a_ncols}, b.vector_dim()={b_dim}"
);
}

/// Compute the reducing max-IP GEMM between a block-transposed A matrix and
/// a row-major B matrix, writing per-A-row max similarities into `scratch`.
///
/// Thin wrapper over [`tiled_reduce`] using `F32Kernel<GROUP>` for the
/// requested architecture. The element type `T` can be any `Copy` type with
/// matching [`ConvertTo`](super::layouts::ConvertTo) impls (zero-cost for
/// `T = f32`; SIMD f16→f32 conversion once per tile for `T = half::f16`).
///
/// `scratch` must have length [`BlockTransposedRef::padded_nrows()`] and be
/// initialized to `f32::MIN` before the first call. On return, `scratch[i]`
/// holds the maximum inner product between A row `i` and any B row.
///
/// # Panics
///
/// Panics if `scratch.len() != a.padded_nrows()` or `a.ncols() != b.vector_dim()`.
pub(super) fn max_ip_kernel<A: Architecture, T: Copy, const GROUP: usize>(
arch: A,
a: BlockTransposedRef<'_, T, GROUP>,
b: MatRef<'_, Standard<T>>,
scratch: &mut [f32],
budget: TileBudget,
) where
F32Kernel<GROUP>: Kernel<A>,
layouts::BlockTransposed<T, GROUP>:
layouts::ConvertTo<A, <F32Kernel<GROUP> as Kernel<A>>::Left> + layouts::Layout<Element = T>,
layouts::RowMajor<T>: layouts::ConvertTo<A, <F32Kernel<GROUP> as Kernel<A>>::Right>
+ layouts::Layout<Element = T>,
{
if scratch.len() != a.padded_nrows() || a.ncols() != b.vector_dim() {
max_ip_kernel_panic(scratch.len(), a.padded_nrows(), a.ncols(), b.vector_dim());
}

let k = a.ncols();
let b_nrows = b.num_vectors();

// Compile-time: A_PANEL must equal GROUP for block-transposed layout correctness.
const { assert!(<F32Kernel<GROUP> as Kernel<A>>::A_PANEL == GROUP) }

let ca = a.layout();
let cb = b.layout();

// SAFETY:
// - a.as_ptr() is valid for a.padded_nrows() * k elements of T.
// - MatRef<Standard<T>> stores nrows * ncols contiguous T elements.
// - scratch.len() == a.padded_nrows() (checked above).
// - a.padded_nrows() is always a multiple of GROUP, and the const assert above
// verifies A_PANEL == GROUP at compile time.
unsafe {
tiled_reduce::<A, F32Kernel<GROUP>, _, _>(
arch,
&ca,
&cb,
a.as_ptr(),
a.padded_nrows(),
b.as_slice().as_ptr(),
b_nrows,
k,
scratch,
budget,
);
}
}

impl<A, const GROUP: usize>
diskann_wide::arch::Target3<
A,
(),
BlockTransposedRef<'_, f32, GROUP>,
MatRef<'_, Standard<f32>>,
&mut [f32],
> for F32Kernel<GROUP>
where
A: Architecture,
Self: Kernel<A>,
layouts::BlockTransposed<f32, GROUP>:
layouts::ConvertTo<A, <Self as Kernel<A>>::Left> + layouts::Layout<Element = f32>,
layouts::RowMajor<f32>:
layouts::ConvertTo<A, <Self as Kernel<A>>::Right> + layouts::Layout<Element = f32>,
{
#[inline(always)]
fn run(
self,
arch: A,
lhs: BlockTransposedRef<'_, f32, GROUP>,
rhs: MatRef<'_, Standard<f32>>,
scratch: &mut [f32],
) {
max_ip_kernel(arch, lhs, rhs, scratch, TileBudget::default());
}
}
Loading
Loading