Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions checker/optional.cc
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,13 @@ Type OptionalMapOfKV() {
return *kInstance;
}

Type ListOfOptionalV() {
static const absl::NoDestructor<ListType> kInstance(
checker_internal::BuiltinsArena(), OptionalOfV());

return *kInstance;
}

class OptionalNames {
public:
static constexpr char kOptionalType[] = "optional_type";
Expand All @@ -85,6 +92,8 @@ class OptionalNames {
static constexpr char kOptionalIndex[] = "_[?_]";
static constexpr char kOptionalFirst[] = "first";
static constexpr char kOptionalLast[] = "last";
static constexpr char kOptionalUnwrap[] = "optional.unwrap";
static constexpr char kOptionalUnwrapOpt[] = "unwrapOpt";
};

class OptionalOverloads {
Expand Down Expand Up @@ -114,6 +123,9 @@ class OptionalOverloads {
// Syntactic sugar for chained indexing.
static constexpr char kOptionalListIndexInt[] = "optional_list_index_int";
static constexpr char kOptionalMapIndexValue[] = "optional_map_index_value";
// Unwrapping
static constexpr char kOptionalUnwrapList[] = "optional_unwrap_list";
static constexpr char kOptionalUnwrapOptList[] = "optional_unwrapOpt_list";
};

absl::Status RegisterOptionalDecls(TypeCheckerBuilder& builder, int version) {
Expand Down Expand Up @@ -207,6 +219,17 @@ absl::Status RegisterOptionalDecls(TypeCheckerBuilder& builder, int version) {
OptionalOfV(), OptionalMapOfKV(),
TypeParamType("K"))));

CEL_ASSIGN_OR_RETURN(
auto unwrap,
MakeFunctionDecl(
OptionalNames::kOptionalUnwrap,
MakeOverloadDecl(OptionalOverloads::kOptionalUnwrapList, ListOfV(), ListOfOptionalV())));
CEL_ASSIGN_OR_RETURN(
auto unwrap_opt,
MakeFunctionDecl(
OptionalNames::kOptionalUnwrapOpt,
MakeMemberOverloadDecl(OptionalOverloads::kOptionalUnwrapOptList, ListOfV(), ListOfOptionalV())));

CEL_RETURN_IF_ERROR(builder.AddVariable(
MakeVariableDecl(OptionalNames::kOptionalType, TypeOfOptionalOfV())));

Expand All @@ -220,6 +243,8 @@ absl::Status RegisterOptionalDecls(TypeCheckerBuilder& builder, int version) {
CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(opt_index)));
CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(select)));
CEL_RETURN_IF_ERROR(builder.MergeFunction(std::move(index)));
CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(unwrap)));
CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(unwrap_opt)));

