diff --git a/diskann-benchmark-runner/src/benchmark.rs b/diskann-benchmark-runner/src/benchmark.rs index 6c196d8af..27cb910a9 100644 --- a/diskann-benchmark-runner/src/benchmark.rs +++ b/diskann-benchmark-runner/src/benchmark.rs @@ -15,7 +15,7 @@ use crate::{ /// Benchmarks consist of an [`Input`] and a corresponding serialized `Output`. Inputs will /// first be validated with the benchmark using [`try_match`](Self::try_match). Only /// successful matches will be passed to [`run`](Self::run). -pub trait Benchmark { +pub trait Benchmark: 'static { /// The [`Input`] type this benchmark matches against. type Input: Input + 'static; @@ -32,7 +32,7 @@ pub trait Benchmark { /// On failure, returns `Err(FailureScore)`. In the [`crate::registry::Benchmarks`] /// registry, [`FailureScore`]s will be used to rank the "nearest misses". Implementations /// are encouraged to generate ranked [`FailureScore`]s to assist in user level debugging. - fn try_match(input: &Self::Input) -> Result; + fn try_match(&self, input: &Self::Input) -> Result; /// Return descriptive information about the benchmark. /// @@ -40,6 +40,7 @@ pub trait Benchmark { /// If `input` is `Some`, and is an unsuccessful match, diagnostic information about what /// was expected should be generated to help users. fn description( + &self, f: &mut std::fmt::Formatter<'_>, input: Option<&Self::Input>, ) -> std::fmt::Result; @@ -52,6 +53,7 @@ pub trait Benchmark { /// /// Implementors may assume that [`Self::try_match`] returned `Ok` on `input`. fn run( + &self, input: &Self::Input, checkpoint: Checkpoint<'_>, output: &mut dyn Output, @@ -88,6 +90,7 @@ pub trait Regression: Benchmark Deserialize<'a>> { /// stream. Instead, all diagnostics should be encoded in the returned [`PassFail`] type /// for reporting upstream. fn check( + &self, tolerances: &Self::Tolerances, input: &Self::Input, before: &Self::Output, @@ -109,8 +112,6 @@ pub enum PassFail { pub(crate) mod internal { use super::*; - use std::marker::PhantomData; - use anyhow::Context; use thiserror::Error; @@ -176,38 +177,32 @@ pub(crate) mod internal { } } - pub(crate) trait AsRegression { - fn as_regression(&self) -> Option<&dyn Regression>; + pub(crate) trait AsRegression { + fn as_regression(benchmark: &T) -> Option<&dyn Regression>; } - #[derive(Debug, Clone)] + #[derive(Debug, Clone, Copy)] pub(crate) struct NoRegression; - impl AsRegression for NoRegression { - fn as_regression(&self) -> Option<&dyn Regression> { + impl AsRegression for NoRegression { + fn as_regression(_benchmark: &T) -> Option<&dyn Regression> { None } } #[derive(Debug, Clone, Copy)] - pub(crate) struct WithRegression(PhantomData); + pub(crate) struct WithRegression; - impl WithRegression { - pub(crate) const fn new() -> Self { - Self(PhantomData) - } - } - - impl AsRegression for WithRegression + impl AsRegression for WithRegression where T: super::Regression, { - fn as_regression(&self) -> Option<&dyn Regression> { - Some(self) + fn as_regression(benchmark: &T) -> Option<&dyn Regression> { + Some(benchmark) } } - impl Regression for WithRegression + impl Regression for T where T: super::Regression, { @@ -242,7 +237,7 @@ pub(crate) mod internal { let after = T::Output::deserialize(after) .map_err(|err| DeserializationError::new(Kind::After, err))?; - let passfail = match T::check(tolerance, input, &before, &after)? { + let passfail = match self.check(tolerance, input, &before, &after)? { PassFail::Pass(pass) => PassFail::Pass(Checked::new(pass)?), PassFail::Fail(fail) => PassFail::Fail(Checked::new(fail)?), }; @@ -253,21 +248,15 @@ pub(crate) mod internal { #[derive(Debug, Clone, Copy)] pub(crate) struct Wrapper { - regression: R, - _type: PhantomData, - } - - impl Wrapper { - pub(crate) const fn new() -> Self { - Self::new_with(NoRegression) - } + benchmark: T, + _regression: R, } impl Wrapper { - pub(crate) const fn new_with(regression: R) -> Self { + pub(crate) const fn new(benchmark: T, regression: R) -> Self { Self { - regression, - _type: PhantomData, + benchmark, + _regression: regression, } } } @@ -278,11 +267,11 @@ pub(crate) mod internal { impl Benchmark for Wrapper where T: super::Benchmark, - R: AsRegression, + R: AsRegression, { fn try_match(&self, input: &Any) -> Result { if let Some(cast) = input.downcast_ref::() { - T::try_match(cast) + self.benchmark.try_match(cast) } else { Err(MATCH_FAIL) } @@ -295,7 +284,7 @@ pub(crate) mod internal { ) -> std::fmt::Result { match input { Some(input) => match input.downcast_ref::() { - Some(cast) => T::description(f, Some(cast)), + Some(cast) => self.benchmark.description(f, Some(cast)), None => write!( f, "expected tag \"{}\" - instead got \"{}\"", @@ -305,7 +294,7 @@ pub(crate) mod internal { }, None => { writeln!(f, "tag \"{}\"", ::tag())?; - T::description(f, None) + self.benchmark.description(f, None) } } } @@ -318,7 +307,7 @@ pub(crate) mod internal { ) -> anyhow::Result { match input.downcast_ref::() { Some(input) => { - let result = T::run(input, checkpoint, output)?; + let result = self.benchmark.run(input, checkpoint, output)?; Ok(serde_json::to_value(result)?) } None => Err(BadDownCast::new(T::Input::tag(), input.tag()).into()), @@ -327,7 +316,7 @@ pub(crate) mod internal { // Extensions fn as_regression(&self) -> Option<&dyn Regression> { - self.regression.as_regression() + R::as_regression(&self.benchmark) } } diff --git a/diskann-benchmark-runner/src/registry.rs b/diskann-benchmark-runner/src/registry.rs index 73e10c605..5d8c7366c 100644 --- a/diskann-benchmark-runner/src/registry.rs +++ b/diskann-benchmark-runner/src/registry.rs @@ -108,13 +108,16 @@ impl Benchmarks { } /// Register a new benchmark with the given name. - pub fn register(&mut self, name: impl Into) + pub fn register(&mut self, name: impl Into, benchmark: T) where - T: Benchmark + 'static, + T: Benchmark, { self.benchmarks.push(RegisteredBenchmark { name: name.into(), - benchmark: Box::new(benchmark::internal::Wrapper::::new()), + benchmark: Box::new(benchmark::internal::Wrapper::::new( + benchmark, + benchmark::internal::NoRegression, + )), }); } @@ -212,12 +215,13 @@ impl Benchmarks { /// /// Upon registration, the associated [`Regression::Tolerances`] input and the benchmark /// itself will be reachable via [`Check`](crate::app::Check). - pub fn register_regression(&mut self, name: impl Into) + pub fn register_regression(&mut self, name: impl Into, benchmark: T) where - T: Regression + 'static, + T: Regression, { - let registered = benchmark::internal::Wrapper::::new_with( - benchmark::internal::WithRegression::::new(), + let registered = benchmark::internal::Wrapper::::new( + benchmark, + benchmark::internal::WithRegression, ); self.benchmarks.push(RegisteredBenchmark { name: name.into(), diff --git a/diskann-benchmark-runner/src/test/dim.rs b/diskann-benchmark-runner/src/test/dim.rs index 07e73e2a6..f0eae36a3 100644 --- a/diskann-benchmark-runner/src/test/dim.rs +++ b/diskann-benchmark-runner/src/test/dim.rs @@ -99,7 +99,7 @@ impl Benchmark for SimpleBench { type Input = DimInput; type Output = usize; - fn try_match(input: &DimInput) -> Result { + fn try_match(&self, input: &DimInput) -> Result { if input.dim.is_none() { Ok(MatchScore(0)) } else { @@ -107,7 +107,11 @@ impl Benchmark for SimpleBench { } } - fn description(f: &mut std::fmt::Formatter<'_>, input: Option<&DimInput>) -> std::fmt::Result { + fn description( + &self, + f: &mut std::fmt::Formatter<'_>, + input: Option<&DimInput>, + ) -> std::fmt::Result { match input { Some(input) if input.dim.is_none() => write!(f, "successful match"), Some(_) => write!(f, "expected dim=None"), @@ -116,6 +120,7 @@ impl Benchmark for SimpleBench { } fn run( + &self, input: &DimInput, _checkpoint: Checkpoint<'_>, mut output: &mut dyn Output, @@ -133,11 +138,15 @@ impl Benchmark for DimBench { type Input = DimInput; type Output = usize; - fn try_match(_input: &DimInput) -> Result { + fn try_match(&self, _input: &DimInput) -> Result { Ok(MatchScore(0)) } - fn description(f: &mut std::fmt::Formatter<'_>, input: Option<&DimInput>) -> std::fmt::Result { + fn description( + &self, + f: &mut std::fmt::Formatter<'_>, + input: Option<&DimInput>, + ) -> std::fmt::Result { if input.is_some() { write!(f, "perfect match") } else { @@ -146,6 +155,7 @@ impl Benchmark for DimBench { } fn run( + &self, input: &DimInput, _checkpoint: Checkpoint<'_>, mut output: &mut dyn Output, @@ -161,6 +171,7 @@ impl Regression for DimBench { type Fail = &'static str; fn check( + &self, tolerance: &Tolerance, input: &DimInput, before: &usize, diff --git a/diskann-benchmark-runner/src/test/mod.rs b/diskann-benchmark-runner/src/test/mod.rs index 540842d4f..ea9853e5e 100644 --- a/diskann-benchmark-runner/src/test/mod.rs +++ b/diskann-benchmark-runner/src/test/mod.rs @@ -22,10 +22,13 @@ pub fn register_inputs(inputs: &mut registry::Inputs) -> anyhow::Result<()> { } pub fn register_benchmarks(benchmarks: &mut registry::Benchmarks) { - benchmarks.register_regression::>("type-bench-f32"); - benchmarks.register_regression::>("type-bench-i8"); - benchmarks.register_regression::>("exact-type-bench-f32-1000"); + benchmarks.register_regression("type-bench-f32", typed::TypeBench::::new()); + benchmarks.register_regression("type-bench-i8", typed::TypeBench::::new()); + benchmarks.register_regression( + "exact-type-bench-f32-1000", + typed::ExactTypeBench::::new(), + ); - benchmarks.register::("simple-bench"); - benchmarks.register_regression::("dim-bench"); + benchmarks.register("simple-bench", dim::SimpleBench); + benchmarks.register_regression("dim-bench", dim::DimBench); } diff --git a/diskann-benchmark-runner/src/test/typed.rs b/diskann-benchmark-runner/src/test/typed.rs index ed49b8b22..cae95f66d 100644 --- a/diskann-benchmark-runner/src/test/typed.rs +++ b/diskann-benchmark-runner/src/test/typed.rs @@ -129,6 +129,12 @@ impl CheckDeserialization for Tolerance { #[derive(Debug)] pub(super) struct TypeBench(std::marker::PhantomData); +impl TypeBench { + pub(super) fn new() -> Self { + Self(std::marker::PhantomData) + } +} + impl Benchmark for TypeBench where T: 'static, @@ -137,17 +143,22 @@ where type Input = TypeInput; type Output = String; - fn try_match(input: &TypeInput) -> Result { + fn try_match(&self, input: &TypeInput) -> Result { // Try to match based on data type. // Add a small penalty so `ExactTypeBench` can be more specific if it hits. Type::::try_match(&input.data_type).map(|m| MatchScore(m.0 + 10)) } - fn description(f: &mut std::fmt::Formatter<'_>, input: Option<&TypeInput>) -> std::fmt::Result { + fn description( + &self, + f: &mut std::fmt::Formatter<'_>, + input: Option<&TypeInput>, + ) -> std::fmt::Result { Type::::description(f, input.map(|i| &i.data_type)) } fn run( + &self, input: &TypeInput, checkpoint: Checkpoint<'_>, mut output: &mut dyn Output, @@ -169,6 +180,7 @@ where type Fail = DataType; fn check( + &self, _tolerance: &Tolerance, input: &TypeInput, before: &String, @@ -189,6 +201,12 @@ where #[derive(Debug)] pub(super) struct ExactTypeBench(std::marker::PhantomData); +impl ExactTypeBench { + pub(super) fn new() -> Self { + Self(std::marker::PhantomData) + } +} + impl Benchmark for ExactTypeBench where T: 'static, @@ -197,7 +215,7 @@ where type Input = TypeInput; type Output = String; - fn try_match(input: &TypeInput) -> Result { + fn try_match(&self, input: &TypeInput) -> Result { if input.dim == N { Type::::try_match(&input.data_type) } else { @@ -205,7 +223,11 @@ where } } - fn description(f: &mut std::fmt::Formatter<'_>, input: Option<&TypeInput>) -> std::fmt::Result { + fn description( + &self, + f: &mut std::fmt::Formatter<'_>, + input: Option<&TypeInput>, + ) -> std::fmt::Result { match input { None => { write!(f, "{}, dim={}", Description::>::new(), N) @@ -232,6 +254,7 @@ where } fn run( + &self, input: &TypeInput, checkpoint: Checkpoint<'_>, mut output: &mut dyn Output, @@ -253,6 +276,7 @@ where type Fail = String; fn check( + &self, _tolerance: &Tolerance, input: &TypeInput, before: &String, diff --git a/diskann-benchmark-simd/src/lib.rs b/diskann-benchmark-simd/src/lib.rs index 4fb921590..8d72efb91 100644 --- a/diskann-benchmark-simd/src/lib.rs +++ b/diskann-benchmark-simd/src/lib.rs @@ -303,137 +303,104 @@ impl std::fmt::Display for CheckResult { // Benchmark Registration // //////////////////////////// -macro_rules! register { - ($arch:literal, $dispatcher:ident, $name:literal, $($kernel:tt)*) => { - #[cfg(target_arch = $arch)] - $dispatcher.register_regression::<$($kernel)*>($name) - }; - ($dispatcher:ident, $name:literal, $($kernel:tt)*) => { - $dispatcher.register_regression::<$($kernel)*>($name) - }; -} - fn register_benchmarks_impl(dispatcher: &mut diskann_benchmark_runner::registry::Benchmarks) { // x86-64-v4 - register!( - "x86_64", - dispatcher, - "simd-op-f32xf32-x86_64_V4", - Kernel - ); - register!( - "x86_64", - dispatcher, - "simd-op-f16xf16-x86_64_V4", - Kernel - ); - register!( - "x86_64", - dispatcher, - "simd-op-u8xu8-x86_64_V4", - Kernel - ); - register!( - "x86_64", - dispatcher, - "simd-op-i8xi8-x86_64_V4", - Kernel - ); + #[cfg(target_arch = "x86_64")] + { + dispatcher.register_regression( + "simd-op-f32xf32-x86_64_V4", + Kernel::::new(), + ); + dispatcher.register_regression( + "simd-op-f16xf16-x86_64_V4", + Kernel::::new(), + ); + dispatcher.register_regression( + "simd-op-u8xu8-x86_64_V4", + Kernel::::new(), + ); + dispatcher.register_regression( + "simd-op-i8xi8-x86_64_V4", + Kernel::::new(), + ); + } // x86-64-v3 - register!( - "x86_64", - dispatcher, - "simd-op-f32xf32-x86_64_V3", - Kernel - ); - register!( - "x86_64", - dispatcher, - "simd-op-f16xf16-x86_64_V3", - Kernel - ); - register!( - "x86_64", - dispatcher, - "simd-op-u8xu8-x86_64_V3", - Kernel - ); - register!( - "x86_64", - dispatcher, - "simd-op-i8xi8-x86_64_V3", - Kernel - ); + #[cfg(target_arch = "x86_64")] + { + dispatcher.register_regression( + "simd-op-f32xf32-x86_64_V3", + Kernel::::new(), + ); + dispatcher.register_regression( + "simd-op-f16xf16-x86_64_V3", + Kernel::::new(), + ); + dispatcher.register_regression( + "simd-op-u8xu8-x86_64_V3", + Kernel::::new(), + ); + dispatcher.register_regression( + "simd-op-i8xi8-x86_64_V3", + Kernel::::new(), + ); + } // aarch64-neon - register!( - "aarch64", - dispatcher, - "simd-op-f32xf32-aarch64_neon", - Kernel - ); - register!( - "aarch64", - dispatcher, - "simd-op-f16xf16-aarch64_neon", - Kernel - ); - register!( - "aarch64", - dispatcher, - "simd-op-u8xu8-aarch64_neon", - Kernel - ); - register!( - "aarch64", - dispatcher, - "simd-op-i8xi8-aarch64_neon", - Kernel - ); + #[cfg(target_arch = "aarch64")] + { + dispatcher.register_regression( + "simd-op-f32xf32-aarch64_neon", + Kernel::::new(), + ); + dispatcher.register_regression( + "simd-op-f16xf16-aarch64_neon", + Kernel::::new(), + ); + dispatcher.register_regression( + "simd-op-u8xu8-aarch64_neon", + Kernel::::new(), + ); + dispatcher.register_regression( + "simd-op-i8xi8-aarch64_neon", + Kernel::::new(), + ); + } // scalar - register!( - dispatcher, + dispatcher.register_regression( "simd-op-f32xf32-scalar", - Kernel + Kernel::::new(), ); - register!( - dispatcher, + dispatcher.register_regression( "simd-op-f16xf16-scalar", - Kernel + Kernel::::new(), ); - register!( - dispatcher, + dispatcher.register_regression( "simd-op-u8xu8-scalar", - Kernel + Kernel::::new(), ); - register!( - dispatcher, + dispatcher.register_regression( "simd-op-i8xi8-scalar", - Kernel + Kernel::::new(), ); // reference - register!( - dispatcher, + dispatcher.register_regression( "simd-op-f32xf32-reference", - Kernel + Kernel::::new(), ); - register!( - dispatcher, + dispatcher.register_regression( "simd-op-f16xf16-reference", - Kernel + Kernel::::new(), ); - register!( - dispatcher, + dispatcher.register_regression( "simd-op-u8xu8-reference", - Kernel + Kernel::::new(), ); - register!( - dispatcher, + dispatcher.register_regression( "simd-op-i8xi8-reference", - Kernel + Kernel::::new(), ); } @@ -449,14 +416,12 @@ struct Reference; struct Identity(T); struct Kernel { - arch: A, _type: std::marker::PhantomData<(A, Q, D)>, } impl Kernel { - fn new(arch: A) -> Self { + fn new() -> Self { Self { - arch, _type: std::marker::PhantomData, } } @@ -582,13 +547,16 @@ where datatype::Type: DispatchRule, datatype::Type: DispatchRule, Identity: DispatchRule, - Kernel: RunBenchmark, + Kernel: RunBenchmark, + A: 'static, + Q: 'static, + D: 'static, { type Input = SimdOp; type Output = Vec; // Matching simply requires that we match the inner type. - fn try_match(from: &SimdOp) -> Result { + fn try_match(&self, from: &SimdOp) -> Result { let mut failscore: Option = None; if datatype::Type::::try_match(&from.query_type).is_err() { *failscore.get_or_insert(0) += 10; @@ -607,19 +575,23 @@ where } fn run( + &self, input: &SimdOp, _: diskann_benchmark_runner::Checkpoint<'_>, mut output: &mut dyn diskann_benchmark_runner::Output, ) -> anyhow::Result { let arch = Identity::::convert(input.arch)?.0; - let kernel = Self::new(arch); writeln!(output, "{}", input)?; - let results = kernel.run(input)?; + let results = self.run_benchmark(input, arch)?; writeln!(output, "\n\n{}", DisplayWrapper(&*results))?; Ok(results) } - fn description(f: &mut std::fmt::Formatter<'_>, input: Option<&SimdOp>) -> std::fmt::Result { + fn description( + &self, + f: &mut std::fmt::Formatter<'_>, + input: Option<&SimdOp>, + ) -> std::fmt::Result { match input { None => { describeln!( @@ -659,13 +631,17 @@ where datatype::Type: DispatchRule, datatype::Type: DispatchRule, Identity: DispatchRule, - Kernel: RunBenchmark, + Kernel: RunBenchmark, + A: 'static, + Q: 'static, + D: 'static, { type Tolerances = SimdTolerance; type Pass = CheckResult; type Fail = CheckResult; fn check( + &self, tolerance: &SimdTolerance, _input: &SimdOp, before: &Vec, @@ -724,8 +700,8 @@ where // Benchmark // /////////////// -trait RunBenchmark { - fn run(self, input: &SimdOp) -> Result, anyhow::Error>; +trait RunBenchmark { + fn run_benchmark(&self, input: &SimdOp, arch: A) -> Result, anyhow::Error>; } #[derive(Debug, Serialize, Deserialize)] @@ -856,8 +832,12 @@ impl Data { macro_rules! stamp { (reference, $Q:ty, $D:ty, $f_l2:ident, $f_ip:ident, $f_cosine:ident) => { - impl RunBenchmark for Kernel { - fn run(self, input: &SimdOp) -> Result, anyhow::Error> { + impl RunBenchmark for Kernel { + fn run_benchmark( + &self, + input: &SimdOp, + _arch: Reference, + ) -> Result, anyhow::Error> { let mut results = Vec::new(); for run in input.runs.iter() { let data = Data::<$Q, $D>::new(run); @@ -873,8 +853,12 @@ macro_rules! stamp { } }; ($arch:path, $Q:ty, $D:ty) => { - impl RunBenchmark for Kernel<$arch, $Q, $D> { - fn run(self, input: &SimdOp) -> Result, anyhow::Error> { + impl RunBenchmark<$arch> for Kernel<$arch, $Q, $D> { + fn run_benchmark( + &self, + input: &SimdOp, + arch: $arch, + ) -> Result, anyhow::Error> { let mut results = Vec::new(); let l2 = &simd::L2 {}; @@ -891,16 +875,13 @@ macro_rules! stamp { // target features. let result = match run.distance { SimilarityMeasure::SquaredL2 => data.run(run, |q, d| { - self.arch - .run2(|q, d| simd::simd_op(l2, self.arch, q, d), q, d) + arch.run2(|q, d| simd::simd_op(l2, arch, q, d), q, d) }), SimilarityMeasure::InnerProduct => data.run(run, |q, d| { - self.arch - .run2(|q, d| simd::simd_op(ip, self.arch, q, d), q, d) + arch.run2(|q, d| simd::simd_op(ip, arch, q, d), q, d) }), SimilarityMeasure::Cosine => data.run(run, |q, d| { - self.arch - .run2(|q, d| simd::simd_op(cosine, self.arch, q, d), q, d) + arch.run2(|q, d| simd::simd_op(cosine, arch, q, d), q, d) }), }; results.push(result) @@ -1237,60 +1218,64 @@ mod tests { #[test] fn check_rejects_mismatched_runs() { - type Bench = Kernel; + let kernel = Kernel::::new(); - let err = Bench::check( - &tolerance(0.0), - &tiny_op(), - &vec![tiny_result(SimilarityMeasure::SquaredL2, 100)], - &vec![tiny_result(SimilarityMeasure::Cosine, 100)], - ) - .unwrap_err(); + let err = kernel + .check( + &tolerance(0.0), + &tiny_op(), + &vec![tiny_result(SimilarityMeasure::SquaredL2, 100)], + &vec![tiny_result(SimilarityMeasure::Cosine, 100)], + ) + .unwrap_err(); assert_eq!(err.to_string(), "run 0 mismatched"); } #[test] fn check_allows_negative_relative_change() { - type Bench = Kernel; + let kernel = Kernel::::new(); - let result = Bench::check( - &tolerance(0.0), - &tiny_op(), - &vec![tiny_result(SimilarityMeasure::SquaredL2, 100)], - &vec![tiny_result(SimilarityMeasure::SquaredL2, 95)], - ) - .unwrap(); + let result = kernel + .check( + &tolerance(0.0), + &tiny_op(), + &vec![tiny_result(SimilarityMeasure::SquaredL2, 100)], + &vec![tiny_result(SimilarityMeasure::SquaredL2, 95)], + ) + .unwrap(); assert!(matches!(result, PassFail::Pass(_))); } #[test] fn check_passes_on_tolerance_boundary() { - type Bench = Kernel; + let kernel = Kernel::::new(); - let result = Bench::check( - &tolerance(0.05), - &tiny_op(), - &vec![tiny_result(SimilarityMeasure::SquaredL2, 100)], - &vec![tiny_result(SimilarityMeasure::SquaredL2, 105)], - ) - .unwrap(); + let result = kernel + .check( + &tolerance(0.05), + &tiny_op(), + &vec![tiny_result(SimilarityMeasure::SquaredL2, 100)], + &vec![tiny_result(SimilarityMeasure::SquaredL2, 105)], + ) + .unwrap(); assert!(matches!(result, PassFail::Pass(_))); } #[test] fn check_fails_above_tolerance_boundary() { - type Bench = Kernel; + let kernel = Kernel::::new(); - let result = Bench::check( - &tolerance(0.05), - &tiny_op(), - &vec![tiny_result(SimilarityMeasure::SquaredL2, 100)], - &vec![tiny_result(SimilarityMeasure::SquaredL2, 106)], - ) - .unwrap(); + let result = kernel + .check( + &tolerance(0.05), + &tiny_op(), + &vec![tiny_result(SimilarityMeasure::SquaredL2, 100)], + &vec![tiny_result(SimilarityMeasure::SquaredL2, 106)], + ) + .unwrap(); assert!(matches!(result, PassFail::Fail(_))); } @@ -1322,15 +1307,16 @@ mod tests { // We require at least a non-zero value. #[test] fn zero_values_rejected() { - type Bench = Kernel; - - let result = Bench::check( - &tolerance(0.05), - &tiny_op(), - &vec![tiny_result(SimilarityMeasure::SquaredL2, 0)], - &vec![tiny_result(SimilarityMeasure::SquaredL2, 0)], - ) - .unwrap(); + let kernel = Kernel::::new(); + + let result = kernel + .check( + &tolerance(0.05), + &tiny_op(), + &vec![tiny_result(SimilarityMeasure::SquaredL2, 0)], + &vec![tiny_result(SimilarityMeasure::SquaredL2, 0)], + ) + .unwrap(); assert!(matches!(result, PassFail::Fail(_))); } diff --git a/diskann-benchmark/src/backend/disk_index/benchmarks.rs b/diskann-benchmark/src/backend/disk_index/benchmarks.rs index fa9b036ad..6c5298dd8 100644 --- a/diskann-benchmark/src/backend/disk_index/benchmarks.rs +++ b/diskann-benchmark/src/backend/disk_index/benchmarks.rs @@ -30,8 +30,7 @@ use crate::{ }; /// Disk Index -struct DiskIndex<'a, T> { - input: &'a DiskIndexOperation, +struct DiskIndex { _vector_type: std::marker::PhantomData, } @@ -41,53 +40,18 @@ pub(super) struct DiskIndexStats { pub(super) search: DiskSearchStats, } -impl<'a, T> DiskIndex<'a, T> +impl DiskIndex where T: VectorRepr, { - fn new(input: &'a DiskIndexOperation) -> Self { + fn new() -> Self { Self { - input, _vector_type: std::marker::PhantomData, } } - - fn run( - &self, - _checkpoint: Checkpoint<'_>, - mut output: &mut dyn Output, - ) -> Result { - writeln!(output, "{}", self.input.source)?; - let (build_stats, index_load) = match &self.input.source { - DiskIndexSource::Load(load) => Ok((None, (*load).clone())), - DiskIndexSource::Build(build) => build_disk_index::(&FileStorageProvider, build) - .map(|stats| { - ( - Some(stats), - DiskIndexLoad { - data_type: build.data_type, - load_path: build.save_path.clone(), - }, - ) - }), - }?; - if let Some(build_stats) = &build_stats { - writeln!(output, "{}", build_stats)?; - } - - writeln!(output, "{}", self.input.search_phase)?; - let search_stats = - search_disk_index::(&index_load, &self.input.search_phase, &FileStorageProvider)?; - writeln!(output, "{}", search_stats)?; - - Ok(DiskIndexStats { - build: build_stats, - search: search_stats, - }) - } } -impl Benchmark for DiskIndex<'static, T> +impl Benchmark for DiskIndex where T: VectorRepr + 'static, Type: DispatchRule, @@ -95,7 +59,7 @@ where type Input = DiskIndexOperation; type Output = DiskIndexStats; - fn try_match(input: &DiskIndexOperation) -> Result { + fn try_match(&self, input: &DiskIndexOperation) -> Result { match &input.source { DiskIndexSource::Load(load) => Type::::try_match(&load.data_type), DiskIndexSource::Build(build) => Type::::try_match(&build.data_type), @@ -103,6 +67,7 @@ where } fn description( + &self, f: &mut std::fmt::Formatter<'_>, input: Option<&DiskIndexOperation>, ) -> std::fmt::Result { @@ -116,11 +81,39 @@ where } fn run( + &self, input: &DiskIndexOperation, - checkpoint: Checkpoint<'_>, - output: &mut dyn Output, + _checkpoint: Checkpoint<'_>, + mut output: &mut dyn Output, ) -> anyhow::Result { - DiskIndex::::new(input).run(checkpoint, output) + writeln!(output, "{}", input.source)?; + + let (build_stats, index_load) = match &input.source { + DiskIndexSource::Load(load) => Ok((None, (*load).clone())), + DiskIndexSource::Build(build) => build_disk_index::(&FileStorageProvider, build) + .map(|stats| { + ( + Some(stats), + DiskIndexLoad { + data_type: build.data_type, + load_path: build.save_path.clone(), + }, + ) + }), + }?; + if let Some(build_stats) = &build_stats { + writeln!(output, "{}", build_stats)?; + } + + writeln!(output, "{}", input.search_phase)?; + let search_stats = + search_disk_index::(&index_load, &input.search_phase, &FileStorageProvider)?; + writeln!(output, "{}", search_stats)?; + + Ok(DiskIndexStats { + build: build_stats, + search: search_stats, + }) } } @@ -129,10 +122,10 @@ where //////////////////////////// pub(super) fn register_benchmarks(benchmarks: &mut diskann_benchmark_runner::registry::Benchmarks) { - benchmarks.register_regression::>("disk-index-f32"); - benchmarks.register_regression::>("disk-index-f16"); - benchmarks.register_regression::>("disk-index-u8"); - benchmarks.register_regression::>("disk-index-i8"); + benchmarks.register_regression("disk-index-f32", DiskIndex::::new()); + benchmarks.register_regression("disk-index-f16", DiskIndex::::new()); + benchmarks.register_regression("disk-index-u8", DiskIndex::::new()); + benchmarks.register_regression("disk-index-i8", DiskIndex::::new()); } ///////////////////////// @@ -302,7 +295,7 @@ fn check_metric( } } -impl Regression for DiskIndex<'static, T> +impl Regression for DiskIndex where T: VectorRepr + 'static, Type: DispatchRule, @@ -312,6 +305,7 @@ where type Fail = DiskIndexCheckResult; fn check( + &self, tolerances: &DiskIndexTolerance, _input: &DiskIndexOperation, before: &DiskIndexStats, diff --git a/diskann-benchmark/src/backend/exhaustive/minmax.rs b/diskann-benchmark/src/backend/exhaustive/minmax.rs index 26b3da16a..73b57733f 100644 --- a/diskann-benchmark/src/backend/exhaustive/minmax.rs +++ b/diskann-benchmark/src/backend/exhaustive/minmax.rs @@ -12,10 +12,10 @@ crate::utils::stub_impl!("minmax-quantization", inputs::exhaustive::MinMax); // MinMax - requires feature "minmax-quantization" #[cfg(feature = "minmax-quantization")] pub(super) fn register_benchmarks(benchmarks: &mut Benchmarks) { - benchmarks.register::>(NAME); - benchmarks.register::>(NAME); - benchmarks.register::>(NAME); - benchmarks.register::>(NAME); + benchmarks.register(NAME, imp::MinMaxQ::<1>); + benchmarks.register(NAME, imp::MinMaxQ::<2>); + benchmarks.register(NAME, imp::MinMaxQ::<4>); + benchmarks.register(NAME, imp::MinMaxQ::<8>); } // Stub implementation @@ -85,23 +85,21 @@ mod imp { Ok(progress) } - /// The dispatcher target for `spherical-quantization` operations. - pub(super) struct MinMaxQ<'a, const NBITS: usize> { - input: &'a inputs::exhaustive::MinMax, - } - - impl<'a, const NBITS: usize> MinMaxQ<'a, NBITS> { - pub(super) fn new(input: &'a inputs::exhaustive::MinMax) -> Self { - Self { input } - } + /// The dispatcher target for `minmax-quantization` operations. + #[derive(Debug, Clone, Copy)] + pub(super) struct MinMaxQ; - pub(super) fn run(self, mut output: &mut dyn Output) -> anyhow::Result + impl MinMaxQ { + pub(super) fn run( + &self, + input: &inputs::exhaustive::MinMax, + mut output: &mut dyn Output, + ) -> anyhow::Result where Unsigned: Representation, Plan: algos::CreateQuantComputer>, { - let input = &self.input; - writeln!(output, "{}", self.input)?; + writeln!(output, "{}", input)?; // Training let data = f32::converting_load(datafiles::BinFile(&input.data), input.data_type)?; @@ -111,13 +109,13 @@ mod imp { let dim = NonZeroUsize::new(data.ncols()).unwrap(); let transform = Transform::new( - (&self.input.transform_kind).into(), + (&input.transform_kind).into(), dim, Some(&mut rng), diskann_quantization::alloc::GlobalAllocator, )?; - let quantizer = MinMaxQuantizer::new(transform, Positive::new(self.input.scale)?); + let quantizer = MinMaxQuantizer::new(transform, Positive::new(input.scale)?); let training_time: MicroSeconds = start.elapsed().into(); @@ -198,7 +196,7 @@ mod imp { } } - impl Benchmark for MinMaxQ<'static, NBITS> + impl Benchmark for MinMaxQ where Unsigned: Representation, Plan: algos::CreateQuantComputer>, @@ -206,7 +204,10 @@ mod imp { type Input = inputs::exhaustive::MinMax; type Output = Results; - fn try_match(input: &inputs::exhaustive::MinMax) -> Result { + fn try_match( + &self, + input: &inputs::exhaustive::MinMax, + ) -> Result { let num_bits = input.num_bits.get(); if num_bits == NBITS { Ok(MatchScore(0)) @@ -218,6 +219,7 @@ mod imp { } fn description( + &self, f: &mut std::fmt::Formatter<'_>, input: Option<&inputs::exhaustive::MinMax>, ) -> std::fmt::Result { @@ -246,11 +248,12 @@ mod imp { } fn run( + &self, input: &inputs::exhaustive::MinMax, _checkpoint: diskann_benchmark_runner::Checkpoint<'_>, output: &mut dyn Output, ) -> anyhow::Result { - MinMaxQ::::new(input).run(output) + self.run(input, output) } } diff --git a/diskann-benchmark/src/backend/exhaustive/product.rs b/diskann-benchmark/src/backend/exhaustive/product.rs index 4723753a0..43f55a8ff 100644 --- a/diskann-benchmark/src/backend/exhaustive/product.rs +++ b/diskann-benchmark/src/backend/exhaustive/product.rs @@ -11,7 +11,7 @@ crate::utils::stub_impl!("product-quantization", inputs::exhaustive::Product); pub(super) fn register_benchmarks(benchmarks: &mut Benchmarks) { #[cfg(feature = "product-quantization")] - benchmarks.register::>(NAME); + benchmarks.register(NAME, imp::ProductQ); #[cfg(not(feature = "product-quantization"))] imp::register(NAME, benchmarks) @@ -65,18 +65,16 @@ mod imp { } /// The dispatcher target for `spherical-quantization` operations. - pub(super) struct ProductQ<'a> { - input: &'a inputs::exhaustive::Product, - } - - impl<'a> ProductQ<'a> { - pub(super) fn new(input: &'a inputs::exhaustive::Product) -> Self { - Self { input } - } + #[derive(Debug, Clone, Copy)] + pub(super) struct ProductQ; - pub(super) fn run(self, mut output: &mut dyn Output) -> anyhow::Result { - let input = &self.input; - writeln!(output, "{}", self.input)?; + impl ProductQ { + pub(super) fn run( + &self, + input: &inputs::exhaustive::Product, + mut output: &mut dyn Output, + ) -> anyhow::Result { + writeln!(output, "{}", input)?; // Training let data = f32::converting_load(datafiles::BinFile(&input.data), input.data_type)?; @@ -190,15 +188,19 @@ mod imp { } } - impl Benchmark for ProductQ<'static> { + impl Benchmark for ProductQ { type Input = inputs::exhaustive::Product; type Output = Results; - fn try_match(_input: &inputs::exhaustive::Product) -> Result { + fn try_match( + &self, + _input: &inputs::exhaustive::Product, + ) -> Result { Ok(MatchScore(0)) } fn description( + &self, f: &mut std::fmt::Formatter<'_>, input: Option<&inputs::exhaustive::Product>, ) -> std::fmt::Result { @@ -210,11 +212,12 @@ mod imp { } fn run( + &self, input: &inputs::exhaustive::Product, _checkpoint: diskann_benchmark_runner::Checkpoint<'_>, output: &mut dyn Output, ) -> anyhow::Result { - ProductQ::new(input).run(output) + self.run(input, output) } } diff --git a/diskann-benchmark/src/backend/exhaustive/spherical.rs b/diskann-benchmark/src/backend/exhaustive/spherical.rs index 1c0881c56..9b1f9a935 100644 --- a/diskann-benchmark/src/backend/exhaustive/spherical.rs +++ b/diskann-benchmark/src/backend/exhaustive/spherical.rs @@ -12,10 +12,10 @@ crate::utils::stub_impl!("spherical-quantization", inputs::exhaustive::Spherical // Spherical - requires feature "spherical-quantization" #[cfg(feature = "spherical-quantization")] pub(super) fn register_benchmarks(benchmarks: &mut Benchmarks) { - benchmarks.register::>(NAME); - benchmarks.register::>(NAME); - benchmarks.register::>(NAME); - benchmarks.register::>(NAME); + benchmarks.register(NAME, imp::SphericalQ::<1>); + benchmarks.register(NAME, imp::SphericalQ::<2>); + benchmarks.register(NAME, imp::SphericalQ::<4>); + benchmarks.register(NAME, imp::SphericalQ::<8>); } // Stub implementation @@ -79,16 +79,14 @@ mod imp { } /// The dispatcher target for `spherical-quantization` operations. - pub(super) struct SphericalQ<'a, const NBITS: usize> { - input: &'a inputs::exhaustive::Spherical, - } - - impl<'a, const NBITS: usize> SphericalQ<'a, NBITS> { - pub(super) fn new(input: &'a inputs::exhaustive::Spherical) -> Self { - Self { input } - } + pub(super) struct SphericalQ; - pub(super) fn run(self, mut output: &mut dyn Output) -> anyhow::Result + impl SphericalQ { + pub(super) fn run( + &self, + input: &inputs::exhaustive::Spherical, + mut output: &mut dyn Output, + ) -> anyhow::Result where Unsigned: Representation, Plan: algos::CreateQuantComputer>, @@ -97,8 +95,7 @@ mod imp { SphericalQuantizer: for<'x> CompressIntoWith<&'x [f32], DataMut<'x, NBITS>, ScopedAllocator<'x>>, { - let input = &self.input; - writeln!(output, "{}", self.input)?; + writeln!(output, "{}", input)?; // Training let data = f32::converting_load(datafiles::BinFile(&input.data), input.data_type)?; @@ -202,7 +199,7 @@ mod imp { } } - impl Benchmark for SphericalQ<'static, NBITS> + impl Benchmark for SphericalQ where Unsigned: Representation, Plan: algos::CreateQuantComputer>, @@ -214,7 +211,10 @@ mod imp { type Input = inputs::exhaustive::Spherical; type Output = Results; - fn try_match(input: &inputs::exhaustive::Spherical) -> Result { + fn try_match( + &self, + input: &inputs::exhaustive::Spherical, + ) -> Result { let num_bits = input.num_bits.get(); if num_bits == NBITS { Ok(MatchScore(0)) @@ -226,6 +226,7 @@ mod imp { } fn description( + &self, f: &mut std::fmt::Formatter<'_>, input: Option<&inputs::exhaustive::Spherical>, ) -> std::fmt::Result { @@ -254,11 +255,12 @@ mod imp { } fn run( + &self, input: &inputs::exhaustive::Spherical, _checkpoint: diskann_benchmark_runner::Checkpoint<'_>, output: &mut dyn Output, ) -> anyhow::Result { - SphericalQ::::new(input).run(output) + self.run(input, output) } } diff --git a/diskann-benchmark/src/backend/filters/benchmark.rs b/diskann-benchmark/src/backend/filters/benchmark.rs index a90ea41ed..43a0717b4 100644 --- a/diskann-benchmark/src/backend/filters/benchmark.rs +++ b/diskann-benchmark/src/backend/filters/benchmark.rs @@ -29,29 +29,23 @@ use crate::{ }; pub(crate) fn register_benchmarks(benchmarks: &mut Benchmarks) { - benchmarks.register::>("metadata-index-build"); + benchmarks.register("metadata-index-build", MetadataIndexJob); } -// Metadata-only index job wrapper -pub(super) struct MetadataIndexJob<'a> { - input: &'a crate::inputs::filters::MetadataIndexBuild, -} - -impl<'a> MetadataIndexJob<'a> { - fn new(input: &'a crate::inputs::filters::MetadataIndexBuild) -> Self { - Self { input } - } -} +// Metadata-only index job. +#[derive(Debug)] +struct MetadataIndexJob; -impl Benchmark for MetadataIndexJob<'static> { +impl Benchmark for MetadataIndexJob { type Input = MetadataIndexBuild; type Output = MetadataIndexBuildStats; - fn try_match(_input: &MetadataIndexBuild) -> Result { + fn try_match(&self, _input: &MetadataIndexBuild) -> Result { Ok(MatchScore(1)) } fn description( + &self, f: &mut std::fmt::Formatter<'_>, _input: Option<&MetadataIndexBuild>, ) -> std::fmt::Result { @@ -63,90 +57,89 @@ impl Benchmark for MetadataIndexJob<'static> { } fn run( + &self, input: &MetadataIndexBuild, checkpoint: Checkpoint<'_>, output: &mut dyn Output, ) -> anyhow::Result { - MetadataIndexJob::new(input).run(checkpoint, output) + run(input, checkpoint, output) } } -impl<'a> MetadataIndexJob<'a> { - fn run( - self, - checkpoint: Checkpoint<'_>, - mut output: &mut dyn Output, - ) -> Result { - // Print the input description so the user sees the job configuration. - writeln!(output, "{}", self.input)?; - - // Use the supplied filter parameters (required for metadata-only build) - let filter_params = &self.input.filter_params; - - // Reuse the helper: build index, parse predicates, produce BitmapFilters and telemetry - let (bitmap_filters_vec, filter_search_results, _label_count) = - prepare_bitmap_filters_from_paths_with_kind( - filter_params.data_labels.as_ref(), - filter_params.query_predicates.as_ref(), - self.input.inverted_index_type, - checkpoint, - )?; - - // Collect per-query matching counts and compute aggregates - let counts: Vec = bitmap_filters_vec.iter().map(|bf| bf.count()).collect(); - let query_count = counts.len(); - let total_matching: usize = counts.iter().cloned().sum(); +fn run( + input: &crate::inputs::filters::MetadataIndexBuild, + checkpoint: Checkpoint<'_>, + mut output: &mut dyn Output, +) -> Result { + // Print the input description so the user sees the job configuration. + writeln!(output, "{}", input)?; + + // Use the supplied filter parameters (required for metadata-only build) + let filter_params = &input.filter_params; + + // Reuse the helper: build index, parse predicates, produce BitmapFilters and telemetry + let (bitmap_filters_vec, filter_search_results, _label_count) = + prepare_bitmap_filters_from_paths_with_kind( + filter_params.data_labels.as_ref(), + filter_params.query_predicates.as_ref(), + input.inverted_index_type, + checkpoint, + )?; - // counts_avg will be computed below via the shared percentiles utility - let mut sorted = counts.clone(); - // Use the shared percentiles utility when we have values. - let ( - counts_p1, - counts_p5, - counts_p10, - counts_p50, - counts_p90, - counts_p95, - counts_p99, - counts_avg, - ) = if sorted.is_empty() { - ( - 0usize, 0usize, 0usize, 0usize, 0usize, 0usize, 0usize, 0.0f64, - ) - } else { - sorted.sort_unstable(); - let p = percentiles::compute_percentiles(&mut sorted)?; - // p.median is f64; round to nearest usize for display/storage - let p50 = p.median.round() as usize; - let p90 = p.p90; - let p99 = p.p99; - let n = sorted.len(); - let p1 = sorted[(n / 100).min(n - 1)]; - let p5 = sorted[((5 * n) / 100).min(n - 1)]; - let p10 = sorted[((10 * n) / 100).min(n - 1)]; - let p95 = sorted[((95 * n) / 100).min(n - 1)]; - (p1, p5, p10, p50, p90, p95, p99, p.mean) - }; + // Collect per-query matching counts and compute aggregates + let counts: Vec = bitmap_filters_vec.iter().map(|bf| bf.count()).collect(); + let query_count = counts.len(); + let total_matching: usize = counts.iter().cloned().sum(); + + // counts_avg will be computed below via the shared percentiles utility + let mut sorted = counts.clone(); + // Use the shared percentiles utility when we have values. + let ( + counts_p1, + counts_p5, + counts_p10, + counts_p50, + counts_p90, + counts_p95, + counts_p99, + counts_avg, + ) = if sorted.is_empty() { + ( + 0usize, 0usize, 0usize, 0usize, 0usize, 0usize, 0usize, 0.0f64, + ) + } else { + sorted.sort_unstable(); + let p = percentiles::compute_percentiles(&mut sorted)?; + // p.median is f64; round to nearest usize for display/storage + let p50 = p.median.round() as usize; + let p90 = p.p90; + let p99 = p.p99; + let n = sorted.len(); + let p1 = sorted[(n / 100).min(n - 1)]; + let p5 = sorted[((5 * n) / 100).min(n - 1)]; + let p10 = sorted[((10 * n) / 100).min(n - 1)]; + let p95 = sorted[((95 * n) / 100).min(n - 1)]; + (p1, p5, p10, p50, p90, p95, p99, p.mean) + }; - let stats = MetadataIndexBuildStats { - label_count: _label_count, - query_count, - total_matching, - counts_avg, - counts_p1, - counts_p5, - counts_p10, - counts_p50, - counts_p90, - counts_p95, - counts_p99, - filter: filter_search_results, - }; + let stats = MetadataIndexBuildStats { + label_count: _label_count, + query_count, + total_matching, + counts_avg, + counts_p1, + counts_p5, + counts_p10, + counts_p50, + counts_p90, + counts_p95, + counts_p99, + filter: filter_search_results, + }; - // Print the human-readable summary for interactive runs. - writeln!(output, "\n\n{}", stats)?; - Ok(stats) - } + // Print the human-readable summary for interactive runs. + writeln!(output, "\n\n{}", stats)?; + Ok(stats) } #[derive(Debug, Serialize)] diff --git a/diskann-benchmark/src/backend/index/benchmarks.rs b/diskann-benchmark/src/backend/index/benchmarks.rs index d38332ee1..ad7fd697a 100644 --- a/diskann-benchmark/src/backend/index/benchmarks.rs +++ b/diskann-benchmark/src/backend/index/benchmarks.rs @@ -32,7 +32,6 @@ use diskann_utils::{ views::{Matrix, MatrixView}, }; use half::f16; -use serde::Serialize; use super::{ build::{self, load_index, save_index, single_or_multi_insert, BuildStats}, @@ -57,66 +56,48 @@ use crate::{ pub(super) fn register_benchmarks(benchmarks: &mut diskann_benchmark_runner::registry::Benchmarks) { // Full Precision - benchmarks.register::>("async-full-precision-f32"); - benchmarks.register::>("async-full-precision-f16"); - benchmarks.register::>("async-full-precision-u8"); - benchmarks.register::>("async-full-precision-i8"); + benchmarks.register("async-full-precision-f32", FullPrecision::::new()); + benchmarks.register("async-full-precision-f16", FullPrecision::::new()); + benchmarks.register("async-full-precision-u8", FullPrecision::::new()); + benchmarks.register("async-full-precision-i8", FullPrecision::::new()); // Dynamic Full Precision - benchmarks.register::>("async-dynamic-full-precision-f32"); - benchmarks.register::>("async-dynamic-full-precision-f16"); - benchmarks.register::>("async-dynamic-full-precision-u8"); - benchmarks.register::>("async-dynamic-full-precision-i8"); + benchmarks.register( + "async-dynamic-full-precision-f32", + DynamicFullPrecision::::new(), + ); + benchmarks.register( + "async-dynamic-full-precision-f16", + DynamicFullPrecision::::new(), + ); + benchmarks.register( + "async-dynamic-full-precision-u8", + DynamicFullPrecision::::new(), + ); + benchmarks.register( + "async-dynamic-full-precision-i8", + DynamicFullPrecision::::new(), + ); product::register_benchmarks(benchmarks); scalar::register_benchmarks(benchmarks); spherical::register_benchmarks(benchmarks); } -////////////// -// Dispatch // -////////////// - -pub(super) trait BuildAndSearch<'a> { - /// The telemetry associated with the build and search. - type Data: Serialize; - - /// Run the job, returning either the completed data or an error. - fn run( - self, - checkpoint: Checkpoint<'_>, - output: &mut dyn Output, - ) -> Result; -} - -pub(super) trait BuildAndDynamicRun<'a> { - /// The telemetry associated with the build and dynamic run. - type Data: Serialize; - - /// Run the runbook, returning either the completed data or an error. - fn run( - self, - checkpoint: Checkpoint<'_>, - output: &mut dyn Output, - ) -> Result; -} - // Full Precision -pub(super) struct FullPrecision<'a, T> { - input: &'a IndexOperation, +pub(super) struct FullPrecision { _type: std::marker::PhantomData, } -impl<'a, T> FullPrecision<'a, T> { - fn new(input: &'a IndexOperation) -> Self { +impl FullPrecision { + pub(super) fn new() -> Self { Self { - input, _type: std::marker::PhantomData, } } } -impl Benchmark for FullPrecision<'static, T> +impl Benchmark for FullPrecision where T: VectorRepr + diskann_utils::sampling::WithApproximateNorm @@ -126,7 +107,7 @@ where type Input = IndexOperation; type Output = BuildResult; - fn try_match(input: &IndexOperation) -> Result { + fn try_match(&self, input: &IndexOperation) -> Result { match &input.source { IndexSource::Load(load) => datatype::Type::::try_match(&load.data_type), IndexSource::Build(build) => datatype::Type::::try_match(&build.data_type), @@ -134,6 +115,7 @@ where } fn description( + &self, f: &mut std::fmt::Formatter<'_>, input: Option<&IndexOperation>, ) -> std::fmt::Result { @@ -151,30 +133,79 @@ where } fn run( + &self, input: &IndexOperation, checkpoint: Checkpoint<'_>, - output: &mut dyn Output, + mut output: &mut dyn Output, ) -> anyhow::Result { - BuildAndSearch::run(FullPrecision::::new(input), checkpoint, output) + writeln!(output, "{}", input)?; + let (index, build_stats) = match &input.source { + IndexSource::Build(build) => { + let (index, build_stats) = run_build( + build, + common::FullPrecision, + None, + output, + |data| { + let index = diskann_async::new_index::( + build.try_as_config()?.build()?, + build.inmem_parameters(data.nrows(), data.ncols()), + common::NoDeletes, + )?; + build::set_start_points( + index.provider(), + data.as_view(), + build.start_point_strategy, + )?; + Ok(index) + }, + single_or_multi_insert, + )?; + + // save the index if requested + if let Some(save_path) = &build.save_path { + utils::tokio::block_on(save_index(index.clone(), save_path))?; + } + + (index, Some(build_stats)) + } + IndexSource::Load(load) => { + let index_config: &IndexConfiguration = &load.to_config()?; + + let index = + { utils::tokio::block_on(load_index::<_>(&load.load_path, index_config))? }; + + (Arc::new(index), None::) + } + }; + + let result = run_search_outer( + &input.search_phase, + common::FullPrecision, + index, + build_stats, + checkpoint, + )?; + + writeln!(output, "\n\n{}", result)?; + Ok(result) } } // Async Dynamic Run -pub(super) struct DynamicFullPrecision<'a, T> { - input: &'a DynamicIndexRun, +pub(super) struct DynamicFullPrecision { _type: std::marker::PhantomData, } -impl<'a, T> DynamicFullPrecision<'a, T> { - fn new(input: &'a DynamicIndexRun) -> Self { +impl DynamicFullPrecision { + fn new() -> Self { Self { - input, _type: std::marker::PhantomData, } } } -impl Benchmark for DynamicFullPrecision<'static, T> +impl Benchmark for DynamicFullPrecision where T: VectorRepr + diskann_utils::sampling::WithApproximateNorm @@ -184,11 +215,12 @@ where type Input = DynamicIndexRun; type Output = Vec>; - fn try_match(input: &DynamicIndexRun) -> Result { + fn try_match(&self, input: &DynamicIndexRun) -> Result { datatype::Type::::try_match(&input.build.data_type) } fn description( + &self, f: &mut std::fmt::Formatter<'_>, input: Option<&DynamicIndexRun>, ) -> std::fmt::Result { @@ -196,11 +228,64 @@ where } fn run( + &self, input: &DynamicIndexRun, - checkpoint: Checkpoint<'_>, - output: &mut dyn Output, + _checkpoint: Checkpoint<'_>, + mut output: &mut dyn Output, ) -> anyhow::Result>> { - BuildAndDynamicRun::run(DynamicFullPrecision::::new(input), checkpoint, output) + writeln!(output, "{}", input)?; + + let groundtruth_directory = input + .runbook_params + .resolved_gt_directory + .as_ref() + .ok_or_else(|| { + anyhow::anyhow!("Ground truth directory path was not resolved during validation") + })?; + + let mut runbook = bigann::RunBook::load( + &input.runbook_params.runbook_path, + &input.runbook_params.dataset_name, + &mut bigann::ScanDirectory::new(groundtruth_directory)?, + )?; + + let mut streamer = full_precision_streaming::(input, runbook.max_points())?; + + let mut results = Vec::new(); + let stages = runbook.len(); + let mut i = 1; + + runbook.run_with( + &mut streamer, + |o: managed::Stats| -> anyhow::Result<()> { + if o.inner().is_maintain() { + let message = format!("Ran maintenance before stage {}", i); + write!(output, "{}", crate::utils::SmallBanner(&message))?; + } else { + let message = + format!("Finished stage {} of {}: {}", i, stages, o.inner().kind()); + write!(output, "{}", crate::utils::SmallBanner(&message))?; + i += 1; + } + writeln!(output, "{}", o)?; + results.push(o); + Ok(()) + }, + )?; + + write!( + output, + "{}", + crate::utils::SmallBanner("End of Run Summary") + )?; + + writeln!( + output, + "{}", + streaming::stats::Summary::new(results.iter().map(|r| r.inner())) + )?; + + Ok(results) } } @@ -395,141 +480,6 @@ where } } -impl<'a, T> BuildAndSearch<'a> for FullPrecision<'a, T> -where - T: VectorRepr - + diskann_utils::sampling::WithApproximateNorm - + diskann::graph::SampleableForStart, -{ - type Data = BuildResult; - fn run( - self, - checkpoint: Checkpoint<'_>, - mut output: &mut dyn Output, - ) -> Result { - writeln!(output, "{}", self.input)?; - let (index, build_stats) = match &self.input.source { - IndexSource::Build(build) => { - let (index, build_stats) = run_build( - build, - common::FullPrecision, - None, - output, - |data| { - let index = diskann_async::new_index::( - build.try_as_config()?.build()?, - build.inmem_parameters(data.nrows(), data.ncols()), - common::NoDeletes, - )?; - build::set_start_points( - index.provider(), - data.as_view(), - build.start_point_strategy, - )?; - Ok(index) - }, - single_or_multi_insert, - )?; - - // save the index if requested - if let Some(save_path) = &build.save_path { - utils::tokio::block_on(save_index(index.clone(), save_path))?; - } - - (index, Some(build_stats)) - } - IndexSource::Load(load) => { - let index_config: &IndexConfiguration = &load.to_config()?; - - let index = - { utils::tokio::block_on(load_index::<_>(&load.load_path, index_config))? }; - - (Arc::new(index), None::) - } - }; - - let result = run_search_outer( - &self.input.search_phase, - common::FullPrecision, - index, - build_stats, - checkpoint, - )?; - - writeln!(output, "\n\n{}", result)?; - Ok(result) - } -} - -impl<'a, T> BuildAndDynamicRun<'a> for DynamicFullPrecision<'a, T> -where - T: VectorRepr - + diskann_utils::sampling::WithApproximateNorm - + diskann::graph::SampleableForStart, -{ - type Data = Vec>; - fn run( - self, - _checkpoint: Checkpoint<'_>, - mut output: &mut dyn Output, - ) -> Result { - writeln!(output, "{}", self.input)?; - - let groundtruth_directory = self - .input - .runbook_params - .resolved_gt_directory - .as_ref() - .ok_or_else(|| { - anyhow::anyhow!("Ground truth directory path was not resolved during validation") - })?; - - let mut runbook = bigann::RunBook::load( - &self.input.runbook_params.runbook_path, - &self.input.runbook_params.dataset_name, - &mut bigann::ScanDirectory::new(groundtruth_directory)?, - )?; - - let mut streamer = full_precision_streaming(&self, runbook.max_points())?; - - let mut results = Vec::new(); - let stages = runbook.len(); - let mut i = 1; - - runbook.run_with( - &mut streamer, - |o: managed::Stats| -> anyhow::Result<()> { - if o.inner().is_maintain() { - let message = format!("Ran maintenance before stage {}", i); - write!(output, "{}", crate::utils::SmallBanner(&message))?; - } else { - let message = - format!("Finished stage {} of {}: {}", i, stages, o.inner().kind()); - write!(output, "{}", crate::utils::SmallBanner(&message))?; - i += 1; - } - writeln!(output, "{}", o)?; - results.push(o); - Ok(()) - }, - )?; - - write!( - output, - "{}", - crate::utils::SmallBanner("End of Run Summary") - )?; - - writeln!( - output, - "{}", - streaming::stats::Summary::new(results.iter().map(|r| r.inner())) - )?; - - Ok(results) - } -} - /// The stack looks like this: /// /// - Bottom: [`FullPrecisionStream`]: The core streaming index implementation. @@ -540,19 +490,19 @@ where /// /// This function constructs the entire stack. fn full_precision_streaming( - config: &DynamicFullPrecision<'_, T>, + input: &DynamicIndexRun, max_points: usize, ) -> anyhow::Result>> where T: bytemuck::Pod + VectorRepr + WithApproximateNorm + SampleableForStart, { - let topk = match &config.input.search_phase { + let topk = match &input.search_phase { SearchPhase::Topk(topk) => topk, _ => anyhow::bail!("Only TopK is currently supported by the streaming index"), }; - let consolidate_threshold: f32 = config.input.runbook_params.consolidate_threshold; + let consolidate_threshold: f32 = input.runbook_params.consolidate_threshold; - let data = datafiles::load_dataset::(datafiles::BinFile(&config.input.build.data))?; + let data = datafiles::load_dataset::(datafiles::BinFile(&input.build.data))?; let queries = Arc::new(datafiles::load_dataset::(datafiles::BinFile( &topk.queries, ))?); @@ -561,28 +511,25 @@ where let max_points = ((max_points as f32) * (1.0 + 2.0 * consolidate_threshold)).ceil() as usize; let index = diskann_async::new_index::( - config - .input - .try_as_config(config.input.build.l_build)? - .build()?, - config.input.inmem_parameters(max_points, data.ncols()), + input.try_as_config(input.build.l_build)?.build()?, + input.inmem_parameters(max_points, data.ncols()), common::TableBasedDeletes, )?; build::set_start_points( index.provider(), data.as_view(), - config.input.build.start_point_strategy, + input.build.start_point_strategy, )?; - let num_threads_and_tasks = NonZeroUsize::new(config.input.build.num_threads).unwrap(); + let num_threads_and_tasks = NonZeroUsize::new(input.build.num_threads).unwrap(); let managed_stream = FullPrecisionStream { index, search: topk.clone(), runtime: benchmark_core::tokio::runtime(num_threads_and_tasks.get())?, ntasks: num_threads_and_tasks, - inplace_delete_num_to_replace: config.input.runbook_params.ip_delete_num_to_replace, - inplace_delete_method: config.input.runbook_params.ip_delete_method.into(), + inplace_delete_num_to_replace: input.runbook_params.ip_delete_num_to_replace, + inplace_delete_method: input.runbook_params.ip_delete_method.into(), }; let managed = Managed::new(max_points, consolidate_threshold, managed_stream); diff --git a/diskann-benchmark/src/backend/index/product.rs b/diskann-benchmark/src/backend/index/product.rs index a857e4e57..a529217bf 100644 --- a/diskann-benchmark/src/backend/index/product.rs +++ b/diskann-benchmark/src/backend/index/product.rs @@ -13,8 +13,8 @@ pub(super) fn register_benchmarks(benchmarks: &mut Benchmarks) { { use half::f16; - benchmarks.register::>("async-pq-f32"); - benchmarks.register::>("async-pq-f16"); + benchmarks.register("async-pq-f32", imp::ProductQuantized::::new()); + benchmarks.register("async-pq-f16", imp::ProductQuantized::::new()); } // Stub implementation @@ -42,7 +42,7 @@ mod imp { use crate::{ backend::index::{ - benchmarks::{run_build, run_search_outer, BuildAndSearch, FullPrecision}, + benchmarks::{run_build, run_search_outer, FullPrecision}, build::{self, load_index, save_index, single_or_multi_insert, BuildStats}, result::QuantBuildResult, }, @@ -50,21 +50,19 @@ mod imp { utils::{self, datafiles}, }; - pub(super) struct ProductQuantized<'a, T> { - input: &'a IndexPQOperation, + pub(super) struct ProductQuantized { _type: std::marker::PhantomData, } - impl<'a, T> ProductQuantized<'a, T> { - fn new(input: &'a IndexPQOperation) -> Self { + impl ProductQuantized { + pub(super) fn new() -> Self { Self { - input, _type: std::marker::PhantomData, } } } - impl Benchmark for ProductQuantized<'static, T> + impl Benchmark for ProductQuantized where T: VectorRepr + diskann_utils::sampling::WithApproximateNorm @@ -74,51 +72,31 @@ mod imp { type Input = IndexPQOperation; type Output = QuantBuildResult; - fn try_match(input: &IndexPQOperation) -> Result { - as Benchmark>::try_match(&input.index_operation) + fn try_match(&self, input: &IndexPQOperation) -> Result { + FullPrecision::::new().try_match(&input.index_operation) } fn description( + &self, f: &mut std::fmt::Formatter<'_>, input: Option<&IndexPQOperation>, ) -> std::fmt::Result { - as Benchmark>::description( - f, - input.map(|f| &f.index_operation), - ) + FullPrecision::::new().description(f, input.map(|f| &f.index_operation)) } fn run( + &self, input: &IndexPQOperation, checkpoint: Checkpoint<'_>, - output: &mut dyn Output, - ) -> anyhow::Result { - let pq = ProductQuantized::::new(input); - BuildAndSearch::run(pq, checkpoint, output) - } - } - - impl<'a, T> BuildAndSearch<'a> for ProductQuantized<'a, T> - where - T: VectorRepr - + diskann_utils::sampling::WithApproximateNorm - + diskann::graph::SampleableForStart, - datatype::Type: DispatchRule, - { - type Data = QuantBuildResult; - fn run( - self, - checkpoint: Checkpoint<'_>, mut output: &mut dyn Output, - ) -> Result { - writeln!(output, "{}", self.input)?; + ) -> anyhow::Result { + writeln!(output, "{}", input)?; - let hybrid = common::Hybrid::new(self.input.max_fp_vecs_per_prune); + let hybrid = common::Hybrid::new(input.max_fp_vecs_per_prune); - let (index, build_stats, quant_training_time) = match &self.input.index_operation.source - { + let (index, build_stats, quant_training_time) = match &input.index_operation.source { IndexSource::Load(load) => { - let index_config: &IndexConfiguration = &self.input.to_config()?; + let index_config: &IndexConfiguration = &input.to_config()?; let index = { utils::tokio::block_on(load_index::<_>(&load.load_path, index_config))? }; @@ -139,17 +117,16 @@ mod imp { diskann_async::train_pq( train_data.as_view(), - self.input.num_pq_chunks, - &mut StdRng::seed_from_u64(self.input.seed), + input.num_pq_chunks, + &mut StdRng::seed_from_u64(input.seed), build.num_threads, )? }; let create_index = |data_view: MatrixView| { let index = diskann_async::new_quant_index::( - self.input.try_as_config()?.build()?, - self.input - .inmem_parameters(data_view.nrows(), data_view.ncols())?, + input.try_as_config()?.build()?, + input.inmem_parameters(data_view.nrows(), data_view.ncols())?, table, common::NoDeletes, )?; @@ -180,9 +157,9 @@ mod imp { } }; - let build = if self.input.use_fp_for_search { + let build = if input.use_fp_for_search { run_search_outer( - &self.input.index_operation.search_phase, + &input.index_operation.search_phase, common::FullPrecision, index, build_stats, @@ -190,7 +167,7 @@ mod imp { )? } else { run_search_outer( - &self.input.index_operation.search_phase, + &input.index_operation.search_phase, hybrid, index, build_stats, diff --git a/diskann-benchmark/src/backend/index/scalar.rs b/diskann-benchmark/src/backend/index/scalar.rs index b418c0d7b..85c230605 100644 --- a/diskann-benchmark/src/backend/index/scalar.rs +++ b/diskann-benchmark/src/backend/index/scalar.rs @@ -14,17 +14,17 @@ pub(super) fn register_benchmarks(benchmarks: &mut Benchmarks) { use half::f16; // f32 - benchmarks.register::>("async-sq-8-bit-f32"); - benchmarks.register::>("async-sq-4-bit-f32"); - benchmarks.register::>("async-sq-2-bit-f32"); - benchmarks.register::>("async-sq-1-bit-f32"); + benchmarks.register("async-sq-8-bit-f32", imp::ScalarQuantized::<8, f32>::new()); + benchmarks.register("async-sq-4-bit-f32", imp::ScalarQuantized::<4, f32>::new()); + benchmarks.register("async-sq-2-bit-f32", imp::ScalarQuantized::<2, f32>::new()); + benchmarks.register("async-sq-1-bit-f32", imp::ScalarQuantized::<1, f32>::new()); // f16 - benchmarks.register::>("async-sq-8-bit-f16"); - benchmarks.register::>("async-sq-4-bit-f16"); - benchmarks.register::>("async-sq-2-bit-f16"); - benchmarks.register::>("async-sq-1-bit-f16"); + benchmarks.register("async-sq-8-bit-f16", imp::ScalarQuantized::<8, f16>::new()); + benchmarks.register("async-sq-4-bit-f16", imp::ScalarQuantized::<4, f16>::new()); + benchmarks.register("async-sq-2-bit-f16", imp::ScalarQuantized::<2, f16>::new()); + benchmarks.register("async-sq-1-bit-f16", imp::ScalarQuantized::<1, f16>::new()); // i8 - benchmarks.register::>("async-sq-1-bit-i8"); + benchmarks.register("async-sq-1-bit-i8", imp::ScalarQuantized::<1, i8>::new()); } // Stub implementation @@ -55,7 +55,7 @@ mod imp { use crate::{ backend::index::{ - benchmarks::{run_build, run_search_outer, BuildAndSearch, FullPrecision}, + benchmarks::{run_build, run_search_outer, FullPrecision}, build::{self, load_index, only_single_insert, save_index, BuildStats}, result::QuantBuildResult, }, @@ -64,16 +64,13 @@ mod imp { }; // Scalar Quantized - pub(super) struct ScalarQuantized<'a, const NBITS: usize, T> { - input: &'a IndexSQOperation, + pub(super) struct ScalarQuantized { _type: std::marker::PhantomData, } - impl<'a, const NBITS: usize, T> ScalarQuantized<'a, NBITS, T> { - fn new(input: &'a IndexSQOperation) -> Self { - assert_eq!(input.num_bits, NBITS); + impl ScalarQuantized { + pub(super) fn new() -> Self { Self { - input, _type: std::marker::PhantomData, } } @@ -81,11 +78,11 @@ mod imp { macro_rules! impl_sq_build { ($N:literal, $T: ty) => { - impl Benchmark for ScalarQuantized<'static, $N, $T> { + impl Benchmark for ScalarQuantized<$N, $T> { type Input = IndexSQOperation; type Output = QuantBuildResult; - fn try_match(input: &IndexSQOperation) -> Result { + fn try_match(&self, input: &IndexSQOperation) -> Result { let mut failure_score: Option = None; match input.index_operation.source { IndexSource::Load(_) => {} @@ -96,7 +93,7 @@ mod imp { } } - if as Benchmark>::try_match(&input.index_operation) + if FullPrecision::<$T>::new().try_match(&input.index_operation) .is_err() { *failure_score.get_or_insert(0) += 1; @@ -113,6 +110,7 @@ mod imp { } fn description( + &self, f: &mut std::fmt::Formatter<'_>, input: Option<&IndexSQOperation>, ) -> std::fmt::Result { @@ -173,25 +171,20 @@ mod imp { } fn run( + &self, input: &IndexSQOperation, checkpoint: Checkpoint<'_>, - output: &mut dyn Output, + mut output: &mut dyn Output, ) -> anyhow::Result { - let sq = ScalarQuantized::<$N, $T>::new(input); - BuildAndSearch::run(sq, checkpoint, output) - } - } + assert_eq!( + input.num_bits, + $N, + "INTERNAL ERROR: this should not have passed the match check" + ); - impl<'a> BuildAndSearch<'a> for ScalarQuantized<'a, $N, $T> { - type Data = QuantBuildResult; - fn run( - self, - checkpoint: Checkpoint<'_>, - mut output: &mut dyn Output, - ) -> Result { - writeln!(output, "{}", self.input)?; + writeln!(output, "{}", input)?; - let (index, build_stats, quant_training_time) = match &self.input.index_operation.source { + let (index, build_stats, quant_training_time) = match &input.index_operation.source { IndexSource::Load(load) => { let index_config: &IndexConfiguration = &load.to_config()?; @@ -208,7 +201,7 @@ mod imp { let start = std::time::Instant::now(); let quantizer = diskann_quantization::scalar::train::ScalarQuantizationParameters::new( - diskann_quantization::num::Positive::new(self.input.standard_deviations).context( + diskann_quantization::num::Positive::new(input.standard_deviations).context( "please file a bug report, this should not have made it past the\ front end", )?, @@ -216,8 +209,8 @@ mod imp { .train(data.as_view()); let create_index = |data_view: MatrixView<$T>| { let index = diskann_async::new_quant_index::<$T, _, _>( - self.input.try_as_config()?.build()?, - self.input + input.try_as_config()?.build()?, + input .inmem_parameters(data_view.nrows(), data_view.ncols())?, inmem::WithBits::<$N>::new(quantizer), common::NoDeletes, @@ -247,9 +240,9 @@ mod imp { }; - let build = if self.input.use_fp_for_search { + let build = if input.use_fp_for_search { run_search_outer( - &self.input.index_operation.search_phase, + &input.index_operation.search_phase, common::FullPrecision, index, build_stats, @@ -257,7 +250,7 @@ mod imp { )? } else { run_search_outer( - &self.input.index_operation.search_phase, + &input.index_operation.search_phase, common::Quantized, index, build_stats, diff --git a/diskann-benchmark/src/backend/index/spherical.rs b/diskann-benchmark/src/backend/index/spherical.rs index 33cb2e2fe..507337da7 100644 --- a/diskann-benchmark/src/backend/index/spherical.rs +++ b/diskann-benchmark/src/backend/index/spherical.rs @@ -16,9 +16,9 @@ pub(super) fn register_benchmarks(benchmarks: &mut Benchmarks) { #[cfg(feature = "spherical-quantization")] { - benchmarks.register::>(NAME); - benchmarks.register::>(NAME); - benchmarks.register::>(NAME); + benchmarks.register(NAME, imp::SphericalQ::<1>); + benchmarks.register(NAME, imp::SphericalQ::<2>); + benchmarks.register(NAME, imp::SphericalQ::<4>); } // Stub implementation @@ -52,7 +52,6 @@ mod imp { use crate::{ backend::index::{ - benchmarks::BuildAndSearch, build::{self, only_single_insert, BuildStats}, result::AggregatedSearchResults, search, @@ -68,15 +67,7 @@ mod imp { }; /// The dispatcher target for `spherical-quantization` operations. - pub(super) struct SphericalQ<'a, const NBITS: usize> { - input: &'a SphericalQuantBuild, - } - - impl<'a, const NBITS: usize> SphericalQ<'a, NBITS> { - pub(super) fn new(input: &'a SphericalQuantBuild) -> Self { - Self { input } - } - } + pub(super) struct SphericalQ; macro_rules! write_field { ($f:ident, $field:tt, $fmt:literal, $($expr:tt)*) => { @@ -126,11 +117,14 @@ mod imp { macro_rules! build_and_search { ($N:literal) => { - impl Benchmark for SphericalQ<'static, $N> { + impl Benchmark for SphericalQ<$N> { type Input = SphericalQuantBuild; type Output = SphericalBuildResult; - fn try_match(input: &SphericalQuantBuild) -> Result { + fn try_match( + &self, + input: &SphericalQuantBuild, + ) -> Result { let mut failure_score: Option = None; if input.build.multi_insert.is_some() { failure_score = Some(1); @@ -157,6 +151,7 @@ mod imp { } fn description( + &self, f: &mut std::fmt::Formatter<'_>, input: Option<&SphericalQuantBuild>, ) -> std::fmt::Result { @@ -200,42 +195,37 @@ mod imp { } fn run( + &self, input: &SphericalQuantBuild, checkpoint: Checkpoint<'_>, - output: &mut dyn Output, + mut output: &mut dyn Output, ) -> anyhow::Result { - let sq = SphericalQ::<$N>::new(input); - BuildAndSearch::run(sq, checkpoint, output) - } - } + assert_eq!( + input.num_bits.get(), + $N, + "INTERNAL ERROR: this should not have passed the match check" + ); - impl<'a> BuildAndSearch<'a> for SphericalQ<'a, $N> { - type Data = SphericalBuildResult; - fn run( - self, - _checkpoint: Checkpoint<'_>, - mut output: &mut dyn Output, - ) -> Result { - writeln!(output, "{}", self.input)?; + writeln!(output, "{}", input)?; - let build = &self.input.build; + let build = &input.build; let data: Arc> = Arc::new(datafiles::load_dataset(datafiles::BinFile(&build.data))?); let start = std::time::Instant::now(); let m: diskann_vector::distance::Metric = build.distance.into(); - let pre_scale = match self.input.pre_scale { + let pre_scale = match input.pre_scale { Some(v) => v.try_into()?, None => diskann_quantization::spherical::PreScale::None, }; let quantizer = diskann_quantization::spherical::SphericalQuantizer::train( data.as_view(), - (&self.input.transform_kind).into(), + (&input.transform_kind).into(), m.try_into()?, pre_scale, - &mut rand::rngs::StdRng::seed_from_u64(self.input.seed), + &mut rand::rngs::StdRng::seed_from_u64(input.seed), GlobalAllocator, )?; @@ -244,8 +234,8 @@ mod imp { // We manually inline the build and search loops because we support // multiple different kinds of searches. let index = diskann_async::new_quant_index::( - self.input.try_as_config()?.build()?, - self.input.inmem_parameters(data.nrows(), data.ncols()), + input.try_as_config()?.build()?, + input.inmem_parameters(data.nrows(), data.ncols()), diskann_quantization::spherical::iface::Impl::<$N>::new(quantizer)?, NoDeletes, )?; @@ -274,12 +264,12 @@ mod imp { runs: Vec::new(), }; - match &self.input.search_phase { + match &input.search_phase { SearchPhase::Topk(search_phase) => { // Handle Topk search phase // Save construction stats before running queries. - _checkpoint.checkpoint(&result)?; + checkpoint.checkpoint(&result)?; let queries: Arc> = Arc::new(datafiles::load_dataset( datafiles::BinFile(&search_phase.queries), @@ -295,7 +285,7 @@ mod imp { &search_phase.runs, ); - for &layout in self.input.query_layouts.iter() { + for &layout in input.query_layouts.iter() { let knn = benchmark_core::search::graph::KNN::new( index.clone(), queries.clone(), @@ -317,7 +307,7 @@ mod imp { // Handle Range search phase // Save construction stats before running queries. - _checkpoint.checkpoint(&result)?; + checkpoint.checkpoint(&result)?; let queries: Arc> = Arc::new(datafiles::load_dataset( datafiles::BinFile(&search_phase.queries), @@ -333,7 +323,7 @@ mod imp { &search_phase.runs, ); - for &layout in self.input.query_layouts.iter() { + for &layout in input.query_layouts.iter() { let range = benchmark_core::search::graph::Range::new( index.clone(), queries.clone(), @@ -358,7 +348,7 @@ mod imp { // Handle Beta Filtered Topk search phase // Save construction stats before running queries. - _checkpoint.checkpoint(&result)?; + checkpoint.checkpoint(&result)?; let queries: Arc> = Arc::new(datafiles::load_dataset( datafiles::BinFile(&search_phase.queries), @@ -384,7 +374,7 @@ mod imp { .map(utils::filters::as_query_label_provider) .collect(); - for &layout in self.input.query_layouts.iter() { + for &layout in input.query_layouts.iter() { let strategy = inmem::spherical::Quantized::search(layout.into()); let search_strategies = setup_filter_strategies( search_phase.beta, @@ -414,7 +404,7 @@ mod imp { // Handle Beta Filtered Topk search phase // Save construction stats before running queries. - _checkpoint.checkpoint(&result)?; + checkpoint.checkpoint(&result)?; let queries: Arc> = Arc::new(datafiles::load_dataset( datafiles::BinFile(&search_phase.queries), @@ -440,7 +430,7 @@ mod imp { .map(utils::filters::as_query_label_provider) .collect(); - for &layout in self.input.query_layouts.iter() { + for &layout in input.query_layouts.iter() { let multihop = benchmark_core::search::graph::MultiHop::new( index.clone(), queries.clone(), diff --git a/diskann-benchmark/src/utils/mod.rs b/diskann-benchmark/src/utils/mod.rs index ebdac8116..cee6cacd3 100644 --- a/diskann-benchmark/src/utils/mod.rs +++ b/diskann-benchmark/src/utils/mod.rs @@ -111,7 +111,7 @@ macro_rules! stub_impl { use crate::inputs; pub(super) fn register(name: &str, registry: &mut Benchmarks) { - registry.register::(name); + registry.register(name, Stub); } /// An empty placeholder to provide a hint for the necessary feature. @@ -121,11 +121,12 @@ macro_rules! stub_impl { type Input = $input; type Output = serde_json::Value; - fn try_match(_input: &$input) -> Result { + fn try_match(&self, _input: &$input) -> Result { Err(FailureScore(0)) } fn description( + &self, f: &mut std::fmt::Formatter<'_>, _input: Option<&$input>, ) -> std::fmt::Result { @@ -134,6 +135,7 @@ macro_rules! stub_impl { } fn run( + &self, _input: &$input, _checkpoint: Checkpoint<'_>, _output: &mut dyn Output,