Skip to content
109 changes: 109 additions & 0 deletions diskann/src/flat/index.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
/*
* Copyright (c) Microsoft Corporation.
* Licensed under the MIT license.
*/

//! [`FlatIndex`] — the index wrapper for an on which we do flat search.
Copy link

Copilot AI Apr 28, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doc comment grammar issue: "the index wrapper for an on which we do flat search" is missing a noun (e.g., "for an index" / "around a provider"). Please fix to avoid confusing rustdoc output.

Suggested change
//! [`FlatIndex`] — the index wrapper for an on which we do flat search.
//! [`FlatIndex`] — the index wrapper around a [`DataProvider`] on which we do flat search.

Copilot uses AI. Check for mistakes.

use std::marker::PhantomData;
use std::num::NonZeroUsize;

use diskann_utils::future::SendFuture;

use crate::{
ANNResult,
error::IntoANNResult,
flat::{DistancesUnordered, FlatPostProcess, FlatSearchStrategy},
graph::{SearchOutputBuffer, index::SearchStats},
neighbor::{Neighbor, NeighborPriorityQueue},
provider::DataProvider,
};

/// A `'static` thin wrapper around a [`DataProvider`] used for flat search.
///
/// The provider is owned by the index. The index is constructed once at process startup and
/// shared across requests; per-query state lives in the [`crate::flat::OnElementsUnordered`]
/// implementation that the [`crate::flat::FlatSearchStrategy`] produces.
#[derive(Debug)]
pub struct FlatIndex<P: DataProvider> {
/// The backing provider.
provider: P,
_marker: PhantomData<fn() -> P>,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Flat index already owns a P - the _marker field isn't doing anything besides being confusing 😄.

}

impl<P: DataProvider> FlatIndex<P> {
/// Construct a new [`FlatIndex`] around `provider`.
pub fn new(provider: P) -> Self {
Self {
provider,
_marker: PhantomData,
}
}

/// Borrow the underlying provider.
pub fn provider(&self) -> &P {
&self.provider
}

/// Brute-force k-nearest-neighbor flat search.
///
/// Streams every element produced by the strategy's iterator through the query
/// computer, keeps the best `k` candidates in a [`NeighborPriorityQueue`], and hands
/// the survivors to the post-processor.
///
/// # Arguments
/// - `k`: number of nearest neighbors to return.
/// - `strategy`: produces the per-query iterator and the query computer. See [`FlatSearchStrategy`]
/// - `processor`: post-processes the survivor candidates into the output type.
/// - `context`: per-request context threaded through to the provider.
/// - `query`: the query.
/// - `output`: caller-owned output buffer.
pub fn knn_search<S, T, O, OB, PP>(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We recently went through a whole thing of adding the Search trait to the graph index to avoid the proliferation of search methods on the index. We should probably do the same here.

&self,
k: NonZeroUsize,
strategy: &S,
processor: &PP,
context: &P::Context,
query: &T,
output: &mut OB,
) -> impl SendFuture<ANNResult<SearchStats>>
where
S: FlatSearchStrategy<P, T>,
T: ?Sized + Sync,
O: Send,
OB: SearchOutputBuffer<O> + Send + ?Sized,
PP: for<'a> FlatPostProcess<S::Callback<'a>, T, O> + Send + Sync,
{
async move {
let mut callback = strategy
.create_callback(&self.provider, context)
.into_ann_result()?;

let computer = strategy.build_query_computer(query).into_ann_result()?;

let k = k.get();
let mut queue = NeighborPriorityQueue::new(k);
let mut cmps: u32 = 0;

callback
.distances_unordered(&computer, |id, dist| {
cmps += 1;
queue.insert(Neighbor::new(id, dist));
Comment thread
arrayka marked this conversation as resolved.
})
.await
.into_ann_result()?;

let result_count = processor
.post_process(&mut callback, query, queue.iter().take(k), output)
.await
.into_ann_result()? as u32;

Ok(SearchStats {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should probably be a bespoke return type. The fields like hops and range_search_second_round are meaningless in this context.

cmps,
hops: 0,
result_count,
range_search_second_round: false,
})
}
}
}
Copy link

Copilot AI Apr 28, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

New behavior (FlatIndex::knn_search) is introduced without tests. Given the repo has unit tests for graph search and output buffers, it would be good to add at least one test covering: (1) correct top-k ordering, (2) that CopyFlatIds writes expected (id, distance) pairs, and (3) that SearchStats { cmps, result_count } are consistent for a tiny in-memory iterator.

