diff --git a/CHANGELOG.md b/CHANGELOG.md index 471d17f55..d7f5a796e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,10 @@ project adheres to [Semantic Versioning](http://semver.org/). ## Unreleased +- Refactor creation of prepared queries (#604, @mjungsbluth) + + When using the opa-envoy-plugin as a Go library, the interface EvalContext contains a breaking change [#604](https://github.com/open-policy-agent/opa-envoy-plugin/pull/604) that allows users of the library to control all three types of options that can be passed during preparation and evaluation of the underlying Rego query. + ### Fixes - Support escaped forward-slashes (`\/`) in JSON request bodies (#256, @Dakatan). diff --git a/envoyauth/evaluation.go b/envoyauth/evaluation.go index 719eb7a00..c2c09a85d 100644 --- a/envoyauth/evaluation.go +++ b/envoyauth/evaluation.go @@ -3,13 +3,11 @@ package envoyauth import ( "context" "fmt" - "sync" "github.com/open-policy-agent/opa/ast" "github.com/open-policy-agent/opa/bundle" "github.com/open-policy-agent/opa/config" "github.com/open-policy-agent/opa/logging" - "github.com/open-policy-agent/opa/metrics" "github.com/open-policy-agent/opa/rego" "github.com/open-policy-agent/opa/storage" "github.com/open-policy-agent/opa/topdown/builtins" @@ -25,17 +23,21 @@ type EvalContext interface { Store() storage.Store Compiler() *ast.Compiler Runtime() *ast.Term - PreparedQueryDoOnce() *sync.Once InterQueryBuiltinCache() iCache.InterQueryCache - PreparedQuery() *rego.PreparedEvalQuery - SetPreparedQuery(*rego.PreparedEvalQuery) Logger() logging.Logger Config() *config.Config DistributedTracing() tracing.Options + CreatePreparedQueryOnce(opts PrepareQueryOpts) (*rego.PreparedEvalQuery, error) +} + +// PrepareQueryOpts - Options to prepare a Rego query to be passed to the CreatePreparedQueryOnce method +type PrepareQueryOpts struct { + Opts []func(*rego.Rego) + PrepareOpts []rego.PrepareOption } // Eval - Evaluates an input against a provided EvalContext and yields result -func Eval(ctx context.Context, evalContext EvalContext, input ast.Value, result *EvalResult, opts ...func(*rego.Rego)) error { +func Eval(ctx context.Context, evalContext EvalContext, input ast.Value, result *EvalResult, evalOpts ...rego.EvalOption) error { var err error logger := evalContext.Logger() @@ -64,7 +66,19 @@ func Eval(ctx context.Context, evalContext EvalContext, input ast.Value, result "txn": result.TxnID, }).Debug("Executing policy query.") - err = constructPreparedQuery(evalContext, result.Txn, result.Metrics, opts) + pq, err := evalContext.CreatePreparedQueryOnce( + PrepareQueryOpts{ + Opts: []func(*rego.Rego){ + rego.Metrics(result.Metrics), + rego.ParsedQuery(evalContext.ParsedQuery()), + rego.Compiler(evalContext.Compiler()), + rego.Store(evalContext.Store()), + rego.Transaction(result.Txn), + rego.Runtime(evalContext.Runtime()), + rego.EnablePrintStatements(true), + rego.DistributedTracingOpts(evalContext.DistributedTracing()), + }, + }) if err != nil { return err } @@ -76,15 +90,22 @@ func Eval(ctx context.Context, evalContext EvalContext, input ast.Value, result ndbCache = builtins.NDBCache{} } + evalOpts = append( + []rego.EvalOption{ + rego.EvalParsedInput(input), + rego.EvalTransaction(result.Txn), + rego.EvalMetrics(result.Metrics), + rego.EvalInterQueryBuiltinCache(evalContext.InterQueryBuiltinCache()), + rego.EvalPrintHook(&ph), + rego.EvalNDBuiltinCache(ndbCache), + }, + evalOpts..., + ) + var rs rego.ResultSet - rs, err = evalContext.PreparedQuery().Eval( + rs, err = pq.Eval( ctx, - rego.EvalParsedInput(input), - rego.EvalTransaction(result.Txn), - rego.EvalMetrics(result.Metrics), - rego.EvalInterQueryBuiltinCache(evalContext.InterQueryBuiltinCache()), - rego.EvalPrintHook(&ph), - rego.EvalNDBuiltinCache(ndbCache), + evalOpts..., ) switch { @@ -101,28 +122,6 @@ func Eval(ctx context.Context, evalContext EvalContext, input ast.Value, result return nil } -func constructPreparedQuery(evalContext EvalContext, txn storage.Transaction, m metrics.Metrics, opts []func(*rego.Rego)) error { - var err error - var pq rego.PreparedEvalQuery - evalContext.PreparedQueryDoOnce().Do(func() { - opts = append(opts, - rego.Metrics(m), - rego.ParsedQuery(evalContext.ParsedQuery()), - rego.Compiler(evalContext.Compiler()), - rego.Store(evalContext.Store()), - rego.Transaction(txn), - rego.Runtime(evalContext.Runtime()), - rego.EnablePrintStatements(true), - rego.DistributedTracingOpts(evalContext.DistributedTracing()), - ) - - pq, err = rego.New(opts...).PrepareForEval(context.Background()) - evalContext.SetPreparedQuery(&pq) - }) - - return err -} - func getRevision(ctx context.Context, store storage.Store, txn storage.Transaction, result *EvalResult) error { revisions := map[string]string{} diff --git a/envoyauth/evaluation_test.go b/envoyauth/evaluation_test.go index 496bde6fc..f32249059 100644 --- a/envoyauth/evaluation_test.go +++ b/envoyauth/evaluation_test.go @@ -21,6 +21,7 @@ import ( "github.com/open-policy-agent/opa/storage" "github.com/open-policy-agent/opa/storage/inmem" iCache "github.com/open-policy-agent/opa/topdown/cache" + "github.com/open-policy-agent/opa/topdown/print" ) func TestGetRevisionLegacy(t *testing.T) { @@ -116,6 +117,15 @@ func TestGetRevisionMulti(t *testing.T) { } +type testPrintHook struct { + printed string +} + +func (h *testPrintHook) Print(pctx print.Context, msg string) error { + h.printed = msg + return nil +} + func TestEval(t *testing.T) { ctx := context.Background() @@ -180,6 +190,17 @@ func TestEval(t *testing.T) { if err != nil { t.Fatal(err) } + + hook := testPrintHook{} + + erp, _, _ := NewEvalResult() + if err := Eval(ctx, server, inputValue, erp, rego.EvalPrintHook(&hook)); err != nil { + t.Fatal(err) + } + + if exp, act := "{\"firstname\": \"foo\", \"lastname\": \"bar\"}", hook.printed; exp != act { + t.Errorf("expected last printed message to be %q, got %q", exp, act) + } } func testAuthzServer(logger logging.Logger) (*mockExtAuthzGrpcServer, error) { @@ -254,6 +275,7 @@ type mockExtAuthzGrpcServer struct { manager *plugins.Manager preparedQuery *rego.PreparedEvalQuery preparedQueryDoOnce *sync.Once + preparedQueryErr error distributedTracingOpts tracing.Options } @@ -301,6 +323,17 @@ func (m *mockExtAuthzGrpcServer) DistributedTracing() tracing.Options { return m.distributedTracingOpts } +func (m *mockExtAuthzGrpcServer) CreatePreparedQueryOnce(opts PrepareQueryOpts) (*rego.PreparedEvalQuery, error) { + m.preparedQueryDoOnce.Do(func() { + pq, err := rego.New(opts.Opts...).PrepareForEval(context.Background()) + + m.preparedQuery = &pq + m.preparedQueryErr = err + }) + + return m.preparedQuery, m.preparedQueryErr +} + type testPlugin struct { events []logs.EventV1 } diff --git a/internal/internal.go b/internal/internal.go index 8e72cc808..b5dffdf53 100644 --- a/internal/internal.go +++ b/internal/internal.go @@ -225,6 +225,7 @@ type envoyExtAuthzGrpcServer struct { manager *plugins.Manager preparedQuery *rego.PreparedEvalQuery preparedQueryDoOnce *sync.Once + preparedQueryErr error interQueryBuiltinCache iCache.InterQueryCache distributedTracingOpts tracing.Options metricAuthzDuration prometheus.HistogramVec @@ -279,6 +280,17 @@ func (p *envoyExtAuthzGrpcServer) DistributedTracing() tracing.Options { return p.distributedTracingOpts } +func (p *envoyExtAuthzGrpcServer) CreatePreparedQueryOnce(opts envoyauth.PrepareQueryOpts) (*rego.PreparedEvalQuery, error) { + p.preparedQueryDoOnce.Do(func() { + pq, err := rego.New(opts.Opts...).PrepareForEval(context.Background()) + + p.preparedQuery = &pq + p.preparedQueryErr = err + }) + + return p.preparedQuery, p.preparedQueryErr +} + func (p *envoyExtAuthzGrpcServer) Start(ctx context.Context) error { p.manager.UpdatePluginStatus(PluginName, &plugins.Status{State: plugins.StateNotReady}) go p.listen()