if (version == 0 || version == 1) {
return absl::OkStatus();
Expand Down
2 changes: 2 additions & 0 deletions checker/standard_library.cc
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,8 @@ absl::Status AddTypeConversions(TypeCheckerBuilder& builder) {
// Int
FunctionDecl to_int;
to_int.set_name(StandardFunctions::kInt);
CEL_RETURN_IF_ERROR(to_int.AddOverload(
MakeOverloadDecl(StandardOverloadIds::kBoolToInt, IntType(), BoolType())));
CEL_RETURN_IF_ERROR(to_int.AddOverload(
MakeOverloadDecl(StandardOverloadIds::kIntToInt, IntType(), IntType())));
CEL_RETURN_IF_ERROR(to_int.AddOverload(MakeOverloadDecl(
Expand Down
2 changes: 2 additions & 0 deletions common/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1028,7 +1028,9 @@ cc_library(
],
deps = [
":kind",
":type",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
],
Expand Down
171 changes: 161 additions & 10 deletions common/function_descriptor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,152 @@
#include <cstddef>

#include "absl/base/macros.h"
#include "absl/container/flat_hash_map.h"
#include "absl/types/span.h"
#include "common/kind.h"
#include "common/type.h"

namespace cel {

namespace {

// Recursive type matching with TypeParam binding support.
// Returns true if types match (i.e., they conflict and cannot coexist).
//
// Matching semantics:
// - dyn/any act as wildcards that match any type
// - TypeParam requires consistent binding across all arguments
// - Container types (list, map, opaque) are matched recursively
//
// The bindings map tracks TypeParam name -> bound Type mappings to ensure
// consistent binding. For example, (A, A) matches (int, int) but not
// (int, string) because A cannot bind to both int and string.
bool TypeMatches(const Type& a, const Type& b,
absl::flat_hash_map<std::string, Type>& bindings) {
// Wildcard types match anything (consistent with type checker's IsWildCardType)
// - dyn: dynamic type, defers type checking to runtime
// - any: legacy wildcard type
// - error: error type propagates through expressions
if (a.IsDyn() || b.IsDyn()) {
return true;
}
if (a.IsAny() || b.IsAny()) {
return true;
}
if (a.IsError() || b.IsError()) {
return true;
}

// TypeParam handling - requires consistent binding
if (a.IsTypeParam()) {
std::string name(a.GetTypeParam().name());
auto it = bindings.find(name);
if (it != bindings.end()) {
// Already bound - check consistency with the bound type
return TypeMatches(it->second, b, bindings);
} else {
// Not yet bound - bind to b and return match
bindings[name] = b;
return true;
}
}
if (b.IsTypeParam()) {
std::string name(b.GetTypeParam().name());
auto it = bindings.find(name);
if (it != bindings.end()) {
// Already bound - check consistency with the bound type
return TypeMatches(a, it->second, bindings);
} else {
// Not yet bound - bind to a and return match
bindings[name] = a;
return true;
}
}

// Different kinds don't match (e.g., int vs string)
TypeKind a_kind = a.kind();
TypeKind b_kind = b.kind();
if (a_kind != b_kind) {
return false;
}

// Same kind - check type-specific details
switch (a_kind) {
case TypeKind::kList: {
// Recursively check element types
return TypeMatches(a.GetList().GetElement(),
b.GetList().GetElement(), bindings);
}
case TypeKind::kMap: {
// Recursively check key and value types
return TypeMatches(a.GetMap().GetKey(), b.GetMap().GetKey(),
bindings) &&
TypeMatches(a.GetMap().GetValue(), b.GetMap().GetValue(),
bindings);
}
case TypeKind::kStruct: {
// Empty StructType acts as wildcard for any struct
StructType struct_a = a.GetStruct();
StructType struct_b = b.GetStruct();
if (!struct_a || !struct_b) {
// Empty struct matches any struct
return true;
}
// Both have names - must match exactly
return struct_a.name() == struct_b.name();
}
case TypeKind::kOpaque: {
// Opaque types (including Optional) - must have same name and
// recursively matching parameters
OpaqueType opaque_a = a.GetOpaque();
OpaqueType opaque_b = b.GetOpaque();
// Empty OpaqueType acts as wildcard for any opaque
if (!opaque_a || !opaque_b) {
return true;
}
if (opaque_a.name() != opaque_b.name()) {
return false;
}
TypeParameters params_a = opaque_a.GetParameters();
TypeParameters params_b = opaque_b.GetParameters();
if (params_a.size() != params_b.size()) {
return false;
}
for (size_t i = 0; i < params_a.size(); i++) {
if (!TypeMatches(params_a[i], params_b[i], bindings)) {
return false;
}
}
return true;
}
default:
// Basic types with same kind always match
return true;
}
}

// Converts a span of Kinds to a vector of Types.
// Uses the implicit Type(Kind) constructor for conversion.
std::vector<Type> KindsToTypes(absl::Span<const Kind> kinds) {
// Uses implicit Type(Kind) constructor for conversion.
std::vector<Type> types(kinds.begin(), kinds.end());
return types;
}

} // namespace

const std::vector<Kind>& FunctionDescriptor::kinds() const {
return impl_->kinds;
}

bool FunctionDescriptor::ShapeMatches(bool receiver_style,
absl::Span<const Kind> types) const {
// Convert Kinds to Types and delegate to the Type-based version.
return ShapeMatches(receiver_style, KindsToTypes(types));
}

bool FunctionDescriptor::ShapeMatches(bool receiver_style,
absl::Span<const Type> types) const {
if (this->receiver_style() != receiver_style) {
return false;
}
Expand All @@ -33,23 +172,33 @@ bool FunctionDescriptor::ShapeMatches(bool receiver_style,
return false;
}

// Use type-level matching with consistent TypeParam binding.
// The binding context is shared across all arguments to ensure
// TypeParam consistency (e.g., (A, A) requires both args to be same type).
absl::flat_hash_map<std::string, Type> bindings;

for (size_t i = 0; i < this->types().size(); i++) {
Kind this_type = this->types()[i];
Kind other_type = types[i];
if (this_type != Kind::kAny && other_type != Kind::kAny &&
this_type != other_type) {
if (!TypeMatches(this->types()[i], types[i], bindings)) {
return false;
}
}
return true;
}

bool FunctionDescriptor::operator==(const FunctionDescriptor& other) const {
return impl_.get() == other.impl_.get() ||
(name() == other.name() &&
receiver_style() == other.receiver_style() &&
types().size() == other.types().size() &&
std::equal(types().begin(), types().end(), other.types().begin()));
if (impl_.get() == other.impl_.get()) {
return true;
}
if (name() != other.name() ||
receiver_style() != other.receiver_style() ||
types().size() != other.types().size()) {
return false;
}
// Compare using Kind for backward compatibility.
// Full Type comparison can be added later if needed.
const auto& lhs_types = types();
const auto& rhs_types = other.types();
return std::equal(lhs_types.begin(), lhs_types.end(), rhs_types.begin());
}

bool FunctionDescriptor::operator<(const FunctionDescriptor& other) const {
Expand All @@ -73,7 +222,9 @@ bool FunctionDescriptor::operator<(const FunctionDescriptor& other) const {
auto rhs_begin = other.types().begin();
auto rhs_end = other.types().end();
while (lhs_begin != lhs_end && rhs_begin != rhs_end) {
if (*lhs_begin < *rhs_begin) {
// Compare types lexicographically using DebugString as stable ordering
// This ensures consistent ordering for all Type variants
if (lhs_begin->DebugString() < rhs_begin->DebugString()) {
return true;
}
if (!(*lhs_begin == *rhs_begin)) {
Expand Down
Loading