Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions vortex-array/src/scalar_fn/vtable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,8 @@ pub trait ScalarFnVTable: 'static + Sized + Clone + Send + Sync {
Ok(args.to_vec())
}

// TODO(connor): This needs a precondition for the number of args it has, or all implementations
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There should already be a check using the arity of the function

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried backtracing through the code myself but I couldn't find where that is called

// need to return an error if it is wrong.
/// Compute the return [`DType`] of the expression if evaluated over the given input types.
fn return_dtype(&self, options: &Self::Options, args: &[DType]) -> VortexResult<DType>;

Expand Down
102 changes: 102 additions & 0 deletions vortex-tensor/public-api.lock
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,34 @@ pub fn vortex_tensor::fixed_shape::FixedShapeTensorMetadata::hash<__H: core::has

impl core::marker::StructuralPartialEq for vortex_tensor::fixed_shape::FixedShapeTensorMetadata

pub mod vortex_tensor::matcher

pub enum vortex_tensor::matcher::TensorMatch<'a>

pub vortex_tensor::matcher::TensorMatch::FixedShapeTensor(&'a vortex_tensor::fixed_shape::FixedShapeTensorMetadata)

pub vortex_tensor::matcher::TensorMatch::Vector

impl<'a> core::cmp::Eq for vortex_tensor::matcher::TensorMatch<'a>

impl<'a> core::cmp::PartialEq for vortex_tensor::matcher::TensorMatch<'a>

pub fn vortex_tensor::matcher::TensorMatch<'a>::eq(&self, other: &vortex_tensor::matcher::TensorMatch<'a>) -> bool

impl<'a> core::fmt::Debug for vortex_tensor::matcher::TensorMatch<'a>

pub fn vortex_tensor::matcher::TensorMatch<'a>::fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result

impl<'a> core::marker::StructuralPartialEq for vortex_tensor::matcher::TensorMatch<'a>

pub struct vortex_tensor::matcher::AnyTensor

impl vortex_array::dtype::extension::matcher::Matcher for vortex_tensor::matcher::AnyTensor

pub type vortex_tensor::matcher::AnyTensor::Match<'a> = vortex_tensor::matcher::TensorMatch<'a>

pub fn vortex_tensor::matcher::AnyTensor::try_match<'a>(item: &'a vortex_array::dtype::extension::erased::ExtDTypeRef) -> core::option::Option<Self::Match>

pub mod vortex_tensor::scalar_fns

pub mod vortex_tensor::scalar_fns::cosine_similarity
Expand Down Expand Up @@ -121,3 +149,77 @@ pub fn vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity::is_null_s
pub fn vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity::return_dtype(&self, _options: &Self::Options, arg_dtypes: &[vortex_array::dtype::DType]) -> vortex_error::VortexResult<vortex_array::dtype::DType>

pub fn vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity::validity(&self, _options: &Self::Options, expression: &vortex_array::expr::expression::Expression) -> vortex_error::VortexResult<core::option::Option<vortex_array::expr::expression::Expression>>

pub mod vortex_tensor::scalar_fns::l2_norm

pub struct vortex_tensor::scalar_fns::l2_norm::L2Norm

impl core::clone::Clone for vortex_tensor::scalar_fns::l2_norm::L2Norm

pub fn vortex_tensor::scalar_fns::l2_norm::L2Norm::clone(&self) -> vortex_tensor::scalar_fns::l2_norm::L2Norm

impl vortex_array::scalar_fn::vtable::ScalarFnVTable for vortex_tensor::scalar_fns::l2_norm::L2Norm

pub type vortex_tensor::scalar_fns::l2_norm::L2Norm::Options = vortex_array::scalar_fn::vtable::EmptyOptions

pub fn vortex_tensor::scalar_fns::l2_norm::L2Norm::arity(&self, _options: &Self::Options) -> vortex_array::scalar_fn::vtable::Arity

pub fn vortex_tensor::scalar_fns::l2_norm::L2Norm::child_name(&self, _options: &Self::Options, child_idx: usize) -> vortex_array::scalar_fn::vtable::ChildName