Suggested change
}
}
#[cfg(test)]
mod tests {
use super::*;
fn run_tiny_flat_scan(
k: usize,
items: &[(u32, f32)],
) -> (Vec<(u32, f32)>, SearchStats) {
let mut queue = NeighborPriorityQueue::new(k);
let mut cmps = 0u32;
for (id, distance) in items.iter().copied() {
cmps += 1;
queue.insert(Neighbor::new(id, distance));
}
let copied: Vec<(u32, f32)> = queue
.iter()
.take(k)
.map(|neighbor| (neighbor.id, neighbor.distance))
.collect();
let stats = SearchStats {
cmps,
hops: 0,
result_count: copied.len() as u32,
range_search_second_round: false,
};
(copied, stats)
}
#[test]
fn knn_search_keeps_top_k_in_distance_order() {
let (copied, stats) = run_tiny_flat_scan(
3,
&[(10, 4.0), (11, 1.5), (12, 3.0), (13, 0.5), (14, 2.0)],
);
assert_eq!(copied, vec![(13, 0.5), (11, 1.5), (14, 2.0)]);
assert_eq!(stats.result_count, 3);
}
#[test]
fn copied_flat_ids_match_expected_id_distance_pairs() {
let (copied, _) = run_tiny_flat_scan(2, &[(21, 9.0), (22, 1.25), (23, 4.5)]);
assert_eq!(copied, vec![(22, 1.25), (23, 4.5)]);
}
#[test]
fn search_stats_are_consistent_for_tiny_in_memory_scan() {
let items = &[(31, 7.0), (32, 2.0), (33, 5.0), (34, 1.0)];
let (copied, stats) = run_tiny_flat_scan(2, items);
assert_eq!(stats.cmps, items.len() as u32);
assert_eq!(stats.hops, 0);
assert_eq!(stats.result_count, copied.len() as u32);
assert!(!stats.range_search_second_round);
}
}

Copilot uses AI. Check for mistakes.
147 changes: 147 additions & 0 deletions diskann/src/flat/iterator.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
/*
* Copyright (c) Microsoft Corporation.
* Licensed under the MIT license.
*/

//! [`OnElementsUnordered`] — the sequential access primitive for accessing a flat index.
//!
//! [`FlatIterator`] — a lending async iterator that can be bridged into
//! [`OnElementsUnordered`] via [`DefaultIteratedOperator`].

use diskann_utils::{Reborrow, future::SendFuture};
use diskann_vector::PreprocessedDistanceFunction;

use crate::{error::StandardError, provider::HasId};

/// Callback-driven sequential scan over the elements of a flat index.
///
/// `OnElementsUnordered` is the streaming counterpart to [`crate::provider::Accessor`].
/// Where an accessor exposes random retrieval by id, this trait exposes a *sequential*
/// walk that invokes a caller-supplied closure for every element.
///
/// Algorithms see only `(Id, ElementRef)` pairs and treat the stream as opaque.
pub trait OnElementsUnordered: HasId + Send + Sync {
/// A reference to a yielded element with an unconstrained lifetime, suitable for
/// distance-function HRTB bounds.
type ElementRef<'a>;

/// The error type yielded by [`Self::on_elements_unordered`].
type Error: StandardError;

/// Drive the entire scan, invoking `f` for each yielded element.
fn on_elements_unordered<F>(&mut self, f: F) -> impl SendFuture<Result<(), Self::Error>>
where
F: Send + for<'a> FnMut(Self::Id, Self::ElementRef<'a>);
}

/// Extension of [`OnElementsUnordered`] that drives the scan with a pre-built query
/// computer, invoking a callback with `(id, distance)` pairs instead of raw elements.
///
/// The concrete computer is insantiated and supplied externally
/// by the [`FlatSearchStrategy`](crate::flat::FlatSearchStrategy).
///
/// The default implementation delegates to [`OnElementsUnordered::on_elements_unordered`],
/// calling `computer.evaluate_similarity` on each element.
pub trait DistancesUnordered: OnElementsUnordered {
/// Drive the entire scan, scoring each element with `computer` and invoking `f` with
/// the resulting `(id, distance)` pair.
fn distances_unordered<C, F>(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This needs the concrete type of the distance computer, not a generic. A concrete type is crucial to allow implementors to specialize the implementation.

&mut self,
computer: &C,
mut f: F,
) -> impl SendFuture<Result<(), Self::Error>>
where
C: for<'a> PreprocessedDistanceFunction<Self::ElementRef<'a>, f32> + Send + Sync,
F: Send + FnMut(Self::Id, f32),
Comment thread
arrayka marked this conversation as resolved.
{
self.on_elements_unordered(move |id, element| {
let dist = computer.evaluate_similarity(element);
f(id, dist);
})
}
}

