@@ -24,6 +24,8 @@ import (
2424 "testing"
2525
2626 "github.com/google/cel-go/common"
27+ "github.com/google/cel-go/common/ast"
28+ "github.com/google/cel-go/common/decls"
2729 "github.com/google/cel-go/common/env"
2830 "github.com/google/cel-go/common/operators"
2931 "github.com/google/cel-go/common/types"
@@ -401,6 +403,30 @@ func TestEnvToConfig(t *testing.T) {
401403 },
402404 want : env .NewConfig ("context proto" ).SetContextVariable (env .NewContextVariable ("google.expr.proto3.test.TestAllTypes" )),
403405 },
406+ {
407+ name : "feature flags" ,
408+ opts : []EnvOption {
409+ DefaultUTCTimeZone (false ),
410+ EnableMacroCallTracking (),
411+ },
412+ want : env .NewConfig ("feature flags" ).AddFeatures (
413+ env .NewFeature ("cel.feature.macro_call_tracking" , true ),
414+ ),
415+ },
416+ {
417+ name : "validators" ,
418+ opts : []EnvOption {
419+ ExtendedValidations (),
420+ ASTValidators (ValidateComprehensionNestingLimit (1 )),
421+ },
422+ want : env .NewConfig ("validators" ).AddValidators (
423+ env .NewValidator ("cel.validator.duration" ),
424+ env .NewValidator ("cel.validator.timestamp" ),
425+ env .NewValidator ("cel.validator.matches" ),
426+ env .NewValidator ("cel.validator.homogeneous_literals" ),
427+ env .NewValidator ("cel.validator.comprehension_nesting_limit" ).SetConfig (map [string ]any {"limit" : 1 }),
428+ ),
429+ },
404430 }
405431
406432 for _ , tst := range tests {
@@ -430,11 +456,12 @@ func TestEnvFromConfig(t *testing.T) {
430456 out ref.Val
431457 }
432458 tests := []struct {
433- name string
434- beforeOpts []EnvOption
435- afterOpts []EnvOption
436- conf * env.Config
437- exprs []exprCase
459+ name string
460+ beforeOpts []EnvOption
461+ afterOpts []EnvOption
462+ conf * env.Config
463+ confHandlers []ConfigOptionFactory
464+ exprs []exprCase
438465 }{
439466 {
440467 name : "std env" ,
@@ -617,18 +644,138 @@ func TestEnvFromConfig(t *testing.T) {
617644 },
618645 },
619646 },
647+ {
648+ name : "extensions - config factory" ,
649+ conf : env .NewConfig ("extensions" ).
650+ AddExtensions (env .NewExtension ("plus" , math .MaxUint32 )),
651+ confHandlers : []ConfigOptionFactory {
652+ func (a any ) (EnvOption , bool ) {
653+ ext , ok := a .(* env.Extension )
654+ if ! ok || ext .Name != "plus" {
655+ return nil , false
656+ }
657+ return Function ("plus" , Overload ("plus_int_int" , []* Type {IntType , IntType }, IntType ,
658+ decls .BinaryBinding (func (lhs , rhs ref.Val ) ref.Val {
659+ l := lhs .(types.Int )
660+ r := rhs .(types.Int )
661+ return l + r
662+ }))), true
663+ },
664+ },
665+ exprs : []exprCase {
666+ {
667+ name : "plus" ,
668+ expr : "plus(1, 2)" ,
669+ out : types .Int (3 ),
670+ },
671+ },
672+ },
673+ {
674+ name : "features" ,
675+ conf : env .NewConfig ("features" ).
676+ AddVariables (
677+ env .NewVariable ("m" ,
678+ env .NewTypeDesc ("map" , env .NewTypeDesc ("string" ), env .NewTypeDesc ("string" )))).
679+ AddFeatures (
680+ env .NewFeature ("cel.feature.backtick_escape_syntax" , true ),
681+ env .NewFeature ("cel.feature.unknown_feature_name" , true )),
682+ exprs : []exprCase {
683+ {
684+ name : "optional key" ,
685+ expr : "m.`key-name` == 'value'" ,
686+ in : map [string ]any {"m" : map [string ]string {"key-name" : "value" }},
687+ out : types .True ,
688+ },
689+ },
690+ },
691+ {
692+ name : "validators" ,
693+ conf : env .NewConfig ("validators" ).
694+ AddVariables (
695+ env .NewVariable ("m" ,
696+ env .NewTypeDesc ("map" , env .NewTypeDesc ("string" ), env .NewTypeDesc ("string" ))),
697+ ).
698+ AddValidators (
699+ env .NewValidator (durationValidatorName ),
700+ env .NewValidator (timestampValidatorName ),
701+ env .NewValidator (regexValidatorName ),
702+ env .NewValidator (homogeneousValidatorName ),
703+ env .NewValidator (nestingLimitValidatorName ).SetConfig (map [string ]any {"limit" : 0 }),
704+ ),
705+ exprs : []exprCase {
706+ {
707+ name : "bad duration" ,
708+ expr : "duration('1')" ,
709+ iss : errors .New ("invalid duration" ),
710+ },
711+ {
712+ name : "bad timestamp" ,
713+ expr : "timestamp('1')" ,
714+ iss : errors .New ("invalid timestamp" ),
715+ },
716+ {
717+ name : "bad regex" ,
718+ expr : "'hello'.matches('?^()')" ,
719+ iss : errors .New ("invalid matches" ),
720+ },
721+ {
722+ name : "mixed type list" ,
723+ expr : "[1, 2.0]" ,
724+ iss : errors .New ("expected type 'int'" ),
725+ },
726+ {
727+ name : "disabled comprehension" ,
728+ expr : "[1, 2].exists(x, x % 2 == 0)" ,
729+ iss : errors .New ("comprehension exceeds nesting limit" ),
730+ },
731+ },
732+ },
733+ {
734+ name : "validators - config factory" ,
735+ conf : env .NewConfig ("validators" ).
736+ AddValidators (
737+ env .NewValidator ("cel.validators.return_type" ).SetConfig (map [string ]any {"type_name" : "string" }),
738+ ),
739+ confHandlers : []ConfigOptionFactory {
740+ func (a any ) (EnvOption , bool ) {
741+ val , ok := a .(* env.Validator )
742+ if ! ok || val .Name != "cel.validators.return_type" {
743+ return nil , false
744+ }
745+ typeName , found := val .ConfigValue ("type_name" )
746+ if ! found {
747+ return func (* Env ) (* Env , error ) {
748+ return nil , fmt .Errorf ("invalid validator: %s missing config parameter 'type_name'" , val .Name )
749+ }, true
750+ }
751+ return func (e * Env ) (* Env , error ) {
752+ t , err := env .NewTypeDesc (typeName .(string )).AsCELType (e .CELTypeProvider ())
753+ if err != nil {
754+ return nil , err
755+ }
756+ return ASTValidators (returnTypeValidator {returnType : t })(e )
757+ }, true
758+ },
759+ },
760+ exprs : []exprCase {
761+ {
762+ name : "string - ok" ,
763+ expr : "'hello'" ,
764+ out : types .String ("hello" ),
765+ },
766+ {
767+ name : "int - error" ,
768+ expr : "1" ,
769+ iss : errors .New ("unsupported return type: int, want string" ),
770+ },
771+ },
772+ },
620773 }
621774 for _ , tst := range tests {
622775 tc := tst
623776 t .Run (tc .name , func (t * testing.T ) {
624777 opts := tc .beforeOpts
625- opts = append (opts , FromConfig (tc .conf , func (elem any ) (EnvOption , bool ) {
626- if ext , ok := elem .(* env.Extension ); ok && ext .Name == "optional" {
627- ver , _ := ext .GetVersion ()
628- return OptionalTypes (OptionalTypesVersion (ver )), true
629- }
630- return nil , false
631- }))
778+ opts = append (opts , FromConfig (tc .conf , tc .confHandlers ... ))
632779 opts = append (opts , tc .afterOpts ... )
633780 var e * Env
634781 var err error
@@ -679,6 +826,16 @@ func TestEnvFromConfigErrors(t *testing.T) {
679826 conf * env.Config
680827 want error
681828 }{
829+ {
830+ name : "bad container" ,
831+ conf : env .NewConfig ("bad container" ).SetContainer (".hello.world" ),
832+ want : errors .New ("container name must not contain" ),
833+ },
834+ {
835+ name : "colliding imports" ,
836+ conf : env .NewConfig ("colliding imports" ).AddImports (env .NewImport ("pkg.ImportName" ), env .NewImport ("pkg2.ImportName" )),
837+ want : errors .New ("abbreviation collides" ),
838+ },
682839 {
683840 name : "invalid subset" ,
684841 conf : env .NewConfig ("invalid subset" ).SetStdLib (env .NewLibrarySubset ().SetDisableMacros (true )),
@@ -707,9 +864,21 @@ func TestEnvFromConfigErrors(t *testing.T) {
707864 {
708865 name : "unrecognized extension" ,
709866 conf : env .NewConfig ("unrecognized extension" ).
710- AddExtensions (env .NewExtension ("optional " , math .MaxUint32 )),
867+ AddExtensions (env .NewExtension ("unrecognized " , math .MaxUint32 )),
711868 want : errors .New ("unrecognized extension" ),
712869 },
870+ {
871+ name : "invalid validator config" ,
872+ conf : env .NewConfig ("invalid validator config" ).
873+ AddValidators (env .NewValidator ("cel.validator.comprehension_nesting_limit" )),
874+ want : errors .New ("invalid validator" ),
875+ },
876+ {
877+ name : "invalid validator config type" ,
878+ conf : env .NewConfig ("invalid validator config" ).
879+ AddValidators (env .NewValidator ("cel.validator.comprehension_nesting_limit" ).SetConfig (map [string ]any {"limit" : 2.0 })),
880+ want : errors .New ("invalid validator" ),
881+ },
713882 }
714883 for _ , tst := range tests {
715884 tc := tst
@@ -829,6 +998,26 @@ func mustContextProto(t *testing.T, pb proto.Message) Activation {
829998 return ctx
830999}
8311000
1001+ type returnTypeValidator struct {
1002+ returnType * Type
1003+ }
1004+
1005+ func (returnTypeValidator ) Name () string {
1006+ return "cel.validators.return_type"
1007+ }
1008+
1009+ func (v returnTypeValidator ) Validate (_ * Env , c ValidatorConfig , a * ast.AST , iss * Issues ) {
1010+ if a .GetType (a .Expr ().ID ()) != v .returnType {
1011+ iss .ReportErrorAtID (a .Expr ().ID (),
1012+ "unsupported return type: %s, want %s" ,
1013+ a .GetType (a .Expr ().ID ()), v .returnType .TypeName ())
1014+ }
1015+ }
1016+
1017+ func (v returnTypeValidator ) ToConfig () * env.Validator {
1018+ return env .NewValidator (v .Name ()).SetConfig (map [string ]any {"type_name" : v .returnType .TypeName ()})
1019+ }
1020+
8321021type customLegacyProvider struct {
8331022 provider ref.TypeProvider
8341023}
0 commit comments