diff --git a/diskann-providers/src/index/diskann_async.rs b/diskann-providers/src/index/diskann_async.rs index 560fb13fa..8480a4b45 100644 --- a/diskann-providers/src/index/diskann_async.rs +++ b/diskann-providers/src/index/diskann_async.rs @@ -158,10 +158,9 @@ where #[cfg(test)] pub(crate) mod tests { use std::{ - collections::HashSet, marker::PhantomData, num::{NonZeroU32, NonZeroUsize}, - sync::{Arc, Mutex}, + sync::Arc, }; use crate::storage::VirtualStorageProvider; @@ -174,8 +173,8 @@ pub(crate) mod tests { DefaultSearchStrategy, InplaceDeleteStrategy, InsertStrategy, MultiInsertStrategy, SearchStrategy, }, - index::{PartitionedNeighbors, QueryLabelProvider, QueryVisitDecision}, - search::{Knn, Range}, + index::{PartitionedNeighbors, QueryLabelProvider}, + search::Range, search_output_buffer, }, neighbor::Neighbor, @@ -378,54 +377,6 @@ pub(crate) mod tests { } } - /// Check the contents of a single search for the query. - /// - /// # Arguments - async fn test_multihop_search( - index: &DiskANNIndex, - parameters: &SearchParameters, - strategy: &S, - query: Q, - mut checker: Checker, - filter: &dyn QueryLabelProvider, - ) where - DP: DataProvider, - S: DefaultSearchStrategy, - Q: Copy + std::fmt::Debug + Send + Sync, - Checker: FnMut(usize, (u32, f32)) -> Result<(), Box>, - { - let mut ids = vec![0; parameters.search_k]; - let mut distances = vec![0.0; parameters.search_k]; - let mut result_output_buffer = - search_output_buffer::IdDistance::new(&mut ids, &mut distances); - let search_params = Knn::new_default(parameters.search_k, parameters.search_l).unwrap(); - let multihop = graph::search::MultihopSearch::new(search_params, filter); - index - .search( - multihop, - strategy, - ¶meters.context, - query, - &mut result_output_buffer, - ) - .await - .unwrap(); - - // Loop over the requested number of results to check, invoking the checker closure. - // - // If the checker closure detects an error, embed that error in a more descriptive - // formatted panic. - for i in 0..parameters.to_check { - println!("{ids:?}"); - if let Err(message) = checker(i, (ids[i], distances[i])) { - panic!( - "Check failed for result {} with error: {}. Query = {:?}. Result: ({}, {})", - i, message, query, ids[i], distances[i] - ); - } - } - } - async fn test_paged_search( index: &DiskANNIndex, strategy: S, @@ -1203,330 +1154,6 @@ pub(crate) mod tests { test_beta_filtering(filter, 3, 7).await; } - ///////////////////////// - // Multi-Hop Filtering // - ///////////////////////// - - async fn test_multihop_filtering( - filter: &dyn QueryLabelProvider, - dim: usize, - grid_size: usize, - ) { - let l = 10; - let max_degree = 2 * dim; - let num_points = (grid_size).pow(dim as u32); - - let (config, parameters) = - simplified_builder(l, max_degree, Metric::L2, dim, num_points, no_modify).unwrap(); - - let mut adjacency_lists = Grid::Three.neighbors(grid_size); - let mut vectors = f32::generate_grid(dim, grid_size); - - assert_eq!(adjacency_lists.len(), num_points); - assert_eq!(vectors.len(), num_points); - - // Append an additional item to the input vectors for the start point. - adjacency_lists.push((num_points as u32 - 1).into()); - vectors.push(vec![grid_size as f32; dim]); - - let table = train_pq( - squish(vectors.iter(), dim).as_view(), - 2.min(dim), // Number of PQ chunks is bounded by the dimension. - &mut create_rnd_from_seed_in_tests(0x04a8832604476965), - 1usize, - ) - .unwrap(); - - let index = new_quant_index::(config, parameters, table, NoDeletes).unwrap(); - let neighbor_accessor = &mut index.provider().neighbors(); - populate_data(&index.data_provider, &DefaultContext, &vectors).await; - populate_graph(neighbor_accessor, &adjacency_lists).await; - - let corpus: diskann_utils::views::Matrix = - squish(vectors.iter().take(num_points), dim); - let query = vec![grid_size as f32; dim]; - - // The strategy we use here for checking is that we pull in a lot of neighbors and - // then walk through the list, verifying monotonicity and that the filter was - // applied properly. - let parameters = SearchParameters { - context: DefaultContext, - search_l: 40, - search_k: 20, - to_check: 20, - }; - - // Compute the raw groundtruth, then screen out any points that don't match the filter - let gt = { - let mut gt = groundtruth(corpus.as_view(), &query, |a, b| SquaredL2::evaluate(a, b)); - gt.retain(|n| filter.is_match(n.id)); - gt.sort_unstable_by(|a, b| a.cmp(b).reverse()); - gt - }; - - // Clone the base groundtruth so we don't need to recompute every time. - let mut gt_clone = gt.clone(); - let strategy = FullPrecision; - - test_multihop_search( - &index, - ¶meters, - &strategy.clone(), - query.as_slice(), - |_, (id, distance)| -> Result<(), Box> { - if let Some(position) = is_match(>_clone, Neighbor::new(id, distance), 0.0) { - gt_clone.remove(position); - Ok(()) - } else { - if id.into_usize() == num_points + 1 { - return Err(Box::new("The start point should not be returned")); - } - Err(Box::new("mismatch")) - } - }, - filter, - ) - .await; - } - - #[tokio::test] - async fn test_even_filtering_multihop() { - test_multihop_filtering(&EvenFilter, 3, 7).await; - } - - /// Metrics tracked by [`CallbackFilter`] for test validation. - #[derive(Debug, Clone, Default)] - struct CallbackMetrics { - /// Total number of callback invocations. - total_visits: usize, - /// Number of candidates that were rejected. - rejected_count: usize, - /// Number of candidates that had distance adjusted. - adjusted_count: usize, - /// All visited candidate IDs in order. - visited_ids: Vec, - } - - #[derive(Debug)] - struct CallbackFilter { - blocked: u32, - adjusted: u32, - adjustment_factor: f32, - metrics: Mutex, - } - - impl CallbackFilter { - fn new(blocked: u32, adjusted: u32, adjustment_factor: f32) -> Self { - Self { - blocked, - adjusted, - adjustment_factor, - metrics: Mutex::new(CallbackMetrics::default()), - } - } - - fn hits(&self) -> Vec { - self.metrics - .lock() - .expect("callback metrics mutex should not be poisoned") - .visited_ids - .clone() - } - - fn metrics(&self) -> CallbackMetrics { - self.metrics - .lock() - .expect("callback metrics mutex should not be poisoned") - .clone() - } - } - - impl QueryLabelProvider for CallbackFilter { - fn is_match(&self, _: u32) -> bool { - true - } - - fn on_visit(&self, neighbor: Neighbor) -> QueryVisitDecision { - let mut metrics = self - .metrics - .lock() - .expect("callback metrics mutex should not be poisoned"); - - metrics.total_visits += 1; - metrics.visited_ids.push(neighbor.id); - - if neighbor.id == self.blocked { - metrics.rejected_count += 1; - return QueryVisitDecision::Reject; - } - if neighbor.id == self.adjusted { - metrics.adjusted_count += 1; - let adjusted = - Neighbor::new(neighbor.id, neighbor.distance * self.adjustment_factor); - return QueryVisitDecision::Accept(adjusted); - } - QueryVisitDecision::Accept(neighbor) - } - } - - #[tokio::test] - async fn test_multihop_callback_enforces_filtering() { - // Test configuration - let dim = 3; - let grid_size: usize = 5; - let l = 10; - let max_degree = 2 * dim; - let num_points = (grid_size).pow(dim as u32); - - let (config, parameters) = - simplified_builder(l, max_degree, Metric::L2, dim, num_points, no_modify).unwrap(); - - let mut adjacency_lists = Grid::Three.neighbors(grid_size); - let mut vectors = f32::generate_grid(dim, grid_size); - - adjacency_lists.push((num_points as u32 - 1).into()); - vectors.push(vec![grid_size as f32; dim]); - - let table = train_pq( - squish(vectors.iter(), dim).as_view(), - 2.min(dim), - &mut create_rnd_from_seed_in_tests(0xdd81b895605c73d4), - 1usize, - ) - .unwrap(); - - let index = new_quant_index::(config, parameters, table, NoDeletes).unwrap(); - let neighbor_accessor = &mut index.provider().neighbors(); - populate_data(&index.data_provider, &DefaultContext, &vectors).await; - populate_graph(neighbor_accessor, &adjacency_lists).await; - - let corpus: diskann_utils::views::Matrix = - squish(vectors.iter().take(num_points), dim); - let query = vec![grid_size as f32; dim]; - - let parameters = SearchParameters { - context: DefaultContext, - search_l: 40, - search_k: 20, - to_check: 10, - }; - - let mut ids = vec![0; parameters.search_k]; - let mut distances = vec![0.0; parameters.search_k]; - let mut result_output_buffer = - search_output_buffer::IdDistance::new(&mut ids, &mut distances); - - let blocked = (num_points - 2) as u32; - let adjusted = (num_points - 1) as u32; - - // Compute baseline groundtruth for validation - let mut baseline_gt = - groundtruth(corpus.as_view(), &query, |a, b| SquaredL2::evaluate(a, b)); - baseline_gt.sort_unstable_by(|a, b| a.cmp(b).reverse()); - - assert!( - baseline_gt.iter().any(|n| n.id == blocked), - "blocked candidate must exist in groundtruth" - ); - - let baseline_adjusted_distance = baseline_gt - .iter() - .find(|n| n.id == adjusted) - .expect("adjusted node should exist in groundtruth") - .distance; - - let filter = CallbackFilter::new(blocked, adjusted, 0.5); - - let search_params = Knn::new_default(parameters.search_k, parameters.search_l).unwrap(); - let multihop = graph::search::MultihopSearch::new(search_params, &filter); - let stats = index - .search( - multihop, - &FullPrecision, - ¶meters.context, - query.as_slice(), - &mut result_output_buffer, - ) - .await - .unwrap(); - - // Retrieve callback metrics for detailed validation - let callback_metrics = filter.metrics(); - - // Validate search statistics - assert!( - stats.result_count >= parameters.to_check as u32, - "expected at least {} results, got {}", - parameters.to_check, - stats.result_count - ); - - // Validate callback was invoked and tracked the blocked candidate - assert!( - callback_metrics.total_visits > 0, - "callback should have been invoked at least once" - ); - assert!( - filter.hits().contains(&blocked), - "callback must evaluate the blocked candidate (visited {} candidates)", - callback_metrics.total_visits - ); - assert_eq!( - callback_metrics.rejected_count, 1, - "exactly one candidate (blocked={}) should be rejected", - blocked - ); - - // Validate blocked candidate is excluded from results - let produced = stats.result_count as usize; - let inspected = produced.min(parameters.to_check); - assert!( - !ids.iter().take(inspected).any(|&id| id == blocked), - "blocked candidate {} should not appear in final results (found in: {:?})", - blocked, - &ids[..inspected] - ); - - // Validate distance adjustment was applied - assert!( - callback_metrics.adjusted_count >= 1, - "adjusted candidate {} should have been visited", - adjusted - ); - - let adjusted_idx = ids - .iter() - .take(inspected) - .position(|&id| id == adjusted) - .expect("adjusted candidate should be present in results"); - let expected_distance = baseline_adjusted_distance * 0.5; - assert!( - (distances[adjusted_idx] - expected_distance).abs() < 1e-5, - "callback should adjust distances before ranking: \ - expected {:.6}, got {:.6} (baseline: {:.6}, factor: 0.5)", - expected_distance, - distances[adjusted_idx], - baseline_adjusted_distance - ); - - // Log metrics for debugging/review - println!( - "test_multihop_callback_enforces_filtering metrics:\n\ - - total callback visits: {}\n\ - - rejected count: {}\n\ - - adjusted count: {}\n\ - - search hops: {}\n\ - - search comparisons: {}\n\ - - result count: {}", - callback_metrics.total_visits, - callback_metrics.rejected_count, - callback_metrics.adjusted_count, - stats.hops, - stats.cmps, - stats.result_count - ); - } - ////////////// // Deletion // ////////////// @@ -3878,417 +3505,6 @@ pub(crate) mod tests { ); } - ///////////////////////////////////// - // Multi-Hop Callback Edge Cases // - ///////////////////////////////////// - - /// Filter that rejects all candidates via on_visit callback. - /// Used to test the fallback behavior when all candidates are rejected. - #[derive(Debug)] - struct RejectAllFilter { - allowed_in_results: HashSet, - } - - impl RejectAllFilter { - fn only>(ids: I) -> Self { - Self { - allowed_in_results: ids.into_iter().collect(), - } - } - } - - impl QueryLabelProvider for RejectAllFilter { - fn is_match(&self, vec_id: u32) -> bool { - self.allowed_in_results.contains(&vec_id) - } - - fn on_visit(&self, _neighbor: Neighbor) -> QueryVisitDecision { - QueryVisitDecision::Reject - } - } - - /// Filter that tracks visit order and can terminate early. - #[derive(Debug)] - struct TerminatingFilter { - target: u32, - hits: Mutex>, - } - - impl TerminatingFilter { - fn new(target: u32) -> Self { - Self { - target, - hits: Mutex::new(Vec::new()), - } - } - - fn hits(&self) -> Vec { - self.hits - .lock() - .expect("mutex should not be poisoned") - .clone() - } - } - - impl QueryLabelProvider for TerminatingFilter { - fn is_match(&self, vec_id: u32) -> bool { - vec_id == self.target - } - - fn on_visit(&self, neighbor: Neighbor) -> QueryVisitDecision { - self.hits - .lock() - .expect("mutex should not be poisoned") - .push(neighbor.id); - if neighbor.id == self.target { - QueryVisitDecision::Terminate - } else { - QueryVisitDecision::Accept(neighbor) - } - } - } - - #[tokio::test] - async fn test_multihop_reject_all_returns_zero_results() { - // When on_visit rejects all candidates, the search should return zero results - // because rejected candidates don't get added to the frontier. - let dim = 3; - let grid_size: usize = 4; - let l = 10; - let max_degree = 2 * dim; - let num_points = (grid_size).pow(dim as u32); - - let (config, parameters) = - simplified_builder(l, max_degree, Metric::L2, dim, num_points, no_modify).unwrap(); - - let mut adjacency_lists = Grid::Three.neighbors(grid_size); - let mut vectors = f32::generate_grid(dim, grid_size); - - adjacency_lists.push((num_points as u32 - 1).into()); - vectors.push(vec![grid_size as f32; dim]); - - let table = train_pq( - squish(vectors.iter(), dim).as_view(), - 2.min(dim), - &mut create_rnd_from_seed_in_tests(0x1234567890abcdef), - 1usize, - ) - .unwrap(); - - let index = new_quant_index::(config, parameters, table, NoDeletes).unwrap(); - let neighbor_accessor = &mut index.provider().neighbors(); - populate_data(&index.data_provider, &DefaultContext, &vectors).await; - populate_graph(neighbor_accessor, &adjacency_lists).await; - - let query = vec![grid_size as f32; dim]; - - let mut ids = vec![0; 10]; - let mut distances = vec![0.0; 10]; - let mut result_output_buffer = - search_output_buffer::IdDistance::new(&mut ids, &mut distances); - - // Allow only the first start point (0) in results via is_match, - // but reject everything via on_visit - let filter = RejectAllFilter::only([0_u32]); - - let search_params = Knn::new_default(10, 20).unwrap(); - let multihop = graph::search::MultihopSearch::new(search_params, &filter); - let stats = index - .search( - multihop, - &FullPrecision, - &DefaultContext, - query.as_slice(), - &mut result_output_buffer, - ) - .await - .unwrap(); - - // When all candidates are rejected via on_visit, result_count should be 0 - // because rejected candidates are not added to the search frontier - assert_eq!( - stats.result_count, 0, - "rejecting all via on_visit should result in zero results" - ); - } - - #[tokio::test] - async fn test_multihop_early_termination() { - // Test that Terminate causes the search to stop early - let dim = 3; - let grid_size: usize = 5; - let l = 10; - let max_degree = 2 * dim; - let num_points = (grid_size).pow(dim as u32); - - let (config, parameters) = - simplified_builder(l, max_degree, Metric::L2, dim, num_points, no_modify).unwrap(); - - let mut adjacency_lists = Grid::Three.neighbors(grid_size); - let mut vectors = f32::generate_grid(dim, grid_size); - - adjacency_lists.push((num_points as u32 - 1).into()); - vectors.push(vec![grid_size as f32; dim]); - - let table = train_pq( - squish(vectors.iter(), dim).as_view(), - 2.min(dim), - &mut create_rnd_from_seed_in_tests(0xfedcba0987654321), - 1usize, - ) - .unwrap(); - - let index = new_quant_index::(config, parameters, table, NoDeletes).unwrap(); - let neighbor_accessor = &mut index.provider().neighbors(); - populate_data(&index.data_provider, &DefaultContext, &vectors).await; - populate_graph(neighbor_accessor, &adjacency_lists).await; - - let query = vec![grid_size as f32; dim]; - - let mut ids = vec![0; 10]; - let mut distances = vec![0.0; 10]; - let mut result_output_buffer = - search_output_buffer::IdDistance::new(&mut ids, &mut distances); - - // Target a point in the middle of the grid - let target = (num_points / 2) as u32; - let filter = TerminatingFilter::new(target); - - let search_params = Knn::new_default(10, 40).unwrap(); - let multihop = graph::search::MultihopSearch::new(search_params, &filter); - let stats = index - .search( - multihop, - &FullPrecision, - &DefaultContext, - query.as_slice(), - &mut result_output_buffer, - ) - .await - .unwrap(); - - let hits = filter.hits(); - - // The search should have terminated after finding the target - assert!( - hits.contains(&target), - "search should have visited the target" - ); - assert!( - stats.result_count >= 1, - "should have at least one result (the target)" - ); - } - - #[tokio::test] - async fn test_multihop_distance_adjustment_affects_ranking() { - // Test that distance adjustments in on_visit affect the final ranking - let dim = 3; - let grid_size: usize = 4; - let l = 10; - let max_degree = 2 * dim; - let num_points = (grid_size).pow(dim as u32); - - let (config, parameters) = - simplified_builder(l, max_degree, Metric::L2, dim, num_points, no_modify).unwrap(); - - let mut adjacency_lists = Grid::Three.neighbors(grid_size); - let mut vectors = f32::generate_grid(dim, grid_size); - - adjacency_lists.push((num_points as u32 - 1).into()); - vectors.push(vec![grid_size as f32; dim]); - - let table = train_pq( - squish(vectors.iter(), dim).as_view(), - 2.min(dim), - &mut create_rnd_from_seed_in_tests(0xabcdef1234567890), - 1usize, - ) - .unwrap(); - - let index = new_quant_index::(config, parameters, table, NoDeletes).unwrap(); - let neighbor_accessor = &mut index.provider().neighbors(); - populate_data(&index.data_provider, &DefaultContext, &vectors).await; - populate_graph(neighbor_accessor, &adjacency_lists).await; - - let query = vec![0.0; dim]; // Query at origin - - // First, run without adjustment to get baseline - let mut baseline_ids = vec![0; 10]; - let mut baseline_distances = vec![0.0; 10]; - let mut baseline_buffer = - search_output_buffer::IdDistance::new(&mut baseline_ids, &mut baseline_distances); - - let search_params = Knn::new_default(10, 20).unwrap(); - let multihop = graph::search::MultihopSearch::new(search_params, &EvenFilter); - let baseline_stats = index - .search( - multihop, - &FullPrecision, - &DefaultContext, - query.as_slice(), - &mut baseline_buffer, - ) - .await - .unwrap(); - - // Now run with a filter that boosts a specific far-away point - let boosted_point = (num_points - 2) as u32; // A point far from origin - let filter = CallbackFilter::new(u32::MAX, boosted_point, 0.01); // Shrink its distance - - let mut adjusted_ids = vec![0; 10]; - let mut adjusted_distances = vec![0.0; 10]; - let mut adjusted_buffer = - search_output_buffer::IdDistance::new(&mut adjusted_ids, &mut adjusted_distances); - - let search_params = Knn::new_default(10, 20).unwrap(); - let multihop = graph::search::MultihopSearch::new(search_params, &filter); - let adjusted_stats = index - .search( - multihop, - &FullPrecision, - &DefaultContext, - query.as_slice(), - &mut adjusted_buffer, - ) - .await - .unwrap(); - - // Both searches should return results - assert!( - baseline_stats.result_count > 0, - "baseline should have results" - ); - assert!( - adjusted_stats.result_count > 0, - "adjusted should have results" - ); - - // If the boosted point was visited and adjusted, it should appear earlier - // in the adjusted results than in the baseline (or appear when it didn't before) - let boosted_in_baseline = baseline_ids - .iter() - .take(baseline_stats.result_count as usize) - .position(|&id| id == boosted_point); - let boosted_in_adjusted = adjusted_ids - .iter() - .take(adjusted_stats.result_count as usize) - .position(|&id| id == boosted_point); - - // The distance adjustment should have some effect if the point was visited - if filter.hits().contains(&boosted_point) { - assert!( - boosted_in_adjusted.is_some(), - "boosted point should appear in adjusted results when visited" - ); - if let (Some(baseline_pos), Some(adjusted_pos)) = - (boosted_in_baseline, boosted_in_adjusted) - { - assert!( - adjusted_pos <= baseline_pos, - "boosted point should rank equal or better after distance reduction" - ); - } - } - } - - #[tokio::test] - async fn test_multihop_terminate_stops_traversal() { - // Test that Terminate (without accept) stops traversal immediately - #[derive(Debug)] - struct TerminateAfterN { - max_visits: usize, - visits: Mutex, - } - - impl TerminateAfterN { - fn new(max_visits: usize) -> Self { - Self { - max_visits, - visits: Mutex::new(0), - } - } - - fn visit_count(&self) -> usize { - *self.visits.lock().unwrap() - } - } - - impl QueryLabelProvider for TerminateAfterN { - fn is_match(&self, _: u32) -> bool { - true - } - - fn on_visit(&self, neighbor: Neighbor) -> QueryVisitDecision { - let mut visits = self.visits.lock().unwrap(); - *visits += 1; - if *visits >= self.max_visits { - QueryVisitDecision::Terminate - } else { - QueryVisitDecision::Accept(neighbor) - } - } - } - - let dim = 3; - let grid_size: usize = 5; - let l = 10; - let max_degree = 2 * dim; - let num_points = (grid_size).pow(dim as u32); - - let (config, parameters) = - simplified_builder(l, max_degree, Metric::L2, dim, num_points, no_modify).unwrap(); - - let mut adjacency_lists = Grid::Three.neighbors(grid_size); - let mut vectors = f32::generate_grid(dim, grid_size); - - adjacency_lists.push((num_points as u32 - 1).into()); - vectors.push(vec![grid_size as f32; dim]); - - let table = train_pq( - squish(vectors.iter(), dim).as_view(), - 2.min(dim), - &mut create_rnd_from_seed_in_tests(0x9876543210fedcba), - 1usize, - ) - .unwrap(); - - let index = new_quant_index::(config, parameters, table, NoDeletes).unwrap(); - let neighbor_accessor = &mut index.provider().neighbors(); - populate_data(&index.data_provider, &DefaultContext, &vectors).await; - populate_graph(neighbor_accessor, &adjacency_lists).await; - - let query = vec![grid_size as f32; dim]; - - let mut ids = vec![0; 10]; - let mut distances = vec![0.0; 10]; - let mut result_output_buffer = - search_output_buffer::IdDistance::new(&mut ids, &mut distances); - - let max_visits = 5; - let filter = TerminateAfterN::new(max_visits); - - let search_params = Knn::new_default(10, 100).unwrap(); // Large L to ensure we'd visit more without termination - let multihop = graph::search::MultihopSearch::new(search_params, &filter); - let _stats = index - .search( - multihop, - &FullPrecision, - &DefaultContext, - query.as_slice(), - &mut result_output_buffer, - ) - .await - .unwrap(); - - // The search should have stopped after max_visits - assert!( - filter.visit_count() <= max_visits + 10, // Allow some slack for beam expansion - "search should have terminated early, got {} visits", - filter.visit_count() - ); - } - #[tokio::test] async fn vectors_with_infinity_values_should_be_inserted_and_searched_without_panic() { let l_build: usize = 20; diff --git a/diskann/src/graph/search/mod.rs b/diskann/src/graph/search/mod.rs index fac279421..99afb0e8e 100644 --- a/diskann/src/graph/search/mod.rs +++ b/diskann/src/graph/search/mod.rs @@ -43,7 +43,7 @@ use crate::{ }; mod knn_search; -mod multihop_search; +pub(crate) mod multihop_search; mod range_search; pub mod record; diff --git a/diskann/src/graph/search/multihop_search.rs b/diskann/src/graph/search/multihop_search.rs index aba0f44c5..8ed36b53c 100644 --- a/diskann/src/graph/search/multihop_search.rs +++ b/diskann/src/graph/search/multihop_search.rs @@ -299,3 +299,57 @@ where Ok(make_stats(scratch)) } + +#[cfg(test)] +mod tests { + use super::*; + + /// A simple label evaluator that matches only even IDs. + #[derive(Debug)] + struct EvenOnly; + + impl QueryLabelProvider for EvenOnly { + fn is_match(&self, id: u32) -> bool { + id.is_multiple_of(2) + } + } + + #[test] + fn predicate_eval_requires_not_visited_and_matching() { + let mut visited = HashSet::new(); + visited.insert(2u32); + let label = EvenOnly; + let pred = NotInMutWithLabelCheck::new(&mut visited, &label); + + // Not visited + matches label → true + assert!(pred.eval(&4)); + + // Already visited + matches label → false + assert!(!pred.eval(&2)); + + // Not visited + doesn't match label → false + assert!(!pred.eval(&3)); + + // Already visited + doesn't match → false + visited.insert(3); + let pred = NotInMutWithLabelCheck::new(&mut visited, &label); + assert!(!pred.eval(&3)); + } + + #[test] + fn predicate_eval_mut_inserts_only_matching() { + let mut visited = HashSet::new(); + let label = EvenOnly; + let mut pred = NotInMutWithLabelCheck::new(&mut visited, &label); + + // Matching + not visited → inserts and returns true + assert!(pred.eval_mut(&4)); + // Second call → already visited, returns false + assert!(!pred.eval_mut(&4)); + + // Non-matching → not inserted, returns false + assert!(!pred.eval_mut(&3)); + // Confirm 3 was NOT added to visited set + assert!(!pred.visited_set.contains(&3)); + } +} diff --git a/diskann/src/graph/test/cases/mod.rs b/diskann/src/graph/test/cases/mod.rs index 74ac80f1f..05104f1d2 100644 --- a/diskann/src/graph/test/cases/mod.rs +++ b/diskann/src/graph/test/cases/mod.rs @@ -7,6 +7,7 @@ mod consolidate; mod grid_insert; mod grid_search; mod inplace_delete; +mod multihop; mod paged_search; mod range_search; diff --git a/diskann/src/graph/test/cases/multihop.rs b/diskann/src/graph/test/cases/multihop.rs new file mode 100644 index 000000000..d12a5572e --- /dev/null +++ b/diskann/src/graph/test/cases/multihop.rs @@ -0,0 +1,781 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +//! Tests for multihop search traversal behavior. +//! +//! Organized into two layers: +//! - **Unit tests** call `multihop_search_internal` directly on small hand-constructed +//! graphs to test each decision path (Accept, Reject+two-hop, Terminate) in isolation. +//! - **Integration tests** go through `index.search(MultihopSearch{...})` end-to-end +//! with baselines for regression protection. + +use std::sync::{Arc, Mutex}; + +use diskann_vector::distance::Metric; + +use crate::{ + graph::{ + self, AdjacencyList, DiskANNIndex, + index::{QueryLabelProvider, QueryVisitDecision}, + search::{ + Knn, MultihopSearch, + record::NoopSearchRecord, + scratch::{PriorityQueueConfiguration, SearchScratch}, + }, + search_output_buffer, + test::provider as test_provider, + }, + neighbor::Neighbor, + provider::BuildQueryComputer, + test::{ + TestRoot, + cmp::{assert_eq_verbose, verbose_eq}, + get_or_save_test_results, + tokio::current_thread_runtime, + }, +}; + +fn root() -> TestRoot { + TestRoot::new("graph/test/cases/multihop") +} + +///////////// +// Filters // +///////////// + +/// Accepts all candidates unconditionally. +#[derive(Debug)] +struct AcceptAll; + +impl QueryLabelProvider for AcceptAll { + fn is_match(&self, _: u32) -> bool { + true + } +} + +/// Accepts all IDs but only allows even IDs in results. +#[derive(Debug)] +struct EvenFilter; + +impl QueryLabelProvider for EvenFilter { + fn is_match(&self, id: u32) -> bool { + id.is_multiple_of(2) + } +} + +/// Rejects all candidates via `on_visit`. +#[derive(Debug)] +struct RejectAll; + +impl QueryLabelProvider for RejectAll { + fn is_match(&self, _: u32) -> bool { + true + } + + fn on_visit(&self, _: Neighbor) -> QueryVisitDecision { + QueryVisitDecision::Reject + } +} + +/// Tracks visited IDs and terminates when the target is found. +#[derive(Debug)] +struct TerminateOnTarget { + target: u32, + hits: Mutex>, +} + +impl TerminateOnTarget { + fn new(target: u32) -> Self { + Self { + target, + hits: Mutex::new(Vec::new()), + } + } + + fn hits(&self) -> Vec { + self.hits.lock().unwrap().clone() + } +} + +impl QueryLabelProvider for TerminateOnTarget { + fn is_match(&self, id: u32) -> bool { + id == self.target + } + + fn on_visit(&self, neighbor: Neighbor) -> QueryVisitDecision { + self.hits.lock().unwrap().push(neighbor.id); + if neighbor.id == self.target { + QueryVisitDecision::Terminate + } else { + QueryVisitDecision::Accept(neighbor) + } + } +} + +/// Accepts all via `is_match`, but blocks one ID and adjusts another's distance. +#[derive(Debug)] +struct BlockAndAdjust { + blocked: u32, + adjusted: u32, + factor: f32, + metrics: Mutex, +} + +#[derive(Debug, Clone, Default, serde::Serialize, serde::Deserialize)] +struct BlockAndAdjustMetrics { + total_visits: usize, + rejected_count: usize, + adjusted_count: usize, + visited_ids: Vec, +} + +verbose_eq!(BlockAndAdjustMetrics { + total_visits, + rejected_count, + adjusted_count, + visited_ids, +}); + +impl BlockAndAdjust { + fn new(blocked: u32, adjusted: u32, factor: f32) -> Self { + Self { + blocked, + adjusted, + factor, + metrics: Mutex::new(BlockAndAdjustMetrics::default()), + } + } + + fn metrics(&self) -> BlockAndAdjustMetrics { + self.metrics.lock().unwrap().clone() + } +} + +impl QueryLabelProvider for BlockAndAdjust { + fn is_match(&self, _: u32) -> bool { + true + } + + fn on_visit(&self, neighbor: Neighbor) -> QueryVisitDecision { + let mut m = self.metrics.lock().unwrap(); + m.total_visits += 1; + m.visited_ids.push(neighbor.id); + + if neighbor.id == self.blocked { + m.rejected_count += 1; + QueryVisitDecision::Reject + } else if neighbor.id == self.adjusted { + m.adjusted_count += 1; + QueryVisitDecision::Accept(Neighbor::new(neighbor.id, neighbor.distance * self.factor)) + } else { + QueryVisitDecision::Accept(neighbor) + } + } +} + +//////////////////////////////////// +// Shared helpers for small graphs // +//////////////////////////////////// + +/// Build a 1D provider with the given points and adjacency lists. +/// +/// `start_pos` is the 1D position of the start node (id = `start_id`). +fn build_1d_provider( + start_id: u32, + start_pos: f32, + start_neighbors: AdjacencyList, + points: Vec<(u32, Vec, AdjacencyList)>, + max_degree: usize, +) -> test_provider::Provider { + let config = test_provider::Config::new( + Metric::L2, + max_degree, + test_provider::StartPoint::new(start_id, vec![start_pos]), + ) + .unwrap(); + + test_provider::Provider::new_from(config, std::iter::once((start_id, start_neighbors)), points) + .unwrap() +} + +/// Call `multihop_search_internal` directly on a provider, bypassing the Search trait. +/// +/// Returns (internal_stats, best_neighbors) where best_neighbors is the contents +/// of the scratch priority queue sorted by distance (nearest first). +fn run_internal( + provider: &test_provider::Provider, + query: &[f32], + k: usize, + l: usize, + max_degree: usize, + filter: &dyn QueryLabelProvider, +) -> (graph::index::InternalSearchStats, Vec>) { + let rt = current_thread_runtime(); + rt.block_on(async { + let mut accessor = test_provider::Accessor::new(provider); + let computer = accessor.build_query_computer(query).unwrap(); + + let mut scratch = SearchScratch::new(PriorityQueueConfiguration::Fixed(l), Some(l)); + + let stats = crate::graph::search::multihop_search::multihop_search_internal( + max_degree, + &Knn::new_default(k, l).unwrap(), + &mut accessor, + &computer, + &mut scratch, + &mut NoopSearchRecord::new(), + filter, + ) + .await + .unwrap(); + + let mut results: Vec<_> = scratch.best.iter().collect(); + results.sort_unstable_by(|a, b| { + a.distance + .partial_cmp(&b.distance) + .unwrap_or(std::cmp::Ordering::Equal) + }); + + (stats, results) + }) +} + +////////////////////////////////////////// +// Unit tests: multihop_search_internal // +////////////////////////////////////////// + +/// Graph: start(10) → 0 → 1 → 2, all matching (AcceptAll). +/// Query at 1.5 — should find all three nodes via normal one-hop expansion. +#[test] +fn accept_all_finds_all_nodes() { + let start_id = 10u32; + let provider = build_1d_provider( + start_id, + 5.0, + AdjacencyList::from_iter_untrusted([0, 1, 2]), + vec![ + ( + 0, + vec![0.0], + AdjacencyList::from_iter_untrusted([1, start_id]), + ), + (1, vec![1.0], AdjacencyList::from_iter_untrusted([0, 2])), + (2, vec![2.0], AdjacencyList::from_iter_untrusted([1])), + ], + 3, + ); + + let (stats, results) = run_internal(&provider, &[1.5], 3, 10, 3, &AcceptAll); + + let ids: Vec = results.iter().map(|n| n.id).collect(); + assert!(ids.contains(&0), "node 0 should be found"); + assert!(ids.contains(&1), "node 1 should be found"); + assert!(ids.contains(&2), "node 2 should be found"); + assert!(stats.cmps > 0, "should have computed distances"); +} + +/// Graph: start(10) → 1(odd) → 2(even), start → 3(odd) → 4(even), start → 0(even). +/// EvenFilter rejects odds via two-hop. Nodes 2 and 4 are only reachable through odds. +#[test] +fn reject_triggers_two_hop_expansion() { + let start_id = 10u32; + let provider = build_1d_provider( + start_id, + 5.0, + AdjacencyList::from_iter_untrusted([0, 1, 3]), + vec![ + ( + 0, + vec![0.0], + AdjacencyList::from_iter_untrusted([1, start_id]), + ), + ( + 1, + vec![1.0], + AdjacencyList::from_iter_untrusted([0, 2, start_id]), + ), + (2, vec![2.0], AdjacencyList::from_iter_untrusted([1, 3])), + ( + 3, + vec![3.0], + AdjacencyList::from_iter_untrusted([0, 4, start_id]), + ), + (4, vec![4.0], AdjacencyList::from_iter_untrusted([3, 2])), + ], + 4, + ); + + let filter = EvenFilter; + let (stats, results) = run_internal(&provider, &[2.0], 5, 20, 4, &filter); + + let ids: Vec = results.iter().map(|n| n.id).collect(); + + // Even nodes reachable only via two-hop through odd nodes. + assert!( + ids.contains(&2), + "node 2 should be found via two-hop through node 1" + ); + assert!( + ids.contains(&4), + "node 4 should be found via two-hop through node 3" + ); + assert!(ids.contains(&0), "node 0 should be found directly"); + + // All results in the best set should be even (matching). + for n in &results { + if n.id == start_id { + continue; + } + assert!( + n.id.is_multiple_of(2), + "non-matching node {} should not be in best set", + n.id + ); + } + + assert!(stats.hops > 0, "should have expanded at least one hop"); +} + +/// RejectAll filter: on_visit rejects everything → only start point in best set, +/// two-hop expansion tries but finds nothing matching (is_match returns true, but +/// on_visit already rejected the one-hop node so two-hop candidates come from rejected). +#[test] +fn reject_all_yields_only_start() { + let start_id = 10u32; + let provider = build_1d_provider( + start_id, + 0.0, + AdjacencyList::from_iter_untrusted([0, 1]), + vec![ + ( + 0, + vec![1.0], + AdjacencyList::from_iter_untrusted([1, start_id]), + ), + (1, vec![2.0], AdjacencyList::from_iter_untrusted([0])), + ], + 2, + ); + + let (_stats, results) = run_internal(&provider, &[0.5], 5, 10, 2, &RejectAll); + + // Only the start point should be in the best set — all one-hop neighbors + // were rejected. Two-hop expansion goes through rejected nodes but RejectAll's + // is_match returns true, so two-hop neighbors that pass NotInMutWithLabelCheck + // get inserted. Let's just verify the search completed without panic. + assert!( + !results.is_empty(), + "at least the start point should be present" + ); +} + +/// TerminateOnTarget: search stops as soon as target is visited. +#[test] +fn terminate_stops_search_on_target() { + let start_id = 10u32; + // Linear chain: start → 0 → 1 → 2(target) → 3. + // With beam_width=1, search visits one node at a time. + let provider = build_1d_provider( + start_id, + -1.0, + AdjacencyList::from_iter_untrusted([0]), + vec![ + ( + 0, + vec![0.0], + AdjacencyList::from_iter_untrusted([1, start_id]), + ), + (1, vec![1.0], AdjacencyList::from_iter_untrusted([0, 2])), + (2, vec![2.0], AdjacencyList::from_iter_untrusted([1, 3])), + (3, vec![3.0], AdjacencyList::from_iter_untrusted([2])), + ], + 2, + ); + + let filter = TerminateOnTarget::new(2); + let (_stats, _results) = run_internal(&provider, &[0.0], 4, 10, 2, &filter); + + let hits = filter.hits(); + assert!(hits.contains(&2), "target node 2 should have been visited"); + assert_eq!( + *hits.last().unwrap(), + 2, + "target should be the last visited node (search terminated)" + ); + // Node 3 is beyond the target — should NOT have been visited. + assert!( + !hits.contains(&3), + "node 3 should not be visited after termination" + ); +} + +/// BlockAndAdjust: blocked node excluded from results, adjusted node has modified distance. +#[test] +fn block_and_adjust_modifies_results() { + let start_id = 10u32; + // start → 0, 1, 2. Block node 1, adjust node 2's distance by 0.5×. + let provider = build_1d_provider( + start_id, + 5.0, + AdjacencyList::from_iter_untrusted([0, 1, 2]), + vec![ + ( + 0, + vec![0.0], + AdjacencyList::from_iter_untrusted([1, start_id]), + ), + (1, vec![1.0], AdjacencyList::from_iter_untrusted([0, 2])), + (2, vec![2.0], AdjacencyList::from_iter_untrusted([1])), + ], + 3, + ); + + let filter = BlockAndAdjust::new(1, 2, 0.5); + let (_stats, results) = run_internal(&provider, &[0.0], 5, 10, 3, &filter); + + let ids: Vec = results.iter().map(|n| n.id).collect(); + + // Blocked node should still be in the best set because on_visit returns Reject + // which means it's added to two-hop candidates, not to best. But it was never + // Accept'd, so it should NOT appear. + // Actually: Reject means it's NOT inserted into scratch.best. Correct. + assert!( + !ids.contains(&1), + "blocked node 1 should not appear in results" + ); + + // Adjusted node's distance should be halved. + // Node 2 at position 2.0, query at 0.0 → L2 squared distance = 4.0, adjusted = 2.0. + let node2 = results + .iter() + .find(|n| n.id == 2) + .expect("node 2 should be in results"); + let expected = 4.0 * 0.5; + assert!( + (node2.distance - expected).abs() < 1e-5, + "adjusted distance should be {}, got {}", + expected, + node2.distance + ); + + let metrics = filter.metrics(); + assert_eq!(metrics.rejected_count, 1, "exactly one rejection (node 1)"); + assert_eq!(metrics.adjusted_count, 1, "exactly one adjustment (node 2)"); +} + +/////////////////////////////// +// Integration tests (E2E) // +/////////////////////////////// + +/// Baseline struct for end-to-end multihop search results. +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] +struct MultihopBaseline { + grid_size: usize, + query: Vec, + k: usize, + l: usize, + results: Vec<(u32, f32)>, + comparisons: usize, + hops: usize, +} + +verbose_eq!(MultihopBaseline { + grid_size, + query, + k, + l, + results, + comparisons, + hops, +}); + +/// Set up a 3D grid index using the test provider. +fn setup_grid_index(grid_size: usize) -> Arc> { + use crate::graph::test::synthetic::Grid; + + let grid = Grid::Three; + let provider = test_provider::Provider::grid(grid, grid_size).unwrap(); + + let index_config = graph::config::Builder::new( + provider.max_degree(), + graph::config::MaxDegree::same(), + 100, + Metric::L2.into(), + ) + .build() + .unwrap(); + + Arc::new(DiskANNIndex::new(index_config, provider, None)) +} + +/// Two-hop reachability through non-matching nodes, end-to-end with baseline. +/// +/// Uses the same hand-constructed 1D graph as the unit test, but goes through +/// `index.search(MultihopSearch{...})` to also exercise post-processing. +#[test] +fn two_hop_reaches_through_non_matching() { + let rt = current_thread_runtime(); + let mut test_root = root(); + let mut path = test_root.path(); + let name = path.push("two_hop_reaches_through_non_matching"); + + let start_id = 10u32; + let provider = build_1d_provider( + start_id, + 5.0, + AdjacencyList::from_iter_untrusted([0, 1, 3]), + vec![ + ( + 0, + vec![0.0], + AdjacencyList::from_iter_untrusted([1, start_id]), + ), + ( + 1, + vec![1.0], + AdjacencyList::from_iter_untrusted([0, 2, start_id]), + ), + (2, vec![2.0], AdjacencyList::from_iter_untrusted([1, 3])), + ( + 3, + vec![3.0], + AdjacencyList::from_iter_untrusted([0, 4, start_id]), + ), + (4, vec![4.0], AdjacencyList::from_iter_untrusted([3, 2])), + ], + 4, + ); + + let index_config = + graph::config::Builder::new(4, graph::config::MaxDegree::same(), 100, Metric::L2.into()) + .build() + .unwrap(); + + let index = Arc::new(DiskANNIndex::new(index_config, provider, None)); + let filter = EvenFilter; + let query = vec![2.0f32]; + let k = 5; + let l = 20; + + let search_params = Knn::new_default(k, l).unwrap(); + let multihop = MultihopSearch::new(search_params, &filter); + + let mut ids = vec![0u32; k]; + let mut distances = vec![0.0f32; k]; + let mut buffer = search_output_buffer::IdDistance::new(&mut ids, &mut distances); + + let stats = rt + .block_on(index.search( + multihop, + &test_provider::Strategy::new(), + &test_provider::Context::new(), + query.as_slice(), + &mut buffer, + )) + .unwrap(); + + let result_count = stats.result_count as usize; + let baseline = MultihopBaseline { + grid_size: 0, // hand-constructed, not grid-based + query: query.clone(), + k, + l, + results: ids[..result_count] + .iter() + .zip(distances[..result_count].iter()) + .map(|(&id, &d)| (id, d)) + .collect(), + comparisons: stats.cmps as usize, + hops: stats.hops as usize, + }; + + let expected = get_or_save_test_results(&name, &baseline); + assert_eq_verbose!(expected, baseline); + + // Invariants that must hold regardless of baseline. + let result_ids: Vec = baseline.results.iter().map(|(id, _)| *id).collect(); + assert!( + result_ids.contains(&2), + "node 2 must be found via two-hop through node 1" + ); + assert!( + result_ids.contains(&4), + "node 4 must be found via two-hop through node 3" + ); + for &(id, _) in &baseline.results { + assert!( + id.is_multiple_of(2), + "all results must match the even filter, got id {}", + id + ); + } +} + +/// Even-filtered multihop search on a 3D grid with baseline. +#[test] +fn even_filtering_grid() { + let rt = current_thread_runtime(); + let mut test_root = root(); + let mut path = test_root.path(); + let name = path.push("even_filtering_grid"); + + let grid_size = 7; + let index = setup_grid_index(grid_size); + let query = vec![grid_size as f32; 3]; + let filter = EvenFilter; + + let k = 20; + let l = 40; + let search_params = Knn::new_default(k, l).unwrap(); + let multihop = MultihopSearch::new(search_params, &filter); + + let mut ids = vec![0u32; k]; + let mut distances = vec![0.0f32; k]; + let mut buffer = search_output_buffer::IdDistance::new(&mut ids, &mut distances); + + let stats = rt + .block_on(index.search( + multihop, + &test_provider::Strategy::new(), + &test_provider::Context::new(), + query.as_slice(), + &mut buffer, + )) + .unwrap(); + + let result_count = stats.result_count as usize; + let baseline = MultihopBaseline { + grid_size, + query: query.clone(), + k, + l, + results: ids[..result_count] + .iter() + .zip(distances[..result_count].iter()) + .map(|(&id, &d)| (id, d)) + .collect(), + comparisons: stats.cmps as usize, + hops: stats.hops as usize, + }; + + let expected = get_or_save_test_results(&name, &baseline); + assert_eq_verbose!(expected, baseline); + + // Invariant: all returned IDs must be even. + for &(id, _) in &baseline.results { + assert!( + id.is_multiple_of(2), + "all results must match the even filter, got id {}", + id + ); + } +} + +/// Callback filtering on a 3D grid: block one node, adjust another's distance. +#[test] +fn callback_filtering_grid() { + use crate::graph::test::synthetic::Grid; + + let rt = current_thread_runtime(); + let mut test_root = root(); + let mut path = test_root.path(); + let name = path.push("callback_filtering_grid"); + + let grid_size = 5; + let num_points = Grid::Three.num_points(grid_size); + let index = setup_grid_index(grid_size); + let query = vec![grid_size as f32; 3]; + + let blocked = (num_points - 2) as u32; + let adjusted = (num_points - 1) as u32; + let filter = BlockAndAdjust::new(blocked, adjusted, 0.5); + + let k = 20; + let l = 40; + let search_params = Knn::new_default(k, l).unwrap(); + let multihop = MultihopSearch::new(search_params, &filter); + + let mut ids = vec![0u32; k]; + let mut distances = vec![0.0f32; k]; + let mut buffer = search_output_buffer::IdDistance::new(&mut ids, &mut distances); + + let stats = rt + .block_on(index.search( + multihop, + &test_provider::Strategy::new(), + &test_provider::Context::new(), + query.as_slice(), + &mut buffer, + )) + .unwrap(); + + let result_count = stats.result_count as usize; + + #[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] + struct CallbackBaseline { + grid_size: usize, + query: Vec, + k: usize, + l: usize, + blocked: u32, + adjusted: u32, + factor: f32, + results: Vec<(u32, f32)>, + comparisons: usize, + hops: usize, + metrics: BlockAndAdjustMetrics, + } + + verbose_eq!(CallbackBaseline { + grid_size, + query, + k, + l, + blocked, + adjusted, + factor, + results, + comparisons, + hops, + metrics, + }); + + let baseline = CallbackBaseline { + grid_size, + query: query.clone(), + k, + l, + blocked, + adjusted, + factor: 0.5, + results: ids[..result_count] + .iter() + .zip(distances[..result_count].iter()) + .map(|(&id, &d)| (id, d)) + .collect(), + comparisons: stats.cmps as usize, + hops: stats.hops as usize, + metrics: filter.metrics(), + }; + + let expected = get_or_save_test_results(&name, &baseline); + assert_eq_verbose!(expected, baseline); + + // Invariants. + let result_ids: Vec = baseline.results.iter().map(|(id, _)| *id).collect(); + assert!( + !result_ids.contains(&blocked), + "blocked node {} must not appear in results", + blocked + ); + assert_eq!( + baseline.metrics.rejected_count, 1, + "exactly one rejection expected" + ); + assert!( + baseline.metrics.adjusted_count >= 1, + "adjusted node should have been visited" + ); +} diff --git a/diskann/test/generated/graph/test/cases/multihop/callback_filtering_grid.json b/diskann/test/generated/graph/test/cases/multihop/callback_filtering_grid.json new file mode 100644 index 000000000..f760f89fa --- /dev/null +++ b/diskann/test/generated/graph/test/cases/multihop/callback_filtering_grid.json @@ -0,0 +1,173 @@ +{ + "file": "diskann/src/graph/test/cases/multihop.rs", + "test": "graph/test/cases/multihop/callback_filtering_grid", + "payload": { + "adjusted": 124, + "blocked": 123, + "comparisons": 68, + "factor": 0.5, + "grid_size": 5, + "hops": 42, + "k": 20, + "l": 40, + "metrics": { + "adjusted_count": 1, + "rejected_count": 1, + "total_visits": 65, + "visited_ids": [ + 124, + 99, + 119, + 123, + 94, + 114, + 74, + 69, + 89, + 93, + 113, + 117, + 73, + 97, + 49, + 109, + 121, + 68, + 88, + 92, + 72, + 96, + 48, + 112, + 116, + 108, + 64, + 84, + 44, + 67, + 87, + 91, + 63, + 83, + 43, + 120, + 104, + 24, + 39, + 59, + 107, + 111, + 47, + 71, + 19, + 79, + 103, + 115, + 23, + 95, + 38, + 58, + 62, + 82, + 86, + 42, + 66, + 18, + 78, + 90, + 46, + 70, + 22, + 106, + 110 + ] + }, + "query": [ + 5.0, + 5.0, + 5.0 + ], + "results": [ + [ + 124, + 1.5 + ], + [ + 119, + 6.0 + ], + [ + 99, + 6.0 + ], + [ + 94, + 9.0 + ], + [ + 118, + 9.0 + ], + [ + 98, + 9.0 + ], + [ + 74, + 11.0 + ], + [ + 114, + 11.0 + ], + [ + 122, + 11.0 + ], + [ + 93, + 12.0 + ], + [ + 97, + 14.0 + ], + [ + 73, + 14.0 + ], + [ + 117, + 14.0 + ], + [ + 113, + 14.0 + ], + [ + 89, + 14.0 + ], + [ + 69, + 14.0 + ], + [ + 92, + 17.0 + ], + [ + 88, + 17.0 + ], + [ + 68, + 17.0 + ], + [ + 121, + 18.0 + ] + ] + } +} \ No newline at end of file diff --git a/diskann/test/generated/graph/test/cases/multihop/even_filtering_grid.json b/diskann/test/generated/graph/test/cases/multihop/even_filtering_grid.json new file mode 100644 index 000000000..fa8524304 --- /dev/null +++ b/diskann/test/generated/graph/test/cases/multihop/even_filtering_grid.json @@ -0,0 +1,98 @@ +{ + "file": "diskann/src/graph/test/cases/multihop.rs", + "test": "graph/test/cases/multihop/even_filtering_grid", + "payload": { + "comparisons": 145, + "grid_size": 7, + "hops": 101, + "k": 20, + "l": 40, + "query": [ + 7.0, + 7.0, + 7.0 + ], + "results": [ + [ + 342, + 3.0 + ], + [ + 334, + 9.0 + ], + [ + 292, + 9.0 + ], + [ + 286, + 9.0 + ], + [ + 340, + 11.0 + ], + [ + 328, + 11.0 + ], + [ + 244, + 11.0 + ], + [ + 284, + 17.0 + ], + [ + 278, + 17.0 + ], + [ + 236, + 17.0 + ], + [ + 230, + 19.0 + ], + [ + 242, + 19.0 + ], + [ + 326, + 19.0 + ], + [ + 272, + 21.0 + ], + [ + 188, + 21.0 + ], + [ + 290, + 21.0 + ], + [ + 194, + 21.0 + ], + [ + 332, + 21.0 + ], + [ + 320, + 21.0 + ], + [ + 228, + 27.0 + ] + ] + } +} \ No newline at end of file diff --git a/diskann/test/generated/graph/test/cases/multihop/two_hop_reaches_through_non_matching.json b/diskann/test/generated/graph/test/cases/multihop/two_hop_reaches_through_non_matching.json new file mode 100644 index 000000000..be9436318 --- /dev/null +++ b/diskann/test/generated/graph/test/cases/multihop/two_hop_reaches_through_non_matching.json @@ -0,0 +1,28 @@ +{ + "file": "diskann/src/graph/test/cases/multihop.rs", + "test": "graph/test/cases/multihop/two_hop_reaches_through_non_matching", + "payload": { + "comparisons": 5, + "grid_size": 0, + "hops": 6, + "k": 5, + "l": 20, + "query": [ + 2.0 + ], + "results": [ + [ + 2, + 0.0 + ], + [ + 4, + 4.0 + ], + [ + 0, + 4.0 + ] + ] + } +} \ No newline at end of file