Skip to content

Commit 0b2bf66

Browse files
l46kokcopybara-github
authored andcommitted
Create CelBlock abstraction to centralize cel.@block logic
PiperOrigin-RevId: 940677666
1 parent 3e7dea1 commit 0b2bf66

8 files changed

Lines changed: 217 additions & 99 deletions

File tree

common/ast/BUILD.bazel

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,12 @@ java_library(
1111
exports = ["//common/src/main/java/dev/cel/common/ast"],
1212
)
1313

14+
java_library(
15+
name = "cel_block",
16+
visibility = ["//:internal"],
17+
exports = ["//common/src/main/java/dev/cel/common/ast:cel_block"],
18+
)
19+
1420
cel_android_library(
1521
name = "ast_android",
1622
exports = ["//common/src/main/java/dev/cel/common/ast:ast_android"],

common/src/main/java/dev/cel/common/ast/BUILD.bazel

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,20 @@ java_library(
5757
],
5858
)
5959

60+
java_library(
61+
name = "cel_block",
62+
srcs = ["CelBlock.java"],
63+
tags = [
64+
],
65+
deps = [
66+
":ast",
67+
"//common:cel_ast",
68+
"//common/annotations",
69+
"//common/navigation",
70+
"@maven//:com_google_guava_guava",
71+
],
72+
)
73+
6074
java_library(
6175
name = "expr_converter",
6276
srcs = EXPR_CONVERTER_SOURCES,
Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
// Copyright 2026 Google LLC
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// https://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package dev.cel.common.ast;
16+
17+
import static com.google.common.collect.ImmutableList.toImmutableList;
18+
19+
import com.google.common.base.Preconditions;
20+
import com.google.common.collect.ImmutableList;
21+
import dev.cel.common.CelAbstractSyntaxTree;
22+
import dev.cel.common.annotations.Internal;
23+
import dev.cel.common.navigation.CelNavigableExpr;
24+
import java.util.Optional;
25+
26+
/**
27+
* Represents a {@code cel.@block} expression.
28+
*
29+
* <p>CEL Block is used by the CSE (Common Subexpression Elimination) optimizer to hoist common
30+
* subexpressions into an evaluated block.
31+
*/
32+
@Internal
33+
public final class CelBlock {
34+
public static final String FUNCTION_NAME = "cel.@block";
35+
public static final String INDEX_PREFIX = "@index";
36+
37+
private final CelExpr blockExpr;
38+
39+
private CelBlock(CelExpr blockExpr) {
40+
this.blockExpr = blockExpr;
41+
}
42+
43+
public ImmutableList<CelExpr> indices() {
44+
return blockExpr.call().args().get(0).list().elements();
45+
}
46+
47+
public CelExpr result() {
48+
return blockExpr.call().args().get(1);
49+
}
50+
51+
public CelExpr expr() {
52+
return blockExpr;
53+
}
54+
55+
/**
56+
* Extracts a {@link CelBlock} from the given AST.
57+
*
58+
* <p>Enforces the contract that {@code cel.@block} must only appear exactly once and at the root
59+
* of the AST.
60+
*
61+
* @throws IllegalArgumentException if the block is malformed or its indices are invalid.
62+
*/
63+
public static Optional<CelBlock> extract(CelAbstractSyntaxTree ast) {
64+
CelNavigableExpr celNavigableExpr = CelNavigableExpr.fromExpr(ast.getExpr());
65+
66+
ImmutableList<CelExpr> allCelBlocks =
67+
celNavigableExpr
68+
.allNodes()
69+
.map(CelNavigableExpr::expr)
70+
.filter(expr -> expr.callOrDefault().function().equals(FUNCTION_NAME))
71+
.collect(toImmutableList());
72+
if (allCelBlocks.isEmpty()) {
73+
return Optional.empty();
74+
}
75+
76+
Preconditions.checkArgument(
77+
allCelBlocks.size() == 1,
78+
"Expected 1 cel.block function to be present but found %s",
79+
allCelBlocks.size());
80+
Preconditions.checkArgument(
81+
celNavigableExpr.expr().equals(allCelBlocks.get(0)),
82+
"Expected cel.block to be present at root");
83+
84+
return Optional.of(fromExpr(allCelBlocks.get(0)));
85+
}
86+
87+
/**
88+
* Constructs a {@link CelBlock} from a {@link CelExpr}.
89+
*
90+
* @throws IllegalArgumentException if the expression is not a valid block.
91+
*/
92+
private static CelBlock fromExpr(CelExpr expr) {
93+
Preconditions.checkArgument(
94+
expr.exprKind().getKind() == CelExpr.ExprKind.Kind.CALL,
95+
"Expected cel.@block to be a call expression");
96+
Preconditions.checkArgument(
97+
expr.call().function().equals(FUNCTION_NAME), "Expected function to be cel.@block");
98+
Preconditions.checkArgument(
99+
expr.call().args().size() == 2, "Expected exactly 2 arguments for cel.@block");
100+
Preconditions.checkArgument(
101+
expr.call().args().get(0).exprKind().getKind() == CelExpr.ExprKind.Kind.LIST,
102+
"Expected first argument of cel.@block to be a list");
103+
104+
CelBlock block = new CelBlock(expr);
105+
106+
// Assert correctness on block indices used in subexpressions
107+
ImmutableList<CelExpr> subexprs = block.indices();
108+
for (int i = 0; i < subexprs.size(); i++) {
109+
verifyBlockIndex(subexprs.get(i), i, expr);
110+
}
111+
112+
// Assert correctness on block indices used in block result
113+
CelExpr blockResult = block.result();
114+
verifyBlockIndex(blockResult, subexprs.size(), expr);
115+
boolean resultHasAtLeastOneBlockIndex =
116+
CelNavigableExpr.fromExpr(blockResult)
117+
.allNodes()
118+
.map(CelNavigableExpr::expr)
119+
.anyMatch(e -> e.identOrDefault().name().startsWith(INDEX_PREFIX));
120+
Preconditions.checkArgument(
121+
resultHasAtLeastOneBlockIndex,
122+
"Expected at least one reference of index in cel.block result");
123+
124+
return block;
125+
}
126+
127+
private static void verifyBlockIndex(CelExpr celExpr, int maxIndexValue, CelExpr rootBlock) {
128+
boolean areAllIndicesValid =
129+
CelNavigableExpr.fromExpr(celExpr)
130+
.allNodes()
131+
.map(CelNavigableExpr::expr)
132+
.filter(expr -> expr.identOrDefault().name().startsWith(INDEX_PREFIX))
133+
.map(CelExpr::ident)
134+
.allMatch(
135+
blockIdent ->
136+
Integer.parseInt(blockIdent.name().substring(INDEX_PREFIX.length()))
137+
< maxIndexValue);
138+
Preconditions.checkArgument(
139+
areAllIndicesValid,
140+
"Illegal block index found. The index value must be less than %s. Expr: %s",
141+
maxIndexValue,
142+
rootBlock);
143+
}
144+
}

optimizer/src/main/java/dev/cel/optimizer/optimizers/BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ java_library(
6060
"//common:mutable_ast",
6161
"//common:mutable_source",
6262
"//common/ast",
63+
"//common/ast:cel_block",
6364
"//common/ast:mutable_expr",
6465
"//common/navigation",
6566
"//common/navigation:common",

optimizer/src/main/java/dev/cel/optimizer/optimizers/SubexpressionOptimizer.java

Lines changed: 4 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
import dev.cel.common.CelSource.Extension.Version;
4242
import dev.cel.common.CelValidationException;
4343
import dev.cel.common.CelVarDecl;
44+
import dev.cel.common.ast.CelBlock;
4445
import dev.cel.common.ast.CelExpr;
4546
import dev.cel.common.ast.CelExpr.CelCall;
4647
import dev.cel.common.ast.CelExpr.CelComprehension;
@@ -238,64 +239,12 @@ private OptimizationResult optimizeUsingCelBlock(CelAbstractSyntaxTree ast, Cel
238239
*/
239240
@VisibleForTesting
240241
static void verifyOptimizedAstCorrectness(CelAbstractSyntaxTree ast) {
241-
CelNavigableExpr celNavigableExpr = CelNavigableExpr.fromExpr(ast.getExpr());
242-
243-
ImmutableList<CelExpr> allCelBlocks =
244-
celNavigableExpr
245-
.allNodes()
246-
.map(CelNavigableExpr::expr)
247-
.filter(expr -> expr.callOrDefault().function().equals(CEL_BLOCK_FUNCTION))
248-
.collect(toImmutableList());
249-
if (allCelBlocks.isEmpty()) {
242+
CelBlock celBlock = CelBlock.extract(ast).orElse(null);
243+
if (celBlock == null) {
250244
return;
251245
}
252246

253-
CelExpr celBlockExpr = allCelBlocks.get(0);
254-
Verify.verify(
255-
allCelBlocks.size() == 1,
256-
"Expected 1 cel.block function to be present but found %s",
257-
allCelBlocks.size());
258-
Verify.verify(
259-
celNavigableExpr.expr().equals(celBlockExpr), "Expected cel.block to be present at root");
260-
261-
// Assert correctness on block indices used in subexpressions
262-
CelCall celBlockCall = celBlockExpr.call();
263-
ImmutableList<CelExpr> subexprs = celBlockCall.args().get(0).list().elements();
264-
for (int i = 0; i < subexprs.size(); i++) {
265-
verifyBlockIndex(subexprs.get(i), i);
266-
}
267-
268-
// Assert correctness on block indices used in block result
269-
CelExpr blockResult = celBlockCall.args().get(1);
270-
verifyBlockIndex(blockResult, subexprs.size());
271-
boolean resultHasAtLeastOneBlockIndex =
272-
CelNavigableExpr.fromExpr(blockResult)
273-
.allNodes()
274-
.map(CelNavigableExpr::expr)
275-
.anyMatch(expr -> expr.identOrDefault().name().startsWith(BLOCK_INDEX_PREFIX));
276-
Verify.verify(
277-
resultHasAtLeastOneBlockIndex,
278-
"Expected at least one reference of index in cel.block result");
279-
280-
verifyNoInvalidScopedMangledVariables(celBlockExpr);
281-
}
282-
283-
private static void verifyBlockIndex(CelExpr celExpr, int maxIndexValue) {
284-
boolean areAllIndicesValid =
285-
CelNavigableExpr.fromExpr(celExpr)
286-
.allNodes()
287-
.map(CelNavigableExpr::expr)
288-
.filter(expr -> expr.identOrDefault().name().startsWith(BLOCK_INDEX_PREFIX))
289-
.map(CelExpr::ident)
290-
.allMatch(
291-
blockIdent ->
292-
Integer.parseInt(blockIdent.name().substring(BLOCK_INDEX_PREFIX.length()))
293-
< maxIndexValue);
294-
Verify.verify(
295-
areAllIndicesValid,
296-
"Illegal block index found. The index value must be less than %s. Expr: %s",
297-
maxIndexValue,
298-
celExpr);
247+
verifyNoInvalidScopedMangledVariables(celBlock.expr());
299248
}
300249

301250
private static void verifyNoInvalidScopedMangledVariables(CelExpr celExpr) {

0 commit comments

Comments
 (0)