diff --git a/diskann-benchmark/src/backend/disk_index/search.rs b/diskann-benchmark/src/backend/disk_index/search.rs index 3494b564b..b972b8205 100644 --- a/diskann-benchmark/src/backend/disk_index/search.rs +++ b/diskann-benchmark/src/backend/disk_index/search.rs @@ -261,7 +261,7 @@ where .zip(statistics_vec.par_iter_mut()) .zip(result_counts.par_iter_mut()); - zipped.for_each_in_pool(&pool, |(((((q, vf), id_chunk), dist_chunk), stats), rc)| { + zipped.for_each_in_pool(pool.as_ref(), |(((((q, vf), id_chunk), dist_chunk), stats), rc)| { let vector_filter = if search_params.vector_filters_file.is_none() { None } else { diff --git a/diskann-benchmark/src/backend/index/product.rs b/diskann-benchmark/src/backend/index/product.rs index a857e4e57..dce9f1f2f 100644 --- a/diskann-benchmark/src/backend/index/product.rs +++ b/diskann-benchmark/src/backend/index/product.rs @@ -141,7 +141,7 @@ mod imp { train_data.as_view(), self.input.num_pq_chunks, &mut StdRng::seed_from_u64(self.input.seed), - build.num_threads, + &diskann_providers::utils::create_thread_pool(build.num_threads)?, )? }; diff --git a/diskann-disk/benches/benchmarks/kmeans_bench.rs b/diskann-disk/benches/benchmarks/kmeans_bench.rs index ddecf526a..79881782c 100644 --- a/diskann-disk/benches/benchmarks/kmeans_bench.rs +++ b/diskann-disk/benches/benchmarks/kmeans_bench.rs @@ -5,7 +5,7 @@ use criterion::Criterion; use diskann_disk::utils::{compute_vecs_l2sq, k_means_clustering}; -use diskann_providers::utils::{create_thread_pool_for_bench, RayonThreadPool}; +use diskann_providers::utils::{create_thread_pool_for_bench, RayonThreadPoolRef}; use rand::Rng; const NUM_POINTS: usize = 100000; @@ -37,7 +37,7 @@ pub fn benchmark_kmeans(c: &mut Criterion) { MAX_KMEANS_REPS, rng, &mut false, - &pool, + pool.as_ref(), ) }) }); @@ -45,13 +45,13 @@ pub fn benchmark_kmeans(c: &mut Criterion) { group.bench_function("Snrm2 Rust Run", |f| { f.iter(|| { let data_copy = data.clone(); - snrm2_benchmark_rust(&data_copy, NUM_POINTS, DIM, &pool); + snrm2_benchmark_rust(&data_copy, NUM_POINTS, DIM, pool.as_ref()); }) }); } /// compute_vecs_l2sq benchmark -fn snrm2_benchmark_rust(data: &[f32], num_points: usize, dim: usize, pool: &RayonThreadPool) { +fn snrm2_benchmark_rust(data: &[f32], num_points: usize, dim: usize, pool: RayonThreadPoolRef<'_>) { let mut docs_l2sq = vec![0.0; num_points]; compute_vecs_l2sq(&mut docs_l2sq, data, num_points, dim, pool).unwrap(); } diff --git a/diskann-disk/benches/benchmarks_iai/kmeans_bench_iai.rs b/diskann-disk/benches/benchmarks_iai/kmeans_bench_iai.rs index 3769077e4..b1fb764b5 100644 --- a/diskann-disk/benches/benchmarks_iai/kmeans_bench_iai.rs +++ b/diskann-disk/benches/benchmarks_iai/kmeans_bench_iai.rs @@ -42,7 +42,7 @@ pub fn benchmark_kmeans_iai(data: Vec) { MAX_KMEANS_REPS, rng, &mut false, - &pool, + pool.as_ref(), ) .unwrap(); @@ -60,5 +60,5 @@ pub fn snrm2_benchmark_rust_iai(data: Vec) { pub fn snrm2_benchmark_rust(data: &[f32], num_points: usize, dim: usize) { let mut docs_l2sq = vec![0.0; num_points]; let pool = create_thread_pool_for_bench(); - compute_vecs_l2sq(&mut docs_l2sq, data, num_points, dim, &pool).unwrap(); + compute_vecs_l2sq(&mut docs_l2sq, data, num_points, dim, pool.as_ref()).unwrap(); } diff --git a/diskann-disk/src/build/builder/build.rs b/diskann-disk/src/build/builder/build.rs index ef1773b9a..dd69640bb 100644 --- a/diskann-disk/src/build/builder/build.rs +++ b/diskann-disk/src/build/builder/build.rs @@ -24,7 +24,7 @@ use diskann_providers::{ }, storage::{AsyncIndexMetadata, DiskGraphOnly, PQStorage}, utils::{ - create_thread_pool, find_medoid_with_sampling, RayonThreadPool, VectorDataIterator, + create_thread_pool, find_medoid_with_sampling, RayonThreadPoolRef, VectorDataIterator, MAX_MEDOID_SAMPLE_SIZE, }, }; @@ -233,10 +233,10 @@ where self.index_configuration.num_threads ); - self.generate_compressed_data(&pool).await?; + self.generate_compressed_data(pool.as_ref()).await?; logger.log_checkpoint(DiskIndexBuildCheckpoint::PqConstruction); - self.build_inmem_index(&pool).await?; + self.build_inmem_index(pool.as_ref()).await?; logger.log_checkpoint(DiskIndexBuildCheckpoint::InmemIndexBuild); // Use physical file to pass the memory index to the disk writer @@ -246,7 +246,7 @@ where Ok(()) } - async fn generate_compressed_data(&mut self, pool: &RayonThreadPool) -> ANNResult<()> { + async fn generate_compressed_data(&mut self, pool: RayonThreadPoolRef<'_>) -> ANNResult<()> { let num_points = self.index_configuration.max_points; let num_chunks = self.disk_build_param.search_pq_chunks(); @@ -289,13 +289,13 @@ where let generator = QuantDataGenerator::< Data::VectorDataType, - PQGeneration, + PQGeneration, >::new( self.index_writer.get_dataset_file(), generator_context, &quantizer_context, )?; - let progress = generator.generate_data(storage_provider, &pool, &self.chunking_config)?; + let progress = generator.generate_data(storage_provider, pool, &self.chunking_config)?; checkpoint_context.update(progress.clone())?; if let Progress::Processed(progress_point) = progress { @@ -310,7 +310,7 @@ where Ok(()) } - async fn build_inmem_index(&mut self, pool: &RayonThreadPool) -> ANNResult<()> { + async fn build_inmem_index(&mut self, pool: RayonThreadPoolRef<'_>) -> ANNResult<()> { match determine_build_strategy::( &self.index_configuration, self.disk_build_param.build_memory_limit().in_bytes() as f64, @@ -324,7 +324,7 @@ where } } - async fn build_merged_vamana_index(&mut self, pool: &RayonThreadPool) -> ANNResult<()> { + async fn build_merged_vamana_index(&mut self, pool: RayonThreadPoolRef<'_>) -> ANNResult<()> { let mut logger = PerfLogger::new_disk_index_build_logger(); let mut workflow = MergedVamanaIndexWorkflow::new(self, pool); diff --git a/diskann-disk/src/build/builder/core.rs b/diskann-disk/src/build/builder/core.rs index f2b049a56..efb9bf697 100644 --- a/diskann-disk/src/build/builder/core.rs +++ b/diskann-disk/src/build/builder/core.rs @@ -11,7 +11,7 @@ use diskann_providers::{ model::{IndexConfiguration, GRAPH_SLACK_FACTOR, MAX_PQ_TRAINING_SET_SIZE}, storage::PQStorage, utils::{ - load_metadata_from_file, RayonThreadPool, SampleVectorReader, SamplingDensity, + load_metadata_from_file, RayonThreadPoolRef, SampleVectorReader, SamplingDensity, READ_WRITE_BLOCK_SIZE, }, }; @@ -468,7 +468,7 @@ pub(crate) fn determine_build_strategy( } pub(crate) struct MergedVamanaIndexWorkflow<'a> { - pool: &'a RayonThreadPool, + pool: RayonThreadPoolRef<'a>, rng: diskann_providers::utils::StandardRng, dataset_file: String, max_degree: u32, @@ -478,7 +478,7 @@ pub(crate) struct MergedVamanaIndexWorkflow<'a> { impl<'a> MergedVamanaIndexWorkflow<'a> { pub(crate) fn new( builder: &mut DiskIndexBuilderCore<'_, Data, StorageProvider>, - pool: &'a RayonThreadPool, + pool: RayonThreadPoolRef<'a>, ) -> Self where Data: GraphDataType, @@ -528,7 +528,7 @@ impl<'a> MergedVamanaIndexWorkflow<'a> { builder.disk_build_param.build_memory_limit().in_bytes() as f64; // calculate how many partitions we need, in order to fit in RAM budget // save id_map for each partition to disk - partition_with_ram_budget::( + partition_with_ram_budget::( &self.dataset_file, builder.index_configuration.dim, sampling_rate, diff --git a/diskann-disk/src/build/builder/quantizer.rs b/diskann-disk/src/build/builder/quantizer.rs index efb4bfd4e..c3eac75fc 100644 --- a/diskann-disk/src/build/builder/quantizer.rs +++ b/diskann-disk/src/build/builder/quantizer.rs @@ -13,7 +13,7 @@ use diskann_providers::{ FixedChunkPQTable, IndexConfiguration, MAX_PQ_TRAINING_SET_SIZE, }, storage::{PQStorage, SQStorage}, - utils::{BridgeErr, PQPathNames}, + utils::{create_thread_pool, BridgeErr, PQPathNames}, }; use diskann_quantization::scalar::train::ScalarQuantizationParameters; use diskann_utils::views::MatrixView; @@ -63,7 +63,7 @@ impl BuildQuantizer { MatrixView::try_from(&train_data, train_size, train_dim).bridge_err()?, num_chunks, &mut rnd, - index_configuration.num_threads, + create_thread_pool(index_configuration.num_threads)?.as_ref(), )? }; // Save at checkpoint. Note the the compressed data path and pivots path here diff --git a/diskann-disk/src/search/provider/disk_provider.rs b/diskann-disk/src/search/provider/disk_provider.rs index b7b30e94a..1e46ef612 100644 --- a/diskann-disk/src/search/provider/disk_provider.rs +++ b/diskann-disk/src/search/provider/disk_provider.rs @@ -1384,7 +1384,7 @@ mod disk_provider_tests { queries .par_row_iter() .enumerate() - .for_each_in_pool(&pool, |(i, query)| { + .for_each_in_pool(pool.as_ref(), |(i, query)| { let mut query_stats = QueryStatistics::default(); let mut indices = vec![0u32; 10]; let mut distances = vec![0f32; 10]; @@ -1444,7 +1444,7 @@ mod disk_provider_tests { queries .par_row_iter() .enumerate() - .for_each_in_pool(&pool, |(i, query)| { + .for_each_in_pool(pool.as_ref(), |(i, query)| { let result = params .index_search_engine .search(query, params.k as u32, params.l as u32, beam_width, None, false) diff --git a/diskann-disk/src/storage/quant/generator.rs b/diskann-disk/src/storage/quant/generator.rs index 8e5006d6f..34b31146d 100644 --- a/diskann-disk/src/storage/quant/generator.rs +++ b/diskann-disk/src/storage/quant/generator.rs @@ -10,9 +10,8 @@ use std::{ use diskann::{error::IntoANNResult, utils::VectorRepr, ANNError, ANNResult}; use diskann_providers::storage::{StorageReadProvider, StorageWriteProvider}; -use diskann_providers::{ - forward_threadpool, - utils::{load_metadata_from_file, AsThreadPool, BridgeErr, ParallelIteratorInPool, Timer}, +use diskann_providers::utils::{ + load_metadata_from_file, BridgeErr, ParallelIteratorInPool, RayonThreadPoolRef, Timer, }; use diskann_utils::{io::Metadata, views}; use rayon::iter::IndexedParallelIterator; @@ -99,15 +98,14 @@ where /// 4. Processes data in blocks of size given by chunking_config.data_compression_chunk_vector_count = 50_000 /// 5. Compresses each block in small batch sizes in parallel to (potentially) take advantage of batch compression with quantizer /// 6. Writes compressed blocks to the output file. - pub fn generate_data( + pub fn generate_data( &self, storage_provider: &Storage, // Provider for reading source data and writing compressed results - pool: &Pool, // Thread pool for parallel processing + pool: RayonThreadPoolRef<'_>, // Thread pool for parallel processing chunking_config: &ChunkingConfig, // Configuration for batching and checkpoint handling ) -> ANNResult where Storage: StorageReadProvider + StorageWriteProvider, - Pool: AsThreadPool, { let timer = Timer::new(); @@ -157,7 +155,6 @@ where let mut compressed_buffer = vec![0_u8; block_size * compressed_size]; - forward_threadpool!(pool = pool: Pool); //Every block has size exactly block_size, except for potentially the last one let action = |block_index| -> ANNResult<()> { let start_index: usize = offset + block_index * block_size; @@ -431,7 +428,7 @@ mod generator_tests { ) .unwrap(); // Run generator - let result = generator.generate_data(storage_provider, &&pool, chunking_config); + let result = generator.generate_data(storage_provider, pool.as_ref(), chunking_config); (generator, result) } diff --git a/diskann-disk/src/storage/quant/pq/pq_generation.rs b/diskann-disk/src/storage/quant/pq/pq_generation.rs index a8a1557c7..ccd2c30a7 100644 --- a/diskann-disk/src/storage/quant/pq/pq_generation.rs +++ b/diskann-disk/src/storage/quant/pq/pq_generation.rs @@ -8,13 +8,12 @@ use std::marker::PhantomData; use diskann::{utils::VectorRepr, ANNError}; use diskann_providers::storage::{StorageReadProvider, StorageWriteProvider}; use diskann_providers::{ - forward_threadpool, model::{ pq::{accum_row_inplace, generate_pq_pivots}, GeneratePivotArguments, }, storage::PQStorage, - utils::{AsThreadPool, BridgeErr, Timer}, + utils::{BridgeErr, RayonThreadPoolRef, Timer}, }; use diskann_quantization::{product::TransposedTable, CompressInto}; use diskann_utils::views::MatrixBase; @@ -23,43 +22,39 @@ use tracing::info; use crate::storage::quant::compressor::{CompressionStage, QuantCompressor}; -pub struct PQGenerationContext<'a, Storage, Pool> +pub struct PQGenerationContext<'a, Storage> where Storage: StorageReadProvider + StorageWriteProvider, - Pool: AsThreadPool, { pub pq_storage: PQStorage, pub num_chunks: usize, pub seed: Option, pub p_val: f64, pub storage_provider: &'a Storage, - pub pool: Pool, + pub pool: RayonThreadPoolRef<'a>, pub metric: Metric, pub dim: usize, pub max_kmeans_reps: usize, pub num_centers: usize, } -pub struct PQGeneration<'a, T, Storage, Pool> +pub struct PQGeneration<'a, T, Storage> where T: VectorRepr, Storage: StorageReadProvider + StorageWriteProvider + 'a, - Pool: AsThreadPool, { table: TransposedTable, num_chunks: usize, phantom_data: PhantomData, phantom_storage: PhantomData<&'a Storage>, - phantom_pool: PhantomData, } -impl<'a, T, Storage, Pool> QuantCompressor for PQGeneration<'a, T, Storage, Pool> +impl<'a, T, Storage> QuantCompressor for PQGeneration<'a, T, Storage> where T: VectorRepr, Storage: StorageReadProvider + StorageWriteProvider + 'a, - Pool: AsThreadPool, { - type CompressorContext = PQGenerationContext<'a, Storage, Pool>; + type CompressorContext = PQGenerationContext<'a, Storage>; fn new_at_stage( stage: CompressionStage, @@ -76,8 +71,7 @@ where .pq_storage .pivot_data_exist(context.storage_provider); - let pool = &context.pool; - forward_threadpool!(pool = pool: Pool); + let pool = context.pool; if !pivots_exists { if stage == CompressionStage::Resume { @@ -156,7 +150,6 @@ where table, num_chunks, phantom_data: PhantomData, - phantom_pool: PhantomData, phantom_storage: PhantomData, }) } @@ -188,7 +181,7 @@ mod pq_generation_tests { use diskann_providers::storage::{ PQStorage, StorageReadProvider, StorageWriteProvider, VirtualStorageProvider, }; - use diskann_providers::utils::{create_thread_pool_for_test, AsThreadPool}; + use diskann_providers::utils::{create_thread_pool_for_test, RayonThreadPoolRef}; use diskann_utils::{ io::{read_bin, write_bin}, test_data_root, @@ -212,7 +205,7 @@ mod pq_generation_tests { 100.0f32, 100.0f32, 100.0f32, 100.0f32, 100.0f32, 100.0f32, 100.0f32, ]; #[allow(clippy::too_many_arguments)] - fn create_new_compressor<'a, R: AsThreadPool, F: vfs::FileSystem>( + fn create_new_compressor<'a, F: vfs::FileSystem>( stage: CompressionStage, provider: &'a VirtualStorageProvider, dim: usize, @@ -220,13 +213,13 @@ mod pq_generation_tests { max_kmeans_reps: usize, num_centers: usize, p_val: f64, - pool: R, + pool: RayonThreadPoolRef<'a>, pivots_path: String, compressed_path: String, data_path: Option<&str>, - ) -> Result, R>, ANNError> { + ) -> Result>, ANNError> { let pq_storage = PQStorage::new(&pivots_path, &compressed_path, data_path); - let context = PQGenerationContext::<'_, _, _> { + let context = PQGenerationContext::<'_, _> { pq_storage, num_chunks, num_centers, @@ -238,7 +231,7 @@ mod pq_generation_tests { metric: Metric::L2, dim, }; - PQGeneration::<_, _, _>::new_at_stage(stage, &context) + PQGeneration::<_, _>::new_at_stage(stage, &context) } #[rstest] @@ -280,7 +273,7 @@ mod pq_generation_tests { &pq_storage, &storage_provider, diskann_providers::utils::create_rnd_provider_from_seed_in_tests(42), - &pool, + pool.as_ref(), ) .unwrap(); @@ -292,7 +285,7 @@ mod pq_generation_tests { max_k_means_reps, num_centers, 1.0, //take all the data to compute codebook - &pool, + pool.as_ref(), pivot_file_name_compressor.to_string(), compressed_file_name.to_string(), Some(data_path), @@ -349,7 +342,7 @@ mod pq_generation_tests { max_k_means_reps, num_centers, 1.0, - &pool, + pool.as_ref(), pivot_file_name.to_string(), compressed_file_name.to_string(), Some(data_path), @@ -375,7 +368,7 @@ mod pq_generation_tests { max_k_means_reps, 256, 1.0, - &pool, + pool.as_ref(), TEST_PQ_PIVOTS_PATH.to_string(), "".to_string(), None, @@ -427,7 +420,7 @@ mod pq_generation_tests { max_k_means_reps, centers, 1.0, - &pool, + pool.as_ref(), TEST_PQ_PIVOTS_PATH.to_string(), "".to_string(), None, diff --git a/diskann-disk/src/utils/kmeans.rs b/diskann-disk/src/utils/kmeans.rs index 60f5519ef..63a87e5e2 100644 --- a/diskann-disk/src/utils/kmeans.rs +++ b/diskann-disk/src/utils/kmeans.rs @@ -12,10 +12,7 @@ use std::cmp::min; use diskann::{ANNError, ANNResult}; -use diskann_providers::{ - forward_threadpool, - utils::{AsThreadPool, ParallelIteratorInPool, RayonThreadPool}, -}; +use diskann_providers::utils::{ParallelIteratorInPool, RayonThreadPoolRef}; use diskann_vector::{distance::SquaredL2, PureDistanceFunction}; use hashbrown::HashSet; use rand::{ @@ -43,7 +40,7 @@ fn lloyds_iter( docs_l2sq: &[f32], closest_docs: &mut Vec>, closest_center: &mut [u32], - pool: &RayonThreadPool, + pool: RayonThreadPoolRef<'_>, ) -> ANNResult { let compute_residual = true; @@ -118,7 +115,7 @@ fn lloyds_iter( /// new vec [num_centers]`, and `closest_center = new size_t[num_points]` /// Final centers are output in centers as row-major num_centers * dim. #[allow(clippy::too_many_arguments)] -pub fn run_lloyds( +pub fn run_lloyds( data: &[f32], num_points: usize, dim: usize, @@ -126,7 +123,7 @@ pub fn run_lloyds( num_centers: usize, max_reps: usize, cancellation_token: &mut bool, - pool: Pool, + pool: RayonThreadPoolRef<'_>, ) -> ANNResult<(Vec>, Vec, f32)> { let mut residual = f32::MAX; @@ -135,7 +132,6 @@ pub fn run_lloyds( let mut docs_l2sq = vec![0.0; num_points]; - forward_threadpool!(pool = pool); compute_vecs_l2sq(&mut docs_l2sq, data, num_points, dim, pool)?; let mut old_residual; @@ -231,7 +227,7 @@ fn select_random_pivots( /// If there are are fewer than num_center distinct points, pick all unique points as pivots, /// and sample data randomly for the remaining pivots. #[allow(clippy::too_many_arguments)] -pub fn k_meanspp_selecting_pivots( +pub fn k_meanspp_selecting_pivots( data: &[f32], num_points: usize, dim: usize, @@ -239,7 +235,7 @@ pub fn k_meanspp_selecting_pivots( num_centers: usize, rng: &mut impl Rng, cancellation_token: &mut bool, - pool: Pool, + pool: RayonThreadPoolRef<'_>, ) -> ANNResult<()> { if num_points > (1 << 23) { return Err(ANNError::log_kmeans_error(format!( @@ -280,7 +276,6 @@ pub fn k_meanspp_selecting_pivots( let mut dist = vec![0.0; num_points]; - forward_threadpool!(pool = pool); // Calculate the distance between each node and the first pivot and store the result in dist. dist.par_iter_mut() .enumerate() @@ -394,7 +389,7 @@ pub fn k_meanspp_selecting_pivots( /// k-means algorithm interface #[allow(clippy::too_many_arguments)] -pub fn k_means_clustering( +pub fn k_means_clustering( data: &[f32], num_points: usize, dim: usize, @@ -403,10 +398,8 @@ pub fn k_means_clustering( max_reps: usize, rng: &mut impl Rng, cancellation_token: &mut bool, - pool: Pool, + pool: RayonThreadPoolRef<'_>, ) -> ANNResult<(Vec>, Vec, f32)> { - forward_threadpool!(pool = pool); - k_meanspp_selecting_pivots( data, num_points, @@ -473,7 +466,7 @@ mod kmeans_test { &docs_l2sq, &mut closest_docs, &mut closest_center, - &pool, + pool.as_ref(), ) .unwrap(); @@ -515,7 +508,7 @@ mod kmeans_test { num_centers, max_reps, &mut (false), - &pool, + pool.as_ref(), ) .unwrap(); @@ -558,7 +551,7 @@ mod kmeans_test { num_centers, max_reps, cancellation_token, - &pool, + pool.as_ref(), ) .unwrap_err(); @@ -699,7 +692,7 @@ mod kmeans_test { num_centers, &mut create_rnd_in_tests(), &mut (false), - &pool, + pool.as_ref(), ) .unwrap(); @@ -751,7 +744,7 @@ mod kmeans_test { num_centers, &mut create_rnd_in_tests(), &mut (false), - &pool, + pool.as_ref(), ) .unwrap(); @@ -765,7 +758,7 @@ mod kmeans_test { num_centers + 1, &mut create_rnd_in_tests(), &mut (false), - &pool, + pool.as_ref(), ) .unwrap(); @@ -779,7 +772,7 @@ mod kmeans_test { num_points, &mut create_rnd_in_tests(), &mut (false), - &pool, + pool.as_ref(), ) .unwrap(); } @@ -829,7 +822,7 @@ mod kmeans_test { num_centers, &mut create_rnd_in_tests(), &mut (false), - &pool, + pool.as_ref(), ) .unwrap(); @@ -945,7 +938,7 @@ mod kmeans_test { num_centers, &mut create_rnd_in_tests(), cancellation_token, - &pool, + pool.as_ref(), ) .unwrap_err(); @@ -968,7 +961,7 @@ mod kmeans_test { let num_centers = 5; let mut pivot_data = vec![0.0; num_centers * pq_dim]; let pool = create_thread_pool_for_test(); - k_meanspp_selecting_pivots(&data, num_points, pq_dim, &mut pivot_data, num_centers, &mut create_rnd_in_tests(), &mut (false),&pool).unwrap(); + k_meanspp_selecting_pivots(&data, num_points, pq_dim, &mut pivot_data, num_centers, &mut create_rnd_in_tests(), &mut (false),pool.as_ref()).unwrap(); } } proptest! { @@ -982,7 +975,7 @@ mod kmeans_test { let num_centers = 5; let mut pivot_data = vec![0.0; num_centers * pq_dim]; let pool = create_thread_pool_for_test(); - k_meanspp_selecting_pivots(&data, num_points, pq_dim, &mut pivot_data, num_centers, &mut create_rnd_in_tests(), &mut (false),&pool).unwrap(); + k_meanspp_selecting_pivots(&data, num_points, pq_dim, &mut pivot_data, num_centers, &mut create_rnd_in_tests(), &mut (false),pool.as_ref()).unwrap(); } } proptest! { @@ -996,7 +989,7 @@ mod kmeans_test { let num_centers = 5; let mut pivot_data = vec![0.0; num_centers * pq_dim]; let pool = create_thread_pool_for_test(); - k_meanspp_selecting_pivots(&data, num_points, pq_dim, &mut pivot_data, num_centers, &mut create_rnd_in_tests(), &mut (false),&pool).unwrap(); + k_meanspp_selecting_pivots(&data, num_points, pq_dim, &mut pivot_data, num_centers, &mut create_rnd_in_tests(), &mut (false),pool.as_ref()).unwrap(); } } } diff --git a/diskann-disk/src/utils/math_util.rs b/diskann-disk/src/utils/math_util.rs index 5acc7207d..e048e0fc6 100644 --- a/diskann-disk/src/utils/math_util.rs +++ b/diskann-disk/src/utils/math_util.rs @@ -14,10 +14,7 @@ use std::{cmp::Ordering, collections::BinaryHeap}; use diskann::{ANNError, ANNResult}; use diskann_linalg::{self, Transpose}; -use diskann_providers::{ - forward_threadpool, - utils::{AsThreadPool, ParallelIteratorInPool, RayonThreadPool}, -}; +use diskann_providers::utils::{ParallelIteratorInPool, RayonThreadPoolRef}; use rayon::prelude::*; // This is the chunk size applied when computing the closest centers in a block. @@ -90,12 +87,12 @@ fn compute_vec_l2sq(data: &[f32], index: usize, dim: usize) -> f32 { /// Compute L2-squared norms of data stored in row-major num_points * dim, /// need to be pre-allocated -pub fn compute_vecs_l2sq( +pub fn compute_vecs_l2sq( vecs_l2sq: &mut [f32], data: &[f32], num_points: usize, dim: usize, - pool: Pool, + pool: RayonThreadPoolRef<'_>, ) -> ANNResult<()> { if data.len() != num_points * dim { return Err(ANNError::log_pq_error(format_args!( @@ -111,7 +108,6 @@ pub fn compute_vecs_l2sq( *vec_l2sq = compute_vec_l2sq(data, i, dim); } } else { - forward_threadpool!(pool = pool); vecs_l2sq .par_iter_mut() .enumerate() @@ -143,7 +139,7 @@ pub fn compute_closest_centers_in_block( center_index: &mut [u32], dist_matrix: &mut [f32], k: usize, - pool: &RayonThreadPool, + pool: RayonThreadPoolRef<'_>, ) -> ANNResult<()> { if k > num_centers { return Err(ANNError::log_index_error(format_args!( @@ -246,7 +242,7 @@ pub fn compute_closest_centers_in_block( /// indices is an empty vector. Additionally, if pts_norms_squared is not null, /// then it will assume that point norms are pre-computed and use those values #[allow(clippy::too_many_arguments)] -pub fn compute_closest_centers( +pub fn compute_closest_centers( data: &[f32], num_points: usize, dim: usize, @@ -256,7 +252,7 @@ pub fn compute_closest_centers( closest_centers_ivf: &mut [u32], mut inverted_index: Option<&mut Vec>>, pts_norms_squared: Option<&[f32]>, - pool: Pool, + pool: RayonThreadPoolRef<'_>, ) -> ANNResult<()> { if k > num_centers { return Err(ANNError::log_index_error(format_args!( @@ -265,8 +261,6 @@ pub fn compute_closest_centers( ))); } - forward_threadpool!(pool = pool); - let pts_norms_squared = if let Some(pts_norms) = pts_norms_squared { pts_norms.to_vec() } else { @@ -370,7 +364,7 @@ mod math_util_test { let mut vecs_l2sq = vec![0.0; num_points]; let pool = create_thread_pool_for_test(); - compute_vecs_l2sq(&mut vecs_l2sq, &data, num_points, dim, &pool).unwrap(); + compute_vecs_l2sq(&mut vecs_l2sq, &data, num_points, dim, pool.as_ref()).unwrap(); let expected = [14.0, 77.0]; @@ -388,7 +382,7 @@ mod math_util_test { let dim = 8; let mut vecs_l2sq = vec![0.0; num_points]; let pool = create_thread_pool_for_test(); - compute_vecs_l2sq(&mut vecs_l2sq, &data, num_points, dim, &pool).unwrap(); + compute_vecs_l2sq(&mut vecs_l2sq, &data, num_points, dim, pool.as_ref()).unwrap(); let expected = [204.0, 1292.0]; @@ -413,9 +407,9 @@ mod math_util_test { ]; let mut docs_l2sq = vec![0.0; num_points]; let pool = create_thread_pool_for_test(); - compute_vecs_l2sq(&mut docs_l2sq, &data, num_points, dim, &pool).unwrap(); + compute_vecs_l2sq(&mut docs_l2sq, &data, num_points, dim, pool.as_ref()).unwrap(); let mut centers_l2sq = vec![0.0; num_centers]; - compute_vecs_l2sq(&mut centers_l2sq, ¢ers, num_centers, dim, &pool).unwrap(); + compute_vecs_l2sq(&mut centers_l2sq, ¢ers, num_centers, dim, pool.as_ref()).unwrap(); let mut center_index = vec![0; num_points]; let mut dist_matrix = vec![0.0; num_points * num_centers]; let k = 1; @@ -431,7 +425,7 @@ mod math_util_test { &mut center_index, &mut dist_matrix, k, - &pool, + pool.as_ref(), ) .unwrap(); @@ -460,9 +454,9 @@ mod math_util_test { ]; let mut docs_l2sq = vec![0.0; num_points]; let pool = create_thread_pool_for_test(); - compute_vecs_l2sq(&mut docs_l2sq, &data, num_points, dim, &pool).unwrap(); + compute_vecs_l2sq(&mut docs_l2sq, &data, num_points, dim, pool.as_ref()).unwrap(); let mut centers_l2sq = vec![0.0; num_centers]; - compute_vecs_l2sq(&mut centers_l2sq, ¢ers, num_centers, dim, &pool).unwrap(); + compute_vecs_l2sq(&mut centers_l2sq, ¢ers, num_centers, dim, pool.as_ref()).unwrap(); let k = 2; let mut center_index = vec![0; num_points * k]; let mut dist_matrix = vec![0.0; num_points * num_centers]; @@ -478,7 +472,7 @@ mod math_util_test { &mut center_index, &mut dist_matrix, k, - &pool, + pool.as_ref(), ) .unwrap(); @@ -519,7 +513,7 @@ mod math_util_test { &mut closest_centers_ivf, Some(&mut inverted_index), None, - &pool, + pool.as_ref(), ) .unwrap(); diff --git a/diskann-disk/src/utils/partition.rs b/diskann-disk/src/utils/partition.rs index 7ca7a6f87..5b9dcffac 100644 --- a/diskann-disk/src/utils/partition.rs +++ b/diskann-disk/src/utils/partition.rs @@ -4,10 +4,7 @@ */ use diskann::{error::IntoANNResult, utils::VectorRepr, ANNError, ANNResult}; use diskann_providers::storage::{StorageReadProvider, StorageWriteProvider}; -use diskann_providers::{ - forward_threadpool, - utils::{gen_random_slice, AsThreadPool, RayonThreadPool, READ_WRITE_BLOCK_SIZE}, -}; +use diskann_providers::utils::{gen_random_slice, RayonThreadPoolRef, READ_WRITE_BLOCK_SIZE}; use crate::utils::{compute_closest_centers, k_meanspp_selecting_pivots, run_lloyds}; use rand::Rng; @@ -22,7 +19,7 @@ use crate::{ const BLOCK_SIZE_LARGE_FILE: u32 = 10_000; #[allow(clippy::too_many_arguments)] -pub fn partition_with_ram_budget( +pub fn partition_with_ram_budget( dataset_file: &str, dim: usize, sampling_rate: f64, @@ -31,16 +28,14 @@ pub fn partition_with_ram_budget( merged_index_prefix: &str, storage_provider: &StorageProvider, rng: &mut impl Rng, - pool: Pool, + pool: RayonThreadPoolRef<'_>, ram_estimator: F, ) -> ANNResult where T: VectorRepr, StorageProvider: StorageReadProvider + StorageWriteProvider, - Pool: AsThreadPool, F: Fn(u64, u64) -> f64, { - forward_threadpool!(pool = pool); // Find partition size and get pivot data let (num_parts, pivot_data, train_dim) = find_partition_size::( dataset_file, @@ -78,7 +73,7 @@ fn find_partition_size( k_base: usize, storage_provider: &StorageProvider, rng: &mut impl Rng, - pool: &RayonThreadPool, + pool: RayonThreadPoolRef<'_>, ram_estimator: &F, ) -> ANNResult<(usize, Vec, usize)> where @@ -245,7 +240,7 @@ fn shard_data_into_clusters_only_ids( k_base: usize, merged_index_prefix: &str, storage_provider: &StorageProvider, - pool: &RayonThreadPool, + pool: RayonThreadPoolRef<'_>, ) -> ANNResult<()> where T: VectorRepr, @@ -369,7 +364,7 @@ fn estimate_cluster_sizes( dim: usize, k_base: usize, cluster_sizes: &mut Vec, - pool: &RayonThreadPool, + pool: RayonThreadPoolRef<'_>, ) -> ANNResult<()> { cluster_sizes.clear(); let mut shard_counts = vec![0; num_centers]; @@ -450,7 +445,7 @@ mod partition_test { dim, k_base, &mut cluster_sizes, - &pool, + pool.as_ref(), ) .unwrap(); @@ -511,7 +506,7 @@ mod partition_test { k_base, merged_index_prefix, &storage_provider, - &pool, + pool.as_ref(), ) .unwrap(); @@ -565,7 +560,7 @@ mod partition_test { let merged_index_prefix = "/test_merged_index_prefix"; let pool = create_thread_pool_for_test(); - let num_parts = partition_with_ram_budget::( + let num_parts = partition_with_ram_budget::( dataset_file, 128, //sift is 128 dimensions sampling_rate, @@ -574,7 +569,7 @@ mod partition_test { merged_index_prefix, &storage_provider, &mut diskann_providers::utils::create_rnd_in_tests(), - &pool, + pool.as_ref(), |num_points, dim| { // Simple RAM estimation for test - capture datasize and graph_degree from context use diskann_providers::model::GRAPH_SLACK_FACTOR; diff --git a/diskann-providers/benches/benchmarks/compute_pq_bench.rs b/diskann-providers/benches/benchmarks/compute_pq_bench.rs index b89891247..e79b1d22d 100644 --- a/diskann-providers/benches/benchmarks/compute_pq_bench.rs +++ b/diskann-providers/benches/benchmarks/compute_pq_bench.rs @@ -59,7 +59,7 @@ fn generate_benchmark_data( query_centroid_l2_distance .par_chunks_mut(NUM_PQ_CENTROIDS) .enumerate() - .for_each_in_pool(&pool, |(_, chunk)| chunk.copy_from_slice(&vec_256)); + .for_each_in_pool(pool.as_ref(), |(_, chunk)| chunk.copy_from_slice(&vec_256)); let pq_data: Vec = (0..num_pq_chunks * n_pts) .map(|_| rng.random_range(0..256) as u8) diff --git a/diskann-providers/benches/benchmarks/diskann_bench.rs b/diskann-providers/benches/benchmarks/diskann_bench.rs index b6f190cdc..e12d6ab66 100644 --- a/diskann-providers/benches/benchmarks/diskann_bench.rs +++ b/diskann-providers/benches/benchmarks/diskann_bench.rs @@ -62,7 +62,7 @@ async fn test_sift_256_vectors_with_quant_vectors() { train_data.as_view(), 32, &mut diskann_providers::utils::create_rnd_in_tests(), - &pool, + pool.as_ref(), ) .unwrap(); diff --git a/diskann-providers/benches/benchmarks_iai/compute_pq_iai.rs b/diskann-providers/benches/benchmarks_iai/compute_pq_iai.rs index caba9bf28..7ce33f04d 100644 --- a/diskann-providers/benches/benchmarks_iai/compute_pq_iai.rs +++ b/diskann-providers/benches/benchmarks_iai/compute_pq_iai.rs @@ -54,7 +54,7 @@ fn generate_benchmark_data() -> (Vec, Vec, Vec) { query_centroid_l2_distance .par_chunks_mut(NUM_PQ_CENTROIDS) .enumerate() - .for_each_in_pool(&pool, |(_, chunk)| chunk.copy_from_slice(&vec_256)); + .for_each_in_pool(pool.as_ref(), |(_, chunk)| chunk.copy_from_slice(&vec_256)); let pq_data: Vec = (0..NUM_PQ_CHUNKS * n_pts) .map(|_| rng.random_range(0..256) as u8) diff --git a/diskann-providers/benches/benchmarks_iai/diskann_iai.rs b/diskann-providers/benches/benchmarks_iai/diskann_iai.rs index ee748575f..0610088d6 100644 --- a/diskann-providers/benches/benchmarks_iai/diskann_iai.rs +++ b/diskann-providers/benches/benchmarks_iai/diskann_iai.rs @@ -58,7 +58,7 @@ async fn test_sift_256_vectors_with_quant_vectors() { train_data.as_view(), 32, &mut diskann_providers::utils::create_rnd_in_tests(), - &pool, + pool.as_ref(), ) .unwrap(); diff --git a/diskann-providers/src/index/diskann_async.rs b/diskann-providers/src/index/diskann_async.rs index 56d3d4436..2a8e24668 100644 --- a/diskann-providers/src/index/diskann_async.rs +++ b/diskann-providers/src/index/diskann_async.rs @@ -58,15 +58,12 @@ pub(crate) fn simplified_builder( Ok((config, params)) } -pub fn train_pq( +pub fn train_pq( data: diskann_utils::views::MatrixView, num_pq_chunks: usize, rng: &mut dyn rand::RngCore, - pool: Pool, -) -> ANNResult -where - Pool: crate::utils::AsThreadPool, -{ + pool: crate::utils::RayonThreadPoolRef<'_>, +) -> ANNResult { let dim = data.ncols(); let pivot_args = model::GeneratePivotArguments::new( data.nrows(), @@ -667,7 +664,7 @@ pub(crate) mod tests { squish(vectors.iter(), dim).as_view(), 2.min(dim), // Number of PQ chunks is bounded by the dimension. &mut create_rnd_from_seed_in_tests(0x04a8832604476965), - 1usize, + crate::utils::create_thread_pool(1).unwrap().as_ref(), ) .unwrap(); @@ -764,7 +761,7 @@ pub(crate) mod tests { matrix.map(|i| (*i).into()).as_view(), 2.min(dim), // Number of PQ chunks is bounded by the dimension. &mut create_rnd_from_seed_in_tests(0x04a8832604476965), - 1usize, + crate::utils::create_thread_pool(1).unwrap().as_ref(), ) .unwrap(); @@ -936,7 +933,13 @@ pub(crate) mod tests { let data = T::generate_spherical(num, dim, radius, rng); let table = { let train_data: diskann_utils::views::Matrix = squish(data.iter(), dim); - train_pq(train_data.as_view(), 2.min(dim), rng, 1usize).unwrap() + train_pq( + train_data.as_view(), + 2.min(dim), + rng, + crate::utils::create_thread_pool(1).unwrap().as_ref(), + ) + .unwrap() }; let index = new_quant_index::(config, params, table, NoDeletes).unwrap(); @@ -1116,7 +1119,7 @@ pub(crate) mod tests { squish(vectors.iter(), dim).as_view(), 2.min(dim), // Number of PQ chunks is bounded by the dimension. &mut create_rnd_from_seed_in_tests(0x04a8832604476965), - 1usize, + crate::utils::create_thread_pool(1).unwrap().as_ref(), ) .unwrap(); @@ -1233,7 +1236,7 @@ pub(crate) mod tests { squish(vectors.iter(), dim).as_view(), 2.min(dim), // Number of PQ chunks is bounded by the dimension. &mut create_rnd_from_seed_in_tests(0x04a8832604476965), - 1usize, + crate::utils::create_thread_pool(1).unwrap().as_ref(), ) .unwrap(); @@ -1391,7 +1394,7 @@ pub(crate) mod tests { squish(vectors.iter(), dim).as_view(), 2.min(dim), &mut create_rnd_from_seed_in_tests(0xdd81b895605c73d4), - 1usize, + crate::utils::create_thread_pool(1).unwrap().as_ref(), ) .unwrap(); @@ -2463,7 +2466,7 @@ pub(crate) mod tests { data.as_view(), 32, &mut create_rnd_from_seed_in_tests(0xe3c52ef001bc7ade), - 1, + crate::utils::create_thread_pool(1).unwrap().as_ref(), ) .unwrap(); @@ -2669,7 +2672,7 @@ pub(crate) mod tests { train_data.as_view(), num_pq_chunks, &mut create_rnd_from_seed_in_tests(0xe3c52ef001bc7ade), - 1, + crate::utils::create_thread_pool(1).unwrap().as_ref(), ) .unwrap(); @@ -3624,7 +3627,7 @@ pub(crate) mod tests { squish(vectors.iter(), dim).as_view(), 2.min(dim), &mut create_rnd_from_seed_in_tests(0x1234567890abcdef), - 1usize, + crate::utils::create_thread_pool(1).unwrap().as_ref(), ) .unwrap(); @@ -3687,7 +3690,7 @@ pub(crate) mod tests { squish(vectors.iter(), dim).as_view(), 2.min(dim), &mut create_rnd_from_seed_in_tests(0xfedcba0987654321), - 1usize, + crate::utils::create_thread_pool(1).unwrap().as_ref(), ) .unwrap(); @@ -3755,7 +3758,7 @@ pub(crate) mod tests { squish(vectors.iter(), dim).as_view(), 2.min(dim), &mut create_rnd_from_seed_in_tests(0xabcdef1234567890), - 1usize, + crate::utils::create_thread_pool(1).unwrap().as_ref(), ) .unwrap(); @@ -3902,7 +3905,7 @@ pub(crate) mod tests { squish(vectors.iter(), dim).as_view(), 2.min(dim), &mut create_rnd_from_seed_in_tests(0x9876543210fedcba), - 1usize, + crate::utils::create_thread_pool(1).unwrap().as_ref(), ) .unwrap(); diff --git a/diskann-providers/src/index/wrapped_async.rs b/diskann-providers/src/index/wrapped_async.rs index df554b5e8..450076fbb 100644 --- a/diskann-providers/src/index/wrapped_async.rs +++ b/diskann-providers/src/index/wrapped_async.rs @@ -458,7 +458,7 @@ mod tests { train_data.as_view(), pq_bytes, &mut create_rnd_from_seed_in_tests(0xe3c52ef001bc7ade), - 2, + crate::utils::create_thread_pool(2).unwrap().as_ref(), ) .unwrap(); diff --git a/diskann-providers/src/model/pq/pq_construction.rs b/diskann-providers/src/model/pq/pq_construction.rs index b2ab3da85..9f5ebe781 100644 --- a/diskann-providers/src/model/pq/pq_construction.rs +++ b/diskann-providers/src/model/pq/pq_construction.rs @@ -29,11 +29,10 @@ use rayon::prelude::*; use tracing::info; use crate::{ - forward_threadpool, model::GeneratePivotArguments, storage::PQStorage, utils::{ - AsThreadPool, BridgeErr, ParallelIteratorInPool, RandomProvider, Timer, + BridgeErr, ParallelIteratorInPool, RandomProvider, RayonThreadPoolRef, Timer, create_rnd_provider_from_seed, }, }; @@ -64,18 +63,17 @@ where /// k-means in each chunk to compute the PQ pivots and stores in bin format in /// file pq_pivots_path as a s num_centers*dim floating point binary file /// PQ pivot table layout: {pivot offsets data: METADATA_SIZE}{pivot vector:[dim; num_centroid]}{centroid vector:[dim; 1]}{chunk offsets:[chunk_num+1; 1]} -pub fn generate_pq_pivots( +pub fn generate_pq_pivots( parameters: GeneratePivotArguments, train_data: &mut [f32], pq_storage: &PQStorage, storage_provider: &Storage, random_provider: RandomProvider, - pool: Pool, + pool: RayonThreadPoolRef<'_>, ) -> ANNResult<()> where Storage: StorageWriteProvider + StorageReadProvider, Random: Rng, - Pool: AsThreadPool, { if pq_storage.pivot_data_exist(storage_provider) { let (file_num_centers, file_dim) = @@ -103,7 +101,6 @@ where &mut chunk_offsets, ); - forward_threadpool!(pool = pool); let trainer = diskann_quantization::product::train::LightPQTrainingParameters::new( parameters.num_centers(), parameters.max_k_means_reps(), @@ -151,7 +148,7 @@ where /// /// Result is stored in the `full_pivot_data`, which must be of size `num_centers * dim`. #[allow(clippy::too_many_arguments)] -pub fn generate_pq_pivots_from_membuf, Pool: AsThreadPool>( +pub fn generate_pq_pivots_from_membuf>( parameters: &GeneratePivotArguments, train_data_slice: &[T], centroid: &mut [f32], @@ -159,7 +156,7 @@ pub fn generate_pq_pivots_from_membuf, Pool: AsThreadPool>( full_pivot_data: &mut [f32], rng: &mut (impl Rng + ?Sized), cancellation_token: &mut bool, - pool: Pool, + pool: RayonThreadPoolRef<'_>, ) -> ANNResult<()> { if full_pivot_data.len() != parameters.num_centers() * parameters.dim() { return Err(ANNError::log_pq_error( @@ -208,7 +205,6 @@ pub fn generate_pq_pivots_from_membuf, Pool: AsThreadPool>( // Calculate the chunk offsets calculate_chunk_offsets(parameters.dim(), parameters.num_pq_chunks(), offsets); - forward_threadpool!(pool = pool); let trainer = diskann_quantization::product::train::LightPQTrainingParameters::new( parameters.num_centers(), parameters.max_k_means_reps(), @@ -382,18 +378,17 @@ where /// Compressed PQ table layout: {num_points: usize}{num_chunks: usize}{compressed pq table: [num_points; num_chunks]} /// It will start from the start_vector_id and compress the data_file in chunks. /// It validates the existing compressed data_file is consistent with the start_vector_id. -pub fn generate_pq_data_from_pivots( +pub fn generate_pq_data_from_pivots( num_centers: usize, num_pq_chunks: usize, pq_storage: &mut PQStorage, storage_provider: &Storage, offset: usize, - pool: Pool, + pool: RayonThreadPoolRef<'_>, ) -> ANNResult<()> where T: Copy + VectorRepr, Storage: StorageWriteProvider + StorageReadProvider, - Pool: AsThreadPool, { let timer = Timer::new(); @@ -469,7 +464,6 @@ where let mut buffer = vec![0.0; full_dim * block_size]; - forward_threadpool!(pool = pool); for block_index in 0..num_blocks { let start_index: usize = offset + block_index * block_size; let end_index: usize = std::cmp::min(start_index + block_size, num_points); @@ -612,17 +606,14 @@ pub fn generate_pq_data_from_pivots_from_membuf>( /// PQ pivots computed earlier, partition the co-ordinates into /// `num_pq_chunks`, and find the closest pivots for each point in each chunk. /// This API doesn't involve reading/writing to disk and is used for in-memory. -pub fn generate_pq_data_from_pivots_from_membuf_batch< - T: Copy + Sync + Into, - Pool: AsThreadPool, ->( +pub fn generate_pq_data_from_pivots_from_membuf_batch>( parameters: &GeneratePivotArguments, vector_data: &[T], pivot_data: &[f32], centroid: &[f32], offsets: &[usize], pq_out: &mut [u8], - pool: Pool, + pool: RayonThreadPoolRef<'_>, ) -> ANNResult<()> { // Perform minimal error checking at this level, mainly on the sizes of `vector_data` // and `pq_out`. @@ -645,8 +636,6 @@ pub fn generate_pq_data_from_pivots_from_membuf_batch< let translate_to_center = parameters.translate_to_center(); let centroid_option: Option<&[f32]> = translate_to_center.then_some(centroid); - forward_threadpool!(pool = pool); - pq_out .par_chunks_mut(num_pq_chunks) .zip(vector_data.par_chunks(dim)) @@ -722,7 +711,7 @@ mod pq_test { &pq_storage, &storage_provider, crate::utils::create_rnd_provider_from_seed_in_tests(42), - &pool, + pool.as_ref(), ) .unwrap(); @@ -794,7 +783,7 @@ mod pq_test { &mut full_pivot_data, &mut crate::utils::create_rnd_in_tests(), &mut (false), - &pool, + pool.as_ref(), ); assert!(result.is_ok()); @@ -831,7 +820,7 @@ mod pq_test { &pq_storage, &storage_provider, crate::utils::create_rnd_provider_from_seed_in_tests(42), - &pool, + pool.as_ref(), ); // still succeed without training data @@ -873,18 +862,11 @@ mod pq_test { &pq_storage, &storage_provider, crate::utils::create_rnd_provider_from_seed_in_tests(42), - &pool, - ) - .unwrap(); - generate_pq_data_from_pivots::( - 2, - 2, - &mut pq_storage, - &storage_provider, - 0, - &pool, + pool.as_ref(), ) .unwrap(); + generate_pq_data_from_pivots::(2, 2, &mut pq_storage, &storage_provider, 0, pool.as_ref()) + .unwrap(); let compressed = read_bin_from::( &mut storage_provider .open_reader(pq_compressed_vectors_path) @@ -943,7 +925,7 @@ mod pq_test { &mut pivot_data, &mut crate::utils::create_rnd_in_tests(), &mut (false), - &pool, + pool.as_ref(), ) .unwrap(); @@ -1024,17 +1006,17 @@ mod pq_test { &pq_storage, &storage_provider, crate::utils::create_rnd_provider_from_seed_in_tests(42), - &pool, + pool.as_ref(), ) .expect("Failed to generate pivots"); - generate_pq_data_from_pivots::( + generate_pq_data_from_pivots::( NUM_PQ_CENTROIDS, num_pq_chunks, &mut pq_storage, &storage_provider, 0, - &pool, + pool.as_ref(), ) .expect("Failed to generate quantized data"); @@ -1059,7 +1041,7 @@ mod pq_test { membuf_pq_data .par_chunks_mut(num_pq_chunks) .enumerate() - .for_each_in_pool(&pool, |(i, membuf_slice)| { + .for_each_in_pool(pool.as_ref(), |(i, membuf_slice)| { generate_pq_data_from_pivots_from_membuf( &full_data_vector[train_dim * i..train_dim * (i + 1)], &full_pivot_data, @@ -1215,7 +1197,7 @@ mod pq_test { &mut full_pivot_data, &mut crate::utils::create_rnd_in_tests(), &mut (false), - &pool, + pool.as_ref(), ); assert!(result.is_ok()); @@ -1247,13 +1229,13 @@ mod pq_test { let pool = create_thread_pool_for_test(); - generate_pq_data_from_pivots::( + generate_pq_data_from_pivots::( NUM_PQ_CENTROIDS, 1, &mut pq_storage, &storage_provider, 0, - &pool, + pool.as_ref(), ) .expect("Failed to generate quantized data"); @@ -1368,7 +1350,7 @@ mod pq_test { &mut full_pivot_data, &mut crate::utils::create_rnd_in_tests(), &mut (false), - &pool, + pool.as_ref(), ) .unwrap(); @@ -1401,7 +1383,7 @@ mod pq_test { ¢roid, &offsets, &mut pq_data, - &pool, + pool.as_ref(), ) .unwrap(); diff --git a/diskann-providers/src/storage/index_storage.rs b/diskann-providers/src/storage/index_storage.rs index 459a140cd..21a1d9333 100644 --- a/diskann-providers/src/storage/index_storage.rs +++ b/diskann-providers/src/storage/index_storage.rs @@ -277,7 +277,7 @@ mod tests { train_data.as_view(), pq_bytes, &mut create_rnd_from_seed_in_tests(0xe3c52ef001bc7ade), - 2, + crate::utils::create_thread_pool(2).unwrap().as_ref(), ) .unwrap(); diff --git a/diskann-providers/src/utils/mod.rs b/diskann-providers/src/utils/mod.rs index 1b51316b8..20743f5d6 100644 --- a/diskann-providers/src/utils/mod.rs +++ b/diskann-providers/src/utils/mod.rs @@ -10,8 +10,8 @@ pub use bridge_error::{Bridge, BridgeErr}; pub mod rayon_util; pub use rayon_util::{ - AsThreadPool, ParallelIteratorInPool, RayonThreadPool, create_thread_pool, - create_thread_pool_for_bench, create_thread_pool_for_test, execute_with_rayon, + ParallelIteratorInPool, RayonThreadPool, RayonThreadPoolRef, create_thread_pool, + create_thread_pool_for_bench, create_thread_pool_for_test, }; mod timer; diff --git a/diskann-providers/src/utils/rayon_util.rs b/diskann-providers/src/utils/rayon_util.rs index 68a242a1b..ac5c2101d 100644 --- a/diskann-providers/src/utils/rayon_util.rs +++ b/diskann-providers/src/utils/rayon_util.rs @@ -2,27 +2,8 @@ * Copyright (c) Microsoft Corporation. * Licensed under the MIT license. */ -use std::ops::Range; - use diskann::{ANNError, ANNResult}; -use rayon::prelude::{IntoParallelIterator, ParallelIterator}; - -/// based on thread_num, execute the task in parallel using Rayon or serial -#[inline] -pub fn execute_with_rayon(range: Range, num_threads: usize, f: F) -> ANNResult<()> -where - F: Fn(usize) -> ANNResult<()> + Sync + Send + Copy, -{ - if num_threads == 1 { - for i in range { - f(i)?; - } - Ok(()) - } else { - let pool = create_thread_pool(num_threads)?; - range.into_par_iter().try_for_each_in_pool(&pool, f) - } -} +use rayon::prelude::ParallelIterator; /// Creates a new thread pool with the specified number of threads. /// If `num_threads` is 0, it defaults to the number of logical CPUs. @@ -71,65 +52,43 @@ impl RayonThreadPool { { self.0.install(op) } -} -mod sealed { - pub trait Sealed {} -} - -/// This allows either an integer to be provided or an explicit `&RayonThreadPool`. -/// If an integer is provided, we create a new thread-pool with the requested number of -/// threads. -/// -/// This trait should be "sealed" to avoid external users being able to implement it. -/// See [as_threadpool_tests] for examples of how to use this trait. -pub trait AsThreadPool: sealed::Sealed + Send + Sync { - type Returns: std::ops::Deref; - fn as_threadpool(&self) -> ANNResult; + pub fn as_ref(&self) -> RayonThreadPoolRef<'_> { + RayonThreadPoolRef(&self.0) + } } -impl sealed::Sealed for usize {} -impl sealed::Sealed for &RayonThreadPool {} +#[derive(Clone, Copy)] +pub struct RayonThreadPoolRef<'a>(&'a rayon::ThreadPool); -impl AsThreadPool for usize { - type Returns = diskann_utils::reborrow::Place; - fn as_threadpool(&self) -> ANNResult { - create_thread_pool(*self).map(diskann_utils::reborrow::Place) +impl<'a> RayonThreadPoolRef<'a> { + /// Wrap an externally-owned `rayon::ThreadPool`. + pub fn new(pool: &'a rayon::ThreadPool) -> Self { + Self(pool) } -} -impl<'a> AsThreadPool for &'a RayonThreadPool { - type Returns = &'a RayonThreadPool; - fn as_threadpool(&self) -> ANNResult { - Ok(self) + pub fn install(self, op: OP) -> R + where + OP: FnOnce() -> R + Send, + R: Send, + { + self.0.install(op) } } -/// The `forward_threadpool` macro simplifies obtaining a thread pool from an input -/// that implements the `AsThreadPool` trait. -#[macro_export] -macro_rules! forward_threadpool { - ($out:ident = $in:ident) => { - $crate::forward_threadpool!($out = $in: _); - }; - ($out:ident = $in:ident: $type:ty) => { - let $out = &*<$type as $crate::utils::AsThreadPool>::as_threadpool(&$in)?; - }; -} - // Allow use of disallowed methods within this trait to provide custom // implementations of common parallel operations that enforce execution // within a specified thread pool. #[allow(clippy::disallowed_methods)] pub trait ParallelIteratorInPool: ParallelIterator + Sized { - fn for_each_in_pool(self, pool: &RayonThreadPool, op: OP) + fn for_each_in_pool(self, pool: RayonThreadPoolRef<'_>, op: OP) where OP: Fn(Self::Item) + Sync + Send, { pool.install(|| self.for_each(op)); } - fn for_each_with_in_pool(self, pool: &RayonThreadPool, init: T, op: OP) + fn for_each_with_in_pool(self, pool: RayonThreadPoolRef<'_>, init: T, op: OP) where OP: Fn(&mut T, Self::Item) + Sync + Send, T: Send + Clone, @@ -137,7 +96,7 @@ pub trait ParallelIteratorInPool: ParallelIterator + Sized { pool.install(|| self.for_each_with(init, op)) } - fn for_each_init_in_pool(self, pool: &RayonThreadPool, init: INIT, op: OP) + fn for_each_init_in_pool(self, pool: RayonThreadPoolRef<'_>, init: INIT, op: OP) where OP: Fn(&mut T, Self::Item) + Sync + Send, INIT: Fn() -> T + Sync + Send, @@ -145,7 +104,7 @@ pub trait ParallelIteratorInPool: ParallelIterator + Sized { pool.install(|| self.for_each_init(init, op)) } - fn try_for_each_in_pool(self, pool: &RayonThreadPool, op: OP) -> Result<(), E> + fn try_for_each_in_pool(self, pool: RayonThreadPoolRef<'_>, op: OP) -> Result<(), E> where OP: Fn(Self::Item) -> Result<(), E> + Sync + Send, E: Send, @@ -155,7 +114,7 @@ pub trait ParallelIteratorInPool: ParallelIterator + Sized { fn try_for_each_with_in_pool( self, - pool: &RayonThreadPool, + pool: RayonThreadPoolRef<'_>, init: T, op: OP, ) -> Result<(), E> @@ -169,7 +128,7 @@ pub trait ParallelIteratorInPool: ParallelIterator + Sized { fn try_for_each_init_in_pool( self, - pool: &RayonThreadPool, + pool: RayonThreadPoolRef<'_>, init: INIT, op: OP, ) -> Result<(), E> @@ -181,18 +140,18 @@ pub trait ParallelIteratorInPool: ParallelIterator + Sized { pool.install(|| self.try_for_each_init(init, op)) } - fn count_in_pool(self, pool: &RayonThreadPool) -> usize { + fn count_in_pool(self, pool: RayonThreadPoolRef<'_>) -> usize { pool.install(|| self.count()) } - fn collect_in_pool(self, pool: &RayonThreadPool) -> C + fn collect_in_pool(self, pool: RayonThreadPoolRef<'_>) -> C where C: rayon::iter::FromParallelIterator + Send, { pool.install(|| self.collect()) } - fn sum_in_pool(self, pool: &RayonThreadPool) -> S + fn sum_in_pool(self, pool: RayonThreadPoolRef<'_>) -> S where S: Send + std::iter::Sum + std::iter::Sum, { @@ -208,6 +167,7 @@ mod tests { use std::sync::{Mutex, mpsc::channel}; use super::*; + use rayon::prelude::IntoParallelIterator; fn get_num_cpus() -> usize { std::thread::available_parallelism() @@ -280,12 +240,32 @@ mod tests { assert!(rayon::current_thread_index().is_some()); } + #[test] + fn test_bring_your_own_pool() { + let external_pool = rayon::ThreadPoolBuilder::new() + .num_threads(2) + .build() + .unwrap(); + let pool_ref = RayonThreadPoolRef::new(&external_pool); + + let res = Mutex::new(Vec::new()); + (0..5).into_par_iter().for_each_in_pool(pool_ref, |x| { + let mut res = res.lock().unwrap(); + res.push(x); + assert_run_in_rayon_thread(); + }); + + let mut res = res.lock().unwrap(); + res.sort(); + assert_eq!(&res[..], &[0, 1, 2, 3, 4]); + } + #[test] fn test_for_each_in_pool() { let pool = create_thread_pool(4).unwrap(); let res = Mutex::new(Vec::new()); - (0..5).into_par_iter().for_each_in_pool(&pool, |x| { + (0..5).into_par_iter().for_each_in_pool(pool.as_ref(), |x| { let mut res = res.lock().unwrap(); res.push(x); assert_run_in_rayon_thread(); @@ -303,7 +283,7 @@ mod tests { (0..5) .into_par_iter() - .for_each_with_in_pool(&pool, sender, |s, x| s.send(x).unwrap()); + .for_each_with_in_pool(pool.as_ref(), sender, |s, x| s.send(x).unwrap()); let mut res: Vec<_> = receiver.iter().collect(); @@ -317,7 +297,7 @@ mod tests { let pool = create_thread_pool(4).unwrap(); let iter = (0..100).into_par_iter(); iter.for_each_init_in_pool( - &pool, + pool.as_ref(), || 0, |s, i| { assert_run_in_rayon_thread(); @@ -334,7 +314,7 @@ mod tests { assert_run_in_rayon_thread(); i as f32 }); - let list = mapped_iter.collect_in_pool::>(&pool); + let list = mapped_iter.collect_in_pool::>(pool.as_ref()); assert!(list.len() == 100); } @@ -342,7 +322,7 @@ mod tests { fn test_try_for_each_in_pool() { let pool = create_thread_pool(4).unwrap(); let iter = (0..100).into_par_iter(); - let result = iter.try_for_each_in_pool(&pool, |i| { + let result = iter.try_for_each_in_pool(pool.as_ref(), |i| { assert_run_in_rayon_thread(); if i < 50 { Ok(()) } else { Err("Error") } }); @@ -354,7 +334,7 @@ mod tests { let pool = create_thread_pool(4).unwrap(); let iter = (0..100).into_par_iter(); let result = iter.try_for_each_init_in_pool( - &pool, + pool.as_ref(), || 0, |_, i| { assert_run_in_rayon_thread(); @@ -368,7 +348,7 @@ mod tests { fn test_try_for_each_with_in_pool() { let pool = create_thread_pool(4).unwrap(); let iter = (0..100).into_par_iter(); - let result = iter.try_for_each_with_in_pool(&pool, 0, |acc, i| { + let result = iter.try_for_each_with_in_pool(pool.as_ref(), 0, |acc, i| { assert_run_in_rayon_thread(); if i < 50 { *acc += i; @@ -384,7 +364,7 @@ mod tests { fn test_count_in_pool() { let pool = create_thread_pool(4).unwrap(); let iter = (0..100).into_par_iter(); - let count = iter.count_in_pool(&pool); + let count = iter.count_in_pool(pool.as_ref()); assert_eq!(count, 100); } @@ -392,7 +372,7 @@ mod tests { fn test_collect_in_pool() { let pool = create_thread_pool(4).unwrap(); let iter = (0..100).into_par_iter(); - let vec = iter.collect_in_pool::>(&pool); + let vec = iter.collect_in_pool::>(pool.as_ref()); assert_eq!(vec.len(), 100); } @@ -400,96 +380,7 @@ mod tests { fn test_sum_in_pool() { let pool = create_thread_pool(4).unwrap(); let iter = (0..100).into_par_iter(); - let sum: i32 = iter.sum_in_pool(&pool); + let sum: i32 = iter.sum_in_pool(pool.as_ref()); assert_eq!(sum, (0..100).sum::()); } } - -#[cfg(test)] -mod as_threadpool_tests { - use super::*; - - fn some_parallel_op(pool: P) -> ANNResult { - forward_threadpool!(pool = pool); - - let ret = (0..100).into_par_iter().map(|i| i as f32).sum_in_pool(pool); - Ok(ret) - } - - fn another_parallel_op(pool: P) -> ANNResult { - forward_threadpool!(pool = pool); - let ret = (0..100).into_par_iter().map(|i| i as f32).sum_in_pool(pool); - Ok(ret) - } - - fn execute_single_parallel_op(pool: P) -> ANNResult { - // Directly pass the thread pool to the function. - some_parallel_op(pool) - } - - fn execute_two_parallel_ops(pool: P) -> ANNResult { - // Need a reference to the thread pool to share it with multiple functions. - forward_threadpool!(pool = pool); - - let ret1 = some_parallel_op(pool)?; - let ret2 = another_parallel_op(pool)?; - Ok(ret1 + ret2) - } - - fn execute_combined_parallel_ops(pool: P) -> ANNResult { - // Need a Threadpool reference to execute the operations. - forward_threadpool!(pool = pool); - - let ret1: f32 = (0..100).into_par_iter().map(|i| i as f32).sum_in_pool(pool); - let ret2 = some_parallel_op(pool)?; - Ok(ret1 + ret2) - } - - #[test] - fn test_execute_single_parallel_op_with_usize() { - let num_threads = 4; - let result = execute_single_parallel_op(num_threads); - assert!(result.is_ok()); - assert!(result.unwrap() > 0.0); - } - - #[test] - fn test_execute_single_parallel_op_with_existing_pool() { - let pool = create_thread_pool(4).unwrap(); - let result = execute_single_parallel_op(&pool); - assert!(result.is_ok()); - assert!(result.unwrap() > 0.0); - } - - #[test] - fn test_execute_two_parallel_ops_with_usize() { - let num_threads = 4; - let result = execute_two_parallel_ops(num_threads); - assert!(result.is_ok()); - assert!(result.unwrap() > 0.0); - } - - #[test] - fn test_execute_two_parallel_ops_with_existing_pool() { - let pool = create_thread_pool(4).unwrap(); - let result = execute_two_parallel_ops(&pool); - assert!(result.is_ok()); - assert!(result.unwrap() > 0.0); - } - - #[test] - fn test_execute_combined_parallel_ops_with_usize() { - let num_threads = 4; - let result = execute_combined_parallel_ops(num_threads); - assert!(result.is_ok()); - assert!(result.unwrap() > 0.0); - } - - #[test] - fn test_execute_combined_parallel_ops_with_existing_pool() { - let pool = create_thread_pool(4).unwrap(); - let result = execute_combined_parallel_ops(&pool); - assert!(result.is_ok()); - assert!(result.unwrap() > 0.0); - } -} diff --git a/diskann-tools/src/utils/build_pq.rs b/diskann-tools/src/utils/build_pq.rs index d345de3ef..bdb73e51e 100644 --- a/diskann-tools/src/utils/build_pq.rs +++ b/diskann-tools/src/utils/build_pq.rs @@ -14,7 +14,7 @@ use diskann_providers::{ get_disk_index_compressed_pq_file, get_disk_index_pq_pivot_file, FileStorageProvider, PQStorage, }, - utils::{load_metadata_from_file, Timer}, + utils::{create_thread_pool, load_metadata_from_file, Timer}, }; use diskann_vector::distance::Metric; use tracing::info; @@ -65,6 +65,8 @@ pub fn build_pq( &mut random_provider.create_rnd(), )?; + let pool = create_thread_pool(parameters.num_threads)?; + diskann_providers::model::pq::generate_pq_pivots( GeneratePivotArguments::new( num_train, @@ -78,16 +80,16 @@ pub fn build_pq( &pq_storage, &storage_provider, random_provider, - parameters.num_threads, + pool.as_ref(), )?; - diskann_providers::model::pq::generate_pq_data_from_pivots::( + diskann_providers::model::pq::generate_pq_data_from_pivots::( NUM_PQ_CENTROIDS, num_pq_chunks, &mut pq_storage, &storage_provider, 0, - parameters.num_threads, + pool.as_ref(), )?; info!( diff --git a/diskann-tools/src/utils/ground_truth.rs b/diskann-tools/src/utils/ground_truth.rs index d68e38451..a56c2b83f 100644 --- a/diskann-tools/src/utils/ground_truth.rs +++ b/diskann-tools/src/utils/ground_truth.rs @@ -500,7 +500,7 @@ where queries_and_neighbor_queue .par_iter_mut() .enumerate() - .for_each_in_pool(&pool, |(idx_query, (query, ref mut neighbor_queue))| { + .for_each_in_pool(pool.as_ref(), |(idx_query, (query, ref mut neighbor_queue))| { for (idx_in_batch, data) in data_batch.iter().enumerate() { let idx = (num_base_points + idx_in_batch) as u32; @@ -585,7 +585,7 @@ where query_multivecs_and_neighbor_queue .par_iter_mut() .enumerate() - .for_each_in_pool(&pool, |(query_idx, (query_multivec, neighbor_queue))| { + .for_each_in_pool(pool.as_ref(), |(query_idx, (query_multivec, neighbor_queue))| { for (idx_base, base_multivec) in base_vectors.iter().enumerate() { // check if calculation is allowed by bitmap if present let allowed_by_bitmap = if let Some(ref bitmaps) = query_bitmaps { diff --git a/diskann-tools/src/utils/range_search_disk_index.rs b/diskann-tools/src/utils/range_search_disk_index.rs index df0b0d4e1..cb4834dcb 100644 --- a/diskann-tools/src/utils/range_search_disk_index.rs +++ b/diskann-tools/src/utils/range_search_disk_index.rs @@ -154,7 +154,7 @@ where let test_start = Instant::now(); zipped.for_each_in_pool( - &pool, + pool.as_ref(), |((((res_count, query), query_result_id), query_result_dist), stats)| { let mut associated_data = vec![]; diff --git a/diskann-tools/src/utils/search_disk_index.rs b/diskann-tools/src/utils/search_disk_index.rs index 6125eb21d..a0a91fde2 100644 --- a/diskann-tools/src/utils/search_disk_index.rs +++ b/diskann-tools/src/utils/search_disk_index.rs @@ -241,7 +241,7 @@ where let test_start = Instant::now(); zipped.for_each_in_pool( - &pool, + pool.as_ref(), |( (((((_cmp, query), vector_filter), query_result_id), query_result_dist), stats), result_count,