pub fn vortex_tensor::scalar_fns::l2_norm::L2Norm::execute(&self, _options: &Self::Options, args: &dyn vortex_array::scalar_fn::vtable::ExecutionArgs, _ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult<vortex_array::array::ArrayRef>

pub fn vortex_tensor::scalar_fns::l2_norm::L2Norm::fmt_sql(&self, _options: &Self::Options, expr: &vortex_array::expr::expression::Expression, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result

pub fn vortex_tensor::scalar_fns::l2_norm::L2Norm::id(&self) -> vortex_array::scalar_fn::ScalarFnId

pub fn vortex_tensor::scalar_fns::l2_norm::L2Norm::is_fallible(&self, _options: &Self::Options) -> bool

pub fn vortex_tensor::scalar_fns::l2_norm::L2Norm::is_null_sensitive(&self, _options: &Self::Options) -> bool

pub fn vortex_tensor::scalar_fns::l2_norm::L2Norm::return_dtype(&self, _options: &Self::Options, arg_dtypes: &[vortex_array::dtype::DType]) -> vortex_error::VortexResult<vortex_array::dtype::DType>

pub fn vortex_tensor::scalar_fns::l2_norm::L2Norm::validity(&self, _options: &Self::Options, expression: &vortex_array::expr::expression::Expression) -> vortex_error::VortexResult<core::option::Option<vortex_array::expr::expression::Expression>>

pub mod vortex_tensor::vector

pub struct vortex_tensor::vector::Vector

impl core::clone::Clone for vortex_tensor::vector::Vector

pub fn vortex_tensor::vector::Vector::clone(&self) -> vortex_tensor::vector::Vector

impl core::cmp::Eq for vortex_tensor::vector::Vector

impl core::cmp::PartialEq for vortex_tensor::vector::Vector

pub fn vortex_tensor::vector::Vector::eq(&self, other: &vortex_tensor::vector::Vector) -> bool

impl core::default::Default for vortex_tensor::vector::Vector

pub fn vortex_tensor::vector::Vector::default() -> vortex_tensor::vector::Vector

impl core::fmt::Debug for vortex_tensor::vector::Vector

pub fn vortex_tensor::vector::Vector::fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result

impl core::hash::Hash for vortex_tensor::vector::Vector

pub fn vortex_tensor::vector::Vector::hash<__H: core::hash::Hasher>(&self, state: &mut __H)

impl core::marker::StructuralPartialEq for vortex_tensor::vector::Vector

impl vortex_array::dtype::extension::vtable::ExtVTable for vortex_tensor::vector::Vector

pub type vortex_tensor::vector::Vector::Metadata = vortex_array::extension::EmptyMetadata

pub type vortex_tensor::vector::Vector::NativeValue<'a> = &'a vortex_array::scalar::scalar_value::ScalarValue

pub fn vortex_tensor::vector::Vector::deserialize_metadata(&self, _metadata: &[u8]) -> vortex_error::VortexResult<Self::Metadata>

pub fn vortex_tensor::vector::Vector::id(&self) -> vortex_array::dtype::extension::ExtId

pub fn vortex_tensor::vector::Vector::serialize_metadata(&self, _metadata: &Self::Metadata) -> vortex_error::VortexResult<alloc::vec::Vec<u8>>

pub fn vortex_tensor::vector::Vector::unpack_native<'a>(&self, _ext_dtype: &'a vortex_array::dtype::extension::typed::ExtDType<Self>, storage_value: &'a vortex_array::scalar::scalar_value::ScalarValue) -> vortex_error::VortexResult<Self::NativeValue>

pub fn vortex_tensor::vector::Vector::validate_dtype(&self, ext_dtype: &vortex_array::dtype::extension::typed::ExtDType<Self>) -> vortex_error::VortexResult<()>
2 changes: 1 addition & 1 deletion vortex-tensor/src/fixed_shape/vtable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ impl ExtVTable for FixedShapeTensor {
type NativeValue<'a> = &'a ScalarValue;

fn id(&self) -> ExtId {
ExtId::new_ref("vortex.fixed_shape_tensor")
ExtId::new_ref("vortex.tensor.fixed_shape_tensor")
}

fn serialize_metadata(&self, metadata: &Self::Metadata) -> VortexResult<Vec<u8>> {
Expand Down
2 changes: 2 additions & 0 deletions vortex-tensor/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,7 @@
//! similarity.

pub mod fixed_shape;
pub mod vector;

pub mod matcher;
pub mod scalar_fns;
42 changes: 42 additions & 0 deletions vortex-tensor/src/matcher.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright the Vortex contributors

//! Matcher for tensor-like extension types.

use vortex::dtype::extension::ExtDTypeRef;
use vortex::dtype::extension::Matcher;

use crate::fixed_shape::FixedShapeTensor;
use crate::fixed_shape::FixedShapeTensorMetadata;
use crate::vector::Vector;

/// Matcher for any tensor-like extension type.
///
/// Currently the different kinds of tensors that are available are:
///
/// - `FixedShapeTensor`
/// - `Vector`
pub struct AnyTensor;

/// The matched variant of a tensor-like extension type.
#[derive(Debug, PartialEq, Eq)]
pub enum TensorMatch<'a> {
/// A [`FixedShapeTensor`] extension type.
FixedShapeTensor(&'a FixedShapeTensorMetadata),
/// A [`Vector`] extension type.
Vector,
}

impl Matcher for AnyTensor {
type Match<'a> = TensorMatch<'a>;

fn try_match<'a>(item: &'a ExtDTypeRef) -> Option<Self::Match<'a>> {
if let Some(metadata) = item.metadata_opt::<FixedShapeTensor>() {
return Some(TensorMatch::FixedShapeTensor(metadata));
}
if item.metadata_opt::<Vector>().is_some() {
return Some(TensorMatch::Vector);
}
None
}
}
Loading
Loading