From b6bffaa3c1b2c16b8f7cecdd0caf4c6df5ca5342 Mon Sep 17 00:00:00 2001 From: Connor Tsui Date: Fri, 13 Mar 2026 13:24:54 -0400 Subject: [PATCH 1/7] vector type first draft Signed-off-by: Connor Tsui --- vortex-tensor/src/lib.rs | 1 + vortex-tensor/src/vector/mod.rs | 10 ++ vortex-tensor/src/vector/vtable.rs | 153 +++++++++++++++++++++++++++++ 3 files changed, 164 insertions(+) create mode 100644 vortex-tensor/src/vector/mod.rs create mode 100644 vortex-tensor/src/vector/vtable.rs diff --git a/vortex-tensor/src/lib.rs b/vortex-tensor/src/lib.rs index dc33066bd3b..72ef475f795 100644 --- a/vortex-tensor/src/lib.rs +++ b/vortex-tensor/src/lib.rs @@ -6,5 +6,6 @@ //! similarity. pub mod fixed_shape; +pub mod vector; pub mod scalar_fns; diff --git a/vortex-tensor/src/vector/mod.rs b/vortex-tensor/src/vector/mod.rs new file mode 100644 index 00000000000..181e08c4e84 --- /dev/null +++ b/vortex-tensor/src/vector/mod.rs @@ -0,0 +1,10 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +//! Vector extension type for fixed-length float vectors (e.g., embeddings). + +/// The VTable for the vector extension type. +#[derive(Clone, Debug, Default, PartialEq, Eq, Hash)] +pub struct Vector; + +mod vtable; diff --git a/vortex-tensor/src/vector/vtable.rs b/vortex-tensor/src/vector/vtable.rs new file mode 100644 index 00000000000..8bf2577e73c --- /dev/null +++ b/vortex-tensor/src/vector/vtable.rs @@ -0,0 +1,153 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use vortex::dtype::DType; +use vortex::dtype::extension::ExtDType; +use vortex::dtype::extension::ExtId; +use vortex::dtype::extension::ExtVTable; +use vortex::error::VortexResult; +use vortex::error::vortex_bail; +use vortex::error::vortex_ensure; +use vortex::extension::EmptyMetadata; +use vortex::scalar::ScalarValue; + +use crate::vector::Vector; + +impl ExtVTable for Vector { + type Metadata = EmptyMetadata; + + // TODO(connor): This is just a placeholder for now. + type NativeValue<'a> = &'a ScalarValue; + + fn id(&self) -> ExtId { + ExtId::new_ref("vortex.vector") + } + + fn serialize_metadata(&self, _metadata: &Self::Metadata) -> VortexResult> { + Ok(Vec::new()) + } + + fn deserialize_metadata(&self, metadata: &[u8]) -> VortexResult { + vortex_ensure!( + metadata.is_empty(), + "Vector metadata must be empty, got {} bytes", + metadata.len() + ); + Ok(EmptyMetadata) + } + + fn validate_dtype(&self, ext_dtype: &ExtDType) -> VortexResult<()> { + let storage_dtype = ext_dtype.storage_dtype(); + let DType::FixedSizeList(element_dtype, _list_size, _nullability) = storage_dtype else { + vortex_bail!("Vector storage dtype must be a FixedSizeList, got {storage_dtype}"); + }; + + vortex_ensure!( + element_dtype.is_float(), + "Vector element dtype must be a float, got {element_dtype}" + ); + vortex_ensure!( + !element_dtype.is_nullable(), + "Vector element dtype must be non-nullable" + ); + + Ok(()) + } + + fn unpack_native<'a>( + &self, + _ext_dtype: &'a ExtDType, + storage_value: &'a ScalarValue, + ) -> VortexResult> { + Ok(storage_value) + } +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use rstest::rstest; + use vortex::dtype::DType; + use vortex::dtype::Nullability; + use vortex::dtype::PType; + use vortex::dtype::extension::ExtDType; + use vortex::dtype::extension::ExtVTable; + use vortex::error::VortexResult; + use vortex::extension::EmptyMetadata; + + use crate::vector::Vector; + + /// Constructs a `FixedSizeList` storage dtype with the given float [`PType`], list size, and + /// [`Nullability`]. + fn vector_storage_dtype(ptype: PType, size: u32, nullability: Nullability) -> DType { + DType::FixedSizeList( + Arc::new(DType::Primitive(ptype, Nullability::NonNullable)), + size, + nullability, + ) + } + + #[rstest] + #[case::f16(PType::F16)] + #[case::f32(PType::F32)] + #[case::f64(PType::F64)] + fn validate_accepts_float_types(#[case] ptype: PType) -> VortexResult<()> { + let storage = vector_storage_dtype(ptype, 128, Nullability::NonNullable); + ExtDType::::try_new(EmptyMetadata, storage)?; + Ok(()) + } + + #[rstest] + #[case::nullable(Nullability::Nullable)] + #[case::non_nullable(Nullability::NonNullable)] + fn validate_accepts_any_outer_nullability( + #[case] nullability: Nullability, + ) -> VortexResult<()> { + let storage = vector_storage_dtype(PType::F32, 128, nullability); + ExtDType::::try_new(EmptyMetadata, storage)?; + Ok(()) + } + + #[test] + fn validate_rejects_non_fsl() { + let storage = DType::Primitive(PType::F32, Nullability::NonNullable); + assert!(ExtDType::::try_new(EmptyMetadata, storage).is_err()); + } + + #[test] + fn validate_rejects_integer_elements() { + let storage = DType::FixedSizeList( + Arc::new(DType::Primitive(PType::U32, Nullability::NonNullable)), + 128, + Nullability::NonNullable, + ); + assert!(ExtDType::::try_new(EmptyMetadata, storage).is_err()); + } + + #[test] + fn validate_rejects_nullable_elements() { + let storage = DType::FixedSizeList( + Arc::new(DType::Primitive(PType::F32, Nullability::Nullable)), + 128, + Nullability::NonNullable, + ); + assert!(ExtDType::::try_new(EmptyMetadata, storage).is_err()); + } + + #[test] + fn roundtrip_metadata() -> VortexResult<()> { + let vtable = Vector; + let bytes = vtable.serialize_metadata(&EmptyMetadata)?; + assert!(bytes.is_empty()); + let deserialized = vtable.deserialize_metadata(&bytes)?; + assert_eq!(deserialized, EmptyMetadata); + Ok(()) + } + + #[test] + fn deserialize_rejects_non_empty_bytes() { + let vtable = Vector; + assert!(vtable.deserialize_metadata(&[0x01]).is_err()); + } +} From 16ed7db374e1f43a9a458e3dced54eebd8824951 Mon Sep 17 00:00:00 2001 From: Connor Tsui Date: Fri, 13 Mar 2026 13:48:29 -0400 Subject: [PATCH 2/7] add `AnyTensor` matcher and impl cosine similarity for vector Signed-off-by: Connor Tsui --- vortex-tensor/src/lib.rs | 1 + vortex-tensor/src/matcher.rs | 42 +++++++ .../src/scalar_fns/cosine_similarity.rs | 118 ++++++++++++++++-- vortex-tensor/src/vector/vtable.rs | 7 +- 4 files changed, 154 insertions(+), 14 deletions(-) create mode 100644 vortex-tensor/src/matcher.rs diff --git a/vortex-tensor/src/lib.rs b/vortex-tensor/src/lib.rs index 72ef475f795..56e96488167 100644 --- a/vortex-tensor/src/lib.rs +++ b/vortex-tensor/src/lib.rs @@ -8,4 +8,5 @@ pub mod fixed_shape; pub mod vector; +pub mod matcher; pub mod scalar_fns; diff --git a/vortex-tensor/src/matcher.rs b/vortex-tensor/src/matcher.rs new file mode 100644 index 00000000000..bb79ad7447e --- /dev/null +++ b/vortex-tensor/src/matcher.rs @@ -0,0 +1,42 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +//! Matcher for tensor-like extension types. + +use vortex::dtype::extension::ExtDTypeRef; +use vortex::dtype::extension::Matcher; + +use crate::fixed_shape::FixedShapeTensor; +use crate::fixed_shape::FixedShapeTensorMetadata; +use crate::vector::Vector; + +/// Matcher for any tensor-like extension type. +/// +/// Currently the different kinds of tensors that are available are: +/// +/// - `FixedShapeTensor` +/// - `Vector` +pub struct AnyTensor; + +/// The matched variant of a tensor-like extension type. +#[derive(Debug, PartialEq, Eq)] +pub enum TensorMatch<'a> { + /// A [`FixedShapeTensor`] extension type. + FixedShapeTensor(&'a FixedShapeTensorMetadata), + /// A [`Vector`] extension type. + Vector, +} + +impl Matcher for AnyTensor { + type Match<'a> = TensorMatch<'a>; + + fn try_match<'a>(item: &'a ExtDTypeRef) -> Option> { + if let Some(metadata) = item.metadata_opt::() { + return Some(TensorMatch::FixedShapeTensor(metadata)); + } + if item.metadata_opt::().is_some() { + return Some(TensorMatch::Vector); + } + None + } +} diff --git a/vortex-tensor/src/scalar_fns/cosine_similarity.rs b/vortex-tensor/src/scalar_fns/cosine_similarity.rs index f64ffc4cd11..cad3f2811d7 100644 --- a/vortex-tensor/src/scalar_fns/cosine_similarity.rs +++ b/vortex-tensor/src/scalar_fns/cosine_similarity.rs @@ -1,8 +1,9 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright the Vortex contributors -//! Cosine similarity expression for [`FixedShapeTensor`](crate::fixed_shape::FixedShapeTensor) -//! arrays. +//! Cosine similarity expression for tensor-like extension arrays +//! ([`FixedShapeTensor`](crate::fixed_shape::FixedShapeTensor) and +//! [`Vector`](crate::vector::Vector)). use std::fmt::Formatter; @@ -19,6 +20,7 @@ use vortex::array::match_each_float_ptype; use vortex::dtype::DType; use vortex::dtype::NativePType; use vortex::dtype::Nullability; +use vortex::dtype::extension::Matcher; use vortex::error::VortexResult; use vortex::error::vortex_bail; use vortex::error::vortex_ensure; @@ -31,18 +33,21 @@ use vortex::scalar_fn::ExecutionArgs; use vortex::scalar_fn::ScalarFnId; use vortex::scalar_fn::ScalarFnVTable; +use crate::matcher::AnyTensor; + // TODO(connor): We will want to add implementations for unit normalized vectors and also vectors // encoded in spherical coordinates. /// Cosine similarity between two columns. /// -/// For [`FixedShapeTensor`], computes `dot(a, b) / (||a|| * ||b||)` over the flat backing buffer of -/// each tensor. The shape and permutation do not affect the result because cosine similarity only -/// depends on the element values, not their logical arrangement. +/// Computes `dot(a, b) / (||a|| * ||b||)` over the flat backing buffer of each tensor or vector. +/// The shape and permutation do not affect the result because cosine similarity only depends on the +/// element values, not their logical arrangement. /// -/// Right now, both inputs must be [`FixedShapeTensor`] extension arrays with the same dtype and a -/// float element type. The output is a float column of the same float type. +/// Both inputs must be tensor-like extension arrays ([`FixedShapeTensor`] or [`Vector`]) with the +/// same dtype and a float element type. The output is a float column of the same float type. /// /// [`FixedShapeTensor`]: crate::fixed_shape::FixedShapeTensor +/// [`Vector`]: crate::vector::Vector #[derive(Clone)] pub struct CosineSimilarity; @@ -92,10 +97,14 @@ impl ScalarFnVTable for CosineSimilarity { // We don't need to look at rhs anymore since we know lhs and rhs are equal. - // Both inputs must be extension types. + // Both inputs must be tensor-like extension types. let lhs_ext = lhs.as_extension_opt().ok_or_else(|| { vortex_err!("cosine_similarity lhs must be an extension type, got {lhs}") })?; + vortex_ensure!( + AnyTensor::matches(lhs_ext), + "cosine_similarity inputs must be an `AnyTensor`, got {lhs}" + ); // Extract the element dtype from the storage FixedSizeList. let element_dtype = lhs_ext @@ -258,6 +267,7 @@ mod tests { use vortex::dtype::Nullability; use vortex::dtype::extension::ExtDType; use vortex::error::VortexResult; + use vortex::extension::EmptyMetadata; use vortex::scalar::Scalar; use vortex::scalar_fn::EmptyOptions; use vortex::scalar_fn::ScalarFn; @@ -265,6 +275,7 @@ mod tests { use crate::fixed_shape::FixedShapeTensor; use crate::fixed_shape::FixedShapeTensorMetadata; use crate::scalar_fns::cosine_similarity::CosineSimilarity; + use crate::vector::Vector; /// Builds a [`FixedShapeTensor`] extension array from flat f64 elements and a logical shape. /// @@ -460,4 +471,95 @@ mod tests { ); Ok(()) } + + /// Builds a [`Vector`] extension array from flat f64 elements and a vector dimension size. + fn vector_array(dim: u32, elements: &[f64]) -> VortexResult { + let row_count = elements.len() / dim as usize; + + let elems: ArrayRef = Buffer::copy_from(elements).into_array(); + let fsl = FixedSizeListArray::new(elems, dim, Validity::NonNullable, row_count); + + let ext_dtype = ExtDType::::try_new(EmptyMetadata, fsl.dtype().clone())?.erased(); + + Ok(ExtensionArray::new(ext_dtype, fsl.into_array()).into_array()) + } + + #[test] + fn vector_unit_vectors() -> VortexResult<()> { + let lhs = vector_array( + 3, + &[ + 1.0, 0.0, 0.0, // vector 0 + 0.0, 1.0, 0.0, // vector 1 + ], + )?; + let rhs = vector_array( + 3, + &[ + 1.0, 0.0, 0.0, // vector 0 + 1.0, 0.0, 0.0, // vector 1 + ], + )?; + + // Row 0: identical -> 1.0, row 1: orthogonal -> 0.0. + assert_close(&eval_cosine_similarity(lhs, rhs, 2)?, &[1.0, 0.0]); + Ok(()) + } + + #[test] + fn vector_self_similarity() -> VortexResult<()> { + let arr = vector_array( + 4, + &[ + 1.0, 2.0, 3.0, 4.0, // vector 0 + 0.0, 1.0, 0.0, 0.0, // vector 1 + 5.0, 0.0, 5.0, 0.0, // vector 2 + ], + )?; + + assert_close( + &eval_cosine_similarity(arr.clone(), arr, 3)?, + &[1.0, 1.0, 1.0], + ); + Ok(()) + } + + /// Builds a [`Vector`] extension array whose storage is a [`ConstantArray`]. + fn constant_vector_array(elements: &[f64], len: usize) -> VortexResult { + let element_dtype = DType::Primitive(vortex::dtype::PType::F64, Nullability::NonNullable); + + let children: Vec = elements + .iter() + .map(|&v| Scalar::primitive(v, Nullability::NonNullable)) + .collect(); + let storage_scalar = + Scalar::fixed_size_list(element_dtype, children, Nullability::NonNullable); + + let storage = ConstantArray::new(storage_scalar, len).into_array(); + + let ext_dtype = + ExtDType::::try_new(EmptyMetadata, storage.dtype().clone())?.erased(); + + Ok(ExtensionArray::new(ext_dtype, storage).into_array()) + } + + #[test] + fn vector_constant_query() -> VortexResult<()> { + let data = vector_array( + 3, + &[ + 1.0, 0.0, 0.0, // vector 0 + 0.0, 1.0, 0.0, // vector 1 + 0.0, 0.0, 1.0, // vector 2 + 1.0, 0.0, 0.0, // vector 3 + ], + )?; + let query = constant_vector_array(&[1.0, 0.0, 0.0], 4)?; + + assert_close( + &eval_cosine_similarity(data, query, 4)?, + &[1.0, 0.0, 0.0, 1.0], + ); + Ok(()) + } } diff --git a/vortex-tensor/src/vector/vtable.rs b/vortex-tensor/src/vector/vtable.rs index 8bf2577e73c..72f022d4904 100644 --- a/vortex-tensor/src/vector/vtable.rs +++ b/vortex-tensor/src/vector/vtable.rs @@ -27,12 +27,7 @@ impl ExtVTable for Vector { Ok(Vec::new()) } - fn deserialize_metadata(&self, metadata: &[u8]) -> VortexResult { - vortex_ensure!( - metadata.is_empty(), - "Vector metadata must be empty, got {} bytes", - metadata.len() - ); + fn deserialize_metadata(&self, _metadata: &[u8]) -> VortexResult { Ok(EmptyMetadata) } From 213d66597c736d2c82834a3bffc4c0f26ded4dbf Mon Sep 17 00:00:00 2001 From: Connor Tsui Date: Fri, 13 Mar 2026 14:39:46 -0400 Subject: [PATCH 3/7] add l2 norm scalar fn Signed-off-by: Connor Tsui --- .../src/scalar_fns/cosine_similarity.rs | 71 +--- vortex-tensor/src/scalar_fns/l2_norm.rs | 303 ++++++++++++++++++ vortex-tensor/src/scalar_fns/mod.rs | 5 + vortex-tensor/src/scalar_fns/utils.rs | 87 +++++ vortex-tensor/src/vector/vtable.rs | 6 - 5 files changed, 404 insertions(+), 68 deletions(-) create mode 100644 vortex-tensor/src/scalar_fns/l2_norm.rs create mode 100644 vortex-tensor/src/scalar_fns/utils.rs diff --git a/vortex-tensor/src/scalar_fns/cosine_similarity.rs b/vortex-tensor/src/scalar_fns/cosine_similarity.rs index cad3f2811d7..3c7c2b8cb53 100644 --- a/vortex-tensor/src/scalar_fns/cosine_similarity.rs +++ b/vortex-tensor/src/scalar_fns/cosine_similarity.rs @@ -9,12 +9,8 @@ use std::fmt::Formatter; use num_traits::Float; use vortex::array::ArrayRef; -use vortex::array::DynArray; use vortex::array::ExecutionCtx; use vortex::array::IntoArray; -use vortex::array::arrays::Constant; -use vortex::array::arrays::ConstantArray; -use vortex::array::arrays::Extension; use vortex::array::arrays::PrimitiveArray; use vortex::array::match_each_float_ptype; use vortex::dtype::DType; @@ -22,7 +18,6 @@ use vortex::dtype::NativePType; use vortex::dtype::Nullability; use vortex::dtype::extension::Matcher; use vortex::error::VortexResult; -use vortex::error::vortex_bail; use vortex::error::vortex_ensure; use vortex::error::vortex_err; use vortex::expr::Expression; @@ -34,9 +29,11 @@ use vortex::scalar_fn::ScalarFnId; use vortex::scalar_fn::ScalarFnVTable; use crate::matcher::AnyTensor; +use crate::scalar_fns::utils::extension_element_ptype; +use crate::scalar_fns::utils::extension_list_size; +use crate::scalar_fns::utils::extension_storage; +use crate::scalar_fns::utils::extract_flat_elements; -// TODO(connor): We will want to add implementations for unit normalized vectors and also vectors -// encoded in spherical coordinates. /// Cosine similarity between two columns. /// /// Computes `dot(a, b) / (||a|| * ||b||)` over the flat backing buffer of each tensor or vector. @@ -101,33 +98,18 @@ impl ScalarFnVTable for CosineSimilarity { let lhs_ext = lhs.as_extension_opt().ok_or_else(|| { vortex_err!("cosine_similarity lhs must be an extension type, got {lhs}") })?; + vortex_ensure!( AnyTensor::matches(lhs_ext), "cosine_similarity inputs must be an `AnyTensor`, got {lhs}" ); - // Extract the element dtype from the storage FixedSizeList. - let element_dtype = lhs_ext - .storage_dtype() - .as_fixed_size_list_element_opt() - .ok_or_else(|| { - vortex_err!( - "cosine_similarity storage dtype must be a FixedSizeList, got {}", - lhs_ext.storage_dtype() - ) - })?; - - // Element dtype must be a non-nullable float primitive. - vortex_ensure!( - element_dtype.is_float(), - "cosine_similarity element dtype must be a float primitive, got {element_dtype}" - ); + let ptype = extension_element_ptype(lhs_ext)?; vortex_ensure!( - !element_dtype.is_nullable(), - "cosine_similarity element dtype must be non-nullable" + ptype.is_float(), + "cosine_similarity element dtype must be a float primitive, got {ptype}" ); - let ptype = element_dtype.as_ptype(); let nullability = Nullability::from(lhs.is_nullable() || rhs.is_nullable()); Ok(DType::Primitive(ptype, nullability)) } @@ -149,10 +131,7 @@ impl ScalarFnVTable for CosineSimilarity { lhs.dtype() ) })?; - let DType::FixedSizeList(_, list_size, _) = ext.storage_dtype() else { - vortex_bail!("expected FixedSizeList storage dtype"); - }; - let list_size = *list_size as usize; + let list_size = extension_list_size(ext)?; // Extract the storage array from each extension input. We pass the storage (FSL) rather // than the extension array to avoid canonicalizing the extension wrapper. @@ -203,38 +182,6 @@ impl ScalarFnVTable for CosineSimilarity { } } -/// Extracts the storage array from an extension array without canonicalizing. -fn extension_storage(array: &ArrayRef) -> VortexResult { - let ext = array - .as_opt::() - .ok_or_else(|| vortex_err!("cosine_similarity input must be an extension array"))?; - Ok(ext.storage_array().clone()) -} - -/// Extracts the flat primitive elements from a tensor storage array (FixedSizeList). -/// -/// When the input is a [`ConstantArray`] (e.g., a literal query vector), only a single row is -/// materialized to avoid expanding it to the full column length. Returns `(elements, stride)` -/// where `stride` is `list_size` for a full array and `0` for a constant. -fn extract_flat_elements( - storage: &ArrayRef, - list_size: usize, -) -> VortexResult<(PrimitiveArray, usize)> { - if let Some(constant) = storage.as_opt::() { - // Rewrite the array as a length 1 array so when we canonicalize, we do not duplicate a - // huge amount of data. - let single = ConstantArray::new(constant.scalar().clone(), 1).into_array(); - let fsl = single.to_canonical()?.into_fixed_size_list(); - let elems = fsl.elements().to_canonical()?.into_primitive(); - Ok((elems, 0)) - } else { - // Otherwise we have to fully expand all of the data. - let fsl = storage.to_canonical()?.into_fixed_size_list(); - let elems = fsl.elements().to_canonical()?.into_primitive(); - Ok((elems, list_size)) - } -} - // TODO(connor): We should try to use a more performant library instead of doing this ourselves. /// Computes cosine similarity between two equal-length float slices. /// diff --git a/vortex-tensor/src/scalar_fns/l2_norm.rs b/vortex-tensor/src/scalar_fns/l2_norm.rs new file mode 100644 index 00000000000..b25518f5c0c --- /dev/null +++ b/vortex-tensor/src/scalar_fns/l2_norm.rs @@ -0,0 +1,303 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +//! L2 norm expression for tensor-like extension arrays +//! ([`FixedShapeTensor`](crate::fixed_shape::FixedShapeTensor) and +//! [`Vector`](crate::vector::Vector)). + +use std::fmt::Formatter; + +use num_traits::Float; +use vortex::array::ArrayRef; +use vortex::array::ExecutionCtx; +use vortex::array::IntoArray; +use vortex::array::arrays::PrimitiveArray; +use vortex::array::match_each_float_ptype; +use vortex::dtype::DType; +use vortex::dtype::NativePType; +use vortex::dtype::Nullability; +use vortex::dtype::extension::Matcher; +use vortex::error::VortexResult; +use vortex::error::vortex_ensure; +use vortex::error::vortex_err; +use vortex::expr::Expression; +use vortex::scalar_fn::Arity; +use vortex::scalar_fn::ChildName; +use vortex::scalar_fn::EmptyOptions; +use vortex::scalar_fn::ExecutionArgs; +use vortex::scalar_fn::ScalarFnId; +use vortex::scalar_fn::ScalarFnVTable; + +use crate::matcher::AnyTensor; +use crate::scalar_fns::utils::extension_element_ptype; +use crate::scalar_fns::utils::extension_list_size; +use crate::scalar_fns::utils::extension_storage; +use crate::scalar_fns::utils::extract_flat_elements; + +/// L2 norm (Euclidean norm) of a tensor or vector column. +/// +/// Computes `||v|| = sqrt(sum(v_i^2))` over the flat backing buffer of each tensor-like type. +/// +/// The input must be a tensor-like extension array with a float element type. The output is a float +/// column of the same float type. +#[derive(Clone)] +pub struct L2Norm; + +impl ScalarFnVTable for L2Norm { + type Options = EmptyOptions; + + fn id(&self) -> ScalarFnId { + ScalarFnId::new_ref("vortex.l2_norm") + } + + fn arity(&self, _options: &Self::Options) -> Arity { + Arity::Exact(1) + } + + fn child_name(&self, _options: &Self::Options, child_idx: usize) -> ChildName { + match child_idx { + 0 => ChildName::from("input"), + _ => unreachable!("L2Norm must have exactly one child"), + } + } + + fn fmt_sql( + &self, + _options: &Self::Options, + expr: &Expression, + f: &mut Formatter<'_>, + ) -> std::fmt::Result { + write!(f, "l2_norm(")?; + expr.child(0).fmt_sql(f)?; + write!(f, ")") + } + + fn return_dtype(&self, _options: &Self::Options, arg_dtypes: &[DType]) -> VortexResult { + debug_assert_eq!(arg_dtypes.len(), 1); + + let input_dtype = &arg_dtypes[0]; + + // Input must be a tensor-like extension type. + let ext = input_dtype.as_extension_opt().ok_or_else(|| { + vortex_err!("l2_norm input must be an extension type, got {input_dtype}") + })?; + + vortex_ensure!( + AnyTensor::matches(ext), + "l2_norm input must be an `AnyTensor`, got {input_dtype}" + ); + + let ptype = extension_element_ptype(ext)?; + vortex_ensure!( + ptype.is_float(), + "l2_norm element dtype must be a float primitive, got {ptype}" + ); + + let nullability = Nullability::from(input_dtype.is_nullable()); + Ok(DType::Primitive(ptype, nullability)) + } + + fn execute( + &self, + _options: &Self::Options, + args: &dyn ExecutionArgs, + _ctx: &mut ExecutionCtx, + ) -> VortexResult { + let input = args.get(0)?; + let row_count = args.row_count(); + + // Get list size (dimensions) from the dtype. + let ext = input.dtype().as_extension_opt().ok_or_else(|| { + vortex_err!( + "l2_norm input must be an extension type, got {}", + input.dtype() + ) + })?; + let list_size = extension_list_size(ext)?; + + let storage = extension_storage(&input)?; + let (elems, stride) = extract_flat_elements(&storage, list_size)?; + + match_each_float_ptype!(elems.ptype(), |T| { + let slice = elems.as_slice::(); + + let result: PrimitiveArray = (0..row_count) + .map(|i| { + let v = &slice[i * stride..i * stride + list_size]; + l2_norm_row(v) + }) + .collect(); + + Ok(result.into_array()) + }) + } + + fn validity( + &self, + _options: &Self::Options, + expression: &Expression, + ) -> VortexResult> { + // The result is null if the input tensor is null. + Ok(Some(expression.child(0).validity()?)) + } + + fn is_null_sensitive(&self, _options: &Self::Options) -> bool { + false + } + + fn is_fallible(&self, _options: &Self::Options) -> bool { + // Canonicalization of the storage array can fail. + true + } +} + +/// Computes the L2 norm (Euclidean norm) of a float slice. +/// +/// Returns `sqrt(sum(v_i^2))`. A zero-length or all-zero input produces `0.0`. +fn l2_norm_row(v: &[T]) -> T { + let mut sum_sq = T::zero(); + for &x in v { + sum_sq = sum_sq + x * x; + } + sum_sq.sqrt() +} + +#[cfg(test)] +mod tests { + use rstest::rstest; + use vortex::array::ArrayRef; + use vortex::array::IntoArray; + use vortex::array::ToCanonical; + use vortex::array::arrays::ExtensionArray; + use vortex::array::arrays::FixedSizeListArray; + use vortex::array::arrays::ScalarFnArray; + use vortex::array::validity::Validity; + use vortex::buffer::Buffer; + use vortex::dtype::extension::ExtDType; + use vortex::error::VortexResult; + use vortex::extension::EmptyMetadata; + use vortex::scalar_fn::EmptyOptions; + use vortex::scalar_fn::ScalarFn; + + use crate::fixed_shape::FixedShapeTensor; + use crate::fixed_shape::FixedShapeTensorMetadata; + use crate::scalar_fns::l2_norm::L2Norm; + use crate::vector::Vector; + + /// Builds a [`FixedShapeTensor`] extension array from flat f64 elements and a logical shape. + fn tensor_array(shape: &[usize], elements: &[f64]) -> VortexResult { + let list_size: u32 = shape.iter().product::().max(1).try_into().unwrap(); + let row_count = elements.len() / list_size as usize; + + let elems: ArrayRef = Buffer::copy_from(elements).into_array(); + let fsl = FixedSizeListArray::new(elems, list_size, Validity::NonNullable, row_count); + + let metadata = FixedShapeTensorMetadata::new(shape.to_vec()); + let ext_dtype = + ExtDType::::try_new(metadata, fsl.dtype().clone())?.erased(); + + Ok(ExtensionArray::new(ext_dtype, fsl.into_array()).into_array()) + } + + /// Builds a [`Vector`] extension array from flat f64 elements and a vector dimension size. + fn vector_array(dim: u32, elements: &[f64]) -> VortexResult { + let row_count = elements.len() / dim as usize; + + let elems: ArrayRef = Buffer::copy_from(elements).into_array(); + let fsl = FixedSizeListArray::new(elems, dim, Validity::NonNullable, row_count); + + let ext_dtype = ExtDType::::try_new(EmptyMetadata, fsl.dtype().clone())?.erased(); + + Ok(ExtensionArray::new(ext_dtype, fsl.into_array()).into_array()) + } + + /// Evaluates L2 norm on a tensor/vector array and returns the result as `Vec`. + fn eval_l2_norm(input: ArrayRef, len: usize) -> VortexResult> { + let scalar_fn = ScalarFn::new(L2Norm, EmptyOptions).erased(); + let result = ScalarFnArray::try_new(scalar_fn, vec![input], len)?; + let prim = result.to_primitive(); + Ok(prim.as_slice::().to_vec()) + } + + #[track_caller] + fn assert_close(actual: &[f64], expected: &[f64]) { + assert_eq!( + actual.len(), + expected.len(), + "length mismatch: got {} elements, expected {}", + actual.len(), + expected.len() + ); + + for (i, (a, e)) in actual.iter().zip(expected).enumerate() { + assert!( + (a - e).abs() < 1e-10, + "element {i}: got {a}, expected {e} (diff = {})", + (a - e).abs() + ); + } + } + + #[test] + fn unit_vector_norm() -> VortexResult<()> { + let arr = tensor_array( + &[3], + &[ + 1.0, 0.0, 0.0, // unit x + 0.0, 1.0, 0.0, // unit y + 0.0, 0.0, 1.0, // unit z + ], + )?; + assert_close(&eval_l2_norm(arr, 3)?, &[1.0, 1.0, 1.0]); + Ok(()) + } + + #[rstest] + #[case::three_four_five(&[2], &[3.0, 4.0], &[5.0])] + #[case::zero_vector(&[3], &[0.0, 0.0, 0.0], &[0.0])] + #[case::single_element(&[1], &[7.0], &[7.0])] + #[case::negative_elements(&[2], &[-3.0, -4.0], &[5.0])] + fn known_norms( + #[case] shape: &[usize], + #[case] elements: &[f64], + #[case] expected: &[f64], + ) -> VortexResult<()> { + let arr = tensor_array(shape, elements)?; + assert_close(&eval_l2_norm(arr, 1)?, expected); + Ok(()) + } + + #[test] + fn multiple_rows() -> VortexResult<()> { + let arr = tensor_array( + &[3], + &[ + 3.0, 4.0, 0.0, // norm = 5.0 + 0.0, 0.0, 0.0, // norm = 0.0 + 1.0, 1.0, 1.0, // norm = sqrt(3) + ], + )?; + assert_close(&eval_l2_norm(arr, 3)?, &[5.0, 0.0, 3.0_f64.sqrt()]); + Ok(()) + } + + #[test] + fn vector_known_norm() -> VortexResult<()> { + let arr = vector_array(2, &[3.0, 4.0])?; + assert_close(&eval_l2_norm(arr, 1)?, &[5.0]); + Ok(()) + } + + #[test] + fn vector_multiple_rows() -> VortexResult<()> { + let arr = vector_array( + 3, + &[ + 1.0, 0.0, 0.0, // norm = 1.0 + 3.0, 4.0, 0.0, // norm = 5.0 + ], + )?; + assert_close(&eval_l2_norm(arr, 2)?, &[1.0, 5.0]); + Ok(()) + } +} diff --git a/vortex-tensor/src/scalar_fns/mod.rs b/vortex-tensor/src/scalar_fns/mod.rs index 2797589e03f..2597f1115c8 100644 --- a/vortex-tensor/src/scalar_fns/mod.rs +++ b/vortex-tensor/src/scalar_fns/mod.rs @@ -1,4 +1,9 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright the Vortex contributors +//! Scalar function expressions defined on tensor and tensor-like extension types. + pub mod cosine_similarity; +pub mod l2_norm; + +mod utils; diff --git a/vortex-tensor/src/scalar_fns/utils.rs b/vortex-tensor/src/scalar_fns/utils.rs new file mode 100644 index 00000000000..b33da4ed440 --- /dev/null +++ b/vortex-tensor/src/scalar_fns/utils.rs @@ -0,0 +1,87 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use vortex::array::ArrayRef; +use vortex::array::IntoArray; +use vortex::array::arrays::Constant; +use vortex::array::arrays::ConstantArray; +use vortex::array::arrays::Extension; +use vortex::array::arrays::PrimitiveArray; +use vortex::dtype::DType; +use vortex::dtype::PType; +use vortex::dtype::extension::ExtDTypeRef; +use vortex::error::VortexResult; +use vortex::error::vortex_bail; +use vortex::error::vortex_ensure; +use vortex::error::vortex_err; + +/// Extracts the list size from a tensor-like extension dtype. +/// +/// The storage dtype must be a `FixedSizeList`. +pub(crate) fn extension_list_size(ext: &ExtDTypeRef) -> VortexResult { + let DType::FixedSizeList(_, list_size, _) = ext.storage_dtype() else { + vortex_bail!( + "expected FixedSizeList storage dtype, got {}", + ext.storage_dtype() + ); + }; + + Ok(*list_size as usize) +} + +/// Extracts the float element [`PType`] from a tensor-like extension dtype. +/// +/// The storage dtype must be a `FixedSizeList` of non-nullable primitives. +pub(crate) fn extension_element_ptype(ext: &ExtDTypeRef) -> VortexResult { + let element_dtype = ext + .storage_dtype() + .as_fixed_size_list_element_opt() + .ok_or_else(|| { + vortex_err!( + "expected FixedSizeList storage dtype, got {}", + ext.storage_dtype() + ) + })?; + + vortex_ensure!( + !element_dtype.is_nullable(), + "element dtype must be non-nullable" + ); + + Ok(element_dtype.as_ptype()) +} + +/// Extracts the storage array from an extension array without canonicalizing. +pub(crate) fn extension_storage(array: &ArrayRef) -> VortexResult { + let ext = array + .as_opt::() + .ok_or_else(|| vortex_err!("scalar_fn input must be an extension array"))?; + + Ok(ext.storage_array().clone()) +} + +// TODO(connor): it would be nicer if this took a generic parameter and a FnMut arg that we run +// directly on the values without having to return this ugly stride. +/// Extracts the flat primitive elements from a tensor storage array (FixedSizeList). +/// +/// When the input is a [`ConstantArray`] (e.g., a literal query vector), only a single row is +/// materialized to avoid expanding it to the full column length. Returns `(elements, stride)` +/// where `stride` is `list_size` for a full array and `0` for a constant. +pub(crate) fn extract_flat_elements( + storage: &ArrayRef, + list_size: usize, +) -> VortexResult<(PrimitiveArray, usize)> { + if let Some(constant) = storage.as_opt::() { + // Rewrite the array as a length 1 array so when we canonicalize, we do not duplicate a huge + // amount of data. + let single = ConstantArray::new(constant.scalar().clone(), 1).into_array(); + let fsl = single.to_canonical()?.into_fixed_size_list(); + let elems = fsl.elements().to_canonical()?.into_primitive(); + return Ok((elems, 0)); + } + + // Otherwise we have to fully expand all of the data. + let fsl = storage.to_canonical()?.into_fixed_size_list(); + let elems = fsl.elements().to_canonical()?.into_primitive(); + Ok((elems, list_size)) +} diff --git a/vortex-tensor/src/vector/vtable.rs b/vortex-tensor/src/vector/vtable.rs index 72f022d4904..a3206c5150e 100644 --- a/vortex-tensor/src/vector/vtable.rs +++ b/vortex-tensor/src/vector/vtable.rs @@ -139,10 +139,4 @@ mod tests { assert_eq!(deserialized, EmptyMetadata); Ok(()) } - - #[test] - fn deserialize_rejects_non_empty_bytes() { - let vtable = Vector; - assert!(vtable.deserialize_metadata(&[0x01]).is_err()); - } } From cc4c48af31f458fc73928e8f3518f562d3a1bcb5 Mon Sep 17 00:00:00 2001 From: Connor Tsui Date: Fri, 13 Mar 2026 17:44:29 -0400 Subject: [PATCH 4/7] clean up Signed-off-by: Connor Tsui --- .../src/scalar_fns/cosine_similarity.rs | 18 ++------ vortex-tensor/src/scalar_fns/l2_norm.rs | 11 ++--- vortex-tensor/src/scalar_fns/utils.rs | 44 ++++++++++++++++--- 3 files changed, 44 insertions(+), 29 deletions(-) diff --git a/vortex-tensor/src/scalar_fns/cosine_similarity.rs b/vortex-tensor/src/scalar_fns/cosine_similarity.rs index 3c7c2b8cb53..d9c952b29bb 100644 --- a/vortex-tensor/src/scalar_fns/cosine_similarity.rs +++ b/vortex-tensor/src/scalar_fns/cosine_similarity.rs @@ -138,22 +138,12 @@ impl ScalarFnVTable for CosineSimilarity { let lhs_storage = extension_storage(&lhs)?; let rhs_storage = extension_storage(&rhs)?; - // Extract the flat primitive elements from each tensor column. When an input is a - // `ConstantArray` (e.g., a literal query vector), we materialize only a single row - // instead of expanding it to the full row count. - let (lhs_elems, lhs_stride) = extract_flat_elements(&lhs_storage, list_size)?; - let (rhs_elems, rhs_stride) = extract_flat_elements(&rhs_storage, list_size)?; - - match_each_float_ptype!(lhs_elems.ptype(), |T| { - let lhs_slice = lhs_elems.as_slice::(); - let rhs_slice = rhs_elems.as_slice::(); + let lhs_flat = extract_flat_elements(&lhs_storage, list_size)?; + let rhs_flat = extract_flat_elements(&rhs_storage, list_size)?; + match_each_float_ptype!(lhs_flat.ptype(), |T| { let result: PrimitiveArray = (0..row_count) - .map(|i| { - let a = &lhs_slice[i * lhs_stride..i * lhs_stride + list_size]; - let b = &rhs_slice[i * rhs_stride..i * rhs_stride + list_size]; - cosine_similarity_row(a, b) - }) + .map(|i| cosine_similarity_row(lhs_flat.row::(i), rhs_flat.row::(i))) .collect(); Ok(result.into_array()) diff --git a/vortex-tensor/src/scalar_fns/l2_norm.rs b/vortex-tensor/src/scalar_fns/l2_norm.rs index b25518f5c0c..bada964b7ef 100644 --- a/vortex-tensor/src/scalar_fns/l2_norm.rs +++ b/vortex-tensor/src/scalar_fns/l2_norm.rs @@ -116,16 +116,11 @@ impl ScalarFnVTable for L2Norm { let list_size = extension_list_size(ext)?; let storage = extension_storage(&input)?; - let (elems, stride) = extract_flat_elements(&storage, list_size)?; - - match_each_float_ptype!(elems.ptype(), |T| { - let slice = elems.as_slice::(); + let flat = extract_flat_elements(&storage, list_size)?; + match_each_float_ptype!(flat.ptype(), |T| { let result: PrimitiveArray = (0..row_count) - .map(|i| { - let v = &slice[i * stride..i * stride + list_size]; - l2_norm_row(v) - }) + .map(|i| l2_norm_row(flat.row::(i))) .collect(); Ok(result.into_array()) diff --git a/vortex-tensor/src/scalar_fns/utils.rs b/vortex-tensor/src/scalar_fns/utils.rs index b33da4ed440..ca7ddb47b02 100644 --- a/vortex-tensor/src/scalar_fns/utils.rs +++ b/vortex-tensor/src/scalar_fns/utils.rs @@ -8,6 +8,7 @@ use vortex::array::arrays::ConstantArray; use vortex::array::arrays::Extension; use vortex::array::arrays::PrimitiveArray; use vortex::dtype::DType; +use vortex::dtype::NativePType; use vortex::dtype::PType; use vortex::dtype::extension::ExtDTypeRef; use vortex::error::VortexResult; @@ -60,28 +61,57 @@ pub(crate) fn extension_storage(array: &ArrayRef) -> VortexResult { Ok(ext.storage_array().clone()) } -// TODO(connor): it would be nicer if this took a generic parameter and a FnMut arg that we run -// directly on the values without having to return this ugly stride. +/// The flat primitive elements of a tensor storage array, with typed row access. +/// +/// This struct hides the stride detail that arises from the [`ConstantArray`] optimization: a +/// constant input materializes only a single row (stride=0), while a full array uses +/// stride=list_size. +pub(crate) struct FlatElements { + elems: PrimitiveArray, + stride: usize, + list_size: usize, +} + +impl FlatElements { + /// Returns the [`PType`] of the underlying elements. + pub fn ptype(&self) -> PType { + self.elems.ptype() + } + + /// Returns the `i`-th row as a typed slice of length `list_size`. + pub fn row(&self, i: usize) -> &[T] { + let slice = self.elems.as_slice::(); + &slice[i * self.stride..i * self.stride + self.list_size] + } +} + /// Extracts the flat primitive elements from a tensor storage array (FixedSizeList). /// /// When the input is a [`ConstantArray`] (e.g., a literal query vector), only a single row is -/// materialized to avoid expanding it to the full column length. Returns `(elements, stride)` -/// where `stride` is `list_size` for a full array and `0` for a constant. +/// materialized to avoid expanding it to the full column length. pub(crate) fn extract_flat_elements( storage: &ArrayRef, list_size: usize, -) -> VortexResult<(PrimitiveArray, usize)> { +) -> VortexResult { if let Some(constant) = storage.as_opt::() { // Rewrite the array as a length 1 array so when we canonicalize, we do not duplicate a huge // amount of data. let single = ConstantArray::new(constant.scalar().clone(), 1).into_array(); let fsl = single.to_canonical()?.into_fixed_size_list(); let elems = fsl.elements().to_canonical()?.into_primitive(); - return Ok((elems, 0)); + return Ok(FlatElements { + elems, + stride: 0, + list_size, + }); } // Otherwise we have to fully expand all of the data. let fsl = storage.to_canonical()?.into_fixed_size_list(); let elems = fsl.elements().to_canonical()?.into_primitive(); - Ok((elems, list_size)) + Ok(FlatElements { + elems, + stride: list_size, + list_size, + }) } From 1198952fd462e2db3be890eb5a4039536a3c25d8 Mon Sep 17 00:00:00 2001 From: Connor Tsui Date: Mon, 16 Mar 2026 10:11:28 -0400 Subject: [PATCH 5/7] lockfile Signed-off-by: Connor Tsui --- vortex-tensor/public-api.lock | 102 ++++++++++++++++++++++++++++++++++ 1 file changed, 102 insertions(+) diff --git a/vortex-tensor/public-api.lock b/vortex-tensor/public-api.lock index 2104eb97c95..2b84eef114f 100644 --- a/vortex-tensor/public-api.lock +++ b/vortex-tensor/public-api.lock @@ -90,6 +90,34 @@ pub fn vortex_tensor::fixed_shape::FixedShapeTensorMetadata::hash<__H: core::has impl core::marker::StructuralPartialEq for vortex_tensor::fixed_shape::FixedShapeTensorMetadata +pub mod vortex_tensor::matcher + +pub enum vortex_tensor::matcher::TensorMatch<'a> + +pub vortex_tensor::matcher::TensorMatch::FixedShapeTensor(&'a vortex_tensor::fixed_shape::FixedShapeTensorMetadata) + +pub vortex_tensor::matcher::TensorMatch::Vector + +impl<'a> core::cmp::Eq for vortex_tensor::matcher::TensorMatch<'a> + +impl<'a> core::cmp::PartialEq for vortex_tensor::matcher::TensorMatch<'a> + +pub fn vortex_tensor::matcher::TensorMatch<'a>::eq(&self, other: &vortex_tensor::matcher::TensorMatch<'a>) -> bool + +impl<'a> core::fmt::Debug for vortex_tensor::matcher::TensorMatch<'a> + +pub fn vortex_tensor::matcher::TensorMatch<'a>::fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result + +impl<'a> core::marker::StructuralPartialEq for vortex_tensor::matcher::TensorMatch<'a> + +pub struct vortex_tensor::matcher::AnyTensor + +impl vortex_array::dtype::extension::matcher::Matcher for vortex_tensor::matcher::AnyTensor + +pub type vortex_tensor::matcher::AnyTensor::Match<'a> = vortex_tensor::matcher::TensorMatch<'a> + +pub fn vortex_tensor::matcher::AnyTensor::try_match<'a>(item: &'a vortex_array::dtype::extension::erased::ExtDTypeRef) -> core::option::Option + pub mod vortex_tensor::scalar_fns pub mod vortex_tensor::scalar_fns::cosine_similarity @@ -121,3 +149,77 @@ pub fn vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity::is_null_s pub fn vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity::return_dtype(&self, _options: &Self::Options, arg_dtypes: &[vortex_array::dtype::DType]) -> vortex_error::VortexResult pub fn vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity::validity(&self, _options: &Self::Options, expression: &vortex_array::expr::expression::Expression) -> vortex_error::VortexResult> + +pub mod vortex_tensor::scalar_fns::l2_norm + +pub struct vortex_tensor::scalar_fns::l2_norm::L2Norm + +impl core::clone::Clone for vortex_tensor::scalar_fns::l2_norm::L2Norm + +pub fn vortex_tensor::scalar_fns::l2_norm::L2Norm::clone(&self) -> vortex_tensor::scalar_fns::l2_norm::L2Norm + +impl vortex_array::scalar_fn::vtable::ScalarFnVTable for vortex_tensor::scalar_fns::l2_norm::L2Norm + +pub type vortex_tensor::scalar_fns::l2_norm::L2Norm::Options = vortex_array::scalar_fn::vtable::EmptyOptions + +pub fn vortex_tensor::scalar_fns::l2_norm::L2Norm::arity(&self, _options: &Self::Options) -> vortex_array::scalar_fn::vtable::Arity + +pub fn vortex_tensor::scalar_fns::l2_norm::L2Norm::child_name(&self, _options: &Self::Options, child_idx: usize) -> vortex_array::scalar_fn::vtable::ChildName + +pub fn vortex_tensor::scalar_fns::l2_norm::L2Norm::execute(&self, _options: &Self::Options, args: &dyn vortex_array::scalar_fn::vtable::ExecutionArgs, _ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult + +pub fn vortex_tensor::scalar_fns::l2_norm::L2Norm::fmt_sql(&self, _options: &Self::Options, expr: &vortex_array::expr::expression::Expression, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result + +pub fn vortex_tensor::scalar_fns::l2_norm::L2Norm::id(&self) -> vortex_array::scalar_fn::ScalarFnId + +pub fn vortex_tensor::scalar_fns::l2_norm::L2Norm::is_fallible(&self, _options: &Self::Options) -> bool + +pub fn vortex_tensor::scalar_fns::l2_norm::L2Norm::is_null_sensitive(&self, _options: &Self::Options) -> bool + +pub fn vortex_tensor::scalar_fns::l2_norm::L2Norm::return_dtype(&self, _options: &Self::Options, arg_dtypes: &[vortex_array::dtype::DType]) -> vortex_error::VortexResult + +pub fn vortex_tensor::scalar_fns::l2_norm::L2Norm::validity(&self, _options: &Self::Options, expression: &vortex_array::expr::expression::Expression) -> vortex_error::VortexResult> + +pub mod vortex_tensor::vector + +pub struct vortex_tensor::vector::Vector + +impl core::clone::Clone for vortex_tensor::vector::Vector + +pub fn vortex_tensor::vector::Vector::clone(&self) -> vortex_tensor::vector::Vector + +impl core::cmp::Eq for vortex_tensor::vector::Vector + +impl core::cmp::PartialEq for vortex_tensor::vector::Vector + +pub fn vortex_tensor::vector::Vector::eq(&self, other: &vortex_tensor::vector::Vector) -> bool + +impl core::default::Default for vortex_tensor::vector::Vector + +pub fn vortex_tensor::vector::Vector::default() -> vortex_tensor::vector::Vector + +impl core::fmt::Debug for vortex_tensor::vector::Vector + +pub fn vortex_tensor::vector::Vector::fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result + +impl core::hash::Hash for vortex_tensor::vector::Vector + +pub fn vortex_tensor::vector::Vector::hash<__H: core::hash::Hasher>(&self, state: &mut __H) + +impl core::marker::StructuralPartialEq for vortex_tensor::vector::Vector + +impl vortex_array::dtype::extension::vtable::ExtVTable for vortex_tensor::vector::Vector + +pub type vortex_tensor::vector::Vector::Metadata = vortex_array::extension::EmptyMetadata + +pub type vortex_tensor::vector::Vector::NativeValue<'a> = &'a vortex_array::scalar::scalar_value::ScalarValue + +pub fn vortex_tensor::vector::Vector::deserialize_metadata(&self, _metadata: &[u8]) -> vortex_error::VortexResult + +pub fn vortex_tensor::vector::Vector::id(&self) -> vortex_array::dtype::extension::ExtId + +pub fn vortex_tensor::vector::Vector::serialize_metadata(&self, _metadata: &Self::Metadata) -> vortex_error::VortexResult> + +pub fn vortex_tensor::vector::Vector::unpack_native<'a>(&self, _ext_dtype: &'a vortex_array::dtype::extension::typed::ExtDType, storage_value: &'a vortex_array::scalar::scalar_value::ScalarValue) -> vortex_error::VortexResult + +pub fn vortex_tensor::vector::Vector::validate_dtype(&self, ext_dtype: &vortex_array::dtype::extension::typed::ExtDType) -> vortex_error::VortexResult<()> From 5a198e8528068a0f4caa64929ac674a07d29773f Mon Sep 17 00:00:00 2001 From: Connor Tsui Date: Mon, 16 Mar 2026 10:59:10 -0400 Subject: [PATCH 6/7] clean up Signed-off-by: Connor Tsui --- vortex-array/src/scalar_fn/vtable.rs | 2 + .../src/scalar_fns/cosine_similarity.rs | 178 +++--------------- vortex-tensor/src/scalar_fns/l2_norm.rs | 97 ++-------- vortex-tensor/src/scalar_fns/utils.rs | 135 ++++++++++++- vortex-tensor/src/vector/mod.rs | 2 +- 5 files changed, 173 insertions(+), 241 deletions(-) diff --git a/vortex-array/src/scalar_fn/vtable.rs b/vortex-array/src/scalar_fn/vtable.rs index ce2cecf5df2..3d9220a531a 100644 --- a/vortex-array/src/scalar_fn/vtable.rs +++ b/vortex-array/src/scalar_fn/vtable.rs @@ -92,6 +92,8 @@ pub trait ScalarFnVTable: 'static + Sized + Clone + Send + Sync { Ok(args.to_vec()) } + // TODO(connor): This needs a precondition for the number of args it has, or all implementations + // need to return an error if it is wrong. /// Compute the return [`DType`] of the expression if evaluated over the given input types. fn return_dtype(&self, options: &Self::Options, args: &[DType]) -> VortexResult; diff --git a/vortex-tensor/src/scalar_fns/cosine_similarity.rs b/vortex-tensor/src/scalar_fns/cosine_similarity.rs index d9c952b29bb..f90fdb7d006 100644 --- a/vortex-tensor/src/scalar_fns/cosine_similarity.rs +++ b/vortex-tensor/src/scalar_fns/cosine_similarity.rs @@ -19,6 +19,7 @@ use vortex::dtype::Nullability; use vortex::dtype::extension::Matcher; use vortex::error::VortexResult; use vortex::error::vortex_ensure; +use vortex::error::vortex_ensure_eq; use vortex::error::vortex_err; use vortex::expr::Expression; use vortex::scalar_fn::Arity; @@ -81,7 +82,12 @@ impl ScalarFnVTable for CosineSimilarity { } fn return_dtype(&self, _options: &Self::Options, arg_dtypes: &[DType]) -> VortexResult { - debug_assert_eq!(arg_dtypes.len(), 2); + vortex_ensure_eq!( + arg_dtypes.len(), + 2, + "CosineSimilarity requires exactly 2 arguments, got {}", + arg_dtypes.len() + ); let lhs = &arg_dtypes[0]; let rhs = &arg_dtypes[1]; @@ -89,25 +95,25 @@ impl ScalarFnVTable for CosineSimilarity { // Both must have the same dtype (ignoring top-level nullability). vortex_ensure!( lhs.eq_ignore_nullability(rhs), - "cosine_similarity requires both inputs to have the same dtype, got {lhs} and {rhs}" + "CosineSimilarity requires both inputs to have the same dtype, got {lhs} and {rhs}" ); // We don't need to look at rhs anymore since we know lhs and rhs are equal. // Both inputs must be tensor-like extension types. let lhs_ext = lhs.as_extension_opt().ok_or_else(|| { - vortex_err!("cosine_similarity lhs must be an extension type, got {lhs}") + vortex_err!("CosineSimilarity lhs must be an extension type, got {lhs}") })?; vortex_ensure!( AnyTensor::matches(lhs_ext), - "cosine_similarity inputs must be an `AnyTensor`, got {lhs}" + "CosineSimilarity inputs must be an `AnyTensor`, got {lhs}" ); let ptype = extension_element_ptype(lhs_ext)?; vortex_ensure!( ptype.is_float(), - "cosine_similarity element dtype must be a float primitive, got {ptype}" + "CosineSimilarity element dtype must be a float primitive, got {ptype}" ); let nullability = Nullability::from(lhs.is_nullable() || rhs.is_nullable()); @@ -191,79 +197,32 @@ fn cosine_similarity_row(a: &[T], b: &[T]) -> T { #[cfg(test)] mod tests { - use vortex::array::ArrayRef; - use vortex::array::IntoArray; + use rstest::rstest; use vortex::array::ToCanonical; - use vortex::array::arrays::ConstantArray; - use vortex::array::arrays::ExtensionArray; - use vortex::array::arrays::FixedSizeListArray; use vortex::array::arrays::ScalarFnArray; - use vortex::array::validity::Validity; - use vortex::buffer::Buffer; - use vortex::dtype::DType; - use vortex::dtype::Nullability; - use vortex::dtype::extension::ExtDType; use vortex::error::VortexResult; - use vortex::extension::EmptyMetadata; - use vortex::scalar::Scalar; use vortex::scalar_fn::EmptyOptions; use vortex::scalar_fn::ScalarFn; - use crate::fixed_shape::FixedShapeTensor; - use crate::fixed_shape::FixedShapeTensorMetadata; use crate::scalar_fns::cosine_similarity::CosineSimilarity; - use crate::vector::Vector; - - /// Builds a [`FixedShapeTensor`] extension array from flat f64 elements and a logical shape. - /// - /// The number of rows is inferred from the total element count divided by the product of the - /// shape dimensions. For 0-dimensional tensors (scalar), each element is one row. - fn tensor_array(shape: &[usize], elements: &[f64]) -> VortexResult { - let list_size: u32 = shape.iter().product::().max(1).try_into().unwrap(); - let row_count = elements.len() / list_size as usize; - - let elems: ArrayRef = Buffer::copy_from(elements).into_array(); - let fsl = FixedSizeListArray::new(elems, list_size, Validity::NonNullable, row_count); - - let metadata = FixedShapeTensorMetadata::new(shape.to_vec()); - let ext_dtype = - ExtDType::::try_new(metadata, fsl.dtype().clone())?.erased(); - - Ok(ExtensionArray::new(ext_dtype, fsl.into_array()).into_array()) - } + use crate::scalar_fns::utils::test_helpers::assert_close; + use crate::scalar_fns::utils::test_helpers::constant_tensor_array; + use crate::scalar_fns::utils::test_helpers::constant_vector_array; + use crate::scalar_fns::utils::test_helpers::tensor_array; + use crate::scalar_fns::utils::test_helpers::vector_array; /// Evaluates cosine similarity between two tensor arrays and returns the result as `Vec`. - fn eval_cosine_similarity(lhs: ArrayRef, rhs: ArrayRef, len: usize) -> VortexResult> { + fn eval_cosine_similarity( + lhs: vortex::array::ArrayRef, + rhs: vortex::array::ArrayRef, + len: usize, + ) -> VortexResult> { let scalar_fn = ScalarFn::new(CosineSimilarity, EmptyOptions).erased(); let result = ScalarFnArray::try_new(scalar_fn, vec![lhs, rhs], len)?; let prim = result.to_primitive(); Ok(prim.as_slice::().to_vec()) } - /// Asserts that each element in `actual` is within `1e-10` of the corresponding `expected` - /// value, with support for NaN (NaN == NaN is considered equal). - #[track_caller] - fn assert_close(actual: &[f64], expected: &[f64]) { - assert_eq!( - actual.len(), - expected.len(), - "length mismatch: got {} elements, expected {}", - actual.len(), - expected.len() - ); - - for (i, (a, e)) in actual.iter().zip(expected).enumerate() { - if a.is_nan() && e.is_nan() { - continue; - } - assert!( - (a - e).abs() < 1e-10, - "element {i}: got {a}, expected {e} (diff = {})", - (a - e).abs() - ); - } - } - #[test] fn unit_vectors_1d() -> VortexResult<()> { let lhs = tensor_array( @@ -281,20 +240,18 @@ mod tests { ], )?; - // Row 0: identical → 1.0, row 1: orthogonal → 0.0. + // Row 0: identical -> 1.0, row 1: orthogonal -> 0.0. assert_close(&eval_cosine_similarity(lhs, rhs, 2)?, &[1.0, 0.0]); Ok(()) } - use rstest::rstest; - /// Single-row cosine similarity for various vector pairs. #[rstest] - // Antiparallel → -1.0. + // Antiparallel -> -1.0. #[case::opposite(&[3], &[1.0, 0.0, 0.0], &[-1.0, 0.0, 0.0], &[-1.0])] - // dot=24, both magnitudes=5 → 24/25 = 0.96. + // dot=24, both magnitudes=5 -> 24/25 = 0.96. #[case::non_unit(&[2], &[3.0, 4.0], &[4.0, 3.0], &[0.96])] - // Zero vector → 0/0 → NaN. + // Zero vector -> 0/0 -> NaN. #[case::zero_norm(&[2], &[0.0, 0.0], &[1.0, 0.0], &[f64::NAN])] fn single_row( #[case] shape: &[usize], @@ -333,14 +290,14 @@ mod tests { let lhs = tensor_array(&[], &[5.0, 3.0])?; let rhs = tensor_array(&[], &[5.0, -3.0])?; - // Same sign → 1.0, opposite sign → -1.0. + // Same sign -> 1.0, opposite sign -> -1.0. assert_close(&eval_cosine_similarity(lhs, rhs, 2)?, &[1.0, -1.0]); Ok(()) } #[test] fn many_rows() -> VortexResult<()> { - // 5 tensors of shape [4] compared against themselves → all 1.0. + // 5 tensors of shape [4] compared against themselves -> all 1.0. let lhs = tensor_array( &[4], &[ @@ -360,35 +317,8 @@ mod tests { Ok(()) } - /// Builds an extension array whose storage is a [`ConstantArray`], representing a single - /// query tensor broadcast to `len` rows. - fn constant_tensor_array( - shape: &[usize], - elements: &[f64], - len: usize, - ) -> VortexResult { - let element_dtype = DType::Primitive(vortex::dtype::PType::F64, Nullability::NonNullable); - - // Build the FSL storage scalar from individual element scalars. - let children: Vec = elements - .iter() - .map(|&v| Scalar::primitive(v, Nullability::NonNullable)) - .collect(); - let storage_scalar = - Scalar::fixed_size_list(element_dtype, children, Nullability::NonNullable); - - // Wrap the FSL scalar in a ConstantArray to avoid materializing `len` copies. - let storage = ConstantArray::new(storage_scalar, len).into_array(); - - let metadata = FixedShapeTensorMetadata::new(shape.to_vec()); - let ext_dtype = - ExtDType::::try_new(metadata, storage.dtype().clone())?.erased(); - - Ok(ExtensionArray::new(ext_dtype, storage).into_array()) - } - #[test] - fn constant_query_vector() -> VortexResult<()> { + fn constant_query_tensor() -> VortexResult<()> { // Compare 4 tensors of shape [3] against a single constant query tensor [1,0,0]. let data = tensor_array( &[3], @@ -401,7 +331,6 @@ mod tests { )?; let query = constant_tensor_array(&[3], &[1.0, 0.0, 0.0], 4)?; - // Only tensor 0 is aligned with the query. assert_close( &eval_cosine_similarity(data, query, 4)?, &[1.0, 0.0, 0.0, 1.0], @@ -409,18 +338,6 @@ mod tests { Ok(()) } - /// Builds a [`Vector`] extension array from flat f64 elements and a vector dimension size. - fn vector_array(dim: u32, elements: &[f64]) -> VortexResult { - let row_count = elements.len() / dim as usize; - - let elems: ArrayRef = Buffer::copy_from(elements).into_array(); - let fsl = FixedSizeListArray::new(elems, dim, Validity::NonNullable, row_count); - - let ext_dtype = ExtDType::::try_new(EmptyMetadata, fsl.dtype().clone())?.erased(); - - Ok(ExtensionArray::new(ext_dtype, fsl.into_array()).into_array()) - } - #[test] fn vector_unit_vectors() -> VortexResult<()> { let lhs = vector_array( @@ -443,43 +360,6 @@ mod tests { Ok(()) } - #[test] - fn vector_self_similarity() -> VortexResult<()> { - let arr = vector_array( - 4, - &[ - 1.0, 2.0, 3.0, 4.0, // vector 0 - 0.0, 1.0, 0.0, 0.0, // vector 1 - 5.0, 0.0, 5.0, 0.0, // vector 2 - ], - )?; - - assert_close( - &eval_cosine_similarity(arr.clone(), arr, 3)?, - &[1.0, 1.0, 1.0], - ); - Ok(()) - } - - /// Builds a [`Vector`] extension array whose storage is a [`ConstantArray`]. - fn constant_vector_array(elements: &[f64], len: usize) -> VortexResult { - let element_dtype = DType::Primitive(vortex::dtype::PType::F64, Nullability::NonNullable); - - let children: Vec = elements - .iter() - .map(|&v| Scalar::primitive(v, Nullability::NonNullable)) - .collect(); - let storage_scalar = - Scalar::fixed_size_list(element_dtype, children, Nullability::NonNullable); - - let storage = ConstantArray::new(storage_scalar, len).into_array(); - - let ext_dtype = - ExtDType::::try_new(EmptyMetadata, storage.dtype().clone())?.erased(); - - Ok(ExtensionArray::new(ext_dtype, storage).into_array()) - } - #[test] fn vector_constant_query() -> VortexResult<()> { let data = vector_array( diff --git a/vortex-tensor/src/scalar_fns/l2_norm.rs b/vortex-tensor/src/scalar_fns/l2_norm.rs index bada964b7ef..5879bb79aab 100644 --- a/vortex-tensor/src/scalar_fns/l2_norm.rs +++ b/vortex-tensor/src/scalar_fns/l2_norm.rs @@ -19,6 +19,7 @@ use vortex::dtype::Nullability; use vortex::dtype::extension::Matcher; use vortex::error::VortexResult; use vortex::error::vortex_ensure; +use vortex::error::vortex_ensure_eq; use vortex::error::vortex_err; use vortex::expr::Expression; use vortex::scalar_fn::Arity; @@ -73,24 +74,29 @@ impl ScalarFnVTable for L2Norm { } fn return_dtype(&self, _options: &Self::Options, arg_dtypes: &[DType]) -> VortexResult { - debug_assert_eq!(arg_dtypes.len(), 1); + vortex_ensure_eq!( + arg_dtypes.len(), + 1, + "L2Norm requires exactly 2 arguments, got {}", + arg_dtypes.len() + ); let input_dtype = &arg_dtypes[0]; // Input must be a tensor-like extension type. let ext = input_dtype.as_extension_opt().ok_or_else(|| { - vortex_err!("l2_norm input must be an extension type, got {input_dtype}") + vortex_err!("L2Norm input must be an extension type, got {input_dtype}") })?; vortex_ensure!( AnyTensor::matches(ext), - "l2_norm input must be an `AnyTensor`, got {input_dtype}" + "L2Norm input must be an `AnyTensor`, got {input_dtype}" ); let ptype = extension_element_ptype(ext)?; vortex_ensure!( ptype.is_float(), - "l2_norm element dtype must be a float primitive, got {ptype}" + "L2Norm element dtype must be a float primitive, got {ptype}" ); let nullability = Nullability::from(input_dtype.is_nullable()); @@ -160,93 +166,25 @@ fn l2_norm_row(v: &[T]) -> T { #[cfg(test)] mod tests { use rstest::rstest; - use vortex::array::ArrayRef; - use vortex::array::IntoArray; use vortex::array::ToCanonical; - use vortex::array::arrays::ExtensionArray; - use vortex::array::arrays::FixedSizeListArray; use vortex::array::arrays::ScalarFnArray; - use vortex::array::validity::Validity; - use vortex::buffer::Buffer; - use vortex::dtype::extension::ExtDType; use vortex::error::VortexResult; - use vortex::extension::EmptyMetadata; use vortex::scalar_fn::EmptyOptions; use vortex::scalar_fn::ScalarFn; - use crate::fixed_shape::FixedShapeTensor; - use crate::fixed_shape::FixedShapeTensorMetadata; use crate::scalar_fns::l2_norm::L2Norm; - use crate::vector::Vector; - - /// Builds a [`FixedShapeTensor`] extension array from flat f64 elements and a logical shape. - fn tensor_array(shape: &[usize], elements: &[f64]) -> VortexResult { - let list_size: u32 = shape.iter().product::().max(1).try_into().unwrap(); - let row_count = elements.len() / list_size as usize; - - let elems: ArrayRef = Buffer::copy_from(elements).into_array(); - let fsl = FixedSizeListArray::new(elems, list_size, Validity::NonNullable, row_count); - - let metadata = FixedShapeTensorMetadata::new(shape.to_vec()); - let ext_dtype = - ExtDType::::try_new(metadata, fsl.dtype().clone())?.erased(); - - Ok(ExtensionArray::new(ext_dtype, fsl.into_array()).into_array()) - } - - /// Builds a [`Vector`] extension array from flat f64 elements and a vector dimension size. - fn vector_array(dim: u32, elements: &[f64]) -> VortexResult { - let row_count = elements.len() / dim as usize; - - let elems: ArrayRef = Buffer::copy_from(elements).into_array(); - let fsl = FixedSizeListArray::new(elems, dim, Validity::NonNullable, row_count); - - let ext_dtype = ExtDType::::try_new(EmptyMetadata, fsl.dtype().clone())?.erased(); - - Ok(ExtensionArray::new(ext_dtype, fsl.into_array()).into_array()) - } + use crate::scalar_fns::utils::test_helpers::assert_close; + use crate::scalar_fns::utils::test_helpers::tensor_array; + use crate::scalar_fns::utils::test_helpers::vector_array; /// Evaluates L2 norm on a tensor/vector array and returns the result as `Vec`. - fn eval_l2_norm(input: ArrayRef, len: usize) -> VortexResult> { + fn eval_l2_norm(input: vortex::array::ArrayRef, len: usize) -> VortexResult> { let scalar_fn = ScalarFn::new(L2Norm, EmptyOptions).erased(); let result = ScalarFnArray::try_new(scalar_fn, vec![input], len)?; let prim = result.to_primitive(); Ok(prim.as_slice::().to_vec()) } - #[track_caller] - fn assert_close(actual: &[f64], expected: &[f64]) { - assert_eq!( - actual.len(), - expected.len(), - "length mismatch: got {} elements, expected {}", - actual.len(), - expected.len() - ); - - for (i, (a, e)) in actual.iter().zip(expected).enumerate() { - assert!( - (a - e).abs() < 1e-10, - "element {i}: got {a}, expected {e} (diff = {})", - (a - e).abs() - ); - } - } - - #[test] - fn unit_vector_norm() -> VortexResult<()> { - let arr = tensor_array( - &[3], - &[ - 1.0, 0.0, 0.0, // unit x - 0.0, 1.0, 0.0, // unit y - 0.0, 0.0, 1.0, // unit z - ], - )?; - assert_close(&eval_l2_norm(arr, 3)?, &[1.0, 1.0, 1.0]); - Ok(()) - } - #[rstest] #[case::three_four_five(&[2], &[3.0, 4.0], &[5.0])] #[case::zero_vector(&[3], &[0.0, 0.0, 0.0], &[0.0])] @@ -276,13 +214,6 @@ mod tests { Ok(()) } - #[test] - fn vector_known_norm() -> VortexResult<()> { - let arr = vector_array(2, &[3.0, 4.0])?; - assert_close(&eval_l2_norm(arr, 1)?, &[5.0]); - Ok(()) - } - #[test] fn vector_multiple_rows() -> VortexResult<()> { let arr = vector_array( diff --git a/vortex-tensor/src/scalar_fns/utils.rs b/vortex-tensor/src/scalar_fns/utils.rs index ca7ddb47b02..196807d22a0 100644 --- a/vortex-tensor/src/scalar_fns/utils.rs +++ b/vortex-tensor/src/scalar_fns/utils.rs @@ -19,7 +19,7 @@ use vortex::error::vortex_err; /// Extracts the list size from a tensor-like extension dtype. /// /// The storage dtype must be a `FixedSizeList`. -pub(crate) fn extension_list_size(ext: &ExtDTypeRef) -> VortexResult { +pub fn extension_list_size(ext: &ExtDTypeRef) -> VortexResult { let DType::FixedSizeList(_, list_size, _) = ext.storage_dtype() else { vortex_bail!( "expected FixedSizeList storage dtype, got {}", @@ -33,7 +33,7 @@ pub(crate) fn extension_list_size(ext: &ExtDTypeRef) -> VortexResult { /// Extracts the float element [`PType`] from a tensor-like extension dtype. /// /// The storage dtype must be a `FixedSizeList` of non-nullable primitives. -pub(crate) fn extension_element_ptype(ext: &ExtDTypeRef) -> VortexResult { +pub fn extension_element_ptype(ext: &ExtDTypeRef) -> VortexResult { let element_dtype = ext .storage_dtype() .as_fixed_size_list_element_opt() @@ -53,7 +53,7 @@ pub(crate) fn extension_element_ptype(ext: &ExtDTypeRef) -> VortexResult } /// Extracts the storage array from an extension array without canonicalizing. -pub(crate) fn extension_storage(array: &ArrayRef) -> VortexResult { +pub fn extension_storage(array: &ArrayRef) -> VortexResult { let ext = array .as_opt::() .ok_or_else(|| vortex_err!("scalar_fn input must be an extension array"))?; @@ -66,7 +66,7 @@ pub(crate) fn extension_storage(array: &ArrayRef) -> VortexResult { /// This struct hides the stride detail that arises from the [`ConstantArray`] optimization: a /// constant input materializes only a single row (stride=0), while a full array uses /// stride=list_size. -pub(crate) struct FlatElements { +pub struct FlatElements { elems: PrimitiveArray, stride: usize, list_size: usize, @@ -74,11 +74,13 @@ pub(crate) struct FlatElements { impl FlatElements { /// Returns the [`PType`] of the underlying elements. + #[must_use] pub fn ptype(&self) -> PType { self.elems.ptype() } /// Returns the `i`-th row as a typed slice of length `list_size`. + #[must_use] pub fn row(&self, i: usize) -> &[T] { let slice = self.elems.as_slice::(); &slice[i * self.stride..i * self.stride + self.list_size] @@ -89,10 +91,7 @@ impl FlatElements { /// /// When the input is a [`ConstantArray`] (e.g., a literal query vector), only a single row is /// materialized to avoid expanding it to the full column length. -pub(crate) fn extract_flat_elements( - storage: &ArrayRef, - list_size: usize, -) -> VortexResult { +pub fn extract_flat_elements(storage: &ArrayRef, list_size: usize) -> VortexResult { if let Some(constant) = storage.as_opt::() { // Rewrite the array as a length 1 array so when we canonicalize, we do not duplicate a huge // amount of data. @@ -115,3 +114,123 @@ pub(crate) fn extract_flat_elements( list_size, }) } + +#[cfg(test)] +pub mod test_helpers { + use vortex::array::ArrayRef; + use vortex::array::IntoArray; + use vortex::array::arrays::ConstantArray; + use vortex::array::arrays::ExtensionArray; + use vortex::array::arrays::FixedSizeListArray; + use vortex::array::validity::Validity; + use vortex::buffer::Buffer; + use vortex::dtype::DType; + use vortex::dtype::Nullability; + use vortex::dtype::extension::ExtDType; + use vortex::error::VortexResult; + use vortex::extension::EmptyMetadata; + use vortex::scalar::Scalar; + + use crate::fixed_shape::FixedShapeTensor; + use crate::fixed_shape::FixedShapeTensorMetadata; + use crate::vector::Vector; + + /// Builds a [`FixedShapeTensor`] extension array from flat f64 elements and a logical shape. + /// + /// The number of rows is inferred from the total element count divided by the product of the + /// shape dimensions. For 0-dimensional tensors (scalar), each element is one row. + pub fn tensor_array(shape: &[usize], elements: &[f64]) -> VortexResult { + let list_size: u32 = shape.iter().product::().max(1).try_into().unwrap(); + let row_count = elements.len() / list_size as usize; + + let elems: ArrayRef = Buffer::copy_from(elements).into_array(); + let fsl = FixedSizeListArray::new(elems, list_size, Validity::NonNullable, row_count); + + let metadata = FixedShapeTensorMetadata::new(shape.to_vec()); + let ext_dtype = + ExtDType::::try_new(metadata, fsl.dtype().clone())?.erased(); + + Ok(ExtensionArray::new(ext_dtype, fsl.into_array()).into_array()) + } + + /// Builds a [`Vector`] extension array from flat f64 elements and a vector dimension size. + pub fn vector_array(dim: u32, elements: &[f64]) -> VortexResult { + let row_count = elements.len() / dim as usize; + + let elems: ArrayRef = Buffer::copy_from(elements).into_array(); + let fsl = FixedSizeListArray::new(elems, dim, Validity::NonNullable, row_count); + + let ext_dtype = ExtDType::::try_new(EmptyMetadata, fsl.dtype().clone())?.erased(); + + Ok(ExtensionArray::new(ext_dtype, fsl.into_array()).into_array()) + } + + /// Builds a [`FixedShapeTensor`] extension array whose storage is a [`ConstantArray`], + /// representing a single query tensor broadcast to `len` rows. + pub fn constant_tensor_array( + shape: &[usize], + elements: &[f64], + len: usize, + ) -> VortexResult { + let element_dtype = DType::Primitive(vortex::dtype::PType::F64, Nullability::NonNullable); + + let children: Vec = elements + .iter() + .map(|&v| Scalar::primitive(v, Nullability::NonNullable)) + .collect(); + let storage_scalar = + Scalar::fixed_size_list(element_dtype, children, Nullability::NonNullable); + + let storage = ConstantArray::new(storage_scalar, len).into_array(); + + let metadata = FixedShapeTensorMetadata::new(shape.to_vec()); + let ext_dtype = + ExtDType::::try_new(metadata, storage.dtype().clone())?.erased(); + + Ok(ExtensionArray::new(ext_dtype, storage).into_array()) + } + + /// Builds a [`Vector`] extension array whose storage is a [`ConstantArray`], representing a + /// single query vector broadcast to `len` rows. + pub fn constant_vector_array(elements: &[f64], len: usize) -> VortexResult { + let element_dtype = DType::Primitive(vortex::dtype::PType::F64, Nullability::NonNullable); + + let children: Vec = elements + .iter() + .map(|&v| Scalar::primitive(v, Nullability::NonNullable)) + .collect(); + let storage_scalar = + Scalar::fixed_size_list(element_dtype, children, Nullability::NonNullable); + + let storage = ConstantArray::new(storage_scalar, len).into_array(); + + let ext_dtype = + ExtDType::::try_new(EmptyMetadata, storage.dtype().clone())?.erased(); + + Ok(ExtensionArray::new(ext_dtype, storage).into_array()) + } + + /// Asserts that each element in `actual` is within `1e-10` of the corresponding `expected` + /// value, with support for NaN (NaN == NaN is considered equal). + #[track_caller] + pub fn assert_close(actual: &[f64], expected: &[f64]) { + assert_eq!( + actual.len(), + expected.len(), + "length mismatch: got {} elements, expected {}", + actual.len(), + expected.len() + ); + + for (i, (a, e)) in actual.iter().zip(expected).enumerate() { + if a.is_nan() && e.is_nan() { + continue; + } + assert!( + (a - e).abs() < 1e-10, + "element {i}: got {a}, expected {e} (diff = {})", + (a - e).abs() + ); + } + } +} diff --git a/vortex-tensor/src/vector/mod.rs b/vortex-tensor/src/vector/mod.rs index 181e08c4e84..acba6792154 100644 --- a/vortex-tensor/src/vector/mod.rs +++ b/vortex-tensor/src/vector/mod.rs @@ -3,7 +3,7 @@ //! Vector extension type for fixed-length float vectors (e.g., embeddings). -/// The VTable for the vector extension type. +/// The Vector extension type. #[derive(Clone, Debug, Default, PartialEq, Eq, Hash)] pub struct Vector; From acdbd8982d16bd7bf0a2ac6455a6bb8187fa6b53 Mon Sep 17 00:00:00 2001 From: Connor Tsui Date: Mon, 16 Mar 2026 15:07:47 -0400 Subject: [PATCH 7/7] address comments Signed-off-by: Connor Tsui --- vortex-tensor/src/fixed_shape/vtable.rs | 2 +- vortex-tensor/src/scalar_fns/cosine_similarity.rs | 5 ++--- vortex-tensor/src/scalar_fns/l2_norm.rs | 5 ++--- vortex-tensor/src/vector/vtable.rs | 2 +- 4 files changed, 6 insertions(+), 8 deletions(-) diff --git a/vortex-tensor/src/fixed_shape/vtable.rs b/vortex-tensor/src/fixed_shape/vtable.rs index 15e47456ba9..136bb4b4718 100644 --- a/vortex-tensor/src/fixed_shape/vtable.rs +++ b/vortex-tensor/src/fixed_shape/vtable.rs @@ -22,7 +22,7 @@ impl ExtVTable for FixedShapeTensor { type NativeValue<'a> = &'a ScalarValue; fn id(&self) -> ExtId { - ExtId::new_ref("vortex.fixed_shape_tensor") + ExtId::new_ref("vortex.tensor.fixed_shape_tensor") } fn serialize_metadata(&self, metadata: &Self::Metadata) -> VortexResult> { diff --git a/vortex-tensor/src/scalar_fns/cosine_similarity.rs b/vortex-tensor/src/scalar_fns/cosine_similarity.rs index f90fdb7d006..b656be107ff 100644 --- a/vortex-tensor/src/scalar_fns/cosine_similarity.rs +++ b/vortex-tensor/src/scalar_fns/cosine_similarity.rs @@ -16,7 +16,6 @@ use vortex::array::match_each_float_ptype; use vortex::dtype::DType; use vortex::dtype::NativePType; use vortex::dtype::Nullability; -use vortex::dtype::extension::Matcher; use vortex::error::VortexResult; use vortex::error::vortex_ensure; use vortex::error::vortex_ensure_eq; @@ -53,7 +52,7 @@ impl ScalarFnVTable for CosineSimilarity { type Options = EmptyOptions; fn id(&self) -> ScalarFnId { - ScalarFnId::new_ref("vortex.cosine_similarity") + ScalarFnId::new_ref("vortex.tensor.cosine_similarity") } fn arity(&self, _options: &Self::Options) -> Arity { @@ -106,7 +105,7 @@ impl ScalarFnVTable for CosineSimilarity { })?; vortex_ensure!( - AnyTensor::matches(lhs_ext), + lhs_ext.is::(), "CosineSimilarity inputs must be an `AnyTensor`, got {lhs}" ); diff --git a/vortex-tensor/src/scalar_fns/l2_norm.rs b/vortex-tensor/src/scalar_fns/l2_norm.rs index 5879bb79aab..ff244d458d3 100644 --- a/vortex-tensor/src/scalar_fns/l2_norm.rs +++ b/vortex-tensor/src/scalar_fns/l2_norm.rs @@ -16,7 +16,6 @@ use vortex::array::match_each_float_ptype; use vortex::dtype::DType; use vortex::dtype::NativePType; use vortex::dtype::Nullability; -use vortex::dtype::extension::Matcher; use vortex::error::VortexResult; use vortex::error::vortex_ensure; use vortex::error::vortex_ensure_eq; @@ -48,7 +47,7 @@ impl ScalarFnVTable for L2Norm { type Options = EmptyOptions; fn id(&self) -> ScalarFnId { - ScalarFnId::new_ref("vortex.l2_norm") + ScalarFnId::new_ref("vortex.tensor.l2_norm") } fn arity(&self, _options: &Self::Options) -> Arity { @@ -89,7 +88,7 @@ impl ScalarFnVTable for L2Norm { })?; vortex_ensure!( - AnyTensor::matches(ext), + ext.is::(), "L2Norm input must be an `AnyTensor`, got {input_dtype}" ); diff --git a/vortex-tensor/src/vector/vtable.rs b/vortex-tensor/src/vector/vtable.rs index a3206c5150e..6ab849ab4e9 100644 --- a/vortex-tensor/src/vector/vtable.rs +++ b/vortex-tensor/src/vector/vtable.rs @@ -20,7 +20,7 @@ impl ExtVTable for Vector { type NativeValue<'a> = &'a ScalarValue; fn id(&self) -> ExtId { - ExtId::new_ref("vortex.vector") + ExtId::new_ref("vortex.tensor.vector") } fn serialize_metadata(&self, _metadata: &Self::Metadata) -> VortexResult> {