From d284073db3aabd112a40df7deb573c56f012d770 Mon Sep 17 00:00:00 2001 From: Nicholas Gates Date: Wed, 11 Mar 2026 20:07:42 -0400 Subject: [PATCH 1/4] Cut over sum aggregate function Signed-off-by: Nicholas Gates --- vortex-array/public-api.lock | 118 --- vortex-array/src/arrays/bool/compute/mod.rs | 1 - vortex-array/src/arrays/bool/compute/sum.rs | 46 - .../src/arrays/chunked/compute/mod.rs | 1 - .../src/arrays/chunked/compute/sum.rs | 236 ----- .../src/arrays/constant/compute/mod.rs | 1 - .../src/arrays/constant/compute/sum.rs | 302 ------- .../src/arrays/decimal/compute/mod.rs | 1 - .../src/arrays/decimal/compute/sum.rs | 416 --------- .../src/arrays/extension/compute/mod.rs | 1 - .../src/arrays/extension/compute/sum.rs | 20 - .../src/arrays/primitive/compute/mod.rs | 1 - .../src/arrays/primitive/compute/sum.rs | 140 --- vortex-array/src/compute/mod.rs | 1 - vortex-array/src/compute/sum.rs | 818 ++++++++++++------ 15 files changed, 571 insertions(+), 1532 deletions(-) delete mode 100644 vortex-array/src/arrays/bool/compute/sum.rs delete mode 100644 vortex-array/src/arrays/chunked/compute/sum.rs delete mode 100644 vortex-array/src/arrays/constant/compute/sum.rs delete mode 100644 vortex-array/src/arrays/decimal/compute/sum.rs delete mode 100644 vortex-array/src/arrays/extension/compute/sum.rs delete mode 100644 vortex-array/src/arrays/primitive/compute/sum.rs diff --git a/vortex-array/public-api.lock b/vortex-array/public-api.lock index 8b822467195..102f17fc105 100644 --- a/vortex-array/public-api.lock +++ b/vortex-array/public-api.lock @@ -542,10 +542,6 @@ impl vortex_array::compute::MinMaxKernel for vortex_array::arrays::BoolVTable pub fn vortex_array::arrays::BoolVTable::min_max(&self, array: &vortex_array::arrays::BoolArray) -> vortex_error::VortexResult> -impl vortex_array::compute::SumKernel for vortex_array::arrays::BoolVTable - -pub fn vortex_array::arrays::BoolVTable::sum(&self, array: &vortex_array::arrays::BoolArray, accumulator: &vortex_array::scalar::Scalar) -> vortex_error::VortexResult - impl vortex_array::optimizer::rules::ArrayParentReduceRule for vortex_array::arrays::bool::BoolMaskedValidityRule pub type vortex_array::arrays::bool::BoolMaskedValidityRule::Parent = vortex_array::arrays::MaskedVTable @@ -720,10 +716,6 @@ impl vortex_array::compute::MinMaxKernel for vortex_array::arrays::ChunkedVTable pub fn vortex_array::arrays::ChunkedVTable::min_max(&self, array: &vortex_array::arrays::ChunkedArray) -> vortex_error::VortexResult> -impl vortex_array::compute::SumKernel for vortex_array::arrays::ChunkedVTable - -pub fn vortex_array::arrays::ChunkedVTable::sum(&self, array: &vortex_array::arrays::ChunkedArray, accumulator: &vortex_array::scalar::Scalar) -> vortex_error::VortexResult - impl vortex_array::scalar_fn::fns::cast::CastReduce for vortex_array::arrays::ChunkedVTable pub fn vortex_array::arrays::ChunkedVTable::cast(array: &vortex_array::arrays::ChunkedArray, dtype: &vortex_array::dtype::DType) -> vortex_error::VortexResult> @@ -874,10 +866,6 @@ impl vortex_array::compute::MinMaxKernel for vortex_array::arrays::ConstantVTabl pub fn vortex_array::arrays::ConstantVTable::min_max(&self, array: &vortex_array::arrays::ConstantArray) -> vortex_error::VortexResult> -impl vortex_array::compute::SumKernel for vortex_array::arrays::ConstantVTable - -pub fn vortex_array::arrays::ConstantVTable::sum(&self, array: &vortex_array::arrays::ConstantArray, accumulator: &vortex_array::scalar::Scalar) -> vortex_error::VortexResult - impl vortex_array::scalar_fn::fns::between::BetweenReduce for vortex_array::arrays::ConstantVTable pub fn vortex_array::arrays::ConstantVTable::between(array: &vortex_array::arrays::ConstantArray, lower: &vortex_array::ArrayRef, upper: &vortex_array::ArrayRef, options: &vortex_array::scalar_fn::fns::between::BetweenOptions) -> vortex_error::VortexResult> @@ -1154,10 +1142,6 @@ impl vortex_array::compute::MinMaxKernel for vortex_array::arrays::DecimalVTable pub fn vortex_array::arrays::DecimalVTable::min_max(&self, array: &vortex_array::arrays::DecimalArray) -> vortex_error::VortexResult> -impl vortex_array::compute::SumKernel for vortex_array::arrays::DecimalVTable - -pub fn vortex_array::arrays::DecimalVTable::sum(&self, array: &vortex_array::arrays::DecimalArray, accumulator: &vortex_array::scalar::Scalar) -> vortex_error::VortexResult - impl vortex_array::optimizer::rules::ArrayParentReduceRule for vortex_array::arrays::decimal::DecimalMaskedValidityRule pub type vortex_array::arrays::decimal::DecimalMaskedValidityRule::Parent = vortex_array::arrays::MaskedVTable @@ -1778,10 +1762,6 @@ impl vortex_array::compute::MinMaxKernel for vortex_array::arrays::ExtensionVTab pub fn vortex_array::arrays::ExtensionVTable::min_max(&self, array: &vortex_array::arrays::ExtensionArray) -> vortex_error::VortexResult> -impl vortex_array::compute::SumKernel for vortex_array::arrays::ExtensionVTable - -pub fn vortex_array::arrays::ExtensionVTable::sum(&self, array: &vortex_array::arrays::ExtensionArray, accumulator: &vortex_array::scalar::Scalar) -> vortex_error::VortexResult - impl vortex_array::scalar_fn::fns::binary::CompareKernel for vortex_array::arrays::ExtensionVTable pub fn vortex_array::arrays::ExtensionVTable::compare(lhs: &vortex_array::arrays::ExtensionArray, rhs: &vortex_array::ArrayRef, operator: vortex_array::scalar_fn::fns::operators::CompareOperator, _ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult> @@ -3102,10 +3082,6 @@ impl vortex_array::compute::NaNCountKernel for vortex_array::arrays::PrimitiveVT pub fn vortex_array::arrays::PrimitiveVTable::nan_count(&self, array: &vortex_array::arrays::PrimitiveArray) -> vortex_error::VortexResult -impl vortex_array::compute::SumKernel for vortex_array::arrays::PrimitiveVTable - -pub fn vortex_array::arrays::PrimitiveVTable::sum(&self, array: &vortex_array::arrays::PrimitiveArray, accumulator: &vortex_array::scalar::Scalar) -> vortex_error::VortexResult - impl vortex_array::optimizer::rules::ArrayParentReduceRule for vortex_array::arrays::primitive::PrimitiveMaskedValidityRule pub type vortex_array::arrays::primitive::PrimitiveMaskedValidityRule::Parent = vortex_array::arrays::MaskedVTable @@ -4716,10 +4692,6 @@ impl vortex_array::compute::MinMaxKernel for vortex_array::arrays::BoolVTable pub fn vortex_array::arrays::BoolVTable::min_max(&self, array: &vortex_array::arrays::BoolArray) -> vortex_error::VortexResult> -impl vortex_array::compute::SumKernel for vortex_array::arrays::BoolVTable - -pub fn vortex_array::arrays::BoolVTable::sum(&self, array: &vortex_array::arrays::BoolArray, accumulator: &vortex_array::scalar::Scalar) -> vortex_error::VortexResult - impl vortex_array::optimizer::rules::ArrayParentReduceRule for vortex_array::arrays::bool::BoolMaskedValidityRule pub type vortex_array::arrays::bool::BoolMaskedValidityRule::Parent = vortex_array::arrays::MaskedVTable @@ -4892,10 +4864,6 @@ impl vortex_array::compute::MinMaxKernel for vortex_array::arrays::ChunkedVTable pub fn vortex_array::arrays::ChunkedVTable::min_max(&self, array: &vortex_array::arrays::ChunkedArray) -> vortex_error::VortexResult> -impl vortex_array::compute::SumKernel for vortex_array::arrays::ChunkedVTable - -pub fn vortex_array::arrays::ChunkedVTable::sum(&self, array: &vortex_array::arrays::ChunkedArray, accumulator: &vortex_array::scalar::Scalar) -> vortex_error::VortexResult - impl vortex_array::scalar_fn::fns::cast::CastReduce for vortex_array::arrays::ChunkedVTable pub fn vortex_array::arrays::ChunkedVTable::cast(array: &vortex_array::arrays::ChunkedArray, dtype: &vortex_array::dtype::DType) -> vortex_error::VortexResult> @@ -5044,10 +5012,6 @@ impl vortex_array::compute::MinMaxKernel for vortex_array::arrays::ConstantVTabl pub fn vortex_array::arrays::ConstantVTable::min_max(&self, array: &vortex_array::arrays::ConstantArray) -> vortex_error::VortexResult> -impl vortex_array::compute::SumKernel for vortex_array::arrays::ConstantVTable - -pub fn vortex_array::arrays::ConstantVTable::sum(&self, array: &vortex_array::arrays::ConstantArray, accumulator: &vortex_array::scalar::Scalar) -> vortex_error::VortexResult - impl vortex_array::scalar_fn::fns::between::BetweenReduce for vortex_array::arrays::ConstantVTable pub fn vortex_array::arrays::ConstantVTable::between(array: &vortex_array::arrays::ConstantArray, lower: &vortex_array::ArrayRef, upper: &vortex_array::ArrayRef, options: &vortex_array::scalar_fn::fns::between::BetweenOptions) -> vortex_error::VortexResult> @@ -5234,10 +5198,6 @@ impl vortex_array::compute::MinMaxKernel for vortex_array::arrays::DecimalVTable pub fn vortex_array::arrays::DecimalVTable::min_max(&self, array: &vortex_array::arrays::DecimalArray) -> vortex_error::VortexResult> -impl vortex_array::compute::SumKernel for vortex_array::arrays::DecimalVTable - -pub fn vortex_array::arrays::DecimalVTable::sum(&self, array: &vortex_array::arrays::DecimalArray, accumulator: &vortex_array::scalar::Scalar) -> vortex_error::VortexResult - impl vortex_array::optimizer::rules::ArrayParentReduceRule for vortex_array::arrays::decimal::DecimalMaskedValidityRule pub type vortex_array::arrays::decimal::DecimalMaskedValidityRule::Parent = vortex_array::arrays::MaskedVTable @@ -5592,10 +5552,6 @@ impl vortex_array::compute::MinMaxKernel for vortex_array::arrays::ExtensionVTab pub fn vortex_array::arrays::ExtensionVTable::min_max(&self, array: &vortex_array::arrays::ExtensionArray) -> vortex_error::VortexResult> -impl vortex_array::compute::SumKernel for vortex_array::arrays::ExtensionVTable - -pub fn vortex_array::arrays::ExtensionVTable::sum(&self, array: &vortex_array::arrays::ExtensionArray, accumulator: &vortex_array::scalar::Scalar) -> vortex_error::VortexResult - impl vortex_array::scalar_fn::fns::binary::CompareKernel for vortex_array::arrays::ExtensionVTable pub fn vortex_array::arrays::ExtensionVTable::compare(lhs: &vortex_array::arrays::ExtensionArray, rhs: &vortex_array::ArrayRef, operator: vortex_array::scalar_fn::fns::operators::CompareOperator, _ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult> @@ -6690,10 +6646,6 @@ impl vortex_array::compute::NaNCountKernel for vortex_array::arrays::PrimitiveVT pub fn vortex_array::arrays::PrimitiveVTable::nan_count(&self, array: &vortex_array::arrays::PrimitiveArray) -> vortex_error::VortexResult -impl vortex_array::compute::SumKernel for vortex_array::arrays::PrimitiveVTable - -pub fn vortex_array::arrays::PrimitiveVTable::sum(&self, array: &vortex_array::arrays::PrimitiveArray, accumulator: &vortex_array::scalar::Scalar) -> vortex_error::VortexResult - impl vortex_array::optimizer::rules::ArrayParentReduceRule for vortex_array::arrays::primitive::PrimitiveMaskedValidityRule pub type vortex_array::arrays::primitive::PrimitiveMaskedValidityRule::Parent = vortex_array::arrays::MaskedVTable @@ -9556,12 +9508,6 @@ impl<'a> core::clone::Clone for vortex_array::compute::InvocationArgs<'a> pub fn vortex_array::compute::InvocationArgs<'a>::clone(&self) -> vortex_array::compute::InvocationArgs<'a> -impl<'a> core::convert::TryFrom<&vortex_array::compute::InvocationArgs<'a>> for vortex_array::compute::SumArgs<'a> - -pub type vortex_array::compute::SumArgs<'a>::Error = vortex_error::VortexError - -pub fn vortex_array::compute::SumArgs<'a>::try_from(value: &vortex_array::compute::InvocationArgs<'a>) -> core::result::Result - pub struct vortex_array::compute::IsConstantKernelAdapter(pub V) impl vortex_array::compute::IsConstantKernelAdapter @@ -9696,36 +9642,6 @@ pub struct vortex_array::compute::NaNCountKernelRef(_) impl inventory::Collect for vortex_array::compute::NaNCountKernelRef -pub struct vortex_array::compute::SumArgs<'a> - -pub vortex_array::compute::SumArgs::accumulator: &'a vortex_array::scalar::Scalar - -pub vortex_array::compute::SumArgs::array: &'a dyn vortex_array::DynArray - -impl<'a> core::convert::TryFrom<&vortex_array::compute::InvocationArgs<'a>> for vortex_array::compute::SumArgs<'a> - -pub type vortex_array::compute::SumArgs<'a>::Error = vortex_error::VortexError - -pub fn vortex_array::compute::SumArgs<'a>::try_from(value: &vortex_array::compute::InvocationArgs<'a>) -> core::result::Result - -pub struct vortex_array::compute::SumKernelAdapter(pub V) - -impl vortex_array::compute::SumKernelAdapter - -pub const fn vortex_array::compute::SumKernelAdapter::lift(&'static self) -> vortex_array::compute::SumKernelRef - -impl core::fmt::Debug for vortex_array::compute::SumKernelAdapter - -pub fn vortex_array::compute::SumKernelAdapter::fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result - -impl vortex_array::compute::Kernel for vortex_array::compute::SumKernelAdapter - -pub fn vortex_array::compute::SumKernelAdapter::invoke(&self, args: &vortex_array::compute::InvocationArgs<'_>) -> vortex_error::VortexResult> - -pub struct vortex_array::compute::SumKernelRef(_) - -impl inventory::Collect for vortex_array::compute::SumKernelRef - pub struct vortex_array::compute::UnaryArgs<'a, O: vortex_array::compute::Options> pub vortex_array::compute::UnaryArgs::array: &'a dyn vortex_array::DynArray @@ -9910,10 +9826,6 @@ impl vo pub fn vortex_array::compute::NaNCountKernelAdapter::invoke(&self, args: &vortex_array::compute::InvocationArgs<'_>) -> vortex_error::VortexResult> -impl vortex_array::compute::Kernel for vortex_array::compute::SumKernelAdapter - -pub fn vortex_array::compute::SumKernelAdapter::invoke(&self, args: &vortex_array::compute::InvocationArgs<'_>) -> vortex_error::VortexResult> - pub trait vortex_array::compute::MinMaxKernel: vortex_array::vtable::VTable pub fn vortex_array::compute::MinMaxKernel::min_max(&self, array: &Self::Array) -> vortex_error::VortexResult> @@ -9998,34 +9910,6 @@ impl vortex_array::compute::Options for vortex_array::scalar_fn::fns::between::B pub fn vortex_array::scalar_fn::fns::between::BetweenOptions::as_any(&self) -> &dyn core::any::Any -pub trait vortex_array::compute::SumKernel: vortex_array::vtable::VTable - -pub fn vortex_array::compute::SumKernel::sum(&self, array: &Self::Array, accumulator: &vortex_array::scalar::Scalar) -> vortex_error::VortexResult - -impl vortex_array::compute::SumKernel for vortex_array::arrays::BoolVTable - -pub fn vortex_array::arrays::BoolVTable::sum(&self, array: &vortex_array::arrays::BoolArray, accumulator: &vortex_array::scalar::Scalar) -> vortex_error::VortexResult - -impl vortex_array::compute::SumKernel for vortex_array::arrays::ChunkedVTable - -pub fn vortex_array::arrays::ChunkedVTable::sum(&self, array: &vortex_array::arrays::ChunkedArray, accumulator: &vortex_array::scalar::Scalar) -> vortex_error::VortexResult - -impl vortex_array::compute::SumKernel for vortex_array::arrays::ConstantVTable - -pub fn vortex_array::arrays::ConstantVTable::sum(&self, array: &vortex_array::arrays::ConstantArray, accumulator: &vortex_array::scalar::Scalar) -> vortex_error::VortexResult - -impl vortex_array::compute::SumKernel for vortex_array::arrays::DecimalVTable - -pub fn vortex_array::arrays::DecimalVTable::sum(&self, array: &vortex_array::arrays::DecimalArray, accumulator: &vortex_array::scalar::Scalar) -> vortex_error::VortexResult - -impl vortex_array::compute::SumKernel for vortex_array::arrays::ExtensionVTable - -pub fn vortex_array::arrays::ExtensionVTable::sum(&self, array: &vortex_array::arrays::ExtensionArray, accumulator: &vortex_array::scalar::Scalar) -> vortex_error::VortexResult - -impl vortex_array::compute::SumKernel for vortex_array::arrays::PrimitiveVTable - -pub fn vortex_array::arrays::PrimitiveVTable::sum(&self, array: &vortex_array::arrays::PrimitiveArray, accumulator: &vortex_array::scalar::Scalar) -> vortex_error::VortexResult - pub fn vortex_array::compute::is_constant(array: &vortex_array::ArrayRef) -> vortex_error::VortexResult> pub fn vortex_array::compute::is_constant_opts(array: &vortex_array::ArrayRef, options: &vortex_array::compute::IsConstantOpts) -> vortex_error::VortexResult> @@ -10042,8 +9926,6 @@ pub fn vortex_array::compute::nan_count(array: &vortex_array::ArrayRef) -> vorte pub fn vortex_array::compute::sum(array: &vortex_array::ArrayRef) -> vortex_error::VortexResult -pub fn vortex_array::compute::sum_impl(array: &vortex_array::ArrayRef, accumulator: &vortex_array::scalar::Scalar, kernels: &[arcref::ArcRef]) -> vortex_error::VortexResult - pub fn vortex_array::compute::warm_up_vtables() pub mod vortex_array::display diff --git a/vortex-array/src/arrays/bool/compute/mod.rs b/vortex-array/src/arrays/bool/compute/mod.rs index 248ea617f2f..0799b77ad0e 100644 --- a/vortex-array/src/arrays/bool/compute/mod.rs +++ b/vortex-array/src/arrays/bool/compute/mod.rs @@ -10,7 +10,6 @@ mod mask; mod min_max; pub mod rules; mod slice; -mod sum; mod take; #[cfg(test)] diff --git a/vortex-array/src/arrays/bool/compute/sum.rs b/vortex-array/src/arrays/bool/compute/sum.rs deleted file mode 100644 index e5c036f190d..00000000000 --- a/vortex-array/src/arrays/bool/compute/sum.rs +++ /dev/null @@ -1,46 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: Copyright the Vortex contributors - -use std::ops::BitAnd; - -use vortex_error::VortexExpect; -use vortex_error::VortexResult; -use vortex_mask::AllOr; - -use crate::arrays::BoolArray; -use crate::arrays::BoolVTable; -use crate::compute::SumKernel; -use crate::compute::SumKernelAdapter; -use crate::dtype::Nullability; -use crate::register_kernel; -use crate::scalar::Scalar; - -impl SumKernel for BoolVTable { - fn sum(&self, array: &BoolArray, accumulator: &Scalar) -> VortexResult { - let true_count: Option = match array.validity_mask()?.bit_buffer() { - AllOr::All => { - // All-valid - Some(array.to_bit_buffer().true_count() as u64) - } - AllOr::None => { - // All-invalid - unreachable!("All-invalid boolean array should have been handled by entry-point") - } - AllOr::Some(validity_mask) => { - Some(array.to_bit_buffer().bitand(validity_mask).true_count() as u64) - } - }; - - let acc_value = accumulator - .as_primitive() - .as_::() - .vortex_expect("cannot be null"); - let result = true_count.and_then(|tc| acc_value.checked_add(tc)); - Ok(match result { - Some(v) => Scalar::primitive(v, Nullability::Nullable), - None => Scalar::null_native::(), - }) - } -} - -register_kernel!(SumKernelAdapter(BoolVTable).lift()); diff --git a/vortex-array/src/arrays/chunked/compute/mod.rs b/vortex-array/src/arrays/chunked/compute/mod.rs index 791e261745c..93b95b5ce9e 100644 --- a/vortex-array/src/arrays/chunked/compute/mod.rs +++ b/vortex-array/src/arrays/chunked/compute/mod.rs @@ -12,7 +12,6 @@ mod mask; mod min_max; pub(crate) mod rules; mod slice; -mod sum; mod take; mod zip; diff --git a/vortex-array/src/arrays/chunked/compute/sum.rs b/vortex-array/src/arrays/chunked/compute/sum.rs deleted file mode 100644 index 3b7f504c03d..00000000000 --- a/vortex-array/src/arrays/chunked/compute/sum.rs +++ /dev/null @@ -1,236 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: Copyright the Vortex contributors - -use vortex_error::VortexResult; - -use crate::arrays::ChunkedArray; -use crate::arrays::ChunkedVTable; -use crate::compute::SumKernel; -use crate::compute::SumKernelAdapter; -use crate::compute::sum_with_accumulator; -use crate::register_kernel; -use crate::scalar::Scalar; - -impl SumKernel for ChunkedVTable { - fn sum(&self, array: &ChunkedArray, accumulator: &Scalar) -> VortexResult { - array - .chunks - .iter() - .try_fold(accumulator.clone(), |result, chunk| { - sum_with_accumulator(chunk, &result) - }) - } -} - -register_kernel!(SumKernelAdapter(ChunkedVTable).lift()); - -#[cfg(test)] -mod tests { - use vortex_buffer::buffer; - - use crate::array::IntoArray; - use crate::arrays::ChunkedArray; - use crate::arrays::ConstantArray; - use crate::arrays::DecimalArray; - use crate::arrays::PrimitiveArray; - use crate::compute::sum; - use crate::dtype::DType; - use crate::dtype::DecimalDType; - use crate::dtype::Nullability; - use crate::dtype::i256; - use crate::scalar::DecimalValue; - use crate::scalar::Scalar; - use crate::validity::Validity; - - #[test] - fn test_sum_chunked_floats_with_nulls() { - // Create chunks with floats including nulls - let chunk1 = - PrimitiveArray::from_option_iter(vec![Some(1.5f64), None, Some(3.2), Some(4.8)]); - - let chunk2 = PrimitiveArray::from_option_iter(vec![Some(2.1f64), Some(5.7), None]); - - let chunk3 = PrimitiveArray::from_option_iter(vec![None, Some(1.0f64), Some(2.5), None]); - - // Create chunked array from the chunks - let dtype = chunk1.dtype().clone(); - let chunked = ChunkedArray::try_new( - vec![ - chunk1.into_array(), - chunk2.into_array(), - chunk3.into_array(), - ], - dtype, - ) - .unwrap(); - - // Compute sum - let result = sum(&chunked.into_array()).unwrap(); - - // Expected sum: 1.5 + 3.2 + 4.8 + 2.1 + 5.7 + 1.0 + 2.5 = 20.8 - assert_eq!(result.as_primitive().as_::(), Some(20.8)); - } - - #[test] - fn test_sum_chunked_floats_all_nulls_is_zero() { - // Create chunks with all nulls - let chunk1 = PrimitiveArray::from_option_iter::(vec![None, None, None]); - let chunk2 = PrimitiveArray::from_option_iter::(vec![None, None]); - - let dtype = chunk1.dtype().clone(); - let chunked = - ChunkedArray::try_new(vec![chunk1.into_array(), chunk2.into_array()], dtype).unwrap(); - // Compute sum - should return null for all nulls - let result = sum(&chunked.into_array()).unwrap(); - assert_eq!(result, Scalar::primitive(0f64, Nullability::Nullable)); - } - - #[test] - fn test_sum_chunked_floats_empty_chunks() { - // Test with some empty chunks mixed with non-empty - let chunk1 = PrimitiveArray::from_option_iter(vec![Some(10.5f64), Some(20.3)]); - let chunk2 = ConstantArray::new(Scalar::primitive(0f64, Nullability::Nullable), 0); - let chunk3 = PrimitiveArray::from_option_iter(vec![Some(5.2f64)]); - - let dtype = chunk1.dtype().clone(); - let chunked = ChunkedArray::try_new( - vec![ - chunk1.into_array(), - chunk2.into_array(), - chunk3.into_array(), - ], - dtype, - ) - .unwrap(); - - // Compute sum: 10.5 + 20.3 + 5.2 = 36.0 - let result = sum(&chunked.into_array()).unwrap(); - assert_eq!(result.as_primitive().as_::(), Some(36.0)); - } - - #[test] - fn test_sum_chunked_int_almost_all_null_chunks() { - let chunk1 = PrimitiveArray::from_option_iter::(vec![Some(1)]); - let chunk2 = PrimitiveArray::from_option_iter::(vec![None]); - - let dtype = chunk1.dtype().clone(); - let chunked = - ChunkedArray::try_new(vec![chunk1.into_array(), chunk2.into_array()], dtype).unwrap(); - - let result = sum(&chunked.into_array()).unwrap(); - assert_eq!(result.as_primitive().as_::(), Some(1)); - } - - #[test] - fn test_sum_chunked_decimals() { - // Create decimal chunks with precision=10, scale=2 - let decimal_dtype = DecimalDType::new(10, 2); - let chunk1 = DecimalArray::new( - buffer![100i32, 100i32, 100i32, 100i32, 100i32], - decimal_dtype, - Validity::AllValid, - ); - let chunk2 = DecimalArray::new( - buffer![200i32, 200i32, 200i32], - decimal_dtype, - Validity::AllValid, - ); - let chunk3 = DecimalArray::new(buffer![300i32, 300i32], decimal_dtype, Validity::AllValid); - - let dtype = chunk1.dtype().clone(); - let chunked = ChunkedArray::try_new( - vec![ - chunk1.into_array(), - chunk2.into_array(), - chunk3.into_array(), - ], - dtype, - ) - .unwrap(); - - // Compute sum: 5*100 + 3*200 + 2*300 = 500 + 600 + 600 = 1700 (represents 17.00) - let result = sum(&chunked.into_array()).unwrap(); - let decimal_result = result.as_decimal(); - assert_eq!( - decimal_result.decimal_value(), - Some(DecimalValue::I256(i256::from_i128(1700))) - ); - } - - #[test] - fn test_sum_chunked_decimals_with_nulls() { - let decimal_dtype = DecimalDType::new(10, 2); - - // Create chunks with some nulls - all must have same nullability - let chunk1 = DecimalArray::new( - buffer![100i32, 100i32, 100i32], - decimal_dtype, - Validity::AllValid, - ); - let chunk2 = DecimalArray::new( - buffer![0i32, 0i32], - decimal_dtype, - Validity::from_iter([false, false]), - ); - let chunk3 = DecimalArray::new(buffer![200i32, 200i32], decimal_dtype, Validity::AllValid); - - let dtype = chunk1.dtype().clone(); - let chunked = ChunkedArray::try_new( - vec![ - chunk1.into_array(), - chunk2.into_array(), - chunk3.into_array(), - ], - dtype, - ) - .unwrap(); - - // Compute sum: 3*100 + 2*200 = 300 + 400 = 700 (nulls ignored) - let result = sum(&chunked.into_array()).unwrap(); - let decimal_result = result.as_decimal(); - assert_eq!( - decimal_result.decimal_value(), - Some(DecimalValue::I256(i256::from_i128(700))) - ); - } - - #[test] - fn test_sum_chunked_decimals_large() { - // Create decimals with precision 3 (max value 999) - // Sum will be 500 + 600 = 1100, which fits in result precision 13 (3+10) - let decimal_dtype = DecimalDType::new(3, 0); - let chunk1 = ConstantArray::new( - Scalar::decimal( - DecimalValue::I16(500), - decimal_dtype, - Nullability::NonNullable, - ), - 1, - ); - let chunk2 = ConstantArray::new( - Scalar::decimal( - DecimalValue::I16(600), - decimal_dtype, - Nullability::NonNullable, - ), - 1, - ); - - let dtype = chunk1.dtype().clone(); - let chunked = - ChunkedArray::try_new(vec![chunk1.into_array(), chunk2.into_array()], dtype).unwrap(); - - // Compute sum: 500 + 600 = 1100 - // Result should have precision 13 (3+10), scale 0 - let result = sum(&chunked.into_array()).unwrap(); - let decimal_result = result.as_decimal(); - assert_eq!( - decimal_result.decimal_value(), - Some(DecimalValue::I256(i256::from_i128(1100))) - ); - assert_eq!( - result.dtype(), - &DType::Decimal(DecimalDType::new(13, 0), Nullability::Nullable) - ); - } -} diff --git a/vortex-array/src/arrays/constant/compute/mod.rs b/vortex-array/src/arrays/constant/compute/mod.rs index 35578829912..f671dcde3d3 100644 --- a/vortex-array/src/arrays/constant/compute/mod.rs +++ b/vortex-array/src/arrays/constant/compute/mod.rs @@ -9,7 +9,6 @@ mod min_max; mod not; pub(crate) mod rules; mod slice; -mod sum; mod take; #[cfg(test)] diff --git a/vortex-array/src/arrays/constant/compute/sum.rs b/vortex-array/src/arrays/constant/compute/sum.rs deleted file mode 100644 index 27738d4f39a..00000000000 --- a/vortex-array/src/arrays/constant/compute/sum.rs +++ /dev/null @@ -1,302 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: Copyright the Vortex contributors - -use num_traits::AsPrimitive; -use num_traits::CheckedAdd; -use num_traits::CheckedMul; -use vortex_error::VortexExpect; -use vortex_error::VortexResult; -use vortex_error::vortex_bail; -use vortex_error::vortex_err; - -use crate::arrays::ConstantArray; -use crate::arrays::ConstantVTable; -use crate::compute::SumKernel; -use crate::compute::SumKernelAdapter; -use crate::dtype::DType; -use crate::dtype::DecimalDType; -use crate::dtype::NativePType; -use crate::dtype::Nullability; -use crate::dtype::i256; -use crate::expr::stats::Stat; -use crate::match_each_native_ptype; -use crate::register_kernel; -use crate::scalar::DecimalScalar; -use crate::scalar::DecimalValue; -use crate::scalar::PrimitiveScalar; -use crate::scalar::Scalar; -use crate::scalar::ScalarValue; - -impl SumKernel for ConstantVTable { - fn sum(&self, array: &ConstantArray, accumulator: &Scalar) -> VortexResult { - // Compute the expected dtype of the sum. - let sum_dtype = Stat::Sum - .dtype(array.dtype()) - .ok_or_else(|| vortex_err!("Sum not supported for dtype {}", array.dtype()))?; - - let sum_value = sum_scalar(array.scalar(), array.len(), accumulator)?; - Scalar::try_new(sum_dtype, sum_value) - } -} - -fn sum_scalar( - scalar: &Scalar, - len: usize, - accumulator: &Scalar, -) -> VortexResult> { - match scalar.dtype() { - DType::Bool(_) => { - let count = match scalar.as_bool().value() { - None => unreachable!("Handled before reaching this point"), - Some(false) => 0u64, - Some(true) => len as u64, - }; - let accumulator = accumulator - .as_primitive() - .as_::() - .vortex_expect("cannot be null"); - Ok(accumulator - .checked_add(count) - .map(|v| ScalarValue::Primitive(v.into()))) - } - DType::Primitive(ptype, _) => { - #[expect(dead_code, reason = "TODO(connor): good question")] - let result = match_each_native_ptype!( - ptype, - unsigned: |T| { sum_integral::(scalar.as_primitive(), len, accumulator)?.map(|v| ScalarValue::Primitive(v.into())) }, - signed: |T| { sum_integral::(scalar.as_primitive(), len, accumulator)?.map(|v| ScalarValue::Primitive(v.into())) }, - floating: |T| { sum_float(scalar.as_primitive(), len, accumulator)?.map(|v| ScalarValue::Primitive(v.into())) } - ); - Ok(result) - } - DType::Decimal(decimal_dtype, _) => { - sum_decimal(scalar.as_decimal(), len, *decimal_dtype, accumulator) - } - DType::Extension(_) => { - sum_scalar(&scalar.as_extension().to_storage_scalar(), len, accumulator) - } - dtype => vortex_bail!("Unsupported dtype for sum: {}", dtype), - } -} - -fn sum_decimal( - decimal_scalar: DecimalScalar, - array_len: usize, - decimal_dtype: DecimalDType, - accumulator: &Scalar, -) -> VortexResult> { - let result_dtype = Stat::Sum - .dtype(&DType::Decimal(decimal_dtype, Nullability::Nullable)) - .vortex_expect("decimal supports sum"); - let result_decimal_type = result_dtype - .as_decimal_opt() - .vortex_expect("must be decimal"); - - let Some(value) = decimal_scalar.decimal_value() else { - // Null value: return null - return Ok(None); - }; - - // Convert array_len to DecimalValue for multiplication. - let len_value = DecimalValue::I256(i256::from_i128(array_len as i128)); - - let Some(array_sum) = value - .checked_mul(&len_value) - .filter(|d| d.fits_in_precision(*result_decimal_type)) - else { - return Ok(None); - }; - - // Add accumulator to array_sum. - let initial_decimal = accumulator.as_decimal(); - let initial_dec_value = initial_decimal - .decimal_value() - .unwrap_or(DecimalValue::I256(i256::ZERO)); - - let total = array_sum - .checked_add(&initial_dec_value) - .and_then(|result| { - result - .fits_in_precision(*result_decimal_type) - .then_some(result) - }); - match total { - Some(result_value) => Ok(Some(ScalarValue::from(result_value))), - None => Ok(None), // Overflow - } -} - -fn sum_integral( - primitive_scalar: PrimitiveScalar<'_>, - array_len: usize, - accumulator: &Scalar, -) -> VortexResult> -where - T: NativePType + CheckedMul + CheckedAdd, -{ - let v = primitive_scalar.as_::(); - let array_len = - T::from(array_len).ok_or_else(|| vortex_err!("array_len must fit the sum type"))?; - let Some(array_sum) = v.and_then(|v| v.checked_mul(&array_len)) else { - return Ok(None); - }; - - let initial = accumulator - .as_primitive() - .as_::() - .vortex_expect("cannot be null"); - Ok(initial.checked_add(&array_sum)) -} - -fn sum_float( - primitive_scalar: PrimitiveScalar<'_>, - array_len: usize, - accumulator: &Scalar, -) -> VortexResult> { - let initial = accumulator - .as_primitive() - .as_::() - .vortex_expect("cannot be null"); - let v = primitive_scalar - .as_::() - .vortex_expect("cannot be null"); - let len_f64: f64 = array_len.as_(); - - Ok(Some(initial + v * len_f64)) -} - -register_kernel!(SumKernelAdapter(ConstantVTable).lift()); - -#[cfg(test)] -mod tests { - use vortex_error::VortexExpect; - - use crate::DynArray; - use crate::IntoArray; - use crate::arrays::ConstantArray; - use crate::compute::sum; - use crate::compute::sum_with_accumulator; - use crate::dtype::DType; - use crate::dtype::DecimalDType; - use crate::dtype::Nullability; - use crate::dtype::Nullability::Nullable; - use crate::dtype::PType; - use crate::dtype::i256; - use crate::expr::stats::Stat; - use crate::scalar::DecimalValue; - use crate::scalar::Scalar; - - #[test] - fn test_sum_unsigned() { - let array = ConstantArray::new(5u64, 10).into_array(); - let result = sum(&array).unwrap(); - assert_eq!(result, 50u64.into()); - } - - #[test] - fn test_sum_signed() { - let array = ConstantArray::new(-5i64, 10).into_array(); - let result = sum(&array).unwrap(); - assert_eq!(result, (-50i64).into()); - } - - #[test] - fn test_sum_nullable_value() { - let array = ConstantArray::new(Scalar::null(DType::Primitive(PType::U32, Nullable)), 10) - .into_array(); - let result = sum(&array).unwrap(); - assert_eq!(result, Scalar::primitive(0u64, Nullable)); - } - - #[test] - fn test_sum_bool_false() { - let array = ConstantArray::new(false, 10).into_array(); - let result = sum(&array).unwrap(); - assert_eq!(result, 0u64.into()); - } - - #[test] - fn test_sum_bool_true() { - let array = ConstantArray::new(true, 10).into_array(); - let result = sum(&array).unwrap(); - assert_eq!(result, 10u64.into()); - } - - #[test] - fn test_sum_bool_null() { - let array = ConstantArray::new(Scalar::null(DType::Bool(Nullable)), 10).into_array(); - let result = sum(&array).unwrap(); - assert_eq!(result, Scalar::primitive(0u64, Nullable)); - } - - #[test] - fn test_sum_decimal() { - let decimal_dtype = DecimalDType::new(10, 2); - let array = ConstantArray::new( - Scalar::decimal( - DecimalValue::I64(100), - decimal_dtype, - Nullability::NonNullable, - ), - 5, - ) - .into_array(); - - let result = sum(&array).unwrap(); - - assert_eq!( - result.as_decimal().decimal_value(), - Some(DecimalValue::I256(i256::from_i128(500))) - ); - assert_eq!(result.dtype(), &Stat::Sum.dtype(array.dtype()).unwrap()); - } - - #[test] - fn test_sum_decimal_null() { - let decimal_dtype = DecimalDType::new(10, 2); - let array = ConstantArray::new(Scalar::null(DType::Decimal(decimal_dtype, Nullable)), 10) - .into_array(); - - let result = sum(&array).unwrap(); - assert_eq!( - result, - Scalar::decimal( - DecimalValue::I256(i256::ZERO), - DecimalDType::new(20, 2), - Nullable - ) - ); - } - - #[test] - fn test_sum_decimal_large_value() { - let decimal_dtype = DecimalDType::new(10, 2); - let array = ConstantArray::new( - Scalar::decimal( - DecimalValue::I64(999_999_999), - decimal_dtype, - Nullability::NonNullable, - ), - 100, - ) - .into_array(); - - let result = sum(&array).unwrap(); - assert_eq!( - result.as_decimal().decimal_value(), - Some(DecimalValue::I256(i256::from_i128(99_999_999_900))) - ); - } - - #[test] - fn test_sum_float_non_multiply() { - let acc = -2048669276050936500000000000f64; - let array = ConstantArray::new(6.1811675e16f64, 25); - let sum = sum_with_accumulator(&array.into_array(), &Scalar::primitive(acc, Nullable)) - .vortex_expect("operation should succeed in test"); - assert_eq!( - f64::try_from(&sum).vortex_expect("operation should succeed in test"), - -2048669274505644600000000000f64 - ); - } -} diff --git a/vortex-array/src/arrays/decimal/compute/mod.rs b/vortex-array/src/arrays/decimal/compute/mod.rs index 90743d880ac..9a42f7a69d5 100644 --- a/vortex-array/src/arrays/decimal/compute/mod.rs +++ b/vortex-array/src/arrays/decimal/compute/mod.rs @@ -9,7 +9,6 @@ mod is_sorted; mod mask; mod min_max; pub mod rules; -mod sum; mod take; #[cfg(test)] diff --git a/vortex-array/src/arrays/decimal/compute/sum.rs b/vortex-array/src/arrays/decimal/compute/sum.rs deleted file mode 100644 index a250d72e144..00000000000 --- a/vortex-array/src/arrays/decimal/compute/sum.rs +++ /dev/null @@ -1,416 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: Copyright the Vortex contributors - -use itertools::Itertools; -use num_traits::AsPrimitive; -use num_traits::CheckedAdd; -use num_traits::NumOps; -use vortex_buffer::BitBuffer; -use vortex_buffer::Buffer; -use vortex_error::VortexExpect; -use vortex_error::VortexResult; -use vortex_error::vortex_bail; -use vortex_mask::Mask; - -use crate::arrays::DecimalArray; -use crate::arrays::DecimalVTable; -use crate::compute::SumKernel; -use crate::compute::SumKernelAdapter; -use crate::dtype::DType; -use crate::dtype::DecimalDType; -use crate::dtype::DecimalType; -use crate::dtype::Nullability::Nullable; -use crate::expr::stats::Stat; -use crate::match_each_decimal_value_type; -use crate::register_kernel; -use crate::scalar::DecimalValue; -use crate::scalar::Scalar; - -impl SumKernel for DecimalVTable { - fn sum(&self, array: &DecimalArray, accumulator: &Scalar) -> VortexResult { - let return_dtype = Stat::Sum - .dtype(array.dtype()) - .vortex_expect("sum for decimals exists"); - let return_decimal_dtype = *return_dtype - .as_decimal_opt() - .vortex_expect("must be decimal"); - - // Extract the initial value as a `DecimalValue`. - let initial_decimal = accumulator - .as_decimal() - .decimal_value() - .vortex_expect("cannot be null"); - - let mask = array.validity_mask()?; - let validity = match &mask { - Mask::AllTrue(_) => None, - Mask::Values(mask_values) => Some(mask_values.bit_buffer()), - Mask::AllFalse(_) => { - vortex_bail!("invalid state, all-null array should be checked by top-level sum fn") - } - }; - - let values_type = DecimalType::smallest_decimal_value_type(&return_decimal_dtype); - match_each_decimal_value_type!(array.values_type(), |I| { - match_each_decimal_value_type!(values_type, |O| { - let initial_val: O = initial_decimal - .cast() - .vortex_expect("cannot fail to cast initial value"); - - Ok(sum_to_scalar( - array.buffer::(), - validity, - initial_val, - return_decimal_dtype, - &return_dtype, - )) - }) - }) - } -} - -/// Compute the checked sum and convert the result to a [`Scalar`]. -/// -/// Returns a null scalar if the sum overflows the underlying integer type or if the result -/// exceeds the declared decimal precision. -fn sum_to_scalar( - values: Buffer, - validity: Option<&BitBuffer>, - initial: O, - return_decimal_dtype: DecimalDType, - return_dtype: &DType, -) -> Scalar -where - T: AsPrimitive, - O: CheckedAdd + NumOps + Into + Copy + 'static, - bool: AsPrimitive, -{ - let raw_sum = match validity { - Some(v) => sum_decimal_with_validity(values, v, initial), - None => sum_decimal(values, initial), - }; - - raw_sum - .map(Into::::into) - // We have to make sure that the decimal value fits the precision of the decimal dtype. - .filter(|v| v.fits_in_precision(return_decimal_dtype)) - .map(|v| Scalar::decimal(v, return_decimal_dtype, Nullable)) - // If an overflow occurs during summation, or final value does not fit, then return a null. - .unwrap_or_else(|| Scalar::null(return_dtype.clone())) -} - -fn sum_decimal, I: Copy + CheckedAdd + 'static>( - values: Buffer, - initial: I, -) -> Option { - let mut sum = initial; - for v in values.iter() { - let v: I = v.as_(); - sum = CheckedAdd::checked_add(&sum, &v)?; - } - Some(sum) -} - -fn sum_decimal_with_validity(values: Buffer, validity: &BitBuffer, initial: I) -> Option -where - T: AsPrimitive, - I: NumOps + CheckedAdd + Copy + 'static, - bool: AsPrimitive, -{ - let mut sum = initial; - for (v, valid) in values.iter().zip_eq(validity) { - let v: I = v.as_() * valid.as_(); - - sum = CheckedAdd::checked_add(&sum, &v)?; - } - Some(sum) -} - -register_kernel!(SumKernelAdapter(DecimalVTable).lift()); - -#[cfg(test)] -mod tests { - use vortex_buffer::buffer; - use vortex_error::VortexExpect; - - use crate::IntoArray; - use crate::arrays::DecimalArray; - use crate::compute::sum; - use crate::dtype::DType; - use crate::dtype::DecimalDType; - use crate::dtype::Nullability; - use crate::dtype::i256; - use crate::scalar::DecimalValue; - use crate::scalar::Scalar; - use crate::scalar::ScalarValue; - use crate::validity::Validity; - - #[test] - fn test_sum_basic() { - let decimal = DecimalArray::new( - buffer![100i32, 200i32, 300i32], - DecimalDType::new(4, 2), - Validity::AllValid, - ); - - let result = sum(&decimal.into_array()).unwrap(); - - let expected = Scalar::try_new( - DType::Decimal(DecimalDType::new(14, 2), Nullability::NonNullable), - Some(ScalarValue::from(DecimalValue::from(600i32))), - ) - .unwrap(); - - assert_eq!(result, expected); - } - - #[test] - fn test_sum_with_nulls() { - let decimal = DecimalArray::new( - buffer![100i32, 200i32, 300i32, 400i32], - DecimalDType::new(4, 2), - Validity::from_iter([true, false, true, true]), - ); - - let result = sum(&decimal.into_array()).unwrap(); - - let expected = Scalar::try_new( - DType::Decimal(DecimalDType::new(14, 2), Nullability::Nullable), - Some(ScalarValue::from(DecimalValue::from(800i32))), - ) - .unwrap(); - - assert_eq!(result, expected); - } - - #[test] - fn test_sum_negative_values() { - let decimal = DecimalArray::new( - buffer![100i32, -200i32, 300i32, -50i32], - DecimalDType::new(4, 2), - Validity::AllValid, - ); - - let result = sum(&decimal.into_array()).unwrap(); - - let expected = Scalar::try_new( - DType::Decimal(DecimalDType::new(14, 2), Nullability::NonNullable), - Some(ScalarValue::from(DecimalValue::from(150i32))), - ) - .unwrap(); - - assert_eq!(result, expected); - } - - #[test] - fn test_sum_near_i32_max() { - // Test values close to i32::MAX to ensure proper handling - let near_max = i32::MAX - 1000; - let decimal = DecimalArray::new( - buffer![near_max, 500i32, 400i32], - DecimalDType::new(10, 2), - Validity::AllValid, - ); - - let result = sum(&decimal.into_array()).unwrap(); - - // Should use i64 for accumulation since precision increases - let expected_sum = near_max as i64 + 500 + 400; - let expected = Scalar::try_new( - DType::Decimal(DecimalDType::new(20, 2), Nullability::NonNullable), - Some(ScalarValue::from(DecimalValue::from(expected_sum))), - ) - .unwrap(); - - assert_eq!(result, expected); - } - - #[test] - fn test_sum_large_i64_values() { - // Test with large i64 values that require i128 accumulation - let large_val = i64::MAX / 4; - let decimal = DecimalArray::new( - buffer![large_val, large_val, large_val, large_val + 1], - DecimalDType::new(19, 0), - Validity::AllValid, - ); - - let result = sum(&decimal.into_array()).unwrap(); - - let expected_sum = (large_val as i128) * 4 + 1; - let expected = Scalar::try_new( - DType::Decimal(DecimalDType::new(29, 0), Nullability::NonNullable), - Some(ScalarValue::from(DecimalValue::from(expected_sum))), - ) - .unwrap(); - - assert_eq!(result, expected); - } - - #[test] - fn test_sum_overflow_detection() { - use crate::dtype::i256; - - // Create values that will overflow when summed - // Use maximum i128 values that will overflow when added - let max_val = i128::MAX / 2; - let decimal = DecimalArray::new( - buffer![max_val, max_val, max_val], - DecimalDType::new(38, 0), - Validity::AllValid, - ); - - let result = sum(&decimal.into_array()).unwrap(); - - // Should use i256 for accumulation - let expected_sum = - i256::from_i128(max_val) + i256::from_i128(max_val) + i256::from_i128(max_val); - let expected = Scalar::try_new( - DType::Decimal(DecimalDType::new(48, 0), Nullability::NonNullable), - Some(ScalarValue::from(DecimalValue::from(expected_sum))), - ) - .unwrap(); - - assert_eq!(result, expected); - } - - #[test] - fn test_sum_mixed_signs_near_overflow() { - // Test that mixed signs work correctly near overflow boundaries - let large_pos = i64::MAX / 2; - let large_neg = -(i64::MAX / 2); - let decimal = DecimalArray::new( - buffer![large_pos, large_neg, large_pos, 1000i64], - DecimalDType::new(19, 3), - Validity::AllValid, - ); - - let result = sum(&decimal.into_array()).unwrap(); - - let expected_sum = (large_pos as i128) + (large_neg as i128) + (large_pos as i128) + 1000; - let expected = Scalar::try_new( - DType::Decimal(DecimalDType::new(29, 3), Nullability::NonNullable), - Some(ScalarValue::from(DecimalValue::from(expected_sum))), - ) - .unwrap(); - - assert_eq!(result, expected); - } - - #[test] - fn test_sum_preserves_scale() { - let decimal = DecimalArray::new( - buffer![12345i32, 67890i32, 11111i32], - DecimalDType::new(6, 4), - Validity::AllValid, - ); - - let result = sum(&decimal.into_array()).unwrap(); - - // Scale should be preserved, precision increased by 10 - let expected = Scalar::try_new( - DType::Decimal(DecimalDType::new(16, 4), Nullability::NonNullable), - Some(ScalarValue::from(DecimalValue::from(91346i32))), - ) - .unwrap(); - - assert_eq!(result, expected); - } - - #[test] - fn test_sum_single_value() { - let decimal = - DecimalArray::new(buffer![42i32], DecimalDType::new(3, 1), Validity::AllValid); - - let result = sum(&decimal.into_array()).unwrap(); - - let expected = Scalar::try_new( - DType::Decimal(DecimalDType::new(13, 1), Nullability::NonNullable), - Some(ScalarValue::from(DecimalValue::from(42i32))), - ) - .unwrap(); - - assert_eq!(result, expected); - } - - #[test] - fn test_sum_with_all_nulls_except_one() { - let decimal = DecimalArray::new( - buffer![100i32, 200i32, 300i32, 400i32], - DecimalDType::new(4, 2), - Validity::from_iter([false, false, true, false]), - ); - - let result = sum(&decimal.into_array()).unwrap(); - - let expected = Scalar::try_new( - DType::Decimal(DecimalDType::new(14, 2), Nullability::Nullable), - Some(ScalarValue::from(DecimalValue::from(300i32))), - ) - .unwrap(); - - assert_eq!(result, expected); - } - - #[test] - fn test_sum_i128_to_i256_boundary() { - // Test the boundary between i128 and i256 accumulation - let large_i128 = i128::MAX / 10; - let decimal = DecimalArray::new( - buffer![ - large_i128, large_i128, large_i128, large_i128, large_i128, large_i128, large_i128, - large_i128, large_i128 - ], - DecimalDType::new(38, 0), - Validity::AllValid, - ); - - let result = sum(&decimal.into_array()).unwrap(); - - // Should use i256 for accumulation since 9 * (i128::MAX / 10) fits in i128 but we increase precision - let expected_sum = i256::from_i128(large_i128).wrapping_pow(1) * i256::from_i128(9); - let expected = Scalar::try_new( - DType::Decimal(DecimalDType::new(48, 0), Nullability::NonNullable), - Some(ScalarValue::from(DecimalValue::from(expected_sum))), - ) - .unwrap(); - - assert_eq!(result, expected); - } - - #[test] - fn test_sum_precision_overflow_without_i256_overflow() { - // Construct values that individually fit in precision 76 but whose sum exceeds it, - // while still fitting in `i256`. This ensures we return null for precision overflow - // and not just for arithmetic overflow. - let ten_to_38 = i256::from_i128(10i128.pow(38)); - let ten_to_75 = ten_to_38 * i256::from_i128(10i128.pow(37)); - // 6 * 10^75 is a 76-digit number, which fits in precision 76. - let val = ten_to_75 * i256::from_i128(6); - - let decimal_dtype = DecimalDType::new(76, 0); - let decimal = DecimalArray::new(buffer![val, val], decimal_dtype, Validity::AllValid); - - // Sum = 12 * 10^75 = 1.2 * 10^76, which exceeds precision 76 but fits in `i256`. - let result = sum(&decimal.into_array()).unwrap(); - assert_eq!( - result, - Scalar::null(DType::Decimal(decimal_dtype, Nullability::Nullable)) - ); - } - - #[test] - fn test_i256_overflow() { - let decimal_dtype = DecimalDType::new(76, 0); - let decimal = DecimalArray::new( - buffer![i256::MAX, i256::MAX, i256::MAX], - decimal_dtype, - Validity::AllValid, - ); - - assert_eq!( - sum(&decimal.into_array()).vortex_expect("operation should succeed in test"), - Scalar::null(DType::Decimal(decimal_dtype, Nullability::Nullable)) - ); - } -} diff --git a/vortex-array/src/arrays/extension/compute/mod.rs b/vortex-array/src/arrays/extension/compute/mod.rs index b5761e21187..be90418a962 100644 --- a/vortex-array/src/arrays/extension/compute/mod.rs +++ b/vortex-array/src/arrays/extension/compute/mod.rs @@ -10,7 +10,6 @@ mod mask; mod min_max; pub(crate) mod rules; mod slice; -mod sum; mod take; #[cfg(test)] diff --git a/vortex-array/src/arrays/extension/compute/sum.rs b/vortex-array/src/arrays/extension/compute/sum.rs deleted file mode 100644 index 89798e510da..00000000000 --- a/vortex-array/src/arrays/extension/compute/sum.rs +++ /dev/null @@ -1,20 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: Copyright the Vortex contributors - -use vortex_error::VortexResult; - -use crate::arrays::ExtensionArray; -use crate::arrays::ExtensionVTable; -use crate::compute::SumKernel; -use crate::compute::SumKernelAdapter; -use crate::compute::{self}; -use crate::register_kernel; -use crate::scalar::Scalar; - -impl SumKernel for ExtensionVTable { - fn sum(&self, array: &ExtensionArray, accumulator: &Scalar) -> VortexResult { - compute::sum_with_accumulator(array.storage_array(), accumulator) - } -} - -register_kernel!(SumKernelAdapter(ExtensionVTable).lift()); diff --git a/vortex-array/src/arrays/primitive/compute/mod.rs b/vortex-array/src/arrays/primitive/compute/mod.rs index f0330ef55aa..9800129e5c4 100644 --- a/vortex-array/src/arrays/primitive/compute/mod.rs +++ b/vortex-array/src/arrays/primitive/compute/mod.rs @@ -11,7 +11,6 @@ mod min_max; mod nan_count; pub(crate) mod rules; mod slice; -mod sum; mod take; pub use is_constant::*; diff --git a/vortex-array/src/arrays/primitive/compute/sum.rs b/vortex-array/src/arrays/primitive/compute/sum.rs deleted file mode 100644 index 28ecff0a62b..00000000000 --- a/vortex-array/src/arrays/primitive/compute/sum.rs +++ /dev/null @@ -1,140 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: Copyright the Vortex contributors - -use itertools::Itertools; -use num_traits::CheckedAdd; -use num_traits::Float; -use num_traits::ToPrimitive; -use vortex_buffer::BitBuffer; -use vortex_error::VortexExpect; -use vortex_error::VortexResult; -use vortex_mask::AllOr; - -use crate::arrays::PrimitiveArray; -use crate::arrays::PrimitiveVTable; -use crate::compute::SumKernel; -use crate::compute::SumKernelAdapter; -use crate::dtype::NativePType; -use crate::dtype::Nullability; -use crate::match_each_native_ptype; -use crate::register_kernel; -use crate::scalar::Scalar; - -impl SumKernel for PrimitiveVTable { - fn sum(&self, array: &PrimitiveArray, accumulator: &Scalar) -> VortexResult { - let array_sum_scalar = match array.validity_mask()?.bit_buffer() { - AllOr::All => { - // All-valid - match_each_native_ptype!( - array.ptype(), - unsigned: |T| { - Scalar::from(sum_integer::<_, u64>( - array.as_slice::(), - accumulator.as_primitive().as_::().vortex_expect("cannot be null"), - )) - }, - signed: |T| { - Scalar::from(sum_integer::<_, i64>( - array.as_slice::(), - accumulator.as_primitive().as_::().vortex_expect("cannot be null"), - )) - }, - floating: |T| { - Scalar::primitive( - sum_float( - array.as_slice::(), - accumulator.as_primitive().as_::().vortex_expect("cannot be null"), - ), - Nullability::Nullable, - ) - } - ) - } - AllOr::None => { - // All-invalid, return accumulator - return Ok(accumulator.clone()); - } - AllOr::Some(validity_mask) => { - // Some-valid - match_each_native_ptype!( - array.ptype(), - unsigned: |T| { - Scalar::from(sum_integer_with_validity::<_, u64>( - array.as_slice::(), - validity_mask, - accumulator.as_primitive().as_::().vortex_expect("cannot be null"), - )) - }, - signed: |T| { - Scalar::from(sum_integer_with_validity::<_, i64>( - array.as_slice::(), - validity_mask, - accumulator.as_primitive().as_::().vortex_expect("cannot be null"), - )) - }, - floating: |T| { - Scalar::primitive( - sum_float_with_validity( - array.as_slice::(), - validity_mask, - accumulator.as_primitive().as_::().vortex_expect("cannot be null"), - ), - Nullability::Nullable, - ) - } - ) - } - }; - - Ok(array_sum_scalar) - } -} - -register_kernel!(SumKernelAdapter(PrimitiveVTable).lift()); - -fn sum_integer( - values: &[T], - accumulator: R, -) -> Option { - let mut sum = accumulator; - for &x in values { - sum = sum.checked_add(&R::from(x)?)?; - } - Some(sum) -} - -fn sum_integer_with_validity( - values: &[T], - validity: &BitBuffer, - accumulator: R, -) -> Option { - let mut sum: R = accumulator; - for (&x, valid) in values.iter().zip_eq(validity.iter()) { - if valid { - sum = sum.checked_add(&R::from(x)?)?; - } - } - Some(sum) -} - -fn sum_float(values: &[T], accumulator: f64) -> f64 { - let mut sum = accumulator; - for &x in values { - sum += x.to_f64().vortex_expect("Failed to cast value to f64"); - } - sum -} - -fn sum_float_with_validity( - array: &[T], - validity: &BitBuffer, - accumulator: f64, -) -> f64 { - let mut sum = accumulator; - for (&x, valid) in array.iter().zip_eq(validity.iter()) { - if valid { - sum += x.to_f64().vortex_expect("Failed to cast value to f64"); - } - } - sum -} diff --git a/vortex-array/src/compute/mod.rs b/vortex-array/src/compute/mod.rs index a703502a29d..053ffa8868a 100644 --- a/vortex-array/src/compute/mod.rs +++ b/vortex-array/src/compute/mod.rs @@ -61,7 +61,6 @@ pub fn warm_up_vtables() { is_sorted::warm_up_vtable(); min_max::warm_up_vtable(); nan_count::warm_up_vtable(); - sum::warm_up_vtable(); } impl ComputeFn { diff --git a/vortex-array/src/compute/sum.rs b/vortex-array/src/compute/sum.rs index b74b336d4de..3add5196fa6 100644 --- a/vortex-array/src/compute/sum.rs +++ b/vortex-array/src/compute/sum.rs @@ -1,60 +1,19 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright the Vortex contributors -use std::sync::LazyLock; - -use arcref::ArcRef; -use num_traits::CheckedAdd; -use num_traits::CheckedSub; -use vortex_error::VortexError; use vortex_error::VortexResult; -use vortex_error::vortex_bail; -use vortex_error::vortex_ensure; use vortex_error::vortex_err; -use vortex_error::vortex_panic; +use vortex_session::VortexSession; use crate::ArrayRef; -use crate::DynArray; -use crate::IntoArray as _; -use crate::compute::ComputeFn; -use crate::compute::ComputeFnVTable; -use crate::compute::InvocationArgs; -use crate::compute::Kernel; -use crate::compute::Output; -use crate::dtype::DType; +use crate::aggregate_fn::Accumulator; +use crate::aggregate_fn::DynAccumulator; +use crate::aggregate_fn::EmptyOptions; +use crate::aggregate_fn::fns::sum::Sum; use crate::expr::stats::Precision; use crate::expr::stats::Stat; use crate::expr::stats::StatsProvider; -use crate::scalar::NumericOperator; use crate::scalar::Scalar; -use crate::vtable::VTable; - -static SUM_FN: LazyLock = LazyLock::new(|| { - let compute = ComputeFn::new("sum".into(), ArcRef::new_ref(&Sum)); - for kernel in inventory::iter:: { - compute.register_kernel(kernel.0.clone()); - } - compute -}); - -pub(crate) fn warm_up_vtable() -> usize { - SUM_FN.kernels().len() -} - -/// Sum an array with an initial value. -/// -/// If the sum overflows, a null scalar will be returned. -/// If the sum is not supported for the array's dtype, an error will be raised. -/// If the array is all-invalid, the sum will be the accumulator. -/// The accumulator must have a dtype compatible with the sum result dtype. -pub(crate) fn sum_with_accumulator(array: &ArrayRef, accumulator: &Scalar) -> VortexResult { - SUM_FN - .invoke(&InvocationArgs { - inputs: &[array.into(), accumulator.into()], - options: &(), - })? - .unwrap_scalar() -} /// Sum an array, starting from zero. /// @@ -62,237 +21,130 @@ pub(crate) fn sum_with_accumulator(array: &ArrayRef, accumulator: &Scalar) -> Vo /// If the sum is not supported for the array's dtype, an error will be raised. /// If the array is all-invalid, the sum will be zero. pub fn sum(array: &ArrayRef) -> VortexResult { - let sum_dtype = Stat::Sum + // Validate that sum is supported for this dtype. + Stat::Sum .dtype(array.dtype()) .ok_or_else(|| vortex_err!("Sum not supported for dtype: {}", array.dtype()))?; - let zero = Scalar::zero_value(&sum_dtype); - sum_with_accumulator(array, &zero) -} - -/// For unary compute functions, it's useful to just have this short-cut. -pub struct SumArgs<'a> { - pub array: &'a dyn DynArray, - pub accumulator: &'a Scalar, -} - -impl<'a> TryFrom<&InvocationArgs<'a>> for SumArgs<'a> { - type Error = VortexError; - - fn try_from(value: &InvocationArgs<'a>) -> Result { - if value.inputs.len() != 2 { - vortex_bail!("Expected 2 inputs, found {}", value.inputs.len()); - } - let array = value.inputs[0] - .array() - .ok_or_else(|| vortex_err!("Expected input 0 to be an array"))?; - let accumulator = value.inputs[1] - .scalar() - .ok_or_else(|| vortex_err!("Expected input 1 to be a scalar"))?; - Ok(SumArgs { array, accumulator }) - } -} - -struct Sum; - -impl ComputeFnVTable for Sum { - fn invoke( - &self, - args: &InvocationArgs, - kernels: &[ArcRef], - ) -> VortexResult { - let SumArgs { array, accumulator } = args.try_into()?; - let array = array.to_array(); - - // Compute the expected dtype of the sum. - let sum_dtype = self.return_dtype(args)?; - - vortex_ensure!( - &sum_dtype == accumulator.dtype(), - "sum_dtype {sum_dtype} must match accumulator dtype {}", - accumulator.dtype() - ); - - // Short-circuit using array statistics. - if let Some(Precision::Exact(sum_scalar)) = array.statistics().get(Stat::Sum) { - // For floats only use stats if accumulator is zero. otherwise we might have numerical - // stability issues. - match &sum_dtype { - DType::Primitive(p, _) => { - if p.is_float() && accumulator.is_zero() == Some(true) { - return Ok(sum_scalar.into()); - } else if p.is_int() { - let sum_from_stat = accumulator - .as_primitive() - .checked_add(&sum_scalar.as_primitive()) - .map(Scalar::from); - return Ok(sum_from_stat - .unwrap_or_else(|| Scalar::null(sum_dtype)) - .into()); - } - } - DType::Decimal(..) => { - let sum_from_stat = accumulator - .as_decimal() - .checked_binary_numeric(&sum_scalar.as_decimal(), NumericOperator::Add) - .map(Scalar::from); - return Ok(sum_from_stat - .unwrap_or_else(|| Scalar::null(sum_dtype)) - .into()); - } - _ => unreachable!("Sum will always be a decimal or a primitive dtype"), - } - } - - let sum_scalar = sum_impl(&array, accumulator, kernels)?; - - // Update the statistics with the computed sum. Stored statistic shouldn't include the accumulator. - match sum_dtype { - DType::Primitive(p, _) => { - if p.is_float() - && accumulator.is_zero() == Some(true) - && let Some(sum_value) = sum_scalar.value().cloned() - { - array - .statistics() - .set(Stat::Sum, Precision::Exact(sum_value)); - } else if p.is_int() - && let Some(less_accumulator) = sum_scalar - .as_primitive() - .checked_sub(&accumulator.as_primitive()) - && let Some(val) = Scalar::from(less_accumulator).into_value() - { - array.statistics().set(Stat::Sum, Precision::Exact(val)); - } - } - DType::Decimal(..) => { - if let Some(less_accumulator) = sum_scalar - .as_decimal() - .checked_binary_numeric(&accumulator.as_decimal(), NumericOperator::Sub) - && let Some(val) = Scalar::from(less_accumulator).into_value() - { - array.statistics().set(Stat::Sum, Precision::Exact(val)); - } - } - _ => unreachable!("Sum will always be a decimal or a primitive dtype"), - } - - Ok(sum_scalar.into()) - } - - fn return_dtype(&self, args: &InvocationArgs) -> VortexResult { - let SumArgs { array, .. } = args.try_into()?; - Stat::Sum - .dtype(array.dtype()) - .ok_or_else(|| vortex_err!("Sum not supported for dtype: {}", array.dtype())) - } - - fn return_len(&self, _args: &InvocationArgs) -> VortexResult { - // The sum function always returns a single scalar value. - Ok(1) - } - - fn is_elementwise(&self) -> bool { - false - } -} -pub struct SumKernelRef(ArcRef); -inventory::collect!(SumKernelRef); - -pub trait SumKernel: VTable { - /// # Preconditions - /// - /// * The array's DType is summable - /// * The array is not all-null - /// * The accumulator must have a dtype compatible with the sum result dtype - fn sum(&self, array: &Self::Array, accumulator: &Scalar) -> VortexResult; -} - -#[derive(Debug)] -pub struct SumKernelAdapter(pub V); - -impl SumKernelAdapter { - pub const fn lift(&'static self) -> SumKernelRef { - SumKernelRef(ArcRef::new_ref(self)) + // Short-circuit using cached array statistics. + if let Some(Precision::Exact(sum_scalar)) = array.statistics().get(Stat::Sum) { + return Ok(sum_scalar); } -} - -impl Kernel for SumKernelAdapter { - fn invoke(&self, args: &InvocationArgs) -> VortexResult> { - let SumArgs { array, accumulator } = args.try_into()?; - let Some(array) = array.as_opt::() else { - return Ok(None); - }; - Ok(Some(V::sum(&self.0, array, accumulator)?.into())) - } -} -/// Sum an array. -/// -/// If the sum overflows, a null scalar will be returned. -/// If the sum is not supported for the array's dtype, an error will be raised. -/// If the array is all-invalid, the sum will be the accumulator. -pub fn sum_impl( - array: &ArrayRef, - accumulator: &Scalar, - kernels: &[ArcRef], -) -> VortexResult { - if array.is_empty() || array.all_invalid()? || accumulator.is_null() { - return Ok(accumulator.clone()); - } - - // Try to find a sum kernel - let args = InvocationArgs { - inputs: &[array.into(), accumulator.into()], - options: &(), - }; - for kernel in kernels { - if let Some(output) = kernel.invoke(&args)? { - return output.unwrap_scalar(); - } + // Compute using Accumulator. + let mut acc = Accumulator::try_new( + Sum, + EmptyOptions, + array.dtype().clone(), + VortexSession::empty(), + )?; + acc.accumulate(array)?; + let result = acc.finish()?; + + // Cache the computed sum as a statistic (only if non-null, i.e. no overflow). + if let Some(val) = result.value().cloned() { + array.statistics().set(Stat::Sum, Precision::Exact(val)); } - // Otherwise, canonicalize and try again. - tracing::debug!("No sum implementation found for {}", array.encoding_id()); - if array.is_canonical() { - // Panic to avoid recursion, but it should never be hit. - vortex_panic!( - "No sum implementation found for canonical array: {}", - array.encoding_id() - ); - } - let canonical = array.to_canonical()?.into_array(); - sum_with_accumulator(&canonical, accumulator) + Ok(result) } #[cfg(test)] -mod test { +mod tests { + use num_traits::CheckedAdd; use vortex_buffer::buffer; use vortex_error::VortexExpect; + use vortex_error::VortexResult; + use crate::ArrayRef; + use crate::DynArray; use crate::IntoArray as _; use crate::arrays::BoolArray; use crate::arrays::ChunkedArray; + use crate::arrays::ConstantArray; + use crate::arrays::DecimalArray; use crate::arrays::PrimitiveArray; use crate::compute::sum; - use crate::compute::sum_with_accumulator; use crate::dtype::DType; + use crate::dtype::DecimalDType; use crate::dtype::Nullability; + use crate::dtype::Nullability::Nullable; use crate::dtype::PType; + use crate::dtype::i256; + use crate::expr::stats::Precision; + use crate::expr::stats::Stat; + use crate::expr::stats::StatsProvider; + use crate::scalar::DecimalValue; + use crate::scalar::NumericOperator; use crate::scalar::Scalar; + use crate::scalar::ScalarValue; + use crate::validity::Validity; + + /// Sum an array with an initial value (test-only helper). + fn sum_with_accumulator(array: &ArrayRef, accumulator: &Scalar) -> VortexResult { + if accumulator.is_null() { + return Ok(accumulator.clone()); + } + if accumulator.is_zero() == Some(true) { + return sum(array); + } + + let sum_dtype = Stat::Sum.dtype(array.dtype()).ok_or_else(|| { + vortex_error::vortex_err!("Sum not supported for dtype: {}", array.dtype()) + })?; + + // For non-float types, try statistics short-circuit with accumulator. + if !matches!(&sum_dtype, DType::Primitive(p, _) if p.is_float()) + && let Some(Precision::Exact(sum_scalar)) = array.statistics().get(Stat::Sum) + { + return add_scalars(&sum_dtype, &sum_scalar, accumulator); + } + + // Compute array sum from zero (also caches stats). + let array_sum = sum(array)?; + + // Combine with the accumulator. + add_scalars(&sum_dtype, &array_sum, accumulator) + } + + /// Add two sum scalars with overflow checking. + fn add_scalars(sum_dtype: &DType, lhs: &Scalar, rhs: &Scalar) -> VortexResult { + if lhs.is_null() || rhs.is_null() { + return Ok(Scalar::null(sum_dtype.as_nullable())); + } + + Ok(match sum_dtype { + DType::Primitive(ptype, _) if ptype.is_float() => { + let lhs_val = f64::try_from(lhs)?; + let rhs_val = f64::try_from(rhs)?; + Scalar::primitive(lhs_val + rhs_val, Nullable) + } + DType::Primitive(..) => lhs + .as_primitive() + .checked_add(&rhs.as_primitive()) + .map(Scalar::from) + .unwrap_or_else(|| Scalar::null(sum_dtype.as_nullable())), + DType::Decimal(..) => lhs + .as_decimal() + .checked_binary_numeric(&rhs.as_decimal(), NumericOperator::Add) + .map(Scalar::from) + .unwrap_or_else(|| Scalar::null(sum_dtype.as_nullable())), + _ => unreachable!("Sum will always be a decimal or a primitive dtype"), + }) + } #[test] fn sum_all_invalid() { let array = PrimitiveArray::from_option_iter::([None, None, None]).into_array(); let result = sum(&array).unwrap(); - assert_eq!(result, Scalar::primitive(0i64, Nullability::Nullable)); + assert_eq!(result, Scalar::primitive(0i64, Nullable)); } #[test] fn sum_all_invalid_float() { let array = PrimitiveArray::from_option_iter::([None, None, None]).into_array(); let result = sum(&array).unwrap(); - assert_eq!(result, Scalar::primitive(0f64, Nullability::Nullable)); + assert_eq!(result, Scalar::primitive(0f64, Nullable)); } #[test] @@ -328,12 +180,484 @@ mod test { .vortex_expect("operation should succeed in test"); let array = array.into_array(); // compute sum with accumulator to populate stats - sum_with_accumulator(&array, &Scalar::primitive(2i64, Nullability::Nullable)).unwrap(); + sum_with_accumulator(&array, &Scalar::primitive(2i64, Nullable)).unwrap(); let sum_without_acc = sum(&array).unwrap(); + assert_eq!(sum_without_acc, Scalar::primitive(9i64, Nullable)); + } + + // -- Constant array tests (migrated from constant/compute/sum.rs) -- + + #[test] + fn sum_constant_unsigned() { + let array = ConstantArray::new(5u64, 10).into_array(); + let result = sum(&array).unwrap(); + assert_eq!(result, 50u64.into()); + } + + #[test] + fn sum_constant_signed() { + let array = ConstantArray::new(-5i64, 10).into_array(); + let result = sum(&array).unwrap(); + assert_eq!(result, (-50i64).into()); + } + + #[test] + fn sum_constant_nullable_value() { + let array = ConstantArray::new(Scalar::null(DType::Primitive(PType::U32, Nullable)), 10) + .into_array(); + let result = sum(&array).unwrap(); + assert_eq!(result, Scalar::primitive(0u64, Nullable)); + } + + #[test] + fn sum_constant_bool_false() { + let array = ConstantArray::new(false, 10).into_array(); + let result = sum(&array).unwrap(); + assert_eq!(result, 0u64.into()); + } + + #[test] + fn sum_constant_bool_true() { + let array = ConstantArray::new(true, 10).into_array(); + let result = sum(&array).unwrap(); + assert_eq!(result, 10u64.into()); + } + + #[test] + fn sum_constant_bool_null() { + let array = ConstantArray::new(Scalar::null(DType::Bool(Nullable)), 10).into_array(); + let result = sum(&array).unwrap(); + assert_eq!(result, Scalar::primitive(0u64, Nullable)); + } + + #[test] + fn sum_constant_decimal() { + let decimal_dtype = DecimalDType::new(10, 2); + let array = ConstantArray::new( + Scalar::decimal( + DecimalValue::I64(100), + decimal_dtype, + Nullability::NonNullable, + ), + 5, + ) + .into_array(); + + let result = sum(&array).unwrap(); + + assert_eq!( + result.as_decimal().decimal_value(), + Some(DecimalValue::I256(i256::from_i128(500))) + ); + assert_eq!(result.dtype(), &Stat::Sum.dtype(array.dtype()).unwrap()); + } + + #[test] + fn sum_constant_decimal_null() { + let decimal_dtype = DecimalDType::new(10, 2); + let array = ConstantArray::new(Scalar::null(DType::Decimal(decimal_dtype, Nullable)), 10) + .into_array(); + + let result = sum(&array).unwrap(); + assert_eq!( + result, + Scalar::decimal( + DecimalValue::I256(i256::ZERO), + DecimalDType::new(20, 2), + Nullable + ) + ); + } + + #[test] + fn sum_constant_decimal_large_value() { + let decimal_dtype = DecimalDType::new(10, 2); + let array = ConstantArray::new( + Scalar::decimal( + DecimalValue::I64(999_999_999), + decimal_dtype, + Nullability::NonNullable, + ), + 100, + ) + .into_array(); + + let result = sum(&array).unwrap(); + assert_eq!( + result.as_decimal().decimal_value(), + Some(DecimalValue::I256(i256::from_i128(99_999_999_900))) + ); + } + + #[test] + fn sum_constant_float_non_multiply() { + let acc = -2048669276050936500000000000f64; + let array = ConstantArray::new(6.1811675e16f64, 25); + let sum = sum_with_accumulator(&array.into_array(), &Scalar::primitive(acc, Nullable)) + .vortex_expect("operation should succeed in test"); + assert_eq!( + f64::try_from(&sum).vortex_expect("operation should succeed in test"), + -2048669274505644600000000000f64 + ); + } + + // -- Decimal array tests (migrated from decimal/compute/sum.rs) -- + + #[test] + fn sum_decimal_basic() { + let decimal = DecimalArray::new( + buffer![100i32, 200i32, 300i32], + DecimalDType::new(4, 2), + Validity::AllValid, + ); + + let result = sum(&decimal.into_array()).unwrap(); + + let expected = Scalar::try_new( + DType::Decimal(DecimalDType::new(14, 2), Nullability::NonNullable), + Some(ScalarValue::from(DecimalValue::from(600i32))), + ) + .unwrap(); + + assert_eq!(result, expected); + } + + #[test] + fn sum_decimal_with_nulls() { + let decimal = DecimalArray::new( + buffer![100i32, 200i32, 300i32, 400i32], + DecimalDType::new(4, 2), + Validity::from_iter([true, false, true, true]), + ); + + let result = sum(&decimal.into_array()).unwrap(); + + let expected = Scalar::try_new( + DType::Decimal(DecimalDType::new(14, 2), Nullable), + Some(ScalarValue::from(DecimalValue::from(800i32))), + ) + .unwrap(); + + assert_eq!(result, expected); + } + + #[test] + fn sum_decimal_negative_values() { + let decimal = DecimalArray::new( + buffer![100i32, -200i32, 300i32, -50i32], + DecimalDType::new(4, 2), + Validity::AllValid, + ); + + let result = sum(&decimal.into_array()).unwrap(); + + let expected = Scalar::try_new( + DType::Decimal(DecimalDType::new(14, 2), Nullability::NonNullable), + Some(ScalarValue::from(DecimalValue::from(150i32))), + ) + .unwrap(); + + assert_eq!(result, expected); + } + + #[test] + fn sum_decimal_near_i32_max() { + let near_max = i32::MAX - 1000; + let decimal = DecimalArray::new( + buffer![near_max, 500i32, 400i32], + DecimalDType::new(10, 2), + Validity::AllValid, + ); + + let result = sum(&decimal.into_array()).unwrap(); + + let expected_sum = near_max as i64 + 500 + 400; + let expected = Scalar::try_new( + DType::Decimal(DecimalDType::new(20, 2), Nullability::NonNullable), + Some(ScalarValue::from(DecimalValue::from(expected_sum))), + ) + .unwrap(); + + assert_eq!(result, expected); + } + + #[test] + fn sum_decimal_large_i64_values() { + let large_val = i64::MAX / 4; + let decimal = DecimalArray::new( + buffer![large_val, large_val, large_val, large_val + 1], + DecimalDType::new(19, 0), + Validity::AllValid, + ); + + let result = sum(&decimal.into_array()).unwrap(); + + let expected_sum = (large_val as i128) * 4 + 1; + let expected = Scalar::try_new( + DType::Decimal(DecimalDType::new(29, 0), Nullability::NonNullable), + Some(ScalarValue::from(DecimalValue::from(expected_sum))), + ) + .unwrap(); + + assert_eq!(result, expected); + } + + #[test] + fn sum_decimal_preserves_scale() { + let decimal = DecimalArray::new( + buffer![12345i32, 67890i32, 11111i32], + DecimalDType::new(6, 4), + Validity::AllValid, + ); + + let result = sum(&decimal.into_array()).unwrap(); + + let expected = Scalar::try_new( + DType::Decimal(DecimalDType::new(16, 4), Nullability::NonNullable), + Some(ScalarValue::from(DecimalValue::from(91346i32))), + ) + .unwrap(); + + assert_eq!(result, expected); + } + + #[test] + fn sum_decimal_single_value() { + let decimal = + DecimalArray::new(buffer![42i32], DecimalDType::new(3, 1), Validity::AllValid); + + let result = sum(&decimal.into_array()).unwrap(); + + let expected = Scalar::try_new( + DType::Decimal(DecimalDType::new(13, 1), Nullability::NonNullable), + Some(ScalarValue::from(DecimalValue::from(42i32))), + ) + .unwrap(); + + assert_eq!(result, expected); + } + + #[test] + fn sum_decimal_all_nulls_except_one() { + let decimal = DecimalArray::new( + buffer![100i32, 200i32, 300i32, 400i32], + DecimalDType::new(4, 2), + Validity::from_iter([false, false, true, false]), + ); + + let result = sum(&decimal.into_array()).unwrap(); + + let expected = Scalar::try_new( + DType::Decimal(DecimalDType::new(14, 2), Nullable), + Some(ScalarValue::from(DecimalValue::from(300i32))), + ) + .unwrap(); + + assert_eq!(result, expected); + } + + #[test] + fn sum_decimal_overflow_detection() { + let max_val = i128::MAX / 2; + let decimal = DecimalArray::new( + buffer![max_val, max_val, max_val], + DecimalDType::new(38, 0), + Validity::AllValid, + ); + + let result = sum(&decimal.into_array()).unwrap(); + + let expected_sum = + i256::from_i128(max_val) + i256::from_i128(max_val) + i256::from_i128(max_val); + let expected = Scalar::try_new( + DType::Decimal(DecimalDType::new(48, 0), Nullability::NonNullable), + Some(ScalarValue::from(DecimalValue::from(expected_sum))), + ) + .unwrap(); + + assert_eq!(result, expected); + } + + #[test] + fn sum_decimal_i256_overflow() { + let decimal_dtype = DecimalDType::new(76, 0); + let decimal = DecimalArray::new( + buffer![i256::MAX, i256::MAX, i256::MAX], + decimal_dtype, + Validity::AllValid, + ); + + assert_eq!( + sum(&decimal.into_array()).vortex_expect("operation should succeed in test"), + Scalar::null(DType::Decimal(decimal_dtype, Nullable)) + ); + } + + // -- Chunked array tests (migrated from chunked/compute/sum.rs) -- + + #[test] + fn sum_chunked_floats_with_nulls() { + let chunk1 = + PrimitiveArray::from_option_iter(vec![Some(1.5f64), None, Some(3.2), Some(4.8)]); + let chunk2 = PrimitiveArray::from_option_iter(vec![Some(2.1f64), Some(5.7), None]); + let chunk3 = PrimitiveArray::from_option_iter(vec![None, Some(1.0f64), Some(2.5), None]); + let dtype = chunk1.dtype().clone(); + let chunked = ChunkedArray::try_new( + vec![ + chunk1.into_array(), + chunk2.into_array(), + chunk3.into_array(), + ], + dtype, + ) + .unwrap(); + + let result = sum(&chunked.into_array()).unwrap(); + assert_eq!(result.as_primitive().as_::(), Some(20.8)); + } + + #[test] + fn sum_chunked_floats_all_nulls_is_zero() { + let chunk1 = PrimitiveArray::from_option_iter::(vec![None, None, None]); + let chunk2 = PrimitiveArray::from_option_iter::(vec![None, None]); + let dtype = chunk1.dtype().clone(); + let chunked = + ChunkedArray::try_new(vec![chunk1.into_array(), chunk2.into_array()], dtype).unwrap(); + let result = sum(&chunked.into_array()).unwrap(); + assert_eq!(result, Scalar::primitive(0f64, Nullable)); + } + + #[test] + fn sum_chunked_floats_empty_chunks() { + let chunk1 = PrimitiveArray::from_option_iter(vec![Some(10.5f64), Some(20.3)]); + let chunk2 = ConstantArray::new(Scalar::primitive(0f64, Nullable), 0); + let chunk3 = PrimitiveArray::from_option_iter(vec![Some(5.2f64)]); + let dtype = chunk1.dtype().clone(); + let chunked = ChunkedArray::try_new( + vec![ + chunk1.into_array(), + chunk2.into_array(), + chunk3.into_array(), + ], + dtype, + ) + .unwrap(); + + let result = sum(&chunked.into_array()).unwrap(); + assert_eq!(result.as_primitive().as_::(), Some(36.0)); + } + + #[test] + fn sum_chunked_int_almost_all_null() { + let chunk1 = PrimitiveArray::from_option_iter::(vec![Some(1)]); + let chunk2 = PrimitiveArray::from_option_iter::(vec![None]); + let dtype = chunk1.dtype().clone(); + let chunked = + ChunkedArray::try_new(vec![chunk1.into_array(), chunk2.into_array()], dtype).unwrap(); + + let result = sum(&chunked.into_array()).unwrap(); + assert_eq!(result.as_primitive().as_::(), Some(1)); + } + + #[test] + fn sum_chunked_decimals() { + let decimal_dtype = DecimalDType::new(10, 2); + let chunk1 = DecimalArray::new( + buffer![100i32, 100i32, 100i32, 100i32, 100i32], + decimal_dtype, + Validity::AllValid, + ); + let chunk2 = DecimalArray::new( + buffer![200i32, 200i32, 200i32], + decimal_dtype, + Validity::AllValid, + ); + let chunk3 = DecimalArray::new(buffer![300i32, 300i32], decimal_dtype, Validity::AllValid); + let dtype = chunk1.dtype().clone(); + let chunked = ChunkedArray::try_new( + vec![ + chunk1.into_array(), + chunk2.into_array(), + chunk3.into_array(), + ], + dtype, + ) + .unwrap(); + + let result = sum(&chunked.into_array()).unwrap(); + let decimal_result = result.as_decimal(); + assert_eq!( + decimal_result.decimal_value(), + Some(DecimalValue::I256(i256::from_i128(1700))) + ); + } + + #[test] + fn sum_chunked_decimals_with_nulls() { + let decimal_dtype = DecimalDType::new(10, 2); + let chunk1 = DecimalArray::new( + buffer![100i32, 100i32, 100i32], + decimal_dtype, + Validity::AllValid, + ); + let chunk2 = DecimalArray::new( + buffer![0i32, 0i32], + decimal_dtype, + Validity::from_iter([false, false]), + ); + let chunk3 = DecimalArray::new(buffer![200i32, 200i32], decimal_dtype, Validity::AllValid); + let dtype = chunk1.dtype().clone(); + let chunked = ChunkedArray::try_new( + vec![ + chunk1.into_array(), + chunk2.into_array(), + chunk3.into_array(), + ], + dtype, + ) + .unwrap(); + + let result = sum(&chunked.into_array()).unwrap(); + let decimal_result = result.as_decimal(); + assert_eq!( + decimal_result.decimal_value(), + Some(DecimalValue::I256(i256::from_i128(700))) + ); + } + + #[test] + fn sum_chunked_decimals_large() { + let decimal_dtype = DecimalDType::new(3, 0); + let chunk1 = ConstantArray::new( + Scalar::decimal( + DecimalValue::I16(500), + decimal_dtype, + Nullability::NonNullable, + ), + 1, + ); + let chunk2 = ConstantArray::new( + Scalar::decimal( + DecimalValue::I16(600), + decimal_dtype, + Nullability::NonNullable, + ), + 1, + ); + let dtype = chunk1.dtype().clone(); + let chunked = + ChunkedArray::try_new(vec![chunk1.into_array(), chunk2.into_array()], dtype).unwrap(); + + let result = sum(&chunked.into_array()).unwrap(); + let decimal_result = result.as_decimal(); + assert_eq!( + decimal_result.decimal_value(), + Some(DecimalValue::I256(i256::from_i128(1100))) + ); assert_eq!( - sum_without_acc, - Scalar::primitive(9i64, Nullability::Nullable) + result.dtype(), + &DType::Decimal(DecimalDType::new(13, 0), Nullable) ); } } From 2dac50c5bb443eedee85a8fb75eb6578b33ec606 Mon Sep 17 00:00:00 2001 From: Nicholas Gates Date: Fri, 13 Mar 2026 12:37:03 -0400 Subject: [PATCH 2/4] Clean up to point to aggregate function Signed-off-by: Nicholas Gates --- fuzz/src/array/mod.rs | 2 +- fuzz/src/array/sum.rs | 2 +- vortex-array/src/aggregate_fn/accumulator.rs | 17 +++++- .../src/aggregate_fn/accumulator_grouped.rs | 17 +++++- vortex-array/src/aggregate_fn/erased.rs | 8 ++- vortex-array/src/aggregate_fn/fns/sum/bool.rs | 4 +- vortex-array/src/aggregate_fn/fns/sum/mod.rs | 52 +++++++++++++++---- vortex-array/src/aggregate_fn/typed.rs | 8 +-- vortex-array/src/aggregate_fn/vtable.rs | 8 ++- vortex-array/src/expr/stats/mod.rs | 39 ++------------ 10 files changed, 98 insertions(+), 59 deletions(-) diff --git a/fuzz/src/array/mod.rs b/fuzz/src/array/mod.rs index 87c219bb4f5..ba6e519d264 100644 --- a/fuzz/src/array/mod.rs +++ b/fuzz/src/array/mod.rs @@ -42,13 +42,13 @@ use tracing::debug; use vortex_array::ArrayRef; use vortex_array::DynArray; use vortex_array::IntoArray; +use vortex_array::aggregate_fn::fns::sum::sum; use vortex_array::arrays::ConstantArray; use vortex_array::arrays::PrimitiveArray; use vortex_array::arrays::arbitrary::ArbitraryArray; use vortex_array::builtins::ArrayBuiltins; use vortex_array::compute::MinMaxResult; use vortex_array::compute::min_max; -use vortex_array::compute::sum; use vortex_array::dtype::DType; use vortex_array::dtype::Nullability; use vortex_array::scalar::Scalar; diff --git a/fuzz/src/array/sum.rs b/fuzz/src/array/sum.rs index eec4c3954e4..3f6c779de56 100644 --- a/fuzz/src/array/sum.rs +++ b/fuzz/src/array/sum.rs @@ -3,7 +3,7 @@ use vortex_array::Canonical; use vortex_array::IntoArray as _; -use vortex_array::compute::sum; +use vortex_array::aggregate_fn::fns::sum::sum; use vortex_array::scalar::Scalar; use vortex_error::VortexResult; diff --git a/vortex-array/src/aggregate_fn/accumulator.rs b/vortex-array/src/aggregate_fn/accumulator.rs index 5e9a12e53fb..6a85b0e3120 100644 --- a/vortex-array/src/aggregate_fn/accumulator.rs +++ b/vortex-array/src/aggregate_fn/accumulator.rs @@ -3,6 +3,7 @@ use vortex_error::VortexResult; use vortex_error::vortex_ensure; +use vortex_error::vortex_err; use vortex_session::VortexSession; use crate::AnyCanonical; @@ -46,8 +47,20 @@ impl Accumulator { dtype: DType, session: VortexSession, ) -> VortexResult { - let return_dtype = vtable.return_dtype(&options, &dtype)?; - let partial_dtype = vtable.partial_dtype(&options, &dtype)?; + let return_dtype = vtable.return_dtype(&options, &dtype).ok_or_else(|| { + vortex_err!( + "Aggregate function {} cannot be applied to dtype {}", + vtable.id(), + dtype + ) + })?; + let partial_dtype = vtable.partial_dtype(&options, &dtype).ok_or_else(|| { + vortex_err!( + "Aggregate function {} cannot be applied to dtype {}", + vtable.id(), + dtype + ) + })?; let partial = vtable.empty_partial(&options, &dtype)?; let aggregate_fn = AggregateFn::new(vtable.clone(), options).erased(); diff --git a/vortex-array/src/aggregate_fn/accumulator_grouped.rs b/vortex-array/src/aggregate_fn/accumulator_grouped.rs index b2b9cf38b35..afc7d29a71d 100644 --- a/vortex-array/src/aggregate_fn/accumulator_grouped.rs +++ b/vortex-array/src/aggregate_fn/accumulator_grouped.rs @@ -7,6 +7,7 @@ use vortex_error::VortexExpect; use vortex_error::VortexResult; use vortex_error::vortex_bail; use vortex_error::vortex_ensure; +use vortex_error::vortex_err; use vortex_error::vortex_panic; use vortex_mask::Mask; use vortex_session::VortexSession; @@ -70,8 +71,20 @@ impl GroupedAccumulator { session: VortexSession, ) -> VortexResult { let aggregate_fn = AggregateFn::new(vtable.clone(), options.clone()).erased(); - let return_dtype = vtable.return_dtype(&options, &dtype)?; - let partial_dtype = vtable.partial_dtype(&options, &dtype)?; + let return_dtype = vtable.return_dtype(&options, &dtype).ok_or_else(|| { + vortex_err!( + "Aggregate function {} cannot be applied to dtype {}", + vtable.id(), + dtype + ) + })?; + let partial_dtype = vtable.partial_dtype(&options, &dtype).ok_or_else(|| { + vortex_err!( + "Aggregate function {} cannot be applied to dtype {}", + vtable.id(), + dtype + ) + })?; Ok(Self { vtable, diff --git a/vortex-array/src/aggregate_fn/erased.rs b/vortex-array/src/aggregate_fn/erased.rs index 0d90a395499..b8953aabc3c 100644 --- a/vortex-array/src/aggregate_fn/erased.rs +++ b/vortex-array/src/aggregate_fn/erased.rs @@ -76,12 +76,16 @@ impl AggregateFnRef { } /// Compute the return [`DType`] per group given the input element type. - pub fn return_dtype(&self, input_dtype: &DType) -> VortexResult { + /// + /// Returns `None` if the input dtype is not supported by the aggregate function. + pub fn return_dtype(&self, input_dtype: &DType) -> Option { self.0.return_dtype(input_dtype) } /// DType of the intermediate accumulator state. - pub fn state_dtype(&self, input_dtype: &DType) -> VortexResult { + /// + /// Returns `None` if the input dtype is not supported by the aggregate function. + pub fn state_dtype(&self, input_dtype: &DType) -> Option { self.0.state_dtype(input_dtype) } diff --git a/vortex-array/src/aggregate_fn/fns/sum/bool.rs b/vortex-array/src/aggregate_fn/fns/sum/bool.rs index 0b88c657af1..e1932feea92 100644 --- a/vortex-array/src/aggregate_fn/fns/sum/bool.rs +++ b/vortex-array/src/aggregate_fn/fns/sum/bool.rs @@ -115,7 +115,9 @@ mod tests { #[test] fn sum_bool_return_dtype() -> VortexResult<()> { - let dtype = Sum.return_dtype(&EmptyOptions, &DType::Bool(Nullability::NonNullable))?; + let dtype = Sum + .return_dtype(&EmptyOptions, &DType::Bool(Nullability::NonNullable)) + .unwrap(); assert_eq!(dtype, DType::Primitive(PType::U64, Nullability::Nullable)); Ok(()) } diff --git a/vortex-array/src/aggregate_fn/fns/sum/mod.rs b/vortex-array/src/aggregate_fn/fns/sum/mod.rs index 9c82b040372..9ed6baff7c6 100644 --- a/vortex-array/src/aggregate_fn/fns/sum/mod.rs +++ b/vortex-array/src/aggregate_fn/fns/sum/mod.rs @@ -27,6 +27,8 @@ use crate::aggregate_fn::AggregateFnVTable; use crate::aggregate_fn::DynAccumulator; use crate::aggregate_fn::EmptyOptions; use crate::dtype::DType; +use crate::dtype::DecimalDType; +use crate::dtype::MAX_PRECISION; use crate::dtype::Nullability; use crate::dtype::PType; use crate::expr::stats::Precision; @@ -78,25 +80,55 @@ impl AggregateFnVTable for Sum { AggregateFnId::new_ref("vortex.sum") } - fn return_dtype(&self, _options: &Self::Options, input_dtype: &DType) -> VortexResult { - Stat::Sum - .dtype(input_dtype) - .ok_or_else(|| vortex_err!("Cannot sum {}", input_dtype)) + fn return_dtype(&self, _options: &Self::Options, input_dtype: &DType) -> Option { + // When a sum overflows, we return a sum _value_ of null. Therefore, we all return dtypes + // are nullable. + use Nullability::Nullable; + + Some(match input_dtype { + DType::Bool(_) => DType::Primitive(PType::U64, Nullable), + DType::Primitive(ptype, _) => match ptype { + PType::U8 | PType::U16 | PType::U32 | PType::U64 => { + DType::Primitive(PType::U64, Nullable) + } + PType::I8 | PType::I16 | PType::I32 | PType::I64 => { + DType::Primitive(PType::I64, Nullable) + } + PType::F16 | PType::F32 | PType::F64 => { + // Float sums cannot overflow, but all null floats still end up as null + DType::Primitive(PType::F64, Nullable) + } + }, + DType::Extension(ext_dtype) => { + self.return_dtype(_options, ext_dtype.storage_dtype())? + } + DType::Decimal(decimal_dtype, _) => { + // Both Spark and DataFusion use this heuristic. + // - https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66 + // - https://github.com/apache/datafusion/blob/4153adf2c0f6e317ef476febfdc834208bd46622/datafusion/functions-aggregate/src/sum.rs#L188 + let precision = u8::min(MAX_PRECISION, decimal_dtype.precision() + 10); + DType::Decimal( + DecimalDType::new(precision, decimal_dtype.scale()), + Nullable, + ) + } + // Unsupported types + _ => return None, + }) } - fn partial_dtype(&self, options: &Self::Options, input_dtype: &DType) -> VortexResult { + fn partial_dtype(&self, options: &Self::Options, input_dtype: &DType) -> Option { self.return_dtype(options, input_dtype) } fn empty_partial( &self, - _options: &Self::Options, + options: &Self::Options, input_dtype: &DType, ) -> VortexResult { - let return_dtype = Stat::Sum - .dtype(input_dtype) - .ok_or_else(|| vortex_err!("Cannot sum {}", input_dtype))?; - + let return_dtype = self + .return_dtype(options, input_dtype) + .ok_or_else(|| vortex_err!("Unsupported sum dtype: {}", input_dtype))?; let initial = make_zero_state(&return_dtype); Ok(SumPartial { diff --git a/vortex-array/src/aggregate_fn/typed.rs b/vortex-array/src/aggregate_fn/typed.rs index fce90ecf33c..e3bf66037ca 100644 --- a/vortex-array/src/aggregate_fn/typed.rs +++ b/vortex-array/src/aggregate_fn/typed.rs @@ -39,8 +39,8 @@ pub(super) trait DynAggregateFn: 'static + Send + Sync + super::sealed::Sealed { fn id(&self) -> AggregateFnId; fn options_any(&self) -> &dyn Any; - fn return_dtype(&self, input_dtype: &DType) -> VortexResult; - fn state_dtype(&self, input_dtype: &DType) -> VortexResult; + fn return_dtype(&self, input_dtype: &DType) -> Option; + fn state_dtype(&self, input_dtype: &DType) -> Option; fn accumulator( &self, input_dtype: &DType, @@ -84,11 +84,11 @@ impl DynAggregateFn for AggregateFnInner { &self.options } - fn return_dtype(&self, input_dtype: &DType) -> VortexResult { + fn return_dtype(&self, input_dtype: &DType) -> Option { V::return_dtype(&self.vtable, &self.options, input_dtype) } - fn state_dtype(&self, input_dtype: &DType) -> VortexResult { + fn state_dtype(&self, input_dtype: &DType) -> Option { V::partial_dtype(&self.vtable, &self.options, input_dtype) } diff --git a/vortex-array/src/aggregate_fn/vtable.rs b/vortex-array/src/aggregate_fn/vtable.rs index 0e45e8a54fd..50bb4cef1cb 100644 --- a/vortex-array/src/aggregate_fn/vtable.rs +++ b/vortex-array/src/aggregate_fn/vtable.rs @@ -61,13 +61,17 @@ pub trait AggregateFnVTable: 'static + Sized + Clone + Send + Sync { } /// The return [`DType`] of the aggregate. - fn return_dtype(&self, options: &Self::Options, input_dtype: &DType) -> VortexResult; + /// + /// Returns `None` if the aggregate function cannot be applied to the input dtype. + fn return_dtype(&self, options: &Self::Options, input_dtype: &DType) -> Option; /// DType of the intermediate partial accumulator state. /// /// Use a struct dtype when multiple fields are needed /// (e.g., Mean: `Struct { sum: f64, count: u64 }`). - fn partial_dtype(&self, options: &Self::Options, input_dtype: &DType) -> VortexResult; + /// + /// Returns `None` if the aggregate function cannot be applied to the input dtype. + fn partial_dtype(&self, options: &Self::Options, input_dtype: &DType) -> Option; /// Return the partial accumulator state for an empty group. fn empty_partial( diff --git a/vortex-array/src/expr/stats/mod.rs b/vortex-array/src/expr/stats/mod.rs index 842d049537c..feaeeb3f311 100644 --- a/vortex-array/src/expr/stats/mod.rs +++ b/vortex-array/src/expr/stats/mod.rs @@ -11,10 +11,7 @@ use num_enum::IntoPrimitive; use num_enum::TryFromPrimitive; use crate::dtype::DType; -use crate::dtype::DecimalDType; -use crate::dtype::MAX_PRECISION; use crate::dtype::Nullability::NonNullable; -use crate::dtype::Nullability::Nullable; use crate::dtype::PType; mod bound; @@ -27,6 +24,10 @@ pub use precision::*; pub use provider::*; pub use stat_bound::*; +use crate::aggregate_fn; +use crate::aggregate_fn::AggregateFnVTable; +use crate::aggregate_fn::EmptyOptions; + #[derive( Debug, Clone, @@ -184,37 +185,7 @@ impl Stat { } } Self::Sum => { - // Any array that cannot be summed has a sum DType of null. - // Any array that can be summed, but overflows, has a sum _value_ of null. - // Therefore, we make integer sum stats nullable. - match data_type { - DType::Bool(_) => DType::Primitive(PType::U64, Nullable), - DType::Primitive(ptype, _) => match ptype { - PType::U8 | PType::U16 | PType::U32 | PType::U64 => { - DType::Primitive(PType::U64, Nullable) - } - PType::I8 | PType::I16 | PType::I32 | PType::I64 => { - DType::Primitive(PType::I64, Nullable) - } - PType::F16 | PType::F32 | PType::F64 => { - // Float sums cannot overflow, but all null floats still end up as null - DType::Primitive(PType::F64, Nullable) - } - }, - DType::Extension(ext_dtype) => self.dtype(ext_dtype.storage_dtype())?, - DType::Decimal(decimal_dtype, _) => { - // Both Spark and DataFusion use this heuristic. - // - https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66 - // - https://github.com/apache/datafusion/blob/4153adf2c0f6e317ef476febfdc834208bd46622/datafusion/functions-aggregate/src/sum.rs#L188 - let precision = u8::min(MAX_PRECISION, decimal_dtype.precision() + 10); - DType::Decimal( - DecimalDType::new(precision, decimal_dtype.scale()), - Nullable, - ) - } - // Unsupported types - _ => return None, - } + return aggregate_fn::fns::sum::Sum.return_dtype(&EmptyOptions, data_type); } }) } From 65a115de8ee407fad46c4022e41baae8d6f55682 Mon Sep 17 00:00:00 2001 From: Nicholas Gates Date: Fri, 13 Mar 2026 13:18:10 -0400 Subject: [PATCH 3/4] Clean up to point to aggregate function Signed-off-by: Nicholas Gates --- fuzz/src/array/mod.rs | 11 ++- fuzz/src/array/sum.rs | 5 +- vortex-array/public-api.lock | 20 ++--- vortex-array/src/aggregate_fn/accumulator.rs | 30 +++---- .../src/aggregate_fn/accumulator_grouped.rs | 47 ++++------ vortex-array/src/aggregate_fn/erased.rs | 17 +--- vortex-array/src/aggregate_fn/fns/sum/bool.rs | 43 +++++---- .../src/aggregate_fn/fns/sum/constant.rs | 20 +++-- .../src/aggregate_fn/fns/sum/decimal.rs | 53 ++++++++--- vortex-array/src/aggregate_fn/fns/sum/mod.rs | 90 +++++++++++-------- .../src/aggregate_fn/fns/sum/primitive.rs | 33 ++++--- vortex-array/src/aggregate_fn/typed.rs | 27 +----- vortex-array/src/array/mod.rs | 4 +- .../src/arrays/chunked/compute/aggregate.rs | 16 ++-- .../src/compute/conformance/consistency.rs | 8 +- vortex-array/src/compute/sum.rs | 5 +- vortex-array/src/stats/array.rs | 6 +- vortex-array/src/variants.rs | 5 +- vortex-layout/src/layouts/file_stats.rs | 5 +- vortex-layout/src/layouts/zoned/zone_map.rs | 5 +- 20 files changed, 244 insertions(+), 206 deletions(-) diff --git a/fuzz/src/array/mod.rs b/fuzz/src/array/mod.rs index ba6e519d264..02e9519230a 100644 --- a/fuzz/src/array/mod.rs +++ b/fuzz/src/array/mod.rs @@ -42,6 +42,7 @@ use tracing::debug; use vortex_array::ArrayRef; use vortex_array::DynArray; use vortex_array::IntoArray; +use vortex_array::VortexSessionExecute; use vortex_array::aggregate_fn::fns::sum::sum; use vortex_array::arrays::ConstantArray; use vortex_array::arrays::PrimitiveArray; @@ -68,6 +69,7 @@ use vortex_error::vortex_panic; use vortex_mask::Mask; use vortex_utils::aliases::hash_set::HashSet; +use crate::SESSION; use crate::error::Backtrace; use crate::error::VortexFuzzError; use crate::error::VortexFuzzResult; @@ -173,6 +175,8 @@ impl<'a> Arbitrary<'a> for FuzzArrayAction { let array = ArbitraryArray::arbitrary(u)?.0; let mut current_array = array.to_array(); + let mut ctx = SESSION.create_execution_ctx(); + let mut valid_actions = actions_for_dtype(current_array.dtype()) .into_iter() .collect::>(); @@ -330,6 +334,7 @@ impl<'a> Arbitrary<'a> for FuzzArrayAction { current_array .to_canonical() .vortex_expect("to_canonical should succeed in fuzz test"), + &mut ctx, ) .vortex_expect("sum_canonical_array should succeed in fuzz test"); (Action::Sum, ExpectedValue::Scalar(sum_result)) @@ -566,6 +571,8 @@ pub fn run_fuzz_action(fuzz_action: FuzzArrayAction) -> VortexFuzzResult { let FuzzArrayAction { array, actions } = fuzz_action; let mut current_array = array.to_array(); + let mut ctx = SESSION.create_execution_ctx(); + debug!( "Initial array:\nTree:\n{}Values:\n{:#}", current_array.display_tree(), @@ -640,8 +647,8 @@ pub fn run_fuzz_action(fuzz_action: FuzzArrayAction) -> VortexFuzzResult { current_array = cast_result; } Action::Sum => { - let sum_result = - sum(¤t_array).vortex_expect("sum operation should succeed in fuzz test"); + let sum_result = sum(¤t_array, &mut ctx) + .vortex_expect("sum operation should succeed in fuzz test"); assert_scalar_eq(&expected.scalar(), &sum_result, i)?; } Action::MinMax => { diff --git a/fuzz/src/array/sum.rs b/fuzz/src/array/sum.rs index 3f6c779de56..9cb2eb9bb90 100644 --- a/fuzz/src/array/sum.rs +++ b/fuzz/src/array/sum.rs @@ -2,13 +2,14 @@ // SPDX-FileCopyrightText: Copyright the Vortex contributors use vortex_array::Canonical; +use vortex_array::ExecutionCtx; use vortex_array::IntoArray as _; use vortex_array::aggregate_fn::fns::sum::sum; use vortex_array::scalar::Scalar; use vortex_error::VortexResult; /// Compute sum on the canonical form of the array to get a consistent baseline. -pub fn sum_canonical_array(canonical: Canonical) -> VortexResult { +pub fn sum_canonical_array(canonical: Canonical, ctx: &mut ExecutionCtx) -> VortexResult { // TODO(joe): replace with baseline not using canonical - sum(&canonical.into_array()) + sum(&canonical.into_array(), ctx) } diff --git a/vortex-array/public-api.lock b/vortex-array/public-api.lock index 41186d02229..bed3b8b1b58 100644 --- a/vortex-array/public-api.lock +++ b/vortex-array/public-api.lock @@ -64,7 +64,7 @@ pub fn vortex_array::aggregate_fn::fns::sum::Sum::combine_partials(&self, partia pub fn vortex_array::aggregate_fn::fns::sum::Sum::deserialize(&self, _metadata: &[u8], _session: &vortex_session::VortexSession) -> vortex_error::VortexResult -pub fn vortex_array::aggregate_fn::fns::sum::Sum::empty_partial(&self, _options: &Self::Options, input_dtype: &vortex_array::dtype::DType) -> vortex_error::VortexResult +pub fn vortex_array::aggregate_fn::fns::sum::Sum::empty_partial(&self, options: &Self::Options, input_dtype: &vortex_array::dtype::DType) -> vortex_error::VortexResult pub fn vortex_array::aggregate_fn::fns::sum::Sum::finalize(&self, partials: vortex_array::ArrayRef) -> vortex_error::VortexResult @@ -76,9 +76,9 @@ pub fn vortex_array::aggregate_fn::fns::sum::Sum::id(&self) -> vortex_array::agg pub fn vortex_array::aggregate_fn::fns::sum::Sum::is_saturated(&self, partial: &Self::Partial) -> bool -pub fn vortex_array::aggregate_fn::fns::sum::Sum::partial_dtype(&self, options: &Self::Options, input_dtype: &vortex_array::dtype::DType) -> vortex_error::VortexResult +pub fn vortex_array::aggregate_fn::fns::sum::Sum::partial_dtype(&self, options: &Self::Options, input_dtype: &vortex_array::dtype::DType) -> core::option::Option -pub fn vortex_array::aggregate_fn::fns::sum::Sum::return_dtype(&self, _options: &Self::Options, input_dtype: &vortex_array::dtype::DType) -> vortex_error::VortexResult +pub fn vortex_array::aggregate_fn::fns::sum::Sum::return_dtype(&self, _options: &Self::Options, input_dtype: &vortex_array::dtype::DType) -> core::option::Option pub fn vortex_array::aggregate_fn::fns::sum::Sum::serialize(&self, options: &Self::Options) -> vortex_error::VortexResult>> @@ -202,9 +202,9 @@ pub fn vortex_array::aggregate_fn::AggregateFnRef::is vortex_array::aggregate_fn::AggregateFnOptions<'_> -pub fn vortex_array::aggregate_fn::AggregateFnRef::return_dtype(&self, input_dtype: &vortex_array::dtype::DType) -> vortex_error::VortexResult +pub fn vortex_array::aggregate_fn::AggregateFnRef::return_dtype(&self, input_dtype: &vortex_array::dtype::DType) -> core::option::Option -pub fn vortex_array::aggregate_fn::AggregateFnRef::state_dtype(&self, input_dtype: &vortex_array::dtype::DType) -> vortex_error::VortexResult +pub fn vortex_array::aggregate_fn::AggregateFnRef::state_dtype(&self, input_dtype: &vortex_array::dtype::DType) -> core::option::Option pub fn vortex_array::aggregate_fn::AggregateFnRef::vtable_ref(&self) -> core::option::Option<&V> @@ -306,9 +306,9 @@ pub fn vortex_array::aggregate_fn::AggregateFnVTable::id(&self) -> vortex_array: pub fn vortex_array::aggregate_fn::AggregateFnVTable::is_saturated(&self, state: &Self::Partial) -> bool -pub fn vortex_array::aggregate_fn::AggregateFnVTable::partial_dtype(&self, options: &Self::Options, input_dtype: &vortex_array::dtype::DType) -> vortex_error::VortexResult +pub fn vortex_array::aggregate_fn::AggregateFnVTable::partial_dtype(&self, options: &Self::Options, input_dtype: &vortex_array::dtype::DType) -> core::option::Option -pub fn vortex_array::aggregate_fn::AggregateFnVTable::return_dtype(&self, options: &Self::Options, input_dtype: &vortex_array::dtype::DType) -> vortex_error::VortexResult +pub fn vortex_array::aggregate_fn::AggregateFnVTable::return_dtype(&self, options: &Self::Options, input_dtype: &vortex_array::dtype::DType) -> core::option::Option pub fn vortex_array::aggregate_fn::AggregateFnVTable::serialize(&self, options: &Self::Options) -> vortex_error::VortexResult>> @@ -324,7 +324,7 @@ pub fn vortex_array::aggregate_fn::fns::sum::Sum::combine_partials(&self, partia pub fn vortex_array::aggregate_fn::fns::sum::Sum::deserialize(&self, _metadata: &[u8], _session: &vortex_session::VortexSession) -> vortex_error::VortexResult -pub fn vortex_array::aggregate_fn::fns::sum::Sum::empty_partial(&self, _options: &Self::Options, input_dtype: &vortex_array::dtype::DType) -> vortex_error::VortexResult +pub fn vortex_array::aggregate_fn::fns::sum::Sum::empty_partial(&self, options: &Self::Options, input_dtype: &vortex_array::dtype::DType) -> vortex_error::VortexResult pub fn vortex_array::aggregate_fn::fns::sum::Sum::finalize(&self, partials: vortex_array::ArrayRef) -> vortex_error::VortexResult @@ -336,9 +336,9 @@ pub fn vortex_array::aggregate_fn::fns::sum::Sum::id(&self) -> vortex_array::agg pub fn vortex_array::aggregate_fn::fns::sum::Sum::is_saturated(&self, partial: &Self::Partial) -> bool -pub fn vortex_array::aggregate_fn::fns::sum::Sum::partial_dtype(&self, options: &Self::Options, input_dtype: &vortex_array::dtype::DType) -> vortex_error::VortexResult +pub fn vortex_array::aggregate_fn::fns::sum::Sum::partial_dtype(&self, options: &Self::Options, input_dtype: &vortex_array::dtype::DType) -> core::option::Option -pub fn vortex_array::aggregate_fn::fns::sum::Sum::return_dtype(&self, _options: &Self::Options, input_dtype: &vortex_array::dtype::DType) -> vortex_error::VortexResult +pub fn vortex_array::aggregate_fn::fns::sum::Sum::return_dtype(&self, _options: &Self::Options, input_dtype: &vortex_array::dtype::DType) -> core::option::Option pub fn vortex_array::aggregate_fn::fns::sum::Sum::serialize(&self, options: &Self::Options) -> vortex_error::VortexResult>> diff --git a/vortex-array/src/aggregate_fn/accumulator.rs b/vortex-array/src/aggregate_fn/accumulator.rs index 6a85b0e3120..107b13d0985 100644 --- a/vortex-array/src/aggregate_fn/accumulator.rs +++ b/vortex-array/src/aggregate_fn/accumulator.rs @@ -4,13 +4,12 @@ use vortex_error::VortexResult; use vortex_error::vortex_ensure; use vortex_error::vortex_err; -use vortex_session::VortexSession; use crate::AnyCanonical; use crate::ArrayRef; use crate::Columnar; use crate::DynArray; -use crate::VortexSessionExecute; +use crate::ExecutionCtx; use crate::aggregate_fn::AggregateFn; use crate::aggregate_fn::AggregateFnRef; use crate::aggregate_fn::AggregateFnVTable; @@ -36,17 +35,10 @@ pub struct Accumulator { partial_dtype: DType, /// The partial state of the accumulator, updated after each accumulate/merge call. partial: V::Partial, - /// A session used to lookup custom aggregate kernels. - session: VortexSession, } impl Accumulator { - pub fn try_new( - vtable: V, - options: V::Options, - dtype: DType, - session: VortexSession, - ) -> VortexResult { + pub fn try_new(vtable: V, options: V::Options, dtype: DType) -> VortexResult { let return_dtype = vtable.return_dtype(&options, &dtype).ok_or_else(|| { vortex_err!( "Aggregate function {} cannot be applied to dtype {}", @@ -71,7 +63,6 @@ impl Accumulator { return_dtype, partial_dtype, partial, - session, }) } } @@ -80,7 +71,7 @@ impl Accumulator { /// function is not known at compile time. pub trait DynAccumulator: 'static + Send { /// Accumulate a new array into the accumulator's state. - fn accumulate(&mut self, batch: &ArrayRef) -> VortexResult<()>; + fn accumulate(&mut self, batch: &ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult<()>; /// Whether the accumulator's result is fully determined. fn is_saturated(&self) -> bool; @@ -97,7 +88,7 @@ pub trait DynAccumulator: 'static + Send { } impl DynAccumulator for Accumulator { - fn accumulate(&mut self, batch: &ArrayRef) -> VortexResult<()> { + fn accumulate(&mut self, batch: &ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult<()> { if self.is_saturated() { return Ok(()); } @@ -109,9 +100,9 @@ impl DynAccumulator for Accumulator { batch.dtype() ); - let kernels = &self.session.aggregate_fns().kernels; + let session = ctx.session().clone(); + let kernels = &session.aggregate_fns().kernels; - let mut ctx = self.session.create_execution_ctx(); let mut batch = batch.clone(); for _ in 0..*MAX_ITERATIONS { if batch.is::() { @@ -125,7 +116,7 @@ impl DynAccumulator for Accumulator { .or_else(|| kernels_r.get(&(batch_id, None))) .and_then(|kernel| { kernel - .aggregate(&self.aggregate_fn, &batch, &mut ctx) + .aggregate(&self.aggregate_fn, &batch, ctx) .transpose() }) .transpose()? @@ -141,14 +132,13 @@ impl DynAccumulator for Accumulator { } // Execute one step and try again - batch = batch.execute(&mut ctx)?; + batch = batch.execute(ctx)?; } // Otherwise, execute the batch until it is columnar and accumulate it into the state. - let columnar = batch.execute::(&mut ctx)?; + let columnar = batch.execute::(ctx)?; - self.vtable - .accumulate(&mut self.partial, &columnar, &mut ctx) + self.vtable.accumulate(&mut self.partial, &columnar, ctx) } fn is_saturated(&self) -> bool { diff --git a/vortex-array/src/aggregate_fn/accumulator_grouped.rs b/vortex-array/src/aggregate_fn/accumulator_grouped.rs index afc7d29a71d..a4d9c38b60e 100644 --- a/vortex-array/src/aggregate_fn/accumulator_grouped.rs +++ b/vortex-array/src/aggregate_fn/accumulator_grouped.rs @@ -10,7 +10,6 @@ use vortex_error::vortex_ensure; use vortex_error::vortex_err; use vortex_error::vortex_panic; use vortex_mask::Mask; -use vortex_session::VortexSession; use crate::AnyCanonical; use crate::ArrayRef; @@ -19,7 +18,6 @@ use crate::Columnar; use crate::DynArray; use crate::ExecutionCtx; use crate::IntoArray; -use crate::VortexSessionExecute; use crate::aggregate_fn::Accumulator; use crate::aggregate_fn::AggregateFn; use crate::aggregate_fn::AggregateFnRef; @@ -59,17 +57,10 @@ pub struct GroupedAccumulator { partial_dtype: DType, /// The accumulated state for prior batches of groups. partials: Vec, - /// A session used to lookup custom aggregate kernels. - session: VortexSession, } impl GroupedAccumulator { - pub fn try_new( - vtable: V, - options: V::Options, - dtype: DType, - session: VortexSession, - ) -> VortexResult { + pub fn try_new(vtable: V, options: V::Options, dtype: DType) -> VortexResult { let aggregate_fn = AggregateFn::new(vtable.clone(), options.clone()).erased(); let return_dtype = vtable.return_dtype(&options, &dtype).ok_or_else(|| { vortex_err!( @@ -94,7 +85,6 @@ impl GroupedAccumulator { return_dtype, partial_dtype, partials: vec![], - session, }) } } @@ -103,7 +93,7 @@ impl GroupedAccumulator { /// function is not known at compile time. pub trait DynGroupedAccumulator: 'static + Send { /// Accumulate a list of groups into the accumulator. - fn accumulate_list(&mut self, groups: &ArrayRef) -> VortexResult<()>; + fn accumulate_list(&mut self, groups: &ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult<()>; /// Finish the accumulation and return the partial aggregate results for all groups. /// Resets the accumulator state for the next round of accumulation. @@ -115,7 +105,7 @@ pub trait DynGroupedAccumulator: 'static + Send { } impl DynGroupedAccumulator for GroupedAccumulator { - fn accumulate_list(&mut self, groups: &ArrayRef) -> VortexResult<()> { + fn accumulate_list(&mut self, groups: &ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult<()> { let elements_dtype = match groups.dtype() { DType::List(elem, _) => elem, DType::FixedSizeList(elem, ..) => elem, @@ -131,17 +121,15 @@ impl DynGroupedAccumulator for GroupedAccumulator { elements_dtype ); - let mut ctx = self.session.create_execution_ctx(); - // We first execute the groups until it is a ListView or FixedSizeList, since we only // dispatch the aggregate kernel over the elements of these arrays. - let canonical = match groups.clone().execute::(&mut ctx)? { + let canonical = match groups.clone().execute::(ctx)? { Columnar::Canonical(c) => c, - Columnar::Constant(c) => c.into_array().execute::(&mut ctx)?, + Columnar::Constant(c) => c.into_array().execute::(ctx)?, }; match canonical { - Canonical::List(groups) => self.accumulate_list_view(&groups, &mut ctx), - Canonical::FixedSizeList(groups) => self.accumulate_fixed_size_list(&groups, &mut ctx), + Canonical::List(groups) => self.accumulate_list_view(&groups, ctx), + Canonical::FixedSizeList(groups) => self.accumulate_fixed_size_list(&groups, ctx), _ => vortex_panic!("We checked the DType above, so this should never happen"), } } @@ -173,8 +161,7 @@ impl GroupedAccumulator { ctx: &mut ExecutionCtx, ) -> VortexResult<()> { let mut elements = groups.elements().clone(); - let session = self.session.clone(); - + let session = ctx.session().clone(); let kernels = &session.aggregate_fns().grouped_kernels; for _ in 0..*MAX_ITERATIONS { @@ -218,7 +205,13 @@ impl GroupedAccumulator { match_each_integer_ptype!(offsets.dtype().as_ptype(), |O| { let offsets = offsets.clone().execute::>(ctx)?; let sizes = sizes.execute::>(ctx)?; - self.accumulate_list_view_typed(&elements, offsets.as_ref(), sizes.as_ref(), &validity) + self.accumulate_list_view_typed( + &elements, + offsets.as_ref(), + sizes.as_ref(), + &validity, + ctx, + ) }) } @@ -228,12 +221,12 @@ impl GroupedAccumulator { offsets: &[O], sizes: &[O], validity: &Mask, + ctx: &mut ExecutionCtx, ) -> VortexResult<()> { let mut accumulator = Accumulator::try_new( self.vtable.clone(), self.options.clone(), self.dtype.clone(), - self.session.clone(), )?; let mut states = builder_with_capacity(&self.partial_dtype, offsets.len()); @@ -243,7 +236,7 @@ impl GroupedAccumulator { if validity.value(offset) { let group = elements.slice(offset..offset + size)?; - accumulator.accumulate(&group)?; + accumulator.accumulate(&group, ctx)?; states.append_scalar(&accumulator.finish()?)?; } else { states.append_null() @@ -259,8 +252,7 @@ impl GroupedAccumulator { ctx: &mut ExecutionCtx, ) -> VortexResult<()> { let mut elements = groups.elements().clone(); - - let session = self.session.clone(); + let session = ctx.session().clone(); let kernels = &session.aggregate_fns().grouped_kernels; for _ in 0..64 { @@ -304,7 +296,6 @@ impl GroupedAccumulator { self.vtable.clone(), self.options.clone(), self.dtype.clone(), - self.session.clone(), )?; let mut states = builder_with_capacity(&self.partial_dtype, groups.len()); @@ -317,7 +308,7 @@ impl GroupedAccumulator { for i in 0..groups.len() { if validity.value(i) { let group = elements.slice(offset..offset + size)?; - accumulator.accumulate(&group)?; + accumulator.accumulate(&group, ctx)?; states.append_scalar(&accumulator.finish()?)?; } else { states.append_null() diff --git a/vortex-array/src/aggregate_fn/erased.rs b/vortex-array/src/aggregate_fn/erased.rs index b8953aabc3c..c7c75f4b4d6 100644 --- a/vortex-array/src/aggregate_fn/erased.rs +++ b/vortex-array/src/aggregate_fn/erased.rs @@ -12,7 +12,6 @@ use std::sync::Arc; use vortex_error::VortexExpect; use vortex_error::VortexResult; -use vortex_session::VortexSession; use vortex_utils::debug_with::DebugWith; use crate::aggregate_fn::AccumulatorRef; @@ -90,21 +89,13 @@ impl AggregateFnRef { } /// Create an accumulator for streaming aggregation. - pub fn accumulator( - &self, - input_dtype: &DType, - session: &VortexSession, - ) -> VortexResult { - self.0.accumulator(input_dtype, session) + pub fn accumulator(&self, input_dtype: &DType) -> VortexResult { + self.0.accumulator(input_dtype) } /// Create a grouped accumulator for grouped streaming aggregation. - pub fn accumulator_grouped( - &self, - input_dtype: &DType, - session: &VortexSession, - ) -> VortexResult { - self.0.accumulator_grouped(input_dtype, session) + pub fn accumulator_grouped(&self, input_dtype: &DType) -> VortexResult { + self.0.accumulator_grouped(input_dtype) } } diff --git a/vortex-array/src/aggregate_fn/fns/sum/bool.rs b/vortex-array/src/aggregate_fn/fns/sum/bool.rs index e1932feea92..a2dad06fff6 100644 --- a/vortex-array/src/aggregate_fn/fns/sum/bool.rs +++ b/vortex-array/src/aggregate_fn/fns/sum/bool.rs @@ -29,9 +29,9 @@ pub(super) fn accumulate_bool(inner: &mut SumState, b: &BoolArray) -> VortexResu #[cfg(test)] mod tests { use vortex_error::VortexResult; - use vortex_session::VortexSession; use crate::IntoArray; + use crate::LEGACY_SESSION; use crate::aggregate_fn::Accumulator; use crate::aggregate_fn::AggregateFnVTable; use crate::aggregate_fn::DynAccumulator; @@ -42,15 +42,15 @@ mod tests { use crate::dtype::DType; use crate::dtype::Nullability; use crate::dtype::PType; - - fn session() -> VortexSession { - VortexSession::empty() - } + use crate::executor::VortexSessionExecute; #[test] fn sum_bool_all_true() -> VortexResult<()> { let arr: BoolArray = [true, true, true].into_iter().collect(); - let result = sum(&arr.into_array())?; + let result = sum( + &arr.into_array(), + &mut LEGACY_SESSION.create_execution_ctx(), + )?; assert_eq!(result.as_primitive().typed_value::(), Some(3)); Ok(()) } @@ -58,7 +58,10 @@ mod tests { #[test] fn sum_bool_mixed() -> VortexResult<()> { let arr: BoolArray = [true, false, true, false, true].into_iter().collect(); - let result = sum(&arr.into_array())?; + let result = sum( + &arr.into_array(), + &mut LEGACY_SESSION.create_execution_ctx(), + )?; assert_eq!(result.as_primitive().typed_value::(), Some(3)); Ok(()) } @@ -66,7 +69,10 @@ mod tests { #[test] fn sum_bool_all_false() -> VortexResult<()> { let arr: BoolArray = [false, false, false].into_iter().collect(); - let result = sum(&arr.into_array())?; + let result = sum( + &arr.into_array(), + &mut LEGACY_SESSION.create_execution_ctx(), + )?; assert_eq!(result.as_primitive().typed_value::(), Some(0)); Ok(()) } @@ -74,7 +80,10 @@ mod tests { #[test] fn sum_bool_with_nulls() -> VortexResult<()> { let arr = BoolArray::from_iter([Some(true), None, Some(true), Some(false)]); - let result = sum(&arr.into_array())?; + let result = sum( + &arr.into_array(), + &mut LEGACY_SESSION.create_execution_ctx(), + )?; assert_eq!(result.as_primitive().typed_value::(), Some(2)); Ok(()) } @@ -82,7 +91,10 @@ mod tests { #[test] fn sum_bool_all_null() -> VortexResult<()> { let arr = BoolArray::from_iter([None::, None, None]); - let result = sum(&arr.into_array())?; + let result = sum( + &arr.into_array(), + &mut LEGACY_SESSION.create_execution_ctx(), + )?; assert_eq!(result.as_primitive().typed_value::(), Some(0)); Ok(()) } @@ -90,7 +102,7 @@ mod tests { #[test] fn sum_bool_empty_produces_zero() -> VortexResult<()> { let dtype = DType::Bool(Nullability::NonNullable); - let mut acc = Accumulator::try_new(Sum, EmptyOptions, dtype, session())?; + let mut acc = Accumulator::try_new(Sum, EmptyOptions, dtype)?; let result = acc.finish()?; assert_eq!(result.as_primitive().typed_value::(), Some(0)); Ok(()) @@ -98,16 +110,17 @@ mod tests { #[test] fn sum_bool_finish_resets_state() -> VortexResult<()> { + let mut ctx = LEGACY_SESSION.create_execution_ctx(); let dtype = DType::Bool(Nullability::NonNullable); - let mut acc = Accumulator::try_new(Sum, EmptyOptions, dtype, session())?; + let mut acc = Accumulator::try_new(Sum, EmptyOptions, dtype)?; let batch1: BoolArray = [true, true, false].into_iter().collect(); - acc.accumulate(&batch1.into_array())?; + acc.accumulate(&batch1.into_array(), &mut ctx)?; let result1 = acc.finish()?; assert_eq!(result1.as_primitive().typed_value::(), Some(2)); let batch2: BoolArray = [false, true].into_iter().collect(); - acc.accumulate(&batch2.into_array())?; + acc.accumulate(&batch2.into_array(), &mut ctx)?; let result2 = acc.finish()?; assert_eq!(result2.as_primitive().typed_value::(), Some(1)); Ok(()) @@ -125,7 +138,7 @@ mod tests { #[test] fn sum_boolean_from_iter() -> VortexResult<()> { let arr = BoolArray::from_iter([true, false, false, true]).into_array(); - let result = sum(&arr)?; + let result = sum(&arr, &mut LEGACY_SESSION.create_execution_ctx())?; assert_eq!(result.as_primitive().as_::(), Some(2)); Ok(()) } diff --git a/vortex-array/src/aggregate_fn/fns/sum/constant.rs b/vortex-array/src/aggregate_fn/fns/sum/constant.rs index db59b4d80ef..286a68f229c 100644 --- a/vortex-array/src/aggregate_fn/fns/sum/constant.rs +++ b/vortex-array/src/aggregate_fn/fns/sum/constant.rs @@ -97,6 +97,8 @@ mod tests { use vortex_error::VortexResult; use crate::IntoArray; + use crate::LEGACY_SESSION; + use crate::VortexSessionExecute; use crate::aggregate_fn::fns::sum::sum; use crate::arrays::ConstantArray; use crate::dtype::DType; @@ -112,7 +114,7 @@ mod tests { #[test] fn sum_constant_unsigned() -> VortexResult<()> { let array = ConstantArray::new(5u64, 10).into_array(); - let result = sum(&array)?; + let result = sum(&array, &mut LEGACY_SESSION.create_execution_ctx())?; assert_eq!(result, 50u64.into()); Ok(()) } @@ -120,7 +122,7 @@ mod tests { #[test] fn sum_constant_signed() -> VortexResult<()> { let array = ConstantArray::new(-5i64, 10).into_array(); - let result = sum(&array)?; + let result = sum(&array, &mut LEGACY_SESSION.create_execution_ctx())?; assert_eq!(result, (-50i64).into()); Ok(()) } @@ -129,7 +131,7 @@ mod tests { fn sum_constant_nullable_value() -> VortexResult<()> { let array = ConstantArray::new(Scalar::null(DType::Primitive(PType::U32, Nullable)), 10) .into_array(); - let result = sum(&array)?; + let result = sum(&array, &mut LEGACY_SESSION.create_execution_ctx())?; assert_eq!(result, Scalar::primitive(0u64, Nullable)); Ok(()) } @@ -137,7 +139,7 @@ mod tests { #[test] fn sum_constant_bool_false() -> VortexResult<()> { let array = ConstantArray::new(false, 10).into_array(); - let result = sum(&array)?; + let result = sum(&array, &mut LEGACY_SESSION.create_execution_ctx())?; assert_eq!(result, 0u64.into()); Ok(()) } @@ -145,7 +147,7 @@ mod tests { #[test] fn sum_constant_bool_true() -> VortexResult<()> { let array = ConstantArray::new(true, 10).into_array(); - let result = sum(&array)?; + let result = sum(&array, &mut LEGACY_SESSION.create_execution_ctx())?; assert_eq!(result, 10u64.into()); Ok(()) } @@ -153,7 +155,7 @@ mod tests { #[test] fn sum_constant_bool_null() -> VortexResult<()> { let array = ConstantArray::new(Scalar::null(DType::Bool(Nullable)), 10).into_array(); - let result = sum(&array)?; + let result = sum(&array, &mut LEGACY_SESSION.create_execution_ctx())?; assert_eq!(result, Scalar::primitive(0u64, Nullable)); Ok(()) } @@ -171,7 +173,7 @@ mod tests { ) .into_array(); - let result = sum(&array)?; + let result = sum(&array, &mut LEGACY_SESSION.create_execution_ctx())?; assert_eq!( result.as_decimal().decimal_value(), @@ -187,7 +189,7 @@ mod tests { let array = ConstantArray::new(Scalar::null(DType::Decimal(decimal_dtype, Nullable)), 10) .into_array(); - let result = sum(&array)?; + let result = sum(&array, &mut LEGACY_SESSION.create_execution_ctx())?; assert_eq!( result, Scalar::decimal( @@ -212,7 +214,7 @@ mod tests { ) .into_array(); - let result = sum(&array)?; + let result = sum(&array, &mut LEGACY_SESSION.create_execution_ctx())?; assert_eq!( result.as_decimal().decimal_value(), Some(DecimalValue::I256(i256::from_i128(99_999_999_900))) diff --git a/vortex-array/src/aggregate_fn/fns/sum/decimal.rs b/vortex-array/src/aggregate_fn/fns/sum/decimal.rs index 4ffa680b2e6..fc388c57b49 100644 --- a/vortex-array/src/aggregate_fn/fns/sum/decimal.rs +++ b/vortex-array/src/aggregate_fn/fns/sum/decimal.rs @@ -51,6 +51,8 @@ mod tests { use vortex_error::VortexResult; use crate::IntoArray; + use crate::LEGACY_SESSION; + use crate::VortexSessionExecute; use crate::aggregate_fn::fns::sum::sum; use crate::arrays::DecimalArray; use crate::dtype::DType; @@ -71,7 +73,10 @@ mod tests { Validity::AllValid, ); - let result = sum(&decimal.into_array())?; + let result = sum( + &decimal.into_array(), + &mut LEGACY_SESSION.create_execution_ctx(), + )?; let expected = Scalar::try_new( DType::Decimal(DecimalDType::new(14, 2), Nullability::NonNullable), @@ -90,7 +95,10 @@ mod tests { Validity::from_iter([true, false, true, true]), ); - let result = sum(&decimal.into_array())?; + let result = sum( + &decimal.into_array(), + &mut LEGACY_SESSION.create_execution_ctx(), + )?; let expected = Scalar::try_new( DType::Decimal(DecimalDType::new(14, 2), Nullable), @@ -109,7 +117,10 @@ mod tests { Validity::AllValid, ); - let result = sum(&decimal.into_array())?; + let result = sum( + &decimal.into_array(), + &mut LEGACY_SESSION.create_execution_ctx(), + )?; let expected = Scalar::try_new( DType::Decimal(DecimalDType::new(14, 2), Nullability::NonNullable), @@ -129,7 +140,10 @@ mod tests { Validity::AllValid, ); - let result = sum(&decimal.into_array())?; + let result = sum( + &decimal.into_array(), + &mut LEGACY_SESSION.create_execution_ctx(), + )?; let expected_sum = near_max as i64 + 500 + 400; let expected = Scalar::try_new( @@ -150,7 +164,10 @@ mod tests { Validity::AllValid, ); - let result = sum(&decimal.into_array())?; + let result = sum( + &decimal.into_array(), + &mut LEGACY_SESSION.create_execution_ctx(), + )?; let expected_sum = (large_val as i128) * 4 + 1; let expected = Scalar::try_new( @@ -170,7 +187,10 @@ mod tests { Validity::AllValid, ); - let result = sum(&decimal.into_array())?; + let result = sum( + &decimal.into_array(), + &mut LEGACY_SESSION.create_execution_ctx(), + )?; let expected = Scalar::try_new( DType::Decimal(DecimalDType::new(16, 4), Nullability::NonNullable), @@ -186,7 +206,10 @@ mod tests { let decimal = DecimalArray::new(buffer![42i32], DecimalDType::new(3, 1), Validity::AllValid); - let result = sum(&decimal.into_array())?; + let result = sum( + &decimal.into_array(), + &mut LEGACY_SESSION.create_execution_ctx(), + )?; let expected = Scalar::try_new( DType::Decimal(DecimalDType::new(13, 1), Nullability::NonNullable), @@ -205,7 +228,10 @@ mod tests { Validity::from_iter([false, false, true, false]), ); - let result = sum(&decimal.into_array())?; + let result = sum( + &decimal.into_array(), + &mut LEGACY_SESSION.create_execution_ctx(), + )?; let expected = Scalar::try_new( DType::Decimal(DecimalDType::new(14, 2), Nullable), @@ -225,7 +251,10 @@ mod tests { Validity::AllValid, ); - let result = sum(&decimal.into_array())?; + let result = sum( + &decimal.into_array(), + &mut LEGACY_SESSION.create_execution_ctx(), + )?; let expected_sum = i256::from_i128(max_val) + i256::from_i128(max_val) + i256::from_i128(max_val); @@ -248,7 +277,11 @@ mod tests { ); assert_eq!( - sum(&decimal.into_array()).vortex_expect("operation should succeed in test"), + sum( + &decimal.into_array(), + &mut LEGACY_SESSION.create_execution_ctx() + ) + .vortex_expect("operation should succeed in test"), Scalar::null(DType::Decimal(decimal_dtype, Nullable)) ); Ok(()) diff --git a/vortex-array/src/aggregate_fn/fns/sum/mod.rs b/vortex-array/src/aggregate_fn/fns/sum/mod.rs index 9ed6baff7c6..df55c6836a0 100644 --- a/vortex-array/src/aggregate_fn/fns/sum/mod.rs +++ b/vortex-array/src/aggregate_fn/fns/sum/mod.rs @@ -11,7 +11,6 @@ use vortex_error::VortexResult; use vortex_error::vortex_bail; use vortex_error::vortex_err; use vortex_error::vortex_panic; -use vortex_session::VortexSession; use self::bool::accumulate_bool; use self::constant::accumulate_constant; @@ -40,7 +39,7 @@ use crate::scalar::Scalar; /// Return the sum of an array. /// /// See [`Sum`] for details. -pub fn sum(array: &ArrayRef) -> VortexResult { +pub fn sum(array: &ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult { // Short-circuit using cached array statistics. if let Some(Precision::Exact(sum_scalar)) = array.statistics().get(Stat::Sum) { return Ok(sum_scalar); @@ -48,13 +47,8 @@ pub fn sum(array: &ArrayRef) -> VortexResult { // Compute using Accumulator. // TODO(ngates): we may want to wrap this three-step dance up into an extension crate maybe. - let mut acc = Accumulator::try_new( - Sum, - EmptyOptions, - array.dtype().clone(), - VortexSession::empty(), - )?; - acc.accumulate(array)?; + let mut acc = Accumulator::try_new(Sum, EmptyOptions, array.dtype().clone())?; + acc.accumulate(array, ctx)?; let result = acc.finish()?; // Cache the computed sum as a statistic (only if non-null, i.e. no overflow). @@ -99,9 +93,6 @@ impl AggregateFnVTable for Sum { DType::Primitive(PType::F64, Nullable) } }, - DType::Extension(ext_dtype) => { - self.return_dtype(_options, ext_dtype.storage_dtype())? - } DType::Decimal(decimal_dtype, _) => { // Both Spark and DataFusion use this heuristic. // - https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66 @@ -317,11 +308,12 @@ mod tests { use vortex_buffer::buffer; use vortex_error::VortexExpect; use vortex_error::VortexResult; - use vortex_session::VortexSession; use crate::ArrayRef; use crate::DynArray; use crate::IntoArray; + use crate::LEGACY_SESSION; + use crate::VortexSessionExecute; use crate::aggregate_fn::Accumulator; use crate::aggregate_fn::AggregateFnVTable; use crate::aggregate_fn::DynAccumulator; @@ -351,17 +343,14 @@ mod tests { use crate::scalar::Scalar; use crate::validity::Validity; - fn session() -> VortexSession { - VortexSession::empty() - } - /// Sum an array with an initial value (test-only helper). fn sum_with_accumulator(array: &ArrayRef, accumulator: &Scalar) -> VortexResult { + let mut ctx = LEGACY_SESSION.create_execution_ctx(); if accumulator.is_null() { return Ok(accumulator.clone()); } if accumulator.is_zero() == Some(true) { - return sum(array); + return sum(array, &mut ctx); } let sum_dtype = Stat::Sum.dtype(array.dtype()).ok_or_else(|| { @@ -376,7 +365,7 @@ mod tests { } // Compute array sum from zero (also caches stats). - let array_sum = sum(array)?; + let array_sum = sum(array, &mut ctx)?; // Combine with the accumulator. add_scalars(&sum_dtype, &array_sum, accumulator) @@ -412,14 +401,15 @@ mod tests { #[test] fn sum_multi_batch() -> VortexResult<()> { + let mut ctx = LEGACY_SESSION.create_execution_ctx(); let dtype = DType::Primitive(PType::I32, Nullability::NonNullable); - let mut acc = Accumulator::try_new(Sum, EmptyOptions, dtype, session())?; + let mut acc = Accumulator::try_new(Sum, EmptyOptions, dtype)?; let batch1 = PrimitiveArray::new(buffer![10i32, 20], Validity::NonNullable).into_array(); - acc.accumulate(&batch1)?; + acc.accumulate(&batch1, &mut ctx)?; let batch2 = PrimitiveArray::new(buffer![3i32, 6, 9], Validity::NonNullable).into_array(); - acc.accumulate(&batch2)?; + acc.accumulate(&batch2, &mut ctx)?; let result = acc.finish()?; assert_eq!(result.as_primitive().typed_value::(), Some(48)); @@ -428,16 +418,17 @@ mod tests { #[test] fn sum_finish_resets_state() -> VortexResult<()> { + let mut ctx = LEGACY_SESSION.create_execution_ctx(); let dtype = DType::Primitive(PType::I32, Nullability::NonNullable); - let mut acc = Accumulator::try_new(Sum, EmptyOptions, dtype, session())?; + let mut acc = Accumulator::try_new(Sum, EmptyOptions, dtype)?; let batch1 = PrimitiveArray::new(buffer![10i32, 20], Validity::NonNullable).into_array(); - acc.accumulate(&batch1)?; + acc.accumulate(&batch1, &mut ctx)?; let result1 = acc.finish()?; assert_eq!(result1.as_primitive().typed_value::(), Some(30)); let batch2 = PrimitiveArray::new(buffer![3i32, 6, 9], Validity::NonNullable).into_array(); - acc.accumulate(&batch2)?; + acc.accumulate(&batch2, &mut ctx)?; let result2 = acc.finish()?; assert_eq!(result2.as_primitive().typed_value::(), Some(18)); Ok(()) @@ -477,7 +468,7 @@ mod tests { // compute sum with accumulator to populate stats sum_with_accumulator(&array, &Scalar::primitive(2i64, Nullable))?; - let sum_without_acc = sum(&array)?; + let sum_without_acc = sum(&array, &mut LEGACY_SESSION.create_execution_ctx())?; assert_eq!(sum_without_acc, Scalar::primitive(9i64, Nullable)); Ok(()) } @@ -500,9 +491,8 @@ mod tests { // Grouped sum tests fn run_grouped_sum(groups: &ArrayRef, elem_dtype: &DType) -> VortexResult { - let mut acc = - GroupedAccumulator::try_new(Sum, EmptyOptions, elem_dtype.clone(), session())?; - acc.accumulate_list(groups)?; + let mut acc = GroupedAccumulator::try_new(Sum, EmptyOptions, elem_dtype.clone())?; + acc.accumulate_list(groups, &mut LEGACY_SESSION.create_execution_ctx())?; acc.finish() } @@ -582,13 +572,14 @@ mod tests { #[test] fn grouped_sum_finish_resets() -> VortexResult<()> { + let mut ctx = LEGACY_SESSION.create_execution_ctx(); let elem_dtype = DType::Primitive(PType::I32, Nullability::NonNullable); - let mut acc = GroupedAccumulator::try_new(Sum, EmptyOptions, elem_dtype, session())?; + let mut acc = GroupedAccumulator::try_new(Sum, EmptyOptions, elem_dtype)?; let elements1 = PrimitiveArray::new(buffer![1i32, 2, 3, 4], Validity::NonNullable).into_array(); let groups1 = FixedSizeListArray::try_new(elements1, 2, Validity::NonNullable, 2)?; - acc.accumulate_list(&groups1.into_array())?; + acc.accumulate_list(&groups1.into_array(), &mut ctx)?; let result1 = acc.finish()?; let expected1 = PrimitiveArray::from_option_iter([Some(3i64), Some(7i64)]).into_array(); @@ -596,7 +587,7 @@ mod tests { let elements2 = PrimitiveArray::new(buffer![10i32, 20], Validity::NonNullable).into_array(); let groups2 = FixedSizeListArray::try_new(elements2, 2, Validity::NonNullable, 1)?; - acc.accumulate_list(&groups2.into_array())?; + acc.accumulate_list(&groups2.into_array(), &mut ctx)?; let result2 = acc.finish()?; let expected2 = PrimitiveArray::from_option_iter([Some(30i64)]).into_array(); @@ -622,7 +613,10 @@ mod tests { dtype, )?; - let result = sum(&chunked.into_array())?; + let result = sum( + &chunked.into_array(), + &mut LEGACY_SESSION.create_execution_ctx(), + )?; assert_eq!(result.as_primitive().as_::(), Some(20.8)); Ok(()) } @@ -633,7 +627,10 @@ mod tests { let chunk2 = PrimitiveArray::from_option_iter::(vec![None, None]); let dtype = chunk1.dtype().clone(); let chunked = ChunkedArray::try_new(vec![chunk1.into_array(), chunk2.into_array()], dtype)?; - let result = sum(&chunked.into_array())?; + let result = sum( + &chunked.into_array(), + &mut LEGACY_SESSION.create_execution_ctx(), + )?; assert_eq!(result, Scalar::primitive(0f64, Nullable)); Ok(()) } @@ -653,7 +650,10 @@ mod tests { dtype, )?; - let result = sum(&chunked.into_array())?; + let result = sum( + &chunked.into_array(), + &mut LEGACY_SESSION.create_execution_ctx(), + )?; assert_eq!(result.as_primitive().as_::(), Some(36.0)); Ok(()) } @@ -665,7 +665,10 @@ mod tests { let dtype = chunk1.dtype().clone(); let chunked = ChunkedArray::try_new(vec![chunk1.into_array(), chunk2.into_array()], dtype)?; - let result = sum(&chunked.into_array())?; + let result = sum( + &chunked.into_array(), + &mut LEGACY_SESSION.create_execution_ctx(), + )?; assert_eq!(result.as_primitive().as_::(), Some(1)); Ok(()) } @@ -694,7 +697,10 @@ mod tests { dtype, )?; - let result = sum(&chunked.into_array())?; + let result = sum( + &chunked.into_array(), + &mut LEGACY_SESSION.create_execution_ctx(), + )?; let decimal_result = result.as_decimal(); assert_eq!( decimal_result.decimal_value(), @@ -727,7 +733,10 @@ mod tests { dtype, )?; - let result = sum(&chunked.into_array())?; + let result = sum( + &chunked.into_array(), + &mut LEGACY_SESSION.create_execution_ctx(), + )?; let decimal_result = result.as_decimal(); assert_eq!( decimal_result.decimal_value(), @@ -758,7 +767,10 @@ mod tests { let dtype = chunk1.dtype().clone(); let chunked = ChunkedArray::try_new(vec![chunk1.into_array(), chunk2.into_array()], dtype)?; - let result = sum(&chunked.into_array())?; + let result = sum( + &chunked.into_array(), + &mut LEGACY_SESSION.create_execution_ctx(), + )?; let decimal_result = result.as_decimal(); assert_eq!( decimal_result.decimal_value(), diff --git a/vortex-array/src/aggregate_fn/fns/sum/primitive.rs b/vortex-array/src/aggregate_fn/fns/sum/primitive.rs index dcec2ba4159..292711f95bf 100644 --- a/vortex-array/src/aggregate_fn/fns/sum/primitive.rs +++ b/vortex-array/src/aggregate_fn/fns/sum/primitive.rs @@ -113,9 +113,10 @@ fn accumulate_primitive_valid( mod tests { use vortex_buffer::buffer; use vortex_error::VortexResult; - use vortex_session::VortexSession; use crate::IntoArray; + use crate::LEGACY_SESSION; + use crate::VortexSessionExecute; use crate::aggregate_fn::Accumulator; use crate::aggregate_fn::DynAccumulator; use crate::aggregate_fn::EmptyOptions; @@ -129,14 +130,10 @@ mod tests { use crate::scalar::Scalar; use crate::validity::Validity; - fn session() -> VortexSession { - VortexSession::empty() - } - #[test] fn sum_i32() -> VortexResult<()> { let arr = PrimitiveArray::new(buffer![1i32, 2, 3, 4], Validity::NonNullable).into_array(); - let result = sum(&arr)?; + let result = sum(&arr, &mut LEGACY_SESSION.create_execution_ctx())?; assert_eq!(result.as_primitive().typed_value::(), Some(10)); Ok(()) } @@ -144,7 +141,7 @@ mod tests { #[test] fn sum_u8() -> VortexResult<()> { let arr = PrimitiveArray::new(buffer![10u8, 20, 30], Validity::NonNullable).into_array(); - let result = sum(&arr)?; + let result = sum(&arr, &mut LEGACY_SESSION.create_execution_ctx())?; assert_eq!(result.as_primitive().typed_value::(), Some(60)); Ok(()) } @@ -153,7 +150,7 @@ mod tests { fn sum_f64() -> VortexResult<()> { let arr = PrimitiveArray::new(buffer![1.5f64, 2.5, 3.0], Validity::NonNullable).into_array(); - let result = sum(&arr)?; + let result = sum(&arr, &mut LEGACY_SESSION.create_execution_ctx())?; assert_eq!(result.as_primitive().typed_value::(), Some(7.0)); Ok(()) } @@ -161,7 +158,7 @@ mod tests { #[test] fn sum_with_nulls() -> VortexResult<()> { let arr = PrimitiveArray::from_option_iter([Some(2i32), None, Some(4)]).into_array(); - let result = sum(&arr)?; + let result = sum(&arr, &mut LEGACY_SESSION.create_execution_ctx())?; assert_eq!(result.as_primitive().typed_value::(), Some(6)); Ok(()) } @@ -169,7 +166,7 @@ mod tests { #[test] fn sum_all_null() -> VortexResult<()> { let arr = PrimitiveArray::from_option_iter([None::, None, None]).into_array(); - let result = sum(&arr)?; + let result = sum(&arr, &mut LEGACY_SESSION.create_execution_ctx())?; assert_eq!(result.as_primitive().typed_value::(), Some(0)); Ok(()) } @@ -177,7 +174,7 @@ mod tests { #[test] fn sum_all_invalid_float() -> VortexResult<()> { let arr = PrimitiveArray::from_option_iter::([None, None, None]).into_array(); - let result = sum(&arr)?; + let result = sum(&arr, &mut LEGACY_SESSION.create_execution_ctx())?; assert_eq!(result, Scalar::primitive(0f64, Nullable)); Ok(()) } @@ -185,7 +182,7 @@ mod tests { #[test] fn sum_buffer_i32() -> VortexResult<()> { let arr = buffer![1, 1, 1, 1].into_array(); - let result = sum(&arr)?; + let result = sum(&arr, &mut LEGACY_SESSION.create_execution_ctx())?; assert_eq!(result.as_primitive().as_::(), Some(4)); Ok(()) } @@ -193,7 +190,7 @@ mod tests { #[test] fn sum_buffer_f64() -> VortexResult<()> { let arr = buffer![1., 1., 1., 1.].into_array(); - let result = sum(&arr)?; + let result = sum(&arr, &mut LEGACY_SESSION.create_execution_ctx())?; assert_eq!(result.as_primitive().as_::(), Some(4.)); Ok(()) } @@ -201,7 +198,7 @@ mod tests { #[test] fn sum_empty_produces_zero() -> VortexResult<()> { let dtype = DType::Primitive(PType::I32, Nullability::NonNullable); - let mut acc = Accumulator::try_new(Sum, EmptyOptions, dtype, session())?; + let mut acc = Accumulator::try_new(Sum, EmptyOptions, dtype)?; let result = acc.finish()?; assert_eq!(result.as_primitive().typed_value::(), Some(0)); Ok(()) @@ -210,7 +207,7 @@ mod tests { #[test] fn sum_empty_f64_produces_zero() -> VortexResult<()> { let dtype = DType::Primitive(PType::F64, Nullability::NonNullable); - let mut acc = Accumulator::try_new(Sum, EmptyOptions, dtype, session())?; + let mut acc = Accumulator::try_new(Sum, EmptyOptions, dtype)?; let result = acc.finish()?; assert_eq!(result.as_primitive().typed_value::(), Some(0.0)); Ok(()) @@ -219,7 +216,7 @@ mod tests { #[test] fn sum_checked_overflow() -> VortexResult<()> { let arr = PrimitiveArray::new(buffer![i64::MAX, 1i64], Validity::NonNullable).into_array(); - let result = sum(&arr)?; + let result = sum(&arr, &mut LEGACY_SESSION.create_execution_ctx())?; assert!(result.is_null()); Ok(()) } @@ -227,12 +224,12 @@ mod tests { #[test] fn sum_checked_overflow_is_saturated() -> VortexResult<()> { let dtype = DType::Primitive(PType::I64, Nullability::NonNullable); - let mut acc = Accumulator::try_new(Sum, EmptyOptions, dtype, session())?; + let mut acc = Accumulator::try_new(Sum, EmptyOptions, dtype)?; assert!(!acc.is_saturated()); let batch = PrimitiveArray::new(buffer![i64::MAX, 1i64], Validity::NonNullable).into_array(); - acc.accumulate(&batch)?; + acc.accumulate(&batch, &mut LEGACY_SESSION.create_execution_ctx())?; assert!(acc.is_saturated()); // finish resets state, clearing saturation diff --git a/vortex-array/src/aggregate_fn/typed.rs b/vortex-array/src/aggregate_fn/typed.rs index e3bf66037ca..328c0b08835 100644 --- a/vortex-array/src/aggregate_fn/typed.rs +++ b/vortex-array/src/aggregate_fn/typed.rs @@ -19,7 +19,6 @@ use std::hash::Hasher; use std::sync::Arc; use vortex_error::VortexResult; -use vortex_session::VortexSession; use crate::aggregate_fn::Accumulator; use crate::aggregate_fn::AccumulatorRef; @@ -41,16 +40,8 @@ pub(super) trait DynAggregateFn: 'static + Send + Sync + super::sealed::Sealed { fn return_dtype(&self, input_dtype: &DType) -> Option; fn state_dtype(&self, input_dtype: &DType) -> Option; - fn accumulator( - &self, - input_dtype: &DType, - session: &VortexSession, - ) -> VortexResult; - fn accumulator_grouped( - &self, - input_dtype: &DType, - session: &VortexSession, - ) -> VortexResult; + fn accumulator(&self, input_dtype: &DType) -> VortexResult; + fn accumulator_grouped(&self, input_dtype: &DType) -> VortexResult; fn options_serialize(&self) -> VortexResult>>; fn options_eq(&self, other_options: &dyn Any) -> bool; @@ -92,29 +83,19 @@ impl DynAggregateFn for AggregateFnInner { V::partial_dtype(&self.vtable, &self.options, input_dtype) } - fn accumulator( - &self, - input_dtype: &DType, - session: &VortexSession, - ) -> VortexResult { + fn accumulator(&self, input_dtype: &DType) -> VortexResult { Ok(Box::new(Accumulator::try_new( self.vtable.clone(), self.options.clone(), input_dtype.clone(), - session.clone(), )?)) } - fn accumulator_grouped( - &self, - input_dtype: &DType, - session: &VortexSession, - ) -> VortexResult { + fn accumulator_grouped(&self, input_dtype: &DType) -> VortexResult { Ok(Box::new(GroupedAccumulator::try_new( self.vtable.clone(), self.options.clone(), input_dtype.clone(), - session.clone(), )?)) } diff --git a/vortex-array/src/array/mod.rs b/vortex-array/src/array/mod.rs index 5056aa649ed..79eb980d317 100644 --- a/vortex-array/src/array/mod.rs +++ b/vortex-array/src/array/mod.rs @@ -567,6 +567,7 @@ impl DynArray for ArrayAdapter { } } + // TODO(ngates): deprecate this function since it requires compute. fn valid_count(&self) -> VortexResult { if let Some(Precision::Exact(invalid_count)) = self.statistics().get_as::(Stat::NullCount) @@ -578,7 +579,8 @@ impl DynArray for ArrayAdapter { Validity::NonNullable | Validity::AllValid => self.len(), Validity::AllInvalid => 0, Validity::Array(a) => { - let array_sum = sum(&a)?; + let mut ctx = LEGACY_SESSION.create_execution_ctx(); + let array_sum = sum(&a, &mut ctx)?; array_sum .as_primitive() .as_::() diff --git a/vortex-array/src/arrays/chunked/compute/aggregate.rs b/vortex-array/src/arrays/chunked/compute/aggregate.rs index 243fb8b1a6b..86965f45072 100644 --- a/vortex-array/src/arrays/chunked/compute/aggregate.rs +++ b/vortex-array/src/arrays/chunked/compute/aggregate.rs @@ -24,9 +24,9 @@ impl DynAggregateKernel for ChunkedArrayAggregate { return Ok(None); }; - let mut acc = aggregate_fn.accumulator(chunked.dtype(), ctx.session())?; + let mut acc = aggregate_fn.accumulator(chunked.dtype())?; for chunk in chunked.chunks() { - acc.accumulate(chunk)?; + acc.accumulate(chunk, ctx)?; } Ok(Some(acc.finish()?)) } @@ -37,9 +37,10 @@ mod tests { use vortex_buffer::Buffer; use vortex_buffer::buffer; use vortex_error::VortexResult; - use vortex_session::VortexSession; use crate::IntoArray; + use crate::LEGACY_SESSION; + use crate::VortexSessionExecute; use crate::aggregate_fn::Accumulator; use crate::aggregate_fn::DynAccumulator; use crate::aggregate_fn::EmptyOptions; @@ -52,13 +53,10 @@ mod tests { use crate::dtype::PType; use crate::scalar::Scalar; - fn session() -> VortexSession { - VortexSession::empty() - } - fn run_sum(batch: &crate::ArrayRef) -> VortexResult { - let mut acc = Accumulator::try_new(Sum, EmptyOptions, batch.dtype().clone(), session())?; - acc.accumulate(batch)?; + let mut ctx = LEGACY_SESSION.create_execution_ctx(); + let mut acc = Accumulator::try_new(Sum, EmptyOptions, batch.dtype().clone())?; + acc.accumulate(batch, &mut ctx)?; acc.finish() } diff --git a/vortex-array/src/compute/conformance/consistency.rs b/vortex-array/src/compute/conformance/consistency.rs index 6e50e978d46..75f3b94f44e 100644 --- a/vortex-array/src/compute/conformance/consistency.rs +++ b/vortex-array/src/compute/conformance/consistency.rs @@ -27,6 +27,8 @@ use vortex_mask::Mask; use crate::ArrayRef; use crate::DynArray; use crate::IntoArray; +use crate::LEGACY_SESSION; +use crate::VortexSessionExecute; use crate::arrays::BoolArray; use crate::arrays::PrimitiveArray; use crate::builtins::ArrayBuiltins; @@ -1014,6 +1016,8 @@ fn test_slice_aggregate_consistency(array: &ArrayRef) { use crate::compute::nan_count; use crate::dtype::DType; + let mut ctx = LEGACY_SESSION.create_execution_ctx(); + let len = array.len(); if len < 5 { return; // Need enough elements for meaningful slice @@ -1051,7 +1055,9 @@ fn test_slice_aggregate_consistency(array: &ArrayRef) { return; } - if let (Ok(slice_sum), Ok(canonical_sum)) = (sum(&sliced), sum(&canonical_sliced)) { + if let (Ok(slice_sum), Ok(canonical_sum)) = + (sum(&sliced, &mut ctx), sum(&canonical_sliced, &mut ctx)) + { // Compare sum scalars assert_eq!( slice_sum, canonical_sum, diff --git a/vortex-array/src/compute/sum.rs b/vortex-array/src/compute/sum.rs index a831e07248a..b1e4fbc6216 100644 --- a/vortex-array/src/compute/sum.rs +++ b/vortex-array/src/compute/sum.rs @@ -4,9 +4,12 @@ use vortex_error::VortexResult; use crate::ArrayRef; +use crate::LEGACY_SESSION; +use crate::VortexSessionExecute; use crate::scalar::Scalar; #[deprecated(note = "use `vortex::array::aggregate_fn::fns::sum::sum` instead")] pub fn sum(array: &ArrayRef) -> VortexResult { - crate::aggregate_fn::fns::sum::sum(array) + let mut ctx = LEGACY_SESSION.create_execution_ctx(); + crate::aggregate_fn::fns::sum::sum(array, &mut ctx) } diff --git a/vortex-array/src/stats/array.rs b/vortex-array/src/stats/array.rs index fc6fcfa877b..01ee4f897be 100644 --- a/vortex-array/src/stats/array.rs +++ b/vortex-array/src/stats/array.rs @@ -15,6 +15,8 @@ use super::StatsSet; use super::StatsSetIntoIter; use super::TypedStatsSetRef; use crate::DynArray; +use crate::LEGACY_SESSION; +use crate::VortexSessionExecute; use crate::aggregate_fn::fns::sum::sum; use crate::builders::builder_with_capacity; use crate::compute::MinMaxResult; @@ -142,6 +144,8 @@ impl StatsSetRef<'_> { } pub fn compute_stat(&self, stat: Stat) -> VortexResult> { + let mut ctx = LEGACY_SESSION.create_execution_ctx(); + // If it's already computed and exact, we can return it. if let Some(Precision::Exact(s)) = self.get(stat) { return Ok(Some(s)); @@ -157,7 +161,7 @@ impl StatsSetRef<'_> { .is_some() .then(|| { // Sum is supported for this dtype. - sum(&array_ref) + sum(&array_ref, &mut ctx) }) .transpose()? } diff --git a/vortex-array/src/variants.rs b/vortex-array/src/variants.rs index 57f5f1dbd30..ce6af05a2e5 100644 --- a/vortex-array/src/variants.rs +++ b/vortex-array/src/variants.rs @@ -12,6 +12,8 @@ use vortex_mask::Mask; use crate::DynArray; use crate::ExecutionCtx; +use crate::LEGACY_SESSION; +use crate::VortexSessionExecute; use crate::aggregate_fn::fns::sum::sum; use crate::arrays::BoolArray; use crate::builtins::ArrayBuiltins; @@ -108,7 +110,8 @@ pub struct BoolTyped<'a>(&'a dyn DynArray); impl BoolTyped<'_> { pub fn true_count(&self) -> VortexResult { - let true_count = sum(&self.0.to_array())?; + let mut ctx = LEGACY_SESSION.create_execution_ctx(); + let true_count = sum(&self.0.to_array(), &mut ctx)?; Ok(true_count .as_primitive() .as_::() diff --git a/vortex-layout/src/layouts/file_stats.rs b/vortex-layout/src/layouts/file_stats.rs index 10032dec4c3..48254684451 100644 --- a/vortex-layout/src/layouts/file_stats.rs +++ b/vortex-layout/src/layouts/file_stats.rs @@ -8,7 +8,9 @@ use futures::StreamExt; use itertools::Itertools; use parking_lot::Mutex; use vortex_array::ArrayRef; +use vortex_array::LEGACY_SESSION; use vortex_array::ToCanonical as _; +use vortex_array::VortexSessionExecute; use vortex_array::dtype::DType; use vortex_array::dtype::Nullability; use vortex_array::expr::stats::Stat; @@ -108,6 +110,7 @@ impl FileStatsAccumulator { } pub fn stats_sets(&self) -> Vec { + let mut ctx = LEGACY_SESSION.create_execution_ctx(); self.accumulators .lock() .iter_mut() @@ -116,7 +119,7 @@ impl FileStatsAccumulator { .vortex_expect("as_stats_table should not fail") .map(|table| { table - .to_stats_set(&self.stats) + .to_stats_set(&self.stats, &mut ctx) .vortex_expect("shouldn't fail to convert table we just created") }) .unwrap_or_default() diff --git a/vortex-layout/src/layouts/zoned/zone_map.rs b/vortex-layout/src/layouts/zoned/zone_map.rs index 3538f47085f..eeeed8852e3 100644 --- a/vortex-layout/src/layouts/zoned/zone_map.rs +++ b/vortex-layout/src/layouts/zoned/zone_map.rs @@ -6,6 +6,7 @@ use std::sync::Arc; use itertools::Itertools; use vortex_array::ArrayRef; use vortex_array::DynArray; +use vortex_array::ExecutionCtx; use vortex_array::IntoArray; use vortex_array::VortexSessionExecute; use vortex_array::aggregate_fn::fns::sum::sum; @@ -108,7 +109,7 @@ impl ZoneMap { } /// Returns an aggregated stats set for the table. - pub fn to_stats_set(&self, stats: &[Stat]) -> VortexResult { + pub fn to_stats_set(&self, stats: &[Stat], ctx: &mut ExecutionCtx) -> VortexResult { let mut stats_set = StatsSet::default(); for &stat in stats { let Some(array) = self.get_stat(stat)? else { @@ -127,7 +128,7 @@ impl ZoneMap { } // These stats sum up Stat::NullCount | Stat::NaNCount | Stat::UncompressedSizeInBytes => { - if let Some(sum_value) = sum(&array)? + if let Some(sum_value) = sum(&array, ctx)? .cast(&DType::Primitive(PType::U64, Nullability::Nullable))? .into_value() { From 8ae7b1241976d6ecfd615fe571bee3abde6bebb7 Mon Sep 17 00:00:00 2001 From: Nicholas Gates Date: Fri, 13 Mar 2026 13:26:01 -0400 Subject: [PATCH 4/4] Clean up to point to aggregate function Signed-off-by: Nicholas Gates --- vortex-array/public-api.lock | 22 +++++++++++----------- vortex-layout/public-api.lock | 2 +- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/vortex-array/public-api.lock b/vortex-array/public-api.lock index bed3b8b1b58..dd2fa238289 100644 --- a/vortex-array/public-api.lock +++ b/vortex-array/public-api.lock @@ -84,7 +84,7 @@ pub fn vortex_array::aggregate_fn::fns::sum::Sum::serialize(&self, options: &Sel pub struct vortex_array::aggregate_fn::fns::sum::SumPartial -pub fn vortex_array::aggregate_fn::fns::sum::sum(array: &vortex_array::ArrayRef) -> vortex_error::VortexResult +pub fn vortex_array::aggregate_fn::fns::sum::sum(array: &vortex_array::ArrayRef, ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult pub mod vortex_array::aggregate_fn::kernels @@ -132,11 +132,11 @@ pub struct vortex_array::aggregate_fn::Accumulator vortex_array::aggregate_fn::Accumulator -pub fn vortex_array::aggregate_fn::Accumulator::try_new(vtable: V, options: ::Options, dtype: vortex_array::dtype::DType, session: vortex_session::VortexSession) -> vortex_error::VortexResult +pub fn vortex_array::aggregate_fn::Accumulator::try_new(vtable: V, options: ::Options, dtype: vortex_array::dtype::DType) -> vortex_error::VortexResult impl vortex_array::aggregate_fn::DynAccumulator for vortex_array::aggregate_fn::Accumulator -pub fn vortex_array::aggregate_fn::Accumulator::accumulate(&mut self, batch: &vortex_array::ArrayRef) -> vortex_error::VortexResult<()> +pub fn vortex_array::aggregate_fn::Accumulator::accumulate(&mut self, batch: &vortex_array::ArrayRef, ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult<()> pub fn vortex_array::aggregate_fn::Accumulator::finish(&mut self) -> vortex_error::VortexResult @@ -188,9 +188,9 @@ pub struct vortex_array::aggregate_fn::AggregateFnRef(_) impl vortex_array::aggregate_fn::AggregateFnRef -pub fn vortex_array::aggregate_fn::AggregateFnRef::accumulator(&self, input_dtype: &vortex_array::dtype::DType, session: &vortex_session::VortexSession) -> vortex_error::VortexResult +pub fn vortex_array::aggregate_fn::AggregateFnRef::accumulator(&self, input_dtype: &vortex_array::dtype::DType) -> vortex_error::VortexResult -pub fn vortex_array::aggregate_fn::AggregateFnRef::accumulator_grouped(&self, input_dtype: &vortex_array::dtype::DType, session: &vortex_session::VortexSession) -> vortex_error::VortexResult +pub fn vortex_array::aggregate_fn::AggregateFnRef::accumulator_grouped(&self, input_dtype: &vortex_array::dtype::DType) -> vortex_error::VortexResult pub fn vortex_array::aggregate_fn::AggregateFnRef::as_(&self) -> &::Options @@ -260,11 +260,11 @@ pub struct vortex_array::aggregate_fn::GroupedAccumulator vortex_array::aggregate_fn::GroupedAccumulator -pub fn vortex_array::aggregate_fn::GroupedAccumulator::try_new(vtable: V, options: ::Options, dtype: vortex_array::dtype::DType, session: vortex_session::VortexSession) -> vortex_error::VortexResult +pub fn vortex_array::aggregate_fn::GroupedAccumulator::try_new(vtable: V, options: ::Options, dtype: vortex_array::dtype::DType) -> vortex_error::VortexResult impl vortex_array::aggregate_fn::DynGroupedAccumulator for vortex_array::aggregate_fn::GroupedAccumulator -pub fn vortex_array::aggregate_fn::GroupedAccumulator::accumulate_list(&mut self, groups: &vortex_array::ArrayRef) -> vortex_error::VortexResult<()> +pub fn vortex_array::aggregate_fn::GroupedAccumulator::accumulate_list(&mut self, groups: &vortex_array::ArrayRef, ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult<()> pub fn vortex_array::aggregate_fn::GroupedAccumulator::finish(&mut self) -> vortex_error::VortexResult @@ -352,7 +352,7 @@ pub fn V::bind(&self, options: Self::Options) -> vortex_array::aggregate_fn::Agg pub trait vortex_array::aggregate_fn::DynAccumulator: 'static + core::marker::Send -pub fn vortex_array::aggregate_fn::DynAccumulator::accumulate(&mut self, batch: &vortex_array::ArrayRef) -> vortex_error::VortexResult<()> +pub fn vortex_array::aggregate_fn::DynAccumulator::accumulate(&mut self, batch: &vortex_array::ArrayRef, ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult<()> pub fn vortex_array::aggregate_fn::DynAccumulator::finish(&mut self) -> vortex_error::VortexResult @@ -362,7 +362,7 @@ pub fn vortex_array::aggregate_fn::DynAccumulator::is_saturated(&self) -> bool impl vortex_array::aggregate_fn::DynAccumulator for vortex_array::aggregate_fn::Accumulator -pub fn vortex_array::aggregate_fn::Accumulator::accumulate(&mut self, batch: &vortex_array::ArrayRef) -> vortex_error::VortexResult<()> +pub fn vortex_array::aggregate_fn::Accumulator::accumulate(&mut self, batch: &vortex_array::ArrayRef, ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult<()> pub fn vortex_array::aggregate_fn::Accumulator::finish(&mut self) -> vortex_error::VortexResult @@ -372,7 +372,7 @@ pub fn vortex_array::aggregate_fn::Accumulator::is_saturated(&self) -> bool pub trait vortex_array::aggregate_fn::DynGroupedAccumulator: 'static + core::marker::Send -pub fn vortex_array::aggregate_fn::DynGroupedAccumulator::accumulate_list(&mut self, groups: &vortex_array::ArrayRef) -> vortex_error::VortexResult<()> +pub fn vortex_array::aggregate_fn::DynGroupedAccumulator::accumulate_list(&mut self, groups: &vortex_array::ArrayRef, ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult<()> pub fn vortex_array::aggregate_fn::DynGroupedAccumulator::finish(&mut self) -> vortex_error::VortexResult @@ -380,7 +380,7 @@ pub fn vortex_array::aggregate_fn::DynGroupedAccumulator::flush(&mut self) -> vo impl vortex_array::aggregate_fn::DynGroupedAccumulator for vortex_array::aggregate_fn::GroupedAccumulator -pub fn vortex_array::aggregate_fn::GroupedAccumulator::accumulate_list(&mut self, groups: &vortex_array::ArrayRef) -> vortex_error::VortexResult<()> +pub fn vortex_array::aggregate_fn::GroupedAccumulator::accumulate_list(&mut self, groups: &vortex_array::ArrayRef, ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult<()> pub fn vortex_array::aggregate_fn::GroupedAccumulator::finish(&mut self) -> vortex_error::VortexResult diff --git a/vortex-layout/public-api.lock b/vortex-layout/public-api.lock index 8163fa19489..2cb3dec9d9a 100644 --- a/vortex-layout/public-api.lock +++ b/vortex-layout/public-api.lock @@ -836,7 +836,7 @@ pub fn vortex_layout::layouts::zoned::zone_map::ZoneMap::present_stats(&self) -> pub fn vortex_layout::layouts::zoned::zone_map::ZoneMap::prune(&self, predicate: &vortex_array::expr::expression::Expression, session: &vortex_session::VortexSession) -> vortex_error::VortexResult -pub fn vortex_layout::layouts::zoned::zone_map::ZoneMap::to_stats_set(&self, stats: &[vortex_array::expr::stats::Stat]) -> vortex_error::VortexResult +pub fn vortex_layout::layouts::zoned::zone_map::ZoneMap::to_stats_set(&self, stats: &[vortex_array::expr::stats::Stat], ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult pub fn vortex_layout::layouts::zoned::zone_map::ZoneMap::try_new(column_dtype: vortex_array::dtype::DType, array: vortex_array::arrays::struct_::array::StructArray, stats: alloc::sync::Arc<[vortex_array::expr::stats::Stat]>) -> vortex_error::VortexResult