diff --git a/policy/engine.go b/policy/engine.go index 736e64fd9..0a0c7583b 100644 --- a/policy/engine.go +++ b/policy/engine.go @@ -19,19 +19,21 @@ import ( "github.com/open-policy-agent/opa/storage" "github.com/open-policy-agent/opa/storage/inmem" "github.com/open-policy-agent/opa/topdown" + "github.com/open-policy-agent/opa/topdown/cache" "github.com/open-policy-agent/opa/topdown/print" "github.com/open-policy-agent/opa/version" ) // Engine represents the policy engine. type Engine struct { - trace bool - builtinErrors bool - modules map[string]*ast.Module - compiler *ast.Compiler - store storage.Store - policies map[string]string - docs map[string]string + trace bool + builtinErrors bool + modules map[string]*ast.Module + compiler *ast.Compiler + store storage.Store + policies map[string]string + docs map[string]string + enableInterQueryCache bool } // CompilerOptions defines the options for the Rego compiler. @@ -171,6 +173,10 @@ func (e *Engine) ShowBuiltinErrors() { e.builtinErrors = true } +func (e *Engine) EnableInterQueryCache() { + e.enableInterQueryCache = true +} + // Check executes all of the loaded policies against the input and returns the results. func (e *Engine) Check(ctx context.Context, configs map[string]any, namespace string) ([]output.CheckResult, error) { var checkResults []output.CheckResult @@ -454,6 +460,9 @@ func (e *Engine) query(ctx context.Context, input any, query string) (output.Que rego.PrintHook(ph), rego.BuiltinErrorList(builtInErrors), } + if e.enableInterQueryCache { + options = append(options, rego.InterQueryBuiltinCache(cache.NewInterQueryCacheWithContext(ctx, nil))) + } regoInstance := rego.New(options...) resultSet, err := regoInstance.Eval(ctx) diff --git a/runner/test.go b/runner/test.go index 77ff8bfe5..65cec989e 100644 --- a/runner/test.go +++ b/runner/test.go @@ -76,6 +76,7 @@ func (t *TestRunner) Run(ctx context.Context, fileList []string) ([]output.Check if err != nil { return nil, fmt.Errorf("load: %w", err) } + engine.EnableInterQueryCache() if t.Trace { engine.EnableTracing()