Skip to content

Commit

Permalink
Report safety errors for generated vars
Browse files Browse the repository at this point in the history
Originally this looked like a type checking bug, but in fact it was
related to the safety check.

Previously, we filtered out safety errors for generated vars because
they are meaningless to the user. This led to unsafe queries being
allowed in some edge cases. For example: sum() > 1 would be
expanded to sum(__local0__); __local0__ > 1. The first expression is
unsafe but this would not be reported.

This change ensures that the compiler reports an error in all cases.

With this change, safety errors for generated vars are turned into error
messages indicating the expression is unsafe. Above, there would be a
single error indicating the expression "sum() > 1" is unsafe.

Fixes open-policy-agent#661

Signed-off-by: Torin Sandall <[email protected]>
  • Loading branch information
tsandall committed Mar 21, 2018
1 parent 041169e commit 548c85a
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 17 deletions.
80 changes: 64 additions & 16 deletions ast/compile.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package ast

import (
"fmt"
"sort"
"strings"

"github.com/open-policy-agent/opa/util"
Expand Down Expand Up @@ -552,11 +553,9 @@ func (c *Compiler) checkSafetyRuleBodies() {

func (c *Compiler) checkBodySafety(safe VarSet, m *Module, b Body, l *Location) Body {
reordered, unsafe := reorderBodyForSafety(c.GetArity, safe, b)
if len(unsafe) != 0 {
for v := range unsafe.Vars() {
if !v.IsGenerated() {
c.err(NewError(UnsafeVarErr, l, "var %v is unsafe", v))
}
if errs := safetyErrorSlice(l, unsafe); len(errs) > 0 {
for _, err := range errs {
c.err(err)
}
return b
}
Expand Down Expand Up @@ -1005,20 +1004,11 @@ func (qc *queryCompiler) rewriteLocalAssignments(_ *QueryContext, body Body) (Bo
}

func (qc *queryCompiler) checkSafety(_ *QueryContext, body Body) (Body, error) {

safe := ReservedVars.Copy()
reordered, unsafe := reorderBodyForSafety(qc.compiler.GetArity, safe, body)

if len(unsafe) != 0 {
var err Errors
for v := range unsafe.Vars() {
if !v.IsGenerated() {
err = append(err, NewError(UnsafeVarErr, body.Loc(), "var %v is unsafe", v))
}
}
return nil, err
if errs := safetyErrorSlice(body.Loc(), unsafe); len(errs) > 0 {
return nil, errs
}

return reordered, nil
}

Expand Down Expand Up @@ -1398,6 +1388,11 @@ func (g *graphTraversal) Visited(u util.T) bool {
return ok
}

type unsafePair struct {
Expr *Expr
Vars VarSet
}

type unsafeVars map[*Expr]VarSet

func (vs unsafeVars) Add(e *Expr, v Var) {
Expand Down Expand Up @@ -1429,6 +1424,16 @@ func (vs unsafeVars) Vars() VarSet {
return r
}

func (vs unsafeVars) Slice() (result []unsafePair) {
for expr, vs := range vs {
result = append(result, unsafePair{
Expr: expr,
Vars: vs,
})
}
return
}

// reorderBodyForSafety returns a copy of the body ordered such that
// left to right evaluation of the body will not encounter unbound variables
// in input positions or negated expressions.
Expand Down Expand Up @@ -2596,3 +2601,46 @@ func rewriteDeclaredVar(g *localVarGenerator, stack *localDeclaredVars, v Var) (
stack.Insert(v, gv)
return
}

func safetyErrorSlice(l *Location, unsafe unsafeVars) (result Errors) {

if len(unsafe) == 0 {
return
}

for v := range unsafe.Vars() {
if !v.IsGenerated() {
result = append(result, NewError(UnsafeVarErr, l, "var %v is unsafe", v))
}
}

if len(result) > 0 {
return
}

// If the expression contains unsafe generated variables, report which
// expressions are unsafe instead of the variables that are unsafe (since
// the latter are not meaningful to the user.)
pairs := unsafe.Slice()

sort.Slice(pairs, func(i, j int) bool {
return pairs[i].Expr.Location.Compare(pairs[j].Expr.Location) < 0
})

// Report at most one error per generated variable.
seen := NewVarSet()

for _, expr := range pairs {
before := len(seen)
for v := range expr.Vars {
if v.IsGenerated() {
seen.Add(v)
}
}
if len(seen) > before {
result = append(result, NewError(UnsafeVarErr, expr.Expr.Location, "expression is unsafe"))
}
}

return
}
2 changes: 1 addition & 1 deletion ast/compile_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2024,9 +2024,9 @@ func TestQueryCompiler(t *testing.T) {
{"safe vars", `data; abc`, `package ex`, []string{"import input.xyz as abc"}, `{}`, `data; input.xyz`},
{"reorder", `x != 1; x = 0`, "", nil, "", `x = 0; x != 1`},
{"bad with target", "x = 1 with data.p as null", "", nil, "", fmt.Errorf("1 error occurred: 1:7: rego_type_error: with keyword target must be input")},
{"unsafe exprs", "count(sum())", "", nil, "", fmt.Errorf("1 error occurred: 1:1: rego_unsafe_var_error: expression is unsafe")},
{"check types", "x = data.a.b.c.z; y = null; x = y", "", nil, "", fmt.Errorf("match error\n\tleft : number\n\tright : null")},
}

for _, tc := range tests {
runQueryCompilerTest(t, tc.note, tc.q, tc.pkg, tc.imports, tc.input, tc.expected)
}
Expand Down

0 comments on commit 548c85a

Please sign in to comment.