Skip to content

Commit

Permalink
Ensure variables in comprehensions don't collide (#1062)
Browse files Browse the repository at this point in the history
  • Loading branch information
TristonianJones authored Nov 6, 2024
1 parent 3f12eca commit 8ad600b
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 32 deletions.
66 changes: 34 additions & 32 deletions ext/comprehensions.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
package ext

import (
"fmt"

"github.com/google/cel-go/cel"
"github.com/google/cel-go/common/ast"
"github.com/google/cel-go/common/operators"
Expand Down Expand Up @@ -220,14 +222,11 @@ func (compreV2Lib) ProgramOptions() []cel.ProgramOption {
}

func quantifierAll(mef cel.MacroExprFactory, target ast.Expr, args []ast.Expr) (ast.Expr, *cel.Error) {
iterVar1, err := extractIterVar(mef, args[0])
if err != nil {
return nil, err
}
iterVar2, err := extractIterVar(mef, args[1])
iterVar1, iterVar2, err := extractIterVars(mef, args[0], args[1])
if err != nil {
return nil, err
}

return mef.NewComprehensionTwoVar(
target,
iterVar1,
Expand All @@ -241,14 +240,11 @@ func quantifierAll(mef cel.MacroExprFactory, target ast.Expr, args []ast.Expr) (
}

func quantifierExists(mef cel.MacroExprFactory, target ast.Expr, args []ast.Expr) (ast.Expr, *cel.Error) {
iterVar1, err := extractIterVar(mef, args[0])
if err != nil {
return nil, err
}
iterVar2, err := extractIterVar(mef, args[1])
iterVar1, iterVar2, err := extractIterVars(mef, args[0], args[1])
if err != nil {
return nil, err
}

return mef.NewComprehensionTwoVar(
target,
iterVar1,
Expand All @@ -262,14 +258,11 @@ func quantifierExists(mef cel.MacroExprFactory, target ast.Expr, args []ast.Expr
}

func quantifierExistsOne(mef cel.MacroExprFactory, target ast.Expr, args []ast.Expr) (ast.Expr, *cel.Error) {
iterVar1, err := extractIterVar(mef, args[0])
if err != nil {
return nil, err
}
iterVar2, err := extractIterVar(mef, args[1])
iterVar1, iterVar2, err := extractIterVars(mef, args[0], args[1])
if err != nil {
return nil, err
}

return mef.NewComprehensionTwoVar(
target,
iterVar1,
Expand All @@ -285,11 +278,7 @@ func quantifierExistsOne(mef cel.MacroExprFactory, target ast.Expr, args []ast.E
}

func transformList(mef cel.MacroExprFactory, target ast.Expr, args []ast.Expr) (ast.Expr, *cel.Error) {
iterVar1, err := extractIterVar(mef, args[0])
if err != nil {
return nil, err
}
iterVar2, err := extractIterVar(mef, args[1])
iterVar1, iterVar2, err := extractIterVars(mef, args[0], args[1])
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -324,11 +313,7 @@ func transformList(mef cel.MacroExprFactory, target ast.Expr, args []ast.Expr) (
}

func transformMap(mef cel.MacroExprFactory, target ast.Expr, args []ast.Expr) (ast.Expr, *cel.Error) {
iterVar1, err := extractIterVar(mef, args[0])
if err != nil {
return nil, err
}
iterVar2, err := extractIterVar(mef, args[1])
iterVar1, iterVar2, err := extractIterVars(mef, args[0], args[1])
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -362,11 +347,7 @@ func transformMap(mef cel.MacroExprFactory, target ast.Expr, args []ast.Expr) (a
}

func transformMapEntry(mef cel.MacroExprFactory, target ast.Expr, args []ast.Expr) (ast.Expr, *cel.Error) {
iterVar1, err := extractIterVar(mef, args[0])
if err != nil {
return nil, err
}
iterVar2, err := extractIterVar(mef, args[1])
iterVar1, iterVar2, err := extractIterVars(mef, args[0], args[1])
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -399,10 +380,31 @@ func transformMapEntry(mef cel.MacroExprFactory, target ast.Expr, args []ast.Exp
), nil
}

func extractIterVar(meh cel.MacroExprFactory, target ast.Expr) (string, *cel.Error) {
func extractIterVars(mef cel.MacroExprFactory, arg0, arg1 ast.Expr) (string, string, *cel.Error) {
iterVar1, err := extractIterVar(mef, arg0)
if err != nil {
return "", "", err
}
iterVar2, err := extractIterVar(mef, arg1)
if err != nil {
return "", "", err
}
if iterVar1 == iterVar2 {
return "", "", mef.NewError(arg1.ID(), fmt.Sprintf("duplicate variable name: %s", iterVar1))
}
if iterVar1 == parser.AccumulatorName {
return "", "", mef.NewError(arg0.ID(), "iteration variable overwrites accumulator variable")
}
if iterVar2 == parser.AccumulatorName {
return "", "", mef.NewError(arg1.ID(), "iteration variable overwrites accumulator variable")
}
return iterVar1, iterVar2, nil
}

func extractIterVar(mef cel.MacroExprFactory, target ast.Expr) (string, *cel.Error) {
iterVar, found := extractIdent(target)
if !found {
return "", meh.NewError(target.ID(), "argument must be a simple name")
return "", mef.NewError(target.ID(), "argument must be a simple name")
}
return iterVar, nil
}
12 changes: 12 additions & 0 deletions ext/comprehensions_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,18 @@ func TestTwoVarComprehensionsStaticErrors(t *testing.T) {
expr string
err string
}{
{
expr: "[].all(i, i, i < i)",
err: "duplicate variable name: i",
},
{
expr: "[].all(__result__, i, __result__ < i)",
err: "iteration variable overwrites accumulator variable",
},
{
expr: "[].all(j, __result__, __result__ < j)",
err: "iteration variable overwrites accumulator variable",
},
{
expr: "[].all(i.j, k, i.j < k)",
err: "argument must be a simple name",
Expand Down

0 comments on commit 8ad600b

Please sign in to comment.