diff --git a/fuzz/src/array/mod.rs b/fuzz/src/array/mod.rs index 87c219bb4f5..02e9519230a 100644 --- a/fuzz/src/array/mod.rs +++ b/fuzz/src/array/mod.rs @@ -42,13 +42,14 @@ use tracing::debug; use vortex_array::ArrayRef; use vortex_array::DynArray; use vortex_array::IntoArray; +use vortex_array::VortexSessionExecute; +use vortex_array::aggregate_fn::fns::sum::sum; use vortex_array::arrays::ConstantArray; use vortex_array::arrays::PrimitiveArray; use vortex_array::arrays::arbitrary::ArbitraryArray; use vortex_array::builtins::ArrayBuiltins; use vortex_array::compute::MinMaxResult; use vortex_array::compute::min_max; -use vortex_array::compute::sum; use vortex_array::dtype::DType; use vortex_array::dtype::Nullability; use vortex_array::scalar::Scalar; @@ -68,6 +69,7 @@ use vortex_error::vortex_panic; use vortex_mask::Mask; use vortex_utils::aliases::hash_set::HashSet; +use crate::SESSION; use crate::error::Backtrace; use crate::error::VortexFuzzError; use crate::error::VortexFuzzResult; @@ -173,6 +175,8 @@ impl<'a> Arbitrary<'a> for FuzzArrayAction { let array = ArbitraryArray::arbitrary(u)?.0; let mut current_array = array.to_array(); + let mut ctx = SESSION.create_execution_ctx(); + let mut valid_actions = actions_for_dtype(current_array.dtype()) .into_iter() .collect::>(); @@ -330,6 +334,7 @@ impl<'a> Arbitrary<'a> for FuzzArrayAction { current_array .to_canonical() .vortex_expect("to_canonical should succeed in fuzz test"), + &mut ctx, ) .vortex_expect("sum_canonical_array should succeed in fuzz test"); (Action::Sum, ExpectedValue::Scalar(sum_result)) @@ -566,6 +571,8 @@ pub fn run_fuzz_action(fuzz_action: FuzzArrayAction) -> VortexFuzzResult { let FuzzArrayAction { array, actions } = fuzz_action; let mut current_array = array.to_array(); + let mut ctx = SESSION.create_execution_ctx(); + debug!( "Initial array:\nTree:\n{}Values:\n{:#}", current_array.display_tree(), @@ -640,8 +647,8 @@ pub fn run_fuzz_action(fuzz_action: FuzzArrayAction) -> VortexFuzzResult { current_array = cast_result; } Action::Sum => { - let sum_result = - sum(¤t_array).vortex_expect("sum operation should succeed in fuzz test"); + let sum_result = sum(¤t_array, &mut ctx) + .vortex_expect("sum operation should succeed in fuzz test"); assert_scalar_eq(&expected.scalar(), &sum_result, i)?; } Action::MinMax => { diff --git a/fuzz/src/array/sum.rs b/fuzz/src/array/sum.rs index eec4c3954e4..9cb2eb9bb90 100644 --- a/fuzz/src/array/sum.rs +++ b/fuzz/src/array/sum.rs @@ -2,13 +2,14 @@ // SPDX-FileCopyrightText: Copyright the Vortex contributors use vortex_array::Canonical; +use vortex_array::ExecutionCtx; use vortex_array::IntoArray as _; -use vortex_array::compute::sum; +use vortex_array::aggregate_fn::fns::sum::sum; use vortex_array::scalar::Scalar; use vortex_error::VortexResult; /// Compute sum on the canonical form of the array to get a consistent baseline. -pub fn sum_canonical_array(canonical: Canonical) -> VortexResult { +pub fn sum_canonical_array(canonical: Canonical, ctx: &mut ExecutionCtx) -> VortexResult { // TODO(joe): replace with baseline not using canonical - sum(&canonical.into_array()) + sum(&canonical.into_array(), ctx) } diff --git a/vortex-array/public-api.lock b/vortex-array/public-api.lock index ed872aa095a..6662ee78b14 100644 --- a/vortex-array/public-api.lock +++ b/vortex-array/public-api.lock @@ -66,7 +66,7 @@ pub fn vortex_array::aggregate_fn::fns::sum::Sum::combine_partials(&self, partia pub fn vortex_array::aggregate_fn::fns::sum::Sum::deserialize(&self, _metadata: &[u8], _session: &vortex_session::VortexSession) -> vortex_error::VortexResult -pub fn vortex_array::aggregate_fn::fns::sum::Sum::empty_partial(&self, _options: &Self::Options, input_dtype: &vortex_array::dtype::DType) -> vortex_error::VortexResult +pub fn vortex_array::aggregate_fn::fns::sum::Sum::empty_partial(&self, options: &Self::Options, input_dtype: &vortex_array::dtype::DType) -> vortex_error::VortexResult pub fn vortex_array::aggregate_fn::fns::sum::Sum::finalize(&self, partials: vortex_array::ArrayRef) -> vortex_error::VortexResult @@ -78,14 +78,16 @@ pub fn vortex_array::aggregate_fn::fns::sum::Sum::id(&self) -> vortex_array::agg pub fn vortex_array::aggregate_fn::fns::sum::Sum::is_saturated(&self, partial: &Self::Partial) -> bool -pub fn vortex_array::aggregate_fn::fns::sum::Sum::partial_dtype(&self, options: &Self::Options, input_dtype: &vortex_array::dtype::DType) -> vortex_error::VortexResult +pub fn vortex_array::aggregate_fn::fns::sum::Sum::partial_dtype(&self, options: &Self::Options, input_dtype: &vortex_array::dtype::DType) -> core::option::Option -pub fn vortex_array::aggregate_fn::fns::sum::Sum::return_dtype(&self, _options: &Self::Options, input_dtype: &vortex_array::dtype::DType) -> vortex_error::VortexResult +pub fn vortex_array::aggregate_fn::fns::sum::Sum::return_dtype(&self, _options: &Self::Options, input_dtype: &vortex_array::dtype::DType) -> core::option::Option pub fn vortex_array::aggregate_fn::fns::sum::Sum::serialize(&self, options: &Self::Options) -> vortex_error::VortexResult>> pub struct vortex_array::aggregate_fn::fns::sum::SumPartial +pub fn vortex_array::aggregate_fn::fns::sum::sum(array: &vortex_array::ArrayRef, ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult + pub mod vortex_array::aggregate_fn::kernels pub trait vortex_array::aggregate_fn::kernels::DynAggregateKernel: 'static + core::marker::Send + core::marker::Sync + core::fmt::Debug @@ -132,11 +134,11 @@ pub struct vortex_array::aggregate_fn::Accumulator vortex_array::aggregate_fn::Accumulator -pub fn vortex_array::aggregate_fn::Accumulator::try_new(vtable: V, options: ::Options, dtype: vortex_array::dtype::DType, session: vortex_session::VortexSession) -> vortex_error::VortexResult +pub fn vortex_array::aggregate_fn::Accumulator::try_new(vtable: V, options: ::Options, dtype: vortex_array::dtype::DType) -> vortex_error::VortexResult impl vortex_array::aggregate_fn::DynAccumulator for vortex_array::aggregate_fn::Accumulator -pub fn vortex_array::aggregate_fn::Accumulator::accumulate(&mut self, batch: &vortex_array::ArrayRef) -> vortex_error::VortexResult<()> +pub fn vortex_array::aggregate_fn::Accumulator::accumulate(&mut self, batch: &vortex_array::ArrayRef, ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult<()> pub fn vortex_array::aggregate_fn::Accumulator::finish(&mut self) -> vortex_error::VortexResult @@ -188,9 +190,9 @@ pub struct vortex_array::aggregate_fn::AggregateFnRef(_) impl vortex_array::aggregate_fn::AggregateFnRef -pub fn vortex_array::aggregate_fn::AggregateFnRef::accumulator(&self, input_dtype: &vortex_array::dtype::DType, session: &vortex_session::VortexSession) -> vortex_error::VortexResult +pub fn vortex_array::aggregate_fn::AggregateFnRef::accumulator(&self, input_dtype: &vortex_array::dtype::DType) -> vortex_error::VortexResult -pub fn vortex_array::aggregate_fn::AggregateFnRef::accumulator_grouped(&self, input_dtype: &vortex_array::dtype::DType, session: &vortex_session::VortexSession) -> vortex_error::VortexResult +pub fn vortex_array::aggregate_fn::AggregateFnRef::accumulator_grouped(&self, input_dtype: &vortex_array::dtype::DType) -> vortex_error::VortexResult pub fn vortex_array::aggregate_fn::AggregateFnRef::as_(&self) -> &::Options @@ -204,9 +206,9 @@ pub fn vortex_array::aggregate_fn::AggregateFnRef::is vortex_array::aggregate_fn::AggregateFnOptions<'_> -pub fn vortex_array::aggregate_fn::AggregateFnRef::return_dtype(&self, input_dtype: &vortex_array::dtype::DType) -> vortex_error::VortexResult +pub fn vortex_array::aggregate_fn::AggregateFnRef::return_dtype(&self, input_dtype: &vortex_array::dtype::DType) -> core::option::Option -pub fn vortex_array::aggregate_fn::AggregateFnRef::state_dtype(&self, input_dtype: &vortex_array::dtype::DType) -> vortex_error::VortexResult +pub fn vortex_array::aggregate_fn::AggregateFnRef::state_dtype(&self, input_dtype: &vortex_array::dtype::DType) -> core::option::Option pub fn vortex_array::aggregate_fn::AggregateFnRef::vtable_ref(&self) -> core::option::Option<&V> @@ -262,11 +264,11 @@ pub struct vortex_array::aggregate_fn::GroupedAccumulator vortex_array::aggregate_fn::GroupedAccumulator -pub fn vortex_array::aggregate_fn::GroupedAccumulator::try_new(vtable: V, options: ::Options, dtype: vortex_array::dtype::DType, session: vortex_session::VortexSession) -> vortex_error::VortexResult +pub fn vortex_array::aggregate_fn::GroupedAccumulator::try_new(vtable: V, options: ::Options, dtype: vortex_array::dtype::DType) -> vortex_error::VortexResult impl vortex_array::aggregate_fn::DynGroupedAccumulator for vortex_array::aggregate_fn::GroupedAccumulator -pub fn vortex_array::aggregate_fn::GroupedAccumulator::accumulate_list(&mut self, groups: &vortex_array::ArrayRef) -> vortex_error::VortexResult<()> +pub fn vortex_array::aggregate_fn::GroupedAccumulator::accumulate_list(&mut self, groups: &vortex_array::ArrayRef, ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult<()> pub fn vortex_array::aggregate_fn::GroupedAccumulator::finish(&mut self) -> vortex_error::VortexResult @@ -310,9 +312,9 @@ pub fn vortex_array::aggregate_fn::AggregateFnVTable::id(&self) -> vortex_array: pub fn vortex_array::aggregate_fn::AggregateFnVTable::is_saturated(&self, state: &Self::Partial) -> bool -pub fn vortex_array::aggregate_fn::AggregateFnVTable::partial_dtype(&self, options: &Self::Options, input_dtype: &vortex_array::dtype::DType) -> vortex_error::VortexResult +pub fn vortex_array::aggregate_fn::AggregateFnVTable::partial_dtype(&self, options: &Self::Options, input_dtype: &vortex_array::dtype::DType) -> core::option::Option -pub fn vortex_array::aggregate_fn::AggregateFnVTable::return_dtype(&self, options: &Self::Options, input_dtype: &vortex_array::dtype::DType) -> vortex_error::VortexResult +pub fn vortex_array::aggregate_fn::AggregateFnVTable::return_dtype(&self, options: &Self::Options, input_dtype: &vortex_array::dtype::DType) -> core::option::Option pub fn vortex_array::aggregate_fn::AggregateFnVTable::serialize(&self, options: &Self::Options) -> vortex_error::VortexResult>> @@ -330,7 +332,7 @@ pub fn vortex_array::aggregate_fn::fns::sum::Sum::combine_partials(&self, partia pub fn vortex_array::aggregate_fn::fns::sum::Sum::deserialize(&self, _metadata: &[u8], _session: &vortex_session::VortexSession) -> vortex_error::VortexResult -pub fn vortex_array::aggregate_fn::fns::sum::Sum::empty_partial(&self, _options: &Self::Options, input_dtype: &vortex_array::dtype::DType) -> vortex_error::VortexResult +pub fn vortex_array::aggregate_fn::fns::sum::Sum::empty_partial(&self, options: &Self::Options, input_dtype: &vortex_array::dtype::DType) -> vortex_error::VortexResult pub fn vortex_array::aggregate_fn::fns::sum::Sum::finalize(&self, partials: vortex_array::ArrayRef) -> vortex_error::VortexResult @@ -342,9 +344,9 @@ pub fn vortex_array::aggregate_fn::fns::sum::Sum::id(&self) -> vortex_array::agg pub fn vortex_array::aggregate_fn::fns::sum::Sum::is_saturated(&self, partial: &Self::Partial) -> bool -pub fn vortex_array::aggregate_fn::fns::sum::Sum::partial_dtype(&self, options: &Self::Options, input_dtype: &vortex_array::dtype::DType) -> vortex_error::VortexResult +pub fn vortex_array::aggregate_fn::fns::sum::Sum::partial_dtype(&self, options: &Self::Options, input_dtype: &vortex_array::dtype::DType) -> core::option::Option -pub fn vortex_array::aggregate_fn::fns::sum::Sum::return_dtype(&self, _options: &Self::Options, input_dtype: &vortex_array::dtype::DType) -> vortex_error::VortexResult +pub fn vortex_array::aggregate_fn::fns::sum::Sum::return_dtype(&self, _options: &Self::Options, input_dtype: &vortex_array::dtype::DType) -> core::option::Option pub fn vortex_array::aggregate_fn::fns::sum::Sum::serialize(&self, options: &Self::Options) -> vortex_error::VortexResult>> @@ -358,7 +360,7 @@ pub fn V::bind(&self, options: Self::Options) -> vortex_array::aggregate_fn::Agg pub trait vortex_array::aggregate_fn::DynAccumulator: 'static + core::marker::Send -pub fn vortex_array::aggregate_fn::DynAccumulator::accumulate(&mut self, batch: &vortex_array::ArrayRef) -> vortex_error::VortexResult<()> +pub fn vortex_array::aggregate_fn::DynAccumulator::accumulate(&mut self, batch: &vortex_array::ArrayRef, ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult<()> pub fn vortex_array::aggregate_fn::DynAccumulator::finish(&mut self) -> vortex_error::VortexResult @@ -368,7 +370,7 @@ pub fn vortex_array::aggregate_fn::DynAccumulator::is_saturated(&self) -> bool impl vortex_array::aggregate_fn::DynAccumulator for vortex_array::aggregate_fn::Accumulator -pub fn vortex_array::aggregate_fn::Accumulator::accumulate(&mut self, batch: &vortex_array::ArrayRef) -> vortex_error::VortexResult<()> +pub fn vortex_array::aggregate_fn::Accumulator::accumulate(&mut self, batch: &vortex_array::ArrayRef, ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult<()> pub fn vortex_array::aggregate_fn::Accumulator::finish(&mut self) -> vortex_error::VortexResult @@ -378,7 +380,7 @@ pub fn vortex_array::aggregate_fn::Accumulator::is_saturated(&self) -> bool pub trait vortex_array::aggregate_fn::DynGroupedAccumulator: 'static + core::marker::Send -pub fn vortex_array::aggregate_fn::DynGroupedAccumulator::accumulate_list(&mut self, groups: &vortex_array::ArrayRef) -> vortex_error::VortexResult<()> +pub fn vortex_array::aggregate_fn::DynGroupedAccumulator::accumulate_list(&mut self, groups: &vortex_array::ArrayRef, ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult<()> pub fn vortex_array::aggregate_fn::DynGroupedAccumulator::finish(&mut self) -> vortex_error::VortexResult @@ -386,7 +388,7 @@ pub fn vortex_array::aggregate_fn::DynGroupedAccumulator::flush(&mut self) -> vo impl vortex_array::aggregate_fn::DynGroupedAccumulator for vortex_array::aggregate_fn::GroupedAccumulator -pub fn vortex_array::aggregate_fn::GroupedAccumulator::accumulate_list(&mut self, groups: &vortex_array::ArrayRef) -> vortex_error::VortexResult<()> +pub fn vortex_array::aggregate_fn::GroupedAccumulator::accumulate_list(&mut self, groups: &vortex_array::ArrayRef, ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult<()> pub fn vortex_array::aggregate_fn::GroupedAccumulator::finish(&mut self) -> vortex_error::VortexResult @@ -440,10 +442,6 @@ impl vortex_array::compute::MinMaxKernel for vortex_array::arrays::Bool pub fn vortex_array::arrays::Bool::min_max(&self, array: &vortex_array::arrays::BoolArray) -> vortex_error::VortexResult> -impl vortex_array::compute::SumKernel for vortex_array::arrays::Bool - -pub fn vortex_array::arrays::Bool::sum(&self, array: &vortex_array::arrays::BoolArray, accumulator: &vortex_array::scalar::Scalar) -> vortex_error::VortexResult - impl vortex_array::optimizer::rules::ArrayParentReduceRule for vortex_array::arrays::bool::BoolMaskedValidityRule pub type vortex_array::arrays::bool::BoolMaskedValidityRule::Parent = vortex_array::arrays::Masked @@ -668,10 +666,6 @@ impl vortex_array::compute::MinMaxKernel for vortex_array::arrays::Chunked pub fn vortex_array::arrays::Chunked::min_max(&self, array: &vortex_array::arrays::ChunkedArray) -> vortex_error::VortexResult> -impl vortex_array::compute::SumKernel for vortex_array::arrays::Chunked - -pub fn vortex_array::arrays::Chunked::sum(&self, array: &vortex_array::arrays::ChunkedArray, accumulator: &vortex_array::scalar::Scalar) -> vortex_error::VortexResult - impl vortex_array::scalar_fn::fns::cast::CastReduce for vortex_array::arrays::Chunked pub fn vortex_array::arrays::Chunked::cast(array: &vortex_array::arrays::ChunkedArray, dtype: &vortex_array::dtype::DType) -> vortex_error::VortexResult> @@ -842,10 +836,6 @@ impl vortex_array::compute::MinMaxKernel for vortex_array::arrays::Constant pub fn vortex_array::arrays::Constant::min_max(&self, array: &vortex_array::arrays::ConstantArray) -> vortex_error::VortexResult> -impl vortex_array::compute::SumKernel for vortex_array::arrays::Constant - -pub fn vortex_array::arrays::Constant::sum(&self, array: &vortex_array::arrays::ConstantArray, accumulator: &vortex_array::scalar::Scalar) -> vortex_error::VortexResult - impl vortex_array::scalar_fn::fns::between::BetweenReduce for vortex_array::arrays::Constant pub fn vortex_array::arrays::Constant::between(array: &vortex_array::arrays::ConstantArray, lower: &vortex_array::ArrayRef, upper: &vortex_array::ArrayRef, options: &vortex_array::scalar_fn::fns::between::BetweenOptions) -> vortex_error::VortexResult> @@ -1060,10 +1050,6 @@ impl vortex_array::compute::MinMaxKernel for vortex_array::arrays::Decimal pub fn vortex_array::arrays::Decimal::min_max(&self, array: &vortex_array::arrays::DecimalArray) -> vortex_error::VortexResult> -impl vortex_array::compute::SumKernel for vortex_array::arrays::Decimal - -pub fn vortex_array::arrays::Decimal::sum(&self, array: &vortex_array::arrays::DecimalArray, accumulator: &vortex_array::scalar::Scalar) -> vortex_error::VortexResult - impl vortex_array::optimizer::rules::ArrayParentReduceRule for vortex_array::arrays::decimal::DecimalMaskedValidityRule pub type vortex_array::arrays::decimal::DecimalMaskedValidityRule::Parent = vortex_array::arrays::Masked @@ -1724,10 +1710,6 @@ impl vortex_array::compute::MinMaxKernel for vortex_array::arrays::Extension pub fn vortex_array::arrays::Extension::min_max(&self, array: &vortex_array::arrays::ExtensionArray) -> vortex_error::VortexResult> -impl vortex_array::compute::SumKernel for vortex_array::arrays::Extension - -pub fn vortex_array::arrays::Extension::sum(&self, array: &vortex_array::arrays::ExtensionArray, accumulator: &vortex_array::scalar::Scalar) -> vortex_error::VortexResult - impl vortex_array::scalar_fn::fns::binary::CompareKernel for vortex_array::arrays::Extension pub fn vortex_array::arrays::Extension::compare(lhs: &vortex_array::arrays::ExtensionArray, rhs: &vortex_array::ArrayRef, operator: vortex_array::scalar_fn::fns::operators::CompareOperator, _ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult> @@ -2976,10 +2958,6 @@ impl vortex_array::compute::NaNCountKernel for vortex_array::arrays::Primitive pub fn vortex_array::arrays::Primitive::nan_count(&self, array: &vortex_array::arrays::PrimitiveArray) -> vortex_error::VortexResult -impl vortex_array::compute::SumKernel for vortex_array::arrays::Primitive - -pub fn vortex_array::arrays::Primitive::sum(&self, array: &vortex_array::arrays::PrimitiveArray, accumulator: &vortex_array::scalar::Scalar) -> vortex_error::VortexResult - impl vortex_array::optimizer::rules::ArrayParentReduceRule for vortex_array::arrays::primitive::PrimitiveMaskedValidityRule pub type vortex_array::arrays::primitive::PrimitiveMaskedValidityRule::Parent = vortex_array::arrays::Masked @@ -4640,10 +4618,6 @@ impl vortex_array::compute::MinMaxKernel for vortex_array::arrays::Bool pub fn vortex_array::arrays::Bool::min_max(&self, array: &vortex_array::arrays::BoolArray) -> vortex_error::VortexResult> -impl vortex_array::compute::SumKernel for vortex_array::arrays::Bool - -pub fn vortex_array::arrays::Bool::sum(&self, array: &vortex_array::arrays::BoolArray, accumulator: &vortex_array::scalar::Scalar) -> vortex_error::VortexResult - impl vortex_array::optimizer::rules::ArrayParentReduceRule for vortex_array::arrays::bool::BoolMaskedValidityRule pub type vortex_array::arrays::bool::BoolMaskedValidityRule::Parent = vortex_array::arrays::Masked @@ -4840,10 +4814,6 @@ impl vortex_array::compute::MinMaxKernel for vortex_array::arrays::Chunked pub fn vortex_array::arrays::Chunked::min_max(&self, array: &vortex_array::arrays::ChunkedArray) -> vortex_error::VortexResult> -impl vortex_array::compute::SumKernel for vortex_array::arrays::Chunked - -pub fn vortex_array::arrays::Chunked::sum(&self, array: &vortex_array::arrays::ChunkedArray, accumulator: &vortex_array::scalar::Scalar) -> vortex_error::VortexResult - impl vortex_array::scalar_fn::fns::cast::CastReduce for vortex_array::arrays::Chunked pub fn vortex_array::arrays::Chunked::cast(array: &vortex_array::arrays::ChunkedArray, dtype: &vortex_array::dtype::DType) -> vortex_error::VortexResult> @@ -5012,10 +4982,6 @@ impl vortex_array::compute::MinMaxKernel for vortex_array::arrays::Constant pub fn vortex_array::arrays::Constant::min_max(&self, array: &vortex_array::arrays::ConstantArray) -> vortex_error::VortexResult> -impl vortex_array::compute::SumKernel for vortex_array::arrays::Constant - -pub fn vortex_array::arrays::Constant::sum(&self, array: &vortex_array::arrays::ConstantArray, accumulator: &vortex_array::scalar::Scalar) -> vortex_error::VortexResult - impl vortex_array::scalar_fn::fns::between::BetweenReduce for vortex_array::arrays::Constant pub fn vortex_array::arrays::Constant::between(array: &vortex_array::arrays::ConstantArray, lower: &vortex_array::ArrayRef, upper: &vortex_array::ArrayRef, options: &vortex_array::scalar_fn::fns::between::BetweenOptions) -> vortex_error::VortexResult> @@ -5166,10 +5132,6 @@ impl vortex_array::compute::MinMaxKernel for vortex_array::arrays::Decimal pub fn vortex_array::arrays::Decimal::min_max(&self, array: &vortex_array::arrays::DecimalArray) -> vortex_error::VortexResult> -impl vortex_array::compute::SumKernel for vortex_array::arrays::Decimal - -pub fn vortex_array::arrays::Decimal::sum(&self, array: &vortex_array::arrays::DecimalArray, accumulator: &vortex_array::scalar::Scalar) -> vortex_error::VortexResult - impl vortex_array::optimizer::rules::ArrayParentReduceRule for vortex_array::arrays::decimal::DecimalMaskedValidityRule pub type vortex_array::arrays::decimal::DecimalMaskedValidityRule::Parent = vortex_array::arrays::Masked @@ -5538,10 +5500,6 @@ impl vortex_array::compute::MinMaxKernel for vortex_array::arrays::Extension pub fn vortex_array::arrays::Extension::min_max(&self, array: &vortex_array::arrays::ExtensionArray) -> vortex_error::VortexResult> -impl vortex_array::compute::SumKernel for vortex_array::arrays::Extension - -pub fn vortex_array::arrays::Extension::sum(&self, array: &vortex_array::arrays::ExtensionArray, accumulator: &vortex_array::scalar::Scalar) -> vortex_error::VortexResult - impl vortex_array::scalar_fn::fns::binary::CompareKernel for vortex_array::arrays::Extension pub fn vortex_array::arrays::Extension::compare(lhs: &vortex_array::arrays::ExtensionArray, rhs: &vortex_array::ArrayRef, operator: vortex_array::scalar_fn::fns::operators::CompareOperator, _ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult> @@ -6588,10 +6546,6 @@ impl vortex_array::compute::NaNCountKernel for vortex_array::arrays::Primitive pub fn vortex_array::arrays::Primitive::nan_count(&self, array: &vortex_array::arrays::PrimitiveArray) -> vortex_error::VortexResult -impl vortex_array::compute::SumKernel for vortex_array::arrays::Primitive - -pub fn vortex_array::arrays::Primitive::sum(&self, array: &vortex_array::arrays::PrimitiveArray, accumulator: &vortex_array::scalar::Scalar) -> vortex_error::VortexResult - impl vortex_array::optimizer::rules::ArrayParentReduceRule for vortex_array::arrays::primitive::PrimitiveMaskedValidityRule pub type vortex_array::arrays::primitive::PrimitiveMaskedValidityRule::Parent = vortex_array::arrays::Masked @@ -9564,12 +9518,6 @@ impl<'a> core::clone::Clone for vortex_array::compute::InvocationArgs<'a> pub fn vortex_array::compute::InvocationArgs<'a>::clone(&self) -> vortex_array::compute::InvocationArgs<'a> -impl<'a> core::convert::TryFrom<&vortex_array::compute::InvocationArgs<'a>> for vortex_array::compute::SumArgs<'a> - -pub type vortex_array::compute::SumArgs<'a>::Error = vortex_error::VortexError - -pub fn vortex_array::compute::SumArgs<'a>::try_from(value: &vortex_array::compute::InvocationArgs<'a>) -> core::result::Result - pub struct vortex_array::compute::IsConstantKernelAdapter(pub V) impl vortex_array::compute::IsConstantKernelAdapter @@ -9704,36 +9652,6 @@ pub struct vortex_array::compute::NaNCountKernelRef(_) impl inventory::Collect for vortex_array::compute::NaNCountKernelRef -pub struct vortex_array::compute::SumArgs<'a> - -pub vortex_array::compute::SumArgs::accumulator: &'a vortex_array::scalar::Scalar - -pub vortex_array::compute::SumArgs::array: &'a dyn vortex_array::DynArray - -impl<'a> core::convert::TryFrom<&vortex_array::compute::InvocationArgs<'a>> for vortex_array::compute::SumArgs<'a> - -pub type vortex_array::compute::SumArgs<'a>::Error = vortex_error::VortexError - -pub fn vortex_array::compute::SumArgs<'a>::try_from(value: &vortex_array::compute::InvocationArgs<'a>) -> core::result::Result - -pub struct vortex_array::compute::SumKernelAdapter(pub V) - -impl vortex_array::compute::SumKernelAdapter - -pub const fn vortex_array::compute::SumKernelAdapter::lift(&'static self) -> vortex_array::compute::SumKernelRef - -impl core::fmt::Debug for vortex_array::compute::SumKernelAdapter - -pub fn vortex_array::compute::SumKernelAdapter::fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result - -impl vortex_array::compute::Kernel for vortex_array::compute::SumKernelAdapter - -pub fn vortex_array::compute::SumKernelAdapter::invoke(&self, args: &vortex_array::compute::InvocationArgs<'_>) -> vortex_error::VortexResult> - -pub struct vortex_array::compute::SumKernelRef(_) - -impl inventory::Collect for vortex_array::compute::SumKernelRef - pub struct vortex_array::compute::UnaryArgs<'a, O: vortex_array::compute::Options> pub vortex_array::compute::UnaryArgs::array: &'a dyn vortex_array::DynArray @@ -9918,10 +9836,6 @@ impl vo pub fn vortex_array::compute::NaNCountKernelAdapter::invoke(&self, args: &vortex_array::compute::InvocationArgs<'_>) -> vortex_error::VortexResult> -impl vortex_array::compute::Kernel for vortex_array::compute::SumKernelAdapter - -pub fn vortex_array::compute::SumKernelAdapter::invoke(&self, args: &vortex_array::compute::InvocationArgs<'_>) -> vortex_error::VortexResult> - pub trait vortex_array::compute::MinMaxKernel: vortex_array::vtable::VTable pub fn vortex_array::compute::MinMaxKernel::min_max(&self, array: &Self::Array) -> vortex_error::VortexResult> @@ -10006,34 +9920,6 @@ impl vortex_array::compute::Options for vortex_array::scalar_fn::fns::between::B pub fn vortex_array::scalar_fn::fns::between::BetweenOptions::as_any(&self) -> &dyn core::any::Any -pub trait vortex_array::compute::SumKernel: vortex_array::vtable::VTable - -pub fn vortex_array::compute::SumKernel::sum(&self, array: &Self::Array, accumulator: &vortex_array::scalar::Scalar) -> vortex_error::VortexResult - -impl vortex_array::compute::SumKernel for vortex_array::arrays::Bool - -pub fn vortex_array::arrays::Bool::sum(&self, array: &vortex_array::arrays::BoolArray, accumulator: &vortex_array::scalar::Scalar) -> vortex_error::VortexResult - -impl vortex_array::compute::SumKernel for vortex_array::arrays::Chunked - -pub fn vortex_array::arrays::Chunked::sum(&self, array: &vortex_array::arrays::ChunkedArray, accumulator: &vortex_array::scalar::Scalar) -> vortex_error::VortexResult - -impl vortex_array::compute::SumKernel for vortex_array::arrays::Constant - -pub fn vortex_array::arrays::Constant::sum(&self, array: &vortex_array::arrays::ConstantArray, accumulator: &vortex_array::scalar::Scalar) -> vortex_error::VortexResult - -impl vortex_array::compute::SumKernel for vortex_array::arrays::Decimal - -pub fn vortex_array::arrays::Decimal::sum(&self, array: &vortex_array::arrays::DecimalArray, accumulator: &vortex_array::scalar::Scalar) -> vortex_error::VortexResult - -impl vortex_array::compute::SumKernel for vortex_array::arrays::Extension - -pub fn vortex_array::arrays::Extension::sum(&self, array: &vortex_array::arrays::ExtensionArray, accumulator: &vortex_array::scalar::Scalar) -> vortex_error::VortexResult - -impl vortex_array::compute::SumKernel for vortex_array::arrays::Primitive - -pub fn vortex_array::arrays::Primitive::sum(&self, array: &vortex_array::arrays::PrimitiveArray, accumulator: &vortex_array::scalar::Scalar) -> vortex_error::VortexResult - pub fn vortex_array::compute::is_constant(array: &vortex_array::ArrayRef) -> vortex_error::VortexResult> pub fn vortex_array::compute::is_constant_opts(array: &vortex_array::ArrayRef, options: &vortex_array::compute::IsConstantOpts) -> vortex_error::VortexResult> @@ -10050,8 +9936,6 @@ pub fn vortex_array::compute::nan_count(array: &vortex_array::ArrayRef) -> vorte pub fn vortex_array::compute::sum(array: &vortex_array::ArrayRef) -> vortex_error::VortexResult -pub fn vortex_array::compute::sum_impl(array: &vortex_array::ArrayRef, accumulator: &vortex_array::scalar::Scalar, kernels: &[arcref::ArcRef]) -> vortex_error::VortexResult - pub fn vortex_array::compute::warm_up_vtables() pub mod vortex_array::display diff --git a/vortex-array/src/aggregate_fn/accumulator.rs b/vortex-array/src/aggregate_fn/accumulator.rs index 5e9a12e53fb..107b13d0985 100644 --- a/vortex-array/src/aggregate_fn/accumulator.rs +++ b/vortex-array/src/aggregate_fn/accumulator.rs @@ -3,13 +3,13 @@ use vortex_error::VortexResult; use vortex_error::vortex_ensure; -use vortex_session::VortexSession; +use vortex_error::vortex_err; use crate::AnyCanonical; use crate::ArrayRef; use crate::Columnar; use crate::DynArray; -use crate::VortexSessionExecute; +use crate::ExecutionCtx; use crate::aggregate_fn::AggregateFn; use crate::aggregate_fn::AggregateFnRef; use crate::aggregate_fn::AggregateFnVTable; @@ -35,19 +35,24 @@ pub struct Accumulator { partial_dtype: DType, /// The partial state of the accumulator, updated after each accumulate/merge call. partial: V::Partial, - /// A session used to lookup custom aggregate kernels. - session: VortexSession, } impl Accumulator { - pub fn try_new( - vtable: V, - options: V::Options, - dtype: DType, - session: VortexSession, - ) -> VortexResult { - let return_dtype = vtable.return_dtype(&options, &dtype)?; - let partial_dtype = vtable.partial_dtype(&options, &dtype)?; + pub fn try_new(vtable: V, options: V::Options, dtype: DType) -> VortexResult { + let return_dtype = vtable.return_dtype(&options, &dtype).ok_or_else(|| { + vortex_err!( + "Aggregate function {} cannot be applied to dtype {}", + vtable.id(), + dtype + ) + })?; + let partial_dtype = vtable.partial_dtype(&options, &dtype).ok_or_else(|| { + vortex_err!( + "Aggregate function {} cannot be applied to dtype {}", + vtable.id(), + dtype + ) + })?; let partial = vtable.empty_partial(&options, &dtype)?; let aggregate_fn = AggregateFn::new(vtable.clone(), options).erased(); @@ -58,7 +63,6 @@ impl Accumulator { return_dtype, partial_dtype, partial, - session, }) } } @@ -67,7 +71,7 @@ impl Accumulator { /// function is not known at compile time. pub trait DynAccumulator: 'static + Send { /// Accumulate a new array into the accumulator's state. - fn accumulate(&mut self, batch: &ArrayRef) -> VortexResult<()>; + fn accumulate(&mut self, batch: &ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult<()>; /// Whether the accumulator's result is fully determined. fn is_saturated(&self) -> bool; @@ -84,7 +88,7 @@ pub trait DynAccumulator: 'static + Send { } impl DynAccumulator for Accumulator { - fn accumulate(&mut self, batch: &ArrayRef) -> VortexResult<()> { + fn accumulate(&mut self, batch: &ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult<()> { if self.is_saturated() { return Ok(()); } @@ -96,9 +100,9 @@ impl DynAccumulator for Accumulator { batch.dtype() ); - let kernels = &self.session.aggregate_fns().kernels; + let session = ctx.session().clone(); + let kernels = &session.aggregate_fns().kernels; - let mut ctx = self.session.create_execution_ctx(); let mut batch = batch.clone(); for _ in 0..*MAX_ITERATIONS { if batch.is::() { @@ -112,7 +116,7 @@ impl DynAccumulator for Accumulator { .or_else(|| kernels_r.get(&(batch_id, None))) .and_then(|kernel| { kernel - .aggregate(&self.aggregate_fn, &batch, &mut ctx) + .aggregate(&self.aggregate_fn, &batch, ctx) .transpose() }) .transpose()? @@ -128,14 +132,13 @@ impl DynAccumulator for Accumulator { } // Execute one step and try again - batch = batch.execute(&mut ctx)?; + batch = batch.execute(ctx)?; } // Otherwise, execute the batch until it is columnar and accumulate it into the state. - let columnar = batch.execute::(&mut ctx)?; + let columnar = batch.execute::(ctx)?; - self.vtable - .accumulate(&mut self.partial, &columnar, &mut ctx) + self.vtable.accumulate(&mut self.partial, &columnar, ctx) } fn is_saturated(&self) -> bool { diff --git a/vortex-array/src/aggregate_fn/accumulator_grouped.rs b/vortex-array/src/aggregate_fn/accumulator_grouped.rs index b2b9cf38b35..a4d9c38b60e 100644 --- a/vortex-array/src/aggregate_fn/accumulator_grouped.rs +++ b/vortex-array/src/aggregate_fn/accumulator_grouped.rs @@ -7,9 +7,9 @@ use vortex_error::VortexExpect; use vortex_error::VortexResult; use vortex_error::vortex_bail; use vortex_error::vortex_ensure; +use vortex_error::vortex_err; use vortex_error::vortex_panic; use vortex_mask::Mask; -use vortex_session::VortexSession; use crate::AnyCanonical; use crate::ArrayRef; @@ -18,7 +18,6 @@ use crate::Columnar; use crate::DynArray; use crate::ExecutionCtx; use crate::IntoArray; -use crate::VortexSessionExecute; use crate::aggregate_fn::Accumulator; use crate::aggregate_fn::AggregateFn; use crate::aggregate_fn::AggregateFnRef; @@ -58,20 +57,25 @@ pub struct GroupedAccumulator { partial_dtype: DType, /// The accumulated state for prior batches of groups. partials: Vec, - /// A session used to lookup custom aggregate kernels. - session: VortexSession, } impl GroupedAccumulator { - pub fn try_new( - vtable: V, - options: V::Options, - dtype: DType, - session: VortexSession, - ) -> VortexResult { + pub fn try_new(vtable: V, options: V::Options, dtype: DType) -> VortexResult { let aggregate_fn = AggregateFn::new(vtable.clone(), options.clone()).erased(); - let return_dtype = vtable.return_dtype(&options, &dtype)?; - let partial_dtype = vtable.partial_dtype(&options, &dtype)?; + let return_dtype = vtable.return_dtype(&options, &dtype).ok_or_else(|| { + vortex_err!( + "Aggregate function {} cannot be applied to dtype {}", + vtable.id(), + dtype + ) + })?; + let partial_dtype = vtable.partial_dtype(&options, &dtype).ok_or_else(|| { + vortex_err!( + "Aggregate function {} cannot be applied to dtype {}", + vtable.id(), + dtype + ) + })?; Ok(Self { vtable, @@ -81,7 +85,6 @@ impl GroupedAccumulator { return_dtype, partial_dtype, partials: vec![], - session, }) } } @@ -90,7 +93,7 @@ impl GroupedAccumulator { /// function is not known at compile time. pub trait DynGroupedAccumulator: 'static + Send { /// Accumulate a list of groups into the accumulator. - fn accumulate_list(&mut self, groups: &ArrayRef) -> VortexResult<()>; + fn accumulate_list(&mut self, groups: &ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult<()>; /// Finish the accumulation and return the partial aggregate results for all groups. /// Resets the accumulator state for the next round of accumulation. @@ -102,7 +105,7 @@ pub trait DynGroupedAccumulator: 'static + Send { } impl DynGroupedAccumulator for GroupedAccumulator { - fn accumulate_list(&mut self, groups: &ArrayRef) -> VortexResult<()> { + fn accumulate_list(&mut self, groups: &ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult<()> { let elements_dtype = match groups.dtype() { DType::List(elem, _) => elem, DType::FixedSizeList(elem, ..) => elem, @@ -118,17 +121,15 @@ impl DynGroupedAccumulator for GroupedAccumulator { elements_dtype ); - let mut ctx = self.session.create_execution_ctx(); - // We first execute the groups until it is a ListView or FixedSizeList, since we only // dispatch the aggregate kernel over the elements of these arrays. - let canonical = match groups.clone().execute::(&mut ctx)? { + let canonical = match groups.clone().execute::(ctx)? { Columnar::Canonical(c) => c, - Columnar::Constant(c) => c.into_array().execute::(&mut ctx)?, + Columnar::Constant(c) => c.into_array().execute::(ctx)?, }; match canonical { - Canonical::List(groups) => self.accumulate_list_view(&groups, &mut ctx), - Canonical::FixedSizeList(groups) => self.accumulate_fixed_size_list(&groups, &mut ctx), + Canonical::List(groups) => self.accumulate_list_view(&groups, ctx), + Canonical::FixedSizeList(groups) => self.accumulate_fixed_size_list(&groups, ctx), _ => vortex_panic!("We checked the DType above, so this should never happen"), } } @@ -160,8 +161,7 @@ impl GroupedAccumulator { ctx: &mut ExecutionCtx, ) -> VortexResult<()> { let mut elements = groups.elements().clone(); - let session = self.session.clone(); - + let session = ctx.session().clone(); let kernels = &session.aggregate_fns().grouped_kernels; for _ in 0..*MAX_ITERATIONS { @@ -205,7 +205,13 @@ impl GroupedAccumulator { match_each_integer_ptype!(offsets.dtype().as_ptype(), |O| { let offsets = offsets.clone().execute::>(ctx)?; let sizes = sizes.execute::>(ctx)?; - self.accumulate_list_view_typed(&elements, offsets.as_ref(), sizes.as_ref(), &validity) + self.accumulate_list_view_typed( + &elements, + offsets.as_ref(), + sizes.as_ref(), + &validity, + ctx, + ) }) } @@ -215,12 +221,12 @@ impl GroupedAccumulator { offsets: &[O], sizes: &[O], validity: &Mask, + ctx: &mut ExecutionCtx, ) -> VortexResult<()> { let mut accumulator = Accumulator::try_new( self.vtable.clone(), self.options.clone(), self.dtype.clone(), - self.session.clone(), )?; let mut states = builder_with_capacity(&self.partial_dtype, offsets.len()); @@ -230,7 +236,7 @@ impl GroupedAccumulator { if validity.value(offset) { let group = elements.slice(offset..offset + size)?; - accumulator.accumulate(&group)?; + accumulator.accumulate(&group, ctx)?; states.append_scalar(&accumulator.finish()?)?; } else { states.append_null() @@ -246,8 +252,7 @@ impl GroupedAccumulator { ctx: &mut ExecutionCtx, ) -> VortexResult<()> { let mut elements = groups.elements().clone(); - - let session = self.session.clone(); + let session = ctx.session().clone(); let kernels = &session.aggregate_fns().grouped_kernels; for _ in 0..64 { @@ -291,7 +296,6 @@ impl GroupedAccumulator { self.vtable.clone(), self.options.clone(), self.dtype.clone(), - self.session.clone(), )?; let mut states = builder_with_capacity(&self.partial_dtype, groups.len()); @@ -304,7 +308,7 @@ impl GroupedAccumulator { for i in 0..groups.len() { if validity.value(i) { let group = elements.slice(offset..offset + size)?; - accumulator.accumulate(&group)?; + accumulator.accumulate(&group, ctx)?; states.append_scalar(&accumulator.finish()?)?; } else { states.append_null() diff --git a/vortex-array/src/aggregate_fn/erased.rs b/vortex-array/src/aggregate_fn/erased.rs index 78f182b6343..750a7c24f77 100644 --- a/vortex-array/src/aggregate_fn/erased.rs +++ b/vortex-array/src/aggregate_fn/erased.rs @@ -12,7 +12,6 @@ use std::sync::Arc; use vortex_error::VortexExpect; use vortex_error::VortexResult; -use vortex_session::VortexSession; use vortex_utils::debug_with::DebugWith; use crate::aggregate_fn::AccumulatorRef; @@ -81,31 +80,27 @@ impl AggregateFnRef { } /// Compute the return [`DType`] per group given the input element type. - pub fn return_dtype(&self, input_dtype: &DType) -> VortexResult { + /// + /// Returns `None` if the input dtype is not supported by the aggregate function. + pub fn return_dtype(&self, input_dtype: &DType) -> Option { self.0.return_dtype(input_dtype) } /// DType of the intermediate accumulator state. - pub fn state_dtype(&self, input_dtype: &DType) -> VortexResult { + /// + /// Returns `None` if the input dtype is not supported by the aggregate function. + pub fn state_dtype(&self, input_dtype: &DType) -> Option { self.0.state_dtype(input_dtype) } /// Create an accumulator for streaming aggregation. - pub fn accumulator( - &self, - input_dtype: &DType, - session: &VortexSession, - ) -> VortexResult { - self.0.accumulator(input_dtype, session) + pub fn accumulator(&self, input_dtype: &DType) -> VortexResult { + self.0.accumulator(input_dtype) } /// Create a grouped accumulator for grouped streaming aggregation. - pub fn accumulator_grouped( - &self, - input_dtype: &DType, - session: &VortexSession, - ) -> VortexResult { - self.0.accumulator_grouped(input_dtype, session) + pub fn accumulator_grouped(&self, input_dtype: &DType) -> VortexResult { + self.0.accumulator_grouped(input_dtype) } } diff --git a/vortex-array/src/aggregate_fn/fns/sum.rs b/vortex-array/src/aggregate_fn/fns/sum.rs deleted file mode 100644 index 52af2f3c4fb..00000000000 --- a/vortex-array/src/aggregate_fn/fns/sum.rs +++ /dev/null @@ -1,837 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: Copyright the Vortex contributors - -use std::ops::BitAnd; - -use itertools::Itertools; -use num_traits::ToPrimitive; -use vortex_error::VortexExpect; -use vortex_error::VortexResult; -use vortex_error::vortex_bail; -use vortex_error::vortex_err; -use vortex_error::vortex_panic; -use vortex_mask::AllOr; - -use crate::ArrayRef; -use crate::Canonical; -use crate::Columnar; -use crate::ExecutionCtx; -use crate::aggregate_fn::AggregateFnId; -use crate::aggregate_fn::AggregateFnVTable; -use crate::aggregate_fn::EmptyOptions; -use crate::arrays::BoolArray; -use crate::arrays::ConstantArray; -use crate::arrays::DecimalArray; -use crate::arrays::PrimitiveArray; -use crate::dtype::DType; -use crate::dtype::Nullability; -use crate::dtype::PType; -use crate::expr::stats::Stat; -use crate::match_each_decimal_value_type; -use crate::match_each_native_ptype; -use crate::scalar::DecimalValue; -use crate::scalar::Scalar; - -#[derive(Clone, Debug)] -pub struct Sum; - -impl AggregateFnVTable for Sum { - type Options = EmptyOptions; - type Partial = SumPartial; - - fn id(&self) -> AggregateFnId { - AggregateFnId::new_ref("vortex.sum") - } - - fn return_dtype(&self, _options: &Self::Options, input_dtype: &DType) -> VortexResult { - Stat::Sum - .dtype(input_dtype) - .ok_or_else(|| vortex_err!("Cannot sum {}", input_dtype)) - } - - fn partial_dtype(&self, options: &Self::Options, input_dtype: &DType) -> VortexResult { - self.return_dtype(options, input_dtype) - } - - fn empty_partial( - &self, - _options: &Self::Options, - input_dtype: &DType, - ) -> VortexResult { - let return_dtype = Stat::Sum - .dtype(input_dtype) - .ok_or_else(|| vortex_err!("Cannot sum {}", input_dtype))?; - - let initial = make_zero_state(&return_dtype); - - Ok(SumPartial { - return_dtype, - current: Some(initial), - }) - } - - fn combine_partials(&self, partial: &mut Self::Partial, other: Scalar) -> VortexResult<()> { - if other.is_null() { - // A null partial means the sub-accumulator saturated (overflow). - partial.current = None; - return Ok(()); - } - let Some(ref mut inner) = partial.current else { - return Ok(()); - }; - let saturated = match inner { - SumState::Unsigned(acc) => { - let val = other - .as_primitive() - .typed_value::() - .vortex_expect("checked non-null"); - checked_add_u64(acc, val) - } - SumState::Signed(acc) => { - let val = other - .as_primitive() - .typed_value::() - .vortex_expect("checked non-null"); - checked_add_i64(acc, val) - } - SumState::Float(acc) => { - let val = other - .as_primitive() - .typed_value::() - .vortex_expect("checked non-null"); - *acc += val; - false - } - SumState::Decimal(acc) => { - let val = other - .as_decimal() - .decimal_value() - .vortex_expect("checked non-null"); - match acc.checked_add(&val) { - Some(r) => { - *acc = r; - false - } - None => true, - } - } - }; - if saturated { - partial.current = None; - } - Ok(()) - } - - fn flush(&self, partial: &mut Self::Partial) -> VortexResult { - let result = match &partial.current { - None => Scalar::null(partial.return_dtype.as_nullable()), - Some(SumState::Unsigned(v)) => Scalar::primitive(*v, Nullability::Nullable), - Some(SumState::Signed(v)) => Scalar::primitive(*v, Nullability::Nullable), - Some(SumState::Float(v)) => Scalar::primitive(*v, Nullability::Nullable), - Some(SumState::Decimal(v)) => { - let decimal_dtype = *partial - .return_dtype - .as_decimal_opt() - .vortex_expect("return dtype must be decimal"); - Scalar::decimal(*v, decimal_dtype, Nullability::Nullable) - } - }; - - // Reset the state - partial.current = Some(make_zero_state(&partial.return_dtype)); - - Ok(result) - } - - #[inline] - fn is_saturated(&self, partial: &Self::Partial) -> bool { - partial.current.is_none() - } - - fn accumulate( - &self, - partial: &mut Self::Partial, - batch: &Columnar, - _ctx: &mut ExecutionCtx, - ) -> VortexResult<()> { - let mut inner = match partial.current.take() { - Some(inner) => inner, - None => return Ok(()), - }; - - let result = match batch { - Columnar::Canonical(c) => match c { - Canonical::Primitive(p) => accumulate_primitive(&mut inner, p), - Canonical::Bool(b) => accumulate_bool(&mut inner, b), - Canonical::Decimal(d) => accumulate_decimal(&mut inner, d), - _ => vortex_bail!("Unsupported canonical type for sum: {}", batch.dtype()), - }, - Columnar::Constant(c) => accumulate_constant(&mut inner, c), - }; - - match result { - Ok(false) => partial.current = Some(inner), - Ok(true) => {} // saturated: current stays None - Err(e) => { - partial.current = Some(inner); - return Err(e); - } - } - Ok(()) - } - - fn finalize(&self, partials: ArrayRef) -> VortexResult { - Ok(partials) - } - - fn finalize_scalar(&self, partial: Scalar) -> VortexResult { - Ok(partial) - } -} - -/// The group state for a sum aggregate, containing the accumulated value and configuration -/// needed for reset/result without external context. -pub struct SumPartial { - return_dtype: DType, - /// The current accumulated state, or `None` if saturated (checked overflow). - current: Option, -} - -/// The accumulated sum value. -/// -// TODO(ngates): instead of an enum, we should use a Box to avoid dispatcher over the -// input type every time? Perhaps? -pub enum SumState { - Unsigned(u64), - Signed(i64), - Float(f64), - Decimal(DecimalValue), -} - -fn make_zero_state(return_dtype: &DType) -> SumState { - match return_dtype { - DType::Primitive(ptype, _) => match ptype { - PType::U8 | PType::U16 | PType::U32 | PType::U64 => SumState::Unsigned(0), - PType::I8 | PType::I16 | PType::I32 | PType::I64 => SumState::Signed(0), - PType::F16 | PType::F32 | PType::F64 => SumState::Float(0.0), - }, - DType::Decimal(decimal, _) => SumState::Decimal(DecimalValue::zero(decimal)), - _ => vortex_panic!("Unsupported sum type"), - } -} - -/// Checked add for u64, returning true if overflow occurred. -#[inline(always)] -fn checked_add_u64(acc: &mut u64, val: u64) -> bool { - match acc.checked_add(val) { - Some(r) => { - *acc = r; - false - } - None => true, - } -} - -/// Checked add for i64, returning true if overflow occurred. -#[inline(always)] -fn checked_add_i64(acc: &mut i64, val: i64) -> bool { - match acc.checked_add(val) { - Some(r) => { - *acc = r; - false - } - None => true, - } -} - -fn accumulate_primitive(inner: &mut SumState, p: &PrimitiveArray) -> VortexResult { - let mask = p.validity_mask()?; - match mask.bit_buffer() { - AllOr::None => Ok(false), - AllOr::All => accumulate_primitive_all(inner, p), - AllOr::Some(validity) => accumulate_primitive_valid(inner, p, validity), - } -} - -fn accumulate_primitive_all(inner: &mut SumState, p: &PrimitiveArray) -> VortexResult { - match inner { - SumState::Unsigned(acc) => match_each_native_ptype!(p.ptype(), - unsigned: |T| { - for &v in p.as_slice::() { - if checked_add_u64(acc, v.to_u64().vortex_expect("unsigned to u64")) { - return Ok(true); - } - } - Ok(false) - }, - signed: |_T| { vortex_panic!("unsigned sum state with signed input") }, - floating: |_T| { vortex_panic!("unsigned sum state with float input") } - ), - SumState::Signed(acc) => match_each_native_ptype!(p.ptype(), - unsigned: |_T| { vortex_panic!("signed sum state with unsigned input") }, - signed: |T| { - for &v in p.as_slice::() { - if checked_add_i64(acc, v.to_i64().vortex_expect("signed to i64")) { - return Ok(true); - } - } - Ok(false) - }, - floating: |_T| { vortex_panic!("signed sum state with float input") } - ), - SumState::Float(acc) => match_each_native_ptype!(p.ptype(), - unsigned: |_T| { vortex_panic!("float sum state with unsigned input") }, - signed: |_T| { vortex_panic!("float sum state with signed input") }, - floating: |T| { - for &v in p.as_slice::() { - *acc += ToPrimitive::to_f64(&v).vortex_expect("float to f64"); - } - Ok(false) - } - ), - SumState::Decimal(_) => vortex_panic!("decimal sum state with primitive input"), - } -} - -fn accumulate_primitive_valid( - inner: &mut SumState, - p: &PrimitiveArray, - validity: &vortex_buffer::BitBuffer, -) -> VortexResult { - match inner { - SumState::Unsigned(acc) => match_each_native_ptype!(p.ptype(), - unsigned: |T| { - for (&v, valid) in p.as_slice::().iter().zip_eq(validity.iter()) { - if valid && checked_add_u64(acc, v.to_u64().vortex_expect("unsigned to u64")) { - return Ok(true); - } - } - Ok(false) - }, - signed: |_T| { vortex_panic!("unsigned sum state with signed input") }, - floating: |_T| { vortex_panic!("unsigned sum state with float input") } - ), - SumState::Signed(acc) => match_each_native_ptype!(p.ptype(), - unsigned: |_T| { vortex_panic!("signed sum state with unsigned input") }, - signed: |T| { - for (&v, valid) in p.as_slice::().iter().zip_eq(validity.iter()) { - if valid && checked_add_i64(acc, v.to_i64().vortex_expect("signed to i64")) { - return Ok(true); - } - } - Ok(false) - }, - floating: |_T| { vortex_panic!("signed sum state with float input") } - ), - SumState::Float(acc) => match_each_native_ptype!(p.ptype(), - unsigned: |_T| { vortex_panic!("float sum state with unsigned input") }, - signed: |_T| { vortex_panic!("float sum state with signed input") }, - floating: |T| { - for (&v, valid) in p.as_slice::().iter().zip_eq(validity.iter()) { - if valid { - *acc += ToPrimitive::to_f64(&v).vortex_expect("float to f64"); - } - } - Ok(false) - } - ), - SumState::Decimal(_) => vortex_panic!("decimal sum state with primitive input"), - } -} - -fn accumulate_bool(inner: &mut SumState, b: &BoolArray) -> VortexResult { - let SumState::Unsigned(acc) = inner else { - vortex_panic!("expected unsigned sum state for bool input"); - }; - - let mask = b.validity_mask()?; - let true_count = match mask.bit_buffer() { - AllOr::None => return Ok(false), - AllOr::All => b.to_bit_buffer().true_count() as u64, - AllOr::Some(validity) => b.to_bit_buffer().bitand(validity).true_count() as u64, - }; - - Ok(checked_add_u64(acc, true_count)) -} - -/// Accumulate a constant array into the sum state. -/// Computes `scalar * len` and adds to the accumulator. -/// Returns Ok(true) if saturated (overflow), Ok(false) if not. -fn accumulate_constant(inner: &mut SumState, c: &ConstantArray) -> VortexResult { - let scalar = c.scalar(); - if scalar.is_null() || c.is_empty() { - return Ok(false); - } - let len = c.len(); - - match scalar.dtype() { - DType::Bool(_) => { - let SumState::Unsigned(acc) = inner else { - vortex_panic!("expected unsigned sum state for bool input"); - }; - let val = scalar - .as_bool() - .value() - .ok_or_else(|| vortex_err!("Expected non-null bool scalar for sum"))?; - if val { - Ok(checked_add_u64(acc, len as u64)) - } else { - Ok(false) - } - } - DType::Primitive(..) => { - let pvalue = scalar - .as_primitive() - .pvalue() - .ok_or_else(|| vortex_err!("Expected non-null primitive scalar for sum"))?; - match inner { - SumState::Unsigned(acc) => { - let val = pvalue.cast::()?; - match val.checked_mul(len as u64) { - Some(product) => Ok(checked_add_u64(acc, product)), - None => Ok(true), - } - } - SumState::Signed(acc) => { - let val = pvalue.cast::()?; - match i64::try_from(len).ok().and_then(|l| val.checked_mul(l)) { - Some(product) => Ok(checked_add_i64(acc, product)), - None => Ok(true), - } - } - SumState::Float(acc) => { - let val = pvalue.cast::()?; - *acc += val * len as f64; - Ok(false) - } - SumState::Decimal(_) => { - vortex_panic!("decimal sum state with primitive input") - } - } - } - DType::Decimal(..) => { - let SumState::Decimal(acc) = inner else { - vortex_panic!("expected decimal sum state for decimal input"); - }; - let val = scalar - .as_decimal() - .decimal_value() - .ok_or_else(|| vortex_err!("Expected non-null decimal scalar for sum"))?; - let len_decimal = DecimalValue::from(len as i128); - match val.checked_mul(&len_decimal) { - Some(product) => match acc.checked_add(&product) { - Some(r) => { - *acc = r; - Ok(false) - } - None => Ok(true), - }, - None => Ok(true), - } - } - _ => vortex_bail!("Unsupported constant type for sum: {}", scalar.dtype()), - } -} - -/// Accumulate a decimal array into the sum state. -/// Returns Ok(true) if saturated (overflow), Ok(false) if not. -fn accumulate_decimal(inner: &mut SumState, d: &DecimalArray) -> VortexResult { - let SumState::Decimal(acc) = inner else { - vortex_panic!("expected decimal sum state for decimal input"); - }; - - let mask = d.validity_mask()?; - match mask.bit_buffer() { - AllOr::None => Ok(false), - AllOr::All => match_each_decimal_value_type!(d.values_type(), |T| { - for &v in d.buffer::().iter() { - match acc.checked_add(&DecimalValue::from(v)) { - Some(r) => *acc = r, - None => return Ok(true), - } - } - Ok(false) - }), - AllOr::Some(validity) => match_each_decimal_value_type!(d.values_type(), |T| { - for (&v, valid) in d.buffer::().iter().zip_eq(validity.iter()) { - if valid { - match acc.checked_add(&DecimalValue::from(v)) { - Some(r) => *acc = r, - None => return Ok(true), - } - } - } - Ok(false) - }), - } -} - -#[cfg(test)] -mod tests { - use vortex_buffer::buffer; - use vortex_error::VortexResult; - use vortex_session::VortexSession; - - use crate::ArrayRef; - use crate::IntoArray; - use crate::aggregate_fn::Accumulator; - use crate::aggregate_fn::AggregateFnVTable; - use crate::aggregate_fn::DynAccumulator; - use crate::aggregate_fn::DynGroupedAccumulator; - use crate::aggregate_fn::EmptyOptions; - use crate::aggregate_fn::GroupedAccumulator; - use crate::aggregate_fn::fns::sum::Sum; - use crate::arrays::BoolArray; - use crate::arrays::FixedSizeListArray; - use crate::arrays::PrimitiveArray; - use crate::assert_arrays_eq; - use crate::dtype::DType; - use crate::dtype::Nullability; - use crate::dtype::PType; - use crate::scalar::Scalar; - use crate::validity::Validity; - - fn session() -> VortexSession { - VortexSession::empty() - } - - fn run_sum(batch: &ArrayRef) -> VortexResult { - let mut acc = Accumulator::try_new(Sum, EmptyOptions, batch.dtype().clone(), session())?; - acc.accumulate(batch)?; - acc.finish() - } - - // Primitive sum tests - - #[test] - fn sum_i32() -> VortexResult<()> { - let arr = PrimitiveArray::new(buffer![1i32, 2, 3, 4], Validity::NonNullable).into_array(); - let result = run_sum(&arr)?; - assert_eq!(result.as_primitive().typed_value::(), Some(10)); - Ok(()) - } - - #[test] - fn sum_u8() -> VortexResult<()> { - let arr = PrimitiveArray::new(buffer![10u8, 20, 30], Validity::NonNullable).into_array(); - let result = run_sum(&arr)?; - assert_eq!(result.as_primitive().typed_value::(), Some(60)); - Ok(()) - } - - #[test] - fn sum_f64() -> VortexResult<()> { - let arr = - PrimitiveArray::new(buffer![1.5f64, 2.5, 3.0], Validity::NonNullable).into_array(); - let result = run_sum(&arr)?; - assert_eq!(result.as_primitive().typed_value::(), Some(7.0)); - Ok(()) - } - - #[test] - fn sum_with_nulls() -> VortexResult<()> { - let arr = PrimitiveArray::from_option_iter([Some(2i32), None, Some(4)]).into_array(); - let result = run_sum(&arr)?; - assert_eq!(result.as_primitive().typed_value::(), Some(6)); - Ok(()) - } - - #[test] - fn sum_all_null() -> VortexResult<()> { - // Arrow semantics: sum of all nulls is zero (identity element) - let arr = PrimitiveArray::from_option_iter([None::, None, None]).into_array(); - let result = run_sum(&arr)?; - assert_eq!(result.as_primitive().typed_value::(), Some(0)); - Ok(()) - } - - // Empty accumulator tests - - #[test] - fn sum_empty_produces_zero() -> VortexResult<()> { - let dtype = DType::Primitive(PType::I32, Nullability::NonNullable); - let mut acc = Accumulator::try_new(Sum, EmptyOptions, dtype, session())?; - let result = acc.finish()?; - assert_eq!(result.as_primitive().typed_value::(), Some(0)); - Ok(()) - } - - #[test] - fn sum_empty_f64_produces_zero() -> VortexResult<()> { - let dtype = DType::Primitive(PType::F64, Nullability::NonNullable); - let mut acc = Accumulator::try_new(Sum, EmptyOptions, dtype, session())?; - let result = acc.finish()?; - assert_eq!(result.as_primitive().typed_value::(), Some(0.0)); - Ok(()) - } - - // Multi-batch and reset tests - - #[test] - fn sum_multi_batch() -> VortexResult<()> { - let dtype = DType::Primitive(PType::I32, Nullability::NonNullable); - let mut acc = Accumulator::try_new(Sum, EmptyOptions, dtype, session())?; - - let batch1 = PrimitiveArray::new(buffer![10i32, 20], Validity::NonNullable).into_array(); - acc.accumulate(&batch1)?; - - let batch2 = PrimitiveArray::new(buffer![3i32, 6, 9], Validity::NonNullable).into_array(); - acc.accumulate(&batch2)?; - - let result = acc.finish()?; - assert_eq!(result.as_primitive().typed_value::(), Some(48)); - Ok(()) - } - - #[test] - fn sum_finish_resets_state() -> VortexResult<()> { - let dtype = DType::Primitive(PType::I32, Nullability::NonNullable); - let mut acc = Accumulator::try_new(Sum, EmptyOptions, dtype, session())?; - - let batch1 = PrimitiveArray::new(buffer![10i32, 20], Validity::NonNullable).into_array(); - acc.accumulate(&batch1)?; - let result1 = acc.finish()?; - assert_eq!(result1.as_primitive().typed_value::(), Some(30)); - - let batch2 = PrimitiveArray::new(buffer![3i32, 6, 9], Validity::NonNullable).into_array(); - acc.accumulate(&batch2)?; - let result2 = acc.finish()?; - assert_eq!(result2.as_primitive().typed_value::(), Some(18)); - Ok(()) - } - - // State merge tests (vtable-level) - - #[test] - fn sum_state_merge() -> VortexResult<()> { - let dtype = DType::Primitive(PType::I32, Nullability::NonNullable); - let mut state = Sum.empty_partial(&EmptyOptions, &dtype)?; - - let scalar1 = Scalar::primitive(100i64, Nullability::Nullable); - Sum.combine_partials(&mut state, scalar1)?; - - let scalar2 = Scalar::primitive(50i64, Nullability::Nullable); - Sum.combine_partials(&mut state, scalar2)?; - - let result = Sum.flush(&mut state)?; - assert_eq!(result.as_primitive().typed_value::(), Some(150)); - Ok(()) - } - - // Overflow tests - - #[test] - fn sum_checked_overflow() -> VortexResult<()> { - let arr = PrimitiveArray::new(buffer![i64::MAX, 1i64], Validity::NonNullable).into_array(); - let result = run_sum(&arr)?; - assert!(result.is_null()); - Ok(()) - } - - #[test] - fn sum_checked_overflow_is_saturated() -> VortexResult<()> { - let dtype = DType::Primitive(PType::I64, Nullability::NonNullable); - let mut acc = Accumulator::try_new(Sum, EmptyOptions, dtype, session())?; - assert!(!acc.is_saturated()); - - let batch = - PrimitiveArray::new(buffer![i64::MAX, 1i64], Validity::NonNullable).into_array(); - acc.accumulate(&batch)?; - assert!(acc.is_saturated()); - - // finish resets state, clearing saturation - drop(acc.finish()?); - assert!(!acc.is_saturated()); - Ok(()) - } - - // Boolean sum tests - - #[test] - fn sum_bool_all_true() -> VortexResult<()> { - let arr: BoolArray = [true, true, true].into_iter().collect(); - let result = run_sum(&arr.into_array())?; - assert_eq!(result.as_primitive().typed_value::(), Some(3)); - Ok(()) - } - - #[test] - fn sum_bool_mixed() -> VortexResult<()> { - let arr: BoolArray = [true, false, true, false, true].into_iter().collect(); - let result = run_sum(&arr.into_array())?; - assert_eq!(result.as_primitive().typed_value::(), Some(3)); - Ok(()) - } - - #[test] - fn sum_bool_all_false() -> VortexResult<()> { - let arr: BoolArray = [false, false, false].into_iter().collect(); - let result = run_sum(&arr.into_array())?; - assert_eq!(result.as_primitive().typed_value::(), Some(0)); - Ok(()) - } - - #[test] - fn sum_bool_with_nulls() -> VortexResult<()> { - let arr = BoolArray::from_iter([Some(true), None, Some(true), Some(false)]); - let result = run_sum(&arr.into_array())?; - assert_eq!(result.as_primitive().typed_value::(), Some(2)); - Ok(()) - } - - #[test] - fn sum_bool_all_null() -> VortexResult<()> { - // Arrow semantics: sum of all nulls is zero (identity element) - let arr = BoolArray::from_iter([None::, None, None]); - let result = run_sum(&arr.into_array())?; - assert_eq!(result.as_primitive().typed_value::(), Some(0)); - Ok(()) - } - - #[test] - fn sum_bool_empty_produces_zero() -> VortexResult<()> { - let dtype = DType::Bool(Nullability::NonNullable); - let mut acc = Accumulator::try_new(Sum, EmptyOptions, dtype, session())?; - let result = acc.finish()?; - assert_eq!(result.as_primitive().typed_value::(), Some(0)); - Ok(()) - } - - #[test] - fn sum_bool_finish_resets_state() -> VortexResult<()> { - let dtype = DType::Bool(Nullability::NonNullable); - let mut acc = Accumulator::try_new(Sum, EmptyOptions, dtype, session())?; - - let batch1: BoolArray = [true, true, false].into_iter().collect(); - acc.accumulate(&batch1.into_array())?; - let result1 = acc.finish()?; - assert_eq!(result1.as_primitive().typed_value::(), Some(2)); - - let batch2: BoolArray = [false, true].into_iter().collect(); - acc.accumulate(&batch2.into_array())?; - let result2 = acc.finish()?; - assert_eq!(result2.as_primitive().typed_value::(), Some(1)); - Ok(()) - } - - #[test] - fn sum_bool_return_dtype() -> VortexResult<()> { - let dtype = Sum.return_dtype(&EmptyOptions, &DType::Bool(Nullability::NonNullable))?; - assert_eq!(dtype, DType::Primitive(PType::U64, Nullability::Nullable)); - Ok(()) - } - - // Grouped sum tests - - fn run_grouped_sum(groups: &ArrayRef, elem_dtype: &DType) -> VortexResult { - let mut acc = - GroupedAccumulator::try_new(Sum, EmptyOptions, elem_dtype.clone(), session())?; - acc.accumulate_list(groups)?; - acc.finish() - } - - #[test] - fn grouped_sum_fixed_size_list() -> VortexResult<()> { - // Groups: [[1,2,3], [4,5,6]] -> sums [6, 15] - let elements = - PrimitiveArray::new(buffer![1i32, 2, 3, 4, 5, 6], Validity::NonNullable).into_array(); - let groups = FixedSizeListArray::try_new(elements, 3, Validity::NonNullable, 2)?; - - let elem_dtype = DType::Primitive(PType::I32, Nullability::NonNullable); - let result = run_grouped_sum(&groups.into_array(), &elem_dtype)?; - - let expected = PrimitiveArray::from_option_iter([Some(6i64), Some(15i64)]).into_array(); - assert_arrays_eq!(&result, &expected); - Ok(()) - } - - #[test] - fn grouped_sum_with_null_elements() -> VortexResult<()> { - // Groups: [[Some(1), None, Some(3)], [None, Some(5), Some(6)]] -> sums [4, 11] - let elements = - PrimitiveArray::from_option_iter([Some(1i32), None, Some(3), None, Some(5), Some(6)]) - .into_array(); - let groups = FixedSizeListArray::try_new(elements, 3, Validity::NonNullable, 2)?; - - let elem_dtype = DType::Primitive(PType::I32, Nullability::Nullable); - let result = run_grouped_sum(&groups.into_array(), &elem_dtype)?; - - let expected = PrimitiveArray::from_option_iter([Some(4i64), Some(11i64)]).into_array(); - assert_arrays_eq!(&result, &expected); - Ok(()) - } - - #[test] - fn grouped_sum_with_null_group() -> VortexResult<()> { - // Groups: [[1,2,3], null, [7,8,9]] -> sums [6, null, 24] - let elements = - PrimitiveArray::new(buffer![1i32, 2, 3, 4, 5, 6, 7, 8, 9], Validity::NonNullable) - .into_array(); - let validity = Validity::from_iter([true, false, true]); - let groups = FixedSizeListArray::try_new(elements, 3, validity, 3)?; - - let elem_dtype = DType::Primitive(PType::I32, Nullability::NonNullable); - let result = run_grouped_sum(&groups.into_array(), &elem_dtype)?; - - let expected = - PrimitiveArray::from_option_iter([Some(6i64), None, Some(24i64)]).into_array(); - assert_arrays_eq!(&result, &expected); - Ok(()) - } - - #[test] - fn grouped_sum_all_null_elements_in_group() -> VortexResult<()> { - // Groups: [[None, None], [Some(3), Some(4)]] -> sums [0, 7] (Arrow semantics) - let elements = - PrimitiveArray::from_option_iter([None::, None, Some(3), Some(4)]).into_array(); - let groups = FixedSizeListArray::try_new(elements, 2, Validity::NonNullable, 2)?; - - let elem_dtype = DType::Primitive(PType::I32, Nullability::Nullable); - let result = run_grouped_sum(&groups.into_array(), &elem_dtype)?; - - let expected = PrimitiveArray::from_option_iter([Some(0i64), Some(7i64)]).into_array(); - assert_arrays_eq!(&result, &expected); - Ok(()) - } - - #[test] - fn grouped_sum_bool() -> VortexResult<()> { - // Groups: [[true, false, true], [true, true, true]] -> sums [2, 3] - let elements: BoolArray = [true, false, true, true, true, true].into_iter().collect(); - let groups = - FixedSizeListArray::try_new(elements.into_array(), 3, Validity::NonNullable, 2)?; - - let elem_dtype = DType::Bool(Nullability::NonNullable); - let result = run_grouped_sum(&groups.into_array(), &elem_dtype)?; - - let expected = PrimitiveArray::from_option_iter([Some(2u64), Some(3u64)]).into_array(); - assert_arrays_eq!(&result, &expected); - Ok(()) - } - - #[test] - fn grouped_sum_finish_resets() -> VortexResult<()> { - let elem_dtype = DType::Primitive(PType::I32, Nullability::NonNullable); - let mut acc = GroupedAccumulator::try_new(Sum, EmptyOptions, elem_dtype, session())?; - - // First batch: [[1, 2], [3, 4]] - let elements1 = - PrimitiveArray::new(buffer![1i32, 2, 3, 4], Validity::NonNullable).into_array(); - let groups1 = FixedSizeListArray::try_new(elements1, 2, Validity::NonNullable, 2)?; - acc.accumulate_list(&groups1.into_array())?; - let result1 = acc.finish()?; - - let expected1 = PrimitiveArray::from_option_iter([Some(3i64), Some(7i64)]).into_array(); - assert_arrays_eq!(&result1, &expected1); - - // Second batch after reset: [[10, 20]] - let elements2 = PrimitiveArray::new(buffer![10i32, 20], Validity::NonNullable).into_array(); - let groups2 = FixedSizeListArray::try_new(elements2, 2, Validity::NonNullable, 1)?; - acc.accumulate_list(&groups2.into_array())?; - let result2 = acc.finish()?; - - let expected2 = PrimitiveArray::from_option_iter([Some(30i64)]).into_array(); - assert_arrays_eq!(&result2, &expected2); - Ok(()) - } -} diff --git a/vortex-array/src/aggregate_fn/fns/sum/bool.rs b/vortex-array/src/aggregate_fn/fns/sum/bool.rs new file mode 100644 index 00000000000..a2dad06fff6 --- /dev/null +++ b/vortex-array/src/aggregate_fn/fns/sum/bool.rs @@ -0,0 +1,145 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use std::ops::BitAnd; + +use vortex_error::VortexResult; +use vortex_error::vortex_panic; +use vortex_mask::AllOr; + +use super::SumState; +use super::checked_add_u64; +use crate::arrays::BoolArray; + +pub(super) fn accumulate_bool(inner: &mut SumState, b: &BoolArray) -> VortexResult { + let SumState::Unsigned(acc) = inner else { + vortex_panic!("expected unsigned sum state for bool input"); + }; + + let mask = b.validity_mask()?; + let true_count = match mask.bit_buffer() { + AllOr::None => return Ok(false), + AllOr::All => b.to_bit_buffer().true_count() as u64, + AllOr::Some(validity) => b.to_bit_buffer().bitand(validity).true_count() as u64, + }; + + Ok(checked_add_u64(acc, true_count)) +} + +#[cfg(test)] +mod tests { + use vortex_error::VortexResult; + + use crate::IntoArray; + use crate::LEGACY_SESSION; + use crate::aggregate_fn::Accumulator; + use crate::aggregate_fn::AggregateFnVTable; + use crate::aggregate_fn::DynAccumulator; + use crate::aggregate_fn::EmptyOptions; + use crate::aggregate_fn::fns::sum::Sum; + use crate::aggregate_fn::fns::sum::sum; + use crate::arrays::BoolArray; + use crate::dtype::DType; + use crate::dtype::Nullability; + use crate::dtype::PType; + use crate::executor::VortexSessionExecute; + + #[test] + fn sum_bool_all_true() -> VortexResult<()> { + let arr: BoolArray = [true, true, true].into_iter().collect(); + let result = sum( + &arr.into_array(), + &mut LEGACY_SESSION.create_execution_ctx(), + )?; + assert_eq!(result.as_primitive().typed_value::(), Some(3)); + Ok(()) + } + + #[test] + fn sum_bool_mixed() -> VortexResult<()> { + let arr: BoolArray = [true, false, true, false, true].into_iter().collect(); + let result = sum( + &arr.into_array(), + &mut LEGACY_SESSION.create_execution_ctx(), + )?; + assert_eq!(result.as_primitive().typed_value::(), Some(3)); + Ok(()) + } + + #[test] + fn sum_bool_all_false() -> VortexResult<()> { + let arr: BoolArray = [false, false, false].into_iter().collect(); + let result = sum( + &arr.into_array(), + &mut LEGACY_SESSION.create_execution_ctx(), + )?; + assert_eq!(result.as_primitive().typed_value::(), Some(0)); + Ok(()) + } + + #[test] + fn sum_bool_with_nulls() -> VortexResult<()> { + let arr = BoolArray::from_iter([Some(true), None, Some(true), Some(false)]); + let result = sum( + &arr.into_array(), + &mut LEGACY_SESSION.create_execution_ctx(), + )?; + assert_eq!(result.as_primitive().typed_value::(), Some(2)); + Ok(()) + } + + #[test] + fn sum_bool_all_null() -> VortexResult<()> { + let arr = BoolArray::from_iter([None::, None, None]); + let result = sum( + &arr.into_array(), + &mut LEGACY_SESSION.create_execution_ctx(), + )?; + assert_eq!(result.as_primitive().typed_value::(), Some(0)); + Ok(()) + } + + #[test] + fn sum_bool_empty_produces_zero() -> VortexResult<()> { + let dtype = DType::Bool(Nullability::NonNullable); + let mut acc = Accumulator::try_new(Sum, EmptyOptions, dtype)?; + let result = acc.finish()?; + assert_eq!(result.as_primitive().typed_value::(), Some(0)); + Ok(()) + } + + #[test] + fn sum_bool_finish_resets_state() -> VortexResult<()> { + let mut ctx = LEGACY_SESSION.create_execution_ctx(); + let dtype = DType::Bool(Nullability::NonNullable); + let mut acc = Accumulator::try_new(Sum, EmptyOptions, dtype)?; + + let batch1: BoolArray = [true, true, false].into_iter().collect(); + acc.accumulate(&batch1.into_array(), &mut ctx)?; + let result1 = acc.finish()?; + assert_eq!(result1.as_primitive().typed_value::(), Some(2)); + + let batch2: BoolArray = [false, true].into_iter().collect(); + acc.accumulate(&batch2.into_array(), &mut ctx)?; + let result2 = acc.finish()?; + assert_eq!(result2.as_primitive().typed_value::(), Some(1)); + Ok(()) + } + + #[test] + fn sum_bool_return_dtype() -> VortexResult<()> { + let dtype = Sum + .return_dtype(&EmptyOptions, &DType::Bool(Nullability::NonNullable)) + .unwrap(); + assert_eq!(dtype, DType::Primitive(PType::U64, Nullability::Nullable)); + Ok(()) + } + + #[test] + fn sum_boolean_from_iter() -> VortexResult<()> { + let arr = BoolArray::from_iter([true, false, false, true]).into_array(); + let result = sum(&arr, &mut LEGACY_SESSION.create_execution_ctx())?; + assert_eq!(result.as_primitive().as_::(), Some(2)); + Ok(()) + } +} diff --git a/vortex-array/src/aggregate_fn/fns/sum/constant.rs b/vortex-array/src/aggregate_fn/fns/sum/constant.rs new file mode 100644 index 00000000000..c55b66a71cc --- /dev/null +++ b/vortex-array/src/aggregate_fn/fns/sum/constant.rs @@ -0,0 +1,219 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use vortex_error::VortexResult; +use vortex_error::vortex_bail; +use vortex_error::vortex_err; + +use crate::dtype::DType; +use crate::dtype::Nullability; +use crate::dtype::PType; +use crate::scalar::DecimalValue; +use crate::scalar::Scalar; + +/// Compute `scalar * len` for a constant array, returning the product as a sum-typed scalar. +/// +/// Returns `Ok(None)` if the scalar is null (no contribution to the sum). +/// Returns a null scalar on overflow (saturation). +pub(super) fn multiply_constant( + scalar: &Scalar, + len: usize, + return_dtype: &DType, +) -> VortexResult> { + if scalar.is_null() || len == 0 { + return Ok(None); + } + + let product = match scalar.dtype() { + DType::Bool(_) => { + let val = scalar + .as_bool() + .value() + .ok_or_else(|| vortex_err!("Expected non-null bool scalar for sum"))?; + if !val { + return Ok(None); + } + Scalar::primitive(len as u64, Nullability::Nullable) + } + DType::Primitive(..) => { + let pvalue = scalar + .as_primitive() + .pvalue() + .ok_or_else(|| vortex_err!("Expected non-null primitive scalar for sum"))?; + match return_dtype { + DType::Primitive(PType::U64, _) => { + let val = pvalue.cast::()?; + match val.checked_mul(len as u64) { + Some(product) => Scalar::primitive(product, Nullability::Nullable), + None => Scalar::null(return_dtype.as_nullable()), + } + } + DType::Primitive(PType::I64, _) => { + let val = pvalue.cast::()?; + match i64::try_from(len).ok().and_then(|l| val.checked_mul(l)) { + Some(product) => Scalar::primitive(product, Nullability::Nullable), + None => Scalar::null(return_dtype.as_nullable()), + } + } + DType::Primitive(PType::F64, _) => { + let val = pvalue.cast::()?; + Scalar::primitive(val * len as f64, Nullability::Nullable) + } + _ => vortex_bail!( + "Unexpected return dtype for primitive sum: {}", + return_dtype + ), + } + } + DType::Decimal(..) => { + let val = scalar + .as_decimal() + .decimal_value() + .ok_or_else(|| vortex_err!("Expected non-null decimal scalar for sum"))?; + let len_decimal = DecimalValue::from(len as i128); + match val.checked_mul(&len_decimal) { + Some(product) => { + let ret_decimal = *return_dtype + .as_decimal_opt() + .ok_or_else(|| vortex_err!("Expected decimal return dtype"))?; + Scalar::decimal(product, ret_decimal, Nullability::Nullable) + } + None => Scalar::null(return_dtype.as_nullable()), + } + } + _ => vortex_bail!("Unsupported constant type for sum: {}", scalar.dtype()), + }; + + Ok(Some(product)) +} + +#[cfg(test)] +mod tests { + use vortex_error::VortexResult; + + use crate::IntoArray; + use crate::LEGACY_SESSION; + use crate::VortexSessionExecute; + use crate::aggregate_fn::fns::sum::sum; + use crate::arrays::ConstantArray; + use crate::dtype::DType; + use crate::dtype::DecimalDType; + use crate::dtype::Nullability; + use crate::dtype::Nullability::Nullable; + use crate::dtype::PType; + use crate::dtype::i256; + use crate::expr::stats::Stat; + use crate::scalar::DecimalValue; + use crate::scalar::Scalar; + + #[test] + fn sum_constant_unsigned() -> VortexResult<()> { + let array = ConstantArray::new(5u64, 10).into_array(); + let result = sum(&array, &mut LEGACY_SESSION.create_execution_ctx())?; + assert_eq!(result, 50u64.into()); + Ok(()) + } + + #[test] + fn sum_constant_signed() -> VortexResult<()> { + let array = ConstantArray::new(-5i64, 10).into_array(); + let result = sum(&array, &mut LEGACY_SESSION.create_execution_ctx())?; + assert_eq!(result, (-50i64).into()); + Ok(()) + } + + #[test] + fn sum_constant_nullable_value() -> VortexResult<()> { + let array = ConstantArray::new(Scalar::null(DType::Primitive(PType::U32, Nullable)), 10) + .into_array(); + let result = sum(&array, &mut LEGACY_SESSION.create_execution_ctx())?; + assert_eq!(result, Scalar::primitive(0u64, Nullable)); + Ok(()) + } + + #[test] + fn sum_constant_bool_false() -> VortexResult<()> { + let array = ConstantArray::new(false, 10).into_array(); + let result = sum(&array, &mut LEGACY_SESSION.create_execution_ctx())?; + assert_eq!(result, 0u64.into()); + Ok(()) + } + + #[test] + fn sum_constant_bool_true() -> VortexResult<()> { + let array = ConstantArray::new(true, 10).into_array(); + let result = sum(&array, &mut LEGACY_SESSION.create_execution_ctx())?; + assert_eq!(result, 10u64.into()); + Ok(()) + } + + #[test] + fn sum_constant_bool_null() -> VortexResult<()> { + let array = ConstantArray::new(Scalar::null(DType::Bool(Nullable)), 10).into_array(); + let result = sum(&array, &mut LEGACY_SESSION.create_execution_ctx())?; + assert_eq!(result, Scalar::primitive(0u64, Nullable)); + Ok(()) + } + + #[test] + fn sum_constant_decimal() -> VortexResult<()> { + let decimal_dtype = DecimalDType::new(10, 2); + let array = ConstantArray::new( + Scalar::decimal( + DecimalValue::I64(100), + decimal_dtype, + Nullability::NonNullable, + ), + 5, + ) + .into_array(); + + let result = sum(&array, &mut LEGACY_SESSION.create_execution_ctx())?; + + assert_eq!( + result.as_decimal().decimal_value(), + Some(DecimalValue::I256(i256::from_i128(500))) + ); + assert_eq!(result.dtype(), &Stat::Sum.dtype(array.dtype()).unwrap()); + Ok(()) + } + + #[test] + fn sum_constant_decimal_null() -> VortexResult<()> { + let decimal_dtype = DecimalDType::new(10, 2); + let array = ConstantArray::new(Scalar::null(DType::Decimal(decimal_dtype, Nullable)), 10) + .into_array(); + + let result = sum(&array, &mut LEGACY_SESSION.create_execution_ctx())?; + assert_eq!( + result, + Scalar::decimal( + DecimalValue::I256(i256::ZERO), + DecimalDType::new(20, 2), + Nullable + ) + ); + Ok(()) + } + + #[test] + fn sum_constant_decimal_large_value() -> VortexResult<()> { + let decimal_dtype = DecimalDType::new(10, 2); + let array = ConstantArray::new( + Scalar::decimal( + DecimalValue::I64(999_999_999), + decimal_dtype, + Nullability::NonNullable, + ), + 100, + ) + .into_array(); + + let result = sum(&array, &mut LEGACY_SESSION.create_execution_ctx())?; + assert_eq!( + result.as_decimal().decimal_value(), + Some(DecimalValue::I256(i256::from_i128(99_999_999_900))) + ); + Ok(()) + } +} diff --git a/vortex-array/src/aggregate_fn/fns/sum/decimal.rs b/vortex-array/src/aggregate_fn/fns/sum/decimal.rs new file mode 100644 index 00000000000..fc388c57b49 --- /dev/null +++ b/vortex-array/src/aggregate_fn/fns/sum/decimal.rs @@ -0,0 +1,289 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use itertools::Itertools; +use vortex_error::VortexResult; +use vortex_error::vortex_panic; +use vortex_mask::AllOr; + +use super::SumState; +use crate::arrays::DecimalArray; +use crate::match_each_decimal_value_type; +use crate::scalar::DecimalValue; + +/// Accumulate a decimal array into the sum state. +/// Returns Ok(true) if saturated (overflow), Ok(false) if not. +pub(super) fn accumulate_decimal(inner: &mut SumState, d: &DecimalArray) -> VortexResult { + let SumState::Decimal(acc) = inner else { + vortex_panic!("expected decimal sum state for decimal input"); + }; + + let mask = d.validity_mask()?; + match mask.bit_buffer() { + AllOr::None => Ok(false), + AllOr::All => match_each_decimal_value_type!(d.values_type(), |T| { + for &v in d.buffer::().iter() { + match acc.checked_add(&DecimalValue::from(v)) { + Some(r) => *acc = r, + None => return Ok(true), + } + } + Ok(false) + }), + AllOr::Some(validity) => match_each_decimal_value_type!(d.values_type(), |T| { + for (&v, valid) in d.buffer::().iter().zip_eq(validity.iter()) { + if valid { + match acc.checked_add(&DecimalValue::from(v)) { + Some(r) => *acc = r, + None => return Ok(true), + } + } + } + Ok(false) + }), + } +} + +#[cfg(test)] +mod tests { + use vortex_buffer::buffer; + use vortex_error::VortexExpect; + use vortex_error::VortexResult; + + use crate::IntoArray; + use crate::LEGACY_SESSION; + use crate::VortexSessionExecute; + use crate::aggregate_fn::fns::sum::sum; + use crate::arrays::DecimalArray; + use crate::dtype::DType; + use crate::dtype::DecimalDType; + use crate::dtype::Nullability; + use crate::dtype::Nullability::Nullable; + use crate::dtype::i256; + use crate::scalar::DecimalValue; + use crate::scalar::Scalar; + use crate::scalar::ScalarValue; + use crate::validity::Validity; + + #[test] + fn sum_decimal_basic() -> VortexResult<()> { + let decimal = DecimalArray::new( + buffer![100i32, 200i32, 300i32], + DecimalDType::new(4, 2), + Validity::AllValid, + ); + + let result = sum( + &decimal.into_array(), + &mut LEGACY_SESSION.create_execution_ctx(), + )?; + + let expected = Scalar::try_new( + DType::Decimal(DecimalDType::new(14, 2), Nullability::NonNullable), + Some(ScalarValue::from(DecimalValue::from(600i32))), + )?; + + assert_eq!(result, expected); + Ok(()) + } + + #[test] + fn sum_decimal_with_nulls() -> VortexResult<()> { + let decimal = DecimalArray::new( + buffer![100i32, 200i32, 300i32, 400i32], + DecimalDType::new(4, 2), + Validity::from_iter([true, false, true, true]), + ); + + let result = sum( + &decimal.into_array(), + &mut LEGACY_SESSION.create_execution_ctx(), + )?; + + let expected = Scalar::try_new( + DType::Decimal(DecimalDType::new(14, 2), Nullable), + Some(ScalarValue::from(DecimalValue::from(800i32))), + )?; + + assert_eq!(result, expected); + Ok(()) + } + + #[test] + fn sum_decimal_negative_values() -> VortexResult<()> { + let decimal = DecimalArray::new( + buffer![100i32, -200i32, 300i32, -50i32], + DecimalDType::new(4, 2), + Validity::AllValid, + ); + + let result = sum( + &decimal.into_array(), + &mut LEGACY_SESSION.create_execution_ctx(), + )?; + + let expected = Scalar::try_new( + DType::Decimal(DecimalDType::new(14, 2), Nullability::NonNullable), + Some(ScalarValue::from(DecimalValue::from(150i32))), + )?; + + assert_eq!(result, expected); + Ok(()) + } + + #[test] + fn sum_decimal_near_i32_max() -> VortexResult<()> { + let near_max = i32::MAX - 1000; + let decimal = DecimalArray::new( + buffer![near_max, 500i32, 400i32], + DecimalDType::new(10, 2), + Validity::AllValid, + ); + + let result = sum( + &decimal.into_array(), + &mut LEGACY_SESSION.create_execution_ctx(), + )?; + + let expected_sum = near_max as i64 + 500 + 400; + let expected = Scalar::try_new( + DType::Decimal(DecimalDType::new(20, 2), Nullability::NonNullable), + Some(ScalarValue::from(DecimalValue::from(expected_sum))), + )?; + + assert_eq!(result, expected); + Ok(()) + } + + #[test] + fn sum_decimal_large_i64_values() -> VortexResult<()> { + let large_val = i64::MAX / 4; + let decimal = DecimalArray::new( + buffer![large_val, large_val, large_val, large_val + 1], + DecimalDType::new(19, 0), + Validity::AllValid, + ); + + let result = sum( + &decimal.into_array(), + &mut LEGACY_SESSION.create_execution_ctx(), + )?; + + let expected_sum = (large_val as i128) * 4 + 1; + let expected = Scalar::try_new( + DType::Decimal(DecimalDType::new(29, 0), Nullability::NonNullable), + Some(ScalarValue::from(DecimalValue::from(expected_sum))), + )?; + + assert_eq!(result, expected); + Ok(()) + } + + #[test] + fn sum_decimal_preserves_scale() -> VortexResult<()> { + let decimal = DecimalArray::new( + buffer![12345i32, 67890i32, 11111i32], + DecimalDType::new(6, 4), + Validity::AllValid, + ); + + let result = sum( + &decimal.into_array(), + &mut LEGACY_SESSION.create_execution_ctx(), + )?; + + let expected = Scalar::try_new( + DType::Decimal(DecimalDType::new(16, 4), Nullability::NonNullable), + Some(ScalarValue::from(DecimalValue::from(91346i32))), + )?; + + assert_eq!(result, expected); + Ok(()) + } + + #[test] + fn sum_decimal_single_value() -> VortexResult<()> { + let decimal = + DecimalArray::new(buffer![42i32], DecimalDType::new(3, 1), Validity::AllValid); + + let result = sum( + &decimal.into_array(), + &mut LEGACY_SESSION.create_execution_ctx(), + )?; + + let expected = Scalar::try_new( + DType::Decimal(DecimalDType::new(13, 1), Nullability::NonNullable), + Some(ScalarValue::from(DecimalValue::from(42i32))), + )?; + + assert_eq!(result, expected); + Ok(()) + } + + #[test] + fn sum_decimal_all_nulls_except_one() -> VortexResult<()> { + let decimal = DecimalArray::new( + buffer![100i32, 200i32, 300i32, 400i32], + DecimalDType::new(4, 2), + Validity::from_iter([false, false, true, false]), + ); + + let result = sum( + &decimal.into_array(), + &mut LEGACY_SESSION.create_execution_ctx(), + )?; + + let expected = Scalar::try_new( + DType::Decimal(DecimalDType::new(14, 2), Nullable), + Some(ScalarValue::from(DecimalValue::from(300i32))), + )?; + + assert_eq!(result, expected); + Ok(()) + } + + #[test] + fn sum_decimal_overflow_detection() -> VortexResult<()> { + let max_val = i128::MAX / 2; + let decimal = DecimalArray::new( + buffer![max_val, max_val, max_val], + DecimalDType::new(38, 0), + Validity::AllValid, + ); + + let result = sum( + &decimal.into_array(), + &mut LEGACY_SESSION.create_execution_ctx(), + )?; + + let expected_sum = + i256::from_i128(max_val) + i256::from_i128(max_val) + i256::from_i128(max_val); + let expected = Scalar::try_new( + DType::Decimal(DecimalDType::new(48, 0), Nullability::NonNullable), + Some(ScalarValue::from(DecimalValue::from(expected_sum))), + )?; + + assert_eq!(result, expected); + Ok(()) + } + + #[test] + fn sum_decimal_i256_overflow() -> VortexResult<()> { + let decimal_dtype = DecimalDType::new(76, 0); + let decimal = DecimalArray::new( + buffer![i256::MAX, i256::MAX, i256::MAX], + decimal_dtype, + Validity::AllValid, + ); + + assert_eq!( + sum( + &decimal.into_array(), + &mut LEGACY_SESSION.create_execution_ctx() + ) + .vortex_expect("operation should succeed in test"), + Scalar::null(DType::Decimal(decimal_dtype, Nullable)) + ); + Ok(()) + } +} diff --git a/vortex-array/src/aggregate_fn/fns/sum/mod.rs b/vortex-array/src/aggregate_fn/fns/sum/mod.rs new file mode 100644 index 00000000000..c5b5c94630d --- /dev/null +++ b/vortex-array/src/aggregate_fn/fns/sum/mod.rs @@ -0,0 +1,793 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +mod bool; +mod constant; +mod decimal; +mod primitive; + +use vortex_error::VortexExpect; +use vortex_error::VortexResult; +use vortex_error::vortex_bail; +use vortex_error::vortex_err; +use vortex_error::vortex_panic; + +use self::bool::accumulate_bool; +use self::constant::multiply_constant; +use self::decimal::accumulate_decimal; +use self::primitive::accumulate_primitive; +use crate::ArrayRef; +use crate::Canonical; +use crate::Columnar; +use crate::ExecutionCtx; +use crate::aggregate_fn::Accumulator; +use crate::aggregate_fn::AggregateFnId; +use crate::aggregate_fn::AggregateFnVTable; +use crate::aggregate_fn::DynAccumulator; +use crate::aggregate_fn::EmptyOptions; +use crate::dtype::DType; +use crate::dtype::DecimalDType; +use crate::dtype::MAX_PRECISION; +use crate::dtype::Nullability; +use crate::dtype::PType; +use crate::expr::stats::Precision; +use crate::expr::stats::Stat; +use crate::expr::stats::StatsProvider; +use crate::scalar::DecimalValue; +use crate::scalar::Scalar; + +/// Return the sum of an array. +/// +/// See [`Sum`] for details. +pub fn sum(array: &ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult { + // Short-circuit using cached array statistics. + if let Some(Precision::Exact(sum_scalar)) = array.statistics().get(Stat::Sum) { + return Ok(sum_scalar); + } + + // Compute using Accumulator. + // TODO(ngates): we may want to wrap this three-step dance up into an extension crate maybe. + let mut acc = Accumulator::try_new(Sum, EmptyOptions, array.dtype().clone())?; + acc.accumulate(array, ctx)?; + let result = acc.finish()?; + + // Cache the computed sum as a statistic (only if non-null, i.e. no overflow). + if let Some(val) = result.value().cloned() { + array.statistics().set(Stat::Sum, Precision::Exact(val)); + } + + Ok(result) +} + +/// Sum an array, starting from zero. +/// +/// If the sum overflows, a null scalar will be returned. +/// If the array is all-invalid, the sum will be zero. +#[derive(Clone, Debug)] +pub struct Sum; + +impl AggregateFnVTable for Sum { + type Options = EmptyOptions; + type Partial = SumPartial; + + fn id(&self) -> AggregateFnId { + AggregateFnId::new_ref("vortex.sum") + } + + fn return_dtype(&self, _options: &Self::Options, input_dtype: &DType) -> Option { + // When a sum overflows, we return a sum _value_ of null. Therefore, we all return dtypes + // are nullable. + use Nullability::Nullable; + + Some(match input_dtype { + DType::Bool(_) => DType::Primitive(PType::U64, Nullable), + DType::Primitive(ptype, _) => match ptype { + PType::U8 | PType::U16 | PType::U32 | PType::U64 => { + DType::Primitive(PType::U64, Nullable) + } + PType::I8 | PType::I16 | PType::I32 | PType::I64 => { + DType::Primitive(PType::I64, Nullable) + } + PType::F16 | PType::F32 | PType::F64 => { + // Float sums cannot overflow, but all null floats still end up as null + DType::Primitive(PType::F64, Nullable) + } + }, + DType::Decimal(decimal_dtype, _) => { + // Both Spark and DataFusion use this heuristic. + // - https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66 + // - https://github.com/apache/datafusion/blob/4153adf2c0f6e317ef476febfdc834208bd46622/datafusion/functions-aggregate/src/sum.rs#L188 + let precision = u8::min(MAX_PRECISION, decimal_dtype.precision() + 10); + DType::Decimal( + DecimalDType::new(precision, decimal_dtype.scale()), + Nullable, + ) + } + // Unsupported types + _ => return None, + }) + } + + fn partial_dtype(&self, options: &Self::Options, input_dtype: &DType) -> Option { + self.return_dtype(options, input_dtype) + } + + fn empty_partial( + &self, + options: &Self::Options, + input_dtype: &DType, + ) -> VortexResult { + let return_dtype = self + .return_dtype(options, input_dtype) + .ok_or_else(|| vortex_err!("Unsupported sum dtype: {}", input_dtype))?; + let initial = make_zero_state(&return_dtype); + + Ok(SumPartial { + return_dtype, + current: Some(initial), + }) + } + + fn combine_partials(&self, partial: &mut Self::Partial, other: Scalar) -> VortexResult<()> { + if other.is_null() { + // A null partial means the sub-accumulator saturated (overflow). + partial.current = None; + return Ok(()); + } + let Some(ref mut inner) = partial.current else { + return Ok(()); + }; + let saturated = match inner { + SumState::Unsigned(acc) => { + let val = other + .as_primitive() + .typed_value::() + .vortex_expect("checked non-null"); + checked_add_u64(acc, val) + } + SumState::Signed(acc) => { + let val = other + .as_primitive() + .typed_value::() + .vortex_expect("checked non-null"); + checked_add_i64(acc, val) + } + SumState::Float(acc) => { + let val = other + .as_primitive() + .typed_value::() + .vortex_expect("checked non-null"); + *acc += val; + false + } + SumState::Decimal(acc) => { + let val = other + .as_decimal() + .decimal_value() + .vortex_expect("checked non-null"); + match acc.checked_add(&val) { + Some(r) => { + *acc = r; + false + } + None => true, + } + } + }; + if saturated { + partial.current = None; + } + Ok(()) + } + + fn flush(&self, partial: &mut Self::Partial) -> VortexResult { + let result = match &partial.current { + None => Scalar::null(partial.return_dtype.as_nullable()), + Some(SumState::Unsigned(v)) => Scalar::primitive(*v, Nullability::Nullable), + Some(SumState::Signed(v)) => Scalar::primitive(*v, Nullability::Nullable), + Some(SumState::Float(v)) => Scalar::primitive(*v, Nullability::Nullable), + Some(SumState::Decimal(v)) => { + let decimal_dtype = *partial + .return_dtype + .as_decimal_opt() + .vortex_expect("return dtype must be decimal"); + Scalar::decimal(*v, decimal_dtype, Nullability::Nullable) + } + }; + + // Reset the state + partial.current = Some(make_zero_state(&partial.return_dtype)); + + Ok(result) + } + + #[inline] + fn is_saturated(&self, partial: &Self::Partial) -> bool { + partial.current.is_none() + } + + fn accumulate( + &self, + partial: &mut Self::Partial, + batch: &Columnar, + _ctx: &mut ExecutionCtx, + ) -> VortexResult<()> { + // Constants compute scalar * len and combine via combine_partials. + if let Columnar::Constant(c) = batch { + if let Some(product) = multiply_constant(c.scalar(), c.len(), &partial.return_dtype)? { + self.combine_partials(partial, product)?; + } + return Ok(()); + } + + let mut inner = match partial.current.take() { + Some(inner) => inner, + None => return Ok(()), + }; + + let result = match batch { + Columnar::Canonical(c) => match c { + Canonical::Primitive(p) => accumulate_primitive(&mut inner, p), + Canonical::Bool(b) => accumulate_bool(&mut inner, b), + Canonical::Decimal(d) => accumulate_decimal(&mut inner, d), + _ => vortex_bail!("Unsupported canonical type for sum: {}", batch.dtype()), + }, + Columnar::Constant(_) => unreachable!(), + }; + + match result { + Ok(false) => partial.current = Some(inner), + Ok(true) => {} // saturated: current stays None + Err(e) => { + partial.current = Some(inner); + return Err(e); + } + } + Ok(()) + } + + fn finalize(&self, partials: ArrayRef) -> VortexResult { + Ok(partials) + } + + fn finalize_scalar(&self, partial: Scalar) -> VortexResult { + Ok(partial) + } +} + +/// The group state for a sum aggregate, containing the accumulated value and configuration +/// needed for reset/result without external context. +pub struct SumPartial { + return_dtype: DType, + /// The current accumulated state, or `None` if saturated (checked overflow). + current: Option, +} + +/// The accumulated sum value. +/// +// TODO(ngates): instead of an enum, we should use a Box to avoid dispatcher over the +// input type every time? Perhaps? +pub enum SumState { + Unsigned(u64), + Signed(i64), + Float(f64), + Decimal(DecimalValue), +} + +fn make_zero_state(return_dtype: &DType) -> SumState { + match return_dtype { + DType::Primitive(ptype, _) => match ptype { + PType::U8 | PType::U16 | PType::U32 | PType::U64 => SumState::Unsigned(0), + PType::I8 | PType::I16 | PType::I32 | PType::I64 => SumState::Signed(0), + PType::F16 | PType::F32 | PType::F64 => SumState::Float(0.0), + }, + DType::Decimal(decimal, _) => SumState::Decimal(DecimalValue::zero(decimal)), + _ => vortex_panic!("Unsupported sum type"), + } +} + +/// Checked add for u64, returning true if overflow occurred. +#[inline(always)] +fn checked_add_u64(acc: &mut u64, val: u64) -> bool { + match acc.checked_add(val) { + Some(r) => { + *acc = r; + false + } + None => true, + } +} + +/// Checked add for i64, returning true if overflow occurred. +#[inline(always)] +fn checked_add_i64(acc: &mut i64, val: i64) -> bool { + match acc.checked_add(val) { + Some(r) => { + *acc = r; + false + } + None => true, + } +} + +#[cfg(test)] +mod tests { + use num_traits::CheckedAdd; + use vortex_buffer::buffer; + use vortex_error::VortexExpect; + use vortex_error::VortexResult; + + use crate::ArrayRef; + use crate::DynArray; + use crate::IntoArray; + use crate::LEGACY_SESSION; + use crate::VortexSessionExecute; + use crate::aggregate_fn::Accumulator; + use crate::aggregate_fn::AggregateFnVTable; + use crate::aggregate_fn::DynAccumulator; + use crate::aggregate_fn::DynGroupedAccumulator; + use crate::aggregate_fn::EmptyOptions; + use crate::aggregate_fn::GroupedAccumulator; + use crate::aggregate_fn::fns::sum::Sum; + use crate::aggregate_fn::fns::sum::sum; + use crate::arrays::BoolArray; + use crate::arrays::ChunkedArray; + use crate::arrays::ConstantArray; + use crate::arrays::DecimalArray; + use crate::arrays::FixedSizeListArray; + use crate::arrays::PrimitiveArray; + use crate::assert_arrays_eq; + use crate::dtype::DType; + use crate::dtype::DecimalDType; + use crate::dtype::Nullability; + use crate::dtype::Nullability::Nullable; + use crate::dtype::PType; + use crate::dtype::i256; + use crate::expr::stats::Precision; + use crate::expr::stats::Stat; + use crate::expr::stats::StatsProvider; + use crate::scalar::DecimalValue; + use crate::scalar::NumericOperator; + use crate::scalar::Scalar; + use crate::validity::Validity; + + /// Sum an array with an initial value (test-only helper). + fn sum_with_accumulator(array: &ArrayRef, accumulator: &Scalar) -> VortexResult { + let mut ctx = LEGACY_SESSION.create_execution_ctx(); + if accumulator.is_null() { + return Ok(accumulator.clone()); + } + if accumulator.is_zero() == Some(true) { + return sum(array, &mut ctx); + } + + let sum_dtype = Stat::Sum.dtype(array.dtype()).ok_or_else(|| { + vortex_error::vortex_err!("Sum not supported for dtype: {}", array.dtype()) + })?; + + // For non-float types, try statistics short-circuit with accumulator. + if !matches!(&sum_dtype, DType::Primitive(p, _) if p.is_float()) + && let Some(Precision::Exact(sum_scalar)) = array.statistics().get(Stat::Sum) + { + return add_scalars(&sum_dtype, &sum_scalar, accumulator); + } + + // Compute array sum from zero (also caches stats). + let array_sum = sum(array, &mut ctx)?; + + // Combine with the accumulator. + add_scalars(&sum_dtype, &array_sum, accumulator) + } + + /// Add two sum scalars with overflow checking. + fn add_scalars(sum_dtype: &DType, lhs: &Scalar, rhs: &Scalar) -> VortexResult { + if lhs.is_null() || rhs.is_null() { + return Ok(Scalar::null(sum_dtype.as_nullable())); + } + + Ok(match sum_dtype { + DType::Primitive(ptype, _) if ptype.is_float() => { + let lhs_val = f64::try_from(lhs)?; + let rhs_val = f64::try_from(rhs)?; + Scalar::primitive(lhs_val + rhs_val, Nullable) + } + DType::Primitive(..) => lhs + .as_primitive() + .checked_add(&rhs.as_primitive()) + .map(Scalar::from) + .unwrap_or_else(|| Scalar::null(sum_dtype.as_nullable())), + DType::Decimal(..) => lhs + .as_decimal() + .checked_binary_numeric(&rhs.as_decimal(), NumericOperator::Add) + .map(Scalar::from) + .unwrap_or_else(|| Scalar::null(sum_dtype.as_nullable())), + _ => unreachable!("Sum will always be a decimal or a primitive dtype"), + }) + } + + // Multi-batch and reset tests + + #[test] + fn sum_multi_batch() -> VortexResult<()> { + let mut ctx = LEGACY_SESSION.create_execution_ctx(); + let dtype = DType::Primitive(PType::I32, Nullability::NonNullable); + let mut acc = Accumulator::try_new(Sum, EmptyOptions, dtype)?; + + let batch1 = PrimitiveArray::new(buffer![10i32, 20], Validity::NonNullable).into_array(); + acc.accumulate(&batch1, &mut ctx)?; + + let batch2 = PrimitiveArray::new(buffer![3i32, 6, 9], Validity::NonNullable).into_array(); + acc.accumulate(&batch2, &mut ctx)?; + + let result = acc.finish()?; + assert_eq!(result.as_primitive().typed_value::(), Some(48)); + Ok(()) + } + + #[test] + fn sum_finish_resets_state() -> VortexResult<()> { + let mut ctx = LEGACY_SESSION.create_execution_ctx(); + let dtype = DType::Primitive(PType::I32, Nullability::NonNullable); + let mut acc = Accumulator::try_new(Sum, EmptyOptions, dtype)?; + + let batch1 = PrimitiveArray::new(buffer![10i32, 20], Validity::NonNullable).into_array(); + acc.accumulate(&batch1, &mut ctx)?; + let result1 = acc.finish()?; + assert_eq!(result1.as_primitive().typed_value::(), Some(30)); + + let batch2 = PrimitiveArray::new(buffer![3i32, 6, 9], Validity::NonNullable).into_array(); + acc.accumulate(&batch2, &mut ctx)?; + let result2 = acc.finish()?; + assert_eq!(result2.as_primitive().typed_value::(), Some(18)); + Ok(()) + } + + // State merge tests (vtable-level) + + #[test] + fn sum_state_merge() -> VortexResult<()> { + let dtype = DType::Primitive(PType::I32, Nullability::NonNullable); + let mut state = Sum.empty_partial(&EmptyOptions, &dtype)?; + + let scalar1 = Scalar::primitive(100i64, Nullable); + Sum.combine_partials(&mut state, scalar1)?; + + let scalar2 = Scalar::primitive(50i64, Nullable); + Sum.combine_partials(&mut state, scalar2)?; + + let result = Sum.flush(&mut state)?; + assert_eq!(result.as_primitive().typed_value::(), Some(150)); + Ok(()) + } + + // Stats caching test + + #[test] + fn sum_stats() -> VortexResult<()> { + let array = ChunkedArray::try_new( + vec![ + PrimitiveArray::from_iter([1, 1, 1]).into_array(), + PrimitiveArray::from_iter([2, 2, 2]).into_array(), + ], + DType::Primitive(PType::I32, Nullability::NonNullable), + ) + .vortex_expect("operation should succeed in test"); + let array = array.into_array(); + // compute sum with accumulator to populate stats + sum_with_accumulator(&array, &Scalar::primitive(2i64, Nullable))?; + + let sum_without_acc = sum(&array, &mut LEGACY_SESSION.create_execution_ctx())?; + assert_eq!(sum_without_acc, Scalar::primitive(9i64, Nullable)); + Ok(()) + } + + // Constant float non-multiply test + + #[test] + fn sum_constant_float_non_multiply() -> VortexResult<()> { + let acc = -2048669276050936500000000000f64; + let array = ConstantArray::new(6.1811675e16f64, 25); + let result = sum_with_accumulator(&array.into_array(), &Scalar::primitive(acc, Nullable)) + .vortex_expect("operation should succeed in test"); + assert_eq!( + f64::try_from(&result).vortex_expect("operation should succeed in test"), + -2048669274505644600000000000f64 + ); + Ok(()) + } + + // Grouped sum tests + + fn run_grouped_sum(groups: &ArrayRef, elem_dtype: &DType) -> VortexResult { + let mut acc = GroupedAccumulator::try_new(Sum, EmptyOptions, elem_dtype.clone())?; + acc.accumulate_list(groups, &mut LEGACY_SESSION.create_execution_ctx())?; + acc.finish() + } + + #[test] + fn grouped_sum_fixed_size_list() -> VortexResult<()> { + let elements = + PrimitiveArray::new(buffer![1i32, 2, 3, 4, 5, 6], Validity::NonNullable).into_array(); + let groups = FixedSizeListArray::try_new(elements, 3, Validity::NonNullable, 2)?; + + let elem_dtype = DType::Primitive(PType::I32, Nullability::NonNullable); + let result = run_grouped_sum(&groups.into_array(), &elem_dtype)?; + + let expected = PrimitiveArray::from_option_iter([Some(6i64), Some(15i64)]).into_array(); + assert_arrays_eq!(&result, &expected); + Ok(()) + } + + #[test] + fn grouped_sum_with_null_elements() -> VortexResult<()> { + let elements = + PrimitiveArray::from_option_iter([Some(1i32), None, Some(3), None, Some(5), Some(6)]) + .into_array(); + let groups = FixedSizeListArray::try_new(elements, 3, Validity::NonNullable, 2)?; + + let elem_dtype = DType::Primitive(PType::I32, Nullable); + let result = run_grouped_sum(&groups.into_array(), &elem_dtype)?; + + let expected = PrimitiveArray::from_option_iter([Some(4i64), Some(11i64)]).into_array(); + assert_arrays_eq!(&result, &expected); + Ok(()) + } + + #[test] + fn grouped_sum_with_null_group() -> VortexResult<()> { + let elements = + PrimitiveArray::new(buffer![1i32, 2, 3, 4, 5, 6, 7, 8, 9], Validity::NonNullable) + .into_array(); + let validity = Validity::from_iter([true, false, true]); + let groups = FixedSizeListArray::try_new(elements, 3, validity, 3)?; + + let elem_dtype = DType::Primitive(PType::I32, Nullability::NonNullable); + let result = run_grouped_sum(&groups.into_array(), &elem_dtype)?; + + let expected = + PrimitiveArray::from_option_iter([Some(6i64), None, Some(24i64)]).into_array(); + assert_arrays_eq!(&result, &expected); + Ok(()) + } + + #[test] + fn grouped_sum_all_null_elements_in_group() -> VortexResult<()> { + let elements = + PrimitiveArray::from_option_iter([None::, None, Some(3), Some(4)]).into_array(); + let groups = FixedSizeListArray::try_new(elements, 2, Validity::NonNullable, 2)?; + + let elem_dtype = DType::Primitive(PType::I32, Nullable); + let result = run_grouped_sum(&groups.into_array(), &elem_dtype)?; + + let expected = PrimitiveArray::from_option_iter([Some(0i64), Some(7i64)]).into_array(); + assert_arrays_eq!(&result, &expected); + Ok(()) + } + + #[test] + fn grouped_sum_bool() -> VortexResult<()> { + let elements: BoolArray = [true, false, true, true, true, true].into_iter().collect(); + let groups = + FixedSizeListArray::try_new(elements.into_array(), 3, Validity::NonNullable, 2)?; + + let elem_dtype = DType::Bool(Nullability::NonNullable); + let result = run_grouped_sum(&groups.into_array(), &elem_dtype)?; + + let expected = PrimitiveArray::from_option_iter([Some(2u64), Some(3u64)]).into_array(); + assert_arrays_eq!(&result, &expected); + Ok(()) + } + + #[test] + fn grouped_sum_finish_resets() -> VortexResult<()> { + let mut ctx = LEGACY_SESSION.create_execution_ctx(); + let elem_dtype = DType::Primitive(PType::I32, Nullability::NonNullable); + let mut acc = GroupedAccumulator::try_new(Sum, EmptyOptions, elem_dtype)?; + + let elements1 = + PrimitiveArray::new(buffer![1i32, 2, 3, 4], Validity::NonNullable).into_array(); + let groups1 = FixedSizeListArray::try_new(elements1, 2, Validity::NonNullable, 2)?; + acc.accumulate_list(&groups1.into_array(), &mut ctx)?; + let result1 = acc.finish()?; + + let expected1 = PrimitiveArray::from_option_iter([Some(3i64), Some(7i64)]).into_array(); + assert_arrays_eq!(&result1, &expected1); + + let elements2 = PrimitiveArray::new(buffer![10i32, 20], Validity::NonNullable).into_array(); + let groups2 = FixedSizeListArray::try_new(elements2, 2, Validity::NonNullable, 1)?; + acc.accumulate_list(&groups2.into_array(), &mut ctx)?; + let result2 = acc.finish()?; + + let expected2 = PrimitiveArray::from_option_iter([Some(30i64)]).into_array(); + assert_arrays_eq!(&result2, &expected2); + Ok(()) + } + + // Chunked array tests + + #[test] + fn sum_chunked_floats_with_nulls() -> VortexResult<()> { + let chunk1 = + PrimitiveArray::from_option_iter(vec![Some(1.5f64), None, Some(3.2), Some(4.8)]); + let chunk2 = PrimitiveArray::from_option_iter(vec![Some(2.1f64), Some(5.7), None]); + let chunk3 = PrimitiveArray::from_option_iter(vec![None, Some(1.0f64), Some(2.5), None]); + let dtype = chunk1.dtype().clone(); + let chunked = ChunkedArray::try_new( + vec![ + chunk1.into_array(), + chunk2.into_array(), + chunk3.into_array(), + ], + dtype, + )?; + + let result = sum( + &chunked.into_array(), + &mut LEGACY_SESSION.create_execution_ctx(), + )?; + assert_eq!(result.as_primitive().as_::(), Some(20.8)); + Ok(()) + } + + #[test] + fn sum_chunked_floats_all_nulls_is_zero() -> VortexResult<()> { + let chunk1 = PrimitiveArray::from_option_iter::(vec![None, None, None]); + let chunk2 = PrimitiveArray::from_option_iter::(vec![None, None]); + let dtype = chunk1.dtype().clone(); + let chunked = ChunkedArray::try_new(vec![chunk1.into_array(), chunk2.into_array()], dtype)?; + let result = sum( + &chunked.into_array(), + &mut LEGACY_SESSION.create_execution_ctx(), + )?; + assert_eq!(result, Scalar::primitive(0f64, Nullable)); + Ok(()) + } + + #[test] + fn sum_chunked_floats_empty_chunks() -> VortexResult<()> { + let chunk1 = PrimitiveArray::from_option_iter(vec![Some(10.5f64), Some(20.3)]); + let chunk2 = ConstantArray::new(Scalar::primitive(0f64, Nullable), 0); + let chunk3 = PrimitiveArray::from_option_iter(vec![Some(5.2f64)]); + let dtype = chunk1.dtype().clone(); + let chunked = ChunkedArray::try_new( + vec![ + chunk1.into_array(), + chunk2.into_array(), + chunk3.into_array(), + ], + dtype, + )?; + + let result = sum( + &chunked.into_array(), + &mut LEGACY_SESSION.create_execution_ctx(), + )?; + assert_eq!(result.as_primitive().as_::(), Some(36.0)); + Ok(()) + } + + #[test] + fn sum_chunked_int_almost_all_null() -> VortexResult<()> { + let chunk1 = PrimitiveArray::from_option_iter::(vec![Some(1)]); + let chunk2 = PrimitiveArray::from_option_iter::(vec![None]); + let dtype = chunk1.dtype().clone(); + let chunked = ChunkedArray::try_new(vec![chunk1.into_array(), chunk2.into_array()], dtype)?; + + let result = sum( + &chunked.into_array(), + &mut LEGACY_SESSION.create_execution_ctx(), + )?; + assert_eq!(result.as_primitive().as_::(), Some(1)); + Ok(()) + } + + #[test] + fn sum_chunked_decimals() -> VortexResult<()> { + let decimal_dtype = DecimalDType::new(10, 2); + let chunk1 = DecimalArray::new( + buffer![100i32, 100i32, 100i32, 100i32, 100i32], + decimal_dtype, + Validity::AllValid, + ); + let chunk2 = DecimalArray::new( + buffer![200i32, 200i32, 200i32], + decimal_dtype, + Validity::AllValid, + ); + let chunk3 = DecimalArray::new(buffer![300i32, 300i32], decimal_dtype, Validity::AllValid); + let dtype = chunk1.dtype().clone(); + let chunked = ChunkedArray::try_new( + vec![ + chunk1.into_array(), + chunk2.into_array(), + chunk3.into_array(), + ], + dtype, + )?; + + let result = sum( + &chunked.into_array(), + &mut LEGACY_SESSION.create_execution_ctx(), + )?; + let decimal_result = result.as_decimal(); + assert_eq!( + decimal_result.decimal_value(), + Some(DecimalValue::I256(i256::from_i128(1700))) + ); + Ok(()) + } + + #[test] + fn sum_chunked_decimals_with_nulls() -> VortexResult<()> { + let decimal_dtype = DecimalDType::new(10, 2); + let chunk1 = DecimalArray::new( + buffer![100i32, 100i32, 100i32], + decimal_dtype, + Validity::AllValid, + ); + let chunk2 = DecimalArray::new( + buffer![0i32, 0i32], + decimal_dtype, + Validity::from_iter([false, false]), + ); + let chunk3 = DecimalArray::new(buffer![200i32, 200i32], decimal_dtype, Validity::AllValid); + let dtype = chunk1.dtype().clone(); + let chunked = ChunkedArray::try_new( + vec![ + chunk1.into_array(), + chunk2.into_array(), + chunk3.into_array(), + ], + dtype, + )?; + + let result = sum( + &chunked.into_array(), + &mut LEGACY_SESSION.create_execution_ctx(), + )?; + let decimal_result = result.as_decimal(); + assert_eq!( + decimal_result.decimal_value(), + Some(DecimalValue::I256(i256::from_i128(700))) + ); + Ok(()) + } + + #[test] + fn sum_chunked_decimals_large() -> VortexResult<()> { + let decimal_dtype = DecimalDType::new(3, 0); + let chunk1 = ConstantArray::new( + Scalar::decimal( + DecimalValue::I16(500), + decimal_dtype, + Nullability::NonNullable, + ), + 1, + ); + let chunk2 = ConstantArray::new( + Scalar::decimal( + DecimalValue::I16(600), + decimal_dtype, + Nullability::NonNullable, + ), + 1, + ); + let dtype = chunk1.dtype().clone(); + let chunked = ChunkedArray::try_new(vec![chunk1.into_array(), chunk2.into_array()], dtype)?; + + let result = sum( + &chunked.into_array(), + &mut LEGACY_SESSION.create_execution_ctx(), + )?; + let decimal_result = result.as_decimal(); + assert_eq!( + decimal_result.decimal_value(), + Some(DecimalValue::I256(i256::from_i128(1100))) + ); + assert_eq!( + result.dtype(), + &DType::Decimal(DecimalDType::new(13, 0), Nullable) + ); + Ok(()) + } +} diff --git a/vortex-array/src/aggregate_fn/fns/sum/primitive.rs b/vortex-array/src/aggregate_fn/fns/sum/primitive.rs new file mode 100644 index 00000000000..292711f95bf --- /dev/null +++ b/vortex-array/src/aggregate_fn/fns/sum/primitive.rs @@ -0,0 +1,240 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use itertools::Itertools; +use num_traits::ToPrimitive; +use vortex_error::VortexExpect; +use vortex_error::VortexResult; +use vortex_error::vortex_panic; +use vortex_mask::AllOr; + +use super::SumState; +use super::checked_add_i64; +use super::checked_add_u64; +use crate::arrays::PrimitiveArray; +use crate::match_each_native_ptype; + +pub(super) fn accumulate_primitive(inner: &mut SumState, p: &PrimitiveArray) -> VortexResult { + let mask = p.validity_mask()?; + match mask.bit_buffer() { + AllOr::None => Ok(false), + AllOr::All => accumulate_primitive_all(inner, p), + AllOr::Some(validity) => accumulate_primitive_valid(inner, p, validity), + } +} + +fn accumulate_primitive_all(inner: &mut SumState, p: &PrimitiveArray) -> VortexResult { + match inner { + SumState::Unsigned(acc) => match_each_native_ptype!(p.ptype(), + unsigned: |T| { + for &v in p.as_slice::() { + if checked_add_u64(acc, v.to_u64().vortex_expect("unsigned to u64")) { + return Ok(true); + } + } + Ok(false) + }, + signed: |_T| { vortex_panic!("unsigned sum state with signed input") }, + floating: |_T| { vortex_panic!("unsigned sum state with float input") } + ), + SumState::Signed(acc) => match_each_native_ptype!(p.ptype(), + unsigned: |_T| { vortex_panic!("signed sum state with unsigned input") }, + signed: |T| { + for &v in p.as_slice::() { + if checked_add_i64(acc, v.to_i64().vortex_expect("signed to i64")) { + return Ok(true); + } + } + Ok(false) + }, + floating: |_T| { vortex_panic!("signed sum state with float input") } + ), + SumState::Float(acc) => match_each_native_ptype!(p.ptype(), + unsigned: |_T| { vortex_panic!("float sum state with unsigned input") }, + signed: |_T| { vortex_panic!("float sum state with signed input") }, + floating: |T| { + for &v in p.as_slice::() { + *acc += ToPrimitive::to_f64(&v).vortex_expect("float to f64"); + } + Ok(false) + } + ), + SumState::Decimal(_) => vortex_panic!("decimal sum state with primitive input"), + } +} + +fn accumulate_primitive_valid( + inner: &mut SumState, + p: &PrimitiveArray, + validity: &vortex_buffer::BitBuffer, +) -> VortexResult { + match inner { + SumState::Unsigned(acc) => match_each_native_ptype!(p.ptype(), + unsigned: |T| { + for (&v, valid) in p.as_slice::().iter().zip_eq(validity.iter()) { + if valid && checked_add_u64(acc, v.to_u64().vortex_expect("unsigned to u64")) { + return Ok(true); + } + } + Ok(false) + }, + signed: |_T| { vortex_panic!("unsigned sum state with signed input") }, + floating: |_T| { vortex_panic!("unsigned sum state with float input") } + ), + SumState::Signed(acc) => match_each_native_ptype!(p.ptype(), + unsigned: |_T| { vortex_panic!("signed sum state with unsigned input") }, + signed: |T| { + for (&v, valid) in p.as_slice::().iter().zip_eq(validity.iter()) { + if valid && checked_add_i64(acc, v.to_i64().vortex_expect("signed to i64")) { + return Ok(true); + } + } + Ok(false) + }, + floating: |_T| { vortex_panic!("signed sum state with float input") } + ), + SumState::Float(acc) => match_each_native_ptype!(p.ptype(), + unsigned: |_T| { vortex_panic!("float sum state with unsigned input") }, + signed: |_T| { vortex_panic!("float sum state with signed input") }, + floating: |T| { + for (&v, valid) in p.as_slice::().iter().zip_eq(validity.iter()) { + if valid { + *acc += ToPrimitive::to_f64(&v).vortex_expect("float to f64"); + } + } + Ok(false) + } + ), + SumState::Decimal(_) => vortex_panic!("decimal sum state with primitive input"), + } +} + +#[cfg(test)] +mod tests { + use vortex_buffer::buffer; + use vortex_error::VortexResult; + + use crate::IntoArray; + use crate::LEGACY_SESSION; + use crate::VortexSessionExecute; + use crate::aggregate_fn::Accumulator; + use crate::aggregate_fn::DynAccumulator; + use crate::aggregate_fn::EmptyOptions; + use crate::aggregate_fn::fns::sum::Sum; + use crate::aggregate_fn::fns::sum::sum; + use crate::arrays::PrimitiveArray; + use crate::dtype::DType; + use crate::dtype::Nullability; + use crate::dtype::Nullability::Nullable; + use crate::dtype::PType; + use crate::scalar::Scalar; + use crate::validity::Validity; + + #[test] + fn sum_i32() -> VortexResult<()> { + let arr = PrimitiveArray::new(buffer![1i32, 2, 3, 4], Validity::NonNullable).into_array(); + let result = sum(&arr, &mut LEGACY_SESSION.create_execution_ctx())?; + assert_eq!(result.as_primitive().typed_value::(), Some(10)); + Ok(()) + } + + #[test] + fn sum_u8() -> VortexResult<()> { + let arr = PrimitiveArray::new(buffer![10u8, 20, 30], Validity::NonNullable).into_array(); + let result = sum(&arr, &mut LEGACY_SESSION.create_execution_ctx())?; + assert_eq!(result.as_primitive().typed_value::(), Some(60)); + Ok(()) + } + + #[test] + fn sum_f64() -> VortexResult<()> { + let arr = + PrimitiveArray::new(buffer![1.5f64, 2.5, 3.0], Validity::NonNullable).into_array(); + let result = sum(&arr, &mut LEGACY_SESSION.create_execution_ctx())?; + assert_eq!(result.as_primitive().typed_value::(), Some(7.0)); + Ok(()) + } + + #[test] + fn sum_with_nulls() -> VortexResult<()> { + let arr = PrimitiveArray::from_option_iter([Some(2i32), None, Some(4)]).into_array(); + let result = sum(&arr, &mut LEGACY_SESSION.create_execution_ctx())?; + assert_eq!(result.as_primitive().typed_value::(), Some(6)); + Ok(()) + } + + #[test] + fn sum_all_null() -> VortexResult<()> { + let arr = PrimitiveArray::from_option_iter([None::, None, None]).into_array(); + let result = sum(&arr, &mut LEGACY_SESSION.create_execution_ctx())?; + assert_eq!(result.as_primitive().typed_value::(), Some(0)); + Ok(()) + } + + #[test] + fn sum_all_invalid_float() -> VortexResult<()> { + let arr = PrimitiveArray::from_option_iter::([None, None, None]).into_array(); + let result = sum(&arr, &mut LEGACY_SESSION.create_execution_ctx())?; + assert_eq!(result, Scalar::primitive(0f64, Nullable)); + Ok(()) + } + + #[test] + fn sum_buffer_i32() -> VortexResult<()> { + let arr = buffer![1, 1, 1, 1].into_array(); + let result = sum(&arr, &mut LEGACY_SESSION.create_execution_ctx())?; + assert_eq!(result.as_primitive().as_::(), Some(4)); + Ok(()) + } + + #[test] + fn sum_buffer_f64() -> VortexResult<()> { + let arr = buffer![1., 1., 1., 1.].into_array(); + let result = sum(&arr, &mut LEGACY_SESSION.create_execution_ctx())?; + assert_eq!(result.as_primitive().as_::(), Some(4.)); + Ok(()) + } + + #[test] + fn sum_empty_produces_zero() -> VortexResult<()> { + let dtype = DType::Primitive(PType::I32, Nullability::NonNullable); + let mut acc = Accumulator::try_new(Sum, EmptyOptions, dtype)?; + let result = acc.finish()?; + assert_eq!(result.as_primitive().typed_value::(), Some(0)); + Ok(()) + } + + #[test] + fn sum_empty_f64_produces_zero() -> VortexResult<()> { + let dtype = DType::Primitive(PType::F64, Nullability::NonNullable); + let mut acc = Accumulator::try_new(Sum, EmptyOptions, dtype)?; + let result = acc.finish()?; + assert_eq!(result.as_primitive().typed_value::(), Some(0.0)); + Ok(()) + } + + #[test] + fn sum_checked_overflow() -> VortexResult<()> { + let arr = PrimitiveArray::new(buffer![i64::MAX, 1i64], Validity::NonNullable).into_array(); + let result = sum(&arr, &mut LEGACY_SESSION.create_execution_ctx())?; + assert!(result.is_null()); + Ok(()) + } + + #[test] + fn sum_checked_overflow_is_saturated() -> VortexResult<()> { + let dtype = DType::Primitive(PType::I64, Nullability::NonNullable); + let mut acc = Accumulator::try_new(Sum, EmptyOptions, dtype)?; + assert!(!acc.is_saturated()); + + let batch = + PrimitiveArray::new(buffer![i64::MAX, 1i64], Validity::NonNullable).into_array(); + acc.accumulate(&batch, &mut LEGACY_SESSION.create_execution_ctx())?; + assert!(acc.is_saturated()); + + // finish resets state, clearing saturation + drop(acc.finish()?); + assert!(!acc.is_saturated()); + Ok(()) + } +} diff --git a/vortex-array/src/aggregate_fn/typed.rs b/vortex-array/src/aggregate_fn/typed.rs index f8521738c69..3d1d4a8d15a 100644 --- a/vortex-array/src/aggregate_fn/typed.rs +++ b/vortex-array/src/aggregate_fn/typed.rs @@ -19,7 +19,6 @@ use std::hash::Hasher; use std::sync::Arc; use vortex_error::VortexResult; -use vortex_session::VortexSession; use crate::aggregate_fn::Accumulator; use crate::aggregate_fn::AccumulatorRef; @@ -40,18 +39,10 @@ pub(super) trait DynAggregateFn: 'static + Send + Sync + super::sealed::Sealed { fn options_any(&self) -> &dyn Any; fn coerce_args(&self, input_dtype: &DType) -> VortexResult; - fn return_dtype(&self, input_dtype: &DType) -> VortexResult; - fn state_dtype(&self, input_dtype: &DType) -> VortexResult; - fn accumulator( - &self, - input_dtype: &DType, - session: &VortexSession, - ) -> VortexResult; - fn accumulator_grouped( - &self, - input_dtype: &DType, - session: &VortexSession, - ) -> VortexResult; + fn return_dtype(&self, input_dtype: &DType) -> Option; + fn state_dtype(&self, input_dtype: &DType) -> Option; + fn accumulator(&self, input_dtype: &DType) -> VortexResult; + fn accumulator_grouped(&self, input_dtype: &DType) -> VortexResult; fn options_serialize(&self) -> VortexResult>>; fn options_eq(&self, other_options: &dyn Any) -> bool; @@ -89,37 +80,27 @@ impl DynAggregateFn for AggregateFnInner { V::coerce_args(&self.vtable, &self.options, input_dtype) } - fn return_dtype(&self, input_dtype: &DType) -> VortexResult { + fn return_dtype(&self, input_dtype: &DType) -> Option { V::return_dtype(&self.vtable, &self.options, input_dtype) } - fn state_dtype(&self, input_dtype: &DType) -> VortexResult { + fn state_dtype(&self, input_dtype: &DType) -> Option { V::partial_dtype(&self.vtable, &self.options, input_dtype) } - fn accumulator( - &self, - input_dtype: &DType, - session: &VortexSession, - ) -> VortexResult { + fn accumulator(&self, input_dtype: &DType) -> VortexResult { Ok(Box::new(Accumulator::try_new( self.vtable.clone(), self.options.clone(), input_dtype.clone(), - session.clone(), )?)) } - fn accumulator_grouped( - &self, - input_dtype: &DType, - session: &VortexSession, - ) -> VortexResult { + fn accumulator_grouped(&self, input_dtype: &DType) -> VortexResult { Ok(Box::new(GroupedAccumulator::try_new( self.vtable.clone(), self.options.clone(), input_dtype.clone(), - session.clone(), )?)) } diff --git a/vortex-array/src/aggregate_fn/vtable.rs b/vortex-array/src/aggregate_fn/vtable.rs index 64db588d336..3843ad29383 100644 --- a/vortex-array/src/aggregate_fn/vtable.rs +++ b/vortex-array/src/aggregate_fn/vtable.rs @@ -70,13 +70,17 @@ pub trait AggregateFnVTable: 'static + Sized + Clone + Send + Sync { } /// The return [`DType`] of the aggregate. - fn return_dtype(&self, options: &Self::Options, input_dtype: &DType) -> VortexResult; + /// + /// Returns `None` if the aggregate function cannot be applied to the input dtype. + fn return_dtype(&self, options: &Self::Options, input_dtype: &DType) -> Option; /// DType of the intermediate partial accumulator state. /// /// Use a struct dtype when multiple fields are needed /// (e.g., Mean: `Struct { sum: f64, count: u64 }`). - fn partial_dtype(&self, options: &Self::Options, input_dtype: &DType) -> VortexResult; + /// + /// Returns `None` if the aggregate function cannot be applied to the input dtype. + fn partial_dtype(&self, options: &Self::Options, input_dtype: &DType) -> Option; /// Return the partial accumulator state for an empty group. fn empty_partial( diff --git a/vortex-array/src/array/mod.rs b/vortex-array/src/array/mod.rs index 943ff246710..79eb980d317 100644 --- a/vortex-array/src/array/mod.rs +++ b/vortex-array/src/array/mod.rs @@ -31,6 +31,7 @@ use crate::ExecutionCtx; use crate::LEGACY_SESSION; use crate::ToCanonical; use crate::VortexSessionExecute; +use crate::aggregate_fn::fns::sum::sum; use crate::arrays::Bool; use crate::arrays::Constant; use crate::arrays::DictArray; @@ -43,7 +44,6 @@ use crate::arrays::VarBin; use crate::arrays::VarBinView; use crate::buffer::BufferHandle; use crate::builders::ArrayBuilder; -use crate::compute; use crate::dtype::DType; use crate::dtype::Nullability; use crate::expr::stats::Precision; @@ -567,6 +567,7 @@ impl DynArray for ArrayAdapter { } } + // TODO(ngates): deprecate this function since it requires compute. fn valid_count(&self) -> VortexResult { if let Some(Precision::Exact(invalid_count)) = self.statistics().get_as::(Stat::NullCount) @@ -578,8 +579,10 @@ impl DynArray for ArrayAdapter { Validity::NonNullable | Validity::AllValid => self.len(), Validity::AllInvalid => 0, Validity::Array(a) => { - let sum = compute::sum(&a)?; - sum.as_primitive() + let mut ctx = LEGACY_SESSION.create_execution_ctx(); + let array_sum = sum(&a, &mut ctx)?; + array_sum + .as_primitive() .as_::() .ok_or_else(|| vortex_err!("sum of validity array is null"))? } diff --git a/vortex-array/src/arrays/bool/compute/mod.rs b/vortex-array/src/arrays/bool/compute/mod.rs index 248ea617f2f..0799b77ad0e 100644 --- a/vortex-array/src/arrays/bool/compute/mod.rs +++ b/vortex-array/src/arrays/bool/compute/mod.rs @@ -10,7 +10,6 @@ mod mask; mod min_max; pub mod rules; mod slice; -mod sum; mod take; #[cfg(test)] diff --git a/vortex-array/src/arrays/bool/compute/sum.rs b/vortex-array/src/arrays/bool/compute/sum.rs deleted file mode 100644 index 531e25f2e90..00000000000 --- a/vortex-array/src/arrays/bool/compute/sum.rs +++ /dev/null @@ -1,46 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: Copyright the Vortex contributors - -use std::ops::BitAnd; - -use vortex_error::VortexExpect; -use vortex_error::VortexResult; -use vortex_mask::AllOr; - -use crate::arrays::Bool; -use crate::arrays::BoolArray; -use crate::compute::SumKernel; -use crate::compute::SumKernelAdapter; -use crate::dtype::Nullability; -use crate::register_kernel; -use crate::scalar::Scalar; - -impl SumKernel for Bool { - fn sum(&self, array: &BoolArray, accumulator: &Scalar) -> VortexResult { - let true_count: Option = match array.validity_mask()?.bit_buffer() { - AllOr::All => { - // All-valid - Some(array.to_bit_buffer().true_count() as u64) - } - AllOr::None => { - // All-invalid - unreachable!("All-invalid boolean array should have been handled by entry-point") - } - AllOr::Some(validity_mask) => { - Some(array.to_bit_buffer().bitand(validity_mask).true_count() as u64) - } - }; - - let acc_value = accumulator - .as_primitive() - .as_::() - .vortex_expect("cannot be null"); - let result = true_count.and_then(|tc| acc_value.checked_add(tc)); - Ok(match result { - Some(v) => Scalar::primitive(v, Nullability::Nullable), - None => Scalar::null_native::(), - }) - } -} - -register_kernel!(SumKernelAdapter(Bool).lift()); diff --git a/vortex-array/src/arrays/chunked/compute/aggregate.rs b/vortex-array/src/arrays/chunked/compute/aggregate.rs index 243fb8b1a6b..86965f45072 100644 --- a/vortex-array/src/arrays/chunked/compute/aggregate.rs +++ b/vortex-array/src/arrays/chunked/compute/aggregate.rs @@ -24,9 +24,9 @@ impl DynAggregateKernel for ChunkedArrayAggregate { return Ok(None); }; - let mut acc = aggregate_fn.accumulator(chunked.dtype(), ctx.session())?; + let mut acc = aggregate_fn.accumulator(chunked.dtype())?; for chunk in chunked.chunks() { - acc.accumulate(chunk)?; + acc.accumulate(chunk, ctx)?; } Ok(Some(acc.finish()?)) } @@ -37,9 +37,10 @@ mod tests { use vortex_buffer::Buffer; use vortex_buffer::buffer; use vortex_error::VortexResult; - use vortex_session::VortexSession; use crate::IntoArray; + use crate::LEGACY_SESSION; + use crate::VortexSessionExecute; use crate::aggregate_fn::Accumulator; use crate::aggregate_fn::DynAccumulator; use crate::aggregate_fn::EmptyOptions; @@ -52,13 +53,10 @@ mod tests { use crate::dtype::PType; use crate::scalar::Scalar; - fn session() -> VortexSession { - VortexSession::empty() - } - fn run_sum(batch: &crate::ArrayRef) -> VortexResult { - let mut acc = Accumulator::try_new(Sum, EmptyOptions, batch.dtype().clone(), session())?; - acc.accumulate(batch)?; + let mut ctx = LEGACY_SESSION.create_execution_ctx(); + let mut acc = Accumulator::try_new(Sum, EmptyOptions, batch.dtype().clone())?; + acc.accumulate(batch, &mut ctx)?; acc.finish() } diff --git a/vortex-array/src/arrays/chunked/compute/mod.rs b/vortex-array/src/arrays/chunked/compute/mod.rs index 791e261745c..93b95b5ce9e 100644 --- a/vortex-array/src/arrays/chunked/compute/mod.rs +++ b/vortex-array/src/arrays/chunked/compute/mod.rs @@ -12,7 +12,6 @@ mod mask; mod min_max; pub(crate) mod rules; mod slice; -mod sum; mod take; mod zip; diff --git a/vortex-array/src/arrays/chunked/compute/sum.rs b/vortex-array/src/arrays/chunked/compute/sum.rs deleted file mode 100644 index 0078b9a8e9a..00000000000 --- a/vortex-array/src/arrays/chunked/compute/sum.rs +++ /dev/null @@ -1,236 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: Copyright the Vortex contributors - -use vortex_error::VortexResult; - -use crate::arrays::Chunked; -use crate::arrays::ChunkedArray; -use crate::compute::SumKernel; -use crate::compute::SumKernelAdapter; -use crate::compute::sum_with_accumulator; -use crate::register_kernel; -use crate::scalar::Scalar; - -impl SumKernel for Chunked { - fn sum(&self, array: &ChunkedArray, accumulator: &Scalar) -> VortexResult { - array - .chunks - .iter() - .try_fold(accumulator.clone(), |result, chunk| { - sum_with_accumulator(chunk, &result) - }) - } -} - -register_kernel!(SumKernelAdapter(Chunked).lift()); - -#[cfg(test)] -mod tests { - use vortex_buffer::buffer; - - use crate::array::IntoArray; - use crate::arrays::ChunkedArray; - use crate::arrays::ConstantArray; - use crate::arrays::DecimalArray; - use crate::arrays::PrimitiveArray; - use crate::compute::sum; - use crate::dtype::DType; - use crate::dtype::DecimalDType; - use crate::dtype::Nullability; - use crate::dtype::i256; - use crate::scalar::DecimalValue; - use crate::scalar::Scalar; - use crate::validity::Validity; - - #[test] - fn test_sum_chunked_floats_with_nulls() { - // Create chunks with floats including nulls - let chunk1 = - PrimitiveArray::from_option_iter(vec![Some(1.5f64), None, Some(3.2), Some(4.8)]); - - let chunk2 = PrimitiveArray::from_option_iter(vec![Some(2.1f64), Some(5.7), None]); - - let chunk3 = PrimitiveArray::from_option_iter(vec![None, Some(1.0f64), Some(2.5), None]); - - // Create chunked array from the chunks - let dtype = chunk1.dtype().clone(); - let chunked = ChunkedArray::try_new( - vec![ - chunk1.into_array(), - chunk2.into_array(), - chunk3.into_array(), - ], - dtype, - ) - .unwrap(); - - // Compute sum - let result = sum(&chunked.into_array()).unwrap(); - - // Expected sum: 1.5 + 3.2 + 4.8 + 2.1 + 5.7 + 1.0 + 2.5 = 20.8 - assert_eq!(result.as_primitive().as_::(), Some(20.8)); - } - - #[test] - fn test_sum_chunked_floats_all_nulls_is_zero() { - // Create chunks with all nulls - let chunk1 = PrimitiveArray::from_option_iter::(vec![None, None, None]); - let chunk2 = PrimitiveArray::from_option_iter::(vec![None, None]); - - let dtype = chunk1.dtype().clone(); - let chunked = - ChunkedArray::try_new(vec![chunk1.into_array(), chunk2.into_array()], dtype).unwrap(); - // Compute sum - should return null for all nulls - let result = sum(&chunked.into_array()).unwrap(); - assert_eq!(result, Scalar::primitive(0f64, Nullability::Nullable)); - } - - #[test] - fn test_sum_chunked_floats_empty_chunks() { - // Test with some empty chunks mixed with non-empty - let chunk1 = PrimitiveArray::from_option_iter(vec![Some(10.5f64), Some(20.3)]); - let chunk2 = ConstantArray::new(Scalar::primitive(0f64, Nullability::Nullable), 0); - let chunk3 = PrimitiveArray::from_option_iter(vec![Some(5.2f64)]); - - let dtype = chunk1.dtype().clone(); - let chunked = ChunkedArray::try_new( - vec![ - chunk1.into_array(), - chunk2.into_array(), - chunk3.into_array(), - ], - dtype, - ) - .unwrap(); - - // Compute sum: 10.5 + 20.3 + 5.2 = 36.0 - let result = sum(&chunked.into_array()).unwrap(); - assert_eq!(result.as_primitive().as_::(), Some(36.0)); - } - - #[test] - fn test_sum_chunked_int_almost_all_null_chunks() { - let chunk1 = PrimitiveArray::from_option_iter::(vec![Some(1)]); - let chunk2 = PrimitiveArray::from_option_iter::(vec![None]); - - let dtype = chunk1.dtype().clone(); - let chunked = - ChunkedArray::try_new(vec![chunk1.into_array(), chunk2.into_array()], dtype).unwrap(); - - let result = sum(&chunked.into_array()).unwrap(); - assert_eq!(result.as_primitive().as_::(), Some(1)); - } - - #[test] - fn test_sum_chunked_decimals() { - // Create decimal chunks with precision=10, scale=2 - let decimal_dtype = DecimalDType::new(10, 2); - let chunk1 = DecimalArray::new( - buffer![100i32, 100i32, 100i32, 100i32, 100i32], - decimal_dtype, - Validity::AllValid, - ); - let chunk2 = DecimalArray::new( - buffer![200i32, 200i32, 200i32], - decimal_dtype, - Validity::AllValid, - ); - let chunk3 = DecimalArray::new(buffer![300i32, 300i32], decimal_dtype, Validity::AllValid); - - let dtype = chunk1.dtype().clone(); - let chunked = ChunkedArray::try_new( - vec![ - chunk1.into_array(), - chunk2.into_array(), - chunk3.into_array(), - ], - dtype, - ) - .unwrap(); - - // Compute sum: 5*100 + 3*200 + 2*300 = 500 + 600 + 600 = 1700 (represents 17.00) - let result = sum(&chunked.into_array()).unwrap(); - let decimal_result = result.as_decimal(); - assert_eq!( - decimal_result.decimal_value(), - Some(DecimalValue::I256(i256::from_i128(1700))) - ); - } - - #[test] - fn test_sum_chunked_decimals_with_nulls() { - let decimal_dtype = DecimalDType::new(10, 2); - - // Create chunks with some nulls - all must have same nullability - let chunk1 = DecimalArray::new( - buffer![100i32, 100i32, 100i32], - decimal_dtype, - Validity::AllValid, - ); - let chunk2 = DecimalArray::new( - buffer![0i32, 0i32], - decimal_dtype, - Validity::from_iter([false, false]), - ); - let chunk3 = DecimalArray::new(buffer![200i32, 200i32], decimal_dtype, Validity::AllValid); - - let dtype = chunk1.dtype().clone(); - let chunked = ChunkedArray::try_new( - vec![ - chunk1.into_array(), - chunk2.into_array(), - chunk3.into_array(), - ], - dtype, - ) - .unwrap(); - - // Compute sum: 3*100 + 2*200 = 300 + 400 = 700 (nulls ignored) - let result = sum(&chunked.into_array()).unwrap(); - let decimal_result = result.as_decimal(); - assert_eq!( - decimal_result.decimal_value(), - Some(DecimalValue::I256(i256::from_i128(700))) - ); - } - - #[test] - fn test_sum_chunked_decimals_large() { - // Create decimals with precision 3 (max value 999) - // Sum will be 500 + 600 = 1100, which fits in result precision 13 (3+10) - let decimal_dtype = DecimalDType::new(3, 0); - let chunk1 = ConstantArray::new( - Scalar::decimal( - DecimalValue::I16(500), - decimal_dtype, - Nullability::NonNullable, - ), - 1, - ); - let chunk2 = ConstantArray::new( - Scalar::decimal( - DecimalValue::I16(600), - decimal_dtype, - Nullability::NonNullable, - ), - 1, - ); - - let dtype = chunk1.dtype().clone(); - let chunked = - ChunkedArray::try_new(vec![chunk1.into_array(), chunk2.into_array()], dtype).unwrap(); - - // Compute sum: 500 + 600 = 1100 - // Result should have precision 13 (3+10), scale 0 - let result = sum(&chunked.into_array()).unwrap(); - let decimal_result = result.as_decimal(); - assert_eq!( - decimal_result.decimal_value(), - Some(DecimalValue::I256(i256::from_i128(1100))) - ); - assert_eq!( - result.dtype(), - &DType::Decimal(DecimalDType::new(13, 0), Nullability::Nullable) - ); - } -} diff --git a/vortex-array/src/arrays/constant/compute/mod.rs b/vortex-array/src/arrays/constant/compute/mod.rs index 35578829912..f671dcde3d3 100644 --- a/vortex-array/src/arrays/constant/compute/mod.rs +++ b/vortex-array/src/arrays/constant/compute/mod.rs @@ -9,7 +9,6 @@ mod min_max; mod not; pub(crate) mod rules; mod slice; -mod sum; mod take; #[cfg(test)] diff --git a/vortex-array/src/arrays/constant/compute/sum.rs b/vortex-array/src/arrays/constant/compute/sum.rs deleted file mode 100644 index 5ff2825ed70..00000000000 --- a/vortex-array/src/arrays/constant/compute/sum.rs +++ /dev/null @@ -1,302 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: Copyright the Vortex contributors - -use num_traits::AsPrimitive; -use num_traits::CheckedAdd; -use num_traits::CheckedMul; -use vortex_error::VortexExpect; -use vortex_error::VortexResult; -use vortex_error::vortex_bail; -use vortex_error::vortex_err; - -use crate::arrays::Constant; -use crate::arrays::ConstantArray; -use crate::compute::SumKernel; -use crate::compute::SumKernelAdapter; -use crate::dtype::DType; -use crate::dtype::DecimalDType; -use crate::dtype::NativePType; -use crate::dtype::Nullability; -use crate::dtype::i256; -use crate::expr::stats::Stat; -use crate::match_each_native_ptype; -use crate::register_kernel; -use crate::scalar::DecimalScalar; -use crate::scalar::DecimalValue; -use crate::scalar::PrimitiveScalar; -use crate::scalar::Scalar; -use crate::scalar::ScalarValue; - -impl SumKernel for Constant { - fn sum(&self, array: &ConstantArray, accumulator: &Scalar) -> VortexResult { - // Compute the expected dtype of the sum. - let sum_dtype = Stat::Sum - .dtype(array.dtype()) - .ok_or_else(|| vortex_err!("Sum not supported for dtype {}", array.dtype()))?; - - let sum_value = sum_scalar(array.scalar(), array.len(), accumulator)?; - Scalar::try_new(sum_dtype, sum_value) - } -} - -fn sum_scalar( - scalar: &Scalar, - len: usize, - accumulator: &Scalar, -) -> VortexResult> { - match scalar.dtype() { - DType::Bool(_) => { - let count = match scalar.as_bool().value() { - None => unreachable!("Handled before reaching this point"), - Some(false) => 0u64, - Some(true) => len as u64, - }; - let accumulator = accumulator - .as_primitive() - .as_::() - .vortex_expect("cannot be null"); - Ok(accumulator - .checked_add(count) - .map(|v| ScalarValue::Primitive(v.into()))) - } - DType::Primitive(ptype, _) => { - #[expect(dead_code, reason = "TODO(connor): good question")] - let result = match_each_native_ptype!( - ptype, - unsigned: |T| { sum_integral::(scalar.as_primitive(), len, accumulator)?.map(|v| ScalarValue::Primitive(v.into())) }, - signed: |T| { sum_integral::(scalar.as_primitive(), len, accumulator)?.map(|v| ScalarValue::Primitive(v.into())) }, - floating: |T| { sum_float(scalar.as_primitive(), len, accumulator)?.map(|v| ScalarValue::Primitive(v.into())) } - ); - Ok(result) - } - DType::Decimal(decimal_dtype, _) => { - sum_decimal(scalar.as_decimal(), len, *decimal_dtype, accumulator) - } - DType::Extension(_) => { - sum_scalar(&scalar.as_extension().to_storage_scalar(), len, accumulator) - } - dtype => vortex_bail!("Unsupported dtype for sum: {}", dtype), - } -} - -fn sum_decimal( - decimal_scalar: DecimalScalar, - array_len: usize, - decimal_dtype: DecimalDType, - accumulator: &Scalar, -) -> VortexResult> { - let result_dtype = Stat::Sum - .dtype(&DType::Decimal(decimal_dtype, Nullability::Nullable)) - .vortex_expect("decimal supports sum"); - let result_decimal_type = result_dtype - .as_decimal_opt() - .vortex_expect("must be decimal"); - - let Some(value) = decimal_scalar.decimal_value() else { - // Null value: return null - return Ok(None); - }; - - // Convert array_len to DecimalValue for multiplication. - let len_value = DecimalValue::I256(i256::from_i128(array_len as i128)); - - let Some(array_sum) = value - .checked_mul(&len_value) - .filter(|d| d.fits_in_precision(*result_decimal_type)) - else { - return Ok(None); - }; - - // Add accumulator to array_sum. - let initial_decimal = accumulator.as_decimal(); - let initial_dec_value = initial_decimal - .decimal_value() - .unwrap_or(DecimalValue::I256(i256::ZERO)); - - let total = array_sum - .checked_add(&initial_dec_value) - .and_then(|result| { - result - .fits_in_precision(*result_decimal_type) - .then_some(result) - }); - match total { - Some(result_value) => Ok(Some(ScalarValue::from(result_value))), - None => Ok(None), // Overflow - } -} - -fn sum_integral( - primitive_scalar: PrimitiveScalar<'_>, - array_len: usize, - accumulator: &Scalar, -) -> VortexResult> -where - T: NativePType + CheckedMul + CheckedAdd, -{ - let v = primitive_scalar.as_::(); - let array_len = - T::from(array_len).ok_or_else(|| vortex_err!("array_len must fit the sum type"))?; - let Some(array_sum) = v.and_then(|v| v.checked_mul(&array_len)) else { - return Ok(None); - }; - - let initial = accumulator - .as_primitive() - .as_::() - .vortex_expect("cannot be null"); - Ok(initial.checked_add(&array_sum)) -} - -fn sum_float( - primitive_scalar: PrimitiveScalar<'_>, - array_len: usize, - accumulator: &Scalar, -) -> VortexResult> { - let initial = accumulator - .as_primitive() - .as_::() - .vortex_expect("cannot be null"); - let v = primitive_scalar - .as_::() - .vortex_expect("cannot be null"); - let len_f64: f64 = array_len.as_(); - - Ok(Some(initial + v * len_f64)) -} - -register_kernel!(SumKernelAdapter(Constant).lift()); - -#[cfg(test)] -mod tests { - use vortex_error::VortexExpect; - - use crate::DynArray; - use crate::IntoArray; - use crate::arrays::ConstantArray; - use crate::compute::sum; - use crate::compute::sum_with_accumulator; - use crate::dtype::DType; - use crate::dtype::DecimalDType; - use crate::dtype::Nullability; - use crate::dtype::Nullability::Nullable; - use crate::dtype::PType; - use crate::dtype::i256; - use crate::expr::stats::Stat; - use crate::scalar::DecimalValue; - use crate::scalar::Scalar; - - #[test] - fn test_sum_unsigned() { - let array = ConstantArray::new(5u64, 10).into_array(); - let result = sum(&array).unwrap(); - assert_eq!(result, 50u64.into()); - } - - #[test] - fn test_sum_signed() { - let array = ConstantArray::new(-5i64, 10).into_array(); - let result = sum(&array).unwrap(); - assert_eq!(result, (-50i64).into()); - } - - #[test] - fn test_sum_nullable_value() { - let array = ConstantArray::new(Scalar::null(DType::Primitive(PType::U32, Nullable)), 10) - .into_array(); - let result = sum(&array).unwrap(); - assert_eq!(result, Scalar::primitive(0u64, Nullable)); - } - - #[test] - fn test_sum_bool_false() { - let array = ConstantArray::new(false, 10).into_array(); - let result = sum(&array).unwrap(); - assert_eq!(result, 0u64.into()); - } - - #[test] - fn test_sum_bool_true() { - let array = ConstantArray::new(true, 10).into_array(); - let result = sum(&array).unwrap(); - assert_eq!(result, 10u64.into()); - } - - #[test] - fn test_sum_bool_null() { - let array = ConstantArray::new(Scalar::null(DType::Bool(Nullable)), 10).into_array(); - let result = sum(&array).unwrap(); - assert_eq!(result, Scalar::primitive(0u64, Nullable)); - } - - #[test] - fn test_sum_decimal() { - let decimal_dtype = DecimalDType::new(10, 2); - let array = ConstantArray::new( - Scalar::decimal( - DecimalValue::I64(100), - decimal_dtype, - Nullability::NonNullable, - ), - 5, - ) - .into_array(); - - let result = sum(&array).unwrap(); - - assert_eq!( - result.as_decimal().decimal_value(), - Some(DecimalValue::I256(i256::from_i128(500))) - ); - assert_eq!(result.dtype(), &Stat::Sum.dtype(array.dtype()).unwrap()); - } - - #[test] - fn test_sum_decimal_null() { - let decimal_dtype = DecimalDType::new(10, 2); - let array = ConstantArray::new(Scalar::null(DType::Decimal(decimal_dtype, Nullable)), 10) - .into_array(); - - let result = sum(&array).unwrap(); - assert_eq!( - result, - Scalar::decimal( - DecimalValue::I256(i256::ZERO), - DecimalDType::new(20, 2), - Nullable - ) - ); - } - - #[test] - fn test_sum_decimal_large_value() { - let decimal_dtype = DecimalDType::new(10, 2); - let array = ConstantArray::new( - Scalar::decimal( - DecimalValue::I64(999_999_999), - decimal_dtype, - Nullability::NonNullable, - ), - 100, - ) - .into_array(); - - let result = sum(&array).unwrap(); - assert_eq!( - result.as_decimal().decimal_value(), - Some(DecimalValue::I256(i256::from_i128(99_999_999_900))) - ); - } - - #[test] - fn test_sum_float_non_multiply() { - let acc = -2048669276050936500000000000f64; - let array = ConstantArray::new(6.1811675e16f64, 25); - let sum = sum_with_accumulator(&array.into_array(), &Scalar::primitive(acc, Nullable)) - .vortex_expect("operation should succeed in test"); - assert_eq!( - f64::try_from(&sum).vortex_expect("operation should succeed in test"), - -2048669274505644600000000000f64 - ); - } -} diff --git a/vortex-array/src/arrays/decimal/compute/mod.rs b/vortex-array/src/arrays/decimal/compute/mod.rs index 90743d880ac..9a42f7a69d5 100644 --- a/vortex-array/src/arrays/decimal/compute/mod.rs +++ b/vortex-array/src/arrays/decimal/compute/mod.rs @@ -9,7 +9,6 @@ mod is_sorted; mod mask; mod min_max; pub mod rules; -mod sum; mod take; #[cfg(test)] diff --git a/vortex-array/src/arrays/decimal/compute/sum.rs b/vortex-array/src/arrays/decimal/compute/sum.rs deleted file mode 100644 index aafc5ab01fd..00000000000 --- a/vortex-array/src/arrays/decimal/compute/sum.rs +++ /dev/null @@ -1,416 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: Copyright the Vortex contributors - -use itertools::Itertools; -use num_traits::AsPrimitive; -use num_traits::CheckedAdd; -use num_traits::NumOps; -use vortex_buffer::BitBuffer; -use vortex_buffer::Buffer; -use vortex_error::VortexExpect; -use vortex_error::VortexResult; -use vortex_error::vortex_bail; -use vortex_mask::Mask; - -use crate::arrays::Decimal; -use crate::arrays::DecimalArray; -use crate::compute::SumKernel; -use crate::compute::SumKernelAdapter; -use crate::dtype::DType; -use crate::dtype::DecimalDType; -use crate::dtype::DecimalType; -use crate::dtype::Nullability::Nullable; -use crate::expr::stats::Stat; -use crate::match_each_decimal_value_type; -use crate::register_kernel; -use crate::scalar::DecimalValue; -use crate::scalar::Scalar; - -impl SumKernel for Decimal { - fn sum(&self, array: &DecimalArray, accumulator: &Scalar) -> VortexResult { - let return_dtype = Stat::Sum - .dtype(array.dtype()) - .vortex_expect("sum for decimals exists"); - let return_decimal_dtype = *return_dtype - .as_decimal_opt() - .vortex_expect("must be decimal"); - - // Extract the initial value as a `DecimalValue`. - let initial_decimal = accumulator - .as_decimal() - .decimal_value() - .vortex_expect("cannot be null"); - - let mask = array.validity_mask()?; - let validity = match &mask { - Mask::AllTrue(_) => None, - Mask::Values(mask_values) => Some(mask_values.bit_buffer()), - Mask::AllFalse(_) => { - vortex_bail!("invalid state, all-null array should be checked by top-level sum fn") - } - }; - - let values_type = DecimalType::smallest_decimal_value_type(&return_decimal_dtype); - match_each_decimal_value_type!(array.values_type(), |I| { - match_each_decimal_value_type!(values_type, |O| { - let initial_val: O = initial_decimal - .cast() - .vortex_expect("cannot fail to cast initial value"); - - Ok(sum_to_scalar( - array.buffer::(), - validity, - initial_val, - return_decimal_dtype, - &return_dtype, - )) - }) - }) - } -} - -/// Compute the checked sum and convert the result to a [`Scalar`]. -/// -/// Returns a null scalar if the sum overflows the underlying integer type or if the result -/// exceeds the declared decimal precision. -fn sum_to_scalar( - values: Buffer, - validity: Option<&BitBuffer>, - initial: O, - return_decimal_dtype: DecimalDType, - return_dtype: &DType, -) -> Scalar -where - T: AsPrimitive, - O: CheckedAdd + NumOps + Into + Copy + 'static, - bool: AsPrimitive, -{ - let raw_sum = match validity { - Some(v) => sum_decimal_with_validity(values, v, initial), - None => sum_decimal(values, initial), - }; - - raw_sum - .map(Into::::into) - // We have to make sure that the decimal value fits the precision of the decimal dtype. - .filter(|v| v.fits_in_precision(return_decimal_dtype)) - .map(|v| Scalar::decimal(v, return_decimal_dtype, Nullable)) - // If an overflow occurs during summation, or final value does not fit, then return a null. - .unwrap_or_else(|| Scalar::null(return_dtype.clone())) -} - -fn sum_decimal, I: Copy + CheckedAdd + 'static>( - values: Buffer, - initial: I, -) -> Option { - let mut sum = initial; - for v in values.iter() { - let v: I = v.as_(); - sum = CheckedAdd::checked_add(&sum, &v)?; - } - Some(sum) -} - -fn sum_decimal_with_validity(values: Buffer, validity: &BitBuffer, initial: I) -> Option -where - T: AsPrimitive, - I: NumOps + CheckedAdd + Copy + 'static, - bool: AsPrimitive, -{ - let mut sum = initial; - for (v, valid) in values.iter().zip_eq(validity) { - let v: I = v.as_() * valid.as_(); - - sum = CheckedAdd::checked_add(&sum, &v)?; - } - Some(sum) -} - -register_kernel!(SumKernelAdapter(Decimal).lift()); - -#[cfg(test)] -mod tests { - use vortex_buffer::buffer; - use vortex_error::VortexExpect; - - use crate::IntoArray; - use crate::arrays::DecimalArray; - use crate::compute::sum; - use crate::dtype::DType; - use crate::dtype::DecimalDType; - use crate::dtype::Nullability; - use crate::dtype::i256; - use crate::scalar::DecimalValue; - use crate::scalar::Scalar; - use crate::scalar::ScalarValue; - use crate::validity::Validity; - - #[test] - fn test_sum_basic() { - let decimal = DecimalArray::new( - buffer![100i32, 200i32, 300i32], - DecimalDType::new(4, 2), - Validity::AllValid, - ); - - let result = sum(&decimal.into_array()).unwrap(); - - let expected = Scalar::try_new( - DType::Decimal(DecimalDType::new(14, 2), Nullability::NonNullable), - Some(ScalarValue::from(DecimalValue::from(600i32))), - ) - .unwrap(); - - assert_eq!(result, expected); - } - - #[test] - fn test_sum_with_nulls() { - let decimal = DecimalArray::new( - buffer![100i32, 200i32, 300i32, 400i32], - DecimalDType::new(4, 2), - Validity::from_iter([true, false, true, true]), - ); - - let result = sum(&decimal.into_array()).unwrap(); - - let expected = Scalar::try_new( - DType::Decimal(DecimalDType::new(14, 2), Nullability::Nullable), - Some(ScalarValue::from(DecimalValue::from(800i32))), - ) - .unwrap(); - - assert_eq!(result, expected); - } - - #[test] - fn test_sum_negative_values() { - let decimal = DecimalArray::new( - buffer![100i32, -200i32, 300i32, -50i32], - DecimalDType::new(4, 2), - Validity::AllValid, - ); - - let result = sum(&decimal.into_array()).unwrap(); - - let expected = Scalar::try_new( - DType::Decimal(DecimalDType::new(14, 2), Nullability::NonNullable), - Some(ScalarValue::from(DecimalValue::from(150i32))), - ) - .unwrap(); - - assert_eq!(result, expected); - } - - #[test] - fn test_sum_near_i32_max() { - // Test values close to i32::MAX to ensure proper handling - let near_max = i32::MAX - 1000; - let decimal = DecimalArray::new( - buffer![near_max, 500i32, 400i32], - DecimalDType::new(10, 2), - Validity::AllValid, - ); - - let result = sum(&decimal.into_array()).unwrap(); - - // Should use i64 for accumulation since precision increases - let expected_sum = near_max as i64 + 500 + 400; - let expected = Scalar::try_new( - DType::Decimal(DecimalDType::new(20, 2), Nullability::NonNullable), - Some(ScalarValue::from(DecimalValue::from(expected_sum))), - ) - .unwrap(); - - assert_eq!(result, expected); - } - - #[test] - fn test_sum_large_i64_values() { - // Test with large i64 values that require i128 accumulation - let large_val = i64::MAX / 4; - let decimal = DecimalArray::new( - buffer![large_val, large_val, large_val, large_val + 1], - DecimalDType::new(19, 0), - Validity::AllValid, - ); - - let result = sum(&decimal.into_array()).unwrap(); - - let expected_sum = (large_val as i128) * 4 + 1; - let expected = Scalar::try_new( - DType::Decimal(DecimalDType::new(29, 0), Nullability::NonNullable), - Some(ScalarValue::from(DecimalValue::from(expected_sum))), - ) - .unwrap(); - - assert_eq!(result, expected); - } - - #[test] - fn test_sum_overflow_detection() { - use crate::dtype::i256; - - // Create values that will overflow when summed - // Use maximum i128 values that will overflow when added - let max_val = i128::MAX / 2; - let decimal = DecimalArray::new( - buffer![max_val, max_val, max_val], - DecimalDType::new(38, 0), - Validity::AllValid, - ); - - let result = sum(&decimal.into_array()).unwrap(); - - // Should use i256 for accumulation - let expected_sum = - i256::from_i128(max_val) + i256::from_i128(max_val) + i256::from_i128(max_val); - let expected = Scalar::try_new( - DType::Decimal(DecimalDType::new(48, 0), Nullability::NonNullable), - Some(ScalarValue::from(DecimalValue::from(expected_sum))), - ) - .unwrap(); - - assert_eq!(result, expected); - } - - #[test] - fn test_sum_mixed_signs_near_overflow() { - // Test that mixed signs work correctly near overflow boundaries - let large_pos = i64::MAX / 2; - let large_neg = -(i64::MAX / 2); - let decimal = DecimalArray::new( - buffer![large_pos, large_neg, large_pos, 1000i64], - DecimalDType::new(19, 3), - Validity::AllValid, - ); - - let result = sum(&decimal.into_array()).unwrap(); - - let expected_sum = (large_pos as i128) + (large_neg as i128) + (large_pos as i128) + 1000; - let expected = Scalar::try_new( - DType::Decimal(DecimalDType::new(29, 3), Nullability::NonNullable), - Some(ScalarValue::from(DecimalValue::from(expected_sum))), - ) - .unwrap(); - - assert_eq!(result, expected); - } - - #[test] - fn test_sum_preserves_scale() { - let decimal = DecimalArray::new( - buffer![12345i32, 67890i32, 11111i32], - DecimalDType::new(6, 4), - Validity::AllValid, - ); - - let result = sum(&decimal.into_array()).unwrap(); - - // Scale should be preserved, precision increased by 10 - let expected = Scalar::try_new( - DType::Decimal(DecimalDType::new(16, 4), Nullability::NonNullable), - Some(ScalarValue::from(DecimalValue::from(91346i32))), - ) - .unwrap(); - - assert_eq!(result, expected); - } - - #[test] - fn test_sum_single_value() { - let decimal = - DecimalArray::new(buffer![42i32], DecimalDType::new(3, 1), Validity::AllValid); - - let result = sum(&decimal.into_array()).unwrap(); - - let expected = Scalar::try_new( - DType::Decimal(DecimalDType::new(13, 1), Nullability::NonNullable), - Some(ScalarValue::from(DecimalValue::from(42i32))), - ) - .unwrap(); - - assert_eq!(result, expected); - } - - #[test] - fn test_sum_with_all_nulls_except_one() { - let decimal = DecimalArray::new( - buffer![100i32, 200i32, 300i32, 400i32], - DecimalDType::new(4, 2), - Validity::from_iter([false, false, true, false]), - ); - - let result = sum(&decimal.into_array()).unwrap(); - - let expected = Scalar::try_new( - DType::Decimal(DecimalDType::new(14, 2), Nullability::Nullable), - Some(ScalarValue::from(DecimalValue::from(300i32))), - ) - .unwrap(); - - assert_eq!(result, expected); - } - - #[test] - fn test_sum_i128_to_i256_boundary() { - // Test the boundary between i128 and i256 accumulation - let large_i128 = i128::MAX / 10; - let decimal = DecimalArray::new( - buffer![ - large_i128, large_i128, large_i128, large_i128, large_i128, large_i128, large_i128, - large_i128, large_i128 - ], - DecimalDType::new(38, 0), - Validity::AllValid, - ); - - let result = sum(&decimal.into_array()).unwrap(); - - // Should use i256 for accumulation since 9 * (i128::MAX / 10) fits in i128 but we increase precision - let expected_sum = i256::from_i128(large_i128).wrapping_pow(1) * i256::from_i128(9); - let expected = Scalar::try_new( - DType::Decimal(DecimalDType::new(48, 0), Nullability::NonNullable), - Some(ScalarValue::from(DecimalValue::from(expected_sum))), - ) - .unwrap(); - - assert_eq!(result, expected); - } - - #[test] - fn test_sum_precision_overflow_without_i256_overflow() { - // Construct values that individually fit in precision 76 but whose sum exceeds it, - // while still fitting in `i256`. This ensures we return null for precision overflow - // and not just for arithmetic overflow. - let ten_to_38 = i256::from_i128(10i128.pow(38)); - let ten_to_75 = ten_to_38 * i256::from_i128(10i128.pow(37)); - // 6 * 10^75 is a 76-digit number, which fits in precision 76. - let val = ten_to_75 * i256::from_i128(6); - - let decimal_dtype = DecimalDType::new(76, 0); - let decimal = DecimalArray::new(buffer![val, val], decimal_dtype, Validity::AllValid); - - // Sum = 12 * 10^75 = 1.2 * 10^76, which exceeds precision 76 but fits in `i256`. - let result = sum(&decimal.into_array()).unwrap(); - assert_eq!( - result, - Scalar::null(DType::Decimal(decimal_dtype, Nullability::Nullable)) - ); - } - - #[test] - fn test_i256_overflow() { - let decimal_dtype = DecimalDType::new(76, 0); - let decimal = DecimalArray::new( - buffer![i256::MAX, i256::MAX, i256::MAX], - decimal_dtype, - Validity::AllValid, - ); - - assert_eq!( - sum(&decimal.into_array()).vortex_expect("operation should succeed in test"), - Scalar::null(DType::Decimal(decimal_dtype, Nullability::Nullable)) - ); - } -} diff --git a/vortex-array/src/arrays/extension/compute/mod.rs b/vortex-array/src/arrays/extension/compute/mod.rs index b5761e21187..be90418a962 100644 --- a/vortex-array/src/arrays/extension/compute/mod.rs +++ b/vortex-array/src/arrays/extension/compute/mod.rs @@ -10,7 +10,6 @@ mod mask; mod min_max; pub(crate) mod rules; mod slice; -mod sum; mod take; #[cfg(test)] diff --git a/vortex-array/src/arrays/extension/compute/sum.rs b/vortex-array/src/arrays/extension/compute/sum.rs deleted file mode 100644 index ef93aa6080f..00000000000 --- a/vortex-array/src/arrays/extension/compute/sum.rs +++ /dev/null @@ -1,20 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: Copyright the Vortex contributors - -use vortex_error::VortexResult; - -use crate::arrays::Extension; -use crate::arrays::ExtensionArray; -use crate::compute::SumKernel; -use crate::compute::SumKernelAdapter; -use crate::compute::{self}; -use crate::register_kernel; -use crate::scalar::Scalar; - -impl SumKernel for Extension { - fn sum(&self, array: &ExtensionArray, accumulator: &Scalar) -> VortexResult { - compute::sum_with_accumulator(array.storage_array(), accumulator) - } -} - -register_kernel!(SumKernelAdapter(Extension).lift()); diff --git a/vortex-array/src/arrays/primitive/compute/mod.rs b/vortex-array/src/arrays/primitive/compute/mod.rs index f0330ef55aa..9800129e5c4 100644 --- a/vortex-array/src/arrays/primitive/compute/mod.rs +++ b/vortex-array/src/arrays/primitive/compute/mod.rs @@ -11,7 +11,6 @@ mod min_max; mod nan_count; pub(crate) mod rules; mod slice; -mod sum; mod take; pub use is_constant::*; diff --git a/vortex-array/src/arrays/primitive/compute/sum.rs b/vortex-array/src/arrays/primitive/compute/sum.rs deleted file mode 100644 index a1e02b9dacb..00000000000 --- a/vortex-array/src/arrays/primitive/compute/sum.rs +++ /dev/null @@ -1,140 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: Copyright the Vortex contributors - -use itertools::Itertools; -use num_traits::CheckedAdd; -use num_traits::Float; -use num_traits::ToPrimitive; -use vortex_buffer::BitBuffer; -use vortex_error::VortexExpect; -use vortex_error::VortexResult; -use vortex_mask::AllOr; - -use crate::arrays::Primitive; -use crate::arrays::PrimitiveArray; -use crate::compute::SumKernel; -use crate::compute::SumKernelAdapter; -use crate::dtype::NativePType; -use crate::dtype::Nullability; -use crate::match_each_native_ptype; -use crate::register_kernel; -use crate::scalar::Scalar; - -impl SumKernel for Primitive { - fn sum(&self, array: &PrimitiveArray, accumulator: &Scalar) -> VortexResult { - let array_sum_scalar = match array.validity_mask()?.bit_buffer() { - AllOr::All => { - // All-valid - match_each_native_ptype!( - array.ptype(), - unsigned: |T| { - Scalar::from(sum_integer::<_, u64>( - array.as_slice::(), - accumulator.as_primitive().as_::().vortex_expect("cannot be null"), - )) - }, - signed: |T| { - Scalar::from(sum_integer::<_, i64>( - array.as_slice::(), - accumulator.as_primitive().as_::().vortex_expect("cannot be null"), - )) - }, - floating: |T| { - Scalar::primitive( - sum_float( - array.as_slice::(), - accumulator.as_primitive().as_::().vortex_expect("cannot be null"), - ), - Nullability::Nullable, - ) - } - ) - } - AllOr::None => { - // All-invalid, return accumulator - return Ok(accumulator.clone()); - } - AllOr::Some(validity_mask) => { - // Some-valid - match_each_native_ptype!( - array.ptype(), - unsigned: |T| { - Scalar::from(sum_integer_with_validity::<_, u64>( - array.as_slice::(), - validity_mask, - accumulator.as_primitive().as_::().vortex_expect("cannot be null"), - )) - }, - signed: |T| { - Scalar::from(sum_integer_with_validity::<_, i64>( - array.as_slice::(), - validity_mask, - accumulator.as_primitive().as_::().vortex_expect("cannot be null"), - )) - }, - floating: |T| { - Scalar::primitive( - sum_float_with_validity( - array.as_slice::(), - validity_mask, - accumulator.as_primitive().as_::().vortex_expect("cannot be null"), - ), - Nullability::Nullable, - ) - } - ) - } - }; - - Ok(array_sum_scalar) - } -} - -register_kernel!(SumKernelAdapter(Primitive).lift()); - -fn sum_integer( - values: &[T], - accumulator: R, -) -> Option { - let mut sum = accumulator; - for &x in values { - sum = sum.checked_add(&R::from(x)?)?; - } - Some(sum) -} - -fn sum_integer_with_validity( - values: &[T], - validity: &BitBuffer, - accumulator: R, -) -> Option { - let mut sum: R = accumulator; - for (&x, valid) in values.iter().zip_eq(validity.iter()) { - if valid { - sum = sum.checked_add(&R::from(x)?)?; - } - } - Some(sum) -} - -fn sum_float(values: &[T], accumulator: f64) -> f64 { - let mut sum = accumulator; - for &x in values { - sum += x.to_f64().vortex_expect("Failed to cast value to f64"); - } - sum -} - -fn sum_float_with_validity( - array: &[T], - validity: &BitBuffer, - accumulator: f64, -) -> f64 { - let mut sum = accumulator; - for (&x, valid) in array.iter().zip_eq(validity.iter()) { - if valid { - sum += x.to_f64().vortex_expect("Failed to cast value to f64"); - } - } - sum -} diff --git a/vortex-array/src/compute/conformance/consistency.rs b/vortex-array/src/compute/conformance/consistency.rs index e6452e4b4de..75f3b94f44e 100644 --- a/vortex-array/src/compute/conformance/consistency.rs +++ b/vortex-array/src/compute/conformance/consistency.rs @@ -27,6 +27,8 @@ use vortex_mask::Mask; use crate::ArrayRef; use crate::DynArray; use crate::IntoArray; +use crate::LEGACY_SESSION; +use crate::VortexSessionExecute; use crate::arrays::BoolArray; use crate::arrays::PrimitiveArray; use crate::builtins::ArrayBuiltins; @@ -1009,11 +1011,13 @@ fn test_boolean_demorgan_consistency(array: &ArrayRef) { /// Aggregate operations on sliced arrays must produce correct results /// regardless of the underlying encoding's offset handling. fn test_slice_aggregate_consistency(array: &ArrayRef) { + use crate::aggregate_fn::fns::sum::sum; use crate::compute::min_max; use crate::compute::nan_count; - use crate::compute::sum; use crate::dtype::DType; + let mut ctx = LEGACY_SESSION.create_execution_ctx(); + let len = array.len(); if len < 5 { return; // Need enough elements for meaningful slice @@ -1051,7 +1055,9 @@ fn test_slice_aggregate_consistency(array: &ArrayRef) { return; } - if let (Ok(slice_sum), Ok(canonical_sum)) = (sum(&sliced), sum(&canonical_sliced)) { + if let (Ok(slice_sum), Ok(canonical_sum)) = + (sum(&sliced, &mut ctx), sum(&canonical_sliced, &mut ctx)) + { // Compare sum scalars assert_eq!( slice_sum, canonical_sum, diff --git a/vortex-array/src/compute/mod.rs b/vortex-array/src/compute/mod.rs index a703502a29d..053ffa8868a 100644 --- a/vortex-array/src/compute/mod.rs +++ b/vortex-array/src/compute/mod.rs @@ -61,7 +61,6 @@ pub fn warm_up_vtables() { is_sorted::warm_up_vtable(); min_max::warm_up_vtable(); nan_count::warm_up_vtable(); - sum::warm_up_vtable(); } impl ComputeFn { diff --git a/vortex-array/src/compute/sum.rs b/vortex-array/src/compute/sum.rs index b74b336d4de..b1e4fbc6216 100644 --- a/vortex-array/src/compute/sum.rs +++ b/vortex-array/src/compute/sum.rs @@ -1,339 +1,15 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright the Vortex contributors -use std::sync::LazyLock; - -use arcref::ArcRef; -use num_traits::CheckedAdd; -use num_traits::CheckedSub; -use vortex_error::VortexError; use vortex_error::VortexResult; -use vortex_error::vortex_bail; -use vortex_error::vortex_ensure; -use vortex_error::vortex_err; -use vortex_error::vortex_panic; use crate::ArrayRef; -use crate::DynArray; -use crate::IntoArray as _; -use crate::compute::ComputeFn; -use crate::compute::ComputeFnVTable; -use crate::compute::InvocationArgs; -use crate::compute::Kernel; -use crate::compute::Output; -use crate::dtype::DType; -use crate::expr::stats::Precision; -use crate::expr::stats::Stat; -use crate::expr::stats::StatsProvider; -use crate::scalar::NumericOperator; +use crate::LEGACY_SESSION; +use crate::VortexSessionExecute; use crate::scalar::Scalar; -use crate::vtable::VTable; - -static SUM_FN: LazyLock = LazyLock::new(|| { - let compute = ComputeFn::new("sum".into(), ArcRef::new_ref(&Sum)); - for kernel in inventory::iter:: { - compute.register_kernel(kernel.0.clone()); - } - compute -}); - -pub(crate) fn warm_up_vtable() -> usize { - SUM_FN.kernels().len() -} - -/// Sum an array with an initial value. -/// -/// If the sum overflows, a null scalar will be returned. -/// If the sum is not supported for the array's dtype, an error will be raised. -/// If the array is all-invalid, the sum will be the accumulator. -/// The accumulator must have a dtype compatible with the sum result dtype. -pub(crate) fn sum_with_accumulator(array: &ArrayRef, accumulator: &Scalar) -> VortexResult { - SUM_FN - .invoke(&InvocationArgs { - inputs: &[array.into(), accumulator.into()], - options: &(), - })? - .unwrap_scalar() -} -/// Sum an array, starting from zero. -/// -/// If the sum overflows, a null scalar will be returned. -/// If the sum is not supported for the array's dtype, an error will be raised. -/// If the array is all-invalid, the sum will be zero. +#[deprecated(note = "use `vortex::array::aggregate_fn::fns::sum::sum` instead")] pub fn sum(array: &ArrayRef) -> VortexResult { - let sum_dtype = Stat::Sum - .dtype(array.dtype()) - .ok_or_else(|| vortex_err!("Sum not supported for dtype: {}", array.dtype()))?; - let zero = Scalar::zero_value(&sum_dtype); - sum_with_accumulator(array, &zero) -} - -/// For unary compute functions, it's useful to just have this short-cut. -pub struct SumArgs<'a> { - pub array: &'a dyn DynArray, - pub accumulator: &'a Scalar, -} - -impl<'a> TryFrom<&InvocationArgs<'a>> for SumArgs<'a> { - type Error = VortexError; - - fn try_from(value: &InvocationArgs<'a>) -> Result { - if value.inputs.len() != 2 { - vortex_bail!("Expected 2 inputs, found {}", value.inputs.len()); - } - let array = value.inputs[0] - .array() - .ok_or_else(|| vortex_err!("Expected input 0 to be an array"))?; - let accumulator = value.inputs[1] - .scalar() - .ok_or_else(|| vortex_err!("Expected input 1 to be a scalar"))?; - Ok(SumArgs { array, accumulator }) - } -} - -struct Sum; - -impl ComputeFnVTable for Sum { - fn invoke( - &self, - args: &InvocationArgs, - kernels: &[ArcRef], - ) -> VortexResult { - let SumArgs { array, accumulator } = args.try_into()?; - let array = array.to_array(); - - // Compute the expected dtype of the sum. - let sum_dtype = self.return_dtype(args)?; - - vortex_ensure!( - &sum_dtype == accumulator.dtype(), - "sum_dtype {sum_dtype} must match accumulator dtype {}", - accumulator.dtype() - ); - - // Short-circuit using array statistics. - if let Some(Precision::Exact(sum_scalar)) = array.statistics().get(Stat::Sum) { - // For floats only use stats if accumulator is zero. otherwise we might have numerical - // stability issues. - match &sum_dtype { - DType::Primitive(p, _) => { - if p.is_float() && accumulator.is_zero() == Some(true) { - return Ok(sum_scalar.into()); - } else if p.is_int() { - let sum_from_stat = accumulator - .as_primitive() - .checked_add(&sum_scalar.as_primitive()) - .map(Scalar::from); - return Ok(sum_from_stat - .unwrap_or_else(|| Scalar::null(sum_dtype)) - .into()); - } - } - DType::Decimal(..) => { - let sum_from_stat = accumulator - .as_decimal() - .checked_binary_numeric(&sum_scalar.as_decimal(), NumericOperator::Add) - .map(Scalar::from); - return Ok(sum_from_stat - .unwrap_or_else(|| Scalar::null(sum_dtype)) - .into()); - } - _ => unreachable!("Sum will always be a decimal or a primitive dtype"), - } - } - - let sum_scalar = sum_impl(&array, accumulator, kernels)?; - - // Update the statistics with the computed sum. Stored statistic shouldn't include the accumulator. - match sum_dtype { - DType::Primitive(p, _) => { - if p.is_float() - && accumulator.is_zero() == Some(true) - && let Some(sum_value) = sum_scalar.value().cloned() - { - array - .statistics() - .set(Stat::Sum, Precision::Exact(sum_value)); - } else if p.is_int() - && let Some(less_accumulator) = sum_scalar - .as_primitive() - .checked_sub(&accumulator.as_primitive()) - && let Some(val) = Scalar::from(less_accumulator).into_value() - { - array.statistics().set(Stat::Sum, Precision::Exact(val)); - } - } - DType::Decimal(..) => { - if let Some(less_accumulator) = sum_scalar - .as_decimal() - .checked_binary_numeric(&accumulator.as_decimal(), NumericOperator::Sub) - && let Some(val) = Scalar::from(less_accumulator).into_value() - { - array.statistics().set(Stat::Sum, Precision::Exact(val)); - } - } - _ => unreachable!("Sum will always be a decimal or a primitive dtype"), - } - - Ok(sum_scalar.into()) - } - - fn return_dtype(&self, args: &InvocationArgs) -> VortexResult { - let SumArgs { array, .. } = args.try_into()?; - Stat::Sum - .dtype(array.dtype()) - .ok_or_else(|| vortex_err!("Sum not supported for dtype: {}", array.dtype())) - } - - fn return_len(&self, _args: &InvocationArgs) -> VortexResult { - // The sum function always returns a single scalar value. - Ok(1) - } - - fn is_elementwise(&self) -> bool { - false - } -} - -pub struct SumKernelRef(ArcRef); -inventory::collect!(SumKernelRef); - -pub trait SumKernel: VTable { - /// # Preconditions - /// - /// * The array's DType is summable - /// * The array is not all-null - /// * The accumulator must have a dtype compatible with the sum result dtype - fn sum(&self, array: &Self::Array, accumulator: &Scalar) -> VortexResult; -} - -#[derive(Debug)] -pub struct SumKernelAdapter(pub V); - -impl SumKernelAdapter { - pub const fn lift(&'static self) -> SumKernelRef { - SumKernelRef(ArcRef::new_ref(self)) - } -} - -impl Kernel for SumKernelAdapter { - fn invoke(&self, args: &InvocationArgs) -> VortexResult> { - let SumArgs { array, accumulator } = args.try_into()?; - let Some(array) = array.as_opt::() else { - return Ok(None); - }; - Ok(Some(V::sum(&self.0, array, accumulator)?.into())) - } -} - -/// Sum an array. -/// -/// If the sum overflows, a null scalar will be returned. -/// If the sum is not supported for the array's dtype, an error will be raised. -/// If the array is all-invalid, the sum will be the accumulator. -pub fn sum_impl( - array: &ArrayRef, - accumulator: &Scalar, - kernels: &[ArcRef], -) -> VortexResult { - if array.is_empty() || array.all_invalid()? || accumulator.is_null() { - return Ok(accumulator.clone()); - } - - // Try to find a sum kernel - let args = InvocationArgs { - inputs: &[array.into(), accumulator.into()], - options: &(), - }; - for kernel in kernels { - if let Some(output) = kernel.invoke(&args)? { - return output.unwrap_scalar(); - } - } - - // Otherwise, canonicalize and try again. - tracing::debug!("No sum implementation found for {}", array.encoding_id()); - if array.is_canonical() { - // Panic to avoid recursion, but it should never be hit. - vortex_panic!( - "No sum implementation found for canonical array: {}", - array.encoding_id() - ); - } - let canonical = array.to_canonical()?.into_array(); - sum_with_accumulator(&canonical, accumulator) -} - -#[cfg(test)] -mod test { - use vortex_buffer::buffer; - use vortex_error::VortexExpect; - - use crate::IntoArray as _; - use crate::arrays::BoolArray; - use crate::arrays::ChunkedArray; - use crate::arrays::PrimitiveArray; - use crate::compute::sum; - use crate::compute::sum_with_accumulator; - use crate::dtype::DType; - use crate::dtype::Nullability; - use crate::dtype::PType; - use crate::scalar::Scalar; - - #[test] - fn sum_all_invalid() { - let array = PrimitiveArray::from_option_iter::([None, None, None]).into_array(); - let result = sum(&array).unwrap(); - assert_eq!(result, Scalar::primitive(0i64, Nullability::Nullable)); - } - - #[test] - fn sum_all_invalid_float() { - let array = PrimitiveArray::from_option_iter::([None, None, None]).into_array(); - let result = sum(&array).unwrap(); - assert_eq!(result, Scalar::primitive(0f64, Nullability::Nullable)); - } - - #[test] - fn sum_constant() { - let array = buffer![1, 1, 1, 1].into_array(); - let result = sum(&array).unwrap(); - assert_eq!(result.as_primitive().as_::(), Some(4)); - } - - #[test] - fn sum_constant_float() { - let array = buffer![1., 1., 1., 1.].into_array(); - let result = sum(&array).unwrap(); - assert_eq!(result.as_primitive().as_::(), Some(4.)); - } - - #[test] - fn sum_boolean() { - let array = BoolArray::from_iter([true, false, false, true]).into_array(); - let result = sum(&array).unwrap(); - assert_eq!(result.as_primitive().as_::(), Some(2)); - } - - #[test] - fn sum_stats() { - let array = ChunkedArray::try_new( - vec![ - PrimitiveArray::from_iter([1, 1, 1]).into_array(), - PrimitiveArray::from_iter([2, 2, 2]).into_array(), - ], - DType::Primitive(PType::I32, Nullability::NonNullable), - ) - .vortex_expect("operation should succeed in test"); - let array = array.into_array(); - // compute sum with accumulator to populate stats - sum_with_accumulator(&array, &Scalar::primitive(2i64, Nullability::Nullable)).unwrap(); - - let sum_without_acc = sum(&array).unwrap(); - assert_eq!( - sum_without_acc, - Scalar::primitive(9i64, Nullability::Nullable) - ); - } + let mut ctx = LEGACY_SESSION.create_execution_ctx(); + crate::aggregate_fn::fns::sum::sum(array, &mut ctx) } diff --git a/vortex-array/src/expr/stats/mod.rs b/vortex-array/src/expr/stats/mod.rs index 842d049537c..feaeeb3f311 100644 --- a/vortex-array/src/expr/stats/mod.rs +++ b/vortex-array/src/expr/stats/mod.rs @@ -11,10 +11,7 @@ use num_enum::IntoPrimitive; use num_enum::TryFromPrimitive; use crate::dtype::DType; -use crate::dtype::DecimalDType; -use crate::dtype::MAX_PRECISION; use crate::dtype::Nullability::NonNullable; -use crate::dtype::Nullability::Nullable; use crate::dtype::PType; mod bound; @@ -27,6 +24,10 @@ pub use precision::*; pub use provider::*; pub use stat_bound::*; +use crate::aggregate_fn; +use crate::aggregate_fn::AggregateFnVTable; +use crate::aggregate_fn::EmptyOptions; + #[derive( Debug, Clone, @@ -184,37 +185,7 @@ impl Stat { } } Self::Sum => { - // Any array that cannot be summed has a sum DType of null. - // Any array that can be summed, but overflows, has a sum _value_ of null. - // Therefore, we make integer sum stats nullable. - match data_type { - DType::Bool(_) => DType::Primitive(PType::U64, Nullable), - DType::Primitive(ptype, _) => match ptype { - PType::U8 | PType::U16 | PType::U32 | PType::U64 => { - DType::Primitive(PType::U64, Nullable) - } - PType::I8 | PType::I16 | PType::I32 | PType::I64 => { - DType::Primitive(PType::I64, Nullable) - } - PType::F16 | PType::F32 | PType::F64 => { - // Float sums cannot overflow, but all null floats still end up as null - DType::Primitive(PType::F64, Nullable) - } - }, - DType::Extension(ext_dtype) => self.dtype(ext_dtype.storage_dtype())?, - DType::Decimal(decimal_dtype, _) => { - // Both Spark and DataFusion use this heuristic. - // - https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66 - // - https://github.com/apache/datafusion/blob/4153adf2c0f6e317ef476febfdc834208bd46622/datafusion/functions-aggregate/src/sum.rs#L188 - let precision = u8::min(MAX_PRECISION, decimal_dtype.precision() + 10); - DType::Decimal( - DecimalDType::new(precision, decimal_dtype.scale()), - Nullable, - ) - } - // Unsupported types - _ => return None, - } + return aggregate_fn::fns::sum::Sum.return_dtype(&EmptyOptions, data_type); } }) } diff --git a/vortex-array/src/stats/array.rs b/vortex-array/src/stats/array.rs index 8220a5a2deb..01ee4f897be 100644 --- a/vortex-array/src/stats/array.rs +++ b/vortex-array/src/stats/array.rs @@ -15,6 +15,9 @@ use super::StatsSet; use super::StatsSetIntoIter; use super::TypedStatsSetRef; use crate::DynArray; +use crate::LEGACY_SESSION; +use crate::VortexSessionExecute; +use crate::aggregate_fn::fns::sum::sum; use crate::builders::builder_with_capacity; use crate::compute::MinMaxResult; use crate::compute::is_constant; @@ -22,7 +25,6 @@ use crate::compute::is_sorted; use crate::compute::is_strict_sorted; use crate::compute::min_max; use crate::compute::nan_count; -use crate::compute::sum; use crate::expr::stats::Precision; use crate::expr::stats::Stat; use crate::expr::stats::StatsProvider; @@ -142,6 +144,8 @@ impl StatsSetRef<'_> { } pub fn compute_stat(&self, stat: Stat) -> VortexResult> { + let mut ctx = LEGACY_SESSION.create_execution_ctx(); + // If it's already computed and exact, we can return it. if let Some(Precision::Exact(s)) = self.get(stat) { return Ok(Some(s)); @@ -157,7 +161,7 @@ impl StatsSetRef<'_> { .is_some() .then(|| { // Sum is supported for this dtype. - sum(&array_ref) + sum(&array_ref, &mut ctx) }) .transpose()? } diff --git a/vortex-array/src/variants.rs b/vortex-array/src/variants.rs index 9dea2a7fc41..ce6af05a2e5 100644 --- a/vortex-array/src/variants.rs +++ b/vortex-array/src/variants.rs @@ -12,9 +12,11 @@ use vortex_mask::Mask; use crate::DynArray; use crate::ExecutionCtx; +use crate::LEGACY_SESSION; +use crate::VortexSessionExecute; +use crate::aggregate_fn::fns::sum::sum; use crate::arrays::BoolArray; use crate::builtins::ArrayBuiltins; -use crate::compute::sum; use crate::dtype::DType; use crate::dtype::FieldNames; use crate::dtype::PType; @@ -108,7 +110,8 @@ pub struct BoolTyped<'a>(&'a dyn DynArray); impl BoolTyped<'_> { pub fn true_count(&self) -> VortexResult { - let true_count = sum(&self.0.to_array())?; + let mut ctx = LEGACY_SESSION.create_execution_ctx(); + let true_count = sum(&self.0.to_array(), &mut ctx)?; Ok(true_count .as_primitive() .as_::() diff --git a/vortex-layout/public-api.lock b/vortex-layout/public-api.lock index 8163fa19489..2cb3dec9d9a 100644 --- a/vortex-layout/public-api.lock +++ b/vortex-layout/public-api.lock @@ -836,7 +836,7 @@ pub fn vortex_layout::layouts::zoned::zone_map::ZoneMap::present_stats(&self) -> pub fn vortex_layout::layouts::zoned::zone_map::ZoneMap::prune(&self, predicate: &vortex_array::expr::expression::Expression, session: &vortex_session::VortexSession) -> vortex_error::VortexResult -pub fn vortex_layout::layouts::zoned::zone_map::ZoneMap::to_stats_set(&self, stats: &[vortex_array::expr::stats::Stat]) -> vortex_error::VortexResult +pub fn vortex_layout::layouts::zoned::zone_map::ZoneMap::to_stats_set(&self, stats: &[vortex_array::expr::stats::Stat], ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult pub fn vortex_layout::layouts::zoned::zone_map::ZoneMap::try_new(column_dtype: vortex_array::dtype::DType, array: vortex_array::arrays::struct_::array::StructArray, stats: alloc::sync::Arc<[vortex_array::expr::stats::Stat]>) -> vortex_error::VortexResult diff --git a/vortex-layout/src/layouts/file_stats.rs b/vortex-layout/src/layouts/file_stats.rs index 10032dec4c3..48254684451 100644 --- a/vortex-layout/src/layouts/file_stats.rs +++ b/vortex-layout/src/layouts/file_stats.rs @@ -8,7 +8,9 @@ use futures::StreamExt; use itertools::Itertools; use parking_lot::Mutex; use vortex_array::ArrayRef; +use vortex_array::LEGACY_SESSION; use vortex_array::ToCanonical as _; +use vortex_array::VortexSessionExecute; use vortex_array::dtype::DType; use vortex_array::dtype::Nullability; use vortex_array::expr::stats::Stat; @@ -108,6 +110,7 @@ impl FileStatsAccumulator { } pub fn stats_sets(&self) -> Vec { + let mut ctx = LEGACY_SESSION.create_execution_ctx(); self.accumulators .lock() .iter_mut() @@ -116,7 +119,7 @@ impl FileStatsAccumulator { .vortex_expect("as_stats_table should not fail") .map(|table| { table - .to_stats_set(&self.stats) + .to_stats_set(&self.stats, &mut ctx) .vortex_expect("shouldn't fail to convert table we just created") }) .unwrap_or_default() diff --git a/vortex-layout/src/layouts/zoned/zone_map.rs b/vortex-layout/src/layouts/zoned/zone_map.rs index e1f4b45e62e..eeeed8852e3 100644 --- a/vortex-layout/src/layouts/zoned/zone_map.rs +++ b/vortex-layout/src/layouts/zoned/zone_map.rs @@ -6,10 +6,11 @@ use std::sync::Arc; use itertools::Itertools; use vortex_array::ArrayRef; use vortex_array::DynArray; +use vortex_array::ExecutionCtx; use vortex_array::IntoArray; use vortex_array::VortexSessionExecute; +use vortex_array::aggregate_fn::fns::sum::sum; use vortex_array::arrays::StructArray; -use vortex_array::compute::sum; use vortex_array::dtype::DType; use vortex_array::dtype::Nullability; use vortex_array::dtype::PType; @@ -108,7 +109,7 @@ impl ZoneMap { } /// Returns an aggregated stats set for the table. - pub fn to_stats_set(&self, stats: &[Stat]) -> VortexResult { + pub fn to_stats_set(&self, stats: &[Stat], ctx: &mut ExecutionCtx) -> VortexResult { let mut stats_set = StatsSet::default(); for &stat in stats { let Some(array) = self.get_stat(stat)? else { @@ -127,7 +128,7 @@ impl ZoneMap { } // These stats sum up Stat::NullCount | Stat::NaNCount | Stat::UncompressedSizeInBytes => { - if let Some(sum_value) = sum(&array)? + if let Some(sum_value) = sum(&array, ctx)? .cast(&DType::Primitive(PType::U64, Nullability::Nullable))? .into_value() {