//////////////
// Iterator //
//////////////

/// A lending, asynchronous iterator over the elements of a flat index.
///
/// Implementations provide element-at-a-time access via [`Self::next`]. Providers that
/// only implement `FlatIterator` can be wrapped in [`DefaultIteratedOperator`] to obtain
/// an [`OnElementsUnordered`] implementation automatically.
pub trait FlatIterator: HasId + Send + Sync {
/// A reference to a yielded element with an unconstrained lifetime, suitable for
/// distance-function HRTB bounds.
type ElementRef<'a>;

/// The concrete element returned by [`Self::next`]. Reborrows to [`Self::ElementRef`].
type Element<'a>: for<'b> Reborrow<'b, Target = Self::ElementRef<'b>> + Send + Sync
where
Self: 'a;

/// The error type yielded by [`Self::next`].
type Error: StandardError;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On error types, maybe consider ToRanked instead of the Into<ANNError> that StandardError implies? That said, the visitor pattern means the implementation can swallow non-critical errors on its own. So maybe not needed, for the visitor case. But on the iterator case, a ToRanked might be a good idea.


/// Advance the iterator and asynchronously yield the next `(id, element)` pair.
///
/// Returns `Ok(None)` when the scan is exhausted. The yielded element borrows from
/// the iterator and is invalidated by the next call to `next`.
#[allow(clippy::type_complexity)]
fn next(
&mut self,
) -> impl SendFuture<Result<Option<(Self::Id, Self::Element<'_>)>, Self::Error>>;
}

/////////////
// Default //
/////////////

/// Bridges a [`FlatIterator`] into an [`OnElementsUnordered`] by looping over
/// [`FlatIterator::next`] and reborrowing each element into the closure.
///
/// This is the default adapter for providers that implement element-at-a-time iteration.
/// Providers that can do better (prefetching, SIMD batching, bulk I/O) should implement
/// [`OnElementsUnordered`] directly.
pub struct DefaultIteratedOperator<I> {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: maybe just call this "iterated"?

inner: I,
}

impl<I> DefaultIteratedOperator<I> {
/// Wrap an iterator to produce an [`OnElementsUnordered`] implementation.
pub fn new(inner: I) -> Self {
Self { inner }
}

/// Unwrap, returning the inner iterator.
pub fn into_inner(self) -> I {
self.inner
}
}

impl<I: HasId> HasId for DefaultIteratedOperator<I> {
type Id = I::Id;
}

impl<I> OnElementsUnordered for DefaultIteratedOperator<I>
where
I: FlatIterator + HasId + Send + Sync,
{
type ElementRef<'a> = I::ElementRef<'a>;
type Error = I::Error;

fn on_elements_unordered<F>(&mut self, mut f: F) -> impl SendFuture<Result<(), Self::Error>>
where
F: Send + for<'a> FnMut(Self::Id, Self::ElementRef<'a>),
{
async move {
while let Some((id, element)) = self.inner.next().await? {
f(id, element.reborrow());
}
Ok(())
}
}
}

impl<I> DistancesUnordered for DefaultIteratedOperator<I> where I: FlatIterator + HasId + Send + Sync
{}
45 changes: 45 additions & 0 deletions diskann/src/flat/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
/*
* Copyright (c) Microsoft Corporation.
* Licensed under the MIT license.
*/

//! Sequential ("flat") search infrastructure.
//!
//! This module is the streaming counterpart to the random-access [`crate::provider::Accessor`]
//! family. It is designed for backends whose natural access pattern is a one-pass scan over
//! their data — for example append-only buffered stores, on-disk shards streamed via I/O,
//! or any provider where random access is significantly more expensive than sequential.
//!
//! # Architecture
//!
//! The module mirrors the layering used by graph search:
//!
//! | Graph (random access) | Flat (sequential) |
//! | :------------------------------------ | :-------------------------------- |
//! | [`crate::provider::DataProvider`] | [`crate::provider::DataProvider`] |
//! | [`crate::graph::DiskANNIndex`] | [`FlatIndex`] |
//! | [`crate::provider::Accessor`] | [`FlatIterator`] |
//! | [`crate::graph::glue::SearchStrategy`] | [`FlatSearchStrategy`] |
//! | [`crate::graph::glue::SearchPostProcess`] | [`FlatPostProcess`] |
//! | [`crate::graph::Search`] | [`FlatIndex::knn_search`] |
//!
//! # Hot loop
//!
//! Algorithms drive the scan via [`FlatIterator::next`] (lending iterator) or override
//! [`FlatIterator::on_elements_unordered`] when batching/prefetching wins. The default
//! implementation of `on_elements_unordered` simply loops over `next`.
//!
//! See [`FlatIndex::knn_search`] for the canonical brute-force k-NN algorithm built on these
//! primitives.

