From 5d1a668e0120fff799984f9f3ed6c7e1000e1940 Mon Sep 17 00:00:00 2001 From: Jonathan Tatum Date: Mon, 6 Oct 2025 13:42:43 -0700 Subject: [PATCH] Add support for configuring custom expression-level validators in CelPolicyCompiler PiperOrigin-RevId: 815866955 --- .../src/main/java/dev/cel/policy/BUILD.bazel | 1 + .../cel/policy/CelPolicyCompilerBuilder.java | 19 ++++ .../dev/cel/policy/CelPolicyCompilerImpl.java | 49 +++++++- .../src/test/java/dev/cel/policy/BUILD.bazel | 5 +- .../cel/policy/CelPolicyCompilerImplTest.java | 106 ++++++++++++++++++ 5 files changed, 177 insertions(+), 3 deletions(-) diff --git a/policy/src/main/java/dev/cel/policy/BUILD.bazel b/policy/src/main/java/dev/cel/policy/BUILD.bazel index ee35e8271..6295e3576 100644 --- a/policy/src/main/java/dev/cel/policy/BUILD.bazel +++ b/policy/src/main/java/dev/cel/policy/BUILD.bazel @@ -137,6 +137,7 @@ java_library( ], deps = [ ":compiler", + "//validator:ast_validator", "@maven//:com_google_errorprone_error_prone_annotations", ], ) diff --git a/policy/src/main/java/dev/cel/policy/CelPolicyCompilerBuilder.java b/policy/src/main/java/dev/cel/policy/CelPolicyCompilerBuilder.java index 592a0120d..d6cd8e6a5 100644 --- a/policy/src/main/java/dev/cel/policy/CelPolicyCompilerBuilder.java +++ b/policy/src/main/java/dev/cel/policy/CelPolicyCompilerBuilder.java @@ -16,6 +16,7 @@ import com.google.errorprone.annotations.CanIgnoreReturnValue; import com.google.errorprone.annotations.CheckReturnValue; +import dev.cel.validator.CelAstValidator; /** Interface for building an instance of {@link CelPolicyCompiler} */ public interface CelPolicyCompilerBuilder { @@ -38,6 +39,24 @@ public interface CelPolicyCompilerBuilder { @CanIgnoreReturnValue CelPolicyCompilerBuilder setAstDepthLimit(int iterationLimit); + /** + * Adds one or more {@link CelAstValidators} to the compiler. These apply per CEL expression in + * the policy. + */ + @CanIgnoreReturnValue + CelPolicyCompilerBuilder addValidators(Iterable validators); + + /** + * Adds one or more {@link CelAstValidators} to the compiler. These apply per CEL expression in + * the policy. + */ + @CanIgnoreReturnValue + CelPolicyCompilerBuilder addValidators(CelAstValidator... validators); + + /** Removes any custom validators from the compiler builder. */ + @CanIgnoreReturnValue + CelPolicyCompilerBuilder clearValidators(); + @CheckReturnValue CelPolicyCompiler build(); } diff --git a/policy/src/main/java/dev/cel/policy/CelPolicyCompilerImpl.java b/policy/src/main/java/dev/cel/policy/CelPolicyCompilerImpl.java index f6f893c1c..6005bd4fe 100644 --- a/policy/src/main/java/dev/cel/policy/CelPolicyCompilerImpl.java +++ b/policy/src/main/java/dev/cel/policy/CelPolicyCompilerImpl.java @@ -64,6 +64,7 @@ final class CelPolicyCompilerImpl implements CelPolicyCompiler { private final String variablesPrefix; private final int iterationLimit; private final Optional astDepthValidator; + private final Optional validator; @Override public CelCompiledRule compileRule(CelPolicy policy) throws CelPolicyValidationException { @@ -194,6 +195,10 @@ private CelCompiledRule compileRuleImpl( CelType outputType = SimpleType.DYN; try { varAst = localCel.compile(expression.value()).getAst(); + if (this.validator.isPresent()) { + CelValidationResult result = this.validator.get().validate(varAst); + varAst = result.getAst(); + } outputType = varAst.getResultType(); } catch (CelValidationException e) { compilerContext.addIssue(expression.id(), e.getErrors()); @@ -212,6 +217,10 @@ private CelCompiledRule compileRuleImpl( CelAbstractSyntaxTree conditionAst; try { conditionAst = localCel.compile(match.condition().value()).getAst(); + if (this.validator.isPresent()) { + CelValidationResult result = this.validator.get().validate(conditionAst); + conditionAst = result.getAst(); + } if (!conditionAst.getResultType().equals(SimpleType.BOOL)) { compilerContext.addIssue( match.condition().id(), @@ -229,6 +238,10 @@ private CelCompiledRule compileRuleImpl( ValueString output = match.result().output(); try { outputAst = localCel.compile(output.value()).getAst(); + if (this.validator.isPresent()) { + CelValidationResult result = this.validator.get().validate(outputAst); + outputAst = result.getAst(); + } } catch (CelValidationException e) { compilerContext.addIssue(output.id(), e.getErrors()); continue; @@ -340,10 +353,12 @@ static final class Builder implements CelPolicyCompilerBuilder { private String variablesPrefix; private int iterationLimit; private Optional astDepthLimitValidator; + private ArrayList validators; private Builder(Cel cel) { this.cel = cel; this.astDepthLimitValidator = Optional.of(AstDepthLimitValidator.DEFAULT); + this.validators = new ArrayList<>(); } @Override @@ -360,6 +375,26 @@ public Builder setIterationLimit(int iterationLimit) { return this; } + @Override + @CanIgnoreReturnValue + public Builder addValidators(Iterable validators) { + validators.forEach(this.validators::add); + return this; + } + + @Override + @CanIgnoreReturnValue + public Builder addValidators(CelAstValidator... validators) { + return addValidators(Arrays.asList(validators)); + } + + @Override + @CanIgnoreReturnValue + public Builder clearValidators() { + this.validators.clear(); + return this; + } + @Override @CanIgnoreReturnValue public CelPolicyCompilerBuilder setAstDepthLimit(int astDepthLimit) { @@ -374,7 +409,7 @@ public CelPolicyCompilerBuilder setAstDepthLimit(int astDepthLimit) { @Override public CelPolicyCompiler build() { return new CelPolicyCompilerImpl( - cel, this.variablesPrefix, this.iterationLimit, astDepthLimitValidator); + cel, this.variablesPrefix, this.iterationLimit, astDepthLimitValidator, validators); } } @@ -388,10 +423,20 @@ private CelPolicyCompilerImpl( Cel cel, String variablesPrefix, int iterationLimit, - Optional astDepthValidator) { + Optional astDepthValidator, + List additionalValidators) { this.cel = checkNotNull(cel); this.variablesPrefix = checkNotNull(variablesPrefix); this.iterationLimit = iterationLimit; this.astDepthValidator = astDepthValidator; + if (additionalValidators.isEmpty()) { + this.validator = Optional.empty(); + } else { + this.validator = + Optional.of( + CelValidatorFactory.standardCelValidatorBuilder(cel) + .addAstValidators(additionalValidators) + .build()); + } } } diff --git a/policy/src/test/java/dev/cel/policy/BUILD.bazel b/policy/src/test/java/dev/cel/policy/BUILD.bazel index 9106caf70..e1195781e 100644 --- a/policy/src/test/java/dev/cel/policy/BUILD.bazel +++ b/policy/src/test/java/dev/cel/policy/BUILD.bazel @@ -17,8 +17,11 @@ java_library( "//bundle:environment_yaml_parser", "//common:cel_ast", "//common:options", + "//common/ast", "//common/formats:value_string", "//common/internal", + "//common/navigation", + "//common/navigation:common", "//common/resources/testdata/proto3:standalone_global_enum_java_proto", "//common/types", "//compiler", @@ -35,8 +38,8 @@ java_library( "//policy:validation_exception", "//runtime", "//runtime:function_binding", - "//runtime:late_function_binding", "//testing/protos:single_file_java_proto", + "//validator:ast_validator", "@cel_spec//proto/cel/expr/conformance/proto3:test_all_types_java_proto", "@maven//:com_google_guava_guava", "@maven//:com_google_testparameterinjector_test_parameter_injector", diff --git a/policy/src/test/java/dev/cel/policy/CelPolicyCompilerImplTest.java b/policy/src/test/java/dev/cel/policy/CelPolicyCompilerImplTest.java index fa0da8a9a..78727c5e6 100644 --- a/policy/src/test/java/dev/cel/policy/CelPolicyCompilerImplTest.java +++ b/policy/src/test/java/dev/cel/policy/CelPolicyCompilerImplTest.java @@ -31,6 +31,11 @@ import dev.cel.bundle.CelFactory; import dev.cel.common.CelAbstractSyntaxTree; import dev.cel.common.CelOptions; +import dev.cel.common.ast.CelConstant; +import dev.cel.common.ast.CelExpr.ExprKind; +import dev.cel.common.navigation.CelNavigableAst; +import dev.cel.common.navigation.CelNavigableExpr; +import dev.cel.common.navigation.TraversalOrder; import dev.cel.common.types.OptionalType; import dev.cel.common.types.SimpleType; import dev.cel.expr.conformance.proto3.TestAllTypes; @@ -47,6 +52,8 @@ import dev.cel.runtime.CelLateFunctionBindings; import dev.cel.testing.testdata.SingleFileProto.SingleFile; import dev.cel.testing.testdata.proto3.StandaloneGlobalEnum; +import dev.cel.validator.CelAstValidator; +import dev.cel.validator.CelAstValidator.IssuesFactory; import java.io.IOException; import java.util.Map; import java.util.Optional; @@ -265,6 +272,105 @@ public void evaluateYamlPolicy_nestedRuleProducesOptionalOutput() throws Excepti assertThat(evalResult).hasValue(Optional.of(true)); } + static final class NoFooLiteralsValidator implements CelAstValidator { + private static boolean isFooLiteral(CelNavigableExpr node) { + return node.getKind().equals(ExprKind.Kind.CONSTANT) + && node.expr().constant().getKind().equals(CelConstant.Kind.STRING_VALUE) + && node.expr().constant().stringValue().equals("foo"); + } + + @Override + public void validate(CelNavigableAst navigableAst, Cel cel, IssuesFactory issuesFactory) { + navigableAst + .getRoot() + .descendants(TraversalOrder.POST_ORDER) + .filter(NoFooLiteralsValidator::isFooLiteral) + .forEach(node -> issuesFactory.addError(node.id(), "'foo' is a forbidden literal")); + } + } + + @Test + public void evaluateYamlPolicy_validatorReportsErrors() throws Exception { + Cel cel = newCel(); + String policySource = + "name: nested_rule_with_forbidden_literal\n" + + "rule:\n" + + " variables:\n" + + " - name: 'foo'\n" + + " expression: \"(true) ? 'bar' : 'foo'\"\n" + + " match:\n" + + " - condition: |\n" + + " variables.foo in ['foo', 'bar', 'foo']\n" + + " output: >\n" + + " 'foo' == variables.foo\n"; + CelPolicy policy = POLICY_PARSER.parse(policySource); + CelPolicyValidationException e = + assertThrows( + CelPolicyValidationException.class, + () -> + CelPolicyCompilerFactory.newPolicyCompiler(cel) + .addValidators(new NoFooLiteralsValidator()) + .build() + .compile(policy)); + + assertThat(e) + .hasMessageThat() + .contains( + "ERROR: :5:37: 'foo' is a forbidden literal\n" + + " | expression: \"(true) ? 'bar' : 'foo'\"\n" + + " | ....................................^"); + assertThat(e) + .hasMessageThat() + .contains( + "ERROR: :8:27: 'foo' is a forbidden literal\n" + + " | variables.foo in ['foo', 'bar', 'foo']\n" + + " | ..........................^"); + assertThat(e) + .hasMessageThat() + .contains( + "ERROR: :8:41: 'foo' is a forbidden literal\n" + + " | variables.foo in ['foo', 'bar', 'foo']\n" + + " | ........................................^"); + } + + // If the condition fails to validate, then the compiler doesn't attempt to compile or validate + // the output, so second test case for coverage. + @Test + public void evaluateYamlPolicy_validatorReportsOutput() throws Exception { + Cel cel = newCel(); + String policySource = + "name: nested_rule_with_forbidden_literal\n" + + "rule:\n" + + " variables:\n" + + " - name: 'foo'\n" + + " expression: \"(true) ? 'bar' : 'foo'\"\n" + + " match:\n" + + " - output: >\n" + + " 'foo' == variables.foo\n"; + CelPolicy policy = POLICY_PARSER.parse(policySource); + CelPolicyValidationException e = + assertThrows( + CelPolicyValidationException.class, + () -> + CelPolicyCompilerFactory.newPolicyCompiler(cel) + .addValidators(new NoFooLiteralsValidator()) + .build() + .compile(policy)); + + assertThat(e) + .hasMessageThat() + .contains( + "ERROR: :5:37: 'foo' is a forbidden literal\n" + + " | expression: \"(true) ? 'bar' : 'foo'\"\n" + + " | ....................................^"); + assertThat(e) + .hasMessageThat() + .contains( + "ERROR: :8:9: 'foo' is a forbidden literal\n" + + " | 'foo' == variables.foo\n" + + " | ........^"); + } + @Test public void evaluateYamlPolicy_lateBoundFunction() throws Exception { String configSource =