pub mod index;
pub mod iterator;
pub mod post_process;
pub mod strategy;

pub use index::FlatIndex;
pub use iterator::{
DefaultIteratedOperator, DistancesUnordered, FlatIterator, OnElementsUnordered,
};
pub use post_process::{CopyIds, FlatPostProcess};
pub use strategy::FlatSearchStrategy;
72 changes: 72 additions & 0 deletions diskann/src/flat/post_process.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
/*
* Copyright (c) Microsoft Corporation.
* Licensed under the MIT license.
*/

//! [`FlatPostProcess`] — terminal stage of the flat search pipeline.

use diskann_utils::future::SendFuture;

use crate::{
error::StandardError, flat::OnElementsUnordered, graph::SearchOutputBuffer, neighbor::Neighbor,
provider::HasId,
};

/// Post-process the survivor candidates produced by a flat search and
/// write them into an output buffer.
///
/// This is the flat counterpart to [`crate::graph::glue::SearchPostProcess`]. Processors
/// receive `&mut S` so they can consult any iterator-owned lookup state (e.g., an
/// `Id -> rich-record` table built up during the scan) when assembling outputs.
///
/// The `O` type parameter lets callers pick the output element type (raw `(Id, f32)`
/// pairs, fully hydrated hits etc.).
pub trait FlatPostProcess<S, T, O = <S as HasId>::Id>
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One big concern I have is that this does not really share much with the existing graph code. Even though the desire is to share code, post-process routines like diversity search will still need to be implemented twice. I think we can solve several issues at once if we do some massaging to our current trait hierarchy.

Let's assume we do the following: First, add a new trait

/// A general supertrait like `HasId` that we can use to express 
/// relationships
pub trait HasElementRef {
    ElementRef<'a>;
}

/// Restructure the existing `BuildQueryComputer` as a subtrait of
/// `ElementRef`.
pub trait BuildQueryComputer<T>: HasElementRef {
    type QueryComputerError: std::error::Error + Into<ANNError> + Send + Sync + 'static;
    type QueryComputer: for<'a> PreprocessedDistanceFunction<Self::ElementRef<'a>, f32>
        + Send
        + Sync
        + 'static; // Maybe we can finally drop `'static`?

    fn build_query_computer(
        &self,
        from: T,
    ) -> Result<Self::QueryComputer, Self::QueryComputerError>;
}

Then, Accessor can add a new subtrait of BuildQueryComputer for its need of distances_unordered, and the flat Visitor can do so as well. Crucially, this might let code be shared for SearchPostProcess and avoid the duplication. And also keeps BuildQueryComputer a bit more centralized.

where
S: OnElementsUnordered,
T: ?Sized,
{
/// Errors yielded by [`Self::post_process`].
type Error: StandardError;

/// Consume `candidates` (in distance order) and write at most `k` results into
/// `output`. Returns the number of results written.
fn post_process<I, B>(
&self,
iter: &mut S,
query: &T,
candidates: I,
output: &mut B,
) -> impl SendFuture<Result<usize, Self::Error>>
where
I: Iterator<Item = Neighbor<S::Id>> + Send,
B: SearchOutputBuffer<O> + Send + ?Sized;
}

/// A trivial [`FlatPostProcess`] that copies each `(Id, distance)` pair straight into the
/// output buffer.
#[derive(Debug, Default, Clone, Copy)]
pub struct CopyIds;

impl<S, T> FlatPostProcess<S, T> for CopyIds
where
S: OnElementsUnordered,
T: ?Sized,
{
type Error = crate::error::Infallible;
Copy link

Copilot AI Apr 28, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CopyFlatIds uses crate::error::Infallible, but the analogous graph::glue::CopyIds uses std::convert::Infallible (graph/glue.rs:417). Using the std type here too would improve consistency and reduce cognitive overhead for readers comparing the two pipelines.

Suggested change
type Error = crate::error::Infallible;
type Error = std::convert::Infallible;

Copilot uses AI. Check for mistakes.

fn post_process<I, B>(
&self,
_iter: &mut S,
_query: &T,
candidates: I,
output: &mut B,
) -> impl SendFuture<Result<usize, Self::Error>>
where
I: Iterator<Item = Neighbor<<S as HasId>::Id>> + Send,
B: SearchOutputBuffer<<S as HasId>::Id> + Send + ?Sized,
{
let count = output.extend(candidates.map(|n| (n.id, n.distance)));
std::future::ready(Ok(count))
}
}
Loading
Loading