diff --git a/go.mod b/go.mod index 46358dc49..65eb74240 100644 --- a/go.mod +++ b/go.mod @@ -1,29 +1,28 @@ module github.com/open-policy-agent/opa-envoy-plugin go 1.22.0 - toolchain go1.23.1 require ( github.com/envoyproxy/go-control-plane/envoy v1.32.2 github.com/golang/protobuf v1.5.4 - github.com/open-policy-agent/opa v0.70.0 + github.com/open-policy-agent/opa v1.0.0 github.com/pkg/errors v0.9.1 github.com/prometheus/client_golang v1.20.5 go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.53.0 - go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.53.0 - go.opentelemetry.io/otel v1.31.0 - go.opentelemetry.io/otel/sdk v1.31.0 - go.opentelemetry.io/otel/trace v1.31.0 + go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.58.0 + go.opentelemetry.io/otel v1.33.0 + go.opentelemetry.io/otel/sdk v1.33.0 + go.opentelemetry.io/otel/trace v1.33.0 golang.org/x/lint v0.0.0-20210508222113-6edffad5e616 golang.org/x/tools v0.28.0 - google.golang.org/genproto/googleapis/rpc v0.0.0-20241015192408-796eee8c2d53 + google.golang.org/genproto/googleapis/rpc v0.0.0-20241209162323-e6fa225c2576 google.golang.org/grpc v1.69.2 google.golang.org/protobuf v1.36.2 ) require ( - github.com/containerd/errdefs v0.3.0 // indirect + github.com/containerd/errdefs v1.0.0 // indirect github.com/containerd/platforms v0.2.1 // indirect github.com/hashicorp/hcl v1.0.0 // indirect github.com/magiconair/properties v1.8.7 // indirect @@ -37,8 +36,8 @@ require ( github.com/spf13/afero v1.11.0 // indirect github.com/spf13/cast v1.6.0 // indirect github.com/spf13/viper v1.18.2 // indirect - github.com/stretchr/testify v1.10.0 // indirect github.com/subosito/gotenv v1.6.0 // indirect + go.opentelemetry.io/auto/sdk v1.1.0 // indirect go.uber.org/atomic v1.9.0 // indirect go.uber.org/multierr v1.9.0 // indirect golang.org/x/exp v0.0.0-20230905200255-921286631fa9 // indirect @@ -56,14 +55,14 @@ require ( github.com/cespare/xxhash v1.1.0 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/cncf/xds/go v0.0.0-20240905190251-b4127c9b8d78 // indirect - github.com/containerd/containerd v1.7.23 // indirect + github.com/containerd/containerd v1.7.24 // indirect github.com/containerd/log v0.1.0 // indirect github.com/dgraph-io/badger/v3 v3.2103.5 // indirect github.com/dgraph-io/ristretto v0.1.1 // indirect github.com/dustin/go-humanize v1.0.0 // indirect github.com/envoyproxy/protoc-gen-validate v1.1.0 // indirect github.com/felixge/httpsnoop v1.0.4 // indirect - github.com/fsnotify/fsnotify v1.7.0 // indirect + github.com/fsnotify/fsnotify v1.8.0 // indirect github.com/go-ini/ini v1.67.0 // indirect github.com/go-logr/logr v1.4.2 // indirect github.com/go-logr/stdr v1.2.2 // indirect @@ -75,7 +74,7 @@ require ( github.com/google/flatbuffers v1.12.1 // indirect github.com/google/uuid v1.6.0 // indirect github.com/gorilla/mux v1.8.1 // indirect - github.com/grpc-ecosystem/grpc-gateway/v2 v2.20.0 // indirect + github.com/grpc-ecosystem/grpc-gateway/v2 v2.24.0 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/klauspost/compress v1.17.9 // indirect github.com/mattn/go-runewidth v0.0.13 // indirect @@ -99,18 +98,18 @@ require ( github.com/yashtewari/glob-intersection v0.2.0 // indirect go.opencensus.io v0.24.0 // indirect go.opentelemetry.io/contrib/propagators/b3 v1.28.0 - go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.28.0 // indirect - go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.28.0 // indirect - go.opentelemetry.io/otel/metric v1.31.0 // indirect - go.opentelemetry.io/proto/otlp v1.3.1 // indirect + go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.33.0 // indirect + go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.33.0 // indirect + go.opentelemetry.io/otel/metric v1.33.0 // indirect + go.opentelemetry.io/proto/otlp v1.4.0 // indirect go.uber.org/automaxprocs v1.6.0 // indirect golang.org/x/mod v0.22.0 // indirect - golang.org/x/net v0.32.0 // indirect + golang.org/x/net v0.33.0 // indirect golang.org/x/sync v0.10.0 // indirect golang.org/x/sys v0.28.0 // indirect golang.org/x/text v0.21.0 // indirect - golang.org/x/time v0.7.0 // indirect - google.golang.org/genproto/googleapis/api v0.0.0-20241015192408-796eee8c2d53 // indirect + golang.org/x/time v0.8.0 // indirect + google.golang.org/genproto/googleapis/api v0.0.0-20241209162323-e6fa225c2576 // indirect oras.land/oras-go/v2 v2.3.1 // indirect sigs.k8s.io/yaml v1.4.0 // indirect ) diff --git a/go.sum b/go.sum index 13162a48a..c4627e499 100644 --- a/go.sum +++ b/go.sum @@ -32,12 +32,12 @@ github.com/cncf/xds/go v0.0.0-20240905190251-b4127c9b8d78 h1:QVw89YDxXxEe+l8gU8E github.com/cncf/xds/go v0.0.0-20240905190251-b4127c9b8d78/go.mod h1:W+zGtBO5Y1IgJhy4+A9GOqVhqLpfZi+vwmdNXUehLA8= github.com/containerd/cgroups v1.1.0 h1:v8rEWFl6EoqHB+swVNjVoCJE8o3jX7e8nqBGPLaDFBM= github.com/containerd/cgroups v1.1.0/go.mod h1:6ppBcbh/NOOUU+dMKrykgaBnK9lCIBxHqJDGwsa1mIw= -github.com/containerd/containerd v1.7.23 h1:H2CClyUkmpKAGlhQp95g2WXHfLYc7whAuvZGBNYOOwQ= -github.com/containerd/containerd v1.7.23/go.mod h1:7QUzfURqZWCZV7RLNEn1XjUCQLEf0bkaK4GjUaZehxw= +github.com/containerd/containerd v1.7.24 h1:zxszGrGjrra1yYJW/6rhm9cJ1ZQ8rkKBR48brqsa7nA= +github.com/containerd/containerd v1.7.24/go.mod h1:7QUzfURqZWCZV7RLNEn1XjUCQLEf0bkaK4GjUaZehxw= github.com/containerd/continuity v0.4.2 h1:v3y/4Yz5jwnvqPKJJ+7Wf93fyWoCB3F5EclWG023MDM= github.com/containerd/continuity v0.4.2/go.mod h1:F6PTNCKepoxEaXLQp3wDAjygEnImnZ/7o4JzpodfroQ= -github.com/containerd/errdefs v0.3.0 h1:FSZgGOeK4yuT/+DnF07/Olde/q4KBoMsaamhXxIMDp4= -github.com/containerd/errdefs v0.3.0/go.mod h1:+YBYIdtsnF4Iw6nWZhJcqGSg/dwvV7tyJ/kCkyJ2k+M= +github.com/containerd/errdefs v1.0.0 h1:tg5yIfIlQIrxYtu9ajqY42W3lpS19XqdxRQeEwYG8PI= +github.com/containerd/errdefs v1.0.0/go.mod h1:+YBYIdtsnF4Iw6nWZhJcqGSg/dwvV7tyJ/kCkyJ2k+M= github.com/containerd/log v0.1.0 h1:TCJt7ioM2cr/tfR8GPbGf9/VRAX8D2B4PjzCpfX540I= github.com/containerd/log v0.1.0/go.mod h1:VRRf09a7mHDIRezVKTRCrOq78v577GXq3bSa3EhrzVo= github.com/containerd/platforms v0.2.1 h1:zvwtM3rz2YHPQsF2CHYM8+KtB5dvhISiXh5ZpSBQv6A= @@ -79,8 +79,8 @@ github.com/foxcpp/go-mockdns v1.1.0/go.mod h1:IhLeSFGed3mJIAXPH2aiRQB+kqz7oqu8ld github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8= github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0= github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= -github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nosvA= -github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyTJwOpGvigM= +github.com/fsnotify/fsnotify v1.8.0 h1:dAwr6QBTBZIkG8roQaJjGof0pp0EeF+tNV7YBP3F/8M= +github.com/fsnotify/fsnotify v1.8.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0= github.com/go-ini/ini v1.67.0 h1:z6ZrTEZqSWOTyH2FlglNbNgARyHG8oLW9gMELqKr06A= github.com/go-ini/ini v1.67.0/go.mod h1:ByCAeIL28uOIIG0E3PJtZPDL8WnHpFKFOtgjp+3Ies8= github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= @@ -132,8 +132,8 @@ github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/gorilla/mux v1.8.1 h1:TuBL49tXwgrFYWhqrNgrUNEY92u81SPhu7sTdzQEiWY= github.com/gorilla/mux v1.8.1/go.mod h1:AKf9I4AEqPTmMytcMc0KkNouC66V3BtZ4qD5fmWSiMQ= -github.com/grpc-ecosystem/grpc-gateway/v2 v2.20.0 h1:bkypFPDjIYGfCYD5mRBvpqxfYX1YCS1PXdKYWi8FsN0= -github.com/grpc-ecosystem/grpc-gateway/v2 v2.20.0/go.mod h1:P+Lt/0by1T8bfcF3z737NnSbmxQAppXMRziHUxPOC8k= +github.com/grpc-ecosystem/grpc-gateway/v2 v2.24.0 h1:TmHmbvxPmaegwhDubVz0lICL0J5Ka2vwTzhoePEXsGE= +github.com/grpc-ecosystem/grpc-gateway/v2 v2.24.0/go.mod h1:qztMSjm835F2bXf+5HKAPIS5qsmQDqZna/PgVt4rWtI= github.com/hashicorp/hcl v1.0.0 h1:0Anlzjpi4vEasTeNFn2mLJgTSwt0+6sfsiTG8qcWGx4= github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ= github.com/inconshreveable/mousetrap v1.0.0/go.mod h1:PxqpIevigyE2G7u3NXJIT2ANytuPF1OarO4DADm73n8= @@ -176,8 +176,8 @@ github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ= github.com/olekukonko/tablewriter v0.0.5 h1:P2Ga83D34wi1o9J6Wh1mRuqd4mF/x/lgBS7N7AbDhec= github.com/olekukonko/tablewriter v0.0.5/go.mod h1:hPp6KlRPjbx+hW8ykQs1w3UBbZlj6HuIJcUGPhkA7kY= -github.com/open-policy-agent/opa v0.70.0 h1:B3cqCN2iQAyKxK6+GI+N40uqkin+wzIrM7YA60t9x1U= -github.com/open-policy-agent/opa v0.70.0/go.mod h1:Y/nm5NY0BX0BqjBriKUiV81sCl8XOjjvqQG7dXrggtI= +github.com/open-policy-agent/opa v1.0.0 h1:fZsEwxg1knpPvUn0YDJuJZBcbVg4G3zKpWa3+CnYK+I= +github.com/open-policy-agent/opa v1.0.0/go.mod h1:+JyoH12I0+zqyC1iX7a2tmoQlipwAEGvOhVJMhmy+rM= github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U= github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM= github.com/opencontainers/image-spec v1.1.0 h1:8SG7/vwALn54lVB/0yZ/MMwhFrPYtpEHQb2IpWsCzug= @@ -209,8 +209,8 @@ github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475 h1:N/ElC8H3+5X github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475/go.mod h1:bCqnVzQkZxMG4s8nGwiZ5l3QUCyqpo9Y+/ZMZ9VjZe4= github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY= github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= -github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU5NdKM8= -github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4= +github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII= +github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o= github.com/russross/blackfriday v1.5.2/go.mod h1:JO/DiYxRf+HjHt06OyowR9PTA263kcR/rfWxYHBV53g= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/sagikazarmark/locafero v0.4.0 h1:HApY1R9zGo4DBgr7dqsTH/JJxLTTsOt7u6keLGt6kNQ= @@ -272,28 +272,30 @@ github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9dec go.opencensus.io v0.22.5/go.mod h1:5pWMHQbX5EPX2/62yrJeAkowc+lfs/XD7Uxpq3pI6kk= go.opencensus.io v0.24.0 h1:y73uSU6J157QMP2kn2r30vwW1A2W2WFwSCGnAVxeaD0= go.opencensus.io v0.24.0/go.mod h1:vNK8G9p7aAivkbmorf4v+7Hgx+Zs0yY+0fOtgBfjQKo= +go.opentelemetry.io/auto/sdk v1.1.0 h1:cH53jehLUN6UFLY71z+NDOiNJqDdPRaXzTel0sJySYA= +go.opentelemetry.io/auto/sdk v1.1.0/go.mod h1:3wSPjt5PWp2RhlCcmmOial7AvC4DQqZb7a7wCow3W8A= go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.53.0 h1:9G6E0TXzGFVfTnawRzrPl83iHOAV7L8NJiR8RSGYV1g= go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.53.0/go.mod h1:azvtTADFQJA8mX80jIH/akaE7h+dbm/sVuaHqN13w74= -go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.53.0 h1:4K4tsIXefpVJtvA/8srF4V4y0akAoPHkIslgAkjixJA= -go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.53.0/go.mod h1:jjdQuTGVsXV4vSs+CJ2qYDeDPf9yIJV23qlIzBm73Vg= +go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.58.0 h1:yd02MEjBdJkG3uabWP9apV+OuWRIXGDuJEUJbOHmCFU= +go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.58.0/go.mod h1:umTcuxiv1n/s/S6/c2AT/g2CQ7u5C59sHDNmfSwgz7Q= go.opentelemetry.io/contrib/propagators/b3 v1.28.0 h1:XR6CFQrQ/ttAYmTBX2loUEFGdk1h17pxYI8828dk/1Y= go.opentelemetry.io/contrib/propagators/b3 v1.28.0/go.mod h1:DWRkzJONLquRz7OJPh2rRbZ7MugQj62rk7g6HRnEqh0= -go.opentelemetry.io/otel v1.31.0 h1:NsJcKPIW0D0H3NgzPDHmo0WW6SptzPdqg/L1zsIm2hY= -go.opentelemetry.io/otel v1.31.0/go.mod h1:O0C14Yl9FgkjqcCZAsE053C13OaddMYr/hz6clDkEJE= -go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.28.0 h1:3Q/xZUyC1BBkualc9ROb4G8qkH90LXEIICcs5zv1OYY= -go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.28.0/go.mod h1:s75jGIWA9OfCMzF0xr+ZgfrB5FEbbV7UuYo32ahUiFI= -go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.28.0 h1:R3X6ZXmNPRR8ul6i3WgFURCHzaXjHdm0karRG/+dj3s= -go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.28.0/go.mod h1:QWFXnDavXWwMx2EEcZsf3yxgEKAqsxQ+Syjp+seyInw= -go.opentelemetry.io/otel/metric v1.31.0 h1:FSErL0ATQAmYHUIzSezZibnyVlft1ybhy4ozRPcF2fE= -go.opentelemetry.io/otel/metric v1.31.0/go.mod h1:C3dEloVbLuYoX41KpmAhOqNriGbA+qqH6PQ5E5mUfnY= -go.opentelemetry.io/otel/sdk v1.31.0 h1:xLY3abVHYZ5HSfOg3l2E5LUj2Cwva5Y7yGxnSW9H5Gk= -go.opentelemetry.io/otel/sdk v1.31.0/go.mod h1:TfRbMdhvxIIr/B2N2LQW2S5v9m3gOQ/08KsbbO5BPT0= +go.opentelemetry.io/otel v1.33.0 h1:/FerN9bax5LoK51X/sI0SVYrjSE0/yUL7DpxW4K3FWw= +go.opentelemetry.io/otel v1.33.0/go.mod h1:SUUkR6csvUQl+yjReHu5uM3EtVV7MBm5FHKRlNx4I8I= +go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.33.0 h1:Vh5HayB/0HHfOQA7Ctx69E/Y/DcQSMPpKANYVMQ7fBA= +go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.33.0/go.mod h1:cpgtDBaqD/6ok/UG0jT15/uKjAY8mRA53diogHBg3UI= +go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.33.0 h1:5pojmb1U1AogINhN3SurB+zm/nIcusopeBNp42f45QM= +go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.33.0/go.mod h1:57gTHJSE5S1tqg+EKsLPlTWhpHMsWlVmer+LA926XiA= +go.opentelemetry.io/otel/metric v1.33.0 h1:r+JOocAyeRVXD8lZpjdQjzMadVZp2M4WmQ+5WtEnklQ= +go.opentelemetry.io/otel/metric v1.33.0/go.mod h1:L9+Fyctbp6HFTddIxClbQkjtubW6O9QS3Ann/M82u6M= +go.opentelemetry.io/otel/sdk v1.33.0 h1:iax7M131HuAm9QkZotNHEfstof92xM+N8sr3uHXc2IM= +go.opentelemetry.io/otel/sdk v1.33.0/go.mod h1:A1Q5oi7/9XaMlIWzPSxLRWOI8nG3FnzHJNbiENQuihM= go.opentelemetry.io/otel/sdk/metric v1.31.0 h1:i9hxxLJF/9kkvfHppyLL55aW7iIJz4JjxTeYusH7zMc= go.opentelemetry.io/otel/sdk/metric v1.31.0/go.mod h1:CRInTMVvNhUKgSAMbKyTMxqOBC0zgyxzW55lZzX43Y8= -go.opentelemetry.io/otel/trace v1.31.0 h1:ffjsj1aRouKewfr85U2aGagJ46+MvodynlQ1HYdmJys= -go.opentelemetry.io/otel/trace v1.31.0/go.mod h1:TXZkRk7SM2ZQLtR6eoAWQFIHPvzQ06FJAsO1tJg480A= -go.opentelemetry.io/proto/otlp v1.3.1 h1:TrMUixzpM0yuc/znrFTP9MMRh8trP93mkCiDVeXrui0= -go.opentelemetry.io/proto/otlp v1.3.1/go.mod h1:0X1WI4de4ZsLrrJNLAQbFeLCm3T7yBkR0XqQ7niQU+8= +go.opentelemetry.io/otel/trace v1.33.0 h1:cCJuF7LRjUFso9LPnEAHJDB2pqzp+hbO8eu1qqW2d/s= +go.opentelemetry.io/otel/trace v1.33.0/go.mod h1:uIcdVUZMpTAmz0tI1z04GoVSezK37CbGV4fr1f2nBck= +go.opentelemetry.io/proto/otlp v1.4.0 h1:TA9WRvW6zMwP+Ssb6fLoUIuirti1gGbP28GcKG1jgeg= +go.opentelemetry.io/proto/otlp v1.4.0/go.mod h1:PPBWZIP98o2ElSqI35IHfu7hIhSwvc5N38Jw8pXuGFY= go.uber.org/atomic v1.9.0 h1:ECmE8Bn/WFTYwEW/bpKD3M8VtR/zQVbavAoalC1PYyE= go.uber.org/atomic v1.9.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= go.uber.org/automaxprocs v1.6.0 h1:O3y2/QNTOdbF+e/dpXNNW7Rx2hZ4sTIPyybbxyNqTUs= @@ -328,8 +330,8 @@ golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLL golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= golang.org/x/net v0.0.0-20201110031124-69a78807bb2b/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= -golang.org/x/net v0.32.0 h1:ZqPmj8Kzc+Y6e0+skZsuACbx+wzMgo5MQsJh9Qd6aYI= -golang.org/x/net v0.32.0/go.mod h1:CwU0IoeOlnQQWJ6ioyFrfRuomB8GKF6KbYXZVyeXNfs= +golang.org/x/net v0.33.0 h1:74SYHlV8BIgHIFC/LrYkOGIwL19eTYXQ5wc6TBuO36I= +golang.org/x/net v0.33.0/go.mod h1:HXLR5J+9DxmrqMwG9qjGCxZ+zKXxBru04zlTvWlWuN4= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= @@ -354,8 +356,8 @@ golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo= golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ= -golang.org/x/time v0.7.0 h1:ntUhktv3OPE6TgYxXWv9vKvUSJyIFJlyohwbkEwPrKQ= -golang.org/x/time v0.7.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= +golang.org/x/time v0.8.0 h1:9i3RxcPv3PZnitoVGMPDKZSq1xW1gK1Xy3ArNOGZfEg= +golang.org/x/time v0.8.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY= @@ -377,10 +379,10 @@ google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoA google.golang.org/genproto v0.0.0-20190425155659-357c62f0e4bb/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE= google.golang.org/genproto v0.0.0-20190819201941-24fa4b261c55/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc= google.golang.org/genproto v0.0.0-20200526211855-cb27e3aa2013/go.mod h1:NbSheEEYHJ7i3ixzK3sjbqSGDJWnxyFXZblF3eUsNvo= -google.golang.org/genproto/googleapis/api v0.0.0-20241015192408-796eee8c2d53 h1:fVoAXEKA4+yufmbdVYv+SE73+cPZbbbe8paLsHfkK+U= -google.golang.org/genproto/googleapis/api v0.0.0-20241015192408-796eee8c2d53/go.mod h1:riSXTwQ4+nqmPGtobMFyW5FqVAmIs0St6VPp4Ug7CE4= -google.golang.org/genproto/googleapis/rpc v0.0.0-20241015192408-796eee8c2d53 h1:X58yt85/IXCx0Y3ZwN6sEIKZzQtDEYaBWrDvErdXrRE= -google.golang.org/genproto/googleapis/rpc v0.0.0-20241015192408-796eee8c2d53/go.mod h1:GX3210XPVPUjJbTUbvwI8f2IpZDMZuPJWDzDuebbviI= +google.golang.org/genproto/googleapis/api v0.0.0-20241209162323-e6fa225c2576 h1:CkkIfIt50+lT6NHAVoRYEyAvQGFM7xEwXUUywFvEb3Q= +google.golang.org/genproto/googleapis/api v0.0.0-20241209162323-e6fa225c2576/go.mod h1:1R3kvZ1dtP3+4p4d3G8uJ8rFk/fWlScl38vanWACI08= +google.golang.org/genproto/googleapis/rpc v0.0.0-20241209162323-e6fa225c2576 h1:8ZmaLZE4XWrtU3MyClkYqqtl6Oegr3235h7jxsDyqCY= +google.golang.org/genproto/googleapis/rpc v0.0.0-20241209162323-e6fa225c2576/go.mod h1:5uTbfoYQed2U9p3KIj2/Zzm02PYhndfdmML0qC3q3FU= google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c= google.golang.org/grpc v1.20.1/go.mod h1:10oTOabMzJvdu6/UiuZezV6QK5dSlG84ov/aaiqXj38= google.golang.org/grpc v1.23.0/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyacEbxg= diff --git a/vendor/github.com/containerd/containerd/content/local/store.go b/vendor/github.com/containerd/containerd/content/local/store.go index e1baee4c2..efe886014 100644 --- a/vendor/github.com/containerd/containerd/content/local/store.go +++ b/vendor/github.com/containerd/containerd/content/local/store.go @@ -67,6 +67,8 @@ type LabelStore interface { type store struct { root string ls LabelStore + + ensureIngestRootOnce func() error } // NewStore returns a local content store @@ -80,14 +82,13 @@ func NewStore(root string) (content.Store, error) { // require labels and should use `NewStore`. `NewLabeledStore` is primarily // useful for tests or standalone implementations. func NewLabeledStore(root string, ls LabelStore) (content.Store, error) { - if err := os.MkdirAll(filepath.Join(root, "ingest"), 0777); err != nil { - return nil, err - } - - return &store{ + s := &store{ root: root, ls: ls, - }, nil + } + + s.ensureIngestRootOnce = sync.OnceValue(s.ensureIngestRoot) + return s, nil } func (s *store) Info(ctx context.Context, dgst digest.Digest) (content.Info, error) { @@ -294,6 +295,9 @@ func (s *store) Status(ctx context.Context, ref string) (content.Status, error) func (s *store) ListStatuses(ctx context.Context, fs ...string) ([]content.Status, error) { fp, err := os.Open(filepath.Join(s.root, "ingest")) if err != nil { + if os.IsNotExist(err) { + return nil, nil + } return nil, err } @@ -344,6 +348,9 @@ func (s *store) ListStatuses(ctx context.Context, fs ...string) ([]content.Statu func (s *store) WalkStatusRefs(ctx context.Context, fn func(string) error) error { fp, err := os.Open(filepath.Join(s.root, "ingest")) if err != nil { + if os.IsNotExist(err) { + return nil + } return err } @@ -545,6 +552,11 @@ func (s *store) writer(ctx context.Context, ref string, total int64, expected di ) foundValidIngest := false + + if err := s.ensureIngestRootOnce(); err != nil { + return nil, err + } + // ensure that the ingest path has been created. if err := os.Mkdir(path, 0755); err != nil { if !os.IsExist(err) { @@ -655,6 +667,10 @@ func (s *store) ingestPaths(ref string) (string, string, string) { return fp, rp, dp } +func (s *store) ensureIngestRoot() error { + return os.MkdirAll(filepath.Join(s.root, "ingest"), 0777) +} + func readFileString(path string) (string, error) { p, err := os.ReadFile(path) return string(p), err diff --git a/vendor/github.com/containerd/containerd/remotes/docker/resolver.go b/vendor/github.com/containerd/containerd/remotes/docker/resolver.go index 4ca2b921e..8ce4cccc0 100644 --- a/vendor/github.com/containerd/containerd/remotes/docker/resolver.go +++ b/vendor/github.com/containerd/containerd/remotes/docker/resolver.go @@ -25,8 +25,10 @@ import ( "net" "net/http" "net/url" + "os" "path" "strings" + "sync" "github.com/containerd/log" "github.com/opencontainers/go-digest" @@ -717,13 +719,18 @@ func NewHTTPFallback(transport http.RoundTripper) http.RoundTripper { type httpFallback struct { super http.RoundTripper host string + mu sync.Mutex } func (f *httpFallback) RoundTrip(r *http.Request) (*http.Response, error) { + f.mu.Lock() + fallback := f.host == r.URL.Host + f.mu.Unlock() + // only fall back if the same host had previously fell back - if f.host != r.URL.Host { + if !fallback { resp, err := f.super.RoundTrip(r) - if !isTLSError(err) { + if !isTLSError(err) && !isPortError(err, r.URL.Host) { return resp, err } } @@ -734,8 +741,12 @@ func (f *httpFallback) RoundTrip(r *http.Request) (*http.Response, error) { plainHTTPRequest := *r plainHTTPRequest.URL = &plainHTTPUrl - if f.host != r.URL.Host { - f.host = r.URL.Host + if !fallback { + f.mu.Lock() + if f.host != r.URL.Host { + f.host = r.URL.Host + } + f.mu.Unlock() // update body on the second attempt if r.Body != nil && r.GetBody != nil { @@ -765,6 +776,18 @@ func isTLSError(err error) bool { return false } +func isPortError(err error, host string) bool { + if isConnError(err) || os.IsTimeout(err) { + if _, port, _ := net.SplitHostPort(host); port != "" { + // Port is specified, will not retry on different port with scheme change + return false + } + return true + } + + return false +} + // HTTPFallback is an http.RoundTripper which allows fallback from https to http // for registry endpoints with configurations for both http and TLS, such as // defaulted localhost endpoints. diff --git a/vendor/github.com/containerd/containerd/remotes/docker/resolver_unix.go b/vendor/github.com/containerd/containerd/remotes/docker/resolver_unix.go new file mode 100644 index 000000000..4ef0e0062 --- /dev/null +++ b/vendor/github.com/containerd/containerd/remotes/docker/resolver_unix.go @@ -0,0 +1,28 @@ +//go:build !windows + +/* + Copyright The containerd Authors. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package docker + +import ( + "errors" + "syscall" +) + +func isConnError(err error) bool { + return errors.Is(err, syscall.ECONNREFUSED) +} diff --git a/vendor/github.com/containerd/containerd/remotes/docker/resolver_windows.go b/vendor/github.com/containerd/containerd/remotes/docker/resolver_windows.go new file mode 100644 index 000000000..9c98df04b --- /dev/null +++ b/vendor/github.com/containerd/containerd/remotes/docker/resolver_windows.go @@ -0,0 +1,30 @@ +//go:build windows + +/* + Copyright The containerd Authors. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package docker + +import ( + "errors" + "syscall" + + "golang.org/x/sys/windows" +) + +func isConnError(err error) bool { + return errors.Is(err, syscall.ECONNREFUSED) || errors.Is(err, windows.WSAECONNREFUSED) +} diff --git a/vendor/github.com/containerd/containerd/version/version.go b/vendor/github.com/containerd/containerd/version/version.go index c61791188..b83e75964 100644 --- a/vendor/github.com/containerd/containerd/version/version.go +++ b/vendor/github.com/containerd/containerd/version/version.go @@ -23,7 +23,7 @@ var ( Package = "github.com/containerd/containerd" // Version holds the complete version number. Filled in at linking time. - Version = "1.7.23+unknown" + Version = "1.7.24+unknown" // Revision is filled with the VCS (e.g. git) revision being used to build // the program at linking time. diff --git a/vendor/github.com/fsnotify/fsnotify/.cirrus.yml b/vendor/github.com/fsnotify/fsnotify/.cirrus.yml index ffc7b992b..f4e7dbf37 100644 --- a/vendor/github.com/fsnotify/fsnotify/.cirrus.yml +++ b/vendor/github.com/fsnotify/fsnotify/.cirrus.yml @@ -1,7 +1,7 @@ freebsd_task: name: 'FreeBSD' freebsd_instance: - image_family: freebsd-13-2 + image_family: freebsd-14-1 install_script: - pkg update -f - pkg install -y go @@ -9,5 +9,6 @@ freebsd_task: # run tests as user "cirrus" instead of root - pw useradd cirrus -m - chown -R cirrus:cirrus . - - FSNOTIFY_BUFFER=4096 sudo --preserve-env=FSNOTIFY_BUFFER -u cirrus go test -parallel 1 -race ./... - - sudo --preserve-env=FSNOTIFY_BUFFER -u cirrus go test -parallel 1 -race ./... + - FSNOTIFY_BUFFER=4096 sudo --preserve-env=FSNOTIFY_BUFFER -u cirrus go test -parallel 1 -race ./... + - sudo --preserve-env=FSNOTIFY_BUFFER -u cirrus go test -parallel 1 -race ./... + - FSNOTIFY_DEBUG=1 sudo --preserve-env=FSNOTIFY_BUFFER -u cirrus go test -parallel 1 -race -v ./... diff --git a/vendor/github.com/fsnotify/fsnotify/.editorconfig b/vendor/github.com/fsnotify/fsnotify/.editorconfig deleted file mode 100644 index fad895851..000000000 --- a/vendor/github.com/fsnotify/fsnotify/.editorconfig +++ /dev/null @@ -1,12 +0,0 @@ -root = true - -[*.go] -indent_style = tab -indent_size = 4 -insert_final_newline = true - -[*.{yml,yaml}] -indent_style = space -indent_size = 2 -insert_final_newline = true -trim_trailing_whitespace = true diff --git a/vendor/github.com/fsnotify/fsnotify/.gitattributes b/vendor/github.com/fsnotify/fsnotify/.gitattributes deleted file mode 100644 index 32f1001be..000000000 --- a/vendor/github.com/fsnotify/fsnotify/.gitattributes +++ /dev/null @@ -1 +0,0 @@ -go.sum linguist-generated diff --git a/vendor/github.com/fsnotify/fsnotify/.gitignore b/vendor/github.com/fsnotify/fsnotify/.gitignore index 391cc076b..daea9dd6d 100644 --- a/vendor/github.com/fsnotify/fsnotify/.gitignore +++ b/vendor/github.com/fsnotify/fsnotify/.gitignore @@ -5,3 +5,6 @@ # Output of go build ./cmd/fsnotify /fsnotify /fsnotify.exe + +/test/kqueue +/test/a.out diff --git a/vendor/github.com/fsnotify/fsnotify/CHANGELOG.md b/vendor/github.com/fsnotify/fsnotify/CHANGELOG.md index e0e575754..fa854785d 100644 --- a/vendor/github.com/fsnotify/fsnotify/CHANGELOG.md +++ b/vendor/github.com/fsnotify/fsnotify/CHANGELOG.md @@ -1,8 +1,36 @@ # Changelog -Unreleased ----------- -Nothing yet. +1.8.0 2023-10-31 +---------------- + +### Additions + +- all: add `FSNOTIFY_DEBUG` to print debug logs to stderr ([#619]) + +### Changes and fixes + +- windows: fix behaviour of `WatchList()` to be consistent with other platforms ([#610]) + +- kqueue: ignore events with Ident=0 ([#590]) + +- kqueue: set O_CLOEXEC to prevent passing file descriptors to children ([#617]) + +- kqueue: emit events as "/path/dir/file" instead of "path/link/file" when watching a symlink ([#625]) + +- inotify: don't send event for IN_DELETE_SELF when also watching the parent ([#620]) + +- inotify: fix panic when calling Remove() in a goroutine ([#650]) + +- fen: allow watching subdirectories of watched directories ([#621]) + +[#590]: https://github.com/fsnotify/fsnotify/pull/590 +[#610]: https://github.com/fsnotify/fsnotify/pull/610 +[#617]: https://github.com/fsnotify/fsnotify/pull/617 +[#619]: https://github.com/fsnotify/fsnotify/pull/619 +[#620]: https://github.com/fsnotify/fsnotify/pull/620 +[#621]: https://github.com/fsnotify/fsnotify/pull/621 +[#625]: https://github.com/fsnotify/fsnotify/pull/625 +[#650]: https://github.com/fsnotify/fsnotify/pull/650 1.7.0 - 2023-10-22 ------------------ diff --git a/vendor/github.com/fsnotify/fsnotify/CONTRIBUTING.md b/vendor/github.com/fsnotify/fsnotify/CONTRIBUTING.md index ea379759d..e4ac2a2ff 100644 --- a/vendor/github.com/fsnotify/fsnotify/CONTRIBUTING.md +++ b/vendor/github.com/fsnotify/fsnotify/CONTRIBUTING.md @@ -1,7 +1,7 @@ Thank you for your interest in contributing to fsnotify! We try to review and merge PRs in a reasonable timeframe, but please be aware that: -- To avoid "wasted" work, please discus changes on the issue tracker first. You +- To avoid "wasted" work, please discuss changes on the issue tracker first. You can just send PRs, but they may end up being rejected for one reason or the other. @@ -20,6 +20,124 @@ platforms. Testing different platforms locally can be done with something like Use the `-short` flag to make the "stress test" run faster. +Writing new tests +----------------- +Scripts in the testdata directory allow creating test cases in a "shell-like" +syntax. The basic format is: + + script + + Output: + desired output + +For example: + + # Create a new empty file with some data. + watch / + echo data >/file + + Output: + create /file + write /file + +Just create a new file to add a new test; select which tests to run with +`-run TestScript/[path]`. + +script +------ +The script is a "shell-like" script: + + cmd arg arg + +Comments are supported with `#`: + + # Comment + cmd arg arg # Comment + +All operations are done in a temp directory; a path like "/foo" is rewritten to +"/tmp/TestFoo/foo". + +Arguments can be quoted with `"` or `'`; there are no escapes and they're +functionally identical right now, but this may change in the future, so best to +assume shell-like rules. + + touch "/file with spaces" + +End-of-line escapes with `\` are not supported. + +### Supported commands + + watch path [ops] # Watch the path, reporting events for it. Nothing is + # watched by default. Optionally a list of ops can be + # given, as with AddWith(path, WithOps(...)). + unwatch path # Stop watching the path. + watchlist n # Assert watchlist length. + + stop # Stop running the script; for debugging. + debug [yes/no] # Enable/disable FSNOTIFY_DEBUG (tests are run in + parallel by default, so -parallel=1 is probably a good + idea). + + touch path + mkdir [-p] dir + ln -s target link # Only ln -s supported. + mkfifo path + mknod dev path + mv src dst + rm [-r] path + chmod mode path # Octal only + sleep time-in-ms + + cat path # Read path (does nothing with the data; just reads it). + echo str >>path # Append "str" to "path". + echo str >path # Truncate "path" and write "str". + + require reason # Skip the test if "reason" is true; "skip" and + skip reason # "require" behave identical; it supports both for + # readability. Possible reasons are: + # + # always Always skip this test. + # symlink Symlinks are supported (requires admin + # permissions on Windows). + # mkfifo Platform doesn't support FIFO named sockets. + # mknod Platform doesn't support device nodes. + + +output +------ +After `Output:` the desired output is given; this is indented by convention, but +that's not required. + +The format of that is: + + # Comment + event path # Comment + + system: + event path + system2: + event path + +Every event is one line, and any whitespace between the event and path are +ignored. The path can optionally be surrounded in ". Anything after a "#" is +ignored. + +Platform-specific tests can be added after GOOS; for example: + + watch / + touch /file + + Output: + # Tested if nothing else matches + create /file + + # Windows-specific test. + windows: + write /file + +You can specify multiple platforms with a comma (e.g. "windows, linux:"). +"kqueue" is a shortcut for all kqueue systems (BSD, macOS). + [goon]: https://github.com/arp242/goon [Vagrant]: https://www.vagrantup.com/ diff --git a/vendor/github.com/fsnotify/fsnotify/backend_fen.go b/vendor/github.com/fsnotify/fsnotify/backend_fen.go index 28497f1dd..c349c326c 100644 --- a/vendor/github.com/fsnotify/fsnotify/backend_fen.go +++ b/vendor/github.com/fsnotify/fsnotify/backend_fen.go @@ -1,8 +1,8 @@ //go:build solaris -// +build solaris -// Note: the documentation on the Watcher type and methods is generated from -// mkdoc.zsh +// FEN backend for illumos (supported) and Solaris (untested, but should work). +// +// See port_create(3c) etc. for docs. https://www.illumos.org/man/3C/port_create package fsnotify @@ -12,150 +12,33 @@ import ( "os" "path/filepath" "sync" + "time" + "github.com/fsnotify/fsnotify/internal" "golang.org/x/sys/unix" ) -// Watcher watches a set of paths, delivering events on a channel. -// -// A watcher should not be copied (e.g. pass it by pointer, rather than by -// value). -// -// # Linux notes -// -// When a file is removed a Remove event won't be emitted until all file -// descriptors are closed, and deletes will always emit a Chmod. For example: -// -// fp := os.Open("file") -// os.Remove("file") // Triggers Chmod -// fp.Close() // Triggers Remove -// -// This is the event that inotify sends, so not much can be changed about this. -// -// The fs.inotify.max_user_watches sysctl variable specifies the upper limit -// for the number of watches per user, and fs.inotify.max_user_instances -// specifies the maximum number of inotify instances per user. Every Watcher you -// create is an "instance", and every path you add is a "watch". -// -// These are also exposed in /proc as /proc/sys/fs/inotify/max_user_watches and -// /proc/sys/fs/inotify/max_user_instances -// -// To increase them you can use sysctl or write the value to the /proc file: -// -// # Default values on Linux 5.18 -// sysctl fs.inotify.max_user_watches=124983 -// sysctl fs.inotify.max_user_instances=128 -// -// To make the changes persist on reboot edit /etc/sysctl.conf or -// /usr/lib/sysctl.d/50-default.conf (details differ per Linux distro; check -// your distro's documentation): -// -// fs.inotify.max_user_watches=124983 -// fs.inotify.max_user_instances=128 -// -// Reaching the limit will result in a "no space left on device" or "too many open -// files" error. -// -// # kqueue notes (macOS, BSD) -// -// kqueue requires opening a file descriptor for every file that's being watched; -// so if you're watching a directory with five files then that's six file -// descriptors. You will run in to your system's "max open files" limit faster on -// these platforms. -// -// The sysctl variables kern.maxfiles and kern.maxfilesperproc can be used to -// control the maximum number of open files, as well as /etc/login.conf on BSD -// systems. -// -// # Windows notes -// -// Paths can be added as "C:\path\to\dir", but forward slashes -// ("C:/path/to/dir") will also work. -// -// When a watched directory is removed it will always send an event for the -// directory itself, but may not send events for all files in that directory. -// Sometimes it will send events for all times, sometimes it will send no -// events, and often only for some files. -// -// The default ReadDirectoryChangesW() buffer size is 64K, which is the largest -// value that is guaranteed to work with SMB filesystems. If you have many -// events in quick succession this may not be enough, and you will have to use -// [WithBufferSize] to increase the value. -type Watcher struct { - // Events sends the filesystem change events. - // - // fsnotify can send the following events; a "path" here can refer to a - // file, directory, symbolic link, or special file like a FIFO. - // - // fsnotify.Create A new path was created; this may be followed by one - // or more Write events if data also gets written to a - // file. - // - // fsnotify.Remove A path was removed. - // - // fsnotify.Rename A path was renamed. A rename is always sent with the - // old path as Event.Name, and a Create event will be - // sent with the new name. Renames are only sent for - // paths that are currently watched; e.g. moving an - // unmonitored file into a monitored directory will - // show up as just a Create. Similarly, renaming a file - // to outside a monitored directory will show up as - // only a Rename. - // - // fsnotify.Write A file or named pipe was written to. A Truncate will - // also trigger a Write. A single "write action" - // initiated by the user may show up as one or multiple - // writes, depending on when the system syncs things to - // disk. For example when compiling a large Go program - // you may get hundreds of Write events, and you may - // want to wait until you've stopped receiving them - // (see the dedup example in cmd/fsnotify). - // - // Some systems may send Write event for directories - // when the directory content changes. - // - // fsnotify.Chmod Attributes were changed. On Linux this is also sent - // when a file is removed (or more accurately, when a - // link to an inode is removed). On kqueue it's sent - // when a file is truncated. On Windows it's never - // sent. +type fen struct { Events chan Event - - // Errors sends any errors. - // - // ErrEventOverflow is used to indicate there are too many events: - // - // - inotify: There are too many queued events (fs.inotify.max_queued_events sysctl) - // - windows: The buffer size is too small; WithBufferSize() can be used to increase it. - // - kqueue, fen: Not used. Errors chan error mu sync.Mutex port *unix.EventPort - done chan struct{} // Channel for sending a "quit message" to the reader goroutine - dirs map[string]struct{} // Explicitly watched directories - watches map[string]struct{} // Explicitly watched non-directories + done chan struct{} // Channel for sending a "quit message" to the reader goroutine + dirs map[string]Op // Explicitly watched directories + watches map[string]Op // Explicitly watched non-directories } -// NewWatcher creates a new Watcher. -func NewWatcher() (*Watcher, error) { - return NewBufferedWatcher(0) +func newBackend(ev chan Event, errs chan error) (backend, error) { + return newBufferedBackend(0, ev, errs) } -// NewBufferedWatcher creates a new Watcher with a buffered Watcher.Events -// channel. -// -// The main use case for this is situations with a very large number of events -// where the kernel buffer size can't be increased (e.g. due to lack of -// permissions). An unbuffered Watcher will perform better for almost all use -// cases, and whenever possible you will be better off increasing the kernel -// buffers instead of adding a large userspace buffer. -func NewBufferedWatcher(sz uint) (*Watcher, error) { - w := &Watcher{ - Events: make(chan Event, sz), - Errors: make(chan error), - dirs: make(map[string]struct{}), - watches: make(map[string]struct{}), +func newBufferedBackend(sz uint, ev chan Event, errs chan error) (backend, error) { + w := &fen{ + Events: ev, + Errors: errs, + dirs: make(map[string]Op), + watches: make(map[string]Op), done: make(chan struct{}), } @@ -171,27 +54,30 @@ func NewBufferedWatcher(sz uint) (*Watcher, error) { // sendEvent attempts to send an event to the user, returning true if the event // was put in the channel successfully and false if the watcher has been closed. -func (w *Watcher) sendEvent(name string, op Op) (sent bool) { +func (w *fen) sendEvent(name string, op Op) (sent bool) { select { - case w.Events <- Event{Name: name, Op: op}: - return true case <-w.done: return false + case w.Events <- Event{Name: name, Op: op}: + return true } } // sendError attempts to send an error to the user, returning true if the error // was put in the channel successfully and false if the watcher has been closed. -func (w *Watcher) sendError(err error) (sent bool) { - select { - case w.Errors <- err: +func (w *fen) sendError(err error) (sent bool) { + if err == nil { return true + } + select { case <-w.done: return false + case w.Errors <- err: + return true } } -func (w *Watcher) isClosed() bool { +func (w *fen) isClosed() bool { select { case <-w.done: return true @@ -200,8 +86,7 @@ func (w *Watcher) isClosed() bool { } } -// Close removes all watches and closes the Events channel. -func (w *Watcher) Close() error { +func (w *fen) Close() error { // Take the lock used by associateFile to prevent lingering events from // being processed after the close w.mu.Lock() @@ -213,60 +98,21 @@ func (w *Watcher) Close() error { return w.port.Close() } -// Add starts monitoring the path for changes. -// -// A path can only be watched once; watching it more than once is a no-op and will -// not return an error. Paths that do not yet exist on the filesystem cannot be -// watched. -// -// A watch will be automatically removed if the watched path is deleted or -// renamed. The exception is the Windows backend, which doesn't remove the -// watcher on renames. -// -// Notifications on network filesystems (NFS, SMB, FUSE, etc.) or special -// filesystems (/proc, /sys, etc.) generally don't work. -// -// Returns [ErrClosed] if [Watcher.Close] was called. -// -// See [Watcher.AddWith] for a version that allows adding options. -// -// # Watching directories -// -// All files in a directory are monitored, including new files that are created -// after the watcher is started. Subdirectories are not watched (i.e. it's -// non-recursive). -// -// # Watching files -// -// Watching individual files (rather than directories) is generally not -// recommended as many programs (especially editors) update files atomically: it -// will write to a temporary file which is then moved to to destination, -// overwriting the original (or some variant thereof). The watcher on the -// original file is now lost, as that no longer exists. -// -// The upshot of this is that a power failure or crash won't leave a -// half-written file. -// -// Watch the parent directory and use Event.Name to filter out files you're not -// interested in. There is an example of this in cmd/fsnotify/file.go. -func (w *Watcher) Add(name string) error { return w.AddWith(name) } +func (w *fen) Add(name string) error { return w.AddWith(name) } -// AddWith is like [Watcher.Add], but allows adding options. When using Add() -// the defaults described below are used. -// -// Possible options are: -// -// - [WithBufferSize] sets the buffer size for the Windows backend; no-op on -// other platforms. The default is 64K (65536 bytes). -func (w *Watcher) AddWith(name string, opts ...addOpt) error { +func (w *fen) AddWith(name string, opts ...addOpt) error { if w.isClosed() { return ErrClosed } - if w.port.PathIsWatched(name) { - return nil + if debug { + fmt.Fprintf(os.Stderr, "FSNOTIFY_DEBUG: %s AddWith(%q)\n", + time.Now().Format("15:04:05.000000000"), name) } - _ = getOptions(opts...) + with := getOptions(opts...) + if !w.xSupports(with.op) { + return fmt.Errorf("%w: %s", xErrUnsupported, with.op) + } // Currently we resolve symlinks that were explicitly requested to be // watched. Otherwise we would use LStat here. @@ -283,7 +129,7 @@ func (w *Watcher) AddWith(name string, opts ...addOpt) error { } w.mu.Lock() - w.dirs[name] = struct{}{} + w.dirs[name] = with.op w.mu.Unlock() return nil } @@ -294,26 +140,22 @@ func (w *Watcher) AddWith(name string, opts ...addOpt) error { } w.mu.Lock() - w.watches[name] = struct{}{} + w.watches[name] = with.op w.mu.Unlock() return nil } -// Remove stops monitoring the path for changes. -// -// Directories are always removed non-recursively. For example, if you added -// /tmp/dir and /tmp/dir/subdir then you will need to remove both. -// -// Removing a path that has not yet been added returns [ErrNonExistentWatch]. -// -// Returns nil if [Watcher.Close] was called. -func (w *Watcher) Remove(name string) error { +func (w *fen) Remove(name string) error { if w.isClosed() { return nil } if !w.port.PathIsWatched(name) { return fmt.Errorf("%w: %s", ErrNonExistentWatch, name) } + if debug { + fmt.Fprintf(os.Stderr, "FSNOTIFY_DEBUG: %s Remove(%q)\n", + time.Now().Format("15:04:05.000000000"), name) + } // The user has expressed an intent. Immediately remove this name from // whichever watch list it might be in. If it's not in there the delete @@ -346,7 +188,7 @@ func (w *Watcher) Remove(name string) error { } // readEvents contains the main loop that runs in a goroutine watching for events. -func (w *Watcher) readEvents() { +func (w *fen) readEvents() { // If this function returns, the watcher has been closed and we can close // these channels defer func() { @@ -382,17 +224,19 @@ func (w *Watcher) readEvents() { continue } + if debug { + internal.Debug(pevent.Path, pevent.Events) + } + err = w.handleEvent(&pevent) - if err != nil { - if !w.sendError(err) { - return - } + if !w.sendError(err) { + return } } } } -func (w *Watcher) handleDirectory(path string, stat os.FileInfo, follow bool, handler func(string, os.FileInfo, bool) error) error { +func (w *fen) handleDirectory(path string, stat os.FileInfo, follow bool, handler func(string, os.FileInfo, bool) error) error { files, err := os.ReadDir(path) if err != nil { return err @@ -418,7 +262,7 @@ func (w *Watcher) handleDirectory(path string, stat os.FileInfo, follow bool, ha // bitmap matches more than one event type (e.g. the file was both modified and // had the attributes changed between when the association was created and the // when event was returned) -func (w *Watcher) handleEvent(event *unix.PortEvent) error { +func (w *fen) handleEvent(event *unix.PortEvent) error { var ( events = event.Events path = event.Path @@ -510,15 +354,9 @@ func (w *Watcher) handleEvent(event *unix.PortEvent) error { } if events&unix.FILE_MODIFIED != 0 { - if fmode.IsDir() { - if watchedDir { - if err := w.updateDirectory(path); err != nil { - return err - } - } else { - if !w.sendEvent(path, Write) { - return nil - } + if fmode.IsDir() && watchedDir { + if err := w.updateDirectory(path); err != nil { + return err } } else { if !w.sendEvent(path, Write) { @@ -543,7 +381,7 @@ func (w *Watcher) handleEvent(event *unix.PortEvent) error { return nil } -func (w *Watcher) updateDirectory(path string) error { +func (w *fen) updateDirectory(path string) error { // The directory was modified, so we must find unwatched entities and watch // them. If something was removed from the directory, nothing will happen, // as everything else should still be watched. @@ -563,10 +401,8 @@ func (w *Watcher) updateDirectory(path string) error { return err } err = w.associateFile(path, finfo, false) - if err != nil { - if !w.sendError(err) { - return nil - } + if !w.sendError(err) { + return nil } if !w.sendEvent(path, Create) { return nil @@ -575,7 +411,7 @@ func (w *Watcher) updateDirectory(path string) error { return nil } -func (w *Watcher) associateFile(path string, stat os.FileInfo, follow bool) error { +func (w *fen) associateFile(path string, stat os.FileInfo, follow bool) error { if w.isClosed() { return ErrClosed } @@ -593,34 +429,34 @@ func (w *Watcher) associateFile(path string, stat os.FileInfo, follow bool) erro // cleared up that discrepancy. The most likely cause is that the event // has fired but we haven't processed it yet. err := w.port.DissociatePath(path) - if err != nil && err != unix.ENOENT { + if err != nil && !errors.Is(err, unix.ENOENT) { return err } } - // FILE_NOFOLLOW means we watch symlinks themselves rather than their - // targets. - events := unix.FILE_MODIFIED | unix.FILE_ATTRIB | unix.FILE_NOFOLLOW - if follow { - // We *DO* follow symlinks for explicitly watched entries. - events = unix.FILE_MODIFIED | unix.FILE_ATTRIB + + var events int + if !follow { + // Watch symlinks themselves rather than their targets unless this entry + // is explicitly watched. + events |= unix.FILE_NOFOLLOW + } + if true { // TODO: implement withOps() + events |= unix.FILE_MODIFIED } - return w.port.AssociatePath(path, stat, - events, - stat.Mode()) + if true { + events |= unix.FILE_ATTRIB + } + return w.port.AssociatePath(path, stat, events, stat.Mode()) } -func (w *Watcher) dissociateFile(path string, stat os.FileInfo, unused bool) error { +func (w *fen) dissociateFile(path string, stat os.FileInfo, unused bool) error { if !w.port.PathIsWatched(path) { return nil } return w.port.DissociatePath(path) } -// WatchList returns all paths explicitly added with [Watcher.Add] (and are not -// yet removed). -// -// Returns nil if [Watcher.Close] was called. -func (w *Watcher) WatchList() []string { +func (w *fen) WatchList() []string { if w.isClosed() { return nil } @@ -638,3 +474,11 @@ func (w *Watcher) WatchList() []string { return entries } + +func (w *fen) xSupports(op Op) bool { + if op.Has(xUnportableOpen) || op.Has(xUnportableRead) || + op.Has(xUnportableCloseWrite) || op.Has(xUnportableCloseRead) { + return false + } + return true +} diff --git a/vendor/github.com/fsnotify/fsnotify/backend_inotify.go b/vendor/github.com/fsnotify/fsnotify/backend_inotify.go index 921c1c1e4..36c311694 100644 --- a/vendor/github.com/fsnotify/fsnotify/backend_inotify.go +++ b/vendor/github.com/fsnotify/fsnotify/backend_inotify.go @@ -1,8 +1,4 @@ //go:build linux && !appengine -// +build linux,!appengine - -// Note: the documentation on the Watcher type and methods is generated from -// mkdoc.zsh package fsnotify @@ -10,127 +6,20 @@ import ( "errors" "fmt" "io" + "io/fs" "os" "path/filepath" "strings" "sync" + "time" "unsafe" + "github.com/fsnotify/fsnotify/internal" "golang.org/x/sys/unix" ) -// Watcher watches a set of paths, delivering events on a channel. -// -// A watcher should not be copied (e.g. pass it by pointer, rather than by -// value). -// -// # Linux notes -// -// When a file is removed a Remove event won't be emitted until all file -// descriptors are closed, and deletes will always emit a Chmod. For example: -// -// fp := os.Open("file") -// os.Remove("file") // Triggers Chmod -// fp.Close() // Triggers Remove -// -// This is the event that inotify sends, so not much can be changed about this. -// -// The fs.inotify.max_user_watches sysctl variable specifies the upper limit -// for the number of watches per user, and fs.inotify.max_user_instances -// specifies the maximum number of inotify instances per user. Every Watcher you -// create is an "instance", and every path you add is a "watch". -// -// These are also exposed in /proc as /proc/sys/fs/inotify/max_user_watches and -// /proc/sys/fs/inotify/max_user_instances -// -// To increase them you can use sysctl or write the value to the /proc file: -// -// # Default values on Linux 5.18 -// sysctl fs.inotify.max_user_watches=124983 -// sysctl fs.inotify.max_user_instances=128 -// -// To make the changes persist on reboot edit /etc/sysctl.conf or -// /usr/lib/sysctl.d/50-default.conf (details differ per Linux distro; check -// your distro's documentation): -// -// fs.inotify.max_user_watches=124983 -// fs.inotify.max_user_instances=128 -// -// Reaching the limit will result in a "no space left on device" or "too many open -// files" error. -// -// # kqueue notes (macOS, BSD) -// -// kqueue requires opening a file descriptor for every file that's being watched; -// so if you're watching a directory with five files then that's six file -// descriptors. You will run in to your system's "max open files" limit faster on -// these platforms. -// -// The sysctl variables kern.maxfiles and kern.maxfilesperproc can be used to -// control the maximum number of open files, as well as /etc/login.conf on BSD -// systems. -// -// # Windows notes -// -// Paths can be added as "C:\path\to\dir", but forward slashes -// ("C:/path/to/dir") will also work. -// -// When a watched directory is removed it will always send an event for the -// directory itself, but may not send events for all files in that directory. -// Sometimes it will send events for all times, sometimes it will send no -// events, and often only for some files. -// -// The default ReadDirectoryChangesW() buffer size is 64K, which is the largest -// value that is guaranteed to work with SMB filesystems. If you have many -// events in quick succession this may not be enough, and you will have to use -// [WithBufferSize] to increase the value. -type Watcher struct { - // Events sends the filesystem change events. - // - // fsnotify can send the following events; a "path" here can refer to a - // file, directory, symbolic link, or special file like a FIFO. - // - // fsnotify.Create A new path was created; this may be followed by one - // or more Write events if data also gets written to a - // file. - // - // fsnotify.Remove A path was removed. - // - // fsnotify.Rename A path was renamed. A rename is always sent with the - // old path as Event.Name, and a Create event will be - // sent with the new name. Renames are only sent for - // paths that are currently watched; e.g. moving an - // unmonitored file into a monitored directory will - // show up as just a Create. Similarly, renaming a file - // to outside a monitored directory will show up as - // only a Rename. - // - // fsnotify.Write A file or named pipe was written to. A Truncate will - // also trigger a Write. A single "write action" - // initiated by the user may show up as one or multiple - // writes, depending on when the system syncs things to - // disk. For example when compiling a large Go program - // you may get hundreds of Write events, and you may - // want to wait until you've stopped receiving them - // (see the dedup example in cmd/fsnotify). - // - // Some systems may send Write event for directories - // when the directory content changes. - // - // fsnotify.Chmod Attributes were changed. On Linux this is also sent - // when a file is removed (or more accurately, when a - // link to an inode is removed). On kqueue it's sent - // when a file is truncated. On Windows it's never - // sent. +type inotify struct { Events chan Event - - // Errors sends any errors. - // - // ErrEventOverflow is used to indicate there are too many events: - // - // - inotify: There are too many queued events (fs.inotify.max_queued_events sysctl) - // - windows: The buffer size is too small; WithBufferSize() can be used to increase it. - // - kqueue, fen: Not used. Errors chan error // Store fd here as os.File.Read() will no longer return on close after @@ -139,8 +28,26 @@ type Watcher struct { inotifyFile *os.File watches *watches done chan struct{} // Channel for sending a "quit message" to the reader goroutine - closeMu sync.Mutex + doneMu sync.Mutex doneResp chan struct{} // Channel to respond to Close + + // Store rename cookies in an array, with the index wrapping to 0. Almost + // all of the time what we get is a MOVED_FROM to set the cookie and the + // next event inotify sends will be MOVED_TO to read it. However, this is + // not guaranteed – as described in inotify(7) – and we may get other events + // between the two MOVED_* events (including other MOVED_* ones). + // + // A second issue is that moving a file outside the watched directory will + // trigger a MOVED_FROM to set the cookie, but we never see the MOVED_TO to + // read and delete it. So just storing it in a map would slowly leak memory. + // + // Doing it like this gives us a simple fast LRU-cache that won't allocate. + // Ten items should be more than enough for our purpose, and a loop over + // such a short array is faster than a map access anyway (not that it hugely + // matters since we're talking about hundreds of ns at the most, but still). + cookies [10]koekje + cookieIndex uint8 + cookiesMu sync.Mutex } type ( @@ -150,9 +57,14 @@ type ( path map[string]uint32 // pathname → wd } watch struct { - wd uint32 // Watch descriptor (as returned by the inotify_add_watch() syscall) - flags uint32 // inotify flags of this watch (see inotify(7) for the list of valid flags) - path string // Watch path. + wd uint32 // Watch descriptor (as returned by the inotify_add_watch() syscall) + flags uint32 // inotify flags of this watch (see inotify(7) for the list of valid flags) + path string // Watch path. + recurse bool // Recursion with ./...? + } + koekje struct { + cookie uint32 + path string } ) @@ -179,23 +91,45 @@ func (w *watches) add(ww *watch) { func (w *watches) remove(wd uint32) { w.mu.Lock() defer w.mu.Unlock() - delete(w.path, w.wd[wd].path) + watch := w.wd[wd] // Could have had Remove() called. See #616. + if watch == nil { + return + } + delete(w.path, watch.path) delete(w.wd, wd) } -func (w *watches) removePath(path string) (uint32, bool) { +func (w *watches) removePath(path string) ([]uint32, error) { w.mu.Lock() defer w.mu.Unlock() + path, recurse := recursivePath(path) wd, ok := w.path[path] if !ok { - return 0, false + return nil, fmt.Errorf("%w: %s", ErrNonExistentWatch, path) + } + + watch := w.wd[wd] + if recurse && !watch.recurse { + return nil, fmt.Errorf("can't use /... with non-recursive watch %q", path) } delete(w.path, path) delete(w.wd, wd) + if !watch.recurse { + return []uint32{wd}, nil + } - return wd, true + wds := make([]uint32, 0, 8) + wds = append(wds, wd) + for p, rwd := range w.path { + if filepath.HasPrefix(p, path) { + delete(w.path, p) + delete(w.wd, rwd) + wds = append(wds, rwd) + } + } + return wds, nil } func (w *watches) byPath(path string) *watch { @@ -236,20 +170,11 @@ func (w *watches) updatePath(path string, f func(*watch) (*watch, error)) error return nil } -// NewWatcher creates a new Watcher. -func NewWatcher() (*Watcher, error) { - return NewBufferedWatcher(0) +func newBackend(ev chan Event, errs chan error) (backend, error) { + return newBufferedBackend(0, ev, errs) } -// NewBufferedWatcher creates a new Watcher with a buffered Watcher.Events -// channel. -// -// The main use case for this is situations with a very large number of events -// where the kernel buffer size can't be increased (e.g. due to lack of -// permissions). An unbuffered Watcher will perform better for almost all use -// cases, and whenever possible you will be better off increasing the kernel -// buffers instead of adding a large userspace buffer. -func NewBufferedWatcher(sz uint) (*Watcher, error) { +func newBufferedBackend(sz uint, ev chan Event, errs chan error) (backend, error) { // Need to set nonblocking mode for SetDeadline to work, otherwise blocking // I/O operations won't terminate on close. fd, errno := unix.InotifyInit1(unix.IN_CLOEXEC | unix.IN_NONBLOCK) @@ -257,12 +182,12 @@ func NewBufferedWatcher(sz uint) (*Watcher, error) { return nil, errno } - w := &Watcher{ + w := &inotify{ + Events: ev, + Errors: errs, fd: fd, inotifyFile: os.NewFile(uintptr(fd), ""), watches: newWatches(), - Events: make(chan Event, sz), - Errors: make(chan error), done: make(chan struct{}), doneResp: make(chan struct{}), } @@ -272,26 +197,29 @@ func NewBufferedWatcher(sz uint) (*Watcher, error) { } // Returns true if the event was sent, or false if watcher is closed. -func (w *Watcher) sendEvent(e Event) bool { +func (w *inotify) sendEvent(e Event) bool { select { - case w.Events <- e: - return true case <-w.done: return false + case w.Events <- e: + return true } } // Returns true if the error was sent, or false if watcher is closed. -func (w *Watcher) sendError(err error) bool { - select { - case w.Errors <- err: +func (w *inotify) sendError(err error) bool { + if err == nil { return true + } + select { case <-w.done: return false + case w.Errors <- err: + return true } } -func (w *Watcher) isClosed() bool { +func (w *inotify) isClosed() bool { select { case <-w.done: return true @@ -300,15 +228,14 @@ func (w *Watcher) isClosed() bool { } } -// Close removes all watches and closes the Events channel. -func (w *Watcher) Close() error { - w.closeMu.Lock() +func (w *inotify) Close() error { + w.doneMu.Lock() if w.isClosed() { - w.closeMu.Unlock() + w.doneMu.Unlock() return nil } close(w.done) - w.closeMu.Unlock() + w.doneMu.Unlock() // Causes any blocking reads to return with an error, provided the file // still supports deadline operations. @@ -323,78 +250,104 @@ func (w *Watcher) Close() error { return nil } -// Add starts monitoring the path for changes. -// -// A path can only be watched once; watching it more than once is a no-op and will -// not return an error. Paths that do not yet exist on the filesystem cannot be -// watched. -// -// A watch will be automatically removed if the watched path is deleted or -// renamed. The exception is the Windows backend, which doesn't remove the -// watcher on renames. -// -// Notifications on network filesystems (NFS, SMB, FUSE, etc.) or special -// filesystems (/proc, /sys, etc.) generally don't work. -// -// Returns [ErrClosed] if [Watcher.Close] was called. -// -// See [Watcher.AddWith] for a version that allows adding options. -// -// # Watching directories -// -// All files in a directory are monitored, including new files that are created -// after the watcher is started. Subdirectories are not watched (i.e. it's -// non-recursive). -// -// # Watching files -// -// Watching individual files (rather than directories) is generally not -// recommended as many programs (especially editors) update files atomically: it -// will write to a temporary file which is then moved to to destination, -// overwriting the original (or some variant thereof). The watcher on the -// original file is now lost, as that no longer exists. -// -// The upshot of this is that a power failure or crash won't leave a -// half-written file. -// -// Watch the parent directory and use Event.Name to filter out files you're not -// interested in. There is an example of this in cmd/fsnotify/file.go. -func (w *Watcher) Add(name string) error { return w.AddWith(name) } - -// AddWith is like [Watcher.Add], but allows adding options. When using Add() -// the defaults described below are used. -// -// Possible options are: -// -// - [WithBufferSize] sets the buffer size for the Windows backend; no-op on -// other platforms. The default is 64K (65536 bytes). -func (w *Watcher) AddWith(name string, opts ...addOpt) error { +func (w *inotify) Add(name string) error { return w.AddWith(name) } + +func (w *inotify) AddWith(path string, opts ...addOpt) error { if w.isClosed() { return ErrClosed } + if debug { + fmt.Fprintf(os.Stderr, "FSNOTIFY_DEBUG: %s AddWith(%q)\n", + time.Now().Format("15:04:05.000000000"), path) + } + + with := getOptions(opts...) + if !w.xSupports(with.op) { + return fmt.Errorf("%w: %s", xErrUnsupported, with.op) + } - name = filepath.Clean(name) - _ = getOptions(opts...) + path, recurse := recursivePath(path) + if recurse { + return filepath.WalkDir(path, func(root string, d fs.DirEntry, err error) error { + if err != nil { + return err + } + if !d.IsDir() { + if root == path { + return fmt.Errorf("fsnotify: not a directory: %q", path) + } + return nil + } - var flags uint32 = unix.IN_MOVED_TO | unix.IN_MOVED_FROM | - unix.IN_CREATE | unix.IN_ATTRIB | unix.IN_MODIFY | - unix.IN_MOVE_SELF | unix.IN_DELETE | unix.IN_DELETE_SELF + // Send a Create event when adding new directory from a recursive + // watch; this is for "mkdir -p one/two/three". Usually all those + // directories will be created before we can set up watchers on the + // subdirectories, so only "one" would be sent as a Create event and + // not "one/two" and "one/two/three" (inotifywait -r has the same + // problem). + if with.sendCreate && root != path { + w.sendEvent(Event{Name: root, Op: Create}) + } + + return w.add(root, with, true) + }) + } - return w.watches.updatePath(name, func(existing *watch) (*watch, error) { + return w.add(path, with, false) +} + +func (w *inotify) add(path string, with withOpts, recurse bool) error { + var flags uint32 + if with.noFollow { + flags |= unix.IN_DONT_FOLLOW + } + if with.op.Has(Create) { + flags |= unix.IN_CREATE + } + if with.op.Has(Write) { + flags |= unix.IN_MODIFY + } + if with.op.Has(Remove) { + flags |= unix.IN_DELETE | unix.IN_DELETE_SELF + } + if with.op.Has(Rename) { + flags |= unix.IN_MOVED_TO | unix.IN_MOVED_FROM | unix.IN_MOVE_SELF + } + if with.op.Has(Chmod) { + flags |= unix.IN_ATTRIB + } + if with.op.Has(xUnportableOpen) { + flags |= unix.IN_OPEN + } + if with.op.Has(xUnportableRead) { + flags |= unix.IN_ACCESS + } + if with.op.Has(xUnportableCloseWrite) { + flags |= unix.IN_CLOSE_WRITE + } + if with.op.Has(xUnportableCloseRead) { + flags |= unix.IN_CLOSE_NOWRITE + } + return w.register(path, flags, recurse) +} + +func (w *inotify) register(path string, flags uint32, recurse bool) error { + return w.watches.updatePath(path, func(existing *watch) (*watch, error) { if existing != nil { flags |= existing.flags | unix.IN_MASK_ADD } - wd, err := unix.InotifyAddWatch(w.fd, name, flags) + wd, err := unix.InotifyAddWatch(w.fd, path, flags) if wd == -1 { return nil, err } if existing == nil { return &watch{ - wd: uint32(wd), - path: name, - flags: flags, + wd: uint32(wd), + path: path, + flags: flags, + recurse: recurse, }, nil } @@ -404,49 +357,44 @@ func (w *Watcher) AddWith(name string, opts ...addOpt) error { }) } -// Remove stops monitoring the path for changes. -// -// Directories are always removed non-recursively. For example, if you added -// /tmp/dir and /tmp/dir/subdir then you will need to remove both. -// -// Removing a path that has not yet been added returns [ErrNonExistentWatch]. -// -// Returns nil if [Watcher.Close] was called. -func (w *Watcher) Remove(name string) error { +func (w *inotify) Remove(name string) error { if w.isClosed() { return nil } + if debug { + fmt.Fprintf(os.Stderr, "FSNOTIFY_DEBUG: %s Remove(%q)\n", + time.Now().Format("15:04:05.000000000"), name) + } return w.remove(filepath.Clean(name)) } -func (w *Watcher) remove(name string) error { - wd, ok := w.watches.removePath(name) - if !ok { - return fmt.Errorf("%w: %s", ErrNonExistentWatch, name) - } - - success, errno := unix.InotifyRmWatch(w.fd, wd) - if success == -1 { - // TODO: Perhaps it's not helpful to return an error here in every case; - // The only two possible errors are: - // - // - EBADF, which happens when w.fd is not a valid file descriptor - // of any kind. - // - EINVAL, which is when fd is not an inotify descriptor or wd - // is not a valid watch descriptor. Watch descriptors are - // invalidated when they are removed explicitly or implicitly; - // explicitly by inotify_rm_watch, implicitly when the file they - // are watching is deleted. - return errno +func (w *inotify) remove(name string) error { + wds, err := w.watches.removePath(name) + if err != nil { + return err + } + + for _, wd := range wds { + _, err := unix.InotifyRmWatch(w.fd, wd) + if err != nil { + // TODO: Perhaps it's not helpful to return an error here in every + // case; the only two possible errors are: + // + // EBADF, which happens when w.fd is not a valid file descriptor of + // any kind. + // + // EINVAL, which is when fd is not an inotify descriptor or wd is + // not a valid watch descriptor. Watch descriptors are invalidated + // when they are removed explicitly or implicitly; explicitly by + // inotify_rm_watch, implicitly when the file they are watching is + // deleted. + return err + } } return nil } -// WatchList returns all paths explicitly added with [Watcher.Add] (and are not -// yet removed). -// -// Returns nil if [Watcher.Close] was called. -func (w *Watcher) WatchList() []string { +func (w *inotify) WatchList() []string { if w.isClosed() { return nil } @@ -463,7 +411,7 @@ func (w *Watcher) WatchList() []string { // readEvents reads from the inotify file descriptor, converts the // received events into Event objects and sends them via the Events channel -func (w *Watcher) readEvents() { +func (w *inotify) readEvents() { defer func() { close(w.doneResp) close(w.Errors) @@ -506,15 +454,17 @@ func (w *Watcher) readEvents() { continue } - var offset uint32 // We don't know how many events we just read into the buffer // While the offset points to at least one whole event... + var offset uint32 for offset <= uint32(n-unix.SizeofInotifyEvent) { var ( // Point "raw" to the event in the buffer raw = (*unix.InotifyEvent)(unsafe.Pointer(&buf[offset])) mask = uint32(raw.Mask) nameLen = uint32(raw.Len) + // Move to the next event in the buffer + next = func() { offset += unix.SizeofInotifyEvent + nameLen } ) if mask&unix.IN_Q_OVERFLOW != 0 { @@ -523,21 +473,53 @@ func (w *Watcher) readEvents() { } } - // If the event happened to the watched directory or the watched file, the kernel - // doesn't append the filename to the event, but we would like to always fill the - // the "Name" field with a valid filename. We retrieve the path of the watch from - // the "paths" map. + /// If the event happened to the watched directory or the watched + /// file, the kernel doesn't append the filename to the event, but + /// we would like to always fill the the "Name" field with a valid + /// filename. We retrieve the path of the watch from the "paths" + /// map. watch := w.watches.byWd(uint32(raw.Wd)) + /// Can be nil if Remove() was called in another goroutine for this + /// path inbetween reading the events from the kernel and reading + /// the internal state. Not much we can do about it, so just skip. + /// See #616. + if watch == nil { + next() + continue + } + + name := watch.path + if nameLen > 0 { + /// Point "bytes" at the first byte of the filename + bytes := (*[unix.PathMax]byte)(unsafe.Pointer(&buf[offset+unix.SizeofInotifyEvent]))[:nameLen:nameLen] + /// The filename is padded with NULL bytes. TrimRight() gets rid of those. + name += "/" + strings.TrimRight(string(bytes[0:nameLen]), "\000") + } + + if debug { + internal.Debug(name, raw.Mask, raw.Cookie) + } + + if mask&unix.IN_IGNORED != 0 { //&& event.Op != 0 + next() + continue + } // inotify will automatically remove the watch on deletes; just need // to clean our state here. - if watch != nil && mask&unix.IN_DELETE_SELF == unix.IN_DELETE_SELF { + if mask&unix.IN_DELETE_SELF == unix.IN_DELETE_SELF { w.watches.remove(watch.wd) } + // We can't really update the state when a watched path is moved; // only IN_MOVE_SELF is sent and not IN_MOVED_{FROM,TO}. So remove // the watch. - if watch != nil && mask&unix.IN_MOVE_SELF == unix.IN_MOVE_SELF { + if mask&unix.IN_MOVE_SELF == unix.IN_MOVE_SELF { + if watch.recurse { + next() // Do nothing + continue + } + err := w.remove(watch.path) if err != nil && !errors.Is(err, ErrNonExistentWatch) { if !w.sendError(err) { @@ -546,34 +528,69 @@ func (w *Watcher) readEvents() { } } - var name string - if watch != nil { - name = watch.path - } - if nameLen > 0 { - // Point "bytes" at the first byte of the filename - bytes := (*[unix.PathMax]byte)(unsafe.Pointer(&buf[offset+unix.SizeofInotifyEvent]))[:nameLen:nameLen] - // The filename is padded with NULL bytes. TrimRight() gets rid of those. - name += "/" + strings.TrimRight(string(bytes[0:nameLen]), "\000") + /// Skip if we're watching both this path and the parent; the parent + /// will already send a delete so no need to do it twice. + if mask&unix.IN_DELETE_SELF != 0 { + if _, ok := w.watches.path[filepath.Dir(watch.path)]; ok { + next() + continue + } } - event := w.newEvent(name, mask) + ev := w.newEvent(name, mask, raw.Cookie) + // Need to update watch path for recurse. + if watch.recurse { + isDir := mask&unix.IN_ISDIR == unix.IN_ISDIR + /// New directory created: set up watch on it. + if isDir && ev.Has(Create) { + err := w.register(ev.Name, watch.flags, true) + if !w.sendError(err) { + return + } - // Send the events that are not ignored on the events channel - if mask&unix.IN_IGNORED == 0 { - if !w.sendEvent(event) { - return + // This was a directory rename, so we need to update all + // the children. + // + // TODO: this is of course pretty slow; we should use a + // better data structure for storing all of this, e.g. store + // children in the watch. I have some code for this in my + // kqueue refactor we can use in the future. For now I'm + // okay with this as it's not publicly available. + // Correctness first, performance second. + if ev.renamedFrom != "" { + w.watches.mu.Lock() + for k, ww := range w.watches.wd { + if k == watch.wd || ww.path == ev.Name { + continue + } + if strings.HasPrefix(ww.path, ev.renamedFrom) { + ww.path = strings.Replace(ww.path, ev.renamedFrom, ev.Name, 1) + w.watches.wd[k] = ww + } + } + w.watches.mu.Unlock() + } } } - // Move to the next event in the buffer - offset += unix.SizeofInotifyEvent + nameLen + /// Send the events that are not ignored on the events channel + if !w.sendEvent(ev) { + return + } + next() } } } -// newEvent returns an platform-independent Event based on an inotify mask. -func (w *Watcher) newEvent(name string, mask uint32) Event { +func (w *inotify) isRecursive(path string) bool { + ww := w.watches.byPath(path) + if ww == nil { // path could be a file, so also check the Dir. + ww = w.watches.byPath(filepath.Dir(path)) + } + return ww != nil && ww.recurse +} + +func (w *inotify) newEvent(name string, mask, cookie uint32) Event { e := Event{Name: name} if mask&unix.IN_CREATE == unix.IN_CREATE || mask&unix.IN_MOVED_TO == unix.IN_MOVED_TO { e.Op |= Create @@ -584,11 +601,58 @@ func (w *Watcher) newEvent(name string, mask uint32) Event { if mask&unix.IN_MODIFY == unix.IN_MODIFY { e.Op |= Write } + if mask&unix.IN_OPEN == unix.IN_OPEN { + e.Op |= xUnportableOpen + } + if mask&unix.IN_ACCESS == unix.IN_ACCESS { + e.Op |= xUnportableRead + } + if mask&unix.IN_CLOSE_WRITE == unix.IN_CLOSE_WRITE { + e.Op |= xUnportableCloseWrite + } + if mask&unix.IN_CLOSE_NOWRITE == unix.IN_CLOSE_NOWRITE { + e.Op |= xUnportableCloseRead + } if mask&unix.IN_MOVE_SELF == unix.IN_MOVE_SELF || mask&unix.IN_MOVED_FROM == unix.IN_MOVED_FROM { e.Op |= Rename } if mask&unix.IN_ATTRIB == unix.IN_ATTRIB { e.Op |= Chmod } + + if cookie != 0 { + if mask&unix.IN_MOVED_FROM == unix.IN_MOVED_FROM { + w.cookiesMu.Lock() + w.cookies[w.cookieIndex] = koekje{cookie: cookie, path: e.Name} + w.cookieIndex++ + if w.cookieIndex > 9 { + w.cookieIndex = 0 + } + w.cookiesMu.Unlock() + } else if mask&unix.IN_MOVED_TO == unix.IN_MOVED_TO { + w.cookiesMu.Lock() + var prev string + for _, c := range w.cookies { + if c.cookie == cookie { + prev = c.path + break + } + } + w.cookiesMu.Unlock() + e.renamedFrom = prev + } + } return e } + +func (w *inotify) xSupports(op Op) bool { + return true // Supports everything. +} + +func (w *inotify) state() { + w.watches.mu.Lock() + defer w.watches.mu.Unlock() + for wd, ww := range w.watches.wd { + fmt.Fprintf(os.Stderr, "%4d: recurse=%t %q\n", wd, ww.recurse, ww.path) + } +} diff --git a/vendor/github.com/fsnotify/fsnotify/backend_kqueue.go b/vendor/github.com/fsnotify/fsnotify/backend_kqueue.go index 063a0915a..d8de5ab76 100644 --- a/vendor/github.com/fsnotify/fsnotify/backend_kqueue.go +++ b/vendor/github.com/fsnotify/fsnotify/backend_kqueue.go @@ -1,8 +1,4 @@ //go:build freebsd || openbsd || netbsd || dragonfly || darwin -// +build freebsd openbsd netbsd dragonfly darwin - -// Note: the documentation on the Watcher type and methods is generated from -// mkdoc.zsh package fsnotify @@ -11,174 +7,195 @@ import ( "fmt" "os" "path/filepath" + "runtime" "sync" + "time" + "github.com/fsnotify/fsnotify/internal" "golang.org/x/sys/unix" ) -// Watcher watches a set of paths, delivering events on a channel. -// -// A watcher should not be copied (e.g. pass it by pointer, rather than by -// value). -// -// # Linux notes -// -// When a file is removed a Remove event won't be emitted until all file -// descriptors are closed, and deletes will always emit a Chmod. For example: -// -// fp := os.Open("file") -// os.Remove("file") // Triggers Chmod -// fp.Close() // Triggers Remove -// -// This is the event that inotify sends, so not much can be changed about this. -// -// The fs.inotify.max_user_watches sysctl variable specifies the upper limit -// for the number of watches per user, and fs.inotify.max_user_instances -// specifies the maximum number of inotify instances per user. Every Watcher you -// create is an "instance", and every path you add is a "watch". -// -// These are also exposed in /proc as /proc/sys/fs/inotify/max_user_watches and -// /proc/sys/fs/inotify/max_user_instances -// -// To increase them you can use sysctl or write the value to the /proc file: -// -// # Default values on Linux 5.18 -// sysctl fs.inotify.max_user_watches=124983 -// sysctl fs.inotify.max_user_instances=128 -// -// To make the changes persist on reboot edit /etc/sysctl.conf or -// /usr/lib/sysctl.d/50-default.conf (details differ per Linux distro; check -// your distro's documentation): -// -// fs.inotify.max_user_watches=124983 -// fs.inotify.max_user_instances=128 -// -// Reaching the limit will result in a "no space left on device" or "too many open -// files" error. -// -// # kqueue notes (macOS, BSD) -// -// kqueue requires opening a file descriptor for every file that's being watched; -// so if you're watching a directory with five files then that's six file -// descriptors. You will run in to your system's "max open files" limit faster on -// these platforms. -// -// The sysctl variables kern.maxfiles and kern.maxfilesperproc can be used to -// control the maximum number of open files, as well as /etc/login.conf on BSD -// systems. -// -// # Windows notes -// -// Paths can be added as "C:\path\to\dir", but forward slashes -// ("C:/path/to/dir") will also work. -// -// When a watched directory is removed it will always send an event for the -// directory itself, but may not send events for all files in that directory. -// Sometimes it will send events for all times, sometimes it will send no -// events, and often only for some files. -// -// The default ReadDirectoryChangesW() buffer size is 64K, which is the largest -// value that is guaranteed to work with SMB filesystems. If you have many -// events in quick succession this may not be enough, and you will have to use -// [WithBufferSize] to increase the value. -type Watcher struct { - // Events sends the filesystem change events. - // - // fsnotify can send the following events; a "path" here can refer to a - // file, directory, symbolic link, or special file like a FIFO. - // - // fsnotify.Create A new path was created; this may be followed by one - // or more Write events if data also gets written to a - // file. - // - // fsnotify.Remove A path was removed. - // - // fsnotify.Rename A path was renamed. A rename is always sent with the - // old path as Event.Name, and a Create event will be - // sent with the new name. Renames are only sent for - // paths that are currently watched; e.g. moving an - // unmonitored file into a monitored directory will - // show up as just a Create. Similarly, renaming a file - // to outside a monitored directory will show up as - // only a Rename. - // - // fsnotify.Write A file or named pipe was written to. A Truncate will - // also trigger a Write. A single "write action" - // initiated by the user may show up as one or multiple - // writes, depending on when the system syncs things to - // disk. For example when compiling a large Go program - // you may get hundreds of Write events, and you may - // want to wait until you've stopped receiving them - // (see the dedup example in cmd/fsnotify). - // - // Some systems may send Write event for directories - // when the directory content changes. - // - // fsnotify.Chmod Attributes were changed. On Linux this is also sent - // when a file is removed (or more accurately, when a - // link to an inode is removed). On kqueue it's sent - // when a file is truncated. On Windows it's never - // sent. +type kqueue struct { Events chan Event - - // Errors sends any errors. - // - // ErrEventOverflow is used to indicate there are too many events: - // - // - inotify: There are too many queued events (fs.inotify.max_queued_events sysctl) - // - windows: The buffer size is too small; WithBufferSize() can be used to increase it. - // - kqueue, fen: Not used. Errors chan error - done chan struct{} - kq int // File descriptor (as returned by the kqueue() syscall). - closepipe [2]int // Pipe used for closing. - mu sync.Mutex // Protects access to watcher data - watches map[string]int // Watched file descriptors (key: path). - watchesByDir map[string]map[int]struct{} // Watched file descriptors indexed by the parent directory (key: dirname(path)). - userWatches map[string]struct{} // Watches added with Watcher.Add() - dirFlags map[string]uint32 // Watched directories to fflags used in kqueue. - paths map[int]pathInfo // File descriptors to path names for processing kqueue events. - fileExists map[string]struct{} // Keep track of if we know this file exists (to stop duplicate create events). - isClosed bool // Set to true when Close() is first called + kq int // File descriptor (as returned by the kqueue() syscall). + closepipe [2]int // Pipe used for closing kq. + watches *watches + done chan struct{} + doneMu sync.Mutex } -type pathInfo struct { - name string - isDir bool +type ( + watches struct { + mu sync.RWMutex + wd map[int]watch // wd → watch + path map[string]int // pathname → wd + byDir map[string]map[int]struct{} // dirname(path) → wd + seen map[string]struct{} // Keep track of if we know this file exists. + byUser map[string]struct{} // Watches added with Watcher.Add() + } + watch struct { + wd int + name string + linkName string // In case of links; name is the target, and this is the link. + isDir bool + dirFlags uint32 + } +) + +func newWatches() *watches { + return &watches{ + wd: make(map[int]watch), + path: make(map[string]int), + byDir: make(map[string]map[int]struct{}), + seen: make(map[string]struct{}), + byUser: make(map[string]struct{}), + } } -// NewWatcher creates a new Watcher. -func NewWatcher() (*Watcher, error) { - return NewBufferedWatcher(0) +func (w *watches) listPaths(userOnly bool) []string { + w.mu.RLock() + defer w.mu.RUnlock() + + if userOnly { + l := make([]string, 0, len(w.byUser)) + for p := range w.byUser { + l = append(l, p) + } + return l + } + + l := make([]string, 0, len(w.path)) + for p := range w.path { + l = append(l, p) + } + return l } -// NewBufferedWatcher creates a new Watcher with a buffered Watcher.Events -// channel. -// -// The main use case for this is situations with a very large number of events -// where the kernel buffer size can't be increased (e.g. due to lack of -// permissions). An unbuffered Watcher will perform better for almost all use -// cases, and whenever possible you will be better off increasing the kernel -// buffers instead of adding a large userspace buffer. -func NewBufferedWatcher(sz uint) (*Watcher, error) { +func (w *watches) watchesInDir(path string) []string { + w.mu.RLock() + defer w.mu.RUnlock() + + l := make([]string, 0, 4) + for fd := range w.byDir[path] { + info := w.wd[fd] + if _, ok := w.byUser[info.name]; !ok { + l = append(l, info.name) + } + } + return l +} + +// Mark path as added by the user. +func (w *watches) addUserWatch(path string) { + w.mu.Lock() + defer w.mu.Unlock() + w.byUser[path] = struct{}{} +} + +func (w *watches) addLink(path string, fd int) { + w.mu.Lock() + defer w.mu.Unlock() + + w.path[path] = fd + w.seen[path] = struct{}{} +} + +func (w *watches) add(path, linkPath string, fd int, isDir bool) { + w.mu.Lock() + defer w.mu.Unlock() + + w.path[path] = fd + w.wd[fd] = watch{wd: fd, name: path, linkName: linkPath, isDir: isDir} + + parent := filepath.Dir(path) + byDir, ok := w.byDir[parent] + if !ok { + byDir = make(map[int]struct{}, 1) + w.byDir[parent] = byDir + } + byDir[fd] = struct{}{} +} + +func (w *watches) byWd(fd int) (watch, bool) { + w.mu.RLock() + defer w.mu.RUnlock() + info, ok := w.wd[fd] + return info, ok +} + +func (w *watches) byPath(path string) (watch, bool) { + w.mu.RLock() + defer w.mu.RUnlock() + info, ok := w.wd[w.path[path]] + return info, ok +} + +func (w *watches) updateDirFlags(path string, flags uint32) { + w.mu.Lock() + defer w.mu.Unlock() + + fd := w.path[path] + info := w.wd[fd] + info.dirFlags = flags + w.wd[fd] = info +} + +func (w *watches) remove(fd int, path string) bool { + w.mu.Lock() + defer w.mu.Unlock() + + isDir := w.wd[fd].isDir + delete(w.path, path) + delete(w.byUser, path) + + parent := filepath.Dir(path) + delete(w.byDir[parent], fd) + + if len(w.byDir[parent]) == 0 { + delete(w.byDir, parent) + } + + delete(w.wd, fd) + delete(w.seen, path) + return isDir +} + +func (w *watches) markSeen(path string, exists bool) { + w.mu.Lock() + defer w.mu.Unlock() + if exists { + w.seen[path] = struct{}{} + } else { + delete(w.seen, path) + } +} + +func (w *watches) seenBefore(path string) bool { + w.mu.RLock() + defer w.mu.RUnlock() + _, ok := w.seen[path] + return ok +} + +func newBackend(ev chan Event, errs chan error) (backend, error) { + return newBufferedBackend(0, ev, errs) +} + +func newBufferedBackend(sz uint, ev chan Event, errs chan error) (backend, error) { kq, closepipe, err := newKqueue() if err != nil { return nil, err } - w := &Watcher{ - kq: kq, - closepipe: closepipe, - watches: make(map[string]int), - watchesByDir: make(map[string]map[int]struct{}), - dirFlags: make(map[string]uint32), - paths: make(map[int]pathInfo), - fileExists: make(map[string]struct{}), - userWatches: make(map[string]struct{}), - Events: make(chan Event, sz), - Errors: make(chan error), - done: make(chan struct{}), + w := &kqueue{ + Events: ev, + Errors: errs, + kq: kq, + closepipe: closepipe, + done: make(chan struct{}), + watches: newWatches(), } go w.readEvents() @@ -203,6 +220,8 @@ func newKqueue() (kq int, closepipe [2]int, err error) { unix.Close(kq) return kq, closepipe, err } + unix.CloseOnExec(closepipe[0]) + unix.CloseOnExec(closepipe[1]) // Register changes to listen on the closepipe. changes := make([]unix.Kevent_t, 1) @@ -221,166 +240,108 @@ func newKqueue() (kq int, closepipe [2]int, err error) { } // Returns true if the event was sent, or false if watcher is closed. -func (w *Watcher) sendEvent(e Event) bool { +func (w *kqueue) sendEvent(e Event) bool { select { - case w.Events <- e: - return true case <-w.done: return false + case w.Events <- e: + return true } } // Returns true if the error was sent, or false if watcher is closed. -func (w *Watcher) sendError(err error) bool { +func (w *kqueue) sendError(err error) bool { + if err == nil { + return true + } select { + case <-w.done: + return false case w.Errors <- err: return true + } +} + +func (w *kqueue) isClosed() bool { + select { case <-w.done: + return true + default: return false } } -// Close removes all watches and closes the Events channel. -func (w *Watcher) Close() error { - w.mu.Lock() - if w.isClosed { - w.mu.Unlock() +func (w *kqueue) Close() error { + w.doneMu.Lock() + if w.isClosed() { + w.doneMu.Unlock() return nil } - w.isClosed = true + close(w.done) + w.doneMu.Unlock() - // copy paths to remove while locked - pathsToRemove := make([]string, 0, len(w.watches)) - for name := range w.watches { - pathsToRemove = append(pathsToRemove, name) - } - w.mu.Unlock() // Unlock before calling Remove, which also locks + pathsToRemove := w.watches.listPaths(false) for _, name := range pathsToRemove { w.Remove(name) } // Send "quit" message to the reader goroutine. unix.Close(w.closepipe[1]) - close(w.done) - return nil } -// Add starts monitoring the path for changes. -// -// A path can only be watched once; watching it more than once is a no-op and will -// not return an error. Paths that do not yet exist on the filesystem cannot be -// watched. -// -// A watch will be automatically removed if the watched path is deleted or -// renamed. The exception is the Windows backend, which doesn't remove the -// watcher on renames. -// -// Notifications on network filesystems (NFS, SMB, FUSE, etc.) or special -// filesystems (/proc, /sys, etc.) generally don't work. -// -// Returns [ErrClosed] if [Watcher.Close] was called. -// -// See [Watcher.AddWith] for a version that allows adding options. -// -// # Watching directories -// -// All files in a directory are monitored, including new files that are created -// after the watcher is started. Subdirectories are not watched (i.e. it's -// non-recursive). -// -// # Watching files -// -// Watching individual files (rather than directories) is generally not -// recommended as many programs (especially editors) update files atomically: it -// will write to a temporary file which is then moved to to destination, -// overwriting the original (or some variant thereof). The watcher on the -// original file is now lost, as that no longer exists. -// -// The upshot of this is that a power failure or crash won't leave a -// half-written file. -// -// Watch the parent directory and use Event.Name to filter out files you're not -// interested in. There is an example of this in cmd/fsnotify/file.go. -func (w *Watcher) Add(name string) error { return w.AddWith(name) } +func (w *kqueue) Add(name string) error { return w.AddWith(name) } -// AddWith is like [Watcher.Add], but allows adding options. When using Add() -// the defaults described below are used. -// -// Possible options are: -// -// - [WithBufferSize] sets the buffer size for the Windows backend; no-op on -// other platforms. The default is 64K (65536 bytes). -func (w *Watcher) AddWith(name string, opts ...addOpt) error { - _ = getOptions(opts...) +func (w *kqueue) AddWith(name string, opts ...addOpt) error { + if debug { + fmt.Fprintf(os.Stderr, "FSNOTIFY_DEBUG: %s AddWith(%q)\n", + time.Now().Format("15:04:05.000000000"), name) + } + + with := getOptions(opts...) + if !w.xSupports(with.op) { + return fmt.Errorf("%w: %s", xErrUnsupported, with.op) + } - w.mu.Lock() - w.userWatches[name] = struct{}{} - w.mu.Unlock() _, err := w.addWatch(name, noteAllEvents) - return err + if err != nil { + return err + } + w.watches.addUserWatch(name) + return nil } -// Remove stops monitoring the path for changes. -// -// Directories are always removed non-recursively. For example, if you added -// /tmp/dir and /tmp/dir/subdir then you will need to remove both. -// -// Removing a path that has not yet been added returns [ErrNonExistentWatch]. -// -// Returns nil if [Watcher.Close] was called. -func (w *Watcher) Remove(name string) error { +func (w *kqueue) Remove(name string) error { + if debug { + fmt.Fprintf(os.Stderr, "FSNOTIFY_DEBUG: %s Remove(%q)\n", + time.Now().Format("15:04:05.000000000"), name) + } return w.remove(name, true) } -func (w *Watcher) remove(name string, unwatchFiles bool) error { - name = filepath.Clean(name) - w.mu.Lock() - if w.isClosed { - w.mu.Unlock() +func (w *kqueue) remove(name string, unwatchFiles bool) error { + if w.isClosed() { return nil } - watchfd, ok := w.watches[name] - w.mu.Unlock() + + name = filepath.Clean(name) + info, ok := w.watches.byPath(name) if !ok { return fmt.Errorf("%w: %s", ErrNonExistentWatch, name) } - err := w.register([]int{watchfd}, unix.EV_DELETE, 0) + err := w.register([]int{info.wd}, unix.EV_DELETE, 0) if err != nil { return err } - unix.Close(watchfd) - - w.mu.Lock() - isDir := w.paths[watchfd].isDir - delete(w.watches, name) - delete(w.userWatches, name) - - parentName := filepath.Dir(name) - delete(w.watchesByDir[parentName], watchfd) - - if len(w.watchesByDir[parentName]) == 0 { - delete(w.watchesByDir, parentName) - } + unix.Close(info.wd) - delete(w.paths, watchfd) - delete(w.dirFlags, name) - delete(w.fileExists, name) - w.mu.Unlock() + isDir := w.watches.remove(info.wd, name) // Find all watched paths that are in this directory that are not external. if unwatchFiles && isDir { - var pathsToRemove []string - w.mu.Lock() - for fd := range w.watchesByDir[name] { - path := w.paths[fd] - if _, ok := w.userWatches[path.name]; !ok { - pathsToRemove = append(pathsToRemove, path.name) - } - } - w.mu.Unlock() + pathsToRemove := w.watches.watchesInDir(name) for _, name := range pathsToRemove { // Since these are internal, not much sense in propagating error to // the user, as that will just confuse them with an error about a @@ -391,23 +352,11 @@ func (w *Watcher) remove(name string, unwatchFiles bool) error { return nil } -// WatchList returns all paths explicitly added with [Watcher.Add] (and are not -// yet removed). -// -// Returns nil if [Watcher.Close] was called. -func (w *Watcher) WatchList() []string { - w.mu.Lock() - defer w.mu.Unlock() - if w.isClosed { +func (w *kqueue) WatchList() []string { + if w.isClosed() { return nil } - - entries := make([]string, 0, len(w.userWatches)) - for pathname := range w.userWatches { - entries = append(entries, pathname) - } - - return entries + return w.watches.listPaths(true) } // Watch all events (except NOTE_EXTEND, NOTE_LINK, NOTE_REVOKE) @@ -417,34 +366,26 @@ const noteAllEvents = unix.NOTE_DELETE | unix.NOTE_WRITE | unix.NOTE_ATTRIB | un // described in kevent(2). // // Returns the real path to the file which was added, with symlinks resolved. -func (w *Watcher) addWatch(name string, flags uint32) (string, error) { - var isDir bool - name = filepath.Clean(name) - - w.mu.Lock() - if w.isClosed { - w.mu.Unlock() +func (w *kqueue) addWatch(name string, flags uint32) (string, error) { + if w.isClosed() { return "", ErrClosed } - watchfd, alreadyWatching := w.watches[name] - // We already have a watch, but we can still override flags. - if alreadyWatching { - isDir = w.paths[watchfd].isDir - } - w.mu.Unlock() + name = filepath.Clean(name) + + info, alreadyWatching := w.watches.byPath(name) if !alreadyWatching { fi, err := os.Lstat(name) if err != nil { return "", err } - // Don't watch sockets or named pipes + // Don't watch sockets or named pipes. if (fi.Mode()&os.ModeSocket == os.ModeSocket) || (fi.Mode()&os.ModeNamedPipe == os.ModeNamedPipe) { return "", nil } - // Follow Symlinks. + // Follow symlinks. if fi.Mode()&os.ModeSymlink == os.ModeSymlink { link, err := os.Readlink(name) if err != nil { @@ -455,18 +396,15 @@ func (w *Watcher) addWatch(name string, flags uint32) (string, error) { return "", nil } - w.mu.Lock() - _, alreadyWatching = w.watches[link] - w.mu.Unlock() - + _, alreadyWatching = w.watches.byPath(link) if alreadyWatching { // Add to watches so we don't get spurious Create events later // on when we diff the directories. - w.watches[name] = 0 - w.fileExists[name] = struct{}{} + w.watches.addLink(name, 0) return link, nil } + info.linkName = name name = link fi, err = os.Lstat(name) if err != nil { @@ -477,7 +415,7 @@ func (w *Watcher) addWatch(name string, flags uint32) (string, error) { // Retry on EINTR; open() can return EINTR in practice on macOS. // See #354, and Go issues 11180 and 39237. for { - watchfd, err = unix.Open(name, openMode, 0) + info.wd, err = unix.Open(name, openMode, 0) if err == nil { break } @@ -488,40 +426,25 @@ func (w *Watcher) addWatch(name string, flags uint32) (string, error) { return "", err } - isDir = fi.IsDir() + info.isDir = fi.IsDir() } - err := w.register([]int{watchfd}, unix.EV_ADD|unix.EV_CLEAR|unix.EV_ENABLE, flags) + err := w.register([]int{info.wd}, unix.EV_ADD|unix.EV_CLEAR|unix.EV_ENABLE, flags) if err != nil { - unix.Close(watchfd) + unix.Close(info.wd) return "", err } if !alreadyWatching { - w.mu.Lock() - parentName := filepath.Dir(name) - w.watches[name] = watchfd - - watchesByDir, ok := w.watchesByDir[parentName] - if !ok { - watchesByDir = make(map[int]struct{}, 1) - w.watchesByDir[parentName] = watchesByDir - } - watchesByDir[watchfd] = struct{}{} - w.paths[watchfd] = pathInfo{name: name, isDir: isDir} - w.mu.Unlock() + w.watches.add(name, info.linkName, info.wd, info.isDir) } - if isDir { - // Watch the directory if it has not been watched before, or if it was - // watched before, but perhaps only a NOTE_DELETE (watchDirectoryFiles) - w.mu.Lock() - + // Watch the directory if it has not been watched before, or if it was + // watched before, but perhaps only a NOTE_DELETE (watchDirectoryFiles) + if info.isDir { watchDir := (flags&unix.NOTE_WRITE) == unix.NOTE_WRITE && - (!alreadyWatching || (w.dirFlags[name]&unix.NOTE_WRITE) != unix.NOTE_WRITE) - // Store flags so this watch can be updated later - w.dirFlags[name] = flags - w.mu.Unlock() + (!alreadyWatching || (info.dirFlags&unix.NOTE_WRITE) != unix.NOTE_WRITE) + w.watches.updateDirFlags(name, flags) if watchDir { if err := w.watchDirectoryFiles(name); err != nil { @@ -534,7 +457,7 @@ func (w *Watcher) addWatch(name string, flags uint32) (string, error) { // readEvents reads from kqueue and converts the received kevents into // Event values that it sends down the Events channel. -func (w *Watcher) readEvents() { +func (w *kqueue) readEvents() { defer func() { close(w.Events) close(w.Errors) @@ -543,50 +466,65 @@ func (w *Watcher) readEvents() { }() eventBuffer := make([]unix.Kevent_t, 10) - for closed := false; !closed; { + for { kevents, err := w.read(eventBuffer) // EINTR is okay, the syscall was interrupted before timeout expired. if err != nil && err != unix.EINTR { if !w.sendError(fmt.Errorf("fsnotify.readEvents: %w", err)) { - closed = true + return } - continue } - // Flush the events we received to the Events channel for _, kevent := range kevents { var ( - watchfd = int(kevent.Ident) - mask = uint32(kevent.Fflags) + wd = int(kevent.Ident) + mask = uint32(kevent.Fflags) ) // Shut down the loop when the pipe is closed, but only after all // other events have been processed. - if watchfd == w.closepipe[0] { - closed = true - continue + if wd == w.closepipe[0] { + return } - w.mu.Lock() - path := w.paths[watchfd] - w.mu.Unlock() + path, ok := w.watches.byWd(wd) + if debug { + internal.Debug(path.name, &kevent) + } - event := w.newEvent(path.name, mask) + // On macOS it seems that sometimes an event with Ident=0 is + // delivered, and no other flags/information beyond that, even + // though we never saw such a file descriptor. For example in + // TestWatchSymlink/277 (usually at the end, but sometimes sooner): + // + // fmt.Printf("READ: %2d %#v\n", kevent.Ident, kevent) + // unix.Kevent_t{Ident:0x2a, Filter:-4, Flags:0x25, Fflags:0x2, Data:0, Udata:(*uint8)(nil)} + // unix.Kevent_t{Ident:0x0, Filter:-4, Flags:0x25, Fflags:0x2, Data:0, Udata:(*uint8)(nil)} + // + // The first is a normal event, the second with Ident 0. No error + // flag, no data, no ... nothing. + // + // I read a bit through bsd/kern_event.c from the xnu source, but I + // don't really see an obvious location where this is triggered – + // this doesn't seem intentional, but idk... + // + // Technically fd 0 is a valid descriptor, so only skip it if + // there's no path, and if we're on macOS. + if !ok && kevent.Ident == 0 && runtime.GOOS == "darwin" { + continue + } + + event := w.newEvent(path.name, path.linkName, mask) if event.Has(Rename) || event.Has(Remove) { w.remove(event.Name, false) - w.mu.Lock() - delete(w.fileExists, event.Name) - w.mu.Unlock() + w.watches.markSeen(event.Name, false) } if path.isDir && event.Has(Write) && !event.Has(Remove) { - w.sendDirectoryChangeEvents(event.Name) - } else { - if !w.sendEvent(event) { - closed = true - continue - } + w.dirChange(event.Name) + } else if !w.sendEvent(event) { + return } if event.Has(Remove) { @@ -594,25 +532,34 @@ func (w *Watcher) readEvents() { // mv f1 f2 will delete f2, then create f2. if path.isDir { fileDir := filepath.Clean(event.Name) - w.mu.Lock() - _, found := w.watches[fileDir] - w.mu.Unlock() + _, found := w.watches.byPath(fileDir) if found { - err := w.sendDirectoryChangeEvents(fileDir) - if err != nil { - if !w.sendError(err) { - closed = true - } + // TODO: this branch is never triggered in any test. + // Added in d6220df (2012). + // isDir check added in 8611c35 (2016): https://github.com/fsnotify/fsnotify/pull/111 + // + // I don't really get how this can be triggered either. + // And it wasn't triggered in the patch that added it, + // either. + // + // Original also had a comment: + // make sure the directory exists before we watch for + // changes. When we do a recursive watch and perform + // rm -rf, the parent directory might have gone + // missing, ignore the missing directory and let the + // upcoming delete event remove the watch from the + // parent directory. + err := w.dirChange(fileDir) + if !w.sendError(err) { + return } } } else { - filePath := filepath.Clean(event.Name) - if fi, err := os.Lstat(filePath); err == nil { - err := w.sendFileCreatedEventIfNew(filePath, fi) - if err != nil { - if !w.sendError(err) { - closed = true - } + path := filepath.Clean(event.Name) + if fi, err := os.Lstat(path); err == nil { + err := w.sendCreateIfNew(path, fi) + if !w.sendError(err) { + return } } } @@ -622,8 +569,14 @@ func (w *Watcher) readEvents() { } // newEvent returns an platform-independent Event based on kqueue Fflags. -func (w *Watcher) newEvent(name string, mask uint32) Event { +func (w *kqueue) newEvent(name, linkName string, mask uint32) Event { e := Event{Name: name} + if linkName != "" { + // If the user watched "/path/link" then emit events as "/path/link" + // rather than "/path/target". + e.Name = linkName + } + if mask&unix.NOTE_DELETE == unix.NOTE_DELETE { e.Op |= Remove } @@ -645,8 +598,7 @@ func (w *Watcher) newEvent(name string, mask uint32) Event { } // watchDirectoryFiles to mimic inotify when adding a watch on a directory -func (w *Watcher) watchDirectoryFiles(dirPath string) error { - // Get all files +func (w *kqueue) watchDirectoryFiles(dirPath string) error { files, err := os.ReadDir(dirPath) if err != nil { return err @@ -674,9 +626,7 @@ func (w *Watcher) watchDirectoryFiles(dirPath string) error { } } - w.mu.Lock() - w.fileExists[cleanPath] = struct{}{} - w.mu.Unlock() + w.watches.markSeen(cleanPath, true) } return nil @@ -686,7 +636,7 @@ func (w *Watcher) watchDirectoryFiles(dirPath string) error { // // This functionality is to have the BSD watcher match the inotify, which sends // a create event for files created in a watched directory. -func (w *Watcher) sendDirectoryChangeEvents(dir string) error { +func (w *kqueue) dirChange(dir string) error { files, err := os.ReadDir(dir) if err != nil { // Directory no longer exists: we can ignore this safely. kqueue will @@ -694,61 +644,51 @@ func (w *Watcher) sendDirectoryChangeEvents(dir string) error { if errors.Is(err, os.ErrNotExist) { return nil } - return fmt.Errorf("fsnotify.sendDirectoryChangeEvents: %w", err) + return fmt.Errorf("fsnotify.dirChange: %w", err) } for _, f := range files { fi, err := f.Info() if err != nil { - return fmt.Errorf("fsnotify.sendDirectoryChangeEvents: %w", err) + return fmt.Errorf("fsnotify.dirChange: %w", err) } - err = w.sendFileCreatedEventIfNew(filepath.Join(dir, fi.Name()), fi) + err = w.sendCreateIfNew(filepath.Join(dir, fi.Name()), fi) if err != nil { // Don't need to send an error if this file isn't readable. if errors.Is(err, unix.EACCES) || errors.Is(err, unix.EPERM) { return nil } - return fmt.Errorf("fsnotify.sendDirectoryChangeEvents: %w", err) + return fmt.Errorf("fsnotify.dirChange: %w", err) } } return nil } -// sendFileCreatedEvent sends a create event if the file isn't already being tracked. -func (w *Watcher) sendFileCreatedEventIfNew(filePath string, fi os.FileInfo) (err error) { - w.mu.Lock() - _, doesExist := w.fileExists[filePath] - w.mu.Unlock() - if !doesExist { - if !w.sendEvent(Event{Name: filePath, Op: Create}) { - return +// Send a create event if the file isn't already being tracked, and start +// watching this file. +func (w *kqueue) sendCreateIfNew(path string, fi os.FileInfo) error { + if !w.watches.seenBefore(path) { + if !w.sendEvent(Event{Name: path, Op: Create}) { + return nil } } - // like watchDirectoryFiles (but without doing another ReadDir) - filePath, err = w.internalWatch(filePath, fi) + // Like watchDirectoryFiles, but without doing another ReadDir. + path, err := w.internalWatch(path, fi) if err != nil { return err } - - w.mu.Lock() - w.fileExists[filePath] = struct{}{} - w.mu.Unlock() - + w.watches.markSeen(path, true) return nil } -func (w *Watcher) internalWatch(name string, fi os.FileInfo) (string, error) { +func (w *kqueue) internalWatch(name string, fi os.FileInfo) (string, error) { if fi.IsDir() { // mimic Linux providing delete events for subdirectories, but preserve // the flags used if currently watching subdirectory - w.mu.Lock() - flags := w.dirFlags[name] - w.mu.Unlock() - - flags |= unix.NOTE_DELETE | unix.NOTE_RENAME - return w.addWatch(name, flags) + info, _ := w.watches.byPath(name) + return w.addWatch(name, info.dirFlags|unix.NOTE_DELETE|unix.NOTE_RENAME) } // watch file to mimic Linux inotify @@ -756,7 +696,7 @@ func (w *Watcher) internalWatch(name string, fi os.FileInfo) (string, error) { } // Register events with the queue. -func (w *Watcher) register(fds []int, flags int, fflags uint32) error { +func (w *kqueue) register(fds []int, flags int, fflags uint32) error { changes := make([]unix.Kevent_t, len(fds)) for i, fd := range fds { // SetKevent converts int to the platform-specific types. @@ -773,10 +713,21 @@ func (w *Watcher) register(fds []int, flags int, fflags uint32) error { } // read retrieves pending events, or waits until an event occurs. -func (w *Watcher) read(events []unix.Kevent_t) ([]unix.Kevent_t, error) { +func (w *kqueue) read(events []unix.Kevent_t) ([]unix.Kevent_t, error) { n, err := unix.Kevent(w.kq, nil, events, nil) if err != nil { return nil, err } return events[0:n], nil } + +func (w *kqueue) xSupports(op Op) bool { + if runtime.GOOS == "freebsd" { + //return true // Supports everything. + } + if op.Has(xUnportableOpen) || op.Has(xUnportableRead) || + op.Has(xUnportableCloseWrite) || op.Has(xUnportableCloseRead) { + return false + } + return true +} diff --git a/vendor/github.com/fsnotify/fsnotify/backend_other.go b/vendor/github.com/fsnotify/fsnotify/backend_other.go index d34a23c01..5eb5dbc66 100644 --- a/vendor/github.com/fsnotify/fsnotify/backend_other.go +++ b/vendor/github.com/fsnotify/fsnotify/backend_other.go @@ -1,205 +1,23 @@ //go:build appengine || (!darwin && !dragonfly && !freebsd && !openbsd && !linux && !netbsd && !solaris && !windows) -// +build appengine !darwin,!dragonfly,!freebsd,!openbsd,!linux,!netbsd,!solaris,!windows - -// Note: the documentation on the Watcher type and methods is generated from -// mkdoc.zsh package fsnotify import "errors" -// Watcher watches a set of paths, delivering events on a channel. -// -// A watcher should not be copied (e.g. pass it by pointer, rather than by -// value). -// -// # Linux notes -// -// When a file is removed a Remove event won't be emitted until all file -// descriptors are closed, and deletes will always emit a Chmod. For example: -// -// fp := os.Open("file") -// os.Remove("file") // Triggers Chmod -// fp.Close() // Triggers Remove -// -// This is the event that inotify sends, so not much can be changed about this. -// -// The fs.inotify.max_user_watches sysctl variable specifies the upper limit -// for the number of watches per user, and fs.inotify.max_user_instances -// specifies the maximum number of inotify instances per user. Every Watcher you -// create is an "instance", and every path you add is a "watch". -// -// These are also exposed in /proc as /proc/sys/fs/inotify/max_user_watches and -// /proc/sys/fs/inotify/max_user_instances -// -// To increase them you can use sysctl or write the value to the /proc file: -// -// # Default values on Linux 5.18 -// sysctl fs.inotify.max_user_watches=124983 -// sysctl fs.inotify.max_user_instances=128 -// -// To make the changes persist on reboot edit /etc/sysctl.conf or -// /usr/lib/sysctl.d/50-default.conf (details differ per Linux distro; check -// your distro's documentation): -// -// fs.inotify.max_user_watches=124983 -// fs.inotify.max_user_instances=128 -// -// Reaching the limit will result in a "no space left on device" or "too many open -// files" error. -// -// # kqueue notes (macOS, BSD) -// -// kqueue requires opening a file descriptor for every file that's being watched; -// so if you're watching a directory with five files then that's six file -// descriptors. You will run in to your system's "max open files" limit faster on -// these platforms. -// -// The sysctl variables kern.maxfiles and kern.maxfilesperproc can be used to -// control the maximum number of open files, as well as /etc/login.conf on BSD -// systems. -// -// # Windows notes -// -// Paths can be added as "C:\path\to\dir", but forward slashes -// ("C:/path/to/dir") will also work. -// -// When a watched directory is removed it will always send an event for the -// directory itself, but may not send events for all files in that directory. -// Sometimes it will send events for all times, sometimes it will send no -// events, and often only for some files. -// -// The default ReadDirectoryChangesW() buffer size is 64K, which is the largest -// value that is guaranteed to work with SMB filesystems. If you have many -// events in quick succession this may not be enough, and you will have to use -// [WithBufferSize] to increase the value. -type Watcher struct { - // Events sends the filesystem change events. - // - // fsnotify can send the following events; a "path" here can refer to a - // file, directory, symbolic link, or special file like a FIFO. - // - // fsnotify.Create A new path was created; this may be followed by one - // or more Write events if data also gets written to a - // file. - // - // fsnotify.Remove A path was removed. - // - // fsnotify.Rename A path was renamed. A rename is always sent with the - // old path as Event.Name, and a Create event will be - // sent with the new name. Renames are only sent for - // paths that are currently watched; e.g. moving an - // unmonitored file into a monitored directory will - // show up as just a Create. Similarly, renaming a file - // to outside a monitored directory will show up as - // only a Rename. - // - // fsnotify.Write A file or named pipe was written to. A Truncate will - // also trigger a Write. A single "write action" - // initiated by the user may show up as one or multiple - // writes, depending on when the system syncs things to - // disk. For example when compiling a large Go program - // you may get hundreds of Write events, and you may - // want to wait until you've stopped receiving them - // (see the dedup example in cmd/fsnotify). - // - // Some systems may send Write event for directories - // when the directory content changes. - // - // fsnotify.Chmod Attributes were changed. On Linux this is also sent - // when a file is removed (or more accurately, when a - // link to an inode is removed). On kqueue it's sent - // when a file is truncated. On Windows it's never - // sent. +type other struct { Events chan Event - - // Errors sends any errors. - // - // ErrEventOverflow is used to indicate there are too many events: - // - // - inotify: There are too many queued events (fs.inotify.max_queued_events sysctl) - // - windows: The buffer size is too small; WithBufferSize() can be used to increase it. - // - kqueue, fen: Not used. Errors chan error } -// NewWatcher creates a new Watcher. -func NewWatcher() (*Watcher, error) { +func newBackend(ev chan Event, errs chan error) (backend, error) { return nil, errors.New("fsnotify not supported on the current platform") } - -// NewBufferedWatcher creates a new Watcher with a buffered Watcher.Events -// channel. -// -// The main use case for this is situations with a very large number of events -// where the kernel buffer size can't be increased (e.g. due to lack of -// permissions). An unbuffered Watcher will perform better for almost all use -// cases, and whenever possible you will be better off increasing the kernel -// buffers instead of adding a large userspace buffer. -func NewBufferedWatcher(sz uint) (*Watcher, error) { return NewWatcher() } - -// Close removes all watches and closes the Events channel. -func (w *Watcher) Close() error { return nil } - -// WatchList returns all paths explicitly added with [Watcher.Add] (and are not -// yet removed). -// -// Returns nil if [Watcher.Close] was called. -func (w *Watcher) WatchList() []string { return nil } - -// Add starts monitoring the path for changes. -// -// A path can only be watched once; watching it more than once is a no-op and will -// not return an error. Paths that do not yet exist on the filesystem cannot be -// watched. -// -// A watch will be automatically removed if the watched path is deleted or -// renamed. The exception is the Windows backend, which doesn't remove the -// watcher on renames. -// -// Notifications on network filesystems (NFS, SMB, FUSE, etc.) or special -// filesystems (/proc, /sys, etc.) generally don't work. -// -// Returns [ErrClosed] if [Watcher.Close] was called. -// -// See [Watcher.AddWith] for a version that allows adding options. -// -// # Watching directories -// -// All files in a directory are monitored, including new files that are created -// after the watcher is started. Subdirectories are not watched (i.e. it's -// non-recursive). -// -// # Watching files -// -// Watching individual files (rather than directories) is generally not -// recommended as many programs (especially editors) update files atomically: it -// will write to a temporary file which is then moved to to destination, -// overwriting the original (or some variant thereof). The watcher on the -// original file is now lost, as that no longer exists. -// -// The upshot of this is that a power failure or crash won't leave a -// half-written file. -// -// Watch the parent directory and use Event.Name to filter out files you're not -// interested in. There is an example of this in cmd/fsnotify/file.go. -func (w *Watcher) Add(name string) error { return nil } - -// AddWith is like [Watcher.Add], but allows adding options. When using Add() -// the defaults described below are used. -// -// Possible options are: -// -// - [WithBufferSize] sets the buffer size for the Windows backend; no-op on -// other platforms. The default is 64K (65536 bytes). -func (w *Watcher) AddWith(name string, opts ...addOpt) error { return nil } - -// Remove stops monitoring the path for changes. -// -// Directories are always removed non-recursively. For example, if you added -// /tmp/dir and /tmp/dir/subdir then you will need to remove both. -// -// Removing a path that has not yet been added returns [ErrNonExistentWatch]. -// -// Returns nil if [Watcher.Close] was called. -func (w *Watcher) Remove(name string) error { return nil } +func newBufferedBackend(sz uint, ev chan Event, errs chan error) (backend, error) { + return newBackend(ev, errs) +} +func (w *other) Close() error { return nil } +func (w *other) WatchList() []string { return nil } +func (w *other) Add(name string) error { return nil } +func (w *other) AddWith(name string, opts ...addOpt) error { return nil } +func (w *other) Remove(name string) error { return nil } +func (w *other) xSupports(op Op) bool { return false } diff --git a/vendor/github.com/fsnotify/fsnotify/backend_windows.go b/vendor/github.com/fsnotify/fsnotify/backend_windows.go index 9bc91e5d6..c54a63083 100644 --- a/vendor/github.com/fsnotify/fsnotify/backend_windows.go +++ b/vendor/github.com/fsnotify/fsnotify/backend_windows.go @@ -1,12 +1,8 @@ //go:build windows -// +build windows // Windows backend based on ReadDirectoryChangesW() // // https://learn.microsoft.com/en-us/windows/win32/api/winbase/nf-winbase-readdirectorychangesw -// -// Note: the documentation on the Watcher type and methods is generated from -// mkdoc.zsh package fsnotify @@ -19,123 +15,15 @@ import ( "runtime" "strings" "sync" + "time" "unsafe" + "github.com/fsnotify/fsnotify/internal" "golang.org/x/sys/windows" ) -// Watcher watches a set of paths, delivering events on a channel. -// -// A watcher should not be copied (e.g. pass it by pointer, rather than by -// value). -// -// # Linux notes -// -// When a file is removed a Remove event won't be emitted until all file -// descriptors are closed, and deletes will always emit a Chmod. For example: -// -// fp := os.Open("file") -// os.Remove("file") // Triggers Chmod -// fp.Close() // Triggers Remove -// -// This is the event that inotify sends, so not much can be changed about this. -// -// The fs.inotify.max_user_watches sysctl variable specifies the upper limit -// for the number of watches per user, and fs.inotify.max_user_instances -// specifies the maximum number of inotify instances per user. Every Watcher you -// create is an "instance", and every path you add is a "watch". -// -// These are also exposed in /proc as /proc/sys/fs/inotify/max_user_watches and -// /proc/sys/fs/inotify/max_user_instances -// -// To increase them you can use sysctl or write the value to the /proc file: -// -// # Default values on Linux 5.18 -// sysctl fs.inotify.max_user_watches=124983 -// sysctl fs.inotify.max_user_instances=128 -// -// To make the changes persist on reboot edit /etc/sysctl.conf or -// /usr/lib/sysctl.d/50-default.conf (details differ per Linux distro; check -// your distro's documentation): -// -// fs.inotify.max_user_watches=124983 -// fs.inotify.max_user_instances=128 -// -// Reaching the limit will result in a "no space left on device" or "too many open -// files" error. -// -// # kqueue notes (macOS, BSD) -// -// kqueue requires opening a file descriptor for every file that's being watched; -// so if you're watching a directory with five files then that's six file -// descriptors. You will run in to your system's "max open files" limit faster on -// these platforms. -// -// The sysctl variables kern.maxfiles and kern.maxfilesperproc can be used to -// control the maximum number of open files, as well as /etc/login.conf on BSD -// systems. -// -// # Windows notes -// -// Paths can be added as "C:\path\to\dir", but forward slashes -// ("C:/path/to/dir") will also work. -// -// When a watched directory is removed it will always send an event for the -// directory itself, but may not send events for all files in that directory. -// Sometimes it will send events for all times, sometimes it will send no -// events, and often only for some files. -// -// The default ReadDirectoryChangesW() buffer size is 64K, which is the largest -// value that is guaranteed to work with SMB filesystems. If you have many -// events in quick succession this may not be enough, and you will have to use -// [WithBufferSize] to increase the value. -type Watcher struct { - // Events sends the filesystem change events. - // - // fsnotify can send the following events; a "path" here can refer to a - // file, directory, symbolic link, or special file like a FIFO. - // - // fsnotify.Create A new path was created; this may be followed by one - // or more Write events if data also gets written to a - // file. - // - // fsnotify.Remove A path was removed. - // - // fsnotify.Rename A path was renamed. A rename is always sent with the - // old path as Event.Name, and a Create event will be - // sent with the new name. Renames are only sent for - // paths that are currently watched; e.g. moving an - // unmonitored file into a monitored directory will - // show up as just a Create. Similarly, renaming a file - // to outside a monitored directory will show up as - // only a Rename. - // - // fsnotify.Write A file or named pipe was written to. A Truncate will - // also trigger a Write. A single "write action" - // initiated by the user may show up as one or multiple - // writes, depending on when the system syncs things to - // disk. For example when compiling a large Go program - // you may get hundreds of Write events, and you may - // want to wait until you've stopped receiving them - // (see the dedup example in cmd/fsnotify). - // - // Some systems may send Write event for directories - // when the directory content changes. - // - // fsnotify.Chmod Attributes were changed. On Linux this is also sent - // when a file is removed (or more accurately, when a - // link to an inode is removed). On kqueue it's sent - // when a file is truncated. On Windows it's never - // sent. +type readDirChangesW struct { Events chan Event - - // Errors sends any errors. - // - // ErrEventOverflow is used to indicate there are too many events: - // - // - inotify: There are too many queued events (fs.inotify.max_queued_events sysctl) - // - windows: The buffer size is too small; WithBufferSize() can be used to increase it. - // - kqueue, fen: Not used. Errors chan error port windows.Handle // Handle to completion port @@ -147,48 +35,40 @@ type Watcher struct { closed bool // Set to true when Close() is first called } -// NewWatcher creates a new Watcher. -func NewWatcher() (*Watcher, error) { - return NewBufferedWatcher(50) +func newBackend(ev chan Event, errs chan error) (backend, error) { + return newBufferedBackend(50, ev, errs) } -// NewBufferedWatcher creates a new Watcher with a buffered Watcher.Events -// channel. -// -// The main use case for this is situations with a very large number of events -// where the kernel buffer size can't be increased (e.g. due to lack of -// permissions). An unbuffered Watcher will perform better for almost all use -// cases, and whenever possible you will be better off increasing the kernel -// buffers instead of adding a large userspace buffer. -func NewBufferedWatcher(sz uint) (*Watcher, error) { +func newBufferedBackend(sz uint, ev chan Event, errs chan error) (backend, error) { port, err := windows.CreateIoCompletionPort(windows.InvalidHandle, 0, 0, 0) if err != nil { return nil, os.NewSyscallError("CreateIoCompletionPort", err) } - w := &Watcher{ + w := &readDirChangesW{ + Events: ev, + Errors: errs, port: port, watches: make(watchMap), input: make(chan *input, 1), - Events: make(chan Event, sz), - Errors: make(chan error), quit: make(chan chan<- error, 1), } go w.readEvents() return w, nil } -func (w *Watcher) isClosed() bool { +func (w *readDirChangesW) isClosed() bool { w.mu.Lock() defer w.mu.Unlock() return w.closed } -func (w *Watcher) sendEvent(name string, mask uint64) bool { +func (w *readDirChangesW) sendEvent(name, renamedFrom string, mask uint64) bool { if mask == 0 { return false } event := w.newEvent(name, uint32(mask)) + event.renamedFrom = renamedFrom select { case ch := <-w.quit: w.quit <- ch @@ -198,17 +78,19 @@ func (w *Watcher) sendEvent(name string, mask uint64) bool { } // Returns true if the error was sent, or false if watcher is closed. -func (w *Watcher) sendError(err error) bool { +func (w *readDirChangesW) sendError(err error) bool { + if err == nil { + return true + } select { case w.Errors <- err: return true case <-w.quit: + return false } - return false } -// Close removes all watches and closes the Events channel. -func (w *Watcher) Close() error { +func (w *readDirChangesW) Close() error { if w.isClosed() { return nil } @@ -226,57 +108,21 @@ func (w *Watcher) Close() error { return <-ch } -// Add starts monitoring the path for changes. -// -// A path can only be watched once; watching it more than once is a no-op and will -// not return an error. Paths that do not yet exist on the filesystem cannot be -// watched. -// -// A watch will be automatically removed if the watched path is deleted or -// renamed. The exception is the Windows backend, which doesn't remove the -// watcher on renames. -// -// Notifications on network filesystems (NFS, SMB, FUSE, etc.) or special -// filesystems (/proc, /sys, etc.) generally don't work. -// -// Returns [ErrClosed] if [Watcher.Close] was called. -// -// See [Watcher.AddWith] for a version that allows adding options. -// -// # Watching directories -// -// All files in a directory are monitored, including new files that are created -// after the watcher is started. Subdirectories are not watched (i.e. it's -// non-recursive). -// -// # Watching files -// -// Watching individual files (rather than directories) is generally not -// recommended as many programs (especially editors) update files atomically: it -// will write to a temporary file which is then moved to to destination, -// overwriting the original (or some variant thereof). The watcher on the -// original file is now lost, as that no longer exists. -// -// The upshot of this is that a power failure or crash won't leave a -// half-written file. -// -// Watch the parent directory and use Event.Name to filter out files you're not -// interested in. There is an example of this in cmd/fsnotify/file.go. -func (w *Watcher) Add(name string) error { return w.AddWith(name) } +func (w *readDirChangesW) Add(name string) error { return w.AddWith(name) } -// AddWith is like [Watcher.Add], but allows adding options. When using Add() -// the defaults described below are used. -// -// Possible options are: -// -// - [WithBufferSize] sets the buffer size for the Windows backend; no-op on -// other platforms. The default is 64K (65536 bytes). -func (w *Watcher) AddWith(name string, opts ...addOpt) error { +func (w *readDirChangesW) AddWith(name string, opts ...addOpt) error { if w.isClosed() { return ErrClosed } + if debug { + fmt.Fprintf(os.Stderr, "FSNOTIFY_DEBUG: %s AddWith(%q)\n", + time.Now().Format("15:04:05.000000000"), filepath.ToSlash(name)) + } with := getOptions(opts...) + if !w.xSupports(with.op) { + return fmt.Errorf("%w: %s", xErrUnsupported, with.op) + } if with.bufsize < 4096 { return fmt.Errorf("fsnotify.WithBufferSize: buffer size cannot be smaller than 4096 bytes") } @@ -295,18 +141,14 @@ func (w *Watcher) AddWith(name string, opts ...addOpt) error { return <-in.reply } -// Remove stops monitoring the path for changes. -// -// Directories are always removed non-recursively. For example, if you added -// /tmp/dir and /tmp/dir/subdir then you will need to remove both. -// -// Removing a path that has not yet been added returns [ErrNonExistentWatch]. -// -// Returns nil if [Watcher.Close] was called. -func (w *Watcher) Remove(name string) error { +func (w *readDirChangesW) Remove(name string) error { if w.isClosed() { return nil } + if debug { + fmt.Fprintf(os.Stderr, "FSNOTIFY_DEBUG: %s Remove(%q)\n", + time.Now().Format("15:04:05.000000000"), filepath.ToSlash(name)) + } in := &input{ op: opRemoveWatch, @@ -320,11 +162,7 @@ func (w *Watcher) Remove(name string) error { return <-in.reply } -// WatchList returns all paths explicitly added with [Watcher.Add] (and are not -// yet removed). -// -// Returns nil if [Watcher.Close] was called. -func (w *Watcher) WatchList() []string { +func (w *readDirChangesW) WatchList() []string { if w.isClosed() { return nil } @@ -335,7 +173,13 @@ func (w *Watcher) WatchList() []string { entries := make([]string, 0, len(w.watches)) for _, entry := range w.watches { for _, watchEntry := range entry { - entries = append(entries, watchEntry.path) + for name := range watchEntry.names { + entries = append(entries, filepath.Join(watchEntry.path, name)) + } + // the directory itself is being watched + if watchEntry.mask != 0 { + entries = append(entries, watchEntry.path) + } } } @@ -361,7 +205,7 @@ const ( sysFSIGNORED = 0x8000 ) -func (w *Watcher) newEvent(name string, mask uint32) Event { +func (w *readDirChangesW) newEvent(name string, mask uint32) Event { e := Event{Name: name} if mask&sysFSCREATE == sysFSCREATE || mask&sysFSMOVEDTO == sysFSMOVEDTO { e.Op |= Create @@ -417,7 +261,7 @@ type ( watchMap map[uint32]indexMap ) -func (w *Watcher) wakeupReader() error { +func (w *readDirChangesW) wakeupReader() error { err := windows.PostQueuedCompletionStatus(w.port, 0, 0, nil) if err != nil { return os.NewSyscallError("PostQueuedCompletionStatus", err) @@ -425,7 +269,7 @@ func (w *Watcher) wakeupReader() error { return nil } -func (w *Watcher) getDir(pathname string) (dir string, err error) { +func (w *readDirChangesW) getDir(pathname string) (dir string, err error) { attr, err := windows.GetFileAttributes(windows.StringToUTF16Ptr(pathname)) if err != nil { return "", os.NewSyscallError("GetFileAttributes", err) @@ -439,7 +283,7 @@ func (w *Watcher) getDir(pathname string) (dir string, err error) { return } -func (w *Watcher) getIno(path string) (ino *inode, err error) { +func (w *readDirChangesW) getIno(path string) (ino *inode, err error) { h, err := windows.CreateFile(windows.StringToUTF16Ptr(path), windows.FILE_LIST_DIRECTORY, windows.FILE_SHARE_READ|windows.FILE_SHARE_WRITE|windows.FILE_SHARE_DELETE, @@ -482,9 +326,8 @@ func (m watchMap) set(ino *inode, watch *watch) { } // Must run within the I/O thread. -func (w *Watcher) addWatch(pathname string, flags uint64, bufsize int) error { - //pathname, recurse := recursivePath(pathname) - recurse := false +func (w *readDirChangesW) addWatch(pathname string, flags uint64, bufsize int) error { + pathname, recurse := recursivePath(pathname) dir, err := w.getDir(pathname) if err != nil { @@ -538,7 +381,7 @@ func (w *Watcher) addWatch(pathname string, flags uint64, bufsize int) error { } // Must run within the I/O thread. -func (w *Watcher) remWatch(pathname string) error { +func (w *readDirChangesW) remWatch(pathname string) error { pathname, recurse := recursivePath(pathname) dir, err := w.getDir(pathname) @@ -566,11 +409,11 @@ func (w *Watcher) remWatch(pathname string) error { return fmt.Errorf("%w: %s", ErrNonExistentWatch, pathname) } if pathname == dir { - w.sendEvent(watch.path, watch.mask&sysFSIGNORED) + w.sendEvent(watch.path, "", watch.mask&sysFSIGNORED) watch.mask = 0 } else { name := filepath.Base(pathname) - w.sendEvent(filepath.Join(watch.path, name), watch.names[name]&sysFSIGNORED) + w.sendEvent(filepath.Join(watch.path, name), "", watch.names[name]&sysFSIGNORED) delete(watch.names, name) } @@ -578,23 +421,23 @@ func (w *Watcher) remWatch(pathname string) error { } // Must run within the I/O thread. -func (w *Watcher) deleteWatch(watch *watch) { +func (w *readDirChangesW) deleteWatch(watch *watch) { for name, mask := range watch.names { if mask&provisional == 0 { - w.sendEvent(filepath.Join(watch.path, name), mask&sysFSIGNORED) + w.sendEvent(filepath.Join(watch.path, name), "", mask&sysFSIGNORED) } delete(watch.names, name) } if watch.mask != 0 { if watch.mask&provisional == 0 { - w.sendEvent(watch.path, watch.mask&sysFSIGNORED) + w.sendEvent(watch.path, "", watch.mask&sysFSIGNORED) } watch.mask = 0 } } // Must run within the I/O thread. -func (w *Watcher) startRead(watch *watch) error { +func (w *readDirChangesW) startRead(watch *watch) error { err := windows.CancelIo(watch.ino.handle) if err != nil { w.sendError(os.NewSyscallError("CancelIo", err)) @@ -624,7 +467,7 @@ func (w *Watcher) startRead(watch *watch) error { err := os.NewSyscallError("ReadDirectoryChanges", rdErr) if rdErr == windows.ERROR_ACCESS_DENIED && watch.mask&provisional == 0 { // Watched directory was probably removed - w.sendEvent(watch.path, watch.mask&sysFSDELETESELF) + w.sendEvent(watch.path, "", watch.mask&sysFSDELETESELF) err = nil } w.deleteWatch(watch) @@ -637,7 +480,7 @@ func (w *Watcher) startRead(watch *watch) error { // readEvents reads from the I/O completion port, converts the // received events into Event objects and sends them via the Events channel. // Entry point to the I/O thread. -func (w *Watcher) readEvents() { +func (w *readDirChangesW) readEvents() { var ( n uint32 key uintptr @@ -700,7 +543,7 @@ func (w *Watcher) readEvents() { } case windows.ERROR_ACCESS_DENIED: // Watched directory was probably removed - w.sendEvent(watch.path, watch.mask&sysFSDELETESELF) + w.sendEvent(watch.path, "", watch.mask&sysFSDELETESELF) w.deleteWatch(watch) w.startRead(watch) continue @@ -733,6 +576,10 @@ func (w *Watcher) readEvents() { name := windows.UTF16ToString(buf) fullname := filepath.Join(watch.path, name) + if debug { + internal.Debug(fullname, raw.Action) + } + var mask uint64 switch raw.Action { case windows.FILE_ACTION_REMOVED: @@ -761,21 +608,22 @@ func (w *Watcher) readEvents() { } } - sendNameEvent := func() { - w.sendEvent(fullname, watch.names[name]&mask) - } if raw.Action != windows.FILE_ACTION_RENAMED_NEW_NAME { - sendNameEvent() + w.sendEvent(fullname, "", watch.names[name]&mask) } if raw.Action == windows.FILE_ACTION_REMOVED { - w.sendEvent(fullname, watch.names[name]&sysFSIGNORED) + w.sendEvent(fullname, "", watch.names[name]&sysFSIGNORED) delete(watch.names, name) } - w.sendEvent(fullname, watch.mask&w.toFSnotifyFlags(raw.Action)) + if watch.rename != "" && raw.Action == windows.FILE_ACTION_RENAMED_NEW_NAME { + w.sendEvent(fullname, filepath.Join(watch.path, watch.rename), watch.mask&w.toFSnotifyFlags(raw.Action)) + } else { + w.sendEvent(fullname, "", watch.mask&w.toFSnotifyFlags(raw.Action)) + } + if raw.Action == windows.FILE_ACTION_RENAMED_NEW_NAME { - fullname = filepath.Join(watch.path, watch.rename) - sendNameEvent() + w.sendEvent(filepath.Join(watch.path, watch.rename), "", watch.names[name]&mask) } // Move to the next event in the buffer @@ -787,8 +635,7 @@ func (w *Watcher) readEvents() { // Error! if offset >= n { //lint:ignore ST1005 Windows should be capitalized - w.sendError(errors.New( - "Windows system assumed buffer larger than it is, events have likely been missed")) + w.sendError(errors.New("Windows system assumed buffer larger than it is, events have likely been missed")) break } } @@ -799,7 +646,7 @@ func (w *Watcher) readEvents() { } } -func (w *Watcher) toWindowsFlags(mask uint64) uint32 { +func (w *readDirChangesW) toWindowsFlags(mask uint64) uint32 { var m uint32 if mask&sysFSMODIFY != 0 { m |= windows.FILE_NOTIFY_CHANGE_LAST_WRITE @@ -810,7 +657,7 @@ func (w *Watcher) toWindowsFlags(mask uint64) uint32 { return m } -func (w *Watcher) toFSnotifyFlags(action uint32) uint64 { +func (w *readDirChangesW) toFSnotifyFlags(action uint32) uint64 { switch action { case windows.FILE_ACTION_ADDED: return sysFSCREATE @@ -825,3 +672,11 @@ func (w *Watcher) toFSnotifyFlags(action uint32) uint64 { } return 0 } + +func (w *readDirChangesW) xSupports(op Op) bool { + if op.Has(xUnportableOpen) || op.Has(xUnportableRead) || + op.Has(xUnportableCloseWrite) || op.Has(xUnportableCloseRead) { + return false + } + return true +} diff --git a/vendor/github.com/fsnotify/fsnotify/fsnotify.go b/vendor/github.com/fsnotify/fsnotify/fsnotify.go index 24c99cc49..0760efe91 100644 --- a/vendor/github.com/fsnotify/fsnotify/fsnotify.go +++ b/vendor/github.com/fsnotify/fsnotify/fsnotify.go @@ -3,19 +3,146 @@ // // Currently supported systems: // -// Linux 2.6.32+ via inotify -// BSD, macOS via kqueue -// Windows via ReadDirectoryChangesW -// illumos via FEN +// - Linux via inotify +// - BSD, macOS via kqueue +// - Windows via ReadDirectoryChangesW +// - illumos via FEN +// +// # FSNOTIFY_DEBUG +// +// Set the FSNOTIFY_DEBUG environment variable to "1" to print debug messages to +// stderr. This can be useful to track down some problems, especially in cases +// where fsnotify is used as an indirect dependency. +// +// Every event will be printed as soon as there's something useful to print, +// with as little processing from fsnotify. +// +// Example output: +// +// FSNOTIFY_DEBUG: 11:34:23.633087586 256:IN_CREATE → "/tmp/file-1" +// FSNOTIFY_DEBUG: 11:34:23.633202319 4:IN_ATTRIB → "/tmp/file-1" +// FSNOTIFY_DEBUG: 11:34:28.989728764 512:IN_DELETE → "/tmp/file-1" package fsnotify import ( "errors" "fmt" + "os" "path/filepath" "strings" ) +// Watcher watches a set of paths, delivering events on a channel. +// +// A watcher should not be copied (e.g. pass it by pointer, rather than by +// value). +// +// # Linux notes +// +// When a file is removed a Remove event won't be emitted until all file +// descriptors are closed, and deletes will always emit a Chmod. For example: +// +// fp := os.Open("file") +// os.Remove("file") // Triggers Chmod +// fp.Close() // Triggers Remove +// +// This is the event that inotify sends, so not much can be changed about this. +// +// The fs.inotify.max_user_watches sysctl variable specifies the upper limit +// for the number of watches per user, and fs.inotify.max_user_instances +// specifies the maximum number of inotify instances per user. Every Watcher you +// create is an "instance", and every path you add is a "watch". +// +// These are also exposed in /proc as /proc/sys/fs/inotify/max_user_watches and +// /proc/sys/fs/inotify/max_user_instances +// +// To increase them you can use sysctl or write the value to the /proc file: +// +// # Default values on Linux 5.18 +// sysctl fs.inotify.max_user_watches=124983 +// sysctl fs.inotify.max_user_instances=128 +// +// To make the changes persist on reboot edit /etc/sysctl.conf or +// /usr/lib/sysctl.d/50-default.conf (details differ per Linux distro; check +// your distro's documentation): +// +// fs.inotify.max_user_watches=124983 +// fs.inotify.max_user_instances=128 +// +// Reaching the limit will result in a "no space left on device" or "too many open +// files" error. +// +// # kqueue notes (macOS, BSD) +// +// kqueue requires opening a file descriptor for every file that's being watched; +// so if you're watching a directory with five files then that's six file +// descriptors. You will run in to your system's "max open files" limit faster on +// these platforms. +// +// The sysctl variables kern.maxfiles and kern.maxfilesperproc can be used to +// control the maximum number of open files, as well as /etc/login.conf on BSD +// systems. +// +// # Windows notes +// +// Paths can be added as "C:\\path\\to\\dir", but forward slashes +// ("C:/path/to/dir") will also work. +// +// When a watched directory is removed it will always send an event for the +// directory itself, but may not send events for all files in that directory. +// Sometimes it will send events for all files, sometimes it will send no +// events, and often only for some files. +// +// The default ReadDirectoryChangesW() buffer size is 64K, which is the largest +// value that is guaranteed to work with SMB filesystems. If you have many +// events in quick succession this may not be enough, and you will have to use +// [WithBufferSize] to increase the value. +type Watcher struct { + b backend + + // Events sends the filesystem change events. + // + // fsnotify can send the following events; a "path" here can refer to a + // file, directory, symbolic link, or special file like a FIFO. + // + // fsnotify.Create A new path was created; this may be followed by one + // or more Write events if data also gets written to a + // file. + // + // fsnotify.Remove A path was removed. + // + // fsnotify.Rename A path was renamed. A rename is always sent with the + // old path as Event.Name, and a Create event will be + // sent with the new name. Renames are only sent for + // paths that are currently watched; e.g. moving an + // unmonitored file into a monitored directory will + // show up as just a Create. Similarly, renaming a file + // to outside a monitored directory will show up as + // only a Rename. + // + // fsnotify.Write A file or named pipe was written to. A Truncate will + // also trigger a Write. A single "write action" + // initiated by the user may show up as one or multiple + // writes, depending on when the system syncs things to + // disk. For example when compiling a large Go program + // you may get hundreds of Write events, and you may + // want to wait until you've stopped receiving them + // (see the dedup example in cmd/fsnotify). + // + // Some systems may send Write event for directories + // when the directory content changes. + // + // fsnotify.Chmod Attributes were changed. On Linux this is also sent + // when a file is removed (or more accurately, when a + // link to an inode is removed). On kqueue it's sent + // when a file is truncated. On Windows it's never + // sent. + Events chan Event + + // Errors sends any errors. + Errors chan error +} + // Event represents a file system notification. type Event struct { // Path to the file or directory. @@ -30,6 +157,16 @@ type Event struct { // This is a bitmask and some systems may send multiple operations at once. // Use the Event.Has() method instead of comparing with ==. Op Op + + // Create events will have this set to the old path if it's a rename. This + // only works when both the source and destination are watched. It's not + // reliable when watching individual files, only directories. + // + // For example "mv /tmp/file /tmp/rename" will emit: + // + // Event{Op: Rename, Name: "/tmp/file"} + // Event{Op: Create, Name: "/tmp/rename", RenamedFrom: "/tmp/file"} + renamedFrom string } // Op describes a set of file operations. @@ -50,7 +187,7 @@ const ( // example "remove to trash" is often a rename). Remove - // The path was renamed to something else; any watched on it will be + // The path was renamed to something else; any watches on it will be // removed. Rename @@ -60,15 +197,155 @@ const ( // get triggered very frequently by some software. For example, Spotlight // indexing on macOS, anti-virus software, backup software, etc. Chmod + + // File descriptor was opened. + // + // Only works on Linux and FreeBSD. + xUnportableOpen + + // File was read from. + // + // Only works on Linux and FreeBSD. + xUnportableRead + + // File opened for writing was closed. + // + // Only works on Linux and FreeBSD. + // + // The advantage of using this over Write is that it's more reliable than + // waiting for Write events to stop. It's also faster (if you're not + // listening to Write events): copying a file of a few GB can easily + // generate tens of thousands of Write events in a short span of time. + xUnportableCloseWrite + + // File opened for reading was closed. + // + // Only works on Linux and FreeBSD. + xUnportableCloseRead ) -// Common errors that can be reported. var ( + // ErrNonExistentWatch is used when Remove() is called on a path that's not + // added. ErrNonExistentWatch = errors.New("fsnotify: can't remove non-existent watch") - ErrEventOverflow = errors.New("fsnotify: queue or buffer overflow") - ErrClosed = errors.New("fsnotify: watcher already closed") + + // ErrClosed is used when trying to operate on a closed Watcher. + ErrClosed = errors.New("fsnotify: watcher already closed") + + // ErrEventOverflow is reported from the Errors channel when there are too + // many events: + // + // - inotify: inotify returns IN_Q_OVERFLOW – because there are too + // many queued events (the fs.inotify.max_queued_events + // sysctl can be used to increase this). + // - windows: The buffer size is too small; WithBufferSize() can be used to increase it. + // - kqueue, fen: Not used. + ErrEventOverflow = errors.New("fsnotify: queue or buffer overflow") + + // ErrUnsupported is returned by AddWith() when WithOps() specified an + // Unportable event that's not supported on this platform. + xErrUnsupported = errors.New("fsnotify: not supported with this backend") ) +// NewWatcher creates a new Watcher. +func NewWatcher() (*Watcher, error) { + ev, errs := make(chan Event), make(chan error) + b, err := newBackend(ev, errs) + if err != nil { + return nil, err + } + return &Watcher{b: b, Events: ev, Errors: errs}, nil +} + +// NewBufferedWatcher creates a new Watcher with a buffered Watcher.Events +// channel. +// +// The main use case for this is situations with a very large number of events +// where the kernel buffer size can't be increased (e.g. due to lack of +// permissions). An unbuffered Watcher will perform better for almost all use +// cases, and whenever possible you will be better off increasing the kernel +// buffers instead of adding a large userspace buffer. +func NewBufferedWatcher(sz uint) (*Watcher, error) { + ev, errs := make(chan Event), make(chan error) + b, err := newBufferedBackend(sz, ev, errs) + if err != nil { + return nil, err + } + return &Watcher{b: b, Events: ev, Errors: errs}, nil +} + +// Add starts monitoring the path for changes. +// +// A path can only be watched once; watching it more than once is a no-op and will +// not return an error. Paths that do not yet exist on the filesystem cannot be +// watched. +// +// A watch will be automatically removed if the watched path is deleted or +// renamed. The exception is the Windows backend, which doesn't remove the +// watcher on renames. +// +// Notifications on network filesystems (NFS, SMB, FUSE, etc.) or special +// filesystems (/proc, /sys, etc.) generally don't work. +// +// Returns [ErrClosed] if [Watcher.Close] was called. +// +// See [Watcher.AddWith] for a version that allows adding options. +// +// # Watching directories +// +// All files in a directory are monitored, including new files that are created +// after the watcher is started. Subdirectories are not watched (i.e. it's +// non-recursive). +// +// # Watching files +// +// Watching individual files (rather than directories) is generally not +// recommended as many programs (especially editors) update files atomically: it +// will write to a temporary file which is then moved to destination, +// overwriting the original (or some variant thereof). The watcher on the +// original file is now lost, as that no longer exists. +// +// The upshot of this is that a power failure or crash won't leave a +// half-written file. +// +// Watch the parent directory and use Event.Name to filter out files you're not +// interested in. There is an example of this in cmd/fsnotify/file.go. +func (w *Watcher) Add(path string) error { return w.b.Add(path) } + +// AddWith is like [Watcher.Add], but allows adding options. When using Add() +// the defaults described below are used. +// +// Possible options are: +// +// - [WithBufferSize] sets the buffer size for the Windows backend; no-op on +// other platforms. The default is 64K (65536 bytes). +func (w *Watcher) AddWith(path string, opts ...addOpt) error { return w.b.AddWith(path, opts...) } + +// Remove stops monitoring the path for changes. +// +// Directories are always removed non-recursively. For example, if you added +// /tmp/dir and /tmp/dir/subdir then you will need to remove both. +// +// Removing a path that has not yet been added returns [ErrNonExistentWatch]. +// +// Returns nil if [Watcher.Close] was called. +func (w *Watcher) Remove(path string) error { return w.b.Remove(path) } + +// Close removes all watches and closes the Events channel. +func (w *Watcher) Close() error { return w.b.Close() } + +// WatchList returns all paths explicitly added with [Watcher.Add] (and are not +// yet removed). +// +// Returns nil if [Watcher.Close] was called. +func (w *Watcher) WatchList() []string { return w.b.WatchList() } + +// Supports reports if all the listed operations are supported by this platform. +// +// Create, Write, Remove, Rename, and Chmod are always supported. It can only +// return false for an Op starting with Unportable. +func (w *Watcher) xSupports(op Op) bool { return w.b.xSupports(op) } + func (o Op) String() string { var b strings.Builder if o.Has(Create) { @@ -80,6 +357,18 @@ func (o Op) String() string { if o.Has(Write) { b.WriteString("|WRITE") } + if o.Has(xUnportableOpen) { + b.WriteString("|OPEN") + } + if o.Has(xUnportableRead) { + b.WriteString("|READ") + } + if o.Has(xUnportableCloseWrite) { + b.WriteString("|CLOSE_WRITE") + } + if o.Has(xUnportableCloseRead) { + b.WriteString("|CLOSE_READ") + } if o.Has(Rename) { b.WriteString("|RENAME") } @@ -100,24 +389,48 @@ func (e Event) Has(op Op) bool { return e.Op.Has(op) } // String returns a string representation of the event with their path. func (e Event) String() string { + if e.renamedFrom != "" { + return fmt.Sprintf("%-13s %q ← %q", e.Op.String(), e.Name, e.renamedFrom) + } return fmt.Sprintf("%-13s %q", e.Op.String(), e.Name) } type ( + backend interface { + Add(string) error + AddWith(string, ...addOpt) error + Remove(string) error + WatchList() []string + Close() error + xSupports(Op) bool + } addOpt func(opt *withOpts) withOpts struct { - bufsize int + bufsize int + op Op + noFollow bool + sendCreate bool } ) +var debug = func() bool { + // Check for exactly "1" (rather than mere existence) so we can add + // options/flags in the future. I don't know if we ever want that, but it's + // nice to leave the option open. + return os.Getenv("FSNOTIFY_DEBUG") == "1" +}() + var defaultOpts = withOpts{ bufsize: 65536, // 64K + op: Create | Write | Remove | Rename | Chmod, } func getOptions(opts ...addOpt) withOpts { with := defaultOpts for _, o := range opts { - o(&with) + if o != nil { + o(&with) + } } return with } @@ -136,9 +449,44 @@ func WithBufferSize(bytes int) addOpt { return func(opt *withOpts) { opt.bufsize = bytes } } +// WithOps sets which operations to listen for. The default is [Create], +// [Write], [Remove], [Rename], and [Chmod]. +// +// Excluding operations you're not interested in can save quite a bit of CPU +// time; in some use cases there may be hundreds of thousands of useless Write +// or Chmod operations per second. +// +// This can also be used to add unportable operations not supported by all +// platforms; unportable operations all start with "Unportable": +// [UnportableOpen], [UnportableRead], [UnportableCloseWrite], and +// [UnportableCloseRead]. +// +// AddWith returns an error when using an unportable operation that's not +// supported. Use [Watcher.Support] to check for support. +func withOps(op Op) addOpt { + return func(opt *withOpts) { opt.op = op } +} + +// WithNoFollow disables following symlinks, so the symlinks themselves are +// watched. +func withNoFollow() addOpt { + return func(opt *withOpts) { opt.noFollow = true } +} + +// "Internal" option for recursive watches on inotify. +func withCreate() addOpt { + return func(opt *withOpts) { opt.sendCreate = true } +} + +var enableRecurse = false + // Check if this path is recursive (ends with "/..." or "\..."), and return the // path with the /... stripped. func recursivePath(path string) (string, bool) { + path = filepath.Clean(path) + if !enableRecurse { // Only enabled in tests for now. + return path, false + } if filepath.Base(path) == "..." { return filepath.Dir(path), true } diff --git a/vendor/github.com/fsnotify/fsnotify/internal/darwin.go b/vendor/github.com/fsnotify/fsnotify/internal/darwin.go new file mode 100644 index 000000000..b0eab1009 --- /dev/null +++ b/vendor/github.com/fsnotify/fsnotify/internal/darwin.go @@ -0,0 +1,39 @@ +//go:build darwin + +package internal + +import ( + "syscall" + + "golang.org/x/sys/unix" +) + +var ( + SyscallEACCES = syscall.EACCES + UnixEACCES = unix.EACCES +) + +var maxfiles uint64 + +// Go 1.19 will do this automatically: https://go-review.googlesource.com/c/go/+/393354/ +func SetRlimit() { + var l syscall.Rlimit + err := syscall.Getrlimit(syscall.RLIMIT_NOFILE, &l) + if err == nil && l.Cur != l.Max { + l.Cur = l.Max + syscall.Setrlimit(syscall.RLIMIT_NOFILE, &l) + } + maxfiles = l.Cur + + if n, err := syscall.SysctlUint32("kern.maxfiles"); err == nil && uint64(n) < maxfiles { + maxfiles = uint64(n) + } + + if n, err := syscall.SysctlUint32("kern.maxfilesperproc"); err == nil && uint64(n) < maxfiles { + maxfiles = uint64(n) + } +} + +func Maxfiles() uint64 { return maxfiles } +func Mkfifo(path string, mode uint32) error { return unix.Mkfifo(path, mode) } +func Mknod(path string, mode uint32, dev int) error { return unix.Mknod(path, mode, dev) } diff --git a/vendor/github.com/fsnotify/fsnotify/internal/debug_darwin.go b/vendor/github.com/fsnotify/fsnotify/internal/debug_darwin.go new file mode 100644 index 000000000..928319fb0 --- /dev/null +++ b/vendor/github.com/fsnotify/fsnotify/internal/debug_darwin.go @@ -0,0 +1,57 @@ +package internal + +import "golang.org/x/sys/unix" + +var names = []struct { + n string + m uint32 +}{ + {"NOTE_ABSOLUTE", unix.NOTE_ABSOLUTE}, + {"NOTE_ATTRIB", unix.NOTE_ATTRIB}, + {"NOTE_BACKGROUND", unix.NOTE_BACKGROUND}, + {"NOTE_CHILD", unix.NOTE_CHILD}, + {"NOTE_CRITICAL", unix.NOTE_CRITICAL}, + {"NOTE_DELETE", unix.NOTE_DELETE}, + {"NOTE_EXEC", unix.NOTE_EXEC}, + {"NOTE_EXIT", unix.NOTE_EXIT}, + {"NOTE_EXITSTATUS", unix.NOTE_EXITSTATUS}, + {"NOTE_EXIT_CSERROR", unix.NOTE_EXIT_CSERROR}, + {"NOTE_EXIT_DECRYPTFAIL", unix.NOTE_EXIT_DECRYPTFAIL}, + {"NOTE_EXIT_DETAIL", unix.NOTE_EXIT_DETAIL}, + {"NOTE_EXIT_DETAIL_MASK", unix.NOTE_EXIT_DETAIL_MASK}, + {"NOTE_EXIT_MEMORY", unix.NOTE_EXIT_MEMORY}, + {"NOTE_EXIT_REPARENTED", unix.NOTE_EXIT_REPARENTED}, + {"NOTE_EXTEND", unix.NOTE_EXTEND}, + {"NOTE_FFAND", unix.NOTE_FFAND}, + {"NOTE_FFCOPY", unix.NOTE_FFCOPY}, + {"NOTE_FFCTRLMASK", unix.NOTE_FFCTRLMASK}, + {"NOTE_FFLAGSMASK", unix.NOTE_FFLAGSMASK}, + {"NOTE_FFNOP", unix.NOTE_FFNOP}, + {"NOTE_FFOR", unix.NOTE_FFOR}, + {"NOTE_FORK", unix.NOTE_FORK}, + {"NOTE_FUNLOCK", unix.NOTE_FUNLOCK}, + {"NOTE_LEEWAY", unix.NOTE_LEEWAY}, + {"NOTE_LINK", unix.NOTE_LINK}, + {"NOTE_LOWAT", unix.NOTE_LOWAT}, + {"NOTE_MACHTIME", unix.NOTE_MACHTIME}, + {"NOTE_MACH_CONTINUOUS_TIME", unix.NOTE_MACH_CONTINUOUS_TIME}, + {"NOTE_NONE", unix.NOTE_NONE}, + {"NOTE_NSECONDS", unix.NOTE_NSECONDS}, + {"NOTE_OOB", unix.NOTE_OOB}, + //{"NOTE_PCTRLMASK", unix.NOTE_PCTRLMASK}, -0x100000 (?!) + {"NOTE_PDATAMASK", unix.NOTE_PDATAMASK}, + {"NOTE_REAP", unix.NOTE_REAP}, + {"NOTE_RENAME", unix.NOTE_RENAME}, + {"NOTE_REVOKE", unix.NOTE_REVOKE}, + {"NOTE_SECONDS", unix.NOTE_SECONDS}, + {"NOTE_SIGNAL", unix.NOTE_SIGNAL}, + {"NOTE_TRACK", unix.NOTE_TRACK}, + {"NOTE_TRACKERR", unix.NOTE_TRACKERR}, + {"NOTE_TRIGGER", unix.NOTE_TRIGGER}, + {"NOTE_USECONDS", unix.NOTE_USECONDS}, + {"NOTE_VM_ERROR", unix.NOTE_VM_ERROR}, + {"NOTE_VM_PRESSURE", unix.NOTE_VM_PRESSURE}, + {"NOTE_VM_PRESSURE_SUDDEN_TERMINATE", unix.NOTE_VM_PRESSURE_SUDDEN_TERMINATE}, + {"NOTE_VM_PRESSURE_TERMINATE", unix.NOTE_VM_PRESSURE_TERMINATE}, + {"NOTE_WRITE", unix.NOTE_WRITE}, +} diff --git a/vendor/github.com/fsnotify/fsnotify/internal/debug_dragonfly.go b/vendor/github.com/fsnotify/fsnotify/internal/debug_dragonfly.go new file mode 100644 index 000000000..3186b0c34 --- /dev/null +++ b/vendor/github.com/fsnotify/fsnotify/internal/debug_dragonfly.go @@ -0,0 +1,33 @@ +package internal + +import "golang.org/x/sys/unix" + +var names = []struct { + n string + m uint32 +}{ + {"NOTE_ATTRIB", unix.NOTE_ATTRIB}, + {"NOTE_CHILD", unix.NOTE_CHILD}, + {"NOTE_DELETE", unix.NOTE_DELETE}, + {"NOTE_EXEC", unix.NOTE_EXEC}, + {"NOTE_EXIT", unix.NOTE_EXIT}, + {"NOTE_EXTEND", unix.NOTE_EXTEND}, + {"NOTE_FFAND", unix.NOTE_FFAND}, + {"NOTE_FFCOPY", unix.NOTE_FFCOPY}, + {"NOTE_FFCTRLMASK", unix.NOTE_FFCTRLMASK}, + {"NOTE_FFLAGSMASK", unix.NOTE_FFLAGSMASK}, + {"NOTE_FFNOP", unix.NOTE_FFNOP}, + {"NOTE_FFOR", unix.NOTE_FFOR}, + {"NOTE_FORK", unix.NOTE_FORK}, + {"NOTE_LINK", unix.NOTE_LINK}, + {"NOTE_LOWAT", unix.NOTE_LOWAT}, + {"NOTE_OOB", unix.NOTE_OOB}, + {"NOTE_PCTRLMASK", unix.NOTE_PCTRLMASK}, + {"NOTE_PDATAMASK", unix.NOTE_PDATAMASK}, + {"NOTE_RENAME", unix.NOTE_RENAME}, + {"NOTE_REVOKE", unix.NOTE_REVOKE}, + {"NOTE_TRACK", unix.NOTE_TRACK}, + {"NOTE_TRACKERR", unix.NOTE_TRACKERR}, + {"NOTE_TRIGGER", unix.NOTE_TRIGGER}, + {"NOTE_WRITE", unix.NOTE_WRITE}, +} diff --git a/vendor/github.com/fsnotify/fsnotify/internal/debug_freebsd.go b/vendor/github.com/fsnotify/fsnotify/internal/debug_freebsd.go new file mode 100644 index 000000000..f69fdb930 --- /dev/null +++ b/vendor/github.com/fsnotify/fsnotify/internal/debug_freebsd.go @@ -0,0 +1,42 @@ +package internal + +import "golang.org/x/sys/unix" + +var names = []struct { + n string + m uint32 +}{ + {"NOTE_ABSTIME", unix.NOTE_ABSTIME}, + {"NOTE_ATTRIB", unix.NOTE_ATTRIB}, + {"NOTE_CHILD", unix.NOTE_CHILD}, + {"NOTE_CLOSE", unix.NOTE_CLOSE}, + {"NOTE_CLOSE_WRITE", unix.NOTE_CLOSE_WRITE}, + {"NOTE_DELETE", unix.NOTE_DELETE}, + {"NOTE_EXEC", unix.NOTE_EXEC}, + {"NOTE_EXIT", unix.NOTE_EXIT}, + {"NOTE_EXTEND", unix.NOTE_EXTEND}, + {"NOTE_FFAND", unix.NOTE_FFAND}, + {"NOTE_FFCOPY", unix.NOTE_FFCOPY}, + {"NOTE_FFCTRLMASK", unix.NOTE_FFCTRLMASK}, + {"NOTE_FFLAGSMASK", unix.NOTE_FFLAGSMASK}, + {"NOTE_FFNOP", unix.NOTE_FFNOP}, + {"NOTE_FFOR", unix.NOTE_FFOR}, + {"NOTE_FILE_POLL", unix.NOTE_FILE_POLL}, + {"NOTE_FORK", unix.NOTE_FORK}, + {"NOTE_LINK", unix.NOTE_LINK}, + {"NOTE_LOWAT", unix.NOTE_LOWAT}, + {"NOTE_MSECONDS", unix.NOTE_MSECONDS}, + {"NOTE_NSECONDS", unix.NOTE_NSECONDS}, + {"NOTE_OPEN", unix.NOTE_OPEN}, + {"NOTE_PCTRLMASK", unix.NOTE_PCTRLMASK}, + {"NOTE_PDATAMASK", unix.NOTE_PDATAMASK}, + {"NOTE_READ", unix.NOTE_READ}, + {"NOTE_RENAME", unix.NOTE_RENAME}, + {"NOTE_REVOKE", unix.NOTE_REVOKE}, + {"NOTE_SECONDS", unix.NOTE_SECONDS}, + {"NOTE_TRACK", unix.NOTE_TRACK}, + {"NOTE_TRACKERR", unix.NOTE_TRACKERR}, + {"NOTE_TRIGGER", unix.NOTE_TRIGGER}, + {"NOTE_USECONDS", unix.NOTE_USECONDS}, + {"NOTE_WRITE", unix.NOTE_WRITE}, +} diff --git a/vendor/github.com/fsnotify/fsnotify/internal/debug_kqueue.go b/vendor/github.com/fsnotify/fsnotify/internal/debug_kqueue.go new file mode 100644 index 000000000..607e683bd --- /dev/null +++ b/vendor/github.com/fsnotify/fsnotify/internal/debug_kqueue.go @@ -0,0 +1,32 @@ +//go:build freebsd || openbsd || netbsd || dragonfly || darwin + +package internal + +import ( + "fmt" + "os" + "strings" + "time" + + "golang.org/x/sys/unix" +) + +func Debug(name string, kevent *unix.Kevent_t) { + mask := uint32(kevent.Fflags) + + var ( + l []string + unknown = mask + ) + for _, n := range names { + if mask&n.m == n.m { + l = append(l, n.n) + unknown ^= n.m + } + } + if unknown > 0 { + l = append(l, fmt.Sprintf("0x%x", unknown)) + } + fmt.Fprintf(os.Stderr, "FSNOTIFY_DEBUG: %s %10d:%-60s → %q\n", + time.Now().Format("15:04:05.000000000"), mask, strings.Join(l, " | "), name) +} diff --git a/vendor/github.com/fsnotify/fsnotify/internal/debug_linux.go b/vendor/github.com/fsnotify/fsnotify/internal/debug_linux.go new file mode 100644 index 000000000..35c734be4 --- /dev/null +++ b/vendor/github.com/fsnotify/fsnotify/internal/debug_linux.go @@ -0,0 +1,56 @@ +package internal + +import ( + "fmt" + "os" + "strings" + "time" + + "golang.org/x/sys/unix" +) + +func Debug(name string, mask, cookie uint32) { + names := []struct { + n string + m uint32 + }{ + {"IN_ACCESS", unix.IN_ACCESS}, + {"IN_ATTRIB", unix.IN_ATTRIB}, + {"IN_CLOSE", unix.IN_CLOSE}, + {"IN_CLOSE_NOWRITE", unix.IN_CLOSE_NOWRITE}, + {"IN_CLOSE_WRITE", unix.IN_CLOSE_WRITE}, + {"IN_CREATE", unix.IN_CREATE}, + {"IN_DELETE", unix.IN_DELETE}, + {"IN_DELETE_SELF", unix.IN_DELETE_SELF}, + {"IN_IGNORED", unix.IN_IGNORED}, + {"IN_ISDIR", unix.IN_ISDIR}, + {"IN_MODIFY", unix.IN_MODIFY}, + {"IN_MOVE", unix.IN_MOVE}, + {"IN_MOVED_FROM", unix.IN_MOVED_FROM}, + {"IN_MOVED_TO", unix.IN_MOVED_TO}, + {"IN_MOVE_SELF", unix.IN_MOVE_SELF}, + {"IN_OPEN", unix.IN_OPEN}, + {"IN_Q_OVERFLOW", unix.IN_Q_OVERFLOW}, + {"IN_UNMOUNT", unix.IN_UNMOUNT}, + } + + var ( + l []string + unknown = mask + ) + for _, n := range names { + if mask&n.m == n.m { + l = append(l, n.n) + unknown ^= n.m + } + } + if unknown > 0 { + l = append(l, fmt.Sprintf("0x%x", unknown)) + } + var c string + if cookie > 0 { + c = fmt.Sprintf("(cookie: %d) ", cookie) + } + fmt.Fprintf(os.Stderr, "FSNOTIFY_DEBUG: %s %-30s → %s%q\n", + time.Now().Format("15:04:05.000000000"), strings.Join(l, "|"), c, name) +} diff --git a/vendor/github.com/fsnotify/fsnotify/internal/debug_netbsd.go b/vendor/github.com/fsnotify/fsnotify/internal/debug_netbsd.go new file mode 100644 index 000000000..e5b3b6f69 --- /dev/null +++ b/vendor/github.com/fsnotify/fsnotify/internal/debug_netbsd.go @@ -0,0 +1,25 @@ +package internal + +import "golang.org/x/sys/unix" + +var names = []struct { + n string + m uint32 +}{ + {"NOTE_ATTRIB", unix.NOTE_ATTRIB}, + {"NOTE_CHILD", unix.NOTE_CHILD}, + {"NOTE_DELETE", unix.NOTE_DELETE}, + {"NOTE_EXEC", unix.NOTE_EXEC}, + {"NOTE_EXIT", unix.NOTE_EXIT}, + {"NOTE_EXTEND", unix.NOTE_EXTEND}, + {"NOTE_FORK", unix.NOTE_FORK}, + {"NOTE_LINK", unix.NOTE_LINK}, + {"NOTE_LOWAT", unix.NOTE_LOWAT}, + {"NOTE_PCTRLMASK", unix.NOTE_PCTRLMASK}, + {"NOTE_PDATAMASK", unix.NOTE_PDATAMASK}, + {"NOTE_RENAME", unix.NOTE_RENAME}, + {"NOTE_REVOKE", unix.NOTE_REVOKE}, + {"NOTE_TRACK", unix.NOTE_TRACK}, + {"NOTE_TRACKERR", unix.NOTE_TRACKERR}, + {"NOTE_WRITE", unix.NOTE_WRITE}, +} diff --git a/vendor/github.com/fsnotify/fsnotify/internal/debug_openbsd.go b/vendor/github.com/fsnotify/fsnotify/internal/debug_openbsd.go new file mode 100644 index 000000000..1dd455bc5 --- /dev/null +++ b/vendor/github.com/fsnotify/fsnotify/internal/debug_openbsd.go @@ -0,0 +1,28 @@ +package internal + +import "golang.org/x/sys/unix" + +var names = []struct { + n string + m uint32 +}{ + {"NOTE_ATTRIB", unix.NOTE_ATTRIB}, + // {"NOTE_CHANGE", unix.NOTE_CHANGE}, // Not on 386? + {"NOTE_CHILD", unix.NOTE_CHILD}, + {"NOTE_DELETE", unix.NOTE_DELETE}, + {"NOTE_EOF", unix.NOTE_EOF}, + {"NOTE_EXEC", unix.NOTE_EXEC}, + {"NOTE_EXIT", unix.NOTE_EXIT}, + {"NOTE_EXTEND", unix.NOTE_EXTEND}, + {"NOTE_FORK", unix.NOTE_FORK}, + {"NOTE_LINK", unix.NOTE_LINK}, + {"NOTE_LOWAT", unix.NOTE_LOWAT}, + {"NOTE_PCTRLMASK", unix.NOTE_PCTRLMASK}, + {"NOTE_PDATAMASK", unix.NOTE_PDATAMASK}, + {"NOTE_RENAME", unix.NOTE_RENAME}, + {"NOTE_REVOKE", unix.NOTE_REVOKE}, + {"NOTE_TRACK", unix.NOTE_TRACK}, + {"NOTE_TRACKERR", unix.NOTE_TRACKERR}, + {"NOTE_TRUNCATE", unix.NOTE_TRUNCATE}, + {"NOTE_WRITE", unix.NOTE_WRITE}, +} diff --git a/vendor/github.com/fsnotify/fsnotify/internal/debug_solaris.go b/vendor/github.com/fsnotify/fsnotify/internal/debug_solaris.go new file mode 100644 index 000000000..f1b2e73bd --- /dev/null +++ b/vendor/github.com/fsnotify/fsnotify/internal/debug_solaris.go @@ -0,0 +1,45 @@ +package internal + +import ( + "fmt" + "os" + "strings" + "time" + + "golang.org/x/sys/unix" +) + +func Debug(name string, mask int32) { + names := []struct { + n string + m int32 + }{ + {"FILE_ACCESS", unix.FILE_ACCESS}, + {"FILE_MODIFIED", unix.FILE_MODIFIED}, + {"FILE_ATTRIB", unix.FILE_ATTRIB}, + {"FILE_TRUNC", unix.FILE_TRUNC}, + {"FILE_NOFOLLOW", unix.FILE_NOFOLLOW}, + {"FILE_DELETE", unix.FILE_DELETE}, + {"FILE_RENAME_TO", unix.FILE_RENAME_TO}, + {"FILE_RENAME_FROM", unix.FILE_RENAME_FROM}, + {"UNMOUNTED", unix.UNMOUNTED}, + {"MOUNTEDOVER", unix.MOUNTEDOVER}, + {"FILE_EXCEPTION", unix.FILE_EXCEPTION}, + } + + var ( + l []string + unknown = mask + ) + for _, n := range names { + if mask&n.m == n.m { + l = append(l, n.n) + unknown ^= n.m + } + } + if unknown > 0 { + l = append(l, fmt.Sprintf("0x%x", unknown)) + } + fmt.Fprintf(os.Stderr, "FSNOTIFY_DEBUG: %s %10d:%-30s → %q\n", + time.Now().Format("15:04:05.000000000"), mask, strings.Join(l, " | "), name) +} diff --git a/vendor/github.com/fsnotify/fsnotify/internal/debug_windows.go b/vendor/github.com/fsnotify/fsnotify/internal/debug_windows.go new file mode 100644 index 000000000..52bf4ce53 --- /dev/null +++ b/vendor/github.com/fsnotify/fsnotify/internal/debug_windows.go @@ -0,0 +1,40 @@ +package internal + +import ( + "fmt" + "os" + "path/filepath" + "strings" + "time" + + "golang.org/x/sys/windows" +) + +func Debug(name string, mask uint32) { + names := []struct { + n string + m uint32 + }{ + {"FILE_ACTION_ADDED", windows.FILE_ACTION_ADDED}, + {"FILE_ACTION_REMOVED", windows.FILE_ACTION_REMOVED}, + {"FILE_ACTION_MODIFIED", windows.FILE_ACTION_MODIFIED}, + {"FILE_ACTION_RENAMED_OLD_NAME", windows.FILE_ACTION_RENAMED_OLD_NAME}, + {"FILE_ACTION_RENAMED_NEW_NAME", windows.FILE_ACTION_RENAMED_NEW_NAME}, + } + + var ( + l []string + unknown = mask + ) + for _, n := range names { + if mask&n.m == n.m { + l = append(l, n.n) + unknown ^= n.m + } + } + if unknown > 0 { + l = append(l, fmt.Sprintf("0x%x", unknown)) + } + fmt.Fprintf(os.Stderr, "FSNOTIFY_DEBUG: %s %-65s → %q\n", + time.Now().Format("15:04:05.000000000"), strings.Join(l, " | "), filepath.ToSlash(name)) +} diff --git a/vendor/github.com/fsnotify/fsnotify/internal/freebsd.go b/vendor/github.com/fsnotify/fsnotify/internal/freebsd.go new file mode 100644 index 000000000..547df1df8 --- /dev/null +++ b/vendor/github.com/fsnotify/fsnotify/internal/freebsd.go @@ -0,0 +1,31 @@ +//go:build freebsd + +package internal + +import ( + "syscall" + + "golang.org/x/sys/unix" +) + +var ( + SyscallEACCES = syscall.EACCES + UnixEACCES = unix.EACCES +) + +var maxfiles uint64 + +func SetRlimit() { + // Go 1.19 will do this automatically: https://go-review.googlesource.com/c/go/+/393354/ + var l syscall.Rlimit + err := syscall.Getrlimit(syscall.RLIMIT_NOFILE, &l) + if err == nil && l.Cur != l.Max { + l.Cur = l.Max + syscall.Setrlimit(syscall.RLIMIT_NOFILE, &l) + } + maxfiles = uint64(l.Cur) +} + +func Maxfiles() uint64 { return maxfiles } +func Mkfifo(path string, mode uint32) error { return unix.Mkfifo(path, mode) } +func Mknod(path string, mode uint32, dev int) error { return unix.Mknod(path, mode, uint64(dev)) } diff --git a/vendor/github.com/fsnotify/fsnotify/internal/internal.go b/vendor/github.com/fsnotify/fsnotify/internal/internal.go new file mode 100644 index 000000000..7daa45e19 --- /dev/null +++ b/vendor/github.com/fsnotify/fsnotify/internal/internal.go @@ -0,0 +1,2 @@ +// Package internal contains some helpers. +package internal diff --git a/vendor/github.com/fsnotify/fsnotify/internal/unix.go b/vendor/github.com/fsnotify/fsnotify/internal/unix.go new file mode 100644 index 000000000..30976ce97 --- /dev/null +++ b/vendor/github.com/fsnotify/fsnotify/internal/unix.go @@ -0,0 +1,31 @@ +//go:build !windows && !darwin && !freebsd + +package internal + +import ( + "syscall" + + "golang.org/x/sys/unix" +) + +var ( + SyscallEACCES = syscall.EACCES + UnixEACCES = unix.EACCES +) + +var maxfiles uint64 + +func SetRlimit() { + // Go 1.19 will do this automatically: https://go-review.googlesource.com/c/go/+/393354/ + var l syscall.Rlimit + err := syscall.Getrlimit(syscall.RLIMIT_NOFILE, &l) + if err == nil && l.Cur != l.Max { + l.Cur = l.Max + syscall.Setrlimit(syscall.RLIMIT_NOFILE, &l) + } + maxfiles = uint64(l.Cur) +} + +func Maxfiles() uint64 { return maxfiles } +func Mkfifo(path string, mode uint32) error { return unix.Mkfifo(path, mode) } +func Mknod(path string, mode uint32, dev int) error { return unix.Mknod(path, mode, dev) } diff --git a/vendor/github.com/fsnotify/fsnotify/internal/unix2.go b/vendor/github.com/fsnotify/fsnotify/internal/unix2.go new file mode 100644 index 000000000..37dfeddc2 --- /dev/null +++ b/vendor/github.com/fsnotify/fsnotify/internal/unix2.go @@ -0,0 +1,7 @@ +//go:build !windows + +package internal + +func HasPrivilegesForSymlink() bool { + return true +} diff --git a/vendor/github.com/fsnotify/fsnotify/internal/windows.go b/vendor/github.com/fsnotify/fsnotify/internal/windows.go new file mode 100644 index 000000000..a72c64954 --- /dev/null +++ b/vendor/github.com/fsnotify/fsnotify/internal/windows.go @@ -0,0 +1,41 @@ +//go:build windows + +package internal + +import ( + "errors" + + "golang.org/x/sys/windows" +) + +// Just a dummy. +var ( + SyscallEACCES = errors.New("dummy") + UnixEACCES = errors.New("dummy") +) + +func SetRlimit() {} +func Maxfiles() uint64 { return 1<<64 - 1 } +func Mkfifo(path string, mode uint32) error { return errors.New("no FIFOs on Windows") } +func Mknod(path string, mode uint32, dev int) error { return errors.New("no device nodes on Windows") } + +func HasPrivilegesForSymlink() bool { + var sid *windows.SID + err := windows.AllocateAndInitializeSid( + &windows.SECURITY_NT_AUTHORITY, + 2, + windows.SECURITY_BUILTIN_DOMAIN_RID, + windows.DOMAIN_ALIAS_RID_ADMINS, + 0, 0, 0, 0, 0, 0, + &sid) + if err != nil { + return false + } + defer windows.FreeSid(sid) + token := windows.Token(0) + member, err := token.IsMember(sid) + if err != nil { + return false + } + return member || token.IsElevated() +} diff --git a/vendor/github.com/fsnotify/fsnotify/mkdoc.zsh b/vendor/github.com/fsnotify/fsnotify/mkdoc.zsh deleted file mode 100644 index 99012ae65..000000000 --- a/vendor/github.com/fsnotify/fsnotify/mkdoc.zsh +++ /dev/null @@ -1,259 +0,0 @@ -#!/usr/bin/env zsh -[ "${ZSH_VERSION:-}" = "" ] && echo >&2 "Only works with zsh" && exit 1 -setopt err_exit no_unset pipefail extended_glob - -# Simple script to update the godoc comments on all watchers so you don't need -# to update the same comment 5 times. - -watcher=$(</tmp/x - print -r -- $cmt >>/tmp/x - tail -n+$(( end + 1 )) $file >>/tmp/x - mv /tmp/x $file - done -} - -set-cmt '^type Watcher struct ' $watcher -set-cmt '^func NewWatcher(' $new -set-cmt '^func NewBufferedWatcher(' $newbuffered -set-cmt '^func (w \*Watcher) Add(' $add -set-cmt '^func (w \*Watcher) AddWith(' $addwith -set-cmt '^func (w \*Watcher) Remove(' $remove -set-cmt '^func (w \*Watcher) Close(' $close -set-cmt '^func (w \*Watcher) WatchList(' $watchlist -set-cmt '^[[:space:]]*Events *chan Event$' $events -set-cmt '^[[:space:]]*Errors *chan error$' $errors diff --git a/vendor/github.com/fsnotify/fsnotify/system_bsd.go b/vendor/github.com/fsnotify/fsnotify/system_bsd.go index 4322b0b88..f65e8fe3e 100644 --- a/vendor/github.com/fsnotify/fsnotify/system_bsd.go +++ b/vendor/github.com/fsnotify/fsnotify/system_bsd.go @@ -1,5 +1,4 @@ //go:build freebsd || openbsd || netbsd || dragonfly -// +build freebsd openbsd netbsd dragonfly package fsnotify diff --git a/vendor/github.com/fsnotify/fsnotify/system_darwin.go b/vendor/github.com/fsnotify/fsnotify/system_darwin.go index 5da5ffa78..a29fc7aab 100644 --- a/vendor/github.com/fsnotify/fsnotify/system_darwin.go +++ b/vendor/github.com/fsnotify/fsnotify/system_darwin.go @@ -1,5 +1,4 @@ //go:build darwin -// +build darwin package fsnotify diff --git a/vendor/github.com/grpc-ecosystem/grpc-gateway/v2/runtime/BUILD.bazel b/vendor/github.com/grpc-ecosystem/grpc-gateway/v2/runtime/BUILD.bazel index 78d7c9f5c..a65d88eb8 100644 --- a/vendor/github.com/grpc-ecosystem/grpc-gateway/v2/runtime/BUILD.bazel +++ b/vendor/github.com/grpc-ecosystem/grpc-gateway/v2/runtime/BUILD.bazel @@ -73,7 +73,7 @@ go_test( "@org_golang_google_genproto_googleapis_api//httpbody", "@org_golang_google_genproto_googleapis_rpc//errdetails", "@org_golang_google_genproto_googleapis_rpc//status", - "@org_golang_google_grpc//:go_default_library", + "@org_golang_google_grpc//:grpc", "@org_golang_google_grpc//codes", "@org_golang_google_grpc//health/grpc_health_v1", "@org_golang_google_grpc//metadata", diff --git a/vendor/github.com/grpc-ecosystem/grpc-gateway/v2/runtime/context.go b/vendor/github.com/grpc-ecosystem/grpc-gateway/v2/runtime/context.go index 5dd4e4478..2f2b34243 100644 --- a/vendor/github.com/grpc-ecosystem/grpc-gateway/v2/runtime/context.go +++ b/vendor/github.com/grpc-ecosystem/grpc-gateway/v2/runtime/context.go @@ -49,6 +49,7 @@ var malformedHTTPHeaders = map[string]struct{}{ type ( rpcMethodKey struct{} httpPathPatternKey struct{} + httpPatternKey struct{} AnnotateContextOption func(ctx context.Context) context.Context ) @@ -404,3 +405,13 @@ func HTTPPathPattern(ctx context.Context) (string, bool) { func withHTTPPathPattern(ctx context.Context, httpPathPattern string) context.Context { return context.WithValue(ctx, httpPathPatternKey{}, httpPathPattern) } + +// HTTPPattern returns the HTTP path pattern struct relating to the HTTP handler, if one exists. +func HTTPPattern(ctx context.Context) (Pattern, bool) { + v, ok := ctx.Value(httpPatternKey{}).(Pattern) + return v, ok +} + +func withHTTPPattern(ctx context.Context, httpPattern Pattern) context.Context { + return context.WithValue(ctx, httpPatternKey{}, httpPattern) +} diff --git a/vendor/github.com/grpc-ecosystem/grpc-gateway/v2/runtime/convert.go b/vendor/github.com/grpc-ecosystem/grpc-gateway/v2/runtime/convert.go index d7b15fcfb..2e50082ad 100644 --- a/vendor/github.com/grpc-ecosystem/grpc-gateway/v2/runtime/convert.go +++ b/vendor/github.com/grpc-ecosystem/grpc-gateway/v2/runtime/convert.go @@ -94,7 +94,7 @@ func Int64(val string) (int64, error) { } // Int64Slice converts 'val' where individual integers are separated by -// 'sep' into a int64 slice. +// 'sep' into an int64 slice. func Int64Slice(val, sep string) ([]int64, error) { s := strings.Split(val, sep) values := make([]int64, len(s)) @@ -118,7 +118,7 @@ func Int32(val string) (int32, error) { } // Int32Slice converts 'val' where individual integers are separated by -// 'sep' into a int32 slice. +// 'sep' into an int32 slice. func Int32Slice(val, sep string) ([]int32, error) { s := strings.Split(val, sep) values := make([]int32, len(s)) @@ -190,7 +190,7 @@ func Bytes(val string) ([]byte, error) { } // BytesSlice converts 'val' where individual bytes sequences, encoded in URL-safe -// base64 without padding, are separated by 'sep' into a slice of bytes slices slice. +// base64 without padding, are separated by 'sep' into a slice of byte slices. func BytesSlice(val, sep string) ([][]byte, error) { s := strings.Split(val, sep) values := make([][]byte, len(s)) diff --git a/vendor/github.com/grpc-ecosystem/grpc-gateway/v2/runtime/errors.go b/vendor/github.com/grpc-ecosystem/grpc-gateway/v2/runtime/errors.go index 568299869..41cd4f503 100644 --- a/vendor/github.com/grpc-ecosystem/grpc-gateway/v2/runtime/errors.go +++ b/vendor/github.com/grpc-ecosystem/grpc-gateway/v2/runtime/errors.go @@ -81,6 +81,21 @@ func HTTPError(ctx context.Context, mux *ServeMux, marshaler Marshaler, w http.R mux.errorHandler(ctx, mux, marshaler, w, r, err) } +// HTTPStreamError uses the mux-configured stream error handler to notify error to the client without closing the connection. +func HTTPStreamError(ctx context.Context, mux *ServeMux, marshaler Marshaler, w http.ResponseWriter, r *http.Request, err error) { + st := mux.streamErrorHandler(ctx, err) + msg := errorChunk(st) + buf, err := marshaler.Marshal(msg) + if err != nil { + grpclog.Errorf("Failed to marshal an error: %v", err) + return + } + if _, err := w.Write(buf); err != nil { + grpclog.Errorf("Failed to notify error to client: %v", err) + return + } +} + // DefaultHTTPErrorHandler is the default error handler. // If "err" is a gRPC Status, the function replies with the status code mapped by HTTPStatusFromCode. // If "err" is a HTTPStatusError, the function replies with the status code provide by that struct. This is @@ -93,6 +108,7 @@ func HTTPError(ctx context.Context, mux *ServeMux, marshaler Marshaler, w http.R func DefaultHTTPErrorHandler(ctx context.Context, mux *ServeMux, marshaler Marshaler, w http.ResponseWriter, r *http.Request, err error) { // return Internal when Marshal failed const fallback = `{"code": 13, "message": "failed to marshal error message"}` + const fallbackRewriter = `{"code": 13, "message": "failed to rewrite error message"}` var customStatus *HTTPStatusError if errors.As(err, &customStatus) { @@ -100,19 +116,28 @@ func DefaultHTTPErrorHandler(ctx context.Context, mux *ServeMux, marshaler Marsh } s := status.Convert(err) - pb := s.Proto() w.Header().Del("Trailer") w.Header().Del("Transfer-Encoding") - contentType := marshaler.ContentType(pb) + respRw, err := mux.forwardResponseRewriter(ctx, s.Proto()) + if err != nil { + grpclog.Errorf("Failed to rewrite error message %q: %v", s, err) + w.WriteHeader(http.StatusInternalServerError) + if _, err := io.WriteString(w, fallbackRewriter); err != nil { + grpclog.Errorf("Failed to write response: %v", err) + } + return + } + + contentType := marshaler.ContentType(respRw) w.Header().Set("Content-Type", contentType) if s.Code() == codes.Unauthenticated { w.Header().Set("WWW-Authenticate", s.Message()) } - buf, merr := marshaler.Marshal(pb) + buf, merr := marshaler.Marshal(respRw) if merr != nil { grpclog.Errorf("Failed to marshal error message %q: %v", s, merr) w.WriteHeader(http.StatusInternalServerError) diff --git a/vendor/github.com/grpc-ecosystem/grpc-gateway/v2/runtime/fieldmask.go b/vendor/github.com/grpc-ecosystem/grpc-gateway/v2/runtime/fieldmask.go index 9005d6a0b..2fcd7af3c 100644 --- a/vendor/github.com/grpc-ecosystem/grpc-gateway/v2/runtime/fieldmask.go +++ b/vendor/github.com/grpc-ecosystem/grpc-gateway/v2/runtime/fieldmask.go @@ -155,7 +155,7 @@ func buildPathsBlindly(name string, in interface{}) []string { return paths } -// fieldMaskPathItem stores a in-progress deconstruction of a path for a fieldmask +// fieldMaskPathItem stores an in-progress deconstruction of a path for a fieldmask type fieldMaskPathItem struct { // the list of prior fields leading up to node connected by dots path string diff --git a/vendor/github.com/grpc-ecosystem/grpc-gateway/v2/runtime/handler.go b/vendor/github.com/grpc-ecosystem/grpc-gateway/v2/runtime/handler.go index de1eef1f4..0fa907656 100644 --- a/vendor/github.com/grpc-ecosystem/grpc-gateway/v2/runtime/handler.go +++ b/vendor/github.com/grpc-ecosystem/grpc-gateway/v2/runtime/handler.go @@ -3,6 +3,7 @@ package runtime import ( "context" "errors" + "fmt" "io" "net/http" "net/textproto" @@ -55,20 +56,33 @@ func ForwardResponseStream(ctx context.Context, mux *ServeMux, marshaler Marshal return } + respRw, err := mux.forwardResponseRewriter(ctx, resp) + if err != nil { + grpclog.Errorf("Rewrite error: %v", err) + handleForwardResponseStreamError(ctx, wroteHeader, marshaler, w, req, mux, err, delimiter) + return + } + if !wroteHeader { - w.Header().Set("Content-Type", marshaler.ContentType(resp)) + var contentType string + if sct, ok := marshaler.(StreamContentType); ok { + contentType = sct.StreamContentType(respRw) + } else { + contentType = marshaler.ContentType(respRw) + } + w.Header().Set("Content-Type", contentType) } var buf []byte - httpBody, isHTTPBody := resp.(*httpbody.HttpBody) + httpBody, isHTTPBody := respRw.(*httpbody.HttpBody) switch { - case resp == nil: + case respRw == nil: buf, err = marshaler.Marshal(errorChunk(status.New(codes.Internal, "empty response"))) case isHTTPBody: buf = httpBody.GetData() default: - result := map[string]interface{}{"result": resp} - if rb, ok := resp.(responseBody); ok { + result := map[string]interface{}{"result": respRw} + if rb, ok := respRw.(responseBody); ok { result["result"] = rb.XXX_ResponseBody() } @@ -164,12 +178,17 @@ func ForwardResponseMessage(ctx context.Context, mux *ServeMux, marshaler Marsha HTTPError(ctx, mux, marshaler, w, req, err) return } + respRw, err := mux.forwardResponseRewriter(ctx, resp) + if err != nil { + grpclog.Errorf("Rewrite error: %v", err) + HTTPError(ctx, mux, marshaler, w, req, err) + return + } var buf []byte - var err error - if rb, ok := resp.(responseBody); ok { + if rb, ok := respRw.(responseBody); ok { buf, err = marshaler.Marshal(rb.XXX_ResponseBody()) } else { - buf, err = marshaler.Marshal(resp) + buf, err = marshaler.Marshal(respRw) } if err != nil { grpclog.Errorf("Marshal error: %v", err) @@ -181,7 +200,7 @@ func ForwardResponseMessage(ctx context.Context, mux *ServeMux, marshaler Marsha w.Header().Set("Content-Length", strconv.Itoa(len(buf))) } - if _, err = w.Write(buf); err != nil { + if _, err = w.Write(buf); err != nil && !errors.Is(err, http.ErrBodyNotAllowed) { grpclog.Errorf("Failed to write response: %v", err) } @@ -201,8 +220,7 @@ func handleForwardResponseOptions(ctx context.Context, w http.ResponseWriter, re } for _, opt := range opts { if err := opt(ctx, w, resp); err != nil { - grpclog.Errorf("Error handling ForwardResponseOptions: %v", err) - return err + return fmt.Errorf("error handling ForwardResponseOptions: %w", err) } } return nil diff --git a/vendor/github.com/grpc-ecosystem/grpc-gateway/v2/runtime/marshaler.go b/vendor/github.com/grpc-ecosystem/grpc-gateway/v2/runtime/marshaler.go index 2c0d25ff4..b1dfc37af 100644 --- a/vendor/github.com/grpc-ecosystem/grpc-gateway/v2/runtime/marshaler.go +++ b/vendor/github.com/grpc-ecosystem/grpc-gateway/v2/runtime/marshaler.go @@ -48,3 +48,11 @@ type Delimited interface { // Delimiter returns the record separator for the stream. Delimiter() []byte } + +// StreamContentType defines the streaming content type. +type StreamContentType interface { + // StreamContentType returns the content type for a stream. This shares the + // same behaviour as for `Marshaler.ContentType`, but is called, if present, + // in the case of a streamed response. + StreamContentType(v interface{}) string +} diff --git a/vendor/github.com/grpc-ecosystem/grpc-gateway/v2/runtime/marshaler_registry.go b/vendor/github.com/grpc-ecosystem/grpc-gateway/v2/runtime/marshaler_registry.go index 0b051e6e8..07c28112c 100644 --- a/vendor/github.com/grpc-ecosystem/grpc-gateway/v2/runtime/marshaler_registry.go +++ b/vendor/github.com/grpc-ecosystem/grpc-gateway/v2/runtime/marshaler_registry.go @@ -86,8 +86,8 @@ func (m marshalerRegistry) add(mime string, marshaler Marshaler) error { // It allows for a mapping of case-sensitive Content-Type MIME type string to runtime.Marshaler interfaces. // // For example, you could allow the client to specify the use of the runtime.JSONPb marshaler -// with a "application/jsonpb" Content-Type and the use of the runtime.JSONBuiltin marshaler -// with a "application/json" Content-Type. +// with an "application/jsonpb" Content-Type and the use of the runtime.JSONBuiltin marshaler +// with an "application/json" Content-Type. // "*" can be used to match any Content-Type. // This can be attached to a ServerMux with the marshaler option. func makeMarshalerMIMERegistry() marshalerRegistry { diff --git a/vendor/github.com/grpc-ecosystem/grpc-gateway/v2/runtime/mux.go b/vendor/github.com/grpc-ecosystem/grpc-gateway/v2/runtime/mux.go index ed9a7e438..60c2065dd 100644 --- a/vendor/github.com/grpc-ecosystem/grpc-gateway/v2/runtime/mux.go +++ b/vendor/github.com/grpc-ecosystem/grpc-gateway/v2/runtime/mux.go @@ -48,12 +48,19 @@ var encodedPathSplitter = regexp.MustCompile("(/|%2F)") // A HandlerFunc handles a specific pair of path pattern and HTTP method. type HandlerFunc func(w http.ResponseWriter, r *http.Request, pathParams map[string]string) +// A Middleware handler wraps another HandlerFunc to do some pre- and/or post-processing of the request. This is used as an alternative to gRPC interceptors when using the direct-to-implementation +// registration methods. It is generally recommended to use gRPC client or server interceptors instead +// where possible. +type Middleware func(HandlerFunc) HandlerFunc + // ServeMux is a request multiplexer for grpc-gateway. // It matches http requests to patterns and invokes the corresponding handler. type ServeMux struct { // handlers maps HTTP method to a list of handlers. handlers map[string][]handler + middlewares []Middleware forwardResponseOptions []func(context.Context, http.ResponseWriter, proto.Message) error + forwardResponseRewriter ForwardResponseRewriter marshalers marshalerRegistry incomingHeaderMatcher HeaderMatcherFunc outgoingHeaderMatcher HeaderMatcherFunc @@ -69,6 +76,24 @@ type ServeMux struct { // ServeMuxOption is an option that can be given to a ServeMux on construction. type ServeMuxOption func(*ServeMux) +// ForwardResponseRewriter is the signature of a function that is capable of rewriting messages +// before they are forwarded in a unary, stream, or error response. +type ForwardResponseRewriter func(ctx context.Context, response proto.Message) (any, error) + +// WithForwardResponseRewriter returns a ServeMuxOption that allows for implementers to insert logic +// that can rewrite the final response before it is forwarded. +// +// The response rewriter function is called during unary message forwarding, stream message +// forwarding and when errors are being forwarded. +// +// NOTE: Using this option will likely make what is generated by `protoc-gen-openapiv2` incorrect. +// Since this option involves making runtime changes to the response shape or type. +func WithForwardResponseRewriter(fwdResponseRewriter ForwardResponseRewriter) ServeMuxOption { + return func(sm *ServeMux) { + sm.forwardResponseRewriter = fwdResponseRewriter + } +} + // WithForwardResponseOption returns a ServeMuxOption representing the forwardResponseOption. // // forwardResponseOption is an option that will be called on the relevant context.Context, @@ -89,6 +114,15 @@ func WithUnescapingMode(mode UnescapingMode) ServeMuxOption { } } +// WithMiddlewares sets server middleware for all handlers. This is useful as an alternative to gRPC +// interceptors when using the direct-to-implementation registration methods and cannot rely +// on gRPC interceptors. It's recommended to use gRPC interceptors instead if possible. +func WithMiddlewares(middlewares ...Middleware) ServeMuxOption { + return func(serveMux *ServeMux) { + serveMux.middlewares = append(serveMux.middlewares, middlewares...) + } +} + // SetQueryParameterParser sets the query parameter parser, used to populate message from query parameters. // Configuring this will mean the generated OpenAPI output is no longer correct, and it should be // done with careful consideration. @@ -277,13 +311,14 @@ func WithHealthzEndpoint(healthCheckClient grpc_health_v1.HealthClient) ServeMux // NewServeMux returns a new ServeMux whose internal mapping is empty. func NewServeMux(opts ...ServeMuxOption) *ServeMux { serveMux := &ServeMux{ - handlers: make(map[string][]handler), - forwardResponseOptions: make([]func(context.Context, http.ResponseWriter, proto.Message) error, 0), - marshalers: makeMarshalerMIMERegistry(), - errorHandler: DefaultHTTPErrorHandler, - streamErrorHandler: DefaultStreamErrorHandler, - routingErrorHandler: DefaultRoutingErrorHandler, - unescapingMode: UnescapingModeDefault, + handlers: make(map[string][]handler), + forwardResponseOptions: make([]func(context.Context, http.ResponseWriter, proto.Message) error, 0), + forwardResponseRewriter: func(ctx context.Context, response proto.Message) (any, error) { return response, nil }, + marshalers: makeMarshalerMIMERegistry(), + errorHandler: DefaultHTTPErrorHandler, + streamErrorHandler: DefaultStreamErrorHandler, + routingErrorHandler: DefaultRoutingErrorHandler, + unescapingMode: UnescapingModeDefault, } for _, opt := range opts { @@ -305,6 +340,9 @@ func NewServeMux(opts ...ServeMuxOption) *ServeMux { // Handle associates "h" to the pair of HTTP method and path pattern. func (s *ServeMux) Handle(meth string, pat Pattern, h HandlerFunc) { + if len(s.middlewares) > 0 { + h = chainMiddlewares(s.middlewares)(h) + } s.handlers[meth] = append([]handler{{pat: pat, h: h}}, s.handlers[meth]...) } @@ -405,7 +443,7 @@ func (s *ServeMux) ServeHTTP(w http.ResponseWriter, r *http.Request) { } continue } - h.h(w, r, pathParams) + s.handleHandler(h, w, r, pathParams) return } @@ -458,7 +496,7 @@ func (s *ServeMux) ServeHTTP(w http.ResponseWriter, r *http.Request) { s.errorHandler(ctx, s, outboundMarshaler, w, r, sterr) return } - h.h(w, r, pathParams) + s.handleHandler(h, w, r, pathParams) return } _, outboundMarshaler := MarshalerForRequest(s, r) @@ -484,3 +522,16 @@ type handler struct { pat Pattern h HandlerFunc } + +func (s *ServeMux) handleHandler(h handler, w http.ResponseWriter, r *http.Request, pathParams map[string]string) { + h.h(w, r.WithContext(withHTTPPattern(r.Context(), h.pat)), pathParams) +} + +func chainMiddlewares(mws []Middleware) Middleware { + return func(next HandlerFunc) HandlerFunc { + for i := len(mws); i > 0; i-- { + next = mws[i-1](next) + } + return next + } +} diff --git a/vendor/github.com/grpc-ecosystem/grpc-gateway/v2/runtime/proto2_convert.go b/vendor/github.com/grpc-ecosystem/grpc-gateway/v2/runtime/proto2_convert.go index d549407f2..f710036b3 100644 --- a/vendor/github.com/grpc-ecosystem/grpc-gateway/v2/runtime/proto2_convert.go +++ b/vendor/github.com/grpc-ecosystem/grpc-gateway/v2/runtime/proto2_convert.go @@ -40,7 +40,7 @@ func Float32P(val string) (*float32, error) { } // Int64P parses the given string representation of an integer -// and returns a pointer to a int64 whose value is same as the parsed integer. +// and returns a pointer to an int64 whose value is same as the parsed integer. func Int64P(val string) (*int64, error) { i, err := Int64(val) if err != nil { @@ -50,7 +50,7 @@ func Int64P(val string) (*int64, error) { } // Int32P parses the given string representation of an integer -// and returns a pointer to a int32 whose value is same as the parsed integer. +// and returns a pointer to an int32 whose value is same as the parsed integer. func Int32P(val string) (*int32, error) { i, err := Int32(val) if err != nil { diff --git a/vendor/github.com/grpc-ecosystem/grpc-gateway/v2/runtime/query.go b/vendor/github.com/grpc-ecosystem/grpc-gateway/v2/runtime/query.go index fe634174b..93fb09922 100644 --- a/vendor/github.com/grpc-ecosystem/grpc-gateway/v2/runtime/query.go +++ b/vendor/github.com/grpc-ecosystem/grpc-gateway/v2/runtime/query.go @@ -291,7 +291,11 @@ func parseMessage(msgDescriptor protoreflect.MessageDescriptor, value string) (p if err != nil { return protoreflect.Value{}, err } - msg = timestamppb.New(t) + timestamp := timestamppb.New(t) + if ok := timestamp.IsValid(); !ok { + return protoreflect.Value{}, fmt.Errorf("%s before 0001-01-01", value) + } + msg = timestamp case "google.protobuf.Duration": d, err := time.ParseDuration(value) if err != nil { diff --git a/vendor/github.com/grpc-ecosystem/grpc-gateway/v2/utilities/pattern.go b/vendor/github.com/grpc-ecosystem/grpc-gateway/v2/utilities/pattern.go index dfe7de486..38ca39cc5 100644 --- a/vendor/github.com/grpc-ecosystem/grpc-gateway/v2/utilities/pattern.go +++ b/vendor/github.com/grpc-ecosystem/grpc-gateway/v2/utilities/pattern.go @@ -1,6 +1,6 @@ package utilities -// An OpCode is a opcode of compiled path patterns. +// OpCode is an opcode of compiled path patterns. type OpCode int // These constants are the valid values of OpCode. diff --git a/vendor/github.com/grpc-ecosystem/grpc-gateway/v2/utilities/string_array_flag.go b/vendor/github.com/grpc-ecosystem/grpc-gateway/v2/utilities/string_array_flag.go index d224ab776..66aa5f2dc 100644 --- a/vendor/github.com/grpc-ecosystem/grpc-gateway/v2/utilities/string_array_flag.go +++ b/vendor/github.com/grpc-ecosystem/grpc-gateway/v2/utilities/string_array_flag.go @@ -5,7 +5,7 @@ import ( "strings" ) -// flagInterface is an cut down interface to `flag` +// flagInterface is a cut down interface to `flag` type flagInterface interface { Var(value flag.Value, name string, usage string) } diff --git a/vendor/github.com/open-policy-agent/opa/ast/annotations.go b/vendor/github.com/open-policy-agent/opa/ast/annotations.go index d6267a0e6..533290d32 100644 --- a/vendor/github.com/open-policy-agent/opa/ast/annotations.go +++ b/vendor/github.com/open-policy-agent/opa/ast/annotations.go @@ -5,973 +5,29 @@ package ast import ( - "encoding/json" - "fmt" - "net/url" - "sort" - "strings" - - astJSON "github.com/open-policy-agent/opa/ast/json" - "github.com/open-policy-agent/opa/internal/deepcopy" - "github.com/open-policy-agent/opa/util" -) - -const ( - annotationScopePackage = "package" - annotationScopeImport = "import" - annotationScopeRule = "rule" - annotationScopeDocument = "document" - annotationScopeSubpackages = "subpackages" + v1 "github.com/open-policy-agent/opa/v1/ast" ) type ( // Annotations represents metadata attached to other AST nodes such as rules. - Annotations struct { - Scope string `json:"scope"` - Title string `json:"title,omitempty"` - Entrypoint bool `json:"entrypoint,omitempty"` - Description string `json:"description,omitempty"` - Organizations []string `json:"organizations,omitempty"` - RelatedResources []*RelatedResourceAnnotation `json:"related_resources,omitempty"` - Authors []*AuthorAnnotation `json:"authors,omitempty"` - Schemas []*SchemaAnnotation `json:"schemas,omitempty"` - Custom map[string]interface{} `json:"custom,omitempty"` - Location *Location `json:"location,omitempty"` - - comments []*Comment - node Node - jsonOptions astJSON.Options - } + Annotations = v1.Annotations // SchemaAnnotation contains a schema declaration for the document identified by the path. - SchemaAnnotation struct { - Path Ref `json:"path"` - Schema Ref `json:"schema,omitempty"` - Definition *interface{} `json:"definition,omitempty"` - } - - AuthorAnnotation struct { - Name string `json:"name"` - Email string `json:"email,omitempty"` - } - - RelatedResourceAnnotation struct { - Ref url.URL `json:"ref"` - Description string `json:"description,omitempty"` - } - - AnnotationSet struct { - byRule map[*Rule][]*Annotations - byPackage map[int]*Annotations - byPath *annotationTreeNode - modules []*Module // Modules this set was constructed from - } + SchemaAnnotation = v1.SchemaAnnotation - annotationTreeNode struct { - Value *Annotations - Children map[Value]*annotationTreeNode // we assume key elements are hashable (vars and strings only!) - } + AuthorAnnotation = v1.AuthorAnnotation - AnnotationsRef struct { - Path Ref `json:"path"` // The path of the node the annotations are applied to - Annotations *Annotations `json:"annotations,omitempty"` - Location *Location `json:"location,omitempty"` // The location of the node the annotations are applied to + RelatedResourceAnnotation = v1.RelatedResourceAnnotation - jsonOptions astJSON.Options + AnnotationSet = v1.AnnotationSet - node Node // The node the annotations are applied to - } + AnnotationsRef = v1.AnnotationsRef - AnnotationsRefSet []*AnnotationsRef + AnnotationsRefSet = v1.AnnotationsRefSet - FlatAnnotationsRefSet AnnotationsRefSet + FlatAnnotationsRefSet = v1.FlatAnnotationsRefSet ) -func (a *Annotations) String() string { - bs, _ := a.MarshalJSON() - return string(bs) -} - -// Loc returns the location of this annotation. -func (a *Annotations) Loc() *Location { - return a.Location -} - -// SetLoc updates the location of this annotation. -func (a *Annotations) SetLoc(l *Location) { - a.Location = l -} - -// EndLoc returns the location of this annotation's last comment line. -func (a *Annotations) EndLoc() *Location { - count := len(a.comments) - if count == 0 { - return a.Location - } - return a.comments[count-1].Location -} - -// Compare returns an integer indicating if a is less than, equal to, or greater -// than other. -func (a *Annotations) Compare(other *Annotations) int { - - if a == nil && other == nil { - return 0 - } - - if a == nil { - return -1 - } - - if other == nil { - return 1 - } - - if cmp := scopeCompare(a.Scope, other.Scope); cmp != 0 { - return cmp - } - - if cmp := strings.Compare(a.Title, other.Title); cmp != 0 { - return cmp - } - - if cmp := strings.Compare(a.Description, other.Description); cmp != 0 { - return cmp - } - - if cmp := compareStringLists(a.Organizations, other.Organizations); cmp != 0 { - return cmp - } - - if cmp := compareRelatedResources(a.RelatedResources, other.RelatedResources); cmp != 0 { - return cmp - } - - if cmp := compareAuthors(a.Authors, other.Authors); cmp != 0 { - return cmp - } - - if cmp := compareSchemas(a.Schemas, other.Schemas); cmp != 0 { - return cmp - } - - if a.Entrypoint != other.Entrypoint { - if a.Entrypoint { - return 1 - } - return -1 - } - - if cmp := util.Compare(a.Custom, other.Custom); cmp != 0 { - return cmp - } - - return 0 -} - -// GetTargetPath returns the path of the node these Annotations are applied to (the target) -func (a *Annotations) GetTargetPath() Ref { - switch n := a.node.(type) { - case *Package: - return n.Path - case *Rule: - return n.Ref().GroundPrefix() - default: - return nil - } -} - -func (a *Annotations) setJSONOptions(opts astJSON.Options) { - a.jsonOptions = opts - if a.Location != nil { - a.Location.JSONOptions = opts - } -} - -func (a *Annotations) MarshalJSON() ([]byte, error) { - if a == nil { - return []byte(`{"scope":""}`), nil - } - - data := map[string]interface{}{ - "scope": a.Scope, - } - - if a.Title != "" { - data["title"] = a.Title - } - - if a.Description != "" { - data["description"] = a.Description - } - - if a.Entrypoint { - data["entrypoint"] = a.Entrypoint - } - - if len(a.Organizations) > 0 { - data["organizations"] = a.Organizations - } - - if len(a.RelatedResources) > 0 { - data["related_resources"] = a.RelatedResources - } - - if len(a.Authors) > 0 { - data["authors"] = a.Authors - } - - if len(a.Schemas) > 0 { - data["schemas"] = a.Schemas - } - - if len(a.Custom) > 0 { - data["custom"] = a.Custom - } - - if a.jsonOptions.MarshalOptions.IncludeLocation.Annotations { - if a.Location != nil { - data["location"] = a.Location - } - } - - return json.Marshal(data) -} - func NewAnnotationsRef(a *Annotations) *AnnotationsRef { - var loc *Location - if a.node != nil { - loc = a.node.Loc() - } - - return &AnnotationsRef{ - Location: loc, - Path: a.GetTargetPath(), - Annotations: a, - node: a.node, - jsonOptions: a.jsonOptions, - } -} - -func (ar *AnnotationsRef) GetPackage() *Package { - switch n := ar.node.(type) { - case *Package: - return n - case *Rule: - return n.Module.Package - default: - return nil - } -} - -func (ar *AnnotationsRef) GetRule() *Rule { - switch n := ar.node.(type) { - case *Rule: - return n - default: - return nil - } -} - -func (ar *AnnotationsRef) MarshalJSON() ([]byte, error) { - data := map[string]interface{}{ - "path": ar.Path, - } - - if ar.Annotations != nil { - data["annotations"] = ar.Annotations - } - - if ar.jsonOptions.MarshalOptions.IncludeLocation.AnnotationsRef { - if ar.Location != nil { - data["location"] = ar.Location - } - } - - return json.Marshal(data) -} - -func scopeCompare(s1, s2 string) int { - - o1 := scopeOrder(s1) - o2 := scopeOrder(s2) - - if o2 < o1 { - return 1 - } else if o2 > o1 { - return -1 - } - - if s1 < s2 { - return -1 - } else if s2 < s1 { - return 1 - } - - return 0 -} - -func scopeOrder(s string) int { - switch s { - case annotationScopeRule: - return 1 - } - return 0 -} - -func compareAuthors(a, b []*AuthorAnnotation) int { - if len(a) > len(b) { - return 1 - } else if len(a) < len(b) { - return -1 - } - - for i := 0; i < len(a); i++ { - if cmp := a[i].Compare(b[i]); cmp != 0 { - return cmp - } - } - - return 0 -} - -func compareRelatedResources(a, b []*RelatedResourceAnnotation) int { - if len(a) > len(b) { - return 1 - } else if len(a) < len(b) { - return -1 - } - - for i := 0; i < len(a); i++ { - if cmp := strings.Compare(a[i].String(), b[i].String()); cmp != 0 { - return cmp - } - } - - return 0 -} - -func compareSchemas(a, b []*SchemaAnnotation) int { - maxLen := len(a) - if len(b) < maxLen { - maxLen = len(b) - } - - for i := 0; i < maxLen; i++ { - if cmp := a[i].Compare(b[i]); cmp != 0 { - return cmp - } - } - - if len(a) > len(b) { - return 1 - } else if len(a) < len(b) { - return -1 - } - - return 0 -} - -func compareStringLists(a, b []string) int { - if len(a) > len(b) { - return 1 - } else if len(a) < len(b) { - return -1 - } - - for i := 0; i < len(a); i++ { - if cmp := strings.Compare(a[i], b[i]); cmp != 0 { - return cmp - } - } - - return 0 -} - -// Copy returns a deep copy of s. -func (a *Annotations) Copy(node Node) *Annotations { - cpy := *a - - cpy.Organizations = make([]string, len(a.Organizations)) - copy(cpy.Organizations, a.Organizations) - - cpy.RelatedResources = make([]*RelatedResourceAnnotation, len(a.RelatedResources)) - for i := range a.RelatedResources { - cpy.RelatedResources[i] = a.RelatedResources[i].Copy() - } - - cpy.Authors = make([]*AuthorAnnotation, len(a.Authors)) - for i := range a.Authors { - cpy.Authors[i] = a.Authors[i].Copy() - } - - cpy.Schemas = make([]*SchemaAnnotation, len(a.Schemas)) - for i := range a.Schemas { - cpy.Schemas[i] = a.Schemas[i].Copy() - } - - cpy.Custom = deepcopy.Map(a.Custom) - - cpy.node = node - - return &cpy -} - -// toObject constructs an AST Object from the annotation. -func (a *Annotations) toObject() (*Object, *Error) { - obj := NewObject() - - if a == nil { - return &obj, nil - } - - if len(a.Scope) > 0 { - obj.Insert(StringTerm("scope"), StringTerm(a.Scope)) - } - - if len(a.Title) > 0 { - obj.Insert(StringTerm("title"), StringTerm(a.Title)) - } - - if a.Entrypoint { - obj.Insert(StringTerm("entrypoint"), BooleanTerm(true)) - } - - if len(a.Description) > 0 { - obj.Insert(StringTerm("description"), StringTerm(a.Description)) - } - - if len(a.Organizations) > 0 { - orgs := make([]*Term, 0, len(a.Organizations)) - for _, org := range a.Organizations { - orgs = append(orgs, StringTerm(org)) - } - obj.Insert(StringTerm("organizations"), ArrayTerm(orgs...)) - } - - if len(a.RelatedResources) > 0 { - rrs := make([]*Term, 0, len(a.RelatedResources)) - for _, rr := range a.RelatedResources { - rrObj := NewObject(Item(StringTerm("ref"), StringTerm(rr.Ref.String()))) - if len(rr.Description) > 0 { - rrObj.Insert(StringTerm("description"), StringTerm(rr.Description)) - } - rrs = append(rrs, NewTerm(rrObj)) - } - obj.Insert(StringTerm("related_resources"), ArrayTerm(rrs...)) - } - - if len(a.Authors) > 0 { - as := make([]*Term, 0, len(a.Authors)) - for _, author := range a.Authors { - aObj := NewObject() - if len(author.Name) > 0 { - aObj.Insert(StringTerm("name"), StringTerm(author.Name)) - } - if len(author.Email) > 0 { - aObj.Insert(StringTerm("email"), StringTerm(author.Email)) - } - as = append(as, NewTerm(aObj)) - } - obj.Insert(StringTerm("authors"), ArrayTerm(as...)) - } - - if len(a.Schemas) > 0 { - ss := make([]*Term, 0, len(a.Schemas)) - for _, s := range a.Schemas { - sObj := NewObject() - if len(s.Path) > 0 { - sObj.Insert(StringTerm("path"), NewTerm(s.Path.toArray())) - } - if len(s.Schema) > 0 { - sObj.Insert(StringTerm("schema"), NewTerm(s.Schema.toArray())) - } - if s.Definition != nil { - def, err := InterfaceToValue(s.Definition) - if err != nil { - return nil, NewError(CompileErr, a.Location, "invalid definition in schema annotation: %s", err.Error()) - } - sObj.Insert(StringTerm("definition"), NewTerm(def)) - } - ss = append(ss, NewTerm(sObj)) - } - obj.Insert(StringTerm("schemas"), ArrayTerm(ss...)) - } - - if len(a.Custom) > 0 { - c, err := InterfaceToValue(a.Custom) - if err != nil { - return nil, NewError(CompileErr, a.Location, "invalid custom annotation %s", err.Error()) - } - obj.Insert(StringTerm("custom"), NewTerm(c)) - } - - return &obj, nil -} - -func attachRuleAnnotations(mod *Module) { - // make a copy of the annotations - cpy := make([]*Annotations, len(mod.Annotations)) - for i, a := range mod.Annotations { - cpy[i] = a.Copy(a.node) - } - - for _, rule := range mod.Rules { - var j int - var found bool - for i, a := range cpy { - if rule.Ref().GroundPrefix().Equal(a.GetTargetPath()) { - if a.Scope == annotationScopeDocument { - rule.Annotations = append(rule.Annotations, a) - } else if a.Scope == annotationScopeRule && rule.Loc().Row > a.Location.Row { - j = i - found = true - rule.Annotations = append(rule.Annotations, a) - } - } - } - - if found && j < len(cpy) { - cpy = append(cpy[:j], cpy[j+1:]...) - } - } -} - -func attachAnnotationsNodes(mod *Module) Errors { - var errs Errors - - // Find first non-annotation statement following each annotation and attach - // the annotation to that statement. - for _, a := range mod.Annotations { - for _, stmt := range mod.stmts { - _, ok := stmt.(*Annotations) - if !ok { - if stmt.Loc().Row > a.Location.Row { - a.node = stmt - break - } - } - } - - if a.Scope == "" { - switch a.node.(type) { - case *Rule: - if a.Entrypoint { - a.Scope = annotationScopeDocument - } else { - a.Scope = annotationScopeRule - } - case *Package: - a.Scope = annotationScopePackage - case *Import: - a.Scope = annotationScopeImport - } - } - - if err := validateAnnotationScopeAttachment(a); err != nil { - errs = append(errs, err) - } - - if err := validateAnnotationEntrypointAttachment(a); err != nil { - errs = append(errs, err) - } - } - - return errs -} - -func validateAnnotationScopeAttachment(a *Annotations) *Error { - - switch a.Scope { - case annotationScopeRule, annotationScopeDocument: - if _, ok := a.node.(*Rule); ok { - return nil - } - return newScopeAttachmentErr(a, "rule") - case annotationScopePackage, annotationScopeSubpackages: - if _, ok := a.node.(*Package); ok { - return nil - } - return newScopeAttachmentErr(a, "package") - } - - return NewError(ParseErr, a.Loc(), "invalid annotation scope '%v'. Use one of '%s', '%s', '%s', or '%s'", - a.Scope, annotationScopeRule, annotationScopeDocument, annotationScopePackage, annotationScopeSubpackages) -} - -func validateAnnotationEntrypointAttachment(a *Annotations) *Error { - if a.Entrypoint && !(a.Scope == annotationScopeDocument || a.Scope == annotationScopePackage) { - return NewError( - ParseErr, a.Loc(), "annotation entrypoint applied to non-document or package scope '%v'", a.Scope) - } - return nil -} - -// Copy returns a deep copy of a. -func (a *AuthorAnnotation) Copy() *AuthorAnnotation { - cpy := *a - return &cpy -} - -// Compare returns an integer indicating if s is less than, equal to, or greater -// than other. -func (a *AuthorAnnotation) Compare(other *AuthorAnnotation) int { - if cmp := strings.Compare(a.Name, other.Name); cmp != 0 { - return cmp - } - - if cmp := strings.Compare(a.Email, other.Email); cmp != 0 { - return cmp - } - - return 0 -} - -func (a *AuthorAnnotation) String() string { - if len(a.Email) == 0 { - return a.Name - } else if len(a.Name) == 0 { - return fmt.Sprintf("<%s>", a.Email) - } - return fmt.Sprintf("%s <%s>", a.Name, a.Email) -} - -// Copy returns a deep copy of rr. -func (rr *RelatedResourceAnnotation) Copy() *RelatedResourceAnnotation { - cpy := *rr - return &cpy -} - -// Compare returns an integer indicating if s is less than, equal to, or greater -// than other. -func (rr *RelatedResourceAnnotation) Compare(other *RelatedResourceAnnotation) int { - if cmp := strings.Compare(rr.Description, other.Description); cmp != 0 { - return cmp - } - - if cmp := strings.Compare(rr.Ref.String(), other.Ref.String()); cmp != 0 { - return cmp - } - - return 0 -} - -func (rr *RelatedResourceAnnotation) String() string { - bs, _ := json.Marshal(rr) - return string(bs) -} - -func (rr *RelatedResourceAnnotation) MarshalJSON() ([]byte, error) { - d := map[string]interface{}{ - "ref": rr.Ref.String(), - } - - if len(rr.Description) > 0 { - d["description"] = rr.Description - } - - return json.Marshal(d) -} - -// Copy returns a deep copy of s. -func (s *SchemaAnnotation) Copy() *SchemaAnnotation { - cpy := *s - return &cpy -} - -// Compare returns an integer indicating if s is less than, equal to, or greater -// than other. -func (s *SchemaAnnotation) Compare(other *SchemaAnnotation) int { - - if cmp := s.Path.Compare(other.Path); cmp != 0 { - return cmp - } - - if cmp := s.Schema.Compare(other.Schema); cmp != 0 { - return cmp - } - - if s.Definition != nil && other.Definition == nil { - return -1 - } else if s.Definition == nil && other.Definition != nil { - return 1 - } else if s.Definition != nil && other.Definition != nil { - return util.Compare(*s.Definition, *other.Definition) - } - - return 0 -} - -func (s *SchemaAnnotation) String() string { - bs, _ := json.Marshal(s) - return string(bs) -} - -func newAnnotationSet() *AnnotationSet { - return &AnnotationSet{ - byRule: map[*Rule][]*Annotations{}, - byPackage: map[int]*Annotations{}, - byPath: newAnnotationTree(), - } -} - -func BuildAnnotationSet(modules []*Module) (*AnnotationSet, Errors) { - as := newAnnotationSet() - var errs Errors - for _, m := range modules { - for _, a := range m.Annotations { - if err := as.add(a); err != nil { - errs = append(errs, err) - } - } - } - if len(errs) > 0 { - return nil, errs - } - as.modules = modules - return as, nil -} - -// NOTE(philipc): During copy propagation, the underlying Nodes can be -// stripped away from the annotations, leading to nil deref panics. We -// silently ignore these cases for now, as a workaround. -func (as *AnnotationSet) add(a *Annotations) *Error { - switch a.Scope { - case annotationScopeRule: - if rule, ok := a.node.(*Rule); ok { - as.byRule[rule] = append(as.byRule[rule], a) - } - case annotationScopePackage: - if pkg, ok := a.node.(*Package); ok { - hash := pkg.Path.Hash() - if exist, ok := as.byPackage[hash]; ok { - return errAnnotationRedeclared(a, exist.Location) - } - as.byPackage[hash] = a - } - case annotationScopeDocument: - if rule, ok := a.node.(*Rule); ok { - path := rule.Ref().GroundPrefix() - x := as.byPath.get(path) - if x != nil { - return errAnnotationRedeclared(a, x.Value.Location) - } - as.byPath.insert(path, a) - } - case annotationScopeSubpackages: - if pkg, ok := a.node.(*Package); ok { - x := as.byPath.get(pkg.Path) - if x != nil && x.Value != nil { - return errAnnotationRedeclared(a, x.Value.Location) - } - as.byPath.insert(pkg.Path, a) - } - } - return nil -} - -func (as *AnnotationSet) GetRuleScope(r *Rule) []*Annotations { - if as == nil { - return nil - } - return as.byRule[r] -} - -func (as *AnnotationSet) GetSubpackagesScope(path Ref) []*Annotations { - if as == nil { - return nil - } - return as.byPath.ancestors(path) -} - -func (as *AnnotationSet) GetDocumentScope(path Ref) *Annotations { - if as == nil { - return nil - } - if node := as.byPath.get(path); node != nil { - return node.Value - } - return nil -} - -func (as *AnnotationSet) GetPackageScope(pkg *Package) *Annotations { - if as == nil { - return nil - } - return as.byPackage[pkg.Path.Hash()] -} - -// Flatten returns a flattened list view of this AnnotationSet. -// The returned slice is sorted, first by the annotations' target path, then by their target location -func (as *AnnotationSet) Flatten() FlatAnnotationsRefSet { - // This preallocation often won't be optimal, but it's superior to starting with a nil slice. - refs := make([]*AnnotationsRef, 0, len(as.byPath.Children)+len(as.byRule)+len(as.byPackage)) - - refs = as.byPath.flatten(refs) - - for _, a := range as.byPackage { - refs = append(refs, NewAnnotationsRef(a)) - } - - for _, as := range as.byRule { - for _, a := range as { - refs = append(refs, NewAnnotationsRef(a)) - } - } - - // Sort by path, then annotation location, for stable output - sort.SliceStable(refs, func(i, j int) bool { - return refs[i].Compare(refs[j]) < 0 - }) - - return refs -} - -// Chain returns the chain of annotations leading up to the given rule. -// The returned slice is ordered as follows -// 0. Entries for the given rule, ordered from the METADATA block declared immediately above the rule, to the block declared farthest away (always at least one entry) -// 1. The 'document' scope entry, if any -// 2. The 'package' scope entry, if any -// 3. Entries for the 'subpackages' scope, if any; ordered from the closest package path to the fartest. E.g.: 'do.re.mi', 'do.re', 'do' -// The returned slice is guaranteed to always contain at least one entry, corresponding to the given rule. -func (as *AnnotationSet) Chain(rule *Rule) AnnotationsRefSet { - var refs []*AnnotationsRef - - ruleAnnots := as.GetRuleScope(rule) - - if len(ruleAnnots) >= 1 { - for _, a := range ruleAnnots { - refs = append(refs, NewAnnotationsRef(a)) - } - } else { - // Make sure there is always a leading entry representing the passed rule, even if it has no annotations - refs = append(refs, &AnnotationsRef{ - Location: rule.Location, - Path: rule.Ref().GroundPrefix(), - node: rule, - }) - } - - if len(refs) > 1 { - // Sort by annotation location; chain must start with annotations declared closest to rule, then going outward - sort.SliceStable(refs, func(i, j int) bool { - return refs[i].Annotations.Location.Compare(refs[j].Annotations.Location) > 0 - }) - } - - docAnnots := as.GetDocumentScope(rule.Ref().GroundPrefix()) - if docAnnots != nil { - refs = append(refs, NewAnnotationsRef(docAnnots)) - } - - pkg := rule.Module.Package - pkgAnnots := as.GetPackageScope(pkg) - if pkgAnnots != nil { - refs = append(refs, NewAnnotationsRef(pkgAnnots)) - } - - subPkgAnnots := as.GetSubpackagesScope(pkg.Path) - // We need to reverse the order, as subPkgAnnots ordering will start at the root, - // whereas we want to end at the root. - for i := len(subPkgAnnots) - 1; i >= 0; i-- { - refs = append(refs, NewAnnotationsRef(subPkgAnnots[i])) - } - - return refs -} - -func (ars FlatAnnotationsRefSet) Insert(ar *AnnotationsRef) FlatAnnotationsRefSet { - result := make(FlatAnnotationsRefSet, 0, len(ars)+1) - - // insertion sort, first by path, then location - for i, current := range ars { - if ar.Compare(current) < 0 { - result = append(result, ar) - result = append(result, ars[i:]...) - break - } - result = append(result, current) - } - - if len(result) < len(ars)+1 { - result = append(result, ar) - } - - return result -} - -func newAnnotationTree() *annotationTreeNode { - return &annotationTreeNode{ - Value: nil, - Children: map[Value]*annotationTreeNode{}, - } -} - -func (t *annotationTreeNode) insert(path Ref, value *Annotations) { - node := t - for _, k := range path { - child, ok := node.Children[k.Value] - if !ok { - child = newAnnotationTree() - node.Children[k.Value] = child - } - node = child - } - node.Value = value -} - -func (t *annotationTreeNode) get(path Ref) *annotationTreeNode { - node := t - for _, k := range path { - if node == nil { - return nil - } - child, ok := node.Children[k.Value] - if !ok { - return nil - } - node = child - } - return node -} - -// ancestors returns a slice of annotations in ascending order, starting with the root of ref; e.g.: 'root', 'root.foo', 'root.foo.bar'. -func (t *annotationTreeNode) ancestors(path Ref) (result []*Annotations) { - node := t - for _, k := range path { - if node == nil { - return result - } - child, ok := node.Children[k.Value] - if !ok { - return result - } - if child.Value != nil { - result = append(result, child.Value) - } - node = child - } - return result -} - -func (t *annotationTreeNode) flatten(refs []*AnnotationsRef) []*AnnotationsRef { - if a := t.Value; a != nil { - refs = append(refs, NewAnnotationsRef(a)) - } - for _, c := range t.Children { - refs = c.flatten(refs) - } - return refs -} - -func (ar *AnnotationsRef) Compare(other *AnnotationsRef) int { - if c := ar.Path.Compare(other.Path); c != 0 { - return c - } - - if c := ar.Annotations.Location.Compare(other.Annotations.Location); c != 0 { - return c - } - - return ar.Annotations.Compare(other.Annotations) + return v1.NewAnnotationsRef(a) } diff --git a/vendor/github.com/open-policy-agent/opa/ast/builtins.go b/vendor/github.com/open-policy-agent/opa/ast/builtins.go index f54d91d31..d0ab69a16 100644 --- a/vendor/github.com/open-policy-agent/opa/ast/builtins.go +++ b/vendor/github.com/open-policy-agent/opa/ast/builtins.go @@ -5,1348 +5,230 @@ package ast import ( - "strings" - - "github.com/open-policy-agent/opa/types" + v1 "github.com/open-policy-agent/opa/v1/ast" ) // Builtins is the registry of built-in functions supported by OPA. // Call RegisterBuiltin to add a new built-in. -var Builtins []*Builtin +var Builtins = v1.Builtins // RegisterBuiltin adds a new built-in function to the registry. func RegisterBuiltin(b *Builtin) { - Builtins = append(Builtins, b) - BuiltinMap[b.Name] = b - if len(b.Infix) > 0 { - BuiltinMap[b.Infix] = b - } + v1.RegisterBuiltin(b) } // DefaultBuiltins is the registry of built-in functions supported in OPA // by default. When adding a new built-in function to OPA, update this // list. -var DefaultBuiltins = [...]*Builtin{ - // Unification/equality ("=") - Equality, - - // Assignment (":=") - Assign, - - // Membership, infix "in": `x in xs` - Member, - MemberWithKey, - - // Comparisons - GreaterThan, - GreaterThanEq, - LessThan, - LessThanEq, - NotEqual, - Equal, - - // Arithmetic - Plus, - Minus, - Multiply, - Divide, - Ceil, - Floor, - Round, - Abs, - Rem, - - // Bitwise Arithmetic - BitsOr, - BitsAnd, - BitsNegate, - BitsXOr, - BitsShiftLeft, - BitsShiftRight, - - // Binary - And, - Or, - - // Aggregates - Count, - Sum, - Product, - Max, - Min, - Any, - All, - - // Arrays - ArrayConcat, - ArraySlice, - ArrayReverse, - - // Conversions - ToNumber, - - // Casts (DEPRECATED) - CastObject, - CastNull, - CastBoolean, - CastString, - CastSet, - CastArray, - - // Regular Expressions - RegexIsValid, - RegexMatch, - RegexMatchDeprecated, - RegexSplit, - GlobsMatch, - RegexTemplateMatch, - RegexFind, - RegexFindAllStringSubmatch, - RegexReplace, - - // Sets - SetDiff, - Intersection, - Union, - - // Strings - AnyPrefixMatch, - AnySuffixMatch, - Concat, - FormatInt, - IndexOf, - IndexOfN, - Substring, - Lower, - Upper, - Contains, - StringCount, - StartsWith, - EndsWith, - Split, - Replace, - ReplaceN, - Trim, - TrimLeft, - TrimPrefix, - TrimRight, - TrimSuffix, - TrimSpace, - Sprintf, - StringReverse, - RenderTemplate, - - // Numbers - NumbersRange, - NumbersRangeStep, - RandIntn, - - // Encoding - JSONMarshal, - JSONMarshalWithOptions, - JSONUnmarshal, - JSONIsValid, - Base64Encode, - Base64Decode, - Base64IsValid, - Base64UrlEncode, - Base64UrlEncodeNoPad, - Base64UrlDecode, - URLQueryDecode, - URLQueryEncode, - URLQueryEncodeObject, - URLQueryDecodeObject, - YAMLMarshal, - YAMLUnmarshal, - YAMLIsValid, - HexEncode, - HexDecode, - - // Object Manipulation - ObjectUnion, - ObjectUnionN, - ObjectRemove, - ObjectFilter, - ObjectGet, - ObjectKeys, - ObjectSubset, - - // JSON Object Manipulation - JSONFilter, - JSONRemove, - JSONPatch, - - // Tokens - JWTDecode, - JWTVerifyRS256, - JWTVerifyRS384, - JWTVerifyRS512, - JWTVerifyPS256, - JWTVerifyPS384, - JWTVerifyPS512, - JWTVerifyES256, - JWTVerifyES384, - JWTVerifyES512, - JWTVerifyHS256, - JWTVerifyHS384, - JWTVerifyHS512, - JWTDecodeVerify, - JWTEncodeSignRaw, - JWTEncodeSign, - - // Time - NowNanos, - ParseNanos, - ParseRFC3339Nanos, - ParseDurationNanos, - Format, - Date, - Clock, - Weekday, - AddDate, - Diff, - - // Crypto - CryptoX509ParseCertificates, - CryptoX509ParseAndVerifyCertificates, - CryptoX509ParseAndVerifyCertificatesWithOptions, - CryptoMd5, - CryptoSha1, - CryptoSha256, - CryptoX509ParseCertificateRequest, - CryptoX509ParseRSAPrivateKey, - CryptoX509ParseKeyPair, - CryptoParsePrivateKeys, - CryptoHmacMd5, - CryptoHmacSha1, - CryptoHmacSha256, - CryptoHmacSha512, - CryptoHmacEqual, - - // Graphs - WalkBuiltin, - ReachableBuiltin, - ReachablePathsBuiltin, - - // Sort - Sort, - - // Types - IsNumber, - IsString, - IsBoolean, - IsArray, - IsSet, - IsObject, - IsNull, - TypeNameBuiltin, - - // HTTP - HTTPSend, - - // GraphQL - GraphQLParse, - GraphQLParseAndVerify, - GraphQLParseQuery, - GraphQLParseSchema, - GraphQLIsValid, - GraphQLSchemaIsValid, - - // JSON Schema - JSONSchemaVerify, - JSONMatchSchema, - - // Cloud Provider Helpers - ProvidersAWSSignReqObj, - - // Rego - RegoParseModule, - RegoMetadataChain, - RegoMetadataRule, - - // OPA - OPARuntime, - - // Tracing - Trace, - - // Networking - NetCIDROverlap, - NetCIDRIntersects, - NetCIDRContains, - NetCIDRContainsMatches, - NetCIDRExpand, - NetCIDRMerge, - NetLookupIPAddr, - NetCIDRIsValid, - - // Glob - GlobMatch, - GlobQuoteMeta, - - // Units - UnitsParse, - UnitsParseBytes, - - // UUIDs - UUIDRFC4122, - UUIDParse, - - // SemVers - SemVerIsValid, - SemVerCompare, - - // Printing - Print, - InternalPrint, -} +var DefaultBuiltins = v1.DefaultBuiltins // BuiltinMap provides a convenient mapping of built-in names to // built-in definitions. -var BuiltinMap map[string]*Builtin +var BuiltinMap = v1.BuiltinMap // Deprecated: Builtins can now be directly annotated with the // Nondeterministic property, and when set to true, will be ignored // for partial evaluation. -var IgnoreDuringPartialEval = []*Builtin{ - RandIntn, - UUIDRFC4122, - JWTDecodeVerify, - JWTEncodeSignRaw, - JWTEncodeSign, - NowNanos, - HTTPSend, - OPARuntime, - NetLookupIPAddr, -} +var IgnoreDuringPartialEval = v1.IgnoreDuringPartialEval /** * Unification */ // Equality represents the "=" operator. -var Equality = &Builtin{ - Name: "eq", - Infix: "=", - Decl: types.NewFunction( - types.Args(types.A, types.A), - types.B, - ), -} +var Equality = v1.Equality /** * Assignment */ // Assign represents the assignment (":=") operator. -var Assign = &Builtin{ - Name: "assign", - Infix: ":=", - Decl: types.NewFunction( - types.Args(types.A, types.A), - types.B, - ), -} +var Assign = v1.Assign // Member represents the `in` (infix) operator. -var Member = &Builtin{ - Name: "internal.member_2", - Infix: "in", - Decl: types.NewFunction( - types.Args( - types.A, - types.A, - ), - types.B, - ), -} +var Member = v1.Member // MemberWithKey represents the `in` (infix) operator when used // with two terms on the lhs, i.e., `k, v in obj`. -var MemberWithKey = &Builtin{ - Name: "internal.member_3", - Infix: "in", - Decl: types.NewFunction( - types.Args( - types.A, - types.A, - types.A, - ), - types.B, - ), -} +var MemberWithKey = v1.MemberWithKey -/** - * Comparisons - */ -var comparison = category("comparison") - -var GreaterThan = &Builtin{ - Name: "gt", - Infix: ">", - Categories: comparison, - Decl: types.NewFunction( - types.Args( - types.Named("x", types.A), - types.Named("y", types.A), - ), - types.Named("result", types.B).Description("true if `x` is greater than `y`; false otherwise"), - ), -} +var GreaterThan = v1.GreaterThan -var GreaterThanEq = &Builtin{ - Name: "gte", - Infix: ">=", - Categories: comparison, - Decl: types.NewFunction( - types.Args( - types.Named("x", types.A), - types.Named("y", types.A), - ), - types.Named("result", types.B).Description("true if `x` is greater or equal to `y`; false otherwise"), - ), -} +var GreaterThanEq = v1.GreaterThanEq // LessThan represents the "<" comparison operator. -var LessThan = &Builtin{ - Name: "lt", - Infix: "<", - Categories: comparison, - Decl: types.NewFunction( - types.Args( - types.Named("x", types.A), - types.Named("y", types.A), - ), - types.Named("result", types.B).Description("true if `x` is less than `y`; false otherwise"), - ), -} +var LessThan = v1.LessThan -var LessThanEq = &Builtin{ - Name: "lte", - Infix: "<=", - Categories: comparison, - Decl: types.NewFunction( - types.Args( - types.Named("x", types.A), - types.Named("y", types.A), - ), - types.Named("result", types.B).Description("true if `x` is less than or equal to `y`; false otherwise"), - ), -} +var LessThanEq = v1.LessThanEq -var NotEqual = &Builtin{ - Name: "neq", - Infix: "!=", - Categories: comparison, - Decl: types.NewFunction( - types.Args( - types.Named("x", types.A), - types.Named("y", types.A), - ), - types.Named("result", types.B).Description("true if `x` is not equal to `y`; false otherwise"), - ), -} +var NotEqual = v1.NotEqual // Equal represents the "==" comparison operator. -var Equal = &Builtin{ - Name: "equal", - Infix: "==", - Categories: comparison, - Decl: types.NewFunction( - types.Args( - types.Named("x", types.A), - types.Named("y", types.A), - ), - types.Named("result", types.B).Description("true if `x` is equal to `y`; false otherwise"), - ), -} +var Equal = v1.Equal -/** - * Arithmetic - */ -var number = category("numbers") - -var Plus = &Builtin{ - Name: "plus", - Infix: "+", - Description: "Plus adds two numbers together.", - Decl: types.NewFunction( - types.Args( - types.Named("x", types.N), - types.Named("y", types.N), - ), - types.Named("z", types.N).Description("the sum of `x` and `y`"), - ), - Categories: number, -} +var Plus = v1.Plus -var Minus = &Builtin{ - Name: "minus", - Infix: "-", - Description: "Minus subtracts the second number from the first number or computes the difference between two sets.", - Decl: types.NewFunction( - types.Args( - types.Named("x", types.NewAny(types.N, types.NewSet(types.A))), - types.Named("y", types.NewAny(types.N, types.NewSet(types.A))), - ), - types.Named("z", types.NewAny(types.N, types.NewSet(types.A))).Description("the difference of `x` and `y`"), - ), - Categories: category("sets", "numbers"), -} +var Minus = v1.Minus -var Multiply = &Builtin{ - Name: "mul", - Infix: "*", - Description: "Multiplies two numbers.", - Decl: types.NewFunction( - types.Args( - types.Named("x", types.N), - types.Named("y", types.N), - ), - types.Named("z", types.N).Description("the product of `x` and `y`"), - ), - Categories: number, -} +var Multiply = v1.Multiply -var Divide = &Builtin{ - Name: "div", - Infix: "/", - Description: "Divides the first number by the second number.", - Decl: types.NewFunction( - types.Args( - types.Named("x", types.N).Description("the dividend"), - types.Named("y", types.N).Description("the divisor"), - ), - types.Named("z", types.N).Description("the result of `x` divided by `y`"), - ), - Categories: number, -} +var Divide = v1.Divide -var Round = &Builtin{ - Name: "round", - Description: "Rounds the number to the nearest integer.", - Decl: types.NewFunction( - types.Args( - types.Named("x", types.N).Description("the number to round"), - ), - types.Named("y", types.N).Description("the result of rounding `x`"), - ), - Categories: number, -} +var Round = v1.Round -var Ceil = &Builtin{ - Name: "ceil", - Description: "Rounds the number _up_ to the nearest integer.", - Decl: types.NewFunction( - types.Args( - types.Named("x", types.N).Description("the number to round"), - ), - types.Named("y", types.N).Description("the result of rounding `x` _up_"), - ), - Categories: number, -} +var Ceil = v1.Ceil -var Floor = &Builtin{ - Name: "floor", - Description: "Rounds the number _down_ to the nearest integer.", - Decl: types.NewFunction( - types.Args( - types.Named("x", types.N).Description("the number to round"), - ), - types.Named("y", types.N).Description("the result of rounding `x` _down_"), - ), - Categories: number, -} +var Floor = v1.Floor -var Abs = &Builtin{ - Name: "abs", - Description: "Returns the number without its sign.", - Decl: types.NewFunction( - types.Args( - types.Named("x", types.N), - ), - types.Named("y", types.N).Description("the absolute value of `x`"), - ), - Categories: number, -} +var Abs = v1.Abs -var Rem = &Builtin{ - Name: "rem", - Infix: "%", - Description: "Returns the remainder for of `x` divided by `y`, for `y != 0`.", - Decl: types.NewFunction( - types.Args( - types.Named("x", types.N), - types.Named("y", types.N), - ), - types.Named("z", types.N).Description("the remainder"), - ), - Categories: number, -} +var Rem = v1.Rem /** * Bitwise */ -var BitsOr = &Builtin{ - Name: "bits.or", - Description: "Returns the bitwise \"OR\" of two integers.", - Decl: types.NewFunction( - types.Args( - types.Named("x", types.N), - types.Named("y", types.N), - ), - types.Named("z", types.N), - ), -} +var BitsOr = v1.BitsOr -var BitsAnd = &Builtin{ - Name: "bits.and", - Description: "Returns the bitwise \"AND\" of two integers.", - Decl: types.NewFunction( - types.Args( - types.Named("x", types.N), - types.Named("y", types.N), - ), - types.Named("z", types.N), - ), -} +var BitsAnd = v1.BitsAnd -var BitsNegate = &Builtin{ - Name: "bits.negate", - Description: "Returns the bitwise negation (flip) of an integer.", - Decl: types.NewFunction( - types.Args( - types.Named("x", types.N), - ), - types.Named("z", types.N), - ), -} +var BitsNegate = v1.BitsNegate -var BitsXOr = &Builtin{ - Name: "bits.xor", - Description: "Returns the bitwise \"XOR\" (exclusive-or) of two integers.", - Decl: types.NewFunction( - types.Args( - types.Named("x", types.N), - types.Named("y", types.N), - ), - types.Named("z", types.N), - ), -} +var BitsXOr = v1.BitsXOr -var BitsShiftLeft = &Builtin{ - Name: "bits.lsh", - Description: "Returns a new integer with its bits shifted `s` bits to the left.", - Decl: types.NewFunction( - types.Args( - types.Named("x", types.N), - types.Named("s", types.N), - ), - types.Named("z", types.N), - ), -} +var BitsShiftLeft = v1.BitsShiftLeft -var BitsShiftRight = &Builtin{ - Name: "bits.rsh", - Description: "Returns a new integer with its bits shifted `s` bits to the right.", - Decl: types.NewFunction( - types.Args( - types.Named("x", types.N), - types.Named("s", types.N), - ), - types.Named("z", types.N), - ), -} +var BitsShiftRight = v1.BitsShiftRight /** * Sets */ -var sets = category("sets") - -var And = &Builtin{ - Name: "and", - Infix: "&", - Description: "Returns the intersection of two sets.", - Decl: types.NewFunction( - types.Args( - types.Named("x", types.NewSet(types.A)), - types.Named("y", types.NewSet(types.A)), - ), - types.Named("z", types.NewSet(types.A)).Description("the intersection of `x` and `y`"), - ), - Categories: sets, -} +var And = v1.And // Or performs a union operation on sets. -var Or = &Builtin{ - Name: "or", - Infix: "|", - Description: "Returns the union of two sets.", - Decl: types.NewFunction( - types.Args( - types.Named("x", types.NewSet(types.A)), - types.Named("y", types.NewSet(types.A)), - ), - types.Named("z", types.NewSet(types.A)).Description("the union of `x` and `y`"), - ), - Categories: sets, -} +var Or = v1.Or -var Intersection = &Builtin{ - Name: "intersection", - Description: "Returns the intersection of the given input sets.", - Decl: types.NewFunction( - types.Args( - types.Named("xs", types.NewSet(types.NewSet(types.A))).Description("set of sets to intersect"), - ), - types.Named("y", types.NewSet(types.A)).Description("the intersection of all `xs` sets"), - ), - Categories: sets, -} +var Intersection = v1.Intersection -var Union = &Builtin{ - Name: "union", - Description: "Returns the union of the given input sets.", - Decl: types.NewFunction( - types.Args( - types.Named("xs", types.NewSet(types.NewSet(types.A))).Description("set of sets to merge"), - ), - types.Named("y", types.NewSet(types.A)).Description("the union of all `xs` sets"), - ), - Categories: sets, -} +var Union = v1.Union /** * Aggregates */ -var aggregates = category("aggregates") - -var Count = &Builtin{ - Name: "count", - Description: " Count takes a collection or string and returns the number of elements (or characters) in it.", - Decl: types.NewFunction( - types.Args( - types.Named("collection", types.NewAny( - types.NewSet(types.A), - types.NewArray(nil, types.A), - types.NewObject(nil, types.NewDynamicProperty(types.A, types.A)), - types.S, - )).Description("the set/array/object/string to be counted"), - ), - types.Named("n", types.N).Description("the count of elements, key/val pairs, or characters, respectively."), - ), - Categories: aggregates, -} +var Count = v1.Count -var Sum = &Builtin{ - Name: "sum", - Description: "Sums elements of an array or set of numbers.", - Decl: types.NewFunction( - types.Args( - types.Named("collection", types.NewAny( - types.NewSet(types.N), - types.NewArray(nil, types.N), - )), - ), - types.Named("n", types.N).Description("the sum of all elements"), - ), - Categories: aggregates, -} +var Sum = v1.Sum -var Product = &Builtin{ - Name: "product", - Description: "Muliplies elements of an array or set of numbers", - Decl: types.NewFunction( - types.Args( - types.Named("collection", types.NewAny( - types.NewSet(types.N), - types.NewArray(nil, types.N), - )), - ), - types.Named("n", types.N).Description("the product of all elements"), - ), - Categories: aggregates, -} +var Product = v1.Product -var Max = &Builtin{ - Name: "max", - Description: "Returns the maximum value in a collection.", - Decl: types.NewFunction( - types.Args( - types.Named("collection", types.NewAny( - types.NewSet(types.A), - types.NewArray(nil, types.A), - )), - ), - types.Named("n", types.A).Description("the maximum of all elements"), - ), - Categories: aggregates, -} +var Max = v1.Max -var Min = &Builtin{ - Name: "min", - Description: "Returns the minimum value in a collection.", - Decl: types.NewFunction( - types.Args( - types.Named("collection", types.NewAny( - types.NewSet(types.A), - types.NewArray(nil, types.A), - )), - ), - types.Named("n", types.A).Description("the minimum of all elements"), - ), - Categories: aggregates, -} +var Min = v1.Min /** * Sorting */ -var Sort = &Builtin{ - Name: "sort", - Description: "Returns a sorted array.", - Decl: types.NewFunction( - types.Args( - types.Named("collection", types.NewAny( - types.NewArray(nil, types.A), - types.NewSet(types.A), - )).Description("the array or set to be sorted"), - ), - types.Named("n", types.NewArray(nil, types.A)).Description("the sorted array"), - ), - Categories: aggregates, -} +var Sort = v1.Sort /** * Arrays */ -var ArrayConcat = &Builtin{ - Name: "array.concat", - Description: "Concatenates two arrays.", - Decl: types.NewFunction( - types.Args( - types.Named("x", types.NewArray(nil, types.A)), - types.Named("y", types.NewArray(nil, types.A)), - ), - types.Named("z", types.NewArray(nil, types.A)).Description("the concatenation of `x` and `y`"), - ), -} +var ArrayConcat = v1.ArrayConcat -var ArraySlice = &Builtin{ - Name: "array.slice", - Description: "Returns a slice of a given array. If `start` is greater or equal than `stop`, `slice` is `[]`.", - Decl: types.NewFunction( - types.Args( - types.Named("arr", types.NewArray(nil, types.A)).Description("the array to be sliced"), - types.Named("start", types.NewNumber()).Description("the start index of the returned slice; if less than zero, it's clamped to 0"), - types.Named("stop", types.NewNumber()).Description("the stop index of the returned slice; if larger than `count(arr)`, it's clamped to `count(arr)`"), - ), - types.Named("slice", types.NewArray(nil, types.A)).Description("the subslice of `array`, from `start` to `end`, including `arr[start]`, but excluding `arr[end]`"), - ), -} // NOTE(sr): this function really needs examples - -var ArrayReverse = &Builtin{ - Name: "array.reverse", - Description: "Returns the reverse of a given array.", - Decl: types.NewFunction( - types.Args( - types.Named("arr", types.NewArray(nil, types.A)).Description("the array to be reversed"), - ), - types.Named("rev", types.NewArray(nil, types.A)).Description("an array containing the elements of `arr` in reverse order"), - ), -} +var ArraySlice = v1.ArraySlice + +var ArrayReverse = v1.ArrayReverse /** * Conversions */ -var conversions = category("conversions") - -var ToNumber = &Builtin{ - Name: "to_number", - Description: "Converts a string, bool, or number value to a number: Strings are converted to numbers using `strconv.Atoi`, Boolean `false` is converted to 0 and `true` is converted to 1.", - Decl: types.NewFunction( - types.Args( - types.Named("x", types.NewAny( - types.N, - types.S, - types.B, - types.NewNull(), - )), - ), - types.Named("num", types.N), - ), - Categories: conversions, -} + +var ToNumber = v1.ToNumber /** * Regular Expressions */ -var RegexMatch = &Builtin{ - Name: "regex.match", - Description: "Matches a string against a regular expression.", - Decl: types.NewFunction( - types.Args( - types.Named("pattern", types.S).Description("regular expression"), - types.Named("value", types.S).Description("value to match against `pattern`"), - ), - types.Named("result", types.B), - ), -} +var RegexMatch = v1.RegexMatch -var RegexIsValid = &Builtin{ - Name: "regex.is_valid", - Description: "Checks if a string is a valid regular expression: the detailed syntax for patterns is defined by https://github.com/google/re2/wiki/Syntax.", - Decl: types.NewFunction( - types.Args( - types.Named("pattern", types.S).Description("regular expression"), - ), - types.Named("result", types.B), - ), -} +var RegexIsValid = v1.RegexIsValid -var RegexFindAllStringSubmatch = &Builtin{ - Name: "regex.find_all_string_submatch_n", - Description: "Returns all successive matches of the expression.", - Decl: types.NewFunction( - types.Args( - types.Named("pattern", types.S).Description("regular expression"), - types.Named("value", types.S).Description("string to match"), - types.Named("number", types.N).Description("number of matches to return; `-1` means all matches"), - ), - types.Named("output", types.NewArray(nil, types.NewArray(nil, types.S))), - ), -} +var RegexFindAllStringSubmatch = v1.RegexFindAllStringSubmatch -var RegexTemplateMatch = &Builtin{ - Name: "regex.template_match", - Description: "Matches a string against a pattern, where there pattern may be glob-like", - Decl: types.NewFunction( - types.Args( - types.Named("template", types.S).Description("template expression containing `0..n` regular expressions"), - types.Named("value", types.S).Description("string to match"), - types.Named("delimiter_start", types.S).Description("start delimiter of the regular expression in `template`"), - types.Named("delimiter_end", types.S).Description("end delimiter of the regular expression in `template`"), - ), - types.Named("result", types.B), - ), -} // TODO(sr): example:`regex.template_match("urn:foo:{.*}", "urn:foo:bar:baz", "{", "}")`` returns ``true``. - -var RegexSplit = &Builtin{ - Name: "regex.split", - Description: "Splits the input string by the occurrences of the given pattern.", - Decl: types.NewFunction( - types.Args( - types.Named("pattern", types.S).Description("regular expression"), - types.Named("value", types.S).Description("string to match"), - ), - types.Named("output", types.NewArray(nil, types.S)).Description("the parts obtained by splitting `value`"), - ), -} +var RegexTemplateMatch = v1.RegexTemplateMatch + +var RegexSplit = v1.RegexSplit // RegexFind takes two strings and a number, the pattern, the value and number of match values to // return, -1 means all match values. -var RegexFind = &Builtin{ - Name: "regex.find_n", - Description: "Returns the specified number of matches when matching the input against the pattern.", - Decl: types.NewFunction( - types.Args( - types.Named("pattern", types.S).Description("regular expression"), - types.Named("value", types.S).Description("string to match"), - types.Named("number", types.N).Description("number of matches to return, if `-1`, returns all matches"), - ), - types.Named("output", types.NewArray(nil, types.S)).Description("collected matches"), - ), -} +var RegexFind = v1.RegexFind // GlobsMatch takes two strings regexp-style strings and evaluates to true if their // intersection matches a non-empty set of non-empty strings. // Examples: // - "a.a." and ".b.b" -> true. // - "[a-z]*" and [0-9]+" -> not true. -var GlobsMatch = &Builtin{ - Name: "regex.globs_match", - Description: `Checks if the intersection of two glob-style regular expressions matches a non-empty set of non-empty strings. -The set of regex symbols is limited for this builtin: only ` + "`.`, `*`, `+`, `[`, `-`, `]` and `\\` are treated as special symbols.", - Decl: types.NewFunction( - types.Args( - types.Named("glob1", types.S), - types.Named("glob2", types.S), - ), - types.Named("result", types.B), - ), -} +var GlobsMatch = v1.GlobsMatch /** * Strings */ -var stringsCat = category("strings") - -var AnyPrefixMatch = &Builtin{ - Name: "strings.any_prefix_match", - Description: "Returns true if any of the search strings begins with any of the base strings.", - Decl: types.NewFunction( - types.Args( - types.Named("search", types.NewAny( - types.S, - types.NewSet(types.S), - types.NewArray(nil, types.S), - )).Description("search string(s)"), - types.Named("base", types.NewAny( - types.S, - types.NewSet(types.S), - types.NewArray(nil, types.S), - )).Description("base string(s)"), - ), - types.Named("result", types.B).Description("result of the prefix check"), - ), - Categories: stringsCat, -} -var AnySuffixMatch = &Builtin{ - Name: "strings.any_suffix_match", - Description: "Returns true if any of the search strings ends with any of the base strings.", - Decl: types.NewFunction( - types.Args( - types.Named("search", types.NewAny( - types.S, - types.NewSet(types.S), - types.NewArray(nil, types.S), - )).Description("search string(s)"), - types.Named("base", types.NewAny( - types.S, - types.NewSet(types.S), - types.NewArray(nil, types.S), - )).Description("base string(s)"), - ), - types.Named("result", types.B).Description("result of the suffix check"), - ), - Categories: stringsCat, -} +var AnyPrefixMatch = v1.AnyPrefixMatch -var Concat = &Builtin{ - Name: "concat", - Description: "Joins a set or array of strings with a delimiter.", - Decl: types.NewFunction( - types.Args( - types.Named("delimiter", types.S), - types.Named("collection", types.NewAny( - types.NewSet(types.S), - types.NewArray(nil, types.S), - )).Description("strings to join"), - ), - types.Named("output", types.S), - ), - Categories: stringsCat, -} +var AnySuffixMatch = v1.AnySuffixMatch -var FormatInt = &Builtin{ - Name: "format_int", - Description: "Returns the string representation of the number in the given base after rounding it down to an integer value.", - Decl: types.NewFunction( - types.Args( - types.Named("number", types.N).Description("number to format"), - types.Named("base", types.N).Description("base of number representation to use"), - ), - types.Named("output", types.S).Description("formatted number"), - ), - Categories: stringsCat, -} +var Concat = v1.Concat -var IndexOf = &Builtin{ - Name: "indexof", - Description: "Returns the index of a substring contained inside a string.", - Decl: types.NewFunction( - types.Args( - types.Named("haystack", types.S).Description("string to search in"), - types.Named("needle", types.S).Description("substring to look for"), - ), - types.Named("output", types.N).Description("index of first occurrence, `-1` if not found"), - ), - Categories: stringsCat, -} +var FormatInt = v1.FormatInt -var IndexOfN = &Builtin{ - Name: "indexof_n", - Description: "Returns a list of all the indexes of a substring contained inside a string.", - Decl: types.NewFunction( - types.Args( - types.Named("haystack", types.S).Description("string to search in"), - types.Named("needle", types.S).Description("substring to look for"), - ), - types.Named("output", types.NewArray(nil, types.N)).Description("all indices at which `needle` occurs in `haystack`, may be empty"), - ), - Categories: stringsCat, -} +var IndexOf = v1.IndexOf -var Substring = &Builtin{ - Name: "substring", - Description: "Returns the portion of a string for a given `offset` and a `length`. If `length < 0`, `output` is the remainder of the string.", - Decl: types.NewFunction( - types.Args( - types.Named("value", types.S), - types.Named("offset", types.N).Description("offset, must be positive"), - types.Named("length", types.N).Description("length of the substring starting from `offset`"), - ), - types.Named("output", types.S).Description("substring of `value` from `offset`, of length `length`"), - ), - Categories: stringsCat, -} +var IndexOfN = v1.IndexOfN -var Contains = &Builtin{ - Name: "contains", - Description: "Returns `true` if the search string is included in the base string", - Decl: types.NewFunction( - types.Args( - types.Named("haystack", types.S).Description("string to search in"), - types.Named("needle", types.S).Description("substring to look for"), - ), - types.Named("result", types.B).Description("result of the containment check"), - ), - Categories: stringsCat, -} +var Substring = v1.Substring -var StringCount = &Builtin{ - Name: "strings.count", - Description: "Returns the number of non-overlapping instances of a substring in a string.", - Decl: types.NewFunction( - types.Args( - types.Named("search", types.S).Description("string to search in"), - types.Named("substring", types.S).Description("substring to look for"), - ), - types.Named("output", types.N).Description("count of occurrences, `0` if not found"), - ), - Categories: stringsCat, -} +var Contains = v1.Contains -var StartsWith = &Builtin{ - Name: "startswith", - Description: "Returns true if the search string begins with the base string.", - Decl: types.NewFunction( - types.Args( - types.Named("search", types.S).Description("search string"), - types.Named("base", types.S).Description("base string"), - ), - types.Named("result", types.B).Description("result of the prefix check"), - ), - Categories: stringsCat, -} +var StringCount = v1.StringCount -var EndsWith = &Builtin{ - Name: "endswith", - Description: "Returns true if the search string ends with the base string.", - Decl: types.NewFunction( - types.Args( - types.Named("search", types.S).Description("search string"), - types.Named("base", types.S).Description("base string"), - ), - types.Named("result", types.B).Description("result of the suffix check"), - ), - Categories: stringsCat, -} +var StartsWith = v1.StartsWith -var Lower = &Builtin{ - Name: "lower", - Description: "Returns the input string but with all characters in lower-case.", - Decl: types.NewFunction( - types.Args( - types.Named("x", types.S).Description("string that is converted to lower-case"), - ), - types.Named("y", types.S).Description("lower-case of x"), - ), - Categories: stringsCat, -} +var EndsWith = v1.EndsWith -var Upper = &Builtin{ - Name: "upper", - Description: "Returns the input string but with all characters in upper-case.", - Decl: types.NewFunction( - types.Args( - types.Named("x", types.S).Description("string that is converted to upper-case"), - ), - types.Named("y", types.S).Description("upper-case of x"), - ), - Categories: stringsCat, -} +var Lower = v1.Lower -var Split = &Builtin{ - Name: "split", - Description: "Split returns an array containing elements of the input string split on a delimiter.", - Decl: types.NewFunction( - types.Args( - types.Named("x", types.S).Description("string that is split"), - types.Named("delimiter", types.S).Description("delimiter used for splitting"), - ), - types.Named("ys", types.NewArray(nil, types.S)).Description("split parts"), - ), - Categories: stringsCat, -} +var Upper = v1.Upper -var Replace = &Builtin{ - Name: "replace", - Description: "Replace replaces all instances of a sub-string.", - Decl: types.NewFunction( - types.Args( - types.Named("x", types.S).Description("string being processed"), - types.Named("old", types.S).Description("substring to replace"), - types.Named("new", types.S).Description("string to replace `old` with"), - ), - types.Named("y", types.S).Description("string with replaced substrings"), - ), - Categories: stringsCat, -} +var Split = v1.Split -var ReplaceN = &Builtin{ - Name: "strings.replace_n", - Description: `Replaces a string from a list of old, new string pairs. -Replacements are performed in the order they appear in the target string, without overlapping matches. -The old string comparisons are done in argument order.`, - Decl: types.NewFunction( - types.Args( - types.Named("patterns", types.NewObject( - nil, - types.NewDynamicProperty( - types.S, - types.S)), - ).Description("replacement pairs"), - types.Named("value", types.S).Description("string to replace substring matches in"), - ), - types.Named("output", types.S), - ), -} +var Replace = v1.Replace -var RegexReplace = &Builtin{ - Name: "regex.replace", - Description: `Find and replaces the text using the regular expression pattern.`, - Decl: types.NewFunction( - types.Args( - types.Named("s", types.S).Description("string being processed"), - types.Named("pattern", types.S).Description("regex pattern to be applied"), - types.Named("value", types.S).Description("regex value"), - ), - types.Named("output", types.S), - ), -} +var ReplaceN = v1.ReplaceN -var Trim = &Builtin{ - Name: "trim", - Description: "Returns `value` with all leading or trailing instances of the `cutset` characters removed.", - Decl: types.NewFunction( - types.Args( - types.Named("value", types.S).Description("string to trim"), - types.Named("cutset", types.S).Description("string of characters that are cut off"), - ), - types.Named("output", types.S).Description("string trimmed of `cutset` characters"), - ), - Categories: stringsCat, -} +var RegexReplace = v1.RegexReplace -var TrimLeft = &Builtin{ - Name: "trim_left", - Description: "Returns `value` with all leading instances of the `cutset` characters removed.", - Decl: types.NewFunction( - types.Args( - types.Named("value", types.S).Description("string to trim"), - types.Named("cutset", types.S).Description("string of characters that are cut off on the left"), - ), - types.Named("output", types.S).Description("string left-trimmed of `cutset` characters"), - ), - Categories: stringsCat, -} +var Trim = v1.Trim -var TrimPrefix = &Builtin{ - Name: "trim_prefix", - Description: "Returns `value` without the prefix. If `value` doesn't start with `prefix`, it is returned unchanged.", - Decl: types.NewFunction( - types.Args( - types.Named("value", types.S).Description("string to trim"), - types.Named("prefix", types.S).Description("prefix to cut off"), - ), - types.Named("output", types.S).Description("string with `prefix` cut off"), - ), - Categories: stringsCat, -} +var TrimLeft = v1.TrimLeft -var TrimRight = &Builtin{ - Name: "trim_right", - Description: "Returns `value` with all trailing instances of the `cutset` characters removed.", - Decl: types.NewFunction( - types.Args( - types.Named("value", types.S).Description("string to trim"), - types.Named("cutset", types.S).Description("string of characters that are cut off on the right"), - ), - types.Named("output", types.S).Description("string right-trimmed of `cutset` characters"), - ), - Categories: stringsCat, -} +var TrimPrefix = v1.TrimPrefix -var TrimSuffix = &Builtin{ - Name: "trim_suffix", - Description: "Returns `value` without the suffix. If `value` doesn't end with `suffix`, it is returned unchanged.", - Decl: types.NewFunction( - types.Args( - types.Named("value", types.S).Description("string to trim"), - types.Named("suffix", types.S).Description("suffix to cut off"), - ), - types.Named("output", types.S).Description("string with `suffix` cut off"), - ), - Categories: stringsCat, -} +var TrimRight = v1.TrimRight -var TrimSpace = &Builtin{ - Name: "trim_space", - Description: "Return the given string with all leading and trailing white space removed.", - Decl: types.NewFunction( - types.Args( - types.Named("value", types.S).Description("string to trim"), - ), - types.Named("output", types.S).Description("string leading and trailing white space cut off"), - ), - Categories: stringsCat, -} +var TrimSuffix = v1.TrimSuffix -var Sprintf = &Builtin{ - Name: "sprintf", - Description: "Returns the given string, formatted.", - Decl: types.NewFunction( - types.Args( - types.Named("format", types.S).Description("string with formatting verbs"), - types.Named("values", types.NewArray(nil, types.A)).Description("arguments to format into formatting verbs"), - ), - types.Named("output", types.S).Description("`format` formatted by the values in `values`"), - ), - Categories: stringsCat, -} +var TrimSpace = v1.TrimSpace -var StringReverse = &Builtin{ - Name: "strings.reverse", - Description: "Reverses a given string.", - Decl: types.NewFunction( - types.Args( - types.Named("x", types.S), - ), - types.Named("y", types.S), - ), - Categories: stringsCat, -} +var Sprintf = v1.Sprintf -var RenderTemplate = &Builtin{ - Name: "strings.render_template", - Description: `Renders a templated string with given template variables injected. For a given templated string and key/value mapping, values will be injected into the template where they are referenced by key. - For examples of templating syntax, see https://pkg.go.dev/text/template`, - Decl: types.NewFunction( - types.Args( - types.Named("value", types.S).Description("a templated string"), - types.Named("vars", types.NewObject(nil, types.NewDynamicProperty(types.S, types.A))).Description("a mapping of template variable keys to values"), - ), - types.Named("result", types.S).Description("rendered template with template variables injected"), - ), - Categories: stringsCat, -} +var StringReverse = v1.StringReverse + +var RenderTemplate = v1.RenderTemplate /** * Numbers @@ -1354,82 +236,19 @@ var RenderTemplate = &Builtin{ // RandIntn returns a random number 0 - n // Marked non-deterministic because it relies on RNG internally. -var RandIntn = &Builtin{ - Name: "rand.intn", - Description: "Returns a random integer between `0` and `n` (`n` exclusive). If `n` is `0`, then `y` is always `0`. For any given argument pair (`str`, `n`), the output will be consistent throughout a query evaluation.", - Decl: types.NewFunction( - types.Args( - types.Named("str", types.S), - types.Named("n", types.N), - ), - types.Named("y", types.N).Description("random integer in the range `[0, abs(n))`"), - ), - Categories: number, - Nondeterministic: true, -} +var RandIntn = v1.RandIntn -var NumbersRange = &Builtin{ - Name: "numbers.range", - Description: "Returns an array of numbers in the given (inclusive) range. If `a==b`, then `range == [a]`; if `a > b`, then `range` is in descending order.", - Decl: types.NewFunction( - types.Args( - types.Named("a", types.N), - types.Named("b", types.N), - ), - types.Named("range", types.NewArray(nil, types.N)).Description("the range between `a` and `b`"), - ), -} +var NumbersRange = v1.NumbersRange -var NumbersRangeStep = &Builtin{ - Name: "numbers.range_step", - Description: `Returns an array of numbers in the given (inclusive) range incremented by a positive step. - If "a==b", then "range == [a]"; if "a > b", then "range" is in descending order. - If the provided "step" is less then 1, an error will be thrown. - If "b" is not in the range of the provided "step", "b" won't be included in the result. - `, - Decl: types.NewFunction( - types.Args( - types.Named("a", types.N), - types.Named("b", types.N), - types.Named("step", types.N), - ), - types.Named("range", types.NewArray(nil, types.N)).Description("the range between `a` and `b` in `step` increments"), - ), -} +var NumbersRangeStep = v1.NumbersRangeStep /** * Units */ -var UnitsParse = &Builtin{ - Name: "units.parse", - Description: `Converts strings like "10G", "5K", "4M", "1500m" and the like into a number. -This number can be a non-integer, such as 1.5, 0.22, etc. Supports standard metric decimal and -binary SI units (e.g., K, Ki, M, Mi, G, Gi etc.) m, K, M, G, T, P, and E are treated as decimal -units and Ki, Mi, Gi, Ti, Pi, and Ei are treated as binary units. - -Note that 'm' and 'M' are case-sensitive, to allow distinguishing between "milli" and "mega" units respectively. Other units are case-insensitive.`, - Decl: types.NewFunction( - types.Args( - types.Named("x", types.S).Description("the unit to parse"), - ), - types.Named("y", types.N).Description("the parsed number"), - ), -} +var UnitsParse = v1.UnitsParse -var UnitsParseBytes = &Builtin{ - Name: "units.parse_bytes", - Description: `Converts strings like "10GB", "5K", "4mb" into an integer number of bytes. -Supports standard byte units (e.g., KB, KiB, etc.) KB, MB, GB, and TB are treated as decimal -units and KiB, MiB, GiB, and TiB are treated as binary units. The bytes symbol (b/B) in the -unit is optional and omitting it wil give the same result (e.g. Mi and MiB).`, - Decl: types.NewFunction( - types.Args( - types.Named("x", types.S).Description("the byte unit to parse"), - ), - types.Named("y", types.N).Description("the parsed number"), - ), -} +var UnitsParseBytes = v1.UnitsParseBytes // /** @@ -1438,1372 +257,241 @@ unit is optional and omitting it wil give the same result (e.g. Mi and MiB).`, // UUIDRFC4122 returns a version 4 UUID string. // Marked non-deterministic because it relies on RNG internally. -var UUIDRFC4122 = &Builtin{ - Name: "uuid.rfc4122", - Description: "Returns a new UUIDv4.", - Decl: types.NewFunction( - types.Args( - types.Named("k", types.S), - ), - types.Named("output", types.S).Description("a version 4 UUID; for any given `k`, the output will be consistent throughout a query evaluation"), - ), - Nondeterministic: true, -} +var UUIDRFC4122 = v1.UUIDRFC4122 -var UUIDParse = &Builtin{ - Name: "uuid.parse", - Description: "Parses the string value as an UUID and returns an object with the well-defined fields of the UUID if valid.", - Categories: nil, - Decl: types.NewFunction( - types.Args( - types.Named("uuid", types.S), - ), - types.Named("result", types.NewObject(nil, types.NewDynamicProperty(types.S, types.A))).Description("Properties of UUID if valid (version, variant, etc). Undefined otherwise."), - ), - Relation: false, -} +var UUIDParse = v1.UUIDParse /** * JSON */ -var objectCat = category("object") - -var JSONFilter = &Builtin{ - Name: "json.filter", - Description: "Filters the object. " + - "For example: `json.filter({\"a\": {\"b\": \"x\", \"c\": \"y\"}}, [\"a/b\"])` will result in `{\"a\": {\"b\": \"x\"}}`). " + - "Paths are not filtered in-order and are deduplicated before being evaluated.", - Decl: types.NewFunction( - types.Args( - types.Named("object", types.NewObject( - nil, - types.NewDynamicProperty(types.A, types.A), - )), - types.Named("paths", types.NewAny( - types.NewArray( - nil, - types.NewAny( - types.S, - types.NewArray( - nil, - types.A, - ), - ), - ), - types.NewSet( - types.NewAny( - types.S, - types.NewArray( - nil, - types.A, - ), - ), - ), - )).Description("JSON string paths"), - ), - types.Named("filtered", types.A).Description("remaining data from `object` with only keys specified in `paths`"), - ), - Categories: objectCat, -} +var JSONFilter = v1.JSONFilter -var JSONRemove = &Builtin{ - Name: "json.remove", - Description: "Removes paths from an object. " + - "For example: `json.remove({\"a\": {\"b\": \"x\", \"c\": \"y\"}}, [\"a/b\"])` will result in `{\"a\": {\"c\": \"y\"}}`. " + - "Paths are not removed in-order and are deduplicated before being evaluated.", - Decl: types.NewFunction( - types.Args( - types.Named("object", types.NewObject( - nil, - types.NewDynamicProperty(types.A, types.A), - )), - types.Named("paths", types.NewAny( - types.NewArray( - nil, - types.NewAny( - types.S, - types.NewArray( - nil, - types.A, - ), - ), - ), - types.NewSet( - types.NewAny( - types.S, - types.NewArray( - nil, - types.A, - ), - ), - ), - )).Description("JSON string paths"), - ), - types.Named("output", types.A).Description("result of removing all keys specified in `paths`"), - ), - Categories: objectCat, -} +var JSONRemove = v1.JSONRemove -var JSONPatch = &Builtin{ - Name: "json.patch", - Description: "Patches an object according to RFC6902. " + - "For example: `json.patch({\"a\": {\"foo\": 1}}, [{\"op\": \"add\", \"path\": \"/a/bar\", \"value\": 2}])` results in `{\"a\": {\"foo\": 1, \"bar\": 2}`. " + - "The patches are applied atomically: if any of them fails, the result will be undefined. " + - "Additionally works on sets, where a value contained in the set is considered to be its path.", - Decl: types.NewFunction( - types.Args( - types.Named("object", types.A), // TODO(sr): types.A? - types.Named("patches", types.NewArray( - nil, - types.NewObject( - []*types.StaticProperty{ - {Key: "op", Value: types.S}, - {Key: "path", Value: types.A}, - }, - types.NewDynamicProperty(types.A, types.A), - ), - )), - ), - types.Named("output", types.A).Description("result obtained after consecutively applying all patch operations in `patches`"), - ), - Categories: objectCat, -} +var JSONPatch = v1.JSONPatch -var ObjectSubset = &Builtin{ - Name: "object.subset", - Description: "Determines if an object `sub` is a subset of another object `super`." + - "Object `sub` is a subset of object `super` if and only if every key in `sub` is also in `super`, " + - "**and** for all keys which `sub` and `super` share, they have the same value. " + - "This function works with objects, sets, arrays and a set of array and set." + - "If both arguments are objects, then the operation is recursive, e.g. " + - "`{\"c\": {\"x\": {10, 15, 20}}` is a subset of `{\"a\": \"b\", \"c\": {\"x\": {10, 15, 20, 25}, \"y\": \"z\"}`. " + - "If both arguments are sets, then this function checks if every element of `sub` is a member of `super`, " + - "but does not attempt to recurse. If both arguments are arrays, " + - "then this function checks if `sub` appears contiguously in order within `super`, " + - "and also does not attempt to recurse. If `super` is array and `sub` is set, " + - "then this function checks if `super` contains every element of `sub` with no consideration of ordering, " + - "and also does not attempt to recurse.", - Decl: types.NewFunction( - types.Args( - types.Named("super", types.NewAny(types.NewObject( - nil, - types.NewDynamicProperty(types.A, types.A), - ), - types.NewSet(types.A), - types.NewArray(nil, types.A), - )).Description("object to test if sub is a subset of"), - types.Named("sub", types.NewAny(types.NewObject( - nil, - types.NewDynamicProperty(types.A, types.A), - ), - types.NewSet(types.A), - types.NewArray(nil, types.A), - )).Description("object to test if super is a superset of"), - ), - types.Named("result", types.A).Description("`true` if `sub` is a subset of `super`"), - ), -} +var ObjectSubset = v1.ObjectSubset -var ObjectUnion = &Builtin{ - Name: "object.union", - Description: "Creates a new object of the asymmetric union of two objects. " + - "For example: `object.union({\"a\": 1, \"b\": 2, \"c\": {\"d\": 3}}, {\"a\": 7, \"c\": {\"d\": 4, \"e\": 5}})` will result in `{\"a\": 7, \"b\": 2, \"c\": {\"d\": 4, \"e\": 5}}`.", - Decl: types.NewFunction( - types.Args( - types.Named("a", types.NewObject( - nil, - types.NewDynamicProperty(types.A, types.A), - )), - types.Named("b", types.NewObject( - nil, - types.NewDynamicProperty(types.A, types.A), - )), - ), - types.Named("output", types.A).Description("a new object which is the result of an asymmetric recursive union of two objects where conflicts are resolved by choosing the key from the right-hand object `b`"), - ), // TODO(sr): types.A? ^^^^^^^ (also below) -} +var ObjectUnion = v1.ObjectUnion -var ObjectUnionN = &Builtin{ - Name: "object.union_n", - Description: "Creates a new object that is the asymmetric union of all objects merged from left to right. " + - "For example: `object.union_n([{\"a\": 1}, {\"b\": 2}, {\"a\": 3}])` will result in `{\"b\": 2, \"a\": 3}`.", - Decl: types.NewFunction( - types.Args( - types.Named("objects", types.NewArray( - nil, - types.NewObject(nil, types.NewDynamicProperty(types.A, types.A)), - )), - ), - types.Named("output", types.A).Description("asymmetric recursive union of all objects in `objects`, merged from left to right, where conflicts are resolved by choosing the key from the right-hand object"), - ), -} +var ObjectUnionN = v1.ObjectUnionN -var ObjectRemove = &Builtin{ - Name: "object.remove", - Description: "Removes specified keys from an object.", - Decl: types.NewFunction( - types.Args( - types.Named("object", types.NewObject( - nil, - types.NewDynamicProperty(types.A, types.A), - )).Description("object to remove keys from"), - types.Named("keys", types.NewAny( - types.NewArray(nil, types.A), - types.NewSet(types.A), - types.NewObject(nil, types.NewDynamicProperty(types.A, types.A)), - )).Description("keys to remove from x"), - ), - types.Named("output", types.A).Description("result of removing the specified `keys` from `object`"), - ), -} +var ObjectRemove = v1.ObjectRemove -var ObjectFilter = &Builtin{ - Name: "object.filter", - Description: "Filters the object by keeping only specified keys. " + - "For example: `object.filter({\"a\": {\"b\": \"x\", \"c\": \"y\"}, \"d\": \"z\"}, [\"a\"])` will result in `{\"a\": {\"b\": \"x\", \"c\": \"y\"}}`).", - Decl: types.NewFunction( - types.Args( - types.Named("object", types.NewObject( - nil, - types.NewDynamicProperty(types.A, types.A), - )).Description("object to filter keys"), - types.Named("keys", types.NewAny( - types.NewArray(nil, types.A), - types.NewSet(types.A), - types.NewObject(nil, types.NewDynamicProperty(types.A, types.A)), - )), - ), - types.Named("filtered", types.A).Description("remaining data from `object` with only keys specified in `keys`"), - ), -} +var ObjectFilter = v1.ObjectFilter -var ObjectGet = &Builtin{ - Name: "object.get", - Description: "Returns value of an object's key if present, otherwise a default. " + - "If the supplied `key` is an `array`, then `object.get` will search through a nested object or array using each key in turn. " + - "For example: `object.get({\"a\": [{ \"b\": true }]}, [\"a\", 0, \"b\"], false)` results in `true`.", - Decl: types.NewFunction( - types.Args( - types.Named("object", types.NewObject(nil, types.NewDynamicProperty(types.A, types.A))).Description("object to get `key` from"), - types.Named("key", types.A).Description("key to lookup in `object`"), - types.Named("default", types.A).Description("default to use when lookup fails"), - ), - types.Named("value", types.A).Description("`object[key]` if present, otherwise `default`"), - ), -} +var ObjectGet = v1.ObjectGet -var ObjectKeys = &Builtin{ - Name: "object.keys", - Description: "Returns a set of an object's keys. " + - "For example: `object.keys({\"a\": 1, \"b\": true, \"c\": \"d\")` results in `{\"a\", \"b\", \"c\"}`.", - Decl: types.NewFunction( - types.Args( - types.Named("object", types.NewObject(nil, types.NewDynamicProperty(types.A, types.A))).Description("object to get keys from"), - ), - types.Named("value", types.NewSet(types.A)).Description("set of `object`'s keys"), - ), -} +var ObjectKeys = v1.ObjectKeys /* * Encoding */ -var encoding = category("encoding") - -var JSONMarshal = &Builtin{ - Name: "json.marshal", - Description: "Serializes the input term to JSON.", - Decl: types.NewFunction( - types.Args( - types.Named("x", types.A).Description("the term to serialize"), - ), - types.Named("y", types.S).Description("the JSON string representation of `x`"), - ), - Categories: encoding, -} -var JSONMarshalWithOptions = &Builtin{ - Name: "json.marshal_with_options", - Description: "Serializes the input term JSON, with additional formatting options via the `opts` parameter. " + - "`opts` accepts keys `pretty` (enable multi-line/formatted JSON), `prefix` (string to prefix lines with, default empty string) and `indent` (string to indent with, default `\\t`).", - Decl: types.NewFunction( - types.Args( - types.Named("x", types.A).Description("the term to serialize"), - types.Named("opts", types.NewObject( - []*types.StaticProperty{ - types.NewStaticProperty("pretty", types.B), - types.NewStaticProperty("indent", types.S), - types.NewStaticProperty("prefix", types.S), - }, - types.NewDynamicProperty(types.S, types.A), - )).Description("encoding options"), - ), - types.Named("y", types.S).Description("the JSON string representation of `x`, with configured prefix/indent string(s) as appropriate"), - ), - Categories: encoding, -} +var JSONMarshal = v1.JSONMarshal -var JSONUnmarshal = &Builtin{ - Name: "json.unmarshal", - Description: "Deserializes the input string.", - Decl: types.NewFunction( - types.Args( - types.Named("x", types.S).Description("a JSON string"), - ), - types.Named("y", types.A).Description("the term deserialized from `x`"), - ), - Categories: encoding, -} +var JSONMarshalWithOptions = v1.JSONMarshalWithOptions -var JSONIsValid = &Builtin{ - Name: "json.is_valid", - Description: "Verifies the input string is a valid JSON document.", - Decl: types.NewFunction( - types.Args( - types.Named("x", types.S).Description("a JSON string"), - ), - types.Named("result", types.B).Description("`true` if `x` is valid JSON, `false` otherwise"), - ), - Categories: encoding, -} +var JSONUnmarshal = v1.JSONUnmarshal -var Base64Encode = &Builtin{ - Name: "base64.encode", - Description: "Serializes the input string into base64 encoding.", - Decl: types.NewFunction( - types.Args( - types.Named("x", types.S), - ), - types.Named("y", types.S).Description("base64 serialization of `x`"), - ), - Categories: encoding, -} +var JSONIsValid = v1.JSONIsValid -var Base64Decode = &Builtin{ - Name: "base64.decode", - Description: "Deserializes the base64 encoded input string.", - Decl: types.NewFunction( - types.Args( - types.Named("x", types.S), - ), - types.Named("y", types.S).Description("base64 deserialization of `x`"), - ), - Categories: encoding, -} +var Base64Encode = v1.Base64Encode -var Base64IsValid = &Builtin{ - Name: "base64.is_valid", - Description: "Verifies the input string is base64 encoded.", - Decl: types.NewFunction( - types.Args( - types.Named("x", types.S), - ), - types.Named("result", types.B).Description("`true` if `x` is valid base64 encoded value, `false` otherwise"), - ), - Categories: encoding, -} +var Base64Decode = v1.Base64Decode -var Base64UrlEncode = &Builtin{ - Name: "base64url.encode", - Description: "Serializes the input string into base64url encoding.", - Decl: types.NewFunction( - types.Args( - types.Named("x", types.S), - ), - types.Named("y", types.S).Description("base64url serialization of `x`"), - ), - Categories: encoding, -} +var Base64IsValid = v1.Base64IsValid -var Base64UrlEncodeNoPad = &Builtin{ - Name: "base64url.encode_no_pad", - Description: "Serializes the input string into base64url encoding without padding.", - Decl: types.NewFunction( - types.Args( - types.Named("x", types.S), - ), - types.Named("y", types.S).Description("base64url serialization of `x`"), - ), - Categories: encoding, -} +var Base64UrlEncode = v1.Base64UrlEncode -var Base64UrlDecode = &Builtin{ - Name: "base64url.decode", - Description: "Deserializes the base64url encoded input string.", - Decl: types.NewFunction( - types.Args( - types.Named("x", types.S), - ), - types.Named("y", types.S).Description("base64url deserialization of `x`"), - ), - Categories: encoding, -} +var Base64UrlEncodeNoPad = v1.Base64UrlEncodeNoPad -var URLQueryDecode = &Builtin{ - Name: "urlquery.decode", - Description: "Decodes a URL-encoded input string.", - Decl: types.NewFunction( - types.Args( - types.Named("x", types.S), - ), - types.Named("y", types.S).Description("URL-encoding deserialization of `x`"), - ), - Categories: encoding, -} +var Base64UrlDecode = v1.Base64UrlDecode -var URLQueryEncode = &Builtin{ - Name: "urlquery.encode", - Description: "Encodes the input string into a URL-encoded string.", - Decl: types.NewFunction( - types.Args( - types.Named("x", types.S), - ), - types.Named("y", types.S).Description("URL-encoding serialization of `x`"), - ), - Categories: encoding, -} +var URLQueryDecode = v1.URLQueryDecode -var URLQueryEncodeObject = &Builtin{ - Name: "urlquery.encode_object", - Description: "Encodes the given object into a URL encoded query string.", - Decl: types.NewFunction( - types.Args( - types.Named("object", types.NewObject( - nil, - types.NewDynamicProperty( - types.S, - types.NewAny( - types.S, - types.NewArray(nil, types.S), - types.NewSet(types.S)))))), - types.Named("y", types.S).Description("the URL-encoded serialization of `object`"), - ), - Categories: encoding, -} +var URLQueryEncode = v1.URLQueryEncode -var URLQueryDecodeObject = &Builtin{ - Name: "urlquery.decode_object", - Description: "Decodes the given URL query string into an object.", - Decl: types.NewFunction( - types.Args( - types.Named("x", types.S).Description("the query string"), - ), - types.Named("object", types.NewObject(nil, types.NewDynamicProperty( - types.S, - types.NewArray(nil, types.S)))).Description("the resulting object"), - ), - Categories: encoding, -} +var URLQueryEncodeObject = v1.URLQueryEncodeObject -var YAMLMarshal = &Builtin{ - Name: "yaml.marshal", - Description: "Serializes the input term to YAML.", - Decl: types.NewFunction( - types.Args( - types.Named("x", types.A).Description("the term to serialize"), - ), - types.Named("y", types.S).Description("the YAML string representation of `x`"), - ), - Categories: encoding, -} +var URLQueryDecodeObject = v1.URLQueryDecodeObject -var YAMLUnmarshal = &Builtin{ - Name: "yaml.unmarshal", - Description: "Deserializes the input string.", - Decl: types.NewFunction( - types.Args( - types.Named("x", types.S).Description("a YAML string"), - ), - types.Named("y", types.A).Description("the term deserialized from `x`"), - ), - Categories: encoding, -} +var YAMLMarshal = v1.YAMLMarshal + +var YAMLUnmarshal = v1.YAMLUnmarshal // YAMLIsValid verifies the input string is a valid YAML document. -var YAMLIsValid = &Builtin{ - Name: "yaml.is_valid", - Description: "Verifies the input string is a valid YAML document.", - Decl: types.NewFunction( - types.Args( - types.Named("x", types.S).Description("a YAML string"), - ), - types.Named("result", types.B).Description("`true` if `x` is valid YAML, `false` otherwise"), - ), - Categories: encoding, -} +var YAMLIsValid = v1.YAMLIsValid -var HexEncode = &Builtin{ - Name: "hex.encode", - Description: "Serializes the input string using hex-encoding.", - Decl: types.NewFunction( - types.Args( - types.Named("x", types.S), - ), - types.Named("y", types.S).Description("serialization of `x` using hex-encoding"), - ), - Categories: encoding, -} +var HexEncode = v1.HexEncode -var HexDecode = &Builtin{ - Name: "hex.decode", - Description: "Deserializes the hex-encoded input string.", - Decl: types.NewFunction( - types.Args( - types.Named("x", types.S).Description("a hex-encoded string"), - ), - types.Named("y", types.S).Description("deserialized from `x`"), - ), - Categories: encoding, -} +var HexDecode = v1.HexDecode /** * Tokens */ -var tokensCat = category("tokens") - -var JWTDecode = &Builtin{ - Name: "io.jwt.decode", - Description: "Decodes a JSON Web Token and outputs it as an object.", - Decl: types.NewFunction( - types.Args( - types.Named("jwt", types.S).Description("JWT token to decode"), - ), - types.Named("output", types.NewArray([]types.Type{ - types.NewObject(nil, types.NewDynamicProperty(types.A, types.A)), - types.NewObject(nil, types.NewDynamicProperty(types.A, types.A)), - types.S, - }, nil)).Description("`[header, payload, sig]`, where `header` and `payload` are objects; `sig` is the hexadecimal representation of the signature on the token."), - ), - Categories: tokensCat, -} -var JWTVerifyRS256 = &Builtin{ - Name: "io.jwt.verify_rs256", - Description: "Verifies if a RS256 JWT signature is valid.", - Decl: types.NewFunction( - types.Args( - types.Named("jwt", types.S).Description("JWT token whose signature is to be verified"), - types.Named("certificate", types.S).Description("PEM encoded certificate, PEM encoded public key, or the JWK key (set) used to verify the signature"), - ), - types.Named("result", types.B).Description("`true` if the signature is valid, `false` otherwise"), - ), - Categories: tokensCat, -} +var JWTDecode = v1.JWTDecode -var JWTVerifyRS384 = &Builtin{ - Name: "io.jwt.verify_rs384", - Description: "Verifies if a RS384 JWT signature is valid.", - Decl: types.NewFunction( - types.Args( - types.Named("jwt", types.S).Description("JWT token whose signature is to be verified"), - types.Named("certificate", types.S).Description("PEM encoded certificate, PEM encoded public key, or the JWK key (set) used to verify the signature"), - ), - types.Named("result", types.B).Description("`true` if the signature is valid, `false` otherwise"), - ), - Categories: tokensCat, -} +var JWTVerifyRS256 = v1.JWTVerifyRS256 -var JWTVerifyRS512 = &Builtin{ - Name: "io.jwt.verify_rs512", - Description: "Verifies if a RS512 JWT signature is valid.", - Decl: types.NewFunction( - types.Args( - types.Named("jwt", types.S).Description("JWT token whose signature is to be verified"), - types.Named("certificate", types.S).Description("PEM encoded certificate, PEM encoded public key, or the JWK key (set) used to verify the signature"), - ), - types.Named("result", types.B).Description("`true` if the signature is valid, `false` otherwise"), - ), - Categories: tokensCat, -} +var JWTVerifyRS384 = v1.JWTVerifyRS384 -var JWTVerifyPS256 = &Builtin{ - Name: "io.jwt.verify_ps256", - Description: "Verifies if a PS256 JWT signature is valid.", - Decl: types.NewFunction( - types.Args( - types.Named("jwt", types.S).Description("JWT token whose signature is to be verified"), - types.Named("certificate", types.S).Description("PEM encoded certificate, PEM encoded public key, or the JWK key (set) used to verify the signature"), - ), - types.Named("result", types.B).Description("`true` if the signature is valid, `false` otherwise"), - ), - Categories: tokensCat, -} +var JWTVerifyRS512 = v1.JWTVerifyRS512 -var JWTVerifyPS384 = &Builtin{ - Name: "io.jwt.verify_ps384", - Description: "Verifies if a PS384 JWT signature is valid.", - Decl: types.NewFunction( - types.Args( - types.Named("jwt", types.S).Description("JWT token whose signature is to be verified"), - types.Named("certificate", types.S).Description("PEM encoded certificate, PEM encoded public key, or the JWK key (set) used to verify the signature"), - ), - types.Named("result", types.B).Description("`true` if the signature is valid, `false` otherwise"), - ), - Categories: tokensCat, -} +var JWTVerifyPS256 = v1.JWTVerifyPS256 -var JWTVerifyPS512 = &Builtin{ - Name: "io.jwt.verify_ps512", - Description: "Verifies if a PS512 JWT signature is valid.", - Decl: types.NewFunction( - types.Args( - types.Named("jwt", types.S).Description("JWT token whose signature is to be verified"), - types.Named("certificate", types.S).Description("PEM encoded certificate, PEM encoded public key, or the JWK key (set) used to verify the signature"), - ), - types.Named("result", types.B).Description("`true` if the signature is valid, `false` otherwise"), - ), - Categories: tokensCat, -} +var JWTVerifyPS384 = v1.JWTVerifyPS384 -var JWTVerifyES256 = &Builtin{ - Name: "io.jwt.verify_es256", - Description: "Verifies if a ES256 JWT signature is valid.", - Decl: types.NewFunction( - types.Args( - types.Named("jwt", types.S).Description("JWT token whose signature is to be verified"), - types.Named("certificate", types.S).Description("PEM encoded certificate, PEM encoded public key, or the JWK key (set) used to verify the signature"), - ), - types.Named("result", types.B).Description("`true` if the signature is valid, `false` otherwise"), - ), - Categories: tokensCat, -} +var JWTVerifyPS512 = v1.JWTVerifyPS512 -var JWTVerifyES384 = &Builtin{ - Name: "io.jwt.verify_es384", - Description: "Verifies if a ES384 JWT signature is valid.", - Decl: types.NewFunction( - types.Args( - types.Named("jwt", types.S).Description("JWT token whose signature is to be verified"), - types.Named("certificate", types.S).Description("PEM encoded certificate, PEM encoded public key, or the JWK key (set) used to verify the signature"), - ), - types.Named("result", types.B).Description("`true` if the signature is valid, `false` otherwise"), - ), - Categories: tokensCat, -} +var JWTVerifyES256 = v1.JWTVerifyES256 -var JWTVerifyES512 = &Builtin{ - Name: "io.jwt.verify_es512", - Description: "Verifies if a ES512 JWT signature is valid.", - Decl: types.NewFunction( - types.Args( - types.Named("jwt", types.S).Description("JWT token whose signature is to be verified"), - types.Named("certificate", types.S).Description("PEM encoded certificate, PEM encoded public key, or the JWK key (set) used to verify the signature"), - ), - types.Named("result", types.B).Description("`true` if the signature is valid, `false` otherwise"), - ), - Categories: tokensCat, -} +var JWTVerifyES384 = v1.JWTVerifyES384 -var JWTVerifyHS256 = &Builtin{ - Name: "io.jwt.verify_hs256", - Description: "Verifies if a HS256 (secret) JWT signature is valid.", - Decl: types.NewFunction( - types.Args( - types.Named("jwt", types.S).Description("JWT token whose signature is to be verified"), - types.Named("secret", types.S).Description("plain text secret used to verify the signature"), - ), - types.Named("result", types.B).Description("`true` if the signature is valid, `false` otherwise"), - ), - Categories: tokensCat, -} +var JWTVerifyES512 = v1.JWTVerifyES512 -var JWTVerifyHS384 = &Builtin{ - Name: "io.jwt.verify_hs384", - Description: "Verifies if a HS384 (secret) JWT signature is valid.", - Decl: types.NewFunction( - types.Args( - types.Named("jwt", types.S).Description("JWT token whose signature is to be verified"), - types.Named("secret", types.S).Description("plain text secret used to verify the signature"), - ), - types.Named("result", types.B).Description("`true` if the signature is valid, `false` otherwise"), - ), - Categories: tokensCat, -} +var JWTVerifyHS256 = v1.JWTVerifyHS256 -var JWTVerifyHS512 = &Builtin{ - Name: "io.jwt.verify_hs512", - Description: "Verifies if a HS512 (secret) JWT signature is valid.", - Decl: types.NewFunction( - types.Args( - types.Named("jwt", types.S).Description("JWT token whose signature is to be verified"), - types.Named("secret", types.S).Description("plain text secret used to verify the signature"), - ), - types.Named("result", types.B).Description("`true` if the signature is valid, `false` otherwise"), - ), - Categories: tokensCat, -} +var JWTVerifyHS384 = v1.JWTVerifyHS384 -// Marked non-deterministic because it relies on time internally. -var JWTDecodeVerify = &Builtin{ - Name: "io.jwt.decode_verify", - Description: `Verifies a JWT signature under parameterized constraints and decodes the claims if it is valid. -Supports the following algorithms: HS256, HS384, HS512, RS256, RS384, RS512, ES256, ES384, ES512, PS256, PS384 and PS512.`, - Decl: types.NewFunction( - types.Args( - types.Named("jwt", types.S).Description("JWT token whose signature is to be verified and whose claims are to be checked"), - types.Named("constraints", types.NewObject(nil, types.NewDynamicProperty(types.S, types.A))).Description("claim verification constraints"), - ), - types.Named("output", types.NewArray([]types.Type{ - types.B, - types.NewObject(nil, types.NewDynamicProperty(types.A, types.A)), - types.NewObject(nil, types.NewDynamicProperty(types.A, types.A)), - }, nil)).Description("`[valid, header, payload]`: if the input token is verified and meets the requirements of `constraints` then `valid` is `true`; `header` and `payload` are objects containing the JOSE header and the JWT claim set; otherwise, `valid` is `false`, `header` and `payload` are `{}`"), - ), - Categories: tokensCat, - Nondeterministic: true, -} +var JWTVerifyHS512 = v1.JWTVerifyHS512 -var tokenSign = category("tokensign") +// Marked non-deterministic because it relies on time internally. +var JWTDecodeVerify = v1.JWTDecodeVerify // Marked non-deterministic because it relies on RNG internally. -var JWTEncodeSignRaw = &Builtin{ - Name: "io.jwt.encode_sign_raw", - Description: "Encodes and optionally signs a JSON Web Token.", - Decl: types.NewFunction( - types.Args( - types.Named("headers", types.S).Description("JWS Protected Header"), - types.Named("payload", types.S).Description("JWS Payload"), - types.Named("key", types.S).Description("JSON Web Key (RFC7517)"), - ), - types.Named("output", types.S).Description("signed JWT"), - ), - Categories: tokenSign, - Nondeterministic: true, -} +var JWTEncodeSignRaw = v1.JWTEncodeSignRaw // Marked non-deterministic because it relies on RNG internally. -var JWTEncodeSign = &Builtin{ - Name: "io.jwt.encode_sign", - Description: "Encodes and optionally signs a JSON Web Token. Inputs are taken as objects, not encoded strings (see `io.jwt.encode_sign_raw`).", - Decl: types.NewFunction( - types.Args( - types.Named("headers", types.NewObject(nil, types.NewDynamicProperty(types.S, types.A))).Description("JWS Protected Header"), - types.Named("payload", types.NewObject(nil, types.NewDynamicProperty(types.S, types.A))).Description("JWS Payload"), - types.Named("key", types.NewObject(nil, types.NewDynamicProperty(types.S, types.A))).Description("JSON Web Key (RFC7517)"), - ), - types.Named("output", types.S).Description("signed JWT"), - ), - Categories: tokenSign, - Nondeterministic: true, -} +var JWTEncodeSign = v1.JWTEncodeSign /** * Time */ // Marked non-deterministic because it relies on time directly. -var NowNanos = &Builtin{ - Name: "time.now_ns", - Description: "Returns the current time since epoch in nanoseconds.", - Decl: types.NewFunction( - nil, - types.Named("now", types.N).Description("nanoseconds since epoch"), - ), - Nondeterministic: true, -} +var NowNanos = v1.NowNanos -var ParseNanos = &Builtin{ - Name: "time.parse_ns", - Description: "Returns the time in nanoseconds parsed from the string in the given format. `undefined` if the result would be outside the valid time range that can fit within an `int64`.", - Decl: types.NewFunction( - types.Args( - types.Named("layout", types.S).Description("format used for parsing, see the [Go `time` package documentation](https://golang.org/pkg/time/#Parse) for more details"), - types.Named("value", types.S).Description("input to parse according to `layout`"), - ), - types.Named("ns", types.N).Description("`value` in nanoseconds since epoch"), - ), -} +var ParseNanos = v1.ParseNanos -var ParseRFC3339Nanos = &Builtin{ - Name: "time.parse_rfc3339_ns", - Description: "Returns the time in nanoseconds parsed from the string in RFC3339 format. `undefined` if the result would be outside the valid time range that can fit within an `int64`.", - Decl: types.NewFunction( - types.Args( - types.Named("value", types.S), - ), - types.Named("ns", types.N).Description("`value` in nanoseconds since epoch"), - ), -} +var ParseRFC3339Nanos = v1.ParseRFC3339Nanos -var ParseDurationNanos = &Builtin{ - Name: "time.parse_duration_ns", - Description: "Returns the duration in nanoseconds represented by a string.", - Decl: types.NewFunction( - types.Args( - types.Named("duration", types.S).Description("a duration like \"3m\"; see the [Go `time` package documentation](https://golang.org/pkg/time/#ParseDuration) for more details"), - ), - types.Named("ns", types.N).Description("the `duration` in nanoseconds"), - ), -} +var ParseDurationNanos = v1.ParseDurationNanos -var Format = &Builtin{ - Name: "time.format", - Description: "Returns the formatted timestamp for the nanoseconds since epoch.", - Decl: types.NewFunction( - types.Args( - types.Named("x", types.NewAny( - types.N, - types.NewArray([]types.Type{types.N, types.S}, nil), - types.NewArray([]types.Type{types.N, types.S, types.S}, nil), - )).Description("a number representing the nanoseconds since the epoch (UTC); or a two-element array of the nanoseconds, and a timezone string; or a three-element array of ns, timezone string and a layout string or golang defined formatting constant (see golang supported time formats)"), - ), - types.Named("formatted timestamp", types.S).Description("the formatted timestamp represented for the nanoseconds since the epoch in the supplied timezone (or UTC)"), - ), -} +var Format = v1.Format -var Date = &Builtin{ - Name: "time.date", - Description: "Returns the `[year, month, day]` for the nanoseconds since epoch.", - Decl: types.NewFunction( - types.Args( - types.Named("x", types.NewAny( - types.N, - types.NewArray([]types.Type{types.N, types.S}, nil), - )).Description("a number representing the nanoseconds since the epoch (UTC); or a two-element array of the nanoseconds, and a timezone string"), - ), - types.Named("date", types.NewArray([]types.Type{types.N, types.N, types.N}, nil)).Description("an array of `year`, `month` (1-12), and `day` (1-31)"), - ), -} +var Date = v1.Date -var Clock = &Builtin{ - Name: "time.clock", - Description: "Returns the `[hour, minute, second]` of the day for the nanoseconds since epoch.", - Decl: types.NewFunction( - types.Args( - types.Named("x", types.NewAny( - types.N, - types.NewArray([]types.Type{types.N, types.S}, nil), - )).Description("a number representing the nanoseconds since the epoch (UTC); or a two-element array of the nanoseconds, and a timezone string"), - ), - types.Named("output", types.NewArray([]types.Type{types.N, types.N, types.N}, nil)). - Description("the `hour`, `minute` (0-59), and `second` (0-59) representing the time of day for the nanoseconds since epoch in the supplied timezone (or UTC)"), - ), -} +var Clock = v1.Clock -var Weekday = &Builtin{ - Name: "time.weekday", - Description: "Returns the day of the week (Monday, Tuesday, ...) for the nanoseconds since epoch.", - Decl: types.NewFunction( - types.Args( - types.Named("x", types.NewAny( - types.N, - types.NewArray([]types.Type{types.N, types.S}, nil), - )).Description("a number representing the nanoseconds since the epoch (UTC); or a two-element array of the nanoseconds, and a timezone string"), - ), - types.Named("day", types.S).Description("the weekday represented by `ns` nanoseconds since the epoch in the supplied timezone (or UTC)"), - ), -} +var Weekday = v1.Weekday -var AddDate = &Builtin{ - Name: "time.add_date", - Description: "Returns the nanoseconds since epoch after adding years, months and days to nanoseconds. Month & day values outside their usual ranges after the operation and will be normalized - for example, October 32 would become November 1. `undefined` if the result would be outside the valid time range that can fit within an `int64`.", - Decl: types.NewFunction( - types.Args( - types.Named("ns", types.N).Description("nanoseconds since the epoch"), - types.Named("years", types.N), - types.Named("months", types.N), - types.Named("days", types.N), - ), - types.Named("output", types.N).Description("nanoseconds since the epoch representing the input time, with years, months and days added"), - ), -} +var AddDate = v1.AddDate -var Diff = &Builtin{ - Name: "time.diff", - Description: "Returns the difference between two unix timestamps in nanoseconds (with optional timezone strings).", - Decl: types.NewFunction( - types.Args( - types.Named("ns1", types.NewAny( - types.N, - types.NewArray([]types.Type{types.N, types.S}, nil), - )), - types.Named("ns2", types.NewAny( - types.N, - types.NewArray([]types.Type{types.N, types.S}, nil), - )), - ), - types.Named("output", types.NewArray([]types.Type{types.N, types.N, types.N, types.N, types.N, types.N}, nil)).Description("difference between `ns1` and `ns2` (in their supplied timezones, if supplied, or UTC) as array of numbers: `[years, months, days, hours, minutes, seconds]`"), - ), -} +var Diff = v1.Diff /** * Crypto. */ -var CryptoX509ParseCertificates = &Builtin{ - Name: "crypto.x509.parse_certificates", - Description: `Returns zero or more certificates from the given encoded string containing -DER certificate data. - -If the input is empty, the function will return null. The input string should be a list of one or more -concatenated PEM blocks. The whole input of concatenated PEM blocks can optionally be Base64 encoded.`, - Decl: types.NewFunction( - types.Args( - types.Named("certs", types.S).Description("base64 encoded DER or PEM data containing one or more certificates or a PEM string of one or more certificates"), - ), - types.Named("output", types.NewArray(nil, types.NewObject(nil, types.NewDynamicProperty(types.S, types.A)))).Description("parsed X.509 certificates represented as objects"), - ), -} +var CryptoX509ParseCertificates = v1.CryptoX509ParseCertificates -var CryptoX509ParseAndVerifyCertificates = &Builtin{ - Name: "crypto.x509.parse_and_verify_certificates", - Description: `Returns one or more certificates from the given string containing PEM -or base64 encoded DER certificates after verifying the supplied certificates form a complete -certificate chain back to a trusted root. - -The first certificate is treated as the root and the last is treated as the leaf, -with all others being treated as intermediates.`, - Decl: types.NewFunction( - types.Args( - types.Named("certs", types.S).Description("base64 encoded DER or PEM data containing two or more certificates where the first is a root CA, the last is a leaf certificate, and all others are intermediate CAs"), - ), - types.Named("output", types.NewArray([]types.Type{ - types.B, - types.NewArray(nil, types.NewObject(nil, types.NewDynamicProperty(types.S, types.A))), - }, nil)).Description("array of `[valid, certs]`: if the input certificate chain could be verified then `valid` is `true` and `certs` is an array of X.509 certificates represented as objects; if the input certificate chain could not be verified then `valid` is `false` and `certs` is `[]`"), - ), -} +var CryptoX509ParseAndVerifyCertificates = v1.CryptoX509ParseAndVerifyCertificates -var CryptoX509ParseAndVerifyCertificatesWithOptions = &Builtin{ - Name: "crypto.x509.parse_and_verify_certificates_with_options", - Description: `Returns one or more certificates from the given string containing PEM -or base64 encoded DER certificates after verifying the supplied certificates form a complete -certificate chain back to a trusted root. A config option passed as the second argument can -be used to configure the validation options used. - -The first certificate is treated as the root and the last is treated as the leaf, -with all others being treated as intermediates.`, - - Decl: types.NewFunction( - types.Args( - types.Named("certs", types.S).Description("base64 encoded DER or PEM data containing two or more certificates where the first is a root CA, the last is a leaf certificate, and all others are intermediate CAs"), - types.Named("options", types.NewObject( - nil, - types.NewDynamicProperty(types.S, types.A), - )).Description("object containing extra configs to verify the validity of certificates. `options` object supports four fields which maps to same fields in [x509.VerifyOptions struct](https://pkg.go.dev/crypto/x509#VerifyOptions). `DNSName`, `CurrentTime`: Nanoseconds since the Unix Epoch as a number, `MaxConstraintComparisons` and `KeyUsages`. `KeyUsages` is list and can have possible values as in: `\"KeyUsageAny\"`, `\"KeyUsageServerAuth\"`, `\"KeyUsageClientAuth\"`, `\"KeyUsageCodeSigning\"`, `\"KeyUsageEmailProtection\"`, `\"KeyUsageIPSECEndSystem\"`, `\"KeyUsageIPSECTunnel\"`, `\"KeyUsageIPSECUser\"`, `\"KeyUsageTimeStamping\"`, `\"KeyUsageOCSPSigning\"`, `\"KeyUsageMicrosoftServerGatedCrypto\"`, `\"KeyUsageNetscapeServerGatedCrypto\"`, `\"KeyUsageMicrosoftCommercialCodeSigning\"`, `\"KeyUsageMicrosoftKernelCodeSigning\"` "), - ), - types.Named("output", types.NewArray([]types.Type{ - types.B, - types.NewArray(nil, types.NewObject(nil, types.NewDynamicProperty(types.S, types.A))), - }, nil)).Description("array of `[valid, certs]`: if the input certificate chain could be verified then `valid` is `true` and `certs` is an array of X.509 certificates represented as objects; if the input certificate chain could not be verified then `valid` is `false` and `certs` is `[]`"), - ), -} +var CryptoX509ParseAndVerifyCertificatesWithOptions = v1.CryptoX509ParseAndVerifyCertificatesWithOptions -var CryptoX509ParseCertificateRequest = &Builtin{ - Name: "crypto.x509.parse_certificate_request", - Description: "Returns a PKCS #10 certificate signing request from the given PEM-encoded PKCS#10 certificate signing request.", - Decl: types.NewFunction( - types.Args( - types.Named("csr", types.S).Description("base64 string containing either a PEM encoded or DER CSR or a string containing a PEM CSR"), - ), - types.Named("output", types.NewObject(nil, types.NewDynamicProperty(types.S, types.A))).Description("X.509 CSR represented as an object"), - ), -} +var CryptoX509ParseCertificateRequest = v1.CryptoX509ParseCertificateRequest -var CryptoX509ParseKeyPair = &Builtin{ - Name: "crypto.x509.parse_keypair", - Description: "Returns a valid key pair", - Decl: types.NewFunction( - types.Args( - types.Named("cert", types.S).Description("string containing PEM or base64 encoded DER certificates"), - types.Named("pem", types.S).Description("string containing PEM or base64 encoded DER keys"), - ), - types.Named("output", types.NewObject(nil, types.NewDynamicProperty(types.S, types.A))).Description("if key pair is valid, returns the tls.certificate(https://pkg.go.dev/crypto/tls#Certificate) as an object. If the key pair is invalid, nil and an error are returned."), - ), -} -var CryptoX509ParseRSAPrivateKey = &Builtin{ - Name: "crypto.x509.parse_rsa_private_key", - Description: "Returns a JWK for signing a JWT from the given PEM-encoded RSA private key.", - Decl: types.NewFunction( - types.Args( - types.Named("pem", types.S).Description("base64 string containing a PEM encoded RSA private key"), - ), - types.Named("output", types.NewObject(nil, types.NewDynamicProperty(types.S, types.A))).Description("JWK as an object"), - ), -} +var CryptoX509ParseKeyPair = v1.CryptoX509ParseKeyPair +var CryptoX509ParseRSAPrivateKey = v1.CryptoX509ParseRSAPrivateKey -var CryptoParsePrivateKeys = &Builtin{ - Name: "crypto.parse_private_keys", - Description: `Returns zero or more private keys from the given encoded string containing DER certificate data. - -If the input is empty, the function will return null. The input string should be a list of one or more concatenated PEM blocks. The whole input of concatenated PEM blocks can optionally be Base64 encoded.`, - Decl: types.NewFunction( - types.Args( - types.Named("keys", types.S).Description("PEM encoded data containing one or more private keys as concatenated blocks. Optionally Base64 encoded."), - ), - types.Named("output", types.NewArray(nil, types.NewObject(nil, types.NewDynamicProperty(types.S, types.A)))).Description("parsed private keys represented as objects"), - ), -} +var CryptoParsePrivateKeys = v1.CryptoParsePrivateKeys -var CryptoMd5 = &Builtin{ - Name: "crypto.md5", - Description: "Returns a string representing the input string hashed with the MD5 function", - Decl: types.NewFunction( - types.Args( - types.Named("x", types.S), - ), - types.Named("y", types.S).Description("MD5-hash of `x`"), - ), -} +var CryptoMd5 = v1.CryptoMd5 -var CryptoSha1 = &Builtin{ - Name: "crypto.sha1", - Description: "Returns a string representing the input string hashed with the SHA1 function", - Decl: types.NewFunction( - types.Args( - types.Named("x", types.S), - ), - types.Named("y", types.S).Description("SHA1-hash of `x`"), - ), -} +var CryptoSha1 = v1.CryptoSha1 -var CryptoSha256 = &Builtin{ - Name: "crypto.sha256", - Description: "Returns a string representing the input string hashed with the SHA256 function", - Decl: types.NewFunction( - types.Args( - types.Named("x", types.S), - ), - types.Named("y", types.S).Description("SHA256-hash of `x`"), - ), -} +var CryptoSha256 = v1.CryptoSha256 -var CryptoHmacMd5 = &Builtin{ - Name: "crypto.hmac.md5", - Description: "Returns a string representing the MD5 HMAC of the input message using the input key.", - Decl: types.NewFunction( - types.Args( - types.Named("x", types.S).Description("input string"), - types.Named("key", types.S).Description("key to use"), - ), - types.Named("y", types.S).Description("MD5-HMAC of `x`"), - ), -} +var CryptoHmacMd5 = v1.CryptoHmacMd5 -var CryptoHmacSha1 = &Builtin{ - Name: "crypto.hmac.sha1", - Description: "Returns a string representing the SHA1 HMAC of the input message using the input key.", - Decl: types.NewFunction( - types.Args( - types.Named("x", types.S).Description("input string"), - types.Named("key", types.S).Description("key to use"), - ), - types.Named("y", types.S).Description("SHA1-HMAC of `x`"), - ), -} +var CryptoHmacSha1 = v1.CryptoHmacSha1 -var CryptoHmacSha256 = &Builtin{ - Name: "crypto.hmac.sha256", - Description: "Returns a string representing the SHA256 HMAC of the input message using the input key.", - Decl: types.NewFunction( - types.Args( - types.Named("x", types.S).Description("input string"), - types.Named("key", types.S).Description("key to use"), - ), - types.Named("y", types.S).Description("SHA256-HMAC of `x`"), - ), -} +var CryptoHmacSha256 = v1.CryptoHmacSha256 -var CryptoHmacSha512 = &Builtin{ - Name: "crypto.hmac.sha512", - Description: "Returns a string representing the SHA512 HMAC of the input message using the input key.", - Decl: types.NewFunction( - types.Args( - types.Named("x", types.S).Description("input string"), - types.Named("key", types.S).Description("key to use"), - ), - types.Named("y", types.S).Description("SHA512-HMAC of `x`"), - ), -} +var CryptoHmacSha512 = v1.CryptoHmacSha512 -var CryptoHmacEqual = &Builtin{ - Name: "crypto.hmac.equal", - Description: "Returns a boolean representing the result of comparing two MACs for equality without leaking timing information.", - Decl: types.NewFunction( - types.Args( - types.Named("mac1", types.S).Description("mac1 to compare"), - types.Named("mac2", types.S).Description("mac2 to compare"), - ), - types.Named("result", types.B).Description("`true` if the MACs are equals, `false` otherwise"), - ), -} +var CryptoHmacEqual = v1.CryptoHmacEqual /** * Graphs. */ -var graphs = category("graph") - -var WalkBuiltin = &Builtin{ - Name: "walk", - Relation: true, - Description: "Generates `[path, value]` tuples for all nested documents of `x` (recursively). Queries can use `walk` to traverse documents nested under `x`.", - Decl: types.NewFunction( - types.Args( - types.Named("x", types.A), - ), - types.Named("output", types.NewArray( - []types.Type{ - types.NewArray(nil, types.A), - types.A, - }, - nil, - )).Description("pairs of `path` and `value`: `path` is an array representing the pointer to `value` in `x`. If `path` is assigned a wildcard (`_`), the `walk` function will skip path creation entirely for faster evaluation."), - ), - Categories: graphs, -} -var ReachableBuiltin = &Builtin{ - Name: "graph.reachable", - Description: "Computes the set of reachable nodes in the graph from a set of starting nodes.", - Decl: types.NewFunction( - types.Args( - types.Named("graph", types.NewObject( - nil, - types.NewDynamicProperty( - types.A, - types.NewAny( - types.NewSet(types.A), - types.NewArray(nil, types.A)), - )), - ).Description("object containing a set or array of neighboring vertices"), - types.Named("initial", types.NewAny(types.NewSet(types.A), types.NewArray(nil, types.A))).Description("set or array of root vertices"), - ), - types.Named("output", types.NewSet(types.A)).Description("set of vertices reachable from the `initial` vertices in the directed `graph`"), - ), -} +var WalkBuiltin = v1.WalkBuiltin -var ReachablePathsBuiltin = &Builtin{ - Name: "graph.reachable_paths", - Description: "Computes the set of reachable paths in the graph from a set of starting nodes.", - Decl: types.NewFunction( - types.Args( - types.Named("graph", types.NewObject( - nil, - types.NewDynamicProperty( - types.A, - types.NewAny( - types.NewSet(types.A), - types.NewArray(nil, types.A)), - )), - ).Description("object containing a set or array of root vertices"), - types.Named("initial", types.NewAny(types.NewSet(types.A), types.NewArray(nil, types.A))).Description("initial paths"), // TODO(sr): copied. is that correct? - ), - types.Named("output", types.NewSet(types.NewArray(nil, types.A))).Description("paths reachable from the `initial` vertices in the directed `graph`"), - ), -} +var ReachableBuiltin = v1.ReachableBuiltin + +var ReachablePathsBuiltin = v1.ReachablePathsBuiltin /** * Type */ -var typesCat = category("types") - -var IsNumber = &Builtin{ - Name: "is_number", - Description: "Returns `true` if the input value is a number.", - Decl: types.NewFunction( - types.Args( - types.Named("x", types.A), - ), - types.Named("result", types.B).Description("`true` if `x` is a number, `false` otherwise."), - ), - Categories: typesCat, -} -var IsString = &Builtin{ - Name: "is_string", - Description: "Returns `true` if the input value is a string.", - Decl: types.NewFunction( - types.Args( - types.Named("x", types.A), - ), - types.Named("result", types.B).Description("`true` if `x` is a string, `false` otherwise."), - ), - Categories: typesCat, -} +var IsNumber = v1.IsNumber -var IsBoolean = &Builtin{ - Name: "is_boolean", - Description: "Returns `true` if the input value is a boolean.", - Decl: types.NewFunction( - types.Args( - types.Named("x", types.A), - ), - types.Named("result", types.B).Description("`true` if `x` is an boolean, `false` otherwise."), - ), - Categories: typesCat, -} +var IsString = v1.IsString -var IsArray = &Builtin{ - Name: "is_array", - Description: "Returns `true` if the input value is an array.", - Decl: types.NewFunction( - types.Args( - types.Named("x", types.A), - ), - types.Named("result", types.B).Description("`true` if `x` is an array, `false` otherwise."), - ), - Categories: typesCat, -} +var IsBoolean = v1.IsBoolean -var IsSet = &Builtin{ - Name: "is_set", - Description: "Returns `true` if the input value is a set.", - Decl: types.NewFunction( - types.Args( - types.Named("x", types.A), - ), - types.Named("result", types.B).Description("`true` if `x` is a set, `false` otherwise."), - ), - Categories: typesCat, -} +var IsArray = v1.IsArray -var IsObject = &Builtin{ - Name: "is_object", - Description: "Returns true if the input value is an object", - Decl: types.NewFunction( - types.Args( - types.Named("x", types.A), - ), - types.Named("result", types.B).Description("`true` if `x` is an object, `false` otherwise."), - ), - Categories: typesCat, -} +var IsSet = v1.IsSet -var IsNull = &Builtin{ - Name: "is_null", - Description: "Returns `true` if the input value is null.", - Decl: types.NewFunction( - types.Args( - types.Named("x", types.A), - ), - types.Named("result", types.B).Description("`true` if `x` is null, `false` otherwise."), - ), - Categories: typesCat, -} +var IsObject = v1.IsObject + +var IsNull = v1.IsNull /** * Type Name */ // TypeNameBuiltin returns the type of the input. -var TypeNameBuiltin = &Builtin{ - Name: "type_name", - Description: "Returns the type of its input value.", - Decl: types.NewFunction( - types.Args( - types.Named("x", types.A), - ), - types.Named("type", types.S).Description(`one of "null", "boolean", "number", "string", "array", "object", "set"`), - ), - Categories: typesCat, -} +var TypeNameBuiltin = v1.TypeNameBuiltin /** * HTTP Request */ // Marked non-deterministic because HTTP request results can be non-deterministic. -var HTTPSend = &Builtin{ - Name: "http.send", - Description: "Returns a HTTP response to the given HTTP request.", - Decl: types.NewFunction( - types.Args( - types.Named("request", types.NewObject(nil, types.NewDynamicProperty(types.S, types.A))), - ), - types.Named("response", types.NewObject(nil, types.NewDynamicProperty(types.A, types.A))), - ), - Nondeterministic: true, -} +var HTTPSend = v1.HTTPSend /** * GraphQL */ // GraphQLParse returns a pair of AST objects from parsing/validation. -var GraphQLParse = &Builtin{ - Name: "graphql.parse", - Description: "Returns AST objects for a given GraphQL query and schema after validating the query against the schema. Returns undefined if errors were encountered during parsing or validation. The query and/or schema can be either GraphQL strings or AST objects from the other GraphQL builtin functions.", - Decl: types.NewFunction( - types.Args( - types.Named("query", types.NewAny(types.S, types.NewObject(nil, types.NewDynamicProperty(types.A, types.A)))), - types.Named("schema", types.NewAny(types.S, types.NewObject(nil, types.NewDynamicProperty(types.A, types.A)))), - ), - types.Named("output", types.NewArray([]types.Type{ - types.NewObject(nil, types.NewDynamicProperty(types.A, types.A)), - types.NewObject(nil, types.NewDynamicProperty(types.A, types.A)), - }, nil)).Description("`output` is of the form `[query_ast, schema_ast]`. If the GraphQL query is valid given the provided schema, then `query_ast` and `schema_ast` are objects describing the ASTs for the query and schema."), - ), -} +var GraphQLParse = v1.GraphQLParse // GraphQLParseAndVerify returns a boolean and a pair of AST object from parsing/validation. -var GraphQLParseAndVerify = &Builtin{ - Name: "graphql.parse_and_verify", - Description: "Returns a boolean indicating success or failure alongside the parsed ASTs for a given GraphQL query and schema after validating the query against the schema. The query and/or schema can be either GraphQL strings or AST objects from the other GraphQL builtin functions.", - Decl: types.NewFunction( - types.Args( - types.Named("query", types.NewAny(types.S, types.NewObject(nil, types.NewDynamicProperty(types.A, types.A)))), - types.Named("schema", types.NewAny(types.S, types.NewObject(nil, types.NewDynamicProperty(types.A, types.A)))), - ), - types.Named("output", types.NewArray([]types.Type{ - types.B, - types.NewObject(nil, types.NewDynamicProperty(types.A, types.A)), - types.NewObject(nil, types.NewDynamicProperty(types.A, types.A)), - }, nil)).Description(" `output` is of the form `[valid, query_ast, schema_ast]`. If the query is valid given the provided schema, then `valid` is `true`, and `query_ast` and `schema_ast` are objects describing the ASTs for the GraphQL query and schema. Otherwise, `valid` is `false` and `query_ast` and `schema_ast` are `{}`."), - ), -} +var GraphQLParseAndVerify = v1.GraphQLParseAndVerify // GraphQLParseQuery parses the input GraphQL query and returns a JSON // representation of its AST. -var GraphQLParseQuery = &Builtin{ - Name: "graphql.parse_query", - Description: "Returns an AST object for a GraphQL query.", - Decl: types.NewFunction( - types.Args( - types.Named("query", types.S), - ), - types.Named("output", types.NewObject(nil, types.NewDynamicProperty(types.A, types.A))).Description("AST object for the GraphQL query."), - ), -} +var GraphQLParseQuery = v1.GraphQLParseQuery // GraphQLParseSchema parses the input GraphQL schema and returns a JSON // representation of its AST. -var GraphQLParseSchema = &Builtin{ - Name: "graphql.parse_schema", - Description: "Returns an AST object for a GraphQL schema.", - Decl: types.NewFunction( - types.Args( - types.Named("schema", types.S), - ), - types.Named("output", types.NewObject(nil, types.NewDynamicProperty(types.A, types.A))).Description("AST object for the GraphQL schema."), - ), -} +var GraphQLParseSchema = v1.GraphQLParseSchema // GraphQLIsValid returns true if a GraphQL query is valid with a given // schema, and returns false for all other inputs. -var GraphQLIsValid = &Builtin{ - Name: "graphql.is_valid", - Description: "Checks that a GraphQL query is valid against a given schema. The query and/or schema can be either GraphQL strings or AST objects from the other GraphQL builtin functions.", - Decl: types.NewFunction( - types.Args( - types.Named("query", types.NewAny(types.S, types.NewObject(nil, types.NewDynamicProperty(types.A, types.A)))), - types.Named("schema", types.NewAny(types.S, types.NewObject(nil, types.NewDynamicProperty(types.A, types.A)))), - ), - types.Named("output", types.B).Description("`true` if the query is valid under the given schema. `false` otherwise."), - ), -} +var GraphQLIsValid = v1.GraphQLIsValid // GraphQLSchemaIsValid returns true if the input is valid GraphQL schema, // and returns false for all other inputs. -var GraphQLSchemaIsValid = &Builtin{ - Name: "graphql.schema_is_valid", - Description: "Checks that the input is a valid GraphQL schema. The schema can be either a GraphQL string or an AST object from the other GraphQL builtin functions.", - Decl: types.NewFunction( - types.Args( - types.Named("schema", types.NewAny(types.S, types.NewObject(nil, types.NewDynamicProperty(types.A, types.A)))), - ), - types.Named("output", types.B).Description("`true` if the schema is a valid GraphQL schema. `false` otherwise."), - ), -} +var GraphQLSchemaIsValid = v1.GraphQLSchemaIsValid /** * JSON Schema @@ -2811,313 +499,76 @@ var GraphQLSchemaIsValid = &Builtin{ // JSONSchemaVerify returns empty string if the input is valid JSON schema // and returns error string for all other inputs. -var JSONSchemaVerify = &Builtin{ - Name: "json.verify_schema", - Description: "Checks that the input is a valid JSON schema object. The schema can be either a JSON string or an JSON object.", - Decl: types.NewFunction( - types.Args( - types.Named("schema", types.NewAny(types.S, types.NewObject(nil, types.NewDynamicProperty(types.A, types.A)))). - Description("the schema to verify"), - ), - types.Named("output", types.NewArray([]types.Type{ - types.B, - types.NewAny(types.S, types.Null{}), - }, nil)). - Description("`output` is of the form `[valid, error]`. If the schema is valid, then `valid` is `true`, and `error` is `null`. Otherwise, `valid` is `false` and `error` is a string describing the error."), - ), - Categories: objectCat, -} +var JSONSchemaVerify = v1.JSONSchemaVerify // JSONMatchSchema returns empty array if the document matches the JSON schema, // and returns non-empty array with error objects otherwise. -var JSONMatchSchema = &Builtin{ - Name: "json.match_schema", - Description: "Checks that the document matches the JSON schema.", - Decl: types.NewFunction( - types.Args( - types.Named("document", types.NewAny(types.S, types.NewObject(nil, types.NewDynamicProperty(types.A, types.A)))). - Description("document to verify by schema"), - types.Named("schema", types.NewAny(types.S, types.NewObject(nil, types.NewDynamicProperty(types.A, types.A)))). - Description("schema to verify document by"), - ), - types.Named("output", types.NewArray([]types.Type{ - types.B, - types.NewArray( - nil, types.NewObject( - []*types.StaticProperty{ - {Key: "error", Value: types.S}, - {Key: "type", Value: types.S}, - {Key: "field", Value: types.S}, - {Key: "desc", Value: types.S}, - }, - nil, - ), - ), - }, nil)). - Description("`output` is of the form `[match, errors]`. If the document is valid given the schema, then `match` is `true`, and `errors` is an empty array. Otherwise, `match` is `false` and `errors` is an array of objects describing the error(s)."), - ), - Categories: objectCat, -} +var JSONMatchSchema = v1.JSONMatchSchema /** * Cloud Provider Helper Functions */ -var providersAWSCat = category("providers.aws") - -var ProvidersAWSSignReqObj = &Builtin{ - Name: "providers.aws.sign_req", - Description: "Signs an HTTP request object for Amazon Web Services. Currently implements [AWS Signature Version 4 request signing](https://docs.aws.amazon.com/AmazonS3/latest/API/sig-v4-authenticating-requests.html) by the `Authorization` header method.", - Decl: types.NewFunction( - types.Args( - types.Named("request", types.NewObject(nil, types.NewDynamicProperty(types.S, types.A))), - types.Named("aws_config", types.NewObject(nil, types.NewDynamicProperty(types.S, types.A))), - types.Named("time_ns", types.N), - ), - types.Named("signed_request", types.NewObject(nil, types.NewDynamicProperty(types.A, types.A))), - ), - Categories: providersAWSCat, -} + +var ProvidersAWSSignReqObj = v1.ProvidersAWSSignReqObj /** * Rego */ -var RegoParseModule = &Builtin{ - Name: "rego.parse_module", - Description: "Parses the input Rego string and returns an object representation of the AST.", - Decl: types.NewFunction( - types.Args( - types.Named("filename", types.S).Description("file name to attach to AST nodes' locations"), - types.Named("rego", types.S).Description("Rego module"), - ), - types.Named("output", types.NewObject(nil, types.NewDynamicProperty(types.S, types.A))), // TODO(tsandall): import AST schema - ), -} +var RegoParseModule = v1.RegoParseModule -var RegoMetadataChain = &Builtin{ - Name: "rego.metadata.chain", - Description: `Returns the chain of metadata for the active rule. -Ordered starting at the active rule, going outward to the most distant node in its package ancestry. -A chain entry is a JSON document with two members: "path", an array representing the path of the node; and "annotations", a JSON document containing the annotations declared for the node. -The first entry in the chain always points to the active rule, even if it has no declared annotations (in which case the "annotations" member is not present).`, - Decl: types.NewFunction( - types.Args(), - types.Named("chain", types.NewArray(nil, types.A)).Description("each array entry represents a node in the path ancestry (chain) of the active rule that also has declared annotations"), - ), -} +var RegoMetadataChain = v1.RegoMetadataChain // RegoMetadataRule returns the metadata for the active rule -var RegoMetadataRule = &Builtin{ - Name: "rego.metadata.rule", - Description: "Returns annotations declared for the active rule and using the _rule_ scope.", - Decl: types.NewFunction( - types.Args(), - types.Named("output", types.A).Description("\"rule\" scope annotations for this rule; empty object if no annotations exist"), - ), -} +var RegoMetadataRule = v1.RegoMetadataRule /** * OPA */ // Marked non-deterministic because of unpredictable config/environment-dependent results. -var OPARuntime = &Builtin{ - Name: "opa.runtime", - Description: "Returns an object that describes the runtime environment where OPA is deployed.", - Decl: types.NewFunction( - nil, - types.Named("output", types.NewObject(nil, types.NewDynamicProperty(types.S, types.A))). - Description("includes a `config` key if OPA was started with a configuration file; an `env` key containing the environment variables that the OPA process was started with; includes `version` and `commit` keys containing the version and build commit of OPA."), - ), - Nondeterministic: true, -} +var OPARuntime = v1.OPARuntime /** * Trace */ -var tracing = category("tracing") - -var Trace = &Builtin{ - Name: "trace", - Description: "Emits `note` as a `Note` event in the query explanation. Query explanations show the exact expressions evaluated by OPA during policy execution. For example, `trace(\"Hello There!\")` includes `Note \"Hello There!\"` in the query explanation. To include variables in the message, use `sprintf`. For example, `person := \"Bob\"; trace(sprintf(\"Hello There! %v\", [person]))` will emit `Note \"Hello There! Bob\"` inside of the explanation.", - Decl: types.NewFunction( - types.Args( - types.Named("note", types.S).Description("the note to include"), - ), - types.Named("result", types.B).Description("always `true`"), - ), - Categories: tracing, -} + +var Trace = v1.Trace /** * Glob */ -var GlobMatch = &Builtin{ - Name: "glob.match", - Description: "Parses and matches strings against the glob notation. Not to be confused with `regex.globs_match`.", - Decl: types.NewFunction( - types.Args( - types.Named("pattern", types.S), - types.Named("delimiters", types.NewAny( - types.NewArray(nil, types.S), - types.NewNull(), - )).Description("glob pattern delimiters, e.g. `[\".\", \":\"]`, defaults to `[\".\"]` if unset. If `delimiters` is `null`, glob match without delimiter."), - types.Named("match", types.S), - ), - types.Named("result", types.B).Description("true if `match` can be found in `pattern` which is separated by `delimiters`"), - ), -} +var GlobMatch = v1.GlobMatch -var GlobQuoteMeta = &Builtin{ - Name: "glob.quote_meta", - Description: "Returns a string which represents a version of the pattern where all asterisks have been escaped.", - Decl: types.NewFunction( - types.Args( - types.Named("pattern", types.S), - ), - types.Named("output", types.S).Description("the escaped string of `pattern`"), - ), - // TODO(sr): example for this was: Calling ``glob.quote_meta("*.github.com", output)`` returns ``\\*.github.com`` as ``output``. -} +var GlobQuoteMeta = v1.GlobQuoteMeta /** * Networking */ -var NetCIDRIntersects = &Builtin{ - Name: "net.cidr_intersects", - Description: "Checks if a CIDR intersects with another CIDR (e.g. `192.168.0.0/16` overlaps with `192.168.1.0/24`). Supports both IPv4 and IPv6 notations.", - Decl: types.NewFunction( - types.Args( - types.Named("cidr1", types.S), - types.Named("cidr2", types.S), - ), - types.Named("result", types.B), - ), -} +var NetCIDRIntersects = v1.NetCIDRIntersects -var NetCIDRExpand = &Builtin{ - Name: "net.cidr_expand", - Description: "Expands CIDR to set of hosts (e.g., `net.cidr_expand(\"192.168.0.0/30\")` generates 4 hosts: `{\"192.168.0.0\", \"192.168.0.1\", \"192.168.0.2\", \"192.168.0.3\"}`).", - Decl: types.NewFunction( - types.Args( - types.Named("cidr", types.S), - ), - types.Named("hosts", types.NewSet(types.S)).Description("set of IP addresses the CIDR `cidr` expands to"), - ), -} +var NetCIDRExpand = v1.NetCIDRExpand -var NetCIDRContains = &Builtin{ - Name: "net.cidr_contains", - Description: "Checks if a CIDR or IP is contained within another CIDR. `output` is `true` if `cidr_or_ip` (e.g. `127.0.0.64/26` or `127.0.0.1`) is contained within `cidr` (e.g. `127.0.0.1/24`) and `false` otherwise. Supports both IPv4 and IPv6 notations.", - Decl: types.NewFunction( - types.Args( - types.Named("cidr", types.S), - types.Named("cidr_or_ip", types.S), - ), - types.Named("result", types.B), - ), -} +var NetCIDRContains = v1.NetCIDRContains -var NetCIDRContainsMatches = &Builtin{ - Name: "net.cidr_contains_matches", - Description: "Checks if collections of cidrs or ips are contained within another collection of cidrs and returns matches. " + - "This function is similar to `net.cidr_contains` except it allows callers to pass collections of CIDRs or IPs as arguments and returns the matches (as opposed to a boolean result indicating a match between two CIDRs/IPs).", - Decl: types.NewFunction( - types.Args( - types.Named("cidrs", netCidrContainsMatchesOperandType), - types.Named("cidrs_or_ips", netCidrContainsMatchesOperandType), - ), - types.Named("output", types.NewSet(types.NewArray([]types.Type{types.A, types.A}, nil))).Description("tuples identifying matches where `cidrs_or_ips` are contained within `cidrs`"), - ), -} +var NetCIDRContainsMatches = v1.NetCIDRContainsMatches -var NetCIDRMerge = &Builtin{ - Name: "net.cidr_merge", - Description: "Merges IP addresses and subnets into the smallest possible list of CIDRs (e.g., `net.cidr_merge([\"192.0.128.0/24\", \"192.0.129.0/24\"])` generates `{\"192.0.128.0/23\"}`." + - `This function merges adjacent subnets where possible, those contained within others and also removes any duplicates. -Supports both IPv4 and IPv6 notations. IPv6 inputs need a prefix length (e.g. "/128").`, - Decl: types.NewFunction( - types.Args( - types.Named("addrs", types.NewAny( - types.NewArray(nil, types.NewAny(types.S)), - types.NewSet(types.S), - )).Description("CIDRs or IP addresses"), - ), - types.Named("output", types.NewSet(types.S)).Description("smallest possible set of CIDRs obtained after merging the provided list of IP addresses and subnets in `addrs`"), - ), -} +var NetCIDRMerge = v1.NetCIDRMerge -var NetCIDRIsValid = &Builtin{ - Name: "net.cidr_is_valid", - Description: "Parses an IPv4/IPv6 CIDR and returns a boolean indicating if the provided CIDR is valid.", - Decl: types.NewFunction( - types.Args( - types.Named("cidr", types.S), - ), - types.Named("result", types.B), - ), -} - -var netCidrContainsMatchesOperandType = types.NewAny( - types.S, - types.NewArray(nil, types.NewAny( - types.S, - types.NewArray(nil, types.A), - )), - types.NewSet(types.NewAny( - types.S, - types.NewArray(nil, types.A), - )), - types.NewObject(nil, types.NewDynamicProperty( - types.S, - types.NewAny( - types.S, - types.NewArray(nil, types.A), - ), - )), -) +var NetCIDRIsValid = v1.NetCIDRIsValid // Marked non-deterministic because DNS resolution results can be non-deterministic. -var NetLookupIPAddr = &Builtin{ - Name: "net.lookup_ip_addr", - Description: "Returns the set of IP addresses (both v4 and v6) that the passed-in `name` resolves to using the standard name resolution mechanisms available.", - Decl: types.NewFunction( - types.Args( - types.Named("name", types.S).Description("domain name to resolve"), - ), - types.Named("addrs", types.NewSet(types.S)).Description("IP addresses (v4 and v6) that `name` resolves to"), - ), - Nondeterministic: true, -} +var NetLookupIPAddr = v1.NetLookupIPAddr /** * Semantic Versions */ -var SemVerIsValid = &Builtin{ - Name: "semver.is_valid", - Description: "Validates that the input is a valid SemVer string.", - Decl: types.NewFunction( - types.Args( - types.Named("vsn", types.A), - ), - types.Named("result", types.B).Description("`true` if `vsn` is a valid SemVer; `false` otherwise"), - ), -} +var SemVerIsValid = v1.SemVerIsValid -var SemVerCompare = &Builtin{ - Name: "semver.compare", - Description: "Compares valid SemVer formatted version strings.", - Decl: types.NewFunction( - types.Args( - types.Named("a", types.S), - types.Named("b", types.S), - ), - types.Named("result", types.N).Description("`-1` if `a < b`; `1` if `a > b`; `0` if `a == b`"), - ), -} +var SemVerCompare = v1.SemVerCompare /** * Printing @@ -3128,248 +579,56 @@ var SemVerCompare = &Builtin{ // operands may be of any type. Furthermore, unlike other built-in functions, // undefined operands DO NOT cause the print() function to fail during // evaluation. -var Print = &Builtin{ - Name: "print", - Decl: types.NewVariadicFunction(nil, types.A, nil), -} +var Print = v1.Print // InternalPrint represents the internal implementation of the print() function. // The compiler rewrites print() calls to refer to the internal implementation. -var InternalPrint = &Builtin{ - Name: "internal.print", - Decl: types.NewFunction([]types.Type{types.NewArray(nil, types.NewSet(types.A))}, nil), -} +var InternalPrint = v1.InternalPrint /** * Deprecated built-ins. */ // SetDiff has been replaced by the minus built-in. -var SetDiff = &Builtin{ - Name: "set_diff", - Decl: types.NewFunction( - types.Args( - types.NewSet(types.A), - types.NewSet(types.A), - ), - types.NewSet(types.A), - ), - deprecated: true, -} +var SetDiff = v1.SetDiff // NetCIDROverlap has been replaced by the `net.cidr_contains` built-in. -var NetCIDROverlap = &Builtin{ - Name: "net.cidr_overlap", - Decl: types.NewFunction( - types.Args( - types.S, - types.S, - ), - types.B, - ), - deprecated: true, -} +var NetCIDROverlap = v1.NetCIDROverlap // CastArray checks the underlying type of the input. If it is array or set, an array // containing the values is returned. If it is not an array, an error is thrown. -var CastArray = &Builtin{ - Name: "cast_array", - Decl: types.NewFunction( - types.Args(types.A), - types.NewArray(nil, types.A), - ), - deprecated: true, -} +var CastArray = v1.CastArray // CastSet checks the underlying type of the input. // If it is a set, the set is returned. // If it is an array, the array is returned in set form (all duplicates removed) // If neither, an error is thrown -var CastSet = &Builtin{ - Name: "cast_set", - Decl: types.NewFunction( - types.Args(types.A), - types.NewSet(types.A), - ), - deprecated: true, -} +var CastSet = v1.CastSet // CastString returns input if it is a string; if not returns error. // For formatting variables, see sprintf -var CastString = &Builtin{ - Name: "cast_string", - Decl: types.NewFunction( - types.Args(types.A), - types.S, - ), - deprecated: true, -} +var CastString = v1.CastString // CastBoolean returns input if it is a boolean; if not returns error. -var CastBoolean = &Builtin{ - Name: "cast_boolean", - Decl: types.NewFunction( - types.Args(types.A), - types.B, - ), - deprecated: true, -} +var CastBoolean = v1.CastBoolean // CastNull returns null if input is null; if not returns error. -var CastNull = &Builtin{ - Name: "cast_null", - Decl: types.NewFunction( - types.Args(types.A), - types.NewNull(), - ), - deprecated: true, -} +var CastNull = v1.CastNull // CastObject returns the given object if it is null; throws an error otherwise -var CastObject = &Builtin{ - Name: "cast_object", - Decl: types.NewFunction( - types.Args(types.A), - types.NewObject(nil, types.NewDynamicProperty(types.A, types.A)), - ), - deprecated: true, -} +var CastObject = v1.CastObject // RegexMatchDeprecated declares `re_match` which has been deprecated. Use `regex.match` instead. -var RegexMatchDeprecated = &Builtin{ - Name: "re_match", - Decl: types.NewFunction( - types.Args( - types.S, - types.S, - ), - types.B, - ), - deprecated: true, -} +var RegexMatchDeprecated = v1.RegexMatchDeprecated // All takes a list and returns true if all of the items // are true. A collection of length 0 returns true. -var All = &Builtin{ - Name: "all", - Decl: types.NewFunction( - types.Args( - types.NewAny( - types.NewSet(types.A), - types.NewArray(nil, types.A), - ), - ), - types.B, - ), - deprecated: true, -} +var All = v1.All // Any takes a collection and returns true if any of the items // is true. A collection of length 0 returns false. -var Any = &Builtin{ - Name: "any", - Decl: types.NewFunction( - types.Args( - types.NewAny( - types.NewSet(types.A), - types.NewArray(nil, types.A), - ), - ), - types.B, - ), - deprecated: true, -} +var Any = v1.Any // Builtin represents a built-in function supported by OPA. Every built-in // function is uniquely identified by a name. -type Builtin struct { - Name string `json:"name"` // Unique name of built-in function, e.g., (arg1,arg2,...,argN) - Description string `json:"description,omitempty"` // Description of what the built-in function does. - - // Categories of the built-in function. Omitted for namespaced - // built-ins, i.e. "array.concat" is taken to be of the "array" category. - // "minus" for example, is part of two categories: numbers and sets. (NOTE(sr): aspirational) - Categories []string `json:"categories,omitempty"` - - Decl *types.Function `json:"decl"` // Built-in function type declaration. - Infix string `json:"infix,omitempty"` // Unique name of infix operator. Default should be unset. - Relation bool `json:"relation,omitempty"` // Indicates if the built-in acts as a relation. - deprecated bool // Indicates if the built-in has been deprecated. - Nondeterministic bool `json:"nondeterministic,omitempty"` // Indicates if the built-in returns non-deterministic results. -} - -// category is a helper for specifying a Builtin's Categories -func category(cs ...string) []string { - return cs -} - -// Minimal returns a shallow copy of b with the descriptions and categories and -// named arguments stripped out. -func (b *Builtin) Minimal() *Builtin { - cpy := *b - fargs := b.Decl.FuncArgs() - if fargs.Variadic != nil { - cpy.Decl = types.NewVariadicFunction(fargs.Args, fargs.Variadic, b.Decl.Result()) - } else { - cpy.Decl = types.NewFunction(fargs.Args, b.Decl.Result()) - } - cpy.Categories = nil - cpy.Description = "" - return &cpy -} - -// IsDeprecated returns true if the Builtin function is deprecated and will be removed in a future release. -func (b *Builtin) IsDeprecated() bool { - return b.deprecated -} - -// IsDeterministic returns true if the Builtin function returns non-deterministic results. -func (b *Builtin) IsNondeterministic() bool { - return b.Nondeterministic -} - -// Expr creates a new expression for the built-in with the given operands. -func (b *Builtin) Expr(operands ...*Term) *Expr { - ts := make([]*Term, len(operands)+1) - ts[0] = NewTerm(b.Ref()) - for i := range operands { - ts[i+1] = operands[i] - } - return &Expr{ - Terms: ts, - } -} - -// Call creates a new term for the built-in with the given operands. -func (b *Builtin) Call(operands ...*Term) *Term { - call := make(Call, len(operands)+1) - call[0] = NewTerm(b.Ref()) - for i := range operands { - call[i+1] = operands[i] - } - return NewTerm(call) -} - -// Ref returns a Ref that refers to the built-in function. -func (b *Builtin) Ref() Ref { - parts := strings.Split(b.Name, ".") - ref := make(Ref, len(parts)) - ref[0] = VarTerm(parts[0]) - for i := 1; i < len(parts); i++ { - ref[i] = StringTerm(parts[i]) - } - return ref -} - -// IsTargetPos returns true if a variable in the i-th position will be bound by -// evaluating the call expression. -func (b *Builtin) IsTargetPos(i int) bool { - return len(b.Decl.FuncArgs().Args) == i -} - -func init() { - BuiltinMap = map[string]*Builtin{} - for _, b := range DefaultBuiltins { - RegisterBuiltin(b) - } -} +type Builtin = v1.Builtin diff --git a/vendor/github.com/open-policy-agent/opa/ast/capabilities.go b/vendor/github.com/open-policy-agent/opa/ast/capabilities.go index 3b95d79e5..bc7278a88 100644 --- a/vendor/github.com/open-policy-agent/opa/ast/capabilities.go +++ b/vendor/github.com/open-policy-agent/opa/ast/capabilities.go @@ -5,228 +5,54 @@ package ast import ( - "bytes" - _ "embed" - "encoding/json" - "fmt" "io" - "os" - "sort" - "strings" - caps "github.com/open-policy-agent/opa/capabilities" - "github.com/open-policy-agent/opa/internal/semver" - "github.com/open-policy-agent/opa/internal/wasm/sdk/opa/capabilities" - "github.com/open-policy-agent/opa/util" + v1 "github.com/open-policy-agent/opa/v1/ast" ) // VersonIndex contains an index from built-in function name, language feature, // and future rego keyword to version number. During the build, this is used to // create an index of the minimum version required for the built-in/feature/kw. -type VersionIndex struct { - Builtins map[string]semver.Version `json:"builtins"` - Features map[string]semver.Version `json:"features"` - Keywords map[string]semver.Version `json:"keywords"` -} - -// NOTE(tsandall): this file is generated by internal/cmd/genversionindex/main.go -// and run as part of go:generate. We generate the version index as part of the -// build process because it's relatively expensive to build (it takes ~500ms on -// my machine) and never changes. -// -//go:embed version_index.json -var versionIndexBs []byte - -var minVersionIndex = func() VersionIndex { - var vi VersionIndex - err := json.Unmarshal(versionIndexBs, &vi) - if err != nil { - panic(err) - } - return vi -}() +type VersionIndex = v1.VersionIndex // In the compiler, we used this to check that we're OK working with ref heads. // If this isn't present, we'll fail. This is to ensure that older versions of // OPA can work with policies that we're compiling -- if they don't know ref // heads, they wouldn't be able to parse them. -const FeatureRefHeadStringPrefixes = "rule_head_ref_string_prefixes" -const FeatureRefHeads = "rule_head_refs" -const FeatureRegoV1Import = "rego_v1_import" +const FeatureRefHeadStringPrefixes = v1.FeatureRefHeadStringPrefixes +const FeatureRefHeads = v1.FeatureRefHeads +const FeatureRegoV1 = v1.FeatureRegoV1 +const FeatureRegoV1Import = v1.FeatureRegoV1Import // Capabilities defines a structure containing data that describes the capabilities // or features supported by a particular version of OPA. -type Capabilities struct { - Builtins []*Builtin `json:"builtins,omitempty"` - FutureKeywords []string `json:"future_keywords,omitempty"` - WasmABIVersions []WasmABIVersion `json:"wasm_abi_versions,omitempty"` - - // Features is a bit of a mixed bag for checking that an older version of OPA - // is able to do what needs to be done. - // TODO(sr): find better words ^^ - Features []string `json:"features,omitempty"` - - // allow_net is an array of hostnames or IP addresses, that an OPA instance is - // allowed to connect to. - // If omitted, ANY host can be connected to. If empty, NO host can be connected to. - // As of now, this only controls fetching remote refs for using JSON Schemas in - // the type checker. - // TODO(sr): support ports to further restrict connection peers - // TODO(sr): support restricting `http.send` using the same mechanism (see https://github.com/open-policy-agent/opa/issues/3665) - AllowNet []string `json:"allow_net,omitempty"` -} +type Capabilities = v1.Capabilities // WasmABIVersion captures the Wasm ABI version. Its `Minor` version is indicating // backwards-compatible changes. -type WasmABIVersion struct { - Version int `json:"version"` - Minor int `json:"minor_version"` -} +type WasmABIVersion = v1.WasmABIVersion // CapabilitiesForThisVersion returns the capabilities of this version of OPA. func CapabilitiesForThisVersion() *Capabilities { - f := &Capabilities{} - - for _, vers := range capabilities.ABIVersions() { - f.WasmABIVersions = append(f.WasmABIVersions, WasmABIVersion{Version: vers[0], Minor: vers[1]}) - } - - f.Builtins = make([]*Builtin, len(Builtins)) - copy(f.Builtins, Builtins) - sort.Slice(f.Builtins, func(i, j int) bool { - return f.Builtins[i].Name < f.Builtins[j].Name - }) - - for kw := range futureKeywords { - f.FutureKeywords = append(f.FutureKeywords, kw) - } - sort.Strings(f.FutureKeywords) - - f.Features = []string{ - FeatureRefHeadStringPrefixes, - FeatureRefHeads, - FeatureRegoV1Import, - } - - return f + return v1.CapabilitiesForThisVersion(v1.CapabilitiesRegoVersion(DefaultRegoVersion)) } // LoadCapabilitiesJSON loads a JSON serialized capabilities structure from the reader r. func LoadCapabilitiesJSON(r io.Reader) (*Capabilities, error) { - d := util.NewJSONDecoder(r) - var c Capabilities - return &c, d.Decode(&c) + return v1.LoadCapabilitiesJSON(r) } // LoadCapabilitiesVersion loads a JSON serialized capabilities structure from the specific version. func LoadCapabilitiesVersion(version string) (*Capabilities, error) { - cvs, err := LoadCapabilitiesVersions() - if err != nil { - return nil, err - } - - for _, cv := range cvs { - if cv == version { - cont, err := caps.FS.ReadFile(cv + ".json") - if err != nil { - return nil, err - } - - return LoadCapabilitiesJSON(bytes.NewReader(cont)) - } - - } - return nil, fmt.Errorf("no capabilities version found %v", version) + return v1.LoadCapabilitiesVersion(version) } // LoadCapabilitiesFile loads a JSON serialized capabilities structure from a file. func LoadCapabilitiesFile(file string) (*Capabilities, error) { - fd, err := os.Open(file) - if err != nil { - return nil, err - } - defer fd.Close() - return LoadCapabilitiesJSON(fd) + return v1.LoadCapabilitiesFile(file) } // LoadCapabilitiesVersions loads all capabilities versions func LoadCapabilitiesVersions() ([]string, error) { - ents, err := caps.FS.ReadDir(".") - if err != nil { - return nil, err - } - - capabilitiesVersions := make([]string, 0, len(ents)) - for _, ent := range ents { - capabilitiesVersions = append(capabilitiesVersions, strings.Replace(ent.Name(), ".json", "", 1)) - } - return capabilitiesVersions, nil -} - -// MinimumCompatibleVersion returns the minimum compatible OPA version based on -// the built-ins, features, and keywords in c. -func (c *Capabilities) MinimumCompatibleVersion() (string, bool) { - - var maxVersion semver.Version - - // this is the oldest OPA release that includes capabilities - if err := maxVersion.Set("0.17.0"); err != nil { - panic("unreachable") - } - - for _, bi := range c.Builtins { - v, ok := minVersionIndex.Builtins[bi.Name] - if !ok { - return "", false - } - if v.Compare(maxVersion) > 0 { - maxVersion = v - } - } - - for _, kw := range c.FutureKeywords { - v, ok := minVersionIndex.Keywords[kw] - if !ok { - return "", false - } - if v.Compare(maxVersion) > 0 { - maxVersion = v - } - } - - for _, feat := range c.Features { - v, ok := minVersionIndex.Features[feat] - if !ok { - return "", false - } - if v.Compare(maxVersion) > 0 { - maxVersion = v - } - } - - return maxVersion.String(), true -} - -func (c *Capabilities) ContainsFeature(feature string) bool { - for _, f := range c.Features { - if f == feature { - return true - } - } - return false -} - -// addBuiltinSorted inserts a built-in into c in sorted order. An existing built-in with the same name -// will be overwritten. -func (c *Capabilities) addBuiltinSorted(bi *Builtin) { - i := sort.Search(len(c.Builtins), func(x int) bool { - return c.Builtins[x].Name >= bi.Name - }) - if i < len(c.Builtins) && bi.Name == c.Builtins[i].Name { - c.Builtins[i] = bi - return - } - c.Builtins = append(c.Builtins, nil) - copy(c.Builtins[i+1:], c.Builtins[i:]) - c.Builtins[i] = bi + return v1.LoadCapabilitiesVersions() } diff --git a/vendor/github.com/open-policy-agent/opa/ast/check.go b/vendor/github.com/open-policy-agent/opa/ast/check.go index 03d31123c..4cf00436d 100644 --- a/vendor/github.com/open-policy-agent/opa/ast/check.go +++ b/vendor/github.com/open-policy-agent/opa/ast/check.go @@ -5,1317 +5,18 @@ package ast import ( - "fmt" - "sort" - "strings" - - "github.com/open-policy-agent/opa/types" - "github.com/open-policy-agent/opa/util" + v1 "github.com/open-policy-agent/opa/v1/ast" ) -type varRewriter func(Ref) Ref - -// exprChecker defines the interface for executing type checking on a single -// expression. The exprChecker must update the provided TypeEnv with inferred -// types of vars. -type exprChecker func(*TypeEnv, *Expr) *Error - -// typeChecker implements type checking on queries and rules. Errors are -// accumulated on the typeChecker so that a single run can report multiple -// issues. -type typeChecker struct { - builtins map[string]*Builtin - required *Capabilities - errs Errors - exprCheckers map[string]exprChecker - varRewriter varRewriter - ss *SchemaSet - allowNet []string - input types.Type - allowUndefinedFuncs bool - schemaTypes map[string]types.Type -} - -// newTypeChecker returns a new typeChecker object that has no errors. -func newTypeChecker() *typeChecker { - return &typeChecker{ - builtins: make(map[string]*Builtin), - schemaTypes: make(map[string]types.Type), - exprCheckers: map[string]exprChecker{ - "eq": checkExprEq, - }, - } -} - -func (tc *typeChecker) newEnv(exist *TypeEnv) *TypeEnv { - if exist != nil { - return exist.wrap() - } - env := newTypeEnv(tc.copy) - if tc.input != nil { - env.tree.Put(InputRootRef, tc.input) - } - return env -} - -func (tc *typeChecker) copy() *typeChecker { - return newTypeChecker(). - WithVarRewriter(tc.varRewriter). - WithSchemaSet(tc.ss). - WithAllowNet(tc.allowNet). - WithInputType(tc.input). - WithAllowUndefinedFunctionCalls(tc.allowUndefinedFuncs). - WithBuiltins(tc.builtins). - WithRequiredCapabilities(tc.required) -} - -func (tc *typeChecker) WithRequiredCapabilities(c *Capabilities) *typeChecker { - tc.required = c - return tc -} - -func (tc *typeChecker) WithBuiltins(builtins map[string]*Builtin) *typeChecker { - tc.builtins = builtins - return tc -} - -func (tc *typeChecker) WithSchemaSet(ss *SchemaSet) *typeChecker { - tc.ss = ss - return tc -} - -func (tc *typeChecker) WithAllowNet(hosts []string) *typeChecker { - tc.allowNet = hosts - return tc -} - -func (tc *typeChecker) WithVarRewriter(f varRewriter) *typeChecker { - tc.varRewriter = f - return tc -} - -func (tc *typeChecker) WithInputType(tpe types.Type) *typeChecker { - tc.input = tpe - return tc -} - -// WithAllowUndefinedFunctionCalls sets the type checker to allow references to undefined functions. -// Additionally, the 'CheckUndefinedFuncs' and 'CheckSafetyRuleBodies' compiler stages are skipped. -func (tc *typeChecker) WithAllowUndefinedFunctionCalls(allow bool) *typeChecker { - tc.allowUndefinedFuncs = allow - return tc -} - -// Env returns a type environment for the specified built-ins with any other -// global types configured on the checker. In practice, this is the default -// environment that other statements will be checked against. -func (tc *typeChecker) Env(builtins map[string]*Builtin) *TypeEnv { - env := tc.newEnv(nil) - for _, bi := range builtins { - env.tree.Put(bi.Ref(), bi.Decl) - } - return env -} - -// CheckBody runs type checking on the body and returns a TypeEnv if no errors -// are found. The resulting TypeEnv wraps the provided one. The resulting -// TypeEnv will be able to resolve types of vars contained in the body. -func (tc *typeChecker) CheckBody(env *TypeEnv, body Body) (*TypeEnv, Errors) { - - errors := []*Error{} - env = tc.newEnv(env) - - WalkExprs(body, func(expr *Expr) bool { - - closureErrs := tc.checkClosures(env, expr) - for _, err := range closureErrs { - errors = append(errors, err) - } - - hasClosureErrors := len(closureErrs) > 0 - - vis := newRefChecker(env, tc.varRewriter) - NewGenericVisitor(vis.Visit).Walk(expr) - for _, err := range vis.errs { - errors = append(errors, err) - } - - hasRefErrors := len(vis.errs) > 0 - - if err := tc.checkExpr(env, expr); err != nil { - // Suppress this error if a more actionable one has occurred. In - // this case, if an error occurred in a ref or closure contained in - // this expression, and the error is due to a nil type, then it's - // likely to be the result of the more specific error. - skip := (hasClosureErrors || hasRefErrors) && causedByNilType(err) - if !skip { - errors = append(errors, err) - } - } - return true - }) - - tc.err(errors) - return env, errors -} - -// CheckTypes runs type checking on the rules returns a TypeEnv if no errors -// are found. The resulting TypeEnv wraps the provided one. The resulting -// TypeEnv will be able to resolve types of refs that refer to rules. -func (tc *typeChecker) CheckTypes(env *TypeEnv, sorted []util.T, as *AnnotationSet) (*TypeEnv, Errors) { - env = tc.newEnv(env) - for _, s := range sorted { - tc.checkRule(env, as, s.(*Rule)) - } - tc.errs.Sort() - return env, tc.errs -} - -func (tc *typeChecker) checkClosures(env *TypeEnv, expr *Expr) Errors { - var result Errors - WalkClosures(expr, func(x interface{}) bool { - switch x := x.(type) { - case *ArrayComprehension: - _, errs := tc.copy().CheckBody(env, x.Body) - if len(errs) > 0 { - result = errs - return true - } - case *SetComprehension: - _, errs := tc.copy().CheckBody(env, x.Body) - if len(errs) > 0 { - result = errs - return true - } - case *ObjectComprehension: - _, errs := tc.copy().CheckBody(env, x.Body) - if len(errs) > 0 { - result = errs - return true - } - } - return false - }) - return result -} - -func (tc *typeChecker) getSchemaType(schemaAnnot *SchemaAnnotation, rule *Rule) (types.Type, *Error) { - if refType, exists := tc.schemaTypes[schemaAnnot.Schema.String()]; exists { - return refType, nil - } - - refType, err := processAnnotation(tc.ss, schemaAnnot, rule, tc.allowNet) - if err != nil { - return nil, err - } - - if refType == nil { - return nil, nil - } - - tc.schemaTypes[schemaAnnot.Schema.String()] = refType - return refType, nil - -} - -func (tc *typeChecker) checkRule(env *TypeEnv, as *AnnotationSet, rule *Rule) { - - env = env.wrap() - - schemaAnnots := getRuleAnnotation(as, rule) - for _, schemaAnnot := range schemaAnnots { - refType, err := tc.getSchemaType(schemaAnnot, rule) - if err != nil { - tc.err([]*Error{err}) - continue - } - - ref := schemaAnnot.Path - // if we do not have a ref or a reftype, we should not evaluate this rule. - if ref == nil || refType == nil { - continue - } - - prefixRef, t := getPrefix(env, ref) - if t == nil || len(prefixRef) == len(ref) { - env.tree.Put(ref, refType) - } else { - newType, err := override(ref[len(prefixRef):], t, refType, rule) - if err != nil { - tc.err([]*Error{err}) - continue - } - env.tree.Put(prefixRef, newType) - } - } - - cpy, err := tc.CheckBody(env, rule.Body) - env = env.next - path := rule.Ref() - - if len(err) > 0 { - // if the rule/function contains an error, add it to the type env so - // that expressions that refer to this rule/function do not encounter - // type errors. - env.tree.Put(path, types.A) - return - } - - var tpe types.Type - - if len(rule.Head.Args) > 0 { - // If args are not referred to in body, infer as any. - WalkVars(rule.Head.Args, func(v Var) bool { - if cpy.Get(v) == nil { - cpy.tree.PutOne(v, types.A) - } - return false - }) - - // Construct function type. - args := make([]types.Type, len(rule.Head.Args)) - for i := 0; i < len(rule.Head.Args); i++ { - args[i] = cpy.Get(rule.Head.Args[i]) - } - - f := types.NewFunction(args, cpy.Get(rule.Head.Value)) - - tpe = f - } else { - switch rule.Head.RuleKind() { - case SingleValue: - typeV := cpy.Get(rule.Head.Value) - if !path.IsGround() { - // e.g. store object[string: whatever] at data.p.q.r, not data.p.q.r[x] or data.p.q.r[x].y[z] - objPath := path.DynamicSuffix() - path = path.GroundPrefix() - - var err error - tpe, err = nestedObject(cpy, objPath, typeV) - if err != nil { - tc.err([]*Error{NewError(TypeErr, rule.Head.Location, err.Error())}) - tpe = nil - } - } else { - if typeV != nil { - tpe = typeV - } - } - case MultiValue: - typeK := cpy.Get(rule.Head.Key) - if typeK != nil { - tpe = types.NewSet(typeK) - } - } - } - - if tpe != nil { - env.tree.Insert(path, tpe, env) - } -} - -// nestedObject creates a nested structure of object types, where each term on path corresponds to a level in the -// nesting. Each term in the path only contributes to the dynamic portion of its corresponding object. -func nestedObject(env *TypeEnv, path Ref, tpe types.Type) (types.Type, error) { - if len(path) == 0 { - return tpe, nil - } - - k := path[0] - typeV, err := nestedObject(env, path[1:], tpe) - if err != nil { - return nil, err - } - if typeV == nil { - return nil, nil - } - - var dynamicProperty *types.DynamicProperty - typeK := env.Get(k) - if typeK == nil { - return nil, nil - } - dynamicProperty = types.NewDynamicProperty(typeK, typeV) - - return types.NewObject(nil, dynamicProperty), nil -} - -func (tc *typeChecker) checkExpr(env *TypeEnv, expr *Expr) *Error { - if err := tc.checkExprWith(env, expr, 0); err != nil { - return err - } - if !expr.IsCall() { - return nil - } - - operator := expr.Operator().String() - - // If the type checker wasn't provided with a required capabilities - // structure then just skip. In some cases, type checking might be run - // without the need to record what builtins are required. - if tc.required != nil { - if bi, ok := tc.builtins[operator]; ok { - tc.required.addBuiltinSorted(bi) - } - } - - checker := tc.exprCheckers[operator] - if checker != nil { - return checker(env, expr) - } - - return tc.checkExprBuiltin(env, expr) -} - -func (tc *typeChecker) checkExprBuiltin(env *TypeEnv, expr *Expr) *Error { - - args := expr.Operands() - pre := getArgTypes(env, args) - - // NOTE(tsandall): undefined functions will have been caught earlier in the - // compiler. We check for undefined functions before the safety check so - // that references to non-existent functions result in undefined function - // errors as opposed to unsafe var errors. - // - // We cannot run type checking before the safety check because part of the - // type checker relies on reordering (in particular for references to local - // vars). - name := expr.Operator() - tpe := env.Get(name) - - if tpe == nil { - if tc.allowUndefinedFuncs { - return nil - } - return NewError(TypeErr, expr.Location, "undefined function %v", name) - } - - // check if the expression refers to a function that contains an error - _, ok := tpe.(types.Any) - if ok { - return nil - } - - ftpe, ok := tpe.(*types.Function) - if !ok { - return NewError(TypeErr, expr.Location, "undefined function %v", name) - } - - fargs := ftpe.FuncArgs() - namedFargs := ftpe.NamedFuncArgs() - - if ftpe.Result() != nil { - fargs.Args = append(fargs.Args, ftpe.Result()) - namedFargs.Args = append(namedFargs.Args, ftpe.NamedResult()) - } - - if len(args) > len(fargs.Args) && fargs.Variadic == nil { - return newArgError(expr.Location, name, "too many arguments", pre, namedFargs) - } - - if len(args) < len(ftpe.FuncArgs().Args) { - return newArgError(expr.Location, name, "too few arguments", pre, namedFargs) - } - - for i := range args { - if !unify1(env, args[i], fargs.Arg(i), false) { - post := make([]types.Type, len(args)) - for i := range args { - post[i] = env.Get(args[i]) - } - return newArgError(expr.Location, name, "invalid argument(s)", post, namedFargs) - } - } - - return nil -} - -func checkExprEq(env *TypeEnv, expr *Expr) *Error { - - pre := getArgTypes(env, expr.Operands()) - exp := Equality.Decl.FuncArgs() - - if len(pre) < len(exp.Args) { - return newArgError(expr.Location, expr.Operator(), "too few arguments", pre, exp) - } - - if len(exp.Args) < len(pre) { - return newArgError(expr.Location, expr.Operator(), "too many arguments", pre, exp) - } - - a, b := expr.Operand(0), expr.Operand(1) - typeA, typeB := env.Get(a), env.Get(b) - - if !unify2(env, a, typeA, b, typeB) { - err := NewError(TypeErr, expr.Location, "match error") - err.Details = &UnificationErrDetail{ - Left: typeA, - Right: typeB, - } - return err - } - - return nil -} - -func (tc *typeChecker) checkExprWith(env *TypeEnv, expr *Expr, i int) *Error { - if i == len(expr.With) { - return nil - } - - target, value := expr.With[i].Target, expr.With[i].Value - targetType, valueType := env.Get(target), env.Get(value) - - if t, ok := targetType.(*types.Function); ok { // built-in function replacement - switch v := valueType.(type) { - case *types.Function: // ...by function - if !unifies(targetType, valueType) { - return newArgError(expr.With[i].Loc(), target.Value.(Ref), "arity mismatch", v.FuncArgs().Args, t.NamedFuncArgs()) - } - default: // ... by value, nothing to check - } - } - - return tc.checkExprWith(env, expr, i+1) -} - -func unify2(env *TypeEnv, a *Term, typeA types.Type, b *Term, typeB types.Type) bool { - - nilA := types.Nil(typeA) - nilB := types.Nil(typeB) - - if nilA && !nilB { - return unify1(env, a, typeB, false) - } else if nilB && !nilA { - return unify1(env, b, typeA, false) - } else if !nilA && !nilB { - return unifies(typeA, typeB) - } - - switch a.Value.(type) { - case *Array: - return unify2Array(env, a, b) - case *object: - return unify2Object(env, a, b) - case Var: - switch b.Value.(type) { - case Var: - return unify1(env, a, types.A, false) && unify1(env, b, env.Get(a), false) - case *Array: - return unify2Array(env, b, a) - case *object: - return unify2Object(env, b, a) - } - } - - return false -} - -func unify2Array(env *TypeEnv, a *Term, b *Term) bool { - arr := a.Value.(*Array) - switch bv := b.Value.(type) { - case *Array: - if arr.Len() == bv.Len() { - for i := 0; i < arr.Len(); i++ { - if !unify2(env, arr.Elem(i), env.Get(arr.Elem(i)), bv.Elem(i), env.Get(bv.Elem(i))) { - return false - } - } - return true - } - case Var: - return unify1(env, a, types.A, false) && unify1(env, b, env.Get(a), false) - } - return false -} - -func unify2Object(env *TypeEnv, a *Term, b *Term) bool { - obj := a.Value.(Object) - switch bv := b.Value.(type) { - case *object: - cv := obj.Intersect(bv) - if obj.Len() == bv.Len() && bv.Len() == len(cv) { - for i := range cv { - if !unify2(env, cv[i][1], env.Get(cv[i][1]), cv[i][2], env.Get(cv[i][2])) { - return false - } - } - return true - } - case Var: - return unify1(env, a, types.A, false) && unify1(env, b, env.Get(a), false) - } - return false -} - -func unify1(env *TypeEnv, term *Term, tpe types.Type, union bool) bool { - switch v := term.Value.(type) { - case *Array: - switch tpe := tpe.(type) { - case *types.Array: - return unify1Array(env, v, tpe, union) - case types.Any: - if types.Compare(tpe, types.A) == 0 { - for i := 0; i < v.Len(); i++ { - unify1(env, v.Elem(i), types.A, true) - } - return true - } - unifies := false - for i := range tpe { - unifies = unify1(env, term, tpe[i], true) || unifies - } - return unifies - } - return false - case *object: - switch tpe := tpe.(type) { - case *types.Object: - return unify1Object(env, v, tpe, union) - case types.Any: - if types.Compare(tpe, types.A) == 0 { - v.Foreach(func(key, value *Term) { - unify1(env, key, types.A, true) - unify1(env, value, types.A, true) - }) - return true - } - unifies := false - for i := range tpe { - unifies = unify1(env, term, tpe[i], true) || unifies - } - return unifies - } - return false - case Set: - switch tpe := tpe.(type) { - case *types.Set: - return unify1Set(env, v, tpe, union) - case types.Any: - if types.Compare(tpe, types.A) == 0 { - v.Foreach(func(elem *Term) { - unify1(env, elem, types.A, true) - }) - return true - } - unifies := false - for i := range tpe { - unifies = unify1(env, term, tpe[i], true) || unifies - } - return unifies - } - return false - case Ref, *ArrayComprehension, *ObjectComprehension, *SetComprehension: - return unifies(env.Get(v), tpe) - case Var: - if !union { - if exist := env.Get(v); exist != nil { - return unifies(exist, tpe) - } - env.tree.PutOne(term.Value, tpe) - } else { - env.tree.PutOne(term.Value, types.Or(env.Get(v), tpe)) - } - return true - default: - if !IsConstant(v) { - panic("unreachable") - } - return unifies(env.Get(term), tpe) - } -} - -func unify1Array(env *TypeEnv, val *Array, tpe *types.Array, union bool) bool { - if val.Len() != tpe.Len() && tpe.Dynamic() == nil { - return false - } - for i := 0; i < val.Len(); i++ { - if !unify1(env, val.Elem(i), tpe.Select(i), union) { - return false - } - } - return true -} - -func unify1Object(env *TypeEnv, val Object, tpe *types.Object, union bool) bool { - if val.Len() != len(tpe.Keys()) && tpe.DynamicValue() == nil { - return false - } - stop := val.Until(func(k, v *Term) bool { - if IsConstant(k.Value) { - if child := selectConstant(tpe, k); child != nil { - if !unify1(env, v, child, union) { - return true - } - } else { - return true - } - } else { - // Inferring type of value under dynamic key would involve unioning - // with all property values of tpe whose keys unify. For now, type - // these values as Any. We can investigate stricter inference in - // the future. - unify1(env, v, types.A, union) - } - return false - }) - return !stop -} - -func unify1Set(env *TypeEnv, val Set, tpe *types.Set, union bool) bool { - of := types.Values(tpe) - return !val.Until(func(elem *Term) bool { - return !unify1(env, elem, of, union) - }) -} - -func (tc *typeChecker) err(errors []*Error) { - tc.errs = append(tc.errs, errors...) -} - -type refChecker struct { - env *TypeEnv - errs Errors - varRewriter varRewriter -} - -func rewriteVarsNop(node Ref) Ref { - return node -} - -func newRefChecker(env *TypeEnv, f varRewriter) *refChecker { - - if f == nil { - f = rewriteVarsNop - } - - return &refChecker{ - env: env, - errs: nil, - varRewriter: f, - } -} - -func (rc *refChecker) Visit(x interface{}) bool { - switch x := x.(type) { - case *ArrayComprehension, *ObjectComprehension, *SetComprehension: - return true - case *Expr: - switch terms := x.Terms.(type) { - case []*Term: - for i := 1; i < len(terms); i++ { - NewGenericVisitor(rc.Visit).Walk(terms[i]) - } - return true - case *Term: - NewGenericVisitor(rc.Visit).Walk(terms) - return true - } - case Ref: - if err := rc.checkApply(rc.env, x); err != nil { - rc.errs = append(rc.errs, err) - return true - } - if err := rc.checkRef(rc.env, rc.env.tree, x, 0); err != nil { - rc.errs = append(rc.errs, err) - } - } - return false -} - -func (rc *refChecker) checkApply(curr *TypeEnv, ref Ref) *Error { - switch tpe := curr.Get(ref).(type) { - case *types.Function: // NOTE(sr): We don't support first-class functions, except for `with`. - return newRefErrUnsupported(ref[0].Location, rc.varRewriter(ref), len(ref)-1, tpe) - } - - return nil -} - -func (rc *refChecker) checkRef(curr *TypeEnv, node *typeTreeNode, ref Ref, idx int) *Error { - - if idx == len(ref) { - return nil - } - - head := ref[idx] - - // NOTE(sr): as long as package statements are required, this isn't possible: - // the shortest possible rule ref is data.a.b (b is idx 2), idx 1 and 2 need to - // be strings or vars. - if idx == 1 || idx == 2 { - switch head.Value.(type) { - case Var, String: // OK - default: - have := rc.env.Get(head.Value) - return newRefErrInvalid(ref[0].Location, rc.varRewriter(ref), idx, have, types.S, getOneOfForNode(node)) - } - } - - if v, ok := head.Value.(Var); ok && idx != 0 { - tpe := types.Keys(rc.env.getRefRecExtent(node)) - if exist := rc.env.Get(v); exist != nil { - if !unifies(tpe, exist) { - return newRefErrInvalid(ref[0].Location, rc.varRewriter(ref), idx, exist, tpe, getOneOfForNode(node)) - } - } else { - rc.env.tree.PutOne(v, tpe) - } - } - - child := node.Child(head.Value) - if child == nil { - // NOTE(sr): idx is reset on purpose: we start over - switch { - case curr.next != nil: - next := curr.next - return rc.checkRef(next, next.tree, ref, 0) - - case RootDocumentNames.Contains(ref[0]): - if idx != 0 { - node.Children().Iter(func(_, child util.T) bool { - _ = rc.checkRef(curr, child.(*typeTreeNode), ref, idx+1) // ignore error - return false - }) - return nil - } - return rc.checkRefLeaf(types.A, ref, 1) - - default: - return rc.checkRefLeaf(types.A, ref, 0) - } - } - - if child.Leaf() { - return rc.checkRefLeaf(child.Value(), ref, idx+1) - } - - return rc.checkRef(curr, child, ref, idx+1) -} - -func (rc *refChecker) checkRefLeaf(tpe types.Type, ref Ref, idx int) *Error { - - if idx == len(ref) { - return nil - } - - head := ref[idx] - - keys := types.Keys(tpe) - if keys == nil { - return newRefErrUnsupported(ref[0].Location, rc.varRewriter(ref), idx-1, tpe) - } - - switch value := head.Value.(type) { - - case Var: - if exist := rc.env.Get(value); exist != nil { - if !unifies(exist, keys) { - return newRefErrInvalid(ref[0].Location, rc.varRewriter(ref), idx, exist, keys, getOneOfForType(tpe)) - } - } else { - rc.env.tree.PutOne(value, types.Keys(tpe)) - } - - case Ref: - if exist := rc.env.Get(value); exist != nil { - if !unifies(exist, keys) { - return newRefErrInvalid(ref[0].Location, rc.varRewriter(ref), idx, exist, keys, getOneOfForType(tpe)) - } - } - - case *Array, Object, Set: - if !unify1(rc.env, head, keys, false) { - return newRefErrInvalid(ref[0].Location, rc.varRewriter(ref), idx, rc.env.Get(head), keys, nil) - } - - default: - child := selectConstant(tpe, head) - if child == nil { - return newRefErrInvalid(ref[0].Location, rc.varRewriter(ref), idx, nil, types.Keys(tpe), getOneOfForType(tpe)) - } - return rc.checkRefLeaf(child, ref, idx+1) - } - - return rc.checkRefLeaf(types.Values(tpe), ref, idx+1) -} - -func unifies(a, b types.Type) bool { - - if a == nil || b == nil { - return false - } - - anyA, ok1 := a.(types.Any) - if ok1 { - if unifiesAny(anyA, b) { - return true - } - } - - anyB, ok2 := b.(types.Any) - if ok2 { - if unifiesAny(anyB, a) { - return true - } - } - - if ok1 || ok2 { - return false - } - - switch a := a.(type) { - case types.Null: - _, ok := b.(types.Null) - return ok - case types.Boolean: - _, ok := b.(types.Boolean) - return ok - case types.Number: - _, ok := b.(types.Number) - return ok - case types.String: - _, ok := b.(types.String) - return ok - case *types.Array: - b, ok := b.(*types.Array) - if !ok { - return false - } - return unifiesArrays(a, b) - case *types.Object: - b, ok := b.(*types.Object) - if !ok { - return false - } - return unifiesObjects(a, b) - case *types.Set: - b, ok := b.(*types.Set) - if !ok { - return false - } - return unifies(types.Values(a), types.Values(b)) - case *types.Function: - // NOTE(sr): variadic functions can only be internal ones, and we've forbidden - // their replacement via `with`; so we disregard variadic here - if types.Arity(a) == types.Arity(b) { - b := b.(*types.Function) - for i := range a.FuncArgs().Args { - if !unifies(a.FuncArgs().Arg(i), b.FuncArgs().Arg(i)) { - return false - } - } - return true - } - return false - default: - panic("unreachable") - } -} - -func unifiesAny(a types.Any, b types.Type) bool { - if _, ok := b.(*types.Function); ok { - return false - } - for i := range a { - if unifies(a[i], b) { - return true - } - } - return len(a) == 0 -} - -func unifiesArrays(a, b *types.Array) bool { - - if !unifiesArraysStatic(a, b) { - return false - } - - if !unifiesArraysStatic(b, a) { - return false - } - - return a.Dynamic() == nil || b.Dynamic() == nil || unifies(a.Dynamic(), b.Dynamic()) -} - -func unifiesArraysStatic(a, b *types.Array) bool { - if a.Len() != 0 { - for i := 0; i < a.Len(); i++ { - if !unifies(a.Select(i), b.Select(i)) { - return false - } - } - } - return true -} - -func unifiesObjects(a, b *types.Object) bool { - if !unifiesObjectsStatic(a, b) { - return false - } - - if !unifiesObjectsStatic(b, a) { - return false - } - - return a.DynamicValue() == nil || b.DynamicValue() == nil || unifies(a.DynamicValue(), b.DynamicValue()) -} - -func unifiesObjectsStatic(a, b *types.Object) bool { - for _, k := range a.Keys() { - if !unifies(a.Select(k), b.Select(k)) { - return false - } - } - return true -} - -// typeErrorCause defines an interface to determine the reason for a type -// error. The type error details implement this interface so that type checking -// can report more actionable errors. -type typeErrorCause interface { - nilType() bool -} - -func causedByNilType(err *Error) bool { - cause, ok := err.Details.(typeErrorCause) - if !ok { - return false - } - return cause.nilType() -} - -// ArgErrDetail represents a generic argument error. -type ArgErrDetail struct { - Have []types.Type `json:"have"` - Want types.FuncArgs `json:"want"` -} - -// Lines returns the string representation of the detail. -func (d *ArgErrDetail) Lines() []string { - lines := make([]string, 2) - lines[0] = "have: " + formatArgs(d.Have) - lines[1] = "want: " + fmt.Sprint(d.Want) - return lines -} - -func (d *ArgErrDetail) nilType() bool { - for i := range d.Have { - if types.Nil(d.Have[i]) { - return true - } - } - return false -} - // UnificationErrDetail describes a type mismatch error when two values are // unified (e.g., x = [1,2,y]). -type UnificationErrDetail struct { - Left types.Type `json:"a"` - Right types.Type `json:"b"` -} - -func (a *UnificationErrDetail) nilType() bool { - return types.Nil(a.Left) || types.Nil(a.Right) -} - -// Lines returns the string representation of the detail. -func (a *UnificationErrDetail) Lines() []string { - lines := make([]string, 2) - lines[0] = fmt.Sprint("left : ", types.Sprint(a.Left)) - lines[1] = fmt.Sprint("right : ", types.Sprint(a.Right)) - return lines -} +type UnificationErrDetail = v1.UnificationErrDetail // RefErrUnsupportedDetail describes an undefined reference error where the // referenced value does not support dereferencing (e.g., scalars). -type RefErrUnsupportedDetail struct { - Ref Ref `json:"ref"` // invalid ref - Pos int `json:"pos"` // invalid element - Have types.Type `json:"have"` // referenced type -} - -// Lines returns the string representation of the detail. -func (r *RefErrUnsupportedDetail) Lines() []string { - lines := []string{ - r.Ref.String(), - strings.Repeat("^", len(r.Ref[:r.Pos+1].String())), - fmt.Sprintf("have: %v", r.Have), - } - return lines -} +type RefErrUnsupportedDetail = v1.RefErrUnsupportedDetail // RefErrInvalidDetail describes an undefined reference error where the referenced // value does not support the reference operand (e.g., missing object key, // invalid key type, etc.) -type RefErrInvalidDetail struct { - Ref Ref `json:"ref"` // invalid ref - Pos int `json:"pos"` // invalid element - Have types.Type `json:"have,omitempty"` // type of invalid element (for var/ref elements) - Want types.Type `json:"want"` // allowed type (for non-object values) - OneOf []Value `json:"oneOf"` // allowed values (e.g., for object keys) -} - -// Lines returns the string representation of the detail. -func (r *RefErrInvalidDetail) Lines() []string { - lines := []string{r.Ref.String()} - offset := len(r.Ref[:r.Pos].String()) + 1 - pad := strings.Repeat(" ", offset) - lines = append(lines, fmt.Sprintf("%s^", pad)) - if r.Have != nil { - lines = append(lines, fmt.Sprintf("%shave (type): %v", pad, r.Have)) - } else { - lines = append(lines, fmt.Sprintf("%shave: %v", pad, r.Ref[r.Pos])) - } - if len(r.OneOf) > 0 { - lines = append(lines, fmt.Sprintf("%swant (one of): %v", pad, r.OneOf)) - } else { - lines = append(lines, fmt.Sprintf("%swant (type): %v", pad, r.Want)) - } - return lines -} - -func formatArgs(args []types.Type) string { - buf := make([]string, len(args)) - for i := range args { - buf[i] = types.Sprint(args[i]) - } - return "(" + strings.Join(buf, ", ") + ")" -} - -func newRefErrInvalid(loc *Location, ref Ref, idx int, have, want types.Type, oneOf []Value) *Error { - err := newRefError(loc, ref) - err.Details = &RefErrInvalidDetail{ - Ref: ref, - Pos: idx, - Have: have, - Want: want, - OneOf: oneOf, - } - return err -} - -func newRefErrUnsupported(loc *Location, ref Ref, idx int, have types.Type) *Error { - err := newRefError(loc, ref) - err.Details = &RefErrUnsupportedDetail{ - Ref: ref, - Pos: idx, - Have: have, - } - return err -} - -func newRefError(loc *Location, ref Ref) *Error { - return NewError(TypeErr, loc, "undefined ref: %v", ref) -} - -func newArgError(loc *Location, builtinName Ref, msg string, have []types.Type, want types.FuncArgs) *Error { - err := NewError(TypeErr, loc, "%v: %v", builtinName, msg) - err.Details = &ArgErrDetail{ - Have: have, - Want: want, - } - return err -} - -func getOneOfForNode(node *typeTreeNode) (result []Value) { - node.Children().Iter(func(k, _ util.T) bool { - result = append(result, k.(Value)) - return false - }) - - sortValueSlice(result) - return result -} - -func getOneOfForType(tpe types.Type) (result []Value) { - switch tpe := tpe.(type) { - case *types.Object: - for _, k := range tpe.Keys() { - v, err := InterfaceToValue(k) - if err != nil { - panic(err) - } - result = append(result, v) - } - - case types.Any: - for _, object := range tpe { - objRes := getOneOfForType(object) - result = append(result, objRes...) - } - } - - result = removeDuplicate(result) - sortValueSlice(result) - return result -} - -func sortValueSlice(sl []Value) { - sort.Slice(sl, func(i, j int) bool { - return sl[i].Compare(sl[j]) < 0 - }) -} - -func removeDuplicate(list []Value) []Value { - seen := make(map[Value]bool) - var newResult []Value - for _, item := range list { - if !seen[item] { - newResult = append(newResult, item) - seen[item] = true - } - } - return newResult -} - -func getArgTypes(env *TypeEnv, args []*Term) []types.Type { - pre := make([]types.Type, len(args)) - for i := range args { - pre[i] = env.Get(args[i]) - } - return pre -} - -// getPrefix returns the shortest prefix of ref that exists in env -func getPrefix(env *TypeEnv, ref Ref) (Ref, types.Type) { - if len(ref) == 1 { - t := env.Get(ref) - if t != nil { - return ref, t - } - } - for i := 1; i < len(ref); i++ { - t := env.Get(ref[:i]) - if t != nil { - return ref[:i], t - } - } - return nil, nil -} - -// override takes a type t and returns a type obtained from t where the path represented by ref within it has type o (overriding the original type of that path) -func override(ref Ref, t types.Type, o types.Type, rule *Rule) (types.Type, *Error) { - var newStaticProps []*types.StaticProperty - obj, ok := t.(*types.Object) - if !ok { - newType, err := getObjectType(ref, o, rule, types.NewDynamicProperty(types.A, types.A)) - if err != nil { - return nil, err - } - return newType, nil - } - found := false - if ok { - staticProps := obj.StaticProperties() - for _, prop := range staticProps { - valueCopy := prop.Value - key, err := InterfaceToValue(prop.Key) - if err != nil { - return nil, NewError(TypeErr, rule.Location, "unexpected error in override: %s", err.Error()) - } - if len(ref) > 0 && ref[0].Value.Compare(key) == 0 { - found = true - if len(ref) == 1 { - valueCopy = o - } else { - newVal, err := override(ref[1:], valueCopy, o, rule) - if err != nil { - return nil, err - } - valueCopy = newVal - } - } - newStaticProps = append(newStaticProps, types.NewStaticProperty(prop.Key, valueCopy)) - } - } - - // ref[0] is not a top-level key in staticProps, so it must be added - if !found { - newType, err := getObjectType(ref, o, rule, obj.DynamicProperties()) - if err != nil { - return nil, err - } - newStaticProps = append(newStaticProps, newType.StaticProperties()...) - } - return types.NewObject(newStaticProps, obj.DynamicProperties()), nil -} - -func getKeys(ref Ref, rule *Rule) ([]interface{}, *Error) { - keys := []interface{}{} - for _, refElem := range ref { - key, err := JSON(refElem.Value) - if err != nil { - return nil, NewError(TypeErr, rule.Location, "error getting key from value: %s", err.Error()) - } - keys = append(keys, key) - } - return keys, nil -} - -func getObjectTypeRec(keys []interface{}, o types.Type, d *types.DynamicProperty) *types.Object { - if len(keys) == 1 { - staticProps := []*types.StaticProperty{types.NewStaticProperty(keys[0], o)} - return types.NewObject(staticProps, d) - } - - staticProps := []*types.StaticProperty{types.NewStaticProperty(keys[0], getObjectTypeRec(keys[1:], o, d))} - return types.NewObject(staticProps, d) -} - -func getObjectType(ref Ref, o types.Type, rule *Rule, d *types.DynamicProperty) (*types.Object, *Error) { - keys, err := getKeys(ref, rule) - if err != nil { - return nil, err - } - return getObjectTypeRec(keys, o, d), nil -} - -func getRuleAnnotation(as *AnnotationSet, rule *Rule) (result []*SchemaAnnotation) { - - for _, x := range as.GetSubpackagesScope(rule.Module.Package.Path) { - result = append(result, x.Schemas...) - } - - if x := as.GetPackageScope(rule.Module.Package); x != nil { - result = append(result, x.Schemas...) - } - - if x := as.GetDocumentScope(rule.Ref().GroundPrefix()); x != nil { - result = append(result, x.Schemas...) - } - - for _, x := range as.GetRuleScope(rule) { - result = append(result, x.Schemas...) - } - - return result -} - -func processAnnotation(ss *SchemaSet, annot *SchemaAnnotation, rule *Rule, allowNet []string) (types.Type, *Error) { - - var schema interface{} - - if annot.Schema != nil { - if ss == nil { - return nil, nil - } - schema = ss.Get(annot.Schema) - if schema == nil { - return nil, NewError(TypeErr, rule.Location, "undefined schema: %v", annot.Schema) - } - } else if annot.Definition != nil { - schema = *annot.Definition - } - - tpe, err := loadSchema(schema, allowNet) - if err != nil { - return nil, NewError(TypeErr, rule.Location, err.Error()) - } - - return tpe, nil -} - -func errAnnotationRedeclared(a *Annotations, other *Location) *Error { - return NewError(TypeErr, a.Location, "%v annotation redeclared: %v", a.Scope, other) -} +type RefErrInvalidDetail = v1.RefErrInvalidDetail diff --git a/vendor/github.com/open-policy-agent/opa/ast/compare.go b/vendor/github.com/open-policy-agent/opa/ast/compare.go index 3bb6f2a75..d36078e33 100644 --- a/vendor/github.com/open-policy-agent/opa/ast/compare.go +++ b/vendor/github.com/open-policy-agent/opa/ast/compare.go @@ -5,9 +5,7 @@ package ast import ( - "encoding/json" - "fmt" - "math/big" + v1 "github.com/open-policy-agent/opa/v1/ast" ) // Compare returns an integer indicating whether two AST values are less than, @@ -37,360 +35,5 @@ import ( // is empty. // Other comparisons are consistent but not defined. func Compare(a, b interface{}) int { - - if t, ok := a.(*Term); ok { - if t == nil { - a = nil - } else { - a = t.Value - } - } - - if t, ok := b.(*Term); ok { - if t == nil { - b = nil - } else { - b = t.Value - } - } - - if a == nil { - if b == nil { - return 0 - } - return -1 - } - if b == nil { - return 1 - } - - sortA := sortOrder(a) - sortB := sortOrder(b) - - if sortA < sortB { - return -1 - } else if sortB < sortA { - return 1 - } - - switch a := a.(type) { - case Null: - return 0 - case Boolean: - b := b.(Boolean) - if a.Equal(b) { - return 0 - } - if !a { - return -1 - } - return 1 - case Number: - if ai, err := json.Number(a).Int64(); err == nil { - if bi, err := json.Number(b.(Number)).Int64(); err == nil { - if ai == bi { - return 0 - } - if ai < bi { - return -1 - } - return 1 - } - } - - // We use big.Rat for comparing big numbers. - // It replaces big.Float due to following reason: - // big.Float comes with a default precision of 64, and setting a - // larger precision results in more memory being allocated - // (regardless of the actual number we are parsing with SetString). - // - // Note: If we're so close to zero that big.Float says we are zero, do - // *not* big.Rat).SetString on the original string it'll potentially - // take very long. - var bigA, bigB *big.Rat - fa, ok := new(big.Float).SetString(string(a)) - if !ok { - panic("illegal value") - } - if fa.IsInt() { - if i, _ := fa.Int64(); i == 0 { - bigA = new(big.Rat).SetInt64(0) - } - } - if bigA == nil { - bigA, ok = new(big.Rat).SetString(string(a)) - if !ok { - panic("illegal value") - } - } - - fb, ok := new(big.Float).SetString(string(b.(Number))) - if !ok { - panic("illegal value") - } - if fb.IsInt() { - if i, _ := fb.Int64(); i == 0 { - bigB = new(big.Rat).SetInt64(0) - } - } - if bigB == nil { - bigB, ok = new(big.Rat).SetString(string(b.(Number))) - if !ok { - panic("illegal value") - } - } - - return bigA.Cmp(bigB) - case String: - b := b.(String) - if a.Equal(b) { - return 0 - } - if a < b { - return -1 - } - return 1 - case Var: - b := b.(Var) - if a.Equal(b) { - return 0 - } - if a < b { - return -1 - } - return 1 - case Ref: - b := b.(Ref) - return termSliceCompare(a, b) - case *Array: - b := b.(*Array) - return termSliceCompare(a.elems, b.elems) - case *lazyObj: - return Compare(a.force(), b) - case *object: - if x, ok := b.(*lazyObj); ok { - b = x.force() - } - b := b.(*object) - return a.Compare(b) - case Set: - b := b.(Set) - return a.Compare(b) - case *ArrayComprehension: - b := b.(*ArrayComprehension) - if cmp := Compare(a.Term, b.Term); cmp != 0 { - return cmp - } - return Compare(a.Body, b.Body) - case *ObjectComprehension: - b := b.(*ObjectComprehension) - if cmp := Compare(a.Key, b.Key); cmp != 0 { - return cmp - } - if cmp := Compare(a.Value, b.Value); cmp != 0 { - return cmp - } - return Compare(a.Body, b.Body) - case *SetComprehension: - b := b.(*SetComprehension) - if cmp := Compare(a.Term, b.Term); cmp != 0 { - return cmp - } - return Compare(a.Body, b.Body) - case Call: - b := b.(Call) - return termSliceCompare(a, b) - case *Expr: - b := b.(*Expr) - return a.Compare(b) - case *SomeDecl: - b := b.(*SomeDecl) - return a.Compare(b) - case *Every: - b := b.(*Every) - return a.Compare(b) - case *With: - b := b.(*With) - return a.Compare(b) - case Body: - b := b.(Body) - return a.Compare(b) - case *Head: - b := b.(*Head) - return a.Compare(b) - case *Rule: - b := b.(*Rule) - return a.Compare(b) - case Args: - b := b.(Args) - return termSliceCompare(a, b) - case *Import: - b := b.(*Import) - return a.Compare(b) - case *Package: - b := b.(*Package) - return a.Compare(b) - case *Annotations: - b := b.(*Annotations) - return a.Compare(b) - case *Module: - b := b.(*Module) - return a.Compare(b) - } - panic(fmt.Sprintf("illegal value: %T", a)) -} - -type termSlice []*Term - -func (s termSlice) Less(i, j int) bool { return Compare(s[i].Value, s[j].Value) < 0 } -func (s termSlice) Swap(i, j int) { x := s[i]; s[i] = s[j]; s[j] = x } -func (s termSlice) Len() int { return len(s) } - -func sortOrder(x interface{}) int { - switch x.(type) { - case Null: - return 0 - case Boolean: - return 1 - case Number: - return 2 - case String: - return 3 - case Var: - return 4 - case Ref: - return 5 - case *Array: - return 6 - case Object: - return 7 - case Set: - return 8 - case *ArrayComprehension: - return 9 - case *ObjectComprehension: - return 10 - case *SetComprehension: - return 11 - case Call: - return 12 - case Args: - return 13 - case *Expr: - return 100 - case *SomeDecl: - return 101 - case *Every: - return 102 - case *With: - return 110 - case *Head: - return 120 - case Body: - return 200 - case *Rule: - return 1000 - case *Import: - return 1001 - case *Package: - return 1002 - case *Annotations: - return 1003 - case *Module: - return 10000 - } - panic(fmt.Sprintf("illegal value: %T", x)) -} - -func importsCompare(a, b []*Import) int { - minLen := len(a) - if len(b) < minLen { - minLen = len(b) - } - for i := 0; i < minLen; i++ { - if cmp := a[i].Compare(b[i]); cmp != 0 { - return cmp - } - } - if len(a) < len(b) { - return -1 - } - if len(b) < len(a) { - return 1 - } - return 0 -} - -func annotationsCompare(a, b []*Annotations) int { - minLen := len(a) - if len(b) < minLen { - minLen = len(b) - } - for i := 0; i < minLen; i++ { - if cmp := a[i].Compare(b[i]); cmp != 0 { - return cmp - } - } - if len(a) < len(b) { - return -1 - } - if len(b) < len(a) { - return 1 - } - return 0 -} - -func rulesCompare(a, b []*Rule) int { - minLen := len(a) - if len(b) < minLen { - minLen = len(b) - } - for i := 0; i < minLen; i++ { - if cmp := a[i].Compare(b[i]); cmp != 0 { - return cmp - } - } - if len(a) < len(b) { - return -1 - } - if len(b) < len(a) { - return 1 - } - return 0 -} - -func termSliceCompare(a, b []*Term) int { - minLen := len(a) - if len(b) < minLen { - minLen = len(b) - } - for i := 0; i < minLen; i++ { - if cmp := Compare(a[i], b[i]); cmp != 0 { - return cmp - } - } - if len(a) < len(b) { - return -1 - } else if len(b) < len(a) { - return 1 - } - return 0 -} - -func withSliceCompare(a, b []*With) int { - minLen := len(a) - if len(b) < minLen { - minLen = len(b) - } - for i := 0; i < minLen; i++ { - if cmp := Compare(a[i], b[i]); cmp != 0 { - return cmp - } - } - if len(a) < len(b) { - return -1 - } else if len(b) < len(a) { - return 1 - } - return 0 + return v1.Compare(a, b) } diff --git a/vendor/github.com/open-policy-agent/opa/ast/compile.go b/vendor/github.com/open-policy-agent/opa/ast/compile.go index 9025f862b..5a3daa910 100644 --- a/vendor/github.com/open-policy-agent/opa/ast/compile.go +++ b/vendor/github.com/open-policy-agent/opa/ast/compile.go @@ -5,5882 +5,123 @@ package ast import ( - "errors" - "fmt" - "io" - "sort" - "strconv" - "strings" - - "github.com/open-policy-agent/opa/ast/location" - "github.com/open-policy-agent/opa/internal/debug" - "github.com/open-policy-agent/opa/internal/gojsonschema" - "github.com/open-policy-agent/opa/metrics" - "github.com/open-policy-agent/opa/types" - "github.com/open-policy-agent/opa/util" + v1 "github.com/open-policy-agent/opa/v1/ast" ) // CompileErrorLimitDefault is the default number errors a compiler will allow before // exiting. const CompileErrorLimitDefault = 10 -var errLimitReached = NewError(CompileErr, nil, "error limit reached") - // Compiler contains the state of a compilation process. -type Compiler struct { - - // Errors contains errors that occurred during the compilation process. - // If there are one or more errors, the compilation process is considered - // "failed". - Errors Errors - - // Modules contains the compiled modules. The compiled modules are the - // output of the compilation process. If the compilation process failed, - // there is no guarantee about the state of the modules. - Modules map[string]*Module - - // ModuleTree organizes the modules into a tree where each node is keyed by - // an element in the module's package path. E.g., given modules containing - // the following package directives: "a", "a.b", "a.c", and "a.b", the - // resulting module tree would be: - // - // root - // | - // +--- data (no modules) - // | - // +--- a (1 module) - // | - // +--- b (2 modules) - // | - // +--- c (1 module) - // - ModuleTree *ModuleTreeNode - - // RuleTree organizes rules into a tree where each node is keyed by an - // element in the rule's path. The rule path is the concatenation of the - // containing package and the stringified rule name. E.g., given the - // following module: - // - // package ex - // p[1] { true } - // p[2] { true } - // q = true - // a.b.c = 3 - // - // root - // | - // +--- data (no rules) - // | - // +--- ex (no rules) - // | - // +--- p (2 rules) - // | - // +--- q (1 rule) - // | - // +--- a - // | - // +--- b - // | - // +--- c (1 rule) - // - // Another example with general refs containing vars at arbitrary locations: - // - // package ex - // a.b[x].d { x := "c" } # R1 - // a.b.c[x] { x := "d" } # R2 - // a.b[x][y] { x := "c"; y := "d" } # R3 - // p := true # R4 - // - // root - // | - // +--- data (no rules) - // | - // +--- ex (no rules) - // | - // +--- a - // | | - // | +--- b (R1, R3) - // | | - // | +--- c (R2) - // | - // +--- p (R4) - RuleTree *TreeNode - - // Graph contains dependencies between rules. An edge (u,v) is added to the - // graph if rule 'u' refers to the virtual document defined by 'v'. - Graph *Graph - - // TypeEnv holds type information for values inferred by the compiler. - TypeEnv *TypeEnv - - // RewrittenVars is a mapping of variables that have been rewritten - // with the key being the generated name and value being the original. - RewrittenVars map[Var]Var - - // Capabliities required by the modules that were compiled. - Required *Capabilities - - localvargen *localVarGenerator - moduleLoader ModuleLoader - ruleIndices *util.HashMap - stages []stage - maxErrs int - sorted []string // list of sorted module names - pathExists func([]string) (bool, error) - after map[string][]CompilerStageDefinition - metrics metrics.Metrics - capabilities *Capabilities // user-supplied capabilities - imports map[string][]*Import // saved imports from stripping - builtins map[string]*Builtin // universe of built-in functions - customBuiltins map[string]*Builtin // user-supplied custom built-in functions (deprecated: use capabilities) - unsafeBuiltinsMap map[string]struct{} // user-supplied set of unsafe built-ins functions to block (deprecated: use capabilities) - deprecatedBuiltinsMap map[string]struct{} // set of deprecated, but not removed, built-in functions - enablePrintStatements bool // indicates if print statements should be elided (default) - comprehensionIndices map[*Term]*ComprehensionIndex // comprehension key index - initialized bool // indicates if init() has been called - debug debug.Debug // emits debug information produced during compilation - schemaSet *SchemaSet // user-supplied schemas for input and data documents - inputType types.Type // global input type retrieved from schema set - annotationSet *AnnotationSet // hierarchical set of annotations - strict bool // enforce strict compilation checks - keepModules bool // whether to keep the unprocessed, parse modules (below) - parsedModules map[string]*Module // parsed, but otherwise unprocessed modules, kept track of when keepModules is true - useTypeCheckAnnotations bool // whether to provide annotated information (schemas) to the type checker - allowUndefinedFuncCalls bool // don't error on calls to unknown functions. - evalMode CompilerEvalMode // - rewriteTestRulesForTracing bool // rewrite test rules to capture dynamic values for tracing. -} +type Compiler = v1.Compiler // CompilerStage defines the interface for stages in the compiler. -type CompilerStage func(*Compiler) *Error +type CompilerStage = v1.CompilerStage // CompilerEvalMode allows toggling certain stages that are only // needed for certain modes, Concretely, only "topdown" mode will // have the compiler build comprehension and rule indices. -type CompilerEvalMode int +type CompilerEvalMode = v1.CompilerEvalMode const ( // EvalModeTopdown (default) instructs the compiler to build rule // and comprehension indices used by topdown evaluation. - EvalModeTopdown CompilerEvalMode = iota + EvalModeTopdown = v1.EvalModeTopdown // EvalModeIR makes the compiler skip the stages for comprehension // and rule indices. - EvalModeIR + EvalModeIR = v1.EvalModeIR ) // CompilerStageDefinition defines a compiler stage -type CompilerStageDefinition struct { - Name string - MetricName string - Stage CompilerStage -} +type CompilerStageDefinition = v1.CompilerStageDefinition // RulesOptions defines the options for retrieving rules by Ref from the // compiler. -type RulesOptions struct { - // IncludeHiddenModules determines if the result contains hidden modules, - // currently only the "system" namespace, i.e. "data.system.*". - IncludeHiddenModules bool -} +type RulesOptions = v1.RulesOptions // QueryContext contains contextual information for running an ad-hoc query. // // Ad-hoc queries can be run in the context of a package and imports may be // included to provide concise access to data. -type QueryContext struct { - Package *Package - Imports []*Import -} +type QueryContext = v1.QueryContext // NewQueryContext returns a new QueryContext object. func NewQueryContext() *QueryContext { - return &QueryContext{} -} - -// WithPackage sets the pkg on qc. -func (qc *QueryContext) WithPackage(pkg *Package) *QueryContext { - if qc == nil { - qc = NewQueryContext() - } - qc.Package = pkg - return qc -} - -// WithImports sets the imports on qc. -func (qc *QueryContext) WithImports(imports []*Import) *QueryContext { - if qc == nil { - qc = NewQueryContext() - } - qc.Imports = imports - return qc -} - -// Copy returns a deep copy of qc. -func (qc *QueryContext) Copy() *QueryContext { - if qc == nil { - return nil - } - cpy := *qc - if cpy.Package != nil { - cpy.Package = qc.Package.Copy() - } - cpy.Imports = make([]*Import, len(qc.Imports)) - for i := range qc.Imports { - cpy.Imports[i] = qc.Imports[i].Copy() - } - return &cpy + return v1.NewQueryContext() } // QueryCompiler defines the interface for compiling ad-hoc queries. -type QueryCompiler interface { - - // Compile should be called to compile ad-hoc queries. The return value is - // the compiled version of the query. - Compile(q Body) (Body, error) - - // TypeEnv returns the type environment built after running type checking - // on the query. - TypeEnv() *TypeEnv - - // WithContext sets the QueryContext on the QueryCompiler. Subsequent calls - // to Compile will take the QueryContext into account. - WithContext(qctx *QueryContext) QueryCompiler - - // WithEnablePrintStatements enables print statements in queries compiled - // with the QueryCompiler. - WithEnablePrintStatements(yes bool) QueryCompiler - - // WithUnsafeBuiltins sets the built-in functions to treat as unsafe and not - // allow inside of queries. By default the query compiler inherits the - // compiler's unsafe built-in functions. This function allows callers to - // override that set. If an empty (non-nil) map is provided, all built-ins - // are allowed. - WithUnsafeBuiltins(unsafe map[string]struct{}) QueryCompiler - - // WithStageAfter registers a stage to run during query compilation after - // the named stage. - WithStageAfter(after string, stage QueryCompilerStageDefinition) QueryCompiler - - // RewrittenVars maps generated vars in the compiled query to vars from the - // parsed query. For example, given the query "input := 1" the rewritten - // query would be "__local0__ = 1". The mapping would then be {__local0__: input}. - RewrittenVars() map[Var]Var - - // ComprehensionIndex returns an index data structure for the given comprehension - // term. If no index is found, returns nil. - ComprehensionIndex(term *Term) *ComprehensionIndex - - // WithStrict enables strict mode for the query compiler. - WithStrict(strict bool) QueryCompiler -} +type QueryCompiler = v1.QueryCompiler // QueryCompilerStage defines the interface for stages in the query compiler. -type QueryCompilerStage func(QueryCompiler, Body) (Body, error) +type QueryCompilerStage = v1.QueryCompilerStage // QueryCompilerStageDefinition defines a QueryCompiler stage -type QueryCompilerStageDefinition struct { - Name string - MetricName string - Stage QueryCompilerStage -} - -type stage struct { - name string - metricName string - f func() -} +type QueryCompilerStageDefinition = v1.QueryCompilerStageDefinition // NewCompiler returns a new empty compiler. func NewCompiler() *Compiler { - - c := &Compiler{ - Modules: map[string]*Module{}, - RewrittenVars: map[Var]Var{}, - Required: &Capabilities{}, - ruleIndices: util.NewHashMap(func(a, b util.T) bool { - r1, r2 := a.(Ref), b.(Ref) - return r1.Equal(r2) - }, func(x util.T) int { - return x.(Ref).Hash() - }), - maxErrs: CompileErrorLimitDefault, - after: map[string][]CompilerStageDefinition{}, - unsafeBuiltinsMap: map[string]struct{}{}, - deprecatedBuiltinsMap: map[string]struct{}{}, - comprehensionIndices: map[*Term]*ComprehensionIndex{}, - debug: debug.Discard(), - } - - c.ModuleTree = NewModuleTree(nil) - c.RuleTree = NewRuleTree(c.ModuleTree) - - c.stages = []stage{ - // Reference resolution should run first as it may be used to lazily - // load additional modules. If any stages run before resolution, they - // need to be re-run after resolution. - {"ResolveRefs", "compile_stage_resolve_refs", c.resolveAllRefs}, - // The local variable generator must be initialized after references are - // resolved and the dynamic module loader has run but before subsequent - // stages that need to generate variables. - {"InitLocalVarGen", "compile_stage_init_local_var_gen", c.initLocalVarGen}, - {"RewriteRuleHeadRefs", "compile_stage_rewrite_rule_head_refs", c.rewriteRuleHeadRefs}, - {"CheckKeywordOverrides", "compile_stage_check_keyword_overrides", c.checkKeywordOverrides}, - {"CheckDuplicateImports", "compile_stage_check_duplicate_imports", c.checkDuplicateImports}, - {"RemoveImports", "compile_stage_remove_imports", c.removeImports}, - {"SetModuleTree", "compile_stage_set_module_tree", c.setModuleTree}, - {"SetRuleTree", "compile_stage_set_rule_tree", c.setRuleTree}, // depends on RewriteRuleHeadRefs - {"RewriteLocalVars", "compile_stage_rewrite_local_vars", c.rewriteLocalVars}, - {"CheckVoidCalls", "compile_stage_check_void_calls", c.checkVoidCalls}, - {"RewritePrintCalls", "compile_stage_rewrite_print_calls", c.rewritePrintCalls}, - {"RewriteExprTerms", "compile_stage_rewrite_expr_terms", c.rewriteExprTerms}, - {"ParseMetadataBlocks", "compile_stage_parse_metadata_blocks", c.parseMetadataBlocks}, - {"SetAnnotationSet", "compile_stage_set_annotationset", c.setAnnotationSet}, - {"RewriteRegoMetadataCalls", "compile_stage_rewrite_rego_metadata_calls", c.rewriteRegoMetadataCalls}, - {"SetGraph", "compile_stage_set_graph", c.setGraph}, - {"RewriteComprehensionTerms", "compile_stage_rewrite_comprehension_terms", c.rewriteComprehensionTerms}, - {"RewriteRefsInHead", "compile_stage_rewrite_refs_in_head", c.rewriteRefsInHead}, - {"RewriteWithValues", "compile_stage_rewrite_with_values", c.rewriteWithModifiers}, - {"CheckRuleConflicts", "compile_stage_check_rule_conflicts", c.checkRuleConflicts}, - {"CheckUndefinedFuncs", "compile_stage_check_undefined_funcs", c.checkUndefinedFuncs}, - {"CheckSafetyRuleHeads", "compile_stage_check_safety_rule_heads", c.checkSafetyRuleHeads}, - {"CheckSafetyRuleBodies", "compile_stage_check_safety_rule_bodies", c.checkSafetyRuleBodies}, - {"RewriteEquals", "compile_stage_rewrite_equals", c.rewriteEquals}, - {"RewriteDynamicTerms", "compile_stage_rewrite_dynamic_terms", c.rewriteDynamicTerms}, - {"RewriteTestRulesForTracing", "compile_stage_rewrite_test_rules_for_tracing", c.rewriteTestRuleEqualities}, // must run after RewriteDynamicTerms - {"CheckRecursion", "compile_stage_check_recursion", c.checkRecursion}, - {"CheckTypes", "compile_stage_check_types", c.checkTypes}, // must be run after CheckRecursion - {"CheckUnsafeBuiltins", "compile_state_check_unsafe_builtins", c.checkUnsafeBuiltins}, - {"CheckDeprecatedBuiltins", "compile_state_check_deprecated_builtins", c.checkDeprecatedBuiltins}, - {"BuildRuleIndices", "compile_stage_rebuild_indices", c.buildRuleIndices}, - {"BuildComprehensionIndices", "compile_stage_rebuild_comprehension_indices", c.buildComprehensionIndices}, - {"BuildRequiredCapabilities", "compile_stage_build_required_capabilities", c.buildRequiredCapabilities}, - } - - return c -} - -// SetErrorLimit sets the number of errors the compiler can encounter before it -// quits. Zero or a negative number indicates no limit. -func (c *Compiler) SetErrorLimit(limit int) *Compiler { - c.maxErrs = limit - return c -} - -// WithEnablePrintStatements enables print statements inside of modules compiled -// by the compiler. If print statements are not enabled, calls to print() are -// erased at compile-time. -func (c *Compiler) WithEnablePrintStatements(yes bool) *Compiler { - c.enablePrintStatements = yes - return c -} - -// WithPathConflictsCheck enables base-virtual document conflict -// detection. The compiler will check that rules don't overlap with -// paths that exist as determined by the provided callable. -func (c *Compiler) WithPathConflictsCheck(fn func([]string) (bool, error)) *Compiler { - c.pathExists = fn - return c -} - -// WithStageAfter registers a stage to run during compilation after -// the named stage. -func (c *Compiler) WithStageAfter(after string, stage CompilerStageDefinition) *Compiler { - c.after[after] = append(c.after[after], stage) - return c -} - -// WithMetrics will set a metrics.Metrics and be used for profiling -// the Compiler instance. -func (c *Compiler) WithMetrics(metrics metrics.Metrics) *Compiler { - c.metrics = metrics - return c -} - -// WithCapabilities sets capabilities to enable during compilation. Capabilities allow the caller -// to specify the set of built-in functions available to the policy. In the future, capabilities -// may be able to restrict access to other language features. Capabilities allow callers to check -// if policies are compatible with a particular version of OPA. If policies are a compiled for a -// specific version of OPA, there is no guarantee that _this_ version of OPA can evaluate them -// successfully. -func (c *Compiler) WithCapabilities(capabilities *Capabilities) *Compiler { - c.capabilities = capabilities - return c -} - -// Capabilities returns the capabilities enabled during compilation. -func (c *Compiler) Capabilities() *Capabilities { - return c.capabilities -} - -// WithDebug sets where debug messages are written to. Passing `nil` has no -// effect. -func (c *Compiler) WithDebug(sink io.Writer) *Compiler { - if sink != nil { - c.debug = debug.New(sink) - } - return c -} - -// WithBuiltins is deprecated. Use WithCapabilities instead. -func (c *Compiler) WithBuiltins(builtins map[string]*Builtin) *Compiler { - c.customBuiltins = make(map[string]*Builtin) - for k, v := range builtins { - c.customBuiltins[k] = v - } - return c -} - -// WithUnsafeBuiltins is deprecated. Use WithCapabilities instead. -func (c *Compiler) WithUnsafeBuiltins(unsafeBuiltins map[string]struct{}) *Compiler { - for name := range unsafeBuiltins { - c.unsafeBuiltinsMap[name] = struct{}{} - } - return c -} - -// WithStrict enables strict mode in the compiler. -func (c *Compiler) WithStrict(strict bool) *Compiler { - c.strict = strict - return c -} - -// WithKeepModules enables retaining unprocessed modules in the compiler. -// Note that the modules aren't copied on the way in or out -- so when -// accessing them via ParsedModules(), mutations will occur in the module -// map that was passed into Compile().` -func (c *Compiler) WithKeepModules(y bool) *Compiler { - c.keepModules = y - return c -} - -// WithUseTypeCheckAnnotations use schema annotations during type checking -func (c *Compiler) WithUseTypeCheckAnnotations(enabled bool) *Compiler { - c.useTypeCheckAnnotations = enabled - return c -} - -func (c *Compiler) WithAllowUndefinedFunctionCalls(allow bool) *Compiler { - c.allowUndefinedFuncCalls = allow - return c -} - -// WithEvalMode allows setting the CompilerEvalMode of the compiler -func (c *Compiler) WithEvalMode(e CompilerEvalMode) *Compiler { - c.evalMode = e - return c -} - -// WithRewriteTestRules enables rewriting test rules to capture dynamic values in local variables, -// so they can be accessed by tracing. -func (c *Compiler) WithRewriteTestRules(rewrite bool) *Compiler { - c.rewriteTestRulesForTracing = rewrite - return c -} - -// ParsedModules returns the parsed, unprocessed modules from the compiler. -// It is `nil` if keeping modules wasn't enabled via `WithKeepModules(true)`. -// The map includes all modules loaded via the ModuleLoader, if one was used. -func (c *Compiler) ParsedModules() map[string]*Module { - return c.parsedModules -} - -func (c *Compiler) QueryCompiler() QueryCompiler { - c.init() - c0 := *c - return newQueryCompiler(&c0) -} - -// Compile runs the compilation process on the input modules. The compiled -// version of the modules and associated data structures are stored on the -// compiler. If the compilation process fails for any reason, the compiler will -// contain a slice of errors. -func (c *Compiler) Compile(modules map[string]*Module) { - - c.init() - - c.Modules = make(map[string]*Module, len(modules)) - c.sorted = make([]string, 0, len(modules)) - - if c.keepModules { - c.parsedModules = make(map[string]*Module, len(modules)) - } else { - c.parsedModules = nil - } - - for k, v := range modules { - c.Modules[k] = v.Copy() - c.sorted = append(c.sorted, k) - if c.parsedModules != nil { - c.parsedModules[k] = v - } - } - - sort.Strings(c.sorted) - - c.compile() -} - -// WithSchemas sets a schemaSet to the compiler -func (c *Compiler) WithSchemas(schemas *SchemaSet) *Compiler { - c.schemaSet = schemas - return c -} - -// Failed returns true if a compilation error has been encountered. -func (c *Compiler) Failed() bool { - return len(c.Errors) > 0 -} - -// ComprehensionIndex returns a data structure specifying how to index comprehension -// results so that callers do not have to recompute the comprehension more than once. -// If no index is found, returns nil. -func (c *Compiler) ComprehensionIndex(term *Term) *ComprehensionIndex { - return c.comprehensionIndices[term] -} - -// GetArity returns the number of args a function referred to by ref takes. If -// ref refers to built-in function, the built-in declaration is consulted, -// otherwise, the ref is used to perform a ruleset lookup. -func (c *Compiler) GetArity(ref Ref) int { - if bi := c.builtins[ref.String()]; bi != nil { - return len(bi.Decl.FuncArgs().Args) - } - rules := c.GetRulesExact(ref) - if len(rules) == 0 { - return -1 - } - return len(rules[0].Head.Args) -} - -// GetRulesExact returns a slice of rules referred to by the reference. -// -// E.g., given the following module: -// -// package a.b.c -// -// p[k] = v { ... } # rule1 -// p[k1] = v1 { ... } # rule2 -// -// The following calls yield the rules on the right. -// -// GetRulesExact("data.a.b.c.p") => [rule1, rule2] -// GetRulesExact("data.a.b.c.p.x") => nil -// GetRulesExact("data.a.b.c") => nil -func (c *Compiler) GetRulesExact(ref Ref) (rules []*Rule) { - node := c.RuleTree - - for _, x := range ref { - if node = node.Child(x.Value); node == nil { - return nil - } - } - - return extractRules(node.Values) -} - -// GetRulesForVirtualDocument returns a slice of rules that produce the virtual -// document referred to by the reference. -// -// E.g., given the following module: -// -// package a.b.c -// -// p[k] = v { ... } # rule1 -// p[k1] = v1 { ... } # rule2 -// -// The following calls yield the rules on the right. -// -// GetRulesForVirtualDocument("data.a.b.c.p") => [rule1, rule2] -// GetRulesForVirtualDocument("data.a.b.c.p.x") => [rule1, rule2] -// GetRulesForVirtualDocument("data.a.b.c") => nil -func (c *Compiler) GetRulesForVirtualDocument(ref Ref) (rules []*Rule) { - - node := c.RuleTree - - for _, x := range ref { - if node = node.Child(x.Value); node == nil { - return nil - } - if len(node.Values) > 0 { - return extractRules(node.Values) - } - } - - return extractRules(node.Values) -} - -// GetRulesWithPrefix returns a slice of rules that share the prefix ref. -// -// E.g., given the following module: -// -// package a.b.c -// -// p[x] = y { ... } # rule1 -// p[k] = v { ... } # rule2 -// q { ... } # rule3 -// -// The following calls yield the rules on the right. -// -// GetRulesWithPrefix("data.a.b.c.p") => [rule1, rule2] -// GetRulesWithPrefix("data.a.b.c.p.a") => nil -// GetRulesWithPrefix("data.a.b.c") => [rule1, rule2, rule3] -func (c *Compiler) GetRulesWithPrefix(ref Ref) (rules []*Rule) { - - node := c.RuleTree - - for _, x := range ref { - if node = node.Child(x.Value); node == nil { - return nil - } - } - - var acc func(node *TreeNode) - - acc = func(node *TreeNode) { - rules = append(rules, extractRules(node.Values)...) - for _, child := range node.Children { - if child.Hide { - continue - } - acc(child) - } - } - - acc(node) - - return rules -} - -func extractRules(s []util.T) []*Rule { - rules := make([]*Rule, len(s)) - for i := range s { - rules[i] = s[i].(*Rule) - } - return rules -} - -// GetRules returns a slice of rules that are referred to by ref. -// -// E.g., given the following module: -// -// package a.b.c -// -// p[x] = y { q[x] = y; ... } # rule1 -// q[x] = y { ... } # rule2 -// -// The following calls yield the rules on the right. -// -// GetRules("data.a.b.c.p") => [rule1] -// GetRules("data.a.b.c.p.x") => [rule1] -// GetRules("data.a.b.c.q") => [rule2] -// GetRules("data.a.b.c") => [rule1, rule2] -// GetRules("data.a.b.d") => nil -func (c *Compiler) GetRules(ref Ref) (rules []*Rule) { - - set := map[*Rule]struct{}{} - - for _, rule := range c.GetRulesForVirtualDocument(ref) { - set[rule] = struct{}{} - } - - for _, rule := range c.GetRulesWithPrefix(ref) { - set[rule] = struct{}{} - } - - for rule := range set { - rules = append(rules, rule) - } - - return rules -} - -// GetRulesDynamic returns a slice of rules that could be referred to by a ref. -// -// Deprecated: use GetRulesDynamicWithOpts -func (c *Compiler) GetRulesDynamic(ref Ref) []*Rule { - return c.GetRulesDynamicWithOpts(ref, RulesOptions{}) -} - -// GetRulesDynamicWithOpts returns a slice of rules that could be referred to by -// a ref. -// When parts of the ref are statically known, we use that information to narrow -// down which rules the ref could refer to, but in the most general case this -// will be an over-approximation. -// -// E.g., given the following modules: -// -// package a.b.c -// -// r1 = 1 # rule1 -// -// and: -// -// package a.d.c -// -// r2 = 2 # rule2 -// -// The following calls yield the rules on the right. -// -// GetRulesDynamicWithOpts("data.a[x].c[y]", opts) => [rule1, rule2] -// GetRulesDynamicWithOpts("data.a[x].c.r2", opts) => [rule2] -// GetRulesDynamicWithOpts("data.a.b[x][y]", opts) => [rule1] -// -// Using the RulesOptions parameter, the inclusion of hidden modules can be -// controlled: -// -// With -// -// package system.main -// -// r3 = 3 # rule3 -// -// We'd get this result: -// -// GetRulesDynamicWithOpts("data[x]", RulesOptions{IncludeHiddenModules: true}) => [rule1, rule2, rule3] -// -// Without the options, it would be excluded. -func (c *Compiler) GetRulesDynamicWithOpts(ref Ref, opts RulesOptions) []*Rule { - node := c.RuleTree - - set := map[*Rule]struct{}{} - var walk func(node *TreeNode, i int) - walk = func(node *TreeNode, i int) { - switch { - case i >= len(ref): - // We've reached the end of the reference and want to collect everything - // under this "prefix". - node.DepthFirst(func(descendant *TreeNode) bool { - insertRules(set, descendant.Values) - if opts.IncludeHiddenModules { - return false - } - return descendant.Hide - }) - - case i == 0 || IsConstant(ref[i].Value): - // The head of the ref is always grounded. In case another part of the - // ref is also grounded, we can lookup the exact child. If it's not found - // we can immediately return... - if child := node.Child(ref[i].Value); child != nil { - if len(child.Values) > 0 { - // Add any rules at this position - insertRules(set, child.Values) - } - // There might still be "sub-rules" contributing key-value "overrides" for e.g. partial object rules, continue walking - walk(child, i+1) - } else { - return - } - - default: - // This part of the ref is a dynamic term. We can't know what it refers - // to and will just need to try all of the children. - for _, child := range node.Children { - if child.Hide && !opts.IncludeHiddenModules { - continue - } - insertRules(set, child.Values) - walk(child, i+1) - } - } - } - - walk(node, 0) - rules := make([]*Rule, 0, len(set)) - for rule := range set { - rules = append(rules, rule) - } - return rules -} - -// Utility: add all rule values to the set. -func insertRules(set map[*Rule]struct{}, rules []util.T) { - for _, rule := range rules { - set[rule.(*Rule)] = struct{}{} - } -} - -// RuleIndex returns a RuleIndex built for the rule set referred to by path. -// The path must refer to the rule set exactly, i.e., given a rule set at path -// data.a.b.c.p, refs data.a.b.c.p.x and data.a.b.c would not return a -// RuleIndex built for the rule. -func (c *Compiler) RuleIndex(path Ref) RuleIndex { - r, ok := c.ruleIndices.Get(path) - if !ok { - return nil - } - return r.(RuleIndex) -} - -// PassesTypeCheck determines whether the given body passes type checking -func (c *Compiler) PassesTypeCheck(body Body) bool { - checker := newTypeChecker().WithSchemaSet(c.schemaSet).WithInputType(c.inputType) - env := c.TypeEnv - _, errs := checker.CheckBody(env, body) - return len(errs) == 0 -} - -// PassesTypeCheckRules determines whether the given rules passes type checking -func (c *Compiler) PassesTypeCheckRules(rules []*Rule) Errors { - elems := []util.T{} - - for _, rule := range rules { - elems = append(elems, rule) - } - - // Load the global input schema if one was provided. - if c.schemaSet != nil { - if schema := c.schemaSet.Get(SchemaRootRef); schema != nil { - - var allowNet []string - if c.capabilities != nil { - allowNet = c.capabilities.AllowNet - } - - tpe, err := loadSchema(schema, allowNet) - if err != nil { - return Errors{NewError(TypeErr, nil, err.Error())} - } - c.inputType = tpe - } - } - - var as *AnnotationSet - if c.useTypeCheckAnnotations { - as = c.annotationSet - } - - checker := newTypeChecker().WithSchemaSet(c.schemaSet).WithInputType(c.inputType) - - if c.TypeEnv == nil { - if c.capabilities == nil { - c.capabilities = CapabilitiesForThisVersion() - } - - c.builtins = make(map[string]*Builtin, len(c.capabilities.Builtins)+len(c.customBuiltins)) - - for _, bi := range c.capabilities.Builtins { - c.builtins[bi.Name] = bi - } - - for name, bi := range c.customBuiltins { - c.builtins[name] = bi - } - - c.TypeEnv = checker.Env(c.builtins) - } - - _, errs := checker.CheckTypes(c.TypeEnv, elems, as) - return errs + return v1.NewCompiler().WithDefaultRegoVersion(DefaultRegoVersion) } // ModuleLoader defines the interface that callers can implement to enable lazy // loading of modules during compilation. -type ModuleLoader func(resolved map[string]*Module) (parsed map[string]*Module, err error) - -// WithModuleLoader sets f as the ModuleLoader on the compiler. -// -// The compiler will invoke the ModuleLoader after resolving all references in -// the current set of input modules. The ModuleLoader can return a new -// collection of parsed modules that are to be included in the compilation -// process. This process will repeat until the ModuleLoader returns an empty -// collection or an error. If an error is returned, compilation will stop -// immediately. -func (c *Compiler) WithModuleLoader(f ModuleLoader) *Compiler { - c.moduleLoader = f - return c -} - -func (c *Compiler) counterAdd(name string, n uint64) { - if c.metrics == nil { - return - } - c.metrics.Counter(name).Add(n) -} - -func (c *Compiler) buildRuleIndices() { - - c.RuleTree.DepthFirst(func(node *TreeNode) bool { - if len(node.Values) == 0 { - return false - } - rules := extractRules(node.Values) - hasNonGroundRef := false - for _, r := range rules { - hasNonGroundRef = !r.Head.Ref().IsGround() - } - if hasNonGroundRef { - // Collect children to ensure that all rules within the extent of a rule with a general ref - // are found on the same index. E.g. the following rules should be indexed under data.a.b.c: - // - // package a - // b.c[x].e := 1 { x := input.x } - // b.c.d := 2 - // b.c.d2.e[x] := 3 { x := input.x } - for _, child := range node.Children { - child.DepthFirst(func(c *TreeNode) bool { - rules = append(rules, extractRules(c.Values)...) - return false - }) - } - } - - index := newBaseDocEqIndex(func(ref Ref) bool { - return isVirtual(c.RuleTree, ref.GroundPrefix()) - }) - if index.Build(rules) { - c.ruleIndices.Put(rules[0].Ref().GroundPrefix(), index) - } - return hasNonGroundRef // currently, we don't allow those branches to go deeper - }) - -} - -func (c *Compiler) buildComprehensionIndices() { - for _, name := range c.sorted { - WalkRules(c.Modules[name], func(r *Rule) bool { - candidates := r.Head.Args.Vars() - candidates.Update(ReservedVars) - n := buildComprehensionIndices(c.debug, c.GetArity, candidates, c.RewrittenVars, r.Body, c.comprehensionIndices) - c.counterAdd(compileStageComprehensionIndexBuild, n) - return false - }) - } -} +type ModuleLoader = v1.ModuleLoader -// buildRequiredCapabilities updates the required capabilities on the compiler -// to include any keyword and feature dependencies present in the modules. The -// built-in function dependencies will have already been added by the type -// checker. -func (c *Compiler) buildRequiredCapabilities() { - - features := map[string]struct{}{} - - // extract required keywords from modules - keywords := map[string]struct{}{} - futureKeywordsPrefix := Ref{FutureRootDocument, StringTerm("keywords")} - for _, name := range c.sorted { - for _, imp := range c.imports[name] { - path := imp.Path.Value.(Ref) - switch { - case path.Equal(RegoV1CompatibleRef): - features[FeatureRegoV1Import] = struct{}{} - case path.HasPrefix(futureKeywordsPrefix): - if len(path) == 2 { - for kw := range futureKeywords { - keywords[kw] = struct{}{} - } - } else { - keywords[string(path[2].Value.(String))] = struct{}{} - } - } - } - } - - c.Required.FutureKeywords = stringMapToSortedSlice(keywords) - - // extract required features from modules - - for _, name := range c.sorted { - for _, rule := range c.Modules[name].Rules { - refLen := len(rule.Head.Reference) - if refLen >= 3 { - if refLen > len(rule.Head.Reference.ConstantPrefix()) { - features[FeatureRefHeads] = struct{}{} - } else { - features[FeatureRefHeadStringPrefixes] = struct{}{} - } - } - } - } - - c.Required.Features = stringMapToSortedSlice(features) - - for i, bi := range c.Required.Builtins { - c.Required.Builtins[i] = bi.Minimal() - } -} - -func stringMapToSortedSlice(xs map[string]struct{}) []string { - if len(xs) == 0 { - return nil - } - s := make([]string, 0, len(xs)) - for k := range xs { - s = append(s, k) - } - sort.Strings(s) - return s -} - -// checkRecursion ensures that there are no recursive definitions, i.e., there are -// no cycles in the Graph. -func (c *Compiler) checkRecursion() { - eq := func(a, b util.T) bool { - return a.(*Rule) == b.(*Rule) - } - - c.RuleTree.DepthFirst(func(node *TreeNode) bool { - for _, rule := range node.Values { - for node := rule.(*Rule); node != nil; node = node.Else { - c.checkSelfPath(node.Loc(), eq, node, node) - } - } - return false - }) -} - -func (c *Compiler) checkSelfPath(loc *Location, eq func(a, b util.T) bool, a, b util.T) { - tr := NewGraphTraversal(c.Graph) - if p := util.DFSPath(tr, eq, a, b); len(p) > 0 { - n := make([]string, 0, len(p)) - for _, x := range p { - n = append(n, astNodeToString(x)) - } - c.err(NewError(RecursionErr, loc, "rule %v is recursive: %v", astNodeToString(a), strings.Join(n, " -> "))) - } -} - -func astNodeToString(x interface{}) string { - return x.(*Rule).Ref().String() -} - -// checkRuleConflicts ensures that rules definitions are not in conflict. -func (c *Compiler) checkRuleConflicts() { - rw := rewriteVarsInRef(c.RewrittenVars) - - c.RuleTree.DepthFirst(func(node *TreeNode) bool { - if len(node.Values) == 0 { - return false // go deeper - } - - kinds := make(map[RuleKind]struct{}, len(node.Values)) - defaultRules := 0 - completeRules := 0 - partialRules := 0 - arities := make(map[int]struct{}, len(node.Values)) - name := "" - var conflicts []Ref - - for _, rule := range node.Values { - r := rule.(*Rule) - ref := r.Ref() - name = rw(ref.Copy()).String() // varRewriter operates in-place - kinds[r.Head.RuleKind()] = struct{}{} - arities[len(r.Head.Args)] = struct{}{} - if r.Default { - defaultRules++ - } - - // Single-value rules may not have any other rules in their extent. - // Rules with vars in their ref are allowed to have rules inside their extent. - // Only the ground portion (terms before the first var term) of a rule's ref is considered when determining - // whether it's inside the extent of another (c.RuleTree is organized this way already). - // These pairs are invalid: - // - // data.p.q.r { true } # data.p.q is { "r": true } - // data.p.q.r.s { true } - // - // data.p.q.r { true } - // data.p.q.r[s].t { s = input.key } - // - // But this is allowed: - // - // data.p.q.r { true } - // data.p.q[r].s.t { r = input.key } - // - // data.p[r] := x { r = input.key; x = input.bar } - // data.p.q[r] := x { r = input.key; x = input.bar } - // - // data.p.q[r] { r := input.r } - // data.p.q.r.s { true } - // - // data.p.q[r] = 1 { r := "r" } - // data.p.q.s = 2 - // - // data.p[q][r] { q := input.q; r := input.r } - // data.p.q.r { true } - // - // data.p.q[r] { r := input.r } - // data.p[q].r { q := input.q } - // - // data.p.q[r][s] { r := input.r; s := input.s } - // data.p[q].r.s { q := input.q } - - if r.Ref().IsGround() && len(node.Children) > 0 { - conflicts = node.flattenChildren() - } - - if r.Head.RuleKind() == SingleValue && r.Head.Ref().IsGround() { - completeRules++ - } else { - partialRules++ - } - } - - switch { - case conflicts != nil: - c.err(NewError(TypeErr, node.Values[0].(*Rule).Loc(), "rule %v conflicts with %v", name, conflicts)) - - case len(kinds) > 1 || len(arities) > 1 || (completeRules >= 1 && partialRules >= 1): - c.err(NewError(TypeErr, node.Values[0].(*Rule).Loc(), "conflicting rules %v found", name)) - - case defaultRules > 1: - c.err(NewError(TypeErr, node.Values[0].(*Rule).Loc(), "multiple default rules %s found", name)) - } - - return false - }) +// SafetyCheckVisitorParams defines the AST visitor parameters to use for collecting +// variables during the safety check. This has to be exported because it's relied on +// by the copy propagation implementation in topdown. +var SafetyCheckVisitorParams = v1.SafetyCheckVisitorParams - if c.pathExists != nil { - for _, err := range CheckPathConflicts(c, c.pathExists) { - c.err(err) - } - } +// ComprehensionIndex specifies how the comprehension term can be indexed. The keys +// tell the evaluator what variables to use for indexing. In the future, the index +// could be expanded with more information that would allow the evaluator to index +// a larger fragment of comprehensions (e.g., by closing over variables in the outer +// query.) +type ComprehensionIndex = v1.ComprehensionIndex - // NOTE(sr): depthfirst might better use sorted for stable errs? - c.ModuleTree.DepthFirst(func(node *ModuleTreeNode) bool { - for _, mod := range node.Modules { - for _, rule := range mod.Rules { - ref := rule.Head.Ref().GroundPrefix() - // Rules with a dynamic portion in their ref are exempted, as a conflict within the dynamic portion - // can only be detected at eval-time. - if len(ref) < len(rule.Head.Ref()) { - continue - } +// ModuleTreeNode represents a node in the module tree. The module +// tree is keyed by the package path. +type ModuleTreeNode = v1.ModuleTreeNode - childNode, tail := node.find(ref) - if childNode != nil && len(tail) == 0 { - for _, childMod := range childNode.Modules { - // Avoid recursively checking a module for equality unless we know it's a possible self-match. - if childMod.Equal(mod) { - continue // don't self-conflict - } - msg := fmt.Sprintf("%v conflicts with rule %v defined at %v", childMod.Package, rule.Head.Ref(), rule.Loc()) - c.err(NewError(TypeErr, mod.Package.Loc(), msg)) - } - } - } - } - return false - }) -} +// TreeNode represents a node in the rule tree. The rule tree is keyed by +// rule path. +type TreeNode = v1.TreeNode -func (c *Compiler) checkUndefinedFuncs() { - for _, name := range c.sorted { - m := c.Modules[name] - for _, err := range checkUndefinedFuncs(c.TypeEnv, m, c.GetArity, c.RewrittenVars) { - c.err(err) - } - } +// NewRuleTree returns a new TreeNode that represents the root +// of the rule tree populated with the given rules. +func NewRuleTree(mtree *ModuleTreeNode) *TreeNode { + return v1.NewRuleTree(mtree) } -func checkUndefinedFuncs(env *TypeEnv, x interface{}, arity func(Ref) int, rwVars map[Var]Var) Errors { - - var errs Errors - - WalkExprs(x, func(expr *Expr) bool { - if !expr.IsCall() { - return false - } - ref := expr.Operator() - if arity := arity(ref); arity >= 0 { - operands := len(expr.Operands()) - if expr.Generated { // an output var was added - if !expr.IsEquality() && operands != arity+1 { - ref = rewriteVarsInRef(rwVars)(ref) - errs = append(errs, arityMismatchError(env, ref, expr, arity, operands-1)) - return true - } - } else { // either output var or not - if operands != arity && operands != arity+1 { - ref = rewriteVarsInRef(rwVars)(ref) - errs = append(errs, arityMismatchError(env, ref, expr, arity, operands)) - return true - } - } - return false - } - ref = rewriteVarsInRef(rwVars)(ref) - errs = append(errs, NewError(TypeErr, expr.Loc(), "undefined function %v", ref)) - return true - }) - - return errs -} +// Graph represents the graph of dependencies between rules. +type Graph = v1.Graph -func arityMismatchError(env *TypeEnv, f Ref, expr *Expr, exp, act int) *Error { - if want, ok := env.Get(f).(*types.Function); ok { // generate richer error for built-in functions - have := make([]types.Type, len(expr.Operands())) - for i, op := range expr.Operands() { - have[i] = env.Get(op) - } - return newArgError(expr.Loc(), f, "arity mismatch", have, want.NamedFuncArgs()) - } - if act != 1 { - return NewError(TypeErr, expr.Loc(), "function %v has arity %d, got %d arguments", f, exp, act) - } - return NewError(TypeErr, expr.Loc(), "function %v has arity %d, got %d argument", f, exp, act) +// NewGraph returns a new Graph based on modules. The list function must return +// the rules referred to directly by the ref. +func NewGraph(modules map[string]*Module, list func(Ref) []*Rule) *Graph { + return v1.NewGraph(modules, list) } -// checkSafetyRuleBodies ensures that variables appearing in negated expressions or non-target -// positions of built-in expressions will be bound when evaluating the rule from left -// to right, re-ordering as necessary. -func (c *Compiler) checkSafetyRuleBodies() { - for _, name := range c.sorted { - m := c.Modules[name] - WalkRules(m, func(r *Rule) bool { - safe := ReservedVars.Copy() - safe.Update(r.Head.Args.Vars()) - r.Body = c.checkBodySafety(safe, r.Body) - return false - }) - } -} +// GraphTraversal is a Traversal that understands the dependency graph +type GraphTraversal = v1.GraphTraversal -func (c *Compiler) checkBodySafety(safe VarSet, b Body) Body { - reordered, unsafe := reorderBodyForSafety(c.builtins, c.GetArity, safe, b) - if errs := safetyErrorSlice(unsafe, c.RewrittenVars); len(errs) > 0 { - for _, err := range errs { - c.err(err) - } - return b - } - return reordered +// NewGraphTraversal returns a Traversal for the dependency graph +func NewGraphTraversal(graph *Graph) *GraphTraversal { + return v1.NewGraphTraversal(graph) } -// SafetyCheckVisitorParams defines the AST visitor parameters to use for collecting -// variables during the safety check. This has to be exported because it's relied on -// by the copy propagation implementation in topdown. -var SafetyCheckVisitorParams = VarVisitorParams{ - SkipRefCallHead: true, - SkipClosures: true, +// OutputVarsFromBody returns all variables which are the "output" for +// the given body. For safety checks this means that they would be +// made safe by the body. +func OutputVarsFromBody(c *Compiler, body Body, safe VarSet) VarSet { + return v1.OutputVarsFromBody(c, body, safe) } -// checkSafetyRuleHeads ensures that variables appearing in the head of a -// rule also appear in the body. -func (c *Compiler) checkSafetyRuleHeads() { - - for _, name := range c.sorted { - m := c.Modules[name] - WalkRules(m, func(r *Rule) bool { - safe := r.Body.Vars(SafetyCheckVisitorParams) - safe.Update(r.Head.Args.Vars()) - unsafe := r.Head.Vars().Diff(safe) - for v := range unsafe { - if w, ok := c.RewrittenVars[v]; ok { - v = w - } - if !v.IsGenerated() { - c.err(NewError(UnsafeVarErr, r.Loc(), "var %v is unsafe", v)) - } - } - return false - }) - } -} - -func compileSchema(goSchema interface{}, allowNet []string) (*gojsonschema.Schema, error) { - gojsonschema.SetAllowNet(allowNet) - - var refLoader gojsonschema.JSONLoader - sl := gojsonschema.NewSchemaLoader() - - if goSchema != nil { - refLoader = gojsonschema.NewGoLoader(goSchema) - } else { - return nil, fmt.Errorf("no schema as input to compile") - } - schemasCompiled, err := sl.Compile(refLoader) - if err != nil { - return nil, fmt.Errorf("unable to compile the schema: %w", err) - } - return schemasCompiled, nil -} - -func mergeSchemas(schemas ...*gojsonschema.SubSchema) (*gojsonschema.SubSchema, error) { - if len(schemas) == 0 { - return nil, nil - } - var result = schemas[0] - - for i := range schemas { - if len(schemas[i].PropertiesChildren) > 0 { - if !schemas[i].Types.Contains("object") { - if err := schemas[i].Types.Add("object"); err != nil { - return nil, fmt.Errorf("unable to set the type in schemas") - } - } - } else if len(schemas[i].ItemsChildren) > 0 { - if !schemas[i].Types.Contains("array") { - if err := schemas[i].Types.Add("array"); err != nil { - return nil, fmt.Errorf("unable to set the type in schemas") - } - } - } - } - - for i := 1; i < len(schemas); i++ { - if result.Types.String() != schemas[i].Types.String() { - return nil, fmt.Errorf("unable to merge these schemas: type mismatch: %v and %v", result.Types.String(), schemas[i].Types.String()) - } else if result.Types.Contains("object") && len(result.PropertiesChildren) > 0 && schemas[i].Types.Contains("object") && len(schemas[i].PropertiesChildren) > 0 { - result.PropertiesChildren = append(result.PropertiesChildren, schemas[i].PropertiesChildren...) - } else if result.Types.Contains("array") && len(result.ItemsChildren) > 0 && schemas[i].Types.Contains("array") && len(schemas[i].ItemsChildren) > 0 { - for j := 0; j < len(schemas[i].ItemsChildren); j++ { - if len(result.ItemsChildren)-1 < j && !(len(schemas[i].ItemsChildren)-1 < j) { - result.ItemsChildren = append(result.ItemsChildren, schemas[i].ItemsChildren[j]) - } - if result.ItemsChildren[j].Types.String() != schemas[i].ItemsChildren[j].Types.String() { - return nil, fmt.Errorf("unable to merge these schemas") - } - } - } - } - return result, nil -} - -type schemaParser struct { - definitionCache map[string]*cachedDef -} - -type cachedDef struct { - properties []*types.StaticProperty -} - -func newSchemaParser() *schemaParser { - return &schemaParser{ - definitionCache: map[string]*cachedDef{}, - } -} - -func (parser *schemaParser) parseSchema(schema interface{}) (types.Type, error) { - return parser.parseSchemaWithPropertyKey(schema, "") -} - -func (parser *schemaParser) parseSchemaWithPropertyKey(schema interface{}, propertyKey string) (types.Type, error) { - subSchema, ok := schema.(*gojsonschema.SubSchema) - if !ok { - return nil, fmt.Errorf("unexpected schema type %v", subSchema) - } - - // Handle referenced schemas, returns directly when a $ref is found - if subSchema.RefSchema != nil { - if existing, ok := parser.definitionCache[subSchema.Ref.String()]; ok { - return types.NewObject(existing.properties, nil), nil - } - return parser.parseSchemaWithPropertyKey(subSchema.RefSchema, subSchema.Ref.String()) - } - - // Handle anyOf - if subSchema.AnyOf != nil { - var orType types.Type - - // If there is a core schema, find its type first - if subSchema.Types.IsTyped() { - copySchema := *subSchema - copySchemaRef := ©Schema - copySchemaRef.AnyOf = nil - coreType, err := parser.parseSchema(copySchemaRef) - if err != nil { - return nil, fmt.Errorf("unexpected schema type %v: %w", subSchema, err) - } - - // Only add Object type with static props to orType - if objType, ok := coreType.(*types.Object); ok { - if objType.StaticProperties() != nil && objType.DynamicProperties() == nil { - orType = types.Or(orType, coreType) - } - } - } - - // Iterate through every property of AnyOf and add it to orType - for _, pSchema := range subSchema.AnyOf { - newtype, err := parser.parseSchema(pSchema) - if err != nil { - return nil, fmt.Errorf("unexpected schema type %v: %w", pSchema, err) - } - orType = types.Or(newtype, orType) - } - - return orType, nil - } - - if subSchema.AllOf != nil { - subSchemaArray := subSchema.AllOf - allOfResult, err := mergeSchemas(subSchemaArray...) - if err != nil { - return nil, err - } - - if subSchema.Types.IsTyped() { - if (subSchema.Types.Contains("object") && allOfResult.Types.Contains("object")) || (subSchema.Types.Contains("array") && allOfResult.Types.Contains("array")) { - objectOrArrayResult, err := mergeSchemas(allOfResult, subSchema) - if err != nil { - return nil, err - } - return parser.parseSchema(objectOrArrayResult) - } else if subSchema.Types.String() != allOfResult.Types.String() { - return nil, fmt.Errorf("unable to merge these schemas") - } - } - return parser.parseSchema(allOfResult) - } - - if subSchema.Types.IsTyped() { - if subSchema.Types.Contains("boolean") { - return types.B, nil - - } else if subSchema.Types.Contains("string") { - return types.S, nil - - } else if subSchema.Types.Contains("integer") || subSchema.Types.Contains("number") { - return types.N, nil - - } else if subSchema.Types.Contains("object") { - if len(subSchema.PropertiesChildren) > 0 { - def := &cachedDef{ - properties: make([]*types.StaticProperty, 0, len(subSchema.PropertiesChildren)), - } - for _, pSchema := range subSchema.PropertiesChildren { - def.properties = append(def.properties, types.NewStaticProperty(pSchema.Property, nil)) - } - if propertyKey != "" { - parser.definitionCache[propertyKey] = def - } - for _, pSchema := range subSchema.PropertiesChildren { - newtype, err := parser.parseSchema(pSchema) - if err != nil { - return nil, fmt.Errorf("unexpected schema type %v: %w", pSchema, err) - } - for i, prop := range def.properties { - if prop.Key == pSchema.Property { - def.properties[i].Value = newtype - break - } - } - } - return types.NewObject(def.properties, nil), nil - } - return types.NewObject(nil, types.NewDynamicProperty(types.A, types.A)), nil - - } else if subSchema.Types.Contains("array") { - if len(subSchema.ItemsChildren) > 0 { - if subSchema.ItemsChildrenIsSingleSchema { - iSchema := subSchema.ItemsChildren[0] - newtype, err := parser.parseSchema(iSchema) - if err != nil { - return nil, fmt.Errorf("unexpected schema type %v", iSchema) - } - return types.NewArray(nil, newtype), nil - } - newTypes := make([]types.Type, 0, len(subSchema.ItemsChildren)) - for i := 0; i != len(subSchema.ItemsChildren); i++ { - iSchema := subSchema.ItemsChildren[i] - newtype, err := parser.parseSchema(iSchema) - if err != nil { - return nil, fmt.Errorf("unexpected schema type %v", iSchema) - } - newTypes = append(newTypes, newtype) - } - return types.NewArray(newTypes, nil), nil - } - return types.NewArray(nil, types.A), nil - } - } - - // Assume types if not specified in schema - if len(subSchema.PropertiesChildren) > 0 { - if err := subSchema.Types.Add("object"); err == nil { - return parser.parseSchema(subSchema) - } - } else if len(subSchema.ItemsChildren) > 0 { - if err := subSchema.Types.Add("array"); err == nil { - return parser.parseSchema(subSchema) - } - } - - return types.A, nil -} - -func (c *Compiler) setAnnotationSet() { - // Sorting modules by name for stable error reporting - sorted := make([]*Module, 0, len(c.Modules)) - for _, mName := range c.sorted { - sorted = append(sorted, c.Modules[mName]) - } - - as, errs := BuildAnnotationSet(sorted) - for _, err := range errs { - c.err(err) - } - c.annotationSet = as -} - -// checkTypes runs the type checker on all rules. The type checker builds a -// TypeEnv that is stored on the compiler. -func (c *Compiler) checkTypes() { - // Recursion is caught in earlier step, so this cannot fail. - sorted, _ := c.Graph.Sort() - checker := newTypeChecker(). - WithAllowNet(c.capabilities.AllowNet). - WithSchemaSet(c.schemaSet). - WithInputType(c.inputType). - WithBuiltins(c.builtins). - WithRequiredCapabilities(c.Required). - WithVarRewriter(rewriteVarsInRef(c.RewrittenVars)). - WithAllowUndefinedFunctionCalls(c.allowUndefinedFuncCalls) - var as *AnnotationSet - if c.useTypeCheckAnnotations { - as = c.annotationSet - } - env, errs := checker.CheckTypes(c.TypeEnv, sorted, as) - for _, err := range errs { - c.err(err) - } - c.TypeEnv = env -} - -func (c *Compiler) checkUnsafeBuiltins() { - for _, name := range c.sorted { - errs := checkUnsafeBuiltins(c.unsafeBuiltinsMap, c.Modules[name]) - for _, err := range errs { - c.err(err) - } - } -} - -func (c *Compiler) checkDeprecatedBuiltins() { - for _, name := range c.sorted { - mod := c.Modules[name] - if c.strict || mod.regoV1Compatible() { - errs := checkDeprecatedBuiltins(c.deprecatedBuiltinsMap, mod) - for _, err := range errs { - c.err(err) - } - } - } -} - -func (c *Compiler) runStage(metricName string, f func()) { - if c.metrics != nil { - c.metrics.Timer(metricName).Start() - defer c.metrics.Timer(metricName).Stop() - } - f() -} - -func (c *Compiler) runStageAfter(metricName string, s CompilerStage) *Error { - if c.metrics != nil { - c.metrics.Timer(metricName).Start() - defer c.metrics.Timer(metricName).Stop() - } - return s(c) -} - -func (c *Compiler) compile() { - - defer func() { - if r := recover(); r != nil && r != errLimitReached { - panic(r) - } - }() - - for _, s := range c.stages { - if c.evalMode == EvalModeIR { - switch s.name { - case "BuildRuleIndices", "BuildComprehensionIndices": - continue // skip these stages - } - } - - if c.allowUndefinedFuncCalls && (s.name == "CheckUndefinedFuncs" || s.name == "CheckSafetyRuleBodies") { - continue - } - - c.runStage(s.metricName, s.f) - if c.Failed() { - return - } - for _, a := range c.after[s.name] { - if err := c.runStageAfter(a.MetricName, a.Stage); err != nil { - c.err(err) - return - } - } - } -} - -func (c *Compiler) init() { - - if c.initialized { - return - } - - if c.capabilities == nil { - c.capabilities = CapabilitiesForThisVersion() - } - - c.builtins = make(map[string]*Builtin, len(c.capabilities.Builtins)+len(c.customBuiltins)) - - for _, bi := range c.capabilities.Builtins { - c.builtins[bi.Name] = bi - if bi.IsDeprecated() { - c.deprecatedBuiltinsMap[bi.Name] = struct{}{} - } - } - - for name, bi := range c.customBuiltins { - c.builtins[name] = bi - } - - // Load the global input schema if one was provided. - if c.schemaSet != nil { - if schema := c.schemaSet.Get(SchemaRootRef); schema != nil { - tpe, err := loadSchema(schema, c.capabilities.AllowNet) - if err != nil { - c.err(NewError(TypeErr, nil, err.Error())) - } else { - c.inputType = tpe - } - } - } - - c.TypeEnv = newTypeChecker(). - WithSchemaSet(c.schemaSet). - WithInputType(c.inputType). - Env(c.builtins) - - c.initialized = true -} - -func (c *Compiler) err(err *Error) { - if c.maxErrs > 0 && len(c.Errors) >= c.maxErrs { - c.Errors = append(c.Errors, errLimitReached) - panic(errLimitReached) - } - c.Errors = append(c.Errors, err) -} - -func (c *Compiler) getExports() *util.HashMap { - - rules := util.NewHashMap(func(a, b util.T) bool { - return a.(Ref).Equal(b.(Ref)) - }, func(v util.T) int { - return v.(Ref).Hash() - }) - - for _, name := range c.sorted { - mod := c.Modules[name] - - for _, rule := range mod.Rules { - hashMapAdd(rules, mod.Package.Path, rule.Head.Ref().GroundPrefix()) - } - } - - return rules -} - -func hashMapAdd(rules *util.HashMap, pkg, rule Ref) { - prev, ok := rules.Get(pkg) - if !ok { - rules.Put(pkg, []Ref{rule}) - return - } - for _, p := range prev.([]Ref) { - if p.Equal(rule) { - return - } - } - rules.Put(pkg, append(prev.([]Ref), rule)) -} - -func (c *Compiler) GetAnnotationSet() *AnnotationSet { - return c.annotationSet -} - -func (c *Compiler) checkDuplicateImports() { - modules := make([]*Module, 0, len(c.Modules)) - - for _, name := range c.sorted { - mod := c.Modules[name] - if c.strict || mod.regoV1Compatible() { - modules = append(modules, mod) - } - } - - errs := checkDuplicateImports(modules) - for _, err := range errs { - c.err(err) - } -} - -func (c *Compiler) checkKeywordOverrides() { - for _, name := range c.sorted { - mod := c.Modules[name] - if c.strict || mod.regoV1Compatible() { - errs := checkRootDocumentOverrides(mod) - for _, err := range errs { - c.err(err) - } - } - } -} - -// resolveAllRefs resolves references in expressions to their fully qualified values. -// -// For instance, given the following module: -// -// package a.b -// import data.foo.bar -// p[x] { bar[_] = x } -// -// The reference "bar[_]" would be resolved to "data.foo.bar[_]". -// -// Ref rules are resolved, too: -// -// package a.b -// q { c.d.e == 1 } -// c.d[e] := 1 if e := "e" -// -// The reference "c.d.e" would be resolved to "data.a.b.c.d.e". -func (c *Compiler) resolveAllRefs() { - - rules := c.getExports() - - for _, name := range c.sorted { - mod := c.Modules[name] - - var ruleExports []Ref - if x, ok := rules.Get(mod.Package.Path); ok { - ruleExports = x.([]Ref) - } - - globals := getGlobals(mod.Package, ruleExports, mod.Imports) - - WalkRules(mod, func(rule *Rule) bool { - err := resolveRefsInRule(globals, rule) - if err != nil { - c.err(NewError(CompileErr, rule.Location, err.Error())) - } - return false - }) - - if c.strict { // check for unused imports - for _, imp := range mod.Imports { - path := imp.Path.Value.(Ref) - if FutureRootDocument.Equal(path[0]) || RegoRootDocument.Equal(path[0]) { - continue // ignore future and rego imports - } - - for v, u := range globals { - if v.Equal(imp.Name()) && !u.used { - c.err(NewError(CompileErr, imp.Location, "%s unused", imp.String())) - } - } - } - } - } - - if c.moduleLoader != nil { - - parsed, err := c.moduleLoader(c.Modules) - if err != nil { - c.err(NewError(CompileErr, nil, err.Error())) - return - } - - if len(parsed) == 0 { - return - } - - for id, module := range parsed { - c.Modules[id] = module.Copy() - c.sorted = append(c.sorted, id) - if c.parsedModules != nil { - c.parsedModules[id] = module - } - } - - sort.Strings(c.sorted) - c.resolveAllRefs() - } -} - -func (c *Compiler) removeImports() { - c.imports = make(map[string][]*Import, len(c.Modules)) - for name := range c.Modules { - c.imports[name] = c.Modules[name].Imports - c.Modules[name].Imports = nil - } -} - -func (c *Compiler) initLocalVarGen() { - c.localvargen = newLocalVarGeneratorForModuleSet(c.sorted, c.Modules) -} - -func (c *Compiler) rewriteComprehensionTerms() { - f := newEqualityFactory(c.localvargen) - for _, name := range c.sorted { - mod := c.Modules[name] - _, _ = rewriteComprehensionTerms(f, mod) // ignore error - } -} - -func (c *Compiler) rewriteExprTerms() { - for _, name := range c.sorted { - mod := c.Modules[name] - WalkRules(mod, func(rule *Rule) bool { - rewriteExprTermsInHead(c.localvargen, rule) - rule.Body = rewriteExprTermsInBody(c.localvargen, rule.Body) - return false - }) - } -} - -func (c *Compiler) rewriteRuleHeadRefs() { - f := newEqualityFactory(c.localvargen) - for _, name := range c.sorted { - WalkRules(c.Modules[name], func(rule *Rule) bool { - - ref := rule.Head.Ref() - // NOTE(sr): We're backfilling Refs here -- all parser code paths would have them, but - // it's possible to construct Module{} instances from Golang code, so we need - // to accommodate for that, too. - if len(rule.Head.Reference) == 0 { - rule.Head.Reference = ref - } - - cannotSpeakStringPrefixRefs := true - cannotSpeakGeneralRefs := true - for _, f := range c.capabilities.Features { - switch f { - case FeatureRefHeadStringPrefixes: - cannotSpeakStringPrefixRefs = false - case FeatureRefHeads: - cannotSpeakGeneralRefs = false - } - } - - if cannotSpeakStringPrefixRefs && cannotSpeakGeneralRefs && rule.Head.Name == "" { - c.err(NewError(CompileErr, rule.Loc(), "rule heads with refs are not supported: %v", rule.Head.Reference)) - return true - } - - for i := 1; i < len(ref); i++ { - if cannotSpeakGeneralRefs && (rule.Head.RuleKind() == MultiValue || i != len(ref)-1) { // last - if _, ok := ref[i].Value.(String); !ok { - c.err(NewError(TypeErr, rule.Loc(), "rule heads with general refs (containing variables) are not supported: %v", rule.Head.Reference)) - continue - } - } - - // Rewrite so that any non-scalar elements in the rule's ref are vars: - // p.q.r[y.z] { ... } => p.q.r[__local0__] { __local0__ = y.z } - // p.q[a.b][c.d] { ... } => p.q[__local0__] { __local0__ = a.b; __local1__ = c.d } - // because that's what the RuleTree knows how to deal with. - if _, ok := ref[i].Value.(Var); !ok && !IsScalar(ref[i].Value) { - expr := f.Generate(ref[i]) - if i == len(ref)-1 && rule.Head.Key.Equal(ref[i]) { - rule.Head.Key = expr.Operand(0) - } - rule.Head.Reference[i] = expr.Operand(0) - rule.Body.Append(expr) - } - } - - return true - }) - } -} - -func (c *Compiler) checkVoidCalls() { - for _, name := range c.sorted { - mod := c.Modules[name] - for _, err := range checkVoidCalls(c.TypeEnv, mod) { - c.err(err) - } - } -} - -func (c *Compiler) rewritePrintCalls() { - var modified bool - if !c.enablePrintStatements { - for _, name := range c.sorted { - if erasePrintCalls(c.Modules[name]) { - modified = true - } - } - } else { - for _, name := range c.sorted { - mod := c.Modules[name] - WalkRules(mod, func(r *Rule) bool { - safe := r.Head.Args.Vars() - safe.Update(ReservedVars) - vis := func(b Body) bool { - modrec, errs := rewritePrintCalls(c.localvargen, c.GetArity, safe, b) - if modrec { - modified = true - } - for _, err := range errs { - c.err(err) - } - return false - } - WalkBodies(r.Head, vis) - WalkBodies(r.Body, vis) - return false - }) - } - } - if modified { - c.Required.addBuiltinSorted(Print) - } -} - -// checkVoidCalls returns errors for any expressions that treat void function -// calls as values. The only void functions in Rego are specific built-ins like -// print(). -func checkVoidCalls(env *TypeEnv, x interface{}) Errors { - var errs Errors - WalkTerms(x, func(x *Term) bool { - if call, ok := x.Value.(Call); ok { - if tpe, ok := env.Get(call[0]).(*types.Function); ok && tpe.Result() == nil { - errs = append(errs, NewError(TypeErr, x.Loc(), "%v used as value", call)) - } - } - return false - }) - return errs -} - -// rewritePrintCalls will rewrite the body so that print operands are captured -// in local variables and their evaluation occurs within a comprehension. -// Wrapping the terms inside of a comprehension ensures that undefined values do -// not short-circuit evaluation. -// -// For example, given the following print statement: -// -// print("the value of x is:", input.x) -// -// The expression would be rewritten to: -// -// print({__local0__ | __local0__ = "the value of x is:"}, {__local1__ | __local1__ = input.x}) -func rewritePrintCalls(gen *localVarGenerator, getArity func(Ref) int, globals VarSet, body Body) (bool, Errors) { - - var errs Errors - var modified bool - - // Visit comprehension bodies recursively to ensure print statements inside - // those bodies only close over variables that are safe. - for i := range body { - if ContainsClosures(body[i]) { - safe := outputVarsForBody(body[:i], getArity, globals) - safe.Update(globals) - WalkClosures(body[i], func(x interface{}) bool { - var modrec bool - var errsrec Errors - switch x := x.(type) { - case *SetComprehension: - modrec, errsrec = rewritePrintCalls(gen, getArity, safe, x.Body) - case *ArrayComprehension: - modrec, errsrec = rewritePrintCalls(gen, getArity, safe, x.Body) - case *ObjectComprehension: - modrec, errsrec = rewritePrintCalls(gen, getArity, safe, x.Body) - case *Every: - safe.Update(x.KeyValueVars()) - modrec, errsrec = rewritePrintCalls(gen, getArity, safe, x.Body) - } - if modrec { - modified = true - } - errs = append(errs, errsrec...) - return true - }) - if len(errs) > 0 { - return false, errs - } - } - } - - for i := range body { - - if !isPrintCall(body[i]) { - continue - } - - modified = true - - var errs Errors - safe := outputVarsForBody(body[:i], getArity, globals) - safe.Update(globals) - args := body[i].Operands() - - for j := range args { - vis := NewVarVisitor().WithParams(SafetyCheckVisitorParams) - vis.Walk(args[j]) - unsafe := vis.Vars().Diff(safe) - for _, v := range unsafe.Sorted() { - errs = append(errs, NewError(CompileErr, args[j].Loc(), "var %v is undeclared", v)) - } - } - - if len(errs) > 0 { - return false, errs - } - - arr := NewArray() - - for j := range args { - x := NewTerm(gen.Generate()).SetLocation(args[j].Loc()) - capture := Equality.Expr(x, args[j]).SetLocation(args[j].Loc()) - arr = arr.Append(SetComprehensionTerm(x, NewBody(capture)).SetLocation(args[j].Loc())) - } - - body.Set(NewExpr([]*Term{ - NewTerm(InternalPrint.Ref()).SetLocation(body[i].Loc()), - NewTerm(arr).SetLocation(body[i].Loc()), - }).SetLocation(body[i].Loc()), i) - } - - return modified, nil -} - -func erasePrintCalls(node interface{}) bool { - var modified bool - NewGenericVisitor(func(x interface{}) bool { - var modrec bool - switch x := x.(type) { - case *Rule: - modrec, x.Body = erasePrintCallsInBody(x.Body) - case *ArrayComprehension: - modrec, x.Body = erasePrintCallsInBody(x.Body) - case *SetComprehension: - modrec, x.Body = erasePrintCallsInBody(x.Body) - case *ObjectComprehension: - modrec, x.Body = erasePrintCallsInBody(x.Body) - case *Every: - modrec, x.Body = erasePrintCallsInBody(x.Body) - } - if modrec { - modified = true - } - return false - }).Walk(node) - return modified -} - -func erasePrintCallsInBody(x Body) (bool, Body) { - - if !containsPrintCall(x) { - return false, x - } - - var cpy Body - - for i := range x { - - // Recursively visit any comprehensions contained in this expression. - erasePrintCalls(x[i]) - - if !isPrintCall(x[i]) { - cpy.Append(x[i]) - } - } - - if len(cpy) == 0 { - term := BooleanTerm(true).SetLocation(x.Loc()) - expr := NewExpr(term).SetLocation(x.Loc()) - cpy.Append(expr) - } - - return true, cpy -} - -func containsPrintCall(x interface{}) bool { - var found bool - WalkExprs(x, func(expr *Expr) bool { - if !found { - if isPrintCall(expr) { - found = true - } - } - return found - }) - return found -} - -func isPrintCall(x *Expr) bool { - return x.IsCall() && x.Operator().Equal(Print.Ref()) -} - -// rewriteRefsInHead will rewrite rules so that the head does not contain any -// terms that require evaluation (e.g., refs or comprehensions). If the key or -// value contains one or more of these terms, the key or value will be moved -// into the body and assigned to a new variable. The new variable will replace -// the key or value in the head. -// -// For instance, given the following rule: -// -// p[{"foo": data.foo[i]}] { i < 100 } -// -// The rule would be re-written as: -// -// p[__local0__] { i < 100; __local0__ = {"foo": data.foo[i]} } -func (c *Compiler) rewriteRefsInHead() { - f := newEqualityFactory(c.localvargen) - for _, name := range c.sorted { - mod := c.Modules[name] - WalkRules(mod, func(rule *Rule) bool { - if requiresEval(rule.Head.Key) { - expr := f.Generate(rule.Head.Key) - rule.Head.Key = expr.Operand(0) - rule.Body.Append(expr) - } - if requiresEval(rule.Head.Value) { - expr := f.Generate(rule.Head.Value) - rule.Head.Value = expr.Operand(0) - rule.Body.Append(expr) - } - for i := 0; i < len(rule.Head.Args); i++ { - if requiresEval(rule.Head.Args[i]) { - expr := f.Generate(rule.Head.Args[i]) - rule.Head.Args[i] = expr.Operand(0) - rule.Body.Append(expr) - } - } - return false - }) - } -} - -func (c *Compiler) rewriteEquals() { - modified := false - for _, name := range c.sorted { - mod := c.Modules[name] - modified = rewriteEquals(mod) || modified - } - if modified { - c.Required.addBuiltinSorted(Equal) - } -} - -func (c *Compiler) rewriteDynamicTerms() { - f := newEqualityFactory(c.localvargen) - for _, name := range c.sorted { - mod := c.Modules[name] - WalkRules(mod, func(rule *Rule) bool { - rule.Body = rewriteDynamics(f, rule.Body) - return false - }) - } -} - -// rewriteTestRuleEqualities rewrites equality expressions in test rule bodies to create local vars for statements that would otherwise -// not have their values captured through tracing, such as refs and comprehensions not unified/assigned to a local var. -// For example, given the following module: -// -// package test -// -// p.q contains v if { -// some v in numbers.range(1, 3) -// } -// -// p.r := "foo" -// -// test_rule { -// p == { -// "q": {4, 5, 6} -// } -// } -// -// `p` in `test_rule` resolves to `data.test.p`, which won't be an entry in the virtual-cache and must therefore be calculated after-the-fact. -// If `p` isn't captured in a local var, there is no trivial way to retrieve its value for test reporting. -func (c *Compiler) rewriteTestRuleEqualities() { - if !c.rewriteTestRulesForTracing { - return - } - - f := newEqualityFactory(c.localvargen) - for _, name := range c.sorted { - mod := c.Modules[name] - WalkRules(mod, func(rule *Rule) bool { - if strings.HasPrefix(string(rule.Head.Name), "test_") { - rule.Body = rewriteTestEqualities(f, rule.Body) - } - return false - }) - } -} - -func (c *Compiler) parseMetadataBlocks() { - // Only parse annotations if rego.metadata built-ins are called - regoMetadataCalled := false - for _, name := range c.sorted { - mod := c.Modules[name] - WalkExprs(mod, func(expr *Expr) bool { - if isRegoMetadataChainCall(expr) || isRegoMetadataRuleCall(expr) { - regoMetadataCalled = true - } - return regoMetadataCalled - }) - - if regoMetadataCalled { - break - } - } - - if regoMetadataCalled { - // NOTE: Possible optimization: only parse annotations for modules on the path of rego.metadata-calling module - for _, name := range c.sorted { - mod := c.Modules[name] - - if len(mod.Annotations) == 0 { - var errs Errors - mod.Annotations, errs = parseAnnotations(mod.Comments) - errs = append(errs, attachAnnotationsNodes(mod)...) - for _, err := range errs { - c.err(err) - } - - attachRuleAnnotations(mod) - } - } - } -} - -func (c *Compiler) rewriteRegoMetadataCalls() { - eqFactory := newEqualityFactory(c.localvargen) - - _, chainFuncAllowed := c.builtins[RegoMetadataChain.Name] - _, ruleFuncAllowed := c.builtins[RegoMetadataRule.Name] - - for _, name := range c.sorted { - mod := c.Modules[name] - - WalkRules(mod, func(rule *Rule) bool { - var firstChainCall *Expr - var firstRuleCall *Expr - - WalkExprs(rule, func(expr *Expr) bool { - if chainFuncAllowed && firstChainCall == nil && isRegoMetadataChainCall(expr) { - firstChainCall = expr - } else if ruleFuncAllowed && firstRuleCall == nil && isRegoMetadataRuleCall(expr) { - firstRuleCall = expr - } - return firstChainCall != nil && firstRuleCall != nil - }) - - chainCalled := firstChainCall != nil - ruleCalled := firstRuleCall != nil - - if chainCalled || ruleCalled { - body := make(Body, 0, len(rule.Body)+2) - - var metadataChainVar Var - if chainCalled { - // Create and inject metadata chain for rule - - chain, err := createMetadataChain(c.annotationSet.Chain(rule)) - if err != nil { - c.err(err) - return false - } - - chain.Location = firstChainCall.Location - eq := eqFactory.Generate(chain) - metadataChainVar = eq.Operands()[0].Value.(Var) - body.Append(eq) - } - - var metadataRuleVar Var - if ruleCalled { - // Create and inject metadata for rule - - var metadataRuleTerm *Term - - a := getPrimaryRuleAnnotations(c.annotationSet, rule) - if a != nil { - annotObj, err := a.toObject() - if err != nil { - c.err(err) - return false - } - metadataRuleTerm = NewTerm(*annotObj) - } else { - // If rule has no annotations, assign an empty object - metadataRuleTerm = ObjectTerm() - } - - metadataRuleTerm.Location = firstRuleCall.Location - eq := eqFactory.Generate(metadataRuleTerm) - metadataRuleVar = eq.Operands()[0].Value.(Var) - body.Append(eq) - } - - for _, expr := range rule.Body { - body.Append(expr) - } - rule.Body = body - - vis := func(b Body) bool { - for _, err := range rewriteRegoMetadataCalls(&metadataChainVar, &metadataRuleVar, b, &c.RewrittenVars) { - c.err(err) - } - return false - } - WalkBodies(rule.Head, vis) - WalkBodies(rule.Body, vis) - } - - return false - }) - } -} - -func getPrimaryRuleAnnotations(as *AnnotationSet, rule *Rule) *Annotations { - annots := as.GetRuleScope(rule) - - if len(annots) == 0 { - return nil - } - - // Sort by annotation location; chain must start with annotations declared closest to rule, then going outward - sort.SliceStable(annots, func(i, j int) bool { - return annots[i].Location.Compare(annots[j].Location) > 0 - }) - - return annots[0] -} - -func rewriteRegoMetadataCalls(metadataChainVar *Var, metadataRuleVar *Var, body Body, rewrittenVars *map[Var]Var) Errors { - var errs Errors - - WalkClosures(body, func(x interface{}) bool { - switch x := x.(type) { - case *ArrayComprehension: - errs = rewriteRegoMetadataCalls(metadataChainVar, metadataRuleVar, x.Body, rewrittenVars) - case *SetComprehension: - errs = rewriteRegoMetadataCalls(metadataChainVar, metadataRuleVar, x.Body, rewrittenVars) - case *ObjectComprehension: - errs = rewriteRegoMetadataCalls(metadataChainVar, metadataRuleVar, x.Body, rewrittenVars) - case *Every: - errs = rewriteRegoMetadataCalls(metadataChainVar, metadataRuleVar, x.Body, rewrittenVars) - } - return true - }) - - for i := range body { - expr := body[i] - var metadataVar Var - - if metadataChainVar != nil && isRegoMetadataChainCall(expr) { - metadataVar = *metadataChainVar - } else if metadataRuleVar != nil && isRegoMetadataRuleCall(expr) { - metadataVar = *metadataRuleVar - } else { - continue - } - - // NOTE(johanfylling): An alternative strategy would be to walk the body and replace all operands[0] - // usages with *metadataChainVar - operands := expr.Operands() - var newExpr *Expr - if len(operands) > 0 { // There is an output var to rewrite - rewrittenVar := operands[0] - newExpr = Equality.Expr(rewrittenVar, NewTerm(metadataVar)) - } else { // No output var, just rewrite expr to metadataVar - newExpr = NewExpr(NewTerm(metadataVar)) - } - - newExpr.Generated = true - newExpr.Location = expr.Location - body.Set(newExpr, i) - } - - return errs -} - -func isRegoMetadataChainCall(x *Expr) bool { - return x.IsCall() && x.Operator().Equal(RegoMetadataChain.Ref()) -} - -func isRegoMetadataRuleCall(x *Expr) bool { - return x.IsCall() && x.Operator().Equal(RegoMetadataRule.Ref()) -} - -func createMetadataChain(chain []*AnnotationsRef) (*Term, *Error) { - - metaArray := NewArray() - for _, link := range chain { - p := link.Path.toArray(). - Slice(1, -1) // Dropping leading 'data' element of path - obj := NewObject( - Item(StringTerm("path"), NewTerm(p)), - ) - if link.Annotations != nil { - annotObj, err := link.Annotations.toObject() - if err != nil { - return nil, err - } - obj.Insert(StringTerm("annotations"), NewTerm(*annotObj)) - } - metaArray = metaArray.Append(NewTerm(obj)) - } - - return NewTerm(metaArray), nil -} - -func (c *Compiler) rewriteLocalVars() { - - var assignment bool - - for _, name := range c.sorted { - mod := c.Modules[name] - gen := c.localvargen - - WalkRules(mod, func(rule *Rule) bool { - argsStack := newLocalDeclaredVars() - - args := NewVarVisitor() - if c.strict { - args.Walk(rule.Head.Args) - } - unusedArgs := args.Vars() - - c.rewriteLocalArgVars(gen, argsStack, rule) - - // Rewrite local vars in each else-branch of the rule. - // Note: this is done instead of a walk so that we can capture any unused function arguments - // across else-branches. - for rule := rule; rule != nil; rule = rule.Else { - stack, errs := c.rewriteLocalVarsInRule(rule, unusedArgs, argsStack, gen) - if stack.assignment { - assignment = true - } - - for arg := range unusedArgs { - if stack.Count(arg) > 1 { - delete(unusedArgs, arg) - } - } - - for _, err := range errs { - c.err(err) - } - } - - if c.strict { - // Report an error for each unused function argument - for arg := range unusedArgs { - if !arg.IsWildcard() { - c.err(NewError(CompileErr, rule.Head.Location, "unused argument %v. (hint: use _ (wildcard variable) instead)", arg)) - } - } - } - - return true - }) - } - - if assignment { - c.Required.addBuiltinSorted(Assign) - } -} - -func (c *Compiler) rewriteLocalVarsInRule(rule *Rule, unusedArgs VarSet, argsStack *localDeclaredVars, gen *localVarGenerator) (*localDeclaredVars, Errors) { - // Rewrite assignments contained in head of rule. Assignments can - // occur in rule head if they're inside a comprehension. Note, - // assigned vars in comprehensions in the head will be rewritten - // first to preserve scoping rules. For example: - // - // p = [x | x := 1] { x := 2 } becomes p = [__local0__ | __local0__ = 1] { __local1__ = 2 } - // - // This behaviour is consistent scoping inside the body. For example: - // - // p = xs { x := 2; xs = [x | x := 1] } becomes p = xs { __local0__ = 2; xs = [__local1__ | __local1__ = 1] } - nestedXform := &rewriteNestedHeadVarLocalTransform{ - gen: gen, - RewrittenVars: c.RewrittenVars, - strict: c.strict, - } - - NewGenericVisitor(nestedXform.Visit).Walk(rule.Head) - - for _, err := range nestedXform.errs { - c.err(err) - } - - // Rewrite assignments in body. - used := NewVarSet() - - for _, t := range rule.Head.Ref()[1:] { - used.Update(t.Vars()) - } - - if rule.Head.Key != nil { - used.Update(rule.Head.Key.Vars()) - } - - if rule.Head.Value != nil { - valueVars := rule.Head.Value.Vars() - used.Update(valueVars) - for arg := range unusedArgs { - if valueVars.Contains(arg) { - delete(unusedArgs, arg) - } - } - } - - stack := argsStack.Copy() - - body, declared, errs := rewriteLocalVars(gen, stack, used, rule.Body, c.strict) - - // For rewritten vars use the collection of all variables that - // were in the stack at some point in time. - for k, v := range stack.rewritten { - c.RewrittenVars[k] = v - } - - rule.Body = body - - // Rewrite vars in head that refer to locally declared vars in the body. - localXform := rewriteHeadVarLocalTransform{declared: declared} - - for i := range rule.Head.Args { - rule.Head.Args[i], _ = transformTerm(localXform, rule.Head.Args[i]) - } - - for i := 1; i < len(rule.Head.Ref()); i++ { - rule.Head.Reference[i], _ = transformTerm(localXform, rule.Head.Ref()[i]) - } - if rule.Head.Key != nil { - rule.Head.Key, _ = transformTerm(localXform, rule.Head.Key) - } - - if rule.Head.Value != nil { - rule.Head.Value, _ = transformTerm(localXform, rule.Head.Value) - } - return stack, errs -} - -type rewriteNestedHeadVarLocalTransform struct { - gen *localVarGenerator - errs Errors - RewrittenVars map[Var]Var - strict bool -} - -func (xform *rewriteNestedHeadVarLocalTransform) Visit(x interface{}) bool { - - if term, ok := x.(*Term); ok { - - stop := false - stack := newLocalDeclaredVars() - - switch x := term.Value.(type) { - case *object: - cpy, _ := x.Map(func(k, v *Term) (*Term, *Term, error) { - kcpy := k.Copy() - NewGenericVisitor(xform.Visit).Walk(kcpy) - vcpy := v.Copy() - NewGenericVisitor(xform.Visit).Walk(vcpy) - return kcpy, vcpy, nil - }) - term.Value = cpy - stop = true - case *set: - cpy, _ := x.Map(func(v *Term) (*Term, error) { - vcpy := v.Copy() - NewGenericVisitor(xform.Visit).Walk(vcpy) - return vcpy, nil - }) - term.Value = cpy - stop = true - case *ArrayComprehension: - xform.errs = rewriteDeclaredVarsInArrayComprehension(xform.gen, stack, x, xform.errs, xform.strict) - stop = true - case *SetComprehension: - xform.errs = rewriteDeclaredVarsInSetComprehension(xform.gen, stack, x, xform.errs, xform.strict) - stop = true - case *ObjectComprehension: - xform.errs = rewriteDeclaredVarsInObjectComprehension(xform.gen, stack, x, xform.errs, xform.strict) - stop = true - } - - for k, v := range stack.rewritten { - xform.RewrittenVars[k] = v - } - - return stop - } - - return false -} - -type rewriteHeadVarLocalTransform struct { - declared map[Var]Var -} - -func (xform rewriteHeadVarLocalTransform) Transform(x interface{}) (interface{}, error) { - if v, ok := x.(Var); ok { - if gv, ok := xform.declared[v]; ok { - return gv, nil - } - } - return x, nil -} - -func (c *Compiler) rewriteLocalArgVars(gen *localVarGenerator, stack *localDeclaredVars, rule *Rule) { - - vis := &ruleArgLocalRewriter{ - stack: stack, - gen: gen, - } - - for i := range rule.Head.Args { - Walk(vis, rule.Head.Args[i]) - } - - for i := range vis.errs { - c.err(vis.errs[i]) - } -} - -type ruleArgLocalRewriter struct { - stack *localDeclaredVars - gen *localVarGenerator - errs []*Error -} - -func (vis *ruleArgLocalRewriter) Visit(x interface{}) Visitor { - - t, ok := x.(*Term) - if !ok { - return vis - } - - switch v := t.Value.(type) { - case Var: - gv, ok := vis.stack.Declared(v) - if ok { - vis.stack.Seen(v) - } else { - gv = vis.gen.Generate() - vis.stack.Insert(v, gv, argVar) - } - t.Value = gv - return nil - case *object: - if cpy, err := v.Map(func(k, v *Term) (*Term, *Term, error) { - vcpy := v.Copy() - Walk(vis, vcpy) - return k, vcpy, nil - }); err != nil { - vis.errs = append(vis.errs, NewError(CompileErr, t.Location, err.Error())) - } else { - t.Value = cpy - } - return nil - case Null, Boolean, Number, String, *ArrayComprehension, *SetComprehension, *ObjectComprehension, Set: - // Scalars are no-ops. Comprehensions are handled above. Sets must not - // contain variables. - return nil - case Call: - vis.errs = append(vis.errs, NewError(CompileErr, t.Location, "rule arguments cannot contain calls")) - return nil - default: - // Recurse on refs and arrays. Any embedded - // variables can be rewritten. - return vis - } -} - -func (c *Compiler) rewriteWithModifiers() { - f := newEqualityFactory(c.localvargen) - for _, name := range c.sorted { - mod := c.Modules[name] - t := NewGenericTransformer(func(x interface{}) (interface{}, error) { - body, ok := x.(Body) - if !ok { - return x, nil - } - body, err := rewriteWithModifiersInBody(c, c.unsafeBuiltinsMap, f, body) - if err != nil { - c.err(err) - } - - return body, nil - }) - _, _ = Transform(t, mod) // ignore error - } -} - -func (c *Compiler) setModuleTree() { - c.ModuleTree = NewModuleTree(c.Modules) -} - -func (c *Compiler) setRuleTree() { - c.RuleTree = NewRuleTree(c.ModuleTree) -} - -func (c *Compiler) setGraph() { - list := func(r Ref) []*Rule { - return c.GetRulesDynamicWithOpts(r, RulesOptions{IncludeHiddenModules: true}) - } - c.Graph = NewGraph(c.Modules, list) -} - -type queryCompiler struct { - compiler *Compiler - qctx *QueryContext - typeEnv *TypeEnv - rewritten map[Var]Var - after map[string][]QueryCompilerStageDefinition - unsafeBuiltins map[string]struct{} - comprehensionIndices map[*Term]*ComprehensionIndex - enablePrintStatements bool -} - -func newQueryCompiler(compiler *Compiler) QueryCompiler { - qc := &queryCompiler{ - compiler: compiler, - qctx: nil, - after: map[string][]QueryCompilerStageDefinition{}, - comprehensionIndices: map[*Term]*ComprehensionIndex{}, - } - return qc -} - -func (qc *queryCompiler) WithStrict(strict bool) QueryCompiler { - qc.compiler.WithStrict(strict) - return qc -} - -func (qc *queryCompiler) WithEnablePrintStatements(yes bool) QueryCompiler { - qc.enablePrintStatements = yes - return qc -} - -func (qc *queryCompiler) WithContext(qctx *QueryContext) QueryCompiler { - qc.qctx = qctx - return qc -} - -func (qc *queryCompiler) WithStageAfter(after string, stage QueryCompilerStageDefinition) QueryCompiler { - qc.after[after] = append(qc.after[after], stage) - return qc -} - -func (qc *queryCompiler) WithUnsafeBuiltins(unsafe map[string]struct{}) QueryCompiler { - qc.unsafeBuiltins = unsafe - return qc -} - -func (qc *queryCompiler) RewrittenVars() map[Var]Var { - return qc.rewritten -} - -func (qc *queryCompiler) ComprehensionIndex(term *Term) *ComprehensionIndex { - if result, ok := qc.comprehensionIndices[term]; ok { - return result - } else if result, ok := qc.compiler.comprehensionIndices[term]; ok { - return result - } - return nil -} - -func (qc *queryCompiler) runStage(metricName string, qctx *QueryContext, query Body, s func(*QueryContext, Body) (Body, error)) (Body, error) { - if qc.compiler.metrics != nil { - qc.compiler.metrics.Timer(metricName).Start() - defer qc.compiler.metrics.Timer(metricName).Stop() - } - return s(qctx, query) -} - -func (qc *queryCompiler) runStageAfter(metricName string, query Body, s QueryCompilerStage) (Body, error) { - if qc.compiler.metrics != nil { - qc.compiler.metrics.Timer(metricName).Start() - defer qc.compiler.metrics.Timer(metricName).Stop() - } - return s(qc, query) -} - -type queryStage = struct { - name string - metricName string - f func(*QueryContext, Body) (Body, error) -} - -func (qc *queryCompiler) Compile(query Body) (Body, error) { - if len(query) == 0 { - return nil, Errors{NewError(CompileErr, nil, "empty query cannot be compiled")} - } - - query = query.Copy() - - stages := []queryStage{ - {"CheckKeywordOverrides", "query_compile_stage_check_keyword_overrides", qc.checkKeywordOverrides}, - {"ResolveRefs", "query_compile_stage_resolve_refs", qc.resolveRefs}, - {"RewriteLocalVars", "query_compile_stage_rewrite_local_vars", qc.rewriteLocalVars}, - {"CheckVoidCalls", "query_compile_stage_check_void_calls", qc.checkVoidCalls}, - {"RewritePrintCalls", "query_compile_stage_rewrite_print_calls", qc.rewritePrintCalls}, - {"RewriteExprTerms", "query_compile_stage_rewrite_expr_terms", qc.rewriteExprTerms}, - {"RewriteComprehensionTerms", "query_compile_stage_rewrite_comprehension_terms", qc.rewriteComprehensionTerms}, - {"RewriteWithValues", "query_compile_stage_rewrite_with_values", qc.rewriteWithModifiers}, - {"CheckUndefinedFuncs", "query_compile_stage_check_undefined_funcs", qc.checkUndefinedFuncs}, - {"CheckSafety", "query_compile_stage_check_safety", qc.checkSafety}, - {"RewriteDynamicTerms", "query_compile_stage_rewrite_dynamic_terms", qc.rewriteDynamicTerms}, - {"CheckTypes", "query_compile_stage_check_types", qc.checkTypes}, - {"CheckUnsafeBuiltins", "query_compile_stage_check_unsafe_builtins", qc.checkUnsafeBuiltins}, - {"CheckDeprecatedBuiltins", "query_compile_stage_check_deprecated_builtins", qc.checkDeprecatedBuiltins}, - } - if qc.compiler.evalMode == EvalModeTopdown { - stages = append(stages, queryStage{"BuildComprehensionIndex", "query_compile_stage_build_comprehension_index", qc.buildComprehensionIndices}) - } - - qctx := qc.qctx.Copy() - - for _, s := range stages { - var err error - query, err = qc.runStage(s.metricName, qctx, query, s.f) - if err != nil { - return nil, qc.applyErrorLimit(err) - } - for _, s := range qc.after[s.name] { - query, err = qc.runStageAfter(s.MetricName, query, s.Stage) - if err != nil { - return nil, qc.applyErrorLimit(err) - } - } - } - - return query, nil -} - -func (qc *queryCompiler) TypeEnv() *TypeEnv { - return qc.typeEnv -} - -func (qc *queryCompiler) applyErrorLimit(err error) error { - var errs Errors - if errors.As(err, &errs) { - if qc.compiler.maxErrs > 0 && len(errs) > qc.compiler.maxErrs { - err = append(errs[:qc.compiler.maxErrs], errLimitReached) - } - } - return err -} - -func (qc *queryCompiler) checkKeywordOverrides(_ *QueryContext, body Body) (Body, error) { - if qc.compiler.strict { - if errs := checkRootDocumentOverrides(body); len(errs) > 0 { - return nil, errs - } - } - return body, nil -} - -func (qc *queryCompiler) resolveRefs(qctx *QueryContext, body Body) (Body, error) { - - var globals map[Var]*usedRef - - if qctx != nil { - pkg := qctx.Package - // Query compiler ought to generate a package if one was not provided and one or more imports were provided. - // The generated package name could even be an empty string to avoid conflicts (it doesn't have to be valid syntactically) - if pkg == nil && len(qctx.Imports) > 0 { - pkg = &Package{Path: RefTerm(VarTerm("")).Value.(Ref)} - } - if pkg != nil { - var ruleExports []Ref - rules := qc.compiler.getExports() - if exist, ok := rules.Get(pkg.Path); ok { - ruleExports = exist.([]Ref) - } - - globals = getGlobals(qctx.Package, ruleExports, qctx.Imports) - qctx.Imports = nil - } - } - - ignore := &declaredVarStack{declaredVars(body)} - - return resolveRefsInBody(globals, ignore, body), nil -} - -func (qc *queryCompiler) rewriteComprehensionTerms(_ *QueryContext, body Body) (Body, error) { - gen := newLocalVarGenerator("q", body) - f := newEqualityFactory(gen) - node, err := rewriteComprehensionTerms(f, body) - if err != nil { - return nil, err - } - return node.(Body), nil -} - -func (qc *queryCompiler) rewriteDynamicTerms(_ *QueryContext, body Body) (Body, error) { - gen := newLocalVarGenerator("q", body) - f := newEqualityFactory(gen) - return rewriteDynamics(f, body), nil -} - -func (qc *queryCompiler) rewriteExprTerms(_ *QueryContext, body Body) (Body, error) { - gen := newLocalVarGenerator("q", body) - return rewriteExprTermsInBody(gen, body), nil -} - -func (qc *queryCompiler) rewriteLocalVars(_ *QueryContext, body Body) (Body, error) { - gen := newLocalVarGenerator("q", body) - stack := newLocalDeclaredVars() - body, _, err := rewriteLocalVars(gen, stack, nil, body, qc.compiler.strict) - if len(err) != 0 { - return nil, err - } - qc.rewritten = make(map[Var]Var, len(stack.rewritten)) - for k, v := range stack.rewritten { - // The vars returned during the rewrite will include all seen vars, - // even if they're not declared with an assignment operation. We don't - // want to include these inside the rewritten set though. - qc.rewritten[k] = v - } - return body, nil -} - -func (qc *queryCompiler) rewritePrintCalls(_ *QueryContext, body Body) (Body, error) { - if !qc.enablePrintStatements { - _, cpy := erasePrintCallsInBody(body) - return cpy, nil - } - gen := newLocalVarGenerator("q", body) - if _, errs := rewritePrintCalls(gen, qc.compiler.GetArity, ReservedVars, body); len(errs) > 0 { - return nil, errs - } - return body, nil -} - -func (qc *queryCompiler) checkVoidCalls(_ *QueryContext, body Body) (Body, error) { - if errs := checkVoidCalls(qc.compiler.TypeEnv, body); len(errs) > 0 { - return nil, errs - } - return body, nil -} - -func (qc *queryCompiler) checkUndefinedFuncs(_ *QueryContext, body Body) (Body, error) { - if errs := checkUndefinedFuncs(qc.compiler.TypeEnv, body, qc.compiler.GetArity, qc.rewritten); len(errs) > 0 { - return nil, errs - } - return body, nil -} - -func (qc *queryCompiler) checkSafety(_ *QueryContext, body Body) (Body, error) { - safe := ReservedVars.Copy() - reordered, unsafe := reorderBodyForSafety(qc.compiler.builtins, qc.compiler.GetArity, safe, body) - if errs := safetyErrorSlice(unsafe, qc.RewrittenVars()); len(errs) > 0 { - return nil, errs - } - return reordered, nil -} - -func (qc *queryCompiler) checkTypes(_ *QueryContext, body Body) (Body, error) { - var errs Errors - checker := newTypeChecker(). - WithSchemaSet(qc.compiler.schemaSet). - WithInputType(qc.compiler.inputType). - WithVarRewriter(rewriteVarsInRef(qc.rewritten, qc.compiler.RewrittenVars)) - qc.typeEnv, errs = checker.CheckBody(qc.compiler.TypeEnv, body) - if len(errs) > 0 { - return nil, errs - } - - return body, nil -} - -func (qc *queryCompiler) checkUnsafeBuiltins(_ *QueryContext, body Body) (Body, error) { - errs := checkUnsafeBuiltins(qc.unsafeBuiltinsMap(), body) - if len(errs) > 0 { - return nil, errs - } - return body, nil -} - -func (qc *queryCompiler) unsafeBuiltinsMap() map[string]struct{} { - if qc.unsafeBuiltins != nil { - return qc.unsafeBuiltins - } - return qc.compiler.unsafeBuiltinsMap -} - -func (qc *queryCompiler) checkDeprecatedBuiltins(_ *QueryContext, body Body) (Body, error) { - if qc.compiler.strict { - errs := checkDeprecatedBuiltins(qc.compiler.deprecatedBuiltinsMap, body) - if len(errs) > 0 { - return nil, errs - } - } - return body, nil -} - -func (qc *queryCompiler) rewriteWithModifiers(_ *QueryContext, body Body) (Body, error) { - f := newEqualityFactory(newLocalVarGenerator("q", body)) - body, err := rewriteWithModifiersInBody(qc.compiler, qc.unsafeBuiltinsMap(), f, body) - if err != nil { - return nil, Errors{err} - } - return body, nil -} - -func (qc *queryCompiler) buildComprehensionIndices(_ *QueryContext, body Body) (Body, error) { - // NOTE(tsandall): The query compiler does not have a metrics object so we - // cannot record index metrics currently. - _ = buildComprehensionIndices(qc.compiler.debug, qc.compiler.GetArity, ReservedVars, qc.RewrittenVars(), body, qc.comprehensionIndices) - return body, nil -} - -// ComprehensionIndex specifies how the comprehension term can be indexed. The keys -// tell the evaluator what variables to use for indexing. In the future, the index -// could be expanded with more information that would allow the evaluator to index -// a larger fragment of comprehensions (e.g., by closing over variables in the outer -// query.) -type ComprehensionIndex struct { - Term *Term - Keys []*Term -} - -func (ci *ComprehensionIndex) String() string { - if ci == nil { - return "" - } - return fmt.Sprintf("", NewArray(ci.Keys...)) -} - -func buildComprehensionIndices(dbg debug.Debug, arity func(Ref) int, candidates VarSet, rwVars map[Var]Var, node interface{}, result map[*Term]*ComprehensionIndex) uint64 { - var n uint64 - cpy := candidates.Copy() - WalkBodies(node, func(b Body) bool { - for _, expr := range b { - index := getComprehensionIndex(dbg, arity, cpy, rwVars, expr) - if index != nil { - result[index.Term] = index - n++ - } - // Any variables appearing in the expressions leading up to the comprehension - // are fair-game to be used as index keys. - cpy.Update(expr.Vars(VarVisitorParams{SkipClosures: true, SkipRefCallHead: true})) - } - return false - }) - return n -} - -func getComprehensionIndex(dbg debug.Debug, arity func(Ref) int, candidates VarSet, rwVars map[Var]Var, expr *Expr) *ComprehensionIndex { - - // Ignore everything except = expressions. Extract - // the comprehension term from the expression. - if !expr.IsEquality() || expr.Negated || len(expr.With) > 0 { - // No debug message, these are assumed to be known hinderances - // to comprehension indexing. - return nil - } - - var term *Term - - lhs, rhs := expr.Operand(0), expr.Operand(1) - - if _, ok := lhs.Value.(Var); ok && IsComprehension(rhs.Value) { - term = rhs - } else if _, ok := rhs.Value.(Var); ok && IsComprehension(lhs.Value) { - term = lhs - } - - if term == nil { - // no debug for this, it's the ordinary "nothing to do here" case - return nil - } - - // Ignore comprehensions that contain expressions that close over variables - // in the outer body if those variables are not also output variables in the - // comprehension body. In other words, ignore comprehensions that we cannot - // safely evaluate without bindings from the outer body. For example: - // - // x = [1] - // [true | data.y[z] = x] # safe to evaluate w/o outer body - // [true | data.y[z] = x[0]] # NOT safe to evaluate because 'x' would be unsafe. - // - // By identifying output variables in the body we also know what to index on by - // intersecting with candidate variables from the outer query. - // - // For example: - // - // x = data.foo[_] - // _ = [y | data.bar[y] = x] # index on 'x' - // - // This query goes from O(data.foo*data.bar) to O(data.foo+data.bar). - var body Body - - switch x := term.Value.(type) { - case *ArrayComprehension: - body = x.Body - case *SetComprehension: - body = x.Body - case *ObjectComprehension: - body = x.Body - } - - outputs := outputVarsForBody(body, arity, ReservedVars) - unsafe := body.Vars(SafetyCheckVisitorParams).Diff(outputs).Diff(ReservedVars) - - if len(unsafe) > 0 { - dbg.Printf("%s: comprehension index: unsafe vars: %v", expr.Location, unsafe) - return nil - } - - // Similarly, ignore comprehensions that contain references with output variables - // that intersect with the candidates. Indexing these comprehensions could worsen - // performance. - regressionVis := newComprehensionIndexRegressionCheckVisitor(candidates) - regressionVis.Walk(body) - if regressionVis.worse { - dbg.Printf("%s: comprehension index: output vars intersect candidates", expr.Location) - return nil - } - - // Check if any nested comprehensions close over candidates. If any intersection is found - // the comprehension cannot be cached because it would require closing over the candidates - // which the evaluator does not support today. - nestedVis := newComprehensionIndexNestedCandidateVisitor(candidates) - nestedVis.Walk(body) - if nestedVis.found { - dbg.Printf("%s: comprehension index: nested comprehensions close over candidates", expr.Location) - return nil - } - - // Make a sorted set of variable names that will serve as the index key set. - // Sort to ensure deterministic indexing. In future this could be relaxed - // if we can decide that one ordering is better than another. If the set is - // empty, there is no indexing to do. - indexVars := candidates.Intersect(outputs) - if len(indexVars) == 0 { - dbg.Printf("%s: comprehension index: no index vars", expr.Location) - return nil - } - - result := make([]*Term, 0, len(indexVars)) - - for v := range indexVars { - result = append(result, NewTerm(v)) - } - - sort.Slice(result, func(i, j int) bool { - return result[i].Value.Compare(result[j].Value) < 0 - }) - - debugRes := make([]*Term, len(result)) - for i, r := range result { - if o, ok := rwVars[r.Value.(Var)]; ok { - debugRes[i] = NewTerm(o) - } else { - debugRes[i] = r - } - } - dbg.Printf("%s: comprehension index: built with keys: %v", expr.Location, debugRes) - return &ComprehensionIndex{Term: term, Keys: result} -} - -type comprehensionIndexRegressionCheckVisitor struct { - candidates VarSet - seen VarSet - worse bool -} - -// TODO(tsandall): Improve this so that users can either supply this list explicitly -// or the information is maintained on the built-in function declaration. What we really -// need to know is whether the built-in function allows callers to push down output -// values or not. It's unlikely that anything outside of OPA does this today so this -// solution is fine for now. -var comprehensionIndexBlacklist = map[string]int{ - WalkBuiltin.Name: len(WalkBuiltin.Decl.FuncArgs().Args), -} - -func newComprehensionIndexRegressionCheckVisitor(candidates VarSet) *comprehensionIndexRegressionCheckVisitor { - return &comprehensionIndexRegressionCheckVisitor{ - candidates: candidates, - seen: NewVarSet(), - } -} - -func (vis *comprehensionIndexRegressionCheckVisitor) Walk(x interface{}) { - NewGenericVisitor(vis.visit).Walk(x) -} - -func (vis *comprehensionIndexRegressionCheckVisitor) visit(x interface{}) bool { - if !vis.worse { - switch x := x.(type) { - case *Expr: - operands := x.Operands() - if pos := comprehensionIndexBlacklist[x.Operator().String()]; pos > 0 && pos < len(operands) { - vis.assertEmptyIntersection(operands[pos].Vars()) - } - case Ref: - vis.assertEmptyIntersection(x.OutputVars()) - case Var: - vis.seen.Add(x) - // Always skip comprehensions. We do not have to visit their bodies here. - case *ArrayComprehension, *SetComprehension, *ObjectComprehension: - return true - } - } - return vis.worse -} - -func (vis *comprehensionIndexRegressionCheckVisitor) assertEmptyIntersection(vs VarSet) { - for v := range vs { - if vis.candidates.Contains(v) && !vis.seen.Contains(v) { - vis.worse = true - return - } - } -} - -type comprehensionIndexNestedCandidateVisitor struct { - candidates VarSet - found bool -} - -func newComprehensionIndexNestedCandidateVisitor(candidates VarSet) *comprehensionIndexNestedCandidateVisitor { - return &comprehensionIndexNestedCandidateVisitor{ - candidates: candidates, - } -} - -func (vis *comprehensionIndexNestedCandidateVisitor) Walk(x interface{}) { - NewGenericVisitor(vis.visit).Walk(x) -} - -func (vis *comprehensionIndexNestedCandidateVisitor) visit(x interface{}) bool { - - if vis.found { - return true - } - - if v, ok := x.(Value); ok && IsComprehension(v) { - varVis := NewVarVisitor().WithParams(VarVisitorParams{SkipRefHead: true}) - varVis.Walk(v) - vis.found = len(varVis.Vars().Intersect(vis.candidates)) > 0 - return true - } - - return false -} - -// ModuleTreeNode represents a node in the module tree. The module -// tree is keyed by the package path. -type ModuleTreeNode struct { - Key Value - Modules []*Module - Children map[Value]*ModuleTreeNode - Hide bool -} - -func (n *ModuleTreeNode) String() string { - var rules []string - for _, m := range n.Modules { - for _, r := range m.Rules { - rules = append(rules, r.Head.String()) - } - } - return fmt.Sprintf("", n.Key, n.Children, rules, n.Hide) -} - -// NewModuleTree returns a new ModuleTreeNode that represents the root -// of the module tree populated with the given modules. -func NewModuleTree(mods map[string]*Module) *ModuleTreeNode { - root := &ModuleTreeNode{ - Children: map[Value]*ModuleTreeNode{}, - } - names := make([]string, 0, len(mods)) - for name := range mods { - names = append(names, name) - } - sort.Strings(names) - for _, name := range names { - m := mods[name] - node := root - for i, x := range m.Package.Path { - c, ok := node.Children[x.Value] - if !ok { - var hide bool - if i == 1 && x.Value.Compare(SystemDocumentKey) == 0 { - hide = true - } - c = &ModuleTreeNode{ - Key: x.Value, - Children: map[Value]*ModuleTreeNode{}, - Hide: hide, - } - node.Children[x.Value] = c - } - node = c - } - node.Modules = append(node.Modules, m) - } - return root -} - -// Size returns the number of modules in the tree. -func (n *ModuleTreeNode) Size() int { - s := len(n.Modules) - for _, c := range n.Children { - s += c.Size() - } - return s -} - -// Child returns n's child with key k. -func (n *ModuleTreeNode) child(k Value) *ModuleTreeNode { - switch k.(type) { - case String, Var: - return n.Children[k] - } - return nil -} - -// Find dereferences ref along the tree. ref[0] is converted to a String -// for convenience. -func (n *ModuleTreeNode) find(ref Ref) (*ModuleTreeNode, Ref) { - if v, ok := ref[0].Value.(Var); ok { - ref = Ref{StringTerm(string(v))}.Concat(ref[1:]) - } - node := n - for i, r := range ref { - next := node.child(r.Value) - if next == nil { - tail := make(Ref, len(ref)-i) - tail[0] = VarTerm(string(ref[i].Value.(String))) - copy(tail[1:], ref[i+1:]) - return node, tail - } - node = next - } - return node, nil -} - -// DepthFirst performs a depth-first traversal of the module tree rooted at n. -// If f returns true, traversal will not continue to the children of n. -func (n *ModuleTreeNode) DepthFirst(f func(*ModuleTreeNode) bool) { - if f(n) { - return - } - for _, node := range n.Children { - node.DepthFirst(f) - } -} - -// TreeNode represents a node in the rule tree. The rule tree is keyed by -// rule path. -type TreeNode struct { - Key Value - Values []util.T - Children map[Value]*TreeNode - Sorted []Value - Hide bool -} - -func (n *TreeNode) String() string { - return fmt.Sprintf("", n.Key, n.Values, n.Sorted, n.Hide) -} - -// NewRuleTree returns a new TreeNode that represents the root -// of the rule tree populated with the given rules. -func NewRuleTree(mtree *ModuleTreeNode) *TreeNode { - root := TreeNode{ - Key: mtree.Key, - } - - mtree.DepthFirst(func(m *ModuleTreeNode) bool { - for _, mod := range m.Modules { - if len(mod.Rules) == 0 { - root.add(mod.Package.Path, nil) - } - for _, rule := range mod.Rules { - root.add(rule.Ref().GroundPrefix(), rule) - } - } - return false - }) - - // ensure that data.system's TreeNode is hidden - node, tail := root.find(DefaultRootRef.Append(NewTerm(SystemDocumentKey))) - if len(tail) == 0 { // found - node.Hide = true - } - - root.DepthFirst(func(x *TreeNode) bool { - x.sort() - return false - }) - - return &root -} - -func (n *TreeNode) add(path Ref, rule *Rule) { - node, tail := n.find(path) - if len(tail) > 0 { - sub := treeNodeFromRef(tail, rule) - if node.Children == nil { - node.Children = make(map[Value]*TreeNode, 1) - } - node.Children[sub.Key] = sub - node.Sorted = append(node.Sorted, sub.Key) - } else { - if rule != nil { - node.Values = append(node.Values, rule) - } - } -} - -// Size returns the number of rules in the tree. -func (n *TreeNode) Size() int { - s := len(n.Values) - for _, c := range n.Children { - s += c.Size() - } - return s -} - -// Child returns n's child with key k. -func (n *TreeNode) Child(k Value) *TreeNode { - switch k.(type) { - case Ref, Call: - return nil - default: - return n.Children[k] - } -} - -// Find dereferences ref along the tree -func (n *TreeNode) Find(ref Ref) *TreeNode { - node := n - for _, r := range ref { - node = node.Child(r.Value) - if node == nil { - return nil - } - } - return node -} - -// Iteratively dereferences ref along the node's subtree. -// - If matching fails immediately, the tail will contain the full ref. -// - Partial matching will result in a tail of non-zero length. -// - A complete match will result in a 0 length tail. -func (n *TreeNode) find(ref Ref) (*TreeNode, Ref) { - node := n - for i := range ref { - next := node.Child(ref[i].Value) - if next == nil { - tail := make(Ref, len(ref)-i) - copy(tail, ref[i:]) - return node, tail - } - node = next - } - return node, nil -} - -// DepthFirst performs a depth-first traversal of the rule tree rooted at n. If -// f returns true, traversal will not continue to the children of n. -func (n *TreeNode) DepthFirst(f func(*TreeNode) bool) { - if f(n) { - return - } - for _, node := range n.Children { - node.DepthFirst(f) - } -} - -func (n *TreeNode) sort() { - sort.Slice(n.Sorted, func(i, j int) bool { - return n.Sorted[i].Compare(n.Sorted[j]) < 0 - }) -} - -func treeNodeFromRef(ref Ref, rule *Rule) *TreeNode { - depth := len(ref) - 1 - key := ref[depth].Value - node := &TreeNode{ - Key: key, - Children: nil, - } - if rule != nil { - node.Values = []util.T{rule} - } - - for i := len(ref) - 2; i >= 0; i-- { - key := ref[i].Value - node = &TreeNode{ - Key: key, - Children: map[Value]*TreeNode{ref[i+1].Value: node}, - Sorted: []Value{ref[i+1].Value}, - } - } - return node -} - -// flattenChildren flattens all children's rule refs into a sorted array. -func (n *TreeNode) flattenChildren() []Ref { - ret := newRefSet() - for _, sub := range n.Children { // we only want the children, so don't use n.DepthFirst() right away - sub.DepthFirst(func(x *TreeNode) bool { - for _, r := range x.Values { - rule := r.(*Rule) - ret.AddPrefix(rule.Ref()) - } - return false - }) - } - - sort.Slice(ret.s, func(i, j int) bool { - return ret.s[i].Compare(ret.s[j]) < 0 - }) - return ret.s -} - -// Graph represents the graph of dependencies between rules. -type Graph struct { - adj map[util.T]map[util.T]struct{} - radj map[util.T]map[util.T]struct{} - nodes map[util.T]struct{} - sorted []util.T -} - -// NewGraph returns a new Graph based on modules. The list function must return -// the rules referred to directly by the ref. -func NewGraph(modules map[string]*Module, list func(Ref) []*Rule) *Graph { - - graph := &Graph{ - adj: map[util.T]map[util.T]struct{}{}, - radj: map[util.T]map[util.T]struct{}{}, - nodes: map[util.T]struct{}{}, - sorted: nil, - } - - // Create visitor to walk a rule AST and add edges to the rule graph for - // each dependency. - vis := func(a *Rule) *GenericVisitor { - stop := false - return NewGenericVisitor(func(x interface{}) bool { - switch x := x.(type) { - case Ref: - for _, b := range list(x) { - for node := b; node != nil; node = node.Else { - graph.addDependency(a, node) - } - } - case *Rule: - if stop { - // Do not recurse into else clauses (which will be handled - // by the outer visitor.) - return true - } - stop = true - } - return false - }) - } - - // Walk over all rules, add them to graph, and build adjacency lists. - for _, module := range modules { - WalkRules(module, func(a *Rule) bool { - graph.addNode(a) - vis(a).Walk(a) - return false - }) - } - - return graph -} - -// Dependencies returns the set of rules that x depends on. -func (g *Graph) Dependencies(x util.T) map[util.T]struct{} { - return g.adj[x] -} - -// Dependents returns the set of rules that depend on x. -func (g *Graph) Dependents(x util.T) map[util.T]struct{} { - return g.radj[x] -} - -// Sort returns a slice of rules sorted by dependencies. If a cycle is found, -// ok is set to false. -func (g *Graph) Sort() (sorted []util.T, ok bool) { - if g.sorted != nil { - return g.sorted, true - } - - sorter := &graphSort{ - sorted: make([]util.T, 0, len(g.nodes)), - deps: g.Dependencies, - marked: map[util.T]struct{}{}, - temp: map[util.T]struct{}{}, - } - - for node := range g.nodes { - if !sorter.Visit(node) { - return nil, false - } - } - - g.sorted = sorter.sorted - return g.sorted, true -} - -func (g *Graph) addDependency(u util.T, v util.T) { - - if _, ok := g.nodes[u]; !ok { - g.addNode(u) - } - - if _, ok := g.nodes[v]; !ok { - g.addNode(v) - } - - edges, ok := g.adj[u] - if !ok { - edges = map[util.T]struct{}{} - g.adj[u] = edges - } - - edges[v] = struct{}{} - - edges, ok = g.radj[v] - if !ok { - edges = map[util.T]struct{}{} - g.radj[v] = edges - } - - edges[u] = struct{}{} -} - -func (g *Graph) addNode(n util.T) { - g.nodes[n] = struct{}{} -} - -type graphSort struct { - sorted []util.T - deps func(util.T) map[util.T]struct{} - marked map[util.T]struct{} - temp map[util.T]struct{} -} - -func (sort *graphSort) Marked(node util.T) bool { - _, marked := sort.marked[node] - return marked -} - -func (sort *graphSort) Visit(node util.T) (ok bool) { - if _, ok := sort.temp[node]; ok { - return false - } - if sort.Marked(node) { - return true - } - sort.temp[node] = struct{}{} - for other := range sort.deps(node) { - if !sort.Visit(other) { - return false - } - } - sort.marked[node] = struct{}{} - delete(sort.temp, node) - sort.sorted = append(sort.sorted, node) - return true -} - -// GraphTraversal is a Traversal that understands the dependency graph -type GraphTraversal struct { - graph *Graph - visited map[util.T]struct{} -} - -// NewGraphTraversal returns a Traversal for the dependency graph -func NewGraphTraversal(graph *Graph) *GraphTraversal { - return &GraphTraversal{ - graph: graph, - visited: map[util.T]struct{}{}, - } -} - -// Edges lists all dependency connections for a given node -func (g *GraphTraversal) Edges(x util.T) []util.T { - r := []util.T{} - for v := range g.graph.Dependencies(x) { - r = append(r, v) - } - return r -} - -// Visited returns whether a node has been visited, setting a node to visited if not -func (g *GraphTraversal) Visited(u util.T) bool { - _, ok := g.visited[u] - g.visited[u] = struct{}{} - return ok -} - -type unsafePair struct { - Expr *Expr - Vars VarSet -} - -type unsafeVarLoc struct { - Var Var - Loc *Location -} - -type unsafeVars map[*Expr]VarSet - -func (vs unsafeVars) Add(e *Expr, v Var) { - if u, ok := vs[e]; ok { - u[v] = struct{}{} - } else { - vs[e] = VarSet{v: struct{}{}} - } -} - -func (vs unsafeVars) Set(e *Expr, s VarSet) { - vs[e] = s -} - -func (vs unsafeVars) Update(o unsafeVars) { - for k, v := range o { - if _, ok := vs[k]; !ok { - vs[k] = VarSet{} - } - vs[k].Update(v) - } -} - -func (vs unsafeVars) Vars() (result []unsafeVarLoc) { - - locs := map[Var]*Location{} - - // If var appears in multiple sets then pick first by location. - for expr, vars := range vs { - for v := range vars { - if locs[v].Compare(expr.Location) > 0 { - locs[v] = expr.Location - } - } - } - - for v, loc := range locs { - result = append(result, unsafeVarLoc{ - Var: v, - Loc: loc, - }) - } - - sort.Slice(result, func(i, j int) bool { - return result[i].Loc.Compare(result[j].Loc) < 0 - }) - - return result -} - -func (vs unsafeVars) Slice() (result []unsafePair) { - for expr, vs := range vs { - result = append(result, unsafePair{ - Expr: expr, - Vars: vs, - }) - } - return -} - -// reorderBodyForSafety returns a copy of the body ordered such that -// left to right evaluation of the body will not encounter unbound variables -// in input positions or negated expressions. -// -// Expressions are added to the re-ordered body as soon as they are considered -// safe. If multiple expressions become safe in the same pass, they are added -// in their original order. This results in minimal re-ordering of the body. -// -// If the body cannot be reordered to ensure safety, the second return value -// contains a mapping of expressions to unsafe variables in those expressions. -func reorderBodyForSafety(builtins map[string]*Builtin, arity func(Ref) int, globals VarSet, body Body) (Body, unsafeVars) { - - bodyVars := body.Vars(SafetyCheckVisitorParams) - reordered := make(Body, 0, len(body)) - safe := VarSet{} - unsafe := unsafeVars{} - - for _, e := range body { - for v := range e.Vars(SafetyCheckVisitorParams) { - if globals.Contains(v) { - safe.Add(v) - } else { - unsafe.Add(e, v) - } - } - } - - for { - n := len(reordered) - - for _, e := range body { - if reordered.Contains(e) { - continue - } - - ovs := outputVarsForExpr(e, arity, safe) - - // check closures: is this expression closing over variables that - // haven't been made safe by what's already included in `reordered`? - vs := unsafeVarsInClosures(e) - cv := vs.Intersect(bodyVars).Diff(globals) - uv := cv.Diff(outputVarsForBody(reordered, arity, safe)) - - if len(uv) > 0 { - if uv.Equal(ovs) { // special case "closure-self" - continue - } - unsafe.Set(e, uv) - } - - for v := range unsafe[e] { - if ovs.Contains(v) || safe.Contains(v) { - delete(unsafe[e], v) - } - } - - if len(unsafe[e]) == 0 { - delete(unsafe, e) - reordered.Append(e) - safe.Update(ovs) // this expression's outputs are safe - } - } - - if len(reordered) == n { // fixed point, could not add any expr of body - break - } - } - - // Recursively visit closures and perform the safety checks on them. - // Update the globals at each expression to include the variables that could - // be closed over. - g := globals.Copy() - for i, e := range reordered { - if i > 0 { - g.Update(reordered[i-1].Vars(SafetyCheckVisitorParams)) - } - xform := &bodySafetyTransformer{ - builtins: builtins, - arity: arity, - current: e, - globals: g, - unsafe: unsafe, - } - NewGenericVisitor(xform.Visit).Walk(e) - } - - return reordered, unsafe -} - -type bodySafetyTransformer struct { - builtins map[string]*Builtin - arity func(Ref) int - current *Expr - globals VarSet - unsafe unsafeVars -} - -func (xform *bodySafetyTransformer) Visit(x interface{}) bool { - switch term := x.(type) { - case *Term: - switch x := term.Value.(type) { - case *object: - cpy, _ := x.Map(func(k, v *Term) (*Term, *Term, error) { - kcpy := k.Copy() - NewGenericVisitor(xform.Visit).Walk(kcpy) - vcpy := v.Copy() - NewGenericVisitor(xform.Visit).Walk(vcpy) - return kcpy, vcpy, nil - }) - term.Value = cpy - return true - case *set: - cpy, _ := x.Map(func(v *Term) (*Term, error) { - vcpy := v.Copy() - NewGenericVisitor(xform.Visit).Walk(vcpy) - return vcpy, nil - }) - term.Value = cpy - return true - case *ArrayComprehension: - xform.reorderArrayComprehensionSafety(x) - return true - case *ObjectComprehension: - xform.reorderObjectComprehensionSafety(x) - return true - case *SetComprehension: - xform.reorderSetComprehensionSafety(x) - return true - } - case *Expr: - if ev, ok := term.Terms.(*Every); ok { - xform.globals.Update(ev.KeyValueVars()) - ev.Body = xform.reorderComprehensionSafety(NewVarSet(), ev.Body) - return true - } - } - return false -} - -func (xform *bodySafetyTransformer) reorderComprehensionSafety(tv VarSet, body Body) Body { - bv := body.Vars(SafetyCheckVisitorParams) - bv.Update(xform.globals) - uv := tv.Diff(bv) - for v := range uv { - xform.unsafe.Add(xform.current, v) - } - - r, u := reorderBodyForSafety(xform.builtins, xform.arity, xform.globals, body) - if len(u) == 0 { - return r - } - - xform.unsafe.Update(u) - return body -} - -func (xform *bodySafetyTransformer) reorderArrayComprehensionSafety(ac *ArrayComprehension) { - ac.Body = xform.reorderComprehensionSafety(ac.Term.Vars(), ac.Body) -} - -func (xform *bodySafetyTransformer) reorderObjectComprehensionSafety(oc *ObjectComprehension) { - tv := oc.Key.Vars() - tv.Update(oc.Value.Vars()) - oc.Body = xform.reorderComprehensionSafety(tv, oc.Body) -} - -func (xform *bodySafetyTransformer) reorderSetComprehensionSafety(sc *SetComprehension) { - sc.Body = xform.reorderComprehensionSafety(sc.Term.Vars(), sc.Body) -} - -// unsafeVarsInClosures collects vars that are contained in closures within -// this expression. -func unsafeVarsInClosures(e *Expr) VarSet { - vs := VarSet{} - WalkClosures(e, func(x interface{}) bool { - vis := &VarVisitor{vars: vs} - if ev, ok := x.(*Every); ok { - vis.Walk(ev.Body) - return true - } - vis.Walk(x) - return true - }) - return vs -} - -// OutputVarsFromBody returns all variables which are the "output" for -// the given body. For safety checks this means that they would be -// made safe by the body. -func OutputVarsFromBody(c *Compiler, body Body, safe VarSet) VarSet { - return outputVarsForBody(body, c.GetArity, safe) -} - -func outputVarsForBody(body Body, arity func(Ref) int, safe VarSet) VarSet { - o := safe.Copy() - for _, e := range body { - o.Update(outputVarsForExpr(e, arity, o)) - } - return o.Diff(safe) -} - -// OutputVarsFromExpr returns all variables which are the "output" for -// the given expression. For safety checks this means that they would be -// made safe by the expr. -func OutputVarsFromExpr(c *Compiler, expr *Expr, safe VarSet) VarSet { - return outputVarsForExpr(expr, c.GetArity, safe) -} - -func outputVarsForExpr(expr *Expr, arity func(Ref) int, safe VarSet) VarSet { - - // Negated expressions must be safe. - if expr.Negated { - return VarSet{} - } - - // With modifier inputs must be safe. - for _, with := range expr.With { - vis := NewVarVisitor().WithParams(SafetyCheckVisitorParams) - vis.Walk(with) - vars := vis.Vars() - unsafe := vars.Diff(safe) - if len(unsafe) > 0 { - return VarSet{} - } - } - - switch terms := expr.Terms.(type) { - case *Term: - return outputVarsForTerms(expr, safe) - case []*Term: - if expr.IsEquality() { - return outputVarsForExprEq(expr, safe) - } - - operator, ok := terms[0].Value.(Ref) - if !ok { - return VarSet{} - } - - ar := arity(operator) - if ar < 0 { - return VarSet{} - } - - return outputVarsForExprCall(expr, ar, safe, terms) - case *Every: - return outputVarsForTerms(terms.Domain, safe) - default: - panic("illegal expression") - } -} - -func outputVarsForExprEq(expr *Expr, safe VarSet) VarSet { - - if !validEqAssignArgCount(expr) { - return safe - } - - output := outputVarsForTerms(expr, safe) - output.Update(safe) - output.Update(Unify(output, expr.Operand(0), expr.Operand(1))) - - return output.Diff(safe) -} - -func outputVarsForExprCall(expr *Expr, arity int, safe VarSet, terms []*Term) VarSet { - - output := outputVarsForTerms(expr, safe) - - numInputTerms := arity + 1 - if numInputTerms >= len(terms) { - return output - } - - params := VarVisitorParams{ - SkipClosures: true, - SkipSets: true, - SkipObjectKeys: true, - SkipRefHead: true, - } - vis := NewVarVisitor().WithParams(params) - vis.Walk(Args(terms[:numInputTerms])) - unsafe := vis.Vars().Diff(output).Diff(safe) - - if len(unsafe) > 0 { - return VarSet{} - } - - vis = NewVarVisitor().WithParams(params) - vis.Walk(Args(terms[numInputTerms:])) - output.Update(vis.vars) - return output -} - -func outputVarsForTerms(expr interface{}, safe VarSet) VarSet { - output := VarSet{} - WalkTerms(expr, func(x *Term) bool { - switch r := x.Value.(type) { - case *SetComprehension, *ArrayComprehension, *ObjectComprehension: - return true - case Ref: - if !isRefSafe(r, safe) { - return true - } - output.Update(r.OutputVars()) - return false - } - return false - }) - return output -} - -type equalityFactory struct { - gen *localVarGenerator -} - -func newEqualityFactory(gen *localVarGenerator) *equalityFactory { - return &equalityFactory{gen} -} - -func (f *equalityFactory) Generate(other *Term) *Expr { - term := NewTerm(f.gen.Generate()).SetLocation(other.Location) - expr := Equality.Expr(term, other) - expr.Generated = true - expr.Location = other.Location - return expr -} - -type localVarGenerator struct { - exclude VarSet - suffix string - next int -} - -func newLocalVarGeneratorForModuleSet(sorted []string, modules map[string]*Module) *localVarGenerator { - exclude := NewVarSet() - vis := &VarVisitor{vars: exclude} - for _, key := range sorted { - vis.Walk(modules[key]) - } - return &localVarGenerator{exclude: exclude, next: 0} -} - -func newLocalVarGenerator(suffix string, node interface{}) *localVarGenerator { - exclude := NewVarSet() - vis := &VarVisitor{vars: exclude} - vis.Walk(node) - return &localVarGenerator{exclude: exclude, suffix: suffix, next: 0} -} - -func (l *localVarGenerator) Generate() Var { - for { - result := Var("__local" + l.suffix + strconv.Itoa(l.next) + "__") - l.next++ - if !l.exclude.Contains(result) { - return result - } - } -} - -func getGlobals(pkg *Package, rules []Ref, imports []*Import) map[Var]*usedRef { - - globals := make(map[Var]*usedRef, len(rules)) // NB: might grow bigger with imports - - // Populate globals with exports within the package. - for _, ref := range rules { - v := ref[0].Value.(Var) - globals[v] = &usedRef{ref: pkg.Path.Append(StringTerm(string(v)))} - } - - // Populate globals with imports. - for _, imp := range imports { - path := imp.Path.Value.(Ref) - if FutureRootDocument.Equal(path[0]) || RegoRootDocument.Equal(path[0]) { - continue // ignore future and rego imports - } - globals[imp.Name()] = &usedRef{ref: path} - } - - return globals -} - -func requiresEval(x *Term) bool { - if x == nil { - return false - } - return ContainsRefs(x) || ContainsComprehensions(x) -} - -func resolveRef(globals map[Var]*usedRef, ignore *declaredVarStack, ref Ref) Ref { - - r := Ref{} - for i, x := range ref { - switch v := x.Value.(type) { - case Var: - if g, ok := globals[v]; ok && !ignore.Contains(v) { - cpy := g.ref.Copy() - for i := range cpy { - cpy[i].SetLocation(x.Location) - } - if i == 0 { - r = cpy - } else { - r = append(r, NewTerm(cpy).SetLocation(x.Location)) - } - g.used = true - } else { - r = append(r, x) - } - case Ref, *Array, Object, Set, *ArrayComprehension, *SetComprehension, *ObjectComprehension, Call: - r = append(r, resolveRefsInTerm(globals, ignore, x)) - default: - r = append(r, x) - } - } - - return r -} - -type usedRef struct { - ref Ref - used bool -} - -func resolveRefsInRule(globals map[Var]*usedRef, rule *Rule) error { - ignore := &declaredVarStack{} - - vars := NewVarSet() - var vis *GenericVisitor - var err error - - // Walk args to collect vars and transform body so that callers can shadow - // root documents. - vis = NewGenericVisitor(func(x interface{}) bool { - if err != nil { - return true - } - switch x := x.(type) { - case Var: - vars.Add(x) - - // Object keys cannot be pattern matched so only walk values. - case *object: - x.Foreach(func(_, v *Term) { - vis.Walk(v) - }) - - // Skip terms that could contain vars that cannot be pattern matched. - case Set, *ArrayComprehension, *SetComprehension, *ObjectComprehension, Call: - return true - - case *Term: - if _, ok := x.Value.(Ref); ok { - if RootDocumentRefs.Contains(x) { - // We could support args named input, data, etc. however - // this would require rewriting terms in the head and body. - // Preventing root document shadowing is simpler, and - // arguably, will prevent confusing names from being used. - // NOTE: this check is also performed as part of strict-mode in - // checkRootDocumentOverrides. - err = fmt.Errorf("args must not shadow %v (use a different variable name)", x) - return true - } - } - } - return false - }) - - vis.Walk(rule.Head.Args) - - if err != nil { - return err - } - - ignore.Push(vars) - ignore.Push(declaredVars(rule.Body)) - - ref := rule.Head.Ref() - for i := 1; i < len(ref); i++ { - ref[i] = resolveRefsInTerm(globals, ignore, ref[i]) - } - if rule.Head.Key != nil { - rule.Head.Key = resolveRefsInTerm(globals, ignore, rule.Head.Key) - } - - if rule.Head.Value != nil { - rule.Head.Value = resolveRefsInTerm(globals, ignore, rule.Head.Value) - } - - rule.Body = resolveRefsInBody(globals, ignore, rule.Body) - return nil -} - -func resolveRefsInBody(globals map[Var]*usedRef, ignore *declaredVarStack, body Body) Body { - r := make([]*Expr, 0, len(body)) - for _, expr := range body { - r = append(r, resolveRefsInExpr(globals, ignore, expr)) - } - return r -} - -func resolveRefsInExpr(globals map[Var]*usedRef, ignore *declaredVarStack, expr *Expr) *Expr { - cpy := *expr - switch ts := expr.Terms.(type) { - case *Term: - cpy.Terms = resolveRefsInTerm(globals, ignore, ts) - case []*Term: - buf := make([]*Term, len(ts)) - for i := 0; i < len(ts); i++ { - buf[i] = resolveRefsInTerm(globals, ignore, ts[i]) - } - cpy.Terms = buf - case *SomeDecl: - if val, ok := ts.Symbols[0].Value.(Call); ok { - cpy.Terms = &SomeDecl{Symbols: []*Term{CallTerm(resolveRefsInTermSlice(globals, ignore, val)...)}} - } - case *Every: - locals := NewVarSet() - if ts.Key != nil { - locals.Update(ts.Key.Vars()) - } - locals.Update(ts.Value.Vars()) - ignore.Push(locals) - cpy.Terms = &Every{ - Key: ts.Key.Copy(), // TODO(sr): do more? - Value: ts.Value.Copy(), // TODO(sr): do more? - Domain: resolveRefsInTerm(globals, ignore, ts.Domain), - Body: resolveRefsInBody(globals, ignore, ts.Body), - } - ignore.Pop() - } - for _, w := range cpy.With { - w.Target = resolveRefsInTerm(globals, ignore, w.Target) - w.Value = resolveRefsInTerm(globals, ignore, w.Value) - } - return &cpy -} - -func resolveRefsInTerm(globals map[Var]*usedRef, ignore *declaredVarStack, term *Term) *Term { - switch v := term.Value.(type) { - case Var: - if g, ok := globals[v]; ok && !ignore.Contains(v) { - cpy := g.ref.Copy() - for i := range cpy { - cpy[i].SetLocation(term.Location) - } - g.used = true - return NewTerm(cpy).SetLocation(term.Location) - } - return term - case Ref: - fqn := resolveRef(globals, ignore, v) - cpy := *term - cpy.Value = fqn - return &cpy - case *object: - cpy := *term - cpy.Value, _ = v.Map(func(k, v *Term) (*Term, *Term, error) { - k = resolveRefsInTerm(globals, ignore, k) - v = resolveRefsInTerm(globals, ignore, v) - return k, v, nil - }) - return &cpy - case *Array: - cpy := *term - cpy.Value = NewArray(resolveRefsInTermArray(globals, ignore, v)...) - return &cpy - case Call: - cpy := *term - cpy.Value = Call(resolveRefsInTermSlice(globals, ignore, v)) - return &cpy - case Set: - s, _ := v.Map(func(e *Term) (*Term, error) { - return resolveRefsInTerm(globals, ignore, e), nil - }) - cpy := *term - cpy.Value = s - return &cpy - case *ArrayComprehension: - ac := &ArrayComprehension{} - ignore.Push(declaredVars(v.Body)) - ac.Term = resolveRefsInTerm(globals, ignore, v.Term) - ac.Body = resolveRefsInBody(globals, ignore, v.Body) - cpy := *term - cpy.Value = ac - ignore.Pop() - return &cpy - case *ObjectComprehension: - oc := &ObjectComprehension{} - ignore.Push(declaredVars(v.Body)) - oc.Key = resolveRefsInTerm(globals, ignore, v.Key) - oc.Value = resolveRefsInTerm(globals, ignore, v.Value) - oc.Body = resolveRefsInBody(globals, ignore, v.Body) - cpy := *term - cpy.Value = oc - ignore.Pop() - return &cpy - case *SetComprehension: - sc := &SetComprehension{} - ignore.Push(declaredVars(v.Body)) - sc.Term = resolveRefsInTerm(globals, ignore, v.Term) - sc.Body = resolveRefsInBody(globals, ignore, v.Body) - cpy := *term - cpy.Value = sc - ignore.Pop() - return &cpy - default: - return term - } -} - -func resolveRefsInTermArray(globals map[Var]*usedRef, ignore *declaredVarStack, terms *Array) []*Term { - cpy := make([]*Term, terms.Len()) - for i := 0; i < terms.Len(); i++ { - cpy[i] = resolveRefsInTerm(globals, ignore, terms.Elem(i)) - } - return cpy -} - -func resolveRefsInTermSlice(globals map[Var]*usedRef, ignore *declaredVarStack, terms []*Term) []*Term { - cpy := make([]*Term, len(terms)) - for i := 0; i < len(terms); i++ { - cpy[i] = resolveRefsInTerm(globals, ignore, terms[i]) - } - return cpy -} - -type declaredVarStack []VarSet - -func (s declaredVarStack) Contains(v Var) bool { - for i := len(s) - 1; i >= 0; i-- { - if _, ok := s[i][v]; ok { - return ok - } - } - return false -} - -func (s declaredVarStack) Add(v Var) { - s[len(s)-1].Add(v) -} - -func (s *declaredVarStack) Push(vs VarSet) { - *s = append(*s, vs) -} - -func (s *declaredVarStack) Pop() { - curr := *s - *s = curr[:len(curr)-1] -} - -func declaredVars(x interface{}) VarSet { - vars := NewVarSet() - vis := NewGenericVisitor(func(x interface{}) bool { - switch x := x.(type) { - case *Expr: - if x.IsAssignment() && validEqAssignArgCount(x) { - WalkVars(x.Operand(0), func(v Var) bool { - vars.Add(v) - return false - }) - } else if decl, ok := x.Terms.(*SomeDecl); ok { - for i := range decl.Symbols { - switch val := decl.Symbols[i].Value.(type) { - case Var: - vars.Add(val) - case Call: - args := val[1:] - if len(args) == 3 { // some x, y in xs - WalkVars(args[1], func(v Var) bool { - vars.Add(v) - return false - }) - } - // some x in xs - WalkVars(args[0], func(v Var) bool { - vars.Add(v) - return false - }) - } - } - } - case *ArrayComprehension, *SetComprehension, *ObjectComprehension: - return true - } - return false - }) - vis.Walk(x) - return vars -} - -// rewriteComprehensionTerms will rewrite comprehensions so that the term part -// is bound to a variable in the body. This allows any type of term to be used -// in the term part (even if the term requires evaluation.) -// -// For instance, given the following comprehension: -// -// [x[0] | x = y[_]; y = [1,2,3]] -// -// The comprehension would be rewritten as: -// -// [__local0__ | x = y[_]; y = [1,2,3]; __local0__ = x[0]] -func rewriteComprehensionTerms(f *equalityFactory, node interface{}) (interface{}, error) { - return TransformComprehensions(node, func(x interface{}) (Value, error) { - switch x := x.(type) { - case *ArrayComprehension: - if requiresEval(x.Term) { - expr := f.Generate(x.Term) - x.Term = expr.Operand(0) - x.Body.Append(expr) - } - return x, nil - case *SetComprehension: - if requiresEval(x.Term) { - expr := f.Generate(x.Term) - x.Term = expr.Operand(0) - x.Body.Append(expr) - } - return x, nil - case *ObjectComprehension: - if requiresEval(x.Key) { - expr := f.Generate(x.Key) - x.Key = expr.Operand(0) - x.Body.Append(expr) - } - if requiresEval(x.Value) { - expr := f.Generate(x.Value) - x.Value = expr.Operand(0) - x.Body.Append(expr) - } - return x, nil - } - panic("illegal type") - }) -} - -// rewriteEquals will rewrite exprs under x as unification calls instead of == -// calls. For example: -// -// data.foo == data.bar is rewritten as data.foo = data.bar -// -// This stage should only run the safety check (since == is a built-in with no -// outputs, so the inputs must not be marked as safe.) -// -// This stage is not executed by the query compiler by default because when -// callers specify == instead of = they expect to receive a true/false/undefined -// result back whereas with = the result is only ever true/undefined. For -// partial evaluation cases we do want to rewrite == to = to simplify the -// result. -func rewriteEquals(x interface{}) (modified bool) { - doubleEq := Equal.Ref() - unifyOp := Equality.Ref() - t := NewGenericTransformer(func(x interface{}) (interface{}, error) { - if x, ok := x.(*Expr); ok && x.IsCall() { - operator := x.Operator() - if operator.Equal(doubleEq) && len(x.Operands()) == 2 { - modified = true - x.SetOperator(NewTerm(unifyOp)) - } - } - return x, nil - }) - _, _ = Transform(t, x) // ignore error - return modified -} - -func rewriteTestEqualities(f *equalityFactory, body Body) Body { - result := make(Body, 0, len(body)) - for _, expr := range body { - // We can't rewrite negated expressions; if the extracted term is undefined, evaluation would fail before - // reaching the negation check. - if !expr.Negated && !expr.Generated { - switch { - case expr.IsEquality(): - terms := expr.Terms.([]*Term) - result, terms[1] = rewriteDynamicsShallow(expr, f, terms[1], result) - result, terms[2] = rewriteDynamicsShallow(expr, f, terms[2], result) - case expr.IsEvery(): - // We rewrite equalities inside of every-bodies as a fail here will be the cause of the test-rule fail. - // Failures inside other expressions with closures, such as comprehensions, won't cause the test-rule to fail, so we skip those. - every := expr.Terms.(*Every) - every.Body = rewriteTestEqualities(f, every.Body) - } - } - result = appendExpr(result, expr) - } - return result -} - -func rewriteDynamicsShallow(original *Expr, f *equalityFactory, term *Term, result Body) (Body, *Term) { - switch term.Value.(type) { - case Ref, *ArrayComprehension, *SetComprehension, *ObjectComprehension: - generated := f.Generate(term) - generated.With = original.With - result.Append(generated) - connectGeneratedExprs(original, generated) - return result, result[len(result)-1].Operand(0) - } - return result, term -} - -// rewriteDynamics will rewrite the body so that dynamic terms (i.e., refs and -// comprehensions) are bound to vars earlier in the query. This translation -// results in eager evaluation. -// -// For instance, given the following query: -// -// foo(data.bar) = 1 -// -// The rewritten version will be: -// -// __local0__ = data.bar; foo(__local0__) = 1 -func rewriteDynamics(f *equalityFactory, body Body) Body { - result := make(Body, 0, len(body)) - for _, expr := range body { - switch { - case expr.IsEquality(): - result = rewriteDynamicsEqExpr(f, expr, result) - case expr.IsCall(): - result = rewriteDynamicsCallExpr(f, expr, result) - case expr.IsEvery(): - result = rewriteDynamicsEveryExpr(f, expr, result) - default: - result = rewriteDynamicsTermExpr(f, expr, result) - } - } - return result -} - -func appendExpr(body Body, expr *Expr) Body { - body.Append(expr) - return body -} - -func rewriteDynamicsEqExpr(f *equalityFactory, expr *Expr, result Body) Body { - if !validEqAssignArgCount(expr) { - return appendExpr(result, expr) - } - terms := expr.Terms.([]*Term) - result, terms[1] = rewriteDynamicsInTerm(expr, f, terms[1], result) - result, terms[2] = rewriteDynamicsInTerm(expr, f, terms[2], result) - return appendExpr(result, expr) -} - -func rewriteDynamicsCallExpr(f *equalityFactory, expr *Expr, result Body) Body { - terms := expr.Terms.([]*Term) - for i := 1; i < len(terms); i++ { - result, terms[i] = rewriteDynamicsOne(expr, f, terms[i], result) - } - return appendExpr(result, expr) -} - -func rewriteDynamicsEveryExpr(f *equalityFactory, expr *Expr, result Body) Body { - ev := expr.Terms.(*Every) - result, ev.Domain = rewriteDynamicsOne(expr, f, ev.Domain, result) - ev.Body = rewriteDynamics(f, ev.Body) - return appendExpr(result, expr) -} - -func rewriteDynamicsTermExpr(f *equalityFactory, expr *Expr, result Body) Body { - term := expr.Terms.(*Term) - result, expr.Terms = rewriteDynamicsInTerm(expr, f, term, result) - return appendExpr(result, expr) -} - -func rewriteDynamicsInTerm(original *Expr, f *equalityFactory, term *Term, result Body) (Body, *Term) { - switch v := term.Value.(type) { - case Ref: - for i := 1; i < len(v); i++ { - result, v[i] = rewriteDynamicsOne(original, f, v[i], result) - } - case *ArrayComprehension: - v.Body = rewriteDynamics(f, v.Body) - case *SetComprehension: - v.Body = rewriteDynamics(f, v.Body) - case *ObjectComprehension: - v.Body = rewriteDynamics(f, v.Body) - default: - result, term = rewriteDynamicsOne(original, f, term, result) - } - return result, term -} - -func rewriteDynamicsOne(original *Expr, f *equalityFactory, term *Term, result Body) (Body, *Term) { - switch v := term.Value.(type) { - case Ref: - for i := 1; i < len(v); i++ { - result, v[i] = rewriteDynamicsOne(original, f, v[i], result) - } - generated := f.Generate(term) - generated.With = original.With - result.Append(generated) - connectGeneratedExprs(original, generated) - return result, result[len(result)-1].Operand(0) - case *Array: - for i := 0; i < v.Len(); i++ { - var t *Term - result, t = rewriteDynamicsOne(original, f, v.Elem(i), result) - v.set(i, t) - } - return result, term - case *object: - cpy := NewObject() - v.Foreach(func(key, value *Term) { - result, key = rewriteDynamicsOne(original, f, key, result) - result, value = rewriteDynamicsOne(original, f, value, result) - cpy.Insert(key, value) - }) - return result, NewTerm(cpy).SetLocation(term.Location) - case Set: - cpy := NewSet() - for _, term := range v.Slice() { - var rw *Term - result, rw = rewriteDynamicsOne(original, f, term, result) - cpy.Add(rw) - } - return result, NewTerm(cpy).SetLocation(term.Location) - case *ArrayComprehension: - var extra *Expr - v.Body, extra = rewriteDynamicsComprehensionBody(original, f, v.Body, term) - result.Append(extra) - connectGeneratedExprs(original, extra) - return result, result[len(result)-1].Operand(0) - case *SetComprehension: - var extra *Expr - v.Body, extra = rewriteDynamicsComprehensionBody(original, f, v.Body, term) - result.Append(extra) - connectGeneratedExprs(original, extra) - return result, result[len(result)-1].Operand(0) - case *ObjectComprehension: - var extra *Expr - v.Body, extra = rewriteDynamicsComprehensionBody(original, f, v.Body, term) - result.Append(extra) - connectGeneratedExprs(original, extra) - return result, result[len(result)-1].Operand(0) - } - return result, term -} - -func rewriteDynamicsComprehensionBody(original *Expr, f *equalityFactory, body Body, term *Term) (Body, *Expr) { - body = rewriteDynamics(f, body) - generated := f.Generate(term) - generated.With = original.With - return body, generated -} - -func rewriteExprTermsInHead(gen *localVarGenerator, rule *Rule) { - for i := range rule.Head.Args { - support, output := expandExprTerm(gen, rule.Head.Args[i]) - for j := range support { - rule.Body.Append(support[j]) - } - rule.Head.Args[i] = output - } - if rule.Head.Key != nil { - support, output := expandExprTerm(gen, rule.Head.Key) - for i := range support { - rule.Body.Append(support[i]) - } - rule.Head.Key = output - } - if rule.Head.Value != nil { - support, output := expandExprTerm(gen, rule.Head.Value) - for i := range support { - rule.Body.Append(support[i]) - } - rule.Head.Value = output - } -} - -func rewriteExprTermsInBody(gen *localVarGenerator, body Body) Body { - cpy := make(Body, 0, len(body)) - for i := 0; i < len(body); i++ { - for _, expr := range expandExpr(gen, body[i]) { - cpy.Append(expr) - } - } - return cpy -} - -func expandExpr(gen *localVarGenerator, expr *Expr) (result []*Expr) { - for i := range expr.With { - extras, value := expandExprTerm(gen, expr.With[i].Value) - expr.With[i].Value = value - result = append(result, extras...) - } - switch terms := expr.Terms.(type) { - case *Term: - extras, term := expandExprTerm(gen, terms) - if len(expr.With) > 0 { - for i := range extras { - extras[i].With = expr.With - } - } - result = append(result, extras...) - expr.Terms = term - result = append(result, expr) - case []*Term: - for i := 1; i < len(terms); i++ { - var extras []*Expr - extras, terms[i] = expandExprTerm(gen, terms[i]) - connectGeneratedExprs(expr, extras...) - if len(expr.With) > 0 { - for i := range extras { - extras[i].With = expr.With - } - } - result = append(result, extras...) - } - result = append(result, expr) - case *Every: - var extras []*Expr - - term := NewTerm(gen.Generate()).SetLocation(terms.Domain.Location) - eq := Equality.Expr(term, terms.Domain).SetLocation(terms.Domain.Location) - eq.Generated = true - eq.With = expr.With - extras = expandExpr(gen, eq) - terms.Domain = term - - terms.Body = rewriteExprTermsInBody(gen, terms.Body) - result = append(result, extras...) - result = append(result, expr) - } - return -} - -func connectGeneratedExprs(parent *Expr, children ...*Expr) { - for _, child := range children { - child.generatedFrom = parent - parent.generates = append(parent.generates, child) - } -} - -func expandExprTerm(gen *localVarGenerator, term *Term) (support []*Expr, output *Term) { - output = term - switch v := term.Value.(type) { - case Call: - for i := 1; i < len(v); i++ { - var extras []*Expr - extras, v[i] = expandExprTerm(gen, v[i]) - support = append(support, extras...) - } - output = NewTerm(gen.Generate()).SetLocation(term.Location) - expr := v.MakeExpr(output).SetLocation(term.Location) - expr.Generated = true - support = append(support, expr) - case Ref: - support = expandExprRef(gen, v) - case *Array: - support = expandExprTermArray(gen, v) - case *object: - cpy, _ := v.Map(func(k, v *Term) (*Term, *Term, error) { - extras1, expandedKey := expandExprTerm(gen, k) - extras2, expandedValue := expandExprTerm(gen, v) - support = append(support, extras1...) - support = append(support, extras2...) - return expandedKey, expandedValue, nil - }) - output = NewTerm(cpy).SetLocation(term.Location) - case Set: - cpy, _ := v.Map(func(x *Term) (*Term, error) { - extras, expanded := expandExprTerm(gen, x) - support = append(support, extras...) - return expanded, nil - }) - output = NewTerm(cpy).SetLocation(term.Location) - case *ArrayComprehension: - support, term := expandExprTerm(gen, v.Term) - for i := range support { - v.Body.Append(support[i]) - } - v.Term = term - v.Body = rewriteExprTermsInBody(gen, v.Body) - case *SetComprehension: - support, term := expandExprTerm(gen, v.Term) - for i := range support { - v.Body.Append(support[i]) - } - v.Term = term - v.Body = rewriteExprTermsInBody(gen, v.Body) - case *ObjectComprehension: - support, key := expandExprTerm(gen, v.Key) - for i := range support { - v.Body.Append(support[i]) - } - v.Key = key - support, value := expandExprTerm(gen, v.Value) - for i := range support { - v.Body.Append(support[i]) - } - v.Value = value - v.Body = rewriteExprTermsInBody(gen, v.Body) - } - return -} - -func expandExprRef(gen *localVarGenerator, v []*Term) (support []*Expr) { - // Start by calling a normal expandExprTerm on all terms. - support = expandExprTermSlice(gen, v) - - // Rewrite references in order to support indirect references. We rewrite - // e.g. - // - // [1, 2, 3][i] - // - // to - // - // __local_var = [1, 2, 3] - // __local_var[i] - // - // to support these. This only impacts the reference subject, i.e. the - // first item in the slice. - var subject = v[0] - switch subject.Value.(type) { - case *Array, Object, Set, *ArrayComprehension, *SetComprehension, *ObjectComprehension, Call: - f := newEqualityFactory(gen) - assignToLocal := f.Generate(subject) - support = append(support, assignToLocal) - v[0] = assignToLocal.Operand(0) - } - return -} - -func expandExprTermArray(gen *localVarGenerator, arr *Array) (support []*Expr) { - for i := 0; i < arr.Len(); i++ { - extras, v := expandExprTerm(gen, arr.Elem(i)) - arr.set(i, v) - support = append(support, extras...) - } - return -} - -func expandExprTermSlice(gen *localVarGenerator, v []*Term) (support []*Expr) { - for i := 0; i < len(v); i++ { - var extras []*Expr - extras, v[i] = expandExprTerm(gen, v[i]) - support = append(support, extras...) - } - return -} - -type localDeclaredVars struct { - vars []*declaredVarSet - - // rewritten contains a mapping of *all* user-defined variables - // that have been rewritten whereas vars contains the state - // from the current query (not any nested queries, and all vars - // seen). - rewritten map[Var]Var - - // indicates if an assignment (:= operator) has been seen *ever* - assignment bool -} - -type varOccurrence int - -const ( - newVar varOccurrence = iota - argVar - seenVar - assignedVar - declaredVar -) - -type declaredVarSet struct { - vs map[Var]Var - reverse map[Var]Var - occurrence map[Var]varOccurrence - count map[Var]int -} - -func newDeclaredVarSet() *declaredVarSet { - return &declaredVarSet{ - vs: map[Var]Var{}, - reverse: map[Var]Var{}, - occurrence: map[Var]varOccurrence{}, - count: map[Var]int{}, - } -} - -func newLocalDeclaredVars() *localDeclaredVars { - return &localDeclaredVars{ - vars: []*declaredVarSet{newDeclaredVarSet()}, - rewritten: map[Var]Var{}, - } -} - -func (s *localDeclaredVars) Copy() *localDeclaredVars { - stack := &localDeclaredVars{ - vars: []*declaredVarSet{}, - rewritten: map[Var]Var{}, - } - - for i := range s.vars { - stack.vars = append(stack.vars, newDeclaredVarSet()) - for k, v := range s.vars[i].vs { - stack.vars[0].vs[k] = v - } - for k, v := range s.vars[i].reverse { - stack.vars[0].reverse[k] = v - } - for k, v := range s.vars[i].count { - stack.vars[0].count[k] = v - } - for k, v := range s.vars[i].occurrence { - stack.vars[0].occurrence[k] = v - } - } - - for k, v := range s.rewritten { - stack.rewritten[k] = v - } - - return stack -} - -func (s *localDeclaredVars) Push() { - s.vars = append(s.vars, newDeclaredVarSet()) -} - -func (s *localDeclaredVars) Pop() *declaredVarSet { - sl := s.vars - curr := sl[len(sl)-1] - s.vars = sl[:len(sl)-1] - return curr -} - -func (s localDeclaredVars) Peek() *declaredVarSet { - return s.vars[len(s.vars)-1] -} - -func (s localDeclaredVars) Insert(x, y Var, occurrence varOccurrence) { - elem := s.vars[len(s.vars)-1] - elem.vs[x] = y - elem.reverse[y] = x - elem.occurrence[x] = occurrence - - elem.count[x] = 1 - - // If the variable has been rewritten (where x != y, with y being - // the generated value), store it in the map of rewritten vars. - // Assume that the generated values are unique for the compilation. - if !x.Equal(y) { - s.rewritten[y] = x - } -} - -func (s localDeclaredVars) Declared(x Var) (y Var, ok bool) { - for i := len(s.vars) - 1; i >= 0; i-- { - if y, ok = s.vars[i].vs[x]; ok { - return - } - } - return -} - -// Occurrence returns a flag that indicates whether x has occurred in the -// current scope. -func (s localDeclaredVars) Occurrence(x Var) varOccurrence { - return s.vars[len(s.vars)-1].occurrence[x] -} - -// GlobalOccurrence returns a flag that indicates whether x has occurred in the -// global scope. -func (s localDeclaredVars) GlobalOccurrence(x Var) (varOccurrence, bool) { - for i := len(s.vars) - 1; i >= 0; i-- { - if occ, ok := s.vars[i].occurrence[x]; ok { - return occ, true - } - } - return newVar, false -} - -// Seen marks x as seen by incrementing its counter -func (s localDeclaredVars) Seen(x Var) { - for i := len(s.vars) - 1; i >= 0; i-- { - dvs := s.vars[i] - if c, ok := dvs.count[x]; ok { - dvs.count[x] = c + 1 - return - } - } - - s.vars[len(s.vars)-1].count[x] = 1 -} - -// Count returns how many times x has been seen -func (s localDeclaredVars) Count(x Var) int { - for i := len(s.vars) - 1; i >= 0; i-- { - if c, ok := s.vars[i].count[x]; ok { - return c - } - } - - return 0 -} - -// rewriteLocalVars rewrites bodies to remove assignment/declaration -// expressions. For example: -// -// a := 1; p[a] -// -// Is rewritten to: -// -// __local0__ = 1; p[__local0__] -// -// During rewriting, assignees are validated to prevent use before declaration. -func rewriteLocalVars(g *localVarGenerator, stack *localDeclaredVars, used VarSet, body Body, strict bool) (Body, map[Var]Var, Errors) { - var errs Errors - body, errs = rewriteDeclaredVarsInBody(g, stack, used, body, errs, strict) - return body, stack.Peek().vs, errs -} - -func rewriteDeclaredVarsInBody(g *localVarGenerator, stack *localDeclaredVars, used VarSet, body Body, errs Errors, strict bool) (Body, Errors) { - - var cpy Body - - for i := range body { - var expr *Expr - switch { - case body[i].IsAssignment(): - stack.assignment = true - expr, errs = rewriteDeclaredAssignment(g, stack, body[i], errs, strict) - case body[i].IsSome(): - expr, errs = rewriteSomeDeclStatement(g, stack, body[i], errs, strict) - case body[i].IsEvery(): - expr, errs = rewriteEveryStatement(g, stack, body[i], errs, strict) - default: - expr, errs = rewriteDeclaredVarsInExpr(g, stack, body[i], errs, strict) - } - if expr != nil { - cpy.Append(expr) - } - } - - // If the body only contained a var statement it will be empty at this - // point. Append true to the body to ensure that it's non-empty (zero length - // bodies are not supported.) - if len(cpy) == 0 { - cpy.Append(NewExpr(BooleanTerm(true))) - } - - errs = checkUnusedAssignedVars(body, stack, used, errs, strict) - return cpy, checkUnusedDeclaredVars(body, stack, used, cpy, errs) -} - -func checkUnusedAssignedVars(body Body, stack *localDeclaredVars, used VarSet, errs Errors, strict bool) Errors { - - if !strict || len(errs) > 0 { - return errs - } - - dvs := stack.Peek() - unused := NewVarSet() - - for v, occ := range dvs.occurrence { - // A var that was assigned in this scope must have been seen (used) more than once (the time of assignment) in - // the same, or nested, scope to be counted as used. - if !v.IsWildcard() && stack.Count(v) <= 1 && occ == assignedVar { - unused.Add(dvs.vs[v]) - } - } - - rewrittenUsed := NewVarSet() - for v := range used { - if gv, ok := stack.Declared(v); ok { - rewrittenUsed.Add(gv) - } else { - rewrittenUsed.Add(v) - } - } - - unused = unused.Diff(rewrittenUsed) - - for _, gv := range unused.Sorted() { - found := false - for i := range body { - if body[i].Vars(VarVisitorParams{}).Contains(gv) { - errs = append(errs, NewError(CompileErr, body[i].Loc(), "assigned var %v unused", dvs.reverse[gv])) - found = true - break - } - } - if !found { - errs = append(errs, NewError(CompileErr, body[0].Loc(), "assigned var %v unused", dvs.reverse[gv])) - } - } - - return errs -} - -func checkUnusedDeclaredVars(body Body, stack *localDeclaredVars, used VarSet, cpy Body, errs Errors) Errors { - - // NOTE(tsandall): Do not generate more errors if there are existing - // declaration errors. - if len(errs) > 0 { - return errs - } - - dvs := stack.Peek() - declared := NewVarSet() - - for v, occ := range dvs.occurrence { - if occ == declaredVar { - declared.Add(dvs.vs[v]) - } - } - - bodyvars := cpy.Vars(VarVisitorParams{}) - - for v := range used { - if gv, ok := stack.Declared(v); ok { - bodyvars.Add(gv) - } else { - bodyvars.Add(v) - } - } - - unused := declared.Diff(bodyvars).Diff(used) - - for _, gv := range unused.Sorted() { - rv := dvs.reverse[gv] - if !rv.IsGenerated() { - // Scan through body exprs, looking for a match between the - // bad var's original name, and each expr's declared vars. - foundUnusedVarByName := false - for i := range body { - varsDeclaredInExpr := declaredVars(body[i]) - if varsDeclaredInExpr.Contains(dvs.reverse[gv]) { - // TODO(philipc): Clean up the offset logic here when the parser - // reports more accurate locations. - errs = append(errs, NewError(CompileErr, body[i].Loc(), "declared var %v unused", dvs.reverse[gv])) - foundUnusedVarByName = true - break - } - } - // Default error location returned. - if !foundUnusedVarByName { - errs = append(errs, NewError(CompileErr, body[0].Loc(), "declared var %v unused", dvs.reverse[gv])) - } - } - } - - return errs -} - -func rewriteEveryStatement(g *localVarGenerator, stack *localDeclaredVars, expr *Expr, errs Errors, strict bool) (*Expr, Errors) { - e := expr.Copy() - every := e.Terms.(*Every) - - errs = rewriteDeclaredVarsInTermRecursive(g, stack, every.Domain, errs, strict) - - stack.Push() - defer stack.Pop() - - // if the key exists, rewrite - if every.Key != nil { - if v := every.Key.Value.(Var); !v.IsWildcard() { - gv, err := rewriteDeclaredVar(g, stack, v, declaredVar) - if err != nil { - return nil, append(errs, NewError(CompileErr, every.Loc(), err.Error())) - } - every.Key.Value = gv - } - } else { // if the key doesn't exist, add dummy local - every.Key = NewTerm(g.Generate()) - } - - // value is always present - if v := every.Value.Value.(Var); !v.IsWildcard() { - gv, err := rewriteDeclaredVar(g, stack, v, declaredVar) - if err != nil { - return nil, append(errs, NewError(CompileErr, every.Loc(), err.Error())) - } - every.Value.Value = gv - } - - used := NewVarSet() - every.Body, errs = rewriteDeclaredVarsInBody(g, stack, used, every.Body, errs, strict) - - return rewriteDeclaredVarsInExpr(g, stack, e, errs, strict) -} - -func rewriteSomeDeclStatement(g *localVarGenerator, stack *localDeclaredVars, expr *Expr, errs Errors, strict bool) (*Expr, Errors) { - e := expr.Copy() - decl := e.Terms.(*SomeDecl) - for i := range decl.Symbols { - switch v := decl.Symbols[i].Value.(type) { - case Var: - if _, err := rewriteDeclaredVar(g, stack, v, declaredVar); err != nil { - return nil, append(errs, NewError(CompileErr, decl.Loc(), err.Error())) - } - case Call: - var key, val, container *Term - switch len(v) { - case 4: // member3 - key = v[1] - val = v[2] - container = v[3] - case 3: // member - key = NewTerm(g.Generate()) - val = v[1] - container = v[2] - } - - var rhs *Term - switch c := container.Value.(type) { - case Ref: - rhs = RefTerm(append(c, key)...) - default: - rhs = RefTerm(container, key) - } - e.Terms = []*Term{ - RefTerm(VarTerm(Equality.Name)), val, rhs, - } - - for _, v0 := range outputVarsForExprEq(e, container.Vars()).Sorted() { - if _, err := rewriteDeclaredVar(g, stack, v0, declaredVar); err != nil { - return nil, append(errs, NewError(CompileErr, decl.Loc(), err.Error())) - } - } - return rewriteDeclaredVarsInExpr(g, stack, e, errs, strict) - } - } - return nil, errs -} - -func rewriteDeclaredVarsInExpr(g *localVarGenerator, stack *localDeclaredVars, expr *Expr, errs Errors, strict bool) (*Expr, Errors) { - vis := NewGenericVisitor(func(x interface{}) bool { - var stop bool - switch x := x.(type) { - case *Term: - stop, errs = rewriteDeclaredVarsInTerm(g, stack, x, errs, strict) - case *With: - stop, errs = true, rewriteDeclaredVarsInWithRecursive(g, stack, x, errs, strict) - } - return stop - }) - vis.Walk(expr) - return expr, errs -} - -func rewriteDeclaredAssignment(g *localVarGenerator, stack *localDeclaredVars, expr *Expr, errs Errors, strict bool) (*Expr, Errors) { - - if expr.Negated { - errs = append(errs, NewError(CompileErr, expr.Location, "cannot assign vars inside negated expression")) - return expr, errs - } - - numErrsBefore := len(errs) - - if !validEqAssignArgCount(expr) { - return expr, errs - } - - // Rewrite terms on right hand side capture seen vars and recursively - // process comprehensions before left hand side is processed. Also - // rewrite with modifier. - errs = rewriteDeclaredVarsInTermRecursive(g, stack, expr.Operand(1), errs, strict) - - for _, w := range expr.With { - errs = rewriteDeclaredVarsInTermRecursive(g, stack, w.Value, errs, strict) - } - - // Rewrite vars on left hand side with unique names. Catch redeclaration - // and invalid term types here. - var vis func(t *Term) bool - - vis = func(t *Term) bool { - switch v := t.Value.(type) { - case Var: - if gv, err := rewriteDeclaredVar(g, stack, v, assignedVar); err != nil { - errs = append(errs, NewError(CompileErr, t.Location, err.Error())) - } else { - t.Value = gv - } - return true - case *Array: - return false - case *object: - v.Foreach(func(_, v *Term) { - WalkTerms(v, vis) - }) - return true - case Ref: - if RootDocumentRefs.Contains(t) { - if gv, err := rewriteDeclaredVar(g, stack, v[0].Value.(Var), assignedVar); err != nil { - errs = append(errs, NewError(CompileErr, t.Location, err.Error())) - } else { - t.Value = gv - } - return true - } - } - errs = append(errs, NewError(CompileErr, t.Location, "cannot assign to %v", TypeName(t.Value))) - return true - } - - WalkTerms(expr.Operand(0), vis) - - if len(errs) == numErrsBefore { - loc := expr.Operator()[0].Location - expr.SetOperator(RefTerm(VarTerm(Equality.Name).SetLocation(loc)).SetLocation(loc)) - } - - return expr, errs -} - -func rewriteDeclaredVarsInTerm(g *localVarGenerator, stack *localDeclaredVars, term *Term, errs Errors, strict bool) (bool, Errors) { - switch v := term.Value.(type) { - case Var: - if gv, ok := stack.Declared(v); ok { - term.Value = gv - stack.Seen(v) - } else if stack.Occurrence(v) == newVar { - stack.Insert(v, v, seenVar) - } - case Ref: - if RootDocumentRefs.Contains(term) { - x := v[0].Value.(Var) - if occ, ok := stack.GlobalOccurrence(x); ok && occ != seenVar { - gv, _ := stack.Declared(x) - term.Value = gv - } - - return true, errs - } - return false, errs - case Call: - ref := v[0] - WalkVars(ref, func(v Var) bool { - if gv, ok := stack.Declared(v); ok && !gv.Equal(v) { - // We will rewrite the ref of a function call, which is never ok since we don't have first-class functions. - errs = append(errs, NewError(CompileErr, term.Location, "called function %s shadowed", ref)) - return true - } - return false - }) - return false, errs - case *object: - cpy, _ := v.Map(func(k, v *Term) (*Term, *Term, error) { - kcpy := k.Copy() - errs = rewriteDeclaredVarsInTermRecursive(g, stack, kcpy, errs, strict) - errs = rewriteDeclaredVarsInTermRecursive(g, stack, v, errs, strict) - return kcpy, v, nil - }) - term.Value = cpy - case Set: - cpy, _ := v.Map(func(elem *Term) (*Term, error) { - elemcpy := elem.Copy() - errs = rewriteDeclaredVarsInTermRecursive(g, stack, elemcpy, errs, strict) - return elemcpy, nil - }) - term.Value = cpy - case *ArrayComprehension: - errs = rewriteDeclaredVarsInArrayComprehension(g, stack, v, errs, strict) - case *SetComprehension: - errs = rewriteDeclaredVarsInSetComprehension(g, stack, v, errs, strict) - case *ObjectComprehension: - errs = rewriteDeclaredVarsInObjectComprehension(g, stack, v, errs, strict) - default: - return false, errs - } - return true, errs -} - -func rewriteDeclaredVarsInTermRecursive(g *localVarGenerator, stack *localDeclaredVars, term *Term, errs Errors, strict bool) Errors { - WalkTerms(term, func(t *Term) bool { - var stop bool - stop, errs = rewriteDeclaredVarsInTerm(g, stack, t, errs, strict) - return stop - }) - return errs -} - -func rewriteDeclaredVarsInWithRecursive(g *localVarGenerator, stack *localDeclaredVars, w *With, errs Errors, strict bool) Errors { - // NOTE(sr): `with input as` and `with input.a.b.c as` are deliberately skipped here: `input` could - // have been shadowed by a local variable/argument but should NOT be replaced in the `with` target. - // - // We cannot drop `input` from the stack since it's conceivable to do `with input[input] as` where - // the second input is meant to be the local var. It's a terrible idea, but when you're shadowing - // `input` those might be your thing. - errs = rewriteDeclaredVarsInTermRecursive(g, stack, w.Target, errs, strict) - if sdwInput, ok := stack.Declared(InputRootDocument.Value.(Var)); ok { // Was "input" shadowed... - switch value := w.Target.Value.(type) { - case Var: - if sdwInput.Equal(value) { // ...and replaced? If so, fix it - w.Target.Value = InputRootRef - } - case Ref: - if sdwInput.Equal(value[0].Value.(Var)) { - w.Target.Value.(Ref)[0].Value = InputRootDocument.Value - } - } - } - // No special handling of the `with` value - return rewriteDeclaredVarsInTermRecursive(g, stack, w.Value, errs, strict) -} - -func rewriteDeclaredVarsInArrayComprehension(g *localVarGenerator, stack *localDeclaredVars, v *ArrayComprehension, errs Errors, strict bool) Errors { - used := NewVarSet() - used.Update(v.Term.Vars()) - - stack.Push() - v.Body, errs = rewriteDeclaredVarsInBody(g, stack, used, v.Body, errs, strict) - errs = rewriteDeclaredVarsInTermRecursive(g, stack, v.Term, errs, strict) - stack.Pop() - return errs -} - -func rewriteDeclaredVarsInSetComprehension(g *localVarGenerator, stack *localDeclaredVars, v *SetComprehension, errs Errors, strict bool) Errors { - used := NewVarSet() - used.Update(v.Term.Vars()) - - stack.Push() - v.Body, errs = rewriteDeclaredVarsInBody(g, stack, used, v.Body, errs, strict) - errs = rewriteDeclaredVarsInTermRecursive(g, stack, v.Term, errs, strict) - stack.Pop() - return errs -} - -func rewriteDeclaredVarsInObjectComprehension(g *localVarGenerator, stack *localDeclaredVars, v *ObjectComprehension, errs Errors, strict bool) Errors { - used := NewVarSet() - used.Update(v.Key.Vars()) - used.Update(v.Value.Vars()) - - stack.Push() - v.Body, errs = rewriteDeclaredVarsInBody(g, stack, used, v.Body, errs, strict) - errs = rewriteDeclaredVarsInTermRecursive(g, stack, v.Key, errs, strict) - errs = rewriteDeclaredVarsInTermRecursive(g, stack, v.Value, errs, strict) - stack.Pop() - return errs -} - -func rewriteDeclaredVar(g *localVarGenerator, stack *localDeclaredVars, v Var, occ varOccurrence) (gv Var, err error) { - switch stack.Occurrence(v) { - case seenVar: - return gv, fmt.Errorf("var %v referenced above", v) - case assignedVar: - return gv, fmt.Errorf("var %v assigned above", v) - case declaredVar: - return gv, fmt.Errorf("var %v declared above", v) - case argVar: - return gv, fmt.Errorf("arg %v redeclared", v) - } - gv = g.Generate() - stack.Insert(v, gv, occ) - return -} - -// rewriteWithModifiersInBody will rewrite the body so that with modifiers do -// not contain terms that require evaluation as values. If this function -// encounters an invalid with modifier target then it will raise an error. -func rewriteWithModifiersInBody(c *Compiler, unsafeBuiltinsMap map[string]struct{}, f *equalityFactory, body Body) (Body, *Error) { - var result Body - for i := range body { - exprs, err := rewriteWithModifier(c, unsafeBuiltinsMap, f, body[i]) - if err != nil { - return nil, err - } - if len(exprs) > 0 { - for _, expr := range exprs { - result.Append(expr) - } - } else { - result.Append(body[i]) - } - } - return result, nil -} - -func rewriteWithModifier(c *Compiler, unsafeBuiltinsMap map[string]struct{}, f *equalityFactory, expr *Expr) ([]*Expr, *Error) { - - var result []*Expr - for i := range expr.With { - eval, err := validateWith(c, unsafeBuiltinsMap, expr, i) - if err != nil { - return nil, err - } - - if eval { - eq := f.Generate(expr.With[i].Value) - result = append(result, eq) - expr.With[i].Value = eq.Operand(0) - } - } - - return append(result, expr), nil -} - -func validateWith(c *Compiler, unsafeBuiltinsMap map[string]struct{}, expr *Expr, i int) (bool, *Error) { - target, value := expr.With[i].Target, expr.With[i].Value - - // Ensure that values that are built-ins are rewritten to Ref (not Var) - if v, ok := value.Value.(Var); ok { - if _, ok := c.builtins[v.String()]; ok { - value.Value = Ref([]*Term{NewTerm(v)}) - } - } - isBuiltinRefOrVar, err := isBuiltinRefOrVar(c.builtins, unsafeBuiltinsMap, target) - if err != nil { - return false, err - } - - isAllowedUnknownFuncCall := false - if c.allowUndefinedFuncCalls { - switch target.Value.(type) { - case Ref, Var: - isAllowedUnknownFuncCall = true - } - } - - switch { - case isDataRef(target): - ref := target.Value.(Ref) - targetNode := c.RuleTree - for i := 0; i < len(ref)-1; i++ { - child := targetNode.Child(ref[i].Value) - if child == nil { - break - } else if len(child.Values) > 0 { - return false, NewError(CompileErr, target.Loc(), "with keyword cannot partially replace virtual document(s)") - } - targetNode = child - } - - if targetNode != nil { - // NOTE(sr): at this point in the compiler stages, we don't have a fully-populated - // TypeEnv yet -- so we have to make do with this check to see if the replacement - // target is a function. It's probably wrong for arity-0 functions, but those are - // and edge case anyways. - if child := targetNode.Child(ref[len(ref)-1].Value); child != nil { - for _, v := range child.Values { - if len(v.(*Rule).Head.Args) > 0 { - if ok, err := validateWithFunctionValue(c.builtins, unsafeBuiltinsMap, c.RuleTree, value); err != nil || ok { - return false, err // err may be nil - } - } - } - } - } - - // If the with-value is a ref to a function, but not a call, we can't rewrite it - if r, ok := value.Value.(Ref); ok { - // TODO: check that target ref doesn't exist? - if valueNode := c.RuleTree.Find(r); valueNode != nil { - for _, v := range valueNode.Values { - if len(v.(*Rule).Head.Args) > 0 { - return false, nil - } - } - } - } - case isInputRef(target): // ok, valid - case isBuiltinRefOrVar: - - // NOTE(sr): first we ensure that parsed Var builtins (`count`, `concat`, etc) - // are rewritten to their proper Ref convention - if v, ok := target.Value.(Var); ok { - target.Value = Ref([]*Term{NewTerm(v)}) - } - - targetRef := target.Value.(Ref) - bi := c.builtins[targetRef.String()] // safe because isBuiltinRefOrVar checked this - if err := validateWithBuiltinTarget(bi, targetRef, target.Loc()); err != nil { - return false, err - } - - if ok, err := validateWithFunctionValue(c.builtins, unsafeBuiltinsMap, c.RuleTree, value); err != nil || ok { - return false, err // err may be nil - } - case isAllowedUnknownFuncCall: - // The target isn't a ref to the input doc, data doc, or a known built-in, but it might be a ref to an unknown built-in. - return false, nil - default: - return false, NewError(TypeErr, target.Location, "with keyword target must reference existing %v, %v, or a function", InputRootDocument, DefaultRootDocument) - } - return requiresEval(value), nil -} - -func validateWithBuiltinTarget(bi *Builtin, target Ref, loc *location.Location) *Error { - switch bi.Name { - case Equality.Name, - RegoMetadataChain.Name, - RegoMetadataRule.Name: - return NewError(CompileErr, loc, "with keyword replacing built-in function: replacement of %q invalid", bi.Name) - } - - switch { - case target.HasPrefix(Ref([]*Term{VarTerm("internal")})): - return NewError(CompileErr, loc, "with keyword replacing built-in function: replacement of internal function %q invalid", target) - - case bi.Relation: - return NewError(CompileErr, loc, "with keyword replacing built-in function: target must not be a relation") - - case bi.Decl.Result() == nil: - return NewError(CompileErr, loc, "with keyword replacing built-in function: target must not be a void function") - } - return nil -} - -func validateWithFunctionValue(bs map[string]*Builtin, unsafeMap map[string]struct{}, ruleTree *TreeNode, value *Term) (bool, *Error) { - if v, ok := value.Value.(Ref); ok { - if ruleTree.Find(v) != nil { // ref exists in rule tree - return true, nil - } - } - return isBuiltinRefOrVar(bs, unsafeMap, value) -} - -func isInputRef(term *Term) bool { - if ref, ok := term.Value.(Ref); ok { - if ref.HasPrefix(InputRootRef) { - return true - } - } - return false -} - -func isDataRef(term *Term) bool { - if ref, ok := term.Value.(Ref); ok { - if ref.HasPrefix(DefaultRootRef) { - return true - } - } - return false -} - -func isBuiltinRefOrVar(bs map[string]*Builtin, unsafeBuiltinsMap map[string]struct{}, term *Term) (bool, *Error) { - switch v := term.Value.(type) { - case Ref, Var: - if _, ok := unsafeBuiltinsMap[v.String()]; ok { - return false, NewError(CompileErr, term.Location, "with keyword replacing built-in function: target must not be unsafe: %q", v) - } - _, ok := bs[v.String()] - return ok, nil - } - return false, nil -} - -func isVirtual(node *TreeNode, ref Ref) bool { - for i := range ref { - child := node.Child(ref[i].Value) - if child == nil { - return false - } else if len(child.Values) > 0 { - return true - } - node = child - } - return true -} - -func safetyErrorSlice(unsafe unsafeVars, rewritten map[Var]Var) (result Errors) { - - if len(unsafe) == 0 { - return - } - - for _, pair := range unsafe.Vars() { - v := pair.Var - if w, ok := rewritten[v]; ok { - v = w - } - if !v.IsGenerated() { - if _, ok := futureKeywords[string(v)]; ok { - result = append(result, NewError(UnsafeVarErr, pair.Loc, - "var %[1]v is unsafe (hint: `import future.keywords.%[1]v` to import a future keyword)", v)) - continue - } - result = append(result, NewError(UnsafeVarErr, pair.Loc, "var %v is unsafe", v)) - } - } - - if len(result) > 0 { - return - } - - // If the expression contains unsafe generated variables, report which - // expressions are unsafe instead of the variables that are unsafe (since - // the latter are not meaningful to the user.) - pairs := unsafe.Slice() - - sort.Slice(pairs, func(i, j int) bool { - return pairs[i].Expr.Location.Compare(pairs[j].Expr.Location) < 0 - }) - - // Report at most one error per generated variable. - seen := NewVarSet() - - for _, expr := range pairs { - before := len(seen) - for v := range expr.Vars { - if v.IsGenerated() { - seen.Add(v) - } - } - if len(seen) > before { - result = append(result, NewError(UnsafeVarErr, expr.Expr.Location, "expression is unsafe")) - } - } - - return -} - -func checkUnsafeBuiltins(unsafeBuiltinsMap map[string]struct{}, node interface{}) Errors { - errs := make(Errors, 0) - WalkExprs(node, func(x *Expr) bool { - if x.IsCall() { - operator := x.Operator().String() - if _, ok := unsafeBuiltinsMap[operator]; ok { - errs = append(errs, NewError(TypeErr, x.Loc(), "unsafe built-in function calls in expression: %v", operator)) - } - } - return false - }) - return errs -} - -func rewriteVarsInRef(vars ...map[Var]Var) varRewriter { - return func(node Ref) Ref { - i, _ := TransformVars(node, func(v Var) (Value, error) { - for _, m := range vars { - if u, ok := m[v]; ok { - return u, nil - } - } - return v, nil - }) - return i.(Ref) - } -} - -// NOTE(sr): This is duplicated with compile/compile.go; but moving it into another location -// would cause a circular dependency -- the refSet definition needs ast.Ref. If we make it -// public in the ast package, the compile package could take it from there, but it would also -// increase our public interface. Let's reconsider if we need it in a third place. -type refSet struct { - s []Ref -} - -func newRefSet(x ...Ref) *refSet { - result := &refSet{} - for i := range x { - result.AddPrefix(x[i]) - } - return result -} - -// ContainsPrefix returns true if r is prefixed by any of the existing refs in the set. -func (rs *refSet) ContainsPrefix(r Ref) bool { - for i := range rs.s { - if r.HasPrefix(rs.s[i]) { - return true - } - } - return false -} - -// AddPrefix inserts r into the set if r is not prefixed by any existing -// refs in the set. If any existing refs are prefixed by r, those existing -// refs are removed. -func (rs *refSet) AddPrefix(r Ref) { - if rs.ContainsPrefix(r) { - return - } - cpy := []Ref{r} - for i := range rs.s { - if !rs.s[i].HasPrefix(r) { - cpy = append(cpy, rs.s[i]) - } - } - rs.s = cpy -} - -// Sorted returns a sorted slice of terms for refs in the set. -func (rs *refSet) Sorted() []*Term { - terms := make([]*Term, len(rs.s)) - for i := range rs.s { - terms[i] = NewTerm(rs.s[i]) - } - sort.Slice(terms, func(i, j int) bool { - return terms[i].Value.Compare(terms[j].Value) < 0 - }) - return terms +// OutputVarsFromExpr returns all variables which are the "output" for +// the given expression. For safety checks this means that they would be +// made safe by the expr. +func OutputVarsFromExpr(c *Compiler, expr *Expr, safe VarSet) VarSet { + return v1.OutputVarsFromExpr(c, expr, safe) } diff --git a/vendor/github.com/open-policy-agent/opa/ast/compilehelper.go b/vendor/github.com/open-policy-agent/opa/ast/compilehelper.go index dd48884f9..37ede329e 100644 --- a/vendor/github.com/open-policy-agent/opa/ast/compilehelper.go +++ b/vendor/github.com/open-policy-agent/opa/ast/compilehelper.go @@ -4,41 +4,29 @@ package ast +import v1 "github.com/open-policy-agent/opa/v1/ast" + // CompileModules takes a set of Rego modules represented as strings and // compiles them for evaluation. The keys of the map are used as filenames. func CompileModules(modules map[string]string) (*Compiler, error) { - return CompileModulesWithOpt(modules, CompileOpts{}) + return CompileModulesWithOpt(modules, CompileOpts{ + ParserOptions: ParserOptions{ + RegoVersion: DefaultRegoVersion, + }, + }) } // CompileOpts defines a set of options for the compiler. -type CompileOpts struct { - EnablePrintStatements bool - ParserOptions ParserOptions -} +type CompileOpts = v1.CompileOpts // CompileModulesWithOpt takes a set of Rego modules represented as strings and // compiles them for evaluation. The keys of the map are used as filenames. func CompileModulesWithOpt(modules map[string]string, opts CompileOpts) (*Compiler, error) { - - parsed := make(map[string]*Module, len(modules)) - - for f, module := range modules { - var pm *Module - var err error - if pm, err = ParseModuleWithOpts(f, module, opts.ParserOptions); err != nil { - return nil, err - } - parsed[f] = pm - } - - compiler := NewCompiler().WithEnablePrintStatements(opts.EnablePrintStatements) - compiler.Compile(parsed) - - if compiler.Failed() { - return nil, compiler.Errors + if opts.ParserOptions.RegoVersion == RegoUndefined { + opts.ParserOptions.RegoVersion = DefaultRegoVersion } - return compiler, nil + return v1.CompileModulesWithOpt(modules, opts) } // MustCompileModules compiles a set of Rego modules represented as strings. If diff --git a/vendor/github.com/open-policy-agent/opa/ast/conflicts.go b/vendor/github.com/open-policy-agent/opa/ast/conflicts.go index c2713ad57..10edce382 100644 --- a/vendor/github.com/open-policy-agent/opa/ast/conflicts.go +++ b/vendor/github.com/open-policy-agent/opa/ast/conflicts.go @@ -5,49 +5,11 @@ package ast import ( - "strings" + v1 "github.com/open-policy-agent/opa/v1/ast" ) // CheckPathConflicts returns a set of errors indicating paths that // are in conflict with the result of the provided callable. func CheckPathConflicts(c *Compiler, exists func([]string) (bool, error)) Errors { - var errs Errors - - root := c.RuleTree.Child(DefaultRootDocument.Value) - if root == nil { - return nil - } - - for _, node := range root.Children { - errs = append(errs, checkDocumentConflicts(node, exists, nil)...) - } - - return errs -} - -func checkDocumentConflicts(node *TreeNode, exists func([]string) (bool, error), path []string) Errors { - - switch key := node.Key.(type) { - case String: - path = append(path, string(key)) - default: // other key types cannot conflict with data - return nil - } - - if len(node.Values) > 0 { - s := strings.Join(path, "/") - if ok, err := exists(path); err != nil { - return Errors{NewError(CompileErr, node.Values[0].(*Rule).Loc(), "conflict check for data path %v: %v", s, err.Error())} - } else if ok { - return Errors{NewError(CompileErr, node.Values[0].(*Rule).Loc(), "conflicting rule for data path %v found", s)} - } - } - - var errs Errors - - for _, child := range node.Children { - errs = append(errs, checkDocumentConflicts(child, exists, path)...) - } - - return errs + return v1.CheckPathConflicts(c, exists) } diff --git a/vendor/github.com/open-policy-agent/opa/ast/doc.go b/vendor/github.com/open-policy-agent/opa/ast/doc.go index 62b04e301..ba974e5ba 100644 --- a/vendor/github.com/open-policy-agent/opa/ast/doc.go +++ b/vendor/github.com/open-policy-agent/opa/ast/doc.go @@ -1,36 +1,8 @@ -// Copyright 2016 The OPA Authors. All rights reserved. +// Copyright 2024 The OPA Authors. All rights reserved. // Use of this source code is governed by an Apache2 // license that can be found in the LICENSE file. -// Package ast declares Rego syntax tree types and also includes a parser and compiler for preparing policies for execution in the policy engine. -// -// Rego policies are defined using a relatively small set of types: modules, package and import declarations, rules, expressions, and terms. At their core, policies consist of rules that are defined by one or more expressions over documents available to the policy engine. The expressions are defined by intrinsic values (terms) such as strings, objects, variables, etc. -// -// Rego policies are typically defined in text files and then parsed and compiled by the policy engine at runtime. The parsing stage takes the text or string representation of the policy and converts it into an abstract syntax tree (AST) that consists of the types mentioned above. The AST is organized as follows: -// -// Module -// | -// +--- Package (Reference) -// | -// +--- Imports -// | | -// | +--- Import (Term) -// | -// +--- Rules -// | -// +--- Rule -// | -// +--- Head -// | | -// | +--- Name (Variable) -// | | -// | +--- Key (Term) -// | | -// | +--- Value (Term) -// | -// +--- Body -// | -// +--- Expression (Term | Terms | Variable Declaration) -// -// At query time, the policy engine expects policies to have been compiled. The compilation stage takes one or more modules and compiles them into a format that the policy engine supports. +// Deprecated: This package is intended for older projects transitioning from OPA v0.x and will remain for the lifetime of OPA v1.x, but its use is not recommended. +// For newer features and behaviours, such as defaulting to the Rego v1 syntax, use the corresponding components in the [github.com/open-policy-agent/opa/v1] package instead. +// See https://www.openpolicyagent.org/docs/latest/v0-compatibility/ for more information. package ast diff --git a/vendor/github.com/open-policy-agent/opa/ast/env.go b/vendor/github.com/open-policy-agent/opa/ast/env.go index c767aafef..ef0ccf89c 100644 --- a/vendor/github.com/open-policy-agent/opa/ast/env.go +++ b/vendor/github.com/open-policy-agent/opa/ast/env.go @@ -5,522 +5,8 @@ package ast import ( - "fmt" - "strings" - - "github.com/open-policy-agent/opa/types" - "github.com/open-policy-agent/opa/util" + v1 "github.com/open-policy-agent/opa/v1/ast" ) // TypeEnv contains type info for static analysis such as type checking. -type TypeEnv struct { - tree *typeTreeNode - next *TypeEnv - newChecker func() *typeChecker -} - -// newTypeEnv returns an empty TypeEnv. The constructor is not exported because -// type environments should only be created by the type checker. -func newTypeEnv(f func() *typeChecker) *TypeEnv { - return &TypeEnv{ - tree: newTypeTree(), - newChecker: f, - } -} - -// Get returns the type of x. -func (env *TypeEnv) Get(x interface{}) types.Type { - - if term, ok := x.(*Term); ok { - x = term.Value - } - - switch x := x.(type) { - - // Scalars. - case Null: - return types.NewNull() - case Boolean: - return types.NewBoolean() - case Number: - return types.NewNumber() - case String: - return types.NewString() - - // Composites. - case *Array: - static := make([]types.Type, x.Len()) - for i := range static { - tpe := env.Get(x.Elem(i).Value) - static[i] = tpe - } - - var dynamic types.Type - if len(static) == 0 { - dynamic = types.A - } - - return types.NewArray(static, dynamic) - - case *lazyObj: - return env.Get(x.force()) - case *object: - static := []*types.StaticProperty{} - var dynamic *types.DynamicProperty - - x.Foreach(func(k, v *Term) { - if IsConstant(k.Value) { - kjson, err := JSON(k.Value) - if err == nil { - tpe := env.Get(v) - static = append(static, types.NewStaticProperty(kjson, tpe)) - return - } - } - // Can't handle it as a static property, fallback to dynamic - typeK := env.Get(k.Value) - typeV := env.Get(v.Value) - dynamic = types.NewDynamicProperty(typeK, typeV) - }) - - if len(static) == 0 && dynamic == nil { - dynamic = types.NewDynamicProperty(types.A, types.A) - } - - return types.NewObject(static, dynamic) - - case Set: - var tpe types.Type - x.Foreach(func(elem *Term) { - other := env.Get(elem.Value) - tpe = types.Or(tpe, other) - }) - if tpe == nil { - tpe = types.A - } - return types.NewSet(tpe) - - // Comprehensions. - case *ArrayComprehension: - cpy, errs := env.newChecker().CheckBody(env, x.Body) - if len(errs) == 0 { - return types.NewArray(nil, cpy.Get(x.Term)) - } - return nil - case *ObjectComprehension: - cpy, errs := env.newChecker().CheckBody(env, x.Body) - if len(errs) == 0 { - return types.NewObject(nil, types.NewDynamicProperty(cpy.Get(x.Key), cpy.Get(x.Value))) - } - return nil - case *SetComprehension: - cpy, errs := env.newChecker().CheckBody(env, x.Body) - if len(errs) == 0 { - return types.NewSet(cpy.Get(x.Term)) - } - return nil - - // Refs. - case Ref: - return env.getRef(x) - - // Vars. - case Var: - if node := env.tree.Child(x); node != nil { - return node.Value() - } - if env.next != nil { - return env.next.Get(x) - } - return nil - - // Calls. - case Call: - return nil - - default: - panic("unreachable") - } -} - -func (env *TypeEnv) getRef(ref Ref) types.Type { - - node := env.tree.Child(ref[0].Value) - if node == nil { - return env.getRefFallback(ref) - } - - return env.getRefRec(node, ref, ref[1:]) -} - -func (env *TypeEnv) getRefFallback(ref Ref) types.Type { - - if env.next != nil { - return env.next.Get(ref) - } - - if RootDocumentNames.Contains(ref[0]) { - return types.A - } - - return nil -} - -func (env *TypeEnv) getRefRec(node *typeTreeNode, ref, tail Ref) types.Type { - if len(tail) == 0 { - return env.getRefRecExtent(node) - } - - if node.Leaf() { - if node.children.Len() > 0 { - if child := node.Child(tail[0].Value); child != nil { - return env.getRefRec(child, ref, tail[1:]) - } - } - return selectRef(node.Value(), tail) - } - - if !IsConstant(tail[0].Value) { - return selectRef(env.getRefRecExtent(node), tail) - } - - child := node.Child(tail[0].Value) - if child == nil { - return env.getRefFallback(ref) - } - - return env.getRefRec(child, ref, tail[1:]) -} - -func (env *TypeEnv) getRefRecExtent(node *typeTreeNode) types.Type { - - if node.Leaf() { - return node.Value() - } - - children := []*types.StaticProperty{} - - node.Children().Iter(func(k, v util.T) bool { - key := k.(Value) - child := v.(*typeTreeNode) - - tpe := env.getRefRecExtent(child) - - // NOTE(sr): Converting to Golang-native types here is an extension of what we did - // before -- only supporting strings. But since we cannot differentiate sets and arrays - // that way, we could reconsider. - switch key.(type) { - case String, Number, Boolean: // skip anything else - propKey, err := JSON(key) - if err != nil { - panic(fmt.Errorf("unreachable, ValueToInterface: %w", err)) - } - children = append(children, types.NewStaticProperty(propKey, tpe)) - } - return false - }) - - // TODO(tsandall): for now, these objects can have any dynamic properties - // because we don't have schema for base docs. Once schemas are supported - // we can improve this. - return types.NewObject(children, types.NewDynamicProperty(types.S, types.A)) -} - -func (env *TypeEnv) wrap() *TypeEnv { - cpy := *env - cpy.next = env - cpy.tree = newTypeTree() - return &cpy -} - -// typeTreeNode is used to store type information in a tree. -type typeTreeNode struct { - key Value - value types.Type - children *util.HashMap -} - -func newTypeTree() *typeTreeNode { - return &typeTreeNode{ - key: nil, - value: nil, - children: util.NewHashMap(valueEq, valueHash), - } -} - -func (n *typeTreeNode) Child(key Value) *typeTreeNode { - value, ok := n.children.Get(key) - if !ok { - return nil - } - return value.(*typeTreeNode) -} - -func (n *typeTreeNode) Children() *util.HashMap { - return n.children -} - -func (n *typeTreeNode) Get(path Ref) types.Type { - curr := n - for _, term := range path { - child, ok := curr.children.Get(term.Value) - if !ok { - return nil - } - curr = child.(*typeTreeNode) - } - return curr.Value() -} - -func (n *typeTreeNode) Leaf() bool { - return n.value != nil -} - -func (n *typeTreeNode) PutOne(key Value, tpe types.Type) { - c, ok := n.children.Get(key) - - var child *typeTreeNode - if !ok { - child = newTypeTree() - child.key = key - n.children.Put(key, child) - } else { - child = c.(*typeTreeNode) - } - - child.value = tpe -} - -func (n *typeTreeNode) Put(path Ref, tpe types.Type) { - curr := n - for _, term := range path { - c, ok := curr.children.Get(term.Value) - - var child *typeTreeNode - if !ok { - child = newTypeTree() - child.key = term.Value - curr.children.Put(child.key, child) - } else { - child = c.(*typeTreeNode) - } - - curr = child - } - curr.value = tpe -} - -// Insert inserts tpe at path in the tree, but also merges the value into any types.Object present along that path. -// If a types.Object is inserted, any leafs already present further down the tree are merged into the inserted object. -// path must be ground. -func (n *typeTreeNode) Insert(path Ref, tpe types.Type, env *TypeEnv) { - curr := n - for i, term := range path { - c, ok := curr.children.Get(term.Value) - - var child *typeTreeNode - if !ok { - child = newTypeTree() - child.key = term.Value - curr.children.Put(child.key, child) - } else { - child = c.(*typeTreeNode) - - if child.value != nil && i+1 < len(path) { - // If child has an object value, merge the new value into it. - if o, ok := child.value.(*types.Object); ok { - var err error - child.value, err = insertIntoObject(o, path[i+1:], tpe, env) - if err != nil { - panic(fmt.Errorf("unreachable, insertIntoObject: %w", err)) - } - } - } - } - - curr = child - } - - curr.value = mergeTypes(curr.value, tpe) - - if _, ok := tpe.(*types.Object); ok && curr.children.Len() > 0 { - // merge all leafs into the inserted object - leafs := curr.Leafs() - for p, t := range leafs { - var err error - curr.value, err = insertIntoObject(curr.value.(*types.Object), *p, t, env) - if err != nil { - panic(fmt.Errorf("unreachable, insertIntoObject: %w", err)) - } - } - } -} - -// mergeTypes merges the types of 'a' and 'b'. If both are sets, their 'of' types are joined with an types.Or. -// If both are objects, the key types of their dynamic properties are joined with types.Or:s, and their value types -// are recursively merged (using mergeTypes). -// If 'a' and 'b' are both objects, and at least one of them have static properties, they are joined -// with an types.Or, instead of being merged. -// If 'a' is an Any containing an Object, and 'b' is an Object (or vice versa); AND both objects have no -// static properties, they are merged. -// If 'a' and 'b' are different types, they are joined with an types.Or. -func mergeTypes(a, b types.Type) types.Type { - if a == nil { - return b - } - - if b == nil { - return a - } - - switch a := a.(type) { - case *types.Object: - if bObj, ok := b.(*types.Object); ok && len(a.StaticProperties()) == 0 && len(bObj.StaticProperties()) == 0 { - if len(a.StaticProperties()) > 0 || len(bObj.StaticProperties()) > 0 { - return types.Or(a, bObj) - } - - aDynProps := a.DynamicProperties() - bDynProps := bObj.DynamicProperties() - dynProps := types.NewDynamicProperty( - types.Or(aDynProps.Key, bDynProps.Key), - mergeTypes(aDynProps.Value, bDynProps.Value)) - return types.NewObject(nil, dynProps) - } else if bAny, ok := b.(types.Any); ok && len(a.StaticProperties()) == 0 { - // If a is an object type with no static components ... - for _, t := range bAny { - if tObj, ok := t.(*types.Object); ok && len(tObj.StaticProperties()) == 0 { - // ... and b is a types.Any containing an object with no static components, we merge them. - aDynProps := a.DynamicProperties() - tDynProps := tObj.DynamicProperties() - tDynProps.Key = types.Or(tDynProps.Key, aDynProps.Key) - tDynProps.Value = types.Or(tDynProps.Value, aDynProps.Value) - return bAny - } - } - } - case *types.Set: - if bSet, ok := b.(*types.Set); ok { - return types.NewSet(types.Or(a.Of(), bSet.Of())) - } - case types.Any: - if _, ok := b.(types.Any); !ok { - return mergeTypes(b, a) - } - } - - return types.Or(a, b) -} - -func (n *typeTreeNode) String() string { - b := strings.Builder{} - - if k := n.key; k != nil { - b.WriteString(k.String()) - } else { - b.WriteString("-") - } - - if v := n.value; v != nil { - b.WriteString(": ") - b.WriteString(v.String()) - } - - n.children.Iter(func(_, v util.T) bool { - if child, ok := v.(*typeTreeNode); ok { - b.WriteString("\n\t+ ") - s := child.String() - s = strings.ReplaceAll(s, "\n", "\n\t") - b.WriteString(s) - } - return false - }) - - return b.String() -} - -func insertIntoObject(o *types.Object, path Ref, tpe types.Type, env *TypeEnv) (*types.Object, error) { - if len(path) == 0 { - return o, nil - } - - key := env.Get(path[0].Value) - - if len(path) == 1 { - var dynamicProps *types.DynamicProperty - if dp := o.DynamicProperties(); dp != nil { - dynamicProps = types.NewDynamicProperty(types.Or(o.DynamicProperties().Key, key), types.Or(o.DynamicProperties().Value, tpe)) - } else { - dynamicProps = types.NewDynamicProperty(key, tpe) - } - return types.NewObject(o.StaticProperties(), dynamicProps), nil - } - - child, err := insertIntoObject(types.NewObject(nil, nil), path[1:], tpe, env) - if err != nil { - return nil, err - } - - var dynamicProps *types.DynamicProperty - if dp := o.DynamicProperties(); dp != nil { - dynamicProps = types.NewDynamicProperty(types.Or(o.DynamicProperties().Key, key), types.Or(o.DynamicProperties().Value, child)) - } else { - dynamicProps = types.NewDynamicProperty(key, child) - } - return types.NewObject(o.StaticProperties(), dynamicProps), nil -} - -func (n *typeTreeNode) Leafs() map[*Ref]types.Type { - leafs := map[*Ref]types.Type{} - n.children.Iter(func(_, v util.T) bool { - collectLeafs(v.(*typeTreeNode), nil, leafs) - return false - }) - return leafs -} - -func collectLeafs(n *typeTreeNode, path Ref, leafs map[*Ref]types.Type) { - nPath := append(path, NewTerm(n.key)) - if n.Leaf() { - leafs[&nPath] = n.Value() - return - } - n.children.Iter(func(_, v util.T) bool { - collectLeafs(v.(*typeTreeNode), nPath, leafs) - return false - }) -} - -func (n *typeTreeNode) Value() types.Type { - return n.value -} - -// selectConstant returns the attribute of the type referred to by the term. If -// the attribute type cannot be determined, nil is returned. -func selectConstant(tpe types.Type, term *Term) types.Type { - x, err := JSON(term.Value) - if err == nil { - return types.Select(tpe, x) - } - return nil -} - -// selectRef returns the type of the nested attribute referred to by ref. If -// the attribute type cannot be determined, nil is returned. If the ref -// contains vars or refs, then the returned type will be a union of the -// possible types. -func selectRef(tpe types.Type, ref Ref) types.Type { - - if tpe == nil || len(ref) == 0 { - return tpe - } - - head, tail := ref[0], ref[1:] - - switch head.Value.(type) { - case Var, Ref, *Array, Object, Set: - return selectRef(types.Values(tpe), tail) - default: - return selectRef(selectConstant(tpe, head), tail) - } -} +type TypeEnv = v1.TypeEnv diff --git a/vendor/github.com/open-policy-agent/opa/ast/errors.go b/vendor/github.com/open-policy-agent/opa/ast/errors.go index 066dfcdd6..0cb8ee28f 100644 --- a/vendor/github.com/open-policy-agent/opa/ast/errors.go +++ b/vendor/github.com/open-policy-agent/opa/ast/errors.go @@ -5,119 +5,42 @@ package ast import ( - "fmt" - "sort" - "strings" + v1 "github.com/open-policy-agent/opa/v1/ast" ) // Errors represents a series of errors encountered during parsing, compiling, // etc. -type Errors []*Error - -func (e Errors) Error() string { - - if len(e) == 0 { - return "no error(s)" - } - - if len(e) == 1 { - return fmt.Sprintf("1 error occurred: %v", e[0].Error()) - } - - s := make([]string, len(e)) - for i, err := range e { - s[i] = err.Error() - } - - return fmt.Sprintf("%d errors occurred:\n%s", len(e), strings.Join(s, "\n")) -} - -// Sort sorts the error slice by location. If the locations are equal then the -// error message is compared. -func (e Errors) Sort() { - sort.Slice(e, func(i, j int) bool { - a := e[i] - b := e[j] - - if cmp := a.Location.Compare(b.Location); cmp != 0 { - return cmp < 0 - } - - return a.Error() < b.Error() - }) -} +type Errors = v1.Errors const ( // ParseErr indicates an unclassified parse error occurred. - ParseErr = "rego_parse_error" + ParseErr = v1.ParseErr // CompileErr indicates an unclassified compile error occurred. - CompileErr = "rego_compile_error" + CompileErr = v1.CompileErr // TypeErr indicates a type error was caught. - TypeErr = "rego_type_error" + TypeErr = v1.TypeErr // UnsafeVarErr indicates an unsafe variable was found during compilation. - UnsafeVarErr = "rego_unsafe_var_error" + UnsafeVarErr = v1.UnsafeVarErr // RecursionErr indicates recursion was found during compilation. - RecursionErr = "rego_recursion_error" + RecursionErr = v1.RecursionErr ) // IsError returns true if err is an AST error with code. func IsError(code string, err error) bool { - if err, ok := err.(*Error); ok { - return err.Code == code - } - return false + return v1.IsError(code, err) } // ErrorDetails defines the interface for detailed error messages. -type ErrorDetails interface { - Lines() []string -} +type ErrorDetails = v1.ErrorDetails // Error represents a single error caught during parsing, compiling, etc. -type Error struct { - Code string `json:"code"` - Message string `json:"message"` - Location *Location `json:"location,omitempty"` - Details ErrorDetails `json:"details,omitempty"` -} - -func (e *Error) Error() string { - - var prefix string - - if e.Location != nil { - - if len(e.Location.File) > 0 { - prefix += e.Location.File + ":" + fmt.Sprint(e.Location.Row) - } else { - prefix += fmt.Sprint(e.Location.Row) + ":" + fmt.Sprint(e.Location.Col) - } - } - - msg := fmt.Sprintf("%v: %v", e.Code, e.Message) - - if len(prefix) > 0 { - msg = prefix + ": " + msg - } - - if e.Details != nil { - for _, line := range e.Details.Lines() { - msg += "\n\t" + line - } - } - - return msg -} +type Error = v1.Error // NewError returns a new Error object. func NewError(code string, loc *Location, f string, a ...interface{}) *Error { - return &Error{ - Code: code, - Location: loc, - Message: fmt.Sprintf(f, a...), - } + return v1.NewError(code, loc, f, a...) } diff --git a/vendor/github.com/open-policy-agent/opa/ast/index.go b/vendor/github.com/open-policy-agent/opa/ast/index.go index cb0cbea32..7e80bb771 100644 --- a/vendor/github.com/open-policy-agent/opa/ast/index.go +++ b/vendor/github.com/open-policy-agent/opa/ast/index.go @@ -5,904 +5,16 @@ package ast import ( - "fmt" - "sort" - "strings" - - "github.com/open-policy-agent/opa/util" + v1 "github.com/open-policy-agent/opa/v1/ast" ) // RuleIndex defines the interface for rule indices. -type RuleIndex interface { - - // Build tries to construct an index for the given rules. If the index was - // constructed, it returns true, otherwise false. - Build(rules []*Rule) bool - - // Lookup searches the index for rules that will match the provided - // resolver. If the resolver returns an error, it is returned via err. - Lookup(resolver ValueResolver) (*IndexResult, error) - - // AllRules traverses the index and returns all rules that will match - // the provided resolver without any optimizations (effectively with - // indexing disabled). If the resolver returns an error, it is returned - // via err. - AllRules(resolver ValueResolver) (*IndexResult, error) -} +type RuleIndex v1.RuleIndex // IndexResult contains the result of an index lookup. -type IndexResult struct { - Kind RuleKind - Rules []*Rule - Else map[*Rule][]*Rule - Default *Rule - EarlyExit bool - OnlyGroundRefs bool -} +type IndexResult = v1.IndexResult // NewIndexResult returns a new IndexResult object. func NewIndexResult(kind RuleKind) *IndexResult { - return &IndexResult{ - Kind: kind, - Else: map[*Rule][]*Rule{}, - } -} - -// Empty returns true if there are no rules to evaluate. -func (ir *IndexResult) Empty() bool { - return len(ir.Rules) == 0 && ir.Default == nil -} - -type baseDocEqIndex struct { - skipIndexing Set - isVirtual func(Ref) bool - root *trieNode - defaultRule *Rule - kind RuleKind - onlyGroundRefs bool -} - -func newBaseDocEqIndex(isVirtual func(Ref) bool) *baseDocEqIndex { - return &baseDocEqIndex{ - skipIndexing: NewSet(NewTerm(InternalPrint.Ref())), - isVirtual: isVirtual, - root: newTrieNodeImpl(), - onlyGroundRefs: true, - } -} - -func (i *baseDocEqIndex) Build(rules []*Rule) bool { - if len(rules) == 0 { - return false - } - - i.kind = rules[0].Head.RuleKind() - indices := newrefindices(i.isVirtual) - - // build indices for each rule. - for idx := range rules { - WalkRules(rules[idx], func(rule *Rule) bool { - if rule.Default { - i.defaultRule = rule - return false - } - if i.onlyGroundRefs { - i.onlyGroundRefs = rule.Head.Reference.IsGround() - } - var skip bool - for _, expr := range rule.Body { - if op := expr.OperatorTerm(); op != nil && i.skipIndexing.Contains(op) { - skip = true - break - } - } - if !skip { - for _, expr := range rule.Body { - indices.Update(rule, expr) - } - } - return false - }) - } - - // build trie out of indices. - for idx := range rules { - var prio int - WalkRules(rules[idx], func(rule *Rule) bool { - if rule.Default { - return false - } - node := i.root - if indices.Indexed(rule) { - for _, ref := range indices.Sorted() { - node = node.Insert(ref, indices.Value(rule, ref), indices.Mapper(rule, ref)) - } - } - // Insert rule into trie with (insertion order, priority order) - // tuple. Retaining the insertion order allows us to return rules - // in the order they were passed to this function. - node.append([...]int{idx, prio}, rule) - prio++ - return false - }) - } - return true -} - -func (i *baseDocEqIndex) Lookup(resolver ValueResolver) (*IndexResult, error) { - - tr := newTrieTraversalResult() - - err := i.root.Traverse(resolver, tr) - if err != nil { - return nil, err - } - - result := NewIndexResult(i.kind) - result.Default = i.defaultRule - result.OnlyGroundRefs = i.onlyGroundRefs - result.Rules = make([]*Rule, 0, len(tr.ordering)) - - for _, pos := range tr.ordering { - sort.Slice(tr.unordered[pos], func(i, j int) bool { - return tr.unordered[pos][i].prio[1] < tr.unordered[pos][j].prio[1] - }) - nodes := tr.unordered[pos] - root := nodes[0].rule - - result.Rules = append(result.Rules, root) - if len(nodes) > 1 { - result.Else[root] = make([]*Rule, len(nodes)-1) - for i := 1; i < len(nodes); i++ { - result.Else[root][i-1] = nodes[i].rule - } - } - } - - result.EarlyExit = tr.values.Len() == 1 && tr.values.Slice()[0].IsGround() - - return result, nil -} - -func (i *baseDocEqIndex) AllRules(_ ValueResolver) (*IndexResult, error) { - tr := newTrieTraversalResult() - - // Walk over the rule trie and accumulate _all_ rules - rw := &ruleWalker{result: tr} - i.root.Do(rw) - - result := NewIndexResult(i.kind) - result.Default = i.defaultRule - result.OnlyGroundRefs = i.onlyGroundRefs - result.Rules = make([]*Rule, 0, len(tr.ordering)) - - for _, pos := range tr.ordering { - sort.Slice(tr.unordered[pos], func(i, j int) bool { - return tr.unordered[pos][i].prio[1] < tr.unordered[pos][j].prio[1] - }) - nodes := tr.unordered[pos] - root := nodes[0].rule - result.Rules = append(result.Rules, root) - if len(nodes) > 1 { - result.Else[root] = make([]*Rule, len(nodes)-1) - for i := 1; i < len(nodes); i++ { - result.Else[root][i-1] = nodes[i].rule - } - } - } - - result.EarlyExit = tr.values.Len() == 1 && tr.values.Slice()[0].IsGround() - - return result, nil -} - -type ruleWalker struct { - result *trieTraversalResult -} - -func (r *ruleWalker) Do(x interface{}) trieWalker { - tn := x.(*trieNode) - r.result.Add(tn) - return r -} - -type valueMapper struct { - Key string - MapValue func(Value) Value -} - -type refindex struct { - Ref Ref - Value Value - Mapper *valueMapper -} - -type refindices struct { - isVirtual func(Ref) bool - rules map[*Rule][]*refindex - frequency *util.HashMap - sorted []Ref -} - -func newrefindices(isVirtual func(Ref) bool) *refindices { - return &refindices{ - isVirtual: isVirtual, - rules: map[*Rule][]*refindex{}, - frequency: util.NewHashMap(func(a, b util.T) bool { - r1, r2 := a.(Ref), b.(Ref) - return r1.Equal(r2) - }, func(x util.T) int { - return x.(Ref).Hash() - }), - } -} - -// Update attempts to update the refindices for the given expression in the -// given rule. If the expression cannot be indexed the update does not affect -// the indices. -func (i *refindices) Update(rule *Rule, expr *Expr) { - - if expr.Negated { - return - } - - if len(expr.With) > 0 { - // NOTE(tsandall): In the future, we may need to consider expressions - // that have with statements applied to them. - return - } - - op := expr.Operator() - - switch { - case op.Equal(Equality.Ref()): - i.updateEq(rule, expr) - - case op.Equal(Equal.Ref()) && len(expr.Operands()) == 2: - // NOTE(tsandall): if equal() is called with more than two arguments the - // output value is being captured in which case the indexer cannot - // exclude the rule if the equal() call would return false (because the - // false value must still be produced.) - i.updateEq(rule, expr) - - case op.Equal(GlobMatch.Ref()) && len(expr.Operands()) == 3: - // NOTE(sr): Same as with equal() above -- 4 operands means the output - // of `glob.match` is captured and the rule can thus not be excluded. - i.updateGlobMatch(rule, expr) - } -} - -// Sorted returns a sorted list of references that the indices were built from. -// References that appear more frequently in the indexed rules are ordered -// before less frequently appearing references. -func (i *refindices) Sorted() []Ref { - - if i.sorted == nil { - counts := make([]int, 0, i.frequency.Len()) - i.sorted = make([]Ref, 0, i.frequency.Len()) - - i.frequency.Iter(func(k, v util.T) bool { - counts = append(counts, v.(int)) - i.sorted = append(i.sorted, k.(Ref)) - return false - }) - - sort.Slice(i.sorted, func(a, b int) bool { - if counts[a] > counts[b] { - return true - } else if counts[b] > counts[a] { - return false - } - return i.sorted[a][0].Loc().Compare(i.sorted[b][0].Loc()) < 0 - }) - } - - return i.sorted -} - -func (i *refindices) Indexed(rule *Rule) bool { - return len(i.rules[rule]) > 0 -} - -func (i *refindices) Value(rule *Rule, ref Ref) Value { - if index := i.index(rule, ref); index != nil { - return index.Value - } - return nil -} - -func (i *refindices) Mapper(rule *Rule, ref Ref) *valueMapper { - if index := i.index(rule, ref); index != nil { - return index.Mapper - } - return nil -} - -func (i *refindices) updateEq(rule *Rule, expr *Expr) { - a, b := expr.Operand(0), expr.Operand(1) - args := rule.Head.Args - if idx, ok := eqOperandsToRefAndValue(i.isVirtual, args, a, b); ok { - i.insert(rule, idx) - return - } - if idx, ok := eqOperandsToRefAndValue(i.isVirtual, args, b, a); ok { - i.insert(rule, idx) - return - } -} - -func (i *refindices) updateGlobMatch(rule *Rule, expr *Expr) { - args := rule.Head.Args - - delim, ok := globDelimiterToString(expr.Operand(1)) - if !ok { - return - } - - if arr := globPatternToArray(expr.Operand(0), delim); arr != nil { - // The 3rd operand of glob.match is the value to match. We assume the - // 3rd operand was a reference that has been rewritten and bound to a - // variable earlier in the query OR a function argument variable. - match := expr.Operand(2) - if _, ok := match.Value.(Var); ok { - var ref Ref - for _, other := range i.rules[rule] { - if _, ok := other.Value.(Var); ok && other.Value.Compare(match.Value) == 0 { - ref = other.Ref - } - } - if ref == nil { - for j, arg := range args { - if arg.Equal(match) { - ref = Ref{FunctionArgRootDocument, IntNumberTerm(j)} - } - } - } - if ref != nil { - i.insert(rule, &refindex{ - Ref: ref, - Value: arr.Value, - Mapper: &valueMapper{ - Key: delim, - MapValue: func(v Value) Value { - if s, ok := v.(String); ok { - return stringSliceToArray(splitStringEscaped(string(s), delim)) - } - return v - }, - }, - }) - } - } - } -} - -func (i *refindices) insert(rule *Rule, index *refindex) { - - count, ok := i.frequency.Get(index.Ref) - if !ok { - count = 0 - } - - i.frequency.Put(index.Ref, count.(int)+1) - - for pos, other := range i.rules[rule] { - if other.Ref.Equal(index.Ref) { - i.rules[rule][pos] = index - return - } - } - - i.rules[rule] = append(i.rules[rule], index) -} - -func (i *refindices) index(rule *Rule, ref Ref) *refindex { - for _, index := range i.rules[rule] { - if index.Ref.Equal(ref) { - return index - } - } - return nil -} - -type trieWalker interface { - Do(x interface{}) trieWalker -} - -type trieTraversalResult struct { - unordered map[int][]*ruleNode - ordering []int - values Set -} - -func newTrieTraversalResult() *trieTraversalResult { - return &trieTraversalResult{ - unordered: map[int][]*ruleNode{}, - values: NewSet(), - } -} - -func (tr *trieTraversalResult) Add(t *trieNode) { - for _, node := range t.rules { - root := node.prio[0] - nodes, ok := tr.unordered[root] - if !ok { - tr.ordering = append(tr.ordering, root) - } - tr.unordered[root] = append(nodes, node) - } - if t.values != nil { - t.values.Foreach(func(v *Term) { tr.values.Add(v) }) - } -} - -type trieNode struct { - ref Ref - values Set - mappers []*valueMapper - next *trieNode - any *trieNode - undefined *trieNode - scalars *util.HashMap - array *trieNode - rules []*ruleNode -} - -func (node *trieNode) String() string { - var flags []string - flags = append(flags, fmt.Sprintf("self:%p", node)) - if len(node.ref) > 0 { - flags = append(flags, node.ref.String()) - } - if node.next != nil { - flags = append(flags, fmt.Sprintf("next:%p", node.next)) - } - if node.any != nil { - flags = append(flags, fmt.Sprintf("any:%p", node.any)) - } - if node.undefined != nil { - flags = append(flags, fmt.Sprintf("undefined:%p", node.undefined)) - } - if node.array != nil { - flags = append(flags, fmt.Sprintf("array:%p", node.array)) - } - if node.scalars.Len() > 0 { - buf := make([]string, 0, node.scalars.Len()) - node.scalars.Iter(func(k, v util.T) bool { - key := k.(Value) - val := v.(*trieNode) - buf = append(buf, fmt.Sprintf("scalar(%v):%p", key, val)) - return false - }) - sort.Strings(buf) - flags = append(flags, strings.Join(buf, " ")) - } - if len(node.rules) > 0 { - flags = append(flags, fmt.Sprintf("%d rule(s)", len(node.rules))) - } - if len(node.mappers) > 0 { - flags = append(flags, fmt.Sprintf("%d mapper(s)", len(node.mappers))) - } - if node.values != nil { - if l := node.values.Len(); l > 0 { - flags = append(flags, fmt.Sprintf("%d value(s)", l)) - } - } - return strings.Join(flags, " ") -} - -func (node *trieNode) append(prio [2]int, rule *Rule) { - node.rules = append(node.rules, &ruleNode{prio, rule}) - - if node.values != nil && rule.Head.Value != nil { - node.values.Add(rule.Head.Value) - return - } - - if node.values == nil && rule.Head.DocKind() == CompleteDoc { - node.values = NewSet(rule.Head.Value) - } -} - -type ruleNode struct { - prio [2]int - rule *Rule -} - -func newTrieNodeImpl() *trieNode { - return &trieNode{ - scalars: util.NewHashMap(valueEq, valueHash), - } -} - -func (node *trieNode) Do(walker trieWalker) { - next := walker.Do(node) - if next == nil { - return - } - if node.any != nil { - node.any.Do(next) - } - if node.undefined != nil { - node.undefined.Do(next) - } - - node.scalars.Iter(func(_, v util.T) bool { - child := v.(*trieNode) - child.Do(next) - return false - }) - - if node.array != nil { - node.array.Do(next) - } - if node.next != nil { - node.next.Do(next) - } -} - -func (node *trieNode) Insert(ref Ref, value Value, mapper *valueMapper) *trieNode { - - if node.next == nil { - node.next = newTrieNodeImpl() - node.next.ref = ref - } - - if mapper != nil { - node.next.addMapper(mapper) - } - - return node.next.insertValue(value) -} - -func (node *trieNode) Traverse(resolver ValueResolver, tr *trieTraversalResult) error { - - if node == nil { - return nil - } - - tr.Add(node) - - return node.next.traverse(resolver, tr) -} - -func (node *trieNode) addMapper(mapper *valueMapper) { - for i := range node.mappers { - if node.mappers[i].Key == mapper.Key { - return - } - } - node.mappers = append(node.mappers, mapper) -} - -func (node *trieNode) insertValue(value Value) *trieNode { - - switch value := value.(type) { - case nil: - if node.undefined == nil { - node.undefined = newTrieNodeImpl() - } - return node.undefined - case Var: - if node.any == nil { - node.any = newTrieNodeImpl() - } - return node.any - case Null, Boolean, Number, String: - child, ok := node.scalars.Get(value) - if !ok { - child = newTrieNodeImpl() - node.scalars.Put(value, child) - } - return child.(*trieNode) - case *Array: - if node.array == nil { - node.array = newTrieNodeImpl() - } - return node.array.insertArray(value) - } - - panic("illegal value") -} - -func (node *trieNode) insertArray(arr *Array) *trieNode { - - if arr.Len() == 0 { - return node - } - - switch head := arr.Elem(0).Value.(type) { - case Var: - if node.any == nil { - node.any = newTrieNodeImpl() - } - return node.any.insertArray(arr.Slice(1, -1)) - case Null, Boolean, Number, String: - child, ok := node.scalars.Get(head) - if !ok { - child = newTrieNodeImpl() - node.scalars.Put(head, child) - } - return child.(*trieNode).insertArray(arr.Slice(1, -1)) - } - - panic("illegal value") -} - -func (node *trieNode) traverse(resolver ValueResolver, tr *trieTraversalResult) error { - - if node == nil { - return nil - } - - v, err := resolver.Resolve(node.ref) - if err != nil { - if IsUnknownValueErr(err) { - return node.traverseUnknown(resolver, tr) - } - return err - } - - if node.undefined != nil { - err = node.undefined.Traverse(resolver, tr) - if err != nil { - return err - } - } - - if v == nil { - return nil - } - - if node.any != nil { - err = node.any.Traverse(resolver, tr) - if err != nil { - return err - } - } - - if err := node.traverseValue(resolver, tr, v); err != nil { - return err - } - - for i := range node.mappers { - if err := node.traverseValue(resolver, tr, node.mappers[i].MapValue(v)); err != nil { - return err - } - } - - return nil -} - -func (node *trieNode) traverseValue(resolver ValueResolver, tr *trieTraversalResult, value Value) error { - - switch value := value.(type) { - case *Array: - if node.array == nil { - return nil - } - return node.array.traverseArray(resolver, tr, value) - - case Null, Boolean, Number, String: - child, ok := node.scalars.Get(value) - if !ok { - return nil - } - return child.(*trieNode).Traverse(resolver, tr) - } - - return nil -} - -func (node *trieNode) traverseArray(resolver ValueResolver, tr *trieTraversalResult, arr *Array) error { - - if arr.Len() == 0 { - return node.Traverse(resolver, tr) - } - - if node.any != nil { - err := node.any.traverseArray(resolver, tr, arr.Slice(1, -1)) - if err != nil { - return err - } - } - - head := arr.Elem(0).Value - - if !IsScalar(head) { - return nil - } - - child, ok := node.scalars.Get(head) - if !ok { - return nil - } - return child.(*trieNode).traverseArray(resolver, tr, arr.Slice(1, -1)) -} - -func (node *trieNode) traverseUnknown(resolver ValueResolver, tr *trieTraversalResult) error { - - if node == nil { - return nil - } - - if err := node.Traverse(resolver, tr); err != nil { - return err - } - - if err := node.undefined.traverseUnknown(resolver, tr); err != nil { - return err - } - - if err := node.any.traverseUnknown(resolver, tr); err != nil { - return err - } - - if err := node.array.traverseUnknown(resolver, tr); err != nil { - return err - } - - var iterErr error - node.scalars.Iter(func(_, v util.T) bool { - child := v.(*trieNode) - if iterErr = child.traverseUnknown(resolver, tr); iterErr != nil { - return true - } - return false - }) - - return iterErr -} - -// If term `a` is one of the function's operands, we store a Ref: `args[0]` -// for the argument number. So for `f(x, y) { x = 10; y = 12 }`, we'll -// bind `args[0]` and `args[1]` to this rule when called for (x=10) and -// (y=12) respectively. -func eqOperandsToRefAndValue(isVirtual func(Ref) bool, args []*Term, a, b *Term) (*refindex, bool) { - switch v := a.Value.(type) { - case Var: - for i, arg := range args { - if arg.Value.Compare(v) == 0 { - if bval, ok := indexValue(b); ok { - return &refindex{Ref: Ref{FunctionArgRootDocument, IntNumberTerm(i)}, Value: bval}, true - } - } - } - case Ref: - if !RootDocumentNames.Contains(v[0]) { - return nil, false - } - if isVirtual(v) { - return nil, false - } - if v.IsNested() || !v.IsGround() { - return nil, false - } - if bval, ok := indexValue(b); ok { - return &refindex{Ref: v, Value: bval}, true - } - } - return nil, false -} - -func indexValue(b *Term) (Value, bool) { - switch b := b.Value.(type) { - case Null, Boolean, Number, String, Var: - return b, true - case *Array: - stop := false - first := true - vis := NewGenericVisitor(func(x interface{}) bool { - if first { - first = false - return false - } - switch x.(type) { - // No nested structures or values that require evaluation (other than var). - case *Array, Object, Set, *ArrayComprehension, *ObjectComprehension, *SetComprehension, Ref: - stop = true - } - return stop - }) - vis.Walk(b) - if !stop { - return b, true - } - } - - return nil, false -} - -func globDelimiterToString(delim *Term) (string, bool) { - - arr, ok := delim.Value.(*Array) - if !ok { - return "", false - } - - var result string - - if arr.Len() == 0 { - result = "." - } else { - for i := 0; i < arr.Len(); i++ { - term := arr.Elem(i) - s, ok := term.Value.(String) - if !ok { - return "", false - } - result += string(s) - } - } - - return result, true -} - -func globPatternToArray(pattern *Term, delim string) *Term { - - s, ok := pattern.Value.(String) - if !ok { - return nil - } - - parts := splitStringEscaped(string(s), delim) - arr := make([]*Term, len(parts)) - - for i := range parts { - if parts[i] == "*" { - arr[i] = VarTerm("$globwildcard") - } else { - var escaped bool - for _, c := range parts[i] { - if c == '\\' { - escaped = !escaped - continue - } - if !escaped { - switch c { - case '[', '?', '{', '*': - // TODO(tsandall): super glob and character pattern - // matching not supported yet. - return nil - } - } - escaped = false - } - arr[i] = StringTerm(parts[i]) - } - } - - return NewTerm(NewArray(arr...)) -} - -// splits s on characters in delim except if delim characters have been escaped -// with reverse solidus. -func splitStringEscaped(s string, delim string) []string { - - var last, curr int - var escaped bool - var result []string - - for ; curr < len(s); curr++ { - if s[curr] == '\\' || escaped { - escaped = !escaped - continue - } - if strings.ContainsRune(delim, rune(s[curr])) { - result = append(result, s[last:curr]) - last = curr + 1 - } - } - - result = append(result, s[last:]) - - return result -} - -func stringSliceToArray(s []string) *Array { - arr := make([]*Term, len(s)) - for i, v := range s { - arr[i] = StringTerm(v) - } - return NewArray(arr...) + return v1.NewIndexResult(kind) } diff --git a/vendor/github.com/open-policy-agent/opa/ast/interning.go b/vendor/github.com/open-policy-agent/opa/ast/interning.go new file mode 100644 index 000000000..239293664 --- /dev/null +++ b/vendor/github.com/open-policy-agent/opa/ast/interning.go @@ -0,0 +1,24 @@ +// Copyright 2024 The OPA Authors. All rights reserved. +// Use of this source code is governed by an Apache2 +// license that can be found in the LICENSE file. + +package ast + +import ( + v1 "github.com/open-policy-agent/opa/v1/ast" +) + +func InternedBooleanTerm(b bool) *Term { + return v1.InternedBooleanTerm(b) +} + +// InternedIntNumberTerm returns a term with the given integer value. The term is +// cached between -1 to 512, and for values outside of that range, this function +// is equivalent to ast.IntNumberTerm. +func InternedIntNumberTerm(i int) *Term { + return v1.InternedIntNumberTerm(i) +} + +func HasInternedIntNumberTerm(i int) bool { + return v1.HasInternedIntNumberTerm(i) +} diff --git a/vendor/github.com/open-policy-agent/opa/ast/json/doc.go b/vendor/github.com/open-policy-agent/opa/ast/json/doc.go new file mode 100644 index 000000000..26aee9b99 --- /dev/null +++ b/vendor/github.com/open-policy-agent/opa/ast/json/doc.go @@ -0,0 +1,8 @@ +// Copyright 2024 The OPA Authors. All rights reserved. +// Use of this source code is governed by an Apache2 +// license that can be found in the LICENSE file. + +// Deprecated: This package is intended for older projects transitioning from OPA v0.x and will remain for the lifetime of OPA v1.x, but its use is not recommended. +// For newer features and behaviours, such as defaulting to the Rego v1 syntax, use the corresponding components in the [github.com/open-policy-agent/opa/v1] package instead. +// See https://www.openpolicyagent.org/docs/latest/v0-compatibility/ for more information. +package json diff --git a/vendor/github.com/open-policy-agent/opa/ast/json/json.go b/vendor/github.com/open-policy-agent/opa/ast/json/json.go index 565017d58..8a3a36bb9 100644 --- a/vendor/github.com/open-policy-agent/opa/ast/json/json.go +++ b/vendor/github.com/open-policy-agent/opa/ast/json/json.go @@ -1,36 +1,15 @@ package json +import v1 "github.com/open-policy-agent/opa/v1/ast/json" + // Options defines the options for JSON operations, // currently only marshaling can be configured -type Options struct { - MarshalOptions MarshalOptions -} +type Options = v1.Options // MarshalOptions defines the options for JSON marshaling, // currently only toggling the marshaling of location information is supported -type MarshalOptions struct { - // IncludeLocation toggles the marshaling of location information - IncludeLocation NodeToggle - // IncludeLocationText additionally/optionally includes the text of the location - IncludeLocationText bool - // ExcludeLocationFile additionally/optionally excludes the file of the location - // Note that this is inverted (i.e. not "include" as the default needs to remain false) - ExcludeLocationFile bool -} +type MarshalOptions = v1.MarshalOptions // NodeToggle is a generic struct to allow the toggling of // settings for different ast node types -type NodeToggle struct { - Term bool - Package bool - Comment bool - Import bool - Rule bool - Head bool - Expr bool - SomeDecl bool - Every bool - With bool - Annotations bool - AnnotationsRef bool -} +type NodeToggle = v1.NodeToggle diff --git a/vendor/github.com/open-policy-agent/opa/ast/map.go b/vendor/github.com/open-policy-agent/opa/ast/map.go index b0cc9eb60..070ad3e5d 100644 --- a/vendor/github.com/open-policy-agent/opa/ast/map.go +++ b/vendor/github.com/open-policy-agent/opa/ast/map.go @@ -5,129 +5,14 @@ package ast import ( - "encoding/json" - - "github.com/open-policy-agent/opa/util" + v1 "github.com/open-policy-agent/opa/v1/ast" ) // ValueMap represents a key/value map between AST term values. Any type of term // can be used as a key in the map. -type ValueMap struct { - hashMap *util.HashMap -} +type ValueMap = v1.ValueMap // NewValueMap returns a new ValueMap. func NewValueMap() *ValueMap { - vs := &ValueMap{ - hashMap: util.NewHashMap(valueEq, valueHash), - } - return vs -} - -// MarshalJSON provides a custom marshaller for the ValueMap which -// will include the key, value, and value type. -func (vs *ValueMap) MarshalJSON() ([]byte, error) { - var tmp []map[string]interface{} - vs.Iter(func(k Value, v Value) bool { - tmp = append(tmp, map[string]interface{}{ - "name": k.String(), - "type": TypeName(v), - "value": v, - }) - return false - }) - return json.Marshal(tmp) -} - -// Copy returns a shallow copy of the ValueMap. -func (vs *ValueMap) Copy() *ValueMap { - if vs == nil { - return nil - } - cpy := NewValueMap() - cpy.hashMap = vs.hashMap.Copy() - return cpy -} - -// Equal returns true if this ValueMap equals the other. -func (vs *ValueMap) Equal(other *ValueMap) bool { - if vs == nil { - return other == nil || other.Len() == 0 - } - if other == nil { - return vs == nil || vs.Len() == 0 - } - return vs.hashMap.Equal(other.hashMap) -} - -// Len returns the number of elements in the map. -func (vs *ValueMap) Len() int { - if vs == nil { - return 0 - } - return vs.hashMap.Len() -} - -// Get returns the value in the map for k. -func (vs *ValueMap) Get(k Value) Value { - if vs != nil { - if v, ok := vs.hashMap.Get(k); ok { - return v.(Value) - } - } - return nil -} - -// Hash returns a hash code for this ValueMap. -func (vs *ValueMap) Hash() int { - if vs == nil { - return 0 - } - return vs.hashMap.Hash() -} - -// Iter calls the iter function for each key/value pair in the map. If the iter -// function returns true, iteration stops. -func (vs *ValueMap) Iter(iter func(Value, Value) bool) bool { - if vs == nil { - return false - } - return vs.hashMap.Iter(func(kt, vt util.T) bool { - k := kt.(Value) - v := vt.(Value) - return iter(k, v) - }) -} - -// Put inserts a key k into the map with value v. -func (vs *ValueMap) Put(k, v Value) { - if vs == nil { - panic("put on nil value map") - } - vs.hashMap.Put(k, v) -} - -// Delete removes a key k from the map. -func (vs *ValueMap) Delete(k Value) { - if vs == nil { - return - } - vs.hashMap.Delete(k) -} - -func (vs *ValueMap) String() string { - if vs == nil { - return "{}" - } - return vs.hashMap.String() -} - -func valueHash(v util.T) int { - return v.(Value).Hash() -} - -func valueEq(a, b util.T) bool { - av := a.(Value) - bv := b.(Value) - return av.Compare(bv) == 0 + return v1.NewValueMap() } diff --git a/vendor/github.com/open-policy-agent/opa/ast/parser.go b/vendor/github.com/open-policy-agent/opa/ast/parser.go index 09ede2bae..45cd4da06 100644 --- a/vendor/github.com/open-policy-agent/opa/ast/parser.go +++ b/vendor/github.com/open-policy-agent/opa/ast/parser.go @@ -1,2733 +1,49 @@ -// Copyright 2020 The OPA Authors. All rights reserved. +// Copyright 2024 The OPA Authors. All rights reserved. // Use of this source code is governed by an Apache2 // license that can be found in the LICENSE file. package ast import ( - "bytes" - "encoding/json" - "fmt" - "io" - "math/big" - "net/url" - "regexp" - "sort" - "strconv" - "strings" - "unicode/utf8" - - "gopkg.in/yaml.v3" - - "github.com/open-policy-agent/opa/ast/internal/scanner" - "github.com/open-policy-agent/opa/ast/internal/tokens" - astJSON "github.com/open-policy-agent/opa/ast/json" - "github.com/open-policy-agent/opa/ast/location" + v1 "github.com/open-policy-agent/opa/v1/ast" ) -var RegoV1CompatibleRef = Ref{VarTerm("rego"), StringTerm("v1")} +var RegoV1CompatibleRef = v1.RegoV1CompatibleRef // RegoVersion defines the Rego syntax requirements for a module. -type RegoVersion int +type RegoVersion = v1.RegoVersion -const DefaultRegoVersion = RegoVersion(0) +const DefaultRegoVersion = RegoV0 const ( + RegoUndefined = v1.RegoUndefined // RegoV0 is the default, original Rego syntax. - RegoV0 RegoVersion = iota + RegoV0 = v1.RegoV0 // RegoV0CompatV1 requires modules to comply with both the RegoV0 and RegoV1 syntax (as when 'rego.v1' is imported in a module). // Shortly, RegoV1 compatibility is required, but 'rego.v1' or 'future.keywords' must also be imported. - RegoV0CompatV1 + RegoV0CompatV1 = v1.RegoV0CompatV1 // RegoV1 is the Rego syntax enforced by OPA 1.0; e.g.: // future.keywords part of default keyword set, and don't require imports; // 'if' and 'contains' required in rule heads; // (some) strict checks on by default. - RegoV1 + RegoV1 = v1.RegoV1 ) -func (v RegoVersion) Int() int { - if v == RegoV1 { - return 1 - } - return 0 -} - -func (v RegoVersion) String() string { - switch v { - case RegoV0: - return "v0" - case RegoV1: - return "v1" - case RegoV0CompatV1: - return "v0v1" - default: - return "unknown" - } -} - func RegoVersionFromInt(i int) RegoVersion { - if i == 1 { - return RegoV1 - } - return RegoV0 -} - -// Note: This state is kept isolated from the parser so that we -// can do efficient shallow copies of these values when doing a -// save() and restore(). -type state struct { - s *scanner.Scanner - lastEnd int - skippedNL bool - tok tokens.Token - tokEnd int - lit string - loc Location - errors Errors - hints []string - comments []*Comment - wildcard int -} - -func (s *state) String() string { - return fmt.Sprintf("", s.s, s.tok, s.lit, s.loc, len(s.errors), len(s.comments)) -} - -func (s *state) Loc() *location.Location { - cpy := s.loc - return &cpy -} - -func (s *state) Text(offset, end int) []byte { - bs := s.s.Bytes() - if offset >= 0 && offset < len(bs) { - if end >= offset && end <= len(bs) { - return bs[offset:end] - } - } - return nil + return v1.RegoVersionFromInt(i) } // Parser is used to parse Rego statements. -type Parser struct { - r io.Reader - s *state - po ParserOptions - cache parsedTermCache -} - -type parsedTermCacheItem struct { - t *Term - post *state // post is the post-state that's restored on a cache-hit - offset int - next *parsedTermCacheItem -} - -type parsedTermCache struct { - m *parsedTermCacheItem -} - -func (c parsedTermCache) String() string { - s := strings.Builder{} - s.WriteRune('{') - var e *parsedTermCacheItem - for e = c.m; e != nil; e = e.next { - s.WriteString(fmt.Sprintf("%v", e)) - } - s.WriteRune('}') - return s.String() -} - -func (e *parsedTermCacheItem) String() string { - return fmt.Sprintf("<%d:%v>", e.offset, e.t) -} +type Parser = v1.Parser // ParserOptions defines the options for parsing Rego statements. -type ParserOptions struct { - Capabilities *Capabilities - ProcessAnnotation bool - AllFutureKeywords bool - FutureKeywords []string - SkipRules bool - JSONOptions *astJSON.Options - // RegoVersion is the version of Rego to parse for. - RegoVersion RegoVersion - unreleasedKeywords bool // TODO(sr): cleanup -} - -// EffectiveRegoVersion returns the effective RegoVersion to use for parsing. -// Deprecated: Use RegoVersion instead. -func (po *ParserOptions) EffectiveRegoVersion() RegoVersion { - return po.RegoVersion -} +type ParserOptions = v1.ParserOptions // NewParser creates and initializes a Parser. func NewParser() *Parser { - p := &Parser{ - s: &state{}, - po: ParserOptions{}, - } - return p -} - -// WithFilename provides the filename for Location details -// on parsed statements. -func (p *Parser) WithFilename(filename string) *Parser { - p.s.loc.File = filename - return p -} - -// WithReader provides the io.Reader that the parser will -// use as its source. -func (p *Parser) WithReader(r io.Reader) *Parser { - p.r = r - return p -} - -// WithProcessAnnotation enables or disables the processing of -// annotations by the Parser -func (p *Parser) WithProcessAnnotation(processAnnotation bool) *Parser { - p.po.ProcessAnnotation = processAnnotation - return p -} - -// WithFutureKeywords enables "future" keywords, i.e., keywords that can -// be imported via -// -// import future.keywords.kw -// import future.keywords.other -// -// but in a more direct way. The equivalent of this import would be -// -// WithFutureKeywords("kw", "other") -func (p *Parser) WithFutureKeywords(kws ...string) *Parser { - p.po.FutureKeywords = kws - return p -} - -// WithAllFutureKeywords enables all "future" keywords, i.e., the -// ParserOption equivalent of -// -// import future.keywords -func (p *Parser) WithAllFutureKeywords(yes bool) *Parser { - p.po.AllFutureKeywords = yes - return p -} - -// withUnreleasedKeywords allows using keywords that haven't surfaced -// as future keywords (see above) yet, but have tests that require -// them to be parsed -func (p *Parser) withUnreleasedKeywords(yes bool) *Parser { - p.po.unreleasedKeywords = yes - return p -} - -// WithCapabilities sets the capabilities structure on the parser. -func (p *Parser) WithCapabilities(c *Capabilities) *Parser { - p.po.Capabilities = c - return p -} - -// WithSkipRules instructs the parser not to attempt to parse Rule statements. -func (p *Parser) WithSkipRules(skip bool) *Parser { - p.po.SkipRules = skip - return p -} - -// WithJSONOptions sets the Options which will be set on nodes to configure -// their JSON marshaling behavior. -func (p *Parser) WithJSONOptions(jsonOptions *astJSON.Options) *Parser { - p.po.JSONOptions = jsonOptions - return p -} - -func (p *Parser) WithRegoVersion(version RegoVersion) *Parser { - p.po.RegoVersion = version - return p -} - -func (p *Parser) parsedTermCacheLookup() (*Term, *state) { - l := p.s.loc.Offset - // stop comparing once the cached offsets are lower than l - for h := p.cache.m; h != nil && h.offset >= l; h = h.next { - if h.offset == l { - return h.t, h.post - } - } - return nil, nil -} - -func (p *Parser) parsedTermCachePush(t *Term, s0 *state) { - s1 := p.save() - o0 := s0.loc.Offset - entry := parsedTermCacheItem{t: t, post: s1, offset: o0} - - // find the first one whose offset is smaller than ours - var e *parsedTermCacheItem - for e = p.cache.m; e != nil; e = e.next { - if e.offset < o0 { - break - } - } - entry.next = e - p.cache.m = &entry -} - -// futureParser returns a shallow copy of `p` with an empty -// cache, and a scanner that knows all future keywords. -// It's used to present hints in errors, when statements would -// only parse successfully if some future keyword is enabled. -func (p *Parser) futureParser() *Parser { - q := *p - q.s = p.save() - q.s.s = p.s.s.WithKeywords(futureKeywords) - q.cache = parsedTermCache{} - return &q -} - -// presentParser returns a shallow copy of `p` with an empty -// cache, and a scanner that knows none of the future keywords. -// It is used to successfully parse keyword imports, like -// -// import future.keywords.in -// -// even when the parser has already been informed about the -// future keyword "in". This parser won't error out because -// "in" is an identifier. -func (p *Parser) presentParser() (*Parser, map[string]tokens.Token) { - var cpy map[string]tokens.Token - q := *p - q.s = p.save() - q.s.s, cpy = p.s.s.WithoutKeywords(futureKeywords) - q.cache = parsedTermCache{} - return &q, cpy -} - -// Parse will read the Rego source and parse statements and -// comments as they are found. Any errors encountered while -// parsing will be accumulated and returned as a list of Errors. -func (p *Parser) Parse() ([]Statement, []*Comment, Errors) { - - if p.po.Capabilities == nil { - p.po.Capabilities = CapabilitiesForThisVersion() - } - - allowedFutureKeywords := map[string]tokens.Token{} - - if p.po.RegoVersion == RegoV1 { - // RegoV1 includes all future keywords in the default language definition - for k, v := range futureKeywords { - allowedFutureKeywords[k] = v - } - - // For sake of error reporting, we still need to check that keywords in capabilities are known, - for _, kw := range p.po.Capabilities.FutureKeywords { - if _, ok := futureKeywords[kw]; !ok { - return nil, nil, Errors{ - &Error{ - Code: ParseErr, - Message: fmt.Sprintf("illegal capabilities: unknown keyword: %v", kw), - Location: nil, - }, - } - } - } - // and that explicitly requested future keywords are known. - for _, kw := range p.po.FutureKeywords { - if _, ok := allowedFutureKeywords[kw]; !ok { - return nil, nil, Errors{ - &Error{ - Code: ParseErr, - Message: fmt.Sprintf("unknown future keyword: %v", kw), - Location: nil, - }, - } - } - } - } else { - for _, kw := range p.po.Capabilities.FutureKeywords { - var ok bool - allowedFutureKeywords[kw], ok = futureKeywords[kw] - if !ok { - return nil, nil, Errors{ - &Error{ - Code: ParseErr, - Message: fmt.Sprintf("illegal capabilities: unknown keyword: %v", kw), - Location: nil, - }, - } - } - } - } - - var err error - p.s.s, err = scanner.New(p.r) - if err != nil { - return nil, nil, Errors{ - &Error{ - Code: ParseErr, - Message: err.Error(), - Location: nil, - }, - } - } - - selected := map[string]tokens.Token{} - if p.po.AllFutureKeywords || p.po.RegoVersion == RegoV1 { - for kw, tok := range allowedFutureKeywords { - selected[kw] = tok - } - } else { - for _, kw := range p.po.FutureKeywords { - tok, ok := allowedFutureKeywords[kw] - if !ok { - return nil, nil, Errors{ - &Error{ - Code: ParseErr, - Message: fmt.Sprintf("unknown future keyword: %v", kw), - Location: nil, - }, - } - } - selected[kw] = tok - } - } - p.s.s = p.s.s.WithKeywords(selected) - - if p.po.RegoVersion == RegoV1 { - for kw, tok := range allowedFutureKeywords { - p.s.s.AddKeyword(kw, tok) - } - } - - // read the first token to initialize the parser - p.scan() - - var stmts []Statement - - // Read from the scanner until the last token is reached or no statements - // can be parsed. Attempt to parse package statements, import statements, - // rule statements, and then body/query statements (in that order). If a - // statement cannot be parsed, restore the parser state before trying the - // next type of statement. If a statement can be parsed, continue from that - // point trying to parse packages, imports, etc. in the same order. - for p.s.tok != tokens.EOF { - - s := p.save() - - if pkg := p.parsePackage(); pkg != nil { - stmts = append(stmts, pkg) - continue - } else if len(p.s.errors) > 0 { - break - } - - p.restore(s) - s = p.save() - - if imp := p.parseImport(); imp != nil { - if RegoRootDocument.Equal(imp.Path.Value.(Ref)[0]) { - p.regoV1Import(imp) - } - - if FutureRootDocument.Equal(imp.Path.Value.(Ref)[0]) { - p.futureImport(imp, allowedFutureKeywords) - } - - stmts = append(stmts, imp) - continue - } else if len(p.s.errors) > 0 { - break - } - - p.restore(s) - - if !p.po.SkipRules { - s = p.save() - - if rules := p.parseRules(); rules != nil { - for i := range rules { - stmts = append(stmts, rules[i]) - } - continue - } else if len(p.s.errors) > 0 { - break - } - - p.restore(s) - } - - if body := p.parseQuery(true, tokens.EOF); body != nil { - stmts = append(stmts, body) - continue - } - - break - } - - if p.po.ProcessAnnotation { - stmts = p.parseAnnotations(stmts) - } - - if p.po.JSONOptions != nil { - for i := range stmts { - vis := NewGenericVisitor(func(x interface{}) bool { - if x, ok := x.(customJSON); ok { - x.setJSONOptions(*p.po.JSONOptions) - } - return false - }) - - vis.Walk(stmts[i]) - } - } - - return stmts, p.s.comments, p.s.errors -} - -func (p *Parser) parseAnnotations(stmts []Statement) []Statement { - - annotStmts, errs := parseAnnotations(p.s.comments) - for _, err := range errs { - p.error(err.Location, err.Message) - } - - for _, annotStmt := range annotStmts { - stmts = append(stmts, annotStmt) - } - - return stmts -} - -func parseAnnotations(comments []*Comment) ([]*Annotations, Errors) { - - var hint = []byte("METADATA") - var curr *metadataParser - var blocks []*metadataParser - - for i := 0; i < len(comments); i++ { - if curr != nil { - if comments[i].Location.Row == comments[i-1].Location.Row+1 && comments[i].Location.Col == 1 { - curr.Append(comments[i]) - continue - } - curr = nil - } - if bytes.HasPrefix(bytes.TrimSpace(comments[i].Text), hint) { - curr = newMetadataParser(comments[i].Location) - blocks = append(blocks, curr) - } - } - - var stmts []*Annotations - var errs Errors - for _, b := range blocks { - a, err := b.Parse() - if err != nil { - errs = append(errs, &Error{ - Code: ParseErr, - Message: err.Error(), - Location: b.loc, - }) - } else { - stmts = append(stmts, a) - } - } - - return stmts, errs -} - -func (p *Parser) parsePackage() *Package { - - var pkg Package - pkg.SetLoc(p.s.Loc()) - - if p.s.tok != tokens.Package { - return nil - } - - p.scan() - if p.s.tok != tokens.Ident { - p.illegalToken() - return nil - } - - term := p.parseTerm() - - if term != nil { - switch v := term.Value.(type) { - case Var: - pkg.Path = Ref{ - DefaultRootDocument.Copy().SetLocation(term.Location), - StringTerm(string(v)).SetLocation(term.Location), - } - case Ref: - pkg.Path = make(Ref, len(v)+1) - pkg.Path[0] = DefaultRootDocument.Copy().SetLocation(v[0].Location) - first, ok := v[0].Value.(Var) - if !ok { - p.errorf(v[0].Location, "unexpected %v token: expecting var", TypeName(v[0].Value)) - return nil - } - pkg.Path[1] = StringTerm(string(first)).SetLocation(v[0].Location) - for i := 2; i < len(pkg.Path); i++ { - switch v[i-1].Value.(type) { - case String: - pkg.Path[i] = v[i-1] - default: - p.errorf(v[i-1].Location, "unexpected %v token: expecting string", TypeName(v[i-1].Value)) - return nil - } - } - default: - p.illegalToken() - return nil - } - } - - if pkg.Path == nil { - if len(p.s.errors) == 0 { - p.error(p.s.Loc(), "expected path") - } - return nil - } - - return &pkg -} - -func (p *Parser) parseImport() *Import { - - var imp Import - imp.SetLoc(p.s.Loc()) - - if p.s.tok != tokens.Import { - return nil - } - - p.scan() - if p.s.tok != tokens.Ident { - p.error(p.s.Loc(), "expected ident") - return nil - } - q, prev := p.presentParser() - term := q.parseTerm() - if term != nil { - switch v := term.Value.(type) { - case Var: - imp.Path = RefTerm(term).SetLocation(term.Location) - case Ref: - for i := 1; i < len(v); i++ { - if _, ok := v[i].Value.(String); !ok { - p.errorf(v[i].Location, "unexpected %v token: expecting string", TypeName(v[i].Value)) - return nil - } - } - imp.Path = term - } - } - // keep advanced parser state, reset known keywords - p.s = q.s - p.s.s = q.s.s.WithKeywords(prev) - - if imp.Path == nil { - p.error(p.s.Loc(), "expected path") - return nil - } - - path := imp.Path.Value.(Ref) - - switch { - case RootDocumentNames.Contains(path[0]): - case FutureRootDocument.Equal(path[0]): - case RegoRootDocument.Equal(path[0]): - default: - p.hint("if this is unexpected, try updating OPA") - p.errorf(imp.Path.Location, "unexpected import path, must begin with one of: %v, got: %v", - RootDocumentNames.Union(NewSet(FutureRootDocument, RegoRootDocument)), - path[0]) - return nil - } - - if p.s.tok == tokens.As { - p.scan() - - if p.s.tok != tokens.Ident { - p.illegal("expected var") - return nil - } - - if alias := p.parseTerm(); alias != nil { - v, ok := alias.Value.(Var) - if ok { - imp.Alias = v - return &imp - } - } - p.illegal("expected var") - return nil - } - - return &imp -} - -func (p *Parser) parseRules() []*Rule { - - var rule Rule - rule.SetLoc(p.s.Loc()) - - if p.s.tok == tokens.Default { - p.scan() - rule.Default = true - } - - if p.s.tok != tokens.Ident { - return nil - } - - usesContains := false - if rule.Head, usesContains = p.parseHead(rule.Default); rule.Head == nil { - return nil - } - - if usesContains { - rule.Head.keywords = append(rule.Head.keywords, tokens.Contains) - } - - if rule.Default { - if !p.validateDefaultRuleValue(&rule) { - return nil - } - - if len(rule.Head.Args) > 0 { - if !p.validateDefaultRuleArgs(&rule) { - return nil - } - } - - rule.Body = NewBody(NewExpr(BooleanTerm(true).SetLocation(rule.Location)).SetLocation(rule.Location)) - return []*Rule{&rule} - } - - // back-compat with `p[x] { ... }`` - hasIf := p.s.tok == tokens.If - - // p[x] if ... becomes a single-value rule p[x] - if hasIf && !usesContains && len(rule.Head.Ref()) == 2 { - if !rule.Head.Ref()[1].IsGround() && len(rule.Head.Args) == 0 { - rule.Head.Key = rule.Head.Ref()[1] - } - - if rule.Head.Value == nil { - rule.Head.generatedValue = true - rule.Head.Value = BooleanTerm(true).SetLocation(rule.Head.Location) - } else { - // p[x] = y if becomes a single-value rule p[x] with value y, but needs name for compat - v, ok := rule.Head.Ref()[0].Value.(Var) - if !ok { - return nil - } - rule.Head.Name = v - } - } - - // p[x] becomes a multi-value rule p - if !hasIf && !usesContains && - len(rule.Head.Args) == 0 && // not a function - len(rule.Head.Ref()) == 2 { // ref like 'p[x]' - v, ok := rule.Head.Ref()[0].Value.(Var) - if !ok { - return nil - } - rule.Head.Name = v - rule.Head.Key = rule.Head.Ref()[1] - if rule.Head.Value == nil { - rule.Head.SetRef(rule.Head.Ref()[:len(rule.Head.Ref())-1]) - } - } - - switch { - case hasIf: - rule.Head.keywords = append(rule.Head.keywords, tokens.If) - p.scan() - s := p.save() - if expr := p.parseLiteral(); expr != nil { - // NOTE(sr): set literals are never false or undefined, so parsing this as - // p if { true } - // ^^^^^^^^ set of one element, `true` - // isn't valid. - isSetLiteral := false - if t, ok := expr.Terms.(*Term); ok { - _, isSetLiteral = t.Value.(Set) - } - // expr.Term is []*Term or Every - if !isSetLiteral { - rule.Body.Append(expr) - break - } - } - - // parsing as literal didn't work out, expect '{ BODY }' - p.restore(s) - fallthrough - - case p.s.tok == tokens.LBrace: - p.scan() - if rule.Body = p.parseBody(tokens.RBrace); rule.Body == nil { - return nil - } - p.scan() - - case usesContains: - rule.Body = NewBody(NewExpr(BooleanTerm(true).SetLocation(rule.Location)).SetLocation(rule.Location)) - rule.generatedBody = true - rule.Location = rule.Head.Location - - return []*Rule{&rule} - - default: - return nil - } - - if p.s.tok == tokens.Else { - if r := rule.Head.Ref(); len(r) > 1 && !r.IsGround() { - p.error(p.s.Loc(), "else keyword cannot be used on rules with variables in head") - return nil - } - if rule.Head.Key != nil { - p.error(p.s.Loc(), "else keyword cannot be used on multi-value rules") - return nil - } - - if rule.Else = p.parseElse(rule.Head); rule.Else == nil { - return nil - } - } - - rule.Location.Text = p.s.Text(rule.Location.Offset, p.s.lastEnd) - - rules := []*Rule{&rule} - - for p.s.tok == tokens.LBrace { - - if rule.Else != nil { - p.error(p.s.Loc(), "expected else keyword") - return nil - } - - loc := p.s.Loc() - - p.scan() - var next Rule - - if next.Body = p.parseBody(tokens.RBrace); next.Body == nil { - return nil - } - p.scan() - - loc.Text = p.s.Text(loc.Offset, p.s.lastEnd) - next.SetLoc(loc) - - // Chained rule head's keep the original - // rule's head AST but have their location - // set to the rule body. - next.Head = rule.Head.Copy() - next.Head.keywords = rule.Head.keywords - for i := range next.Head.Args { - if v, ok := next.Head.Args[i].Value.(Var); ok && v.IsWildcard() { - next.Head.Args[i].Value = Var(p.genwildcard()) - } - } - setLocRecursive(next.Head, loc) - - rules = append(rules, &next) - } - - return rules -} - -func (p *Parser) parseElse(head *Head) *Rule { - - var rule Rule - rule.SetLoc(p.s.Loc()) - - rule.Head = head.Copy() - rule.Head.generatedValue = false - for i := range rule.Head.Args { - if v, ok := rule.Head.Args[i].Value.(Var); ok && v.IsWildcard() { - rule.Head.Args[i].Value = Var(p.genwildcard()) - } - } - rule.Head.SetLoc(p.s.Loc()) - - defer func() { - rule.Location.Text = p.s.Text(rule.Location.Offset, p.s.lastEnd) - }() - - p.scan() - - switch p.s.tok { - case tokens.LBrace, tokens.If: // no value, but a body follows directly - rule.Head.generatedValue = true - rule.Head.Value = BooleanTerm(true) - case tokens.Assign, tokens.Unify: - rule.Head.Assign = tokens.Assign == p.s.tok - p.scan() - rule.Head.Value = p.parseTermInfixCall() - if rule.Head.Value == nil { - return nil - } - rule.Head.Location.Text = p.s.Text(rule.Head.Location.Offset, p.s.lastEnd) - default: - p.illegal("expected else value term or rule body") - return nil - } - - hasIf := p.s.tok == tokens.If - hasLBrace := p.s.tok == tokens.LBrace - - if !hasIf && !hasLBrace { - rule.Body = NewBody(NewExpr(BooleanTerm(true))) - rule.generatedBody = true - setLocRecursive(rule.Body, rule.Location) - return &rule - } - - if hasIf { - rule.Head.keywords = append(rule.Head.keywords, tokens.If) - p.scan() - } - - if p.s.tok == tokens.LBrace { - p.scan() - if rule.Body = p.parseBody(tokens.RBrace); rule.Body == nil { - return nil - } - p.scan() - } else if p.s.tok != tokens.EOF { - expr := p.parseLiteral() - if expr == nil { - return nil - } - rule.Body.Append(expr) - setLocRecursive(rule.Body, rule.Location) - } else { - p.illegal("rule body expected") - return nil - } - - if p.s.tok == tokens.Else { - if rule.Else = p.parseElse(head); rule.Else == nil { - return nil - } - } - return &rule -} - -func (p *Parser) parseHead(defaultRule bool) (*Head, bool) { - head := &Head{} - loc := p.s.Loc() - defer func() { - if head != nil { - head.SetLoc(loc) - head.Location.Text = p.s.Text(head.Location.Offset, p.s.lastEnd) - } - }() - - term := p.parseVar() - if term == nil { - return nil, false - } - - ref := p.parseTermFinish(term, true) - if ref == nil { - p.illegal("expected rule head name") - return nil, false - } - - switch x := ref.Value.(type) { - case Var: - // Modify the code to add the location to the head ref - // and set the head ref's jsonOptions. - head = VarHead(x, ref.Location, p.po.JSONOptions) - case Ref: - head = RefHead(x) - case Call: - op, args := x[0], x[1:] - var ref Ref - switch y := op.Value.(type) { - case Var: - ref = Ref{op} - case Ref: - if _, ok := y[0].Value.(Var); !ok { - p.illegal("rule head ref %v invalid", y) - return nil, false - } - ref = y - } - head = RefHead(ref) - head.Args = append([]*Term{}, args...) - - default: - return nil, false - } - - name := head.Ref().String() - - switch p.s.tok { - case tokens.Contains: // NOTE: no Value for `contains` heads, we return here - // Catch error case of using 'contains' with a function definition rule head. - if head.Args != nil { - p.illegal("the contains keyword can only be used with multi-value rule definitions (e.g., %s contains { ... })", name) - } - p.scan() - head.Key = p.parseTermInfixCall() - if head.Key == nil { - p.illegal("expected rule key term (e.g., %s contains { ... })", name) - } - return head, true - - case tokens.Unify: - p.scan() - head.Value = p.parseTermInfixCall() - if head.Value == nil { - // FIX HEAD.String() - p.illegal("expected rule value term (e.g., %s[%s] = { ... })", name, head.Key) - } - case tokens.Assign: - p.scan() - head.Assign = true - head.Value = p.parseTermInfixCall() - if head.Value == nil { - switch { - case len(head.Args) > 0: - p.illegal("expected function value term (e.g., %s(...) := { ... })", name) - case head.Key != nil: - p.illegal("expected partial rule value term (e.g., %s[...] := { ... })", name) - case defaultRule: - p.illegal("expected default rule value term (e.g., default %s := )", name) - default: - p.illegal("expected rule value term (e.g., %s := { ... })", name) - } - } - } - - if head.Value == nil && head.Key == nil { - if len(head.Ref()) != 2 || len(head.Args) > 0 { - head.generatedValue = true - head.Value = BooleanTerm(true).SetLocation(head.Location) - } - } - return head, false -} - -func (p *Parser) parseBody(end tokens.Token) Body { - return p.parseQuery(false, end) -} - -func (p *Parser) parseQuery(requireSemi bool, end tokens.Token) Body { - body := Body{} - - if p.s.tok == end { - p.error(p.s.Loc(), "found empty body") - return nil - } - - for { - expr := p.parseLiteral() - if expr == nil { - return nil - } - - body.Append(expr) - - if p.s.tok == tokens.Semicolon { - p.scan() - continue - } - - if p.s.tok == end || requireSemi { - return body - } - - if !p.s.skippedNL { - // If there was already an error then don't pile this one on - if len(p.s.errors) == 0 { - p.illegal(`expected \n or %s or %s`, tokens.Semicolon, end) - } - return nil - } - } -} - -func (p *Parser) parseLiteral() (expr *Expr) { - - offset := p.s.loc.Offset - loc := p.s.Loc() - - defer func() { - if expr != nil { - loc.Text = p.s.Text(offset, p.s.lastEnd) - expr.SetLoc(loc) - } - }() - - var negated bool - if p.s.tok == tokens.Not { - p.scan() - negated = true - } - - switch p.s.tok { - case tokens.Some: - if negated { - p.illegal("illegal negation of 'some'") - return nil - } - return p.parseSome() - case tokens.Every: - if negated { - p.illegal("illegal negation of 'every'") - return nil - } - return p.parseEvery() - default: - s := p.save() - expr := p.parseExpr() - if expr != nil { - expr.Negated = negated - if p.s.tok == tokens.With { - if expr.With = p.parseWith(); expr.With == nil { - return nil - } - } - // If we find a plain `every` identifier, attempt to parse an every expression, - // add hint if it succeeds. - if term, ok := expr.Terms.(*Term); ok && Var("every").Equal(term.Value) { - var hint bool - t := p.save() - p.restore(s) - if expr := p.futureParser().parseEvery(); expr != nil { - _, hint = expr.Terms.(*Every) - } - p.restore(t) - if hint { - p.hint("`import future.keywords.every` for `every x in xs { ... }` expressions") - } - } - return expr - } - return nil - } -} - -func (p *Parser) parseWith() []*With { - - withs := []*With{} - - for { - - with := With{ - Location: p.s.Loc(), - } - p.scan() - - if p.s.tok != tokens.Ident { - p.illegal("expected ident") - return nil - } - - with.Target = p.parseTerm() - if with.Target == nil { - return nil - } - - switch with.Target.Value.(type) { - case Ref, Var: - break - default: - p.illegal("expected with target path") - } - - if p.s.tok != tokens.As { - p.illegal("expected as keyword") - return nil - } - - p.scan() - - if with.Value = p.parseTermInfixCall(); with.Value == nil { - return nil - } - - with.Location.Text = p.s.Text(with.Location.Offset, p.s.lastEnd) - - withs = append(withs, &with) - - if p.s.tok != tokens.With { - break - } - } - - return withs -} - -func (p *Parser) parseSome() *Expr { - - decl := &SomeDecl{} - decl.SetLoc(p.s.Loc()) - - // Attempt to parse "some x in xs", which will end up in - // SomeDecl{Symbols: ["member(x, xs)"]} - s := p.save() - p.scan() - if term := p.parseTermInfixCall(); term != nil { - if call, ok := term.Value.(Call); ok { - switch call[0].String() { - case Member.Name: - if len(call) != 3 { - p.illegal("illegal domain") - return nil - } - case MemberWithKey.Name: - if len(call) != 4 { - p.illegal("illegal domain") - return nil - } - default: - p.illegal("expected `x in xs` or `x, y in xs` expression") - return nil - } - - decl.Symbols = []*Term{term} - expr := NewExpr(decl).SetLocation(decl.Location) - if p.s.tok == tokens.With { - if expr.With = p.parseWith(); expr.With == nil { - return nil - } - } - return expr - } - } - - p.restore(s) - s = p.save() // new copy for later - var hint bool - p.scan() - if term := p.futureParser().parseTermInfixCall(); term != nil { - if call, ok := term.Value.(Call); ok { - switch call[0].String() { - case Member.Name, MemberWithKey.Name: - hint = true - } - } - } - - // go on as before, it's `some x[...]` or illegal - p.restore(s) - if hint { - p.hint("`import future.keywords.in` for `some x in xs` expressions") - } - - for { // collecting var args - - p.scan() - - if p.s.tok != tokens.Ident { - p.illegal("expected var") - return nil - } - - decl.Symbols = append(decl.Symbols, p.parseVar()) - - p.scan() - - if p.s.tok != tokens.Comma { - break - } - } - - return NewExpr(decl).SetLocation(decl.Location) -} - -func (p *Parser) parseEvery() *Expr { - qb := &Every{} - qb.SetLoc(p.s.Loc()) - - // TODO(sr): We'd get more accurate error messages if we didn't rely on - // parseTermInfixCall here, but parsed "var [, var] in term" manually. - p.scan() - term := p.parseTermInfixCall() - if term == nil { - return nil - } - call, ok := term.Value.(Call) - if !ok { - p.illegal("expected `x[, y] in xs { ... }` expression") - return nil - } - switch call[0].String() { - case Member.Name: // x in xs - if len(call) != 3 { - p.illegal("illegal domain") - return nil - } - qb.Value = call[1] - qb.Domain = call[2] - case MemberWithKey.Name: // k, v in xs - if len(call) != 4 { - p.illegal("illegal domain") - return nil - } - qb.Key = call[1] - qb.Value = call[2] - qb.Domain = call[3] - if _, ok := qb.Key.Value.(Var); !ok { - p.illegal("expected key to be a variable") - return nil - } - default: - p.illegal("expected `x[, y] in xs { ... }` expression") - return nil - } - if _, ok := qb.Value.Value.(Var); !ok { - p.illegal("expected value to be a variable") - return nil - } - if p.s.tok == tokens.LBrace { // every x in xs { ... } - p.scan() - body := p.parseBody(tokens.RBrace) - if body == nil { - return nil - } - p.scan() - qb.Body = body - expr := NewExpr(qb).SetLocation(qb.Location) - - if p.s.tok == tokens.With { - if expr.With = p.parseWith(); expr.With == nil { - return nil - } - } - return expr - } - - p.illegal("missing body") - return nil -} - -func (p *Parser) parseExpr() *Expr { - - lhs := p.parseTermInfixCall() - if lhs == nil { - return nil - } - - if op := p.parseTermOp(tokens.Assign, tokens.Unify); op != nil { - if rhs := p.parseTermInfixCall(); rhs != nil { - return NewExpr([]*Term{op, lhs, rhs}) - } - return nil - } - - // NOTE(tsandall): the top-level call term is converted to an expr because - // the evaluator does not support the call term type (nested calls are - // rewritten by the compiler.) - if call, ok := lhs.Value.(Call); ok { - return NewExpr([]*Term(call)) - } - - return NewExpr(lhs) -} - -// parseTermInfixCall consumes the next term from the input and returns it. If a -// term cannot be parsed the return value is nil and error will be recorded. The -// scanner will be advanced to the next token before returning. -// By starting out with infix relations (==, !=, <, etc) and further calling the -// other binary operators (|, &, arithmetics), it constitutes the binding -// precedence. -func (p *Parser) parseTermInfixCall() *Term { - return p.parseTermIn(nil, true, p.s.loc.Offset) -} - -func (p *Parser) parseTermInfixCallInList() *Term { - return p.parseTermIn(nil, false, p.s.loc.Offset) -} - -func (p *Parser) parseTermIn(lhs *Term, keyVal bool, offset int) *Term { - // NOTE(sr): `in` is a bit special: besides `lhs in rhs`, it also - // supports `key, val in rhs`, so it can have an optional second lhs. - // `keyVal` triggers if we attempt to parse a second lhs argument (`mhs`). - if lhs == nil { - lhs = p.parseTermRelation(nil, offset) - } - if lhs != nil { - if keyVal && p.s.tok == tokens.Comma { // second "lhs", or "middle hand side" - s := p.save() - p.scan() - if mhs := p.parseTermRelation(nil, offset); mhs != nil { - if op := p.parseTermOpName(MemberWithKey.Ref(), tokens.In); op != nil { - if rhs := p.parseTermRelation(nil, p.s.loc.Offset); rhs != nil { - call := p.setLoc(CallTerm(op, lhs, mhs, rhs), lhs.Location, offset, p.s.lastEnd) - switch p.s.tok { - case tokens.In: - return p.parseTermIn(call, keyVal, offset) - default: - return call - } - } - } - } - p.restore(s) - } - if op := p.parseTermOpName(Member.Ref(), tokens.In); op != nil { - if rhs := p.parseTermRelation(nil, p.s.loc.Offset); rhs != nil { - call := p.setLoc(CallTerm(op, lhs, rhs), lhs.Location, offset, p.s.lastEnd) - switch p.s.tok { - case tokens.In: - return p.parseTermIn(call, keyVal, offset) - default: - return call - } - } - } - } - return lhs -} - -func (p *Parser) parseTermRelation(lhs *Term, offset int) *Term { - if lhs == nil { - lhs = p.parseTermOr(nil, offset) - } - if lhs != nil { - if op := p.parseTermOp(tokens.Equal, tokens.Neq, tokens.Lt, tokens.Gt, tokens.Lte, tokens.Gte); op != nil { - if rhs := p.parseTermOr(nil, p.s.loc.Offset); rhs != nil { - call := p.setLoc(CallTerm(op, lhs, rhs), lhs.Location, offset, p.s.lastEnd) - switch p.s.tok { - case tokens.Equal, tokens.Neq, tokens.Lt, tokens.Gt, tokens.Lte, tokens.Gte: - return p.parseTermRelation(call, offset) - default: - return call - } - } - } - } - return lhs -} - -func (p *Parser) parseTermOr(lhs *Term, offset int) *Term { - if lhs == nil { - lhs = p.parseTermAnd(nil, offset) - } - if lhs != nil { - if op := p.parseTermOp(tokens.Or); op != nil { - if rhs := p.parseTermAnd(nil, p.s.loc.Offset); rhs != nil { - call := p.setLoc(CallTerm(op, lhs, rhs), lhs.Location, offset, p.s.lastEnd) - switch p.s.tok { - case tokens.Or: - return p.parseTermOr(call, offset) - default: - return call - } - } - } - return lhs - } - return nil -} - -func (p *Parser) parseTermAnd(lhs *Term, offset int) *Term { - if lhs == nil { - lhs = p.parseTermArith(nil, offset) - } - if lhs != nil { - if op := p.parseTermOp(tokens.And); op != nil { - if rhs := p.parseTermArith(nil, p.s.loc.Offset); rhs != nil { - call := p.setLoc(CallTerm(op, lhs, rhs), lhs.Location, offset, p.s.lastEnd) - switch p.s.tok { - case tokens.And: - return p.parseTermAnd(call, offset) - default: - return call - } - } - } - return lhs - } - return nil -} - -func (p *Parser) parseTermArith(lhs *Term, offset int) *Term { - if lhs == nil { - lhs = p.parseTermFactor(nil, offset) - } - if lhs != nil { - if op := p.parseTermOp(tokens.Add, tokens.Sub); op != nil { - if rhs := p.parseTermFactor(nil, p.s.loc.Offset); rhs != nil { - call := p.setLoc(CallTerm(op, lhs, rhs), lhs.Location, offset, p.s.lastEnd) - switch p.s.tok { - case tokens.Add, tokens.Sub: - return p.parseTermArith(call, offset) - default: - return call - } - } - } - } - return lhs -} - -func (p *Parser) parseTermFactor(lhs *Term, offset int) *Term { - if lhs == nil { - lhs = p.parseTerm() - } - if lhs != nil { - if op := p.parseTermOp(tokens.Mul, tokens.Quo, tokens.Rem); op != nil { - if rhs := p.parseTerm(); rhs != nil { - call := p.setLoc(CallTerm(op, lhs, rhs), lhs.Location, offset, p.s.lastEnd) - switch p.s.tok { - case tokens.Mul, tokens.Quo, tokens.Rem: - return p.parseTermFactor(call, offset) - default: - return call - } - } - } - } - return lhs -} - -func (p *Parser) parseTerm() *Term { - if term, s := p.parsedTermCacheLookup(); s != nil { - p.restore(s) - return term - } - s0 := p.save() - - var term *Term - switch p.s.tok { - case tokens.Null: - term = NullTerm().SetLocation(p.s.Loc()) - case tokens.True: - term = BooleanTerm(true).SetLocation(p.s.Loc()) - case tokens.False: - term = BooleanTerm(false).SetLocation(p.s.Loc()) - case tokens.Sub, tokens.Dot, tokens.Number: - term = p.parseNumber() - case tokens.String: - term = p.parseString() - case tokens.Ident, tokens.Contains: // NOTE(sr): contains anywhere BUT in rule heads gets no special treatment - term = p.parseVar() - case tokens.LBrack: - term = p.parseArray() - case tokens.LBrace: - term = p.parseSetOrObject() - case tokens.LParen: - offset := p.s.loc.Offset - p.scan() - if r := p.parseTermInfixCall(); r != nil { - if p.s.tok == tokens.RParen { - r.Location.Text = p.s.Text(offset, p.s.tokEnd) - term = r - } else { - p.error(p.s.Loc(), "non-terminated expression") - } - } - default: - p.illegalToken() - } - - term = p.parseTermFinish(term, false) - p.parsedTermCachePush(term, s0) - return term -} - -func (p *Parser) parseTermFinish(head *Term, skipws bool) *Term { - if head == nil { - return nil - } - offset := p.s.loc.Offset - p.doScan(skipws) - - switch p.s.tok { - case tokens.LParen, tokens.Dot, tokens.LBrack: - return p.parseRef(head, offset) - case tokens.Whitespace: - p.scan() - fallthrough - default: - if _, ok := head.Value.(Var); ok && RootDocumentNames.Contains(head) { - return RefTerm(head).SetLocation(head.Location) - } - return head - } -} - -func (p *Parser) parseNumber() *Term { - var prefix string - loc := p.s.Loc() - if p.s.tok == tokens.Sub { - prefix = "-" - p.scan() - switch p.s.tok { - case tokens.Number, tokens.Dot: - break - default: - p.illegal("expected number") - return nil - } - } - if p.s.tok == tokens.Dot { - prefix += "." - p.scan() - if p.s.tok != tokens.Number { - p.illegal("expected number") - return nil - } - } - - // Check for multiple leading 0's, parsed by math/big.Float.Parse as decimal 0: - // https://golang.org/pkg/math/big/#Float.Parse - if ((len(prefix) != 0 && prefix[0] == '-') || len(prefix) == 0) && - len(p.s.lit) > 1 && p.s.lit[0] == '0' && p.s.lit[1] == '0' { - p.illegal("expected number") - return nil - } - - // Ensure that the number is valid - s := prefix + p.s.lit - f, ok := new(big.Float).SetString(s) - if !ok { - p.illegal("invalid float") - return nil - } - - // Put limit on size of exponent to prevent non-linear cost of String() - // function on big.Float from causing denial of service: https://github.com/golang/go/issues/11068 - // - // n == sign * mantissa * 2^exp - // 0.5 <= mantissa < 1.0 - // - // The limit is arbitrary. - exp := f.MantExp(nil) - if exp > 1e5 || exp < -1e5 || f.IsInf() { // +/- inf, exp is 0 - p.error(p.s.Loc(), "number too big") - return nil - } - - // Note: Use the original string, do *not* round trip from - // the big.Float as it can cause precision loss. - r := NumberTerm(json.Number(s)).SetLocation(loc) - return r -} - -func (p *Parser) parseString() *Term { - if p.s.lit[0] == '"' { - var s string - err := json.Unmarshal([]byte(p.s.lit), &s) - if err != nil { - p.errorf(p.s.Loc(), "illegal string literal: %s", p.s.lit) - return nil - } - term := StringTerm(s).SetLocation(p.s.Loc()) - return term - } - return p.parseRawString() -} - -func (p *Parser) parseRawString() *Term { - if len(p.s.lit) < 2 { - return nil - } - term := StringTerm(p.s.lit[1 : len(p.s.lit)-1]).SetLocation(p.s.Loc()) - return term -} - -// this is the name to use for instantiating an empty set, e.g., `set()`. -var setConstructor = RefTerm(VarTerm("set")) - -func (p *Parser) parseCall(operator *Term, offset int) (term *Term) { - - loc := operator.Location - var end int - - defer func() { - p.setLoc(term, loc, offset, end) - }() - - p.scan() // steps over '(' - - if p.s.tok == tokens.RParen { // no args, i.e. set() or any.func() - end = p.s.tokEnd - p.scanWS() - if operator.Equal(setConstructor) { - return SetTerm() - } - return CallTerm(operator) - } - - if r := p.parseTermList(tokens.RParen, []*Term{operator}); r != nil { - end = p.s.tokEnd - p.scanWS() - return CallTerm(r...) - } - - return nil -} - -func (p *Parser) parseRef(head *Term, offset int) (term *Term) { - - loc := head.Location - var end int - - defer func() { - p.setLoc(term, loc, offset, end) - }() - - switch h := head.Value.(type) { - case Var, *Array, Object, Set, *ArrayComprehension, *ObjectComprehension, *SetComprehension, Call: - // ok - default: - p.errorf(loc, "illegal ref (head cannot be %v)", TypeName(h)) - } - - ref := []*Term{head} - - for { - switch p.s.tok { - case tokens.Dot: - p.scanWS() - if p.s.tok != tokens.Ident { - p.illegal("expected %v", tokens.Ident) - return nil - } - ref = append(ref, StringTerm(p.s.lit).SetLocation(p.s.Loc())) - p.scanWS() - case tokens.LParen: - term = p.parseCall(p.setLoc(RefTerm(ref...), loc, offset, p.s.loc.Offset), offset) - if term != nil { - switch p.s.tok { - case tokens.Whitespace: - p.scan() - end = p.s.lastEnd - return term - case tokens.Dot, tokens.LBrack: - term = p.parseRef(term, offset) - } - } - end = p.s.tokEnd - return term - case tokens.LBrack: - p.scan() - if term := p.parseTermInfixCall(); term != nil { - if p.s.tok != tokens.RBrack { - p.illegal("expected %v", tokens.LBrack) - return nil - } - ref = append(ref, term) - p.scanWS() - } else { - return nil - } - case tokens.Whitespace: - end = p.s.lastEnd - p.scan() - return RefTerm(ref...) - default: - end = p.s.lastEnd - return RefTerm(ref...) - } - } -} - -func (p *Parser) parseArray() (term *Term) { - - loc := p.s.Loc() - offset := p.s.loc.Offset - - defer func() { - p.setLoc(term, loc, offset, p.s.tokEnd) - }() - - p.scan() - - if p.s.tok == tokens.RBrack { - return ArrayTerm() - } - - potentialComprehension := true - - // Skip leading commas, eg [, x, y] - // Supported for backwards compatibility. In the future - // we should make this a parse error. - if p.s.tok == tokens.Comma { - potentialComprehension = false - p.scan() - } - - s := p.save() - - // NOTE(tsandall): The parser cannot attempt a relational term here because - // of ambiguity around comprehensions. For example, given: - // - // {1 | 1} - // - // Does this represent a set comprehension or a set containing binary OR - // call? We resolve the ambiguity by prioritizing comprehensions. - head := p.parseTerm() - - if head == nil { - return nil - } - - switch p.s.tok { - case tokens.RBrack: - return ArrayTerm(head) - case tokens.Comma: - p.scan() - if terms := p.parseTermList(tokens.RBrack, []*Term{head}); terms != nil { - return NewTerm(NewArray(terms...)) - } - return nil - case tokens.Or: - if potentialComprehension { - // Try to parse as if it is an array comprehension - p.scan() - if body := p.parseBody(tokens.RBrack); body != nil { - return ArrayComprehensionTerm(head, body) - } - if p.s.tok != tokens.Comma { - return nil - } - } - // fall back to parsing as a normal array definition - } - - p.restore(s) - - if terms := p.parseTermList(tokens.RBrack, nil); terms != nil { - return NewTerm(NewArray(terms...)) - } - return nil -} - -func (p *Parser) parseSetOrObject() (term *Term) { - loc := p.s.Loc() - offset := p.s.loc.Offset - - defer func() { - p.setLoc(term, loc, offset, p.s.tokEnd) - }() - - p.scan() - - if p.s.tok == tokens.RBrace { - return ObjectTerm() - } - - potentialComprehension := true - - // Skip leading commas, eg {, x, y} - // Supported for backwards compatibility. In the future - // we should make this a parse error. - if p.s.tok == tokens.Comma { - potentialComprehension = false - p.scan() - } - - s := p.save() - - // Try parsing just a single term first to give comprehensions higher - // priority to "or" calls in ambiguous situations. Eg: { a | b } - // will be a set comprehension. - // - // Note: We don't know yet if it is a set or object being defined. - head := p.parseTerm() - if head == nil { - return nil - } - - switch p.s.tok { - case tokens.Or: - if potentialComprehension { - return p.parseSet(s, head, potentialComprehension) - } - case tokens.RBrace, tokens.Comma: - return p.parseSet(s, head, potentialComprehension) - case tokens.Colon: - return p.parseObject(head, potentialComprehension) - } - - p.restore(s) - - head = p.parseTermInfixCallInList() - if head == nil { - return nil - } - - switch p.s.tok { - case tokens.RBrace, tokens.Comma: - return p.parseSet(s, head, false) - case tokens.Colon: - // It still might be an object comprehension, eg { a+1: b | ... } - return p.parseObject(head, potentialComprehension) - } - - p.illegal("non-terminated set") - return nil -} - -func (p *Parser) parseSet(s *state, head *Term, potentialComprehension bool) *Term { - switch p.s.tok { - case tokens.RBrace: - return SetTerm(head) - case tokens.Comma: - p.scan() - if terms := p.parseTermList(tokens.RBrace, []*Term{head}); terms != nil { - return SetTerm(terms...) - } - case tokens.Or: - if potentialComprehension { - // Try to parse as if it is a set comprehension - p.scan() - if body := p.parseBody(tokens.RBrace); body != nil { - return SetComprehensionTerm(head, body) - } - if p.s.tok != tokens.Comma { - return nil - } - } - // Fall back to parsing as normal set definition - p.restore(s) - if terms := p.parseTermList(tokens.RBrace, nil); terms != nil { - return SetTerm(terms...) - } - } - return nil -} - -func (p *Parser) parseObject(k *Term, potentialComprehension bool) *Term { - // NOTE(tsandall): Assumption: this function is called after parsing the key - // of the head element and then receiving a colon token from the scanner. - // Advance beyond the colon and attempt to parse an object. - if p.s.tok != tokens.Colon { - panic("expected colon") - } - p.scan() - - s := p.save() - - // NOTE(sr): We first try to parse the value as a term (`v`), and see - // if we can parse `{ x: v | ...}` as a comprehension. - // However, if we encounter either a Comma or an RBace, it cannot be - // parsed as a comprehension -- so we save double work further down - // where `parseObjectFinish(k, v, false)` would only exercise the - // same code paths once more. - v := p.parseTerm() - if v == nil { - return nil - } - - potentialRelation := true - if potentialComprehension { - switch p.s.tok { - case tokens.RBrace, tokens.Comma: - potentialRelation = false - fallthrough - case tokens.Or: - if term := p.parseObjectFinish(k, v, true); term != nil { - return term - } - } - } - - p.restore(s) - - if potentialRelation { - v := p.parseTermInfixCallInList() - if v == nil { - return nil - } - - switch p.s.tok { - case tokens.RBrace, tokens.Comma: - return p.parseObjectFinish(k, v, false) - } - } - - p.illegal("non-terminated object") - return nil -} - -func (p *Parser) parseObjectFinish(key, val *Term, potentialComprehension bool) *Term { - switch p.s.tok { - case tokens.RBrace: - return ObjectTerm([2]*Term{key, val}) - case tokens.Or: - if potentialComprehension { - p.scan() - if body := p.parseBody(tokens.RBrace); body != nil { - return ObjectComprehensionTerm(key, val, body) - } - } else { - p.illegal("non-terminated object") - } - case tokens.Comma: - p.scan() - if r := p.parseTermPairList(tokens.RBrace, [][2]*Term{{key, val}}); r != nil { - return ObjectTerm(r...) - } - } - return nil -} - -func (p *Parser) parseTermList(end tokens.Token, r []*Term) []*Term { - if p.s.tok == end { - return r - } - for { - term := p.parseTermInfixCallInList() - if term != nil { - r = append(r, term) - switch p.s.tok { - case end: - return r - case tokens.Comma: - p.scan() - if p.s.tok == end { - return r - } - continue - default: - p.illegal(fmt.Sprintf("expected %q or %q", tokens.Comma, end)) - return nil - } - } - return nil - } -} - -func (p *Parser) parseTermPairList(end tokens.Token, r [][2]*Term) [][2]*Term { - if p.s.tok == end { - return r - } - for { - key := p.parseTermInfixCallInList() - if key != nil { - switch p.s.tok { - case tokens.Colon: - p.scan() - if val := p.parseTermInfixCallInList(); val != nil { - r = append(r, [2]*Term{key, val}) - switch p.s.tok { - case end: - return r - case tokens.Comma: - p.scan() - if p.s.tok == end { - return r - } - continue - default: - p.illegal(fmt.Sprintf("expected %q or %q", tokens.Comma, end)) - return nil - } - } - default: - p.illegal(fmt.Sprintf("expected %q", tokens.Colon)) - return nil - } - } - return nil - } -} - -func (p *Parser) parseTermOp(values ...tokens.Token) *Term { - for i := range values { - if p.s.tok == values[i] { - r := RefTerm(VarTerm(fmt.Sprint(p.s.tok)).SetLocation(p.s.Loc())).SetLocation(p.s.Loc()) - p.scan() - return r - } - } - return nil -} - -func (p *Parser) parseTermOpName(ref Ref, values ...tokens.Token) *Term { - for i := range values { - if p.s.tok == values[i] { - for _, r := range ref { - r.SetLocation(p.s.Loc()) - } - t := RefTerm(ref...) - t.SetLocation(p.s.Loc()) - p.scan() - return t - } - } - return nil -} - -func (p *Parser) parseVar() *Term { - - s := p.s.lit - - term := VarTerm(s).SetLocation(p.s.Loc()) - - // Update wildcard values with unique identifiers - if term.Equal(Wildcard) { - term.Value = Var(p.genwildcard()) - } - - return term -} - -func (p *Parser) genwildcard() string { - c := p.s.wildcard - p.s.wildcard++ - return fmt.Sprintf("%v%d", WildcardPrefix, c) -} - -func (p *Parser) error(loc *location.Location, reason string) { - p.errorf(loc, reason) -} - -func (p *Parser) errorf(loc *location.Location, f string, a ...interface{}) { - msg := strings.Builder{} - msg.WriteString(fmt.Sprintf(f, a...)) - - switch len(p.s.hints) { - case 0: // nothing to do - case 1: - msg.WriteString(" (hint: ") - msg.WriteString(p.s.hints[0]) - msg.WriteRune(')') - default: - msg.WriteString(" (hints: ") - for i, h := range p.s.hints { - if i > 0 { - msg.WriteString(", ") - } - msg.WriteString(h) - } - msg.WriteRune(')') - } - - p.s.errors = append(p.s.errors, &Error{ - Code: ParseErr, - Message: msg.String(), - Location: loc, - Details: newParserErrorDetail(p.s.s.Bytes(), loc.Offset), - }) - p.s.hints = nil -} - -func (p *Parser) hint(f string, a ...interface{}) { - p.s.hints = append(p.s.hints, fmt.Sprintf(f, a...)) -} - -func (p *Parser) illegal(note string, a ...interface{}) { - tok := p.s.tok.String() - - if p.s.tok == tokens.Illegal { - p.errorf(p.s.Loc(), "illegal token") - return - } - - tokType := "token" - if tokens.IsKeyword(p.s.tok) { - tokType = "keyword" - } - if _, ok := futureKeywords[p.s.tok.String()]; ok { - tokType = "keyword" - } - - note = fmt.Sprintf(note, a...) - if len(note) > 0 { - p.errorf(p.s.Loc(), "unexpected %s %s: %s", tok, tokType, note) - } else { - p.errorf(p.s.Loc(), "unexpected %s %s", tok, tokType) - } -} - -func (p *Parser) illegalToken() { - p.illegal("") -} - -func (p *Parser) scan() { - p.doScan(true) -} - -func (p *Parser) scanWS() { - p.doScan(false) -} - -func (p *Parser) doScan(skipws bool) { - - // NOTE(tsandall): the last position is used to compute the "text" field for - // complex AST nodes. Whitespace never affects the last position of an AST - // node so do not update it when scanning. - if p.s.tok != tokens.Whitespace { - p.s.lastEnd = p.s.tokEnd - p.s.skippedNL = false - } - - var errs []scanner.Error - for { - var pos scanner.Position - p.s.tok, pos, p.s.lit, errs = p.s.s.Scan() - - p.s.tokEnd = pos.End - p.s.loc.Row = pos.Row - p.s.loc.Col = pos.Col - p.s.loc.Offset = pos.Offset - p.s.loc.Text = p.s.Text(pos.Offset, pos.End) - p.s.loc.Tabs = pos.Tabs - - for _, err := range errs { - p.error(p.s.Loc(), err.Message) - } - - if len(errs) > 0 { - p.s.tok = tokens.Illegal - } - - if p.s.tok == tokens.Whitespace { - if p.s.lit == "\n" { - p.s.skippedNL = true - } - if skipws { - continue - } - } - - if p.s.tok != tokens.Comment { - break - } - - // For backwards compatibility leave a nil - // Text value if there is no text rather than - // an empty string. - var commentText []byte - if len(p.s.lit) > 1 { - commentText = []byte(p.s.lit[1:]) - } - comment := NewComment(commentText) - comment.SetLoc(p.s.Loc()) - p.s.comments = append(p.s.comments, comment) - } -} - -func (p *Parser) save() *state { - cpy := *p.s - s := *cpy.s - cpy.s = &s - return &cpy -} - -func (p *Parser) restore(s *state) { - p.s = s -} - -func setLocRecursive(x interface{}, loc *location.Location) { - NewGenericVisitor(func(x interface{}) bool { - if node, ok := x.(Node); ok { - node.SetLoc(loc) - } - return false - }).Walk(x) -} - -func (p *Parser) setLoc(term *Term, loc *location.Location, offset, end int) *Term { - if term != nil { - cpy := *loc - term.Location = &cpy - term.Location.Text = p.s.Text(offset, end) - } - return term -} - -func (p *Parser) validateDefaultRuleValue(rule *Rule) bool { - if rule.Head.Value == nil { - p.error(rule.Loc(), "illegal default rule (must have a value)") - return false - } - - valid := true - vis := NewGenericVisitor(func(x interface{}) bool { - switch x.(type) { - case *ArrayComprehension, *ObjectComprehension, *SetComprehension: // skip closures - return true - case Ref, Var, Call: - p.error(rule.Loc(), fmt.Sprintf("illegal default rule (value cannot contain %v)", TypeName(x))) - valid = false - return true - } - return false - }) - - vis.Walk(rule.Head.Value.Value) - return valid -} - -func (p *Parser) validateDefaultRuleArgs(rule *Rule) bool { - - valid := true - vars := NewVarSet() - - vis := NewGenericVisitor(func(x interface{}) bool { - switch x := x.(type) { - case Var: - if vars.Contains(x) { - p.error(rule.Loc(), fmt.Sprintf("illegal default rule (arguments cannot be repeated %v)", x)) - valid = false - return true - } - vars.Add(x) - - case *Term: - switch v := x.Value.(type) { - case Var: // do nothing - default: - p.error(rule.Loc(), fmt.Sprintf("illegal default rule (arguments cannot contain %v)", TypeName(v))) - valid = false - return true - } - } - - return false - }) - - vis.Walk(rule.Head.Args) - return valid -} - -// We explicitly use yaml unmarshalling, to accommodate for the '_' in 'related_resources', -// which isn't handled properly by json for some reason. -type rawAnnotation struct { - Scope string `yaml:"scope"` - Title string `yaml:"title"` - Entrypoint bool `yaml:"entrypoint"` - Description string `yaml:"description"` - Organizations []string `yaml:"organizations"` - RelatedResources []interface{} `yaml:"related_resources"` - Authors []interface{} `yaml:"authors"` - Schemas []map[string]any `yaml:"schemas"` - Custom map[string]interface{} `yaml:"custom"` -} - -type metadataParser struct { - buf *bytes.Buffer - comments []*Comment - loc *location.Location -} - -func newMetadataParser(loc *Location) *metadataParser { - return &metadataParser{loc: loc, buf: bytes.NewBuffer(nil)} -} - -func (b *metadataParser) Append(c *Comment) { - b.buf.Write(bytes.TrimPrefix(c.Text, []byte(" "))) - b.buf.WriteByte('\n') - b.comments = append(b.comments, c) -} - -var yamlLineErrRegex = regexp.MustCompile(`^yaml:(?: unmarshal errors:[\n\s]*)? line ([[:digit:]]+):`) - -func (b *metadataParser) Parse() (*Annotations, error) { - - var raw rawAnnotation - - if len(bytes.TrimSpace(b.buf.Bytes())) == 0 { - return nil, fmt.Errorf("expected METADATA block, found whitespace") - } - - if err := yaml.Unmarshal(b.buf.Bytes(), &raw); err != nil { - var comment *Comment - match := yamlLineErrRegex.FindStringSubmatch(err.Error()) - if len(match) == 2 { - index, err2 := strconv.Atoi(match[1]) - if err2 == nil { - if index >= len(b.comments) { - comment = b.comments[len(b.comments)-1] - } else { - comment = b.comments[index] - } - b.loc = comment.Location - } - } - - if match == nil && len(b.comments) > 0 { - b.loc = b.comments[0].Location - } - - return nil, augmentYamlError(err, b.comments) - } - - var result Annotations - result.comments = b.comments - result.Scope = raw.Scope - result.Entrypoint = raw.Entrypoint - result.Title = raw.Title - result.Description = raw.Description - result.Organizations = raw.Organizations - - for _, v := range raw.RelatedResources { - rr, err := parseRelatedResource(v) - if err != nil { - return nil, fmt.Errorf("invalid related-resource definition %s: %w", v, err) - } - result.RelatedResources = append(result.RelatedResources, rr) - } - - for _, pair := range raw.Schemas { - k, v := unwrapPair(pair) - - var a SchemaAnnotation - var err error - - a.Path, err = ParseRef(k) - if err != nil { - return nil, fmt.Errorf("invalid document reference") - } - - switch v := v.(type) { - case string: - a.Schema, err = parseSchemaRef(v) - if err != nil { - return nil, err - } - case map[string]any: - w, err := convertYAMLMapKeyTypes(v, nil) - if err != nil { - return nil, fmt.Errorf("invalid schema definition: %w", err) - } - a.Definition = &w - default: - return nil, fmt.Errorf("invalid schema declaration for path %q", k) - } - - result.Schemas = append(result.Schemas, &a) - } - - for _, v := range raw.Authors { - author, err := parseAuthor(v) - if err != nil { - return nil, fmt.Errorf("invalid author definition %s: %w", v, err) - } - result.Authors = append(result.Authors, author) - } - - result.Custom = make(map[string]interface{}) - for k, v := range raw.Custom { - val, err := convertYAMLMapKeyTypes(v, nil) - if err != nil { - return nil, err - } - result.Custom[k] = val - } - - result.Location = b.loc - - // recreate original text of entire metadata block for location text attribute - sb := strings.Builder{} - sb.WriteString("# METADATA\n") - - lines := bytes.Split(b.buf.Bytes(), []byte{'\n'}) - - for _, line := range lines[:len(lines)-1] { - sb.WriteString("# ") - sb.Write(line) - sb.WriteByte('\n') - } - - result.Location.Text = []byte(strings.TrimSuffix(sb.String(), "\n")) - - return &result, nil -} - -// augmentYamlError augments a YAML error with hints intended to help the user figure out the cause of an otherwise -// cryptic error. These are hints, instead of proper errors, because they are educated guesses, and aren't guaranteed -// to be correct. -func augmentYamlError(err error, comments []*Comment) error { - // Adding hints for when key/value ':' separator isn't suffixed with a legal YAML space symbol - for _, comment := range comments { - txt := string(comment.Text) - parts := strings.Split(txt, ":") - if len(parts) > 1 { - parts = parts[1:] - var invalidSpaces []string - for partIndex, part := range parts { - if len(part) == 0 && partIndex == len(parts)-1 { - invalidSpaces = []string{} - break - } - - r, _ := utf8.DecodeRuneInString(part) - if r == ' ' || r == '\t' { - invalidSpaces = []string{} - break - } - - invalidSpaces = append(invalidSpaces, fmt.Sprintf("%+q", r)) - } - if len(invalidSpaces) > 0 { - err = fmt.Errorf( - "%s\n Hint: on line %d, symbol(s) %v immediately following a key/value separator ':' is not a legal yaml space character", - err.Error(), comment.Location.Row, invalidSpaces) - } - } - } - return err -} - -func unwrapPair(pair map[string]interface{}) (string, interface{}) { - for k, v := range pair { - return k, v - } - return "", nil -} - -var errInvalidSchemaRef = fmt.Errorf("invalid schema reference") - -// NOTE(tsandall): 'schema' is not registered as a root because it's not -// supported by the compiler or evaluator today. Once we fix that, we can remove -// this function. -func parseSchemaRef(s string) (Ref, error) { - - term, err := ParseTerm(s) - if err == nil { - switch v := term.Value.(type) { - case Var: - if term.Equal(SchemaRootDocument) { - return SchemaRootRef.Copy(), nil - } - case Ref: - if v.HasPrefix(SchemaRootRef) { - return v, nil - } - } - } - - return nil, errInvalidSchemaRef -} - -func parseRelatedResource(rr interface{}) (*RelatedResourceAnnotation, error) { - rr, err := convertYAMLMapKeyTypes(rr, nil) - if err != nil { - return nil, err - } - - switch rr := rr.(type) { - case string: - if len(rr) > 0 { - u, err := url.Parse(rr) - if err != nil { - return nil, err - } - return &RelatedResourceAnnotation{Ref: *u}, nil - } - return nil, fmt.Errorf("ref URL may not be empty string") - case map[string]interface{}: - description := strings.TrimSpace(getSafeString(rr, "description")) - ref := strings.TrimSpace(getSafeString(rr, "ref")) - if len(ref) > 0 { - u, err := url.Parse(ref) - if err != nil { - return nil, err - } - return &RelatedResourceAnnotation{Description: description, Ref: *u}, nil - } - return nil, fmt.Errorf("'ref' value required in object") - } - - return nil, fmt.Errorf("invalid value type, must be string or map") -} - -func parseAuthor(a interface{}) (*AuthorAnnotation, error) { - a, err := convertYAMLMapKeyTypes(a, nil) - if err != nil { - return nil, err - } - - switch a := a.(type) { - case string: - return parseAuthorString(a) - case map[string]interface{}: - name := strings.TrimSpace(getSafeString(a, "name")) - email := strings.TrimSpace(getSafeString(a, "email")) - if len(name) > 0 || len(email) > 0 { - return &AuthorAnnotation{name, email}, nil - } - return nil, fmt.Errorf("'name' and/or 'email' values required in object") - } - - return nil, fmt.Errorf("invalid value type, must be string or map") -} - -func getSafeString(m map[string]interface{}, k string) string { - if v, found := m[k]; found { - if s, ok := v.(string); ok { - return s - } - } - return "" -} - -const emailPrefix = "<" -const emailSuffix = ">" - -// parseAuthor parses a string into an AuthorAnnotation. If the last word of the input string is enclosed within <>, -// it is extracted as the author's email. The email may not contain whitelines, as it then will be interpreted as -// multiple words. -func parseAuthorString(s string) (*AuthorAnnotation, error) { - parts := strings.Fields(s) - - if len(parts) == 0 { - return nil, fmt.Errorf("author is an empty string") - } - - namePartCount := len(parts) - trailing := parts[namePartCount-1] - var email string - if len(trailing) >= len(emailPrefix)+len(emailSuffix) && strings.HasPrefix(trailing, emailPrefix) && - strings.HasSuffix(trailing, emailSuffix) { - email = trailing[len(emailPrefix):] - email = email[0 : len(email)-len(emailSuffix)] - namePartCount = namePartCount - 1 - } - - name := strings.Join(parts[0:namePartCount], " ") - - return &AuthorAnnotation{Name: name, Email: email}, nil -} - -func convertYAMLMapKeyTypes(x any, path []string) (any, error) { - var err error - switch x := x.(type) { - case map[any]any: - result := make(map[string]any, len(x)) - for k, v := range x { - str, ok := k.(string) - if !ok { - return nil, fmt.Errorf("invalid map key type(s): %v", strings.Join(path, "/")) - } - result[str], err = convertYAMLMapKeyTypes(v, append(path, str)) - if err != nil { - return nil, err - } - } - return result, nil - case []any: - for i := range x { - x[i], err = convertYAMLMapKeyTypes(x[i], append(path, fmt.Sprintf("%d", i))) - if err != nil { - return nil, err - } - } - return x, nil - default: - return x, nil - } -} - -// futureKeywords is the source of truth for future keywords that will -// eventually become standard keywords inside of Rego. -var futureKeywords = map[string]tokens.Token{ - "in": tokens.In, - "every": tokens.Every, - "contains": tokens.Contains, - "if": tokens.If, + return v1.NewParser().WithRegoVersion(DefaultRegoVersion) } func IsFutureKeyword(s string) bool { - _, ok := futureKeywords[s] - return ok -} - -func (p *Parser) futureImport(imp *Import, allowedFutureKeywords map[string]tokens.Token) { - path := imp.Path.Value.(Ref) - - if len(path) == 1 || !path[1].Equal(StringTerm("keywords")) { - p.errorf(imp.Path.Location, "invalid import, must be `future.keywords`") - return - } - - if imp.Alias != "" { - p.errorf(imp.Path.Location, "`future` imports cannot be aliased") - return - } - - if p.s.s.RegoV1Compatible() { - p.errorf(imp.Path.Location, "the `%s` import implies `future.keywords`, these are therefore mutually exclusive", RegoV1CompatibleRef) - return - } - - kwds := make([]string, 0, len(allowedFutureKeywords)) - for k := range allowedFutureKeywords { - kwds = append(kwds, k) - } - - switch len(path) { - case 2: // all keywords imported, nothing to do - case 3: // one keyword imported - kw, ok := path[2].Value.(String) - if !ok { - p.errorf(imp.Path.Location, "invalid import, must be `future.keywords.x`, e.g. `import future.keywords.in`") - return - } - keyword := string(kw) - _, ok = allowedFutureKeywords[keyword] - if !ok { - sort.Strings(kwds) // so the error message is stable - p.errorf(imp.Path.Location, "unexpected keyword, must be one of %v", kwds) - return - } - - kwds = []string{keyword} // overwrite - } - for _, kw := range kwds { - p.s.s.AddKeyword(kw, allowedFutureKeywords[kw]) - } -} - -func (p *Parser) regoV1Import(imp *Import) { - if !p.po.Capabilities.ContainsFeature(FeatureRegoV1Import) { - p.errorf(imp.Path.Location, "invalid import, `%s` is not supported by current capabilities", RegoV1CompatibleRef) - return - } - - path := imp.Path.Value.(Ref) - - // v1 is only valid option - if len(path) == 1 || !path[1].Equal(RegoV1CompatibleRef[1]) || len(path) > 2 { - p.errorf(imp.Path.Location, "invalid import `%s`, must be `%s`", path, RegoV1CompatibleRef) - return - } - - if p.po.RegoVersion == RegoV1 { - // We're parsing for Rego v1, where the 'rego.v1' import is a no-op. - return - } - - if imp.Alias != "" { - p.errorf(imp.Path.Location, "`rego` imports cannot be aliased") - return - } - - // import all future keywords with the rego.v1 import - kwds := make([]string, 0, len(futureKeywords)) - for k := range futureKeywords { - kwds = append(kwds, k) - } - - if p.s.s.HasKeyword(futureKeywords) && !p.s.s.RegoV1Compatible() { - // We have imported future keywords, but they didn't come from another `rego.v1` import. - p.errorf(imp.Path.Location, "the `%s` import implies `future.keywords`, these are therefore mutually exclusive", RegoV1CompatibleRef) - return - } - - p.s.s.SetRegoV1Compatible() - for _, kw := range kwds { - p.s.s.AddKeyword(kw, futureKeywords[kw]) - } + return v1.IsFutureKeywordForRegoVersion(s, RegoV0) } diff --git a/vendor/github.com/open-policy-agent/opa/ast/parser_ext.go b/vendor/github.com/open-policy-agent/opa/ast/parser_ext.go index 83c87e47b..3b8b40682 100644 --- a/vendor/github.com/open-policy-agent/opa/ast/parser_ext.go +++ b/vendor/github.com/open-policy-agent/opa/ast/parser_ext.go @@ -1,24 +1,13 @@ -// Copyright 2016 The OPA Authors. All rights reserved. +// Copyright 2024 The OPA Authors. All rights reserved. // Use of this source code is governed by an Apache2 // license that can be found in the LICENSE file. -// This file contains extra functions for parsing Rego. -// Most of the parsing is handled by the code in parser.go, -// however, there are additional utilities that are -// helpful for dealing with Rego source inputs (e.g., REPL -// statements, source files, etc.) - package ast import ( - "bytes" - "errors" "fmt" - "strings" - "unicode" - "github.com/open-policy-agent/opa/ast/internal/tokens" - astJSON "github.com/open-policy-agent/opa/ast/json" + v1 "github.com/open-policy-agent/opa/v1/ast" ) // MustParseBody returns a parsed body. @@ -30,11 +19,7 @@ func MustParseBody(input string) Body { // MustParseBodyWithOpts returns a parsed body. // If an error occurs during parsing, panic. func MustParseBodyWithOpts(input string, opts ParserOptions) Body { - parsed, err := ParseBodyWithOpts(input, opts) - if err != nil { - panic(err) - } - return parsed + return v1.MustParseBodyWithOpts(input, setDefaultRegoVersion(opts)) } // MustParseExpr returns a parsed expression. @@ -66,11 +51,7 @@ func MustParseModule(input string) *Module { // MustParseModuleWithOpts returns a parsed module. // If an error occurs during parsing, panic. func MustParseModuleWithOpts(input string, opts ParserOptions) *Module { - parsed, err := ParseModuleWithOpts("", input, opts) - if err != nil { - panic(err) - } - return parsed + return v1.MustParseModuleWithOpts(input, setDefaultRegoVersion(opts)) } // MustParsePackage returns a Package. @@ -104,11 +85,7 @@ func MustParseStatement(input string) Statement { } func MustParseStatementWithOpts(input string, popts ParserOptions) Statement { - parsed, err := ParseStatementWithOpts(input, popts) - if err != nil { - panic(err) - } - return parsed + return v1.MustParseStatementWithOpts(input, setDefaultRegoVersion(popts)) } // MustParseRef returns a parsed reference. @@ -134,11 +111,7 @@ func MustParseRule(input string) *Rule { // MustParseRuleWithOpts returns a parsed rule. // If an error occurs during parsing, panic. func MustParseRuleWithOpts(input string, opts ParserOptions) *Rule { - parsed, err := ParseRuleWithOpts(input, opts) - if err != nil { - panic(err) - } - return parsed + return v1.MustParseRuleWithOpts(input, setDefaultRegoVersion(opts)) } // MustParseTerm returns a parsed term. @@ -154,331 +127,59 @@ func MustParseTerm(input string) *Term { // ParseRuleFromBody returns a rule if the body can be interpreted as a rule // definition. Otherwise, an error is returned. func ParseRuleFromBody(module *Module, body Body) (*Rule, error) { - - if len(body) != 1 { - return nil, fmt.Errorf("multiple expressions cannot be used for rule head") - } - - return ParseRuleFromExpr(module, body[0]) + return v1.ParseRuleFromBody(module, body) } // ParseRuleFromExpr returns a rule if the expression can be interpreted as a // rule definition. func ParseRuleFromExpr(module *Module, expr *Expr) (*Rule, error) { - - if len(expr.With) > 0 { - return nil, fmt.Errorf("expressions using with keyword cannot be used for rule head") - } - - if expr.Negated { - return nil, fmt.Errorf("negated expressions cannot be used for rule head") - } - - if _, ok := expr.Terms.(*SomeDecl); ok { - return nil, errors.New("'some' declarations cannot be used for rule head") - } - - if term, ok := expr.Terms.(*Term); ok { - switch v := term.Value.(type) { - case Ref: - if len(v) > 2 { // 2+ dots - return ParseCompleteDocRuleWithDotsFromTerm(module, term) - } - return ParsePartialSetDocRuleFromTerm(module, term) - default: - return nil, fmt.Errorf("%v cannot be used for rule name", TypeName(v)) - } - } - - if _, ok := expr.Terms.([]*Term); !ok { - // This is a defensive check in case other kinds of expression terms are - // introduced in the future. - return nil, errors.New("expression cannot be used for rule head") - } - - if expr.IsEquality() { - return parseCompleteRuleFromEq(module, expr) - } else if expr.IsAssignment() { - rule, err := parseCompleteRuleFromEq(module, expr) - if err != nil { - return nil, err - } - rule.Head.Assign = true - return rule, nil - } - - if _, ok := BuiltinMap[expr.Operator().String()]; ok { - return nil, fmt.Errorf("rule name conflicts with built-in function") - } - - return ParseRuleFromCallExpr(module, expr.Terms.([]*Term)) -} - -func parseCompleteRuleFromEq(module *Module, expr *Expr) (rule *Rule, err error) { - - // ensure the rule location is set to the expr location - // the helper functions called below try to set the location based - // on the terms they've been provided but that is not as accurate. - defer func() { - if rule != nil { - rule.Location = expr.Location - rule.Head.Location = expr.Location - } - }() - - lhs, rhs := expr.Operand(0), expr.Operand(1) - if lhs == nil || rhs == nil { - return nil, errors.New("assignment requires two operands") - } - - rule, err = ParseRuleFromCallEqExpr(module, lhs, rhs) - if err == nil { - return rule, nil - } - - rule, err = ParsePartialObjectDocRuleFromEqExpr(module, lhs, rhs) - if err == nil { - return rule, nil - } - - return ParseCompleteDocRuleFromEqExpr(module, lhs, rhs) + return v1.ParseRuleFromExpr(module, expr) } // ParseCompleteDocRuleFromAssignmentExpr returns a rule if the expression can // be interpreted as a complete document definition declared with the assignment // operator. func ParseCompleteDocRuleFromAssignmentExpr(module *Module, lhs, rhs *Term) (*Rule, error) { - - rule, err := ParseCompleteDocRuleFromEqExpr(module, lhs, rhs) - if err != nil { - return nil, err - } - - rule.Head.Assign = true - - return rule, nil + return v1.ParseCompleteDocRuleFromAssignmentExpr(module, lhs, rhs) } // ParseCompleteDocRuleFromEqExpr returns a rule if the expression can be // interpreted as a complete document definition. func ParseCompleteDocRuleFromEqExpr(module *Module, lhs, rhs *Term) (*Rule, error) { - var head *Head - - if v, ok := lhs.Value.(Var); ok { - // Modify the code to add the location to the head ref - // and set the head ref's jsonOptions. - head = VarHead(v, lhs.Location, &lhs.jsonOptions) - } else if r, ok := lhs.Value.(Ref); ok { // groundness ? - if _, ok := r[0].Value.(Var); !ok { - return nil, fmt.Errorf("invalid rule head: %v", r) - } - head = RefHead(r) - if len(r) > 1 && !r[len(r)-1].IsGround() { - return nil, fmt.Errorf("ref not ground") - } - } else { - return nil, fmt.Errorf("%v cannot be used for rule name", TypeName(lhs.Value)) - } - head.Value = rhs - head.Location = lhs.Location - head.setJSONOptions(lhs.jsonOptions) - - body := NewBody(NewExpr(BooleanTerm(true).SetLocation(rhs.Location)).SetLocation(rhs.Location)) - setJSONOptions(body, &rhs.jsonOptions) - - return &Rule{ - Location: lhs.Location, - Head: head, - Body: body, - Module: module, - jsonOptions: lhs.jsonOptions, - generatedBody: true, - }, nil + return v1.ParseCompleteDocRuleFromEqExpr(module, lhs, rhs) } func ParseCompleteDocRuleWithDotsFromTerm(module *Module, term *Term) (*Rule, error) { - ref, ok := term.Value.(Ref) - if !ok { - return nil, fmt.Errorf("%v cannot be used for rule name", TypeName(term.Value)) - } - - if _, ok := ref[0].Value.(Var); !ok { - return nil, fmt.Errorf("invalid rule head: %v", ref) - } - head := RefHead(ref, BooleanTerm(true).SetLocation(term.Location)) - head.generatedValue = true - head.Location = term.Location - head.jsonOptions = term.jsonOptions - - body := NewBody(NewExpr(BooleanTerm(true).SetLocation(term.Location)).SetLocation(term.Location)) - setJSONOptions(body, &term.jsonOptions) - - return &Rule{ - Location: term.Location, - Head: head, - Body: body, - Module: module, - - jsonOptions: term.jsonOptions, - }, nil + return v1.ParseCompleteDocRuleWithDotsFromTerm(module, term) } // ParsePartialObjectDocRuleFromEqExpr returns a rule if the expression can be // interpreted as a partial object document definition. func ParsePartialObjectDocRuleFromEqExpr(module *Module, lhs, rhs *Term) (*Rule, error) { - ref, ok := lhs.Value.(Ref) - if !ok { - return nil, fmt.Errorf("%v cannot be used as rule name", TypeName(lhs.Value)) - } - - if _, ok := ref[0].Value.(Var); !ok { - return nil, fmt.Errorf("invalid rule head: %v", ref) - } - - head := RefHead(ref, rhs) - if len(ref) == 2 { // backcompat for naked `foo.bar = "baz"` statements - head.Name = ref[0].Value.(Var) - head.Key = ref[1] - } - head.Location = rhs.Location - head.jsonOptions = rhs.jsonOptions - - body := NewBody(NewExpr(BooleanTerm(true).SetLocation(rhs.Location)).SetLocation(rhs.Location)) - setJSONOptions(body, &rhs.jsonOptions) - - rule := &Rule{ - Location: rhs.Location, - Head: head, - Body: body, - Module: module, - jsonOptions: rhs.jsonOptions, - } - - return rule, nil + return v1.ParsePartialObjectDocRuleFromEqExpr(module, lhs, rhs) } // ParsePartialSetDocRuleFromTerm returns a rule if the term can be interpreted // as a partial set document definition. func ParsePartialSetDocRuleFromTerm(module *Module, term *Term) (*Rule, error) { - - ref, ok := term.Value.(Ref) - if !ok || len(ref) == 1 { - return nil, fmt.Errorf("%vs cannot be used for rule head", TypeName(term.Value)) - } - if _, ok := ref[0].Value.(Var); !ok { - return nil, fmt.Errorf("invalid rule head: %v", ref) - } - - head := RefHead(ref) - if len(ref) == 2 { - v, ok := ref[0].Value.(Var) - if !ok { - return nil, fmt.Errorf("%vs cannot be used for rule head", TypeName(term.Value)) - } - // Modify the code to add the location to the head ref - // and set the head ref's jsonOptions. - head = VarHead(v, ref[0].Location, &ref[0].jsonOptions) - head.Key = ref[1] - } - head.Location = term.Location - head.jsonOptions = term.jsonOptions - - body := NewBody(NewExpr(BooleanTerm(true).SetLocation(term.Location)).SetLocation(term.Location)) - setJSONOptions(body, &term.jsonOptions) - - rule := &Rule{ - Location: term.Location, - Head: head, - Body: body, - Module: module, - jsonOptions: term.jsonOptions, - } - - return rule, nil + return v1.ParsePartialSetDocRuleFromTerm(module, term) } // ParseRuleFromCallEqExpr returns a rule if the term can be interpreted as a // function definition (e.g., f(x) = y => f(x) = y { true }). func ParseRuleFromCallEqExpr(module *Module, lhs, rhs *Term) (*Rule, error) { - - call, ok := lhs.Value.(Call) - if !ok { - return nil, fmt.Errorf("must be call") - } - - ref, ok := call[0].Value.(Ref) - if !ok { - return nil, fmt.Errorf("%vs cannot be used in function signature", TypeName(call[0].Value)) - } - if _, ok := ref[0].Value.(Var); !ok { - return nil, fmt.Errorf("invalid rule head: %v", ref) - } - - head := RefHead(ref, rhs) - head.Location = lhs.Location - head.Args = Args(call[1:]) - head.jsonOptions = lhs.jsonOptions - - body := NewBody(NewExpr(BooleanTerm(true).SetLocation(rhs.Location)).SetLocation(rhs.Location)) - setJSONOptions(body, &rhs.jsonOptions) - - rule := &Rule{ - Location: lhs.Location, - Head: head, - Body: body, - Module: module, - jsonOptions: lhs.jsonOptions, - } - - return rule, nil + return v1.ParseRuleFromCallEqExpr(module, lhs, rhs) } // ParseRuleFromCallExpr returns a rule if the terms can be interpreted as a // function returning true or some value (e.g., f(x) => f(x) = true { true }). func ParseRuleFromCallExpr(module *Module, terms []*Term) (*Rule, error) { - - if len(terms) <= 1 { - return nil, fmt.Errorf("rule argument list must take at least one argument") - } - - loc := terms[0].Location - ref := terms[0].Value.(Ref) - if _, ok := ref[0].Value.(Var); !ok { - return nil, fmt.Errorf("invalid rule head: %v", ref) - } - head := RefHead(ref, BooleanTerm(true).SetLocation(loc)) - head.Location = loc - head.Args = terms[1:] - head.jsonOptions = terms[0].jsonOptions - - body := NewBody(NewExpr(BooleanTerm(true).SetLocation(loc)).SetLocation(loc)) - setJSONOptions(body, &terms[0].jsonOptions) - - rule := &Rule{ - Location: loc, - Head: head, - Module: module, - Body: body, - jsonOptions: terms[0].jsonOptions, - } - return rule, nil + return v1.ParseRuleFromCallExpr(module, terms) } // ParseImports returns a slice of Import objects. func ParseImports(input string) ([]*Import, error) { - stmts, _, err := ParseStatements("", input) - if err != nil { - return nil, err - } - result := []*Import{} - for _, stmt := range stmts { - if imp, ok := stmt.(*Import); ok { - result = append(result, imp) - } else { - return nil, fmt.Errorf("expected import but got %T", stmt) - } - } - return result, nil + return v1.ParseImports(input) } // ParseModule returns a parsed Module object. @@ -492,11 +193,7 @@ func ParseModule(filename, input string) (*Module, error) { // For details on Module objects and their fields, see policy.go. // Empty input will return nil, nil. func ParseModuleWithOpts(filename, input string, popts ParserOptions) (*Module, error) { - stmts, comments, err := ParseStatementsWithOpts(filename, input, popts) - if err != nil { - return nil, err - } - return parseModule(filename, stmts, comments, popts.RegoVersion) + return v1.ParseModuleWithOpts(filename, input, setDefaultRegoVersion(popts)) } // ParseBody returns exactly one body. @@ -508,28 +205,7 @@ func ParseBody(input string) (Body, error) { // ParseBodyWithOpts returns exactly one body. It does _not_ set SkipRules: true on its own, // but respects whatever ParserOptions it's been given. func ParseBodyWithOpts(input string, popts ParserOptions) (Body, error) { - - stmts, _, err := ParseStatementsWithOpts("", input, popts) - if err != nil { - return nil, err - } - - result := Body{} - - for _, stmt := range stmts { - switch stmt := stmt.(type) { - case Body: - for i := range stmt { - result.Append(stmt[i]) - } - case *Comment: - // skip - default: - return nil, fmt.Errorf("expected body but got %T", stmt) - } - } - - return result, nil + return v1.ParseBodyWithOpts(input, setDefaultRegoVersion(popts)) } // ParseExpr returns exactly one expression. @@ -548,15 +224,7 @@ func ParseExpr(input string) (*Expr, error) { // ParsePackage returns exactly one Package. // If multiple statements are parsed, an error is returned. func ParsePackage(input string) (*Package, error) { - stmt, err := ParseStatement(input) - if err != nil { - return nil, err - } - pkg, ok := stmt.(*Package) - if !ok { - return nil, fmt.Errorf("expected package but got %T", stmt) - } - return pkg, nil + return v1.ParsePackage(input) } // ParseTerm returns exactly one term. @@ -592,18 +260,7 @@ func ParseRef(input string) (Ref, error) { // ParseRuleWithOpts returns exactly one rule. // If multiple rules are parsed, an error is returned. func ParseRuleWithOpts(input string, opts ParserOptions) (*Rule, error) { - stmts, _, err := ParseStatementsWithOpts("", input, opts) - if err != nil { - return nil, err - } - if len(stmts) != 1 { - return nil, fmt.Errorf("expected exactly one statement (rule), got %v = %T, %T", stmts, stmts[0], stmts[1]) - } - rule, ok := stmts[0].(*Rule) - if !ok { - return nil, fmt.Errorf("expected rule but got %T", stmts[0]) - } - return rule, nil + return v1.ParseRuleWithOpts(input, setDefaultRegoVersion(opts)) } // ParseRule returns exactly one rule. @@ -628,14 +285,7 @@ func ParseStatement(input string) (Statement, error) { } func ParseStatementWithOpts(input string, popts ParserOptions) (Statement, error) { - stmts, _, err := ParseStatementsWithOpts("", input, popts) - if err != nil { - return nil, err - } - if len(stmts) != 1 { - return nil, fmt.Errorf("expected exactly one statement") - } - return stmts[0], nil + return v1.ParseStatementWithOpts(input, setDefaultRegoVersion(popts)) } // ParseStatements is deprecated. Use ParseStatementWithOpts instead. @@ -646,204 +296,15 @@ func ParseStatements(filename, input string) ([]Statement, []*Comment, error) { // ParseStatementsWithOpts returns a slice of parsed statements. This is the // default return value from the parser. func ParseStatementsWithOpts(filename, input string, popts ParserOptions) ([]Statement, []*Comment, error) { - - parser := NewParser(). - WithFilename(filename). - WithReader(bytes.NewBufferString(input)). - WithProcessAnnotation(popts.ProcessAnnotation). - WithFutureKeywords(popts.FutureKeywords...). - WithAllFutureKeywords(popts.AllFutureKeywords). - WithCapabilities(popts.Capabilities). - WithSkipRules(popts.SkipRules). - WithJSONOptions(popts.JSONOptions). - WithRegoVersion(popts.RegoVersion). - withUnreleasedKeywords(popts.unreleasedKeywords) - - stmts, comments, errs := parser.Parse() - - if len(errs) > 0 { - return nil, nil, errs - } - - return stmts, comments, nil -} - -func parseModule(filename string, stmts []Statement, comments []*Comment, regoCompatibilityMode RegoVersion) (*Module, error) { - - if len(stmts) == 0 { - return nil, NewError(ParseErr, &Location{File: filename}, "empty module") - } - - var errs Errors - - pkg, ok := stmts[0].(*Package) - if !ok { - loc := stmts[0].Loc() - errs = append(errs, NewError(ParseErr, loc, "package expected")) - } - - mod := &Module{ - Package: pkg, - stmts: stmts, - } - - // The comments slice only holds comments that were not their own statements. - mod.Comments = append(mod.Comments, comments...) - mod.regoVersion = regoCompatibilityMode - - for i, stmt := range stmts[1:] { - switch stmt := stmt.(type) { - case *Import: - mod.Imports = append(mod.Imports, stmt) - if mod.regoVersion == RegoV0 && Compare(stmt.Path.Value, RegoV1CompatibleRef) == 0 { - mod.regoVersion = RegoV0CompatV1 - } - case *Rule: - setRuleModule(stmt, mod) - mod.Rules = append(mod.Rules, stmt) - case Body: - rule, err := ParseRuleFromBody(mod, stmt) - if err != nil { - errs = append(errs, NewError(ParseErr, stmt[0].Location, err.Error())) - continue - } - rule.generatedBody = true - mod.Rules = append(mod.Rules, rule) - - // NOTE(tsandall): the statement should now be interpreted as a - // rule so update the statement list. This is important for the - // logic below that associates annotations with statements. - stmts[i+1] = rule - case *Package: - errs = append(errs, NewError(ParseErr, stmt.Loc(), "unexpected package")) - case *Annotations: - mod.Annotations = append(mod.Annotations, stmt) - case *Comment: - // Ignore comments, they're handled above. - default: - panic("illegal value") // Indicates grammar is out-of-sync with code. - } - } - - if mod.regoVersion == RegoV0CompatV1 || mod.regoVersion == RegoV1 { - for _, rule := range mod.Rules { - for r := rule; r != nil; r = r.Else { - errs = append(errs, CheckRegoV1(r)...) - } - } - } - - if len(errs) > 0 { - return nil, errs - } - - errs = append(errs, attachAnnotationsNodes(mod)...) - - if len(errs) > 0 { - return nil, errs - } - - attachRuleAnnotations(mod) - - return mod, nil -} - -func ruleDeclarationHasKeyword(rule *Rule, keyword tokens.Token) bool { - for _, kw := range rule.Head.keywords { - if kw == keyword { - return true - } - } - return false -} - -func newScopeAttachmentErr(a *Annotations, want string) *Error { - var have string - if a.node != nil { - have = fmt.Sprintf(" (have %v)", TypeName(a.node)) - } - return NewError(ParseErr, a.Loc(), "annotation scope '%v' must be applied to %v%v", a.Scope, want, have) -} - -func setRuleModule(rule *Rule, module *Module) { - rule.Module = module - if rule.Else != nil { - setRuleModule(rule.Else, module) - } -} - -func setJSONOptions(x interface{}, jsonOptions *astJSON.Options) { - vis := NewGenericVisitor(func(x interface{}) bool { - if x, ok := x.(customJSON); ok { - x.setJSONOptions(*jsonOptions) - } - return false - }) - vis.Walk(x) + return v1.ParseStatementsWithOpts(filename, input, setDefaultRegoVersion(popts)) } // ParserErrorDetail holds additional details for parser errors. -type ParserErrorDetail struct { - Line string `json:"line"` - Idx int `json:"idx"` -} - -func newParserErrorDetail(bs []byte, offset int) *ParserErrorDetail { - - // Find first non-space character at or before offset position. - if offset >= len(bs) { - offset = len(bs) - 1 - } else if offset < 0 { - offset = 0 - } - - for offset > 0 && unicode.IsSpace(rune(bs[offset])) { - offset-- - } - - // Find beginning of line containing offset. - begin := offset - - for begin > 0 && !isNewLineChar(bs[begin]) { - begin-- - } +type ParserErrorDetail = v1.ParserErrorDetail - if isNewLineChar(bs[begin]) { - begin++ +func setDefaultRegoVersion(opts ParserOptions) ParserOptions { + if opts.RegoVersion == RegoUndefined { + opts.RegoVersion = DefaultRegoVersion } - - // Find end of line containing offset. - end := offset - - for end < len(bs) && !isNewLineChar(bs[end]) { - end++ - } - - if begin > end { - begin = end - } - - // Extract line and compute index of offset byte in line. - line := bs[begin:end] - index := offset - begin - - return &ParserErrorDetail{ - Line: string(line), - Idx: index, - } -} - -// Lines returns the pretty formatted line output for the error details. -func (d ParserErrorDetail) Lines() []string { - line := strings.TrimLeft(d.Line, "\t") // remove leading tabs - tabCount := len(d.Line) - len(line) - indent := d.Idx - tabCount - if indent < 0 { - indent = 0 - } - return []string{line, strings.Repeat(" ", indent) + "^"} -} - -func isNewLineChar(b byte) bool { - return b == '\r' || b == '\n' + return opts } diff --git a/vendor/github.com/open-policy-agent/opa/ast/policy.go b/vendor/github.com/open-policy-agent/opa/ast/policy.go index 43e9bba4a..a29f0dcc7 100644 --- a/vendor/github.com/open-policy-agent/opa/ast/policy.go +++ b/vendor/github.com/open-policy-agent/opa/ast/policy.go @@ -1,196 +1,113 @@ -// Copyright 2016 The OPA Authors. All rights reserved. +// Copyright 2024 The OPA Authors. All rights reserved. // Use of this source code is governed by an Apache2 // license that can be found in the LICENSE file. package ast import ( - "bytes" - "encoding/json" - "fmt" - "math/rand" - "strings" - "time" - - "github.com/open-policy-agent/opa/ast/internal/tokens" astJSON "github.com/open-policy-agent/opa/ast/json" - "github.com/open-policy-agent/opa/util" + v1 "github.com/open-policy-agent/opa/v1/ast" ) -// Initialize seed for term hashing. This is intentionally placed before the -// root document sets are constructed to ensure they use the same hash seed as -// subsequent lookups. If the hash seeds are out of sync, lookups will fail. -var hashSeed = rand.New(rand.NewSource(time.Now().UnixNano())) -var hashSeed0 = (uint64(hashSeed.Uint32()) << 32) | uint64(hashSeed.Uint32()) - // DefaultRootDocument is the default root document. // // All package directives inside source files are implicitly prefixed with the // DefaultRootDocument value. -var DefaultRootDocument = VarTerm("data") +var DefaultRootDocument = v1.DefaultRootDocument // InputRootDocument names the document containing query arguments. -var InputRootDocument = VarTerm("input") +var InputRootDocument = v1.InputRootDocument // SchemaRootDocument names the document containing external data schemas. -var SchemaRootDocument = VarTerm("schema") +var SchemaRootDocument = v1.SchemaRootDocument // FunctionArgRootDocument names the document containing function arguments. // It's only for internal usage, for referencing function arguments between // the index and topdown. -var FunctionArgRootDocument = VarTerm("args") +var FunctionArgRootDocument = v1.FunctionArgRootDocument // FutureRootDocument names the document containing new, to-become-default, // features. -var FutureRootDocument = VarTerm("future") +var FutureRootDocument = v1.FutureRootDocument // RegoRootDocument names the document containing new, to-become-default, // features in a future versioned release. -var RegoRootDocument = VarTerm("rego") +var RegoRootDocument = v1.RegoRootDocument // RootDocumentNames contains the names of top-level documents that can be // referred to in modules and queries. // // Note, the schema document is not currently implemented in the evaluator so it // is not registered as a root document name (yet). -var RootDocumentNames = NewSet( - DefaultRootDocument, - InputRootDocument, -) +var RootDocumentNames = v1.RootDocumentNames // DefaultRootRef is a reference to the root of the default document. // // All refs to data in the policy engine's storage layer are prefixed with this ref. -var DefaultRootRef = Ref{DefaultRootDocument} +var DefaultRootRef = v1.DefaultRootRef // InputRootRef is a reference to the root of the input document. // // All refs to query arguments are prefixed with this ref. -var InputRootRef = Ref{InputRootDocument} +var InputRootRef = v1.InputRootRef // SchemaRootRef is a reference to the root of the schema document. // // All refs to schema documents are prefixed with this ref. Note, the schema // document is not currently implemented in the evaluator so it is not // registered as a root document ref (yet). -var SchemaRootRef = Ref{SchemaRootDocument} +var SchemaRootRef = v1.SchemaRootRef // RootDocumentRefs contains the prefixes of top-level documents that all // non-local references start with. -var RootDocumentRefs = NewSet( - NewTerm(DefaultRootRef), - NewTerm(InputRootRef), -) +var RootDocumentRefs = v1.RootDocumentRefs // SystemDocumentKey is the name of the top-level key that identifies the system // document. -var SystemDocumentKey = String("system") +const SystemDocumentKey = v1.SystemDocumentKey // ReservedVars is the set of names that refer to implicitly ground vars. -var ReservedVars = NewVarSet( - DefaultRootDocument.Value.(Var), - InputRootDocument.Value.(Var), -) +var ReservedVars = v1.ReservedVars // Wildcard represents the wildcard variable as defined in the language. -var Wildcard = &Term{Value: Var("_")} +var Wildcard = v1.Wildcard // WildcardPrefix is the special character that all wildcard variables are // prefixed with when the statement they are contained in is parsed. -var WildcardPrefix = "$" +const WildcardPrefix = v1.WildcardPrefix // Keywords contains strings that map to language keywords. -var Keywords = KeywordsForRegoVersion(DefaultRegoVersion) +var Keywords = v1.Keywords -var KeywordsV0 = [...]string{ - "not", - "package", - "import", - "as", - "default", - "else", - "with", - "null", - "true", - "false", - "some", -} +var KeywordsV0 = v1.KeywordsV0 -var KeywordsV1 = [...]string{ - "not", - "package", - "import", - "as", - "default", - "else", - "with", - "null", - "true", - "false", - "some", - "if", - "contains", - "in", - "every", -} +var KeywordsV1 = v1.KeywordsV1 func KeywordsForRegoVersion(v RegoVersion) []string { - switch v { - case RegoV0: - return KeywordsV0[:] - case RegoV1, RegoV0CompatV1: - return KeywordsV1[:] - } - return nil + return v1.KeywordsForRegoVersion(v) } // IsKeyword returns true if s is a language keyword. func IsKeyword(s string) bool { - return IsInKeywords(s, Keywords) + return v1.IsKeyword(s) } func IsInKeywords(s string, keywords []string) bool { - for _, x := range keywords { - if x == s { - return true - } - } - return false + return v1.IsInKeywords(s, keywords) } // IsKeywordInRegoVersion returns true if s is a language keyword. func IsKeywordInRegoVersion(s string, regoVersion RegoVersion) bool { - switch regoVersion { - case RegoV0: - for _, x := range KeywordsV0 { - if x == s { - return true - } - } - case RegoV1, RegoV0CompatV1: - for _, x := range KeywordsV1 { - if x == s { - return true - } - } - } - - return false + return v1.IsKeywordInRegoVersion(s, regoVersion) } type ( // Node represents a node in an AST. Nodes may be statements in a policy module // or elements of an ad-hoc query, expression, etc. - Node interface { - fmt.Stringer - Loc() *Location - SetLoc(*Location) - } + Node = v1.Node // Statement represents a single statement in a policy module. - Statement interface { - Node - } + Statement = v1.Statement ) type ( @@ -198,729 +115,72 @@ type ( // Module represents a collection of policies (defined by rules) // within a namespace (defined by the package) and optional // dependencies on external documents (defined by imports). - Module struct { - Package *Package `json:"package"` - Imports []*Import `json:"imports,omitempty"` - Annotations []*Annotations `json:"annotations,omitempty"` - Rules []*Rule `json:"rules,omitempty"` - Comments []*Comment `json:"comments,omitempty"` - stmts []Statement - regoVersion RegoVersion - } + Module = v1.Module // Comment contains the raw text from the comment in the definition. - Comment struct { - // TODO: these fields have inconsistent JSON keys with other structs in this package. - Text []byte - Location *Location - - jsonOptions astJSON.Options - } + Comment = v1.Comment // Package represents the namespace of the documents produced // by rules inside the module. - Package struct { - Path Ref `json:"path"` - Location *Location `json:"location,omitempty"` - - jsonOptions astJSON.Options - } + Package = v1.Package // Import represents a dependency on a document outside of the policy // namespace. Imports are optional. - Import struct { - Path *Term `json:"path"` - Alias Var `json:"alias,omitempty"` - Location *Location `json:"location,omitempty"` - - jsonOptions astJSON.Options - } + Import = v1.Import // Rule represents a rule as defined in the language. Rules define the // content of documents that represent policy decisions. - Rule struct { - Default bool `json:"default,omitempty"` - Head *Head `json:"head"` - Body Body `json:"body"` - Else *Rule `json:"else,omitempty"` - Location *Location `json:"location,omitempty"` - Annotations []*Annotations `json:"annotations,omitempty"` - - // Module is a pointer to the module containing this rule. If the rule - // was NOT created while parsing/constructing a module, this should be - // left unset. The pointer is not included in any standard operations - // on the rule (e.g., printing, comparison, visiting, etc.) - Module *Module `json:"-"` - - generatedBody bool - jsonOptions astJSON.Options - } + Rule = v1.Rule // Head represents the head of a rule. - Head struct { - Name Var `json:"name,omitempty"` - Reference Ref `json:"ref,omitempty"` - Args Args `json:"args,omitempty"` - Key *Term `json:"key,omitempty"` - Value *Term `json:"value,omitempty"` - Assign bool `json:"assign,omitempty"` - Location *Location `json:"location,omitempty"` - - keywords []tokens.Token - generatedValue bool - jsonOptions astJSON.Options - } + Head = v1.Head // Args represents zero or more arguments to a rule. - Args []*Term + Args = v1.Args // Body represents one or more expressions contained inside a rule or user // function. - Body []*Expr + Body = v1.Body // Expr represents a single expression contained inside the body of a rule. - Expr struct { - With []*With `json:"with,omitempty"` - Terms interface{} `json:"terms"` - Index int `json:"index"` - Generated bool `json:"generated,omitempty"` - Negated bool `json:"negated,omitempty"` - Location *Location `json:"location,omitempty"` - - jsonOptions astJSON.Options - generatedFrom *Expr - generates []*Expr - } + Expr = v1.Expr // SomeDecl represents a variable declaration statement. The symbols are variables. - SomeDecl struct { - Symbols []*Term `json:"symbols"` - Location *Location `json:"location,omitempty"` + SomeDecl = v1.SomeDecl - jsonOptions astJSON.Options - } - - Every struct { - Key *Term `json:"key"` - Value *Term `json:"value"` - Domain *Term `json:"domain"` - Body Body `json:"body"` - Location *Location `json:"location,omitempty"` - - jsonOptions astJSON.Options - } + Every = v1.Every // With represents a modifier on an expression. - With struct { - Target *Term `json:"target"` - Value *Term `json:"value"` - Location *Location `json:"location,omitempty"` - - jsonOptions astJSON.Options - } + With = v1.With ) -// Compare returns an integer indicating whether mod is less than, equal to, -// or greater than other. -func (mod *Module) Compare(other *Module) int { - if mod == nil { - if other == nil { - return 0 - } - return -1 - } else if other == nil { - return 1 - } - if cmp := mod.Package.Compare(other.Package); cmp != 0 { - return cmp - } - if cmp := importsCompare(mod.Imports, other.Imports); cmp != 0 { - return cmp - } - if cmp := annotationsCompare(mod.Annotations, other.Annotations); cmp != 0 { - return cmp - } - return rulesCompare(mod.Rules, other.Rules) -} - -// Copy returns a deep copy of mod. -func (mod *Module) Copy() *Module { - cpy := *mod - cpy.Rules = make([]*Rule, len(mod.Rules)) - - nodes := make(map[Node]Node, len(mod.Rules)+len(mod.Imports)+1 /* package */) - - for i := range mod.Rules { - cpy.Rules[i] = mod.Rules[i].Copy() - cpy.Rules[i].Module = &cpy - nodes[mod.Rules[i]] = cpy.Rules[i] - } - - cpy.Imports = make([]*Import, len(mod.Imports)) - for i := range mod.Imports { - cpy.Imports[i] = mod.Imports[i].Copy() - nodes[mod.Imports[i]] = cpy.Imports[i] - } - - cpy.Package = mod.Package.Copy() - nodes[mod.Package] = cpy.Package - - cpy.Annotations = make([]*Annotations, len(mod.Annotations)) - for i, a := range mod.Annotations { - cpy.Annotations[i] = a.Copy(nodes[a.node]) - } - - cpy.Comments = make([]*Comment, len(mod.Comments)) - for i := range mod.Comments { - cpy.Comments[i] = mod.Comments[i].Copy() - } - - cpy.stmts = make([]Statement, len(mod.stmts)) - for i := range mod.stmts { - cpy.stmts[i] = nodes[mod.stmts[i]] - } - - return &cpy -} - -// Equal returns true if mod equals other. -func (mod *Module) Equal(other *Module) bool { - return mod.Compare(other) == 0 -} - -func (mod *Module) String() string { - byNode := map[Node][]*Annotations{} - for _, a := range mod.Annotations { - byNode[a.node] = append(byNode[a.node], a) - } - - appendAnnotationStrings := func(buf []string, node Node) []string { - if as, ok := byNode[node]; ok { - for i := range as { - buf = append(buf, "# METADATA") - buf = append(buf, "# "+as[i].String()) - } - } - return buf - } - - buf := []string{} - buf = appendAnnotationStrings(buf, mod.Package) - buf = append(buf, mod.Package.String()) - - if len(mod.Imports) > 0 { - buf = append(buf, "") - for _, imp := range mod.Imports { - buf = appendAnnotationStrings(buf, imp) - buf = append(buf, imp.String()) - } - } - if len(mod.Rules) > 0 { - buf = append(buf, "") - for _, rule := range mod.Rules { - buf = appendAnnotationStrings(buf, rule) - buf = append(buf, rule.stringWithOpts(toStringOpts{regoVersion: mod.regoVersion})) - } - } - return strings.Join(buf, "\n") -} - -// RuleSet returns a RuleSet containing named rules in the mod. -func (mod *Module) RuleSet(name Var) RuleSet { - rs := NewRuleSet() - for _, rule := range mod.Rules { - if rule.Head.Name.Equal(name) { - rs.Add(rule) - } - } - return rs -} - -// UnmarshalJSON parses bs and stores the result in mod. The rules in the module -// will have their module pointer set to mod. -func (mod *Module) UnmarshalJSON(bs []byte) error { - - // Declare a new type and use a type conversion to avoid recursively calling - // Module#UnmarshalJSON. - type module Module - - if err := util.UnmarshalJSON(bs, (*module)(mod)); err != nil { - return err - } - - WalkRules(mod, func(rule *Rule) bool { - rule.Module = mod - return false - }) - - return nil -} - -func (mod *Module) regoV1Compatible() bool { - return mod.regoVersion == RegoV1 || mod.regoVersion == RegoV0CompatV1 -} - -func (mod *Module) RegoVersion() RegoVersion { - return mod.regoVersion -} - -// SetRegoVersion sets the RegoVersion for the module. -// Note: Setting a rego-version that does not match the module's rego-version might have unintended consequences. -func (mod *Module) SetRegoVersion(v RegoVersion) { - mod.regoVersion = v -} - // NewComment returns a new Comment object. func NewComment(text []byte) *Comment { - return &Comment{ - Text: text, - } -} - -// Loc returns the location of the comment in the definition. -func (c *Comment) Loc() *Location { - if c == nil { - return nil - } - return c.Location -} - -// SetLoc sets the location on c. -func (c *Comment) SetLoc(loc *Location) { - c.Location = loc -} - -func (c *Comment) String() string { - return "#" + string(c.Text) -} - -// Copy returns a deep copy of c. -func (c *Comment) Copy() *Comment { - cpy := *c - cpy.Text = make([]byte, len(c.Text)) - copy(cpy.Text, c.Text) - return &cpy -} - -// Equal returns true if this comment equals the other comment. -// Unlike other equality checks on AST nodes, comment equality -// depends on location. -func (c *Comment) Equal(other *Comment) bool { - return c.Location.Equal(other.Location) && bytes.Equal(c.Text, other.Text) -} - -func (c *Comment) setJSONOptions(opts astJSON.Options) { - // Note: this is not used for location since Comments use default JSON marshaling - // behavior with struct field names in JSON. - c.jsonOptions = opts - if c.Location != nil { - c.Location.JSONOptions = opts - } -} - -// Compare returns an integer indicating whether pkg is less than, equal to, -// or greater than other. -func (pkg *Package) Compare(other *Package) int { - return Compare(pkg.Path, other.Path) -} - -// Copy returns a deep copy of pkg. -func (pkg *Package) Copy() *Package { - cpy := *pkg - cpy.Path = pkg.Path.Copy() - return &cpy -} - -// Equal returns true if pkg is equal to other. -func (pkg *Package) Equal(other *Package) bool { - return pkg.Compare(other) == 0 -} - -// Loc returns the location of the Package in the definition. -func (pkg *Package) Loc() *Location { - if pkg == nil { - return nil - } - return pkg.Location -} - -// SetLoc sets the location on pkg. -func (pkg *Package) SetLoc(loc *Location) { - pkg.Location = loc -} - -func (pkg *Package) String() string { - if pkg == nil { - return "" - } else if len(pkg.Path) <= 1 { - return fmt.Sprintf("package ", pkg.Path) - } - // Omit head as all packages have the DefaultRootDocument prepended at parse time. - path := make(Ref, len(pkg.Path)-1) - path[0] = VarTerm(string(pkg.Path[1].Value.(String))) - copy(path[1:], pkg.Path[2:]) - return fmt.Sprintf("package %v", path) -} - -func (pkg *Package) setJSONOptions(opts astJSON.Options) { - pkg.jsonOptions = opts - if pkg.Location != nil { - pkg.Location.JSONOptions = opts - } -} - -func (pkg *Package) MarshalJSON() ([]byte, error) { - data := map[string]interface{}{ - "path": pkg.Path, - } - - if pkg.jsonOptions.MarshalOptions.IncludeLocation.Package { - if pkg.Location != nil { - data["location"] = pkg.Location - } - } - - return json.Marshal(data) + return v1.NewComment(text) } // IsValidImportPath returns an error indicating if the import path is invalid. // If the import path is valid, err is nil. func IsValidImportPath(v Value) (err error) { - switch v := v.(type) { - case Var: - if !v.Equal(DefaultRootDocument.Value) && !v.Equal(InputRootDocument.Value) { - return fmt.Errorf("invalid path %v: path must begin with input or data", v) - } - case Ref: - if err := IsValidImportPath(v[0].Value); err != nil { - return fmt.Errorf("invalid path %v: path must begin with input or data", v) - } - for _, e := range v[1:] { - if _, ok := e.Value.(String); !ok { - return fmt.Errorf("invalid path %v: path elements must be strings", v) - } - } - default: - return fmt.Errorf("invalid path %v: path must be ref or var", v) - } - return nil -} - -// Compare returns an integer indicating whether imp is less than, equal to, -// or greater than other. -func (imp *Import) Compare(other *Import) int { - if imp == nil { - if other == nil { - return 0 - } - return -1 - } else if other == nil { - return 1 - } - if cmp := Compare(imp.Path, other.Path); cmp != 0 { - return cmp - } - return Compare(imp.Alias, other.Alias) -} - -// Copy returns a deep copy of imp. -func (imp *Import) Copy() *Import { - cpy := *imp - cpy.Path = imp.Path.Copy() - return &cpy -} - -// Equal returns true if imp is equal to other. -func (imp *Import) Equal(other *Import) bool { - return imp.Compare(other) == 0 -} - -// Loc returns the location of the Import in the definition. -func (imp *Import) Loc() *Location { - if imp == nil { - return nil - } - return imp.Location -} - -// SetLoc sets the location on imp. -func (imp *Import) SetLoc(loc *Location) { - imp.Location = loc -} - -// Name returns the variable that is used to refer to the imported virtual -// document. This is the alias if defined otherwise the last element in the -// path. -func (imp *Import) Name() Var { - if len(imp.Alias) != 0 { - return imp.Alias - } - switch v := imp.Path.Value.(type) { - case Var: - return v - case Ref: - if len(v) == 1 { - return v[0].Value.(Var) - } - return Var(v[len(v)-1].Value.(String)) - } - panic("illegal import") -} - -func (imp *Import) String() string { - buf := []string{"import", imp.Path.String()} - if len(imp.Alias) > 0 { - buf = append(buf, "as "+imp.Alias.String()) - } - return strings.Join(buf, " ") -} - -func (imp *Import) setJSONOptions(opts astJSON.Options) { - imp.jsonOptions = opts - if imp.Location != nil { - imp.Location.JSONOptions = opts - } -} - -func (imp *Import) MarshalJSON() ([]byte, error) { - data := map[string]interface{}{ - "path": imp.Path, - } - - if len(imp.Alias) != 0 { - data["alias"] = imp.Alias - } - - if imp.jsonOptions.MarshalOptions.IncludeLocation.Import { - if imp.Location != nil { - data["location"] = imp.Location - } - } - - return json.Marshal(data) -} - -// Compare returns an integer indicating whether rule is less than, equal to, -// or greater than other. -func (rule *Rule) Compare(other *Rule) int { - if rule == nil { - if other == nil { - return 0 - } - return -1 - } else if other == nil { - return 1 - } - if cmp := rule.Head.Compare(other.Head); cmp != 0 { - return cmp - } - if cmp := util.Compare(rule.Default, other.Default); cmp != 0 { - return cmp - } - if cmp := rule.Body.Compare(other.Body); cmp != 0 { - return cmp - } - - if cmp := annotationsCompare(rule.Annotations, other.Annotations); cmp != 0 { - return cmp - } - - return rule.Else.Compare(other.Else) -} - -// Copy returns a deep copy of rule. -func (rule *Rule) Copy() *Rule { - cpy := *rule - cpy.Head = rule.Head.Copy() - cpy.Body = rule.Body.Copy() - - cpy.Annotations = make([]*Annotations, len(rule.Annotations)) - for i, a := range rule.Annotations { - cpy.Annotations[i] = a.Copy(&cpy) - } - - if cpy.Else != nil { - cpy.Else = rule.Else.Copy() - } - return &cpy -} - -// Equal returns true if rule is equal to other. -func (rule *Rule) Equal(other *Rule) bool { - return rule.Compare(other) == 0 -} - -// Loc returns the location of the Rule in the definition. -func (rule *Rule) Loc() *Location { - if rule == nil { - return nil - } - return rule.Location -} - -// SetLoc sets the location on rule. -func (rule *Rule) SetLoc(loc *Location) { - rule.Location = loc -} - -// Path returns a ref referring to the document produced by this rule. If rule -// is not contained in a module, this function panics. -// Deprecated: Poor handling of ref rules. Use `(*Rule).Ref()` instead. -func (rule *Rule) Path() Ref { - if rule.Module == nil { - panic("assertion failed") - } - return rule.Module.Package.Path.Extend(rule.Head.Ref().GroundPrefix()) -} - -// Ref returns a ref referring to the document produced by this rule. If rule -// is not contained in a module, this function panics. The returned ref may -// contain variables in the last position. -func (rule *Rule) Ref() Ref { - if rule.Module == nil { - panic("assertion failed") - } - return rule.Module.Package.Path.Extend(rule.Head.Ref()) -} - -func (rule *Rule) String() string { - return rule.stringWithOpts(toStringOpts{}) -} - -type toStringOpts struct { - regoVersion RegoVersion -} - -func (rule *Rule) stringWithOpts(opts toStringOpts) string { - buf := []string{} - if rule.Default { - buf = append(buf, "default") - } - buf = append(buf, rule.Head.stringWithOpts(opts)) - if !rule.Default { - switch opts.regoVersion { - case RegoV1, RegoV0CompatV1: - buf = append(buf, "if") - } - buf = append(buf, "{") - buf = append(buf, rule.Body.String()) - buf = append(buf, "}") - } - if rule.Else != nil { - buf = append(buf, rule.Else.elseString(opts)) - } - return strings.Join(buf, " ") -} - -func (rule *Rule) isFunction() bool { - return len(rule.Head.Args) > 0 -} - -func (rule *Rule) setJSONOptions(opts astJSON.Options) { - rule.jsonOptions = opts - if rule.Location != nil { - rule.Location.JSONOptions = opts - } -} - -func (rule *Rule) MarshalJSON() ([]byte, error) { - data := map[string]interface{}{ - "head": rule.Head, - "body": rule.Body, - } - - if rule.Default { - data["default"] = true - } - - if rule.Else != nil { - data["else"] = rule.Else - } - - if rule.jsonOptions.MarshalOptions.IncludeLocation.Rule { - if rule.Location != nil { - data["location"] = rule.Location - } - } - - if len(rule.Annotations) != 0 { - data["annotations"] = rule.Annotations - } - - return json.Marshal(data) -} - -func (rule *Rule) elseString(opts toStringOpts) string { - var buf []string - - buf = append(buf, "else") - - value := rule.Head.Value - if value != nil { - buf = append(buf, "=") - buf = append(buf, value.String()) - } - - switch opts.regoVersion { - case RegoV1, RegoV0CompatV1: - buf = append(buf, "if") - } - - buf = append(buf, "{") - buf = append(buf, rule.Body.String()) - buf = append(buf, "}") - - if rule.Else != nil { - buf = append(buf, rule.Else.elseString(opts)) - } - - return strings.Join(buf, " ") + return v1.IsValidImportPath(v) } // NewHead returns a new Head object. If args are provided, the first will be // used for the key and the second will be used for the value. func NewHead(name Var, args ...*Term) *Head { - head := &Head{ - Name: name, // backcompat - Reference: []*Term{NewTerm(name)}, - } - if len(args) == 0 { - return head - } - head.Key = args[0] - if len(args) == 1 { - return head - } - head.Value = args[1] - if head.Key != nil && head.Value != nil { - head.Reference = head.Reference.Append(args[0]) - } - return head + return v1.NewHead(name, args...) } // VarHead creates a head object, initializes its Name, Location, and Options, // and returns the new head. func VarHead(name Var, location *Location, jsonOpts *astJSON.Options) *Head { - h := NewHead(name) - h.Reference[0].Location = location - if jsonOpts != nil { - h.Reference[0].setJSONOptions(*jsonOpts) - } - return h + return v1.VarHead(name, location, jsonOpts) } // RefHead returns a new Head object with the passed Ref. If args are provided, // the first will be used for the value. func RefHead(ref Ref, args ...*Term) *Head { - head := &Head{} - head.SetRef(ref) - if len(ref) < 2 { - head.Name = ref[0].Value.(Var) - } - if len(args) >= 1 { - head.Value = args[0] - } - return head + return v1.RefHead(ref, args...) } // DocKind represents the collection of document types that can be produced by rules. @@ -928,1164 +188,48 @@ type DocKind int const ( // CompleteDoc represents a document that is completely defined by the rule. - CompleteDoc = iota + CompleteDoc = v1.CompleteDoc // PartialSetDoc represents a set document that is partially defined by the rule. - PartialSetDoc + PartialSetDoc = v1.PartialSetDoc // PartialObjectDoc represents an object document that is partially defined by the rule. - PartialObjectDoc -) // TODO(sr): Deprecate? - -// DocKind returns the type of document produced by this rule. -func (head *Head) DocKind() DocKind { - if head.Key != nil { - if head.Value != nil { - return PartialObjectDoc - } - return PartialSetDoc - } - return CompleteDoc -} + PartialObjectDoc = v1.PartialObjectDoc +) -type RuleKind int +type RuleKind = v1.RuleKind const ( - SingleValue = iota - MultiValue + SingleValue = v1.SingleValue + MultiValue = v1.MultiValue ) -// RuleKind returns the type of rule this is -func (head *Head) RuleKind() RuleKind { - // NOTE(sr): This is bit verbose, since the key is irrelevant for single vs - // multi value, but as good a spot as to assert the invariant. - switch { - case head.Value != nil: - return SingleValue - case head.Key != nil: - return MultiValue - default: - panic("unreachable") - } -} - -// Ref returns the Ref of the rule. If it doesn't have one, it's filled in -// via the Head's Name. -func (head *Head) Ref() Ref { - if len(head.Reference) > 0 { - return head.Reference - } - return Ref{&Term{Value: head.Name}} -} - -// SetRef can be used to set a rule head's Reference -func (head *Head) SetRef(r Ref) { - head.Reference = r -} - -// Compare returns an integer indicating whether head is less than, equal to, -// or greater than other. -func (head *Head) Compare(other *Head) int { - if head == nil { - if other == nil { - return 0 - } - return -1 - } else if other == nil { - return 1 - } - if head.Assign && !other.Assign { - return -1 - } else if !head.Assign && other.Assign { - return 1 - } - if cmp := Compare(head.Args, other.Args); cmp != 0 { - return cmp - } - if cmp := Compare(head.Reference, other.Reference); cmp != 0 { - return cmp - } - if cmp := Compare(head.Name, other.Name); cmp != 0 { - return cmp - } - if cmp := Compare(head.Key, other.Key); cmp != 0 { - return cmp - } - return Compare(head.Value, other.Value) -} - -// Copy returns a deep copy of head. -func (head *Head) Copy() *Head { - cpy := *head - cpy.Reference = head.Reference.Copy() - cpy.Args = head.Args.Copy() - cpy.Key = head.Key.Copy() - cpy.Value = head.Value.Copy() - cpy.keywords = nil - return &cpy -} - -// Equal returns true if this head equals other. -func (head *Head) Equal(other *Head) bool { - return head.Compare(other) == 0 -} - -func (head *Head) String() string { - return head.stringWithOpts(toStringOpts{}) -} - -func (head *Head) stringWithOpts(opts toStringOpts) string { - buf := strings.Builder{} - buf.WriteString(head.Ref().String()) - containsAdded := false - - switch { - case len(head.Args) != 0: - buf.WriteString(head.Args.String()) - case len(head.Reference) == 1 && head.Key != nil: - switch opts.regoVersion { - case RegoV0: - buf.WriteRune('[') - buf.WriteString(head.Key.String()) - buf.WriteRune(']') - default: - containsAdded = true - buf.WriteString(" contains ") - buf.WriteString(head.Key.String()) - } - } - if head.Value != nil { - if head.Assign { - buf.WriteString(" := ") - } else { - buf.WriteString(" = ") - } - buf.WriteString(head.Value.String()) - } else if !containsAdded && head.Name == "" && head.Key != nil { - buf.WriteString(" contains ") - buf.WriteString(head.Key.String()) - } - return buf.String() -} - -func (head *Head) setJSONOptions(opts astJSON.Options) { - head.jsonOptions = opts - if head.Location != nil { - head.Location.JSONOptions = opts - } -} - -func (head *Head) MarshalJSON() ([]byte, error) { - var loc *Location - includeLoc := head.jsonOptions.MarshalOptions.IncludeLocation - if includeLoc.Head { - if head.Location != nil { - loc = head.Location - } - - for _, term := range head.Reference { - if term.Location != nil { - term.jsonOptions.MarshalOptions.IncludeLocation.Term = includeLoc.Term - } - } - } - - // NOTE(sr): we do this to override the rendering of `head.Reference`. - // It's still what'll be used via the default means of encoding/json - // for unmarshaling a json object into a Head struct! - type h Head - return json.Marshal(struct { - h - Ref Ref `json:"ref"` - Location *Location `json:"location,omitempty"` - }{ - h: h(*head), - Ref: head.Ref(), - Location: loc, - }) -} - -// Vars returns a set of vars found in the head. -func (head *Head) Vars() VarSet { - vis := &VarVisitor{vars: VarSet{}} - // TODO: improve test coverage for this. - if head.Args != nil { - vis.Walk(head.Args) - } - if head.Key != nil { - vis.Walk(head.Key) - } - if head.Value != nil { - vis.Walk(head.Value) - } - if len(head.Reference) > 0 { - vis.Walk(head.Reference[1:]) - } - return vis.vars -} - -// Loc returns the Location of head. -func (head *Head) Loc() *Location { - if head == nil { - return nil - } - return head.Location -} - -// SetLoc sets the location on head. -func (head *Head) SetLoc(loc *Location) { - head.Location = loc -} - -func (head *Head) HasDynamicRef() bool { - pos := head.Reference.Dynamic() - // Ref is dynamic if it has one non-constant term that isn't the first or last term or if it's a partial set rule. - return pos > 0 && (pos < len(head.Reference)-1 || head.RuleKind() == MultiValue) -} - -// Copy returns a deep copy of a. -func (a Args) Copy() Args { - cpy := Args{} - for _, t := range a { - cpy = append(cpy, t.Copy()) - } - return cpy -} - -func (a Args) String() string { - buf := make([]string, 0, len(a)) - for _, t := range a { - buf = append(buf, t.String()) - } - return "(" + strings.Join(buf, ", ") + ")" -} - -// Loc returns the Location of a. -func (a Args) Loc() *Location { - if len(a) == 0 { - return nil - } - return a[0].Location -} - -// SetLoc sets the location on a. -func (a Args) SetLoc(loc *Location) { - if len(a) != 0 { - a[0].SetLocation(loc) - } -} - -// Vars returns a set of vars that appear in a. -func (a Args) Vars() VarSet { - vis := &VarVisitor{vars: VarSet{}} - vis.Walk(a) - return vis.vars -} - // NewBody returns a new Body containing the given expressions. The indices of // the immediate expressions will be reset. func NewBody(exprs ...*Expr) Body { - for i, expr := range exprs { - expr.Index = i - } - return Body(exprs) -} - -// MarshalJSON returns JSON encoded bytes representing body. -func (body Body) MarshalJSON() ([]byte, error) { - // Serialize empty Body to empty array. This handles both the empty case and the - // nil case (whereas by default the result would be null if body was nil.) - if len(body) == 0 { - return []byte(`[]`), nil - } - ret, err := json.Marshal([]*Expr(body)) - return ret, err -} - -// Append adds the expr to the body and updates the expr's index accordingly. -func (body *Body) Append(expr *Expr) { - n := len(*body) - expr.Index = n - *body = append(*body, expr) -} - -// Set sets the expr in the body at the specified position and updates the -// expr's index accordingly. -func (body Body) Set(expr *Expr, pos int) { - body[pos] = expr - expr.Index = pos -} - -// Compare returns an integer indicating whether body is less than, equal to, -// or greater than other. -// -// If body is a subset of other, it is considered less than (and vice versa). -func (body Body) Compare(other Body) int { - minLen := len(body) - if len(other) < minLen { - minLen = len(other) - } - for i := 0; i < minLen; i++ { - if cmp := body[i].Compare(other[i]); cmp != 0 { - return cmp - } - } - if len(body) < len(other) { - return -1 - } - if len(other) < len(body) { - return 1 - } - return 0 -} - -// Copy returns a deep copy of body. -func (body Body) Copy() Body { - cpy := make(Body, len(body)) - for i := range body { - cpy[i] = body[i].Copy() - } - return cpy -} - -// Contains returns true if this body contains the given expression. -func (body Body) Contains(x *Expr) bool { - for _, e := range body { - if e.Equal(x) { - return true - } - } - return false -} - -// Equal returns true if this Body is equal to the other Body. -func (body Body) Equal(other Body) bool { - return body.Compare(other) == 0 -} - -// Hash returns the hash code for the Body. -func (body Body) Hash() int { - s := 0 - for _, e := range body { - s += e.Hash() - } - return s -} - -// IsGround returns true if all of the expressions in the Body are ground. -func (body Body) IsGround() bool { - for _, e := range body { - if !e.IsGround() { - return false - } - } - return true -} - -// Loc returns the location of the Body in the definition. -func (body Body) Loc() *Location { - if len(body) == 0 { - return nil - } - return body[0].Location -} - -// SetLoc sets the location on body. -func (body Body) SetLoc(loc *Location) { - if len(body) != 0 { - body[0].SetLocation(loc) - } -} - -func (body Body) String() string { - buf := make([]string, 0, len(body)) - for _, v := range body { - buf = append(buf, v.String()) - } - return strings.Join(buf, "; ") -} - -// Vars returns a VarSet containing variables in body. The params can be set to -// control which vars are included. -func (body Body) Vars(params VarVisitorParams) VarSet { - vis := NewVarVisitor().WithParams(params) - vis.Walk(body) - return vis.Vars() + return v1.NewBody(exprs...) } // NewExpr returns a new Expr object. func NewExpr(terms interface{}) *Expr { - switch terms.(type) { - case *SomeDecl, *Every, *Term, []*Term: // ok - default: - panic("unreachable") - } - return &Expr{ - Negated: false, - Terms: terms, - Index: 0, - With: nil, - } -} - -// Complement returns a copy of this expression with the negation flag flipped. -func (expr *Expr) Complement() *Expr { - cpy := *expr - cpy.Negated = !cpy.Negated - return &cpy -} - -// Equal returns true if this Expr equals the other Expr. -func (expr *Expr) Equal(other *Expr) bool { - return expr.Compare(other) == 0 -} - -// Compare returns an integer indicating whether expr is less than, equal to, -// or greater than other. -// -// Expressions are compared as follows: -// -// 1. Declarations are always less than other expressions. -// 2. Preceding expression (by Index) is always less than the other expression. -// 3. Non-negated expressions are always less than negated expressions. -// 4. Single term expressions are always less than built-in expressions. -// -// Otherwise, the expression terms are compared normally. If both expressions -// have the same terms, the modifiers are compared. -func (expr *Expr) Compare(other *Expr) int { - - if expr == nil { - if other == nil { - return 0 - } - return -1 - } else if other == nil { - return 1 - } - - o1 := expr.sortOrder() - o2 := other.sortOrder() - if o1 < o2 { - return -1 - } else if o2 < o1 { - return 1 - } - - switch { - case expr.Index < other.Index: - return -1 - case expr.Index > other.Index: - return 1 - } - - switch { - case expr.Negated && !other.Negated: - return 1 - case !expr.Negated && other.Negated: - return -1 - } - - switch t := expr.Terms.(type) { - case *Term: - if cmp := Compare(t.Value, other.Terms.(*Term).Value); cmp != 0 { - return cmp - } - case []*Term: - if cmp := termSliceCompare(t, other.Terms.([]*Term)); cmp != 0 { - return cmp - } - case *SomeDecl: - if cmp := Compare(t, other.Terms.(*SomeDecl)); cmp != 0 { - return cmp - } - case *Every: - if cmp := Compare(t, other.Terms.(*Every)); cmp != 0 { - return cmp - } - } - - return withSliceCompare(expr.With, other.With) -} - -func (expr *Expr) sortOrder() int { - switch expr.Terms.(type) { - case *SomeDecl: - return 0 - case *Term: - return 1 - case []*Term: - return 2 - case *Every: - return 3 - } - return -1 -} - -// CopyWithoutTerms returns a deep copy of expr without its Terms -func (expr *Expr) CopyWithoutTerms() *Expr { - cpy := *expr - - cpy.With = make([]*With, len(expr.With)) - for i := range expr.With { - cpy.With[i] = expr.With[i].Copy() - } - - return &cpy -} - -// Copy returns a deep copy of expr. -func (expr *Expr) Copy() *Expr { - - cpy := expr.CopyWithoutTerms() - - switch ts := expr.Terms.(type) { - case *SomeDecl: - cpy.Terms = ts.Copy() - case []*Term: - cpyTs := make([]*Term, len(ts)) - for i := range ts { - cpyTs[i] = ts[i].Copy() - } - cpy.Terms = cpyTs - case *Term: - cpy.Terms = ts.Copy() - case *Every: - cpy.Terms = ts.Copy() - } - - return cpy -} - -// Hash returns the hash code of the Expr. -func (expr *Expr) Hash() int { - s := expr.Index - switch ts := expr.Terms.(type) { - case *SomeDecl: - s += ts.Hash() - case []*Term: - for _, t := range ts { - s += t.Value.Hash() - } - case *Term: - s += ts.Value.Hash() - } - if expr.Negated { - s++ - } - for _, w := range expr.With { - s += w.Hash() - } - return s -} - -// IncludeWith returns a copy of expr with the with modifier appended. -func (expr *Expr) IncludeWith(target *Term, value *Term) *Expr { - cpy := *expr - cpy.With = append(cpy.With, &With{Target: target, Value: value}) - return &cpy -} - -// NoWith returns a copy of expr where the with modifier has been removed. -func (expr *Expr) NoWith() *Expr { - cpy := *expr - cpy.With = nil - return &cpy -} - -// IsEquality returns true if this is an equality expression. -func (expr *Expr) IsEquality() bool { - return isGlobalBuiltin(expr, Var(Equality.Name)) -} - -// IsAssignment returns true if this an assignment expression. -func (expr *Expr) IsAssignment() bool { - return isGlobalBuiltin(expr, Var(Assign.Name)) -} - -// IsCall returns true if this expression calls a function. -func (expr *Expr) IsCall() bool { - _, ok := expr.Terms.([]*Term) - return ok -} - -// IsEvery returns true if this expression is an 'every' expression. -func (expr *Expr) IsEvery() bool { - _, ok := expr.Terms.(*Every) - return ok -} - -// IsSome returns true if this expression is a 'some' expression. -func (expr *Expr) IsSome() bool { - _, ok := expr.Terms.(*SomeDecl) - return ok -} - -// Operator returns the name of the function or built-in this expression refers -// to. If this expression is not a function call, returns nil. -func (expr *Expr) Operator() Ref { - op := expr.OperatorTerm() - if op == nil { - return nil - } - return op.Value.(Ref) -} - -// OperatorTerm returns the name of the function or built-in this expression -// refers to. If this expression is not a function call, returns nil. -func (expr *Expr) OperatorTerm() *Term { - terms, ok := expr.Terms.([]*Term) - if !ok || len(terms) == 0 { - return nil - } - return terms[0] -} - -// Operand returns the term at the zero-based pos. If the expr does not include -// at least pos+1 terms, this function returns nil. -func (expr *Expr) Operand(pos int) *Term { - terms, ok := expr.Terms.([]*Term) - if !ok { - return nil - } - idx := pos + 1 - if idx < len(terms) { - return terms[idx] - } - return nil -} - -// Operands returns the built-in function operands. -func (expr *Expr) Operands() []*Term { - terms, ok := expr.Terms.([]*Term) - if !ok { - return nil - } - return terms[1:] -} - -// IsGround returns true if all of the expression terms are ground. -func (expr *Expr) IsGround() bool { - switch ts := expr.Terms.(type) { - case []*Term: - for _, t := range ts[1:] { - if !t.IsGround() { - return false - } - } - case *Term: - return ts.IsGround() - } - return true -} - -// SetOperator sets the expr's operator and returns the expr itself. If expr is -// not a call expr, this function will panic. -func (expr *Expr) SetOperator(term *Term) *Expr { - expr.Terms.([]*Term)[0] = term - return expr -} - -// SetLocation sets the expr's location and returns the expr itself. -func (expr *Expr) SetLocation(loc *Location) *Expr { - expr.Location = loc - return expr -} - -// Loc returns the Location of expr. -func (expr *Expr) Loc() *Location { - if expr == nil { - return nil - } - return expr.Location -} - -// SetLoc sets the location on expr. -func (expr *Expr) SetLoc(loc *Location) { - expr.SetLocation(loc) -} - -func (expr *Expr) String() string { - buf := make([]string, 0, 2+len(expr.With)) - if expr.Negated { - buf = append(buf, "not") - } - switch t := expr.Terms.(type) { - case []*Term: - if expr.IsEquality() && validEqAssignArgCount(expr) { - buf = append(buf, fmt.Sprintf("%v %v %v", t[1], Equality.Infix, t[2])) - } else { - buf = append(buf, Call(t).String()) - } - case fmt.Stringer: - buf = append(buf, t.String()) - } - - for i := range expr.With { - buf = append(buf, expr.With[i].String()) - } - - return strings.Join(buf, " ") -} - -func (expr *Expr) setJSONOptions(opts astJSON.Options) { - expr.jsonOptions = opts - if expr.Location != nil { - expr.Location.JSONOptions = opts - } -} - -func (expr *Expr) MarshalJSON() ([]byte, error) { - data := map[string]interface{}{ - "terms": expr.Terms, - "index": expr.Index, - } - - if len(expr.With) > 0 { - data["with"] = expr.With - } - - if expr.Generated { - data["generated"] = true - } - - if expr.Negated { - data["negated"] = true - } - - if expr.jsonOptions.MarshalOptions.IncludeLocation.Expr { - if expr.Location != nil { - data["location"] = expr.Location - } - } - - return json.Marshal(data) -} - -// UnmarshalJSON parses the byte array and stores the result in expr. -func (expr *Expr) UnmarshalJSON(bs []byte) error { - v := map[string]interface{}{} - if err := util.UnmarshalJSON(bs, &v); err != nil { - return err - } - return unmarshalExpr(expr, v) -} - -// Vars returns a VarSet containing variables in expr. The params can be set to -// control which vars are included. -func (expr *Expr) Vars(params VarVisitorParams) VarSet { - vis := NewVarVisitor().WithParams(params) - vis.Walk(expr) - return vis.Vars() + return v1.NewExpr(terms) } // NewBuiltinExpr creates a new Expr object with the supplied terms. // The builtin operator must be the first term. func NewBuiltinExpr(terms ...*Term) *Expr { - return &Expr{Terms: terms} -} - -func (expr *Expr) CogeneratedExprs() []*Expr { - visited := map[*Expr]struct{}{} - visitCogeneratedExprs(expr, func(e *Expr) bool { - if expr.Equal(e) { - return true - } - if _, ok := visited[e]; ok { - return true - } - visited[e] = struct{}{} - return false - }) - - result := make([]*Expr, 0, len(visited)) - for e := range visited { - result = append(result, e) - } - return result -} - -func (expr *Expr) BaseCogeneratedExpr() *Expr { - if expr.generatedFrom == nil { - return expr - } - return expr.generatedFrom.BaseCogeneratedExpr() -} - -func visitCogeneratedExprs(expr *Expr, f func(*Expr) bool) { - if parent := expr.generatedFrom; parent != nil { - if stop := f(parent); !stop { - visitCogeneratedExprs(parent, f) - } - } - for _, child := range expr.generates { - if stop := f(child); !stop { - visitCogeneratedExprs(child, f) - } - } -} - -func (d *SomeDecl) String() string { - if call, ok := d.Symbols[0].Value.(Call); ok { - if len(call) == 4 { - return "some " + call[1].String() + ", " + call[2].String() + " in " + call[3].String() - } - return "some " + call[1].String() + " in " + call[2].String() - } - buf := make([]string, len(d.Symbols)) - for i := range buf { - buf[i] = d.Symbols[i].String() - } - return "some " + strings.Join(buf, ", ") -} - -// SetLoc sets the Location on d. -func (d *SomeDecl) SetLoc(loc *Location) { - d.Location = loc -} - -// Loc returns the Location of d. -func (d *SomeDecl) Loc() *Location { - return d.Location -} - -// Copy returns a deep copy of d. -func (d *SomeDecl) Copy() *SomeDecl { - cpy := *d - cpy.Symbols = termSliceCopy(d.Symbols) - return &cpy -} - -// Compare returns an integer indicating whether d is less than, equal to, or -// greater than other. -func (d *SomeDecl) Compare(other *SomeDecl) int { - return termSliceCompare(d.Symbols, other.Symbols) -} - -// Hash returns a hash code of d. -func (d *SomeDecl) Hash() int { - return termSliceHash(d.Symbols) -} - -func (d *SomeDecl) setJSONOptions(opts astJSON.Options) { - d.jsonOptions = opts - if d.Location != nil { - d.Location.JSONOptions = opts - } -} - -func (d *SomeDecl) MarshalJSON() ([]byte, error) { - data := map[string]interface{}{ - "symbols": d.Symbols, - } - - if d.jsonOptions.MarshalOptions.IncludeLocation.SomeDecl { - if d.Location != nil { - data["location"] = d.Location - } - } - - return json.Marshal(data) -} - -func (q *Every) String() string { - if q.Key != nil { - return fmt.Sprintf("every %s, %s in %s { %s }", - q.Key, - q.Value, - q.Domain, - q.Body) - } - return fmt.Sprintf("every %s in %s { %s }", - q.Value, - q.Domain, - q.Body) -} - -func (q *Every) Loc() *Location { - return q.Location -} - -func (q *Every) SetLoc(l *Location) { - q.Location = l -} - -// Copy returns a deep copy of d. -func (q *Every) Copy() *Every { - cpy := *q - cpy.Key = q.Key.Copy() - cpy.Value = q.Value.Copy() - cpy.Domain = q.Domain.Copy() - cpy.Body = q.Body.Copy() - return &cpy -} - -func (q *Every) Compare(other *Every) int { - for _, terms := range [][2]*Term{ - {q.Key, other.Key}, - {q.Value, other.Value}, - {q.Domain, other.Domain}, - } { - if d := Compare(terms[0], terms[1]); d != 0 { - return d - } - } - return q.Body.Compare(other.Body) -} - -// KeyValueVars returns the key and val arguments of an `every` -// expression, if they are non-nil and not wildcards. -func (q *Every) KeyValueVars() VarSet { - vis := &VarVisitor{vars: VarSet{}} - if q.Key != nil { - vis.Walk(q.Key) - } - vis.Walk(q.Value) - return vis.vars -} - -func (q *Every) setJSONOptions(opts astJSON.Options) { - q.jsonOptions = opts - if q.Location != nil { - q.Location.JSONOptions = opts - } -} - -func (q *Every) MarshalJSON() ([]byte, error) { - data := map[string]interface{}{ - "key": q.Key, - "value": q.Value, - "domain": q.Domain, - "body": q.Body, - } - - if q.jsonOptions.MarshalOptions.IncludeLocation.Every { - if q.Location != nil { - data["location"] = q.Location - } - } - - return json.Marshal(data) -} - -func (w *With) String() string { - return "with " + w.Target.String() + " as " + w.Value.String() -} - -// Equal returns true if this With is equals the other With. -func (w *With) Equal(other *With) bool { - return Compare(w, other) == 0 -} - -// Compare returns an integer indicating whether w is less than, equal to, or -// greater than other. -func (w *With) Compare(other *With) int { - if w == nil { - if other == nil { - return 0 - } - return -1 - } else if other == nil { - return 1 - } - if cmp := Compare(w.Target, other.Target); cmp != 0 { - return cmp - } - return Compare(w.Value, other.Value) -} - -// Copy returns a deep copy of w. -func (w *With) Copy() *With { - cpy := *w - cpy.Value = w.Value.Copy() - cpy.Target = w.Target.Copy() - return &cpy -} - -// Hash returns the hash code of the With. -func (w With) Hash() int { - return w.Target.Hash() + w.Value.Hash() -} - -// SetLocation sets the location on w. -func (w *With) SetLocation(loc *Location) *With { - w.Location = loc - return w -} - -// Loc returns the Location of w. -func (w *With) Loc() *Location { - if w == nil { - return nil - } - return w.Location -} - -// SetLoc sets the location on w. -func (w *With) SetLoc(loc *Location) { - w.Location = loc -} - -func (w *With) setJSONOptions(opts astJSON.Options) { - w.jsonOptions = opts - if w.Location != nil { - w.Location.JSONOptions = opts - } -} - -func (w *With) MarshalJSON() ([]byte, error) { - data := map[string]interface{}{ - "target": w.Target, - "value": w.Value, - } - - if w.jsonOptions.MarshalOptions.IncludeLocation.With { - if w.Location != nil { - data["location"] = w.Location - } - } - - return json.Marshal(data) + return v1.NewBuiltinExpr(terms...) } // Copy returns a deep copy of the AST node x. If x is not an AST node, x is returned unmodified. func Copy(x interface{}) interface{} { - switch x := x.(type) { - case *Module: - return x.Copy() - case *Package: - return x.Copy() - case *Import: - return x.Copy() - case *Rule: - return x.Copy() - case *Head: - return x.Copy() - case Args: - return x.Copy() - case Body: - return x.Copy() - case *Expr: - return x.Copy() - case *With: - return x.Copy() - case *SomeDecl: - return x.Copy() - case *Every: - return x.Copy() - case *Term: - return x.Copy() - case *ArrayComprehension: - return x.Copy() - case *SetComprehension: - return x.Copy() - case *ObjectComprehension: - return x.Copy() - case Set: - return x.Copy() - case *object: - return x.Copy() - case *Array: - return x.Copy() - case Ref: - return x.Copy() - case Call: - return x.Copy() - case *Comment: - return x.Copy() - } - return x + return v1.Copy(x) } // RuleSet represents a collection of rules that produce a virtual document. -type RuleSet []*Rule +type RuleSet = v1.RuleSet // NewRuleSet returns a new RuleSet containing the given rules. func NewRuleSet(rules ...*Rule) RuleSet { - rs := make(RuleSet, 0, len(rules)) - for _, rule := range rules { - rs.Add(rule) - } - return rs -} - -// Add inserts the rule into rs. -func (rs *RuleSet) Add(rule *Rule) { - for _, exist := range *rs { - if exist.Equal(rule) { - return - } - } - *rs = append(*rs, rule) -} - -// Contains returns true if rs contains rule. -func (rs RuleSet) Contains(rule *Rule) bool { - for i := range rs { - if rs[i].Equal(rule) { - return true - } - } - return false -} - -// Diff returns a new RuleSet containing rules in rs that are not in other. -func (rs RuleSet) Diff(other RuleSet) RuleSet { - result := NewRuleSet() - for i := range rs { - if !other.Contains(rs[i]) { - result.Add(rs[i]) - } - } - return result -} - -// Equal returns true if rs equals other. -func (rs RuleSet) Equal(other RuleSet) bool { - return len(rs.Diff(other)) == 0 && len(other.Diff(rs)) == 0 -} - -// Merge returns a ruleset containing the union of rules from rs an other. -func (rs RuleSet) Merge(other RuleSet) RuleSet { - result := NewRuleSet() - for i := range rs { - result.Add(rs[i]) - } - for i := range other { - result.Add(other[i]) - } - return result -} - -func (rs RuleSet) String() string { - buf := make([]string, 0, len(rs)) - for _, rule := range rs { - buf = append(buf, rule.String()) - } - return "{" + strings.Join(buf, ", ") + "}" -} - -// Returns true if the equality or assignment expression referred to by expr -// has a valid number of arguments. -func validEqAssignArgCount(expr *Expr) bool { - return len(expr.Operands()) == 2 -} - -// this function checks if the expr refers to a non-namespaced (global) built-in -// function like eq, gt, plus, etc. -func isGlobalBuiltin(expr *Expr, name Var) bool { - terms, ok := expr.Terms.([]*Term) - if !ok { - return false - } - - // NOTE(tsandall): do not use Term#Equal or Value#Compare to avoid - // allocation here. - ref, ok := terms[0].Value.(Ref) - if !ok || len(ref) != 1 { - return false - } - if head, ok := ref[0].Value.(Var); ok { - return head.Equal(name) - } - return false + return v1.NewRuleSet(rules...) } diff --git a/vendor/github.com/open-policy-agent/opa/ast/pretty.go b/vendor/github.com/open-policy-agent/opa/ast/pretty.go index b4f05ad50..f2b8104e0 100644 --- a/vendor/github.com/open-policy-agent/opa/ast/pretty.go +++ b/vendor/github.com/open-policy-agent/opa/ast/pretty.go @@ -5,78 +5,14 @@ package ast import ( - "fmt" "io" - "strings" + + v1 "github.com/open-policy-agent/opa/v1/ast" ) // Pretty writes a pretty representation of the AST rooted at x to w. // // This is function is intended for debug purposes when inspecting ASTs. func Pretty(w io.Writer, x interface{}) { - pp := &prettyPrinter{ - depth: -1, - w: w, - } - NewBeforeAfterVisitor(pp.Before, pp.After).Walk(x) -} - -type prettyPrinter struct { - depth int - w io.Writer -} - -func (pp *prettyPrinter) Before(x interface{}) bool { - switch x.(type) { - case *Term: - default: - pp.depth++ - } - - switch x := x.(type) { - case *Term: - return false - case Args: - if len(x) == 0 { - return false - } - pp.writeType(x) - case *Expr: - extras := []string{} - if x.Negated { - extras = append(extras, "negated") - } - extras = append(extras, fmt.Sprintf("index=%d", x.Index)) - pp.writeIndent("%v %v", TypeName(x), strings.Join(extras, " ")) - case Null, Boolean, Number, String, Var: - pp.writeValue(x) - default: - pp.writeType(x) - } - return false -} - -func (pp *prettyPrinter) After(x interface{}) { - switch x.(type) { - case *Term: - default: - pp.depth-- - } -} - -func (pp *prettyPrinter) writeValue(x interface{}) { - pp.writeIndent(fmt.Sprint(x)) -} - -func (pp *prettyPrinter) writeType(x interface{}) { - pp.writeIndent(TypeName(x)) -} - -func (pp *prettyPrinter) writeIndent(f string, a ...interface{}) { - pad := strings.Repeat(" ", pp.depth) - pp.write(pad+f, a...) -} - -func (pp *prettyPrinter) write(f string, a ...interface{}) { - fmt.Fprintf(pp.w, f+"\n", a...) + v1.Pretty(w, x) } diff --git a/vendor/github.com/open-policy-agent/opa/ast/schema.go b/vendor/github.com/open-policy-agent/opa/ast/schema.go index 8c96ac624..979958a3c 100644 --- a/vendor/github.com/open-policy-agent/opa/ast/schema.go +++ b/vendor/github.com/open-policy-agent/opa/ast/schema.go @@ -5,59 +5,13 @@ package ast import ( - "fmt" - - "github.com/open-policy-agent/opa/types" - "github.com/open-policy-agent/opa/util" + v1 "github.com/open-policy-agent/opa/v1/ast" ) // SchemaSet holds a map from a path to a schema. -type SchemaSet struct { - m *util.HashMap -} +type SchemaSet = v1.SchemaSet // NewSchemaSet returns an empty SchemaSet. func NewSchemaSet() *SchemaSet { - - eqFunc := func(a, b util.T) bool { - return a.(Ref).Equal(b.(Ref)) - } - - hashFunc := func(x util.T) int { return x.(Ref).Hash() } - - return &SchemaSet{ - m: util.NewHashMap(eqFunc, hashFunc), - } -} - -// Put inserts a raw schema into the set. -func (ss *SchemaSet) Put(path Ref, raw interface{}) { - ss.m.Put(path, raw) -} - -// Get returns the raw schema identified by the path. -func (ss *SchemaSet) Get(path Ref) interface{} { - if ss == nil { - return nil - } - x, ok := ss.m.Get(path) - if !ok { - return nil - } - return x -} - -func loadSchema(raw interface{}, allowNet []string) (types.Type, error) { - - jsonSchema, err := compileSchema(raw, allowNet) - if err != nil { - return nil, err - } - - tpe, err := newSchemaParser().parseSchema(jsonSchema.RootSchema) - if err != nil { - return nil, fmt.Errorf("type checking: %w", err) - } - - return tpe, nil + return v1.NewSchemaSet() } diff --git a/vendor/github.com/open-policy-agent/opa/ast/strings.go b/vendor/github.com/open-policy-agent/opa/ast/strings.go index e489f6977..ef9354bf7 100644 --- a/vendor/github.com/open-policy-agent/opa/ast/strings.go +++ b/vendor/github.com/open-policy-agent/opa/ast/strings.go @@ -5,14 +5,10 @@ package ast import ( - "reflect" - "strings" + v1 "github.com/open-policy-agent/opa/v1/ast" ) // TypeName returns a human readable name for the AST element type. func TypeName(x interface{}) string { - if _, ok := x.(*lazyObj); ok { - return "object" - } - return strings.ToLower(reflect.Indirect(reflect.ValueOf(x)).Type().Name()) + return v1.TypeName(x) } diff --git a/vendor/github.com/open-policy-agent/opa/ast/term.go b/vendor/github.com/open-policy-agent/opa/ast/term.go index ce8ee4853..a5d146ea2 100644 --- a/vendor/github.com/open-policy-agent/opa/ast/term.go +++ b/vendor/github.com/open-policy-agent/opa/ast/term.go @@ -1,40 +1,22 @@ -// Copyright 2016 The OPA Authors. All rights reserved. +// Copyright 2024 The OPA Authors. All rights reserved. // Use of this source code is governed by an Apache2 // license that can be found in the LICENSE file. -// nolint: deadcode // Public API. package ast import ( - "bytes" "encoding/json" - "errors" - "fmt" "io" - "math" - "math/big" - "net/url" - "regexp" - "sort" - "strconv" - "strings" - "sync" - - "github.com/OneOfOne/xxhash" - - astJSON "github.com/open-policy-agent/opa/ast/json" - "github.com/open-policy-agent/opa/ast/location" - "github.com/open-policy-agent/opa/util" -) -var errFindNotFound = fmt.Errorf("find: not found") + v1 "github.com/open-policy-agent/opa/v1/ast" +) // Location records a position in source code. -type Location = location.Location +type Location = v1.Location // NewLocation returns a new Location object. func NewLocation(text []byte, file string, row int, col int) *Location { - return location.NewLocation(text, file, row, col) + return v1.NewLocation(text, file, row, col) } // Value declares the common interface for all Term values. Every kind of Term value @@ -45,226 +27,58 @@ func NewLocation(text []byte, file string, row int, col int) *Location { // - Variables, References // - Array, Set, and Object Comprehensions // - Calls -type Value interface { - Compare(other Value) int // Compare returns <0, 0, or >0 if this Value is less than, equal to, or greater than other, respectively. - Find(path Ref) (Value, error) // Find returns value referred to by path or an error if path is not found. - Hash() int // Returns hash code of the value. - IsGround() bool // IsGround returns true if this value is not a variable or contains no variables. - String() string // String returns a human readable string representation of the value. -} +type Value = v1.Value // InterfaceToValue converts a native Go value x to a Value. func InterfaceToValue(x interface{}) (Value, error) { - switch x := x.(type) { - case nil: - return Null{}, nil - case bool: - return Boolean(x), nil - case json.Number: - return Number(x), nil - case int64: - return int64Number(x), nil - case uint64: - return uint64Number(x), nil - case float64: - return floatNumber(x), nil - case int: - return intNumber(x), nil - case string: - return String(x), nil - case []interface{}: - r := make([]*Term, len(x)) - for i, e := range x { - e, err := InterfaceToValue(e) - if err != nil { - return nil, err - } - r[i] = &Term{Value: e} - } - return NewArray(r...), nil - case map[string]interface{}: - r := newobject(len(x)) - for k, v := range x { - k, err := InterfaceToValue(k) - if err != nil { - return nil, err - } - v, err := InterfaceToValue(v) - if err != nil { - return nil, err - } - r.Insert(NewTerm(k), NewTerm(v)) - } - return r, nil - case map[string]string: - r := newobject(len(x)) - for k, v := range x { - k, err := InterfaceToValue(k) - if err != nil { - return nil, err - } - v, err := InterfaceToValue(v) - if err != nil { - return nil, err - } - r.Insert(NewTerm(k), NewTerm(v)) - } - return r, nil - default: - ptr := util.Reference(x) - if err := util.RoundTrip(ptr); err != nil { - return nil, fmt.Errorf("ast: interface conversion: %w", err) - } - return InterfaceToValue(*ptr) - } + return v1.InterfaceToValue(x) } // ValueFromReader returns an AST value from a JSON serialized value in the reader. func ValueFromReader(r io.Reader) (Value, error) { - var x interface{} - if err := util.NewJSONDecoder(r).Decode(&x); err != nil { - return nil, err - } - return InterfaceToValue(x) + return v1.ValueFromReader(r) } // As converts v into a Go native type referred to by x. func As(v Value, x interface{}) error { - return util.NewJSONDecoder(bytes.NewBufferString(v.String())).Decode(x) + return v1.As(v, x) } // Resolver defines the interface for resolving references to native Go values. -type Resolver interface { - Resolve(Ref) (interface{}, error) -} +type Resolver = v1.Resolver // ValueResolver defines the interface for resolving references to AST values. -type ValueResolver interface { - Resolve(Ref) (Value, error) -} +type ValueResolver = v1.ValueResolver // UnknownValueErr indicates a ValueResolver was unable to resolve a reference // because the reference refers to an unknown value. -type UnknownValueErr struct{} - -func (UnknownValueErr) Error() string { - return "unknown value" -} +type UnknownValueErr = v1.UnknownValueErr // IsUnknownValueErr returns true if the err is an UnknownValueErr. func IsUnknownValueErr(err error) bool { - _, ok := err.(UnknownValueErr) - return ok -} - -type illegalResolver struct{} - -func (illegalResolver) Resolve(ref Ref) (interface{}, error) { - return nil, fmt.Errorf("illegal value: %v", ref) + return v1.IsUnknownValueErr(err) } // ValueToInterface returns the Go representation of an AST value. The AST // value should not contain any values that require evaluation (e.g., vars, // comprehensions, etc.) func ValueToInterface(v Value, resolver Resolver) (interface{}, error) { - return valueToInterface(v, resolver, JSONOpt{}) -} - -func valueToInterface(v Value, resolver Resolver, opt JSONOpt) (interface{}, error) { - switch v := v.(type) { - case Null: - return nil, nil - case Boolean: - return bool(v), nil - case Number: - return json.Number(v), nil - case String: - return string(v), nil - case *Array: - buf := []interface{}{} - for i := 0; i < v.Len(); i++ { - x1, err := valueToInterface(v.Elem(i).Value, resolver, opt) - if err != nil { - return nil, err - } - buf = append(buf, x1) - } - return buf, nil - case *object: - buf := make(map[string]interface{}, v.Len()) - err := v.Iter(func(k, v *Term) error { - ki, err := valueToInterface(k.Value, resolver, opt) - if err != nil { - return err - } - var str string - var ok bool - if str, ok = ki.(string); !ok { - var buf bytes.Buffer - if err := json.NewEncoder(&buf).Encode(ki); err != nil { - return err - } - str = strings.TrimSpace(buf.String()) - } - vi, err := valueToInterface(v.Value, resolver, opt) - if err != nil { - return err - } - buf[str] = vi - return nil - }) - if err != nil { - return nil, err - } - return buf, nil - case *lazyObj: - if opt.CopyMaps { - return valueToInterface(v.force(), resolver, opt) - } - return v.native, nil - case Set: - buf := []interface{}{} - iter := func(x *Term) error { - x1, err := valueToInterface(x.Value, resolver, opt) - if err != nil { - return err - } - buf = append(buf, x1) - return nil - } - var err error - if opt.SortSets { - err = v.Sorted().Iter(iter) - } else { - err = v.Iter(iter) - } - if err != nil { - return nil, err - } - return buf, nil - case Ref: - return resolver.Resolve(v) - default: - return nil, fmt.Errorf("%v requires evaluation", TypeName(v)) - } + return v1.ValueToInterface(v, resolver) } // JSON returns the JSON representation of v. The value must not contain any // refs or terms that require evaluation (e.g., vars, comprehensions, etc.) func JSON(v Value) (interface{}, error) { - return JSONWithOpt(v, JSONOpt{}) + return v1.JSON(v) } // JSONOpt defines parameters for AST to JSON conversion. -type JSONOpt struct { - SortSets bool // sort sets before serializing (this makes conversion more expensive) - CopyMaps bool // enforces copying of map[string]interface{} read from the store -} +type JSONOpt = v1.JSONOpt // JSONWithOpt returns the JSON representation of v. The value must not contain any // refs or terms that require evaluation (e.g., vars, comprehensions, etc.) func JSONWithOpt(v Value, opt JSONOpt) (interface{}, error) { - return valueToInterface(v, illegalResolver{}, opt) + return v1.JSONWithOpt(v, opt) } // MustJSON returns the JSON representation of v. The value must not contain any @@ -272,3003 +86,221 @@ func JSONWithOpt(v Value, opt JSONOpt) (interface{}, error) { // the conversion fails, this function will panic. This function is mostly for // test purposes. func MustJSON(v Value) interface{} { - r, err := JSON(v) - if err != nil { - panic(err) - } - return r + return v1.MustJSON(v) } // MustInterfaceToValue converts a native Go value x to a Value. If the // conversion fails, this function will panic. This function is mostly for test // purposes. func MustInterfaceToValue(x interface{}) Value { - v, err := InterfaceToValue(x) - if err != nil { - panic(err) - } - return v + return v1.MustInterfaceToValue(x) } // Term is an argument to a function. -type Term struct { - Value Value `json:"value"` // the value of the Term as represented in Go - Location *Location `json:"location,omitempty"` // the location of the Term in the source - - jsonOptions astJSON.Options -} +type Term = v1.Term // NewTerm returns a new Term object. func NewTerm(v Value) *Term { - return &Term{ - Value: v, - } -} - -// SetLocation updates the term's Location and returns the term itself. -func (term *Term) SetLocation(loc *Location) *Term { - term.Location = loc - return term -} - -// Loc returns the Location of term. -func (term *Term) Loc() *Location { - if term == nil { - return nil - } - return term.Location -} - -// SetLoc sets the location on term. -func (term *Term) SetLoc(loc *Location) { - term.SetLocation(loc) -} - -// Copy returns a deep copy of term. -func (term *Term) Copy() *Term { - - if term == nil { - return nil - } - - cpy := *term - - switch v := term.Value.(type) { - case Null, Boolean, Number, String, Var: - cpy.Value = v - case Ref: - cpy.Value = v.Copy() - case *Array: - cpy.Value = v.Copy() - case Set: - cpy.Value = v.Copy() - case *object: - cpy.Value = v.Copy() - case *ArrayComprehension: - cpy.Value = v.Copy() - case *ObjectComprehension: - cpy.Value = v.Copy() - case *SetComprehension: - cpy.Value = v.Copy() - case Call: - cpy.Value = v.Copy() - } - - return &cpy -} - -// Equal returns true if this term equals the other term. Equality is -// defined for each kind of term. -func (term *Term) Equal(other *Term) bool { - if term == nil && other != nil { - return false - } - if term != nil && other == nil { - return false - } - if term == other { - return true - } - - // TODO(tsandall): This early-exit avoids allocations for types that have - // Equal() functions that just use == underneath. We should revisit the - // other types and implement Equal() functions that do not require - // allocations. - switch v := term.Value.(type) { - case Null: - return v.Equal(other.Value) - case Boolean: - return v.Equal(other.Value) - case Number: - return v.Equal(other.Value) - case String: - return v.Equal(other.Value) - case Var: - return v.Equal(other.Value) - } - - return term.Value.Compare(other.Value) == 0 -} - -// Get returns a value referred to by name from the term. -func (term *Term) Get(name *Term) *Term { - switch v := term.Value.(type) { - case *object: - return v.Get(name) - case *Array: - return v.Get(name) - case interface { - Get(*Term) *Term - }: - return v.Get(name) - case Set: - if v.Contains(name) { - return name - } - } - return nil -} - -// Hash returns the hash code of the Term's Value. Its Location -// is ignored. -func (term *Term) Hash() int { - return term.Value.Hash() -} - -// IsGround returns true if this term's Value is ground. -func (term *Term) IsGround() bool { - return term.Value.IsGround() -} - -func (term *Term) setJSONOptions(opts astJSON.Options) { - term.jsonOptions = opts - if term.Location != nil { - term.Location.JSONOptions = opts - } -} - -// MarshalJSON returns the JSON encoding of the term. -// -// Specialized marshalling logic is required to include a type hint for Value. -func (term *Term) MarshalJSON() ([]byte, error) { - d := map[string]interface{}{ - "type": TypeName(term.Value), - "value": term.Value, - } - if term.jsonOptions.MarshalOptions.IncludeLocation.Term { - if term.Location != nil { - d["location"] = term.Location - } - } - return json.Marshal(d) -} - -func (term *Term) String() string { - return term.Value.String() -} - -// UnmarshalJSON parses the byte array and stores the result in term. -// Specialized unmarshalling is required to handle Value and Location. -func (term *Term) UnmarshalJSON(bs []byte) error { - v := map[string]interface{}{} - if err := util.UnmarshalJSON(bs, &v); err != nil { - return err - } - val, err := unmarshalValue(v) - if err != nil { - return err - } - term.Value = val - - if loc, ok := v["location"].(map[string]interface{}); ok { - term.Location = &Location{} - err := unmarshalLocation(term.Location, loc) - if err != nil { - return err - } - } - return nil -} - -// Vars returns a VarSet with variables contained in this term. -func (term *Term) Vars() VarSet { - vis := &VarVisitor{vars: VarSet{}} - vis.Walk(term) - return vis.vars + return v1.NewTerm(v) } // IsConstant returns true if the AST value is constant. func IsConstant(v Value) bool { - found := false - vis := GenericVisitor{ - func(x interface{}) bool { - switch x.(type) { - case Var, Ref, *ArrayComprehension, *ObjectComprehension, *SetComprehension, Call: - found = true - return true - } - return false - }, - } - vis.Walk(v) - return !found + return v1.IsConstant(v) } // IsComprehension returns true if the supplied value is a comprehension. func IsComprehension(x Value) bool { - switch x.(type) { - case *ArrayComprehension, *ObjectComprehension, *SetComprehension: - return true - } - return false + return v1.IsComprehension(x) } // ContainsRefs returns true if the Value v contains refs. func ContainsRefs(v interface{}) bool { - found := false - WalkRefs(v, func(Ref) bool { - found = true - return found - }) - return found + return v1.ContainsRefs(v) } // ContainsComprehensions returns true if the Value v contains comprehensions. func ContainsComprehensions(v interface{}) bool { - found := false - WalkClosures(v, func(x interface{}) bool { - switch x.(type) { - case *ArrayComprehension, *ObjectComprehension, *SetComprehension: - found = true - return found - } - return found - }) - return found + return v1.ContainsComprehensions(v) } // ContainsClosures returns true if the Value v contains closures. func ContainsClosures(v interface{}) bool { - found := false - WalkClosures(v, func(x interface{}) bool { - switch x.(type) { - case *ArrayComprehension, *ObjectComprehension, *SetComprehension, *Every: - found = true - return found - } - return found - }) - return found + return v1.ContainsClosures(v) } // IsScalar returns true if the AST value is a scalar. func IsScalar(v Value) bool { - switch v.(type) { - case String: - return true - case Number: - return true - case Boolean: - return true - case Null: - return true - } - return false + return v1.IsScalar(v) } // Null represents the null value defined by JSON. -type Null struct{} +type Null = v1.Null // NullTerm creates a new Term with a Null value. func NullTerm() *Term { - return &Term{Value: Null{}} -} - -// Equal returns true if the other term Value is also Null. -func (null Null) Equal(other Value) bool { - switch other.(type) { - case Null: - return true - default: - return false - } -} - -// Compare compares null to other, return <0, 0, or >0 if it is less than, equal to, -// or greater than other. -func (null Null) Compare(other Value) int { - return Compare(null, other) -} - -// Find returns the current value or a not found error. -func (null Null) Find(path Ref) (Value, error) { - if len(path) == 0 { - return null, nil - } - return nil, errFindNotFound -} - -// Hash returns the hash code for the Value. -func (null Null) Hash() int { - return 0 -} - -// IsGround always returns true. -func (Null) IsGround() bool { - return true -} - -func (null Null) String() string { - return "null" + return v1.NullTerm() } // Boolean represents a boolean value defined by JSON. -type Boolean bool +type Boolean = v1.Boolean // BooleanTerm creates a new Term with a Boolean value. func BooleanTerm(b bool) *Term { - return &Term{Value: Boolean(b)} -} - -// Equal returns true if the other Value is a Boolean and is equal. -func (bol Boolean) Equal(other Value) bool { - switch other := other.(type) { - case Boolean: - return bol == other - default: - return false - } -} - -// Compare compares bol to other, return <0, 0, or >0 if it is less than, equal to, -// or greater than other. -func (bol Boolean) Compare(other Value) int { - return Compare(bol, other) -} - -// Find returns the current value or a not found error. -func (bol Boolean) Find(path Ref) (Value, error) { - if len(path) == 0 { - return bol, nil - } - return nil, errFindNotFound -} - -// Hash returns the hash code for the Value. -func (bol Boolean) Hash() int { - if bol { - return 1 - } - return 0 -} - -// IsGround always returns true. -func (Boolean) IsGround() bool { - return true -} - -func (bol Boolean) String() string { - return strconv.FormatBool(bool(bol)) + return v1.BooleanTerm(b) } // Number represents a numeric value as defined by JSON. -type Number json.Number +type Number = v1.Number // NumberTerm creates a new Term with a Number value. func NumberTerm(n json.Number) *Term { - return &Term{Value: Number(n)} + return v1.NumberTerm(n) } // IntNumberTerm creates a new Term with an integer Number value. func IntNumberTerm(i int) *Term { - return &Term{Value: Number(strconv.Itoa(i))} + return v1.IntNumberTerm(i) } // UIntNumberTerm creates a new Term with an unsigned integer Number value. func UIntNumberTerm(u uint64) *Term { - return &Term{Value: uint64Number(u)} + return v1.UIntNumberTerm(u) } // FloatNumberTerm creates a new Term with a floating point Number value. func FloatNumberTerm(f float64) *Term { - s := strconv.FormatFloat(f, 'g', -1, 64) - return &Term{Value: Number(s)} -} - -// Equal returns true if the other Value is a Number and is equal. -func (num Number) Equal(other Value) bool { - switch other := other.(type) { - case Number: - return Compare(num, other) == 0 - default: - return false - } -} - -// Compare compares num to other, return <0, 0, or >0 if it is less than, equal to, -// or greater than other. -func (num Number) Compare(other Value) int { - return Compare(num, other) -} - -// Find returns the current value or a not found error. -func (num Number) Find(path Ref) (Value, error) { - if len(path) == 0 { - return num, nil - } - return nil, errFindNotFound -} - -// Hash returns the hash code for the Value. -func (num Number) Hash() int { - f, err := json.Number(num).Float64() - if err != nil { - bs := []byte(num) - h := xxhash.Checksum64(bs) - return int(h) - } - return int(f) -} - -// Int returns the int representation of num if possible. -func (num Number) Int() (int, bool) { - i64, ok := num.Int64() - return int(i64), ok -} - -// Int64 returns the int64 representation of num if possible. -func (num Number) Int64() (int64, bool) { - i, err := json.Number(num).Int64() - if err != nil { - return 0, false - } - return i, true -} - -// Float64 returns the float64 representation of num if possible. -func (num Number) Float64() (float64, bool) { - f, err := json.Number(num).Float64() - if err != nil { - return 0, false - } - return f, true -} - -// IsGround always returns true. -func (Number) IsGround() bool { - return true -} - -// MarshalJSON returns JSON encoded bytes representing num. -func (num Number) MarshalJSON() ([]byte, error) { - return json.Marshal(json.Number(num)) -} - -func (num Number) String() string { - return string(num) -} - -func intNumber(i int) Number { - return Number(strconv.Itoa(i)) -} - -func int64Number(i int64) Number { - return Number(strconv.FormatInt(i, 10)) -} - -func uint64Number(u uint64) Number { - return Number(strconv.FormatUint(u, 10)) -} - -func floatNumber(f float64) Number { - return Number(strconv.FormatFloat(f, 'g', -1, 64)) + return v1.FloatNumberTerm(f) } // String represents a string value as defined by JSON. -type String string +type String = v1.String // StringTerm creates a new Term with a String value. func StringTerm(s string) *Term { - return &Term{Value: String(s)} -} - -// Equal returns true if the other Value is a String and is equal. -func (str String) Equal(other Value) bool { - switch other := other.(type) { - case String: - return str == other - default: - return false - } -} - -// Compare compares str to other, return <0, 0, or >0 if it is less than, equal to, -// or greater than other. -func (str String) Compare(other Value) int { - return Compare(str, other) -} - -// Find returns the current value or a not found error. -func (str String) Find(path Ref) (Value, error) { - if len(path) == 0 { - return str, nil - } - return nil, errFindNotFound -} - -// IsGround always returns true. -func (String) IsGround() bool { - return true -} - -func (str String) String() string { - return strconv.Quote(string(str)) -} - -// Hash returns the hash code for the Value. -func (str String) Hash() int { - h := xxhash.ChecksumString64S(string(str), hashSeed0) - return int(h) + return v1.StringTerm(s) } // Var represents a variable as defined by the language. -type Var string +type Var = v1.Var // VarTerm creates a new Term with a Variable value. func VarTerm(v string) *Term { - return &Term{Value: Var(v)} -} - -// Equal returns true if the other Value is a Variable and has the same value -// (name). -func (v Var) Equal(other Value) bool { - switch other := other.(type) { - case Var: - return v == other - default: - return false - } -} - -// Compare compares v to other, return <0, 0, or >0 if it is less than, equal to, -// or greater than other. -func (v Var) Compare(other Value) int { - return Compare(v, other) -} - -// Find returns the current value or a not found error. -func (v Var) Find(path Ref) (Value, error) { - if len(path) == 0 { - return v, nil - } - return nil, errFindNotFound -} - -// Hash returns the hash code for the Value. -func (v Var) Hash() int { - h := xxhash.ChecksumString64S(string(v), hashSeed0) - return int(h) -} - -// IsGround always returns false. -func (Var) IsGround() bool { - return false -} - -// IsWildcard returns true if this is a wildcard variable. -func (v Var) IsWildcard() bool { - return strings.HasPrefix(string(v), WildcardPrefix) -} - -// IsGenerated returns true if this variable was generated during compilation. -func (v Var) IsGenerated() bool { - return strings.HasPrefix(string(v), "__local") -} - -func (v Var) String() string { - // Special case for wildcard so that string representation is parseable. The - // parser mangles wildcard variables to make their names unique and uses an - // illegal variable name character (WildcardPrefix) to avoid conflicts. When - // we serialize the variable here, we need to make sure it's parseable. - if v.IsWildcard() { - return Wildcard.String() - } - return string(v) + return v1.VarTerm(v) } // Ref represents a reference as defined by the language. -type Ref []*Term +type Ref = v1.Ref // EmptyRef returns a new, empty reference. func EmptyRef() Ref { - return Ref([]*Term{}) + return v1.EmptyRef() } // PtrRef returns a new reference against the head for the pointer // s. Path components in the pointer are unescaped. func PtrRef(head *Term, s string) (Ref, error) { - s = strings.Trim(s, "/") - if s == "" { - return Ref{head}, nil - } - parts := strings.Split(s, "/") - if maxLen := math.MaxInt32; len(parts) >= maxLen { - return nil, fmt.Errorf("path too long: %s, %d > %d (max)", s, len(parts), maxLen) - } - ref := make(Ref, uint(len(parts))+1) - ref[0] = head - for i := 0; i < len(parts); i++ { - var err error - parts[i], err = url.PathUnescape(parts[i]) - if err != nil { - return nil, err - } - ref[i+1] = StringTerm(parts[i]) - } - return ref, nil + return v1.PtrRef(head, s) } // RefTerm creates a new Term with a Ref value. func RefTerm(r ...*Term) *Term { - return &Term{Value: Ref(r)} -} - -// Append returns a copy of ref with the term appended to the end. -func (ref Ref) Append(term *Term) Ref { - n := len(ref) - dst := make(Ref, n+1) - copy(dst, ref) - dst[n] = term - return dst -} - -// Insert returns a copy of the ref with x inserted at pos. If pos < len(ref), -// existing elements are shifted to the right. If pos > len(ref)+1 this -// function panics. -func (ref Ref) Insert(x *Term, pos int) Ref { - switch { - case pos == len(ref): - return ref.Append(x) - case pos > len(ref)+1: - panic("illegal index") - } - cpy := make(Ref, len(ref)+1) - copy(cpy, ref[:pos]) - cpy[pos] = x - copy(cpy[pos+1:], ref[pos:]) - return cpy -} - -// Extend returns a copy of ref with the terms from other appended. The head of -// other will be converted to a string. -func (ref Ref) Extend(other Ref) Ref { - dst := make(Ref, len(ref)+len(other)) - copy(dst, ref) - - head := other[0].Copy() - head.Value = String(head.Value.(Var)) - offset := len(ref) - dst[offset] = head - - copy(dst[offset+1:], other[1:]) - return dst -} - -// Concat returns a ref with the terms appended. -func (ref Ref) Concat(terms []*Term) Ref { - if len(terms) == 0 { - return ref - } - cpy := make(Ref, len(ref)+len(terms)) - copy(cpy, ref) - copy(cpy[len(ref):], terms) - return cpy -} - -// Dynamic returns the offset of the first non-constant operand of ref. -func (ref Ref) Dynamic() int { - switch ref[0].Value.(type) { - case Call: - return 0 - } - for i := 1; i < len(ref); i++ { - if !IsConstant(ref[i].Value) { - return i - } - } - return -1 -} - -// Copy returns a deep copy of ref. -func (ref Ref) Copy() Ref { - return termSliceCopy(ref) -} - -// Equal returns true if ref is equal to other. -func (ref Ref) Equal(other Value) bool { - return Compare(ref, other) == 0 -} - -// Compare compares ref to other, return <0, 0, or >0 if it is less than, equal to, -// or greater than other. -func (ref Ref) Compare(other Value) int { - return Compare(ref, other) -} - -// Find returns the current value or a "not found" error. -func (ref Ref) Find(path Ref) (Value, error) { - if len(path) == 0 { - return ref, nil - } - return nil, errFindNotFound -} - -// Hash returns the hash code for the Value. -func (ref Ref) Hash() int { - return termSliceHash(ref) -} - -// HasPrefix returns true if the other ref is a prefix of this ref. -func (ref Ref) HasPrefix(other Ref) bool { - if len(other) > len(ref) { - return false - } - for i := range other { - if !ref[i].Equal(other[i]) { - return false - } - } - return true -} - -// ConstantPrefix returns the constant portion of the ref starting from the head. -func (ref Ref) ConstantPrefix() Ref { - ref = ref.Copy() - - i := ref.Dynamic() - if i < 0 { - return ref - } - return ref[:i] -} - -func (ref Ref) StringPrefix() Ref { - r := ref.Copy() - - for i := 1; i < len(ref); i++ { - switch r[i].Value.(type) { - case String: // pass - default: // cut off - return r[:i] - } - } - - return r -} - -// GroundPrefix returns the ground portion of the ref starting from the head. By -// definition, the head of the reference is always ground. -func (ref Ref) GroundPrefix() Ref { - prefix := make(Ref, 0, len(ref)) - - for i, x := range ref { - if i > 0 && !x.IsGround() { - break - } - prefix = append(prefix, x) - } - - return prefix -} - -func (ref Ref) DynamicSuffix() Ref { - i := ref.Dynamic() - if i < 0 { - return nil - } - return ref[i:] -} - -// IsGround returns true if all of the parts of the Ref are ground. -func (ref Ref) IsGround() bool { - if len(ref) == 0 { - return true - } - return termSliceIsGround(ref[1:]) -} - -// IsNested returns true if this ref contains other Refs. -func (ref Ref) IsNested() bool { - for _, x := range ref { - if _, ok := x.Value.(Ref); ok { - return true - } - } - return false -} - -// Ptr returns a slash-separated path string for this ref. If the ref -// contains non-string terms this function returns an error. Path -// components are escaped. -func (ref Ref) Ptr() (string, error) { - parts := make([]string, 0, len(ref)-1) - for _, term := range ref[1:] { - if str, ok := term.Value.(String); ok { - parts = append(parts, url.PathEscape(string(str))) - } else { - return "", fmt.Errorf("invalid path value type") - } - } - return strings.Join(parts, "/"), nil + return v1.RefTerm(r...) } -var varRegexp = regexp.MustCompile("^[[:alpha:]_][[:alpha:][:digit:]_]*$") - func IsVarCompatibleString(s string) bool { - return varRegexp.MatchString(s) -} - -func (ref Ref) String() string { - if len(ref) == 0 { - return "" - } - buf := []string{ref[0].Value.String()} - path := ref[1:] - for _, p := range path { - switch p := p.Value.(type) { - case String: - str := string(p) - if varRegexp.MatchString(str) && len(buf) > 0 && !IsKeyword(str) { - buf = append(buf, "."+str) - } else { - buf = append(buf, "["+p.String()+"]") - } - default: - buf = append(buf, "["+p.String()+"]") - } - } - return strings.Join(buf, "") -} - -// OutputVars returns a VarSet containing variables that would be bound by evaluating -// this expression in isolation. -func (ref Ref) OutputVars() VarSet { - vis := NewVarVisitor().WithParams(VarVisitorParams{SkipRefHead: true}) - vis.Walk(ref) - return vis.Vars() -} - -func (ref Ref) toArray() *Array { - a := NewArray() - for _, term := range ref { - if _, ok := term.Value.(String); ok { - a = a.Append(term) - } else { - a = a.Append(StringTerm(term.Value.String())) - } - } - return a + return v1.IsVarCompatibleString(s) } // QueryIterator defines the interface for querying AST documents with references. -type QueryIterator func(map[Var]Value, Value) error +type QueryIterator = v1.QueryIterator // ArrayTerm creates a new Term with an Array value. func ArrayTerm(a ...*Term) *Term { - return NewTerm(NewArray(a...)) + return v1.ArrayTerm(a...) } // NewArray creates an Array with the terms provided. The array will // use the provided term slice. func NewArray(a ...*Term) *Array { - hs := make([]int, len(a)) - for i, e := range a { - hs[i] = e.Value.Hash() - } - arr := &Array{elems: a, hashs: hs, ground: termSliceIsGround(a)} - arr.rehash() - return arr + return v1.NewArray(a...) } // Array represents an array as defined by the language. Arrays are similar to the // same types as defined by JSON with the exception that they can contain Vars // and References. -type Array struct { - elems []*Term - hashs []int // element hashes - hash int - ground bool -} - -// Copy returns a deep copy of arr. -func (arr *Array) Copy() *Array { - cpy := make([]int, len(arr.elems)) - copy(cpy, arr.hashs) - return &Array{ - elems: termSliceCopy(arr.elems), - hashs: cpy, - hash: arr.hash, - ground: arr.IsGround()} -} - -// Equal returns true if arr is equal to other. -func (arr *Array) Equal(other Value) bool { - return Compare(arr, other) == 0 -} - -// Compare compares arr to other, return <0, 0, or >0 if it is less than, equal to, -// or greater than other. -func (arr *Array) Compare(other Value) int { - return Compare(arr, other) -} - -// Find returns the value at the index or an out-of-range error. -func (arr *Array) Find(path Ref) (Value, error) { - if len(path) == 0 { - return arr, nil - } - num, ok := path[0].Value.(Number) - if !ok { - return nil, errFindNotFound - } - i, ok := num.Int() - if !ok { - return nil, errFindNotFound - } - if i < 0 || i >= arr.Len() { - return nil, errFindNotFound - } - return arr.Elem(i).Value.Find(path[1:]) -} - -// Get returns the element at pos or nil if not possible. -func (arr *Array) Get(pos *Term) *Term { - num, ok := pos.Value.(Number) - if !ok { - return nil - } - - i, ok := num.Int() - if !ok { - return nil - } - - if i >= 0 && i < len(arr.elems) { - return arr.elems[i] - } - - return nil -} - -// Sorted returns a new Array that contains the sorted elements of arr. -func (arr *Array) Sorted() *Array { - cpy := make([]*Term, len(arr.elems)) - for i := range cpy { - cpy[i] = arr.elems[i] - } - sort.Sort(termSlice(cpy)) - a := NewArray(cpy...) - a.hashs = arr.hashs - return a -} - -// Hash returns the hash code for the Value. -func (arr *Array) Hash() int { - return arr.hash -} - -// IsGround returns true if all of the Array elements are ground. -func (arr *Array) IsGround() bool { - return arr.ground -} - -// MarshalJSON returns JSON encoded bytes representing arr. -func (arr *Array) MarshalJSON() ([]byte, error) { - if len(arr.elems) == 0 { - return []byte(`[]`), nil - } - return json.Marshal(arr.elems) -} - -func (arr *Array) String() string { - var b strings.Builder - b.WriteRune('[') - for i, e := range arr.elems { - if i > 0 { - b.WriteString(", ") - } - b.WriteString(e.String()) - } - b.WriteRune(']') - return b.String() -} - -// Len returns the number of elements in the array. -func (arr *Array) Len() int { - return len(arr.elems) -} - -// Elem returns the element i of arr. -func (arr *Array) Elem(i int) *Term { - return arr.elems[i] -} - -// Set sets the element i of arr. -func (arr *Array) Set(i int, v *Term) { - arr.set(i, v) -} - -// rehash updates the cached hash of arr. -func (arr *Array) rehash() { - arr.hash = 0 - for _, h := range arr.hashs { - arr.hash += h - } -} - -// set sets the element i of arr. -func (arr *Array) set(i int, v *Term) { - arr.ground = arr.ground && v.IsGround() - arr.elems[i] = v - arr.hashs[i] = v.Value.Hash() - arr.rehash() -} - -// Slice returns a slice of arr starting from i index to j. -1 -// indicates the end of the array. The returned value array is not a -// copy and any modifications to either of arrays may be reflected to -// the other. -func (arr *Array) Slice(i, j int) *Array { - var elems []*Term - var hashs []int - if j == -1 { - elems = arr.elems[i:] - hashs = arr.hashs[i:] - } else { - elems = arr.elems[i:j] - hashs = arr.hashs[i:j] - } - // If arr is ground, the slice is, too. - // If it's not, the slice could still be. - gr := arr.ground || termSliceIsGround(elems) - - s := &Array{elems: elems, hashs: hashs, ground: gr} - s.rehash() - return s -} - -// Iter calls f on each element in arr. If f returns an error, -// iteration stops and the return value is the error. -func (arr *Array) Iter(f func(*Term) error) error { - for i := range arr.elems { - if err := f(arr.elems[i]); err != nil { - return err - } - } - return nil -} - -// Until calls f on each element in arr. If f returns true, iteration stops. -func (arr *Array) Until(f func(*Term) bool) bool { - err := arr.Iter(func(t *Term) error { - if f(t) { - return errStop - } - return nil - }) - return err != nil -} - -// Foreach calls f on each element in arr. -func (arr *Array) Foreach(f func(*Term)) { - _ = arr.Iter(func(t *Term) error { - f(t) - return nil - }) // ignore error -} - -// Append appends a term to arr, returning the appended array. -func (arr *Array) Append(v *Term) *Array { - cpy := *arr - cpy.elems = append(arr.elems, v) - cpy.hashs = append(arr.hashs, v.Value.Hash()) - cpy.hash = arr.hash + v.Value.Hash() - cpy.ground = arr.ground && v.IsGround() - return &cpy -} +type Array = v1.Array // Set represents a set as defined by the language. -type Set interface { - Value - Len() int - Copy() Set - Diff(Set) Set - Intersect(Set) Set - Union(Set) Set - Add(*Term) - Iter(func(*Term) error) error - Until(func(*Term) bool) bool - Foreach(func(*Term)) - Contains(*Term) bool - Map(func(*Term) (*Term, error)) (Set, error) - Reduce(*Term, func(*Term, *Term) (*Term, error)) (*Term, error) - Sorted() *Array - Slice() []*Term -} +type Set = v1.Set // NewSet returns a new Set containing t. func NewSet(t ...*Term) Set { - s := newset(len(t)) - for i := range t { - s.Add(t[i]) - } - return s + return v1.NewSet(t...) } -func newset(n int) *set { - var keys []*Term - if n > 0 { - keys = make([]*Term, 0, n) - } - return &set{ - elems: make(map[int]*Term, n), - keys: keys, - hash: 0, - ground: true, - sortGuard: new(sync.Once), - } -} - -// SetTerm returns a new Term representing a set containing terms t. func SetTerm(t ...*Term) *Term { - set := NewSet(t...) - return &Term{ - Value: set, - } -} - -type set struct { - elems map[int]*Term - keys []*Term - hash int - ground bool - sortGuard *sync.Once // Prevents race condition around sorting. -} - -// Copy returns a deep copy of s. -func (s *set) Copy() Set { - cpy := newset(s.Len()) - s.Foreach(func(x *Term) { - cpy.Add(x.Copy()) - }) - cpy.hash = s.hash - cpy.ground = s.ground - return cpy -} - -// IsGround returns true if all terms in s are ground. -func (s *set) IsGround() bool { - return s.ground -} - -// Hash returns a hash code for s. -func (s *set) Hash() int { - return s.hash -} - -func (s *set) String() string { - if s.Len() == 0 { - return "set()" - } - var b strings.Builder - b.WriteRune('{') - for i := range s.sortedKeys() { - if i > 0 { - b.WriteString(", ") - } - b.WriteString(s.keys[i].Value.String()) - } - b.WriteRune('}') - return b.String() -} - -func (s *set) sortedKeys() []*Term { - s.sortGuard.Do(func() { - sort.Sort(termSlice(s.keys)) - }) - return s.keys -} - -// Compare compares s to other, return <0, 0, or >0 if it is less than, equal to, -// or greater than other. -func (s *set) Compare(other Value) int { - o1 := sortOrder(s) - o2 := sortOrder(other) - if o1 < o2 { - return -1 - } else if o1 > o2 { - return 1 - } - t := other.(*set) - return termSliceCompare(s.sortedKeys(), t.sortedKeys()) -} - -// Find returns the set or dereferences the element itself. -func (s *set) Find(path Ref) (Value, error) { - if len(path) == 0 { - return s, nil - } - if !s.Contains(path[0]) { - return nil, errFindNotFound - } - return path[0].Value.Find(path[1:]) -} - -// Diff returns elements in s that are not in other. -func (s *set) Diff(other Set) Set { - r := NewSet() - s.Foreach(func(x *Term) { - if !other.Contains(x) { - r.Add(x) - } - }) - return r -} - -// Intersect returns the set containing elements in both s and other. -func (s *set) Intersect(other Set) Set { - o := other.(*set) - n, m := s.Len(), o.Len() - ss := s - so := o - if m < n { - ss = o - so = s - n = m - } - - r := newset(n) - ss.Foreach(func(x *Term) { - if so.Contains(x) { - r.Add(x) - } - }) - return r -} - -// Union returns the set containing all elements of s and other. -func (s *set) Union(other Set) Set { - r := NewSet() - s.Foreach(func(x *Term) { - r.Add(x) - }) - other.Foreach(func(x *Term) { - r.Add(x) - }) - return r -} - -// Add updates s to include t. -func (s *set) Add(t *Term) { - s.insert(t) -} - -// Iter calls f on each element in s. If f returns an error, iteration stops -// and the return value is the error. -func (s *set) Iter(f func(*Term) error) error { - for i := range s.sortedKeys() { - if err := f(s.keys[i]); err != nil { - return err - } - } - return nil -} - -var errStop = errors.New("stop") - -// Until calls f on each element in s. If f returns true, iteration stops. -func (s *set) Until(f func(*Term) bool) bool { - err := s.Iter(func(t *Term) error { - if f(t) { - return errStop - } - return nil - }) - return err != nil -} - -// Foreach calls f on each element in s. -func (s *set) Foreach(f func(*Term)) { - _ = s.Iter(func(t *Term) error { - f(t) - return nil - }) // ignore error -} - -// Map returns a new Set obtained by applying f to each value in s. -func (s *set) Map(f func(*Term) (*Term, error)) (Set, error) { - set := NewSet() - err := s.Iter(func(x *Term) error { - term, err := f(x) - if err != nil { - return err - } - set.Add(term) - return nil - }) - if err != nil { - return nil, err - } - return set, nil -} - -// Reduce returns a Term produced by applying f to each value in s. The first -// argument to f is the reduced value (starting with i) and the second argument -// to f is the element in s. -func (s *set) Reduce(i *Term, f func(*Term, *Term) (*Term, error)) (*Term, error) { - err := s.Iter(func(x *Term) error { - var err error - i, err = f(i, x) - if err != nil { - return err - } - return nil - }) - return i, err -} - -// Contains returns true if t is in s. -func (s *set) Contains(t *Term) bool { - return s.get(t) != nil -} - -// Len returns the number of elements in the set. -func (s *set) Len() int { - return len(s.keys) -} - -// MarshalJSON returns JSON encoded bytes representing s. -func (s *set) MarshalJSON() ([]byte, error) { - if s.keys == nil { - return []byte(`[]`), nil - } - return json.Marshal(s.sortedKeys()) -} - -// Sorted returns an Array that contains the sorted elements of s. -func (s *set) Sorted() *Array { - cpy := make([]*Term, len(s.keys)) - copy(cpy, s.sortedKeys()) - return NewArray(cpy...) -} - -// Slice returns a slice of terms contained in the set. -func (s *set) Slice() []*Term { - return s.sortedKeys() -} - -// NOTE(philipc): We assume a many-readers, single-writer model here. -// This method should NOT be used concurrently, or else we risk data races. -func (s *set) insert(x *Term) { - hash := x.Hash() - insertHash := hash - // This `equal` utility is duplicated and manually inlined a number of - // time in this file. Inlining it avoids heap allocations, so it makes - // a big performance difference: some operations like lookup become twice - // as slow without it. - var equal func(v Value) bool - - switch x := x.Value.(type) { - case Null, Boolean, String, Var: - equal = func(y Value) bool { return x == y } - case Number: - if xi, err := json.Number(x).Int64(); err == nil { - equal = func(y Value) bool { - if y, ok := y.(Number); ok { - if yi, err := json.Number(y).Int64(); err == nil { - return xi == yi - } - } - - return false - } - break - } - - // We use big.Rat for comparing big numbers. - // It replaces big.Float due to following reason: - // big.Float comes with a default precision of 64, and setting a - // larger precision results in more memory being allocated - // (regardless of the actual number we are parsing with SetString). - // - // Note: If we're so close to zero that big.Float says we are zero, do - // *not* big.Rat).SetString on the original string it'll potentially - // take very long. - var a *big.Rat - fa, ok := new(big.Float).SetString(string(x)) - if !ok { - panic("illegal value") - } - if fa.IsInt() { - if i, _ := fa.Int64(); i == 0 { - a = new(big.Rat).SetInt64(0) - } - } - if a == nil { - a, ok = new(big.Rat).SetString(string(x)) - if !ok { - panic("illegal value") - } - } - - equal = func(b Value) bool { - if bNum, ok := b.(Number); ok { - var b *big.Rat - fb, ok := new(big.Float).SetString(string(bNum)) - if !ok { - panic("illegal value") - } - if fb.IsInt() { - if i, _ := fb.Int64(); i == 0 { - b = new(big.Rat).SetInt64(0) - } - } - if b == nil { - b, ok = new(big.Rat).SetString(string(bNum)) - if !ok { - panic("illegal value") - } - } - - return a.Cmp(b) == 0 - } - - return false - } - default: - equal = func(y Value) bool { return Compare(x, y) == 0 } - } - - for curr, ok := s.elems[insertHash]; ok; { - if equal(curr.Value) { - return - } - - insertHash++ - curr, ok = s.elems[insertHash] - } - - s.elems[insertHash] = x - // O(1) insertion, but we'll have to re-sort the keys later. - s.keys = append(s.keys, x) - // Reset the sync.Once instance. - // See https://github.com/golang/go/issues/25955 for why we do it this way. - s.sortGuard = new(sync.Once) - - s.hash += hash - s.ground = s.ground && x.IsGround() -} - -func (s *set) get(x *Term) *Term { - hash := x.Hash() - // This `equal` utility is duplicated and manually inlined a number of - // time in this file. Inlining it avoids heap allocations, so it makes - // a big performance difference: some operations like lookup become twice - // as slow without it. - var equal func(v Value) bool - - switch x := x.Value.(type) { - case Null, Boolean, String, Var: - equal = func(y Value) bool { return x == y } - case Number: - if xi, err := json.Number(x).Int64(); err == nil { - equal = func(y Value) bool { - if y, ok := y.(Number); ok { - if yi, err := json.Number(y).Int64(); err == nil { - return xi == yi - } - } - - return false - } - break - } - - // We use big.Rat for comparing big numbers. - // It replaces big.Float due to following reason: - // big.Float comes with a default precision of 64, and setting a - // larger precision results in more memory being allocated - // (regardless of the actual number we are parsing with SetString). - // - // Note: If we're so close to zero that big.Float says we are zero, do - // *not* big.Rat).SetString on the original string it'll potentially - // take very long. - var a *big.Rat - fa, ok := new(big.Float).SetString(string(x)) - if !ok { - panic("illegal value") - } - if fa.IsInt() { - if i, _ := fa.Int64(); i == 0 { - a = new(big.Rat).SetInt64(0) - } - } - if a == nil { - a, ok = new(big.Rat).SetString(string(x)) - if !ok { - panic("illegal value") - } - } - - equal = func(b Value) bool { - if bNum, ok := b.(Number); ok { - var b *big.Rat - fb, ok := new(big.Float).SetString(string(bNum)) - if !ok { - panic("illegal value") - } - if fb.IsInt() { - if i, _ := fb.Int64(); i == 0 { - b = new(big.Rat).SetInt64(0) - } - } - if b == nil { - b, ok = new(big.Rat).SetString(string(bNum)) - if !ok { - panic("illegal value") - } - } - - return a.Cmp(b) == 0 - } - return false - - } - - default: - equal = func(y Value) bool { return Compare(x, y) == 0 } - } - - for curr, ok := s.elems[hash]; ok; { - if equal(curr.Value) { - return curr - } - - hash++ - curr, ok = s.elems[hash] - } - return nil + return v1.SetTerm(t...) } // Object represents an object as defined by the language. -type Object interface { - Value - Len() int - Get(*Term) *Term - Copy() Object - Insert(*Term, *Term) - Iter(func(*Term, *Term) error) error - Until(func(*Term, *Term) bool) bool - Foreach(func(*Term, *Term)) - Map(func(*Term, *Term) (*Term, *Term, error)) (Object, error) - Diff(other Object) Object - Intersect(other Object) [][3]*Term - Merge(other Object) (Object, bool) - MergeWith(other Object, conflictResolver func(v1, v2 *Term) (*Term, bool)) (Object, bool) - Filter(filter Object) (Object, error) - Keys() []*Term - KeysIterator() ObjectKeysIterator - get(k *Term) *objectElem // To prevent external implementations -} +type Object = v1.Object // NewObject creates a new Object with t. func NewObject(t ...[2]*Term) Object { - obj := newobject(len(t)) - for i := range t { - obj.Insert(t[i][0], t[i][1]) - } - return obj + return v1.NewObject(t...) } // ObjectTerm creates a new Term with an Object value. func ObjectTerm(o ...[2]*Term) *Term { - return &Term{Value: NewObject(o...)} + return v1.ObjectTerm(o...) } func LazyObject(blob map[string]interface{}) Object { - return &lazyObj{native: blob, cache: map[string]Value{}} -} - -type lazyObj struct { - strict Object - cache map[string]Value - native map[string]interface{} -} - -func (l *lazyObj) force() Object { - if l.strict == nil { - l.strict = MustInterfaceToValue(l.native).(Object) - // NOTE(jf): a possible performance improvement here would be to check how many - // entries have been realized to AST in the cache, and if some threshold compared to the - // total number of keys is exceeded, realize the remaining entries and set l.strict to l.cache. - l.cache = map[string]Value{} // We don't need the cache anymore; drop it to free up memory. - } - return l.strict -} - -func (l *lazyObj) Compare(other Value) int { - o1 := sortOrder(l) - o2 := sortOrder(other) - if o1 < o2 { - return -1 - } else if o2 < o1 { - return 1 - } - return l.force().Compare(other) -} - -func (l *lazyObj) Copy() Object { - return l -} - -func (l *lazyObj) Diff(other Object) Object { - return l.force().Diff(other) -} - -func (l *lazyObj) Intersect(other Object) [][3]*Term { - return l.force().Intersect(other) -} - -func (l *lazyObj) Iter(f func(*Term, *Term) error) error { - return l.force().Iter(f) -} - -func (l *lazyObj) Until(f func(*Term, *Term) bool) bool { - // NOTE(sr): there could be benefits in not forcing here -- if we abort because - // `f` returns true, we could save us from converting the rest of the object. - return l.force().Until(f) -} - -func (l *lazyObj) Foreach(f func(*Term, *Term)) { - l.force().Foreach(f) -} - -func (l *lazyObj) Filter(filter Object) (Object, error) { - return l.force().Filter(filter) -} - -func (l *lazyObj) Map(f func(*Term, *Term) (*Term, *Term, error)) (Object, error) { - return l.force().Map(f) -} - -func (l *lazyObj) MarshalJSON() ([]byte, error) { - return l.force().(*object).MarshalJSON() -} - -func (l *lazyObj) Merge(other Object) (Object, bool) { - return l.force().Merge(other) -} - -func (l *lazyObj) MergeWith(other Object, conflictResolver func(v1, v2 *Term) (*Term, bool)) (Object, bool) { - return l.force().MergeWith(other, conflictResolver) -} - -func (l *lazyObj) Len() int { - return len(l.native) -} - -func (l *lazyObj) String() string { - return l.force().String() + return v1.LazyObject(blob) } -// get is merely there to implement the Object interface -- `get` there serves the -// purpose of prohibiting external implementations. It's never called for lazyObj. -func (*lazyObj) get(*Term) *objectElem { - return nil -} - -func (l *lazyObj) Get(k *Term) *Term { - if l.strict != nil { - return l.strict.Get(k) - } - if s, ok := k.Value.(String); ok { - if v, ok := l.cache[string(s)]; ok { - return NewTerm(v) - } - - if val, ok := l.native[string(s)]; ok { - var converted Value - switch val := val.(type) { - case map[string]interface{}: - converted = LazyObject(val) - default: - converted = MustInterfaceToValue(val) - } - l.cache[string(s)] = converted - return NewTerm(converted) - } - } - return nil -} - -func (l *lazyObj) Insert(k, v *Term) { - l.force().Insert(k, v) -} - -func (*lazyObj) IsGround() bool { - return true -} - -func (l *lazyObj) Hash() int { - return l.force().Hash() -} - -func (l *lazyObj) Keys() []*Term { - if l.strict != nil { - return l.strict.Keys() - } - ret := make([]*Term, 0, len(l.native)) - for k := range l.native { - ret = append(ret, StringTerm(k)) - } - sort.Sort(termSlice(ret)) - return ret -} - -func (l *lazyObj) KeysIterator() ObjectKeysIterator { - return &lazyObjKeysIterator{keys: l.Keys()} -} - -type lazyObjKeysIterator struct { - current int - keys []*Term -} - -func (ki *lazyObjKeysIterator) Next() (*Term, bool) { - if ki.current == len(ki.keys) { - return nil, false - } - ki.current++ - return ki.keys[ki.current-1], true -} - -func (l *lazyObj) Find(path Ref) (Value, error) { - if l.strict != nil { - return l.strict.Find(path) - } - if len(path) == 0 { - return l, nil - } - if p0, ok := path[0].Value.(String); ok { - if v, ok := l.cache[string(p0)]; ok { - return v.Find(path[1:]) - } - - if v, ok := l.native[string(p0)]; ok { - var converted Value - switch v := v.(type) { - case map[string]interface{}: - converted = LazyObject(v) - default: - converted = MustInterfaceToValue(v) - } - l.cache[string(p0)] = converted - return converted.Find(path[1:]) - } - } - return nil, errFindNotFound -} - -type object struct { - elems map[int]*objectElem - keys objectElemSlice - ground int // number of key and value grounds. Counting is - // required to support insert's key-value replace. - hash int - sortGuard *sync.Once // Prevents race condition around sorting. -} - -func newobject(n int) *object { - var keys objectElemSlice - if n > 0 { - keys = make(objectElemSlice, 0, n) - } - return &object{ - elems: make(map[int]*objectElem, n), - keys: keys, - ground: 0, - hash: 0, - sortGuard: new(sync.Once), - } -} - -type objectElem struct { - key *Term - value *Term - next *objectElem -} - -type objectElemSlice []*objectElem - -func (s objectElemSlice) Less(i, j int) bool { return Compare(s[i].key.Value, s[j].key.Value) < 0 } -func (s objectElemSlice) Swap(i, j int) { x := s[i]; s[i] = s[j]; s[j] = x } -func (s objectElemSlice) Len() int { return len(s) } - // Item is a helper for constructing an tuple containing two Terms // representing a key/value pair in an Object. func Item(key, value *Term) [2]*Term { - return [2]*Term{key, value} -} - -func (obj *object) sortedKeys() objectElemSlice { - obj.sortGuard.Do(func() { - sort.Sort(obj.keys) - }) - return obj.keys -} - -// Compare compares obj to other, return <0, 0, or >0 if it is less than, equal to, -// or greater than other. -func (obj *object) Compare(other Value) int { - if x, ok := other.(*lazyObj); ok { - other = x.force() - } - o1 := sortOrder(obj) - o2 := sortOrder(other) - if o1 < o2 { - return -1 - } else if o2 < o1 { - return 1 - } - a := obj - b := other.(*object) - // Ensure that keys are in canonical sorted order before use! - akeys := a.sortedKeys() - bkeys := b.sortedKeys() - minLen := len(akeys) - if len(b.keys) < len(akeys) { - minLen = len(bkeys) - } - for i := 0; i < minLen; i++ { - keysCmp := Compare(akeys[i].key, bkeys[i].key) - if keysCmp < 0 { - return -1 - } - if keysCmp > 0 { - return 1 - } - valA := akeys[i].value - valB := bkeys[i].value - valCmp := Compare(valA, valB) - if valCmp != 0 { - return valCmp - } - } - if len(akeys) < len(bkeys) { - return -1 - } - if len(bkeys) < len(akeys) { - return 1 - } - return 0 -} - -// Find returns the value at the key or undefined. -func (obj *object) Find(path Ref) (Value, error) { - if len(path) == 0 { - return obj, nil - } - value := obj.Get(path[0]) - if value == nil { - return nil, errFindNotFound - } - return value.Value.Find(path[1:]) -} - -func (obj *object) Insert(k, v *Term) { - obj.insert(k, v) -} - -// Get returns the value of k in obj if k exists, otherwise nil. -func (obj *object) Get(k *Term) *Term { - if elem := obj.get(k); elem != nil { - return elem.value - } - return nil -} - -// Hash returns the hash code for the Value. -func (obj *object) Hash() int { - return obj.hash -} - -// IsGround returns true if all of the Object key/value pairs are ground. -func (obj *object) IsGround() bool { - return obj.ground == 2*len(obj.keys) -} - -// Copy returns a deep copy of obj. -func (obj *object) Copy() Object { - cpy, _ := obj.Map(func(k, v *Term) (*Term, *Term, error) { - return k.Copy(), v.Copy(), nil - }) - cpy.(*object).hash = obj.hash - return cpy -} - -// Diff returns a new Object that contains only the key/value pairs that exist in obj. -func (obj *object) Diff(other Object) Object { - r := NewObject() - obj.Foreach(func(k, v *Term) { - if other.Get(k) == nil { - r.Insert(k, v) - } - }) - return r -} - -// Intersect returns a slice of term triplets that represent the intersection of keys -// between obj and other. For each intersecting key, the values from obj and other are included -// as the last two terms in the triplet (respectively). -func (obj *object) Intersect(other Object) [][3]*Term { - r := [][3]*Term{} - obj.Foreach(func(k, v *Term) { - if v2 := other.Get(k); v2 != nil { - r = append(r, [3]*Term{k, v, v2}) - } - }) - return r -} - -// Iter calls the function f for each key-value pair in the object. If f -// returns an error, iteration stops and the error is returned. -func (obj *object) Iter(f func(*Term, *Term) error) error { - for _, node := range obj.sortedKeys() { - if err := f(node.key, node.value); err != nil { - return err - } - } - return nil -} - -// Until calls f for each key-value pair in the object. If f returns -// true, iteration stops and Until returns true. Otherwise, return -// false. -func (obj *object) Until(f func(*Term, *Term) bool) bool { - err := obj.Iter(func(k, v *Term) error { - if f(k, v) { - return errStop - } - return nil - }) - return err != nil -} - -// Foreach calls f for each key-value pair in the object. -func (obj *object) Foreach(f func(*Term, *Term)) { - _ = obj.Iter(func(k, v *Term) error { - f(k, v) - return nil - }) // ignore error -} - -// Map returns a new Object constructed by mapping each element in the object -// using the function f. -func (obj *object) Map(f func(*Term, *Term) (*Term, *Term, error)) (Object, error) { - cpy := newobject(obj.Len()) - err := obj.Iter(func(k, v *Term) error { - var err error - k, v, err = f(k, v) - if err != nil { - return err - } - cpy.insert(k, v) - return nil - }) - if err != nil { - return nil, err - } - return cpy, nil -} - -// Keys returns the keys of obj. -func (obj *object) Keys() []*Term { - keys := make([]*Term, len(obj.keys)) - - for i, elem := range obj.sortedKeys() { - keys[i] = elem.key - } - - return keys -} - -// Returns an iterator over the obj's keys. -func (obj *object) KeysIterator() ObjectKeysIterator { - return newobjectKeysIterator(obj) -} - -// MarshalJSON returns JSON encoded bytes representing obj. -func (obj *object) MarshalJSON() ([]byte, error) { - sl := make([][2]*Term, obj.Len()) - for i, node := range obj.sortedKeys() { - sl[i] = Item(node.key, node.value) - } - return json.Marshal(sl) -} - -// Merge returns a new Object containing the non-overlapping keys of obj and other. If there are -// overlapping keys between obj and other, the values of associated with the keys are merged. Only -// objects can be merged with other objects. If the values cannot be merged, the second turn value -// will be false. -func (obj object) Merge(other Object) (Object, bool) { - return obj.MergeWith(other, func(v1, v2 *Term) (*Term, bool) { - obj1, ok1 := v1.Value.(Object) - obj2, ok2 := v2.Value.(Object) - if !ok1 || !ok2 { - return nil, true - } - obj3, ok := obj1.Merge(obj2) - if !ok { - return nil, true - } - return NewTerm(obj3), false - }) -} - -// MergeWith returns a new Object containing the merged keys of obj and other. -// If there are overlapping keys between obj and other, the conflictResolver -// is called. The conflictResolver can return a merged value and a boolean -// indicating if the merge has failed and should stop. -func (obj object) MergeWith(other Object, conflictResolver func(v1, v2 *Term) (*Term, bool)) (Object, bool) { - result := NewObject() - stop := obj.Until(func(k, v *Term) bool { - v2 := other.Get(k) - // The key didn't exist in other, keep the original value - if v2 == nil { - result.Insert(k, v) - return false - } - - // The key exists in both, resolve the conflict if possible - merged, stop := conflictResolver(v, v2) - if !stop { - result.Insert(k, merged) - } - return stop - }) - - if stop { - return nil, false - } - - // Copy in any values from other for keys that don't exist in obj - other.Foreach(func(k, v *Term) { - if v2 := obj.Get(k); v2 == nil { - result.Insert(k, v) - } - }) - return result, true -} - -// Filter returns a new object from values in obj where the keys are -// found in filter. Array indices for values can be specified as -// number strings. -func (obj *object) Filter(filter Object) (Object, error) { - filtered, err := filterObject(obj, filter) - if err != nil { - return nil, err - } - return filtered.(Object), nil -} - -// Len returns the number of elements in the object. -func (obj object) Len() int { - return len(obj.keys) -} - -func (obj object) String() string { - var b strings.Builder - b.WriteRune('{') - - for i, elem := range obj.sortedKeys() { - if i > 0 { - b.WriteString(", ") - } - b.WriteString(elem.key.String()) - b.WriteString(": ") - b.WriteString(elem.value.String()) - } - b.WriteRune('}') - return b.String() -} - -func (obj *object) get(k *Term) *objectElem { - hash := k.Hash() - - // This `equal` utility is duplicated and manually inlined a number of - // time in this file. Inlining it avoids heap allocations, so it makes - // a big performance difference: some operations like lookup become twice - // as slow without it. - var equal func(v Value) bool - - switch x := k.Value.(type) { - case Null, Boolean, String, Var: - equal = func(y Value) bool { return x == y } - case Number: - if xi, err := json.Number(x).Int64(); err == nil { - equal = func(y Value) bool { - if y, ok := y.(Number); ok { - if yi, err := json.Number(y).Int64(); err == nil { - return xi == yi - } - } - - return false - } - break - } - - // We use big.Rat for comparing big numbers. - // It replaces big.Float due to following reason: - // big.Float comes with a default precision of 64, and setting a - // larger precision results in more memory being allocated - // (regardless of the actual number we are parsing with SetString). - // - // Note: If we're so close to zero that big.Float says we are zero, do - // *not* big.Rat).SetString on the original string it'll potentially - // take very long. - var a *big.Rat - fa, ok := new(big.Float).SetString(string(x)) - if !ok { - panic("illegal value") - } - if fa.IsInt() { - if i, _ := fa.Int64(); i == 0 { - a = new(big.Rat).SetInt64(0) - } - } - if a == nil { - a, ok = new(big.Rat).SetString(string(x)) - if !ok { - panic("illegal value") - } - } - - equal = func(b Value) bool { - if bNum, ok := b.(Number); ok { - var b *big.Rat - fb, ok := new(big.Float).SetString(string(bNum)) - if !ok { - panic("illegal value") - } - if fb.IsInt() { - if i, _ := fb.Int64(); i == 0 { - b = new(big.Rat).SetInt64(0) - } - } - if b == nil { - b, ok = new(big.Rat).SetString(string(bNum)) - if !ok { - panic("illegal value") - } - } - - return a.Cmp(b) == 0 - } - - return false - } - default: - equal = func(y Value) bool { return Compare(x, y) == 0 } - } - - for curr := obj.elems[hash]; curr != nil; curr = curr.next { - if equal(curr.key.Value) { - return curr - } - } - return nil -} - -// NOTE(philipc): We assume a many-readers, single-writer model here. -// This method should NOT be used concurrently, or else we risk data races. -func (obj *object) insert(k, v *Term) { - hash := k.Hash() - head := obj.elems[hash] - // This `equal` utility is duplicated and manually inlined a number of - // time in this file. Inlining it avoids heap allocations, so it makes - // a big performance difference: some operations like lookup become twice - // as slow without it. - var equal func(v Value) bool - - switch x := k.Value.(type) { - case Null, Boolean, String, Var: - equal = func(y Value) bool { return x == y } - case Number: - if xi, err := json.Number(x).Int64(); err == nil { - equal = func(y Value) bool { - if y, ok := y.(Number); ok { - if yi, err := json.Number(y).Int64(); err == nil { - return xi == yi - } - } - - return false - } - break - } - - // We use big.Rat for comparing big numbers. - // It replaces big.Float due to following reason: - // big.Float comes with a default precision of 64, and setting a - // larger precision results in more memory being allocated - // (regardless of the actual number we are parsing with SetString). - // - // Note: If we're so close to zero that big.Float says we are zero, do - // *not* big.Rat).SetString on the original string it'll potentially - // take very long. - var a *big.Rat - fa, ok := new(big.Float).SetString(string(x)) - if !ok { - panic("illegal value") - } - if fa.IsInt() { - if i, _ := fa.Int64(); i == 0 { - a = new(big.Rat).SetInt64(0) - } - } - if a == nil { - a, ok = new(big.Rat).SetString(string(x)) - if !ok { - panic("illegal value") - } - } - - equal = func(b Value) bool { - if bNum, ok := b.(Number); ok { - var b *big.Rat - fb, ok := new(big.Float).SetString(string(bNum)) - if !ok { - panic("illegal value") - } - if fb.IsInt() { - if i, _ := fb.Int64(); i == 0 { - b = new(big.Rat).SetInt64(0) - } - } - if b == nil { - b, ok = new(big.Rat).SetString(string(bNum)) - if !ok { - panic("illegal value") - } - } - - return a.Cmp(b) == 0 - } - - return false - } - default: - equal = func(y Value) bool { return Compare(x, y) == 0 } - } - - for curr := head; curr != nil; curr = curr.next { - if equal(curr.key.Value) { - // The ground bit of the value may change in - // replace, hence adjust the counter per old - // and new value. - - if curr.value.IsGround() { - obj.ground-- - } - if v.IsGround() { - obj.ground++ - } - - curr.value = v - - obj.rehash() - return - } - } - elem := &objectElem{ - key: k, - value: v, - next: head, - } - obj.elems[hash] = elem - // O(1) insertion, but we'll have to re-sort the keys later. - obj.keys = append(obj.keys, elem) - // Reset the sync.Once instance. - // See https://github.com/golang/go/issues/25955 for why we do it this way. - obj.sortGuard = new(sync.Once) - obj.hash += hash + v.Hash() - - if k.IsGround() { - obj.ground++ - } - if v.IsGround() { - obj.ground++ - } -} - -func (obj *object) rehash() { - // obj.keys is considered truth, from which obj.hash and obj.elems are recalculated. - - obj.hash = 0 - obj.elems = make(map[int]*objectElem, len(obj.keys)) - - for _, elem := range obj.keys { - hash := elem.key.Hash() - obj.hash += hash + elem.value.Hash() - obj.elems[hash] = elem - } -} - -func filterObject(o Value, filter Value) (Value, error) { - if filter.Compare(Null{}) == 0 { - return o, nil - } - - filteredObj, ok := filter.(*object) - if !ok { - return nil, fmt.Errorf("invalid filter value %q, expected an object", filter) - } - - switch v := o.(type) { - case String, Number, Boolean, Null: - return o, nil - case *Array: - values := NewArray() - for i := 0; i < v.Len(); i++ { - subFilter := filteredObj.Get(StringTerm(strconv.Itoa(i))) - if subFilter != nil { - filteredValue, err := filterObject(v.Elem(i).Value, subFilter.Value) - if err != nil { - return nil, err - } - values = values.Append(NewTerm(filteredValue)) - } - } - return values, nil - case Set: - values := NewSet() - err := v.Iter(func(t *Term) error { - if filteredObj.Get(t) != nil { - filteredValue, err := filterObject(t.Value, filteredObj.Get(t).Value) - if err != nil { - return err - } - values.Add(NewTerm(filteredValue)) - } - return nil - }) - return values, err - case *object: - values := NewObject() - - iterObj := v - other := filteredObj - if v.Len() < filteredObj.Len() { - iterObj = filteredObj - other = v - } - - err := iterObj.Iter(func(key *Term, _ *Term) error { - if other.Get(key) != nil { - filteredValue, err := filterObject(v.Get(key).Value, filteredObj.Get(key).Value) - if err != nil { - return err - } - values.Insert(key, NewTerm(filteredValue)) - } - return nil - }) - return values, err - default: - return nil, fmt.Errorf("invalid object value type %q", v) - } + return v1.Item(key, value) } // NOTE(philipc): The only way to get an ObjectKeyIterator should be // from an Object. This ensures that the iterator can have implementation- // specific details internally, with no contracts except to the very // limited interface. -type ObjectKeysIterator interface { - Next() (*Term, bool) -} - -type objectKeysIterator struct { - obj *object - numKeys int - index int -} - -func newobjectKeysIterator(o *object) ObjectKeysIterator { - return &objectKeysIterator{ - obj: o, - numKeys: o.Len(), - index: 0, - } -} - -func (oki *objectKeysIterator) Next() (*Term, bool) { - if oki.index == oki.numKeys || oki.numKeys == 0 { - return nil, false - } - oki.index++ - return oki.obj.sortedKeys()[oki.index-1].key, true -} +type ObjectKeysIterator = v1.ObjectKeysIterator // ArrayComprehension represents an array comprehension as defined in the language. -type ArrayComprehension struct { - Term *Term `json:"term"` - Body Body `json:"body"` -} +type ArrayComprehension = v1.ArrayComprehension // ArrayComprehensionTerm creates a new Term with an ArrayComprehension value. func ArrayComprehensionTerm(term *Term, body Body) *Term { - return &Term{ - Value: &ArrayComprehension{ - Term: term, - Body: body, - }, - } -} - -// Copy returns a deep copy of ac. -func (ac *ArrayComprehension) Copy() *ArrayComprehension { - cpy := *ac - cpy.Body = ac.Body.Copy() - cpy.Term = ac.Term.Copy() - return &cpy -} - -// Equal returns true if ac is equal to other. -func (ac *ArrayComprehension) Equal(other Value) bool { - return Compare(ac, other) == 0 -} - -// Compare compares ac to other, return <0, 0, or >0 if it is less than, equal to, -// or greater than other. -func (ac *ArrayComprehension) Compare(other Value) int { - return Compare(ac, other) -} - -// Find returns the current value or a not found error. -func (ac *ArrayComprehension) Find(path Ref) (Value, error) { - if len(path) == 0 { - return ac, nil - } - return nil, errFindNotFound -} - -// Hash returns the hash code of the Value. -func (ac *ArrayComprehension) Hash() int { - return ac.Term.Hash() + ac.Body.Hash() -} - -// IsGround returns true if the Term and Body are ground. -func (ac *ArrayComprehension) IsGround() bool { - return ac.Term.IsGround() && ac.Body.IsGround() -} - -func (ac *ArrayComprehension) String() string { - return "[" + ac.Term.String() + " | " + ac.Body.String() + "]" + return v1.ArrayComprehensionTerm(term, body) } // ObjectComprehension represents an object comprehension as defined in the language. -type ObjectComprehension struct { - Key *Term `json:"key"` - Value *Term `json:"value"` - Body Body `json:"body"` -} +type ObjectComprehension = v1.ObjectComprehension // ObjectComprehensionTerm creates a new Term with an ObjectComprehension value. func ObjectComprehensionTerm(key, value *Term, body Body) *Term { - return &Term{ - Value: &ObjectComprehension{ - Key: key, - Value: value, - Body: body, - }, - } -} - -// Copy returns a deep copy of oc. -func (oc *ObjectComprehension) Copy() *ObjectComprehension { - cpy := *oc - cpy.Body = oc.Body.Copy() - cpy.Key = oc.Key.Copy() - cpy.Value = oc.Value.Copy() - return &cpy -} - -// Equal returns true if oc is equal to other. -func (oc *ObjectComprehension) Equal(other Value) bool { - return Compare(oc, other) == 0 -} - -// Compare compares oc to other, return <0, 0, or >0 if it is less than, equal to, -// or greater than other. -func (oc *ObjectComprehension) Compare(other Value) int { - return Compare(oc, other) -} - -// Find returns the current value or a not found error. -func (oc *ObjectComprehension) Find(path Ref) (Value, error) { - if len(path) == 0 { - return oc, nil - } - return nil, errFindNotFound -} - -// Hash returns the hash code of the Value. -func (oc *ObjectComprehension) Hash() int { - return oc.Key.Hash() + oc.Value.Hash() + oc.Body.Hash() -} - -// IsGround returns true if the Key, Value and Body are ground. -func (oc *ObjectComprehension) IsGround() bool { - return oc.Key.IsGround() && oc.Value.IsGround() && oc.Body.IsGround() -} - -func (oc *ObjectComprehension) String() string { - return "{" + oc.Key.String() + ": " + oc.Value.String() + " | " + oc.Body.String() + "}" + return v1.ObjectComprehensionTerm(key, value, body) } // SetComprehension represents a set comprehension as defined in the language. -type SetComprehension struct { - Term *Term `json:"term"` - Body Body `json:"body"` -} +type SetComprehension = v1.SetComprehension // SetComprehensionTerm creates a new Term with an SetComprehension value. func SetComprehensionTerm(term *Term, body Body) *Term { - return &Term{ - Value: &SetComprehension{ - Term: term, - Body: body, - }, - } -} - -// Copy returns a deep copy of sc. -func (sc *SetComprehension) Copy() *SetComprehension { - cpy := *sc - cpy.Body = sc.Body.Copy() - cpy.Term = sc.Term.Copy() - return &cpy -} - -// Equal returns true if sc is equal to other. -func (sc *SetComprehension) Equal(other Value) bool { - return Compare(sc, other) == 0 -} - -// Compare compares sc to other, return <0, 0, or >0 if it is less than, equal to, -// or greater than other. -func (sc *SetComprehension) Compare(other Value) int { - return Compare(sc, other) -} - -// Find returns the current value or a not found error. -func (sc *SetComprehension) Find(path Ref) (Value, error) { - if len(path) == 0 { - return sc, nil - } - return nil, errFindNotFound -} - -// Hash returns the hash code of the Value. -func (sc *SetComprehension) Hash() int { - return sc.Term.Hash() + sc.Body.Hash() -} - -// IsGround returns true if the Term and Body are ground. -func (sc *SetComprehension) IsGround() bool { - return sc.Term.IsGround() && sc.Body.IsGround() -} - -func (sc *SetComprehension) String() string { - return "{" + sc.Term.String() + " | " + sc.Body.String() + "}" + return v1.SetComprehensionTerm(term, body) } // Call represents as function call in the language. -type Call []*Term +type Call = v1.Call // CallTerm returns a new Term with a Call value defined by terms. The first // term is the operator and the rest are operands. func CallTerm(terms ...*Term) *Term { - return NewTerm(Call(terms)) -} - -// Copy returns a deep copy of c. -func (c Call) Copy() Call { - return termSliceCopy(c) -} - -// Compare compares c to other, return <0, 0, or >0 if it is less than, equal to, -// or greater than other. -func (c Call) Compare(other Value) int { - return Compare(c, other) -} - -// Find returns the current value or a not found error. -func (c Call) Find(Ref) (Value, error) { - return nil, errFindNotFound -} - -// Hash returns the hash code for the Value. -func (c Call) Hash() int { - return termSliceHash(c) -} - -// IsGround returns true if the Value is ground. -func (c Call) IsGround() bool { - return termSliceIsGround(c) -} - -// MakeExpr returns an ew Expr from this call. -func (c Call) MakeExpr(output *Term) *Expr { - terms := []*Term(c) - return NewExpr(append(terms, output)) -} - -func (c Call) String() string { - args := make([]string, len(c)-1) - for i := 1; i < len(c); i++ { - args[i-1] = c[i].String() - } - return fmt.Sprintf("%v(%v)", c[0], strings.Join(args, ", ")) -} - -func termSliceCopy(a []*Term) []*Term { - cpy := make([]*Term, len(a)) - for i := range a { - cpy[i] = a[i].Copy() - } - return cpy -} - -func termSliceEqual(a, b []*Term) bool { - if len(a) == len(b) { - for i := range a { - if !a[i].Equal(b[i]) { - return false - } - } - return true - } - return false -} - -func termSliceHash(a []*Term) int { - var hash int - for _, v := range a { - hash += v.Value.Hash() - } - return hash -} - -func termSliceIsGround(a []*Term) bool { - for _, v := range a { - if !v.IsGround() { - return false - } - } - return true -} - -// NOTE(tsandall): The unmarshalling errors in these functions are not -// helpful for callers because they do not identify the source of the -// unmarshalling error. Because OPA doesn't accept JSON describing ASTs -// from callers, this is acceptable (for now). If that changes in the future, -// the error messages should be revisited. The current approach focuses -// on the happy path and treats all errors the same. If better error -// reporting is needed, the error paths will need to be fleshed out. - -func unmarshalBody(b []interface{}) (Body, error) { - buf := Body{} - for _, e := range b { - if m, ok := e.(map[string]interface{}); ok { - expr := &Expr{} - if err := unmarshalExpr(expr, m); err == nil { - buf = append(buf, expr) - continue - } - } - goto unmarshal_error - } - return buf, nil -unmarshal_error: - return nil, fmt.Errorf("ast: unable to unmarshal body") -} - -func unmarshalExpr(expr *Expr, v map[string]interface{}) error { - if x, ok := v["negated"]; ok { - if b, ok := x.(bool); ok { - expr.Negated = b - } else { - return fmt.Errorf("ast: unable to unmarshal negated field with type: %T (expected true or false)", v["negated"]) - } - } - if generatedRaw, ok := v["generated"]; ok { - if b, ok := generatedRaw.(bool); ok { - expr.Generated = b - } else { - return fmt.Errorf("ast: unable to unmarshal generated field with type: %T (expected true or false)", v["generated"]) - } - } - - if err := unmarshalExprIndex(expr, v); err != nil { - return err - } - switch ts := v["terms"].(type) { - case map[string]interface{}: - t, err := unmarshalTerm(ts) - if err != nil { - return err - } - expr.Terms = t - case []interface{}: - terms, err := unmarshalTermSlice(ts) - if err != nil { - return err - } - expr.Terms = terms - default: - return fmt.Errorf(`ast: unable to unmarshal terms field with type: %T (expected {"value": ..., "type": ...} or [{"value": ..., "type": ...}, ...])`, v["terms"]) - } - if x, ok := v["with"]; ok { - if sl, ok := x.([]interface{}); ok { - ws := make([]*With, len(sl)) - for i := range sl { - var err error - ws[i], err = unmarshalWith(sl[i]) - if err != nil { - return err - } - } - expr.With = ws - } - } - if loc, ok := v["location"].(map[string]interface{}); ok { - expr.Location = &Location{} - if err := unmarshalLocation(expr.Location, loc); err != nil { - return err - } - } - return nil -} - -func unmarshalLocation(loc *Location, v map[string]interface{}) error { - if x, ok := v["file"]; ok { - if s, ok := x.(string); ok { - loc.File = s - } else { - return fmt.Errorf("ast: unable to unmarshal file field with type: %T (expected string)", v["file"]) - } - } - if x, ok := v["row"]; ok { - if n, ok := x.(json.Number); ok { - i64, err := n.Int64() - if err != nil { - return err - } - loc.Row = int(i64) - } else { - return fmt.Errorf("ast: unable to unmarshal row field with type: %T (expected number)", v["row"]) - } - } - if x, ok := v["col"]; ok { - if n, ok := x.(json.Number); ok { - i64, err := n.Int64() - if err != nil { - return err - } - loc.Col = int(i64) - } else { - return fmt.Errorf("ast: unable to unmarshal col field with type: %T (expected number)", v["col"]) - } - } - - return nil -} - -func unmarshalExprIndex(expr *Expr, v map[string]interface{}) error { - if x, ok := v["index"]; ok { - if n, ok := x.(json.Number); ok { - i, err := n.Int64() - if err == nil { - expr.Index = int(i) - return nil - } - } - } - return fmt.Errorf("ast: unable to unmarshal index field with type: %T (expected integer)", v["index"]) -} - -func unmarshalTerm(m map[string]interface{}) (*Term, error) { - var term Term - - v, err := unmarshalValue(m) - if err != nil { - return nil, err - } - term.Value = v - - if loc, ok := m["location"].(map[string]interface{}); ok { - term.Location = &Location{} - if err := unmarshalLocation(term.Location, loc); err != nil { - return nil, err - } - } - - return &term, nil -} - -func unmarshalTermSlice(s []interface{}) ([]*Term, error) { - buf := []*Term{} - for _, x := range s { - if m, ok := x.(map[string]interface{}); ok { - t, err := unmarshalTerm(m) - if err == nil { - buf = append(buf, t) - continue - } - return nil, err - } - return nil, fmt.Errorf("ast: unable to unmarshal term") - } - return buf, nil -} - -func unmarshalTermSliceValue(d map[string]interface{}) ([]*Term, error) { - if s, ok := d["value"].([]interface{}); ok { - return unmarshalTermSlice(s) - } - return nil, fmt.Errorf(`ast: unable to unmarshal term (expected {"value": [...], "type": ...} where type is one of: ref, array, or set)`) -} - -func unmarshalWith(i interface{}) (*With, error) { - if m, ok := i.(map[string]interface{}); ok { - tgt, _ := m["target"].(map[string]interface{}) - target, err := unmarshalTerm(tgt) - if err == nil { - val, _ := m["value"].(map[string]interface{}) - value, err := unmarshalTerm(val) - if err == nil { - return &With{ - Target: target, - Value: value, - }, nil - } - return nil, err - } - return nil, err - } - return nil, fmt.Errorf(`ast: unable to unmarshal with modifier (expected {"target": {...}, "value": {...}})`) -} - -func unmarshalValue(d map[string]interface{}) (Value, error) { - v := d["value"] - switch d["type"] { - case "null": - return Null{}, nil - case "boolean": - if b, ok := v.(bool); ok { - return Boolean(b), nil - } - case "number": - if n, ok := v.(json.Number); ok { - return Number(n), nil - } - case "string": - if s, ok := v.(string); ok { - return String(s), nil - } - case "var": - if s, ok := v.(string); ok { - return Var(s), nil - } - case "ref": - if s, err := unmarshalTermSliceValue(d); err == nil { - return Ref(s), nil - } - case "array": - if s, err := unmarshalTermSliceValue(d); err == nil { - return NewArray(s...), nil - } - case "set": - if s, err := unmarshalTermSliceValue(d); err == nil { - set := NewSet() - for _, x := range s { - set.Add(x) - } - return set, nil - } - case "object": - if s, ok := v.([]interface{}); ok { - buf := NewObject() - for _, x := range s { - if i, ok := x.([]interface{}); ok && len(i) == 2 { - p, err := unmarshalTermSlice(i) - if err == nil { - buf.Insert(p[0], p[1]) - continue - } - } - goto unmarshal_error - } - return buf, nil - } - case "arraycomprehension", "setcomprehension": - if m, ok := v.(map[string]interface{}); ok { - t, ok := m["term"].(map[string]interface{}) - if !ok { - goto unmarshal_error - } - - term, err := unmarshalTerm(t) - if err != nil { - goto unmarshal_error - } - - b, ok := m["body"].([]interface{}) - if !ok { - goto unmarshal_error - } - - body, err := unmarshalBody(b) - if err != nil { - goto unmarshal_error - } - - if d["type"] == "arraycomprehension" { - return &ArrayComprehension{Term: term, Body: body}, nil - } - return &SetComprehension{Term: term, Body: body}, nil - } - case "objectcomprehension": - if m, ok := v.(map[string]interface{}); ok { - k, ok := m["key"].(map[string]interface{}) - if !ok { - goto unmarshal_error - } - - key, err := unmarshalTerm(k) - if err != nil { - goto unmarshal_error - } - - v, ok := m["value"].(map[string]interface{}) - if !ok { - goto unmarshal_error - } - - value, err := unmarshalTerm(v) - if err != nil { - goto unmarshal_error - } - - b, ok := m["body"].([]interface{}) - if !ok { - goto unmarshal_error - } - - body, err := unmarshalBody(b) - if err != nil { - goto unmarshal_error - } - - return &ObjectComprehension{Key: key, Value: value, Body: body}, nil - } - case "call": - if s, err := unmarshalTermSliceValue(d); err == nil { - return Call(s), nil - } - } -unmarshal_error: - return nil, fmt.Errorf("ast: unable to unmarshal term") + return v1.CallTerm(terms...) } diff --git a/vendor/github.com/open-policy-agent/opa/ast/transform.go b/vendor/github.com/open-policy-agent/opa/ast/transform.go index 391a16486..cfb137813 100644 --- a/vendor/github.com/open-policy-agent/opa/ast/transform.go +++ b/vendor/github.com/open-policy-agent/opa/ast/transform.go @@ -5,427 +5,42 @@ package ast import ( - "fmt" + v1 "github.com/open-policy-agent/opa/v1/ast" ) // Transformer defines the interface for transforming AST elements. If the // transformer returns nil and does not indicate an error, the AST element will // be set to nil and no transformations will be applied to children of the // element. -type Transformer interface { - Transform(interface{}) (interface{}, error) -} +type Transformer = v1.Transformer // Transform iterates the AST and calls the Transform function on the // Transformer t for x before recursing. func Transform(t Transformer, x interface{}) (interface{}, error) { - - if term, ok := x.(*Term); ok { - return Transform(t, term.Value) - } - - y, err := t.Transform(x) - if err != nil { - return x, err - } - - if y == nil { - return nil, nil - } - - var ok bool - switch y := y.(type) { - case *Module: - p, err := Transform(t, y.Package) - if err != nil { - return nil, err - } - if y.Package, ok = p.(*Package); !ok { - return nil, fmt.Errorf("illegal transform: %T != %T", y.Package, p) - } - for i := range y.Imports { - imp, err := Transform(t, y.Imports[i]) - if err != nil { - return nil, err - } - if y.Imports[i], ok = imp.(*Import); !ok { - return nil, fmt.Errorf("illegal transform: %T != %T", y.Imports[i], imp) - } - } - for i := range y.Rules { - rule, err := Transform(t, y.Rules[i]) - if err != nil { - return nil, err - } - if y.Rules[i], ok = rule.(*Rule); !ok { - return nil, fmt.Errorf("illegal transform: %T != %T", y.Rules[i], rule) - } - } - for i := range y.Annotations { - a, err := Transform(t, y.Annotations[i]) - if err != nil { - return nil, err - } - if y.Annotations[i], ok = a.(*Annotations); !ok { - return nil, fmt.Errorf("illegal transform: %T != %T", y.Annotations[i], a) - } - } - for i := range y.Comments { - comment, err := Transform(t, y.Comments[i]) - if err != nil { - return nil, err - } - if y.Comments[i], ok = comment.(*Comment); !ok { - return nil, fmt.Errorf("illegal transform: %T != %T", y.Comments[i], comment) - } - } - return y, nil - case *Package: - ref, err := Transform(t, y.Path) - if err != nil { - return nil, err - } - if y.Path, ok = ref.(Ref); !ok { - return nil, fmt.Errorf("illegal transform: %T != %T", y.Path, ref) - } - return y, nil - case *Import: - y.Path, err = transformTerm(t, y.Path) - if err != nil { - return nil, err - } - if y.Alias, err = transformVar(t, y.Alias); err != nil { - return nil, err - } - return y, nil - case *Rule: - if y.Head, err = transformHead(t, y.Head); err != nil { - return nil, err - } - if y.Body, err = transformBody(t, y.Body); err != nil { - return nil, err - } - if y.Else != nil { - rule, err := Transform(t, y.Else) - if err != nil { - return nil, err - } - if y.Else, ok = rule.(*Rule); !ok { - return nil, fmt.Errorf("illegal transform: %T != %T", y.Else, rule) - } - } - return y, nil - case *Head: - if y.Reference, err = transformRef(t, y.Reference); err != nil { - return nil, err - } - if y.Name, err = transformVar(t, y.Name); err != nil { - return nil, err - } - if y.Args, err = transformArgs(t, y.Args); err != nil { - return nil, err - } - if y.Key != nil { - if y.Key, err = transformTerm(t, y.Key); err != nil { - return nil, err - } - } - if y.Value != nil { - if y.Value, err = transformTerm(t, y.Value); err != nil { - return nil, err - } - } - return y, nil - case Args: - for i := range y { - if y[i], err = transformTerm(t, y[i]); err != nil { - return nil, err - } - } - return y, nil - case Body: - for i, e := range y { - e, err := Transform(t, e) - if err != nil { - return nil, err - } - if y[i], ok = e.(*Expr); !ok { - return nil, fmt.Errorf("illegal transform: %T != %T", y[i], e) - } - } - return y, nil - case *Expr: - switch ts := y.Terms.(type) { - case *SomeDecl: - decl, err := Transform(t, ts) - if err != nil { - return nil, err - } - if y.Terms, ok = decl.(*SomeDecl); !ok { - return nil, fmt.Errorf("illegal transform: %T != %T", y, decl) - } - return y, nil - case []*Term: - for i := range ts { - if ts[i], err = transformTerm(t, ts[i]); err != nil { - return nil, err - } - } - case *Term: - if y.Terms, err = transformTerm(t, ts); err != nil { - return nil, err - } - case *Every: - if ts.Key != nil { - ts.Key, err = transformTerm(t, ts.Key) - if err != nil { - return nil, err - } - } - ts.Value, err = transformTerm(t, ts.Value) - if err != nil { - return nil, err - } - ts.Domain, err = transformTerm(t, ts.Domain) - if err != nil { - return nil, err - } - ts.Body, err = transformBody(t, ts.Body) - if err != nil { - return nil, err - } - y.Terms = ts - } - for i, w := range y.With { - w, err := Transform(t, w) - if err != nil { - return nil, err - } - if y.With[i], ok = w.(*With); !ok { - return nil, fmt.Errorf("illegal transform: %T != %T", y.With[i], w) - } - } - return y, nil - case *With: - if y.Target, err = transformTerm(t, y.Target); err != nil { - return nil, err - } - if y.Value, err = transformTerm(t, y.Value); err != nil { - return nil, err - } - return y, nil - case Ref: - for i, term := range y { - if y[i], err = transformTerm(t, term); err != nil { - return nil, err - } - } - return y, nil - case *object: - return y.Map(func(k, v *Term) (*Term, *Term, error) { - k, err := transformTerm(t, k) - if err != nil { - return nil, nil, err - } - v, err = transformTerm(t, v) - if err != nil { - return nil, nil, err - } - return k, v, nil - }) - case *Array: - for i := 0; i < y.Len(); i++ { - v, err := transformTerm(t, y.Elem(i)) - if err != nil { - return nil, err - } - y.set(i, v) - } - return y, nil - case Set: - y, err = y.Map(func(term *Term) (*Term, error) { - return transformTerm(t, term) - }) - if err != nil { - return nil, err - } - return y, nil - case *ArrayComprehension: - if y.Term, err = transformTerm(t, y.Term); err != nil { - return nil, err - } - if y.Body, err = transformBody(t, y.Body); err != nil { - return nil, err - } - return y, nil - case *ObjectComprehension: - if y.Key, err = transformTerm(t, y.Key); err != nil { - return nil, err - } - if y.Value, err = transformTerm(t, y.Value); err != nil { - return nil, err - } - if y.Body, err = transformBody(t, y.Body); err != nil { - return nil, err - } - return y, nil - case *SetComprehension: - if y.Term, err = transformTerm(t, y.Term); err != nil { - return nil, err - } - if y.Body, err = transformBody(t, y.Body); err != nil { - return nil, err - } - return y, nil - case Call: - for i := range y { - if y[i], err = transformTerm(t, y[i]); err != nil { - return nil, err - } - } - return y, nil - default: - return y, nil - } + return v1.Transform(t, x) } // TransformRefs calls the function f on all references under x. func TransformRefs(x interface{}, f func(Ref) (Value, error)) (interface{}, error) { - t := &GenericTransformer{func(x interface{}) (interface{}, error) { - if r, ok := x.(Ref); ok { - return f(r) - } - return x, nil - }} - return Transform(t, x) + return v1.TransformRefs(x, f) } // TransformVars calls the function f on all vars under x. func TransformVars(x interface{}, f func(Var) (Value, error)) (interface{}, error) { - t := &GenericTransformer{func(x interface{}) (interface{}, error) { - if v, ok := x.(Var); ok { - return f(v) - } - return x, nil - }} - return Transform(t, x) + return v1.TransformVars(x, f) } // TransformComprehensions calls the functio nf on all comprehensions under x. func TransformComprehensions(x interface{}, f func(interface{}) (Value, error)) (interface{}, error) { - t := &GenericTransformer{func(x interface{}) (interface{}, error) { - switch x := x.(type) { - case *ArrayComprehension: - return f(x) - case *SetComprehension: - return f(x) - case *ObjectComprehension: - return f(x) - } - return x, nil - }} - return Transform(t, x) + return v1.TransformComprehensions(x, f) } // GenericTransformer implements the Transformer interface to provide a utility // to transform AST nodes using a closure. -type GenericTransformer struct { - f func(interface{}) (interface{}, error) -} +type GenericTransformer = v1.GenericTransformer // NewGenericTransformer returns a new GenericTransformer that will transform // AST nodes using the function f. func NewGenericTransformer(f func(x interface{}) (interface{}, error)) *GenericTransformer { - return &GenericTransformer{ - f: f, - } -} - -// Transform calls the function f on the GenericTransformer. -func (t *GenericTransformer) Transform(x interface{}) (interface{}, error) { - return t.f(x) -} - -func transformHead(t Transformer, head *Head) (*Head, error) { - y, err := Transform(t, head) - if err != nil { - return nil, err - } - h, ok := y.(*Head) - if !ok { - return nil, fmt.Errorf("illegal transform: %T != %T", head, y) - } - return h, nil -} - -func transformArgs(t Transformer, args Args) (Args, error) { - y, err := Transform(t, args) - if err != nil { - return nil, err - } - a, ok := y.(Args) - if !ok { - return nil, fmt.Errorf("illegal transform: %T != %T", args, y) - } - return a, nil -} - -func transformBody(t Transformer, body Body) (Body, error) { - y, err := Transform(t, body) - if err != nil { - return nil, err - } - r, ok := y.(Body) - if !ok { - return nil, fmt.Errorf("illegal transform: %T != %T", body, y) - } - return r, nil -} - -func transformTerm(t Transformer, term *Term) (*Term, error) { - v, err := transformValue(t, term.Value) - if err != nil { - return nil, err - } - r := &Term{ - Value: v, - Location: term.Location, - } - return r, nil -} - -func transformValue(t Transformer, v Value) (Value, error) { - v1, err := Transform(t, v) - if err != nil { - return nil, err - } - r, ok := v1.(Value) - if !ok { - return nil, fmt.Errorf("illegal transform: %T != %T", v, v1) - } - return r, nil -} - -func transformVar(t Transformer, v Var) (Var, error) { - v1, err := Transform(t, v) - if err != nil { - return "", err - } - r, ok := v1.(Var) - if !ok { - return "", fmt.Errorf("illegal transform: %T != %T", v, v1) - } - return r, nil -} - -func transformRef(t Transformer, r Ref) (Ref, error) { - r1, err := Transform(t, r) - if err != nil { - return nil, err - } - r2, ok := r1.(Ref) - if !ok { - return nil, fmt.Errorf("illegal transform: %T != %T", r, r2) - } - return r2, nil + return v1.NewGenericTransformer(f) } diff --git a/vendor/github.com/open-policy-agent/opa/ast/unify.go b/vendor/github.com/open-policy-agent/opa/ast/unify.go index 60244974a..3cb260272 100644 --- a/vendor/github.com/open-policy-agent/opa/ast/unify.go +++ b/vendor/github.com/open-policy-agent/opa/ast/unify.go @@ -4,232 +4,11 @@ package ast -func isRefSafe(ref Ref, safe VarSet) bool { - switch head := ref[0].Value.(type) { - case Var: - return safe.Contains(head) - case Call: - return isCallSafe(head, safe) - default: - for v := range ref[0].Vars() { - if !safe.Contains(v) { - return false - } - } - return true - } -} - -func isCallSafe(call Call, safe VarSet) bool { - vis := NewVarVisitor().WithParams(SafetyCheckVisitorParams) - vis.Walk(call) - unsafe := vis.Vars().Diff(safe) - return len(unsafe) == 0 -} +import v1 "github.com/open-policy-agent/opa/v1/ast" // Unify returns a set of variables that will be unified when the equality expression defined by // terms a and b is evaluated. The unifier assumes that variables in the VarSet safe are already // unified. func Unify(safe VarSet, a *Term, b *Term) VarSet { - u := &unifier{ - safe: safe, - unified: VarSet{}, - unknown: map[Var]VarSet{}, - } - u.unify(a, b) - return u.unified -} - -type unifier struct { - safe VarSet - unified VarSet - unknown map[Var]VarSet -} - -func (u *unifier) isSafe(x Var) bool { - return u.safe.Contains(x) || u.unified.Contains(x) -} - -func (u *unifier) unify(a *Term, b *Term) { - - switch a := a.Value.(type) { - - case Var: - switch b := b.Value.(type) { - case Var: - if u.isSafe(b) { - u.markSafe(a) - } else if u.isSafe(a) { - u.markSafe(b) - } else { - u.markUnknown(a, b) - u.markUnknown(b, a) - } - case *Array, Object: - u.unifyAll(a, b) - case Ref: - if isRefSafe(b, u.safe) { - u.markSafe(a) - } - case Call: - if isCallSafe(b, u.safe) { - u.markSafe(a) - } - default: - u.markSafe(a) - } - - case Ref: - if isRefSafe(a, u.safe) { - switch b := b.Value.(type) { - case Var: - u.markSafe(b) - case *Array, Object: - u.markAllSafe(b) - } - } - - case Call: - if isCallSafe(a, u.safe) { - switch b := b.Value.(type) { - case Var: - u.markSafe(b) - case *Array, Object: - u.markAllSafe(b) - } - } - - case *ArrayComprehension: - switch b := b.Value.(type) { - case Var: - u.markSafe(b) - case *Array: - u.markAllSafe(b) - } - case *ObjectComprehension: - switch b := b.Value.(type) { - case Var: - u.markSafe(b) - case *object: - u.markAllSafe(b) - } - case *SetComprehension: - switch b := b.Value.(type) { - case Var: - u.markSafe(b) - } - - case *Array: - switch b := b.Value.(type) { - case Var: - u.unifyAll(b, a) - case *ArrayComprehension, *ObjectComprehension, *SetComprehension: - u.markAllSafe(a) - case Ref: - if isRefSafe(b, u.safe) { - u.markAllSafe(a) - } - case Call: - if isCallSafe(b, u.safe) { - u.markAllSafe(a) - } - case *Array: - if a.Len() == b.Len() { - for i := 0; i < a.Len(); i++ { - u.unify(a.Elem(i), b.Elem(i)) - } - } - } - - case *object: - switch b := b.Value.(type) { - case Var: - u.unifyAll(b, a) - case Ref: - if isRefSafe(b, u.safe) { - u.markAllSafe(a) - } - case Call: - if isCallSafe(b, u.safe) { - u.markAllSafe(a) - } - case *object: - if a.Len() == b.Len() { - _ = a.Iter(func(k, v *Term) error { - if v2 := b.Get(k); v2 != nil { - u.unify(v, v2) - } - return nil - }) // impossible to return error - } - } - - default: - switch b := b.Value.(type) { - case Var: - u.markSafe(b) - } - } -} - -func (u *unifier) markAllSafe(x Value) { - vis := u.varVisitor() - vis.Walk(x) - for v := range vis.Vars() { - u.markSafe(v) - } -} - -func (u *unifier) markSafe(x Var) { - u.unified.Add(x) - - // Add dependencies of 'x' to safe set - vs := u.unknown[x] - delete(u.unknown, x) - for v := range vs { - u.markSafe(v) - } - - // Add dependants of 'x' to safe set if they have no more - // dependencies. - for v, deps := range u.unknown { - if deps.Contains(x) { - delete(deps, x) - if len(deps) == 0 { - u.markSafe(v) - } - } - } -} - -func (u *unifier) markUnknown(a, b Var) { - if _, ok := u.unknown[a]; !ok { - u.unknown[a] = NewVarSet() - } - u.unknown[a].Add(b) -} - -func (u *unifier) unifyAll(a Var, b Value) { - if u.isSafe(a) { - u.markAllSafe(b) - } else { - vis := u.varVisitor() - vis.Walk(b) - unsafe := vis.Vars().Diff(u.safe).Diff(u.unified) - if len(unsafe) == 0 { - u.markSafe(a) - } else { - for v := range unsafe { - u.markUnknown(a, v) - } - } - } -} - -func (u *unifier) varVisitor() *VarVisitor { - return NewVarVisitor().WithParams(VarVisitorParams{ - SkipRefHead: true, - SkipObjectKeys: true, - SkipClosures: true, - }) + return v1.Unify(safe, a, b) } diff --git a/vendor/github.com/open-policy-agent/opa/ast/varset.go b/vendor/github.com/open-policy-agent/opa/ast/varset.go index 14f531494..9e7db8efd 100644 --- a/vendor/github.com/open-policy-agent/opa/ast/varset.go +++ b/vendor/github.com/open-policy-agent/opa/ast/varset.go @@ -5,96 +5,13 @@ package ast import ( - "fmt" - "sort" + v1 "github.com/open-policy-agent/opa/v1/ast" ) // VarSet represents a set of variables. -type VarSet map[Var]struct{} +type VarSet = v1.VarSet // NewVarSet returns a new VarSet containing the specified variables. func NewVarSet(vs ...Var) VarSet { - s := VarSet{} - for _, v := range vs { - s.Add(v) - } - return s -} - -// Add updates the set to include the variable "v". -func (s VarSet) Add(v Var) { - s[v] = struct{}{} -} - -// Contains returns true if the set contains the variable "v". -func (s VarSet) Contains(v Var) bool { - _, ok := s[v] - return ok -} - -// Copy returns a shallow copy of the VarSet. -func (s VarSet) Copy() VarSet { - cpy := VarSet{} - for v := range s { - cpy.Add(v) - } - return cpy -} - -// Diff returns a VarSet containing variables in s that are not in vs. -func (s VarSet) Diff(vs VarSet) VarSet { - r := VarSet{} - for v := range s { - if !vs.Contains(v) { - r.Add(v) - } - } - return r -} - -// Equal returns true if s contains exactly the same elements as vs. -func (s VarSet) Equal(vs VarSet) bool { - if len(s.Diff(vs)) > 0 { - return false - } - return len(vs.Diff(s)) == 0 -} - -// Intersect returns a VarSet containing variables in s that are in vs. -func (s VarSet) Intersect(vs VarSet) VarSet { - r := VarSet{} - for v := range s { - if vs.Contains(v) { - r.Add(v) - } - } - return r -} - -// Sorted returns a sorted slice of vars from s. -func (s VarSet) Sorted() []Var { - sorted := make([]Var, 0, len(s)) - for v := range s { - sorted = append(sorted, v) - } - sort.Slice(sorted, func(i, j int) bool { - return sorted[i].Compare(sorted[j]) < 0 - }) - return sorted -} - -// Update merges the other VarSet into this VarSet. -func (s VarSet) Update(vs VarSet) { - for v := range vs { - s.Add(v) - } -} - -func (s VarSet) String() string { - tmp := make([]string, 0, len(s)) - for v := range s { - tmp = append(tmp, string(v)) - } - sort.Strings(tmp) - return fmt.Sprintf("%v", tmp) + return v1.NewVarSet(vs...) } diff --git a/vendor/github.com/open-policy-agent/opa/ast/visit.go b/vendor/github.com/open-policy-agent/opa/ast/visit.go index d83c31149..94823c6cc 100644 --- a/vendor/github.com/open-policy-agent/opa/ast/visit.go +++ b/vendor/github.com/open-policy-agent/opa/ast/visit.go @@ -4,780 +4,120 @@ package ast +import v1 "github.com/open-policy-agent/opa/v1/ast" + // Visitor defines the interface for iterating AST elements. The Visit function // can return a Visitor w which will be used to visit the children of the AST // element v. If the Visit function returns nil, the children will not be // visited. // Deprecated: use GenericVisitor or another visitor implementation -type Visitor interface { - Visit(v interface{}) (w Visitor) -} +type Visitor = v1.Visitor // BeforeAndAfterVisitor wraps Visitor to provide hooks for being called before // and after the AST has been visited. // Deprecated: use GenericVisitor or another visitor implementation -type BeforeAndAfterVisitor interface { - Visitor - Before(x interface{}) - After(x interface{}) -} +type BeforeAndAfterVisitor = v1.BeforeAndAfterVisitor // Walk iterates the AST by calling the Visit function on the Visitor // v for x before recursing. // Deprecated: use GenericVisitor.Walk func Walk(v Visitor, x interface{}) { - if bav, ok := v.(BeforeAndAfterVisitor); !ok { - walk(v, x) - } else { - bav.Before(x) - defer bav.After(x) - walk(bav, x) - } + v1.Walk(v, x) } // WalkBeforeAndAfter iterates the AST by calling the Visit function on the // Visitor v for x before recursing. // Deprecated: use GenericVisitor.Walk func WalkBeforeAndAfter(v BeforeAndAfterVisitor, x interface{}) { - Walk(v, x) -} - -func walk(v Visitor, x interface{}) { - w := v.Visit(x) - if w == nil { - return - } - switch x := x.(type) { - case *Module: - Walk(w, x.Package) - for i := range x.Imports { - Walk(w, x.Imports[i]) - } - for i := range x.Rules { - Walk(w, x.Rules[i]) - } - for i := range x.Annotations { - Walk(w, x.Annotations[i]) - } - for i := range x.Comments { - Walk(w, x.Comments[i]) - } - case *Package: - Walk(w, x.Path) - case *Import: - Walk(w, x.Path) - Walk(w, x.Alias) - case *Rule: - Walk(w, x.Head) - Walk(w, x.Body) - if x.Else != nil { - Walk(w, x.Else) - } - case *Head: - Walk(w, x.Name) - Walk(w, x.Args) - if x.Key != nil { - Walk(w, x.Key) - } - if x.Value != nil { - Walk(w, x.Value) - } - case Body: - for i := range x { - Walk(w, x[i]) - } - case Args: - for i := range x { - Walk(w, x[i]) - } - case *Expr: - switch ts := x.Terms.(type) { - case *Term, *SomeDecl, *Every: - Walk(w, ts) - case []*Term: - for i := range ts { - Walk(w, ts[i]) - } - } - for i := range x.With { - Walk(w, x.With[i]) - } - case *With: - Walk(w, x.Target) - Walk(w, x.Value) - case *Term: - Walk(w, x.Value) - case Ref: - for i := range x { - Walk(w, x[i]) - } - case *object: - x.Foreach(func(k, vv *Term) { - Walk(w, k) - Walk(w, vv) - }) - case *Array: - x.Foreach(func(t *Term) { - Walk(w, t) - }) - case Set: - x.Foreach(func(t *Term) { - Walk(w, t) - }) - case *ArrayComprehension: - Walk(w, x.Term) - Walk(w, x.Body) - case *ObjectComprehension: - Walk(w, x.Key) - Walk(w, x.Value) - Walk(w, x.Body) - case *SetComprehension: - Walk(w, x.Term) - Walk(w, x.Body) - case Call: - for i := range x { - Walk(w, x[i]) - } - case *Every: - if x.Key != nil { - Walk(w, x.Key) - } - Walk(w, x.Value) - Walk(w, x.Domain) - Walk(w, x.Body) - case *SomeDecl: - for i := range x.Symbols { - Walk(w, x.Symbols[i]) - } - } + v1.WalkBeforeAndAfter(v, x) } // WalkVars calls the function f on all vars under x. If the function f // returns true, AST nodes under the last node will not be visited. func WalkVars(x interface{}, f func(Var) bool) { - vis := &GenericVisitor{func(x interface{}) bool { - if v, ok := x.(Var); ok { - return f(v) - } - return false - }} - vis.Walk(x) + v1.WalkVars(x, f) } // WalkClosures calls the function f on all closures under x. If the function f // returns true, AST nodes under the last node will not be visited. func WalkClosures(x interface{}, f func(interface{}) bool) { - vis := &GenericVisitor{func(x interface{}) bool { - switch x := x.(type) { - case *ArrayComprehension, *ObjectComprehension, *SetComprehension, *Every: - return f(x) - } - return false - }} - vis.Walk(x) + v1.WalkClosures(x, f) } // WalkRefs calls the function f on all references under x. If the function f // returns true, AST nodes under the last node will not be visited. func WalkRefs(x interface{}, f func(Ref) bool) { - vis := &GenericVisitor{func(x interface{}) bool { - if r, ok := x.(Ref); ok { - return f(r) - } - return false - }} - vis.Walk(x) + v1.WalkRefs(x, f) } // WalkTerms calls the function f on all terms under x. If the function f // returns true, AST nodes under the last node will not be visited. func WalkTerms(x interface{}, f func(*Term) bool) { - vis := &GenericVisitor{func(x interface{}) bool { - if term, ok := x.(*Term); ok { - return f(term) - } - return false - }} - vis.Walk(x) + v1.WalkTerms(x, f) } // WalkWiths calls the function f on all with modifiers under x. If the function f // returns true, AST nodes under the last node will not be visited. func WalkWiths(x interface{}, f func(*With) bool) { - vis := &GenericVisitor{func(x interface{}) bool { - if w, ok := x.(*With); ok { - return f(w) - } - return false - }} - vis.Walk(x) + v1.WalkWiths(x, f) } // WalkExprs calls the function f on all expressions under x. If the function f // returns true, AST nodes under the last node will not be visited. func WalkExprs(x interface{}, f func(*Expr) bool) { - vis := &GenericVisitor{func(x interface{}) bool { - if r, ok := x.(*Expr); ok { - return f(r) - } - return false - }} - vis.Walk(x) + v1.WalkExprs(x, f) } // WalkBodies calls the function f on all bodies under x. If the function f // returns true, AST nodes under the last node will not be visited. func WalkBodies(x interface{}, f func(Body) bool) { - vis := &GenericVisitor{func(x interface{}) bool { - if b, ok := x.(Body); ok { - return f(b) - } - return false - }} - vis.Walk(x) + v1.WalkBodies(x, f) } // WalkRules calls the function f on all rules under x. If the function f // returns true, AST nodes under the last node will not be visited. func WalkRules(x interface{}, f func(*Rule) bool) { - vis := &GenericVisitor{func(x interface{}) bool { - if r, ok := x.(*Rule); ok { - stop := f(r) - // NOTE(tsandall): since rules cannot be embedded inside of queries - // we can stop early if there is no else block. - if stop || r.Else == nil { - return true - } - } - return false - }} - vis.Walk(x) + v1.WalkRules(x, f) } // WalkNodes calls the function f on all nodes under x. If the function f // returns true, AST nodes under the last node will not be visited. func WalkNodes(x interface{}, f func(Node) bool) { - vis := &GenericVisitor{func(x interface{}) bool { - if n, ok := x.(Node); ok { - return f(n) - } - return false - }} - vis.Walk(x) + v1.WalkNodes(x, f) } // GenericVisitor provides a utility to walk over AST nodes using a // closure. If the closure returns true, the visitor will not walk // over AST nodes under x. -type GenericVisitor struct { - f func(x interface{}) bool -} +type GenericVisitor = v1.GenericVisitor // NewGenericVisitor returns a new GenericVisitor that will invoke the function // f on AST nodes. func NewGenericVisitor(f func(x interface{}) bool) *GenericVisitor { - return &GenericVisitor{f} -} - -// Walk iterates the AST by calling the function f on the -// GenericVisitor before recursing. Contrary to the generic Walk, this -// does not require allocating the visitor from heap. -func (vis *GenericVisitor) Walk(x interface{}) { - if vis.f(x) { - return - } - - switch x := x.(type) { - case *Module: - vis.Walk(x.Package) - for i := range x.Imports { - vis.Walk(x.Imports[i]) - } - for i := range x.Rules { - vis.Walk(x.Rules[i]) - } - for i := range x.Annotations { - vis.Walk(x.Annotations[i]) - } - for i := range x.Comments { - vis.Walk(x.Comments[i]) - } - case *Package: - vis.Walk(x.Path) - case *Import: - vis.Walk(x.Path) - vis.Walk(x.Alias) - case *Rule: - vis.Walk(x.Head) - vis.Walk(x.Body) - if x.Else != nil { - vis.Walk(x.Else) - } - case *Head: - vis.Walk(x.Name) - vis.Walk(x.Args) - if x.Key != nil { - vis.Walk(x.Key) - } - if x.Value != nil { - vis.Walk(x.Value) - } - case Body: - for i := range x { - vis.Walk(x[i]) - } - case Args: - for i := range x { - vis.Walk(x[i]) - } - case *Expr: - switch ts := x.Terms.(type) { - case *Term, *SomeDecl, *Every: - vis.Walk(ts) - case []*Term: - for i := range ts { - vis.Walk(ts[i]) - } - } - for i := range x.With { - vis.Walk(x.With[i]) - } - case *With: - vis.Walk(x.Target) - vis.Walk(x.Value) - case *Term: - vis.Walk(x.Value) - case Ref: - for i := range x { - vis.Walk(x[i]) - } - case *object: - x.Foreach(func(k, _ *Term) { - vis.Walk(k) - vis.Walk(x.Get(k)) - }) - case Object: - x.Foreach(func(k, _ *Term) { - vis.Walk(k) - vis.Walk(x.Get(k)) - }) - case *Array: - x.Foreach(func(t *Term) { - vis.Walk(t) - }) - case Set: - xSlice := x.Slice() - for i := range xSlice { - vis.Walk(xSlice[i]) - } - case *ArrayComprehension: - vis.Walk(x.Term) - vis.Walk(x.Body) - case *ObjectComprehension: - vis.Walk(x.Key) - vis.Walk(x.Value) - vis.Walk(x.Body) - case *SetComprehension: - vis.Walk(x.Term) - vis.Walk(x.Body) - case Call: - for i := range x { - vis.Walk(x[i]) - } - case *Every: - if x.Key != nil { - vis.Walk(x.Key) - } - vis.Walk(x.Value) - vis.Walk(x.Domain) - vis.Walk(x.Body) - case *SomeDecl: - for i := range x.Symbols { - vis.Walk(x.Symbols[i]) - } - } + return v1.NewGenericVisitor(f) } // BeforeAfterVisitor provides a utility to walk over AST nodes using // closures. If the before closure returns true, the visitor will not // walk over AST nodes under x. The after closure is invoked always // after visiting a node. -type BeforeAfterVisitor struct { - before func(x interface{}) bool - after func(x interface{}) -} +type BeforeAfterVisitor = v1.BeforeAfterVisitor // NewBeforeAfterVisitor returns a new BeforeAndAfterVisitor that // will invoke the functions before and after AST nodes. func NewBeforeAfterVisitor(before func(x interface{}) bool, after func(x interface{})) *BeforeAfterVisitor { - return &BeforeAfterVisitor{before, after} -} - -// Walk iterates the AST by calling the functions on the -// BeforeAndAfterVisitor before and after recursing. Contrary to the -// generic Walk, this does not require allocating the visitor from -// heap. -func (vis *BeforeAfterVisitor) Walk(x interface{}) { - defer vis.after(x) - if vis.before(x) { - return - } - - switch x := x.(type) { - case *Module: - vis.Walk(x.Package) - for i := range x.Imports { - vis.Walk(x.Imports[i]) - } - for i := range x.Rules { - vis.Walk(x.Rules[i]) - } - for i := range x.Annotations { - vis.Walk(x.Annotations[i]) - } - for i := range x.Comments { - vis.Walk(x.Comments[i]) - } - case *Package: - vis.Walk(x.Path) - case *Import: - vis.Walk(x.Path) - vis.Walk(x.Alias) - case *Rule: - vis.Walk(x.Head) - vis.Walk(x.Body) - if x.Else != nil { - vis.Walk(x.Else) - } - case *Head: - if len(x.Reference) > 0 { - vis.Walk(x.Reference) - } else { - vis.Walk(x.Name) - if x.Key != nil { - vis.Walk(x.Key) - } - } - vis.Walk(x.Args) - if x.Value != nil { - vis.Walk(x.Value) - } - case Body: - for i := range x { - vis.Walk(x[i]) - } - case Args: - for i := range x { - vis.Walk(x[i]) - } - case *Expr: - switch ts := x.Terms.(type) { - case *Term, *SomeDecl, *Every: - vis.Walk(ts) - case []*Term: - for i := range ts { - vis.Walk(ts[i]) - } - } - for i := range x.With { - vis.Walk(x.With[i]) - } - case *With: - vis.Walk(x.Target) - vis.Walk(x.Value) - case *Term: - vis.Walk(x.Value) - case Ref: - for i := range x { - vis.Walk(x[i]) - } - case *object: - x.Foreach(func(k, _ *Term) { - vis.Walk(k) - vis.Walk(x.Get(k)) - }) - case Object: - x.Foreach(func(k, _ *Term) { - vis.Walk(k) - vis.Walk(x.Get(k)) - }) - case *Array: - x.Foreach(func(t *Term) { - vis.Walk(t) - }) - case Set: - xSlice := x.Slice() - for i := range xSlice { - vis.Walk(xSlice[i]) - } - case *ArrayComprehension: - vis.Walk(x.Term) - vis.Walk(x.Body) - case *ObjectComprehension: - vis.Walk(x.Key) - vis.Walk(x.Value) - vis.Walk(x.Body) - case *SetComprehension: - vis.Walk(x.Term) - vis.Walk(x.Body) - case Call: - for i := range x { - vis.Walk(x[i]) - } - case *Every: - if x.Key != nil { - vis.Walk(x.Key) - } - vis.Walk(x.Value) - vis.Walk(x.Domain) - vis.Walk(x.Body) - case *SomeDecl: - for i := range x.Symbols { - vis.Walk(x.Symbols[i]) - } - } + return v1.NewBeforeAfterVisitor(before, after) } // VarVisitor walks AST nodes under a given node and collects all encountered // variables. The collected variables can be controlled by specifying // VarVisitorParams when creating the visitor. -type VarVisitor struct { - params VarVisitorParams - vars VarSet -} +type VarVisitor = v1.VarVisitor // VarVisitorParams contains settings for a VarVisitor. -type VarVisitorParams struct { - SkipRefHead bool - SkipRefCallHead bool - SkipObjectKeys bool - SkipClosures bool - SkipWithTarget bool - SkipSets bool -} +type VarVisitorParams = v1.VarVisitorParams // NewVarVisitor returns a new VarVisitor object. func NewVarVisitor() *VarVisitor { - return &VarVisitor{ - vars: NewVarSet(), - } -} - -// WithParams sets the parameters in params on vis. -func (vis *VarVisitor) WithParams(params VarVisitorParams) *VarVisitor { - vis.params = params - return vis -} - -// Vars returns a VarSet that contains collected vars. -func (vis *VarVisitor) Vars() VarSet { - return vis.vars -} - -// visit determines if the VarVisitor will recurse into x: if it returns `true`, -// the visitor will _skip_ that branch of the AST -func (vis *VarVisitor) visit(v interface{}) bool { - if vis.params.SkipObjectKeys { - if o, ok := v.(Object); ok { - o.Foreach(func(_, v *Term) { - vis.Walk(v) - }) - return true - } - } - if vis.params.SkipRefHead { - if r, ok := v.(Ref); ok { - rSlice := r[1:] - for i := range rSlice { - vis.Walk(rSlice[i]) - } - return true - } - } - if vis.params.SkipClosures { - switch v := v.(type) { - case *ArrayComprehension, *ObjectComprehension, *SetComprehension: - return true - case *Expr: - if ev, ok := v.Terms.(*Every); ok { - vis.Walk(ev.Domain) - // We're _not_ walking ev.Body -- that's the closure here - return true - } - } - } - if vis.params.SkipWithTarget { - if v, ok := v.(*With); ok { - vis.Walk(v.Value) - return true - } - } - if vis.params.SkipSets { - if _, ok := v.(Set); ok { - return true - } - } - if vis.params.SkipRefCallHead { - switch v := v.(type) { - case *Expr: - if terms, ok := v.Terms.([]*Term); ok { - termSlice := terms[0].Value.(Ref)[1:] - for i := range termSlice { - vis.Walk(termSlice[i]) - } - for i := 1; i < len(terms); i++ { - vis.Walk(terms[i]) - } - for i := range v.With { - vis.Walk(v.With[i]) - } - return true - } - case Call: - operator := v[0].Value.(Ref) - for i := 1; i < len(operator); i++ { - vis.Walk(operator[i]) - } - for i := 1; i < len(v); i++ { - vis.Walk(v[i]) - } - return true - case *With: - if ref, ok := v.Target.Value.(Ref); ok { - refSlice := ref[1:] - for i := range refSlice { - vis.Walk(refSlice[i]) - } - } - if ref, ok := v.Value.Value.(Ref); ok { - refSlice := ref[1:] - for i := range refSlice { - vis.Walk(refSlice[i]) - } - } else { - vis.Walk(v.Value) - } - return true - } - } - if v, ok := v.(Var); ok { - vis.vars.Add(v) - } - return false -} - -// Walk iterates the AST by calling the function f on the -// GenericVisitor before recursing. Contrary to the generic Walk, this -// does not require allocating the visitor from heap. -func (vis *VarVisitor) Walk(x interface{}) { - if vis.visit(x) { - return - } - - switch x := x.(type) { - case *Module: - vis.Walk(x.Package) - for i := range x.Imports { - vis.Walk(x.Imports[i]) - } - for i := range x.Rules { - vis.Walk(x.Rules[i]) - } - for i := range x.Comments { - vis.Walk(x.Comments[i]) - } - case *Package: - vis.Walk(x.Path) - case *Import: - vis.Walk(x.Path) - vis.Walk(x.Alias) - case *Rule: - vis.Walk(x.Head) - vis.Walk(x.Body) - if x.Else != nil { - vis.Walk(x.Else) - } - case *Head: - if len(x.Reference) > 0 { - vis.Walk(x.Reference) - } else { - vis.Walk(x.Name) - if x.Key != nil { - vis.Walk(x.Key) - } - } - vis.Walk(x.Args) - - if x.Value != nil { - vis.Walk(x.Value) - } - case Body: - for i := range x { - vis.Walk(x[i]) - } - case Args: - for i := range x { - vis.Walk(x[i]) - } - case *Expr: - switch ts := x.Terms.(type) { - case *Term, *SomeDecl, *Every: - vis.Walk(ts) - case []*Term: - for i := range ts { - vis.Walk(ts[i]) - } - } - for i := range x.With { - vis.Walk(x.With[i]) - } - case *With: - vis.Walk(x.Target) - vis.Walk(x.Value) - case *Term: - vis.Walk(x.Value) - case Ref: - for i := range x { - vis.Walk(x[i]) - } - case *object: - x.Foreach(func(k, _ *Term) { - vis.Walk(k) - vis.Walk(x.Get(k)) - }) - case *Array: - x.Foreach(func(t *Term) { - vis.Walk(t) - }) - case Set: - xSlice := x.Slice() - for i := range xSlice { - vis.Walk(xSlice[i]) - } - case *ArrayComprehension: - vis.Walk(x.Term) - vis.Walk(x.Body) - case *ObjectComprehension: - vis.Walk(x.Key) - vis.Walk(x.Value) - vis.Walk(x.Body) - case *SetComprehension: - vis.Walk(x.Term) - vis.Walk(x.Body) - case Call: - for i := range x { - vis.Walk(x[i]) - } - case *Every: - if x.Key != nil { - vis.Walk(x.Key) - } - vis.Walk(x.Value) - vis.Walk(x.Domain) - vis.Walk(x.Body) - case *SomeDecl: - for i := range x.Symbols { - vis.Walk(x.Symbols[i]) - } - } + return v1.NewVarVisitor() } diff --git a/vendor/github.com/open-policy-agent/opa/bundle/bundle.go b/vendor/github.com/open-policy-agent/opa/bundle/bundle.go index 0e159384e..50ad97349 100644 --- a/vendor/github.com/open-policy-agent/opa/bundle/bundle.go +++ b/vendor/github.com/open-policy-agent/opa/bundle/bundle.go @@ -6,1386 +6,97 @@ package bundle import ( - "archive/tar" - "bytes" - "compress/gzip" - "encoding/hex" - "encoding/json" - "errors" - "fmt" "io" - "net/url" - "os" - "path" - "path/filepath" - "reflect" - "strings" - "github.com/gobwas/glob" "github.com/open-policy-agent/opa/ast" - astJSON "github.com/open-policy-agent/opa/ast/json" - "github.com/open-policy-agent/opa/format" - "github.com/open-policy-agent/opa/internal/file/archive" - "github.com/open-policy-agent/opa/internal/merge" - "github.com/open-policy-agent/opa/metrics" - "github.com/open-policy-agent/opa/util" + v1 "github.com/open-policy-agent/opa/v1/bundle" ) // Common file extensions and file names. const ( - RegoExt = ".rego" - WasmFile = "policy.wasm" - PlanFile = "plan.json" - ManifestExt = ".manifest" - SignaturesFile = "signatures.json" - patchFile = "patch.json" - dataFile = "data.json" - yamlDataFile = "data.yaml" - ymlDataFile = "data.yml" - defaultHashingAlg = "SHA-256" - DefaultSizeLimitBytes = (1024 * 1024 * 1024) // limit bundle reads to 1GB to protect against gzip bombs - DeltaBundleType = "delta" - SnapshotBundleType = "snapshot" + RegoExt = v1.RegoExt + WasmFile = v1.WasmFile + PlanFile = v1.PlanFile + ManifestExt = v1.ManifestExt + SignaturesFile = v1.SignaturesFile + + DefaultSizeLimitBytes = v1.DefaultSizeLimitBytes + DeltaBundleType = v1.DeltaBundleType + SnapshotBundleType = v1.SnapshotBundleType ) // Bundle represents a loaded bundle. The bundle can contain data and policies. -type Bundle struct { - Signatures SignaturesConfig - Manifest Manifest - Data map[string]interface{} - Modules []ModuleFile - Wasm []byte // Deprecated. Use WasmModules instead - WasmModules []WasmModuleFile - PlanModules []PlanModuleFile - Patch Patch - Etag string - Raw []Raw - - lazyLoadingMode bool - sizeLimitBytes int64 -} +type Bundle = v1.Bundle // Raw contains raw bytes representing the bundle's content -type Raw struct { - Path string - Value []byte -} +type Raw = v1.Raw // Patch contains an array of objects wherein each object represents the patch operation to be // applied to the bundle data. -type Patch struct { - Data []PatchOperation `json:"data,omitempty"` -} +type Patch = v1.Patch // PatchOperation models a single patch operation against a document. -type PatchOperation struct { - Op string `json:"op"` - Path string `json:"path"` - Value interface{} `json:"value"` -} +type PatchOperation = v1.PatchOperation // SignaturesConfig represents an array of JWTs that encapsulate the signatures for the bundle. -type SignaturesConfig struct { - Signatures []string `json:"signatures,omitempty"` - Plugin string `json:"plugin,omitempty"` -} - -// isEmpty returns if the SignaturesConfig is empty. -func (s SignaturesConfig) isEmpty() bool { - return reflect.DeepEqual(s, SignaturesConfig{}) -} +type SignaturesConfig = v1.SignaturesConfig // DecodedSignature represents the decoded JWT payload. -type DecodedSignature struct { - Files []FileInfo `json:"files"` - KeyID string `json:"keyid"` // Deprecated, use kid in the JWT header instead. - Scope string `json:"scope"` - IssuedAt int64 `json:"iat"` - Issuer string `json:"iss"` -} +type DecodedSignature = v1.DecodedSignature // FileInfo contains the hashing algorithm used, resulting digest etc. -type FileInfo struct { - Name string `json:"name"` - Hash string `json:"hash"` - Algorithm string `json:"algorithm"` -} +type FileInfo = v1.FileInfo // NewFile returns a new FileInfo. func NewFile(name, hash, alg string) FileInfo { - return FileInfo{ - Name: name, - Hash: hash, - Algorithm: alg, - } + return v1.NewFile(name, hash, alg) } // Manifest represents the manifest from a bundle. The manifest may contain // metadata such as the bundle revision. -type Manifest struct { - Revision string `json:"revision"` - Roots *[]string `json:"roots,omitempty"` - WasmResolvers []WasmResolver `json:"wasm,omitempty"` - // RegoVersion is the global Rego version for the bundle described by this Manifest. - // The Rego version of individual files can be overridden in FileRegoVersions. - // We don't use ast.RegoVersion here, as this iota type's order isn't guaranteed to be stable over time. - // We use a pointer so that we can support hand-made bundles that don't have an explicit version appropriately. - // E.g. in OPA 0.x if --v1-compatible is used when consuming the bundle, and there is no specified version, - // we should default to v1; if --v1-compatible isn't used, we should default to v0. In OPA 1.0, no --x-compatible - // flag and no explicit bundle version should default to v1. - RegoVersion *int `json:"rego_version,omitempty"` - // FileRegoVersions is a map from file paths to Rego versions. - // This allows individual files to override the global Rego version specified by RegoVersion. - FileRegoVersions map[string]int `json:"file_rego_versions,omitempty"` - Metadata map[string]interface{} `json:"metadata,omitempty"` - - compiledFileRegoVersions []fileRegoVersion -} - -type fileRegoVersion struct { - path glob.Glob - version int -} +type Manifest = v1.Manifest // WasmResolver maps a wasm module to an entrypoint ref. -type WasmResolver struct { - Entrypoint string `json:"entrypoint,omitempty"` - Module string `json:"module,omitempty"` - Annotations []*ast.Annotations `json:"annotations,omitempty"` -} - -// Init initializes the manifest. If you instantiate a manifest -// manually, call Init to ensure that the roots are set properly. -func (m *Manifest) Init() { - if m.Roots == nil { - defaultRoots := []string{""} - m.Roots = &defaultRoots - } -} - -// AddRoot adds r to the roots of m. This function is idempotent. -func (m *Manifest) AddRoot(r string) { - m.Init() - if !RootPathsContain(*m.Roots, r) { - *m.Roots = append(*m.Roots, r) - } -} - -func (m *Manifest) SetRegoVersion(v ast.RegoVersion) { - m.Init() - regoVersion := 0 - if v == ast.RegoV1 { - regoVersion = 1 - } - m.RegoVersion = ®oVersion -} - -// Equal returns true if m is semantically equivalent to other. -func (m Manifest) Equal(other Manifest) bool { - - // This is safe since both are passed by value. - m.Init() - other.Init() - - if m.Revision != other.Revision { - return false - } - - if m.RegoVersion == nil && other.RegoVersion != nil { - return false - } - if m.RegoVersion != nil && other.RegoVersion == nil { - return false - } - if m.RegoVersion != nil && other.RegoVersion != nil && *m.RegoVersion != *other.RegoVersion { - return false - } - - // If both are nil, or both are empty, we consider them equal. - if !(len(m.FileRegoVersions) == 0 && len(other.FileRegoVersions) == 0) && - !reflect.DeepEqual(m.FileRegoVersions, other.FileRegoVersions) { - return false - } - - if !reflect.DeepEqual(m.Metadata, other.Metadata) { - return false - } - - return m.equalWasmResolversAndRoots(other) -} - -func (m Manifest) Empty() bool { - return m.Equal(Manifest{}) -} - -// Copy returns a deep copy of the manifest. -func (m Manifest) Copy() Manifest { - m.Init() - roots := make([]string, len(*m.Roots)) - copy(roots, *m.Roots) - m.Roots = &roots - - wasmModules := make([]WasmResolver, len(m.WasmResolvers)) - copy(wasmModules, m.WasmResolvers) - m.WasmResolvers = wasmModules - - metadata := m.Metadata - - if metadata != nil { - m.Metadata = make(map[string]interface{}) - for k, v := range metadata { - m.Metadata[k] = v - } - } - - return m -} - -func (m Manifest) String() string { - m.Init() - if m.RegoVersion != nil { - return fmt.Sprintf("", - m.Revision, *m.RegoVersion, *m.Roots, m.WasmResolvers, m.Metadata) - } - return fmt.Sprintf("", - m.Revision, *m.Roots, m.WasmResolvers, m.Metadata) -} - -func (m Manifest) rootSet() stringSet { - rs := map[string]struct{}{} - - for _, r := range *m.Roots { - rs[r] = struct{}{} - } - - return stringSet(rs) -} - -func (m Manifest) equalWasmResolversAndRoots(other Manifest) bool { - if len(m.WasmResolvers) != len(other.WasmResolvers) { - return false - } - - for i := 0; i < len(m.WasmResolvers); i++ { - if !m.WasmResolvers[i].Equal(&other.WasmResolvers[i]) { - return false - } - } - - return m.rootSet().Equal(other.rootSet()) -} - -func (wr *WasmResolver) Equal(other *WasmResolver) bool { - if wr == nil && other == nil { - return true - } - - if wr == nil || other == nil { - return false - } - - if wr.Module != other.Module { - return false - } - - if wr.Entrypoint != other.Entrypoint { - return false - } - - annotLen := len(wr.Annotations) - if annotLen != len(other.Annotations) { - return false - } - - for i := 0; i < annotLen; i++ { - if wr.Annotations[i].Compare(other.Annotations[i]) != 0 { - return false - } - } - - return true -} - -type stringSet map[string]struct{} - -func (ss stringSet) Equal(other stringSet) bool { - if len(ss) != len(other) { - return false - } - for k := range other { - if _, ok := ss[k]; !ok { - return false - } - } - return true -} - -func (m *Manifest) validateAndInjectDefaults(b Bundle) error { - - m.Init() - - // Validate roots in bundle. - roots := *m.Roots - - // Standardize the roots (no starting or trailing slash) - for i := range roots { - roots[i] = strings.Trim(roots[i], "/") - } - - for i := 0; i < len(roots)-1; i++ { - for j := i + 1; j < len(roots); j++ { - if RootPathsOverlap(roots[i], roots[j]) { - return fmt.Errorf("manifest has overlapped roots: '%v' and '%v'", roots[i], roots[j]) - } - } - } - - // Validate modules in bundle. - for _, module := range b.Modules { - found := false - if path, err := module.Parsed.Package.Path.Ptr(); err == nil { - found = RootPathsContain(roots, path) - } - if !found { - return fmt.Errorf("manifest roots %v do not permit '%v' in module '%v'", roots, module.Parsed.Package, module.Path) - } - } - - // Build a set of wasm module entrypoints to validate - wasmModuleToEps := map[string]string{} - seenEps := map[string]struct{}{} - for _, wm := range b.WasmModules { - wasmModuleToEps[wm.Path] = "" - } - - for _, wmConfig := range b.Manifest.WasmResolvers { - _, ok := wasmModuleToEps[wmConfig.Module] - if !ok { - return fmt.Errorf("manifest references wasm module '%s' but the module file does not exist", wmConfig.Module) - } - - // Ensure wasm module entrypoint in within bundle roots - if !RootPathsContain(roots, wmConfig.Entrypoint) { - return fmt.Errorf("manifest roots %v do not permit '%v' entrypoint for wasm module '%v'", roots, wmConfig.Entrypoint, wmConfig.Module) - } - - if _, ok := seenEps[wmConfig.Entrypoint]; ok { - return fmt.Errorf("entrypoint '%s' cannot be used by more than one wasm module", wmConfig.Entrypoint) - } - seenEps[wmConfig.Entrypoint] = struct{}{} - - wasmModuleToEps[wmConfig.Module] = wmConfig.Entrypoint - } - - // Validate data patches in bundle. - for _, patch := range b.Patch.Data { - path := strings.Trim(patch.Path, "/") - if !RootPathsContain(roots, path) { - return fmt.Errorf("manifest roots %v do not permit data patch at path '%s'", roots, path) - } - } - - if b.lazyLoadingMode { - return nil - } - - // Validate data in bundle. - return dfs(b.Data, "", func(path string, node interface{}) (bool, error) { - path = strings.Trim(path, "/") - if RootPathsContain(roots, path) { - return true, nil - } - - if _, ok := node.(map[string]interface{}); ok { - for i := range roots { - if RootPathsContain(strings.Split(path, "/"), roots[i]) { - return false, nil - } - } - } - return false, fmt.Errorf("manifest roots %v do not permit data at path '/%s' (hint: check bundle directory structure)", roots, path) - }) -} +type WasmResolver = v1.WasmResolver // ModuleFile represents a single module contained in a bundle. -type ModuleFile struct { - URL string - Path string - RelativePath string - Raw []byte - Parsed *ast.Module -} +type ModuleFile = v1.ModuleFile // WasmModuleFile represents a single wasm module contained in a bundle. -type WasmModuleFile struct { - URL string - Path string - Entrypoints []ast.Ref - Raw []byte -} +type WasmModuleFile = v1.WasmModuleFile // PlanModuleFile represents a single plan module contained in a bundle. // // NOTE(tsandall): currently the plans are just opaque binary blobs. In the // future we could inject the entrypoints so that the plans could be executed // inside of OPA proper like we do for Wasm modules. -type PlanModuleFile struct { - URL string - Path string - Raw []byte -} +type PlanModuleFile = v1.PlanModuleFile // Reader contains the reader to load the bundle from. -type Reader struct { - loader DirectoryLoader - includeManifestInData bool - metrics metrics.Metrics - baseDir string - verificationConfig *VerificationConfig - skipVerify bool - processAnnotations bool - jsonOptions *astJSON.Options - capabilities *ast.Capabilities - files map[string]FileInfo // files in the bundle signature payload - sizeLimitBytes int64 - etag string - lazyLoadingMode bool - name string - persist bool - regoVersion ast.RegoVersion - followSymlinks bool -} +type Reader = v1.Reader // NewReader is deprecated. Use NewCustomReader instead. func NewReader(r io.Reader) *Reader { - return NewCustomReader(NewTarballLoader(r)) + return v1.NewReader(r).WithRegoVersion(ast.DefaultRegoVersion) } // NewCustomReader returns a new Reader configured to use the // specified DirectoryLoader. func NewCustomReader(loader DirectoryLoader) *Reader { - nr := Reader{ - loader: loader, - metrics: metrics.New(), - files: make(map[string]FileInfo), - sizeLimitBytes: DefaultSizeLimitBytes + 1, - } - return &nr -} - -// IncludeManifestInData sets whether the manifest metadata should be -// included in the bundle's data. -func (r *Reader) IncludeManifestInData(includeManifestInData bool) *Reader { - r.includeManifestInData = includeManifestInData - return r -} - -// WithMetrics sets the metrics object to be used while loading bundles -func (r *Reader) WithMetrics(m metrics.Metrics) *Reader { - r.metrics = m - return r -} - -// WithBaseDir sets a base directory for file paths of loaded Rego -// modules. This will *NOT* affect the loaded path of data files. -func (r *Reader) WithBaseDir(dir string) *Reader { - r.baseDir = dir - return r -} - -// WithBundleVerificationConfig sets the key configuration used to verify a signed bundle -func (r *Reader) WithBundleVerificationConfig(config *VerificationConfig) *Reader { - r.verificationConfig = config - return r -} - -// WithSkipBundleVerification skips verification of a signed bundle -func (r *Reader) WithSkipBundleVerification(skipVerify bool) *Reader { - r.skipVerify = skipVerify - return r -} - -// WithProcessAnnotations enables annotation processing during .rego file parsing. -func (r *Reader) WithProcessAnnotations(yes bool) *Reader { - r.processAnnotations = yes - return r -} - -// WithCapabilities sets the supported capabilities when loading the files -func (r *Reader) WithCapabilities(caps *ast.Capabilities) *Reader { - r.capabilities = caps - return r -} - -// WithJSONOptions sets the JSONOptions to use when parsing policy files -func (r *Reader) WithJSONOptions(opts *astJSON.Options) *Reader { - r.jsonOptions = opts - return r -} - -// WithSizeLimitBytes sets the size limit to apply to files in the bundle. If files are larger -// than this, an error will be returned by the reader. -func (r *Reader) WithSizeLimitBytes(n int64) *Reader { - r.sizeLimitBytes = n + 1 - return r -} - -// WithBundleEtag sets the given etag value on the bundle -func (r *Reader) WithBundleEtag(etag string) *Reader { - r.etag = etag - return r -} - -// WithBundleName specifies the bundle name -func (r *Reader) WithBundleName(name string) *Reader { - r.name = name - return r -} - -func (r *Reader) WithFollowSymlinks(yes bool) *Reader { - r.followSymlinks = yes - return r -} - -// WithLazyLoadingMode sets the bundle loading mode. If true, -// bundles will be read in lazy mode. In this mode, data files in the bundle will not be -// deserialized and the check to validate that the bundle data does not contain paths -// outside the bundle's roots will not be performed while reading the bundle. -func (r *Reader) WithLazyLoadingMode(yes bool) *Reader { - r.lazyLoadingMode = yes - return r -} - -// WithBundlePersistence specifies if the downloaded bundle will eventually be persisted to disk. -func (r *Reader) WithBundlePersistence(persist bool) *Reader { - r.persist = persist - return r -} - -func (r *Reader) WithRegoVersion(version ast.RegoVersion) *Reader { - r.regoVersion = version - return r -} - -func (r *Reader) ParserOptions() ast.ParserOptions { - return ast.ParserOptions{ - ProcessAnnotation: r.processAnnotations, - Capabilities: r.capabilities, - JSONOptions: r.jsonOptions, - RegoVersion: r.regoVersion, - } -} - -// Read returns a new Bundle loaded from the reader. -func (r *Reader) Read() (Bundle, error) { - - var bundle Bundle - var descriptors []*Descriptor - var err error - var raw []Raw - - bundle.Signatures, bundle.Patch, descriptors, err = preProcessBundle(r.loader, r.skipVerify, r.sizeLimitBytes) - if err != nil { - return bundle, err - } - - bundle.lazyLoadingMode = r.lazyLoadingMode - bundle.sizeLimitBytes = r.sizeLimitBytes - - if bundle.Type() == SnapshotBundleType { - err = r.checkSignaturesAndDescriptors(bundle.Signatures) - if err != nil { - return bundle, err - } - - bundle.Data = map[string]interface{}{} - } - - var modules []ModuleFile - for _, f := range descriptors { - buf, err := readFile(f, r.sizeLimitBytes) - if err != nil { - return bundle, err - } - - // verify the file content - if bundle.Type() == SnapshotBundleType && !bundle.Signatures.isEmpty() { - path := f.Path() - if r.baseDir != "" { - path = f.URL() - } - path = strings.TrimPrefix(path, "/") - - // check if the file is to be excluded from bundle verification - if r.isFileExcluded(path) { - delete(r.files, path) - } else { - if err = r.verifyBundleFile(path, buf); err != nil { - return bundle, err - } - } - } - - // Normalize the paths to use `/` separators - path := filepath.ToSlash(f.Path()) - - if strings.HasSuffix(path, RegoExt) { - fullPath := r.fullPath(path) - bs := buf.Bytes() - - if r.lazyLoadingMode { - p := fullPath - if r.name != "" { - p = modulePathWithPrefix(r.name, fullPath) - } - - raw = append(raw, Raw{Path: p, Value: bs}) - } - - // Modules are parsed after we've had a chance to read the manifest - mf := ModuleFile{ - URL: f.URL(), - Path: fullPath, - RelativePath: path, - Raw: bs, - } - modules = append(modules, mf) - } else if filepath.Base(path) == WasmFile { - bundle.WasmModules = append(bundle.WasmModules, WasmModuleFile{ - URL: f.URL(), - Path: r.fullPath(path), - Raw: buf.Bytes(), - }) - } else if filepath.Base(path) == PlanFile { - bundle.PlanModules = append(bundle.PlanModules, PlanModuleFile{ - URL: f.URL(), - Path: r.fullPath(path), - Raw: buf.Bytes(), - }) - } else if filepath.Base(path) == dataFile { - if r.lazyLoadingMode { - raw = append(raw, Raw{Path: path, Value: buf.Bytes()}) - continue - } - - var value interface{} - - r.metrics.Timer(metrics.RegoDataParse).Start() - err := util.UnmarshalJSON(buf.Bytes(), &value) - r.metrics.Timer(metrics.RegoDataParse).Stop() - - if err != nil { - return bundle, fmt.Errorf("bundle load failed on %v: %w", r.fullPath(path), err) - } - - if err := insertValue(&bundle, path, value); err != nil { - return bundle, err - } - - } else if filepath.Base(path) == yamlDataFile || filepath.Base(path) == ymlDataFile { - if r.lazyLoadingMode { - raw = append(raw, Raw{Path: path, Value: buf.Bytes()}) - continue - } - - var value interface{} - - r.metrics.Timer(metrics.RegoDataParse).Start() - err := util.Unmarshal(buf.Bytes(), &value) - r.metrics.Timer(metrics.RegoDataParse).Stop() - - if err != nil { - return bundle, fmt.Errorf("bundle load failed on %v: %w", r.fullPath(path), err) - } - - if err := insertValue(&bundle, path, value); err != nil { - return bundle, err - } - - } else if strings.HasSuffix(path, ManifestExt) { - if err := util.NewJSONDecoder(&buf).Decode(&bundle.Manifest); err != nil { - return bundle, fmt.Errorf("bundle load failed on manifest decode: %w", err) - } - } - } - - // Parse modules - popts := r.ParserOptions() - popts.RegoVersion = bundle.RegoVersion(popts.RegoVersion) - for _, mf := range modules { - modulePopts := popts - if modulePopts.RegoVersion, err = bundle.RegoVersionForFile(mf.RelativePath, popts.RegoVersion); err != nil { - return bundle, err - } - r.metrics.Timer(metrics.RegoModuleParse).Start() - mf.Parsed, err = ast.ParseModuleWithOpts(mf.Path, string(mf.Raw), modulePopts) - r.metrics.Timer(metrics.RegoModuleParse).Stop() - if err != nil { - return bundle, err - } - bundle.Modules = append(bundle.Modules, mf) - } - - if bundle.Type() == DeltaBundleType { - if len(bundle.Data) != 0 { - return bundle, fmt.Errorf("delta bundle expected to contain only patch file but data files found") - } - - if len(bundle.Modules) != 0 { - return bundle, fmt.Errorf("delta bundle expected to contain only patch file but policy files found") - } - - if len(bundle.WasmModules) != 0 { - return bundle, fmt.Errorf("delta bundle expected to contain only patch file but wasm files found") - } - - if r.persist { - return bundle, fmt.Errorf("'persist' property is true in config. persisting delta bundle to disk is not supported") - } - } - - // check if the bundle signatures specify any files that weren't found in the bundle - if bundle.Type() == SnapshotBundleType && len(r.files) != 0 { - extra := []string{} - for k := range r.files { - extra = append(extra, k) - } - return bundle, fmt.Errorf("file(s) %v specified in bundle signatures but not found in the target bundle", extra) - } - - if err := bundle.Manifest.validateAndInjectDefaults(bundle); err != nil { - return bundle, err - } - - // Inject the wasm module entrypoint refs into the WasmModuleFile structs - epMap := map[string][]string{} - for _, r := range bundle.Manifest.WasmResolvers { - epMap[r.Module] = append(epMap[r.Module], r.Entrypoint) - } - for i := 0; i < len(bundle.WasmModules); i++ { - entrypoints := epMap[bundle.WasmModules[i].Path] - for _, entrypoint := range entrypoints { - ref, err := ast.PtrRef(ast.DefaultRootDocument, entrypoint) - if err != nil { - return bundle, fmt.Errorf("failed to parse wasm module entrypoint '%s': %s", entrypoint, err) - } - bundle.WasmModules[i].Entrypoints = append(bundle.WasmModules[i].Entrypoints, ref) - } - } - - if r.includeManifestInData { - var metadata map[string]interface{} - - b, err := json.Marshal(&bundle.Manifest) - if err != nil { - return bundle, fmt.Errorf("bundle load failed on manifest marshal: %w", err) - } - - err = util.UnmarshalJSON(b, &metadata) - if err != nil { - return bundle, fmt.Errorf("bundle load failed on manifest unmarshal: %w", err) - } - - // For backwards compatibility always write to the old unnamed manifest path - // This will *not* be correct if >1 bundle is in use... - if err := bundle.insertData(legacyManifestStoragePath, metadata); err != nil { - return bundle, fmt.Errorf("bundle load failed on %v: %w", legacyRevisionStoragePath, err) - } - } - - bundle.Etag = r.etag - bundle.Raw = raw - - return bundle, nil -} - -func (r *Reader) isFileExcluded(path string) bool { - for _, e := range r.verificationConfig.Exclude { - match, _ := filepath.Match(e, path) - if match { - return true - } - } - return false -} - -func (r *Reader) checkSignaturesAndDescriptors(signatures SignaturesConfig) error { - if r.skipVerify { - return nil - } - - if signatures.isEmpty() && r.verificationConfig != nil && r.verificationConfig.KeyID != "" { - return fmt.Errorf("bundle missing .signatures.json file") - } - - if !signatures.isEmpty() { - if r.verificationConfig == nil { - return fmt.Errorf("verification key not provided") - } - - // verify the JWT signatures included in the `.signatures.json` file - if err := r.verifyBundleSignature(signatures); err != nil { - return err - } - } - return nil -} - -func (r *Reader) verifyBundleSignature(sc SignaturesConfig) error { - var err error - r.files, err = VerifyBundleSignature(sc, r.verificationConfig) - return err -} - -func (r *Reader) verifyBundleFile(path string, data bytes.Buffer) error { - return VerifyBundleFile(path, data, r.files) -} - -func (r *Reader) fullPath(path string) string { - if r.baseDir != "" { - path = filepath.Join(r.baseDir, path) - } - return path + return v1.NewCustomReader(loader).WithRegoVersion(ast.DefaultRegoVersion) } // Write is deprecated. Use NewWriter instead. func Write(w io.Writer, bundle Bundle) error { - return NewWriter(w). - UseModulePath(true). - DisableFormat(true). - Write(bundle) + return v1.Write(w, bundle) } // Writer implements bundle serialization. -type Writer struct { - usePath bool - disableFormat bool - w io.Writer -} +type Writer = v1.Writer // NewWriter returns a bundle writer that writes to w. func NewWriter(w io.Writer) *Writer { - return &Writer{ - w: w, - } -} - -// UseModulePath configures the writer to use the module file path instead of the -// module file URL during serialization. This is for backwards compatibility. -func (w *Writer) UseModulePath(yes bool) *Writer { - w.usePath = yes - return w -} - -// DisableFormat configures the writer to just write out raw bytes instead -// of formatting modules before serialization. -func (w *Writer) DisableFormat(yes bool) *Writer { - w.disableFormat = yes - return w -} - -// Write writes the bundle to the writer's output stream. -func (w *Writer) Write(bundle Bundle) error { - gw := gzip.NewWriter(w.w) - tw := tar.NewWriter(gw) - - bundleType := bundle.Type() - - if bundleType == SnapshotBundleType { - var buf bytes.Buffer - - if err := json.NewEncoder(&buf).Encode(bundle.Data); err != nil { - return err - } - - if err := archive.WriteFile(tw, "data.json", buf.Bytes()); err != nil { - return err - } - - for _, module := range bundle.Modules { - path := module.URL - if w.usePath { - path = module.Path - } - - if err := archive.WriteFile(tw, path, module.Raw); err != nil { - return err - } - } - - if err := w.writeWasm(tw, bundle); err != nil { - return err - } - - if err := writeSignatures(tw, bundle); err != nil { - return err - } - - if err := w.writePlan(tw, bundle); err != nil { - return err - } - } else if bundleType == DeltaBundleType { - if err := writePatch(tw, bundle); err != nil { - return err - } - } - - if err := writeManifest(tw, bundle); err != nil { - return err - } - - if err := tw.Close(); err != nil { - return err - } - - return gw.Close() -} - -func (w *Writer) writeWasm(tw *tar.Writer, bundle Bundle) error { - for _, wm := range bundle.WasmModules { - path := wm.URL - if w.usePath { - path = wm.Path - } - - err := archive.WriteFile(tw, path, wm.Raw) - if err != nil { - return err - } - } - - if len(bundle.Wasm) > 0 { - err := archive.WriteFile(tw, "/"+WasmFile, bundle.Wasm) - if err != nil { - return err - } - } - - return nil -} - -func (w *Writer) writePlan(tw *tar.Writer, bundle Bundle) error { - for _, wm := range bundle.PlanModules { - path := wm.URL - if w.usePath { - path = wm.Path - } - - err := archive.WriteFile(tw, path, wm.Raw) - if err != nil { - return err - } - } - - return nil -} - -func writeManifest(tw *tar.Writer, bundle Bundle) error { - - if bundle.Manifest.Empty() { - return nil - } - - var buf bytes.Buffer - - if err := json.NewEncoder(&buf).Encode(bundle.Manifest); err != nil { - return err - } - - return archive.WriteFile(tw, ManifestExt, buf.Bytes()) -} - -func writePatch(tw *tar.Writer, bundle Bundle) error { - - var buf bytes.Buffer - - if err := json.NewEncoder(&buf).Encode(bundle.Patch); err != nil { - return err - } - - return archive.WriteFile(tw, patchFile, buf.Bytes()) -} - -func writeSignatures(tw *tar.Writer, bundle Bundle) error { - - if bundle.Signatures.isEmpty() { - return nil - } - - bs, err := json.MarshalIndent(bundle.Signatures, "", " ") - if err != nil { - return err - } - - return archive.WriteFile(tw, fmt.Sprintf(".%v", SignaturesFile), bs) -} - -func hashBundleFiles(hash SignatureHasher, b *Bundle) ([]FileInfo, error) { - - files := []FileInfo{} - - bs, err := hash.HashFile(b.Data) - if err != nil { - return files, err - } - files = append(files, NewFile(strings.TrimPrefix("data.json", "/"), hex.EncodeToString(bs), defaultHashingAlg)) - - if len(b.Wasm) != 0 { - bs, err := hash.HashFile(b.Wasm) - if err != nil { - return files, err - } - files = append(files, NewFile(strings.TrimPrefix(WasmFile, "/"), hex.EncodeToString(bs), defaultHashingAlg)) - } - - for _, wasmModule := range b.WasmModules { - bs, err := hash.HashFile(wasmModule.Raw) - if err != nil { - return files, err - } - files = append(files, NewFile(strings.TrimPrefix(wasmModule.Path, "/"), hex.EncodeToString(bs), defaultHashingAlg)) - } - - for _, planmodule := range b.PlanModules { - bs, err := hash.HashFile(planmodule.Raw) - if err != nil { - return files, err - } - files = append(files, NewFile(strings.TrimPrefix(planmodule.Path, "/"), hex.EncodeToString(bs), defaultHashingAlg)) - } - - // If the manifest is essentially empty, don't add it to the signatures since it - // won't be written to the bundle. Otherwise: - // parse the manifest into a JSON structure; - // then recursively order the fields of all objects alphabetically and then apply - // the hash function to result to compute the hash. - if !b.Manifest.Empty() { - mbs, err := json.Marshal(b.Manifest) - if err != nil { - return files, err - } - - var result map[string]interface{} - if err := util.Unmarshal(mbs, &result); err != nil { - return files, err - } - - bs, err = hash.HashFile(result) - if err != nil { - return files, err - } - - files = append(files, NewFile(strings.TrimPrefix(ManifestExt, "/"), hex.EncodeToString(bs), defaultHashingAlg)) - } - - return files, err -} - -// FormatModules formats Rego modules -// Modules will be formatted to comply with rego-v0, but Rego compatibility of individual parsed modules will be respected (e.g. if 'rego.v1' is imported). -func (b *Bundle) FormatModules(useModulePath bool) error { - return b.FormatModulesForRegoVersion(ast.RegoV0, true, useModulePath) -} - -// FormatModulesForRegoVersion formats Rego modules to comply with a given Rego version -func (b *Bundle) FormatModulesForRegoVersion(version ast.RegoVersion, preserveModuleRegoVersion bool, useModulePath bool) error { - var err error - - for i, module := range b.Modules { - opts := format.Opts{} - if preserveModuleRegoVersion { - opts.RegoVersion = module.Parsed.RegoVersion() - opts.ParserOptions = &ast.ParserOptions{ - RegoVersion: opts.RegoVersion, - } - } else { - opts.RegoVersion = version - } - - if module.Raw == nil { - module.Raw, err = format.AstWithOpts(module.Parsed, opts) - if err != nil { - return err - } - } else { - path := module.URL - if useModulePath { - path = module.Path - } - - module.Raw, err = format.SourceWithOpts(path, module.Raw, opts) - if err != nil { - return err - } - } - b.Modules[i].Raw = module.Raw - } - return nil -} - -// GenerateSignature generates the signature for the given bundle. -func (b *Bundle) GenerateSignature(signingConfig *SigningConfig, keyID string, useModulePath bool) error { - - hash, err := NewSignatureHasher(HashingAlgorithm(defaultHashingAlg)) - if err != nil { - return err - } - - files := []FileInfo{} - - for _, module := range b.Modules { - bytes, err := hash.HashFile(module.Raw) - if err != nil { - return err - } - - path := module.URL - if useModulePath { - path = module.Path - } - files = append(files, NewFile(strings.TrimPrefix(path, "/"), hex.EncodeToString(bytes), defaultHashingAlg)) - } - - result, err := hashBundleFiles(hash, b) - if err != nil { - return err - } - files = append(files, result...) - - // generate signed token - token, err := GenerateSignedToken(files, signingConfig, keyID) - if err != nil { - return err - } - - if b.Signatures.isEmpty() { - b.Signatures = SignaturesConfig{} - } - - if signingConfig.Plugin != "" { - b.Signatures.Plugin = signingConfig.Plugin - } - - b.Signatures.Signatures = []string{token} - - return nil -} - -// ParsedModules returns a map of parsed modules with names that are -// unique and human readable for the given a bundle name. -func (b *Bundle) ParsedModules(bundleName string) map[string]*ast.Module { - - mods := make(map[string]*ast.Module, len(b.Modules)) - - for _, mf := range b.Modules { - mods[modulePathWithPrefix(bundleName, mf.Path)] = mf.Parsed - } - - return mods -} - -func (b *Bundle) RegoVersion(def ast.RegoVersion) ast.RegoVersion { - if v := b.Manifest.RegoVersion; v != nil { - if *v == 0 { - return ast.RegoV0 - } else if *v == 1 { - return ast.RegoV1 - } - } - return def -} - -func (b *Bundle) SetRegoVersion(v ast.RegoVersion) { - b.Manifest.SetRegoVersion(v) -} - -// RegoVersionForFile returns the rego-version for the specified file path. -// If there is no defined version for the given path, the default version def is returned. -// If the version does not correspond to ast.RegoV0 or ast.RegoV1, an error is returned. -func (b *Bundle) RegoVersionForFile(path string, def ast.RegoVersion) (ast.RegoVersion, error) { - version, err := b.Manifest.numericRegoVersionForFile(path) - if err != nil { - return def, err - } else if version == nil { - return def, nil - } else if *version == 0 { - return ast.RegoV0, nil - } else if *version == 1 { - return ast.RegoV1, nil - } - return def, fmt.Errorf("unknown bundle rego-version %d for file '%s'", *version, path) -} - -func (m *Manifest) numericRegoVersionForFile(path string) (*int, error) { - var version *int - - if len(m.FileRegoVersions) != len(m.compiledFileRegoVersions) { - m.compiledFileRegoVersions = make([]fileRegoVersion, 0, len(m.FileRegoVersions)) - for pattern, v := range m.FileRegoVersions { - compiled, err := glob.Compile(pattern) - if err != nil { - return nil, fmt.Errorf("failed to compile glob pattern %s: %s", pattern, err) - } - m.compiledFileRegoVersions = append(m.compiledFileRegoVersions, fileRegoVersion{compiled, v}) - } - } - - for _, fv := range m.compiledFileRegoVersions { - if fv.path.Match(path) { - version = &fv.version - break - } - } - - if version == nil { - version = m.RegoVersion - } - return version, nil -} - -// Equal returns true if this bundle's contents equal the other bundle's -// contents. -func (b Bundle) Equal(other Bundle) bool { - if !reflect.DeepEqual(b.Data, other.Data) { - return false - } - - if len(b.Modules) != len(other.Modules) { - return false - } - for i := range b.Modules { - // To support bundles built from rootless filesystems we ignore a "/" prefix - // for URLs and Paths, such that "/file" and "file" are equivalent - if strings.TrimPrefix(b.Modules[i].URL, string(filepath.Separator)) != - strings.TrimPrefix(other.Modules[i].URL, string(filepath.Separator)) { - return false - } - if strings.TrimPrefix(b.Modules[i].Path, string(filepath.Separator)) != - strings.TrimPrefix(other.Modules[i].Path, string(filepath.Separator)) { - return false - } - if !b.Modules[i].Parsed.Equal(other.Modules[i].Parsed) { - return false - } - if !bytes.Equal(b.Modules[i].Raw, other.Modules[i].Raw) { - return false - } - } - if (b.Wasm == nil && other.Wasm != nil) || (b.Wasm != nil && other.Wasm == nil) { - return false - } - - return bytes.Equal(b.Wasm, other.Wasm) -} - -// Copy returns a deep copy of the bundle. -func (b Bundle) Copy() Bundle { - - // Copy data. - var x interface{} = b.Data - - if err := util.RoundTrip(&x); err != nil { - panic(err) - } - - if x != nil { - b.Data = x.(map[string]interface{}) - } - - // Copy modules. - for i := range b.Modules { - bs := make([]byte, len(b.Modules[i].Raw)) - copy(bs, b.Modules[i].Raw) - b.Modules[i].Raw = bs - b.Modules[i].Parsed = b.Modules[i].Parsed.Copy() - } - - // Copy manifest. - b.Manifest = b.Manifest.Copy() - - return b -} - -func (b *Bundle) insertData(key []string, value interface{}) error { - // Build an object with the full structure for the value - obj, err := mktree(key, value) - if err != nil { - return err - } - - // Merge the new data in with the current bundle data object - merged, ok := merge.InterfaceMaps(b.Data, obj) - if !ok { - return fmt.Errorf("failed to insert data file from path %s", filepath.Join(key...)) - } - - b.Data = merged - - return nil -} - -func (b *Bundle) readData(key []string) *interface{} { - - if len(key) == 0 { - if len(b.Data) == 0 { - return nil - } - var result interface{} = b.Data - return &result - } - - node := b.Data - - for i := 0; i < len(key)-1; i++ { - - child, ok := node[key[i]] - if !ok { - return nil - } - - childObj, ok := child.(map[string]interface{}) - if !ok { - return nil - } - - node = childObj - } - - child, ok := node[key[len(key)-1]] - if !ok { - return nil - } - - return &child -} - -// Type returns the type of the bundle. -func (b *Bundle) Type() string { - if len(b.Patch.Data) != 0 { - return DeltaBundleType - } - return SnapshotBundleType -} - -func mktree(path []string, value interface{}) (map[string]interface{}, error) { - if len(path) == 0 { - // For 0 length path the value is the full tree. - obj, ok := value.(map[string]interface{}) - if !ok { - return nil, fmt.Errorf("root value must be object") - } - return obj, nil - } - - dir := map[string]interface{}{} - for i := len(path) - 1; i > 0; i-- { - dir[path[i]] = value - value = dir - dir = map[string]interface{}{} - } - dir[path[0]] = value - - return dir, nil + return v1.NewWriter(w) } // Merge accepts a set of bundles and merges them into a single result bundle. If there are @@ -1393,7 +104,7 @@ func mktree(path []string, value interface{}) (map[string]interface{}, error) { // will have an empty revision except in the special case where a single bundle is provided // (and in that case the bundle is just returned unmodified.) func Merge(bundles []*Bundle) (*Bundle, error) { - return MergeWithRegoVersion(bundles, ast.RegoV0, false) + return MergeWithRegoVersion(bundles, ast.DefaultRegoVersion, false) } // MergeWithRegoVersion creates a merged bundle from the provided bundles, similar to Merge. @@ -1405,348 +116,19 @@ func Merge(bundles []*Bundle) (*Bundle, error) { // If usePath is true, per-file rego-versions will be calculated using the file's ModuleFile.Path; otherwise, the file's // ModuleFile.URL will be used. func MergeWithRegoVersion(bundles []*Bundle, regoVersion ast.RegoVersion, usePath bool) (*Bundle, error) { - - if len(bundles) == 0 { - return nil, errors.New("expected at least one bundle") - } - - if len(bundles) == 1 { - result := bundles[0] - // We respect the bundle rego-version, defaulting to the provided rego version if not set. - result.SetRegoVersion(result.RegoVersion(regoVersion)) - fileRegoVersions, err := bundleRegoVersions(result, result.RegoVersion(regoVersion), usePath) - if err != nil { - return nil, err - } - result.Manifest.FileRegoVersions = fileRegoVersions - return result, nil + if regoVersion == ast.RegoUndefined { + regoVersion = ast.DefaultRegoVersion } - var roots []string - var result Bundle - - for _, b := range bundles { - - if b.Manifest.Roots == nil { - return nil, errors.New("bundle manifest not initialized") - } - - roots = append(roots, *b.Manifest.Roots...) - - result.Modules = append(result.Modules, b.Modules...) - - for _, root := range *b.Manifest.Roots { - key := strings.Split(root, "/") - if val := b.readData(key); val != nil { - if err := result.insertData(key, *val); err != nil { - return nil, err - } - } - } - - result.Manifest.WasmResolvers = append(result.Manifest.WasmResolvers, b.Manifest.WasmResolvers...) - result.WasmModules = append(result.WasmModules, b.WasmModules...) - result.PlanModules = append(result.PlanModules, b.PlanModules...) - - if b.Manifest.RegoVersion != nil || len(b.Manifest.FileRegoVersions) > 0 { - if result.Manifest.FileRegoVersions == nil { - result.Manifest.FileRegoVersions = map[string]int{} - } - - fileRegoVersions, err := bundleRegoVersions(b, regoVersion, usePath) - if err != nil { - return nil, err - } - for k, v := range fileRegoVersions { - result.Manifest.FileRegoVersions[k] = v - } - } - } - - // We respect the bundle rego-version, defaulting to the provided rego version if not set. - result.SetRegoVersion(result.RegoVersion(regoVersion)) - - if result.Data == nil { - result.Data = map[string]interface{}{} - } - - result.Manifest.Roots = &roots - - if err := result.Manifest.validateAndInjectDefaults(result); err != nil { - return nil, err - } - - return &result, nil -} - -func bundleRegoVersions(bundle *Bundle, regoVersion ast.RegoVersion, usePath bool) (map[string]int, error) { - fileRegoVersions := map[string]int{} - - // we drop the bundle-global rego versions and record individual rego versions for each module. - for _, m := range bundle.Modules { - // We fetch rego-version by the path relative to the bundle root, as the complete path of the module might - // contain the path between OPA working directory and the bundle root. - v, err := bundle.RegoVersionForFile(bundleRelativePath(m, usePath), bundle.RegoVersion(regoVersion)) - if err != nil { - return nil, err - } - // only record the rego version if it's different from one applied globally to the result bundle - if v != regoVersion { - // We store the rego version by the absolute path to the bundle root, as this will be the - possibly new - path - // to the module inside the merged bundle. - fileRegoVersions[bundleAbsolutePath(m, usePath)] = v.Int() - } - } - - return fileRegoVersions, nil -} - -func bundleRelativePath(m ModuleFile, usePath bool) string { - p := m.RelativePath - if p == "" { - if usePath { - p = m.Path - } else { - p = m.URL - } - } - return p -} - -func bundleAbsolutePath(m ModuleFile, usePath bool) string { - var p string - if usePath { - p = m.Path - } else { - p = m.URL - } - if !path.IsAbs(p) { - p = "/" + p - } - return path.Clean(p) + return v1.MergeWithRegoVersion(bundles, regoVersion, usePath) } // RootPathsOverlap takes in two bundle root paths and returns true if they overlap. func RootPathsOverlap(pathA string, pathB string) bool { - a := rootPathSegments(pathA) - b := rootPathSegments(pathB) - return rootContains(a, b) || rootContains(b, a) + return v1.RootPathsOverlap(pathA, pathB) } // RootPathsContain takes a set of bundle root paths and returns true if the path is contained. func RootPathsContain(roots []string, path string) bool { - segments := rootPathSegments(path) - for i := range roots { - if rootContains(rootPathSegments(roots[i]), segments) { - return true - } - } - return false -} - -func rootPathSegments(path string) []string { - return strings.Split(path, "/") -} - -func rootContains(root []string, other []string) bool { - - // A single segment, empty string root always contains the other. - if len(root) == 1 && root[0] == "" { - return true - } - - if len(root) > len(other) { - return false - } - - for j := range root { - if root[j] != other[j] { - return false - } - } - - return true -} - -func insertValue(b *Bundle, path string, value interface{}) error { - if err := b.insertData(getNormalizedPath(path), value); err != nil { - return fmt.Errorf("bundle load failed on %v: %w", path, err) - } - return nil -} - -func getNormalizedPath(path string) []string { - // Remove leading / and . characters from the directory path. If the bundle - // was written with OPA then the paths will contain a leading slash. On the - // other hand, if the path is empty, filepath.Dir will return '.'. - // Note: filepath.Dir can return paths with '\' separators, always use - // filepath.ToSlash to keep them normalized. - dirpath := strings.TrimLeft(normalizePath(filepath.Dir(path)), "/.") - var key []string - if dirpath != "" { - key = strings.Split(dirpath, "/") - } - return key -} - -func dfs(value interface{}, path string, fn func(string, interface{}) (bool, error)) error { - if stop, err := fn(path, value); err != nil { - return err - } else if stop { - return nil - } - obj, ok := value.(map[string]interface{}) - if !ok { - return nil - } - for key := range obj { - if err := dfs(obj[key], path+"/"+key, fn); err != nil { - return err - } - } - return nil -} - -func modulePathWithPrefix(bundleName string, modulePath string) string { - // Default prefix is just the bundle name - prefix := bundleName - - // Bundle names are sometimes just file paths, some of which - // are full urls (file:///foo/). Parse these and only use the path. - parsed, err := url.Parse(bundleName) - if err == nil { - prefix = filepath.Join(parsed.Host, parsed.Path) - } - - // Note: filepath.Join can return paths with '\' separators, always use - // filepath.ToSlash to keep them normalized. - return normalizePath(filepath.Join(prefix, modulePath)) -} - -// IsStructuredDoc checks if the file name equals a structured file extension ex. ".json" -func IsStructuredDoc(name string) bool { - return filepath.Base(name) == dataFile || filepath.Base(name) == yamlDataFile || - filepath.Base(name) == SignaturesFile || filepath.Base(name) == ManifestExt -} - -func preProcessBundle(loader DirectoryLoader, skipVerify bool, sizeLimitBytes int64) (SignaturesConfig, Patch, []*Descriptor, error) { - descriptors := []*Descriptor{} - var signatures SignaturesConfig - var patch Patch - - for { - f, err := loader.NextFile() - if err == io.EOF { - break - } - - if err != nil { - return signatures, patch, nil, fmt.Errorf("bundle read failed: %w", err) - } - - // check for the signatures file - if !skipVerify && strings.HasSuffix(f.Path(), SignaturesFile) { - buf, err := readFile(f, sizeLimitBytes) - if err != nil { - return signatures, patch, nil, err - } - - if err := util.NewJSONDecoder(&buf).Decode(&signatures); err != nil { - return signatures, patch, nil, fmt.Errorf("bundle load failed on signatures decode: %w", err) - } - } else if !strings.HasSuffix(f.Path(), SignaturesFile) { - descriptors = append(descriptors, f) - - if filepath.Base(f.Path()) == patchFile { - - var b bytes.Buffer - tee := io.TeeReader(f.reader, &b) - f.reader = tee - - buf, err := readFile(f, sizeLimitBytes) - if err != nil { - return signatures, patch, nil, err - } - - if err := util.NewJSONDecoder(&buf).Decode(&patch); err != nil { - return signatures, patch, nil, fmt.Errorf("bundle load failed on patch decode: %w", err) - } - - f.reader = &b - } - } - } - return signatures, patch, descriptors, nil -} - -func readFile(f *Descriptor, sizeLimitBytes int64) (bytes.Buffer, error) { - // Case for pre-loaded byte buffers, like those from the tarballLoader. - if bb, ok := f.reader.(*bytes.Buffer); ok { - _ = f.Close() // always close, even on error - - if int64(bb.Len()) >= sizeLimitBytes { - return *bb, fmt.Errorf("bundle file '%v' size (%d bytes) exceeded max size (%v bytes)", - strings.TrimPrefix(f.Path(), "/"), bb.Len(), sizeLimitBytes-1) - } - - return *bb, nil - } - - // Case for *lazyFile readers: - if lf, ok := f.reader.(*lazyFile); ok { - var buf bytes.Buffer - if lf.file == nil { - var err error - if lf.file, err = os.Open(lf.path); err != nil { - return buf, fmt.Errorf("failed to open file %s: %w", f.path, err) - } - } - // Bail out if we can't read the whole file-- there's nothing useful we can do at that point! - fileSize, _ := fstatFileSize(lf.file) - if fileSize > sizeLimitBytes { - return buf, fmt.Errorf(maxSizeLimitBytesErrMsg, strings.TrimPrefix(f.Path(), "/"), fileSize, sizeLimitBytes-1) - } - // Prealloc the buffer for the file read. - buffer := make([]byte, fileSize) - _, err := io.ReadFull(lf.file, buffer) - if err != nil { - return buf, err - } - _ = lf.file.Close() // always close, even on error - - // Note(philipc): Replace the lazyFile reader in the *Descriptor with a - // pointer to the wrapping bytes.Buffer, so that we don't re-read the - // file on disk again by accident. - buf = *bytes.NewBuffer(buffer) - f.reader = &buf - return buf, nil - } - - // Fallback case: - var buf bytes.Buffer - n, err := f.Read(&buf, sizeLimitBytes) - _ = f.Close() // always close, even on error - - if err != nil && err != io.EOF { - return buf, err - } else if err == nil && n >= sizeLimitBytes { - return buf, fmt.Errorf(maxSizeLimitBytesErrMsg, strings.TrimPrefix(f.Path(), "/"), n, sizeLimitBytes-1) - } - - return buf, nil -} - -// Takes an already open file handle and invokes the os.Stat system call on it -// to determine the file's size. Passes any errors from *File.Stat on up to the -// caller. -func fstatFileSize(f *os.File) (int64, error) { - fileInfo, err := f.Stat() - if err != nil { - return 0, err - } - return fileInfo.Size(), nil -} - -func normalizePath(p string) string { - return filepath.ToSlash(p) + return v1.RootPathsContain(roots, path) } diff --git a/vendor/github.com/open-policy-agent/opa/bundle/doc.go b/vendor/github.com/open-policy-agent/opa/bundle/doc.go new file mode 100644 index 000000000..7ec7c9b33 --- /dev/null +++ b/vendor/github.com/open-policy-agent/opa/bundle/doc.go @@ -0,0 +1,8 @@ +// Copyright 2024 The OPA Authors. All rights reserved. +// Use of this source code is governed by an Apache2 +// license that can be found in the LICENSE file. + +// Deprecated: This package is intended for older projects transitioning from OPA v0.x and will remain for the lifetime of OPA v1.x, but its use is not recommended. +// For newer features and behaviours, such as defaulting to the Rego v1 syntax, use the corresponding components in the [github.com/open-policy-agent/opa/v1] package instead. +// See https://www.openpolicyagent.org/docs/latest/v0-compatibility/ for more information. +package bundle diff --git a/vendor/github.com/open-policy-agent/opa/bundle/file.go b/vendor/github.com/open-policy-agent/opa/bundle/file.go index 80b1a87eb..ccb7b2351 100644 --- a/vendor/github.com/open-policy-agent/opa/bundle/file.go +++ b/vendor/github.com/open-policy-agent/opa/bundle/file.go @@ -1,508 +1,50 @@ package bundle import ( - "archive/tar" - "bytes" - "compress/gzip" - "fmt" "io" - "io/fs" - "os" - "path/filepath" - "sort" - "strings" - "sync" - - "github.com/open-policy-agent/opa/loader/filter" "github.com/open-policy-agent/opa/storage" + v1 "github.com/open-policy-agent/opa/v1/bundle" ) -const maxSizeLimitBytesErrMsg = "bundle file %s size (%d bytes) exceeds configured size_limit_bytes (%d bytes)" - // Descriptor contains information about a file and // can be used to read the file contents. -type Descriptor struct { - url string - path string - reader io.Reader - closer io.Closer - closeOnce *sync.Once -} - -// lazyFile defers reading the file until the first call of Read -type lazyFile struct { - path string - file *os.File -} - -// newLazyFile creates a new instance of lazyFile -func newLazyFile(path string) *lazyFile { - return &lazyFile{path: path} -} - -// Read implements io.Reader. It will check if the file has been opened -// and open it if it has not before attempting to read using the file's -// read method -func (f *lazyFile) Read(b []byte) (int, error) { - var err error - - if f.file == nil { - if f.file, err = os.Open(f.path); err != nil { - return 0, fmt.Errorf("failed to open file %s: %w", f.path, err) - } - } - - return f.file.Read(b) -} - -// Close closes the lazy file if it has been opened using the file's -// close method -func (f *lazyFile) Close() error { - if f.file != nil { - return f.file.Close() - } - - return nil -} +type Descriptor = v1.Descriptor func NewDescriptor(url, path string, reader io.Reader) *Descriptor { - return &Descriptor{ - url: url, - path: path, - reader: reader, - } -} - -func (d *Descriptor) WithCloser(closer io.Closer) *Descriptor { - d.closer = closer - d.closeOnce = new(sync.Once) - return d -} - -// Path returns the path of the file. -func (d *Descriptor) Path() string { - return d.path -} - -// URL returns the url of the file. -func (d *Descriptor) URL() string { - return d.url -} - -// Read will read all the contents from the file the Descriptor refers to -// into the dest writer up n bytes. Will return an io.EOF error -// if EOF is encountered before n bytes are read. -func (d *Descriptor) Read(dest io.Writer, n int64) (int64, error) { - n, err := io.CopyN(dest, d.reader, n) - return n, err + return v1.NewDescriptor(url, path, reader) } -// Close the file, on some Loader implementations this might be a no-op. -// It should *always* be called regardless of file. -func (d *Descriptor) Close() error { - var err error - if d.closer != nil { - d.closeOnce.Do(func() { - err = d.closer.Close() - }) - } - return err -} - -type PathFormat int64 +type PathFormat = v1.PathFormat const ( - Chrooted PathFormat = iota - SlashRooted - Passthrough + Chrooted = v1.Chrooted + SlashRooted = v1.SlashRooted + Passthrough = v1.Passthrough ) // DirectoryLoader defines an interface which can be used to load // files from a directory by iterating over each one in the tree. -type DirectoryLoader interface { - // NextFile must return io.EOF if there is no next value. The returned - // descriptor should *always* be closed when no longer needed. - NextFile() (*Descriptor, error) - WithFilter(filter filter.LoaderFilter) DirectoryLoader - WithPathFormat(PathFormat) DirectoryLoader - WithSizeLimitBytes(sizeLimitBytes int64) DirectoryLoader - WithFollowSymlinks(followSymlinks bool) DirectoryLoader -} - -type dirLoader struct { - root string - files []string - idx int - filter filter.LoaderFilter - pathFormat PathFormat - maxSizeLimitBytes int64 - followSymlinks bool -} - -// Normalize root directory, ex "./src/bundle" -> "src/bundle" -// We don't need an absolute path, but this makes the joined/trimmed -// paths more uniform. -func normalizeRootDirectory(root string) string { - if len(root) > 1 { - if root[0] == '.' && root[1] == filepath.Separator { - if len(root) == 2 { - root = root[:1] // "./" -> "." - } else { - root = root[2:] // remove leading "./" - } - } - } - return root -} +type DirectoryLoader = v1.DirectoryLoader // NewDirectoryLoader returns a basic DirectoryLoader implementation // that will load files from a given root directory path. func NewDirectoryLoader(root string) DirectoryLoader { - d := dirLoader{ - root: normalizeRootDirectory(root), - pathFormat: Chrooted, - } - return &d -} - -// WithFilter specifies the filter object to use to filter files while loading bundles -func (d *dirLoader) WithFilter(filter filter.LoaderFilter) DirectoryLoader { - d.filter = filter - return d -} - -// WithPathFormat specifies how a path is formatted in a Descriptor -func (d *dirLoader) WithPathFormat(pathFormat PathFormat) DirectoryLoader { - d.pathFormat = pathFormat - return d -} - -// WithSizeLimitBytes specifies the maximum size of any file in the directory to read -func (d *dirLoader) WithSizeLimitBytes(sizeLimitBytes int64) DirectoryLoader { - d.maxSizeLimitBytes = sizeLimitBytes - return d -} - -// WithFollowSymlinks specifies whether to follow symlinks when loading files from the directory -func (d *dirLoader) WithFollowSymlinks(followSymlinks bool) DirectoryLoader { - d.followSymlinks = followSymlinks - return d -} - -func formatPath(fileName string, root string, pathFormat PathFormat) string { - switch pathFormat { - case SlashRooted: - if !strings.HasPrefix(fileName, string(filepath.Separator)) { - return string(filepath.Separator) + fileName - } - return fileName - case Chrooted: - // Trim off the root directory and return path as if chrooted - result := strings.TrimPrefix(fileName, filepath.FromSlash(root)) - if root == "." && filepath.Base(fileName) == ManifestExt { - result = fileName - } - if !strings.HasPrefix(result, string(filepath.Separator)) { - result = string(filepath.Separator) + result - } - return result - case Passthrough: - fallthrough - default: - return fileName - } -} - -// NextFile iterates to the next file in the directory tree -// and returns a file Descriptor for the file. -func (d *dirLoader) NextFile() (*Descriptor, error) { - // build a list of all files we will iterate over and read, but only one time - if d.files == nil { - d.files = []string{} - err := filepath.Walk(d.root, func(path string, info os.FileInfo, _ error) error { - if info == nil { - return nil - } - - if info.Mode().IsRegular() { - if d.filter != nil && d.filter(filepath.ToSlash(path), info, getdepth(path, false)) { - return nil - } - if d.maxSizeLimitBytes > 0 && info.Size() > d.maxSizeLimitBytes { - return fmt.Errorf(maxSizeLimitBytesErrMsg, strings.TrimPrefix(path, "/"), info.Size(), d.maxSizeLimitBytes) - } - d.files = append(d.files, path) - } else if d.followSymlinks && info.Mode().Type()&fs.ModeSymlink == fs.ModeSymlink { - if d.filter != nil && d.filter(filepath.ToSlash(path), info, getdepth(path, false)) { - return nil - } - if d.maxSizeLimitBytes > 0 && info.Size() > d.maxSizeLimitBytes { - return fmt.Errorf(maxSizeLimitBytesErrMsg, strings.TrimPrefix(path, "/"), info.Size(), d.maxSizeLimitBytes) - } - d.files = append(d.files, path) - } else if info.Mode().IsDir() { - if d.filter != nil && d.filter(filepath.ToSlash(path), info, getdepth(path, true)) { - return filepath.SkipDir - } - } - return nil - }) - if err != nil { - return nil, fmt.Errorf("failed to list files: %w", err) - } - } - - // If done reading files then just return io.EOF - // errors for each NextFile() call - if d.idx >= len(d.files) { - return nil, io.EOF - } - - fileName := d.files[d.idx] - d.idx++ - fh := newLazyFile(fileName) - - cleanedPath := formatPath(fileName, d.root, d.pathFormat) - f := NewDescriptor(filepath.Join(d.root, cleanedPath), cleanedPath, fh).WithCloser(fh) - return f, nil -} - -type tarballLoader struct { - baseURL string - r io.Reader - tr *tar.Reader - files []file - idx int - filter filter.LoaderFilter - skipDir map[string]struct{} - pathFormat PathFormat - maxSizeLimitBytes int64 -} - -type file struct { - name string - reader io.Reader - path storage.Path - raw []byte + return v1.NewDirectoryLoader(root) } // NewTarballLoader is deprecated. Use NewTarballLoaderWithBaseURL instead. func NewTarballLoader(r io.Reader) DirectoryLoader { - l := tarballLoader{ - r: r, - pathFormat: Passthrough, - } - return &l + return v1.NewTarballLoader(r) } // NewTarballLoaderWithBaseURL returns a new DirectoryLoader that reads // files out of a gzipped tar archive. The file URLs will be prefixed // with the baseURL. func NewTarballLoaderWithBaseURL(r io.Reader, baseURL string) DirectoryLoader { - l := tarballLoader{ - baseURL: strings.TrimSuffix(baseURL, "/"), - r: r, - pathFormat: Passthrough, - } - return &l -} - -// WithFilter specifies the filter object to use to filter files while loading bundles -func (t *tarballLoader) WithFilter(filter filter.LoaderFilter) DirectoryLoader { - t.filter = filter - return t -} - -// WithPathFormat specifies how a path is formatted in a Descriptor -func (t *tarballLoader) WithPathFormat(pathFormat PathFormat) DirectoryLoader { - t.pathFormat = pathFormat - return t -} - -// WithSizeLimitBytes specifies the maximum size of any file in the tarball to read -func (t *tarballLoader) WithSizeLimitBytes(sizeLimitBytes int64) DirectoryLoader { - t.maxSizeLimitBytes = sizeLimitBytes - return t -} - -// WithFollowSymlinks is a no-op for tarballLoader -func (t *tarballLoader) WithFollowSymlinks(_ bool) DirectoryLoader { - return t -} - -// NextFile iterates to the next file in the directory tree -// and returns a file Descriptor for the file. -func (t *tarballLoader) NextFile() (*Descriptor, error) { - if t.tr == nil { - gr, err := gzip.NewReader(t.r) - if err != nil { - return nil, fmt.Errorf("archive read failed: %w", err) - } - - t.tr = tar.NewReader(gr) - } - - if t.files == nil { - t.files = []file{} - - if t.skipDir == nil { - t.skipDir = map[string]struct{}{} - } - - for { - header, err := t.tr.Next() - - if err == io.EOF { - break - } - - if err != nil { - return nil, err - } - - // Keep iterating on the archive until we find a normal file - if header.Typeflag == tar.TypeReg { - - if t.filter != nil { - - if t.filter(filepath.ToSlash(header.Name), header.FileInfo(), getdepth(header.Name, false)) { - continue - } - - basePath := strings.Trim(filepath.Dir(filepath.ToSlash(header.Name)), "/") - - // check if the directory is to be skipped - if _, ok := t.skipDir[basePath]; ok { - continue - } - - match := false - for p := range t.skipDir { - if strings.HasPrefix(basePath, p) { - match = true - break - } - } - - if match { - continue - } - } - - if t.maxSizeLimitBytes > 0 && header.Size > t.maxSizeLimitBytes { - return nil, fmt.Errorf(maxSizeLimitBytesErrMsg, header.Name, header.Size, t.maxSizeLimitBytes) - } - - f := file{name: header.Name} - - // Note(philipc): We rely on the previous size check in this loop for safety. - buf := bytes.NewBuffer(make([]byte, 0, header.Size)) - if _, err := io.Copy(buf, t.tr); err != nil { - return nil, fmt.Errorf("failed to copy file %s: %w", header.Name, err) - } - - f.reader = buf - - t.files = append(t.files, f) - } else if header.Typeflag == tar.TypeDir { - cleanedPath := filepath.ToSlash(header.Name) - if t.filter != nil && t.filter(cleanedPath, header.FileInfo(), getdepth(header.Name, true)) { - t.skipDir[strings.Trim(cleanedPath, "/")] = struct{}{} - } - } - } - } - - // If done reading files then just return io.EOF - // errors for each NextFile() call - if t.idx >= len(t.files) { - return nil, io.EOF - } - - f := t.files[t.idx] - t.idx++ - - cleanedPath := formatPath(f.name, "", t.pathFormat) - d := NewDescriptor(filepath.Join(t.baseURL, cleanedPath), cleanedPath, f.reader) - return d, nil -} - -// Next implements the storage.Iterator interface. -// It iterates to the next policy or data file in the directory tree -// and returns a storage.Update for the file. -func (it *iterator) Next() (*storage.Update, error) { - if it.files == nil { - it.files = []file{} - - for _, item := range it.raw { - f := file{name: item.Path} - - fpath := strings.TrimLeft(normalizePath(filepath.Dir(f.name)), "/.") - if strings.HasSuffix(f.name, RegoExt) { - fpath = strings.Trim(normalizePath(f.name), "/") - } - - p, ok := storage.ParsePathEscaped("/" + fpath) - if !ok { - return nil, fmt.Errorf("storage path invalid: %v", f.name) - } - f.path = p - - f.raw = item.Value - - it.files = append(it.files, f) - } - - sortFilePathAscend(it.files) - } - - // If done reading files then just return io.EOF - // errors for each NextFile() call - if it.idx >= len(it.files) { - return nil, io.EOF - } - - f := it.files[it.idx] - it.idx++ - - isPolicy := false - if strings.HasSuffix(f.name, RegoExt) { - isPolicy = true - } - - return &storage.Update{ - Path: f.path, - Value: f.raw, - IsPolicy: isPolicy, - }, nil -} - -type iterator struct { - raw []Raw - files []file - idx int + return v1.NewTarballLoaderWithBaseURL(r, baseURL) } func NewIterator(raw []Raw) storage.Iterator { - it := iterator{ - raw: raw, - } - return &it -} - -func sortFilePathAscend(files []file) { - sort.Slice(files, func(i, j int) bool { - return len(files[i].path) < len(files[j].path) - }) -} - -func getdepth(path string, isDir bool) int { - if isDir { - cleanedPath := strings.Trim(filepath.ToSlash(path), "/") - return len(strings.Split(cleanedPath, "/")) - } - - basePath := strings.Trim(filepath.Dir(filepath.ToSlash(path)), "/") - return len(strings.Split(basePath, "/")) + return v1.NewIterator(raw) } diff --git a/vendor/github.com/open-policy-agent/opa/bundle/filefs.go b/vendor/github.com/open-policy-agent/opa/bundle/filefs.go index a3a0dbf20..16e00928d 100644 --- a/vendor/github.com/open-policy-agent/opa/bundle/filefs.go +++ b/vendor/github.com/open-policy-agent/opa/bundle/filefs.go @@ -4,140 +4,19 @@ package bundle import ( - "fmt" - "io" "io/fs" - "path/filepath" - "sync" - "github.com/open-policy-agent/opa/loader/filter" + v1 "github.com/open-policy-agent/opa/v1/bundle" ) -const ( - defaultFSLoaderRoot = "." -) - -type dirLoaderFS struct { - sync.Mutex - filesystem fs.FS - files []string - idx int - filter filter.LoaderFilter - root string - pathFormat PathFormat - maxSizeLimitBytes int64 - followSymlinks bool -} - // NewFSLoader returns a basic DirectoryLoader implementation // that will load files from a fs.FS interface func NewFSLoader(filesystem fs.FS) (DirectoryLoader, error) { - return NewFSLoaderWithRoot(filesystem, defaultFSLoaderRoot), nil + return v1.NewFSLoader(filesystem) } // NewFSLoaderWithRoot returns a basic DirectoryLoader implementation // that will load files from a fs.FS interface at the supplied root func NewFSLoaderWithRoot(filesystem fs.FS, root string) DirectoryLoader { - d := dirLoaderFS{ - filesystem: filesystem, - root: normalizeRootDirectory(root), - pathFormat: Chrooted, - } - - return &d -} - -func (d *dirLoaderFS) walkDir(path string, dirEntry fs.DirEntry, err error) error { - if err != nil { - return err - } - - if dirEntry != nil { - info, err := dirEntry.Info() - if err != nil { - return err - } - - if dirEntry.Type().IsRegular() { - if d.filter != nil && d.filter(filepath.ToSlash(path), info, getdepth(path, false)) { - return nil - } - - if d.maxSizeLimitBytes > 0 && info.Size() > d.maxSizeLimitBytes { - return fmt.Errorf("file %s size %d exceeds limit of %d", path, info.Size(), d.maxSizeLimitBytes) - } - - d.files = append(d.files, path) - } else if dirEntry.Type()&fs.ModeSymlink != 0 && d.followSymlinks { - if d.filter != nil && d.filter(filepath.ToSlash(path), info, getdepth(path, false)) { - return nil - } - - if d.maxSizeLimitBytes > 0 && info.Size() > d.maxSizeLimitBytes { - return fmt.Errorf("file %s size %d exceeds limit of %d", path, info.Size(), d.maxSizeLimitBytes) - } - - d.files = append(d.files, path) - } else if dirEntry.Type().IsDir() { - if d.filter != nil && d.filter(filepath.ToSlash(path), info, getdepth(path, true)) { - return fs.SkipDir - } - } - } - return nil -} - -// WithFilter specifies the filter object to use to filter files while loading bundles -func (d *dirLoaderFS) WithFilter(filter filter.LoaderFilter) DirectoryLoader { - d.filter = filter - return d -} - -// WithPathFormat specifies how a path is formatted in a Descriptor -func (d *dirLoaderFS) WithPathFormat(pathFormat PathFormat) DirectoryLoader { - d.pathFormat = pathFormat - return d -} - -// WithSizeLimitBytes specifies the maximum size of any file in the filesystem directory to read -func (d *dirLoaderFS) WithSizeLimitBytes(sizeLimitBytes int64) DirectoryLoader { - d.maxSizeLimitBytes = sizeLimitBytes - return d -} - -func (d *dirLoaderFS) WithFollowSymlinks(followSymlinks bool) DirectoryLoader { - d.followSymlinks = followSymlinks - return d -} - -// NextFile iterates to the next file in the directory tree -// and returns a file Descriptor for the file. -func (d *dirLoaderFS) NextFile() (*Descriptor, error) { - d.Lock() - defer d.Unlock() - - if d.files == nil { - err := fs.WalkDir(d.filesystem, d.root, d.walkDir) - if err != nil { - return nil, fmt.Errorf("failed to list files: %w", err) - } - } - - // If done reading files then just return io.EOF - // errors for each NextFile() call - if d.idx >= len(d.files) { - return nil, io.EOF - } - - fileName := d.files[d.idx] - d.idx++ - - fh, err := d.filesystem.Open(fileName) - if err != nil { - return nil, fmt.Errorf("failed to open file %s: %w", fileName, err) - } - - cleanedPath := formatPath(fileName, d.root, d.pathFormat) - f := NewDescriptor(cleanedPath, cleanedPath, fh).WithCloser(fh) - return f, nil + return v1.NewFSLoaderWithRoot(filesystem, root) } diff --git a/vendor/github.com/open-policy-agent/opa/bundle/hash.go b/vendor/github.com/open-policy-agent/opa/bundle/hash.go index 021801bb0..d4cc601de 100644 --- a/vendor/github.com/open-policy-agent/opa/bundle/hash.go +++ b/vendor/github.com/open-policy-agent/opa/bundle/hash.go @@ -5,137 +5,28 @@ package bundle import ( - "bytes" - "crypto/md5" - "crypto/sha1" - "crypto/sha256" - "crypto/sha512" - "encoding/json" - "fmt" - "hash" - "io" - "sort" - "strings" + v1 "github.com/open-policy-agent/opa/v1/bundle" ) // HashingAlgorithm represents a subset of hashing algorithms implemented in Go -type HashingAlgorithm string +type HashingAlgorithm = v1.HashingAlgorithm // Supported values for HashingAlgorithm const ( - MD5 HashingAlgorithm = "MD5" - SHA1 HashingAlgorithm = "SHA-1" - SHA224 HashingAlgorithm = "SHA-224" - SHA256 HashingAlgorithm = "SHA-256" - SHA384 HashingAlgorithm = "SHA-384" - SHA512 HashingAlgorithm = "SHA-512" - SHA512224 HashingAlgorithm = "SHA-512-224" - SHA512256 HashingAlgorithm = "SHA-512-256" + MD5 = v1.MD5 + SHA1 = v1.SHA1 + SHA224 = v1.SHA224 + SHA256 = v1.SHA256 + SHA384 = v1.SHA384 + SHA512 = v1.SHA512 + SHA512224 = v1.SHA512224 + SHA512256 = v1.SHA512256 ) -// String returns the string representation of a HashingAlgorithm -func (alg HashingAlgorithm) String() string { - return string(alg) -} - // SignatureHasher computes a signature digest for a file with (structured or unstructured) data and policy -type SignatureHasher interface { - HashFile(v interface{}) ([]byte, error) -} - -type hasher struct { - h func() hash.Hash // hash function factory -} +type SignatureHasher = v1.SignatureHasher // NewSignatureHasher returns a signature hasher suitable for a particular hashing algorithm func NewSignatureHasher(alg HashingAlgorithm) (SignatureHasher, error) { - h := &hasher{} - - switch alg { - case MD5: - h.h = md5.New - case SHA1: - h.h = sha1.New - case SHA224: - h.h = sha256.New224 - case SHA256: - h.h = sha256.New - case SHA384: - h.h = sha512.New384 - case SHA512: - h.h = sha512.New - case SHA512224: - h.h = sha512.New512_224 - case SHA512256: - h.h = sha512.New512_256 - default: - return nil, fmt.Errorf("unsupported hashing algorithm: %s", alg) - } - - return h, nil -} - -// HashFile hashes the file content, JSON or binary, both in golang native format. -func (h *hasher) HashFile(v interface{}) ([]byte, error) { - hf := h.h() - walk(v, hf) - return hf.Sum(nil), nil -} - -// walk hashes the file content, JSON or binary, both in golang native format. -// -// Computation for unstructured documents is a hash of the document. -// -// Computation for the types of structured JSON document is as follows: -// -// object: Hash {, then each key (in alphabetical order) and digest of the value, then comma (between items) and finally }. -// -// array: Hash [, then digest of the value, then comma (between items) and finally ]. -func walk(v interface{}, h io.Writer) { - - switch x := v.(type) { - case map[string]interface{}: - _, _ = h.Write([]byte("{")) - - var keys []string - for k := range x { - keys = append(keys, k) - } - sort.Strings(keys) - - for i, key := range keys { - if i > 0 { - _, _ = h.Write([]byte(",")) - } - - _, _ = h.Write(encodePrimitive(key)) - _, _ = h.Write([]byte(":")) - walk(x[key], h) - } - - _, _ = h.Write([]byte("}")) - case []interface{}: - _, _ = h.Write([]byte("[")) - - for i, e := range x { - if i > 0 { - _, _ = h.Write([]byte(",")) - } - walk(e, h) - } - - _, _ = h.Write([]byte("]")) - case []byte: - _, _ = h.Write(x) - default: - _, _ = h.Write(encodePrimitive(x)) - } -} - -func encodePrimitive(v interface{}) []byte { - var buf bytes.Buffer - encoder := json.NewEncoder(&buf) - encoder.SetEscapeHTML(false) - _ = encoder.Encode(v) - return []byte(strings.Trim(buf.String(), "\n")) + return v1.NewSignatureHasher(alg) } diff --git a/vendor/github.com/open-policy-agent/opa/bundle/keys.go b/vendor/github.com/open-policy-agent/opa/bundle/keys.go index 810bee4b7..99f9b0f16 100644 --- a/vendor/github.com/open-policy-agent/opa/bundle/keys.go +++ b/vendor/github.com/open-policy-agent/opa/bundle/keys.go @@ -6,139 +6,25 @@ package bundle import ( - "encoding/pem" - "fmt" - "os" - - "github.com/open-policy-agent/opa/internal/jwx/jwa" - "github.com/open-policy-agent/opa/internal/jwx/jws/sign" - "github.com/open-policy-agent/opa/keys" - - "github.com/open-policy-agent/opa/util" -) - -const ( - defaultTokenSigningAlg = "RS256" + v1 "github.com/open-policy-agent/opa/v1/bundle" ) // KeyConfig holds the keys used to sign or verify bundles and tokens // Moved to own package, alias kept for backwards compatibility -type KeyConfig = keys.Config +type KeyConfig = v1.KeyConfig // VerificationConfig represents the key configuration used to verify a signed bundle -type VerificationConfig struct { - PublicKeys map[string]*KeyConfig - KeyID string `json:"keyid"` - Scope string `json:"scope"` - Exclude []string `json:"exclude_files"` -} +type VerificationConfig = v1.VerificationConfig // NewVerificationConfig return a new VerificationConfig func NewVerificationConfig(keys map[string]*KeyConfig, id, scope string, exclude []string) *VerificationConfig { - return &VerificationConfig{ - PublicKeys: keys, - KeyID: id, - Scope: scope, - Exclude: exclude, - } -} - -// ValidateAndInjectDefaults validates the config and inserts default values -func (vc *VerificationConfig) ValidateAndInjectDefaults(keys map[string]*KeyConfig) error { - vc.PublicKeys = keys - - if vc.KeyID != "" { - found := false - for key := range keys { - if key == vc.KeyID { - found = true - break - } - } - - if !found { - return fmt.Errorf("key id %s not found", vc.KeyID) - } - } - return nil -} - -// GetPublicKey returns the public key corresponding to the given key id -func (vc *VerificationConfig) GetPublicKey(id string) (*KeyConfig, error) { - var kc *KeyConfig - var ok bool - - if kc, ok = vc.PublicKeys[id]; !ok { - return nil, fmt.Errorf("verification key corresponding to ID %v not found", id) - } - return kc, nil + return v1.NewVerificationConfig(keys, id, scope, exclude) } // SigningConfig represents the key configuration used to generate a signed bundle -type SigningConfig struct { - Plugin string - Key string - Algorithm string - ClaimsPath string -} +type SigningConfig = v1.SigningConfig // NewSigningConfig return a new SigningConfig func NewSigningConfig(key, alg, claimsPath string) *SigningConfig { - if alg == "" { - alg = defaultTokenSigningAlg - } - - return &SigningConfig{ - Plugin: defaultSignerID, - Key: key, - Algorithm: alg, - ClaimsPath: claimsPath, - } -} - -// WithPlugin sets the signing plugin in the signing config -func (s *SigningConfig) WithPlugin(plugin string) *SigningConfig { - if plugin != "" { - s.Plugin = plugin - } - return s -} - -// GetPrivateKey returns the private key or secret from the signing config -func (s *SigningConfig) GetPrivateKey() (interface{}, error) { - - block, _ := pem.Decode([]byte(s.Key)) - if block != nil { - return sign.GetSigningKey(s.Key, jwa.SignatureAlgorithm(s.Algorithm)) - } - - var priv string - if _, err := os.Stat(s.Key); err == nil { - bs, err := os.ReadFile(s.Key) - if err != nil { - return nil, err - } - priv = string(bs) - } else if os.IsNotExist(err) { - priv = s.Key - } else { - return nil, err - } - - return sign.GetSigningKey(priv, jwa.SignatureAlgorithm(s.Algorithm)) -} - -// GetClaims returns the claims by reading the file specified in the signing config -func (s *SigningConfig) GetClaims() (map[string]interface{}, error) { - var claims map[string]interface{} - - bs, err := os.ReadFile(s.ClaimsPath) - if err != nil { - return claims, err - } - - if err := util.UnmarshalJSON(bs, &claims); err != nil { - return claims, err - } - return claims, nil + return v1.NewSigningConfig(key, alg, claimsPath) } diff --git a/vendor/github.com/open-policy-agent/opa/bundle/sign.go b/vendor/github.com/open-policy-agent/opa/bundle/sign.go index cf9a3e183..56e25eec9 100644 --- a/vendor/github.com/open-policy-agent/opa/bundle/sign.go +++ b/vendor/github.com/open-policy-agent/opa/bundle/sign.go @@ -6,130 +6,30 @@ package bundle import ( - "crypto/rand" - "encoding/json" - "fmt" - - "github.com/open-policy-agent/opa/internal/jwx/jwa" - "github.com/open-policy-agent/opa/internal/jwx/jws" + v1 "github.com/open-policy-agent/opa/v1/bundle" ) -const defaultSignerID = "_default" - -var signers map[string]Signer - // Signer is the interface expected for implementations that generate bundle signatures. -type Signer interface { - GenerateSignedToken([]FileInfo, *SigningConfig, string) (string, error) -} +type Signer v1.Signer // GenerateSignedToken will retrieve the Signer implementation based on the Plugin specified // in SigningConfig, and call its implementation of GenerateSignedToken. The signer generates // a signed token given the list of files to be included in the payload and the bundle // signing config. The keyID if non-empty, represents the value for the "keyid" claim in the token. func GenerateSignedToken(files []FileInfo, sc *SigningConfig, keyID string) (string, error) { - var plugin string - // for backwards compatibility, check if there is no plugin specified, and use default - if sc.Plugin == "" { - plugin = defaultSignerID - } else { - plugin = sc.Plugin - } - signer, err := GetSigner(plugin) - if err != nil { - return "", err - } - return signer.GenerateSignedToken(files, sc, keyID) + return v1.GenerateSignedToken(files, sc, keyID) } // DefaultSigner is the default bundle signing implementation. It signs bundles by generating // a JWT and signing it using a locally-accessible private key. -type DefaultSigner struct{} - -// GenerateSignedToken generates a signed token given the list of files to be -// included in the payload and the bundle signing config. The keyID if non-empty, -// represents the value for the "keyid" claim in the token -func (*DefaultSigner) GenerateSignedToken(files []FileInfo, sc *SigningConfig, keyID string) (string, error) { - payload, err := generatePayload(files, sc, keyID) - if err != nil { - return "", err - } - - privateKey, err := sc.GetPrivateKey() - if err != nil { - return "", err - } - - var headers jws.StandardHeaders - - if err := headers.Set(jws.AlgorithmKey, jwa.SignatureAlgorithm(sc.Algorithm)); err != nil { - return "", err - } - - if keyID != "" { - if err := headers.Set(jws.KeyIDKey, keyID); err != nil { - return "", err - } - } - - hdr, err := json.Marshal(headers) - if err != nil { - return "", err - } - - token, err := jws.SignLiteral(payload, - jwa.SignatureAlgorithm(sc.Algorithm), - privateKey, - hdr, - rand.Reader) - if err != nil { - return "", err - } - return string(token), nil -} - -func generatePayload(files []FileInfo, sc *SigningConfig, keyID string) ([]byte, error) { - payload := make(map[string]interface{}) - payload["files"] = files - - if sc.ClaimsPath != "" { - claims, err := sc.GetClaims() - if err != nil { - return nil, err - } - - for claim, value := range claims { - payload[claim] = value - } - } else { - if keyID != "" { - // keyid claim is deprecated but include it for backwards compatibility. - payload["keyid"] = keyID - } - } - return json.Marshal(payload) -} +type DefaultSigner v1.DefaultSigner // GetSigner returns the Signer registered under the given id func GetSigner(id string) (Signer, error) { - signer, ok := signers[id] - if !ok { - return nil, fmt.Errorf("no signer exists under id %s", id) - } - return signer, nil + return v1.GetSigner(id) } // RegisterSigner registers a Signer under the given id func RegisterSigner(id string, s Signer) error { - if id == defaultSignerID { - return fmt.Errorf("signer id %s is reserved, use a different id", id) - } - signers[id] = s - return nil -} - -func init() { - signers = map[string]Signer{ - defaultSignerID: &DefaultSigner{}, - } + return v1.RegisterSigner(id, s) } diff --git a/vendor/github.com/open-policy-agent/opa/bundle/store.go b/vendor/github.com/open-policy-agent/opa/bundle/store.go index 9a49f025e..d73cc7742 100644 --- a/vendor/github.com/open-policy-agent/opa/bundle/store.go +++ b/vendor/github.com/open-policy-agent/opa/bundle/store.go @@ -6,1031 +6,118 @@ package bundle import ( "context" - "encoding/base64" - "encoding/json" - "fmt" - "path/filepath" - "strings" - "github.com/open-policy-agent/opa/ast" - iCompiler "github.com/open-policy-agent/opa/internal/compiler" - "github.com/open-policy-agent/opa/internal/json/patch" - "github.com/open-policy-agent/opa/metrics" "github.com/open-policy-agent/opa/storage" - "github.com/open-policy-agent/opa/util" + v1 "github.com/open-policy-agent/opa/v1/bundle" ) // BundlesBasePath is the storage path used for storing bundle metadata -var BundlesBasePath = storage.MustParsePath("/system/bundles") +var BundlesBasePath = v1.BundlesBasePath // Note: As needed these helpers could be memoized. // ManifestStoragePath is the storage path used for the given named bundle manifest. func ManifestStoragePath(name string) storage.Path { - return append(BundlesBasePath, name, "manifest") + return v1.ManifestStoragePath(name) } // EtagStoragePath is the storage path used for the given named bundle etag. func EtagStoragePath(name string) storage.Path { - return append(BundlesBasePath, name, "etag") -} - -func namedBundlePath(name string) storage.Path { - return append(BundlesBasePath, name) -} - -func rootsPath(name string) storage.Path { - return append(BundlesBasePath, name, "manifest", "roots") -} - -func revisionPath(name string) storage.Path { - return append(BundlesBasePath, name, "manifest", "revision") -} - -func wasmModulePath(name string) storage.Path { - return append(BundlesBasePath, name, "wasm") -} - -func wasmEntrypointsPath(name string) storage.Path { - return append(BundlesBasePath, name, "manifest", "wasm") -} - -func metadataPath(name string) storage.Path { - return append(BundlesBasePath, name, "manifest", "metadata") -} - -func read(ctx context.Context, store storage.Store, txn storage.Transaction, path storage.Path) (interface{}, error) { - value, err := store.Read(ctx, txn, path) - if err != nil { - return nil, err - } - - if astValue, ok := value.(ast.Value); ok { - value, err = ast.JSON(astValue) - if err != nil { - return nil, err - } - } - - return value, nil + return v1.EtagStoragePath(name) } // ReadBundleNamesFromStore will return a list of bundle names which have had their metadata stored. func ReadBundleNamesFromStore(ctx context.Context, store storage.Store, txn storage.Transaction) ([]string, error) { - value, err := read(ctx, store, txn, BundlesBasePath) - if err != nil { - return nil, err - } - - bundleMap, ok := value.(map[string]interface{}) - if !ok { - return nil, fmt.Errorf("corrupt manifest roots") - } - - bundles := make([]string, len(bundleMap)) - idx := 0 - for name := range bundleMap { - bundles[idx] = name - idx++ - } - return bundles, nil + return v1.ReadBundleNamesFromStore(ctx, store, txn) } // WriteManifestToStore will write the manifest into the storage. This function is called when // the bundle is activated. func WriteManifestToStore(ctx context.Context, store storage.Store, txn storage.Transaction, name string, manifest Manifest) error { - return write(ctx, store, txn, ManifestStoragePath(name), manifest) + return v1.WriteManifestToStore(ctx, store, txn, name, manifest) } // WriteEtagToStore will write the bundle etag into the storage. This function is called when the bundle is activated. func WriteEtagToStore(ctx context.Context, store storage.Store, txn storage.Transaction, name, etag string) error { - return write(ctx, store, txn, EtagStoragePath(name), etag) -} - -func write(ctx context.Context, store storage.Store, txn storage.Transaction, path storage.Path, value interface{}) error { - if err := util.RoundTrip(&value); err != nil { - return err - } - - var dir []string - if len(path) > 1 { - dir = path[:len(path)-1] - } - - if err := storage.MakeDir(ctx, store, txn, dir); err != nil { - return err - } - - return store.Write(ctx, txn, storage.AddOp, path, value) + return v1.WriteEtagToStore(ctx, store, txn, name, etag) } // EraseManifestFromStore will remove the manifest from storage. This function is called // when the bundle is deactivated. func EraseManifestFromStore(ctx context.Context, store storage.Store, txn storage.Transaction, name string) error { - path := namedBundlePath(name) - err := store.Write(ctx, txn, storage.RemoveOp, path, nil) - return suppressNotFound(err) -} - -// eraseBundleEtagFromStore will remove the bundle etag from storage. This function is called -// when the bundle is deactivated. -func eraseBundleEtagFromStore(ctx context.Context, store storage.Store, txn storage.Transaction, name string) error { - path := EtagStoragePath(name) - err := store.Write(ctx, txn, storage.RemoveOp, path, nil) - return suppressNotFound(err) -} - -func suppressNotFound(err error) error { - if err == nil || storage.IsNotFound(err) { - return nil - } - return err -} - -func writeWasmModulesToStore(ctx context.Context, store storage.Store, txn storage.Transaction, name string, b *Bundle) error { - basePath := wasmModulePath(name) - for _, wm := range b.WasmModules { - path := append(basePath, wm.Path) - err := write(ctx, store, txn, path, base64.StdEncoding.EncodeToString(wm.Raw)) - if err != nil { - return err - } - } - return nil -} - -func eraseWasmModulesFromStore(ctx context.Context, store storage.Store, txn storage.Transaction, name string) error { - path := wasmModulePath(name) - - err := store.Write(ctx, txn, storage.RemoveOp, path, nil) - return suppressNotFound(err) -} - -// ReadWasmMetadataFromStore will read Wasm module resolver metadata from the store. -func ReadWasmMetadataFromStore(ctx context.Context, store storage.Store, txn storage.Transaction, name string) ([]WasmResolver, error) { - path := wasmEntrypointsPath(name) - value, err := read(ctx, store, txn, path) - if err != nil { - return nil, err - } - - bs, err := json.Marshal(value) - if err != nil { - return nil, fmt.Errorf("corrupt wasm manifest data") - } - - var wasmMetadata []WasmResolver - - err = util.UnmarshalJSON(bs, &wasmMetadata) - if err != nil { - return nil, fmt.Errorf("corrupt wasm manifest data") - } - - return wasmMetadata, nil + return v1.EraseManifestFromStore(ctx, store, txn, name) } // ReadWasmModulesFromStore will write Wasm module resolver metadata from the store. func ReadWasmModulesFromStore(ctx context.Context, store storage.Store, txn storage.Transaction, name string) (map[string][]byte, error) { - path := wasmModulePath(name) - value, err := read(ctx, store, txn, path) - if err != nil { - return nil, err - } - - encodedModules, ok := value.(map[string]interface{}) - if !ok { - return nil, fmt.Errorf("corrupt wasm modules") - } - - rawModules := map[string][]byte{} - for path, enc := range encodedModules { - encStr, ok := enc.(string) - if !ok { - return nil, fmt.Errorf("corrupt wasm modules") - } - bs, err := base64.StdEncoding.DecodeString(encStr) - if err != nil { - return nil, err - } - rawModules[path] = bs - } - return rawModules, nil + return v1.ReadWasmModulesFromStore(ctx, store, txn, name) } // ReadBundleRootsFromStore returns the roots in the specified bundle. // If the bundle is not activated, this function will return // storage NotFound error. func ReadBundleRootsFromStore(ctx context.Context, store storage.Store, txn storage.Transaction, name string) ([]string, error) { - value, err := read(ctx, store, txn, rootsPath(name)) - if err != nil { - return nil, err - } - - sl, ok := value.([]interface{}) - if !ok { - return nil, fmt.Errorf("corrupt manifest roots") - } - - roots := make([]string, len(sl)) - - for i := range sl { - roots[i], ok = sl[i].(string) - if !ok { - return nil, fmt.Errorf("corrupt manifest root") - } - } - - return roots, nil + return v1.ReadBundleRootsFromStore(ctx, store, txn, name) } // ReadBundleRevisionFromStore returns the revision in the specified bundle. // If the bundle is not activated, this function will return // storage NotFound error. func ReadBundleRevisionFromStore(ctx context.Context, store storage.Store, txn storage.Transaction, name string) (string, error) { - return readRevisionFromStore(ctx, store, txn, revisionPath(name)) -} - -func readRevisionFromStore(ctx context.Context, store storage.Store, txn storage.Transaction, path storage.Path) (string, error) { - value, err := read(ctx, store, txn, path) - if err != nil { - return "", err - } - - str, ok := value.(string) - if !ok { - return "", fmt.Errorf("corrupt manifest revision") - } - - return str, nil + return v1.ReadBundleRevisionFromStore(ctx, store, txn, name) } // ReadBundleMetadataFromStore returns the metadata in the specified bundle. // If the bundle is not activated, this function will return // storage NotFound error. func ReadBundleMetadataFromStore(ctx context.Context, store storage.Store, txn storage.Transaction, name string) (map[string]interface{}, error) { - return readMetadataFromStore(ctx, store, txn, metadataPath(name)) -} - -func readMetadataFromStore(ctx context.Context, store storage.Store, txn storage.Transaction, path storage.Path) (map[string]interface{}, error) { - value, err := read(ctx, store, txn, path) - if err != nil { - return nil, suppressNotFound(err) - } - - data, ok := value.(map[string]interface{}) - if !ok { - return nil, fmt.Errorf("corrupt manifest metadata") - } - - return data, nil + return v1.ReadBundleMetadataFromStore(ctx, store, txn, name) } // ReadBundleEtagFromStore returns the etag for the specified bundle. // If the bundle is not activated, this function will return // storage NotFound error. func ReadBundleEtagFromStore(ctx context.Context, store storage.Store, txn storage.Transaction, name string) (string, error) { - return readEtagFromStore(ctx, store, txn, EtagStoragePath(name)) -} - -func readEtagFromStore(ctx context.Context, store storage.Store, txn storage.Transaction, path storage.Path) (string, error) { - value, err := read(ctx, store, txn, path) - if err != nil { - return "", err - } - - str, ok := value.(string) - if !ok { - return "", fmt.Errorf("corrupt bundle etag") - } - - return str, nil + return v1.ReadBundleEtagFromStore(ctx, store, txn, name) } // ActivateOpts defines options for the Activate API call. -type ActivateOpts struct { - Ctx context.Context - Store storage.Store - Txn storage.Transaction - TxnCtx *storage.Context - Compiler *ast.Compiler - Metrics metrics.Metrics - Bundles map[string]*Bundle // Optional - ExtraModules map[string]*ast.Module // Optional - AuthorizationDecisionRef ast.Ref - ParserOptions ast.ParserOptions - - legacy bool -} +type ActivateOpts = v1.ActivateOpts // Activate the bundle(s) by loading into the given Store. This will load policies, data, and record // the manifest in storage. The compiler provided will have had the polices compiled on it. func Activate(opts *ActivateOpts) error { - opts.legacy = false - return activateBundles(opts) + return v1.Activate(opts) } // DeactivateOpts defines options for the Deactivate API call -type DeactivateOpts struct { - Ctx context.Context - Store storage.Store - Txn storage.Transaction - BundleNames map[string]struct{} - ParserOptions ast.ParserOptions -} +type DeactivateOpts = v1.DeactivateOpts // Deactivate the bundle(s). This will erase associated data, policies, and the manifest entry from the store. func Deactivate(opts *DeactivateOpts) error { - erase := map[string]struct{}{} - for name := range opts.BundleNames { - roots, err := ReadBundleRootsFromStore(opts.Ctx, opts.Store, opts.Txn, name) - if suppressNotFound(err) != nil { - return err - } - for _, root := range roots { - erase[root] = struct{}{} - } - } - _, err := eraseBundles(opts.Ctx, opts.Store, opts.Txn, opts.ParserOptions, opts.BundleNames, erase) - return err -} - -func activateBundles(opts *ActivateOpts) error { - - // Build collections of bundle names, modules, and roots to erase - erase := map[string]struct{}{} - names := map[string]struct{}{} - deltaBundles := map[string]*Bundle{} - snapshotBundles := map[string]*Bundle{} - - for name, b := range opts.Bundles { - if b.Type() == DeltaBundleType { - deltaBundles[name] = b - } else { - snapshotBundles[name] = b - names[name] = struct{}{} - - roots, err := ReadBundleRootsFromStore(opts.Ctx, opts.Store, opts.Txn, name) - if suppressNotFound(err) != nil { - return err - } - for _, root := range roots { - erase[root] = struct{}{} - } - - // Erase data at new roots to prepare for writing the new data - for _, root := range *b.Manifest.Roots { - erase[root] = struct{}{} - } - } - } - - // Before changing anything make sure the roots don't collide with any - // other bundles that already are activated or other bundles being activated. - err := hasRootsOverlap(opts.Ctx, opts.Store, opts.Txn, opts.Bundles) - if err != nil { - return err - } - - if len(deltaBundles) != 0 { - err := activateDeltaBundles(opts, deltaBundles) - if err != nil { - return err - } - } - - // Erase data and policies at new + old roots, and remove the old - // manifests before activating a new snapshot bundle. - remaining, err := eraseBundles(opts.Ctx, opts.Store, opts.Txn, opts.ParserOptions, names, erase) - if err != nil { - return err - } - - // Validate data in bundle does not contain paths outside the bundle's roots. - for _, b := range snapshotBundles { - - if b.lazyLoadingMode { - - for _, item := range b.Raw { - path := filepath.ToSlash(item.Path) - - if filepath.Base(path) == dataFile || filepath.Base(path) == yamlDataFile { - var val map[string]json.RawMessage - err = util.Unmarshal(item.Value, &val) - if err == nil { - err = doDFS(val, filepath.Dir(strings.Trim(path, "/")), *b.Manifest.Roots) - if err != nil { - return err - } - } else { - // Build an object for the value - p := getNormalizedPath(path) - - if len(p) == 0 { - return fmt.Errorf("root value must be object") - } - - // verify valid YAML or JSON value - var x interface{} - err := util.Unmarshal(item.Value, &x) - if err != nil { - return err - } - - value := item.Value - dir := map[string]json.RawMessage{} - for i := len(p) - 1; i > 0; i-- { - dir[p[i]] = value - - bs, err := json.Marshal(dir) - if err != nil { - return err - } - - value = bs - dir = map[string]json.RawMessage{} - } - dir[p[0]] = value - - err = doDFS(dir, filepath.Dir(strings.Trim(path, "/")), *b.Manifest.Roots) - if err != nil { - return err - } - } - } - } - } - } - - // Compile the modules all at once to avoid having to re-do work. - remainingAndExtra := make(map[string]*ast.Module) - for name, mod := range remaining { - remainingAndExtra[name] = mod - } - for name, mod := range opts.ExtraModules { - remainingAndExtra[name] = mod - } - - err = compileModules(opts.Compiler, opts.Metrics, snapshotBundles, remainingAndExtra, opts.legacy, opts.AuthorizationDecisionRef) - if err != nil { - return err - } - - if err := writeDataAndModules(opts.Ctx, opts.Store, opts.Txn, opts.TxnCtx, snapshotBundles, opts.legacy); err != nil { - return err - } - - if err := ast.CheckPathConflicts(opts.Compiler, storage.NonEmpty(opts.Ctx, opts.Store, opts.Txn)); len(err) > 0 { - return err - } - - for name, b := range snapshotBundles { - if err := writeManifestToStore(opts, name, b.Manifest); err != nil { - return err - } - - if err := writeEtagToStore(opts, name, b.Etag); err != nil { - return err - } - - if err := writeWasmModulesToStore(opts.Ctx, opts.Store, opts.Txn, name, b); err != nil { - return err - } - } - - return nil -} - -func doDFS(obj map[string]json.RawMessage, path string, roots []string) error { - if len(roots) == 1 && roots[0] == "" { - return nil - } - - for key := range obj { - - newPath := filepath.Join(strings.Trim(path, "/"), key) - - // Note: filepath.Join can return paths with '\' separators, always use - // filepath.ToSlash to keep them normalized. - newPath = strings.TrimLeft(normalizePath(newPath), "/.") - - contains := false - prefix := false - if RootPathsContain(roots, newPath) { - contains = true - } else { - for i := range roots { - if strings.HasPrefix(strings.Trim(roots[i], "/"), newPath) { - prefix = true - break - } - } - } - - if !contains && !prefix { - return fmt.Errorf("manifest roots %v do not permit data at path '/%s' (hint: check bundle directory structure)", roots, newPath) - } - - if contains { - continue - } - - var next map[string]json.RawMessage - err := util.Unmarshal(obj[key], &next) - if err != nil { - return fmt.Errorf("manifest roots %v do not permit data at path '/%s' (hint: check bundle directory structure)", roots, newPath) - } - - if err := doDFS(next, newPath, roots); err != nil { - return err - } - } - return nil -} - -func activateDeltaBundles(opts *ActivateOpts, bundles map[string]*Bundle) error { - - // Check that the manifest roots and wasm resolvers in the delta bundle - // match with those currently in the store - for name, b := range bundles { - value, err := opts.Store.Read(opts.Ctx, opts.Txn, ManifestStoragePath(name)) - if err != nil { - if storage.IsNotFound(err) { - continue - } - return err - } - - manifest, err := valueToManifest(value) - if err != nil { - return fmt.Errorf("corrupt manifest data: %w", err) - } - - if !b.Manifest.equalWasmResolversAndRoots(manifest) { - return fmt.Errorf("delta bundle '%s' has wasm resolvers or manifest roots that are different from those in the store", name) - } - } - - for _, b := range bundles { - err := applyPatches(opts.Ctx, opts.Store, opts.Txn, b.Patch.Data) - if err != nil { - return err - } - } - - if err := ast.CheckPathConflicts(opts.Compiler, storage.NonEmpty(opts.Ctx, opts.Store, opts.Txn)); len(err) > 0 { - return err - } - - for name, b := range bundles { - if err := writeManifestToStore(opts, name, b.Manifest); err != nil { - return err - } - - if err := writeEtagToStore(opts, name, b.Etag); err != nil { - return err - } - } - - return nil -} - -func valueToManifest(v interface{}) (Manifest, error) { - if astV, ok := v.(ast.Value); ok { - var err error - v, err = ast.JSON(astV) - if err != nil { - return Manifest{}, err - } - } - - var manifest Manifest - - bs, err := json.Marshal(v) - if err != nil { - return Manifest{}, err - } - - err = util.UnmarshalJSON(bs, &manifest) - if err != nil { - return Manifest{}, err - } - - return manifest, nil -} - -// erase bundles by name and roots. This will clear all policies and data at its roots and remove its -// manifest from storage. -func eraseBundles(ctx context.Context, store storage.Store, txn storage.Transaction, parserOpts ast.ParserOptions, names map[string]struct{}, roots map[string]struct{}) (map[string]*ast.Module, error) { - - if err := eraseData(ctx, store, txn, roots); err != nil { - return nil, err - } - - remaining, err := erasePolicies(ctx, store, txn, parserOpts, roots) - if err != nil { - return nil, err - } - - for name := range names { - if err := EraseManifestFromStore(ctx, store, txn, name); suppressNotFound(err) != nil { - return nil, err - } - - if err := LegacyEraseManifestFromStore(ctx, store, txn); suppressNotFound(err) != nil { - return nil, err - } - - if err := eraseBundleEtagFromStore(ctx, store, txn, name); suppressNotFound(err) != nil { - return nil, err - } - - if err := eraseWasmModulesFromStore(ctx, store, txn, name); suppressNotFound(err) != nil { - return nil, err - } - } - - return remaining, nil -} - -func eraseData(ctx context.Context, store storage.Store, txn storage.Transaction, roots map[string]struct{}) error { - for root := range roots { - path, ok := storage.ParsePathEscaped("/" + root) - if !ok { - return fmt.Errorf("manifest root path invalid: %v", root) - } - - if len(path) > 0 { - if err := store.Write(ctx, txn, storage.RemoveOp, path, nil); suppressNotFound(err) != nil { - return err - } - } - } - return nil -} - -func erasePolicies(ctx context.Context, store storage.Store, txn storage.Transaction, parserOpts ast.ParserOptions, roots map[string]struct{}) (map[string]*ast.Module, error) { - - ids, err := store.ListPolicies(ctx, txn) - if err != nil { - return nil, err - } - - remaining := map[string]*ast.Module{} - - for _, id := range ids { - bs, err := store.GetPolicy(ctx, txn, id) - if err != nil { - return nil, err - } - module, err := ast.ParseModuleWithOpts(id, string(bs), parserOpts) - if err != nil { - return nil, err - } - path, err := module.Package.Path.Ptr() - if err != nil { - return nil, err - } - deleted := false - for root := range roots { - if RootPathsContain([]string{root}, path) { - if err := store.DeletePolicy(ctx, txn, id); err != nil { - return nil, err - } - deleted = true - break - } - } - if !deleted { - remaining[id] = module - } - } - - return remaining, nil -} - -func writeManifestToStore(opts *ActivateOpts, name string, manifest Manifest) error { - // Always write manifests to the named location. If the plugin is in the older style config - // then also write to the old legacy unnamed location. - if err := WriteManifestToStore(opts.Ctx, opts.Store, opts.Txn, name, manifest); err != nil { - return err - } - - if opts.legacy { - if err := LegacyWriteManifestToStore(opts.Ctx, opts.Store, opts.Txn, manifest); err != nil { - return err - } - } - - return nil -} - -func writeEtagToStore(opts *ActivateOpts, name, etag string) error { - if err := WriteEtagToStore(opts.Ctx, opts.Store, opts.Txn, name, etag); err != nil { - return err - } - - return nil -} - -func writeDataAndModules(ctx context.Context, store storage.Store, txn storage.Transaction, txnCtx *storage.Context, bundles map[string]*Bundle, legacy bool) error { - params := storage.WriteParams - params.Context = txnCtx - - for name, b := range bundles { - if len(b.Raw) == 0 { - // Write data from each new bundle into the store. Only write under the - // roots contained in their manifest. - if err := writeData(ctx, store, txn, *b.Manifest.Roots, b.Data); err != nil { - return err - } - - for _, mf := range b.Modules { - var path string - - // For backwards compatibility, in legacy mode, upsert policies to - // the unprefixed path. - if legacy { - path = mf.Path - } else { - path = modulePathWithPrefix(name, mf.Path) - } - - if err := store.UpsertPolicy(ctx, txn, path, mf.Raw); err != nil { - return err - } - } - } else { - params.BasePaths = *b.Manifest.Roots - - err := store.Truncate(ctx, txn, params, NewIterator(b.Raw)) - if err != nil { - return fmt.Errorf("store truncate failed for bundle '%s': %v", name, err) - } - } - } - - return nil -} - -func writeData(ctx context.Context, store storage.Store, txn storage.Transaction, roots []string, data map[string]interface{}) error { - for _, root := range roots { - path, ok := storage.ParsePathEscaped("/" + root) - if !ok { - return fmt.Errorf("manifest root path invalid: %v", root) - } - if value, ok := lookup(path, data); ok { - if len(path) > 0 { - if err := storage.MakeDir(ctx, store, txn, path[:len(path)-1]); err != nil { - return err - } - } - if err := store.Write(ctx, txn, storage.AddOp, path, value); err != nil { - return err - } - } - } - return nil -} - -func compileModules(compiler *ast.Compiler, m metrics.Metrics, bundles map[string]*Bundle, extraModules map[string]*ast.Module, legacy bool, authorizationDecisionRef ast.Ref) error { - - m.Timer(metrics.RegoModuleCompile).Start() - defer m.Timer(metrics.RegoModuleCompile).Stop() - - modules := map[string]*ast.Module{} - - // preserve any modules already on the compiler - for name, module := range compiler.Modules { - modules[name] = module - } - - // preserve any modules passed in from the store - for name, module := range extraModules { - modules[name] = module - } - - // include all the new bundle modules - for bundleName, b := range bundles { - if legacy { - for _, mf := range b.Modules { - modules[mf.Path] = mf.Parsed - } - } else { - for name, module := range b.ParsedModules(bundleName) { - modules[name] = module - } - } - } - - if compiler.Compile(modules); compiler.Failed() { - return compiler.Errors - } - - if authorizationDecisionRef.Equal(ast.EmptyRef()) { - return nil - } - - return iCompiler.VerifyAuthorizationPolicySchema(compiler, authorizationDecisionRef) + return v1.Deactivate(opts) } -func writeModules(ctx context.Context, store storage.Store, txn storage.Transaction, compiler *ast.Compiler, m metrics.Metrics, bundles map[string]*Bundle, extraModules map[string]*ast.Module, legacy bool) error { - - m.Timer(metrics.RegoModuleCompile).Start() - defer m.Timer(metrics.RegoModuleCompile).Stop() - - modules := map[string]*ast.Module{} - - // preserve any modules already on the compiler - for name, module := range compiler.Modules { - modules[name] = module - } - - // preserve any modules passed in from the store - for name, module := range extraModules { - modules[name] = module - } - - // include all the new bundle modules - for bundleName, b := range bundles { - if legacy { - for _, mf := range b.Modules { - modules[mf.Path] = mf.Parsed - } - } else { - for name, module := range b.ParsedModules(bundleName) { - modules[name] = module - } - } - } - - if compiler.Compile(modules); compiler.Failed() { - return compiler.Errors - } - for bundleName, b := range bundles { - for _, mf := range b.Modules { - var path string - - // For backwards compatibility, in legacy mode, upsert policies to - // the unprefixed path. - if legacy { - path = mf.Path - } else { - path = modulePathWithPrefix(bundleName, mf.Path) - } - - if err := store.UpsertPolicy(ctx, txn, path, mf.Raw); err != nil { - return err - } - } - } - return nil -} - -func lookup(path storage.Path, data map[string]interface{}) (interface{}, bool) { - if len(path) == 0 { - return data, true - } - for i := 0; i < len(path)-1; i++ { - value, ok := data[path[i]] - if !ok { - return nil, false - } - obj, ok := value.(map[string]interface{}) - if !ok { - return nil, false - } - data = obj - } - value, ok := data[path[len(path)-1]] - return value, ok -} - -func hasRootsOverlap(ctx context.Context, store storage.Store, txn storage.Transaction, bundles map[string]*Bundle) error { - collisions := map[string][]string{} - allBundles, err := ReadBundleNamesFromStore(ctx, store, txn) - if suppressNotFound(err) != nil { - return err - } - - allRoots := map[string][]string{} - - // Build a map of roots for existing bundles already in the system - for _, name := range allBundles { - roots, err := ReadBundleRootsFromStore(ctx, store, txn, name) - if suppressNotFound(err) != nil { - return err - } - allRoots[name] = roots - } - - // Add in any bundles that are being activated, overwrite existing roots - // with new ones where bundles are in both groups. - for name, bundle := range bundles { - allRoots[name] = *bundle.Manifest.Roots - } - - // Now check for each new bundle if it conflicts with any of the others - for name, bundle := range bundles { - for otherBundle, otherRoots := range allRoots { - if name == otherBundle { - // Skip the current bundle being checked - continue - } - - // Compare the "new" roots with other existing (or a different bundles new roots) - for _, newRoot := range *bundle.Manifest.Roots { - for _, otherRoot := range otherRoots { - if RootPathsOverlap(newRoot, otherRoot) { - collisions[otherBundle] = append(collisions[otherBundle], newRoot) - } - } - } - } - } - - if len(collisions) > 0 { - var bundleNames []string - for name := range collisions { - bundleNames = append(bundleNames, name) - } - return fmt.Errorf("detected overlapping roots in bundle manifest with: %s", bundleNames) - } - return nil -} - -func applyPatches(ctx context.Context, store storage.Store, txn storage.Transaction, patches []PatchOperation) error { - for _, pat := range patches { - - // construct patch path - path, ok := patch.ParsePatchPathEscaped("/" + strings.Trim(pat.Path, "/")) - if !ok { - return fmt.Errorf("error parsing patch path") - } - - var op storage.PatchOp - switch pat.Op { - case "upsert": - op = storage.AddOp - - _, err := store.Read(ctx, txn, path[:len(path)-1]) - if err != nil { - if !storage.IsNotFound(err) { - return err - } - - if err := storage.MakeDir(ctx, store, txn, path[:len(path)-1]); err != nil { - return err - } - } - case "remove": - op = storage.RemoveOp - case "replace": - op = storage.ReplaceOp - default: - return fmt.Errorf("bad patch operation: %v", pat.Op) - } - - // apply the patch - if err := store.Write(ctx, txn, op, path, pat.Value); err != nil { - return err - } - } - - return nil -} - -// Helpers for the older single (unnamed) bundle style manifest storage. - -// LegacyManifestStoragePath is the older unnamed bundle path for manifests to be stored. -// Deprecated: Use ManifestStoragePath and named bundles instead. -var legacyManifestStoragePath = storage.MustParsePath("/system/bundle/manifest") -var legacyRevisionStoragePath = append(legacyManifestStoragePath, "revision") - // LegacyWriteManifestToStore will write the bundle manifest to the older single (unnamed) bundle manifest location. // Deprecated: Use WriteManifestToStore and named bundles instead. func LegacyWriteManifestToStore(ctx context.Context, store storage.Store, txn storage.Transaction, manifest Manifest) error { - return write(ctx, store, txn, legacyManifestStoragePath, manifest) + return v1.LegacyWriteManifestToStore(ctx, store, txn, manifest) } // LegacyEraseManifestFromStore will erase the bundle manifest from the older single (unnamed) bundle manifest location. // Deprecated: Use WriteManifestToStore and named bundles instead. func LegacyEraseManifestFromStore(ctx context.Context, store storage.Store, txn storage.Transaction) error { - err := store.Write(ctx, txn, storage.RemoveOp, legacyManifestStoragePath, nil) - if err != nil { - return err - } - return nil + return v1.LegacyEraseManifestFromStore(ctx, store, txn) } // LegacyReadRevisionFromStore will read the bundle manifest revision from the older single (unnamed) bundle manifest location. // Deprecated: Use ReadBundleRevisionFromStore and named bundles instead. func LegacyReadRevisionFromStore(ctx context.Context, store storage.Store, txn storage.Transaction) (string, error) { - return readRevisionFromStore(ctx, store, txn, legacyRevisionStoragePath) + return v1.LegacyReadRevisionFromStore(ctx, store, txn) } // ActivateLegacy calls Activate for the bundles but will also write their manifest to the older unnamed store location. // Deprecated: Use Activate with named bundles instead. func ActivateLegacy(opts *ActivateOpts) error { - opts.legacy = true - return activateBundles(opts) + return v1.ActivateLegacy(opts) } diff --git a/vendor/github.com/open-policy-agent/opa/bundle/verify.go b/vendor/github.com/open-policy-agent/opa/bundle/verify.go index e85be835b..ef2e1e32d 100644 --- a/vendor/github.com/open-policy-agent/opa/bundle/verify.go +++ b/vendor/github.com/open-policy-agent/opa/bundle/verify.go @@ -6,26 +6,11 @@ package bundle import ( - "bytes" - "encoding/base64" - "encoding/hex" - "encoding/json" - "fmt" - - "github.com/open-policy-agent/opa/internal/jwx/jwa" - "github.com/open-policy-agent/opa/internal/jwx/jws" - "github.com/open-policy-agent/opa/internal/jwx/jws/verify" - "github.com/open-policy-agent/opa/util" + v1 "github.com/open-policy-agent/opa/v1/bundle" ) -const defaultVerifierID = "_default" - -var verifiers map[string]Verifier - // Verifier is the interface expected for implementations that verify bundle signatures. -type Verifier interface { - VerifyBundleSignature(SignaturesConfig, *VerificationConfig) (map[string]FileInfo, error) -} +type Verifier v1.Verifier // VerifyBundleSignature will retrieve the Verifier implementation based // on the Plugin specified in SignaturesConfig, and call its implementation @@ -33,199 +18,19 @@ type Verifier interface { // using the given public keys or secret. If a signature is verified, it keeps // track of the files specified in the JWT payload func VerifyBundleSignature(sc SignaturesConfig, bvc *VerificationConfig) (map[string]FileInfo, error) { - // default implementation does not return a nil for map, so don't - // do it here either - files := make(map[string]FileInfo) - var plugin string - // for backwards compatibility, check if there is no plugin specified, and use default - if sc.Plugin == "" { - plugin = defaultVerifierID - } else { - plugin = sc.Plugin - } - verifier, err := GetVerifier(plugin) - if err != nil { - return files, err - } - return verifier.VerifyBundleSignature(sc, bvc) + return v1.VerifyBundleSignature(sc, bvc) } // DefaultVerifier is the default bundle verification implementation. It verifies bundles by checking // the JWT signature using a locally-accessible public key. -type DefaultVerifier struct{} - -// VerifyBundleSignature verifies the bundle signature using the given public keys or secret. -// If a signature is verified, it keeps track of the files specified in the JWT payload -func (*DefaultVerifier) VerifyBundleSignature(sc SignaturesConfig, bvc *VerificationConfig) (map[string]FileInfo, error) { - files := make(map[string]FileInfo) - - if len(sc.Signatures) == 0 { - return files, fmt.Errorf(".signatures.json: missing JWT (expected exactly one)") - } - - if len(sc.Signatures) > 1 { - return files, fmt.Errorf(".signatures.json: multiple JWTs not supported (expected exactly one)") - } - - for _, token := range sc.Signatures { - payload, err := verifyJWTSignature(token, bvc) - if err != nil { - return files, err - } - - for _, file := range payload.Files { - files[file.Name] = file - } - } - return files, nil -} - -func verifyJWTSignature(token string, bvc *VerificationConfig) (*DecodedSignature, error) { - // decode JWT to check if the header specifies the key to use and/or if claims have the scope. - - parts, err := jws.SplitCompact(token) - if err != nil { - return nil, err - } - - var decodedHeader []byte - if decodedHeader, err = base64.RawURLEncoding.DecodeString(parts[0]); err != nil { - return nil, fmt.Errorf("failed to base64 decode JWT headers: %w", err) - } - - var hdr jws.StandardHeaders - if err := json.Unmarshal(decodedHeader, &hdr); err != nil { - return nil, fmt.Errorf("failed to parse JWT headers: %w", err) - } - - payload, err := base64.RawURLEncoding.DecodeString(parts[1]) - if err != nil { - return nil, err - } - - var ds DecodedSignature - if err := json.Unmarshal(payload, &ds); err != nil { - return nil, err - } - - // check for the id of the key to use for JWT signature verification - // first in the OPA config. If not found, then check the JWT kid. - keyID := bvc.KeyID - if keyID == "" { - keyID = hdr.KeyID - } - if keyID == "" { - // If header has no key id, check the deprecated key claim. - keyID = ds.KeyID - } - - if keyID == "" { - return nil, fmt.Errorf("verification key ID is empty") - } - - // now that we have the keyID, fetch the actual key - keyConfig, err := bvc.GetPublicKey(keyID) - if err != nil { - return nil, err - } - - // verify JWT signature - alg := jwa.SignatureAlgorithm(keyConfig.Algorithm) - key, err := verify.GetSigningKey(keyConfig.Key, alg) - if err != nil { - return nil, err - } - - _, err = jws.Verify([]byte(token), alg, key) - if err != nil { - return nil, err - } - - // verify the scope - scope := bvc.Scope - if scope == "" { - scope = keyConfig.Scope - } - - if ds.Scope != scope { - return nil, fmt.Errorf("scope mismatch") - } - return &ds, nil -} - -// VerifyBundleFile verifies the hash of a file in the bundle matches to that provided in the bundle's signature -func VerifyBundleFile(path string, data bytes.Buffer, files map[string]FileInfo) error { - var file FileInfo - var ok bool - - if file, ok = files[path]; !ok { - return fmt.Errorf("file %v not included in bundle signature", path) - } - - if file.Algorithm == "" { - return fmt.Errorf("no hashing algorithm provided for file %v", path) - } - - hash, err := NewSignatureHasher(HashingAlgorithm(file.Algorithm)) - if err != nil { - return err - } - - // hash the file content - // For unstructured files, hash the byte stream of the file - // For structured files, read the byte stream and parse into a JSON structure; - // then recursively order the fields of all objects alphabetically and then apply - // the hash function to result to compute the hash. This ensures that the digital signature is - // independent of whitespace and other non-semantic JSON features. - var value interface{} - if IsStructuredDoc(path) { - err := util.Unmarshal(data.Bytes(), &value) - if err != nil { - return err - } - } else { - value = data.Bytes() - } - - bs, err := hash.HashFile(value) - if err != nil { - return err - } - - // compare file hash with same file in the JWT payloads - fb, err := hex.DecodeString(file.Hash) - if err != nil { - return err - } - - if !bytes.Equal(fb, bs) { - return fmt.Errorf("%v: digest mismatch (want: %x, got: %x)", path, fb, bs) - } - - delete(files, path) - return nil -} +type DefaultVerifier = v1.DefaultVerifier // GetVerifier returns the Verifier registered under the given id func GetVerifier(id string) (Verifier, error) { - verifier, ok := verifiers[id] - if !ok { - return nil, fmt.Errorf("no verifier exists under id %s", id) - } - return verifier, nil + return v1.GetVerifier(id) } // RegisterVerifier registers a Verifier under the given id func RegisterVerifier(id string, v Verifier) error { - if id == defaultVerifierID { - return fmt.Errorf("verifier id %s is reserved, use a different id", id) - } - verifiers[id] = v - return nil -} - -func init() { - verifiers = map[string]Verifier{ - defaultVerifierID: &DefaultVerifier{}, - } + return v1.RegisterVerifier(id, v) } diff --git a/vendor/github.com/open-policy-agent/opa/capabilities/doc.go b/vendor/github.com/open-policy-agent/opa/capabilities/doc.go new file mode 100644 index 000000000..189c2e727 --- /dev/null +++ b/vendor/github.com/open-policy-agent/opa/capabilities/doc.go @@ -0,0 +1,8 @@ +// Copyright 2024 The OPA Authors. All rights reserved. +// Use of this source code is governed by an Apache2 +// license that can be found in the LICENSE file. + +// Deprecated: This package is intended for older projects transitioning from OPA v0.x and will remain for the lifetime of OPA v1.x, but its use is not recommended. +// For newer features and behaviours, such as defaulting to the Rego v1 syntax, use the corresponding components in the [github.com/open-policy-agent/opa/v1] package instead. +// See https://www.openpolicyagent.org/docs/latest/v0-compatibility/ for more information. +package capabilities diff --git a/vendor/github.com/open-policy-agent/opa/capabilities/v1.0.0.json b/vendor/github.com/open-policy-agent/opa/capabilities/v1.0.0.json new file mode 100644 index 000000000..48a87b0c3 --- /dev/null +++ b/vendor/github.com/open-policy-agent/opa/capabilities/v1.0.0.json @@ -0,0 +1,4835 @@ +{ + "builtins": [ + { + "name": "abs", + "decl": { + "args": [ + { + "type": "number" + } + ], + "result": { + "type": "number" + }, + "type": "function" + } + }, + { + "name": "all", + "decl": { + "args": [ + { + "of": [ + { + "dynamic": { + "type": "any" + }, + "type": "array" + }, + { + "of": { + "type": "any" + }, + "type": "set" + } + ], + "type": "any" + } + ], + "result": { + "type": "boolean" + }, + "type": "function" + } + }, + { + "name": "and", + "decl": { + "args": [ + { + "of": { + "type": "any" + }, + "type": "set" + }, + { + "of": { + "type": "any" + }, + "type": "set" + } + ], + "result": { + "of": { + "type": "any" + }, + "type": "set" + }, + "type": "function" + }, + "infix": "\u0026" + }, + { + "name": "any", + "decl": { + "args": [ + { + "of": [ + { + "dynamic": { + "type": "any" + }, + "type": "array" + }, + { + "of": { + "type": "any" + }, + "type": "set" + } + ], + "type": "any" + } + ], + "result": { + "type": "boolean" + }, + "type": "function" + } + }, + { + "name": "array.concat", + "decl": { + "args": [ + { + "dynamic": { + "type": "any" + }, + "type": "array" + }, + { + "dynamic": { + "type": "any" + }, + "type": "array" + } + ], + "result": { + "dynamic": { + "type": "any" + }, + "type": "array" + }, + "type": "function" + } + }, + { + "name": "array.reverse", + "decl": { + "args": [ + { + "dynamic": { + "type": "any" + }, + "type": "array" + } + ], + "result": { + "dynamic": { + "type": "any" + }, + "type": "array" + }, + "type": "function" + } + }, + { + "name": "array.slice", + "decl": { + "args": [ + { + "dynamic": { + "type": "any" + }, + "type": "array" + }, + { + "type": "number" + }, + { + "type": "number" + } + ], + "result": { + "dynamic": { + "type": "any" + }, + "type": "array" + }, + "type": "function" + } + }, + { + "name": "assign", + "decl": { + "args": [ + { + "type": "any" + }, + { + "type": "any" + } + ], + "result": { + "type": "boolean" + }, + "type": "function" + }, + "infix": ":=" + }, + { + "name": "base64.decode", + "decl": { + "args": [ + { + "type": "string" + } + ], + "result": { + "type": "string" + }, + "type": "function" + } + }, + { + "name": "base64.encode", + "decl": { + "args": [ + { + "type": "string" + } + ], + "result": { + "type": "string" + }, + "type": "function" + } + }, + { + "name": "base64.is_valid", + "decl": { + "args": [ + { + "type": "string" + } + ], + "result": { + "type": "boolean" + }, + "type": "function" + } + }, + { + "name": "base64url.decode", + "decl": { + "args": [ + { + "type": "string" + } + ], + "result": { + "type": "string" + }, + "type": "function" + } + }, + { + "name": "base64url.encode", + "decl": { + "args": [ + { + "type": "string" + } + ], + "result": { + "type": "string" + }, + "type": "function" + } + }, + { + "name": "base64url.encode_no_pad", + "decl": { + "args": [ + { + "type": "string" + } + ], + "result": { + "type": "string" + }, + "type": "function" + } + }, + { + "name": "bits.and", + "decl": { + "args": [ + { + "type": "number" + }, + { + "type": "number" + } + ], + "result": { + "type": "number" + }, + "type": "function" + } + }, + { + "name": "bits.lsh", + "decl": { + "args": [ + { + "type": "number" + }, + { + "type": "number" + } + ], + "result": { + "type": "number" + }, + "type": "function" + } + }, + { + "name": "bits.negate", + "decl": { + "args": [ + { + "type": "number" + } + ], + "result": { + "type": "number" + }, + "type": "function" + } + }, + { + "name": "bits.or", + "decl": { + "args": [ + { + "type": "number" + }, + { + "type": "number" + } + ], + "result": { + "type": "number" + }, + "type": "function" + } + }, + { + "name": "bits.rsh", + "decl": { + "args": [ + { + "type": "number" + }, + { + "type": "number" + } + ], + "result": { + "type": "number" + }, + "type": "function" + } + }, + { + "name": "bits.xor", + "decl": { + "args": [ + { + "type": "number" + }, + { + "type": "number" + } + ], + "result": { + "type": "number" + }, + "type": "function" + } + }, + { + "name": "cast_array", + "decl": { + "args": [ + { + "type": "any" + } + ], + "result": { + "dynamic": { + "type": "any" + }, + "type": "array" + }, + "type": "function" + } + }, + { + "name": "cast_boolean", + "decl": { + "args": [ + { + "type": "any" + } + ], + "result": { + "type": "boolean" + }, + "type": "function" + } + }, + { + "name": "cast_null", + "decl": { + "args": [ + { + "type": "any" + } + ], + "result": { + "type": "null" + }, + "type": "function" + } + }, + { + "name": "cast_object", + "decl": { + "args": [ + { + "type": "any" + } + ], + "result": { + "dynamic": { + "key": { + "type": "any" + }, + "value": { + "type": "any" + } + }, + "type": "object" + }, + "type": "function" + } + }, + { + "name": "cast_set", + "decl": { + "args": [ + { + "type": "any" + } + ], + "result": { + "of": { + "type": "any" + }, + "type": "set" + }, + "type": "function" + } + }, + { + "name": "cast_string", + "decl": { + "args": [ + { + "type": "any" + } + ], + "result": { + "type": "string" + }, + "type": "function" + } + }, + { + "name": "ceil", + "decl": { + "args": [ + { + "type": "number" + } + ], + "result": { + "type": "number" + }, + "type": "function" + } + }, + { + "name": "concat", + "decl": { + "args": [ + { + "type": "string" + }, + { + "of": [ + { + "dynamic": { + "type": "string" + }, + "type": "array" + }, + { + "of": { + "type": "string" + }, + "type": "set" + } + ], + "type": "any" + } + ], + "result": { + "type": "string" + }, + "type": "function" + } + }, + { + "name": "contains", + "decl": { + "args": [ + { + "type": "string" + }, + { + "type": "string" + } + ], + "result": { + "type": "boolean" + }, + "type": "function" + } + }, + { + "name": "count", + "decl": { + "args": [ + { + "of": [ + { + "type": "string" + }, + { + "dynamic": { + "type": "any" + }, + "type": "array" + }, + { + "dynamic": { + "key": { + "type": "any" + }, + "value": { + "type": "any" + } + }, + "type": "object" + }, + { + "of": { + "type": "any" + }, + "type": "set" + } + ], + "type": "any" + } + ], + "result": { + "type": "number" + }, + "type": "function" + } + }, + { + "name": "crypto.hmac.equal", + "decl": { + "args": [ + { + "type": "string" + }, + { + "type": "string" + } + ], + "result": { + "type": "boolean" + }, + "type": "function" + } + }, + { + "name": "crypto.hmac.md5", + "decl": { + "args": [ + { + "type": "string" + }, + { + "type": "string" + } + ], + "result": { + "type": "string" + }, + "type": "function" + } + }, + { + "name": "crypto.hmac.sha1", + "decl": { + "args": [ + { + "type": "string" + }, + { + "type": "string" + } + ], + "result": { + "type": "string" + }, + "type": "function" + } + }, + { + "name": "crypto.hmac.sha256", + "decl": { + "args": [ + { + "type": "string" + }, + { + "type": "string" + } + ], + "result": { + "type": "string" + }, + "type": "function" + } + }, + { + "name": "crypto.hmac.sha512", + "decl": { + "args": [ + { + "type": "string" + }, + { + "type": "string" + } + ], + "result": { + "type": "string" + }, + "type": "function" + } + }, + { + "name": "crypto.md5", + "decl": { + "args": [ + { + "type": "string" + } + ], + "result": { + "type": "string" + }, + "type": "function" + } + }, + { + "name": "crypto.parse_private_keys", + "decl": { + "args": [ + { + "type": "string" + } + ], + "result": { + "dynamic": { + "dynamic": { + "key": { + "type": "string" + }, + "value": { + "type": "any" + } + }, + "type": "object" + }, + "type": "array" + }, + "type": "function" + } + }, + { + "name": "crypto.sha1", + "decl": { + "args": [ + { + "type": "string" + } + ], + "result": { + "type": "string" + }, + "type": "function" + } + }, + { + "name": "crypto.sha256", + "decl": { + "args": [ + { + "type": "string" + } + ], + "result": { + "type": "string" + }, + "type": "function" + } + }, + { + "name": "crypto.x509.parse_and_verify_certificates", + "decl": { + "args": [ + { + "type": "string" + } + ], + "result": { + "static": [ + { + "type": "boolean" + }, + { + "dynamic": { + "dynamic": { + "key": { + "type": "string" + }, + "value": { + "type": "any" + } + }, + "type": "object" + }, + "type": "array" + } + ], + "type": "array" + }, + "type": "function" + } + }, + { + "name": "crypto.x509.parse_and_verify_certificates_with_options", + "decl": { + "args": [ + { + "type": "string" + }, + { + "dynamic": { + "key": { + "type": "string" + }, + "value": { + "type": "any" + } + }, + "type": "object" + } + ], + "result": { + "static": [ + { + "type": "boolean" + }, + { + "dynamic": { + "dynamic": { + "key": { + "type": "string" + }, + "value": { + "type": "any" + } + }, + "type": "object" + }, + "type": "array" + } + ], + "type": "array" + }, + "type": "function" + } + }, + { + "name": "crypto.x509.parse_certificate_request", + "decl": { + "args": [ + { + "type": "string" + } + ], + "result": { + "dynamic": { + "key": { + "type": "string" + }, + "value": { + "type": "any" + } + }, + "type": "object" + }, + "type": "function" + } + }, + { + "name": "crypto.x509.parse_certificates", + "decl": { + "args": [ + { + "type": "string" + } + ], + "result": { + "dynamic": { + "dynamic": { + "key": { + "type": "string" + }, + "value": { + "type": "any" + } + }, + "type": "object" + }, + "type": "array" + }, + "type": "function" + } + }, + { + "name": "crypto.x509.parse_keypair", + "decl": { + "args": [ + { + "type": "string" + }, + { + "type": "string" + } + ], + "result": { + "dynamic": { + "key": { + "type": "string" + }, + "value": { + "type": "any" + } + }, + "type": "object" + }, + "type": "function" + } + }, + { + "name": "crypto.x509.parse_rsa_private_key", + "decl": { + "args": [ + { + "type": "string" + } + ], + "result": { + "dynamic": { + "key": { + "type": "string" + }, + "value": { + "type": "any" + } + }, + "type": "object" + }, + "type": "function" + } + }, + { + "name": "div", + "decl": { + "args": [ + { + "type": "number" + }, + { + "type": "number" + } + ], + "result": { + "type": "number" + }, + "type": "function" + }, + "infix": "/" + }, + { + "name": "endswith", + "decl": { + "args": [ + { + "type": "string" + }, + { + "type": "string" + } + ], + "result": { + "type": "boolean" + }, + "type": "function" + } + }, + { + "name": "eq", + "decl": { + "args": [ + { + "type": "any" + }, + { + "type": "any" + } + ], + "result": { + "type": "boolean" + }, + "type": "function" + }, + "infix": "=" + }, + { + "name": "equal", + "decl": { + "args": [ + { + "type": "any" + }, + { + "type": "any" + } + ], + "result": { + "type": "boolean" + }, + "type": "function" + }, + "infix": "==" + }, + { + "name": "floor", + "decl": { + "args": [ + { + "type": "number" + } + ], + "result": { + "type": "number" + }, + "type": "function" + } + }, + { + "name": "format_int", + "decl": { + "args": [ + { + "type": "number" + }, + { + "type": "number" + } + ], + "result": { + "type": "string" + }, + "type": "function" + } + }, + { + "name": "glob.match", + "decl": { + "args": [ + { + "type": "string" + }, + { + "of": [ + { + "type": "null" + }, + { + "dynamic": { + "type": "string" + }, + "type": "array" + } + ], + "type": "any" + }, + { + "type": "string" + } + ], + "result": { + "type": "boolean" + }, + "type": "function" + } + }, + { + "name": "glob.quote_meta", + "decl": { + "args": [ + { + "type": "string" + } + ], + "result": { + "type": "string" + }, + "type": "function" + } + }, + { + "name": "graph.reachable", + "decl": { + "args": [ + { + "dynamic": { + "key": { + "type": "any" + }, + "value": { + "of": [ + { + "dynamic": { + "type": "any" + }, + "type": "array" + }, + { + "of": { + "type": "any" + }, + "type": "set" + } + ], + "type": "any" + } + }, + "type": "object" + }, + { + "of": [ + { + "dynamic": { + "type": "any" + }, + "type": "array" + }, + { + "of": { + "type": "any" + }, + "type": "set" + } + ], + "type": "any" + } + ], + "result": { + "of": { + "type": "any" + }, + "type": "set" + }, + "type": "function" + } + }, + { + "name": "graph.reachable_paths", + "decl": { + "args": [ + { + "dynamic": { + "key": { + "type": "any" + }, + "value": { + "of": [ + { + "dynamic": { + "type": "any" + }, + "type": "array" + }, + { + "of": { + "type": "any" + }, + "type": "set" + } + ], + "type": "any" + } + }, + "type": "object" + }, + { + "of": [ + { + "dynamic": { + "type": "any" + }, + "type": "array" + }, + { + "of": { + "type": "any" + }, + "type": "set" + } + ], + "type": "any" + } + ], + "result": { + "of": { + "dynamic": { + "type": "any" + }, + "type": "array" + }, + "type": "set" + }, + "type": "function" + } + }, + { + "name": "graphql.is_valid", + "decl": { + "args": [ + { + "of": [ + { + "type": "string" + }, + { + "dynamic": { + "key": { + "type": "any" + }, + "value": { + "type": "any" + } + }, + "type": "object" + } + ], + "type": "any" + }, + { + "of": [ + { + "type": "string" + }, + { + "dynamic": { + "key": { + "type": "any" + }, + "value": { + "type": "any" + } + }, + "type": "object" + } + ], + "type": "any" + } + ], + "result": { + "type": "boolean" + }, + "type": "function" + } + }, + { + "name": "graphql.parse", + "decl": { + "args": [ + { + "of": [ + { + "type": "string" + }, + { + "dynamic": { + "key": { + "type": "any" + }, + "value": { + "type": "any" + } + }, + "type": "object" + } + ], + "type": "any" + }, + { + "of": [ + { + "type": "string" + }, + { + "dynamic": { + "key": { + "type": "any" + }, + "value": { + "type": "any" + } + }, + "type": "object" + } + ], + "type": "any" + } + ], + "result": { + "static": [ + { + "dynamic": { + "key": { + "type": "any" + }, + "value": { + "type": "any" + } + }, + "type": "object" + }, + { + "dynamic": { + "key": { + "type": "any" + }, + "value": { + "type": "any" + } + }, + "type": "object" + } + ], + "type": "array" + }, + "type": "function" + } + }, + { + "name": "graphql.parse_and_verify", + "decl": { + "args": [ + { + "of": [ + { + "type": "string" + }, + { + "dynamic": { + "key": { + "type": "any" + }, + "value": { + "type": "any" + } + }, + "type": "object" + } + ], + "type": "any" + }, + { + "of": [ + { + "type": "string" + }, + { + "dynamic": { + "key": { + "type": "any" + }, + "value": { + "type": "any" + } + }, + "type": "object" + } + ], + "type": "any" + } + ], + "result": { + "static": [ + { + "type": "boolean" + }, + { + "dynamic": { + "key": { + "type": "any" + }, + "value": { + "type": "any" + } + }, + "type": "object" + }, + { + "dynamic": { + "key": { + "type": "any" + }, + "value": { + "type": "any" + } + }, + "type": "object" + } + ], + "type": "array" + }, + "type": "function" + } + }, + { + "name": "graphql.parse_query", + "decl": { + "args": [ + { + "type": "string" + } + ], + "result": { + "dynamic": { + "key": { + "type": "any" + }, + "value": { + "type": "any" + } + }, + "type": "object" + }, + "type": "function" + } + }, + { + "name": "graphql.parse_schema", + "decl": { + "args": [ + { + "type": "string" + } + ], + "result": { + "dynamic": { + "key": { + "type": "any" + }, + "value": { + "type": "any" + } + }, + "type": "object" + }, + "type": "function" + } + }, + { + "name": "graphql.schema_is_valid", + "decl": { + "args": [ + { + "of": [ + { + "type": "string" + }, + { + "dynamic": { + "key": { + "type": "any" + }, + "value": { + "type": "any" + } + }, + "type": "object" + } + ], + "type": "any" + } + ], + "result": { + "type": "boolean" + }, + "type": "function" + } + }, + { + "name": "gt", + "decl": { + "args": [ + { + "type": "any" + }, + { + "type": "any" + } + ], + "result": { + "type": "boolean" + }, + "type": "function" + }, + "infix": "\u003e" + }, + { + "name": "gte", + "decl": { + "args": [ + { + "type": "any" + }, + { + "type": "any" + } + ], + "result": { + "type": "boolean" + }, + "type": "function" + }, + "infix": "\u003e=" + }, + { + "name": "hex.decode", + "decl": { + "args": [ + { + "type": "string" + } + ], + "result": { + "type": "string" + }, + "type": "function" + } + }, + { + "name": "hex.encode", + "decl": { + "args": [ + { + "type": "string" + } + ], + "result": { + "type": "string" + }, + "type": "function" + } + }, + { + "name": "http.send", + "decl": { + "args": [ + { + "dynamic": { + "key": { + "type": "string" + }, + "value": { + "type": "any" + } + }, + "type": "object" + } + ], + "result": { + "dynamic": { + "key": { + "type": "any" + }, + "value": { + "type": "any" + } + }, + "type": "object" + }, + "type": "function" + }, + "nondeterministic": true + }, + { + "name": "indexof", + "decl": { + "args": [ + { + "type": "string" + }, + { + "type": "string" + } + ], + "result": { + "type": "number" + }, + "type": "function" + } + }, + { + "name": "indexof_n", + "decl": { + "args": [ + { + "type": "string" + }, + { + "type": "string" + } + ], + "result": { + "dynamic": { + "type": "number" + }, + "type": "array" + }, + "type": "function" + } + }, + { + "name": "internal.member_2", + "decl": { + "args": [ + { + "type": "any" + }, + { + "type": "any" + } + ], + "result": { + "type": "boolean" + }, + "type": "function" + }, + "infix": "in" + }, + { + "name": "internal.member_3", + "decl": { + "args": [ + { + "type": "any" + }, + { + "type": "any" + }, + { + "type": "any" + } + ], + "result": { + "type": "boolean" + }, + "type": "function" + }, + "infix": "in" + }, + { + "name": "internal.print", + "decl": { + "args": [ + { + "dynamic": { + "of": { + "type": "any" + }, + "type": "set" + }, + "type": "array" + } + ], + "type": "function" + } + }, + { + "name": "intersection", + "decl": { + "args": [ + { + "of": { + "of": { + "type": "any" + }, + "type": "set" + }, + "type": "set" + } + ], + "result": { + "of": { + "type": "any" + }, + "type": "set" + }, + "type": "function" + } + }, + { + "name": "io.jwt.decode", + "decl": { + "args": [ + { + "type": "string" + } + ], + "result": { + "static": [ + { + "dynamic": { + "key": { + "type": "any" + }, + "value": { + "type": "any" + } + }, + "type": "object" + }, + { + "dynamic": { + "key": { + "type": "any" + }, + "value": { + "type": "any" + } + }, + "type": "object" + }, + { + "type": "string" + } + ], + "type": "array" + }, + "type": "function" + } + }, + { + "name": "io.jwt.decode_verify", + "decl": { + "args": [ + { + "type": "string" + }, + { + "dynamic": { + "key": { + "type": "string" + }, + "value": { + "type": "any" + } + }, + "type": "object" + } + ], + "result": { + "static": [ + { + "type": "boolean" + }, + { + "dynamic": { + "key": { + "type": "any" + }, + "value": { + "type": "any" + } + }, + "type": "object" + }, + { + "dynamic": { + "key": { + "type": "any" + }, + "value": { + "type": "any" + } + }, + "type": "object" + } + ], + "type": "array" + }, + "type": "function" + }, + "nondeterministic": true + }, + { + "name": "io.jwt.encode_sign", + "decl": { + "args": [ + { + "dynamic": { + "key": { + "type": "string" + }, + "value": { + "type": "any" + } + }, + "type": "object" + }, + { + "dynamic": { + "key": { + "type": "string" + }, + "value": { + "type": "any" + } + }, + "type": "object" + }, + { + "dynamic": { + "key": { + "type": "string" + }, + "value": { + "type": "any" + } + }, + "type": "object" + } + ], + "result": { + "type": "string" + }, + "type": "function" + }, + "nondeterministic": true + }, + { + "name": "io.jwt.encode_sign_raw", + "decl": { + "args": [ + { + "type": "string" + }, + { + "type": "string" + }, + { + "type": "string" + } + ], + "result": { + "type": "string" + }, + "type": "function" + }, + "nondeterministic": true + }, + { + "name": "io.jwt.verify_es256", + "decl": { + "args": [ + { + "type": "string" + }, + { + "type": "string" + } + ], + "result": { + "type": "boolean" + }, + "type": "function" + } + }, + { + "name": "io.jwt.verify_es384", + "decl": { + "args": [ + { + "type": "string" + }, + { + "type": "string" + } + ], + "result": { + "type": "boolean" + }, + "type": "function" + } + }, + { + "name": "io.jwt.verify_es512", + "decl": { + "args": [ + { + "type": "string" + }, + { + "type": "string" + } + ], + "result": { + "type": "boolean" + }, + "type": "function" + } + }, + { + "name": "io.jwt.verify_hs256", + "decl": { + "args": [ + { + "type": "string" + }, + { + "type": "string" + } + ], + "result": { + "type": "boolean" + }, + "type": "function" + } + }, + { + "name": "io.jwt.verify_hs384", + "decl": { + "args": [ + { + "type": "string" + }, + { + "type": "string" + } + ], + "result": { + "type": "boolean" + }, + "type": "function" + } + }, + { + "name": "io.jwt.verify_hs512", + "decl": { + "args": [ + { + "type": "string" + }, + { + "type": "string" + } + ], + "result": { + "type": "boolean" + }, + "type": "function" + } + }, + { + "name": "io.jwt.verify_ps256", + "decl": { + "args": [ + { + "type": "string" + }, + { + "type": "string" + } + ], + "result": { + "type": "boolean" + }, + "type": "function" + } + }, + { + "name": "io.jwt.verify_ps384", + "decl": { + "args": [ + { + "type": "string" + }, + { + "type": "string" + } + ], + "result": { + "type": "boolean" + }, + "type": "function" + } + }, + { + "name": "io.jwt.verify_ps512", + "decl": { + "args": [ + { + "type": "string" + }, + { + "type": "string" + } + ], + "result": { + "type": "boolean" + }, + "type": "function" + } + }, + { + "name": "io.jwt.verify_rs256", + "decl": { + "args": [ + { + "type": "string" + }, + { + "type": "string" + } + ], + "result": { + "type": "boolean" + }, + "type": "function" + } + }, + { + "name": "io.jwt.verify_rs384", + "decl": { + "args": [ + { + "type": "string" + }, + { + "type": "string" + } + ], + "result": { + "type": "boolean" + }, + "type": "function" + } + }, + { + "name": "io.jwt.verify_rs512", + "decl": { + "args": [ + { + "type": "string" + }, + { + "type": "string" + } + ], + "result": { + "type": "boolean" + }, + "type": "function" + } + }, + { + "name": "is_array", + "decl": { + "args": [ + { + "type": "any" + } + ], + "result": { + "type": "boolean" + }, + "type": "function" + } + }, + { + "name": "is_boolean", + "decl": { + "args": [ + { + "type": "any" + } + ], + "result": { + "type": "boolean" + }, + "type": "function" + } + }, + { + "name": "is_null", + "decl": { + "args": [ + { + "type": "any" + } + ], + "result": { + "type": "boolean" + }, + "type": "function" + } + }, + { + "name": "is_number", + "decl": { + "args": [ + { + "type": "any" + } + ], + "result": { + "type": "boolean" + }, + "type": "function" + } + }, + { + "name": "is_object", + "decl": { + "args": [ + { + "type": "any" + } + ], + "result": { + "type": "boolean" + }, + "type": "function" + } + }, + { + "name": "is_set", + "decl": { + "args": [ + { + "type": "any" + } + ], + "result": { + "type": "boolean" + }, + "type": "function" + } + }, + { + "name": "is_string", + "decl": { + "args": [ + { + "type": "any" + } + ], + "result": { + "type": "boolean" + }, + "type": "function" + } + }, + { + "name": "json.filter", + "decl": { + "args": [ + { + "dynamic": { + "key": { + "type": "any" + }, + "value": { + "type": "any" + } + }, + "type": "object" + }, + { + "of": [ + { + "dynamic": { + "of": [ + { + "type": "string" + }, + { + "dynamic": { + "type": "any" + }, + "type": "array" + } + ], + "type": "any" + }, + "type": "array" + }, + { + "of": { + "of": [ + { + "type": "string" + }, + { + "dynamic": { + "type": "any" + }, + "type": "array" + } + ], + "type": "any" + }, + "type": "set" + } + ], + "type": "any" + } + ], + "result": { + "type": "any" + }, + "type": "function" + } + }, + { + "name": "json.is_valid", + "decl": { + "args": [ + { + "type": "string" + } + ], + "result": { + "type": "boolean" + }, + "type": "function" + } + }, + { + "name": "json.marshal", + "decl": { + "args": [ + { + "type": "any" + } + ], + "result": { + "type": "string" + }, + "type": "function" + } + }, + { + "name": "json.marshal_with_options", + "decl": { + "args": [ + { + "type": "any" + }, + { + "dynamic": { + "key": { + "type": "string" + }, + "value": { + "type": "any" + } + }, + "static": [ + { + "key": "indent", + "value": { + "type": "string" + } + }, + { + "key": "prefix", + "value": { + "type": "string" + } + }, + { + "key": "pretty", + "value": { + "type": "boolean" + } + } + ], + "type": "object" + } + ], + "result": { + "type": "string" + }, + "type": "function" + } + }, + { + "name": "json.match_schema", + "decl": { + "args": [ + { + "of": [ + { + "type": "string" + }, + { + "dynamic": { + "key": { + "type": "any" + }, + "value": { + "type": "any" + } + }, + "type": "object" + } + ], + "type": "any" + }, + { + "of": [ + { + "type": "string" + }, + { + "dynamic": { + "key": { + "type": "any" + }, + "value": { + "type": "any" + } + }, + "type": "object" + } + ], + "type": "any" + } + ], + "result": { + "static": [ + { + "type": "boolean" + }, + { + "dynamic": { + "static": [ + { + "key": "desc", + "value": { + "type": "string" + } + }, + { + "key": "error", + "value": { + "type": "string" + } + }, + { + "key": "field", + "value": { + "type": "string" + } + }, + { + "key": "type", + "value": { + "type": "string" + } + } + ], + "type": "object" + }, + "type": "array" + } + ], + "type": "array" + }, + "type": "function" + } + }, + { + "name": "json.patch", + "decl": { + "args": [ + { + "type": "any" + }, + { + "dynamic": { + "dynamic": { + "key": { + "type": "any" + }, + "value": { + "type": "any" + } + }, + "static": [ + { + "key": "op", + "value": { + "type": "string" + } + }, + { + "key": "path", + "value": { + "type": "any" + } + } + ], + "type": "object" + }, + "type": "array" + } + ], + "result": { + "type": "any" + }, + "type": "function" + } + }, + { + "name": "json.remove", + "decl": { + "args": [ + { + "dynamic": { + "key": { + "type": "any" + }, + "value": { + "type": "any" + } + }, + "type": "object" + }, + { + "of": [ + { + "dynamic": { + "of": [ + { + "type": "string" + }, + { + "dynamic": { + "type": "any" + }, + "type": "array" + } + ], + "type": "any" + }, + "type": "array" + }, + { + "of": { + "of": [ + { + "type": "string" + }, + { + "dynamic": { + "type": "any" + }, + "type": "array" + } + ], + "type": "any" + }, + "type": "set" + } + ], + "type": "any" + } + ], + "result": { + "type": "any" + }, + "type": "function" + } + }, + { + "name": "json.unmarshal", + "decl": { + "args": [ + { + "type": "string" + } + ], + "result": { + "type": "any" + }, + "type": "function" + } + }, + { + "name": "json.verify_schema", + "decl": { + "args": [ + { + "of": [ + { + "type": "string" + }, + { + "dynamic": { + "key": { + "type": "any" + }, + "value": { + "type": "any" + } + }, + "type": "object" + } + ], + "type": "any" + } + ], + "result": { + "static": [ + { + "type": "boolean" + }, + { + "of": [ + { + "type": "null" + }, + { + "type": "string" + } + ], + "type": "any" + } + ], + "type": "array" + }, + "type": "function" + } + }, + { + "name": "lower", + "decl": { + "args": [ + { + "type": "string" + } + ], + "result": { + "type": "string" + }, + "type": "function" + } + }, + { + "name": "lt", + "decl": { + "args": [ + { + "type": "any" + }, + { + "type": "any" + } + ], + "result": { + "type": "boolean" + }, + "type": "function" + }, + "infix": "\u003c" + }, + { + "name": "lte", + "decl": { + "args": [ + { + "type": "any" + }, + { + "type": "any" + } + ], + "result": { + "type": "boolean" + }, + "type": "function" + }, + "infix": "\u003c=" + }, + { + "name": "max", + "decl": { + "args": [ + { + "of": [ + { + "dynamic": { + "type": "any" + }, + "type": "array" + }, + { + "of": { + "type": "any" + }, + "type": "set" + } + ], + "type": "any" + } + ], + "result": { + "type": "any" + }, + "type": "function" + } + }, + { + "name": "min", + "decl": { + "args": [ + { + "of": [ + { + "dynamic": { + "type": "any" + }, + "type": "array" + }, + { + "of": { + "type": "any" + }, + "type": "set" + } + ], + "type": "any" + } + ], + "result": { + "type": "any" + }, + "type": "function" + } + }, + { + "name": "minus", + "decl": { + "args": [ + { + "of": [ + { + "type": "number" + }, + { + "of": { + "type": "any" + }, + "type": "set" + } + ], + "type": "any" + }, + { + "of": [ + { + "type": "number" + }, + { + "of": { + "type": "any" + }, + "type": "set" + } + ], + "type": "any" + } + ], + "result": { + "of": [ + { + "type": "number" + }, + { + "of": { + "type": "any" + }, + "type": "set" + } + ], + "type": "any" + }, + "type": "function" + }, + "infix": "-" + }, + { + "name": "mul", + "decl": { + "args": [ + { + "type": "number" + }, + { + "type": "number" + } + ], + "result": { + "type": "number" + }, + "type": "function" + }, + "infix": "*" + }, + { + "name": "neq", + "decl": { + "args": [ + { + "type": "any" + }, + { + "type": "any" + } + ], + "result": { + "type": "boolean" + }, + "type": "function" + }, + "infix": "!=" + }, + { + "name": "net.cidr_contains", + "decl": { + "args": [ + { + "type": "string" + }, + { + "type": "string" + } + ], + "result": { + "type": "boolean" + }, + "type": "function" + } + }, + { + "name": "net.cidr_contains_matches", + "decl": { + "args": [ + { + "of": [ + { + "type": "string" + }, + { + "dynamic": { + "of": [ + { + "type": "string" + }, + { + "dynamic": { + "type": "any" + }, + "type": "array" + } + ], + "type": "any" + }, + "type": "array" + }, + { + "dynamic": { + "key": { + "type": "string" + }, + "value": { + "of": [ + { + "type": "string" + }, + { + "dynamic": { + "type": "any" + }, + "type": "array" + } + ], + "type": "any" + } + }, + "type": "object" + }, + { + "of": { + "of": [ + { + "type": "string" + }, + { + "dynamic": { + "type": "any" + }, + "type": "array" + } + ], + "type": "any" + }, + "type": "set" + } + ], + "type": "any" + }, + { + "of": [ + { + "type": "string" + }, + { + "dynamic": { + "of": [ + { + "type": "string" + }, + { + "dynamic": { + "type": "any" + }, + "type": "array" + } + ], + "type": "any" + }, + "type": "array" + }, + { + "dynamic": { + "key": { + "type": "string" + }, + "value": { + "of": [ + { + "type": "string" + }, + { + "dynamic": { + "type": "any" + }, + "type": "array" + } + ], + "type": "any" + } + }, + "type": "object" + }, + { + "of": { + "of": [ + { + "type": "string" + }, + { + "dynamic": { + "type": "any" + }, + "type": "array" + } + ], + "type": "any" + }, + "type": "set" + } + ], + "type": "any" + } + ], + "result": { + "of": { + "static": [ + { + "type": "any" + }, + { + "type": "any" + } + ], + "type": "array" + }, + "type": "set" + }, + "type": "function" + } + }, + { + "name": "net.cidr_expand", + "decl": { + "args": [ + { + "type": "string" + } + ], + "result": { + "of": { + "type": "string" + }, + "type": "set" + }, + "type": "function" + } + }, + { + "name": "net.cidr_intersects", + "decl": { + "args": [ + { + "type": "string" + }, + { + "type": "string" + } + ], + "result": { + "type": "boolean" + }, + "type": "function" + } + }, + { + "name": "net.cidr_is_valid", + "decl": { + "args": [ + { + "type": "string" + } + ], + "result": { + "type": "boolean" + }, + "type": "function" + } + }, + { + "name": "net.cidr_merge", + "decl": { + "args": [ + { + "of": [ + { + "dynamic": { + "of": [ + { + "type": "string" + } + ], + "type": "any" + }, + "type": "array" + }, + { + "of": { + "type": "string" + }, + "type": "set" + } + ], + "type": "any" + } + ], + "result": { + "of": { + "type": "string" + }, + "type": "set" + }, + "type": "function" + } + }, + { + "name": "net.cidr_overlap", + "decl": { + "args": [ + { + "type": "string" + }, + { + "type": "string" + } + ], + "result": { + "type": "boolean" + }, + "type": "function" + } + }, + { + "name": "net.lookup_ip_addr", + "decl": { + "args": [ + { + "type": "string" + } + ], + "result": { + "of": { + "type": "string" + }, + "type": "set" + }, + "type": "function" + }, + "nondeterministic": true + }, + { + "name": "numbers.range", + "decl": { + "args": [ + { + "type": "number" + }, + { + "type": "number" + } + ], + "result": { + "dynamic": { + "type": "number" + }, + "type": "array" + }, + "type": "function" + } + }, + { + "name": "numbers.range_step", + "decl": { + "args": [ + { + "type": "number" + }, + { + "type": "number" + }, + { + "type": "number" + } + ], + "result": { + "dynamic": { + "type": "number" + }, + "type": "array" + }, + "type": "function" + } + }, + { + "name": "object.filter", + "decl": { + "args": [ + { + "dynamic": { + "key": { + "type": "any" + }, + "value": { + "type": "any" + } + }, + "type": "object" + }, + { + "of": [ + { + "dynamic": { + "type": "any" + }, + "type": "array" + }, + { + "dynamic": { + "key": { + "type": "any" + }, + "value": { + "type": "any" + } + }, + "type": "object" + }, + { + "of": { + "type": "any" + }, + "type": "set" + } + ], + "type": "any" + } + ], + "result": { + "type": "any" + }, + "type": "function" + } + }, + { + "name": "object.get", + "decl": { + "args": [ + { + "dynamic": { + "key": { + "type": "any" + }, + "value": { + "type": "any" + } + }, + "type": "object" + }, + { + "type": "any" + }, + { + "type": "any" + } + ], + "result": { + "type": "any" + }, + "type": "function" + } + }, + { + "name": "object.keys", + "decl": { + "args": [ + { + "dynamic": { + "key": { + "type": "any" + }, + "value": { + "type": "any" + } + }, + "type": "object" + } + ], + "result": { + "of": { + "type": "any" + }, + "type": "set" + }, + "type": "function" + } + }, + { + "name": "object.remove", + "decl": { + "args": [ + { + "dynamic": { + "key": { + "type": "any" + }, + "value": { + "type": "any" + } + }, + "type": "object" + }, + { + "of": [ + { + "dynamic": { + "type": "any" + }, + "type": "array" + }, + { + "dynamic": { + "key": { + "type": "any" + }, + "value": { + "type": "any" + } + }, + "type": "object" + }, + { + "of": { + "type": "any" + }, + "type": "set" + } + ], + "type": "any" + } + ], + "result": { + "type": "any" + }, + "type": "function" + } + }, + { + "name": "object.subset", + "decl": { + "args": [ + { + "of": [ + { + "dynamic": { + "type": "any" + }, + "type": "array" + }, + { + "dynamic": { + "key": { + "type": "any" + }, + "value": { + "type": "any" + } + }, + "type": "object" + }, + { + "of": { + "type": "any" + }, + "type": "set" + } + ], + "type": "any" + }, + { + "of": [ + { + "dynamic": { + "type": "any" + }, + "type": "array" + }, + { + "dynamic": { + "key": { + "type": "any" + }, + "value": { + "type": "any" + } + }, + "type": "object" + }, + { + "of": { + "type": "any" + }, + "type": "set" + } + ], + "type": "any" + } + ], + "result": { + "type": "any" + }, + "type": "function" + } + }, + { + "name": "object.union", + "decl": { + "args": [ + { + "dynamic": { + "key": { + "type": "any" + }, + "value": { + "type": "any" + } + }, + "type": "object" + }, + { + "dynamic": { + "key": { + "type": "any" + }, + "value": { + "type": "any" + } + }, + "type": "object" + } + ], + "result": { + "type": "any" + }, + "type": "function" + } + }, + { + "name": "object.union_n", + "decl": { + "args": [ + { + "dynamic": { + "dynamic": { + "key": { + "type": "any" + }, + "value": { + "type": "any" + } + }, + "type": "object" + }, + "type": "array" + } + ], + "result": { + "type": "any" + }, + "type": "function" + } + }, + { + "name": "opa.runtime", + "decl": { + "result": { + "dynamic": { + "key": { + "type": "string" + }, + "value": { + "type": "any" + } + }, + "type": "object" + }, + "type": "function" + }, + "nondeterministic": true + }, + { + "name": "or", + "decl": { + "args": [ + { + "of": { + "type": "any" + }, + "type": "set" + }, + { + "of": { + "type": "any" + }, + "type": "set" + } + ], + "result": { + "of": { + "type": "any" + }, + "type": "set" + }, + "type": "function" + }, + "infix": "|" + }, + { + "name": "plus", + "decl": { + "args": [ + { + "type": "number" + }, + { + "type": "number" + } + ], + "result": { + "type": "number" + }, + "type": "function" + }, + "infix": "+" + }, + { + "name": "print", + "decl": { + "type": "function", + "variadic": { + "type": "any" + } + } + }, + { + "name": "product", + "decl": { + "args": [ + { + "of": [ + { + "dynamic": { + "type": "number" + }, + "type": "array" + }, + { + "of": { + "type": "number" + }, + "type": "set" + } + ], + "type": "any" + } + ], + "result": { + "type": "number" + }, + "type": "function" + } + }, + { + "name": "providers.aws.sign_req", + "decl": { + "args": [ + { + "dynamic": { + "key": { + "type": "string" + }, + "value": { + "type": "any" + } + }, + "type": "object" + }, + { + "dynamic": { + "key": { + "type": "string" + }, + "value": { + "type": "any" + } + }, + "type": "object" + }, + { + "type": "number" + } + ], + "result": { + "dynamic": { + "key": { + "type": "any" + }, + "value": { + "type": "any" + } + }, + "type": "object" + }, + "type": "function" + } + }, + { + "name": "rand.intn", + "decl": { + "args": [ + { + "type": "string" + }, + { + "type": "number" + } + ], + "result": { + "type": "number" + }, + "type": "function" + }, + "nondeterministic": true + }, + { + "name": "re_match", + "decl": { + "args": [ + { + "type": "string" + }, + { + "type": "string" + } + ], + "result": { + "type": "boolean" + }, + "type": "function" + } + }, + { + "name": "regex.find_all_string_submatch_n", + "decl": { + "args": [ + { + "type": "string" + }, + { + "type": "string" + }, + { + "type": "number" + } + ], + "result": { + "dynamic": { + "dynamic": { + "type": "string" + }, + "type": "array" + }, + "type": "array" + }, + "type": "function" + } + }, + { + "name": "regex.find_n", + "decl": { + "args": [ + { + "type": "string" + }, + { + "type": "string" + }, + { + "type": "number" + } + ], + "result": { + "dynamic": { + "type": "string" + }, + "type": "array" + }, + "type": "function" + } + }, + { + "name": "regex.globs_match", + "decl": { + "args": [ + { + "type": "string" + }, + { + "type": "string" + } + ], + "result": { + "type": "boolean" + }, + "type": "function" + } + }, + { + "name": "regex.is_valid", + "decl": { + "args": [ + { + "type": "string" + } + ], + "result": { + "type": "boolean" + }, + "type": "function" + } + }, + { + "name": "regex.match", + "decl": { + "args": [ + { + "type": "string" + }, + { + "type": "string" + } + ], + "result": { + "type": "boolean" + }, + "type": "function" + } + }, + { + "name": "regex.replace", + "decl": { + "args": [ + { + "type": "string" + }, + { + "type": "string" + }, + { + "type": "string" + } + ], + "result": { + "type": "string" + }, + "type": "function" + } + }, + { + "name": "regex.split", + "decl": { + "args": [ + { + "type": "string" + }, + { + "type": "string" + } + ], + "result": { + "dynamic": { + "type": "string" + }, + "type": "array" + }, + "type": "function" + } + }, + { + "name": "regex.template_match", + "decl": { + "args": [ + { + "type": "string" + }, + { + "type": "string" + }, + { + "type": "string" + }, + { + "type": "string" + } + ], + "result": { + "type": "boolean" + }, + "type": "function" + } + }, + { + "name": "rego.metadata.chain", + "decl": { + "result": { + "dynamic": { + "type": "any" + }, + "type": "array" + }, + "type": "function" + } + }, + { + "name": "rego.metadata.rule", + "decl": { + "result": { + "type": "any" + }, + "type": "function" + } + }, + { + "name": "rego.parse_module", + "decl": { + "args": [ + { + "type": "string" + }, + { + "type": "string" + } + ], + "result": { + "dynamic": { + "key": { + "type": "string" + }, + "value": { + "type": "any" + } + }, + "type": "object" + }, + "type": "function" + } + }, + { + "name": "rem", + "decl": { + "args": [ + { + "type": "number" + }, + { + "type": "number" + } + ], + "result": { + "type": "number" + }, + "type": "function" + }, + "infix": "%" + }, + { + "name": "replace", + "decl": { + "args": [ + { + "type": "string" + }, + { + "type": "string" + }, + { + "type": "string" + } + ], + "result": { + "type": "string" + }, + "type": "function" + } + }, + { + "name": "round", + "decl": { + "args": [ + { + "type": "number" + } + ], + "result": { + "type": "number" + }, + "type": "function" + } + }, + { + "name": "semver.compare", + "decl": { + "args": [ + { + "type": "string" + }, + { + "type": "string" + } + ], + "result": { + "type": "number" + }, + "type": "function" + } + }, + { + "name": "semver.is_valid", + "decl": { + "args": [ + { + "type": "any" + } + ], + "result": { + "type": "boolean" + }, + "type": "function" + } + }, + { + "name": "set_diff", + "decl": { + "args": [ + { + "of": { + "type": "any" + }, + "type": "set" + }, + { + "of": { + "type": "any" + }, + "type": "set" + } + ], + "result": { + "of": { + "type": "any" + }, + "type": "set" + }, + "type": "function" + } + }, + { + "name": "sort", + "decl": { + "args": [ + { + "of": [ + { + "dynamic": { + "type": "any" + }, + "type": "array" + }, + { + "of": { + "type": "any" + }, + "type": "set" + } + ], + "type": "any" + } + ], + "result": { + "dynamic": { + "type": "any" + }, + "type": "array" + }, + "type": "function" + } + }, + { + "name": "split", + "decl": { + "args": [ + { + "type": "string" + }, + { + "type": "string" + } + ], + "result": { + "dynamic": { + "type": "string" + }, + "type": "array" + }, + "type": "function" + } + }, + { + "name": "sprintf", + "decl": { + "args": [ + { + "type": "string" + }, + { + "dynamic": { + "type": "any" + }, + "type": "array" + } + ], + "result": { + "type": "string" + }, + "type": "function" + } + }, + { + "name": "startswith", + "decl": { + "args": [ + { + "type": "string" + }, + { + "type": "string" + } + ], + "result": { + "type": "boolean" + }, + "type": "function" + } + }, + { + "name": "strings.any_prefix_match", + "decl": { + "args": [ + { + "of": [ + { + "type": "string" + }, + { + "dynamic": { + "type": "string" + }, + "type": "array" + }, + { + "of": { + "type": "string" + }, + "type": "set" + } + ], + "type": "any" + }, + { + "of": [ + { + "type": "string" + }, + { + "dynamic": { + "type": "string" + }, + "type": "array" + }, + { + "of": { + "type": "string" + }, + "type": "set" + } + ], + "type": "any" + } + ], + "result": { + "type": "boolean" + }, + "type": "function" + } + }, + { + "name": "strings.any_suffix_match", + "decl": { + "args": [ + { + "of": [ + { + "type": "string" + }, + { + "dynamic": { + "type": "string" + }, + "type": "array" + }, + { + "of": { + "type": "string" + }, + "type": "set" + } + ], + "type": "any" + }, + { + "of": [ + { + "type": "string" + }, + { + "dynamic": { + "type": "string" + }, + "type": "array" + }, + { + "of": { + "type": "string" + }, + "type": "set" + } + ], + "type": "any" + } + ], + "result": { + "type": "boolean" + }, + "type": "function" + } + }, + { + "name": "strings.count", + "decl": { + "args": [ + { + "type": "string" + }, + { + "type": "string" + } + ], + "result": { + "type": "number" + }, + "type": "function" + } + }, + { + "name": "strings.render_template", + "decl": { + "args": [ + { + "type": "string" + }, + { + "dynamic": { + "key": { + "type": "string" + }, + "value": { + "type": "any" + } + }, + "type": "object" + } + ], + "result": { + "type": "string" + }, + "type": "function" + } + }, + { + "name": "strings.replace_n", + "decl": { + "args": [ + { + "dynamic": { + "key": { + "type": "string" + }, + "value": { + "type": "string" + } + }, + "type": "object" + }, + { + "type": "string" + } + ], + "result": { + "type": "string" + }, + "type": "function" + } + }, + { + "name": "strings.reverse", + "decl": { + "args": [ + { + "type": "string" + } + ], + "result": { + "type": "string" + }, + "type": "function" + } + }, + { + "name": "substring", + "decl": { + "args": [ + { + "type": "string" + }, + { + "type": "number" + }, + { + "type": "number" + } + ], + "result": { + "type": "string" + }, + "type": "function" + } + }, + { + "name": "sum", + "decl": { + "args": [ + { + "of": [ + { + "dynamic": { + "type": "number" + }, + "type": "array" + }, + { + "of": { + "type": "number" + }, + "type": "set" + } + ], + "type": "any" + } + ], + "result": { + "type": "number" + }, + "type": "function" + } + }, + { + "name": "time.add_date", + "decl": { + "args": [ + { + "type": "number" + }, + { + "type": "number" + }, + { + "type": "number" + }, + { + "type": "number" + } + ], + "result": { + "type": "number" + }, + "type": "function" + } + }, + { + "name": "time.clock", + "decl": { + "args": [ + { + "of": [ + { + "type": "number" + }, + { + "static": [ + { + "type": "number" + }, + { + "type": "string" + } + ], + "type": "array" + } + ], + "type": "any" + } + ], + "result": { + "static": [ + { + "type": "number" + }, + { + "type": "number" + }, + { + "type": "number" + } + ], + "type": "array" + }, + "type": "function" + } + }, + { + "name": "time.date", + "decl": { + "args": [ + { + "of": [ + { + "type": "number" + }, + { + "static": [ + { + "type": "number" + }, + { + "type": "string" + } + ], + "type": "array" + } + ], + "type": "any" + } + ], + "result": { + "static": [ + { + "type": "number" + }, + { + "type": "number" + }, + { + "type": "number" + } + ], + "type": "array" + }, + "type": "function" + } + }, + { + "name": "time.diff", + "decl": { + "args": [ + { + "of": [ + { + "type": "number" + }, + { + "static": [ + { + "type": "number" + }, + { + "type": "string" + } + ], + "type": "array" + } + ], + "type": "any" + }, + { + "of": [ + { + "type": "number" + }, + { + "static": [ + { + "type": "number" + }, + { + "type": "string" + } + ], + "type": "array" + } + ], + "type": "any" + } + ], + "result": { + "static": [ + { + "type": "number" + }, + { + "type": "number" + }, + { + "type": "number" + }, + { + "type": "number" + }, + { + "type": "number" + }, + { + "type": "number" + } + ], + "type": "array" + }, + "type": "function" + } + }, + { + "name": "time.format", + "decl": { + "args": [ + { + "of": [ + { + "type": "number" + }, + { + "static": [ + { + "type": "number" + }, + { + "type": "string" + } + ], + "type": "array" + }, + { + "static": [ + { + "type": "number" + }, + { + "type": "string" + }, + { + "type": "string" + } + ], + "type": "array" + } + ], + "type": "any" + } + ], + "result": { + "type": "string" + }, + "type": "function" + } + }, + { + "name": "time.now_ns", + "decl": { + "result": { + "type": "number" + }, + "type": "function" + }, + "nondeterministic": true + }, + { + "name": "time.parse_duration_ns", + "decl": { + "args": [ + { + "type": "string" + } + ], + "result": { + "type": "number" + }, + "type": "function" + } + }, + { + "name": "time.parse_ns", + "decl": { + "args": [ + { + "type": "string" + }, + { + "type": "string" + } + ], + "result": { + "type": "number" + }, + "type": "function" + } + }, + { + "name": "time.parse_rfc3339_ns", + "decl": { + "args": [ + { + "type": "string" + } + ], + "result": { + "type": "number" + }, + "type": "function" + } + }, + { + "name": "time.weekday", + "decl": { + "args": [ + { + "of": [ + { + "type": "number" + }, + { + "static": [ + { + "type": "number" + }, + { + "type": "string" + } + ], + "type": "array" + } + ], + "type": "any" + } + ], + "result": { + "type": "string" + }, + "type": "function" + } + }, + { + "name": "to_number", + "decl": { + "args": [ + { + "of": [ + { + "type": "null" + }, + { + "type": "boolean" + }, + { + "type": "number" + }, + { + "type": "string" + } + ], + "type": "any" + } + ], + "result": { + "type": "number" + }, + "type": "function" + } + }, + { + "name": "trace", + "decl": { + "args": [ + { + "type": "string" + } + ], + "result": { + "type": "boolean" + }, + "type": "function" + } + }, + { + "name": "trim", + "decl": { + "args": [ + { + "type": "string" + }, + { + "type": "string" + } + ], + "result": { + "type": "string" + }, + "type": "function" + } + }, + { + "name": "trim_left", + "decl": { + "args": [ + { + "type": "string" + }, + { + "type": "string" + } + ], + "result": { + "type": "string" + }, + "type": "function" + } + }, + { + "name": "trim_prefix", + "decl": { + "args": [ + { + "type": "string" + }, + { + "type": "string" + } + ], + "result": { + "type": "string" + }, + "type": "function" + } + }, + { + "name": "trim_right", + "decl": { + "args": [ + { + "type": "string" + }, + { + "type": "string" + } + ], + "result": { + "type": "string" + }, + "type": "function" + } + }, + { + "name": "trim_space", + "decl": { + "args": [ + { + "type": "string" + } + ], + "result": { + "type": "string" + }, + "type": "function" + } + }, + { + "name": "trim_suffix", + "decl": { + "args": [ + { + "type": "string" + }, + { + "type": "string" + } + ], + "result": { + "type": "string" + }, + "type": "function" + } + }, + { + "name": "type_name", + "decl": { + "args": [ + { + "type": "any" + } + ], + "result": { + "type": "string" + }, + "type": "function" + } + }, + { + "name": "union", + "decl": { + "args": [ + { + "of": { + "of": { + "type": "any" + }, + "type": "set" + }, + "type": "set" + } + ], + "result": { + "of": { + "type": "any" + }, + "type": "set" + }, + "type": "function" + } + }, + { + "name": "units.parse", + "decl": { + "args": [ + { + "type": "string" + } + ], + "result": { + "type": "number" + }, + "type": "function" + } + }, + { + "name": "units.parse_bytes", + "decl": { + "args": [ + { + "type": "string" + } + ], + "result": { + "type": "number" + }, + "type": "function" + } + }, + { + "name": "upper", + "decl": { + "args": [ + { + "type": "string" + } + ], + "result": { + "type": "string" + }, + "type": "function" + } + }, + { + "name": "urlquery.decode", + "decl": { + "args": [ + { + "type": "string" + } + ], + "result": { + "type": "string" + }, + "type": "function" + } + }, + { + "name": "urlquery.decode_object", + "decl": { + "args": [ + { + "type": "string" + } + ], + "result": { + "dynamic": { + "key": { + "type": "string" + }, + "value": { + "dynamic": { + "type": "string" + }, + "type": "array" + } + }, + "type": "object" + }, + "type": "function" + } + }, + { + "name": "urlquery.encode", + "decl": { + "args": [ + { + "type": "string" + } + ], + "result": { + "type": "string" + }, + "type": "function" + } + }, + { + "name": "urlquery.encode_object", + "decl": { + "args": [ + { + "dynamic": { + "key": { + "type": "string" + }, + "value": { + "of": [ + { + "type": "string" + }, + { + "dynamic": { + "type": "string" + }, + "type": "array" + }, + { + "of": { + "type": "string" + }, + "type": "set" + } + ], + "type": "any" + } + }, + "type": "object" + } + ], + "result": { + "type": "string" + }, + "type": "function" + } + }, + { + "name": "uuid.parse", + "decl": { + "args": [ + { + "type": "string" + } + ], + "result": { + "dynamic": { + "key": { + "type": "string" + }, + "value": { + "type": "any" + } + }, + "type": "object" + }, + "type": "function" + } + }, + { + "name": "uuid.rfc4122", + "decl": { + "args": [ + { + "type": "string" + } + ], + "result": { + "type": "string" + }, + "type": "function" + }, + "nondeterministic": true + }, + { + "name": "walk", + "decl": { + "args": [ + { + "type": "any" + } + ], + "result": { + "static": [ + { + "dynamic": { + "type": "any" + }, + "type": "array" + }, + { + "type": "any" + } + ], + "type": "array" + }, + "type": "function" + }, + "relation": true + }, + { + "name": "yaml.is_valid", + "decl": { + "args": [ + { + "type": "string" + } + ], + "result": { + "type": "boolean" + }, + "type": "function" + } + }, + { + "name": "yaml.marshal", + "decl": { + "args": [ + { + "type": "any" + } + ], + "result": { + "type": "string" + }, + "type": "function" + } + }, + { + "name": "yaml.unmarshal", + "decl": { + "args": [ + { + "type": "string" + } + ], + "result": { + "type": "any" + }, + "type": "function" + } + } + ], + "wasm_abi_versions": [ + { + "version": 1, + "minor_version": 1 + }, + { + "version": 1, + "minor_version": 2 + } + ], + "features": [ + "rego_v1" + ] +} diff --git a/vendor/github.com/open-policy-agent/opa/cmd/bench.go b/vendor/github.com/open-policy-agent/opa/cmd/bench.go index 74bc4073c..2ec1e9a12 100644 --- a/vendor/github.com/open-policy-agent/opa/cmd/bench.go +++ b/vendor/github.com/open-policy-agent/opa/cmd/bench.go @@ -19,22 +19,22 @@ import ( "testing" "time" - "github.com/open-policy-agent/opa/ast" + "github.com/open-policy-agent/opa/v1/ast" - "github.com/open-policy-agent/opa/server/types" + "github.com/open-policy-agent/opa/v1/server/types" - "github.com/open-policy-agent/opa/logging" - "github.com/open-policy-agent/opa/runtime" + "github.com/open-policy-agent/opa/v1/logging" + "github.com/open-policy-agent/opa/v1/runtime" "github.com/olekukonko/tablewriter" "github.com/spf13/cobra" "github.com/open-policy-agent/opa/cmd/internal/env" - "github.com/open-policy-agent/opa/compile" "github.com/open-policy-agent/opa/internal/presentation" - "github.com/open-policy-agent/opa/metrics" - "github.com/open-policy-agent/opa/rego" - "github.com/open-policy-agent/opa/util" + "github.com/open-policy-agent/opa/v1/compile" + "github.com/open-policy-agent/opa/v1/metrics" + "github.com/open-policy-agent/opa/v1/rego" + "github.com/open-policy-agent/opa/v1/util" ) // benchmarkCommandParams are a superset of evalCommandParams @@ -62,8 +62,9 @@ func newBenchmarkEvalParams() benchmarkCommandParams { evalPrettyOutput, benchmarkGoBenchOutput, }), - target: util.NewEnumFlag(compile.TargetRego, []string{compile.TargetRego, compile.TargetWasm}), - schema: &schemaFlags{}, + target: util.NewEnumFlag(compile.TargetRego, []string{compile.TargetRego, compile.TargetWasm}), + schema: &schemaFlags{}, + capabilities: newcapabilitiesFlag(), }, gracefulShutdownPeriod: 10, } diff --git a/vendor/github.com/open-policy-agent/opa/cmd/build.go b/vendor/github.com/open-policy-agent/opa/cmd/build.go index 7e33eefdf..f6655c37a 100644 --- a/vendor/github.com/open-policy-agent/opa/cmd/build.go +++ b/vendor/github.com/open-policy-agent/opa/cmd/build.go @@ -14,12 +14,12 @@ import ( "github.com/spf13/cobra" - "github.com/open-policy-agent/opa/ast" - "github.com/open-policy-agent/opa/bundle" "github.com/open-policy-agent/opa/cmd/internal/env" - "github.com/open-policy-agent/opa/compile" - "github.com/open-policy-agent/opa/keys" - "github.com/open-policy-agent/opa/util" + "github.com/open-policy-agent/opa/v1/ast" + "github.com/open-policy-agent/opa/v1/bundle" + "github.com/open-policy-agent/opa/v1/compile" + "github.com/open-policy-agent/opa/v1/keys" + "github.com/open-policy-agent/opa/v1/util" ) const defaultPublicKeyID = "default" @@ -47,6 +47,7 @@ type buildParams struct { v0Compatible bool v1Compatible bool followSymlinks bool + wasmIncludePrint bool } func newBuildParams() buildParams { @@ -56,6 +57,17 @@ func newBuildParams() buildParams { } } +func (p *buildParams) regoVersion() ast.RegoVersion { + if p.v0Compatible { + // v0 takes precedence over v1 + return ast.RegoV0 + } + if p.v1Compatible { + return ast.RegoV1 + } + return ast.DefaultRegoVersion +} + func init() { buildParams := newBuildParams() @@ -241,6 +253,7 @@ against OPA v0.22.0: buildCommand.Flags().StringVarP(&buildParams.outputFile, "output", "o", "bundle.tar.gz", "set the output filename") buildCommand.Flags().StringVar(&buildParams.ns, "partial-namespace", "partial", "set the namespace to use for partially evaluated files in an optimized bundle") buildCommand.Flags().BoolVar(&buildParams.followSymlinks, "follow-symlinks", false, "follow symlinks in the input set of paths when building the bundle") + buildCommand.Flags().BoolVar(&buildParams.wasmIncludePrint, "wasm-include-print", false, "enable print statements inside of WebAssembly modules compiled by the compiler") addBundleModeFlag(buildCommand.Flags(), &buildParams.bundleMode, false) addIgnoreFlag(buildCommand.Flags(), &buildParams.ignore) @@ -290,7 +303,7 @@ func dobuild(params buildParams, args []string) error { if params.capabilities.C != nil { capabilities = params.capabilities.C } else { - capabilities = ast.CapabilitiesForThisVersion() + capabilities = ast.CapabilitiesForThisVersion(ast.CapabilitiesRegoVersion(params.regoVersion())) } compiler := compile.New(). @@ -309,14 +322,7 @@ func dobuild(params buildParams, args []string) error { WithPartialNamespace(params.ns). WithFollowSymlinks(params.followSymlinks) - regoVersion := ast.DefaultRegoVersion - if params.v0Compatible { - // v0 takes precedence over v1 - regoVersion = ast.RegoV0 - } else if params.v1Compatible { - regoVersion = ast.RegoV1 - } - compiler = compiler.WithRegoVersion(regoVersion) + compiler = compiler.WithRegoVersion(params.regoVersion()) if params.revision.isSet { compiler = compiler.WithRevision(*params.revision.v) @@ -334,6 +340,10 @@ func dobuild(params buildParams, args []string) error { compiler = compiler.WithEnablePrintStatements(true) } + if params.target.String() == compile.TargetWasm { + compiler = compiler.WithEnablePrintStatements(params.wasmIncludePrint) + } + err = compiler.Build(context.Background()) if err != nil { return err diff --git a/vendor/github.com/open-policy-agent/opa/cmd/capabilities.go b/vendor/github.com/open-policy-agent/opa/cmd/capabilities.go index 4fdc3685a..e031671a1 100644 --- a/vendor/github.com/open-policy-agent/opa/cmd/capabilities.go +++ b/vendor/github.com/open-policy-agent/opa/cmd/capabilities.go @@ -9,15 +9,23 @@ import ( "fmt" "strings" - "github.com/open-policy-agent/opa/ast" "github.com/open-policy-agent/opa/cmd/internal/env" + "github.com/open-policy-agent/opa/v1/ast" "github.com/spf13/cobra" ) type capabilitiesParams struct { - showCurrent bool - version string - file string + showCurrent bool + version string + file string + v0Compatible bool +} + +func (p *capabilitiesParams) regoVersion() ast.RegoVersion { + if p.v0Compatible { + return ast.RegoV0 + } + return ast.DefaultRegoVersion } func init() { @@ -84,7 +92,8 @@ Print the capabilities of a capabilities file } capabilitiesCommand.Flags().BoolVar(&capabilitiesParams.showCurrent, "current", false, "print current capabilities") capabilitiesCommand.Flags().StringVar(&capabilitiesParams.version, "version", "", "print capabilities of a specific version") - capabilitiesCommand.Flags().StringVar(&capabilitiesParams.file, "file", "", "print current capabilities") + capabilitiesCommand.Flags().StringVar(&capabilitiesParams.file, "file", "", "print capabilities defined by a file") + addV0CompatibleFlag(capabilitiesCommand.Flags(), &capabilitiesParams.v0Compatible, false) RootCommand.AddCommand(capabilitiesCommand) } @@ -100,7 +109,7 @@ func doCapabilities(params capabilitiesParams) (string, error) { } else if len(params.file) > 0 { c, err = ast.LoadCapabilitiesFile(params.file) } else if params.showCurrent { - c = ast.CapabilitiesForThisVersion() + c = ast.CapabilitiesForThisVersion(ast.CapabilitiesRegoVersion(params.regoVersion())) } else { return showVersions() } diff --git a/vendor/github.com/open-policy-agent/opa/cmd/check.go b/vendor/github.com/open-policy-agent/opa/cmd/check.go index db874812b..6805411a7 100644 --- a/vendor/github.com/open-policy-agent/opa/cmd/check.go +++ b/vendor/github.com/open-policy-agent/opa/cmd/check.go @@ -12,11 +12,11 @@ import ( "github.com/spf13/cobra" - "github.com/open-policy-agent/opa/ast" "github.com/open-policy-agent/opa/cmd/internal/env" pr "github.com/open-policy-agent/opa/internal/presentation" - "github.com/open-policy-agent/opa/loader" - "github.com/open-policy-agent/opa/util" + "github.com/open-policy-agent/opa/v1/ast" + "github.com/open-policy-agent/opa/v1/loader" + "github.com/open-policy-agent/opa/v1/util" ) type checkParams struct { @@ -73,7 +73,7 @@ func checkModules(params checkParams, args []string) error { if params.capabilities.C != nil { capabilities = params.capabilities.C } else { - capabilities = ast.CapabilitiesForThisVersion() + capabilities = ast.CapabilitiesForThisVersion(ast.CapabilitiesRegoVersion(params.regoVersion())) } ss, err := loader.Schemas(params.schema.path) @@ -194,8 +194,8 @@ func init() { addCapabilitiesFlag(checkCommand.Flags(), checkParams.capabilities) addSchemaFlags(checkCommand.Flags(), checkParams.schema) addStrictFlag(checkCommand.Flags(), &checkParams.strict, false) - addRegoV1FlagWithDescription(checkCommand.Flags(), &checkParams.regoV1, false, - "check for Rego v1 compatibility (policies must also be compatible with current OPA version)") + addRegoV0V1FlagWithDescription(checkCommand.Flags(), &checkParams.regoV1, false, + "check for Rego v0 and v1 compatibility (policies must be compatible with both Rego versions)") addV0CompatibleFlag(checkCommand.Flags(), &checkParams.v0Compatible, false) addV1CompatibleFlag(checkCommand.Flags(), &checkParams.v1Compatible, false) RootCommand.AddCommand(checkCommand) diff --git a/vendor/github.com/open-policy-agent/opa/cmd/deps.go b/vendor/github.com/open-policy-agent/opa/cmd/deps.go index 72c29b278..dba62cdee 100644 --- a/vendor/github.com/open-policy-agent/opa/cmd/deps.go +++ b/vendor/github.com/open-policy-agent/opa/cmd/deps.go @@ -10,15 +10,15 @@ import ( "io" "os" - "github.com/open-policy-agent/opa/dependencies" "github.com/open-policy-agent/opa/internal/presentation" + "github.com/open-policy-agent/opa/v1/dependencies" "github.com/spf13/cobra" - "github.com/open-policy-agent/opa/ast" "github.com/open-policy-agent/opa/cmd/internal/env" - "github.com/open-policy-agent/opa/loader" - "github.com/open-policy-agent/opa/util" + "github.com/open-policy-agent/opa/v1/ast" + "github.com/open-policy-agent/opa/v1/loader" + "github.com/open-policy-agent/opa/v1/util" ) type depsCommandParams struct { @@ -74,8 +74,6 @@ Given a policy like this: package policy - import rego.v1 - allow if is_admin is_admin if "admin" in input.user.roles diff --git a/vendor/github.com/open-policy-agent/opa/cmd/doc.go b/vendor/github.com/open-policy-agent/opa/cmd/doc.go index 387bdff80..34ad90ade 100644 --- a/vendor/github.com/open-policy-agent/opa/cmd/doc.go +++ b/vendor/github.com/open-policy-agent/opa/cmd/doc.go @@ -2,5 +2,4 @@ // Use of this source code is governed by an Apache2 // license that can be found in the LICENSE file. -// Package cmd contains the entry points for OPA commands. package cmd diff --git a/vendor/github.com/open-policy-agent/opa/cmd/eval.go b/vendor/github.com/open-policy-agent/opa/cmd/eval.go index 89bfead2a..be9261b2c 100644 --- a/vendor/github.com/open-policy-agent/opa/cmd/eval.go +++ b/vendor/github.com/open-policy-agent/opa/cmd/eval.go @@ -17,22 +17,26 @@ import ( "github.com/spf13/cobra" - "github.com/open-policy-agent/opa/ast" - "github.com/open-policy-agent/opa/ast/location" - "github.com/open-policy-agent/opa/bundle" "github.com/open-policy-agent/opa/cmd/internal/env" - "github.com/open-policy-agent/opa/compile" - "github.com/open-policy-agent/opa/cover" fileurl "github.com/open-policy-agent/opa/internal/file/url" pr "github.com/open-policy-agent/opa/internal/presentation" "github.com/open-policy-agent/opa/internal/runtime" - "github.com/open-policy-agent/opa/loader" - "github.com/open-policy-agent/opa/metrics" - "github.com/open-policy-agent/opa/profiler" - "github.com/open-policy-agent/opa/rego" - "github.com/open-policy-agent/opa/topdown" - "github.com/open-policy-agent/opa/topdown/lineage" - "github.com/open-policy-agent/opa/util" + "github.com/open-policy-agent/opa/v1/ast" + "github.com/open-policy-agent/opa/v1/ast/location" + "github.com/open-policy-agent/opa/v1/bundle" + "github.com/open-policy-agent/opa/v1/compile" + "github.com/open-policy-agent/opa/v1/cover" + "github.com/open-policy-agent/opa/v1/loader" + "github.com/open-policy-agent/opa/v1/metrics" + "github.com/open-policy-agent/opa/v1/profiler" + "github.com/open-policy-agent/opa/v1/rego" + "github.com/open-policy-agent/opa/v1/topdown" + "github.com/open-policy-agent/opa/v1/topdown/lineage" + "github.com/open-policy-agent/opa/v1/util" +) + +var ( + errIllegalUnknownsArg = errors.New("illegal argument with --unknowns, specify string with one or more --unknowns") ) type evalCommandParams struct { @@ -109,6 +113,7 @@ func newEvalCommandParams() evalCommandParams { } func validateEvalParams(p *evalCommandParams, cmdArgs []string) error { + if len(cmdArgs) > 0 && p.stdin { return errors.New("specify query argument or --stdin but not both") } else if len(cmdArgs) == 0 && !p.stdin { @@ -132,6 +137,21 @@ func validateEvalParams(p *evalCommandParams, cmdArgs []string) error { return errors.New("invalid output format for evaluation") } + // check if illegal arguments is passed with unknowns flag + for _, unknwn := range p.unknowns { + term, err := ast.ParseTerm(unknwn) + if err != nil { + return err + } + + switch term.Value.(type) { + case ast.Ref: + return nil + default: + return errIllegalUnknownsArg + } + } + if p.optimizationLevel > 0 { if len(p.dataPaths.v) > 0 && p.bundlePaths.isFlagSet() { return fmt.Errorf("specify either --data or --bundle flag with optimization level greater than 0") @@ -392,7 +412,7 @@ func eval(args []string, params evalCommandParams, w io.Writer) (bool, error) { for _, t := range timers { val, ok := t[name].(int64) if !ok { - return false, fmt.Errorf("missing timer for %s" + name) + return false, fmt.Errorf("missing timer for %s", name) } vals = append(vals, val) } @@ -479,7 +499,6 @@ func evalOnce(ctx context.Context, ectx *evalContext) pr.Output { result.Errors = pr.NewOutputErrors(resultErr) if ectx.builtInErrorList != nil { for _, err := range *(ectx.builtInErrorList) { - err := err result.Errors = append(result.Errors, pr.NewOutputErrors(&err)...) } } @@ -686,8 +705,10 @@ func setupEval(args []string, params evalCommandParams) (*evalContext, error) { regoArgs = append(regoArgs, rego.BuiltinErrorList(&builtInErrors)) } - if params.capabilities != nil { + if params.capabilities.C != nil { regoArgs = append(regoArgs, rego.Capabilities(params.capabilities.C)) + } else { + regoArgs = append(regoArgs, rego.Capabilities(ast.CapabilitiesForThisVersion(ast.CapabilitiesRegoVersion(params.regoVersion())))) } if params.strict { @@ -859,7 +880,7 @@ func generateOptimizedBundle(params evalCommandParams, asBundle bool, filter loa if params.capabilities.C != nil { capabilities = params.capabilities.C } else { - capabilities = ast.CapabilitiesForThisVersion() + capabilities = ast.CapabilitiesForThisVersion(ast.CapabilitiesRegoVersion(params.regoVersion())) } compiler := compile.New(). diff --git a/vendor/github.com/open-policy-agent/opa/cmd/exec.go b/vendor/github.com/open-policy-agent/opa/cmd/exec.go index a26065d61..ecabc2b84 100644 --- a/vendor/github.com/open-policy-agent/opa/cmd/exec.go +++ b/vendor/github.com/open-policy-agent/opa/cmd/exec.go @@ -15,14 +15,14 @@ import ( "github.com/open-policy-agent/opa/cmd/internal/exec" "github.com/open-policy-agent/opa/internal/config" internal_logging "github.com/open-policy-agent/opa/internal/logging" - "github.com/open-policy-agent/opa/logging" - "github.com/open-policy-agent/opa/plugins" - "github.com/open-policy-agent/opa/plugins/bundle" - "github.com/open-policy-agent/opa/plugins/discovery" - "github.com/open-policy-agent/opa/plugins/logs" - "github.com/open-policy-agent/opa/plugins/status" - "github.com/open-policy-agent/opa/sdk" - "github.com/open-policy-agent/opa/util" + "github.com/open-policy-agent/opa/v1/logging" + "github.com/open-policy-agent/opa/v1/plugins" + "github.com/open-policy-agent/opa/v1/plugins/bundle" + "github.com/open-policy-agent/opa/v1/plugins/discovery" + "github.com/open-policy-agent/opa/v1/plugins/logs" + "github.com/open-policy-agent/opa/v1/plugins/status" + "github.com/open-policy-agent/opa/v1/sdk" + "github.com/open-policy-agent/opa/v1/util" ) func init() { diff --git a/vendor/github.com/open-policy-agent/opa/cmd/features.go b/vendor/github.com/open-policy-agent/opa/cmd/features.go index ed3de8e64..379abd392 100644 --- a/vendor/github.com/open-policy-agent/opa/cmd/features.go +++ b/vendor/github.com/open-policy-agent/opa/cmd/features.go @@ -7,4 +7,4 @@ package cmd -import _ "github.com/open-policy-agent/opa/features/wasm" +import _ "github.com/open-policy-agent/opa/v1/features/wasm" diff --git a/vendor/github.com/open-policy-agent/opa/cmd/filters.go b/vendor/github.com/open-policy-agent/opa/cmd/filters.go index a830aafd1..83a3e92ad 100644 --- a/vendor/github.com/open-policy-agent/opa/cmd/filters.go +++ b/vendor/github.com/open-policy-agent/opa/cmd/filters.go @@ -8,8 +8,8 @@ import ( "os" "path/filepath" - "github.com/open-policy-agent/opa/bundle" - "github.com/open-policy-agent/opa/loader" + "github.com/open-policy-agent/opa/v1/bundle" + "github.com/open-policy-agent/opa/v1/loader" ) type loaderFilter struct { diff --git a/vendor/github.com/open-policy-agent/opa/cmd/flags.go b/vendor/github.com/open-policy-agent/opa/cmd/flags.go index c0e65b521..20a7ec742 100644 --- a/vendor/github.com/open-policy-agent/opa/cmd/flags.go +++ b/vendor/github.com/open-policy-agent/opa/cmd/flags.go @@ -9,8 +9,8 @@ import ( "github.com/spf13/pflag" - "github.com/open-policy-agent/opa/ast" - "github.com/open-policy-agent/opa/util" + "github.com/open-policy-agent/opa/v1/ast" + "github.com/open-policy-agent/opa/v1/util" ) func addConfigFileFlag(fs *pflag.FlagSet, file *string) { @@ -153,16 +153,21 @@ func addStrictFlag(fs *pflag.FlagSet, strict *bool, value bool) { fs.BoolVarP(strict, "strict", "S", value, "enable compiler strict mode") } -func addRegoV1FlagWithDescription(fs *pflag.FlagSet, regoV1 *bool, value bool, description string) { +func addRegoV0V1FlagWithDescription(fs *pflag.FlagSet, regoV1 *bool, value bool, description string) { + fs.BoolVar(regoV1, "v0-v1", value, description) + + // For backwards compatibility fs.BoolVar(regoV1, "rego-v1", value, description) + _ = fs.MarkHidden("rego-v1") } func addV0CompatibleFlag(fs *pflag.FlagSet, v1Compatible *bool, value bool) { - fs.BoolVar(v1Compatible, "v0-compatible", value, "opt-in to OPA features and behaviors prior to the OPA v1.0 release. Takes precedence over --v1-compatible") + fs.BoolVar(v1Compatible, "v0-compatible", value, "opt-in to OPA features and behaviors prior to the OPA v1.0 release") } func addV1CompatibleFlag(fs *pflag.FlagSet, v1Compatible *bool, value bool) { fs.BoolVar(v1Compatible, "v1-compatible", value, "opt-in to OPA features and behaviors that are enabled by default in OPA v1.0") + _ = fs.MarkHidden("v1-compatible") } func addReadAstValuesFromStoreFlag(fs *pflag.FlagSet, readAstValuesFromStore *bool, value bool) { diff --git a/vendor/github.com/open-policy-agent/opa/cmd/fmt.go b/vendor/github.com/open-policy-agent/opa/cmd/fmt.go index fb8353f38..e9bb13a3c 100644 --- a/vendor/github.com/open-policy-agent/opa/cmd/fmt.go +++ b/vendor/github.com/open-policy-agent/opa/cmd/fmt.go @@ -14,21 +14,22 @@ import ( "github.com/sergi/go-diff/diffmatchpatch" "github.com/spf13/cobra" - "github.com/open-policy-agent/opa/ast" "github.com/open-policy-agent/opa/cmd/internal/env" - "github.com/open-policy-agent/opa/format" fileurl "github.com/open-policy-agent/opa/internal/file/url" + "github.com/open-policy-agent/opa/v1/ast" + "github.com/open-policy-agent/opa/v1/format" ) type fmtCommandParams struct { - overwrite bool - list bool - diff bool - fail bool - regoV1 bool - v0Compatible bool - v1Compatible bool - checkResult bool + overwrite bool + list bool + diff bool + fail bool + regoV1 bool + v0Compatible bool + v1Compatible bool + checkResult bool + dropV0Imports bool } var fmtParams = fmtCommandParams{} @@ -58,7 +59,7 @@ is provided - this tool will use stdin. The format of the output is not defined specifically; whatever this tool outputs is considered correct format (with the exception of bugs). -If the '-w' option is supplied, the 'fmt' command with overwrite the source file +If the '-w' option is supplied, the 'fmt' command will overwrite the source file instead of printing to stdout. If the '-d' option is supplied, the 'fmt' command will output a diff between the @@ -69,7 +70,26 @@ that would change if formatted. The '-l' option will suppress any other output to stdout from the 'fmt' command. If the '--fail' option is supplied, the 'fmt' command will return a non zero exit -code if a file would be reformatted.`, +code if a file would be reformatted. + +The 'fmt' command can be run in several compatibility modes for consuming and outputting +different Rego versions: + +* 'opa fmt': + * v1 Rego is formatted to v1 + * 'rego.v1'/'future.keywords' imports are NOT removed + * 'rego.v1'/'future.keywords' imports are NOT added if missing + * v0 rego is rejected +* 'opa fmt --v0-compatible': + * v0 Rego is formatted to v0 + * v1 Rego is rejected +* 'opa fmt --v0-v1': + * v0 Rego is formatted to be compatible with v0 AND v1 + * v1 Rego is rejected +* 'opa fmt --v0-v1 --v1-compatible': + * v1 Rego is formatted to be compatible with v0 AND v1 + * v0 Rego is rejected +`, PreRunE: func(cmd *cobra.Command, _ []string) error { return env.CmdFlags.CheckEnvironmentVariables(cmd) }, @@ -133,7 +153,12 @@ func formatFile(params *fmtCommandParams, out io.Writer, filename string, info o } opts := format.Opts{ - RegoVersion: params.regoVersion(), + RegoVersion: params.regoVersion(), + DropV0Imports: params.dropV0Imports, + } + + if params.regoV1 { + opts.ParserOptions = &ast.ParserOptions{RegoVersion: ast.RegoV0} } if params.v0Compatible { @@ -214,6 +239,11 @@ func formatStdin(params *fmtCommandParams, r io.Reader, w io.Writer) error { opts := format.Opts{} opts.RegoVersion = params.regoVersion() + + if params.regoV1 { + opts.ParserOptions = &ast.ParserOptions{RegoVersion: ast.RegoV0} + } + formatted, err := format.SourceWithOpts("stdin", contents, opts) if err != nil { return err @@ -223,9 +253,9 @@ func formatStdin(params *fmtCommandParams, r io.Reader, w io.Writer) error { return err } -func doDiff(old, new []byte) (diffString string) { +func doDiff(a, b []byte) (diffString string) { // "a" is old, "b" is new dmp := diffmatchpatch.New() - diffs := dmp.DiffMain(string(old), string(new), false) + diffs := dmp.DiffMain(string(a), string(b), false) return dmp.DiffPrettyText(diffs) } @@ -250,10 +280,11 @@ func init() { formatCommand.Flags().BoolVarP(&fmtParams.list, "list", "l", false, "list all files who would change when formatted") formatCommand.Flags().BoolVarP(&fmtParams.diff, "diff", "d", false, "only display a diff of the changes") formatCommand.Flags().BoolVar(&fmtParams.fail, "fail", false, "non zero exit code on reformat") - addRegoV1FlagWithDescription(formatCommand.Flags(), &fmtParams.regoV1, false, "format module(s) to be compatible with both Rego v1 and current OPA version)") + addRegoV0V1FlagWithDescription(formatCommand.Flags(), &fmtParams.regoV1, false, "format module(s) to be compatible with both Rego v0 and v1") addV0CompatibleFlag(formatCommand.Flags(), &fmtParams.v0Compatible, false) addV1CompatibleFlag(formatCommand.Flags(), &fmtParams.v1Compatible, false) formatCommand.Flags().BoolVar(&fmtParams.checkResult, "check-result", true, "assert that the formatted code is valid and can be successfully parsed (default true)") + formatCommand.Flags().BoolVar(&fmtParams.dropV0Imports, "drop-v0-imports", false, "drop v0 imports from the formatted code, such as 'rego.v1' and 'future.keywords'") RootCommand.AddCommand(formatCommand) } diff --git a/vendor/github.com/open-policy-agent/opa/cmd/inspect.go b/vendor/github.com/open-policy-agent/opa/cmd/inspect.go index 8b70bdfbe..cd5125dde 100644 --- a/vendor/github.com/open-policy-agent/opa/cmd/inspect.go +++ b/vendor/github.com/open-policy-agent/opa/cmd/inspect.go @@ -13,13 +13,13 @@ import ( "strconv" "strings" - "github.com/open-policy-agent/opa/ast" - "github.com/open-policy-agent/opa/bundle" "github.com/open-policy-agent/opa/cmd/internal/env" ib "github.com/open-policy-agent/opa/internal/bundle/inspect" pr "github.com/open-policy-agent/opa/internal/presentation" iStrs "github.com/open-policy-agent/opa/internal/strings" - "github.com/open-policy-agent/opa/util" + "github.com/open-policy-agent/opa/v1/ast" + "github.com/open-policy-agent/opa/v1/bundle" + "github.com/open-policy-agent/opa/v1/util" "github.com/olekukonko/tablewriter" "github.com/spf13/cobra" @@ -31,14 +31,18 @@ const pageWidth = 80 type inspectCommandParams struct { outputFormat *util.EnumFlag listAnnotations bool + v0Compatible bool v1Compatible bool } func (p *inspectCommandParams) regoVersion() ast.RegoVersion { + if p.v0Compatible { + return ast.RegoV0 + } if p.v1Compatible { return ast.RegoV1 } - return ast.RegoV0 + return ast.DefaultRegoVersion } func newInspectCommandParams() inspectCommandParams { @@ -98,6 +102,7 @@ that file and summarize its structure and contents. addOutputFormat(inspectCommand.Flags(), params.outputFormat) addListAnnotations(inspectCommand.Flags(), ¶ms.listAnnotations) + addV0CompatibleFlag(inspectCommand.Flags(), ¶ms.v0Compatible, false) addV1CompatibleFlag(inspectCommand.Flags(), ¶ms.v1Compatible, false) RootCommand.AddCommand(inspectCommand) } diff --git a/vendor/github.com/open-policy-agent/opa/cmd/internal/exec/exec.go b/vendor/github.com/open-policy-agent/opa/cmd/internal/exec/exec.go index 8748b4dda..974b755c9 100644 --- a/vendor/github.com/open-policy-agent/opa/cmd/internal/exec/exec.go +++ b/vendor/github.com/open-policy-agent/opa/cmd/internal/exec/exec.go @@ -8,7 +8,7 @@ import ( "path/filepath" "strings" - "github.com/open-policy-agent/opa/sdk" + "github.com/open-policy-agent/opa/v1/sdk" ) var ( diff --git a/vendor/github.com/open-policy-agent/opa/cmd/internal/exec/json_reporter.go b/vendor/github.com/open-policy-agent/opa/cmd/internal/exec/json_reporter.go index c0b23dbe7..ba40a60b3 100644 --- a/vendor/github.com/open-policy-agent/opa/cmd/internal/exec/json_reporter.go +++ b/vendor/github.com/open-policy-agent/opa/cmd/internal/exec/json_reporter.go @@ -7,7 +7,7 @@ import ( "io" "time" - "github.com/open-policy-agent/opa/sdk" + "github.com/open-policy-agent/opa/v1/sdk" ) type result struct { diff --git a/vendor/github.com/open-policy-agent/opa/cmd/internal/exec/params.go b/vendor/github.com/open-policy-agent/opa/cmd/internal/exec/params.go index 28f0521e9..34c0a7733 100644 --- a/vendor/github.com/open-policy-agent/opa/cmd/internal/exec/params.go +++ b/vendor/github.com/open-policy-agent/opa/cmd/internal/exec/params.go @@ -5,8 +5,8 @@ import ( "io" "time" - "github.com/open-policy-agent/opa/logging" - "github.com/open-policy-agent/opa/util" + "github.com/open-policy-agent/opa/v1/logging" + "github.com/open-policy-agent/opa/v1/util" ) type Params struct { diff --git a/vendor/github.com/open-policy-agent/opa/cmd/internal/exec/parser.go b/vendor/github.com/open-policy-agent/opa/cmd/internal/exec/parser.go index 46ebf08dc..1f857c6ac 100644 --- a/vendor/github.com/open-policy-agent/opa/cmd/internal/exec/parser.go +++ b/vendor/github.com/open-policy-agent/opa/cmd/internal/exec/parser.go @@ -3,7 +3,7 @@ package exec import ( "io" - "github.com/open-policy-agent/opa/util" + "github.com/open-policy-agent/opa/v1/util" ) type parser interface { diff --git a/vendor/github.com/open-policy-agent/opa/cmd/oracle.go b/vendor/github.com/open-policy-agent/opa/cmd/oracle.go index 31f0cf5a9..d9772103a 100644 --- a/vendor/github.com/open-policy-agent/opa/cmd/oracle.go +++ b/vendor/github.com/open-policy-agent/opa/cmd/oracle.go @@ -14,12 +14,12 @@ import ( "github.com/spf13/cobra" - "github.com/open-policy-agent/opa/ast" - "github.com/open-policy-agent/opa/bundle" "github.com/open-policy-agent/opa/cmd/internal/env" "github.com/open-policy-agent/opa/internal/oracle" "github.com/open-policy-agent/opa/internal/presentation" - "github.com/open-policy-agent/opa/loader" + "github.com/open-policy-agent/opa/v1/ast" + "github.com/open-policy-agent/opa/v1/bundle" + "github.com/open-policy-agent/opa/v1/loader" ) type findDefinitionParams struct { @@ -162,6 +162,8 @@ func dofindDefinition(params findDefinitionParams, stdin io.Reader, stdout io.Wr } } + // FindDefinition() will instantiate a new compiler, but we don't need to set the + // default rego-version because the passed modules already have the rego-version from parsing. result, err := oracle.New().FindDefinition(oracle.DefinitionQuery{ Buffer: bs, Filename: filename, diff --git a/vendor/github.com/open-policy-agent/opa/cmd/parse.go b/vendor/github.com/open-policy-agent/opa/cmd/parse.go index e525cff6d..8eddefdc3 100644 --- a/vendor/github.com/open-policy-agent/opa/cmd/parse.go +++ b/vendor/github.com/open-policy-agent/opa/cmd/parse.go @@ -13,12 +13,12 @@ import ( "github.com/spf13/cobra" - "github.com/open-policy-agent/opa/ast" - astJSON "github.com/open-policy-agent/opa/ast/json" "github.com/open-policy-agent/opa/cmd/internal/env" pr "github.com/open-policy-agent/opa/internal/presentation" - "github.com/open-policy-agent/opa/loader" - "github.com/open-policy-agent/opa/util" + "github.com/open-policy-agent/opa/v1/ast" + astJSON "github.com/open-policy-agent/opa/v1/ast/json" + "github.com/open-policy-agent/opa/v1/loader" + "github.com/open-policy-agent/opa/v1/util" ) const ( diff --git a/vendor/github.com/open-policy-agent/opa/cmd/refactor.go b/vendor/github.com/open-policy-agent/opa/cmd/refactor.go index 7ef187c91..b92621b49 100644 --- a/vendor/github.com/open-policy-agent/opa/cmd/refactor.go +++ b/vendor/github.com/open-policy-agent/opa/cmd/refactor.go @@ -13,12 +13,12 @@ import ( "github.com/spf13/cobra" - "github.com/open-policy-agent/opa/ast" "github.com/open-policy-agent/opa/cmd/internal/env" - "github.com/open-policy-agent/opa/format" fileurl "github.com/open-policy-agent/opa/internal/file/url" - "github.com/open-policy-agent/opa/loader" - "github.com/open-policy-agent/opa/refactor" + "github.com/open-policy-agent/opa/v1/ast" + "github.com/open-policy-agent/opa/v1/format" + "github.com/open-policy-agent/opa/v1/loader" + "github.com/open-policy-agent/opa/v1/refactor" ) type moveCommandParams struct { @@ -149,7 +149,7 @@ func doMove(params moveCommandParams, args []string, out io.Writer) error { return err } - formatted, err := format.Ast(mod) + formatted, err := format.AstWithOpts(mod, format.Opts{RegoVersion: params.regoVersion()}) if err != nil { return newError("failed to parse Rego source file: %v", err) } diff --git a/vendor/github.com/open-policy-agent/opa/cmd/run.go b/vendor/github.com/open-policy-agent/opa/cmd/run.go index 9e8cb0eac..68703655f 100644 --- a/vendor/github.com/open-policy-agent/opa/cmd/run.go +++ b/vendor/github.com/open-policy-agent/opa/cmd/run.go @@ -17,9 +17,9 @@ import ( "github.com/open-policy-agent/opa/cmd/internal/env" fileurl "github.com/open-policy-agent/opa/internal/file/url" - "github.com/open-policy-agent/opa/runtime" - "github.com/open-policy-agent/opa/server" - "github.com/open-policy-agent/opa/util" + "github.com/open-policy-agent/opa/v1/runtime" + "github.com/open-policy-agent/opa/v1/server" + "github.com/open-policy-agent/opa/v1/util" ) const ( @@ -215,7 +215,7 @@ See https://godoc.org/crypto/tls#pkg-constants for more information. runCommand.Flags().BoolVarP(&cmdParams.serverMode, "server", "s", false, "start the runtime in server mode") runCommand.Flags().IntVar(&cmdParams.rt.ReadyTimeout, "ready-timeout", 0, "wait (in seconds) for configured plugins before starting server (value <= 0 disables ready check)") runCommand.Flags().StringVarP(&cmdParams.rt.HistoryPath, "history", "H", historyPath(), "set path of history file") - cmdParams.rt.Addrs = runCommand.Flags().StringSliceP("addr", "a", []string{defaultAddr}, "set listening address of the server (e.g., [ip]: for TCP, unix:// for UNIX domain socket)") + cmdParams.rt.Addrs = runCommand.Flags().StringSliceP("addr", "a", []string{defaultLocalAddr}, "set listening address of the server (e.g., [ip]: for TCP, unix:// for UNIX domain socket)") cmdParams.rt.DiagnosticAddrs = runCommand.Flags().StringSlice("diagnostic-addr", []string{}, "set read-only diagnostic listening address of the server for /health and /metric APIs (e.g., [ip]: for TCP, unix:// for UNIX domain socket)") cmdParams.rt.UnixSocketPerm = runCommand.Flags().String("unix-socket-perm", "755", "specify the permissions for the Unix domain socket if used to listen for incoming connections") runCommand.Flags().BoolVar(&cmdParams.rt.H2CEnabled, "h2c", false, "enable H2C for HTTP listeners") @@ -381,9 +381,8 @@ func initRuntime(ctx context.Context, params runCmdParams, args []string, addrSe rt.SetDistributedTracingLogging() rt.Params.AddrSetByUser = addrSetByUser - // v0 negates v1 - if !addrSetByUser && !rt.Params.V0Compatible && rt.Params.V1Compatible { - rt.Params.Addrs = &[]string{defaultLocalAddr} + if !addrSetByUser && rt.Params.V0Compatible { + rt.Params.Addrs = &[]string{defaultAddr} } return rt, nil diff --git a/vendor/github.com/open-policy-agent/opa/cmd/sign.go b/vendor/github.com/open-policy-agent/opa/cmd/sign.go index 1580b13ad..5e9676c12 100644 --- a/vendor/github.com/open-policy-agent/opa/cmd/sign.go +++ b/vendor/github.com/open-policy-agent/opa/cmd/sign.go @@ -16,10 +16,10 @@ import ( "github.com/spf13/cobra" - "github.com/open-policy-agent/opa/bundle" "github.com/open-policy-agent/opa/cmd/internal/env" initload "github.com/open-policy-agent/opa/internal/runtime/init" - "github.com/open-policy-agent/opa/util" + "github.com/open-policy-agent/opa/v1/bundle" + "github.com/open-policy-agent/opa/v1/util" ) type signCmdParams struct { diff --git a/vendor/github.com/open-policy-agent/opa/cmd/test.go b/vendor/github.com/open-policy-agent/opa/cmd/test.go index 7b7dfefe9..94cdd173e 100644 --- a/vendor/github.com/open-policy-agent/opa/cmd/test.go +++ b/vendor/github.com/open-policy-agent/opa/cmd/test.go @@ -6,6 +6,7 @@ package cmd import ( "context" + "errors" "fmt" "io" "os" @@ -17,21 +18,21 @@ import ( "github.com/fsnotify/fsnotify" "github.com/open-policy-agent/opa/internal/pathwatcher" initload "github.com/open-policy-agent/opa/internal/runtime/init" - "github.com/open-policy-agent/opa/loader" "github.com/spf13/cobra" - "github.com/open-policy-agent/opa/ast" - "github.com/open-policy-agent/opa/bundle" "github.com/open-policy-agent/opa/cmd/internal/env" - "github.com/open-policy-agent/opa/compile" - "github.com/open-policy-agent/opa/cover" "github.com/open-policy-agent/opa/internal/runtime" - "github.com/open-policy-agent/opa/storage" - "github.com/open-policy-agent/opa/storage/inmem" - "github.com/open-policy-agent/opa/tester" - "github.com/open-policy-agent/opa/topdown" - "github.com/open-policy-agent/opa/topdown/lineage" - "github.com/open-policy-agent/opa/util" + "github.com/open-policy-agent/opa/v1/ast" + "github.com/open-policy-agent/opa/v1/bundle" + "github.com/open-policy-agent/opa/v1/compile" + "github.com/open-policy-agent/opa/v1/cover" + "github.com/open-policy-agent/opa/v1/loader" + "github.com/open-policy-agent/opa/v1/storage" + "github.com/open-policy-agent/opa/v1/storage/inmem" + "github.com/open-policy-agent/opa/v1/tester" + "github.com/open-policy-agent/opa/v1/topdown" + "github.com/open-policy-agent/opa/v1/topdown/lineage" + "github.com/open-policy-agent/opa/v1/util" ) const ( @@ -115,11 +116,17 @@ func opaTest(args []string, testParams testCommandParams) (int, error) { var bundles map[string]*bundle.Bundle var store storage.Store + popts := ast.ParserOptions{ + RegoVersion: testParams.RegoVersion(), + Capabilities: testParams.capabilities.C, + ProcessAnnotation: true, + } + if testParams.bundleMode { - bundles, err = tester.LoadBundlesWithRegoVersion(args, filter.Apply, testParams.RegoVersion()) + bundles, err = tester.LoadBundlesWithParserOptions(args, filter.Apply, popts) store = inmem.NewWithOpts(inmem.OptRoundTripOnWrite(false)) } else { - modules, store, err = tester.LoadWithRegoVersion(args, filter.Apply, testParams.RegoVersion()) + modules, store, err = tester.LoadWithParserOptions(args, filter.Apply, popts) } if err != nil { @@ -286,7 +293,7 @@ func processWatcherUpdate(ctx context.Context, testParams testCommandParams, pat var loadResult *initload.LoadPathsResult - err := pathwatcher.ProcessWatcherUpdate(ctx, paths, removed, store, filter.Apply, testParams.bundleMode, + err := pathwatcher.ProcessWatcherUpdateForRegoVersion(ctx, testParams.RegoVersion(), paths, removed, store, filter.Apply, testParams.bundleMode, func(ctx context.Context, txn storage.Transaction, loaded *initload.LoadPathsResult) error { if len(loaded.Files.Documents) > 0 || removed != "" { if err := store.Write(ctx, txn, storage.AddOp, storage.Path{}, loaded.Files.Documents); err != nil { @@ -373,7 +380,7 @@ func compileAndSetupTests(ctx context.Context, testParams testCommandParams, sto if testParams.benchmark { errMsg := "coverage reporting is not supported when benchmarking tests" fmt.Fprintln(testParams.errOutput, errMsg) - return nil, nil, fmt.Errorf(errMsg) + return nil, nil, errors.New(errMsg) } cov = cover.New() coverTracer = cov @@ -460,8 +467,6 @@ Example policy (example/authz.rego): package authz - import rego.v1 - allow if { input.path == ["users"] input.method == "POST" @@ -476,8 +481,6 @@ Example test (example/authz_test.rego): package authz_test - import rego.v1 - import data.authz.allow test_post_allowed if { diff --git a/vendor/github.com/open-policy-agent/opa/cmd/version.go b/vendor/github.com/open-policy-agent/opa/cmd/version.go index 3d3b0de4a..500f16e17 100644 --- a/vendor/github.com/open-policy-agent/opa/cmd/version.go +++ b/vendor/github.com/open-policy-agent/opa/cmd/version.go @@ -11,12 +11,13 @@ import ( "io" "os" + "github.com/open-policy-agent/opa/v1/ast" + version2 "github.com/open-policy-agent/opa/v1/version" "github.com/spf13/cobra" "github.com/open-policy-agent/opa/cmd/internal/env" "github.com/open-policy-agent/opa/internal/report" "github.com/open-policy-agent/opa/internal/uuid" - "github.com/open-policy-agent/opa/version" ) func init() { @@ -41,16 +42,17 @@ func init() { } func generateCmdOutput(out io.Writer, check bool) { - fmt.Fprintln(out, "Version: "+version.Version) - fmt.Fprintln(out, "Build Commit: "+version.Vcs) - fmt.Fprintln(out, "Build Timestamp: "+version.Timestamp) - fmt.Fprintln(out, "Build Hostname: "+version.Hostname) - fmt.Fprintln(out, "Go Version: "+version.GoVersion) - fmt.Fprintln(out, "Platform: "+version.Platform) + fmt.Fprintln(out, "Version: "+version2.Version) + fmt.Fprintln(out, "Build Commit: "+version2.Vcs) + fmt.Fprintln(out, "Build Timestamp: "+version2.Timestamp) + fmt.Fprintln(out, "Build Hostname: "+version2.Hostname) + fmt.Fprintln(out, "Go Version: "+version2.GoVersion) + fmt.Fprintln(out, "Platform: "+version2.Platform) + fmt.Fprintln(out, "Rego Version: "+ast.DefaultRegoVersion.String()) var wasmAvailable string - if version.WasmRuntimeAvailable() { + if version2.WasmRuntimeAvailable() { wasmAvailable = "available" } else { wasmAvailable = "unavailable" diff --git a/vendor/github.com/open-policy-agent/opa/config/config.go b/vendor/github.com/open-policy-agent/opa/config/config.go index 87ab10911..e612df0a0 100644 --- a/vendor/github.com/open-policy-agent/opa/config/config.go +++ b/vendor/github.com/open-policy-agent/opa/config/config.go @@ -6,254 +6,14 @@ package config import ( - "encoding/json" - "fmt" - "os" - "path/filepath" - "reflect" - "sort" - "strings" - - "github.com/open-policy-agent/opa/ast" - "github.com/open-policy-agent/opa/internal/ref" - "github.com/open-policy-agent/opa/util" - "github.com/open-policy-agent/opa/version" + v1 "github.com/open-policy-agent/opa/v1/config" ) // Config represents the configuration file that OPA can be started with. -type Config struct { - Services json.RawMessage `json:"services,omitempty"` - Labels map[string]string `json:"labels,omitempty"` - Discovery json.RawMessage `json:"discovery,omitempty"` - Bundle json.RawMessage `json:"bundle,omitempty"` // Deprecated: Use `bundles` instead - Bundles json.RawMessage `json:"bundles,omitempty"` - DecisionLogs json.RawMessage `json:"decision_logs,omitempty"` - Status json.RawMessage `json:"status,omitempty"` - Plugins map[string]json.RawMessage `json:"plugins,omitempty"` - Keys json.RawMessage `json:"keys,omitempty"` - DefaultDecision *string `json:"default_decision,omitempty"` - DefaultAuthorizationDecision *string `json:"default_authorization_decision,omitempty"` - Caching json.RawMessage `json:"caching,omitempty"` - NDBuiltinCache bool `json:"nd_builtin_cache,omitempty"` - PersistenceDirectory *string `json:"persistence_directory,omitempty"` - DistributedTracing json.RawMessage `json:"distributed_tracing,omitempty"` - Server *struct { - Encoding json.RawMessage `json:"encoding,omitempty"` - Decoding json.RawMessage `json:"decoding,omitempty"` - Metrics json.RawMessage `json:"metrics,omitempty"` - } `json:"server,omitempty"` - Storage *struct { - Disk json.RawMessage `json:"disk,omitempty"` - } `json:"storage,omitempty"` - Extra map[string]json.RawMessage `json:"-"` -} +type Config = v1.Config // ParseConfig returns a valid Config object with defaults injected. The id // and version parameters will be set in the labels map. func ParseConfig(raw []byte, id string) (*Config, error) { - // NOTE(sr): based on https://stackoverflow.com/a/33499066/993018 - var result Config - objValue := reflect.ValueOf(&result).Elem() - knownFields := map[string]reflect.Value{} - for i := 0; i != objValue.NumField(); i++ { - jsonName := strings.Split(objValue.Type().Field(i).Tag.Get("json"), ",")[0] - knownFields[jsonName] = objValue.Field(i) - } - - if err := util.Unmarshal(raw, &result.Extra); err != nil { - return nil, err - } - - for key, chunk := range result.Extra { - if field, found := knownFields[key]; found { - if err := util.Unmarshal(chunk, field.Addr().Interface()); err != nil { - return nil, err - } - delete(result.Extra, key) - } - } - if len(result.Extra) == 0 { - result.Extra = nil - } - return &result, result.validateAndInjectDefaults(id) -} - -// PluginNames returns a sorted list of names of enabled plugins. -func (c Config) PluginNames() (result []string) { - if c.Bundle != nil || c.Bundles != nil { - result = append(result, "bundles") - } - if c.Status != nil { - result = append(result, "status") - } - if c.DecisionLogs != nil { - result = append(result, "decision_logs") - } - for name := range c.Plugins { - result = append(result, name) - } - sort.Strings(result) - return result -} - -// PluginsEnabled returns true if one or more plugin features are enabled. -// -// Deprecated. Use PluginNames instead. -func (c Config) PluginsEnabled() bool { - return c.Bundle != nil || c.Bundles != nil || c.DecisionLogs != nil || c.Status != nil || len(c.Plugins) > 0 -} - -// DefaultDecisionRef returns the default decision as a reference. -func (c Config) DefaultDecisionRef() ast.Ref { - r, _ := ref.ParseDataPath(*c.DefaultDecision) - return r -} - -// DefaultAuthorizationDecisionRef returns the default authorization decision -// as a reference. -func (c Config) DefaultAuthorizationDecisionRef() ast.Ref { - r, _ := ref.ParseDataPath(*c.DefaultAuthorizationDecision) - return r -} - -// NDBuiltinCacheEnabled returns if the ND builtins cache should be used. -func (c Config) NDBuiltinCacheEnabled() bool { - return c.NDBuiltinCache -} - -func (c *Config) validateAndInjectDefaults(id string) error { - - if c.DefaultDecision == nil { - s := defaultDecisionPath - c.DefaultDecision = &s - } - - _, err := ref.ParseDataPath(*c.DefaultDecision) - if err != nil { - return err - } - - if c.DefaultAuthorizationDecision == nil { - s := defaultAuthorizationDecisionPath - c.DefaultAuthorizationDecision = &s - } - - _, err = ref.ParseDataPath(*c.DefaultAuthorizationDecision) - if err != nil { - return err - } - - if c.Labels == nil { - c.Labels = map[string]string{} - } - - c.Labels["id"] = id - c.Labels["version"] = version.Version - - return nil -} - -// GetPersistenceDirectory returns the configured persistence directory, or $PWD/.opa if none is configured -func (c Config) GetPersistenceDirectory() (string, error) { - if c.PersistenceDirectory == nil { - pwd, err := os.Getwd() - if err != nil { - return "", err - } - return filepath.Join(pwd, ".opa"), nil - } - return *c.PersistenceDirectory, nil -} - -// ActiveConfig returns OPA's active configuration -// with the credentials and crypto keys removed -func (c *Config) ActiveConfig() (interface{}, error) { - bs, err := json.Marshal(c) - if err != nil { - return nil, err - } - - var result map[string]interface{} - if err := util.UnmarshalJSON(bs, &result); err != nil { - return nil, err - } - for k, e := range c.Extra { - var v any - if err := util.UnmarshalJSON(e, &v); err != nil { - return nil, err - } - result[k] = v - } - - if err := removeServiceCredentials(result["services"]); err != nil { - return nil, err - } - - if err := removeCryptoKeys(result["keys"]); err != nil { - return nil, err - } - - return result, nil -} - -func removeServiceCredentials(x interface{}) error { - switch x := x.(type) { - case nil: - return nil - case []interface{}: - for _, v := range x { - err := removeKey(v, "credentials") - if err != nil { - return err - } - } - - case map[string]interface{}: - for _, v := range x { - err := removeKey(v, "credentials") - if err != nil { - return err - } - } - default: - return fmt.Errorf("illegal service config type: %T", x) - } - - return nil -} - -func removeCryptoKeys(x interface{}) error { - switch x := x.(type) { - case nil: - return nil - case map[string]interface{}: - for _, v := range x { - err := removeKey(v, "key", "private_key") - if err != nil { - return err - } - } - default: - return fmt.Errorf("illegal keys config type: %T", x) - } - - return nil + return v1.ParseConfig(raw, id) } - -func removeKey(x interface{}, keys ...string) error { - val, ok := x.(map[string]interface{}) - if !ok { - return fmt.Errorf("type assertion error") - } - - for _, key := range keys { - delete(val, key) - } - - return nil -} - -const ( - defaultDecisionPath = "/system/main" - defaultAuthorizationDecisionPath = "/system/authz/allow" -) diff --git a/vendor/github.com/open-policy-agent/opa/config/doc.go b/vendor/github.com/open-policy-agent/opa/config/doc.go new file mode 100644 index 000000000..c6dd968dc --- /dev/null +++ b/vendor/github.com/open-policy-agent/opa/config/doc.go @@ -0,0 +1,8 @@ +// Copyright 2024 The OPA Authors. All rights reserved. +// Use of this source code is governed by an Apache2 +// license that can be found in the LICENSE file. + +// Deprecated: This package is intended for older projects transitioning from OPA v0.x and will remain for the lifetime of OPA v1.x, but its use is not recommended. +// For newer features and behaviours, such as defaulting to the Rego v1 syntax, use the corresponding components in the [github.com/open-policy-agent/opa/v1] package instead. +// See https://www.openpolicyagent.org/docs/latest/v0-compatibility/ for more information. +package config diff --git a/vendor/github.com/open-policy-agent/opa/hooks/doc.go b/vendor/github.com/open-policy-agent/opa/hooks/doc.go new file mode 100644 index 000000000..6c5092427 --- /dev/null +++ b/vendor/github.com/open-policy-agent/opa/hooks/doc.go @@ -0,0 +1,8 @@ +// Copyright 2024 The OPA Authors. All rights reserved. +// Use of this source code is governed by an Apache2 +// license that can be found in the LICENSE file. + +// Deprecated: This package is intended for older projects transitioning from OPA v0.x and will remain for the lifetime of OPA v1.x, but its use is not recommended. +// For newer features and behaviours, such as defaulting to the Rego v1 syntax, use the corresponding components in the [github.com/open-policy-agent/opa/v1] package instead. +// See https://www.openpolicyagent.org/docs/latest/v0-compatibility/ for more information. +package hooks diff --git a/vendor/github.com/open-policy-agent/opa/hooks/hooks.go b/vendor/github.com/open-policy-agent/opa/hooks/hooks.go index 9659d7b49..110398873 100644 --- a/vendor/github.com/open-policy-agent/opa/hooks/hooks.go +++ b/vendor/github.com/open-policy-agent/opa/hooks/hooks.go @@ -5,10 +5,7 @@ package hooks import ( - "context" - "fmt" - - "github.com/open-policy-agent/opa/config" + v1 "github.com/open-policy-agent/opa/v1/hooks" ) // Hook is a hook to be called in some select places in OPA's operation. @@ -27,26 +24,14 @@ import ( // When multiple instances of a hook are provided, they are all going to be executed // in an unspecified order (it's a map-range call underneath). If you need hooks to // be run in order, you can wrap them into another hook, and configure that one. -type Hook any +type Hook = v1.Hook // Hooks is the type used for every struct in OPA that can work with hooks. -type Hooks struct { - m map[Hook]struct{} // we are NOT providing a stable invocation ordering -} +type Hooks = v1.Hooks // New creates a new instance of Hooks. func New(hs ...Hook) Hooks { - h := Hooks{m: make(map[Hook]struct{}, len(hs))} - for i := range hs { - h.m[hs[i]] = struct{}{} - } - return h -} - -func (hs Hooks) Each(fn func(Hook)) { - for h := range hs.m { - fn(h) - } + return v1.New(hs...) } // ConfigHook allows inspecting or rewriting the configuration when the plugin @@ -54,24 +39,8 @@ func (hs Hooks) Each(fn func(Hook)) { // Note that this hook is not run when the plugin manager is reconfigured. This // usually only happens when there's a new config from a discovery bundle, and // for processing _that_, there's `ConfigDiscoveryHook`. -type ConfigHook interface { - OnConfig(context.Context, *config.Config) (*config.Config, error) -} +type ConfigHook = v1.ConfigHook // ConfigHook allows inspecting or rewriting the discovered configuration when // the discovery plugin is processing it. -type ConfigDiscoveryHook interface { - OnConfigDiscovery(context.Context, *config.Config) (*config.Config, error) -} - -func (hs Hooks) Validate() error { - for h := range hs.m { - switch h.(type) { - case ConfigHook, - ConfigDiscoveryHook: // OK - default: - return fmt.Errorf("unknown hook type %T", h) - } - } - return nil -} +type ConfigDiscoveryHook = v1.ConfigDiscoveryHook diff --git a/vendor/github.com/open-policy-agent/opa/internal/bundle/inspect/inspect.go b/vendor/github.com/open-policy-agent/opa/internal/bundle/inspect/inspect.go index 39cef42fe..43f203faf 100644 --- a/vendor/github.com/open-policy-agent/opa/internal/bundle/inspect/inspect.go +++ b/vendor/github.com/open-policy-agent/opa/internal/bundle/inspect/inspect.go @@ -12,12 +12,12 @@ import ( "path/filepath" "strings" - "github.com/open-policy-agent/opa/ast" - "github.com/open-policy-agent/opa/ast/json" - "github.com/open-policy-agent/opa/bundle" initload "github.com/open-policy-agent/opa/internal/runtime/init" - "github.com/open-policy-agent/opa/loader" - "github.com/open-policy-agent/opa/util" + "github.com/open-policy-agent/opa/v1/ast" + "github.com/open-policy-agent/opa/v1/ast/json" + "github.com/open-policy-agent/opa/v1/bundle" + "github.com/open-policy-agent/opa/v1/loader" + "github.com/open-policy-agent/opa/v1/util" ) // Info represents information about a bundle. diff --git a/vendor/github.com/open-policy-agent/opa/internal/bundle/utils.go b/vendor/github.com/open-policy-agent/opa/internal/bundle/utils.go index 064649733..3d67d5692 100644 --- a/vendor/github.com/open-policy-agent/opa/internal/bundle/utils.go +++ b/vendor/github.com/open-policy-agent/opa/internal/bundle/utils.go @@ -11,10 +11,10 @@ import ( "os" "path/filepath" - "github.com/open-policy-agent/opa/ast" - "github.com/open-policy-agent/opa/bundle" - "github.com/open-policy-agent/opa/resolver/wasm" - "github.com/open-policy-agent/opa/storage" + "github.com/open-policy-agent/opa/v1/ast" + "github.com/open-policy-agent/opa/v1/bundle" + "github.com/open-policy-agent/opa/v1/resolver/wasm" + "github.com/open-policy-agent/opa/v1/storage" ) // LoadWasmResolversFromStore will lookup all Wasm modules from the store along with the diff --git a/vendor/github.com/open-policy-agent/opa/internal/compiler/utils.go b/vendor/github.com/open-policy-agent/opa/internal/compiler/utils.go index 4d80aeeef..dfb781e19 100644 --- a/vendor/github.com/open-policy-agent/opa/internal/compiler/utils.go +++ b/vendor/github.com/open-policy-agent/opa/internal/compiler/utils.go @@ -5,9 +5,9 @@ package compiler import ( - "github.com/open-policy-agent/opa/ast" - "github.com/open-policy-agent/opa/schemas" - "github.com/open-policy-agent/opa/util" + "github.com/open-policy-agent/opa/v1/ast" + "github.com/open-policy-agent/opa/v1/schemas" + "github.com/open-policy-agent/opa/v1/util" ) type SchemaFile string @@ -32,7 +32,10 @@ func VerifyAuthorizationPolicySchema(compiler *ast.Compiler, ref ast.Ref) error schemaSet := ast.NewSchemaSet() schemaSet.Put(ast.SchemaRootRef, schemaDefinitions[AuthorizationPolicySchema]) - errs := ast.NewCompiler().WithSchemas(schemaSet).PassesTypeCheckRules(rules) + errs := ast.NewCompiler(). + WithDefaultRegoVersion(compiler.DefaultRegoVersion()). + WithSchemas(schemaSet). + PassesTypeCheckRules(rules) if len(errs) > 0 { return errs diff --git a/vendor/github.com/open-policy-agent/opa/internal/compiler/wasm/wasm.go b/vendor/github.com/open-policy-agent/opa/internal/compiler/wasm/wasm.go index 9a5cebec5..08dfe4486 100644 --- a/vendor/github.com/open-policy-agent/opa/internal/compiler/wasm/wasm.go +++ b/vendor/github.com/open-policy-agent/opa/internal/compiler/wasm/wasm.go @@ -12,7 +12,6 @@ import ( "fmt" "io" - "github.com/open-policy-agent/opa/ast" "github.com/open-policy-agent/opa/internal/compiler/wasm/opa" "github.com/open-policy-agent/opa/internal/debug" "github.com/open-policy-agent/opa/internal/wasm/encoding" @@ -20,8 +19,9 @@ import ( "github.com/open-policy-agent/opa/internal/wasm/module" "github.com/open-policy-agent/opa/internal/wasm/types" "github.com/open-policy-agent/opa/internal/wasm/util" - "github.com/open-policy-agent/opa/ir" - opatypes "github.com/open-policy-agent/opa/types" + "github.com/open-policy-agent/opa/v1/ast" + "github.com/open-policy-agent/opa/v1/ir" + opatypes "github.com/open-policy-agent/opa/v1/types" ) // Record Wasm ABI version in exported global variable diff --git a/vendor/github.com/open-policy-agent/opa/internal/config/config.go b/vendor/github.com/open-policy-agent/opa/internal/config/config.go index b1a9731f6..fdac48772 100644 --- a/vendor/github.com/open-policy-agent/opa/internal/config/config.go +++ b/vendor/github.com/open-policy-agent/opa/internal/config/config.go @@ -15,11 +15,11 @@ import ( "sigs.k8s.io/yaml" "github.com/open-policy-agent/opa/internal/strvals" - "github.com/open-policy-agent/opa/keys" - "github.com/open-policy-agent/opa/logging" - "github.com/open-policy-agent/opa/plugins/rest" - "github.com/open-policy-agent/opa/tracing" - "github.com/open-policy-agent/opa/util" + "github.com/open-policy-agent/opa/v1/keys" + "github.com/open-policy-agent/opa/v1/logging" + "github.com/open-policy-agent/opa/v1/plugins/rest" + "github.com/open-policy-agent/opa/v1/tracing" + "github.com/open-policy-agent/opa/v1/util" ) // ServiceOptions stores the options passed to ParseServicesConfig diff --git a/vendor/github.com/open-policy-agent/opa/internal/distributedtracing/distributedtracing.go b/vendor/github.com/open-policy-agent/opa/internal/distributedtracing/distributedtracing.go index 956a5025d..e4254e265 100644 --- a/vendor/github.com/open-policy-agent/opa/internal/distributedtracing/distributedtracing.go +++ b/vendor/github.com/open-policy-agent/opa/internal/distributedtracing/distributedtracing.go @@ -22,14 +22,14 @@ import ( semconv "go.opentelemetry.io/otel/semconv/v1.7.0" "google.golang.org/grpc/credentials" - "github.com/open-policy-agent/opa/config" - "github.com/open-policy-agent/opa/logging" - "github.com/open-policy-agent/opa/util" + "github.com/open-policy-agent/opa/v1/config" + "github.com/open-policy-agent/opa/v1/logging" + "github.com/open-policy-agent/opa/v1/util" // The import registers opentelemetry with the top-level `tracing` package, // so the latter can be used from rego/topdown without an explicit build-time // dependency. - _ "github.com/open-policy-agent/opa/features/tracing" + _ "github.com/open-policy-agent/opa/v1/features/tracing" ) const ( diff --git a/vendor/github.com/open-policy-agent/opa/internal/edittree/edittree.go b/vendor/github.com/open-policy-agent/opa/internal/edittree/edittree.go index 9cfaee8ba..4a4f8101f 100644 --- a/vendor/github.com/open-policy-agent/opa/internal/edittree/edittree.go +++ b/vendor/github.com/open-policy-agent/opa/internal/edittree/edittree.go @@ -146,14 +146,13 @@ package edittree import ( - "encoding/json" "fmt" "math/big" "sort" "strings" - "github.com/open-policy-agent/opa/ast" "github.com/open-policy-agent/opa/internal/edittree/bitvector" + "github.com/open-policy-agent/opa/v1/ast" ) // Deletions are encoded with a nil value pointer. @@ -213,10 +212,10 @@ func (e *EditTree) getKeyHash(key *ast.Term) (int, bool) { case ast.Null, ast.Boolean, ast.String, ast.Var: equal = func(y ast.Value) bool { return x == y } case ast.Number: - if xi, err := json.Number(x).Int64(); err == nil { + if xi, ok := x.Int64(); ok { equal = func(y ast.Value) bool { if y, ok := y.(ast.Number); ok { - if yi, err := json.Number(y).Int64(); err == nil { + if yi, ok := y.Int64(); ok { return xi == yi } } @@ -725,9 +724,9 @@ func (e *EditTree) Unfold(path ast.Ref) (*EditTree, error) { // Fall back to looking up the key in e.value. // Extend the tree if key is present. Error otherwise. - if v, err := x.Find(ast.Ref{ast.IntNumberTerm(idx)}); err == nil { + if v, err := x.Find(ast.Ref{ast.InternedIntNumberTerm(idx)}); err == nil { // TODO: Consider a more efficient "Replace" function that special-cases this for arrays instead? - _, err := e.Delete(ast.IntNumberTerm(idx)) + _, err := e.Delete(ast.InternedIntNumberTerm(idx)) if err != nil { return nil, err } @@ -1026,8 +1025,7 @@ func (e *EditTree) Exists(path ast.Ref) bool { } // Fallback if child lookup failed. // We have to ensure that the lookup term is a number here, or Find will fail. - k := ast.Ref{ast.IntNumberTerm(idx)}.Concat(path[1:]) - _, err = x.Find(k) + _, err = x.Find(ast.Ref{ast.InternedIntNumberTerm(idx)}.Concat(path[1:])) return err == nil default: // Catch all primitive types. diff --git a/vendor/github.com/open-policy-agent/opa/internal/future/filter_imports.go b/vendor/github.com/open-policy-agent/opa/internal/future/filter_imports.go index cf5721101..eb6091cc6 100644 --- a/vendor/github.com/open-policy-agent/opa/internal/future/filter_imports.go +++ b/vendor/github.com/open-policy-agent/opa/internal/future/filter_imports.go @@ -4,7 +4,7 @@ package future -import "github.com/open-policy-agent/opa/ast" +import "github.com/open-policy-agent/opa/v1/ast" // FilterFutureImports filters OUT any future imports from the passed slice of // `*ast.Import`s. diff --git a/vendor/github.com/open-policy-agent/opa/internal/future/parser_opts.go b/vendor/github.com/open-policy-agent/opa/internal/future/parser_opts.go index 804702b94..84a529287 100644 --- a/vendor/github.com/open-policy-agent/opa/internal/future/parser_opts.go +++ b/vendor/github.com/open-policy-agent/opa/internal/future/parser_opts.go @@ -7,7 +7,7 @@ package future import ( "fmt" - "github.com/open-policy-agent/opa/ast" + "github.com/open-policy-agent/opa/v1/ast" ) // ParserOptionsFromFutureImports transforms a slice of `ast.Import`s into the diff --git a/vendor/github.com/open-policy-agent/opa/internal/gojsonschema/schemaReferencePool.go b/vendor/github.com/open-policy-agent/opa/internal/gojsonschema/schemaReferencePool.go index 876419f56..515702095 100644 --- a/vendor/github.com/open-policy-agent/opa/internal/gojsonschema/schemaReferencePool.go +++ b/vendor/github.com/open-policy-agent/opa/internal/gojsonschema/schemaReferencePool.go @@ -25,10 +25,6 @@ package gojsonschema -import ( - "fmt" -) - type schemaReferencePool struct { documents map[string]*SubSchema } @@ -44,7 +40,7 @@ func newSchemaReferencePool() *schemaReferencePool { func (p *schemaReferencePool) Get(ref string) (r *SubSchema, o bool) { if internalLogEnabled { - internalLog(fmt.Sprintf("Schema Reference ( %s )", ref)) + internalLog("Schema Reference ( %s )", ref) } if sch, ok := p.documents[ref]; ok { @@ -60,7 +56,7 @@ func (p *schemaReferencePool) Get(ref string) (r *SubSchema, o bool) { func (p *schemaReferencePool) Add(ref string, sch *SubSchema) { if internalLogEnabled { - internalLog(fmt.Sprintf("Add Schema Reference %s to pool", ref)) + internalLog("Add Schema Reference %s to pool", ref) } if _, ok := p.documents[ref]; !ok { p.documents[ref] = sch diff --git a/vendor/github.com/open-policy-agent/opa/internal/gojsonschema/validation.go b/vendor/github.com/open-policy-agent/opa/internal/gojsonschema/validation.go index 7c86e3724..efdea58b6 100644 --- a/vendor/github.com/open-policy-agent/opa/internal/gojsonschema/validation.go +++ b/vendor/github.com/open-policy-agent/opa/internal/gojsonschema/validation.go @@ -348,7 +348,7 @@ func (v *SubSchema) validateSchema(currentSubSchema *SubSchema, currentNode inte } } - if currentSubSchema.dependencies != nil && len(currentSubSchema.dependencies) > 0 { + if len(currentSubSchema.dependencies) > 0 { if currentNodeMap, ok := currentNode.(map[string]interface{}); ok { for elementKey := range currentNodeMap { if dependency, ok := currentSubSchema.dependencies[elementKey]; ok { @@ -469,7 +469,7 @@ func (v *SubSchema) validateArray(currentSubSchema *SubSchema, value []interface result.mergeErrors(validationResult) } } else { - if currentSubSchema.ItemsChildren != nil && len(currentSubSchema.ItemsChildren) > 0 { + if len(currentSubSchema.ItemsChildren) > 0 { nbItems := len(currentSubSchema.ItemsChildren) diff --git a/vendor/github.com/open-policy-agent/opa/internal/gqlparser/validator/rules/fields_on_correct_type.go b/vendor/github.com/open-policy-agent/opa/internal/gqlparser/validator/rules/fields_on_correct_type.go index d536e5e5f..f68176747 100644 --- a/vendor/github.com/open-policy-agent/opa/internal/gqlparser/validator/rules/fields_on_correct_type.go +++ b/vendor/github.com/open-policy-agent/opa/internal/gqlparser/validator/rules/fields_on_correct_type.go @@ -27,7 +27,7 @@ func init() { } addError( - Message(message), + Message(message), //nolint:govet At(field.Position), ) }) diff --git a/vendor/github.com/open-policy-agent/opa/internal/gqlparser/validator/rules/fragments_on_composite_types.go b/vendor/github.com/open-policy-agent/opa/internal/gqlparser/validator/rules/fragments_on_composite_types.go index 66bd348c4..861e3b16c 100644 --- a/vendor/github.com/open-policy-agent/opa/internal/gqlparser/validator/rules/fragments_on_composite_types.go +++ b/vendor/github.com/open-policy-agent/opa/internal/gqlparser/validator/rules/fragments_on_composite_types.go @@ -20,7 +20,7 @@ func init() { message := fmt.Sprintf(`Fragment cannot condition on non composite type "%s".`, inlineFragment.TypeCondition) addError( - Message(message), + Message(message), //nolint:govet At(inlineFragment.Position), ) }) @@ -33,7 +33,7 @@ func init() { message := fmt.Sprintf(`Fragment "%s" cannot condition on non composite type "%s".`, fragment.Name, fragment.TypeCondition) addError( - Message(message), + Message(message), //nolint:govet At(fragment.Position), ) }) diff --git a/vendor/github.com/open-policy-agent/opa/internal/json/patch/patch.go b/vendor/github.com/open-policy-agent/opa/internal/json/patch/patch.go index 31c89869d..550618079 100644 --- a/vendor/github.com/open-policy-agent/opa/internal/json/patch/patch.go +++ b/vendor/github.com/open-policy-agent/opa/internal/json/patch/patch.go @@ -7,7 +7,7 @@ package patch import ( "strings" - "github.com/open-policy-agent/opa/storage" + "github.com/open-policy-agent/opa/v1/storage" ) // ParsePatchPathEscaped returns a new path for the given escaped str. diff --git a/vendor/github.com/open-policy-agent/opa/internal/logging/logging.go b/vendor/github.com/open-policy-agent/opa/internal/logging/logging.go index 2158a6e59..dc28b1b35 100644 --- a/vendor/github.com/open-policy-agent/opa/internal/logging/logging.go +++ b/vendor/github.com/open-policy-agent/opa/internal/logging/logging.go @@ -12,7 +12,7 @@ import ( "github.com/sirupsen/logrus" - "github.com/open-policy-agent/opa/logging" + "github.com/open-policy-agent/opa/v1/logging" ) func GetLevel(level string) (logging.Level, error) { diff --git a/vendor/github.com/open-policy-agent/opa/internal/oracle/oracle.go b/vendor/github.com/open-policy-agent/opa/internal/oracle/oracle.go index baa74fae7..f32d01b32 100644 --- a/vendor/github.com/open-policy-agent/opa/internal/oracle/oracle.go +++ b/vendor/github.com/open-policy-agent/opa/internal/oracle/oracle.go @@ -3,7 +3,7 @@ package oracle import ( "errors" - "github.com/open-policy-agent/opa/ast" + "github.com/open-policy-agent/opa/v1/ast" ) // Error defines the structure of errors returned by the oracle. @@ -198,9 +198,9 @@ func findContainingNodeStack(module *ast.Module, pos int) []ast.Node { ast.WalkNodes(module, func(x ast.Node) bool { - min, max := getLocMinMax(x) + minLoc, maxLoc := getLocMinMax(x) - if pos < min || pos >= max { + if pos < minLoc || pos >= maxLoc { return true } @@ -218,7 +218,7 @@ func getLocMinMax(x ast.Node) (int, int) { } loc := x.Loc() - min := loc.Offset + minOff := loc.Offset // Special case bodies because location text is only for the first expr. if body, ok := x.(ast.Body); ok { @@ -227,10 +227,10 @@ func getLocMinMax(x ast.Node) (int, int) { if extraLoc == nil { return -1, -1 } - return min, extraLoc.Offset + len(extraLoc.Text) + return minOff, extraLoc.Offset + len(extraLoc.Text) } - return min, min + len(loc.Text) + return minOff, minOff + len(loc.Text) } // findLastExpr returns the last expression in an ast.Body that has not been generated diff --git a/vendor/github.com/open-policy-agent/opa/internal/pathwatcher/utils.go b/vendor/github.com/open-policy-agent/opa/internal/pathwatcher/utils.go index 31319a9ce..ee7fb794c 100644 --- a/vendor/github.com/open-policy-agent/opa/internal/pathwatcher/utils.go +++ b/vendor/github.com/open-policy-agent/opa/internal/pathwatcher/utils.go @@ -12,10 +12,10 @@ import ( "sort" "github.com/fsnotify/fsnotify" - "github.com/open-policy-agent/opa/ast" initload "github.com/open-policy-agent/opa/internal/runtime/init" - "github.com/open-policy-agent/opa/loader" - "github.com/open-policy-agent/opa/storage" + "github.com/open-policy-agent/opa/v1/ast" + "github.com/open-policy-agent/opa/v1/loader" + "github.com/open-policy-agent/opa/v1/storage" ) // CreatePathWatcher creates watchers to monitor for path changes @@ -42,7 +42,7 @@ func CreatePathWatcher(rootPaths []string) (*fsnotify.Watcher, error) { // ProcessWatcherUpdate handles an occurrence of a watcher event func ProcessWatcherUpdate(ctx context.Context, paths []string, removed string, store storage.Store, filter loader.Filter, asBundle bool, f func(context.Context, storage.Transaction, *initload.LoadPathsResult) error) error { - return ProcessWatcherUpdateForRegoVersion(ctx, ast.RegoV0, paths, removed, store, filter, asBundle, f) + return ProcessWatcherUpdateForRegoVersion(ctx, ast.DefaultRegoVersion, paths, removed, store, filter, asBundle, f) } func ProcessWatcherUpdateForRegoVersion(ctx context.Context, regoVersion ast.RegoVersion, paths []string, removed string, store storage.Store, filter loader.Filter, asBundle bool, @@ -75,7 +75,7 @@ func ProcessWatcherUpdateForRegoVersion(ctx context.Context, regoVersion ast.Reg if err != nil { return err } - module, err := ast.ParseModule(id, string(bs)) + module, err := ast.ParseModuleWithOpts(id, string(bs), ast.ParserOptions{RegoVersion: regoVersion}) if err != nil { return err } diff --git a/vendor/github.com/open-policy-agent/opa/internal/planner/planner.go b/vendor/github.com/open-policy-agent/opa/internal/planner/planner.go index b75d26dda..d6b302041 100644 --- a/vendor/github.com/open-policy-agent/opa/internal/planner/planner.go +++ b/vendor/github.com/open-policy-agent/opa/internal/planner/planner.go @@ -11,10 +11,10 @@ import ( "io" "sort" - "github.com/open-policy-agent/opa/ast" - "github.com/open-policy-agent/opa/ast/location" "github.com/open-policy-agent/opa/internal/debug" - "github.com/open-policy-agent/opa/ir" + "github.com/open-policy-agent/opa/v1/ast" + "github.com/open-policy-agent/opa/v1/ast/location" + "github.com/open-policy-agent/opa/v1/ir" ) // QuerySet represents the input to the planner. @@ -1037,7 +1037,7 @@ func (p *Planner) planExprCall(e *ast.Expr, iter planiter) error { args = p.defaultOperands() } else if decl, ok := p.decls[operator]; ok { relation = decl.Relation - arity = len(decl.Decl.Args()) + arity = decl.Decl.Arity() void = decl.Decl.Result() == nil name = operator p.externs[operator] = decl diff --git a/vendor/github.com/open-policy-agent/opa/internal/planner/rules.go b/vendor/github.com/open-policy-agent/opa/internal/planner/rules.go index f5d6f3fc6..2f424da52 100644 --- a/vendor/github.com/open-policy-agent/opa/internal/planner/rules.go +++ b/vendor/github.com/open-policy-agent/opa/internal/planner/rules.go @@ -4,7 +4,7 @@ import ( "fmt" "sort" - "github.com/open-policy-agent/opa/ast" + "github.com/open-policy-agent/opa/v1/ast" ) // funcstack implements a simple map structure used to keep track of virtual diff --git a/vendor/github.com/open-policy-agent/opa/internal/planner/varstack.go b/vendor/github.com/open-policy-agent/opa/internal/planner/varstack.go index dccff1b5c..0df6bcd8b 100644 --- a/vendor/github.com/open-policy-agent/opa/internal/planner/varstack.go +++ b/vendor/github.com/open-policy-agent/opa/internal/planner/varstack.go @@ -5,8 +5,8 @@ package planner import ( - "github.com/open-policy-agent/opa/ast" - "github.com/open-policy-agent/opa/ir" + "github.com/open-policy-agent/opa/v1/ast" + "github.com/open-policy-agent/opa/v1/ir" ) type varstack []map[ast.Var]ir.Local diff --git a/vendor/github.com/open-policy-agent/opa/internal/presentation/presentation.go b/vendor/github.com/open-policy-agent/opa/internal/presentation/presentation.go index b69f93a27..5d70d3bd5 100644 --- a/vendor/github.com/open-policy-agent/opa/internal/presentation/presentation.go +++ b/vendor/github.com/open-policy-agent/opa/internal/presentation/presentation.go @@ -18,15 +18,15 @@ import ( "github.com/olekukonko/tablewriter" - "github.com/open-policy-agent/opa/ast" - "github.com/open-policy-agent/opa/cover" - "github.com/open-policy-agent/opa/format" - "github.com/open-policy-agent/opa/loader" - "github.com/open-policy-agent/opa/metrics" - "github.com/open-policy-agent/opa/profiler" - "github.com/open-policy-agent/opa/rego" - "github.com/open-policy-agent/opa/storage" - "github.com/open-policy-agent/opa/topdown" + "github.com/open-policy-agent/opa/v1/ast" + "github.com/open-policy-agent/opa/v1/cover" + "github.com/open-policy-agent/opa/v1/format" + "github.com/open-policy-agent/opa/v1/loader" + "github.com/open-policy-agent/opa/v1/metrics" + "github.com/open-policy-agent/opa/v1/profiler" + "github.com/open-policy-agent/opa/v1/rego" + "github.com/open-policy-agent/opa/v1/storage" + "github.com/open-policy-agent/opa/v1/topdown" ) // DefaultProfileSortOrder is the default ordering unless something is specified in the CLI diff --git a/vendor/github.com/open-policy-agent/opa/internal/prometheus/prometheus.go b/vendor/github.com/open-policy-agent/opa/internal/prometheus/prometheus.go index 5646a01d2..b733053de 100644 --- a/vendor/github.com/open-policy-agent/opa/internal/prometheus/prometheus.go +++ b/vendor/github.com/open-policy-agent/opa/internal/prometheus/prometheus.go @@ -21,7 +21,7 @@ import ( "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promhttp" - "github.com/open-policy-agent/opa/metrics" + "github.com/open-policy-agent/opa/v1/metrics" ) // Provider wraps a metrics.Metrics provider with a Prometheus registry that can diff --git a/vendor/github.com/open-policy-agent/opa/internal/providers/aws/ecr.go b/vendor/github.com/open-policy-agent/opa/internal/providers/aws/ecr.go index 179b5b5d5..55e587e9f 100644 --- a/vendor/github.com/open-policy-agent/opa/internal/providers/aws/ecr.go +++ b/vendor/github.com/open-policy-agent/opa/internal/providers/aws/ecr.go @@ -11,7 +11,7 @@ import ( "time" "github.com/open-policy-agent/opa/internal/version" - "github.com/open-policy-agent/opa/logging" + "github.com/open-policy-agent/opa/v1/logging" ) // Values taken from diff --git a/vendor/github.com/open-policy-agent/opa/internal/providers/aws/kms.go b/vendor/github.com/open-policy-agent/opa/internal/providers/aws/kms.go index 77c0bc934..6dfb06a49 100644 --- a/vendor/github.com/open-policy-agent/opa/internal/providers/aws/kms.go +++ b/vendor/github.com/open-policy-agent/opa/internal/providers/aws/kms.go @@ -10,7 +10,7 @@ import ( "time" "github.com/open-policy-agent/opa/internal/version" - "github.com/open-policy-agent/opa/logging" + "github.com/open-policy-agent/opa/v1/logging" ) // Values taken from diff --git a/vendor/github.com/open-policy-agent/opa/internal/providers/aws/signing_v4.go b/vendor/github.com/open-policy-agent/opa/internal/providers/aws/signing_v4.go index bfb780754..3c152831b 100644 --- a/vendor/github.com/open-policy-agent/opa/internal/providers/aws/signing_v4.go +++ b/vendor/github.com/open-policy-agent/opa/internal/providers/aws/signing_v4.go @@ -19,7 +19,7 @@ import ( v4 "github.com/open-policy-agent/opa/internal/providers/aws/v4" - "github.com/open-policy-agent/opa/ast" + "github.com/open-policy-agent/opa/v1/ast" ) func stringFromTerm(t *ast.Term) string { diff --git a/vendor/github.com/open-policy-agent/opa/internal/providers/aws/util.go b/vendor/github.com/open-policy-agent/opa/internal/providers/aws/util.go index e033da746..9ce9af90d 100644 --- a/vendor/github.com/open-policy-agent/opa/internal/providers/aws/util.go +++ b/vendor/github.com/open-policy-agent/opa/internal/providers/aws/util.go @@ -5,7 +5,7 @@ import ( "io" "net/http" - "github.com/open-policy-agent/opa/logging" + "github.com/open-policy-agent/opa/v1/logging" ) // DoRequestWithClient is a convenience function to get the body of an HTTP response with diff --git a/vendor/github.com/open-policy-agent/opa/internal/ref/ref.go b/vendor/github.com/open-policy-agent/opa/internal/ref/ref.go index 6e84df4b0..173b5a3c1 100644 --- a/vendor/github.com/open-policy-agent/opa/internal/ref/ref.go +++ b/vendor/github.com/open-policy-agent/opa/internal/ref/ref.go @@ -9,8 +9,8 @@ import ( "errors" "strings" - "github.com/open-policy-agent/opa/ast" - "github.com/open-policy-agent/opa/storage" + "github.com/open-policy-agent/opa/v1/ast" + "github.com/open-policy-agent/opa/v1/storage" ) // ParseDataPath returns a ref from the slash separated path s rooted at data. diff --git a/vendor/github.com/open-policy-agent/opa/internal/rego/opa/options.go b/vendor/github.com/open-policy-agent/opa/internal/rego/opa/options.go index b58a05ee8..072e37667 100644 --- a/vendor/github.com/open-policy-agent/opa/internal/rego/opa/options.go +++ b/vendor/github.com/open-policy-agent/opa/internal/rego/opa/options.go @@ -4,11 +4,11 @@ import ( "io" "time" - "github.com/open-policy-agent/opa/ast" - "github.com/open-policy-agent/opa/metrics" - "github.com/open-policy-agent/opa/topdown/builtins" - "github.com/open-policy-agent/opa/topdown/cache" - "github.com/open-policy-agent/opa/topdown/print" + "github.com/open-policy-agent/opa/v1/ast" + "github.com/open-policy-agent/opa/v1/metrics" + "github.com/open-policy-agent/opa/v1/topdown/builtins" + "github.com/open-policy-agent/opa/v1/topdown/cache" + "github.com/open-policy-agent/opa/v1/topdown/print" ) // Result holds the evaluation result. diff --git a/vendor/github.com/open-policy-agent/opa/internal/report/report.go b/vendor/github.com/open-policy-agent/opa/internal/report/report.go index 145d0a946..55f4cfe21 100644 --- a/vendor/github.com/open-policy-agent/opa/internal/report/report.go +++ b/vendor/github.com/open-policy-agent/opa/internal/report/report.go @@ -17,12 +17,12 @@ import ( "sync" "time" - "github.com/open-policy-agent/opa/keys" - "github.com/open-policy-agent/opa/logging" + "github.com/open-policy-agent/opa/v1/keys" + "github.com/open-policy-agent/opa/v1/logging" + "github.com/open-policy-agent/opa/v1/version" - "github.com/open-policy-agent/opa/plugins/rest" - "github.com/open-policy-agent/opa/util" - "github.com/open-policy-agent/opa/version" + "github.com/open-policy-agent/opa/v1/plugins/rest" + "github.com/open-policy-agent/opa/v1/util" ) // ExternalServiceURL is the base HTTP URL for a telemetry service. diff --git a/vendor/github.com/open-policy-agent/opa/internal/runtime/init/init.go b/vendor/github.com/open-policy-agent/opa/internal/runtime/init/init.go index b1a5b7157..814847a12 100644 --- a/vendor/github.com/open-policy-agent/opa/internal/runtime/init/init.go +++ b/vendor/github.com/open-policy-agent/opa/internal/runtime/init/init.go @@ -12,12 +12,12 @@ import ( "path/filepath" "strings" - "github.com/open-policy-agent/opa/ast" - "github.com/open-policy-agent/opa/bundle" storedversion "github.com/open-policy-agent/opa/internal/version" - "github.com/open-policy-agent/opa/loader" - "github.com/open-policy-agent/opa/metrics" - "github.com/open-policy-agent/opa/storage" + "github.com/open-policy-agent/opa/v1/ast" + "github.com/open-policy-agent/opa/v1/bundle" + "github.com/open-policy-agent/opa/v1/loader" + "github.com/open-policy-agent/opa/v1/metrics" + "github.com/open-policy-agent/opa/v1/storage" ) // InsertAndCompileOptions contains the input for the operation. @@ -53,6 +53,7 @@ func InsertAndCompile(ctx context.Context, opts InsertAndCompileOptions) (*Inser } compiler := ast.NewCompiler(). + WithDefaultRegoVersion(opts.ParserOptions.RegoVersion). SetErrorLimit(opts.MaxErrors). WithPathConflictsCheck(storage.NonEmpty(ctx, opts.Store, opts.Txn)). WithEnablePrintStatements(opts.EnablePrintStatements) diff --git a/vendor/github.com/open-policy-agent/opa/internal/runtime/runtime.go b/vendor/github.com/open-policy-agent/opa/internal/runtime/runtime.go index 217afcede..85b49e307 100644 --- a/vendor/github.com/open-policy-agent/opa/internal/runtime/runtime.go +++ b/vendor/github.com/open-policy-agent/opa/internal/runtime/runtime.go @@ -9,9 +9,9 @@ import ( "os" "strings" - "github.com/open-policy-agent/opa/ast" - "github.com/open-policy-agent/opa/util" - "github.com/open-policy-agent/opa/version" + "github.com/open-policy-agent/opa/v1/ast" + "github.com/open-policy-agent/opa/v1/util" + "github.com/open-policy-agent/opa/v1/version" ) // Params controls the types of runtime information to return. diff --git a/vendor/github.com/open-policy-agent/opa/internal/strvals/parser.go b/vendor/github.com/open-policy-agent/opa/internal/strvals/parser.go index 1fc07f68c..1eceb83df 100644 --- a/vendor/github.com/open-policy-agent/opa/internal/strvals/parser.go +++ b/vendor/github.com/open-policy-agent/opa/internal/strvals/parser.go @@ -31,7 +31,7 @@ var ErrNotList = errors.New("not a list") // MaxIndex is the maximum index that will be allowed by setIndex. // The default value 65536 = 1024 * 64 -var MaxIndex = 65536 +const MaxIndex = 65536 // ToYAML takes a string of arguments and converts to a YAML document. func ToYAML(s string) (string, error) { diff --git a/vendor/github.com/open-policy-agent/opa/internal/version/version.go b/vendor/github.com/open-policy-agent/opa/internal/version/version.go index 1c2e9ecd0..dc52733fc 100644 --- a/vendor/github.com/open-policy-agent/opa/internal/version/version.go +++ b/vendor/github.com/open-policy-agent/opa/internal/version/version.go @@ -10,8 +10,8 @@ import ( "fmt" "runtime" - "github.com/open-policy-agent/opa/storage" - "github.com/open-policy-agent/opa/version" + "github.com/open-policy-agent/opa/v1/storage" + "github.com/open-policy-agent/opa/v1/version" ) var versionPath = storage.MustParsePath("/system/version") diff --git a/vendor/github.com/open-policy-agent/opa/internal/wasm/encoding/reader.go b/vendor/github.com/open-policy-agent/opa/internal/wasm/encoding/reader.go index 35e6059c7..7120392ce 100644 --- a/vendor/github.com/open-policy-agent/opa/internal/wasm/encoding/reader.go +++ b/vendor/github.com/open-policy-agent/opa/internal/wasm/encoding/reader.go @@ -809,19 +809,19 @@ func readLimits(r io.Reader, l *module.Limit) error { return err } - min, err := leb128.ReadVarUint32(r) + minLim, err := leb128.ReadVarUint32(r) if err != nil { return err } - l.Min = min + l.Min = minLim if b == 1 { - max, err := leb128.ReadVarUint32(r) + maxLim, err := leb128.ReadVarUint32(r) if err != nil { return err } - l.Max = &max + l.Max = &maxLim } else if b != 0 { return fmt.Errorf("illegal limit flag") } diff --git a/vendor/github.com/open-policy-agent/opa/internal/wasm/sdk/internal/wasm/bindings.go b/vendor/github.com/open-policy-agent/opa/internal/wasm/sdk/internal/wasm/bindings.go index f90031c53..33ee7e451 100644 --- a/vendor/github.com/open-policy-agent/opa/internal/wasm/sdk/internal/wasm/bindings.go +++ b/vendor/github.com/open-policy-agent/opa/internal/wasm/sdk/internal/wasm/bindings.go @@ -18,12 +18,12 @@ import ( wasmtime "github.com/bytecodealliance/wasmtime-go/v3" - "github.com/open-policy-agent/opa/ast" - "github.com/open-policy-agent/opa/metrics" - "github.com/open-policy-agent/opa/topdown" - "github.com/open-policy-agent/opa/topdown/builtins" - "github.com/open-policy-agent/opa/topdown/cache" - "github.com/open-policy-agent/opa/topdown/print" + "github.com/open-policy-agent/opa/v1/ast" + "github.com/open-policy-agent/opa/v1/metrics" + "github.com/open-policy-agent/opa/v1/topdown" + "github.com/open-policy-agent/opa/v1/topdown/builtins" + "github.com/open-policy-agent/opa/v1/topdown/cache" + "github.com/open-policy-agent/opa/v1/topdown/print" ) func opaFunctions(dispatcher *builtinDispatcher, store *wasmtime.Store) map[string]wasmtime.AsExtern { @@ -222,8 +222,8 @@ func getExports(c *wasmtime.Caller) exports { return e } -func (e exports) Malloc(caller *wasmtime.Caller, len int32) (int32, error) { - ptr, err := e.mallocFn.Call(caller, len) +func (e exports) Malloc(caller *wasmtime.Caller, length int32) (int32, error) { + ptr, err := e.mallocFn.Call(caller, length) if err != nil { return 0, err } @@ -238,8 +238,8 @@ func (e exports) ValueDump(caller *wasmtime.Caller, addr int32) (int32, error) { return result.(int32), nil } -func (e exports) ValueParse(caller *wasmtime.Caller, addr int32, len int32) (int32, error) { - result, err := e.valueParseFn.Call(caller, addr, len) +func (e exports) ValueParse(caller *wasmtime.Caller, addr int32, length int32) (int32, error) { + result, err := e.valueParseFn.Call(caller, addr, length) if err != nil { return 0, err } diff --git a/vendor/github.com/open-policy-agent/opa/internal/wasm/sdk/internal/wasm/pool.go b/vendor/github.com/open-policy-agent/opa/internal/wasm/sdk/internal/wasm/pool.go index 40970be81..0bb38b423 100644 --- a/vendor/github.com/open-policy-agent/opa/internal/wasm/sdk/internal/wasm/pool.go +++ b/vendor/github.com/open-policy-agent/opa/internal/wasm/sdk/internal/wasm/pool.go @@ -13,7 +13,7 @@ import ( "github.com/open-policy-agent/opa/internal/wasm/sdk/opa/errors" "github.com/open-policy-agent/opa/internal/wasm/util" - "github.com/open-policy-agent/opa/metrics" + "github.com/open-policy-agent/opa/v1/metrics" ) var errNotReady = errors.New(errors.NotReadyErr, "") diff --git a/vendor/github.com/open-policy-agent/opa/internal/wasm/sdk/internal/wasm/vm.go b/vendor/github.com/open-policy-agent/opa/internal/wasm/sdk/internal/wasm/vm.go index b089a2042..e5f589623 100644 --- a/vendor/github.com/open-policy-agent/opa/internal/wasm/sdk/internal/wasm/vm.go +++ b/vendor/github.com/open-policy-agent/opa/internal/wasm/sdk/internal/wasm/vm.go @@ -16,14 +16,14 @@ import ( wasmtime "github.com/bytecodealliance/wasmtime-go/v3" - "github.com/open-policy-agent/opa/ast" sdk_errors "github.com/open-policy-agent/opa/internal/wasm/sdk/opa/errors" "github.com/open-policy-agent/opa/internal/wasm/util" - "github.com/open-policy-agent/opa/metrics" - "github.com/open-policy-agent/opa/topdown" - "github.com/open-policy-agent/opa/topdown/builtins" - "github.com/open-policy-agent/opa/topdown/cache" - "github.com/open-policy-agent/opa/topdown/print" + "github.com/open-policy-agent/opa/v1/ast" + "github.com/open-policy-agent/opa/v1/metrics" + "github.com/open-policy-agent/opa/v1/topdown" + "github.com/open-policy-agent/opa/v1/topdown/builtins" + "github.com/open-policy-agent/opa/v1/topdown/cache" + "github.com/open-policy-agent/opa/v1/topdown/print" ) // VM is a wrapper around a Wasm VM instance diff --git a/vendor/github.com/open-policy-agent/opa/internal/wasm/sdk/opa/opa.go b/vendor/github.com/open-policy-agent/opa/internal/wasm/sdk/opa/opa.go index 97ca58753..8b6b9db5c 100644 --- a/vendor/github.com/open-policy-agent/opa/internal/wasm/sdk/opa/opa.go +++ b/vendor/github.com/open-policy-agent/opa/internal/wasm/sdk/opa/opa.go @@ -12,13 +12,13 @@ import ( "sync" "time" - "github.com/open-policy-agent/opa/ast" "github.com/open-policy-agent/opa/internal/wasm/sdk/internal/wasm" sdk_errors "github.com/open-policy-agent/opa/internal/wasm/sdk/opa/errors" - "github.com/open-policy-agent/opa/metrics" - "github.com/open-policy-agent/opa/topdown/builtins" - "github.com/open-policy-agent/opa/topdown/cache" - "github.com/open-policy-agent/opa/topdown/print" + "github.com/open-policy-agent/opa/v1/ast" + "github.com/open-policy-agent/opa/v1/metrics" + "github.com/open-policy-agent/opa/v1/topdown/builtins" + "github.com/open-policy-agent/opa/v1/topdown/cache" + "github.com/open-policy-agent/opa/v1/topdown/print" ) var errNotReady = sdk_errors.New(sdk_errors.NotReadyErr, "") diff --git a/vendor/github.com/open-policy-agent/opa/loader/doc.go b/vendor/github.com/open-policy-agent/opa/loader/doc.go new file mode 100644 index 000000000..9f60920d9 --- /dev/null +++ b/vendor/github.com/open-policy-agent/opa/loader/doc.go @@ -0,0 +1,8 @@ +// Copyright 2024 The OPA Authors. All rights reserved. +// Use of this source code is governed by an Apache2 +// license that can be found in the LICENSE file. + +// Deprecated: This package is intended for older projects transitioning from OPA v0.x and will remain for the lifetime of OPA v1.x, but its use is not recommended. +// For newer features and behaviours, such as defaulting to the Rego v1 syntax, use the corresponding components in the [github.com/open-policy-agent/opa/v1] package instead. +// See https://www.openpolicyagent.org/docs/latest/v0-compatibility/ for more information. +package loader diff --git a/vendor/github.com/open-policy-agent/opa/loader/errors.go b/vendor/github.com/open-policy-agent/opa/loader/errors.go index b8aafb142..8dc70b867 100644 --- a/vendor/github.com/open-policy-agent/opa/loader/errors.go +++ b/vendor/github.com/open-policy-agent/opa/loader/errors.go @@ -5,58 +5,8 @@ package loader import ( - "fmt" - "strings" - - "github.com/open-policy-agent/opa/ast" + v1 "github.com/open-policy-agent/opa/v1/loader" ) // Errors is a wrapper for multiple loader errors. -type Errors []error - -func (e Errors) Error() string { - if len(e) == 0 { - return "no error(s)" - } - if len(e) == 1 { - return "1 error occurred during loading: " + e[0].Error() - } - buf := make([]string, len(e)) - for i := range buf { - buf[i] = e[i].Error() - } - return fmt.Sprintf("%v errors occurred during loading:\n", len(e)) + strings.Join(buf, "\n") -} - -func (e *Errors) add(err error) { - if errs, ok := err.(ast.Errors); ok { - for i := range errs { - *e = append(*e, errs[i]) - } - } else { - *e = append(*e, err) - } -} - -type unsupportedDocumentType string - -func (path unsupportedDocumentType) Error() string { - return string(path) + ": document must be of type object" -} - -type unrecognizedFile string - -func (path unrecognizedFile) Error() string { - return string(path) + ": can't recognize file type" -} - -func isUnrecognizedFile(err error) bool { - _, ok := err.(unrecognizedFile) - return ok -} - -type mergeError string - -func (e mergeError) Error() string { - return string(e) + ": merge error" -} +type Errors = v1.Errors diff --git a/vendor/github.com/open-policy-agent/opa/loader/loader.go b/vendor/github.com/open-policy-agent/opa/loader/loader.go index 461639ed1..9b2f91d4e 100644 --- a/vendor/github.com/open-policy-agent/opa/loader/loader.go +++ b/vendor/github.com/open-policy-agent/opa/loader/loader.go @@ -6,478 +6,74 @@ package loader import ( - "bytes" - "fmt" - "io" "io/fs" "os" - "path/filepath" - "sort" "strings" - "sigs.k8s.io/yaml" - "github.com/open-policy-agent/opa/ast" - astJSON "github.com/open-policy-agent/opa/ast/json" "github.com/open-policy-agent/opa/bundle" - fileurl "github.com/open-policy-agent/opa/internal/file/url" - "github.com/open-policy-agent/opa/internal/merge" - "github.com/open-policy-agent/opa/loader/filter" - "github.com/open-policy-agent/opa/metrics" - "github.com/open-policy-agent/opa/storage" - "github.com/open-policy-agent/opa/storage/inmem" - "github.com/open-policy-agent/opa/util" + v1 "github.com/open-policy-agent/opa/v1/loader" ) // Result represents the result of successfully loading zero or more files. -type Result struct { - Documents map[string]interface{} - Modules map[string]*RegoFile - path []string -} - -// ParsedModules returns the parsed modules stored on the result. -func (l *Result) ParsedModules() map[string]*ast.Module { - modules := make(map[string]*ast.Module) - for _, module := range l.Modules { - modules[module.Name] = module.Parsed - } - return modules -} - -// Compiler returns a Compiler object with the compiled modules from this loader -// result. -func (l *Result) Compiler() (*ast.Compiler, error) { - compiler := ast.NewCompiler() - compiler.Compile(l.ParsedModules()) - if compiler.Failed() { - return nil, compiler.Errors - } - return compiler, nil -} - -// Store returns a Store object with the documents from this loader result. -func (l *Result) Store() (storage.Store, error) { - return l.StoreWithOpts() -} - -// StoreWithOpts returns a Store object with the documents from this loader result, -// instantiated with the passed options. -func (l *Result) StoreWithOpts(opts ...inmem.Opt) (storage.Store, error) { - return inmem.NewFromObjectWithOpts(l.Documents, opts...), nil -} +type Result = v1.Result // RegoFile represents the result of loading a single Rego source file. -type RegoFile struct { - Name string - Parsed *ast.Module - Raw []byte -} +type RegoFile = v1.RegoFile // Filter defines the interface for filtering files during loading. If the // filter returns true, the file should be excluded from the result. -type Filter = filter.LoaderFilter +type Filter = v1.Filter // GlobExcludeName excludes files and directories whose names do not match the // shell style pattern at minDepth or greater. func GlobExcludeName(pattern string, minDepth int) Filter { - return func(_ string, info fs.FileInfo, depth int) bool { - match, _ := filepath.Match(pattern, info.Name()) - return match && depth >= minDepth - } + return v1.GlobExcludeName(pattern, minDepth) } // FileLoader defines an interface for loading OPA data files // and Rego policies. -type FileLoader interface { - All(paths []string) (*Result, error) - Filtered(paths []string, filter Filter) (*Result, error) - AsBundle(path string) (*bundle.Bundle, error) - WithReader(io.Reader) FileLoader - WithFS(fs.FS) FileLoader - WithMetrics(metrics.Metrics) FileLoader - WithFilter(Filter) FileLoader - WithBundleVerificationConfig(*bundle.VerificationConfig) FileLoader - WithSkipBundleVerification(bool) FileLoader - WithProcessAnnotation(bool) FileLoader - WithCapabilities(*ast.Capabilities) FileLoader - WithJSONOptions(*astJSON.Options) FileLoader - WithRegoVersion(ast.RegoVersion) FileLoader - WithFollowSymlinks(bool) FileLoader -} +type FileLoader = v1.FileLoader // NewFileLoader returns a new FileLoader instance. func NewFileLoader() FileLoader { - return &fileLoader{ - metrics: metrics.New(), - files: make(map[string]bundle.FileInfo), - } -} - -type fileLoader struct { - metrics metrics.Metrics - filter Filter - bvc *bundle.VerificationConfig - skipVerify bool - files map[string]bundle.FileInfo - opts ast.ParserOptions - fsys fs.FS - reader io.Reader - followSymlinks bool -} - -// WithFS provides an fs.FS to use for loading files. You can pass nil to -// use plain IO calls (e.g. os.Open, os.Stat, etc.), this is the default -// behaviour. -func (fl *fileLoader) WithFS(fsys fs.FS) FileLoader { - fl.fsys = fsys - return fl -} - -// WithReader provides an io.Reader to use for loading the bundle tarball. -// An io.Reader passed via WithReader takes precedence over an fs.FS passed -// via WithFS. -func (fl *fileLoader) WithReader(rdr io.Reader) FileLoader { - fl.reader = rdr - return fl -} - -// WithMetrics provides the metrics instance to use while loading -func (fl *fileLoader) WithMetrics(m metrics.Metrics) FileLoader { - fl.metrics = m - return fl -} - -// WithFilter specifies the filter object to use to filter files while loading -func (fl *fileLoader) WithFilter(filter Filter) FileLoader { - fl.filter = filter - return fl -} - -// WithBundleVerificationConfig sets the key configuration used to verify a signed bundle -func (fl *fileLoader) WithBundleVerificationConfig(config *bundle.VerificationConfig) FileLoader { - fl.bvc = config - return fl -} - -// WithSkipBundleVerification skips verification of a signed bundle -func (fl *fileLoader) WithSkipBundleVerification(skipVerify bool) FileLoader { - fl.skipVerify = skipVerify - return fl -} - -// WithProcessAnnotation enables or disables processing of schema annotations on rules -func (fl *fileLoader) WithProcessAnnotation(processAnnotation bool) FileLoader { - fl.opts.ProcessAnnotation = processAnnotation - return fl -} - -// WithCapabilities sets the supported capabilities when loading the files -func (fl *fileLoader) WithCapabilities(caps *ast.Capabilities) FileLoader { - fl.opts.Capabilities = caps - return fl -} - -// WithJSONOptions sets the JSONOptions for use when parsing files -func (fl *fileLoader) WithJSONOptions(opts *astJSON.Options) FileLoader { - fl.opts.JSONOptions = opts - return fl -} - -// WithRegoVersion sets the ast.RegoVersion to use when parsing and compiling modules. -func (fl *fileLoader) WithRegoVersion(version ast.RegoVersion) FileLoader { - fl.opts.RegoVersion = version - return fl -} - -// WithFollowSymlinks enables or disables following symlinks when loading files -func (fl *fileLoader) WithFollowSymlinks(followSymlinks bool) FileLoader { - fl.followSymlinks = followSymlinks - return fl -} - -// All returns a Result object loaded (recursively) from the specified paths. -func (fl fileLoader) All(paths []string) (*Result, error) { - return fl.Filtered(paths, nil) -} - -// Filtered returns a Result object loaded (recursively) from the specified -// paths while applying the given filters. If any filter returns true, the -// file/directory is excluded. -func (fl fileLoader) Filtered(paths []string, filter Filter) (*Result, error) { - return all(fl.fsys, paths, filter, func(curr *Result, path string, depth int) error { - - var ( - bs []byte - err error - ) - if fl.fsys != nil { - bs, err = fs.ReadFile(fl.fsys, path) - } else { - bs, err = os.ReadFile(path) - } - if err != nil { - return err - } - - result, err := loadKnownTypes(path, bs, fl.metrics, fl.opts) - if err != nil { - if !isUnrecognizedFile(err) { - return err - } - if depth > 0 { - return nil - } - result, err = loadFileForAnyType(path, bs, fl.metrics, fl.opts) - if err != nil { - return err - } - } - - return curr.merge(path, result) - }) -} - -// AsBundle loads a path as a bundle. If it is a single file -// it will be treated as a normal tarball bundle. If a directory -// is supplied it will be loaded as an unzipped bundle tree. -func (fl fileLoader) AsBundle(path string) (*bundle.Bundle, error) { - path, err := fileurl.Clean(path) - if err != nil { - return nil, err - } - - if err := checkForUNCPath(path); err != nil { - return nil, err - } - - var bundleLoader bundle.DirectoryLoader - var isDir bool - if fl.reader != nil { - bundleLoader = bundle.NewTarballLoaderWithBaseURL(fl.reader, path).WithFilter(fl.filter) - } else { - bundleLoader, isDir, err = GetBundleDirectoryLoaderFS(fl.fsys, path, fl.filter) - } - - if err != nil { - return nil, err - } - bundleLoader = bundleLoader.WithFollowSymlinks(fl.followSymlinks) - - br := bundle.NewCustomReader(bundleLoader). - WithMetrics(fl.metrics). - WithBundleVerificationConfig(fl.bvc). - WithSkipBundleVerification(fl.skipVerify). - WithProcessAnnotations(fl.opts.ProcessAnnotation). - WithCapabilities(fl.opts.Capabilities). - WithJSONOptions(fl.opts.JSONOptions). - WithFollowSymlinks(fl.followSymlinks). - WithRegoVersion(fl.opts.RegoVersion) - - // For bundle directories add the full path in front of module file names - // to simplify debugging. - if isDir { - br.WithBaseDir(path) - } - - b, err := br.Read() - if err != nil { - err = fmt.Errorf("bundle %s: %w", path, err) - } - - return &b, err + return v1.NewFileLoader().WithRegoVersion(ast.DefaultRegoVersion) } // GetBundleDirectoryLoader returns a bundle directory loader which can be used to load // files in the directory func GetBundleDirectoryLoader(path string) (bundle.DirectoryLoader, bool, error) { - return GetBundleDirectoryLoaderFS(nil, path, nil) + return v1.GetBundleDirectoryLoader(path) } // GetBundleDirectoryLoaderWithFilter returns a bundle directory loader which can be used to load // files in the directory after applying the given filter. func GetBundleDirectoryLoaderWithFilter(path string, filter Filter) (bundle.DirectoryLoader, bool, error) { - return GetBundleDirectoryLoaderFS(nil, path, filter) + return v1.GetBundleDirectoryLoaderWithFilter(path, filter) } // GetBundleDirectoryLoaderFS returns a bundle directory loader which can be used to load // files in the directory. func GetBundleDirectoryLoaderFS(fsys fs.FS, path string, filter Filter) (bundle.DirectoryLoader, bool, error) { - path, err := fileurl.Clean(path) - if err != nil { - return nil, false, err - } - - if err := checkForUNCPath(path); err != nil { - return nil, false, err - } - - var fi fs.FileInfo - if fsys != nil { - fi, err = fs.Stat(fsys, path) - } else { - fi, err = os.Stat(path) - } - if err != nil { - return nil, false, fmt.Errorf("error reading %q: %s", path, err) - } - - var bundleLoader bundle.DirectoryLoader - if fi.IsDir() { - if fsys != nil { - bundleLoader = bundle.NewFSLoaderWithRoot(fsys, path) - } else { - bundleLoader = bundle.NewDirectoryLoader(path) - } - } else { - var fh fs.File - if fsys != nil { - fh, err = fsys.Open(path) - } else { - fh, err = os.Open(path) - } - if err != nil { - return nil, false, err - } - bundleLoader = bundle.NewTarballLoaderWithBaseURL(fh, path) - } - - if filter != nil { - bundleLoader = bundleLoader.WithFilter(filter) - } - return bundleLoader, fi.IsDir(), nil + return v1.GetBundleDirectoryLoaderFS(fsys, path, filter) } // FilteredPaths is the same as FilterPathsFS using the current diretory file // system func FilteredPaths(paths []string, filter Filter) ([]string, error) { - return FilteredPathsFS(nil, paths, filter) + return v1.FilteredPaths(paths, filter) } // FilteredPathsFS return a list of files from the specified // paths while applying the given filters. If any filter returns true, the // file/directory is excluded. func FilteredPathsFS(fsys fs.FS, paths []string, filter Filter) ([]string, error) { - result := []string{} - - _, err := all(fsys, paths, filter, func(_ *Result, path string, _ int) error { - result = append(result, path) - return nil - }) - if err != nil { - return nil, err - } - return result, nil + return v1.FilteredPathsFS(fsys, paths, filter) } // Schemas loads a schema set from the specified file path. func Schemas(schemaPath string) (*ast.SchemaSet, error) { - - var errs Errors - ss, err := loadSchemas(schemaPath) - if err != nil { - errs.add(err) - return nil, errs - } - - return ss, nil -} - -func loadSchemas(schemaPath string) (*ast.SchemaSet, error) { - - if schemaPath == "" { - return nil, nil - } - - ss := ast.NewSchemaSet() - path, err := fileurl.Clean(schemaPath) - if err != nil { - return nil, err - } - - info, err := os.Stat(path) - if err != nil { - return nil, err - } - - // Handle single file case. - if !info.IsDir() { - schema, err := loadOneSchema(path) - if err != nil { - return nil, err - } - ss.Put(ast.SchemaRootRef, schema) - return ss, nil - - } - - // Handle directory case. - rootDir := path - - err = filepath.Walk(path, - func(path string, info os.FileInfo, err error) error { - if err != nil { - return err - } else if info.IsDir() { - return nil - } - - schema, err := loadOneSchema(path) - if err != nil { - return err - } - - relPath, err := filepath.Rel(rootDir, path) - if err != nil { - return err - } - - key := getSchemaSetByPathKey(relPath) - ss.Put(key, schema) - return nil - }) - - if err != nil { - return nil, err - } - - return ss, nil -} - -func getSchemaSetByPathKey(path string) ast.Ref { - - front := filepath.Dir(path) - last := strings.TrimSuffix(filepath.Base(path), filepath.Ext(path)) - - var parts []string - - if front != "." { - parts = append(strings.Split(filepath.ToSlash(front), "/"), last) - } else { - parts = []string{last} - } - - key := make(ast.Ref, 1+len(parts)) - key[0] = ast.SchemaRootDocument - for i := range parts { - key[i+1] = ast.StringTerm(parts[i]) - } - - return key -} - -func loadOneSchema(path string) (interface{}, error) { - bs, err := os.ReadFile(path) - if err != nil { - return nil, err - } - - var schema interface{} - if err := util.Unmarshal(bs, &schema); err != nil { - return nil, fmt.Errorf("%s: %w", path, err) - } - - return schema, nil + return v1.Schemas(schemaPath) } // All returns a Result object loaded (recursively) from the specified paths. @@ -517,321 +113,33 @@ func Rego(path string) (*RegoFile, error) { // RegoWithOpts returns a RegoFile object loaded from the given path. func RegoWithOpts(path string, opts ast.ParserOptions) (*RegoFile, error) { - path, err := fileurl.Clean(path) - if err != nil { - return nil, err + if opts.RegoVersion == ast.RegoUndefined { + opts.RegoVersion = ast.DefaultRegoVersion } - bs, err := os.ReadFile(path) - if err != nil { - return nil, err - } - return loadRego(path, bs, metrics.New(), opts) + + return v1.RegoWithOpts(path, opts) } // CleanPath returns the normalized version of a path that can be used as an identifier. func CleanPath(path string) string { - return strings.Trim(path, "/") + return v1.CleanPath(path) } // Paths returns a sorted list of files contained at path. If recurse is true // and path is a directory, then Paths will walk the directory structure // recursively and list files at each level. func Paths(path string, recurse bool) (paths []string, err error) { - path, err = fileurl.Clean(path) - if err != nil { - return nil, err - } - err = filepath.Walk(path, func(f string, _ os.FileInfo, _ error) error { - if !recurse { - if path != f && path != filepath.Dir(f) { - return filepath.SkipDir - } - } - paths = append(paths, f) - return nil - }) - return paths, err + return v1.Paths(path, recurse) } // Dirs resolves filepaths to directories. It will return a list of unique // directories. func Dirs(paths []string) []string { - unique := map[string]struct{}{} - - for _, path := range paths { - // TODO: /dir/dir will register top level directory /dir - dir := filepath.Dir(path) - unique[dir] = struct{}{} - } - - u := make([]string, 0, len(unique)) - for k := range unique { - u = append(u, k) - } - sort.Strings(u) - return u + return v1.Dirs(paths) } // SplitPrefix returns a tuple specifying the document prefix and the file // path. func SplitPrefix(path string) ([]string, string) { - // Non-prefixed URLs can be returned without modification and their contents - // can be rooted directly under data. - if strings.Index(path, "://") == strings.Index(path, ":") { - return nil, path - } - parts := strings.SplitN(path, ":", 2) - if len(parts) == 2 && len(parts[0]) > 0 { - return strings.Split(parts[0], "."), parts[1] - } - return nil, path -} - -func (l *Result) merge(path string, result interface{}) error { - switch result := result.(type) { - case bundle.Bundle: - for _, module := range result.Modules { - l.Modules[module.Path] = &RegoFile{ - Name: module.Path, - Parsed: module.Parsed, - Raw: module.Raw, - } - } - return l.mergeDocument(path, result.Data) - case *RegoFile: - l.Modules[CleanPath(path)] = result - return nil - default: - return l.mergeDocument(path, result) - } -} - -func (l *Result) mergeDocument(path string, doc interface{}) error { - obj, ok := makeDir(l.path, doc) - if !ok { - return unsupportedDocumentType(path) - } - merged, ok := merge.InterfaceMaps(l.Documents, obj) - if !ok { - return mergeError(path) - } - for k := range merged { - l.Documents[k] = merged[k] - } - return nil -} - -func (l *Result) withParent(p string) *Result { - path := append(l.path, p) - return &Result{ - Documents: l.Documents, - Modules: l.Modules, - path: path, - } -} - -func newResult() *Result { - return &Result{ - Documents: map[string]interface{}{}, - Modules: map[string]*RegoFile{}, - } -} - -func all(fsys fs.FS, paths []string, filter Filter, f func(*Result, string, int) error) (*Result, error) { - errs := Errors{} - root := newResult() - - for _, path := range paths { - - // Paths can be prefixed with a string that specifies where content should be - // loaded under data. E.g., foo.bar:/path/to/some.json will load the content - // of some.json under {"foo": {"bar": ...}}. - loaded := root - prefix, path := SplitPrefix(path) - if len(prefix) > 0 { - for _, part := range prefix { - loaded = loaded.withParent(part) - } - } - - allRec(fsys, path, filter, &errs, loaded, 0, f) - } - - if len(errs) > 0 { - return nil, errs - } - - return root, nil -} - -func allRec(fsys fs.FS, path string, filter Filter, errors *Errors, loaded *Result, depth int, f func(*Result, string, int) error) { - - path, err := fileurl.Clean(path) - if err != nil { - errors.add(err) - return - } - - if err := checkForUNCPath(path); err != nil { - errors.add(err) - return - } - - var info fs.FileInfo - if fsys != nil { - info, err = fs.Stat(fsys, path) - } else { - info, err = os.Stat(path) - } - - if err != nil { - errors.add(err) - return - } - - if filter != nil && filter(path, info, depth) { - return - } - - if !info.IsDir() { - if err := f(loaded, path, depth); err != nil { - errors.add(err) - } - return - } - - // If we are recursing on directories then content must be loaded under path - // specified by directory hierarchy. - if depth > 0 { - loaded = loaded.withParent(info.Name()) - } - - var files []fs.DirEntry - if fsys != nil { - files, err = fs.ReadDir(fsys, path) - } else { - files, err = os.ReadDir(path) - } - if err != nil { - errors.add(err) - return - } - - for _, file := range files { - allRec(fsys, filepath.Join(path, file.Name()), filter, errors, loaded, depth+1, f) - } -} - -func loadKnownTypes(path string, bs []byte, m metrics.Metrics, opts ast.ParserOptions) (interface{}, error) { - switch filepath.Ext(path) { - case ".json": - return loadJSON(path, bs, m) - case ".rego": - return loadRego(path, bs, m, opts) - case ".yaml", ".yml": - return loadYAML(path, bs, m) - default: - if strings.HasSuffix(path, ".tar.gz") { - r, err := loadBundleFile(path, bs, m, opts) - if err != nil { - err = fmt.Errorf("bundle %s: %w", path, err) - } - return r, err - } - } - return nil, unrecognizedFile(path) -} - -func loadFileForAnyType(path string, bs []byte, m metrics.Metrics, opts ast.ParserOptions) (interface{}, error) { - module, err := loadRego(path, bs, m, opts) - if err == nil { - return module, nil - } - doc, err := loadJSON(path, bs, m) - if err == nil { - return doc, nil - } - doc, err = loadYAML(path, bs, m) - if err == nil { - return doc, nil - } - return nil, unrecognizedFile(path) -} - -func loadBundleFile(path string, bs []byte, m metrics.Metrics, opts ast.ParserOptions) (bundle.Bundle, error) { - tl := bundle.NewTarballLoaderWithBaseURL(bytes.NewBuffer(bs), path) - br := bundle.NewCustomReader(tl). - WithRegoVersion(opts.RegoVersion). - WithJSONOptions(opts.JSONOptions). - WithProcessAnnotations(opts.ProcessAnnotation). - WithMetrics(m). - WithSkipBundleVerification(true). - IncludeManifestInData(true) - return br.Read() -} - -func loadRego(path string, bs []byte, m metrics.Metrics, opts ast.ParserOptions) (*RegoFile, error) { - m.Timer(metrics.RegoModuleParse).Start() - var module *ast.Module - var err error - module, err = ast.ParseModuleWithOpts(path, string(bs), opts) - m.Timer(metrics.RegoModuleParse).Stop() - if err != nil { - return nil, err - } - result := &RegoFile{ - Name: path, - Parsed: module, - Raw: bs, - } - return result, nil -} - -func loadJSON(path string, bs []byte, m metrics.Metrics) (interface{}, error) { - m.Timer(metrics.RegoDataParse).Start() - var x interface{} - err := util.UnmarshalJSON(bs, &x) - m.Timer(metrics.RegoDataParse).Stop() - - if err != nil { - return nil, fmt.Errorf("%s: %w", path, err) - } - return x, nil -} - -func loadYAML(path string, bs []byte, m metrics.Metrics) (interface{}, error) { - m.Timer(metrics.RegoDataParse).Start() - bs, err := yaml.YAMLToJSON(bs) - m.Timer(metrics.RegoDataParse).Stop() - if err != nil { - return nil, fmt.Errorf("%v: error converting YAML to JSON: %v", path, err) - } - return loadJSON(path, bs, m) -} - -func makeDir(path []string, x interface{}) (map[string]interface{}, bool) { - if len(path) == 0 { - obj, ok := x.(map[string]interface{}) - if !ok { - return nil, false - } - return obj, true - } - return makeDir(path[:len(path)-1], map[string]interface{}{path[len(path)-1]: x}) -} - -// isUNC reports whether path is a UNC path. -func isUNC(path string) bool { - return len(path) > 1 && isSlash(path[0]) && isSlash(path[1]) -} - -func isSlash(c uint8) bool { - return c == '\\' || c == '/' -} - -func checkForUNCPath(path string) error { - if isUNC(path) { - return fmt.Errorf("UNC path read is not allowed: %s", path) - } - return nil + return v1.SplitPrefix(path) } diff --git a/vendor/github.com/open-policy-agent/opa/logging/doc.go b/vendor/github.com/open-policy-agent/opa/logging/doc.go new file mode 100644 index 000000000..665cb369a --- /dev/null +++ b/vendor/github.com/open-policy-agent/opa/logging/doc.go @@ -0,0 +1,8 @@ +// Copyright 2024 The OPA Authors. All rights reserved. +// Use of this source code is governed by an Apache2 +// license that can be found in the LICENSE file. + +// Deprecated: This package is intended for older projects transitioning from OPA v0.x and will remain for the lifetime of OPA v1.x, but its use is not recommended. +// For newer features and behaviours, such as defaulting to the Rego v1 syntax, use the corresponding components in the [github.com/open-policy-agent/opa/v1] package instead. +// See https://www.openpolicyagent.org/docs/latest/v0-compatibility/ for more information. +package logging diff --git a/vendor/github.com/open-policy-agent/opa/logging/logging.go b/vendor/github.com/open-policy-agent/opa/logging/logging.go index 7a1edfb56..f84deb86b 100644 --- a/vendor/github.com/open-policy-agent/opa/logging/logging.go +++ b/vendor/github.com/open-policy-agent/opa/logging/logging.go @@ -2,265 +2,78 @@ package logging import ( "context" - "io" - "net/http" - "github.com/sirupsen/logrus" + v1 "github.com/open-policy-agent/opa/v1/logging" ) // Level log level for Logger -type Level uint8 +type Level = v1.Level const ( // Error error log level - Error Level = iota + Error = v1.Error // Warn warn log level - Warn + Warn = v1.Warn // Info info log level - Info + Info = v1.Info // Debug debug log level - Debug + Debug = v1.Debug ) // Logger provides interface for OPA logger implementations -type Logger interface { - Debug(fmt string, a ...interface{}) - Info(fmt string, a ...interface{}) - Error(fmt string, a ...interface{}) - Warn(fmt string, a ...interface{}) - - WithFields(map[string]interface{}) Logger - - GetLevel() Level - SetLevel(Level) -} +type Logger = v1.Logger // StandardLogger is the default OPA logger implementation. -type StandardLogger struct { - logger *logrus.Logger - fields map[string]interface{} -} +type StandardLogger = v1.StandardLogger // New returns a new standard logger. func New() *StandardLogger { - return &StandardLogger{ - logger: logrus.New(), - } + return v1.New() } // Get returns the standard logger used throughout OPA. // // Deprecated. Do not rely on the global logger. func Get() *StandardLogger { - return &StandardLogger{ - logger: logrus.StandardLogger(), - } -} - -// SetOutput sets the underlying logrus output. -func (l *StandardLogger) SetOutput(w io.Writer) { - l.logger.SetOutput(w) -} - -// SetFormatter sets the underlying logrus formatter. -func (l *StandardLogger) SetFormatter(formatter logrus.Formatter) { - l.logger.SetFormatter(formatter) -} - -// WithFields provides additional fields to include in log output -func (l *StandardLogger) WithFields(fields map[string]interface{}) Logger { - cp := *l - cp.fields = make(map[string]interface{}) - for k, v := range l.fields { - cp.fields[k] = v - } - for k, v := range fields { - cp.fields[k] = v - } - return &cp -} - -// getFields returns additional fields of this logger -func (l *StandardLogger) getFields() map[string]interface{} { - return l.fields -} - -// SetLevel sets the standard logger level. -func (l *StandardLogger) SetLevel(level Level) { - var logrusLevel logrus.Level - switch level { - case Error: // set logging level report Warn or higher (includes Error) - logrusLevel = logrus.WarnLevel - case Warn: - logrusLevel = logrus.WarnLevel - case Info: - logrusLevel = logrus.InfoLevel - case Debug: - logrusLevel = logrus.DebugLevel - default: - l.Warn("unknown log level %v", level) - logrusLevel = logrus.InfoLevel - } - - l.logger.SetLevel(logrusLevel) -} - -// GetLevel returns the standard logger level. -func (l *StandardLogger) GetLevel() Level { - logrusLevel := l.logger.GetLevel() - - var level Level - switch logrusLevel { - case logrus.WarnLevel: - level = Error - case logrus.InfoLevel: - level = Info - case logrus.DebugLevel: - level = Debug - default: - l.Warn("unknown log level %v", logrusLevel) - level = Info - } - - return level -} - -// Debug logs at debug level -func (l *StandardLogger) Debug(fmt string, a ...interface{}) { - if len(a) == 0 { - l.logger.WithFields(l.getFields()).Debug(fmt) - return - } - l.logger.WithFields(l.getFields()).Debugf(fmt, a...) -} - -// Info logs at info level -func (l *StandardLogger) Info(fmt string, a ...interface{}) { - if len(a) == 0 { - l.logger.WithFields(l.getFields()).Info(fmt) - return - } - l.logger.WithFields(l.getFields()).Infof(fmt, a...) -} - -// Error logs at error level -func (l *StandardLogger) Error(fmt string, a ...interface{}) { - if len(a) == 0 { - l.logger.WithFields(l.getFields()).Error(fmt) - return - } - l.logger.WithFields(l.getFields()).Errorf(fmt, a...) -} - -// Warn logs at warn level -func (l *StandardLogger) Warn(fmt string, a ...interface{}) { - if len(a) == 0 { - l.logger.WithFields(l.getFields()).Warn(fmt) - return - } - l.logger.WithFields(l.getFields()).Warnf(fmt, a...) + return v1.Get() } // NoOpLogger logging implementation that does nothing -type NoOpLogger struct { - level Level - fields map[string]interface{} -} +type NoOpLogger = v1.NoOpLogger // NewNoOpLogger instantiates new NoOpLogger func NewNoOpLogger() *NoOpLogger { - return &NoOpLogger{ - level: Info, - } -} - -// WithFields provides additional fields to include in log output. -// Implemented here primarily to be able to switch between implementations without loss of data. -func (l *NoOpLogger) WithFields(fields map[string]interface{}) Logger { - cp := *l - cp.fields = fields - return &cp + return v1.NewNoOpLogger() } -// Debug noop -func (*NoOpLogger) Debug(string, ...interface{}) {} - -// Info noop -func (*NoOpLogger) Info(string, ...interface{}) {} - -// Error noop -func (*NoOpLogger) Error(string, ...interface{}) {} - -// Warn noop -func (*NoOpLogger) Warn(string, ...interface{}) {} - -// SetLevel set log level -func (l *NoOpLogger) SetLevel(level Level) { - l.level = level -} - -// GetLevel get log level -func (l *NoOpLogger) GetLevel() Level { - return l.level -} - -type requestContextKey string - -const reqCtxKey = requestContextKey("request-context-key") - // RequestContext represents the request context used to store data // related to the request that could be used on logs. -type RequestContext struct { - ClientAddr string - ReqID uint64 - ReqMethod string - ReqPath string - HTTPRequestContext HTTPRequestContext -} +type RequestContext = v1.RequestContext -type HTTPRequestContext struct { - Header http.Header -} - -// Fields adapts the RequestContext fields to logrus.Fields. -func (rctx RequestContext) Fields() logrus.Fields { - return logrus.Fields{ - "client_addr": rctx.ClientAddr, - "req_id": rctx.ReqID, - "req_method": rctx.ReqMethod, - "req_path": rctx.ReqPath, - } -} +type HTTPRequestContext = v1.HTTPRequestContext // NewContext returns a copy of parent with an associated RequestContext. func NewContext(parent context.Context, val *RequestContext) context.Context { - return context.WithValue(parent, reqCtxKey, val) + return v1.NewContext(parent, val) } // FromContext returns the RequestContext associated with ctx, if any. func FromContext(ctx context.Context) (*RequestContext, bool) { - requestContext, ok := ctx.Value(reqCtxKey).(*RequestContext) - return requestContext, ok + return v1.FromContext(ctx) } -const httpReqCtxKey = requestContextKey("http-request-context-key") - func WithHTTPRequestContext(parent context.Context, val *HTTPRequestContext) context.Context { - return context.WithValue(parent, httpReqCtxKey, val) + return v1.WithHTTPRequestContext(parent, val) } func HTTPRequestContextFromContext(ctx context.Context) (*HTTPRequestContext, bool) { - requestContext, ok := ctx.Value(httpReqCtxKey).(*HTTPRequestContext) - return requestContext, ok + return v1.HTTPRequestContextFromContext(ctx) } -const decisionCtxKey = requestContextKey("decision_id") - func WithDecisionID(parent context.Context, id string) context.Context { - return context.WithValue(parent, decisionCtxKey, id) + return v1.WithDecisionID(parent, id) } func DecisionIDFromContext(ctx context.Context) (string, bool) { - s, ok := ctx.Value(decisionCtxKey).(string) - return s, ok + return v1.DecisionIDFromContext(ctx) } diff --git a/vendor/github.com/open-policy-agent/opa/logging/test/doc.go b/vendor/github.com/open-policy-agent/opa/logging/test/doc.go new file mode 100644 index 000000000..dffe11ff7 --- /dev/null +++ b/vendor/github.com/open-policy-agent/opa/logging/test/doc.go @@ -0,0 +1,8 @@ +// Copyright 2024 The OPA Authors. All rights reserved. +// Use of this source code is governed by an Apache2 +// license that can be found in the LICENSE file. + +// Deprecated: This package is intended for older projects transitioning from OPA v0.x and will remain for the lifetime of OPA v1.x, but its use is not recommended. +// For newer features and behaviours, such as defaulting to the Rego v1 syntax, use the corresponding components in the [github.com/open-policy-agent/opa/v1] package instead. +// See https://www.openpolicyagent.org/docs/latest/v0-compatibility/ for more information. +package test diff --git a/vendor/github.com/open-policy-agent/opa/logging/test/test.go b/vendor/github.com/open-policy-agent/opa/logging/test/test.go index 5fd2dd47c..1379c1eea 100644 --- a/vendor/github.com/open-policy-agent/opa/logging/test/test.go +++ b/vendor/github.com/open-policy-agent/opa/logging/test/test.go @@ -1,101 +1,16 @@ package test import ( - "fmt" - "sync" - - "github.com/open-policy-agent/opa/logging" + v1 "github.com/open-policy-agent/opa/v1/logging/test" ) // LogEntry represents a log message. -type LogEntry struct { - Level logging.Level - Fields map[string]interface{} - Message string -} +type LogEntry = v1.LogEntry // Logger implementation that buffers messages for test purposes. -type Logger struct { - level logging.Level - fields map[string]interface{} - entries *[]LogEntry - mtx *sync.Mutex -} +type Logger = v1.Logger // New instantiates new Logger. func New() *Logger { - return &Logger{ - level: logging.Info, - entries: &[]LogEntry{}, - mtx: &sync.Mutex{}, - } -} - -// WithFields provides additional fields to include in log output. -// Implemented here primarily to be able to switch between implementations without loss of data. -func (l *Logger) WithFields(fields map[string]interface{}) logging.Logger { - l.mtx.Lock() - defer l.mtx.Unlock() - cp := Logger{ - level: l.level, - entries: l.entries, - fields: l.fields, - mtx: l.mtx, - } - flds := make(map[string]interface{}) - for k, v := range cp.fields { - flds[k] = v - } - for k, v := range fields { - flds[k] = v - } - cp.fields = flds - return &cp -} - -// Debug buffers a log message. -func (l *Logger) Debug(f string, a ...interface{}) { - l.append(logging.Debug, f, a...) -} - -// Info buffers a log message. -func (l *Logger) Info(f string, a ...interface{}) { - l.append(logging.Info, f, a...) -} - -// Error buffers a log message. -func (l *Logger) Error(f string, a ...interface{}) { - l.append(logging.Error, f, a...) -} - -// Warn buffers a log message. -func (l *Logger) Warn(f string, a ...interface{}) { - l.append(logging.Warn, f, a...) -} - -// SetLevel set log level. -func (l *Logger) SetLevel(level logging.Level) { - l.level = level -} - -// GetLevel get log level. -func (l *Logger) GetLevel() logging.Level { - return l.level -} - -// Entries returns buffered log entries. -func (l *Logger) Entries() []LogEntry { - l.mtx.Lock() - defer l.mtx.Unlock() - return *l.entries -} - -func (l *Logger) append(lvl logging.Level, f string, a ...interface{}) { - l.mtx.Lock() - defer l.mtx.Unlock() - *l.entries = append(*l.entries, LogEntry{ - Level: lvl, - Fields: l.fields, - Message: fmt.Sprintf(f, a...), - }) + return v1.New() } diff --git a/vendor/github.com/open-policy-agent/opa/metrics/doc.go b/vendor/github.com/open-policy-agent/opa/metrics/doc.go new file mode 100644 index 000000000..6a306991e --- /dev/null +++ b/vendor/github.com/open-policy-agent/opa/metrics/doc.go @@ -0,0 +1,8 @@ +// Copyright 2024 The OPA Authors. All rights reserved. +// Use of this source code is governed by an Apache2 +// license that can be found in the LICENSE file. + +// Deprecated: This package is intended for older projects transitioning from OPA v0.x and will remain for the lifetime of OPA v1.x, but its use is not recommended. +// For newer features and behaviours, such as defaulting to the Rego v1 syntax, use the corresponding components in the [github.com/open-policy-agent/opa/v1] package instead. +// See https://www.openpolicyagent.org/docs/latest/v0-compatibility/ for more information. +package metrics diff --git a/vendor/github.com/open-policy-agent/opa/metrics/metrics.go b/vendor/github.com/open-policy-agent/opa/metrics/metrics.go index 53cd606a3..2d2ae78b1 100644 --- a/vendor/github.com/open-policy-agent/opa/metrics/metrics.go +++ b/vendor/github.com/open-policy-agent/opa/metrics/metrics.go @@ -6,307 +6,52 @@ package metrics import ( - "encoding/json" - "fmt" - "sort" - "strings" - "sync" - "sync/atomic" - "time" - - go_metrics "github.com/rcrowley/go-metrics" + v1 "github.com/open-policy-agent/opa/v1/metrics" ) // Well-known metric names. const ( - BundleRequest = "bundle_request" - ServerHandler = "server_handler" - ServerQueryCacheHit = "server_query_cache_hit" - SDKDecisionEval = "sdk_decision_eval" - RegoQueryCompile = "rego_query_compile" - RegoQueryEval = "rego_query_eval" - RegoQueryParse = "rego_query_parse" - RegoModuleParse = "rego_module_parse" - RegoDataParse = "rego_data_parse" - RegoModuleCompile = "rego_module_compile" - RegoPartialEval = "rego_partial_eval" - RegoInputParse = "rego_input_parse" - RegoLoadFiles = "rego_load_files" - RegoLoadBundles = "rego_load_bundles" - RegoExternalResolve = "rego_external_resolve" + BundleRequest = v1.BundleRequest + ServerHandler = v1.ServerHandler + ServerQueryCacheHit = v1.ServerQueryCacheHit + SDKDecisionEval = v1.SDKDecisionEval + RegoQueryCompile = v1.RegoQueryCompile + RegoQueryEval = v1.RegoQueryEval + RegoQueryParse = v1.RegoQueryParse + RegoModuleParse = v1.RegoModuleParse + RegoDataParse = v1.RegoDataParse + RegoModuleCompile = v1.RegoModuleCompile + RegoPartialEval = v1.RegoPartialEval + RegoInputParse = v1.RegoInputParse + RegoLoadFiles = v1.RegoLoadFiles + RegoLoadBundles = v1.RegoLoadBundles + RegoExternalResolve = v1.RegoExternalResolve ) // Info contains attributes describing the underlying metrics provider. -type Info struct { - Name string `json:"name"` // name is a unique human-readable identifier for the provider. -} +type Info = v1.Info // Metrics defines the interface for a collection of performance metrics in the // policy engine. -type Metrics interface { - Info() Info - Timer(name string) Timer - Histogram(name string) Histogram - Counter(name string) Counter - All() map[string]interface{} - Clear() - json.Marshaler -} - -type TimerMetrics interface { - Timers() map[string]interface{} -} +type Metrics = v1.Metrics -type metrics struct { - mtx sync.Mutex - timers map[string]Timer - histograms map[string]Histogram - counters map[string]Counter -} +type TimerMetrics = v1.TimerMetrics // New returns a new Metrics object. func New() Metrics { - m := &metrics{} - m.Clear() - return m -} - -type metric struct { - Key string - Value interface{} -} - -func (*metrics) Info() Info { - return Info{ - Name: "", - } -} - -func (m *metrics) String() string { - - all := m.All() - sorted := make([]metric, 0, len(all)) - - for key, value := range all { - sorted = append(sorted, metric{ - Key: key, - Value: value, - }) - } - - sort.Slice(sorted, func(i, j int) bool { - return sorted[i].Key < sorted[j].Key - }) - - buf := make([]string, len(sorted)) - for i := range sorted { - buf[i] = fmt.Sprintf("%v:%v", sorted[i].Key, sorted[i].Value) - } - - return strings.Join(buf, " ") -} - -func (m *metrics) MarshalJSON() ([]byte, error) { - return json.Marshal(m.All()) -} - -func (m *metrics) Timer(name string) Timer { - m.mtx.Lock() - defer m.mtx.Unlock() - t, ok := m.timers[name] - if !ok { - t = &timer{} - m.timers[name] = t - } - return t -} - -func (m *metrics) Histogram(name string) Histogram { - m.mtx.Lock() - defer m.mtx.Unlock() - h, ok := m.histograms[name] - if !ok { - h = newHistogram() - m.histograms[name] = h - } - return h -} - -func (m *metrics) Counter(name string) Counter { - m.mtx.Lock() - defer m.mtx.Unlock() - c, ok := m.counters[name] - if !ok { - zero := counter{} - c = &zero - m.counters[name] = c - } - return c -} - -func (m *metrics) All() map[string]interface{} { - m.mtx.Lock() - defer m.mtx.Unlock() - result := map[string]interface{}{} - for name, timer := range m.timers { - result[m.formatKey(name, timer)] = timer.Value() - } - for name, hist := range m.histograms { - result[m.formatKey(name, hist)] = hist.Value() - } - for name, cntr := range m.counters { - result[m.formatKey(name, cntr)] = cntr.Value() - } - return result -} - -func (m *metrics) Timers() map[string]interface{} { - m.mtx.Lock() - defer m.mtx.Unlock() - ts := map[string]interface{}{} - for n, t := range m.timers { - ts[m.formatKey(n, t)] = t.Value() - } - return ts -} - -func (m *metrics) Clear() { - m.mtx.Lock() - defer m.mtx.Unlock() - m.timers = map[string]Timer{} - m.histograms = map[string]Histogram{} - m.counters = map[string]Counter{} -} - -func (m *metrics) formatKey(name string, metrics interface{}) string { - switch metrics.(type) { - case Timer: - return "timer_" + name + "_ns" - case Histogram: - return "histogram_" + name - case Counter: - return "counter_" + name - default: - return name - } + return v1.New() } // Timer defines the interface for a restartable timer that accumulates elapsed // time. -type Timer interface { - Value() interface{} - Int64() int64 - Start() - Stop() int64 -} - -type timer struct { - mtx sync.Mutex - start time.Time - value int64 -} - -func (t *timer) Start() { - t.mtx.Lock() - defer t.mtx.Unlock() - t.start = time.Now() -} - -func (t *timer) Stop() int64 { - t.mtx.Lock() - defer t.mtx.Unlock() - delta := time.Since(t.start).Nanoseconds() - t.value += delta - return delta -} - -func (t *timer) Value() interface{} { - return t.Int64() -} - -func (t *timer) Int64() int64 { - t.mtx.Lock() - defer t.mtx.Unlock() - return t.value -} +type Timer = v1.Timer // Histogram defines the interface for a histogram with hardcoded percentiles. -type Histogram interface { - Value() interface{} - Update(int64) -} - -type histogram struct { - hist go_metrics.Histogram // is thread-safe because of the underlying ExpDecaySample -} - -func newHistogram() Histogram { - // NOTE(tsandall): the reservoir size and alpha factor are taken from - // https://github.com/rcrowley/go-metrics. They may need to be tweaked in - // the future. - sample := go_metrics.NewExpDecaySample(1028, 0.015) - hist := go_metrics.NewHistogram(sample) - return &histogram{hist} -} - -func (h *histogram) Update(v int64) { - h.hist.Update(v) -} - -func (h *histogram) Value() interface{} { - values := map[string]interface{}{} - snap := h.hist.Snapshot() - percentiles := snap.Percentiles([]float64{ - 0.5, - 0.75, - 0.9, - 0.95, - 0.99, - 0.999, - 0.9999, - }) - values["count"] = snap.Count() - values["min"] = snap.Min() - values["max"] = snap.Max() - values["mean"] = snap.Mean() - values["stddev"] = snap.StdDev() - values["median"] = percentiles[0] - values["75%"] = percentiles[1] - values["90%"] = percentiles[2] - values["95%"] = percentiles[3] - values["99%"] = percentiles[4] - values["99.9%"] = percentiles[5] - values["99.99%"] = percentiles[6] - return values -} +type Histogram = v1.Histogram // Counter defines the interface for a monotonic increasing counter. -type Counter interface { - Value() interface{} - Incr() - Add(n uint64) -} - -type counter struct { - c uint64 -} - -func (c *counter) Incr() { - atomic.AddUint64(&c.c, 1) -} - -func (c *counter) Add(n uint64) { - atomic.AddUint64(&c.c, n) -} - -func (c *counter) Value() interface{} { - return atomic.LoadUint64(&c.c) -} +type Counter = v1.Counter func Statistics(num ...int64) interface{} { - t := newHistogram() - for _, n := range num { - t.Update(n) - } - return t.Value() + return v1.Statistics(num...) } diff --git a/vendor/github.com/open-policy-agent/opa/plugins/doc.go b/vendor/github.com/open-policy-agent/opa/plugins/doc.go new file mode 100644 index 000000000..77dade9b7 --- /dev/null +++ b/vendor/github.com/open-policy-agent/opa/plugins/doc.go @@ -0,0 +1,8 @@ +// Copyright 2024 The OPA Authors. All rights reserved. +// Use of this source code is governed by an Apache2 +// license that can be found in the LICENSE file. + +// Deprecated: This package is intended for older projects transitioning from OPA v0.x and will remain for the lifetime of OPA v1.x, but its use is not recommended. +// For newer features and behaviours, such as defaulting to the Rego v1 syntax, use the corresponding components in the [github.com/open-policy-agent/opa/v1] package instead. +// See https://www.openpolicyagent.org/docs/latest/v0-compatibility/ for more information. +package plugins diff --git a/vendor/github.com/open-policy-agent/opa/plugins/logs/doc.go b/vendor/github.com/open-policy-agent/opa/plugins/logs/doc.go new file mode 100644 index 000000000..947df3601 --- /dev/null +++ b/vendor/github.com/open-policy-agent/opa/plugins/logs/doc.go @@ -0,0 +1,8 @@ +// Copyright 2024 The OPA Authors. All rights reserved. +// Use of this source code is governed by an Apache2 +// license that can be found in the LICENSE file. + +// Deprecated: This package is intended for older projects transitioning from OPA v0.x and will remain for the lifetime of OPA v1.x, but its use is not recommended. +// For newer features and behaviours, such as defaulting to the Rego v1 syntax, use the corresponding components in the [github.com/open-policy-agent/opa/v1] package instead. +// See https://www.openpolicyagent.org/docs/latest/v0-compatibility/ for more information. +package logs diff --git a/vendor/github.com/open-policy-agent/opa/plugins/logs/plugin.go b/vendor/github.com/open-policy-agent/opa/plugins/logs/plugin.go index 8d31798eb..206dc0c8c 100644 --- a/vendor/github.com/open-policy-agent/opa/plugins/logs/plugin.go +++ b/vendor/github.com/open-policy-agent/opa/plugins/logs/plugin.go @@ -6,1125 +6,60 @@ package logs import ( - "bytes" - "context" - "encoding/json" - "fmt" - "math" - "math/rand" - "net/url" - "reflect" - "strings" - "sync" - "time" - - "golang.org/x/time/rate" - - "github.com/open-policy-agent/opa/ast" - "github.com/open-policy-agent/opa/internal/ref" - "github.com/open-policy-agent/opa/logging" - "github.com/open-policy-agent/opa/metrics" "github.com/open-policy-agent/opa/plugins" - lstat "github.com/open-policy-agent/opa/plugins/logs/status" - "github.com/open-policy-agent/opa/plugins/rest" - "github.com/open-policy-agent/opa/plugins/status" - "github.com/open-policy-agent/opa/rego" - "github.com/open-policy-agent/opa/server" - "github.com/open-policy-agent/opa/storage" - "github.com/open-policy-agent/opa/util" + v1 "github.com/open-policy-agent/opa/v1/plugins/logs" ) // Logger defines the interface for decision logging plugins. -type Logger interface { - plugins.Plugin - - Log(context.Context, EventV1) error -} +type Logger = v1.Logger // EventV1 represents a decision log event. // WARNING: The AST() function for EventV1 must be kept in sync with // the struct. Any changes here MUST be reflected in the AST() // implementation below. -type EventV1 struct { - Labels map[string]string `json:"labels"` - DecisionID string `json:"decision_id"` - TraceID string `json:"trace_id,omitempty"` - SpanID string `json:"span_id,omitempty"` - Revision string `json:"revision,omitempty"` // Deprecated: Use Bundles instead - Bundles map[string]BundleInfoV1 `json:"bundles,omitempty"` - Path string `json:"path,omitempty"` - Query string `json:"query,omitempty"` - Input *interface{} `json:"input,omitempty"` - Result *interface{} `json:"result,omitempty"` - MappedResult *interface{} `json:"mapped_result,omitempty"` - NDBuiltinCache *interface{} `json:"nd_builtin_cache,omitempty"` - Erased []string `json:"erased,omitempty"` - Masked []string `json:"masked,omitempty"` - Error error `json:"error,omitempty"` - RequestedBy string `json:"requested_by,omitempty"` - Timestamp time.Time `json:"timestamp"` - Metrics map[string]interface{} `json:"metrics,omitempty"` - RequestID uint64 `json:"req_id,omitempty"` - RequestContext *RequestContext `json:"request_context,omitempty"` - - inputAST ast.Value -} +type EventV1 = v1.EventV1 // BundleInfoV1 describes a bundle associated with a decision log event. -type BundleInfoV1 struct { - Revision string `json:"revision,omitempty"` -} - -type RequestContext struct { - HTTPRequest *HTTPRequestContext `json:"http,omitempty"` -} - -type HTTPRequestContext struct { - Headers map[string][]string `json:"headers,omitempty"` -} - -// AST returns the BundleInfoV1 as an AST value -func (b *BundleInfoV1) AST() ast.Value { - result := ast.NewObject() - if len(b.Revision) > 0 { - result.Insert(ast.StringTerm("revision"), ast.StringTerm(b.Revision)) - } - return result -} - -// Key ast.Term values for the Rego AST representation of the EventV1 -var labelsKey = ast.StringTerm("labels") -var decisionIDKey = ast.StringTerm("decision_id") -var revisionKey = ast.StringTerm("revision") -var bundlesKey = ast.StringTerm("bundles") -var pathKey = ast.StringTerm("path") -var queryKey = ast.StringTerm("query") -var inputKey = ast.StringTerm("input") -var resultKey = ast.StringTerm("result") -var mappedResultKey = ast.StringTerm("mapped_result") -var ndBuiltinCacheKey = ast.StringTerm("nd_builtin_cache") -var erasedKey = ast.StringTerm("erased") -var maskedKey = ast.StringTerm("masked") -var errorKey = ast.StringTerm("error") -var requestedByKey = ast.StringTerm("requested_by") -var timestampKey = ast.StringTerm("timestamp") -var metricsKey = ast.StringTerm("metrics") -var requestIDKey = ast.StringTerm("req_id") - -// AST returns the Rego AST representation for a given EventV1 object. -// This avoids having to round trip through JSON while applying a decision log -// mask policy to the event. -func (e *EventV1) AST() (ast.Value, error) { - var err error - event := ast.NewObject() - - if e.Labels != nil { - labelsObj := ast.NewObject() - for k, v := range e.Labels { - labelsObj.Insert(ast.StringTerm(k), ast.StringTerm(v)) - } - event.Insert(labelsKey, ast.NewTerm(labelsObj)) - } else { - event.Insert(labelsKey, ast.NullTerm()) - } - - event.Insert(decisionIDKey, ast.StringTerm(e.DecisionID)) - - if len(e.Revision) > 0 { - event.Insert(revisionKey, ast.StringTerm(e.Revision)) - } - - if len(e.Bundles) > 0 { - bundlesObj := ast.NewObject() - for k, v := range e.Bundles { - bundlesObj.Insert(ast.StringTerm(k), ast.NewTerm(v.AST())) - } - event.Insert(bundlesKey, ast.NewTerm(bundlesObj)) - } - - if len(e.Path) > 0 { - event.Insert(pathKey, ast.StringTerm(e.Path)) - } - - if len(e.Query) > 0 { - event.Insert(queryKey, ast.StringTerm(e.Query)) - } - - if e.Input != nil { - if e.inputAST == nil { - e.inputAST, err = roundtripJSONToAST(e.Input) - if err != nil { - return nil, err - } - } - event.Insert(inputKey, ast.NewTerm(e.inputAST)) - } - - if e.Result != nil { - results, err := roundtripJSONToAST(e.Result) - if err != nil { - return nil, err - } - event.Insert(resultKey, ast.NewTerm(results)) - } - - if e.MappedResult != nil { - mResults, err := roundtripJSONToAST(e.MappedResult) - if err != nil { - return nil, err - } - event.Insert(mappedResultKey, ast.NewTerm(mResults)) - } - - if e.NDBuiltinCache != nil { - ndbCache, err := roundtripJSONToAST(e.NDBuiltinCache) - if err != nil { - return nil, err - } - event.Insert(ndBuiltinCacheKey, ast.NewTerm(ndbCache)) - } - - if len(e.Erased) > 0 { - erased := make([]*ast.Term, len(e.Erased)) - for i, v := range e.Erased { - erased[i] = ast.StringTerm(v) - } - event.Insert(erasedKey, ast.NewTerm(ast.NewArray(erased...))) - } - - if len(e.Masked) > 0 { - masked := make([]*ast.Term, len(e.Masked)) - for i, v := range e.Masked { - masked[i] = ast.StringTerm(v) - } - event.Insert(maskedKey, ast.NewTerm(ast.NewArray(masked...))) - } +type BundleInfoV1 = v1.BundleInfoV1 - if e.Error != nil { - evalErr, err := roundtripJSONToAST(e.Error) - if err != nil { - return nil, err - } - event.Insert(errorKey, ast.NewTerm(evalErr)) - } +type RequestContext = v1.RequestContext - if len(e.RequestedBy) > 0 { - event.Insert(requestedByKey, ast.StringTerm(e.RequestedBy)) - } - - // Use the timestamp JSON marshaller to ensure the format is the same as - // round tripping through JSON. - timeBytes, err := e.Timestamp.MarshalJSON() - if err != nil { - return nil, err - } - event.Insert(timestampKey, ast.StringTerm(strings.Trim(string(timeBytes), "\""))) - - if e.Metrics != nil { - m, err := ast.InterfaceToValue(e.Metrics) - if err != nil { - return nil, err - } - event.Insert(metricsKey, ast.NewTerm(m)) - } - - if e.RequestID > 0 { - event.Insert(requestIDKey, ast.UIntNumberTerm(e.RequestID)) - } - - return event, nil -} - -func roundtripJSONToAST(x interface{}) (ast.Value, error) { - rawPtr := util.Reference(x) - // roundtrip through json: this turns slices (e.g. []string, []bool) into - // []interface{}, the only array type ast.InterfaceToValue can work with - if err := util.RoundTrip(rawPtr); err != nil { - return nil, err - } - - return ast.InterfaceToValue(*rawPtr) -} - -const ( - // min amount of time to wait following a failure - minRetryDelay = time.Millisecond * 100 - defaultMinDelaySeconds = int64(300) - defaultMaxDelaySeconds = int64(600) - defaultUploadSizeLimitBytes = int64(32768) // 32KB limit - defaultBufferSizeLimitBytes = int64(0) // unlimited - defaultMaskDecisionPath = "/system/log/mask" - defaultDropDecisionPath = "/system/log/drop" - logRateLimitExDropCounterName = "decision_logs_dropped_rate_limit_exceeded" - logNDBDropCounterName = "decision_logs_nd_builtin_cache_dropped" - logBufferSizeLimitExDropCounterName = "decision_logs_dropped_buffer_size_limit_bytes_exceeded" - logEncodingFailureCounterName = "decision_logs_encoding_failure" - defaultResourcePath = "/logs" -) +type HTTPRequestContext = v1.HTTPRequestContext // ReportingConfig represents configuration for the plugin's reporting behaviour. -type ReportingConfig struct { - BufferSizeLimitBytes *int64 `json:"buffer_size_limit_bytes,omitempty"` // max size of in-memory buffer - UploadSizeLimitBytes *int64 `json:"upload_size_limit_bytes,omitempty"` // max size of upload payload - MinDelaySeconds *int64 `json:"min_delay_seconds,omitempty"` // min amount of time to wait between successful poll attempts - MaxDelaySeconds *int64 `json:"max_delay_seconds,omitempty"` // max amount of time to wait between poll attempts - MaxDecisionsPerSecond *float64 `json:"max_decisions_per_second,omitempty"` // max number of decision logs to buffer per second - Trigger *plugins.TriggerMode `json:"trigger,omitempty"` // trigger mode -} +type ReportingConfig = v1.ReportingConfig -type RequestContextConfig struct { - HTTPRequest *HTTPRequestContextConfig `json:"http,omitempty"` -} +type RequestContextConfig = v1.RequestContextConfig -type HTTPRequestContextConfig struct { - Headers []string `json:"headers,omitempty"` -} +type HTTPRequestContextConfig = v1.HTTPRequestContextConfig // Config represents the plugin configuration. -type Config struct { - Plugin *string `json:"plugin"` - Service string `json:"service"` - PartitionName string `json:"partition_name,omitempty"` - Reporting ReportingConfig `json:"reporting"` - RequestContext RequestContextConfig `json:"request_context"` - MaskDecision *string `json:"mask_decision"` - DropDecision *string `json:"drop_decision"` - ConsoleLogs bool `json:"console"` - Resource *string `json:"resource"` - NDBuiltinCache bool `json:"nd_builtin_cache,omitempty"` - maskDecisionRef ast.Ref - dropDecisionRef ast.Ref -} - -func (c *Config) validateAndInjectDefaults(services []string, pluginsList []string, trigger *plugins.TriggerMode) error { - - if c.Plugin != nil { - var found bool - for _, other := range pluginsList { - if other == *c.Plugin { - found = true - break - } - } - if !found { - return fmt.Errorf("invalid plugin name %q in decision_logs", *c.Plugin) - } - } else if c.Service == "" && len(services) != 0 && !c.ConsoleLogs { - // For backwards compatibility allow defaulting to the first - // service listed, but only if console logging is disabled. If enabled - // we can't tell if the deployer wanted to use only console logs or - // both console logs and the default service option. - c.Service = services[0] - } else if c.Service != "" { - found := false - - for _, svc := range services { - if svc == c.Service { - found = true - break - } - } - - if !found { - return fmt.Errorf("invalid service name %q in decision_logs", c.Service) - } - } - - t, err := plugins.ValidateAndInjectDefaultsForTriggerMode(trigger, c.Reporting.Trigger) - if err != nil { - return fmt.Errorf("invalid decision_log config: %w", err) - } - c.Reporting.Trigger = t - - min := defaultMinDelaySeconds - max := defaultMaxDelaySeconds - - // reject bad min/max values - if c.Reporting.MaxDelaySeconds != nil && c.Reporting.MinDelaySeconds != nil { - if *c.Reporting.MaxDelaySeconds < *c.Reporting.MinDelaySeconds { - return fmt.Errorf("max reporting delay must be >= min reporting delay in decision_logs") - } - min = *c.Reporting.MinDelaySeconds - max = *c.Reporting.MaxDelaySeconds - } else if c.Reporting.MaxDelaySeconds == nil && c.Reporting.MinDelaySeconds != nil { - return fmt.Errorf("reporting configuration missing 'max_delay_seconds' in decision_logs") - } else if c.Reporting.MinDelaySeconds == nil && c.Reporting.MaxDelaySeconds != nil { - return fmt.Errorf("reporting configuration missing 'min_delay_seconds' in decision_logs") - } - - // scale to seconds - minSeconds := int64(time.Duration(min) * time.Second) - c.Reporting.MinDelaySeconds = &minSeconds - - maxSeconds := int64(time.Duration(max) * time.Second) - c.Reporting.MaxDelaySeconds = &maxSeconds - - // default the upload size limit - uploadLimit := defaultUploadSizeLimitBytes - if c.Reporting.UploadSizeLimitBytes != nil { - uploadLimit = *c.Reporting.UploadSizeLimitBytes - } - - c.Reporting.UploadSizeLimitBytes = &uploadLimit - - if c.Reporting.BufferSizeLimitBytes != nil && c.Reporting.MaxDecisionsPerSecond != nil { - return fmt.Errorf("invalid decision_log config, specify either 'buffer_size_limit_bytes' or 'max_decisions_per_second'") - } - - // default the buffer size limit - bufferLimit := defaultBufferSizeLimitBytes - if c.Reporting.BufferSizeLimitBytes != nil { - bufferLimit = *c.Reporting.BufferSizeLimitBytes - } - - c.Reporting.BufferSizeLimitBytes = &bufferLimit - - if c.MaskDecision == nil { - maskDecision := defaultMaskDecisionPath - c.MaskDecision = &maskDecision - } - - c.maskDecisionRef, err = ref.ParseDataPath(*c.MaskDecision) - if err != nil { - return fmt.Errorf("invalid mask_decision in decision_logs: %w", err) - } - - if c.DropDecision == nil { - dropDecision := defaultDropDecisionPath - c.DropDecision = &dropDecision - } - - c.dropDecisionRef, err = ref.ParseDataPath(*c.DropDecision) - if err != nil { - return fmt.Errorf("invalid drop_decision in decision_logs: %w", err) - } - - if c.PartitionName != "" { - resourcePath := fmt.Sprintf("/logs/%v", c.PartitionName) - c.Resource = &resourcePath - } else if c.Resource == nil { - resourcePath := defaultResourcePath - c.Resource = &resourcePath - } else { - if _, err := url.Parse(*c.Resource); err != nil { - return fmt.Errorf("invalid resource path %q: %w", *c.Resource, err) - } - } - - return nil -} +type Config = v1.Config // Plugin implements decision log buffering and uploading. -type Plugin struct { - manager *plugins.Manager - config Config - buffer *logBuffer - enc *chunkEncoder - mtx sync.Mutex - statusMtx sync.Mutex - stop chan chan struct{} - reconfig chan reconfigure - preparedMask prepareOnce - preparedDrop prepareOnce - limiter *rate.Limiter - metrics metrics.Metrics - logger logging.Logger - status *lstat.Status -} - -type prepareOnce struct { - once *sync.Once - preparedQuery *rego.PreparedEvalQuery - err error -} - -func newPrepareOnce() *prepareOnce { - return &prepareOnce{ - once: new(sync.Once), - } -} - -func (po *prepareOnce) drop() { - po.once = new(sync.Once) -} - -func (po *prepareOnce) prepareOnce(f func() (*rego.PreparedEvalQuery, error)) (*rego.PreparedEvalQuery, error) { - po.once.Do(func() { - po.preparedQuery, po.err = f() - }) - return po.preparedQuery, po.err -} - -type reconfigure struct { - config interface{} - done chan struct{} -} +type Plugin = v1.Plugin -// ParseConfig validates the config and injects default values. func ParseConfig(config []byte, services []string, pluginList []string) (*Config, error) { - t := plugins.DefaultTriggerMode - return NewConfigBuilder(). - WithBytes(config). - WithServices(services). - WithPlugins(pluginList). - WithTriggerMode(&t). - Parse() + return v1.ParseConfig(config, services, pluginList) } // ConfigBuilder assists in the construction of the plugin configuration. -type ConfigBuilder struct { - raw []byte - services []string - plugins []string - trigger *plugins.TriggerMode -} +type ConfigBuilder = v1.ConfigBuilder // NewConfigBuilder returns a new ConfigBuilder to build and parse the plugin config. func NewConfigBuilder() *ConfigBuilder { - return &ConfigBuilder{} -} - -// WithBytes sets the raw plugin config. -func (b *ConfigBuilder) WithBytes(config []byte) *ConfigBuilder { - b.raw = config - return b -} - -// WithServices sets the services that implement control plane APIs. -func (b *ConfigBuilder) WithServices(services []string) *ConfigBuilder { - b.services = services - return b -} - -// WithPlugins sets the list of named plugins for decision logging. -func (b *ConfigBuilder) WithPlugins(plugins []string) *ConfigBuilder { - b.plugins = plugins - return b -} - -// WithTriggerMode sets the plugin trigger mode. -func (b *ConfigBuilder) WithTriggerMode(trigger *plugins.TriggerMode) *ConfigBuilder { - b.trigger = trigger - return b -} - -// Parse validates the config and injects default values. -func (b *ConfigBuilder) Parse() (*Config, error) { - if b.raw == nil { - return nil, nil - } - - var parsedConfig Config - - if err := util.Unmarshal(b.raw, &parsedConfig); err != nil { - return nil, err - } - - if parsedConfig.Plugin == nil && parsedConfig.Service == "" && len(b.services) == 0 && !parsedConfig.ConsoleLogs { - // Nothing to validate or inject - return nil, nil - } - - if err := parsedConfig.validateAndInjectDefaults(b.services, b.plugins, b.trigger); err != nil { - return nil, err - } - - return &parsedConfig, nil + return v1.NewConfigBuilder() } // New returns a new Plugin with the given config. func New(parsedConfig *Config, manager *plugins.Manager) *Plugin { - - plugin := &Plugin{ - manager: manager, - config: *parsedConfig, - stop: make(chan chan struct{}), - buffer: newLogBuffer(*parsedConfig.Reporting.BufferSizeLimitBytes), - enc: newChunkEncoder(*parsedConfig.Reporting.UploadSizeLimitBytes), - reconfig: make(chan reconfigure), - logger: manager.Logger().WithFields(map[string]interface{}{"plugin": Name}), - status: &lstat.Status{}, - preparedDrop: *newPrepareOnce(), - preparedMask: *newPrepareOnce(), - } - - if parsedConfig.Reporting.MaxDecisionsPerSecond != nil { - limit := *parsedConfig.Reporting.MaxDecisionsPerSecond - plugin.limiter = rate.NewLimiter(rate.Limit(limit), int(math.Max(1, limit))) - } - - manager.RegisterCompilerTrigger(plugin.compilerUpdated) - - manager.UpdatePluginStatus(Name, &plugins.Status{State: plugins.StateNotReady}) - - return plugin -} - -// WithMetrics sets the global metrics provider to be used by the plugin. -func (p *Plugin) WithMetrics(m metrics.Metrics) *Plugin { - p.metrics = m - p.enc.WithMetrics(m) - return p + return v1.New(parsedConfig, manager) } // Name identifies the plugin on manager. -const Name = "decision_logs" +const Name = v1.Name // Lookup returns the decision logs plugin registered with the manager. func Lookup(manager *plugins.Manager) *Plugin { - if p := manager.Plugin(Name); p != nil { - return p.(*Plugin) - } - return nil -} - -// Start starts the plugin. -func (p *Plugin) Start(_ context.Context) error { - p.logger.Info("Starting decision logger.") - go p.loop() - p.manager.UpdatePluginStatus(Name, &plugins.Status{State: plugins.StateOK}) - return nil -} - -// Stop stops the plugin. -func (p *Plugin) Stop(ctx context.Context) { - p.logger.Info("Stopping decision logger.") - - if *p.config.Reporting.Trigger == plugins.TriggerPeriodic { - if _, ok := ctx.Deadline(); ok && p.config.Service != "" { - p.flushDecisions(ctx) - } - } - - done := make(chan struct{}) - p.stop <- done - <-done - p.manager.UpdatePluginStatus(Name, &plugins.Status{State: plugins.StateNotReady}) -} - -// Config returns the plugin's current configuration -func (p *Plugin) Config() *Config { - return &p.config -} - -func (p *Plugin) flushDecisions(ctx context.Context) { - p.logger.Info("Flushing decision logs.") - - done := make(chan bool) - - go func(ctx context.Context, done chan bool) { - for ctx.Err() == nil { - if _, err := p.oneShot(ctx); err != nil { - p.logger.Error("Error flushing decisions: %s", err) - // Wait some before retrying, but skip incrementing interval since we are shutting down - time.Sleep(1 * time.Second) - } else { - done <- true - return - } - } - }(ctx, done) - - select { - case <-done: - p.logger.Info("All decisions in buffer uploaded.") - case <-ctx.Done(): - switch ctx.Err() { - case context.DeadlineExceeded, context.Canceled: - p.logger.Error("Plugin stopped with decisions possibly still in buffer.") - } - } -} - -// Log appends a decision log event to the buffer for uploading. -func (p *Plugin) Log(ctx context.Context, decision *server.Info) error { - - bundles := map[string]BundleInfoV1{} - for name, info := range decision.Bundles { - bundles[name] = BundleInfoV1{Revision: info.Revision} - } - - event := EventV1{ - Labels: p.manager.Labels(), - DecisionID: decision.DecisionID, - TraceID: decision.TraceID, - SpanID: decision.SpanID, - Revision: decision.Revision, - Bundles: bundles, - Path: decision.Path, - Query: decision.Query, - Input: decision.Input, - Result: decision.Results, - MappedResult: decision.MappedResults, - NDBuiltinCache: decision.NDBuiltinCache, - RequestedBy: decision.RemoteAddr, - Timestamp: decision.Timestamp, - RequestID: decision.RequestID, - inputAST: decision.InputAST, - } - - headers := map[string][]string{} - rctx := p.config.RequestContext - - if rctx.HTTPRequest != nil && len(rctx.HTTPRequest.Headers) > 0 && decision.HTTPRequestContext.Header != nil { - for _, h := range rctx.HTTPRequest.Headers { - values := decision.HTTPRequestContext.Header.Values(h) - if len(values) > 0 { - headers[h] = decision.HTTPRequestContext.Header.Values(h) - } - } - } - - if len(headers) > 0 { - event.RequestContext = &RequestContext{HTTPRequest: &HTTPRequestContext{Headers: headers}} - } - - input, err := event.AST() - if err != nil { - return err - } - - drop, err := p.dropEvent(ctx, decision.Txn, input) - if err != nil { - p.logger.Error("Log drop decision failed: %v.", err) - return nil - } - - if drop { - p.logger.Debug("Decision log event to path %v dropped", event.Path) - return nil - } - - if decision.Metrics != nil { - event.Metrics = decision.Metrics.All() - } - - if decision.Error != nil { - event.Error = decision.Error - } - - if err := p.maskEvent(ctx, decision.Txn, input, &event); err != nil { - // TODO(tsandall): see note below about error handling. - p.logger.Error("Log event masking failed: %v.", err) - return nil - } - - if p.config.ConsoleLogs { - if err := p.logEvent(event); err != nil { - p.logger.Error("Failed to log to console: %v.", err) - } - } - - if p.config.Service != "" { - p.encodeAndBufferEvent(event) - } - - if p.config.Plugin != nil { - proxy, ok := p.manager.Plugin(*p.config.Plugin).(Logger) - if !ok { - return fmt.Errorf("plugin does not implement Logger interface") - } - return proxy.Log(ctx, event) - } - - return nil -} - -// Reconfigure notifies the plugin with a new configuration. -func (p *Plugin) Reconfigure(_ context.Context, config interface{}) { - - done := make(chan struct{}) - p.reconfig <- reconfigure{config: config, done: done} - - p.preparedMask.drop() - p.preparedDrop.drop() - - <-done -} - -// Trigger can be used to control when the plugin attempts to upload -// a new decision log in manual triggering mode. -func (p *Plugin) Trigger(ctx context.Context) error { - done := make(chan error) - - go func() { - if p.config.Service != "" { - err := p.doOneShot(ctx) - if err != nil { - if ctx.Err() == nil { - done <- err - } - } - } - close(done) - }() - - select { - case err := <-done: - return err - case <-ctx.Done(): - return ctx.Err() - } -} - -// compilerUpdated is called when a compiler trigger on the plugin manager -// fires. This indicates a new compiler instance is available. The decision -// logger needs to prepare a new masking query. -func (p *Plugin) compilerUpdated(storage.Transaction) { - p.preparedMask.drop() - p.preparedDrop.drop() -} - -func (p *Plugin) loop() { - - ctx, cancel := context.WithCancel(context.Background()) - - var retry int - - for { - - var waitC chan struct{} - - if *p.config.Reporting.Trigger == plugins.TriggerPeriodic && p.config.Service != "" { - err := p.doOneShot(ctx) - - var delay time.Duration - - if err == nil { - min := float64(*p.config.Reporting.MinDelaySeconds) - max := float64(*p.config.Reporting.MaxDelaySeconds) - delay = time.Duration(((max - min) * rand.Float64()) + min) - } else { - delay = util.DefaultBackoff(float64(minRetryDelay), float64(*p.config.Reporting.MaxDelaySeconds), retry) - } - - p.logger.Debug("Waiting %v before next upload/retry.", delay) - - waitC = make(chan struct{}) - go func() { - timer, timerCancel := util.TimerWithCancel(delay) - select { - case <-timer.C: - if err != nil { - retry++ - } else { - retry = 0 - } - close(waitC) - case <-ctx.Done(): - timerCancel() // explicitly cancel the timer. - } - }() - } - - select { - case <-waitC: - case update := <-p.reconfig: - p.reconfigure(update.config) - update.done <- struct{}{} - case done := <-p.stop: - cancel() - done <- struct{}{} - return - } - } -} - -func (p *Plugin) doOneShot(ctx context.Context) error { - uploaded, err := p.oneShot(ctx) - - // Make a local copy of the plugins's status. - p.statusMtx.Lock() - p.status.SetError(err) - oldStatus := p.status - p.statusMtx.Unlock() - - if s := status.Lookup(p.manager); s != nil { - s.UpdateDecisionLogsStatus(*oldStatus) - } - - if err != nil { - p.logger.Error("%v.", err) - } else if uploaded { - p.logger.Info("Logs uploaded successfully.") - } else { - p.logger.Debug("Log upload queue was empty.") - } - return err -} - -func (p *Plugin) oneShot(ctx context.Context) (ok bool, err error) { - // Make a local copy of the plugins's encoder and buffer and create - // a new encoder and buffer. This is needed as locking the buffer for - // the upload duration will block policy evaluation and result in - // increased latency for OPA clients - p.mtx.Lock() - oldChunkEnc := p.enc - oldBuffer := p.buffer - p.buffer = newLogBuffer(*p.config.Reporting.BufferSizeLimitBytes) - p.enc = newChunkEncoder(*p.config.Reporting.UploadSizeLimitBytes).WithMetrics(p.metrics) - p.mtx.Unlock() - - // Along with uploading the compressed events in the buffer - // to the remote server, flush any pending compressed data to the - // underlying writer and add to the buffer. - chunk, err := oldChunkEnc.Flush() - if err != nil { - return false, err - } - - for _, ch := range chunk { - p.bufferChunk(oldBuffer, ch) - } - - if oldBuffer.Len() == 0 { - return false, nil - } - - for bs := oldBuffer.Pop(); bs != nil; bs = oldBuffer.Pop() { - if err == nil { - err = uploadChunk(ctx, p.manager.Client(p.config.Service), *p.config.Resource, bs) - } - if err != nil { - if p.limiter != nil { - events, decErr := newChunkDecoder(bs).decode() - if decErr != nil { - continue - } - - for _, event := range events { - p.encodeAndBufferEvent(event) - } - } else { - // requeue the chunk - p.mtx.Lock() - p.bufferChunk(p.buffer, bs) - p.mtx.Unlock() - } - } - } - - return err == nil, err -} - -func (p *Plugin) reconfigure(config interface{}) { - - newConfig := config.(*Config) - - if reflect.DeepEqual(p.config, *newConfig) { - p.logger.Debug("Decision log uploader configuration unchanged.") - return - } - - p.logger.Info("Decision log uploader configuration changed.") - p.config = *newConfig -} - -// NOTE(philipc): Because ND builtins caching can cause unbounded growth in -// decision log entry size, we do best-effort event encoding here, and when we -// run out of space, we drop the ND builtins cache, and try encoding again. -func (p *Plugin) encodeAndBufferEvent(event EventV1) { - if p.limiter != nil { - if !p.limiter.Allow() { - if p.metrics != nil { - p.metrics.Counter(logRateLimitExDropCounterName).Incr() - } - - p.logger.Error("Decision log dropped as rate limit exceeded. Reduce reporting interval or increase rate limit.") - return - } - } - - result, err := p.encodeEvent(event) - if err != nil { - // If there's no ND builtins cache in the event, then we don't - // need to retry encoding anything. - if event.NDBuiltinCache == nil { - // TODO(tsandall): revisit this now that we have an API that - // can return an error. Should the default behaviour be to - // fail-closed as we do for plugins? - - if p.metrics != nil { - p.metrics.Counter(logEncodingFailureCounterName).Incr() - } - p.logger.Error("Log encoding failed: %v.", err) - return - } - - // Attempt to encode the event again, dropping the ND builtins cache. - newEvent := event - newEvent.NDBuiltinCache = nil - - result, err = p.encodeEvent(newEvent) - if err != nil { - if p.metrics != nil { - p.metrics.Counter(logEncodingFailureCounterName).Incr() - } - p.logger.Error("Log encoding failed: %v.", err) - return - } - - // Re-encoding was successful, but we still need to alert users. - p.logger.Error("ND builtins cache dropped from this event to fit under maximum upload size limits. Increase upload size limit or change usage of non-deterministic builtins.") - p.metrics.Counter(logNDBDropCounterName).Incr() - } - - p.mtx.Lock() - defer p.mtx.Unlock() - for _, chunk := range result { - p.bufferChunk(p.buffer, chunk) - } -} - -func (p *Plugin) encodeEvent(event EventV1) ([][]byte, error) { - var buf bytes.Buffer - if err := json.NewEncoder(&buf).Encode(event); err != nil { - return nil, err - } - - p.mtx.Lock() - defer p.mtx.Unlock() - return p.enc.WriteBytes(buf.Bytes()) -} - -func (p *Plugin) bufferChunk(buffer *logBuffer, bs []byte) { - dropped := buffer.Push(bs) - if dropped > 0 { - if p.metrics != nil { - p.metrics.Counter(logBufferSizeLimitExDropCounterName).Incr() - } - p.logger.Error("Dropped %v chunks from buffer. Reduce reporting interval or increase buffer size.", dropped) - } -} - -func (p *Plugin) maskEvent(ctx context.Context, txn storage.Transaction, input ast.Value, event *EventV1) error { - pq, err := p.preparedMask.prepareOnce(func() (*rego.PreparedEvalQuery, error) { - var pq rego.PreparedEvalQuery - - query := ast.NewBody(ast.NewExpr(ast.NewTerm(p.config.maskDecisionRef))) - - r := rego.New( - rego.ParsedQuery(query), - rego.Compiler(p.manager.GetCompiler()), - rego.Store(p.manager.Store), - rego.Transaction(txn), - rego.Runtime(p.manager.Info), - rego.EnablePrintStatements(p.manager.EnablePrintStatements()), - rego.PrintHook(p.manager.PrintHook()), - ) - - pq, err := r.PrepareForEval(context.Background()) - if err != nil { - return nil, err - } - return &pq, nil - }) - - if err != nil { - return err - } - - rs, err := pq.Eval( - ctx, - rego.EvalParsedInput(input), - rego.EvalTransaction(txn), - ) - - if err != nil { - return err - } else if len(rs) == 0 { - return nil - } - - mRuleSet, err := newMaskRuleSet( - rs[0].Expressions[0].Value, - func(mRule *maskRule, err error) { - p.logger.Error("mask rule skipped: %s: %s", mRule.String(), err.Error()) - }, - ) - if err != nil { - return err - } - - mRuleSet.Mask(event) - - return nil -} - -func (p *Plugin) dropEvent(ctx context.Context, txn storage.Transaction, input ast.Value) (bool, error) { - var err error - - pq, err := p.preparedDrop.prepareOnce(func() (*rego.PreparedEvalQuery, error) { - var pq rego.PreparedEvalQuery - - query := ast.NewBody(ast.NewExpr(ast.NewTerm(p.config.dropDecisionRef))) - r := rego.New( - rego.ParsedQuery(query), - rego.Compiler(p.manager.GetCompiler()), - rego.Store(p.manager.Store), - rego.Transaction(txn), - rego.Runtime(p.manager.Info), - rego.EnablePrintStatements(p.manager.EnablePrintStatements()), - rego.PrintHook(p.manager.PrintHook()), - ) - - pq, err := r.PrepareForEval(context.Background()) - if err != nil { - return nil, err - } - return &pq, nil - }) - - if err != nil { - return false, err - } - - rs, err := pq.Eval( - ctx, - rego.EvalParsedInput(input), - rego.EvalTransaction(txn), - ) - - if err != nil { - return false, err - } - - return rs.Allowed(), nil -} - -func uploadChunk(ctx context.Context, client rest.Client, uploadPath string, data []byte) error { - - resp, err := client. - WithHeader("Content-Type", "application/json"). - WithHeader("Content-Encoding", "gzip"). - WithBytes(data). - Do(ctx, "POST", uploadPath) - - if err != nil { - return fmt.Errorf("log upload failed: %w", err) - } - - defer util.Close(resp) - - if resp.StatusCode < 200 || resp.StatusCode >= 300 { - return lstat.HTTPError{StatusCode: resp.StatusCode} - } - - return nil -} - -func (p *Plugin) logEvent(event EventV1) error { - eventBuf, err := json.Marshal(&event) - if err != nil { - return err - } - fields := map[string]interface{}{} - err = util.UnmarshalJSON(eventBuf, &fields) - if err != nil { - return err - } - p.manager.ConsoleLogger().WithFields(fields).WithFields(map[string]interface{}{ - "type": "openpolicyagent.org/decision_logs", - }).Info("Decision Log") - return nil + return v1.Lookup(manager) } diff --git a/vendor/github.com/open-policy-agent/opa/plugins/plugins.go b/vendor/github.com/open-policy-agent/opa/plugins/plugins.go index 567acfb81..b050979e3 100644 --- a/vendor/github.com/open-policy-agent/opa/plugins/plugins.go +++ b/vendor/github.com/open-policy-agent/opa/plugins/plugins.go @@ -6,13 +6,6 @@ package plugins import ( - "context" - "errors" - "fmt" - mr "math/rand" - "sync" - "time" - "github.com/open-policy-agent/opa/internal/report" "github.com/prometheus/client_golang/prometheus" "go.opentelemetry.io/otel/sdk/trace" @@ -21,20 +14,14 @@ import ( "github.com/open-policy-agent/opa/ast" "github.com/open-policy-agent/opa/bundle" - "github.com/open-policy-agent/opa/config" "github.com/open-policy-agent/opa/hooks" - bundleUtils "github.com/open-policy-agent/opa/internal/bundle" - cfg "github.com/open-policy-agent/opa/internal/config" - initload "github.com/open-policy-agent/opa/internal/runtime/init" - "github.com/open-policy-agent/opa/keys" "github.com/open-policy-agent/opa/loader" "github.com/open-policy-agent/opa/logging" - "github.com/open-policy-agent/opa/plugins/rest" "github.com/open-policy-agent/opa/resolver/wasm" "github.com/open-policy-agent/opa/storage" - "github.com/open-policy-agent/opa/topdown/cache" "github.com/open-policy-agent/opa/topdown/print" "github.com/open-policy-agent/opa/tracing" + v1 "github.com/open-policy-agent/opa/v1/plugins" ) // Factory defines the interface OPA uses to instantiate your plugin. @@ -84,10 +71,7 @@ import ( // // After a plugin has been created subsequent status updates can be // send anytime the plugin enters a ready or error state. -type Factory interface { - Validate(manager *Manager, config []byte) (interface{}, error) - New(manager *Manager, config interface{}) Plugin -} +type Factory = v1.Factory // Plugin defines the interface OPA uses to manage your plugin. // @@ -101,139 +85,70 @@ type Factory interface { // When OPA receives new configuration for your plugin via discovery // it will first Validate the configuration using your factory and // then call Reconfigure. -type Plugin interface { - Start(ctx context.Context) error - Stop(ctx context.Context) - Reconfigure(ctx context.Context, config interface{}) -} +type Plugin = v1.Plugin // Triggerable defines the interface plugins use for manual plugin triggers. -type Triggerable interface { - Trigger(context.Context) error -} +type Triggerable = v1.Triggerable // State defines the state that a Plugin instance is currently // in with pre-defined states. -type State string +type State = v1.State const ( // StateNotReady indicates that the Plugin is not in an error state, but isn't // ready for normal operation yet. This should only happen at // initialization time. - StateNotReady State = "NOT_READY" + StateNotReady = v1.StateNotReady // StateOK signifies that the Plugin is operating normally. - StateOK State = "OK" + StateOK = v1.StateOK // StateErr indicates that the Plugin is in an error state and should not // be considered as functional. - StateErr State = "ERROR" + StateErr = v1.StateErr // StateWarn indicates the Plugin is operating, but in a potentially dangerous or // degraded state. It may be used to indicate manual remediation is needed, or to // alert admins of some other noteworthy state. - StateWarn State = "WARN" + StateWarn = v1.StateWarn ) // TriggerMode defines the trigger mode utilized by a Plugin for bundle download, // log upload etc. -type TriggerMode string +type TriggerMode = v1.TriggerMode const ( // TriggerPeriodic represents periodic polling mechanism - TriggerPeriodic TriggerMode = "periodic" + TriggerPeriodic = v1.TriggerPeriodic // TriggerManual represents manual triggering mechanism - TriggerManual TriggerMode = "manual" + TriggerManual = v1.TriggerManual // DefaultTriggerMode represents default trigger mechanism - DefaultTriggerMode TriggerMode = "periodic" + DefaultTriggerMode = v1.DefaultTriggerMode ) -// default interval between OPA report uploads -var defaultUploadIntervalSec = int64(3600) - // Status has a Plugin's current status plus an optional Message. -type Status struct { - State State `json:"state"` - Message string `json:"message,omitempty"` -} - -func (s *Status) String() string { - return fmt.Sprintf("{%v %q}", s.State, s.Message) -} +type Status = v1.Status // StatusListener defines a handler to register for status updates. -type StatusListener func(status map[string]*Status) +type StatusListener v1.StatusListener // Manager implements lifecycle management of plugins and gives plugins access // to engine-wide components like storage. -type Manager struct { - Store storage.Store - Config *config.Config - Info *ast.Term - ID string - - compiler *ast.Compiler - compilerMux sync.RWMutex - wasmResolvers []*wasm.Resolver - wasmResolversMtx sync.RWMutex - services map[string]rest.Client - keys map[string]*keys.Config - plugins []namedplugin - registeredTriggers []func(storage.Transaction) - mtx sync.Mutex - pluginStatus map[string]*Status - pluginStatusListeners map[string]StatusListener - initBundles map[string]*bundle.Bundle - initFiles loader.Result - maxErrors int - initialized bool - interQueryBuiltinCacheConfig *cache.Config - gracefulShutdownPeriod int - registeredCacheTriggers []func(*cache.Config) - logger logging.Logger - consoleLogger logging.Logger - serverInitialized chan struct{} - serverInitializedOnce sync.Once - printHook print.Hook - enablePrintStatements bool - router *mux.Router - prometheusRegister prometheus.Registerer - tracerProvider *trace.TracerProvider - distributedTacingOpts tracing.Options - registeredNDCacheTriggers []func(bool) - registeredTelemetryGatherers map[string]report.Gatherer - bootstrapConfigLabels map[string]string - hooks hooks.Hooks - enableTelemetry bool - reporter *report.Reporter - opaReportNotifyCh chan struct{} - stop chan chan struct{} - parserOptions ast.ParserOptions -} - -type managerContextKey string -type managerWasmResolverKey string - -const managerCompilerContextKey = managerContextKey("compiler") -const managerWasmResolverContextKey = managerWasmResolverKey("wasmResolvers") +type Manager = v1.Manager // SetCompilerOnContext puts the compiler into the storage context. Calling this // function before committing updated policies to storage allows the manager to // skip parsing and compiling of modules. Instead, the manager will use the // compiler that was stored on the context. func SetCompilerOnContext(context *storage.Context, compiler *ast.Compiler) { - context.Put(managerCompilerContextKey, compiler) + v1.SetCompilerOnContext(context, compiler) } // GetCompilerOnContext gets the compiler cached on the storage context. func GetCompilerOnContext(context *storage.Context) *ast.Compiler { - compiler, ok := context.Get(managerCompilerContextKey).(*ast.Compiler) - if !ok { - return nil - } - return compiler + return v1.GetCompilerOnContext(context) } // SetWasmResolversOnContext puts a set of Wasm Resolvers into the storage @@ -241,858 +156,113 @@ func GetCompilerOnContext(context *storage.Context) *ast.Compiler { // storage allows the manager to skip initializing modules before using them. // Instead, the manager will use the compiler that was stored on the context. func SetWasmResolversOnContext(context *storage.Context, rs []*wasm.Resolver) { - context.Put(managerWasmResolverContextKey, rs) -} - -// getWasmResolversOnContext gets the resolvers cached on the storage context. -func getWasmResolversOnContext(context *storage.Context) []*wasm.Resolver { - resolvers, ok := context.Get(managerWasmResolverContextKey).([]*wasm.Resolver) - if !ok { - return nil - } - return resolvers -} - -func validateTriggerMode(mode TriggerMode) error { - switch mode { - case TriggerPeriodic, TriggerManual: - return nil - default: - return fmt.Errorf("invalid trigger mode %q (want %q or %q)", mode, TriggerPeriodic, TriggerManual) - } + v1.SetWasmResolversOnContext(context, rs) } // ValidateAndInjectDefaultsForTriggerMode validates the trigger mode and injects default values func ValidateAndInjectDefaultsForTriggerMode(a, b *TriggerMode) (*TriggerMode, error) { - - if a == nil && b != nil { - err := validateTriggerMode(*b) - if err != nil { - return nil, err - } - return b, nil - } else if a != nil && b == nil { - err := validateTriggerMode(*a) - if err != nil { - return nil, err - } - return a, nil - } else if a != nil && b != nil { - if *a != *b { - return nil, fmt.Errorf("trigger mode mismatch: %s and %s (hint: check discovery configuration)", *a, *b) - } - err := validateTriggerMode(*a) - if err != nil { - return nil, err - } - return a, nil - } - - t := DefaultTriggerMode - return &t, nil -} - -type namedplugin struct { - name string - plugin Plugin + return v1.ValidateAndInjectDefaultsForTriggerMode(a, b) } // Info sets the runtime information on the manager. The runtime information is // propagated to opa.runtime() built-in function calls. func Info(term *ast.Term) func(*Manager) { - return func(m *Manager) { - m.Info = term - } + return v1.Info(term) } // InitBundles provides the initial set of bundles to load. func InitBundles(b map[string]*bundle.Bundle) func(*Manager) { - return func(m *Manager) { - m.initBundles = b - } + return v1.InitBundles(b) } // InitFiles provides the initial set of other data/policy files to load. func InitFiles(f loader.Result) func(*Manager) { - return func(m *Manager) { - m.initFiles = f - } + return v1.InitFiles(f) } // MaxErrors sets the error limit for the manager's shared compiler. func MaxErrors(n int) func(*Manager) { - return func(m *Manager) { - m.maxErrors = n - } + return v1.MaxErrors(n) } // GracefulShutdownPeriod passes the configured graceful shutdown period to plugins func GracefulShutdownPeriod(gracefulShutdownPeriod int) func(*Manager) { - return func(m *Manager) { - m.gracefulShutdownPeriod = gracefulShutdownPeriod - } + return v1.GracefulShutdownPeriod(gracefulShutdownPeriod) } // Logger configures the passed logger on the plugin manager (useful to // configure default fields) func Logger(logger logging.Logger) func(*Manager) { - return func(m *Manager) { - m.logger = logger - } + return v1.Logger(logger) } // ConsoleLogger sets the passed logger to be used by plugins that are // configured with console logging enabled. func ConsoleLogger(logger logging.Logger) func(*Manager) { - return func(m *Manager) { - m.consoleLogger = logger - } + return v1.ConsoleLogger(logger) } func EnablePrintStatements(yes bool) func(*Manager) { - return func(m *Manager) { - m.enablePrintStatements = yes - } + return v1.EnablePrintStatements(yes) } func PrintHook(h print.Hook) func(*Manager) { - return func(m *Manager) { - m.printHook = h - } + return v1.PrintHook(h) } func WithRouter(r *mux.Router) func(*Manager) { - return func(m *Manager) { - m.router = r - } + return v1.WithRouter(r) } // WithPrometheusRegister sets the passed prometheus.Registerer to be used by plugins func WithPrometheusRegister(prometheusRegister prometheus.Registerer) func(*Manager) { - return func(m *Manager) { - m.prometheusRegister = prometheusRegister - } + return v1.WithPrometheusRegister(prometheusRegister) } // WithTracerProvider sets the passed *trace.TracerProvider to be used by plugins func WithTracerProvider(tracerProvider *trace.TracerProvider) func(*Manager) { - return func(m *Manager) { - m.tracerProvider = tracerProvider - } + return v1.WithTracerProvider(tracerProvider) } // WithDistributedTracingOpts sets the options to be used by distributed tracing. func WithDistributedTracingOpts(tr tracing.Options) func(*Manager) { - return func(m *Manager) { - m.distributedTacingOpts = tr - } + return v1.WithDistributedTracingOpts(tr) } // WithHooks allows passing hooks to the plugin manager. func WithHooks(hs hooks.Hooks) func(*Manager) { - return func(m *Manager) { - m.hooks = hs - } + return v1.WithHooks(hs) } // WithParserOptions sets the parser options to be used by the plugin manager. func WithParserOptions(opts ast.ParserOptions) func(*Manager) { - return func(m *Manager) { - m.parserOptions = opts - } + return v1.WithParserOptions(opts) } // WithEnableTelemetry controls whether OPA will send telemetry reports to an external service. func WithEnableTelemetry(enableTelemetry bool) func(*Manager) { - return func(m *Manager) { - m.enableTelemetry = enableTelemetry - } + return v1.WithEnableTelemetry(enableTelemetry) } // WithTelemetryGatherers allows registration of telemetry gatherers which enable injection of additional data in the // telemetry report func WithTelemetryGatherers(gs map[string]report.Gatherer) func(*Manager) { - return func(m *Manager) { - m.registeredTelemetryGatherers = gs - } + return v1.WithTelemetryGatherers(gs) } // New creates a new Manager using config. func New(raw []byte, id string, store storage.Store, opts ...func(*Manager)) (*Manager, error) { - - parsedConfig, err := config.ParseConfig(raw, id) - if err != nil { - return nil, err - } - - m := &Manager{ - Store: store, - Config: parsedConfig, - ID: id, - pluginStatus: map[string]*Status{}, - pluginStatusListeners: map[string]StatusListener{}, - maxErrors: -1, - serverInitialized: make(chan struct{}), - bootstrapConfigLabels: parsedConfig.Labels, - } - - for _, f := range opts { - f(m) - } - - if m.logger == nil { - m.logger = logging.Get() - } - - if m.consoleLogger == nil { - m.consoleLogger = logging.New() - } - - m.hooks.Each(func(h hooks.Hook) { - if f, ok := h.(hooks.ConfigHook); ok { - if c, e := f.OnConfig(context.Background(), parsedConfig); e != nil { - err = errors.Join(err, e) - } else { - parsedConfig = c - } - } - }) - if err != nil { - return nil, err - } - - // do after options and overrides - m.keys, err = keys.ParseKeysConfig(parsedConfig.Keys) - if err != nil { - return nil, err - } - - m.interQueryBuiltinCacheConfig, err = cache.ParseCachingConfig(parsedConfig.Caching) - if err != nil { - return nil, err - } - - serviceOpts := cfg.ServiceOptions{ - Raw: parsedConfig.Services, - AuthPlugin: m.AuthPlugin, - Keys: m.keys, - Logger: m.logger, - DistributedTacingOpts: m.distributedTacingOpts, - } - - m.services, err = cfg.ParseServicesConfig(serviceOpts) - if err != nil { - return nil, err - } - - if m.enableTelemetry { - reporter, err := report.New(id, report.Options{Logger: m.logger}) - if err != nil { - return nil, err - } - m.reporter = reporter - - m.reporter.RegisterGatherer("min_compatible_version", func(_ context.Context) (any, error) { - var minimumCompatibleVersion string - if m.compiler != nil && m.compiler.Required != nil { - minimumCompatibleVersion, _ = m.compiler.Required.MinimumCompatibleVersion() - } - return minimumCompatibleVersion, nil - }) - - // register any additional gatherers - for k, g := range m.registeredTelemetryGatherers { - m.reporter.RegisterGatherer(k, g) - } - } - - return m, nil -} - -// Init returns an error if the manager could not initialize itself. Init() should -// be called before Start(). Init() is idempotent. -func (m *Manager) Init(ctx context.Context) error { - - if m.initialized { - return nil - } - - params := storage.TransactionParams{ - Write: true, - Context: storage.NewContext(), - } - - if m.enableTelemetry { - m.opaReportNotifyCh = make(chan struct{}) - m.stop = make(chan chan struct{}) - go m.sendOPAUpdateLoop(ctx) - } - - err := storage.Txn(ctx, m.Store, params, func(txn storage.Transaction) error { - - result, err := initload.InsertAndCompile(ctx, initload.InsertAndCompileOptions{ - Store: m.Store, - Txn: txn, - Files: m.initFiles, - Bundles: m.initBundles, - MaxErrors: m.maxErrors, - EnablePrintStatements: m.enablePrintStatements, - }) - - if err != nil { - return err - } - - SetCompilerOnContext(params.Context, result.Compiler) - - resolvers, err := bundleUtils.LoadWasmResolversFromStore(ctx, m.Store, txn, nil) - if err != nil { - return err + options := make([]func(*Manager), 0, len(opts)+1) + options = append(options, opts...) + + // Add option to apply default Rego version if not set. Must be last in list of options. + options = append(options, func(m *Manager) { + if m.ParserOptions().RegoVersion == ast.RegoUndefined { + cpy := m.ParserOptions() + cpy.RegoVersion = ast.DefaultRegoVersion + WithParserOptions(cpy)(m) } - SetWasmResolversOnContext(params.Context, resolvers) - - _, err = m.Store.Register(ctx, txn, storage.TriggerConfig{OnCommit: m.onCommit}) - return err }) - if err != nil { - if m.stop != nil { - done := make(chan struct{}) - m.stop <- done - <-done - } - - return err - } - - m.initialized = true - return nil -} - -// Labels returns the set of labels from the configuration. -func (m *Manager) Labels() map[string]string { - m.mtx.Lock() - defer m.mtx.Unlock() - return m.Config.Labels -} - -// InterQueryBuiltinCacheConfig returns the configuration for the inter-query caches. -func (m *Manager) InterQueryBuiltinCacheConfig() *cache.Config { - m.mtx.Lock() - defer m.mtx.Unlock() - return m.interQueryBuiltinCacheConfig -} - -// Register adds a plugin to the manager. When the manager is started, all of -// the plugins will be started. -func (m *Manager) Register(name string, plugin Plugin) { - m.mtx.Lock() - defer m.mtx.Unlock() - m.plugins = append(m.plugins, namedplugin{ - name: name, - plugin: plugin, - }) - if _, ok := m.pluginStatus[name]; !ok { - m.pluginStatus[name] = &Status{State: StateNotReady} - } -} - -// Plugins returns the list of plugins registered with the manager. -func (m *Manager) Plugins() []string { - m.mtx.Lock() - defer m.mtx.Unlock() - result := make([]string, len(m.plugins)) - for i := range m.plugins { - result[i] = m.plugins[i].name - } - return result -} - -// Plugin returns the plugin registered with name or nil if name is not found. -func (m *Manager) Plugin(name string) Plugin { - m.mtx.Lock() - defer m.mtx.Unlock() - for i := range m.plugins { - if m.plugins[i].name == name { - return m.plugins[i].plugin - } - } - return nil -} - -// AuthPlugin returns the HTTPAuthPlugin registered with name or nil if name is not found. -func (m *Manager) AuthPlugin(name string) rest.HTTPAuthPlugin { - m.mtx.Lock() - defer m.mtx.Unlock() - for i := range m.plugins { - if m.plugins[i].name == name { - return m.plugins[i].plugin.(rest.HTTPAuthPlugin) - } - } - return nil -} - -// GetCompiler returns the manager's compiler. -func (m *Manager) GetCompiler() *ast.Compiler { - m.compilerMux.RLock() - defer m.compilerMux.RUnlock() - return m.compiler -} - -func (m *Manager) setCompiler(compiler *ast.Compiler) { - m.compilerMux.Lock() - defer m.compilerMux.Unlock() - m.compiler = compiler -} - -// GetRouter returns the managers router if set -func (m *Manager) GetRouter() *mux.Router { - m.mtx.Lock() - defer m.mtx.Unlock() - return m.router -} - -// RegisterCompilerTrigger registers for change notifications when the compiler -// is changed. -func (m *Manager) RegisterCompilerTrigger(f func(storage.Transaction)) { - m.mtx.Lock() - defer m.mtx.Unlock() - m.registeredTriggers = append(m.registeredTriggers, f) -} - -// GetWasmResolvers returns the manager's set of Wasm Resolvers. -func (m *Manager) GetWasmResolvers() []*wasm.Resolver { - m.wasmResolversMtx.RLock() - defer m.wasmResolversMtx.RUnlock() - return m.wasmResolvers -} - -func (m *Manager) setWasmResolvers(rs []*wasm.Resolver) { - m.wasmResolversMtx.Lock() - defer m.wasmResolversMtx.Unlock() - m.wasmResolvers = rs -} - -// Start starts the manager. Init() should be called once before Start(). -func (m *Manager) Start(ctx context.Context) error { - - if m == nil { - return nil - } - - if !m.initialized { - if err := m.Init(ctx); err != nil { - return err - } - } - - var toStart []Plugin - - func() { - m.mtx.Lock() - defer m.mtx.Unlock() - toStart = make([]Plugin, len(m.plugins)) - for i := range m.plugins { - toStart[i] = m.plugins[i].plugin - } - }() - - for i := range toStart { - if err := toStart[i].Start(ctx); err != nil { - return err - } - } - - return nil -} - -// Stop stops the manager, stopping all the plugins registered with it. -// Any plugin that needs to perform cleanup should do so within the duration -// of the graceful shutdown period passed with the context as a timeout. -// Note that a graceful shutdown period configured with the Manager instance -// will override the timeout of the passed in context (if applicable). -func (m *Manager) Stop(ctx context.Context) { - var toStop []Plugin - - func() { - m.mtx.Lock() - defer m.mtx.Unlock() - toStop = make([]Plugin, len(m.plugins)) - for i := range m.plugins { - toStop[i] = m.plugins[i].plugin - } - }() - - var cancel context.CancelFunc - if m.gracefulShutdownPeriod > 0 { - ctx, cancel = context.WithTimeout(ctx, time.Duration(m.gracefulShutdownPeriod)*time.Second) - } else { - ctx, cancel = context.WithCancel(ctx) - } - defer cancel() - for i := range toStop { - toStop[i].Stop(ctx) - } - if c, ok := m.Store.(interface{ Close(context.Context) error }); ok { - if err := c.Close(ctx); err != nil { - m.logger.Error("Error closing store: %v", err) - } - } - - if m.stop != nil { - done := make(chan struct{}) - m.stop <- done - <-done - } -} - -// Reconfigure updates the configuration on the manager. -func (m *Manager) Reconfigure(config *config.Config) error { - opts := cfg.ServiceOptions{ - Raw: config.Services, - AuthPlugin: m.AuthPlugin, - Logger: m.logger, - DistributedTacingOpts: m.distributedTacingOpts, - } - - keys, err := keys.ParseKeysConfig(config.Keys) - if err != nil { - return err - } - opts.Keys = keys - - services, err := cfg.ParseServicesConfig(opts) - if err != nil { - return err - } - - interQueryBuiltinCacheConfig, err := cache.ParseCachingConfig(config.Caching) - if err != nil { - return err - } - - m.mtx.Lock() - defer m.mtx.Unlock() - - // don't overwrite existing labels, only allow additions - always based on the boostrap config - if config.Labels == nil { - config.Labels = m.bootstrapConfigLabels - } else { - for label, value := range m.bootstrapConfigLabels { - config.Labels[label] = value - } - } - - // don't erase persistence directory - if config.PersistenceDirectory == nil { - config.PersistenceDirectory = m.Config.PersistenceDirectory - } - - m.Config = config - m.interQueryBuiltinCacheConfig = interQueryBuiltinCacheConfig - for name, client := range services { - m.services[name] = client - } - - for name, key := range keys { - m.keys[name] = key - } - - for _, trigger := range m.registeredCacheTriggers { - trigger(interQueryBuiltinCacheConfig) - } - - for _, trigger := range m.registeredNDCacheTriggers { - trigger(config.NDBuiltinCache) - } - - return nil -} - -// PluginStatus returns the current statuses of any plugins registered. -func (m *Manager) PluginStatus() map[string]*Status { - m.mtx.Lock() - defer m.mtx.Unlock() - - return m.copyPluginStatus() -} - -// RegisterPluginStatusListener registers a StatusListener to be -// called when plugin status updates occur. -func (m *Manager) RegisterPluginStatusListener(name string, listener StatusListener) { - m.mtx.Lock() - defer m.mtx.Unlock() - - m.pluginStatusListeners[name] = listener -} - -// UnregisterPluginStatusListener removes a StatusListener registered with the -// same name. -func (m *Manager) UnregisterPluginStatusListener(name string) { - m.mtx.Lock() - defer m.mtx.Unlock() - - delete(m.pluginStatusListeners, name) -} - -// UpdatePluginStatus updates a named plugins status. Any registered -// listeners will be called with a copy of the new state of all -// plugins. -func (m *Manager) UpdatePluginStatus(pluginName string, status *Status) { - - var toNotify map[string]StatusListener - var statuses map[string]*Status - - func() { - m.mtx.Lock() - defer m.mtx.Unlock() - m.pluginStatus[pluginName] = status - toNotify = make(map[string]StatusListener, len(m.pluginStatusListeners)) - for k, v := range m.pluginStatusListeners { - toNotify[k] = v - } - statuses = m.copyPluginStatus() - }() - - for _, l := range toNotify { - l(statuses) - } -} - -func (m *Manager) copyPluginStatus() map[string]*Status { - statusCpy := map[string]*Status{} - for k, v := range m.pluginStatus { - var cpy *Status - if v != nil { - cpy = &Status{ - State: v.State, - Message: v.Message, - } - } - statusCpy[k] = cpy - } - return statusCpy -} - -func (m *Manager) onCommit(ctx context.Context, txn storage.Transaction, event storage.TriggerEvent) { - - compiler := GetCompilerOnContext(event.Context) - - // If the context does not contain the compiler fallback to loading the - // compiler from the store. Currently the bundle plugin sets the - // compiler on the context but the server does not (nor would users - // implementing their own policy loading.) - if compiler == nil && event.PolicyChanged() { - compiler, _ = loadCompilerFromStore(ctx, m.Store, txn, m.enablePrintStatements, m.ParserOptions()) - } - - if compiler != nil { - m.setCompiler(compiler) - - if m.enableTelemetry && event.PolicyChanged() { - m.opaReportNotifyCh <- struct{}{} - } - - for _, f := range m.registeredTriggers { - f(txn) - } - } - - // Similar to the compiler, look for a set of resolvers on the transaction - // context. If they are not set we may need to reload from the store. - resolvers := getWasmResolversOnContext(event.Context) - if resolvers != nil { - m.setWasmResolvers(resolvers) - - } else if event.DataChanged() { - if requiresWasmResolverReload(event) { - resolvers, err := bundleUtils.LoadWasmResolversFromStore(ctx, m.Store, txn, nil) - if err != nil { - panic(err) - } - m.setWasmResolvers(resolvers) - } else { - err := m.updateWasmResolversData(ctx, event) - if err != nil { - panic(err) - } - } - } -} - -func loadCompilerFromStore(ctx context.Context, store storage.Store, txn storage.Transaction, enablePrintStatements bool, popts ast.ParserOptions) (*ast.Compiler, error) { - policies, err := store.ListPolicies(ctx, txn) - if err != nil { - return nil, err - } - modules := map[string]*ast.Module{} - - for _, policy := range policies { - bs, err := store.GetPolicy(ctx, txn, policy) - if err != nil { - return nil, err - } - module, err := ast.ParseModuleWithOpts(policy, string(bs), popts) - if err != nil { - return nil, err - } - modules[policy] = module - } - - compiler := ast.NewCompiler().WithEnablePrintStatements(enablePrintStatements) - compiler.Compile(modules) - return compiler, nil -} - -func requiresWasmResolverReload(event storage.TriggerEvent) bool { - // If the data changes touched the bundle path (which includes - // the wasm modules) we will reload them. Otherwise update - // data for each module already on the manager. - for _, dataEvent := range event.Data { - if dataEvent.Path.HasPrefix(bundle.BundlesBasePath) { - return true - } - } - return false -} - -func (m *Manager) updateWasmResolversData(ctx context.Context, event storage.TriggerEvent) error { - m.wasmResolversMtx.Lock() - defer m.wasmResolversMtx.Unlock() - - for _, resolver := range m.wasmResolvers { - for _, dataEvent := range event.Data { - var err error - if dataEvent.Removed { - err = resolver.RemoveDataPath(ctx, dataEvent.Path) - } else { - err = resolver.SetDataPath(ctx, dataEvent.Path, dataEvent.Data) - } - if err != nil { - return fmt.Errorf("failed to update wasm runtime data: %s", err) - } - } - } - return nil -} - -// PublicKeys returns a public keys that can be used for verifying signed bundles. -func (m *Manager) PublicKeys() map[string]*keys.Config { - m.mtx.Lock() - defer m.mtx.Unlock() - return m.keys -} - -// Client returns a client for communicating with a remote service. -func (m *Manager) Client(name string) rest.Client { - m.mtx.Lock() - defer m.mtx.Unlock() - return m.services[name] -} - -// Services returns a list of services that m can provide clients for. -func (m *Manager) Services() []string { - m.mtx.Lock() - defer m.mtx.Unlock() - s := make([]string, 0, len(m.services)) - for name := range m.services { - s = append(s, name) - } - return s -} - -// Logger gets the standard logger for this plugin manager. -func (m *Manager) Logger() logging.Logger { - return m.logger -} - -// ConsoleLogger gets the console logger for this plugin manager. -func (m *Manager) ConsoleLogger() logging.Logger { - return m.consoleLogger -} - -func (m *Manager) PrintHook() print.Hook { - return m.printHook -} - -func (m *Manager) EnablePrintStatements() bool { - return m.enablePrintStatements -} - -// ServerInitialized signals a channel indicating that the OPA -// server has finished initialization. -func (m *Manager) ServerInitialized() { - m.serverInitializedOnce.Do(func() { close(m.serverInitialized) }) -} - -// ServerInitializedChannel returns a receive-only channel that -// is closed when the OPA server has finished initialization. -// Be aware that the socket of the server listener may not be -// open by the time this channel is closed. There is a very -// small window where the socket may still be closed, due to -// a race condition. -func (m *Manager) ServerInitializedChannel() <-chan struct{} { - return m.serverInitialized -} - -// RegisterCacheTrigger accepts a func that receives new inter-query cache config generated by -// a reconfigure of the plugin manager, so that it can be propagated to existing inter-query caches. -func (m *Manager) RegisterCacheTrigger(trigger func(*cache.Config)) { - m.mtx.Lock() - defer m.mtx.Unlock() - m.registeredCacheTriggers = append(m.registeredCacheTriggers, trigger) -} - -// PrometheusRegister gets the prometheus.Registerer for this plugin manager. -func (m *Manager) PrometheusRegister() prometheus.Registerer { - return m.prometheusRegister -} - -// TracerProvider gets the *trace.TracerProvider for this plugin manager. -func (m *Manager) TracerProvider() *trace.TracerProvider { - return m.tracerProvider -} - -func (m *Manager) RegisterNDCacheTrigger(trigger func(bool)) { - m.mtx.Lock() - defer m.mtx.Unlock() - m.registeredNDCacheTriggers = append(m.registeredNDCacheTriggers, trigger) -} - -func (m *Manager) sendOPAUpdateLoop(ctx context.Context) { - ticker := time.NewTicker(time.Duration(int64(time.Second) * defaultUploadIntervalSec)) - mr.New(mr.NewSource(time.Now().UnixNano())) - - ctx, cancel := context.WithCancel(ctx) - - var opaReportNotify bool - - for { - select { - case <-m.opaReportNotifyCh: - opaReportNotify = true - case <-ticker.C: - ticker.Stop() - - if opaReportNotify { - opaReportNotify = false - _, err := m.reporter.SendReport(ctx) - if err != nil { - m.logger.WithFields(map[string]interface{}{"err": err}).Debug("Unable to send OPA telemetry report.") - } - } - - newInterval := mr.Int63n(defaultUploadIntervalSec) + defaultUploadIntervalSec - ticker = time.NewTicker(time.Duration(int64(time.Second) * newInterval)) - case done := <-m.stop: - cancel() - ticker.Stop() - done <- struct{}{} - return - } - } -} - -func (m *Manager) ParserOptions() ast.ParserOptions { - return m.parserOptions + return v1.New(raw, id, store, options...) } diff --git a/vendor/github.com/open-policy-agent/opa/plugins/status/metrics.go b/vendor/github.com/open-policy-agent/opa/plugins/status/metrics.go deleted file mode 100644 index aa27e1cd3..000000000 --- a/vendor/github.com/open-policy-agent/opa/plugins/status/metrics.go +++ /dev/null @@ -1,82 +0,0 @@ -package status - -import ( - "github.com/open-policy-agent/opa/version" - "github.com/prometheus/client_golang/prometheus" -) - -var ( - opaInfo = prometheus.NewGauge( - prometheus.GaugeOpts{ - Name: "opa_info", - Help: "Information about the OPA environment.", - ConstLabels: map[string]string{"version": version.Version}, - }, - ) - pluginStatus = prometheus.NewGaugeVec( - prometheus.GaugeOpts{ - Name: "plugin_status_gauge", - Help: "Gauge for the plugin by status."}, - []string{"name", "status"}, - ) - loaded = prometheus.NewCounterVec( - prometheus.CounterOpts{ - Name: "bundle_loaded_counter", - Help: "Counter for the bundle loaded."}, - []string{"name"}, - ) - failLoad = prometheus.NewCounterVec( - prometheus.CounterOpts{ - Name: "bundle_failed_load_counter", - Help: "Counter for the failed bundle load."}, - []string{"name", "code", "message"}, - ) - lastRequest = prometheus.NewGaugeVec( - prometheus.GaugeOpts{ - Name: "last_bundle_request", - Help: "Gauge for the last bundle request."}, - []string{"name"}, - ) - lastSuccessfulActivation = prometheus.NewGaugeVec( - prometheus.GaugeOpts{ - Name: "last_success_bundle_activation", - Help: "Gauge for the last success bundle activation."}, - []string{"name", "active_revision"}, - ) - lastSuccessfulDownload = prometheus.NewGaugeVec( - prometheus.GaugeOpts{ - Name: "last_success_bundle_download", - Help: "Gauge for the last success bundle download."}, - []string{"name"}, - ) - lastSuccessfulRequest = prometheus.NewGaugeVec( - prometheus.GaugeOpts{ - Name: "last_success_bundle_request", - Help: "Gauge for the last success bundle request."}, - []string{"name"}, - ) - bundleLoadDuration = prometheus.NewHistogramVec(prometheus.HistogramOpts{ - Name: "bundle_loading_duration_ns", - Help: "Histogram for the bundle loading duration by stage.", - Buckets: prometheus.ExponentialBuckets(1000, 2, 20), - }, []string{"name", "stage"}) - - // allCollectors is a list of all collectors maintained by the status plugin. - // Note: when adding a new collector, make sure to also add it to this list, - // or it won't survive status plugin reconfigure events. - allCollectors = []prometheus.Collector{ - opaInfo, - pluginStatus, - loaded, - failLoad, - lastRequest, - lastSuccessfulActivation, - lastSuccessfulDownload, - lastSuccessfulRequest, - bundleLoadDuration, - } -) - -func init() { - opaInfo.Set(1) -} diff --git a/vendor/github.com/open-policy-agent/opa/rego/doc.go b/vendor/github.com/open-policy-agent/opa/rego/doc.go new file mode 100644 index 000000000..febe75696 --- /dev/null +++ b/vendor/github.com/open-policy-agent/opa/rego/doc.go @@ -0,0 +1,8 @@ +// Copyright 2024 The OPA Authors. All rights reserved. +// Use of this source code is governed by an Apache2 +// license that can be found in the LICENSE file. + +// Deprecated: This package is intended for older projects transitioning from OPA v0.x and will remain for the lifetime of OPA v1.x, but its use is not recommended. +// For newer features and behaviours, such as defaulting to the Rego v1 syntax, use the corresponding components in the [github.com/open-policy-agent/opa/v1] package instead. +// See https://www.openpolicyagent.org/docs/latest/v0-compatibility/ for more information. +package rego diff --git a/vendor/github.com/open-policy-agent/opa/rego/errors.go b/vendor/github.com/open-policy-agent/opa/rego/errors.go index dcc5e2679..bcbd2efed 100644 --- a/vendor/github.com/open-policy-agent/opa/rego/errors.go +++ b/vendor/github.com/open-policy-agent/opa/rego/errors.go @@ -1,24 +1,17 @@ package rego +import v1 "github.com/open-policy-agent/opa/v1/rego" + // HaltError is an error type to return from a custom function implementation // that will abort the evaluation process (analogous to topdown.Halt). -type HaltError struct { - err error -} - -// Error delegates to the wrapped error -func (h *HaltError) Error() string { - return h.err.Error() -} +type HaltError = v1.HaltError // NewHaltError wraps an error such that the evaluation process will stop // when it occurs. func NewHaltError(err error) error { - return &HaltError{err: err} + return v1.NewHaltError(err) } // ErrorDetails interface is satisfied by an error that provides further // details. -type ErrorDetails interface { - Lines() []string -} +type ErrorDetails = v1.ErrorDetails diff --git a/vendor/github.com/open-policy-agent/opa/rego/plugins.go b/vendor/github.com/open-policy-agent/opa/rego/plugins.go index abaa91034..38ef84416 100644 --- a/vendor/github.com/open-policy-agent/opa/rego/plugins.go +++ b/vendor/github.com/open-policy-agent/opa/rego/plugins.go @@ -5,39 +5,13 @@ package rego import ( - "context" - "sync" - - "github.com/open-policy-agent/opa/ast" - "github.com/open-policy-agent/opa/ir" + v1 "github.com/open-policy-agent/opa/v1/rego" ) -var targetPlugins = map[string]TargetPlugin{} -var pluginMtx sync.Mutex - -type TargetPlugin interface { - IsTarget(string) bool - PrepareForEval(context.Context, *ir.Policy, ...PrepareOption) (TargetPluginEval, error) -} - -type TargetPluginEval interface { - Eval(context.Context, *EvalContext, ast.Value) (ast.Value, error) -} +type TargetPlugin = v1.TargetPlugin -func (r *Rego) targetPlugin(tgt string) TargetPlugin { - for _, p := range targetPlugins { - if p.IsTarget(tgt) { - return p - } - } - return nil -} +type TargetPluginEval = v1.TargetPluginEval func RegisterPlugin(name string, p TargetPlugin) { - pluginMtx.Lock() - defer pluginMtx.Unlock() - if _, ok := targetPlugins[name]; ok { - panic("plugin already registered " + name) - } - targetPlugins[name] = p + v1.RegisterPlugin(name, p) } diff --git a/vendor/github.com/open-policy-agent/opa/rego/rego.go b/vendor/github.com/open-policy-agent/opa/rego/rego.go index 64b4b9b93..e6af30c39 100644 --- a/vendor/github.com/open-policy-agent/opa/rego/rego.go +++ b/vendor/github.com/open-policy-agent/opa/rego/rego.go @@ -6,958 +6,367 @@ package rego import ( - "bytes" - "context" - "errors" - "fmt" "io" - "strings" "time" "github.com/open-policy-agent/opa/ast" "github.com/open-policy-agent/opa/bundle" - bundleUtils "github.com/open-policy-agent/opa/internal/bundle" - "github.com/open-policy-agent/opa/internal/compiler/wasm" - "github.com/open-policy-agent/opa/internal/future" - "github.com/open-policy-agent/opa/internal/planner" - "github.com/open-policy-agent/opa/internal/rego/opa" - "github.com/open-policy-agent/opa/internal/wasm/encoding" - "github.com/open-policy-agent/opa/ir" "github.com/open-policy-agent/opa/loader" - "github.com/open-policy-agent/opa/metrics" - "github.com/open-policy-agent/opa/plugins" - "github.com/open-policy-agent/opa/resolver" "github.com/open-policy-agent/opa/storage" - "github.com/open-policy-agent/opa/storage/inmem" - "github.com/open-policy-agent/opa/topdown" - "github.com/open-policy-agent/opa/topdown/builtins" - "github.com/open-policy-agent/opa/topdown/cache" - "github.com/open-policy-agent/opa/topdown/print" - "github.com/open-policy-agent/opa/tracing" - "github.com/open-policy-agent/opa/types" - "github.com/open-policy-agent/opa/util" -) - -const ( - defaultPartialNamespace = "partial" - wasmVarPrefix = "^" -) - -// nolint: deadcode,varcheck -const ( - targetWasm = "wasm" - targetRego = "rego" + "github.com/open-policy-agent/opa/v1/metrics" + v1 "github.com/open-policy-agent/opa/v1/rego" + "github.com/open-policy-agent/opa/v1/resolver" + "github.com/open-policy-agent/opa/v1/topdown" + "github.com/open-policy-agent/opa/v1/topdown/builtins" + "github.com/open-policy-agent/opa/v1/topdown/cache" + "github.com/open-policy-agent/opa/v1/topdown/print" + "github.com/open-policy-agent/opa/v1/tracing" ) // CompileResult represents the result of compiling a Rego query, zero or more // Rego modules, and arbitrary contextual data into an executable. -type CompileResult struct { - Bytes []byte `json:"bytes"` -} +type CompileResult = v1.CompileResult // PartialQueries contains the queries and support modules produced by partial // evaluation. -type PartialQueries struct { - Queries []ast.Body `json:"queries,omitempty"` - Support []*ast.Module `json:"modules,omitempty"` -} +type PartialQueries = v1.PartialQueries // PartialResult represents the result of partial evaluation. The result can be // used to generate a new query that can be run when inputs are known. -type PartialResult struct { - compiler *ast.Compiler - store storage.Store - body ast.Body - builtinDecls map[string]*ast.Builtin - builtinFuncs map[string]*topdown.Builtin -} - -// Rego returns an object that can be evaluated to produce a query result. -func (pr PartialResult) Rego(options ...func(*Rego)) *Rego { - options = append(options, Compiler(pr.compiler), Store(pr.store), ParsedQuery(pr.body)) - r := New(options...) - - // Propagate any custom builtins. - for k, v := range pr.builtinDecls { - r.builtinDecls[k] = v - } - for k, v := range pr.builtinFuncs { - r.builtinFuncs[k] = v - } - return r -} - -// preparedQuery is a wrapper around a Rego object which has pre-processed -// state stored on it. Once prepared there are a more limited number of actions -// that can be taken with it. It will, however, be able to evaluate faster since -// it will not have to re-parse or compile as much. -type preparedQuery struct { - r *Rego - cfg *PrepareConfig -} +type PartialResult = v1.PartialResult // EvalContext defines the set of options allowed to be set at evaluation // time. Any other options will need to be set on a new Rego object. -type EvalContext struct { - hasInput bool - time time.Time - seed io.Reader - rawInput *interface{} - parsedInput ast.Value - metrics metrics.Metrics - txn storage.Transaction - instrument bool - instrumentation *topdown.Instrumentation - partialNamespace string - queryTracers []topdown.QueryTracer - compiledQuery compiledQuery - unknowns []string - disableInlining []ast.Ref - parsedUnknowns []*ast.Term - indexing bool - earlyExit bool - interQueryBuiltinCache cache.InterQueryCache - interQueryBuiltinValueCache cache.InterQueryValueCache - ndBuiltinCache builtins.NDBCache - resolvers []refResolver - sortSets bool - copyMaps bool - printHook print.Hook - capabilities *ast.Capabilities - strictBuiltinErrors bool - virtualCache topdown.VirtualCache -} - -func (e *EvalContext) RawInput() *interface{} { - return e.rawInput -} - -func (e *EvalContext) ParsedInput() ast.Value { - return e.parsedInput -} - -func (e *EvalContext) Time() time.Time { - return e.time -} - -func (e *EvalContext) Seed() io.Reader { - return e.seed -} - -func (e *EvalContext) InterQueryBuiltinCache() cache.InterQueryCache { - return e.interQueryBuiltinCache -} - -func (e *EvalContext) InterQueryBuiltinValueCache() cache.InterQueryValueCache { - return e.interQueryBuiltinValueCache -} - -func (e *EvalContext) PrintHook() print.Hook { - return e.printHook -} - -func (e *EvalContext) Metrics() metrics.Metrics { - return e.metrics -} - -func (e *EvalContext) StrictBuiltinErrors() bool { - return e.strictBuiltinErrors -} - -func (e *EvalContext) NDBCache() builtins.NDBCache { - return e.ndBuiltinCache -} - -func (e *EvalContext) CompiledQuery() ast.Body { - return e.compiledQuery.query -} - -func (e *EvalContext) Capabilities() *ast.Capabilities { - return e.capabilities -} - -func (e *EvalContext) Transaction() storage.Transaction { - return e.txn -} +type EvalContext = v1.EvalContext // EvalOption defines a function to set an option on an EvalConfig -type EvalOption func(*EvalContext) +type EvalOption = v1.EvalOption // EvalInput configures the input for a Prepared Query's evaluation func EvalInput(input interface{}) EvalOption { - return func(e *EvalContext) { - e.rawInput = &input - e.hasInput = true - } + return v1.EvalInput(input) } // EvalParsedInput configures the input for a Prepared Query's evaluation func EvalParsedInput(input ast.Value) EvalOption { - return func(e *EvalContext) { - e.parsedInput = input - e.hasInput = true - } + return v1.EvalParsedInput(input) } // EvalMetrics configures the metrics for a Prepared Query's evaluation func EvalMetrics(metric metrics.Metrics) EvalOption { - return func(e *EvalContext) { - e.metrics = metric - } + return v1.EvalMetrics(metric) } // EvalTransaction configures the Transaction for a Prepared Query's evaluation func EvalTransaction(txn storage.Transaction) EvalOption { - return func(e *EvalContext) { - e.txn = txn - } + return v1.EvalTransaction(txn) } // EvalInstrument enables or disables instrumenting for a Prepared Query's evaluation func EvalInstrument(instrument bool) EvalOption { - return func(e *EvalContext) { - e.instrument = instrument - } + return v1.EvalInstrument(instrument) } // EvalTracer configures a tracer for a Prepared Query's evaluation // Deprecated: Use EvalQueryTracer instead. func EvalTracer(tracer topdown.Tracer) EvalOption { - return func(e *EvalContext) { - if tracer != nil { - e.queryTracers = append(e.queryTracers, topdown.WrapLegacyTracer(tracer)) - } - } + return v1.EvalTracer(tracer) } // EvalQueryTracer configures a tracer for a Prepared Query's evaluation func EvalQueryTracer(tracer topdown.QueryTracer) EvalOption { - return func(e *EvalContext) { - if tracer != nil { - e.queryTracers = append(e.queryTracers, tracer) - } - } + return v1.EvalQueryTracer(tracer) } // EvalPartialNamespace returns an argument that sets the namespace to use for // partial evaluation results. The namespace must be a valid package path // component. func EvalPartialNamespace(ns string) EvalOption { - return func(e *EvalContext) { - e.partialNamespace = ns - } + return v1.EvalPartialNamespace(ns) } // EvalUnknowns returns an argument that sets the values to treat as // unknown during partial evaluation. func EvalUnknowns(unknowns []string) EvalOption { - return func(e *EvalContext) { - e.unknowns = unknowns - } + return v1.EvalUnknowns(unknowns) } // EvalDisableInlining returns an argument that adds a set of paths to exclude from // partial evaluation inlining. func EvalDisableInlining(paths []ast.Ref) EvalOption { - return func(e *EvalContext) { - e.disableInlining = paths - } + return v1.EvalDisableInlining(paths) } // EvalParsedUnknowns returns an argument that sets the values to treat // as unknown during partial evaluation. func EvalParsedUnknowns(unknowns []*ast.Term) EvalOption { - return func(e *EvalContext) { - e.parsedUnknowns = unknowns - } + return v1.EvalParsedUnknowns(unknowns) } // EvalRuleIndexing will disable indexing optimizations for the // evaluation. This should only be used when tracing in debug mode. func EvalRuleIndexing(enabled bool) EvalOption { - return func(e *EvalContext) { - e.indexing = enabled - } + return v1.EvalRuleIndexing(enabled) } // EvalEarlyExit will disable 'early exit' optimizations for the // evaluation. This should only be used when tracing in debug mode. func EvalEarlyExit(enabled bool) EvalOption { - return func(e *EvalContext) { - e.earlyExit = enabled - } + return v1.EvalEarlyExit(enabled) } // EvalTime sets the wall clock time to use during policy evaluation. // time.now_ns() calls will return this value. func EvalTime(x time.Time) EvalOption { - return func(e *EvalContext) { - e.time = x - } + return v1.EvalTime(x) } // EvalSeed sets a reader that will seed randomization required by built-in functions. // If a seed is not provided crypto/rand.Reader is used. func EvalSeed(r io.Reader) EvalOption { - return func(e *EvalContext) { - e.seed = r - } + return v1.EvalSeed(r) } // EvalInterQueryBuiltinCache sets the inter-query cache that built-in functions can utilize // during evaluation. func EvalInterQueryBuiltinCache(c cache.InterQueryCache) EvalOption { - return func(e *EvalContext) { - e.interQueryBuiltinCache = c - } + return v1.EvalInterQueryBuiltinCache(c) } // EvalInterQueryBuiltinValueCache sets the inter-query value cache that built-in functions can utilize // during evaluation. func EvalInterQueryBuiltinValueCache(c cache.InterQueryValueCache) EvalOption { - return func(e *EvalContext) { - e.interQueryBuiltinValueCache = c - } + return v1.EvalInterQueryBuiltinValueCache(c) } // EvalNDBuiltinCache sets the non-deterministic builtin cache that built-in functions can // use during evaluation. func EvalNDBuiltinCache(c builtins.NDBCache) EvalOption { - return func(e *EvalContext) { - e.ndBuiltinCache = c - } + return v1.EvalNDBuiltinCache(c) } // EvalResolver sets a Resolver for a specified ref path for this evaluation. func EvalResolver(ref ast.Ref, r resolver.Resolver) EvalOption { - return func(e *EvalContext) { - e.resolvers = append(e.resolvers, refResolver{ref, r}) - } + return v1.EvalResolver(ref, r) } // EvalSortSets causes the evaluator to sort sets before returning them as JSON arrays. func EvalSortSets(yes bool) EvalOption { - return func(e *EvalContext) { - e.sortSets = yes - } + return v1.EvalSortSets(yes) } // EvalCopyMaps causes the evaluator to copy `map[string]interface{}`s before returning them. func EvalCopyMaps(yes bool) EvalOption { - return func(e *EvalContext) { - e.copyMaps = yes - } + return v1.EvalCopyMaps(yes) } // EvalPrintHook sets the object to use for handling print statement outputs. func EvalPrintHook(ph print.Hook) EvalOption { - return func(e *EvalContext) { - e.printHook = ph - } + return v1.EvalPrintHook(ph) } // EvalVirtualCache sets the topdown.VirtualCache to use for evaluation. This is // optional, and if not set, the default cache is used. func EvalVirtualCache(vc topdown.VirtualCache) EvalOption { - return func(e *EvalContext) { - e.virtualCache = vc - } -} - -func (pq preparedQuery) Modules() map[string]*ast.Module { - mods := make(map[string]*ast.Module) - - for name, mod := range pq.r.parsedModules { - mods[name] = mod - } - - for _, b := range pq.r.bundles { - for _, mod := range b.Modules { - mods[mod.Path] = mod.Parsed - } - } - - return mods -} - -// newEvalContext creates a new EvalContext overlaying any EvalOptions over top -// the Rego object on the preparedQuery. The returned function should be called -// once the evaluation is complete to close any transactions that might have -// been opened. -func (pq preparedQuery) newEvalContext(ctx context.Context, options []EvalOption) (*EvalContext, func(context.Context), error) { - ectx := &EvalContext{ - hasInput: false, - rawInput: nil, - parsedInput: nil, - metrics: nil, - txn: nil, - instrument: false, - instrumentation: nil, - partialNamespace: pq.r.partialNamespace, - queryTracers: nil, - unknowns: pq.r.unknowns, - parsedUnknowns: pq.r.parsedUnknowns, - compiledQuery: compiledQuery{}, - indexing: true, - earlyExit: true, - resolvers: pq.r.resolvers, - printHook: pq.r.printHook, - capabilities: pq.r.capabilities, - strictBuiltinErrors: pq.r.strictBuiltinErrors, - } - - for _, o := range options { - o(ectx) - } - - if ectx.metrics == nil { - ectx.metrics = metrics.New() - } - - if ectx.instrument { - ectx.instrumentation = topdown.NewInstrumentation(ectx.metrics) - } - - // Default to an empty "finish" function - finishFunc := func(context.Context) {} - - var err error - ectx.disableInlining, err = parseStringsToRefs(pq.r.disableInlining) - if err != nil { - return nil, finishFunc, err - } - - if ectx.txn == nil { - ectx.txn, err = pq.r.store.NewTransaction(ctx) - if err != nil { - return nil, finishFunc, err - } - finishFunc = func(ctx context.Context) { - pq.r.store.Abort(ctx, ectx.txn) - } - } - - // If we didn't get an input specified in the Eval options - // then fall back to the Rego object's input fields. - if !ectx.hasInput { - ectx.rawInput = pq.r.rawInput - ectx.parsedInput = pq.r.parsedInput - } - - if ectx.parsedInput == nil { - if ectx.rawInput == nil { - // Fall back to the original Rego objects input if none was specified - // Note that it could still be nil - ectx.rawInput = pq.r.rawInput - } - - if pq.r.targetPlugin(pq.r.target) == nil && // no plugin claims this target - pq.r.target != targetWasm { - ectx.parsedInput, err = pq.r.parseRawInput(ectx.rawInput, ectx.metrics) - if err != nil { - return nil, finishFunc, err - } - } - } - - return ectx, finishFunc, nil + return v1.EvalVirtualCache(vc) } // PreparedEvalQuery holds the prepared Rego state that has been pre-processed // for subsequent evaluations. -type PreparedEvalQuery struct { - preparedQuery -} - -// Eval evaluates this PartialResult's Rego object with additional eval options -// and returns a ResultSet. -// If options are provided they will override the original Rego options respective value. -// The original Rego object transaction will *not* be re-used. A new transaction will be opened -// if one is not provided with an EvalOption. -func (pq PreparedEvalQuery) Eval(ctx context.Context, options ...EvalOption) (ResultSet, error) { - ectx, finish, err := pq.newEvalContext(ctx, options) - if err != nil { - return nil, err - } - defer finish(ctx) - - ectx.compiledQuery = pq.r.compiledQueries[evalQueryType] - - return pq.r.eval(ctx, ectx) -} +type PreparedEvalQuery = v1.PreparedEvalQuery // PreparedPartialQuery holds the prepared Rego state that has been pre-processed // for partial evaluations. -type PreparedPartialQuery struct { - preparedQuery -} - -// Partial runs partial evaluation on the prepared query and returns the result. -// The original Rego object transaction will *not* be re-used. A new transaction will be opened -// if one is not provided with an EvalOption. -func (pq PreparedPartialQuery) Partial(ctx context.Context, options ...EvalOption) (*PartialQueries, error) { - ectx, finish, err := pq.newEvalContext(ctx, options) - if err != nil { - return nil, err - } - defer finish(ctx) - - ectx.compiledQuery = pq.r.compiledQueries[partialQueryType] - - return pq.r.partial(ctx, ectx) -} +type PreparedPartialQuery = v1.PreparedPartialQuery // Errors represents a collection of errors returned when evaluating Rego. -type Errors []error - -func (errs Errors) Error() string { - if len(errs) == 0 { - return "no error" - } - if len(errs) == 1 { - return fmt.Sprintf("1 error occurred: %v", errs[0].Error()) - } - buf := []string{fmt.Sprintf("%v errors occurred", len(errs))} - for _, err := range errs { - buf = append(buf, err.Error()) - } - return strings.Join(buf, "\n") -} - -var errPartialEvaluationNotEffective = errors.New("partial evaluation not effective") +type Errors = v1.Errors // IsPartialEvaluationNotEffectiveErr returns true if err is an error returned by // this package to indicate that partial evaluation was ineffective. func IsPartialEvaluationNotEffectiveErr(err error) bool { - errs, ok := err.(Errors) - if !ok { - return false - } - return len(errs) == 1 && errs[0] == errPartialEvaluationNotEffective -} - -type compiledQuery struct { - query ast.Body - compiler ast.QueryCompiler -} - -type queryType int - -// Define a query type for each of the top level Rego -// API's that compile queries differently. -const ( - evalQueryType queryType = iota - partialResultQueryType - partialQueryType - compileQueryType -) - -type loadPaths struct { - paths []string - filter loader.Filter + return v1.IsPartialEvaluationNotEffectiveErr(err) } // Rego constructs a query and can be evaluated to obtain results. -type Rego struct { - query string - parsedQuery ast.Body - compiledQueries map[queryType]compiledQuery - pkg string - parsedPackage *ast.Package - imports []string - parsedImports []*ast.Import - rawInput *interface{} - parsedInput ast.Value - unknowns []string - parsedUnknowns []*ast.Term - disableInlining []string - shallowInlining bool - skipPartialNamespace bool - partialNamespace string - modules []rawModule - parsedModules map[string]*ast.Module - compiler *ast.Compiler - store storage.Store - ownStore bool - ownStoreReadAst bool - txn storage.Transaction - metrics metrics.Metrics - queryTracers []topdown.QueryTracer - tracebuf *topdown.BufferTracer - trace bool - instrumentation *topdown.Instrumentation - instrument bool - capture map[*ast.Expr]ast.Var // map exprs to generated capture vars - termVarID int - dump io.Writer - runtime *ast.Term - time time.Time - seed io.Reader - capabilities *ast.Capabilities - builtinDecls map[string]*ast.Builtin - builtinFuncs map[string]*topdown.Builtin - unsafeBuiltins map[string]struct{} - loadPaths loadPaths - bundlePaths []string - bundles map[string]*bundle.Bundle - skipBundleVerification bool - interQueryBuiltinCache cache.InterQueryCache - interQueryBuiltinValueCache cache.InterQueryValueCache - ndBuiltinCache builtins.NDBCache - strictBuiltinErrors bool - builtinErrorList *[]topdown.Error - resolvers []refResolver - schemaSet *ast.SchemaSet - target string // target type (wasm, rego, etc.) - opa opa.EvalEngine - generateJSON func(*ast.Term, *EvalContext) (interface{}, error) - printHook print.Hook - enablePrintStatements bool - distributedTacingOpts tracing.Options - strict bool - pluginMgr *plugins.Manager - plugins []TargetPlugin - targetPrepState TargetPluginEval - regoVersion ast.RegoVersion -} +type Rego = v1.Rego // Function represents a built-in function that is callable in Rego. -type Function struct { - Name string - Description string - Decl *types.Function - Memoize bool - Nondeterministic bool -} +type Function = v1.Function // BuiltinContext contains additional attributes from the evaluator that // built-in functions can use, e.g., the request context.Context, caches, etc. -type BuiltinContext = topdown.BuiltinContext +type BuiltinContext = v1.BuiltinContext type ( // Builtin1 defines a built-in function that accepts 1 argument. - Builtin1 func(bctx BuiltinContext, op1 *ast.Term) (*ast.Term, error) + Builtin1 = v1.Builtin1 // Builtin2 defines a built-in function that accepts 2 arguments. - Builtin2 func(bctx BuiltinContext, op1, op2 *ast.Term) (*ast.Term, error) + Builtin2 = v1.Builtin2 // Builtin3 defines a built-in function that accepts 3 argument. - Builtin3 func(bctx BuiltinContext, op1, op2, op3 *ast.Term) (*ast.Term, error) + Builtin3 = v1.Builtin3 // Builtin4 defines a built-in function that accepts 4 argument. - Builtin4 func(bctx BuiltinContext, op1, op2, op3, op4 *ast.Term) (*ast.Term, error) + Builtin4 = v1.Builtin4 // BuiltinDyn defines a built-in function that accepts a list of arguments. - BuiltinDyn func(bctx BuiltinContext, terms []*ast.Term) (*ast.Term, error) + BuiltinDyn = v1.BuiltinDyn ) // RegisterBuiltin1 adds a built-in function globally inside the OPA runtime. func RegisterBuiltin1(decl *Function, impl Builtin1) { - ast.RegisterBuiltin(&ast.Builtin{ - Name: decl.Name, - Description: decl.Description, - Decl: decl.Decl, - Nondeterministic: decl.Nondeterministic, - }) - topdown.RegisterBuiltinFunc(decl.Name, func(bctx BuiltinContext, terms []*ast.Term, iter func(*ast.Term) error) error { - result, err := memoize(decl, bctx, terms, func() (*ast.Term, error) { return impl(bctx, terms[0]) }) - return finishFunction(decl.Name, bctx, result, err, iter) - }) + v1.RegisterBuiltin1(decl, impl) } // RegisterBuiltin2 adds a built-in function globally inside the OPA runtime. func RegisterBuiltin2(decl *Function, impl Builtin2) { - ast.RegisterBuiltin(&ast.Builtin{ - Name: decl.Name, - Description: decl.Description, - Decl: decl.Decl, - Nondeterministic: decl.Nondeterministic, - }) - topdown.RegisterBuiltinFunc(decl.Name, func(bctx BuiltinContext, terms []*ast.Term, iter func(*ast.Term) error) error { - result, err := memoize(decl, bctx, terms, func() (*ast.Term, error) { return impl(bctx, terms[0], terms[1]) }) - return finishFunction(decl.Name, bctx, result, err, iter) - }) + v1.RegisterBuiltin2(decl, impl) } // RegisterBuiltin3 adds a built-in function globally inside the OPA runtime. func RegisterBuiltin3(decl *Function, impl Builtin3) { - ast.RegisterBuiltin(&ast.Builtin{ - Name: decl.Name, - Description: decl.Description, - Decl: decl.Decl, - Nondeterministic: decl.Nondeterministic, - }) - topdown.RegisterBuiltinFunc(decl.Name, func(bctx BuiltinContext, terms []*ast.Term, iter func(*ast.Term) error) error { - result, err := memoize(decl, bctx, terms, func() (*ast.Term, error) { return impl(bctx, terms[0], terms[1], terms[2]) }) - return finishFunction(decl.Name, bctx, result, err, iter) - }) + v1.RegisterBuiltin3(decl, impl) } // RegisterBuiltin4 adds a built-in function globally inside the OPA runtime. func RegisterBuiltin4(decl *Function, impl Builtin4) { - ast.RegisterBuiltin(&ast.Builtin{ - Name: decl.Name, - Description: decl.Description, - Decl: decl.Decl, - Nondeterministic: decl.Nondeterministic, - }) - topdown.RegisterBuiltinFunc(decl.Name, func(bctx BuiltinContext, terms []*ast.Term, iter func(*ast.Term) error) error { - result, err := memoize(decl, bctx, terms, func() (*ast.Term, error) { return impl(bctx, terms[0], terms[1], terms[2], terms[3]) }) - return finishFunction(decl.Name, bctx, result, err, iter) - }) + v1.RegisterBuiltin4(decl, impl) } // RegisterBuiltinDyn adds a built-in function globally inside the OPA runtime. func RegisterBuiltinDyn(decl *Function, impl BuiltinDyn) { - ast.RegisterBuiltin(&ast.Builtin{ - Name: decl.Name, - Description: decl.Description, - Decl: decl.Decl, - Nondeterministic: decl.Nondeterministic, - }) - topdown.RegisterBuiltinFunc(decl.Name, func(bctx BuiltinContext, terms []*ast.Term, iter func(*ast.Term) error) error { - result, err := memoize(decl, bctx, terms, func() (*ast.Term, error) { return impl(bctx, terms) }) - return finishFunction(decl.Name, bctx, result, err, iter) - }) + v1.RegisterBuiltinDyn(decl, impl) } // Function1 returns an option that adds a built-in function to the Rego object. func Function1(decl *Function, f Builtin1) func(*Rego) { - return newFunction(decl, func(bctx BuiltinContext, terms []*ast.Term, iter func(*ast.Term) error) error { - result, err := memoize(decl, bctx, terms, func() (*ast.Term, error) { return f(bctx, terms[0]) }) - return finishFunction(decl.Name, bctx, result, err, iter) - }) + return v1.Function1(decl, f) } // Function2 returns an option that adds a built-in function to the Rego object. func Function2(decl *Function, f Builtin2) func(*Rego) { - return newFunction(decl, func(bctx BuiltinContext, terms []*ast.Term, iter func(*ast.Term) error) error { - result, err := memoize(decl, bctx, terms, func() (*ast.Term, error) { return f(bctx, terms[0], terms[1]) }) - return finishFunction(decl.Name, bctx, result, err, iter) - }) + return v1.Function2(decl, f) } // Function3 returns an option that adds a built-in function to the Rego object. func Function3(decl *Function, f Builtin3) func(*Rego) { - return newFunction(decl, func(bctx BuiltinContext, terms []*ast.Term, iter func(*ast.Term) error) error { - result, err := memoize(decl, bctx, terms, func() (*ast.Term, error) { return f(bctx, terms[0], terms[1], terms[2]) }) - return finishFunction(decl.Name, bctx, result, err, iter) - }) + return v1.Function3(decl, f) } // Function4 returns an option that adds a built-in function to the Rego object. func Function4(decl *Function, f Builtin4) func(*Rego) { - return newFunction(decl, func(bctx BuiltinContext, terms []*ast.Term, iter func(*ast.Term) error) error { - result, err := memoize(decl, bctx, terms, func() (*ast.Term, error) { return f(bctx, terms[0], terms[1], terms[2], terms[3]) }) - return finishFunction(decl.Name, bctx, result, err, iter) - }) + return v1.Function4(decl, f) } // FunctionDyn returns an option that adds a built-in function to the Rego object. func FunctionDyn(decl *Function, f BuiltinDyn) func(*Rego) { - return newFunction(decl, func(bctx BuiltinContext, terms []*ast.Term, iter func(*ast.Term) error) error { - result, err := memoize(decl, bctx, terms, func() (*ast.Term, error) { return f(bctx, terms) }) - return finishFunction(decl.Name, bctx, result, err, iter) - }) + return v1.FunctionDyn(decl, f) } // FunctionDecl returns an option that adds a custom-built-in function // __declaration__. NO implementation is provided. This is used for // non-interpreter execution envs (e.g., Wasm). func FunctionDecl(decl *Function) func(*Rego) { - return newDecl(decl) -} - -func newDecl(decl *Function) func(*Rego) { - return func(r *Rego) { - r.builtinDecls[decl.Name] = &ast.Builtin{ - Name: decl.Name, - Decl: decl.Decl, - } - } -} - -type memo struct { - term *ast.Term - err error -} - -type memokey string - -func memoize(decl *Function, bctx BuiltinContext, terms []*ast.Term, ifEmpty func() (*ast.Term, error)) (*ast.Term, error) { - - if !decl.Memoize { - return ifEmpty() - } - - // NOTE(tsandall): we assume memoization is applied to infrequent built-in - // calls that do things like fetch data from remote locations. As such, - // converting the terms to strings is acceptable for now. - var b strings.Builder - if _, err := b.WriteString(decl.Name); err != nil { - return nil, err - } - - // The term slice _may_ include an output term depending on how the caller - // referred to the built-in function. Only use the arguments as the cache - // key. Unification ensures we don't get false positive matches. - for i := 0; i < len(decl.Decl.Args()); i++ { - if _, err := b.WriteString(terms[i].String()); err != nil { - return nil, err - } - } - - key := memokey(b.String()) - hit, ok := bctx.Cache.Get(key) - var m memo - if ok { - m = hit.(memo) - } else { - m.term, m.err = ifEmpty() - bctx.Cache.Put(key, m) - } - - return m.term, m.err + return v1.FunctionDecl(decl) } // Dump returns an argument that sets the writer to dump debugging information to. func Dump(w io.Writer) func(r *Rego) { - return func(r *Rego) { - r.dump = w - } + return v1.Dump(w) } // Query returns an argument that sets the Rego query. func Query(q string) func(r *Rego) { - return func(r *Rego) { - r.query = q - } + return v1.Query(q) } // ParsedQuery returns an argument that sets the Rego query. func ParsedQuery(q ast.Body) func(r *Rego) { - return func(r *Rego) { - r.parsedQuery = q - } + return v1.ParsedQuery(q) } // Package returns an argument that sets the Rego package on the query's // context. func Package(p string) func(r *Rego) { - return func(r *Rego) { - r.pkg = p - } + return v1.Package(p) } // ParsedPackage returns an argument that sets the Rego package on the query's // context. func ParsedPackage(pkg *ast.Package) func(r *Rego) { - return func(r *Rego) { - r.parsedPackage = pkg - } + return v1.ParsedPackage(pkg) } // Imports returns an argument that adds a Rego import to the query's context. func Imports(p []string) func(r *Rego) { - return func(r *Rego) { - r.imports = append(r.imports, p...) - } + return v1.Imports(p) } // ParsedImports returns an argument that adds Rego imports to the query's // context. func ParsedImports(imp []*ast.Import) func(r *Rego) { - return func(r *Rego) { - r.parsedImports = append(r.parsedImports, imp...) - } + return v1.ParsedImports(imp) } // Input returns an argument that sets the Rego input document. Input should be // a native Go value representing the input document. func Input(x interface{}) func(r *Rego) { - return func(r *Rego) { - r.rawInput = &x - } + return v1.Input(x) } // ParsedInput returns an argument that sets the Rego input document. func ParsedInput(x ast.Value) func(r *Rego) { - return func(r *Rego) { - r.parsedInput = x - } + return v1.ParsedInput(x) } // Unknowns returns an argument that sets the values to treat as unknown during // partial evaluation. func Unknowns(unknowns []string) func(r *Rego) { - return func(r *Rego) { - r.unknowns = unknowns - } + return v1.Unknowns(unknowns) } // ParsedUnknowns returns an argument that sets the values to treat as unknown // during partial evaluation. func ParsedUnknowns(unknowns []*ast.Term) func(r *Rego) { - return func(r *Rego) { - r.parsedUnknowns = unknowns - } + return v1.ParsedUnknowns(unknowns) } // DisableInlining adds a set of paths to exclude from partial evaluation inlining. func DisableInlining(paths []string) func(r *Rego) { - return func(r *Rego) { - r.disableInlining = paths - } + return v1.DisableInlining(paths) } // ShallowInlining prevents rules that depend on unknown values from being inlined. // Rules that only depend on known values are inlined. func ShallowInlining(yes bool) func(r *Rego) { - return func(r *Rego) { - r.shallowInlining = yes - } + return v1.ShallowInlining(yes) } // SkipPartialNamespace disables namespacing of partial evalution results for support // rules generated from policy. Synthetic support rules are still namespaced. func SkipPartialNamespace(yes bool) func(r *Rego) { - return func(r *Rego) { - r.skipPartialNamespace = yes - } + return v1.SkipPartialNamespace(yes) } // PartialNamespace returns an argument that sets the namespace to use for // partial evaluation results. The namespace must be a valid package path // component. func PartialNamespace(ns string) func(r *Rego) { - return func(r *Rego) { - r.partialNamespace = ns - } + return v1.PartialNamespace(ns) } // Module returns an argument that adds a Rego module. func Module(filename, input string) func(r *Rego) { - return func(r *Rego) { - r.modules = append(r.modules, rawModule{ - filename: filename, - module: input, - }) - } + return v1.Module(filename, input) } // ParsedModule returns an argument that adds a parsed Rego module. If a string // module with the same filename name is added, it will override the parsed // module. func ParsedModule(module *ast.Module) func(*Rego) { - return func(r *Rego) { - var filename string - if module.Package.Location != nil { - filename = module.Package.Location.File - } else { - filename = fmt.Sprintf("module_%p.rego", module) - } - r.parsedModules[filename] = module - } + return v1.ParsedModule(module) } // Load returns an argument that adds a filesystem path to load data @@ -968,9 +377,7 @@ func ParsedModule(module *ast.Module) func(*Rego) { // The Load option can only be used once. // Note: Loading files will require a write transaction on the store. func Load(paths []string, filter loader.Filter) func(r *Rego) { - return func(r *Rego) { - r.loadPaths = loadPaths{paths, filter} - } + return v1.Load(paths, filter) } // LoadBundle returns an argument that adds a filesystem path to load @@ -978,23 +385,17 @@ func Load(paths []string, filter loader.Filter) func(r *Rego) { // to be loaded as a bundle. // Note: Loading bundles will require a write transaction on the store. func LoadBundle(path string) func(r *Rego) { - return func(r *Rego) { - r.bundlePaths = append(r.bundlePaths, path) - } + return v1.LoadBundle(path) } // ParsedBundle returns an argument that adds a bundle to be loaded. func ParsedBundle(name string, b *bundle.Bundle) func(r *Rego) { - return func(r *Rego) { - r.bundles[name] = b - } + return v1.ParsedBundle(name, b) } // Compiler returns an argument that sets the Rego compiler. func Compiler(c *ast.Compiler) func(r *Rego) { - return func(r *Rego) { - r.compiler = c - } + return v1.Compiler(c) } // Store returns an argument that sets the policy engine's data storage layer. @@ -1003,18 +404,14 @@ func Compiler(c *ast.Compiler) func(r *Rego) { // must also be provided via the Transaction() option. After loading files // or bundles the transaction should be aborted or committed. func Store(s storage.Store) func(r *Rego) { - return func(r *Rego) { - r.store = s - } + return v1.Store(s) } // StoreReadAST returns an argument that sets whether the store should eagerly convert data to AST values. // // Only applicable when no store has been set on the Rego object through the Store option. func StoreReadAST(enabled bool) func(r *Rego) { - return func(r *Rego) { - r.ownStoreReadAst = enabled - } + return v1.StoreReadAST(enabled) } // Transaction returns an argument that sets the transaction to use for storage @@ -1024,93 +421,65 @@ func StoreReadAST(enabled bool) func(r *Rego) { // Store() option. If using Load(), LoadBundle(), or ParsedBundle() options // the transaction will likely require write params. func Transaction(txn storage.Transaction) func(r *Rego) { - return func(r *Rego) { - r.txn = txn - } + return v1.Transaction(txn) } // Metrics returns an argument that sets the metrics collection. func Metrics(m metrics.Metrics) func(r *Rego) { - return func(r *Rego) { - r.metrics = m - } + return v1.Metrics(m) } // Instrument returns an argument that enables instrumentation for diagnosing // performance issues. func Instrument(yes bool) func(r *Rego) { - return func(r *Rego) { - r.instrument = yes - } + return v1.Instrument(yes) } // Trace returns an argument that enables tracing on r. func Trace(yes bool) func(r *Rego) { - return func(r *Rego) { - r.trace = yes - } + return v1.Trace(yes) } // Tracer returns an argument that adds a query tracer to r. // Deprecated: Use QueryTracer instead. func Tracer(t topdown.Tracer) func(r *Rego) { - return func(r *Rego) { - if t != nil { - r.queryTracers = append(r.queryTracers, topdown.WrapLegacyTracer(t)) - } - } + return v1.Tracer(t) } // QueryTracer returns an argument that adds a query tracer to r. func QueryTracer(t topdown.QueryTracer) func(r *Rego) { - return func(r *Rego) { - if t != nil { - r.queryTracers = append(r.queryTracers, t) - } - } + return v1.QueryTracer(t) } // Runtime returns an argument that sets the runtime data to provide to the // evaluation engine. func Runtime(term *ast.Term) func(r *Rego) { - return func(r *Rego) { - r.runtime = term - } + return v1.Runtime(term) } // Time sets the wall clock time to use during policy evaluation. Prepared queries // do not inherit this parameter. Use EvalTime to set the wall clock time when // executing a prepared query. func Time(x time.Time) func(r *Rego) { - return func(r *Rego) { - r.time = x - } + return v1.Time(x) } // Seed sets a reader that will seed randomization required by built-in functions. // If a seed is not provided crypto/rand.Reader is used. func Seed(r io.Reader) func(*Rego) { - return func(e *Rego) { - e.seed = r - } + return v1.Seed(r) } // PrintTrace is a helper function to write a human-readable version of the // trace to the writer w. func PrintTrace(w io.Writer, r *Rego) { - if r == nil || r.tracebuf == nil { - return - } - topdown.PrettyTrace(w, *r.tracebuf) + v1.PrintTrace(w, r) } // PrintTraceWithLocation is a helper function to write a human-readable version of the // trace to the writer w. func PrintTraceWithLocation(w io.Writer, r *Rego) { - if r == nil || r.tracebuf == nil { - return - } - topdown.PrettyTraceWithLocation(w, *r.tracebuf) + v1.PrintTraceWithLocation(w, r) } // UnsafeBuiltins sets the built-in functions to treat as unsafe and not allow. @@ -1118,104 +487,76 @@ func PrintTraceWithLocation(w io.Writer, r *Rego) { // compiler. This option is always honored for query compilation. Provide an // empty (non-nil) map to disable checks on queries. func UnsafeBuiltins(unsafeBuiltins map[string]struct{}) func(r *Rego) { - return func(r *Rego) { - r.unsafeBuiltins = unsafeBuiltins - } + return v1.UnsafeBuiltins(unsafeBuiltins) } // SkipBundleVerification skips verification of a signed bundle. func SkipBundleVerification(yes bool) func(r *Rego) { - return func(r *Rego) { - r.skipBundleVerification = yes - } + return v1.SkipBundleVerification(yes) } // InterQueryBuiltinCache sets the inter-query cache that built-in functions can utilize // during evaluation. func InterQueryBuiltinCache(c cache.InterQueryCache) func(r *Rego) { - return func(r *Rego) { - r.interQueryBuiltinCache = c - } + return v1.InterQueryBuiltinCache(c) } // InterQueryBuiltinValueCache sets the inter-query value cache that built-in functions can utilize // during evaluation. func InterQueryBuiltinValueCache(c cache.InterQueryValueCache) func(r *Rego) { - return func(r *Rego) { - r.interQueryBuiltinValueCache = c - } + return v1.InterQueryBuiltinValueCache(c) } // NDBuiltinCache sets the non-deterministic builtins cache. func NDBuiltinCache(c builtins.NDBCache) func(r *Rego) { - return func(r *Rego) { - r.ndBuiltinCache = c - } + return v1.NDBuiltinCache(c) } // StrictBuiltinErrors tells the evaluator to treat all built-in function errors as fatal errors. func StrictBuiltinErrors(yes bool) func(r *Rego) { - return func(r *Rego) { - r.strictBuiltinErrors = yes - } + return v1.StrictBuiltinErrors(yes) } // BuiltinErrorList supplies an error slice to store built-in function errors. func BuiltinErrorList(list *[]topdown.Error) func(r *Rego) { - return func(r *Rego) { - r.builtinErrorList = list - } + return v1.BuiltinErrorList(list) } // Resolver sets a Resolver for a specified ref path. func Resolver(ref ast.Ref, r resolver.Resolver) func(r *Rego) { - return func(rego *Rego) { - rego.resolvers = append(rego.resolvers, refResolver{ref, r}) - } + return v1.Resolver(ref, r) } // Schemas sets the schemaSet func Schemas(x *ast.SchemaSet) func(r *Rego) { - return func(r *Rego) { - r.schemaSet = x - } + return v1.Schemas(x) } // Capabilities configures the underlying compiler's capabilities. // This option is ignored for module compilation if the caller supplies the // compiler. func Capabilities(c *ast.Capabilities) func(r *Rego) { - return func(r *Rego) { - r.capabilities = c - } + return v1.Capabilities(c) } // Target sets the runtime to exercise. func Target(t string) func(r *Rego) { - return func(r *Rego) { - r.target = t - } + return v1.Target(t) } // GenerateJSON sets the AST to JSON converter for the results. func GenerateJSON(f func(*ast.Term, *EvalContext) (interface{}, error)) func(r *Rego) { - return func(r *Rego) { - r.generateJSON = f - } + return v1.GenerateJSON(f) } // PrintHook sets the object to use for handling print statement outputs. func PrintHook(h print.Hook) func(r *Rego) { - return func(r *Rego) { - r.printHook = h - } + return v1.PrintHook(h) } // DistributedTracingOpts sets the options to be used by distributed tracing. func DistributedTracingOpts(tr tracing.Options) func(r *Rego) { - return func(r *Rego) { - r.distributedTacingOpts = tr - } + return v1.DistributedTracingOpts(tr) } // EnablePrintStatements enables print() calls. If this option is not provided, @@ -1223,1667 +564,65 @@ func DistributedTracingOpts(tr tracing.Options) func(r *Rego) { // queries and policies that passed as raw strings, i.e., this function will not // have any affect if the caller supplies the ast.Compiler instance. func EnablePrintStatements(yes bool) func(r *Rego) { - return func(r *Rego) { - r.enablePrintStatements = yes - } + return v1.EnablePrintStatements(yes) } // Strict enables or disables strict-mode in the compiler func Strict(yes bool) func(r *Rego) { - return func(r *Rego) { - r.strict = yes - } + return v1.Strict(yes) } func SetRegoVersion(version ast.RegoVersion) func(r *Rego) { - return func(r *Rego) { - r.regoVersion = version - } + return v1.SetRegoVersion(version) } // New returns a new Rego object. func New(options ...func(r *Rego)) *Rego { - - r := &Rego{ - parsedModules: map[string]*ast.Module{}, - capture: map[*ast.Expr]ast.Var{}, - compiledQueries: map[queryType]compiledQuery{}, - builtinDecls: map[string]*ast.Builtin{}, - builtinFuncs: map[string]*topdown.Builtin{}, - bundles: map[string]*bundle.Bundle{}, - } - - for _, option := range options { - option(r) - } - - if r.compiler == nil { - r.compiler = ast.NewCompiler(). - WithUnsafeBuiltins(r.unsafeBuiltins). - WithBuiltins(r.builtinDecls). - WithDebug(r.dump). - WithSchemas(r.schemaSet). - WithCapabilities(r.capabilities). - WithEnablePrintStatements(r.enablePrintStatements). - WithStrict(r.strict). - WithUseTypeCheckAnnotations(true) - - // topdown could be target "" or "rego", but both could be overridden by - // a target plugin (checked below) - if r.target == targetWasm { - r.compiler = r.compiler.WithEvalMode(ast.EvalModeIR) + opts := make([]func(r *Rego), 0, len(options)+1) + opts = append(opts, options...) + opts = append(opts, func(r *Rego) { + if r.RegoVersion() == ast.RegoUndefined { + SetRegoVersion(ast.DefaultRegoVersion)(r) } - } - - if r.store == nil { - r.store = inmem.NewWithOpts(inmem.OptReturnASTValuesOnRead(r.ownStoreReadAst)) - r.ownStore = true - } else { - r.ownStore = false - } - - if r.metrics == nil { - r.metrics = metrics.New() - } - - if r.instrument { - r.instrumentation = topdown.NewInstrumentation(r.metrics) - r.compiler.WithMetrics(r.metrics) - } - - if r.trace { - r.tracebuf = topdown.NewBufferTracer() - r.queryTracers = append(r.queryTracers, r.tracebuf) - } - - if r.partialNamespace == "" { - r.partialNamespace = defaultPartialNamespace - } - - if r.generateJSON == nil { - r.generateJSON = generateJSON - } - - if r.pluginMgr != nil { - for _, name := range r.pluginMgr.Plugins() { - p := r.pluginMgr.Plugin(name) - if p0, ok := p.(TargetPlugin); ok { - r.plugins = append(r.plugins, p0) - } - } - } - - if t := r.targetPlugin(r.target); t != nil { - r.compiler = r.compiler.WithEvalMode(ast.EvalModeIR) - } - - return r -} - -// Eval evaluates this Rego object and returns a ResultSet. -func (r *Rego) Eval(ctx context.Context) (ResultSet, error) { - var err error - var txnClose transactionCloser - r.txn, txnClose, err = r.getTxn(ctx) - if err != nil { - return nil, err - } - - pq, err := r.PrepareForEval(ctx) - if err != nil { - _ = txnClose(ctx, err) // Ignore error - return nil, err - } - - evalArgs := []EvalOption{ - EvalTransaction(r.txn), - EvalMetrics(r.metrics), - EvalInstrument(r.instrument), - EvalTime(r.time), - EvalInterQueryBuiltinCache(r.interQueryBuiltinCache), - EvalInterQueryBuiltinValueCache(r.interQueryBuiltinValueCache), - EvalSeed(r.seed), - } - - if r.ndBuiltinCache != nil { - evalArgs = append(evalArgs, EvalNDBuiltinCache(r.ndBuiltinCache)) - } - - for _, qt := range r.queryTracers { - evalArgs = append(evalArgs, EvalQueryTracer(qt)) - } - - for i := range r.resolvers { - evalArgs = append(evalArgs, EvalResolver(r.resolvers[i].ref, r.resolvers[i].r)) - } - - rs, err := pq.Eval(ctx, evalArgs...) - txnErr := txnClose(ctx, err) // Always call closer - if err == nil { - err = txnErr - } - return rs, err -} - -// PartialEval has been deprecated and renamed to PartialResult. -func (r *Rego) PartialEval(ctx context.Context) (PartialResult, error) { - return r.PartialResult(ctx) -} - -// PartialResult partially evaluates this Rego object and returns a PartialResult. -func (r *Rego) PartialResult(ctx context.Context) (PartialResult, error) { - var err error - var txnClose transactionCloser - r.txn, txnClose, err = r.getTxn(ctx) - if err != nil { - return PartialResult{}, err - } - - pq, err := r.PrepareForEval(ctx, WithPartialEval()) - txnErr := txnClose(ctx, err) // Always call closer - if err != nil { - return PartialResult{}, err - } - if txnErr != nil { - return PartialResult{}, txnErr - } - - pr := PartialResult{ - compiler: pq.r.compiler, - store: pq.r.store, - body: pq.r.parsedQuery, - builtinDecls: pq.r.builtinDecls, - builtinFuncs: pq.r.builtinFuncs, - } - - return pr, nil -} - -// Partial runs partial evaluation on r and returns the result. -func (r *Rego) Partial(ctx context.Context) (*PartialQueries, error) { - var err error - var txnClose transactionCloser - r.txn, txnClose, err = r.getTxn(ctx) - if err != nil { - return nil, err - } - - pq, err := r.PrepareForPartial(ctx) - if err != nil { - _ = txnClose(ctx, err) // Ignore error - return nil, err - } - - evalArgs := []EvalOption{ - EvalTransaction(r.txn), - EvalMetrics(r.metrics), - EvalInstrument(r.instrument), - EvalInterQueryBuiltinCache(r.interQueryBuiltinCache), - EvalInterQueryBuiltinValueCache(r.interQueryBuiltinValueCache), - } - - if r.ndBuiltinCache != nil { - evalArgs = append(evalArgs, EvalNDBuiltinCache(r.ndBuiltinCache)) - } - - for _, t := range r.queryTracers { - evalArgs = append(evalArgs, EvalQueryTracer(t)) - } - - for i := range r.resolvers { - evalArgs = append(evalArgs, EvalResolver(r.resolvers[i].ref, r.resolvers[i].r)) - } - - pqs, err := pq.Partial(ctx, evalArgs...) - txnErr := txnClose(ctx, err) // Always call closer - if err == nil { - err = txnErr - } - return pqs, err + }) + + return v1.New(opts...) } // CompileOption defines a function to set options on Compile calls. -type CompileOption func(*CompileContext) +type CompileOption = v1.CompileOption // CompileContext contains options for Compile calls. -type CompileContext struct { - partial bool -} +type CompileContext = v1.CompileContext // CompilePartial defines an option to control whether partial evaluation is run // before the query is planned and compiled. func CompilePartial(yes bool) CompileOption { - return func(cfg *CompileContext) { - cfg.partial = yes - } -} - -// Compile returns a compiled policy query. -func (r *Rego) Compile(ctx context.Context, opts ...CompileOption) (*CompileResult, error) { - - var cfg CompileContext - - for _, opt := range opts { - opt(&cfg) - } - - var queries []ast.Body - modules := make([]*ast.Module, 0, len(r.compiler.Modules)) - - if cfg.partial { - - pq, err := r.Partial(ctx) - if err != nil { - return nil, err - } - if r.dump != nil { - if len(pq.Queries) != 0 { - msg := fmt.Sprintf("QUERIES (%d total):", len(pq.Queries)) - fmt.Fprintln(r.dump, msg) - fmt.Fprintln(r.dump, strings.Repeat("-", len(msg))) - for i := range pq.Queries { - fmt.Println(pq.Queries[i]) - } - fmt.Fprintln(r.dump) - } - if len(pq.Support) != 0 { - msg := fmt.Sprintf("SUPPORT (%d total):", len(pq.Support)) - fmt.Fprintln(r.dump, msg) - fmt.Fprintln(r.dump, strings.Repeat("-", len(msg))) - for i := range pq.Support { - fmt.Println(pq.Support[i]) - } - fmt.Fprintln(r.dump) - } - } - - queries = pq.Queries - modules = pq.Support - - for _, module := range r.compiler.Modules { - modules = append(modules, module) - } - } else { - var err error - // If creating a new transaction it should be closed before calling the - // planner to avoid holding open the transaction longer than needed. - // - // TODO(tsandall): in future, planner could make use of store, in which - // case this will need to change. - var txnClose transactionCloser - r.txn, txnClose, err = r.getTxn(ctx) - if err != nil { - return nil, err - } - - err = r.prepare(ctx, compileQueryType, nil) - txnErr := txnClose(ctx, err) // Always call closer - if err != nil { - return nil, err - } - if txnErr != nil { - return nil, err - } - - for _, module := range r.compiler.Modules { - modules = append(modules, module) - } - - queries = []ast.Body{r.compiledQueries[compileQueryType].query} - } - - if tgt := r.targetPlugin(r.target); tgt != nil { - return nil, fmt.Errorf("unsupported for rego target plugins") - } - - return r.compileWasm(modules, queries, compileQueryType) // TODO(sr) control flow is funky here -} - -func (r *Rego) compileWasm(_ []*ast.Module, queries []ast.Body, qType queryType) (*CompileResult, error) { - policy, err := r.planQuery(queries, qType) - if err != nil { - return nil, err - } - - m, err := wasm.New().WithPolicy(policy).Compile() - if err != nil { - return nil, err - } - - var out bytes.Buffer - if err := encoding.WriteModule(&out, m); err != nil { - return nil, err - } - - return &CompileResult{ - Bytes: out.Bytes(), - }, nil + return v1.CompilePartial(yes) } // PrepareOption defines a function to set an option to control // the behavior of the Prepare call. -type PrepareOption func(*PrepareConfig) +type PrepareOption = v1.PrepareOption // PrepareConfig holds settings to control the behavior of the // Prepare call. -type PrepareConfig struct { - doPartialEval bool - disableInlining *[]string - builtinFuncs map[string]*topdown.Builtin -} +type PrepareConfig = v1.PrepareConfig // WithPartialEval configures an option for PrepareForEval // which will have it perform partial evaluation while preparing // the query (similar to rego.Rego#PartialResult) func WithPartialEval() PrepareOption { - return func(p *PrepareConfig) { - p.doPartialEval = true - } + return v1.WithPartialEval() } // WithNoInline adds a set of paths to exclude from partial evaluation inlining. func WithNoInline(paths []string) PrepareOption { - return func(p *PrepareConfig) { - p.disableInlining = &paths - } + return v1.WithNoInline(paths) } // WithBuiltinFuncs carries the rego.Function{1,2,3} per-query function definitions // to the target plugins. func WithBuiltinFuncs(bis map[string]*topdown.Builtin) PrepareOption { - return func(p *PrepareConfig) { - if p.builtinFuncs == nil { - p.builtinFuncs = make(map[string]*topdown.Builtin, len(bis)) - } - for k, v := range bis { - p.builtinFuncs[k] = v - } - } -} - -// BuiltinFuncs allows retrieving the builtin funcs set via PrepareOption -// WithBuiltinFuncs. -func (p *PrepareConfig) BuiltinFuncs() map[string]*topdown.Builtin { - return p.builtinFuncs -} - -// PrepareForEval will parse inputs, modules, and query arguments in preparation -// of evaluating them. -func (r *Rego) PrepareForEval(ctx context.Context, opts ...PrepareOption) (PreparedEvalQuery, error) { - if !r.hasQuery() { - return PreparedEvalQuery{}, fmt.Errorf("cannot evaluate empty query") - } - - pCfg := &PrepareConfig{} - for _, o := range opts { - o(pCfg) - } - - var err error - var txnClose transactionCloser - r.txn, txnClose, err = r.getTxn(ctx) - if err != nil { - return PreparedEvalQuery{}, err - } - - // If the caller wanted to do partial evaluation as part of preparation - // do it now and use the new Rego object. - if pCfg.doPartialEval { - - pr, err := r.partialResult(ctx, pCfg) - if err != nil { - _ = txnClose(ctx, err) // Ignore error - return PreparedEvalQuery{}, err - } - - // Prepare the new query using the result of partial evaluation - pq, err := pr.Rego(Transaction(r.txn)).PrepareForEval(ctx) - txnErr := txnClose(ctx, err) - if err != nil { - return pq, err - } - return pq, txnErr - } - - err = r.prepare(ctx, evalQueryType, []extraStage{ - { - after: "ResolveRefs", - stage: ast.QueryCompilerStageDefinition{ - Name: "RewriteToCaptureValue", - MetricName: "query_compile_stage_rewrite_to_capture_value", - Stage: r.rewriteQueryToCaptureValue, - }, - }, - }) - if err != nil { - _ = txnClose(ctx, err) // Ignore error - return PreparedEvalQuery{}, err - } - - switch r.target { - case targetWasm: // TODO(sr): make wasm a target plugin, too - - if r.hasWasmModule() { - _ = txnClose(ctx, err) // Ignore error - return PreparedEvalQuery{}, fmt.Errorf("wasm target not supported") - } - - var modules []*ast.Module - for _, module := range r.compiler.Modules { - modules = append(modules, module) - } - - queries := []ast.Body{r.compiledQueries[evalQueryType].query} - - e, err := opa.LookupEngine(targetWasm) - if err != nil { - return PreparedEvalQuery{}, err - } - - // nolint: staticcheck // SA4006 false positive - cr, err := r.compileWasm(modules, queries, evalQueryType) - if err != nil { - _ = txnClose(ctx, err) // Ignore error - return PreparedEvalQuery{}, err - } - - // nolint: staticcheck // SA4006 false positive - data, err := r.store.Read(ctx, r.txn, storage.Path{}) - if err != nil { - _ = txnClose(ctx, err) // Ignore error - return PreparedEvalQuery{}, err - } - - o, err := e.New().WithPolicyBytes(cr.Bytes).WithDataJSON(data).Init() - if err != nil { - _ = txnClose(ctx, err) // Ignore error - return PreparedEvalQuery{}, err - } - r.opa = o - - case targetRego: // do nothing, don't lookup default plugin - default: // either a specific plugin target, or one that is default - if tgt := r.targetPlugin(r.target); tgt != nil { - queries := []ast.Body{r.compiledQueries[evalQueryType].query} - pol, err := r.planQuery(queries, evalQueryType) - if err != nil { - return PreparedEvalQuery{}, err - } - // always add the builtins provided via rego.FunctionN options - opts = append(opts, WithBuiltinFuncs(r.builtinFuncs)) - r.targetPrepState, err = tgt.PrepareForEval(ctx, pol, opts...) - if err != nil { - return PreparedEvalQuery{}, err - } - } - } - - txnErr := txnClose(ctx, err) // Always call closer - if err != nil { - return PreparedEvalQuery{}, err - } - if txnErr != nil { - return PreparedEvalQuery{}, txnErr - } - - return PreparedEvalQuery{preparedQuery{r, pCfg}}, err -} - -// PrepareForPartial will parse inputs, modules, and query arguments in preparation -// of partially evaluating them. -func (r *Rego) PrepareForPartial(ctx context.Context, opts ...PrepareOption) (PreparedPartialQuery, error) { - if !r.hasQuery() { - return PreparedPartialQuery{}, fmt.Errorf("cannot evaluate empty query") - } - - pCfg := &PrepareConfig{} - for _, o := range opts { - o(pCfg) - } - - var err error - var txnClose transactionCloser - r.txn, txnClose, err = r.getTxn(ctx) - if err != nil { - return PreparedPartialQuery{}, err - } - - err = r.prepare(ctx, partialQueryType, []extraStage{ - { - after: "CheckSafety", - stage: ast.QueryCompilerStageDefinition{ - Name: "RewriteEquals", - MetricName: "query_compile_stage_rewrite_equals", - Stage: r.rewriteEqualsForPartialQueryCompile, - }, - }, - }) - txnErr := txnClose(ctx, err) // Always call closer - if err != nil { - return PreparedPartialQuery{}, err - } - if txnErr != nil { - return PreparedPartialQuery{}, txnErr - } - - return PreparedPartialQuery{preparedQuery{r, pCfg}}, err -} - -func (r *Rego) prepare(ctx context.Context, qType queryType, extras []extraStage) error { - var err error - - r.parsedInput, err = r.parseInput() - if err != nil { - return err - } - - err = r.loadFiles(ctx, r.txn, r.metrics) - if err != nil { - return err - } - - err = r.loadBundles(ctx, r.txn, r.metrics) - if err != nil { - return err - } - - err = r.parseModules(ctx, r.txn, r.metrics) - if err != nil { - return err - } - - // Compile the modules *before* the query, else functions - // defined in the module won't be found... - err = r.compileModules(ctx, r.txn, r.metrics) - if err != nil { - return err - } - - imports, err := r.prepareImports() - if err != nil { - return err - } - - queryImports := []*ast.Import{} - for _, imp := range imports { - path := imp.Path.Value.(ast.Ref) - if path.HasPrefix([]*ast.Term{ast.FutureRootDocument}) || path.HasPrefix([]*ast.Term{ast.RegoRootDocument}) { - queryImports = append(queryImports, imp) - } - } - - r.parsedQuery, err = r.parseQuery(queryImports, r.metrics) - if err != nil { - return err - } - - err = r.compileAndCacheQuery(qType, r.parsedQuery, imports, r.metrics, extras) - if err != nil { - return err - } - - return nil -} - -func (r *Rego) parseModules(ctx context.Context, txn storage.Transaction, m metrics.Metrics) error { - if len(r.modules) == 0 { - return nil - } - - ids, err := r.store.ListPolicies(ctx, txn) - if err != nil { - return err - } - - m.Timer(metrics.RegoModuleParse).Start() - defer m.Timer(metrics.RegoModuleParse).Stop() - var errs Errors - - // Parse any modules that are saved to the store, but only if - // another compile step is going to occur (ie. we have parsed modules - // that need to be compiled). - for _, id := range ids { - // if it is already on the compiler we're using - // then don't bother to re-parse it from source - if _, haveMod := r.compiler.Modules[id]; haveMod { - continue - } - - bs, err := r.store.GetPolicy(ctx, txn, id) - if err != nil { - return err - } - - parsed, err := ast.ParseModuleWithOpts(id, string(bs), ast.ParserOptions{RegoVersion: r.regoVersion}) - if err != nil { - errs = append(errs, err) - } - - r.parsedModules[id] = parsed - } - - // Parse any passed in as arguments to the Rego object - for _, module := range r.modules { - p, err := module.ParseWithOpts(ast.ParserOptions{RegoVersion: r.regoVersion}) - if err != nil { - switch errorWithType := err.(type) { - case ast.Errors: - for _, e := range errorWithType { - errs = append(errs, e) - } - default: - errs = append(errs, errorWithType) - } - } - r.parsedModules[module.filename] = p - } - - if len(errs) > 0 { - return errs - } - - return nil -} - -func (r *Rego) loadFiles(ctx context.Context, txn storage.Transaction, m metrics.Metrics) error { - if len(r.loadPaths.paths) == 0 { - return nil - } - - m.Timer(metrics.RegoLoadFiles).Start() - defer m.Timer(metrics.RegoLoadFiles).Stop() - - result, err := loader.NewFileLoader(). - WithMetrics(m). - WithProcessAnnotation(true). - WithRegoVersion(r.regoVersion). - Filtered(r.loadPaths.paths, r.loadPaths.filter) - if err != nil { - return err - } - for name, mod := range result.Modules { - r.parsedModules[name] = mod.Parsed - } - - if len(result.Documents) > 0 { - err = r.store.Write(ctx, txn, storage.AddOp, storage.Path{}, result.Documents) - if err != nil { - return err - } - } - return nil -} - -func (r *Rego) loadBundles(_ context.Context, _ storage.Transaction, m metrics.Metrics) error { - if len(r.bundlePaths) == 0 { - return nil - } - - m.Timer(metrics.RegoLoadBundles).Start() - defer m.Timer(metrics.RegoLoadBundles).Stop() - - for _, path := range r.bundlePaths { - bndl, err := loader.NewFileLoader(). - WithMetrics(m). - WithProcessAnnotation(true). - WithSkipBundleVerification(r.skipBundleVerification). - WithRegoVersion(r.regoVersion). - AsBundle(path) - if err != nil { - return fmt.Errorf("loading error: %s", err) - } - r.bundles[path] = bndl - } - return nil -} - -func (r *Rego) parseInput() (ast.Value, error) { - if r.parsedInput != nil { - return r.parsedInput, nil - } - return r.parseRawInput(r.rawInput, r.metrics) -} - -func (r *Rego) parseRawInput(rawInput *interface{}, m metrics.Metrics) (ast.Value, error) { - var input ast.Value - - if rawInput == nil { - return input, nil - } - - m.Timer(metrics.RegoInputParse).Start() - defer m.Timer(metrics.RegoInputParse).Stop() - - rawPtr := util.Reference(rawInput) - - // roundtrip through json: this turns slices (e.g. []string, []bool) into - // []interface{}, the only array type ast.InterfaceToValue can work with - if err := util.RoundTrip(rawPtr); err != nil { - return nil, err - } - - return ast.InterfaceToValue(*rawPtr) -} - -func (r *Rego) parseQuery(queryImports []*ast.Import, m metrics.Metrics) (ast.Body, error) { - if r.parsedQuery != nil { - return r.parsedQuery, nil - } - - m.Timer(metrics.RegoQueryParse).Start() - defer m.Timer(metrics.RegoQueryParse).Stop() - - popts, err := future.ParserOptionsFromFutureImports(queryImports) - if err != nil { - return nil, err - } - popts.RegoVersion = r.regoVersion - popts, err = parserOptionsFromRegoVersionImport(queryImports, popts) - if err != nil { - return nil, err - } - popts.SkipRules = true - return ast.ParseBodyWithOpts(r.query, popts) -} - -func parserOptionsFromRegoVersionImport(imports []*ast.Import, popts ast.ParserOptions) (ast.ParserOptions, error) { - for _, imp := range imports { - path := imp.Path.Value.(ast.Ref) - if ast.Compare(path, ast.RegoV1CompatibleRef) == 0 { - popts.RegoVersion = ast.RegoV1 - return popts, nil - } - } - return popts, nil -} - -func (r *Rego) compileModules(ctx context.Context, txn storage.Transaction, m metrics.Metrics) error { - - // Only compile again if there are new modules. - if len(r.bundles) > 0 || len(r.parsedModules) > 0 { - - // The bundle.Activate call will activate any bundles passed in - // (ie compile + handle data store changes), and include any of - // the additional modules passed in. If no bundles are provided - // it will only compile the passed in modules. - // Use this as the single-point of compiling everything only a - // single time. - opts := &bundle.ActivateOpts{ - Ctx: ctx, - Store: r.store, - Txn: txn, - Compiler: r.compilerForTxn(ctx, r.store, txn), - Metrics: m, - Bundles: r.bundles, - ExtraModules: r.parsedModules, - ParserOptions: ast.ParserOptions{RegoVersion: r.regoVersion}, - } - err := bundle.Activate(opts) - if err != nil { - return err - } - } - - // Ensure all configured resolvers from the store are loaded. Skip if any were explicitly provided. - if len(r.resolvers) == 0 { - resolvers, err := bundleUtils.LoadWasmResolversFromStore(ctx, r.store, txn, r.bundles) - if err != nil { - return err - } - - for _, rslvr := range resolvers { - for _, ep := range rslvr.Entrypoints() { - r.resolvers = append(r.resolvers, refResolver{ep, rslvr}) - } - } - } - return nil -} - -func (r *Rego) compileAndCacheQuery(qType queryType, query ast.Body, imports []*ast.Import, m metrics.Metrics, extras []extraStage) error { - m.Timer(metrics.RegoQueryCompile).Start() - defer m.Timer(metrics.RegoQueryCompile).Stop() - - cachedQuery, ok := r.compiledQueries[qType] - if ok && cachedQuery.query != nil && cachedQuery.compiler != nil { - return nil - } - - qc, compiled, err := r.compileQuery(query, imports, m, extras) - if err != nil { - return err - } - - // cache the query for future use - r.compiledQueries[qType] = compiledQuery{ - query: compiled, - compiler: qc, - } - return nil -} - -func (r *Rego) prepareImports() ([]*ast.Import, error) { - imports := r.parsedImports - - if len(r.imports) > 0 { - s := make([]string, len(r.imports)) - for i := range r.imports { - s[i] = fmt.Sprintf("import %v", r.imports[i]) - } - parsed, err := ast.ParseImports(strings.Join(s, "\n")) - if err != nil { - return nil, err - } - imports = append(imports, parsed...) - } - return imports, nil -} - -func (r *Rego) compileQuery(query ast.Body, imports []*ast.Import, _ metrics.Metrics, extras []extraStage) (ast.QueryCompiler, ast.Body, error) { - var pkg *ast.Package - - if r.pkg != "" { - var err error - pkg, err = ast.ParsePackage(fmt.Sprintf("package %v", r.pkg)) - if err != nil { - return nil, nil, err - } - } else { - pkg = r.parsedPackage - } - - qctx := ast.NewQueryContext(). - WithPackage(pkg). - WithImports(imports) - - qc := r.compiler.QueryCompiler(). - WithContext(qctx). - WithUnsafeBuiltins(r.unsafeBuiltins). - WithEnablePrintStatements(r.enablePrintStatements). - WithStrict(false) - - for _, extra := range extras { - qc = qc.WithStageAfter(extra.after, extra.stage) - } - - compiled, err := qc.Compile(query) - - return qc, compiled, err - -} - -func (r *Rego) eval(ctx context.Context, ectx *EvalContext) (ResultSet, error) { - switch { - case r.targetPrepState != nil: // target plugin flow - var val ast.Value - if r.runtime != nil { - val = r.runtime.Value - } - s, err := r.targetPrepState.Eval(ctx, ectx, val) - if err != nil { - return nil, err - } - return r.valueToQueryResult(s, ectx) - case r.target == targetWasm: - return r.evalWasm(ctx, ectx) - case r.target == targetRego: // continue - } - - q := topdown.NewQuery(ectx.compiledQuery.query). - WithQueryCompiler(ectx.compiledQuery.compiler). - WithCompiler(r.compiler). - WithStore(r.store). - WithTransaction(ectx.txn). - WithBuiltins(r.builtinFuncs). - WithMetrics(ectx.metrics). - WithInstrumentation(ectx.instrumentation). - WithRuntime(r.runtime). - WithIndexing(ectx.indexing). - WithEarlyExit(ectx.earlyExit). - WithInterQueryBuiltinCache(ectx.interQueryBuiltinCache). - WithInterQueryBuiltinValueCache(ectx.interQueryBuiltinValueCache). - WithStrictBuiltinErrors(r.strictBuiltinErrors). - WithBuiltinErrorList(r.builtinErrorList). - WithSeed(ectx.seed). - WithPrintHook(ectx.printHook). - WithDistributedTracingOpts(r.distributedTacingOpts). - WithVirtualCache(ectx.virtualCache) - - if !ectx.time.IsZero() { - q = q.WithTime(ectx.time) - } - - if ectx.ndBuiltinCache != nil { - q = q.WithNDBuiltinCache(ectx.ndBuiltinCache) - } - - for i := range ectx.queryTracers { - q = q.WithQueryTracer(ectx.queryTracers[i]) - } - - if ectx.parsedInput != nil { - q = q.WithInput(ast.NewTerm(ectx.parsedInput)) - } - - for i := range ectx.resolvers { - q = q.WithResolver(ectx.resolvers[i].ref, ectx.resolvers[i].r) - } - - // Cancel query if context is cancelled or deadline is reached. - c := topdown.NewCancel() - q = q.WithCancel(c) - exit := make(chan struct{}) - defer close(exit) - go waitForDone(ctx, exit, func() { - c.Cancel() - }) - - var rs ResultSet - err := q.Iter(ctx, func(qr topdown.QueryResult) error { - result, err := r.generateResult(qr, ectx) - if err != nil { - return err - } - rs = append(rs, result) - return nil - }) - - if err != nil { - return nil, err - } - - if len(rs) == 0 { - return nil, nil - } - - return rs, nil -} - -func (r *Rego) evalWasm(ctx context.Context, ectx *EvalContext) (ResultSet, error) { - input := ectx.rawInput - if ectx.parsedInput != nil { - i := interface{}(ectx.parsedInput) - input = &i - } - result, err := r.opa.Eval(ctx, opa.EvalOpts{ - Metrics: r.metrics, - Input: input, - Time: ectx.time, - Seed: ectx.seed, - InterQueryBuiltinCache: ectx.interQueryBuiltinCache, - NDBuiltinCache: ectx.ndBuiltinCache, - PrintHook: ectx.printHook, - Capabilities: ectx.capabilities, - }) - if err != nil { - return nil, err - } - - parsed, err := ast.ParseTerm(string(result.Result)) - if err != nil { - return nil, err - } - - return r.valueToQueryResult(parsed.Value, ectx) -} - -func (r *Rego) valueToQueryResult(res ast.Value, ectx *EvalContext) (ResultSet, error) { - resultSet, ok := res.(ast.Set) - if !ok { - return nil, fmt.Errorf("illegal result type") - } - - if resultSet.Len() == 0 { - return nil, nil - } - - var rs ResultSet - err := resultSet.Iter(func(term *ast.Term) error { - obj, ok := term.Value.(ast.Object) - if !ok { - return fmt.Errorf("illegal result type") - } - qr := topdown.QueryResult{} - obj.Foreach(func(k, v *ast.Term) { - kvt := ast.VarTerm(string(k.Value.(ast.String))) - qr[kvt.Value.(ast.Var)] = v - }) - result, err := r.generateResult(qr, ectx) - if err != nil { - return err - } - rs = append(rs, result) - return nil - }) - - return rs, err -} - -func (r *Rego) generateResult(qr topdown.QueryResult, ectx *EvalContext) (Result, error) { - - rewritten := ectx.compiledQuery.compiler.RewrittenVars() - - result := newResult() - for k, term := range qr { - v, err := r.generateJSON(term, ectx) - if err != nil { - return result, err - } - - if rw, ok := rewritten[k]; ok { - k = rw - } - if isTermVar(k) || isTermWasmVar(k) || k.IsGenerated() || k.IsWildcard() { - continue - } - result.Bindings[string(k)] = v - } - - for _, expr := range ectx.compiledQuery.query { - if expr.Generated { - continue - } - - if k, ok := r.capture[expr]; ok { - v, err := r.generateJSON(qr[k], ectx) - if err != nil { - return result, err - } - result.Expressions = append(result.Expressions, newExpressionValue(expr, v)) - } else { - result.Expressions = append(result.Expressions, newExpressionValue(expr, true)) - } - - } - return result, nil -} - -func (r *Rego) partialResult(ctx context.Context, pCfg *PrepareConfig) (PartialResult, error) { - - err := r.prepare(ctx, partialResultQueryType, []extraStage{ - { - after: "ResolveRefs", - stage: ast.QueryCompilerStageDefinition{ - Name: "RewriteForPartialEval", - MetricName: "query_compile_stage_rewrite_for_partial_eval", - Stage: r.rewriteQueryForPartialEval, - }, - }, - }) - if err != nil { - return PartialResult{}, err - } - - ectx := &EvalContext{ - parsedInput: r.parsedInput, - metrics: r.metrics, - txn: r.txn, - partialNamespace: r.partialNamespace, - queryTracers: r.queryTracers, - compiledQuery: r.compiledQueries[partialResultQueryType], - instrumentation: r.instrumentation, - indexing: true, - resolvers: r.resolvers, - capabilities: r.capabilities, - strictBuiltinErrors: r.strictBuiltinErrors, - } - - disableInlining := r.disableInlining - - if pCfg.disableInlining != nil { - disableInlining = *pCfg.disableInlining - } - - ectx.disableInlining, err = parseStringsToRefs(disableInlining) - if err != nil { - return PartialResult{}, err - } - - pq, err := r.partial(ctx, ectx) - if err != nil { - return PartialResult{}, err - } - - // Construct module for queries. - id := fmt.Sprintf("__partialresult__%s__", ectx.partialNamespace) - - module, err := ast.ParseModule(id, "package "+ectx.partialNamespace) - if err != nil { - return PartialResult{}, fmt.Errorf("bad partial namespace") - } - - module.Rules = make([]*ast.Rule, len(pq.Queries)) - for i, body := range pq.Queries { - rule := &ast.Rule{ - Head: ast.NewHead(ast.Var("__result__"), nil, ast.Wildcard), - Body: body, - Module: module, - } - module.Rules[i] = rule - if checkPartialResultForRecursiveRefs(body, rule.Path()) { - return PartialResult{}, Errors{errPartialEvaluationNotEffective} - } - } - - // Update compiler with partial evaluation output. - r.compiler.Modules[id] = module - for i, module := range pq.Support { - r.compiler.Modules[fmt.Sprintf("__partialsupport__%s__%d__", ectx.partialNamespace, i)] = module - } - - r.metrics.Timer(metrics.RegoModuleCompile).Start() - r.compilerForTxn(ctx, r.store, r.txn).Compile(r.compiler.Modules) - r.metrics.Timer(metrics.RegoModuleCompile).Stop() - - if r.compiler.Failed() { - return PartialResult{}, r.compiler.Errors - } - - result := PartialResult{ - compiler: r.compiler, - store: r.store, - body: ast.MustParseBody(fmt.Sprintf("data.%v.__result__", ectx.partialNamespace)), - builtinDecls: r.builtinDecls, - builtinFuncs: r.builtinFuncs, - } - - return result, nil -} - -func (r *Rego) partial(ctx context.Context, ectx *EvalContext) (*PartialQueries, error) { - - var unknowns []*ast.Term - - switch { - case ectx.parsedUnknowns != nil: - unknowns = ectx.parsedUnknowns - case ectx.unknowns != nil: - unknowns = make([]*ast.Term, len(ectx.unknowns)) - for i := range ectx.unknowns { - var err error - unknowns[i], err = ast.ParseTerm(ectx.unknowns[i]) - if err != nil { - return nil, err - } - } - default: - // Use input document as unknown if caller has not specified any. - unknowns = []*ast.Term{ast.NewTerm(ast.InputRootRef)} - } - - q := topdown.NewQuery(ectx.compiledQuery.query). - WithQueryCompiler(ectx.compiledQuery.compiler). - WithCompiler(r.compiler). - WithStore(r.store). - WithTransaction(ectx.txn). - WithBuiltins(r.builtinFuncs). - WithMetrics(ectx.metrics). - WithInstrumentation(ectx.instrumentation). - WithUnknowns(unknowns). - WithDisableInlining(ectx.disableInlining). - WithRuntime(r.runtime). - WithIndexing(ectx.indexing). - WithEarlyExit(ectx.earlyExit). - WithPartialNamespace(ectx.partialNamespace). - WithSkipPartialNamespace(r.skipPartialNamespace). - WithShallowInlining(r.shallowInlining). - WithInterQueryBuiltinCache(ectx.interQueryBuiltinCache). - WithInterQueryBuiltinValueCache(ectx.interQueryBuiltinValueCache). - WithStrictBuiltinErrors(ectx.strictBuiltinErrors). - WithSeed(ectx.seed). - WithPrintHook(ectx.printHook) - - if !ectx.time.IsZero() { - q = q.WithTime(ectx.time) - } - - if ectx.ndBuiltinCache != nil { - q = q.WithNDBuiltinCache(ectx.ndBuiltinCache) - } - - for i := range ectx.queryTracers { - q = q.WithQueryTracer(ectx.queryTracers[i]) - } - - if ectx.parsedInput != nil { - q = q.WithInput(ast.NewTerm(ectx.parsedInput)) - } - - for i := range ectx.resolvers { - q = q.WithResolver(ectx.resolvers[i].ref, ectx.resolvers[i].r) - } - - // Cancel query if context is cancelled or deadline is reached. - c := topdown.NewCancel() - q = q.WithCancel(c) - exit := make(chan struct{}) - defer close(exit) - go waitForDone(ctx, exit, func() { - c.Cancel() - }) - - queries, support, err := q.PartialRun(ctx) - if err != nil { - return nil, err - } - - // If the target rego-version is v0, and the rego.v1 import is available, then we attempt to apply it to support modules. - if r.regoVersion == ast.RegoV0 && (r.capabilities == nil || r.capabilities.ContainsFeature(ast.FeatureRegoV1Import)) { - - for i, mod := range support { - // We can't apply the RegoV0CompatV1 version to the support module if it contains rules or vars that - // conflict with future keywords. - applyRegoVersion := true - - ast.WalkRules(mod, func(r *ast.Rule) bool { - name := r.Head.Name - if name == "" && len(r.Head.Reference) > 0 { - name = r.Head.Reference[0].Value.(ast.Var) - } - if ast.IsFutureKeyword(name.String()) { - applyRegoVersion = false - return true - } - return false - }) - - if applyRegoVersion { - ast.WalkVars(mod, func(v ast.Var) bool { - if ast.IsFutureKeyword(v.String()) { - applyRegoVersion = false - return true - } - return false - }) - } - - if applyRegoVersion { - support[i].SetRegoVersion(ast.RegoV0CompatV1) - } else { - support[i].SetRegoVersion(r.regoVersion) - } - } - } else { - // If the target rego-version is not v0, then we apply the target rego-version to the support modules. - for i := range support { - support[i].SetRegoVersion(r.regoVersion) - } - } - - pq := &PartialQueries{ - Queries: queries, - Support: support, - } - - return pq, nil -} - -func (r *Rego) rewriteQueryToCaptureValue(_ ast.QueryCompiler, query ast.Body) (ast.Body, error) { - - checkCapture := iteration(query) || len(query) > 1 - - for _, expr := range query { - - if expr.Negated { - continue - } - - if expr.IsAssignment() || expr.IsEquality() { - continue - } - - var capture *ast.Term - - // If the expression can be evaluated as a function, rewrite it to - // capture the return value. E.g., neq(1,2) becomes neq(1,2,x) but - // plus(1,2,x) does not get rewritten. - switch terms := expr.Terms.(type) { - case *ast.Term: - capture = r.generateTermVar() - expr.Terms = ast.Equality.Expr(terms, capture).Terms - r.capture[expr] = capture.Value.(ast.Var) - case []*ast.Term: - tpe := r.compiler.TypeEnv.Get(terms[0]) - if !types.Void(tpe) && types.Arity(tpe) == len(terms)-1 { - capture = r.generateTermVar() - expr.Terms = append(terms, capture) - r.capture[expr] = capture.Value.(ast.Var) - } - } - - if capture != nil && checkCapture { - cpy := expr.Copy() - cpy.Terms = capture - cpy.Generated = true - cpy.With = nil - query.Append(cpy) - } - } - - return query, nil -} - -func (r *Rego) rewriteQueryForPartialEval(_ ast.QueryCompiler, query ast.Body) (ast.Body, error) { - if len(query) != 1 { - return nil, fmt.Errorf("partial evaluation requires single ref (not multiple expressions)") - } - - term, ok := query[0].Terms.(*ast.Term) - if !ok { - return nil, fmt.Errorf("partial evaluation requires ref (not expression)") - } - - ref, ok := term.Value.(ast.Ref) - if !ok { - return nil, fmt.Errorf("partial evaluation requires ref (not %v)", ast.TypeName(term.Value)) - } - - if !ref.IsGround() { - return nil, fmt.Errorf("partial evaluation requires ground ref") - } - - return ast.NewBody(ast.Equality.Expr(ast.Wildcard, term)), nil -} - -// rewriteEqualsForPartialQueryCompile will rewrite == to = in queries. Normally -// this wouldn't be done, except for handling queries with the `Partial` API -// where rewriting them can substantially simplify the result, and it is unlikely -// that the caller would need expression values. -func (r *Rego) rewriteEqualsForPartialQueryCompile(_ ast.QueryCompiler, query ast.Body) (ast.Body, error) { - doubleEq := ast.Equal.Ref() - unifyOp := ast.Equality.Ref() - ast.WalkExprs(query, func(x *ast.Expr) bool { - if x.IsCall() { - operator := x.Operator() - if operator.Equal(doubleEq) && len(x.Operands()) == 2 { - x.SetOperator(ast.NewTerm(unifyOp)) - } - } - return false - }) - return query, nil -} - -func (r *Rego) generateTermVar() *ast.Term { - r.termVarID++ - prefix := ast.WildcardPrefix - if p := r.targetPlugin(r.target); p != nil { - prefix = wasmVarPrefix - } else if r.target == targetWasm { - prefix = wasmVarPrefix - } - return ast.VarTerm(fmt.Sprintf("%sterm%v", prefix, r.termVarID)) -} - -func (r Rego) hasQuery() bool { - return len(r.query) != 0 || len(r.parsedQuery) != 0 -} - -func (r Rego) hasWasmModule() bool { - for _, b := range r.bundles { - if len(b.WasmModules) > 0 { - return true - } - } - return false -} - -type transactionCloser func(ctx context.Context, err error) error - -// getTxn will conditionally create a read or write transaction suitable for -// the configured Rego object. The returned function should be used to close the txn -// regardless of status. -func (r *Rego) getTxn(ctx context.Context) (storage.Transaction, transactionCloser, error) { - - noopCloser := func(_ context.Context, _ error) error { - return nil // no-op default - } - - if r.txn != nil { - // Externally provided txn - return r.txn, noopCloser, nil - } - - // Create a new transaction.. - params := storage.TransactionParams{} - - // Bundles and data paths may require writing data files or manifests to storage - if len(r.bundles) > 0 || len(r.bundlePaths) > 0 || len(r.loadPaths.paths) > 0 { - - // If we were given a store we will *not* write to it, only do that on one - // which was created automatically on behalf of the user. - if !r.ownStore { - return nil, noopCloser, errors.New("unable to start write transaction when store was provided") - } - - params.Write = true - } - - txn, err := r.store.NewTransaction(ctx, params) - if err != nil { - return nil, noopCloser, err - } - - // Setup a closer function that will abort or commit as needed. - closer := func(ctx context.Context, txnErr error) error { - var err error - - if txnErr == nil && params.Write { - err = r.store.Commit(ctx, txn) - } else { - r.store.Abort(ctx, txn) - } - - // Clear the auto created transaction now that it is closed. - r.txn = nil - - return err - } - - return txn, closer, nil -} - -func (r *Rego) compilerForTxn(ctx context.Context, store storage.Store, txn storage.Transaction) *ast.Compiler { - // Update the compiler to have a valid path conflict check - // for the current context and transaction. - return r.compiler.WithPathConflictsCheck(storage.NonEmpty(ctx, store, txn)) -} - -func checkPartialResultForRecursiveRefs(body ast.Body, path ast.Ref) bool { - var stop bool - ast.WalkRefs(body, func(x ast.Ref) bool { - if !stop { - if path.HasPrefix(x) { - stop = true - } - } - return stop - }) - return stop -} - -func isTermVar(v ast.Var) bool { - return strings.HasPrefix(string(v), ast.WildcardPrefix+"term") -} - -func isTermWasmVar(v ast.Var) bool { - return strings.HasPrefix(string(v), wasmVarPrefix+"term") -} - -func waitForDone(ctx context.Context, exit chan struct{}, f func()) { - select { - case <-exit: - return - case <-ctx.Done(): - f() - return - } -} - -type rawModule struct { - filename string - module string -} - -func (m rawModule) Parse() (*ast.Module, error) { - return ast.ParseModule(m.filename, m.module) -} - -func (m rawModule) ParseWithOpts(opts ast.ParserOptions) (*ast.Module, error) { - return ast.ParseModuleWithOpts(m.filename, m.module, opts) -} - -type extraStage struct { - after string - stage ast.QueryCompilerStageDefinition -} - -type refResolver struct { - ref ast.Ref - r resolver.Resolver -} - -func iteration(x interface{}) bool { - - var stopped bool - - vis := ast.NewGenericVisitor(func(x interface{}) bool { - switch x := x.(type) { - case *ast.Term: - if ast.IsComprehension(x.Value) { - return true - } - case ast.Ref: - if !stopped { - if bi := ast.BuiltinMap[x.String()]; bi != nil { - if bi.Relation { - stopped = true - return stopped - } - } - for i := 1; i < len(x); i++ { - if _, ok := x[i].Value.(ast.Var); ok { - stopped = true - return stopped - } - } - } - return stopped - } - return stopped - }) - - vis.Walk(x) - - return stopped -} - -func parseStringsToRefs(s []string) ([]ast.Ref, error) { - - refs := make([]ast.Ref, len(s)) - for i := range refs { - var err error - refs[i], err = ast.ParseRef(s[i]) - if err != nil { - return nil, err - } - } - - return refs, nil -} - -// helper function to finish a built-in function call. If an error occurred, -// wrap the error and return it. Otherwise, invoke the iterator if the result -// was defined. -func finishFunction(name string, bctx topdown.BuiltinContext, result *ast.Term, err error, iter func(*ast.Term) error) error { - if err != nil { - var e *HaltError - if errors.As(err, &e) { - tdErr := &topdown.Error{ - Code: topdown.BuiltinErr, - Message: fmt.Sprintf("%v: %v", name, e.Error()), - Location: bctx.Location, - } - return topdown.Halt{Err: tdErr.Wrap(e)} - } - tdErr := &topdown.Error{ - Code: topdown.BuiltinErr, - Message: fmt.Sprintf("%v: %v", name, err.Error()), - Location: bctx.Location, - } - return tdErr.Wrap(err) - } - if result == nil { - return nil - } - return iter(result) -} - -// helper function to return an option that sets a custom built-in function. -func newFunction(decl *Function, f topdown.BuiltinFunc) func(*Rego) { - return func(r *Rego) { - r.builtinDecls[decl.Name] = &ast.Builtin{ - Name: decl.Name, - Decl: decl.Decl, - Nondeterministic: decl.Nondeterministic, - } - r.builtinFuncs[decl.Name] = &topdown.Builtin{ - Decl: r.builtinDecls[decl.Name], - Func: f, - } - } -} - -func generateJSON(term *ast.Term, ectx *EvalContext) (interface{}, error) { - return ast.JSONWithOpt(term.Value, - ast.JSONOpt{ - SortSets: ectx.sortSets, - CopyMaps: ectx.copyMaps, - }) -} - -func (r *Rego) planQuery(queries []ast.Body, evalQueryType queryType) (*ir.Policy, error) { - modules := make([]*ast.Module, 0, len(r.compiler.Modules)) - for _, module := range r.compiler.Modules { - modules = append(modules, module) - } - - decls := make(map[string]*ast.Builtin, len(r.builtinDecls)+len(ast.BuiltinMap)) - - for k, v := range ast.BuiltinMap { - decls[k] = v - } - - for k, v := range r.builtinDecls { - decls[k] = v - } - - const queryName = "eval" // NOTE(tsandall): the query name is arbitrary - - p := planner.New(). - WithQueries([]planner.QuerySet{ - { - Name: queryName, - Queries: queries, - RewrittenVars: r.compiledQueries[evalQueryType].compiler.RewrittenVars(), - }, - }). - WithModules(modules). - WithBuiltinDecls(decls). - WithDebug(r.dump) - - policy, err := p.Plan() - if err != nil { - return nil, err - } - if r.dump != nil { - fmt.Fprintln(r.dump, "PLAN:") - fmt.Fprintln(r.dump, "-----") - err = ir.Pretty(r.dump, policy) - if err != nil { - return nil, err - } - fmt.Fprintln(r.dump) - } - return policy, nil + return v1.WithBuiltinFuncs(bis) } diff --git a/vendor/github.com/open-policy-agent/opa/rego/resultset.go b/vendor/github.com/open-policy-agent/opa/rego/resultset.go index e60fa6fbe..5c03360df 100644 --- a/vendor/github.com/open-policy-agent/opa/rego/resultset.go +++ b/vendor/github.com/open-policy-agent/opa/rego/resultset.go @@ -1,90 +1,22 @@ package rego import ( - "fmt" - - "github.com/open-policy-agent/opa/ast" + v1 "github.com/open-policy-agent/opa/v1/rego" ) // ResultSet represents a collection of output from Rego evaluation. An empty // result set represents an undefined query. -type ResultSet []Result +type ResultSet = v1.ResultSet // Vars represents a collection of variable bindings. The keys are the variable // names and the values are the binding values. -type Vars map[string]interface{} - -// WithoutWildcards returns a copy of v with wildcard variables removed. -func (v Vars) WithoutWildcards() Vars { - n := Vars{} - for k, v := range v { - if ast.Var(k).IsWildcard() || ast.Var(k).IsGenerated() { - continue - } - n[k] = v - } - return n -} +type Vars = v1.Vars // Result defines the output of Rego evaluation. -type Result struct { - Expressions []*ExpressionValue `json:"expressions"` - Bindings Vars `json:"bindings,omitempty"` -} - -func newResult() Result { - return Result{ - Bindings: Vars{}, - } -} +type Result = v1.Result // Location defines a position in a Rego query or module. -type Location struct { - Row int `json:"row"` - Col int `json:"col"` -} +type Location = v1.Location // ExpressionValue defines the value of an expression in a Rego query. -type ExpressionValue struct { - Value interface{} `json:"value"` - Text string `json:"text"` - Location *Location `json:"location"` -} - -func newExpressionValue(expr *ast.Expr, value interface{}) *ExpressionValue { - result := &ExpressionValue{ - Value: value, - } - if expr.Location != nil { - result.Text = string(expr.Location.Text) - result.Location = &Location{ - Row: expr.Location.Row, - Col: expr.Location.Col, - } - } - return result -} - -func (ev *ExpressionValue) String() string { - return fmt.Sprint(ev.Value) -} - -// Allowed is a helper method that'll return true if all of these conditions hold: -// - the result set only has one element -// - there is only one expression in the result set's only element -// - that expression has the value `true` -// - there are no bindings. -// -// If bindings are present, this will yield `false`: it would be a pitfall to -// return `true` for a query like `data.authz.allow = x`, which always has result -// set element with value true, but could also have a binding `x: false`. -func (rs ResultSet) Allowed() bool { - if len(rs) == 1 && len(rs[0].Bindings) == 0 { - if exprs := rs[0].Expressions; len(exprs) == 1 { - if b, ok := exprs[0].Value.(bool); ok { - return b - } - } - } - return false -} +type ExpressionValue = v1.ExpressionValue diff --git a/vendor/github.com/open-policy-agent/opa/resolver/wasm/doc.go b/vendor/github.com/open-policy-agent/opa/resolver/wasm/doc.go new file mode 100644 index 000000000..165997e4a --- /dev/null +++ b/vendor/github.com/open-policy-agent/opa/resolver/wasm/doc.go @@ -0,0 +1,8 @@ +// Copyright 2024 The OPA Authors. All rights reserved. +// Use of this source code is governed by an Apache2 +// license that can be found in the LICENSE file. + +// Deprecated: This package is intended for older projects transitioning from OPA v0.x and will remain for the lifetime of OPA v1.x, but its use is not recommended. +// For newer features and behaviours, such as defaulting to the Rego v1 syntax, use the corresponding components in the [github.com/open-policy-agent/opa/v1] package instead. +// See https://www.openpolicyagent.org/docs/latest/v0-compatibility/ for more information. +package wasm diff --git a/vendor/github.com/open-policy-agent/opa/resolver/wasm/wasm.go b/vendor/github.com/open-policy-agent/opa/resolver/wasm/wasm.go index 9c13879dc..ff8b9b820 100644 --- a/vendor/github.com/open-policy-agent/opa/resolver/wasm/wasm.go +++ b/vendor/github.com/open-policy-agent/opa/resolver/wasm/wasm.go @@ -5,169 +5,16 @@ package wasm import ( - "context" - "fmt" - "strconv" - "github.com/open-policy-agent/opa/ast" - "github.com/open-policy-agent/opa/internal/rego/opa" - "github.com/open-policy-agent/opa/resolver" + v1 "github.com/open-policy-agent/opa/v1/resolver/wasm" ) // New creates a new Resolver instance which is using the Wasm module // policy for the given entrypoint ref. func New(entrypoints []ast.Ref, policy []byte, data interface{}) (*Resolver, error) { - e, err := opa.LookupEngine("wasm") - if err != nil { - return nil, err - } - o, err := e.New(). - WithPolicyBytes(policy). - WithDataJSON(data). - Init() - if err != nil { - return nil, err - } - - // Construct a quick lookup table of ref -> entrypoint ID - // for handling evaluations. Only the entrypoints provided - // by the caller will be constructed, this may be a subset - // of entrypoints available in the Wasm module, however - // only the configured ones will be used when Eval() is - // called. - entrypointRefToID := ast.NewValueMap() - epIDs, err := o.Entrypoints(context.Background()) - if err != nil { - return nil, err - } - for path, id := range epIDs { - for _, ref := range entrypoints { - refPtr, err := ref.Ptr() - if err != nil { - return nil, err - } - if refPtr == path { - entrypointRefToID.Put(ref, ast.Number(strconv.Itoa(int(id)))) - } - } - } - - return &Resolver{ - entrypoints: entrypoints, - entrypointIDs: entrypointRefToID, - o: o, - }, nil + return v1.New(entrypoints, policy, data) } // Resolver implements the resolver.Resolver interface // using Wasm modules to perform an evaluation. -type Resolver struct { - entrypoints []ast.Ref - entrypointIDs *ast.ValueMap - o opa.EvalEngine -} - -// Entrypoints returns a list of entrypoints this resolver is configured to -// perform evaluations on. -func (r *Resolver) Entrypoints() []ast.Ref { - return r.entrypoints -} - -// Close shuts down the resolver. -func (r *Resolver) Close() { - r.o.Close() -} - -// Eval performs an evaluation using the provided input and the Wasm module -// associated with this Resolver instance. -func (r *Resolver) Eval(ctx context.Context, input resolver.Input) (resolver.Result, error) { - v := r.entrypointIDs.Get(input.Ref) - if v == nil { - return resolver.Result{}, fmt.Errorf("unknown entrypoint %s", input.Ref) - } - - numValue, ok := v.(ast.Number) - if !ok { - return resolver.Result{}, fmt.Errorf("internal error: invalid entrypoint id %s", numValue) - } - - epID, ok := numValue.Int() - if !ok { - return resolver.Result{}, fmt.Errorf("internal error: invalid entrypoint id %s", numValue) - } - - var in *interface{} - if input.Input != nil { - var str interface{} = []byte(input.Input.String()) - in = &str - } - - opts := opa.EvalOpts{ - Input: in, - Entrypoint: int32(epID), - Metrics: input.Metrics, - } - out, err := r.o.Eval(ctx, opts) - if err != nil { - return resolver.Result{}, err - } - - result, err := getResult(out) - if err != nil { - return resolver.Result{}, err - } - - return resolver.Result{Value: result}, nil -} - -// SetData will update the external data for the Wasm instance. -func (r *Resolver) SetData(ctx context.Context, data interface{}) error { - return r.o.SetData(ctx, data) -} - -// SetDataPath will set the provided data on the wasm instance at the specified path. -func (r *Resolver) SetDataPath(ctx context.Context, path []string, data interface{}) error { - return r.o.SetDataPath(ctx, path, data) -} - -// RemoveDataPath will remove any data at the specified path. -func (r *Resolver) RemoveDataPath(ctx context.Context, path []string) error { - return r.o.RemoveDataPath(ctx, path) -} - -func getResult(evalResult *opa.Result) (ast.Value, error) { - - parsed, err := ast.ParseTerm(string(evalResult.Result)) - if err != nil { - return nil, fmt.Errorf("failed to parse wasm result: %s", err) - } - - resultSet, ok := parsed.Value.(ast.Set) - if !ok { - return nil, fmt.Errorf("illegal result type") - } - - if resultSet.Len() == 0 { - return nil, nil - } - - if resultSet.Len() > 1 { - return nil, fmt.Errorf("illegal result type") - } - - var obj ast.Object - err = resultSet.Iter(func(term *ast.Term) error { - obj, ok = term.Value.(ast.Object) - if !ok || obj.Len() != 1 { - return fmt.Errorf("illegal result type") - } - return nil - }) - if err != nil { - return nil, err - } - - result := obj.Get(ast.StringTerm("result")) - - return result.Value, nil -} +type Resolver = v1.Resolver diff --git a/vendor/github.com/open-policy-agent/opa/runtime/doc.go b/vendor/github.com/open-policy-agent/opa/runtime/doc.go index 4e047af6a..7d72c66a7 100644 --- a/vendor/github.com/open-policy-agent/opa/runtime/doc.go +++ b/vendor/github.com/open-policy-agent/opa/runtime/doc.go @@ -3,4 +3,8 @@ // license that can be found in the LICENSE file. // Package runtime contains the entry point to the policy engine. +// +// Deprecated: This package is intended for older projects transitioning from OPA v0.x and will remain for the lifetime of OPA v1.x, but its use is not recommended. +// For newer features and behaviours, such as defaulting to the Rego v1 syntax, use the corresponding components in the [github.com/open-policy-agent/opa/v1] package instead. +// See https://www.openpolicyagent.org/docs/latest/v0-compatibility/ for more information. package runtime diff --git a/vendor/github.com/open-policy-agent/opa/runtime/logging.go b/vendor/github.com/open-policy-agent/opa/runtime/logging.go index 2ee65a046..5b84ea6a5 100644 --- a/vendor/github.com/open-policy-agent/opa/runtime/logging.go +++ b/vendor/github.com/open-policy-agent/opa/runtime/logging.go @@ -5,262 +5,17 @@ package runtime import ( - "bytes" - "compress/gzip" - "io" "net/http" - "strings" - "sync/atomic" - "time" "github.com/open-policy-agent/opa/logging" - "github.com/open-policy-agent/opa/topdown/print" + v1 "github.com/open-policy-agent/opa/v1/runtime" ) -type loggingPrintHook struct { - logger logging.Logger -} - -func (h loggingPrintHook) Print(pctx print.Context, msg string) error { - // NOTE(tsandall): if the request context is not present then do not panic, - // just log the print message without the additional context. - var fields map[string]any - rctx, ok := logging.FromContext(pctx.Context) - if ok { - fields = rctx.Fields() - } else { - fields = make(map[string]any, 1) - } - fields["line"] = pctx.Location.String() - h.logger.WithFields(fields).Info(msg) - return nil -} - // LoggingHandler returns an http.Handler that will print log messages // containing the request information as well as response status and latency. -type LoggingHandler struct { - logger logging.Logger - inner http.Handler - requestID uint64 -} +type LoggingHandler = v1.LoggingHandler // NewLoggingHandler returns a new http.Handler. func NewLoggingHandler(logger logging.Logger, inner http.Handler) http.Handler { - return &LoggingHandler{ - logger: logger, - inner: inner, - requestID: uint64(0), - } -} - -func (h *LoggingHandler) loggingEnabled(level logging.Level) bool { - return level <= h.logger.GetLevel() -} - -func (h *LoggingHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { - - cloneHeaders := r.Header.Clone() - - // set the HTTP headers via a dedicated key on the parent context irrespective of logging level - r = r.WithContext(logging.WithHTTPRequestContext(r.Context(), &logging.HTTPRequestContext{Header: cloneHeaders})) - - var rctx logging.RequestContext - rctx.ReqID = atomic.AddUint64(&h.requestID, uint64(1)) - - recorder := newRecorder(h.logger, w, r, rctx.ReqID, h.loggingEnabled(logging.Debug)) - t0 := time.Now() - - if h.loggingEnabled(logging.Info) { - - rctx.ClientAddr = r.RemoteAddr - rctx.ReqMethod = r.Method - rctx.ReqPath = r.URL.EscapedPath() - r = r.WithContext(logging.NewContext(r.Context(), &rctx)) - - var err error - fields := rctx.Fields() - - if h.loggingEnabled(logging.Debug) { - var bs []byte - if r.Body != nil { - bs, r.Body, err = readBody(r.Body) - } - if err == nil { - if gzipReceived(r.Header) { - // the request is compressed - var gzReader *gzip.Reader - var plainOutput []byte - reader := bytes.NewReader(bs) - gzReader, err = gzip.NewReader(reader) - if err == nil { - plainOutput, err = io.ReadAll(gzReader) - if err == nil { - defer gzReader.Close() - fields["req_body"] = string(plainOutput) - } - } - } else { - fields["req_body"] = string(bs) - } - } - - // err can be thrown on different statements - if err != nil { - fields["err"] = err - } - - fields["req_params"] = r.URL.Query() - } - - if err == nil { - h.logger.WithFields(fields).Info("Received request.") - } else { - h.logger.WithFields(fields).Error("Failed to read body.") - } - } - - h.inner.ServeHTTP(recorder, r) - - dt := time.Since(t0) - statusCode := 200 - if recorder.statusCode != 0 { - statusCode = recorder.statusCode - } - - if h.loggingEnabled(logging.Info) { - fields := map[string]interface{}{ - "client_addr": rctx.ClientAddr, - "req_id": rctx.ReqID, - "req_method": rctx.ReqMethod, - "req_path": rctx.ReqPath, - "resp_status": statusCode, - "resp_bytes": recorder.bytesWritten, - "resp_duration": float64(dt.Nanoseconds()) / 1e6, - } - - if h.loggingEnabled(logging.Debug) { - switch { - case isPprofEndpoint(r): - // pprof always sends binary data (protobuf) - fields["resp_body"] = "[binary payload]" - - case gzipAccepted(r.Header) && isMetricsEndpoint(r): - // metrics endpoint does so when the client accepts it (e.g. prometheus) - fields["resp_body"] = "[compressed payload]" - - case gzipAccepted(r.Header) && gzipReceived(w.Header()) && (isDataEndpoint(r) || isCompileEndpoint(r)): - // data and compile endpoints might compress the response - gzReader, gzErr := gzip.NewReader(recorder.buf) - if gzErr == nil { - plainOutput, readErr := io.ReadAll(gzReader) - if readErr == nil { - defer gzReader.Close() - fields["resp_body"] = string(plainOutput) - } else { - h.logger.Error("Failed to decompressed the payload: %v", readErr.Error()) - } - } else { - h.logger.Error("Failed to read the compressed payload: %v", gzErr.Error()) - } - - default: - fields["resp_body"] = recorder.buf.String() - } - } - - h.logger.WithFields(fields).Info("Sent response.") - } -} - -func gzipAccepted(header http.Header) bool { - a := header.Get("Accept-Encoding") - parts := strings.Split(a, ",") - for _, part := range parts { - part = strings.TrimSpace(part) - if part == "gzip" || strings.HasPrefix(part, "gzip;") { - return true - } - } - return false -} - -func gzipReceived(header http.Header) bool { - a := header.Get("Content-Encoding") - parts := strings.Split(a, ",") - for _, part := range parts { - part = strings.TrimSpace(part) - if part == "gzip" || strings.HasPrefix(part, "gzip;") { - return true - } - } - return false -} - -func isPprofEndpoint(req *http.Request) bool { - return strings.HasPrefix(req.URL.Path, "/debug/pprof/") -} - -func isMetricsEndpoint(req *http.Request) bool { - return strings.HasPrefix(req.URL.Path, "/metrics") -} - -func isDataEndpoint(req *http.Request) bool { - return strings.HasPrefix(req.URL.Path, "/v1/data") || strings.HasPrefix(req.URL.Path, "/v0/data") -} - -func isCompileEndpoint(req *http.Request) bool { - return strings.HasPrefix(req.URL.Path, "/v1/compile") -} - -type recorder struct { - logger logging.Logger - inner http.ResponseWriter - req *http.Request - id uint64 - - buf *bytes.Buffer - bytesWritten int - statusCode int -} - -func newRecorder(logger logging.Logger, w http.ResponseWriter, r *http.Request, id uint64, buffer bool) *recorder { - var buf *bytes.Buffer - if buffer { - buf = new(bytes.Buffer) - } - return &recorder{ - logger: logger, - buf: buf, - inner: w, - req: r, - id: id, - } -} - -func (r *recorder) Header() http.Header { - return r.inner.Header() -} - -func (r *recorder) Write(bs []byte) (int, error) { - r.bytesWritten += len(bs) - if r.buf != nil { - r.buf.Write(bs) - } - return r.inner.Write(bs) -} - -func (r *recorder) WriteHeader(s int) { - r.statusCode = s - r.inner.WriteHeader(s) -} - -func readBody(r io.ReadCloser) ([]byte, io.ReadCloser, error) { - if r == http.NoBody { - return nil, r, nil - } - var buf bytes.Buffer - if _, err := buf.ReadFrom(r); err != nil { - return nil, r, err - } - return buf.Bytes(), io.NopCloser(bytes.NewReader(buf.Bytes())), nil + return v1.NewLoggingHandler(logger, inner) } diff --git a/vendor/github.com/open-policy-agent/opa/runtime/runtime.go b/vendor/github.com/open-policy-agent/opa/runtime/runtime.go index d39c9efb5..42b4bda2c 100644 --- a/vendor/github.com/open-policy-agent/opa/runtime/runtime.go +++ b/vendor/github.com/open-policy-agent/opa/runtime/runtime.go @@ -5,68 +5,10 @@ package runtime import ( - "bytes" "context" - "crypto/rand" - "crypto/tls" - "crypto/x509" - "fmt" - "io" - mr "math/rand" - "net/url" - "os" - "os/signal" - "strings" - "sync" - "syscall" - "time" - "github.com/fsnotify/fsnotify" - "github.com/gorilla/mux" - "github.com/open-policy-agent/opa/storage/inmem" - "go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp" - "go.opentelemetry.io/otel/exporters/otlp/otlptrace" - "go.opentelemetry.io/otel/propagation" - "go.uber.org/automaxprocs/maxprocs" - - "github.com/open-policy-agent/opa/ast" - "github.com/open-policy-agent/opa/bundle" - opa_config "github.com/open-policy-agent/opa/config" - "github.com/open-policy-agent/opa/internal/compiler" - "github.com/open-policy-agent/opa/internal/config" - internal_tracing "github.com/open-policy-agent/opa/internal/distributedtracing" - internal_logging "github.com/open-policy-agent/opa/internal/logging" - "github.com/open-policy-agent/opa/internal/pathwatcher" - "github.com/open-policy-agent/opa/internal/prometheus" - "github.com/open-policy-agent/opa/internal/ref" - "github.com/open-policy-agent/opa/internal/report" - "github.com/open-policy-agent/opa/internal/runtime" - initload "github.com/open-policy-agent/opa/internal/runtime/init" - "github.com/open-policy-agent/opa/internal/uuid" - "github.com/open-policy-agent/opa/loader" - "github.com/open-policy-agent/opa/logging" - "github.com/open-policy-agent/opa/metrics" "github.com/open-policy-agent/opa/plugins" - "github.com/open-policy-agent/opa/plugins/discovery" - "github.com/open-policy-agent/opa/plugins/logs" - metrics_config "github.com/open-policy-agent/opa/plugins/server/metrics" - "github.com/open-policy-agent/opa/repl" - "github.com/open-policy-agent/opa/server" - "github.com/open-policy-agent/opa/storage" - "github.com/open-policy-agent/opa/storage/disk" - "github.com/open-policy-agent/opa/tracing" - "github.com/open-policy-agent/opa/util" - "github.com/open-policy-agent/opa/version" -) - -var ( - registeredPlugins map[string]plugins.Factory - registeredPluginsMux sync.Mutex -) - -const ( - // default interval between OPA version report uploads - defaultUploadIntervalSec = int64(3600) + v1 "github.com/open-policy-agent/opa/v1/runtime" ) // RegisterPlugin registers a plugin factory with the runtime @@ -74,919 +16,25 @@ const ( // plugin configuration and instantiate plugins. If no configuration is // provided, plugins are not instantiated. This function is idempotent. func RegisterPlugin(name string, factory plugins.Factory) { - registeredPluginsMux.Lock() - defer registeredPluginsMux.Unlock() - registeredPlugins[name] = factory + v1.RegisterPlugin(name, factory) } // Params stores the configuration for an OPA instance. -type Params struct { - // Globally unique identifier for this OPA instance. If an ID is not specified, - // the runtime will generate one. - ID string - - // Addrs are the listening addresses that the OPA server will bind to. - Addrs *[]string - - // DiagnosticAddrs are the listening addresses that the OPA server will bind to - // for read-only diagnostic API's (/health, /metrics, etc) - DiagnosticAddrs *[]string - - // H2CEnabled flag controls whether OPA will allow H2C (HTTP/2 cleartext) on - // HTTP listeners. - H2CEnabled bool - - // Authentication is the type of authentication scheme to use. - Authentication server.AuthenticationScheme - - // Authorization is the type of authorization scheme to use. - Authorization server.AuthorizationScheme - - // Certificate is the certificate to use in server-mode. If the certificate - // is nil, the server will NOT use TLS. - Certificate *tls.Certificate - - // CertificateFile and CertificateKeyFile are the paths to the cert and its - // keyfile. It'll be used to periodically reload the files from disk if they - // have changed. The server will attempt to refresh every 5 minutes, unless - // a different CertificateRefresh time.Duration is provided - CertificateFile string - CertificateKeyFile string - CertificateRefresh time.Duration - - // CertPool holds the CA certs trusted by the OPA server. - CertPool *x509.CertPool - // CertPoolFile, if set permits the reloading of the CA cert pool from disk - CertPoolFile string - - // MinVersion contains the minimum TLS version that is acceptable. - // If zero, TLS 1.2 is currently taken as the minimum. - MinTLSVersion uint16 - - // HistoryPath is the filename to store the interactive shell user - // input history. - HistoryPath string - - // Output format controls how the REPL will print query results. - // Default: "pretty". - OutputFormat string - - // Paths contains filenames of base documents and policy modules to load on - // startup. Data files may be prefixed with ":" to indicate - // where the contained document should be loaded. - Paths []string - - // Optional filter that will be passed to the file loader. - Filter loader.Filter - - // BundleMode will enable treating the Paths provided as bundles rather than - // loading all data & policy files. - BundleMode bool - - // Watch flag controls whether OPA will watch the Paths files for changes. - // If this flag is true, OPA will watch the Paths files for changes and - // reload the storage layer each time they change. This is useful for - // interactive development. - Watch bool - - // ErrorLimit is the number of errors the compiler will allow to occur before - // exiting early. - ErrorLimit int - - // PprofEnabled flag controls whether pprof endpoints are enabled - PprofEnabled bool - - // DecisionIDFactory generates decision IDs to include in API responses - // sent by the server (in response to Data API queries.) - DecisionIDFactory func() string - - // Logging configures the logging behaviour. - Logging LoggingConfig - - // Logger sets the logger implementation to use for debug logs. - Logger logging.Logger - - // ConsoleLogger sets the logger implementation to use for console logs. - ConsoleLogger logging.Logger - - // ConfigFile refers to the OPA configuration to load on startup. - ConfigFile string - - // ConfigOverrides are overrides for the OPA configuration that are applied - // over top the config file They are in a list of key=value syntax that - // conform to the syntax defined in the `strval` package - ConfigOverrides []string - - // ConfigOverrideFiles Similar to `ConfigOverrides` except they are in the - // form of `key=path/to/file`where the file contains the value to be used. - ConfigOverrideFiles []string - - // Output is the output stream used when run as an interactive shell. This - // is mostly for test purposes. - Output io.Writer - - // GracefulShutdownPeriod is the time (in seconds) to wait for the http - // server to shutdown gracefully. - GracefulShutdownPeriod int - - // ShutdownWaitPeriod is the time (in seconds) to wait before initiating shutdown. - ShutdownWaitPeriod int - - // EnableVersionCheck flag controls whether OPA will report its version to an external service. - // If this flag is true, OPA will report its version to the external service - EnableVersionCheck bool - - // BundleVerificationConfig sets the key configuration used to verify a signed bundle - BundleVerificationConfig *bundle.VerificationConfig - - // SkipBundleVerification flag controls whether OPA will verify a signed bundle - SkipBundleVerification bool - - // SkipKnownSchemaCheck flag controls whether OPA will perform type checking on known input schemas - SkipKnownSchemaCheck bool - - // ReadyTimeout flag controls if and for how long OPA server will wait (in seconds) for - // configured bundles and plugins to be activated/ready before listening for traffic. - // A value of 0 or less means no wait is exercised. - ReadyTimeout int - - // Router is the router to which handlers for the REST API are added. - // Router uses a first-matching-route-wins strategy, so no existing routes are overridden - // If it is nil, a new mux.Router will be created - Router *mux.Router - - // DiskStorage, if set, will make the runtime instantiate a disk-backed storage - // implementation (instead of the default, in-memory store). - // It can also be enabled via config, and this runtime field takes precedence. - DiskStorage *disk.Options - - DistributedTracingOpts tracing.Options - - // Check if default Addr is set or the user has changed it. - AddrSetByUser bool - - // UnixSocketPerm specifies the permission for the Unix domain socket if used to listen for connections - UnixSocketPerm *string - - // V0Compatible will enable OPA features and behaviors that were enabled by default in OPA v0.x releases. - // Takes precedence over V1Compatible. - V0Compatible bool - - // V1Compatible will enable OPA features and behaviors that will be enabled by default in a future OPA v1.0 release. - // This flag allows users to opt-in to the new behavior and helps transition to the future release upon which - // the new behavior will be enabled by default. - // If V0Compatible is set, V1Compatible will be ignored. - V1Compatible bool - - // CipherSuites specifies the list of enabled TLS 1.0–1.2 cipher suites - CipherSuites *[]uint16 - - // ReadAstValuesFromStore controls whether the storage layer should return AST values when reading from the store. - // This is an eager conversion, that comes with an upfront performance cost when updating the store (e.g. bundle updates). - // Evaluation performance is affected in that data doesn't need to be converted to AST during evaluation. - // Only applicable when using the default in-memory store, and not when used together with the DiskStorage option. - ReadAstValuesFromStore bool -} - -func (p *Params) regoVersion() ast.RegoVersion { - // v0 takes precedence over v1 - if p.V0Compatible { - return ast.RegoV0 - } - if p.V1Compatible { - return ast.RegoV1 - } - return ast.DefaultRegoVersion -} +type Params = v1.Params // LoggingConfig stores the configuration for OPA's logging behaviour. -type LoggingConfig struct { - Level string - Format string - TimestampFormat string -} +type LoggingConfig = v1.LoggingConfig // NewParams returns a new Params object. func NewParams() Params { - return Params{ - Output: os.Stdout, - BundleMode: false, - EnableVersionCheck: false, - } + return v1.NewParams() } // Runtime represents a single OPA instance. -type Runtime struct { - Params Params - Store storage.Store - Manager *plugins.Manager - - logger logging.Logger - server *server.Server - metrics *prometheus.Provider - reporter *report.Reporter - traceExporter *otlptrace.Exporter - loadedPathsResult *initload.LoadPathsResult - - serverInitialized bool - serverInitMtx sync.RWMutex - done chan struct{} - repl *repl.REPL -} +type Runtime = v1.Runtime // NewRuntime returns a new Runtime object initialized with params. Clients must // call StartServer() or StartREPL() to start the runtime in either mode. func NewRuntime(ctx context.Context, params Params) (*Runtime, error) { - if params.ID == "" { - var err error - params.ID, err = generateInstanceID() - if err != nil { - return nil, err - } - } - - level, err := internal_logging.GetLevel(params.Logging.Level) - if err != nil { - return nil, err - } - - // NOTE(tsandall): This is a temporary hack to ensure that log formatting - // and leveling is applied correctly. Currently there are a few places where - // the global logger is used as a fallback, however, that fallback _should_ - // never be used. This ensures that _if_ the fallback is used accidentally, - // that the logging configuration is applied. Once we remove all usage of - // the global logger and we remove the API that allows callers to access the - // global logger, we can remove this. - logging.Get().SetFormatter(internal_logging.GetFormatter(params.Logging.Format, params.Logging.TimestampFormat)) - logging.Get().SetLevel(level) - - var logger logging.Logger - - if params.Logger != nil { - logger = params.Logger - } else { - stdLogger := logging.New() - stdLogger.SetLevel(level) - stdLogger.SetFormatter(internal_logging.GetFormatter(params.Logging.Format, params.Logging.TimestampFormat)) - logger = stdLogger - } - - var filePaths []string - urlPathCount := 0 - for _, path := range params.Paths { - if strings.HasPrefix(path, "http://") || strings.HasPrefix(path, "https://") { - urlPathCount++ - override, err := urlPathToConfigOverride(urlPathCount, path) - if err != nil { - return nil, err - } - params.ConfigOverrides = append(params.ConfigOverrides, override...) - } else { - filePaths = append(filePaths, path) - } - } - params.Paths = filePaths - - config, err := config.Load(params.ConfigFile, params.ConfigOverrides, params.ConfigOverrideFiles) - if err != nil { - return nil, fmt.Errorf("config error: %w", err) - } - - var reporter *report.Reporter - if params.EnableVersionCheck { - var err error - reporter, err = report.New(params.ID, report.Options{Logger: logger}) - if err != nil { - return nil, fmt.Errorf("config error: %w", err) - } - } - - regoVersion := params.regoVersion() - - loaded, err := initload.LoadPathsForRegoVersion(regoVersion, params.Paths, params.Filter, params.BundleMode, params.BundleVerificationConfig, params.SkipBundleVerification, false, false, nil, nil) - if err != nil { - return nil, fmt.Errorf("load error: %w", err) - } - - isAuthorizationEnabled := params.Authorization != server.AuthorizationOff - - info, err := runtime.Term(runtime.Params{Config: config, IsAuthorizationEnabled: isAuthorizationEnabled, SkipKnownSchemaCheck: params.SkipKnownSchemaCheck}) - if err != nil { - return nil, err - } - - consoleLogger := params.ConsoleLogger - if consoleLogger == nil { - l := logging.New() - l.SetFormatter(internal_logging.GetFormatter(params.Logging.Format, params.Logging.TimestampFormat)) - consoleLogger = l - } - - if params.Router == nil { - params.Router = mux.NewRouter() - } - - metricsConfig, parseConfigErr := extractMetricsConfig(config, params) - if parseConfigErr != nil { - return nil, parseConfigErr - } - metrics := prometheus.New(metrics.New(), errorLogger(logger), metricsConfig.Prom.HTTPRequestDurationSeconds.Buckets) - - var store storage.Store - if params.DiskStorage == nil { - params.DiskStorage, err = disk.OptionsFromConfig(config, params.ID) - if err != nil { - return nil, fmt.Errorf("parse disk store configuration: %w", err) - } - } - - if params.DiskStorage != nil { - store, err = disk.New(ctx, logger, metrics, *params.DiskStorage) - if err != nil { - return nil, fmt.Errorf("initialize disk store: %w", err) - } - } else { - store = inmem.NewWithOpts(inmem.OptRoundTripOnWrite(false), - inmem.OptReturnASTValuesOnRead(params.ReadAstValuesFromStore)) - } - - traceExporter, tracerProvider, _, err := internal_tracing.Init(ctx, config, params.ID) - if err != nil { - return nil, fmt.Errorf("config error: %w", err) - } - if tracerProvider != nil { - params.DistributedTracingOpts = tracing.NewOptions( - otelhttp.WithTracerProvider(tracerProvider), - otelhttp.WithPropagators(propagation.TraceContext{}), - ) - } - - manager, err := plugins.New(config, - params.ID, - store, - plugins.Info(info), - plugins.InitBundles(loaded.Bundles), - plugins.InitFiles(loaded.Files), - plugins.MaxErrors(params.ErrorLimit), - plugins.GracefulShutdownPeriod(params.GracefulShutdownPeriod), - plugins.ConsoleLogger(consoleLogger), - plugins.Logger(logger), - plugins.EnablePrintStatements(logger.GetLevel() >= logging.Info), - plugins.PrintHook(loggingPrintHook{logger: logger}), - plugins.WithRouter(params.Router), - plugins.WithPrometheusRegister(metrics), - plugins.WithTracerProvider(tracerProvider), - plugins.WithEnableTelemetry(params.EnableVersionCheck), - plugins.WithParserOptions(ast.ParserOptions{RegoVersion: regoVersion})) - if err != nil { - return nil, fmt.Errorf("config error: %w", err) - } - - if err := manager.Init(ctx); err != nil { - return nil, fmt.Errorf("initialization error: %w", err) - } - - if isAuthorizationEnabled && !params.SkipKnownSchemaCheck { - if err := verifyAuthorizationPolicySchema(manager); err != nil { - return nil, fmt.Errorf("initialization error: %w", err) - } - } - - var bootConfig map[string]interface{} - err = util.Unmarshal(config, &bootConfig) - if err != nil { - return nil, fmt.Errorf("config error: %w", err) - } - - disco, err := discovery.New(manager, discovery.Factories(registeredPlugins), discovery.Metrics(metrics), discovery.BootConfig(bootConfig)) - if err != nil { - return nil, fmt.Errorf("config error: %w", err) - } - - manager.Register(discovery.Name, disco) - - rt := &Runtime{ - Store: manager.Store, - Params: params, - Manager: manager, - logger: logger, - metrics: metrics, - reporter: reporter, - serverInitialized: false, - traceExporter: traceExporter, - loadedPathsResult: loaded, - } - - return rt, nil -} - -// extractMetricsConfig returns the configuration for server metrics and parsing errors if any -func extractMetricsConfig(config []byte, params Params) (*metrics_config.Config, error) { - var opaParsedConfig, opaParsedConfigErr = opa_config.ParseConfig(config, params.ID) - if opaParsedConfigErr != nil { - return nil, opaParsedConfigErr - } - - var serverMetricsData []byte - if opaParsedConfig.Server != nil { - serverMetricsData = opaParsedConfig.Server.Metrics - } - - var configBuilder = metrics_config.NewConfigBuilder() - var metricsParsedConfig, metricsParsedConfigErr = configBuilder.WithBytes(serverMetricsData).Parse() - if metricsParsedConfigErr != nil { - return nil, fmt.Errorf("server metrics configuration parse error: %w", metricsParsedConfigErr) - } - - return metricsParsedConfig, nil -} - -// StartServer starts the runtime in server mode. This function will block the -// calling goroutine and will exit the program on error. -func (rt *Runtime) StartServer(ctx context.Context) { - err := rt.Serve(ctx) - if err != nil { - os.Exit(1) - } -} - -// Serve will start a new REST API server and listen for requests. This -// will block until either: an error occurs, the context is canceled, or -// a SIGTERM or SIGKILL signal is sent. -func (rt *Runtime) Serve(ctx context.Context) error { - if rt.Params.Addrs == nil { - return fmt.Errorf("at least one address must be configured in runtime parameters") - } - - serverInitializingMessage := "Initializing server." - if !rt.Params.AddrSetByUser && (rt.Params.V0Compatible || !rt.Params.V1Compatible) { - serverInitializingMessage += " OPA is running on a public (0.0.0.0) network interface. Unless you intend to expose OPA outside of the host, binding to the localhost interface (--addr localhost:8181) is recommended. See https://www.openpolicyagent.org/docs/latest/security/#interface-binding" - } - - if rt.Params.DiagnosticAddrs == nil { - rt.Params.DiagnosticAddrs = &[]string{} - } - - rt.logger.WithFields(map[string]interface{}{ - "addrs": *rt.Params.Addrs, - "diagnostic-addrs": *rt.Params.DiagnosticAddrs, - }).Info(serverInitializingMessage) - - if rt.Params.Authorization == server.AuthorizationOff && rt.Params.Authentication == server.AuthenticationToken { - rt.logger.Error("Token authentication enabled without authorization. Authentication will be ineffective. See https://www.openpolicyagent.org/docs/latest/security/#authentication-and-authorization for more information.") - } - - checkUserPrivileges(rt.logger) - - // NOTE(tsandall): at some point, hopefully we can remove this because the - // Go runtime will just do the right thing. Until then, try to set - // GOMAXPROCS based on the CPU quota applied to the process. - undo, err := maxprocs.Set(maxprocs.Logger(func(f string, a ...interface{}) { - rt.logger.Debug(f, a...) - })) - if err != nil { - rt.logger.WithFields(map[string]interface{}{"err": err}).Debug("Failed to set GOMAXPROCS from CPU quota.") - } - - defer undo() - - if err := rt.Manager.Start(ctx); err != nil { - rt.logger.WithFields(map[string]interface{}{"err": err}).Error("Failed to start plugins.") - return err - } - - defer rt.Manager.Stop(ctx) - - if rt.traceExporter != nil { - if err := rt.traceExporter.Start(ctx); err != nil { - rt.logger.WithFields(map[string]interface{}{"err": err}).Error("Failed to start OpenTelemetry trace exporter.") - return err - } - } - - rt.server = server.New(). - WithRouter(rt.Params.Router). - WithStore(rt.Store). - WithManager(rt.Manager). - WithCompilerErrorLimit(rt.Params.ErrorLimit). - WithPprofEnabled(rt.Params.PprofEnabled). - WithAddresses(*rt.Params.Addrs). - WithH2CEnabled(rt.Params.H2CEnabled). - // always use the initial values for the certificate and ca pool, reloading behavior is configured below - WithCertificate(rt.Params.Certificate). - WithCertPool(rt.Params.CertPool). - WithAuthentication(rt.Params.Authentication). - WithAuthorization(rt.Params.Authorization). - WithDecisionIDFactory(rt.decisionIDFactory). - WithDecisionLoggerWithErr(rt.decisionLogger). - WithRuntime(rt.Manager.Info). - WithMetrics(rt.metrics). - WithMinTLSVersion(rt.Params.MinTLSVersion). - WithCipherSuites(rt.Params.CipherSuites). - WithDistributedTracingOpts(rt.Params.DistributedTracingOpts) - - // If decision_logging plugin enabled, check to see if we opted in to the ND builtins cache. - if lp := logs.Lookup(rt.Manager); lp != nil { - rt.server = rt.server.WithNDBCacheEnabled(rt.Manager.Config.NDBuiltinCacheEnabled()) - } - - if rt.Params.DiagnosticAddrs != nil { - rt.server = rt.server.WithDiagnosticAddresses(*rt.Params.DiagnosticAddrs) - } - - if rt.Params.UnixSocketPerm != nil { - rt.server = rt.server.WithUnixSocketPermission(rt.Params.UnixSocketPerm) - } - - // If a refresh period is set, then we will periodically reload the certificate and ca pool. Otherwise, we will only - // reload cert, key and ca pool files when they change on disk. - if rt.Params.CertificateRefresh > 0 { - rt.server = rt.server.WithCertRefresh(rt.Params.CertificateRefresh) - } - - // if either the cert or the ca pool file is set then these fields will be set on the server and reloaded when they - // change on disk. - if rt.Params.CertificateFile != "" || rt.Params.CertPoolFile != "" { - rt.server = rt.server.WithTLSConfig(&server.TLSConfig{ - CertFile: rt.Params.CertificateFile, - KeyFile: rt.Params.CertificateKeyFile, - CertPoolFile: rt.Params.CertPoolFile, - }) - } - - ctx, cancel := context.WithCancel(ctx) - defer cancel() - rt.server, err = rt.server.Init(ctx) - if err != nil { - rt.logger.WithFields(map[string]interface{}{"err": err}).Error("Unable to initialize server.") - return err - } - - if rt.Params.Watch { - if err := rt.startWatcher(ctx, rt.Params.Paths, rt.onReloadLogger); err != nil { - rt.logger.WithFields(map[string]interface{}{"err": err}).Error("Unable to open watch.") - return err - } - } - - if rt.Params.EnableVersionCheck { - d := time.Duration(int64(time.Second) * defaultUploadIntervalSec) - rt.done = make(chan struct{}) - go rt.checkOPAUpdateLoop(ctx, d, rt.done) - } - - defer func() { - if rt.done != nil { - rt.done <- struct{}{} - } - }() - - rt.server.Handler = NewLoggingHandler(rt.logger, rt.server.Handler) - rt.server.DiagnosticHandler = NewLoggingHandler(rt.logger, rt.server.DiagnosticHandler) - - if err := rt.waitPluginsReady( - 100*time.Millisecond, - time.Second*time.Duration(rt.Params.ReadyTimeout)); err != nil { - rt.logger.WithFields(map[string]interface{}{"err": err}).Error("Failed to wait for plugins activation.") - return err - } - - loops, err := rt.server.Listeners() - if err != nil { - rt.logger.WithFields(map[string]interface{}{"err": err}).Error("Unable to create listeners.") - return err - } - - errc := make(chan error) - for _, loop := range loops { - go func(serverLoop func() error) { - errc <- serverLoop() - }(loop) - } - - // Buffer one element as os/signal uses non-blocking channel sends. - // This prevents potentially dropping the first element and failing to shut - // down gracefully. A buffer of 1 is sufficient as we're just looking for a - // one-time shutdown signal. - signalc := make(chan os.Signal, 1) - signal.Notify(signalc, syscall.SIGINT, syscall.SIGTERM) - - // Note that there is a small chance the socket of the server listener is still - // closed by the time this block is executed, due to the serverLoop above - // executing in a goroutine. - rt.serverInitMtx.Lock() - rt.serverInitialized = true - rt.serverInitMtx.Unlock() - rt.Manager.ServerInitialized() - - rt.logger.Debug("Server initialized.") - - for { - select { - case <-ctx.Done(): - return rt.gracefulServerShutdown(rt.server) - case <-signalc: - return rt.gracefulServerShutdown(rt.server) - case err := <-errc: - rt.logger.WithFields(map[string]interface{}{"err": err}).Error("Listener failed.") - os.Exit(1) - } - } -} - -// Addrs returns a list of addresses that the runtime is listening on (when -// in server mode). Returns an empty list if it hasn't started listening. -func (rt *Runtime) Addrs() []string { - rt.serverInitMtx.RLock() - defer rt.serverInitMtx.RUnlock() - - if !rt.serverInitialized { - return nil - } - - return rt.server.Addrs() -} - -// DiagnosticAddrs returns a list of diagnostic addresses that the runtime is -// listening on (when in server mode). Returns an empty list if it hasn't -// started listening. -func (rt *Runtime) DiagnosticAddrs() []string { - if rt.server == nil { - return nil - } - - return rt.server.DiagnosticAddrs() -} - -// StartREPL starts the runtime in REPL mode. This function will block the calling goroutine. -func (rt *Runtime) StartREPL(ctx context.Context) { - if err := rt.Manager.Start(ctx); err != nil { - fmt.Fprintln(rt.Params.Output, "error starting plugins:", err) - os.Exit(1) - } - - defer rt.Manager.Stop(ctx) - - banner := rt.getBanner() - repl := repl.New(rt.Store, rt.Params.HistoryPath, rt.Params.Output, rt.Params.OutputFormat, rt.Params.ErrorLimit, banner). - WithRuntime(rt.Manager.Info). - WithRegoVersion(rt.Params.regoVersion()). - WithInitBundles(rt.loadedPathsResult.Bundles) - - if rt.Params.Watch { - if err := rt.startWatcher(ctx, rt.Params.Paths, onReloadPrinter(rt.Params.Output)); err != nil { - fmt.Fprintln(rt.Params.Output, "error opening watch:", err) - os.Exit(1) - } - } - - if rt.Params.EnableVersionCheck { - go func() { - repl.SetOPAVersionReport(rt.checkOPAUpdate(ctx).Slice()) - }() - } - - rt.repl = repl - repl.Loop(ctx) -} - -// SetDistributedTracingLogging configures the distributed tracing's ErrorHandler, -// and logger instances. -func (rt *Runtime) SetDistributedTracingLogging() { - internal_tracing.SetupLogging(rt.logger) -} - -func (rt *Runtime) checkOPAUpdate(ctx context.Context) *report.DataResponse { - resp, _ := rt.reporter.SendReport(ctx) - return resp -} - -func (rt *Runtime) checkOPAUpdateLoop(ctx context.Context, uploadDuration time.Duration, done chan struct{}) { - ticker := time.NewTicker(uploadDuration) - mr.New(mr.NewSource(time.Now().UnixNano())) // Seed the PRNG. - - for { - resp, err := rt.reporter.SendReport(ctx) - if err != nil { - rt.logger.WithFields(map[string]interface{}{"err": err}).Debug("Unable to send OPA version report.") - } else { - if resp.Latest.OPAUpToDate { - rt.logger.WithFields(map[string]interface{}{ - "current_version": version.Version, - }).Debug("OPA is up to date.") - } else { - rt.logger.WithFields(map[string]interface{}{ - "download_opa": resp.Latest.Download, - "release_notes": resp.Latest.ReleaseNotes, - "current_version": version.Version, - "latest_version": strings.TrimPrefix(resp.Latest.LatestRelease, "v"), - }).Info("OPA is out of date.") - } - } - select { - case <-ticker.C: - ticker.Stop() - newInterval := mr.Int63n(defaultUploadIntervalSec) + defaultUploadIntervalSec - ticker = time.NewTicker(time.Duration(int64(time.Second) * newInterval)) - case <-done: - ticker.Stop() - return - } - } -} - -func (rt *Runtime) decisionIDFactory() string { - if rt.Params.DecisionIDFactory != nil { - return rt.Params.DecisionIDFactory() - } - if logs.Lookup(rt.Manager) != nil { - return generateDecisionID() - } - return "" -} - -func (rt *Runtime) decisionLogger(ctx context.Context, event *server.Info) error { - plugin := logs.Lookup(rt.Manager) - if plugin == nil { - return nil - } - - return plugin.Log(ctx, event) -} - -func (rt *Runtime) startWatcher(ctx context.Context, paths []string, onReload func(time.Duration, error)) error { - watcher, err := rt.getWatcher(paths) - if err != nil { - return err - } - go rt.readWatcher(ctx, watcher, paths, onReload) - return nil -} - -func (rt *Runtime) readWatcher(ctx context.Context, watcher *fsnotify.Watcher, paths []string, onReload func(time.Duration, error)) { - for evt := range watcher.Events { - removalMask := fsnotify.Remove | fsnotify.Rename - mask := fsnotify.Create | fsnotify.Write | removalMask - if (evt.Op & mask) != 0 { - rt.logger.WithFields(map[string]interface{}{ - "event": evt.String(), - }).Debug("Registered file event.") - t0 := time.Now() - removed := "" - if (evt.Op & removalMask) != 0 { - removed = evt.Name - } - err := rt.processWatcherUpdate(ctx, paths, removed) - onReload(time.Since(t0), err) - } - } -} - -func (rt *Runtime) processWatcherUpdate(ctx context.Context, paths []string, removed string) error { - - return pathwatcher.ProcessWatcherUpdateForRegoVersion(ctx, rt.Manager.ParserOptions().RegoVersion, paths, removed, rt.Store, rt.Params.Filter, rt.Params.BundleMode, func(ctx context.Context, txn storage.Transaction, loaded *initload.LoadPathsResult) error { - _, err := initload.InsertAndCompile(ctx, initload.InsertAndCompileOptions{ - Store: rt.Store, - Txn: txn, - Files: loaded.Files, - Bundles: loaded.Bundles, - MaxErrors: -1, - ParserOptions: rt.Manager.ParserOptions(), - }) - - return err - }) -} - -func (rt *Runtime) getBanner() string { - var buf bytes.Buffer - fmt.Fprintf(&buf, "OPA %v (commit %v, built at %v)\n", version.Version, version.Vcs, version.Timestamp) - fmt.Fprintf(&buf, "\n") - fmt.Fprintf(&buf, "Run 'help' to see a list of commands and check for updates.\n") - return buf.String() -} - -func (rt *Runtime) gracefulServerShutdown(s *server.Server) error { - if rt.Params.ShutdownWaitPeriod > 0 { - rt.logger.Info("Waiting %vs before initiating shutdown...", rt.Params.ShutdownWaitPeriod) - time.Sleep(time.Duration(rt.Params.ShutdownWaitPeriod) * time.Second) - } - - rt.logger.Info("Shutting down...") - ctx, cancel := context.WithTimeout(context.Background(), time.Duration(rt.Params.GracefulShutdownPeriod)*time.Second) - defer cancel() - err := s.Shutdown(ctx) - if err != nil { - rt.logger.WithFields(map[string]interface{}{"err": err}).Error("Failed to shutdown server gracefully.") - return err - } - rt.logger.Info("Server shutdown.") - - if rt.traceExporter != nil { - err = rt.traceExporter.Shutdown(ctx) - if err != nil { - rt.logger.WithFields(map[string]interface{}{"err": err}).Error("Failed to shutdown OpenTelemetry trace exporter gracefully.") - } - } - return nil -} - -func (rt *Runtime) waitPluginsReady(checkInterval, timeout time.Duration) error { - if timeout <= 0 { - return nil - } - - // check readiness of all plugins - pluginsReady := func() bool { - for _, status := range rt.Manager.PluginStatus() { - if status != nil && status.State != plugins.StateOK { - return false - } - } - return true - } - - rt.logger.Debug("Waiting for plugins activation (%v).", timeout) - - return util.WaitFunc(pluginsReady, checkInterval, timeout) -} - -func (rt *Runtime) onReloadLogger(d time.Duration, err error) { - rt.logger.WithFields(map[string]interface{}{ - "duration": d, - "err": err, - }).Info("Processed file watch event.") -} - -func (rt *Runtime) getWatcher(rootPaths []string) (*fsnotify.Watcher, error) { - watcher, err := pathwatcher.CreatePathWatcher(rootPaths) - if err != nil { - return nil, err - } - - for _, path := range watcher.WatchList() { - rt.logger.WithFields(map[string]interface{}{"path": path}).Debug("watching path") - } - - return watcher, nil -} - -func urlPathToConfigOverride(pathCount int, path string) ([]string, error) { - uri, err := url.Parse(path) - if err != nil { - return nil, err - } - baseURL := uri.Scheme + "://" + uri.Host - urlPath := uri.Path - if uri.RawQuery != "" { - urlPath += "?" + uri.RawQuery - } - - return []string{ - fmt.Sprintf("services.cli%d.url=%s", pathCount, baseURL), - fmt.Sprintf("bundles.cli%d.service=cli%d", pathCount, pathCount), - fmt.Sprintf("bundles.cli%d.resource=%s", pathCount, urlPath), - fmt.Sprintf("bundles.cli%d.persist=true", pathCount), - }, nil -} - -func errorLogger(logger logging.Logger) func(attrs map[string]interface{}, f string, a ...interface{}) { - return func(attrs map[string]interface{}, f string, a ...interface{}) { - logger.WithFields(attrs).Error(f, a...) - } -} - -func onReloadPrinter(output io.Writer) func(time.Duration, error) { - return func(d time.Duration, err error) { - if err != nil { - fmt.Fprintf(output, "\n# reload error (took %v): %v", d, err) - } else { - fmt.Fprintf(output, "\n# reloaded files (took %v)", d) - } - } -} - -func generateInstanceID() (string, error) { - return uuid.New(rand.Reader) -} - -func generateDecisionID() string { - id, err := uuid.New(rand.Reader) - if err != nil { - return "" - } - return id -} - -func verifyAuthorizationPolicySchema(m *plugins.Manager) error { - authorizationDecisionRef, err := ref.ParseDataPath(*m.Config.DefaultAuthorizationDecision) - if err != nil { - return err - } - - return compiler.VerifyAuthorizationPolicySchema(m.GetCompiler(), authorizationDecisionRef) -} - -func init() { - registeredPlugins = make(map[string]plugins.Factory) + return v1.NewRuntime(ctx, params) } diff --git a/vendor/github.com/open-policy-agent/opa/server/buffer.go b/vendor/github.com/open-policy-agent/opa/server/buffer.go index 412abe1c4..be44f3a9c 100644 --- a/vendor/github.com/open-policy-agent/opa/server/buffer.go +++ b/vendor/github.com/open-policy-agent/opa/server/buffer.go @@ -5,40 +5,11 @@ package server import ( - "time" - - "github.com/open-policy-agent/opa/ast" - "github.com/open-policy-agent/opa/logging" - "github.com/open-policy-agent/opa/metrics" - "github.com/open-policy-agent/opa/storage" - "github.com/open-policy-agent/opa/topdown" + v1 "github.com/open-policy-agent/opa/v1/server" ) // Info contains information describing a policy decision. -type Info struct { - Txn storage.Transaction - Revision string // Deprecated: Use `Bundles` instead - Bundles map[string]BundleInfo - DecisionID string - TraceID string - SpanID string - RemoteAddr string - HTTPRequestContext logging.HTTPRequestContext - Query string - Path string - Timestamp time.Time - Input *interface{} - InputAST ast.Value - Results *interface{} - MappedResults *interface{} - NDBuiltinCache *interface{} - Error error - Metrics metrics.Metrics - Trace []*topdown.Event - RequestID uint64 -} +type Info = v1.Info // BundleInfo contains information describing a bundle. -type BundleInfo struct { - Revision string -} +type BundleInfo = v1.BundleInfo diff --git a/vendor/github.com/open-policy-agent/opa/server/doc.go b/vendor/github.com/open-policy-agent/opa/server/doc.go index 378194cd4..4eb7efad2 100644 --- a/vendor/github.com/open-policy-agent/opa/server/doc.go +++ b/vendor/github.com/open-policy-agent/opa/server/doc.go @@ -3,4 +3,8 @@ // license that can be found in the LICENSE file. // Package server contains the policy engine's server handlers. +// +// Deprecated: This package is intended for older projects transitioning from OPA v0.x and will remain for the lifetime of OPA v1.x, but its use is not recommended. +// For newer features and behaviours, such as defaulting to the Rego v1 syntax, use the corresponding components in the [github.com/open-policy-agent/opa/v1] package instead. +// See https://www.openpolicyagent.org/docs/latest/v0-compatibility/ for more information. package server diff --git a/vendor/github.com/open-policy-agent/opa/server/features.go b/vendor/github.com/open-policy-agent/opa/server/features.go index ea846a427..3b0153722 100644 --- a/vendor/github.com/open-policy-agent/opa/server/features.go +++ b/vendor/github.com/open-policy-agent/opa/server/features.go @@ -7,4 +7,4 @@ package server -import _ "github.com/open-policy-agent/opa/features/wasm" +import _ "github.com/open-policy-agent/opa/v1/features/wasm" diff --git a/vendor/github.com/open-policy-agent/opa/server/server.go b/vendor/github.com/open-policy-agent/opa/server/server.go index bb5c70767..fb19a9bc0 100644 --- a/vendor/github.com/open-policy-agent/opa/server/server.go +++ b/vendor/github.com/open-policy-agent/opa/server/server.go @@ -5,3064 +5,61 @@ package server import ( - "bytes" - "context" - "crypto/tls" - "crypto/x509" - "encoding/json" - "errors" - "fmt" - "html/template" - "io" - "net" - "net/http" - "net/http/pprof" - "net/url" - "os" - "strconv" - "strings" - "sync" - "time" - - serverDecodingPlugin "github.com/open-policy-agent/opa/plugins/server/decoding" - serverEncodingPlugin "github.com/open-policy-agent/opa/plugins/server/encoding" - - "github.com/gorilla/mux" - "go.opentelemetry.io/otel/attribute" - "go.opentelemetry.io/otel/trace" - - "golang.org/x/net/http2" - "golang.org/x/net/http2/h2c" - - "github.com/open-policy-agent/opa/ast" - "github.com/open-policy-agent/opa/bundle" - "github.com/open-policy-agent/opa/internal/json/patch" - "github.com/open-policy-agent/opa/logging" - "github.com/open-policy-agent/opa/metrics" - "github.com/open-policy-agent/opa/plugins" - bundlePlugin "github.com/open-policy-agent/opa/plugins/bundle" - "github.com/open-policy-agent/opa/plugins/status" - "github.com/open-policy-agent/opa/rego" - "github.com/open-policy-agent/opa/server/authorizer" - "github.com/open-policy-agent/opa/server/handlers" - "github.com/open-policy-agent/opa/server/identifier" - "github.com/open-policy-agent/opa/server/types" - "github.com/open-policy-agent/opa/server/writer" - "github.com/open-policy-agent/opa/storage" - "github.com/open-policy-agent/opa/topdown" - "github.com/open-policy-agent/opa/topdown/builtins" - iCache "github.com/open-policy-agent/opa/topdown/cache" - "github.com/open-policy-agent/opa/topdown/lineage" - "github.com/open-policy-agent/opa/tracing" - "github.com/open-policy-agent/opa/util" - "github.com/open-policy-agent/opa/version" + v1 "github.com/open-policy-agent/opa/v1/server" ) // AuthenticationScheme enumerates the supported authentication schemes. The // authentication scheme determines how client identities are established. -type AuthenticationScheme int +type AuthenticationScheme = v1.AuthenticationScheme // Set of supported authentication schemes. const ( - AuthenticationOff AuthenticationScheme = iota - AuthenticationToken - AuthenticationTLS + AuthenticationOff = v1.AuthenticationOff + AuthenticationToken = v1.AuthenticationToken + AuthenticationTLS = v1.AuthenticationTLS ) -var supportedTLSVersions = []uint16{tls.VersionTLS10, tls.VersionTLS11, tls.VersionTLS12, tls.VersionTLS13} - // AuthorizationScheme enumerates the supported authorization schemes. The authorization // scheme determines how access to OPA is controlled. -type AuthorizationScheme int +type AuthorizationScheme = v1.AuthorizationScheme // Set of supported authorization schemes. const ( - AuthorizationOff AuthorizationScheme = iota - AuthorizationBasic + AuthorizationOff = v1.AuthorizationOff + AuthorizationBasic = v1.AuthorizationBasic ) -const defaultMinTLSVersion = tls.VersionTLS12 - // Set of handlers for use in the "handler" dimension of the duration metric. const ( - PromHandlerV0Data = "v0/data" - PromHandlerV1Data = "v1/data" - PromHandlerV1Query = "v1/query" - PromHandlerV1Policies = "v1/policies" - PromHandlerV1Compile = "v1/compile" - PromHandlerV1Config = "v1/config" - PromHandlerV1Status = "v1/status" - PromHandlerIndex = "index" - PromHandlerCatch = "catchall" - PromHandlerHealth = "health" - PromHandlerAPIAuthz = "authz" + PromHandlerV0Data = v1.PromHandlerV0Data + PromHandlerV1Data = v1.PromHandlerV1Data + PromHandlerV1Query = v1.PromHandlerV1Query + PromHandlerV1Policies = v1.PromHandlerV1Policies + PromHandlerV1Compile = v1.PromHandlerV1Compile + PromHandlerV1Config = v1.PromHandlerV1Config + PromHandlerV1Status = v1.PromHandlerV1Status + PromHandlerIndex = v1.PromHandlerIndex + PromHandlerCatch = v1.PromHandlerCatch + PromHandlerHealth = v1.PromHandlerHealth + PromHandlerAPIAuthz = v1.PromHandlerAPIAuthz ) -const pqMaxCacheSize = 100 - -// OpenTelemetry attributes -const otelDecisionIDAttr = "opa.decision_id" - -// map of unsafe builtins -var unsafeBuiltinsMap = map[string]struct{}{ast.HTTPSend.Name: {}} - // Server represents an instance of OPA running in server mode. -type Server struct { - Handler http.Handler - DiagnosticHandler http.Handler - - router *mux.Router - addrs []string - diagAddrs []string - h2cEnabled bool - authentication AuthenticationScheme - authorization AuthorizationScheme - cert *tls.Certificate - tlsConfigMtx sync.RWMutex - certFile string - certFileHash []byte - certKeyFile string - certKeyFileHash []byte - certRefresh time.Duration - certPool *x509.CertPool - certPoolFile string - certPoolFileHash []byte - minTLSVersion uint16 - mtx sync.RWMutex - partials map[string]rego.PartialResult - preparedEvalQueries *cache - store storage.Store - manager *plugins.Manager - decisionIDFactory func() string - logger func(context.Context, *Info) error - errLimit int - pprofEnabled bool - runtime *ast.Term - httpListeners []httpListener - metrics Metrics - defaultDecisionPath string - interQueryBuiltinCache iCache.InterQueryCache - interQueryBuiltinValueCache iCache.InterQueryValueCache - allPluginsOkOnce bool - distributedTracingOpts tracing.Options - ndbCacheEnabled bool - unixSocketPerm *string - cipherSuites *[]uint16 -} +type Server = v1.Server // Metrics defines the interface that the server requires for recording HTTP // handler metrics. -type Metrics interface { - RegisterEndpoints(registrar func(path, method string, handler http.Handler)) - InstrumentHandler(handler http.Handler, label string) http.Handler -} +type Metrics = v1.Metrics // TLSConfig represents the TLS configuration for the server. // This configuration is used to configure file watchers to reload each file as it // changes on disk. -type TLSConfig struct { - // CertFile is the path to the server's serving certificate file. - CertFile string - - // KeyFile is the path to the server's key file, completing the key pair for the - // CertFile certificate. - KeyFile string - - // CertPoolFile is the path to the CA cert pool file. The contents of this file will be - // reloaded when the file changes on disk and used in as trusted client CAs in the TLS config - // for new connections to the server. - CertPoolFile string -} +type TLSConfig = v1.TLSConfig // Loop will contain all the calls from the server that we'll be listening on. -type Loop func() error +type Loop = v1.Loop // New returns a new Server. func New() *Server { - s := Server{} - return &s -} - -// Init initializes the server. This function MUST be called before starting any loops -// from s.Listeners(). -func (s *Server) Init(ctx context.Context) (*Server, error) { - s.initRouters(ctx) - - txn, err := s.store.NewTransaction(ctx, storage.WriteParams) - if err != nil { - return nil, err - } - - // Register triggers so that if runtime reloads the policies, the - // server sees the change. - config := storage.TriggerConfig{ - OnCommit: s.reload, - } - if _, err := s.store.Register(ctx, txn, config); err != nil { - s.store.Abort(ctx, txn) - return nil, err - } - - s.partials = map[string]rego.PartialResult{} - s.preparedEvalQueries = newCache(pqMaxCacheSize) - s.defaultDecisionPath = s.generateDefaultDecisionPath() - s.manager.RegisterNDCacheTrigger(s.updateNDCache) - - s.Handler = s.initHandlerAuthn(s.Handler) - - // compression handler - s.Handler, err = s.initHandlerCompression(s.Handler) - if err != nil { - return nil, err - } - s.DiagnosticHandler = s.initHandlerAuthn(s.DiagnosticHandler) - - s.Handler, err = s.initHandlerDecodingLimits(s.Handler) - if err != nil { - return nil, err - } - - return s, s.store.Commit(ctx, txn) -} - -// Shutdown will attempt to gracefully shutdown each of the http servers -// currently in use by the OPA Server. If any exceed the deadline specified -// by the context an error will be returned. -func (s *Server) Shutdown(ctx context.Context) error { - errChan := make(chan error) - for _, srvr := range s.httpListeners { - go func(s httpListener) { - errChan <- s.Shutdown(ctx) - }(srvr) - } - // wait until each server has finished shutting down - var errorList []error - for i := 0; i < len(s.httpListeners); i++ { - err := <-errChan - if err != nil { - errorList = append(errorList, err) - } - } - - if len(errorList) > 0 { - errMsg := "error while shutting down: " - for i, err := range errorList { - errMsg += fmt.Sprintf("(%d) %s. ", i, err.Error()) - } - return errors.New(errMsg) - } - return nil -} - -// WithAddresses sets the listening addresses that the server will bind to. -func (s *Server) WithAddresses(addrs []string) *Server { - s.addrs = addrs - return s -} - -// WithDiagnosticAddresses sets the listening addresses that the server will -// bind to and *only* serve read-only diagnostic API's. -func (s *Server) WithDiagnosticAddresses(addrs []string) *Server { - s.diagAddrs = addrs - return s -} - -// WithAuthentication sets authentication scheme to use on the server. -func (s *Server) WithAuthentication(scheme AuthenticationScheme) *Server { - s.authentication = scheme - return s -} - -// WithAuthorization sets authorization scheme to use on the server. -func (s *Server) WithAuthorization(scheme AuthorizationScheme) *Server { - s.authorization = scheme - return s -} - -// WithCertificate sets the server-side certificate that the server will use. -func (s *Server) WithCertificate(cert *tls.Certificate) *Server { - s.cert = cert - return s -} - -// WithCertificatePaths sets the server-side certificate and keyfile paths -// that the server will periodically check for changes, and reload if necessary. -func (s *Server) WithCertificatePaths(certFile, keyFile string, refresh time.Duration) *Server { - s.certFile = certFile - s.certKeyFile = keyFile - s.certRefresh = refresh - return s -} - -// WithCertPool sets the server-side cert pool that the server will use. -func (s *Server) WithCertPool(pool *x509.CertPool) *Server { - s.certPool = pool - return s -} - -// WithTLSConfig sets the TLS configuration used by the server. -func (s *Server) WithTLSConfig(tlsConfig *TLSConfig) *Server { - s.certFile = tlsConfig.CertFile - s.certKeyFile = tlsConfig.KeyFile - s.certPoolFile = tlsConfig.CertPoolFile - return s -} - -// WithCertRefresh sets the period on which certs, keys and cert pools are reloaded from disk. -func (s *Server) WithCertRefresh(refresh time.Duration) *Server { - s.certRefresh = refresh - return s -} - -// WithStore sets the storage used by the server. -func (s *Server) WithStore(store storage.Store) *Server { - s.store = store - return s -} - -// WithMetrics sets the metrics provider used by the server. -func (s *Server) WithMetrics(m Metrics) *Server { - s.metrics = m - return s -} - -// WithManager sets the plugins manager used by the server. -func (s *Server) WithManager(manager *plugins.Manager) *Server { - s.manager = manager - return s -} - -// WithCompilerErrorLimit sets the limit on the number of compiler errors the server will -// allow. -func (s *Server) WithCompilerErrorLimit(limit int) *Server { - s.errLimit = limit - return s -} - -// WithPprofEnabled sets whether pprof endpoints are enabled -func (s *Server) WithPprofEnabled(pprofEnabled bool) *Server { - s.pprofEnabled = pprofEnabled - return s -} - -// WithH2CEnabled sets whether h2c ("HTTP/2 cleartext") is enabled for the http listener -func (s *Server) WithH2CEnabled(enabled bool) *Server { - s.h2cEnabled = enabled - return s -} - -// WithDecisionLogger sets the decision logger used by the -// server. DEPRECATED. Use WithDecisionLoggerWithErr instead. -func (s *Server) WithDecisionLogger(logger func(context.Context, *Info)) *Server { - s.logger = func(ctx context.Context, info *Info) error { - logger(ctx, info) - return nil - } - return s -} - -// WithDecisionLoggerWithErr sets the decision logger used by the server. -func (s *Server) WithDecisionLoggerWithErr(logger func(context.Context, *Info) error) *Server { - s.logger = logger - return s -} - -// WithDecisionIDFactory sets a function on the server to generate decision IDs. -func (s *Server) WithDecisionIDFactory(f func() string) *Server { - s.decisionIDFactory = f - return s -} - -// WithRuntime sets the runtime data to provide to the evaluation engine. -func (s *Server) WithRuntime(term *ast.Term) *Server { - s.runtime = term - return s -} - -// WithRouter sets the mux.Router to attach OPA's HTTP API routes onto. If a -// router is not supplied, the server will create it's own. -func (s *Server) WithRouter(router *mux.Router) *Server { - s.router = router - return s -} - -func (s *Server) WithMinTLSVersion(minTLSVersion uint16) *Server { - if isMinTLSVersionSupported(minTLSVersion) { - s.minTLSVersion = minTLSVersion - } else { - s.minTLSVersion = defaultMinTLSVersion - } - return s -} - -// WithDistributedTracingOpts sets the options to be used by distributed tracing. -func (s *Server) WithDistributedTracingOpts(opts tracing.Options) *Server { - s.distributedTracingOpts = opts - return s -} - -// WithNDBCacheEnabled sets whether the ND builtins cache is to be used. -func (s *Server) WithNDBCacheEnabled(ndbCacheEnabled bool) *Server { - s.ndbCacheEnabled = ndbCacheEnabled - return s -} - -// WithCipherSuites sets the list of enabled TLS 1.0–1.2 cipher suites. -func (s *Server) WithCipherSuites(cipherSuites *[]uint16) *Server { - s.cipherSuites = cipherSuites - return s -} - -// WithUnixSocketPermission sets the permission for the Unix domain socket if used to listen for -// incoming connections. Applies to the sockets the server is listening on including diagnostic API's. -func (s *Server) WithUnixSocketPermission(unixSocketPerm *string) *Server { - s.unixSocketPerm = unixSocketPerm - return s -} - -// Listeners returns functions that listen and serve connections. -func (s *Server) Listeners() ([]Loop, error) { - loops := []Loop{} - - handlerBindings := map[httpListenerType]struct { - addrs []string - handler http.Handler - }{ - defaultListenerType: {s.addrs, s.Handler}, - diagnosticListenerType: {s.diagAddrs, s.DiagnosticHandler}, - } - - for t, binding := range handlerBindings { - for _, addr := range binding.addrs { - l, listener, err := s.getListener(addr, binding.handler, t) - if err != nil { - return nil, err - } - s.httpListeners = append(s.httpListeners, listener) - loops = append(loops, l...) - } - } - - return loops, nil -} - -// Addrs returns a list of addresses that the server is listening on. -// If the server hasn't been started it will not return an address. -func (s *Server) Addrs() []string { - return s.addrsForType(defaultListenerType) -} - -// DiagnosticAddrs returns a list of addresses that the server is listening on -// for the read-only diagnostic API's (eg /health, /metrics, etc) -// If the server hasn't been started it will not return an address. -func (s *Server) DiagnosticAddrs() []string { - return s.addrsForType(diagnosticListenerType) -} - -func (s *Server) addrsForType(t httpListenerType) []string { - var addrs []string - for _, l := range s.httpListeners { - a := l.Addr() - if a != "" && l.Type() == t { - addrs = append(addrs, a) - } - } - return addrs -} - -type tcpKeepAliveListener struct { - *net.TCPListener -} - -func (ln tcpKeepAliveListener) Accept() (net.Conn, error) { - tc, err := ln.AcceptTCP() - if err != nil { - return nil, err - } - err = tc.SetKeepAlive(true) - if err != nil { - return nil, err - } - err = tc.SetKeepAlivePeriod(3 * time.Minute) - if err != nil { - return nil, err - } - return tc, nil -} - -type httpListenerType int - -const ( - defaultListenerType httpListenerType = iota - diagnosticListenerType -) - -type httpListener interface { - Addr() string - ListenAndServe() error - ListenAndServeTLS(certFile, keyFile string) error - Shutdown(context.Context) error - Type() httpListenerType -} - -// baseHTTPListener is just a wrapper around http.Server -type baseHTTPListener struct { - s *http.Server - l net.Listener - t httpListenerType - addr string - addrMtx sync.RWMutex -} - -var _ httpListener = (*baseHTTPListener)(nil) - -func newHTTPListener(srvr *http.Server, t httpListenerType) httpListener { - return &baseHTTPListener{s: srvr, t: t} -} - -func newHTTPUnixSocketListener(srvr *http.Server, l net.Listener, t httpListenerType) httpListener { - return &baseHTTPListener{s: srvr, l: l, t: t} -} - -func (b *baseHTTPListener) ListenAndServe() error { - addr := b.s.Addr - if addr == "" { - addr = ":http" - } - var err error - b.l, err = net.Listen("tcp", addr) - if err != nil { - return err - } - - b.initAddr() - - return b.s.Serve(tcpKeepAliveListener{b.l.(*net.TCPListener)}) -} - -func (b *baseHTTPListener) initAddr() { - b.addrMtx.Lock() - if addr := b.l.(*net.TCPListener).Addr(); addr != nil { - b.addr = addr.String() - } - b.addrMtx.Unlock() -} - -func (b *baseHTTPListener) Addr() string { - b.addrMtx.Lock() - defer b.addrMtx.Unlock() - return b.addr -} - -func (b *baseHTTPListener) ListenAndServeTLS(certFile, keyFile string) error { - addr := b.s.Addr - if addr == "" { - addr = ":https" - } - - var err error - b.l, err = net.Listen("tcp", addr) - if err != nil { - return err - } - - b.initAddr() - - defer b.l.Close() - - return b.s.ServeTLS(tcpKeepAliveListener{b.l.(*net.TCPListener)}, certFile, keyFile) -} - -func (b *baseHTTPListener) Shutdown(ctx context.Context) error { - return b.s.Shutdown(ctx) -} - -func (b *baseHTTPListener) Type() httpListenerType { - return b.t -} - -func isMinTLSVersionSupported(TLSVersion uint16) bool { - for _, version := range supportedTLSVersions { - if TLSVersion == version { - return true - } - } - return false -} - -func (s *Server) getListener(addr string, h http.Handler, t httpListenerType) ([]Loop, httpListener, error) { - parsedURL, err := parseURL(addr, s.cert != nil) - if err != nil { - return nil, nil, err - } - - var loops []Loop - var loop Loop - var listener httpListener - switch parsedURL.Scheme { - case "unix": - loop, listener, err = s.getListenerForUNIXSocket(parsedURL, h, t) - loops = []Loop{loop} - case "http": - loop, listener, err = s.getListenerForHTTPServer(parsedURL, h, t) - loops = []Loop{loop} - case "https": - loop, listener, err = s.getListenerForHTTPSServer(parsedURL, h, t) - logger := s.manager.Logger().WithFields(map[string]interface{}{ - "cert-file": s.certFile, - "cert-key-file": s.certKeyFile, - }) - - // if a manual cert refresh period has been set, then use the polling behavior, - // otherwise use the fsnotify default behavior - if s.certRefresh > 0 { - loops = []Loop{loop, s.certLoopPolling(logger)} - } else if s.certFile != "" || s.certPoolFile != "" { - loops = []Loop{loop, s.certLoopNotify(logger)} - } - default: - err = fmt.Errorf("invalid url scheme %q", parsedURL.Scheme) - } - - return loops, listener, err -} - -func (s *Server) getListenerForHTTPServer(u *url.URL, h http.Handler, t httpListenerType) (Loop, httpListener, error) { - if s.h2cEnabled { - h2s := &http2.Server{} - h = h2c.NewHandler(h, h2s) - } - h1s := http.Server{ - Addr: u.Host, - Handler: h, - } - - l := newHTTPListener(&h1s, t) - - return l.ListenAndServe, l, nil -} - -func (s *Server) getListenerForHTTPSServer(u *url.URL, h http.Handler, t httpListenerType) (Loop, httpListener, error) { - - if s.cert == nil { - return nil, nil, fmt.Errorf("TLS certificate required but not supplied") - } - - tlsConfig := tls.Config{ - GetCertificate: s.getCertificate, - // GetConfigForClient is used to ensure that a fresh config is provided containing the latest cert pool. - // This is not required, but appears to be how connect time updates config should be done: - // https://github.com/golang/go/issues/16066#issuecomment-250606132 - GetConfigForClient: func(_ *tls.ClientHelloInfo) (*tls.Config, error) { - s.tlsConfigMtx.Lock() - defer s.tlsConfigMtx.Unlock() - - cfg := &tls.Config{ - GetCertificate: s.getCertificate, - ClientCAs: s.certPool, - } - - if s.authentication == AuthenticationTLS { - cfg.ClientAuth = tls.RequireAndVerifyClientCert - } - - if s.minTLSVersion != 0 { - cfg.MinVersion = s.minTLSVersion - } else { - cfg.MinVersion = defaultMinTLSVersion - } - - if s.cipherSuites != nil { - cfg.CipherSuites = *s.cipherSuites - } - - return cfg, nil - }, - } - - httpsServer := http.Server{ - Addr: u.Host, - Handler: h, - TLSConfig: &tlsConfig, - } - - l := newHTTPListener(&httpsServer, t) - - httpsLoop := func() error { return l.ListenAndServeTLS("", "") } - - return httpsLoop, l, nil -} - -func (s *Server) getListenerForUNIXSocket(u *url.URL, h http.Handler, t httpListenerType) (Loop, httpListener, error) { - socketPath := u.Host + u.Path - - // Recover @ prefix for abstract Unix sockets. - if strings.HasPrefix(u.String(), u.Scheme+"://@") { - socketPath = "@" + socketPath - } else { - // Remove domain socket file in case it already exists. - os.Remove(socketPath) - } - - domainSocketServer := http.Server{Handler: h} - unixListener, err := net.Listen("unix", socketPath) - if err != nil { - return nil, nil, err - } - - if s.unixSocketPerm != nil { - modeVal, err := strconv.ParseUint(*s.unixSocketPerm, 8, 32) - if err != nil { - return nil, nil, err - } - - if err := os.Chmod(socketPath, os.FileMode(modeVal)); err != nil { - return nil, nil, err - } - } - - l := newHTTPUnixSocketListener(&domainSocketServer, unixListener, t) - - domainSocketLoop := func() error { return domainSocketServer.Serve(unixListener) } - return domainSocketLoop, l, nil -} - -func (s *Server) initHandlerAuthn(handler http.Handler) http.Handler { - switch s.authentication { - case AuthenticationToken: - handler = identifier.NewTokenBased(handler) - case AuthenticationTLS: - handler = identifier.NewTLSBased(handler) - } - - return handler -} - -func (s *Server) initHandlerAuthz(handler http.Handler) http.Handler { - switch s.authorization { - case AuthorizationBasic: - handler = authorizer.NewBasic( - handler, - s.getCompiler, - s.store, - authorizer.Runtime(s.runtime), - authorizer.Decision(s.manager.Config.DefaultAuthorizationDecisionRef), - authorizer.PrintHook(s.manager.PrintHook()), - authorizer.EnablePrintStatements(s.manager.EnablePrintStatements()), - authorizer.InterQueryCache(s.interQueryBuiltinCache), - authorizer.InterQueryValueCache(s.interQueryBuiltinValueCache)) - - if s.metrics != nil { - handler = s.instrumentHandler(handler.ServeHTTP, PromHandlerAPIAuthz) - } - } - - return handler -} - -// Enforces request body size limits on incoming requests. For gzipped requests, -// it passes the size limit down the body-reading method via the request -// context. -func (s *Server) initHandlerDecodingLimits(handler http.Handler) (http.Handler, error) { - var decodingRawConfig json.RawMessage - serverConfig := s.manager.Config.Server - if serverConfig != nil { - decodingRawConfig = serverConfig.Decoding - } - decodingConfig, err := serverDecodingPlugin.NewConfigBuilder().WithBytes(decodingRawConfig).Parse() - if err != nil { - return nil, err - } - decodingHandler := handlers.DecodingLimitsHandler(handler, *decodingConfig.MaxLength, *decodingConfig.Gzip.MaxLength) - - return decodingHandler, nil -} - -func (s *Server) initHandlerCompression(handler http.Handler) (http.Handler, error) { - var encodingRawConfig json.RawMessage - serverConfig := s.manager.Config.Server - if serverConfig != nil { - encodingRawConfig = serverConfig.Encoding - } - encodingConfig, err := serverEncodingPlugin.NewConfigBuilder().WithBytes(encodingRawConfig).Parse() - if err != nil { - return nil, err - } - compressHandler := handlers.CompressHandler(handler, *encodingConfig.Gzip.MinLength, *encodingConfig.Gzip.CompressionLevel) - - return compressHandler, nil -} - -func (s *Server) initRouters(ctx context.Context) { - mainRouter := s.router - if mainRouter == nil { - mainRouter = mux.NewRouter() - } - - diagRouter := mux.NewRouter() - - // authorizer, if configured, needs the iCache to be set up already - - cacheConfig := s.manager.InterQueryBuiltinCacheConfig() - - s.interQueryBuiltinCache = iCache.NewInterQueryCacheWithContext(ctx, cacheConfig) - s.interQueryBuiltinValueCache = iCache.NewInterQueryValueCache(ctx, cacheConfig) - - s.manager.RegisterCacheTrigger(s.updateCacheConfig) - - // Add authorization handler. This must come BEFORE authentication handler - // so that the latter can run first. - handlerAuthz := s.initHandlerAuthz(mainRouter) - - handlerAuthzDiag := s.initHandlerAuthz(diagRouter) - - // All routers get the same base configuration *and* diagnostic API's - for _, router := range []*mux.Router{mainRouter, diagRouter} { - router.StrictSlash(true) - router.UseEncodedPath() - router.StrictSlash(true) - - if s.metrics != nil { - s.metrics.RegisterEndpoints(func(path, method string, handler http.Handler) { - router.Handle(path, handler).Methods(method) - }) - } - - router.Handle("/health", s.instrumentHandler(s.unversionedGetHealth, PromHandlerHealth)).Methods(http.MethodGet) - // Use this route to evaluate health policy defined at system.health - // By convention, policy is typically defined at system.health.live and system.health.ready, and is - // evaluated by calling /health/live and /health/ready respectively. - router.Handle("/health/{path:.+}", s.instrumentHandler(s.unversionedGetHealthWithPolicy, PromHandlerHealth)).Methods(http.MethodGet) - } - - if s.pprofEnabled { - mainRouter.HandleFunc("/debug/pprof/", pprof.Index) - mainRouter.Handle("/debug/pprof/allocs", pprof.Handler("allocs")) - mainRouter.Handle("/debug/pprof/block", pprof.Handler("block")) - mainRouter.Handle("/debug/pprof/heap", pprof.Handler("heap")) - mainRouter.Handle("/debug/pprof/mutex", pprof.Handler("mutex")) - mainRouter.HandleFunc("/debug/pprof/cmdline", pprof.Cmdline) - mainRouter.HandleFunc("/debug/pprof/profile", pprof.Profile) - mainRouter.HandleFunc("/debug/pprof/symbol", pprof.Symbol) - mainRouter.HandleFunc("/debug/pprof/trace", pprof.Trace) - } - - // Only the main mainRouter gets the OPA API's (data, policies, query, etc) - mainRouter.Handle("/v0/data/{path:.+}", s.instrumentHandler(s.v0DataPost, PromHandlerV0Data)).Methods(http.MethodPost) - mainRouter.Handle("/v0/data", s.instrumentHandler(s.v0DataPost, PromHandlerV0Data)).Methods(http.MethodPost) - mainRouter.Handle("/v1/data/{path:.+}", s.instrumentHandler(s.v1DataDelete, PromHandlerV1Data)).Methods(http.MethodDelete) - mainRouter.Handle("/v1/data/{path:.+}", s.instrumentHandler(s.v1DataPut, PromHandlerV1Data)).Methods(http.MethodPut) - mainRouter.Handle("/v1/data", s.instrumentHandler(s.v1DataPut, PromHandlerV1Data)).Methods(http.MethodPut) - mainRouter.Handle("/v1/data/{path:.+}", s.instrumentHandler(s.v1DataGet, PromHandlerV1Data)).Methods(http.MethodGet) - mainRouter.Handle("/v1/data", s.instrumentHandler(s.v1DataGet, PromHandlerV1Data)).Methods(http.MethodGet) - mainRouter.Handle("/v1/data/{path:.+}", s.instrumentHandler(s.v1DataPatch, PromHandlerV1Data)).Methods(http.MethodPatch) - mainRouter.Handle("/v1/data", s.instrumentHandler(s.v1DataPatch, PromHandlerV1Data)).Methods(http.MethodPatch) - mainRouter.Handle("/v1/data/{path:.+}", s.instrumentHandler(s.v1DataPost, PromHandlerV1Data)).Methods(http.MethodPost) - mainRouter.Handle("/v1/data", s.instrumentHandler(s.v1DataPost, PromHandlerV1Data)).Methods(http.MethodPost) - mainRouter.Handle("/v1/policies", s.instrumentHandler(s.v1PoliciesList, PromHandlerV1Policies)).Methods(http.MethodGet) - mainRouter.Handle("/v1/policies/{path:.+}", s.instrumentHandler(s.v1PoliciesDelete, PromHandlerV1Policies)).Methods(http.MethodDelete) - mainRouter.Handle("/v1/policies/{path:.+}", s.instrumentHandler(s.v1PoliciesGet, PromHandlerV1Policies)).Methods(http.MethodGet) - mainRouter.Handle("/v1/policies/{path:.+}", s.instrumentHandler(s.v1PoliciesPut, PromHandlerV1Policies)).Methods(http.MethodPut) - mainRouter.Handle("/v1/query", s.instrumentHandler(s.v1QueryGet, PromHandlerV1Query)).Methods(http.MethodGet) - mainRouter.Handle("/v1/query", s.instrumentHandler(s.v1QueryPost, PromHandlerV1Query)).Methods(http.MethodPost) - mainRouter.Handle("/v1/compile", s.instrumentHandler(s.v1CompilePost, PromHandlerV1Compile)).Methods(http.MethodPost) - mainRouter.Handle("/v1/config", s.instrumentHandler(s.v1ConfigGet, PromHandlerV1Config)).Methods(http.MethodGet) - mainRouter.Handle("/v1/status", s.instrumentHandler(s.v1StatusGet, PromHandlerV1Status)).Methods(http.MethodGet) - mainRouter.Handle("/", s.instrumentHandler(s.unversionedPost, PromHandlerIndex)).Methods(http.MethodPost) - mainRouter.Handle("/", s.instrumentHandler(s.indexGet, PromHandlerIndex)).Methods(http.MethodGet) - - // These are catch all handlers that respond http.StatusMethodNotAllowed for resources that exist but the method is not allowed - mainRouter.Handle("/v0/data/{path:.*}", s.instrumentHandler(writer.HTTPStatus(http.StatusMethodNotAllowed), PromHandlerCatch)).Methods(http.MethodGet, http.MethodHead, - http.MethodConnect, http.MethodDelete, http.MethodOptions, http.MethodPatch, http.MethodPut, http.MethodTrace) - mainRouter.Handle("/v0/data", s.instrumentHandler(writer.HTTPStatus(http.StatusMethodNotAllowed), PromHandlerCatch)).Methods(http.MethodGet, http.MethodHead, - http.MethodConnect, http.MethodDelete, http.MethodOptions, http.MethodPatch, http.MethodPut, - http.MethodTrace) - // v1 Data catch all - mainRouter.Handle("/v1/data/{path:.*}", s.instrumentHandler(writer.HTTPStatus(http.StatusMethodNotAllowed), PromHandlerCatch)).Methods(http.MethodHead, - http.MethodConnect, http.MethodOptions, http.MethodTrace) - mainRouter.Handle("/v1/data", s.instrumentHandler(writer.HTTPStatus(http.StatusMethodNotAllowed), PromHandlerCatch)).Methods(http.MethodHead, - http.MethodConnect, http.MethodDelete, http.MethodOptions, http.MethodTrace) - // Policies catch all - mainRouter.Handle("/v1/policies", s.instrumentHandler(writer.HTTPStatus(http.StatusMethodNotAllowed), PromHandlerCatch)).Methods(http.MethodHead, - http.MethodConnect, http.MethodDelete, http.MethodOptions, http.MethodTrace, http.MethodPost, http.MethodPut, - http.MethodPatch) - // Policies (/policies/{path.+} catch all - mainRouter.Handle("/v1/policies/{path:.*}", s.instrumentHandler(writer.HTTPStatus(http.StatusMethodNotAllowed), PromHandlerCatch)).Methods(http.MethodHead, - http.MethodConnect, http.MethodOptions, http.MethodTrace, http.MethodPost) - // Query catch all - mainRouter.Handle("/v1/query/{path:.*}", s.instrumentHandler(writer.HTTPStatus(http.StatusMethodNotAllowed), PromHandlerCatch)).Methods(http.MethodHead, - http.MethodConnect, http.MethodDelete, http.MethodOptions, http.MethodTrace, http.MethodPost, http.MethodPut, http.MethodPatch) - mainRouter.Handle("/v1/query", s.instrumentHandler(writer.HTTPStatus(http.StatusMethodNotAllowed), PromHandlerCatch)).Methods(http.MethodHead, - http.MethodConnect, http.MethodDelete, http.MethodOptions, http.MethodTrace, http.MethodPut, http.MethodPatch) - - s.Handler = mainRouter - s.DiagnosticHandler = diagRouter - - // Add authorization handler in the end so that it can run first - s.Handler = handlerAuthz - s.DiagnosticHandler = handlerAuthzDiag -} - -func (s *Server) instrumentHandler(handler func(http.ResponseWriter, *http.Request), label string) http.Handler { - var httpHandler http.Handler = http.HandlerFunc(handler) - if len(s.distributedTracingOpts) > 0 { - httpHandler = tracing.NewHandler(httpHandler, label, s.distributedTracingOpts) - } - if s.metrics != nil { - return s.metrics.InstrumentHandler(httpHandler, label) - } - return httpHandler -} - -func (s *Server) execQuery(ctx context.Context, br bundleRevisions, txn storage.Transaction, parsedQuery ast.Body, input ast.Value, rawInput *interface{}, m metrics.Metrics, explainMode types.ExplainModeV1, includeMetrics, includeInstrumentation, pretty bool) (*types.QueryResponseV1, error) { - results := types.QueryResponseV1{} - logger := s.getDecisionLogger(br) - - var buf *topdown.BufferTracer - if explainMode != types.ExplainOffV1 { - buf = topdown.NewBufferTracer() - } - - var ndbCache builtins.NDBCache - if s.ndbCacheEnabled { - ndbCache = builtins.NDBCache{} - } - - opts := []func(*rego.Rego){ - rego.Store(s.store), - rego.Transaction(txn), - rego.Compiler(s.getCompiler()), - rego.ParsedQuery(parsedQuery), - rego.ParsedInput(input), - rego.Metrics(m), - rego.Instrument(includeInstrumentation), - rego.QueryTracer(buf), - rego.Runtime(s.runtime), - rego.UnsafeBuiltins(unsafeBuiltinsMap), - rego.InterQueryBuiltinCache(s.interQueryBuiltinCache), - rego.InterQueryBuiltinValueCache(s.interQueryBuiltinValueCache), - rego.PrintHook(s.manager.PrintHook()), - rego.EnablePrintStatements(s.manager.EnablePrintStatements()), - rego.DistributedTracingOpts(s.distributedTracingOpts), - rego.NDBuiltinCache(ndbCache), - } - - for _, r := range s.manager.GetWasmResolvers() { - for _, entrypoint := range r.Entrypoints() { - opts = append(opts, rego.Resolver(entrypoint, r)) - } - } - - rego := rego.New(opts...) - - output, err := rego.Eval(ctx) - if err != nil { - _ = logger.Log(ctx, txn, "", parsedQuery.String(), rawInput, input, nil, ndbCache, err, m) - return nil, err - } - - for _, result := range output { - results.Result = append(results.Result, result.Bindings.WithoutWildcards()) - } - - if includeMetrics || includeInstrumentation { - results.Metrics = m.All() - } - - if explainMode != types.ExplainOffV1 { - results.Explanation = s.getExplainResponse(explainMode, *buf, pretty) - } - - var x interface{} = results.Result - if err := logger.Log(ctx, txn, "", parsedQuery.String(), rawInput, input, &x, ndbCache, nil, m); err != nil { - return nil, err - } - return &results, nil -} - -func (s *Server) indexGet(w http.ResponseWriter, _ *http.Request) { - _ = indexHTML.Execute(w, struct { - Version string - BuildCommit string - BuildTimestamp string - BuildHostname string - }{ - Version: version.Version, - BuildCommit: version.Vcs, - BuildTimestamp: version.Timestamp, - BuildHostname: version.Hostname, - }) -} - -type bundleRevisions struct { - LegacyRevision string - Revisions map[string]string -} - -func getRevisions(ctx context.Context, store storage.Store, txn storage.Transaction) (bundleRevisions, error) { - - var err error - var br bundleRevisions - br.Revisions = map[string]string{} - - // Check if we still have a legacy bundle manifest in the store - br.LegacyRevision, err = bundle.LegacyReadRevisionFromStore(ctx, store, txn) - if err != nil && !storage.IsNotFound(err) { - return br, err - } - - // read all bundle revisions from storage (if any exist) - names, err := bundle.ReadBundleNamesFromStore(ctx, store, txn) - if err != nil && !storage.IsNotFound(err) { - return br, err - } - - for _, name := range names { - r, err := bundle.ReadBundleRevisionFromStore(ctx, store, txn, name) - if err != nil && !storage.IsNotFound(err) { - return br, err - } - br.Revisions[name] = r - } - - return br, nil -} - -func (s *Server) reload(context.Context, storage.Transaction, storage.TriggerEvent) { - - // NOTE(tsandall): We currently rely on the storage txn to provide - // critical sections in the server. - // - // If you modify this function to change any other state on the server, you must - // review the other places in the server where that state is accessed to avoid data - // races--the state must be accessed _after_ a txn has been opened. - - // reset some cached info - s.partials = map[string]rego.PartialResult{} - s.preparedEvalQueries = newCache(pqMaxCacheSize) - s.defaultDecisionPath = s.generateDefaultDecisionPath() -} - -func (s *Server) unversionedPost(w http.ResponseWriter, r *http.Request) { - s.v0QueryPath(w, r, "", true) -} - -func (s *Server) v0DataPost(w http.ResponseWriter, r *http.Request) { - s.v0QueryPath(w, r, mux.Vars(r)["path"], false) -} - -func (s *Server) v0QueryPath(w http.ResponseWriter, r *http.Request, urlPath string, useDefaultDecisionPath bool) { - m := metrics.New() - m.Timer(metrics.ServerHandler).Start() - - decisionID := s.generateDecisionID() - ctx := logging.WithDecisionID(r.Context(), decisionID) - annotateSpan(ctx, decisionID) - - input, goInput, err := readInputV0(r) - if err != nil { - writer.ErrorString(w, http.StatusBadRequest, types.CodeInvalidParameter, fmt.Errorf("unexpected parse error for input: %w", err)) - return - } - - // Prepare for query. - txn, err := s.store.NewTransaction(ctx) - if err != nil { - writer.ErrorAuto(w, err) - return - } - - defer s.store.Abort(ctx, txn) - - br, err := getRevisions(ctx, s.store, txn) - if err != nil { - writer.ErrorAuto(w, err) - return - } - - if useDefaultDecisionPath { - urlPath = s.generateDefaultDecisionPath() - } - - logger := s.getDecisionLogger(br) - - var ndbCache builtins.NDBCache - if s.ndbCacheEnabled { - ndbCache = builtins.NDBCache{} - } - - pqID := "v0QueryPath::" + urlPath - preparedQuery, ok := s.getCachedPreparedEvalQuery(pqID, m) - if !ok { - opts := []func(*rego.Rego){ - rego.Compiler(s.getCompiler()), - rego.Store(s.store), - } - - // Set resolvers on the base Rego object to avoid having them get - // re-initialized, and to propagate them to the prepared query. - for _, r := range s.manager.GetWasmResolvers() { - for _, entrypoint := range r.Entrypoints() { - opts = append(opts, rego.Resolver(entrypoint, r)) - } - } - - rego, err := s.makeRego(ctx, false, txn, input, urlPath, m, false, nil, opts) - if err != nil { - _ = logger.Log(ctx, txn, urlPath, "", goInput, input, nil, ndbCache, err, m) - writer.ErrorAuto(w, err) - return - } - - pq, err := rego.PrepareForEval(ctx) - if err != nil { - _ = logger.Log(ctx, txn, urlPath, "", goInput, input, nil, ndbCache, err, m) - writer.ErrorAuto(w, err) - return - } - preparedQuery = &pq - s.preparedEvalQueries.Insert(pqID, preparedQuery) - } - - evalOpts := []rego.EvalOption{ - rego.EvalTransaction(txn), - rego.EvalParsedInput(input), - rego.EvalMetrics(m), - rego.EvalInterQueryBuiltinCache(s.interQueryBuiltinCache), - rego.EvalInterQueryBuiltinValueCache(s.interQueryBuiltinValueCache), - rego.EvalNDBuiltinCache(ndbCache), - } - - rs, err := preparedQuery.Eval( - ctx, - evalOpts..., - ) - - m.Timer(metrics.ServerHandler).Stop() - - // Handle results. - if err != nil { - _ = logger.Log(ctx, txn, urlPath, "", goInput, input, nil, ndbCache, err, m) - writer.ErrorAuto(w, err) - return - } - - if len(rs) == 0 { - ref := stringPathToDataRef(urlPath) - - var messageType = types.MsgMissingError - if len(s.getCompiler().GetRulesForVirtualDocument(ref)) > 0 { - messageType = types.MsgFoundUndefinedError - } - err := types.NewErrorV1(types.CodeUndefinedDocument, fmt.Sprintf("%v: %v", messageType, ref)) - if err := logger.Log(ctx, txn, urlPath, "", goInput, input, nil, ndbCache, err, m); err != nil { - writer.ErrorAuto(w, err) - return - } - - writer.Error(w, http.StatusNotFound, err) - return - } - err = logger.Log(ctx, txn, urlPath, "", goInput, input, &rs[0].Expressions[0].Value, ndbCache, nil, m) - if err != nil { - writer.ErrorAuto(w, err) - return - } - - writer.JSONOK(w, rs[0].Expressions[0].Value, pretty(r)) -} - -func (s *Server) getCachedPreparedEvalQuery(key string, m metrics.Metrics) (*rego.PreparedEvalQuery, bool) { - pq, ok := s.preparedEvalQueries.Get(key) - m.Counter(metrics.ServerQueryCacheHit) // Creates the counter on the metrics if it doesn't exist, starts at 0 - if ok { - m.Counter(metrics.ServerQueryCacheHit).Incr() // Increment counter on hit - return pq.(*rego.PreparedEvalQuery), true - } - return nil, false -} - -func (s *Server) canEval(ctx context.Context) bool { - // Create very simple query that binds a single variable. - opts := []func(*rego.Rego){ - rego.Compiler(s.getCompiler()), - rego.Store(s.store), - rego.Query("x = 1"), - } - - for _, r := range s.manager.GetWasmResolvers() { - for _, ep := range r.Entrypoints() { - opts = append(opts, rego.Resolver(ep, r)) - } - } - - eval := rego.New(opts...) - // Run evaluation. - rs, err := eval.Eval(ctx) - if err != nil { - return false - } - - v, ok := rs[0].Bindings["x"] - if ok { - jsonNumber, ok := v.(json.Number) - if ok && jsonNumber.String() == "1" { - return true - } - } - return false -} - -func (s *Server) bundlesReady(pluginStatuses map[string]*plugins.Status) bool { - - // Look for a discovery plugin first, if it exists and isn't ready - // then don't bother with the others. - // Note: use "discovery" instead of `discovery.Name` to avoid import - // cycle problems.. - dpStatus, ok := pluginStatuses["discovery"] - if ok && dpStatus != nil && (dpStatus.State != plugins.StateOK) { - return false - } - - // The bundle plugin won't return "OK" until the first activation - // of each configured bundle. - bpStatus, ok := pluginStatuses[bundlePlugin.Name] - if ok && bpStatus != nil && (bpStatus.State != plugins.StateOK) { - return false - } - - return true -} - -func (s *Server) unversionedGetHealth(w http.ResponseWriter, r *http.Request) { - ctx := r.Context() - includeBundleStatus := getBoolParam(r.URL, types.ParamBundleActivationV1, true) || - getBoolParam(r.URL, types.ParamBundlesActivationV1, true) - includePluginStatus := getBoolParam(r.URL, types.ParamPluginsV1, true) - excludePlugin := getStringSliceParam(r.URL, types.ParamExcludePluginV1) - excludePluginMap := map[string]struct{}{} - for _, name := range excludePlugin { - excludePluginMap[name] = struct{}{} - } - - // Ensure the server can evaluate a simple query - if !s.canEval(ctx) { - writeHealthResponse(w, errors.New("unable to perform evaluation")) - return - } - - pluginStatuses := s.manager.PluginStatus() - - // Ensure that bundles (if configured, and requested to be included in the result) - // have been activated successfully. This will include discovery bundles as well as - // normal bundles that are configured. - if includeBundleStatus && !s.bundlesReady(pluginStatuses) { - // For backwards compatibility we don't return a payload with statuses for the bundle endpoint - writeHealthResponse(w, errors.New("one or more bundles are not activated")) - return - } - - if includePluginStatus { - // Ensure that all plugins (if requested to be included in the result) have an OK status. - hasErr := false - for name, status := range pluginStatuses { - if _, exclude := excludePluginMap[name]; exclude { - continue - } - if status != nil && status.State != plugins.StateOK { - hasErr = true - break - } - } - if hasErr { - writeHealthResponse(w, errors.New("one or more plugins are not up")) - return - } - } - writeHealthResponse(w, nil) -} - -func (s *Server) unversionedGetHealthWithPolicy(w http.ResponseWriter, r *http.Request) { - pluginStatus := s.manager.PluginStatus() - pluginState := map[string]string{} - - // optimistically assume all plugins are ok - allPluginsOk := true - - // build input document for health check query - input := func() map[string]interface{} { - s.mtx.Lock() - defer s.mtx.Unlock() - - // iterate over plugin status to extract state - for name, status := range pluginStatus { - if status != nil { - pluginState[name] = string(status.State) - // if all plugins have not been in OK state yet, then check to see if plugin state is OKx - if !s.allPluginsOkOnce && status.State != plugins.StateOK { - allPluginsOk = false - } - } - } - // once all plugins are OK, set the allPluginsOkOnce flag to true, indicating that all - // plugins have achieved a "ready" state at least once on the server. - if allPluginsOk { - s.allPluginsOkOnce = true - } - - return map[string]interface{}{ - "plugin_state": pluginState, - "plugins_ready": s.allPluginsOkOnce, - } - }() - - vars := mux.Vars(r) - urlPath := vars["path"] - healthDataPath := fmt.Sprintf("/system/health/%s", urlPath) - healthDataPath = stringPathToDataRef(healthDataPath).String() - - rego := rego.New( - rego.Query(healthDataPath), - rego.Compiler(s.getCompiler()), - rego.Store(s.store), - rego.Input(input), - rego.Runtime(s.runtime), - rego.PrintHook(s.manager.PrintHook()), - ) - - rs, err := rego.Eval(r.Context()) - if err != nil { - writeHealthResponse(w, err) - return - } - - if len(rs) == 0 { - writeHealthResponse(w, fmt.Errorf("health check (%v) was undefined", healthDataPath)) - return - } - - result, ok := rs[0].Expressions[0].Value.(bool) - if ok && result { - writeHealthResponse(w, nil) - return - } - - writeHealthResponse(w, fmt.Errorf("health check (%v) returned unexpected value", healthDataPath)) -} - -func writeHealthResponse(w http.ResponseWriter, err error) { - if err != nil { - writer.JSON(w, http.StatusInternalServerError, types.HealthResponseV1{Error: err.Error()}, false) - return - } - - writer.JSONOK(w, types.HealthResponseV1{}, false) -} - -func (s *Server) v1CompilePost(w http.ResponseWriter, r *http.Request) { - ctx := r.Context() - explainMode := getExplain(r.URL.Query()[types.ParamExplainV1], types.ExplainOffV1) - includeInstrumentation := getBoolParam(r.URL, types.ParamInstrumentV1, true) - - m := metrics.New() - m.Timer(metrics.ServerHandler).Start() - m.Timer(metrics.RegoQueryParse).Start() - - // decompress the input if sent as zip - body, err := util.ReadMaybeCompressedBody(r) - if err != nil { - writer.Error(w, http.StatusBadRequest, types.NewErrorV1(types.CodeInvalidParameter, "could not decompress the body")) - return - } - - request, reqErr := readInputCompilePostV1(body) - if reqErr != nil { - writer.Error(w, http.StatusBadRequest, reqErr) - return - } - - m.Timer(metrics.RegoQueryParse).Stop() - - c := storage.NewContext().WithMetrics(m) - txn, err := s.store.NewTransaction(ctx, storage.TransactionParams{Context: c}) - if err != nil { - writer.ErrorAuto(w, err) - return - } - - defer s.store.Abort(ctx, txn) - - var buf *topdown.BufferTracer - if explainMode != types.ExplainOffV1 { - buf = topdown.NewBufferTracer() - } - - eval := rego.New( - rego.Compiler(s.getCompiler()), - rego.Store(s.store), - rego.Transaction(txn), - rego.ParsedQuery(request.Query), - rego.ParsedInput(request.Input), - rego.ParsedUnknowns(request.Unknowns), - rego.DisableInlining(request.Options.DisableInlining), - rego.QueryTracer(buf), - rego.Instrument(includeInstrumentation), - rego.Metrics(m), - rego.Runtime(s.runtime), - rego.UnsafeBuiltins(unsafeBuiltinsMap), - rego.InterQueryBuiltinCache(s.interQueryBuiltinCache), - rego.InterQueryBuiltinValueCache(s.interQueryBuiltinValueCache), - rego.PrintHook(s.manager.PrintHook()), - ) - - pq, err := eval.Partial(ctx) - if err != nil { - switch err := err.(type) { - case ast.Errors: - writer.Error(w, http.StatusBadRequest, types.NewErrorV1(types.CodeInvalidParameter, types.MsgCompileModuleError).WithASTErrors(err)) - default: - writer.ErrorAuto(w, err) - } - return - } - - m.Timer(metrics.ServerHandler).Stop() - - result := types.CompileResponseV1{} - - if includeMetrics(r) || includeInstrumentation { - result.Metrics = m.All() - } - - if explainMode != types.ExplainOffV1 { - result.Explanation = s.getExplainResponse(explainMode, *buf, pretty(r)) - } - - var i interface{} = types.PartialEvaluationResultV1{ - Queries: pq.Queries, - Support: pq.Support, - } - - result.Result = &i - - writer.JSONOK(w, result, pretty(r)) -} - -func (s *Server) v1DataGet(w http.ResponseWriter, r *http.Request) { - m := metrics.New() - - m.Timer(metrics.ServerHandler).Start() - - decisionID := s.generateDecisionID() - ctx := logging.WithDecisionID(r.Context(), decisionID) - annotateSpan(ctx, decisionID) - - vars := mux.Vars(r) - urlPath := vars["path"] - explainMode := getExplain(r.URL.Query()["explain"], types.ExplainOffV1) - includeInstrumentation := getBoolParam(r.URL, types.ParamInstrumentV1, true) - provenance := getBoolParam(r.URL, types.ParamProvenanceV1, true) - strictBuiltinErrors := getBoolParam(r.URL, types.ParamStrictBuiltinErrors, true) - - m.Timer(metrics.RegoInputParse).Start() - - inputs := r.URL.Query()[types.ParamInputV1] - - var input ast.Value - var goInput *interface{} - - if len(inputs) > 0 { - var err error - input, goInput, err = readInputGetV1(inputs[len(inputs)-1]) - if err != nil { - writer.ErrorString(w, http.StatusBadRequest, types.CodeInvalidParameter, err) - return - } - } - - m.Timer(metrics.RegoInputParse).Stop() - - // Prepare for query. - c := storage.NewContext().WithMetrics(m) - txn, err := s.store.NewTransaction(ctx, storage.TransactionParams{Context: c}) - if err != nil { - writer.ErrorAuto(w, err) - return - } - defer s.store.Abort(ctx, txn) - - br, err := getRevisions(ctx, s.store, txn) - if err != nil { - writer.ErrorAuto(w, err) - return - } - - logger := s.getDecisionLogger(br) - - var ndbCache builtins.NDBCache - if s.ndbCacheEnabled { - ndbCache = builtins.NDBCache{} - } - - var buf *topdown.BufferTracer - - if explainMode != types.ExplainOffV1 { - buf = topdown.NewBufferTracer() - } - - pqID := "v1DataGet::" - if strictBuiltinErrors { - pqID += "strict-builtin-errors::" - } - pqID += urlPath - preparedQuery, ok := s.getCachedPreparedEvalQuery(pqID, m) - if !ok { - opts := []func(*rego.Rego){ - rego.Compiler(s.getCompiler()), - rego.Store(s.store), - } - - for _, r := range s.manager.GetWasmResolvers() { - for _, entrypoint := range r.Entrypoints() { - opts = append(opts, rego.Resolver(entrypoint, r)) - } - } - - rego, err := s.makeRego(ctx, strictBuiltinErrors, txn, input, urlPath, m, includeInstrumentation, buf, opts) - if err != nil { - _ = logger.Log(ctx, txn, urlPath, "", goInput, input, nil, ndbCache, err, m) - writer.ErrorAuto(w, err) - return - } - - pq, err := rego.PrepareForEval(ctx) - if err != nil { - _ = logger.Log(ctx, txn, urlPath, "", goInput, input, nil, ndbCache, err, m) - writer.ErrorAuto(w, err) - return - } - preparedQuery = &pq - s.preparedEvalQueries.Insert(pqID, preparedQuery) - } - - evalOpts := []rego.EvalOption{ - rego.EvalTransaction(txn), - rego.EvalParsedInput(input), - rego.EvalMetrics(m), - rego.EvalQueryTracer(buf), - rego.EvalInterQueryBuiltinCache(s.interQueryBuiltinCache), - rego.EvalInterQueryBuiltinValueCache(s.interQueryBuiltinValueCache), - rego.EvalInstrument(includeInstrumentation), - rego.EvalNDBuiltinCache(ndbCache), - } - - rs, err := preparedQuery.Eval( - ctx, - evalOpts..., - ) - - m.Timer(metrics.ServerHandler).Stop() - - // Handle results. - if err != nil { - _ = logger.Log(ctx, txn, urlPath, "", goInput, input, nil, ndbCache, err, m) - writer.ErrorAuto(w, err) - return - } - - result := types.DataResponseV1{ - DecisionID: decisionID, - } - - if includeMetrics(r) || includeInstrumentation { - result.Metrics = m.All() - } - - if provenance { - result.Provenance = s.getProvenance(br) - } - - if len(rs) == 0 { - if explainMode == types.ExplainFullV1 { - result.Explanation, err = types.NewTraceV1(lineage.Full(*buf), pretty(r)) - if err != nil { - writer.ErrorAuto(w, err) - return - } - } - - if err := logger.Log(ctx, txn, urlPath, "", goInput, input, nil, ndbCache, nil, m); err != nil { - writer.ErrorAuto(w, err) - return - } - writer.JSONOK(w, result, pretty(r)) - return - } - - result.Result = &rs[0].Expressions[0].Value - - if explainMode != types.ExplainOffV1 { - result.Explanation = s.getExplainResponse(explainMode, *buf, pretty(r)) - } - - if err := logger.Log(ctx, txn, urlPath, "", goInput, input, result.Result, ndbCache, nil, m); err != nil { - writer.ErrorAuto(w, err) - return - } - writer.JSONOK(w, result, pretty(r)) -} - -func (s *Server) v1DataPatch(w http.ResponseWriter, r *http.Request) { - m := metrics.New() - m.Timer(metrics.ServerHandler).Start() - defer m.Timer(metrics.ServerHandler).Stop() - - ctx := r.Context() - vars := mux.Vars(r) - var ops []types.PatchV1 - - m.Timer(metrics.RegoInputParse).Start() - if err := util.NewJSONDecoder(r.Body).Decode(&ops); err != nil { - writer.ErrorString(w, http.StatusBadRequest, types.CodeInvalidParameter, err) - return - } - m.Timer(metrics.RegoInputParse).Stop() - - patches, err := s.prepareV1PatchSlice(vars["path"], ops) - if err != nil { - writer.ErrorAuto(w, err) - return - } - - params := storage.WriteParams - params.Context = storage.NewContext().WithMetrics(m) - txn, err := s.store.NewTransaction(ctx, params) - if err != nil { - writer.ErrorAuto(w, err) - return - } - - for _, patch := range patches { - if err := s.checkPathScope(ctx, txn, patch.path); err != nil { - s.abortAuto(ctx, txn, w, err) - return - } - - if err := s.store.Write(ctx, txn, patch.op, patch.path, patch.value); err != nil { - s.abortAuto(ctx, txn, w, err) - return - } - } - - if err := ast.CheckPathConflicts(s.getCompiler(), storage.NonEmpty(ctx, s.store, txn)); len(err) > 0 { - s.store.Abort(ctx, txn) - writer.ErrorString(w, http.StatusBadRequest, types.CodeInvalidParameter, err) - return - } - - if err := s.store.Commit(ctx, txn); err != nil { - writer.ErrorAuto(w, err) - return - } - - if includeMetrics(r) { - result := types.DataResponseV1{ - Metrics: m.All(), - } - writer.JSONOK(w, result, false) - return - } - - w.WriteHeader(http.StatusNoContent) -} - -func (s *Server) v1DataPost(w http.ResponseWriter, r *http.Request) { - m := metrics.New() - m.Timer(metrics.ServerHandler).Start() - - decisionID := s.generateDecisionID() - ctx := logging.WithDecisionID(r.Context(), decisionID) - annotateSpan(ctx, decisionID) - - vars := mux.Vars(r) - urlPath := vars["path"] - explainMode := getExplain(r.URL.Query()[types.ParamExplainV1], types.ExplainOffV1) - includeInstrumentation := getBoolParam(r.URL, types.ParamInstrumentV1, true) - provenance := getBoolParam(r.URL, types.ParamProvenanceV1, true) - strictBuiltinErrors := getBoolParam(r.URL, types.ParamStrictBuiltinErrors, true) - - m.Timer(metrics.RegoInputParse).Start() - - input, goInput, err := readInputPostV1(r) - if err != nil { - writer.ErrorString(w, http.StatusBadRequest, types.CodeInvalidParameter, err) - return - } - - m.Timer(metrics.RegoInputParse).Stop() - - txn, err := s.store.NewTransaction(ctx, storage.TransactionParams{Context: storage.NewContext().WithMetrics(m)}) - if err != nil { - writer.ErrorAuto(w, err) - return - } - - defer s.store.Abort(ctx, txn) - - br, err := getRevisions(ctx, s.store, txn) - if err != nil { - writer.ErrorAuto(w, err) - return - } - - logger := s.getDecisionLogger(br) - - var ndbCache builtins.NDBCache - if s.ndbCacheEnabled { - ndbCache = builtins.NDBCache{} - } - - var buf *topdown.BufferTracer - - if explainMode != types.ExplainOffV1 { - buf = topdown.NewBufferTracer() - } - - pqID := "v1DataPost::" - if strictBuiltinErrors { - pqID += "strict-builtin-errors::" - } - pqID += urlPath - preparedQuery, ok := s.getCachedPreparedEvalQuery(pqID, m) - if !ok { - opts := []func(*rego.Rego){ - rego.Compiler(s.getCompiler()), - rego.Store(s.store), - } - - // Set resolvers on the base Rego object to avoid having them get - // re-initialized, and to propagate them to the prepared query. - for _, r := range s.manager.GetWasmResolvers() { - for _, entrypoint := range r.Entrypoints() { - opts = append(opts, rego.Resolver(entrypoint, r)) - } - } - - rego, err := s.makeRego(ctx, strictBuiltinErrors, txn, input, urlPath, m, includeInstrumentation, buf, opts) - if err != nil { - _ = logger.Log(ctx, txn, urlPath, "", goInput, input, nil, ndbCache, err, m) - writer.ErrorAuto(w, err) - return - } - - pq, err := rego.PrepareForEval(ctx) - if err != nil { - _ = logger.Log(ctx, txn, urlPath, "", goInput, input, nil, ndbCache, err, m) - writer.ErrorAuto(w, err) - return - } - preparedQuery = &pq - s.preparedEvalQueries.Insert(pqID, preparedQuery) - } - - evalOpts := []rego.EvalOption{ - rego.EvalTransaction(txn), - rego.EvalParsedInput(input), - rego.EvalMetrics(m), - rego.EvalQueryTracer(buf), - rego.EvalInterQueryBuiltinCache(s.interQueryBuiltinCache), - rego.EvalInterQueryBuiltinValueCache(s.interQueryBuiltinValueCache), - rego.EvalInstrument(includeInstrumentation), - rego.EvalNDBuiltinCache(ndbCache), - } - - rs, err := preparedQuery.Eval( - ctx, - evalOpts..., - ) - - m.Timer(metrics.ServerHandler).Stop() - - // Handle results. - if err != nil { - _ = logger.Log(ctx, txn, urlPath, "", goInput, input, nil, ndbCache, err, m) - writer.ErrorAuto(w, err) - return - } - - result := types.DataResponseV1{ - DecisionID: decisionID, - } - - if input == nil { - result.Warning = types.NewWarning(types.CodeAPIUsageWarn, types.MsgInputKeyMissing) - } - - if includeMetrics(r) || includeInstrumentation { - result.Metrics = m.All() - } - - if provenance { - result.Provenance = s.getProvenance(br) - } - - if len(rs) == 0 { - if explainMode == types.ExplainFullV1 { - result.Explanation, err = types.NewTraceV1(lineage.Full(*buf), pretty(r)) - if err != nil { - writer.ErrorAuto(w, err) - return - } - } - err = logger.Log(ctx, txn, urlPath, "", goInput, input, nil, ndbCache, nil, m) - if err != nil { - writer.ErrorAuto(w, err) - return - } - writer.JSONOK(w, result, pretty(r)) - return - } - - result.Result = &rs[0].Expressions[0].Value - - if explainMode != types.ExplainOffV1 { - result.Explanation = s.getExplainResponse(explainMode, *buf, pretty(r)) - } - - if err := logger.Log(ctx, txn, urlPath, "", goInput, input, result.Result, ndbCache, nil, m); err != nil { - writer.ErrorAuto(w, err) - return - } - writer.JSONOK(w, result, pretty(r)) -} - -func (s *Server) v1DataPut(w http.ResponseWriter, r *http.Request) { - m := metrics.New() - m.Timer(metrics.ServerHandler).Start() - defer m.Timer(metrics.ServerHandler).Stop() - - ctx := r.Context() - vars := mux.Vars(r) - - m.Timer(metrics.RegoInputParse).Start() - var value interface{} - if err := util.NewJSONDecoder(r.Body).Decode(&value); err != nil { - writer.ErrorString(w, http.StatusBadRequest, types.CodeInvalidParameter, err) - return - } - m.Timer(metrics.RegoInputParse).Stop() - - path, ok := storage.ParsePathEscaped("/" + strings.Trim(vars["path"], "/")) - if !ok { - writer.Error(w, http.StatusBadRequest, types.NewErrorV1(types.CodeInvalidParameter, "bad path: %v", vars["path"])) - return - } - - params := storage.WriteParams - params.Context = storage.NewContext().WithMetrics(m) - txn, err := s.store.NewTransaction(ctx, params) - if err != nil { - writer.ErrorAuto(w, err) - return - } - - if err := s.checkPathScope(ctx, txn, path); err != nil { - s.abortAuto(ctx, txn, w, err) - return - } - - _, err = s.store.Read(ctx, txn, path) - if err != nil { - if !storage.IsNotFound(err) { - s.abortAuto(ctx, txn, w, err) - return - } - if len(path) > 0 { - if err := storage.MakeDir(ctx, s.store, txn, path[:len(path)-1]); err != nil { - s.abortAuto(ctx, txn, w, err) - return - } - } - } else if r.Header.Get("If-None-Match") == "*" { - s.store.Abort(ctx, txn) - w.WriteHeader(http.StatusNotModified) - return - } - - if err := s.store.Write(ctx, txn, storage.AddOp, path, value); err != nil { - s.abortAuto(ctx, txn, w, err) - return - } - - if err := ast.CheckPathConflicts(s.getCompiler(), storage.NonEmpty(ctx, s.store, txn)); len(err) > 0 { - s.store.Abort(ctx, txn) - writer.ErrorString(w, http.StatusBadRequest, types.CodeInvalidParameter, err) - return - } - - if err := s.store.Commit(ctx, txn); err != nil { - writer.ErrorAuto(w, err) - return - } - - if includeMetrics(r) { - result := types.DataResponseV1{ - Metrics: m.All(), - } - writer.JSONOK(w, result, false) - return - } - - w.WriteHeader(http.StatusNoContent) -} - -func (s *Server) v1DataDelete(w http.ResponseWriter, r *http.Request) { - m := metrics.New() - m.Timer(metrics.ServerHandler).Start() - defer m.Timer(metrics.ServerHandler).Stop() - - ctx := r.Context() - vars := mux.Vars(r) - - path, ok := storage.ParsePathEscaped("/" + strings.Trim(vars["path"], "/")) - if !ok { - writer.Error(w, http.StatusBadRequest, types.NewErrorV1(types.CodeInvalidParameter, "bad path: %v", vars["path"])) - return - } - - params := storage.WriteParams - params.Context = storage.NewContext().WithMetrics(m) - txn, err := s.store.NewTransaction(ctx, params) - if err != nil { - writer.ErrorAuto(w, err) - return - } - - if err := s.checkPathScope(ctx, txn, path); err != nil { - s.abortAuto(ctx, txn, w, err) - return - } - - _, err = s.store.Read(ctx, txn, path) - if err != nil { - s.abortAuto(ctx, txn, w, err) - return - } - - if err := s.store.Write(ctx, txn, storage.RemoveOp, path, nil); err != nil { - s.abortAuto(ctx, txn, w, err) - return - } - - if err := s.store.Commit(ctx, txn); err != nil { - writer.ErrorAuto(w, err) - return - } - - if includeMetrics(r) { - result := types.DataResponseV1{ - Metrics: m.All(), - } - writer.JSONOK(w, result, false) - return - } - - w.WriteHeader(http.StatusNoContent) -} - -func (s *Server) v1PoliciesDelete(w http.ResponseWriter, r *http.Request) { - ctx := r.Context() - vars := mux.Vars(r) - - id, err := url.PathUnescape(vars["path"]) - if err != nil { - writer.ErrorString(w, http.StatusBadRequest, types.CodeInvalidParameter, err) - return - } - - m := metrics.New() - params := storage.WriteParams - params.Context = storage.NewContext().WithMetrics(m) - txn, err := s.store.NewTransaction(ctx, params) - if err != nil { - writer.ErrorAuto(w, err) - return - } - - if err := s.checkPolicyIDScope(ctx, txn, id); err != nil { - s.abortAuto(ctx, txn, w, err) - return - } - - modules, err := s.loadModules(ctx, txn) - if err != nil { - s.abortAuto(ctx, txn, w, err) - return - } - - delete(modules, id) - - c := ast.NewCompiler().SetErrorLimit(s.errLimit) - - m.Timer(metrics.RegoModuleCompile).Start() - - if c.Compile(modules); c.Failed() { - s.abort(ctx, txn, func() { - writer.Error(w, http.StatusBadRequest, types.NewErrorV1(types.CodeInvalidOperation, types.MsgCompileModuleError).WithASTErrors(c.Errors)) - }) - return - } - - m.Timer(metrics.RegoModuleCompile).Stop() - - if err := s.store.DeletePolicy(ctx, txn, id); err != nil { - s.abortAuto(ctx, txn, w, err) - return - } - - if err := s.store.Commit(ctx, txn); err != nil { - writer.ErrorAuto(w, err) - return - } - - resp := types.PolicyDeleteResponseV1{} - if includeMetrics(r) { - resp.Metrics = m.All() - } - - writer.JSONOK(w, resp, pretty(r)) -} - -func (s *Server) v1PoliciesGet(w http.ResponseWriter, r *http.Request) { - ctx := r.Context() - vars := mux.Vars(r) - - path, err := url.PathUnescape(vars["path"]) - if err != nil { - writer.ErrorString(w, http.StatusBadRequest, types.CodeInvalidParameter, err) - return - } - - txn, err := s.store.NewTransaction(ctx) - if err != nil { - writer.ErrorAuto(w, err) - return - } - - defer s.store.Abort(ctx, txn) - - bs, err := s.store.GetPolicy(ctx, txn, path) - if err != nil { - writer.ErrorAuto(w, err) - return - } - - c := s.getCompiler() - - resp := types.PolicyGetResponseV1{ - Result: types.PolicyV1{ - ID: path, - Raw: string(bs), - AST: c.Modules[path], - }, - } - - writer.JSONOK(w, resp, pretty(r)) -} - -func (s *Server) v1PoliciesList(w http.ResponseWriter, r *http.Request) { - - ctx := r.Context() - - txn, err := s.store.NewTransaction(ctx) - if err != nil { - writer.ErrorAuto(w, err) - return - } - - defer s.store.Abort(ctx, txn) - - policies := []types.PolicyV1{} - c := s.getCompiler() - - // Only return policies from the store, the compiler - // may contain additional partially compiled modules. - ids, err := s.store.ListPolicies(ctx, txn) - if err != nil { - writer.ErrorAuto(w, err) - return - } - for _, id := range ids { - bs, err := s.store.GetPolicy(ctx, txn, id) - if err != nil { - writer.ErrorAuto(w, err) - return - } - policy := types.PolicyV1{ - ID: id, - Raw: string(bs), - AST: c.Modules[id], - } - policies = append(policies, policy) - } - - writer.JSONOK(w, types.PolicyListResponseV1{Result: policies}, pretty(r)) -} - -func (s *Server) v1PoliciesPut(w http.ResponseWriter, r *http.Request) { - ctx := r.Context() - vars := mux.Vars(r) - - id, err := url.PathUnescape(vars["path"]) - if err != nil { - writer.ErrorString(w, http.StatusBadRequest, types.CodeInvalidParameter, err) - return - } - - includeMetrics := includeMetrics(r) - m := metrics.New() - - m.Timer("server_read_bytes").Start() - - buf, err := io.ReadAll(r.Body) - if err != nil { - writer.ErrorString(w, http.StatusBadRequest, types.CodeInvalidParameter, err) - return - } - - m.Timer("server_read_bytes").Stop() - - params := storage.WriteParams - params.Context = storage.NewContext().WithMetrics(m) - txn, err := s.store.NewTransaction(ctx, params) - if err != nil { - writer.ErrorAuto(w, err) - return - } - - if err := s.checkPolicyIDScope(ctx, txn, id); err != nil && !storage.IsNotFound(err) { - s.abortAuto(ctx, txn, w, err) - return - } - - if bs, err := s.store.GetPolicy(ctx, txn, id); err != nil { - if !storage.IsNotFound(err) { - s.abortAuto(ctx, txn, w, err) - return - } - } else if bytes.Equal(buf, bs) { - s.store.Abort(ctx, txn) - resp := types.PolicyPutResponseV1{} - if includeMetrics { - resp.Metrics = m.All() - } - writer.JSONOK(w, resp, pretty(r)) - return - } - - m.Timer(metrics.RegoModuleParse).Start() - parsedMod, err := ast.ParseModule(id, string(buf)) - m.Timer(metrics.RegoModuleParse).Stop() - - if err != nil { - s.store.Abort(ctx, txn) - switch err := err.(type) { - case ast.Errors: - writer.Error(w, http.StatusBadRequest, types.NewErrorV1(types.CodeInvalidParameter, types.MsgCompileModuleError).WithASTErrors(err)) - default: - writer.ErrorString(w, http.StatusBadRequest, types.CodeInvalidParameter, err) - } - return - } - - if parsedMod == nil { - s.store.Abort(ctx, txn) - writer.Error(w, http.StatusBadRequest, types.NewErrorV1(types.CodeInvalidParameter, "empty module")) - return - } - - if err := s.checkPolicyPackageScope(ctx, txn, parsedMod.Package); err != nil { - s.abortAuto(ctx, txn, w, err) - return - } - - modules, err := s.loadModules(ctx, txn) - if err != nil { - s.abortAuto(ctx, txn, w, err) - return - } - - modules[id] = parsedMod - - c := ast.NewCompiler(). - SetErrorLimit(s.errLimit). - WithPathConflictsCheck(storage.NonEmpty(ctx, s.store, txn)). - WithEnablePrintStatements(s.manager.EnablePrintStatements()) - - m.Timer(metrics.RegoModuleCompile).Start() - - if c.Compile(modules); c.Failed() { - s.abort(ctx, txn, func() { - writer.Error(w, http.StatusBadRequest, types.NewErrorV1(types.CodeInvalidParameter, types.MsgCompileModuleError).WithASTErrors(c.Errors)) - }) - return - } - - m.Timer(metrics.RegoModuleCompile).Stop() - - if err := s.store.UpsertPolicy(ctx, txn, id, buf); err != nil { - s.abortAuto(ctx, txn, w, err) - return - } - - if err := s.store.Commit(ctx, txn); err != nil { - writer.ErrorAuto(w, err) - return - } - - resp := types.PolicyPutResponseV1{} - - if includeMetrics { - resp.Metrics = m.All() - } - - writer.JSONOK(w, resp, pretty(r)) -} - -func (s *Server) v1QueryGet(w http.ResponseWriter, r *http.Request) { - m := metrics.New() - - decisionID := s.generateDecisionID() - ctx := logging.WithDecisionID(r.Context(), decisionID) - annotateSpan(ctx, decisionID) - - values := r.URL.Query() - - qStrs := values[types.ParamQueryV1] - if len(qStrs) == 0 { - writer.Error(w, http.StatusBadRequest, types.NewErrorV1(types.CodeInvalidParameter, "missing parameter 'q'")) - return - } - qStr := qStrs[len(qStrs)-1] - - parsedQuery, err := validateQuery(qStr) - if err != nil { - switch err := err.(type) { - case ast.Errors: - writer.Error(w, http.StatusBadRequest, types.NewErrorV1(types.CodeInvalidParameter, types.MsgParseQueryError).WithASTErrors(err)) - default: - writer.ErrorAuto(w, err) - } - return - } - - explainMode := getExplain(r.URL.Query()["explain"], types.ExplainOffV1) - includeInstrumentation := getBoolParam(r.URL, types.ParamInstrumentV1, true) - - params := storage.TransactionParams{Context: storage.NewContext().WithMetrics(m)} - txn, err := s.store.NewTransaction(ctx, params) - if err != nil { - writer.ErrorAuto(w, err) - return - } - - defer s.store.Abort(ctx, txn) - - br, err := getRevisions(ctx, s.store, txn) - if err != nil { - writer.ErrorAuto(w, err) - return - } - pretty := pretty(r) - results, err := s.execQuery(ctx, br, txn, parsedQuery, nil, nil, m, explainMode, includeMetrics(r), includeInstrumentation, pretty) - if err != nil { - switch err := err.(type) { - case ast.Errors: - writer.Error(w, http.StatusBadRequest, types.NewErrorV1(types.CodeInvalidParameter, types.MsgCompileQueryError).WithASTErrors(err)) - default: - writer.ErrorAuto(w, err) - } - return - } - - writer.JSONOK(w, results, pretty) -} - -func (s *Server) v1QueryPost(w http.ResponseWriter, r *http.Request) { - m := metrics.New() - m.Timer(metrics.ServerHandler).Start() - - decisionID := s.generateDecisionID() - ctx := logging.WithDecisionID(r.Context(), decisionID) - annotateSpan(ctx, decisionID) - - var request types.QueryRequestV1 - err := util.NewJSONDecoder(r.Body).Decode(&request) - if err != nil { - writer.Error(w, http.StatusBadRequest, types.NewErrorV1(types.CodeInvalidParameter, "error(s) occurred while decoding request: %v", err.Error())) - return - } - qStr := request.Query - parsedQuery, err := validateQuery(qStr) - if err != nil { - switch err := err.(type) { - case ast.Errors: - writer.Error(w, http.StatusBadRequest, types.NewErrorV1(types.CodeInvalidParameter, types.MsgParseQueryError).WithASTErrors(err)) - default: - writer.ErrorAuto(w, err) - } - return - } - - pretty := pretty(r) - explainMode := getExplain(r.URL.Query()["explain"], types.ExplainOffV1) - includeMetrics := includeMetrics(r) - includeInstrumentation := getBoolParam(r.URL, types.ParamInstrumentV1, true) - - var input ast.Value - - if request.Input != nil { - input, err = ast.InterfaceToValue(*request.Input) - if err != nil { - writer.ErrorAuto(w, err) - return - } - } - - params := storage.TransactionParams{Context: storage.NewContext().WithMetrics(m)} - txn, err := s.store.NewTransaction(ctx, params) - if err != nil { - writer.ErrorAuto(w, err) - return - } - - defer s.store.Abort(ctx, txn) - - br, err := getRevisions(ctx, s.store, txn) - if err != nil { - writer.ErrorAuto(w, err) - return - } - - results, err := s.execQuery(ctx, br, txn, parsedQuery, input, request.Input, m, explainMode, includeMetrics, includeInstrumentation, pretty) - if err != nil { - switch err := err.(type) { - case ast.Errors: - writer.Error(w, http.StatusBadRequest, types.NewErrorV1(types.CodeInvalidParameter, types.MsgCompileQueryError).WithASTErrors(err)) - default: - writer.ErrorAuto(w, err) - } - return - } - - m.Timer(metrics.ServerHandler).Stop() - - if includeMetrics || includeInstrumentation { - results.Metrics = m.All() - } - - writer.JSONOK(w, results, pretty) -} - -func (s *Server) v1ConfigGet(w http.ResponseWriter, r *http.Request) { - result, err := s.manager.Config.ActiveConfig() - if err != nil { - writer.ErrorAuto(w, err) - return - } - writer.JSONOK(w, types.ConfigResponseV1{Result: &result}, pretty(r)) -} - -func (s *Server) v1StatusGet(w http.ResponseWriter, r *http.Request) { - p := status.Lookup(s.manager) - if p == nil { - writer.ErrorString(w, http.StatusInternalServerError, types.CodeInternal, errors.New("status plugin not enabled")) - return - } - - var st interface{} = p.Snapshot() - writer.JSONOK(w, types.StatusResponseV1{Result: &st}, pretty(r)) -} - -func (s *Server) checkPolicyIDScope(ctx context.Context, txn storage.Transaction, id string) error { - - bs, err := s.store.GetPolicy(ctx, txn, id) - if err != nil { - return err - } - - module, err := ast.ParseModule(id, string(bs)) - if err != nil { - return err - } - - return s.checkPolicyPackageScope(ctx, txn, module.Package) -} - -func (s *Server) checkPolicyPackageScope(ctx context.Context, txn storage.Transaction, pkg *ast.Package) error { - - path, err := pkg.Path.Ptr() - if err != nil { - return err - } - - spath, ok := storage.ParsePathEscaped("/" + path) - if !ok { - return types.BadRequestErr("invalid package path: cannot determine scope") - } - - return s.checkPathScope(ctx, txn, spath) -} - -func (s *Server) checkPathScope(ctx context.Context, txn storage.Transaction, path storage.Path) error { - - names, err := bundle.ReadBundleNamesFromStore(ctx, s.store, txn) - if err != nil { - if !storage.IsNotFound(err) { - return err - } - return nil - } - - bundleRoots := map[string][]string{} - for _, name := range names { - roots, err := bundle.ReadBundleRootsFromStore(ctx, s.store, txn, name) - if err != nil && !storage.IsNotFound(err) { - return err - } - bundleRoots[name] = roots - } - - spath := strings.Trim(path.String(), "/") - - if spath == "" && len(bundleRoots) > 0 { - return types.BadRequestErr("can't write to document root with bundle roots configured") - } - - spathParts := strings.Split(spath, "/") - - for name, roots := range bundleRoots { - if roots == nil { - return types.BadRequestErr(fmt.Sprintf("all paths owned by bundle %q", name)) - } - for _, root := range roots { - if root == "" { - return types.BadRequestErr(fmt.Sprintf("all paths owned by bundle %q", name)) - } - if isPathOwned(spathParts, strings.Split(root, "/")) { - return types.BadRequestErr(fmt.Sprintf("path %v is owned by bundle %q", spath, name)) - } - } - } - - return nil -} - -func (s *Server) getDecisionLogger(br bundleRevisions) (logger decisionLogger) { - // For backwards compatibility use `revision` as needed. - if s.hasLegacyBundle(br) { - logger.revision = br.LegacyRevision - } else { - logger.revisions = br.Revisions - } - logger.logger = s.logger - return logger -} - -func (s *Server) getExplainResponse(explainMode types.ExplainModeV1, trace []*topdown.Event, pretty bool) (explanation types.TraceV1) { - switch explainMode { - case types.ExplainNotesV1: - var err error - explanation, err = types.NewTraceV1(lineage.Notes(trace), pretty) - if err != nil { - break - } - case types.ExplainFailsV1: - var err error - explanation, err = types.NewTraceV1(lineage.Fails(trace), pretty) - if err != nil { - break - } - case types.ExplainFullV1: - var err error - explanation, err = types.NewTraceV1(lineage.Full(trace), pretty) - if err != nil { - break - } - case types.ExplainDebugV1: - var err error - explanation, err = types.NewTraceV1(lineage.Debug(trace), pretty) - if err != nil { - break - } - } - return explanation -} - -func (s *Server) abort(ctx context.Context, txn storage.Transaction, finish func()) { - s.store.Abort(ctx, txn) - finish() -} - -func (s *Server) abortAuto(ctx context.Context, txn storage.Transaction, w http.ResponseWriter, err error) { - s.abort(ctx, txn, func() { writer.ErrorAuto(w, err) }) -} - -func (s *Server) loadModules(ctx context.Context, txn storage.Transaction) (map[string]*ast.Module, error) { - - ids, err := s.store.ListPolicies(ctx, txn) - if err != nil { - return nil, err - } - - modules := make(map[string]*ast.Module, len(ids)) - - for _, id := range ids { - bs, err := s.store.GetPolicy(ctx, txn, id) - if err != nil { - return nil, err - } - - parsed, err := ast.ParseModule(id, string(bs)) - if err != nil { - return nil, err - } - - modules[id] = parsed - } - - return modules, nil -} - -func (s *Server) getCompiler() *ast.Compiler { - return s.manager.GetCompiler() -} - -func (s *Server) makeRego(_ context.Context, - strictBuiltinErrors bool, - txn storage.Transaction, - input ast.Value, - urlPath string, - m metrics.Metrics, - instrument bool, - tracer topdown.QueryTracer, - opts []func(*rego.Rego), -) (*rego.Rego, error) { - queryPath := stringPathToDataRef(urlPath).String() - - opts = append( - opts, - rego.Transaction(txn), - rego.Query(queryPath), - rego.ParsedInput(input), - rego.Metrics(m), - rego.QueryTracer(tracer), - rego.Instrument(instrument), - rego.Runtime(s.runtime), - rego.UnsafeBuiltins(unsafeBuiltinsMap), - rego.StrictBuiltinErrors(strictBuiltinErrors), - rego.PrintHook(s.manager.PrintHook()), - rego.DistributedTracingOpts(s.distributedTracingOpts), - ) - - return rego.New(opts...), nil -} - -func (s *Server) prepareV1PatchSlice(root string, ops []types.PatchV1) (result []patchImpl, err error) { - - root = "/" + strings.Trim(root, "/") - - for _, op := range ops { - - impl := patchImpl{ - value: op.Value, - } - - // Map patch operation. - switch op.Op { - case "add": - impl.op = storage.AddOp - case "remove": - impl.op = storage.RemoveOp - case "replace": - impl.op = storage.ReplaceOp - default: - return nil, types.BadPatchOperationErr(op.Op) - } - - // Construct patch path. - path := strings.Trim(op.Path, "/") - if len(path) > 0 { - if root == "/" { - path = root + path - } else { - path = root + "/" + path - } - } else { - path = root - } - - var ok bool - impl.path, ok = patch.ParsePatchPathEscaped(path) - if !ok { - return nil, types.BadPatchPathErr(op.Path) - } - - result = append(result, impl) - } - - return result, nil -} - -func (s *Server) generateDecisionID() string { - if s.decisionIDFactory != nil { - return s.decisionIDFactory() - } - return "" -} - -func (s *Server) getProvenance(br bundleRevisions) *types.ProvenanceV1 { - - p := &types.ProvenanceV1{ - Version: version.Version, - Vcs: version.Vcs, - Timestamp: version.Timestamp, - Hostname: version.Hostname, - } - - // For backwards compatibility, if the bundles are using the old - // style config we need to fill in the older `Revision` field. - // Otherwise use the newer `Bundles` keyword. - if s.hasLegacyBundle(br) { - p.Revision = br.LegacyRevision - } else { - p.Bundles = map[string]types.ProvenanceBundleV1{} - for name, revision := range br.Revisions { - p.Bundles[name] = types.ProvenanceBundleV1{Revision: revision} - } - } - - return p -} - -func (s *Server) hasLegacyBundle(br bundleRevisions) bool { - bp := bundlePlugin.Lookup(s.manager) - return br.LegacyRevision != "" || (bp != nil && !bp.Config().IsMultiBundle()) -} - -func (s *Server) generateDefaultDecisionPath() string { - // Assume the path is safe to transition back to a url - p, _ := s.manager.Config.DefaultDecisionRef().Ptr() - return p -} - -func isPathOwned(path, root []string) bool { - for i := 0; i < len(path) && i < len(root); i++ { - if path[i] != root[i] { - return false - } - } - return true -} - -func (s *Server) updateCacheConfig(cacheConfig *iCache.Config) { - s.interQueryBuiltinCache.UpdateConfig(cacheConfig) - s.interQueryBuiltinValueCache.UpdateConfig(cacheConfig) -} - -func (s *Server) updateNDCache(enabled bool) { - s.mtx.Lock() - defer s.mtx.Unlock() - s.ndbCacheEnabled = enabled -} - -func stringPathToDataRef(s string) (r ast.Ref) { - result := ast.Ref{ast.DefaultRootDocument} - return append(result, stringPathToRef(s)...) -} - -func stringPathToRef(s string) (r ast.Ref) { - if len(s) == 0 { - return r - } - p := strings.Split(s, "/") - for _, x := range p { - if x == "" { - continue - } - if y, err := url.PathUnescape(x); err == nil { - x = y - } - i, err := strconv.Atoi(x) - if err != nil { - r = append(r, ast.StringTerm(x)) - } else { - r = append(r, ast.IntNumberTerm(i)) - } - } - return r -} - -func validateQuery(query string) (ast.Body, error) { - return ast.ParseBody(query) -} - -func getBoolParam(url *url.URL, name string, ifEmpty bool) bool { - - p, ok := url.Query()[name] - if !ok { - return false - } - - // Query params w/o values are represented as slice (of len 1) with an - // empty string. - if len(p) == 1 && p[0] == "" { - return ifEmpty - } - - for _, x := range p { - if strings.ToLower(x) == "true" { - return true - } - } - - return false -} - -func getStringSliceParam(url *url.URL, name string) []string { - - p, ok := url.Query()[name] - if !ok { - return nil - } - - // Query params w/o values are represented as slice (of len 1) with an - // empty string. - if len(p) == 1 && p[0] == "" { - return nil - } - - return p -} - -func getExplain(p []string, zero types.ExplainModeV1) types.ExplainModeV1 { - for _, x := range p { - switch x { - case string(types.ExplainNotesV1): - return types.ExplainNotesV1 - case string(types.ExplainFailsV1): - return types.ExplainFailsV1 - case string(types.ExplainFullV1): - return types.ExplainFullV1 - case string(types.ExplainDebugV1): - return types.ExplainDebugV1 - } - } - return zero -} - -func readInputV0(r *http.Request) (ast.Value, *interface{}, error) { - - parsed, ok := authorizer.GetBodyOnContext(r.Context()) - if ok { - v, err := ast.InterfaceToValue(parsed) - return v, &parsed, err - } - - // decompress the input if sent as zip - bodyBytes, err := util.ReadMaybeCompressedBody(r) - if err != nil { - return nil, nil, fmt.Errorf("could not decompress the body: %w", err) - } - - var x interface{} - - if strings.Contains(r.Header.Get("Content-Type"), "yaml") { - if len(bodyBytes) > 0 { - if err = util.Unmarshal(bodyBytes, &x); err != nil { - return nil, nil, fmt.Errorf("body contains malformed input document: %w", err) - } - } - } else { - dec := util.NewJSONDecoder(bytes.NewBuffer(bodyBytes)) - if err := dec.Decode(&x); err != nil && err != io.EOF { - return nil, nil, fmt.Errorf("body contains malformed input document: %w", err) - } - } - - v, err := ast.InterfaceToValue(x) - return v, &x, err -} - -func readInputGetV1(str string) (ast.Value, *interface{}, error) { - var input interface{} - if err := util.UnmarshalJSON([]byte(str), &input); err != nil { - return nil, nil, fmt.Errorf("parameter contains malformed input document: %w", err) - } - v, err := ast.InterfaceToValue(input) - return v, &input, err -} - -func readInputPostV1(r *http.Request) (ast.Value, *interface{}, error) { - - parsed, ok := authorizer.GetBodyOnContext(r.Context()) - if ok { - if obj, ok := parsed.(map[string]interface{}); ok { - if input, ok := obj["input"]; ok { - v, err := ast.InterfaceToValue(input) - return v, &input, err - } - } - return nil, nil, nil - } - - var request types.DataRequestV1 - - // decompress the input if sent as zip - bodyBytes, err := util.ReadMaybeCompressedBody(r) - if err != nil { - return nil, nil, fmt.Errorf("could not decompress the body: %w", err) - } - - ct := r.Header.Get("Content-Type") - // There is no standard for yaml mime-type so we just look for - // anything related - if strings.Contains(ct, "yaml") { - if len(bodyBytes) > 0 { - if err = util.Unmarshal(bodyBytes, &request); err != nil { - return nil, nil, fmt.Errorf("body contains malformed input document: %w", err) - } - } - } else { - dec := util.NewJSONDecoder(bytes.NewBuffer(bodyBytes)) - if err := dec.Decode(&request); err != nil && err != io.EOF { - return nil, nil, fmt.Errorf("body contains malformed input document: %w", err) - } - } - - if request.Input == nil { - return nil, nil, nil - } - - v, err := ast.InterfaceToValue(*request.Input) - return v, request.Input, err -} - -type compileRequest struct { - Query ast.Body - Input ast.Value - Unknowns []*ast.Term - Options compileRequestOptions -} - -type compileRequestOptions struct { - DisableInlining []string -} - -func readInputCompilePostV1(reqBytes []byte) (*compileRequest, *types.ErrorV1) { - var request types.CompileRequestV1 - - err := util.NewJSONDecoder(bytes.NewBuffer(reqBytes)).Decode(&request) - if err != nil { - return nil, types.NewErrorV1(types.CodeInvalidParameter, "error(s) occurred while decoding request: %v", err.Error()) - } - - query, err := ast.ParseBody(request.Query) - if err != nil { - switch err := err.(type) { - case ast.Errors: - return nil, types.NewErrorV1(types.CodeInvalidParameter, types.MsgParseQueryError).WithASTErrors(err) - default: - return nil, types.NewErrorV1(types.CodeInvalidParameter, "%v: %v", types.MsgParseQueryError, err) - } - } else if len(query) == 0 { - return nil, types.NewErrorV1(types.CodeInvalidParameter, "missing required 'query' value") - } - - var input ast.Value - if request.Input != nil { - input, err = ast.InterfaceToValue(*request.Input) - if err != nil { - return nil, types.NewErrorV1(types.CodeInvalidParameter, "error(s) occurred while converting input: %v", err) - } - } - - var unknowns []*ast.Term - if request.Unknowns != nil { - unknowns = make([]*ast.Term, len(*request.Unknowns)) - for i, s := range *request.Unknowns { - unknowns[i], err = ast.ParseTerm(s) - if err != nil { - return nil, types.NewErrorV1(types.CodeInvalidParameter, "error(s) occurred while parsing unknowns: %v", err) - } - } - } - - result := &compileRequest{ - Query: query, - Input: input, - Unknowns: unknowns, - Options: compileRequestOptions{ - DisableInlining: request.Options.DisableInlining, - }, - } - - return result, nil -} - -var indexHTML, _ = template.New("index").Parse(` - - - - - -
- ________      ________    ________
-|\   __  \    |\   __  \  |\   __  \
-\ \  \|\  \   \ \  \|\  \ \ \  \|\  \
- \ \  \\\  \   \ \   ____\ \ \   __  \
-  \ \  \\\  \   \ \  \___|  \ \  \ \  \
-   \ \_______\   \ \__\      \ \__\ \__\
-    \|_______|    \|__|       \|__|\|__|
-
-Open Policy Agent - An open source project to policy-enable your service.
-
-Version: {{ .Version }}
-Build Commit: {{ .BuildCommit }}
-Build Timestamp: {{ .BuildTimestamp }}
-Build Hostname: {{ .BuildHostname }}
-
-Query:
-
-
Input Data (JSON):
-
-
-
- - -`) - -type decisionLogger struct { - revisions map[string]string - revision string // Deprecated: Use `revisions` instead. - logger func(context.Context, *Info) error -} - -func (l decisionLogger) Log(ctx context.Context, txn storage.Transaction, path string, query string, goInput *interface{}, astInput ast.Value, goResults *interface{}, ndbCache builtins.NDBCache, err error, m metrics.Metrics) error { - - bundles := map[string]BundleInfo{} - for name, rev := range l.revisions { - bundles[name] = BundleInfo{Revision: rev} - } - - rctx := logging.RequestContext{} - if r, ok := logging.FromContext(ctx); ok { - rctx = *r - } - decisionID, _ := logging.DecisionIDFromContext(ctx) - - var httpRctx logging.HTTPRequestContext - - httpRctxVal, _ := logging.HTTPRequestContextFromContext(ctx) - if httpRctxVal != nil { - httpRctx = *httpRctxVal - } - - info := &Info{ - Txn: txn, - Revision: l.revision, - Bundles: bundles, - Timestamp: time.Now().UTC(), - DecisionID: decisionID, - RemoteAddr: rctx.ClientAddr, - HTTPRequestContext: httpRctx, - Path: path, - Query: query, - Input: goInput, - InputAST: astInput, - Results: goResults, - Error: err, - Metrics: m, - RequestID: rctx.ReqID, - } - - if ndbCache != nil { - x, err := ast.JSON(ndbCache.AsValue()) - if err != nil { - return err - } - info.NDBuiltinCache = &x - } - - sctx := trace.SpanFromContext(ctx).SpanContext() - if sctx.IsValid() { - info.TraceID = sctx.TraceID().String() - info.SpanID = sctx.SpanID().String() - } - - if l.logger != nil { - if err := l.logger(ctx, info); err != nil { - return fmt.Errorf("decision_logs: %w", err) - } - } - - return nil -} - -type patchImpl struct { - path storage.Path - op storage.PatchOp - value interface{} -} - -func parseURL(s string, useHTTPSByDefault bool) (*url.URL, error) { - if !strings.Contains(s, "://") { - scheme := "http://" - if useHTTPSByDefault { - scheme = "https://" - } - s = scheme + s - } - return url.Parse(s) -} - -func annotateSpan(ctx context.Context, decisionID string) { - if decisionID == "" { - return - } - trace.SpanFromContext(ctx). - SetAttributes(attribute.String(otelDecisionIDAttr, decisionID)) -} - -func pretty(r *http.Request) bool { - return getBoolParam(r.URL, types.ParamPrettyV1, true) -} - -func includeMetrics(r *http.Request) bool { - return getBoolParam(r.URL, types.ParamMetricsV1, true) + return v1.New() } diff --git a/vendor/github.com/open-policy-agent/opa/storage/doc.go b/vendor/github.com/open-policy-agent/opa/storage/doc.go index 6fa2f86d9..c33db689e 100644 --- a/vendor/github.com/open-policy-agent/opa/storage/doc.go +++ b/vendor/github.com/open-policy-agent/opa/storage/doc.go @@ -3,4 +3,8 @@ // license that can be found in the LICENSE file. // Package storage exposes the policy engine's storage layer. +// +// Deprecated: This package is intended for older projects transitioning from OPA v0.x and will remain for the lifetime of OPA v1.x, but its use is not recommended. +// For newer features and behaviours, such as defaulting to the Rego v1 syntax, use the corresponding components in the [github.com/open-policy-agent/opa/v1] package instead. +// See https://www.openpolicyagent.org/docs/latest/v0-compatibility/ for more information. package storage diff --git a/vendor/github.com/open-policy-agent/opa/storage/errors.go b/vendor/github.com/open-policy-agent/opa/storage/errors.go index 8c789052e..1403b3a98 100644 --- a/vendor/github.com/open-policy-agent/opa/storage/errors.go +++ b/vendor/github.com/open-policy-agent/opa/storage/errors.go @@ -5,118 +5,69 @@ package storage import ( - "fmt" + v1 "github.com/open-policy-agent/opa/v1/storage" ) const ( // InternalErr indicates an unknown, internal error has occurred. - InternalErr = "storage_internal_error" + InternalErr = v1.InternalErr // NotFoundErr indicates the path used in the storage operation does not // locate a document. - NotFoundErr = "storage_not_found_error" + NotFoundErr = v1.NotFoundErr // WriteConflictErr indicates a write on the path enocuntered a conflicting // value inside the transaction. - WriteConflictErr = "storage_write_conflict_error" + WriteConflictErr = v1.WriteConflictErr // InvalidPatchErr indicates an invalid patch/write was issued. The patch // was rejected. - InvalidPatchErr = "storage_invalid_patch_error" + InvalidPatchErr = v1.InvalidPatchErr // InvalidTransactionErr indicates an invalid operation was performed // inside of the transaction. - InvalidTransactionErr = "storage_invalid_txn_error" + InvalidTransactionErr = v1.InvalidTransactionErr // TriggersNotSupportedErr indicates the caller attempted to register a // trigger against a store that does not support them. - TriggersNotSupportedErr = "storage_triggers_not_supported_error" + TriggersNotSupportedErr = v1.TriggersNotSupportedErr // WritesNotSupportedErr indicate the caller attempted to perform a write // against a store that does not support them. - WritesNotSupportedErr = "storage_writes_not_supported_error" + WritesNotSupportedErr = v1.WritesNotSupportedErr // PolicyNotSupportedErr indicate the caller attempted to perform a policy // management operation against a store that does not support them. - PolicyNotSupportedErr = "storage_policy_not_supported_error" + PolicyNotSupportedErr = v1.PolicyNotSupportedErr ) // Error is the error type returned by the storage layer. -type Error struct { - Code string `json:"code"` - Message string `json:"message"` -} - -func (err *Error) Error() string { - if err.Message != "" { - return fmt.Sprintf("%v: %v", err.Code, err.Message) - } - return err.Code -} +type Error = v1.Error // IsNotFound returns true if this error is a NotFoundErr. func IsNotFound(err error) bool { - switch err := err.(type) { - case *Error: - return err.Code == NotFoundErr - } - return false + return v1.IsNotFound(err) } // IsWriteConflictError returns true if this error a WriteConflictErr. func IsWriteConflictError(err error) bool { - switch err := err.(type) { - case *Error: - return err.Code == WriteConflictErr - } - return false + return v1.IsWriteConflictError(err) } // IsInvalidPatch returns true if this error is a InvalidPatchErr. func IsInvalidPatch(err error) bool { - switch err := err.(type) { - case *Error: - return err.Code == InvalidPatchErr - } - return false + return v1.IsInvalidPatch(err) } // IsInvalidTransaction returns true if this error is a InvalidTransactionErr. func IsInvalidTransaction(err error) bool { - switch err := err.(type) { - case *Error: - return err.Code == InvalidTransactionErr - } - return false + return v1.IsInvalidTransaction(err) } // IsIndexingNotSupported is a stub for backwards-compatibility. // // Deprecated: We no longer return IndexingNotSupported errors, so it is // unnecessary to check for them. -func IsIndexingNotSupported(error) bool { return false } - -func writeConflictError(path Path) *Error { - return &Error{ - Code: WriteConflictErr, - Message: path.String(), - } -} - -func triggersNotSupportedError() *Error { - return &Error{ - Code: TriggersNotSupportedErr, - } -} - -func writesNotSupportedError() *Error { - return &Error{ - Code: WritesNotSupportedErr, - } -} - -func policyNotSupportedError() *Error { - return &Error{ - Code: PolicyNotSupportedErr, - } +func IsIndexingNotSupported(err error) bool { + return v1.IsIndexingNotSupported(err) } diff --git a/vendor/github.com/open-policy-agent/opa/storage/inmem/doc.go b/vendor/github.com/open-policy-agent/opa/storage/inmem/doc.go new file mode 100644 index 000000000..5f536b66d --- /dev/null +++ b/vendor/github.com/open-policy-agent/opa/storage/inmem/doc.go @@ -0,0 +1,8 @@ +// Copyright 2016 The OPA Authors. All rights reserved. +// Use of this source code is governed by an Apache2 +// license that can be found in the LICENSE file. + +// Deprecated: This package is intended for older projects transitioning from OPA v0.x and will remain for the lifetime of OPA v1.x, but its use is not recommended. +// For newer features and behaviours, such as defaulting to the Rego v1 syntax, use the corresponding components in the [github.com/open-policy-agent/opa/v1] package instead. +// See https://www.openpolicyagent.org/docs/latest/v0-compatibility/ for more information. +package inmem diff --git a/vendor/github.com/open-policy-agent/opa/storage/inmem/inmem.go b/vendor/github.com/open-policy-agent/opa/storage/inmem/inmem.go index 9f5b8ba25..0a41b9d0d 100644 --- a/vendor/github.com/open-policy-agent/opa/storage/inmem/inmem.go +++ b/vendor/github.com/open-policy-agent/opa/storage/inmem/inmem.go @@ -16,443 +16,41 @@ package inmem import ( - "context" - "fmt" "io" - "path/filepath" - "strings" - "sync" - "sync/atomic" - "github.com/open-policy-agent/opa/ast" - "github.com/open-policy-agent/opa/internal/merge" "github.com/open-policy-agent/opa/storage" - "github.com/open-policy-agent/opa/util" + v1 "github.com/open-policy-agent/opa/v1/storage/inmem" ) // New returns an empty in-memory store. func New() storage.Store { - return NewWithOpts() + return v1.New() } // NewWithOpts returns an empty in-memory store, with extra options passed. func NewWithOpts(opts ...Opt) storage.Store { - s := &store{ - triggers: map[*handle]storage.TriggerConfig{}, - policies: map[string][]byte{}, - roundTripOnWrite: true, - returnASTValuesOnRead: false, - } - - for _, opt := range opts { - opt(s) - } - - if s.returnASTValuesOnRead { - s.data = ast.NewObject() - } else { - s.data = map[string]interface{}{} - } - - return s + return v1.NewWithOpts(opts...) } // NewFromObject returns a new in-memory store from the supplied data object. func NewFromObject(data map[string]interface{}) storage.Store { - return NewFromObjectWithOpts(data) + return v1.NewFromObject(data) } // NewFromObjectWithOpts returns a new in-memory store from the supplied data object, with the // options passed. func NewFromObjectWithOpts(data map[string]interface{}, opts ...Opt) storage.Store { - db := NewWithOpts(opts...) - ctx := context.Background() - txn, err := db.NewTransaction(ctx, storage.WriteParams) - if err != nil { - panic(err) - } - if err := db.Write(ctx, txn, storage.AddOp, storage.Path{}, data); err != nil { - panic(err) - } - if err := db.Commit(ctx, txn); err != nil { - panic(err) - } - return db + return v1.NewFromObjectWithOpts(data, opts...) } // NewFromReader returns a new in-memory store from a reader that produces a // JSON serialized object. This function is for test purposes. func NewFromReader(r io.Reader) storage.Store { - return NewFromReaderWithOpts(r) + return v1.NewFromReader(r) } // NewFromReader returns a new in-memory store from a reader that produces a // JSON serialized object, with extra options. This function is for test purposes. func NewFromReaderWithOpts(r io.Reader, opts ...Opt) storage.Store { - d := util.NewJSONDecoder(r) - var data map[string]interface{} - if err := d.Decode(&data); err != nil { - panic(err) - } - return NewFromObjectWithOpts(data, opts...) -} - -type store struct { - rmu sync.RWMutex // reader-writer lock - wmu sync.Mutex // writer lock - xid uint64 // last generated transaction id - data interface{} // raw or AST data - policies map[string][]byte // raw policies - triggers map[*handle]storage.TriggerConfig // registered triggers - - // roundTripOnWrite, if true, means that every call to Write round trips the - // data through JSON before adding the data to the store. Defaults to true. - roundTripOnWrite bool - - // returnASTValuesOnRead, if true, means that the store will eagerly convert data to AST values, - // and return them on Read. - // FIXME: naming(?) - returnASTValuesOnRead bool -} - -type handle struct { - db *store -} - -func (db *store) NewTransaction(_ context.Context, params ...storage.TransactionParams) (storage.Transaction, error) { - var write bool - var ctx *storage.Context - if len(params) > 0 { - write = params[0].Write - ctx = params[0].Context - } - xid := atomic.AddUint64(&db.xid, uint64(1)) - if write { - db.wmu.Lock() - } else { - db.rmu.RLock() - } - return newTransaction(xid, write, ctx, db), nil -} - -// Truncate implements the storage.Store interface. This method must be called within a transaction. -func (db *store) Truncate(ctx context.Context, txn storage.Transaction, params storage.TransactionParams, it storage.Iterator) error { - var update *storage.Update - var err error - mergedData := map[string]interface{}{} - - underlying, err := db.underlying(txn) - if err != nil { - return err - } - - for { - update, err = it.Next() - if err != nil { - break - } - - if update.IsPolicy { - err = underlying.UpsertPolicy(strings.TrimLeft(update.Path.String(), "/"), update.Value) - if err != nil { - return err - } - } else { - var value interface{} - err = util.Unmarshal(update.Value, &value) - if err != nil { - return err - } - - var key []string - dirpath := strings.TrimLeft(update.Path.String(), "/") - if len(dirpath) > 0 { - key = strings.Split(dirpath, "/") - } - - if value != nil { - obj, err := mktree(key, value) - if err != nil { - return err - } - - merged, ok := merge.InterfaceMaps(mergedData, obj) - if !ok { - return fmt.Errorf("failed to insert data file from path %s", filepath.Join(key...)) - } - mergedData = merged - } - } - } - - if err != nil && err != io.EOF { - return err - } - - // For backwards compatibility, check if `RootOverwrite` was configured. - if params.RootOverwrite { - newPath, ok := storage.ParsePathEscaped("/") - if !ok { - return fmt.Errorf("storage path invalid: %v", newPath) - } - return underlying.Write(storage.AddOp, newPath, mergedData) - } - - for _, root := range params.BasePaths { - newPath, ok := storage.ParsePathEscaped("/" + root) - if !ok { - return fmt.Errorf("storage path invalid: %v", newPath) - } - - if value, ok := lookup(newPath, mergedData); ok { - if len(newPath) > 0 { - if err := storage.MakeDir(ctx, db, txn, newPath[:len(newPath)-1]); err != nil { - return err - } - } - if err := underlying.Write(storage.AddOp, newPath, value); err != nil { - return err - } - } - } - return nil -} - -func (db *store) Commit(ctx context.Context, txn storage.Transaction) error { - underlying, err := db.underlying(txn) - if err != nil { - return err - } - if underlying.write { - db.rmu.Lock() - event := underlying.Commit() - db.runOnCommitTriggers(ctx, txn, event) - // Mark the transaction stale after executing triggers, so they can - // perform store operations if needed. - underlying.stale = true - db.rmu.Unlock() - db.wmu.Unlock() - } else { - db.rmu.RUnlock() - } - return nil -} - -func (db *store) Abort(_ context.Context, txn storage.Transaction) { - underlying, err := db.underlying(txn) - if err != nil { - panic(err) - } - underlying.stale = true - if underlying.write { - db.wmu.Unlock() - } else { - db.rmu.RUnlock() - } -} - -func (db *store) ListPolicies(_ context.Context, txn storage.Transaction) ([]string, error) { - underlying, err := db.underlying(txn) - if err != nil { - return nil, err - } - return underlying.ListPolicies(), nil -} - -func (db *store) GetPolicy(_ context.Context, txn storage.Transaction, id string) ([]byte, error) { - underlying, err := db.underlying(txn) - if err != nil { - return nil, err - } - return underlying.GetPolicy(id) -} - -func (db *store) UpsertPolicy(_ context.Context, txn storage.Transaction, id string, bs []byte) error { - underlying, err := db.underlying(txn) - if err != nil { - return err - } - return underlying.UpsertPolicy(id, bs) -} - -func (db *store) DeletePolicy(_ context.Context, txn storage.Transaction, id string) error { - underlying, err := db.underlying(txn) - if err != nil { - return err - } - if _, err := underlying.GetPolicy(id); err != nil { - return err - } - return underlying.DeletePolicy(id) -} - -func (db *store) Register(_ context.Context, txn storage.Transaction, config storage.TriggerConfig) (storage.TriggerHandle, error) { - underlying, err := db.underlying(txn) - if err != nil { - return nil, err - } - if !underlying.write { - return nil, &storage.Error{ - Code: storage.InvalidTransactionErr, - Message: "triggers must be registered with a write transaction", - } - } - h := &handle{db} - db.triggers[h] = config - return h, nil -} - -func (db *store) Read(_ context.Context, txn storage.Transaction, path storage.Path) (interface{}, error) { - underlying, err := db.underlying(txn) - if err != nil { - return nil, err - } - - v, err := underlying.Read(path) - if err != nil { - return nil, err - } - - return v, nil -} - -func (db *store) Write(_ context.Context, txn storage.Transaction, op storage.PatchOp, path storage.Path, value interface{}) error { - underlying, err := db.underlying(txn) - if err != nil { - return err - } - val := util.Reference(value) - if db.roundTripOnWrite { - if err := util.RoundTrip(val); err != nil { - return err - } - } - return underlying.Write(op, path, *val) -} - -func (h *handle) Unregister(_ context.Context, txn storage.Transaction) { - underlying, err := h.db.underlying(txn) - if err != nil { - panic(err) - } - if !underlying.write { - panic(&storage.Error{ - Code: storage.InvalidTransactionErr, - Message: "triggers must be unregistered with a write transaction", - }) - } - delete(h.db.triggers, h) -} - -func (db *store) runOnCommitTriggers(ctx context.Context, txn storage.Transaction, event storage.TriggerEvent) { - if db.returnASTValuesOnRead && len(db.triggers) > 0 { - // FIXME: Not very performant for large data. - - dataEvents := make([]storage.DataEvent, 0, len(event.Data)) - - for _, dataEvent := range event.Data { - if astData, ok := dataEvent.Data.(ast.Value); ok { - jsn, err := ast.ValueToInterface(astData, illegalResolver{}) - if err != nil { - panic(err) - } - dataEvents = append(dataEvents, storage.DataEvent{ - Path: dataEvent.Path, - Data: jsn, - Removed: dataEvent.Removed, - }) - } else { - dataEvents = append(dataEvents, dataEvent) - } - } - - event = storage.TriggerEvent{ - Policy: event.Policy, - Data: dataEvents, - Context: event.Context, - } - } - - for _, t := range db.triggers { - t.OnCommit(ctx, txn, event) - } -} - -type illegalResolver struct{} - -func (illegalResolver) Resolve(ref ast.Ref) (interface{}, error) { - return nil, fmt.Errorf("illegal value: %v", ref) -} - -func (db *store) underlying(txn storage.Transaction) (*transaction, error) { - underlying, ok := txn.(*transaction) - if !ok { - return nil, &storage.Error{ - Code: storage.InvalidTransactionErr, - Message: fmt.Sprintf("unexpected transaction type %T", txn), - } - } - if underlying.db != db { - return nil, &storage.Error{ - Code: storage.InvalidTransactionErr, - Message: "unknown transaction", - } - } - if underlying.stale { - return nil, &storage.Error{ - Code: storage.InvalidTransactionErr, - Message: "stale transaction", - } - } - return underlying, nil -} - -const rootMustBeObjectMsg = "root must be object" -const rootCannotBeRemovedMsg = "root cannot be removed" - -func invalidPatchError(f string, a ...interface{}) *storage.Error { - return &storage.Error{ - Code: storage.InvalidPatchErr, - Message: fmt.Sprintf(f, a...), - } -} - -func mktree(path []string, value interface{}) (map[string]interface{}, error) { - if len(path) == 0 { - // For 0 length path the value is the full tree. - obj, ok := value.(map[string]interface{}) - if !ok { - return nil, invalidPatchError(rootMustBeObjectMsg) - } - return obj, nil - } - - dir := map[string]interface{}{} - for i := len(path) - 1; i > 0; i-- { - dir[path[i]] = value - value = dir - dir = map[string]interface{}{} - } - dir[path[0]] = value - - return dir, nil -} - -func lookup(path storage.Path, data map[string]interface{}) (interface{}, bool) { - if len(path) == 0 { - return data, true - } - for i := 0; i < len(path)-1; i++ { - value, ok := data[path[i]] - if !ok { - return nil, false - } - obj, ok := value.(map[string]interface{}) - if !ok { - return nil, false - } - data = obj - } - value, ok := data[path[len(path)-1]] - return value, ok + return v1.NewFromReaderWithOpts(r, opts...) } diff --git a/vendor/github.com/open-policy-agent/opa/storage/inmem/opts.go b/vendor/github.com/open-policy-agent/opa/storage/inmem/opts.go index 2239fc73a..43f03ef27 100644 --- a/vendor/github.com/open-policy-agent/opa/storage/inmem/opts.go +++ b/vendor/github.com/open-policy-agent/opa/storage/inmem/opts.go @@ -1,7 +1,9 @@ package inmem +import v1 "github.com/open-policy-agent/opa/v1/storage/inmem" + // An Opt modifies store at instantiation. -type Opt func(*store) +type Opt = v1.Opt // OptRoundTripOnWrite sets whether incoming objects written to store are // round-tripped through JSON to ensure they are serializable to JSON. @@ -19,9 +21,7 @@ type Opt func(*store) // and that mutations happening to the objects after they have been passed into // Write() don't affect their logic. func OptRoundTripOnWrite(enabled bool) Opt { - return func(s *store) { - s.roundTripOnWrite = enabled - } + return v1.OptRoundTripOnWrite(enabled) } // OptReturnASTValuesOnRead sets whether data values added to the store should be @@ -31,7 +31,5 @@ func OptRoundTripOnWrite(enabled bool) Opt { // which may result in panics if the data is not valid. Callers should ensure that passed data // can be serialized to AST values; otherwise, it's recommended to also enable OptRoundTripOnWrite. func OptReturnASTValuesOnRead(enabled bool) Opt { - return func(s *store) { - s.returnASTValuesOnRead = enabled - } + return v1.OptReturnASTValuesOnRead(enabled) } diff --git a/vendor/github.com/open-policy-agent/opa/storage/interface.go b/vendor/github.com/open-policy-agent/opa/storage/interface.go index 6baca9a59..0192c459c 100644 --- a/vendor/github.com/open-policy-agent/opa/storage/interface.go +++ b/vendor/github.com/open-policy-agent/opa/storage/interface.go @@ -5,243 +5,82 @@ package storage import ( - "context" - - "github.com/open-policy-agent/opa/metrics" + v1 "github.com/open-policy-agent/opa/v1/storage" ) // Transaction defines the interface that identifies a consistent snapshot over // the policy engine's storage layer. -type Transaction interface { - ID() uint64 -} +type Transaction = v1.Transaction // Store defines the interface for the storage layer's backend. -type Store interface { - Trigger - Policy - - // NewTransaction is called create a new transaction in the store. - NewTransaction(context.Context, ...TransactionParams) (Transaction, error) - - // Read is called to fetch a document referred to by path. - Read(context.Context, Transaction, Path) (interface{}, error) - - // Write is called to modify a document referred to by path. - Write(context.Context, Transaction, PatchOp, Path, interface{}) error - - // Commit is called to finish the transaction. If Commit returns an error, the - // transaction must be automatically aborted by the Store implementation. - Commit(context.Context, Transaction) error - - // Truncate is called to make a copy of the underlying store, write documents in the new store - // by creating multiple transactions in the new store as needed and finally swapping - // over to the new storage instance. This method must be called within a transaction on the original store. - Truncate(context.Context, Transaction, TransactionParams, Iterator) error - - // Abort is called to cancel the transaction. - Abort(context.Context, Transaction) -} +type Store = v1.Store // MakeDirer defines the interface a Store could realize to override the // generic MakeDir functionality in storage.MakeDir -type MakeDirer interface { - MakeDir(context.Context, Transaction, Path) error -} +type MakeDirer = v1.MakeDirer // TransactionParams describes a new transaction. -type TransactionParams struct { - - // BasePaths indicates the top-level paths where write operations will be performed in this transaction. - BasePaths []string - - // RootOverwrite is deprecated. Use BasePaths instead. - RootOverwrite bool - - // Write indicates if this transaction will perform any write operations. - Write bool - - // Context contains key/value pairs passed to triggers. - Context *Context -} +type TransactionParams = v1.TransactionParams // Context is a simple container for key/value pairs. -type Context struct { - values map[interface{}]interface{} -} +type Context = v1.Context // NewContext returns a new context object. func NewContext() *Context { - return &Context{ - values: map[interface{}]interface{}{}, - } -} - -// Get returns the key value in the context. -func (ctx *Context) Get(key interface{}) interface{} { - if ctx == nil { - return nil - } - return ctx.values[key] -} - -// Put adds a key/value pair to the context. -func (ctx *Context) Put(key, value interface{}) { - ctx.values[key] = value -} - -var metricsKey = struct{}{} - -// WithMetrics allows passing metrics via the Context. -// It puts the metrics object in the ctx, and returns the same -// ctx (not a copy) for convenience. -func (ctx *Context) WithMetrics(m metrics.Metrics) *Context { - ctx.values[metricsKey] = m - return ctx -} - -// Metrics() allows using a Context's metrics. Returns nil if metrics -// were not attached to the Context. -func (ctx *Context) Metrics() metrics.Metrics { - if m, ok := ctx.values[metricsKey]; ok { - if met, ok := m.(metrics.Metrics); ok { - return met - } - } - return nil + return v1.NewContext() } // WriteParams specifies the TransactionParams for a write transaction. -var WriteParams = TransactionParams{ - Write: true, -} +var WriteParams = v1.WriteParams // PatchOp is the enumeration of supposed modifications. -type PatchOp int +type PatchOp = v1.PatchOp // Patch supports add, remove, and replace operations. const ( - AddOp PatchOp = iota - RemoveOp = iota - ReplaceOp = iota + AddOp = v1.AddOp + RemoveOp = v1.RemoveOp + ReplaceOp = v1.ReplaceOp ) // WritesNotSupported provides a default implementation of the write // interface which may be used if the backend does not support writes. -type WritesNotSupported struct{} - -func (WritesNotSupported) Write(context.Context, Transaction, PatchOp, Path, interface{}) error { - return writesNotSupportedError() -} +type WritesNotSupported = v1.WritesNotSupported // Policy defines the interface for policy module storage. -type Policy interface { - ListPolicies(context.Context, Transaction) ([]string, error) - GetPolicy(context.Context, Transaction, string) ([]byte, error) - UpsertPolicy(context.Context, Transaction, string, []byte) error - DeletePolicy(context.Context, Transaction, string) error -} +type Policy = v1.Policy // PolicyNotSupported provides a default implementation of the policy interface // which may be used if the backend does not support policy storage. -type PolicyNotSupported struct{} - -// ListPolicies always returns a PolicyNotSupportedErr. -func (PolicyNotSupported) ListPolicies(context.Context, Transaction) ([]string, error) { - return nil, policyNotSupportedError() -} - -// GetPolicy always returns a PolicyNotSupportedErr. -func (PolicyNotSupported) GetPolicy(context.Context, Transaction, string) ([]byte, error) { - return nil, policyNotSupportedError() -} - -// UpsertPolicy always returns a PolicyNotSupportedErr. -func (PolicyNotSupported) UpsertPolicy(context.Context, Transaction, string, []byte) error { - return policyNotSupportedError() -} - -// DeletePolicy always returns a PolicyNotSupportedErr. -func (PolicyNotSupported) DeletePolicy(context.Context, Transaction, string) error { - return policyNotSupportedError() -} +type PolicyNotSupported = v1.PolicyNotSupported // PolicyEvent describes a change to a policy. -type PolicyEvent struct { - ID string - Data []byte - Removed bool -} +type PolicyEvent = v1.PolicyEvent // DataEvent describes a change to a base data document. -type DataEvent struct { - Path Path - Data interface{} - Removed bool -} +type DataEvent = v1.DataEvent // TriggerEvent describes the changes that caused the trigger to be invoked. -type TriggerEvent struct { - Policy []PolicyEvent - Data []DataEvent - Context *Context -} - -// IsZero returns true if the TriggerEvent indicates no changes occurred. This -// function is primarily for test purposes. -func (e TriggerEvent) IsZero() bool { - return !e.PolicyChanged() && !e.DataChanged() -} - -// PolicyChanged returns true if the trigger was caused by a policy change. -func (e TriggerEvent) PolicyChanged() bool { - return len(e.Policy) > 0 -} - -// DataChanged returns true if the trigger was caused by a data change. -func (e TriggerEvent) DataChanged() bool { - return len(e.Data) > 0 -} +type TriggerEvent = v1.TriggerEvent // TriggerConfig contains the trigger registration configuration. -type TriggerConfig struct { - - // OnCommit is invoked when a transaction is successfully committed. The - // callback is invoked with a handle to the write transaction that - // successfully committed before other clients see the changes. - OnCommit func(context.Context, Transaction, TriggerEvent) -} +type TriggerConfig = v1.TriggerConfig // Trigger defines the interface that stores implement to register for change // notifications when the store is changed. -type Trigger interface { - Register(context.Context, Transaction, TriggerConfig) (TriggerHandle, error) -} +type Trigger = v1.Trigger // TriggersNotSupported provides default implementations of the Trigger // interface which may be used if the backend does not support triggers. -type TriggersNotSupported struct{} - -// Register always returns an error indicating triggers are not supported. -func (TriggersNotSupported) Register(context.Context, Transaction, TriggerConfig) (TriggerHandle, error) { - return nil, triggersNotSupportedError() -} +type TriggersNotSupported = v1.TriggersNotSupported // TriggerHandle defines the interface that can be used to unregister triggers that have // been registered on a Store. -type TriggerHandle interface { - Unregister(context.Context, Transaction) -} +type TriggerHandle = v1.TriggerHandle // Iterator defines the interface that can be used to read files from a directory starting with // files at the base of the directory, then sub-directories etc. -type Iterator interface { - Next() (*Update, error) -} +type Iterator = v1.Iterator // Update contains information about a file -type Update struct { - Path Path - Value []byte - IsPolicy bool -} +type Update = v1.Update diff --git a/vendor/github.com/open-policy-agent/opa/storage/path.go b/vendor/github.com/open-policy-agent/opa/storage/path.go index 02ef4cab4..91d4f34f2 100644 --- a/vendor/github.com/open-policy-agent/opa/storage/path.go +++ b/vendor/github.com/open-policy-agent/opa/storage/path.go @@ -5,150 +5,30 @@ package storage import ( - "fmt" - "net/url" - "strconv" - "strings" - "github.com/open-policy-agent/opa/ast" + v1 "github.com/open-policy-agent/opa/v1/storage" ) // Path refers to a document in storage. -type Path []string +type Path = v1.Path // ParsePath returns a new path for the given str. func ParsePath(str string) (path Path, ok bool) { - if len(str) == 0 { - return nil, false - } - if str[0] != '/' { - return nil, false - } - if len(str) == 1 { - return Path{}, true - } - parts := strings.Split(str[1:], "/") - return parts, true + return v1.ParsePath(str) } // ParsePathEscaped returns a new path for the given escaped str. func ParsePathEscaped(str string) (path Path, ok bool) { - path, ok = ParsePath(str) - if !ok { - return - } - for i := range path { - segment, err := url.PathUnescape(path[i]) - if err == nil { - path[i] = segment - } - } - return + return v1.ParsePathEscaped(str) } // NewPathForRef returns a new path for the given ref. func NewPathForRef(ref ast.Ref) (path Path, err error) { - - if len(ref) == 0 { - return nil, fmt.Errorf("empty reference (indicates error in caller)") - } - - if len(ref) == 1 { - return Path{}, nil - } - - path = make(Path, 0, len(ref)-1) - - for _, term := range ref[1:] { - switch v := term.Value.(type) { - case ast.String: - path = append(path, string(v)) - case ast.Number: - path = append(path, v.String()) - case ast.Boolean, ast.Null: - return nil, &Error{ - Code: NotFoundErr, - Message: fmt.Sprintf("%v: does not exist", ref), - } - case *ast.Array, ast.Object, ast.Set: - return nil, fmt.Errorf("composites cannot be base document keys: %v", ref) - default: - return nil, fmt.Errorf("unresolved reference (indicates error in caller): %v", ref) - } - } - - return path, nil -} - -// Compare performs lexigraphical comparison on p and other and returns -1 if p -// is less than other, 0 if p is equal to other, or 1 if p is greater than -// other. -func (p Path) Compare(other Path) (cmp int) { - min := len(p) - if len(other) < min { - min = len(other) - } - for i := 0; i < min; i++ { - if cmp := strings.Compare(p[i], other[i]); cmp != 0 { - return cmp - } - } - if len(p) < len(other) { - return -1 - } - if len(p) == len(other) { - return 0 - } - return 1 -} - -// Equal returns true if p is the same as other. -func (p Path) Equal(other Path) bool { - return p.Compare(other) == 0 -} - -// HasPrefix returns true if p starts with other. -func (p Path) HasPrefix(other Path) bool { - if len(other) > len(p) { - return false - } - for i := range other { - if p[i] != other[i] { - return false - } - } - return true -} - -// Ref returns a ref that represents p rooted at head. -func (p Path) Ref(head *ast.Term) (ref ast.Ref) { - ref = make(ast.Ref, len(p)+1) - ref[0] = head - for i := range p { - idx, err := strconv.ParseInt(p[i], 10, 64) - if err == nil { - ref[i+1] = ast.UIntNumberTerm(uint64(idx)) - } else { - ref[i+1] = ast.StringTerm(p[i]) - } - } - return ref -} - -func (p Path) String() string { - buf := make([]string, len(p)) - for i := range buf { - buf[i] = url.PathEscape(p[i]) - } - return "/" + strings.Join(buf, "/") + return v1.NewPathForRef(ref) } // MustParsePath returns a new Path for s. If s cannot be parsed, this function // will panic. This is mostly for test purposes. func MustParsePath(s string) Path { - path, ok := ParsePath(s) - if !ok { - panic(s) - } - return path + return v1.MustParsePath(s) } diff --git a/vendor/github.com/open-policy-agent/opa/storage/storage.go b/vendor/github.com/open-policy-agent/opa/storage/storage.go index 2f8a39c59..c02773d98 100644 --- a/vendor/github.com/open-policy-agent/opa/storage/storage.go +++ b/vendor/github.com/open-policy-agent/opa/storage/storage.go @@ -7,85 +7,34 @@ package storage import ( "context" - "github.com/open-policy-agent/opa/ast" + v1 "github.com/open-policy-agent/opa/v1/storage" ) // NewTransactionOrDie is a helper function to create a new transaction. If the // storage layer cannot create a new transaction, this function will panic. This // function should only be used for tests. func NewTransactionOrDie(ctx context.Context, store Store, params ...TransactionParams) Transaction { - txn, err := store.NewTransaction(ctx, params...) - if err != nil { - panic(err) - } - return txn + return v1.NewTransactionOrDie(ctx, store, params...) } // ReadOne is a convenience function to read a single value from the provided Store. It // will create a new Transaction to perform the read with, and clean up after itself // should an error occur. func ReadOne(ctx context.Context, store Store, path Path) (interface{}, error) { - txn, err := store.NewTransaction(ctx) - if err != nil { - return nil, err - } - defer store.Abort(ctx, txn) - - return store.Read(ctx, txn, path) + return v1.ReadOne(ctx, store, path) } // WriteOne is a convenience function to write a single value to the provided Store. It // will create a new Transaction to perform the write with, and clean up after itself // should an error occur. func WriteOne(ctx context.Context, store Store, op PatchOp, path Path, value interface{}) error { - txn, err := store.NewTransaction(ctx, WriteParams) - if err != nil { - return err - } - - if err := store.Write(ctx, txn, op, path, value); err != nil { - store.Abort(ctx, txn) - return err - } - - return store.Commit(ctx, txn) + return v1.WriteOne(ctx, store, op, path, value) } // MakeDir inserts an empty object at path. If the parent path does not exist, // MakeDir will create it recursively. func MakeDir(ctx context.Context, store Store, txn Transaction, path Path) error { - - // Allow the Store implementation to deal with this in its own way. - if md, ok := store.(MakeDirer); ok { - return md.MakeDir(ctx, txn, path) - } - - if len(path) == 0 { - return nil - } - - node, err := store.Read(ctx, txn, path) - if err != nil { - if !IsNotFound(err) { - return err - } - - if err := MakeDir(ctx, store, txn, path[:len(path)-1]); err != nil { - return err - } - - return store.Write(ctx, txn, AddOp, path, map[string]interface{}{}) - } - - if _, ok := node.(map[string]interface{}); ok { - return nil - } - - if _, ok := node.(ast.Object); ok { - return nil - } - - return writeConflictError(path) + return v1.MakeDir(ctx, store, txn, path) } // Txn is a convenience function that executes f inside a new transaction @@ -93,44 +42,12 @@ func MakeDir(ctx context.Context, store Store, txn Transaction, path Path) error // aborted and the error is returned. Otherwise, the transaction is committed // and the result of the commit is returned. func Txn(ctx context.Context, store Store, params TransactionParams, f func(Transaction) error) error { - - txn, err := store.NewTransaction(ctx, params) - if err != nil { - return err - } - - if err := f(txn); err != nil { - store.Abort(ctx, txn) - return err - } - - return store.Commit(ctx, txn) + return v1.Txn(ctx, store, params, f) } // NonEmpty returns a function that tests if a path is non-empty. A // path is non-empty if a Read on the path returns a value or a Read // on any of the path prefixes returns a non-object value. func NonEmpty(ctx context.Context, store Store, txn Transaction) func([]string) (bool, error) { - return func(path []string) (bool, error) { - if _, err := store.Read(ctx, txn, Path(path)); err == nil { - return true, nil - } else if !IsNotFound(err) { - return false, err - } - for i := len(path) - 1; i > 0; i-- { - val, err := store.Read(ctx, txn, Path(path[:i])) - if err != nil && !IsNotFound(err) { - return false, err - } else if err == nil { - if _, ok := val.(map[string]interface{}); ok { - return false, nil - } - if _, ok := val.(ast.Object); ok { - return false, nil - } - return true, nil - } - } - return false, nil - } + return v1.NonEmpty(ctx, store, txn) } diff --git a/vendor/github.com/open-policy-agent/opa/topdown/builtins.go b/vendor/github.com/open-policy-agent/opa/topdown/builtins.go index cf694d433..f28c6c795 100644 --- a/vendor/github.com/open-policy-agent/opa/topdown/builtins.go +++ b/vendor/github.com/open-policy-agent/opa/topdown/builtins.go @@ -5,219 +5,63 @@ package topdown import ( - "context" - "encoding/binary" - "fmt" - "io" - "math/rand" - - "github.com/open-policy-agent/opa/ast" - "github.com/open-policy-agent/opa/metrics" - "github.com/open-policy-agent/opa/topdown/builtins" - "github.com/open-policy-agent/opa/topdown/cache" - "github.com/open-policy-agent/opa/topdown/print" - "github.com/open-policy-agent/opa/tracing" + v1 "github.com/open-policy-agent/opa/v1/topdown" ) type ( // Deprecated: Functional-style builtins are deprecated. Use BuiltinFunc instead. - FunctionalBuiltin1 func(op1 ast.Value) (output ast.Value, err error) + FunctionalBuiltin1 = v1.FunctionalBuiltin1 //nolint:staticcheck // SA1019: Intentional use of deprecated type. // Deprecated: Functional-style builtins are deprecated. Use BuiltinFunc instead. - FunctionalBuiltin2 func(op1, op2 ast.Value) (output ast.Value, err error) + FunctionalBuiltin2 = v1.FunctionalBuiltin2 //nolint:staticcheck // SA1019: Intentional use of deprecated type. // Deprecated: Functional-style builtins are deprecated. Use BuiltinFunc instead. - FunctionalBuiltin3 func(op1, op2, op3 ast.Value) (output ast.Value, err error) + FunctionalBuiltin3 = v1.FunctionalBuiltin3 //nolint:staticcheck // SA1019: Intentional use of deprecated type. // Deprecated: Functional-style builtins are deprecated. Use BuiltinFunc instead. - FunctionalBuiltin4 func(op1, op2, op3, op4 ast.Value) (output ast.Value, err error) + FunctionalBuiltin4 = v1.FunctionalBuiltin4 //nolint:staticcheck // SA1019: Intentional use of deprecated type. // BuiltinContext contains context from the evaluator that may be used by // built-in functions. - BuiltinContext struct { - Context context.Context // request context that was passed when query started - Metrics metrics.Metrics // metrics registry for recording built-in specific metrics - Seed io.Reader // randomization source - Time *ast.Term // wall clock time - Cancel Cancel // atomic value that signals evaluation to halt - Runtime *ast.Term // runtime information on the OPA instance - Cache builtins.Cache // built-in function state cache - InterQueryBuiltinCache cache.InterQueryCache // cross-query built-in function state cache - InterQueryBuiltinValueCache cache.InterQueryValueCache // cross-query built-in function state value cache. this cache is useful for scenarios where the entry size cannot be calculated - NDBuiltinCache builtins.NDBCache // cache for non-deterministic built-in state - Location *ast.Location // location of built-in call - Tracers []Tracer // Deprecated: Use QueryTracers instead - QueryTracers []QueryTracer // tracer objects for trace() built-in function - TraceEnabled bool // indicates whether tracing is enabled for the evaluation - QueryID uint64 // identifies query being evaluated - ParentID uint64 // identifies parent of query being evaluated - PrintHook print.Hook // provides callback function to use for printing - DistributedTracingOpts tracing.Options // options to be used by distributed tracing. - rand *rand.Rand // randomization source for non-security-sensitive operations - Capabilities *ast.Capabilities - } + BuiltinContext = v1.BuiltinContext // BuiltinFunc defines an interface for implementing built-in functions. // The built-in function is called with the plugged operands from the call // (including the output operands.) The implementation should evaluate the // operands and invoke the iterator for each successful/defined output // value. - BuiltinFunc func(bctx BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error + BuiltinFunc = v1.BuiltinFunc ) -// Rand returns a random number generator based on the Seed for this built-in -// context. The random number will be re-used across multiple calls to this -// function. If a random number generator cannot be created, an error is -// returned. -func (bctx *BuiltinContext) Rand() (*rand.Rand, error) { - - if bctx.rand != nil { - return bctx.rand, nil - } - - seed, err := readInt64(bctx.Seed) - if err != nil { - return nil, err - } - - bctx.rand = rand.New(rand.NewSource(seed)) - return bctx.rand, nil -} - // RegisterBuiltinFunc adds a new built-in function to the evaluation engine. func RegisterBuiltinFunc(name string, f BuiltinFunc) { - builtinFunctions[name] = builtinErrorWrapper(name, f) + v1.RegisterBuiltinFunc(name, f) } // Deprecated: Functional-style builtins are deprecated. Use RegisterBuiltinFunc instead. func RegisterFunctionalBuiltin1(name string, fun FunctionalBuiltin1) { - builtinFunctions[name] = functionalWrapper1(name, fun) + v1.RegisterFunctionalBuiltin1(name, fun) } // Deprecated: Functional-style builtins are deprecated. Use RegisterBuiltinFunc instead. func RegisterFunctionalBuiltin2(name string, fun FunctionalBuiltin2) { - builtinFunctions[name] = functionalWrapper2(name, fun) + v1.RegisterFunctionalBuiltin2(name, fun) } // Deprecated: Functional-style builtins are deprecated. Use RegisterBuiltinFunc instead. func RegisterFunctionalBuiltin3(name string, fun FunctionalBuiltin3) { - builtinFunctions[name] = functionalWrapper3(name, fun) + v1.RegisterFunctionalBuiltin3(name, fun) } // Deprecated: Functional-style builtins are deprecated. Use RegisterBuiltinFunc instead. func RegisterFunctionalBuiltin4(name string, fun FunctionalBuiltin4) { - builtinFunctions[name] = functionalWrapper4(name, fun) + v1.RegisterFunctionalBuiltin4(name, fun) } // GetBuiltin returns a built-in function implementation, nil if no built-in found. func GetBuiltin(name string) BuiltinFunc { - return builtinFunctions[name] + return v1.GetBuiltin(name) } // Deprecated: The BuiltinEmpty type is no longer needed. Use nil return values instead. -type BuiltinEmpty struct{} - -func (BuiltinEmpty) Error() string { - return "" -} - -var builtinFunctions = map[string]BuiltinFunc{} - -func builtinErrorWrapper(name string, fn BuiltinFunc) BuiltinFunc { - return func(bctx BuiltinContext, args []*ast.Term, iter func(*ast.Term) error) error { - err := fn(bctx, args, iter) - if err == nil { - return nil - } - return handleBuiltinErr(name, bctx.Location, err) - } -} - -func functionalWrapper1(name string, fn FunctionalBuiltin1) BuiltinFunc { - return func(bctx BuiltinContext, args []*ast.Term, iter func(*ast.Term) error) error { - result, err := fn(args[0].Value) - if err == nil { - return iter(ast.NewTerm(result)) - } - return handleBuiltinErr(name, bctx.Location, err) - } -} - -func functionalWrapper2(name string, fn FunctionalBuiltin2) BuiltinFunc { - return func(bctx BuiltinContext, args []*ast.Term, iter func(*ast.Term) error) error { - result, err := fn(args[0].Value, args[1].Value) - if err == nil { - return iter(ast.NewTerm(result)) - } - return handleBuiltinErr(name, bctx.Location, err) - } -} - -func functionalWrapper3(name string, fn FunctionalBuiltin3) BuiltinFunc { - return func(bctx BuiltinContext, args []*ast.Term, iter func(*ast.Term) error) error { - result, err := fn(args[0].Value, args[1].Value, args[2].Value) - if err == nil { - return iter(ast.NewTerm(result)) - } - return handleBuiltinErr(name, bctx.Location, err) - } -} - -func functionalWrapper4(name string, fn FunctionalBuiltin4) BuiltinFunc { - return func(bctx BuiltinContext, args []*ast.Term, iter func(*ast.Term) error) error { - result, err := fn(args[0].Value, args[1].Value, args[2].Value, args[3].Value) - if err == nil { - return iter(ast.NewTerm(result)) - } - if _, empty := err.(BuiltinEmpty); empty { - return nil - } - return handleBuiltinErr(name, bctx.Location, err) - } -} - -func handleBuiltinErr(name string, loc *ast.Location, err error) error { - switch err := err.(type) { - case BuiltinEmpty: - return nil - case *Error, Halt: - return err - case builtins.ErrOperand: - e := &Error{ - Code: TypeErr, - Message: fmt.Sprintf("%v: %v", name, err.Error()), - Location: loc, - } - return e.Wrap(err) - default: - e := &Error{ - Code: BuiltinErr, - Message: fmt.Sprintf("%v: %v", name, err.Error()), - Location: loc, - } - return e.Wrap(err) - } -} - -func readInt64(r io.Reader) (int64, error) { - bs := make([]byte, 8) - n, err := io.ReadFull(r, bs) - if n != len(bs) || err != nil { - return 0, err - } - return int64(binary.BigEndian.Uint64(bs)), nil -} - -// Used to get older-style (ast.Term, error) tuples out of newer functions. -func getResult(fn BuiltinFunc, operands ...*ast.Term) (*ast.Term, error) { - var result *ast.Term - extractionFn := func(r *ast.Term) error { - result = r - return nil - } - err := fn(BuiltinContext{}, operands, extractionFn) - if err != nil { - return nil, err - } - return result, nil -} +type BuiltinEmpty = v1.Builtin diff --git a/vendor/github.com/open-policy-agent/opa/topdown/builtins/builtins.go b/vendor/github.com/open-policy-agent/opa/topdown/builtins/builtins.go index 353f95684..152c37717 100644 --- a/vendor/github.com/open-policy-agent/opa/topdown/builtins/builtins.go +++ b/vendor/github.com/open-policy-agent/opa/topdown/builtins/builtins.go @@ -6,323 +6,118 @@ package builtins import ( - "encoding/json" - "fmt" "math/big" - "strings" "github.com/open-policy-agent/opa/ast" - "github.com/open-policy-agent/opa/util" + v1 "github.com/open-policy-agent/opa/v1/topdown/builtins" ) // Cache defines the built-in cache used by the top-down evaluation. The keys // must be comparable and should not be of type string. -type Cache map[interface{}]interface{} - -// Put updates the cache for the named built-in. -func (c Cache) Put(k, v interface{}) { - c[k] = v -} - -// Get returns the cached value for k. -func (c Cache) Get(k interface{}) (interface{}, bool) { - v, ok := c[k] - return v, ok -} +type Cache = v1.Cache // We use an ast.Object for the cached keys/values because a naive // map[ast.Value]ast.Value will not correctly detect value equality of // the member keys. -type NDBCache map[string]ast.Object - -func (c NDBCache) AsValue() ast.Value { - out := ast.NewObject() - for bname, obj := range c { - out.Insert(ast.StringTerm(bname), ast.NewTerm(obj)) - } - return out -} - -// Put updates the cache for the named built-in. -// Automatically creates the 2-level hierarchy as needed. -func (c NDBCache) Put(name string, k, v ast.Value) { - if _, ok := c[name]; !ok { - c[name] = ast.NewObject() - } - c[name].Insert(ast.NewTerm(k), ast.NewTerm(v)) -} - -// Get returns the cached value for k for the named builtin. -func (c NDBCache) Get(name string, k ast.Value) (ast.Value, bool) { - if m, ok := c[name]; ok { - v := m.Get(ast.NewTerm(k)) - if v != nil { - return v.Value, true - } - return nil, false - } - return nil, false -} - -// Convenience functions for serializing the data structure. -func (c NDBCache) MarshalJSON() ([]byte, error) { - v, err := ast.JSON(c.AsValue()) - if err != nil { - return nil, err - } - return json.Marshal(v) -} - -func (c *NDBCache) UnmarshalJSON(data []byte) error { - out := map[string]ast.Object{} - var incoming interface{} - - // Note: We use util.Unmarshal instead of json.Unmarshal to get - // correct deserialization of number types. - err := util.Unmarshal(data, &incoming) - if err != nil { - return err - } - - // Convert interface types back into ast.Value types. - nestedObject, err := ast.InterfaceToValue(incoming) - if err != nil { - return err - } - - // Reconstruct NDBCache from nested ast.Object structure. - if source, ok := nestedObject.(ast.Object); ok { - err = source.Iter(func(k, v *ast.Term) error { - if obj, ok := v.Value.(ast.Object); ok { - out[string(k.Value.(ast.String))] = obj - return nil - } - return fmt.Errorf("expected Object, got other Value type in conversion") - }) - if err != nil { - return err - } - } - - *c = out - - return nil -} +type NDBCache = v1.NDBCache // ErrOperand represents an invalid operand has been passed to a built-in // function. Built-ins should return ErrOperand to indicate a type error has // occurred. -type ErrOperand string - -func (err ErrOperand) Error() string { - return string(err) -} +type ErrOperand = v1.ErrOperand // NewOperandErr returns a generic operand error. func NewOperandErr(pos int, f string, a ...interface{}) error { - f = fmt.Sprintf("operand %v ", pos) + f - return ErrOperand(fmt.Sprintf(f, a...)) + return v1.NewOperandErr(pos, f, a...) } // NewOperandTypeErr returns an operand error indicating the operand's type was wrong. func NewOperandTypeErr(pos int, got ast.Value, expected ...string) error { - - if len(expected) == 1 { - return NewOperandErr(pos, "must be %v but got %v", expected[0], ast.TypeName(got)) - } - - return NewOperandErr(pos, "must be one of {%v} but got %v", strings.Join(expected, ", "), ast.TypeName(got)) + return v1.NewOperandTypeErr(pos, got, expected...) } // NewOperandElementErr returns an operand error indicating an element in the // composite operand was wrong. func NewOperandElementErr(pos int, composite ast.Value, got ast.Value, expected ...string) error { - - tpe := ast.TypeName(composite) - - if len(expected) == 1 { - return NewOperandErr(pos, "must be %v of %vs but got %v containing %v", tpe, expected[0], tpe, ast.TypeName(got)) - } - - return NewOperandErr(pos, "must be %v of (any of) {%v} but got %v containing %v", tpe, strings.Join(expected, ", "), tpe, ast.TypeName(got)) + return v1.NewOperandElementErr(pos, composite, got, expected...) } // NewOperandEnumErr returns an operand error indicating a value was wrong. func NewOperandEnumErr(pos int, expected ...string) error { - - if len(expected) == 1 { - return NewOperandErr(pos, "must be %v", expected[0]) - } - - return NewOperandErr(pos, "must be one of {%v}", strings.Join(expected, ", ")) + return v1.NewOperandEnumErr(pos, expected...) } // IntOperand converts x to an int. If the cast fails, a descriptive error is // returned. func IntOperand(x ast.Value, pos int) (int, error) { - n, ok := x.(ast.Number) - if !ok { - return 0, NewOperandTypeErr(pos, x, "number") - } - - i, ok := n.Int() - if !ok { - return 0, NewOperandErr(pos, "must be integer number but got floating-point number") - } - - return i, nil + return v1.IntOperand(x, pos) } // BigIntOperand converts x to a big int. If the cast fails, a descriptive error // is returned. func BigIntOperand(x ast.Value, pos int) (*big.Int, error) { - n, err := NumberOperand(x, 1) - if err != nil { - return nil, NewOperandTypeErr(pos, x, "integer") - } - bi, err := NumberToInt(n) - if err != nil { - return nil, NewOperandErr(pos, "must be integer number but got floating-point number") - } - - return bi, nil + return v1.BigIntOperand(x, pos) } // NumberOperand converts x to a number. If the cast fails, a descriptive error is // returned. func NumberOperand(x ast.Value, pos int) (ast.Number, error) { - n, ok := x.(ast.Number) - if !ok { - return ast.Number(""), NewOperandTypeErr(pos, x, "number") - } - return n, nil + return v1.NumberOperand(x, pos) } // SetOperand converts x to a set. If the cast fails, a descriptive error is // returned. func SetOperand(x ast.Value, pos int) (ast.Set, error) { - s, ok := x.(ast.Set) - if !ok { - return nil, NewOperandTypeErr(pos, x, "set") - } - return s, nil + return v1.SetOperand(x, pos) } // StringOperand converts x to a string. If the cast fails, a descriptive error is // returned. func StringOperand(x ast.Value, pos int) (ast.String, error) { - s, ok := x.(ast.String) - if !ok { - return ast.String(""), NewOperandTypeErr(pos, x, "string") - } - return s, nil + return v1.StringOperand(x, pos) } // ObjectOperand converts x to an object. If the cast fails, a descriptive // error is returned. func ObjectOperand(x ast.Value, pos int) (ast.Object, error) { - o, ok := x.(ast.Object) - if !ok { - return nil, NewOperandTypeErr(pos, x, "object") - } - return o, nil + return v1.ObjectOperand(x, pos) } // ArrayOperand converts x to an array. If the cast fails, a descriptive // error is returned. func ArrayOperand(x ast.Value, pos int) (*ast.Array, error) { - a, ok := x.(*ast.Array) - if !ok { - return ast.NewArray(), NewOperandTypeErr(pos, x, "array") - } - return a, nil + return v1.ArrayOperand(x, pos) } // NumberToFloat converts n to a big float. func NumberToFloat(n ast.Number) *big.Float { - r, ok := new(big.Float).SetString(string(n)) - if !ok { - panic("illegal value") - } - return r + return v1.NumberToFloat(n) } // FloatToNumber converts f to a number. func FloatToNumber(f *big.Float) ast.Number { - var format byte = 'g' - if f.IsInt() { - format = 'f' - } - return ast.Number(f.Text(format, -1)) + return v1.FloatToNumber(f) } // NumberToInt converts n to a big int. // If n cannot be converted to an big int, an error is returned. func NumberToInt(n ast.Number) (*big.Int, error) { - f := NumberToFloat(n) - r, accuracy := f.Int(nil) - if accuracy != big.Exact { - return nil, fmt.Errorf("illegal value") - } - return r, nil + return v1.NumberToInt(n) } // IntToNumber converts i to a number. func IntToNumber(i *big.Int) ast.Number { - return ast.Number(i.String()) + return v1.IntToNumber(i) } // StringSliceOperand converts x to a []string. If the cast fails, a descriptive error is // returned. func StringSliceOperand(a ast.Value, pos int) ([]string, error) { - type iterable interface { - Iter(func(*ast.Term) error) error - Len() int - } - - strs, ok := a.(iterable) - if !ok { - return nil, NewOperandTypeErr(pos, a, "array", "set") - } - - var outStrs = make([]string, 0, strs.Len()) - if err := strs.Iter(func(x *ast.Term) error { - s, ok := x.Value.(ast.String) - if !ok { - return NewOperandElementErr(pos, a, x.Value, "string") - } - outStrs = append(outStrs, string(s)) - return nil - }); err != nil { - return nil, err - } - - return outStrs, nil + return v1.StringSliceOperand(a, pos) } // RuneSliceOperand converts x to a []rune. If the cast fails, a descriptive error is // returned. func RuneSliceOperand(x ast.Value, pos int) ([]rune, error) { - a, err := ArrayOperand(x, pos) - if err != nil { - return nil, err - } - - var f = make([]rune, a.Len()) - for k := 0; k < a.Len(); k++ { - b := a.Elem(k) - c, ok := b.Value.(ast.String) - if !ok { - return nil, NewOperandElementErr(pos, x, b.Value, "string") - } - - d := []rune(string(c)) - if len(d) != 1 { - return nil, NewOperandElementErr(pos, x, b.Value, "rune") - } - - f[k] = d[0] - } - - return f, nil + return v1.RuneSliceOperand(x, pos) } diff --git a/vendor/github.com/open-policy-agent/opa/topdown/builtins/doc.go b/vendor/github.com/open-policy-agent/opa/topdown/builtins/doc.go new file mode 100644 index 000000000..1eec53694 --- /dev/null +++ b/vendor/github.com/open-policy-agent/opa/topdown/builtins/doc.go @@ -0,0 +1,8 @@ +// Copyright 2016 The OPA Authors. All rights reserved. +// Use of this source code is governed by an Apache2 +// license that can be found in the LICENSE file. + +// Deprecated: This package is intended for older projects transitioning from OPA v0.x and will remain for the lifetime of OPA v1.x, but its use is not recommended. +// For newer features and behaviours, such as defaulting to the Rego v1 syntax, use the corresponding components in the [github.com/open-policy-agent/opa/v1] package instead. +// See https://www.openpolicyagent.org/docs/latest/v0-compatibility/ for more information. +package builtins diff --git a/vendor/github.com/open-policy-agent/opa/topdown/cache.go b/vendor/github.com/open-policy-agent/opa/topdown/cache.go index 265457e02..bb39df03e 100644 --- a/vendor/github.com/open-policy-agent/opa/topdown/cache.go +++ b/vendor/github.com/open-policy-agent/opa/topdown/cache.go @@ -5,348 +5,15 @@ package topdown import ( - "github.com/open-policy-agent/opa/ast" - "github.com/open-policy-agent/opa/util" + v1 "github.com/open-policy-agent/opa/v1/topdown" ) // VirtualCache defines the interface for a cache that stores the results of // evaluated virtual documents (rules). // The cache is a stack of frames, where each frame is a mapping from references // to values. -type VirtualCache interface { - // Push pushes a new, empty frame of value mappings onto the stack. - Push() - - // Pop pops the top frame of value mappings from the stack, removing all associated entries. - Pop() - - // Get returns the value associated with the given reference. The second return value - // indicates whether the reference has a recorded 'undefined' result. - Get(ref ast.Ref) (*ast.Term, bool) - - // Put associates the given reference with the given value. If the value is nil, the reference - // is marked as having an 'undefined' result. - Put(ref ast.Ref, value *ast.Term) - - // Keys returns the set of keys that have been cached for the active frame. - Keys() []ast.Ref -} - -type virtualCache struct { - stack []*virtualCacheElem -} - -type virtualCacheElem struct { - value *ast.Term - children *util.HashMap - undefined bool -} +type VirtualCache = v1.VirtualCache func NewVirtualCache() VirtualCache { - cache := &virtualCache{} - cache.Push() - return cache -} - -func (c *virtualCache) Push() { - c.stack = append(c.stack, newVirtualCacheElem()) -} - -func (c *virtualCache) Pop() { - c.stack = c.stack[:len(c.stack)-1] -} - -// Returns the resolved value of the AST term and a flag indicating if the value -// should be interpretted as undefined: -// -// nil, true indicates the ref is undefined -// ast.Term, false indicates the ref is defined -// nil, false indicates the ref has not been cached -// ast.Term, true is impossible -func (c *virtualCache) Get(ref ast.Ref) (*ast.Term, bool) { - node := c.stack[len(c.stack)-1] - for i := 0; i < len(ref); i++ { - x, ok := node.children.Get(ref[i]) - if !ok { - return nil, false - } - node = x.(*virtualCacheElem) - } - if node.undefined { - return nil, true - } - - return node.value, false -} - -// If value is a nil pointer, set the 'undefined' flag on the cache element to -// indicate that the Ref has resolved to undefined. -func (c *virtualCache) Put(ref ast.Ref, value *ast.Term) { - node := c.stack[len(c.stack)-1] - for i := 0; i < len(ref); i++ { - x, ok := node.children.Get(ref[i]) - if ok { - node = x.(*virtualCacheElem) - } else { - next := newVirtualCacheElem() - node.children.Put(ref[i], next) - node = next - } - } - if value != nil { - node.value = value - } else { - node.undefined = true - } -} - -func (c *virtualCache) Keys() []ast.Ref { - node := c.stack[len(c.stack)-1] - return keysRecursive(nil, node) -} - -func keysRecursive(root ast.Ref, node *virtualCacheElem) []ast.Ref { - var keys []ast.Ref - node.children.Iter(func(k, v util.T) bool { - ref := root.Append(k.(*ast.Term)) - if v.(*virtualCacheElem).value != nil { - keys = append(keys, ref) - } - if v.(*virtualCacheElem).children.Len() > 0 { - keys = append(keys, keysRecursive(ref, v.(*virtualCacheElem))...) - } - return false - }) - return keys -} - -func newVirtualCacheElem() *virtualCacheElem { - return &virtualCacheElem{children: newVirtualCacheHashMap()} -} - -func newVirtualCacheHashMap() *util.HashMap { - return util.NewHashMap(func(a, b util.T) bool { - return a.(*ast.Term).Equal(b.(*ast.Term)) - }, func(x util.T) int { - return x.(*ast.Term).Hash() - }) -} - -// baseCache implements a trie structure to cache base documents read out of -// storage. Values inserted into the cache may contain other values that were -// previously inserted. In this case, the previous values are erased from the -// structure. -type baseCache struct { - root *baseCacheElem -} - -func newBaseCache() *baseCache { - return &baseCache{ - root: newBaseCacheElem(), - } -} - -func (c *baseCache) Get(ref ast.Ref) ast.Value { - node := c.root - for i := 0; i < len(ref); i++ { - node = node.children[ref[i].Value] - if node == nil { - return nil - } else if node.value != nil { - result, err := node.value.Find(ref[i+1:]) - if err != nil { - return nil - } - return result - } - } - return nil -} - -func (c *baseCache) Put(ref ast.Ref, value ast.Value) { - node := c.root - for i := 0; i < len(ref); i++ { - if child, ok := node.children[ref[i].Value]; ok { - node = child - } else { - child := newBaseCacheElem() - node.children[ref[i].Value] = child - node = child - } - } - node.set(value) -} - -type baseCacheElem struct { - value ast.Value - children map[ast.Value]*baseCacheElem -} - -func newBaseCacheElem() *baseCacheElem { - return &baseCacheElem{ - children: map[ast.Value]*baseCacheElem{}, - } -} - -func (e *baseCacheElem) set(value ast.Value) { - e.value = value - e.children = map[ast.Value]*baseCacheElem{} -} - -type refStack struct { - sl []refStackElem -} - -type refStackElem struct { - refs []ast.Ref -} - -func newRefStack() *refStack { - return &refStack{} -} - -func (s *refStack) Push(refs []ast.Ref) { - s.sl = append(s.sl, refStackElem{refs: refs}) -} - -func (s *refStack) Pop() { - s.sl = s.sl[:len(s.sl)-1] -} - -func (s *refStack) Prefixed(ref ast.Ref) bool { - if s != nil { - for i := len(s.sl) - 1; i >= 0; i-- { - for j := range s.sl[i].refs { - if ref.HasPrefix(s.sl[i].refs[j]) { - return true - } - } - } - } - return false -} - -type comprehensionCache struct { - stack []map[*ast.Term]*comprehensionCacheElem -} - -type comprehensionCacheElem struct { - value *ast.Term - children *util.HashMap -} - -func newComprehensionCache() *comprehensionCache { - cache := &comprehensionCache{} - cache.Push() - return cache -} - -func (c *comprehensionCache) Push() { - c.stack = append(c.stack, map[*ast.Term]*comprehensionCacheElem{}) -} - -func (c *comprehensionCache) Pop() { - c.stack = c.stack[:len(c.stack)-1] -} - -func (c *comprehensionCache) Elem(t *ast.Term) (*comprehensionCacheElem, bool) { - elem, ok := c.stack[len(c.stack)-1][t] - return elem, ok -} - -func (c *comprehensionCache) Set(t *ast.Term, elem *comprehensionCacheElem) { - c.stack[len(c.stack)-1][t] = elem -} - -func newComprehensionCacheElem() *comprehensionCacheElem { - return &comprehensionCacheElem{children: newComprehensionCacheHashMap()} -} - -func (c *comprehensionCacheElem) Get(key []*ast.Term) *ast.Term { - node := c - for i := 0; i < len(key); i++ { - x, ok := node.children.Get(key[i]) - if !ok { - return nil - } - node = x.(*comprehensionCacheElem) - } - return node.value -} - -func (c *comprehensionCacheElem) Put(key []*ast.Term, value *ast.Term) { - node := c - for i := 0; i < len(key); i++ { - x, ok := node.children.Get(key[i]) - if ok { - node = x.(*comprehensionCacheElem) - } else { - next := newComprehensionCacheElem() - node.children.Put(key[i], next) - node = next - } - } - node.value = value -} - -func newComprehensionCacheHashMap() *util.HashMap { - return util.NewHashMap(func(a, b util.T) bool { - return a.(*ast.Term).Equal(b.(*ast.Term)) - }, func(x util.T) int { - return x.(*ast.Term).Hash() - }) -} - -type functionMocksStack struct { - stack []*functionMocksElem -} - -type functionMocksElem []frame - -type frame map[string]*ast.Term - -func newFunctionMocksStack() *functionMocksStack { - stack := &functionMocksStack{} - stack.Push() - return stack -} - -func newFunctionMocksElem() *functionMocksElem { - return &functionMocksElem{} -} - -func (s *functionMocksStack) Push() { - s.stack = append(s.stack, newFunctionMocksElem()) -} - -func (s *functionMocksStack) Pop() { - s.stack = s.stack[:len(s.stack)-1] -} - -func (s *functionMocksStack) PopPairs() { - current := s.stack[len(s.stack)-1] - *current = (*current)[:len(*current)-1] -} - -func (s *functionMocksStack) PutPairs(mocks [][2]*ast.Term) { - el := frame{} - for i := range mocks { - el[mocks[i][0].Value.String()] = mocks[i][1] - } - s.Put(el) -} - -func (s *functionMocksStack) Put(el frame) { - current := s.stack[len(s.stack)-1] - *current = append(*current, el) -} - -func (s *functionMocksStack) Get(f ast.Ref) (*ast.Term, bool) { - current := *s.stack[len(s.stack)-1] - for i := len(current) - 1; i >= 0; i-- { - if r, ok := current[i][f.String()]; ok { - return r, true - } - } - return nil, false + return v1.NewVirtualCache() } diff --git a/vendor/github.com/open-policy-agent/opa/topdown/cache/cache.go b/vendor/github.com/open-policy-agent/opa/topdown/cache/cache.go index 55ed34061..e95617b22 100644 --- a/vendor/github.com/open-policy-agent/opa/topdown/cache/cache.go +++ b/vendor/github.com/open-policy-agent/opa/topdown/cache/cache.go @@ -6,132 +6,34 @@ package cache import ( - "container/list" "context" - "fmt" - "math" - "sync" - "time" - "github.com/open-policy-agent/opa/ast" - "github.com/open-policy-agent/opa/util" -) - -const ( - defaultInterQueryBuiltinValueCacheSize = int(0) // unlimited - defaultMaxSizeBytes = int64(0) // unlimited - defaultForcedEvictionThresholdPercentage = int64(100) // trigger at max_size_bytes - defaultStaleEntryEvictionPeriodSeconds = int64(0) // never + v1 "github.com/open-policy-agent/opa/v1/topdown/cache" ) // Config represents the configuration for the inter-query builtin cache. -type Config struct { - InterQueryBuiltinCache InterQueryBuiltinCacheConfig `json:"inter_query_builtin_cache"` - InterQueryBuiltinValueCache InterQueryBuiltinValueCacheConfig `json:"inter_query_builtin_value_cache"` -} +type Config = v1.Config // InterQueryBuiltinValueCacheConfig represents the configuration of the inter-query value cache that built-in functions can utilize. // MaxNumEntries - max number of cache entries -type InterQueryBuiltinValueCacheConfig struct { - MaxNumEntries *int `json:"max_num_entries,omitempty"` -} +type InterQueryBuiltinValueCacheConfig = v1.InterQueryBuiltinValueCacheConfig // InterQueryBuiltinCacheConfig represents the configuration of the inter-query cache that built-in functions can utilize. // MaxSizeBytes - max capacity of cache in bytes // ForcedEvictionThresholdPercentage - capacity usage in percentage after which forced FIFO eviction starts // StaleEntryEvictionPeriodSeconds - time period between end of previous and start of new stale entry eviction routine -type InterQueryBuiltinCacheConfig struct { - MaxSizeBytes *int64 `json:"max_size_bytes,omitempty"` - ForcedEvictionThresholdPercentage *int64 `json:"forced_eviction_threshold_percentage,omitempty"` - StaleEntryEvictionPeriodSeconds *int64 `json:"stale_entry_eviction_period_seconds,omitempty"` -} +type InterQueryBuiltinCacheConfig = v1.InterQueryBuiltinCacheConfig // ParseCachingConfig returns the config for the inter-query cache. func ParseCachingConfig(raw []byte) (*Config, error) { - if raw == nil { - maxSize := new(int64) - *maxSize = defaultMaxSizeBytes - threshold := new(int64) - *threshold = defaultForcedEvictionThresholdPercentage - period := new(int64) - *period = defaultStaleEntryEvictionPeriodSeconds - - maxInterQueryBuiltinValueCacheSize := new(int) - *maxInterQueryBuiltinValueCacheSize = defaultInterQueryBuiltinValueCacheSize - - return &Config{InterQueryBuiltinCache: InterQueryBuiltinCacheConfig{MaxSizeBytes: maxSize, ForcedEvictionThresholdPercentage: threshold, StaleEntryEvictionPeriodSeconds: period}, - InterQueryBuiltinValueCache: InterQueryBuiltinValueCacheConfig{MaxNumEntries: maxInterQueryBuiltinValueCacheSize}}, nil - } - - var config Config - - if err := util.Unmarshal(raw, &config); err == nil { - if err = config.validateAndInjectDefaults(); err != nil { - return nil, err - } - } else { - return nil, err - } - - return &config, nil -} - -func (c *Config) validateAndInjectDefaults() error { - if c.InterQueryBuiltinCache.MaxSizeBytes == nil { - maxSize := new(int64) - *maxSize = defaultMaxSizeBytes - c.InterQueryBuiltinCache.MaxSizeBytes = maxSize - } - if c.InterQueryBuiltinCache.ForcedEvictionThresholdPercentage == nil { - threshold := new(int64) - *threshold = defaultForcedEvictionThresholdPercentage - c.InterQueryBuiltinCache.ForcedEvictionThresholdPercentage = threshold - } else { - threshold := *c.InterQueryBuiltinCache.ForcedEvictionThresholdPercentage - if threshold < 0 || threshold > 100 { - return fmt.Errorf("invalid forced_eviction_threshold_percentage %v", threshold) - } - } - if c.InterQueryBuiltinCache.StaleEntryEvictionPeriodSeconds == nil { - period := new(int64) - *period = defaultStaleEntryEvictionPeriodSeconds - c.InterQueryBuiltinCache.StaleEntryEvictionPeriodSeconds = period - } else { - period := *c.InterQueryBuiltinCache.StaleEntryEvictionPeriodSeconds - if period < 0 { - return fmt.Errorf("invalid stale_entry_eviction_period_seconds %v", period) - } - } - - if c.InterQueryBuiltinValueCache.MaxNumEntries == nil { - maxSize := new(int) - *maxSize = defaultInterQueryBuiltinValueCacheSize - c.InterQueryBuiltinValueCache.MaxNumEntries = maxSize - } else { - numEntries := *c.InterQueryBuiltinValueCache.MaxNumEntries - if numEntries < 0 { - return fmt.Errorf("invalid max_num_entries %v", numEntries) - } - } - - return nil + return v1.ParseCachingConfig(raw) } // InterQueryCacheValue defines the interface for the data that the inter-query cache holds. -type InterQueryCacheValue interface { - SizeInBytes() int64 - Clone() (InterQueryCacheValue, error) -} +type InterQueryCacheValue = v1.InterQueryCacheValue // InterQueryCache defines the interface for the inter-query cache. -type InterQueryCache interface { - Get(key ast.Value) (value InterQueryCacheValue, found bool) - Insert(key ast.Value, value InterQueryCacheValue) int - InsertWithExpiry(key ast.Value, value InterQueryCacheValue, expiresAt time.Time) int - Delete(key ast.Value) - UpdateConfig(config *Config) - Clone(value InterQueryCacheValue) (InterQueryCacheValue, error) -} +type InterQueryCache = v1.InterQueryCache // NewInterQueryCache returns a new inter-query cache. // The cache uses a FIFO eviction policy when it reaches the forced eviction threshold. @@ -139,7 +41,7 @@ type InterQueryCache interface { // // config - to configure the InterQueryCache func NewInterQueryCache(config *Config) InterQueryCache { - return newCache(config) + return v1.NewInterQueryCache(config) } // NewInterQueryCacheWithContext returns a new inter-query cache with context. @@ -152,255 +54,11 @@ func NewInterQueryCache(config *Config) InterQueryCache { // ctx - used to control lifecycle of the stale entry cleanup routine // config - to configure the InterQueryCache func NewInterQueryCacheWithContext(ctx context.Context, config *Config) InterQueryCache { - iqCache := newCache(config) - if iqCache.staleEntryEvictionTimePeriodSeconds() > 0 { - cleanupTicker := time.NewTicker(time.Duration(iqCache.staleEntryEvictionTimePeriodSeconds()) * time.Second) - go func() { - for { - select { - case <-cleanupTicker.C: - cleanupTicker.Stop() - iqCache.cleanStaleValues() - cleanupTicker = time.NewTicker(time.Duration(iqCache.staleEntryEvictionTimePeriodSeconds()) * time.Second) - case <-ctx.Done(): - cleanupTicker.Stop() - return - } - } - }() - } - - return iqCache -} - -type cacheItem struct { - value InterQueryCacheValue - expiresAt time.Time - keyElement *list.Element + return v1.NewInterQueryCacheWithContext(ctx, config) } -type cache struct { - items map[string]cacheItem - usage int64 - config *Config - l *list.List - mtx sync.Mutex -} - -func newCache(config *Config) *cache { - return &cache{ - items: map[string]cacheItem{}, - usage: 0, - config: config, - l: list.New(), - } -} - -// InsertWithExpiry inserts a key k into the cache with value v with an expiration time expiresAt. -// A zero time value for expiresAt indicates no expiry -func (c *cache) InsertWithExpiry(k ast.Value, v InterQueryCacheValue, expiresAt time.Time) (dropped int) { - c.mtx.Lock() - defer c.mtx.Unlock() - return c.unsafeInsert(k, v, expiresAt) -} - -// Insert inserts a key k into the cache with value v with no expiration time. -func (c *cache) Insert(k ast.Value, v InterQueryCacheValue) (dropped int) { - return c.InsertWithExpiry(k, v, time.Time{}) -} - -// Get returns the value in the cache for k. -func (c *cache) Get(k ast.Value) (InterQueryCacheValue, bool) { - c.mtx.Lock() - defer c.mtx.Unlock() - cacheItem, ok := c.unsafeGet(k) - - if ok { - return cacheItem.value, true - } - return nil, false -} - -// Delete deletes the value in the cache for k. -func (c *cache) Delete(k ast.Value) { - c.mtx.Lock() - defer c.mtx.Unlock() - c.unsafeDelete(k) -} - -func (c *cache) UpdateConfig(config *Config) { - if config == nil { - return - } - c.mtx.Lock() - defer c.mtx.Unlock() - c.config = config -} - -func (c *cache) Clone(value InterQueryCacheValue) (InterQueryCacheValue, error) { - c.mtx.Lock() - defer c.mtx.Unlock() - return c.unsafeClone(value) -} - -func (c *cache) unsafeInsert(k ast.Value, v InterQueryCacheValue, expiresAt time.Time) (dropped int) { - size := v.SizeInBytes() - limit := int64(math.Ceil(float64(c.forcedEvictionThresholdPercentage())/100.0) * (float64(c.maxSizeBytes()))) - if limit > 0 { - if size > limit { - dropped++ - return dropped - } - - for key := c.l.Front(); key != nil && (c.usage+size > limit); key = c.l.Front() { - dropKey := key.Value.(ast.Value) - c.unsafeDelete(dropKey) - dropped++ - } - } - - // By deleting the old value, if it exists, we ensure the usage variable stays correct - c.unsafeDelete(k) - - c.items[k.String()] = cacheItem{ - value: v, - expiresAt: expiresAt, - keyElement: c.l.PushBack(k), - } - c.usage += size - return dropped -} - -func (c *cache) unsafeGet(k ast.Value) (cacheItem, bool) { - value, ok := c.items[k.String()] - return value, ok -} - -func (c *cache) unsafeDelete(k ast.Value) { - cacheItem, ok := c.unsafeGet(k) - if !ok { - return - } - - c.usage -= cacheItem.value.SizeInBytes() - delete(c.items, k.String()) - c.l.Remove(cacheItem.keyElement) -} - -func (c *cache) unsafeClone(value InterQueryCacheValue) (InterQueryCacheValue, error) { - return value.Clone() -} - -func (c *cache) maxSizeBytes() int64 { - if c.config == nil { - return defaultMaxSizeBytes - } - return *c.config.InterQueryBuiltinCache.MaxSizeBytes -} - -func (c *cache) forcedEvictionThresholdPercentage() int64 { - if c.config == nil { - return defaultForcedEvictionThresholdPercentage - } - return *c.config.InterQueryBuiltinCache.ForcedEvictionThresholdPercentage -} - -func (c *cache) staleEntryEvictionTimePeriodSeconds() int64 { - if c.config == nil { - return defaultStaleEntryEvictionPeriodSeconds - } - return *c.config.InterQueryBuiltinCache.StaleEntryEvictionPeriodSeconds -} - -func (c *cache) cleanStaleValues() (dropped int) { - c.mtx.Lock() - defer c.mtx.Unlock() - for key := c.l.Front(); key != nil; { - nextKey := key.Next() - // if expiresAt is zero, the item doesn't have an expiry - if ea := c.items[(key.Value.(ast.Value)).String()].expiresAt; !ea.IsZero() && ea.Before(time.Now()) { - c.unsafeDelete(key.Value.(ast.Value)) - dropped++ - } - key = nextKey - } - return dropped -} - -type InterQueryValueCache interface { - Get(key ast.Value) (value any, found bool) - Insert(key ast.Value, value any) int - Delete(key ast.Value) - UpdateConfig(config *Config) -} - -type interQueryValueCache struct { - items map[string]any - config *Config - mtx sync.RWMutex -} - -// Get returns the value in the cache for k. -func (c *interQueryValueCache) Get(k ast.Value) (any, bool) { - c.mtx.RLock() - defer c.mtx.RUnlock() - value, ok := c.items[k.String()] - return value, ok -} - -// Insert inserts a key k into the cache with value v. -func (c *interQueryValueCache) Insert(k ast.Value, v any) (dropped int) { - c.mtx.Lock() - defer c.mtx.Unlock() - - maxEntries := c.maxNumEntries() - if maxEntries > 0 { - if len(c.items) >= maxEntries { - itemsToRemove := len(c.items) - maxEntries + 1 - - // Delete a (semi-)random key to make room for the new one. - for k := range c.items { - delete(c.items, k) - dropped++ - - if itemsToRemove == dropped { - break - } - } - } - } - - c.items[k.String()] = v - return dropped -} - -// Delete deletes the value in the cache for k. -func (c *interQueryValueCache) Delete(k ast.Value) { - c.mtx.Lock() - defer c.mtx.Unlock() - delete(c.items, k.String()) -} - -// UpdateConfig updates the cache config. -func (c *interQueryValueCache) UpdateConfig(config *Config) { - if config == nil { - return - } - c.mtx.Lock() - defer c.mtx.Unlock() - c.config = config -} - -func (c *interQueryValueCache) maxNumEntries() int { - if c.config == nil { - return defaultInterQueryBuiltinValueCacheSize - } - return *c.config.InterQueryBuiltinValueCache.MaxNumEntries -} +type InterQueryValueCache = v1.InterQueryValueCache -func NewInterQueryValueCache(_ context.Context, config *Config) InterQueryValueCache { - return &interQueryValueCache{ - items: map[string]any{}, - config: config, - } +func NewInterQueryValueCache(ctx context.Context, config *Config) InterQueryValueCache { + return v1.NewInterQueryValueCache(ctx, config) } diff --git a/vendor/github.com/open-policy-agent/opa/topdown/cache/doc.go b/vendor/github.com/open-policy-agent/opa/topdown/cache/doc.go new file mode 100644 index 000000000..640530c08 --- /dev/null +++ b/vendor/github.com/open-policy-agent/opa/topdown/cache/doc.go @@ -0,0 +1,8 @@ +// Copyright 2016 The OPA Authors. All rights reserved. +// Use of this source code is governed by an Apache2 +// license that can be found in the LICENSE file. + +// Deprecated: This package is intended for older projects transitioning from OPA v0.x and will remain for the lifetime of OPA v1.x, but its use is not recommended. +// For newer features and behaviours, such as defaulting to the Rego v1 syntax, use the corresponding components in the [github.com/open-policy-agent/opa/v1] package instead. +// See https://www.openpolicyagent.org/docs/latest/v0-compatibility/ for more information. +package cache diff --git a/vendor/github.com/open-policy-agent/opa/topdown/cancel.go b/vendor/github.com/open-policy-agent/opa/topdown/cancel.go index 534e0799a..395a14a80 100644 --- a/vendor/github.com/open-policy-agent/opa/topdown/cancel.go +++ b/vendor/github.com/open-policy-agent/opa/topdown/cancel.go @@ -5,29 +5,14 @@ package topdown import ( - "sync/atomic" + v1 "github.com/open-policy-agent/opa/v1/topdown" ) // Cancel defines the interface for cancelling topdown queries. Cancel // operations are thread-safe and idempotent. -type Cancel interface { - Cancel() - Cancelled() bool -} - -type cancel struct { - flag int32 -} +type Cancel = v1.Cancel // NewCancel returns a new Cancel object. func NewCancel() Cancel { - return &cancel{} -} - -func (c *cancel) Cancel() { - atomic.StoreInt32(&c.flag, 1) -} - -func (c *cancel) Cancelled() bool { - return atomic.LoadInt32(&c.flag) != 0 + return v1.NewCancel() } diff --git a/vendor/github.com/open-policy-agent/opa/topdown/doc.go b/vendor/github.com/open-policy-agent/opa/topdown/doc.go index 9aa7aa45c..a303ef788 100644 --- a/vendor/github.com/open-policy-agent/opa/topdown/doc.go +++ b/vendor/github.com/open-policy-agent/opa/topdown/doc.go @@ -7,4 +7,8 @@ // The topdown implementation is a modified version of the standard top-down // evaluation algorithm used in Datalog. References and comprehensions are // evaluated eagerly while all other terms are evaluated lazily. +// +// Deprecated: This package is intended for older projects transitioning from OPA v0.x and will remain for the lifetime of OPA v1.x, but its use is not recommended. +// For newer features and behaviours, such as defaulting to the Rego v1 syntax, use the corresponding components in the [github.com/open-policy-agent/opa/v1] package instead. +// See https://www.openpolicyagent.org/docs/latest/v0-compatibility/ for more information. package topdown diff --git a/vendor/github.com/open-policy-agent/opa/topdown/errors.go b/vendor/github.com/open-policy-agent/opa/topdown/errors.go index 918df6c85..47853ec6d 100644 --- a/vendor/github.com/open-policy-agent/opa/topdown/errors.go +++ b/vendor/github.com/open-policy-agent/opa/topdown/errors.go @@ -5,145 +5,50 @@ package topdown import ( - "errors" - "fmt" - - "github.com/open-policy-agent/opa/ast" + v1 "github.com/open-policy-agent/opa/v1/topdown" ) // Halt is a special error type that built-in function implementations return to indicate // that policy evaluation should stop immediately. -type Halt struct { - Err error -} - -func (h Halt) Error() string { - return h.Err.Error() -} - -func (h Halt) Unwrap() error { return h.Err } +type Halt = v1.Halt // Error is the error type returned by the Eval and Query functions when // an evaluation error occurs. -type Error struct { - Code string `json:"code"` - Message string `json:"message"` - Location *ast.Location `json:"location,omitempty"` - err error `json:"-"` -} +type Error = v1.Error const ( // InternalErr represents an unknown evaluation error. - InternalErr string = "eval_internal_error" + InternalErr = v1.InternalErr // CancelErr indicates the evaluation process was cancelled. - CancelErr string = "eval_cancel_error" + CancelErr = v1.CancelErr // ConflictErr indicates a conflict was encountered during evaluation. For // instance, a conflict occurs if a rule produces multiple, differing values // for the same key in an object. Conflict errors indicate the policy does // not account for the data loaded into the policy engine. - ConflictErr string = "eval_conflict_error" + ConflictErr = v1.ConflictErr // TypeErr indicates evaluation stopped because an expression was applied to // a value of an inappropriate type. - TypeErr string = "eval_type_error" + TypeErr = v1.TypeErr // BuiltinErr indicates a built-in function received a semantically invalid // input or encountered some kind of runtime error, e.g., connection // timeout, connection refused, etc. - BuiltinErr string = "eval_builtin_error" + BuiltinErr = v1.BuiltinErr // WithMergeErr indicates that the real and replacement data could not be merged. - WithMergeErr string = "eval_with_merge_error" + WithMergeErr = v1.WithMergeErr ) // IsError returns true if the err is an Error. func IsError(err error) bool { - var e *Error - return errors.As(err, &e) + return v1.IsError(err) } // IsCancel returns true if err was caused by cancellation. func IsCancel(err error) bool { - return errors.Is(err, &Error{Code: CancelErr}) -} - -// Is allows matching topdown errors using errors.Is (see IsCancel). -func (e *Error) Is(target error) bool { - var t *Error - if errors.As(target, &t) { - return (t.Code == "" || e.Code == t.Code) && - (t.Message == "" || e.Message == t.Message) && - (t.Location == nil || t.Location.Compare(e.Location) == 0) - } - return false -} - -func (e *Error) Error() string { - msg := fmt.Sprintf("%v: %v", e.Code, e.Message) - - if e.Location != nil { - msg = e.Location.String() + ": " + msg - } - - return msg -} - -func (e *Error) Wrap(err error) *Error { - e.err = err - return e -} - -func (e *Error) Unwrap() error { - return e.err -} - -func functionConflictErr(loc *ast.Location) error { - return &Error{ - Code: ConflictErr, - Location: loc, - Message: "functions must not produce multiple outputs for same inputs", - } -} - -func completeDocConflictErr(loc *ast.Location) error { - return &Error{ - Code: ConflictErr, - Location: loc, - Message: "complete rules must not produce multiple outputs", - } -} - -func objectDocKeyConflictErr(loc *ast.Location) error { - return &Error{ - Code: ConflictErr, - Location: loc, - Message: "object keys must be unique", - } -} - -func unsupportedBuiltinErr(loc *ast.Location) error { - return &Error{ - Code: InternalErr, - Location: loc, - Message: "unsupported built-in", - } -} - -func mergeConflictErr(loc *ast.Location) error { - return &Error{ - Code: WithMergeErr, - Location: loc, - Message: "real and replacement data could not be merged", - } -} - -func internalErr(loc *ast.Location, msg string) error { - return &Error{ - Code: InternalErr, - Location: loc, - Message: msg, - } + return v1.IsCancel(err) } diff --git a/vendor/github.com/open-policy-agent/opa/topdown/graphql.go b/vendor/github.com/open-policy-agent/opa/topdown/graphql.go index 8fb1b58a7..0d6ebda0a 100644 --- a/vendor/github.com/open-policy-agent/opa/topdown/graphql.go +++ b/vendor/github.com/open-policy-agent/opa/topdown/graphql.go @@ -16,8 +16,8 @@ import ( // Side-effecting import. Triggers GraphQL library's validation rule init() functions. _ "github.com/open-policy-agent/opa/internal/gqlparser/validator/rules" - "github.com/open-policy-agent/opa/ast" - "github.com/open-policy-agent/opa/topdown/builtins" + "github.com/open-policy-agent/opa/v1/ast" + "github.com/open-policy-agent/opa/v1/topdown/builtins" ) // Parses a GraphQL schema, and returns the GraphQL AST for the schema. @@ -295,7 +295,7 @@ func builtinGraphQLParseAndVerify(_ BuiltinContext, operands []*ast.Term, iter f var err error unverified := ast.ArrayTerm( - ast.BooleanTerm(false), + ast.InternedBooleanTerm(false), ast.NewTerm(ast.NewObject()), ast.NewTerm(ast.NewObject()), ) @@ -353,7 +353,7 @@ func builtinGraphQLParseAndVerify(_ BuiltinContext, operands []*ast.Term, iter f // Construct return value. verified := ast.ArrayTerm( - ast.BooleanTerm(true), + ast.InternedBooleanTerm(true), ast.NewTerm(queryResult), ast.NewTerm(querySchema), ) @@ -421,10 +421,10 @@ func builtinGraphQLIsValid(_ BuiltinContext, operands []*ast.Term, iter func(*as queryDoc, err = objectToQueryDocument(x) default: // Error if wrong type. - return iter(ast.BooleanTerm(false)) + return iter(ast.InternedBooleanTerm(false)) } if err != nil { - return iter(ast.BooleanTerm(false)) + return iter(ast.InternedBooleanTerm(false)) } switch x := operands[1].Value.(type) { @@ -434,23 +434,23 @@ func builtinGraphQLIsValid(_ BuiltinContext, operands []*ast.Term, iter func(*as schemaDoc, err = objectToSchemaDocument(x) default: // Error if wrong type. - return iter(ast.BooleanTerm(false)) + return iter(ast.InternedBooleanTerm(false)) } if err != nil { - return iter(ast.BooleanTerm(false)) + return iter(ast.InternedBooleanTerm(false)) } // Validate the query against the schema, erroring if there's an issue. schema, err := convertSchema(schemaDoc) if err != nil { - return iter(ast.BooleanTerm(false)) + return iter(ast.InternedBooleanTerm(false)) } if err := validateQuery(schema, queryDoc); err != nil { - return iter(ast.BooleanTerm(false)) + return iter(ast.InternedBooleanTerm(false)) } // If we got this far, the GraphQL query passed validation. - return iter(ast.BooleanTerm(true)) + return iter(ast.InternedBooleanTerm(true)) } func builtinGraphQLSchemaIsValid(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error { @@ -464,15 +464,15 @@ func builtinGraphQLSchemaIsValid(_ BuiltinContext, operands []*ast.Term, iter fu schemaDoc, err = objectToSchemaDocument(x) default: // Error if wrong type. - return iter(ast.BooleanTerm(false)) + return iter(ast.InternedBooleanTerm(false)) } if err != nil { - return iter(ast.BooleanTerm(false)) + return iter(ast.InternedBooleanTerm(false)) } // Validate the schema, this determines the result _, err = convertSchema(schemaDoc) - return iter(ast.BooleanTerm(err == nil)) + return iter(ast.InternedBooleanTerm(err == nil)) } func init() { diff --git a/vendor/github.com/open-policy-agent/opa/topdown/http.go b/vendor/github.com/open-policy-agent/opa/topdown/http.go index 18bfd3c72..693ea4048 100644 --- a/vendor/github.com/open-policy-agent/opa/topdown/http.go +++ b/vendor/github.com/open-policy-agent/opa/topdown/http.go @@ -5,1616 +5,13 @@ package topdown import ( - "bytes" - "context" - "crypto/tls" - "crypto/x509" - "encoding/json" - "fmt" - "io" - "math" - "net" - "net/http" - "net/url" - "os" - "runtime" - "strconv" - "strings" - "time" - - "github.com/open-policy-agent/opa/ast" - "github.com/open-policy-agent/opa/internal/version" - "github.com/open-policy-agent/opa/topdown/builtins" - "github.com/open-policy-agent/opa/topdown/cache" - "github.com/open-policy-agent/opa/tracing" - "github.com/open-policy-agent/opa/util" -) - -type cachingMode string - -const ( - defaultHTTPRequestTimeoutEnv = "HTTP_SEND_TIMEOUT" - defaultCachingMode cachingMode = "serialized" - cachingModeDeserialized cachingMode = "deserialized" -) - -var defaultHTTPRequestTimeout = time.Second * 5 - -var allowedKeyNames = [...]string{ - "method", - "url", - "body", - "enable_redirect", - "force_json_decode", - "force_yaml_decode", - "headers", - "raw_body", - "tls_use_system_certs", - "tls_ca_cert", - "tls_ca_cert_file", - "tls_ca_cert_env_variable", - "tls_client_cert", - "tls_client_cert_file", - "tls_client_cert_env_variable", - "tls_client_key", - "tls_client_key_file", - "tls_client_key_env_variable", - "tls_insecure_skip_verify", - "tls_server_name", - "timeout", - "cache", - "force_cache", - "force_cache_duration_seconds", - "raise_error", - "caching_mode", - "max_retry_attempts", - "cache_ignored_headers", -} - -// ref: https://www.rfc-editor.org/rfc/rfc7231#section-6.1 -var cacheableHTTPStatusCodes = [...]int{ - http.StatusOK, - http.StatusNonAuthoritativeInfo, - http.StatusNoContent, - http.StatusPartialContent, - http.StatusMultipleChoices, - http.StatusMovedPermanently, - http.StatusNotFound, - http.StatusMethodNotAllowed, - http.StatusGone, - http.StatusRequestURITooLong, - http.StatusNotImplemented, -} - -var ( - allowedKeys = ast.NewSet() - cacheableCodes = ast.NewSet() - requiredKeys = ast.NewSet(ast.StringTerm("method"), ast.StringTerm("url")) - httpSendLatencyMetricKey = "rego_builtin_" + strings.ReplaceAll(ast.HTTPSend.Name, ".", "_") - httpSendInterQueryCacheHits = httpSendLatencyMetricKey + "_interquery_cache_hits" + v1 "github.com/open-policy-agent/opa/v1/topdown" ) -type httpSendKey string - const ( - // httpSendBuiltinCacheKey is the key in the builtin context cache that - // points to the http.send() specific cache resides at. - httpSendBuiltinCacheKey httpSendKey = "HTTP_SEND_CACHE_KEY" - // HTTPSendInternalErr represents a runtime evaluation error. - HTTPSendInternalErr string = "eval_http_send_internal_error" + HTTPSendInternalErr = v1.HTTPSendInternalErr // HTTPSendNetworkErr represents a network error. - HTTPSendNetworkErr string = "eval_http_send_network_error" - - // minRetryDelay is amount of time to backoff after the first failure. - minRetryDelay = time.Millisecond * 100 - - // maxRetryDelay is the upper bound of backoff delay. - maxRetryDelay = time.Second * 60 + HTTPSendNetworkErr = v1.HTTPSendNetworkErr ) - -func builtinHTTPSend(bctx BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error { - - obj, err := builtins.ObjectOperand(operands[0].Value, 1) - if err != nil { - return handleBuiltinErr(ast.HTTPSend.Name, bctx.Location, err) - } - - raiseError, err := getRaiseErrorValue(obj) - if err != nil { - return handleBuiltinErr(ast.HTTPSend.Name, bctx.Location, err) - } - - req, err := validateHTTPRequestOperand(operands[0], 1) - if err != nil { - if raiseError { - return handleHTTPSendErr(bctx, err) - } - - return iter(generateRaiseErrorResult(handleBuiltinErr(ast.HTTPSend.Name, bctx.Location, err))) - } - - result, err := getHTTPResponse(bctx, req) - if err != nil { - if raiseError { - return handleHTTPSendErr(bctx, err) - } - - result = generateRaiseErrorResult(err) - } - return iter(result) -} - -func generateRaiseErrorResult(err error) *ast.Term { - obj := ast.NewObject() - obj.Insert(ast.StringTerm("status_code"), ast.IntNumberTerm(0)) - - errObj := ast.NewObject() - - switch err.(type) { - case *url.Error: - errObj.Insert(ast.StringTerm("code"), ast.StringTerm(HTTPSendNetworkErr)) - default: - errObj.Insert(ast.StringTerm("code"), ast.StringTerm(HTTPSendInternalErr)) - } - - errObj.Insert(ast.StringTerm("message"), ast.StringTerm(err.Error())) - obj.Insert(ast.StringTerm("error"), ast.NewTerm(errObj)) - - return ast.NewTerm(obj) -} - -func getHTTPResponse(bctx BuiltinContext, req ast.Object) (*ast.Term, error) { - - bctx.Metrics.Timer(httpSendLatencyMetricKey).Start() - defer bctx.Metrics.Timer(httpSendLatencyMetricKey).Stop() - - key, err := getKeyFromRequest(req) - if err != nil { - return nil, err - } - - reqExecutor, err := newHTTPRequestExecutor(bctx, req, key) - if err != nil { - return nil, err - } - // Check if cache already has a response for this query - // set headers to exclude cache_ignored_headers - resp, err := reqExecutor.CheckCache() - if err != nil { - return nil, err - } - - if resp == nil { - httpResp, err := reqExecutor.ExecuteHTTPRequest() - if err != nil { - reqExecutor.InsertErrorIntoCache(err) - return nil, err - } - defer util.Close(httpResp) - // Add result to intra/inter-query cache. - resp, err = reqExecutor.InsertIntoCache(httpResp) - if err != nil { - return nil, err - } - } - - return ast.NewTerm(resp), nil -} - -// getKeyFromRequest returns a key to be used for caching HTTP responses -// deletes headers from request object mentioned in cache_ignored_headers -func getKeyFromRequest(req ast.Object) (ast.Object, error) { - // deep copy so changes to key do not reflect in the request object - key := req.Copy() - cacheIgnoredHeadersTerm := req.Get(ast.StringTerm("cache_ignored_headers")) - allHeadersTerm := req.Get(ast.StringTerm("headers")) - // skip because no headers to delete - if cacheIgnoredHeadersTerm == nil || allHeadersTerm == nil { - // need to explicitly set cache_ignored_headers to null - // equivalent requests might have different sets of exclusion lists - key.Insert(ast.StringTerm("cache_ignored_headers"), ast.NullTerm()) - return key, nil - } - var cacheIgnoredHeaders []string - var allHeaders map[string]interface{} - err := ast.As(cacheIgnoredHeadersTerm.Value, &cacheIgnoredHeaders) - if err != nil { - return nil, err - } - err = ast.As(allHeadersTerm.Value, &allHeaders) - if err != nil { - return nil, err - } - for _, header := range cacheIgnoredHeaders { - delete(allHeaders, header) - } - val, err := ast.InterfaceToValue(allHeaders) - if err != nil { - return nil, err - } - key.Insert(ast.StringTerm("headers"), ast.NewTerm(val)) - // remove cache_ignored_headers key - key.Insert(ast.StringTerm("cache_ignored_headers"), ast.NullTerm()) - return key, nil -} - -func init() { - createAllowedKeys() - createCacheableHTTPStatusCodes() - initDefaults() - RegisterBuiltinFunc(ast.HTTPSend.Name, builtinHTTPSend) -} - -func handleHTTPSendErr(bctx BuiltinContext, err error) error { - // Return HTTP client timeout errors in a generic error message to avoid confusion about what happened. - // Do not do this if the builtin context was cancelled and is what caused the request to stop. - if urlErr, ok := err.(*url.Error); ok && urlErr.Timeout() && bctx.Context.Err() == nil { - err = fmt.Errorf("%s %s: request timed out", urlErr.Op, urlErr.URL) - } - if err := bctx.Context.Err(); err != nil { - return Halt{ - Err: &Error{ - Code: CancelErr, - Message: fmt.Sprintf("http.send: timed out (%s)", err.Error()), - }, - } - } - return handleBuiltinErr(ast.HTTPSend.Name, bctx.Location, err) -} - -func initDefaults() { - timeoutDuration := os.Getenv(defaultHTTPRequestTimeoutEnv) - if timeoutDuration != "" { - var err error - defaultHTTPRequestTimeout, err = time.ParseDuration(timeoutDuration) - if err != nil { - // If it is set to something not valid don't let the process continue in a state - // that will almost definitely give unexpected results by having it set at 0 - // which means no timeout.. - // This environment variable isn't considered part of the public API. - // TODO(patrick-east): Remove the environment variable - panic(fmt.Sprintf("invalid value for HTTP_SEND_TIMEOUT: %s", err)) - } - } -} - -func validateHTTPRequestOperand(term *ast.Term, pos int) (ast.Object, error) { - - obj, err := builtins.ObjectOperand(term.Value, pos) - if err != nil { - return nil, err - } - - requestKeys := ast.NewSet(obj.Keys()...) - - invalidKeys := requestKeys.Diff(allowedKeys) - if invalidKeys.Len() != 0 { - return nil, builtins.NewOperandErr(pos, "invalid request parameters(s): %v", invalidKeys) - } - - missingKeys := requiredKeys.Diff(requestKeys) - if missingKeys.Len() != 0 { - return nil, builtins.NewOperandErr(pos, "missing required request parameters(s): %v", missingKeys) - } - - return obj, nil - -} - -// canonicalizeHeaders returns a copy of the headers where the keys are in -// canonical HTTP form. -func canonicalizeHeaders(headers map[string]interface{}) map[string]interface{} { - canonicalized := map[string]interface{}{} - - for k, v := range headers { - canonicalized[http.CanonicalHeaderKey(k)] = v - } - - return canonicalized -} - -// useSocket examines the url for "unix://" and returns a *http.Transport with -// a DialContext that opens a socket (specified in the http call). -// The url is expected to contain socket=/path/to/socket (url encoded) -// Ex. "unix://localhost/end/point?socket=%2Ftmp%2Fhttp.sock" -func useSocket(rawURL string, tlsConfig *tls.Config) (bool, string, *http.Transport) { - u, err := url.Parse(rawURL) - if err != nil { - return false, "", nil - } - - if u.Scheme != "unix" || u.RawQuery == "" { - return false, rawURL, nil - } - - v, err := url.ParseQuery(u.RawQuery) - if err != nil { - return false, rawURL, nil - } - - // Rewrite URL targeting the UNIX domain socket. - u.Scheme = "http" - - // Extract the path to the socket. - // Only retrieve the first value. Subsequent values are ignored and removed - // to prevent HTTP parameter pollution. - socket := v.Get("socket") - v.Del("socket") - u.RawQuery = v.Encode() - - tr := http.DefaultTransport.(*http.Transport).Clone() - tr.DialContext = func(ctx context.Context, _, _ string) (net.Conn, error) { - return http.DefaultTransport.(*http.Transport).DialContext(ctx, "unix", socket) - } - tr.TLSClientConfig = tlsConfig - tr.DisableKeepAlives = true - - return true, u.String(), tr -} - -func verifyHost(bctx BuiltinContext, host string) error { - if bctx.Capabilities == nil || bctx.Capabilities.AllowNet == nil { - return nil - } - - for _, allowed := range bctx.Capabilities.AllowNet { - if allowed == host { - return nil - } - } - - return fmt.Errorf("unallowed host: %s", host) -} - -func verifyURLHost(bctx BuiltinContext, unverifiedURL string) error { - // Eager return to avoid unnecessary URL parsing - if bctx.Capabilities == nil || bctx.Capabilities.AllowNet == nil { - return nil - } - - parsedURL, err := url.Parse(unverifiedURL) - if err != nil { - return err - } - - host := strings.Split(parsedURL.Host, ":")[0] - - return verifyHost(bctx, host) -} - -func createHTTPRequest(bctx BuiltinContext, obj ast.Object) (*http.Request, *http.Client, error) { - var url string - var method string - - // Additional CA certificates loading options. - var tlsCaCert []byte - var tlsCaCertEnvVar string - var tlsCaCertFile string - - // Client TLS certificate and key options. Each input source - // comes in a matched pair. - var tlsClientCert []byte - var tlsClientKey []byte - - var tlsClientCertEnvVar string - var tlsClientKeyEnvVar string - - var tlsClientCertFile string - var tlsClientKeyFile string - - var tlsServerName string - var body *bytes.Buffer - var rawBody *bytes.Buffer - var enableRedirect bool - var tlsUseSystemCerts *bool - var tlsConfig tls.Config - var customHeaders map[string]interface{} - var tlsInsecureSkipVerify bool - timeout := defaultHTTPRequestTimeout - - for _, val := range obj.Keys() { - key, err := ast.JSON(val.Value) - if err != nil { - return nil, nil, err - } - - key = key.(string) - - var strVal string - - if s, ok := obj.Get(val).Value.(ast.String); ok { - strVal = strings.Trim(string(s), "\"") - } else { - // Most parameters are strings, so consolidate the type checking. - switch key { - case "method", - "url", - "raw_body", - "tls_ca_cert", - "tls_ca_cert_file", - "tls_ca_cert_env_variable", - "tls_client_cert", - "tls_client_cert_file", - "tls_client_cert_env_variable", - "tls_client_key", - "tls_client_key_file", - "tls_client_key_env_variable", - "tls_server_name": - return nil, nil, fmt.Errorf("%q must be a string", key) - } - } - - switch key { - case "method": - method = strings.ToUpper(strVal) - case "url": - err := verifyURLHost(bctx, strVal) - if err != nil { - return nil, nil, err - } - url = strVal - case "enable_redirect": - enableRedirect, err = strconv.ParseBool(obj.Get(val).String()) - if err != nil { - return nil, nil, err - } - case "body": - bodyVal := obj.Get(val).Value - bodyValInterface, err := ast.JSON(bodyVal) - if err != nil { - return nil, nil, err - } - - bodyValBytes, err := json.Marshal(bodyValInterface) - if err != nil { - return nil, nil, err - } - body = bytes.NewBuffer(bodyValBytes) - case "raw_body": - rawBody = bytes.NewBufferString(strVal) - case "tls_use_system_certs": - tempTLSUseSystemCerts, err := strconv.ParseBool(obj.Get(val).String()) - if err != nil { - return nil, nil, err - } - tlsUseSystemCerts = &tempTLSUseSystemCerts - case "tls_ca_cert": - tlsCaCert = []byte(strVal) - case "tls_ca_cert_file": - tlsCaCertFile = strVal - case "tls_ca_cert_env_variable": - tlsCaCertEnvVar = strVal - case "tls_client_cert": - tlsClientCert = []byte(strVal) - case "tls_client_cert_file": - tlsClientCertFile = strVal - case "tls_client_cert_env_variable": - tlsClientCertEnvVar = strVal - case "tls_client_key": - tlsClientKey = []byte(strVal) - case "tls_client_key_file": - tlsClientKeyFile = strVal - case "tls_client_key_env_variable": - tlsClientKeyEnvVar = strVal - case "tls_server_name": - tlsServerName = strVal - case "headers": - headersVal := obj.Get(val).Value - headersValInterface, err := ast.JSON(headersVal) - if err != nil { - return nil, nil, err - } - var ok bool - customHeaders, ok = headersValInterface.(map[string]interface{}) - if !ok { - return nil, nil, fmt.Errorf("invalid type for headers key") - } - case "tls_insecure_skip_verify": - tlsInsecureSkipVerify, err = strconv.ParseBool(obj.Get(val).String()) - if err != nil { - return nil, nil, err - } - case "timeout": - timeout, err = parseTimeout(obj.Get(val).Value) - if err != nil { - return nil, nil, err - } - case "cache", "caching_mode", - "force_cache", "force_cache_duration_seconds", - "force_json_decode", "force_yaml_decode", - "raise_error", "max_retry_attempts", "cache_ignored_headers": // no-op - default: - return nil, nil, fmt.Errorf("invalid parameter %q", key) - } - } - - isTLS := false - client := &http.Client{ - Timeout: timeout, - CheckRedirect: func(*http.Request, []*http.Request) error { - return http.ErrUseLastResponse - }, - } - - if tlsInsecureSkipVerify { - isTLS = true - tlsConfig.InsecureSkipVerify = tlsInsecureSkipVerify - } - - if len(tlsClientCert) > 0 && len(tlsClientKey) > 0 { - cert, err := tls.X509KeyPair(tlsClientCert, tlsClientKey) - if err != nil { - return nil, nil, err - } - - isTLS = true - tlsConfig.Certificates = append(tlsConfig.Certificates, cert) - } - - if tlsClientCertFile != "" && tlsClientKeyFile != "" { - cert, err := tls.LoadX509KeyPair(tlsClientCertFile, tlsClientKeyFile) - if err != nil { - return nil, nil, err - } - - isTLS = true - tlsConfig.Certificates = append(tlsConfig.Certificates, cert) - } - - if tlsClientCertEnvVar != "" && tlsClientKeyEnvVar != "" { - cert, err := tls.X509KeyPair( - []byte(os.Getenv(tlsClientCertEnvVar)), - []byte(os.Getenv(tlsClientKeyEnvVar))) - if err != nil { - return nil, nil, fmt.Errorf("cannot extract public/private key pair from envvars %q, %q: %w", - tlsClientCertEnvVar, tlsClientKeyEnvVar, err) - } - - isTLS = true - tlsConfig.Certificates = append(tlsConfig.Certificates, cert) - } - - // Use system certs if no CA cert is provided - // or system certs flag is not set - if len(tlsCaCert) == 0 && tlsCaCertFile == "" && tlsCaCertEnvVar == "" && tlsUseSystemCerts == nil { - trueValue := true - tlsUseSystemCerts = &trueValue - } - - // Check the system certificates config first so that we - // load additional certificated into the correct pool. - if tlsUseSystemCerts != nil && *tlsUseSystemCerts && runtime.GOOS != "windows" { - pool, err := x509.SystemCertPool() - if err != nil { - return nil, nil, err - } - - isTLS = true - tlsConfig.RootCAs = pool - } - - if len(tlsCaCert) != 0 { - tlsCaCert = bytes.Replace(tlsCaCert, []byte("\\n"), []byte("\n"), -1) - pool, err := addCACertsFromBytes(tlsConfig.RootCAs, tlsCaCert) - if err != nil { - return nil, nil, err - } - - isTLS = true - tlsConfig.RootCAs = pool - } - - if tlsCaCertFile != "" { - pool, err := addCACertsFromFile(tlsConfig.RootCAs, tlsCaCertFile) - if err != nil { - return nil, nil, err - } - - isTLS = true - tlsConfig.RootCAs = pool - } - - if tlsCaCertEnvVar != "" { - pool, err := addCACertsFromEnv(tlsConfig.RootCAs, tlsCaCertEnvVar) - if err != nil { - return nil, nil, err - } - - isTLS = true - tlsConfig.RootCAs = pool - } - - if isTLS { - if ok, parsedURL, tr := useSocket(url, &tlsConfig); ok { - client.Transport = tr - url = parsedURL - } else { - tr := http.DefaultTransport.(*http.Transport).Clone() - tr.TLSClientConfig = &tlsConfig - tr.DisableKeepAlives = true - client.Transport = tr - } - } else { - if ok, parsedURL, tr := useSocket(url, nil); ok { - client.Transport = tr - url = parsedURL - } - } - - // check if redirects are enabled - if enableRedirect { - client.CheckRedirect = func(req *http.Request, _ []*http.Request) error { - return verifyURLHost(bctx, req.URL.String()) - } - } - - if rawBody != nil { - body = rawBody - } else if body == nil { - body = bytes.NewBufferString("") - } - - // create the http request, use the builtin context's context to ensure - // the request is cancelled if evaluation is cancelled. - req, err := http.NewRequest(method, url, body) - if err != nil { - return nil, nil, err - } - - req = req.WithContext(bctx.Context) - - // Add custom headers - if len(customHeaders) != 0 { - customHeaders = canonicalizeHeaders(customHeaders) - - for k, v := range customHeaders { - header, ok := v.(string) - if !ok { - return nil, nil, fmt.Errorf("invalid type for headers value %q", v) - } - - req.Header.Add(k, header) - } - - // Don't overwrite or append to one that was set in the custom headers - if _, hasUA := customHeaders["User-Agent"]; !hasUA { - req.Header.Add("User-Agent", version.UserAgent) - } - - // If the caller specifies the Host header, use it for the HTTP - // request host and the TLS server name. - if host, hasHost := customHeaders["Host"]; hasHost { - host := host.(string) // We already checked that it's a string. - req.Host = host - - // Only default the ServerName if the caller has - // specified the host. If we don't specify anything, - // Go will default to the target hostname. This name - // is not the same as the default that Go populates - // `req.Host` with, which is why we don't just set - // this unconditionally. - tlsConfig.ServerName = host - } - } - - if tlsServerName != "" { - tlsConfig.ServerName = tlsServerName - } - - if len(bctx.DistributedTracingOpts) > 0 { - client.Transport = tracing.NewTransport(client.Transport, bctx.DistributedTracingOpts) - } - - return req, client, nil -} - -func executeHTTPRequest(req *http.Request, client *http.Client, inputReqObj ast.Object) (*http.Response, error) { - var err error - var retry int - - retry, err = getNumberValFromReqObj(inputReqObj, ast.StringTerm("max_retry_attempts")) - if err != nil { - return nil, err - } - - for i := 0; true; i++ { - - var resp *http.Response - resp, err = client.Do(req) - if err == nil { - return resp, nil - } - - // final attempt - if i == retry { - break - } - - if err == context.Canceled { - return nil, err - } - - delay := util.DefaultBackoff(float64(minRetryDelay), float64(maxRetryDelay), i) - timer, timerCancel := util.TimerWithCancel(delay) - select { - case <-timer.C: - case <-req.Context().Done(): - timerCancel() // explicitly cancel the timer. - return nil, context.Canceled - } - } - return nil, err -} - -func isContentType(header http.Header, typ ...string) bool { - for _, t := range typ { - if strings.Contains(header.Get("Content-Type"), t) { - return true - } - } - return false -} - -type httpSendCacheEntry struct { - response *ast.Value - error error -} - -// The httpSendCache is used for intra-query caching of http.send results. -type httpSendCache struct { - entries *util.HashMap -} - -func newHTTPSendCache() *httpSendCache { - return &httpSendCache{ - entries: util.NewHashMap(valueEq, valueHash), - } -} - -func valueHash(v util.T) int { - return ast.StringTerm(v.(ast.Value).String()).Hash() -} - -func valueEq(a, b util.T) bool { - av := a.(ast.Value) - bv := b.(ast.Value) - return av.String() == bv.String() -} - -func (cache *httpSendCache) get(k ast.Value) *httpSendCacheEntry { - if v, ok := cache.entries.Get(k); ok { - v := v.(httpSendCacheEntry) - return &v - } - return nil -} - -func (cache *httpSendCache) putResponse(k ast.Value, v *ast.Value) { - cache.entries.Put(k, httpSendCacheEntry{response: v}) -} - -func (cache *httpSendCache) putError(k ast.Value, v error) { - cache.entries.Put(k, httpSendCacheEntry{error: v}) -} - -// In the BuiltinContext cache we only store a single entry that points to -// our ValueMap which is the "real" http.send() cache. -func getHTTPSendCache(bctx BuiltinContext) *httpSendCache { - raw, ok := bctx.Cache.Get(httpSendBuiltinCacheKey) - if !ok { - // Initialize if it isn't there - c := newHTTPSendCache() - bctx.Cache.Put(httpSendBuiltinCacheKey, c) - return c - } - - c, ok := raw.(*httpSendCache) - if !ok { - return nil - } - return c -} - -// checkHTTPSendCache checks for the given key's value in the cache -func checkHTTPSendCache(bctx BuiltinContext, key ast.Object) (ast.Value, error) { - requestCache := getHTTPSendCache(bctx) - if requestCache == nil { - return nil, nil - } - - v := requestCache.get(key) - if v != nil { - if v.error != nil { - return nil, v.error - } - if v.response != nil { - return *v.response, nil - } - // This should never happen - } - - return nil, nil -} - -func insertIntoHTTPSendCache(bctx BuiltinContext, key ast.Object, value ast.Value) { - requestCache := getHTTPSendCache(bctx) - if requestCache == nil { - // Should never happen.. if it does just skip caching the value - // FIXME: return error instead, to prevent inconsistencies? - return - } - requestCache.putResponse(key, &value) -} - -func insertErrorIntoHTTPSendCache(bctx BuiltinContext, key ast.Object, err error) { - requestCache := getHTTPSendCache(bctx) - if requestCache == nil { - // Should never happen.. if it does just skip caching the value - // FIXME: return error instead, to prevent inconsistencies? - return - } - requestCache.putError(key, err) -} - -// checkHTTPSendInterQueryCache checks for the given key's value in the inter-query cache -func (c *interQueryCache) checkHTTPSendInterQueryCache() (ast.Value, error) { - requestCache := c.bctx.InterQueryBuiltinCache - - cachedValue, found := requestCache.Get(c.key) - if !found { - return nil, nil - } - - value, cerr := requestCache.Clone(cachedValue) - if cerr != nil { - return nil, handleHTTPSendErr(c.bctx, cerr) - } - - c.bctx.Metrics.Counter(httpSendInterQueryCacheHits).Incr() - var cachedRespData *interQueryCacheData - - switch v := value.(type) { - case *interQueryCacheValue: - var err error - cachedRespData, err = v.copyCacheData() - if err != nil { - return nil, err - } - case *interQueryCacheData: - cachedRespData = v - default: - return nil, nil - } - - if getCurrentTime(c.bctx).Before(cachedRespData.ExpiresAt) { - return cachedRespData.formatToAST(c.forceJSONDecode, c.forceYAMLDecode) - } - - var err error - c.httpReq, c.httpClient, err = createHTTPRequest(c.bctx, c.key) - if err != nil { - return nil, handleHTTPSendErr(c.bctx, err) - } - - headers := parseResponseHeaders(cachedRespData.Headers) - - // check with the server if the stale response is still up-to-date. - // If server returns a new response (ie. status_code=200), update the cache with the new response - // If server returns an unmodified response (ie. status_code=304), update the headers for the existing response - result, modified, err := revalidateCachedResponse(c.httpReq, c.httpClient, c.key, headers) - requestCache.Delete(c.key) - if err != nil || result == nil { - return nil, err - } - - defer result.Body.Close() - - if !modified { - // update the headers in the cached response with their corresponding values from the 304 (Not Modified) response - for headerName, values := range result.Header { - cachedRespData.Headers.Del(headerName) - for _, v := range values { - cachedRespData.Headers.Add(headerName, v) - } - } - - if forceCaching(c.forceCacheParams) { - createdAt := getCurrentTime(c.bctx) - cachedRespData.ExpiresAt = createdAt.Add(time.Second * time.Duration(c.forceCacheParams.forceCacheDurationSeconds)) - } else { - expiresAt, err := expiryFromHeaders(result.Header) - if err != nil { - return nil, err - } - cachedRespData.ExpiresAt = expiresAt - } - - cachingMode, err := getCachingMode(c.key) - if err != nil { - return nil, err - } - - var pcv cache.InterQueryCacheValue - - if cachingMode == defaultCachingMode { - pcv, err = cachedRespData.toCacheValue() - if err != nil { - return nil, err - } - } else { - pcv = cachedRespData - } - - c.bctx.InterQueryBuiltinCache.InsertWithExpiry(c.key, pcv, cachedRespData.ExpiresAt) - - return cachedRespData.formatToAST(c.forceJSONDecode, c.forceYAMLDecode) - } - - newValue, respBody, err := formatHTTPResponseToAST(result, c.forceJSONDecode, c.forceYAMLDecode) - if err != nil { - return nil, err - } - - if err := insertIntoHTTPSendInterQueryCache(c.bctx, c.key, result, respBody, c.forceCacheParams); err != nil { - return nil, err - } - - return newValue, nil -} - -// insertIntoHTTPSendInterQueryCache inserts given key and value in the inter-query cache -func insertIntoHTTPSendInterQueryCache(bctx BuiltinContext, key ast.Value, resp *http.Response, respBody []byte, cacheParams *forceCacheParams) error { - if resp == nil || (!forceCaching(cacheParams) && !canStore(resp.Header)) || !cacheableCodes.Contains(ast.IntNumberTerm(resp.StatusCode)) { - return nil - } - - requestCache := bctx.InterQueryBuiltinCache - - obj, ok := key.(ast.Object) - if !ok { - return fmt.Errorf("interface conversion error") - } - - cachingMode, err := getCachingMode(obj) - if err != nil { - return err - } - - var pcv cache.InterQueryCacheValue - var pcvData *interQueryCacheData - if cachingMode == defaultCachingMode { - pcv, pcvData, err = newInterQueryCacheValue(bctx, resp, respBody, cacheParams) - } else { - pcvData, err = newInterQueryCacheData(bctx, resp, respBody, cacheParams) - pcv = pcvData - } - - if err != nil { - return err - } - - requestCache.InsertWithExpiry(key, pcv, pcvData.ExpiresAt) - return nil -} - -func createAllowedKeys() { - for _, element := range allowedKeyNames { - allowedKeys.Add(ast.StringTerm(element)) - } -} - -func createCacheableHTTPStatusCodes() { - for _, element := range cacheableHTTPStatusCodes { - cacheableCodes.Add(ast.IntNumberTerm(element)) - } -} - -func parseTimeout(timeoutVal ast.Value) (time.Duration, error) { - var timeout time.Duration - switch t := timeoutVal.(type) { - case ast.Number: - timeoutInt, ok := t.Int64() - if !ok { - return timeout, fmt.Errorf("invalid timeout number value %v, must be int64", timeoutVal) - } - return time.Duration(timeoutInt), nil - case ast.String: - // Support strings without a unit, treat them the same as just a number value (ns) - var err error - timeoutInt, err := strconv.ParseInt(string(t), 10, 64) - if err == nil { - return time.Duration(timeoutInt), nil - } - - // Try parsing it as a duration (requires a supported units suffix) - timeout, err = time.ParseDuration(string(t)) - if err != nil { - return timeout, fmt.Errorf("invalid timeout value %v: %s", timeoutVal, err) - } - return timeout, nil - default: - return timeout, builtins.NewOperandErr(1, "'timeout' must be one of {string, number} but got %s", ast.TypeName(t)) - } -} - -func getBoolValFromReqObj(req ast.Object, key *ast.Term) (bool, error) { - var b ast.Boolean - var ok bool - if v := req.Get(key); v != nil { - if b, ok = v.Value.(ast.Boolean); !ok { - return false, fmt.Errorf("invalid value for %v field", key.String()) - } - } - return bool(b), nil -} - -func getNumberValFromReqObj(req ast.Object, key *ast.Term) (int, error) { - term := req.Get(key) - if term == nil { - return 0, nil - } - - if t, ok := term.Value.(ast.Number); ok { - num, ok := t.Int() - if !ok || num < 0 { - return 0, fmt.Errorf("invalid value %v for field %v", t.String(), key.String()) - } - return num, nil - } - - return 0, fmt.Errorf("invalid value %v for field %v", term.String(), key.String()) -} - -func getCachingMode(req ast.Object) (cachingMode, error) { - key := ast.StringTerm("caching_mode") - var s ast.String - var ok bool - if v := req.Get(key); v != nil { - if s, ok = v.Value.(ast.String); !ok { - return "", fmt.Errorf("invalid value for %v field", key.String()) - } - - switch cachingMode(s) { - case defaultCachingMode, cachingModeDeserialized: - return cachingMode(s), nil - default: - return "", fmt.Errorf("invalid value specified for %v field: %v", key.String(), string(s)) - } - } - return defaultCachingMode, nil -} - -type interQueryCacheValue struct { - Data []byte -} - -func newInterQueryCacheValue(bctx BuiltinContext, resp *http.Response, respBody []byte, cacheParams *forceCacheParams) (*interQueryCacheValue, *interQueryCacheData, error) { - data, err := newInterQueryCacheData(bctx, resp, respBody, cacheParams) - if err != nil { - return nil, nil, err - } - - b, err := json.Marshal(data) - if err != nil { - return nil, nil, err - } - return &interQueryCacheValue{Data: b}, data, nil -} - -func (cb interQueryCacheValue) Clone() (cache.InterQueryCacheValue, error) { - dup := make([]byte, len(cb.Data)) - copy(dup, cb.Data) - return &interQueryCacheValue{Data: dup}, nil -} - -func (cb interQueryCacheValue) SizeInBytes() int64 { - return int64(len(cb.Data)) -} - -func (cb *interQueryCacheValue) copyCacheData() (*interQueryCacheData, error) { - var res interQueryCacheData - err := util.UnmarshalJSON(cb.Data, &res) - if err != nil { - return nil, err - } - return &res, nil -} - -type interQueryCacheData struct { - RespBody []byte - Status string - StatusCode int - Headers http.Header - ExpiresAt time.Time -} - -func forceCaching(cacheParams *forceCacheParams) bool { - return cacheParams != nil && cacheParams.forceCacheDurationSeconds > 0 -} - -func expiryFromHeaders(headers http.Header) (time.Time, error) { - var expiresAt time.Time - maxAge, err := parseMaxAgeCacheDirective(parseCacheControlHeader(headers)) - if err != nil { - return time.Time{}, err - } - if maxAge != -1 { - createdAt, err := getResponseHeaderDate(headers) - if err != nil { - return time.Time{}, err - } - expiresAt = createdAt.Add(time.Second * time.Duration(maxAge)) - } else { - expiresAt = getResponseHeaderExpires(headers) - } - return expiresAt, nil -} - -func newInterQueryCacheData(bctx BuiltinContext, resp *http.Response, respBody []byte, cacheParams *forceCacheParams) (*interQueryCacheData, error) { - var expiresAt time.Time - - if forceCaching(cacheParams) { - createdAt := getCurrentTime(bctx) - expiresAt = createdAt.Add(time.Second * time.Duration(cacheParams.forceCacheDurationSeconds)) - } else { - var err error - expiresAt, err = expiryFromHeaders(resp.Header) - if err != nil { - return nil, err - } - } - - cv := interQueryCacheData{ - ExpiresAt: expiresAt, - RespBody: respBody, - Status: resp.Status, - StatusCode: resp.StatusCode, - Headers: resp.Header} - - return &cv, nil -} - -func (c *interQueryCacheData) formatToAST(forceJSONDecode, forceYAMLDecode bool) (ast.Value, error) { - return prepareASTResult(c.Headers, forceJSONDecode, forceYAMLDecode, c.RespBody, c.Status, c.StatusCode) -} - -func (c *interQueryCacheData) toCacheValue() (*interQueryCacheValue, error) { - b, err := json.Marshal(c) - if err != nil { - return nil, err - } - return &interQueryCacheValue{Data: b}, nil -} - -func (c *interQueryCacheData) SizeInBytes() int64 { - return 0 -} - -func (c *interQueryCacheData) Clone() (cache.InterQueryCacheValue, error) { - dup := make([]byte, len(c.RespBody)) - copy(dup, c.RespBody) - - return &interQueryCacheData{ - ExpiresAt: c.ExpiresAt, - RespBody: dup, - Status: c.Status, - StatusCode: c.StatusCode, - Headers: c.Headers.Clone()}, nil -} - -type responseHeaders struct { - etag string // identifier for a specific version of the response - lastModified string // date and time response was last modified as per origin server -} - -// deltaSeconds specifies a non-negative integer, representing -// time in seconds: http://tools.ietf.org/html/rfc7234#section-1.2.1 -type deltaSeconds int32 - -func parseResponseHeaders(headers http.Header) *responseHeaders { - result := responseHeaders{} - - result.etag = headers.Get("etag") - - result.lastModified = headers.Get("last-modified") - - return &result -} - -func revalidateCachedResponse(req *http.Request, client *http.Client, inputReqObj ast.Object, headers *responseHeaders) (*http.Response, bool, error) { - etag := headers.etag - lastModified := headers.lastModified - - if etag == "" && lastModified == "" { - return nil, false, nil - } - - cloneReq := req.Clone(req.Context()) - - if etag != "" { - cloneReq.Header.Set("if-none-match", etag) - } - - if lastModified != "" { - cloneReq.Header.Set("if-modified-since", lastModified) - } - - response, err := executeHTTPRequest(cloneReq, client, inputReqObj) - if err != nil { - return nil, false, err - } - - switch response.StatusCode { - case http.StatusOK: - return response, true, nil - - case http.StatusNotModified: - return response, false, nil - } - util.Close(response) - return nil, false, nil -} - -func canStore(headers http.Header) bool { - ccHeaders := parseCacheControlHeader(headers) - - // Check "no-store" cache directive - // The "no-store" response directive indicates that a cache MUST NOT - // store any part of either the immediate request or response. - if _, ok := ccHeaders["no-store"]; ok { - return false - } - return true -} - -func getCurrentTime(bctx BuiltinContext) time.Time { - var current time.Time - - value, err := ast.JSON(bctx.Time.Value) - if err != nil { - return current - } - - valueNum, ok := value.(json.Number) - if !ok { - return current - } - - valueNumInt, err := valueNum.Int64() - if err != nil { - return current - } - - current = time.Unix(0, valueNumInt).UTC() - return current -} - -func parseCacheControlHeader(headers http.Header) map[string]string { - ccDirectives := map[string]string{} - ccHeader := headers.Get("cache-control") - - for _, part := range strings.Split(ccHeader, ",") { - part = strings.Trim(part, " ") - if part == "" { - continue - } - if strings.ContainsRune(part, '=') { - items := strings.Split(part, "=") - if len(items) != 2 { - continue - } - ccDirectives[strings.Trim(items[0], " ")] = strings.Trim(items[1], ",") - } else { - ccDirectives[part] = "" - } - } - - return ccDirectives -} - -func getResponseHeaderDate(headers http.Header) (date time.Time, err error) { - dateHeader := headers.Get("date") - if dateHeader == "" { - err = fmt.Errorf("no date header") - return - } - return http.ParseTime(dateHeader) -} - -func getResponseHeaderExpires(headers http.Header) time.Time { - expiresHeader := headers.Get("expires") - if expiresHeader == "" { - return time.Time{} - } - - date, err := http.ParseTime(expiresHeader) - if err != nil { - // servers can set `Expires: 0` which is an invalid date to indicate expired content - return time.Time{} - } - - return date -} - -// parseMaxAgeCacheDirective parses the max-age directive expressed in delta-seconds as per -// https://tools.ietf.org/html/rfc7234#section-1.2.1 -func parseMaxAgeCacheDirective(cc map[string]string) (deltaSeconds, error) { - maxAge, ok := cc["max-age"] - if !ok { - return deltaSeconds(-1), nil - } - - val, err := strconv.ParseUint(maxAge, 10, 32) - if err != nil { - if numError, ok := err.(*strconv.NumError); ok { - if numError.Err == strconv.ErrRange { - return deltaSeconds(math.MaxInt32), nil - } - } - return deltaSeconds(-1), err - } - - if val > math.MaxInt32 { - return deltaSeconds(math.MaxInt32), nil - } - return deltaSeconds(val), nil -} - -func formatHTTPResponseToAST(resp *http.Response, forceJSONDecode, forceYAMLDecode bool) (ast.Value, []byte, error) { - - resultRawBody, err := io.ReadAll(resp.Body) - if err != nil { - return nil, nil, err - } - - resultObj, err := prepareASTResult(resp.Header, forceJSONDecode, forceYAMLDecode, resultRawBody, resp.Status, resp.StatusCode) - if err != nil { - return nil, nil, err - } - - return resultObj, resultRawBody, nil -} - -func prepareASTResult(headers http.Header, forceJSONDecode, forceYAMLDecode bool, body []byte, status string, statusCode int) (ast.Value, error) { - var resultBody interface{} - - // If the response body cannot be JSON/YAML decoded, - // an error will not be returned. Instead, the "body" field - // in the result will be null. - switch { - case forceJSONDecode || isContentType(headers, "application/json"): - _ = util.UnmarshalJSON(body, &resultBody) - case forceYAMLDecode || isContentType(headers, "application/yaml", "application/x-yaml"): - _ = util.Unmarshal(body, &resultBody) - } - - result := make(map[string]interface{}) - result["status"] = status - result["status_code"] = statusCode - result["body"] = resultBody - result["raw_body"] = string(body) - result["headers"] = getResponseHeaders(headers) - - resultObj, err := ast.InterfaceToValue(result) - if err != nil { - return nil, err - } - - return resultObj, nil -} - -func getResponseHeaders(headers http.Header) map[string]interface{} { - respHeaders := map[string]interface{}{} - for headerName, values := range headers { - var respValues []interface{} - for _, v := range values { - respValues = append(respValues, v) - } - respHeaders[strings.ToLower(headerName)] = respValues - } - return respHeaders -} - -// httpRequestExecutor defines an interface for the http send cache -type httpRequestExecutor interface { - CheckCache() (ast.Value, error) - InsertIntoCache(value *http.Response) (ast.Value, error) - InsertErrorIntoCache(err error) - ExecuteHTTPRequest() (*http.Response, error) -} - -// newHTTPRequestExecutor returns a new HTTP request executor that wraps either an inter-query or -// intra-query cache implementation -func newHTTPRequestExecutor(bctx BuiltinContext, req ast.Object, key ast.Object) (httpRequestExecutor, error) { - useInterQueryCache, forceCacheParams, err := useInterQueryCache(req) - if err != nil { - return nil, handleHTTPSendErr(bctx, err) - } - - if useInterQueryCache && bctx.InterQueryBuiltinCache != nil { - return newInterQueryCache(bctx, req, key, forceCacheParams) - } - return newIntraQueryCache(bctx, req, key) -} - -type interQueryCache struct { - bctx BuiltinContext - req ast.Object - key ast.Object - httpReq *http.Request - httpClient *http.Client - forceJSONDecode bool - forceYAMLDecode bool - forceCacheParams *forceCacheParams -} - -func newInterQueryCache(bctx BuiltinContext, req ast.Object, key ast.Object, forceCacheParams *forceCacheParams) (*interQueryCache, error) { - return &interQueryCache{bctx: bctx, req: req, key: key, forceCacheParams: forceCacheParams}, nil -} - -// CheckCache checks the cache for the value of the key set on this object -func (c *interQueryCache) CheckCache() (ast.Value, error) { - var err error - - // Checking the intra-query cache first ensures consistency of errors and HTTP responses within a query. - resp, err := checkHTTPSendCache(c.bctx, c.key) - if err != nil { - return nil, err - } - if resp != nil { - return resp, nil - } - - c.forceJSONDecode, err = getBoolValFromReqObj(c.key, ast.StringTerm("force_json_decode")) - if err != nil { - return nil, handleHTTPSendErr(c.bctx, err) - } - c.forceYAMLDecode, err = getBoolValFromReqObj(c.key, ast.StringTerm("force_yaml_decode")) - if err != nil { - return nil, handleHTTPSendErr(c.bctx, err) - } - - resp, err = c.checkHTTPSendInterQueryCache() - // Always insert the result of the inter-query cache into the intra-query cache, to maintain consistency within the same query. - if err != nil { - insertErrorIntoHTTPSendCache(c.bctx, c.key, err) - } - if resp != nil { - insertIntoHTTPSendCache(c.bctx, c.key, resp) - } - return resp, err -} - -// InsertIntoCache inserts the key set on this object into the cache with the given value -func (c *interQueryCache) InsertIntoCache(value *http.Response) (ast.Value, error) { - result, respBody, err := formatHTTPResponseToAST(value, c.forceJSONDecode, c.forceYAMLDecode) - if err != nil { - return nil, handleHTTPSendErr(c.bctx, err) - } - - // Always insert into the intra-query cache, to maintain consistency within the same query. - insertIntoHTTPSendCache(c.bctx, c.key, result) - - // We ignore errors when populating the inter-query cache, because we've already populated the intra-cache, - // and query consistency is our primary concern. - _ = insertIntoHTTPSendInterQueryCache(c.bctx, c.key, value, respBody, c.forceCacheParams) - return result, nil -} - -func (c *interQueryCache) InsertErrorIntoCache(err error) { - insertErrorIntoHTTPSendCache(c.bctx, c.key, err) -} - -// ExecuteHTTPRequest executes a HTTP request -func (c *interQueryCache) ExecuteHTTPRequest() (*http.Response, error) { - var err error - c.httpReq, c.httpClient, err = createHTTPRequest(c.bctx, c.req) - if err != nil { - return nil, handleHTTPSendErr(c.bctx, err) - } - - return executeHTTPRequest(c.httpReq, c.httpClient, c.req) -} - -type intraQueryCache struct { - bctx BuiltinContext - req ast.Object - key ast.Object -} - -func newIntraQueryCache(bctx BuiltinContext, req ast.Object, key ast.Object) (*intraQueryCache, error) { - return &intraQueryCache{bctx: bctx, req: req, key: key}, nil -} - -// CheckCache checks the cache for the value of the key set on this object -func (c *intraQueryCache) CheckCache() (ast.Value, error) { - return checkHTTPSendCache(c.bctx, c.key) -} - -// InsertIntoCache inserts the key set on this object into the cache with the given value -func (c *intraQueryCache) InsertIntoCache(value *http.Response) (ast.Value, error) { - forceJSONDecode, err := getBoolValFromReqObj(c.key, ast.StringTerm("force_json_decode")) - if err != nil { - return nil, handleHTTPSendErr(c.bctx, err) - } - forceYAMLDecode, err := getBoolValFromReqObj(c.key, ast.StringTerm("force_yaml_decode")) - if err != nil { - return nil, handleHTTPSendErr(c.bctx, err) - } - - result, _, err := formatHTTPResponseToAST(value, forceJSONDecode, forceYAMLDecode) - if err != nil { - return nil, handleHTTPSendErr(c.bctx, err) - } - - if cacheableCodes.Contains(ast.IntNumberTerm(value.StatusCode)) { - insertIntoHTTPSendCache(c.bctx, c.key, result) - } - - return result, nil -} - -func (c *intraQueryCache) InsertErrorIntoCache(err error) { - insertErrorIntoHTTPSendCache(c.bctx, c.key, err) -} - -// ExecuteHTTPRequest executes a HTTP request -func (c *intraQueryCache) ExecuteHTTPRequest() (*http.Response, error) { - httpReq, httpClient, err := createHTTPRequest(c.bctx, c.req) - if err != nil { - return nil, handleHTTPSendErr(c.bctx, err) - } - return executeHTTPRequest(httpReq, httpClient, c.req) -} - -func useInterQueryCache(req ast.Object) (bool, *forceCacheParams, error) { - value, err := getBoolValFromReqObj(req, ast.StringTerm("cache")) - if err != nil { - return false, nil, err - } - - valueForceCache, err := getBoolValFromReqObj(req, ast.StringTerm("force_cache")) - if err != nil { - return false, nil, err - } - - if valueForceCache { - forceCacheParams, err := newForceCacheParams(req) - return true, forceCacheParams, err - } - - return value, nil, nil -} - -type forceCacheParams struct { - forceCacheDurationSeconds int32 -} - -func newForceCacheParams(req ast.Object) (*forceCacheParams, error) { - term := req.Get(ast.StringTerm("force_cache_duration_seconds")) - if term == nil { - return nil, fmt.Errorf("'force_cache' set but 'force_cache_duration_seconds' parameter is missing") - } - - forceCacheDurationSeconds := term.String() - - value, err := strconv.ParseInt(forceCacheDurationSeconds, 10, 32) - if err != nil { - return nil, err - } - - return &forceCacheParams{forceCacheDurationSeconds: int32(value)}, nil -} - -func getRaiseErrorValue(req ast.Object) (bool, error) { - result := ast.Boolean(true) - var ok bool - if v := req.Get(ast.StringTerm("raise_error")); v != nil { - if result, ok = v.Value.(ast.Boolean); !ok { - return false, fmt.Errorf("invalid value for raise_error field") - } - } - return bool(result), nil -} diff --git a/vendor/github.com/open-policy-agent/opa/topdown/instrumentation.go b/vendor/github.com/open-policy-agent/opa/topdown/instrumentation.go index 6eacc338e..845f8da61 100644 --- a/vendor/github.com/open-policy-agent/opa/topdown/instrumentation.go +++ b/vendor/github.com/open-policy-agent/opa/topdown/instrumentation.go @@ -4,60 +4,18 @@ package topdown -import "github.com/open-policy-agent/opa/metrics" - -const ( - evalOpPlug = "eval_op_plug" - evalOpResolve = "eval_op_resolve" - evalOpRuleIndex = "eval_op_rule_index" - evalOpBuiltinCall = "eval_op_builtin_call" - evalOpVirtualCacheHit = "eval_op_virtual_cache_hit" - evalOpVirtualCacheMiss = "eval_op_virtual_cache_miss" - evalOpBaseCacheHit = "eval_op_base_cache_hit" - evalOpBaseCacheMiss = "eval_op_base_cache_miss" - evalOpComprehensionCacheSkip = "eval_op_comprehension_cache_skip" - evalOpComprehensionCacheBuild = "eval_op_comprehension_cache_build" - evalOpComprehensionCacheHit = "eval_op_comprehension_cache_hit" - evalOpComprehensionCacheMiss = "eval_op_comprehension_cache_miss" - partialOpSaveUnify = "partial_op_save_unify" - partialOpSaveSetContains = "partial_op_save_set_contains" - partialOpSaveSetContainsRec = "partial_op_save_set_contains_rec" - partialOpCopyPropagation = "partial_op_copy_propagation" +import ( + "github.com/open-policy-agent/opa/v1/metrics" + v1 "github.com/open-policy-agent/opa/v1/topdown" ) // Instrumentation implements helper functions to instrument query evaluation // to diagnose performance issues. Instrumentation may be expensive in some // cases, so it is disabled by default. -type Instrumentation struct { - m metrics.Metrics -} +type Instrumentation = v1.Instrumentation // NewInstrumentation returns a new Instrumentation object. Performance // diagnostics recorded on this Instrumentation object will stored in m. func NewInstrumentation(m metrics.Metrics) *Instrumentation { - return &Instrumentation{ - m: m, - } -} - -func (instr *Instrumentation) startTimer(name string) { - if instr == nil { - return - } - instr.m.Timer(name).Start() -} - -func (instr *Instrumentation) stopTimer(name string) { - if instr == nil { - return - } - delta := instr.m.Timer(name).Stop() - instr.m.Histogram(name).Update(delta) -} - -func (instr *Instrumentation) counterIncr(name string) { - if instr == nil { - return - } - instr.m.Counter(name).Incr() + return v1.NewInstrumentation(m) } diff --git a/vendor/github.com/open-policy-agent/opa/topdown/print.go b/vendor/github.com/open-policy-agent/opa/topdown/print.go index 765b344b3..5eacd180d 100644 --- a/vendor/github.com/open-policy-agent/opa/topdown/print.go +++ b/vendor/github.com/open-policy-agent/opa/topdown/print.go @@ -5,82 +5,12 @@ package topdown import ( - "fmt" "io" - "strings" - "github.com/open-policy-agent/opa/ast" - "github.com/open-policy-agent/opa/topdown/builtins" "github.com/open-policy-agent/opa/topdown/print" + v1 "github.com/open-policy-agent/opa/v1/topdown" ) func NewPrintHook(w io.Writer) print.Hook { - return printHook{w: w} -} - -type printHook struct { - w io.Writer -} - -func (h printHook) Print(_ print.Context, msg string) error { - _, err := fmt.Fprintln(h.w, msg) - return err -} - -func builtinPrint(bctx BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error { - - if bctx.PrintHook == nil { - return iter(nil) - } - - arr, err := builtins.ArrayOperand(operands[0].Value, 1) - if err != nil { - return err - } - - buf := make([]string, arr.Len()) - - err = builtinPrintCrossProductOperands(bctx, buf, arr, 0, func(buf []string) error { - pctx := print.Context{ - Context: bctx.Context, - Location: bctx.Location, - } - return bctx.PrintHook.Print(pctx, strings.Join(buf, " ")) - }) - if err != nil { - return err - } - - return iter(nil) -} - -func builtinPrintCrossProductOperands(bctx BuiltinContext, buf []string, operands *ast.Array, i int, f func([]string) error) error { - - if i >= operands.Len() { - return f(buf) - } - - xs, ok := operands.Elem(i).Value.(ast.Set) - if !ok { - return Halt{Err: internalErr(bctx.Location, fmt.Sprintf("illegal argument type: %v", ast.TypeName(operands.Elem(i).Value)))} - } - - if xs.Len() == 0 { - buf[i] = "" - return builtinPrintCrossProductOperands(bctx, buf, operands, i+1, f) - } - - return xs.Iter(func(x *ast.Term) error { - switch v := x.Value.(type) { - case ast.String: - buf[i] = string(v) - default: - buf[i] = v.String() - } - return builtinPrintCrossProductOperands(bctx, buf, operands, i+1, f) - }) -} - -func init() { - RegisterBuiltinFunc(ast.InternalPrint.Name, builtinPrint) + return v1.NewPrintHook(w) } diff --git a/vendor/github.com/open-policy-agent/opa/topdown/print/doc.go b/vendor/github.com/open-policy-agent/opa/topdown/print/doc.go new file mode 100644 index 000000000..c2ee0eca7 --- /dev/null +++ b/vendor/github.com/open-policy-agent/opa/topdown/print/doc.go @@ -0,0 +1,8 @@ +// Copyright 2016 The OPA Authors. All rights reserved. +// Use of this source code is governed by an Apache2 +// license that can be found in the LICENSE file. + +// Deprecated: This package is intended for older projects transitioning from OPA v0.x and will remain for the lifetime of OPA v1.x, but its use is not recommended. +// For newer features and behaviours, such as defaulting to the Rego v1 syntax, use the corresponding components in the [github.com/open-policy-agent/opa/v1] package instead. +// See https://www.openpolicyagent.org/docs/latest/v0-compatibility/ for more information. +package print diff --git a/vendor/github.com/open-policy-agent/opa/topdown/print/print.go b/vendor/github.com/open-policy-agent/opa/topdown/print/print.go index 0fb6abdca..66ffbb176 100644 --- a/vendor/github.com/open-policy-agent/opa/topdown/print/print.go +++ b/vendor/github.com/open-policy-agent/opa/topdown/print/print.go @@ -1,21 +1,14 @@ package print import ( - "context" - - "github.com/open-policy-agent/opa/ast" + v1 "github.com/open-policy-agent/opa/v1/topdown/print" ) // Context provides the Hook implementation context about the print() call. -type Context struct { - Context context.Context // request context passed when query executed - Location *ast.Location // location of print call -} +type Context = v1.Context // Hook defines the interface that callers can implement to receive print // statement outputs. If the hook returns an error, it will be surfaced if // strict builtin error checking is enabled (otherwise, it will not halt // execution.) -type Hook interface { - Print(Context, string) error -} +type Hook = v1.Hook diff --git a/vendor/github.com/open-policy-agent/opa/topdown/query.go b/vendor/github.com/open-policy-agent/opa/topdown/query.go index 8406cfdd8..d24060991 100644 --- a/vendor/github.com/open-policy-agent/opa/topdown/query.go +++ b/vendor/github.com/open-policy-agent/opa/topdown/query.go @@ -1,599 +1,24 @@ package topdown import ( - "context" - "crypto/rand" - "io" - "sort" - "time" - - "github.com/open-policy-agent/opa/ast" - "github.com/open-policy-agent/opa/metrics" - "github.com/open-policy-agent/opa/resolver" - "github.com/open-policy-agent/opa/storage" - "github.com/open-policy-agent/opa/topdown/builtins" - "github.com/open-policy-agent/opa/topdown/cache" - "github.com/open-policy-agent/opa/topdown/copypropagation" - "github.com/open-policy-agent/opa/topdown/print" - "github.com/open-policy-agent/opa/tracing" + "github.com/open-policy-agent/opa/v1/ast" + v1 "github.com/open-policy-agent/opa/v1/topdown" ) // QueryResultSet represents a collection of results returned by a query. -type QueryResultSet []QueryResult +type QueryResultSet = v1.QueryResultSet // QueryResult represents a single result returned by a query. The result // contains bindings for all variables that appear in the query. -type QueryResult map[ast.Var]*ast.Term +type QueryResult = v1.QueryResult // Query provides a configurable interface for performing query evaluation. -type Query struct { - seed io.Reader - time time.Time - cancel Cancel - query ast.Body - queryCompiler ast.QueryCompiler - compiler *ast.Compiler - store storage.Store - txn storage.Transaction - input *ast.Term - external *resolverTrie - tracers []QueryTracer - plugTraceVars bool - unknowns []*ast.Term - partialNamespace string - skipSaveNamespace bool - metrics metrics.Metrics - instr *Instrumentation - disableInlining []ast.Ref - shallowInlining bool - genvarprefix string - runtime *ast.Term - builtins map[string]*Builtin - indexing bool - earlyExit bool - interQueryBuiltinCache cache.InterQueryCache - interQueryBuiltinValueCache cache.InterQueryValueCache - ndBuiltinCache builtins.NDBCache - strictBuiltinErrors bool - builtinErrorList *[]Error - strictObjects bool - printHook print.Hook - tracingOpts tracing.Options - virtualCache VirtualCache -} +type Query = v1.Query // Builtin represents a built-in function that queries can call. -type Builtin struct { - Decl *ast.Builtin - Func BuiltinFunc -} +type Builtin = v1.Builtin // NewQuery returns a new Query object that can be run. func NewQuery(query ast.Body) *Query { - return &Query{ - query: query, - genvarprefix: ast.WildcardPrefix, - indexing: true, - earlyExit: true, - external: newResolverTrie(), - } -} - -// WithQueryCompiler sets the queryCompiler used for the query. -func (q *Query) WithQueryCompiler(queryCompiler ast.QueryCompiler) *Query { - q.queryCompiler = queryCompiler - return q -} - -// WithCompiler sets the compiler to use for the query. -func (q *Query) WithCompiler(compiler *ast.Compiler) *Query { - q.compiler = compiler - return q -} - -// WithStore sets the store to use for the query. -func (q *Query) WithStore(store storage.Store) *Query { - q.store = store - return q -} - -// WithTransaction sets the transaction to use for the query. All queries -// should be performed over a consistent snapshot of the storage layer. -func (q *Query) WithTransaction(txn storage.Transaction) *Query { - q.txn = txn - return q -} - -// WithCancel sets the cancellation object to use for the query. Set this if -// you need to abort queries based on a deadline. This is optional. -func (q *Query) WithCancel(cancel Cancel) *Query { - q.cancel = cancel - return q -} - -// WithInput sets the input object to use for the query. References rooted at -// input will be evaluated against this value. This is optional. -func (q *Query) WithInput(input *ast.Term) *Query { - q.input = input - return q -} - -// WithTracer adds a query tracer to use during evaluation. This is optional. -// Deprecated: Use WithQueryTracer instead. -func (q *Query) WithTracer(tracer Tracer) *Query { - qt, ok := tracer.(QueryTracer) - if !ok { - qt = WrapLegacyTracer(tracer) - } - return q.WithQueryTracer(qt) -} - -// WithQueryTracer adds a query tracer to use during evaluation. This is optional. -// Disabled QueryTracers will be ignored. -func (q *Query) WithQueryTracer(tracer QueryTracer) *Query { - if !tracer.Enabled() { - return q - } - - q.tracers = append(q.tracers, tracer) - - // If *any* of the tracers require local variable metadata we need to - // enabled plugging local trace variables. - conf := tracer.Config() - if conf.PlugLocalVars { - q.plugTraceVars = true - } - - return q -} - -// WithMetrics sets the metrics collection to add evaluation metrics to. This -// is optional. -func (q *Query) WithMetrics(m metrics.Metrics) *Query { - q.metrics = m - return q -} - -// WithInstrumentation sets the instrumentation configuration to enable on the -// evaluation process. By default, instrumentation is turned off. -func (q *Query) WithInstrumentation(instr *Instrumentation) *Query { - q.instr = instr - return q -} - -// WithUnknowns sets the initial set of variables or references to treat as -// unknown during query evaluation. This is required for partial evaluation. -func (q *Query) WithUnknowns(terms []*ast.Term) *Query { - q.unknowns = terms - return q -} - -// WithPartialNamespace sets the namespace to use for supporting rules -// generated as part of the partial evaluation process. The ns value must be a -// valid package path component. -func (q *Query) WithPartialNamespace(ns string) *Query { - q.partialNamespace = ns - return q -} - -// WithSkipPartialNamespace disables namespacing of saved support rules that are generated -// from the original policy (rules which are completely synthetic are still namespaced.) -func (q *Query) WithSkipPartialNamespace(yes bool) *Query { - q.skipSaveNamespace = yes - return q -} - -// WithDisableInlining adds a set of paths to the query that should be excluded from -// inlining. Inlining during partial evaluation can be expensive in some cases -// (e.g., when a cross-product is computed.) Disabling inlining avoids expensive -// computation at the cost of generating support rules. -func (q *Query) WithDisableInlining(paths []ast.Ref) *Query { - q.disableInlining = paths - return q -} - -// WithShallowInlining disables aggressive inlining performed during partial evaluation. -// When shallow inlining is enabled rules that depend (transitively) on unknowns are not inlined. -// Only rules/values that are completely known will be inlined. -func (q *Query) WithShallowInlining(yes bool) *Query { - q.shallowInlining = yes - return q -} - -// WithRuntime sets the runtime data to execute the query with. The runtime data -// can be returned by the `opa.runtime` built-in function. -func (q *Query) WithRuntime(runtime *ast.Term) *Query { - q.runtime = runtime - return q -} - -// WithBuiltins adds a set of built-in functions that can be called by the -// query. -func (q *Query) WithBuiltins(builtins map[string]*Builtin) *Query { - q.builtins = builtins - return q -} - -// WithIndexing will enable or disable using rule indexing for the evaluation -// of the query. The default is enabled. -func (q *Query) WithIndexing(enabled bool) *Query { - q.indexing = enabled - return q -} - -// WithEarlyExit will enable or disable using 'early exit' for the evaluation -// of the query. The default is enabled. -func (q *Query) WithEarlyExit(enabled bool) *Query { - q.earlyExit = enabled - return q -} - -// WithSeed sets a reader that will seed randomization required by built-in functions. -// If a seed is not provided crypto/rand.Reader is used. -func (q *Query) WithSeed(r io.Reader) *Query { - q.seed = r - return q -} - -// WithTime sets the time that will be returned by the time.now_ns() built-in function. -func (q *Query) WithTime(x time.Time) *Query { - q.time = x - return q -} - -// WithInterQueryBuiltinCache sets the inter-query cache that built-in functions can utilize. -func (q *Query) WithInterQueryBuiltinCache(c cache.InterQueryCache) *Query { - q.interQueryBuiltinCache = c - return q -} - -// WithInterQueryBuiltinValueCache sets the inter-query value cache that built-in functions can utilize. -func (q *Query) WithInterQueryBuiltinValueCache(c cache.InterQueryValueCache) *Query { - q.interQueryBuiltinValueCache = c - return q -} - -// WithNDBuiltinCache sets the non-deterministic builtin cache. -func (q *Query) WithNDBuiltinCache(c builtins.NDBCache) *Query { - q.ndBuiltinCache = c - return q -} - -// WithStrictBuiltinErrors tells the evaluator to treat all built-in function errors as fatal errors. -func (q *Query) WithStrictBuiltinErrors(yes bool) *Query { - q.strictBuiltinErrors = yes - return q -} - -// WithBuiltinErrorList supplies a pointer to an Error slice to store built-in function errors -// encountered during evaluation. This error slice can be inspected after evaluation to determine -// which built-in function errors occurred. -func (q *Query) WithBuiltinErrorList(list *[]Error) *Query { - q.builtinErrorList = list - return q -} - -// WithResolver configures an external resolver to use for the given ref. -func (q *Query) WithResolver(ref ast.Ref, r resolver.Resolver) *Query { - q.external.Put(ref, r) - return q -} - -func (q *Query) WithPrintHook(h print.Hook) *Query { - q.printHook = h - return q -} - -// WithDistributedTracingOpts sets the options to be used by distributed tracing. -func (q *Query) WithDistributedTracingOpts(tr tracing.Options) *Query { - q.tracingOpts = tr - return q -} - -// WithStrictObjects tells the evaluator to avoid the "lazy object" optimization -// applied when reading objects from the store. It will result in higher memory -// usage and should only be used temporarily while adjusting code that breaks -// because of the optimization. -func (q *Query) WithStrictObjects(yes bool) *Query { - q.strictObjects = yes - return q -} - -// WithVirtualCache sets the VirtualCache to use during evaluation. This is -// optional, and if not set, the default cache is used. -func (q *Query) WithVirtualCache(vc VirtualCache) *Query { - q.virtualCache = vc - return q -} - -// PartialRun executes partial evaluation on the query with respect to unknown -// values. Partial evaluation attempts to evaluate as much of the query as -// possible without requiring values for the unknowns set on the query. The -// result of partial evaluation is a new set of queries that can be evaluated -// once the unknown value is known. In addition to new queries, partial -// evaluation may produce additional support modules that should be used in -// conjunction with the partially evaluated queries. -func (q *Query) PartialRun(ctx context.Context) (partials []ast.Body, support []*ast.Module, err error) { - if q.partialNamespace == "" { - q.partialNamespace = "partial" // lazily initialize partial namespace - } - if q.seed == nil { - q.seed = rand.Reader - } - if !q.time.IsZero() { - q.time = time.Now() - } - if q.metrics == nil { - q.metrics = metrics.New() - } - - f := &queryIDFactory{} - b := newBindings(0, q.instr) - - var vc VirtualCache - if q.virtualCache != nil { - vc = q.virtualCache - } else { - vc = NewVirtualCache() - } - - e := &eval{ - ctx: ctx, - metrics: q.metrics, - seed: q.seed, - time: ast.NumberTerm(int64ToJSONNumber(q.time.UnixNano())), - cancel: q.cancel, - query: q.query, - queryCompiler: q.queryCompiler, - queryIDFact: f, - queryID: f.Next(), - bindings: b, - compiler: q.compiler, - store: q.store, - baseCache: newBaseCache(), - targetStack: newRefStack(), - txn: q.txn, - input: q.input, - external: q.external, - tracers: q.tracers, - traceEnabled: len(q.tracers) > 0, - plugTraceVars: q.plugTraceVars, - instr: q.instr, - builtins: q.builtins, - builtinCache: builtins.Cache{}, - functionMocks: newFunctionMocksStack(), - interQueryBuiltinCache: q.interQueryBuiltinCache, - interQueryBuiltinValueCache: q.interQueryBuiltinValueCache, - ndBuiltinCache: q.ndBuiltinCache, - virtualCache: vc, - comprehensionCache: newComprehensionCache(), - saveSet: newSaveSet(q.unknowns, b, q.instr), - saveStack: newSaveStack(), - saveSupport: newSaveSupport(), - saveNamespace: ast.StringTerm(q.partialNamespace), - skipSaveNamespace: q.skipSaveNamespace, - inliningControl: &inliningControl{ - shallow: q.shallowInlining, - }, - genvarprefix: q.genvarprefix, - runtime: q.runtime, - indexing: q.indexing, - earlyExit: q.earlyExit, - builtinErrors: &builtinErrors{}, - printHook: q.printHook, - strictObjects: q.strictObjects, - } - - if len(q.disableInlining) > 0 { - e.inliningControl.PushDisable(q.disableInlining, false) - } - - e.caller = e - q.metrics.Timer(metrics.RegoPartialEval).Start() - defer q.metrics.Timer(metrics.RegoPartialEval).Stop() - - livevars := ast.NewVarSet() - for _, t := range q.unknowns { - switch v := t.Value.(type) { - case ast.Var: - livevars.Add(v) - case ast.Ref: - livevars.Add(v[0].Value.(ast.Var)) - } - } - - ast.WalkVars(q.query, func(x ast.Var) bool { - if !x.IsGenerated() { - livevars.Add(x) - } - return false - }) - - p := copypropagation.New(livevars).WithCompiler(q.compiler) - - err = e.Run(func(e *eval) error { - - // Build output from saved expressions. - body := ast.NewBody() - - for _, elem := range e.saveStack.Stack[len(e.saveStack.Stack)-1] { - body.Append(elem.Plug(e.bindings)) - } - - // Include bindings as exprs so that when caller evals the result, they - // can obtain values for the vars in their query. - bindingExprs := []*ast.Expr{} - _ = e.bindings.Iter(e.bindings, func(a, b *ast.Term) error { - bindingExprs = append(bindingExprs, ast.Equality.Expr(a, b)) - return nil - }) // cannot return error - - // Sort binding expressions so that results are deterministic. - sort.Slice(bindingExprs, func(i, j int) bool { - return bindingExprs[i].Compare(bindingExprs[j]) < 0 - }) - - for i := range bindingExprs { - body.Append(bindingExprs[i]) - } - - // Skip this rule body if it fails to type-check. - // Type-checking failure means the rule body will never succeed. - if !e.compiler.PassesTypeCheck(body) { - return nil - } - - if !q.shallowInlining { - body = applyCopyPropagation(p, e.instr, body) - } - - partials = append(partials, body) - return nil - }) - - support = e.saveSupport.List() - - if len(e.builtinErrors.errs) > 0 { - if q.strictBuiltinErrors { - err = e.builtinErrors.errs[0] - } else if q.builtinErrorList != nil { - // If a builtinErrorList has been supplied, we must use pointer indirection - // to append to it. builtinErrorList is a slice pointer so that errors can be - // appended to it without returning a new slice and changing the interface - // of PartialRun. - for _, err := range e.builtinErrors.errs { - if tdError, ok := err.(*Error); ok { - *(q.builtinErrorList) = append(*(q.builtinErrorList), *tdError) - } else { - *(q.builtinErrorList) = append(*(q.builtinErrorList), Error{ - Code: BuiltinErr, - Message: err.Error(), - }) - } - } - } - } - - for i := range support { - sort.Slice(support[i].Rules, func(j, k int) bool { - return support[i].Rules[j].Compare(support[i].Rules[k]) < 0 - }) - } - - return partials, support, err -} - -// Run is a wrapper around Iter that accumulates query results and returns them -// in one shot. -func (q *Query) Run(ctx context.Context) (QueryResultSet, error) { - qrs := QueryResultSet{} - return qrs, q.Iter(ctx, func(qr QueryResult) error { - qrs = append(qrs, qr) - return nil - }) -} - -// Iter executes the query and invokes the iter function with query results -// produced by evaluating the query. -func (q *Query) Iter(ctx context.Context, iter func(QueryResult) error) error { - // Query evaluation must not be allowed if the compiler has errors and is in an undefined, possibly inconsistent state - if q.compiler != nil && len(q.compiler.Errors) > 0 { - return &Error{ - Code: InternalErr, - Message: "compiler has errors", - } - } - - if q.seed == nil { - q.seed = rand.Reader - } - if q.time.IsZero() { - q.time = time.Now() - } - if q.metrics == nil { - q.metrics = metrics.New() - } - - f := &queryIDFactory{} - - var vc VirtualCache - if q.virtualCache != nil { - vc = q.virtualCache - } else { - vc = NewVirtualCache() - } - - e := &eval{ - ctx: ctx, - metrics: q.metrics, - seed: q.seed, - time: ast.NumberTerm(int64ToJSONNumber(q.time.UnixNano())), - cancel: q.cancel, - query: q.query, - queryCompiler: q.queryCompiler, - queryIDFact: f, - queryID: f.Next(), - bindings: newBindings(0, q.instr), - compiler: q.compiler, - store: q.store, - baseCache: newBaseCache(), - targetStack: newRefStack(), - txn: q.txn, - input: q.input, - external: q.external, - tracers: q.tracers, - traceEnabled: len(q.tracers) > 0, - plugTraceVars: q.plugTraceVars, - instr: q.instr, - builtins: q.builtins, - builtinCache: builtins.Cache{}, - functionMocks: newFunctionMocksStack(), - interQueryBuiltinCache: q.interQueryBuiltinCache, - interQueryBuiltinValueCache: q.interQueryBuiltinValueCache, - ndBuiltinCache: q.ndBuiltinCache, - virtualCache: vc, - comprehensionCache: newComprehensionCache(), - genvarprefix: q.genvarprefix, - runtime: q.runtime, - indexing: q.indexing, - earlyExit: q.earlyExit, - builtinErrors: &builtinErrors{}, - printHook: q.printHook, - tracingOpts: q.tracingOpts, - strictObjects: q.strictObjects, - } - e.caller = e - q.metrics.Timer(metrics.RegoQueryEval).Start() - err := e.Run(func(e *eval) error { - qr := QueryResult{} - _ = e.bindings.Iter(nil, func(k, v *ast.Term) error { - qr[k.Value.(ast.Var)] = v - return nil - }) // cannot return error - return iter(qr) - }) - - if len(e.builtinErrors.errs) > 0 { - if q.strictBuiltinErrors { - err = e.builtinErrors.errs[0] - } else if q.builtinErrorList != nil { - // If a builtinErrorList has been supplied, we must use pointer indirection - // to append to it. builtinErrorList is a slice pointer so that errors can be - // appended to it without returning a new slice and changing the interface - // of Iter. - for _, err := range e.builtinErrors.errs { - if tdError, ok := err.(*Error); ok { - *(q.builtinErrorList) = append(*(q.builtinErrorList), *tdError) - } else { - *(q.builtinErrorList) = append(*(q.builtinErrorList), Error{ - Code: BuiltinErr, - Message: err.Error(), - }) - } - } - } - } - - q.metrics.Timer(metrics.RegoQueryEval).Stop() - return err + return v1.NewQuery(query) } diff --git a/vendor/github.com/open-policy-agent/opa/topdown/trace.go b/vendor/github.com/open-policy-agent/opa/topdown/trace.go index 277c94b62..4d4cc295e 100644 --- a/vendor/github.com/open-policy-agent/opa/topdown/trace.go +++ b/vendor/github.com/open-policy-agent/opa/topdown/trace.go @@ -5,898 +5,108 @@ package topdown import ( - "bytes" - "fmt" "io" - "slices" - "strings" - iStrs "github.com/open-policy-agent/opa/internal/strings" - - "github.com/open-policy-agent/opa/ast" - "github.com/open-policy-agent/opa/topdown/builtins" -) - -const ( - minLocationWidth = 5 // len("query") - maxIdealLocationWidth = 64 - columnPadding = 4 - maxExprVarWidth = 32 - maxPrettyExprVarWidth = 64 + v1 "github.com/open-policy-agent/opa/v1/topdown" ) // Op defines the types of tracing events. -type Op string +type Op = v1.Op const ( // EnterOp is emitted when a new query is about to be evaluated. - EnterOp Op = "Enter" + EnterOp = v1.EnterOp // ExitOp is emitted when a query has evaluated to true. - ExitOp Op = "Exit" + ExitOp = v1.ExitOp // EvalOp is emitted when an expression is about to be evaluated. - EvalOp Op = "Eval" + EvalOp = v1.EvalOp // RedoOp is emitted when an expression, rule, or query is being re-evaluated. - RedoOp Op = "Redo" + RedoOp = v1.RedoOp // SaveOp is emitted when an expression is saved instead of evaluated // during partial evaluation. - SaveOp Op = "Save" + SaveOp = v1.SaveOp // FailOp is emitted when an expression evaluates to false. - FailOp Op = "Fail" + FailOp = v1.FailOp // DuplicateOp is emitted when a query has produced a duplicate value. The search // will stop at the point where the duplicate was emitted and backtrack. - DuplicateOp Op = "Duplicate" + DuplicateOp = v1.DuplicateOp // NoteOp is emitted when an expression invokes a tracing built-in function. - NoteOp Op = "Note" + NoteOp = v1.NoteOp // IndexOp is emitted during an expression evaluation to represent lookup // matches. - IndexOp Op = "Index" + IndexOp = v1.IndexOp // WasmOp is emitted when resolving a ref using an external // Resolver. - WasmOp Op = "Wasm" + WasmOp = v1.WasmOp // UnifyOp is emitted when two terms are unified. Node will be set to an // equality expression with the two terms. This Node will not have location // info. - UnifyOp Op = "Unify" - FailedAssertionOp Op = "FailedAssertion" + UnifyOp = v1.UnifyOp + FailedAssertionOp = v1.FailedAssertionOp ) // VarMetadata provides some user facing information about // a variable in some policy. -type VarMetadata struct { - Name ast.Var `json:"name"` - Location *ast.Location `json:"location"` -} +type VarMetadata = v1.VarMetadata // Event contains state associated with a tracing event. -type Event struct { - Op Op // Identifies type of event. - Node ast.Node // Contains AST node relevant to the event. - Location *ast.Location // The location of the Node this event relates to. - QueryID uint64 // Identifies the query this event belongs to. - ParentID uint64 // Identifies the parent query this event belongs to. - Locals *ast.ValueMap // Contains local variable bindings from the query context. Nil if variables were not included in the trace event. - LocalMetadata map[ast.Var]VarMetadata // Contains metadata for the local variable bindings. Nil if variables were not included in the trace event. - Message string // Contains message for Note events. - Ref *ast.Ref // Identifies the subject ref for the event. Only applies to Index and Wasm operations. - - input *ast.Term - bindings *bindings - localVirtualCacheSnapshot *ast.ValueMap -} - -func (evt *Event) WithInput(input *ast.Term) *Event { - evt.input = input - return evt -} - -// HasRule returns true if the Event contains an ast.Rule. -func (evt *Event) HasRule() bool { - _, ok := evt.Node.(*ast.Rule) - return ok -} - -// HasBody returns true if the Event contains an ast.Body. -func (evt *Event) HasBody() bool { - _, ok := evt.Node.(ast.Body) - return ok -} - -// HasExpr returns true if the Event contains an ast.Expr. -func (evt *Event) HasExpr() bool { - _, ok := evt.Node.(*ast.Expr) - return ok -} - -// Equal returns true if this event is equal to the other event. -func (evt *Event) Equal(other *Event) bool { - if evt.Op != other.Op { - return false - } - if evt.QueryID != other.QueryID { - return false - } - if evt.ParentID != other.ParentID { - return false - } - if !evt.equalNodes(other) { - return false - } - return evt.Locals.Equal(other.Locals) -} - -func (evt *Event) String() string { - return fmt.Sprintf("%v %v %v (qid=%v, pqid=%v)", evt.Op, evt.Node, evt.Locals, evt.QueryID, evt.ParentID) -} - -// Input returns the input object as it was at the event. -func (evt *Event) Input() *ast.Term { - return evt.input -} - -// Plug plugs event bindings into the provided ast.Term. Because bindings are mutable, this only makes sense to do when -// the event is emitted rather than on recorded trace events as the bindings are going to be different by then. -func (evt *Event) Plug(term *ast.Term) *ast.Term { - return evt.bindings.Plug(term) -} - -func (evt *Event) equalNodes(other *Event) bool { - switch a := evt.Node.(type) { - case ast.Body: - if b, ok := other.Node.(ast.Body); ok { - return a.Equal(b) - } - case *ast.Rule: - if b, ok := other.Node.(*ast.Rule); ok { - return a.Equal(b) - } - case *ast.Expr: - if b, ok := other.Node.(*ast.Expr); ok { - return a.Equal(b) - } - case nil: - return other.Node == nil - } - return false -} +type Event = v1.Event // Tracer defines the interface for tracing in the top-down evaluation engine. // Deprecated: Use QueryTracer instead. -type Tracer interface { - Enabled() bool - Trace(*Event) -} +type Tracer = v1.Tracer // QueryTracer defines the interface for tracing in the top-down evaluation engine. // The implementation can provide additional configuration to modify the tracing // behavior for query evaluations. -type QueryTracer interface { - Enabled() bool - TraceEvent(Event) - Config() TraceConfig -} +type QueryTracer = v1.QueryTracer // TraceConfig defines some common configuration for Tracer implementations -type TraceConfig struct { - PlugLocalVars bool // Indicate whether to plug local variable bindings before calling into the tracer. -} - -// legacyTracer Implements the QueryTracer interface by wrapping an older Tracer instance. -type legacyTracer struct { - t Tracer -} - -func (l *legacyTracer) Enabled() bool { - return l.t.Enabled() -} - -func (l *legacyTracer) Config() TraceConfig { - return TraceConfig{ - PlugLocalVars: true, // For backwards compatibility old tracers will plug local variables - } -} - -func (l *legacyTracer) TraceEvent(evt Event) { - l.t.Trace(&evt) -} +type TraceConfig = v1.TraceConfig // WrapLegacyTracer will create a new QueryTracer which wraps an // older Tracer instance. func WrapLegacyTracer(tracer Tracer) QueryTracer { - return &legacyTracer{t: tracer} + return v1.WrapLegacyTracer(tracer) } // BufferTracer implements the Tracer and QueryTracer interface by // simply buffering all events received. -type BufferTracer []*Event +type BufferTracer = v1.BufferTracer // NewBufferTracer returns a new BufferTracer. func NewBufferTracer() *BufferTracer { - return &BufferTracer{} -} - -// Enabled always returns true if the BufferTracer is instantiated. -func (b *BufferTracer) Enabled() bool { - return b != nil -} - -// Trace adds the event to the buffer. -// Deprecated: Use TraceEvent instead. -func (b *BufferTracer) Trace(evt *Event) { - *b = append(*b, evt) -} - -// TraceEvent adds the event to the buffer. -func (b *BufferTracer) TraceEvent(evt Event) { - *b = append(*b, &evt) -} - -// Config returns the Tracers standard configuration -func (b *BufferTracer) Config() TraceConfig { - return TraceConfig{PlugLocalVars: true} + return v1.NewBufferTracer() } // PrettyTrace pretty prints the trace to the writer. func PrettyTrace(w io.Writer, trace []*Event) { - PrettyTraceWithOpts(w, trace, PrettyTraceOptions{}) + v1.PrettyTrace(w, trace) } // PrettyTraceWithLocation prints the trace to the writer and includes location information func PrettyTraceWithLocation(w io.Writer, trace []*Event) { - PrettyTraceWithOpts(w, trace, PrettyTraceOptions{Locations: true}) -} - -type PrettyTraceOptions struct { - Locations bool // Include location information - ExprVariables bool // Include variables found in the expression - LocalVariables bool // Include all local variables -} - -type traceRow []string - -func (r *traceRow) add(s string) { - *r = append(*r, s) -} - -type traceTable struct { - rows []traceRow - maxWidths []int + v1.PrettyTraceWithLocation(w, trace) } -func (t *traceTable) add(row traceRow) { - t.rows = append(t.rows, row) - for i := range row { - if i >= len(t.maxWidths) { - t.maxWidths = append(t.maxWidths, len(row[i])) - } else if len(row[i]) > t.maxWidths[i] { - t.maxWidths[i] = len(row[i]) - } - } -} - -func (t *traceTable) write(w io.Writer, padding int) { - for _, row := range t.rows { - for i, cell := range row { - width := t.maxWidths[i] + padding - if i < len(row)-1 { - _, _ = fmt.Fprintf(w, "%-*s ", width, cell) - } else { - _, _ = fmt.Fprintf(w, "%s", cell) - } - } - _, _ = fmt.Fprintln(w) - } -} +type PrettyTraceOptions = v1.PrettyTraceOptions func PrettyTraceWithOpts(w io.Writer, trace []*Event, opts PrettyTraceOptions) { - depths := depths{} - - // FIXME: Can we shorten each location as we process each trace event instead of beforehand? - filePathAliases, _ := getShortenedFileNames(trace) - - table := traceTable{} - - for _, event := range trace { - depth := depths.GetOrSet(event.QueryID, event.ParentID) - row := traceRow{} - - if opts.Locations { - location := formatLocation(event, filePathAliases) - row.add(location) - } - - row.add(formatEvent(event, depth)) - - if opts.ExprVariables { - vars := exprLocalVars(event) - keys := sortedKeys(vars) - - buf := new(bytes.Buffer) - buf.WriteString("{") - for i, k := range keys { - if i > 0 { - buf.WriteString(", ") - } - _, _ = fmt.Fprintf(buf, "%v: %s", k, iStrs.Truncate(vars.Get(k).String(), maxExprVarWidth)) - } - buf.WriteString("}") - row.add(buf.String()) - } - - if opts.LocalVariables { - if locals := event.Locals; locals != nil { - keys := sortedKeys(locals) - - buf := new(bytes.Buffer) - buf.WriteString("{") - for i, k := range keys { - if i > 0 { - buf.WriteString(", ") - } - _, _ = fmt.Fprintf(buf, "%v: %s", k, iStrs.Truncate(locals.Get(k).String(), maxExprVarWidth)) - } - buf.WriteString("}") - row.add(buf.String()) - } else { - row.add("{}") - } - } - - table.add(row) - } - - table.write(w, columnPadding) -} - -func sortedKeys(vm *ast.ValueMap) []ast.Value { - keys := make([]ast.Value, 0, vm.Len()) - vm.Iter(func(k, _ ast.Value) bool { - keys = append(keys, k) - return false - }) - slices.SortFunc(keys, func(a, b ast.Value) int { - return strings.Compare(a.String(), b.String()) - }) - return keys -} - -func exprLocalVars(e *Event) *ast.ValueMap { - vars := ast.NewValueMap() - - findVars := func(term *ast.Term) bool { - //if r, ok := term.Value.(ast.Ref); ok { - // fmt.Printf("ref: %v\n", r) - // //return true - //} - if name, ok := term.Value.(ast.Var); ok { - if meta, ok := e.LocalMetadata[name]; ok { - if val := e.Locals.Get(name); val != nil { - vars.Put(meta.Name, val) - } - } - } - return false - } - - if r, ok := e.Node.(*ast.Rule); ok { - // We're only interested in vars in the head, not the body - ast.WalkTerms(r.Head, findVars) - return vars - } - - // The local cache snapshot only contains a snapshot for those refs present in the event node, - // so they can all be added to the vars map. - e.localVirtualCacheSnapshot.Iter(func(k, v ast.Value) bool { - vars.Put(k, v) - return false - }) - - ast.WalkTerms(e.Node, findVars) - - return vars -} - -func formatEvent(event *Event, depth int) string { - padding := formatEventPadding(event, depth) - if event.Op == NoteOp { - return fmt.Sprintf("%v%v %q", padding, event.Op, event.Message) - } - - var details interface{} - if node, ok := event.Node.(*ast.Rule); ok { - details = node.Path() - } else if event.Ref != nil { - details = event.Ref - } else { - details = rewrite(event).Node - } - - template := "%v%v %v" - opts := []interface{}{padding, event.Op, details} - - if event.Message != "" { - template += " %v" - opts = append(opts, event.Message) - } - - return fmt.Sprintf(template, opts...) + v1.PrettyTraceWithOpts(w, trace, opts) } -func formatEventPadding(event *Event, depth int) string { - spaces := formatEventSpaces(event, depth) - if spaces > 1 { - return strings.Repeat("| ", spaces-1) - } - return "" -} - -func formatEventSpaces(event *Event, depth int) int { - switch event.Op { - case EnterOp: - return depth - case RedoOp: - if _, ok := event.Node.(*ast.Expr); !ok { - return depth - } - } - return depth + 1 -} - -// getShortenedFileNames will return a map of file paths to shortened aliases -// that were found in the trace. It also returns the longest location expected -func getShortenedFileNames(trace []*Event) (map[string]string, int) { - // Get a deduplicated list of all file paths - // and the longest file path size - fpAliases := map[string]string{} - var canShorten []string - longestLocation := 0 - for _, event := range trace { - if event.Location != nil { - if event.Location.File != "" { - // length of ":" - curLen := len(event.Location.File) + numDigits10(event.Location.Row) + 1 - if curLen > longestLocation { - longestLocation = curLen - } - - if _, ok := fpAliases[event.Location.File]; ok { - continue - } - - canShorten = append(canShorten, event.Location.File) - - // Default to just alias their full path - fpAliases[event.Location.File] = event.Location.File - } else { - // length of ":" - curLen := minLocationWidth + numDigits10(event.Location.Row) + 1 - if curLen > longestLocation { - longestLocation = curLen - } - } - } - } - - if len(canShorten) > 0 && longestLocation > maxIdealLocationWidth { - fpAliases, longestLocation = iStrs.TruncateFilePaths(maxIdealLocationWidth, longestLocation, canShorten...) - } - - return fpAliases, longestLocation -} - -func numDigits10(n int) int { - if n < 10 { - return 1 - } - return numDigits10(n/10) + 1 -} - -func formatLocation(event *Event, fileAliases map[string]string) string { - - location := event.Location - if location == nil { - return "" - } - - if location.File == "" { - return fmt.Sprintf("query:%v", location.Row) - } - - return fmt.Sprintf("%v:%v", fileAliases[location.File], location.Row) -} - -// depths is a helper for computing the depth of an event. Events within the -// same query all have the same depth. The depth of query is -// depth(parent(query))+1. -type depths map[uint64]int - -func (ds depths) GetOrSet(qid uint64, pqid uint64) int { - depth := ds[qid] - if depth == 0 { - depth = ds[pqid] - depth++ - ds[qid] = depth - } - return depth -} - -func builtinTrace(bctx BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error { - - str, err := builtins.StringOperand(operands[0].Value, 1) - if err != nil { - return handleBuiltinErr(ast.Trace.Name, bctx.Location, err) - } - - if !bctx.TraceEnabled { - return iter(ast.BooleanTerm(true)) - } - - evt := Event{ - Op: NoteOp, - Location: bctx.Location, - QueryID: bctx.QueryID, - ParentID: bctx.ParentID, - Message: string(str), - } - - for i := range bctx.QueryTracers { - bctx.QueryTracers[i].TraceEvent(evt) - } - - return iter(ast.BooleanTerm(true)) -} - -func rewrite(event *Event) *Event { - - cpy := *event - - var node ast.Node - - switch v := event.Node.(type) { - case *ast.Expr: - expr := v.Copy() - - // Hide generated local vars in 'key' position that have not been - // rewritten. - if ev, ok := v.Terms.(*ast.Every); ok { - if kv, ok := ev.Key.Value.(ast.Var); ok { - if rw, ok := cpy.LocalMetadata[kv]; !ok || rw.Name.IsGenerated() { - expr.Terms.(*ast.Every).Key = nil - } - } - } - node = expr - case ast.Body: - node = v.Copy() - case *ast.Rule: - node = v.Copy() - } - - _, _ = ast.TransformVars(node, func(v ast.Var) (ast.Value, error) { - if meta, ok := cpy.LocalMetadata[v]; ok { - return meta.Name, nil - } - return v, nil - }) - - cpy.Node = node - - return &cpy -} - -type varInfo struct { - VarMetadata - val ast.Value - exprLoc *ast.Location - col int // 0-indexed column -} - -func (v varInfo) Value() string { - if v.val != nil { - return v.val.String() - } - return "undefined" -} - -func (v varInfo) Title() string { - if v.exprLoc != nil && v.exprLoc.Text != nil { - return string(v.exprLoc.Text) - } - return string(v.Name) -} - -func padLocationText(loc *ast.Location) string { - if loc == nil { - return "" - } - - text := string(loc.Text) - - if loc.Col == 0 { - return text - } - - buf := new(bytes.Buffer) - j := 0 - for i := 1; i < loc.Col; i++ { - if len(loc.Tabs) > 0 && j < len(loc.Tabs) && loc.Tabs[j] == i { - buf.WriteString("\t") - j++ - } else { - buf.WriteString(" ") - } - } - - buf.WriteString(text) - return buf.String() -} - -type PrettyEventOpts struct { - PrettyVars bool -} - -func walkTestTerms(x interface{}, f func(*ast.Term) bool) { - var vis *ast.GenericVisitor - vis = ast.NewGenericVisitor(func(x interface{}) bool { - switch x := x.(type) { - case ast.Call: - for _, t := range x[1:] { - vis.Walk(t) - } - return true - case *ast.Expr: - if x.IsCall() { - for _, o := range x.Operands() { - vis.Walk(o) - } - for i := range x.With { - vis.Walk(x.With[i]) - } - return true - } - case *ast.Term: - return f(x) - case *ast.With: - vis.Walk(x.Value) - return true - } - return false - }) - vis.Walk(x) -} +type PrettyEventOpts = v1.PrettyEventOpts func PrettyEvent(w io.Writer, e *Event, opts PrettyEventOpts) error { - if !opts.PrettyVars { - _, _ = fmt.Fprintln(w, padLocationText(e.Location)) - return nil - } - - buf := new(bytes.Buffer) - exprVars := map[string]varInfo{} - - findVars := func(unknownAreUndefined bool) func(term *ast.Term) bool { - return func(term *ast.Term) bool { - if term.Location == nil { - return false - } - - switch v := term.Value.(type) { - case *ast.ArrayComprehension, *ast.SetComprehension, *ast.ObjectComprehension: - // we don't report on the internals of a comprehension, as it's already evaluated, and we won't have the local vars. - return true - case ast.Var: - var info *varInfo - if meta, ok := e.LocalMetadata[v]; ok { - info = &varInfo{ - VarMetadata: meta, - val: e.Locals.Get(v), - exprLoc: term.Location, - } - } else if unknownAreUndefined { - info = &varInfo{ - VarMetadata: VarMetadata{Name: v}, - exprLoc: term.Location, - col: term.Location.Col, - } - } - - if info != nil { - if v, exists := exprVars[info.Title()]; !exists || v.val == nil { - if term.Location != nil { - info.col = term.Location.Col - } - exprVars[info.Title()] = *info - } - } - } - return false - } - } - - expr, ok := e.Node.(*ast.Expr) - if !ok || expr == nil { - return nil - } - - base := expr.BaseCogeneratedExpr() - exprText := padLocationText(base.Location) - buf.WriteString(exprText) - - e.localVirtualCacheSnapshot.Iter(func(k, v ast.Value) bool { - var info *varInfo - switch k := k.(type) { - case ast.Ref: - info = &varInfo{ - VarMetadata: VarMetadata{Name: ast.Var(k.String())}, - val: v, - exprLoc: k[0].Location, - col: k[0].Location.Col, - } - case *ast.ArrayComprehension: - info = &varInfo{ - VarMetadata: VarMetadata{Name: ast.Var(k.String())}, - val: v, - exprLoc: k.Term.Location, - col: k.Term.Location.Col, - } - case *ast.SetComprehension: - info = &varInfo{ - VarMetadata: VarMetadata{Name: ast.Var(k.String())}, - val: v, - exprLoc: k.Term.Location, - col: k.Term.Location.Col, - } - case *ast.ObjectComprehension: - info = &varInfo{ - VarMetadata: VarMetadata{Name: ast.Var(k.String())}, - val: v, - exprLoc: k.Key.Location, - col: k.Key.Location.Col, - } - } - - if info != nil { - exprVars[info.Title()] = *info - } - - return false - }) - - // If the expression is negated, we can't confidently assert that vars with unknown values are 'undefined', - // since the compiler might have opted out of the necessary rewrite. - walkTestTerms(expr, findVars(!expr.Negated)) - coExprs := expr.CogeneratedExprs() - for _, coExpr := range coExprs { - // Only the current "co-expr" can have undefined vars, if we don't know the value for a var in any other co-expr, - // it's unknown, not undefined. A var can be unknown if it hasn't been assigned a value yet, because the co-expr - // hasn't been evaluated yet (the fail happened before it). - walkTestTerms(coExpr, findVars(false)) - } - - printPrettyVars(buf, exprVars) - _, _ = fmt.Fprint(w, buf.String()) - return nil -} - -func printPrettyVars(w *bytes.Buffer, exprVars map[string]varInfo) { - containsTabs := false - varRows := make(map[int]interface{}) - for _, info := range exprVars { - if len(info.exprLoc.Tabs) > 0 { - containsTabs = true - } - varRows[info.exprLoc.Row] = nil - } - - if containsTabs && len(varRows) > 1 { - // We can't (currently) reliably point to var locations when they are on different rows that contain tabs. - // So we'll just print them in alphabetical order instead. - byName := make([]varInfo, 0, len(exprVars)) - for _, info := range exprVars { - byName = append(byName, info) - } - slices.SortStableFunc(byName, func(a, b varInfo) int { - return strings.Compare(a.Title(), b.Title()) - }) - - w.WriteString("\n\nWhere:\n") - for _, info := range byName { - w.WriteString(fmt.Sprintf("\n%s: %s", info.Title(), iStrs.Truncate(info.Value(), maxPrettyExprVarWidth))) - } - - return - } - - byCol := make([]varInfo, 0, len(exprVars)) - for _, info := range exprVars { - byCol = append(byCol, info) - } - slices.SortFunc(byCol, func(a, b varInfo) int { - // sort first by column, then by reverse row (to present vars in the same order they appear in the expr) - if a.col == b.col { - if a.exprLoc.Row == b.exprLoc.Row { - return strings.Compare(a.Title(), b.Title()) - } - return b.exprLoc.Row - a.exprLoc.Row - } - return a.col - b.col - }) - - if len(byCol) == 0 { - return - } - - w.WriteString("\n") - printArrows(w, byCol, -1) - for i := len(byCol) - 1; i >= 0; i-- { - w.WriteString("\n") - printArrows(w, byCol, i) - } -} - -func printArrows(w *bytes.Buffer, l []varInfo, printValueAt int) { - prevCol := 0 - var slice []varInfo - if printValueAt >= 0 { - slice = l[:printValueAt+1] - } else { - slice = l - } - isFirst := true - for i, info := range slice { - - isLast := i >= len(slice)-1 - col := info.col - - if !isLast && col == l[i+1].col { - // We're sharing the same column with another, subsequent var - continue - } - - spaces := col - 1 - if i > 0 && !isFirst { - spaces = (col - prevCol) - 1 - } - - for j := 0; j < spaces; j++ { - tab := false - for _, t := range info.exprLoc.Tabs { - if t == j+prevCol+1 { - w.WriteString("\t") - tab = true - break - } - } - if !tab { - w.WriteString(" ") - } - } - - if isLast && printValueAt >= 0 { - valueStr := iStrs.Truncate(info.Value(), maxPrettyExprVarWidth) - if (i > 0 && col == l[i-1].col) || (i < len(l)-1 && col == l[i+1].col) { - // There is another var on this column, so we need to include the name to differentiate them. - w.WriteString(fmt.Sprintf("%s: %s", info.Title(), valueStr)) - } else { - w.WriteString(valueStr) - } - } else { - w.WriteString("|") - } - prevCol = col - isFirst = false - } -} - -func init() { - RegisterBuiltinFunc(ast.Trace.Name, builtinTrace) + return v1.PrettyEvent(w, e, opts) } diff --git a/vendor/github.com/open-policy-agent/opa/topdown/type_name.go b/vendor/github.com/open-policy-agent/opa/topdown/type_name.go deleted file mode 100644 index 0a8b44aed..000000000 --- a/vendor/github.com/open-policy-agent/opa/topdown/type_name.go +++ /dev/null @@ -1,36 +0,0 @@ -// Copyright 2018 The OPA Authors. All rights reserved. -// Use of this source code is governed by an Apache2 -// license that can be found in the LICENSE file. - -package topdown - -import ( - "fmt" - - "github.com/open-policy-agent/opa/ast" -) - -func builtinTypeName(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error { - switch operands[0].Value.(type) { - case ast.Null: - return iter(ast.StringTerm("null")) - case ast.Boolean: - return iter(ast.StringTerm("boolean")) - case ast.Number: - return iter(ast.StringTerm("number")) - case ast.String: - return iter(ast.StringTerm("string")) - case *ast.Array: - return iter(ast.StringTerm("array")) - case ast.Object: - return iter(ast.StringTerm("object")) - case ast.Set: - return iter(ast.StringTerm("set")) - } - - return fmt.Errorf("illegal value") -} - -func init() { - RegisterBuiltinFunc(ast.TypeNameBuiltin.Name, builtinTypeName) -} diff --git a/vendor/github.com/open-policy-agent/opa/tracing/doc.go b/vendor/github.com/open-policy-agent/opa/tracing/doc.go new file mode 100644 index 000000000..161a3d0ce --- /dev/null +++ b/vendor/github.com/open-policy-agent/opa/tracing/doc.go @@ -0,0 +1,8 @@ +// Copyright 2024 The OPA Authors. All rights reserved. +// Use of this source code is governed by an Apache2 +// license that can be found in the LICENSE file. + +// Deprecated: This package is intended for older projects transitioning from OPA v0.x and will remain for the lifetime of OPA v1.x, but its use is not recommended. +// For newer features and behaviours, such as defaulting to the Rego v1 syntax, use the corresponding components in the [github.com/open-policy-agent/opa/v1] package instead. +// See https://www.openpolicyagent.org/docs/latest/v0-compatibility/ for more information. +package tracing diff --git a/vendor/github.com/open-policy-agent/opa/tracing/tracing.go b/vendor/github.com/open-policy-agent/opa/tracing/tracing.go index 2708b78e2..ad6ac668e 100644 --- a/vendor/github.com/open-policy-agent/opa/tracing/tracing.go +++ b/vendor/github.com/open-policy-agent/opa/tracing/tracing.go @@ -8,48 +8,38 @@ // configured sink. package tracing -import "net/http" +import ( + "net/http" + + v1 "github.com/open-policy-agent/opa/v1/tracing" +) // Options are options for the HTTPTracingService, passed along as-is. -type Options []interface{} +type Options = v1.Options // NewOptions is a helper method for constructing `tracing.Options` func NewOptions(opts ...interface{}) Options { - return opts + return v1.NewOptions(opts...) } // HTTPTracingService defines how distributed tracing comes in, server- and client-side -type HTTPTracingService interface { - // NewTransport is used when setting up an HTTP client - NewTransport(http.RoundTripper, Options) http.RoundTripper - - // NewHandler is used to wrap an http.Handler in the server - NewHandler(http.Handler, string, Options) http.Handler -} - -var tracing HTTPTracingService +type HTTPTracingService = v1.HTTPTracingService // RegisterHTTPTracing enables a HTTPTracingService for further use. func RegisterHTTPTracing(ht HTTPTracingService) { - tracing = ht + v1.RegisterHTTPTracing(ht) } // NewTransport returns another http.RoundTripper, instrumented to emit tracing // spans according to Options. Provided by the HTTPTracingService registered with // this package via RegisterHTTPTracing. func NewTransport(tr http.RoundTripper, opts Options) http.RoundTripper { - if tracing == nil { - return tr - } - return tracing.NewTransport(tr, opts) + return v1.NewTransport(tr, opts) } // NewHandler returns another http.Handler, instrumented to emit tracing spans // according to Options. Provided by the HTTPTracingService registered with // this package via RegisterHTTPTracing. func NewHandler(f http.Handler, label string, opts Options) http.Handler { - if tracing == nil { - return f - } - return tracing.NewHandler(f, label, opts) + return v1.NewHandler(f, label, opts) } diff --git a/vendor/github.com/open-policy-agent/opa/util/backoff.go b/vendor/github.com/open-policy-agent/opa/util/backoff.go index 6fbf63ef7..11da67d92 100644 --- a/vendor/github.com/open-policy-agent/opa/util/backoff.go +++ b/vendor/github.com/open-policy-agent/opa/util/backoff.go @@ -5,49 +5,19 @@ package util import ( - "math/rand" "time" -) -func init() { - // NOTE(sr): We don't need good random numbers here; it's used for jittering - // the backup timing a bit. But anyways, let's make it random enough; without - // a call to rand.Seed() we'd get the same stream of numbers for each program - // run. (Or not, if some other packages happens to seed the global randomness - // source.) - // Note(philipc): rand.Seed() was deprecated in Go 1.20, so we've switched to - // using the recommended rand.New(rand.NewSource(seed)) style. - rand.New(rand.NewSource(time.Now().UnixNano())) -} + v1 "github.com/open-policy-agent/opa/v1/util" +) // DefaultBackoff returns a delay with an exponential backoff based on the // number of retries. -func DefaultBackoff(base, max float64, retries int) time.Duration { - return Backoff(base, max, .2, 1.6, retries) +func DefaultBackoff(base, maxNS float64, retries int) time.Duration { + return v1.DefaultBackoff(base, maxNS, retries) } // Backoff returns a delay with an exponential backoff based on the number of // retries. Same algorithm used in gRPC. -func Backoff(base, max, jitter, factor float64, retries int) time.Duration { - if retries == 0 { - return 0 - } - - backoff, max := base, max - for backoff < max && retries > 0 { - backoff *= factor - retries-- - } - if backoff > max { - backoff = max - } - - // Randomize backoff delays so that if a cluster of requests start at - // the same time, they won't operate in lockstep. - backoff *= 1 + jitter*(rand.Float64()*2-1) - if backoff < 0 { - return 0 - } - - return time.Duration(backoff) +func Backoff(base, maxNS, jitter, factor float64, retries int) time.Duration { + return v1.Backoff(base, maxNS, jitter, factor, retries) } diff --git a/vendor/github.com/open-policy-agent/opa/util/close.go b/vendor/github.com/open-policy-agent/opa/util/close.go index c3c177557..7f14cf070 100644 --- a/vendor/github.com/open-policy-agent/opa/util/close.go +++ b/vendor/github.com/open-policy-agent/opa/util/close.go @@ -5,18 +5,14 @@ package util import ( - "io" "net/http" + + v1 "github.com/open-policy-agent/opa/v1/util" ) // Close reads the remaining bytes from the response and then closes it to // ensure that the connection is freed. If the body is not read and closed, a // leak can occur. func Close(resp *http.Response) { - if resp != nil && resp.Body != nil { - if _, err := io.Copy(io.Discard, resp.Body); err != nil { - return - } - resp.Body.Close() - } + v1.Close(resp) } diff --git a/vendor/github.com/open-policy-agent/opa/util/compare.go b/vendor/github.com/open-policy-agent/opa/util/compare.go index 8ae775369..e74d1d49b 100644 --- a/vendor/github.com/open-policy-agent/opa/util/compare.go +++ b/vendor/github.com/open-policy-agent/opa/util/compare.go @@ -5,10 +5,7 @@ package util import ( - "encoding/json" - "fmt" - "math/big" - "sort" + v1 "github.com/open-policy-agent/opa/v1/util" ) // Compare returns 0 if a equals b, -1 if a is less than b, and 1 if b is than a. @@ -18,167 +15,5 @@ import ( // are compared recursively. If one slice or map is a subset of the other slice or map // it is considered "less than". Nil is always equal to nil. func Compare(a, b interface{}) int { - aSortOrder := sortOrder(a) - bSortOrder := sortOrder(b) - if aSortOrder < bSortOrder { - return -1 - } else if bSortOrder < aSortOrder { - return 1 - } - switch a := a.(type) { - case nil: - return 0 - case bool: - switch b := b.(type) { - case bool: - if a == b { - return 0 - } - if !a { - return -1 - } - return 1 - } - case json.Number: - switch b := b.(type) { - case json.Number: - return compareJSONNumber(a, b) - } - case int: - switch b := b.(type) { - case int: - if a == b { - return 0 - } else if a < b { - return -1 - } - return 1 - } - case float64: - switch b := b.(type) { - case float64: - if a == b { - return 0 - } else if a < b { - return -1 - } - return 1 - } - case string: - switch b := b.(type) { - case string: - if a == b { - return 0 - } else if a < b { - return -1 - } - return 1 - } - case []interface{}: - switch b := b.(type) { - case []interface{}: - bLen := len(b) - aLen := len(a) - minLen := aLen - if bLen < minLen { - minLen = bLen - } - for i := 0; i < minLen; i++ { - cmp := Compare(a[i], b[i]) - if cmp != 0 { - return cmp - } - } - if aLen == bLen { - return 0 - } else if aLen < bLen { - return -1 - } - return 1 - } - case map[string]interface{}: - switch b := b.(type) { - case map[string]interface{}: - var aKeys []string - for k := range a { - aKeys = append(aKeys, k) - } - var bKeys []string - for k := range b { - bKeys = append(bKeys, k) - } - sort.Strings(aKeys) - sort.Strings(bKeys) - aLen := len(aKeys) - bLen := len(bKeys) - minLen := aLen - if bLen < minLen { - minLen = bLen - } - for i := 0; i < minLen; i++ { - if aKeys[i] < bKeys[i] { - return -1 - } else if bKeys[i] < aKeys[i] { - return 1 - } - aVal := a[aKeys[i]] - bVal := b[bKeys[i]] - cmp := Compare(aVal, bVal) - if cmp != 0 { - return cmp - } - } - if aLen == bLen { - return 0 - } else if aLen < bLen { - return -1 - } - return 1 - } - } - - panic(fmt.Sprintf("illegal arguments of type %T and type %T", a, b)) -} - -const ( - nilSort = iota - boolSort = iota - numberSort = iota - stringSort = iota - arraySort = iota - objectSort = iota -) - -func compareJSONNumber(a, b json.Number) int { - bigA, ok := new(big.Float).SetString(string(a)) - if !ok { - panic("illegal value") - } - bigB, ok := new(big.Float).SetString(string(b)) - if !ok { - panic("illegal value") - } - return bigA.Cmp(bigB) -} - -func sortOrder(v interface{}) int { - switch v.(type) { - case nil: - return nilSort - case bool: - return boolSort - case json.Number: - return numberSort - case int: - return numberSort - case float64: - return numberSort - case string: - return stringSort - case []interface{}: - return arraySort - case map[string]interface{}: - return objectSort - } - panic(fmt.Sprintf("illegal argument of type %T", v)) + return v1.Compare(a, b) } diff --git a/vendor/github.com/open-policy-agent/opa/util/doc.go b/vendor/github.com/open-policy-agent/opa/util/doc.go index 900dff8c1..25f5f53bd 100644 --- a/vendor/github.com/open-policy-agent/opa/util/doc.go +++ b/vendor/github.com/open-policy-agent/opa/util/doc.go @@ -3,4 +3,8 @@ // license that can be found in the LICENSE file. // Package util provides generic utilities used throughout the policy engine. +// +// Deprecated: This package is intended for older projects transitioning from OPA v0.x and will remain for the lifetime of OPA v1.x, but its use is not recommended. +// For newer features and behaviours, such as defaulting to the Rego v1 syntax, use the corresponding components in the [github.com/open-policy-agent/opa/v1] package instead. +// See https://www.openpolicyagent.org/docs/latest/v0-compatibility/ for more information. package util diff --git a/vendor/github.com/open-policy-agent/opa/util/enumflag.go b/vendor/github.com/open-policy-agent/opa/util/enumflag.go index 4796f0269..6c28d4df4 100644 --- a/vendor/github.com/open-policy-agent/opa/util/enumflag.go +++ b/vendor/github.com/open-policy-agent/opa/util/enumflag.go @@ -5,55 +5,15 @@ package util import ( - "fmt" - "strings" + v1 "github.com/open-policy-agent/opa/v1/util" ) // EnumFlag implements the pflag.Value interface to provide enumerated command // line parameter values. -type EnumFlag struct { - defaultValue string - vs []string - i int -} +type EnumFlag = v1.EnumFlag // NewEnumFlag returns a new EnumFlag that has a defaultValue and vs enumerated // values. func NewEnumFlag(defaultValue string, vs []string) *EnumFlag { - f := &EnumFlag{ - i: -1, - vs: vs, - defaultValue: defaultValue, - } - return f -} - -// Type returns the valid enumeration values. -func (f *EnumFlag) Type() string { - return "{" + strings.Join(f.vs, ",") + "}" -} - -// String returns the EnumValue's value as string. -func (f *EnumFlag) String() string { - if f.i == -1 { - return f.defaultValue - } - return f.vs[f.i] -} - -// IsSet will return true if the EnumFlag has been set. -func (f *EnumFlag) IsSet() bool { - return f.i != -1 -} - -// Set sets the enum value. If s is not a valid enum value, an error is -// returned. -func (f *EnumFlag) Set(s string) error { - for i := range f.vs { - if f.vs[i] == s { - f.i = i - return nil - } - } - return fmt.Errorf("must be one of %v", f.Type()) + return v1.NewEnumFlag(defaultValue, vs) } diff --git a/vendor/github.com/open-policy-agent/opa/util/graph.go b/vendor/github.com/open-policy-agent/opa/util/graph.go index f0e824245..edf59da91 100644 --- a/vendor/github.com/open-policy-agent/opa/util/graph.go +++ b/vendor/github.com/open-policy-agent/opa/util/graph.go @@ -4,87 +4,31 @@ package util -// Traversal defines a basic interface to perform traversals. -type Traversal interface { - - // Edges should return the neighbours of node "u". - Edges(u T) []T +import v1 "github.com/open-policy-agent/opa/v1/util" - // Visited should return true if node "u" has already been visited in this - // traversal. If the same traversal is used multiple times, the state that - // tracks visited nodes should be reset. - Visited(u T) bool -} +// Traversal defines a basic interface to perform traversals. +type Traversal = v1.Traversal // Equals should return true if node "u" equals node "v". -type Equals func(u T, v T) bool +type Equals = v1.Equals // Iter should return true to indicate stop. -type Iter func(u T) bool +type Iter = v1.Iter // DFS performs a depth first traversal calling f for each node starting from u. // If f returns true, traversal stops and DFS returns true. func DFS(t Traversal, f Iter, u T) bool { - lifo := NewLIFO(u) - for lifo.Size() > 0 { - next, _ := lifo.Pop() - if t.Visited(next) { - continue - } - if f(next) { - return true - } - for _, v := range t.Edges(next) { - lifo.Push(v) - } - } - return false + return v1.DFS(t, f, u) } // BFS performs a breadth first traversal calling f for each node starting from // u. If f returns true, traversal stops and BFS returns true. func BFS(t Traversal, f Iter, u T) bool { - fifo := NewFIFO(u) - for fifo.Size() > 0 { - next, _ := fifo.Pop() - if t.Visited(next) { - continue - } - if f(next) { - return true - } - for _, v := range t.Edges(next) { - fifo.Push(v) - } - } - return false + return v1.BFS(t, f, u) } // DFSPath returns a path from node a to node z found by performing // a depth first traversal. If no path is found, an empty slice is returned. func DFSPath(t Traversal, eq Equals, a, z T) []T { - p := dfsRecursive(t, eq, a, z, []T{}) - for i := len(p)/2 - 1; i >= 0; i-- { - o := len(p) - i - 1 - p[i], p[o] = p[o], p[i] - } - return p -} - -func dfsRecursive(t Traversal, eq Equals, u, z T, path []T) []T { - if t.Visited(u) { - return path - } - for _, v := range t.Edges(u) { - if eq(v, z) { - path = append(path, z) - path = append(path, u) - return path - } - if p := dfsRecursive(t, eq, v, z, path); len(p) > 0 { - path = append(p, u) - return path - } - } - return path + return v1.DFSPath(t, eq, a, z) } diff --git a/vendor/github.com/open-policy-agent/opa/util/hashmap.go b/vendor/github.com/open-policy-agent/opa/util/hashmap.go index 8875a6323..5f030c77b 100644 --- a/vendor/github.com/open-policy-agent/opa/util/hashmap.go +++ b/vendor/github.com/open-policy-agent/opa/util/hashmap.go @@ -5,153 +5,16 @@ package util import ( - "fmt" - "strings" + v1 "github.com/open-policy-agent/opa/v1/util" ) // T is a concise way to refer to T. -type T interface{} - -type hashEntry struct { - k T - v T - next *hashEntry -} +type T = v1.T // HashMap represents a key/value map. -type HashMap struct { - eq func(T, T) bool - hash func(T) int - table map[int]*hashEntry - size int -} +type HashMap = v1.HashMap // NewHashMap returns a new empty HashMap. func NewHashMap(eq func(T, T) bool, hash func(T) int) *HashMap { - return &HashMap{ - eq: eq, - hash: hash, - table: make(map[int]*hashEntry), - size: 0, - } -} - -// Copy returns a shallow copy of this HashMap. -func (h *HashMap) Copy() *HashMap { - cpy := NewHashMap(h.eq, h.hash) - h.Iter(func(k, v T) bool { - cpy.Put(k, v) - return false - }) - return cpy -} - -// Equal returns true if this HashMap equals the other HashMap. -// Two hash maps are equal if they contain the same key/value pairs. -func (h *HashMap) Equal(other *HashMap) bool { - if h.Len() != other.Len() { - return false - } - return !h.Iter(func(k, v T) bool { - ov, ok := other.Get(k) - if !ok { - return true - } - return !h.eq(v, ov) - }) -} - -// Get returns the value for k. -func (h *HashMap) Get(k T) (T, bool) { - hash := h.hash(k) - for entry := h.table[hash]; entry != nil; entry = entry.next { - if h.eq(entry.k, k) { - return entry.v, true - } - } - return nil, false -} - -// Delete removes the key k. -func (h *HashMap) Delete(k T) { - hash := h.hash(k) - var prev *hashEntry - for entry := h.table[hash]; entry != nil; entry = entry.next { - if h.eq(entry.k, k) { - if prev != nil { - prev.next = entry.next - } else { - h.table[hash] = entry.next - } - h.size-- - return - } - prev = entry - } -} - -// Hash returns the hash code for this hash map. -func (h *HashMap) Hash() int { - var hash int - h.Iter(func(k, v T) bool { - hash += h.hash(k) + h.hash(v) - return false - }) - return hash -} - -// Iter invokes the iter function for each element in the HashMap. -// If the iter function returns true, iteration stops and the return value is true. -// If the iter function never returns true, iteration proceeds through all elements -// and the return value is false. -func (h *HashMap) Iter(iter func(T, T) bool) bool { - for _, entry := range h.table { - for ; entry != nil; entry = entry.next { - if iter(entry.k, entry.v) { - return true - } - } - } - return false -} - -// Len returns the current size of this HashMap. -func (h *HashMap) Len() int { - return h.size -} - -// Put inserts a key/value pair into this HashMap. If the key is already present, the existing -// value is overwritten. -func (h *HashMap) Put(k T, v T) { - hash := h.hash(k) - head := h.table[hash] - for entry := head; entry != nil; entry = entry.next { - if h.eq(entry.k, k) { - entry.v = v - return - } - } - h.table[hash] = &hashEntry{k: k, v: v, next: head} - h.size++ -} - -func (h *HashMap) String() string { - var buf []string - h.Iter(func(k T, v T) bool { - buf = append(buf, fmt.Sprintf("%v: %v", k, v)) - return false - }) - return "{" + strings.Join(buf, ", ") + "}" -} - -// Update returns a new HashMap with elements from the other HashMap put into this HashMap. -// If the other HashMap contains elements with the same key as this HashMap, the value -// from the other HashMap overwrites the value from this HashMap. -func (h *HashMap) Update(other *HashMap) *HashMap { - updated := h.Copy() - other.Iter(func(k, v T) bool { - updated.Put(k, v) - return false - }) - return updated + return v1.NewHashMap(eq, hash) } diff --git a/vendor/github.com/open-policy-agent/opa/util/json.go b/vendor/github.com/open-policy-agent/opa/util/json.go index 0b7fd2ed6..9b19a967b 100644 --- a/vendor/github.com/open-policy-agent/opa/util/json.go +++ b/vendor/github.com/open-policy-agent/opa/util/json.go @@ -5,15 +5,10 @@ package util import ( - "bytes" "encoding/json" - "fmt" "io" - "reflect" - "sigs.k8s.io/yaml" - - "github.com/open-policy-agent/opa/loader/extension" + v1 "github.com/open-policy-agent/opa/v1/util" ) // UnmarshalJSON parses the JSON encoded data and stores the result in the value @@ -22,29 +17,7 @@ import ( // This function is intended to be used in place of the standard json.Marshal // function when json.Number is required. func UnmarshalJSON(bs []byte, x interface{}) error { - return unmarshalJSON(bs, x, true) -} - -func unmarshalJSON(bs []byte, x interface{}, ext bool) error { - buf := bytes.NewBuffer(bs) - decoder := NewJSONDecoder(buf) - if err := decoder.Decode(x); err != nil { - if handler := extension.FindExtension(".json"); handler != nil && ext { - return handler(bs, x) - } - return err - } - - // Since decoder.Decode validates only the first json structure in bytes, - // check if decoder has more bytes to consume to validate whole input bytes. - tok, err := decoder.Token() - if tok != nil { - return fmt.Errorf("error: invalid character '%s' after top-level value", tok) - } - if err != nil && err != io.EOF { - return err - } - return nil + return v1.UnmarshalJSON(bs, x) } // NewJSONDecoder returns a new decoder that reads from r. @@ -52,9 +25,7 @@ func unmarshalJSON(bs []byte, x interface{}, ext bool) error { // This function is intended to be used in place of the standard json.NewDecoder // when json.Number is required. func NewJSONDecoder(r io.Reader) *json.Decoder { - decoder := json.NewDecoder(r) - decoder.UseNumber() - return decoder + return v1.NewJSONDecoder(r) } // MustUnmarshalJSON parse the JSON encoded data and returns the result. @@ -62,11 +33,7 @@ func NewJSONDecoder(r io.Reader) *json.Decoder { // If the data cannot be decoded, this function will panic. This function is for // test purposes. func MustUnmarshalJSON(bs []byte) interface{} { - var x interface{} - if err := UnmarshalJSON(bs, &x); err != nil { - panic(err) - } - return x + return v1.MustUnmarshalJSON(bs) } // MustMarshalJSON returns the JSON encoding of x @@ -74,11 +41,7 @@ func MustUnmarshalJSON(bs []byte) interface{} { // If the data cannot be encoded, this function will panic. This function is for // test purposes. func MustMarshalJSON(x interface{}) []byte { - bs, err := json.Marshal(x) - if err != nil { - panic(err) - } - return bs + return v1.MustMarshalJSON(x) } // RoundTrip encodes to JSON, and decodes the result again. @@ -87,11 +50,7 @@ func MustMarshalJSON(x interface{}) []byte { // rego.Input and inmem's Write operations. Works with both references and // values. func RoundTrip(x *interface{}) error { - bs, err := json.Marshal(x) - if err != nil { - return err - } - return UnmarshalJSON(bs, x) + return v1.RoundTrip(x) } // Reference returns a pointer to its argument unless the argument already is @@ -100,34 +59,10 @@ func RoundTrip(x *interface{}) error { // Used for preparing Go types (including pointers to structs) into values to be // put through util.RoundTrip(). func Reference(x interface{}) *interface{} { - var y interface{} - rv := reflect.ValueOf(x) - if rv.Kind() == reflect.Ptr { - return Reference(rv.Elem().Interface()) - } - if rv.Kind() != reflect.Invalid { - y = rv.Interface() - return &y - } - return &x + return v1.Reference(x) } // Unmarshal decodes a YAML, JSON or JSON extension value into the specified type. func Unmarshal(bs []byte, v interface{}) error { - if len(bs) > 2 && bs[0] == 0xef && bs[1] == 0xbb && bs[2] == 0xbf { - bs = bs[3:] // Strip UTF-8 BOM, see https://www.rfc-editor.org/rfc/rfc8259#section-8.1 - } - - if json.Valid(bs) { - return unmarshalJSON(bs, v, false) - } - nbs, err := yaml.YAMLToJSON(bs) - if err == nil { - return unmarshalJSON(nbs, v, false) - } - // not json or yaml: try extensions - if handler := extension.FindExtension(".json"); handler != nil { - return handler(bs, v) - } - return err + return v1.Unmarshal(bs, v) } diff --git a/vendor/github.com/open-policy-agent/opa/util/maps.go b/vendor/github.com/open-policy-agent/opa/util/maps.go index d943b4d0a..1a9c71f28 100644 --- a/vendor/github.com/open-policy-agent/opa/util/maps.go +++ b/vendor/github.com/open-policy-agent/opa/util/maps.go @@ -1,10 +1,8 @@ package util +import v1 "github.com/open-policy-agent/opa/v1/util" + // Values returns a slice of values from any map. Copied from golang.org/x/exp/maps. func Values[M ~map[K]V, K comparable, V any](m M) []V { - r := make([]V, 0, len(m)) - for _, v := range m { - r = append(r, v) - } - return r + return v1.Values(m) } diff --git a/vendor/github.com/open-policy-agent/opa/util/queue.go b/vendor/github.com/open-policy-agent/opa/util/queue.go index 63a2ffc16..d130a97ab 100644 --- a/vendor/github.com/open-policy-agent/opa/util/queue.go +++ b/vendor/github.com/open-policy-agent/opa/util/queue.go @@ -4,110 +4,22 @@ package util -// LIFO represents a simple LIFO queue. -type LIFO struct { - top *queueNode - size int -} +import v1 "github.com/open-policy-agent/opa/v1/util" -type queueNode struct { - v T - next *queueNode -} +// LIFO represents a simple LIFO queue. +type LIFO = v1.LIFO // NewLIFO returns a new LIFO queue containing elements ts starting with the // left-most argument at the bottom. func NewLIFO(ts ...T) *LIFO { - s := &LIFO{} - for i := range ts { - s.Push(ts[i]) - } - return s -} - -// Push adds a new element onto the LIFO. -func (s *LIFO) Push(t T) { - node := &queueNode{v: t, next: s.top} - s.top = node - s.size++ -} - -// Peek returns the top of the LIFO. If LIFO is empty, returns nil, false. -func (s *LIFO) Peek() (T, bool) { - if s.top == nil { - return nil, false - } - return s.top.v, true -} - -// Pop returns the top of the LIFO and removes it. If LIFO is empty returns -// nil, false. -func (s *LIFO) Pop() (T, bool) { - if s.top == nil { - return nil, false - } - node := s.top - s.top = node.next - s.size-- - return node.v, true -} - -// Size returns the size of the LIFO. -func (s *LIFO) Size() int { - return s.size + return v1.NewLIFO(ts...) } // FIFO represents a simple FIFO queue. -type FIFO struct { - front *queueNode - back *queueNode - size int -} +type FIFO = v1.FIFO // NewFIFO returns a new FIFO queue containing elements ts starting with the // left-most argument at the front. func NewFIFO(ts ...T) *FIFO { - s := &FIFO{} - for i := range ts { - s.Push(ts[i]) - } - return s -} - -// Push adds a new element onto the LIFO. -func (s *FIFO) Push(t T) { - node := &queueNode{v: t, next: nil} - if s.front == nil { - s.front = node - s.back = node - } else { - s.back.next = node - s.back = node - } - s.size++ -} - -// Peek returns the top of the LIFO. If LIFO is empty, returns nil, false. -func (s *FIFO) Peek() (T, bool) { - if s.front == nil { - return nil, false - } - return s.front.v, true -} - -// Pop returns the top of the LIFO and removes it. If LIFO is empty returns -// nil, false. -func (s *FIFO) Pop() (T, bool) { - if s.front == nil { - return nil, false - } - node := s.front - s.front = node.next - s.size-- - return node.v, true -} - -// Size returns the size of the LIFO. -func (s *FIFO) Size() int { - return s.size + return v1.NewFIFO(ts...) } diff --git a/vendor/github.com/open-policy-agent/opa/util/read_gzip_body.go b/vendor/github.com/open-policy-agent/opa/util/read_gzip_body.go index 217638b36..70b615a4e 100644 --- a/vendor/github.com/open-policy-agent/opa/util/read_gzip_body.go +++ b/vendor/github.com/open-policy-agent/opa/util/read_gzip_body.go @@ -1,25 +1,11 @@ package util import ( - "bytes" - "compress/gzip" - "encoding/binary" - "fmt" - "io" "net/http" - "strings" - "sync" - "github.com/open-policy-agent/opa/util/decoding" + v1 "github.com/open-policy-agent/opa/v1/util" ) -var gzipReaderPool = sync.Pool{ - New: func() interface{} { - reader := new(gzip.Reader) - return reader - }, -} - // Note(philipc): Originally taken from server/server.go // The DecodingLimitHandler handles validating that the gzip payload is within the // allowed max size limit. Thus, in the event of a forged payload size trailer, @@ -27,55 +13,5 @@ var gzipReaderPool = sync.Pool{ // payload size, but not an unbounded amount of memory, as was potentially // possible before. func ReadMaybeCompressedBody(r *http.Request) ([]byte, error) { - var content *bytes.Buffer - // Note(philipc): If the request body is of unknown length (such as what - // happens when 'Transfer-Encoding: chunked' is set), we have to do an - // incremental read of the body. In this case, we can't be too clever, we - // just do the best we can with whatever is streamed over to us. - // Fetch gzip payload size limit from request context. - if maxLength, ok := decoding.GetServerDecodingMaxLen(r.Context()); ok { - bs, err := io.ReadAll(io.LimitReader(r.Body, maxLength)) - if err != nil { - return bs, err - } - content = bytes.NewBuffer(bs) - } else { - // Read content from the request body into a buffer of known size. - content = bytes.NewBuffer(make([]byte, 0, r.ContentLength)) - if _, err := io.CopyN(content, r.Body, r.ContentLength); err != nil { - return content.Bytes(), err - } - } - - // Decompress gzip content by reading from the buffer. - if strings.Contains(r.Header.Get("Content-Encoding"), "gzip") { - // Fetch gzip payload size limit from request context. - gzipMaxLength, _ := decoding.GetServerDecodingGzipMaxLen(r.Context()) - - // Note(philipc): The last 4 bytes of a well-formed gzip blob will - // always be a little-endian uint32, representing the decompressed - // content size, modulo 2^32. We validate that the size is safe, - // earlier in DecodingLimitHandler. - sizeTrailerField := binary.LittleEndian.Uint32(content.Bytes()[content.Len()-4:]) - if sizeTrailerField > uint32(gzipMaxLength) { - return content.Bytes(), fmt.Errorf("gzip payload too large") - } - // Pull a gzip decompressor from the pool, and assign it to the current - // buffer, using Reset(). Later, return it back to the pool for another - // request to use. - gzReader := gzipReaderPool.Get().(*gzip.Reader) - if err := gzReader.Reset(content); err != nil { - return nil, err - } - defer gzReader.Close() - defer gzipReaderPool.Put(gzReader) - decompressedContent := bytes.NewBuffer(make([]byte, 0, sizeTrailerField)) - if _, err := io.CopyN(decompressedContent, gzReader, int64(sizeTrailerField)); err != nil { - return decompressedContent.Bytes(), err - } - return decompressedContent.Bytes(), nil - } - - // Request was not compressed; return the content bytes. - return content.Bytes(), nil + return v1.ReadMaybeCompressedBody(r) } diff --git a/vendor/github.com/open-policy-agent/opa/util/time.go b/vendor/github.com/open-policy-agent/opa/util/time.go index 93ef03939..364197470 100644 --- a/vendor/github.com/open-policy-agent/opa/util/time.go +++ b/vendor/github.com/open-policy-agent/opa/util/time.go @@ -1,6 +1,10 @@ package util -import "time" +import ( + "time" + + v1 "github.com/open-policy-agent/opa/v1/util" +) // TimerWithCancel exists because of memory leaks when using // time.After in select statements. Instead, we now manually create timers, @@ -30,19 +34,5 @@ import "time" // } // } func TimerWithCancel(delay time.Duration) (*time.Timer, func()) { - timer := time.NewTimer(delay) - - return timer, func() { - // Note: The Stop function returns: - // - true: if the timer is active. (no draining required) - // - false: if the timer was already stopped or fired/expired. - // In this case the channel should be drained to prevent memory - // leaks only if it is not empty. - // This operation is safe only if the cancel function is - // used in same goroutine. Concurrent reading or canceling may - // cause deadlock. - if !timer.Stop() && len(timer.C) > 0 { - <-timer.C - } - } + return v1.TimerWithCancel(delay) } diff --git a/vendor/github.com/open-policy-agent/opa/util/wait.go b/vendor/github.com/open-policy-agent/opa/util/wait.go index b70ab6fcf..235f84a0e 100644 --- a/vendor/github.com/open-policy-agent/opa/util/wait.go +++ b/vendor/github.com/open-policy-agent/opa/util/wait.go @@ -5,8 +5,9 @@ package util import ( - "fmt" "time" + + v1 "github.com/open-policy-agent/opa/v1/util" ) // WaitFunc will call passed function at an interval and return nil @@ -14,21 +15,5 @@ import ( // If timeout is reached before the passed in function returns true // an error is returned. func WaitFunc(fun func() bool, interval, timeout time.Duration) error { - if fun() { - return nil - } - ticker := time.NewTicker(interval) - timer := time.NewTimer(timeout) - defer ticker.Stop() - defer timer.Stop() - for { - select { - case <-timer.C: - return fmt.Errorf("timeout") - case <-ticker.C: - if fun() { - return nil - } - } - } + return v1.WaitFunc(fun, interval, timeout) } diff --git a/vendor/github.com/open-policy-agent/opa/v1/ast/annotations.go b/vendor/github.com/open-policy-agent/opa/v1/ast/annotations.go new file mode 100644 index 000000000..cfa1c5dab --- /dev/null +++ b/vendor/github.com/open-policy-agent/opa/v1/ast/annotations.go @@ -0,0 +1,977 @@ +// Copyright 2022 The OPA Authors. All rights reserved. +// Use of this source code is governed by an Apache2 +// license that can be found in the LICENSE file. + +package ast + +import ( + "encoding/json" + "fmt" + "net/url" + "sort" + "strings" + + "github.com/open-policy-agent/opa/internal/deepcopy" + astJSON "github.com/open-policy-agent/opa/v1/ast/json" + "github.com/open-policy-agent/opa/v1/util" +) + +const ( + annotationScopePackage = "package" + annotationScopeImport = "import" + annotationScopeRule = "rule" + annotationScopeDocument = "document" + annotationScopeSubpackages = "subpackages" +) + +type ( + // Annotations represents metadata attached to other AST nodes such as rules. + Annotations struct { + Scope string `json:"scope"` + Title string `json:"title,omitempty"` + Entrypoint bool `json:"entrypoint,omitempty"` + Description string `json:"description,omitempty"` + Organizations []string `json:"organizations,omitempty"` + RelatedResources []*RelatedResourceAnnotation `json:"related_resources,omitempty"` + Authors []*AuthorAnnotation `json:"authors,omitempty"` + Schemas []*SchemaAnnotation `json:"schemas,omitempty"` + Custom map[string]interface{} `json:"custom,omitempty"` + Location *Location `json:"location,omitempty"` + + comments []*Comment + node Node + jsonOptions astJSON.Options + } + + // SchemaAnnotation contains a schema declaration for the document identified by the path. + SchemaAnnotation struct { + Path Ref `json:"path"` + Schema Ref `json:"schema,omitempty"` + Definition *interface{} `json:"definition,omitempty"` + } + + AuthorAnnotation struct { + Name string `json:"name"` + Email string `json:"email,omitempty"` + } + + RelatedResourceAnnotation struct { + Ref url.URL `json:"ref"` + Description string `json:"description,omitempty"` + } + + AnnotationSet struct { + byRule map[*Rule][]*Annotations + byPackage map[int]*Annotations + byPath *annotationTreeNode + modules []*Module // Modules this set was constructed from + } + + annotationTreeNode struct { + Value *Annotations + Children map[Value]*annotationTreeNode // we assume key elements are hashable (vars and strings only!) + } + + AnnotationsRef struct { + Path Ref `json:"path"` // The path of the node the annotations are applied to + Annotations *Annotations `json:"annotations,omitempty"` + Location *Location `json:"location,omitempty"` // The location of the node the annotations are applied to + + jsonOptions astJSON.Options + + node Node // The node the annotations are applied to + } + + AnnotationsRefSet []*AnnotationsRef + + FlatAnnotationsRefSet AnnotationsRefSet +) + +func (a *Annotations) String() string { + bs, _ := a.MarshalJSON() + return string(bs) +} + +// Loc returns the location of this annotation. +func (a *Annotations) Loc() *Location { + return a.Location +} + +// SetLoc updates the location of this annotation. +func (a *Annotations) SetLoc(l *Location) { + a.Location = l +} + +// EndLoc returns the location of this annotation's last comment line. +func (a *Annotations) EndLoc() *Location { + count := len(a.comments) + if count == 0 { + return a.Location + } + return a.comments[count-1].Location +} + +// Compare returns an integer indicating if a is less than, equal to, or greater +// than other. +func (a *Annotations) Compare(other *Annotations) int { + + if a == nil && other == nil { + return 0 + } + + if a == nil { + return -1 + } + + if other == nil { + return 1 + } + + if cmp := scopeCompare(a.Scope, other.Scope); cmp != 0 { + return cmp + } + + if cmp := strings.Compare(a.Title, other.Title); cmp != 0 { + return cmp + } + + if cmp := strings.Compare(a.Description, other.Description); cmp != 0 { + return cmp + } + + if cmp := compareStringLists(a.Organizations, other.Organizations); cmp != 0 { + return cmp + } + + if cmp := compareRelatedResources(a.RelatedResources, other.RelatedResources); cmp != 0 { + return cmp + } + + if cmp := compareAuthors(a.Authors, other.Authors); cmp != 0 { + return cmp + } + + if cmp := compareSchemas(a.Schemas, other.Schemas); cmp != 0 { + return cmp + } + + if a.Entrypoint != other.Entrypoint { + if a.Entrypoint { + return 1 + } + return -1 + } + + if cmp := util.Compare(a.Custom, other.Custom); cmp != 0 { + return cmp + } + + return 0 +} + +// GetTargetPath returns the path of the node these Annotations are applied to (the target) +func (a *Annotations) GetTargetPath() Ref { + switch n := a.node.(type) { + case *Package: + return n.Path + case *Rule: + return n.Ref().GroundPrefix() + default: + return nil + } +} + +func (a *Annotations) setJSONOptions(opts astJSON.Options) { + a.jsonOptions = opts + if a.Location != nil { + a.Location.JSONOptions = opts + } +} + +func (a *Annotations) MarshalJSON() ([]byte, error) { + if a == nil { + return []byte(`{"scope":""}`), nil + } + + data := map[string]interface{}{ + "scope": a.Scope, + } + + if a.Title != "" { + data["title"] = a.Title + } + + if a.Description != "" { + data["description"] = a.Description + } + + if a.Entrypoint { + data["entrypoint"] = a.Entrypoint + } + + if len(a.Organizations) > 0 { + data["organizations"] = a.Organizations + } + + if len(a.RelatedResources) > 0 { + data["related_resources"] = a.RelatedResources + } + + if len(a.Authors) > 0 { + data["authors"] = a.Authors + } + + if len(a.Schemas) > 0 { + data["schemas"] = a.Schemas + } + + if len(a.Custom) > 0 { + data["custom"] = a.Custom + } + + if a.jsonOptions.MarshalOptions.IncludeLocation.Annotations { + if a.Location != nil { + data["location"] = a.Location + } + } + + return json.Marshal(data) +} + +func NewAnnotationsRef(a *Annotations) *AnnotationsRef { + var loc *Location + if a.node != nil { + loc = a.node.Loc() + } + + return &AnnotationsRef{ + Location: loc, + Path: a.GetTargetPath(), + Annotations: a, + node: a.node, + jsonOptions: a.jsonOptions, + } +} + +func (ar *AnnotationsRef) GetPackage() *Package { + switch n := ar.node.(type) { + case *Package: + return n + case *Rule: + return n.Module.Package + default: + return nil + } +} + +func (ar *AnnotationsRef) GetRule() *Rule { + switch n := ar.node.(type) { + case *Rule: + return n + default: + return nil + } +} + +func (ar *AnnotationsRef) MarshalJSON() ([]byte, error) { + data := map[string]interface{}{ + "path": ar.Path, + } + + if ar.Annotations != nil { + data["annotations"] = ar.Annotations + } + + if ar.jsonOptions.MarshalOptions.IncludeLocation.AnnotationsRef { + if ar.Location != nil { + data["location"] = ar.Location + } + } + + return json.Marshal(data) +} + +func scopeCompare(s1, s2 string) int { + + o1 := scopeOrder(s1) + o2 := scopeOrder(s2) + + if o2 < o1 { + return 1 + } else if o2 > o1 { + return -1 + } + + if s1 < s2 { + return -1 + } else if s2 < s1 { + return 1 + } + + return 0 +} + +func scopeOrder(s string) int { + switch s { + case annotationScopeRule: + return 1 + } + return 0 +} + +func compareAuthors(a, b []*AuthorAnnotation) int { + if len(a) > len(b) { + return 1 + } else if len(a) < len(b) { + return -1 + } + + for i := 0; i < len(a); i++ { + if cmp := a[i].Compare(b[i]); cmp != 0 { + return cmp + } + } + + return 0 +} + +func compareRelatedResources(a, b []*RelatedResourceAnnotation) int { + if len(a) > len(b) { + return 1 + } else if len(a) < len(b) { + return -1 + } + + for i := 0; i < len(a); i++ { + if cmp := strings.Compare(a[i].String(), b[i].String()); cmp != 0 { + return cmp + } + } + + return 0 +} + +func compareSchemas(a, b []*SchemaAnnotation) int { + maxLen := len(a) + if len(b) < maxLen { + maxLen = len(b) + } + + for i := 0; i < maxLen; i++ { + if cmp := a[i].Compare(b[i]); cmp != 0 { + return cmp + } + } + + if len(a) > len(b) { + return 1 + } else if len(a) < len(b) { + return -1 + } + + return 0 +} + +func compareStringLists(a, b []string) int { + if len(a) > len(b) { + return 1 + } else if len(a) < len(b) { + return -1 + } + + for i := 0; i < len(a); i++ { + if cmp := strings.Compare(a[i], b[i]); cmp != 0 { + return cmp + } + } + + return 0 +} + +// Copy returns a deep copy of s. +func (a *Annotations) Copy(node Node) *Annotations { + cpy := *a + + cpy.Organizations = make([]string, len(a.Organizations)) + copy(cpy.Organizations, a.Organizations) + + cpy.RelatedResources = make([]*RelatedResourceAnnotation, len(a.RelatedResources)) + for i := range a.RelatedResources { + cpy.RelatedResources[i] = a.RelatedResources[i].Copy() + } + + cpy.Authors = make([]*AuthorAnnotation, len(a.Authors)) + for i := range a.Authors { + cpy.Authors[i] = a.Authors[i].Copy() + } + + cpy.Schemas = make([]*SchemaAnnotation, len(a.Schemas)) + for i := range a.Schemas { + cpy.Schemas[i] = a.Schemas[i].Copy() + } + + cpy.Custom = deepcopy.Map(a.Custom) + + cpy.node = node + + return &cpy +} + +// toObject constructs an AST Object from the annotation. +func (a *Annotations) toObject() (*Object, *Error) { + obj := NewObject() + + if a == nil { + return &obj, nil + } + + if len(a.Scope) > 0 { + obj.Insert(StringTerm("scope"), StringTerm(a.Scope)) + } + + if len(a.Title) > 0 { + obj.Insert(StringTerm("title"), StringTerm(a.Title)) + } + + if a.Entrypoint { + obj.Insert(StringTerm("entrypoint"), BooleanTerm(true)) + } + + if len(a.Description) > 0 { + obj.Insert(StringTerm("description"), StringTerm(a.Description)) + } + + if len(a.Organizations) > 0 { + orgs := make([]*Term, 0, len(a.Organizations)) + for _, org := range a.Organizations { + orgs = append(orgs, StringTerm(org)) + } + obj.Insert(StringTerm("organizations"), ArrayTerm(orgs...)) + } + + if len(a.RelatedResources) > 0 { + rrs := make([]*Term, 0, len(a.RelatedResources)) + for _, rr := range a.RelatedResources { + rrObj := NewObject(Item(StringTerm("ref"), StringTerm(rr.Ref.String()))) + if len(rr.Description) > 0 { + rrObj.Insert(StringTerm("description"), StringTerm(rr.Description)) + } + rrs = append(rrs, NewTerm(rrObj)) + } + obj.Insert(StringTerm("related_resources"), ArrayTerm(rrs...)) + } + + if len(a.Authors) > 0 { + as := make([]*Term, 0, len(a.Authors)) + for _, author := range a.Authors { + aObj := NewObject() + if len(author.Name) > 0 { + aObj.Insert(StringTerm("name"), StringTerm(author.Name)) + } + if len(author.Email) > 0 { + aObj.Insert(StringTerm("email"), StringTerm(author.Email)) + } + as = append(as, NewTerm(aObj)) + } + obj.Insert(StringTerm("authors"), ArrayTerm(as...)) + } + + if len(a.Schemas) > 0 { + ss := make([]*Term, 0, len(a.Schemas)) + for _, s := range a.Schemas { + sObj := NewObject() + if len(s.Path) > 0 { + sObj.Insert(StringTerm("path"), NewTerm(s.Path.toArray())) + } + if len(s.Schema) > 0 { + sObj.Insert(StringTerm("schema"), NewTerm(s.Schema.toArray())) + } + if s.Definition != nil { + def, err := InterfaceToValue(s.Definition) + if err != nil { + return nil, NewError(CompileErr, a.Location, "invalid definition in schema annotation: %s", err.Error()) + } + sObj.Insert(StringTerm("definition"), NewTerm(def)) + } + ss = append(ss, NewTerm(sObj)) + } + obj.Insert(StringTerm("schemas"), ArrayTerm(ss...)) + } + + if len(a.Custom) > 0 { + c, err := InterfaceToValue(a.Custom) + if err != nil { + return nil, NewError(CompileErr, a.Location, "invalid custom annotation %s", err.Error()) + } + obj.Insert(StringTerm("custom"), NewTerm(c)) + } + + return &obj, nil +} + +func attachRuleAnnotations(mod *Module) { + // make a copy of the annotations + cpy := make([]*Annotations, len(mod.Annotations)) + for i, a := range mod.Annotations { + cpy[i] = a.Copy(a.node) + } + + for _, rule := range mod.Rules { + var j int + var found bool + for i, a := range cpy { + if rule.Ref().GroundPrefix().Equal(a.GetTargetPath()) { + if a.Scope == annotationScopeDocument { + rule.Annotations = append(rule.Annotations, a) + } else if a.Scope == annotationScopeRule && rule.Loc().Row > a.Location.Row { + j = i + found = true + rule.Annotations = append(rule.Annotations, a) + } + } + } + + if found && j < len(cpy) { + cpy = append(cpy[:j], cpy[j+1:]...) + } + } +} + +func attachAnnotationsNodes(mod *Module) Errors { + var errs Errors + + // Find first non-annotation statement following each annotation and attach + // the annotation to that statement. + for _, a := range mod.Annotations { + for _, stmt := range mod.stmts { + _, ok := stmt.(*Annotations) + if !ok { + if stmt.Loc().Row > a.Location.Row { + a.node = stmt + break + } + } + } + + if a.Scope == "" { + switch a.node.(type) { + case *Rule: + if a.Entrypoint { + a.Scope = annotationScopeDocument + } else { + a.Scope = annotationScopeRule + } + case *Package: + a.Scope = annotationScopePackage + case *Import: + a.Scope = annotationScopeImport + } + } + + if err := validateAnnotationScopeAttachment(a); err != nil { + errs = append(errs, err) + } + + if err := validateAnnotationEntrypointAttachment(a); err != nil { + errs = append(errs, err) + } + } + + return errs +} + +func validateAnnotationScopeAttachment(a *Annotations) *Error { + + switch a.Scope { + case annotationScopeRule, annotationScopeDocument: + if _, ok := a.node.(*Rule); ok { + return nil + } + return newScopeAttachmentErr(a, "rule") + case annotationScopePackage, annotationScopeSubpackages: + if _, ok := a.node.(*Package); ok { + return nil + } + return newScopeAttachmentErr(a, "package") + } + + return NewError(ParseErr, a.Loc(), "invalid annotation scope '%v'. Use one of '%s', '%s', '%s', or '%s'", + a.Scope, annotationScopeRule, annotationScopeDocument, annotationScopePackage, annotationScopeSubpackages) +} + +func validateAnnotationEntrypointAttachment(a *Annotations) *Error { + if a.Entrypoint && !(a.Scope == annotationScopeDocument || a.Scope == annotationScopePackage) { + return NewError( + ParseErr, a.Loc(), "annotation entrypoint applied to non-document or package scope '%v'", a.Scope) + } + return nil +} + +// Copy returns a deep copy of a. +func (a *AuthorAnnotation) Copy() *AuthorAnnotation { + cpy := *a + return &cpy +} + +// Compare returns an integer indicating if s is less than, equal to, or greater +// than other. +func (a *AuthorAnnotation) Compare(other *AuthorAnnotation) int { + if cmp := strings.Compare(a.Name, other.Name); cmp != 0 { + return cmp + } + + if cmp := strings.Compare(a.Email, other.Email); cmp != 0 { + return cmp + } + + return 0 +} + +func (a *AuthorAnnotation) String() string { + if len(a.Email) == 0 { + return a.Name + } else if len(a.Name) == 0 { + return fmt.Sprintf("<%s>", a.Email) + } + return fmt.Sprintf("%s <%s>", a.Name, a.Email) +} + +// Copy returns a deep copy of rr. +func (rr *RelatedResourceAnnotation) Copy() *RelatedResourceAnnotation { + cpy := *rr + return &cpy +} + +// Compare returns an integer indicating if s is less than, equal to, or greater +// than other. +func (rr *RelatedResourceAnnotation) Compare(other *RelatedResourceAnnotation) int { + if cmp := strings.Compare(rr.Description, other.Description); cmp != 0 { + return cmp + } + + if cmp := strings.Compare(rr.Ref.String(), other.Ref.String()); cmp != 0 { + return cmp + } + + return 0 +} + +func (rr *RelatedResourceAnnotation) String() string { + bs, _ := json.Marshal(rr) + return string(bs) +} + +func (rr *RelatedResourceAnnotation) MarshalJSON() ([]byte, error) { + d := map[string]interface{}{ + "ref": rr.Ref.String(), + } + + if len(rr.Description) > 0 { + d["description"] = rr.Description + } + + return json.Marshal(d) +} + +// Copy returns a deep copy of s. +func (s *SchemaAnnotation) Copy() *SchemaAnnotation { + cpy := *s + return &cpy +} + +// Compare returns an integer indicating if s is less than, equal to, or greater +// than other. +func (s *SchemaAnnotation) Compare(other *SchemaAnnotation) int { + + if cmp := s.Path.Compare(other.Path); cmp != 0 { + return cmp + } + + if cmp := s.Schema.Compare(other.Schema); cmp != 0 { + return cmp + } + + if s.Definition != nil && other.Definition == nil { + return -1 + } else if s.Definition == nil && other.Definition != nil { + return 1 + } else if s.Definition != nil && other.Definition != nil { + return util.Compare(*s.Definition, *other.Definition) + } + + return 0 +} + +func (s *SchemaAnnotation) String() string { + bs, _ := json.Marshal(s) + return string(bs) +} + +func newAnnotationSet() *AnnotationSet { + return &AnnotationSet{ + byRule: map[*Rule][]*Annotations{}, + byPackage: map[int]*Annotations{}, + byPath: newAnnotationTree(), + } +} + +func BuildAnnotationSet(modules []*Module) (*AnnotationSet, Errors) { + as := newAnnotationSet() + var errs Errors + for _, m := range modules { + for _, a := range m.Annotations { + if err := as.add(a); err != nil { + errs = append(errs, err) + } + } + } + if len(errs) > 0 { + return nil, errs + } + as.modules = modules + return as, nil +} + +// NOTE(philipc): During copy propagation, the underlying Nodes can be +// stripped away from the annotations, leading to nil deref panics. We +// silently ignore these cases for now, as a workaround. +func (as *AnnotationSet) add(a *Annotations) *Error { + switch a.Scope { + case annotationScopeRule: + if rule, ok := a.node.(*Rule); ok { + as.byRule[rule] = append(as.byRule[rule], a) + } + case annotationScopePackage: + if pkg, ok := a.node.(*Package); ok { + hash := pkg.Path.Hash() + if exist, ok := as.byPackage[hash]; ok { + return errAnnotationRedeclared(a, exist.Location) + } + as.byPackage[hash] = a + } + case annotationScopeDocument: + if rule, ok := a.node.(*Rule); ok { + path := rule.Ref().GroundPrefix() + x := as.byPath.get(path) + if x != nil { + return errAnnotationRedeclared(a, x.Value.Location) + } + as.byPath.insert(path, a) + } + case annotationScopeSubpackages: + if pkg, ok := a.node.(*Package); ok { + x := as.byPath.get(pkg.Path) + if x != nil && x.Value != nil { + return errAnnotationRedeclared(a, x.Value.Location) + } + as.byPath.insert(pkg.Path, a) + } + } + return nil +} + +func (as *AnnotationSet) GetRuleScope(r *Rule) []*Annotations { + if as == nil { + return nil + } + return as.byRule[r] +} + +func (as *AnnotationSet) GetSubpackagesScope(path Ref) []*Annotations { + if as == nil { + return nil + } + return as.byPath.ancestors(path) +} + +func (as *AnnotationSet) GetDocumentScope(path Ref) *Annotations { + if as == nil { + return nil + } + if node := as.byPath.get(path); node != nil { + return node.Value + } + return nil +} + +func (as *AnnotationSet) GetPackageScope(pkg *Package) *Annotations { + if as == nil { + return nil + } + return as.byPackage[pkg.Path.Hash()] +} + +// Flatten returns a flattened list view of this AnnotationSet. +// The returned slice is sorted, first by the annotations' target path, then by their target location +func (as *AnnotationSet) Flatten() FlatAnnotationsRefSet { + // This preallocation often won't be optimal, but it's superior to starting with a nil slice. + refs := make([]*AnnotationsRef, 0, len(as.byPath.Children)+len(as.byRule)+len(as.byPackage)) + + refs = as.byPath.flatten(refs) + + for _, a := range as.byPackage { + refs = append(refs, NewAnnotationsRef(a)) + } + + for _, as := range as.byRule { + for _, a := range as { + refs = append(refs, NewAnnotationsRef(a)) + } + } + + // Sort by path, then annotation location, for stable output + sort.SliceStable(refs, func(i, j int) bool { + return refs[i].Compare(refs[j]) < 0 + }) + + return refs +} + +// Chain returns the chain of annotations leading up to the given rule. +// The returned slice is ordered as follows +// 0. Entries for the given rule, ordered from the METADATA block declared immediately above the rule, to the block declared farthest away (always at least one entry) +// 1. The 'document' scope entry, if any +// 2. The 'package' scope entry, if any +// 3. Entries for the 'subpackages' scope, if any; ordered from the closest package path to the fartest. E.g.: 'do.re.mi', 'do.re', 'do' +// The returned slice is guaranteed to always contain at least one entry, corresponding to the given rule. +func (as *AnnotationSet) Chain(rule *Rule) AnnotationsRefSet { + var refs []*AnnotationsRef + + ruleAnnots := as.GetRuleScope(rule) + + if len(ruleAnnots) >= 1 { + for _, a := range ruleAnnots { + refs = append(refs, NewAnnotationsRef(a)) + } + } else { + // Make sure there is always a leading entry representing the passed rule, even if it has no annotations + refs = append(refs, &AnnotationsRef{ + Location: rule.Location, + Path: rule.Ref().GroundPrefix(), + node: rule, + }) + } + + if len(refs) > 1 { + // Sort by annotation location; chain must start with annotations declared closest to rule, then going outward + sort.SliceStable(refs, func(i, j int) bool { + return refs[i].Annotations.Location.Compare(refs[j].Annotations.Location) > 0 + }) + } + + docAnnots := as.GetDocumentScope(rule.Ref().GroundPrefix()) + if docAnnots != nil { + refs = append(refs, NewAnnotationsRef(docAnnots)) + } + + pkg := rule.Module.Package + pkgAnnots := as.GetPackageScope(pkg) + if pkgAnnots != nil { + refs = append(refs, NewAnnotationsRef(pkgAnnots)) + } + + subPkgAnnots := as.GetSubpackagesScope(pkg.Path) + // We need to reverse the order, as subPkgAnnots ordering will start at the root, + // whereas we want to end at the root. + for i := len(subPkgAnnots) - 1; i >= 0; i-- { + refs = append(refs, NewAnnotationsRef(subPkgAnnots[i])) + } + + return refs +} + +func (ars FlatAnnotationsRefSet) Insert(ar *AnnotationsRef) FlatAnnotationsRefSet { + result := make(FlatAnnotationsRefSet, 0, len(ars)+1) + + // insertion sort, first by path, then location + for i, current := range ars { + if ar.Compare(current) < 0 { + result = append(result, ar) + result = append(result, ars[i:]...) + break + } + result = append(result, current) + } + + if len(result) < len(ars)+1 { + result = append(result, ar) + } + + return result +} + +func newAnnotationTree() *annotationTreeNode { + return &annotationTreeNode{ + Value: nil, + Children: map[Value]*annotationTreeNode{}, + } +} + +func (t *annotationTreeNode) insert(path Ref, value *Annotations) { + node := t + for _, k := range path { + child, ok := node.Children[k.Value] + if !ok { + child = newAnnotationTree() + node.Children[k.Value] = child + } + node = child + } + node.Value = value +} + +func (t *annotationTreeNode) get(path Ref) *annotationTreeNode { + node := t + for _, k := range path { + if node == nil { + return nil + } + child, ok := node.Children[k.Value] + if !ok { + return nil + } + node = child + } + return node +} + +// ancestors returns a slice of annotations in ascending order, starting with the root of ref; e.g.: 'root', 'root.foo', 'root.foo.bar'. +func (t *annotationTreeNode) ancestors(path Ref) (result []*Annotations) { + node := t + for _, k := range path { + if node == nil { + return result + } + child, ok := node.Children[k.Value] + if !ok { + return result + } + if child.Value != nil { + result = append(result, child.Value) + } + node = child + } + return result +} + +func (t *annotationTreeNode) flatten(refs []*AnnotationsRef) []*AnnotationsRef { + if a := t.Value; a != nil { + refs = append(refs, NewAnnotationsRef(a)) + } + for _, c := range t.Children { + refs = c.flatten(refs) + } + return refs +} + +func (ar *AnnotationsRef) Compare(other *AnnotationsRef) int { + if c := ar.Path.Compare(other.Path); c != 0 { + return c + } + + if c := ar.Annotations.Location.Compare(other.Annotations.Location); c != 0 { + return c + } + + return ar.Annotations.Compare(other.Annotations) +} diff --git a/vendor/github.com/open-policy-agent/opa/v1/ast/builtins.go b/vendor/github.com/open-policy-agent/opa/v1/ast/builtins.go new file mode 100644 index 000000000..9585620dc --- /dev/null +++ b/vendor/github.com/open-policy-agent/opa/v1/ast/builtins.go @@ -0,0 +1,3398 @@ +// Copyright 2016 The OPA Authors. All rights reserved. +// Use of this source code is governed by an Apache2 +// license that can be found in the LICENSE file. + +package ast + +import ( + "strings" + + "github.com/open-policy-agent/opa/v1/types" +) + +// Builtins is the registry of built-in functions supported by OPA. +// Call RegisterBuiltin to add a new built-in. +var Builtins []*Builtin + +// RegisterBuiltin adds a new built-in function to the registry. +func RegisterBuiltin(b *Builtin) { + Builtins = append(Builtins, b) + BuiltinMap[b.Name] = b + if len(b.Infix) > 0 { + BuiltinMap[b.Infix] = b + } +} + +// DefaultBuiltins is the registry of built-in functions supported in OPA +// by default. When adding a new built-in function to OPA, update this +// list. +var DefaultBuiltins = [...]*Builtin{ + // Unification/equality ("=") + Equality, + + // Assignment (":=") + Assign, + + // Membership, infix "in": `x in xs` + Member, + MemberWithKey, + + // Comparisons + GreaterThan, + GreaterThanEq, + LessThan, + LessThanEq, + NotEqual, + Equal, + + // Arithmetic + Plus, + Minus, + Multiply, + Divide, + Ceil, + Floor, + Round, + Abs, + Rem, + + // Bitwise Arithmetic + BitsOr, + BitsAnd, + BitsNegate, + BitsXOr, + BitsShiftLeft, + BitsShiftRight, + + // Binary + And, + Or, + + // Aggregates + Count, + Sum, + Product, + Max, + Min, + Any, + All, + + // Arrays + ArrayConcat, + ArraySlice, + ArrayReverse, + + // Conversions + ToNumber, + + // Casts (DEPRECATED) + CastObject, + CastNull, + CastBoolean, + CastString, + CastSet, + CastArray, + + // Regular Expressions + RegexIsValid, + RegexMatch, + RegexMatchDeprecated, + RegexSplit, + GlobsMatch, + RegexTemplateMatch, + RegexFind, + RegexFindAllStringSubmatch, + RegexReplace, + + // Sets + SetDiff, + Intersection, + Union, + + // Strings + AnyPrefixMatch, + AnySuffixMatch, + Concat, + FormatInt, + IndexOf, + IndexOfN, + Substring, + Lower, + Upper, + Contains, + StringCount, + StartsWith, + EndsWith, + Split, + Replace, + ReplaceN, + Trim, + TrimLeft, + TrimPrefix, + TrimRight, + TrimSuffix, + TrimSpace, + Sprintf, + StringReverse, + RenderTemplate, + + // Numbers + NumbersRange, + NumbersRangeStep, + RandIntn, + + // Encoding + JSONMarshal, + JSONMarshalWithOptions, + JSONUnmarshal, + JSONIsValid, + Base64Encode, + Base64Decode, + Base64IsValid, + Base64UrlEncode, + Base64UrlEncodeNoPad, + Base64UrlDecode, + URLQueryDecode, + URLQueryEncode, + URLQueryEncodeObject, + URLQueryDecodeObject, + YAMLMarshal, + YAMLUnmarshal, + YAMLIsValid, + HexEncode, + HexDecode, + + // Object Manipulation + ObjectUnion, + ObjectUnionN, + ObjectRemove, + ObjectFilter, + ObjectGet, + ObjectKeys, + ObjectSubset, + + // JSON Object Manipulation + JSONFilter, + JSONRemove, + JSONPatch, + + // Tokens + JWTDecode, + JWTVerifyRS256, + JWTVerifyRS384, + JWTVerifyRS512, + JWTVerifyPS256, + JWTVerifyPS384, + JWTVerifyPS512, + JWTVerifyES256, + JWTVerifyES384, + JWTVerifyES512, + JWTVerifyHS256, + JWTVerifyHS384, + JWTVerifyHS512, + JWTDecodeVerify, + JWTEncodeSignRaw, + JWTEncodeSign, + + // Time + NowNanos, + ParseNanos, + ParseRFC3339Nanos, + ParseDurationNanos, + Format, + Date, + Clock, + Weekday, + AddDate, + Diff, + + // Crypto + CryptoX509ParseCertificates, + CryptoX509ParseAndVerifyCertificates, + CryptoX509ParseAndVerifyCertificatesWithOptions, + CryptoMd5, + CryptoSha1, + CryptoSha256, + CryptoX509ParseCertificateRequest, + CryptoX509ParseRSAPrivateKey, + CryptoX509ParseKeyPair, + CryptoParsePrivateKeys, + CryptoHmacMd5, + CryptoHmacSha1, + CryptoHmacSha256, + CryptoHmacSha512, + CryptoHmacEqual, + + // Graphs + WalkBuiltin, + ReachableBuiltin, + ReachablePathsBuiltin, + + // Sort + Sort, + + // Types + IsNumber, + IsString, + IsBoolean, + IsArray, + IsSet, + IsObject, + IsNull, + TypeNameBuiltin, + + // HTTP + HTTPSend, + + // GraphQL + GraphQLParse, + GraphQLParseAndVerify, + GraphQLParseQuery, + GraphQLParseSchema, + GraphQLIsValid, + GraphQLSchemaIsValid, + + // JSON Schema + JSONSchemaVerify, + JSONMatchSchema, + + // Cloud Provider Helpers + ProvidersAWSSignReqObj, + + // Rego + RegoParseModule, + RegoMetadataChain, + RegoMetadataRule, + + // OPA + OPARuntime, + + // Tracing + Trace, + + // Networking + NetCIDROverlap, + NetCIDRIntersects, + NetCIDRContains, + NetCIDRContainsMatches, + NetCIDRExpand, + NetCIDRMerge, + NetLookupIPAddr, + NetCIDRIsValid, + + // Glob + GlobMatch, + GlobQuoteMeta, + + // Units + UnitsParse, + UnitsParseBytes, + + // UUIDs + UUIDRFC4122, + UUIDParse, + + // SemVers + SemVerIsValid, + SemVerCompare, + + // Printing + Print, + InternalPrint, +} + +// BuiltinMap provides a convenient mapping of built-in names to +// built-in definitions. +var BuiltinMap map[string]*Builtin + +// Deprecated: Builtins can now be directly annotated with the +// Nondeterministic property, and when set to true, will be ignored +// for partial evaluation. +var IgnoreDuringPartialEval = []*Builtin{ + RandIntn, + UUIDRFC4122, + JWTDecodeVerify, + JWTEncodeSignRaw, + JWTEncodeSign, + NowNanos, + HTTPSend, + OPARuntime, + NetLookupIPAddr, +} + +/** + * Unification + */ + +// Equality represents the "=" operator. +var Equality = &Builtin{ + Name: "eq", + Infix: "=", + Decl: types.NewFunction( + types.Args(types.A, types.A), + types.B, + ), +} + +/** + * Assignment + */ + +// Assign represents the assignment (":=") operator. +var Assign = &Builtin{ + Name: "assign", + Infix: ":=", + Decl: types.NewFunction( + types.Args(types.A, types.A), + types.B, + ), +} + +// Member represents the `in` (infix) operator. +var Member = &Builtin{ + Name: "internal.member_2", + Infix: "in", + Decl: types.NewFunction( + types.Args( + types.A, + types.A, + ), + types.B, + ), +} + +// MemberWithKey represents the `in` (infix) operator when used +// with two terms on the lhs, i.e., `k, v in obj`. +var MemberWithKey = &Builtin{ + Name: "internal.member_3", + Infix: "in", + Decl: types.NewFunction( + types.Args( + types.A, + types.A, + types.A, + ), + types.B, + ), +} + +/** + * Comparisons + */ +var comparison = category("comparison") + +var GreaterThan = &Builtin{ + Name: "gt", + Infix: ">", + Categories: comparison, + Decl: types.NewFunction( + types.Args( + types.Named("x", types.A), + types.Named("y", types.A), + ), + types.Named("result", types.B).Description("true if `x` is greater than `y`; false otherwise"), + ), +} + +var GreaterThanEq = &Builtin{ + Name: "gte", + Infix: ">=", + Categories: comparison, + Decl: types.NewFunction( + types.Args( + types.Named("x", types.A), + types.Named("y", types.A), + ), + types.Named("result", types.B).Description("true if `x` is greater or equal to `y`; false otherwise"), + ), +} + +// LessThan represents the "<" comparison operator. +var LessThan = &Builtin{ + Name: "lt", + Infix: "<", + Categories: comparison, + Decl: types.NewFunction( + types.Args( + types.Named("x", types.A), + types.Named("y", types.A), + ), + types.Named("result", types.B).Description("true if `x` is less than `y`; false otherwise"), + ), +} + +var LessThanEq = &Builtin{ + Name: "lte", + Infix: "<=", + Categories: comparison, + Decl: types.NewFunction( + types.Args( + types.Named("x", types.A), + types.Named("y", types.A), + ), + types.Named("result", types.B).Description("true if `x` is less than or equal to `y`; false otherwise"), + ), +} + +var NotEqual = &Builtin{ + Name: "neq", + Infix: "!=", + Categories: comparison, + Decl: types.NewFunction( + types.Args( + types.Named("x", types.A), + types.Named("y", types.A), + ), + types.Named("result", types.B).Description("true if `x` is not equal to `y`; false otherwise"), + ), +} + +// Equal represents the "==" comparison operator. +var Equal = &Builtin{ + Name: "equal", + Infix: "==", + Categories: comparison, + Decl: types.NewFunction( + types.Args( + types.Named("x", types.A), + types.Named("y", types.A), + ), + types.Named("result", types.B).Description("true if `x` is equal to `y`; false otherwise"), + ), +} + +/** + * Arithmetic + */ +var number = category("numbers") + +var Plus = &Builtin{ + Name: "plus", + Infix: "+", + Description: "Plus adds two numbers together.", + Decl: types.NewFunction( + types.Args( + types.Named("x", types.N), + types.Named("y", types.N), + ), + types.Named("z", types.N).Description("the sum of `x` and `y`"), + ), + Categories: number, +} + +var Minus = &Builtin{ + Name: "minus", + Infix: "-", + Description: "Minus subtracts the second number from the first number or computes the difference between two sets.", + Decl: types.NewFunction( + types.Args( + types.Named("x", types.NewAny(types.N, types.NewSet(types.A))), + types.Named("y", types.NewAny(types.N, types.NewSet(types.A))), + ), + types.Named("z", types.NewAny(types.N, types.NewSet(types.A))).Description("the difference of `x` and `y`"), + ), + Categories: category("sets", "numbers"), +} + +var Multiply = &Builtin{ + Name: "mul", + Infix: "*", + Description: "Multiplies two numbers.", + Decl: types.NewFunction( + types.Args( + types.Named("x", types.N), + types.Named("y", types.N), + ), + types.Named("z", types.N).Description("the product of `x` and `y`"), + ), + Categories: number, +} + +var Divide = &Builtin{ + Name: "div", + Infix: "/", + Description: "Divides the first number by the second number.", + Decl: types.NewFunction( + types.Args( + types.Named("x", types.N).Description("the dividend"), + types.Named("y", types.N).Description("the divisor"), + ), + types.Named("z", types.N).Description("the result of `x` divided by `y`"), + ), + Categories: number, +} + +var Round = &Builtin{ + Name: "round", + Description: "Rounds the number to the nearest integer.", + Decl: types.NewFunction( + types.Args( + types.Named("x", types.N).Description("the number to round"), + ), + types.Named("y", types.N).Description("the result of rounding `x`"), + ), + Categories: number, +} + +var Ceil = &Builtin{ + Name: "ceil", + Description: "Rounds the number _up_ to the nearest integer.", + Decl: types.NewFunction( + types.Args( + types.Named("x", types.N).Description("the number to round"), + ), + types.Named("y", types.N).Description("the result of rounding `x` _up_"), + ), + Categories: number, +} + +var Floor = &Builtin{ + Name: "floor", + Description: "Rounds the number _down_ to the nearest integer.", + Decl: types.NewFunction( + types.Args( + types.Named("x", types.N).Description("the number to round"), + ), + types.Named("y", types.N).Description("the result of rounding `x` _down_"), + ), + Categories: number, +} + +var Abs = &Builtin{ + Name: "abs", + Description: "Returns the number without its sign.", + Decl: types.NewFunction( + types.Args( + types.Named("x", types.N).Description("the number to take the absolute value of"), + ), + types.Named("y", types.N).Description("the absolute value of `x`"), + ), + Categories: number, +} + +var Rem = &Builtin{ + Name: "rem", + Infix: "%", + Description: "Returns the remainder for of `x` divided by `y`, for `y != 0`.", + Decl: types.NewFunction( + types.Args( + types.Named("x", types.N), + types.Named("y", types.N), + ), + types.Named("z", types.N).Description("the remainder"), + ), + Categories: number, +} + +/** + * Bitwise + */ + +var BitsOr = &Builtin{ + Name: "bits.or", + Description: "Returns the bitwise \"OR\" of two integers.", + Decl: types.NewFunction( + types.Args( + types.Named("x", types.N).Description("the first integer"), + types.Named("y", types.N).Description("the second integer"), + ), + types.Named("z", types.N).Description("the bitwise OR of `x` and `y`"), + ), +} + +var BitsAnd = &Builtin{ + Name: "bits.and", + Description: "Returns the bitwise \"AND\" of two integers.", + Decl: types.NewFunction( + types.Args( + types.Named("x", types.N).Description("the first integer"), + types.Named("y", types.N).Description("the second integer"), + ), + types.Named("z", types.N).Description("the bitwise AND of `x` and `y`"), + ), +} + +var BitsNegate = &Builtin{ + Name: "bits.negate", + Description: "Returns the bitwise negation (flip) of an integer.", + Decl: types.NewFunction( + types.Args( + types.Named("x", types.N).Description("the integer to negate"), + ), + types.Named("z", types.N).Description("the bitwise negation of `x`"), + ), +} + +var BitsXOr = &Builtin{ + Name: "bits.xor", + Description: "Returns the bitwise \"XOR\" (exclusive-or) of two integers.", + Decl: types.NewFunction( + types.Args( + types.Named("x", types.N).Description("the first integer"), + types.Named("y", types.N).Description("the second integer"), + ), + types.Named("z", types.N).Description("the bitwise XOR of `x` and `y`"), + ), +} + +var BitsShiftLeft = &Builtin{ + Name: "bits.lsh", + Description: "Returns a new integer with its bits shifted `s` bits to the left.", + Decl: types.NewFunction( + types.Args( + types.Named("x", types.N).Description("the integer to shift"), + types.Named("s", types.N).Description("the number of bits to shift"), + ), + types.Named("z", types.N).Description("the result of shifting `x` `s` bits to the left"), + ), +} + +var BitsShiftRight = &Builtin{ + Name: "bits.rsh", + Description: "Returns a new integer with its bits shifted `s` bits to the right.", + Decl: types.NewFunction( + types.Args( + types.Named("x", types.N).Description("the integer to shift"), + types.Named("s", types.N).Description("the number of bits to shift"), + ), + types.Named("z", types.N).Description("the result of shifting `x` `s` bits to the right"), + ), +} + +/** + * Sets + */ + +var sets = category("sets") + +var And = &Builtin{ + Name: "and", + Infix: "&", + Description: "Returns the intersection of two sets.", + Decl: types.NewFunction( + types.Args( + types.Named("x", types.NewSet(types.A)).Description("the first set"), + types.Named("y", types.NewSet(types.A)).Description("the second set"), + ), + types.Named("z", types.NewSet(types.A)).Description("the intersection of `x` and `y`"), + ), + Categories: sets, +} + +// Or performs a union operation on sets. +var Or = &Builtin{ + Name: "or", + Infix: "|", + Description: "Returns the union of two sets.", + Decl: types.NewFunction( + types.Args( + types.Named("x", types.NewSet(types.A)), + types.Named("y", types.NewSet(types.A)), + ), + types.Named("z", types.NewSet(types.A)).Description("the union of `x` and `y`"), + ), + Categories: sets, +} + +var Intersection = &Builtin{ + Name: "intersection", + Description: "Returns the intersection of the given input sets.", + Decl: types.NewFunction( + types.Args( + types.Named("xs", types.NewSet(types.NewSet(types.A))).Description("set of sets to intersect"), + ), + types.Named("y", types.NewSet(types.A)).Description("the intersection of all `xs` sets"), + ), + Categories: sets, +} + +var Union = &Builtin{ + Name: "union", + Description: "Returns the union of the given input sets.", + Decl: types.NewFunction( + types.Args( + types.Named("xs", types.NewSet(types.NewSet(types.A))).Description("set of sets to merge"), + ), + types.Named("y", types.NewSet(types.A)).Description("the union of all `xs` sets"), + ), + Categories: sets, +} + +/** + * Aggregates + */ + +var aggregates = category("aggregates") + +var Count = &Builtin{ + Name: "count", + Description: " Count takes a collection or string and returns the number of elements (or characters) in it.", + Decl: types.NewFunction( + types.Args( + types.Named("collection", types.NewAny( + types.NewSet(types.A), + types.NewArray(nil, types.A), + types.NewObject(nil, types.NewDynamicProperty(types.A, types.A)), + types.S, + )).Description("the set/array/object/string to be counted"), + ), + types.Named("n", types.N).Description("the count of elements, key/val pairs, or characters, respectively."), + ), + Categories: aggregates, +} + +var Sum = &Builtin{ + Name: "sum", + Description: "Sums elements of an array or set of numbers.", + Decl: types.NewFunction( + types.Args( + types.Named("collection", types.NewAny( + types.NewSet(types.N), + types.NewArray(nil, types.N), + )).Description("the set or array of numbers to sum"), + ), + types.Named("n", types.N).Description("the sum of all elements"), + ), + Categories: aggregates, +} + +var Product = &Builtin{ + Name: "product", + Description: "Multiplies elements of an array or set of numbers", + Decl: types.NewFunction( + types.Args( + types.Named("collection", types.NewAny( + types.NewSet(types.N), + types.NewArray(nil, types.N), + )).Description("the set or array of numbers to multiply"), + ), + types.Named("n", types.N).Description("the product of all elements"), + ), + Categories: aggregates, +} + +var Max = &Builtin{ + Name: "max", + Description: "Returns the maximum value in a collection.", + Decl: types.NewFunction( + types.Args( + types.Named("collection", types.NewAny( + types.NewSet(types.A), + types.NewArray(nil, types.A), + )).Description("the set or array to be searched"), + ), + types.Named("n", types.A).Description("the maximum of all elements"), + ), + Categories: aggregates, +} + +var Min = &Builtin{ + Name: "min", + Description: "Returns the minimum value in a collection.", + Decl: types.NewFunction( + types.Args( + types.Named("collection", types.NewAny( + types.NewSet(types.A), + types.NewArray(nil, types.A), + )).Description("the set or array to be searched"), + ), + types.Named("n", types.A).Description("the minimum of all elements"), + ), + Categories: aggregates, +} + +/** + * Sorting + */ + +var Sort = &Builtin{ + Name: "sort", + Description: "Returns a sorted array.", + Decl: types.NewFunction( + types.Args( + types.Named("collection", types.NewAny( + types.NewArray(nil, types.A), + types.NewSet(types.A), + )).Description("the array or set to be sorted"), + ), + types.Named("n", types.NewArray(nil, types.A)).Description("the sorted array"), + ), + Categories: aggregates, +} + +/** + * Arrays + */ + +var ArrayConcat = &Builtin{ + Name: "array.concat", + Description: "Concatenates two arrays.", + Decl: types.NewFunction( + types.Args( + types.Named("x", types.NewArray(nil, types.A)).Description("the first array"), + types.Named("y", types.NewArray(nil, types.A)).Description("the second array"), + ), + types.Named("z", types.NewArray(nil, types.A)).Description("the concatenation of `x` and `y`"), + ), +} + +var ArraySlice = &Builtin{ + Name: "array.slice", + Description: "Returns a slice of a given array. If `start` is greater or equal than `stop`, `slice` is `[]`.", + Decl: types.NewFunction( + types.Args( + types.Named("arr", types.NewArray(nil, types.A)).Description("the array to be sliced"), + types.Named("start", types.NewNumber()).Description("the start index of the returned slice; if less than zero, it's clamped to 0"), + types.Named("stop", types.NewNumber()).Description("the stop index of the returned slice; if larger than `count(arr)`, it's clamped to `count(arr)`"), + ), + types.Named("slice", types.NewArray(nil, types.A)).Description("the subslice of `array`, from `start` to `end`, including `arr[start]`, but excluding `arr[end]`"), + ), +} // NOTE(sr): this function really needs examples + +var ArrayReverse = &Builtin{ + Name: "array.reverse", + Description: "Returns the reverse of a given array.", + Decl: types.NewFunction( + types.Args( + types.Named("arr", types.NewArray(nil, types.A)).Description("the array to be reversed"), + ), + types.Named("rev", types.NewArray(nil, types.A)).Description("an array containing the elements of `arr` in reverse order"), + ), +} + +/** + * Conversions + */ +var conversions = category("conversions") + +var ToNumber = &Builtin{ + Name: "to_number", + Description: "Converts a string, bool, or number value to a number: Strings are converted to numbers using `strconv.Atoi`, Boolean `false` is converted to 0 and `true` is converted to 1.", + Decl: types.NewFunction( + types.Args( + types.Named("x", types.NewAny( + types.N, + types.S, + types.B, + types.NewNull(), + )).Description("value to convert"), + ), + types.Named("num", types.N).Description("the numeric representation of `x`"), + ), + Categories: conversions, +} + +/** + * Regular Expressions + */ + +var RegexMatch = &Builtin{ + Name: "regex.match", + Description: "Matches a string against a regular expression.", + Decl: types.NewFunction( + types.Args( + types.Named("pattern", types.S).Description("regular expression"), + types.Named("value", types.S).Description("value to match against `pattern`"), + ), + types.Named("result", types.B).Description("true if `value` matches `pattern`"), + ), +} + +var RegexIsValid = &Builtin{ + Name: "regex.is_valid", + Description: "Checks if a string is a valid regular expression: the detailed syntax for patterns is defined by https://github.com/google/re2/wiki/Syntax.", + Decl: types.NewFunction( + types.Args( + types.Named("pattern", types.S).Description("regular expression"), + ), + types.Named("result", types.B).Description("true if `pattern` is a valid regular expression"), + ), +} + +var RegexFindAllStringSubmatch = &Builtin{ + Name: "regex.find_all_string_submatch_n", + Description: "Returns all successive matches of the expression.", + Decl: types.NewFunction( + types.Args( + types.Named("pattern", types.S).Description("regular expression"), + types.Named("value", types.S).Description("string to match"), + types.Named("number", types.N).Description("number of matches to return; `-1` means all matches"), + ), + types.Named("output", types.NewArray(nil, types.NewArray(nil, types.S))).Description("array of all matches"), + ), +} + +var RegexTemplateMatch = &Builtin{ + Name: "regex.template_match", + Description: "Matches a string against a pattern, where there pattern may be glob-like", + Decl: types.NewFunction( + types.Args( + types.Named("template", types.S).Description("template expression containing `0..n` regular expressions"), + types.Named("value", types.S).Description("string to match"), + types.Named("delimiter_start", types.S).Description("start delimiter of the regular expression in `template`"), + types.Named("delimiter_end", types.S).Description("end delimiter of the regular expression in `template`"), + ), + types.Named("result", types.B).Description("true if `value` matches the `template`"), + ), +} // TODO(sr): example:`regex.template_match("urn:foo:{.*}", "urn:foo:bar:baz", "{", "}")`` returns ``true``. + +var RegexSplit = &Builtin{ + Name: "regex.split", + Description: "Splits the input string by the occurrences of the given pattern.", + Decl: types.NewFunction( + types.Args( + types.Named("pattern", types.S).Description("regular expression"), + types.Named("value", types.S).Description("string to match"), + ), + types.Named("output", types.NewArray(nil, types.S)).Description("the parts obtained by splitting `value`"), + ), +} + +// RegexFind takes two strings and a number, the pattern, the value and number of match values to +// return, -1 means all match values. +var RegexFind = &Builtin{ + Name: "regex.find_n", + Description: "Returns the specified number of matches when matching the input against the pattern.", + Decl: types.NewFunction( + types.Args( + types.Named("pattern", types.S).Description("regular expression"), + types.Named("value", types.S).Description("string to match"), + types.Named("number", types.N).Description("number of matches to return, if `-1`, returns all matches"), + ), + types.Named("output", types.NewArray(nil, types.S)).Description("collected matches"), + ), +} + +// GlobsMatch takes two strings regexp-style strings and evaluates to true if their +// intersection matches a non-empty set of non-empty strings. +// Examples: +// - "a.a." and ".b.b" -> true. +// - "[a-z]*" and [0-9]+" -> not true. +var GlobsMatch = &Builtin{ + Name: "regex.globs_match", + Description: `Checks if the intersection of two glob-style regular expressions matches a non-empty set of non-empty strings. +The set of regex symbols is limited for this builtin: only ` + "`.`, `*`, `+`, `[`, `-`, `]` and `\\` are treated as special symbols.", + Decl: types.NewFunction( + types.Args( + types.Named("glob1", types.S).Description("first glob-style regular expression"), + types.Named("glob2", types.S).Description("second glob-style regular expression"), + ), + types.Named("result", types.B).Description("true if the intersection of `glob1` and `glob2` matches a non-empty set of non-empty strings"), + ), +} + +/** + * Strings + */ +var stringsCat = category("strings") + +var AnyPrefixMatch = &Builtin{ + Name: "strings.any_prefix_match", + Description: "Returns true if any of the search strings begins with any of the base strings.", + Decl: types.NewFunction( + types.Args( + types.Named("search", types.NewAny( + types.S, + types.NewSet(types.S), + types.NewArray(nil, types.S), + )).Description("search string(s)"), + types.Named("base", types.NewAny( + types.S, + types.NewSet(types.S), + types.NewArray(nil, types.S), + )).Description("base string(s)"), + ), + types.Named("result", types.B).Description("result of the prefix check"), + ), + Categories: stringsCat, +} + +var AnySuffixMatch = &Builtin{ + Name: "strings.any_suffix_match", + Description: "Returns true if any of the search strings ends with any of the base strings.", + Decl: types.NewFunction( + types.Args( + types.Named("search", types.NewAny( + types.S, + types.NewSet(types.S), + types.NewArray(nil, types.S), + )).Description("search string(s)"), + types.Named("base", types.NewAny( + types.S, + types.NewSet(types.S), + types.NewArray(nil, types.S), + )).Description("base string(s)"), + ), + types.Named("result", types.B).Description("result of the suffix check"), + ), + Categories: stringsCat, +} + +var Concat = &Builtin{ + Name: "concat", + Description: "Joins a set or array of strings with a delimiter.", + Decl: types.NewFunction( + types.Args( + types.Named("delimiter", types.S).Description("string to use as a delimiter"), + types.Named("collection", types.NewAny( + types.NewSet(types.S), + types.NewArray(nil, types.S), + )).Description("strings to join"), + ), + types.Named("output", types.S).Description("the joined string"), + ), + Categories: stringsCat, +} + +var FormatInt = &Builtin{ + Name: "format_int", + Description: "Returns the string representation of the number in the given base after rounding it down to an integer value.", + Decl: types.NewFunction( + types.Args( + types.Named("number", types.N).Description("number to format"), + types.Named("base", types.N).Description("base of number representation to use"), + ), + types.Named("output", types.S).Description("formatted number"), + ), + Categories: stringsCat, +} + +var IndexOf = &Builtin{ + Name: "indexof", + Description: "Returns the index of a substring contained inside a string.", + Decl: types.NewFunction( + types.Args( + types.Named("haystack", types.S).Description("string to search in"), + types.Named("needle", types.S).Description("substring to look for"), + ), + types.Named("output", types.N).Description("index of first occurrence, `-1` if not found"), + ), + Categories: stringsCat, +} + +var IndexOfN = &Builtin{ + Name: "indexof_n", + Description: "Returns a list of all the indexes of a substring contained inside a string.", + Decl: types.NewFunction( + types.Args( + types.Named("haystack", types.S).Description("string to search in"), + types.Named("needle", types.S).Description("substring to look for"), + ), + types.Named("output", types.NewArray(nil, types.N)).Description("all indices at which `needle` occurs in `haystack`, may be empty"), + ), + Categories: stringsCat, +} + +var Substring = &Builtin{ + Name: "substring", + Description: "Returns the portion of a string for a given `offset` and a `length`. If `length < 0`, `output` is the remainder of the string.", + Decl: types.NewFunction( + types.Args( + types.Named("value", types.S).Description("string to extract substring from"), + types.Named("offset", types.N).Description("offset, must be positive"), + types.Named("length", types.N).Description("length of the substring starting from `offset`"), + ), + types.Named("output", types.S).Description("substring of `value` from `offset`, of length `length`"), + ), + Categories: stringsCat, +} + +var Contains = &Builtin{ + Name: "contains", + Description: "Returns `true` if the search string is included in the base string", + Decl: types.NewFunction( + types.Args( + types.Named("haystack", types.S).Description("string to search in"), + types.Named("needle", types.S).Description("substring to look for"), + ), + types.Named("result", types.B).Description("result of the containment check"), + ), + Categories: stringsCat, +} + +var StringCount = &Builtin{ + Name: "strings.count", + Description: "Returns the number of non-overlapping instances of a substring in a string.", + Decl: types.NewFunction( + types.Args( + types.Named("search", types.S).Description("string to search in"), + types.Named("substring", types.S).Description("substring to look for"), + ), + types.Named("output", types.N).Description("count of occurrences, `0` if not found"), + ), + Categories: stringsCat, +} + +var StartsWith = &Builtin{ + Name: "startswith", + Description: "Returns true if the search string begins with the base string.", + Decl: types.NewFunction( + types.Args( + types.Named("search", types.S).Description("search string"), + types.Named("base", types.S).Description("base string"), + ), + types.Named("result", types.B).Description("result of the prefix check"), + ), + Categories: stringsCat, +} + +var EndsWith = &Builtin{ + Name: "endswith", + Description: "Returns true if the search string ends with the base string.", + Decl: types.NewFunction( + types.Args( + types.Named("search", types.S).Description("search string"), + types.Named("base", types.S).Description("base string"), + ), + types.Named("result", types.B).Description("result of the suffix check"), + ), + Categories: stringsCat, +} + +var Lower = &Builtin{ + Name: "lower", + Description: "Returns the input string but with all characters in lower-case.", + Decl: types.NewFunction( + types.Args( + types.Named("x", types.S).Description("string that is converted to lower-case"), + ), + types.Named("y", types.S).Description("lower-case of x"), + ), + Categories: stringsCat, +} + +var Upper = &Builtin{ + Name: "upper", + Description: "Returns the input string but with all characters in upper-case.", + Decl: types.NewFunction( + types.Args( + types.Named("x", types.S).Description("string that is converted to upper-case"), + ), + types.Named("y", types.S).Description("upper-case of x"), + ), + Categories: stringsCat, +} + +var Split = &Builtin{ + Name: "split", + Description: "Split returns an array containing elements of the input string split on a delimiter.", + Decl: types.NewFunction( + types.Args( + types.Named("x", types.S).Description("string that is split"), + types.Named("delimiter", types.S).Description("delimiter used for splitting"), + ), + types.Named("ys", types.NewArray(nil, types.S)).Description("split parts"), + ), + Categories: stringsCat, +} + +var Replace = &Builtin{ + Name: "replace", + Description: "Replace replaces all instances of a sub-string.", + Decl: types.NewFunction( + types.Args( + types.Named("x", types.S).Description("string being processed"), + types.Named("old", types.S).Description("substring to replace"), + types.Named("new", types.S).Description("string to replace `old` with"), + ), + types.Named("y", types.S).Description("string with replaced substrings"), + ), + Categories: stringsCat, +} + +var ReplaceN = &Builtin{ + Name: "strings.replace_n", + Description: `Replaces a string from a list of old, new string pairs. +Replacements are performed in the order they appear in the target string, without overlapping matches. +The old string comparisons are done in argument order.`, + Decl: types.NewFunction( + types.Args( + types.Named("patterns", types.NewObject( + nil, + types.NewDynamicProperty( + types.S, + types.S)), + ).Description("replacement pairs"), + types.Named("value", types.S).Description("string to replace substring matches in"), + ), + types.Named("output", types.S).Description("string with replaced substrings"), + ), +} + +var RegexReplace = &Builtin{ + Name: "regex.replace", + Description: `Find and replaces the text using the regular expression pattern.`, + Decl: types.NewFunction( + types.Args( + types.Named("s", types.S).Description("string being processed"), + types.Named("pattern", types.S).Description("regex pattern to be applied"), + types.Named("value", types.S).Description("regex value"), + ), + types.Named("output", types.S).Description("string with replaced substrings"), + ), +} + +var Trim = &Builtin{ + Name: "trim", + Description: "Returns `value` with all leading or trailing instances of the `cutset` characters removed.", + Decl: types.NewFunction( + types.Args( + types.Named("value", types.S).Description("string to trim"), + types.Named("cutset", types.S).Description("string of characters that are cut off"), + ), + types.Named("output", types.S).Description("string trimmed of `cutset` characters"), + ), + Categories: stringsCat, +} + +var TrimLeft = &Builtin{ + Name: "trim_left", + Description: "Returns `value` with all leading instances of the `cutset` characters removed.", + Decl: types.NewFunction( + types.Args( + types.Named("value", types.S).Description("string to trim"), + types.Named("cutset", types.S).Description("string of characters that are cut off on the left"), + ), + types.Named("output", types.S).Description("string left-trimmed of `cutset` characters"), + ), + Categories: stringsCat, +} + +var TrimPrefix = &Builtin{ + Name: "trim_prefix", + Description: "Returns `value` without the prefix. If `value` doesn't start with `prefix`, it is returned unchanged.", + Decl: types.NewFunction( + types.Args( + types.Named("value", types.S).Description("string to trim"), + types.Named("prefix", types.S).Description("prefix to cut off"), + ), + types.Named("output", types.S).Description("string with `prefix` cut off"), + ), + Categories: stringsCat, +} + +var TrimRight = &Builtin{ + Name: "trim_right", + Description: "Returns `value` with all trailing instances of the `cutset` characters removed.", + Decl: types.NewFunction( + types.Args( + types.Named("value", types.S).Description("string to trim"), + types.Named("cutset", types.S).Description("string of characters that are cut off on the right"), + ), + types.Named("output", types.S).Description("string right-trimmed of `cutset` characters"), + ), + Categories: stringsCat, +} + +var TrimSuffix = &Builtin{ + Name: "trim_suffix", + Description: "Returns `value` without the suffix. If `value` doesn't end with `suffix`, it is returned unchanged.", + Decl: types.NewFunction( + types.Args( + types.Named("value", types.S).Description("string to trim"), + types.Named("suffix", types.S).Description("suffix to cut off"), + ), + types.Named("output", types.S).Description("string with `suffix` cut off"), + ), + Categories: stringsCat, +} + +var TrimSpace = &Builtin{ + Name: "trim_space", + Description: "Return the given string with all leading and trailing white space removed.", + Decl: types.NewFunction( + types.Args( + types.Named("value", types.S).Description("string to trim"), + ), + types.Named("output", types.S).Description("string leading and trailing white space cut off"), + ), + Categories: stringsCat, +} + +var Sprintf = &Builtin{ + Name: "sprintf", + Description: "Returns the given string, formatted.", + Decl: types.NewFunction( + types.Args( + types.Named("format", types.S).Description("string with formatting verbs"), + types.Named("values", types.NewArray(nil, types.A)).Description("arguments to format into formatting verbs"), + ), + types.Named("output", types.S).Description("`format` formatted by the values in `values`"), + ), + Categories: stringsCat, +} + +var StringReverse = &Builtin{ + Name: "strings.reverse", + Description: "Reverses a given string.", + Decl: types.NewFunction( + types.Args( + types.Named("x", types.S).Description("string to reverse"), + ), + types.Named("y", types.S).Description("reversed string"), + ), + Categories: stringsCat, +} + +var RenderTemplate = &Builtin{ + Name: "strings.render_template", + Description: `Renders a templated string with given template variables injected. For a given templated string and key/value mapping, values will be injected into the template where they are referenced by key. + For examples of templating syntax, see https://pkg.go.dev/text/template`, + Decl: types.NewFunction( + types.Args( + types.Named("value", types.S).Description("a templated string"), + types.Named("vars", types.NewObject(nil, types.NewDynamicProperty(types.S, types.A))).Description("a mapping of template variable keys to values"), + ), + types.Named("result", types.S).Description("rendered template with template variables injected"), + ), + Categories: stringsCat, +} + +/** + * Numbers + */ + +// RandIntn returns a random number 0 - n +// Marked non-deterministic because it relies on RNG internally. +var RandIntn = &Builtin{ + Name: "rand.intn", + Description: "Returns a random integer between `0` and `n` (`n` exclusive). If `n` is `0`, then `y` is always `0`. For any given argument pair (`str`, `n`), the output will be consistent throughout a query evaluation.", + Decl: types.NewFunction( + types.Args( + types.Named("str", types.S).Description("seed string for the random number"), + types.Named("n", types.N).Description("upper bound of the random number (exclusive)"), + ), + types.Named("y", types.N).Description("random integer in the range `[0, abs(n))`"), + ), + Categories: number, + Nondeterministic: true, +} + +var NumbersRange = &Builtin{ + Name: "numbers.range", + Description: "Returns an array of numbers in the given (inclusive) range. If `a==b`, then `range == [a]`; if `a > b`, then `range` is in descending order.", + Decl: types.NewFunction( + types.Args( + types.Named("a", types.N).Description("the start of the range"), + types.Named("b", types.N).Description("the end of the range (inclusive)"), + ), + types.Named("range", types.NewArray(nil, types.N)).Description("the range between `a` and `b`"), + ), +} + +var NumbersRangeStep = &Builtin{ + Name: "numbers.range_step", + Description: `Returns an array of numbers in the given (inclusive) range incremented by a positive step. + If "a==b", then "range == [a]"; if "a > b", then "range" is in descending order. + If the provided "step" is less then 1, an error will be thrown. + If "b" is not in the range of the provided "step", "b" won't be included in the result. + `, + Decl: types.NewFunction( + types.Args( + types.Named("a", types.N).Description("the start of the range"), + types.Named("b", types.N).Description("the end of the range (inclusive)"), + types.Named("step", types.N).Description("the step between numbers in the range"), + ), + types.Named("range", types.NewArray(nil, types.N)).Description("the range between `a` and `b` in `step` increments"), + ), +} + +/** + * Units + */ + +var UnitsParse = &Builtin{ + Name: "units.parse", + Description: `Converts strings like "10G", "5K", "4M", "1500m", and the like into a number. +This number can be a non-integer, such as 1.5, 0.22, etc. Scientific notation is supported, +allowing values such as "1e-3K" (1) or "2.5e6M" (2.5 million M). + +Supports standard metric decimal and binary SI units (e.g., K, Ki, M, Mi, G, Gi, etc.) where +m, K, M, G, T, P, and E are treated as decimal units and Ki, Mi, Gi, Ti, Pi, and Ei are treated as +binary units. + +Note that 'm' and 'M' are case-sensitive to allow distinguishing between "milli" and "mega" units +respectively. Other units are case-insensitive.`, + Decl: types.NewFunction( + types.Args( + types.Named("x", types.S).Description("the unit to parse"), + ), + types.Named("y", types.N).Description("the parsed number"), + ), +} + +var UnitsParseBytes = &Builtin{ + Name: "units.parse_bytes", + Description: `Converts strings like "10GB", "5K", "4mb", or "1e6KB" into an integer number of bytes. + +Supports standard byte units (e.g., KB, KiB, etc.) where KB, MB, GB, and TB are treated as decimal +units, and KiB, MiB, GiB, and TiB are treated as binary units. Scientific notation is supported, +enabling values like "1.5e3MB" (1500MB) or "2e6GiB" (2 million GiB). + +The bytes symbol (b/B) in the unit is optional; omitting it will yield the same result (e.g., "Mi" +and "MiB" are equivalent).`, + Decl: types.NewFunction( + types.Args( + types.Named("x", types.S).Description("the byte unit to parse"), + ), + types.Named("y", types.N).Description("the parsed number"), + ), +} + +// +/** + * Type + */ + +// UUIDRFC4122 returns a version 4 UUID string. +// Marked non-deterministic because it relies on RNG internally. +var UUIDRFC4122 = &Builtin{ + Name: "uuid.rfc4122", + Description: "Returns a new UUIDv4.", + Decl: types.NewFunction( + types.Args( + types.Named("k", types.S).Description("seed string"), + ), + types.Named("output", types.S).Description("a version 4 UUID; for any given `k`, the output will be consistent throughout a query evaluation"), + ), + Nondeterministic: true, +} + +var UUIDParse = &Builtin{ + Name: "uuid.parse", + Description: "Parses the string value as an UUID and returns an object with the well-defined fields of the UUID if valid.", + Categories: nil, + Decl: types.NewFunction( + types.Args( + types.Named("uuid", types.S).Description("UUID string to parse"), + ), + types.Named("result", types.NewObject(nil, types.NewDynamicProperty(types.S, types.A))).Description("Properties of UUID if valid (version, variant, etc). Undefined otherwise."), + ), + Relation: false, +} + +/** + * JSON + */ + +var objectCat = category("object") + +var JSONFilter = &Builtin{ + Name: "json.filter", + Description: "Filters the object. " + + "For example: `json.filter({\"a\": {\"b\": \"x\", \"c\": \"y\"}}, [\"a/b\"])` will result in `{\"a\": {\"b\": \"x\"}}`). " + + "Paths are not filtered in-order and are deduplicated before being evaluated.", + Decl: types.NewFunction( + types.Args( + types.Named("object", types.NewObject( + nil, + types.NewDynamicProperty(types.A, types.A), + )).Description("object to filter"), + types.Named("paths", types.NewAny( + types.NewArray( + nil, + types.NewAny( + types.S, + types.NewArray( + nil, + types.A, + ), + ), + ), + types.NewSet( + types.NewAny( + types.S, + types.NewArray( + nil, + types.A, + ), + ), + ), + )).Description("JSON string paths"), + ), + types.Named("filtered", types.A).Description("remaining data from `object` with only keys specified in `paths`"), + ), + Categories: objectCat, +} + +var JSONRemove = &Builtin{ + Name: "json.remove", + Description: "Removes paths from an object. " + + "For example: `json.remove({\"a\": {\"b\": \"x\", \"c\": \"y\"}}, [\"a/b\"])` will result in `{\"a\": {\"c\": \"y\"}}`. " + + "Paths are not removed in-order and are deduplicated before being evaluated.", + Decl: types.NewFunction( + types.Args( + types.Named("object", types.NewObject( + nil, + types.NewDynamicProperty(types.A, types.A), + )).Description("object to remove paths from"), + types.Named("paths", types.NewAny( + types.NewArray( + nil, + types.NewAny( + types.S, + types.NewArray( + nil, + types.A, + ), + ), + ), + types.NewSet( + types.NewAny( + types.S, + types.NewArray( + nil, + types.A, + ), + ), + ), + )).Description("JSON string paths"), + ), + types.Named("output", types.A).Description("result of removing all keys specified in `paths`"), + ), + Categories: objectCat, +} + +var JSONPatch = &Builtin{ + Name: "json.patch", + Description: "Patches an object according to RFC6902. " + + "For example: `json.patch({\"a\": {\"foo\": 1}}, [{\"op\": \"add\", \"path\": \"/a/bar\", \"value\": 2}])` results in `{\"a\": {\"foo\": 1, \"bar\": 2}`. " + + "The patches are applied atomically: if any of them fails, the result will be undefined. " + + "Additionally works on sets, where a value contained in the set is considered to be its path.", + Decl: types.NewFunction( + types.Args( + types.Named("object", types.A).Description("the object to patch"), // TODO(sr): types.A? + types.Named("patches", types.NewArray( + nil, + types.NewObject( + []*types.StaticProperty{ + {Key: "op", Value: types.S}, + {Key: "path", Value: types.A}, + }, + types.NewDynamicProperty(types.A, types.A), + ), + )).Description("the JSON patches to apply"), + ), + types.Named("output", types.A).Description("result obtained after consecutively applying all patch operations in `patches`"), + ), + Categories: objectCat, +} + +var ObjectSubset = &Builtin{ + Name: "object.subset", + Description: "Determines if an object `sub` is a subset of another object `super`." + + "Object `sub` is a subset of object `super` if and only if every key in `sub` is also in `super`, " + + "**and** for all keys which `sub` and `super` share, they have the same value. " + + "This function works with objects, sets, arrays and a set of array and set." + + "If both arguments are objects, then the operation is recursive, e.g. " + + "`{\"c\": {\"x\": {10, 15, 20}}` is a subset of `{\"a\": \"b\", \"c\": {\"x\": {10, 15, 20, 25}, \"y\": \"z\"}`. " + + "If both arguments are sets, then this function checks if every element of `sub` is a member of `super`, " + + "but does not attempt to recurse. If both arguments are arrays, " + + "then this function checks if `sub` appears contiguously in order within `super`, " + + "and also does not attempt to recurse. If `super` is array and `sub` is set, " + + "then this function checks if `super` contains every element of `sub` with no consideration of ordering, " + + "and also does not attempt to recurse.", + Decl: types.NewFunction( + types.Args( + types.Named("super", types.NewAny(types.NewObject( + nil, + types.NewDynamicProperty(types.A, types.A), + ), types.NewSet(types.A), + types.NewArray(nil, types.A), + )).Description("object to test if sub is a subset of"), + types.Named("sub", types.NewAny(types.NewObject( + nil, + types.NewDynamicProperty(types.A, types.A), + ), types.NewSet(types.A), + types.NewArray(nil, types.A), + )).Description("object to test if super is a superset of"), + ), + types.Named("result", types.A).Description("`true` if `sub` is a subset of `super`"), + ), +} + +var ObjectUnion = &Builtin{ + Name: "object.union", + Description: "Creates a new object of the asymmetric union of two objects. " + + "For example: `object.union({\"a\": 1, \"b\": 2, \"c\": {\"d\": 3}}, {\"a\": 7, \"c\": {\"d\": 4, \"e\": 5}})` will result in `{\"a\": 7, \"b\": 2, \"c\": {\"d\": 4, \"e\": 5}}`.", + Decl: types.NewFunction( + types.Args( + types.Named("a", types.NewObject( + nil, + types.NewDynamicProperty(types.A, types.A), + )).Description("left-hand object"), + types.Named("b", types.NewObject( + nil, + types.NewDynamicProperty(types.A, types.A), + )).Description("right-hand object"), + ), + types.Named("output", types.A).Description("a new object which is the result of an asymmetric recursive union of two objects where conflicts are resolved by choosing the key from the right-hand object `b`"), + ), // TODO(sr): types.A? ^^^^^^^ (also below) +} + +var ObjectUnionN = &Builtin{ + Name: "object.union_n", + Description: "Creates a new object that is the asymmetric union of all objects merged from left to right. " + + "For example: `object.union_n([{\"a\": 1}, {\"b\": 2}, {\"a\": 3}])` will result in `{\"b\": 2, \"a\": 3}`.", + Decl: types.NewFunction( + types.Args( + types.Named("objects", types.NewArray( + nil, + types.NewObject(nil, types.NewDynamicProperty(types.A, types.A)), + )).Description("list of objects to merge"), + ), + types.Named("output", types.A).Description("asymmetric recursive union of all objects in `objects`, merged from left to right, where conflicts are resolved by choosing the key from the right-hand object"), + ), +} + +var ObjectRemove = &Builtin{ + Name: "object.remove", + Description: "Removes specified keys from an object.", + Decl: types.NewFunction( + types.Args( + types.Named("object", types.NewObject( + nil, + types.NewDynamicProperty(types.A, types.A), + )).Description("object to remove keys from"), + types.Named("keys", types.NewAny( + types.NewArray(nil, types.A), + types.NewSet(types.A), + types.NewObject(nil, types.NewDynamicProperty(types.A, types.A)), + )).Description("keys to remove from x"), + ), + types.Named("output", types.A).Description("result of removing the specified `keys` from `object`"), + ), +} + +var ObjectFilter = &Builtin{ + Name: "object.filter", + Description: "Filters the object by keeping only specified keys. " + + "For example: `object.filter({\"a\": {\"b\": \"x\", \"c\": \"y\"}, \"d\": \"z\"}, [\"a\"])` will result in `{\"a\": {\"b\": \"x\", \"c\": \"y\"}}`).", + Decl: types.NewFunction( + types.Args( + types.Named("object", types.NewObject( + nil, + types.NewDynamicProperty(types.A, types.A), + )).Description("object to filter keys"), + types.Named("keys", types.NewAny( + types.NewArray(nil, types.A), + types.NewSet(types.A), + types.NewObject(nil, types.NewDynamicProperty(types.A, types.A)), + )).Description("keys to keep in `object`"), + ), + types.Named("filtered", types.A).Description("remaining data from `object` with only keys specified in `keys`"), + ), +} + +var ObjectGet = &Builtin{ + Name: "object.get", + Description: "Returns value of an object's key if present, otherwise a default. " + + "If the supplied `key` is an `array`, then `object.get` will search through a nested object or array using each key in turn. " + + "For example: `object.get({\"a\": [{ \"b\": true }]}, [\"a\", 0, \"b\"], false)` results in `true`.", + Decl: types.NewFunction( + types.Args( + types.Named("object", types.NewObject(nil, types.NewDynamicProperty(types.A, types.A))).Description("object to get `key` from"), + types.Named("key", types.A).Description("key to lookup in `object`"), + types.Named("default", types.A).Description("default to use when lookup fails"), + ), + types.Named("value", types.A).Description("`object[key]` if present, otherwise `default`"), + ), +} + +var ObjectKeys = &Builtin{ + Name: "object.keys", + Description: "Returns a set of an object's keys. " + + "For example: `object.keys({\"a\": 1, \"b\": true, \"c\": \"d\")` results in `{\"a\", \"b\", \"c\"}`.", + Decl: types.NewFunction( + types.Args( + types.Named("object", types.NewObject(nil, types.NewDynamicProperty(types.A, types.A))).Description("object to get keys from"), + ), + types.Named("value", types.NewSet(types.A)).Description("set of `object`'s keys"), + ), +} + +/* + * Encoding + */ +var encoding = category("encoding") + +var JSONMarshal = &Builtin{ + Name: "json.marshal", + Description: "Serializes the input term to JSON.", + Decl: types.NewFunction( + types.Args( + types.Named("x", types.A).Description("the term to serialize"), + ), + types.Named("y", types.S).Description("the JSON string representation of `x`"), + ), + Categories: encoding, +} + +var JSONMarshalWithOptions = &Builtin{ + Name: "json.marshal_with_options", + Description: "Serializes the input term JSON, with additional formatting options via the `opts` parameter. " + + "`opts` accepts keys `pretty` (enable multi-line/formatted JSON), `prefix` (string to prefix lines with, default empty string) and `indent` (string to indent with, default `\\t`).", + Decl: types.NewFunction( + types.Args( + types.Named("x", types.A).Description("the term to serialize"), + types.Named("opts", types.NewObject( + []*types.StaticProperty{ + types.NewStaticProperty("pretty", types.B), + types.NewStaticProperty("indent", types.S), + types.NewStaticProperty("prefix", types.S), + }, + types.NewDynamicProperty(types.S, types.A), + )).Description("encoding options"), + ), + types.Named("y", types.S).Description("the JSON string representation of `x`, with configured prefix/indent string(s) as appropriate"), + ), + Categories: encoding, +} + +var JSONUnmarshal = &Builtin{ + Name: "json.unmarshal", + Description: "Deserializes the input string.", + Decl: types.NewFunction( + types.Args( + types.Named("x", types.S).Description("a JSON string"), + ), + types.Named("y", types.A).Description("the term deserialized from `x`"), + ), + Categories: encoding, +} + +var JSONIsValid = &Builtin{ + Name: "json.is_valid", + Description: "Verifies the input string is a valid JSON document.", + Decl: types.NewFunction( + types.Args( + types.Named("x", types.S).Description("a JSON string"), + ), + types.Named("result", types.B).Description("`true` if `x` is valid JSON, `false` otherwise"), + ), + Categories: encoding, +} + +var Base64Encode = &Builtin{ + Name: "base64.encode", + Description: "Serializes the input string into base64 encoding.", + Decl: types.NewFunction( + types.Args( + types.Named("x", types.S).Description("string to encode"), + ), + types.Named("y", types.S).Description("base64 serialization of `x`"), + ), + Categories: encoding, +} + +var Base64Decode = &Builtin{ + Name: "base64.decode", + Description: "Deserializes the base64 encoded input string.", + Decl: types.NewFunction( + types.Args( + types.Named("x", types.S).Description("string to decode"), + ), + types.Named("y", types.S).Description("base64 deserialization of `x`"), + ), + Categories: encoding, +} + +var Base64IsValid = &Builtin{ + Name: "base64.is_valid", + Description: "Verifies the input string is base64 encoded.", + Decl: types.NewFunction( + types.Args( + types.Named("x", types.S).Description("string to check"), + ), + types.Named("result", types.B).Description("`true` if `x` is valid base64 encoded value, `false` otherwise"), + ), + Categories: encoding, +} + +var Base64UrlEncode = &Builtin{ + Name: "base64url.encode", + Description: "Serializes the input string into base64url encoding.", + Decl: types.NewFunction( + types.Args( + types.Named("x", types.S).Description("string to encode"), + ), + types.Named("y", types.S).Description("base64url serialization of `x`"), + ), + Categories: encoding, +} + +var Base64UrlEncodeNoPad = &Builtin{ + Name: "base64url.encode_no_pad", + Description: "Serializes the input string into base64url encoding without padding.", + Decl: types.NewFunction( + types.Args( + types.Named("x", types.S).Description("string to encode"), + ), + types.Named("y", types.S).Description("base64url serialization of `x`"), + ), + Categories: encoding, +} + +var Base64UrlDecode = &Builtin{ + Name: "base64url.decode", + Description: "Deserializes the base64url encoded input string.", + Decl: types.NewFunction( + types.Args( + types.Named("x", types.S).Description("string to decode"), + ), + types.Named("y", types.S).Description("base64url deserialization of `x`"), + ), + Categories: encoding, +} + +var URLQueryDecode = &Builtin{ + Name: "urlquery.decode", + Description: "Decodes a URL-encoded input string.", + Decl: types.NewFunction( + types.Args( + types.Named("x", types.S).Description("the URL-encoded string"), + ), + types.Named("y", types.S).Description("URL-encoding deserialization of `x`"), + ), + Categories: encoding, +} + +var URLQueryEncode = &Builtin{ + Name: "urlquery.encode", + Description: "Encodes the input string into a URL-encoded string.", + Decl: types.NewFunction( + types.Args( + types.Named("x", types.S).Description("the string to encode"), + ), + types.Named("y", types.S).Description("URL-encoding serialization of `x`"), + ), + Categories: encoding, +} + +var URLQueryEncodeObject = &Builtin{ + Name: "urlquery.encode_object", + Description: "Encodes the given object into a URL encoded query string.", + Decl: types.NewFunction( + types.Args( + types.Named("object", types.NewObject( + nil, + types.NewDynamicProperty( + types.S, + types.NewAny( + types.S, + types.NewArray(nil, types.S), + types.NewSet(types.S)), + ), + ), + ).Description("the object to encode"), + ), + types.Named("y", types.S).Description("the URL-encoded serialization of `object`"), + ), + Categories: encoding, +} + +var URLQueryDecodeObject = &Builtin{ + Name: "urlquery.decode_object", + Description: "Decodes the given URL query string into an object.", + Decl: types.NewFunction( + types.Args( + types.Named("x", types.S).Description("the query string"), + ), + types.Named("object", types.NewObject(nil, types.NewDynamicProperty( + types.S, + types.NewArray(nil, types.S)))).Description("the resulting object"), + ), + Categories: encoding, +} + +var YAMLMarshal = &Builtin{ + Name: "yaml.marshal", + Description: "Serializes the input term to YAML.", + Decl: types.NewFunction( + types.Args( + types.Named("x", types.A).Description("the term to serialize"), + ), + types.Named("y", types.S).Description("the YAML string representation of `x`"), + ), + Categories: encoding, +} + +var YAMLUnmarshal = &Builtin{ + Name: "yaml.unmarshal", + Description: "Deserializes the input string.", + Decl: types.NewFunction( + types.Args( + types.Named("x", types.S).Description("a YAML string"), + ), + types.Named("y", types.A).Description("the term deserialized from `x`"), + ), + Categories: encoding, +} + +// YAMLIsValid verifies the input string is a valid YAML document. +var YAMLIsValid = &Builtin{ + Name: "yaml.is_valid", + Description: "Verifies the input string is a valid YAML document.", + Decl: types.NewFunction( + types.Args( + types.Named("x", types.S).Description("a YAML string"), + ), + types.Named("result", types.B).Description("`true` if `x` is valid YAML, `false` otherwise"), + ), + Categories: encoding, +} + +var HexEncode = &Builtin{ + Name: "hex.encode", + Description: "Serializes the input string using hex-encoding.", + Decl: types.NewFunction( + types.Args( + types.Named("x", types.S).Description("string to encode"), + ), + types.Named("y", types.S).Description("serialization of `x` using hex-encoding"), + ), + Categories: encoding, +} + +var HexDecode = &Builtin{ + Name: "hex.decode", + Description: "Deserializes the hex-encoded input string.", + Decl: types.NewFunction( + types.Args( + types.Named("x", types.S).Description("a hex-encoded string"), + ), + types.Named("y", types.S).Description("deserialized from `x`"), + ), + Categories: encoding, +} + +/** + * Tokens + */ +var tokensCat = category("tokens") + +var JWTDecode = &Builtin{ + Name: "io.jwt.decode", + Description: "Decodes a JSON Web Token and outputs it as an object.", + Decl: types.NewFunction( + types.Args( + types.Named("jwt", types.S).Description("JWT token to decode"), + ), + types.Named("output", types.NewArray([]types.Type{ + types.NewObject(nil, types.NewDynamicProperty(types.A, types.A)), + types.NewObject(nil, types.NewDynamicProperty(types.A, types.A)), + types.S, + }, nil)).Description("`[header, payload, sig]`, where `header` and `payload` are objects; `sig` is the hexadecimal representation of the signature on the token."), + ), + Categories: tokensCat, +} + +var JWTVerifyRS256 = &Builtin{ + Name: "io.jwt.verify_rs256", + Description: "Verifies if a RS256 JWT signature is valid.", + Decl: types.NewFunction( + types.Args( + types.Named("jwt", types.S).Description("JWT token whose signature is to be verified"), + types.Named("certificate", types.S).Description("PEM encoded certificate, PEM encoded public key, or the JWK key (set) used to verify the signature"), + ), + types.Named("result", types.B).Description("`true` if the signature is valid, `false` otherwise"), + ), + Categories: tokensCat, +} + +var JWTVerifyRS384 = &Builtin{ + Name: "io.jwt.verify_rs384", + Description: "Verifies if a RS384 JWT signature is valid.", + Decl: types.NewFunction( + types.Args( + types.Named("jwt", types.S).Description("JWT token whose signature is to be verified"), + types.Named("certificate", types.S).Description("PEM encoded certificate, PEM encoded public key, or the JWK key (set) used to verify the signature"), + ), + types.Named("result", types.B).Description("`true` if the signature is valid, `false` otherwise"), + ), + Categories: tokensCat, +} + +var JWTVerifyRS512 = &Builtin{ + Name: "io.jwt.verify_rs512", + Description: "Verifies if a RS512 JWT signature is valid.", + Decl: types.NewFunction( + types.Args( + types.Named("jwt", types.S).Description("JWT token whose signature is to be verified"), + types.Named("certificate", types.S).Description("PEM encoded certificate, PEM encoded public key, or the JWK key (set) used to verify the signature"), + ), + types.Named("result", types.B).Description("`true` if the signature is valid, `false` otherwise"), + ), + Categories: tokensCat, +} + +var JWTVerifyPS256 = &Builtin{ + Name: "io.jwt.verify_ps256", + Description: "Verifies if a PS256 JWT signature is valid.", + Decl: types.NewFunction( + types.Args( + types.Named("jwt", types.S).Description("JWT token whose signature is to be verified"), + types.Named("certificate", types.S).Description("PEM encoded certificate, PEM encoded public key, or the JWK key (set) used to verify the signature"), + ), + types.Named("result", types.B).Description("`true` if the signature is valid, `false` otherwise"), + ), + Categories: tokensCat, +} + +var JWTVerifyPS384 = &Builtin{ + Name: "io.jwt.verify_ps384", + Description: "Verifies if a PS384 JWT signature is valid.", + Decl: types.NewFunction( + types.Args( + types.Named("jwt", types.S).Description("JWT token whose signature is to be verified"), + types.Named("certificate", types.S).Description("PEM encoded certificate, PEM encoded public key, or the JWK key (set) used to verify the signature"), + ), + types.Named("result", types.B).Description("`true` if the signature is valid, `false` otherwise"), + ), + Categories: tokensCat, +} + +var JWTVerifyPS512 = &Builtin{ + Name: "io.jwt.verify_ps512", + Description: "Verifies if a PS512 JWT signature is valid.", + Decl: types.NewFunction( + types.Args( + types.Named("jwt", types.S).Description("JWT token whose signature is to be verified"), + types.Named("certificate", types.S).Description("PEM encoded certificate, PEM encoded public key, or the JWK key (set) used to verify the signature"), + ), + types.Named("result", types.B).Description("`true` if the signature is valid, `false` otherwise"), + ), + Categories: tokensCat, +} + +var JWTVerifyES256 = &Builtin{ + Name: "io.jwt.verify_es256", + Description: "Verifies if a ES256 JWT signature is valid.", + Decl: types.NewFunction( + types.Args( + types.Named("jwt", types.S).Description("JWT token whose signature is to be verified"), + types.Named("certificate", types.S).Description("PEM encoded certificate, PEM encoded public key, or the JWK key (set) used to verify the signature"), + ), + types.Named("result", types.B).Description("`true` if the signature is valid, `false` otherwise"), + ), + Categories: tokensCat, +} + +var JWTVerifyES384 = &Builtin{ + Name: "io.jwt.verify_es384", + Description: "Verifies if a ES384 JWT signature is valid.", + Decl: types.NewFunction( + types.Args( + types.Named("jwt", types.S).Description("JWT token whose signature is to be verified"), + types.Named("certificate", types.S).Description("PEM encoded certificate, PEM encoded public key, or the JWK key (set) used to verify the signature"), + ), + types.Named("result", types.B).Description("`true` if the signature is valid, `false` otherwise"), + ), + Categories: tokensCat, +} + +var JWTVerifyES512 = &Builtin{ + Name: "io.jwt.verify_es512", + Description: "Verifies if a ES512 JWT signature is valid.", + Decl: types.NewFunction( + types.Args( + types.Named("jwt", types.S).Description("JWT token whose signature is to be verified"), + types.Named("certificate", types.S).Description("PEM encoded certificate, PEM encoded public key, or the JWK key (set) used to verify the signature"), + ), + types.Named("result", types.B).Description("`true` if the signature is valid, `false` otherwise"), + ), + Categories: tokensCat, +} + +var JWTVerifyHS256 = &Builtin{ + Name: "io.jwt.verify_hs256", + Description: "Verifies if a HS256 (secret) JWT signature is valid.", + Decl: types.NewFunction( + types.Args( + types.Named("jwt", types.S).Description("JWT token whose signature is to be verified"), + types.Named("secret", types.S).Description("plain text secret used to verify the signature"), + ), + types.Named("result", types.B).Description("`true` if the signature is valid, `false` otherwise"), + ), + Categories: tokensCat, +} + +var JWTVerifyHS384 = &Builtin{ + Name: "io.jwt.verify_hs384", + Description: "Verifies if a HS384 (secret) JWT signature is valid.", + Decl: types.NewFunction( + types.Args( + types.Named("jwt", types.S).Description("JWT token whose signature is to be verified"), + types.Named("secret", types.S).Description("plain text secret used to verify the signature"), + ), + types.Named("result", types.B).Description("`true` if the signature is valid, `false` otherwise"), + ), + Categories: tokensCat, +} + +var JWTVerifyHS512 = &Builtin{ + Name: "io.jwt.verify_hs512", + Description: "Verifies if a HS512 (secret) JWT signature is valid.", + Decl: types.NewFunction( + types.Args( + types.Named("jwt", types.S).Description("JWT token whose signature is to be verified"), + types.Named("secret", types.S).Description("plain text secret used to verify the signature"), + ), + types.Named("result", types.B).Description("`true` if the signature is valid, `false` otherwise"), + ), + Categories: tokensCat, +} + +// Marked non-deterministic because it relies on time internally. +var JWTDecodeVerify = &Builtin{ + Name: "io.jwt.decode_verify", + Description: `Verifies a JWT signature under parameterized constraints and decodes the claims if it is valid. +Supports the following algorithms: HS256, HS384, HS512, RS256, RS384, RS512, ES256, ES384, ES512, PS256, PS384 and PS512.`, + Decl: types.NewFunction( + types.Args( + types.Named("jwt", types.S).Description("JWT token whose signature is to be verified and whose claims are to be checked"), + types.Named("constraints", types.NewObject(nil, types.NewDynamicProperty(types.S, types.A))).Description("claim verification constraints"), + ), + types.Named("output", types.NewArray([]types.Type{ + types.B, + types.NewObject(nil, types.NewDynamicProperty(types.A, types.A)), + types.NewObject(nil, types.NewDynamicProperty(types.A, types.A)), + }, nil)).Description("`[valid, header, payload]`: if the input token is verified and meets the requirements of `constraints` then `valid` is `true`; `header` and `payload` are objects containing the JOSE header and the JWT claim set; otherwise, `valid` is `false`, `header` and `payload` are `{}`"), + ), + Categories: tokensCat, + Nondeterministic: true, +} + +var tokenSign = category("tokensign") + +// Marked non-deterministic because it relies on RNG internally. +var JWTEncodeSignRaw = &Builtin{ + Name: "io.jwt.encode_sign_raw", + Description: "Encodes and optionally signs a JSON Web Token.", + Decl: types.NewFunction( + types.Args( + types.Named("headers", types.S).Description("JWS Protected Header"), + types.Named("payload", types.S).Description("JWS Payload"), + types.Named("key", types.S).Description("JSON Web Key (RFC7517)"), + ), + types.Named("output", types.S).Description("signed JWT"), + ), + Categories: tokenSign, + Nondeterministic: true, +} + +// Marked non-deterministic because it relies on RNG internally. +var JWTEncodeSign = &Builtin{ + Name: "io.jwt.encode_sign", + Description: "Encodes and optionally signs a JSON Web Token. Inputs are taken as objects, not encoded strings (see `io.jwt.encode_sign_raw`).", + Decl: types.NewFunction( + types.Args( + types.Named("headers", types.NewObject(nil, types.NewDynamicProperty(types.S, types.A))).Description("JWS Protected Header"), + types.Named("payload", types.NewObject(nil, types.NewDynamicProperty(types.S, types.A))).Description("JWS Payload"), + types.Named("key", types.NewObject(nil, types.NewDynamicProperty(types.S, types.A))).Description("JSON Web Key (RFC7517)"), + ), + types.Named("output", types.S).Description("signed JWT"), + ), + Categories: tokenSign, + Nondeterministic: true, +} + +/** + * Time + */ + +// Marked non-deterministic because it relies on time directly. +var NowNanos = &Builtin{ + Name: "time.now_ns", + Description: "Returns the current time since epoch in nanoseconds.", + Decl: types.NewFunction( + nil, + types.Named("now", types.N).Description("nanoseconds since epoch"), + ), + Nondeterministic: true, +} + +var ParseNanos = &Builtin{ + Name: "time.parse_ns", + Description: "Returns the time in nanoseconds parsed from the string in the given format. `undefined` if the result would be outside the valid time range that can fit within an `int64`.", + Decl: types.NewFunction( + types.Args( + types.Named("layout", types.S).Description("format used for parsing, see the [Go `time` package documentation](https://golang.org/pkg/time/#Parse) for more details"), + types.Named("value", types.S).Description("input to parse according to `layout`"), + ), + types.Named("ns", types.N).Description("`value` in nanoseconds since epoch"), + ), +} + +var ParseRFC3339Nanos = &Builtin{ + Name: "time.parse_rfc3339_ns", + Description: "Returns the time in nanoseconds parsed from the string in RFC3339 format. `undefined` if the result would be outside the valid time range that can fit within an `int64`.", + Decl: types.NewFunction( + types.Args( + types.Named("value", types.S).Description("input string to parse in RFC3339 format"), + ), + types.Named("ns", types.N).Description("`value` in nanoseconds since epoch"), + ), +} + +var ParseDurationNanos = &Builtin{ + Name: "time.parse_duration_ns", + Description: "Returns the duration in nanoseconds represented by a string.", + Decl: types.NewFunction( + types.Args( + types.Named("duration", types.S).Description("a duration like \"3m\"; see the [Go `time` package documentation](https://golang.org/pkg/time/#ParseDuration) for more details"), + ), + types.Named("ns", types.N).Description("the `duration` in nanoseconds"), + ), +} + +var Format = &Builtin{ + Name: "time.format", + Description: "Returns the formatted timestamp for the nanoseconds since epoch.", + Decl: types.NewFunction( + types.Args( + types.Named("x", types.NewAny( + types.N, + types.NewArray([]types.Type{types.N, types.S}, nil), + types.NewArray([]types.Type{types.N, types.S, types.S}, nil), + )).Description("a number representing the nanoseconds since the epoch (UTC); or a two-element array of the nanoseconds, and a timezone string; or a three-element array of ns, timezone string and a layout string or golang defined formatting constant (see golang supported time formats)"), + ), + types.Named("formatted timestamp", types.S).Description("the formatted timestamp represented for the nanoseconds since the epoch in the supplied timezone (or UTC)"), + ), +} + +var Date = &Builtin{ + Name: "time.date", + Description: "Returns the `[year, month, day]` for the nanoseconds since epoch.", + Decl: types.NewFunction( + types.Args( + types.Named("x", types.NewAny( + types.N, + types.NewArray([]types.Type{types.N, types.S}, nil), + )).Description("a number representing the nanoseconds since the epoch (UTC); or a two-element array of the nanoseconds, and a timezone string"), + ), + types.Named("date", types.NewArray([]types.Type{types.N, types.N, types.N}, nil)).Description("an array of `year`, `month` (1-12), and `day` (1-31)"), + ), +} + +var Clock = &Builtin{ + Name: "time.clock", + Description: "Returns the `[hour, minute, second]` of the day for the nanoseconds since epoch.", + Decl: types.NewFunction( + types.Args( + types.Named("x", types.NewAny( + types.N, + types.NewArray([]types.Type{types.N, types.S}, nil), + )).Description("a number representing the nanoseconds since the epoch (UTC); or a two-element array of the nanoseconds, and a timezone string"), + ), + types.Named("output", types.NewArray([]types.Type{types.N, types.N, types.N}, nil)). + Description("the `hour`, `minute` (0-59), and `second` (0-59) representing the time of day for the nanoseconds since epoch in the supplied timezone (or UTC)"), + ), +} + +var Weekday = &Builtin{ + Name: "time.weekday", + Description: "Returns the day of the week (Monday, Tuesday, ...) for the nanoseconds since epoch.", + Decl: types.NewFunction( + types.Args( + types.Named("x", types.NewAny( + types.N, + types.NewArray([]types.Type{types.N, types.S}, nil), + )).Description("a number representing the nanoseconds since the epoch (UTC); or a two-element array of the nanoseconds, and a timezone string"), + ), + types.Named("day", types.S).Description("the weekday represented by `ns` nanoseconds since the epoch in the supplied timezone (or UTC)"), + ), +} + +var AddDate = &Builtin{ + Name: "time.add_date", + Description: "Returns the nanoseconds since epoch after adding years, months and days to nanoseconds. Month & day values outside their usual ranges after the operation and will be normalized - for example, October 32 would become November 1. `undefined` if the result would be outside the valid time range that can fit within an `int64`.", + Decl: types.NewFunction( + types.Args( + types.Named("ns", types.N).Description("nanoseconds since the epoch"), + types.Named("years", types.N).Description("number of years to add"), + types.Named("months", types.N).Description("number of months to add"), + types.Named("days", types.N).Description("number of days to add"), + ), + types.Named("output", types.N).Description("nanoseconds since the epoch representing the input time, with years, months and days added"), + ), +} + +var Diff = &Builtin{ + Name: "time.diff", + Description: "Returns the difference between two unix timestamps in nanoseconds (with optional timezone strings).", + Decl: types.NewFunction( + types.Args( + types.Named("ns1", types.NewAny( + types.N, + types.NewArray([]types.Type{types.N, types.S}, nil), + )).Description("nanoseconds since the epoch; or a two-element array of the nanoseconds, and a timezone string"), + types.Named("ns2", types.NewAny( + types.N, + types.NewArray([]types.Type{types.N, types.S}, nil), + )).Description("nanoseconds since the epoch; or a two-element array of the nanoseconds, and a timezone string"), + ), + types.Named("output", types.NewArray([]types.Type{types.N, types.N, types.N, types.N, types.N, types.N}, nil)).Description("difference between `ns1` and `ns2` (in their supplied timezones, if supplied, or UTC) as array of numbers: `[years, months, days, hours, minutes, seconds]`"), + ), +} + +/** + * Crypto. + */ + +var CryptoX509ParseCertificates = &Builtin{ + Name: "crypto.x509.parse_certificates", + Description: `Returns zero or more certificates from the given encoded string containing +DER certificate data. + +If the input is empty, the function will return null. The input string should be a list of one or more +concatenated PEM blocks. The whole input of concatenated PEM blocks can optionally be Base64 encoded.`, + Decl: types.NewFunction( + types.Args( + types.Named("certs", types.S).Description("base64 encoded DER or PEM data containing one or more certificates or a PEM string of one or more certificates"), + ), + types.Named("output", types.NewArray(nil, types.NewObject(nil, types.NewDynamicProperty(types.S, types.A)))).Description("parsed X.509 certificates represented as objects"), + ), +} + +var CryptoX509ParseAndVerifyCertificates = &Builtin{ + Name: "crypto.x509.parse_and_verify_certificates", + Description: `Returns one or more certificates from the given string containing PEM +or base64 encoded DER certificates after verifying the supplied certificates form a complete +certificate chain back to a trusted root. + +The first certificate is treated as the root and the last is treated as the leaf, +with all others being treated as intermediates.`, + Decl: types.NewFunction( + types.Args( + types.Named("certs", types.S).Description("base64 encoded DER or PEM data containing two or more certificates where the first is a root CA, the last is a leaf certificate, and all others are intermediate CAs"), + ), + types.Named("output", types.NewArray([]types.Type{ + types.B, + types.NewArray(nil, types.NewObject(nil, types.NewDynamicProperty(types.S, types.A))), + }, nil)).Description("array of `[valid, certs]`: if the input certificate chain could be verified then `valid` is `true` and `certs` is an array of X.509 certificates represented as objects; if the input certificate chain could not be verified then `valid` is `false` and `certs` is `[]`"), + ), +} + +var CryptoX509ParseAndVerifyCertificatesWithOptions = &Builtin{ + Name: "crypto.x509.parse_and_verify_certificates_with_options", + Description: `Returns one or more certificates from the given string containing PEM +or base64 encoded DER certificates after verifying the supplied certificates form a complete +certificate chain back to a trusted root. A config option passed as the second argument can +be used to configure the validation options used. + +The first certificate is treated as the root and the last is treated as the leaf, +with all others being treated as intermediates.`, + + Decl: types.NewFunction( + types.Args( + types.Named("certs", types.S).Description("base64 encoded DER or PEM data containing two or more certificates where the first is a root CA, the last is a leaf certificate, and all others are intermediate CAs"), + types.Named("options", types.NewObject( + nil, + types.NewDynamicProperty(types.S, types.A), + )).Description("object containing extra configs to verify the validity of certificates. `options` object supports four fields which maps to same fields in [x509.VerifyOptions struct](https://pkg.go.dev/crypto/x509#VerifyOptions). `DNSName`, `CurrentTime`: Nanoseconds since the Unix Epoch as a number, `MaxConstraintComparisons` and `KeyUsages`. `KeyUsages` is list and can have possible values as in: `\"KeyUsageAny\"`, `\"KeyUsageServerAuth\"`, `\"KeyUsageClientAuth\"`, `\"KeyUsageCodeSigning\"`, `\"KeyUsageEmailProtection\"`, `\"KeyUsageIPSECEndSystem\"`, `\"KeyUsageIPSECTunnel\"`, `\"KeyUsageIPSECUser\"`, `\"KeyUsageTimeStamping\"`, `\"KeyUsageOCSPSigning\"`, `\"KeyUsageMicrosoftServerGatedCrypto\"`, `\"KeyUsageNetscapeServerGatedCrypto\"`, `\"KeyUsageMicrosoftCommercialCodeSigning\"`, `\"KeyUsageMicrosoftKernelCodeSigning\"` "), + ), + types.Named("output", types.NewArray([]types.Type{ + types.B, + types.NewArray(nil, types.NewObject(nil, types.NewDynamicProperty(types.S, types.A))), + }, nil)).Description("array of `[valid, certs]`: if the input certificate chain could be verified then `valid` is `true` and `certs` is an array of X.509 certificates represented as objects; if the input certificate chain could not be verified then `valid` is `false` and `certs` is `[]`"), + ), +} + +var CryptoX509ParseCertificateRequest = &Builtin{ + Name: "crypto.x509.parse_certificate_request", + Description: "Returns a PKCS #10 certificate signing request from the given PEM-encoded PKCS#10 certificate signing request.", + Decl: types.NewFunction( + types.Args( + types.Named("csr", types.S).Description("base64 string containing either a PEM encoded or DER CSR or a string containing a PEM CSR"), + ), + types.Named("output", types.NewObject(nil, types.NewDynamicProperty(types.S, types.A))).Description("X.509 CSR represented as an object"), + ), +} + +var CryptoX509ParseKeyPair = &Builtin{ + Name: "crypto.x509.parse_keypair", + Description: "Returns a valid key pair", + Decl: types.NewFunction( + types.Args( + types.Named("cert", types.S).Description("string containing PEM or base64 encoded DER certificates"), + types.Named("pem", types.S).Description("string containing PEM or base64 encoded DER keys"), + ), + types.Named("output", types.NewObject(nil, types.NewDynamicProperty(types.S, types.A))).Description("if key pair is valid, returns the tls.certificate(https://pkg.go.dev/crypto/tls#Certificate) as an object. If the key pair is invalid, nil and an error are returned."), + ), +} +var CryptoX509ParseRSAPrivateKey = &Builtin{ + Name: "crypto.x509.parse_rsa_private_key", + Description: "Returns a JWK for signing a JWT from the given PEM-encoded RSA private key.", + Decl: types.NewFunction( + types.Args( + types.Named("pem", types.S).Description("base64 string containing a PEM encoded RSA private key"), + ), + types.Named("output", types.NewObject(nil, types.NewDynamicProperty(types.S, types.A))).Description("JWK as an object"), + ), +} + +var CryptoParsePrivateKeys = &Builtin{ + Name: "crypto.parse_private_keys", + Description: `Returns zero or more private keys from the given encoded string containing DER certificate data. + +If the input is empty, the function will return null. The input string should be a list of one or more concatenated PEM blocks. The whole input of concatenated PEM blocks can optionally be Base64 encoded.`, + Decl: types.NewFunction( + types.Args( + types.Named("keys", types.S).Description("PEM encoded data containing one or more private keys as concatenated blocks. Optionally Base64 encoded."), + ), + types.Named("output", types.NewArray(nil, types.NewObject(nil, types.NewDynamicProperty(types.S, types.A)))).Description("parsed private keys represented as objects"), + ), +} + +var CryptoMd5 = &Builtin{ + Name: "crypto.md5", + Description: "Returns a string representing the input string hashed with the MD5 function", + Decl: types.NewFunction( + types.Args( + types.Named("x", types.S).Description("input string"), + ), + types.Named("y", types.S).Description("MD5-hash of `x`"), + ), +} + +var CryptoSha1 = &Builtin{ + Name: "crypto.sha1", + Description: "Returns a string representing the input string hashed with the SHA1 function", + Decl: types.NewFunction( + types.Args( + types.Named("x", types.S).Description("input string"), + ), + types.Named("y", types.S).Description("SHA1-hash of `x`"), + ), +} + +var CryptoSha256 = &Builtin{ + Name: "crypto.sha256", + Description: "Returns a string representing the input string hashed with the SHA256 function", + Decl: types.NewFunction( + types.Args( + types.Named("x", types.S).Description("input string"), + ), + types.Named("y", types.S).Description("SHA256-hash of `x`"), + ), +} + +var CryptoHmacMd5 = &Builtin{ + Name: "crypto.hmac.md5", + Description: "Returns a string representing the MD5 HMAC of the input message using the input key.", + Decl: types.NewFunction( + types.Args( + types.Named("x", types.S).Description("input string"), + types.Named("key", types.S).Description("key to use"), + ), + types.Named("y", types.S).Description("MD5-HMAC of `x`"), + ), +} + +var CryptoHmacSha1 = &Builtin{ + Name: "crypto.hmac.sha1", + Description: "Returns a string representing the SHA1 HMAC of the input message using the input key.", + Decl: types.NewFunction( + types.Args( + types.Named("x", types.S).Description("input string"), + types.Named("key", types.S).Description("key to use"), + ), + types.Named("y", types.S).Description("SHA1-HMAC of `x`"), + ), +} + +var CryptoHmacSha256 = &Builtin{ + Name: "crypto.hmac.sha256", + Description: "Returns a string representing the SHA256 HMAC of the input message using the input key.", + Decl: types.NewFunction( + types.Args( + types.Named("x", types.S).Description("input string"), + types.Named("key", types.S).Description("key to use"), + ), + types.Named("y", types.S).Description("SHA256-HMAC of `x`"), + ), +} + +var CryptoHmacSha512 = &Builtin{ + Name: "crypto.hmac.sha512", + Description: "Returns a string representing the SHA512 HMAC of the input message using the input key.", + Decl: types.NewFunction( + types.Args( + types.Named("x", types.S).Description("input string"), + types.Named("key", types.S).Description("key to use"), + ), + types.Named("y", types.S).Description("SHA512-HMAC of `x`"), + ), +} + +var CryptoHmacEqual = &Builtin{ + Name: "crypto.hmac.equal", + Description: "Returns a boolean representing the result of comparing two MACs for equality without leaking timing information.", + Decl: types.NewFunction( + types.Args( + types.Named("mac1", types.S).Description("mac1 to compare"), + types.Named("mac2", types.S).Description("mac2 to compare"), + ), + types.Named("result", types.B).Description("`true` if the MACs are equals, `false` otherwise"), + ), +} + +/** + * Graphs. + */ +var graphs = category("graph") + +var WalkBuiltin = &Builtin{ + Name: "walk", + Relation: true, + Description: "Generates `[path, value]` tuples for all nested documents of `x` (recursively). Queries can use `walk` to traverse documents nested under `x`.", + Decl: types.NewFunction( + types.Args( + types.Named("x", types.A).Description("value to walk"), + ), + types.Named("output", types.NewArray( + []types.Type{ + types.NewArray(nil, types.A), + types.A, + }, + nil, + )).Description("pairs of `path` and `value`: `path` is an array representing the pointer to `value` in `x`. If `path` is assigned a wildcard (`_`), the `walk` function will skip path creation entirely for faster evaluation."), + ), + Categories: graphs, +} + +var ReachableBuiltin = &Builtin{ + Name: "graph.reachable", + Description: "Computes the set of reachable nodes in the graph from a set of starting nodes.", + Decl: types.NewFunction( + types.Args( + types.Named("graph", types.NewObject( + nil, + types.NewDynamicProperty( + types.A, + types.NewAny( + types.NewSet(types.A), + types.NewArray(nil, types.A)), + )), + ).Description("object containing a set or array of neighboring vertices"), + types.Named("initial", types.NewAny(types.NewSet(types.A), types.NewArray(nil, types.A))).Description("set or array of root vertices"), + ), + types.Named("output", types.NewSet(types.A)).Description("set of vertices reachable from the `initial` vertices in the directed `graph`"), + ), +} + +var ReachablePathsBuiltin = &Builtin{ + Name: "graph.reachable_paths", + Description: "Computes the set of reachable paths in the graph from a set of starting nodes.", + Decl: types.NewFunction( + types.Args( + types.Named("graph", types.NewObject( + nil, + types.NewDynamicProperty( + types.A, + types.NewAny( + types.NewSet(types.A), + types.NewArray(nil, types.A)), + )), + ).Description("object containing a set or array of root vertices"), + types.Named("initial", types.NewAny(types.NewSet(types.A), types.NewArray(nil, types.A))).Description("initial paths"), // TODO(sr): copied. is that correct? + ), + types.Named("output", types.NewSet(types.NewArray(nil, types.A))).Description("paths reachable from the `initial` vertices in the directed `graph`"), + ), +} + +/** + * Type + */ +var typesCat = category("types") + +var IsNumber = &Builtin{ + Name: "is_number", + Description: "Returns `true` if the input value is a number.", + Decl: types.NewFunction( + types.Args( + types.Named("x", types.A).Description("input value"), + ), + types.Named("result", types.B).Description("`true` if `x` is a number, `false` otherwise."), + ), + Categories: typesCat, +} + +var IsString = &Builtin{ + Name: "is_string", + Description: "Returns `true` if the input value is a string.", + Decl: types.NewFunction( + types.Args( + types.Named("x", types.A).Description("input value"), + ), + types.Named("result", types.B).Description("`true` if `x` is a string, `false` otherwise."), + ), + Categories: typesCat, +} + +var IsBoolean = &Builtin{ + Name: "is_boolean", + Description: "Returns `true` if the input value is a boolean.", + Decl: types.NewFunction( + types.Args( + types.Named("x", types.A).Description("input value"), + ), + types.Named("result", types.B).Description("`true` if `x` is an boolean, `false` otherwise."), + ), + Categories: typesCat, +} + +var IsArray = &Builtin{ + Name: "is_array", + Description: "Returns `true` if the input value is an array.", + Decl: types.NewFunction( + types.Args( + types.Named("x", types.A).Description("input value"), + ), + types.Named("result", types.B).Description("`true` if `x` is an array, `false` otherwise."), + ), + Categories: typesCat, +} + +var IsSet = &Builtin{ + Name: "is_set", + Description: "Returns `true` if the input value is a set.", + Decl: types.NewFunction( + types.Args( + types.Named("x", types.A).Description("input value"), + ), + types.Named("result", types.B).Description("`true` if `x` is a set, `false` otherwise."), + ), + Categories: typesCat, +} + +var IsObject = &Builtin{ + Name: "is_object", + Description: "Returns true if the input value is an object", + Decl: types.NewFunction( + types.Args( + types.Named("x", types.A).Description("input value"), + ), + types.Named("result", types.B).Description("`true` if `x` is an object, `false` otherwise."), + ), + Categories: typesCat, +} + +var IsNull = &Builtin{ + Name: "is_null", + Description: "Returns `true` if the input value is null.", + Decl: types.NewFunction( + types.Args( + types.Named("x", types.A).Description("input value"), + ), + types.Named("result", types.B).Description("`true` if `x` is null, `false` otherwise."), + ), + Categories: typesCat, +} + +/** + * Type Name + */ + +// TypeNameBuiltin returns the type of the input. +var TypeNameBuiltin = &Builtin{ + Name: "type_name", + Description: "Returns the type of its input value.", + Decl: types.NewFunction( + types.Args( + types.Named("x", types.A).Description("input value"), + ), + types.Named("type", types.S).Description(`one of "null", "boolean", "number", "string", "array", "object", "set"`), + ), + Categories: typesCat, +} + +/** + * HTTP Request + */ + +// Marked non-deterministic because HTTP request results can be non-deterministic. +var HTTPSend = &Builtin{ + Name: "http.send", + Description: "Returns a HTTP response to the given HTTP request.", + Decl: types.NewFunction( + types.Args( + types.Named("request", types.NewObject(nil, types.NewDynamicProperty(types.S, types.A))). + Description("the HTTP request object"), + ), + types.Named("response", types.NewObject(nil, types.NewDynamicProperty(types.A, types.A))). + Description("the HTTP response object"), + ), + Nondeterministic: true, +} + +/** + * GraphQL + */ + +// GraphQLParse returns a pair of AST objects from parsing/validation. +var GraphQLParse = &Builtin{ + Name: "graphql.parse", + Description: "Returns AST objects for a given GraphQL query and schema after validating the query against the schema. Returns undefined if errors were encountered during parsing or validation. The query and/or schema can be either GraphQL strings or AST objects from the other GraphQL builtin functions.", + Decl: types.NewFunction( + types.Args( + types.Named("query", types.NewAny(types.S, types.NewObject(nil, types.NewDynamicProperty(types.A, types.A)))). + Description("the GraphQL query"), + types.Named("schema", types.NewAny(types.S, types.NewObject(nil, types.NewDynamicProperty(types.A, types.A)))). + Description("the GraphQL schema"), + ), + types.Named("output", types.NewArray([]types.Type{ + types.NewObject(nil, types.NewDynamicProperty(types.A, types.A)), + types.NewObject(nil, types.NewDynamicProperty(types.A, types.A)), + }, nil)).Description("`output` is of the form `[query_ast, schema_ast]`. If the GraphQL query is valid given the provided schema, then `query_ast` and `schema_ast` are objects describing the ASTs for the query and schema."), + ), +} + +// GraphQLParseAndVerify returns a boolean and a pair of AST object from parsing/validation. +var GraphQLParseAndVerify = &Builtin{ + Name: "graphql.parse_and_verify", + Description: "Returns a boolean indicating success or failure alongside the parsed ASTs for a given GraphQL query and schema after validating the query against the schema. The query and/or schema can be either GraphQL strings or AST objects from the other GraphQL builtin functions.", + Decl: types.NewFunction( + types.Args( + types.Named("query", types.NewAny(types.S, types.NewObject(nil, types.NewDynamicProperty(types.A, types.A)))). + Description("the GraphQL query"), + types.Named("schema", types.NewAny(types.S, types.NewObject(nil, types.NewDynamicProperty(types.A, types.A)))). + Description("the GraphQL schema"), + ), + types.Named("output", types.NewArray([]types.Type{ + types.B, + types.NewObject(nil, types.NewDynamicProperty(types.A, types.A)), + types.NewObject(nil, types.NewDynamicProperty(types.A, types.A)), + }, nil)).Description(" `output` is of the form `[valid, query_ast, schema_ast]`. If the query is valid given the provided schema, then `valid` is `true`, and `query_ast` and `schema_ast` are objects describing the ASTs for the GraphQL query and schema. Otherwise, `valid` is `false` and `query_ast` and `schema_ast` are `{}`."), + ), +} + +// GraphQLParseQuery parses the input GraphQL query and returns a JSON +// representation of its AST. +var GraphQLParseQuery = &Builtin{ + Name: "graphql.parse_query", + Description: "Returns an AST object for a GraphQL query.", + Decl: types.NewFunction( + types.Args( + types.Named("query", types.S).Description("GraphQL query string"), + ), + types.Named("output", types.NewObject(nil, types.NewDynamicProperty(types.A, types.A))).Description("AST object for the GraphQL query."), + ), +} + +// GraphQLParseSchema parses the input GraphQL schema and returns a JSON +// representation of its AST. +var GraphQLParseSchema = &Builtin{ + Name: "graphql.parse_schema", + Description: "Returns an AST object for a GraphQL schema.", + Decl: types.NewFunction( + types.Args( + types.Named("schema", types.S).Description("GraphQL schema string"), + ), + types.Named("output", types.NewObject(nil, types.NewDynamicProperty(types.A, types.A))).Description("AST object for the GraphQL schema."), + ), +} + +// GraphQLIsValid returns true if a GraphQL query is valid with a given +// schema, and returns false for all other inputs. +var GraphQLIsValid = &Builtin{ + Name: "graphql.is_valid", + Description: "Checks that a GraphQL query is valid against a given schema. The query and/or schema can be either GraphQL strings or AST objects from the other GraphQL builtin functions.", + Decl: types.NewFunction( + types.Args( + types.Named("query", types.NewAny(types.S, types.NewObject(nil, types.NewDynamicProperty(types.A, types.A)))). + Description("the GraphQL query"), + types.Named("schema", types.NewAny(types.S, types.NewObject(nil, types.NewDynamicProperty(types.A, types.A)))). + Description("the GraphQL schema"), + ), + types.Named("output", types.B).Description("`true` if the query is valid under the given schema. `false` otherwise."), + ), +} + +// GraphQLSchemaIsValid returns true if the input is valid GraphQL schema, +// and returns false for all other inputs. +var GraphQLSchemaIsValid = &Builtin{ + Name: "graphql.schema_is_valid", + Description: "Checks that the input is a valid GraphQL schema. The schema can be either a GraphQL string or an AST object from the other GraphQL builtin functions.", + Decl: types.NewFunction( + types.Args( + types.Named("schema", types.NewAny(types.S, types.NewObject(nil, types.NewDynamicProperty(types.A, types.A)))). + Description("the schema to verify"), + ), + types.Named("output", types.B).Description("`true` if the schema is a valid GraphQL schema. `false` otherwise."), + ), +} + +/** + * JSON Schema + */ + +// JSONSchemaVerify returns empty string if the input is valid JSON schema +// and returns error string for all other inputs. +var JSONSchemaVerify = &Builtin{ + Name: "json.verify_schema", + Description: "Checks that the input is a valid JSON schema object. The schema can be either a JSON string or an JSON object.", + Decl: types.NewFunction( + types.Args( + types.Named("schema", types.NewAny(types.S, types.NewObject(nil, types.NewDynamicProperty(types.A, types.A)))). + Description("the schema to verify"), + ), + types.Named("output", types.NewArray([]types.Type{ + types.B, + types.NewAny(types.S, types.Null{}), + }, nil)). + Description("`output` is of the form `[valid, error]`. If the schema is valid, then `valid` is `true`, and `error` is `null`. Otherwise, `valid` is `false` and `error` is a string describing the error."), + ), + Categories: objectCat, +} + +// JSONMatchSchema returns empty array if the document matches the JSON schema, +// and returns non-empty array with error objects otherwise. +var JSONMatchSchema = &Builtin{ + Name: "json.match_schema", + Description: "Checks that the document matches the JSON schema.", + Decl: types.NewFunction( + types.Args( + types.Named("document", types.NewAny(types.S, types.NewObject(nil, types.NewDynamicProperty(types.A, types.A)))). + Description("document to verify by schema"), + types.Named("schema", types.NewAny(types.S, types.NewObject(nil, types.NewDynamicProperty(types.A, types.A)))). + Description("schema to verify document by"), + ), + types.Named("output", types.NewArray([]types.Type{ + types.B, + types.NewArray( + nil, types.NewObject( + []*types.StaticProperty{ + {Key: "error", Value: types.S}, + {Key: "type", Value: types.S}, + {Key: "field", Value: types.S}, + {Key: "desc", Value: types.S}, + }, + nil, + ), + ), + }, nil)). + Description("`output` is of the form `[match, errors]`. If the document is valid given the schema, then `match` is `true`, and `errors` is an empty array. Otherwise, `match` is `false` and `errors` is an array of objects describing the error(s)."), + ), + Categories: objectCat, +} + +/** + * Cloud Provider Helper Functions + */ +var providersAWSCat = category("providers.aws") + +var ProvidersAWSSignReqObj = &Builtin{ + Name: "providers.aws.sign_req", + Description: "Signs an HTTP request object for Amazon Web Services. Currently implements [AWS Signature Version 4 request signing](https://docs.aws.amazon.com/AmazonS3/latest/API/sig-v4-authenticating-requests.html) by the `Authorization` header method.", + Decl: types.NewFunction( + types.Args( + types.Named("request", types.NewObject(nil, types.NewDynamicProperty(types.S, types.A))). + Description("HTTP request object"), + types.Named("aws_config", types.NewObject(nil, types.NewDynamicProperty(types.S, types.A))). + Description("AWS configuration object"), + types.Named("time_ns", types.N).Description("nanoseconds since the epoch"), + ), + types.Named("signed_request", types.NewObject(nil, types.NewDynamicProperty(types.A, types.A))). + Description("HTTP request object with `Authorization` header"), + ), + Categories: providersAWSCat, +} + +/** + * Rego + */ + +var RegoParseModule = &Builtin{ + Name: "rego.parse_module", + Description: "Parses the input Rego string and returns an object representation of the AST.", + Decl: types.NewFunction( + types.Args( + types.Named("filename", types.S).Description("file name to attach to AST nodes' locations"), + types.Named("rego", types.S).Description("Rego module"), + ), + types.Named("output", types.NewObject(nil, types.NewDynamicProperty(types.S, types.A))). + Description("AST object for the Rego module"), + ), +} + +var RegoMetadataChain = &Builtin{ + Name: "rego.metadata.chain", + Description: `Returns the chain of metadata for the active rule. +Ordered starting at the active rule, going outward to the most distant node in its package ancestry. +A chain entry is a JSON document with two members: "path", an array representing the path of the node; and "annotations", a JSON document containing the annotations declared for the node. +The first entry in the chain always points to the active rule, even if it has no declared annotations (in which case the "annotations" member is not present).`, + Decl: types.NewFunction( + types.Args(), + types.Named("chain", types.NewArray(nil, types.A)).Description("each array entry represents a node in the path ancestry (chain) of the active rule that also has declared annotations"), + ), +} + +// RegoMetadataRule returns the metadata for the active rule +var RegoMetadataRule = &Builtin{ + Name: "rego.metadata.rule", + Description: "Returns annotations declared for the active rule and using the _rule_ scope.", + Decl: types.NewFunction( + types.Args(), + types.Named("output", types.A).Description("\"rule\" scope annotations for this rule; empty object if no annotations exist"), + ), +} + +/** + * OPA + */ + +// Marked non-deterministic because of unpredictable config/environment-dependent results. +var OPARuntime = &Builtin{ + Name: "opa.runtime", + Description: "Returns an object that describes the runtime environment where OPA is deployed.", + Decl: types.NewFunction( + nil, + types.Named("output", types.NewObject(nil, types.NewDynamicProperty(types.S, types.A))). + Description("includes a `config` key if OPA was started with a configuration file; an `env` key containing the environment variables that the OPA process was started with; includes `version` and `commit` keys containing the version and build commit of OPA."), + ), + Nondeterministic: true, +} + +/** + * Trace + */ +var tracing = category("tracing") + +var Trace = &Builtin{ + Name: "trace", + Description: "Emits `note` as a `Note` event in the query explanation. Query explanations show the exact expressions evaluated by OPA during policy execution. For example, `trace(\"Hello There!\")` includes `Note \"Hello There!\"` in the query explanation. To include variables in the message, use `sprintf`. For example, `person := \"Bob\"; trace(sprintf(\"Hello There! %v\", [person]))` will emit `Note \"Hello There! Bob\"` inside of the explanation.", + Decl: types.NewFunction( + types.Args( + types.Named("note", types.S).Description("the note to include"), + ), + types.Named("result", types.B).Description("always `true`"), + ), + Categories: tracing, +} + +/** + * Glob + */ + +var GlobMatch = &Builtin{ + Name: "glob.match", + Description: "Parses and matches strings against the glob notation. Not to be confused with `regex.globs_match`.", + Decl: types.NewFunction( + types.Args( + types.Named("pattern", types.S).Description("glob pattern"), + types.Named("delimiters", types.NewAny( + types.NewArray(nil, types.S), + types.NewNull(), + )).Description("glob pattern delimiters, e.g. `[\".\", \":\"]`, defaults to `[\".\"]` if unset. If `delimiters` is `null`, glob match without delimiter."), + types.Named("match", types.S).Description("string to match against `pattern`"), + ), + types.Named("result", types.B).Description("true if `match` can be found in `pattern` which is separated by `delimiters`"), + ), +} + +var GlobQuoteMeta = &Builtin{ + Name: "glob.quote_meta", + Description: "Returns a string which represents a version of the pattern where all asterisks have been escaped.", + Decl: types.NewFunction( + types.Args( + types.Named("pattern", types.S).Description("glob pattern"), + ), + types.Named("output", types.S).Description("the escaped string of `pattern`"), + ), + // TODO(sr): example for this was: Calling ``glob.quote_meta("*.github.com", output)`` returns ``\\*.github.com`` as ``output``. +} + +/** + * Networking + */ + +var NetCIDRIntersects = &Builtin{ + Name: "net.cidr_intersects", + Description: "Checks if a CIDR intersects with another CIDR (e.g. `192.168.0.0/16` overlaps with `192.168.1.0/24`). Supports both IPv4 and IPv6 notations.", + Decl: types.NewFunction( + types.Args( + types.Named("cidr1", types.S).Description("first CIDR"), + types.Named("cidr2", types.S).Description("second CIDR"), + ), + types.Named("result", types.B).Description("`true` if `cidr1` intersects with `cidr2`"), + ), +} + +var NetCIDRExpand = &Builtin{ + Name: "net.cidr_expand", + Description: "Expands CIDR to set of hosts (e.g., `net.cidr_expand(\"192.168.0.0/30\")` generates 4 hosts: `{\"192.168.0.0\", \"192.168.0.1\", \"192.168.0.2\", \"192.168.0.3\"}`).", + Decl: types.NewFunction( + types.Args( + types.Named("cidr", types.S).Description("CIDR to expand"), + ), + types.Named("hosts", types.NewSet(types.S)).Description("set of IP addresses the CIDR `cidr` expands to"), + ), +} + +var NetCIDRContains = &Builtin{ + Name: "net.cidr_contains", + Description: "Checks if a CIDR or IP is contained within another CIDR. `output` is `true` if `cidr_or_ip` (e.g. `127.0.0.64/26` or `127.0.0.1`) is contained within `cidr` (e.g. `127.0.0.1/24`) and `false` otherwise. Supports both IPv4 and IPv6 notations.", + Decl: types.NewFunction( + types.Args( + types.Named("cidr", types.S).Description("CIDR to check against"), + types.Named("cidr_or_ip", types.S).Description("CIDR or IP to check"), + ), + types.Named("result", types.B).Description("`true` if `cidr_or_ip` is contained within `cidr`"), + ), +} + +var NetCIDRContainsMatches = &Builtin{ + Name: "net.cidr_contains_matches", + Description: "Checks if collections of cidrs or ips are contained within another collection of cidrs and returns matches. " + + "This function is similar to `net.cidr_contains` except it allows callers to pass collections of CIDRs or IPs as arguments and returns the matches (as opposed to a boolean result indicating a match between two CIDRs/IPs).", + Decl: types.NewFunction( + types.Args( + types.Named("cidrs", netCidrContainsMatchesOperandType).Description("CIDRs to check against"), + types.Named("cidrs_or_ips", netCidrContainsMatchesOperandType).Description("CIDRs or IPs to check"), + ), + types.Named("output", types.NewSet(types.NewArray([]types.Type{types.A, types.A}, nil))).Description("tuples identifying matches where `cidrs_or_ips` are contained within `cidrs`"), + ), +} + +var NetCIDRMerge = &Builtin{ + Name: "net.cidr_merge", + Description: "Merges IP addresses and subnets into the smallest possible list of CIDRs (e.g., `net.cidr_merge([\"192.0.128.0/24\", \"192.0.129.0/24\"])` generates `{\"192.0.128.0/23\"}`." + + `This function merges adjacent subnets where possible, those contained within others and also removes any duplicates. +Supports both IPv4 and IPv6 notations. IPv6 inputs need a prefix length (e.g. "/128").`, + Decl: types.NewFunction( + types.Args( + types.Named("addrs", types.NewAny( + types.NewArray(nil, types.NewAny(types.S)), + types.NewSet(types.S), + )).Description("CIDRs or IP addresses"), + ), + types.Named("output", types.NewSet(types.S)).Description("smallest possible set of CIDRs obtained after merging the provided list of IP addresses and subnets in `addrs`"), + ), +} + +var NetCIDRIsValid = &Builtin{ + Name: "net.cidr_is_valid", + Description: "Parses an IPv4/IPv6 CIDR and returns a boolean indicating if the provided CIDR is valid.", + Decl: types.NewFunction( + types.Args( + types.Named("cidr", types.S).Description("CIDR to validate"), + ), + types.Named("result", types.B).Description("`true` if `cidr` is a valid CIDR"), + ), +} + +var netCidrContainsMatchesOperandType = types.NewAny( + types.S, + types.NewArray(nil, types.NewAny( + types.S, + types.NewArray(nil, types.A), + )), + types.NewSet(types.NewAny( + types.S, + types.NewArray(nil, types.A), + )), + types.NewObject(nil, types.NewDynamicProperty( + types.S, + types.NewAny( + types.S, + types.NewArray(nil, types.A), + ), + )), +) + +// Marked non-deterministic because DNS resolution results can be non-deterministic. +var NetLookupIPAddr = &Builtin{ + Name: "net.lookup_ip_addr", + Description: "Returns the set of IP addresses (both v4 and v6) that the passed-in `name` resolves to using the standard name resolution mechanisms available.", + Decl: types.NewFunction( + types.Args( + types.Named("name", types.S).Description("domain name to resolve"), + ), + types.Named("addrs", types.NewSet(types.S)).Description("IP addresses (v4 and v6) that `name` resolves to"), + ), + Nondeterministic: true, +} + +/** + * Semantic Versions + */ + +var SemVerIsValid = &Builtin{ + Name: "semver.is_valid", + Description: "Validates that the input is a valid SemVer string.", + Decl: types.NewFunction( + types.Args( + types.Named("vsn", types.A).Description("input to validate"), + ), + types.Named("result", types.B).Description("`true` if `vsn` is a valid SemVer; `false` otherwise"), + ), +} + +var SemVerCompare = &Builtin{ + Name: "semver.compare", + Description: "Compares valid SemVer formatted version strings.", + Decl: types.NewFunction( + types.Args( + types.Named("a", types.S).Description("first version string"), + types.Named("b", types.S).Description("second version string"), + ), + types.Named("result", types.N).Description("`-1` if `a < b`; `1` if `a > b`; `0` if `a == b`"), + ), +} + +/** + * Printing + */ + +// Print is a special built-in function that writes zero or more operands +// to a message buffer. The caller controls how the buffer is displayed. The +// operands may be of any type. Furthermore, unlike other built-in functions, +// undefined operands DO NOT cause the print() function to fail during +// evaluation. +var Print = &Builtin{ + Name: "print", + Decl: types.NewVariadicFunction(nil, types.A, nil), +} + +// InternalPrint represents the internal implementation of the print() function. +// The compiler rewrites print() calls to refer to the internal implementation. +var InternalPrint = &Builtin{ + Name: "internal.print", + Decl: types.NewFunction([]types.Type{types.NewArray(nil, types.NewSet(types.A))}, nil), +} + +/** + * Deprecated built-ins. + */ + +// SetDiff has been replaced by the minus built-in. +var SetDiff = &Builtin{ + Name: "set_diff", + Decl: types.NewFunction( + types.Args( + types.NewSet(types.A), + types.NewSet(types.A), + ), + types.NewSet(types.A), + ), + deprecated: true, +} + +// NetCIDROverlap has been replaced by the `net.cidr_contains` built-in. +var NetCIDROverlap = &Builtin{ + Name: "net.cidr_overlap", + Decl: types.NewFunction( + types.Args( + types.S, + types.S, + ), + types.B, + ), + deprecated: true, +} + +// CastArray checks the underlying type of the input. If it is array or set, an array +// containing the values is returned. If it is not an array, an error is thrown. +var CastArray = &Builtin{ + Name: "cast_array", + Decl: types.NewFunction( + types.Args(types.A), + types.NewArray(nil, types.A), + ), + deprecated: true, +} + +// CastSet checks the underlying type of the input. +// If it is a set, the set is returned. +// If it is an array, the array is returned in set form (all duplicates removed) +// If neither, an error is thrown +var CastSet = &Builtin{ + Name: "cast_set", + Decl: types.NewFunction( + types.Args(types.A), + types.NewSet(types.A), + ), + deprecated: true, +} + +// CastString returns input if it is a string; if not returns error. +// For formatting variables, see sprintf +var CastString = &Builtin{ + Name: "cast_string", + Decl: types.NewFunction( + types.Args(types.A), + types.S, + ), + deprecated: true, +} + +// CastBoolean returns input if it is a boolean; if not returns error. +var CastBoolean = &Builtin{ + Name: "cast_boolean", + Decl: types.NewFunction( + types.Args(types.A), + types.B, + ), + deprecated: true, +} + +// CastNull returns null if input is null; if not returns error. +var CastNull = &Builtin{ + Name: "cast_null", + Decl: types.NewFunction( + types.Args(types.A), + types.NewNull(), + ), + deprecated: true, +} + +// CastObject returns the given object if it is null; throws an error otherwise +var CastObject = &Builtin{ + Name: "cast_object", + Decl: types.NewFunction( + types.Args(types.A), + types.NewObject(nil, types.NewDynamicProperty(types.A, types.A)), + ), + deprecated: true, +} + +// RegexMatchDeprecated declares `re_match` which has been deprecated. Use `regex.match` instead. +var RegexMatchDeprecated = &Builtin{ + Name: "re_match", + Decl: types.NewFunction( + types.Args( + types.S, + types.S, + ), + types.B, + ), + deprecated: true, +} + +// All takes a list and returns true if all of the items +// are true. A collection of length 0 returns true. +var All = &Builtin{ + Name: "all", + Decl: types.NewFunction( + types.Args( + types.NewAny( + types.NewSet(types.A), + types.NewArray(nil, types.A), + ), + ), + types.B, + ), + deprecated: true, +} + +// Any takes a collection and returns true if any of the items +// is true. A collection of length 0 returns false. +var Any = &Builtin{ + Name: "any", + Decl: types.NewFunction( + types.Args( + types.NewAny( + types.NewSet(types.A), + types.NewArray(nil, types.A), + ), + ), + types.B, + ), + deprecated: true, +} + +// Builtin represents a built-in function supported by OPA. Every built-in +// function is uniquely identified by a name. +type Builtin struct { + Name string `json:"name"` // Unique name of built-in function, e.g., (arg1,arg2,...,argN) + Description string `json:"description,omitempty"` // Description of what the built-in function does. + + // Categories of the built-in function. Omitted for namespaced + // built-ins, i.e. "array.concat" is taken to be of the "array" category. + // "minus" for example, is part of two categories: numbers and sets. (NOTE(sr): aspirational) + Categories []string `json:"categories,omitempty"` + + Decl *types.Function `json:"decl"` // Built-in function type declaration. + Infix string `json:"infix,omitempty"` // Unique name of infix operator. Default should be unset. + Relation bool `json:"relation,omitempty"` // Indicates if the built-in acts as a relation. + deprecated bool // Indicates if the built-in has been deprecated. + Nondeterministic bool `json:"nondeterministic,omitempty"` // Indicates if the built-in returns non-deterministic results. +} + +// category is a helper for specifying a Builtin's Categories +func category(cs ...string) []string { + return cs +} + +// Minimal returns a shallow copy of b with the descriptions and categories and +// named arguments stripped out. +func (b *Builtin) Minimal() *Builtin { + cpy := *b + fargs := b.Decl.FuncArgs() + if fargs.Variadic != nil { + cpy.Decl = types.NewVariadicFunction(fargs.Args, fargs.Variadic, b.Decl.Result()) + } else { + cpy.Decl = types.NewFunction(fargs.Args, b.Decl.Result()) + } + cpy.Categories = nil + cpy.Description = "" + return &cpy +} + +// IsDeprecated returns true if the Builtin function is deprecated and will be removed in a future release. +func (b *Builtin) IsDeprecated() bool { + return b.deprecated +} + +// IsDeterministic returns true if the Builtin function returns non-deterministic results. +func (b *Builtin) IsNondeterministic() bool { + return b.Nondeterministic +} + +// Expr creates a new expression for the built-in with the given operands. +func (b *Builtin) Expr(operands ...*Term) *Expr { + ts := make([]*Term, len(operands)+1) + ts[0] = NewTerm(b.Ref()) + for i := range operands { + ts[i+1] = operands[i] + } + return &Expr{ + Terms: ts, + } +} + +// Call creates a new term for the built-in with the given operands. +func (b *Builtin) Call(operands ...*Term) *Term { + call := make(Call, len(operands)+1) + call[0] = NewTerm(b.Ref()) + for i := range operands { + call[i+1] = operands[i] + } + return NewTerm(call) +} + +// Ref returns a Ref that refers to the built-in function. +func (b *Builtin) Ref() Ref { + parts := strings.Split(b.Name, ".") + ref := make(Ref, len(parts)) + ref[0] = VarTerm(parts[0]) + for i := 1; i < len(parts); i++ { + ref[i] = StringTerm(parts[i]) + } + return ref +} + +// IsTargetPos returns true if a variable in the i-th position will be bound by +// evaluating the call expression. +func (b *Builtin) IsTargetPos(i int) bool { + return b.Decl.Arity() == i +} + +func init() { + BuiltinMap = map[string]*Builtin{} + for _, b := range DefaultBuiltins { + RegisterBuiltin(b) + } +} diff --git a/vendor/github.com/open-policy-agent/opa/v1/ast/capabilities.go b/vendor/github.com/open-policy-agent/opa/v1/ast/capabilities.go new file mode 100644 index 000000000..e7d561d9e --- /dev/null +++ b/vendor/github.com/open-policy-agent/opa/v1/ast/capabilities.go @@ -0,0 +1,267 @@ +// Copyright 2020 The OPA Authors. All rights reserved. +// Use of this source code is governed by an Apache2 +// license that can be found in the LICENSE file. + +package ast + +import ( + "bytes" + _ "embed" + "encoding/json" + "fmt" + "io" + "os" + "sort" + "strings" + + "github.com/open-policy-agent/opa/internal/semver" + "github.com/open-policy-agent/opa/internal/wasm/sdk/opa/capabilities" + caps "github.com/open-policy-agent/opa/v1/capabilities" + "github.com/open-policy-agent/opa/v1/util" +) + +// VersonIndex contains an index from built-in function name, language feature, +// and future rego keyword to version number. During the build, this is used to +// create an index of the minimum version required for the built-in/feature/kw. +type VersionIndex struct { + Builtins map[string]semver.Version `json:"builtins"` + Features map[string]semver.Version `json:"features"` + Keywords map[string]semver.Version `json:"keywords"` +} + +// NOTE(tsandall): this file is generated by internal/cmd/genversionindex/main.go +// and run as part of go:generate. We generate the version index as part of the +// build process because it's relatively expensive to build (it takes ~500ms on +// my machine) and never changes. +// +//go:embed version_index.json +var versionIndexBs []byte + +var minVersionIndex = func() VersionIndex { + var vi VersionIndex + err := json.Unmarshal(versionIndexBs, &vi) + if err != nil { + panic(err) + } + return vi +}() + +// In the compiler, we used this to check that we're OK working with ref heads. +// If this isn't present, we'll fail. This is to ensure that older versions of +// OPA can work with policies that we're compiling -- if they don't know ref +// heads, they wouldn't be able to parse them. +const FeatureRefHeadStringPrefixes = "rule_head_ref_string_prefixes" +const FeatureRefHeads = "rule_head_refs" +const FeatureRegoV1 = "rego_v1" +const FeatureRegoV1Import = "rego_v1_import" + +// Capabilities defines a structure containing data that describes the capabilities +// or features supported by a particular version of OPA. +type Capabilities struct { + Builtins []*Builtin `json:"builtins,omitempty"` + FutureKeywords []string `json:"future_keywords,omitempty"` + WasmABIVersions []WasmABIVersion `json:"wasm_abi_versions,omitempty"` + + // Features is a bit of a mixed bag for checking that an older version of OPA + // is able to do what needs to be done. + // TODO(sr): find better words ^^ + Features []string `json:"features,omitempty"` + + // allow_net is an array of hostnames or IP addresses, that an OPA instance is + // allowed to connect to. + // If omitted, ANY host can be connected to. If empty, NO host can be connected to. + // As of now, this only controls fetching remote refs for using JSON Schemas in + // the type checker. + // TODO(sr): support ports to further restrict connection peers + // TODO(sr): support restricting `http.send` using the same mechanism (see https://github.com/open-policy-agent/opa/issues/3665) + AllowNet []string `json:"allow_net,omitempty"` +} + +// WasmABIVersion captures the Wasm ABI version. Its `Minor` version is indicating +// backwards-compatible changes. +type WasmABIVersion struct { + Version int `json:"version"` + Minor int `json:"minor_version"` +} + +type CapabilitiesOptions struct { + regoVersion RegoVersion +} + +func newCapabilitiesOptions(opts []CapabilitiesOption) CapabilitiesOptions { + co := CapabilitiesOptions{} + for _, opt := range opts { + opt(&co) + } + return co +} + +type CapabilitiesOption func(*CapabilitiesOptions) + +func CapabilitiesRegoVersion(regoVersion RegoVersion) CapabilitiesOption { + return func(o *CapabilitiesOptions) { + o.regoVersion = regoVersion + } +} + +// CapabilitiesForThisVersion returns the capabilities of this version of OPA. +func CapabilitiesForThisVersion(opts ...CapabilitiesOption) *Capabilities { + co := newCapabilitiesOptions(opts) + + f := &Capabilities{} + + for _, vers := range capabilities.ABIVersions() { + f.WasmABIVersions = append(f.WasmABIVersions, WasmABIVersion{Version: vers[0], Minor: vers[1]}) + } + + f.Builtins = make([]*Builtin, len(Builtins)) + copy(f.Builtins, Builtins) + sort.Slice(f.Builtins, func(i, j int) bool { + return f.Builtins[i].Name < f.Builtins[j].Name + }) + + if co.regoVersion == RegoV0 || co.regoVersion == RegoV0CompatV1 { + for kw := range allFutureKeywords { + f.FutureKeywords = append(f.FutureKeywords, kw) + } + + f.Features = []string{ + FeatureRefHeadStringPrefixes, + FeatureRefHeads, + FeatureRegoV1Import, + } + } else { + for kw := range futureKeywords { + f.FutureKeywords = append(f.FutureKeywords, kw) + } + + f.Features = []string{ + FeatureRegoV1, + } + } + + sort.Strings(f.FutureKeywords) + sort.Strings(f.Features) + + return f +} + +// LoadCapabilitiesJSON loads a JSON serialized capabilities structure from the reader r. +func LoadCapabilitiesJSON(r io.Reader) (*Capabilities, error) { + d := util.NewJSONDecoder(r) + var c Capabilities + return &c, d.Decode(&c) +} + +// LoadCapabilitiesVersion loads a JSON serialized capabilities structure from the specific version. +func LoadCapabilitiesVersion(version string) (*Capabilities, error) { + cvs, err := LoadCapabilitiesVersions() + if err != nil { + return nil, err + } + + for _, cv := range cvs { + if cv == version { + cont, err := caps.FS.ReadFile(cv + ".json") + if err != nil { + return nil, err + } + + return LoadCapabilitiesJSON(bytes.NewReader(cont)) + } + + } + return nil, fmt.Errorf("no capabilities version found %v", version) +} + +// LoadCapabilitiesFile loads a JSON serialized capabilities structure from a file. +func LoadCapabilitiesFile(file string) (*Capabilities, error) { + fd, err := os.Open(file) + if err != nil { + return nil, err + } + defer fd.Close() + return LoadCapabilitiesJSON(fd) +} + +// LoadCapabilitiesVersions loads all capabilities versions +func LoadCapabilitiesVersions() ([]string, error) { + ents, err := caps.FS.ReadDir(".") + if err != nil { + return nil, err + } + + capabilitiesVersions := make([]string, 0, len(ents)) + for _, ent := range ents { + capabilitiesVersions = append(capabilitiesVersions, strings.Replace(ent.Name(), ".json", "", 1)) + } + return capabilitiesVersions, nil +} + +// MinimumCompatibleVersion returns the minimum compatible OPA version based on +// the built-ins, features, and keywords in c. +func (c *Capabilities) MinimumCompatibleVersion() (string, bool) { + + var maxVersion semver.Version + + // this is the oldest OPA release that includes capabilities + if err := maxVersion.Set("0.17.0"); err != nil { + panic("unreachable") + } + + for _, bi := range c.Builtins { + v, ok := minVersionIndex.Builtins[bi.Name] + if !ok { + return "", false + } + if v.Compare(maxVersion) > 0 { + maxVersion = v + } + } + + for _, kw := range c.FutureKeywords { + v, ok := minVersionIndex.Keywords[kw] + if !ok { + return "", false + } + if v.Compare(maxVersion) > 0 { + maxVersion = v + } + } + + for _, feat := range c.Features { + v, ok := minVersionIndex.Features[feat] + if !ok { + return "", false + } + if v.Compare(maxVersion) > 0 { + maxVersion = v + } + } + + return maxVersion.String(), true +} + +func (c *Capabilities) ContainsFeature(feature string) bool { + for _, f := range c.Features { + if f == feature { + return true + } + } + return false +} + +// addBuiltinSorted inserts a built-in into c in sorted order. An existing built-in with the same name +// will be overwritten. +func (c *Capabilities) addBuiltinSorted(bi *Builtin) { + i := sort.Search(len(c.Builtins), func(x int) bool { + return c.Builtins[x].Name >= bi.Name + }) + if i < len(c.Builtins) && bi.Name == c.Builtins[i].Name { + c.Builtins[i] = bi + return + } + c.Builtins = append(c.Builtins, nil) + copy(c.Builtins[i+1:], c.Builtins[i:]) + c.Builtins[i] = bi +} diff --git a/vendor/github.com/open-policy-agent/opa/v1/ast/check.go b/vendor/github.com/open-policy-agent/opa/v1/ast/check.go new file mode 100644 index 000000000..57c2fa5d7 --- /dev/null +++ b/vendor/github.com/open-policy-agent/opa/v1/ast/check.go @@ -0,0 +1,1329 @@ +// Copyright 2017 The OPA Authors. All rights reserved. +// Use of this source code is governed by an Apache2 +// license that can be found in the LICENSE file. + +package ast + +import ( + "fmt" + "sort" + "strings" + + "github.com/open-policy-agent/opa/v1/types" + "github.com/open-policy-agent/opa/v1/util" +) + +type varRewriter func(Ref) Ref + +// exprChecker defines the interface for executing type checking on a single +// expression. The exprChecker must update the provided TypeEnv with inferred +// types of vars. +type exprChecker func(*TypeEnv, *Expr) *Error + +// typeChecker implements type checking on queries and rules. Errors are +// accumulated on the typeChecker so that a single run can report multiple +// issues. +type typeChecker struct { + builtins map[string]*Builtin + required *Capabilities + errs Errors + exprCheckers map[string]exprChecker + varRewriter varRewriter + ss *SchemaSet + allowNet []string + input types.Type + allowUndefinedFuncs bool + schemaTypes map[string]types.Type +} + +// newTypeChecker returns a new typeChecker object that has no errors. +func newTypeChecker() *typeChecker { + return &typeChecker{ + exprCheckers: map[string]exprChecker{ + "eq": checkExprEq, + }, + } +} + +func (tc *typeChecker) newEnv(exist *TypeEnv) *TypeEnv { + if exist != nil { + return exist.wrap() + } + env := newTypeEnv(tc.copy) + if tc.input != nil { + env.tree.Put(InputRootRef, tc.input) + } + return env +} + +func (tc *typeChecker) copy() *typeChecker { + return newTypeChecker(). + WithVarRewriter(tc.varRewriter). + WithSchemaSet(tc.ss). + WithSchemaTypes(tc.schemaTypes). + WithAllowNet(tc.allowNet). + WithInputType(tc.input). + WithAllowUndefinedFunctionCalls(tc.allowUndefinedFuncs). + WithBuiltins(tc.builtins). + WithRequiredCapabilities(tc.required) +} + +func (tc *typeChecker) WithRequiredCapabilities(c *Capabilities) *typeChecker { + tc.required = c + return tc +} + +func (tc *typeChecker) WithBuiltins(builtins map[string]*Builtin) *typeChecker { + tc.builtins = builtins + return tc +} + +func (tc *typeChecker) WithSchemaSet(ss *SchemaSet) *typeChecker { + tc.ss = ss + return tc +} + +func (tc *typeChecker) WithSchemaTypes(schemaTypes map[string]types.Type) *typeChecker { + tc.schemaTypes = schemaTypes + return tc +} + +func (tc *typeChecker) WithAllowNet(hosts []string) *typeChecker { + tc.allowNet = hosts + return tc +} + +func (tc *typeChecker) WithVarRewriter(f varRewriter) *typeChecker { + tc.varRewriter = f + return tc +} + +func (tc *typeChecker) WithInputType(tpe types.Type) *typeChecker { + tc.input = tpe + return tc +} + +// WithAllowUndefinedFunctionCalls sets the type checker to allow references to undefined functions. +// Additionally, the 'CheckUndefinedFuncs' and 'CheckSafetyRuleBodies' compiler stages are skipped. +func (tc *typeChecker) WithAllowUndefinedFunctionCalls(allow bool) *typeChecker { + tc.allowUndefinedFuncs = allow + return tc +} + +// Env returns a type environment for the specified built-ins with any other +// global types configured on the checker. In practice, this is the default +// environment that other statements will be checked against. +func (tc *typeChecker) Env(builtins map[string]*Builtin) *TypeEnv { + env := tc.newEnv(nil) + for _, bi := range builtins { + env.tree.Put(bi.Ref(), bi.Decl) + } + return env +} + +// CheckBody runs type checking on the body and returns a TypeEnv if no errors +// are found. The resulting TypeEnv wraps the provided one. The resulting +// TypeEnv will be able to resolve types of vars contained in the body. +func (tc *typeChecker) CheckBody(env *TypeEnv, body Body) (*TypeEnv, Errors) { + + errors := []*Error{} + env = tc.newEnv(env) + vis := newRefChecker(env, tc.varRewriter) + + WalkExprs(body, func(expr *Expr) bool { + + closureErrs := tc.checkClosures(env, expr) + for _, err := range closureErrs { + errors = append(errors, err) + } + + hasClosureErrors := len(closureErrs) > 0 + + // reset errors from previous iteration + vis.errs = nil + NewGenericVisitor(vis.Visit).Walk(expr) + for _, err := range vis.errs { + errors = append(errors, err) + } + + hasRefErrors := len(vis.errs) > 0 + + if err := tc.checkExpr(env, expr); err != nil { + // Suppress this error if a more actionable one has occurred. In + // this case, if an error occurred in a ref or closure contained in + // this expression, and the error is due to a nil type, then it's + // likely to be the result of the more specific error. + skip := (hasClosureErrors || hasRefErrors) && causedByNilType(err) + if !skip { + errors = append(errors, err) + } + } + return true + }) + + tc.err(errors) + return env, errors +} + +// CheckTypes runs type checking on the rules returns a TypeEnv if no errors +// are found. The resulting TypeEnv wraps the provided one. The resulting +// TypeEnv will be able to resolve types of refs that refer to rules. +func (tc *typeChecker) CheckTypes(env *TypeEnv, sorted []util.T, as *AnnotationSet) (*TypeEnv, Errors) { + env = tc.newEnv(env) + for _, s := range sorted { + tc.checkRule(env, as, s.(*Rule)) + } + tc.errs.Sort() + return env, tc.errs +} + +func (tc *typeChecker) checkClosures(env *TypeEnv, expr *Expr) Errors { + var result Errors + WalkClosures(expr, func(x interface{}) bool { + switch x := x.(type) { + case *ArrayComprehension: + _, errs := tc.copy().CheckBody(env, x.Body) + if len(errs) > 0 { + result = errs + return true + } + case *SetComprehension: + _, errs := tc.copy().CheckBody(env, x.Body) + if len(errs) > 0 { + result = errs + return true + } + case *ObjectComprehension: + _, errs := tc.copy().CheckBody(env, x.Body) + if len(errs) > 0 { + result = errs + return true + } + } + return false + }) + return result +} + +func (tc *typeChecker) getSchemaType(schemaAnnot *SchemaAnnotation, rule *Rule) (types.Type, *Error) { + if tc.schemaTypes == nil { + tc.schemaTypes = make(map[string]types.Type) + } + + if refType, exists := tc.schemaTypes[schemaAnnot.Schema.String()]; exists { + return refType, nil + } + + refType, err := processAnnotation(tc.ss, schemaAnnot, rule, tc.allowNet) + if err != nil { + return nil, err + } + + if refType == nil { + return nil, nil + } + + tc.schemaTypes[schemaAnnot.Schema.String()] = refType + return refType, nil + +} + +func (tc *typeChecker) checkRule(env *TypeEnv, as *AnnotationSet, rule *Rule) { + + env = env.wrap() + + schemaAnnots := getRuleAnnotation(as, rule) + for _, schemaAnnot := range schemaAnnots { + refType, err := tc.getSchemaType(schemaAnnot, rule) + if err != nil { + tc.err([]*Error{err}) + continue + } + + ref := schemaAnnot.Path + // if we do not have a ref or a reftype, we should not evaluate this rule. + if ref == nil || refType == nil { + continue + } + + prefixRef, t := getPrefix(env, ref) + if t == nil || len(prefixRef) == len(ref) { + env.tree.Put(ref, refType) + } else { + newType, err := override(ref[len(prefixRef):], t, refType, rule) + if err != nil { + tc.err([]*Error{err}) + continue + } + env.tree.Put(prefixRef, newType) + } + } + + cpy, err := tc.CheckBody(env, rule.Body) + env = env.next + path := rule.Ref() + + if len(err) > 0 { + // if the rule/function contains an error, add it to the type env so + // that expressions that refer to this rule/function do not encounter + // type errors. + env.tree.Put(path, types.A) + return + } + + var tpe types.Type + + if len(rule.Head.Args) > 0 { + // If args are not referred to in body, infer as any. + WalkVars(rule.Head.Args, func(v Var) bool { + if cpy.Get(v) == nil { + cpy.tree.PutOne(v, types.A) + } + return false + }) + + // Construct function type. + args := make([]types.Type, len(rule.Head.Args)) + for i := 0; i < len(rule.Head.Args); i++ { + args[i] = cpy.Get(rule.Head.Args[i]) + } + + f := types.NewFunction(args, cpy.Get(rule.Head.Value)) + + tpe = f + } else { + switch rule.Head.RuleKind() { + case SingleValue: + typeV := cpy.Get(rule.Head.Value) + if !path.IsGround() { + // e.g. store object[string: whatever] at data.p.q.r, not data.p.q.r[x] or data.p.q.r[x].y[z] + objPath := path.DynamicSuffix() + path = path.GroundPrefix() + + var err error + tpe, err = nestedObject(cpy, objPath, typeV) + if err != nil { + tc.err([]*Error{NewError(TypeErr, rule.Head.Location, err.Error())}) //nolint:govet + tpe = nil + } + } else { + if typeV != nil { + tpe = typeV + } + } + case MultiValue: + typeK := cpy.Get(rule.Head.Key) + if typeK != nil { + tpe = types.NewSet(typeK) + } + } + } + + if tpe != nil { + env.tree.Insert(path, tpe, env) + } +} + +// nestedObject creates a nested structure of object types, where each term on path corresponds to a level in the +// nesting. Each term in the path only contributes to the dynamic portion of its corresponding object. +func nestedObject(env *TypeEnv, path Ref, tpe types.Type) (types.Type, error) { + if len(path) == 0 { + return tpe, nil + } + + k := path[0] + typeV, err := nestedObject(env, path[1:], tpe) + if err != nil { + return nil, err + } + if typeV == nil { + return nil, nil + } + + var dynamicProperty *types.DynamicProperty + typeK := env.Get(k) + if typeK == nil { + return nil, nil + } + dynamicProperty = types.NewDynamicProperty(typeK, typeV) + + return types.NewObject(nil, dynamicProperty), nil +} + +func (tc *typeChecker) checkExpr(env *TypeEnv, expr *Expr) *Error { + if err := tc.checkExprWith(env, expr, 0); err != nil { + return err + } + if !expr.IsCall() { + return nil + } + + operator := expr.Operator().String() + + // If the type checker wasn't provided with a required capabilities + // structure then just skip. In some cases, type checking might be run + // without the need to record what builtins are required. + if tc.required != nil && tc.builtins != nil { + if bi, ok := tc.builtins[operator]; ok { + tc.required.addBuiltinSorted(bi) + } + } + + checker := tc.exprCheckers[operator] + if checker != nil { + return checker(env, expr) + } + + return tc.checkExprBuiltin(env, expr) +} + +func (tc *typeChecker) checkExprBuiltin(env *TypeEnv, expr *Expr) *Error { + + args := expr.Operands() + pre := getArgTypes(env, args) + + // NOTE(tsandall): undefined functions will have been caught earlier in the + // compiler. We check for undefined functions before the safety check so + // that references to non-existent functions result in undefined function + // errors as opposed to unsafe var errors. + // + // We cannot run type checking before the safety check because part of the + // type checker relies on reordering (in particular for references to local + // vars). + name := expr.Operator() + tpe := env.Get(name) + + if tpe == nil { + if tc.allowUndefinedFuncs { + return nil + } + return NewError(TypeErr, expr.Location, "undefined function %v", name) + } + + // check if the expression refers to a function that contains an error + _, ok := tpe.(types.Any) + if ok { + return nil + } + + ftpe, ok := tpe.(*types.Function) + if !ok { + return NewError(TypeErr, expr.Location, "undefined function %v", name) + } + + fargs := ftpe.FuncArgs() + namedFargs := ftpe.NamedFuncArgs() + + if ftpe.Result() != nil { + fargs.Args = append(fargs.Args, ftpe.Result()) + namedFargs.Args = append(namedFargs.Args, ftpe.NamedResult()) + } + + if len(args) > len(fargs.Args) && fargs.Variadic == nil { + return newArgError(expr.Location, name, "too many arguments", pre, namedFargs) + } + + if len(args) < len(ftpe.FuncArgs().Args) { + return newArgError(expr.Location, name, "too few arguments", pre, namedFargs) + } + + for i := range args { + if !unify1(env, args[i], fargs.Arg(i), false) { + post := make([]types.Type, len(args)) + for i := range args { + post[i] = env.Get(args[i]) + } + return newArgError(expr.Location, name, "invalid argument(s)", post, namedFargs) + } + } + + return nil +} + +func checkExprEq(env *TypeEnv, expr *Expr) *Error { + + pre := getArgTypes(env, expr.Operands()) + + if len(pre) < Equality.Decl.Arity() { + return newArgError(expr.Location, expr.Operator(), "too few arguments", pre, Equality.Decl.FuncArgs()) + } + + if Equality.Decl.Arity() < len(pre) { + return newArgError(expr.Location, expr.Operator(), "too many arguments", pre, Equality.Decl.FuncArgs()) + } + + a, b := expr.Operand(0), expr.Operand(1) + typeA, typeB := env.Get(a), env.Get(b) + + if !unify2(env, a, typeA, b, typeB) { + err := NewError(TypeErr, expr.Location, "match error") + err.Details = &UnificationErrDetail{ + Left: typeA, + Right: typeB, + } + return err + } + + return nil +} + +func (tc *typeChecker) checkExprWith(env *TypeEnv, expr *Expr, i int) *Error { + if i == len(expr.With) { + return nil + } + + target, value := expr.With[i].Target, expr.With[i].Value + targetType, valueType := env.Get(target), env.Get(value) + + if t, ok := targetType.(*types.Function); ok { // built-in function replacement + switch v := valueType.(type) { + case *types.Function: // ...by function + if !unifies(targetType, valueType) { + return newArgError(expr.With[i].Loc(), target.Value.(Ref), "arity mismatch", v.FuncArgs().Args, t.NamedFuncArgs()) + } + default: // ... by value, nothing to check + } + } + + return tc.checkExprWith(env, expr, i+1) +} + +func unify2(env *TypeEnv, a *Term, typeA types.Type, b *Term, typeB types.Type) bool { + + nilA := types.Nil(typeA) + nilB := types.Nil(typeB) + + if nilA && !nilB { + return unify1(env, a, typeB, false) + } else if nilB && !nilA { + return unify1(env, b, typeA, false) + } else if !nilA && !nilB { + return unifies(typeA, typeB) + } + + switch a.Value.(type) { + case *Array: + return unify2Array(env, a, b) + case *object: + return unify2Object(env, a, b) + case Var: + switch b.Value.(type) { + case Var: + return unify1(env, a, types.A, false) && unify1(env, b, env.Get(a), false) + case *Array: + return unify2Array(env, b, a) + case *object: + return unify2Object(env, b, a) + } + } + + return false +} + +func unify2Array(env *TypeEnv, a *Term, b *Term) bool { + arr := a.Value.(*Array) + switch bv := b.Value.(type) { + case *Array: + if arr.Len() == bv.Len() { + for i := 0; i < arr.Len(); i++ { + if !unify2(env, arr.Elem(i), env.Get(arr.Elem(i)), bv.Elem(i), env.Get(bv.Elem(i))) { + return false + } + } + return true + } + case Var: + return unify1(env, a, types.A, false) && unify1(env, b, env.Get(a), false) + } + return false +} + +func unify2Object(env *TypeEnv, a *Term, b *Term) bool { + obj := a.Value.(Object) + switch bv := b.Value.(type) { + case *object: + cv := obj.Intersect(bv) + if obj.Len() == bv.Len() && bv.Len() == len(cv) { + for i := range cv { + if !unify2(env, cv[i][1], env.Get(cv[i][1]), cv[i][2], env.Get(cv[i][2])) { + return false + } + } + return true + } + case Var: + return unify1(env, a, types.A, false) && unify1(env, b, env.Get(a), false) + } + return false +} + +func unify1(env *TypeEnv, term *Term, tpe types.Type, union bool) bool { + switch v := term.Value.(type) { + case *Array: + switch tpe := tpe.(type) { + case *types.Array: + return unify1Array(env, v, tpe, union) + case types.Any: + if types.Compare(tpe, types.A) == 0 { + for i := 0; i < v.Len(); i++ { + unify1(env, v.Elem(i), types.A, true) + } + return true + } + unifies := false + for i := range tpe { + unifies = unify1(env, term, tpe[i], true) || unifies + } + return unifies + } + return false + case *object: + switch tpe := tpe.(type) { + case *types.Object: + return unify1Object(env, v, tpe, union) + case types.Any: + if types.Compare(tpe, types.A) == 0 { + v.Foreach(func(key, value *Term) { + unify1(env, key, types.A, true) + unify1(env, value, types.A, true) + }) + return true + } + unifies := false + for i := range tpe { + unifies = unify1(env, term, tpe[i], true) || unifies + } + return unifies + } + return false + case Set: + switch tpe := tpe.(type) { + case *types.Set: + return unify1Set(env, v, tpe, union) + case types.Any: + if types.Compare(tpe, types.A) == 0 { + v.Foreach(func(elem *Term) { + unify1(env, elem, types.A, true) + }) + return true + } + unifies := false + for i := range tpe { + unifies = unify1(env, term, tpe[i], true) || unifies + } + return unifies + } + return false + case Ref, *ArrayComprehension, *ObjectComprehension, *SetComprehension: + return unifies(env.Get(v), tpe) + case Var: + if !union { + if exist := env.Get(v); exist != nil { + return unifies(exist, tpe) + } + env.tree.PutOne(term.Value, tpe) + } else { + env.tree.PutOne(term.Value, types.Or(env.Get(v), tpe)) + } + return true + default: + if !IsConstant(v) { + panic("unreachable") + } + return unifies(env.Get(term), tpe) + } +} + +func unify1Array(env *TypeEnv, val *Array, tpe *types.Array, union bool) bool { + if val.Len() != tpe.Len() && tpe.Dynamic() == nil { + return false + } + for i := 0; i < val.Len(); i++ { + if !unify1(env, val.Elem(i), tpe.Select(i), union) { + return false + } + } + return true +} + +func unify1Object(env *TypeEnv, val Object, tpe *types.Object, union bool) bool { + if val.Len() != len(tpe.Keys()) && tpe.DynamicValue() == nil { + return false + } + stop := val.Until(func(k, v *Term) bool { + if IsConstant(k.Value) { + if child := selectConstant(tpe, k); child != nil { + if !unify1(env, v, child, union) { + return true + } + } else { + return true + } + } else { + // Inferring type of value under dynamic key would involve unioning + // with all property values of tpe whose keys unify. For now, type + // these values as Any. We can investigate stricter inference in + // the future. + unify1(env, v, types.A, union) + } + return false + }) + return !stop +} + +func unify1Set(env *TypeEnv, val Set, tpe *types.Set, union bool) bool { + of := types.Values(tpe) + return !val.Until(func(elem *Term) bool { + return !unify1(env, elem, of, union) + }) +} + +func (tc *typeChecker) err(errors []*Error) { + tc.errs = append(tc.errs, errors...) +} + +type refChecker struct { + env *TypeEnv + errs Errors + varRewriter varRewriter +} + +func rewriteVarsNop(node Ref) Ref { + return node +} + +func newRefChecker(env *TypeEnv, f varRewriter) *refChecker { + if f == nil { + f = rewriteVarsNop + } + + return &refChecker{ + env: env, + errs: nil, + varRewriter: f, + } +} + +func (rc *refChecker) Visit(x interface{}) bool { + switch x := x.(type) { + case *ArrayComprehension, *ObjectComprehension, *SetComprehension: + return true + case *Expr: + switch terms := x.Terms.(type) { + case []*Term: + for i := 1; i < len(terms); i++ { + NewGenericVisitor(rc.Visit).Walk(terms[i]) + } + return true + case *Term: + NewGenericVisitor(rc.Visit).Walk(terms) + return true + } + case Ref: + if err := rc.checkApply(rc.env, x); err != nil { + rc.errs = append(rc.errs, err) + return true + } + if err := rc.checkRef(rc.env, rc.env.tree, x, 0); err != nil { + rc.errs = append(rc.errs, err) + } + } + return false +} + +func (rc *refChecker) checkApply(curr *TypeEnv, ref Ref) *Error { + switch tpe := curr.Get(ref).(type) { + case *types.Function: // NOTE(sr): We don't support first-class functions, except for `with`. + return newRefErrUnsupported(ref[0].Location, rc.varRewriter(ref), len(ref)-1, tpe) + } + + return nil +} + +func (rc *refChecker) checkRef(curr *TypeEnv, node *typeTreeNode, ref Ref, idx int) *Error { + + if idx == len(ref) { + return nil + } + + head := ref[idx] + + // NOTE(sr): as long as package statements are required, this isn't possible: + // the shortest possible rule ref is data.a.b (b is idx 2), idx 1 and 2 need to + // be strings or vars. + if idx == 1 || idx == 2 { + switch head.Value.(type) { + case Var, String: // OK + default: + have := rc.env.Get(head.Value) + return newRefErrInvalid(ref[0].Location, rc.varRewriter(ref), idx, have, types.S, getOneOfForNode(node)) + } + } + + if v, ok := head.Value.(Var); ok && idx != 0 { + tpe := types.Keys(rc.env.getRefRecExtent(node)) + if exist := rc.env.Get(v); exist != nil { + if !unifies(tpe, exist) { + return newRefErrInvalid(ref[0].Location, rc.varRewriter(ref), idx, exist, tpe, getOneOfForNode(node)) + } + } else { + rc.env.tree.PutOne(v, tpe) + } + } + + child := node.Child(head.Value) + if child == nil { + // NOTE(sr): idx is reset on purpose: we start over + switch { + case curr.next != nil: + next := curr.next + return rc.checkRef(next, next.tree, ref, 0) + + case RootDocumentNames.Contains(ref[0]): + if idx != 0 { + node.Children().Iter(func(_, child util.T) bool { + _ = rc.checkRef(curr, child.(*typeTreeNode), ref, idx+1) // ignore error + return false + }) + return nil + } + return rc.checkRefLeaf(types.A, ref, 1) + + default: + return rc.checkRefLeaf(types.A, ref, 0) + } + } + + if child.Leaf() { + return rc.checkRefLeaf(child.Value(), ref, idx+1) + } + + return rc.checkRef(curr, child, ref, idx+1) +} + +func (rc *refChecker) checkRefLeaf(tpe types.Type, ref Ref, idx int) *Error { + + if idx == len(ref) { + return nil + } + + head := ref[idx] + + keys := types.Keys(tpe) + if keys == nil { + return newRefErrUnsupported(ref[0].Location, rc.varRewriter(ref), idx-1, tpe) + } + + switch value := head.Value.(type) { + + case Var: + if exist := rc.env.Get(value); exist != nil { + if !unifies(exist, keys) { + return newRefErrInvalid(ref[0].Location, rc.varRewriter(ref), idx, exist, keys, getOneOfForType(tpe)) + } + } else { + rc.env.tree.PutOne(value, types.Keys(tpe)) + } + + case Ref: + if exist := rc.env.Get(value); exist != nil { + if !unifies(exist, keys) { + return newRefErrInvalid(ref[0].Location, rc.varRewriter(ref), idx, exist, keys, getOneOfForType(tpe)) + } + } + + case *Array, Object, Set: + if !unify1(rc.env, head, keys, false) { + return newRefErrInvalid(ref[0].Location, rc.varRewriter(ref), idx, rc.env.Get(head), keys, nil) + } + + default: + child := selectConstant(tpe, head) + if child == nil { + return newRefErrInvalid(ref[0].Location, rc.varRewriter(ref), idx, nil, types.Keys(tpe), getOneOfForType(tpe)) + } + return rc.checkRefLeaf(child, ref, idx+1) + } + + return rc.checkRefLeaf(types.Values(tpe), ref, idx+1) +} + +func unifies(a, b types.Type) bool { + + if a == nil || b == nil { + return false + } + + anyA, ok1 := a.(types.Any) + if ok1 { + if unifiesAny(anyA, b) { + return true + } + } + + anyB, ok2 := b.(types.Any) + if ok2 { + if unifiesAny(anyB, a) { + return true + } + } + + if ok1 || ok2 { + return false + } + + switch a := a.(type) { + case types.Null: + _, ok := b.(types.Null) + return ok + case types.Boolean: + _, ok := b.(types.Boolean) + return ok + case types.Number: + _, ok := b.(types.Number) + return ok + case types.String: + _, ok := b.(types.String) + return ok + case *types.Array: + b, ok := b.(*types.Array) + if !ok { + return false + } + return unifiesArrays(a, b) + case *types.Object: + b, ok := b.(*types.Object) + if !ok { + return false + } + return unifiesObjects(a, b) + case *types.Set: + b, ok := b.(*types.Set) + if !ok { + return false + } + return unifies(types.Values(a), types.Values(b)) + case *types.Function: + // NOTE(sr): variadic functions can only be internal ones, and we've forbidden + // their replacement via `with`; so we disregard variadic here + if types.Arity(a) == types.Arity(b) { + b := b.(*types.Function) + for i := range a.FuncArgs().Args { + if !unifies(a.FuncArgs().Arg(i), b.FuncArgs().Arg(i)) { + return false + } + } + return true + } + return false + default: + panic("unreachable") + } +} + +func unifiesAny(a types.Any, b types.Type) bool { + if _, ok := b.(*types.Function); ok { + return false + } + for i := range a { + if unifies(a[i], b) { + return true + } + } + return len(a) == 0 +} + +func unifiesArrays(a, b *types.Array) bool { + + if !unifiesArraysStatic(a, b) { + return false + } + + if !unifiesArraysStatic(b, a) { + return false + } + + return a.Dynamic() == nil || b.Dynamic() == nil || unifies(a.Dynamic(), b.Dynamic()) +} + +func unifiesArraysStatic(a, b *types.Array) bool { + if a.Len() != 0 { + for i := 0; i < a.Len(); i++ { + if !unifies(a.Select(i), b.Select(i)) { + return false + } + } + } + return true +} + +func unifiesObjects(a, b *types.Object) bool { + if !unifiesObjectsStatic(a, b) { + return false + } + + if !unifiesObjectsStatic(b, a) { + return false + } + + return a.DynamicValue() == nil || b.DynamicValue() == nil || unifies(a.DynamicValue(), b.DynamicValue()) +} + +func unifiesObjectsStatic(a, b *types.Object) bool { + for _, k := range a.Keys() { + if !unifies(a.Select(k), b.Select(k)) { + return false + } + } + return true +} + +// typeErrorCause defines an interface to determine the reason for a type +// error. The type error details implement this interface so that type checking +// can report more actionable errors. +type typeErrorCause interface { + nilType() bool +} + +func causedByNilType(err *Error) bool { + cause, ok := err.Details.(typeErrorCause) + if !ok { + return false + } + return cause.nilType() +} + +// ArgErrDetail represents a generic argument error. +type ArgErrDetail struct { + Have []types.Type `json:"have"` + Want types.FuncArgs `json:"want"` +} + +// Lines returns the string representation of the detail. +func (d *ArgErrDetail) Lines() []string { + lines := make([]string, 2) + lines[0] = "have: " + formatArgs(d.Have) + lines[1] = "want: " + fmt.Sprint(d.Want) + return lines +} + +func (d *ArgErrDetail) nilType() bool { + for i := range d.Have { + if types.Nil(d.Have[i]) { + return true + } + } + return false +} + +// UnificationErrDetail describes a type mismatch error when two values are +// unified (e.g., x = [1,2,y]). +type UnificationErrDetail struct { + Left types.Type `json:"a"` + Right types.Type `json:"b"` +} + +func (a *UnificationErrDetail) nilType() bool { + return types.Nil(a.Left) || types.Nil(a.Right) +} + +// Lines returns the string representation of the detail. +func (a *UnificationErrDetail) Lines() []string { + lines := make([]string, 2) + lines[0] = fmt.Sprint("left : ", types.Sprint(a.Left)) + lines[1] = fmt.Sprint("right : ", types.Sprint(a.Right)) + return lines +} + +// RefErrUnsupportedDetail describes an undefined reference error where the +// referenced value does not support dereferencing (e.g., scalars). +type RefErrUnsupportedDetail struct { + Ref Ref `json:"ref"` // invalid ref + Pos int `json:"pos"` // invalid element + Have types.Type `json:"have"` // referenced type +} + +// Lines returns the string representation of the detail. +func (r *RefErrUnsupportedDetail) Lines() []string { + lines := []string{ + r.Ref.String(), + strings.Repeat("^", len(r.Ref[:r.Pos+1].String())), + fmt.Sprintf("have: %v", r.Have), + } + return lines +} + +// RefErrInvalidDetail describes an undefined reference error where the referenced +// value does not support the reference operand (e.g., missing object key, +// invalid key type, etc.) +type RefErrInvalidDetail struct { + Ref Ref `json:"ref"` // invalid ref + Pos int `json:"pos"` // invalid element + Have types.Type `json:"have,omitempty"` // type of invalid element (for var/ref elements) + Want types.Type `json:"want"` // allowed type (for non-object values) + OneOf []Value `json:"oneOf"` // allowed values (e.g., for object keys) +} + +// Lines returns the string representation of the detail. +func (r *RefErrInvalidDetail) Lines() []string { + lines := []string{r.Ref.String()} + offset := len(r.Ref[:r.Pos].String()) + 1 + pad := strings.Repeat(" ", offset) + lines = append(lines, fmt.Sprintf("%s^", pad)) + if r.Have != nil { + lines = append(lines, fmt.Sprintf("%shave (type): %v", pad, r.Have)) + } else { + lines = append(lines, fmt.Sprintf("%shave: %v", pad, r.Ref[r.Pos])) + } + if len(r.OneOf) > 0 { + lines = append(lines, fmt.Sprintf("%swant (one of): %v", pad, r.OneOf)) + } else { + lines = append(lines, fmt.Sprintf("%swant (type): %v", pad, r.Want)) + } + return lines +} + +func formatArgs(args []types.Type) string { + buf := make([]string, len(args)) + for i := range args { + buf[i] = types.Sprint(args[i]) + } + return "(" + strings.Join(buf, ", ") + ")" +} + +func newRefErrInvalid(loc *Location, ref Ref, idx int, have, want types.Type, oneOf []Value) *Error { + err := newRefError(loc, ref) + err.Details = &RefErrInvalidDetail{ + Ref: ref, + Pos: idx, + Have: have, + Want: want, + OneOf: oneOf, + } + return err +} + +func newRefErrUnsupported(loc *Location, ref Ref, idx int, have types.Type) *Error { + err := newRefError(loc, ref) + err.Details = &RefErrUnsupportedDetail{ + Ref: ref, + Pos: idx, + Have: have, + } + return err +} + +func newRefError(loc *Location, ref Ref) *Error { + return NewError(TypeErr, loc, "undefined ref: %v", ref) +} + +func newArgError(loc *Location, builtinName Ref, msg string, have []types.Type, want types.FuncArgs) *Error { + err := NewError(TypeErr, loc, "%v: %v", builtinName, msg) + err.Details = &ArgErrDetail{ + Have: have, + Want: want, + } + return err +} + +func getOneOfForNode(node *typeTreeNode) (result []Value) { + node.Children().Iter(func(k, _ util.T) bool { + result = append(result, k.(Value)) + return false + }) + + sortValueSlice(result) + return result +} + +func getOneOfForType(tpe types.Type) (result []Value) { + switch tpe := tpe.(type) { + case *types.Object: + for _, k := range tpe.Keys() { + v, err := InterfaceToValue(k) + if err != nil { + panic(err) + } + result = append(result, v) + } + + case types.Any: + for _, object := range tpe { + objRes := getOneOfForType(object) + result = append(result, objRes...) + } + } + + result = removeDuplicate(result) + sortValueSlice(result) + return result +} + +func sortValueSlice(sl []Value) { + sort.Slice(sl, func(i, j int) bool { + return sl[i].Compare(sl[j]) < 0 + }) +} + +func removeDuplicate(list []Value) []Value { + seen := make(map[Value]bool) + var newResult []Value + for _, item := range list { + if !seen[item] { + newResult = append(newResult, item) + seen[item] = true + } + } + return newResult +} + +func getArgTypes(env *TypeEnv, args []*Term) []types.Type { + pre := make([]types.Type, len(args)) + for i := range args { + pre[i] = env.Get(args[i]) + } + return pre +} + +// getPrefix returns the shortest prefix of ref that exists in env +func getPrefix(env *TypeEnv, ref Ref) (Ref, types.Type) { + if len(ref) == 1 { + t := env.Get(ref) + if t != nil { + return ref, t + } + } + for i := 1; i < len(ref); i++ { + t := env.Get(ref[:i]) + if t != nil { + return ref[:i], t + } + } + return nil, nil +} + +// override takes a type t and returns a type obtained from t where the path represented by ref within it has type o (overriding the original type of that path) +func override(ref Ref, t types.Type, o types.Type, rule *Rule) (types.Type, *Error) { + var newStaticProps []*types.StaticProperty + obj, ok := t.(*types.Object) + if !ok { + newType, err := getObjectType(ref, o, rule, types.NewDynamicProperty(types.A, types.A)) + if err != nil { + return nil, err + } + return newType, nil + } + found := false + if ok { + staticProps := obj.StaticProperties() + for _, prop := range staticProps { + valueCopy := prop.Value + key, err := InterfaceToValue(prop.Key) + if err != nil { + return nil, NewError(TypeErr, rule.Location, "unexpected error in override: %s", err.Error()) + } + if len(ref) > 0 && ref[0].Value.Compare(key) == 0 { + found = true + if len(ref) == 1 { + valueCopy = o + } else { + newVal, err := override(ref[1:], valueCopy, o, rule) + if err != nil { + return nil, err + } + valueCopy = newVal + } + } + newStaticProps = append(newStaticProps, types.NewStaticProperty(prop.Key, valueCopy)) + } + } + + // ref[0] is not a top-level key in staticProps, so it must be added + if !found { + newType, err := getObjectType(ref, o, rule, obj.DynamicProperties()) + if err != nil { + return nil, err + } + newStaticProps = append(newStaticProps, newType.StaticProperties()...) + } + return types.NewObject(newStaticProps, obj.DynamicProperties()), nil +} + +func getKeys(ref Ref, rule *Rule) ([]interface{}, *Error) { + keys := []interface{}{} + for _, refElem := range ref { + key, err := JSON(refElem.Value) + if err != nil { + return nil, NewError(TypeErr, rule.Location, "error getting key from value: %s", err.Error()) + } + keys = append(keys, key) + } + return keys, nil +} + +func getObjectTypeRec(keys []interface{}, o types.Type, d *types.DynamicProperty) *types.Object { + if len(keys) == 1 { + staticProps := []*types.StaticProperty{types.NewStaticProperty(keys[0], o)} + return types.NewObject(staticProps, d) + } + + staticProps := []*types.StaticProperty{types.NewStaticProperty(keys[0], getObjectTypeRec(keys[1:], o, d))} + return types.NewObject(staticProps, d) +} + +func getObjectType(ref Ref, o types.Type, rule *Rule, d *types.DynamicProperty) (*types.Object, *Error) { + keys, err := getKeys(ref, rule) + if err != nil { + return nil, err + } + return getObjectTypeRec(keys, o, d), nil +} + +func getRuleAnnotation(as *AnnotationSet, rule *Rule) (result []*SchemaAnnotation) { + + for _, x := range as.GetSubpackagesScope(rule.Module.Package.Path) { + result = append(result, x.Schemas...) + } + + if x := as.GetPackageScope(rule.Module.Package); x != nil { + result = append(result, x.Schemas...) + } + + if x := as.GetDocumentScope(rule.Ref().GroundPrefix()); x != nil { + result = append(result, x.Schemas...) + } + + for _, x := range as.GetRuleScope(rule) { + result = append(result, x.Schemas...) + } + + return result +} + +func processAnnotation(ss *SchemaSet, annot *SchemaAnnotation, rule *Rule, allowNet []string) (types.Type, *Error) { + + var schema interface{} + + if annot.Schema != nil { + if ss == nil { + return nil, nil + } + schema = ss.Get(annot.Schema) + if schema == nil { + return nil, NewError(TypeErr, rule.Location, "undefined schema: %v", annot.Schema) + } + } else if annot.Definition != nil { + schema = *annot.Definition + } + + tpe, err := loadSchema(schema, allowNet) + if err != nil { + return nil, NewError(TypeErr, rule.Location, err.Error()) //nolint:govet + } + + return tpe, nil +} + +func errAnnotationRedeclared(a *Annotations, other *Location) *Error { + return NewError(TypeErr, a.Location, "%v annotation redeclared: %v", a.Scope, other) +} diff --git a/vendor/github.com/open-policy-agent/opa/v1/ast/compare.go b/vendor/github.com/open-policy-agent/opa/v1/ast/compare.go new file mode 100644 index 000000000..3bb6f2a75 --- /dev/null +++ b/vendor/github.com/open-policy-agent/opa/v1/ast/compare.go @@ -0,0 +1,396 @@ +// Copyright 2016 The OPA Authors. All rights reserved. +// Use of this source code is governed by an Apache2 +// license that can be found in the LICENSE file. + +package ast + +import ( + "encoding/json" + "fmt" + "math/big" +) + +// Compare returns an integer indicating whether two AST values are less than, +// equal to, or greater than each other. +// +// If a is less than b, the return value is negative. If a is greater than b, +// the return value is positive. If a is equal to b, the return value is zero. +// +// Different types are never equal to each other. For comparison purposes, types +// are sorted as follows: +// +// nil < Null < Boolean < Number < String < Var < Ref < Array < Object < Set < +// ArrayComprehension < ObjectComprehension < SetComprehension < Expr < SomeDecl +// < With < Body < Rule < Import < Package < Module. +// +// Arrays and Refs are equal if and only if both a and b have the same length +// and all corresponding elements are equal. If one element is not equal, the +// return value is the same as for the first differing element. If all elements +// are equal but a and b have different lengths, the shorter is considered less +// than the other. +// +// Objects are considered equal if and only if both a and b have the same sorted +// (key, value) pairs and are of the same length. Other comparisons are +// consistent but not defined. +// +// Sets are considered equal if and only if the symmetric difference of a and b +// is empty. +// Other comparisons are consistent but not defined. +func Compare(a, b interface{}) int { + + if t, ok := a.(*Term); ok { + if t == nil { + a = nil + } else { + a = t.Value + } + } + + if t, ok := b.(*Term); ok { + if t == nil { + b = nil + } else { + b = t.Value + } + } + + if a == nil { + if b == nil { + return 0 + } + return -1 + } + if b == nil { + return 1 + } + + sortA := sortOrder(a) + sortB := sortOrder(b) + + if sortA < sortB { + return -1 + } else if sortB < sortA { + return 1 + } + + switch a := a.(type) { + case Null: + return 0 + case Boolean: + b := b.(Boolean) + if a.Equal(b) { + return 0 + } + if !a { + return -1 + } + return 1 + case Number: + if ai, err := json.Number(a).Int64(); err == nil { + if bi, err := json.Number(b.(Number)).Int64(); err == nil { + if ai == bi { + return 0 + } + if ai < bi { + return -1 + } + return 1 + } + } + + // We use big.Rat for comparing big numbers. + // It replaces big.Float due to following reason: + // big.Float comes with a default precision of 64, and setting a + // larger precision results in more memory being allocated + // (regardless of the actual number we are parsing with SetString). + // + // Note: If we're so close to zero that big.Float says we are zero, do + // *not* big.Rat).SetString on the original string it'll potentially + // take very long. + var bigA, bigB *big.Rat + fa, ok := new(big.Float).SetString(string(a)) + if !ok { + panic("illegal value") + } + if fa.IsInt() { + if i, _ := fa.Int64(); i == 0 { + bigA = new(big.Rat).SetInt64(0) + } + } + if bigA == nil { + bigA, ok = new(big.Rat).SetString(string(a)) + if !ok { + panic("illegal value") + } + } + + fb, ok := new(big.Float).SetString(string(b.(Number))) + if !ok { + panic("illegal value") + } + if fb.IsInt() { + if i, _ := fb.Int64(); i == 0 { + bigB = new(big.Rat).SetInt64(0) + } + } + if bigB == nil { + bigB, ok = new(big.Rat).SetString(string(b.(Number))) + if !ok { + panic("illegal value") + } + } + + return bigA.Cmp(bigB) + case String: + b := b.(String) + if a.Equal(b) { + return 0 + } + if a < b { + return -1 + } + return 1 + case Var: + b := b.(Var) + if a.Equal(b) { + return 0 + } + if a < b { + return -1 + } + return 1 + case Ref: + b := b.(Ref) + return termSliceCompare(a, b) + case *Array: + b := b.(*Array) + return termSliceCompare(a.elems, b.elems) + case *lazyObj: + return Compare(a.force(), b) + case *object: + if x, ok := b.(*lazyObj); ok { + b = x.force() + } + b := b.(*object) + return a.Compare(b) + case Set: + b := b.(Set) + return a.Compare(b) + case *ArrayComprehension: + b := b.(*ArrayComprehension) + if cmp := Compare(a.Term, b.Term); cmp != 0 { + return cmp + } + return Compare(a.Body, b.Body) + case *ObjectComprehension: + b := b.(*ObjectComprehension) + if cmp := Compare(a.Key, b.Key); cmp != 0 { + return cmp + } + if cmp := Compare(a.Value, b.Value); cmp != 0 { + return cmp + } + return Compare(a.Body, b.Body) + case *SetComprehension: + b := b.(*SetComprehension) + if cmp := Compare(a.Term, b.Term); cmp != 0 { + return cmp + } + return Compare(a.Body, b.Body) + case Call: + b := b.(Call) + return termSliceCompare(a, b) + case *Expr: + b := b.(*Expr) + return a.Compare(b) + case *SomeDecl: + b := b.(*SomeDecl) + return a.Compare(b) + case *Every: + b := b.(*Every) + return a.Compare(b) + case *With: + b := b.(*With) + return a.Compare(b) + case Body: + b := b.(Body) + return a.Compare(b) + case *Head: + b := b.(*Head) + return a.Compare(b) + case *Rule: + b := b.(*Rule) + return a.Compare(b) + case Args: + b := b.(Args) + return termSliceCompare(a, b) + case *Import: + b := b.(*Import) + return a.Compare(b) + case *Package: + b := b.(*Package) + return a.Compare(b) + case *Annotations: + b := b.(*Annotations) + return a.Compare(b) + case *Module: + b := b.(*Module) + return a.Compare(b) + } + panic(fmt.Sprintf("illegal value: %T", a)) +} + +type termSlice []*Term + +func (s termSlice) Less(i, j int) bool { return Compare(s[i].Value, s[j].Value) < 0 } +func (s termSlice) Swap(i, j int) { x := s[i]; s[i] = s[j]; s[j] = x } +func (s termSlice) Len() int { return len(s) } + +func sortOrder(x interface{}) int { + switch x.(type) { + case Null: + return 0 + case Boolean: + return 1 + case Number: + return 2 + case String: + return 3 + case Var: + return 4 + case Ref: + return 5 + case *Array: + return 6 + case Object: + return 7 + case Set: + return 8 + case *ArrayComprehension: + return 9 + case *ObjectComprehension: + return 10 + case *SetComprehension: + return 11 + case Call: + return 12 + case Args: + return 13 + case *Expr: + return 100 + case *SomeDecl: + return 101 + case *Every: + return 102 + case *With: + return 110 + case *Head: + return 120 + case Body: + return 200 + case *Rule: + return 1000 + case *Import: + return 1001 + case *Package: + return 1002 + case *Annotations: + return 1003 + case *Module: + return 10000 + } + panic(fmt.Sprintf("illegal value: %T", x)) +} + +func importsCompare(a, b []*Import) int { + minLen := len(a) + if len(b) < minLen { + minLen = len(b) + } + for i := 0; i < minLen; i++ { + if cmp := a[i].Compare(b[i]); cmp != 0 { + return cmp + } + } + if len(a) < len(b) { + return -1 + } + if len(b) < len(a) { + return 1 + } + return 0 +} + +func annotationsCompare(a, b []*Annotations) int { + minLen := len(a) + if len(b) < minLen { + minLen = len(b) + } + for i := 0; i < minLen; i++ { + if cmp := a[i].Compare(b[i]); cmp != 0 { + return cmp + } + } + if len(a) < len(b) { + return -1 + } + if len(b) < len(a) { + return 1 + } + return 0 +} + +func rulesCompare(a, b []*Rule) int { + minLen := len(a) + if len(b) < minLen { + minLen = len(b) + } + for i := 0; i < minLen; i++ { + if cmp := a[i].Compare(b[i]); cmp != 0 { + return cmp + } + } + if len(a) < len(b) { + return -1 + } + if len(b) < len(a) { + return 1 + } + return 0 +} + +func termSliceCompare(a, b []*Term) int { + minLen := len(a) + if len(b) < minLen { + minLen = len(b) + } + for i := 0; i < minLen; i++ { + if cmp := Compare(a[i], b[i]); cmp != 0 { + return cmp + } + } + if len(a) < len(b) { + return -1 + } else if len(b) < len(a) { + return 1 + } + return 0 +} + +func withSliceCompare(a, b []*With) int { + minLen := len(a) + if len(b) < minLen { + minLen = len(b) + } + for i := 0; i < minLen; i++ { + if cmp := Compare(a[i], b[i]); cmp != 0 { + return cmp + } + } + if len(a) < len(b) { + return -1 + } else if len(b) < len(a) { + return 1 + } + return 0 +} diff --git a/vendor/github.com/open-policy-agent/opa/v1/ast/compile.go b/vendor/github.com/open-policy-agent/opa/v1/ast/compile.go new file mode 100644 index 000000000..5f78b0da1 --- /dev/null +++ b/vendor/github.com/open-policy-agent/opa/v1/ast/compile.go @@ -0,0 +1,5984 @@ +// Copyright 2016 The OPA Authors. All rights reserved. +// Use of this source code is governed by an Apache2 +// license that can be found in the LICENSE file. + +package ast + +import ( + "errors" + "fmt" + "io" + "sort" + "strconv" + "strings" + + "github.com/open-policy-agent/opa/internal/debug" + "github.com/open-policy-agent/opa/internal/gojsonschema" + "github.com/open-policy-agent/opa/v1/ast/location" + "github.com/open-policy-agent/opa/v1/metrics" + "github.com/open-policy-agent/opa/v1/types" + "github.com/open-policy-agent/opa/v1/util" +) + +// CompileErrorLimitDefault is the default number errors a compiler will allow before +// exiting. +const CompileErrorLimitDefault = 10 + +var errLimitReached = NewError(CompileErr, nil, "error limit reached") + +// Compiler contains the state of a compilation process. +type Compiler struct { + + // Errors contains errors that occurred during the compilation process. + // If there are one or more errors, the compilation process is considered + // "failed". + Errors Errors + + // Modules contains the compiled modules. The compiled modules are the + // output of the compilation process. If the compilation process failed, + // there is no guarantee about the state of the modules. + Modules map[string]*Module + + // ModuleTree organizes the modules into a tree where each node is keyed by + // an element in the module's package path. E.g., given modules containing + // the following package directives: "a", "a.b", "a.c", and "a.b", the + // resulting module tree would be: + // + // root + // | + // +--- data (no modules) + // | + // +--- a (1 module) + // | + // +--- b (2 modules) + // | + // +--- c (1 module) + // + ModuleTree *ModuleTreeNode + + // RuleTree organizes rules into a tree where each node is keyed by an + // element in the rule's path. The rule path is the concatenation of the + // containing package and the stringified rule name. E.g., given the + // following module: + // + // package ex + // p[1] { true } + // p[2] { true } + // q = true + // a.b.c = 3 + // + // root + // | + // +--- data (no rules) + // | + // +--- ex (no rules) + // | + // +--- p (2 rules) + // | + // +--- q (1 rule) + // | + // +--- a + // | + // +--- b + // | + // +--- c (1 rule) + // + // Another example with general refs containing vars at arbitrary locations: + // + // package ex + // a.b[x].d { x := "c" } # R1 + // a.b.c[x] { x := "d" } # R2 + // a.b[x][y] { x := "c"; y := "d" } # R3 + // p := true # R4 + // + // root + // | + // +--- data (no rules) + // | + // +--- ex (no rules) + // | + // +--- a + // | | + // | +--- b (R1, R3) + // | | + // | +--- c (R2) + // | + // +--- p (R4) + RuleTree *TreeNode + + // Graph contains dependencies between rules. An edge (u,v) is added to the + // graph if rule 'u' refers to the virtual document defined by 'v'. + Graph *Graph + + // TypeEnv holds type information for values inferred by the compiler. + TypeEnv *TypeEnv + + // RewrittenVars is a mapping of variables that have been rewritten + // with the key being the generated name and value being the original. + RewrittenVars map[Var]Var + + // Capabilities required by the modules that were compiled. + Required *Capabilities + + localvargen *localVarGenerator + moduleLoader ModuleLoader + ruleIndices *util.HashMap + stages []stage + maxErrs int + sorted []string // list of sorted module names + pathExists func([]string) (bool, error) + after map[string][]CompilerStageDefinition + metrics metrics.Metrics + capabilities *Capabilities // user-supplied capabilities + imports map[string][]*Import // saved imports from stripping + builtins map[string]*Builtin // universe of built-in functions + customBuiltins map[string]*Builtin // user-supplied custom built-in functions (deprecated: use capabilities) + unsafeBuiltinsMap map[string]struct{} // user-supplied set of unsafe built-ins functions to block (deprecated: use capabilities) + deprecatedBuiltinsMap map[string]struct{} // set of deprecated, but not removed, built-in functions + enablePrintStatements bool // indicates if print statements should be elided (default) + comprehensionIndices map[*Term]*ComprehensionIndex // comprehension key index + initialized bool // indicates if init() has been called + debug debug.Debug // emits debug information produced during compilation + schemaSet *SchemaSet // user-supplied schemas for input and data documents + inputType types.Type // global input type retrieved from schema set + annotationSet *AnnotationSet // hierarchical set of annotations + strict bool // enforce strict compilation checks + keepModules bool // whether to keep the unprocessed, parse modules (below) + parsedModules map[string]*Module // parsed, but otherwise unprocessed modules, kept track of when keepModules is true + useTypeCheckAnnotations bool // whether to provide annotated information (schemas) to the type checker + allowUndefinedFuncCalls bool // don't error on calls to unknown functions. + evalMode CompilerEvalMode // + rewriteTestRulesForTracing bool // rewrite test rules to capture dynamic values for tracing. + defaultRegoVersion RegoVersion +} + +func (c *Compiler) DefaultRegoVersion() RegoVersion { + return c.defaultRegoVersion +} + +// CompilerStage defines the interface for stages in the compiler. +type CompilerStage func(*Compiler) *Error + +// CompilerEvalMode allows toggling certain stages that are only +// needed for certain modes, Concretely, only "topdown" mode will +// have the compiler build comprehension and rule indices. +type CompilerEvalMode int + +const ( + // EvalModeTopdown (default) instructs the compiler to build rule + // and comprehension indices used by topdown evaluation. + EvalModeTopdown CompilerEvalMode = iota + + // EvalModeIR makes the compiler skip the stages for comprehension + // and rule indices. + EvalModeIR +) + +// CompilerStageDefinition defines a compiler stage +type CompilerStageDefinition struct { + Name string + MetricName string + Stage CompilerStage +} + +// RulesOptions defines the options for retrieving rules by Ref from the +// compiler. +type RulesOptions struct { + // IncludeHiddenModules determines if the result contains hidden modules, + // currently only the "system" namespace, i.e. "data.system.*". + IncludeHiddenModules bool +} + +// QueryContext contains contextual information for running an ad-hoc query. +// +// Ad-hoc queries can be run in the context of a package and imports may be +// included to provide concise access to data. +type QueryContext struct { + Package *Package + Imports []*Import +} + +// NewQueryContext returns a new QueryContext object. +func NewQueryContext() *QueryContext { + return &QueryContext{} +} + +// WithPackage sets the pkg on qc. +func (qc *QueryContext) WithPackage(pkg *Package) *QueryContext { + if qc == nil { + qc = NewQueryContext() + } + qc.Package = pkg + return qc +} + +// WithImports sets the imports on qc. +func (qc *QueryContext) WithImports(imports []*Import) *QueryContext { + if qc == nil { + qc = NewQueryContext() + } + qc.Imports = imports + return qc +} + +// Copy returns a deep copy of qc. +func (qc *QueryContext) Copy() *QueryContext { + if qc == nil { + return nil + } + cpy := *qc + if cpy.Package != nil { + cpy.Package = qc.Package.Copy() + } + cpy.Imports = make([]*Import, len(qc.Imports)) + for i := range qc.Imports { + cpy.Imports[i] = qc.Imports[i].Copy() + } + return &cpy +} + +// QueryCompiler defines the interface for compiling ad-hoc queries. +type QueryCompiler interface { + + // Compile should be called to compile ad-hoc queries. The return value is + // the compiled version of the query. + Compile(q Body) (Body, error) + + // TypeEnv returns the type environment built after running type checking + // on the query. + TypeEnv() *TypeEnv + + // WithContext sets the QueryContext on the QueryCompiler. Subsequent calls + // to Compile will take the QueryContext into account. + WithContext(qctx *QueryContext) QueryCompiler + + // WithEnablePrintStatements enables print statements in queries compiled + // with the QueryCompiler. + WithEnablePrintStatements(yes bool) QueryCompiler + + // WithUnsafeBuiltins sets the built-in functions to treat as unsafe and not + // allow inside of queries. By default the query compiler inherits the + // compiler's unsafe built-in functions. This function allows callers to + // override that set. If an empty (non-nil) map is provided, all built-ins + // are allowed. + WithUnsafeBuiltins(unsafe map[string]struct{}) QueryCompiler + + // WithStageAfter registers a stage to run during query compilation after + // the named stage. + WithStageAfter(after string, stage QueryCompilerStageDefinition) QueryCompiler + + // RewrittenVars maps generated vars in the compiled query to vars from the + // parsed query. For example, given the query "input := 1" the rewritten + // query would be "__local0__ = 1". The mapping would then be {__local0__: input}. + RewrittenVars() map[Var]Var + + // ComprehensionIndex returns an index data structure for the given comprehension + // term. If no index is found, returns nil. + ComprehensionIndex(term *Term) *ComprehensionIndex + + // WithStrict enables strict mode for the query compiler. + WithStrict(strict bool) QueryCompiler +} + +// QueryCompilerStage defines the interface for stages in the query compiler. +type QueryCompilerStage func(QueryCompiler, Body) (Body, error) + +// QueryCompilerStageDefinition defines a QueryCompiler stage +type QueryCompilerStageDefinition struct { + Name string + MetricName string + Stage QueryCompilerStage +} + +type stage struct { + name string + metricName string + f func() +} + +// NewCompiler returns a new empty compiler. +func NewCompiler() *Compiler { + + c := &Compiler{ + Modules: map[string]*Module{}, + RewrittenVars: map[Var]Var{}, + Required: &Capabilities{}, + ruleIndices: util.NewHashMap(func(a, b util.T) bool { + r1, r2 := a.(Ref), b.(Ref) + return r1.Equal(r2) + }, func(x util.T) int { + return x.(Ref).Hash() + }), + maxErrs: CompileErrorLimitDefault, + after: map[string][]CompilerStageDefinition{}, + unsafeBuiltinsMap: map[string]struct{}{}, + deprecatedBuiltinsMap: map[string]struct{}{}, + comprehensionIndices: map[*Term]*ComprehensionIndex{}, + debug: debug.Discard(), + defaultRegoVersion: DefaultRegoVersion, + } + + c.ModuleTree = NewModuleTree(nil) + c.RuleTree = NewRuleTree(c.ModuleTree) + + c.stages = []stage{ + // Reference resolution should run first as it may be used to lazily + // load additional modules. If any stages run before resolution, they + // need to be re-run after resolution. + {"ResolveRefs", "compile_stage_resolve_refs", c.resolveAllRefs}, + // The local variable generator must be initialized after references are + // resolved and the dynamic module loader has run but before subsequent + // stages that need to generate variables. + {"InitLocalVarGen", "compile_stage_init_local_var_gen", c.initLocalVarGen}, + {"RewriteRuleHeadRefs", "compile_stage_rewrite_rule_head_refs", c.rewriteRuleHeadRefs}, + {"CheckKeywordOverrides", "compile_stage_check_keyword_overrides", c.checkKeywordOverrides}, + {"CheckDuplicateImports", "compile_stage_check_imports", c.checkImports}, + {"RemoveImports", "compile_stage_remove_imports", c.removeImports}, + {"SetModuleTree", "compile_stage_set_module_tree", c.setModuleTree}, + {"SetRuleTree", "compile_stage_set_rule_tree", c.setRuleTree}, // depends on RewriteRuleHeadRefs + {"RewriteLocalVars", "compile_stage_rewrite_local_vars", c.rewriteLocalVars}, + {"CheckVoidCalls", "compile_stage_check_void_calls", c.checkVoidCalls}, + {"RewritePrintCalls", "compile_stage_rewrite_print_calls", c.rewritePrintCalls}, + {"RewriteExprTerms", "compile_stage_rewrite_expr_terms", c.rewriteExprTerms}, + {"ParseMetadataBlocks", "compile_stage_parse_metadata_blocks", c.parseMetadataBlocks}, + {"SetAnnotationSet", "compile_stage_set_annotationset", c.setAnnotationSet}, + {"RewriteRegoMetadataCalls", "compile_stage_rewrite_rego_metadata_calls", c.rewriteRegoMetadataCalls}, + {"SetGraph", "compile_stage_set_graph", c.setGraph}, + {"RewriteComprehensionTerms", "compile_stage_rewrite_comprehension_terms", c.rewriteComprehensionTerms}, + {"RewriteRefsInHead", "compile_stage_rewrite_refs_in_head", c.rewriteRefsInHead}, + {"RewriteWithValues", "compile_stage_rewrite_with_values", c.rewriteWithModifiers}, + {"CheckRuleConflicts", "compile_stage_check_rule_conflicts", c.checkRuleConflicts}, + {"CheckUndefinedFuncs", "compile_stage_check_undefined_funcs", c.checkUndefinedFuncs}, + {"CheckSafetyRuleHeads", "compile_stage_check_safety_rule_heads", c.checkSafetyRuleHeads}, + {"CheckSafetyRuleBodies", "compile_stage_check_safety_rule_bodies", c.checkSafetyRuleBodies}, + {"RewriteEquals", "compile_stage_rewrite_equals", c.rewriteEquals}, + {"RewriteDynamicTerms", "compile_stage_rewrite_dynamic_terms", c.rewriteDynamicTerms}, + {"RewriteTestRulesForTracing", "compile_stage_rewrite_test_rules_for_tracing", c.rewriteTestRuleEqualities}, // must run after RewriteDynamicTerms + {"CheckRecursion", "compile_stage_check_recursion", c.checkRecursion}, + {"CheckTypes", "compile_stage_check_types", c.checkTypes}, // must be run after CheckRecursion + {"CheckUnsafeBuiltins", "compile_state_check_unsafe_builtins", c.checkUnsafeBuiltins}, + {"CheckDeprecatedBuiltins", "compile_state_check_deprecated_builtins", c.checkDeprecatedBuiltins}, + {"BuildRuleIndices", "compile_stage_rebuild_indices", c.buildRuleIndices}, + {"BuildComprehensionIndices", "compile_stage_rebuild_comprehension_indices", c.buildComprehensionIndices}, + {"BuildRequiredCapabilities", "compile_stage_build_required_capabilities", c.buildRequiredCapabilities}, + } + + return c +} + +// SetErrorLimit sets the number of errors the compiler can encounter before it +// quits. Zero or a negative number indicates no limit. +func (c *Compiler) SetErrorLimit(limit int) *Compiler { + c.maxErrs = limit + return c +} + +// WithEnablePrintStatements enables print statements inside of modules compiled +// by the compiler. If print statements are not enabled, calls to print() are +// erased at compile-time. +func (c *Compiler) WithEnablePrintStatements(yes bool) *Compiler { + c.enablePrintStatements = yes + return c +} + +// WithPathConflictsCheck enables base-virtual document conflict +// detection. The compiler will check that rules don't overlap with +// paths that exist as determined by the provided callable. +func (c *Compiler) WithPathConflictsCheck(fn func([]string) (bool, error)) *Compiler { + c.pathExists = fn + return c +} + +// WithStageAfter registers a stage to run during compilation after +// the named stage. +func (c *Compiler) WithStageAfter(after string, stage CompilerStageDefinition) *Compiler { + c.after[after] = append(c.after[after], stage) + return c +} + +// WithMetrics will set a metrics.Metrics and be used for profiling +// the Compiler instance. +func (c *Compiler) WithMetrics(metrics metrics.Metrics) *Compiler { + c.metrics = metrics + return c +} + +// WithCapabilities sets capabilities to enable during compilation. Capabilities allow the caller +// to specify the set of built-in functions available to the policy. In the future, capabilities +// may be able to restrict access to other language features. Capabilities allow callers to check +// if policies are compatible with a particular version of OPA. If policies are a compiled for a +// specific version of OPA, there is no guarantee that _this_ version of OPA can evaluate them +// successfully. +func (c *Compiler) WithCapabilities(capabilities *Capabilities) *Compiler { + c.capabilities = capabilities + return c +} + +// Capabilities returns the capabilities enabled during compilation. +func (c *Compiler) Capabilities() *Capabilities { + return c.capabilities +} + +// WithDebug sets where debug messages are written to. Passing `nil` has no +// effect. +func (c *Compiler) WithDebug(sink io.Writer) *Compiler { + if sink != nil { + c.debug = debug.New(sink) + } + return c +} + +// WithBuiltins is deprecated. Use WithCapabilities instead. +func (c *Compiler) WithBuiltins(builtins map[string]*Builtin) *Compiler { + c.customBuiltins = make(map[string]*Builtin) + for k, v := range builtins { + c.customBuiltins[k] = v + } + return c +} + +// WithUnsafeBuiltins is deprecated. Use WithCapabilities instead. +func (c *Compiler) WithUnsafeBuiltins(unsafeBuiltins map[string]struct{}) *Compiler { + for name := range unsafeBuiltins { + c.unsafeBuiltinsMap[name] = struct{}{} + } + return c +} + +// WithStrict enables strict mode in the compiler. +func (c *Compiler) WithStrict(strict bool) *Compiler { + c.strict = strict + return c +} + +// WithKeepModules enables retaining unprocessed modules in the compiler. +// Note that the modules aren't copied on the way in or out -- so when +// accessing them via ParsedModules(), mutations will occur in the module +// map that was passed into Compile().` +func (c *Compiler) WithKeepModules(y bool) *Compiler { + c.keepModules = y + return c +} + +// WithUseTypeCheckAnnotations use schema annotations during type checking +func (c *Compiler) WithUseTypeCheckAnnotations(enabled bool) *Compiler { + c.useTypeCheckAnnotations = enabled + return c +} + +func (c *Compiler) WithAllowUndefinedFunctionCalls(allow bool) *Compiler { + c.allowUndefinedFuncCalls = allow + return c +} + +// WithEvalMode allows setting the CompilerEvalMode of the compiler +func (c *Compiler) WithEvalMode(e CompilerEvalMode) *Compiler { + c.evalMode = e + return c +} + +// WithRewriteTestRules enables rewriting test rules to capture dynamic values in local variables, +// so they can be accessed by tracing. +func (c *Compiler) WithRewriteTestRules(rewrite bool) *Compiler { + c.rewriteTestRulesForTracing = rewrite + return c +} + +// ParsedModules returns the parsed, unprocessed modules from the compiler. +// It is `nil` if keeping modules wasn't enabled via `WithKeepModules(true)`. +// The map includes all modules loaded via the ModuleLoader, if one was used. +func (c *Compiler) ParsedModules() map[string]*Module { + return c.parsedModules +} + +func (c *Compiler) QueryCompiler() QueryCompiler { + c.init() + c0 := *c + return newQueryCompiler(&c0) +} + +// Compile runs the compilation process on the input modules. The compiled +// version of the modules and associated data structures are stored on the +// compiler. If the compilation process fails for any reason, the compiler will +// contain a slice of errors. +func (c *Compiler) Compile(modules map[string]*Module) { + + c.init() + + c.Modules = make(map[string]*Module, len(modules)) + c.sorted = make([]string, 0, len(modules)) + + if c.keepModules { + c.parsedModules = make(map[string]*Module, len(modules)) + } else { + c.parsedModules = nil + } + + for k, v := range modules { + c.Modules[k] = v.Copy() + c.sorted = append(c.sorted, k) + if c.parsedModules != nil { + c.parsedModules[k] = v + } + } + + sort.Strings(c.sorted) + + c.compile() +} + +// WithSchemas sets a schemaSet to the compiler +func (c *Compiler) WithSchemas(schemas *SchemaSet) *Compiler { + c.schemaSet = schemas + return c +} + +// Failed returns true if a compilation error has been encountered. +func (c *Compiler) Failed() bool { + return len(c.Errors) > 0 +} + +// ComprehensionIndex returns a data structure specifying how to index comprehension +// results so that callers do not have to recompute the comprehension more than once. +// If no index is found, returns nil. +func (c *Compiler) ComprehensionIndex(term *Term) *ComprehensionIndex { + return c.comprehensionIndices[term] +} + +// GetArity returns the number of args a function referred to by ref takes. If +// ref refers to built-in function, the built-in declaration is consulted, +// otherwise, the ref is used to perform a ruleset lookup. +func (c *Compiler) GetArity(ref Ref) int { + if bi := c.builtins[ref.String()]; bi != nil { + return len(bi.Decl.FuncArgs().Args) + } + rules := c.GetRulesExact(ref) + if len(rules) == 0 { + return -1 + } + return len(rules[0].Head.Args) +} + +// GetRulesExact returns a slice of rules referred to by the reference. +// +// E.g., given the following module: +// +// package a.b.c +// +// p[k] = v { ... } # rule1 +// p[k1] = v1 { ... } # rule2 +// +// The following calls yield the rules on the right. +// +// GetRulesExact("data.a.b.c.p") => [rule1, rule2] +// GetRulesExact("data.a.b.c.p.x") => nil +// GetRulesExact("data.a.b.c") => nil +func (c *Compiler) GetRulesExact(ref Ref) (rules []*Rule) { + node := c.RuleTree + + for _, x := range ref { + if node = node.Child(x.Value); node == nil { + return nil + } + } + + return extractRules(node.Values) +} + +// GetRulesForVirtualDocument returns a slice of rules that produce the virtual +// document referred to by the reference. +// +// E.g., given the following module: +// +// package a.b.c +// +// p[k] = v { ... } # rule1 +// p[k1] = v1 { ... } # rule2 +// +// The following calls yield the rules on the right. +// +// GetRulesForVirtualDocument("data.a.b.c.p") => [rule1, rule2] +// GetRulesForVirtualDocument("data.a.b.c.p.x") => [rule1, rule2] +// GetRulesForVirtualDocument("data.a.b.c") => nil +func (c *Compiler) GetRulesForVirtualDocument(ref Ref) (rules []*Rule) { + + node := c.RuleTree + + for _, x := range ref { + if node = node.Child(x.Value); node == nil { + return nil + } + if len(node.Values) > 0 { + return extractRules(node.Values) + } + } + + return extractRules(node.Values) +} + +// GetRulesWithPrefix returns a slice of rules that share the prefix ref. +// +// E.g., given the following module: +// +// package a.b.c +// +// p[x] = y { ... } # rule1 +// p[k] = v { ... } # rule2 +// q { ... } # rule3 +// +// The following calls yield the rules on the right. +// +// GetRulesWithPrefix("data.a.b.c.p") => [rule1, rule2] +// GetRulesWithPrefix("data.a.b.c.p.a") => nil +// GetRulesWithPrefix("data.a.b.c") => [rule1, rule2, rule3] +func (c *Compiler) GetRulesWithPrefix(ref Ref) (rules []*Rule) { + + node := c.RuleTree + + for _, x := range ref { + if node = node.Child(x.Value); node == nil { + return nil + } + } + + var acc func(node *TreeNode) + + acc = func(node *TreeNode) { + rules = append(rules, extractRules(node.Values)...) + for _, child := range node.Children { + if child.Hide { + continue + } + acc(child) + } + } + + acc(node) + + return rules +} + +func extractRules(s []util.T) []*Rule { + rules := make([]*Rule, len(s)) + for i := range s { + rules[i] = s[i].(*Rule) + } + return rules +} + +// GetRules returns a slice of rules that are referred to by ref. +// +// E.g., given the following module: +// +// package a.b.c +// +// p[x] = y { q[x] = y; ... } # rule1 +// q[x] = y { ... } # rule2 +// +// The following calls yield the rules on the right. +// +// GetRules("data.a.b.c.p") => [rule1] +// GetRules("data.a.b.c.p.x") => [rule1] +// GetRules("data.a.b.c.q") => [rule2] +// GetRules("data.a.b.c") => [rule1, rule2] +// GetRules("data.a.b.d") => nil +func (c *Compiler) GetRules(ref Ref) (rules []*Rule) { + + set := map[*Rule]struct{}{} + + for _, rule := range c.GetRulesForVirtualDocument(ref) { + set[rule] = struct{}{} + } + + for _, rule := range c.GetRulesWithPrefix(ref) { + set[rule] = struct{}{} + } + + for rule := range set { + rules = append(rules, rule) + } + + return rules +} + +// GetRulesDynamic returns a slice of rules that could be referred to by a ref. +// +// Deprecated: use GetRulesDynamicWithOpts +func (c *Compiler) GetRulesDynamic(ref Ref) []*Rule { + return c.GetRulesDynamicWithOpts(ref, RulesOptions{}) +} + +// GetRulesDynamicWithOpts returns a slice of rules that could be referred to by +// a ref. +// When parts of the ref are statically known, we use that information to narrow +// down which rules the ref could refer to, but in the most general case this +// will be an over-approximation. +// +// E.g., given the following modules: +// +// package a.b.c +// +// r1 = 1 # rule1 +// +// and: +// +// package a.d.c +// +// r2 = 2 # rule2 +// +// The following calls yield the rules on the right. +// +// GetRulesDynamicWithOpts("data.a[x].c[y]", opts) => [rule1, rule2] +// GetRulesDynamicWithOpts("data.a[x].c.r2", opts) => [rule2] +// GetRulesDynamicWithOpts("data.a.b[x][y]", opts) => [rule1] +// +// Using the RulesOptions parameter, the inclusion of hidden modules can be +// controlled: +// +// With +// +// package system.main +// +// r3 = 3 # rule3 +// +// We'd get this result: +// +// GetRulesDynamicWithOpts("data[x]", RulesOptions{IncludeHiddenModules: true}) => [rule1, rule2, rule3] +// +// Without the options, it would be excluded. +func (c *Compiler) GetRulesDynamicWithOpts(ref Ref, opts RulesOptions) []*Rule { + node := c.RuleTree + + set := map[*Rule]struct{}{} + var walk func(node *TreeNode, i int) + walk = func(node *TreeNode, i int) { + switch { + case i >= len(ref): + // We've reached the end of the reference and want to collect everything + // under this "prefix". + node.DepthFirst(func(descendant *TreeNode) bool { + insertRules(set, descendant.Values) + if opts.IncludeHiddenModules { + return false + } + return descendant.Hide + }) + + case i == 0 || IsConstant(ref[i].Value): + // The head of the ref is always grounded. In case another part of the + // ref is also grounded, we can lookup the exact child. If it's not found + // we can immediately return... + if child := node.Child(ref[i].Value); child != nil { + if len(child.Values) > 0 { + // Add any rules at this position + insertRules(set, child.Values) + } + // There might still be "sub-rules" contributing key-value "overrides" for e.g. partial object rules, continue walking + walk(child, i+1) + } else { + return + } + + default: + // This part of the ref is a dynamic term. We can't know what it refers + // to and will just need to try all of the children. + for _, child := range node.Children { + if child.Hide && !opts.IncludeHiddenModules { + continue + } + insertRules(set, child.Values) + walk(child, i+1) + } + } + } + + walk(node, 0) + rules := make([]*Rule, 0, len(set)) + for rule := range set { + rules = append(rules, rule) + } + return rules +} + +// Utility: add all rule values to the set. +func insertRules(set map[*Rule]struct{}, rules []util.T) { + for _, rule := range rules { + set[rule.(*Rule)] = struct{}{} + } +} + +// RuleIndex returns a RuleIndex built for the rule set referred to by path. +// The path must refer to the rule set exactly, i.e., given a rule set at path +// data.a.b.c.p, refs data.a.b.c.p.x and data.a.b.c would not return a +// RuleIndex built for the rule. +func (c *Compiler) RuleIndex(path Ref) RuleIndex { + r, ok := c.ruleIndices.Get(path) + if !ok { + return nil + } + return r.(RuleIndex) +} + +// PassesTypeCheck determines whether the given body passes type checking +func (c *Compiler) PassesTypeCheck(body Body) bool { + checker := newTypeChecker().WithSchemaSet(c.schemaSet).WithInputType(c.inputType) + env := c.TypeEnv + _, errs := checker.CheckBody(env, body) + return len(errs) == 0 +} + +// PassesTypeCheckRules determines whether the given rules passes type checking +func (c *Compiler) PassesTypeCheckRules(rules []*Rule) Errors { + elems := []util.T{} + + for _, rule := range rules { + elems = append(elems, rule) + } + + // Load the global input schema if one was provided. + if c.schemaSet != nil { + if schema := c.schemaSet.Get(SchemaRootRef); schema != nil { + + var allowNet []string + if c.capabilities != nil { + allowNet = c.capabilities.AllowNet + } + + tpe, err := loadSchema(schema, allowNet) + if err != nil { + return Errors{NewError(TypeErr, nil, err.Error())} //nolint:govet + } + c.inputType = tpe + } + } + + var as *AnnotationSet + if c.useTypeCheckAnnotations { + as = c.annotationSet + } + + checker := newTypeChecker().WithSchemaSet(c.schemaSet).WithInputType(c.inputType) + + if c.TypeEnv == nil { + if c.capabilities == nil { + c.capabilities = CapabilitiesForThisVersion() + } + + c.builtins = make(map[string]*Builtin, len(c.capabilities.Builtins)+len(c.customBuiltins)) + + for _, bi := range c.capabilities.Builtins { + c.builtins[bi.Name] = bi + } + + for name, bi := range c.customBuiltins { + c.builtins[name] = bi + } + + c.TypeEnv = checker.Env(c.builtins) + } + + _, errs := checker.CheckTypes(c.TypeEnv, elems, as) + return errs +} + +// ModuleLoader defines the interface that callers can implement to enable lazy +// loading of modules during compilation. +type ModuleLoader func(resolved map[string]*Module) (parsed map[string]*Module, err error) + +// WithModuleLoader sets f as the ModuleLoader on the compiler. +// +// The compiler will invoke the ModuleLoader after resolving all references in +// the current set of input modules. The ModuleLoader can return a new +// collection of parsed modules that are to be included in the compilation +// process. This process will repeat until the ModuleLoader returns an empty +// collection or an error. If an error is returned, compilation will stop +// immediately. +func (c *Compiler) WithModuleLoader(f ModuleLoader) *Compiler { + c.moduleLoader = f + return c +} + +// WithDefaultRegoVersion sets the default Rego version to use when a module doesn't specify one; +// such as when it's hand-crafted instead of parsed. +func (c *Compiler) WithDefaultRegoVersion(regoVersion RegoVersion) *Compiler { + c.defaultRegoVersion = regoVersion + return c +} + +func (c *Compiler) counterAdd(name string, n uint64) { + if c.metrics == nil { + return + } + c.metrics.Counter(name).Add(n) +} + +func (c *Compiler) buildRuleIndices() { + + c.RuleTree.DepthFirst(func(node *TreeNode) bool { + if len(node.Values) == 0 { + return false + } + rules := extractRules(node.Values) + hasNonGroundRef := false + for _, r := range rules { + hasNonGroundRef = !r.Head.Ref().IsGround() + } + if hasNonGroundRef { + // Collect children to ensure that all rules within the extent of a rule with a general ref + // are found on the same index. E.g. the following rules should be indexed under data.a.b.c: + // + // package a + // b.c[x].e := 1 { x := input.x } + // b.c.d := 2 + // b.c.d2.e[x] := 3 { x := input.x } + for _, child := range node.Children { + child.DepthFirst(func(c *TreeNode) bool { + rules = append(rules, extractRules(c.Values)...) + return false + }) + } + } + + index := newBaseDocEqIndex(func(ref Ref) bool { + return isVirtual(c.RuleTree, ref.GroundPrefix()) + }) + if index.Build(rules) { + c.ruleIndices.Put(rules[0].Ref().GroundPrefix(), index) + } + return hasNonGroundRef // currently, we don't allow those branches to go deeper + }) + +} + +func (c *Compiler) buildComprehensionIndices() { + for _, name := range c.sorted { + WalkRules(c.Modules[name], func(r *Rule) bool { + candidates := r.Head.Args.Vars() + candidates.Update(ReservedVars) + n := buildComprehensionIndices(c.debug, c.GetArity, candidates, c.RewrittenVars, r.Body, c.comprehensionIndices) + c.counterAdd(compileStageComprehensionIndexBuild, n) + return false + }) + } +} + +// buildRequiredCapabilities updates the required capabilities on the compiler +// to include any keyword and feature dependencies present in the modules. The +// built-in function dependencies will have already been added by the type +// checker. +func (c *Compiler) buildRequiredCapabilities() { + + features := map[string]struct{}{} + + // extract required keywords from modules + + keywords := map[string]struct{}{} + futureKeywordsPrefix := Ref{FutureRootDocument, StringTerm("keywords")} + for _, name := range c.sorted { + for _, imp := range c.imports[name] { + mod := c.Modules[name] + path := imp.Path.Value.(Ref) + switch { + case path.Equal(RegoV1CompatibleRef): + if !c.moduleIsRegoV1(mod) { + features[FeatureRegoV1Import] = struct{}{} + } + case path.HasPrefix(futureKeywordsPrefix): + if len(path) == 2 { + if c.moduleIsRegoV1(mod) { + for kw := range futureKeywords { + keywords[kw] = struct{}{} + } + } else { + for kw := range allFutureKeywords { + keywords[kw] = struct{}{} + } + } + } else { + kw := string(path[2].Value.(String)) + if c.moduleIsRegoV1(mod) { + for allowedKw := range futureKeywords { + if kw == allowedKw { + keywords[kw] = struct{}{} + break + } + } + } else { + for allowedKw := range allFutureKeywords { + if kw == allowedKw { + keywords[kw] = struct{}{} + break + } + } + } + } + } + } + } + + c.Required.FutureKeywords = stringMapToSortedSlice(keywords) + + // extract required features from modules + + for _, name := range c.sorted { + mod := c.Modules[name] + + if c.moduleIsRegoV1(mod) { + features[FeatureRegoV1] = struct{}{} + } else { + for _, rule := range mod.Rules { + refLen := len(rule.Head.Reference) + if refLen >= 3 { + if refLen > len(rule.Head.Reference.ConstantPrefix()) { + features[FeatureRefHeads] = struct{}{} + } else { + features[FeatureRefHeadStringPrefixes] = struct{}{} + } + } + } + } + } + + c.Required.Features = stringMapToSortedSlice(features) + + for i, bi := range c.Required.Builtins { + c.Required.Builtins[i] = bi.Minimal() + } +} + +func stringMapToSortedSlice(xs map[string]struct{}) []string { + if len(xs) == 0 { + return nil + } + s := make([]string, 0, len(xs)) + for k := range xs { + s = append(s, k) + } + sort.Strings(s) + return s +} + +// checkRecursion ensures that there are no recursive definitions, i.e., there are +// no cycles in the Graph. +func (c *Compiler) checkRecursion() { + eq := func(a, b util.T) bool { + return a.(*Rule) == b.(*Rule) + } + + c.RuleTree.DepthFirst(func(node *TreeNode) bool { + for _, rule := range node.Values { + for node := rule.(*Rule); node != nil; node = node.Else { + c.checkSelfPath(node.Loc(), eq, node, node) + } + } + return false + }) +} + +func (c *Compiler) checkSelfPath(loc *Location, eq func(a, b util.T) bool, a, b util.T) { + tr := NewGraphTraversal(c.Graph) + if p := util.DFSPath(tr, eq, a, b); len(p) > 0 { + n := make([]string, 0, len(p)) + for _, x := range p { + n = append(n, astNodeToString(x)) + } + c.err(NewError(RecursionErr, loc, "rule %v is recursive: %v", astNodeToString(a), strings.Join(n, " -> "))) + } +} + +func astNodeToString(x interface{}) string { + return x.(*Rule).Ref().String() +} + +// checkRuleConflicts ensures that rules definitions are not in conflict. +func (c *Compiler) checkRuleConflicts() { + rw := rewriteVarsInRef(c.RewrittenVars) + + c.RuleTree.DepthFirst(func(node *TreeNode) bool { + if len(node.Values) == 0 { + return false // go deeper + } + + kinds := make(map[RuleKind]struct{}, len(node.Values)) + completeRules := 0 + partialRules := 0 + arities := make(map[int]struct{}, len(node.Values)) + name := "" + var conflicts []Ref + defaultRules := make([]*Rule, 0) + + for _, rule := range node.Values { + r := rule.(*Rule) + ref := r.Ref() + name = rw(ref.Copy()).String() // varRewriter operates in-place + kinds[r.Head.RuleKind()] = struct{}{} + arities[len(r.Head.Args)] = struct{}{} + if r.Default { + defaultRules = append(defaultRules, r) + } + + // Single-value rules may not have any other rules in their extent. + // Rules with vars in their ref are allowed to have rules inside their extent. + // Only the ground portion (terms before the first var term) of a rule's ref is considered when determining + // whether it's inside the extent of another (c.RuleTree is organized this way already). + // These pairs are invalid: + // + // data.p.q.r { true } # data.p.q is { "r": true } + // data.p.q.r.s { true } + // + // data.p.q.r { true } + // data.p.q.r[s].t { s = input.key } + // + // But this is allowed: + // + // data.p.q.r { true } + // data.p.q[r].s.t { r = input.key } + // + // data.p[r] := x { r = input.key; x = input.bar } + // data.p.q[r] := x { r = input.key; x = input.bar } + // + // data.p.q[r] { r := input.r } + // data.p.q.r.s { true } + // + // data.p.q[r] = 1 { r := "r" } + // data.p.q.s = 2 + // + // data.p[q][r] { q := input.q; r := input.r } + // data.p.q.r { true } + // + // data.p.q[r] { r := input.r } + // data.p[q].r { q := input.q } + // + // data.p.q[r][s] { r := input.r; s := input.s } + // data.p[q].r.s { q := input.q } + + if r.Ref().IsGround() && len(node.Children) > 0 { + conflicts = node.flattenChildren() + } + + if r.Head.RuleKind() == SingleValue && r.Head.Ref().IsGround() { + completeRules++ + } else { + partialRules++ + } + } + + switch { + case conflicts != nil: + c.err(NewError(TypeErr, node.Values[0].(*Rule).Loc(), "rule %v conflicts with %v", name, conflicts)) + + case len(kinds) > 1 || len(arities) > 1 || (completeRules >= 1 && partialRules >= 1): + c.err(NewError(TypeErr, node.Values[0].(*Rule).Loc(), "conflicting rules %v found", name)) + + case len(defaultRules) > 1: + + defaultRuleLocations := strings.Builder{} + defaultRuleLocations.WriteString(defaultRules[0].Loc().String()) + for i := 1; i < len(defaultRules); i++ { + defaultRuleLocations.WriteString(", ") + defaultRuleLocations.WriteString(defaultRules[i].Loc().String()) + } + + c.err(NewError( + TypeErr, + defaultRules[0].Module.Package.Loc(), + "multiple default rules %s found at %s", + name, defaultRuleLocations.String()), + ) + } + + return false + }) + + if c.pathExists != nil { + for _, err := range CheckPathConflicts(c, c.pathExists) { + c.err(err) + } + } + + // NOTE(sr): depthfirst might better use sorted for stable errs? + c.ModuleTree.DepthFirst(func(node *ModuleTreeNode) bool { + for _, mod := range node.Modules { + for _, rule := range mod.Rules { + ref := rule.Head.Ref().GroundPrefix() + // Rules with a dynamic portion in their ref are exempted, as a conflict within the dynamic portion + // can only be detected at eval-time. + if len(ref) < len(rule.Head.Ref()) { + continue + } + + childNode, tail := node.find(ref) + if childNode != nil && len(tail) == 0 { + for _, childMod := range childNode.Modules { + // Avoid recursively checking a module for equality unless we know it's a possible self-match. + if childMod.Equal(mod) { + continue // don't self-conflict + } + msg := fmt.Sprintf("%v conflicts with rule %v defined at %v", childMod.Package, rule.Head.Ref(), rule.Loc()) + c.err(NewError(TypeErr, mod.Package.Loc(), msg)) //nolint:govet + } + } + } + } + return false + }) +} + +func (c *Compiler) checkUndefinedFuncs() { + for _, name := range c.sorted { + m := c.Modules[name] + for _, err := range checkUndefinedFuncs(c.TypeEnv, m, c.GetArity, c.RewrittenVars) { + c.err(err) + } + } +} + +func checkUndefinedFuncs(env *TypeEnv, x interface{}, arity func(Ref) int, rwVars map[Var]Var) Errors { + + var errs Errors + + WalkExprs(x, func(expr *Expr) bool { + if !expr.IsCall() { + return false + } + ref := expr.Operator() + if arity := arity(ref); arity >= 0 { + operands := len(expr.Operands()) + if expr.Generated { // an output var was added + if !expr.IsEquality() && operands != arity+1 { + ref = rewriteVarsInRef(rwVars)(ref) + errs = append(errs, arityMismatchError(env, ref, expr, arity, operands-1)) + return true + } + } else { // either output var or not + if operands != arity && operands != arity+1 { + ref = rewriteVarsInRef(rwVars)(ref) + errs = append(errs, arityMismatchError(env, ref, expr, arity, operands)) + return true + } + } + return false + } + ref = rewriteVarsInRef(rwVars)(ref) + errs = append(errs, NewError(TypeErr, expr.Loc(), "undefined function %v", ref)) + return true + }) + + return errs +} + +func arityMismatchError(env *TypeEnv, f Ref, expr *Expr, exp, act int) *Error { + if want, ok := env.Get(f).(*types.Function); ok { // generate richer error for built-in functions + have := make([]types.Type, len(expr.Operands())) + for i, op := range expr.Operands() { + have[i] = env.Get(op) + } + return newArgError(expr.Loc(), f, "arity mismatch", have, want.NamedFuncArgs()) + } + if act != 1 { + return NewError(TypeErr, expr.Loc(), "function %v has arity %d, got %d arguments", f, exp, act) + } + return NewError(TypeErr, expr.Loc(), "function %v has arity %d, got %d argument", f, exp, act) +} + +// checkSafetyRuleBodies ensures that variables appearing in negated expressions or non-target +// positions of built-in expressions will be bound when evaluating the rule from left +// to right, re-ordering as necessary. +func (c *Compiler) checkSafetyRuleBodies() { + for _, name := range c.sorted { + m := c.Modules[name] + WalkRules(m, func(r *Rule) bool { + safe := ReservedVars.Copy() + safe.Update(r.Head.Args.Vars()) + r.Body = c.checkBodySafety(safe, r.Body) + return false + }) + } +} + +func (c *Compiler) checkBodySafety(safe VarSet, b Body) Body { + reordered, unsafe := reorderBodyForSafety(c.builtins, c.GetArity, safe, b) + if errs := safetyErrorSlice(unsafe, c.RewrittenVars); len(errs) > 0 { + for _, err := range errs { + c.err(err) + } + return b + } + return reordered +} + +// SafetyCheckVisitorParams defines the AST visitor parameters to use for collecting +// variables during the safety check. This has to be exported because it's relied on +// by the copy propagation implementation in topdown. +var SafetyCheckVisitorParams = VarVisitorParams{ + SkipRefCallHead: true, + SkipClosures: true, +} + +// checkSafetyRuleHeads ensures that variables appearing in the head of a +// rule also appear in the body. +func (c *Compiler) checkSafetyRuleHeads() { + + for _, name := range c.sorted { + m := c.Modules[name] + WalkRules(m, func(r *Rule) bool { + safe := r.Body.Vars(SafetyCheckVisitorParams) + safe.Update(r.Head.Args.Vars()) + unsafe := r.Head.Vars().Diff(safe) + for v := range unsafe { + if w, ok := c.RewrittenVars[v]; ok { + v = w + } + if !v.IsGenerated() { + c.err(NewError(UnsafeVarErr, r.Loc(), "var %v is unsafe", v)) + } + } + return false + }) + } +} + +func compileSchema(goSchema interface{}, allowNet []string) (*gojsonschema.Schema, error) { + gojsonschema.SetAllowNet(allowNet) + + var refLoader gojsonschema.JSONLoader + sl := gojsonschema.NewSchemaLoader() + + if goSchema != nil { + refLoader = gojsonschema.NewGoLoader(goSchema) + } else { + return nil, fmt.Errorf("no schema as input to compile") + } + schemasCompiled, err := sl.Compile(refLoader) + if err != nil { + return nil, fmt.Errorf("unable to compile the schema: %w", err) + } + return schemasCompiled, nil +} + +func mergeSchemas(schemas ...*gojsonschema.SubSchema) (*gojsonschema.SubSchema, error) { + if len(schemas) == 0 { + return nil, nil + } + var result = schemas[0] + + for i := range schemas { + if len(schemas[i].PropertiesChildren) > 0 { + if !schemas[i].Types.Contains("object") { + if err := schemas[i].Types.Add("object"); err != nil { + return nil, fmt.Errorf("unable to set the type in schemas") + } + } + } else if len(schemas[i].ItemsChildren) > 0 { + if !schemas[i].Types.Contains("array") { + if err := schemas[i].Types.Add("array"); err != nil { + return nil, fmt.Errorf("unable to set the type in schemas") + } + } + } + } + + for i := 1; i < len(schemas); i++ { + if result.Types.String() != schemas[i].Types.String() { + return nil, fmt.Errorf("unable to merge these schemas: type mismatch: %v and %v", result.Types.String(), schemas[i].Types.String()) + } else if result.Types.Contains("object") && len(result.PropertiesChildren) > 0 && schemas[i].Types.Contains("object") && len(schemas[i].PropertiesChildren) > 0 { + result.PropertiesChildren = append(result.PropertiesChildren, schemas[i].PropertiesChildren...) + } else if result.Types.Contains("array") && len(result.ItemsChildren) > 0 && schemas[i].Types.Contains("array") && len(schemas[i].ItemsChildren) > 0 { + for j := 0; j < len(schemas[i].ItemsChildren); j++ { + if len(result.ItemsChildren)-1 < j && !(len(schemas[i].ItemsChildren)-1 < j) { + result.ItemsChildren = append(result.ItemsChildren, schemas[i].ItemsChildren[j]) + } + if result.ItemsChildren[j].Types.String() != schemas[i].ItemsChildren[j].Types.String() { + return nil, fmt.Errorf("unable to merge these schemas") + } + } + } + } + return result, nil +} + +type schemaParser struct { + definitionCache map[string]*cachedDef +} + +type cachedDef struct { + properties []*types.StaticProperty +} + +func newSchemaParser() *schemaParser { + return &schemaParser{ + definitionCache: map[string]*cachedDef{}, + } +} + +func (parser *schemaParser) parseSchema(schema interface{}) (types.Type, error) { + return parser.parseSchemaWithPropertyKey(schema, "") +} + +func (parser *schemaParser) parseSchemaWithPropertyKey(schema interface{}, propertyKey string) (types.Type, error) { + subSchema, ok := schema.(*gojsonschema.SubSchema) + if !ok { + return nil, fmt.Errorf("unexpected schema type %v", subSchema) + } + + // Handle referenced schemas, returns directly when a $ref is found + if subSchema.RefSchema != nil { + if existing, ok := parser.definitionCache[subSchema.Ref.String()]; ok { + return types.NewObject(existing.properties, nil), nil + } + return parser.parseSchemaWithPropertyKey(subSchema.RefSchema, subSchema.Ref.String()) + } + + // Handle anyOf + if subSchema.AnyOf != nil { + var orType types.Type + + // If there is a core schema, find its type first + if subSchema.Types.IsTyped() { + copySchema := *subSchema + copySchemaRef := ©Schema + copySchemaRef.AnyOf = nil + coreType, err := parser.parseSchema(copySchemaRef) + if err != nil { + return nil, fmt.Errorf("unexpected schema type %v: %w", subSchema, err) + } + + // Only add Object type with static props to orType + if objType, ok := coreType.(*types.Object); ok { + if objType.StaticProperties() != nil && objType.DynamicProperties() == nil { + orType = types.Or(orType, coreType) + } + } + } + + // Iterate through every property of AnyOf and add it to orType + for _, pSchema := range subSchema.AnyOf { + newtype, err := parser.parseSchema(pSchema) + if err != nil { + return nil, fmt.Errorf("unexpected schema type %v: %w", pSchema, err) + } + orType = types.Or(newtype, orType) + } + + return orType, nil + } + + if subSchema.AllOf != nil { + subSchemaArray := subSchema.AllOf + allOfResult, err := mergeSchemas(subSchemaArray...) + if err != nil { + return nil, err + } + + if subSchema.Types.IsTyped() { + if (subSchema.Types.Contains("object") && allOfResult.Types.Contains("object")) || (subSchema.Types.Contains("array") && allOfResult.Types.Contains("array")) { + objectOrArrayResult, err := mergeSchemas(allOfResult, subSchema) + if err != nil { + return nil, err + } + return parser.parseSchema(objectOrArrayResult) + } else if subSchema.Types.String() != allOfResult.Types.String() { + return nil, fmt.Errorf("unable to merge these schemas") + } + } + return parser.parseSchema(allOfResult) + } + + if subSchema.Types.IsTyped() { + if subSchema.Types.Contains("boolean") { + return types.B, nil + + } else if subSchema.Types.Contains("string") { + return types.S, nil + + } else if subSchema.Types.Contains("integer") || subSchema.Types.Contains("number") { + return types.N, nil + + } else if subSchema.Types.Contains("object") { + if len(subSchema.PropertiesChildren) > 0 { + def := &cachedDef{ + properties: make([]*types.StaticProperty, 0, len(subSchema.PropertiesChildren)), + } + for _, pSchema := range subSchema.PropertiesChildren { + def.properties = append(def.properties, types.NewStaticProperty(pSchema.Property, nil)) + } + if propertyKey != "" { + parser.definitionCache[propertyKey] = def + } + for _, pSchema := range subSchema.PropertiesChildren { + newtype, err := parser.parseSchema(pSchema) + if err != nil { + return nil, fmt.Errorf("unexpected schema type %v: %w", pSchema, err) + } + for i, prop := range def.properties { + if prop.Key == pSchema.Property { + def.properties[i].Value = newtype + break + } + } + } + return types.NewObject(def.properties, nil), nil + } + return types.NewObject(nil, types.NewDynamicProperty(types.A, types.A)), nil + + } else if subSchema.Types.Contains("array") { + if len(subSchema.ItemsChildren) > 0 { + if subSchema.ItemsChildrenIsSingleSchema { + iSchema := subSchema.ItemsChildren[0] + newtype, err := parser.parseSchema(iSchema) + if err != nil { + return nil, fmt.Errorf("unexpected schema type %v", iSchema) + } + return types.NewArray(nil, newtype), nil + } + newTypes := make([]types.Type, 0, len(subSchema.ItemsChildren)) + for i := 0; i != len(subSchema.ItemsChildren); i++ { + iSchema := subSchema.ItemsChildren[i] + newtype, err := parser.parseSchema(iSchema) + if err != nil { + return nil, fmt.Errorf("unexpected schema type %v", iSchema) + } + newTypes = append(newTypes, newtype) + } + return types.NewArray(newTypes, nil), nil + } + return types.NewArray(nil, types.A), nil + } + } + + // Assume types if not specified in schema + if len(subSchema.PropertiesChildren) > 0 { + if err := subSchema.Types.Add("object"); err == nil { + return parser.parseSchema(subSchema) + } + } else if len(subSchema.ItemsChildren) > 0 { + if err := subSchema.Types.Add("array"); err == nil { + return parser.parseSchema(subSchema) + } + } + + return types.A, nil +} + +func (c *Compiler) setAnnotationSet() { + // Sorting modules by name for stable error reporting + sorted := make([]*Module, 0, len(c.Modules)) + for _, mName := range c.sorted { + sorted = append(sorted, c.Modules[mName]) + } + + as, errs := BuildAnnotationSet(sorted) + for _, err := range errs { + c.err(err) + } + c.annotationSet = as +} + +// checkTypes runs the type checker on all rules. The type checker builds a +// TypeEnv that is stored on the compiler. +func (c *Compiler) checkTypes() { + // Recursion is caught in earlier step, so this cannot fail. + sorted, _ := c.Graph.Sort() + checker := newTypeChecker(). + WithAllowNet(c.capabilities.AllowNet). + WithSchemaSet(c.schemaSet). + WithInputType(c.inputType). + WithBuiltins(c.builtins). + WithRequiredCapabilities(c.Required). + WithVarRewriter(rewriteVarsInRef(c.RewrittenVars)). + WithAllowUndefinedFunctionCalls(c.allowUndefinedFuncCalls) + var as *AnnotationSet + if c.useTypeCheckAnnotations { + as = c.annotationSet + } + env, errs := checker.CheckTypes(c.TypeEnv, sorted, as) + for _, err := range errs { + c.err(err) + } + c.TypeEnv = env +} + +func (c *Compiler) checkUnsafeBuiltins() { + for _, name := range c.sorted { + errs := checkUnsafeBuiltins(c.unsafeBuiltinsMap, c.Modules[name]) + for _, err := range errs { + c.err(err) + } + } +} + +func (c *Compiler) checkDeprecatedBuiltins() { + for _, name := range c.sorted { + mod := c.Modules[name] + if c.strict || mod.regoV1Compatible() { + errs := checkDeprecatedBuiltins(c.deprecatedBuiltinsMap, mod) + for _, err := range errs { + c.err(err) + } + } + } +} + +func (c *Compiler) runStage(metricName string, f func()) { + if c.metrics != nil { + c.metrics.Timer(metricName).Start() + defer c.metrics.Timer(metricName).Stop() + } + f() +} + +func (c *Compiler) runStageAfter(metricName string, s CompilerStage) *Error { + if c.metrics != nil { + c.metrics.Timer(metricName).Start() + defer c.metrics.Timer(metricName).Stop() + } + return s(c) +} + +func (c *Compiler) compile() { + + defer func() { + if r := recover(); r != nil && r != errLimitReached { + panic(r) + } + }() + + for _, s := range c.stages { + if c.evalMode == EvalModeIR { + switch s.name { + case "BuildRuleIndices", "BuildComprehensionIndices": + continue // skip these stages + } + } + + if c.allowUndefinedFuncCalls && (s.name == "CheckUndefinedFuncs" || s.name == "CheckSafetyRuleBodies") { + continue + } + + c.runStage(s.metricName, s.f) + if c.Failed() { + return + } + for _, a := range c.after[s.name] { + if err := c.runStageAfter(a.MetricName, a.Stage); err != nil { + c.err(err) + return + } + } + } +} + +func (c *Compiler) init() { + + if c.initialized { + return + } + + if c.capabilities == nil { + c.capabilities = CapabilitiesForThisVersion() + } + + c.builtins = make(map[string]*Builtin, len(c.capabilities.Builtins)+len(c.customBuiltins)) + + for _, bi := range c.capabilities.Builtins { + c.builtins[bi.Name] = bi + if bi.IsDeprecated() { + c.deprecatedBuiltinsMap[bi.Name] = struct{}{} + } + } + + for name, bi := range c.customBuiltins { + c.builtins[name] = bi + } + + // Load the global input schema if one was provided. + if c.schemaSet != nil { + if schema := c.schemaSet.Get(SchemaRootRef); schema != nil { + tpe, err := loadSchema(schema, c.capabilities.AllowNet) + if err != nil { + c.err(NewError(TypeErr, nil, err.Error())) //nolint:govet + } else { + c.inputType = tpe + } + } + } + + c.TypeEnv = newTypeChecker(). + WithSchemaSet(c.schemaSet). + WithInputType(c.inputType). + Env(c.builtins) + + c.initialized = true +} + +func (c *Compiler) err(err *Error) { + if c.maxErrs > 0 && len(c.Errors) >= c.maxErrs { + c.Errors = append(c.Errors, errLimitReached) + panic(errLimitReached) + } + c.Errors = append(c.Errors, err) +} + +func (c *Compiler) getExports() *util.HashMap { + + rules := util.NewHashMap(func(a, b util.T) bool { + return a.(Ref).Equal(b.(Ref)) + }, func(v util.T) int { + return v.(Ref).Hash() + }) + + for _, name := range c.sorted { + mod := c.Modules[name] + + for _, rule := range mod.Rules { + hashMapAdd(rules, mod.Package.Path, rule.Head.Ref().GroundPrefix()) + } + } + + return rules +} + +func hashMapAdd(rules *util.HashMap, pkg, rule Ref) { + prev, ok := rules.Get(pkg) + if !ok { + rules.Put(pkg, []Ref{rule}) + return + } + for _, p := range prev.([]Ref) { + if p.Equal(rule) { + return + } + } + rules.Put(pkg, append(prev.([]Ref), rule)) +} + +func (c *Compiler) GetAnnotationSet() *AnnotationSet { + return c.annotationSet +} + +func (c *Compiler) checkImports() { + modules := make([]*Module, 0, len(c.Modules)) + + supportsRegoV1Import := c.capabilities.ContainsFeature(FeatureRegoV1Import) || + c.capabilities.ContainsFeature(FeatureRegoV1) + + for _, name := range c.sorted { + mod := c.Modules[name] + + for _, imp := range mod.Imports { + if !supportsRegoV1Import && Compare(imp.Path, RegoV1CompatibleRef) == 0 { + c.err(NewError(CompileErr, imp.Loc(), "rego.v1 import is not supported")) + } + } + + if c.strict || c.moduleIsRegoV1Compatible(mod) { + modules = append(modules, mod) + } + } + + errs := checkDuplicateImports(modules) + for _, err := range errs { + c.err(err) + } +} + +func (c *Compiler) checkKeywordOverrides() { + for _, name := range c.sorted { + mod := c.Modules[name] + if c.strict || c.moduleIsRegoV1Compatible(mod) { + errs := checkRootDocumentOverrides(mod) + for _, err := range errs { + c.err(err) + } + } + } +} + +func (c *Compiler) moduleIsRegoV1(mod *Module) bool { + if mod.regoVersion == RegoUndefined { + switch c.defaultRegoVersion { + case RegoUndefined: + c.err(NewError(CompileErr, mod.Package.Loc(), "cannot determine rego version for module")) + return false + case RegoV1: + return true + } + return false + } + return mod.regoVersion == RegoV1 +} + +func (c *Compiler) moduleIsRegoV1Compatible(mod *Module) bool { + if mod.regoVersion == RegoUndefined { + switch c.defaultRegoVersion { + case RegoUndefined: + c.err(NewError(CompileErr, mod.Package.Loc(), "cannot determine rego version for module")) + return false + case RegoV1, RegoV0CompatV1: + return true + } + return false + } + return mod.regoV1Compatible() +} + +// resolveAllRefs resolves references in expressions to their fully qualified values. +// +// For instance, given the following module: +// +// package a.b +// import data.foo.bar +// p[x] { bar[_] = x } +// +// The reference "bar[_]" would be resolved to "data.foo.bar[_]". +// +// Ref rules are resolved, too: +// +// package a.b +// q { c.d.e == 1 } +// c.d[e] := 1 if e := "e" +// +// The reference "c.d.e" would be resolved to "data.a.b.c.d.e". +func (c *Compiler) resolveAllRefs() { + + rules := c.getExports() + + for _, name := range c.sorted { + mod := c.Modules[name] + + var ruleExports []Ref + if x, ok := rules.Get(mod.Package.Path); ok { + ruleExports = x.([]Ref) + } + + globals := getGlobals(mod.Package, ruleExports, mod.Imports) + + WalkRules(mod, func(rule *Rule) bool { + err := resolveRefsInRule(globals, rule) + if err != nil { + c.err(NewError(CompileErr, rule.Location, err.Error())) //nolint:govet + } + return false + }) + + if c.strict { // check for unused imports + for _, imp := range mod.Imports { + path := imp.Path.Value.(Ref) + if FutureRootDocument.Equal(path[0]) || RegoRootDocument.Equal(path[0]) { + continue // ignore future and rego imports + } + + for v, u := range globals { + if v.Equal(imp.Name()) && !u.used { + c.err(NewError(CompileErr, imp.Location, "%s unused", imp.String())) + } + } + } + } + } + + if c.moduleLoader != nil { + + parsed, err := c.moduleLoader(c.Modules) + if err != nil { + c.err(NewError(CompileErr, nil, err.Error())) //nolint:govet + return + } + + if len(parsed) == 0 { + return + } + + for id, module := range parsed { + c.Modules[id] = module.Copy() + c.sorted = append(c.sorted, id) + if c.parsedModules != nil { + c.parsedModules[id] = module + } + } + + sort.Strings(c.sorted) + c.resolveAllRefs() + } +} + +func (c *Compiler) removeImports() { + c.imports = make(map[string][]*Import, len(c.Modules)) + for name := range c.Modules { + c.imports[name] = c.Modules[name].Imports + c.Modules[name].Imports = nil + } +} + +func (c *Compiler) initLocalVarGen() { + c.localvargen = newLocalVarGeneratorForModuleSet(c.sorted, c.Modules) +} + +func (c *Compiler) rewriteComprehensionTerms() { + f := newEqualityFactory(c.localvargen) + for _, name := range c.sorted { + mod := c.Modules[name] + _, _ = rewriteComprehensionTerms(f, mod) // ignore error + } +} + +func (c *Compiler) rewriteExprTerms() { + for _, name := range c.sorted { + mod := c.Modules[name] + WalkRules(mod, func(rule *Rule) bool { + rewriteExprTermsInHead(c.localvargen, rule) + rule.Body = rewriteExprTermsInBody(c.localvargen, rule.Body) + return false + }) + } +} + +func (c *Compiler) rewriteRuleHeadRefs() { + f := newEqualityFactory(c.localvargen) + for _, name := range c.sorted { + WalkRules(c.Modules[name], func(rule *Rule) bool { + + ref := rule.Head.Ref() + // NOTE(sr): We're backfilling Refs here -- all parser code paths would have them, but + // it's possible to construct Module{} instances from Golang code, so we need + // to accommodate for that, too. + if len(rule.Head.Reference) == 0 { + rule.Head.Reference = ref + } + + cannotSpeakStringPrefixRefs := true + cannotSpeakGeneralRefs := true + for _, f := range c.capabilities.Features { + switch f { + case FeatureRefHeadStringPrefixes: + cannotSpeakStringPrefixRefs = false + case FeatureRefHeads: + cannotSpeakGeneralRefs = false + case FeatureRegoV1: + cannotSpeakStringPrefixRefs = false + cannotSpeakGeneralRefs = false + } + } + + if cannotSpeakStringPrefixRefs && cannotSpeakGeneralRefs && rule.Head.Name == "" { + c.err(NewError(CompileErr, rule.Loc(), "rule heads with refs are not supported: %v", rule.Head.Reference)) + return true + } + + for i := 1; i < len(ref); i++ { + if cannotSpeakGeneralRefs && (rule.Head.RuleKind() == MultiValue || i != len(ref)-1) { // last + if _, ok := ref[i].Value.(String); !ok { + c.err(NewError(TypeErr, rule.Loc(), "rule heads with general refs (containing variables) are not supported: %v", rule.Head.Reference)) + continue + } + } + + // Rewrite so that any non-scalar elements in the rule's ref are vars: + // p.q.r[y.z] { ... } => p.q.r[__local0__] { __local0__ = y.z } + // p.q[a.b][c.d] { ... } => p.q[__local0__] { __local0__ = a.b; __local1__ = c.d } + // because that's what the RuleTree knows how to deal with. + if _, ok := ref[i].Value.(Var); !ok && !IsScalar(ref[i].Value) { + expr := f.Generate(ref[i]) + if i == len(ref)-1 && rule.Head.Key.Equal(ref[i]) { + rule.Head.Key = expr.Operand(0) + } + rule.Head.Reference[i] = expr.Operand(0) + rule.Body.Append(expr) + } + } + + return true + }) + } +} + +func (c *Compiler) checkVoidCalls() { + for _, name := range c.sorted { + mod := c.Modules[name] + for _, err := range checkVoidCalls(c.TypeEnv, mod) { + c.err(err) + } + } +} + +func (c *Compiler) rewritePrintCalls() { + var modified bool + if !c.enablePrintStatements { + for _, name := range c.sorted { + if erasePrintCalls(c.Modules[name]) { + modified = true + } + } + } else { + for _, name := range c.sorted { + mod := c.Modules[name] + WalkRules(mod, func(r *Rule) bool { + safe := r.Head.Args.Vars() + safe.Update(ReservedVars) + vis := func(b Body) bool { + modrec, errs := rewritePrintCalls(c.localvargen, c.GetArity, safe, b) + if modrec { + modified = true + } + for _, err := range errs { + c.err(err) + } + return false + } + WalkBodies(r.Head, vis) + WalkBodies(r.Body, vis) + return false + }) + } + } + if modified { + c.Required.addBuiltinSorted(Print) + } +} + +// checkVoidCalls returns errors for any expressions that treat void function +// calls as values. The only void functions in Rego are specific built-ins like +// print(). +func checkVoidCalls(env *TypeEnv, x interface{}) Errors { + var errs Errors + WalkTerms(x, func(x *Term) bool { + if call, ok := x.Value.(Call); ok { + if tpe, ok := env.Get(call[0]).(*types.Function); ok && tpe.Result() == nil { + errs = append(errs, NewError(TypeErr, x.Loc(), "%v used as value", call)) + } + } + return false + }) + return errs +} + +// rewritePrintCalls will rewrite the body so that print operands are captured +// in local variables and their evaluation occurs within a comprehension. +// Wrapping the terms inside of a comprehension ensures that undefined values do +// not short-circuit evaluation. +// +// For example, given the following print statement: +// +// print("the value of x is:", input.x) +// +// The expression would be rewritten to: +// +// print({__local0__ | __local0__ = "the value of x is:"}, {__local1__ | __local1__ = input.x}) +func rewritePrintCalls(gen *localVarGenerator, getArity func(Ref) int, globals VarSet, body Body) (bool, Errors) { + + var errs Errors + var modified bool + + // Visit comprehension bodies recursively to ensure print statements inside + // those bodies only close over variables that are safe. + for i := range body { + if ContainsClosures(body[i]) { + safe := outputVarsForBody(body[:i], getArity, globals) + safe.Update(globals) + WalkClosures(body[i], func(x interface{}) bool { + var modrec bool + var errsrec Errors + switch x := x.(type) { + case *SetComprehension: + modrec, errsrec = rewritePrintCalls(gen, getArity, safe, x.Body) + case *ArrayComprehension: + modrec, errsrec = rewritePrintCalls(gen, getArity, safe, x.Body) + case *ObjectComprehension: + modrec, errsrec = rewritePrintCalls(gen, getArity, safe, x.Body) + case *Every: + safe.Update(x.KeyValueVars()) + modrec, errsrec = rewritePrintCalls(gen, getArity, safe, x.Body) + } + if modrec { + modified = true + } + errs = append(errs, errsrec...) + return true + }) + if len(errs) > 0 { + return false, errs + } + } + } + + for i := range body { + + if !isPrintCall(body[i]) { + continue + } + + modified = true + + var errs Errors + safe := outputVarsForBody(body[:i], getArity, globals) + safe.Update(globals) + args := body[i].Operands() + + for j := range args { + vis := NewVarVisitor().WithParams(SafetyCheckVisitorParams) + vis.Walk(args[j]) + unsafe := vis.Vars().Diff(safe) + for _, v := range unsafe.Sorted() { + errs = append(errs, NewError(CompileErr, args[j].Loc(), "var %v is undeclared", v)) + } + } + + if len(errs) > 0 { + return false, errs + } + + arr := NewArray() + + for j := range args { + x := NewTerm(gen.Generate()).SetLocation(args[j].Loc()) + capture := Equality.Expr(x, args[j]).SetLocation(args[j].Loc()) + arr = arr.Append(SetComprehensionTerm(x, NewBody(capture)).SetLocation(args[j].Loc())) + } + + body.Set(NewExpr([]*Term{ + NewTerm(InternalPrint.Ref()).SetLocation(body[i].Loc()), + NewTerm(arr).SetLocation(body[i].Loc()), + }).SetLocation(body[i].Loc()), i) + } + + return modified, nil +} + +func erasePrintCalls(node interface{}) bool { + var modified bool + NewGenericVisitor(func(x interface{}) bool { + var modrec bool + switch x := x.(type) { + case *Rule: + modrec, x.Body = erasePrintCallsInBody(x.Body) + case *ArrayComprehension: + modrec, x.Body = erasePrintCallsInBody(x.Body) + case *SetComprehension: + modrec, x.Body = erasePrintCallsInBody(x.Body) + case *ObjectComprehension: + modrec, x.Body = erasePrintCallsInBody(x.Body) + case *Every: + modrec, x.Body = erasePrintCallsInBody(x.Body) + } + if modrec { + modified = true + } + return false + }).Walk(node) + return modified +} + +func erasePrintCallsInBody(x Body) (bool, Body) { + + if !containsPrintCall(x) { + return false, x + } + + var cpy Body + + for i := range x { + + // Recursively visit any comprehensions contained in this expression. + erasePrintCalls(x[i]) + + if !isPrintCall(x[i]) { + cpy.Append(x[i]) + } + } + + if len(cpy) == 0 { + term := BooleanTerm(true).SetLocation(x.Loc()) + expr := NewExpr(term).SetLocation(x.Loc()) + cpy.Append(expr) + } + + return true, cpy +} + +func containsPrintCall(x interface{}) bool { + var found bool + WalkExprs(x, func(expr *Expr) bool { + if !found { + if isPrintCall(expr) { + found = true + } + } + return found + }) + return found +} + +func isPrintCall(x *Expr) bool { + return x.IsCall() && x.Operator().Equal(Print.Ref()) +} + +// rewriteRefsInHead will rewrite rules so that the head does not contain any +// terms that require evaluation (e.g., refs or comprehensions). If the key or +// value contains one or more of these terms, the key or value will be moved +// into the body and assigned to a new variable. The new variable will replace +// the key or value in the head. +// +// For instance, given the following rule: +// +// p[{"foo": data.foo[i]}] { i < 100 } +// +// The rule would be re-written as: +// +// p[__local0__] { i < 100; __local0__ = {"foo": data.foo[i]} } +func (c *Compiler) rewriteRefsInHead() { + f := newEqualityFactory(c.localvargen) + for _, name := range c.sorted { + mod := c.Modules[name] + WalkRules(mod, func(rule *Rule) bool { + if requiresEval(rule.Head.Key) { + expr := f.Generate(rule.Head.Key) + rule.Head.Key = expr.Operand(0) + rule.Body.Append(expr) + } + if requiresEval(rule.Head.Value) { + expr := f.Generate(rule.Head.Value) + rule.Head.Value = expr.Operand(0) + rule.Body.Append(expr) + } + for i := 0; i < len(rule.Head.Args); i++ { + if requiresEval(rule.Head.Args[i]) { + expr := f.Generate(rule.Head.Args[i]) + rule.Head.Args[i] = expr.Operand(0) + rule.Body.Append(expr) + } + } + return false + }) + } +} + +func (c *Compiler) rewriteEquals() { + modified := false + for _, name := range c.sorted { + mod := c.Modules[name] + modified = rewriteEquals(mod) || modified + } + if modified { + c.Required.addBuiltinSorted(Equal) + } +} + +func (c *Compiler) rewriteDynamicTerms() { + f := newEqualityFactory(c.localvargen) + for _, name := range c.sorted { + mod := c.Modules[name] + WalkRules(mod, func(rule *Rule) bool { + rule.Body = rewriteDynamics(f, rule.Body) + return false + }) + } +} + +// rewriteTestRuleEqualities rewrites equality expressions in test rule bodies to create local vars for statements that would otherwise +// not have their values captured through tracing, such as refs and comprehensions not unified/assigned to a local var. +// For example, given the following module: +// +// package test +// +// p.q contains v if { +// some v in numbers.range(1, 3) +// } +// +// p.r := "foo" +// +// test_rule { +// p == { +// "q": {4, 5, 6} +// } +// } +// +// `p` in `test_rule` resolves to `data.test.p`, which won't be an entry in the virtual-cache and must therefore be calculated after-the-fact. +// If `p` isn't captured in a local var, there is no trivial way to retrieve its value for test reporting. +func (c *Compiler) rewriteTestRuleEqualities() { + if !c.rewriteTestRulesForTracing { + return + } + + f := newEqualityFactory(c.localvargen) + for _, name := range c.sorted { + mod := c.Modules[name] + WalkRules(mod, func(rule *Rule) bool { + if strings.HasPrefix(string(rule.Head.Name), "test_") { + rule.Body = rewriteTestEqualities(f, rule.Body) + } + return false + }) + } +} + +func (c *Compiler) parseMetadataBlocks() { + // Only parse annotations if rego.metadata built-ins are called + regoMetadataCalled := false + for _, name := range c.sorted { + mod := c.Modules[name] + WalkExprs(mod, func(expr *Expr) bool { + if isRegoMetadataChainCall(expr) || isRegoMetadataRuleCall(expr) { + regoMetadataCalled = true + } + return regoMetadataCalled + }) + + if regoMetadataCalled { + break + } + } + + if regoMetadataCalled { + // NOTE: Possible optimization: only parse annotations for modules on the path of rego.metadata-calling module + for _, name := range c.sorted { + mod := c.Modules[name] + + if len(mod.Annotations) == 0 { + var errs Errors + mod.Annotations, errs = parseAnnotations(mod.Comments) + errs = append(errs, attachAnnotationsNodes(mod)...) + for _, err := range errs { + c.err(err) + } + + attachRuleAnnotations(mod) + } + } + } +} + +func (c *Compiler) rewriteRegoMetadataCalls() { + eqFactory := newEqualityFactory(c.localvargen) + + _, chainFuncAllowed := c.builtins[RegoMetadataChain.Name] + _, ruleFuncAllowed := c.builtins[RegoMetadataRule.Name] + + for _, name := range c.sorted { + mod := c.Modules[name] + + WalkRules(mod, func(rule *Rule) bool { + var firstChainCall *Expr + var firstRuleCall *Expr + + WalkExprs(rule, func(expr *Expr) bool { + if chainFuncAllowed && firstChainCall == nil && isRegoMetadataChainCall(expr) { + firstChainCall = expr + } else if ruleFuncAllowed && firstRuleCall == nil && isRegoMetadataRuleCall(expr) { + firstRuleCall = expr + } + return firstChainCall != nil && firstRuleCall != nil + }) + + chainCalled := firstChainCall != nil + ruleCalled := firstRuleCall != nil + + if chainCalled || ruleCalled { + body := make(Body, 0, len(rule.Body)+2) + + var metadataChainVar Var + if chainCalled { + // Create and inject metadata chain for rule + + chain, err := createMetadataChain(c.annotationSet.Chain(rule)) + if err != nil { + c.err(err) + return false + } + + chain.Location = firstChainCall.Location + eq := eqFactory.Generate(chain) + metadataChainVar = eq.Operands()[0].Value.(Var) + body.Append(eq) + } + + var metadataRuleVar Var + if ruleCalled { + // Create and inject metadata for rule + + var metadataRuleTerm *Term + + a := getPrimaryRuleAnnotations(c.annotationSet, rule) + if a != nil { + annotObj, err := a.toObject() + if err != nil { + c.err(err) + return false + } + metadataRuleTerm = NewTerm(*annotObj) + } else { + // If rule has no annotations, assign an empty object + metadataRuleTerm = ObjectTerm() + } + + metadataRuleTerm.Location = firstRuleCall.Location + eq := eqFactory.Generate(metadataRuleTerm) + metadataRuleVar = eq.Operands()[0].Value.(Var) + body.Append(eq) + } + + for _, expr := range rule.Body { + body.Append(expr) + } + rule.Body = body + + vis := func(b Body) bool { + for _, err := range rewriteRegoMetadataCalls(&metadataChainVar, &metadataRuleVar, b, &c.RewrittenVars) { + c.err(err) + } + return false + } + WalkBodies(rule.Head, vis) + WalkBodies(rule.Body, vis) + } + + return false + }) + } +} + +func getPrimaryRuleAnnotations(as *AnnotationSet, rule *Rule) *Annotations { + annots := as.GetRuleScope(rule) + + if len(annots) == 0 { + return nil + } + + // Sort by annotation location; chain must start with annotations declared closest to rule, then going outward + sort.SliceStable(annots, func(i, j int) bool { + return annots[i].Location.Compare(annots[j].Location) > 0 + }) + + return annots[0] +} + +func rewriteRegoMetadataCalls(metadataChainVar *Var, metadataRuleVar *Var, body Body, rewrittenVars *map[Var]Var) Errors { + var errs Errors + + WalkClosures(body, func(x interface{}) bool { + switch x := x.(type) { + case *ArrayComprehension: + errs = rewriteRegoMetadataCalls(metadataChainVar, metadataRuleVar, x.Body, rewrittenVars) + case *SetComprehension: + errs = rewriteRegoMetadataCalls(metadataChainVar, metadataRuleVar, x.Body, rewrittenVars) + case *ObjectComprehension: + errs = rewriteRegoMetadataCalls(metadataChainVar, metadataRuleVar, x.Body, rewrittenVars) + case *Every: + errs = rewriteRegoMetadataCalls(metadataChainVar, metadataRuleVar, x.Body, rewrittenVars) + } + return true + }) + + for i := range body { + expr := body[i] + var metadataVar Var + + if metadataChainVar != nil && isRegoMetadataChainCall(expr) { + metadataVar = *metadataChainVar + } else if metadataRuleVar != nil && isRegoMetadataRuleCall(expr) { + metadataVar = *metadataRuleVar + } else { + continue + } + + // NOTE(johanfylling): An alternative strategy would be to walk the body and replace all operands[0] + // usages with *metadataChainVar + operands := expr.Operands() + var newExpr *Expr + if len(operands) > 0 { // There is an output var to rewrite + rewrittenVar := operands[0] + newExpr = Equality.Expr(rewrittenVar, NewTerm(metadataVar)) + } else { // No output var, just rewrite expr to metadataVar + newExpr = NewExpr(NewTerm(metadataVar)) + } + + newExpr.Generated = true + newExpr.Location = expr.Location + body.Set(newExpr, i) + } + + return errs +} + +func isRegoMetadataChainCall(x *Expr) bool { + return x.IsCall() && x.Operator().Equal(RegoMetadataChain.Ref()) +} + +func isRegoMetadataRuleCall(x *Expr) bool { + return x.IsCall() && x.Operator().Equal(RegoMetadataRule.Ref()) +} + +func createMetadataChain(chain []*AnnotationsRef) (*Term, *Error) { + + metaArray := NewArray() + for _, link := range chain { + p := link.Path.toArray(). + Slice(1, -1) // Dropping leading 'data' element of path + obj := NewObject( + Item(StringTerm("path"), NewTerm(p)), + ) + if link.Annotations != nil { + annotObj, err := link.Annotations.toObject() + if err != nil { + return nil, err + } + obj.Insert(StringTerm("annotations"), NewTerm(*annotObj)) + } + metaArray = metaArray.Append(NewTerm(obj)) + } + + return NewTerm(metaArray), nil +} + +func (c *Compiler) rewriteLocalVars() { + + var assignment bool + + for _, name := range c.sorted { + mod := c.Modules[name] + gen := c.localvargen + + WalkRules(mod, func(rule *Rule) bool { + argsStack := newLocalDeclaredVars() + + args := NewVarVisitor() + if c.strict { + args.Walk(rule.Head.Args) + } + unusedArgs := args.Vars() + + c.rewriteLocalArgVars(gen, argsStack, rule) + + // Rewrite local vars in each else-branch of the rule. + // Note: this is done instead of a walk so that we can capture any unused function arguments + // across else-branches. + for rule := rule; rule != nil; rule = rule.Else { + stack, errs := c.rewriteLocalVarsInRule(rule, unusedArgs, argsStack, gen) + if stack.assignment { + assignment = true + } + + for arg := range unusedArgs { + if stack.Count(arg) > 1 { + delete(unusedArgs, arg) + } + } + + for _, err := range errs { + c.err(err) + } + } + + if c.strict { + // Report an error for each unused function argument + for arg := range unusedArgs { + if !arg.IsWildcard() { + c.err(NewError(CompileErr, rule.Head.Location, "unused argument %v. (hint: use _ (wildcard variable) instead)", arg)) + } + } + } + + return true + }) + } + + if assignment { + c.Required.addBuiltinSorted(Assign) + } +} + +func (c *Compiler) rewriteLocalVarsInRule(rule *Rule, unusedArgs VarSet, argsStack *localDeclaredVars, gen *localVarGenerator) (*localDeclaredVars, Errors) { + // Rewrite assignments contained in head of rule. Assignments can + // occur in rule head if they're inside a comprehension. Note, + // assigned vars in comprehensions in the head will be rewritten + // first to preserve scoping rules. For example: + // + // p = [x | x := 1] { x := 2 } becomes p = [__local0__ | __local0__ = 1] { __local1__ = 2 } + // + // This behaviour is consistent scoping inside the body. For example: + // + // p = xs { x := 2; xs = [x | x := 1] } becomes p = xs { __local0__ = 2; xs = [__local1__ | __local1__ = 1] } + nestedXform := &rewriteNestedHeadVarLocalTransform{ + gen: gen, + RewrittenVars: c.RewrittenVars, + strict: c.strict, + } + + NewGenericVisitor(nestedXform.Visit).Walk(rule.Head) + + for _, err := range nestedXform.errs { + c.err(err) + } + + // Rewrite assignments in body. + used := NewVarSet() + + for _, t := range rule.Head.Ref()[1:] { + used.Update(t.Vars()) + } + + if rule.Head.Key != nil { + used.Update(rule.Head.Key.Vars()) + } + + if rule.Head.Value != nil { + valueVars := rule.Head.Value.Vars() + used.Update(valueVars) + for arg := range unusedArgs { + if valueVars.Contains(arg) { + delete(unusedArgs, arg) + } + } + } + + stack := argsStack.Copy() + + body, declared, errs := rewriteLocalVars(gen, stack, used, rule.Body, c.strict) + + // For rewritten vars use the collection of all variables that + // were in the stack at some point in time. + for k, v := range stack.rewritten { + c.RewrittenVars[k] = v + } + + rule.Body = body + + // Rewrite vars in head that refer to locally declared vars in the body. + localXform := rewriteHeadVarLocalTransform{declared: declared} + + for i := range rule.Head.Args { + rule.Head.Args[i], _ = transformTerm(localXform, rule.Head.Args[i]) + } + + for i := 1; i < len(rule.Head.Ref()); i++ { + rule.Head.Reference[i], _ = transformTerm(localXform, rule.Head.Ref()[i]) + } + if rule.Head.Key != nil { + rule.Head.Key, _ = transformTerm(localXform, rule.Head.Key) + } + + if rule.Head.Value != nil { + rule.Head.Value, _ = transformTerm(localXform, rule.Head.Value) + } + return stack, errs +} + +type rewriteNestedHeadVarLocalTransform struct { + gen *localVarGenerator + errs Errors + RewrittenVars map[Var]Var + strict bool +} + +func (xform *rewriteNestedHeadVarLocalTransform) Visit(x interface{}) bool { + + if term, ok := x.(*Term); ok { + + stop := false + stack := newLocalDeclaredVars() + + switch x := term.Value.(type) { + case *object: + cpy, _ := x.Map(func(k, v *Term) (*Term, *Term, error) { + kcpy := k.Copy() + NewGenericVisitor(xform.Visit).Walk(kcpy) + vcpy := v.Copy() + NewGenericVisitor(xform.Visit).Walk(vcpy) + return kcpy, vcpy, nil + }) + term.Value = cpy + stop = true + case *set: + cpy, _ := x.Map(func(v *Term) (*Term, error) { + vcpy := v.Copy() + NewGenericVisitor(xform.Visit).Walk(vcpy) + return vcpy, nil + }) + term.Value = cpy + stop = true + case *ArrayComprehension: + xform.errs = rewriteDeclaredVarsInArrayComprehension(xform.gen, stack, x, xform.errs, xform.strict) + stop = true + case *SetComprehension: + xform.errs = rewriteDeclaredVarsInSetComprehension(xform.gen, stack, x, xform.errs, xform.strict) + stop = true + case *ObjectComprehension: + xform.errs = rewriteDeclaredVarsInObjectComprehension(xform.gen, stack, x, xform.errs, xform.strict) + stop = true + } + + for k, v := range stack.rewritten { + xform.RewrittenVars[k] = v + } + + return stop + } + + return false +} + +type rewriteHeadVarLocalTransform struct { + declared map[Var]Var +} + +func (xform rewriteHeadVarLocalTransform) Transform(x interface{}) (interface{}, error) { + if v, ok := x.(Var); ok { + if gv, ok := xform.declared[v]; ok { + return gv, nil + } + } + return x, nil +} + +func (c *Compiler) rewriteLocalArgVars(gen *localVarGenerator, stack *localDeclaredVars, rule *Rule) { + + vis := &ruleArgLocalRewriter{ + stack: stack, + gen: gen, + } + + for i := range rule.Head.Args { + Walk(vis, rule.Head.Args[i]) + } + + for i := range vis.errs { + c.err(vis.errs[i]) + } +} + +type ruleArgLocalRewriter struct { + stack *localDeclaredVars + gen *localVarGenerator + errs []*Error +} + +func (vis *ruleArgLocalRewriter) Visit(x interface{}) Visitor { + + t, ok := x.(*Term) + if !ok { + return vis + } + + switch v := t.Value.(type) { + case Var: + gv, ok := vis.stack.Declared(v) + if ok { + vis.stack.Seen(v) + } else { + gv = vis.gen.Generate() + vis.stack.Insert(v, gv, argVar) + } + t.Value = gv + return nil + case *object: + if cpy, err := v.Map(func(k, v *Term) (*Term, *Term, error) { + vcpy := v.Copy() + Walk(vis, vcpy) + return k, vcpy, nil + }); err != nil { + vis.errs = append(vis.errs, NewError(CompileErr, t.Location, err.Error())) //nolint:govet + } else { + t.Value = cpy + } + return nil + case Null, Boolean, Number, String, *ArrayComprehension, *SetComprehension, *ObjectComprehension, Set: + // Scalars are no-ops. Comprehensions are handled above. Sets must not + // contain variables. + return nil + case Call: + vis.errs = append(vis.errs, NewError(CompileErr, t.Location, "rule arguments cannot contain calls")) + return nil + default: + // Recurse on refs and arrays. Any embedded + // variables can be rewritten. + return vis + } +} + +func (c *Compiler) rewriteWithModifiers() { + f := newEqualityFactory(c.localvargen) + for _, name := range c.sorted { + mod := c.Modules[name] + t := NewGenericTransformer(func(x interface{}) (interface{}, error) { + body, ok := x.(Body) + if !ok { + return x, nil + } + body, err := rewriteWithModifiersInBody(c, c.unsafeBuiltinsMap, f, body) + if err != nil { + c.err(err) + } + + return body, nil + }) + _, _ = Transform(t, mod) // ignore error + } +} + +func (c *Compiler) setModuleTree() { + c.ModuleTree = NewModuleTree(c.Modules) +} + +func (c *Compiler) setRuleTree() { + c.RuleTree = NewRuleTree(c.ModuleTree) +} + +func (c *Compiler) setGraph() { + list := func(r Ref) []*Rule { + return c.GetRulesDynamicWithOpts(r, RulesOptions{IncludeHiddenModules: true}) + } + c.Graph = NewGraph(c.Modules, list) +} + +type queryCompiler struct { + compiler *Compiler + qctx *QueryContext + typeEnv *TypeEnv + rewritten map[Var]Var + after map[string][]QueryCompilerStageDefinition + unsafeBuiltins map[string]struct{} + comprehensionIndices map[*Term]*ComprehensionIndex + enablePrintStatements bool +} + +func newQueryCompiler(compiler *Compiler) QueryCompiler { + qc := &queryCompiler{ + compiler: compiler, + qctx: nil, + after: map[string][]QueryCompilerStageDefinition{}, + comprehensionIndices: map[*Term]*ComprehensionIndex{}, + } + return qc +} + +func (qc *queryCompiler) WithStrict(strict bool) QueryCompiler { + qc.compiler.WithStrict(strict) + return qc +} + +func (qc *queryCompiler) WithEnablePrintStatements(yes bool) QueryCompiler { + qc.enablePrintStatements = yes + return qc +} + +func (qc *queryCompiler) WithContext(qctx *QueryContext) QueryCompiler { + qc.qctx = qctx + return qc +} + +func (qc *queryCompiler) WithStageAfter(after string, stage QueryCompilerStageDefinition) QueryCompiler { + qc.after[after] = append(qc.after[after], stage) + return qc +} + +func (qc *queryCompiler) WithUnsafeBuiltins(unsafe map[string]struct{}) QueryCompiler { + qc.unsafeBuiltins = unsafe + return qc +} + +func (qc *queryCompiler) RewrittenVars() map[Var]Var { + return qc.rewritten +} + +func (qc *queryCompiler) ComprehensionIndex(term *Term) *ComprehensionIndex { + if result, ok := qc.comprehensionIndices[term]; ok { + return result + } else if result, ok := qc.compiler.comprehensionIndices[term]; ok { + return result + } + return nil +} + +func (qc *queryCompiler) runStage(metricName string, qctx *QueryContext, query Body, s func(*QueryContext, Body) (Body, error)) (Body, error) { + if qc.compiler.metrics != nil { + qc.compiler.metrics.Timer(metricName).Start() + defer qc.compiler.metrics.Timer(metricName).Stop() + } + return s(qctx, query) +} + +func (qc *queryCompiler) runStageAfter(metricName string, query Body, s QueryCompilerStage) (Body, error) { + if qc.compiler.metrics != nil { + qc.compiler.metrics.Timer(metricName).Start() + defer qc.compiler.metrics.Timer(metricName).Stop() + } + return s(qc, query) +} + +type queryStage = struct { + name string + metricName string + f func(*QueryContext, Body) (Body, error) +} + +func (qc *queryCompiler) Compile(query Body) (Body, error) { + if len(query) == 0 { + return nil, Errors{NewError(CompileErr, nil, "empty query cannot be compiled")} + } + + query = query.Copy() + + stages := []queryStage{ + {"CheckKeywordOverrides", "query_compile_stage_check_keyword_overrides", qc.checkKeywordOverrides}, + {"ResolveRefs", "query_compile_stage_resolve_refs", qc.resolveRefs}, + {"RewriteLocalVars", "query_compile_stage_rewrite_local_vars", qc.rewriteLocalVars}, + {"CheckVoidCalls", "query_compile_stage_check_void_calls", qc.checkVoidCalls}, + {"RewritePrintCalls", "query_compile_stage_rewrite_print_calls", qc.rewritePrintCalls}, + {"RewriteExprTerms", "query_compile_stage_rewrite_expr_terms", qc.rewriteExprTerms}, + {"RewriteComprehensionTerms", "query_compile_stage_rewrite_comprehension_terms", qc.rewriteComprehensionTerms}, + {"RewriteWithValues", "query_compile_stage_rewrite_with_values", qc.rewriteWithModifiers}, + {"CheckUndefinedFuncs", "query_compile_stage_check_undefined_funcs", qc.checkUndefinedFuncs}, + {"CheckSafety", "query_compile_stage_check_safety", qc.checkSafety}, + {"RewriteDynamicTerms", "query_compile_stage_rewrite_dynamic_terms", qc.rewriteDynamicTerms}, + {"CheckTypes", "query_compile_stage_check_types", qc.checkTypes}, + {"CheckUnsafeBuiltins", "query_compile_stage_check_unsafe_builtins", qc.checkUnsafeBuiltins}, + {"CheckDeprecatedBuiltins", "query_compile_stage_check_deprecated_builtins", qc.checkDeprecatedBuiltins}, + } + if qc.compiler.evalMode == EvalModeTopdown { + stages = append(stages, queryStage{"BuildComprehensionIndex", "query_compile_stage_build_comprehension_index", qc.buildComprehensionIndices}) + } + + qctx := qc.qctx.Copy() + + for _, s := range stages { + var err error + query, err = qc.runStage(s.metricName, qctx, query, s.f) + if err != nil { + return nil, qc.applyErrorLimit(err) + } + for _, s := range qc.after[s.name] { + query, err = qc.runStageAfter(s.MetricName, query, s.Stage) + if err != nil { + return nil, qc.applyErrorLimit(err) + } + } + } + + return query, nil +} + +func (qc *queryCompiler) TypeEnv() *TypeEnv { + return qc.typeEnv +} + +func (qc *queryCompiler) applyErrorLimit(err error) error { + var errs Errors + if errors.As(err, &errs) { + if qc.compiler.maxErrs > 0 && len(errs) > qc.compiler.maxErrs { + err = append(errs[:qc.compiler.maxErrs], errLimitReached) + } + } + return err +} + +func (qc *queryCompiler) checkKeywordOverrides(_ *QueryContext, body Body) (Body, error) { + if qc.compiler.strict { + if errs := checkRootDocumentOverrides(body); len(errs) > 0 { + return nil, errs + } + } + return body, nil +} + +func (qc *queryCompiler) resolveRefs(qctx *QueryContext, body Body) (Body, error) { + + var globals map[Var]*usedRef + + if qctx != nil { + pkg := qctx.Package + // Query compiler ought to generate a package if one was not provided and one or more imports were provided. + // The generated package name could even be an empty string to avoid conflicts (it doesn't have to be valid syntactically) + if pkg == nil && len(qctx.Imports) > 0 { + pkg = &Package{Path: RefTerm(VarTerm("")).Value.(Ref)} + } + if pkg != nil { + var ruleExports []Ref + rules := qc.compiler.getExports() + if exist, ok := rules.Get(pkg.Path); ok { + ruleExports = exist.([]Ref) + } + + globals = getGlobals(qctx.Package, ruleExports, qctx.Imports) + qctx.Imports = nil + } + } + + ignore := &declaredVarStack{declaredVars(body)} + + return resolveRefsInBody(globals, ignore, body), nil +} + +func (qc *queryCompiler) rewriteComprehensionTerms(_ *QueryContext, body Body) (Body, error) { + gen := newLocalVarGenerator("q", body) + f := newEqualityFactory(gen) + node, err := rewriteComprehensionTerms(f, body) + if err != nil { + return nil, err + } + return node.(Body), nil +} + +func (qc *queryCompiler) rewriteDynamicTerms(_ *QueryContext, body Body) (Body, error) { + gen := newLocalVarGenerator("q", body) + f := newEqualityFactory(gen) + return rewriteDynamics(f, body), nil +} + +func (qc *queryCompiler) rewriteExprTerms(_ *QueryContext, body Body) (Body, error) { + gen := newLocalVarGenerator("q", body) + return rewriteExprTermsInBody(gen, body), nil +} + +func (qc *queryCompiler) rewriteLocalVars(_ *QueryContext, body Body) (Body, error) { + gen := newLocalVarGenerator("q", body) + stack := newLocalDeclaredVars() + body, _, err := rewriteLocalVars(gen, stack, nil, body, qc.compiler.strict) + if len(err) != 0 { + return nil, err + } + qc.rewritten = make(map[Var]Var, len(stack.rewritten)) + for k, v := range stack.rewritten { + // The vars returned during the rewrite will include all seen vars, + // even if they're not declared with an assignment operation. We don't + // want to include these inside the rewritten set though. + qc.rewritten[k] = v + } + return body, nil +} + +func (qc *queryCompiler) rewritePrintCalls(_ *QueryContext, body Body) (Body, error) { + if !qc.enablePrintStatements { + _, cpy := erasePrintCallsInBody(body) + return cpy, nil + } + gen := newLocalVarGenerator("q", body) + if _, errs := rewritePrintCalls(gen, qc.compiler.GetArity, ReservedVars, body); len(errs) > 0 { + return nil, errs + } + return body, nil +} + +func (qc *queryCompiler) checkVoidCalls(_ *QueryContext, body Body) (Body, error) { + if errs := checkVoidCalls(qc.compiler.TypeEnv, body); len(errs) > 0 { + return nil, errs + } + return body, nil +} + +func (qc *queryCompiler) checkUndefinedFuncs(_ *QueryContext, body Body) (Body, error) { + if errs := checkUndefinedFuncs(qc.compiler.TypeEnv, body, qc.compiler.GetArity, qc.rewritten); len(errs) > 0 { + return nil, errs + } + return body, nil +} + +func (qc *queryCompiler) checkSafety(_ *QueryContext, body Body) (Body, error) { + safe := ReservedVars.Copy() + reordered, unsafe := reorderBodyForSafety(qc.compiler.builtins, qc.compiler.GetArity, safe, body) + if errs := safetyErrorSlice(unsafe, qc.RewrittenVars()); len(errs) > 0 { + return nil, errs + } + return reordered, nil +} + +func (qc *queryCompiler) checkTypes(_ *QueryContext, body Body) (Body, error) { + var errs Errors + checker := newTypeChecker(). + WithSchemaSet(qc.compiler.schemaSet). + WithInputType(qc.compiler.inputType). + WithVarRewriter(rewriteVarsInRef(qc.rewritten, qc.compiler.RewrittenVars)) + qc.typeEnv, errs = checker.CheckBody(qc.compiler.TypeEnv, body) + if len(errs) > 0 { + return nil, errs + } + + return body, nil +} + +func (qc *queryCompiler) checkUnsafeBuiltins(_ *QueryContext, body Body) (Body, error) { + errs := checkUnsafeBuiltins(qc.unsafeBuiltinsMap(), body) + if len(errs) > 0 { + return nil, errs + } + return body, nil +} + +func (qc *queryCompiler) unsafeBuiltinsMap() map[string]struct{} { + if qc.unsafeBuiltins != nil { + return qc.unsafeBuiltins + } + return qc.compiler.unsafeBuiltinsMap +} + +func (qc *queryCompiler) checkDeprecatedBuiltins(_ *QueryContext, body Body) (Body, error) { + if qc.compiler.strict { + errs := checkDeprecatedBuiltins(qc.compiler.deprecatedBuiltinsMap, body) + if len(errs) > 0 { + return nil, errs + } + } + return body, nil +} + +func (qc *queryCompiler) rewriteWithModifiers(_ *QueryContext, body Body) (Body, error) { + f := newEqualityFactory(newLocalVarGenerator("q", body)) + body, err := rewriteWithModifiersInBody(qc.compiler, qc.unsafeBuiltinsMap(), f, body) + if err != nil { + return nil, Errors{err} + } + return body, nil +} + +func (qc *queryCompiler) buildComprehensionIndices(_ *QueryContext, body Body) (Body, error) { + // NOTE(tsandall): The query compiler does not have a metrics object so we + // cannot record index metrics currently. + _ = buildComprehensionIndices(qc.compiler.debug, qc.compiler.GetArity, ReservedVars, qc.RewrittenVars(), body, qc.comprehensionIndices) + return body, nil +} + +// ComprehensionIndex specifies how the comprehension term can be indexed. The keys +// tell the evaluator what variables to use for indexing. In the future, the index +// could be expanded with more information that would allow the evaluator to index +// a larger fragment of comprehensions (e.g., by closing over variables in the outer +// query.) +type ComprehensionIndex struct { + Term *Term + Keys []*Term +} + +func (ci *ComprehensionIndex) String() string { + if ci == nil { + return "" + } + return fmt.Sprintf("", NewArray(ci.Keys...)) +} + +func buildComprehensionIndices(dbg debug.Debug, arity func(Ref) int, candidates VarSet, rwVars map[Var]Var, node interface{}, result map[*Term]*ComprehensionIndex) uint64 { + var n uint64 + cpy := candidates.Copy() + WalkBodies(node, func(b Body) bool { + for _, expr := range b { + index := getComprehensionIndex(dbg, arity, cpy, rwVars, expr) + if index != nil { + result[index.Term] = index + n++ + } + // Any variables appearing in the expressions leading up to the comprehension + // are fair-game to be used as index keys. + cpy.Update(expr.Vars(VarVisitorParams{SkipClosures: true, SkipRefCallHead: true})) + } + return false + }) + return n +} + +func getComprehensionIndex(dbg debug.Debug, arity func(Ref) int, candidates VarSet, rwVars map[Var]Var, expr *Expr) *ComprehensionIndex { + + // Ignore everything except = expressions. Extract + // the comprehension term from the expression. + if !expr.IsEquality() || expr.Negated || len(expr.With) > 0 { + // No debug message, these are assumed to be known hinderances + // to comprehension indexing. + return nil + } + + var term *Term + + lhs, rhs := expr.Operand(0), expr.Operand(1) + + if _, ok := lhs.Value.(Var); ok && IsComprehension(rhs.Value) { + term = rhs + } else if _, ok := rhs.Value.(Var); ok && IsComprehension(lhs.Value) { + term = lhs + } + + if term == nil { + // no debug for this, it's the ordinary "nothing to do here" case + return nil + } + + // Ignore comprehensions that contain expressions that close over variables + // in the outer body if those variables are not also output variables in the + // comprehension body. In other words, ignore comprehensions that we cannot + // safely evaluate without bindings from the outer body. For example: + // + // x = [1] + // [true | data.y[z] = x] # safe to evaluate w/o outer body + // [true | data.y[z] = x[0]] # NOT safe to evaluate because 'x' would be unsafe. + // + // By identifying output variables in the body we also know what to index on by + // intersecting with candidate variables from the outer query. + // + // For example: + // + // x = data.foo[_] + // _ = [y | data.bar[y] = x] # index on 'x' + // + // This query goes from O(data.foo*data.bar) to O(data.foo+data.bar). + var body Body + + switch x := term.Value.(type) { + case *ArrayComprehension: + body = x.Body + case *SetComprehension: + body = x.Body + case *ObjectComprehension: + body = x.Body + } + + outputs := outputVarsForBody(body, arity, ReservedVars) + unsafe := body.Vars(SafetyCheckVisitorParams).Diff(outputs).Diff(ReservedVars) + + if len(unsafe) > 0 { + dbg.Printf("%s: comprehension index: unsafe vars: %v", expr.Location, unsafe) + return nil + } + + // Similarly, ignore comprehensions that contain references with output variables + // that intersect with the candidates. Indexing these comprehensions could worsen + // performance. + regressionVis := newComprehensionIndexRegressionCheckVisitor(candidates) + regressionVis.Walk(body) + if regressionVis.worse { + dbg.Printf("%s: comprehension index: output vars intersect candidates", expr.Location) + return nil + } + + // Check if any nested comprehensions close over candidates. If any intersection is found + // the comprehension cannot be cached because it would require closing over the candidates + // which the evaluator does not support today. + nestedVis := newComprehensionIndexNestedCandidateVisitor(candidates) + nestedVis.Walk(body) + if nestedVis.found { + dbg.Printf("%s: comprehension index: nested comprehensions close over candidates", expr.Location) + return nil + } + + // Make a sorted set of variable names that will serve as the index key set. + // Sort to ensure deterministic indexing. In future this could be relaxed + // if we can decide that one ordering is better than another. If the set is + // empty, there is no indexing to do. + indexVars := candidates.Intersect(outputs) + if len(indexVars) == 0 { + dbg.Printf("%s: comprehension index: no index vars", expr.Location) + return nil + } + + result := make([]*Term, 0, len(indexVars)) + + for v := range indexVars { + result = append(result, NewTerm(v)) + } + + sort.Slice(result, func(i, j int) bool { + return result[i].Value.Compare(result[j].Value) < 0 + }) + + debugRes := make([]*Term, len(result)) + for i, r := range result { + if o, ok := rwVars[r.Value.(Var)]; ok { + debugRes[i] = NewTerm(o) + } else { + debugRes[i] = r + } + } + dbg.Printf("%s: comprehension index: built with keys: %v", expr.Location, debugRes) + return &ComprehensionIndex{Term: term, Keys: result} +} + +type comprehensionIndexRegressionCheckVisitor struct { + candidates VarSet + seen VarSet + worse bool +} + +// TODO(tsandall): Improve this so that users can either supply this list explicitly +// or the information is maintained on the built-in function declaration. What we really +// need to know is whether the built-in function allows callers to push down output +// values or not. It's unlikely that anything outside of OPA does this today so this +// solution is fine for now. +var comprehensionIndexBlacklist = map[string]int{ + WalkBuiltin.Name: len(WalkBuiltin.Decl.FuncArgs().Args), +} + +func newComprehensionIndexRegressionCheckVisitor(candidates VarSet) *comprehensionIndexRegressionCheckVisitor { + return &comprehensionIndexRegressionCheckVisitor{ + candidates: candidates, + seen: NewVarSet(), + } +} + +func (vis *comprehensionIndexRegressionCheckVisitor) Walk(x interface{}) { + NewGenericVisitor(vis.visit).Walk(x) +} + +func (vis *comprehensionIndexRegressionCheckVisitor) visit(x interface{}) bool { + if !vis.worse { + switch x := x.(type) { + case *Expr: + operands := x.Operands() + if pos := comprehensionIndexBlacklist[x.Operator().String()]; pos > 0 && pos < len(operands) { + vis.assertEmptyIntersection(operands[pos].Vars()) + } + case Ref: + vis.assertEmptyIntersection(x.OutputVars()) + case Var: + vis.seen.Add(x) + // Always skip comprehensions. We do not have to visit their bodies here. + case *ArrayComprehension, *SetComprehension, *ObjectComprehension: + return true + } + } + return vis.worse +} + +func (vis *comprehensionIndexRegressionCheckVisitor) assertEmptyIntersection(vs VarSet) { + for v := range vs { + if vis.candidates.Contains(v) && !vis.seen.Contains(v) { + vis.worse = true + return + } + } +} + +type comprehensionIndexNestedCandidateVisitor struct { + candidates VarSet + found bool +} + +func newComprehensionIndexNestedCandidateVisitor(candidates VarSet) *comprehensionIndexNestedCandidateVisitor { + return &comprehensionIndexNestedCandidateVisitor{ + candidates: candidates, + } +} + +func (vis *comprehensionIndexNestedCandidateVisitor) Walk(x interface{}) { + NewGenericVisitor(vis.visit).Walk(x) +} + +func (vis *comprehensionIndexNestedCandidateVisitor) visit(x interface{}) bool { + + if vis.found { + return true + } + + if v, ok := x.(Value); ok && IsComprehension(v) { + varVis := NewVarVisitor().WithParams(VarVisitorParams{SkipRefHead: true}) + varVis.Walk(v) + vis.found = len(varVis.Vars().Intersect(vis.candidates)) > 0 + return true + } + + return false +} + +// ModuleTreeNode represents a node in the module tree. The module +// tree is keyed by the package path. +type ModuleTreeNode struct { + Key Value + Modules []*Module + Children map[Value]*ModuleTreeNode + Hide bool +} + +func (n *ModuleTreeNode) String() string { + var rules []string + for _, m := range n.Modules { + for _, r := range m.Rules { + rules = append(rules, r.Head.String()) + } + } + return fmt.Sprintf("", n.Key, n.Children, rules, n.Hide) +} + +// NewModuleTree returns a new ModuleTreeNode that represents the root +// of the module tree populated with the given modules. +func NewModuleTree(mods map[string]*Module) *ModuleTreeNode { + root := &ModuleTreeNode{ + Children: map[Value]*ModuleTreeNode{}, + } + names := make([]string, 0, len(mods)) + for name := range mods { + names = append(names, name) + } + sort.Strings(names) + for _, name := range names { + m := mods[name] + node := root + for i, x := range m.Package.Path { + c, ok := node.Children[x.Value] + if !ok { + var hide bool + if i == 1 && x.Value.Compare(SystemDocumentKey) == 0 { + hide = true + } + c = &ModuleTreeNode{ + Key: x.Value, + Children: map[Value]*ModuleTreeNode{}, + Hide: hide, + } + node.Children[x.Value] = c + } + node = c + } + node.Modules = append(node.Modules, m) + } + return root +} + +// Size returns the number of modules in the tree. +func (n *ModuleTreeNode) Size() int { + s := len(n.Modules) + for _, c := range n.Children { + s += c.Size() + } + return s +} + +// Child returns n's child with key k. +func (n *ModuleTreeNode) child(k Value) *ModuleTreeNode { + switch k.(type) { + case String, Var: + return n.Children[k] + } + return nil +} + +// Find dereferences ref along the tree. ref[0] is converted to a String +// for convenience. +func (n *ModuleTreeNode) find(ref Ref) (*ModuleTreeNode, Ref) { + if v, ok := ref[0].Value.(Var); ok { + ref = Ref{StringTerm(string(v))}.Concat(ref[1:]) + } + node := n + for i, r := range ref { + next := node.child(r.Value) + if next == nil { + tail := make(Ref, len(ref)-i) + tail[0] = VarTerm(string(ref[i].Value.(String))) + copy(tail[1:], ref[i+1:]) + return node, tail + } + node = next + } + return node, nil +} + +// DepthFirst performs a depth-first traversal of the module tree rooted at n. +// If f returns true, traversal will not continue to the children of n. +func (n *ModuleTreeNode) DepthFirst(f func(*ModuleTreeNode) bool) { + if f(n) { + return + } + for _, node := range n.Children { + node.DepthFirst(f) + } +} + +// TreeNode represents a node in the rule tree. The rule tree is keyed by +// rule path. +type TreeNode struct { + Key Value + Values []util.T + Children map[Value]*TreeNode + Sorted []Value + Hide bool +} + +func (n *TreeNode) String() string { + return fmt.Sprintf("", n.Key, n.Values, n.Sorted, n.Hide) +} + +// NewRuleTree returns a new TreeNode that represents the root +// of the rule tree populated with the given rules. +func NewRuleTree(mtree *ModuleTreeNode) *TreeNode { + root := TreeNode{ + Key: mtree.Key, + } + + mtree.DepthFirst(func(m *ModuleTreeNode) bool { + for _, mod := range m.Modules { + if len(mod.Rules) == 0 { + root.add(mod.Package.Path, nil) + } + for _, rule := range mod.Rules { + root.add(rule.Ref().GroundPrefix(), rule) + } + } + return false + }) + + // ensure that data.system's TreeNode is hidden + node, tail := root.find(DefaultRootRef.Append(NewTerm(SystemDocumentKey))) + if len(tail) == 0 { // found + node.Hide = true + } + + root.DepthFirst(func(x *TreeNode) bool { + x.sort() + return false + }) + + return &root +} + +func (n *TreeNode) add(path Ref, rule *Rule) { + node, tail := n.find(path) + if len(tail) > 0 { + sub := treeNodeFromRef(tail, rule) + if node.Children == nil { + node.Children = make(map[Value]*TreeNode, 1) + } + node.Children[sub.Key] = sub + node.Sorted = append(node.Sorted, sub.Key) + } else { + if rule != nil { + node.Values = append(node.Values, rule) + } + } +} + +// Size returns the number of rules in the tree. +func (n *TreeNode) Size() int { + s := len(n.Values) + for _, c := range n.Children { + s += c.Size() + } + return s +} + +// Child returns n's child with key k. +func (n *TreeNode) Child(k Value) *TreeNode { + switch k.(type) { + case Ref, Call: + return nil + default: + return n.Children[k] + } +} + +// Find dereferences ref along the tree +func (n *TreeNode) Find(ref Ref) *TreeNode { + node := n + for _, r := range ref { + node = node.Child(r.Value) + if node == nil { + return nil + } + } + return node +} + +// Iteratively dereferences ref along the node's subtree. +// - If matching fails immediately, the tail will contain the full ref. +// - Partial matching will result in a tail of non-zero length. +// - A complete match will result in a 0 length tail. +func (n *TreeNode) find(ref Ref) (*TreeNode, Ref) { + node := n + for i := range ref { + next := node.Child(ref[i].Value) + if next == nil { + tail := make(Ref, len(ref)-i) + copy(tail, ref[i:]) + return node, tail + } + node = next + } + return node, nil +} + +// DepthFirst performs a depth-first traversal of the rule tree rooted at n. If +// f returns true, traversal will not continue to the children of n. +func (n *TreeNode) DepthFirst(f func(*TreeNode) bool) { + if f(n) { + return + } + for _, node := range n.Children { + node.DepthFirst(f) + } +} + +func (n *TreeNode) sort() { + sort.Slice(n.Sorted, func(i, j int) bool { + return n.Sorted[i].Compare(n.Sorted[j]) < 0 + }) +} + +func treeNodeFromRef(ref Ref, rule *Rule) *TreeNode { + depth := len(ref) - 1 + key := ref[depth].Value + node := &TreeNode{ + Key: key, + Children: nil, + } + if rule != nil { + node.Values = []util.T{rule} + } + + for i := len(ref) - 2; i >= 0; i-- { + key := ref[i].Value + node = &TreeNode{ + Key: key, + Children: map[Value]*TreeNode{ref[i+1].Value: node}, + Sorted: []Value{ref[i+1].Value}, + } + } + return node +} + +// flattenChildren flattens all children's rule refs into a sorted array. +func (n *TreeNode) flattenChildren() []Ref { + ret := newRefSet() + for _, sub := range n.Children { // we only want the children, so don't use n.DepthFirst() right away + sub.DepthFirst(func(x *TreeNode) bool { + for _, r := range x.Values { + rule := r.(*Rule) + ret.AddPrefix(rule.Ref()) + } + return false + }) + } + + sort.Slice(ret.s, func(i, j int) bool { + return ret.s[i].Compare(ret.s[j]) < 0 + }) + return ret.s +} + +// Graph represents the graph of dependencies between rules. +type Graph struct { + adj map[util.T]map[util.T]struct{} + radj map[util.T]map[util.T]struct{} + nodes map[util.T]struct{} + sorted []util.T +} + +// NewGraph returns a new Graph based on modules. The list function must return +// the rules referred to directly by the ref. +func NewGraph(modules map[string]*Module, list func(Ref) []*Rule) *Graph { + + graph := &Graph{ + adj: map[util.T]map[util.T]struct{}{}, + radj: map[util.T]map[util.T]struct{}{}, + nodes: map[util.T]struct{}{}, + sorted: nil, + } + + // Create visitor to walk a rule AST and add edges to the rule graph for + // each dependency. + vis := func(a *Rule) *GenericVisitor { + stop := false + return NewGenericVisitor(func(x interface{}) bool { + switch x := x.(type) { + case Ref: + for _, b := range list(x) { + for node := b; node != nil; node = node.Else { + graph.addDependency(a, node) + } + } + case *Rule: + if stop { + // Do not recurse into else clauses (which will be handled + // by the outer visitor.) + return true + } + stop = true + } + return false + }) + } + + // Walk over all rules, add them to graph, and build adjacency lists. + for _, module := range modules { + WalkRules(module, func(a *Rule) bool { + graph.addNode(a) + vis(a).Walk(a) + return false + }) + } + + return graph +} + +// Dependencies returns the set of rules that x depends on. +func (g *Graph) Dependencies(x util.T) map[util.T]struct{} { + return g.adj[x] +} + +// Dependents returns the set of rules that depend on x. +func (g *Graph) Dependents(x util.T) map[util.T]struct{} { + return g.radj[x] +} + +// Sort returns a slice of rules sorted by dependencies. If a cycle is found, +// ok is set to false. +func (g *Graph) Sort() (sorted []util.T, ok bool) { + if g.sorted != nil { + return g.sorted, true + } + + sorter := &graphSort{ + sorted: make([]util.T, 0, len(g.nodes)), + deps: g.Dependencies, + marked: map[util.T]struct{}{}, + temp: map[util.T]struct{}{}, + } + + for node := range g.nodes { + if !sorter.Visit(node) { + return nil, false + } + } + + g.sorted = sorter.sorted + return g.sorted, true +} + +func (g *Graph) addDependency(u util.T, v util.T) { + + if _, ok := g.nodes[u]; !ok { + g.addNode(u) + } + + if _, ok := g.nodes[v]; !ok { + g.addNode(v) + } + + edges, ok := g.adj[u] + if !ok { + edges = map[util.T]struct{}{} + g.adj[u] = edges + } + + edges[v] = struct{}{} + + edges, ok = g.radj[v] + if !ok { + edges = map[util.T]struct{}{} + g.radj[v] = edges + } + + edges[u] = struct{}{} +} + +func (g *Graph) addNode(n util.T) { + g.nodes[n] = struct{}{} +} + +type graphSort struct { + sorted []util.T + deps func(util.T) map[util.T]struct{} + marked map[util.T]struct{} + temp map[util.T]struct{} +} + +func (sort *graphSort) Marked(node util.T) bool { + _, marked := sort.marked[node] + return marked +} + +func (sort *graphSort) Visit(node util.T) (ok bool) { + if _, ok := sort.temp[node]; ok { + return false + } + if sort.Marked(node) { + return true + } + sort.temp[node] = struct{}{} + for other := range sort.deps(node) { + if !sort.Visit(other) { + return false + } + } + sort.marked[node] = struct{}{} + delete(sort.temp, node) + sort.sorted = append(sort.sorted, node) + return true +} + +// GraphTraversal is a Traversal that understands the dependency graph +type GraphTraversal struct { + graph *Graph + visited map[util.T]struct{} +} + +// NewGraphTraversal returns a Traversal for the dependency graph +func NewGraphTraversal(graph *Graph) *GraphTraversal { + return &GraphTraversal{ + graph: graph, + visited: map[util.T]struct{}{}, + } +} + +// Edges lists all dependency connections for a given node +func (g *GraphTraversal) Edges(x util.T) []util.T { + r := []util.T{} + for v := range g.graph.Dependencies(x) { + r = append(r, v) + } + return r +} + +// Visited returns whether a node has been visited, setting a node to visited if not +func (g *GraphTraversal) Visited(u util.T) bool { + _, ok := g.visited[u] + g.visited[u] = struct{}{} + return ok +} + +type unsafePair struct { + Expr *Expr + Vars VarSet +} + +type unsafeVarLoc struct { + Var Var + Loc *Location +} + +type unsafeVars map[*Expr]VarSet + +func (vs unsafeVars) Add(e *Expr, v Var) { + if u, ok := vs[e]; ok { + u[v] = struct{}{} + } else { + vs[e] = VarSet{v: struct{}{}} + } +} + +func (vs unsafeVars) Set(e *Expr, s VarSet) { + vs[e] = s +} + +func (vs unsafeVars) Update(o unsafeVars) { + for k, v := range o { + if _, ok := vs[k]; !ok { + vs[k] = VarSet{} + } + vs[k].Update(v) + } +} + +func (vs unsafeVars) Vars() (result []unsafeVarLoc) { + + locs := map[Var]*Location{} + + // If var appears in multiple sets then pick first by location. + for expr, vars := range vs { + for v := range vars { + if locs[v].Compare(expr.Location) > 0 { + locs[v] = expr.Location + } + } + } + + for v, loc := range locs { + result = append(result, unsafeVarLoc{ + Var: v, + Loc: loc, + }) + } + + sort.Slice(result, func(i, j int) bool { + return result[i].Loc.Compare(result[j].Loc) < 0 + }) + + return result +} + +func (vs unsafeVars) Slice() (result []unsafePair) { + for expr, vs := range vs { + result = append(result, unsafePair{ + Expr: expr, + Vars: vs, + }) + } + return +} + +// reorderBodyForSafety returns a copy of the body ordered such that +// left to right evaluation of the body will not encounter unbound variables +// in input positions or negated expressions. +// +// Expressions are added to the re-ordered body as soon as they are considered +// safe. If multiple expressions become safe in the same pass, they are added +// in their original order. This results in minimal re-ordering of the body. +// +// If the body cannot be reordered to ensure safety, the second return value +// contains a mapping of expressions to unsafe variables in those expressions. +func reorderBodyForSafety(builtins map[string]*Builtin, arity func(Ref) int, globals VarSet, body Body) (Body, unsafeVars) { + + bodyVars := body.Vars(SafetyCheckVisitorParams) + reordered := make(Body, 0, len(body)) + safe := VarSet{} + unsafe := unsafeVars{} + + for _, e := range body { + for v := range e.Vars(SafetyCheckVisitorParams) { + if globals.Contains(v) { + safe.Add(v) + } else { + unsafe.Add(e, v) + } + } + } + + for { + n := len(reordered) + + for _, e := range body { + if reordered.Contains(e) { + continue + } + + ovs := outputVarsForExpr(e, arity, safe) + + // check closures: is this expression closing over variables that + // haven't been made safe by what's already included in `reordered`? + vs := unsafeVarsInClosures(e) + cv := vs.Intersect(bodyVars).Diff(globals) + uv := cv.Diff(outputVarsForBody(reordered, arity, safe)) + + if len(uv) > 0 { + if uv.Equal(ovs) { // special case "closure-self" + continue + } + unsafe.Set(e, uv) + } + + for v := range unsafe[e] { + if ovs.Contains(v) || safe.Contains(v) { + delete(unsafe[e], v) + } + } + + if len(unsafe[e]) == 0 { + delete(unsafe, e) + reordered.Append(e) + safe.Update(ovs) // this expression's outputs are safe + } + } + + if len(reordered) == n { // fixed point, could not add any expr of body + break + } + } + + // Recursively visit closures and perform the safety checks on them. + // Update the globals at each expression to include the variables that could + // be closed over. + g := globals.Copy() + for i, e := range reordered { + if i > 0 { + g.Update(reordered[i-1].Vars(SafetyCheckVisitorParams)) + } + xform := &bodySafetyTransformer{ + builtins: builtins, + arity: arity, + current: e, + globals: g, + unsafe: unsafe, + } + NewGenericVisitor(xform.Visit).Walk(e) + } + + return reordered, unsafe +} + +type bodySafetyTransformer struct { + builtins map[string]*Builtin + arity func(Ref) int + current *Expr + globals VarSet + unsafe unsafeVars +} + +func (xform *bodySafetyTransformer) Visit(x interface{}) bool { + switch term := x.(type) { + case *Term: + switch x := term.Value.(type) { + case *object: + cpy, _ := x.Map(func(k, v *Term) (*Term, *Term, error) { + kcpy := k.Copy() + NewGenericVisitor(xform.Visit).Walk(kcpy) + vcpy := v.Copy() + NewGenericVisitor(xform.Visit).Walk(vcpy) + return kcpy, vcpy, nil + }) + term.Value = cpy + return true + case *set: + cpy, _ := x.Map(func(v *Term) (*Term, error) { + vcpy := v.Copy() + NewGenericVisitor(xform.Visit).Walk(vcpy) + return vcpy, nil + }) + term.Value = cpy + return true + case *ArrayComprehension: + xform.reorderArrayComprehensionSafety(x) + return true + case *ObjectComprehension: + xform.reorderObjectComprehensionSafety(x) + return true + case *SetComprehension: + xform.reorderSetComprehensionSafety(x) + return true + } + case *Expr: + if ev, ok := term.Terms.(*Every); ok { + xform.globals.Update(ev.KeyValueVars()) + ev.Body = xform.reorderComprehensionSafety(NewVarSet(), ev.Body) + return true + } + } + return false +} + +func (xform *bodySafetyTransformer) reorderComprehensionSafety(tv VarSet, body Body) Body { + bv := body.Vars(SafetyCheckVisitorParams) + bv.Update(xform.globals) + uv := tv.Diff(bv) + for v := range uv { + xform.unsafe.Add(xform.current, v) + } + + r, u := reorderBodyForSafety(xform.builtins, xform.arity, xform.globals, body) + if len(u) == 0 { + return r + } + + xform.unsafe.Update(u) + return body +} + +func (xform *bodySafetyTransformer) reorderArrayComprehensionSafety(ac *ArrayComprehension) { + ac.Body = xform.reorderComprehensionSafety(ac.Term.Vars(), ac.Body) +} + +func (xform *bodySafetyTransformer) reorderObjectComprehensionSafety(oc *ObjectComprehension) { + tv := oc.Key.Vars() + tv.Update(oc.Value.Vars()) + oc.Body = xform.reorderComprehensionSafety(tv, oc.Body) +} + +func (xform *bodySafetyTransformer) reorderSetComprehensionSafety(sc *SetComprehension) { + sc.Body = xform.reorderComprehensionSafety(sc.Term.Vars(), sc.Body) +} + +// unsafeVarsInClosures collects vars that are contained in closures within +// this expression. +func unsafeVarsInClosures(e *Expr) VarSet { + vs := VarSet{} + WalkClosures(e, func(x interface{}) bool { + vis := &VarVisitor{vars: vs} + if ev, ok := x.(*Every); ok { + vis.Walk(ev.Body) + return true + } + vis.Walk(x) + return true + }) + return vs +} + +// OutputVarsFromBody returns all variables which are the "output" for +// the given body. For safety checks this means that they would be +// made safe by the body. +func OutputVarsFromBody(c *Compiler, body Body, safe VarSet) VarSet { + return outputVarsForBody(body, c.GetArity, safe) +} + +func outputVarsForBody(body Body, arity func(Ref) int, safe VarSet) VarSet { + o := safe.Copy() + for _, e := range body { + o.Update(outputVarsForExpr(e, arity, o)) + } + return o.Diff(safe) +} + +// OutputVarsFromExpr returns all variables which are the "output" for +// the given expression. For safety checks this means that they would be +// made safe by the expr. +func OutputVarsFromExpr(c *Compiler, expr *Expr, safe VarSet) VarSet { + return outputVarsForExpr(expr, c.GetArity, safe) +} + +func outputVarsForExpr(expr *Expr, arity func(Ref) int, safe VarSet) VarSet { + + // Negated expressions must be safe. + if expr.Negated { + return VarSet{} + } + + // With modifier inputs must be safe. + for _, with := range expr.With { + vis := NewVarVisitor().WithParams(SafetyCheckVisitorParams) + vis.Walk(with) + vars := vis.Vars() + unsafe := vars.Diff(safe) + if len(unsafe) > 0 { + return VarSet{} + } + } + + switch terms := expr.Terms.(type) { + case *Term: + return outputVarsForTerms(expr, safe) + case []*Term: + if expr.IsEquality() { + return outputVarsForExprEq(expr, safe) + } + + operator, ok := terms[0].Value.(Ref) + if !ok { + return VarSet{} + } + + ar := arity(operator) + if ar < 0 { + return VarSet{} + } + + return outputVarsForExprCall(expr, ar, safe, terms) + case *Every: + return outputVarsForTerms(terms.Domain, safe) + default: + panic("illegal expression") + } +} + +func outputVarsForExprEq(expr *Expr, safe VarSet) VarSet { + + if !validEqAssignArgCount(expr) { + return safe + } + + output := outputVarsForTerms(expr, safe) + output.Update(safe) + output.Update(Unify(output, expr.Operand(0), expr.Operand(1))) + + return output.Diff(safe) +} + +func outputVarsForExprCall(expr *Expr, arity int, safe VarSet, terms []*Term) VarSet { + + output := outputVarsForTerms(expr, safe) + + numInputTerms := arity + 1 + if numInputTerms >= len(terms) { + return output + } + + params := VarVisitorParams{ + SkipClosures: true, + SkipSets: true, + SkipObjectKeys: true, + SkipRefHead: true, + } + vis := NewVarVisitor().WithParams(params) + vis.Walk(Args(terms[:numInputTerms])) + unsafe := vis.Vars().Diff(output).Diff(safe) + + if len(unsafe) > 0 { + return VarSet{} + } + + vis = NewVarVisitor().WithParams(params) + vis.Walk(Args(terms[numInputTerms:])) + output.Update(vis.vars) + return output +} + +func outputVarsForTerms(expr interface{}, safe VarSet) VarSet { + output := VarSet{} + WalkTerms(expr, func(x *Term) bool { + switch r := x.Value.(type) { + case *SetComprehension, *ArrayComprehension, *ObjectComprehension: + return true + case Ref: + if !isRefSafe(r, safe) { + return true + } + output.Update(r.OutputVars()) + return false + } + return false + }) + return output +} + +type equalityFactory struct { + gen *localVarGenerator +} + +func newEqualityFactory(gen *localVarGenerator) *equalityFactory { + return &equalityFactory{gen} +} + +func (f *equalityFactory) Generate(other *Term) *Expr { + term := NewTerm(f.gen.Generate()).SetLocation(other.Location) + expr := Equality.Expr(term, other) + expr.Generated = true + expr.Location = other.Location + return expr +} + +type localVarGenerator struct { + exclude VarSet + suffix string + next int +} + +func newLocalVarGeneratorForModuleSet(sorted []string, modules map[string]*Module) *localVarGenerator { + exclude := NewVarSet() + vis := &VarVisitor{vars: exclude} + for _, key := range sorted { + vis.Walk(modules[key]) + } + return &localVarGenerator{exclude: exclude, next: 0} +} + +func newLocalVarGenerator(suffix string, node interface{}) *localVarGenerator { + exclude := NewVarSet() + vis := &VarVisitor{vars: exclude} + vis.Walk(node) + return &localVarGenerator{exclude: exclude, suffix: suffix, next: 0} +} + +func (l *localVarGenerator) Generate() Var { + for { + result := Var("__local" + l.suffix + strconv.Itoa(l.next) + "__") + l.next++ + if !l.exclude.Contains(result) { + return result + } + } +} + +func getGlobals(pkg *Package, rules []Ref, imports []*Import) map[Var]*usedRef { + + globals := make(map[Var]*usedRef, len(rules)) // NB: might grow bigger with imports + + // Populate globals with exports within the package. + for _, ref := range rules { + v := ref[0].Value.(Var) + globals[v] = &usedRef{ref: pkg.Path.Append(StringTerm(string(v)))} + } + + // Populate globals with imports. + for _, imp := range imports { + path := imp.Path.Value.(Ref) + if FutureRootDocument.Equal(path[0]) || RegoRootDocument.Equal(path[0]) { + continue // ignore future and rego imports + } + globals[imp.Name()] = &usedRef{ref: path} + } + + return globals +} + +func requiresEval(x *Term) bool { + if x == nil { + return false + } + return ContainsRefs(x) || ContainsComprehensions(x) +} + +func resolveRef(globals map[Var]*usedRef, ignore *declaredVarStack, ref Ref) Ref { + + r := Ref{} + for i, x := range ref { + switch v := x.Value.(type) { + case Var: + if g, ok := globals[v]; ok && !ignore.Contains(v) { + cpy := g.ref.Copy() + for i := range cpy { + cpy[i].SetLocation(x.Location) + } + if i == 0 { + r = cpy + } else { + r = append(r, NewTerm(cpy).SetLocation(x.Location)) + } + g.used = true + } else { + r = append(r, x) + } + case Ref, *Array, Object, Set, *ArrayComprehension, *SetComprehension, *ObjectComprehension, Call: + r = append(r, resolveRefsInTerm(globals, ignore, x)) + default: + r = append(r, x) + } + } + + return r +} + +type usedRef struct { + ref Ref + used bool +} + +func resolveRefsInRule(globals map[Var]*usedRef, rule *Rule) error { + ignore := &declaredVarStack{} + + vars := NewVarSet() + var vis *GenericVisitor + var err error + + // Walk args to collect vars and transform body so that callers can shadow + // root documents. + vis = NewGenericVisitor(func(x interface{}) bool { + if err != nil { + return true + } + switch x := x.(type) { + case Var: + vars.Add(x) + + // Object keys cannot be pattern matched so only walk values. + case *object: + x.Foreach(func(_, v *Term) { + vis.Walk(v) + }) + + // Skip terms that could contain vars that cannot be pattern matched. + case Set, *ArrayComprehension, *SetComprehension, *ObjectComprehension, Call: + return true + + case *Term: + if _, ok := x.Value.(Ref); ok { + if RootDocumentRefs.Contains(x) { + // We could support args named input, data, etc. however + // this would require rewriting terms in the head and body. + // Preventing root document shadowing is simpler, and + // arguably, will prevent confusing names from being used. + // NOTE: this check is also performed as part of strict-mode in + // checkRootDocumentOverrides. + err = fmt.Errorf("args must not shadow %v (use a different variable name)", x) + return true + } + } + } + return false + }) + + vis.Walk(rule.Head.Args) + + if err != nil { + return err + } + + ignore.Push(vars) + ignore.Push(declaredVars(rule.Body)) + + ref := rule.Head.Ref() + for i := 1; i < len(ref); i++ { + ref[i] = resolveRefsInTerm(globals, ignore, ref[i]) + } + if rule.Head.Key != nil { + rule.Head.Key = resolveRefsInTerm(globals, ignore, rule.Head.Key) + } + + if rule.Head.Value != nil { + rule.Head.Value = resolveRefsInTerm(globals, ignore, rule.Head.Value) + } + + rule.Body = resolveRefsInBody(globals, ignore, rule.Body) + return nil +} + +func resolveRefsInBody(globals map[Var]*usedRef, ignore *declaredVarStack, body Body) Body { + r := make([]*Expr, 0, len(body)) + for _, expr := range body { + r = append(r, resolveRefsInExpr(globals, ignore, expr)) + } + return r +} + +func resolveRefsInExpr(globals map[Var]*usedRef, ignore *declaredVarStack, expr *Expr) *Expr { + cpy := *expr + switch ts := expr.Terms.(type) { + case *Term: + cpy.Terms = resolveRefsInTerm(globals, ignore, ts) + case []*Term: + buf := make([]*Term, len(ts)) + for i := 0; i < len(ts); i++ { + buf[i] = resolveRefsInTerm(globals, ignore, ts[i]) + } + cpy.Terms = buf + case *SomeDecl: + if val, ok := ts.Symbols[0].Value.(Call); ok { + cpy.Terms = &SomeDecl{Symbols: []*Term{CallTerm(resolveRefsInTermSlice(globals, ignore, val)...)}} + } + case *Every: + locals := NewVarSet() + if ts.Key != nil { + locals.Update(ts.Key.Vars()) + } + locals.Update(ts.Value.Vars()) + ignore.Push(locals) + cpy.Terms = &Every{ + Key: ts.Key.Copy(), // TODO(sr): do more? + Value: ts.Value.Copy(), // TODO(sr): do more? + Domain: resolveRefsInTerm(globals, ignore, ts.Domain), + Body: resolveRefsInBody(globals, ignore, ts.Body), + } + ignore.Pop() + } + for _, w := range cpy.With { + w.Target = resolveRefsInTerm(globals, ignore, w.Target) + w.Value = resolveRefsInTerm(globals, ignore, w.Value) + } + return &cpy +} + +func resolveRefsInTerm(globals map[Var]*usedRef, ignore *declaredVarStack, term *Term) *Term { + switch v := term.Value.(type) { + case Var: + if g, ok := globals[v]; ok && !ignore.Contains(v) { + cpy := g.ref.Copy() + for i := range cpy { + cpy[i].SetLocation(term.Location) + } + g.used = true + return NewTerm(cpy).SetLocation(term.Location) + } + return term + case Ref: + fqn := resolveRef(globals, ignore, v) + cpy := *term + cpy.Value = fqn + return &cpy + case *object: + cpy := *term + cpy.Value, _ = v.Map(func(k, v *Term) (*Term, *Term, error) { + k = resolveRefsInTerm(globals, ignore, k) + v = resolveRefsInTerm(globals, ignore, v) + return k, v, nil + }) + return &cpy + case *Array: + cpy := *term + cpy.Value = NewArray(resolveRefsInTermArray(globals, ignore, v)...) + return &cpy + case Call: + cpy := *term + cpy.Value = Call(resolveRefsInTermSlice(globals, ignore, v)) + return &cpy + case Set: + s, _ := v.Map(func(e *Term) (*Term, error) { + return resolveRefsInTerm(globals, ignore, e), nil + }) + cpy := *term + cpy.Value = s + return &cpy + case *ArrayComprehension: + ac := &ArrayComprehension{} + ignore.Push(declaredVars(v.Body)) + ac.Term = resolveRefsInTerm(globals, ignore, v.Term) + ac.Body = resolveRefsInBody(globals, ignore, v.Body) + cpy := *term + cpy.Value = ac + ignore.Pop() + return &cpy + case *ObjectComprehension: + oc := &ObjectComprehension{} + ignore.Push(declaredVars(v.Body)) + oc.Key = resolveRefsInTerm(globals, ignore, v.Key) + oc.Value = resolveRefsInTerm(globals, ignore, v.Value) + oc.Body = resolveRefsInBody(globals, ignore, v.Body) + cpy := *term + cpy.Value = oc + ignore.Pop() + return &cpy + case *SetComprehension: + sc := &SetComprehension{} + ignore.Push(declaredVars(v.Body)) + sc.Term = resolveRefsInTerm(globals, ignore, v.Term) + sc.Body = resolveRefsInBody(globals, ignore, v.Body) + cpy := *term + cpy.Value = sc + ignore.Pop() + return &cpy + default: + return term + } +} + +func resolveRefsInTermArray(globals map[Var]*usedRef, ignore *declaredVarStack, terms *Array) []*Term { + cpy := make([]*Term, terms.Len()) + for i := 0; i < terms.Len(); i++ { + cpy[i] = resolveRefsInTerm(globals, ignore, terms.Elem(i)) + } + return cpy +} + +func resolveRefsInTermSlice(globals map[Var]*usedRef, ignore *declaredVarStack, terms []*Term) []*Term { + cpy := make([]*Term, len(terms)) + for i := 0; i < len(terms); i++ { + cpy[i] = resolveRefsInTerm(globals, ignore, terms[i]) + } + return cpy +} + +type declaredVarStack []VarSet + +func (s declaredVarStack) Contains(v Var) bool { + for i := len(s) - 1; i >= 0; i-- { + if _, ok := s[i][v]; ok { + return ok + } + } + return false +} + +func (s declaredVarStack) Add(v Var) { + s[len(s)-1].Add(v) +} + +func (s *declaredVarStack) Push(vs VarSet) { + *s = append(*s, vs) +} + +func (s *declaredVarStack) Pop() { + curr := *s + *s = curr[:len(curr)-1] +} + +func declaredVars(x interface{}) VarSet { + vars := NewVarSet() + vis := NewGenericVisitor(func(x interface{}) bool { + switch x := x.(type) { + case *Expr: + if x.IsAssignment() && validEqAssignArgCount(x) { + WalkVars(x.Operand(0), func(v Var) bool { + vars.Add(v) + return false + }) + } else if decl, ok := x.Terms.(*SomeDecl); ok { + for i := range decl.Symbols { + switch val := decl.Symbols[i].Value.(type) { + case Var: + vars.Add(val) + case Call: + args := val[1:] + if len(args) == 3 { // some x, y in xs + WalkVars(args[1], func(v Var) bool { + vars.Add(v) + return false + }) + } + // some x in xs + WalkVars(args[0], func(v Var) bool { + vars.Add(v) + return false + }) + } + } + } + case *ArrayComprehension, *SetComprehension, *ObjectComprehension: + return true + } + return false + }) + vis.Walk(x) + return vars +} + +// rewriteComprehensionTerms will rewrite comprehensions so that the term part +// is bound to a variable in the body. This allows any type of term to be used +// in the term part (even if the term requires evaluation.) +// +// For instance, given the following comprehension: +// +// [x[0] | x = y[_]; y = [1,2,3]] +// +// The comprehension would be rewritten as: +// +// [__local0__ | x = y[_]; y = [1,2,3]; __local0__ = x[0]] +func rewriteComprehensionTerms(f *equalityFactory, node interface{}) (interface{}, error) { + return TransformComprehensions(node, func(x interface{}) (Value, error) { + switch x := x.(type) { + case *ArrayComprehension: + if requiresEval(x.Term) { + expr := f.Generate(x.Term) + x.Term = expr.Operand(0) + x.Body.Append(expr) + } + return x, nil + case *SetComprehension: + if requiresEval(x.Term) { + expr := f.Generate(x.Term) + x.Term = expr.Operand(0) + x.Body.Append(expr) + } + return x, nil + case *ObjectComprehension: + if requiresEval(x.Key) { + expr := f.Generate(x.Key) + x.Key = expr.Operand(0) + x.Body.Append(expr) + } + if requiresEval(x.Value) { + expr := f.Generate(x.Value) + x.Value = expr.Operand(0) + x.Body.Append(expr) + } + return x, nil + } + panic("illegal type") + }) +} + +// rewriteEquals will rewrite exprs under x as unification calls instead of == +// calls. For example: +// +// data.foo == data.bar is rewritten as data.foo = data.bar +// +// This stage should only run the safety check (since == is a built-in with no +// outputs, so the inputs must not be marked as safe.) +// +// This stage is not executed by the query compiler by default because when +// callers specify == instead of = they expect to receive a true/false/undefined +// result back whereas with = the result is only ever true/undefined. For +// partial evaluation cases we do want to rewrite == to = to simplify the +// result. +func rewriteEquals(x interface{}) (modified bool) { + doubleEq := Equal.Ref() + unifyOp := Equality.Ref() + t := NewGenericTransformer(func(x interface{}) (interface{}, error) { + if x, ok := x.(*Expr); ok && x.IsCall() { + operator := x.Operator() + if operator.Equal(doubleEq) && len(x.Operands()) == 2 { + modified = true + x.SetOperator(NewTerm(unifyOp)) + } + } + return x, nil + }) + _, _ = Transform(t, x) // ignore error + return modified +} + +func rewriteTestEqualities(f *equalityFactory, body Body) Body { + result := make(Body, 0, len(body)) + for _, expr := range body { + // We can't rewrite negated expressions; if the extracted term is undefined, evaluation would fail before + // reaching the negation check. + if !expr.Negated && !expr.Generated { + switch { + case expr.IsEquality(): + terms := expr.Terms.([]*Term) + result, terms[1] = rewriteDynamicsShallow(expr, f, terms[1], result) + result, terms[2] = rewriteDynamicsShallow(expr, f, terms[2], result) + case expr.IsEvery(): + // We rewrite equalities inside of every-bodies as a fail here will be the cause of the test-rule fail. + // Failures inside other expressions with closures, such as comprehensions, won't cause the test-rule to fail, so we skip those. + every := expr.Terms.(*Every) + every.Body = rewriteTestEqualities(f, every.Body) + } + } + result = appendExpr(result, expr) + } + return result +} + +func rewriteDynamicsShallow(original *Expr, f *equalityFactory, term *Term, result Body) (Body, *Term) { + switch term.Value.(type) { + case Ref, *ArrayComprehension, *SetComprehension, *ObjectComprehension: + generated := f.Generate(term) + generated.With = original.With + result.Append(generated) + connectGeneratedExprs(original, generated) + return result, result[len(result)-1].Operand(0) + } + return result, term +} + +// rewriteDynamics will rewrite the body so that dynamic terms (i.e., refs and +// comprehensions) are bound to vars earlier in the query. This translation +// results in eager evaluation. +// +// For instance, given the following query: +// +// foo(data.bar) = 1 +// +// The rewritten version will be: +// +// __local0__ = data.bar; foo(__local0__) = 1 +func rewriteDynamics(f *equalityFactory, body Body) Body { + result := make(Body, 0, len(body)) + for _, expr := range body { + switch { + case expr.IsEquality(): + result = rewriteDynamicsEqExpr(f, expr, result) + case expr.IsCall(): + result = rewriteDynamicsCallExpr(f, expr, result) + case expr.IsEvery(): + result = rewriteDynamicsEveryExpr(f, expr, result) + default: + result = rewriteDynamicsTermExpr(f, expr, result) + } + } + return result +} + +func appendExpr(body Body, expr *Expr) Body { + body.Append(expr) + return body +} + +func rewriteDynamicsEqExpr(f *equalityFactory, expr *Expr, result Body) Body { + if !validEqAssignArgCount(expr) { + return appendExpr(result, expr) + } + terms := expr.Terms.([]*Term) + result, terms[1] = rewriteDynamicsInTerm(expr, f, terms[1], result) + result, terms[2] = rewriteDynamicsInTerm(expr, f, terms[2], result) + return appendExpr(result, expr) +} + +func rewriteDynamicsCallExpr(f *equalityFactory, expr *Expr, result Body) Body { + terms := expr.Terms.([]*Term) + for i := 1; i < len(terms); i++ { + result, terms[i] = rewriteDynamicsOne(expr, f, terms[i], result) + } + return appendExpr(result, expr) +} + +func rewriteDynamicsEveryExpr(f *equalityFactory, expr *Expr, result Body) Body { + ev := expr.Terms.(*Every) + result, ev.Domain = rewriteDynamicsOne(expr, f, ev.Domain, result) + ev.Body = rewriteDynamics(f, ev.Body) + return appendExpr(result, expr) +} + +func rewriteDynamicsTermExpr(f *equalityFactory, expr *Expr, result Body) Body { + term := expr.Terms.(*Term) + result, expr.Terms = rewriteDynamicsInTerm(expr, f, term, result) + return appendExpr(result, expr) +} + +func rewriteDynamicsInTerm(original *Expr, f *equalityFactory, term *Term, result Body) (Body, *Term) { + switch v := term.Value.(type) { + case Ref: + for i := 1; i < len(v); i++ { + result, v[i] = rewriteDynamicsOne(original, f, v[i], result) + } + case *ArrayComprehension: + v.Body = rewriteDynamics(f, v.Body) + case *SetComprehension: + v.Body = rewriteDynamics(f, v.Body) + case *ObjectComprehension: + v.Body = rewriteDynamics(f, v.Body) + default: + result, term = rewriteDynamicsOne(original, f, term, result) + } + return result, term +} + +func rewriteDynamicsOne(original *Expr, f *equalityFactory, term *Term, result Body) (Body, *Term) { + switch v := term.Value.(type) { + case Ref: + for i := 1; i < len(v); i++ { + result, v[i] = rewriteDynamicsOne(original, f, v[i], result) + } + generated := f.Generate(term) + generated.With = original.With + result.Append(generated) + connectGeneratedExprs(original, generated) + return result, result[len(result)-1].Operand(0) + case *Array: + for i := 0; i < v.Len(); i++ { + var t *Term + result, t = rewriteDynamicsOne(original, f, v.Elem(i), result) + v.set(i, t) + } + return result, term + case *object: + cpy := NewObject() + v.Foreach(func(key, value *Term) { + result, key = rewriteDynamicsOne(original, f, key, result) + result, value = rewriteDynamicsOne(original, f, value, result) + cpy.Insert(key, value) + }) + return result, NewTerm(cpy).SetLocation(term.Location) + case Set: + cpy := NewSet() + for _, term := range v.Slice() { + var rw *Term + result, rw = rewriteDynamicsOne(original, f, term, result) + cpy.Add(rw) + } + return result, NewTerm(cpy).SetLocation(term.Location) + case *ArrayComprehension: + var extra *Expr + v.Body, extra = rewriteDynamicsComprehensionBody(original, f, v.Body, term) + result.Append(extra) + connectGeneratedExprs(original, extra) + return result, result[len(result)-1].Operand(0) + case *SetComprehension: + var extra *Expr + v.Body, extra = rewriteDynamicsComprehensionBody(original, f, v.Body, term) + result.Append(extra) + connectGeneratedExprs(original, extra) + return result, result[len(result)-1].Operand(0) + case *ObjectComprehension: + var extra *Expr + v.Body, extra = rewriteDynamicsComprehensionBody(original, f, v.Body, term) + result.Append(extra) + connectGeneratedExprs(original, extra) + return result, result[len(result)-1].Operand(0) + } + return result, term +} + +func rewriteDynamicsComprehensionBody(original *Expr, f *equalityFactory, body Body, term *Term) (Body, *Expr) { + body = rewriteDynamics(f, body) + generated := f.Generate(term) + generated.With = original.With + return body, generated +} + +func rewriteExprTermsInHead(gen *localVarGenerator, rule *Rule) { + for i := range rule.Head.Args { + support, output := expandExprTerm(gen, rule.Head.Args[i]) + for j := range support { + rule.Body.Append(support[j]) + } + rule.Head.Args[i] = output + } + if rule.Head.Key != nil { + support, output := expandExprTerm(gen, rule.Head.Key) + for i := range support { + rule.Body.Append(support[i]) + } + rule.Head.Key = output + } + if rule.Head.Value != nil { + support, output := expandExprTerm(gen, rule.Head.Value) + for i := range support { + rule.Body.Append(support[i]) + } + rule.Head.Value = output + } +} + +func rewriteExprTermsInBody(gen *localVarGenerator, body Body) Body { + cpy := make(Body, 0, len(body)) + for i := 0; i < len(body); i++ { + for _, expr := range expandExpr(gen, body[i]) { + cpy.Append(expr) + } + } + return cpy +} + +func expandExpr(gen *localVarGenerator, expr *Expr) (result []*Expr) { + for i := range expr.With { + extras, value := expandExprTerm(gen, expr.With[i].Value) + expr.With[i].Value = value + result = append(result, extras...) + } + switch terms := expr.Terms.(type) { + case *Term: + extras, term := expandExprTerm(gen, terms) + if len(expr.With) > 0 { + for i := range extras { + extras[i].With = expr.With + } + } + result = append(result, extras...) + expr.Terms = term + result = append(result, expr) + case []*Term: + for i := 1; i < len(terms); i++ { + var extras []*Expr + extras, terms[i] = expandExprTerm(gen, terms[i]) + connectGeneratedExprs(expr, extras...) + if len(expr.With) > 0 { + for i := range extras { + extras[i].With = expr.With + } + } + result = append(result, extras...) + } + result = append(result, expr) + case *Every: + var extras []*Expr + + term := NewTerm(gen.Generate()).SetLocation(terms.Domain.Location) + eq := Equality.Expr(term, terms.Domain).SetLocation(terms.Domain.Location) + eq.Generated = true + eq.With = expr.With + extras = expandExpr(gen, eq) + terms.Domain = term + + terms.Body = rewriteExprTermsInBody(gen, terms.Body) + result = append(result, extras...) + result = append(result, expr) + } + return +} + +func connectGeneratedExprs(parent *Expr, children ...*Expr) { + for _, child := range children { + child.generatedFrom = parent + parent.generates = append(parent.generates, child) + } +} + +func expandExprTerm(gen *localVarGenerator, term *Term) (support []*Expr, output *Term) { + output = term + switch v := term.Value.(type) { + case Call: + for i := 1; i < len(v); i++ { + var extras []*Expr + extras, v[i] = expandExprTerm(gen, v[i]) + support = append(support, extras...) + } + output = NewTerm(gen.Generate()).SetLocation(term.Location) + expr := v.MakeExpr(output).SetLocation(term.Location) + expr.Generated = true + support = append(support, expr) + case Ref: + support = expandExprRef(gen, v) + case *Array: + support = expandExprTermArray(gen, v) + case *object: + cpy, _ := v.Map(func(k, v *Term) (*Term, *Term, error) { + extras1, expandedKey := expandExprTerm(gen, k) + extras2, expandedValue := expandExprTerm(gen, v) + support = append(support, extras1...) + support = append(support, extras2...) + return expandedKey, expandedValue, nil + }) + output = NewTerm(cpy).SetLocation(term.Location) + case Set: + cpy, _ := v.Map(func(x *Term) (*Term, error) { + extras, expanded := expandExprTerm(gen, x) + support = append(support, extras...) + return expanded, nil + }) + output = NewTerm(cpy).SetLocation(term.Location) + case *ArrayComprehension: + support, term := expandExprTerm(gen, v.Term) + for i := range support { + v.Body.Append(support[i]) + } + v.Term = term + v.Body = rewriteExprTermsInBody(gen, v.Body) + case *SetComprehension: + support, term := expandExprTerm(gen, v.Term) + for i := range support { + v.Body.Append(support[i]) + } + v.Term = term + v.Body = rewriteExprTermsInBody(gen, v.Body) + case *ObjectComprehension: + support, key := expandExprTerm(gen, v.Key) + for i := range support { + v.Body.Append(support[i]) + } + v.Key = key + support, value := expandExprTerm(gen, v.Value) + for i := range support { + v.Body.Append(support[i]) + } + v.Value = value + v.Body = rewriteExprTermsInBody(gen, v.Body) + } + return +} + +func expandExprRef(gen *localVarGenerator, v []*Term) (support []*Expr) { + // Start by calling a normal expandExprTerm on all terms. + support = expandExprTermSlice(gen, v) + + // Rewrite references in order to support indirect references. We rewrite + // e.g. + // + // [1, 2, 3][i] + // + // to + // + // __local_var = [1, 2, 3] + // __local_var[i] + // + // to support these. This only impacts the reference subject, i.e. the + // first item in the slice. + var subject = v[0] + switch subject.Value.(type) { + case *Array, Object, Set, *ArrayComprehension, *SetComprehension, *ObjectComprehension, Call: + f := newEqualityFactory(gen) + assignToLocal := f.Generate(subject) + support = append(support, assignToLocal) + v[0] = assignToLocal.Operand(0) + } + return +} + +func expandExprTermArray(gen *localVarGenerator, arr *Array) (support []*Expr) { + for i := 0; i < arr.Len(); i++ { + extras, v := expandExprTerm(gen, arr.Elem(i)) + arr.set(i, v) + support = append(support, extras...) + } + return +} + +func expandExprTermSlice(gen *localVarGenerator, v []*Term) (support []*Expr) { + for i := 0; i < len(v); i++ { + var extras []*Expr + extras, v[i] = expandExprTerm(gen, v[i]) + support = append(support, extras...) + } + return +} + +type localDeclaredVars struct { + vars []*declaredVarSet + + // rewritten contains a mapping of *all* user-defined variables + // that have been rewritten whereas vars contains the state + // from the current query (not any nested queries, and all vars + // seen). + rewritten map[Var]Var + + // indicates if an assignment (:= operator) has been seen *ever* + assignment bool +} + +type varOccurrence int + +const ( + newVar varOccurrence = iota + argVar + seenVar + assignedVar + declaredVar +) + +type declaredVarSet struct { + vs map[Var]Var + reverse map[Var]Var + occurrence map[Var]varOccurrence + count map[Var]int +} + +func newDeclaredVarSet() *declaredVarSet { + return &declaredVarSet{ + vs: map[Var]Var{}, + reverse: map[Var]Var{}, + occurrence: map[Var]varOccurrence{}, + count: map[Var]int{}, + } +} + +func newLocalDeclaredVars() *localDeclaredVars { + return &localDeclaredVars{ + vars: []*declaredVarSet{newDeclaredVarSet()}, + rewritten: map[Var]Var{}, + } +} + +func (s *localDeclaredVars) Copy() *localDeclaredVars { + stack := &localDeclaredVars{ + vars: []*declaredVarSet{}, + rewritten: map[Var]Var{}, + } + + for i := range s.vars { + stack.vars = append(stack.vars, newDeclaredVarSet()) + for k, v := range s.vars[i].vs { + stack.vars[0].vs[k] = v + } + for k, v := range s.vars[i].reverse { + stack.vars[0].reverse[k] = v + } + for k, v := range s.vars[i].count { + stack.vars[0].count[k] = v + } + for k, v := range s.vars[i].occurrence { + stack.vars[0].occurrence[k] = v + } + } + + for k, v := range s.rewritten { + stack.rewritten[k] = v + } + + return stack +} + +func (s *localDeclaredVars) Push() { + s.vars = append(s.vars, newDeclaredVarSet()) +} + +func (s *localDeclaredVars) Pop() *declaredVarSet { + sl := s.vars + curr := sl[len(sl)-1] + s.vars = sl[:len(sl)-1] + return curr +} + +func (s localDeclaredVars) Peek() *declaredVarSet { + return s.vars[len(s.vars)-1] +} + +func (s localDeclaredVars) Insert(x, y Var, occurrence varOccurrence) { + elem := s.vars[len(s.vars)-1] + elem.vs[x] = y + elem.reverse[y] = x + elem.occurrence[x] = occurrence + + elem.count[x] = 1 + + // If the variable has been rewritten (where x != y, with y being + // the generated value), store it in the map of rewritten vars. + // Assume that the generated values are unique for the compilation. + if !x.Equal(y) { + s.rewritten[y] = x + } +} + +func (s localDeclaredVars) Declared(x Var) (y Var, ok bool) { + for i := len(s.vars) - 1; i >= 0; i-- { + if y, ok = s.vars[i].vs[x]; ok { + return + } + } + return +} + +// Occurrence returns a flag that indicates whether x has occurred in the +// current scope. +func (s localDeclaredVars) Occurrence(x Var) varOccurrence { + return s.vars[len(s.vars)-1].occurrence[x] +} + +// GlobalOccurrence returns a flag that indicates whether x has occurred in the +// global scope. +func (s localDeclaredVars) GlobalOccurrence(x Var) (varOccurrence, bool) { + for i := len(s.vars) - 1; i >= 0; i-- { + if occ, ok := s.vars[i].occurrence[x]; ok { + return occ, true + } + } + return newVar, false +} + +// Seen marks x as seen by incrementing its counter +func (s localDeclaredVars) Seen(x Var) { + for i := len(s.vars) - 1; i >= 0; i-- { + dvs := s.vars[i] + if c, ok := dvs.count[x]; ok { + dvs.count[x] = c + 1 + return + } + } + + s.vars[len(s.vars)-1].count[x] = 1 +} + +// Count returns how many times x has been seen +func (s localDeclaredVars) Count(x Var) int { + for i := len(s.vars) - 1; i >= 0; i-- { + if c, ok := s.vars[i].count[x]; ok { + return c + } + } + + return 0 +} + +// rewriteLocalVars rewrites bodies to remove assignment/declaration +// expressions. For example: +// +// a := 1; p[a] +// +// Is rewritten to: +// +// __local0__ = 1; p[__local0__] +// +// During rewriting, assignees are validated to prevent use before declaration. +func rewriteLocalVars(g *localVarGenerator, stack *localDeclaredVars, used VarSet, body Body, strict bool) (Body, map[Var]Var, Errors) { + var errs Errors + body, errs = rewriteDeclaredVarsInBody(g, stack, used, body, errs, strict) + return body, stack.Peek().vs, errs +} + +func rewriteDeclaredVarsInBody(g *localVarGenerator, stack *localDeclaredVars, used VarSet, body Body, errs Errors, strict bool) (Body, Errors) { + + var cpy Body + + for i := range body { + var expr *Expr + switch { + case body[i].IsAssignment(): + stack.assignment = true + expr, errs = rewriteDeclaredAssignment(g, stack, body[i], errs, strict) + case body[i].IsSome(): + expr, errs = rewriteSomeDeclStatement(g, stack, body[i], errs, strict) + case body[i].IsEvery(): + expr, errs = rewriteEveryStatement(g, stack, body[i], errs, strict) + default: + expr, errs = rewriteDeclaredVarsInExpr(g, stack, body[i], errs, strict) + } + if expr != nil { + cpy.Append(expr) + } + } + + // If the body only contained a var statement it will be empty at this + // point. Append true to the body to ensure that it's non-empty (zero length + // bodies are not supported.) + if len(cpy) == 0 { + cpy.Append(NewExpr(BooleanTerm(true))) + } + + errs = checkUnusedAssignedVars(body, stack, used, errs, strict) + return cpy, checkUnusedDeclaredVars(body, stack, used, cpy, errs) +} + +func checkUnusedAssignedVars(body Body, stack *localDeclaredVars, used VarSet, errs Errors, strict bool) Errors { + + if !strict || len(errs) > 0 { + return errs + } + + dvs := stack.Peek() + unused := NewVarSet() + + for v, occ := range dvs.occurrence { + // A var that was assigned in this scope must have been seen (used) more than once (the time of assignment) in + // the same, or nested, scope to be counted as used. + if !v.IsWildcard() && stack.Count(v) <= 1 && occ == assignedVar { + unused.Add(dvs.vs[v]) + } + } + + rewrittenUsed := NewVarSet() + for v := range used { + if gv, ok := stack.Declared(v); ok { + rewrittenUsed.Add(gv) + } else { + rewrittenUsed.Add(v) + } + } + + unused = unused.Diff(rewrittenUsed) + + for _, gv := range unused.Sorted() { + found := false + for i := range body { + if body[i].Vars(VarVisitorParams{}).Contains(gv) { + errs = append(errs, NewError(CompileErr, body[i].Loc(), "assigned var %v unused", dvs.reverse[gv])) + found = true + break + } + } + if !found { + errs = append(errs, NewError(CompileErr, body[0].Loc(), "assigned var %v unused", dvs.reverse[gv])) + } + } + + return errs +} + +func checkUnusedDeclaredVars(body Body, stack *localDeclaredVars, used VarSet, cpy Body, errs Errors) Errors { + + // NOTE(tsandall): Do not generate more errors if there are existing + // declaration errors. + if len(errs) > 0 { + return errs + } + + dvs := stack.Peek() + declared := NewVarSet() + + for v, occ := range dvs.occurrence { + if occ == declaredVar { + declared.Add(dvs.vs[v]) + } + } + + bodyvars := cpy.Vars(VarVisitorParams{}) + + for v := range used { + if gv, ok := stack.Declared(v); ok { + bodyvars.Add(gv) + } else { + bodyvars.Add(v) + } + } + + unused := declared.Diff(bodyvars).Diff(used) + + for _, gv := range unused.Sorted() { + rv := dvs.reverse[gv] + if !rv.IsGenerated() { + // Scan through body exprs, looking for a match between the + // bad var's original name, and each expr's declared vars. + foundUnusedVarByName := false + for i := range body { + varsDeclaredInExpr := declaredVars(body[i]) + if varsDeclaredInExpr.Contains(dvs.reverse[gv]) { + // TODO(philipc): Clean up the offset logic here when the parser + // reports more accurate locations. + errs = append(errs, NewError(CompileErr, body[i].Loc(), "declared var %v unused", dvs.reverse[gv])) + foundUnusedVarByName = true + break + } + } + // Default error location returned. + if !foundUnusedVarByName { + errs = append(errs, NewError(CompileErr, body[0].Loc(), "declared var %v unused", dvs.reverse[gv])) + } + } + } + + return errs +} + +func rewriteEveryStatement(g *localVarGenerator, stack *localDeclaredVars, expr *Expr, errs Errors, strict bool) (*Expr, Errors) { + e := expr.Copy() + every := e.Terms.(*Every) + + errs = rewriteDeclaredVarsInTermRecursive(g, stack, every.Domain, errs, strict) + + stack.Push() + defer stack.Pop() + + // if the key exists, rewrite + if every.Key != nil { + if v := every.Key.Value.(Var); !v.IsWildcard() { + gv, err := rewriteDeclaredVar(g, stack, v, declaredVar) + if err != nil { + return nil, append(errs, NewError(CompileErr, every.Loc(), err.Error())) //nolint:govet + } + every.Key.Value = gv + } + } else { // if the key doesn't exist, add dummy local + every.Key = NewTerm(g.Generate()) + } + + // value is always present + if v := every.Value.Value.(Var); !v.IsWildcard() { + gv, err := rewriteDeclaredVar(g, stack, v, declaredVar) + if err != nil { + return nil, append(errs, NewError(CompileErr, every.Loc(), err.Error())) //nolint:govet + } + every.Value.Value = gv + } + + used := NewVarSet() + every.Body, errs = rewriteDeclaredVarsInBody(g, stack, used, every.Body, errs, strict) + + return rewriteDeclaredVarsInExpr(g, stack, e, errs, strict) +} + +func rewriteSomeDeclStatement(g *localVarGenerator, stack *localDeclaredVars, expr *Expr, errs Errors, strict bool) (*Expr, Errors) { + e := expr.Copy() + decl := e.Terms.(*SomeDecl) + for i := range decl.Symbols { + switch v := decl.Symbols[i].Value.(type) { + case Var: + if _, err := rewriteDeclaredVar(g, stack, v, declaredVar); err != nil { + return nil, append(errs, NewError(CompileErr, decl.Loc(), err.Error())) //nolint:govet + } + case Call: + var key, val, container *Term + switch len(v) { + case 4: // member3 + key = v[1] + val = v[2] + container = v[3] + case 3: // member + key = NewTerm(g.Generate()) + val = v[1] + container = v[2] + } + + var rhs *Term + switch c := container.Value.(type) { + case Ref: + rhs = RefTerm(append(c, key)...) + default: + rhs = RefTerm(container, key) + } + e.Terms = []*Term{ + RefTerm(VarTerm(Equality.Name)), val, rhs, + } + + for _, v0 := range outputVarsForExprEq(e, container.Vars()).Sorted() { + if _, err := rewriteDeclaredVar(g, stack, v0, declaredVar); err != nil { + return nil, append(errs, NewError(CompileErr, decl.Loc(), err.Error())) //nolint:govet + } + } + return rewriteDeclaredVarsInExpr(g, stack, e, errs, strict) + } + } + return nil, errs +} + +func rewriteDeclaredVarsInExpr(g *localVarGenerator, stack *localDeclaredVars, expr *Expr, errs Errors, strict bool) (*Expr, Errors) { + vis := NewGenericVisitor(func(x interface{}) bool { + var stop bool + switch x := x.(type) { + case *Term: + stop, errs = rewriteDeclaredVarsInTerm(g, stack, x, errs, strict) + case *With: + stop, errs = true, rewriteDeclaredVarsInWithRecursive(g, stack, x, errs, strict) + } + return stop + }) + vis.Walk(expr) + return expr, errs +} + +func rewriteDeclaredAssignment(g *localVarGenerator, stack *localDeclaredVars, expr *Expr, errs Errors, strict bool) (*Expr, Errors) { + + if expr.Negated { + errs = append(errs, NewError(CompileErr, expr.Location, "cannot assign vars inside negated expression")) + return expr, errs + } + + numErrsBefore := len(errs) + + if !validEqAssignArgCount(expr) { + return expr, errs + } + + // Rewrite terms on right hand side capture seen vars and recursively + // process comprehensions before left hand side is processed. Also + // rewrite with modifier. + errs = rewriteDeclaredVarsInTermRecursive(g, stack, expr.Operand(1), errs, strict) + + for _, w := range expr.With { + errs = rewriteDeclaredVarsInTermRecursive(g, stack, w.Value, errs, strict) + } + + // Rewrite vars on left hand side with unique names. Catch redeclaration + // and invalid term types here. + var vis func(t *Term) bool + + vis = func(t *Term) bool { + switch v := t.Value.(type) { + case Var: + if gv, err := rewriteDeclaredVar(g, stack, v, assignedVar); err != nil { + errs = append(errs, NewError(CompileErr, t.Location, err.Error())) //nolint:govet + } else { + t.Value = gv + } + return true + case *Array: + return false + case *object: + v.Foreach(func(_, v *Term) { + WalkTerms(v, vis) + }) + return true + case Ref: + if RootDocumentRefs.Contains(t) { + if gv, err := rewriteDeclaredVar(g, stack, v[0].Value.(Var), assignedVar); err != nil { + errs = append(errs, NewError(CompileErr, t.Location, err.Error())) //nolint:govet + } else { + t.Value = gv + } + return true + } + } + errs = append(errs, NewError(CompileErr, t.Location, "cannot assign to %v", TypeName(t.Value))) + return true + } + + WalkTerms(expr.Operand(0), vis) + + if len(errs) == numErrsBefore { + loc := expr.Operator()[0].Location + expr.SetOperator(RefTerm(VarTerm(Equality.Name).SetLocation(loc)).SetLocation(loc)) + } + + return expr, errs +} + +func rewriteDeclaredVarsInTerm(g *localVarGenerator, stack *localDeclaredVars, term *Term, errs Errors, strict bool) (bool, Errors) { + switch v := term.Value.(type) { + case Var: + if gv, ok := stack.Declared(v); ok { + term.Value = gv + stack.Seen(v) + } else if stack.Occurrence(v) == newVar { + stack.Insert(v, v, seenVar) + } + case Ref: + if RootDocumentRefs.Contains(term) { + x := v[0].Value.(Var) + if occ, ok := stack.GlobalOccurrence(x); ok && occ != seenVar { + gv, _ := stack.Declared(x) + term.Value = gv + } + + return true, errs + } + return false, errs + case Call: + ref := v[0] + WalkVars(ref, func(v Var) bool { + if gv, ok := stack.Declared(v); ok && !gv.Equal(v) { + // We will rewrite the ref of a function call, which is never ok since we don't have first-class functions. + errs = append(errs, NewError(CompileErr, term.Location, "called function %s shadowed", ref)) + return true + } + return false + }) + return false, errs + case *object: + cpy, _ := v.Map(func(k, v *Term) (*Term, *Term, error) { + kcpy := k.Copy() + errs = rewriteDeclaredVarsInTermRecursive(g, stack, kcpy, errs, strict) + errs = rewriteDeclaredVarsInTermRecursive(g, stack, v, errs, strict) + return kcpy, v, nil + }) + term.Value = cpy + case Set: + cpy, _ := v.Map(func(elem *Term) (*Term, error) { + elemcpy := elem.Copy() + errs = rewriteDeclaredVarsInTermRecursive(g, stack, elemcpy, errs, strict) + return elemcpy, nil + }) + term.Value = cpy + case *ArrayComprehension: + errs = rewriteDeclaredVarsInArrayComprehension(g, stack, v, errs, strict) + case *SetComprehension: + errs = rewriteDeclaredVarsInSetComprehension(g, stack, v, errs, strict) + case *ObjectComprehension: + errs = rewriteDeclaredVarsInObjectComprehension(g, stack, v, errs, strict) + default: + return false, errs + } + return true, errs +} + +func rewriteDeclaredVarsInTermRecursive(g *localVarGenerator, stack *localDeclaredVars, term *Term, errs Errors, strict bool) Errors { + WalkTerms(term, func(t *Term) bool { + var stop bool + stop, errs = rewriteDeclaredVarsInTerm(g, stack, t, errs, strict) + return stop + }) + return errs +} + +func rewriteDeclaredVarsInWithRecursive(g *localVarGenerator, stack *localDeclaredVars, w *With, errs Errors, strict bool) Errors { + // NOTE(sr): `with input as` and `with input.a.b.c as` are deliberately skipped here: `input` could + // have been shadowed by a local variable/argument but should NOT be replaced in the `with` target. + // + // We cannot drop `input` from the stack since it's conceivable to do `with input[input] as` where + // the second input is meant to be the local var. It's a terrible idea, but when you're shadowing + // `input` those might be your thing. + errs = rewriteDeclaredVarsInTermRecursive(g, stack, w.Target, errs, strict) + if sdwInput, ok := stack.Declared(InputRootDocument.Value.(Var)); ok { // Was "input" shadowed... + switch value := w.Target.Value.(type) { + case Var: + if sdwInput.Equal(value) { // ...and replaced? If so, fix it + w.Target.Value = InputRootRef + } + case Ref: + if sdwInput.Equal(value[0].Value.(Var)) { + w.Target.Value.(Ref)[0].Value = InputRootDocument.Value + } + } + } + // No special handling of the `with` value + return rewriteDeclaredVarsInTermRecursive(g, stack, w.Value, errs, strict) +} + +func rewriteDeclaredVarsInArrayComprehension(g *localVarGenerator, stack *localDeclaredVars, v *ArrayComprehension, errs Errors, strict bool) Errors { + used := NewVarSet() + used.Update(v.Term.Vars()) + + stack.Push() + v.Body, errs = rewriteDeclaredVarsInBody(g, stack, used, v.Body, errs, strict) + errs = rewriteDeclaredVarsInTermRecursive(g, stack, v.Term, errs, strict) + stack.Pop() + return errs +} + +func rewriteDeclaredVarsInSetComprehension(g *localVarGenerator, stack *localDeclaredVars, v *SetComprehension, errs Errors, strict bool) Errors { + used := NewVarSet() + used.Update(v.Term.Vars()) + + stack.Push() + v.Body, errs = rewriteDeclaredVarsInBody(g, stack, used, v.Body, errs, strict) + errs = rewriteDeclaredVarsInTermRecursive(g, stack, v.Term, errs, strict) + stack.Pop() + return errs +} + +func rewriteDeclaredVarsInObjectComprehension(g *localVarGenerator, stack *localDeclaredVars, v *ObjectComprehension, errs Errors, strict bool) Errors { + used := NewVarSet() + used.Update(v.Key.Vars()) + used.Update(v.Value.Vars()) + + stack.Push() + v.Body, errs = rewriteDeclaredVarsInBody(g, stack, used, v.Body, errs, strict) + errs = rewriteDeclaredVarsInTermRecursive(g, stack, v.Key, errs, strict) + errs = rewriteDeclaredVarsInTermRecursive(g, stack, v.Value, errs, strict) + stack.Pop() + return errs +} + +func rewriteDeclaredVar(g *localVarGenerator, stack *localDeclaredVars, v Var, occ varOccurrence) (gv Var, err error) { + switch stack.Occurrence(v) { + case seenVar: + return gv, fmt.Errorf("var %v referenced above", v) + case assignedVar: + return gv, fmt.Errorf("var %v assigned above", v) + case declaredVar: + return gv, fmt.Errorf("var %v declared above", v) + case argVar: + return gv, fmt.Errorf("arg %v redeclared", v) + } + gv = g.Generate() + stack.Insert(v, gv, occ) + return +} + +// rewriteWithModifiersInBody will rewrite the body so that with modifiers do +// not contain terms that require evaluation as values. If this function +// encounters an invalid with modifier target then it will raise an error. +func rewriteWithModifiersInBody(c *Compiler, unsafeBuiltinsMap map[string]struct{}, f *equalityFactory, body Body) (Body, *Error) { + var result Body + for i := range body { + exprs, err := rewriteWithModifier(c, unsafeBuiltinsMap, f, body[i]) + if err != nil { + return nil, err + } + if len(exprs) > 0 { + for _, expr := range exprs { + result.Append(expr) + } + } else { + result.Append(body[i]) + } + } + return result, nil +} + +func rewriteWithModifier(c *Compiler, unsafeBuiltinsMap map[string]struct{}, f *equalityFactory, expr *Expr) ([]*Expr, *Error) { + + var result []*Expr + for i := range expr.With { + eval, err := validateWith(c, unsafeBuiltinsMap, expr, i) + if err != nil { + return nil, err + } + + if eval { + eq := f.Generate(expr.With[i].Value) + result = append(result, eq) + expr.With[i].Value = eq.Operand(0) + } + } + + return append(result, expr), nil +} + +func validateWith(c *Compiler, unsafeBuiltinsMap map[string]struct{}, expr *Expr, i int) (bool, *Error) { + target, value := expr.With[i].Target, expr.With[i].Value + + // Ensure that values that are built-ins are rewritten to Ref (not Var) + if v, ok := value.Value.(Var); ok { + if _, ok := c.builtins[v.String()]; ok { + value.Value = Ref([]*Term{NewTerm(v)}) + } + } + isBuiltinRefOrVar, err := isBuiltinRefOrVar(c.builtins, unsafeBuiltinsMap, target) + if err != nil { + return false, err + } + + isAllowedUnknownFuncCall := false + if c.allowUndefinedFuncCalls { + switch target.Value.(type) { + case Ref, Var: + isAllowedUnknownFuncCall = true + } + } + + switch { + case isDataRef(target): + ref := target.Value.(Ref) + targetNode := c.RuleTree + for i := 0; i < len(ref)-1; i++ { + child := targetNode.Child(ref[i].Value) + if child == nil { + break + } else if len(child.Values) > 0 { + return false, NewError(CompileErr, target.Loc(), "with keyword cannot partially replace virtual document(s)") + } + targetNode = child + } + + if targetNode != nil { + // NOTE(sr): at this point in the compiler stages, we don't have a fully-populated + // TypeEnv yet -- so we have to make do with this check to see if the replacement + // target is a function. It's probably wrong for arity-0 functions, but those are + // and edge case anyways. + if child := targetNode.Child(ref[len(ref)-1].Value); child != nil { + for _, v := range child.Values { + if len(v.(*Rule).Head.Args) > 0 { + if ok, err := validateWithFunctionValue(c.builtins, unsafeBuiltinsMap, c.RuleTree, value); err != nil || ok { + return false, err // err may be nil + } + } + } + } + } + + // If the with-value is a ref to a function, but not a call, we can't rewrite it + if r, ok := value.Value.(Ref); ok { + // TODO: check that target ref doesn't exist? + if valueNode := c.RuleTree.Find(r); valueNode != nil { + for _, v := range valueNode.Values { + if len(v.(*Rule).Head.Args) > 0 { + return false, nil + } + } + } + } + case isInputRef(target): // ok, valid + case isBuiltinRefOrVar: + + // NOTE(sr): first we ensure that parsed Var builtins (`count`, `concat`, etc) + // are rewritten to their proper Ref convention + if v, ok := target.Value.(Var); ok { + target.Value = Ref([]*Term{NewTerm(v)}) + } + + targetRef := target.Value.(Ref) + bi := c.builtins[targetRef.String()] // safe because isBuiltinRefOrVar checked this + if err := validateWithBuiltinTarget(bi, targetRef, target.Loc()); err != nil { + return false, err + } + + if ok, err := validateWithFunctionValue(c.builtins, unsafeBuiltinsMap, c.RuleTree, value); err != nil || ok { + return false, err // err may be nil + } + case isAllowedUnknownFuncCall: + // The target isn't a ref to the input doc, data doc, or a known built-in, but it might be a ref to an unknown built-in. + return false, nil + default: + return false, NewError(TypeErr, target.Location, "with keyword target must reference existing %v, %v, or a function", InputRootDocument, DefaultRootDocument) + } + return requiresEval(value), nil +} + +func validateWithBuiltinTarget(bi *Builtin, target Ref, loc *location.Location) *Error { + switch bi.Name { + case Equality.Name, + RegoMetadataChain.Name, + RegoMetadataRule.Name: + return NewError(CompileErr, loc, "with keyword replacing built-in function: replacement of %q invalid", bi.Name) + } + + switch { + case target.HasPrefix(Ref([]*Term{VarTerm("internal")})): + return NewError(CompileErr, loc, "with keyword replacing built-in function: replacement of internal function %q invalid", target) + + case bi.Relation: + return NewError(CompileErr, loc, "with keyword replacing built-in function: target must not be a relation") + + case bi.Decl.Result() == nil: + return NewError(CompileErr, loc, "with keyword replacing built-in function: target must not be a void function") + } + return nil +} + +func validateWithFunctionValue(bs map[string]*Builtin, unsafeMap map[string]struct{}, ruleTree *TreeNode, value *Term) (bool, *Error) { + if v, ok := value.Value.(Ref); ok { + if ruleTree.Find(v) != nil { // ref exists in rule tree + return true, nil + } + } + return isBuiltinRefOrVar(bs, unsafeMap, value) +} + +func isInputRef(term *Term) bool { + if ref, ok := term.Value.(Ref); ok { + if ref.HasPrefix(InputRootRef) { + return true + } + } + return false +} + +func isDataRef(term *Term) bool { + if ref, ok := term.Value.(Ref); ok { + if ref.HasPrefix(DefaultRootRef) { + return true + } + } + return false +} + +func isBuiltinRefOrVar(bs map[string]*Builtin, unsafeBuiltinsMap map[string]struct{}, term *Term) (bool, *Error) { + switch v := term.Value.(type) { + case Ref, Var: + if _, ok := unsafeBuiltinsMap[v.String()]; ok { + return false, NewError(CompileErr, term.Location, "with keyword replacing built-in function: target must not be unsafe: %q", v) + } + _, ok := bs[v.String()] + return ok, nil + } + return false, nil +} + +func isVirtual(node *TreeNode, ref Ref) bool { + for i := range ref { + child := node.Child(ref[i].Value) + if child == nil { + return false + } else if len(child.Values) > 0 { + return true + } + node = child + } + return true +} + +func safetyErrorSlice(unsafe unsafeVars, rewritten map[Var]Var) (result Errors) { + + if len(unsafe) == 0 { + return + } + + for _, pair := range unsafe.Vars() { + v := pair.Var + if w, ok := rewritten[v]; ok { + v = w + } + if !v.IsGenerated() { + if _, ok := allFutureKeywords[string(v)]; ok { + result = append(result, NewError(UnsafeVarErr, pair.Loc, + "var %[1]v is unsafe (hint: `import future.keywords.%[1]v` to import a future keyword)", v)) + continue + } + result = append(result, NewError(UnsafeVarErr, pair.Loc, "var %v is unsafe", v)) + } + } + + if len(result) > 0 { + return + } + + // If the expression contains unsafe generated variables, report which + // expressions are unsafe instead of the variables that are unsafe (since + // the latter are not meaningful to the user.) + pairs := unsafe.Slice() + + sort.Slice(pairs, func(i, j int) bool { + return pairs[i].Expr.Location.Compare(pairs[j].Expr.Location) < 0 + }) + + // Report at most one error per generated variable. + seen := NewVarSet() + + for _, expr := range pairs { + before := len(seen) + for v := range expr.Vars { + if v.IsGenerated() { + seen.Add(v) + } + } + if len(seen) > before { + result = append(result, NewError(UnsafeVarErr, expr.Expr.Location, "expression is unsafe")) + } + } + + return +} + +func checkUnsafeBuiltins(unsafeBuiltinsMap map[string]struct{}, node interface{}) Errors { + errs := make(Errors, 0) + WalkExprs(node, func(x *Expr) bool { + if x.IsCall() { + operator := x.Operator().String() + if _, ok := unsafeBuiltinsMap[operator]; ok { + errs = append(errs, NewError(TypeErr, x.Loc(), "unsafe built-in function calls in expression: %v", operator)) + } + } + return false + }) + return errs +} + +func rewriteVarsInRef(vars ...map[Var]Var) varRewriter { + return func(node Ref) Ref { + i, _ := TransformVars(node, func(v Var) (Value, error) { + for _, m := range vars { + if u, ok := m[v]; ok { + return u, nil + } + } + return v, nil + }) + return i.(Ref) + } +} + +// NOTE(sr): This is duplicated with compile/compile.go; but moving it into another location +// would cause a circular dependency -- the refSet definition needs ast.Ref. If we make it +// public in the ast package, the compile package could take it from there, but it would also +// increase our public interface. Let's reconsider if we need it in a third place. +type refSet struct { + s []Ref +} + +func newRefSet(x ...Ref) *refSet { + result := &refSet{} + for i := range x { + result.AddPrefix(x[i]) + } + return result +} + +// ContainsPrefix returns true if r is prefixed by any of the existing refs in the set. +func (rs *refSet) ContainsPrefix(r Ref) bool { + for i := range rs.s { + if r.HasPrefix(rs.s[i]) { + return true + } + } + return false +} + +// AddPrefix inserts r into the set if r is not prefixed by any existing +// refs in the set. If any existing refs are prefixed by r, those existing +// refs are removed. +func (rs *refSet) AddPrefix(r Ref) { + if rs.ContainsPrefix(r) { + return + } + cpy := []Ref{r} + for i := range rs.s { + if !rs.s[i].HasPrefix(r) { + cpy = append(cpy, rs.s[i]) + } + } + rs.s = cpy +} + +// Sorted returns a sorted slice of terms for refs in the set. +func (rs *refSet) Sorted() []*Term { + terms := make([]*Term, len(rs.s)) + for i := range rs.s { + terms[i] = NewTerm(rs.s[i]) + } + sort.Slice(terms, func(i, j int) bool { + return terms[i].Value.Compare(terms[j].Value) < 0 + }) + return terms +} diff --git a/vendor/github.com/open-policy-agent/opa/v1/ast/compilehelper.go b/vendor/github.com/open-policy-agent/opa/v1/ast/compilehelper.go new file mode 100644 index 000000000..7d81d45e6 --- /dev/null +++ b/vendor/github.com/open-policy-agent/opa/v1/ast/compilehelper.go @@ -0,0 +1,62 @@ +// Copyright 2016 The OPA Authors. All rights reserved. +// Use of this source code is governed by an Apache2 +// license that can be found in the LICENSE file. + +package ast + +// CompileModules takes a set of Rego modules represented as strings and +// compiles them for evaluation. The keys of the map are used as filenames. +func CompileModules(modules map[string]string) (*Compiler, error) { + return CompileModulesWithOpt(modules, CompileOpts{}) +} + +// CompileOpts defines a set of options for the compiler. +type CompileOpts struct { + EnablePrintStatements bool + ParserOptions ParserOptions +} + +// CompileModulesWithOpt takes a set of Rego modules represented as strings and +// compiles them for evaluation. The keys of the map are used as filenames. +func CompileModulesWithOpt(modules map[string]string, opts CompileOpts) (*Compiler, error) { + + parsed := make(map[string]*Module, len(modules)) + + for f, module := range modules { + var pm *Module + var err error + if pm, err = ParseModuleWithOpts(f, module, opts.ParserOptions); err != nil { + return nil, err + } + parsed[f] = pm + } + + compiler := NewCompiler(). + WithDefaultRegoVersion(opts.ParserOptions.RegoVersion). + WithEnablePrintStatements(opts.EnablePrintStatements) + compiler.Compile(parsed) + + if compiler.Failed() { + return nil, compiler.Errors + } + + return compiler, nil +} + +// MustCompileModules compiles a set of Rego modules represented as strings. If +// the compilation process fails, this function panics. +func MustCompileModules(modules map[string]string) *Compiler { + return MustCompileModulesWithOpts(modules, CompileOpts{}) +} + +// MustCompileModulesWithOpts compiles a set of Rego modules represented as strings. If +// the compilation process fails, this function panics. +func MustCompileModulesWithOpts(modules map[string]string, opts CompileOpts) *Compiler { + + compiler, err := CompileModulesWithOpt(modules, opts) + if err != nil { + panic(err) + } + + return compiler +} diff --git a/vendor/github.com/open-policy-agent/opa/ast/compilemetrics.go b/vendor/github.com/open-policy-agent/opa/v1/ast/compilemetrics.go similarity index 100% rename from vendor/github.com/open-policy-agent/opa/ast/compilemetrics.go rename to vendor/github.com/open-policy-agent/opa/v1/ast/compilemetrics.go diff --git a/vendor/github.com/open-policy-agent/opa/v1/ast/conflicts.go b/vendor/github.com/open-policy-agent/opa/v1/ast/conflicts.go new file mode 100644 index 000000000..c2713ad57 --- /dev/null +++ b/vendor/github.com/open-policy-agent/opa/v1/ast/conflicts.go @@ -0,0 +1,53 @@ +// Copyright 2019 The OPA Authors. All rights reserved. +// Use of this source code is governed by an Apache2 +// license that can be found in the LICENSE file. + +package ast + +import ( + "strings" +) + +// CheckPathConflicts returns a set of errors indicating paths that +// are in conflict with the result of the provided callable. +func CheckPathConflicts(c *Compiler, exists func([]string) (bool, error)) Errors { + var errs Errors + + root := c.RuleTree.Child(DefaultRootDocument.Value) + if root == nil { + return nil + } + + for _, node := range root.Children { + errs = append(errs, checkDocumentConflicts(node, exists, nil)...) + } + + return errs +} + +func checkDocumentConflicts(node *TreeNode, exists func([]string) (bool, error), path []string) Errors { + + switch key := node.Key.(type) { + case String: + path = append(path, string(key)) + default: // other key types cannot conflict with data + return nil + } + + if len(node.Values) > 0 { + s := strings.Join(path, "/") + if ok, err := exists(path); err != nil { + return Errors{NewError(CompileErr, node.Values[0].(*Rule).Loc(), "conflict check for data path %v: %v", s, err.Error())} + } else if ok { + return Errors{NewError(CompileErr, node.Values[0].(*Rule).Loc(), "conflicting rule for data path %v found", s)} + } + } + + var errs Errors + + for _, child := range node.Children { + errs = append(errs, checkDocumentConflicts(child, exists, path)...) + } + + return errs +} diff --git a/vendor/github.com/open-policy-agent/opa/v1/ast/doc.go b/vendor/github.com/open-policy-agent/opa/v1/ast/doc.go new file mode 100644 index 000000000..62b04e301 --- /dev/null +++ b/vendor/github.com/open-policy-agent/opa/v1/ast/doc.go @@ -0,0 +1,36 @@ +// Copyright 2016 The OPA Authors. All rights reserved. +// Use of this source code is governed by an Apache2 +// license that can be found in the LICENSE file. + +// Package ast declares Rego syntax tree types and also includes a parser and compiler for preparing policies for execution in the policy engine. +// +// Rego policies are defined using a relatively small set of types: modules, package and import declarations, rules, expressions, and terms. At their core, policies consist of rules that are defined by one or more expressions over documents available to the policy engine. The expressions are defined by intrinsic values (terms) such as strings, objects, variables, etc. +// +// Rego policies are typically defined in text files and then parsed and compiled by the policy engine at runtime. The parsing stage takes the text or string representation of the policy and converts it into an abstract syntax tree (AST) that consists of the types mentioned above. The AST is organized as follows: +// +// Module +// | +// +--- Package (Reference) +// | +// +--- Imports +// | | +// | +--- Import (Term) +// | +// +--- Rules +// | +// +--- Rule +// | +// +--- Head +// | | +// | +--- Name (Variable) +// | | +// | +--- Key (Term) +// | | +// | +--- Value (Term) +// | +// +--- Body +// | +// +--- Expression (Term | Terms | Variable Declaration) +// +// At query time, the policy engine expects policies to have been compiled. The compilation stage takes one or more modules and compiles them into a format that the policy engine supports. +package ast diff --git a/vendor/github.com/open-policy-agent/opa/v1/ast/env.go b/vendor/github.com/open-policy-agent/opa/v1/ast/env.go new file mode 100644 index 000000000..fb374b173 --- /dev/null +++ b/vendor/github.com/open-policy-agent/opa/v1/ast/env.go @@ -0,0 +1,526 @@ +// Copyright 2017 The OPA Authors. All rights reserved. +// Use of this source code is governed by an Apache2 +// license that can be found in the LICENSE file. + +package ast + +import ( + "fmt" + "strings" + + "github.com/open-policy-agent/opa/v1/types" + "github.com/open-policy-agent/opa/v1/util" +) + +// TypeEnv contains type info for static analysis such as type checking. +type TypeEnv struct { + tree *typeTreeNode + next *TypeEnv + newChecker func() *typeChecker +} + +// newTypeEnv returns an empty TypeEnv. The constructor is not exported because +// type environments should only be created by the type checker. +func newTypeEnv(f func() *typeChecker) *TypeEnv { + return &TypeEnv{ + tree: newTypeTree(), + newChecker: f, + } +} + +// Get returns the type of x. +func (env *TypeEnv) Get(x interface{}) types.Type { + + if term, ok := x.(*Term); ok { + x = term.Value + } + + switch x := x.(type) { + + // Scalars. + case Null: + return types.NewNull() + case Boolean: + return types.NewBoolean() + case Number: + return types.NewNumber() + case String: + return types.NewString() + + // Composites. + case *Array: + static := make([]types.Type, x.Len()) + for i := range static { + tpe := env.Get(x.Elem(i).Value) + static[i] = tpe + } + + var dynamic types.Type + if len(static) == 0 { + dynamic = types.A + } + + return types.NewArray(static, dynamic) + + case *lazyObj: + return env.Get(x.force()) + case *object: + static := []*types.StaticProperty{} + var dynamic *types.DynamicProperty + + x.Foreach(func(k, v *Term) { + if IsConstant(k.Value) { + kjson, err := JSON(k.Value) + if err == nil { + tpe := env.Get(v) + static = append(static, types.NewStaticProperty(kjson, tpe)) + return + } + } + // Can't handle it as a static property, fallback to dynamic + typeK := env.Get(k.Value) + typeV := env.Get(v.Value) + dynamic = types.NewDynamicProperty(typeK, typeV) + }) + + if len(static) == 0 && dynamic == nil { + dynamic = types.NewDynamicProperty(types.A, types.A) + } + + return types.NewObject(static, dynamic) + + case Set: + var tpe types.Type + x.Foreach(func(elem *Term) { + other := env.Get(elem.Value) + tpe = types.Or(tpe, other) + }) + if tpe == nil { + tpe = types.A + } + return types.NewSet(tpe) + + // Comprehensions. + case *ArrayComprehension: + cpy, errs := env.newChecker().CheckBody(env, x.Body) + if len(errs) == 0 { + return types.NewArray(nil, cpy.Get(x.Term)) + } + return nil + case *ObjectComprehension: + cpy, errs := env.newChecker().CheckBody(env, x.Body) + if len(errs) == 0 { + return types.NewObject(nil, types.NewDynamicProperty(cpy.Get(x.Key), cpy.Get(x.Value))) + } + return nil + case *SetComprehension: + cpy, errs := env.newChecker().CheckBody(env, x.Body) + if len(errs) == 0 { + return types.NewSet(cpy.Get(x.Term)) + } + return nil + + // Refs. + case Ref: + return env.getRef(x) + + // Vars. + case Var: + if node := env.tree.Child(x); node != nil { + return node.Value() + } + if env.next != nil { + return env.next.Get(x) + } + return nil + + // Calls. + case Call: + return nil + + default: + panic("unreachable") + } +} + +func (env *TypeEnv) getRef(ref Ref) types.Type { + + node := env.tree.Child(ref[0].Value) + if node == nil { + return env.getRefFallback(ref) + } + + return env.getRefRec(node, ref, ref[1:]) +} + +func (env *TypeEnv) getRefFallback(ref Ref) types.Type { + + if env.next != nil { + return env.next.Get(ref) + } + + if RootDocumentNames.Contains(ref[0]) { + return types.A + } + + return nil +} + +func (env *TypeEnv) getRefRec(node *typeTreeNode, ref, tail Ref) types.Type { + if len(tail) == 0 { + return env.getRefRecExtent(node) + } + + if node.Leaf() { + if node.children.Len() > 0 { + if child := node.Child(tail[0].Value); child != nil { + return env.getRefRec(child, ref, tail[1:]) + } + } + return selectRef(node.Value(), tail) + } + + if !IsConstant(tail[0].Value) { + return selectRef(env.getRefRecExtent(node), tail) + } + + child := node.Child(tail[0].Value) + if child == nil { + return env.getRefFallback(ref) + } + + return env.getRefRec(child, ref, tail[1:]) +} + +func (env *TypeEnv) getRefRecExtent(node *typeTreeNode) types.Type { + + if node.Leaf() { + return node.Value() + } + + children := []*types.StaticProperty{} + + node.Children().Iter(func(k, v util.T) bool { + key := k.(Value) + child := v.(*typeTreeNode) + + tpe := env.getRefRecExtent(child) + + // NOTE(sr): Converting to Golang-native types here is an extension of what we did + // before -- only supporting strings. But since we cannot differentiate sets and arrays + // that way, we could reconsider. + switch key.(type) { + case String, Number, Boolean: // skip anything else + propKey, err := JSON(key) + if err != nil { + panic(fmt.Errorf("unreachable, ValueToInterface: %w", err)) + } + children = append(children, types.NewStaticProperty(propKey, tpe)) + } + return false + }) + + // TODO(tsandall): for now, these objects can have any dynamic properties + // because we don't have schema for base docs. Once schemas are supported + // we can improve this. + return types.NewObject(children, types.NewDynamicProperty(types.S, types.A)) +} + +func (env *TypeEnv) wrap() *TypeEnv { + cpy := *env + cpy.next = env + cpy.tree = newTypeTree() + return &cpy +} + +// typeTreeNode is used to store type information in a tree. +type typeTreeNode struct { + key Value + value types.Type + children *util.HashMap +} + +func newTypeTree() *typeTreeNode { + return &typeTreeNode{ + key: nil, + value: nil, + children: util.NewHashMap(valueEq, valueHash), + } +} + +func (n *typeTreeNode) Child(key Value) *typeTreeNode { + value, ok := n.children.Get(key) + if !ok { + return nil + } + return value.(*typeTreeNode) +} + +func (n *typeTreeNode) Children() *util.HashMap { + return n.children +} + +func (n *typeTreeNode) Get(path Ref) types.Type { + curr := n + for _, term := range path { + child, ok := curr.children.Get(term.Value) + if !ok { + return nil + } + curr = child.(*typeTreeNode) + } + return curr.Value() +} + +func (n *typeTreeNode) Leaf() bool { + return n.value != nil +} + +func (n *typeTreeNode) PutOne(key Value, tpe types.Type) { + c, ok := n.children.Get(key) + + var child *typeTreeNode + if !ok { + child = newTypeTree() + child.key = key + n.children.Put(key, child) + } else { + child = c.(*typeTreeNode) + } + + child.value = tpe +} + +func (n *typeTreeNode) Put(path Ref, tpe types.Type) { + curr := n + for _, term := range path { + c, ok := curr.children.Get(term.Value) + + var child *typeTreeNode + if !ok { + child = newTypeTree() + child.key = term.Value + curr.children.Put(child.key, child) + } else { + child = c.(*typeTreeNode) + } + + curr = child + } + curr.value = tpe +} + +// Insert inserts tpe at path in the tree, but also merges the value into any types.Object present along that path. +// If a types.Object is inserted, any leafs already present further down the tree are merged into the inserted object. +// path must be ground. +func (n *typeTreeNode) Insert(path Ref, tpe types.Type, env *TypeEnv) { + curr := n + for i, term := range path { + c, ok := curr.children.Get(term.Value) + + var child *typeTreeNode + if !ok { + child = newTypeTree() + child.key = term.Value + curr.children.Put(child.key, child) + } else { + child = c.(*typeTreeNode) + + if child.value != nil && i+1 < len(path) { + // If child has an object value, merge the new value into it. + if o, ok := child.value.(*types.Object); ok { + var err error + child.value, err = insertIntoObject(o, path[i+1:], tpe, env) + if err != nil { + panic(fmt.Errorf("unreachable, insertIntoObject: %w", err)) + } + } + } + } + + curr = child + } + + curr.value = mergeTypes(curr.value, tpe) + + if _, ok := tpe.(*types.Object); ok && curr.children.Len() > 0 { + // merge all leafs into the inserted object + leafs := curr.Leafs() + for p, t := range leafs { + var err error + curr.value, err = insertIntoObject(curr.value.(*types.Object), *p, t, env) + if err != nil { + panic(fmt.Errorf("unreachable, insertIntoObject: %w", err)) + } + } + } +} + +// mergeTypes merges the types of 'a' and 'b'. If both are sets, their 'of' types are joined with an types.Or. +// If both are objects, the key types of their dynamic properties are joined with types.Or:s, and their value types +// are recursively merged (using mergeTypes). +// If 'a' and 'b' are both objects, and at least one of them have static properties, they are joined +// with an types.Or, instead of being merged. +// If 'a' is an Any containing an Object, and 'b' is an Object (or vice versa); AND both objects have no +// static properties, they are merged. +// If 'a' and 'b' are different types, they are joined with an types.Or. +func mergeTypes(a, b types.Type) types.Type { + if a == nil { + return b + } + + if b == nil { + return a + } + + switch a := a.(type) { + case *types.Object: + if bObj, ok := b.(*types.Object); ok && len(a.StaticProperties()) == 0 && len(bObj.StaticProperties()) == 0 { + if len(a.StaticProperties()) > 0 || len(bObj.StaticProperties()) > 0 { + return types.Or(a, bObj) + } + + aDynProps := a.DynamicProperties() + bDynProps := bObj.DynamicProperties() + dynProps := types.NewDynamicProperty( + types.Or(aDynProps.Key, bDynProps.Key), + mergeTypes(aDynProps.Value, bDynProps.Value)) + return types.NewObject(nil, dynProps) + } else if bAny, ok := b.(types.Any); ok && len(a.StaticProperties()) == 0 { + // If a is an object type with no static components ... + for _, t := range bAny { + if tObj, ok := t.(*types.Object); ok && len(tObj.StaticProperties()) == 0 { + // ... and b is a types.Any containing an object with no static components, we merge them. + aDynProps := a.DynamicProperties() + tDynProps := tObj.DynamicProperties() + tDynProps.Key = types.Or(tDynProps.Key, aDynProps.Key) + tDynProps.Value = types.Or(tDynProps.Value, aDynProps.Value) + return bAny + } + } + } + case *types.Set: + if bSet, ok := b.(*types.Set); ok { + return types.NewSet(types.Or(a.Of(), bSet.Of())) + } + case types.Any: + if _, ok := b.(types.Any); !ok { + return mergeTypes(b, a) + } + } + + return types.Or(a, b) +} + +func (n *typeTreeNode) String() string { + b := strings.Builder{} + + if k := n.key; k != nil { + b.WriteString(k.String()) + } else { + b.WriteString("-") + } + + if v := n.value; v != nil { + b.WriteString(": ") + b.WriteString(v.String()) + } + + n.children.Iter(func(_, v util.T) bool { + if child, ok := v.(*typeTreeNode); ok { + b.WriteString("\n\t+ ") + s := child.String() + s = strings.ReplaceAll(s, "\n", "\n\t") + b.WriteString(s) + } + return false + }) + + return b.String() +} + +func insertIntoObject(o *types.Object, path Ref, tpe types.Type, env *TypeEnv) (*types.Object, error) { + if len(path) == 0 { + return o, nil + } + + key := env.Get(path[0].Value) + + if len(path) == 1 { + var dynamicProps *types.DynamicProperty + if dp := o.DynamicProperties(); dp != nil { + dynamicProps = types.NewDynamicProperty(types.Or(o.DynamicProperties().Key, key), types.Or(o.DynamicProperties().Value, tpe)) + } else { + dynamicProps = types.NewDynamicProperty(key, tpe) + } + return types.NewObject(o.StaticProperties(), dynamicProps), nil + } + + child, err := insertIntoObject(types.NewObject(nil, nil), path[1:], tpe, env) + if err != nil { + return nil, err + } + + var dynamicProps *types.DynamicProperty + if dp := o.DynamicProperties(); dp != nil { + dynamicProps = types.NewDynamicProperty(types.Or(o.DynamicProperties().Key, key), types.Or(o.DynamicProperties().Value, child)) + } else { + dynamicProps = types.NewDynamicProperty(key, child) + } + return types.NewObject(o.StaticProperties(), dynamicProps), nil +} + +func (n *typeTreeNode) Leafs() map[*Ref]types.Type { + leafs := map[*Ref]types.Type{} + n.children.Iter(func(_, v util.T) bool { + collectLeafs(v.(*typeTreeNode), nil, leafs) + return false + }) + return leafs +} + +func collectLeafs(n *typeTreeNode, path Ref, leafs map[*Ref]types.Type) { + nPath := append(path, NewTerm(n.key)) + if n.Leaf() { + leafs[&nPath] = n.Value() + return + } + n.children.Iter(func(_, v util.T) bool { + collectLeafs(v.(*typeTreeNode), nPath, leafs) + return false + }) +} + +func (n *typeTreeNode) Value() types.Type { + return n.value +} + +// selectConstant returns the attribute of the type referred to by the term. If +// the attribute type cannot be determined, nil is returned. +func selectConstant(tpe types.Type, term *Term) types.Type { + x, err := JSON(term.Value) + if err == nil { + return types.Select(tpe, x) + } + return nil +} + +// selectRef returns the type of the nested attribute referred to by ref. If +// the attribute type cannot be determined, nil is returned. If the ref +// contains vars or refs, then the returned type will be a union of the +// possible types. +func selectRef(tpe types.Type, ref Ref) types.Type { + + if tpe == nil || len(ref) == 0 { + return tpe + } + + head, tail := ref[0], ref[1:] + + switch head.Value.(type) { + case Var, Ref, *Array, Object, Set: + return selectRef(types.Values(tpe), tail) + default: + return selectRef(selectConstant(tpe, head), tail) + } +} diff --git a/vendor/github.com/open-policy-agent/opa/v1/ast/errors.go b/vendor/github.com/open-policy-agent/opa/v1/ast/errors.go new file mode 100644 index 000000000..066dfcdd6 --- /dev/null +++ b/vendor/github.com/open-policy-agent/opa/v1/ast/errors.go @@ -0,0 +1,123 @@ +// Copyright 2016 The OPA Authors. All rights reserved. +// Use of this source code is governed by an Apache2 +// license that can be found in the LICENSE file. + +package ast + +import ( + "fmt" + "sort" + "strings" +) + +// Errors represents a series of errors encountered during parsing, compiling, +// etc. +type Errors []*Error + +func (e Errors) Error() string { + + if len(e) == 0 { + return "no error(s)" + } + + if len(e) == 1 { + return fmt.Sprintf("1 error occurred: %v", e[0].Error()) + } + + s := make([]string, len(e)) + for i, err := range e { + s[i] = err.Error() + } + + return fmt.Sprintf("%d errors occurred:\n%s", len(e), strings.Join(s, "\n")) +} + +// Sort sorts the error slice by location. If the locations are equal then the +// error message is compared. +func (e Errors) Sort() { + sort.Slice(e, func(i, j int) bool { + a := e[i] + b := e[j] + + if cmp := a.Location.Compare(b.Location); cmp != 0 { + return cmp < 0 + } + + return a.Error() < b.Error() + }) +} + +const ( + // ParseErr indicates an unclassified parse error occurred. + ParseErr = "rego_parse_error" + + // CompileErr indicates an unclassified compile error occurred. + CompileErr = "rego_compile_error" + + // TypeErr indicates a type error was caught. + TypeErr = "rego_type_error" + + // UnsafeVarErr indicates an unsafe variable was found during compilation. + UnsafeVarErr = "rego_unsafe_var_error" + + // RecursionErr indicates recursion was found during compilation. + RecursionErr = "rego_recursion_error" +) + +// IsError returns true if err is an AST error with code. +func IsError(code string, err error) bool { + if err, ok := err.(*Error); ok { + return err.Code == code + } + return false +} + +// ErrorDetails defines the interface for detailed error messages. +type ErrorDetails interface { + Lines() []string +} + +// Error represents a single error caught during parsing, compiling, etc. +type Error struct { + Code string `json:"code"` + Message string `json:"message"` + Location *Location `json:"location,omitempty"` + Details ErrorDetails `json:"details,omitempty"` +} + +func (e *Error) Error() string { + + var prefix string + + if e.Location != nil { + + if len(e.Location.File) > 0 { + prefix += e.Location.File + ":" + fmt.Sprint(e.Location.Row) + } else { + prefix += fmt.Sprint(e.Location.Row) + ":" + fmt.Sprint(e.Location.Col) + } + } + + msg := fmt.Sprintf("%v: %v", e.Code, e.Message) + + if len(prefix) > 0 { + msg = prefix + ": " + msg + } + + if e.Details != nil { + for _, line := range e.Details.Lines() { + msg += "\n\t" + line + } + } + + return msg +} + +// NewError returns a new Error object. +func NewError(code string, loc *Location, f string, a ...interface{}) *Error { + return &Error{ + Code: code, + Location: loc, + Message: fmt.Sprintf(f, a...), + } +} diff --git a/vendor/github.com/open-policy-agent/opa/v1/ast/index.go b/vendor/github.com/open-policy-agent/opa/v1/ast/index.go new file mode 100644 index 000000000..63cd480d1 --- /dev/null +++ b/vendor/github.com/open-policy-agent/opa/v1/ast/index.go @@ -0,0 +1,932 @@ +// Copyright 2017 The OPA Authors. All rights reserved. +// Use of this source code is governed by an Apache2 +// license that can be found in the LICENSE file. + +package ast + +import ( + "fmt" + "sort" + "strings" + "sync" + + "github.com/open-policy-agent/opa/v1/util" +) + +// RuleIndex defines the interface for rule indices. +type RuleIndex interface { + + // Build tries to construct an index for the given rules. If the index was + // constructed, it returns true, otherwise false. + Build(rules []*Rule) bool + + // Lookup searches the index for rules that will match the provided + // resolver. If the resolver returns an error, it is returned via err. + Lookup(resolver ValueResolver) (*IndexResult, error) + + // AllRules traverses the index and returns all rules that will match + // the provided resolver without any optimizations (effectively with + // indexing disabled). If the resolver returns an error, it is returned + // via err. + AllRules(resolver ValueResolver) (*IndexResult, error) +} + +// IndexResult contains the result of an index lookup. +type IndexResult struct { + Kind RuleKind + Rules []*Rule + Else map[*Rule][]*Rule + Default *Rule + EarlyExit bool + OnlyGroundRefs bool +} + +// NewIndexResult returns a new IndexResult object. +func NewIndexResult(kind RuleKind) *IndexResult { + return &IndexResult{ + Kind: kind, + Else: map[*Rule][]*Rule{}, + } +} + +// Empty returns true if there are no rules to evaluate. +func (ir *IndexResult) Empty() bool { + return len(ir.Rules) == 0 && ir.Default == nil +} + +type baseDocEqIndex struct { + skipIndexing Set + isVirtual func(Ref) bool + root *trieNode + defaultRule *Rule + kind RuleKind + onlyGroundRefs bool +} + +var ( + equalityRef = Equality.Ref() + equalRef = Equal.Ref() + globMatchRef = GlobMatch.Ref() + internalPrintRef = InternalPrint.Ref() +) + +func newBaseDocEqIndex(isVirtual func(Ref) bool) *baseDocEqIndex { + return &baseDocEqIndex{ + skipIndexing: NewSet(NewTerm(internalPrintRef)), + isVirtual: isVirtual, + root: newTrieNodeImpl(), + onlyGroundRefs: true, + } +} + +func (i *baseDocEqIndex) Build(rules []*Rule) bool { + if len(rules) == 0 { + return false + } + + i.kind = rules[0].Head.RuleKind() + indices := newrefindices(i.isVirtual) + + // build indices for each rule. + for idx := range rules { + WalkRules(rules[idx], func(rule *Rule) bool { + if rule.Default { + i.defaultRule = rule + return false + } + if i.onlyGroundRefs { + i.onlyGroundRefs = rule.Head.Reference.IsGround() + } + var skip bool + for _, expr := range rule.Body { + if op := expr.OperatorTerm(); op != nil && i.skipIndexing.Contains(op) { + skip = true + break + } + } + if !skip { + for _, expr := range rule.Body { + indices.Update(rule, expr) + } + } + return false + }) + } + + // build trie out of indices. + for idx := range rules { + var prio int + WalkRules(rules[idx], func(rule *Rule) bool { + if rule.Default { + return false + } + node := i.root + if indices.Indexed(rule) { + for _, ref := range indices.Sorted() { + node = node.Insert(ref, indices.Value(rule, ref), indices.Mapper(rule, ref)) + } + } + // Insert rule into trie with (insertion order, priority order) + // tuple. Retaining the insertion order allows us to return rules + // in the order they were passed to this function. + node.append([...]int{idx, prio}, rule) + prio++ + return false + }) + } + return true +} + +func (i *baseDocEqIndex) Lookup(resolver ValueResolver) (*IndexResult, error) { + tr := ttrPool.Get().(*trieTraversalResult) + + defer func() { + clear(tr.unordered) + tr.ordering = tr.ordering[:0] + tr.values.clear() + + ttrPool.Put(tr) + }() + + err := i.root.Traverse(resolver, tr) + if err != nil { + return nil, err + } + + result := NewIndexResult(i.kind) + result.Default = i.defaultRule + result.OnlyGroundRefs = i.onlyGroundRefs + result.Rules = make([]*Rule, 0, len(tr.ordering)) + + for _, pos := range tr.ordering { + sort.Slice(tr.unordered[pos], func(i, j int) bool { + return tr.unordered[pos][i].prio[1] < tr.unordered[pos][j].prio[1] + }) + nodes := tr.unordered[pos] + root := nodes[0].rule + + result.Rules = append(result.Rules, root) + if len(nodes) > 1 { + result.Else[root] = make([]*Rule, len(nodes)-1) + for i := 1; i < len(nodes); i++ { + result.Else[root][i-1] = nodes[i].rule + } + } + } + + result.EarlyExit = tr.values.Len() == 1 && tr.values.Slice()[0].IsGround() + + return result, nil +} + +func (i *baseDocEqIndex) AllRules(_ ValueResolver) (*IndexResult, error) { + tr := newTrieTraversalResult() + + // Walk over the rule trie and accumulate _all_ rules + rw := &ruleWalker{result: tr} + i.root.Do(rw) + + result := NewIndexResult(i.kind) + result.Default = i.defaultRule + result.OnlyGroundRefs = i.onlyGroundRefs + result.Rules = make([]*Rule, 0, len(tr.ordering)) + + for _, pos := range tr.ordering { + sort.Slice(tr.unordered[pos], func(i, j int) bool { + return tr.unordered[pos][i].prio[1] < tr.unordered[pos][j].prio[1] + }) + nodes := tr.unordered[pos] + root := nodes[0].rule + result.Rules = append(result.Rules, root) + if len(nodes) > 1 { + result.Else[root] = make([]*Rule, len(nodes)-1) + for i := 1; i < len(nodes); i++ { + result.Else[root][i-1] = nodes[i].rule + } + } + } + + result.EarlyExit = tr.values.Len() == 1 && tr.values.Slice()[0].IsGround() + + return result, nil +} + +type ruleWalker struct { + result *trieTraversalResult +} + +func (r *ruleWalker) Do(x interface{}) trieWalker { + tn := x.(*trieNode) + r.result.Add(tn) + return r +} + +type valueMapper struct { + Key string + MapValue func(Value) Value +} + +type refindex struct { + Ref Ref + Value Value + Mapper *valueMapper +} + +type refindices struct { + isVirtual func(Ref) bool + rules map[*Rule][]*refindex + frequency *util.HashMap + sorted []Ref +} + +func newrefindices(isVirtual func(Ref) bool) *refindices { + return &refindices{ + isVirtual: isVirtual, + rules: map[*Rule][]*refindex{}, + frequency: util.NewHashMap(func(a, b util.T) bool { + r1, r2 := a.(Ref), b.(Ref) + return r1.Equal(r2) + }, func(x util.T) int { + return x.(Ref).Hash() + }), + } +} + +// Update attempts to update the refindices for the given expression in the +// given rule. If the expression cannot be indexed the update does not affect +// the indices. +func (i *refindices) Update(rule *Rule, expr *Expr) { + + if expr.Negated { + return + } + + if len(expr.With) > 0 { + // NOTE(tsandall): In the future, we may need to consider expressions + // that have with statements applied to them. + return + } + + op := expr.Operator() + + switch { + case op.Equal(equalityRef): + i.updateEq(rule, expr) + + case op.Equal(equalRef) && len(expr.Operands()) == 2: + // NOTE(tsandall): if equal() is called with more than two arguments the + // output value is being captured in which case the indexer cannot + // exclude the rule if the equal() call would return false (because the + // false value must still be produced.) + i.updateEq(rule, expr) + + case op.Equal(globMatchRef) && len(expr.Operands()) == 3: + // NOTE(sr): Same as with equal() above -- 4 operands means the output + // of `glob.match` is captured and the rule can thus not be excluded. + i.updateGlobMatch(rule, expr) + } +} + +// Sorted returns a sorted list of references that the indices were built from. +// References that appear more frequently in the indexed rules are ordered +// before less frequently appearing references. +func (i *refindices) Sorted() []Ref { + + if i.sorted == nil { + counts := make([]int, 0, i.frequency.Len()) + i.sorted = make([]Ref, 0, i.frequency.Len()) + + i.frequency.Iter(func(k, v util.T) bool { + counts = append(counts, v.(int)) + i.sorted = append(i.sorted, k.(Ref)) + return false + }) + + sort.Slice(i.sorted, func(a, b int) bool { + if counts[a] > counts[b] { + return true + } else if counts[b] > counts[a] { + return false + } + return i.sorted[a][0].Loc().Compare(i.sorted[b][0].Loc()) < 0 + }) + } + + return i.sorted +} + +func (i *refindices) Indexed(rule *Rule) bool { + return len(i.rules[rule]) > 0 +} + +func (i *refindices) Value(rule *Rule, ref Ref) Value { + if index := i.index(rule, ref); index != nil { + return index.Value + } + return nil +} + +func (i *refindices) Mapper(rule *Rule, ref Ref) *valueMapper { + if index := i.index(rule, ref); index != nil { + return index.Mapper + } + return nil +} + +func (i *refindices) updateEq(rule *Rule, expr *Expr) { + a, b := expr.Operand(0), expr.Operand(1) + args := rule.Head.Args + if idx, ok := eqOperandsToRefAndValue(i.isVirtual, args, a, b); ok { + i.insert(rule, idx) + return + } + if idx, ok := eqOperandsToRefAndValue(i.isVirtual, args, b, a); ok { + i.insert(rule, idx) + return + } +} + +func (i *refindices) updateGlobMatch(rule *Rule, expr *Expr) { + args := rule.Head.Args + + delim, ok := globDelimiterToString(expr.Operand(1)) + if !ok { + return + } + + if arr := globPatternToArray(expr.Operand(0), delim); arr != nil { + // The 3rd operand of glob.match is the value to match. We assume the + // 3rd operand was a reference that has been rewritten and bound to a + // variable earlier in the query OR a function argument variable. + match := expr.Operand(2) + if _, ok := match.Value.(Var); ok { + var ref Ref + for _, other := range i.rules[rule] { + if _, ok := other.Value.(Var); ok && other.Value.Compare(match.Value) == 0 { + ref = other.Ref + } + } + if ref == nil { + for j, arg := range args { + if arg.Equal(match) { + ref = Ref{FunctionArgRootDocument, InternedIntNumberTerm(j)} + } + } + } + if ref != nil { + i.insert(rule, &refindex{ + Ref: ref, + Value: arr.Value, + Mapper: &valueMapper{ + Key: delim, + MapValue: func(v Value) Value { + if s, ok := v.(String); ok { + return stringSliceToArray(splitStringEscaped(string(s), delim)) + } + return v + }, + }, + }) + } + } + } +} + +func (i *refindices) insert(rule *Rule, index *refindex) { + + count, ok := i.frequency.Get(index.Ref) + if !ok { + count = 0 + } + + i.frequency.Put(index.Ref, count.(int)+1) + + for pos, other := range i.rules[rule] { + if other.Ref.Equal(index.Ref) { + i.rules[rule][pos] = index + return + } + } + + i.rules[rule] = append(i.rules[rule], index) +} + +func (i *refindices) index(rule *Rule, ref Ref) *refindex { + for _, index := range i.rules[rule] { + if index.Ref.Equal(ref) { + return index + } + } + return nil +} + +type trieWalker interface { + Do(x interface{}) trieWalker +} + +type trieTraversalResult struct { + unordered map[int][]*ruleNode + ordering []int + values *set +} + +var ttrPool = sync.Pool{ + New: func() any { + return newTrieTraversalResult() + }, +} + +func newTrieTraversalResult() *trieTraversalResult { + return &trieTraversalResult{ + unordered: map[int][]*ruleNode{}, + // Number 3 is arbitrary, but seemed to be the most common number of values + // stored when benchmarking the trie traversal against a large policy library + // (Regal). + values: newset(3), + } +} + +func (tr *trieTraversalResult) Add(t *trieNode) { + for _, node := range t.rules { + root := node.prio[0] + nodes, ok := tr.unordered[root] + if !ok { + tr.ordering = append(tr.ordering, root) + } + tr.unordered[root] = append(nodes, node) + } + if t.values != nil { + t.values.Foreach(tr.values.insertNoGuard) + } +} + +type trieNode struct { + ref Ref + values Set + mappers []*valueMapper + next *trieNode + any *trieNode + undefined *trieNode + scalars *util.HashMap + array *trieNode + rules []*ruleNode +} + +func (node *trieNode) String() string { + var flags []string + flags = append(flags, fmt.Sprintf("self:%p", node)) + if len(node.ref) > 0 { + flags = append(flags, node.ref.String()) + } + if node.next != nil { + flags = append(flags, fmt.Sprintf("next:%p", node.next)) + } + if node.any != nil { + flags = append(flags, fmt.Sprintf("any:%p", node.any)) + } + if node.undefined != nil { + flags = append(flags, fmt.Sprintf("undefined:%p", node.undefined)) + } + if node.array != nil { + flags = append(flags, fmt.Sprintf("array:%p", node.array)) + } + if node.scalars.Len() > 0 { + buf := make([]string, 0, node.scalars.Len()) + node.scalars.Iter(func(k, v util.T) bool { + key := k.(Value) + val := v.(*trieNode) + buf = append(buf, fmt.Sprintf("scalar(%v):%p", key, val)) + return false + }) + sort.Strings(buf) + flags = append(flags, strings.Join(buf, " ")) + } + if len(node.rules) > 0 { + flags = append(flags, fmt.Sprintf("%d rule(s)", len(node.rules))) + } + if len(node.mappers) > 0 { + flags = append(flags, fmt.Sprintf("%d mapper(s)", len(node.mappers))) + } + if node.values != nil { + if l := node.values.Len(); l > 0 { + flags = append(flags, fmt.Sprintf("%d value(s)", l)) + } + } + return strings.Join(flags, " ") +} + +func (node *trieNode) append(prio [2]int, rule *Rule) { + node.rules = append(node.rules, &ruleNode{prio, rule}) + + if node.values != nil && rule.Head.Value != nil { + node.values.Add(rule.Head.Value) + return + } + + if node.values == nil && rule.Head.DocKind() == CompleteDoc { + node.values = NewSet(rule.Head.Value) + } +} + +type ruleNode struct { + prio [2]int + rule *Rule +} + +func newTrieNodeImpl() *trieNode { + return &trieNode{ + scalars: util.NewHashMap(valueEq, valueHash), + } +} + +func (node *trieNode) Do(walker trieWalker) { + next := walker.Do(node) + if next == nil { + return + } + if node.any != nil { + node.any.Do(next) + } + if node.undefined != nil { + node.undefined.Do(next) + } + + node.scalars.Iter(func(_, v util.T) bool { + child := v.(*trieNode) + child.Do(next) + return false + }) + + if node.array != nil { + node.array.Do(next) + } + if node.next != nil { + node.next.Do(next) + } +} + +func (node *trieNode) Insert(ref Ref, value Value, mapper *valueMapper) *trieNode { + + if node.next == nil { + node.next = newTrieNodeImpl() + node.next.ref = ref + } + + if mapper != nil { + node.next.addMapper(mapper) + } + + return node.next.insertValue(value) +} + +func (node *trieNode) Traverse(resolver ValueResolver, tr *trieTraversalResult) error { + + if node == nil { + return nil + } + + tr.Add(node) + + return node.next.traverse(resolver, tr) +} + +func (node *trieNode) addMapper(mapper *valueMapper) { + for i := range node.mappers { + if node.mappers[i].Key == mapper.Key { + return + } + } + node.mappers = append(node.mappers, mapper) +} + +func (node *trieNode) insertValue(value Value) *trieNode { + + switch value := value.(type) { + case nil: + if node.undefined == nil { + node.undefined = newTrieNodeImpl() + } + return node.undefined + case Var: + if node.any == nil { + node.any = newTrieNodeImpl() + } + return node.any + case Null, Boolean, Number, String: + child, ok := node.scalars.Get(value) + if !ok { + child = newTrieNodeImpl() + node.scalars.Put(value, child) + } + return child.(*trieNode) + case *Array: + if node.array == nil { + node.array = newTrieNodeImpl() + } + return node.array.insertArray(value) + } + + panic("illegal value") +} + +func (node *trieNode) insertArray(arr *Array) *trieNode { + + if arr.Len() == 0 { + return node + } + + switch head := arr.Elem(0).Value.(type) { + case Var: + if node.any == nil { + node.any = newTrieNodeImpl() + } + return node.any.insertArray(arr.Slice(1, -1)) + case Null, Boolean, Number, String: + child, ok := node.scalars.Get(head) + if !ok { + child = newTrieNodeImpl() + node.scalars.Put(head, child) + } + return child.(*trieNode).insertArray(arr.Slice(1, -1)) + } + + panic("illegal value") +} + +func (node *trieNode) traverse(resolver ValueResolver, tr *trieTraversalResult) error { + + if node == nil { + return nil + } + + v, err := resolver.Resolve(node.ref) + if err != nil { + if IsUnknownValueErr(err) { + return node.traverseUnknown(resolver, tr) + } + return err + } + + if node.undefined != nil { + err = node.undefined.Traverse(resolver, tr) + if err != nil { + return err + } + } + + if v == nil { + return nil + } + + if node.any != nil { + err = node.any.Traverse(resolver, tr) + if err != nil { + return err + } + } + + if err := node.traverseValue(resolver, tr, v); err != nil { + return err + } + + for i := range node.mappers { + if err := node.traverseValue(resolver, tr, node.mappers[i].MapValue(v)); err != nil { + return err + } + } + + return nil +} + +func (node *trieNode) traverseValue(resolver ValueResolver, tr *trieTraversalResult, value Value) error { + + switch value := value.(type) { + case *Array: + if node.array == nil { + return nil + } + return node.array.traverseArray(resolver, tr, value) + + case Null, Boolean, Number, String: + child, ok := node.scalars.Get(value) + if !ok { + return nil + } + return child.(*trieNode).Traverse(resolver, tr) + } + + return nil +} + +func (node *trieNode) traverseArray(resolver ValueResolver, tr *trieTraversalResult, arr *Array) error { + + if arr.Len() == 0 { + return node.Traverse(resolver, tr) + } + + if node.any != nil { + err := node.any.traverseArray(resolver, tr, arr.Slice(1, -1)) + if err != nil { + return err + } + } + + head := arr.Elem(0).Value + + if !IsScalar(head) { + return nil + } + + child, ok := node.scalars.Get(head) + if !ok { + return nil + } + return child.(*trieNode).traverseArray(resolver, tr, arr.Slice(1, -1)) +} + +func (node *trieNode) traverseUnknown(resolver ValueResolver, tr *trieTraversalResult) error { + + if node == nil { + return nil + } + + if err := node.Traverse(resolver, tr); err != nil { + return err + } + + if err := node.undefined.traverseUnknown(resolver, tr); err != nil { + return err + } + + if err := node.any.traverseUnknown(resolver, tr); err != nil { + return err + } + + if err := node.array.traverseUnknown(resolver, tr); err != nil { + return err + } + + var iterErr error + node.scalars.Iter(func(_, v util.T) bool { + child := v.(*trieNode) + if iterErr = child.traverseUnknown(resolver, tr); iterErr != nil { + return true + } + return false + }) + + return iterErr +} + +// If term `a` is one of the function's operands, we store a Ref: `args[0]` +// for the argument number. So for `f(x, y) { x = 10; y = 12 }`, we'll +// bind `args[0]` and `args[1]` to this rule when called for (x=10) and +// (y=12) respectively. +func eqOperandsToRefAndValue(isVirtual func(Ref) bool, args []*Term, a, b *Term) (*refindex, bool) { + switch v := a.Value.(type) { + case Var: + for i, arg := range args { + if arg.Value.Compare(v) == 0 { + if bval, ok := indexValue(b); ok { + return &refindex{Ref: Ref{FunctionArgRootDocument, InternedIntNumberTerm(i)}, Value: bval}, true + } + } + } + case Ref: + if !RootDocumentNames.Contains(v[0]) { + return nil, false + } + if isVirtual(v) { + return nil, false + } + if v.IsNested() || !v.IsGround() { + return nil, false + } + if bval, ok := indexValue(b); ok { + return &refindex{Ref: v, Value: bval}, true + } + } + return nil, false +} + +func indexValue(b *Term) (Value, bool) { + switch b := b.Value.(type) { + case Null, Boolean, Number, String, Var: + return b, true + case *Array: + stop := false + first := true + vis := NewGenericVisitor(func(x interface{}) bool { + if first { + first = false + return false + } + switch x.(type) { + // No nested structures or values that require evaluation (other than var). + case *Array, Object, Set, *ArrayComprehension, *ObjectComprehension, *SetComprehension, Ref: + stop = true + } + return stop + }) + vis.Walk(b) + if !stop { + return b, true + } + } + + return nil, false +} + +func globDelimiterToString(delim *Term) (string, bool) { + + arr, ok := delim.Value.(*Array) + if !ok { + return "", false + } + + var result string + + if arr.Len() == 0 { + result = "." + } else { + for i := 0; i < arr.Len(); i++ { + term := arr.Elem(i) + s, ok := term.Value.(String) + if !ok { + return "", false + } + result += string(s) + } + } + + return result, true +} + +func globPatternToArray(pattern *Term, delim string) *Term { + + s, ok := pattern.Value.(String) + if !ok { + return nil + } + + parts := splitStringEscaped(string(s), delim) + arr := make([]*Term, len(parts)) + + for i := range parts { + if parts[i] == "*" { + arr[i] = VarTerm("$globwildcard") + } else { + var escaped bool + for _, c := range parts[i] { + if c == '\\' { + escaped = !escaped + continue + } + if !escaped { + switch c { + case '[', '?', '{', '*': + // TODO(tsandall): super glob and character pattern + // matching not supported yet. + return nil + } + } + escaped = false + } + arr[i] = StringTerm(parts[i]) + } + } + + return NewTerm(NewArray(arr...)) +} + +// splits s on characters in delim except if delim characters have been escaped +// with reverse solidus. +func splitStringEscaped(s string, delim string) []string { + + var last, curr int + var escaped bool + var result []string + + for ; curr < len(s); curr++ { + if s[curr] == '\\' || escaped { + escaped = !escaped + continue + } + if strings.ContainsRune(delim, rune(s[curr])) { + result = append(result, s[last:curr]) + last = curr + 1 + } + } + + result = append(result, s[last:]) + + return result +} + +func stringSliceToArray(s []string) *Array { + arr := make([]*Term, len(s)) + for i, v := range s { + arr[i] = StringTerm(v) + } + return NewArray(arr...) +} diff --git a/vendor/github.com/open-policy-agent/opa/ast/internal/scanner/scanner.go b/vendor/github.com/open-policy-agent/opa/v1/ast/internal/scanner/scanner.go similarity index 95% rename from vendor/github.com/open-policy-agent/opa/ast/internal/scanner/scanner.go rename to vendor/github.com/open-policy-agent/opa/v1/ast/internal/scanner/scanner.go index a0200ac18..4558f9141 100644 --- a/vendor/github.com/open-policy-agent/opa/ast/internal/scanner/scanner.go +++ b/vendor/github.com/open-policy-agent/opa/v1/ast/internal/scanner/scanner.go @@ -9,8 +9,9 @@ import ( "io" "unicode" "unicode/utf8" + "unsafe" - "github.com/open-policy-agent/opa/ast/internal/tokens" + "github.com/open-policy-agent/opa/v1/ast/internal/tokens" ) const bom = 0xFEFF @@ -18,31 +19,31 @@ const bom = 0xFEFF // Scanner is used to tokenize an input stream of // Rego source code. type Scanner struct { + keywords map[string]tokens.Token + bs []byte + errors []Error + tabs []int offset int row int col int - bs []byte - curr rune width int - errors []Error - keywords map[string]tokens.Token - tabs []int + curr rune regoV1Compatible bool } // Error represents a scanner error. type Error struct { - Pos Position Message string + Pos Position } // Position represents a point in the scanned source code. type Position struct { + Tabs []int // positions of any tabs preceding Col Offset int // start offset in bytes End int // end offset in bytes Row int // line number computed in bytes Col int // column number computed in bytes - Tabs []int // positions of any tabs preceding Col } // New returns an initialized scanner that will scan @@ -270,7 +271,8 @@ func (s *Scanner) scanIdentifier() string { for isLetter(s.curr) || isDigit(s.curr) { s.next() } - return string(s.bs[start : s.offset-1]) + + return byteSliceToString(s.bs[start : s.offset-1]) } func (s *Scanner) scanNumber() string { @@ -321,7 +323,7 @@ func (s *Scanner) scanNumber() string { } } - return string(s.bs[start : s.offset-1]) + return byteSliceToString(s.bs[start : s.offset-1]) } func (s *Scanner) scanString() string { @@ -355,7 +357,7 @@ func (s *Scanner) scanString() string { } } - return string(s.bs[start : s.offset-1]) + return byteSliceToString(s.bs[start : s.offset-1]) } func (s *Scanner) scanRawString() string { @@ -370,7 +372,8 @@ func (s *Scanner) scanRawString() string { break } } - return string(s.bs[start : s.offset-1]) + + return byteSliceToString(s.bs[start : s.offset-1]) } func (s *Scanner) scanComment() string { @@ -383,7 +386,8 @@ func (s *Scanner) scanComment() string { if s.offset > 1 && s.bs[s.offset-2] == '\r' { end = end - 1 } - return string(s.bs[start:end]) + + return byteSliceToString(s.bs[start:end]) } func (s *Scanner) next() { @@ -413,7 +417,7 @@ func (s *Scanner) next() { if s.curr == '\n' { s.row++ s.col = 0 - s.tabs = []int{} + s.tabs = s.tabs[:0] } else { s.col++ if s.curr == '\t' { @@ -453,3 +457,7 @@ func (s *Scanner) error(reason string) { Col: s.col, }, Message: reason}) } + +func byteSliceToString(bs []byte) string { + return unsafe.String(unsafe.SliceData(bs), len(bs)) +} diff --git a/vendor/github.com/open-policy-agent/opa/ast/internal/tokens/tokens.go b/vendor/github.com/open-policy-agent/opa/v1/ast/internal/tokens/tokens.go similarity index 100% rename from vendor/github.com/open-policy-agent/opa/ast/internal/tokens/tokens.go rename to vendor/github.com/open-policy-agent/opa/v1/ast/internal/tokens/tokens.go diff --git a/vendor/github.com/open-policy-agent/opa/v1/ast/interning.go b/vendor/github.com/open-policy-agent/opa/v1/ast/interning.go new file mode 100644 index 000000000..f521af966 --- /dev/null +++ b/vendor/github.com/open-policy-agent/opa/v1/ast/interning.go @@ -0,0 +1,1092 @@ +// Copyright 2024 The OPA Authors. All rights reserved. +// Use of this source code is governed by an Apache2 +// license that can be found in the LICENSE file. + +package ast + +import "strconv" + +// NOTE! Great care must be taken **not** to modify the terms returned +// from these functions, as they are shared across all callers. + +var ( + booleanTrueTerm = &Term{Value: Boolean(true)} + booleanFalseTerm = &Term{Value: Boolean(false)} + + // since this is by far the most common negative number + minusOneTerm = &Term{Value: Number("-1")} +) + +// InternedBooleanTerm returns an interned term with the given boolean value. +func InternedBooleanTerm(b bool) *Term { + if b { + return booleanTrueTerm + } + + return booleanFalseTerm +} + +// InternedIntNumberTerm returns a term with the given integer value. The term is +// cached between -1 to 512, and for values outside of that range, this function +// is equivalent to ast.IntNumberTerm. +func InternedIntNumberTerm(i int) *Term { + if i >= 0 && i < len(intNumberTerms) { + return intNumberTerms[i] + } + + if i == -1 { + return minusOneTerm + } + + return &Term{Value: Number(strconv.Itoa(i))} +} + +// InternedIntFromString returns a term with the given integer value if the string +// maps to an interned term. If the string does not map to an interned term, nil is +// returned. +func InternedIntNumberTermFromString(s string) *Term { + if term, ok := stringToIntNumberTermMap[s]; ok { + return term + } + + return nil +} + +// HasInternedIntNumberTerm returns true if the given integer value maps to an interned +// term, otherwise false. +func HasInternedIntNumberTerm(i int) bool { + return i >= -1 && i < len(intNumberTerms) +} + +var stringToIntNumberTermMap = map[string]*Term{ + "-1": minusOneTerm, + "0": intNumberTerms[0], + "1": intNumberTerms[1], + "2": intNumberTerms[2], + "3": intNumberTerms[3], + "4": intNumberTerms[4], + "5": intNumberTerms[5], + "6": intNumberTerms[6], + "7": intNumberTerms[7], + "8": intNumberTerms[8], + "9": intNumberTerms[9], + "10": intNumberTerms[10], + "11": intNumberTerms[11], + "12": intNumberTerms[12], + "13": intNumberTerms[13], + "14": intNumberTerms[14], + "15": intNumberTerms[15], + "16": intNumberTerms[16], + "17": intNumberTerms[17], + "18": intNumberTerms[18], + "19": intNumberTerms[19], + "20": intNumberTerms[20], + "21": intNumberTerms[21], + "22": intNumberTerms[22], + "23": intNumberTerms[23], + "24": intNumberTerms[24], + "25": intNumberTerms[25], + "26": intNumberTerms[26], + "27": intNumberTerms[27], + "28": intNumberTerms[28], + "29": intNumberTerms[29], + "30": intNumberTerms[30], + "31": intNumberTerms[31], + "32": intNumberTerms[32], + "33": intNumberTerms[33], + "34": intNumberTerms[34], + "35": intNumberTerms[35], + "36": intNumberTerms[36], + "37": intNumberTerms[37], + "38": intNumberTerms[38], + "39": intNumberTerms[39], + "40": intNumberTerms[40], + "41": intNumberTerms[41], + "42": intNumberTerms[42], + "43": intNumberTerms[43], + "44": intNumberTerms[44], + "45": intNumberTerms[45], + "46": intNumberTerms[46], + "47": intNumberTerms[47], + "48": intNumberTerms[48], + "49": intNumberTerms[49], + "50": intNumberTerms[50], + "51": intNumberTerms[51], + "52": intNumberTerms[52], + "53": intNumberTerms[53], + "54": intNumberTerms[54], + "55": intNumberTerms[55], + "56": intNumberTerms[56], + "57": intNumberTerms[57], + "58": intNumberTerms[58], + "59": intNumberTerms[59], + "60": intNumberTerms[60], + "61": intNumberTerms[61], + "62": intNumberTerms[62], + "63": intNumberTerms[63], + "64": intNumberTerms[64], + "65": intNumberTerms[65], + "66": intNumberTerms[66], + "67": intNumberTerms[67], + "68": intNumberTerms[68], + "69": intNumberTerms[69], + "70": intNumberTerms[70], + "71": intNumberTerms[71], + "72": intNumberTerms[72], + "73": intNumberTerms[73], + "74": intNumberTerms[74], + "75": intNumberTerms[75], + "76": intNumberTerms[76], + "77": intNumberTerms[77], + "78": intNumberTerms[78], + "79": intNumberTerms[79], + "80": intNumberTerms[80], + "81": intNumberTerms[81], + "82": intNumberTerms[82], + "83": intNumberTerms[83], + "84": intNumberTerms[84], + "85": intNumberTerms[85], + "86": intNumberTerms[86], + "87": intNumberTerms[87], + "88": intNumberTerms[88], + "89": intNumberTerms[89], + "90": intNumberTerms[90], + "91": intNumberTerms[91], + "92": intNumberTerms[92], + "93": intNumberTerms[93], + "94": intNumberTerms[94], + "95": intNumberTerms[95], + "96": intNumberTerms[96], + "97": intNumberTerms[97], + "98": intNumberTerms[98], + "99": intNumberTerms[99], + "100": intNumberTerms[100], + "101": intNumberTerms[101], + "102": intNumberTerms[102], + "103": intNumberTerms[103], + "104": intNumberTerms[104], + "105": intNumberTerms[105], + "106": intNumberTerms[106], + "107": intNumberTerms[107], + "108": intNumberTerms[108], + "109": intNumberTerms[109], + "110": intNumberTerms[110], + "111": intNumberTerms[111], + "112": intNumberTerms[112], + "113": intNumberTerms[113], + "114": intNumberTerms[114], + "115": intNumberTerms[115], + "116": intNumberTerms[116], + "117": intNumberTerms[117], + "118": intNumberTerms[118], + "119": intNumberTerms[119], + "120": intNumberTerms[120], + "121": intNumberTerms[121], + "122": intNumberTerms[122], + "123": intNumberTerms[123], + "124": intNumberTerms[124], + "125": intNumberTerms[125], + "126": intNumberTerms[126], + "127": intNumberTerms[127], + "128": intNumberTerms[128], + "129": intNumberTerms[129], + "130": intNumberTerms[130], + "131": intNumberTerms[131], + "132": intNumberTerms[132], + "133": intNumberTerms[133], + "134": intNumberTerms[134], + "135": intNumberTerms[135], + "136": intNumberTerms[136], + "137": intNumberTerms[137], + "138": intNumberTerms[138], + "139": intNumberTerms[139], + "140": intNumberTerms[140], + "141": intNumberTerms[141], + "142": intNumberTerms[142], + "143": intNumberTerms[143], + "144": intNumberTerms[144], + "145": intNumberTerms[145], + "146": intNumberTerms[146], + "147": intNumberTerms[147], + "148": intNumberTerms[148], + "149": intNumberTerms[149], + "150": intNumberTerms[150], + "151": intNumberTerms[151], + "152": intNumberTerms[152], + "153": intNumberTerms[153], + "154": intNumberTerms[154], + "155": intNumberTerms[155], + "156": intNumberTerms[156], + "157": intNumberTerms[157], + "158": intNumberTerms[158], + "159": intNumberTerms[159], + "160": intNumberTerms[160], + "161": intNumberTerms[161], + "162": intNumberTerms[162], + "163": intNumberTerms[163], + "164": intNumberTerms[164], + "165": intNumberTerms[165], + "166": intNumberTerms[166], + "167": intNumberTerms[167], + "168": intNumberTerms[168], + "169": intNumberTerms[169], + "170": intNumberTerms[170], + "171": intNumberTerms[171], + "172": intNumberTerms[172], + "173": intNumberTerms[173], + "174": intNumberTerms[174], + "175": intNumberTerms[175], + "176": intNumberTerms[176], + "177": intNumberTerms[177], + "178": intNumberTerms[178], + "179": intNumberTerms[179], + "180": intNumberTerms[180], + "181": intNumberTerms[181], + "182": intNumberTerms[182], + "183": intNumberTerms[183], + "184": intNumberTerms[184], + "185": intNumberTerms[185], + "186": intNumberTerms[186], + "187": intNumberTerms[187], + "188": intNumberTerms[188], + "189": intNumberTerms[189], + "190": intNumberTerms[190], + "191": intNumberTerms[191], + "192": intNumberTerms[192], + "193": intNumberTerms[193], + "194": intNumberTerms[194], + "195": intNumberTerms[195], + "196": intNumberTerms[196], + "197": intNumberTerms[197], + "198": intNumberTerms[198], + "199": intNumberTerms[199], + "200": intNumberTerms[200], + "201": intNumberTerms[201], + "202": intNumberTerms[202], + "203": intNumberTerms[203], + "204": intNumberTerms[204], + "205": intNumberTerms[205], + "206": intNumberTerms[206], + "207": intNumberTerms[207], + "208": intNumberTerms[208], + "209": intNumberTerms[209], + "210": intNumberTerms[210], + "211": intNumberTerms[211], + "212": intNumberTerms[212], + "213": intNumberTerms[213], + "214": intNumberTerms[214], + "215": intNumberTerms[215], + "216": intNumberTerms[216], + "217": intNumberTerms[217], + "218": intNumberTerms[218], + "219": intNumberTerms[219], + "220": intNumberTerms[220], + "221": intNumberTerms[221], + "222": intNumberTerms[222], + "223": intNumberTerms[223], + "224": intNumberTerms[224], + "225": intNumberTerms[225], + "226": intNumberTerms[226], + "227": intNumberTerms[227], + "228": intNumberTerms[228], + "229": intNumberTerms[229], + "230": intNumberTerms[230], + "231": intNumberTerms[231], + "232": intNumberTerms[232], + "233": intNumberTerms[233], + "234": intNumberTerms[234], + "235": intNumberTerms[235], + "236": intNumberTerms[236], + "237": intNumberTerms[237], + "238": intNumberTerms[238], + "239": intNumberTerms[239], + "240": intNumberTerms[240], + "241": intNumberTerms[241], + "242": intNumberTerms[242], + "243": intNumberTerms[243], + "244": intNumberTerms[244], + "245": intNumberTerms[245], + "246": intNumberTerms[246], + "247": intNumberTerms[247], + "248": intNumberTerms[248], + "249": intNumberTerms[249], + "250": intNumberTerms[250], + "251": intNumberTerms[251], + "252": intNumberTerms[252], + "253": intNumberTerms[253], + "254": intNumberTerms[254], + "255": intNumberTerms[255], + "256": intNumberTerms[256], + "257": intNumberTerms[257], + "258": intNumberTerms[258], + "259": intNumberTerms[259], + "260": intNumberTerms[260], + "261": intNumberTerms[261], + "262": intNumberTerms[262], + "263": intNumberTerms[263], + "264": intNumberTerms[264], + "265": intNumberTerms[265], + "266": intNumberTerms[266], + "267": intNumberTerms[267], + "268": intNumberTerms[268], + "269": intNumberTerms[269], + "270": intNumberTerms[270], + "271": intNumberTerms[271], + "272": intNumberTerms[272], + "273": intNumberTerms[273], + "274": intNumberTerms[274], + "275": intNumberTerms[275], + "276": intNumberTerms[276], + "277": intNumberTerms[277], + "278": intNumberTerms[278], + "279": intNumberTerms[279], + "280": intNumberTerms[280], + "281": intNumberTerms[281], + "282": intNumberTerms[282], + "283": intNumberTerms[283], + "284": intNumberTerms[284], + "285": intNumberTerms[285], + "286": intNumberTerms[286], + "287": intNumberTerms[287], + "288": intNumberTerms[288], + "289": intNumberTerms[289], + "290": intNumberTerms[290], + "291": intNumberTerms[291], + "292": intNumberTerms[292], + "293": intNumberTerms[293], + "294": intNumberTerms[294], + "295": intNumberTerms[295], + "296": intNumberTerms[296], + "297": intNumberTerms[297], + "298": intNumberTerms[298], + "299": intNumberTerms[299], + "300": intNumberTerms[300], + "301": intNumberTerms[301], + "302": intNumberTerms[302], + "303": intNumberTerms[303], + "304": intNumberTerms[304], + "305": intNumberTerms[305], + "306": intNumberTerms[306], + "307": intNumberTerms[307], + "308": intNumberTerms[308], + "309": intNumberTerms[309], + "310": intNumberTerms[310], + "311": intNumberTerms[311], + "312": intNumberTerms[312], + "313": intNumberTerms[313], + "314": intNumberTerms[314], + "315": intNumberTerms[315], + "316": intNumberTerms[316], + "317": intNumberTerms[317], + "318": intNumberTerms[318], + "319": intNumberTerms[319], + "320": intNumberTerms[320], + "321": intNumberTerms[321], + "322": intNumberTerms[322], + "323": intNumberTerms[323], + "324": intNumberTerms[324], + "325": intNumberTerms[325], + "326": intNumberTerms[326], + "327": intNumberTerms[327], + "328": intNumberTerms[328], + "329": intNumberTerms[329], + "330": intNumberTerms[330], + "331": intNumberTerms[331], + "332": intNumberTerms[332], + "333": intNumberTerms[333], + "334": intNumberTerms[334], + "335": intNumberTerms[335], + "336": intNumberTerms[336], + "337": intNumberTerms[337], + "338": intNumberTerms[338], + "339": intNumberTerms[339], + "340": intNumberTerms[340], + "341": intNumberTerms[341], + "342": intNumberTerms[342], + "343": intNumberTerms[343], + "344": intNumberTerms[344], + "345": intNumberTerms[345], + "346": intNumberTerms[346], + "347": intNumberTerms[347], + "348": intNumberTerms[348], + "349": intNumberTerms[349], + "350": intNumberTerms[350], + "351": intNumberTerms[351], + "352": intNumberTerms[352], + "353": intNumberTerms[353], + "354": intNumberTerms[354], + "355": intNumberTerms[355], + "356": intNumberTerms[356], + "357": intNumberTerms[357], + "358": intNumberTerms[358], + "359": intNumberTerms[359], + "360": intNumberTerms[360], + "361": intNumberTerms[361], + "362": intNumberTerms[362], + "363": intNumberTerms[363], + "364": intNumberTerms[364], + "365": intNumberTerms[365], + "366": intNumberTerms[366], + "367": intNumberTerms[367], + "368": intNumberTerms[368], + "369": intNumberTerms[369], + "370": intNumberTerms[370], + "371": intNumberTerms[371], + "372": intNumberTerms[372], + "373": intNumberTerms[373], + "374": intNumberTerms[374], + "375": intNumberTerms[375], + "376": intNumberTerms[376], + "377": intNumberTerms[377], + "378": intNumberTerms[378], + "379": intNumberTerms[379], + "380": intNumberTerms[380], + "381": intNumberTerms[381], + "382": intNumberTerms[382], + "383": intNumberTerms[383], + "384": intNumberTerms[384], + "385": intNumberTerms[385], + "386": intNumberTerms[386], + "387": intNumberTerms[387], + "388": intNumberTerms[388], + "389": intNumberTerms[389], + "390": intNumberTerms[390], + "391": intNumberTerms[391], + "392": intNumberTerms[392], + "393": intNumberTerms[393], + "394": intNumberTerms[394], + "395": intNumberTerms[395], + "396": intNumberTerms[396], + "397": intNumberTerms[397], + "398": intNumberTerms[398], + "399": intNumberTerms[399], + "400": intNumberTerms[400], + "401": intNumberTerms[401], + "402": intNumberTerms[402], + "403": intNumberTerms[403], + "404": intNumberTerms[404], + "405": intNumberTerms[405], + "406": intNumberTerms[406], + "407": intNumberTerms[407], + "408": intNumberTerms[408], + "409": intNumberTerms[409], + "410": intNumberTerms[410], + "411": intNumberTerms[411], + "412": intNumberTerms[412], + "413": intNumberTerms[413], + "414": intNumberTerms[414], + "415": intNumberTerms[415], + "416": intNumberTerms[416], + "417": intNumberTerms[417], + "418": intNumberTerms[418], + "419": intNumberTerms[419], + "420": intNumberTerms[420], + "421": intNumberTerms[421], + "422": intNumberTerms[422], + "423": intNumberTerms[423], + "424": intNumberTerms[424], + "425": intNumberTerms[425], + "426": intNumberTerms[426], + "427": intNumberTerms[427], + "428": intNumberTerms[428], + "429": intNumberTerms[429], + "430": intNumberTerms[430], + "431": intNumberTerms[431], + "432": intNumberTerms[432], + "433": intNumberTerms[433], + "434": intNumberTerms[434], + "435": intNumberTerms[435], + "436": intNumberTerms[436], + "437": intNumberTerms[437], + "438": intNumberTerms[438], + "439": intNumberTerms[439], + "440": intNumberTerms[440], + "441": intNumberTerms[441], + "442": intNumberTerms[442], + "443": intNumberTerms[443], + "444": intNumberTerms[444], + "445": intNumberTerms[445], + "446": intNumberTerms[446], + "447": intNumberTerms[447], + "448": intNumberTerms[448], + "449": intNumberTerms[449], + "450": intNumberTerms[450], + "451": intNumberTerms[451], + "452": intNumberTerms[452], + "453": intNumberTerms[453], + "454": intNumberTerms[454], + "455": intNumberTerms[455], + "456": intNumberTerms[456], + "457": intNumberTerms[457], + "458": intNumberTerms[458], + "459": intNumberTerms[459], + "460": intNumberTerms[460], + "461": intNumberTerms[461], + "462": intNumberTerms[462], + "463": intNumberTerms[463], + "464": intNumberTerms[464], + "465": intNumberTerms[465], + "466": intNumberTerms[466], + "467": intNumberTerms[467], + "468": intNumberTerms[468], + "469": intNumberTerms[469], + "470": intNumberTerms[470], + "471": intNumberTerms[471], + "472": intNumberTerms[472], + "473": intNumberTerms[473], + "474": intNumberTerms[474], + "475": intNumberTerms[475], + "476": intNumberTerms[476], + "477": intNumberTerms[477], + "478": intNumberTerms[478], + "479": intNumberTerms[479], + "480": intNumberTerms[480], + "481": intNumberTerms[481], + "482": intNumberTerms[482], + "483": intNumberTerms[483], + "484": intNumberTerms[484], + "485": intNumberTerms[485], + "486": intNumberTerms[486], + "487": intNumberTerms[487], + "488": intNumberTerms[488], + "489": intNumberTerms[489], + "490": intNumberTerms[490], + "491": intNumberTerms[491], + "492": intNumberTerms[492], + "493": intNumberTerms[493], + "494": intNumberTerms[494], + "495": intNumberTerms[495], + "496": intNumberTerms[496], + "497": intNumberTerms[497], + "498": intNumberTerms[498], + "499": intNumberTerms[499], + "500": intNumberTerms[500], + "501": intNumberTerms[501], + "502": intNumberTerms[502], + "503": intNumberTerms[503], + "504": intNumberTerms[504], + "505": intNumberTerms[505], + "506": intNumberTerms[506], + "507": intNumberTerms[507], + "508": intNumberTerms[508], + "509": intNumberTerms[509], + "510": intNumberTerms[510], + "511": intNumberTerms[511], + "512": intNumberTerms[512], +} + +var intNumberTerms = [...]*Term{ + {Value: Number("0")}, + {Value: Number("1")}, + {Value: Number("2")}, + {Value: Number("3")}, + {Value: Number("4")}, + {Value: Number("5")}, + {Value: Number("6")}, + {Value: Number("7")}, + {Value: Number("8")}, + {Value: Number("9")}, + {Value: Number("10")}, + {Value: Number("11")}, + {Value: Number("12")}, + {Value: Number("13")}, + {Value: Number("14")}, + {Value: Number("15")}, + {Value: Number("16")}, + {Value: Number("17")}, + {Value: Number("18")}, + {Value: Number("19")}, + {Value: Number("20")}, + {Value: Number("21")}, + {Value: Number("22")}, + {Value: Number("23")}, + {Value: Number("24")}, + {Value: Number("25")}, + {Value: Number("26")}, + {Value: Number("27")}, + {Value: Number("28")}, + {Value: Number("29")}, + {Value: Number("30")}, + {Value: Number("31")}, + {Value: Number("32")}, + {Value: Number("33")}, + {Value: Number("34")}, + {Value: Number("35")}, + {Value: Number("36")}, + {Value: Number("37")}, + {Value: Number("38")}, + {Value: Number("39")}, + {Value: Number("40")}, + {Value: Number("41")}, + {Value: Number("42")}, + {Value: Number("43")}, + {Value: Number("44")}, + {Value: Number("45")}, + {Value: Number("46")}, + {Value: Number("47")}, + {Value: Number("48")}, + {Value: Number("49")}, + {Value: Number("50")}, + {Value: Number("51")}, + {Value: Number("52")}, + {Value: Number("53")}, + {Value: Number("54")}, + {Value: Number("55")}, + {Value: Number("56")}, + {Value: Number("57")}, + {Value: Number("58")}, + {Value: Number("59")}, + {Value: Number("60")}, + {Value: Number("61")}, + {Value: Number("62")}, + {Value: Number("63")}, + {Value: Number("64")}, + {Value: Number("65")}, + {Value: Number("66")}, + {Value: Number("67")}, + {Value: Number("68")}, + {Value: Number("69")}, + {Value: Number("70")}, + {Value: Number("71")}, + {Value: Number("72")}, + {Value: Number("73")}, + {Value: Number("74")}, + {Value: Number("75")}, + {Value: Number("76")}, + {Value: Number("77")}, + {Value: Number("78")}, + {Value: Number("79")}, + {Value: Number("80")}, + {Value: Number("81")}, + {Value: Number("82")}, + {Value: Number("83")}, + {Value: Number("84")}, + {Value: Number("85")}, + {Value: Number("86")}, + {Value: Number("87")}, + {Value: Number("88")}, + {Value: Number("89")}, + {Value: Number("90")}, + {Value: Number("91")}, + {Value: Number("92")}, + {Value: Number("93")}, + {Value: Number("94")}, + {Value: Number("95")}, + {Value: Number("96")}, + {Value: Number("97")}, + {Value: Number("98")}, + {Value: Number("99")}, + {Value: Number("100")}, + {Value: Number("101")}, + {Value: Number("102")}, + {Value: Number("103")}, + {Value: Number("104")}, + {Value: Number("105")}, + {Value: Number("106")}, + {Value: Number("107")}, + {Value: Number("108")}, + {Value: Number("109")}, + {Value: Number("110")}, + {Value: Number("111")}, + {Value: Number("112")}, + {Value: Number("113")}, + {Value: Number("114")}, + {Value: Number("115")}, + {Value: Number("116")}, + {Value: Number("117")}, + {Value: Number("118")}, + {Value: Number("119")}, + {Value: Number("120")}, + {Value: Number("121")}, + {Value: Number("122")}, + {Value: Number("123")}, + {Value: Number("124")}, + {Value: Number("125")}, + {Value: Number("126")}, + {Value: Number("127")}, + {Value: Number("128")}, + {Value: Number("129")}, + {Value: Number("130")}, + {Value: Number("131")}, + {Value: Number("132")}, + {Value: Number("133")}, + {Value: Number("134")}, + {Value: Number("135")}, + {Value: Number("136")}, + {Value: Number("137")}, + {Value: Number("138")}, + {Value: Number("139")}, + {Value: Number("140")}, + {Value: Number("141")}, + {Value: Number("142")}, + {Value: Number("143")}, + {Value: Number("144")}, + {Value: Number("145")}, + {Value: Number("146")}, + {Value: Number("147")}, + {Value: Number("148")}, + {Value: Number("149")}, + {Value: Number("150")}, + {Value: Number("151")}, + {Value: Number("152")}, + {Value: Number("153")}, + {Value: Number("154")}, + {Value: Number("155")}, + {Value: Number("156")}, + {Value: Number("157")}, + {Value: Number("158")}, + {Value: Number("159")}, + {Value: Number("160")}, + {Value: Number("161")}, + {Value: Number("162")}, + {Value: Number("163")}, + {Value: Number("164")}, + {Value: Number("165")}, + {Value: Number("166")}, + {Value: Number("167")}, + {Value: Number("168")}, + {Value: Number("169")}, + {Value: Number("170")}, + {Value: Number("171")}, + {Value: Number("172")}, + {Value: Number("173")}, + {Value: Number("174")}, + {Value: Number("175")}, + {Value: Number("176")}, + {Value: Number("177")}, + {Value: Number("178")}, + {Value: Number("179")}, + {Value: Number("180")}, + {Value: Number("181")}, + {Value: Number("182")}, + {Value: Number("183")}, + {Value: Number("184")}, + {Value: Number("185")}, + {Value: Number("186")}, + {Value: Number("187")}, + {Value: Number("188")}, + {Value: Number("189")}, + {Value: Number("190")}, + {Value: Number("191")}, + {Value: Number("192")}, + {Value: Number("193")}, + {Value: Number("194")}, + {Value: Number("195")}, + {Value: Number("196")}, + {Value: Number("197")}, + {Value: Number("198")}, + {Value: Number("199")}, + {Value: Number("200")}, + {Value: Number("201")}, + {Value: Number("202")}, + {Value: Number("203")}, + {Value: Number("204")}, + {Value: Number("205")}, + {Value: Number("206")}, + {Value: Number("207")}, + {Value: Number("208")}, + {Value: Number("209")}, + {Value: Number("210")}, + {Value: Number("211")}, + {Value: Number("212")}, + {Value: Number("213")}, + {Value: Number("214")}, + {Value: Number("215")}, + {Value: Number("216")}, + {Value: Number("217")}, + {Value: Number("218")}, + {Value: Number("219")}, + {Value: Number("220")}, + {Value: Number("221")}, + {Value: Number("222")}, + {Value: Number("223")}, + {Value: Number("224")}, + {Value: Number("225")}, + {Value: Number("226")}, + {Value: Number("227")}, + {Value: Number("228")}, + {Value: Number("229")}, + {Value: Number("230")}, + {Value: Number("231")}, + {Value: Number("232")}, + {Value: Number("233")}, + {Value: Number("234")}, + {Value: Number("235")}, + {Value: Number("236")}, + {Value: Number("237")}, + {Value: Number("238")}, + {Value: Number("239")}, + {Value: Number("240")}, + {Value: Number("241")}, + {Value: Number("242")}, + {Value: Number("243")}, + {Value: Number("244")}, + {Value: Number("245")}, + {Value: Number("246")}, + {Value: Number("247")}, + {Value: Number("248")}, + {Value: Number("249")}, + {Value: Number("250")}, + {Value: Number("251")}, + {Value: Number("252")}, + {Value: Number("253")}, + {Value: Number("254")}, + {Value: Number("255")}, + {Value: Number("256")}, + {Value: Number("257")}, + {Value: Number("258")}, + {Value: Number("259")}, + {Value: Number("260")}, + {Value: Number("261")}, + {Value: Number("262")}, + {Value: Number("263")}, + {Value: Number("264")}, + {Value: Number("265")}, + {Value: Number("266")}, + {Value: Number("267")}, + {Value: Number("268")}, + {Value: Number("269")}, + {Value: Number("270")}, + {Value: Number("271")}, + {Value: Number("272")}, + {Value: Number("273")}, + {Value: Number("274")}, + {Value: Number("275")}, + {Value: Number("276")}, + {Value: Number("277")}, + {Value: Number("278")}, + {Value: Number("279")}, + {Value: Number("280")}, + {Value: Number("281")}, + {Value: Number("282")}, + {Value: Number("283")}, + {Value: Number("284")}, + {Value: Number("285")}, + {Value: Number("286")}, + {Value: Number("287")}, + {Value: Number("288")}, + {Value: Number("289")}, + {Value: Number("290")}, + {Value: Number("291")}, + {Value: Number("292")}, + {Value: Number("293")}, + {Value: Number("294")}, + {Value: Number("295")}, + {Value: Number("296")}, + {Value: Number("297")}, + {Value: Number("298")}, + {Value: Number("299")}, + {Value: Number("300")}, + {Value: Number("301")}, + {Value: Number("302")}, + {Value: Number("303")}, + {Value: Number("304")}, + {Value: Number("305")}, + {Value: Number("306")}, + {Value: Number("307")}, + {Value: Number("308")}, + {Value: Number("309")}, + {Value: Number("310")}, + {Value: Number("311")}, + {Value: Number("312")}, + {Value: Number("313")}, + {Value: Number("314")}, + {Value: Number("315")}, + {Value: Number("316")}, + {Value: Number("317")}, + {Value: Number("318")}, + {Value: Number("319")}, + {Value: Number("320")}, + {Value: Number("321")}, + {Value: Number("322")}, + {Value: Number("323")}, + {Value: Number("324")}, + {Value: Number("325")}, + {Value: Number("326")}, + {Value: Number("327")}, + {Value: Number("328")}, + {Value: Number("329")}, + {Value: Number("330")}, + {Value: Number("331")}, + {Value: Number("332")}, + {Value: Number("333")}, + {Value: Number("334")}, + {Value: Number("335")}, + {Value: Number("336")}, + {Value: Number("337")}, + {Value: Number("338")}, + {Value: Number("339")}, + {Value: Number("340")}, + {Value: Number("341")}, + {Value: Number("342")}, + {Value: Number("343")}, + {Value: Number("344")}, + {Value: Number("345")}, + {Value: Number("346")}, + {Value: Number("347")}, + {Value: Number("348")}, + {Value: Number("349")}, + {Value: Number("350")}, + {Value: Number("351")}, + {Value: Number("352")}, + {Value: Number("353")}, + {Value: Number("354")}, + {Value: Number("355")}, + {Value: Number("356")}, + {Value: Number("357")}, + {Value: Number("358")}, + {Value: Number("359")}, + {Value: Number("360")}, + {Value: Number("361")}, + {Value: Number("362")}, + {Value: Number("363")}, + {Value: Number("364")}, + {Value: Number("365")}, + {Value: Number("366")}, + {Value: Number("367")}, + {Value: Number("368")}, + {Value: Number("369")}, + {Value: Number("370")}, + {Value: Number("371")}, + {Value: Number("372")}, + {Value: Number("373")}, + {Value: Number("374")}, + {Value: Number("375")}, + {Value: Number("376")}, + {Value: Number("377")}, + {Value: Number("378")}, + {Value: Number("379")}, + {Value: Number("380")}, + {Value: Number("381")}, + {Value: Number("382")}, + {Value: Number("383")}, + {Value: Number("384")}, + {Value: Number("385")}, + {Value: Number("386")}, + {Value: Number("387")}, + {Value: Number("388")}, + {Value: Number("389")}, + {Value: Number("390")}, + {Value: Number("391")}, + {Value: Number("392")}, + {Value: Number("393")}, + {Value: Number("394")}, + {Value: Number("395")}, + {Value: Number("396")}, + {Value: Number("397")}, + {Value: Number("398")}, + {Value: Number("399")}, + {Value: Number("400")}, + {Value: Number("401")}, + {Value: Number("402")}, + {Value: Number("403")}, + {Value: Number("404")}, + {Value: Number("405")}, + {Value: Number("406")}, + {Value: Number("407")}, + {Value: Number("408")}, + {Value: Number("409")}, + {Value: Number("410")}, + {Value: Number("411")}, + {Value: Number("412")}, + {Value: Number("413")}, + {Value: Number("414")}, + {Value: Number("415")}, + {Value: Number("416")}, + {Value: Number("417")}, + {Value: Number("418")}, + {Value: Number("419")}, + {Value: Number("420")}, + {Value: Number("421")}, + {Value: Number("422")}, + {Value: Number("423")}, + {Value: Number("424")}, + {Value: Number("425")}, + {Value: Number("426")}, + {Value: Number("427")}, + {Value: Number("428")}, + {Value: Number("429")}, + {Value: Number("430")}, + {Value: Number("431")}, + {Value: Number("432")}, + {Value: Number("433")}, + {Value: Number("434")}, + {Value: Number("435")}, + {Value: Number("436")}, + {Value: Number("437")}, + {Value: Number("438")}, + {Value: Number("439")}, + {Value: Number("440")}, + {Value: Number("441")}, + {Value: Number("442")}, + {Value: Number("443")}, + {Value: Number("444")}, + {Value: Number("445")}, + {Value: Number("446")}, + {Value: Number("447")}, + {Value: Number("448")}, + {Value: Number("449")}, + {Value: Number("450")}, + {Value: Number("451")}, + {Value: Number("452")}, + {Value: Number("453")}, + {Value: Number("454")}, + {Value: Number("455")}, + {Value: Number("456")}, + {Value: Number("457")}, + {Value: Number("458")}, + {Value: Number("459")}, + {Value: Number("460")}, + {Value: Number("461")}, + {Value: Number("462")}, + {Value: Number("463")}, + {Value: Number("464")}, + {Value: Number("465")}, + {Value: Number("466")}, + {Value: Number("467")}, + {Value: Number("468")}, + {Value: Number("469")}, + {Value: Number("470")}, + {Value: Number("471")}, + {Value: Number("472")}, + {Value: Number("473")}, + {Value: Number("474")}, + {Value: Number("475")}, + {Value: Number("476")}, + {Value: Number("477")}, + {Value: Number("478")}, + {Value: Number("479")}, + {Value: Number("480")}, + {Value: Number("481")}, + {Value: Number("482")}, + {Value: Number("483")}, + {Value: Number("484")}, + {Value: Number("485")}, + {Value: Number("486")}, + {Value: Number("487")}, + {Value: Number("488")}, + {Value: Number("489")}, + {Value: Number("490")}, + {Value: Number("491")}, + {Value: Number("492")}, + {Value: Number("493")}, + {Value: Number("494")}, + {Value: Number("495")}, + {Value: Number("496")}, + {Value: Number("497")}, + {Value: Number("498")}, + {Value: Number("499")}, + {Value: Number("500")}, + {Value: Number("501")}, + {Value: Number("502")}, + {Value: Number("503")}, + {Value: Number("504")}, + {Value: Number("505")}, + {Value: Number("506")}, + {Value: Number("507")}, + {Value: Number("508")}, + {Value: Number("509")}, + {Value: Number("510")}, + {Value: Number("511")}, + {Value: Number("512")}, +} diff --git a/vendor/github.com/open-policy-agent/opa/v1/ast/json/json.go b/vendor/github.com/open-policy-agent/opa/v1/ast/json/json.go new file mode 100644 index 000000000..565017d58 --- /dev/null +++ b/vendor/github.com/open-policy-agent/opa/v1/ast/json/json.go @@ -0,0 +1,36 @@ +package json + +// Options defines the options for JSON operations, +// currently only marshaling can be configured +type Options struct { + MarshalOptions MarshalOptions +} + +// MarshalOptions defines the options for JSON marshaling, +// currently only toggling the marshaling of location information is supported +type MarshalOptions struct { + // IncludeLocation toggles the marshaling of location information + IncludeLocation NodeToggle + // IncludeLocationText additionally/optionally includes the text of the location + IncludeLocationText bool + // ExcludeLocationFile additionally/optionally excludes the file of the location + // Note that this is inverted (i.e. not "include" as the default needs to remain false) + ExcludeLocationFile bool +} + +// NodeToggle is a generic struct to allow the toggling of +// settings for different ast node types +type NodeToggle struct { + Term bool + Package bool + Comment bool + Import bool + Rule bool + Head bool + Expr bool + SomeDecl bool + Every bool + With bool + Annotations bool + AnnotationsRef bool +} diff --git a/vendor/github.com/open-policy-agent/opa/ast/location/location.go b/vendor/github.com/open-policy-agent/opa/v1/ast/location/location.go similarity index 98% rename from vendor/github.com/open-policy-agent/opa/ast/location/location.go rename to vendor/github.com/open-policy-agent/opa/v1/ast/location/location.go index 92226df3f..d529ad72d 100644 --- a/vendor/github.com/open-policy-agent/opa/ast/location/location.go +++ b/vendor/github.com/open-policy-agent/opa/v1/ast/location/location.go @@ -7,7 +7,7 @@ import ( "errors" "fmt" - astJSON "github.com/open-policy-agent/opa/ast/json" + astJSON "github.com/open-policy-agent/opa/v1/ast/json" ) // Location records a position in source code diff --git a/vendor/github.com/open-policy-agent/opa/v1/ast/map.go b/vendor/github.com/open-policy-agent/opa/v1/ast/map.go new file mode 100644 index 000000000..c22d279a6 --- /dev/null +++ b/vendor/github.com/open-policy-agent/opa/v1/ast/map.go @@ -0,0 +1,133 @@ +// Copyright 2016 The OPA Authors. All rights reserved. +// Use of this source code is governed by an Apache2 +// license that can be found in the LICENSE file. + +package ast + +import ( + "encoding/json" + + "github.com/open-policy-agent/opa/v1/util" +) + +// ValueMap represents a key/value map between AST term values. Any type of term +// can be used as a key in the map. +type ValueMap struct { + hashMap *util.HashMap +} + +// NewValueMap returns a new ValueMap. +func NewValueMap() *ValueMap { + vs := &ValueMap{ + hashMap: util.NewHashMap(valueEq, valueHash), + } + return vs +} + +// MarshalJSON provides a custom marshaller for the ValueMap which +// will include the key, value, and value type. +func (vs *ValueMap) MarshalJSON() ([]byte, error) { + var tmp []map[string]interface{} + vs.Iter(func(k Value, v Value) bool { + tmp = append(tmp, map[string]interface{}{ + "name": k.String(), + "type": TypeName(v), + "value": v, + }) + return false + }) + return json.Marshal(tmp) +} + +// Copy returns a shallow copy of the ValueMap. +func (vs *ValueMap) Copy() *ValueMap { + if vs == nil { + return nil + } + cpy := NewValueMap() + cpy.hashMap = vs.hashMap.Copy() + return cpy +} + +// Equal returns true if this ValueMap equals the other. +func (vs *ValueMap) Equal(other *ValueMap) bool { + if vs == nil { + return other == nil || other.Len() == 0 + } + if other == nil { + return vs.Len() == 0 + } + return vs.hashMap.Equal(other.hashMap) +} + +// Len returns the number of elements in the map. +func (vs *ValueMap) Len() int { + if vs == nil { + return 0 + } + return vs.hashMap.Len() +} + +// Get returns the value in the map for k. +func (vs *ValueMap) Get(k Value) Value { + if vs != nil { + if v, ok := vs.hashMap.Get(k); ok { + return v.(Value) + } + } + return nil +} + +// Hash returns a hash code for this ValueMap. +func (vs *ValueMap) Hash() int { + if vs == nil { + return 0 + } + return vs.hashMap.Hash() +} + +// Iter calls the iter function for each key/value pair in the map. If the iter +// function returns true, iteration stops. +func (vs *ValueMap) Iter(iter func(Value, Value) bool) bool { + if vs == nil { + return false + } + return vs.hashMap.Iter(func(kt, vt util.T) bool { + k := kt.(Value) + v := vt.(Value) + return iter(k, v) + }) +} + +// Put inserts a key k into the map with value v. +func (vs *ValueMap) Put(k, v Value) { + if vs == nil { + panic("put on nil value map") + } + vs.hashMap.Put(k, v) +} + +// Delete removes a key k from the map. +func (vs *ValueMap) Delete(k Value) { + if vs == nil { + return + } + vs.hashMap.Delete(k) +} + +func (vs *ValueMap) String() string { + if vs == nil { + return "{}" + } + return vs.hashMap.String() +} + +func valueHash(v util.T) int { + return v.(Value).Hash() +} + +func valueEq(a, b util.T) bool { + av := a.(Value) + bv := b.(Value) + return av.Compare(bv) == 0 +} diff --git a/vendor/github.com/open-policy-agent/opa/ast/marshal.go b/vendor/github.com/open-policy-agent/opa/v1/ast/marshal.go similarity index 80% rename from vendor/github.com/open-policy-agent/opa/ast/marshal.go rename to vendor/github.com/open-policy-agent/opa/v1/ast/marshal.go index 53fb11204..dd8dff684 100644 --- a/vendor/github.com/open-policy-agent/opa/ast/marshal.go +++ b/vendor/github.com/open-policy-agent/opa/v1/ast/marshal.go @@ -1,7 +1,7 @@ package ast import ( - astJSON "github.com/open-policy-agent/opa/ast/json" + astJSON "github.com/open-policy-agent/opa/v1/ast/json" ) // customJSON is an interface that can be implemented by AST nodes that diff --git a/vendor/github.com/open-policy-agent/opa/v1/ast/parser.go b/vendor/github.com/open-policy-agent/opa/v1/ast/parser.go new file mode 100644 index 000000000..6639ca990 --- /dev/null +++ b/vendor/github.com/open-policy-agent/opa/v1/ast/parser.go @@ -0,0 +1,2780 @@ +// Copyright 2020 The OPA Authors. All rights reserved. +// Use of this source code is governed by an Apache2 +// license that can be found in the LICENSE file. + +package ast + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "math/big" + "net/url" + "regexp" + "sort" + "strconv" + "strings" + "unicode/utf8" + + "gopkg.in/yaml.v3" + + "github.com/open-policy-agent/opa/v1/ast/internal/scanner" + "github.com/open-policy-agent/opa/v1/ast/internal/tokens" + astJSON "github.com/open-policy-agent/opa/v1/ast/json" + "github.com/open-policy-agent/opa/v1/ast/location" +) + +var RegoV1CompatibleRef = Ref{VarTerm("rego"), StringTerm("v1")} + +// RegoVersion defines the Rego syntax requirements for a module. +type RegoVersion int + +const DefaultRegoVersion = RegoV1 + +const ( + RegoUndefined RegoVersion = iota + // RegoV0 is the default, original Rego syntax. + RegoV0 + // RegoV0CompatV1 requires modules to comply with both the RegoV0 and RegoV1 syntax (as when 'rego.v1' is imported in a module). + // Shortly, RegoV1 compatibility is required, but 'rego.v1' or 'future.keywords' must also be imported. + RegoV0CompatV1 + // RegoV1 is the Rego syntax enforced by OPA 1.0; e.g.: + // future.keywords part of default keyword set, and don't require imports; + // 'if' and 'contains' required in rule heads; + // (some) strict checks on by default. + RegoV1 +) + +func (v RegoVersion) Int() int { + if v == RegoV1 { + return 1 + } + return 0 +} + +func (v RegoVersion) String() string { + switch v { + case RegoV0: + return "v0" + case RegoV1: + return "v1" + case RegoV0CompatV1: + return "v0v1" + default: + return "unknown" + } +} + +func RegoVersionFromInt(i int) RegoVersion { + if i == 1 { + return RegoV1 + } + return RegoV0 +} + +// Note: This state is kept isolated from the parser so that we +// can do efficient shallow copies of these values when doing a +// save() and restore(). +type state struct { + s *scanner.Scanner + lastEnd int + skippedNL bool + tok tokens.Token + tokEnd int + lit string + loc Location + errors Errors + hints []string + comments []*Comment + wildcard int +} + +func (s *state) String() string { + return fmt.Sprintf("", s.s, s.tok, s.lit, s.loc, len(s.errors), len(s.comments)) +} + +func (s *state) Loc() *location.Location { + cpy := s.loc + return &cpy +} + +func (s *state) Text(offset, end int) []byte { + bs := s.s.Bytes() + if offset >= 0 && offset < len(bs) { + if end >= offset && end <= len(bs) { + return bs[offset:end] + } + } + return nil +} + +// Parser is used to parse Rego statements. +type Parser struct { + r io.Reader + s *state + po ParserOptions + cache parsedTermCache +} + +type parsedTermCacheItem struct { + t *Term + post *state // post is the post-state that's restored on a cache-hit + offset int + next *parsedTermCacheItem +} + +type parsedTermCache struct { + m *parsedTermCacheItem +} + +func (c parsedTermCache) String() string { + s := strings.Builder{} + s.WriteRune('{') + var e *parsedTermCacheItem + for e = c.m; e != nil; e = e.next { + s.WriteString(fmt.Sprintf("%v", e)) + } + s.WriteRune('}') + return s.String() +} + +func (e *parsedTermCacheItem) String() string { + return fmt.Sprintf("<%d:%v>", e.offset, e.t) +} + +// ParserOptions defines the options for parsing Rego statements. +type ParserOptions struct { + Capabilities *Capabilities + ProcessAnnotation bool + AllFutureKeywords bool + FutureKeywords []string + SkipRules bool + JSONOptions *astJSON.Options + // RegoVersion is the version of Rego to parse for. + RegoVersion RegoVersion + unreleasedKeywords bool // TODO(sr): cleanup +} + +// EffectiveRegoVersion returns the effective RegoVersion to use for parsing. +func (po *ParserOptions) EffectiveRegoVersion() RegoVersion { + if po.RegoVersion == RegoUndefined { + return DefaultRegoVersion + } + return po.RegoVersion +} + +// NewParser creates and initializes a Parser. +func NewParser() *Parser { + p := &Parser{ + s: &state{}, + po: ParserOptions{}, + } + return p +} + +// WithFilename provides the filename for Location details +// on parsed statements. +func (p *Parser) WithFilename(filename string) *Parser { + p.s.loc.File = filename + return p +} + +// WithReader provides the io.Reader that the parser will +// use as its source. +func (p *Parser) WithReader(r io.Reader) *Parser { + p.r = r + return p +} + +// WithProcessAnnotation enables or disables the processing of +// annotations by the Parser +func (p *Parser) WithProcessAnnotation(processAnnotation bool) *Parser { + p.po.ProcessAnnotation = processAnnotation + return p +} + +// WithFutureKeywords enables "future" keywords, i.e., keywords that can +// be imported via +// +// import future.keywords.kw +// import future.keywords.other +// +// but in a more direct way. The equivalent of this import would be +// +// WithFutureKeywords("kw", "other") +func (p *Parser) WithFutureKeywords(kws ...string) *Parser { + p.po.FutureKeywords = kws + return p +} + +// WithAllFutureKeywords enables all "future" keywords, i.e., the +// ParserOption equivalent of +// +// import future.keywords +func (p *Parser) WithAllFutureKeywords(yes bool) *Parser { + p.po.AllFutureKeywords = yes + return p +} + +// withUnreleasedKeywords allows using keywords that haven't surfaced +// as future keywords (see above) yet, but have tests that require +// them to be parsed +func (p *Parser) withUnreleasedKeywords(yes bool) *Parser { + p.po.unreleasedKeywords = yes + return p +} + +// WithCapabilities sets the capabilities structure on the parser. +func (p *Parser) WithCapabilities(c *Capabilities) *Parser { + p.po.Capabilities = c + return p +} + +// WithSkipRules instructs the parser not to attempt to parse Rule statements. +func (p *Parser) WithSkipRules(skip bool) *Parser { + p.po.SkipRules = skip + return p +} + +// WithJSONOptions sets the Options which will be set on nodes to configure +// their JSON marshaling behavior. +func (p *Parser) WithJSONOptions(jsonOptions *astJSON.Options) *Parser { + p.po.JSONOptions = jsonOptions + return p +} + +func (p *Parser) WithRegoVersion(version RegoVersion) *Parser { + p.po.RegoVersion = version + return p +} + +func (p *Parser) parsedTermCacheLookup() (*Term, *state) { + l := p.s.loc.Offset + // stop comparing once the cached offsets are lower than l + for h := p.cache.m; h != nil && h.offset >= l; h = h.next { + if h.offset == l { + return h.t, h.post + } + } + return nil, nil +} + +func (p *Parser) parsedTermCachePush(t *Term, s0 *state) { + s1 := p.save() + o0 := s0.loc.Offset + entry := parsedTermCacheItem{t: t, post: s1, offset: o0} + + // find the first one whose offset is smaller than ours + var e *parsedTermCacheItem + for e = p.cache.m; e != nil; e = e.next { + if e.offset < o0 { + break + } + } + entry.next = e + p.cache.m = &entry +} + +// futureParser returns a shallow copy of `p` with an empty +// cache, and a scanner that knows all future keywords. +// It's used to present hints in errors, when statements would +// only parse successfully if some future keyword is enabled. +func (p *Parser) futureParser() *Parser { + q := *p + q.s = p.save() + q.s.s = p.s.s.WithKeywords(allFutureKeywords) + q.cache = parsedTermCache{} + return &q +} + +// presentParser returns a shallow copy of `p` with an empty +// cache, and a scanner that knows none of the future keywords. +// It is used to successfully parse keyword imports, like +// +// import future.keywords.in +// +// even when the parser has already been informed about the +// future keyword "in". This parser won't error out because +// "in" is an identifier. +func (p *Parser) presentParser() (*Parser, map[string]tokens.Token) { + var cpy map[string]tokens.Token + q := *p + q.s = p.save() + q.s.s, cpy = p.s.s.WithoutKeywords(allFutureKeywords) + q.cache = parsedTermCache{} + return &q, cpy +} + +// Parse will read the Rego source and parse statements and +// comments as they are found. Any errors encountered while +// parsing will be accumulated and returned as a list of Errors. +func (p *Parser) Parse() ([]Statement, []*Comment, Errors) { + + if p.po.Capabilities == nil { + p.po.Capabilities = CapabilitiesForThisVersion(CapabilitiesRegoVersion(p.po.RegoVersion)) + } + + allowedFutureKeywords := map[string]tokens.Token{} + + if p.po.EffectiveRegoVersion() == RegoV1 { + if !p.po.Capabilities.ContainsFeature(FeatureRegoV1) { + return nil, nil, Errors{ + &Error{ + Code: ParseErr, + Message: "illegal capabilities: rego_v1 feature required for parsing v1 Rego", + Location: nil, + }, + } + } + + // rego-v1 includes all v0 future keywords in the default language definition + for k, v := range futureKeywordsV0 { + allowedFutureKeywords[k] = v + } + + for _, kw := range p.po.Capabilities.FutureKeywords { + if tok, ok := futureKeywords[kw]; ok { + allowedFutureKeywords[kw] = tok + } else { + // For sake of error reporting, we still need to check that keywords in capabilities are known in v0 + if _, ok := futureKeywordsV0[kw]; !ok { + return nil, nil, Errors{ + &Error{ + Code: ParseErr, + Message: fmt.Sprintf("illegal capabilities: unknown keyword: %v", kw), + Location: nil, + }, + } + } + } + } + + // Check that explicitly requested future keywords are known. + for _, kw := range p.po.FutureKeywords { + if _, ok := allowedFutureKeywords[kw]; !ok { + return nil, nil, Errors{ + &Error{ + Code: ParseErr, + Message: fmt.Sprintf("unknown future keyword: %v", kw), + Location: nil, + }, + } + } + } + } else { + for _, kw := range p.po.Capabilities.FutureKeywords { + var ok bool + allowedFutureKeywords[kw], ok = allFutureKeywords[kw] + if !ok { + return nil, nil, Errors{ + &Error{ + Code: ParseErr, + Message: fmt.Sprintf("illegal capabilities: unknown keyword: %v", kw), + Location: nil, + }, + } + } + } + + if p.po.Capabilities.ContainsFeature(FeatureRegoV1) { + // rego-v1 includes all v0 future keywords in the default language definition + for k, v := range futureKeywordsV0 { + allowedFutureKeywords[k] = v + } + } + } + + var err error + p.s.s, err = scanner.New(p.r) + if err != nil { + return nil, nil, Errors{ + &Error{ + Code: ParseErr, + Message: err.Error(), + Location: nil, + }, + } + } + + selected := map[string]tokens.Token{} + if p.po.AllFutureKeywords || p.po.EffectiveRegoVersion() == RegoV1 { + for kw, tok := range allowedFutureKeywords { + selected[kw] = tok + } + } else { + for _, kw := range p.po.FutureKeywords { + tok, ok := allowedFutureKeywords[kw] + if !ok { + return nil, nil, Errors{ + &Error{ + Code: ParseErr, + Message: fmt.Sprintf("unknown future keyword: %v", kw), + Location: nil, + }, + } + } + selected[kw] = tok + } + } + p.s.s = p.s.s.WithKeywords(selected) + + if p.po.EffectiveRegoVersion() == RegoV1 { + for kw, tok := range allowedFutureKeywords { + p.s.s.AddKeyword(kw, tok) + } + } + + // read the first token to initialize the parser + p.scan() + + var stmts []Statement + + // Read from the scanner until the last token is reached or no statements + // can be parsed. Attempt to parse package statements, import statements, + // rule statements, and then body/query statements (in that order). If a + // statement cannot be parsed, restore the parser state before trying the + // next type of statement. If a statement can be parsed, continue from that + // point trying to parse packages, imports, etc. in the same order. + for p.s.tok != tokens.EOF { + + s := p.save() + + if pkg := p.parsePackage(); pkg != nil { + stmts = append(stmts, pkg) + continue + } else if len(p.s.errors) > 0 { + break + } + + p.restore(s) + s = p.save() + + if imp := p.parseImport(); imp != nil { + if RegoRootDocument.Equal(imp.Path.Value.(Ref)[0]) { + p.regoV1Import(imp) + } + + if FutureRootDocument.Equal(imp.Path.Value.(Ref)[0]) { + p.futureImport(imp, allowedFutureKeywords) + } + + stmts = append(stmts, imp) + continue + } else if len(p.s.errors) > 0 { + break + } + + p.restore(s) + + if !p.po.SkipRules { + s = p.save() + + if rules := p.parseRules(); rules != nil { + for i := range rules { + stmts = append(stmts, rules[i]) + } + continue + } else if len(p.s.errors) > 0 { + break + } + + p.restore(s) + } + + if body := p.parseQuery(true, tokens.EOF); body != nil { + stmts = append(stmts, body) + continue + } + + break + } + + if p.po.ProcessAnnotation { + stmts = p.parseAnnotations(stmts) + } + + if p.po.JSONOptions != nil { + for i := range stmts { + vis := NewGenericVisitor(func(x interface{}) bool { + if x, ok := x.(customJSON); ok { + x.setJSONOptions(*p.po.JSONOptions) + } + return false + }) + + vis.Walk(stmts[i]) + } + } + + return stmts, p.s.comments, p.s.errors +} + +func (p *Parser) parseAnnotations(stmts []Statement) []Statement { + + annotStmts, errs := parseAnnotations(p.s.comments) + for _, err := range errs { + p.error(err.Location, err.Message) + } + + for _, annotStmt := range annotStmts { + stmts = append(stmts, annotStmt) + } + + return stmts +} + +func parseAnnotations(comments []*Comment) ([]*Annotations, Errors) { + + var hint = []byte("METADATA") + var curr *metadataParser + var blocks []*metadataParser + + for i := 0; i < len(comments); i++ { + if curr != nil { + if comments[i].Location.Row == comments[i-1].Location.Row+1 && comments[i].Location.Col == 1 { + curr.Append(comments[i]) + continue + } + curr = nil + } + if bytes.HasPrefix(bytes.TrimSpace(comments[i].Text), hint) { + curr = newMetadataParser(comments[i].Location) + blocks = append(blocks, curr) + } + } + + var stmts []*Annotations + var errs Errors + for _, b := range blocks { + a, err := b.Parse() + if err != nil { + errs = append(errs, &Error{ + Code: ParseErr, + Message: err.Error(), + Location: b.loc, + }) + } else { + stmts = append(stmts, a) + } + } + + return stmts, errs +} + +func (p *Parser) parsePackage() *Package { + + var pkg Package + pkg.SetLoc(p.s.Loc()) + + if p.s.tok != tokens.Package { + return nil + } + + p.scan() + if p.s.tok != tokens.Ident { + p.illegalToken() + return nil + } + + term := p.parseTerm() + + if term != nil { + switch v := term.Value.(type) { + case Var: + pkg.Path = Ref{ + DefaultRootDocument.Copy().SetLocation(term.Location), + StringTerm(string(v)).SetLocation(term.Location), + } + case Ref: + pkg.Path = make(Ref, len(v)+1) + pkg.Path[0] = DefaultRootDocument.Copy().SetLocation(v[0].Location) + first, ok := v[0].Value.(Var) + if !ok { + p.errorf(v[0].Location, "unexpected %v token: expecting var", TypeName(v[0].Value)) + return nil + } + pkg.Path[1] = StringTerm(string(first)).SetLocation(v[0].Location) + for i := 2; i < len(pkg.Path); i++ { + switch v[i-1].Value.(type) { + case String: + pkg.Path[i] = v[i-1] + default: + p.errorf(v[i-1].Location, "unexpected %v token: expecting string", TypeName(v[i-1].Value)) + return nil + } + } + default: + p.illegalToken() + return nil + } + } + + if pkg.Path == nil { + if len(p.s.errors) == 0 { + p.error(p.s.Loc(), "expected path") + } + return nil + } + + return &pkg +} + +func (p *Parser) parseImport() *Import { + + var imp Import + imp.SetLoc(p.s.Loc()) + + if p.s.tok != tokens.Import { + return nil + } + + p.scan() + if p.s.tok != tokens.Ident { + p.error(p.s.Loc(), "expected ident") + return nil + } + q, prev := p.presentParser() + term := q.parseTerm() + if term != nil { + switch v := term.Value.(type) { + case Var: + imp.Path = RefTerm(term).SetLocation(term.Location) + case Ref: + for i := 1; i < len(v); i++ { + if _, ok := v[i].Value.(String); !ok { + p.errorf(v[i].Location, "unexpected %v token: expecting string", TypeName(v[i].Value)) + return nil + } + } + imp.Path = term + } + } + // keep advanced parser state, reset known keywords + p.s = q.s + p.s.s = q.s.s.WithKeywords(prev) + + if imp.Path == nil { + p.error(p.s.Loc(), "expected path") + return nil + } + + path := imp.Path.Value.(Ref) + + switch { + case RootDocumentNames.Contains(path[0]): + case FutureRootDocument.Equal(path[0]): + case RegoRootDocument.Equal(path[0]): + default: + p.hint("if this is unexpected, try updating OPA") + p.errorf(imp.Path.Location, "unexpected import path, must begin with one of: %v, got: %v", + RootDocumentNames.Union(NewSet(FutureRootDocument, RegoRootDocument)), + path[0]) + return nil + } + + if p.s.tok == tokens.As { + p.scan() + + if p.s.tok != tokens.Ident { + p.illegal("expected var") + return nil + } + + if alias := p.parseTerm(); alias != nil { + v, ok := alias.Value.(Var) + if ok { + imp.Alias = v + return &imp + } + } + p.illegal("expected var") + return nil + } + + return &imp +} + +func (p *Parser) parseRules() []*Rule { + + var rule Rule + rule.SetLoc(p.s.Loc()) + + if p.s.tok == tokens.Default { + p.scan() + rule.Default = true + } + + if p.s.tok != tokens.Ident { + return nil + } + + usesContains := false + if rule.Head, usesContains = p.parseHead(rule.Default); rule.Head == nil { + return nil + } + + if usesContains { + rule.Head.keywords = append(rule.Head.keywords, tokens.Contains) + } + + if rule.Default { + if !p.validateDefaultRuleValue(&rule) { + return nil + } + + if len(rule.Head.Args) > 0 { + if !p.validateDefaultRuleArgs(&rule) { + return nil + } + } + + rule.Body = NewBody(NewExpr(BooleanTerm(true).SetLocation(rule.Location)).SetLocation(rule.Location)) + return []*Rule{&rule} + } + + // back-compat with `p[x] { ... }`` + hasIf := p.s.tok == tokens.If + + // p[x] if ... becomes a single-value rule p[x] + if hasIf && !usesContains && len(rule.Head.Ref()) == 2 { + if !rule.Head.Ref()[1].IsGround() && len(rule.Head.Args) == 0 { + rule.Head.Key = rule.Head.Ref()[1] + } + + if rule.Head.Value == nil { + rule.Head.generatedValue = true + rule.Head.Value = BooleanTerm(true).SetLocation(rule.Head.Location) + } else { + // p[x] = y if becomes a single-value rule p[x] with value y, but needs name for compat + v, ok := rule.Head.Ref()[0].Value.(Var) + if !ok { + return nil + } + rule.Head.Name = v + } + } + + // p[x] becomes a multi-value rule p + if !hasIf && !usesContains && + len(rule.Head.Args) == 0 && // not a function + len(rule.Head.Ref()) == 2 { // ref like 'p[x]' + v, ok := rule.Head.Ref()[0].Value.(Var) + if !ok { + return nil + } + rule.Head.Name = v + rule.Head.Key = rule.Head.Ref()[1] + if rule.Head.Value == nil { + rule.Head.SetRef(rule.Head.Ref()[:len(rule.Head.Ref())-1]) + } + } + + switch { + case hasIf: + rule.Head.keywords = append(rule.Head.keywords, tokens.If) + p.scan() + s := p.save() + if expr := p.parseLiteral(); expr != nil { + // NOTE(sr): set literals are never false or undefined, so parsing this as + // p if { true } + // ^^^^^^^^ set of one element, `true` + // isn't valid. + isSetLiteral := false + if t, ok := expr.Terms.(*Term); ok { + _, isSetLiteral = t.Value.(Set) + } + // expr.Term is []*Term or Every + if !isSetLiteral { + rule.Body.Append(expr) + break + } + } + + // parsing as literal didn't work out, expect '{ BODY }' + p.restore(s) + fallthrough + + case p.s.tok == tokens.LBrace: + p.scan() + if rule.Body = p.parseBody(tokens.RBrace); rule.Body == nil { + return nil + } + p.scan() + + case usesContains: + rule.Body = NewBody(NewExpr(BooleanTerm(true).SetLocation(rule.Location)).SetLocation(rule.Location)) + rule.generatedBody = true + rule.Location = rule.Head.Location + + return []*Rule{&rule} + + default: + return nil + } + + if p.s.tok == tokens.Else { + if r := rule.Head.Ref(); len(r) > 1 && !r.IsGround() { + p.error(p.s.Loc(), "else keyword cannot be used on rules with variables in head") + return nil + } + if rule.Head.Key != nil { + p.error(p.s.Loc(), "else keyword cannot be used on multi-value rules") + return nil + } + + if rule.Else = p.parseElse(rule.Head); rule.Else == nil { + return nil + } + } + + rule.Location.Text = p.s.Text(rule.Location.Offset, p.s.lastEnd) + + rules := []*Rule{&rule} + + for p.s.tok == tokens.LBrace { + + if rule.Else != nil { + p.error(p.s.Loc(), "expected else keyword") + return nil + } + + loc := p.s.Loc() + + p.scan() + var next Rule + + if next.Body = p.parseBody(tokens.RBrace); next.Body == nil { + return nil + } + p.scan() + + loc.Text = p.s.Text(loc.Offset, p.s.lastEnd) + next.SetLoc(loc) + + // Chained rule head's keep the original + // rule's head AST but have their location + // set to the rule body. + next.Head = rule.Head.Copy() + next.Head.keywords = rule.Head.keywords + for i := range next.Head.Args { + if v, ok := next.Head.Args[i].Value.(Var); ok && v.IsWildcard() { + next.Head.Args[i].Value = Var(p.genwildcard()) + } + } + setLocRecursive(next.Head, loc) + + rules = append(rules, &next) + } + + return rules +} + +func (p *Parser) parseElse(head *Head) *Rule { + + var rule Rule + rule.SetLoc(p.s.Loc()) + + rule.Head = head.Copy() + rule.Head.generatedValue = false + for i := range rule.Head.Args { + if v, ok := rule.Head.Args[i].Value.(Var); ok && v.IsWildcard() { + rule.Head.Args[i].Value = Var(p.genwildcard()) + } + } + rule.Head.SetLoc(p.s.Loc()) + + defer func() { + rule.Location.Text = p.s.Text(rule.Location.Offset, p.s.lastEnd) + }() + + p.scan() + + switch p.s.tok { + case tokens.LBrace, tokens.If: // no value, but a body follows directly + rule.Head.generatedValue = true + rule.Head.Value = BooleanTerm(true) + case tokens.Assign, tokens.Unify: + rule.Head.Assign = tokens.Assign == p.s.tok + p.scan() + rule.Head.Value = p.parseTermInfixCall() + if rule.Head.Value == nil { + return nil + } + rule.Head.Location.Text = p.s.Text(rule.Head.Location.Offset, p.s.lastEnd) + default: + p.illegal("expected else value term or rule body") + return nil + } + + hasIf := p.s.tok == tokens.If + hasLBrace := p.s.tok == tokens.LBrace + + if !hasIf && !hasLBrace { + rule.Body = NewBody(NewExpr(BooleanTerm(true))) + rule.generatedBody = true + setLocRecursive(rule.Body, rule.Location) + return &rule + } + + if hasIf { + rule.Head.keywords = append(rule.Head.keywords, tokens.If) + p.scan() + } + + if p.s.tok == tokens.LBrace { + p.scan() + if rule.Body = p.parseBody(tokens.RBrace); rule.Body == nil { + return nil + } + p.scan() + } else if p.s.tok != tokens.EOF { + expr := p.parseLiteral() + if expr == nil { + return nil + } + rule.Body.Append(expr) + setLocRecursive(rule.Body, rule.Location) + } else { + p.illegal("rule body expected") + return nil + } + + if p.s.tok == tokens.Else { + if rule.Else = p.parseElse(head); rule.Else == nil { + return nil + } + } + return &rule +} + +func (p *Parser) parseHead(defaultRule bool) (*Head, bool) { + head := &Head{} + loc := p.s.Loc() + defer func() { + if head != nil { + head.SetLoc(loc) + head.Location.Text = p.s.Text(head.Location.Offset, p.s.lastEnd) + } + }() + + term := p.parseVar() + if term == nil { + return nil, false + } + + ref := p.parseTermFinish(term, true) + if ref == nil { + p.illegal("expected rule head name") + return nil, false + } + + switch x := ref.Value.(type) { + case Var: + // Modify the code to add the location to the head ref + // and set the head ref's jsonOptions. + head = VarHead(x, ref.Location, p.po.JSONOptions) + case Ref: + head = RefHead(x) + case Call: + op, args := x[0], x[1:] + var ref Ref + switch y := op.Value.(type) { + case Var: + ref = Ref{op} + case Ref: + if _, ok := y[0].Value.(Var); !ok { + p.illegal("rule head ref %v invalid", y) + return nil, false + } + ref = y + } + head = RefHead(ref) + head.Args = append([]*Term{}, args...) + + default: + return nil, false + } + + name := head.Ref().String() + + switch p.s.tok { + case tokens.Contains: // NOTE: no Value for `contains` heads, we return here + // Catch error case of using 'contains' with a function definition rule head. + if head.Args != nil { + p.illegal("the contains keyword can only be used with multi-value rule definitions (e.g., %s contains { ... })", name) + } + p.scan() + head.Key = p.parseTermInfixCall() + if head.Key == nil { + p.illegal("expected rule key term (e.g., %s contains { ... })", name) + } + return head, true + + case tokens.Unify: + p.scan() + head.Value = p.parseTermInfixCall() + if head.Value == nil { + // FIX HEAD.String() + p.illegal("expected rule value term (e.g., %s[%s] = { ... })", name, head.Key) + } + case tokens.Assign: + p.scan() + head.Assign = true + head.Value = p.parseTermInfixCall() + if head.Value == nil { + switch { + case len(head.Args) > 0: + p.illegal("expected function value term (e.g., %s(...) := { ... })", name) + case head.Key != nil: + p.illegal("expected partial rule value term (e.g., %s[...] := { ... })", name) + case defaultRule: + p.illegal("expected default rule value term (e.g., default %s := )", name) + default: + p.illegal("expected rule value term (e.g., %s := { ... })", name) + } + } + } + + if head.Value == nil && head.Key == nil { + if len(head.Ref()) != 2 || len(head.Args) > 0 { + head.generatedValue = true + head.Value = BooleanTerm(true).SetLocation(head.Location) + } + } + return head, false +} + +func (p *Parser) parseBody(end tokens.Token) Body { + return p.parseQuery(false, end) +} + +func (p *Parser) parseQuery(requireSemi bool, end tokens.Token) Body { + body := Body{} + + if p.s.tok == end { + p.error(p.s.Loc(), "found empty body") + return nil + } + + for { + expr := p.parseLiteral() + if expr == nil { + return nil + } + + body.Append(expr) + + if p.s.tok == tokens.Semicolon { + p.scan() + continue + } + + if p.s.tok == end || requireSemi { + return body + } + + if !p.s.skippedNL { + // If there was already an error then don't pile this one on + if len(p.s.errors) == 0 { + p.illegal(`expected \n or %s or %s`, tokens.Semicolon, end) + } + return nil + } + } +} + +func (p *Parser) parseLiteral() (expr *Expr) { + + offset := p.s.loc.Offset + loc := p.s.Loc() + + defer func() { + if expr != nil { + loc.Text = p.s.Text(offset, p.s.lastEnd) + expr.SetLoc(loc) + } + }() + + var negated bool + if p.s.tok == tokens.Not { + p.scan() + negated = true + } + + switch p.s.tok { + case tokens.Some: + if negated { + p.illegal("illegal negation of 'some'") + return nil + } + return p.parseSome() + case tokens.Every: + if negated { + p.illegal("illegal negation of 'every'") + return nil + } + return p.parseEvery() + default: + s := p.save() + expr := p.parseExpr() + if expr != nil { + expr.Negated = negated + if p.s.tok == tokens.With { + if expr.With = p.parseWith(); expr.With == nil { + return nil + } + } + // If we find a plain `every` identifier, attempt to parse an every expression, + // add hint if it succeeds. + if term, ok := expr.Terms.(*Term); ok && Var("every").Equal(term.Value) { + var hint bool + t := p.save() + p.restore(s) + if expr := p.futureParser().parseEvery(); expr != nil { + _, hint = expr.Terms.(*Every) + } + p.restore(t) + if hint { + p.hint("`import future.keywords.every` for `every x in xs { ... }` expressions") + } + } + return expr + } + return nil + } +} + +func (p *Parser) parseWith() []*With { + + withs := []*With{} + + for { + + with := With{ + Location: p.s.Loc(), + } + p.scan() + + if p.s.tok != tokens.Ident { + p.illegal("expected ident") + return nil + } + + with.Target = p.parseTerm() + if with.Target == nil { + return nil + } + + switch with.Target.Value.(type) { + case Ref, Var: + break + default: + p.illegal("expected with target path") + } + + if p.s.tok != tokens.As { + p.illegal("expected as keyword") + return nil + } + + p.scan() + + if with.Value = p.parseTermInfixCall(); with.Value == nil { + return nil + } + + with.Location.Text = p.s.Text(with.Location.Offset, p.s.lastEnd) + + withs = append(withs, &with) + + if p.s.tok != tokens.With { + break + } + } + + return withs +} + +func (p *Parser) parseSome() *Expr { + + decl := &SomeDecl{} + decl.SetLoc(p.s.Loc()) + + // Attempt to parse "some x in xs", which will end up in + // SomeDecl{Symbols: ["member(x, xs)"]} + s := p.save() + p.scan() + if term := p.parseTermInfixCall(); term != nil { + if call, ok := term.Value.(Call); ok { + switch call[0].String() { + case Member.Name: + if len(call) != 3 { + p.illegal("illegal domain") + return nil + } + case MemberWithKey.Name: + if len(call) != 4 { + p.illegal("illegal domain") + return nil + } + default: + p.illegal("expected `x in xs` or `x, y in xs` expression") + return nil + } + + decl.Symbols = []*Term{term} + expr := NewExpr(decl).SetLocation(decl.Location) + if p.s.tok == tokens.With { + if expr.With = p.parseWith(); expr.With == nil { + return nil + } + } + return expr + } + } + + p.restore(s) + s = p.save() // new copy for later + var hint bool + p.scan() + if term := p.futureParser().parseTermInfixCall(); term != nil { + if call, ok := term.Value.(Call); ok { + switch call[0].String() { + case Member.Name, MemberWithKey.Name: + hint = true + } + } + } + + // go on as before, it's `some x[...]` or illegal + p.restore(s) + if hint { + p.hint("`import future.keywords.in` for `some x in xs` expressions") + } + + for { // collecting var args + + p.scan() + + if p.s.tok != tokens.Ident { + p.illegal("expected var") + return nil + } + + decl.Symbols = append(decl.Symbols, p.parseVar()) + + p.scan() + + if p.s.tok != tokens.Comma { + break + } + } + + return NewExpr(decl).SetLocation(decl.Location) +} + +func (p *Parser) parseEvery() *Expr { + qb := &Every{} + qb.SetLoc(p.s.Loc()) + + // TODO(sr): We'd get more accurate error messages if we didn't rely on + // parseTermInfixCall here, but parsed "var [, var] in term" manually. + p.scan() + term := p.parseTermInfixCall() + if term == nil { + return nil + } + call, ok := term.Value.(Call) + if !ok { + p.illegal("expected `x[, y] in xs { ... }` expression") + return nil + } + switch call[0].String() { + case Member.Name: // x in xs + if len(call) != 3 { + p.illegal("illegal domain") + return nil + } + qb.Value = call[1] + qb.Domain = call[2] + case MemberWithKey.Name: // k, v in xs + if len(call) != 4 { + p.illegal("illegal domain") + return nil + } + qb.Key = call[1] + qb.Value = call[2] + qb.Domain = call[3] + if _, ok := qb.Key.Value.(Var); !ok { + p.illegal("expected key to be a variable") + return nil + } + default: + p.illegal("expected `x[, y] in xs { ... }` expression") + return nil + } + if _, ok := qb.Value.Value.(Var); !ok { + p.illegal("expected value to be a variable") + return nil + } + if p.s.tok == tokens.LBrace { // every x in xs { ... } + p.scan() + body := p.parseBody(tokens.RBrace) + if body == nil { + return nil + } + p.scan() + qb.Body = body + expr := NewExpr(qb).SetLocation(qb.Location) + + if p.s.tok == tokens.With { + if expr.With = p.parseWith(); expr.With == nil { + return nil + } + } + return expr + } + + p.illegal("missing body") + return nil +} + +func (p *Parser) parseExpr() *Expr { + + lhs := p.parseTermInfixCall() + if lhs == nil { + return nil + } + + if op := p.parseTermOp(tokens.Assign, tokens.Unify); op != nil { + if rhs := p.parseTermInfixCall(); rhs != nil { + return NewExpr([]*Term{op, lhs, rhs}) + } + return nil + } + + // NOTE(tsandall): the top-level call term is converted to an expr because + // the evaluator does not support the call term type (nested calls are + // rewritten by the compiler.) + if call, ok := lhs.Value.(Call); ok { + return NewExpr([]*Term(call)) + } + + return NewExpr(lhs) +} + +// parseTermInfixCall consumes the next term from the input and returns it. If a +// term cannot be parsed the return value is nil and error will be recorded. The +// scanner will be advanced to the next token before returning. +// By starting out with infix relations (==, !=, <, etc) and further calling the +// other binary operators (|, &, arithmetics), it constitutes the binding +// precedence. +func (p *Parser) parseTermInfixCall() *Term { + return p.parseTermIn(nil, true, p.s.loc.Offset) +} + +func (p *Parser) parseTermInfixCallInList() *Term { + return p.parseTermIn(nil, false, p.s.loc.Offset) +} + +// use static references to avoid allocations, and +// copy them to the call term only when needed +var memberWithKeyRef = MemberWithKey.Ref() +var memberRef = Member.Ref() + +func (p *Parser) parseTermIn(lhs *Term, keyVal bool, offset int) *Term { + // NOTE(sr): `in` is a bit special: besides `lhs in rhs`, it also + // supports `key, val in rhs`, so it can have an optional second lhs. + // `keyVal` triggers if we attempt to parse a second lhs argument (`mhs`). + if lhs == nil { + lhs = p.parseTermRelation(nil, offset) + } + if lhs != nil { + if keyVal && p.s.tok == tokens.Comma { // second "lhs", or "middle hand side" + s := p.save() + p.scan() + if mhs := p.parseTermRelation(nil, offset); mhs != nil { + + if op := p.parseTermOpName(memberWithKeyRef, tokens.In); op != nil { + if rhs := p.parseTermRelation(nil, p.s.loc.Offset); rhs != nil { + call := p.setLoc(CallTerm(op, lhs, mhs, rhs), lhs.Location, offset, p.s.lastEnd) + switch p.s.tok { + case tokens.In: + return p.parseTermIn(call, keyVal, offset) + default: + return call + } + } + } + } + p.restore(s) + } + if op := p.parseTermOpName(memberRef, tokens.In); op != nil { + if rhs := p.parseTermRelation(nil, p.s.loc.Offset); rhs != nil { + call := p.setLoc(CallTerm(op, lhs, rhs), lhs.Location, offset, p.s.lastEnd) + switch p.s.tok { + case tokens.In: + return p.parseTermIn(call, keyVal, offset) + default: + return call + } + } + } + } + return lhs +} + +func (p *Parser) parseTermRelation(lhs *Term, offset int) *Term { + if lhs == nil { + lhs = p.parseTermOr(nil, offset) + } + if lhs != nil { + if op := p.parseTermOp(tokens.Equal, tokens.Neq, tokens.Lt, tokens.Gt, tokens.Lte, tokens.Gte); op != nil { + if rhs := p.parseTermOr(nil, p.s.loc.Offset); rhs != nil { + call := p.setLoc(CallTerm(op, lhs, rhs), lhs.Location, offset, p.s.lastEnd) + switch p.s.tok { + case tokens.Equal, tokens.Neq, tokens.Lt, tokens.Gt, tokens.Lte, tokens.Gte: + return p.parseTermRelation(call, offset) + default: + return call + } + } + } + } + return lhs +} + +func (p *Parser) parseTermOr(lhs *Term, offset int) *Term { + if lhs == nil { + lhs = p.parseTermAnd(nil, offset) + } + if lhs != nil { + if op := p.parseTermOp(tokens.Or); op != nil { + if rhs := p.parseTermAnd(nil, p.s.loc.Offset); rhs != nil { + call := p.setLoc(CallTerm(op, lhs, rhs), lhs.Location, offset, p.s.lastEnd) + switch p.s.tok { + case tokens.Or: + return p.parseTermOr(call, offset) + default: + return call + } + } + } + return lhs + } + return nil +} + +func (p *Parser) parseTermAnd(lhs *Term, offset int) *Term { + if lhs == nil { + lhs = p.parseTermArith(nil, offset) + } + if lhs != nil { + if op := p.parseTermOp(tokens.And); op != nil { + if rhs := p.parseTermArith(nil, p.s.loc.Offset); rhs != nil { + call := p.setLoc(CallTerm(op, lhs, rhs), lhs.Location, offset, p.s.lastEnd) + switch p.s.tok { + case tokens.And: + return p.parseTermAnd(call, offset) + default: + return call + } + } + } + return lhs + } + return nil +} + +func (p *Parser) parseTermArith(lhs *Term, offset int) *Term { + if lhs == nil { + lhs = p.parseTermFactor(nil, offset) + } + if lhs != nil { + if op := p.parseTermOp(tokens.Add, tokens.Sub); op != nil { + if rhs := p.parseTermFactor(nil, p.s.loc.Offset); rhs != nil { + call := p.setLoc(CallTerm(op, lhs, rhs), lhs.Location, offset, p.s.lastEnd) + switch p.s.tok { + case tokens.Add, tokens.Sub: + return p.parseTermArith(call, offset) + default: + return call + } + } + } + } + return lhs +} + +func (p *Parser) parseTermFactor(lhs *Term, offset int) *Term { + if lhs == nil { + lhs = p.parseTerm() + } + if lhs != nil { + if op := p.parseTermOp(tokens.Mul, tokens.Quo, tokens.Rem); op != nil { + if rhs := p.parseTerm(); rhs != nil { + call := p.setLoc(CallTerm(op, lhs, rhs), lhs.Location, offset, p.s.lastEnd) + switch p.s.tok { + case tokens.Mul, tokens.Quo, tokens.Rem: + return p.parseTermFactor(call, offset) + default: + return call + } + } + } + } + return lhs +} + +func (p *Parser) parseTerm() *Term { + if term, s := p.parsedTermCacheLookup(); s != nil { + p.restore(s) + return term + } + s0 := p.save() + + var term *Term + switch p.s.tok { + case tokens.Null: + term = NullTerm().SetLocation(p.s.Loc()) + case tokens.True: + term = BooleanTerm(true).SetLocation(p.s.Loc()) + case tokens.False: + term = BooleanTerm(false).SetLocation(p.s.Loc()) + case tokens.Sub, tokens.Dot, tokens.Number: + term = p.parseNumber() + case tokens.String: + term = p.parseString() + case tokens.Ident, tokens.Contains: // NOTE(sr): contains anywhere BUT in rule heads gets no special treatment + term = p.parseVar() + case tokens.LBrack: + term = p.parseArray() + case tokens.LBrace: + term = p.parseSetOrObject() + case tokens.LParen: + offset := p.s.loc.Offset + p.scan() + if r := p.parseTermInfixCall(); r != nil { + if p.s.tok == tokens.RParen { + r.Location.Text = p.s.Text(offset, p.s.tokEnd) + term = r + } else { + p.error(p.s.Loc(), "non-terminated expression") + } + } + default: + p.illegalToken() + } + + term = p.parseTermFinish(term, false) + p.parsedTermCachePush(term, s0) + return term +} + +func (p *Parser) parseTermFinish(head *Term, skipws bool) *Term { + if head == nil { + return nil + } + offset := p.s.loc.Offset + p.doScan(skipws) + + switch p.s.tok { + case tokens.LParen, tokens.Dot, tokens.LBrack: + return p.parseRef(head, offset) + case tokens.Whitespace: + p.scan() + fallthrough + default: + if _, ok := head.Value.(Var); ok && RootDocumentNames.Contains(head) { + return RefTerm(head).SetLocation(head.Location) + } + return head + } +} + +func (p *Parser) parseNumber() *Term { + var prefix string + loc := p.s.Loc() + if p.s.tok == tokens.Sub { + prefix = "-" + p.scan() + switch p.s.tok { + case tokens.Number, tokens.Dot: + break + default: + p.illegal("expected number") + return nil + } + } + if p.s.tok == tokens.Dot { + prefix += "." + p.scan() + if p.s.tok != tokens.Number { + p.illegal("expected number") + return nil + } + } + + // Check for multiple leading 0's, parsed by math/big.Float.Parse as decimal 0: + // https://golang.org/pkg/math/big/#Float.Parse + if ((len(prefix) != 0 && prefix[0] == '-') || len(prefix) == 0) && + len(p.s.lit) > 1 && p.s.lit[0] == '0' && p.s.lit[1] == '0' { + p.illegal("expected number") + return nil + } + + // Ensure that the number is valid + s := prefix + p.s.lit + f, ok := new(big.Float).SetString(s) + if !ok { + p.illegal("invalid float") + return nil + } + + // Put limit on size of exponent to prevent non-linear cost of String() + // function on big.Float from causing denial of service: https://github.com/golang/go/issues/11068 + // + // n == sign * mantissa * 2^exp + // 0.5 <= mantissa < 1.0 + // + // The limit is arbitrary. + exp := f.MantExp(nil) + if exp > 1e5 || exp < -1e5 || f.IsInf() { // +/- inf, exp is 0 + p.error(p.s.Loc(), "number too big") + return nil + } + + // Note: Use the original string, do *not* round trip from + // the big.Float as it can cause precision loss. + return NumberTerm(json.Number(s)).SetLocation(loc) +} + +func (p *Parser) parseString() *Term { + if p.s.lit[0] == '"' { + var s string + err := json.Unmarshal([]byte(p.s.lit), &s) + if err != nil { + p.errorf(p.s.Loc(), "illegal string literal: %s", p.s.lit) + return nil + } + term := StringTerm(s).SetLocation(p.s.Loc()) + return term + } + return p.parseRawString() +} + +func (p *Parser) parseRawString() *Term { + if len(p.s.lit) < 2 { + return nil + } + term := StringTerm(p.s.lit[1 : len(p.s.lit)-1]).SetLocation(p.s.Loc()) + return term +} + +// this is the name to use for instantiating an empty set, e.g., `set()`. +var setConstructor = RefTerm(VarTerm("set")) + +func (p *Parser) parseCall(operator *Term, offset int) (term *Term) { + + loc := operator.Location + var end int + + defer func() { + p.setLoc(term, loc, offset, end) + }() + + p.scan() // steps over '(' + + if p.s.tok == tokens.RParen { // no args, i.e. set() or any.func() + end = p.s.tokEnd + p.scanWS() + if operator.Equal(setConstructor) { + return SetTerm() + } + return CallTerm(operator) + } + + if r := p.parseTermList(tokens.RParen, []*Term{operator}); r != nil { + end = p.s.tokEnd + p.scanWS() + return CallTerm(r...) + } + + return nil +} + +func (p *Parser) parseRef(head *Term, offset int) (term *Term) { + + loc := head.Location + var end int + + defer func() { + p.setLoc(term, loc, offset, end) + }() + + switch h := head.Value.(type) { + case Var, *Array, Object, Set, *ArrayComprehension, *ObjectComprehension, *SetComprehension, Call: + // ok + default: + p.errorf(loc, "illegal ref (head cannot be %v)", TypeName(h)) + } + + ref := []*Term{head} + + for { + switch p.s.tok { + case tokens.Dot: + p.scanWS() + if p.s.tok != tokens.Ident { + p.illegal("expected %v", tokens.Ident) + return nil + } + ref = append(ref, StringTerm(p.s.lit).SetLocation(p.s.Loc())) + p.scanWS() + case tokens.LParen: + term = p.parseCall(p.setLoc(RefTerm(ref...), loc, offset, p.s.loc.Offset), offset) + if term != nil { + switch p.s.tok { + case tokens.Whitespace: + p.scan() + end = p.s.lastEnd + return term + case tokens.Dot, tokens.LBrack: + term = p.parseRef(term, offset) + } + } + end = p.s.tokEnd + return term + case tokens.LBrack: + p.scan() + if term := p.parseTermInfixCall(); term != nil { + if p.s.tok != tokens.RBrack { + p.illegal("expected %v", tokens.LBrack) + return nil + } + ref = append(ref, term) + p.scanWS() + } else { + return nil + } + case tokens.Whitespace: + end = p.s.lastEnd + p.scan() + return RefTerm(ref...) + default: + end = p.s.lastEnd + return RefTerm(ref...) + } + } +} + +func (p *Parser) parseArray() (term *Term) { + + loc := p.s.Loc() + offset := p.s.loc.Offset + + defer func() { + p.setLoc(term, loc, offset, p.s.tokEnd) + }() + + p.scan() + + if p.s.tok == tokens.RBrack { + return ArrayTerm() + } + + potentialComprehension := true + + // Skip leading commas, eg [, x, y] + // Supported for backwards compatibility. In the future + // we should make this a parse error. + if p.s.tok == tokens.Comma { + potentialComprehension = false + p.scan() + } + + s := p.save() + + // NOTE(tsandall): The parser cannot attempt a relational term here because + // of ambiguity around comprehensions. For example, given: + // + // {1 | 1} + // + // Does this represent a set comprehension or a set containing binary OR + // call? We resolve the ambiguity by prioritizing comprehensions. + head := p.parseTerm() + + if head == nil { + return nil + } + + switch p.s.tok { + case tokens.RBrack: + return ArrayTerm(head) + case tokens.Comma: + p.scan() + if terms := p.parseTermList(tokens.RBrack, []*Term{head}); terms != nil { + return NewTerm(NewArray(terms...)) + } + return nil + case tokens.Or: + if potentialComprehension { + // Try to parse as if it is an array comprehension + p.scan() + if body := p.parseBody(tokens.RBrack); body != nil { + return ArrayComprehensionTerm(head, body) + } + if p.s.tok != tokens.Comma { + return nil + } + } + // fall back to parsing as a normal array definition + } + + p.restore(s) + + if terms := p.parseTermList(tokens.RBrack, nil); terms != nil { + return NewTerm(NewArray(terms...)) + } + return nil +} + +func (p *Parser) parseSetOrObject() (term *Term) { + loc := p.s.Loc() + offset := p.s.loc.Offset + + defer func() { + p.setLoc(term, loc, offset, p.s.tokEnd) + }() + + p.scan() + + if p.s.tok == tokens.RBrace { + return ObjectTerm() + } + + potentialComprehension := true + + // Skip leading commas, eg {, x, y} + // Supported for backwards compatibility. In the future + // we should make this a parse error. + if p.s.tok == tokens.Comma { + potentialComprehension = false + p.scan() + } + + s := p.save() + + // Try parsing just a single term first to give comprehensions higher + // priority to "or" calls in ambiguous situations. Eg: { a | b } + // will be a set comprehension. + // + // Note: We don't know yet if it is a set or object being defined. + head := p.parseTerm() + if head == nil { + return nil + } + + switch p.s.tok { + case tokens.Or: + if potentialComprehension { + return p.parseSet(s, head, potentialComprehension) + } + case tokens.RBrace, tokens.Comma: + return p.parseSet(s, head, potentialComprehension) + case tokens.Colon: + return p.parseObject(head, potentialComprehension) + } + + p.restore(s) + + head = p.parseTermInfixCallInList() + if head == nil { + return nil + } + + switch p.s.tok { + case tokens.RBrace, tokens.Comma: + return p.parseSet(s, head, false) + case tokens.Colon: + // It still might be an object comprehension, eg { a+1: b | ... } + return p.parseObject(head, potentialComprehension) + } + + p.illegal("non-terminated set") + return nil +} + +func (p *Parser) parseSet(s *state, head *Term, potentialComprehension bool) *Term { + switch p.s.tok { + case tokens.RBrace: + return SetTerm(head) + case tokens.Comma: + p.scan() + if terms := p.parseTermList(tokens.RBrace, []*Term{head}); terms != nil { + return SetTerm(terms...) + } + case tokens.Or: + if potentialComprehension { + // Try to parse as if it is a set comprehension + p.scan() + if body := p.parseBody(tokens.RBrace); body != nil { + return SetComprehensionTerm(head, body) + } + if p.s.tok != tokens.Comma { + return nil + } + } + // Fall back to parsing as normal set definition + p.restore(s) + if terms := p.parseTermList(tokens.RBrace, nil); terms != nil { + return SetTerm(terms...) + } + } + return nil +} + +func (p *Parser) parseObject(k *Term, potentialComprehension bool) *Term { + // NOTE(tsandall): Assumption: this function is called after parsing the key + // of the head element and then receiving a colon token from the scanner. + // Advance beyond the colon and attempt to parse an object. + if p.s.tok != tokens.Colon { + panic("expected colon") + } + p.scan() + + s := p.save() + + // NOTE(sr): We first try to parse the value as a term (`v`), and see + // if we can parse `{ x: v | ...}` as a comprehension. + // However, if we encounter either a Comma or an RBace, it cannot be + // parsed as a comprehension -- so we save double work further down + // where `parseObjectFinish(k, v, false)` would only exercise the + // same code paths once more. + v := p.parseTerm() + if v == nil { + return nil + } + + potentialRelation := true + if potentialComprehension { + switch p.s.tok { + case tokens.RBrace, tokens.Comma: + potentialRelation = false + fallthrough + case tokens.Or: + if term := p.parseObjectFinish(k, v, true); term != nil { + return term + } + } + } + + p.restore(s) + + if potentialRelation { + v := p.parseTermInfixCallInList() + if v == nil { + return nil + } + + switch p.s.tok { + case tokens.RBrace, tokens.Comma: + return p.parseObjectFinish(k, v, false) + } + } + + p.illegal("non-terminated object") + return nil +} + +func (p *Parser) parseObjectFinish(key, val *Term, potentialComprehension bool) *Term { + switch p.s.tok { + case tokens.RBrace: + return ObjectTerm([2]*Term{key, val}) + case tokens.Or: + if potentialComprehension { + p.scan() + if body := p.parseBody(tokens.RBrace); body != nil { + return ObjectComprehensionTerm(key, val, body) + } + } else { + p.illegal("non-terminated object") + } + case tokens.Comma: + p.scan() + if r := p.parseTermPairList(tokens.RBrace, [][2]*Term{{key, val}}); r != nil { + return ObjectTerm(r...) + } + } + return nil +} + +func (p *Parser) parseTermList(end tokens.Token, r []*Term) []*Term { + if p.s.tok == end { + return r + } + for { + term := p.parseTermInfixCallInList() + if term != nil { + r = append(r, term) + switch p.s.tok { + case end: + return r + case tokens.Comma: + p.scan() + if p.s.tok == end { + return r + } + continue + default: + p.illegal(fmt.Sprintf("expected %q or %q", tokens.Comma, end)) + return nil + } + } + return nil + } +} + +func (p *Parser) parseTermPairList(end tokens.Token, r [][2]*Term) [][2]*Term { + if p.s.tok == end { + return r + } + for { + key := p.parseTermInfixCallInList() + if key != nil { + switch p.s.tok { + case tokens.Colon: + p.scan() + if val := p.parseTermInfixCallInList(); val != nil { + r = append(r, [2]*Term{key, val}) + switch p.s.tok { + case end: + return r + case tokens.Comma: + p.scan() + if p.s.tok == end { + return r + } + continue + default: + p.illegal(fmt.Sprintf("expected %q or %q", tokens.Comma, end)) + return nil + } + } + default: + p.illegal(fmt.Sprintf("expected %q", tokens.Colon)) + return nil + } + } + return nil + } +} + +func (p *Parser) parseTermOp(values ...tokens.Token) *Term { + for i := range values { + if p.s.tok == values[i] { + r := RefTerm(VarTerm(fmt.Sprint(p.s.tok)).SetLocation(p.s.Loc())).SetLocation(p.s.Loc()) + p.scan() + return r + } + } + return nil +} + +func (p *Parser) parseTermOpName(ref Ref, values ...tokens.Token) *Term { + for i := range values { + if p.s.tok == values[i] { + cp := ref.Copy() + for _, r := range cp { + r.SetLocation(p.s.Loc()) + } + t := RefTerm(cp...) + t.SetLocation(p.s.Loc()) + p.scan() + return t + } + } + return nil +} + +func (p *Parser) parseVar() *Term { + + s := p.s.lit + + term := VarTerm(s).SetLocation(p.s.Loc()) + + // Update wildcard values with unique identifiers + if term.Equal(Wildcard) { + term.Value = Var(p.genwildcard()) + } + + return term +} + +func (p *Parser) genwildcard() string { + c := p.s.wildcard + p.s.wildcard++ + return fmt.Sprintf("%v%d", WildcardPrefix, c) +} + +func (p *Parser) error(loc *location.Location, reason string) { + p.errorf(loc, reason) //nolint:govet +} + +func (p *Parser) errorf(loc *location.Location, f string, a ...interface{}) { + msg := strings.Builder{} + msg.WriteString(fmt.Sprintf(f, a...)) + + switch len(p.s.hints) { + case 0: // nothing to do + case 1: + msg.WriteString(" (hint: ") + msg.WriteString(p.s.hints[0]) + msg.WriteRune(')') + default: + msg.WriteString(" (hints: ") + for i, h := range p.s.hints { + if i > 0 { + msg.WriteString(", ") + } + msg.WriteString(h) + } + msg.WriteRune(')') + } + + p.s.errors = append(p.s.errors, &Error{ + Code: ParseErr, + Message: msg.String(), + Location: loc, + Details: newParserErrorDetail(p.s.s.Bytes(), loc.Offset), + }) + p.s.hints = nil +} + +func (p *Parser) hint(f string, a ...interface{}) { + p.s.hints = append(p.s.hints, fmt.Sprintf(f, a...)) +} + +func (p *Parser) illegal(note string, a ...interface{}) { + tok := p.s.tok.String() + + if p.s.tok == tokens.Illegal { + p.errorf(p.s.Loc(), "illegal token") + return + } + + tokType := "token" + if tokens.IsKeyword(p.s.tok) { + tokType = "keyword" + } else if _, ok := allFutureKeywords[p.s.tok.String()]; ok { + tokType = "keyword" + } + + note = fmt.Sprintf(note, a...) + if len(note) > 0 { + p.errorf(p.s.Loc(), "unexpected %s %s: %s", tok, tokType, note) + } else { + p.errorf(p.s.Loc(), "unexpected %s %s", tok, tokType) + } +} + +func (p *Parser) illegalToken() { + p.illegal("") +} + +func (p *Parser) scan() { + p.doScan(true) +} + +func (p *Parser) scanWS() { + p.doScan(false) +} + +func (p *Parser) doScan(skipws bool) { + + // NOTE(tsandall): the last position is used to compute the "text" field for + // complex AST nodes. Whitespace never affects the last position of an AST + // node so do not update it when scanning. + if p.s.tok != tokens.Whitespace { + p.s.lastEnd = p.s.tokEnd + p.s.skippedNL = false + } + + var errs []scanner.Error + for { + var pos scanner.Position + p.s.tok, pos, p.s.lit, errs = p.s.s.Scan() + + p.s.tokEnd = pos.End + p.s.loc.Row = pos.Row + p.s.loc.Col = pos.Col + p.s.loc.Offset = pos.Offset + p.s.loc.Text = p.s.Text(pos.Offset, pos.End) + p.s.loc.Tabs = pos.Tabs + + for _, err := range errs { + p.error(p.s.Loc(), err.Message) + } + + if len(errs) > 0 { + p.s.tok = tokens.Illegal + } + + if p.s.tok == tokens.Whitespace { + if p.s.lit == "\n" { + p.s.skippedNL = true + } + if skipws { + continue + } + } + + if p.s.tok != tokens.Comment { + break + } + + // For backwards compatibility leave a nil + // Text value if there is no text rather than + // an empty string. + var commentText []byte + if len(p.s.lit) > 1 { + commentText = []byte(p.s.lit[1:]) + } + comment := NewComment(commentText) + comment.SetLoc(p.s.Loc()) + p.s.comments = append(p.s.comments, comment) + } +} + +func (p *Parser) save() *state { + cpy := *p.s + s := *cpy.s + cpy.s = &s + return &cpy +} + +func (p *Parser) restore(s *state) { + p.s = s +} + +func setLocRecursive(x interface{}, loc *location.Location) { + NewGenericVisitor(func(x interface{}) bool { + if node, ok := x.(Node); ok { + node.SetLoc(loc) + } + return false + }).Walk(x) +} + +func (p *Parser) setLoc(term *Term, loc *location.Location, offset, end int) *Term { + if term != nil { + cpy := *loc + term.Location = &cpy + term.Location.Text = p.s.Text(offset, end) + } + return term +} + +func (p *Parser) validateDefaultRuleValue(rule *Rule) bool { + if rule.Head.Value == nil { + p.error(rule.Loc(), "illegal default rule (must have a value)") + return false + } + + valid := true + vis := NewGenericVisitor(func(x interface{}) bool { + switch x.(type) { + case *ArrayComprehension, *ObjectComprehension, *SetComprehension: // skip closures + return true + case Ref, Var, Call: + p.error(rule.Loc(), fmt.Sprintf("illegal default rule (value cannot contain %v)", TypeName(x))) + valid = false + return true + } + return false + }) + + vis.Walk(rule.Head.Value.Value) + return valid +} + +func (p *Parser) validateDefaultRuleArgs(rule *Rule) bool { + + valid := true + vars := NewVarSet() + + vis := NewGenericVisitor(func(x interface{}) bool { + switch x := x.(type) { + case Var: + if vars.Contains(x) { + p.error(rule.Loc(), fmt.Sprintf("illegal default rule (arguments cannot be repeated %v)", x)) + valid = false + return true + } + vars.Add(x) + + case *Term: + switch v := x.Value.(type) { + case Var: // do nothing + default: + p.error(rule.Loc(), fmt.Sprintf("illegal default rule (arguments cannot contain %v)", TypeName(v))) + valid = false + return true + } + } + + return false + }) + + vis.Walk(rule.Head.Args) + return valid +} + +// We explicitly use yaml unmarshalling, to accommodate for the '_' in 'related_resources', +// which isn't handled properly by json for some reason. +type rawAnnotation struct { + Scope string `yaml:"scope"` + Title string `yaml:"title"` + Entrypoint bool `yaml:"entrypoint"` + Description string `yaml:"description"` + Organizations []string `yaml:"organizations"` + RelatedResources []interface{} `yaml:"related_resources"` + Authors []interface{} `yaml:"authors"` + Schemas []map[string]any `yaml:"schemas"` + Custom map[string]interface{} `yaml:"custom"` +} + +type metadataParser struct { + buf *bytes.Buffer + comments []*Comment + loc *location.Location +} + +func newMetadataParser(loc *Location) *metadataParser { + return &metadataParser{loc: loc, buf: bytes.NewBuffer(nil)} +} + +func (b *metadataParser) Append(c *Comment) { + b.buf.Write(bytes.TrimPrefix(c.Text, []byte(" "))) + b.buf.WriteByte('\n') + b.comments = append(b.comments, c) +} + +var yamlLineErrRegex = regexp.MustCompile(`^yaml:(?: unmarshal errors:[\n\s]*)? line ([[:digit:]]+):`) + +func (b *metadataParser) Parse() (*Annotations, error) { + + var raw rawAnnotation + + if len(bytes.TrimSpace(b.buf.Bytes())) == 0 { + return nil, fmt.Errorf("expected METADATA block, found whitespace") + } + + if err := yaml.Unmarshal(b.buf.Bytes(), &raw); err != nil { + var comment *Comment + match := yamlLineErrRegex.FindStringSubmatch(err.Error()) + if len(match) == 2 { + index, err2 := strconv.Atoi(match[1]) + if err2 == nil { + if index >= len(b.comments) { + comment = b.comments[len(b.comments)-1] + } else { + comment = b.comments[index] + } + b.loc = comment.Location + } + } + + if match == nil && len(b.comments) > 0 { + b.loc = b.comments[0].Location + } + + return nil, augmentYamlError(err, b.comments) + } + + var result Annotations + result.comments = b.comments + result.Scope = raw.Scope + result.Entrypoint = raw.Entrypoint + result.Title = raw.Title + result.Description = raw.Description + result.Organizations = raw.Organizations + + for _, v := range raw.RelatedResources { + rr, err := parseRelatedResource(v) + if err != nil { + return nil, fmt.Errorf("invalid related-resource definition %s: %w", v, err) + } + result.RelatedResources = append(result.RelatedResources, rr) + } + + for _, pair := range raw.Schemas { + k, v := unwrapPair(pair) + + var a SchemaAnnotation + var err error + + a.Path, err = ParseRef(k) + if err != nil { + return nil, fmt.Errorf("invalid document reference") + } + + switch v := v.(type) { + case string: + a.Schema, err = parseSchemaRef(v) + if err != nil { + return nil, err + } + case map[string]any: + w, err := convertYAMLMapKeyTypes(v, nil) + if err != nil { + return nil, fmt.Errorf("invalid schema definition: %w", err) + } + a.Definition = &w + default: + return nil, fmt.Errorf("invalid schema declaration for path %q", k) + } + + result.Schemas = append(result.Schemas, &a) + } + + for _, v := range raw.Authors { + author, err := parseAuthor(v) + if err != nil { + return nil, fmt.Errorf("invalid author definition %s: %w", v, err) + } + result.Authors = append(result.Authors, author) + } + + result.Custom = make(map[string]interface{}) + for k, v := range raw.Custom { + val, err := convertYAMLMapKeyTypes(v, nil) + if err != nil { + return nil, err + } + result.Custom[k] = val + } + + result.Location = b.loc + + // recreate original text of entire metadata block for location text attribute + sb := strings.Builder{} + sb.WriteString("# METADATA\n") + + lines := bytes.Split(b.buf.Bytes(), []byte{'\n'}) + + for _, line := range lines[:len(lines)-1] { + sb.WriteString("# ") + sb.Write(line) + sb.WriteByte('\n') + } + + result.Location.Text = []byte(strings.TrimSuffix(sb.String(), "\n")) + + return &result, nil +} + +// augmentYamlError augments a YAML error with hints intended to help the user figure out the cause of an otherwise +// cryptic error. These are hints, instead of proper errors, because they are educated guesses, and aren't guaranteed +// to be correct. +func augmentYamlError(err error, comments []*Comment) error { + // Adding hints for when key/value ':' separator isn't suffixed with a legal YAML space symbol + for _, comment := range comments { + txt := string(comment.Text) + parts := strings.Split(txt, ":") + if len(parts) > 1 { + parts = parts[1:] + var invalidSpaces []string + for partIndex, part := range parts { + if len(part) == 0 && partIndex == len(parts)-1 { + invalidSpaces = []string{} + break + } + + r, _ := utf8.DecodeRuneInString(part) + if r == ' ' || r == '\t' { + invalidSpaces = []string{} + break + } + + invalidSpaces = append(invalidSpaces, fmt.Sprintf("%+q", r)) + } + if len(invalidSpaces) > 0 { + err = fmt.Errorf( + "%s\n Hint: on line %d, symbol(s) %v immediately following a key/value separator ':' is not a legal yaml space character", + err.Error(), comment.Location.Row, invalidSpaces) + } + } + } + return err +} + +func unwrapPair(pair map[string]interface{}) (string, interface{}) { + for k, v := range pair { + return k, v + } + return "", nil +} + +var errInvalidSchemaRef = fmt.Errorf("invalid schema reference") + +// NOTE(tsandall): 'schema' is not registered as a root because it's not +// supported by the compiler or evaluator today. Once we fix that, we can remove +// this function. +func parseSchemaRef(s string) (Ref, error) { + + term, err := ParseTerm(s) + if err == nil { + switch v := term.Value.(type) { + case Var: + if term.Equal(SchemaRootDocument) { + return SchemaRootRef.Copy(), nil + } + case Ref: + if v.HasPrefix(SchemaRootRef) { + return v, nil + } + } + } + + return nil, errInvalidSchemaRef +} + +func parseRelatedResource(rr interface{}) (*RelatedResourceAnnotation, error) { + rr, err := convertYAMLMapKeyTypes(rr, nil) + if err != nil { + return nil, err + } + + switch rr := rr.(type) { + case string: + if len(rr) > 0 { + u, err := url.Parse(rr) + if err != nil { + return nil, err + } + return &RelatedResourceAnnotation{Ref: *u}, nil + } + return nil, fmt.Errorf("ref URL may not be empty string") + case map[string]interface{}: + description := strings.TrimSpace(getSafeString(rr, "description")) + ref := strings.TrimSpace(getSafeString(rr, "ref")) + if len(ref) > 0 { + u, err := url.Parse(ref) + if err != nil { + return nil, err + } + return &RelatedResourceAnnotation{Description: description, Ref: *u}, nil + } + return nil, fmt.Errorf("'ref' value required in object") + } + + return nil, fmt.Errorf("invalid value type, must be string or map") +} + +func parseAuthor(a interface{}) (*AuthorAnnotation, error) { + a, err := convertYAMLMapKeyTypes(a, nil) + if err != nil { + return nil, err + } + + switch a := a.(type) { + case string: + return parseAuthorString(a) + case map[string]interface{}: + name := strings.TrimSpace(getSafeString(a, "name")) + email := strings.TrimSpace(getSafeString(a, "email")) + if len(name) > 0 || len(email) > 0 { + return &AuthorAnnotation{name, email}, nil + } + return nil, fmt.Errorf("'name' and/or 'email' values required in object") + } + + return nil, fmt.Errorf("invalid value type, must be string or map") +} + +func getSafeString(m map[string]interface{}, k string) string { + if v, found := m[k]; found { + if s, ok := v.(string); ok { + return s + } + } + return "" +} + +const emailPrefix = "<" +const emailSuffix = ">" + +// parseAuthor parses a string into an AuthorAnnotation. If the last word of the input string is enclosed within <>, +// it is extracted as the author's email. The email may not contain whitelines, as it then will be interpreted as +// multiple words. +func parseAuthorString(s string) (*AuthorAnnotation, error) { + parts := strings.Fields(s) + + if len(parts) == 0 { + return nil, fmt.Errorf("author is an empty string") + } + + namePartCount := len(parts) + trailing := parts[namePartCount-1] + var email string + if len(trailing) >= len(emailPrefix)+len(emailSuffix) && strings.HasPrefix(trailing, emailPrefix) && + strings.HasSuffix(trailing, emailSuffix) { + email = trailing[len(emailPrefix):] + email = email[0 : len(email)-len(emailSuffix)] + namePartCount = namePartCount - 1 + } + + name := strings.Join(parts[0:namePartCount], " ") + + return &AuthorAnnotation{Name: name, Email: email}, nil +} + +func convertYAMLMapKeyTypes(x any, path []string) (any, error) { + var err error + switch x := x.(type) { + case map[any]any: + result := make(map[string]any, len(x)) + for k, v := range x { + str, ok := k.(string) + if !ok { + return nil, fmt.Errorf("invalid map key type(s): %v", strings.Join(path, "/")) + } + result[str], err = convertYAMLMapKeyTypes(v, append(path, str)) + if err != nil { + return nil, err + } + } + return result, nil + case []any: + for i := range x { + x[i], err = convertYAMLMapKeyTypes(x[i], append(path, fmt.Sprintf("%d", i))) + if err != nil { + return nil, err + } + } + return x, nil + default: + return x, nil + } +} + +// futureKeywords is the source of truth for future keywords that will +// eventually become standard keywords inside of Rego. +var futureKeywords = map[string]tokens.Token{} + +// futureKeywordsV0 is the source of truth for future keywords that were +// not yet a standard part of Rego in v0, and required importing. +var futureKeywordsV0 = map[string]tokens.Token{ + "in": tokens.In, + "every": tokens.Every, + "contains": tokens.Contains, + "if": tokens.If, +} + +var allFutureKeywords map[string]tokens.Token + +func IsFutureKeyword(s string) bool { + return IsFutureKeywordForRegoVersion(s, RegoV1) +} + +func IsFutureKeywordForRegoVersion(s string, v RegoVersion) bool { + var yes bool + + switch v { + case RegoV0, RegoV0CompatV1: + _, yes = futureKeywordsV0[s] + case RegoV1: + _, yes = futureKeywords[s] + } + + return yes +} + +func (p *Parser) futureImport(imp *Import, allowedFutureKeywords map[string]tokens.Token) { + path := imp.Path.Value.(Ref) + + if len(path) == 1 || !path[1].Equal(StringTerm("keywords")) { + p.errorf(imp.Path.Location, "invalid import, must be `future.keywords`") + return + } + + if imp.Alias != "" { + p.errorf(imp.Path.Location, "`future` imports cannot be aliased") + return + } + + kwds := make([]string, 0, len(allowedFutureKeywords)) + for k := range allowedFutureKeywords { + kwds = append(kwds, k) + } + + switch len(path) { + case 2: // all keywords imported, nothing to do + case 3: // one keyword imported + kw, ok := path[2].Value.(String) + if !ok { + p.errorf(imp.Path.Location, "invalid import, must be `future.keywords.x`, e.g. `import future.keywords.in`") + return + } + keyword := string(kw) + _, ok = allowedFutureKeywords[keyword] + if !ok { + sort.Strings(kwds) // so the error message is stable + p.errorf(imp.Path.Location, "unexpected keyword, must be one of %v", kwds) + return + } + + kwds = []string{keyword} // overwrite + } + for _, kw := range kwds { + p.s.s.AddKeyword(kw, allowedFutureKeywords[kw]) + } +} + +func (p *Parser) regoV1Import(imp *Import) { + if !p.po.Capabilities.ContainsFeature(FeatureRegoV1Import) && !p.po.Capabilities.ContainsFeature(FeatureRegoV1) { + p.errorf(imp.Path.Location, "invalid import, `%s` is not supported by current capabilities", RegoV1CompatibleRef) + return + } + + path := imp.Path.Value.(Ref) + + // v1 is only valid option + if len(path) == 1 || !path[1].Equal(RegoV1CompatibleRef[1]) || len(path) > 2 { + p.errorf(imp.Path.Location, "invalid import `%s`, must be `%s`", path, RegoV1CompatibleRef) + return + } + + if p.po.EffectiveRegoVersion() == RegoV1 { + // We're parsing for Rego v1, where the 'rego.v1' import is a no-op. + return + } + + if imp.Alias != "" { + p.errorf(imp.Path.Location, "`rego` imports cannot be aliased") + return + } + + // import all future keywords with the rego.v1 import + kwds := make([]string, 0, len(futureKeywordsV0)) + for k := range futureKeywordsV0 { + kwds = append(kwds, k) + } + + p.s.s.SetRegoV1Compatible() + for _, kw := range kwds { + p.s.s.AddKeyword(kw, futureKeywordsV0[kw]) + } +} + +func init() { + allFutureKeywords = map[string]tokens.Token{} + for k, v := range futureKeywords { + allFutureKeywords[k] = v + } + for k, v := range futureKeywordsV0 { + allFutureKeywords[k] = v + } +} diff --git a/vendor/github.com/open-policy-agent/opa/v1/ast/parser_ext.go b/vendor/github.com/open-policy-agent/opa/v1/ast/parser_ext.go new file mode 100644 index 000000000..f08c112a7 --- /dev/null +++ b/vendor/github.com/open-policy-agent/opa/v1/ast/parser_ext.go @@ -0,0 +1,854 @@ +// Copyright 2016 The OPA Authors. All rights reserved. +// Use of this source code is governed by an Apache2 +// license that can be found in the LICENSE file. + +// This file contains extra functions for parsing Rego. +// Most of the parsing is handled by the code in parser.go, +// however, there are additional utilities that are +// helpful for dealing with Rego source inputs (e.g., REPL +// statements, source files, etc.) + +package ast + +import ( + "bytes" + "errors" + "fmt" + "strings" + "unicode" + + "github.com/open-policy-agent/opa/v1/ast/internal/tokens" + astJSON "github.com/open-policy-agent/opa/v1/ast/json" +) + +// MustParseBody returns a parsed body. +// If an error occurs during parsing, panic. +func MustParseBody(input string) Body { + return MustParseBodyWithOpts(input, ParserOptions{}) +} + +// MustParseBodyWithOpts returns a parsed body. +// If an error occurs during parsing, panic. +func MustParseBodyWithOpts(input string, opts ParserOptions) Body { + parsed, err := ParseBodyWithOpts(input, opts) + if err != nil { + panic(err) + } + return parsed +} + +// MustParseExpr returns a parsed expression. +// If an error occurs during parsing, panic. +func MustParseExpr(input string) *Expr { + parsed, err := ParseExpr(input) + if err != nil { + panic(err) + } + return parsed +} + +// MustParseImports returns a slice of imports. +// If an error occurs during parsing, panic. +func MustParseImports(input string) []*Import { + parsed, err := ParseImports(input) + if err != nil { + panic(err) + } + return parsed +} + +// MustParseModule returns a parsed module. +// If an error occurs during parsing, panic. +func MustParseModule(input string) *Module { + return MustParseModuleWithOpts(input, ParserOptions{}) +} + +// MustParseModuleWithOpts returns a parsed module. +// If an error occurs during parsing, panic. +func MustParseModuleWithOpts(input string, opts ParserOptions) *Module { + parsed, err := ParseModuleWithOpts("", input, opts) + if err != nil { + panic(err) + } + return parsed +} + +// MustParsePackage returns a Package. +// If an error occurs during parsing, panic. +func MustParsePackage(input string) *Package { + parsed, err := ParsePackage(input) + if err != nil { + panic(err) + } + return parsed +} + +// MustParseStatements returns a slice of parsed statements. +// If an error occurs during parsing, panic. +func MustParseStatements(input string) []Statement { + parsed, _, err := ParseStatements("", input) + if err != nil { + panic(err) + } + return parsed +} + +// MustParseStatement returns exactly one statement. +// If an error occurs during parsing, panic. +func MustParseStatement(input string) Statement { + parsed, err := ParseStatement(input) + if err != nil { + panic(err) + } + return parsed +} + +func MustParseStatementWithOpts(input string, popts ParserOptions) Statement { + parsed, err := ParseStatementWithOpts(input, popts) + if err != nil { + panic(err) + } + return parsed +} + +// MustParseRef returns a parsed reference. +// If an error occurs during parsing, panic. +func MustParseRef(input string) Ref { + parsed, err := ParseRef(input) + if err != nil { + panic(err) + } + return parsed +} + +// MustParseRule returns a parsed rule. +// If an error occurs during parsing, panic. +func MustParseRule(input string) *Rule { + parsed, err := ParseRule(input) + if err != nil { + panic(err) + } + return parsed +} + +// MustParseRuleWithOpts returns a parsed rule. +// If an error occurs during parsing, panic. +func MustParseRuleWithOpts(input string, opts ParserOptions) *Rule { + parsed, err := ParseRuleWithOpts(input, opts) + if err != nil { + panic(err) + } + return parsed +} + +// MustParseTerm returns a parsed term. +// If an error occurs during parsing, panic. +func MustParseTerm(input string) *Term { + parsed, err := ParseTerm(input) + if err != nil { + panic(err) + } + return parsed +} + +// ParseRuleFromBody returns a rule if the body can be interpreted as a rule +// definition. Otherwise, an error is returned. +func ParseRuleFromBody(module *Module, body Body) (*Rule, error) { + + if len(body) != 1 { + return nil, fmt.Errorf("multiple expressions cannot be used for rule head") + } + + return ParseRuleFromExpr(module, body[0]) +} + +// ParseRuleFromExpr returns a rule if the expression can be interpreted as a +// rule definition. +func ParseRuleFromExpr(module *Module, expr *Expr) (*Rule, error) { + + if len(expr.With) > 0 { + return nil, fmt.Errorf("expressions using with keyword cannot be used for rule head") + } + + if expr.Negated { + return nil, fmt.Errorf("negated expressions cannot be used for rule head") + } + + if _, ok := expr.Terms.(*SomeDecl); ok { + return nil, errors.New("'some' declarations cannot be used for rule head") + } + + if term, ok := expr.Terms.(*Term); ok { + switch v := term.Value.(type) { + case Ref: + if len(v) > 2 { // 2+ dots + return ParseCompleteDocRuleWithDotsFromTerm(module, term) + } + return ParsePartialSetDocRuleFromTerm(module, term) + default: + return nil, fmt.Errorf("%v cannot be used for rule name", TypeName(v)) + } + } + + if _, ok := expr.Terms.([]*Term); !ok { + // This is a defensive check in case other kinds of expression terms are + // introduced in the future. + return nil, errors.New("expression cannot be used for rule head") + } + + if expr.IsEquality() { + return parseCompleteRuleFromEq(module, expr) + } else if expr.IsAssignment() { + rule, err := parseCompleteRuleFromEq(module, expr) + if err != nil { + return nil, err + } + rule.Head.Assign = true + return rule, nil + } + + if _, ok := BuiltinMap[expr.Operator().String()]; ok { + return nil, fmt.Errorf("rule name conflicts with built-in function") + } + + return ParseRuleFromCallExpr(module, expr.Terms.([]*Term)) +} + +func parseCompleteRuleFromEq(module *Module, expr *Expr) (rule *Rule, err error) { + + // ensure the rule location is set to the expr location + // the helper functions called below try to set the location based + // on the terms they've been provided but that is not as accurate. + defer func() { + if rule != nil { + rule.Location = expr.Location + rule.Head.Location = expr.Location + } + }() + + lhs, rhs := expr.Operand(0), expr.Operand(1) + if lhs == nil || rhs == nil { + return nil, errors.New("assignment requires two operands") + } + + rule, err = ParseRuleFromCallEqExpr(module, lhs, rhs) + if err == nil { + return rule, nil + } + + rule, err = ParsePartialObjectDocRuleFromEqExpr(module, lhs, rhs) + if err == nil { + return rule, nil + } + + return ParseCompleteDocRuleFromEqExpr(module, lhs, rhs) +} + +// ParseCompleteDocRuleFromAssignmentExpr returns a rule if the expression can +// be interpreted as a complete document definition declared with the assignment +// operator. +func ParseCompleteDocRuleFromAssignmentExpr(module *Module, lhs, rhs *Term) (*Rule, error) { + + rule, err := ParseCompleteDocRuleFromEqExpr(module, lhs, rhs) + if err != nil { + return nil, err + } + + rule.Head.Assign = true + + return rule, nil +} + +// ParseCompleteDocRuleFromEqExpr returns a rule if the expression can be +// interpreted as a complete document definition. +func ParseCompleteDocRuleFromEqExpr(module *Module, lhs, rhs *Term) (*Rule, error) { + var head *Head + + if v, ok := lhs.Value.(Var); ok { + // Modify the code to add the location to the head ref + // and set the head ref's jsonOptions. + head = VarHead(v, lhs.Location, &lhs.jsonOptions) + } else if r, ok := lhs.Value.(Ref); ok { // groundness ? + if _, ok := r[0].Value.(Var); !ok { + return nil, fmt.Errorf("invalid rule head: %v", r) + } + head = RefHead(r) + if len(r) > 1 && !r[len(r)-1].IsGround() { + return nil, fmt.Errorf("ref not ground") + } + } else { + return nil, fmt.Errorf("%v cannot be used for rule name", TypeName(lhs.Value)) + } + head.Value = rhs + head.Location = lhs.Location + head.setJSONOptions(lhs.jsonOptions) + + body := NewBody(NewExpr(BooleanTerm(true).SetLocation(rhs.Location)).SetLocation(rhs.Location)) + setJSONOptions(body, &rhs.jsonOptions) + + return &Rule{ + Location: lhs.Location, + Head: head, + Body: body, + Module: module, + jsonOptions: lhs.jsonOptions, + generatedBody: true, + }, nil +} + +func ParseCompleteDocRuleWithDotsFromTerm(module *Module, term *Term) (*Rule, error) { + ref, ok := term.Value.(Ref) + if !ok { + return nil, fmt.Errorf("%v cannot be used for rule name", TypeName(term.Value)) + } + + if _, ok := ref[0].Value.(Var); !ok { + return nil, fmt.Errorf("invalid rule head: %v", ref) + } + head := RefHead(ref, BooleanTerm(true).SetLocation(term.Location)) + head.generatedValue = true + head.Location = term.Location + head.jsonOptions = term.jsonOptions + + body := NewBody(NewExpr(BooleanTerm(true).SetLocation(term.Location)).SetLocation(term.Location)) + setJSONOptions(body, &term.jsonOptions) + + return &Rule{ + Location: term.Location, + Head: head, + Body: body, + Module: module, + + jsonOptions: term.jsonOptions, + }, nil +} + +// ParsePartialObjectDocRuleFromEqExpr returns a rule if the expression can be +// interpreted as a partial object document definition. +func ParsePartialObjectDocRuleFromEqExpr(module *Module, lhs, rhs *Term) (*Rule, error) { + ref, ok := lhs.Value.(Ref) + if !ok { + return nil, fmt.Errorf("%v cannot be used as rule name", TypeName(lhs.Value)) + } + + if _, ok := ref[0].Value.(Var); !ok { + return nil, fmt.Errorf("invalid rule head: %v", ref) + } + + head := RefHead(ref, rhs) + if len(ref) == 2 { // backcompat for naked `foo.bar = "baz"` statements + head.Name = ref[0].Value.(Var) + head.Key = ref[1] + } + head.Location = rhs.Location + head.jsonOptions = rhs.jsonOptions + + body := NewBody(NewExpr(BooleanTerm(true).SetLocation(rhs.Location)).SetLocation(rhs.Location)) + setJSONOptions(body, &rhs.jsonOptions) + + rule := &Rule{ + Location: rhs.Location, + Head: head, + Body: body, + Module: module, + jsonOptions: rhs.jsonOptions, + } + + return rule, nil +} + +// ParsePartialSetDocRuleFromTerm returns a rule if the term can be interpreted +// as a partial set document definition. +func ParsePartialSetDocRuleFromTerm(module *Module, term *Term) (*Rule, error) { + + ref, ok := term.Value.(Ref) + if !ok || len(ref) == 1 { + return nil, fmt.Errorf("%vs cannot be used for rule head", TypeName(term.Value)) + } + if _, ok := ref[0].Value.(Var); !ok { + return nil, fmt.Errorf("invalid rule head: %v", ref) + } + + head := RefHead(ref) + if len(ref) == 2 { + v, ok := ref[0].Value.(Var) + if !ok { + return nil, fmt.Errorf("%vs cannot be used for rule head", TypeName(term.Value)) + } + // Modify the code to add the location to the head ref + // and set the head ref's jsonOptions. + head = VarHead(v, ref[0].Location, &ref[0].jsonOptions) + head.Key = ref[1] + } + head.Location = term.Location + head.jsonOptions = term.jsonOptions + + body := NewBody(NewExpr(BooleanTerm(true).SetLocation(term.Location)).SetLocation(term.Location)) + setJSONOptions(body, &term.jsonOptions) + + rule := &Rule{ + Location: term.Location, + Head: head, + Body: body, + Module: module, + jsonOptions: term.jsonOptions, + } + + return rule, nil +} + +// ParseRuleFromCallEqExpr returns a rule if the term can be interpreted as a +// function definition (e.g., f(x) = y => f(x) = y { true }). +func ParseRuleFromCallEqExpr(module *Module, lhs, rhs *Term) (*Rule, error) { + + call, ok := lhs.Value.(Call) + if !ok { + return nil, fmt.Errorf("must be call") + } + + ref, ok := call[0].Value.(Ref) + if !ok { + return nil, fmt.Errorf("%vs cannot be used in function signature", TypeName(call[0].Value)) + } + if _, ok := ref[0].Value.(Var); !ok { + return nil, fmt.Errorf("invalid rule head: %v", ref) + } + + head := RefHead(ref, rhs) + head.Location = lhs.Location + head.Args = Args(call[1:]) + head.jsonOptions = lhs.jsonOptions + + body := NewBody(NewExpr(BooleanTerm(true).SetLocation(rhs.Location)).SetLocation(rhs.Location)) + setJSONOptions(body, &rhs.jsonOptions) + + rule := &Rule{ + Location: lhs.Location, + Head: head, + Body: body, + Module: module, + jsonOptions: lhs.jsonOptions, + } + + return rule, nil +} + +// ParseRuleFromCallExpr returns a rule if the terms can be interpreted as a +// function returning true or some value (e.g., f(x) => f(x) = true { true }). +func ParseRuleFromCallExpr(module *Module, terms []*Term) (*Rule, error) { + + if len(terms) <= 1 { + return nil, fmt.Errorf("rule argument list must take at least one argument") + } + + loc := terms[0].Location + ref := terms[0].Value.(Ref) + if _, ok := ref[0].Value.(Var); !ok { + return nil, fmt.Errorf("invalid rule head: %v", ref) + } + head := RefHead(ref, BooleanTerm(true).SetLocation(loc)) + head.Location = loc + head.Args = terms[1:] + head.jsonOptions = terms[0].jsonOptions + + body := NewBody(NewExpr(BooleanTerm(true).SetLocation(loc)).SetLocation(loc)) + setJSONOptions(body, &terms[0].jsonOptions) + + rule := &Rule{ + Location: loc, + Head: head, + Module: module, + Body: body, + jsonOptions: terms[0].jsonOptions, + } + return rule, nil +} + +// ParseImports returns a slice of Import objects. +func ParseImports(input string) ([]*Import, error) { + stmts, _, err := ParseStatements("", input) + if err != nil { + return nil, err + } + result := []*Import{} + for _, stmt := range stmts { + if imp, ok := stmt.(*Import); ok { + result = append(result, imp) + } else { + return nil, fmt.Errorf("expected import but got %T", stmt) + } + } + return result, nil +} + +// ParseModule returns a parsed Module object. +// For details on Module objects and their fields, see policy.go. +// Empty input will return nil, nil. +func ParseModule(filename, input string) (*Module, error) { + return ParseModuleWithOpts(filename, input, ParserOptions{}) +} + +// ParseModuleWithOpts returns a parsed Module object, and has an additional input ParserOptions +// For details on Module objects and their fields, see policy.go. +// Empty input will return nil, nil. +func ParseModuleWithOpts(filename, input string, popts ParserOptions) (*Module, error) { + stmts, comments, err := ParseStatementsWithOpts(filename, input, popts) + if err != nil { + return nil, err + } + return parseModule(filename, stmts, comments, popts.RegoVersion) +} + +// ParseBody returns exactly one body. +// If multiple bodies are parsed, an error is returned. +func ParseBody(input string) (Body, error) { + return ParseBodyWithOpts(input, ParserOptions{SkipRules: true}) +} + +// ParseBodyWithOpts returns exactly one body. It does _not_ set SkipRules: true on its own, +// but respects whatever ParserOptions it's been given. +func ParseBodyWithOpts(input string, popts ParserOptions) (Body, error) { + + stmts, _, err := ParseStatementsWithOpts("", input, popts) + if err != nil { + return nil, err + } + + result := Body{} + + for _, stmt := range stmts { + switch stmt := stmt.(type) { + case Body: + for i := range stmt { + result.Append(stmt[i]) + } + case *Comment: + // skip + default: + return nil, fmt.Errorf("expected body but got %T", stmt) + } + } + + return result, nil +} + +// ParseExpr returns exactly one expression. +// If multiple expressions are parsed, an error is returned. +func ParseExpr(input string) (*Expr, error) { + body, err := ParseBody(input) + if err != nil { + return nil, fmt.Errorf("failed to parse expression: %w", err) + } + if len(body) != 1 { + return nil, fmt.Errorf("expected exactly one expression but got: %v", body) + } + return body[0], nil +} + +// ParsePackage returns exactly one Package. +// If multiple statements are parsed, an error is returned. +func ParsePackage(input string) (*Package, error) { + stmt, err := ParseStatement(input) + if err != nil { + return nil, err + } + pkg, ok := stmt.(*Package) + if !ok { + return nil, fmt.Errorf("expected package but got %T", stmt) + } + return pkg, nil +} + +// ParseTerm returns exactly one term. +// If multiple terms are parsed, an error is returned. +func ParseTerm(input string) (*Term, error) { + body, err := ParseBody(input) + if err != nil { + return nil, fmt.Errorf("failed to parse term: %w", err) + } + if len(body) != 1 { + return nil, fmt.Errorf("expected exactly one term but got: %v", body) + } + term, ok := body[0].Terms.(*Term) + if !ok { + return nil, fmt.Errorf("expected term but got %v", body[0].Terms) + } + return term, nil +} + +// ParseRef returns exactly one reference. +func ParseRef(input string) (Ref, error) { + term, err := ParseTerm(input) + if err != nil { + return nil, fmt.Errorf("failed to parse ref: %w", err) + } + ref, ok := term.Value.(Ref) + if !ok { + return nil, fmt.Errorf("expected ref but got %v", term) + } + return ref, nil +} + +// ParseRuleWithOpts returns exactly one rule. +// If multiple rules are parsed, an error is returned. +func ParseRuleWithOpts(input string, opts ParserOptions) (*Rule, error) { + stmts, _, err := ParseStatementsWithOpts("", input, opts) + if err != nil { + return nil, err + } + if len(stmts) != 1 { + return nil, fmt.Errorf("expected exactly one statement (rule), got %v = %T, %T", stmts, stmts[0], stmts[1]) + } + rule, ok := stmts[0].(*Rule) + if !ok { + return nil, fmt.Errorf("expected rule but got %T", stmts[0]) + } + return rule, nil +} + +// ParseRule returns exactly one rule. +// If multiple rules are parsed, an error is returned. +func ParseRule(input string) (*Rule, error) { + return ParseRuleWithOpts(input, ParserOptions{}) +} + +// ParseStatement returns exactly one statement. +// A statement might be a term, expression, rule, etc. Regardless, +// this function expects *exactly* one statement. If multiple +// statements are parsed, an error is returned. +func ParseStatement(input string) (Statement, error) { + stmts, _, err := ParseStatements("", input) + if err != nil { + return nil, err + } + if len(stmts) != 1 { + return nil, fmt.Errorf("expected exactly one statement") + } + return stmts[0], nil +} + +func ParseStatementWithOpts(input string, popts ParserOptions) (Statement, error) { + stmts, _, err := ParseStatementsWithOpts("", input, popts) + if err != nil { + return nil, err + } + if len(stmts) != 1 { + return nil, fmt.Errorf("expected exactly one statement") + } + return stmts[0], nil +} + +// ParseStatements is deprecated. Use ParseStatementWithOpts instead. +func ParseStatements(filename, input string) ([]Statement, []*Comment, error) { + return ParseStatementsWithOpts(filename, input, ParserOptions{}) +} + +// ParseStatementsWithOpts returns a slice of parsed statements. This is the +// default return value from the parser. +func ParseStatementsWithOpts(filename, input string, popts ParserOptions) ([]Statement, []*Comment, error) { + + parser := NewParser(). + WithFilename(filename). + WithReader(bytes.NewBufferString(input)). + WithProcessAnnotation(popts.ProcessAnnotation). + WithFutureKeywords(popts.FutureKeywords...). + WithAllFutureKeywords(popts.AllFutureKeywords). + WithCapabilities(popts.Capabilities). + WithSkipRules(popts.SkipRules). + WithJSONOptions(popts.JSONOptions). + WithRegoVersion(popts.RegoVersion). + withUnreleasedKeywords(popts.unreleasedKeywords) + + stmts, comments, errs := parser.Parse() + + if len(errs) > 0 { + return nil, nil, errs + } + + return stmts, comments, nil +} + +func parseModule(filename string, stmts []Statement, comments []*Comment, regoCompatibilityMode RegoVersion) (*Module, error) { + + if len(stmts) == 0 { + return nil, NewError(ParseErr, &Location{File: filename}, "empty module") + } + + var errs Errors + + pkg, ok := stmts[0].(*Package) + if !ok { + loc := stmts[0].Loc() + errs = append(errs, NewError(ParseErr, loc, "package expected")) + } + + mod := &Module{ + Package: pkg, + stmts: stmts, + } + + // The comments slice only holds comments that were not their own statements. + mod.Comments = append(mod.Comments, comments...) + + if regoCompatibilityMode == RegoUndefined { + mod.regoVersion = DefaultRegoVersion + } else { + mod.regoVersion = regoCompatibilityMode + } + + for i, stmt := range stmts[1:] { + switch stmt := stmt.(type) { + case *Import: + mod.Imports = append(mod.Imports, stmt) + if mod.regoVersion == RegoV0 && Compare(stmt.Path.Value, RegoV1CompatibleRef) == 0 { + mod.regoVersion = RegoV0CompatV1 + } + case *Rule: + setRuleModule(stmt, mod) + mod.Rules = append(mod.Rules, stmt) + case Body: + rule, err := ParseRuleFromBody(mod, stmt) + if err != nil { + errs = append(errs, NewError(ParseErr, stmt[0].Location, err.Error())) //nolint:govet + continue + } + rule.generatedBody = true + mod.Rules = append(mod.Rules, rule) + + // NOTE(tsandall): the statement should now be interpreted as a + // rule so update the statement list. This is important for the + // logic below that associates annotations with statements. + stmts[i+1] = rule + case *Package: + errs = append(errs, NewError(ParseErr, stmt.Loc(), "unexpected package")) + case *Annotations: + mod.Annotations = append(mod.Annotations, stmt) + case *Comment: + // Ignore comments, they're handled above. + default: + panic("illegal value") // Indicates grammar is out-of-sync with code. + } + } + + if mod.regoVersion == RegoV0CompatV1 || mod.regoVersion == RegoV1 { + for _, rule := range mod.Rules { + for r := rule; r != nil; r = r.Else { + errs = append(errs, CheckRegoV1(r)...) + } + } + } + + if len(errs) > 0 { + return nil, errs + } + + errs = append(errs, attachAnnotationsNodes(mod)...) + + if len(errs) > 0 { + return nil, errs + } + + attachRuleAnnotations(mod) + + return mod, nil +} + +func ruleDeclarationHasKeyword(rule *Rule, keyword tokens.Token) bool { + for _, kw := range rule.Head.keywords { + if kw == keyword { + return true + } + } + return false +} + +func newScopeAttachmentErr(a *Annotations, want string) *Error { + var have string + if a.node != nil { + have = fmt.Sprintf(" (have %v)", TypeName(a.node)) + } + return NewError(ParseErr, a.Loc(), "annotation scope '%v' must be applied to %v%v", a.Scope, want, have) +} + +func setRuleModule(rule *Rule, module *Module) { + rule.Module = module + if rule.Else != nil { + setRuleModule(rule.Else, module) + } +} + +func setJSONOptions(x interface{}, jsonOptions *astJSON.Options) { + vis := NewGenericVisitor(func(x interface{}) bool { + if x, ok := x.(customJSON); ok { + x.setJSONOptions(*jsonOptions) + } + return false + }) + vis.Walk(x) +} + +// ParserErrorDetail holds additional details for parser errors. +type ParserErrorDetail struct { + Line string `json:"line"` + Idx int `json:"idx"` +} + +func newParserErrorDetail(bs []byte, offset int) *ParserErrorDetail { + + // Find first non-space character at or before offset position. + if offset >= len(bs) { + offset = len(bs) - 1 + } else if offset < 0 { + offset = 0 + } + + for offset > 0 && unicode.IsSpace(rune(bs[offset])) { + offset-- + } + + // Find beginning of line containing offset. + begin := offset + + for begin > 0 && !isNewLineChar(bs[begin]) { + begin-- + } + + if isNewLineChar(bs[begin]) { + begin++ + } + + // Find end of line containing offset. + end := offset + + for end < len(bs) && !isNewLineChar(bs[end]) { + end++ + } + + if begin > end { + begin = end + } + + // Extract line and compute index of offset byte in line. + line := bs[begin:end] + index := offset - begin + + return &ParserErrorDetail{ + Line: string(line), + Idx: index, + } +} + +// Lines returns the pretty formatted line output for the error details. +func (d ParserErrorDetail) Lines() []string { + line := strings.TrimLeft(d.Line, "\t") // remove leading tabs + tabCount := len(d.Line) - len(line) + indent := d.Idx - tabCount + if indent < 0 { + indent = 0 + } + return []string{line, strings.Repeat(" ", indent) + "^"} +} + +func isNewLineChar(b byte) bool { + return b == '\r' || b == '\n' +} diff --git a/vendor/github.com/open-policy-agent/opa/v1/ast/policy.go b/vendor/github.com/open-policy-agent/opa/v1/ast/policy.go new file mode 100644 index 000000000..0e0422a9f --- /dev/null +++ b/vendor/github.com/open-policy-agent/opa/v1/ast/policy.go @@ -0,0 +1,2119 @@ +// Copyright 2016 The OPA Authors. All rights reserved. +// Use of this source code is governed by an Apache2 +// license that can be found in the LICENSE file. + +package ast + +import ( + "bytes" + "encoding/json" + "fmt" + "math/rand" + "strings" + "time" + + "github.com/open-policy-agent/opa/v1/ast/internal/tokens" + astJSON "github.com/open-policy-agent/opa/v1/ast/json" + "github.com/open-policy-agent/opa/v1/util" +) + +// Initialize seed for term hashing. This is intentionally placed before the +// root document sets are constructed to ensure they use the same hash seed as +// subsequent lookups. If the hash seeds are out of sync, lookups will fail. +var hashSeed = rand.New(rand.NewSource(time.Now().UnixNano())) +var hashSeed0 = (uint64(hashSeed.Uint32()) << 32) | uint64(hashSeed.Uint32()) + +// DefaultRootDocument is the default root document. +// +// All package directives inside source files are implicitly prefixed with the +// DefaultRootDocument value. +var DefaultRootDocument = VarTerm("data") + +// InputRootDocument names the document containing query arguments. +var InputRootDocument = VarTerm("input") + +// SchemaRootDocument names the document containing external data schemas. +var SchemaRootDocument = VarTerm("schema") + +// FunctionArgRootDocument names the document containing function arguments. +// It's only for internal usage, for referencing function arguments between +// the index and topdown. +var FunctionArgRootDocument = VarTerm("args") + +// FutureRootDocument names the document containing new, to-become-default, +// features. +var FutureRootDocument = VarTerm("future") + +// RegoRootDocument names the document containing new, to-become-default, +// features in a future versioned release. +var RegoRootDocument = VarTerm("rego") + +// RootDocumentNames contains the names of top-level documents that can be +// referred to in modules and queries. +// +// Note, the schema document is not currently implemented in the evaluator so it +// is not registered as a root document name (yet). +var RootDocumentNames = NewSet( + DefaultRootDocument, + InputRootDocument, +) + +// DefaultRootRef is a reference to the root of the default document. +// +// All refs to data in the policy engine's storage layer are prefixed with this ref. +var DefaultRootRef = Ref{DefaultRootDocument} + +// InputRootRef is a reference to the root of the input document. +// +// All refs to query arguments are prefixed with this ref. +var InputRootRef = Ref{InputRootDocument} + +// SchemaRootRef is a reference to the root of the schema document. +// +// All refs to schema documents are prefixed with this ref. Note, the schema +// document is not currently implemented in the evaluator so it is not +// registered as a root document ref (yet). +var SchemaRootRef = Ref{SchemaRootDocument} + +// RootDocumentRefs contains the prefixes of top-level documents that all +// non-local references start with. +var RootDocumentRefs = NewSet( + NewTerm(DefaultRootRef), + NewTerm(InputRootRef), +) + +// SystemDocumentKey is the name of the top-level key that identifies the system +// document. +const SystemDocumentKey = String("system") + +// ReservedVars is the set of names that refer to implicitly ground vars. +var ReservedVars = NewVarSet( + DefaultRootDocument.Value.(Var), + InputRootDocument.Value.(Var), +) + +// Wildcard represents the wildcard variable as defined in the language. +var Wildcard = &Term{Value: Var("_")} + +// WildcardPrefix is the special character that all wildcard variables are +// prefixed with when the statement they are contained in is parsed. +const WildcardPrefix = "$" + +// Keywords contains strings that map to language keywords. +var Keywords = KeywordsForRegoVersion(DefaultRegoVersion) + +var KeywordsV0 = [...]string{ + "not", + "package", + "import", + "as", + "default", + "else", + "with", + "null", + "true", + "false", + "some", +} + +var KeywordsV1 = [...]string{ + "not", + "package", + "import", + "as", + "default", + "else", + "with", + "null", + "true", + "false", + "some", + "if", + "contains", + "in", + "every", +} + +func KeywordsForRegoVersion(v RegoVersion) []string { + switch v { + case RegoV0: + return KeywordsV0[:] + case RegoV1, RegoV0CompatV1: + return KeywordsV1[:] + } + return nil +} + +// IsKeyword returns true if s is a language keyword. +func IsKeyword(s string) bool { + return IsInKeywords(s, Keywords) +} + +func IsInKeywords(s string, keywords []string) bool { + for _, x := range keywords { + if x == s { + return true + } + } + return false +} + +// IsKeywordInRegoVersion returns true if s is a language keyword. +func IsKeywordInRegoVersion(s string, regoVersion RegoVersion) bool { + switch regoVersion { + case RegoV0: + for _, x := range KeywordsV0 { + if x == s { + return true + } + } + case RegoV1, RegoV0CompatV1: + for _, x := range KeywordsV1 { + if x == s { + return true + } + } + } + + return false +} + +type ( + // Node represents a node in an AST. Nodes may be statements in a policy module + // or elements of an ad-hoc query, expression, etc. + Node interface { + fmt.Stringer + Loc() *Location + SetLoc(*Location) + } + + // Statement represents a single statement in a policy module. + Statement interface { + Node + } +) + +type ( + + // Module represents a collection of policies (defined by rules) + // within a namespace (defined by the package) and optional + // dependencies on external documents (defined by imports). + Module struct { + Package *Package `json:"package"` + Imports []*Import `json:"imports,omitempty"` + Annotations []*Annotations `json:"annotations,omitempty"` + Rules []*Rule `json:"rules,omitempty"` + Comments []*Comment `json:"comments,omitempty"` + stmts []Statement + regoVersion RegoVersion + } + + // Comment contains the raw text from the comment in the definition. + Comment struct { + // TODO: these fields have inconsistent JSON keys with other structs in this package. + Text []byte + Location *Location + + jsonOptions astJSON.Options + } + + // Package represents the namespace of the documents produced + // by rules inside the module. + Package struct { + Path Ref `json:"path"` + Location *Location `json:"location,omitempty"` + + jsonOptions astJSON.Options + } + + // Import represents a dependency on a document outside of the policy + // namespace. Imports are optional. + Import struct { + Path *Term `json:"path"` + Alias Var `json:"alias,omitempty"` + Location *Location `json:"location,omitempty"` + + jsonOptions astJSON.Options + } + + // Rule represents a rule as defined in the language. Rules define the + // content of documents that represent policy decisions. + Rule struct { + Default bool `json:"default,omitempty"` + Head *Head `json:"head"` + Body Body `json:"body"` + Else *Rule `json:"else,omitempty"` + Location *Location `json:"location,omitempty"` + Annotations []*Annotations `json:"annotations,omitempty"` + + // Module is a pointer to the module containing this rule. If the rule + // was NOT created while parsing/constructing a module, this should be + // left unset. The pointer is not included in any standard operations + // on the rule (e.g., printing, comparison, visiting, etc.) + Module *Module `json:"-"` + + generatedBody bool + jsonOptions astJSON.Options + } + + // Head represents the head of a rule. + Head struct { + Name Var `json:"name,omitempty"` + Reference Ref `json:"ref,omitempty"` + Args Args `json:"args,omitempty"` + Key *Term `json:"key,omitempty"` + Value *Term `json:"value,omitempty"` + Assign bool `json:"assign,omitempty"` + Location *Location `json:"location,omitempty"` + + keywords []tokens.Token + generatedValue bool + jsonOptions astJSON.Options + } + + // Args represents zero or more arguments to a rule. + Args []*Term + + // Body represents one or more expressions contained inside a rule or user + // function. + Body []*Expr + + // Expr represents a single expression contained inside the body of a rule. + Expr struct { + With []*With `json:"with,omitempty"` + Terms interface{} `json:"terms"` + Index int `json:"index"` + Generated bool `json:"generated,omitempty"` + Negated bool `json:"negated,omitempty"` + Location *Location `json:"location,omitempty"` + + jsonOptions astJSON.Options + generatedFrom *Expr + generates []*Expr + } + + // SomeDecl represents a variable declaration statement. The symbols are variables. + SomeDecl struct { + Symbols []*Term `json:"symbols"` + Location *Location `json:"location,omitempty"` + + jsonOptions astJSON.Options + } + + Every struct { + Key *Term `json:"key"` + Value *Term `json:"value"` + Domain *Term `json:"domain"` + Body Body `json:"body"` + Location *Location `json:"location,omitempty"` + + jsonOptions astJSON.Options + } + + // With represents a modifier on an expression. + With struct { + Target *Term `json:"target"` + Value *Term `json:"value"` + Location *Location `json:"location,omitempty"` + + jsonOptions astJSON.Options + } +) + +// SetModuleRegoVersion sets the RegoVersion for the Module. +func SetModuleRegoVersion(mod *Module, v RegoVersion) { + mod.regoVersion = v +} + +// Compare returns an integer indicating whether mod is less than, equal to, +// or greater than other. +func (mod *Module) Compare(other *Module) int { + if mod == nil { + if other == nil { + return 0 + } + return -1 + } else if other == nil { + return 1 + } + if cmp := mod.Package.Compare(other.Package); cmp != 0 { + return cmp + } + if cmp := importsCompare(mod.Imports, other.Imports); cmp != 0 { + return cmp + } + if cmp := annotationsCompare(mod.Annotations, other.Annotations); cmp != 0 { + return cmp + } + return rulesCompare(mod.Rules, other.Rules) +} + +// Copy returns a deep copy of mod. +func (mod *Module) Copy() *Module { + cpy := *mod + cpy.Rules = make([]*Rule, len(mod.Rules)) + + nodes := make(map[Node]Node, len(mod.Rules)+len(mod.Imports)+1 /* package */) + + for i := range mod.Rules { + cpy.Rules[i] = mod.Rules[i].Copy() + cpy.Rules[i].Module = &cpy + nodes[mod.Rules[i]] = cpy.Rules[i] + } + + cpy.Imports = make([]*Import, len(mod.Imports)) + for i := range mod.Imports { + cpy.Imports[i] = mod.Imports[i].Copy() + nodes[mod.Imports[i]] = cpy.Imports[i] + } + + cpy.Package = mod.Package.Copy() + nodes[mod.Package] = cpy.Package + + cpy.Annotations = make([]*Annotations, len(mod.Annotations)) + for i, a := range mod.Annotations { + cpy.Annotations[i] = a.Copy(nodes[a.node]) + } + + cpy.Comments = make([]*Comment, len(mod.Comments)) + for i := range mod.Comments { + cpy.Comments[i] = mod.Comments[i].Copy() + } + + cpy.stmts = make([]Statement, len(mod.stmts)) + for i := range mod.stmts { + cpy.stmts[i] = nodes[mod.stmts[i]] + } + + return &cpy +} + +// Equal returns true if mod equals other. +func (mod *Module) Equal(other *Module) bool { + return mod.Compare(other) == 0 +} + +func (mod *Module) String() string { + byNode := map[Node][]*Annotations{} + for _, a := range mod.Annotations { + byNode[a.node] = append(byNode[a.node], a) + } + + appendAnnotationStrings := func(buf []string, node Node) []string { + if as, ok := byNode[node]; ok { + for i := range as { + buf = append(buf, "# METADATA") + buf = append(buf, "# "+as[i].String()) + } + } + return buf + } + + buf := []string{} + buf = appendAnnotationStrings(buf, mod.Package) + buf = append(buf, mod.Package.String()) + + if len(mod.Imports) > 0 { + buf = append(buf, "") + for _, imp := range mod.Imports { + buf = appendAnnotationStrings(buf, imp) + buf = append(buf, imp.String()) + } + } + if len(mod.Rules) > 0 { + buf = append(buf, "") + for _, rule := range mod.Rules { + buf = appendAnnotationStrings(buf, rule) + buf = append(buf, rule.stringWithOpts(toStringOpts{regoVersion: mod.regoVersion})) + } + } + return strings.Join(buf, "\n") +} + +// RuleSet returns a RuleSet containing named rules in the mod. +func (mod *Module) RuleSet(name Var) RuleSet { + rs := NewRuleSet() + for _, rule := range mod.Rules { + if rule.Head.Name.Equal(name) { + rs.Add(rule) + } + } + return rs +} + +// UnmarshalJSON parses bs and stores the result in mod. The rules in the module +// will have their module pointer set to mod. +func (mod *Module) UnmarshalJSON(bs []byte) error { + + // Declare a new type and use a type conversion to avoid recursively calling + // Module#UnmarshalJSON. + type module Module + + if err := util.UnmarshalJSON(bs, (*module)(mod)); err != nil { + return err + } + + WalkRules(mod, func(rule *Rule) bool { + rule.Module = mod + return false + }) + + return nil +} + +func (mod *Module) regoV1Compatible() bool { + return mod.regoVersion == RegoV1 || mod.regoVersion == RegoV0CompatV1 +} + +func (mod *Module) RegoVersion() RegoVersion { + return mod.regoVersion +} + +// SetRegoVersion sets the RegoVersion for the module. +// Note: Setting a rego-version that does not match the module's rego-version might have unintended consequences. +func (mod *Module) SetRegoVersion(v RegoVersion) { + mod.regoVersion = v +} + +// NewComment returns a new Comment object. +func NewComment(text []byte) *Comment { + return &Comment{ + Text: text, + } +} + +// Loc returns the location of the comment in the definition. +func (c *Comment) Loc() *Location { + if c == nil { + return nil + } + return c.Location +} + +// SetLoc sets the location on c. +func (c *Comment) SetLoc(loc *Location) { + c.Location = loc +} + +func (c *Comment) String() string { + return "#" + string(c.Text) +} + +// Copy returns a deep copy of c. +func (c *Comment) Copy() *Comment { + cpy := *c + cpy.Text = make([]byte, len(c.Text)) + copy(cpy.Text, c.Text) + return &cpy +} + +// Equal returns true if this comment equals the other comment. +// Unlike other equality checks on AST nodes, comment equality +// depends on location. +func (c *Comment) Equal(other *Comment) bool { + return c.Location.Equal(other.Location) && bytes.Equal(c.Text, other.Text) +} + +func (c *Comment) setJSONOptions(opts astJSON.Options) { + // Note: this is not used for location since Comments use default JSON marshaling + // behavior with struct field names in JSON. + c.jsonOptions = opts + if c.Location != nil { + c.Location.JSONOptions = opts + } +} + +// Compare returns an integer indicating whether pkg is less than, equal to, +// or greater than other. +func (pkg *Package) Compare(other *Package) int { + return Compare(pkg.Path, other.Path) +} + +// Copy returns a deep copy of pkg. +func (pkg *Package) Copy() *Package { + cpy := *pkg + cpy.Path = pkg.Path.Copy() + return &cpy +} + +// Equal returns true if pkg is equal to other. +func (pkg *Package) Equal(other *Package) bool { + return pkg.Compare(other) == 0 +} + +// Loc returns the location of the Package in the definition. +func (pkg *Package) Loc() *Location { + if pkg == nil { + return nil + } + return pkg.Location +} + +// SetLoc sets the location on pkg. +func (pkg *Package) SetLoc(loc *Location) { + pkg.Location = loc +} + +func (pkg *Package) String() string { + if pkg == nil { + return "" + } else if len(pkg.Path) <= 1 { + return fmt.Sprintf("package ", pkg.Path) + } + // Omit head as all packages have the DefaultRootDocument prepended at parse time. + path := make(Ref, len(pkg.Path)-1) + path[0] = VarTerm(string(pkg.Path[1].Value.(String))) + copy(path[1:], pkg.Path[2:]) + return fmt.Sprintf("package %v", path) +} + +func (pkg *Package) setJSONOptions(opts astJSON.Options) { + pkg.jsonOptions = opts + if pkg.Location != nil { + pkg.Location.JSONOptions = opts + } +} + +func (pkg *Package) MarshalJSON() ([]byte, error) { + data := map[string]interface{}{ + "path": pkg.Path, + } + + if pkg.jsonOptions.MarshalOptions.IncludeLocation.Package { + if pkg.Location != nil { + data["location"] = pkg.Location + } + } + + return json.Marshal(data) +} + +// IsValidImportPath returns an error indicating if the import path is invalid. +// If the import path is valid, err is nil. +func IsValidImportPath(v Value) (err error) { + switch v := v.(type) { + case Var: + if !v.Equal(DefaultRootDocument.Value) && !v.Equal(InputRootDocument.Value) { + return fmt.Errorf("invalid path %v: path must begin with input or data", v) + } + case Ref: + if err := IsValidImportPath(v[0].Value); err != nil { + return fmt.Errorf("invalid path %v: path must begin with input or data", v) + } + for _, e := range v[1:] { + if _, ok := e.Value.(String); !ok { + return fmt.Errorf("invalid path %v: path elements must be strings", v) + } + } + default: + return fmt.Errorf("invalid path %v: path must be ref or var", v) + } + return nil +} + +// Compare returns an integer indicating whether imp is less than, equal to, +// or greater than other. +func (imp *Import) Compare(other *Import) int { + if imp == nil { + if other == nil { + return 0 + } + return -1 + } else if other == nil { + return 1 + } + if cmp := Compare(imp.Path, other.Path); cmp != 0 { + return cmp + } + return Compare(imp.Alias, other.Alias) +} + +// Copy returns a deep copy of imp. +func (imp *Import) Copy() *Import { + cpy := *imp + cpy.Path = imp.Path.Copy() + return &cpy +} + +// Equal returns true if imp is equal to other. +func (imp *Import) Equal(other *Import) bool { + return imp.Compare(other) == 0 +} + +// Loc returns the location of the Import in the definition. +func (imp *Import) Loc() *Location { + if imp == nil { + return nil + } + return imp.Location +} + +// SetLoc sets the location on imp. +func (imp *Import) SetLoc(loc *Location) { + imp.Location = loc +} + +// Name returns the variable that is used to refer to the imported virtual +// document. This is the alias if defined otherwise the last element in the +// path. +func (imp *Import) Name() Var { + if len(imp.Alias) != 0 { + return imp.Alias + } + switch v := imp.Path.Value.(type) { + case Var: + return v + case Ref: + if len(v) == 1 { + return v[0].Value.(Var) + } + return Var(v[len(v)-1].Value.(String)) + } + panic("illegal import") +} + +func (imp *Import) String() string { + buf := []string{"import", imp.Path.String()} + if len(imp.Alias) > 0 { + buf = append(buf, "as "+imp.Alias.String()) + } + return strings.Join(buf, " ") +} + +func (imp *Import) setJSONOptions(opts astJSON.Options) { + imp.jsonOptions = opts + if imp.Location != nil { + imp.Location.JSONOptions = opts + } +} + +func (imp *Import) MarshalJSON() ([]byte, error) { + data := map[string]interface{}{ + "path": imp.Path, + } + + if len(imp.Alias) != 0 { + data["alias"] = imp.Alias + } + + if imp.jsonOptions.MarshalOptions.IncludeLocation.Import { + if imp.Location != nil { + data["location"] = imp.Location + } + } + + return json.Marshal(data) +} + +// Compare returns an integer indicating whether rule is less than, equal to, +// or greater than other. +func (rule *Rule) Compare(other *Rule) int { + if rule == nil { + if other == nil { + return 0 + } + return -1 + } else if other == nil { + return 1 + } + if cmp := rule.Head.Compare(other.Head); cmp != 0 { + return cmp + } + if cmp := util.Compare(rule.Default, other.Default); cmp != 0 { + return cmp + } + if cmp := rule.Body.Compare(other.Body); cmp != 0 { + return cmp + } + + if cmp := annotationsCompare(rule.Annotations, other.Annotations); cmp != 0 { + return cmp + } + + return rule.Else.Compare(other.Else) +} + +// Copy returns a deep copy of rule. +func (rule *Rule) Copy() *Rule { + cpy := *rule + cpy.Head = rule.Head.Copy() + cpy.Body = rule.Body.Copy() + + cpy.Annotations = make([]*Annotations, len(rule.Annotations)) + for i, a := range rule.Annotations { + cpy.Annotations[i] = a.Copy(&cpy) + } + + if cpy.Else != nil { + cpy.Else = rule.Else.Copy() + } + return &cpy +} + +// Equal returns true if rule is equal to other. +func (rule *Rule) Equal(other *Rule) bool { + return rule.Compare(other) == 0 +} + +// Loc returns the location of the Rule in the definition. +func (rule *Rule) Loc() *Location { + if rule == nil { + return nil + } + return rule.Location +} + +// SetLoc sets the location on rule. +func (rule *Rule) SetLoc(loc *Location) { + rule.Location = loc +} + +// Path returns a ref referring to the document produced by this rule. If rule +// is not contained in a module, this function panics. +// Deprecated: Poor handling of ref rules. Use `(*Rule).Ref()` instead. +func (rule *Rule) Path() Ref { + if rule.Module == nil { + panic("assertion failed") + } + return rule.Module.Package.Path.Extend(rule.Head.Ref().GroundPrefix()) +} + +// Ref returns a ref referring to the document produced by this rule. If rule +// is not contained in a module, this function panics. The returned ref may +// contain variables in the last position. +func (rule *Rule) Ref() Ref { + if rule.Module == nil { + panic("assertion failed") + } + return rule.Module.Package.Path.Extend(rule.Head.Ref()) +} + +func (rule *Rule) String() string { + regoVersion := DefaultRegoVersion + if rule.Module != nil { + regoVersion = rule.Module.RegoVersion() + } + return rule.stringWithOpts(toStringOpts{regoVersion: regoVersion}) +} + +type toStringOpts struct { + regoVersion RegoVersion +} + +func (o toStringOpts) RegoVersion() RegoVersion { + if o.regoVersion == RegoUndefined { + return DefaultRegoVersion + } + return o.regoVersion +} + +func (rule *Rule) stringWithOpts(opts toStringOpts) string { + buf := []string{} + if rule.Default { + buf = append(buf, "default") + } + buf = append(buf, rule.Head.stringWithOpts(opts)) + if !rule.Default { + switch opts.RegoVersion() { + case RegoV1, RegoV0CompatV1: + buf = append(buf, "if") + } + buf = append(buf, "{") + buf = append(buf, rule.Body.String()) + buf = append(buf, "}") + } + if rule.Else != nil { + buf = append(buf, rule.Else.elseString(opts)) + } + return strings.Join(buf, " ") +} + +func (rule *Rule) isFunction() bool { + return len(rule.Head.Args) > 0 +} + +func (rule *Rule) setJSONOptions(opts astJSON.Options) { + rule.jsonOptions = opts + if rule.Location != nil { + rule.Location.JSONOptions = opts + } +} + +func (rule *Rule) MarshalJSON() ([]byte, error) { + data := map[string]interface{}{ + "head": rule.Head, + "body": rule.Body, + } + + if rule.Default { + data["default"] = true + } + + if rule.Else != nil { + data["else"] = rule.Else + } + + if rule.jsonOptions.MarshalOptions.IncludeLocation.Rule { + if rule.Location != nil { + data["location"] = rule.Location + } + } + + if len(rule.Annotations) != 0 { + data["annotations"] = rule.Annotations + } + + return json.Marshal(data) +} + +func (rule *Rule) elseString(opts toStringOpts) string { + var buf []string + + buf = append(buf, "else") + + value := rule.Head.Value + if value != nil { + buf = append(buf, "=") + buf = append(buf, value.String()) + } + + switch opts.RegoVersion() { + case RegoV1, RegoV0CompatV1: + buf = append(buf, "if") + } + + buf = append(buf, "{") + buf = append(buf, rule.Body.String()) + buf = append(buf, "}") + + if rule.Else != nil { + buf = append(buf, rule.Else.elseString(opts)) + } + + return strings.Join(buf, " ") +} + +// NewHead returns a new Head object. If args are provided, the first will be +// used for the key and the second will be used for the value. +func NewHead(name Var, args ...*Term) *Head { + head := &Head{ + Name: name, // backcompat + Reference: []*Term{NewTerm(name)}, + } + if len(args) == 0 { + return head + } + head.Key = args[0] + if len(args) == 1 { + return head + } + head.Value = args[1] + if head.Key != nil && head.Value != nil { + head.Reference = head.Reference.Append(args[0]) + } + return head +} + +// VarHead creates a head object, initializes its Name, Location, and Options, +// and returns the new head. +func VarHead(name Var, location *Location, jsonOpts *astJSON.Options) *Head { + h := NewHead(name) + h.Reference[0].Location = location + if jsonOpts != nil { + h.Reference[0].setJSONOptions(*jsonOpts) + } + return h +} + +// RefHead returns a new Head object with the passed Ref. If args are provided, +// the first will be used for the value. +func RefHead(ref Ref, args ...*Term) *Head { + head := &Head{} + head.SetRef(ref) + if len(ref) < 2 { + head.Name = ref[0].Value.(Var) + } + if len(args) >= 1 { + head.Value = args[0] + } + return head +} + +// DocKind represents the collection of document types that can be produced by rules. +type DocKind int + +const ( + // CompleteDoc represents a document that is completely defined by the rule. + CompleteDoc = iota + + // PartialSetDoc represents a set document that is partially defined by the rule. + PartialSetDoc + + // PartialObjectDoc represents an object document that is partially defined by the rule. + PartialObjectDoc +) // TODO(sr): Deprecate? + +// DocKind returns the type of document produced by this rule. +func (head *Head) DocKind() DocKind { + if head.Key != nil { + if head.Value != nil { + return PartialObjectDoc + } + return PartialSetDoc + } + return CompleteDoc +} + +type RuleKind int + +const ( + SingleValue = iota + MultiValue +) + +// RuleKind returns the type of rule this is +func (head *Head) RuleKind() RuleKind { + // NOTE(sr): This is bit verbose, since the key is irrelevant for single vs + // multi value, but as good a spot as to assert the invariant. + switch { + case head.Value != nil: + return SingleValue + case head.Key != nil: + return MultiValue + default: + panic("unreachable") + } +} + +// Ref returns the Ref of the rule. If it doesn't have one, it's filled in +// via the Head's Name. +func (head *Head) Ref() Ref { + if len(head.Reference) > 0 { + return head.Reference + } + return Ref{&Term{Value: head.Name}} +} + +// SetRef can be used to set a rule head's Reference +func (head *Head) SetRef(r Ref) { + head.Reference = r +} + +// Compare returns an integer indicating whether head is less than, equal to, +// or greater than other. +func (head *Head) Compare(other *Head) int { + if head == nil { + if other == nil { + return 0 + } + return -1 + } else if other == nil { + return 1 + } + if head.Assign && !other.Assign { + return -1 + } else if !head.Assign && other.Assign { + return 1 + } + if cmp := Compare(head.Args, other.Args); cmp != 0 { + return cmp + } + if cmp := Compare(head.Reference, other.Reference); cmp != 0 { + return cmp + } + if cmp := Compare(head.Name, other.Name); cmp != 0 { + return cmp + } + if cmp := Compare(head.Key, other.Key); cmp != 0 { + return cmp + } + return Compare(head.Value, other.Value) +} + +// Copy returns a deep copy of head. +func (head *Head) Copy() *Head { + cpy := *head + cpy.Reference = head.Reference.Copy() + cpy.Args = head.Args.Copy() + cpy.Key = head.Key.Copy() + cpy.Value = head.Value.Copy() + cpy.keywords = nil + return &cpy +} + +// Equal returns true if this head equals other. +func (head *Head) Equal(other *Head) bool { + return head.Compare(other) == 0 +} + +func (head *Head) String() string { + return head.stringWithOpts(toStringOpts{}) +} + +func (head *Head) stringWithOpts(opts toStringOpts) string { + buf := strings.Builder{} + buf.WriteString(head.Ref().String()) + containsAdded := false + + switch { + case len(head.Args) != 0: + buf.WriteString(head.Args.String()) + case len(head.Reference) == 1 && head.Key != nil: + switch opts.RegoVersion() { + case RegoV0: + buf.WriteRune('[') + buf.WriteString(head.Key.String()) + buf.WriteRune(']') + default: + containsAdded = true + buf.WriteString(" contains ") + buf.WriteString(head.Key.String()) + } + } + if head.Value != nil { + if head.Assign { + buf.WriteString(" := ") + } else { + buf.WriteString(" = ") + } + buf.WriteString(head.Value.String()) + } else if !containsAdded && head.Name == "" && head.Key != nil { + buf.WriteString(" contains ") + buf.WriteString(head.Key.String()) + } + return buf.String() +} + +func (head *Head) setJSONOptions(opts astJSON.Options) { + head.jsonOptions = opts + if head.Location != nil { + head.Location.JSONOptions = opts + } +} + +func (head *Head) MarshalJSON() ([]byte, error) { + var loc *Location + includeLoc := head.jsonOptions.MarshalOptions.IncludeLocation + if includeLoc.Head { + if head.Location != nil { + loc = head.Location + } + + for _, term := range head.Reference { + if term.Location != nil { + term.jsonOptions.MarshalOptions.IncludeLocation.Term = includeLoc.Term + } + } + } + + // NOTE(sr): we do this to override the rendering of `head.Reference`. + // It's still what'll be used via the default means of encoding/json + // for unmarshaling a json object into a Head struct! + type h Head + return json.Marshal(struct { + h + Ref Ref `json:"ref"` + Location *Location `json:"location,omitempty"` + }{ + h: h(*head), + Ref: head.Ref(), + Location: loc, + }) +} + +// Vars returns a set of vars found in the head. +func (head *Head) Vars() VarSet { + vis := &VarVisitor{vars: VarSet{}} + // TODO: improve test coverage for this. + if head.Args != nil { + vis.Walk(head.Args) + } + if head.Key != nil { + vis.Walk(head.Key) + } + if head.Value != nil { + vis.Walk(head.Value) + } + if len(head.Reference) > 0 { + vis.Walk(head.Reference[1:]) + } + return vis.vars +} + +// Loc returns the Location of head. +func (head *Head) Loc() *Location { + if head == nil { + return nil + } + return head.Location +} + +// SetLoc sets the location on head. +func (head *Head) SetLoc(loc *Location) { + head.Location = loc +} + +func (head *Head) HasDynamicRef() bool { + pos := head.Reference.Dynamic() + // Ref is dynamic if it has one non-constant term that isn't the first or last term or if it's a partial set rule. + return pos > 0 && (pos < len(head.Reference)-1 || head.RuleKind() == MultiValue) +} + +// Copy returns a deep copy of a. +func (a Args) Copy() Args { + cpy := Args{} + for _, t := range a { + cpy = append(cpy, t.Copy()) + } + return cpy +} + +func (a Args) String() string { + buf := make([]string, 0, len(a)) + for _, t := range a { + buf = append(buf, t.String()) + } + return "(" + strings.Join(buf, ", ") + ")" +} + +// Loc returns the Location of a. +func (a Args) Loc() *Location { + if len(a) == 0 { + return nil + } + return a[0].Location +} + +// SetLoc sets the location on a. +func (a Args) SetLoc(loc *Location) { + if len(a) != 0 { + a[0].SetLocation(loc) + } +} + +// Vars returns a set of vars that appear in a. +func (a Args) Vars() VarSet { + vis := &VarVisitor{vars: VarSet{}} + vis.Walk(a) + return vis.vars +} + +// NewBody returns a new Body containing the given expressions. The indices of +// the immediate expressions will be reset. +func NewBody(exprs ...*Expr) Body { + for i, expr := range exprs { + expr.Index = i + } + return Body(exprs) +} + +// MarshalJSON returns JSON encoded bytes representing body. +func (body Body) MarshalJSON() ([]byte, error) { + // Serialize empty Body to empty array. This handles both the empty case and the + // nil case (whereas by default the result would be null if body was nil.) + if len(body) == 0 { + return []byte(`[]`), nil + } + ret, err := json.Marshal([]*Expr(body)) + return ret, err +} + +// Append adds the expr to the body and updates the expr's index accordingly. +func (body *Body) Append(expr *Expr) { + n := len(*body) + expr.Index = n + *body = append(*body, expr) +} + +// Set sets the expr in the body at the specified position and updates the +// expr's index accordingly. +func (body Body) Set(expr *Expr, pos int) { + body[pos] = expr + expr.Index = pos +} + +// Compare returns an integer indicating whether body is less than, equal to, +// or greater than other. +// +// If body is a subset of other, it is considered less than (and vice versa). +func (body Body) Compare(other Body) int { + minLen := len(body) + if len(other) < minLen { + minLen = len(other) + } + for i := 0; i < minLen; i++ { + if cmp := body[i].Compare(other[i]); cmp != 0 { + return cmp + } + } + if len(body) < len(other) { + return -1 + } + if len(other) < len(body) { + return 1 + } + return 0 +} + +// Copy returns a deep copy of body. +func (body Body) Copy() Body { + cpy := make(Body, len(body)) + for i := range body { + cpy[i] = body[i].Copy() + } + return cpy +} + +// Contains returns true if this body contains the given expression. +func (body Body) Contains(x *Expr) bool { + for _, e := range body { + if e.Equal(x) { + return true + } + } + return false +} + +// Equal returns true if this Body is equal to the other Body. +func (body Body) Equal(other Body) bool { + return body.Compare(other) == 0 +} + +// Hash returns the hash code for the Body. +func (body Body) Hash() int { + s := 0 + for _, e := range body { + s += e.Hash() + } + return s +} + +// IsGround returns true if all of the expressions in the Body are ground. +func (body Body) IsGround() bool { + for _, e := range body { + if !e.IsGround() { + return false + } + } + return true +} + +// Loc returns the location of the Body in the definition. +func (body Body) Loc() *Location { + if len(body) == 0 { + return nil + } + return body[0].Location +} + +// SetLoc sets the location on body. +func (body Body) SetLoc(loc *Location) { + if len(body) != 0 { + body[0].SetLocation(loc) + } +} + +func (body Body) String() string { + buf := make([]string, 0, len(body)) + for _, v := range body { + buf = append(buf, v.String()) + } + return strings.Join(buf, "; ") +} + +// Vars returns a VarSet containing variables in body. The params can be set to +// control which vars are included. +func (body Body) Vars(params VarVisitorParams) VarSet { + vis := NewVarVisitor().WithParams(params) + vis.Walk(body) + return vis.Vars() +} + +// NewExpr returns a new Expr object. +func NewExpr(terms interface{}) *Expr { + switch terms.(type) { + case *SomeDecl, *Every, *Term, []*Term: // ok + default: + panic("unreachable") + } + return &Expr{ + Negated: false, + Terms: terms, + Index: 0, + With: nil, + } +} + +// Complement returns a copy of this expression with the negation flag flipped. +func (expr *Expr) Complement() *Expr { + cpy := *expr + cpy.Negated = !cpy.Negated + return &cpy +} + +// ComplementNoWith returns a copy of this expression with the negation flag flipped +// and the with modifier removed. This is the same as calling .Complement().NoWith() +// but without making an intermediate copy. +func (expr *Expr) ComplementNoWith() *Expr { + cpy := *expr + cpy.Negated = !cpy.Negated + cpy.With = nil + return &cpy +} + +// Equal returns true if this Expr equals the other Expr. +func (expr *Expr) Equal(other *Expr) bool { + return expr.Compare(other) == 0 +} + +// Compare returns an integer indicating whether expr is less than, equal to, +// or greater than other. +// +// Expressions are compared as follows: +// +// 1. Declarations are always less than other expressions. +// 2. Preceding expression (by Index) is always less than the other expression. +// 3. Non-negated expressions are always less than negated expressions. +// 4. Single term expressions are always less than built-in expressions. +// +// Otherwise, the expression terms are compared normally. If both expressions +// have the same terms, the modifiers are compared. +func (expr *Expr) Compare(other *Expr) int { + + if expr == nil { + if other == nil { + return 0 + } + return -1 + } else if other == nil { + return 1 + } + + o1 := expr.sortOrder() + o2 := other.sortOrder() + if o1 < o2 { + return -1 + } else if o2 < o1 { + return 1 + } + + switch { + case expr.Index < other.Index: + return -1 + case expr.Index > other.Index: + return 1 + } + + switch { + case expr.Negated && !other.Negated: + return 1 + case !expr.Negated && other.Negated: + return -1 + } + + switch t := expr.Terms.(type) { + case *Term: + if cmp := Compare(t.Value, other.Terms.(*Term).Value); cmp != 0 { + return cmp + } + case []*Term: + if cmp := termSliceCompare(t, other.Terms.([]*Term)); cmp != 0 { + return cmp + } + case *SomeDecl: + if cmp := Compare(t, other.Terms.(*SomeDecl)); cmp != 0 { + return cmp + } + case *Every: + if cmp := Compare(t, other.Terms.(*Every)); cmp != 0 { + return cmp + } + } + + return withSliceCompare(expr.With, other.With) +} + +func (expr *Expr) sortOrder() int { + switch expr.Terms.(type) { + case *SomeDecl: + return 0 + case *Term: + return 1 + case []*Term: + return 2 + case *Every: + return 3 + } + return -1 +} + +// CopyWithoutTerms returns a deep copy of expr without its Terms +func (expr *Expr) CopyWithoutTerms() *Expr { + cpy := *expr + + if expr.With != nil { + cpy.With = make([]*With, len(expr.With)) + for i := range expr.With { + cpy.With[i] = expr.With[i].Copy() + } + } + + return &cpy +} + +// Copy returns a deep copy of expr. +func (expr *Expr) Copy() *Expr { + + cpy := expr.CopyWithoutTerms() + + switch ts := expr.Terms.(type) { + case *SomeDecl: + cpy.Terms = ts.Copy() + case []*Term: + cpyTs := make([]*Term, len(ts)) + for i := range ts { + cpyTs[i] = ts[i].Copy() + } + cpy.Terms = cpyTs + case *Term: + cpy.Terms = ts.Copy() + case *Every: + cpy.Terms = ts.Copy() + } + + return cpy +} + +// Hash returns the hash code of the Expr. +func (expr *Expr) Hash() int { + s := expr.Index + switch ts := expr.Terms.(type) { + case *SomeDecl: + s += ts.Hash() + case []*Term: + for _, t := range ts { + s += t.Value.Hash() + } + case *Term: + s += ts.Value.Hash() + } + if expr.Negated { + s++ + } + for _, w := range expr.With { + s += w.Hash() + } + return s +} + +// IncludeWith returns a copy of expr with the with modifier appended. +func (expr *Expr) IncludeWith(target *Term, value *Term) *Expr { + cpy := *expr + cpy.With = append(cpy.With, &With{Target: target, Value: value}) + return &cpy +} + +// NoWith returns a copy of expr where the with modifier has been removed. +func (expr *Expr) NoWith() *Expr { + cpy := *expr + cpy.With = nil + return &cpy +} + +// IsEquality returns true if this is an equality expression. +func (expr *Expr) IsEquality() bool { + return isGlobalBuiltin(expr, Var(Equality.Name)) +} + +// IsAssignment returns true if this an assignment expression. +func (expr *Expr) IsAssignment() bool { + return isGlobalBuiltin(expr, Var(Assign.Name)) +} + +// IsCall returns true if this expression calls a function. +func (expr *Expr) IsCall() bool { + _, ok := expr.Terms.([]*Term) + return ok +} + +// IsEvery returns true if this expression is an 'every' expression. +func (expr *Expr) IsEvery() bool { + _, ok := expr.Terms.(*Every) + return ok +} + +// IsSome returns true if this expression is a 'some' expression. +func (expr *Expr) IsSome() bool { + _, ok := expr.Terms.(*SomeDecl) + return ok +} + +// Operator returns the name of the function or built-in this expression refers +// to. If this expression is not a function call, returns nil. +func (expr *Expr) Operator() Ref { + op := expr.OperatorTerm() + if op == nil { + return nil + } + return op.Value.(Ref) +} + +// OperatorTerm returns the name of the function or built-in this expression +// refers to. If this expression is not a function call, returns nil. +func (expr *Expr) OperatorTerm() *Term { + terms, ok := expr.Terms.([]*Term) + if !ok || len(terms) == 0 { + return nil + } + return terms[0] +} + +// Operand returns the term at the zero-based pos. If the expr does not include +// at least pos+1 terms, this function returns nil. +func (expr *Expr) Operand(pos int) *Term { + terms, ok := expr.Terms.([]*Term) + if !ok { + return nil + } + idx := pos + 1 + if idx < len(terms) { + return terms[idx] + } + return nil +} + +// Operands returns the built-in function operands. +func (expr *Expr) Operands() []*Term { + terms, ok := expr.Terms.([]*Term) + if !ok { + return nil + } + return terms[1:] +} + +// IsGround returns true if all of the expression terms are ground. +func (expr *Expr) IsGround() bool { + switch ts := expr.Terms.(type) { + case []*Term: + for _, t := range ts[1:] { + if !t.IsGround() { + return false + } + } + case *Term: + return ts.IsGround() + } + return true +} + +// SetOperator sets the expr's operator and returns the expr itself. If expr is +// not a call expr, this function will panic. +func (expr *Expr) SetOperator(term *Term) *Expr { + expr.Terms.([]*Term)[0] = term + return expr +} + +// SetLocation sets the expr's location and returns the expr itself. +func (expr *Expr) SetLocation(loc *Location) *Expr { + expr.Location = loc + return expr +} + +// Loc returns the Location of expr. +func (expr *Expr) Loc() *Location { + if expr == nil { + return nil + } + return expr.Location +} + +// SetLoc sets the location on expr. +func (expr *Expr) SetLoc(loc *Location) { + expr.SetLocation(loc) +} + +func (expr *Expr) String() string { + buf := make([]string, 0, 2+len(expr.With)) + if expr.Negated { + buf = append(buf, "not") + } + switch t := expr.Terms.(type) { + case []*Term: + if expr.IsEquality() && validEqAssignArgCount(expr) { + buf = append(buf, fmt.Sprintf("%v %v %v", t[1], Equality.Infix, t[2])) + } else { + buf = append(buf, Call(t).String()) + } + case fmt.Stringer: + buf = append(buf, t.String()) + } + + for i := range expr.With { + buf = append(buf, expr.With[i].String()) + } + + return strings.Join(buf, " ") +} + +func (expr *Expr) setJSONOptions(opts astJSON.Options) { + expr.jsonOptions = opts + if expr.Location != nil { + expr.Location.JSONOptions = opts + } +} + +func (expr *Expr) MarshalJSON() ([]byte, error) { + data := map[string]interface{}{ + "terms": expr.Terms, + "index": expr.Index, + } + + if len(expr.With) > 0 { + data["with"] = expr.With + } + + if expr.Generated { + data["generated"] = true + } + + if expr.Negated { + data["negated"] = true + } + + if expr.jsonOptions.MarshalOptions.IncludeLocation.Expr { + if expr.Location != nil { + data["location"] = expr.Location + } + } + + return json.Marshal(data) +} + +// UnmarshalJSON parses the byte array and stores the result in expr. +func (expr *Expr) UnmarshalJSON(bs []byte) error { + v := map[string]interface{}{} + if err := util.UnmarshalJSON(bs, &v); err != nil { + return err + } + return unmarshalExpr(expr, v) +} + +// Vars returns a VarSet containing variables in expr. The params can be set to +// control which vars are included. +func (expr *Expr) Vars(params VarVisitorParams) VarSet { + vis := NewVarVisitor().WithParams(params) + vis.Walk(expr) + return vis.Vars() +} + +// NewBuiltinExpr creates a new Expr object with the supplied terms. +// The builtin operator must be the first term. +func NewBuiltinExpr(terms ...*Term) *Expr { + return &Expr{Terms: terms} +} + +func (expr *Expr) CogeneratedExprs() []*Expr { + visited := map[*Expr]struct{}{} + visitCogeneratedExprs(expr, func(e *Expr) bool { + if expr.Equal(e) { + return true + } + if _, ok := visited[e]; ok { + return true + } + visited[e] = struct{}{} + return false + }) + + result := make([]*Expr, 0, len(visited)) + for e := range visited { + result = append(result, e) + } + return result +} + +func (expr *Expr) BaseCogeneratedExpr() *Expr { + if expr.generatedFrom == nil { + return expr + } + return expr.generatedFrom.BaseCogeneratedExpr() +} + +func visitCogeneratedExprs(expr *Expr, f func(*Expr) bool) { + if parent := expr.generatedFrom; parent != nil { + if stop := f(parent); !stop { + visitCogeneratedExprs(parent, f) + } + } + for _, child := range expr.generates { + if stop := f(child); !stop { + visitCogeneratedExprs(child, f) + } + } +} + +func (d *SomeDecl) String() string { + if call, ok := d.Symbols[0].Value.(Call); ok { + if len(call) == 4 { + return "some " + call[1].String() + ", " + call[2].String() + " in " + call[3].String() + } + return "some " + call[1].String() + " in " + call[2].String() + } + buf := make([]string, len(d.Symbols)) + for i := range buf { + buf[i] = d.Symbols[i].String() + } + return "some " + strings.Join(buf, ", ") +} + +// SetLoc sets the Location on d. +func (d *SomeDecl) SetLoc(loc *Location) { + d.Location = loc +} + +// Loc returns the Location of d. +func (d *SomeDecl) Loc() *Location { + return d.Location +} + +// Copy returns a deep copy of d. +func (d *SomeDecl) Copy() *SomeDecl { + cpy := *d + cpy.Symbols = termSliceCopy(d.Symbols) + return &cpy +} + +// Compare returns an integer indicating whether d is less than, equal to, or +// greater than other. +func (d *SomeDecl) Compare(other *SomeDecl) int { + return termSliceCompare(d.Symbols, other.Symbols) +} + +// Hash returns a hash code of d. +func (d *SomeDecl) Hash() int { + return termSliceHash(d.Symbols) +} + +func (d *SomeDecl) setJSONOptions(opts astJSON.Options) { + d.jsonOptions = opts + if d.Location != nil { + d.Location.JSONOptions = opts + } +} + +func (d *SomeDecl) MarshalJSON() ([]byte, error) { + data := map[string]interface{}{ + "symbols": d.Symbols, + } + + if d.jsonOptions.MarshalOptions.IncludeLocation.SomeDecl { + if d.Location != nil { + data["location"] = d.Location + } + } + + return json.Marshal(data) +} + +func (q *Every) String() string { + if q.Key != nil { + return fmt.Sprintf("every %s, %s in %s { %s }", + q.Key, + q.Value, + q.Domain, + q.Body) + } + return fmt.Sprintf("every %s in %s { %s }", + q.Value, + q.Domain, + q.Body) +} + +func (q *Every) Loc() *Location { + return q.Location +} + +func (q *Every) SetLoc(l *Location) { + q.Location = l +} + +// Copy returns a deep copy of d. +func (q *Every) Copy() *Every { + cpy := *q + cpy.Key = q.Key.Copy() + cpy.Value = q.Value.Copy() + cpy.Domain = q.Domain.Copy() + cpy.Body = q.Body.Copy() + return &cpy +} + +func (q *Every) Compare(other *Every) int { + for _, terms := range [][2]*Term{ + {q.Key, other.Key}, + {q.Value, other.Value}, + {q.Domain, other.Domain}, + } { + if d := Compare(terms[0], terms[1]); d != 0 { + return d + } + } + return q.Body.Compare(other.Body) +} + +// KeyValueVars returns the key and val arguments of an `every` +// expression, if they are non-nil and not wildcards. +func (q *Every) KeyValueVars() VarSet { + vis := &VarVisitor{vars: VarSet{}} + if q.Key != nil { + vis.Walk(q.Key) + } + vis.Walk(q.Value) + return vis.vars +} + +func (q *Every) setJSONOptions(opts astJSON.Options) { + q.jsonOptions = opts + if q.Location != nil { + q.Location.JSONOptions = opts + } +} + +func (q *Every) MarshalJSON() ([]byte, error) { + data := map[string]interface{}{ + "key": q.Key, + "value": q.Value, + "domain": q.Domain, + "body": q.Body, + } + + if q.jsonOptions.MarshalOptions.IncludeLocation.Every { + if q.Location != nil { + data["location"] = q.Location + } + } + + return json.Marshal(data) +} + +func (w *With) String() string { + return "with " + w.Target.String() + " as " + w.Value.String() +} + +// Equal returns true if this With is equals the other With. +func (w *With) Equal(other *With) bool { + return Compare(w, other) == 0 +} + +// Compare returns an integer indicating whether w is less than, equal to, or +// greater than other. +func (w *With) Compare(other *With) int { + if w == nil { + if other == nil { + return 0 + } + return -1 + } else if other == nil { + return 1 + } + if cmp := Compare(w.Target, other.Target); cmp != 0 { + return cmp + } + return Compare(w.Value, other.Value) +} + +// Copy returns a deep copy of w. +func (w *With) Copy() *With { + cpy := *w + cpy.Value = w.Value.Copy() + cpy.Target = w.Target.Copy() + return &cpy +} + +// Hash returns the hash code of the With. +func (w With) Hash() int { + return w.Target.Hash() + w.Value.Hash() +} + +// SetLocation sets the location on w. +func (w *With) SetLocation(loc *Location) *With { + w.Location = loc + return w +} + +// Loc returns the Location of w. +func (w *With) Loc() *Location { + if w == nil { + return nil + } + return w.Location +} + +// SetLoc sets the location on w. +func (w *With) SetLoc(loc *Location) { + w.Location = loc +} + +func (w *With) setJSONOptions(opts astJSON.Options) { + w.jsonOptions = opts + if w.Location != nil { + w.Location.JSONOptions = opts + } +} + +func (w *With) MarshalJSON() ([]byte, error) { + data := map[string]interface{}{ + "target": w.Target, + "value": w.Value, + } + + if w.jsonOptions.MarshalOptions.IncludeLocation.With { + if w.Location != nil { + data["location"] = w.Location + } + } + + return json.Marshal(data) +} + +// Copy returns a deep copy of the AST node x. If x is not an AST node, x is returned unmodified. +func Copy(x interface{}) interface{} { + switch x := x.(type) { + case *Module: + return x.Copy() + case *Package: + return x.Copy() + case *Import: + return x.Copy() + case *Rule: + return x.Copy() + case *Head: + return x.Copy() + case Args: + return x.Copy() + case Body: + return x.Copy() + case *Expr: + return x.Copy() + case *With: + return x.Copy() + case *SomeDecl: + return x.Copy() + case *Every: + return x.Copy() + case *Term: + return x.Copy() + case *ArrayComprehension: + return x.Copy() + case *SetComprehension: + return x.Copy() + case *ObjectComprehension: + return x.Copy() + case Set: + return x.Copy() + case *object: + return x.Copy() + case *Array: + return x.Copy() + case Ref: + return x.Copy() + case Call: + return x.Copy() + case *Comment: + return x.Copy() + } + return x +} + +// RuleSet represents a collection of rules that produce a virtual document. +type RuleSet []*Rule + +// NewRuleSet returns a new RuleSet containing the given rules. +func NewRuleSet(rules ...*Rule) RuleSet { + rs := make(RuleSet, 0, len(rules)) + for _, rule := range rules { + rs.Add(rule) + } + return rs +} + +// Add inserts the rule into rs. +func (rs *RuleSet) Add(rule *Rule) { + for _, exist := range *rs { + if exist.Equal(rule) { + return + } + } + *rs = append(*rs, rule) +} + +// Contains returns true if rs contains rule. +func (rs RuleSet) Contains(rule *Rule) bool { + for i := range rs { + if rs[i].Equal(rule) { + return true + } + } + return false +} + +// Diff returns a new RuleSet containing rules in rs that are not in other. +func (rs RuleSet) Diff(other RuleSet) RuleSet { + result := NewRuleSet() + for i := range rs { + if !other.Contains(rs[i]) { + result.Add(rs[i]) + } + } + return result +} + +// Equal returns true if rs equals other. +func (rs RuleSet) Equal(other RuleSet) bool { + return len(rs.Diff(other)) == 0 && len(other.Diff(rs)) == 0 +} + +// Merge returns a ruleset containing the union of rules from rs an other. +func (rs RuleSet) Merge(other RuleSet) RuleSet { + result := NewRuleSet() + for i := range rs { + result.Add(rs[i]) + } + for i := range other { + result.Add(other[i]) + } + return result +} + +func (rs RuleSet) String() string { + buf := make([]string, 0, len(rs)) + for _, rule := range rs { + buf = append(buf, rule.String()) + } + return "{" + strings.Join(buf, ", ") + "}" +} + +// Returns true if the equality or assignment expression referred to by expr +// has a valid number of arguments. +func validEqAssignArgCount(expr *Expr) bool { + return len(expr.Operands()) == 2 +} + +// this function checks if the expr refers to a non-namespaced (global) built-in +// function like eq, gt, plus, etc. +func isGlobalBuiltin(expr *Expr, name Var) bool { + terms, ok := expr.Terms.([]*Term) + if !ok { + return false + } + + // NOTE(tsandall): do not use Term#Equal or Value#Compare to avoid + // allocation here. + ref, ok := terms[0].Value.(Ref) + if !ok || len(ref) != 1 { + return false + } + if head, ok := ref[0].Value.(Var); ok { + return head.Equal(name) + } + return false +} diff --git a/vendor/github.com/open-policy-agent/opa/v1/ast/pretty.go b/vendor/github.com/open-policy-agent/opa/v1/ast/pretty.go new file mode 100644 index 000000000..b4f05ad50 --- /dev/null +++ b/vendor/github.com/open-policy-agent/opa/v1/ast/pretty.go @@ -0,0 +1,82 @@ +// Copyright 2018 The OPA Authors. All rights reserved. +// Use of this source code is governed by an Apache2 +// license that can be found in the LICENSE file. + +package ast + +import ( + "fmt" + "io" + "strings" +) + +// Pretty writes a pretty representation of the AST rooted at x to w. +// +// This is function is intended for debug purposes when inspecting ASTs. +func Pretty(w io.Writer, x interface{}) { + pp := &prettyPrinter{ + depth: -1, + w: w, + } + NewBeforeAfterVisitor(pp.Before, pp.After).Walk(x) +} + +type prettyPrinter struct { + depth int + w io.Writer +} + +func (pp *prettyPrinter) Before(x interface{}) bool { + switch x.(type) { + case *Term: + default: + pp.depth++ + } + + switch x := x.(type) { + case *Term: + return false + case Args: + if len(x) == 0 { + return false + } + pp.writeType(x) + case *Expr: + extras := []string{} + if x.Negated { + extras = append(extras, "negated") + } + extras = append(extras, fmt.Sprintf("index=%d", x.Index)) + pp.writeIndent("%v %v", TypeName(x), strings.Join(extras, " ")) + case Null, Boolean, Number, String, Var: + pp.writeValue(x) + default: + pp.writeType(x) + } + return false +} + +func (pp *prettyPrinter) After(x interface{}) { + switch x.(type) { + case *Term: + default: + pp.depth-- + } +} + +func (pp *prettyPrinter) writeValue(x interface{}) { + pp.writeIndent(fmt.Sprint(x)) +} + +func (pp *prettyPrinter) writeType(x interface{}) { + pp.writeIndent(TypeName(x)) +} + +func (pp *prettyPrinter) writeIndent(f string, a ...interface{}) { + pad := strings.Repeat(" ", pp.depth) + pp.write(pad+f, a...) +} + +func (pp *prettyPrinter) write(f string, a ...interface{}) { + fmt.Fprintf(pp.w, f+"\n", a...) +} diff --git a/vendor/github.com/open-policy-agent/opa/ast/rego_v1.go b/vendor/github.com/open-policy-agent/opa/v1/ast/rego_v1.go similarity index 96% rename from vendor/github.com/open-policy-agent/opa/ast/rego_v1.go rename to vendor/github.com/open-policy-agent/opa/v1/ast/rego_v1.go index b64dfce7b..8b757ecc3 100644 --- a/vendor/github.com/open-policy-agent/opa/ast/rego_v1.go +++ b/vendor/github.com/open-policy-agent/opa/v1/ast/rego_v1.go @@ -3,7 +3,7 @@ package ast import ( "fmt" - "github.com/open-policy-agent/opa/ast/internal/tokens" + "github.com/open-policy-agent/opa/v1/ast/internal/tokens" ) func checkDuplicateImports(modules []*Module) (errors Errors) { @@ -192,7 +192,7 @@ func checkRegoV1Rule(rule *Rule, opts RegoCheckOptions) Errors { var errs Errors if opts.NoKeywordsAsRuleNames && IsKeywordInRegoVersion(rule.Head.Name.String(), RegoV1) { - errs = append(errs, NewError(ParseErr, rule.Location, fmt.Sprintf("%s keyword cannot be used for rule name", rule.Head.Name.String()))) + errs = append(errs, NewError(ParseErr, rule.Location, "%s keyword cannot be used for rule name", rule.Head.Name.String())) } if opts.RequireRuleBodyOrValue && rule.generatedBody && rule.Head.generatedValue { errs = append(errs, NewError(ParseErr, rule.Location, "%s must have value assignment and/or body declaration", t)) diff --git a/vendor/github.com/open-policy-agent/opa/v1/ast/schema.go b/vendor/github.com/open-policy-agent/opa/v1/ast/schema.go new file mode 100644 index 000000000..e84a147a4 --- /dev/null +++ b/vendor/github.com/open-policy-agent/opa/v1/ast/schema.go @@ -0,0 +1,63 @@ +// Copyright 2021 The OPA Authors. All rights reserved. +// Use of this source code is governed by an Apache2 +// license that can be found in the LICENSE file. + +package ast + +import ( + "fmt" + + "github.com/open-policy-agent/opa/v1/types" + "github.com/open-policy-agent/opa/v1/util" +) + +// SchemaSet holds a map from a path to a schema. +type SchemaSet struct { + m *util.HashMap +} + +// NewSchemaSet returns an empty SchemaSet. +func NewSchemaSet() *SchemaSet { + + eqFunc := func(a, b util.T) bool { + return a.(Ref).Equal(b.(Ref)) + } + + hashFunc := func(x util.T) int { return x.(Ref).Hash() } + + return &SchemaSet{ + m: util.NewHashMap(eqFunc, hashFunc), + } +} + +// Put inserts a raw schema into the set. +func (ss *SchemaSet) Put(path Ref, raw interface{}) { + ss.m.Put(path, raw) +} + +// Get returns the raw schema identified by the path. +func (ss *SchemaSet) Get(path Ref) interface{} { + if ss == nil { + return nil + } + x, ok := ss.m.Get(path) + if !ok { + return nil + } + return x +} + +func loadSchema(raw interface{}, allowNet []string) (types.Type, error) { + + jsonSchema, err := compileSchema(raw, allowNet) + if err != nil { + return nil, err + } + + tpe, err := newSchemaParser().parseSchema(jsonSchema.RootSchema) + if err != nil { + return nil, fmt.Errorf("type checking: %w", err) + } + + return tpe, nil +} diff --git a/vendor/github.com/open-policy-agent/opa/v1/ast/strings.go b/vendor/github.com/open-policy-agent/opa/v1/ast/strings.go new file mode 100644 index 000000000..e489f6977 --- /dev/null +++ b/vendor/github.com/open-policy-agent/opa/v1/ast/strings.go @@ -0,0 +1,18 @@ +// Copyright 2016 The OPA Authors. All rights reserved. +// Use of this source code is governed by an Apache2 +// license that can be found in the LICENSE file. + +package ast + +import ( + "reflect" + "strings" +) + +// TypeName returns a human readable name for the AST element type. +func TypeName(x interface{}) string { + if _, ok := x.(*lazyObj); ok { + return "object" + } + return strings.ToLower(reflect.Indirect(reflect.ValueOf(x)).Type().Name()) +} diff --git a/vendor/github.com/open-policy-agent/opa/v1/ast/term.go b/vendor/github.com/open-policy-agent/opa/v1/ast/term.go new file mode 100644 index 000000000..d79f4418b --- /dev/null +++ b/vendor/github.com/open-policy-agent/opa/v1/ast/term.go @@ -0,0 +1,3358 @@ +// Copyright 2016 The OPA Authors. All rights reserved. +// Use of this source code is governed by an Apache2 +// license that can be found in the LICENSE file. + +// nolint: deadcode // Public API. +package ast + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "math" + "math/big" + "net/url" + "regexp" + "sort" + "strconv" + "strings" + "sync" + + "github.com/OneOfOne/xxhash" + + astJSON "github.com/open-policy-agent/opa/v1/ast/json" + "github.com/open-policy-agent/opa/v1/ast/location" + "github.com/open-policy-agent/opa/v1/util" +) + +var errFindNotFound = fmt.Errorf("find: not found") + +// Location records a position in source code. +type Location = location.Location + +// NewLocation returns a new Location object. +func NewLocation(text []byte, file string, row int, col int) *Location { + return location.NewLocation(text, file, row, col) +} + +// Value declares the common interface for all Term values. Every kind of Term value +// in the language is represented as a type that implements this interface: +// +// - Null, Boolean, Number, String +// - Object, Array, Set +// - Variables, References +// - Array, Set, and Object Comprehensions +// - Calls +type Value interface { + Compare(other Value) int // Compare returns <0, 0, or >0 if this Value is less than, equal to, or greater than other, respectively. + Find(path Ref) (Value, error) // Find returns value referred to by path or an error if path is not found. + Hash() int // Returns hash code of the value. + IsGround() bool // IsGround returns true if this value is not a variable or contains no variables. + String() string // String returns a human readable string representation of the value. +} + +// InterfaceToValue converts a native Go value x to a Value. +func InterfaceToValue(x interface{}) (Value, error) { + switch x := x.(type) { + case nil: + return Null{}, nil + case bool: + return Boolean(x), nil + case json.Number: + return Number(x), nil + case int64: + return int64Number(x), nil + case uint64: + return uint64Number(x), nil + case float64: + return floatNumber(x), nil + case int: + return intNumber(x), nil + case string: + return String(x), nil + case []any: + r := util.NewPtrSlice[Term](len(x)) + for i, e := range x { + e, err := InterfaceToValue(e) + if err != nil { + return nil, err + } + r[i].Value = e + } + return NewArray(r...), nil + case map[string]any: + kvs := util.NewPtrSlice[Term](len(x) * 2) + idx := 0 + for k, v := range x { + k, err := InterfaceToValue(k) + if err != nil { + return nil, err + } + kvs[idx].Value = k + v, err := InterfaceToValue(v) + if err != nil { + return nil, err + } + kvs[idx+1].Value = v + idx += 2 + } + tuples := make([][2]*Term, len(kvs)/2) + for i := 0; i < len(kvs); i += 2 { + tuples[i/2] = *(*[2]*Term)(kvs[i : i+2]) + } + return NewObject(tuples...), nil + case map[string]string: + r := newobject(len(x)) + for k, v := range x { + k, err := InterfaceToValue(k) + if err != nil { + return nil, err + } + v, err := InterfaceToValue(v) + if err != nil { + return nil, err + } + r.Insert(NewTerm(k), NewTerm(v)) + } + return r, nil + default: + ptr := util.Reference(x) + if err := util.RoundTrip(ptr); err != nil { + return nil, fmt.Errorf("ast: interface conversion: %w", err) + } + return InterfaceToValue(*ptr) + } +} + +// ValueFromReader returns an AST value from a JSON serialized value in the reader. +func ValueFromReader(r io.Reader) (Value, error) { + var x interface{} + if err := util.NewJSONDecoder(r).Decode(&x); err != nil { + return nil, err + } + return InterfaceToValue(x) +} + +// As converts v into a Go native type referred to by x. +func As(v Value, x interface{}) error { + return util.NewJSONDecoder(bytes.NewBufferString(v.String())).Decode(x) +} + +// Resolver defines the interface for resolving references to native Go values. +type Resolver interface { + Resolve(Ref) (interface{}, error) +} + +// ValueResolver defines the interface for resolving references to AST values. +type ValueResolver interface { + Resolve(Ref) (Value, error) +} + +// UnknownValueErr indicates a ValueResolver was unable to resolve a reference +// because the reference refers to an unknown value. +type UnknownValueErr struct{} + +func (UnknownValueErr) Error() string { + return "unknown value" +} + +// IsUnknownValueErr returns true if the err is an UnknownValueErr. +func IsUnknownValueErr(err error) bool { + _, ok := err.(UnknownValueErr) + return ok +} + +type illegalResolver struct{} + +func (illegalResolver) Resolve(ref Ref) (interface{}, error) { + return nil, fmt.Errorf("illegal value: %v", ref) +} + +// ValueToInterface returns the Go representation of an AST value. The AST +// value should not contain any values that require evaluation (e.g., vars, +// comprehensions, etc.) +func ValueToInterface(v Value, resolver Resolver) (interface{}, error) { + return valueToInterface(v, resolver, JSONOpt{}) +} + +func valueToInterface(v Value, resolver Resolver, opt JSONOpt) (interface{}, error) { + switch v := v.(type) { + case Null: + return nil, nil + case Boolean: + return bool(v), nil + case Number: + return json.Number(v), nil + case String: + return string(v), nil + case *Array: + buf := []interface{}{} + for i := 0; i < v.Len(); i++ { + x1, err := valueToInterface(v.Elem(i).Value, resolver, opt) + if err != nil { + return nil, err + } + buf = append(buf, x1) + } + return buf, nil + case *object: + buf := make(map[string]interface{}, v.Len()) + err := v.Iter(func(k, v *Term) error { + ki, err := valueToInterface(k.Value, resolver, opt) + if err != nil { + return err + } + var str string + var ok bool + if str, ok = ki.(string); !ok { + var buf bytes.Buffer + if err := json.NewEncoder(&buf).Encode(ki); err != nil { + return err + } + str = strings.TrimSpace(buf.String()) + } + vi, err := valueToInterface(v.Value, resolver, opt) + if err != nil { + return err + } + buf[str] = vi + return nil + }) + if err != nil { + return nil, err + } + return buf, nil + case *lazyObj: + if opt.CopyMaps { + return valueToInterface(v.force(), resolver, opt) + } + return v.native, nil + case Set: + buf := []interface{}{} + iter := func(x *Term) error { + x1, err := valueToInterface(x.Value, resolver, opt) + if err != nil { + return err + } + buf = append(buf, x1) + return nil + } + var err error + if opt.SortSets { + err = v.Sorted().Iter(iter) + } else { + err = v.Iter(iter) + } + if err != nil { + return nil, err + } + return buf, nil + case Ref: + return resolver.Resolve(v) + default: + return nil, fmt.Errorf("%v requires evaluation", TypeName(v)) + } +} + +// JSON returns the JSON representation of v. The value must not contain any +// refs or terms that require evaluation (e.g., vars, comprehensions, etc.) +func JSON(v Value) (interface{}, error) { + return JSONWithOpt(v, JSONOpt{}) +} + +// JSONOpt defines parameters for AST to JSON conversion. +type JSONOpt struct { + SortSets bool // sort sets before serializing (this makes conversion more expensive) + CopyMaps bool // enforces copying of map[string]interface{} read from the store +} + +// JSONWithOpt returns the JSON representation of v. The value must not contain any +// refs or terms that require evaluation (e.g., vars, comprehensions, etc.) +func JSONWithOpt(v Value, opt JSONOpt) (interface{}, error) { + return valueToInterface(v, illegalResolver{}, opt) +} + +// MustJSON returns the JSON representation of v. The value must not contain any +// refs or terms that require evaluation (e.g., vars, comprehensions, etc.) If +// the conversion fails, this function will panic. This function is mostly for +// test purposes. +func MustJSON(v Value) interface{} { + r, err := JSON(v) + if err != nil { + panic(err) + } + return r +} + +// MustInterfaceToValue converts a native Go value x to a Value. If the +// conversion fails, this function will panic. This function is mostly for test +// purposes. +func MustInterfaceToValue(x interface{}) Value { + v, err := InterfaceToValue(x) + if err != nil { + panic(err) + } + return v +} + +// Term is an argument to a function. +type Term struct { + Value Value `json:"value"` // the value of the Term as represented in Go + Location *Location `json:"location,omitempty"` // the location of the Term in the source + + jsonOptions astJSON.Options +} + +// NewTerm returns a new Term object. +func NewTerm(v Value) *Term { + return &Term{ + Value: v, + } +} + +// SetLocation updates the term's Location and returns the term itself. +func (term *Term) SetLocation(loc *Location) *Term { + term.Location = loc + return term +} + +// Loc returns the Location of term. +func (term *Term) Loc() *Location { + if term == nil { + return nil + } + return term.Location +} + +// SetLoc sets the location on term. +func (term *Term) SetLoc(loc *Location) { + term.SetLocation(loc) +} + +// Copy returns a deep copy of term. +func (term *Term) Copy() *Term { + if term == nil { + return nil + } + + cpy := *term + + switch v := term.Value.(type) { + case Null, Boolean, Number, String, Var: + cpy.Value = v + case Ref: + cpy.Value = v.Copy() + case *Array: + cpy.Value = v.Copy() + case Set: + cpy.Value = v.Copy() + case *object: + cpy.Value = v.Copy() + case *ArrayComprehension: + cpy.Value = v.Copy() + case *ObjectComprehension: + cpy.Value = v.Copy() + case *SetComprehension: + cpy.Value = v.Copy() + case Call: + cpy.Value = v.Copy() + } + + return &cpy +} + +// Equal returns true if this term equals the other term. Equality is +// defined for each kind of term. +func (term *Term) Equal(other *Term) bool { + if term == nil && other != nil { + return false + } + if term != nil && other == nil { + return false + } + if term == other { + return true + } + + // TODO(tsandall): This early-exit avoids allocations for types that have + // Equal() functions that just use == underneath. We should revisit the + // other types and implement Equal() functions that do not require + // allocations. + switch v := term.Value.(type) { + case Null: + return v.Equal(other.Value) + case Boolean: + return v.Equal(other.Value) + case Number: + return v.Equal(other.Value) + case String: + return v.Equal(other.Value) + case Var: + return v.Equal(other.Value) + case Ref: + return v.Equal(other.Value) + case *Array: + return v.Equal(other.Value) + } + + return term.Value.Compare(other.Value) == 0 +} + +// Get returns a value referred to by name from the term. +func (term *Term) Get(name *Term) *Term { + switch v := term.Value.(type) { + case *object: + return v.Get(name) + case *Array: + return v.Get(name) + case interface { + Get(*Term) *Term + }: + return v.Get(name) + case Set: + if v.Contains(name) { + return name + } + } + return nil +} + +// Hash returns the hash code of the Term's Value. Its Location +// is ignored. +func (term *Term) Hash() int { + return term.Value.Hash() +} + +// IsGround returns true if this term's Value is ground. +func (term *Term) IsGround() bool { + return term.Value.IsGround() +} + +func (term *Term) setJSONOptions(opts astJSON.Options) { + term.jsonOptions = opts + if term.Location != nil { + term.Location.JSONOptions = opts + } +} + +// MarshalJSON returns the JSON encoding of the term. +// +// Specialized marshalling logic is required to include a type hint for Value. +func (term *Term) MarshalJSON() ([]byte, error) { + d := map[string]interface{}{ + "type": TypeName(term.Value), + "value": term.Value, + } + if term.jsonOptions.MarshalOptions.IncludeLocation.Term { + if term.Location != nil { + d["location"] = term.Location + } + } + return json.Marshal(d) +} + +func (term *Term) String() string { + return term.Value.String() +} + +// UnmarshalJSON parses the byte array and stores the result in term. +// Specialized unmarshalling is required to handle Value and Location. +func (term *Term) UnmarshalJSON(bs []byte) error { + v := map[string]interface{}{} + if err := util.UnmarshalJSON(bs, &v); err != nil { + return err + } + val, err := unmarshalValue(v) + if err != nil { + return err + } + term.Value = val + + if loc, ok := v["location"].(map[string]interface{}); ok { + term.Location = &Location{} + err := unmarshalLocation(term.Location, loc) + if err != nil { + return err + } + } + return nil +} + +// Vars returns a VarSet with variables contained in this term. +func (term *Term) Vars() VarSet { + vis := &VarVisitor{vars: VarSet{}} + vis.Walk(term) + return vis.vars +} + +// IsConstant returns true if the AST value is constant. +func IsConstant(v Value) bool { + found := false + vis := GenericVisitor{ + func(x interface{}) bool { + switch x.(type) { + case Var, Ref, *ArrayComprehension, *ObjectComprehension, *SetComprehension, Call: + found = true + return true + } + return false + }, + } + vis.Walk(v) + return !found +} + +// IsComprehension returns true if the supplied value is a comprehension. +func IsComprehension(x Value) bool { + switch x.(type) { + case *ArrayComprehension, *ObjectComprehension, *SetComprehension: + return true + } + return false +} + +// ContainsRefs returns true if the Value v contains refs. +func ContainsRefs(v interface{}) bool { + found := false + WalkRefs(v, func(Ref) bool { + found = true + return found + }) + return found +} + +// ContainsComprehensions returns true if the Value v contains comprehensions. +func ContainsComprehensions(v interface{}) bool { + found := false + WalkClosures(v, func(x interface{}) bool { + switch x.(type) { + case *ArrayComprehension, *ObjectComprehension, *SetComprehension: + found = true + return found + } + return found + }) + return found +} + +// ContainsClosures returns true if the Value v contains closures. +func ContainsClosures(v interface{}) bool { + found := false + WalkClosures(v, func(x interface{}) bool { + switch x.(type) { + case *ArrayComprehension, *ObjectComprehension, *SetComprehension, *Every: + found = true + return found + } + return found + }) + return found +} + +// IsScalar returns true if the AST value is a scalar. +func IsScalar(v Value) bool { + switch v.(type) { + case String: + return true + case Number: + return true + case Boolean: + return true + case Null: + return true + } + return false +} + +// Null represents the null value defined by JSON. +type Null struct{} + +// NullTerm creates a new Term with a Null value. +func NullTerm() *Term { + return &Term{Value: Null{}} +} + +// Equal returns true if the other term Value is also Null. +func (null Null) Equal(other Value) bool { + switch other.(type) { + case Null: + return true + default: + return false + } +} + +// Compare compares null to other, return <0, 0, or >0 if it is less than, equal to, +// or greater than other. +func (null Null) Compare(other Value) int { + return Compare(null, other) +} + +// Find returns the current value or a not found error. +func (null Null) Find(path Ref) (Value, error) { + if len(path) == 0 { + return null, nil + } + return nil, errFindNotFound +} + +// Hash returns the hash code for the Value. +func (null Null) Hash() int { + return 0 +} + +// IsGround always returns true. +func (Null) IsGround() bool { + return true +} + +func (null Null) String() string { + return "null" +} + +// Boolean represents a boolean value defined by JSON. +type Boolean bool + +// BooleanTerm creates a new Term with a Boolean value. +func BooleanTerm(b bool) *Term { + return &Term{Value: Boolean(b)} +} + +// Equal returns true if the other Value is a Boolean and is equal. +func (bol Boolean) Equal(other Value) bool { + switch other := other.(type) { + case Boolean: + return bol == other + default: + return false + } +} + +// Compare compares bol to other, return <0, 0, or >0 if it is less than, equal to, +// or greater than other. +func (bol Boolean) Compare(other Value) int { + return Compare(bol, other) +} + +// Find returns the current value or a not found error. +func (bol Boolean) Find(path Ref) (Value, error) { + if len(path) == 0 { + return bol, nil + } + return nil, errFindNotFound +} + +// Hash returns the hash code for the Value. +func (bol Boolean) Hash() int { + if bol { + return 1 + } + return 0 +} + +// IsGround always returns true. +func (Boolean) IsGround() bool { + return true +} + +func (bol Boolean) String() string { + return strconv.FormatBool(bool(bol)) +} + +// Number represents a numeric value as defined by JSON. +type Number json.Number + +// NumberTerm creates a new Term with a Number value. +func NumberTerm(n json.Number) *Term { + return &Term{Value: Number(n)} +} + +// IntNumberTerm creates a new Term with an integer Number value. +func IntNumberTerm(i int) *Term { + return &Term{Value: Number(strconv.Itoa(i))} +} + +// UIntNumberTerm creates a new Term with an unsigned integer Number value. +func UIntNumberTerm(u uint64) *Term { + return &Term{Value: uint64Number(u)} +} + +// FloatNumberTerm creates a new Term with a floating point Number value. +func FloatNumberTerm(f float64) *Term { + s := strconv.FormatFloat(f, 'g', -1, 64) + return &Term{Value: Number(s)} +} + +// Equal returns true if the other Value is a Number and is equal. +func (num Number) Equal(other Value) bool { + switch other := other.(type) { + case Number: + n1, ok1 := num.Int64() + n2, ok2 := other.Int64() + if ok1 && ok2 && n1 == n2 { + return true + } + + return Compare(num, other) == 0 + default: + return false + } +} + +// Compare compares num to other, return <0, 0, or >0 if it is less than, equal to, +// or greater than other. +func (num Number) Compare(other Value) int { + return Compare(num, other) +} + +// Find returns the current value or a not found error. +func (num Number) Find(path Ref) (Value, error) { + if len(path) == 0 { + return num, nil + } + return nil, errFindNotFound +} + +// Hash returns the hash code for the Value. +func (num Number) Hash() int { + f, err := json.Number(num).Float64() + if err != nil { + bs := []byte(num) + h := xxhash.Checksum64(bs) + return int(h) + } + return int(f) +} + +// Int returns the int representation of num if possible. +func (num Number) Int() (int, bool) { + i64, ok := num.Int64() + return int(i64), ok +} + +// Int64 returns the int64 representation of num if possible. +func (num Number) Int64() (int64, bool) { + i, err := json.Number(num).Int64() + if err != nil { + return 0, false + } + return i, true +} + +// Float64 returns the float64 representation of num if possible. +func (num Number) Float64() (float64, bool) { + f, err := json.Number(num).Float64() + if err != nil { + return 0, false + } + return f, true +} + +// IsGround always returns true. +func (Number) IsGround() bool { + return true +} + +// MarshalJSON returns JSON encoded bytes representing num. +func (num Number) MarshalJSON() ([]byte, error) { + return json.Marshal(json.Number(num)) +} + +func (num Number) String() string { + return string(num) +} + +func intNumber(i int) Number { + return Number(strconv.Itoa(i)) +} + +func int64Number(i int64) Number { + return Number(strconv.FormatInt(i, 10)) +} + +func uint64Number(u uint64) Number { + return Number(strconv.FormatUint(u, 10)) +} + +func floatNumber(f float64) Number { + return Number(strconv.FormatFloat(f, 'g', -1, 64)) +} + +// String represents a string value as defined by JSON. +type String string + +// StringTerm creates a new Term with a String value. +func StringTerm(s string) *Term { + return &Term{Value: String(s)} +} + +// Equal returns true if the other Value is a String and is equal. +func (str String) Equal(other Value) bool { + switch other := other.(type) { + case String: + return str == other + default: + return false + } +} + +// Compare compares str to other, return <0, 0, or >0 if it is less than, equal to, +// or greater than other. +func (str String) Compare(other Value) int { + return Compare(str, other) +} + +// Find returns the current value or a not found error. +func (str String) Find(path Ref) (Value, error) { + if len(path) == 0 { + return str, nil + } + return nil, errFindNotFound +} + +// IsGround always returns true. +func (String) IsGround() bool { + return true +} + +func (str String) String() string { + return strconv.Quote(string(str)) +} + +// Hash returns the hash code for the Value. +func (str String) Hash() int { + h := xxhash.ChecksumString64S(string(str), hashSeed0) + return int(h) +} + +// Var represents a variable as defined by the language. +type Var string + +// VarTerm creates a new Term with a Variable value. +func VarTerm(v string) *Term { + return &Term{Value: Var(v)} +} + +// Equal returns true if the other Value is a Variable and has the same value +// (name). +func (v Var) Equal(other Value) bool { + switch other := other.(type) { + case Var: + return v == other + default: + return false + } +} + +// Compare compares v to other, return <0, 0, or >0 if it is less than, equal to, +// or greater than other. +func (v Var) Compare(other Value) int { + return Compare(v, other) +} + +// Find returns the current value or a not found error. +func (v Var) Find(path Ref) (Value, error) { + if len(path) == 0 { + return v, nil + } + return nil, errFindNotFound +} + +// Hash returns the hash code for the Value. +func (v Var) Hash() int { + h := xxhash.ChecksumString64S(string(v), hashSeed0) + return int(h) +} + +// IsGround always returns false. +func (Var) IsGround() bool { + return false +} + +// IsWildcard returns true if this is a wildcard variable. +func (v Var) IsWildcard() bool { + return strings.HasPrefix(string(v), WildcardPrefix) +} + +// IsGenerated returns true if this variable was generated during compilation. +func (v Var) IsGenerated() bool { + return strings.HasPrefix(string(v), "__local") +} + +func (v Var) String() string { + // Special case for wildcard so that string representation is parseable. The + // parser mangles wildcard variables to make their names unique and uses an + // illegal variable name character (WildcardPrefix) to avoid conflicts. When + // we serialize the variable here, we need to make sure it's parseable. + if v.IsWildcard() { + return Wildcard.String() + } + return string(v) +} + +// Ref represents a reference as defined by the language. +type Ref []*Term + +// EmptyRef returns a new, empty reference. +func EmptyRef() Ref { + return Ref([]*Term{}) +} + +// PtrRef returns a new reference against the head for the pointer +// s. Path components in the pointer are unescaped. +func PtrRef(head *Term, s string) (Ref, error) { + s = strings.Trim(s, "/") + if s == "" { + return Ref{head}, nil + } + parts := strings.Split(s, "/") + if maxLen := math.MaxInt32; len(parts) >= maxLen { + return nil, fmt.Errorf("path too long: %s, %d > %d (max)", s, len(parts), maxLen) + } + ref := make(Ref, uint(len(parts))+1) + ref[0] = head + for i := 0; i < len(parts); i++ { + var err error + parts[i], err = url.PathUnescape(parts[i]) + if err != nil { + return nil, err + } + ref[i+1] = StringTerm(parts[i]) + } + return ref, nil +} + +// RefTerm creates a new Term with a Ref value. +func RefTerm(r ...*Term) *Term { + return &Term{Value: Ref(r)} +} + +// Append returns a copy of ref with the term appended to the end. +func (ref Ref) Append(term *Term) Ref { + n := len(ref) + dst := make(Ref, n+1) + copy(dst, ref) + dst[n] = term + return dst +} + +// Insert returns a copy of the ref with x inserted at pos. If pos < len(ref), +// existing elements are shifted to the right. If pos > len(ref)+1 this +// function panics. +func (ref Ref) Insert(x *Term, pos int) Ref { + switch { + case pos == len(ref): + return ref.Append(x) + case pos > len(ref)+1: + panic("illegal index") + } + cpy := make(Ref, len(ref)+1) + copy(cpy, ref[:pos]) + cpy[pos] = x + copy(cpy[pos+1:], ref[pos:]) + return cpy +} + +// Extend returns a copy of ref with the terms from other appended. The head of +// other will be converted to a string. +func (ref Ref) Extend(other Ref) Ref { + dst := make(Ref, len(ref)+len(other)) + copy(dst, ref) + + head := other[0].Copy() + head.Value = String(head.Value.(Var)) + offset := len(ref) + dst[offset] = head + + copy(dst[offset+1:], other[1:]) + return dst +} + +// Concat returns a ref with the terms appended. +func (ref Ref) Concat(terms []*Term) Ref { + if len(terms) == 0 { + return ref + } + cpy := make(Ref, len(ref)+len(terms)) + copy(cpy, ref) + copy(cpy[len(ref):], terms) + return cpy +} + +// Dynamic returns the offset of the first non-constant operand of ref. +func (ref Ref) Dynamic() int { + switch ref[0].Value.(type) { + case Call: + return 0 + } + for i := 1; i < len(ref); i++ { + if !IsConstant(ref[i].Value) { + return i + } + } + return -1 +} + +// Copy returns a deep copy of ref. +func (ref Ref) Copy() Ref { + return termSliceCopy(ref) +} + +// Equal returns true if ref is equal to other. +func (ref Ref) Equal(other Value) bool { + switch o := other.(type) { + case Ref: + if len(ref) == len(o) { + for i := range ref { + if !ref[i].Equal(o[i]) { + return false + } + } + + return true + } + } + + return false +} + +// Compare compares ref to other, return <0, 0, or >0 if it is less than, equal to, +// or greater than other. +func (ref Ref) Compare(other Value) int { + return Compare(ref, other) +} + +// Find returns the current value or a "not found" error. +func (ref Ref) Find(path Ref) (Value, error) { + if len(path) == 0 { + return ref, nil + } + return nil, errFindNotFound +} + +// Hash returns the hash code for the Value. +func (ref Ref) Hash() int { + return termSliceHash(ref) +} + +// HasPrefix returns true if the other ref is a prefix of this ref. +func (ref Ref) HasPrefix(other Ref) bool { + if len(other) > len(ref) { + return false + } + for i := range other { + if !ref[i].Equal(other[i]) { + return false + } + } + return true +} + +// ConstantPrefix returns the constant portion of the ref starting from the head. +func (ref Ref) ConstantPrefix() Ref { + ref = ref.Copy() + + i := ref.Dynamic() + if i < 0 { + return ref + } + return ref[:i] +} + +func (ref Ref) StringPrefix() Ref { + r := ref.Copy() + + for i := 1; i < len(ref); i++ { + switch r[i].Value.(type) { + case String: // pass + default: // cut off + return r[:i] + } + } + + return r +} + +// GroundPrefix returns the ground portion of the ref starting from the head. By +// definition, the head of the reference is always ground. +func (ref Ref) GroundPrefix() Ref { + prefix := make(Ref, 0, len(ref)) + + for i, x := range ref { + if i > 0 && !x.IsGround() { + break + } + prefix = append(prefix, x) + } + + return prefix +} + +func (ref Ref) DynamicSuffix() Ref { + i := ref.Dynamic() + if i < 0 { + return nil + } + return ref[i:] +} + +// IsGround returns true if all of the parts of the Ref are ground. +func (ref Ref) IsGround() bool { + if len(ref) == 0 { + return true + } + return termSliceIsGround(ref[1:]) +} + +// IsNested returns true if this ref contains other Refs. +func (ref Ref) IsNested() bool { + for _, x := range ref { + if _, ok := x.Value.(Ref); ok { + return true + } + } + return false +} + +// Ptr returns a slash-separated path string for this ref. If the ref +// contains non-string terms this function returns an error. Path +// components are escaped. +func (ref Ref) Ptr() (string, error) { + parts := make([]string, 0, len(ref)-1) + for _, term := range ref[1:] { + if str, ok := term.Value.(String); ok { + parts = append(parts, url.PathEscape(string(str))) + } else { + return "", fmt.Errorf("invalid path value type") + } + } + return strings.Join(parts, "/"), nil +} + +var varRegexp = regexp.MustCompile("^[[:alpha:]_][[:alpha:][:digit:]_]*$") + +func IsVarCompatibleString(s string) bool { + return varRegexp.MatchString(s) +} + +var sbPool = sync.Pool{ + New: func() any { + return &strings.Builder{} + }, +} + +func (ref Ref) String() string { + if len(ref) == 0 { + return "" + } + + sb := sbPool.Get().(*strings.Builder) + sb.Reset() + + defer sbPool.Put(sb) + + sb.Grow(10 * len(ref)) + + sb.WriteString(ref[0].Value.String()) + + for _, p := range ref[1:] { + switch p := p.Value.(type) { + case String: + str := string(p) + if varRegexp.MatchString(str) && !IsKeyword(str) { + sb.WriteByte('.') + sb.WriteString(str) + } else { + sb.WriteString(`["`) + sb.WriteString(str) + sb.WriteString(`"]`) + } + default: + sb.WriteByte('[') + sb.WriteString(p.String()) + sb.WriteByte(']') + } + } + + return sb.String() +} + +// OutputVars returns a VarSet containing variables that would be bound by evaluating +// this expression in isolation. +func (ref Ref) OutputVars() VarSet { + vis := NewVarVisitor().WithParams(VarVisitorParams{SkipRefHead: true}) + vis.Walk(ref) + return vis.Vars() +} + +func (ref Ref) toArray() *Array { + a := NewArray() + for _, term := range ref { + if _, ok := term.Value.(String); ok { + a = a.Append(term) + } else { + a = a.Append(StringTerm(term.Value.String())) + } + } + return a +} + +// QueryIterator defines the interface for querying AST documents with references. +type QueryIterator func(map[Var]Value, Value) error + +// ArrayTerm creates a new Term with an Array value. +func ArrayTerm(a ...*Term) *Term { + return NewTerm(NewArray(a...)) +} + +// NewArray creates an Array with the terms provided. The array will +// use the provided term slice. +func NewArray(a ...*Term) *Array { + hs := make([]int, len(a)) + for i, e := range a { + hs[i] = e.Value.Hash() + } + arr := &Array{elems: a, hashs: hs, ground: termSliceIsGround(a)} + arr.rehash() + return arr +} + +// Array represents an array as defined by the language. Arrays are similar to the +// same types as defined by JSON with the exception that they can contain Vars +// and References. +type Array struct { + elems []*Term + hashs []int // element hashes + hash int + ground bool +} + +// Copy returns a deep copy of arr. +func (arr *Array) Copy() *Array { + cpy := make([]int, len(arr.elems)) + copy(cpy, arr.hashs) + return &Array{ + elems: termSliceCopy(arr.elems), + hashs: cpy, + hash: arr.hash, + ground: arr.IsGround()} +} + +// Equal returns true if arr is equal to other. +func (arr *Array) Equal(other Value) bool { + if arr == other { + return true + } + + if other, ok := other.(*Array); ok && len(arr.elems) == len(other.elems) { + for i := range arr.elems { + if !arr.elems[i].Equal(other.elems[i]) { + return false + } + } + return true + } + + return false +} + +// Compare compares arr to other, return <0, 0, or >0 if it is less than, equal to, +// or greater than other. +func (arr *Array) Compare(other Value) int { + return Compare(arr, other) +} + +// Find returns the value at the index or an out-of-range error. +func (arr *Array) Find(path Ref) (Value, error) { + if len(path) == 0 { + return arr, nil + } + num, ok := path[0].Value.(Number) + if !ok { + return nil, errFindNotFound + } + i, ok := num.Int() + if !ok { + return nil, errFindNotFound + } + if i < 0 || i >= arr.Len() { + return nil, errFindNotFound + } + return arr.Elem(i).Value.Find(path[1:]) +} + +// Get returns the element at pos or nil if not possible. +func (arr *Array) Get(pos *Term) *Term { + num, ok := pos.Value.(Number) + if !ok { + return nil + } + + i, ok := num.Int() + if !ok { + return nil + } + + if i >= 0 && i < len(arr.elems) { + return arr.elems[i] + } + + return nil +} + +// Sorted returns a new Array that contains the sorted elements of arr. +func (arr *Array) Sorted() *Array { + cpy := make([]*Term, len(arr.elems)) + for i := range cpy { + cpy[i] = arr.elems[i] + } + sort.Sort(termSlice(cpy)) + a := NewArray(cpy...) + a.hashs = arr.hashs + return a +} + +// Hash returns the hash code for the Value. +func (arr *Array) Hash() int { + return arr.hash +} + +// IsGround returns true if all of the Array elements are ground. +func (arr *Array) IsGround() bool { + return arr.ground +} + +// MarshalJSON returns JSON encoded bytes representing arr. +func (arr *Array) MarshalJSON() ([]byte, error) { + if len(arr.elems) == 0 { + return []byte(`[]`), nil + } + return json.Marshal(arr.elems) +} + +func (arr *Array) String() string { + sb := sbPool.Get().(*strings.Builder) + sb.Reset() + sb.Grow(len(arr.elems) * 16) + + defer sbPool.Put(sb) + + sb.WriteRune('[') + for i, e := range arr.elems { + if i > 0 { + sb.WriteString(", ") + } + sb.WriteString(e.String()) + } + sb.WriteRune(']') + + return sb.String() +} + +// Len returns the number of elements in the array. +func (arr *Array) Len() int { + return len(arr.elems) +} + +// Elem returns the element i of arr. +func (arr *Array) Elem(i int) *Term { + return arr.elems[i] +} + +// Set sets the element i of arr. +func (arr *Array) Set(i int, v *Term) { + arr.set(i, v) +} + +// rehash updates the cached hash of arr. +func (arr *Array) rehash() { + arr.hash = 0 + for _, h := range arr.hashs { + arr.hash += h + } +} + +// set sets the element i of arr. +func (arr *Array) set(i int, v *Term) { + arr.ground = arr.ground && v.IsGround() + arr.elems[i] = v + arr.hashs[i] = v.Value.Hash() + arr.rehash() +} + +// Slice returns a slice of arr starting from i index to j. -1 +// indicates the end of the array. The returned value array is not a +// copy and any modifications to either of arrays may be reflected to +// the other. +func (arr *Array) Slice(i, j int) *Array { + var elems []*Term + var hashs []int + if j == -1 { + elems = arr.elems[i:] + hashs = arr.hashs[i:] + } else { + elems = arr.elems[i:j] + hashs = arr.hashs[i:j] + } + // If arr is ground, the slice is, too. + // If it's not, the slice could still be. + gr := arr.ground || termSliceIsGround(elems) + + s := &Array{elems: elems, hashs: hashs, ground: gr} + s.rehash() + return s +} + +// Iter calls f on each element in arr. If f returns an error, +// iteration stops and the return value is the error. +func (arr *Array) Iter(f func(*Term) error) error { + for i := range arr.elems { + if err := f(arr.elems[i]); err != nil { + return err + } + } + return nil +} + +// Until calls f on each element in arr. If f returns true, iteration stops. +func (arr *Array) Until(f func(*Term) bool) bool { + for _, term := range arr.elems { + if f(term) { + return true + } + } + return false +} + +// Foreach calls f on each element in arr. +func (arr *Array) Foreach(f func(*Term)) { + for _, term := range arr.elems { + f(term) + } +} + +// Append appends a term to arr, returning the appended array. +func (arr *Array) Append(v *Term) *Array { + cpy := *arr + cpy.elems = append(arr.elems, v) + cpy.hashs = append(arr.hashs, v.Value.Hash()) + cpy.hash = arr.hash + v.Value.Hash() + cpy.ground = arr.ground && v.IsGround() + return &cpy +} + +// Set represents a set as defined by the language. +type Set interface { + Value + Len() int + Copy() Set + Diff(Set) Set + Intersect(Set) Set + Union(Set) Set + Add(*Term) + Iter(func(*Term) error) error + Until(func(*Term) bool) bool + Foreach(func(*Term)) + Contains(*Term) bool + Map(func(*Term) (*Term, error)) (Set, error) + Reduce(*Term, func(*Term, *Term) (*Term, error)) (*Term, error) + Sorted() *Array + Slice() []*Term +} + +// NewSet returns a new Set containing t. +func NewSet(t ...*Term) Set { + s := newset(len(t)) + for _, term := range t { + s.insert(term, false) + } + return s +} + +func newset(n int) *set { + var keys []*Term + if n > 0 { + keys = make([]*Term, 0, n) + } + return &set{ + elems: make(map[int]*Term, n), + keys: keys, + hash: 0, + ground: true, + sortGuard: new(sync.Once), + } +} + +// SetTerm returns a new Term representing a set containing terms t. +func SetTerm(t ...*Term) *Term { + set := NewSet(t...) + return &Term{ + Value: set, + } +} + +type set struct { + elems map[int]*Term + keys []*Term + hash int + ground bool + sortGuard *sync.Once // Prevents race condition around sorting. +} + +// Copy returns a deep copy of s. +func (s *set) Copy() Set { + terms := make([]*Term, len(s.keys)) + for i := range s.keys { + terms[i] = s.keys[i].Copy() + } + cpy := NewSet(terms...).(*set) + cpy.hash = s.hash + cpy.ground = s.ground + return cpy +} + +// IsGround returns true if all terms in s are ground. +func (s *set) IsGround() bool { + return s.ground +} + +// Hash returns a hash code for s. +func (s *set) Hash() int { + return s.hash +} + +func (s *set) String() string { + if s.Len() == 0 { + return "set()" + } + + sb := sbPool.Get().(*strings.Builder) + sb.Reset() + sb.Grow(s.Len() * 16) + + defer sbPool.Put(sb) + + sb.WriteRune('{') + for i := range s.sortedKeys() { + if i > 0 { + sb.WriteString(", ") + } + sb.WriteString(s.keys[i].Value.String()) + } + sb.WriteRune('}') + + return sb.String() +} + +func (s *set) sortedKeys() []*Term { + s.sortGuard.Do(func() { + sort.Sort(termSlice(s.keys)) + }) + return s.keys +} + +// Compare compares s to other, return <0, 0, or >0 if it is less than, equal to, +// or greater than other. +func (s *set) Compare(other Value) int { + o1 := sortOrder(s) + o2 := sortOrder(other) + if o1 < o2 { + return -1 + } else if o1 > o2 { + return 1 + } + t := other.(*set) + return termSliceCompare(s.sortedKeys(), t.sortedKeys()) +} + +// Find returns the set or dereferences the element itself. +func (s *set) Find(path Ref) (Value, error) { + if len(path) == 0 { + return s, nil + } + if !s.Contains(path[0]) { + return nil, errFindNotFound + } + return path[0].Value.Find(path[1:]) +} + +// Diff returns elements in s that are not in other. +func (s *set) Diff(other Set) Set { + terms := make([]*Term, 0, len(s.keys)) + for _, term := range s.sortedKeys() { + if !other.Contains(term) { + terms = append(terms, term) + } + } + + return NewSet(terms...) +} + +// Intersect returns the set containing elements in both s and other. +func (s *set) Intersect(other Set) Set { + o := other.(*set) + n, m := s.Len(), o.Len() + ss := s + so := o + if m < n { + ss = o + so = s + n = m + } + + terms := make([]*Term, 0, n) + for _, term := range ss.sortedKeys() { + if so.Contains(term) { + terms = append(terms, term) + } + } + + return NewSet(terms...) +} + +// Union returns the set containing all elements of s and other. +func (s *set) Union(other Set) Set { + r := NewSet() + s.Foreach(r.Add) + other.Foreach(r.Add) + return r +} + +// Add updates s to include t. +func (s *set) Add(t *Term) { + s.insert(t, true) +} + +// Iter calls f on each element in s. If f returns an error, iteration stops +// and the return value is the error. +func (s *set) Iter(f func(*Term) error) error { + for _, term := range s.sortedKeys() { + if err := f(term); err != nil { + return err + } + } + return nil +} + +// Until calls f on each element in s. If f returns true, iteration stops. +func (s *set) Until(f func(*Term) bool) bool { + for _, term := range s.sortedKeys() { + if f(term) { + return true + } + } + return false +} + +// Foreach calls f on each element in s. +func (s *set) Foreach(f func(*Term)) { + for _, term := range s.sortedKeys() { + f(term) + } +} + +// Map returns a new Set obtained by applying f to each value in s. +func (s *set) Map(f func(*Term) (*Term, error)) (Set, error) { + mapped := make([]*Term, 0, len(s.keys)) + for _, x := range s.sortedKeys() { + term, err := f(x) + if err != nil { + return nil, err + } + mapped = append(mapped, term) + } + return NewSet(mapped...), nil +} + +// Reduce returns a Term produced by applying f to each value in s. The first +// argument to f is the reduced value (starting with i) and the second argument +// to f is the element in s. +func (s *set) Reduce(i *Term, f func(*Term, *Term) (*Term, error)) (*Term, error) { + err := s.Iter(func(x *Term) error { + var err error + i, err = f(i, x) + if err != nil { + return err + } + return nil + }) + return i, err +} + +// Contains returns true if t is in s. +func (s *set) Contains(t *Term) bool { + return s.get(t) != nil +} + +// Len returns the number of elements in the set. +func (s *set) Len() int { + return len(s.keys) +} + +// MarshalJSON returns JSON encoded bytes representing s. +func (s *set) MarshalJSON() ([]byte, error) { + if s.keys == nil { + return []byte(`[]`), nil + } + return json.Marshal(s.sortedKeys()) +} + +// Sorted returns an Array that contains the sorted elements of s. +func (s *set) Sorted() *Array { + cpy := make([]*Term, len(s.keys)) + copy(cpy, s.sortedKeys()) + return NewArray(cpy...) +} + +// Slice returns a slice of terms contained in the set. +func (s *set) Slice() []*Term { + return s.sortedKeys() +} + +// Internal method to use for cases where a set may be reused in favor +// of creating a new one (with the associated allocations). +func (s *set) clear() { + clear(s.elems) + s.keys = s.keys[:0] + s.hash = 0 + s.ground = true + s.sortGuard = new(sync.Once) +} + +func (s *set) insertNoGuard(x *Term) { + s.insert(x, false) +} + +// NOTE(philipc): We assume a many-readers, single-writer model here. +// This method should NOT be used concurrently, or else we risk data races. +func (s *set) insert(x *Term, resetSortGuard bool) { + hash := x.Hash() + insertHash := hash + // This `equal` utility is duplicated and manually inlined a number of + // time in this file. Inlining it avoids heap allocations, so it makes + // a big performance difference: some operations like lookup become twice + // as slow without it. + var equal func(v Value) bool + + switch x := x.Value.(type) { + case Null, Boolean, String, Var: + equal = func(y Value) bool { return x == y } + case Number: + if xi, err := json.Number(x).Int64(); err == nil { + equal = func(y Value) bool { + if y, ok := y.(Number); ok { + if yi, err := json.Number(y).Int64(); err == nil { + return xi == yi + } + } + + return false + } + break + } + + // We use big.Rat for comparing big numbers. + // It replaces big.Float due to following reason: + // big.Float comes with a default precision of 64, and setting a + // larger precision results in more memory being allocated + // (regardless of the actual number we are parsing with SetString). + // + // Note: If we're so close to zero that big.Float says we are zero, do + // *not* big.Rat).SetString on the original string it'll potentially + // take very long. + var a *big.Rat + fa, ok := new(big.Float).SetString(string(x)) + if !ok { + panic("illegal value") + } + if fa.IsInt() { + if i, _ := fa.Int64(); i == 0 { + a = new(big.Rat).SetInt64(0) + } + } + if a == nil { + a, ok = new(big.Rat).SetString(string(x)) + if !ok { + panic("illegal value") + } + } + + equal = func(b Value) bool { + if bNum, ok := b.(Number); ok { + var b *big.Rat + fb, ok := new(big.Float).SetString(string(bNum)) + if !ok { + panic("illegal value") + } + if fb.IsInt() { + if i, _ := fb.Int64(); i == 0 { + b = new(big.Rat).SetInt64(0) + } + } + if b == nil { + b, ok = new(big.Rat).SetString(string(bNum)) + if !ok { + panic("illegal value") + } + } + + return a.Cmp(b) == 0 + } + + return false + } + default: + equal = func(y Value) bool { return Compare(x, y) == 0 } + } + + for curr, ok := s.elems[insertHash]; ok; { + if equal(curr.Value) { + return + } + + insertHash++ + curr, ok = s.elems[insertHash] + } + + s.elems[insertHash] = x + // O(1) insertion, but we'll have to re-sort the keys later. + s.keys = append(s.keys, x) + + if resetSortGuard { + // Reset the sync.Once instance. + // See https://github.com/golang/go/issues/25955 for why we do it this way. + // Note that this will always be the case when external code calls insert via + // Add, or otherwise. Internal code may however benefit from not having to + // re-create this pointer when it's known not to be needed. + s.sortGuard = new(sync.Once) + } + + s.hash += hash + s.ground = s.ground && x.IsGround() +} + +func (s *set) get(x *Term) *Term { + hash := x.Hash() + // This `equal` utility is duplicated and manually inlined a number of + // time in this file. Inlining it avoids heap allocations, so it makes + // a big performance difference: some operations like lookup become twice + // as slow without it. + var equal func(v Value) bool + + switch x := x.Value.(type) { + case Null, Boolean, String, Var: + equal = func(y Value) bool { return x == y } + case Number: + if xi, err := json.Number(x).Int64(); err == nil { + equal = func(y Value) bool { + if y, ok := y.(Number); ok { + if yi, err := json.Number(y).Int64(); err == nil { + return xi == yi + } + } + + return false + } + break + } + + // We use big.Rat for comparing big numbers. + // It replaces big.Float due to following reason: + // big.Float comes with a default precision of 64, and setting a + // larger precision results in more memory being allocated + // (regardless of the actual number we are parsing with SetString). + // + // Note: If we're so close to zero that big.Float says we are zero, do + // *not* big.Rat).SetString on the original string it'll potentially + // take very long. + var a *big.Rat + fa, ok := new(big.Float).SetString(string(x)) + if !ok { + panic("illegal value") + } + if fa.IsInt() { + if i, _ := fa.Int64(); i == 0 { + a = new(big.Rat).SetInt64(0) + } + } + if a == nil { + a, ok = new(big.Rat).SetString(string(x)) + if !ok { + panic("illegal value") + } + } + + equal = func(b Value) bool { + if bNum, ok := b.(Number); ok { + var b *big.Rat + fb, ok := new(big.Float).SetString(string(bNum)) + if !ok { + panic("illegal value") + } + if fb.IsInt() { + if i, _ := fb.Int64(); i == 0 { + b = new(big.Rat).SetInt64(0) + } + } + if b == nil { + b, ok = new(big.Rat).SetString(string(bNum)) + if !ok { + panic("illegal value") + } + } + + return a.Cmp(b) == 0 + } + return false + + } + + default: + equal = func(y Value) bool { return Compare(x, y) == 0 } + } + + for curr, ok := s.elems[hash]; ok; { + if equal(curr.Value) { + return curr + } + + hash++ + curr, ok = s.elems[hash] + } + return nil +} + +// Object represents an object as defined by the language. +type Object interface { + Value + Len() int + Get(*Term) *Term + Copy() Object + Insert(*Term, *Term) + Iter(func(*Term, *Term) error) error + Until(func(*Term, *Term) bool) bool + Foreach(func(*Term, *Term)) + Map(func(*Term, *Term) (*Term, *Term, error)) (Object, error) + Diff(other Object) Object + Intersect(other Object) [][3]*Term + Merge(other Object) (Object, bool) + MergeWith(other Object, conflictResolver func(v1, v2 *Term) (*Term, bool)) (Object, bool) + Filter(filter Object) (Object, error) + Keys() []*Term + KeysIterator() ObjectKeysIterator + get(k *Term) *objectElem // To prevent external implementations +} + +// NewObject creates a new Object with t. +func NewObject(t ...[2]*Term) Object { + obj := newobject(len(t)) + for i := range t { + obj.insert(t[i][0], t[i][1], false) + } + return obj +} + +// ObjectTerm creates a new Term with an Object value. +func ObjectTerm(o ...[2]*Term) *Term { + return &Term{Value: NewObject(o...)} +} + +func LazyObject(blob map[string]interface{}) Object { + return &lazyObj{native: blob, cache: map[string]Value{}} +} + +type lazyObj struct { + strict Object + cache map[string]Value + native map[string]interface{} +} + +func (l *lazyObj) force() Object { + if l.strict == nil { + l.strict = MustInterfaceToValue(l.native).(Object) + // NOTE(jf): a possible performance improvement here would be to check how many + // entries have been realized to AST in the cache, and if some threshold compared to the + // total number of keys is exceeded, realize the remaining entries and set l.strict to l.cache. + l.cache = map[string]Value{} // We don't need the cache anymore; drop it to free up memory. + } + return l.strict +} + +func (l *lazyObj) Compare(other Value) int { + o1 := sortOrder(l) + o2 := sortOrder(other) + if o1 < o2 { + return -1 + } else if o2 < o1 { + return 1 + } + return l.force().Compare(other) +} + +func (l *lazyObj) Copy() Object { + return l +} + +func (l *lazyObj) Diff(other Object) Object { + return l.force().Diff(other) +} + +func (l *lazyObj) Intersect(other Object) [][3]*Term { + return l.force().Intersect(other) +} + +func (l *lazyObj) Iter(f func(*Term, *Term) error) error { + return l.force().Iter(f) +} + +func (l *lazyObj) Until(f func(*Term, *Term) bool) bool { + // NOTE(sr): there could be benefits in not forcing here -- if we abort because + // `f` returns true, we could save us from converting the rest of the object. + return l.force().Until(f) +} + +func (l *lazyObj) Foreach(f func(*Term, *Term)) { + l.force().Foreach(f) +} + +func (l *lazyObj) Filter(filter Object) (Object, error) { + return l.force().Filter(filter) +} + +func (l *lazyObj) Map(f func(*Term, *Term) (*Term, *Term, error)) (Object, error) { + return l.force().Map(f) +} + +func (l *lazyObj) MarshalJSON() ([]byte, error) { + return l.force().(*object).MarshalJSON() +} + +func (l *lazyObj) Merge(other Object) (Object, bool) { + return l.force().Merge(other) +} + +func (l *lazyObj) MergeWith(other Object, conflictResolver func(v1, v2 *Term) (*Term, bool)) (Object, bool) { + return l.force().MergeWith(other, conflictResolver) +} + +func (l *lazyObj) Len() int { + return len(l.native) +} + +func (l *lazyObj) String() string { + return l.force().String() +} + +// get is merely there to implement the Object interface -- `get` there serves the +// purpose of prohibiting external implementations. It's never called for lazyObj. +func (*lazyObj) get(*Term) *objectElem { + return nil +} + +func (l *lazyObj) Get(k *Term) *Term { + if l.strict != nil { + return l.strict.Get(k) + } + if s, ok := k.Value.(String); ok { + if v, ok := l.cache[string(s)]; ok { + return NewTerm(v) + } + + if val, ok := l.native[string(s)]; ok { + var converted Value + switch val := val.(type) { + case map[string]interface{}: + converted = LazyObject(val) + default: + converted = MustInterfaceToValue(val) + } + l.cache[string(s)] = converted + return NewTerm(converted) + } + } + return nil +} + +func (l *lazyObj) Insert(k, v *Term) { + l.force().Insert(k, v) +} + +func (*lazyObj) IsGround() bool { + return true +} + +func (l *lazyObj) Hash() int { + return l.force().Hash() +} + +func (l *lazyObj) Keys() []*Term { + if l.strict != nil { + return l.strict.Keys() + } + ret := make([]*Term, 0, len(l.native)) + for k := range l.native { + ret = append(ret, StringTerm(k)) + } + sort.Sort(termSlice(ret)) + return ret +} + +func (l *lazyObj) KeysIterator() ObjectKeysIterator { + return &lazyObjKeysIterator{keys: l.Keys()} +} + +type lazyObjKeysIterator struct { + current int + keys []*Term +} + +func (ki *lazyObjKeysIterator) Next() (*Term, bool) { + if ki.current == len(ki.keys) { + return nil, false + } + ki.current++ + return ki.keys[ki.current-1], true +} + +func (l *lazyObj) Find(path Ref) (Value, error) { + if l.strict != nil { + return l.strict.Find(path) + } + if len(path) == 0 { + return l, nil + } + if p0, ok := path[0].Value.(String); ok { + if v, ok := l.cache[string(p0)]; ok { + return v.Find(path[1:]) + } + + if v, ok := l.native[string(p0)]; ok { + var converted Value + switch v := v.(type) { + case map[string]interface{}: + converted = LazyObject(v) + default: + converted = MustInterfaceToValue(v) + } + l.cache[string(p0)] = converted + return converted.Find(path[1:]) + } + } + return nil, errFindNotFound +} + +type object struct { + elems map[int]*objectElem + keys objectElemSlice + ground int // number of key and value grounds. Counting is + // required to support insert's key-value replace. + hash int + sortGuard *sync.Once // Prevents race condition around sorting. +} + +func newobject(n int) *object { + var keys objectElemSlice + if n > 0 { + keys = make(objectElemSlice, 0, n) + } + return &object{ + elems: make(map[int]*objectElem, n), + keys: keys, + ground: 0, + hash: 0, + sortGuard: new(sync.Once), + } +} + +type objectElem struct { + key *Term + value *Term + next *objectElem +} + +type objectElemSlice []*objectElem + +func (s objectElemSlice) Less(i, j int) bool { return Compare(s[i].key.Value, s[j].key.Value) < 0 } +func (s objectElemSlice) Swap(i, j int) { x := s[i]; s[i] = s[j]; s[j] = x } +func (s objectElemSlice) Len() int { return len(s) } + +// Item is a helper for constructing an tuple containing two Terms +// representing a key/value pair in an Object. +func Item(key, value *Term) [2]*Term { + return [2]*Term{key, value} +} + +func (obj *object) sortedKeys() objectElemSlice { + obj.sortGuard.Do(func() { + sort.Sort(obj.keys) + }) + return obj.keys +} + +// Compare compares obj to other, return <0, 0, or >0 if it is less than, equal to, +// or greater than other. +func (obj *object) Compare(other Value) int { + if x, ok := other.(*lazyObj); ok { + other = x.force() + } + o1 := sortOrder(obj) + o2 := sortOrder(other) + if o1 < o2 { + return -1 + } else if o2 < o1 { + return 1 + } + a := obj + b := other.(*object) + // Ensure that keys are in canonical sorted order before use! + akeys := a.sortedKeys() + bkeys := b.sortedKeys() + minLen := len(akeys) + if len(b.keys) < len(akeys) { + minLen = len(bkeys) + } + for i := 0; i < minLen; i++ { + keysCmp := Compare(akeys[i].key, bkeys[i].key) + if keysCmp < 0 { + return -1 + } + if keysCmp > 0 { + return 1 + } + valA := akeys[i].value + valB := bkeys[i].value + valCmp := Compare(valA, valB) + if valCmp != 0 { + return valCmp + } + } + if len(akeys) < len(bkeys) { + return -1 + } + if len(bkeys) < len(akeys) { + return 1 + } + return 0 +} + +// Find returns the value at the key or undefined. +func (obj *object) Find(path Ref) (Value, error) { + if len(path) == 0 { + return obj, nil + } + value := obj.Get(path[0]) + if value == nil { + return nil, errFindNotFound + } + return value.Value.Find(path[1:]) +} + +func (obj *object) Insert(k, v *Term) { + obj.insert(k, v, true) +} + +// Get returns the value of k in obj if k exists, otherwise nil. +func (obj *object) Get(k *Term) *Term { + if elem := obj.get(k); elem != nil { + return elem.value + } + return nil +} + +// Hash returns the hash code for the Value. +func (obj *object) Hash() int { + return obj.hash +} + +// IsGround returns true if all of the Object key/value pairs are ground. +func (obj *object) IsGround() bool { + return obj.ground == 2*len(obj.keys) +} + +// Copy returns a deep copy of obj. +func (obj *object) Copy() Object { + cpy, _ := obj.Map(func(k, v *Term) (*Term, *Term, error) { + return k.Copy(), v.Copy(), nil + }) + cpy.(*object).hash = obj.hash + return cpy +} + +// Diff returns a new Object that contains only the key/value pairs that exist in obj. +func (obj *object) Diff(other Object) Object { + r := newobject(obj.Len()) + for _, node := range obj.sortedKeys() { + if other.Get(node.key) == nil { + r.insert(node.key, node.value, false) + } + } + return r +} + +// Intersect returns a slice of term triplets that represent the intersection of keys +// between obj and other. For each intersecting key, the values from obj and other are included +// as the last two terms in the triplet (respectively). +func (obj *object) Intersect(other Object) [][3]*Term { + r := [][3]*Term{} + obj.Foreach(func(k, v *Term) { + if v2 := other.Get(k); v2 != nil { + r = append(r, [3]*Term{k, v, v2}) + } + }) + return r +} + +// Iter calls the function f for each key-value pair in the object. If f +// returns an error, iteration stops and the error is returned. +func (obj *object) Iter(f func(*Term, *Term) error) error { + for _, node := range obj.sortedKeys() { + if err := f(node.key, node.value); err != nil { + return err + } + } + return nil +} + +// Until calls f for each key-value pair in the object. If f returns +// true, iteration stops and Until returns true. Otherwise, return +// false. +func (obj *object) Until(f func(*Term, *Term) bool) bool { + for _, node := range obj.sortedKeys() { + if f(node.key, node.value) { + return true + } + } + return false +} + +// Foreach calls f for each key-value pair in the object. +func (obj *object) Foreach(f func(*Term, *Term)) { + for _, node := range obj.sortedKeys() { + f(node.key, node.value) + } +} + +// Map returns a new Object constructed by mapping each element in the object +// using the function f. +func (obj *object) Map(f func(*Term, *Term) (*Term, *Term, error)) (Object, error) { + cpy := newobject(obj.Len()) + for _, node := range obj.sortedKeys() { + k, v, err := f(node.key, node.value) + if err != nil { + return nil, err + } + cpy.insert(k, v, false) + } + return cpy, nil +} + +// Keys returns the keys of obj. +func (obj *object) Keys() []*Term { + keys := make([]*Term, len(obj.keys)) + + for i, elem := range obj.sortedKeys() { + keys[i] = elem.key + } + + return keys +} + +// Returns an iterator over the obj's keys. +func (obj *object) KeysIterator() ObjectKeysIterator { + return newobjectKeysIterator(obj) +} + +// MarshalJSON returns JSON encoded bytes representing obj. +func (obj *object) MarshalJSON() ([]byte, error) { + sl := make([][2]*Term, obj.Len()) + for i, node := range obj.sortedKeys() { + sl[i] = Item(node.key, node.value) + } + return json.Marshal(sl) +} + +// Merge returns a new Object containing the non-overlapping keys of obj and other. If there are +// overlapping keys between obj and other, the values of associated with the keys are merged. Only +// objects can be merged with other objects. If the values cannot be merged, the second turn value +// will be false. +func (obj object) Merge(other Object) (Object, bool) { + return obj.MergeWith(other, func(v1, v2 *Term) (*Term, bool) { + obj1, ok1 := v1.Value.(Object) + obj2, ok2 := v2.Value.(Object) + if !ok1 || !ok2 { + return nil, true + } + obj3, ok := obj1.Merge(obj2) + if !ok { + return nil, true + } + return NewTerm(obj3), false + }) +} + +// MergeWith returns a new Object containing the merged keys of obj and other. +// If there are overlapping keys between obj and other, the conflictResolver +// is called. The conflictResolver can return a merged value and a boolean +// indicating if the merge has failed and should stop. +func (obj object) MergeWith(other Object, conflictResolver func(v1, v2 *Term) (*Term, bool)) (Object, bool) { + result := NewObject() + stop := obj.Until(func(k, v *Term) bool { + v2 := other.Get(k) + // The key didn't exist in other, keep the original value + if v2 == nil { + result.Insert(k, v) + return false + } + + // The key exists in both, resolve the conflict if possible + merged, stop := conflictResolver(v, v2) + if !stop { + result.Insert(k, merged) + } + return stop + }) + + if stop { + return nil, false + } + + // Copy in any values from other for keys that don't exist in obj + other.Foreach(func(k, v *Term) { + if v2 := obj.Get(k); v2 == nil { + result.Insert(k, v) + } + }) + return result, true +} + +// Filter returns a new object from values in obj where the keys are +// found in filter. Array indices for values can be specified as +// number strings. +func (obj *object) Filter(filter Object) (Object, error) { + filtered, err := filterObject(obj, filter) + if err != nil { + return nil, err + } + return filtered.(Object), nil +} + +// Len returns the number of elements in the object. +func (obj object) Len() int { + return len(obj.keys) +} + +func (obj object) String() string { + sb := sbPool.Get().(*strings.Builder) + sb.Reset() + sb.Grow(obj.Len() * 32) + + defer sbPool.Put(sb) + + sb.WriteRune('{') + + for i, elem := range obj.sortedKeys() { + if i > 0 { + sb.WriteString(", ") + } + sb.WriteString(elem.key.String()) + sb.WriteString(": ") + sb.WriteString(elem.value.String()) + } + sb.WriteRune('}') + + return sb.String() +} + +func (obj *object) get(k *Term) *objectElem { + hash := k.Hash() + + // This `equal` utility is duplicated and manually inlined a number of + // time in this file. Inlining it avoids heap allocations, so it makes + // a big performance difference: some operations like lookup become twice + // as slow without it. + var equal func(v Value) bool + + switch x := k.Value.(type) { + case Null, Boolean, String, Var: + equal = func(y Value) bool { return x == y } + case Number: + if xi, ok := x.Int64(); ok { + equal = func(y Value) bool { + if y, ok := y.(Number); ok { + if yi, ok := y.Int64(); ok { + return xi == yi + } + } + + return false + } + break + } + + // We use big.Rat for comparing big numbers. + // It replaces big.Float due to following reason: + // big.Float comes with a default precision of 64, and setting a + // larger precision results in more memory being allocated + // (regardless of the actual number we are parsing with SetString). + // + // Note: If we're so close to zero that big.Float says we are zero, do + // *not* big.Rat).SetString on the original string it'll potentially + // take very long. + var a *big.Rat + fa, ok := new(big.Float).SetString(string(x)) + if !ok { + panic("illegal value") + } + if fa.IsInt() { + if i, _ := fa.Int64(); i == 0 { + a = new(big.Rat).SetInt64(0) + } + } + if a == nil { + a, ok = new(big.Rat).SetString(string(x)) + if !ok { + panic("illegal value") + } + } + + equal = func(b Value) bool { + if bNum, ok := b.(Number); ok { + var b *big.Rat + fb, ok := new(big.Float).SetString(string(bNum)) + if !ok { + panic("illegal value") + } + if fb.IsInt() { + if i, _ := fb.Int64(); i == 0 { + b = new(big.Rat).SetInt64(0) + } + } + if b == nil { + b, ok = new(big.Rat).SetString(string(bNum)) + if !ok { + panic("illegal value") + } + } + + return a.Cmp(b) == 0 + } + + return false + } + default: + equal = func(y Value) bool { return Compare(x, y) == 0 } + } + + for curr := obj.elems[hash]; curr != nil; curr = curr.next { + if equal(curr.key.Value) { + return curr + } + } + return nil +} + +// NOTE(philipc): We assume a many-readers, single-writer model here. +// This method should NOT be used concurrently, or else we risk data races. +func (obj *object) insert(k, v *Term, resetSortGuard bool) { + hash := k.Hash() + head := obj.elems[hash] + // This `equal` utility is duplicated and manually inlined a number of + // time in this file. Inlining it avoids heap allocations, so it makes + // a big performance difference: some operations like lookup become twice + // as slow without it. + var equal func(v Value) bool + + switch x := k.Value.(type) { + case Null, Boolean, String, Var: + equal = func(y Value) bool { return x == y } + case Number: + if xi, err := json.Number(x).Int64(); err == nil { + equal = func(y Value) bool { + if y, ok := y.(Number); ok { + if yi, err := json.Number(y).Int64(); err == nil { + return xi == yi + } + } + + return false + } + break + } + + // We use big.Rat for comparing big numbers. + // It replaces big.Float due to following reason: + // big.Float comes with a default precision of 64, and setting a + // larger precision results in more memory being allocated + // (regardless of the actual number we are parsing with SetString). + // + // Note: If we're so close to zero that big.Float says we are zero, do + // *not* big.Rat).SetString on the original string it'll potentially + // take very long. + var a *big.Rat + fa, ok := new(big.Float).SetString(string(x)) + if !ok { + panic("illegal value") + } + if fa.IsInt() { + if i, _ := fa.Int64(); i == 0 { + a = new(big.Rat).SetInt64(0) + } + } + if a == nil { + a, ok = new(big.Rat).SetString(string(x)) + if !ok { + panic("illegal value") + } + } + + equal = func(b Value) bool { + if bNum, ok := b.(Number); ok { + var b *big.Rat + fb, ok := new(big.Float).SetString(string(bNum)) + if !ok { + panic("illegal value") + } + if fb.IsInt() { + if i, _ := fb.Int64(); i == 0 { + b = new(big.Rat).SetInt64(0) + } + } + if b == nil { + b, ok = new(big.Rat).SetString(string(bNum)) + if !ok { + panic("illegal value") + } + } + + return a.Cmp(b) == 0 + } + + return false + } + default: + equal = func(y Value) bool { return Compare(x, y) == 0 } + } + + for curr := head; curr != nil; curr = curr.next { + if equal(curr.key.Value) { + // The ground bit of the value may change in + // replace, hence adjust the counter per old + // and new value. + + if curr.value.IsGround() { + obj.ground-- + } + if v.IsGround() { + obj.ground++ + } + + curr.value = v + + obj.rehash() + return + } + } + elem := &objectElem{ + key: k, + value: v, + next: head, + } + obj.elems[hash] = elem + // O(1) insertion, but we'll have to re-sort the keys later. + obj.keys = append(obj.keys, elem) + + if resetSortGuard { + // Reset the sync.Once instance. + // See https://github.com/golang/go/issues/25955 for why we do it this way. + // Note that this will always be the case when external code calls insert via + // Add, or otherwise. Internal code may however benefit from not having to + // re-create this pointer when it's known not to be needed. + obj.sortGuard = new(sync.Once) + } + + obj.hash += hash + v.Hash() + + if k.IsGround() { + obj.ground++ + } + if v.IsGround() { + obj.ground++ + } +} + +func (obj *object) rehash() { + // obj.keys is considered truth, from which obj.hash and obj.elems are recalculated. + + obj.hash = 0 + obj.elems = make(map[int]*objectElem, len(obj.keys)) + + for _, elem := range obj.keys { + hash := elem.key.Hash() + obj.hash += hash + elem.value.Hash() + obj.elems[hash] = elem + } +} + +func filterObject(o Value, filter Value) (Value, error) { + if filter.Compare(Null{}) == 0 { + return o, nil + } + + filteredObj, ok := filter.(*object) + if !ok { + return nil, fmt.Errorf("invalid filter value %q, expected an object", filter) + } + + switch v := o.(type) { + case String, Number, Boolean, Null: + return o, nil + case *Array: + values := NewArray() + for i := 0; i < v.Len(); i++ { + subFilter := filteredObj.Get(StringTerm(strconv.Itoa(i))) + if subFilter != nil { + filteredValue, err := filterObject(v.Elem(i).Value, subFilter.Value) + if err != nil { + return nil, err + } + values = values.Append(NewTerm(filteredValue)) + } + } + return values, nil + case Set: + terms := make([]*Term, 0, v.Len()) + for _, t := range v.Slice() { + if filteredObj.Get(t) != nil { + filteredValue, err := filterObject(t.Value, filteredObj.Get(t).Value) + if err != nil { + return nil, err + } + terms = append(terms, NewTerm(filteredValue)) + } + } + return NewSet(terms...), nil + case *object: + values := NewObject() + + iterObj := v + other := filteredObj + if v.Len() < filteredObj.Len() { + iterObj = filteredObj + other = v + } + + err := iterObj.Iter(func(key *Term, _ *Term) error { + if other.Get(key) != nil { + filteredValue, err := filterObject(v.Get(key).Value, filteredObj.Get(key).Value) + if err != nil { + return err + } + values.Insert(key, NewTerm(filteredValue)) + } + return nil + }) + return values, err + default: + return nil, fmt.Errorf("invalid object value type %q", v) + } +} + +// NOTE(philipc): The only way to get an ObjectKeyIterator should be +// from an Object. This ensures that the iterator can have implementation- +// specific details internally, with no contracts except to the very +// limited interface. +type ObjectKeysIterator interface { + Next() (*Term, bool) +} + +type objectKeysIterator struct { + obj *object + numKeys int + index int +} + +func newobjectKeysIterator(o *object) ObjectKeysIterator { + return &objectKeysIterator{ + obj: o, + numKeys: o.Len(), + index: 0, + } +} + +func (oki *objectKeysIterator) Next() (*Term, bool) { + if oki.index == oki.numKeys || oki.numKeys == 0 { + return nil, false + } + oki.index++ + return oki.obj.sortedKeys()[oki.index-1].key, true +} + +// ArrayComprehension represents an array comprehension as defined in the language. +type ArrayComprehension struct { + Term *Term `json:"term"` + Body Body `json:"body"` +} + +// ArrayComprehensionTerm creates a new Term with an ArrayComprehension value. +func ArrayComprehensionTerm(term *Term, body Body) *Term { + return &Term{ + Value: &ArrayComprehension{ + Term: term, + Body: body, + }, + } +} + +// Copy returns a deep copy of ac. +func (ac *ArrayComprehension) Copy() *ArrayComprehension { + cpy := *ac + cpy.Body = ac.Body.Copy() + cpy.Term = ac.Term.Copy() + return &cpy +} + +// Equal returns true if ac is equal to other. +func (ac *ArrayComprehension) Equal(other Value) bool { + return Compare(ac, other) == 0 +} + +// Compare compares ac to other, return <0, 0, or >0 if it is less than, equal to, +// or greater than other. +func (ac *ArrayComprehension) Compare(other Value) int { + return Compare(ac, other) +} + +// Find returns the current value or a not found error. +func (ac *ArrayComprehension) Find(path Ref) (Value, error) { + if len(path) == 0 { + return ac, nil + } + return nil, errFindNotFound +} + +// Hash returns the hash code of the Value. +func (ac *ArrayComprehension) Hash() int { + return ac.Term.Hash() + ac.Body.Hash() +} + +// IsGround returns true if the Term and Body are ground. +func (ac *ArrayComprehension) IsGround() bool { + return ac.Term.IsGround() && ac.Body.IsGround() +} + +func (ac *ArrayComprehension) String() string { + return "[" + ac.Term.String() + " | " + ac.Body.String() + "]" +} + +// ObjectComprehension represents an object comprehension as defined in the language. +type ObjectComprehension struct { + Key *Term `json:"key"` + Value *Term `json:"value"` + Body Body `json:"body"` +} + +// ObjectComprehensionTerm creates a new Term with an ObjectComprehension value. +func ObjectComprehensionTerm(key, value *Term, body Body) *Term { + return &Term{ + Value: &ObjectComprehension{ + Key: key, + Value: value, + Body: body, + }, + } +} + +// Copy returns a deep copy of oc. +func (oc *ObjectComprehension) Copy() *ObjectComprehension { + cpy := *oc + cpy.Body = oc.Body.Copy() + cpy.Key = oc.Key.Copy() + cpy.Value = oc.Value.Copy() + return &cpy +} + +// Equal returns true if oc is equal to other. +func (oc *ObjectComprehension) Equal(other Value) bool { + return Compare(oc, other) == 0 +} + +// Compare compares oc to other, return <0, 0, or >0 if it is less than, equal to, +// or greater than other. +func (oc *ObjectComprehension) Compare(other Value) int { + return Compare(oc, other) +} + +// Find returns the current value or a not found error. +func (oc *ObjectComprehension) Find(path Ref) (Value, error) { + if len(path) == 0 { + return oc, nil + } + return nil, errFindNotFound +} + +// Hash returns the hash code of the Value. +func (oc *ObjectComprehension) Hash() int { + return oc.Key.Hash() + oc.Value.Hash() + oc.Body.Hash() +} + +// IsGround returns true if the Key, Value and Body are ground. +func (oc *ObjectComprehension) IsGround() bool { + return oc.Key.IsGround() && oc.Value.IsGround() && oc.Body.IsGround() +} + +func (oc *ObjectComprehension) String() string { + return "{" + oc.Key.String() + ": " + oc.Value.String() + " | " + oc.Body.String() + "}" +} + +// SetComprehension represents a set comprehension as defined in the language. +type SetComprehension struct { + Term *Term `json:"term"` + Body Body `json:"body"` +} + +// SetComprehensionTerm creates a new Term with an SetComprehension value. +func SetComprehensionTerm(term *Term, body Body) *Term { + return &Term{ + Value: &SetComprehension{ + Term: term, + Body: body, + }, + } +} + +// Copy returns a deep copy of sc. +func (sc *SetComprehension) Copy() *SetComprehension { + cpy := *sc + cpy.Body = sc.Body.Copy() + cpy.Term = sc.Term.Copy() + return &cpy +} + +// Equal returns true if sc is equal to other. +func (sc *SetComprehension) Equal(other Value) bool { + return Compare(sc, other) == 0 +} + +// Compare compares sc to other, return <0, 0, or >0 if it is less than, equal to, +// or greater than other. +func (sc *SetComprehension) Compare(other Value) int { + return Compare(sc, other) +} + +// Find returns the current value or a not found error. +func (sc *SetComprehension) Find(path Ref) (Value, error) { + if len(path) == 0 { + return sc, nil + } + return nil, errFindNotFound +} + +// Hash returns the hash code of the Value. +func (sc *SetComprehension) Hash() int { + return sc.Term.Hash() + sc.Body.Hash() +} + +// IsGround returns true if the Term and Body are ground. +func (sc *SetComprehension) IsGround() bool { + return sc.Term.IsGround() && sc.Body.IsGround() +} + +func (sc *SetComprehension) String() string { + return "{" + sc.Term.String() + " | " + sc.Body.String() + "}" +} + +// Call represents as function call in the language. +type Call []*Term + +// CallTerm returns a new Term with a Call value defined by terms. The first +// term is the operator and the rest are operands. +func CallTerm(terms ...*Term) *Term { + return NewTerm(Call(terms)) +} + +// Copy returns a deep copy of c. +func (c Call) Copy() Call { + return termSliceCopy(c) +} + +// Compare compares c to other, return <0, 0, or >0 if it is less than, equal to, +// or greater than other. +func (c Call) Compare(other Value) int { + return Compare(c, other) +} + +// Find returns the current value or a not found error. +func (c Call) Find(Ref) (Value, error) { + return nil, errFindNotFound +} + +// Hash returns the hash code for the Value. +func (c Call) Hash() int { + return termSliceHash(c) +} + +// IsGround returns true if the Value is ground. +func (c Call) IsGround() bool { + return termSliceIsGround(c) +} + +// MakeExpr returns an ew Expr from this call. +func (c Call) MakeExpr(output *Term) *Expr { + terms := []*Term(c) + return NewExpr(append(terms, output)) +} + +func (c Call) String() string { + args := make([]string, len(c)-1) + for i := 1; i < len(c); i++ { + args[i-1] = c[i].String() + } + return fmt.Sprintf("%v(%v)", c[0], strings.Join(args, ", ")) +} + +func termSliceCopy(a []*Term) []*Term { + cpy := make([]*Term, len(a)) + for i := range a { + cpy[i] = a[i].Copy() + } + return cpy +} + +func termSliceEqual(a, b []*Term) bool { + if len(a) == len(b) { + for i := range a { + if !a[i].Equal(b[i]) { + return false + } + } + return true + } + return false +} + +func termSliceHash(a []*Term) int { + var hash int + for _, v := range a { + hash += v.Value.Hash() + } + return hash +} + +func termSliceIsGround(a []*Term) bool { + for _, v := range a { + if !v.IsGround() { + return false + } + } + return true +} + +// NOTE(tsandall): The unmarshalling errors in these functions are not +// helpful for callers because they do not identify the source of the +// unmarshalling error. Because OPA doesn't accept JSON describing ASTs +// from callers, this is acceptable (for now). If that changes in the future, +// the error messages should be revisited. The current approach focuses +// on the happy path and treats all errors the same. If better error +// reporting is needed, the error paths will need to be fleshed out. + +func unmarshalBody(b []interface{}) (Body, error) { + buf := Body{} + for _, e := range b { + if m, ok := e.(map[string]interface{}); ok { + expr := &Expr{} + if err := unmarshalExpr(expr, m); err == nil { + buf = append(buf, expr) + continue + } + } + goto unmarshal_error + } + return buf, nil +unmarshal_error: + return nil, fmt.Errorf("ast: unable to unmarshal body") +} + +func unmarshalExpr(expr *Expr, v map[string]interface{}) error { + if x, ok := v["negated"]; ok { + if b, ok := x.(bool); ok { + expr.Negated = b + } else { + return fmt.Errorf("ast: unable to unmarshal negated field with type: %T (expected true or false)", v["negated"]) + } + } + if generatedRaw, ok := v["generated"]; ok { + if b, ok := generatedRaw.(bool); ok { + expr.Generated = b + } else { + return fmt.Errorf("ast: unable to unmarshal generated field with type: %T (expected true or false)", v["generated"]) + } + } + + if err := unmarshalExprIndex(expr, v); err != nil { + return err + } + switch ts := v["terms"].(type) { + case map[string]interface{}: + t, err := unmarshalTerm(ts) + if err != nil { + return err + } + expr.Terms = t + case []interface{}: + terms, err := unmarshalTermSlice(ts) + if err != nil { + return err + } + expr.Terms = terms + default: + return fmt.Errorf(`ast: unable to unmarshal terms field with type: %T (expected {"value": ..., "type": ...} or [{"value": ..., "type": ...}, ...])`, v["terms"]) + } + if x, ok := v["with"]; ok { + if sl, ok := x.([]interface{}); ok { + ws := make([]*With, len(sl)) + for i := range sl { + var err error + ws[i], err = unmarshalWith(sl[i]) + if err != nil { + return err + } + } + expr.With = ws + } + } + if loc, ok := v["location"].(map[string]interface{}); ok { + expr.Location = &Location{} + if err := unmarshalLocation(expr.Location, loc); err != nil { + return err + } + } + return nil +} + +func unmarshalLocation(loc *Location, v map[string]interface{}) error { + if x, ok := v["file"]; ok { + if s, ok := x.(string); ok { + loc.File = s + } else { + return fmt.Errorf("ast: unable to unmarshal file field with type: %T (expected string)", v["file"]) + } + } + if x, ok := v["row"]; ok { + if n, ok := x.(json.Number); ok { + i64, err := n.Int64() + if err != nil { + return err + } + loc.Row = int(i64) + } else { + return fmt.Errorf("ast: unable to unmarshal row field with type: %T (expected number)", v["row"]) + } + } + if x, ok := v["col"]; ok { + if n, ok := x.(json.Number); ok { + i64, err := n.Int64() + if err != nil { + return err + } + loc.Col = int(i64) + } else { + return fmt.Errorf("ast: unable to unmarshal col field with type: %T (expected number)", v["col"]) + } + } + + return nil +} + +func unmarshalExprIndex(expr *Expr, v map[string]interface{}) error { + if x, ok := v["index"]; ok { + if n, ok := x.(json.Number); ok { + i, err := n.Int64() + if err == nil { + expr.Index = int(i) + return nil + } + } + } + return fmt.Errorf("ast: unable to unmarshal index field with type: %T (expected integer)", v["index"]) +} + +func unmarshalTerm(m map[string]interface{}) (*Term, error) { + var term Term + + v, err := unmarshalValue(m) + if err != nil { + return nil, err + } + term.Value = v + + if loc, ok := m["location"].(map[string]interface{}); ok { + term.Location = &Location{} + if err := unmarshalLocation(term.Location, loc); err != nil { + return nil, err + } + } + + return &term, nil +} + +func unmarshalTermSlice(s []interface{}) ([]*Term, error) { + buf := []*Term{} + for _, x := range s { + if m, ok := x.(map[string]interface{}); ok { + t, err := unmarshalTerm(m) + if err == nil { + buf = append(buf, t) + continue + } + return nil, err + } + return nil, fmt.Errorf("ast: unable to unmarshal term") + } + return buf, nil +} + +func unmarshalTermSliceValue(d map[string]interface{}) ([]*Term, error) { + if s, ok := d["value"].([]interface{}); ok { + return unmarshalTermSlice(s) + } + return nil, fmt.Errorf(`ast: unable to unmarshal term (expected {"value": [...], "type": ...} where type is one of: ref, array, or set)`) +} + +func unmarshalWith(i interface{}) (*With, error) { + if m, ok := i.(map[string]interface{}); ok { + tgt, _ := m["target"].(map[string]interface{}) + target, err := unmarshalTerm(tgt) + if err == nil { + val, _ := m["value"].(map[string]interface{}) + value, err := unmarshalTerm(val) + if err == nil { + return &With{ + Target: target, + Value: value, + }, nil + } + return nil, err + } + return nil, err + } + return nil, fmt.Errorf(`ast: unable to unmarshal with modifier (expected {"target": {...}, "value": {...}})`) +} + +func unmarshalValue(d map[string]interface{}) (Value, error) { + v := d["value"] + switch d["type"] { + case "null": + return Null{}, nil + case "boolean": + if b, ok := v.(bool); ok { + return Boolean(b), nil + } + case "number": + if n, ok := v.(json.Number); ok { + return Number(n), nil + } + case "string": + if s, ok := v.(string); ok { + return String(s), nil + } + case "var": + if s, ok := v.(string); ok { + return Var(s), nil + } + case "ref": + if s, err := unmarshalTermSliceValue(d); err == nil { + return Ref(s), nil + } + case "array": + if s, err := unmarshalTermSliceValue(d); err == nil { + return NewArray(s...), nil + } + case "set": + if s, err := unmarshalTermSliceValue(d); err == nil { + return NewSet(s...), nil + } + case "object": + if s, ok := v.([]interface{}); ok { + buf := NewObject() + for _, x := range s { + if i, ok := x.([]interface{}); ok && len(i) == 2 { + p, err := unmarshalTermSlice(i) + if err == nil { + buf.Insert(p[0], p[1]) + continue + } + } + goto unmarshal_error + } + return buf, nil + } + case "arraycomprehension", "setcomprehension": + if m, ok := v.(map[string]interface{}); ok { + t, ok := m["term"].(map[string]interface{}) + if !ok { + goto unmarshal_error + } + + term, err := unmarshalTerm(t) + if err != nil { + goto unmarshal_error + } + + b, ok := m["body"].([]interface{}) + if !ok { + goto unmarshal_error + } + + body, err := unmarshalBody(b) + if err != nil { + goto unmarshal_error + } + + if d["type"] == "arraycomprehension" { + return &ArrayComprehension{Term: term, Body: body}, nil + } + return &SetComprehension{Term: term, Body: body}, nil + } + case "objectcomprehension": + if m, ok := v.(map[string]interface{}); ok { + k, ok := m["key"].(map[string]interface{}) + if !ok { + goto unmarshal_error + } + + key, err := unmarshalTerm(k) + if err != nil { + goto unmarshal_error + } + + v, ok := m["value"].(map[string]interface{}) + if !ok { + goto unmarshal_error + } + + value, err := unmarshalTerm(v) + if err != nil { + goto unmarshal_error + } + + b, ok := m["body"].([]interface{}) + if !ok { + goto unmarshal_error + } + + body, err := unmarshalBody(b) + if err != nil { + goto unmarshal_error + } + + return &ObjectComprehension{Key: key, Value: value, Body: body}, nil + } + case "call": + if s, err := unmarshalTermSliceValue(d); err == nil { + return Call(s), nil + } + } +unmarshal_error: + return nil, fmt.Errorf("ast: unable to unmarshal term") +} diff --git a/vendor/github.com/open-policy-agent/opa/v1/ast/transform.go b/vendor/github.com/open-policy-agent/opa/v1/ast/transform.go new file mode 100644 index 000000000..391a16486 --- /dev/null +++ b/vendor/github.com/open-policy-agent/opa/v1/ast/transform.go @@ -0,0 +1,431 @@ +// Copyright 2016 The OPA Authors. All rights reserved. +// Use of this source code is governed by an Apache2 +// license that can be found in the LICENSE file. + +package ast + +import ( + "fmt" +) + +// Transformer defines the interface for transforming AST elements. If the +// transformer returns nil and does not indicate an error, the AST element will +// be set to nil and no transformations will be applied to children of the +// element. +type Transformer interface { + Transform(interface{}) (interface{}, error) +} + +// Transform iterates the AST and calls the Transform function on the +// Transformer t for x before recursing. +func Transform(t Transformer, x interface{}) (interface{}, error) { + + if term, ok := x.(*Term); ok { + return Transform(t, term.Value) + } + + y, err := t.Transform(x) + if err != nil { + return x, err + } + + if y == nil { + return nil, nil + } + + var ok bool + switch y := y.(type) { + case *Module: + p, err := Transform(t, y.Package) + if err != nil { + return nil, err + } + if y.Package, ok = p.(*Package); !ok { + return nil, fmt.Errorf("illegal transform: %T != %T", y.Package, p) + } + for i := range y.Imports { + imp, err := Transform(t, y.Imports[i]) + if err != nil { + return nil, err + } + if y.Imports[i], ok = imp.(*Import); !ok { + return nil, fmt.Errorf("illegal transform: %T != %T", y.Imports[i], imp) + } + } + for i := range y.Rules { + rule, err := Transform(t, y.Rules[i]) + if err != nil { + return nil, err + } + if y.Rules[i], ok = rule.(*Rule); !ok { + return nil, fmt.Errorf("illegal transform: %T != %T", y.Rules[i], rule) + } + } + for i := range y.Annotations { + a, err := Transform(t, y.Annotations[i]) + if err != nil { + return nil, err + } + if y.Annotations[i], ok = a.(*Annotations); !ok { + return nil, fmt.Errorf("illegal transform: %T != %T", y.Annotations[i], a) + } + } + for i := range y.Comments { + comment, err := Transform(t, y.Comments[i]) + if err != nil { + return nil, err + } + if y.Comments[i], ok = comment.(*Comment); !ok { + return nil, fmt.Errorf("illegal transform: %T != %T", y.Comments[i], comment) + } + } + return y, nil + case *Package: + ref, err := Transform(t, y.Path) + if err != nil { + return nil, err + } + if y.Path, ok = ref.(Ref); !ok { + return nil, fmt.Errorf("illegal transform: %T != %T", y.Path, ref) + } + return y, nil + case *Import: + y.Path, err = transformTerm(t, y.Path) + if err != nil { + return nil, err + } + if y.Alias, err = transformVar(t, y.Alias); err != nil { + return nil, err + } + return y, nil + case *Rule: + if y.Head, err = transformHead(t, y.Head); err != nil { + return nil, err + } + if y.Body, err = transformBody(t, y.Body); err != nil { + return nil, err + } + if y.Else != nil { + rule, err := Transform(t, y.Else) + if err != nil { + return nil, err + } + if y.Else, ok = rule.(*Rule); !ok { + return nil, fmt.Errorf("illegal transform: %T != %T", y.Else, rule) + } + } + return y, nil + case *Head: + if y.Reference, err = transformRef(t, y.Reference); err != nil { + return nil, err + } + if y.Name, err = transformVar(t, y.Name); err != nil { + return nil, err + } + if y.Args, err = transformArgs(t, y.Args); err != nil { + return nil, err + } + if y.Key != nil { + if y.Key, err = transformTerm(t, y.Key); err != nil { + return nil, err + } + } + if y.Value != nil { + if y.Value, err = transformTerm(t, y.Value); err != nil { + return nil, err + } + } + return y, nil + case Args: + for i := range y { + if y[i], err = transformTerm(t, y[i]); err != nil { + return nil, err + } + } + return y, nil + case Body: + for i, e := range y { + e, err := Transform(t, e) + if err != nil { + return nil, err + } + if y[i], ok = e.(*Expr); !ok { + return nil, fmt.Errorf("illegal transform: %T != %T", y[i], e) + } + } + return y, nil + case *Expr: + switch ts := y.Terms.(type) { + case *SomeDecl: + decl, err := Transform(t, ts) + if err != nil { + return nil, err + } + if y.Terms, ok = decl.(*SomeDecl); !ok { + return nil, fmt.Errorf("illegal transform: %T != %T", y, decl) + } + return y, nil + case []*Term: + for i := range ts { + if ts[i], err = transformTerm(t, ts[i]); err != nil { + return nil, err + } + } + case *Term: + if y.Terms, err = transformTerm(t, ts); err != nil { + return nil, err + } + case *Every: + if ts.Key != nil { + ts.Key, err = transformTerm(t, ts.Key) + if err != nil { + return nil, err + } + } + ts.Value, err = transformTerm(t, ts.Value) + if err != nil { + return nil, err + } + ts.Domain, err = transformTerm(t, ts.Domain) + if err != nil { + return nil, err + } + ts.Body, err = transformBody(t, ts.Body) + if err != nil { + return nil, err + } + y.Terms = ts + } + for i, w := range y.With { + w, err := Transform(t, w) + if err != nil { + return nil, err + } + if y.With[i], ok = w.(*With); !ok { + return nil, fmt.Errorf("illegal transform: %T != %T", y.With[i], w) + } + } + return y, nil + case *With: + if y.Target, err = transformTerm(t, y.Target); err != nil { + return nil, err + } + if y.Value, err = transformTerm(t, y.Value); err != nil { + return nil, err + } + return y, nil + case Ref: + for i, term := range y { + if y[i], err = transformTerm(t, term); err != nil { + return nil, err + } + } + return y, nil + case *object: + return y.Map(func(k, v *Term) (*Term, *Term, error) { + k, err := transformTerm(t, k) + if err != nil { + return nil, nil, err + } + v, err = transformTerm(t, v) + if err != nil { + return nil, nil, err + } + return k, v, nil + }) + case *Array: + for i := 0; i < y.Len(); i++ { + v, err := transformTerm(t, y.Elem(i)) + if err != nil { + return nil, err + } + y.set(i, v) + } + return y, nil + case Set: + y, err = y.Map(func(term *Term) (*Term, error) { + return transformTerm(t, term) + }) + if err != nil { + return nil, err + } + return y, nil + case *ArrayComprehension: + if y.Term, err = transformTerm(t, y.Term); err != nil { + return nil, err + } + if y.Body, err = transformBody(t, y.Body); err != nil { + return nil, err + } + return y, nil + case *ObjectComprehension: + if y.Key, err = transformTerm(t, y.Key); err != nil { + return nil, err + } + if y.Value, err = transformTerm(t, y.Value); err != nil { + return nil, err + } + if y.Body, err = transformBody(t, y.Body); err != nil { + return nil, err + } + return y, nil + case *SetComprehension: + if y.Term, err = transformTerm(t, y.Term); err != nil { + return nil, err + } + if y.Body, err = transformBody(t, y.Body); err != nil { + return nil, err + } + return y, nil + case Call: + for i := range y { + if y[i], err = transformTerm(t, y[i]); err != nil { + return nil, err + } + } + return y, nil + default: + return y, nil + } +} + +// TransformRefs calls the function f on all references under x. +func TransformRefs(x interface{}, f func(Ref) (Value, error)) (interface{}, error) { + t := &GenericTransformer{func(x interface{}) (interface{}, error) { + if r, ok := x.(Ref); ok { + return f(r) + } + return x, nil + }} + return Transform(t, x) +} + +// TransformVars calls the function f on all vars under x. +func TransformVars(x interface{}, f func(Var) (Value, error)) (interface{}, error) { + t := &GenericTransformer{func(x interface{}) (interface{}, error) { + if v, ok := x.(Var); ok { + return f(v) + } + return x, nil + }} + return Transform(t, x) +} + +// TransformComprehensions calls the functio nf on all comprehensions under x. +func TransformComprehensions(x interface{}, f func(interface{}) (Value, error)) (interface{}, error) { + t := &GenericTransformer{func(x interface{}) (interface{}, error) { + switch x := x.(type) { + case *ArrayComprehension: + return f(x) + case *SetComprehension: + return f(x) + case *ObjectComprehension: + return f(x) + } + return x, nil + }} + return Transform(t, x) +} + +// GenericTransformer implements the Transformer interface to provide a utility +// to transform AST nodes using a closure. +type GenericTransformer struct { + f func(interface{}) (interface{}, error) +} + +// NewGenericTransformer returns a new GenericTransformer that will transform +// AST nodes using the function f. +func NewGenericTransformer(f func(x interface{}) (interface{}, error)) *GenericTransformer { + return &GenericTransformer{ + f: f, + } +} + +// Transform calls the function f on the GenericTransformer. +func (t *GenericTransformer) Transform(x interface{}) (interface{}, error) { + return t.f(x) +} + +func transformHead(t Transformer, head *Head) (*Head, error) { + y, err := Transform(t, head) + if err != nil { + return nil, err + } + h, ok := y.(*Head) + if !ok { + return nil, fmt.Errorf("illegal transform: %T != %T", head, y) + } + return h, nil +} + +func transformArgs(t Transformer, args Args) (Args, error) { + y, err := Transform(t, args) + if err != nil { + return nil, err + } + a, ok := y.(Args) + if !ok { + return nil, fmt.Errorf("illegal transform: %T != %T", args, y) + } + return a, nil +} + +func transformBody(t Transformer, body Body) (Body, error) { + y, err := Transform(t, body) + if err != nil { + return nil, err + } + r, ok := y.(Body) + if !ok { + return nil, fmt.Errorf("illegal transform: %T != %T", body, y) + } + return r, nil +} + +func transformTerm(t Transformer, term *Term) (*Term, error) { + v, err := transformValue(t, term.Value) + if err != nil { + return nil, err + } + r := &Term{ + Value: v, + Location: term.Location, + } + return r, nil +} + +func transformValue(t Transformer, v Value) (Value, error) { + v1, err := Transform(t, v) + if err != nil { + return nil, err + } + r, ok := v1.(Value) + if !ok { + return nil, fmt.Errorf("illegal transform: %T != %T", v, v1) + } + return r, nil +} + +func transformVar(t Transformer, v Var) (Var, error) { + v1, err := Transform(t, v) + if err != nil { + return "", err + } + r, ok := v1.(Var) + if !ok { + return "", fmt.Errorf("illegal transform: %T != %T", v, v1) + } + return r, nil +} + +func transformRef(t Transformer, r Ref) (Ref, error) { + r1, err := Transform(t, r) + if err != nil { + return nil, err + } + r2, ok := r1.(Ref) + if !ok { + return nil, fmt.Errorf("illegal transform: %T != %T", r, r2) + } + return r2, nil +} diff --git a/vendor/github.com/open-policy-agent/opa/v1/ast/unify.go b/vendor/github.com/open-policy-agent/opa/v1/ast/unify.go new file mode 100644 index 000000000..60244974a --- /dev/null +++ b/vendor/github.com/open-policy-agent/opa/v1/ast/unify.go @@ -0,0 +1,235 @@ +// Copyright 2016 The OPA Authors. All rights reserved. +// Use of this source code is governed by an Apache2 +// license that can be found in the LICENSE file. + +package ast + +func isRefSafe(ref Ref, safe VarSet) bool { + switch head := ref[0].Value.(type) { + case Var: + return safe.Contains(head) + case Call: + return isCallSafe(head, safe) + default: + for v := range ref[0].Vars() { + if !safe.Contains(v) { + return false + } + } + return true + } +} + +func isCallSafe(call Call, safe VarSet) bool { + vis := NewVarVisitor().WithParams(SafetyCheckVisitorParams) + vis.Walk(call) + unsafe := vis.Vars().Diff(safe) + return len(unsafe) == 0 +} + +// Unify returns a set of variables that will be unified when the equality expression defined by +// terms a and b is evaluated. The unifier assumes that variables in the VarSet safe are already +// unified. +func Unify(safe VarSet, a *Term, b *Term) VarSet { + u := &unifier{ + safe: safe, + unified: VarSet{}, + unknown: map[Var]VarSet{}, + } + u.unify(a, b) + return u.unified +} + +type unifier struct { + safe VarSet + unified VarSet + unknown map[Var]VarSet +} + +func (u *unifier) isSafe(x Var) bool { + return u.safe.Contains(x) || u.unified.Contains(x) +} + +func (u *unifier) unify(a *Term, b *Term) { + + switch a := a.Value.(type) { + + case Var: + switch b := b.Value.(type) { + case Var: + if u.isSafe(b) { + u.markSafe(a) + } else if u.isSafe(a) { + u.markSafe(b) + } else { + u.markUnknown(a, b) + u.markUnknown(b, a) + } + case *Array, Object: + u.unifyAll(a, b) + case Ref: + if isRefSafe(b, u.safe) { + u.markSafe(a) + } + case Call: + if isCallSafe(b, u.safe) { + u.markSafe(a) + } + default: + u.markSafe(a) + } + + case Ref: + if isRefSafe(a, u.safe) { + switch b := b.Value.(type) { + case Var: + u.markSafe(b) + case *Array, Object: + u.markAllSafe(b) + } + } + + case Call: + if isCallSafe(a, u.safe) { + switch b := b.Value.(type) { + case Var: + u.markSafe(b) + case *Array, Object: + u.markAllSafe(b) + } + } + + case *ArrayComprehension: + switch b := b.Value.(type) { + case Var: + u.markSafe(b) + case *Array: + u.markAllSafe(b) + } + case *ObjectComprehension: + switch b := b.Value.(type) { + case Var: + u.markSafe(b) + case *object: + u.markAllSafe(b) + } + case *SetComprehension: + switch b := b.Value.(type) { + case Var: + u.markSafe(b) + } + + case *Array: + switch b := b.Value.(type) { + case Var: + u.unifyAll(b, a) + case *ArrayComprehension, *ObjectComprehension, *SetComprehension: + u.markAllSafe(a) + case Ref: + if isRefSafe(b, u.safe) { + u.markAllSafe(a) + } + case Call: + if isCallSafe(b, u.safe) { + u.markAllSafe(a) + } + case *Array: + if a.Len() == b.Len() { + for i := 0; i < a.Len(); i++ { + u.unify(a.Elem(i), b.Elem(i)) + } + } + } + + case *object: + switch b := b.Value.(type) { + case Var: + u.unifyAll(b, a) + case Ref: + if isRefSafe(b, u.safe) { + u.markAllSafe(a) + } + case Call: + if isCallSafe(b, u.safe) { + u.markAllSafe(a) + } + case *object: + if a.Len() == b.Len() { + _ = a.Iter(func(k, v *Term) error { + if v2 := b.Get(k); v2 != nil { + u.unify(v, v2) + } + return nil + }) // impossible to return error + } + } + + default: + switch b := b.Value.(type) { + case Var: + u.markSafe(b) + } + } +} + +func (u *unifier) markAllSafe(x Value) { + vis := u.varVisitor() + vis.Walk(x) + for v := range vis.Vars() { + u.markSafe(v) + } +} + +func (u *unifier) markSafe(x Var) { + u.unified.Add(x) + + // Add dependencies of 'x' to safe set + vs := u.unknown[x] + delete(u.unknown, x) + for v := range vs { + u.markSafe(v) + } + + // Add dependants of 'x' to safe set if they have no more + // dependencies. + for v, deps := range u.unknown { + if deps.Contains(x) { + delete(deps, x) + if len(deps) == 0 { + u.markSafe(v) + } + } + } +} + +func (u *unifier) markUnknown(a, b Var) { + if _, ok := u.unknown[a]; !ok { + u.unknown[a] = NewVarSet() + } + u.unknown[a].Add(b) +} + +func (u *unifier) unifyAll(a Var, b Value) { + if u.isSafe(a) { + u.markAllSafe(b) + } else { + vis := u.varVisitor() + vis.Walk(b) + unsafe := vis.Vars().Diff(u.safe).Diff(u.unified) + if len(unsafe) == 0 { + u.markSafe(a) + } else { + for v := range unsafe { + u.markUnknown(a, v) + } + } + } +} + +func (u *unifier) varVisitor() *VarVisitor { + return NewVarVisitor().WithParams(VarVisitorParams{ + SkipRefHead: true, + SkipObjectKeys: true, + SkipClosures: true, + }) +} diff --git a/vendor/github.com/open-policy-agent/opa/v1/ast/varset.go b/vendor/github.com/open-policy-agent/opa/v1/ast/varset.go new file mode 100644 index 000000000..14f531494 --- /dev/null +++ b/vendor/github.com/open-policy-agent/opa/v1/ast/varset.go @@ -0,0 +1,100 @@ +// Copyright 2016 The OPA Authors. All rights reserved. +// Use of this source code is governed by an Apache2 +// license that can be found in the LICENSE file. + +package ast + +import ( + "fmt" + "sort" +) + +// VarSet represents a set of variables. +type VarSet map[Var]struct{} + +// NewVarSet returns a new VarSet containing the specified variables. +func NewVarSet(vs ...Var) VarSet { + s := VarSet{} + for _, v := range vs { + s.Add(v) + } + return s +} + +// Add updates the set to include the variable "v". +func (s VarSet) Add(v Var) { + s[v] = struct{}{} +} + +// Contains returns true if the set contains the variable "v". +func (s VarSet) Contains(v Var) bool { + _, ok := s[v] + return ok +} + +// Copy returns a shallow copy of the VarSet. +func (s VarSet) Copy() VarSet { + cpy := VarSet{} + for v := range s { + cpy.Add(v) + } + return cpy +} + +// Diff returns a VarSet containing variables in s that are not in vs. +func (s VarSet) Diff(vs VarSet) VarSet { + r := VarSet{} + for v := range s { + if !vs.Contains(v) { + r.Add(v) + } + } + return r +} + +// Equal returns true if s contains exactly the same elements as vs. +func (s VarSet) Equal(vs VarSet) bool { + if len(s.Diff(vs)) > 0 { + return false + } + return len(vs.Diff(s)) == 0 +} + +// Intersect returns a VarSet containing variables in s that are in vs. +func (s VarSet) Intersect(vs VarSet) VarSet { + r := VarSet{} + for v := range s { + if vs.Contains(v) { + r.Add(v) + } + } + return r +} + +// Sorted returns a sorted slice of vars from s. +func (s VarSet) Sorted() []Var { + sorted := make([]Var, 0, len(s)) + for v := range s { + sorted = append(sorted, v) + } + sort.Slice(sorted, func(i, j int) bool { + return sorted[i].Compare(sorted[j]) < 0 + }) + return sorted +} + +// Update merges the other VarSet into this VarSet. +func (s VarSet) Update(vs VarSet) { + for v := range vs { + s.Add(v) + } +} + +func (s VarSet) String() string { + tmp := make([]string, 0, len(s)) + for v := range s { + tmp = append(tmp, string(v)) + } + sort.Strings(tmp) + return fmt.Sprintf("%v", tmp) +} diff --git a/vendor/github.com/open-policy-agent/opa/ast/version_index.json b/vendor/github.com/open-policy-agent/opa/v1/ast/version_index.json similarity index 99% rename from vendor/github.com/open-policy-agent/opa/ast/version_index.json rename to vendor/github.com/open-policy-agent/opa/v1/ast/version_index.json index 718df220f..b888b3e02 100644 --- a/vendor/github.com/open-policy-agent/opa/ast/version_index.json +++ b/vendor/github.com/open-policy-agent/opa/v1/ast/version_index.json @@ -1395,6 +1395,13 @@ } }, "features": { + "rego_v1": { + "Major": 1, + "Minor": 0, + "Patch": 0, + "PreRelease": "", + "Metadata": "" + }, "rego_v1_import": { "Major": 0, "Minor": 59, diff --git a/vendor/github.com/open-policy-agent/opa/v1/ast/visit.go b/vendor/github.com/open-policy-agent/opa/v1/ast/visit.go new file mode 100644 index 000000000..91cfa208e --- /dev/null +++ b/vendor/github.com/open-policy-agent/opa/v1/ast/visit.go @@ -0,0 +1,783 @@ +// Copyright 2016 The OPA Authors. All rights reserved. +// Use of this source code is governed by an Apache2 +// license that can be found in the LICENSE file. + +package ast + +// Visitor defines the interface for iterating AST elements. The Visit function +// can return a Visitor w which will be used to visit the children of the AST +// element v. If the Visit function returns nil, the children will not be +// visited. +// Deprecated: use GenericVisitor or another visitor implementation +type Visitor interface { + Visit(v interface{}) (w Visitor) +} + +// BeforeAndAfterVisitor wraps Visitor to provide hooks for being called before +// and after the AST has been visited. +// Deprecated: use GenericVisitor or another visitor implementation +type BeforeAndAfterVisitor interface { + Visitor + Before(x interface{}) + After(x interface{}) +} + +// Walk iterates the AST by calling the Visit function on the Visitor +// v for x before recursing. +// Deprecated: use GenericVisitor.Walk +func Walk(v Visitor, x interface{}) { + if bav, ok := v.(BeforeAndAfterVisitor); !ok { + walk(v, x) + } else { + bav.Before(x) + defer bav.After(x) + walk(bav, x) + } +} + +// WalkBeforeAndAfter iterates the AST by calling the Visit function on the +// Visitor v for x before recursing. +// Deprecated: use GenericVisitor.Walk +func WalkBeforeAndAfter(v BeforeAndAfterVisitor, x interface{}) { + Walk(v, x) +} + +func walk(v Visitor, x interface{}) { + w := v.Visit(x) + if w == nil { + return + } + switch x := x.(type) { + case *Module: + Walk(w, x.Package) + for i := range x.Imports { + Walk(w, x.Imports[i]) + } + for i := range x.Rules { + Walk(w, x.Rules[i]) + } + for i := range x.Annotations { + Walk(w, x.Annotations[i]) + } + for i := range x.Comments { + Walk(w, x.Comments[i]) + } + case *Package: + Walk(w, x.Path) + case *Import: + Walk(w, x.Path) + Walk(w, x.Alias) + case *Rule: + Walk(w, x.Head) + Walk(w, x.Body) + if x.Else != nil { + Walk(w, x.Else) + } + case *Head: + Walk(w, x.Name) + Walk(w, x.Args) + if x.Key != nil { + Walk(w, x.Key) + } + if x.Value != nil { + Walk(w, x.Value) + } + case Body: + for i := range x { + Walk(w, x[i]) + } + case Args: + for i := range x { + Walk(w, x[i]) + } + case *Expr: + switch ts := x.Terms.(type) { + case *Term, *SomeDecl, *Every: + Walk(w, ts) + case []*Term: + for i := range ts { + Walk(w, ts[i]) + } + } + for i := range x.With { + Walk(w, x.With[i]) + } + case *With: + Walk(w, x.Target) + Walk(w, x.Value) + case *Term: + Walk(w, x.Value) + case Ref: + for i := range x { + Walk(w, x[i]) + } + case *object: + x.Foreach(func(k, vv *Term) { + Walk(w, k) + Walk(w, vv) + }) + case *Array: + x.Foreach(func(t *Term) { + Walk(w, t) + }) + case Set: + x.Foreach(func(t *Term) { + Walk(w, t) + }) + case *ArrayComprehension: + Walk(w, x.Term) + Walk(w, x.Body) + case *ObjectComprehension: + Walk(w, x.Key) + Walk(w, x.Value) + Walk(w, x.Body) + case *SetComprehension: + Walk(w, x.Term) + Walk(w, x.Body) + case Call: + for i := range x { + Walk(w, x[i]) + } + case *Every: + if x.Key != nil { + Walk(w, x.Key) + } + Walk(w, x.Value) + Walk(w, x.Domain) + Walk(w, x.Body) + case *SomeDecl: + for i := range x.Symbols { + Walk(w, x.Symbols[i]) + } + } +} + +// WalkVars calls the function f on all vars under x. If the function f +// returns true, AST nodes under the last node will not be visited. +func WalkVars(x interface{}, f func(Var) bool) { + vis := &GenericVisitor{func(x interface{}) bool { + if v, ok := x.(Var); ok { + return f(v) + } + return false + }} + vis.Walk(x) +} + +// WalkClosures calls the function f on all closures under x. If the function f +// returns true, AST nodes under the last node will not be visited. +func WalkClosures(x interface{}, f func(interface{}) bool) { + vis := &GenericVisitor{func(x interface{}) bool { + switch x := x.(type) { + case *ArrayComprehension, *ObjectComprehension, *SetComprehension, *Every: + return f(x) + } + return false + }} + vis.Walk(x) +} + +// WalkRefs calls the function f on all references under x. If the function f +// returns true, AST nodes under the last node will not be visited. +func WalkRefs(x interface{}, f func(Ref) bool) { + vis := &GenericVisitor{func(x interface{}) bool { + if r, ok := x.(Ref); ok { + return f(r) + } + return false + }} + vis.Walk(x) +} + +// WalkTerms calls the function f on all terms under x. If the function f +// returns true, AST nodes under the last node will not be visited. +func WalkTerms(x interface{}, f func(*Term) bool) { + vis := &GenericVisitor{func(x interface{}) bool { + if term, ok := x.(*Term); ok { + return f(term) + } + return false + }} + vis.Walk(x) +} + +// WalkWiths calls the function f on all with modifiers under x. If the function f +// returns true, AST nodes under the last node will not be visited. +func WalkWiths(x interface{}, f func(*With) bool) { + vis := &GenericVisitor{func(x interface{}) bool { + if w, ok := x.(*With); ok { + return f(w) + } + return false + }} + vis.Walk(x) +} + +// WalkExprs calls the function f on all expressions under x. If the function f +// returns true, AST nodes under the last node will not be visited. +func WalkExprs(x interface{}, f func(*Expr) bool) { + vis := &GenericVisitor{func(x interface{}) bool { + if r, ok := x.(*Expr); ok { + return f(r) + } + return false + }} + vis.Walk(x) +} + +// WalkBodies calls the function f on all bodies under x. If the function f +// returns true, AST nodes under the last node will not be visited. +func WalkBodies(x interface{}, f func(Body) bool) { + vis := &GenericVisitor{func(x interface{}) bool { + if b, ok := x.(Body); ok { + return f(b) + } + return false + }} + vis.Walk(x) +} + +// WalkRules calls the function f on all rules under x. If the function f +// returns true, AST nodes under the last node will not be visited. +func WalkRules(x interface{}, f func(*Rule) bool) { + vis := &GenericVisitor{func(x interface{}) bool { + if r, ok := x.(*Rule); ok { + stop := f(r) + // NOTE(tsandall): since rules cannot be embedded inside of queries + // we can stop early if there is no else block. + if stop || r.Else == nil { + return true + } + } + return false + }} + vis.Walk(x) +} + +// WalkNodes calls the function f on all nodes under x. If the function f +// returns true, AST nodes under the last node will not be visited. +func WalkNodes(x interface{}, f func(Node) bool) { + vis := &GenericVisitor{func(x interface{}) bool { + if n, ok := x.(Node); ok { + return f(n) + } + return false + }} + vis.Walk(x) +} + +// GenericVisitor provides a utility to walk over AST nodes using a +// closure. If the closure returns true, the visitor will not walk +// over AST nodes under x. +type GenericVisitor struct { + f func(x interface{}) bool +} + +// NewGenericVisitor returns a new GenericVisitor that will invoke the function +// f on AST nodes. +func NewGenericVisitor(f func(x interface{}) bool) *GenericVisitor { + return &GenericVisitor{f} +} + +// Walk iterates the AST by calling the function f on the +// GenericVisitor before recursing. Contrary to the generic Walk, this +// does not require allocating the visitor from heap. +func (vis *GenericVisitor) Walk(x interface{}) { + if vis.f(x) { + return + } + + switch x := x.(type) { + case *Module: + vis.Walk(x.Package) + for i := range x.Imports { + vis.Walk(x.Imports[i]) + } + for i := range x.Rules { + vis.Walk(x.Rules[i]) + } + for i := range x.Annotations { + vis.Walk(x.Annotations[i]) + } + for i := range x.Comments { + vis.Walk(x.Comments[i]) + } + case *Package: + vis.Walk(x.Path) + case *Import: + vis.Walk(x.Path) + vis.Walk(x.Alias) + case *Rule: + vis.Walk(x.Head) + vis.Walk(x.Body) + if x.Else != nil { + vis.Walk(x.Else) + } + case *Head: + vis.Walk(x.Name) + vis.Walk(x.Args) + if x.Key != nil { + vis.Walk(x.Key) + } + if x.Value != nil { + vis.Walk(x.Value) + } + case Body: + for i := range x { + vis.Walk(x[i]) + } + case Args: + for i := range x { + vis.Walk(x[i]) + } + case *Expr: + switch ts := x.Terms.(type) { + case *Term, *SomeDecl, *Every: + vis.Walk(ts) + case []*Term: + for i := range ts { + vis.Walk(ts[i]) + } + } + for i := range x.With { + vis.Walk(x.With[i]) + } + case *With: + vis.Walk(x.Target) + vis.Walk(x.Value) + case *Term: + vis.Walk(x.Value) + case Ref: + for i := range x { + vis.Walk(x[i]) + } + case *object: + x.Foreach(func(k, _ *Term) { + vis.Walk(k) + vis.Walk(x.Get(k)) + }) + case Object: + for _, k := range x.Keys() { + vis.Walk(k) + vis.Walk(x.Get(k)) + } + case *Array: + for i := 0; i < x.Len(); i++ { + vis.Walk(x.Elem(i)) + } + case Set: + xSlice := x.Slice() + for i := range xSlice { + vis.Walk(xSlice[i]) + } + case *ArrayComprehension: + vis.Walk(x.Term) + vis.Walk(x.Body) + case *ObjectComprehension: + vis.Walk(x.Key) + vis.Walk(x.Value) + vis.Walk(x.Body) + case *SetComprehension: + vis.Walk(x.Term) + vis.Walk(x.Body) + case Call: + for i := range x { + vis.Walk(x[i]) + } + case *Every: + if x.Key != nil { + vis.Walk(x.Key) + } + vis.Walk(x.Value) + vis.Walk(x.Domain) + vis.Walk(x.Body) + case *SomeDecl: + for i := range x.Symbols { + vis.Walk(x.Symbols[i]) + } + } +} + +// BeforeAfterVisitor provides a utility to walk over AST nodes using +// closures. If the before closure returns true, the visitor will not +// walk over AST nodes under x. The after closure is invoked always +// after visiting a node. +type BeforeAfterVisitor struct { + before func(x interface{}) bool + after func(x interface{}) +} + +// NewBeforeAfterVisitor returns a new BeforeAndAfterVisitor that +// will invoke the functions before and after AST nodes. +func NewBeforeAfterVisitor(before func(x interface{}) bool, after func(x interface{})) *BeforeAfterVisitor { + return &BeforeAfterVisitor{before, after} +} + +// Walk iterates the AST by calling the functions on the +// BeforeAndAfterVisitor before and after recursing. Contrary to the +// generic Walk, this does not require allocating the visitor from +// heap. +func (vis *BeforeAfterVisitor) Walk(x interface{}) { + defer vis.after(x) + if vis.before(x) { + return + } + + switch x := x.(type) { + case *Module: + vis.Walk(x.Package) + for i := range x.Imports { + vis.Walk(x.Imports[i]) + } + for i := range x.Rules { + vis.Walk(x.Rules[i]) + } + for i := range x.Annotations { + vis.Walk(x.Annotations[i]) + } + for i := range x.Comments { + vis.Walk(x.Comments[i]) + } + case *Package: + vis.Walk(x.Path) + case *Import: + vis.Walk(x.Path) + vis.Walk(x.Alias) + case *Rule: + vis.Walk(x.Head) + vis.Walk(x.Body) + if x.Else != nil { + vis.Walk(x.Else) + } + case *Head: + if len(x.Reference) > 0 { + vis.Walk(x.Reference) + } else { + vis.Walk(x.Name) + if x.Key != nil { + vis.Walk(x.Key) + } + } + vis.Walk(x.Args) + if x.Value != nil { + vis.Walk(x.Value) + } + case Body: + for i := range x { + vis.Walk(x[i]) + } + case Args: + for i := range x { + vis.Walk(x[i]) + } + case *Expr: + switch ts := x.Terms.(type) { + case *Term, *SomeDecl, *Every: + vis.Walk(ts) + case []*Term: + for i := range ts { + vis.Walk(ts[i]) + } + } + for i := range x.With { + vis.Walk(x.With[i]) + } + case *With: + vis.Walk(x.Target) + vis.Walk(x.Value) + case *Term: + vis.Walk(x.Value) + case Ref: + for i := range x { + vis.Walk(x[i]) + } + case *object: + x.Foreach(func(k, _ *Term) { + vis.Walk(k) + vis.Walk(x.Get(k)) + }) + case Object: + x.Foreach(func(k, _ *Term) { + vis.Walk(k) + vis.Walk(x.Get(k)) + }) + case *Array: + x.Foreach(func(t *Term) { + vis.Walk(t) + }) + case Set: + xSlice := x.Slice() + for i := range xSlice { + vis.Walk(xSlice[i]) + } + case *ArrayComprehension: + vis.Walk(x.Term) + vis.Walk(x.Body) + case *ObjectComprehension: + vis.Walk(x.Key) + vis.Walk(x.Value) + vis.Walk(x.Body) + case *SetComprehension: + vis.Walk(x.Term) + vis.Walk(x.Body) + case Call: + for i := range x { + vis.Walk(x[i]) + } + case *Every: + if x.Key != nil { + vis.Walk(x.Key) + } + vis.Walk(x.Value) + vis.Walk(x.Domain) + vis.Walk(x.Body) + case *SomeDecl: + for i := range x.Symbols { + vis.Walk(x.Symbols[i]) + } + } +} + +// VarVisitor walks AST nodes under a given node and collects all encountered +// variables. The collected variables can be controlled by specifying +// VarVisitorParams when creating the visitor. +type VarVisitor struct { + params VarVisitorParams + vars VarSet +} + +// VarVisitorParams contains settings for a VarVisitor. +type VarVisitorParams struct { + SkipRefHead bool + SkipRefCallHead bool + SkipObjectKeys bool + SkipClosures bool + SkipWithTarget bool + SkipSets bool +} + +// NewVarVisitor returns a new VarVisitor object. +func NewVarVisitor() *VarVisitor { + return &VarVisitor{ + vars: NewVarSet(), + } +} + +// WithParams sets the parameters in params on vis. +func (vis *VarVisitor) WithParams(params VarVisitorParams) *VarVisitor { + vis.params = params + return vis +} + +// Vars returns a VarSet that contains collected vars. +func (vis *VarVisitor) Vars() VarSet { + return vis.vars +} + +// visit determines if the VarVisitor will recurse into x: if it returns `true`, +// the visitor will _skip_ that branch of the AST +func (vis *VarVisitor) visit(v interface{}) bool { + if vis.params.SkipObjectKeys { + if o, ok := v.(Object); ok { + o.Foreach(func(_, v *Term) { + vis.Walk(v) + }) + return true + } + } + if vis.params.SkipRefHead { + if r, ok := v.(Ref); ok { + rSlice := r[1:] + for i := range rSlice { + vis.Walk(rSlice[i]) + } + return true + } + } + if vis.params.SkipClosures { + switch v := v.(type) { + case *ArrayComprehension, *ObjectComprehension, *SetComprehension: + return true + case *Expr: + if ev, ok := v.Terms.(*Every); ok { + vis.Walk(ev.Domain) + // We're _not_ walking ev.Body -- that's the closure here + return true + } + } + } + if vis.params.SkipWithTarget { + if v, ok := v.(*With); ok { + vis.Walk(v.Value) + return true + } + } + if vis.params.SkipSets { + if _, ok := v.(Set); ok { + return true + } + } + if vis.params.SkipRefCallHead { + switch v := v.(type) { + case *Expr: + if terms, ok := v.Terms.([]*Term); ok { + termSlice := terms[0].Value.(Ref)[1:] + for i := range termSlice { + vis.Walk(termSlice[i]) + } + for i := 1; i < len(terms); i++ { + vis.Walk(terms[i]) + } + for i := range v.With { + vis.Walk(v.With[i]) + } + return true + } + case Call: + operator := v[0].Value.(Ref) + for i := 1; i < len(operator); i++ { + vis.Walk(operator[i]) + } + for i := 1; i < len(v); i++ { + vis.Walk(v[i]) + } + return true + case *With: + if ref, ok := v.Target.Value.(Ref); ok { + refSlice := ref[1:] + for i := range refSlice { + vis.Walk(refSlice[i]) + } + } + if ref, ok := v.Value.Value.(Ref); ok { + refSlice := ref[1:] + for i := range refSlice { + vis.Walk(refSlice[i]) + } + } else { + vis.Walk(v.Value) + } + return true + } + } + if v, ok := v.(Var); ok { + vis.vars.Add(v) + } + return false +} + +// Walk iterates the AST by calling the function f on the +// GenericVisitor before recursing. Contrary to the generic Walk, this +// does not require allocating the visitor from heap. +func (vis *VarVisitor) Walk(x interface{}) { + if vis.visit(x) { + return + } + + switch x := x.(type) { + case *Module: + vis.Walk(x.Package) + for i := range x.Imports { + vis.Walk(x.Imports[i]) + } + for i := range x.Rules { + vis.Walk(x.Rules[i]) + } + for i := range x.Comments { + vis.Walk(x.Comments[i]) + } + case *Package: + vis.Walk(x.Path) + case *Import: + vis.Walk(x.Path) + vis.Walk(x.Alias) + case *Rule: + vis.Walk(x.Head) + vis.Walk(x.Body) + if x.Else != nil { + vis.Walk(x.Else) + } + case *Head: + if len(x.Reference) > 0 { + vis.Walk(x.Reference) + } else { + vis.Walk(x.Name) + if x.Key != nil { + vis.Walk(x.Key) + } + } + vis.Walk(x.Args) + + if x.Value != nil { + vis.Walk(x.Value) + } + case Body: + for i := range x { + vis.Walk(x[i]) + } + case Args: + for i := range x { + vis.Walk(x[i]) + } + case *Expr: + switch ts := x.Terms.(type) { + case *Term, *SomeDecl, *Every: + vis.Walk(ts) + case []*Term: + for i := range ts { + vis.Walk(ts[i]) + } + } + for i := range x.With { + vis.Walk(x.With[i]) + } + case *With: + vis.Walk(x.Target) + vis.Walk(x.Value) + case *Term: + vis.Walk(x.Value) + case Ref: + for i := range x { + vis.Walk(x[i]) + } + case *object: + x.Foreach(func(k, _ *Term) { + vis.Walk(k) + vis.Walk(x.Get(k)) + }) + case *Array: + x.Foreach(func(t *Term) { + vis.Walk(t) + }) + case Set: + xSlice := x.Slice() + for i := range xSlice { + vis.Walk(xSlice[i]) + } + case *ArrayComprehension: + vis.Walk(x.Term) + vis.Walk(x.Body) + case *ObjectComprehension: + vis.Walk(x.Key) + vis.Walk(x.Value) + vis.Walk(x.Body) + case *SetComprehension: + vis.Walk(x.Term) + vis.Walk(x.Body) + case Call: + for i := range x { + vis.Walk(x[i]) + } + case *Every: + if x.Key != nil { + vis.Walk(x.Key) + } + vis.Walk(x.Value) + vis.Walk(x.Domain) + vis.Walk(x.Body) + case *SomeDecl: + for i := range x.Symbols { + vis.Walk(x.Symbols[i]) + } + } +} diff --git a/vendor/github.com/open-policy-agent/opa/v1/bundle/bundle.go b/vendor/github.com/open-policy-agent/opa/v1/bundle/bundle.go new file mode 100644 index 000000000..be320f6a7 --- /dev/null +++ b/vendor/github.com/open-policy-agent/opa/v1/bundle/bundle.go @@ -0,0 +1,1760 @@ +// Copyright 2018 The OPA Authors. All rights reserved. +// Use of this source code is governed by an Apache2 +// license that can be found in the LICENSE file. + +// Package bundle implements bundle loading. +package bundle + +import ( + "archive/tar" + "bytes" + "compress/gzip" + "encoding/hex" + "encoding/json" + "errors" + "fmt" + "io" + "net/url" + "os" + "path" + "path/filepath" + "reflect" + "strings" + + "github.com/gobwas/glob" + "github.com/open-policy-agent/opa/internal/file/archive" + "github.com/open-policy-agent/opa/internal/merge" + "github.com/open-policy-agent/opa/v1/ast" + astJSON "github.com/open-policy-agent/opa/v1/ast/json" + "github.com/open-policy-agent/opa/v1/format" + "github.com/open-policy-agent/opa/v1/metrics" + "github.com/open-policy-agent/opa/v1/util" +) + +// Common file extensions and file names. +const ( + RegoExt = ".rego" + WasmFile = "policy.wasm" + PlanFile = "plan.json" + ManifestExt = ".manifest" + SignaturesFile = "signatures.json" + patchFile = "patch.json" + dataFile = "data.json" + yamlDataFile = "data.yaml" + ymlDataFile = "data.yml" + defaultHashingAlg = "SHA-256" + DefaultSizeLimitBytes = (1024 * 1024 * 1024) // limit bundle reads to 1GB to protect against gzip bombs + DeltaBundleType = "delta" + SnapshotBundleType = "snapshot" +) + +// Bundle represents a loaded bundle. The bundle can contain data and policies. +type Bundle struct { + Signatures SignaturesConfig + Manifest Manifest + Data map[string]interface{} + Modules []ModuleFile + Wasm []byte // Deprecated. Use WasmModules instead + WasmModules []WasmModuleFile + PlanModules []PlanModuleFile + Patch Patch + Etag string + Raw []Raw + + lazyLoadingMode bool + sizeLimitBytes int64 +} + +// Raw contains raw bytes representing the bundle's content +type Raw struct { + Path string + Value []byte +} + +// Patch contains an array of objects wherein each object represents the patch operation to be +// applied to the bundle data. +type Patch struct { + Data []PatchOperation `json:"data,omitempty"` +} + +// PatchOperation models a single patch operation against a document. +type PatchOperation struct { + Op string `json:"op"` + Path string `json:"path"` + Value interface{} `json:"value"` +} + +// SignaturesConfig represents an array of JWTs that encapsulate the signatures for the bundle. +type SignaturesConfig struct { + Signatures []string `json:"signatures,omitempty"` + Plugin string `json:"plugin,omitempty"` +} + +// isEmpty returns if the SignaturesConfig is empty. +func (s SignaturesConfig) isEmpty() bool { + return reflect.DeepEqual(s, SignaturesConfig{}) +} + +// DecodedSignature represents the decoded JWT payload. +type DecodedSignature struct { + Files []FileInfo `json:"files"` + KeyID string `json:"keyid"` // Deprecated, use kid in the JWT header instead. + Scope string `json:"scope"` + IssuedAt int64 `json:"iat"` + Issuer string `json:"iss"` +} + +// FileInfo contains the hashing algorithm used, resulting digest etc. +type FileInfo struct { + Name string `json:"name"` + Hash string `json:"hash"` + Algorithm string `json:"algorithm"` +} + +// NewFile returns a new FileInfo. +func NewFile(name, hash, alg string) FileInfo { + return FileInfo{ + Name: name, + Hash: hash, + Algorithm: alg, + } +} + +// Manifest represents the manifest from a bundle. The manifest may contain +// metadata such as the bundle revision. +type Manifest struct { + Revision string `json:"revision"` + Roots *[]string `json:"roots,omitempty"` + WasmResolvers []WasmResolver `json:"wasm,omitempty"` + // RegoVersion is the global Rego version for the bundle described by this Manifest. + // The Rego version of individual files can be overridden in FileRegoVersions. + // We don't use ast.RegoVersion here, as this iota type's order isn't guaranteed to be stable over time. + // We use a pointer so that we can support hand-made bundles that don't have an explicit version appropriately. + // E.g. in OPA 0.x if --v1-compatible is used when consuming the bundle, and there is no specified version, + // we should default to v1; if --v1-compatible isn't used, we should default to v0. In OPA 1.0, no --x-compatible + // flag and no explicit bundle version should default to v1. + RegoVersion *int `json:"rego_version,omitempty"` + // FileRegoVersions is a map from file paths to Rego versions. + // This allows individual files to override the global Rego version specified by RegoVersion. + FileRegoVersions map[string]int `json:"file_rego_versions,omitempty"` + Metadata map[string]interface{} `json:"metadata,omitempty"` + + compiledFileRegoVersions []fileRegoVersion +} + +type fileRegoVersion struct { + path glob.Glob + version int +} + +// WasmResolver maps a wasm module to an entrypoint ref. +type WasmResolver struct { + Entrypoint string `json:"entrypoint,omitempty"` + Module string `json:"module,omitempty"` + Annotations []*ast.Annotations `json:"annotations,omitempty"` +} + +// Init initializes the manifest. If you instantiate a manifest +// manually, call Init to ensure that the roots are set properly. +func (m *Manifest) Init() { + if m.Roots == nil { + defaultRoots := []string{""} + m.Roots = &defaultRoots + } +} + +// AddRoot adds r to the roots of m. This function is idempotent. +func (m *Manifest) AddRoot(r string) { + m.Init() + if !RootPathsContain(*m.Roots, r) { + *m.Roots = append(*m.Roots, r) + } +} + +func (m *Manifest) SetRegoVersion(v ast.RegoVersion) { + m.Init() + regoVersion := 0 + if v == ast.RegoV1 { + regoVersion = 1 + } + m.RegoVersion = ®oVersion +} + +// Equal returns true if m is semantically equivalent to other. +func (m Manifest) Equal(other Manifest) bool { + + // This is safe since both are passed by value. + m.Init() + other.Init() + + if m.Revision != other.Revision { + return false + } + + if m.RegoVersion == nil && other.RegoVersion != nil { + return false + } + if m.RegoVersion != nil && other.RegoVersion == nil { + return false + } + if m.RegoVersion != nil && other.RegoVersion != nil && *m.RegoVersion != *other.RegoVersion { + return false + } + + // If both are nil, or both are empty, we consider them equal. + if !(len(m.FileRegoVersions) == 0 && len(other.FileRegoVersions) == 0) && + !reflect.DeepEqual(m.FileRegoVersions, other.FileRegoVersions) { + return false + } + + if !reflect.DeepEqual(m.Metadata, other.Metadata) { + return false + } + + return m.equalWasmResolversAndRoots(other) +} + +func (m Manifest) Empty() bool { + return m.Equal(Manifest{}) +} + +// Copy returns a deep copy of the manifest. +func (m Manifest) Copy() Manifest { + m.Init() + roots := make([]string, len(*m.Roots)) + copy(roots, *m.Roots) + m.Roots = &roots + + wasmModules := make([]WasmResolver, len(m.WasmResolvers)) + copy(wasmModules, m.WasmResolvers) + m.WasmResolvers = wasmModules + + metadata := m.Metadata + + if metadata != nil { + m.Metadata = make(map[string]interface{}) + for k, v := range metadata { + m.Metadata[k] = v + } + } + + return m +} + +func (m Manifest) String() string { + m.Init() + if m.RegoVersion != nil { + return fmt.Sprintf("", + m.Revision, *m.RegoVersion, *m.Roots, m.WasmResolvers, m.Metadata) + } + return fmt.Sprintf("", + m.Revision, *m.Roots, m.WasmResolvers, m.Metadata) +} + +func (m Manifest) rootSet() stringSet { + rs := map[string]struct{}{} + + for _, r := range *m.Roots { + rs[r] = struct{}{} + } + + return stringSet(rs) +} + +func (m Manifest) equalWasmResolversAndRoots(other Manifest) bool { + if len(m.WasmResolvers) != len(other.WasmResolvers) { + return false + } + + for i := 0; i < len(m.WasmResolvers); i++ { + if !m.WasmResolvers[i].Equal(&other.WasmResolvers[i]) { + return false + } + } + + return m.rootSet().Equal(other.rootSet()) +} + +func (wr *WasmResolver) Equal(other *WasmResolver) bool { + if wr == nil && other == nil { + return true + } + + if wr == nil || other == nil { + return false + } + + if wr.Module != other.Module { + return false + } + + if wr.Entrypoint != other.Entrypoint { + return false + } + + annotLen := len(wr.Annotations) + if annotLen != len(other.Annotations) { + return false + } + + for i := 0; i < annotLen; i++ { + if wr.Annotations[i].Compare(other.Annotations[i]) != 0 { + return false + } + } + + return true +} + +type stringSet map[string]struct{} + +func (ss stringSet) Equal(other stringSet) bool { + if len(ss) != len(other) { + return false + } + for k := range other { + if _, ok := ss[k]; !ok { + return false + } + } + return true +} + +func (m *Manifest) validateAndInjectDefaults(b Bundle) error { + + m.Init() + + // Validate roots in bundle. + roots := *m.Roots + + // Standardize the roots (no starting or trailing slash) + for i := range roots { + roots[i] = strings.Trim(roots[i], "/") + } + + for i := 0; i < len(roots)-1; i++ { + for j := i + 1; j < len(roots); j++ { + if RootPathsOverlap(roots[i], roots[j]) { + return fmt.Errorf("manifest has overlapped roots: '%v' and '%v'", roots[i], roots[j]) + } + } + } + + // Validate modules in bundle. + for _, module := range b.Modules { + found := false + if path, err := module.Parsed.Package.Path.Ptr(); err == nil { + found = RootPathsContain(roots, path) + } + if !found { + return fmt.Errorf("manifest roots %v do not permit '%v' in module '%v'", roots, module.Parsed.Package, module.Path) + } + } + + // Build a set of wasm module entrypoints to validate + wasmModuleToEps := map[string]string{} + seenEps := map[string]struct{}{} + for _, wm := range b.WasmModules { + wasmModuleToEps[wm.Path] = "" + } + + for _, wmConfig := range b.Manifest.WasmResolvers { + _, ok := wasmModuleToEps[wmConfig.Module] + if !ok { + return fmt.Errorf("manifest references wasm module '%s' but the module file does not exist", wmConfig.Module) + } + + // Ensure wasm module entrypoint in within bundle roots + if !RootPathsContain(roots, wmConfig.Entrypoint) { + return fmt.Errorf("manifest roots %v do not permit '%v' entrypoint for wasm module '%v'", roots, wmConfig.Entrypoint, wmConfig.Module) + } + + if _, ok := seenEps[wmConfig.Entrypoint]; ok { + return fmt.Errorf("entrypoint '%s' cannot be used by more than one wasm module", wmConfig.Entrypoint) + } + seenEps[wmConfig.Entrypoint] = struct{}{} + + wasmModuleToEps[wmConfig.Module] = wmConfig.Entrypoint + } + + // Validate data patches in bundle. + for _, patch := range b.Patch.Data { + path := strings.Trim(patch.Path, "/") + if !RootPathsContain(roots, path) { + return fmt.Errorf("manifest roots %v do not permit data patch at path '%s'", roots, path) + } + } + + if b.lazyLoadingMode { + return nil + } + + // Validate data in bundle. + return dfs(b.Data, "", func(path string, node interface{}) (bool, error) { + path = strings.Trim(path, "/") + if RootPathsContain(roots, path) { + return true, nil + } + + if _, ok := node.(map[string]interface{}); ok { + for i := range roots { + if RootPathsContain(strings.Split(path, "/"), roots[i]) { + return false, nil + } + } + } + return false, fmt.Errorf("manifest roots %v do not permit data at path '/%s' (hint: check bundle directory structure)", roots, path) + }) +} + +// ModuleFile represents a single module contained in a bundle. +type ModuleFile struct { + URL string + Path string + RelativePath string + Raw []byte + Parsed *ast.Module +} + +// WasmModuleFile represents a single wasm module contained in a bundle. +type WasmModuleFile struct { + URL string + Path string + Entrypoints []ast.Ref + Raw []byte +} + +// PlanModuleFile represents a single plan module contained in a bundle. +// +// NOTE(tsandall): currently the plans are just opaque binary blobs. In the +// future we could inject the entrypoints so that the plans could be executed +// inside of OPA proper like we do for Wasm modules. +type PlanModuleFile struct { + URL string + Path string + Raw []byte +} + +// Reader contains the reader to load the bundle from. +type Reader struct { + loader DirectoryLoader + includeManifestInData bool + metrics metrics.Metrics + baseDir string + verificationConfig *VerificationConfig + skipVerify bool + processAnnotations bool + jsonOptions *astJSON.Options + capabilities *ast.Capabilities + files map[string]FileInfo // files in the bundle signature payload + sizeLimitBytes int64 + etag string + lazyLoadingMode bool + name string + persist bool + regoVersion ast.RegoVersion + followSymlinks bool +} + +// NewReader is deprecated. Use NewCustomReader instead. +func NewReader(r io.Reader) *Reader { + return NewCustomReader(NewTarballLoader(r)) +} + +// NewCustomReader returns a new Reader configured to use the +// specified DirectoryLoader. +func NewCustomReader(loader DirectoryLoader) *Reader { + nr := Reader{ + loader: loader, + metrics: metrics.New(), + files: make(map[string]FileInfo), + sizeLimitBytes: DefaultSizeLimitBytes + 1, + } + return &nr +} + +// IncludeManifestInData sets whether the manifest metadata should be +// included in the bundle's data. +func (r *Reader) IncludeManifestInData(includeManifestInData bool) *Reader { + r.includeManifestInData = includeManifestInData + return r +} + +// WithMetrics sets the metrics object to be used while loading bundles +func (r *Reader) WithMetrics(m metrics.Metrics) *Reader { + r.metrics = m + return r +} + +// WithBaseDir sets a base directory for file paths of loaded Rego +// modules. This will *NOT* affect the loaded path of data files. +func (r *Reader) WithBaseDir(dir string) *Reader { + r.baseDir = dir + return r +} + +// WithBundleVerificationConfig sets the key configuration used to verify a signed bundle +func (r *Reader) WithBundleVerificationConfig(config *VerificationConfig) *Reader { + r.verificationConfig = config + return r +} + +// WithSkipBundleVerification skips verification of a signed bundle +func (r *Reader) WithSkipBundleVerification(skipVerify bool) *Reader { + r.skipVerify = skipVerify + return r +} + +// WithProcessAnnotations enables annotation processing during .rego file parsing. +func (r *Reader) WithProcessAnnotations(yes bool) *Reader { + r.processAnnotations = yes + return r +} + +// WithCapabilities sets the supported capabilities when loading the files +func (r *Reader) WithCapabilities(caps *ast.Capabilities) *Reader { + r.capabilities = caps + return r +} + +// WithJSONOptions sets the JSONOptions to use when parsing policy files +func (r *Reader) WithJSONOptions(opts *astJSON.Options) *Reader { + r.jsonOptions = opts + return r +} + +// WithSizeLimitBytes sets the size limit to apply to files in the bundle. If files are larger +// than this, an error will be returned by the reader. +func (r *Reader) WithSizeLimitBytes(n int64) *Reader { + r.sizeLimitBytes = n + 1 + return r +} + +// WithBundleEtag sets the given etag value on the bundle +func (r *Reader) WithBundleEtag(etag string) *Reader { + r.etag = etag + return r +} + +// WithBundleName specifies the bundle name +func (r *Reader) WithBundleName(name string) *Reader { + r.name = name + return r +} + +func (r *Reader) WithFollowSymlinks(yes bool) *Reader { + r.followSymlinks = yes + return r +} + +// WithLazyLoadingMode sets the bundle loading mode. If true, +// bundles will be read in lazy mode. In this mode, data files in the bundle will not be +// deserialized and the check to validate that the bundle data does not contain paths +// outside the bundle's roots will not be performed while reading the bundle. +func (r *Reader) WithLazyLoadingMode(yes bool) *Reader { + r.lazyLoadingMode = yes + return r +} + +// WithBundlePersistence specifies if the downloaded bundle will eventually be persisted to disk. +func (r *Reader) WithBundlePersistence(persist bool) *Reader { + r.persist = persist + return r +} + +func (r *Reader) WithRegoVersion(version ast.RegoVersion) *Reader { + r.regoVersion = version + return r +} + +func (r *Reader) ParserOptions() ast.ParserOptions { + return ast.ParserOptions{ + ProcessAnnotation: r.processAnnotations, + Capabilities: r.capabilities, + JSONOptions: r.jsonOptions, + RegoVersion: r.regoVersion, + } +} + +// Read returns a new Bundle loaded from the reader. +func (r *Reader) Read() (Bundle, error) { + + var bundle Bundle + var descriptors []*Descriptor + var err error + var raw []Raw + + bundle.Signatures, bundle.Patch, descriptors, err = preProcessBundle(r.loader, r.skipVerify, r.sizeLimitBytes) + if err != nil { + return bundle, err + } + + bundle.lazyLoadingMode = r.lazyLoadingMode + bundle.sizeLimitBytes = r.sizeLimitBytes + + if bundle.Type() == SnapshotBundleType { + err = r.checkSignaturesAndDescriptors(bundle.Signatures) + if err != nil { + return bundle, err + } + + bundle.Data = map[string]interface{}{} + } + + var modules []ModuleFile + for _, f := range descriptors { + buf, err := readFile(f, r.sizeLimitBytes) + if err != nil { + return bundle, err + } + + // verify the file content + if bundle.Type() == SnapshotBundleType && !bundle.Signatures.isEmpty() { + path := f.Path() + if r.baseDir != "" { + path = f.URL() + } + path = strings.TrimPrefix(path, "/") + + // check if the file is to be excluded from bundle verification + if r.isFileExcluded(path) { + delete(r.files, path) + } else { + if err = r.verifyBundleFile(path, buf); err != nil { + return bundle, err + } + } + } + + // Normalize the paths to use `/` separators + path := filepath.ToSlash(f.Path()) + + if strings.HasSuffix(path, RegoExt) { + fullPath := r.fullPath(path) + bs := buf.Bytes() + + if r.lazyLoadingMode { + p := fullPath + if r.name != "" { + p = modulePathWithPrefix(r.name, fullPath) + } + + raw = append(raw, Raw{Path: p, Value: bs}) + } + + // Modules are parsed after we've had a chance to read the manifest + mf := ModuleFile{ + URL: f.URL(), + Path: fullPath, + RelativePath: path, + Raw: bs, + } + modules = append(modules, mf) + } else if filepath.Base(path) == WasmFile { + bundle.WasmModules = append(bundle.WasmModules, WasmModuleFile{ + URL: f.URL(), + Path: r.fullPath(path), + Raw: buf.Bytes(), + }) + } else if filepath.Base(path) == PlanFile { + bundle.PlanModules = append(bundle.PlanModules, PlanModuleFile{ + URL: f.URL(), + Path: r.fullPath(path), + Raw: buf.Bytes(), + }) + } else if filepath.Base(path) == dataFile { + if r.lazyLoadingMode { + raw = append(raw, Raw{Path: path, Value: buf.Bytes()}) + continue + } + + var value interface{} + + r.metrics.Timer(metrics.RegoDataParse).Start() + err := util.UnmarshalJSON(buf.Bytes(), &value) + r.metrics.Timer(metrics.RegoDataParse).Stop() + + if err != nil { + return bundle, fmt.Errorf("bundle load failed on %v: %w", r.fullPath(path), err) + } + + if err := insertValue(&bundle, path, value); err != nil { + return bundle, err + } + + } else if filepath.Base(path) == yamlDataFile || filepath.Base(path) == ymlDataFile { + if r.lazyLoadingMode { + raw = append(raw, Raw{Path: path, Value: buf.Bytes()}) + continue + } + + var value interface{} + + r.metrics.Timer(metrics.RegoDataParse).Start() + err := util.Unmarshal(buf.Bytes(), &value) + r.metrics.Timer(metrics.RegoDataParse).Stop() + + if err != nil { + return bundle, fmt.Errorf("bundle load failed on %v: %w", r.fullPath(path), err) + } + + if err := insertValue(&bundle, path, value); err != nil { + return bundle, err + } + + } else if strings.HasSuffix(path, ManifestExt) { + if err := util.NewJSONDecoder(&buf).Decode(&bundle.Manifest); err != nil { + return bundle, fmt.Errorf("bundle load failed on manifest decode: %w", err) + } + } + } + + // Parse modules + popts := r.ParserOptions() + popts.RegoVersion = bundle.RegoVersion(popts.EffectiveRegoVersion()) + for _, mf := range modules { + modulePopts := popts + if modulePopts.RegoVersion, err = bundle.RegoVersionForFile(mf.RelativePath, popts.EffectiveRegoVersion()); err != nil { + return bundle, err + } + r.metrics.Timer(metrics.RegoModuleParse).Start() + mf.Parsed, err = ast.ParseModuleWithOpts(mf.Path, string(mf.Raw), modulePopts) + r.metrics.Timer(metrics.RegoModuleParse).Stop() + if err != nil { + return bundle, err + } + bundle.Modules = append(bundle.Modules, mf) + } + + if bundle.Type() == DeltaBundleType { + if len(bundle.Data) != 0 { + return bundle, fmt.Errorf("delta bundle expected to contain only patch file but data files found") + } + + if len(bundle.Modules) != 0 { + return bundle, fmt.Errorf("delta bundle expected to contain only patch file but policy files found") + } + + if len(bundle.WasmModules) != 0 { + return bundle, fmt.Errorf("delta bundle expected to contain only patch file but wasm files found") + } + + if r.persist { + return bundle, fmt.Errorf("'persist' property is true in config. persisting delta bundle to disk is not supported") + } + } + + // check if the bundle signatures specify any files that weren't found in the bundle + if bundle.Type() == SnapshotBundleType && len(r.files) != 0 { + extra := []string{} + for k := range r.files { + extra = append(extra, k) + } + return bundle, fmt.Errorf("file(s) %v specified in bundle signatures but not found in the target bundle", extra) + } + + if err := bundle.Manifest.validateAndInjectDefaults(bundle); err != nil { + return bundle, err + } + + // Inject the wasm module entrypoint refs into the WasmModuleFile structs + epMap := map[string][]string{} + for _, r := range bundle.Manifest.WasmResolvers { + epMap[r.Module] = append(epMap[r.Module], r.Entrypoint) + } + for i := 0; i < len(bundle.WasmModules); i++ { + entrypoints := epMap[bundle.WasmModules[i].Path] + for _, entrypoint := range entrypoints { + ref, err := ast.PtrRef(ast.DefaultRootDocument, entrypoint) + if err != nil { + return bundle, fmt.Errorf("failed to parse wasm module entrypoint '%s': %s", entrypoint, err) + } + bundle.WasmModules[i].Entrypoints = append(bundle.WasmModules[i].Entrypoints, ref) + } + } + + if r.includeManifestInData { + var metadata map[string]interface{} + + b, err := json.Marshal(&bundle.Manifest) + if err != nil { + return bundle, fmt.Errorf("bundle load failed on manifest marshal: %w", err) + } + + err = util.UnmarshalJSON(b, &metadata) + if err != nil { + return bundle, fmt.Errorf("bundle load failed on manifest unmarshal: %w", err) + } + + // For backwards compatibility always write to the old unnamed manifest path + // This will *not* be correct if >1 bundle is in use... + if err := bundle.insertData(legacyManifestStoragePath, metadata); err != nil { + return bundle, fmt.Errorf("bundle load failed on %v: %w", legacyRevisionStoragePath, err) + } + } + + bundle.Etag = r.etag + bundle.Raw = raw + + return bundle, nil +} + +func (r *Reader) isFileExcluded(path string) bool { + for _, e := range r.verificationConfig.Exclude { + match, _ := filepath.Match(e, path) + if match { + return true + } + } + return false +} + +func (r *Reader) checkSignaturesAndDescriptors(signatures SignaturesConfig) error { + if r.skipVerify { + return nil + } + + if signatures.isEmpty() && r.verificationConfig != nil && r.verificationConfig.KeyID != "" { + return fmt.Errorf("bundle missing .signatures.json file") + } + + if !signatures.isEmpty() { + if r.verificationConfig == nil { + return fmt.Errorf("verification key not provided") + } + + // verify the JWT signatures included in the `.signatures.json` file + if err := r.verifyBundleSignature(signatures); err != nil { + return err + } + } + return nil +} + +func (r *Reader) verifyBundleSignature(sc SignaturesConfig) error { + var err error + r.files, err = VerifyBundleSignature(sc, r.verificationConfig) + return err +} + +func (r *Reader) verifyBundleFile(path string, data bytes.Buffer) error { + return VerifyBundleFile(path, data, r.files) +} + +func (r *Reader) fullPath(path string) string { + if r.baseDir != "" { + path = filepath.Join(r.baseDir, path) + } + return path +} + +// Write is deprecated. Use NewWriter instead. +func Write(w io.Writer, bundle Bundle) error { + return NewWriter(w). + UseModulePath(true). + DisableFormat(true). + Write(bundle) +} + +// Writer implements bundle serialization. +type Writer struct { + usePath bool + disableFormat bool + w io.Writer +} + +// NewWriter returns a bundle writer that writes to w. +func NewWriter(w io.Writer) *Writer { + return &Writer{ + w: w, + } +} + +// UseModulePath configures the writer to use the module file path instead of the +// module file URL during serialization. This is for backwards compatibility. +func (w *Writer) UseModulePath(yes bool) *Writer { + w.usePath = yes + return w +} + +// DisableFormat configures the writer to just write out raw bytes instead +// of formatting modules before serialization. +func (w *Writer) DisableFormat(yes bool) *Writer { + w.disableFormat = yes + return w +} + +// Write writes the bundle to the writer's output stream. +func (w *Writer) Write(bundle Bundle) error { + gw := gzip.NewWriter(w.w) + tw := tar.NewWriter(gw) + + bundleType := bundle.Type() + + if bundleType == SnapshotBundleType { + var buf bytes.Buffer + + if err := json.NewEncoder(&buf).Encode(bundle.Data); err != nil { + return err + } + + if err := archive.WriteFile(tw, "data.json", buf.Bytes()); err != nil { + return err + } + + for _, module := range bundle.Modules { + path := module.URL + if w.usePath { + path = module.Path + } + + if err := archive.WriteFile(tw, path, module.Raw); err != nil { + return err + } + } + + if err := w.writeWasm(tw, bundle); err != nil { + return err + } + + if err := writeSignatures(tw, bundle); err != nil { + return err + } + + if err := w.writePlan(tw, bundle); err != nil { + return err + } + } else if bundleType == DeltaBundleType { + if err := writePatch(tw, bundle); err != nil { + return err + } + } + + if err := writeManifest(tw, bundle); err != nil { + return err + } + + if err := tw.Close(); err != nil { + return err + } + + return gw.Close() +} + +func (w *Writer) writeWasm(tw *tar.Writer, bundle Bundle) error { + for _, wm := range bundle.WasmModules { + path := wm.URL + if w.usePath { + path = wm.Path + } + + err := archive.WriteFile(tw, path, wm.Raw) + if err != nil { + return err + } + } + + if len(bundle.Wasm) > 0 { + err := archive.WriteFile(tw, "/"+WasmFile, bundle.Wasm) + if err != nil { + return err + } + } + + return nil +} + +func (w *Writer) writePlan(tw *tar.Writer, bundle Bundle) error { + for _, wm := range bundle.PlanModules { + path := wm.URL + if w.usePath { + path = wm.Path + } + + err := archive.WriteFile(tw, path, wm.Raw) + if err != nil { + return err + } + } + + return nil +} + +func writeManifest(tw *tar.Writer, bundle Bundle) error { + + if bundle.Manifest.Empty() { + return nil + } + + var buf bytes.Buffer + + if err := json.NewEncoder(&buf).Encode(bundle.Manifest); err != nil { + return err + } + + return archive.WriteFile(tw, ManifestExt, buf.Bytes()) +} + +func writePatch(tw *tar.Writer, bundle Bundle) error { + + var buf bytes.Buffer + + if err := json.NewEncoder(&buf).Encode(bundle.Patch); err != nil { + return err + } + + return archive.WriteFile(tw, patchFile, buf.Bytes()) +} + +func writeSignatures(tw *tar.Writer, bundle Bundle) error { + + if bundle.Signatures.isEmpty() { + return nil + } + + bs, err := json.MarshalIndent(bundle.Signatures, "", " ") + if err != nil { + return err + } + + return archive.WriteFile(tw, fmt.Sprintf(".%v", SignaturesFile), bs) +} + +func hashBundleFiles(hash SignatureHasher, b *Bundle) ([]FileInfo, error) { + + files := []FileInfo{} + + bs, err := hash.HashFile(b.Data) + if err != nil { + return files, err + } + files = append(files, NewFile(strings.TrimPrefix("data.json", "/"), hex.EncodeToString(bs), defaultHashingAlg)) + + if len(b.Wasm) != 0 { + bs, err := hash.HashFile(b.Wasm) + if err != nil { + return files, err + } + files = append(files, NewFile(strings.TrimPrefix(WasmFile, "/"), hex.EncodeToString(bs), defaultHashingAlg)) + } + + for _, wasmModule := range b.WasmModules { + bs, err := hash.HashFile(wasmModule.Raw) + if err != nil { + return files, err + } + files = append(files, NewFile(strings.TrimPrefix(wasmModule.Path, "/"), hex.EncodeToString(bs), defaultHashingAlg)) + } + + for _, planmodule := range b.PlanModules { + bs, err := hash.HashFile(planmodule.Raw) + if err != nil { + return files, err + } + files = append(files, NewFile(strings.TrimPrefix(planmodule.Path, "/"), hex.EncodeToString(bs), defaultHashingAlg)) + } + + // If the manifest is essentially empty, don't add it to the signatures since it + // won't be written to the bundle. Otherwise: + // parse the manifest into a JSON structure; + // then recursively order the fields of all objects alphabetically and then apply + // the hash function to result to compute the hash. + if !b.Manifest.Empty() { + mbs, err := json.Marshal(b.Manifest) + if err != nil { + return files, err + } + + var result map[string]interface{} + if err := util.Unmarshal(mbs, &result); err != nil { + return files, err + } + + bs, err = hash.HashFile(result) + if err != nil { + return files, err + } + + files = append(files, NewFile(strings.TrimPrefix(ManifestExt, "/"), hex.EncodeToString(bs), defaultHashingAlg)) + } + + return files, err +} + +// FormatModules formats Rego modules +// Modules will be formatted to comply with [ast.DefaultRegoVersion], but Rego compatibility of individual parsed modules will be respected (e.g. if 'rego.v1' is imported). +func (b *Bundle) FormatModules(useModulePath bool) error { + return b.FormatModulesForRegoVersion(ast.DefaultRegoVersion, true, useModulePath) +} + +// FormatModulesForRegoVersion formats Rego modules to comply with a given Rego version +func (b *Bundle) FormatModulesForRegoVersion(version ast.RegoVersion, preserveModuleRegoVersion bool, useModulePath bool) error { + var err error + + for i, module := range b.Modules { + opts := format.Opts{} + if preserveModuleRegoVersion { + opts.RegoVersion = module.Parsed.RegoVersion() + opts.ParserOptions = &ast.ParserOptions{ + RegoVersion: opts.RegoVersion, + } + } else { + opts.RegoVersion = version + } + + if module.Raw == nil { + module.Raw, err = format.AstWithOpts(module.Parsed, opts) + if err != nil { + return err + } + } else { + path := module.URL + if useModulePath { + path = module.Path + } + + module.Raw, err = format.SourceWithOpts(path, module.Raw, opts) + if err != nil { + return err + } + } + b.Modules[i].Raw = module.Raw + } + return nil +} + +// GenerateSignature generates the signature for the given bundle. +func (b *Bundle) GenerateSignature(signingConfig *SigningConfig, keyID string, useModulePath bool) error { + + hash, err := NewSignatureHasher(HashingAlgorithm(defaultHashingAlg)) + if err != nil { + return err + } + + files := []FileInfo{} + + for _, module := range b.Modules { + bytes, err := hash.HashFile(module.Raw) + if err != nil { + return err + } + + path := module.URL + if useModulePath { + path = module.Path + } + files = append(files, NewFile(strings.TrimPrefix(path, "/"), hex.EncodeToString(bytes), defaultHashingAlg)) + } + + result, err := hashBundleFiles(hash, b) + if err != nil { + return err + } + files = append(files, result...) + + // generate signed token + token, err := GenerateSignedToken(files, signingConfig, keyID) + if err != nil { + return err + } + + if b.Signatures.isEmpty() { + b.Signatures = SignaturesConfig{} + } + + if signingConfig.Plugin != "" { + b.Signatures.Plugin = signingConfig.Plugin + } + + b.Signatures.Signatures = []string{token} + + return nil +} + +// ParsedModules returns a map of parsed modules with names that are +// unique and human readable for the given a bundle name. +func (b *Bundle) ParsedModules(bundleName string) map[string]*ast.Module { + + mods := make(map[string]*ast.Module, len(b.Modules)) + + for _, mf := range b.Modules { + mods[modulePathWithPrefix(bundleName, mf.Path)] = mf.Parsed + } + + return mods +} + +func (b *Bundle) RegoVersion(def ast.RegoVersion) ast.RegoVersion { + if v := b.Manifest.RegoVersion; v != nil { + if *v == 0 { + return ast.RegoV0 + } else if *v == 1 { + return ast.RegoV1 + } + } + return def +} + +func (b *Bundle) SetRegoVersion(v ast.RegoVersion) { + b.Manifest.SetRegoVersion(v) +} + +// RegoVersionForFile returns the rego-version for the specified file path. +// If there is no defined version for the given path, the default version def is returned. +// If the version does not correspond to ast.RegoV0 or ast.RegoV1, an error is returned. +func (b *Bundle) RegoVersionForFile(path string, def ast.RegoVersion) (ast.RegoVersion, error) { + if def == ast.RegoUndefined { + def = ast.DefaultRegoVersion + } + + version, err := b.Manifest.numericRegoVersionForFile(path) + if err != nil { + return def, err + } else if version == nil { + return def, nil + } else if *version == 0 { + return ast.RegoV0, nil + } else if *version == 1 { + return ast.RegoV1, nil + } + return def, fmt.Errorf("unknown bundle rego-version %d for file '%s'", *version, path) +} + +func (m *Manifest) numericRegoVersionForFile(path string) (*int, error) { + var version *int + + if len(m.FileRegoVersions) != len(m.compiledFileRegoVersions) { + m.compiledFileRegoVersions = make([]fileRegoVersion, 0, len(m.FileRegoVersions)) + for pattern, v := range m.FileRegoVersions { + compiled, err := glob.Compile(pattern) + if err != nil { + return nil, fmt.Errorf("failed to compile glob pattern %s: %s", pattern, err) + } + m.compiledFileRegoVersions = append(m.compiledFileRegoVersions, fileRegoVersion{compiled, v}) + } + } + + for _, fv := range m.compiledFileRegoVersions { + if fv.path.Match(path) { + version = &fv.version + break + } + } + + if version == nil { + version = m.RegoVersion + } + return version, nil +} + +// Equal returns true if this bundle's contents equal the other bundle's +// contents. +func (b Bundle) Equal(other Bundle) bool { + if !reflect.DeepEqual(b.Data, other.Data) { + return false + } + + if len(b.Modules) != len(other.Modules) { + return false + } + for i := range b.Modules { + // To support bundles built from rootless filesystems we ignore a "/" prefix + // for URLs and Paths, such that "/file" and "file" are equivalent + if strings.TrimPrefix(b.Modules[i].URL, string(filepath.Separator)) != + strings.TrimPrefix(other.Modules[i].URL, string(filepath.Separator)) { + return false + } + if strings.TrimPrefix(b.Modules[i].Path, string(filepath.Separator)) != + strings.TrimPrefix(other.Modules[i].Path, string(filepath.Separator)) { + return false + } + if !b.Modules[i].Parsed.Equal(other.Modules[i].Parsed) { + return false + } + if !bytes.Equal(b.Modules[i].Raw, other.Modules[i].Raw) { + return false + } + } + if (b.Wasm == nil && other.Wasm != nil) || (b.Wasm != nil && other.Wasm == nil) { + return false + } + + return bytes.Equal(b.Wasm, other.Wasm) +} + +// Copy returns a deep copy of the bundle. +func (b Bundle) Copy() Bundle { + + // Copy data. + var x interface{} = b.Data + + if err := util.RoundTrip(&x); err != nil { + panic(err) + } + + if x != nil { + b.Data = x.(map[string]interface{}) + } + + // Copy modules. + for i := range b.Modules { + bs := make([]byte, len(b.Modules[i].Raw)) + copy(bs, b.Modules[i].Raw) + b.Modules[i].Raw = bs + b.Modules[i].Parsed = b.Modules[i].Parsed.Copy() + } + + // Copy manifest. + b.Manifest = b.Manifest.Copy() + + return b +} + +func (b *Bundle) insertData(key []string, value interface{}) error { + // Build an object with the full structure for the value + obj, err := mktree(key, value) + if err != nil { + return err + } + + // Merge the new data in with the current bundle data object + merged, ok := merge.InterfaceMaps(b.Data, obj) + if !ok { + return fmt.Errorf("failed to insert data file from path %s", filepath.Join(key...)) + } + + b.Data = merged + + return nil +} + +func (b *Bundle) readData(key []string) *interface{} { + + if len(key) == 0 { + if len(b.Data) == 0 { + return nil + } + var result interface{} = b.Data + return &result + } + + node := b.Data + + for i := 0; i < len(key)-1; i++ { + + child, ok := node[key[i]] + if !ok { + return nil + } + + childObj, ok := child.(map[string]interface{}) + if !ok { + return nil + } + + node = childObj + } + + child, ok := node[key[len(key)-1]] + if !ok { + return nil + } + + return &child +} + +// Type returns the type of the bundle. +func (b *Bundle) Type() string { + if len(b.Patch.Data) != 0 { + return DeltaBundleType + } + return SnapshotBundleType +} + +func mktree(path []string, value interface{}) (map[string]interface{}, error) { + if len(path) == 0 { + // For 0 length path the value is the full tree. + obj, ok := value.(map[string]interface{}) + if !ok { + return nil, fmt.Errorf("root value must be object") + } + return obj, nil + } + + dir := map[string]interface{}{} + for i := len(path) - 1; i > 0; i-- { + dir[path[i]] = value + value = dir + dir = map[string]interface{}{} + } + dir[path[0]] = value + + return dir, nil +} + +// Merge accepts a set of bundles and merges them into a single result bundle. If there are +// any conflicts during the merge (e.g., with roots) an error is returned. The result bundle +// will have an empty revision except in the special case where a single bundle is provided +// (and in that case the bundle is just returned unmodified.) +func Merge(bundles []*Bundle) (*Bundle, error) { + return MergeWithRegoVersion(bundles, ast.DefaultRegoVersion, false) +} + +// MergeWithRegoVersion creates a merged bundle from the provided bundles, similar to Merge. +// If more than one bundle is provided, the rego version of the result bundle is set to the provided regoVersion. +// Any Rego files in a bundle of conflicting rego version will be marked in the result's manifest with the rego version +// of its original bundle. If the Rego file already had an overriding rego version, it will be preserved. +// If a single bundle is provided, it will retain any rego version information it already had. If it has none, the +// provided regoVersion will be applied to it. +// If usePath is true, per-file rego-versions will be calculated using the file's ModuleFile.Path; otherwise, the file's +// ModuleFile.URL will be used. +func MergeWithRegoVersion(bundles []*Bundle, regoVersion ast.RegoVersion, usePath bool) (*Bundle, error) { + + if len(bundles) == 0 { + return nil, errors.New("expected at least one bundle") + } + + if regoVersion == ast.RegoUndefined { + regoVersion = ast.DefaultRegoVersion + } + + if len(bundles) == 1 { + result := bundles[0] + // We respect the bundle rego-version, defaulting to the provided rego version if not set. + result.SetRegoVersion(result.RegoVersion(regoVersion)) + fileRegoVersions, err := bundleRegoVersions(result, result.RegoVersion(regoVersion), usePath) + if err != nil { + return nil, err + } + result.Manifest.FileRegoVersions = fileRegoVersions + return result, nil + } + + var roots []string + var result Bundle + + for _, b := range bundles { + + if b.Manifest.Roots == nil { + return nil, errors.New("bundle manifest not initialized") + } + + roots = append(roots, *b.Manifest.Roots...) + + result.Modules = append(result.Modules, b.Modules...) + + for _, root := range *b.Manifest.Roots { + key := strings.Split(root, "/") + if val := b.readData(key); val != nil { + if err := result.insertData(key, *val); err != nil { + return nil, err + } + } + } + + result.Manifest.WasmResolvers = append(result.Manifest.WasmResolvers, b.Manifest.WasmResolvers...) + result.WasmModules = append(result.WasmModules, b.WasmModules...) + result.PlanModules = append(result.PlanModules, b.PlanModules...) + + if b.Manifest.RegoVersion != nil || len(b.Manifest.FileRegoVersions) > 0 { + if result.Manifest.FileRegoVersions == nil { + result.Manifest.FileRegoVersions = map[string]int{} + } + + fileRegoVersions, err := bundleRegoVersions(b, regoVersion, usePath) + if err != nil { + return nil, err + } + for k, v := range fileRegoVersions { + result.Manifest.FileRegoVersions[k] = v + } + } + } + + // We respect the bundle rego-version, defaulting to the provided rego version if not set. + result.SetRegoVersion(result.RegoVersion(regoVersion)) + + if result.Data == nil { + result.Data = map[string]interface{}{} + } + + result.Manifest.Roots = &roots + + if err := result.Manifest.validateAndInjectDefaults(result); err != nil { + return nil, err + } + + return &result, nil +} + +func bundleRegoVersions(bundle *Bundle, regoVersion ast.RegoVersion, usePath bool) (map[string]int, error) { + fileRegoVersions := map[string]int{} + + // we drop the bundle-global rego versions and record individual rego versions for each module. + for _, m := range bundle.Modules { + // We fetch rego-version by the path relative to the bundle root, as the complete path of the module might + // contain the path between OPA working directory and the bundle root. + v, err := bundle.RegoVersionForFile(bundleRelativePath(m, usePath), bundle.RegoVersion(regoVersion)) + if err != nil { + return nil, err + } + // only record the rego version if it's different from one applied globally to the result bundle + if v != regoVersion { + // We store the rego version by the absolute path to the bundle root, as this will be the - possibly new - path + // to the module inside the merged bundle. + fileRegoVersions[bundleAbsolutePath(m, usePath)] = v.Int() + } + } + + return fileRegoVersions, nil +} + +func bundleRelativePath(m ModuleFile, usePath bool) string { + p := m.RelativePath + if p == "" { + if usePath { + p = m.Path + } else { + p = m.URL + } + } + return p +} + +func bundleAbsolutePath(m ModuleFile, usePath bool) string { + var p string + if usePath { + p = m.Path + } else { + p = m.URL + } + if !path.IsAbs(p) { + p = "/" + p + } + return path.Clean(p) +} + +// RootPathsOverlap takes in two bundle root paths and returns true if they overlap. +func RootPathsOverlap(pathA string, pathB string) bool { + a := rootPathSegments(pathA) + b := rootPathSegments(pathB) + return rootContains(a, b) || rootContains(b, a) +} + +// RootPathsContain takes a set of bundle root paths and returns true if the path is contained. +func RootPathsContain(roots []string, path string) bool { + segments := rootPathSegments(path) + for i := range roots { + if rootContains(rootPathSegments(roots[i]), segments) { + return true + } + } + return false +} + +func rootPathSegments(path string) []string { + return strings.Split(path, "/") +} + +func rootContains(root []string, other []string) bool { + + // A single segment, empty string root always contains the other. + if len(root) == 1 && root[0] == "" { + return true + } + + if len(root) > len(other) { + return false + } + + for j := range root { + if root[j] != other[j] { + return false + } + } + + return true +} + +func insertValue(b *Bundle, path string, value interface{}) error { + if err := b.insertData(getNormalizedPath(path), value); err != nil { + return fmt.Errorf("bundle load failed on %v: %w", path, err) + } + return nil +} + +func getNormalizedPath(path string) []string { + // Remove leading / and . characters from the directory path. If the bundle + // was written with OPA then the paths will contain a leading slash. On the + // other hand, if the path is empty, filepath.Dir will return '.'. + // Note: filepath.Dir can return paths with '\' separators, always use + // filepath.ToSlash to keep them normalized. + dirpath := strings.TrimLeft(normalizePath(filepath.Dir(path)), "/.") + var key []string + if dirpath != "" { + key = strings.Split(dirpath, "/") + } + return key +} + +func dfs(value interface{}, path string, fn func(string, interface{}) (bool, error)) error { + if stop, err := fn(path, value); err != nil { + return err + } else if stop { + return nil + } + obj, ok := value.(map[string]interface{}) + if !ok { + return nil + } + for key := range obj { + if err := dfs(obj[key], path+"/"+key, fn); err != nil { + return err + } + } + return nil +} + +func modulePathWithPrefix(bundleName string, modulePath string) string { + // Default prefix is just the bundle name + prefix := bundleName + + // Bundle names are sometimes just file paths, some of which + // are full urls (file:///foo/). Parse these and only use the path. + parsed, err := url.Parse(bundleName) + if err == nil { + prefix = filepath.Join(parsed.Host, parsed.Path) + } + + // Note: filepath.Join can return paths with '\' separators, always use + // filepath.ToSlash to keep them normalized. + return normalizePath(filepath.Join(prefix, modulePath)) +} + +// IsStructuredDoc checks if the file name equals a structured file extension ex. ".json" +func IsStructuredDoc(name string) bool { + return filepath.Base(name) == dataFile || filepath.Base(name) == yamlDataFile || + filepath.Base(name) == SignaturesFile || filepath.Base(name) == ManifestExt +} + +func preProcessBundle(loader DirectoryLoader, skipVerify bool, sizeLimitBytes int64) (SignaturesConfig, Patch, []*Descriptor, error) { + descriptors := []*Descriptor{} + var signatures SignaturesConfig + var patch Patch + + for { + f, err := loader.NextFile() + if err == io.EOF { + break + } + + if err != nil { + return signatures, patch, nil, fmt.Errorf("bundle read failed: %w", err) + } + + // check for the signatures file + if !skipVerify && strings.HasSuffix(f.Path(), SignaturesFile) { + buf, err := readFile(f, sizeLimitBytes) + if err != nil { + return signatures, patch, nil, err + } + + if err := util.NewJSONDecoder(&buf).Decode(&signatures); err != nil { + return signatures, patch, nil, fmt.Errorf("bundle load failed on signatures decode: %w", err) + } + } else if !strings.HasSuffix(f.Path(), SignaturesFile) { + descriptors = append(descriptors, f) + + if filepath.Base(f.Path()) == patchFile { + + var b bytes.Buffer + tee := io.TeeReader(f.reader, &b) + f.reader = tee + + buf, err := readFile(f, sizeLimitBytes) + if err != nil { + return signatures, patch, nil, err + } + + if err := util.NewJSONDecoder(&buf).Decode(&patch); err != nil { + return signatures, patch, nil, fmt.Errorf("bundle load failed on patch decode: %w", err) + } + + f.reader = &b + } + } + } + return signatures, patch, descriptors, nil +} + +func readFile(f *Descriptor, sizeLimitBytes int64) (bytes.Buffer, error) { + // Case for pre-loaded byte buffers, like those from the tarballLoader. + if bb, ok := f.reader.(*bytes.Buffer); ok { + _ = f.Close() // always close, even on error + + if int64(bb.Len()) >= sizeLimitBytes { + return *bb, fmt.Errorf("bundle file '%v' size (%d bytes) exceeded max size (%v bytes)", + strings.TrimPrefix(f.Path(), "/"), bb.Len(), sizeLimitBytes-1) + } + + return *bb, nil + } + + // Case for *lazyFile readers: + if lf, ok := f.reader.(*lazyFile); ok { + var buf bytes.Buffer + if lf.file == nil { + var err error + if lf.file, err = os.Open(lf.path); err != nil { + return buf, fmt.Errorf("failed to open file %s: %w", f.path, err) + } + } + // Bail out if we can't read the whole file-- there's nothing useful we can do at that point! + fileSize, _ := fstatFileSize(lf.file) + if fileSize > sizeLimitBytes { + return buf, fmt.Errorf(maxSizeLimitBytesErrMsg, strings.TrimPrefix(f.Path(), "/"), fileSize, sizeLimitBytes-1) + } + // Prealloc the buffer for the file read. + buffer := make([]byte, fileSize) + _, err := io.ReadFull(lf.file, buffer) + if err != nil { + return buf, err + } + _ = lf.file.Close() // always close, even on error + + // Note(philipc): Replace the lazyFile reader in the *Descriptor with a + // pointer to the wrapping bytes.Buffer, so that we don't re-read the + // file on disk again by accident. + buf = *bytes.NewBuffer(buffer) + f.reader = &buf + return buf, nil + } + + // Fallback case: + var buf bytes.Buffer + n, err := f.Read(&buf, sizeLimitBytes) + _ = f.Close() // always close, even on error + + if err != nil && err != io.EOF { + return buf, err + } else if err == nil && n >= sizeLimitBytes { + return buf, fmt.Errorf(maxSizeLimitBytesErrMsg, strings.TrimPrefix(f.Path(), "/"), n, sizeLimitBytes-1) + } + + return buf, nil +} + +// Takes an already open file handle and invokes the os.Stat system call on it +// to determine the file's size. Passes any errors from *File.Stat on up to the +// caller. +func fstatFileSize(f *os.File) (int64, error) { + fileInfo, err := f.Stat() + if err != nil { + return 0, err + } + return fileInfo.Size(), nil +} + +func normalizePath(p string) string { + return filepath.ToSlash(p) +} diff --git a/vendor/github.com/open-policy-agent/opa/v1/bundle/file.go b/vendor/github.com/open-policy-agent/opa/v1/bundle/file.go new file mode 100644 index 000000000..2f1cdb8ea --- /dev/null +++ b/vendor/github.com/open-policy-agent/opa/v1/bundle/file.go @@ -0,0 +1,508 @@ +package bundle + +import ( + "archive/tar" + "bytes" + "compress/gzip" + "fmt" + "io" + "io/fs" + "os" + "path/filepath" + "sort" + "strings" + "sync" + + "github.com/open-policy-agent/opa/v1/loader/filter" + + "github.com/open-policy-agent/opa/v1/storage" +) + +const maxSizeLimitBytesErrMsg = "bundle file %s size (%d bytes) exceeds configured size_limit_bytes (%d bytes)" + +// Descriptor contains information about a file and +// can be used to read the file contents. +type Descriptor struct { + url string + path string + reader io.Reader + closer io.Closer + closeOnce *sync.Once +} + +// lazyFile defers reading the file until the first call of Read +type lazyFile struct { + path string + file *os.File +} + +// newLazyFile creates a new instance of lazyFile +func newLazyFile(path string) *lazyFile { + return &lazyFile{path: path} +} + +// Read implements io.Reader. It will check if the file has been opened +// and open it if it has not before attempting to read using the file's +// read method +func (f *lazyFile) Read(b []byte) (int, error) { + var err error + + if f.file == nil { + if f.file, err = os.Open(f.path); err != nil { + return 0, fmt.Errorf("failed to open file %s: %w", f.path, err) + } + } + + return f.file.Read(b) +} + +// Close closes the lazy file if it has been opened using the file's +// close method +func (f *lazyFile) Close() error { + if f.file != nil { + return f.file.Close() + } + + return nil +} + +func NewDescriptor(url, path string, reader io.Reader) *Descriptor { + return &Descriptor{ + url: url, + path: path, + reader: reader, + } +} + +func (d *Descriptor) WithCloser(closer io.Closer) *Descriptor { + d.closer = closer + d.closeOnce = new(sync.Once) + return d +} + +// Path returns the path of the file. +func (d *Descriptor) Path() string { + return d.path +} + +// URL returns the url of the file. +func (d *Descriptor) URL() string { + return d.url +} + +// Read will read all the contents from the file the Descriptor refers to +// into the dest writer up n bytes. Will return an io.EOF error +// if EOF is encountered before n bytes are read. +func (d *Descriptor) Read(dest io.Writer, n int64) (int64, error) { + n, err := io.CopyN(dest, d.reader, n) + return n, err +} + +// Close the file, on some Loader implementations this might be a no-op. +// It should *always* be called regardless of file. +func (d *Descriptor) Close() error { + var err error + if d.closer != nil { + d.closeOnce.Do(func() { + err = d.closer.Close() + }) + } + return err +} + +type PathFormat int64 + +const ( + Chrooted PathFormat = iota + SlashRooted + Passthrough +) + +// DirectoryLoader defines an interface which can be used to load +// files from a directory by iterating over each one in the tree. +type DirectoryLoader interface { + // NextFile must return io.EOF if there is no next value. The returned + // descriptor should *always* be closed when no longer needed. + NextFile() (*Descriptor, error) + WithFilter(filter filter.LoaderFilter) DirectoryLoader + WithPathFormat(PathFormat) DirectoryLoader + WithSizeLimitBytes(sizeLimitBytes int64) DirectoryLoader + WithFollowSymlinks(followSymlinks bool) DirectoryLoader +} + +type dirLoader struct { + root string + files []string + idx int + filter filter.LoaderFilter + pathFormat PathFormat + maxSizeLimitBytes int64 + followSymlinks bool +} + +// Normalize root directory, ex "./src/bundle" -> "src/bundle" +// We don't need an absolute path, but this makes the joined/trimmed +// paths more uniform. +func normalizeRootDirectory(root string) string { + if len(root) > 1 { + if root[0] == '.' && root[1] == filepath.Separator { + if len(root) == 2 { + root = root[:1] // "./" -> "." + } else { + root = root[2:] // remove leading "./" + } + } + } + return root +} + +// NewDirectoryLoader returns a basic DirectoryLoader implementation +// that will load files from a given root directory path. +func NewDirectoryLoader(root string) DirectoryLoader { + d := dirLoader{ + root: normalizeRootDirectory(root), + pathFormat: Chrooted, + } + return &d +} + +// WithFilter specifies the filter object to use to filter files while loading bundles +func (d *dirLoader) WithFilter(filter filter.LoaderFilter) DirectoryLoader { + d.filter = filter + return d +} + +// WithPathFormat specifies how a path is formatted in a Descriptor +func (d *dirLoader) WithPathFormat(pathFormat PathFormat) DirectoryLoader { + d.pathFormat = pathFormat + return d +} + +// WithSizeLimitBytes specifies the maximum size of any file in the directory to read +func (d *dirLoader) WithSizeLimitBytes(sizeLimitBytes int64) DirectoryLoader { + d.maxSizeLimitBytes = sizeLimitBytes + return d +} + +// WithFollowSymlinks specifies whether to follow symlinks when loading files from the directory +func (d *dirLoader) WithFollowSymlinks(followSymlinks bool) DirectoryLoader { + d.followSymlinks = followSymlinks + return d +} + +func formatPath(fileName string, root string, pathFormat PathFormat) string { + switch pathFormat { + case SlashRooted: + if !strings.HasPrefix(fileName, string(filepath.Separator)) { + return string(filepath.Separator) + fileName + } + return fileName + case Chrooted: + // Trim off the root directory and return path as if chrooted + result := strings.TrimPrefix(fileName, filepath.FromSlash(root)) + if root == "." && filepath.Base(fileName) == ManifestExt { + result = fileName + } + if !strings.HasPrefix(result, string(filepath.Separator)) { + result = string(filepath.Separator) + result + } + return result + case Passthrough: + fallthrough + default: + return fileName + } +} + +// NextFile iterates to the next file in the directory tree +// and returns a file Descriptor for the file. +func (d *dirLoader) NextFile() (*Descriptor, error) { + // build a list of all files we will iterate over and read, but only one time + if d.files == nil { + d.files = []string{} + err := filepath.Walk(d.root, func(path string, info os.FileInfo, _ error) error { + if info == nil { + return nil + } + + if info.Mode().IsRegular() { + if d.filter != nil && d.filter(filepath.ToSlash(path), info, getdepth(path, false)) { + return nil + } + if d.maxSizeLimitBytes > 0 && info.Size() > d.maxSizeLimitBytes { + return fmt.Errorf(maxSizeLimitBytesErrMsg, strings.TrimPrefix(path, "/"), info.Size(), d.maxSizeLimitBytes) + } + d.files = append(d.files, path) + } else if d.followSymlinks && info.Mode().Type()&fs.ModeSymlink == fs.ModeSymlink { + if d.filter != nil && d.filter(filepath.ToSlash(path), info, getdepth(path, false)) { + return nil + } + if d.maxSizeLimitBytes > 0 && info.Size() > d.maxSizeLimitBytes { + return fmt.Errorf(maxSizeLimitBytesErrMsg, strings.TrimPrefix(path, "/"), info.Size(), d.maxSizeLimitBytes) + } + d.files = append(d.files, path) + } else if info.Mode().IsDir() { + if d.filter != nil && d.filter(filepath.ToSlash(path), info, getdepth(path, true)) { + return filepath.SkipDir + } + } + return nil + }) + if err != nil { + return nil, fmt.Errorf("failed to list files: %w", err) + } + } + + // If done reading files then just return io.EOF + // errors for each NextFile() call + if d.idx >= len(d.files) { + return nil, io.EOF + } + + fileName := d.files[d.idx] + d.idx++ + fh := newLazyFile(fileName) + + cleanedPath := formatPath(fileName, d.root, d.pathFormat) + f := NewDescriptor(filepath.Join(d.root, cleanedPath), cleanedPath, fh).WithCloser(fh) + return f, nil +} + +type tarballLoader struct { + baseURL string + r io.Reader + tr *tar.Reader + files []file + idx int + filter filter.LoaderFilter + skipDir map[string]struct{} + pathFormat PathFormat + maxSizeLimitBytes int64 +} + +type file struct { + name string + reader io.Reader + path storage.Path + raw []byte +} + +// NewTarballLoader is deprecated. Use NewTarballLoaderWithBaseURL instead. +func NewTarballLoader(r io.Reader) DirectoryLoader { + l := tarballLoader{ + r: r, + pathFormat: Passthrough, + } + return &l +} + +// NewTarballLoaderWithBaseURL returns a new DirectoryLoader that reads +// files out of a gzipped tar archive. The file URLs will be prefixed +// with the baseURL. +func NewTarballLoaderWithBaseURL(r io.Reader, baseURL string) DirectoryLoader { + l := tarballLoader{ + baseURL: strings.TrimSuffix(baseURL, "/"), + r: r, + pathFormat: Passthrough, + } + return &l +} + +// WithFilter specifies the filter object to use to filter files while loading bundles +func (t *tarballLoader) WithFilter(filter filter.LoaderFilter) DirectoryLoader { + t.filter = filter + return t +} + +// WithPathFormat specifies how a path is formatted in a Descriptor +func (t *tarballLoader) WithPathFormat(pathFormat PathFormat) DirectoryLoader { + t.pathFormat = pathFormat + return t +} + +// WithSizeLimitBytes specifies the maximum size of any file in the tarball to read +func (t *tarballLoader) WithSizeLimitBytes(sizeLimitBytes int64) DirectoryLoader { + t.maxSizeLimitBytes = sizeLimitBytes + return t +} + +// WithFollowSymlinks is a no-op for tarballLoader +func (t *tarballLoader) WithFollowSymlinks(_ bool) DirectoryLoader { + return t +} + +// NextFile iterates to the next file in the directory tree +// and returns a file Descriptor for the file. +func (t *tarballLoader) NextFile() (*Descriptor, error) { + if t.tr == nil { + gr, err := gzip.NewReader(t.r) + if err != nil { + return nil, fmt.Errorf("archive read failed: %w", err) + } + + t.tr = tar.NewReader(gr) + } + + if t.files == nil { + t.files = []file{} + + if t.skipDir == nil { + t.skipDir = map[string]struct{}{} + } + + for { + header, err := t.tr.Next() + + if err == io.EOF { + break + } + + if err != nil { + return nil, err + } + + // Keep iterating on the archive until we find a normal file + if header.Typeflag == tar.TypeReg { + + if t.filter != nil { + + if t.filter(filepath.ToSlash(header.Name), header.FileInfo(), getdepth(header.Name, false)) { + continue + } + + basePath := strings.Trim(filepath.Dir(filepath.ToSlash(header.Name)), "/") + + // check if the directory is to be skipped + if _, ok := t.skipDir[basePath]; ok { + continue + } + + match := false + for p := range t.skipDir { + if strings.HasPrefix(basePath, p) { + match = true + break + } + } + + if match { + continue + } + } + + if t.maxSizeLimitBytes > 0 && header.Size > t.maxSizeLimitBytes { + return nil, fmt.Errorf(maxSizeLimitBytesErrMsg, header.Name, header.Size, t.maxSizeLimitBytes) + } + + f := file{name: header.Name} + + // Note(philipc): We rely on the previous size check in this loop for safety. + buf := bytes.NewBuffer(make([]byte, 0, header.Size)) + if _, err := io.Copy(buf, t.tr); err != nil { + return nil, fmt.Errorf("failed to copy file %s: %w", header.Name, err) + } + + f.reader = buf + + t.files = append(t.files, f) + } else if header.Typeflag == tar.TypeDir { + cleanedPath := filepath.ToSlash(header.Name) + if t.filter != nil && t.filter(cleanedPath, header.FileInfo(), getdepth(header.Name, true)) { + t.skipDir[strings.Trim(cleanedPath, "/")] = struct{}{} + } + } + } + } + + // If done reading files then just return io.EOF + // errors for each NextFile() call + if t.idx >= len(t.files) { + return nil, io.EOF + } + + f := t.files[t.idx] + t.idx++ + + cleanedPath := formatPath(f.name, "", t.pathFormat) + d := NewDescriptor(filepath.Join(t.baseURL, cleanedPath), cleanedPath, f.reader) + return d, nil +} + +// Next implements the storage.Iterator interface. +// It iterates to the next policy or data file in the directory tree +// and returns a storage.Update for the file. +func (it *iterator) Next() (*storage.Update, error) { + if it.files == nil { + it.files = []file{} + + for _, item := range it.raw { + f := file{name: item.Path} + + fpath := strings.TrimLeft(normalizePath(filepath.Dir(f.name)), "/.") + if strings.HasSuffix(f.name, RegoExt) { + fpath = strings.Trim(normalizePath(f.name), "/") + } + + p, ok := storage.ParsePathEscaped("/" + fpath) + if !ok { + return nil, fmt.Errorf("storage path invalid: %v", f.name) + } + f.path = p + + f.raw = item.Value + + it.files = append(it.files, f) + } + + sortFilePathAscend(it.files) + } + + // If done reading files then just return io.EOF + // errors for each NextFile() call + if it.idx >= len(it.files) { + return nil, io.EOF + } + + f := it.files[it.idx] + it.idx++ + + isPolicy := false + if strings.HasSuffix(f.name, RegoExt) { + isPolicy = true + } + + return &storage.Update{ + Path: f.path, + Value: f.raw, + IsPolicy: isPolicy, + }, nil +} + +type iterator struct { + raw []Raw + files []file + idx int +} + +func NewIterator(raw []Raw) storage.Iterator { + it := iterator{ + raw: raw, + } + return &it +} + +func sortFilePathAscend(files []file) { + sort.Slice(files, func(i, j int) bool { + return len(files[i].path) < len(files[j].path) + }) +} + +func getdepth(path string, isDir bool) int { + if isDir { + cleanedPath := strings.Trim(filepath.ToSlash(path), "/") + return len(strings.Split(cleanedPath, "/")) + } + + basePath := strings.Trim(filepath.Dir(filepath.ToSlash(path)), "/") + return len(strings.Split(basePath, "/")) +} diff --git a/vendor/github.com/open-policy-agent/opa/v1/bundle/filefs.go b/vendor/github.com/open-policy-agent/opa/v1/bundle/filefs.go new file mode 100644 index 000000000..7ab3de989 --- /dev/null +++ b/vendor/github.com/open-policy-agent/opa/v1/bundle/filefs.go @@ -0,0 +1,143 @@ +//go:build go1.16 +// +build go1.16 + +package bundle + +import ( + "fmt" + "io" + "io/fs" + "path/filepath" + "sync" + + "github.com/open-policy-agent/opa/v1/loader/filter" +) + +const ( + defaultFSLoaderRoot = "." +) + +type dirLoaderFS struct { + sync.Mutex + filesystem fs.FS + files []string + idx int + filter filter.LoaderFilter + root string + pathFormat PathFormat + maxSizeLimitBytes int64 + followSymlinks bool +} + +// NewFSLoader returns a basic DirectoryLoader implementation +// that will load files from a fs.FS interface +func NewFSLoader(filesystem fs.FS) (DirectoryLoader, error) { + return NewFSLoaderWithRoot(filesystem, defaultFSLoaderRoot), nil +} + +// NewFSLoaderWithRoot returns a basic DirectoryLoader implementation +// that will load files from a fs.FS interface at the supplied root +func NewFSLoaderWithRoot(filesystem fs.FS, root string) DirectoryLoader { + d := dirLoaderFS{ + filesystem: filesystem, + root: normalizeRootDirectory(root), + pathFormat: Chrooted, + } + + return &d +} + +func (d *dirLoaderFS) walkDir(path string, dirEntry fs.DirEntry, err error) error { + if err != nil { + return err + } + + if dirEntry != nil { + info, err := dirEntry.Info() + if err != nil { + return err + } + + if dirEntry.Type().IsRegular() { + if d.filter != nil && d.filter(filepath.ToSlash(path), info, getdepth(path, false)) { + return nil + } + + if d.maxSizeLimitBytes > 0 && info.Size() > d.maxSizeLimitBytes { + return fmt.Errorf("file %s size %d exceeds limit of %d", path, info.Size(), d.maxSizeLimitBytes) + } + + d.files = append(d.files, path) + } else if dirEntry.Type()&fs.ModeSymlink != 0 && d.followSymlinks { + if d.filter != nil && d.filter(filepath.ToSlash(path), info, getdepth(path, false)) { + return nil + } + + if d.maxSizeLimitBytes > 0 && info.Size() > d.maxSizeLimitBytes { + return fmt.Errorf("file %s size %d exceeds limit of %d", path, info.Size(), d.maxSizeLimitBytes) + } + + d.files = append(d.files, path) + } else if dirEntry.Type().IsDir() { + if d.filter != nil && d.filter(filepath.ToSlash(path), info, getdepth(path, true)) { + return fs.SkipDir + } + } + } + return nil +} + +// WithFilter specifies the filter object to use to filter files while loading bundles +func (d *dirLoaderFS) WithFilter(filter filter.LoaderFilter) DirectoryLoader { + d.filter = filter + return d +} + +// WithPathFormat specifies how a path is formatted in a Descriptor +func (d *dirLoaderFS) WithPathFormat(pathFormat PathFormat) DirectoryLoader { + d.pathFormat = pathFormat + return d +} + +// WithSizeLimitBytes specifies the maximum size of any file in the filesystem directory to read +func (d *dirLoaderFS) WithSizeLimitBytes(sizeLimitBytes int64) DirectoryLoader { + d.maxSizeLimitBytes = sizeLimitBytes + return d +} + +func (d *dirLoaderFS) WithFollowSymlinks(followSymlinks bool) DirectoryLoader { + d.followSymlinks = followSymlinks + return d +} + +// NextFile iterates to the next file in the directory tree +// and returns a file Descriptor for the file. +func (d *dirLoaderFS) NextFile() (*Descriptor, error) { + d.Lock() + defer d.Unlock() + + if d.files == nil { + err := fs.WalkDir(d.filesystem, d.root, d.walkDir) + if err != nil { + return nil, fmt.Errorf("failed to list files: %w", err) + } + } + + // If done reading files then just return io.EOF + // errors for each NextFile() call + if d.idx >= len(d.files) { + return nil, io.EOF + } + + fileName := d.files[d.idx] + d.idx++ + + fh, err := d.filesystem.Open(fileName) + if err != nil { + return nil, fmt.Errorf("failed to open file %s: %w", fileName, err) + } + + cleanedPath := formatPath(fileName, d.root, d.pathFormat) + f := NewDescriptor(cleanedPath, cleanedPath, fh).WithCloser(fh) + return f, nil +} diff --git a/vendor/github.com/open-policy-agent/opa/v1/bundle/hash.go b/vendor/github.com/open-policy-agent/opa/v1/bundle/hash.go new file mode 100644 index 000000000..021801bb0 --- /dev/null +++ b/vendor/github.com/open-policy-agent/opa/v1/bundle/hash.go @@ -0,0 +1,141 @@ +// Copyright 2020 The OPA Authors. All rights reserved. +// Use of this source code is governed by an Apache2 +// license that can be found in the LICENSE file. + +package bundle + +import ( + "bytes" + "crypto/md5" + "crypto/sha1" + "crypto/sha256" + "crypto/sha512" + "encoding/json" + "fmt" + "hash" + "io" + "sort" + "strings" +) + +// HashingAlgorithm represents a subset of hashing algorithms implemented in Go +type HashingAlgorithm string + +// Supported values for HashingAlgorithm +const ( + MD5 HashingAlgorithm = "MD5" + SHA1 HashingAlgorithm = "SHA-1" + SHA224 HashingAlgorithm = "SHA-224" + SHA256 HashingAlgorithm = "SHA-256" + SHA384 HashingAlgorithm = "SHA-384" + SHA512 HashingAlgorithm = "SHA-512" + SHA512224 HashingAlgorithm = "SHA-512-224" + SHA512256 HashingAlgorithm = "SHA-512-256" +) + +// String returns the string representation of a HashingAlgorithm +func (alg HashingAlgorithm) String() string { + return string(alg) +} + +// SignatureHasher computes a signature digest for a file with (structured or unstructured) data and policy +type SignatureHasher interface { + HashFile(v interface{}) ([]byte, error) +} + +type hasher struct { + h func() hash.Hash // hash function factory +} + +// NewSignatureHasher returns a signature hasher suitable for a particular hashing algorithm +func NewSignatureHasher(alg HashingAlgorithm) (SignatureHasher, error) { + h := &hasher{} + + switch alg { + case MD5: + h.h = md5.New + case SHA1: + h.h = sha1.New + case SHA224: + h.h = sha256.New224 + case SHA256: + h.h = sha256.New + case SHA384: + h.h = sha512.New384 + case SHA512: + h.h = sha512.New + case SHA512224: + h.h = sha512.New512_224 + case SHA512256: + h.h = sha512.New512_256 + default: + return nil, fmt.Errorf("unsupported hashing algorithm: %s", alg) + } + + return h, nil +} + +// HashFile hashes the file content, JSON or binary, both in golang native format. +func (h *hasher) HashFile(v interface{}) ([]byte, error) { + hf := h.h() + walk(v, hf) + return hf.Sum(nil), nil +} + +// walk hashes the file content, JSON or binary, both in golang native format. +// +// Computation for unstructured documents is a hash of the document. +// +// Computation for the types of structured JSON document is as follows: +// +// object: Hash {, then each key (in alphabetical order) and digest of the value, then comma (between items) and finally }. +// +// array: Hash [, then digest of the value, then comma (between items) and finally ]. +func walk(v interface{}, h io.Writer) { + + switch x := v.(type) { + case map[string]interface{}: + _, _ = h.Write([]byte("{")) + + var keys []string + for k := range x { + keys = append(keys, k) + } + sort.Strings(keys) + + for i, key := range keys { + if i > 0 { + _, _ = h.Write([]byte(",")) + } + + _, _ = h.Write(encodePrimitive(key)) + _, _ = h.Write([]byte(":")) + walk(x[key], h) + } + + _, _ = h.Write([]byte("}")) + case []interface{}: + _, _ = h.Write([]byte("[")) + + for i, e := range x { + if i > 0 { + _, _ = h.Write([]byte(",")) + } + walk(e, h) + } + + _, _ = h.Write([]byte("]")) + case []byte: + _, _ = h.Write(x) + default: + _, _ = h.Write(encodePrimitive(x)) + } +} + +func encodePrimitive(v interface{}) []byte { + var buf bytes.Buffer + encoder := json.NewEncoder(&buf) + encoder.SetEscapeHTML(false) + _ = encoder.Encode(v) + return []byte(strings.Trim(buf.String(), "\n")) +} diff --git a/vendor/github.com/open-policy-agent/opa/v1/bundle/keys.go b/vendor/github.com/open-policy-agent/opa/v1/bundle/keys.go new file mode 100644 index 000000000..aad30a675 --- /dev/null +++ b/vendor/github.com/open-policy-agent/opa/v1/bundle/keys.go @@ -0,0 +1,144 @@ +// Copyright 2020 The OPA Authors. All rights reserved. +// Use of this source code is governed by an Apache2 +// license that can be found in the LICENSE file. + +// Package bundle provide helpers that assist in creating the verification and signing key configuration +package bundle + +import ( + "encoding/pem" + "fmt" + "os" + + "github.com/open-policy-agent/opa/internal/jwx/jwa" + "github.com/open-policy-agent/opa/internal/jwx/jws/sign" + "github.com/open-policy-agent/opa/v1/keys" + + "github.com/open-policy-agent/opa/v1/util" +) + +const ( + defaultTokenSigningAlg = "RS256" +) + +// KeyConfig holds the keys used to sign or verify bundles and tokens +// Moved to own package, alias kept for backwards compatibility +type KeyConfig = keys.Config + +// VerificationConfig represents the key configuration used to verify a signed bundle +type VerificationConfig struct { + PublicKeys map[string]*KeyConfig + KeyID string `json:"keyid"` + Scope string `json:"scope"` + Exclude []string `json:"exclude_files"` +} + +// NewVerificationConfig return a new VerificationConfig +func NewVerificationConfig(keys map[string]*KeyConfig, id, scope string, exclude []string) *VerificationConfig { + return &VerificationConfig{ + PublicKeys: keys, + KeyID: id, + Scope: scope, + Exclude: exclude, + } +} + +// ValidateAndInjectDefaults validates the config and inserts default values +func (vc *VerificationConfig) ValidateAndInjectDefaults(keys map[string]*KeyConfig) error { + vc.PublicKeys = keys + + if vc.KeyID != "" { + found := false + for key := range keys { + if key == vc.KeyID { + found = true + break + } + } + + if !found { + return fmt.Errorf("key id %s not found", vc.KeyID) + } + } + return nil +} + +// GetPublicKey returns the public key corresponding to the given key id +func (vc *VerificationConfig) GetPublicKey(id string) (*KeyConfig, error) { + var kc *KeyConfig + var ok bool + + if kc, ok = vc.PublicKeys[id]; !ok { + return nil, fmt.Errorf("verification key corresponding to ID %v not found", id) + } + return kc, nil +} + +// SigningConfig represents the key configuration used to generate a signed bundle +type SigningConfig struct { + Plugin string + Key string + Algorithm string + ClaimsPath string +} + +// NewSigningConfig return a new SigningConfig +func NewSigningConfig(key, alg, claimsPath string) *SigningConfig { + if alg == "" { + alg = defaultTokenSigningAlg + } + + return &SigningConfig{ + Plugin: defaultSignerID, + Key: key, + Algorithm: alg, + ClaimsPath: claimsPath, + } +} + +// WithPlugin sets the signing plugin in the signing config +func (s *SigningConfig) WithPlugin(plugin string) *SigningConfig { + if plugin != "" { + s.Plugin = plugin + } + return s +} + +// GetPrivateKey returns the private key or secret from the signing config +func (s *SigningConfig) GetPrivateKey() (interface{}, error) { + + block, _ := pem.Decode([]byte(s.Key)) + if block != nil { + return sign.GetSigningKey(s.Key, jwa.SignatureAlgorithm(s.Algorithm)) + } + + var priv string + if _, err := os.Stat(s.Key); err == nil { + bs, err := os.ReadFile(s.Key) + if err != nil { + return nil, err + } + priv = string(bs) + } else if os.IsNotExist(err) { + priv = s.Key + } else { + return nil, err + } + + return sign.GetSigningKey(priv, jwa.SignatureAlgorithm(s.Algorithm)) +} + +// GetClaims returns the claims by reading the file specified in the signing config +func (s *SigningConfig) GetClaims() (map[string]interface{}, error) { + var claims map[string]interface{} + + bs, err := os.ReadFile(s.ClaimsPath) + if err != nil { + return claims, err + } + + if err := util.UnmarshalJSON(bs, &claims); err != nil { + return claims, err + } + return claims, nil +} diff --git a/vendor/github.com/open-policy-agent/opa/v1/bundle/sign.go b/vendor/github.com/open-policy-agent/opa/v1/bundle/sign.go new file mode 100644 index 000000000..cf9a3e183 --- /dev/null +++ b/vendor/github.com/open-policy-agent/opa/v1/bundle/sign.go @@ -0,0 +1,135 @@ +// Copyright 2020 The OPA Authors. All rights reserved. +// Use of this source code is governed by an Apache2 +// license that can be found in the LICENSE file. + +// Package bundle provide helpers that assist in the creating a signed bundle +package bundle + +import ( + "crypto/rand" + "encoding/json" + "fmt" + + "github.com/open-policy-agent/opa/internal/jwx/jwa" + "github.com/open-policy-agent/opa/internal/jwx/jws" +) + +const defaultSignerID = "_default" + +var signers map[string]Signer + +// Signer is the interface expected for implementations that generate bundle signatures. +type Signer interface { + GenerateSignedToken([]FileInfo, *SigningConfig, string) (string, error) +} + +// GenerateSignedToken will retrieve the Signer implementation based on the Plugin specified +// in SigningConfig, and call its implementation of GenerateSignedToken. The signer generates +// a signed token given the list of files to be included in the payload and the bundle +// signing config. The keyID if non-empty, represents the value for the "keyid" claim in the token. +func GenerateSignedToken(files []FileInfo, sc *SigningConfig, keyID string) (string, error) { + var plugin string + // for backwards compatibility, check if there is no plugin specified, and use default + if sc.Plugin == "" { + plugin = defaultSignerID + } else { + plugin = sc.Plugin + } + signer, err := GetSigner(plugin) + if err != nil { + return "", err + } + return signer.GenerateSignedToken(files, sc, keyID) +} + +// DefaultSigner is the default bundle signing implementation. It signs bundles by generating +// a JWT and signing it using a locally-accessible private key. +type DefaultSigner struct{} + +// GenerateSignedToken generates a signed token given the list of files to be +// included in the payload and the bundle signing config. The keyID if non-empty, +// represents the value for the "keyid" claim in the token +func (*DefaultSigner) GenerateSignedToken(files []FileInfo, sc *SigningConfig, keyID string) (string, error) { + payload, err := generatePayload(files, sc, keyID) + if err != nil { + return "", err + } + + privateKey, err := sc.GetPrivateKey() + if err != nil { + return "", err + } + + var headers jws.StandardHeaders + + if err := headers.Set(jws.AlgorithmKey, jwa.SignatureAlgorithm(sc.Algorithm)); err != nil { + return "", err + } + + if keyID != "" { + if err := headers.Set(jws.KeyIDKey, keyID); err != nil { + return "", err + } + } + + hdr, err := json.Marshal(headers) + if err != nil { + return "", err + } + + token, err := jws.SignLiteral(payload, + jwa.SignatureAlgorithm(sc.Algorithm), + privateKey, + hdr, + rand.Reader) + if err != nil { + return "", err + } + return string(token), nil +} + +func generatePayload(files []FileInfo, sc *SigningConfig, keyID string) ([]byte, error) { + payload := make(map[string]interface{}) + payload["files"] = files + + if sc.ClaimsPath != "" { + claims, err := sc.GetClaims() + if err != nil { + return nil, err + } + + for claim, value := range claims { + payload[claim] = value + } + } else { + if keyID != "" { + // keyid claim is deprecated but include it for backwards compatibility. + payload["keyid"] = keyID + } + } + return json.Marshal(payload) +} + +// GetSigner returns the Signer registered under the given id +func GetSigner(id string) (Signer, error) { + signer, ok := signers[id] + if !ok { + return nil, fmt.Errorf("no signer exists under id %s", id) + } + return signer, nil +} + +// RegisterSigner registers a Signer under the given id +func RegisterSigner(id string, s Signer) error { + if id == defaultSignerID { + return fmt.Errorf("signer id %s is reserved, use a different id", id) + } + signers[id] = s + return nil +} + +func init() { + signers = map[string]Signer{ + defaultSignerID: &DefaultSigner{}, + } +} diff --git a/vendor/github.com/open-policy-agent/opa/v1/bundle/store.go b/vendor/github.com/open-policy-agent/opa/v1/bundle/store.go new file mode 100644 index 000000000..b9c8c3388 --- /dev/null +++ b/vendor/github.com/open-policy-agent/opa/v1/bundle/store.go @@ -0,0 +1,1036 @@ +// Copyright 2019 The OPA Authors. All rights reserved. +// Use of this source code is governed by an Apache2 +// license that can be found in the LICENSE file. + +package bundle + +import ( + "context" + "encoding/base64" + "encoding/json" + "fmt" + "path/filepath" + "strings" + + iCompiler "github.com/open-policy-agent/opa/internal/compiler" + "github.com/open-policy-agent/opa/internal/json/patch" + "github.com/open-policy-agent/opa/v1/ast" + "github.com/open-policy-agent/opa/v1/metrics" + "github.com/open-policy-agent/opa/v1/storage" + "github.com/open-policy-agent/opa/v1/util" +) + +// BundlesBasePath is the storage path used for storing bundle metadata +var BundlesBasePath = storage.MustParsePath("/system/bundles") + +// Note: As needed these helpers could be memoized. + +// ManifestStoragePath is the storage path used for the given named bundle manifest. +func ManifestStoragePath(name string) storage.Path { + return append(BundlesBasePath, name, "manifest") +} + +// EtagStoragePath is the storage path used for the given named bundle etag. +func EtagStoragePath(name string) storage.Path { + return append(BundlesBasePath, name, "etag") +} + +func namedBundlePath(name string) storage.Path { + return append(BundlesBasePath, name) +} + +func rootsPath(name string) storage.Path { + return append(BundlesBasePath, name, "manifest", "roots") +} + +func revisionPath(name string) storage.Path { + return append(BundlesBasePath, name, "manifest", "revision") +} + +func wasmModulePath(name string) storage.Path { + return append(BundlesBasePath, name, "wasm") +} + +func wasmEntrypointsPath(name string) storage.Path { + return append(BundlesBasePath, name, "manifest", "wasm") +} + +func metadataPath(name string) storage.Path { + return append(BundlesBasePath, name, "manifest", "metadata") +} + +func read(ctx context.Context, store storage.Store, txn storage.Transaction, path storage.Path) (interface{}, error) { + value, err := store.Read(ctx, txn, path) + if err != nil { + return nil, err + } + + if astValue, ok := value.(ast.Value); ok { + value, err = ast.JSON(astValue) + if err != nil { + return nil, err + } + } + + return value, nil +} + +// ReadBundleNamesFromStore will return a list of bundle names which have had their metadata stored. +func ReadBundleNamesFromStore(ctx context.Context, store storage.Store, txn storage.Transaction) ([]string, error) { + value, err := read(ctx, store, txn, BundlesBasePath) + if err != nil { + return nil, err + } + + bundleMap, ok := value.(map[string]interface{}) + if !ok { + return nil, fmt.Errorf("corrupt manifest roots") + } + + bundles := make([]string, len(bundleMap)) + idx := 0 + for name := range bundleMap { + bundles[idx] = name + idx++ + } + return bundles, nil +} + +// WriteManifestToStore will write the manifest into the storage. This function is called when +// the bundle is activated. +func WriteManifestToStore(ctx context.Context, store storage.Store, txn storage.Transaction, name string, manifest Manifest) error { + return write(ctx, store, txn, ManifestStoragePath(name), manifest) +} + +// WriteEtagToStore will write the bundle etag into the storage. This function is called when the bundle is activated. +func WriteEtagToStore(ctx context.Context, store storage.Store, txn storage.Transaction, name, etag string) error { + return write(ctx, store, txn, EtagStoragePath(name), etag) +} + +func write(ctx context.Context, store storage.Store, txn storage.Transaction, path storage.Path, value interface{}) error { + if err := util.RoundTrip(&value); err != nil { + return err + } + + var dir []string + if len(path) > 1 { + dir = path[:len(path)-1] + } + + if err := storage.MakeDir(ctx, store, txn, dir); err != nil { + return err + } + + return store.Write(ctx, txn, storage.AddOp, path, value) +} + +// EraseManifestFromStore will remove the manifest from storage. This function is called +// when the bundle is deactivated. +func EraseManifestFromStore(ctx context.Context, store storage.Store, txn storage.Transaction, name string) error { + path := namedBundlePath(name) + err := store.Write(ctx, txn, storage.RemoveOp, path, nil) + return suppressNotFound(err) +} + +// eraseBundleEtagFromStore will remove the bundle etag from storage. This function is called +// when the bundle is deactivated. +func eraseBundleEtagFromStore(ctx context.Context, store storage.Store, txn storage.Transaction, name string) error { + path := EtagStoragePath(name) + err := store.Write(ctx, txn, storage.RemoveOp, path, nil) + return suppressNotFound(err) +} + +func suppressNotFound(err error) error { + if err == nil || storage.IsNotFound(err) { + return nil + } + return err +} + +func writeWasmModulesToStore(ctx context.Context, store storage.Store, txn storage.Transaction, name string, b *Bundle) error { + basePath := wasmModulePath(name) + for _, wm := range b.WasmModules { + path := append(basePath, wm.Path) + err := write(ctx, store, txn, path, base64.StdEncoding.EncodeToString(wm.Raw)) + if err != nil { + return err + } + } + return nil +} + +func eraseWasmModulesFromStore(ctx context.Context, store storage.Store, txn storage.Transaction, name string) error { + path := wasmModulePath(name) + + err := store.Write(ctx, txn, storage.RemoveOp, path, nil) + return suppressNotFound(err) +} + +// ReadWasmMetadataFromStore will read Wasm module resolver metadata from the store. +func ReadWasmMetadataFromStore(ctx context.Context, store storage.Store, txn storage.Transaction, name string) ([]WasmResolver, error) { + path := wasmEntrypointsPath(name) + value, err := read(ctx, store, txn, path) + if err != nil { + return nil, err + } + + bs, err := json.Marshal(value) + if err != nil { + return nil, fmt.Errorf("corrupt wasm manifest data") + } + + var wasmMetadata []WasmResolver + + err = util.UnmarshalJSON(bs, &wasmMetadata) + if err != nil { + return nil, fmt.Errorf("corrupt wasm manifest data") + } + + return wasmMetadata, nil +} + +// ReadWasmModulesFromStore will write Wasm module resolver metadata from the store. +func ReadWasmModulesFromStore(ctx context.Context, store storage.Store, txn storage.Transaction, name string) (map[string][]byte, error) { + path := wasmModulePath(name) + value, err := read(ctx, store, txn, path) + if err != nil { + return nil, err + } + + encodedModules, ok := value.(map[string]interface{}) + if !ok { + return nil, fmt.Errorf("corrupt wasm modules") + } + + rawModules := map[string][]byte{} + for path, enc := range encodedModules { + encStr, ok := enc.(string) + if !ok { + return nil, fmt.Errorf("corrupt wasm modules") + } + bs, err := base64.StdEncoding.DecodeString(encStr) + if err != nil { + return nil, err + } + rawModules[path] = bs + } + return rawModules, nil +} + +// ReadBundleRootsFromStore returns the roots in the specified bundle. +// If the bundle is not activated, this function will return +// storage NotFound error. +func ReadBundleRootsFromStore(ctx context.Context, store storage.Store, txn storage.Transaction, name string) ([]string, error) { + value, err := read(ctx, store, txn, rootsPath(name)) + if err != nil { + return nil, err + } + + sl, ok := value.([]interface{}) + if !ok { + return nil, fmt.Errorf("corrupt manifest roots") + } + + roots := make([]string, len(sl)) + + for i := range sl { + roots[i], ok = sl[i].(string) + if !ok { + return nil, fmt.Errorf("corrupt manifest root") + } + } + + return roots, nil +} + +// ReadBundleRevisionFromStore returns the revision in the specified bundle. +// If the bundle is not activated, this function will return +// storage NotFound error. +func ReadBundleRevisionFromStore(ctx context.Context, store storage.Store, txn storage.Transaction, name string) (string, error) { + return readRevisionFromStore(ctx, store, txn, revisionPath(name)) +} + +func readRevisionFromStore(ctx context.Context, store storage.Store, txn storage.Transaction, path storage.Path) (string, error) { + value, err := read(ctx, store, txn, path) + if err != nil { + return "", err + } + + str, ok := value.(string) + if !ok { + return "", fmt.Errorf("corrupt manifest revision") + } + + return str, nil +} + +// ReadBundleMetadataFromStore returns the metadata in the specified bundle. +// If the bundle is not activated, this function will return +// storage NotFound error. +func ReadBundleMetadataFromStore(ctx context.Context, store storage.Store, txn storage.Transaction, name string) (map[string]interface{}, error) { + return readMetadataFromStore(ctx, store, txn, metadataPath(name)) +} + +func readMetadataFromStore(ctx context.Context, store storage.Store, txn storage.Transaction, path storage.Path) (map[string]interface{}, error) { + value, err := read(ctx, store, txn, path) + if err != nil { + return nil, suppressNotFound(err) + } + + data, ok := value.(map[string]interface{}) + if !ok { + return nil, fmt.Errorf("corrupt manifest metadata") + } + + return data, nil +} + +// ReadBundleEtagFromStore returns the etag for the specified bundle. +// If the bundle is not activated, this function will return +// storage NotFound error. +func ReadBundleEtagFromStore(ctx context.Context, store storage.Store, txn storage.Transaction, name string) (string, error) { + return readEtagFromStore(ctx, store, txn, EtagStoragePath(name)) +} + +func readEtagFromStore(ctx context.Context, store storage.Store, txn storage.Transaction, path storage.Path) (string, error) { + value, err := read(ctx, store, txn, path) + if err != nil { + return "", err + } + + str, ok := value.(string) + if !ok { + return "", fmt.Errorf("corrupt bundle etag") + } + + return str, nil +} + +// ActivateOpts defines options for the Activate API call. +type ActivateOpts struct { + Ctx context.Context + Store storage.Store + Txn storage.Transaction + TxnCtx *storage.Context + Compiler *ast.Compiler + Metrics metrics.Metrics + Bundles map[string]*Bundle // Optional + ExtraModules map[string]*ast.Module // Optional + AuthorizationDecisionRef ast.Ref + ParserOptions ast.ParserOptions + + legacy bool +} + +// Activate the bundle(s) by loading into the given Store. This will load policies, data, and record +// the manifest in storage. The compiler provided will have had the polices compiled on it. +func Activate(opts *ActivateOpts) error { + opts.legacy = false + return activateBundles(opts) +} + +// DeactivateOpts defines options for the Deactivate API call +type DeactivateOpts struct { + Ctx context.Context + Store storage.Store + Txn storage.Transaction + BundleNames map[string]struct{} + ParserOptions ast.ParserOptions +} + +// Deactivate the bundle(s). This will erase associated data, policies, and the manifest entry from the store. +func Deactivate(opts *DeactivateOpts) error { + erase := map[string]struct{}{} + for name := range opts.BundleNames { + roots, err := ReadBundleRootsFromStore(opts.Ctx, opts.Store, opts.Txn, name) + if suppressNotFound(err) != nil { + return err + } + for _, root := range roots { + erase[root] = struct{}{} + } + } + _, err := eraseBundles(opts.Ctx, opts.Store, opts.Txn, opts.ParserOptions, opts.BundleNames, erase) + return err +} + +func activateBundles(opts *ActivateOpts) error { + + // Build collections of bundle names, modules, and roots to erase + erase := map[string]struct{}{} + names := map[string]struct{}{} + deltaBundles := map[string]*Bundle{} + snapshotBundles := map[string]*Bundle{} + + for name, b := range opts.Bundles { + if b.Type() == DeltaBundleType { + deltaBundles[name] = b + } else { + snapshotBundles[name] = b + names[name] = struct{}{} + + roots, err := ReadBundleRootsFromStore(opts.Ctx, opts.Store, opts.Txn, name) + if suppressNotFound(err) != nil { + return err + } + for _, root := range roots { + erase[root] = struct{}{} + } + + // Erase data at new roots to prepare for writing the new data + for _, root := range *b.Manifest.Roots { + erase[root] = struct{}{} + } + } + } + + // Before changing anything make sure the roots don't collide with any + // other bundles that already are activated or other bundles being activated. + err := hasRootsOverlap(opts.Ctx, opts.Store, opts.Txn, opts.Bundles) + if err != nil { + return err + } + + if len(deltaBundles) != 0 { + err := activateDeltaBundles(opts, deltaBundles) + if err != nil { + return err + } + } + + // Erase data and policies at new + old roots, and remove the old + // manifests before activating a new snapshot bundle. + remaining, err := eraseBundles(opts.Ctx, opts.Store, opts.Txn, opts.ParserOptions, names, erase) + if err != nil { + return err + } + + // Validate data in bundle does not contain paths outside the bundle's roots. + for _, b := range snapshotBundles { + + if b.lazyLoadingMode { + + for _, item := range b.Raw { + path := filepath.ToSlash(item.Path) + + if filepath.Base(path) == dataFile || filepath.Base(path) == yamlDataFile { + var val map[string]json.RawMessage + err = util.Unmarshal(item.Value, &val) + if err == nil { + err = doDFS(val, filepath.Dir(strings.Trim(path, "/")), *b.Manifest.Roots) + if err != nil { + return err + } + } else { + // Build an object for the value + p := getNormalizedPath(path) + + if len(p) == 0 { + return fmt.Errorf("root value must be object") + } + + // verify valid YAML or JSON value + var x interface{} + err := util.Unmarshal(item.Value, &x) + if err != nil { + return err + } + + value := item.Value + dir := map[string]json.RawMessage{} + for i := len(p) - 1; i > 0; i-- { + dir[p[i]] = value + + bs, err := json.Marshal(dir) + if err != nil { + return err + } + + value = bs + dir = map[string]json.RawMessage{} + } + dir[p[0]] = value + + err = doDFS(dir, filepath.Dir(strings.Trim(path, "/")), *b.Manifest.Roots) + if err != nil { + return err + } + } + } + } + } + } + + // Compile the modules all at once to avoid having to re-do work. + remainingAndExtra := make(map[string]*ast.Module) + for name, mod := range remaining { + remainingAndExtra[name] = mod + } + for name, mod := range opts.ExtraModules { + remainingAndExtra[name] = mod + } + + err = compileModules(opts.Compiler, opts.Metrics, snapshotBundles, remainingAndExtra, opts.legacy, opts.AuthorizationDecisionRef) + if err != nil { + return err + } + + if err := writeDataAndModules(opts.Ctx, opts.Store, opts.Txn, opts.TxnCtx, snapshotBundles, opts.legacy); err != nil { + return err + } + + if err := ast.CheckPathConflicts(opts.Compiler, storage.NonEmpty(opts.Ctx, opts.Store, opts.Txn)); len(err) > 0 { + return err + } + + for name, b := range snapshotBundles { + if err := writeManifestToStore(opts, name, b.Manifest); err != nil { + return err + } + + if err := writeEtagToStore(opts, name, b.Etag); err != nil { + return err + } + + if err := writeWasmModulesToStore(opts.Ctx, opts.Store, opts.Txn, name, b); err != nil { + return err + } + } + + return nil +} + +func doDFS(obj map[string]json.RawMessage, path string, roots []string) error { + if len(roots) == 1 && roots[0] == "" { + return nil + } + + for key := range obj { + + newPath := filepath.Join(strings.Trim(path, "/"), key) + + // Note: filepath.Join can return paths with '\' separators, always use + // filepath.ToSlash to keep them normalized. + newPath = strings.TrimLeft(normalizePath(newPath), "/.") + + contains := false + prefix := false + if RootPathsContain(roots, newPath) { + contains = true + } else { + for i := range roots { + if strings.HasPrefix(strings.Trim(roots[i], "/"), newPath) { + prefix = true + break + } + } + } + + if !contains && !prefix { + return fmt.Errorf("manifest roots %v do not permit data at path '/%s' (hint: check bundle directory structure)", roots, newPath) + } + + if contains { + continue + } + + var next map[string]json.RawMessage + err := util.Unmarshal(obj[key], &next) + if err != nil { + return fmt.Errorf("manifest roots %v do not permit data at path '/%s' (hint: check bundle directory structure)", roots, newPath) + } + + if err := doDFS(next, newPath, roots); err != nil { + return err + } + } + return nil +} + +func activateDeltaBundles(opts *ActivateOpts, bundles map[string]*Bundle) error { + + // Check that the manifest roots and wasm resolvers in the delta bundle + // match with those currently in the store + for name, b := range bundles { + value, err := opts.Store.Read(opts.Ctx, opts.Txn, ManifestStoragePath(name)) + if err != nil { + if storage.IsNotFound(err) { + continue + } + return err + } + + manifest, err := valueToManifest(value) + if err != nil { + return fmt.Errorf("corrupt manifest data: %w", err) + } + + if !b.Manifest.equalWasmResolversAndRoots(manifest) { + return fmt.Errorf("delta bundle '%s' has wasm resolvers or manifest roots that are different from those in the store", name) + } + } + + for _, b := range bundles { + err := applyPatches(opts.Ctx, opts.Store, opts.Txn, b.Patch.Data) + if err != nil { + return err + } + } + + if err := ast.CheckPathConflicts(opts.Compiler, storage.NonEmpty(opts.Ctx, opts.Store, opts.Txn)); len(err) > 0 { + return err + } + + for name, b := range bundles { + if err := writeManifestToStore(opts, name, b.Manifest); err != nil { + return err + } + + if err := writeEtagToStore(opts, name, b.Etag); err != nil { + return err + } + } + + return nil +} + +func valueToManifest(v interface{}) (Manifest, error) { + if astV, ok := v.(ast.Value); ok { + var err error + v, err = ast.JSON(astV) + if err != nil { + return Manifest{}, err + } + } + + var manifest Manifest + + bs, err := json.Marshal(v) + if err != nil { + return Manifest{}, err + } + + err = util.UnmarshalJSON(bs, &manifest) + if err != nil { + return Manifest{}, err + } + + return manifest, nil +} + +// erase bundles by name and roots. This will clear all policies and data at its roots and remove its +// manifest from storage. +func eraseBundles(ctx context.Context, store storage.Store, txn storage.Transaction, parserOpts ast.ParserOptions, names map[string]struct{}, roots map[string]struct{}) (map[string]*ast.Module, error) { + + if err := eraseData(ctx, store, txn, roots); err != nil { + return nil, err + } + + remaining, err := erasePolicies(ctx, store, txn, parserOpts, roots) + if err != nil { + return nil, err + } + + for name := range names { + if err := EraseManifestFromStore(ctx, store, txn, name); suppressNotFound(err) != nil { + return nil, err + } + + if err := LegacyEraseManifestFromStore(ctx, store, txn); suppressNotFound(err) != nil { + return nil, err + } + + if err := eraseBundleEtagFromStore(ctx, store, txn, name); suppressNotFound(err) != nil { + return nil, err + } + + if err := eraseWasmModulesFromStore(ctx, store, txn, name); suppressNotFound(err) != nil { + return nil, err + } + } + + return remaining, nil +} + +func eraseData(ctx context.Context, store storage.Store, txn storage.Transaction, roots map[string]struct{}) error { + for root := range roots { + path, ok := storage.ParsePathEscaped("/" + root) + if !ok { + return fmt.Errorf("manifest root path invalid: %v", root) + } + + if len(path) > 0 { + if err := store.Write(ctx, txn, storage.RemoveOp, path, nil); suppressNotFound(err) != nil { + return err + } + } + } + return nil +} + +func erasePolicies(ctx context.Context, store storage.Store, txn storage.Transaction, parserOpts ast.ParserOptions, roots map[string]struct{}) (map[string]*ast.Module, error) { + + ids, err := store.ListPolicies(ctx, txn) + if err != nil { + return nil, err + } + + remaining := map[string]*ast.Module{} + + for _, id := range ids { + bs, err := store.GetPolicy(ctx, txn, id) + if err != nil { + return nil, err + } + module, err := ast.ParseModuleWithOpts(id, string(bs), parserOpts) + if err != nil { + return nil, err + } + path, err := module.Package.Path.Ptr() + if err != nil { + return nil, err + } + deleted := false + for root := range roots { + if RootPathsContain([]string{root}, path) { + if err := store.DeletePolicy(ctx, txn, id); err != nil { + return nil, err + } + deleted = true + break + } + } + if !deleted { + remaining[id] = module + } + } + + return remaining, nil +} + +func writeManifestToStore(opts *ActivateOpts, name string, manifest Manifest) error { + // Always write manifests to the named location. If the plugin is in the older style config + // then also write to the old legacy unnamed location. + if err := WriteManifestToStore(opts.Ctx, opts.Store, opts.Txn, name, manifest); err != nil { + return err + } + + if opts.legacy { + if err := LegacyWriteManifestToStore(opts.Ctx, opts.Store, opts.Txn, manifest); err != nil { + return err + } + } + + return nil +} + +func writeEtagToStore(opts *ActivateOpts, name, etag string) error { + if err := WriteEtagToStore(opts.Ctx, opts.Store, opts.Txn, name, etag); err != nil { + return err + } + + return nil +} + +func writeDataAndModules(ctx context.Context, store storage.Store, txn storage.Transaction, txnCtx *storage.Context, bundles map[string]*Bundle, legacy bool) error { + params := storage.WriteParams + params.Context = txnCtx + + for name, b := range bundles { + if len(b.Raw) == 0 { + // Write data from each new bundle into the store. Only write under the + // roots contained in their manifest. + if err := writeData(ctx, store, txn, *b.Manifest.Roots, b.Data); err != nil { + return err + } + + for _, mf := range b.Modules { + var path string + + // For backwards compatibility, in legacy mode, upsert policies to + // the unprefixed path. + if legacy { + path = mf.Path + } else { + path = modulePathWithPrefix(name, mf.Path) + } + + if err := store.UpsertPolicy(ctx, txn, path, mf.Raw); err != nil { + return err + } + } + } else { + params.BasePaths = *b.Manifest.Roots + + err := store.Truncate(ctx, txn, params, NewIterator(b.Raw)) + if err != nil { + return fmt.Errorf("store truncate failed for bundle '%s': %v", name, err) + } + } + } + + return nil +} + +func writeData(ctx context.Context, store storage.Store, txn storage.Transaction, roots []string, data map[string]interface{}) error { + for _, root := range roots { + path, ok := storage.ParsePathEscaped("/" + root) + if !ok { + return fmt.Errorf("manifest root path invalid: %v", root) + } + if value, ok := lookup(path, data); ok { + if len(path) > 0 { + if err := storage.MakeDir(ctx, store, txn, path[:len(path)-1]); err != nil { + return err + } + } + if err := store.Write(ctx, txn, storage.AddOp, path, value); err != nil { + return err + } + } + } + return nil +} + +func compileModules(compiler *ast.Compiler, m metrics.Metrics, bundles map[string]*Bundle, extraModules map[string]*ast.Module, legacy bool, authorizationDecisionRef ast.Ref) error { + + m.Timer(metrics.RegoModuleCompile).Start() + defer m.Timer(metrics.RegoModuleCompile).Stop() + + modules := map[string]*ast.Module{} + + // preserve any modules already on the compiler + for name, module := range compiler.Modules { + modules[name] = module + } + + // preserve any modules passed in from the store + for name, module := range extraModules { + modules[name] = module + } + + // include all the new bundle modules + for bundleName, b := range bundles { + if legacy { + for _, mf := range b.Modules { + modules[mf.Path] = mf.Parsed + } + } else { + for name, module := range b.ParsedModules(bundleName) { + modules[name] = module + } + } + } + + if compiler.Compile(modules); compiler.Failed() { + return compiler.Errors + } + + if authorizationDecisionRef.Equal(ast.EmptyRef()) { + return nil + } + + return iCompiler.VerifyAuthorizationPolicySchema(compiler, authorizationDecisionRef) +} + +func writeModules(ctx context.Context, store storage.Store, txn storage.Transaction, compiler *ast.Compiler, m metrics.Metrics, bundles map[string]*Bundle, extraModules map[string]*ast.Module, legacy bool) error { + + m.Timer(metrics.RegoModuleCompile).Start() + defer m.Timer(metrics.RegoModuleCompile).Stop() + + modules := map[string]*ast.Module{} + + // preserve any modules already on the compiler + for name, module := range compiler.Modules { + modules[name] = module + } + + // preserve any modules passed in from the store + for name, module := range extraModules { + modules[name] = module + } + + // include all the new bundle modules + for bundleName, b := range bundles { + if legacy { + for _, mf := range b.Modules { + modules[mf.Path] = mf.Parsed + } + } else { + for name, module := range b.ParsedModules(bundleName) { + modules[name] = module + } + } + } + + if compiler.Compile(modules); compiler.Failed() { + return compiler.Errors + } + for bundleName, b := range bundles { + for _, mf := range b.Modules { + var path string + + // For backwards compatibility, in legacy mode, upsert policies to + // the unprefixed path. + if legacy { + path = mf.Path + } else { + path = modulePathWithPrefix(bundleName, mf.Path) + } + + if err := store.UpsertPolicy(ctx, txn, path, mf.Raw); err != nil { + return err + } + } + } + return nil +} + +func lookup(path storage.Path, data map[string]interface{}) (interface{}, bool) { + if len(path) == 0 { + return data, true + } + for i := 0; i < len(path)-1; i++ { + value, ok := data[path[i]] + if !ok { + return nil, false + } + obj, ok := value.(map[string]interface{}) + if !ok { + return nil, false + } + data = obj + } + value, ok := data[path[len(path)-1]] + return value, ok +} + +func hasRootsOverlap(ctx context.Context, store storage.Store, txn storage.Transaction, bundles map[string]*Bundle) error { + collisions := map[string][]string{} + allBundles, err := ReadBundleNamesFromStore(ctx, store, txn) + if suppressNotFound(err) != nil { + return err + } + + allRoots := map[string][]string{} + + // Build a map of roots for existing bundles already in the system + for _, name := range allBundles { + roots, err := ReadBundleRootsFromStore(ctx, store, txn, name) + if suppressNotFound(err) != nil { + return err + } + allRoots[name] = roots + } + + // Add in any bundles that are being activated, overwrite existing roots + // with new ones where bundles are in both groups. + for name, bundle := range bundles { + allRoots[name] = *bundle.Manifest.Roots + } + + // Now check for each new bundle if it conflicts with any of the others + for name, bundle := range bundles { + for otherBundle, otherRoots := range allRoots { + if name == otherBundle { + // Skip the current bundle being checked + continue + } + + // Compare the "new" roots with other existing (or a different bundles new roots) + for _, newRoot := range *bundle.Manifest.Roots { + for _, otherRoot := range otherRoots { + if RootPathsOverlap(newRoot, otherRoot) { + collisions[otherBundle] = append(collisions[otherBundle], newRoot) + } + } + } + } + } + + if len(collisions) > 0 { + var bundleNames []string + for name := range collisions { + bundleNames = append(bundleNames, name) + } + return fmt.Errorf("detected overlapping roots in bundle manifest with: %s", bundleNames) + } + return nil +} + +func applyPatches(ctx context.Context, store storage.Store, txn storage.Transaction, patches []PatchOperation) error { + for _, pat := range patches { + + // construct patch path + path, ok := patch.ParsePatchPathEscaped("/" + strings.Trim(pat.Path, "/")) + if !ok { + return fmt.Errorf("error parsing patch path") + } + + var op storage.PatchOp + switch pat.Op { + case "upsert": + op = storage.AddOp + + _, err := store.Read(ctx, txn, path[:len(path)-1]) + if err != nil { + if !storage.IsNotFound(err) { + return err + } + + if err := storage.MakeDir(ctx, store, txn, path[:len(path)-1]); err != nil { + return err + } + } + case "remove": + op = storage.RemoveOp + case "replace": + op = storage.ReplaceOp + default: + return fmt.Errorf("bad patch operation: %v", pat.Op) + } + + // apply the patch + if err := store.Write(ctx, txn, op, path, pat.Value); err != nil { + return err + } + } + + return nil +} + +// Helpers for the older single (unnamed) bundle style manifest storage. + +// LegacyManifestStoragePath is the older unnamed bundle path for manifests to be stored. +// Deprecated: Use ManifestStoragePath and named bundles instead. +var legacyManifestStoragePath = storage.MustParsePath("/system/bundle/manifest") +var legacyRevisionStoragePath = append(legacyManifestStoragePath, "revision") + +// LegacyWriteManifestToStore will write the bundle manifest to the older single (unnamed) bundle manifest location. +// Deprecated: Use WriteManifestToStore and named bundles instead. +func LegacyWriteManifestToStore(ctx context.Context, store storage.Store, txn storage.Transaction, manifest Manifest) error { + return write(ctx, store, txn, legacyManifestStoragePath, manifest) +} + +// LegacyEraseManifestFromStore will erase the bundle manifest from the older single (unnamed) bundle manifest location. +// Deprecated: Use WriteManifestToStore and named bundles instead. +func LegacyEraseManifestFromStore(ctx context.Context, store storage.Store, txn storage.Transaction) error { + err := store.Write(ctx, txn, storage.RemoveOp, legacyManifestStoragePath, nil) + if err != nil { + return err + } + return nil +} + +// LegacyReadRevisionFromStore will read the bundle manifest revision from the older single (unnamed) bundle manifest location. +// Deprecated: Use ReadBundleRevisionFromStore and named bundles instead. +func LegacyReadRevisionFromStore(ctx context.Context, store storage.Store, txn storage.Transaction) (string, error) { + return readRevisionFromStore(ctx, store, txn, legacyRevisionStoragePath) +} + +// ActivateLegacy calls Activate for the bundles but will also write their manifest to the older unnamed store location. +// Deprecated: Use Activate with named bundles instead. +func ActivateLegacy(opts *ActivateOpts) error { + opts.legacy = true + return activateBundles(opts) +} diff --git a/vendor/github.com/open-policy-agent/opa/v1/bundle/verify.go b/vendor/github.com/open-policy-agent/opa/v1/bundle/verify.go new file mode 100644 index 000000000..2a4bb02c0 --- /dev/null +++ b/vendor/github.com/open-policy-agent/opa/v1/bundle/verify.go @@ -0,0 +1,231 @@ +// Copyright 2020 The OPA Authors. All rights reserved. +// Use of this source code is governed by an Apache2 +// license that can be found in the LICENSE file. + +// Package bundle provide helpers that assist in the bundle signature verification process +package bundle + +import ( + "bytes" + "encoding/base64" + "encoding/hex" + "encoding/json" + "fmt" + + "github.com/open-policy-agent/opa/internal/jwx/jwa" + "github.com/open-policy-agent/opa/internal/jwx/jws" + "github.com/open-policy-agent/opa/internal/jwx/jws/verify" + "github.com/open-policy-agent/opa/v1/util" +) + +const defaultVerifierID = "_default" + +var verifiers map[string]Verifier + +// Verifier is the interface expected for implementations that verify bundle signatures. +type Verifier interface { + VerifyBundleSignature(SignaturesConfig, *VerificationConfig) (map[string]FileInfo, error) +} + +// VerifyBundleSignature will retrieve the Verifier implementation based +// on the Plugin specified in SignaturesConfig, and call its implementation +// of VerifyBundleSignature. VerifyBundleSignature verifies the bundle signature +// using the given public keys or secret. If a signature is verified, it keeps +// track of the files specified in the JWT payload +func VerifyBundleSignature(sc SignaturesConfig, bvc *VerificationConfig) (map[string]FileInfo, error) { + // default implementation does not return a nil for map, so don't + // do it here either + files := make(map[string]FileInfo) + var plugin string + // for backwards compatibility, check if there is no plugin specified, and use default + if sc.Plugin == "" { + plugin = defaultVerifierID + } else { + plugin = sc.Plugin + } + verifier, err := GetVerifier(plugin) + if err != nil { + return files, err + } + return verifier.VerifyBundleSignature(sc, bvc) +} + +// DefaultVerifier is the default bundle verification implementation. It verifies bundles by checking +// the JWT signature using a locally-accessible public key. +type DefaultVerifier struct{} + +// VerifyBundleSignature verifies the bundle signature using the given public keys or secret. +// If a signature is verified, it keeps track of the files specified in the JWT payload +func (*DefaultVerifier) VerifyBundleSignature(sc SignaturesConfig, bvc *VerificationConfig) (map[string]FileInfo, error) { + files := make(map[string]FileInfo) + + if len(sc.Signatures) == 0 { + return files, fmt.Errorf(".signatures.json: missing JWT (expected exactly one)") + } + + if len(sc.Signatures) > 1 { + return files, fmt.Errorf(".signatures.json: multiple JWTs not supported (expected exactly one)") + } + + for _, token := range sc.Signatures { + payload, err := verifyJWTSignature(token, bvc) + if err != nil { + return files, err + } + + for _, file := range payload.Files { + files[file.Name] = file + } + } + return files, nil +} + +func verifyJWTSignature(token string, bvc *VerificationConfig) (*DecodedSignature, error) { + // decode JWT to check if the header specifies the key to use and/or if claims have the scope. + + parts, err := jws.SplitCompact(token) + if err != nil { + return nil, err + } + + var decodedHeader []byte + if decodedHeader, err = base64.RawURLEncoding.DecodeString(parts[0]); err != nil { + return nil, fmt.Errorf("failed to base64 decode JWT headers: %w", err) + } + + var hdr jws.StandardHeaders + if err := json.Unmarshal(decodedHeader, &hdr); err != nil { + return nil, fmt.Errorf("failed to parse JWT headers: %w", err) + } + + payload, err := base64.RawURLEncoding.DecodeString(parts[1]) + if err != nil { + return nil, err + } + + var ds DecodedSignature + if err := json.Unmarshal(payload, &ds); err != nil { + return nil, err + } + + // check for the id of the key to use for JWT signature verification + // first in the OPA config. If not found, then check the JWT kid. + keyID := bvc.KeyID + if keyID == "" { + keyID = hdr.KeyID + } + if keyID == "" { + // If header has no key id, check the deprecated key claim. + keyID = ds.KeyID + } + + if keyID == "" { + return nil, fmt.Errorf("verification key ID is empty") + } + + // now that we have the keyID, fetch the actual key + keyConfig, err := bvc.GetPublicKey(keyID) + if err != nil { + return nil, err + } + + // verify JWT signature + alg := jwa.SignatureAlgorithm(keyConfig.Algorithm) + key, err := verify.GetSigningKey(keyConfig.Key, alg) + if err != nil { + return nil, err + } + + _, err = jws.Verify([]byte(token), alg, key) + if err != nil { + return nil, err + } + + // verify the scope + scope := bvc.Scope + if scope == "" { + scope = keyConfig.Scope + } + + if ds.Scope != scope { + return nil, fmt.Errorf("scope mismatch") + } + return &ds, nil +} + +// VerifyBundleFile verifies the hash of a file in the bundle matches to that provided in the bundle's signature +func VerifyBundleFile(path string, data bytes.Buffer, files map[string]FileInfo) error { + var file FileInfo + var ok bool + + if file, ok = files[path]; !ok { + return fmt.Errorf("file %v not included in bundle signature", path) + } + + if file.Algorithm == "" { + return fmt.Errorf("no hashing algorithm provided for file %v", path) + } + + hash, err := NewSignatureHasher(HashingAlgorithm(file.Algorithm)) + if err != nil { + return err + } + + // hash the file content + // For unstructured files, hash the byte stream of the file + // For structured files, read the byte stream and parse into a JSON structure; + // then recursively order the fields of all objects alphabetically and then apply + // the hash function to result to compute the hash. This ensures that the digital signature is + // independent of whitespace and other non-semantic JSON features. + var value interface{} + if IsStructuredDoc(path) { + err := util.Unmarshal(data.Bytes(), &value) + if err != nil { + return err + } + } else { + value = data.Bytes() + } + + bs, err := hash.HashFile(value) + if err != nil { + return err + } + + // compare file hash with same file in the JWT payloads + fb, err := hex.DecodeString(file.Hash) + if err != nil { + return err + } + + if !bytes.Equal(fb, bs) { + return fmt.Errorf("%v: digest mismatch (want: %x, got: %x)", path, fb, bs) + } + + delete(files, path) + return nil +} + +// GetVerifier returns the Verifier registered under the given id +func GetVerifier(id string) (Verifier, error) { + verifier, ok := verifiers[id] + if !ok { + return nil, fmt.Errorf("no verifier exists under id %s", id) + } + return verifier, nil +} + +// RegisterVerifier registers a Verifier under the given id +func RegisterVerifier(id string, v Verifier) error { + if id == defaultVerifierID { + return fmt.Errorf("verifier id %s is reserved, use a different id", id) + } + verifiers[id] = v + return nil +} + +func init() { + verifiers = map[string]Verifier{ + defaultVerifierID: &DefaultVerifier{}, + } +} diff --git a/vendor/github.com/open-policy-agent/opa/v1/capabilities/capabilities.go b/vendor/github.com/open-policy-agent/opa/v1/capabilities/capabilities.go new file mode 100644 index 000000000..5b0bb1ea5 --- /dev/null +++ b/vendor/github.com/open-policy-agent/opa/v1/capabilities/capabilities.go @@ -0,0 +1,18 @@ +// Copyright 2021 The OPA Authors. All rights reserved. +// Use of this source code is governed by an Apache2 +// license that can be found in the LICENSE file. + +//go:build go1.16 +// +build go1.16 + +package capabilities + +import ( + v0 "github.com/open-policy-agent/opa/capabilities" +) + +// FS contains the embedded capabilities/ directory of the built version, +// which has all the capabilities of previous versions: +// "v0.18.0.json" contains the capabilities JSON of version v0.18.0, etc + +var FS = v0.FS diff --git a/vendor/github.com/open-policy-agent/opa/compile/compile.go b/vendor/github.com/open-policy-agent/opa/v1/compile/compile.go similarity index 98% rename from vendor/github.com/open-policy-agent/opa/compile/compile.go rename to vendor/github.com/open-policy-agent/opa/v1/compile/compile.go index 0b165e651..96b1f844c 100644 --- a/vendor/github.com/open-policy-agent/opa/compile/compile.go +++ b/vendor/github.com/open-policy-agent/opa/v1/compile/compile.go @@ -18,19 +18,19 @@ import ( "sort" "strings" - "github.com/open-policy-agent/opa/ast" - "github.com/open-policy-agent/opa/bundle" "github.com/open-policy-agent/opa/internal/compiler/wasm" "github.com/open-policy-agent/opa/internal/debug" "github.com/open-policy-agent/opa/internal/planner" "github.com/open-policy-agent/opa/internal/ref" initload "github.com/open-policy-agent/opa/internal/runtime/init" "github.com/open-policy-agent/opa/internal/wasm/encoding" - "github.com/open-policy-agent/opa/ir" - "github.com/open-policy-agent/opa/loader" - "github.com/open-policy-agent/opa/rego" - "github.com/open-policy-agent/opa/storage" - "github.com/open-policy-agent/opa/storage/inmem" + "github.com/open-policy-agent/opa/v1/ast" + "github.com/open-policy-agent/opa/v1/bundle" + "github.com/open-policy-agent/opa/v1/ir" + "github.com/open-policy-agent/opa/v1/loader" + "github.com/open-policy-agent/opa/v1/rego" + "github.com/open-policy-agent/opa/v1/storage" + "github.com/open-policy-agent/opa/v1/storage/inmem" ) const ( @@ -94,6 +94,7 @@ func New() *Compiler { optimizationLevel: 0, target: TargetRego, debug: debug.Discard(), + regoVersion: ast.DefaultRegoVersion, } } @@ -294,6 +295,10 @@ func addEntrypointsFromAnnotations(c *Compiler, arefs []*ast.AnnotationsRef) err // Build compiles and links the input files and outputs a bundle to the writer. func (c *Compiler) Build(ctx context.Context) error { + if c.regoVersion == ast.RegoUndefined { + return fmt.Errorf("rego-version not set") + } + if err := c.init(); err != nil { return err } diff --git a/vendor/github.com/open-policy-agent/opa/v1/config/config.go b/vendor/github.com/open-policy-agent/opa/v1/config/config.go new file mode 100644 index 000000000..09adb556f --- /dev/null +++ b/vendor/github.com/open-policy-agent/opa/v1/config/config.go @@ -0,0 +1,259 @@ +// Copyright 2018 The OPA Authors. All rights reserved. +// Use of this source code is governed by an Apache2 +// license that can be found in the LICENSE file. + +// Package config implements OPA configuration file parsing and validation. +package config + +import ( + "encoding/json" + "fmt" + "os" + "path/filepath" + "reflect" + "sort" + "strings" + + "github.com/open-policy-agent/opa/internal/ref" + "github.com/open-policy-agent/opa/v1/ast" + "github.com/open-policy-agent/opa/v1/util" + "github.com/open-policy-agent/opa/v1/version" +) + +// Config represents the configuration file that OPA can be started with. +type Config struct { + Services json.RawMessage `json:"services,omitempty"` + Labels map[string]string `json:"labels,omitempty"` + Discovery json.RawMessage `json:"discovery,omitempty"` + Bundle json.RawMessage `json:"bundle,omitempty"` // Deprecated: Use `bundles` instead + Bundles json.RawMessage `json:"bundles,omitempty"` + DecisionLogs json.RawMessage `json:"decision_logs,omitempty"` + Status json.RawMessage `json:"status,omitempty"` + Plugins map[string]json.RawMessage `json:"plugins,omitempty"` + Keys json.RawMessage `json:"keys,omitempty"` + DefaultDecision *string `json:"default_decision,omitempty"` + DefaultAuthorizationDecision *string `json:"default_authorization_decision,omitempty"` + Caching json.RawMessage `json:"caching,omitempty"` + NDBuiltinCache bool `json:"nd_builtin_cache,omitempty"` + PersistenceDirectory *string `json:"persistence_directory,omitempty"` + DistributedTracing json.RawMessage `json:"distributed_tracing,omitempty"` + Server *struct { + Encoding json.RawMessage `json:"encoding,omitempty"` + Decoding json.RawMessage `json:"decoding,omitempty"` + Metrics json.RawMessage `json:"metrics,omitempty"` + } `json:"server,omitempty"` + Storage *struct { + Disk json.RawMessage `json:"disk,omitempty"` + } `json:"storage,omitempty"` + Extra map[string]json.RawMessage `json:"-"` +} + +// ParseConfig returns a valid Config object with defaults injected. The id +// and version parameters will be set in the labels map. +func ParseConfig(raw []byte, id string) (*Config, error) { + // NOTE(sr): based on https://stackoverflow.com/a/33499066/993018 + var result Config + objValue := reflect.ValueOf(&result).Elem() + knownFields := map[string]reflect.Value{} + for i := 0; i != objValue.NumField(); i++ { + jsonName := strings.Split(objValue.Type().Field(i).Tag.Get("json"), ",")[0] + knownFields[jsonName] = objValue.Field(i) + } + + if err := util.Unmarshal(raw, &result.Extra); err != nil { + return nil, err + } + + for key, chunk := range result.Extra { + if field, found := knownFields[key]; found { + if err := util.Unmarshal(chunk, field.Addr().Interface()); err != nil { + return nil, err + } + delete(result.Extra, key) + } + } + if len(result.Extra) == 0 { + result.Extra = nil + } + return &result, result.validateAndInjectDefaults(id) +} + +// PluginNames returns a sorted list of names of enabled plugins. +func (c Config) PluginNames() (result []string) { + if c.Bundle != nil || c.Bundles != nil { + result = append(result, "bundles") + } + if c.Status != nil { + result = append(result, "status") + } + if c.DecisionLogs != nil { + result = append(result, "decision_logs") + } + for name := range c.Plugins { + result = append(result, name) + } + sort.Strings(result) + return result +} + +// PluginsEnabled returns true if one or more plugin features are enabled. +// +// Deprecated. Use PluginNames instead. +func (c Config) PluginsEnabled() bool { + return c.Bundle != nil || c.Bundles != nil || c.DecisionLogs != nil || c.Status != nil || len(c.Plugins) > 0 +} + +// DefaultDecisionRef returns the default decision as a reference. +func (c Config) DefaultDecisionRef() ast.Ref { + r, _ := ref.ParseDataPath(*c.DefaultDecision) + return r +} + +// DefaultAuthorizationDecisionRef returns the default authorization decision +// as a reference. +func (c Config) DefaultAuthorizationDecisionRef() ast.Ref { + r, _ := ref.ParseDataPath(*c.DefaultAuthorizationDecision) + return r +} + +// NDBuiltinCacheEnabled returns if the ND builtins cache should be used. +func (c Config) NDBuiltinCacheEnabled() bool { + return c.NDBuiltinCache +} + +func (c *Config) validateAndInjectDefaults(id string) error { + + if c.DefaultDecision == nil { + s := defaultDecisionPath + c.DefaultDecision = &s + } + + _, err := ref.ParseDataPath(*c.DefaultDecision) + if err != nil { + return err + } + + if c.DefaultAuthorizationDecision == nil { + s := defaultAuthorizationDecisionPath + c.DefaultAuthorizationDecision = &s + } + + _, err = ref.ParseDataPath(*c.DefaultAuthorizationDecision) + if err != nil { + return err + } + + if c.Labels == nil { + c.Labels = map[string]string{} + } + + c.Labels["id"] = id + c.Labels["version"] = version.Version + + return nil +} + +// GetPersistenceDirectory returns the configured persistence directory, or $PWD/.opa if none is configured +func (c Config) GetPersistenceDirectory() (string, error) { + if c.PersistenceDirectory == nil { + pwd, err := os.Getwd() + if err != nil { + return "", err + } + return filepath.Join(pwd, ".opa"), nil + } + return *c.PersistenceDirectory, nil +} + +// ActiveConfig returns OPA's active configuration +// with the credentials and crypto keys removed +func (c *Config) ActiveConfig() (interface{}, error) { + bs, err := json.Marshal(c) + if err != nil { + return nil, err + } + + var result map[string]interface{} + if err := util.UnmarshalJSON(bs, &result); err != nil { + return nil, err + } + for k, e := range c.Extra { + var v any + if err := util.UnmarshalJSON(e, &v); err != nil { + return nil, err + } + result[k] = v + } + + if err := removeServiceCredentials(result["services"]); err != nil { + return nil, err + } + + if err := removeCryptoKeys(result["keys"]); err != nil { + return nil, err + } + + return result, nil +} + +func removeServiceCredentials(x interface{}) error { + switch x := x.(type) { + case nil: + return nil + case []interface{}: + for _, v := range x { + err := removeKey(v, "credentials") + if err != nil { + return err + } + } + + case map[string]interface{}: + for _, v := range x { + err := removeKey(v, "credentials") + if err != nil { + return err + } + } + default: + return fmt.Errorf("illegal service config type: %T", x) + } + + return nil +} + +func removeCryptoKeys(x interface{}) error { + switch x := x.(type) { + case nil: + return nil + case map[string]interface{}: + for _, v := range x { + err := removeKey(v, "key", "private_key") + if err != nil { + return err + } + } + default: + return fmt.Errorf("illegal keys config type: %T", x) + } + + return nil +} + +func removeKey(x interface{}, keys ...string) error { + val, ok := x.(map[string]interface{}) + if !ok { + return fmt.Errorf("type assertion error") + } + + for _, key := range keys { + delete(val, key) + } + + return nil +} + +const ( + defaultDecisionPath = "/system/main" + defaultAuthorizationDecisionPath = "/system/authz/allow" +) diff --git a/vendor/github.com/open-policy-agent/opa/cover/cover.go b/vendor/github.com/open-policy-agent/opa/v1/cover/cover.go similarity index 98% rename from vendor/github.com/open-policy-agent/opa/cover/cover.go rename to vendor/github.com/open-policy-agent/opa/v1/cover/cover.go index 55443be35..2e119337b 100644 --- a/vendor/github.com/open-policy-agent/opa/cover/cover.go +++ b/vendor/github.com/open-policy-agent/opa/v1/cover/cover.go @@ -10,8 +10,8 @@ import ( "fmt" "sort" - "github.com/open-policy-agent/opa/ast" - "github.com/open-policy-agent/opa/topdown" + "github.com/open-policy-agent/opa/v1/ast" + "github.com/open-policy-agent/opa/v1/topdown" ) // Cover computes and reports on coverage. diff --git a/vendor/github.com/open-policy-agent/opa/dependencies/deps.go b/vendor/github.com/open-policy-agent/opa/v1/dependencies/deps.go similarity index 99% rename from vendor/github.com/open-policy-agent/opa/dependencies/deps.go rename to vendor/github.com/open-policy-agent/opa/v1/dependencies/deps.go index 4cf18befc..217516a0a 100644 --- a/vendor/github.com/open-policy-agent/opa/dependencies/deps.go +++ b/vendor/github.com/open-policy-agent/opa/v1/dependencies/deps.go @@ -8,8 +8,8 @@ import ( "fmt" "sort" - "github.com/open-policy-agent/opa/ast" - "github.com/open-policy-agent/opa/util" + "github.com/open-policy-agent/opa/v1/ast" + "github.com/open-policy-agent/opa/v1/util" ) // All returns the list of data ast.Refs that the given AST element depends on. diff --git a/vendor/github.com/open-policy-agent/opa/dependencies/doc.go b/vendor/github.com/open-policy-agent/opa/v1/dependencies/doc.go similarity index 100% rename from vendor/github.com/open-policy-agent/opa/dependencies/doc.go rename to vendor/github.com/open-policy-agent/opa/v1/dependencies/doc.go diff --git a/vendor/github.com/open-policy-agent/opa/download/config.go b/vendor/github.com/open-policy-agent/opa/v1/download/config.go similarity index 98% rename from vendor/github.com/open-policy-agent/opa/download/config.go rename to vendor/github.com/open-policy-agent/opa/v1/download/config.go index bdf8e21d2..7db194768 100644 --- a/vendor/github.com/open-policy-agent/opa/download/config.go +++ b/vendor/github.com/open-policy-agent/opa/v1/download/config.go @@ -8,7 +8,7 @@ import ( "fmt" "time" - "github.com/open-policy-agent/opa/plugins" + "github.com/open-policy-agent/opa/v1/plugins" ) const ( diff --git a/vendor/github.com/open-policy-agent/opa/download/download.go b/vendor/github.com/open-policy-agent/opa/v1/download/download.go similarity index 97% rename from vendor/github.com/open-policy-agent/opa/download/download.go rename to vendor/github.com/open-policy-agent/opa/v1/download/download.go index 6f53cec57..89c62465f 100644 --- a/vendor/github.com/open-policy-agent/opa/download/download.go +++ b/vendor/github.com/open-policy-agent/opa/v1/download/download.go @@ -18,13 +18,13 @@ import ( "sync" "time" - "github.com/open-policy-agent/opa/ast" - "github.com/open-policy-agent/opa/bundle" - "github.com/open-policy-agent/opa/logging" - "github.com/open-policy-agent/opa/metrics" - "github.com/open-policy-agent/opa/plugins" - "github.com/open-policy-agent/opa/plugins/rest" - "github.com/open-policy-agent/opa/util" + "github.com/open-policy-agent/opa/v1/ast" + "github.com/open-policy-agent/opa/v1/bundle" + "github.com/open-policy-agent/opa/v1/logging" + "github.com/open-policy-agent/opa/v1/metrics" + "github.com/open-policy-agent/opa/v1/plugins" + "github.com/open-policy-agent/opa/v1/plugins/rest" + "github.com/open-policy-agent/opa/v1/util" ) const ( diff --git a/vendor/github.com/open-policy-agent/opa/download/oci_download.go b/vendor/github.com/open-policy-agent/opa/v1/download/oci_download.go similarity index 97% rename from vendor/github.com/open-policy-agent/opa/download/oci_download.go rename to vendor/github.com/open-policy-agent/opa/v1/download/oci_download.go index 64aa02279..82a6e5fc3 100644 --- a/vendor/github.com/open-policy-agent/opa/download/oci_download.go +++ b/vendor/github.com/open-policy-agent/opa/v1/download/oci_download.go @@ -24,13 +24,13 @@ import ( oraslib "oras.land/oras-go/v2" "oras.land/oras-go/v2/content/oci" - "github.com/open-policy-agent/opa/ast" - "github.com/open-policy-agent/opa/bundle" - "github.com/open-policy-agent/opa/logging" - "github.com/open-policy-agent/opa/metrics" - "github.com/open-policy-agent/opa/plugins" - "github.com/open-policy-agent/opa/plugins/rest" - "github.com/open-policy-agent/opa/util" + "github.com/open-policy-agent/opa/v1/ast" + "github.com/open-policy-agent/opa/v1/bundle" + "github.com/open-policy-agent/opa/v1/logging" + "github.com/open-policy-agent/opa/v1/metrics" + "github.com/open-policy-agent/opa/v1/plugins" + "github.com/open-policy-agent/opa/v1/plugins/rest" + "github.com/open-policy-agent/opa/v1/util" ) // NewOCI returns a new Downloader that can be started. diff --git a/vendor/github.com/open-policy-agent/opa/download/oci_download_unavailable.go b/vendor/github.com/open-policy-agent/opa/v1/download/oci_download_unavailable.go similarity index 90% rename from vendor/github.com/open-policy-agent/opa/download/oci_download_unavailable.go rename to vendor/github.com/open-policy-agent/opa/v1/download/oci_download_unavailable.go index e105d2bd7..ad22fca9b 100644 --- a/vendor/github.com/open-policy-agent/opa/download/oci_download_unavailable.go +++ b/vendor/github.com/open-policy-agent/opa/v1/download/oci_download_unavailable.go @@ -5,9 +5,9 @@ package download import ( "context" - "github.com/open-policy-agent/opa/ast" - "github.com/open-policy-agent/opa/bundle" - "github.com/open-policy-agent/opa/plugins/rest" + "github.com/open-policy-agent/opa/v1/ast" + "github.com/open-policy-agent/opa/v1/bundle" + "github.com/open-policy-agent/opa/v1/plugins/rest" ) func NewOCI(Config, rest.Client, string, string) *OCIDownloader { diff --git a/vendor/github.com/open-policy-agent/opa/download/oci_downloader.go b/vendor/github.com/open-policy-agent/opa/v1/download/oci_downloader.go similarity index 86% rename from vendor/github.com/open-policy-agent/opa/download/oci_downloader.go rename to vendor/github.com/open-policy-agent/opa/v1/download/oci_downloader.go index 8e39aa92e..27939d92d 100644 --- a/vendor/github.com/open-policy-agent/opa/download/oci_downloader.go +++ b/vendor/github.com/open-policy-agent/opa/v1/download/oci_downloader.go @@ -4,10 +4,10 @@ import ( "context" "sync" - "github.com/open-policy-agent/opa/ast" - "github.com/open-policy-agent/opa/bundle" - "github.com/open-policy-agent/opa/logging" - "github.com/open-policy-agent/opa/plugins/rest" + "github.com/open-policy-agent/opa/v1/ast" + "github.com/open-policy-agent/opa/v1/bundle" + "github.com/open-policy-agent/opa/v1/logging" + "github.com/open-policy-agent/opa/v1/plugins/rest" "oras.land/oras-go/v2/content/oci" ) diff --git a/vendor/github.com/open-policy-agent/opa/download/testharness.go b/vendor/github.com/open-policy-agent/opa/v1/download/testharness.go similarity index 98% rename from vendor/github.com/open-policy-agent/opa/download/testharness.go rename to vendor/github.com/open-policy-agent/opa/v1/download/testharness.go index 62c8c1835..deb1b055e 100644 --- a/vendor/github.com/open-policy-agent/opa/download/testharness.go +++ b/vendor/github.com/open-policy-agent/opa/v1/download/testharness.go @@ -19,9 +19,9 @@ import ( "testing" "time" - "github.com/open-policy-agent/opa/bundle" - "github.com/open-policy-agent/opa/keys" - "github.com/open-policy-agent/opa/plugins/rest" + "github.com/open-policy-agent/opa/v1/bundle" + "github.com/open-policy-agent/opa/v1/keys" + "github.com/open-policy-agent/opa/v1/plugins/rest" ) var errUnauthorized = errors.New("401 Unauthorized") diff --git a/vendor/github.com/open-policy-agent/opa/features/tracing/tracing.go b/vendor/github.com/open-policy-agent/opa/v1/features/tracing/tracing.go similarity index 93% rename from vendor/github.com/open-policy-agent/opa/features/tracing/tracing.go rename to vendor/github.com/open-policy-agent/opa/v1/features/tracing/tracing.go index a77ad4cc8..47199b913 100644 --- a/vendor/github.com/open-policy-agent/opa/features/tracing/tracing.go +++ b/vendor/github.com/open-policy-agent/opa/v1/features/tracing/tracing.go @@ -7,7 +7,7 @@ package tracing import ( "net/http" - pkg_tracing "github.com/open-policy-agent/opa/tracing" + pkg_tracing "github.com/open-policy-agent/opa/v1/tracing" "go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp" ) diff --git a/vendor/github.com/open-policy-agent/opa/features/wasm/wasm.go b/vendor/github.com/open-policy-agent/opa/v1/features/wasm/wasm.go similarity index 100% rename from vendor/github.com/open-policy-agent/opa/features/wasm/wasm.go rename to vendor/github.com/open-policy-agent/opa/v1/features/wasm/wasm.go diff --git a/vendor/github.com/open-policy-agent/opa/format/format.go b/vendor/github.com/open-policy-agent/opa/v1/format/format.go similarity index 93% rename from vendor/github.com/open-policy-agent/opa/format/format.go rename to vendor/github.com/open-policy-agent/opa/v1/format/format.go index 43e559466..56c30171d 100644 --- a/vendor/github.com/open-policy-agent/opa/format/format.go +++ b/vendor/github.com/open-policy-agent/opa/v1/format/format.go @@ -13,9 +13,9 @@ import ( "strings" "unicode" - "github.com/open-policy-agent/opa/ast" "github.com/open-policy-agent/opa/internal/future" - "github.com/open-policy-agent/opa/types" + "github.com/open-policy-agent/opa/v1/ast" + "github.com/open-policy-agent/opa/v1/types" ) // Opts lets you control the code formatting via `AstWithOpts()`. @@ -31,6 +31,17 @@ type Opts struct { // ParserOptions is the parser options used when parsing the module to be formatted. ParserOptions *ast.ParserOptions + + // DropV0Imports instructs the formatter to drop all v0 imports from the module; i.e. 'rego.v1' and 'future.keywords' imports. + // Imports are only removed if [Opts.RegoVersion] makes them redundant. + DropV0Imports bool +} + +func (o Opts) effectiveRegoVersion() ast.RegoVersion { + if o.RegoVersion == ast.RegoUndefined { + return ast.DefaultRegoVersion + } + return o.RegoVersion } // defaultLocationFile is the file name used in `Ast()` for terms @@ -46,23 +57,29 @@ func Source(filename string, src []byte) ([]byte, error) { } func SourceWithOpts(filename string, src []byte, opts Opts) ([]byte, error) { + regoVersion := opts.effectiveRegoVersion() + var parserOpts ast.ParserOptions if opts.ParserOptions != nil { parserOpts = *opts.ParserOptions } else { - if opts.RegoVersion == ast.RegoV1 { + if regoVersion == ast.RegoV1 { // If the rego version is V1, we need to parse it as such, to allow for future keywords not being imported. // Otherwise, we'll default to the default rego-version. parserOpts.RegoVersion = ast.RegoV1 } } + if parserOpts.RegoVersion == ast.RegoUndefined { + parserOpts.RegoVersion = ast.DefaultRegoVersion + } + module, err := ast.ParseModuleWithOpts(filename, string(src), parserOpts) if err != nil { return nil, err } - if opts.RegoVersion == ast.RegoV0CompatV1 || opts.RegoVersion == ast.RegoV1 { + if regoVersion == ast.RegoV0CompatV1 || regoVersion == ast.RegoV1 { checkOpts := ast.NewRegoCheckOptions() // The module is parsed as v0, so we need to disable checks that will be automatically amended by the AstWithOpts call anyways. checkOpts.RequireIfKeyword = false @@ -127,6 +144,7 @@ type fmtOpts struct { refHeads bool regoV1 bool + regoV1Imported bool futureKeywords []string } @@ -154,7 +172,8 @@ func AstWithOpts(x interface{}, opts Opts) ([]byte, error) { o := fmtOpts{} - if opts.RegoVersion == ast.RegoV0CompatV1 || opts.RegoVersion == ast.RegoV1 { + regoVersion := opts.effectiveRegoVersion() + if regoVersion == ast.RegoV0CompatV1 || regoVersion == ast.RegoV1 { o.regoV1 = true o.ifs = true o.contains = true @@ -186,6 +205,7 @@ func AstWithOpts(x interface{}, opts Opts) ([]byte, error) { switch { case isRegoV1Compatible(n): + o.regoV1Imported = true o.contains = true o.ifs = true case future.IsAllFutureKeywords(n): @@ -220,14 +240,21 @@ func AstWithOpts(x interface{}, opts Opts) ([]byte, error) { switch x := x.(type) { case *ast.Module: - if opts.RegoVersion == ast.RegoV1 { + if regoVersion == ast.RegoV1 && opts.DropV0Imports { x.Imports = filterRegoV1Import(x.Imports) - } else if opts.RegoVersion == ast.RegoV0CompatV1 { + } else if regoVersion == ast.RegoV0CompatV1 { x.Imports = ensureRegoV1Import(x.Imports) } - if opts.RegoVersion == ast.RegoV0CompatV1 || opts.RegoVersion == ast.RegoV1 || moduleIsRegoV1Compatible(x) { - x.Imports = future.FilterFutureImports(x.Imports) + regoV1Imported := moduleIsRegoV1Compatible(x) + if regoVersion == ast.RegoV0CompatV1 || regoVersion == ast.RegoV1 || regoV1Imported { + if !opts.DropV0Imports && !regoV1Imported { + for _, kw := range o.futureKeywords { + x.Imports = ensureFutureKeywordImport(x.Imports, kw) + } + } else { + x.Imports = future.FilterFutureImports(x.Imports) + } } else { for kw := range extraFutureKeywordImports { x.Imports = ensureFutureKeywordImport(x.Imports, kw) @@ -386,7 +413,7 @@ func (w *writer) writePackage(pkg *ast.Package, comments []*ast.Comment) []*ast. copy(path[1:], pkg.Path[2:]) w.write("package ") - w.writeRef(path) + w.writeRef(path, nil) w.blankLine() @@ -526,7 +553,11 @@ func (w *writer) writeElse(rule *ast.Rule, comments []*ast.Comment) []*ast.Comme } rule.Else.Head.Name = "else" // NOTE(sr): whaaat - rule.Else.Head.Reference = ast.Ref{ast.VarTerm("else")} + + elseHeadReference := ast.VarTerm("else") // construct a reference for the term + elseHeadReference.Location = rule.Else.Head.Location // and set the location to match the rule location + + rule.Else.Head.Reference = ast.Ref{elseHeadReference} rule.Else.Head.Args = nil comments = w.insertComments(comments, rule.Else.Head.Location) @@ -552,7 +583,7 @@ func (w *writer) writeHead(head *ast.Head, isDefault, isExpandedConst bool, comm ref = ref.GroundPrefix() } if w.fmtOpts.refHeads || len(ref) == 1 { - w.writeRef(ref) + w.writeRef(ref, comments) } else { w.write(ref[0].String()) w.write("[") @@ -754,7 +785,7 @@ func (w *writer) writeFunctionCall(expr *ast.Expr, comments []*ast.Comment) []*a return w.writeFunctionCallPlain(terms, comments) } - numDeclArgs := len(bi.Decl.Args()) + numDeclArgs := bi.Decl.Arity() numCallArgs := len(terms) - 1 switch numCallArgs { @@ -811,7 +842,7 @@ func (w *writer) writeTermParens(parens bool, term *ast.Term, comments []*ast.Co switch x := term.Value.(type) { case ast.Ref: - w.writeRef(x) + comments = w.writeRef(x, comments) case ast.Object: comments = w.writeObject(x, term.Location, comments) case *ast.Array: @@ -846,14 +877,14 @@ func (w *writer) writeTermParens(parens bool, term *ast.Term, comments []*ast.Co return comments } -func (w *writer) writeRef(x ast.Ref) { +func (w *writer) writeRef(x ast.Ref, comments []*ast.Comment) []*ast.Comment { if len(x) > 0 { parens := false _, ok := x[0].Value.(ast.Call) if ok { parens = x[0].Location.Text[0] == 40 // Starts with "(" } - w.writeTermParens(parens, x[0], nil) + comments = w.writeTermParens(parens, x[0], comments) path := x[1:] for _, t := range path { switch p := t.Value.(type) { @@ -863,11 +894,13 @@ func (w *writer) writeRef(x ast.Ref) { w.writeBracketed(w.formatVar(p)) default: w.write("[") - w.writeTerm(t, nil) + comments = w.writeTerm(t, comments) w.write("]") } } } + + return comments } func (w *writer) writeBracketed(str string) { @@ -912,7 +945,7 @@ func (w *writer) writeCall(parens bool, x ast.Call, loc *ast.Location, comments // NOTE(Trolloldem): writeCall is only invoked when the function call is a term // of another function. The only valid arity is the one of the // built-in function - if len(bi.Decl.Args()) != len(x)-1 { + if bi.Decl.Arity() != len(x)-1 { w.errs = append(w.errs, ArityFormatMismatchError(x[1:], x[0].String(), loc, bi.Decl)) return comments } @@ -929,10 +962,10 @@ func (w *writer) writeCall(parens bool, x ast.Call, loc *ast.Location, comments func (w *writer) writeInOperator(parens bool, operands []*ast.Term, comments []*ast.Comment, loc *ast.Location, f *types.Function) []*ast.Comment { - if len(operands) != len(f.Args()) { + if len(operands) != f.Arity() { // The number of operands does not math the arity of the `in` operator operator := ast.Member.Name - if len(f.Args()) == 3 { + if f.Arity() == 3 { operator = ast.MemberWithKey.Name } w.errs = append(w.errs, ArityFormatMismatchError(operands, operator, loc, f)) @@ -1030,7 +1063,7 @@ func (w *writer) writeObjectComprehension(object *ast.ObjectComprehension, loc * return w.writeComprehension('{', '}', object.Value, object.Body, loc, comments) } -func (w *writer) writeComprehension(open, close byte, term *ast.Term, body ast.Body, loc *ast.Location, comments []*ast.Comment) []*ast.Comment { +func (w *writer) writeComprehension(openChar, closeChar byte, term *ast.Term, body ast.Body, loc *ast.Location, comments []*ast.Comment) []*ast.Comment { if term.Location.Row-loc.Row >= 1 { w.endLine() w.startLine() @@ -1044,10 +1077,10 @@ func (w *writer) writeComprehension(open, close byte, term *ast.Term, body ast.B comments = w.writeTermParens(parens, term, comments) w.write(" |") - return w.writeComprehensionBody(open, close, body, term.Location, loc, comments) + return w.writeComprehensionBody(openChar, closeChar, body, term.Location, loc, comments) } -func (w *writer) writeComprehensionBody(open, close byte, body ast.Body, term, compr *ast.Location, comments []*ast.Comment) []*ast.Comment { +func (w *writer) writeComprehensionBody(openChar, closeChar byte, body ast.Body, term, compr *ast.Location, comments []*ast.Comment) []*ast.Comment { exprs := make([]interface{}, 0, len(body)) for _, expr := range body { exprs = append(exprs, expr) @@ -1071,7 +1104,7 @@ func (w *writer) writeComprehensionBody(open, close byte, body ast.Body, term, c comments = w.writeExpr(body[i], comments) } - return w.insertComments(comments, closingLoc(0, 0, open, close, compr)) + return w.insertComments(comments, closingLoc(0, 0, openChar, closeChar, compr)) } func (w *writer) writeImports(imports []*ast.Import, comments []*ast.Comment) []*ast.Comment { @@ -1111,7 +1144,7 @@ func (w *writer) writeImport(imp *ast.Import) { w2 := writer{ buf: bytes.Buffer{}, } - w2.writeRef(path) + w2.writeRef(path, nil) buf = append(buf, w2.buf.String()) } else { buf = append(buf, path.String()) @@ -1397,7 +1430,7 @@ func getLoc(x interface{}) *ast.Location { } } -func closingLoc(skipOpen, skipClose, open, close byte, loc *ast.Location) *ast.Location { +func closingLoc(skipOpen, skipClose, openChar, closeChar byte, loc *ast.Location) *ast.Location { i, offset := 0, 0 // Skip past parens/brackets/braces in rule heads. @@ -1406,7 +1439,7 @@ func closingLoc(skipOpen, skipClose, open, close byte, loc *ast.Location) *ast.L } for ; i < len(loc.Text); i++ { - if loc.Text[i] == open { + if loc.Text[i] == openChar { break } } @@ -1423,9 +1456,9 @@ func closingLoc(skipOpen, skipClose, open, close byte, loc *ast.Location) *ast.L } switch loc.Text[i] { - case open: + case openChar: state++ - case close: + case closeChar: state-- case '\n': offset++ @@ -1435,10 +1468,10 @@ func closingLoc(skipOpen, skipClose, open, close byte, loc *ast.Location) *ast.L return &ast.Location{Row: loc.Row + offset} } -func skipPast(open, close byte, loc *ast.Location) (int, int) { +func skipPast(openChar, closeChar byte, loc *ast.Location) (int, int) { i := 0 for ; i < len(loc.Text); i++ { - if loc.Text[i] == open { + if loc.Text[i] == openChar { break } } @@ -1452,9 +1485,9 @@ func skipPast(open, close byte, loc *ast.Location) (int, int) { } switch loc.Text[i] { - case open: + case openChar: state++ - case close: + case closeChar: state-- case '\n': offset++ @@ -1597,9 +1630,9 @@ type ArityFormatErrDetail struct { // arityMismatchError but for `fmt` checks since the compiler has not run yet. func ArityFormatMismatchError(operands []*ast.Term, operator string, loc *ast.Location, f *types.Function) *ast.Error { - want := make([]string, len(f.Args())) - for i := range f.Args() { - want[i] = types.Sprint(f.Args()[i]) + want := make([]string, f.Arity()) + for i, arg := range f.Args() { + want[i] = types.Sprint(arg) } have := make([]string, len(operands)) diff --git a/vendor/github.com/open-policy-agent/opa/v1/hooks/hooks.go b/vendor/github.com/open-policy-agent/opa/v1/hooks/hooks.go new file mode 100644 index 000000000..caf69b124 --- /dev/null +++ b/vendor/github.com/open-policy-agent/opa/v1/hooks/hooks.go @@ -0,0 +1,77 @@ +// Copyright 2023 The OPA Authors. All rights reserved. +// Use of this source code is governed by an Apache2 +// license that can be found in the LICENSE file. + +package hooks + +import ( + "context" + "fmt" + + "github.com/open-policy-agent/opa/v1/config" +) + +// Hook is a hook to be called in some select places in OPA's operation. +// +// The base Hook interface is any, and wherever a hook can occur, the calling code +// will check if your hook implements an appropriate interface. If so, your hook +// is called. +// +// This allows you to only hook in to behavior you care about, and it allows the +// OPA to add more hooks in the future. +// +// All hook interfaces in this package have Hook in the name. Hooks must be safe +// for concurrent use. It is expected that hooks are fast; if a hook needs to take +// time, then copy what you need and ensure the hook is async. +// +// When multiple instances of a hook are provided, they are all going to be executed +// in an unspecified order (it's a map-range call underneath). If you need hooks to +// be run in order, you can wrap them into another hook, and configure that one. +type Hook any + +// Hooks is the type used for every struct in OPA that can work with hooks. +type Hooks struct { + m map[Hook]struct{} // we are NOT providing a stable invocation ordering +} + +// New creates a new instance of Hooks. +func New(hs ...Hook) Hooks { + h := Hooks{m: make(map[Hook]struct{}, len(hs))} + for i := range hs { + h.m[hs[i]] = struct{}{} + } + return h +} + +func (hs Hooks) Each(fn func(Hook)) { + for h := range hs.m { + fn(h) + } +} + +// ConfigHook allows inspecting or rewriting the configuration when the plugin +// manager is processing it. +// Note that this hook is not run when the plugin manager is reconfigured. This +// usually only happens when there's a new config from a discovery bundle, and +// for processing _that_, there's `ConfigDiscoveryHook`. +type ConfigHook interface { + OnConfig(context.Context, *config.Config) (*config.Config, error) +} + +// ConfigHook allows inspecting or rewriting the discovered configuration when +// the discovery plugin is processing it. +type ConfigDiscoveryHook interface { + OnConfigDiscovery(context.Context, *config.Config) (*config.Config, error) +} + +func (hs Hooks) Validate() error { + for h := range hs.m { + switch h.(type) { + case ConfigHook, + ConfigDiscoveryHook: // OK + default: + return fmt.Errorf("unknown hook type %T", h) + } + } + return nil +} diff --git a/vendor/github.com/open-policy-agent/opa/ir/ir.go b/vendor/github.com/open-policy-agent/opa/v1/ir/ir.go similarity index 99% rename from vendor/github.com/open-policy-agent/opa/ir/ir.go rename to vendor/github.com/open-policy-agent/opa/v1/ir/ir.go index c07670704..4f6961605 100644 --- a/vendor/github.com/open-policy-agent/opa/ir/ir.go +++ b/vendor/github.com/open-policy-agent/opa/v1/ir/ir.go @@ -11,7 +11,7 @@ package ir import ( "fmt" - "github.com/open-policy-agent/opa/types" + "github.com/open-policy-agent/opa/v1/types" ) type ( diff --git a/vendor/github.com/open-policy-agent/opa/ir/marshal.go b/vendor/github.com/open-policy-agent/opa/v1/ir/marshal.go similarity index 100% rename from vendor/github.com/open-policy-agent/opa/ir/marshal.go rename to vendor/github.com/open-policy-agent/opa/v1/ir/marshal.go diff --git a/vendor/github.com/open-policy-agent/opa/ir/pretty.go b/vendor/github.com/open-policy-agent/opa/v1/ir/pretty.go similarity index 100% rename from vendor/github.com/open-policy-agent/opa/ir/pretty.go rename to vendor/github.com/open-policy-agent/opa/v1/ir/pretty.go diff --git a/vendor/github.com/open-policy-agent/opa/ir/walk.go b/vendor/github.com/open-policy-agent/opa/v1/ir/walk.go similarity index 100% rename from vendor/github.com/open-policy-agent/opa/ir/walk.go rename to vendor/github.com/open-policy-agent/opa/v1/ir/walk.go diff --git a/vendor/github.com/open-policy-agent/opa/keys/keys.go b/vendor/github.com/open-policy-agent/opa/v1/keys/keys.go similarity index 98% rename from vendor/github.com/open-policy-agent/opa/keys/keys.go rename to vendor/github.com/open-policy-agent/opa/v1/keys/keys.go index de0349694..fba7a9c93 100644 --- a/vendor/github.com/open-policy-agent/opa/keys/keys.go +++ b/vendor/github.com/open-policy-agent/opa/v1/keys/keys.go @@ -5,7 +5,7 @@ import ( "fmt" "os" - "github.com/open-policy-agent/opa/util" + "github.com/open-policy-agent/opa/v1/util" ) const defaultSigningAlgorithm = "RS256" diff --git a/vendor/github.com/open-policy-agent/opa/v1/loader/errors.go b/vendor/github.com/open-policy-agent/opa/v1/loader/errors.go new file mode 100644 index 000000000..55b8e7dc4 --- /dev/null +++ b/vendor/github.com/open-policy-agent/opa/v1/loader/errors.go @@ -0,0 +1,62 @@ +// Copyright 2017 The OPA Authors. All rights reserved. +// Use of this source code is governed by an Apache2 +// license that can be found in the LICENSE file. + +package loader + +import ( + "fmt" + "strings" + + "github.com/open-policy-agent/opa/v1/ast" +) + +// Errors is a wrapper for multiple loader errors. +type Errors []error + +func (e Errors) Error() string { + if len(e) == 0 { + return "no error(s)" + } + if len(e) == 1 { + return "1 error occurred during loading: " + e[0].Error() + } + buf := make([]string, len(e)) + for i := range buf { + buf[i] = e[i].Error() + } + return fmt.Sprintf("%v errors occurred during loading:\n", len(e)) + strings.Join(buf, "\n") +} + +func (e *Errors) add(err error) { + if errs, ok := err.(ast.Errors); ok { + for i := range errs { + *e = append(*e, errs[i]) + } + } else { + *e = append(*e, err) + } +} + +type unsupportedDocumentType string + +func (path unsupportedDocumentType) Error() string { + return string(path) + ": document must be of type object" +} + +type unrecognizedFile string + +func (path unrecognizedFile) Error() string { + return string(path) + ": can't recognize file type" +} + +func isUnrecognizedFile(err error) bool { + _, ok := err.(unrecognizedFile) + return ok +} + +type mergeError string + +func (e mergeError) Error() string { + return string(e) + ": merge error" +} diff --git a/vendor/github.com/open-policy-agent/opa/loader/extension/extension.go b/vendor/github.com/open-policy-agent/opa/v1/loader/extension/extension.go similarity index 100% rename from vendor/github.com/open-policy-agent/opa/loader/extension/extension.go rename to vendor/github.com/open-policy-agent/opa/v1/loader/extension/extension.go diff --git a/vendor/github.com/open-policy-agent/opa/loader/filter/filter.go b/vendor/github.com/open-policy-agent/opa/v1/loader/filter/filter.go similarity index 100% rename from vendor/github.com/open-policy-agent/opa/loader/filter/filter.go rename to vendor/github.com/open-policy-agent/opa/v1/loader/filter/filter.go diff --git a/vendor/github.com/open-policy-agent/opa/v1/loader/loader.go b/vendor/github.com/open-policy-agent/opa/v1/loader/loader.go new file mode 100644 index 000000000..8daf22458 --- /dev/null +++ b/vendor/github.com/open-policy-agent/opa/v1/loader/loader.go @@ -0,0 +1,838 @@ +// Copyright 2017 The OPA Authors. All rights reserved. +// Use of this source code is governed by an Apache2 +// license that can be found in the LICENSE file. + +// Package loader contains utilities for loading files into OPA. +package loader + +import ( + "bytes" + "fmt" + "io" + "io/fs" + "os" + "path/filepath" + "sort" + "strings" + + "sigs.k8s.io/yaml" + + fileurl "github.com/open-policy-agent/opa/internal/file/url" + "github.com/open-policy-agent/opa/internal/merge" + "github.com/open-policy-agent/opa/v1/ast" + astJSON "github.com/open-policy-agent/opa/v1/ast/json" + "github.com/open-policy-agent/opa/v1/bundle" + "github.com/open-policy-agent/opa/v1/loader/filter" + "github.com/open-policy-agent/opa/v1/metrics" + "github.com/open-policy-agent/opa/v1/storage" + "github.com/open-policy-agent/opa/v1/storage/inmem" + "github.com/open-policy-agent/opa/v1/util" +) + +// Result represents the result of successfully loading zero or more files. +type Result struct { + Documents map[string]interface{} + Modules map[string]*RegoFile + path []string +} + +// ParsedModules returns the parsed modules stored on the result. +func (l *Result) ParsedModules() map[string]*ast.Module { + modules := make(map[string]*ast.Module) + for _, module := range l.Modules { + modules[module.Name] = module.Parsed + } + return modules +} + +// Compiler returns a Compiler object with the compiled modules from this loader +// result. +func (l *Result) Compiler() (*ast.Compiler, error) { + compiler := ast.NewCompiler() + compiler.Compile(l.ParsedModules()) + if compiler.Failed() { + return nil, compiler.Errors + } + return compiler, nil +} + +// Store returns a Store object with the documents from this loader result. +func (l *Result) Store() (storage.Store, error) { + return l.StoreWithOpts() +} + +// StoreWithOpts returns a Store object with the documents from this loader result, +// instantiated with the passed options. +func (l *Result) StoreWithOpts(opts ...inmem.Opt) (storage.Store, error) { + return inmem.NewFromObjectWithOpts(l.Documents, opts...), nil +} + +// RegoFile represents the result of loading a single Rego source file. +type RegoFile struct { + Name string + Parsed *ast.Module + Raw []byte +} + +// Filter defines the interface for filtering files during loading. If the +// filter returns true, the file should be excluded from the result. +type Filter = filter.LoaderFilter + +// GlobExcludeName excludes files and directories whose names do not match the +// shell style pattern at minDepth or greater. +func GlobExcludeName(pattern string, minDepth int) Filter { + return func(_ string, info fs.FileInfo, depth int) bool { + match, _ := filepath.Match(pattern, info.Name()) + return match && depth >= minDepth + } +} + +// FileLoader defines an interface for loading OPA data files +// and Rego policies. +type FileLoader interface { + All(paths []string) (*Result, error) + Filtered(paths []string, filter Filter) (*Result, error) + AsBundle(path string) (*bundle.Bundle, error) + WithReader(io.Reader) FileLoader + WithFS(fs.FS) FileLoader + WithMetrics(metrics.Metrics) FileLoader + WithFilter(Filter) FileLoader + WithBundleVerificationConfig(*bundle.VerificationConfig) FileLoader + WithSkipBundleVerification(bool) FileLoader + WithProcessAnnotation(bool) FileLoader + WithCapabilities(*ast.Capabilities) FileLoader + WithJSONOptions(*astJSON.Options) FileLoader + WithRegoVersion(ast.RegoVersion) FileLoader + WithFollowSymlinks(bool) FileLoader +} + +// NewFileLoader returns a new FileLoader instance. +func NewFileLoader() FileLoader { + return &fileLoader{ + metrics: metrics.New(), + files: make(map[string]bundle.FileInfo), + } +} + +type fileLoader struct { + metrics metrics.Metrics + filter Filter + bvc *bundle.VerificationConfig + skipVerify bool + files map[string]bundle.FileInfo + opts ast.ParserOptions + fsys fs.FS + reader io.Reader + followSymlinks bool +} + +// WithFS provides an fs.FS to use for loading files. You can pass nil to +// use plain IO calls (e.g. os.Open, os.Stat, etc.), this is the default +// behaviour. +func (fl *fileLoader) WithFS(fsys fs.FS) FileLoader { + fl.fsys = fsys + return fl +} + +// WithReader provides an io.Reader to use for loading the bundle tarball. +// An io.Reader passed via WithReader takes precedence over an fs.FS passed +// via WithFS. +func (fl *fileLoader) WithReader(rdr io.Reader) FileLoader { + fl.reader = rdr + return fl +} + +// WithMetrics provides the metrics instance to use while loading +func (fl *fileLoader) WithMetrics(m metrics.Metrics) FileLoader { + fl.metrics = m + return fl +} + +// WithFilter specifies the filter object to use to filter files while loading +func (fl *fileLoader) WithFilter(filter Filter) FileLoader { + fl.filter = filter + return fl +} + +// WithBundleVerificationConfig sets the key configuration used to verify a signed bundle +func (fl *fileLoader) WithBundleVerificationConfig(config *bundle.VerificationConfig) FileLoader { + fl.bvc = config + return fl +} + +// WithSkipBundleVerification skips verification of a signed bundle +func (fl *fileLoader) WithSkipBundleVerification(skipVerify bool) FileLoader { + fl.skipVerify = skipVerify + return fl +} + +// WithProcessAnnotation enables or disables processing of schema annotations on rules +func (fl *fileLoader) WithProcessAnnotation(processAnnotation bool) FileLoader { + fl.opts.ProcessAnnotation = processAnnotation + return fl +} + +// WithCapabilities sets the supported capabilities when loading the files +func (fl *fileLoader) WithCapabilities(caps *ast.Capabilities) FileLoader { + fl.opts.Capabilities = caps + return fl +} + +// WithJSONOptions sets the JSONOptions for use when parsing files +func (fl *fileLoader) WithJSONOptions(opts *astJSON.Options) FileLoader { + fl.opts.JSONOptions = opts + return fl +} + +// WithRegoVersion sets the ast.RegoVersion to use when parsing and compiling modules. +func (fl *fileLoader) WithRegoVersion(version ast.RegoVersion) FileLoader { + fl.opts.RegoVersion = version + return fl +} + +// WithFollowSymlinks enables or disables following symlinks when loading files +func (fl *fileLoader) WithFollowSymlinks(followSymlinks bool) FileLoader { + fl.followSymlinks = followSymlinks + return fl +} + +// All returns a Result object loaded (recursively) from the specified paths. +func (fl fileLoader) All(paths []string) (*Result, error) { + return fl.Filtered(paths, nil) +} + +// Filtered returns a Result object loaded (recursively) from the specified +// paths while applying the given filters. If any filter returns true, the +// file/directory is excluded. +func (fl fileLoader) Filtered(paths []string, filter Filter) (*Result, error) { + return all(fl.fsys, paths, filter, func(curr *Result, path string, depth int) error { + + var ( + bs []byte + err error + ) + if fl.fsys != nil { + bs, err = fs.ReadFile(fl.fsys, path) + } else { + bs, err = os.ReadFile(path) + } + if err != nil { + return err + } + + result, err := loadKnownTypes(path, bs, fl.metrics, fl.opts) + if err != nil { + if !isUnrecognizedFile(err) { + return err + } + if depth > 0 { + return nil + } + result, err = loadFileForAnyType(path, bs, fl.metrics, fl.opts) + if err != nil { + return err + } + } + + return curr.merge(path, result) + }) +} + +// AsBundle loads a path as a bundle. If it is a single file +// it will be treated as a normal tarball bundle. If a directory +// is supplied it will be loaded as an unzipped bundle tree. +func (fl fileLoader) AsBundle(path string) (*bundle.Bundle, error) { + path, err := fileurl.Clean(path) + if err != nil { + return nil, err + } + + if err := checkForUNCPath(path); err != nil { + return nil, err + } + + var bundleLoader bundle.DirectoryLoader + var isDir bool + if fl.reader != nil { + bundleLoader = bundle.NewTarballLoaderWithBaseURL(fl.reader, path).WithFilter(fl.filter) + } else { + bundleLoader, isDir, err = GetBundleDirectoryLoaderFS(fl.fsys, path, fl.filter) + } + + if err != nil { + return nil, err + } + bundleLoader = bundleLoader.WithFollowSymlinks(fl.followSymlinks) + + br := bundle.NewCustomReader(bundleLoader). + WithMetrics(fl.metrics). + WithBundleVerificationConfig(fl.bvc). + WithSkipBundleVerification(fl.skipVerify). + WithProcessAnnotations(fl.opts.ProcessAnnotation). + WithCapabilities(fl.opts.Capabilities). + WithJSONOptions(fl.opts.JSONOptions). + WithFollowSymlinks(fl.followSymlinks). + WithRegoVersion(fl.opts.RegoVersion) + + // For bundle directories add the full path in front of module file names + // to simplify debugging. + if isDir { + br.WithBaseDir(path) + } + + b, err := br.Read() + if err != nil { + err = fmt.Errorf("bundle %s: %w", path, err) + } + + return &b, err +} + +// GetBundleDirectoryLoader returns a bundle directory loader which can be used to load +// files in the directory +func GetBundleDirectoryLoader(path string) (bundle.DirectoryLoader, bool, error) { + return GetBundleDirectoryLoaderFS(nil, path, nil) +} + +// GetBundleDirectoryLoaderWithFilter returns a bundle directory loader which can be used to load +// files in the directory after applying the given filter. +func GetBundleDirectoryLoaderWithFilter(path string, filter Filter) (bundle.DirectoryLoader, bool, error) { + return GetBundleDirectoryLoaderFS(nil, path, filter) +} + +// GetBundleDirectoryLoaderFS returns a bundle directory loader which can be used to load +// files in the directory. +func GetBundleDirectoryLoaderFS(fsys fs.FS, path string, filter Filter) (bundle.DirectoryLoader, bool, error) { + path, err := fileurl.Clean(path) + if err != nil { + return nil, false, err + } + + if err := checkForUNCPath(path); err != nil { + return nil, false, err + } + + var fi fs.FileInfo + if fsys != nil { + fi, err = fs.Stat(fsys, path) + } else { + fi, err = os.Stat(path) + } + if err != nil { + return nil, false, fmt.Errorf("error reading %q: %s", path, err) + } + + var bundleLoader bundle.DirectoryLoader + if fi.IsDir() { + if fsys != nil { + bundleLoader = bundle.NewFSLoaderWithRoot(fsys, path) + } else { + bundleLoader = bundle.NewDirectoryLoader(path) + } + } else { + var fh fs.File + if fsys != nil { + fh, err = fsys.Open(path) + } else { + fh, err = os.Open(path) + } + if err != nil { + return nil, false, err + } + bundleLoader = bundle.NewTarballLoaderWithBaseURL(fh, path) + } + + if filter != nil { + bundleLoader = bundleLoader.WithFilter(filter) + } + return bundleLoader, fi.IsDir(), nil +} + +// FilteredPaths is the same as FilterPathsFS using the current diretory file +// system +func FilteredPaths(paths []string, filter Filter) ([]string, error) { + return FilteredPathsFS(nil, paths, filter) +} + +// FilteredPathsFS return a list of files from the specified +// paths while applying the given filters. If any filter returns true, the +// file/directory is excluded. +func FilteredPathsFS(fsys fs.FS, paths []string, filter Filter) ([]string, error) { + result := []string{} + + _, err := all(fsys, paths, filter, func(_ *Result, path string, _ int) error { + result = append(result, path) + return nil + }) + if err != nil { + return nil, err + } + return result, nil +} + +// Schemas loads a schema set from the specified file path. +func Schemas(schemaPath string) (*ast.SchemaSet, error) { + + var errs Errors + ss, err := loadSchemas(schemaPath) + if err != nil { + errs.add(err) + return nil, errs + } + + return ss, nil +} + +func loadSchemas(schemaPath string) (*ast.SchemaSet, error) { + + if schemaPath == "" { + return nil, nil + } + + ss := ast.NewSchemaSet() + path, err := fileurl.Clean(schemaPath) + if err != nil { + return nil, err + } + + info, err := os.Stat(path) + if err != nil { + return nil, err + } + + // Handle single file case. + if !info.IsDir() { + schema, err := loadOneSchema(path) + if err != nil { + return nil, err + } + ss.Put(ast.SchemaRootRef, schema) + return ss, nil + + } + + // Handle directory case. + rootDir := path + + err = filepath.Walk(path, + func(path string, info os.FileInfo, err error) error { + if err != nil { + return err + } else if info.IsDir() { + return nil + } + + schema, err := loadOneSchema(path) + if err != nil { + return err + } + + relPath, err := filepath.Rel(rootDir, path) + if err != nil { + return err + } + + key := getSchemaSetByPathKey(relPath) + ss.Put(key, schema) + return nil + }) + + if err != nil { + return nil, err + } + + return ss, nil +} + +func getSchemaSetByPathKey(path string) ast.Ref { + + front := filepath.Dir(path) + last := strings.TrimSuffix(filepath.Base(path), filepath.Ext(path)) + + var parts []string + + if front != "." { + parts = append(strings.Split(filepath.ToSlash(front), "/"), last) + } else { + parts = []string{last} + } + + key := make(ast.Ref, 1+len(parts)) + key[0] = ast.SchemaRootDocument + for i := range parts { + key[i+1] = ast.StringTerm(parts[i]) + } + + return key +} + +func loadOneSchema(path string) (interface{}, error) { + bs, err := os.ReadFile(path) + if err != nil { + return nil, err + } + + var schema interface{} + if err := util.Unmarshal(bs, &schema); err != nil { + return nil, fmt.Errorf("%s: %w", path, err) + } + + return schema, nil +} + +// All returns a Result object loaded (recursively) from the specified paths. +// Deprecated: Use FileLoader.Filtered() instead. +func All(paths []string) (*Result, error) { + return NewFileLoader().Filtered(paths, nil) +} + +// Filtered returns a Result object loaded (recursively) from the specified +// paths while applying the given filters. If any filter returns true, the +// file/directory is excluded. +// Deprecated: Use FileLoader.Filtered() instead. +func Filtered(paths []string, filter Filter) (*Result, error) { + return NewFileLoader().Filtered(paths, filter) +} + +// AsBundle loads a path as a bundle. If it is a single file +// it will be treated as a normal tarball bundle. If a directory +// is supplied it will be loaded as an unzipped bundle tree. +// Deprecated: Use FileLoader.AsBundle() instead. +func AsBundle(path string) (*bundle.Bundle, error) { + return NewFileLoader().AsBundle(path) +} + +// AllRegos returns a Result object loaded (recursively) with all Rego source +// files from the specified paths. +func AllRegos(paths []string) (*Result, error) { + return NewFileLoader().Filtered(paths, func(_ string, info os.FileInfo, _ int) bool { + return !info.IsDir() && !strings.HasSuffix(info.Name(), bundle.RegoExt) + }) +} + +// Rego is deprecated. Use RegoWithOpts instead. +func Rego(path string) (*RegoFile, error) { + return RegoWithOpts(path, ast.ParserOptions{}) +} + +// RegoWithOpts returns a RegoFile object loaded from the given path. +func RegoWithOpts(path string, opts ast.ParserOptions) (*RegoFile, error) { + path, err := fileurl.Clean(path) + if err != nil { + return nil, err + } + bs, err := os.ReadFile(path) + if err != nil { + return nil, err + } + return loadRego(path, bs, metrics.New(), opts) +} + +// CleanPath returns the normalized version of a path that can be used as an identifier. +func CleanPath(path string) string { + return strings.Trim(path, "/") +} + +// Paths returns a sorted list of files contained at path. If recurse is true +// and path is a directory, then Paths will walk the directory structure +// recursively and list files at each level. +func Paths(path string, recurse bool) (paths []string, err error) { + path, err = fileurl.Clean(path) + if err != nil { + return nil, err + } + err = filepath.Walk(path, func(f string, _ os.FileInfo, _ error) error { + if !recurse { + if path != f && path != filepath.Dir(f) { + return filepath.SkipDir + } + } + paths = append(paths, f) + return nil + }) + return paths, err +} + +// Dirs resolves filepaths to directories. It will return a list of unique +// directories. +func Dirs(paths []string) []string { + unique := map[string]struct{}{} + + for _, path := range paths { + // TODO: /dir/dir will register top level directory /dir + dir := filepath.Dir(path) + unique[dir] = struct{}{} + } + + u := make([]string, 0, len(unique)) + for k := range unique { + u = append(u, k) + } + sort.Strings(u) + return u +} + +// SplitPrefix returns a tuple specifying the document prefix and the file +// path. +func SplitPrefix(path string) ([]string, string) { + // Non-prefixed URLs can be returned without modification and their contents + // can be rooted directly under data. + if strings.Index(path, "://") == strings.Index(path, ":") { + return nil, path + } + parts := strings.SplitN(path, ":", 2) + if len(parts) == 2 && len(parts[0]) > 0 { + return strings.Split(parts[0], "."), parts[1] + } + return nil, path +} + +func (l *Result) merge(path string, result interface{}) error { + switch result := result.(type) { + case bundle.Bundle: + for _, module := range result.Modules { + l.Modules[module.Path] = &RegoFile{ + Name: module.Path, + Parsed: module.Parsed, + Raw: module.Raw, + } + } + return l.mergeDocument(path, result.Data) + case *RegoFile: + l.Modules[CleanPath(path)] = result + return nil + default: + return l.mergeDocument(path, result) + } +} + +func (l *Result) mergeDocument(path string, doc interface{}) error { + obj, ok := makeDir(l.path, doc) + if !ok { + return unsupportedDocumentType(path) + } + merged, ok := merge.InterfaceMaps(l.Documents, obj) + if !ok { + return mergeError(path) + } + for k := range merged { + l.Documents[k] = merged[k] + } + return nil +} + +func (l *Result) withParent(p string) *Result { + path := append(l.path, p) + return &Result{ + Documents: l.Documents, + Modules: l.Modules, + path: path, + } +} + +func newResult() *Result { + return &Result{ + Documents: map[string]interface{}{}, + Modules: map[string]*RegoFile{}, + } +} + +func all(fsys fs.FS, paths []string, filter Filter, f func(*Result, string, int) error) (*Result, error) { + errs := Errors{} + root := newResult() + + for _, path := range paths { + + // Paths can be prefixed with a string that specifies where content should be + // loaded under data. E.g., foo.bar:/path/to/some.json will load the content + // of some.json under {"foo": {"bar": ...}}. + loaded := root + prefix, path := SplitPrefix(path) + if len(prefix) > 0 { + for _, part := range prefix { + loaded = loaded.withParent(part) + } + } + + allRec(fsys, path, filter, &errs, loaded, 0, f) + } + + if len(errs) > 0 { + return nil, errs + } + + return root, nil +} + +func allRec(fsys fs.FS, path string, filter Filter, errors *Errors, loaded *Result, depth int, f func(*Result, string, int) error) { + + path, err := fileurl.Clean(path) + if err != nil { + errors.add(err) + return + } + + if err := checkForUNCPath(path); err != nil { + errors.add(err) + return + } + + var info fs.FileInfo + if fsys != nil { + info, err = fs.Stat(fsys, path) + } else { + info, err = os.Stat(path) + } + + if err != nil { + errors.add(err) + return + } + + if filter != nil && filter(path, info, depth) { + return + } + + if !info.IsDir() { + if err := f(loaded, path, depth); err != nil { + errors.add(err) + } + return + } + + // If we are recursing on directories then content must be loaded under path + // specified by directory hierarchy. + if depth > 0 { + loaded = loaded.withParent(info.Name()) + } + + var files []fs.DirEntry + if fsys != nil { + files, err = fs.ReadDir(fsys, path) + } else { + files, err = os.ReadDir(path) + } + if err != nil { + errors.add(err) + return + } + + for _, file := range files { + allRec(fsys, filepath.Join(path, file.Name()), filter, errors, loaded, depth+1, f) + } +} + +func loadKnownTypes(path string, bs []byte, m metrics.Metrics, opts ast.ParserOptions) (interface{}, error) { + switch filepath.Ext(path) { + case ".json": + return loadJSON(path, bs, m) + case ".rego": + return loadRego(path, bs, m, opts) + case ".yaml", ".yml": + return loadYAML(path, bs, m) + default: + if strings.HasSuffix(path, ".tar.gz") { + r, err := loadBundleFile(path, bs, m, opts) + if err != nil { + err = fmt.Errorf("bundle %s: %w", path, err) + } + return r, err + } + } + return nil, unrecognizedFile(path) +} + +func loadFileForAnyType(path string, bs []byte, m metrics.Metrics, opts ast.ParserOptions) (interface{}, error) { + module, err := loadRego(path, bs, m, opts) + if err == nil { + return module, nil + } + doc, err := loadJSON(path, bs, m) + if err == nil { + return doc, nil + } + doc, err = loadYAML(path, bs, m) + if err == nil { + return doc, nil + } + return nil, unrecognizedFile(path) +} + +func loadBundleFile(path string, bs []byte, m metrics.Metrics, opts ast.ParserOptions) (bundle.Bundle, error) { + tl := bundle.NewTarballLoaderWithBaseURL(bytes.NewBuffer(bs), path) + br := bundle.NewCustomReader(tl). + WithRegoVersion(opts.RegoVersion). + WithCapabilities(opts.Capabilities). + WithJSONOptions(opts.JSONOptions). + WithProcessAnnotations(opts.ProcessAnnotation). + WithMetrics(m). + WithSkipBundleVerification(true). + IncludeManifestInData(true) + return br.Read() +} + +func loadRego(path string, bs []byte, m metrics.Metrics, opts ast.ParserOptions) (*RegoFile, error) { + m.Timer(metrics.RegoModuleParse).Start() + var module *ast.Module + var err error + module, err = ast.ParseModuleWithOpts(path, string(bs), opts) + m.Timer(metrics.RegoModuleParse).Stop() + if err != nil { + return nil, err + } + result := &RegoFile{ + Name: path, + Parsed: module, + Raw: bs, + } + return result, nil +} + +func loadJSON(path string, bs []byte, m metrics.Metrics) (interface{}, error) { + m.Timer(metrics.RegoDataParse).Start() + var x interface{} + err := util.UnmarshalJSON(bs, &x) + m.Timer(metrics.RegoDataParse).Stop() + + if err != nil { + return nil, fmt.Errorf("%s: %w", path, err) + } + return x, nil +} + +func loadYAML(path string, bs []byte, m metrics.Metrics) (interface{}, error) { + m.Timer(metrics.RegoDataParse).Start() + bs, err := yaml.YAMLToJSON(bs) + m.Timer(metrics.RegoDataParse).Stop() + if err != nil { + return nil, fmt.Errorf("%v: error converting YAML to JSON: %v", path, err) + } + return loadJSON(path, bs, m) +} + +func makeDir(path []string, x interface{}) (map[string]interface{}, bool) { + if len(path) == 0 { + obj, ok := x.(map[string]interface{}) + if !ok { + return nil, false + } + return obj, true + } + return makeDir(path[:len(path)-1], map[string]interface{}{path[len(path)-1]: x}) +} + +// isUNC reports whether path is a UNC path. +func isUNC(path string) bool { + return len(path) > 1 && isSlash(path[0]) && isSlash(path[1]) +} + +func isSlash(c uint8) bool { + return c == '\\' || c == '/' +} + +func checkForUNCPath(path string) error { + if isUNC(path) { + return fmt.Errorf("UNC path read is not allowed: %s", path) + } + return nil +} diff --git a/vendor/github.com/open-policy-agent/opa/v1/logging/logging.go b/vendor/github.com/open-policy-agent/opa/v1/logging/logging.go new file mode 100644 index 000000000..7a1edfb56 --- /dev/null +++ b/vendor/github.com/open-policy-agent/opa/v1/logging/logging.go @@ -0,0 +1,266 @@ +package logging + +import ( + "context" + "io" + "net/http" + + "github.com/sirupsen/logrus" +) + +// Level log level for Logger +type Level uint8 + +const ( + // Error error log level + Error Level = iota + // Warn warn log level + Warn + // Info info log level + Info + // Debug debug log level + Debug +) + +// Logger provides interface for OPA logger implementations +type Logger interface { + Debug(fmt string, a ...interface{}) + Info(fmt string, a ...interface{}) + Error(fmt string, a ...interface{}) + Warn(fmt string, a ...interface{}) + + WithFields(map[string]interface{}) Logger + + GetLevel() Level + SetLevel(Level) +} + +// StandardLogger is the default OPA logger implementation. +type StandardLogger struct { + logger *logrus.Logger + fields map[string]interface{} +} + +// New returns a new standard logger. +func New() *StandardLogger { + return &StandardLogger{ + logger: logrus.New(), + } +} + +// Get returns the standard logger used throughout OPA. +// +// Deprecated. Do not rely on the global logger. +func Get() *StandardLogger { + return &StandardLogger{ + logger: logrus.StandardLogger(), + } +} + +// SetOutput sets the underlying logrus output. +func (l *StandardLogger) SetOutput(w io.Writer) { + l.logger.SetOutput(w) +} + +// SetFormatter sets the underlying logrus formatter. +func (l *StandardLogger) SetFormatter(formatter logrus.Formatter) { + l.logger.SetFormatter(formatter) +} + +// WithFields provides additional fields to include in log output +func (l *StandardLogger) WithFields(fields map[string]interface{}) Logger { + cp := *l + cp.fields = make(map[string]interface{}) + for k, v := range l.fields { + cp.fields[k] = v + } + for k, v := range fields { + cp.fields[k] = v + } + return &cp +} + +// getFields returns additional fields of this logger +func (l *StandardLogger) getFields() map[string]interface{} { + return l.fields +} + +// SetLevel sets the standard logger level. +func (l *StandardLogger) SetLevel(level Level) { + var logrusLevel logrus.Level + switch level { + case Error: // set logging level report Warn or higher (includes Error) + logrusLevel = logrus.WarnLevel + case Warn: + logrusLevel = logrus.WarnLevel + case Info: + logrusLevel = logrus.InfoLevel + case Debug: + logrusLevel = logrus.DebugLevel + default: + l.Warn("unknown log level %v", level) + logrusLevel = logrus.InfoLevel + } + + l.logger.SetLevel(logrusLevel) +} + +// GetLevel returns the standard logger level. +func (l *StandardLogger) GetLevel() Level { + logrusLevel := l.logger.GetLevel() + + var level Level + switch logrusLevel { + case logrus.WarnLevel: + level = Error + case logrus.InfoLevel: + level = Info + case logrus.DebugLevel: + level = Debug + default: + l.Warn("unknown log level %v", logrusLevel) + level = Info + } + + return level +} + +// Debug logs at debug level +func (l *StandardLogger) Debug(fmt string, a ...interface{}) { + if len(a) == 0 { + l.logger.WithFields(l.getFields()).Debug(fmt) + return + } + l.logger.WithFields(l.getFields()).Debugf(fmt, a...) +} + +// Info logs at info level +func (l *StandardLogger) Info(fmt string, a ...interface{}) { + if len(a) == 0 { + l.logger.WithFields(l.getFields()).Info(fmt) + return + } + l.logger.WithFields(l.getFields()).Infof(fmt, a...) +} + +// Error logs at error level +func (l *StandardLogger) Error(fmt string, a ...interface{}) { + if len(a) == 0 { + l.logger.WithFields(l.getFields()).Error(fmt) + return + } + l.logger.WithFields(l.getFields()).Errorf(fmt, a...) +} + +// Warn logs at warn level +func (l *StandardLogger) Warn(fmt string, a ...interface{}) { + if len(a) == 0 { + l.logger.WithFields(l.getFields()).Warn(fmt) + return + } + l.logger.WithFields(l.getFields()).Warnf(fmt, a...) +} + +// NoOpLogger logging implementation that does nothing +type NoOpLogger struct { + level Level + fields map[string]interface{} +} + +// NewNoOpLogger instantiates new NoOpLogger +func NewNoOpLogger() *NoOpLogger { + return &NoOpLogger{ + level: Info, + } +} + +// WithFields provides additional fields to include in log output. +// Implemented here primarily to be able to switch between implementations without loss of data. +func (l *NoOpLogger) WithFields(fields map[string]interface{}) Logger { + cp := *l + cp.fields = fields + return &cp +} + +// Debug noop +func (*NoOpLogger) Debug(string, ...interface{}) {} + +// Info noop +func (*NoOpLogger) Info(string, ...interface{}) {} + +// Error noop +func (*NoOpLogger) Error(string, ...interface{}) {} + +// Warn noop +func (*NoOpLogger) Warn(string, ...interface{}) {} + +// SetLevel set log level +func (l *NoOpLogger) SetLevel(level Level) { + l.level = level +} + +// GetLevel get log level +func (l *NoOpLogger) GetLevel() Level { + return l.level +} + +type requestContextKey string + +const reqCtxKey = requestContextKey("request-context-key") + +// RequestContext represents the request context used to store data +// related to the request that could be used on logs. +type RequestContext struct { + ClientAddr string + ReqID uint64 + ReqMethod string + ReqPath string + HTTPRequestContext HTTPRequestContext +} + +type HTTPRequestContext struct { + Header http.Header +} + +// Fields adapts the RequestContext fields to logrus.Fields. +func (rctx RequestContext) Fields() logrus.Fields { + return logrus.Fields{ + "client_addr": rctx.ClientAddr, + "req_id": rctx.ReqID, + "req_method": rctx.ReqMethod, + "req_path": rctx.ReqPath, + } +} + +// NewContext returns a copy of parent with an associated RequestContext. +func NewContext(parent context.Context, val *RequestContext) context.Context { + return context.WithValue(parent, reqCtxKey, val) +} + +// FromContext returns the RequestContext associated with ctx, if any. +func FromContext(ctx context.Context) (*RequestContext, bool) { + requestContext, ok := ctx.Value(reqCtxKey).(*RequestContext) + return requestContext, ok +} + +const httpReqCtxKey = requestContextKey("http-request-context-key") + +func WithHTTPRequestContext(parent context.Context, val *HTTPRequestContext) context.Context { + return context.WithValue(parent, httpReqCtxKey, val) +} + +func HTTPRequestContextFromContext(ctx context.Context) (*HTTPRequestContext, bool) { + requestContext, ok := ctx.Value(httpReqCtxKey).(*HTTPRequestContext) + return requestContext, ok +} + +const decisionCtxKey = requestContextKey("decision_id") + +func WithDecisionID(parent context.Context, id string) context.Context { + return context.WithValue(parent, decisionCtxKey, id) +} + +func DecisionIDFromContext(ctx context.Context) (string, bool) { + s, ok := ctx.Value(decisionCtxKey).(string) + return s, ok +} diff --git a/vendor/github.com/open-policy-agent/opa/v1/logging/test/test.go b/vendor/github.com/open-policy-agent/opa/v1/logging/test/test.go new file mode 100644 index 000000000..73588a2dc --- /dev/null +++ b/vendor/github.com/open-policy-agent/opa/v1/logging/test/test.go @@ -0,0 +1,101 @@ +package test + +import ( + "fmt" + "sync" + + "github.com/open-policy-agent/opa/v1/logging" +) + +// LogEntry represents a log message. +type LogEntry struct { + Level logging.Level + Fields map[string]interface{} + Message string +} + +// Logger implementation that buffers messages for test purposes. +type Logger struct { + level logging.Level + fields map[string]interface{} + entries *[]LogEntry + mtx *sync.Mutex +} + +// New instantiates new Logger. +func New() *Logger { + return &Logger{ + level: logging.Info, + entries: &[]LogEntry{}, + mtx: &sync.Mutex{}, + } +} + +// WithFields provides additional fields to include in log output. +// Implemented here primarily to be able to switch between implementations without loss of data. +func (l *Logger) WithFields(fields map[string]interface{}) logging.Logger { + l.mtx.Lock() + defer l.mtx.Unlock() + cp := Logger{ + level: l.level, + entries: l.entries, + fields: l.fields, + mtx: l.mtx, + } + flds := make(map[string]interface{}) + for k, v := range cp.fields { + flds[k] = v + } + for k, v := range fields { + flds[k] = v + } + cp.fields = flds + return &cp +} + +// Debug buffers a log message. +func (l *Logger) Debug(f string, a ...interface{}) { + l.append(logging.Debug, f, a...) +} + +// Info buffers a log message. +func (l *Logger) Info(f string, a ...interface{}) { + l.append(logging.Info, f, a...) +} + +// Error buffers a log message. +func (l *Logger) Error(f string, a ...interface{}) { + l.append(logging.Error, f, a...) +} + +// Warn buffers a log message. +func (l *Logger) Warn(f string, a ...interface{}) { + l.append(logging.Warn, f, a...) +} + +// SetLevel set log level. +func (l *Logger) SetLevel(level logging.Level) { + l.level = level +} + +// GetLevel get log level. +func (l *Logger) GetLevel() logging.Level { + return l.level +} + +// Entries returns buffered log entries. +func (l *Logger) Entries() []LogEntry { + l.mtx.Lock() + defer l.mtx.Unlock() + return *l.entries +} + +func (l *Logger) append(lvl logging.Level, f string, a ...interface{}) { + l.mtx.Lock() + defer l.mtx.Unlock() + *l.entries = append(*l.entries, LogEntry{ + Level: lvl, + Fields: l.fields, + Message: fmt.Sprintf(f, a...), + }) +} diff --git a/vendor/github.com/open-policy-agent/opa/v1/metrics/metrics.go b/vendor/github.com/open-policy-agent/opa/v1/metrics/metrics.go new file mode 100644 index 000000000..53cd606a3 --- /dev/null +++ b/vendor/github.com/open-policy-agent/opa/v1/metrics/metrics.go @@ -0,0 +1,312 @@ +// Copyright 2017 The OPA Authors. All rights reserved. +// Use of this source code is governed by an Apache2 +// license that can be found in the LICENSE file. + +// Package metrics contains helpers for performance metric management inside the policy engine. +package metrics + +import ( + "encoding/json" + "fmt" + "sort" + "strings" + "sync" + "sync/atomic" + "time" + + go_metrics "github.com/rcrowley/go-metrics" +) + +// Well-known metric names. +const ( + BundleRequest = "bundle_request" + ServerHandler = "server_handler" + ServerQueryCacheHit = "server_query_cache_hit" + SDKDecisionEval = "sdk_decision_eval" + RegoQueryCompile = "rego_query_compile" + RegoQueryEval = "rego_query_eval" + RegoQueryParse = "rego_query_parse" + RegoModuleParse = "rego_module_parse" + RegoDataParse = "rego_data_parse" + RegoModuleCompile = "rego_module_compile" + RegoPartialEval = "rego_partial_eval" + RegoInputParse = "rego_input_parse" + RegoLoadFiles = "rego_load_files" + RegoLoadBundles = "rego_load_bundles" + RegoExternalResolve = "rego_external_resolve" +) + +// Info contains attributes describing the underlying metrics provider. +type Info struct { + Name string `json:"name"` // name is a unique human-readable identifier for the provider. +} + +// Metrics defines the interface for a collection of performance metrics in the +// policy engine. +type Metrics interface { + Info() Info + Timer(name string) Timer + Histogram(name string) Histogram + Counter(name string) Counter + All() map[string]interface{} + Clear() + json.Marshaler +} + +type TimerMetrics interface { + Timers() map[string]interface{} +} + +type metrics struct { + mtx sync.Mutex + timers map[string]Timer + histograms map[string]Histogram + counters map[string]Counter +} + +// New returns a new Metrics object. +func New() Metrics { + m := &metrics{} + m.Clear() + return m +} + +type metric struct { + Key string + Value interface{} +} + +func (*metrics) Info() Info { + return Info{ + Name: "", + } +} + +func (m *metrics) String() string { + + all := m.All() + sorted := make([]metric, 0, len(all)) + + for key, value := range all { + sorted = append(sorted, metric{ + Key: key, + Value: value, + }) + } + + sort.Slice(sorted, func(i, j int) bool { + return sorted[i].Key < sorted[j].Key + }) + + buf := make([]string, len(sorted)) + for i := range sorted { + buf[i] = fmt.Sprintf("%v:%v", sorted[i].Key, sorted[i].Value) + } + + return strings.Join(buf, " ") +} + +func (m *metrics) MarshalJSON() ([]byte, error) { + return json.Marshal(m.All()) +} + +func (m *metrics) Timer(name string) Timer { + m.mtx.Lock() + defer m.mtx.Unlock() + t, ok := m.timers[name] + if !ok { + t = &timer{} + m.timers[name] = t + } + return t +} + +func (m *metrics) Histogram(name string) Histogram { + m.mtx.Lock() + defer m.mtx.Unlock() + h, ok := m.histograms[name] + if !ok { + h = newHistogram() + m.histograms[name] = h + } + return h +} + +func (m *metrics) Counter(name string) Counter { + m.mtx.Lock() + defer m.mtx.Unlock() + c, ok := m.counters[name] + if !ok { + zero := counter{} + c = &zero + m.counters[name] = c + } + return c +} + +func (m *metrics) All() map[string]interface{} { + m.mtx.Lock() + defer m.mtx.Unlock() + result := map[string]interface{}{} + for name, timer := range m.timers { + result[m.formatKey(name, timer)] = timer.Value() + } + for name, hist := range m.histograms { + result[m.formatKey(name, hist)] = hist.Value() + } + for name, cntr := range m.counters { + result[m.formatKey(name, cntr)] = cntr.Value() + } + return result +} + +func (m *metrics) Timers() map[string]interface{} { + m.mtx.Lock() + defer m.mtx.Unlock() + ts := map[string]interface{}{} + for n, t := range m.timers { + ts[m.formatKey(n, t)] = t.Value() + } + return ts +} + +func (m *metrics) Clear() { + m.mtx.Lock() + defer m.mtx.Unlock() + m.timers = map[string]Timer{} + m.histograms = map[string]Histogram{} + m.counters = map[string]Counter{} +} + +func (m *metrics) formatKey(name string, metrics interface{}) string { + switch metrics.(type) { + case Timer: + return "timer_" + name + "_ns" + case Histogram: + return "histogram_" + name + case Counter: + return "counter_" + name + default: + return name + } +} + +// Timer defines the interface for a restartable timer that accumulates elapsed +// time. +type Timer interface { + Value() interface{} + Int64() int64 + Start() + Stop() int64 +} + +type timer struct { + mtx sync.Mutex + start time.Time + value int64 +} + +func (t *timer) Start() { + t.mtx.Lock() + defer t.mtx.Unlock() + t.start = time.Now() +} + +func (t *timer) Stop() int64 { + t.mtx.Lock() + defer t.mtx.Unlock() + delta := time.Since(t.start).Nanoseconds() + t.value += delta + return delta +} + +func (t *timer) Value() interface{} { + return t.Int64() +} + +func (t *timer) Int64() int64 { + t.mtx.Lock() + defer t.mtx.Unlock() + return t.value +} + +// Histogram defines the interface for a histogram with hardcoded percentiles. +type Histogram interface { + Value() interface{} + Update(int64) +} + +type histogram struct { + hist go_metrics.Histogram // is thread-safe because of the underlying ExpDecaySample +} + +func newHistogram() Histogram { + // NOTE(tsandall): the reservoir size and alpha factor are taken from + // https://github.com/rcrowley/go-metrics. They may need to be tweaked in + // the future. + sample := go_metrics.NewExpDecaySample(1028, 0.015) + hist := go_metrics.NewHistogram(sample) + return &histogram{hist} +} + +func (h *histogram) Update(v int64) { + h.hist.Update(v) +} + +func (h *histogram) Value() interface{} { + values := map[string]interface{}{} + snap := h.hist.Snapshot() + percentiles := snap.Percentiles([]float64{ + 0.5, + 0.75, + 0.9, + 0.95, + 0.99, + 0.999, + 0.9999, + }) + values["count"] = snap.Count() + values["min"] = snap.Min() + values["max"] = snap.Max() + values["mean"] = snap.Mean() + values["stddev"] = snap.StdDev() + values["median"] = percentiles[0] + values["75%"] = percentiles[1] + values["90%"] = percentiles[2] + values["95%"] = percentiles[3] + values["99%"] = percentiles[4] + values["99.9%"] = percentiles[5] + values["99.99%"] = percentiles[6] + return values +} + +// Counter defines the interface for a monotonic increasing counter. +type Counter interface { + Value() interface{} + Incr() + Add(n uint64) +} + +type counter struct { + c uint64 +} + +func (c *counter) Incr() { + atomic.AddUint64(&c.c, 1) +} + +func (c *counter) Add(n uint64) { + atomic.AddUint64(&c.c, n) +} + +func (c *counter) Value() interface{} { + return atomic.LoadUint64(&c.c) +} + +func Statistics(num ...int64) interface{} { + t := newHistogram() + for _, n := range num { + t.Update(n) + } + return t.Value() +} diff --git a/vendor/github.com/open-policy-agent/opa/plugins/bundle/config.go b/vendor/github.com/open-policy-agent/opa/v1/plugins/bundle/config.go similarity index 96% rename from vendor/github.com/open-policy-agent/opa/plugins/bundle/config.go rename to vendor/github.com/open-policy-agent/opa/v1/plugins/bundle/config.go index 1cafbccc2..32458f954 100644 --- a/vendor/github.com/open-policy-agent/opa/plugins/bundle/config.go +++ b/vendor/github.com/open-policy-agent/opa/v1/plugins/bundle/config.go @@ -10,12 +10,12 @@ import ( "path" "strings" - "github.com/open-policy-agent/opa/plugins" + "github.com/open-policy-agent/opa/v1/plugins" - "github.com/open-policy-agent/opa/bundle" - "github.com/open-policy-agent/opa/download" - "github.com/open-policy-agent/opa/keys" - "github.com/open-policy-agent/opa/util" + "github.com/open-policy-agent/opa/v1/bundle" + "github.com/open-policy-agent/opa/v1/download" + "github.com/open-policy-agent/opa/v1/keys" + "github.com/open-policy-agent/opa/v1/util" ) // ParseConfig validates the config and injects default values. This is diff --git a/vendor/github.com/open-policy-agent/opa/v1/plugins/bundle/errors.go b/vendor/github.com/open-policy-agent/opa/v1/plugins/bundle/errors.go new file mode 100644 index 000000000..16fcbb5ec --- /dev/null +++ b/vendor/github.com/open-policy-agent/opa/v1/plugins/bundle/errors.go @@ -0,0 +1,53 @@ +package bundle + +import ( + "errors" + "fmt" + + "github.com/open-policy-agent/opa/v1/download" +) + +// Errors represents a list of errors that occurred during a bundle load enriched by the bundle name. +type Errors []Error + +func (e Errors) Unwrap() []error { + output := make([]error, len(e)) + for i := range e { + output[i] = e[i] + } + return output +} +func (e Errors) Error() string { + err := errors.Join(e.Unwrap()...) + return err.Error() +} + +type Error struct { + BundleName string + Code string + HTTPCode int + Message string + Err error +} + +func NewBundleError(bundleName string, cause error) Error { + var ( + httpError download.HTTPError + ) + switch { + case cause == nil: + return Error{BundleName: bundleName, Code: "", HTTPCode: -1, Message: "", Err: nil} + case errors.As(cause, &httpError): + return Error{BundleName: bundleName, Code: errCode, HTTPCode: httpError.StatusCode, Message: httpError.Error(), Err: cause} + default: + return Error{BundleName: bundleName, Code: errCode, HTTPCode: -1, Message: cause.Error(), Err: cause} + } +} + +func (e Error) Error() string { + return fmt.Sprintf("Bundle name: %s, Code: %s, HTTPCode: %d, Message: %s", e.BundleName, errCode, e.HTTPCode, e.Message) +} + +func (e Error) Unwrap() error { + return e.Err +} diff --git a/vendor/github.com/open-policy-agent/opa/plugins/bundle/plugin.go b/vendor/github.com/open-policy-agent/opa/v1/plugins/bundle/plugin.go similarity index 96% rename from vendor/github.com/open-policy-agent/opa/plugins/bundle/plugin.go rename to vendor/github.com/open-policy-agent/opa/v1/plugins/bundle/plugin.go index 53ccde387..5ffc041bb 100644 --- a/vendor/github.com/open-policy-agent/opa/plugins/bundle/plugin.go +++ b/vendor/github.com/open-policy-agent/opa/v1/plugins/bundle/plugin.go @@ -19,15 +19,15 @@ import ( "sync" "time" - "github.com/open-policy-agent/opa/ast" - "github.com/open-policy-agent/opa/bundle" - "github.com/open-policy-agent/opa/download" bundleUtils "github.com/open-policy-agent/opa/internal/bundle" "github.com/open-policy-agent/opa/internal/ref" - "github.com/open-policy-agent/opa/logging" - "github.com/open-policy-agent/opa/metrics" - "github.com/open-policy-agent/opa/plugins" - "github.com/open-policy-agent/opa/storage" + "github.com/open-policy-agent/opa/v1/ast" + "github.com/open-policy-agent/opa/v1/bundle" + "github.com/open-policy-agent/opa/v1/download" + "github.com/open-policy-agent/opa/v1/logging" + "github.com/open-policy-agent/opa/v1/metrics" + "github.com/open-policy-agent/opa/v1/plugins" + "github.com/open-policy-agent/opa/v1/storage" ) // maxActivationRetry represents the maximum number of attempts @@ -257,6 +257,8 @@ func (p *Plugin) Loaders() map[string]Loader { // Trigger triggers a bundle download on all configured bundles. func (p *Plugin) Trigger(ctx context.Context) error { + var errs Errors + p.mtx.Lock() downloaders := map[string]Loader{} for name, dl := range p.downloaders { @@ -264,9 +266,20 @@ func (p *Plugin) Trigger(ctx context.Context) error { } p.mtx.Unlock() - for _, d := range downloaders { - // plugin callback will log the trigger error and include it in the bundle status - _ = d.Trigger(ctx) + for name, d := range downloaders { + // plugin callback will also log the trigger error and include it in the bundle status + err := d.Trigger(ctx) + + // only return errors for TriggerMode manual as periodic bundles will be retried + if err != nil { + trigger := p.Config().Bundles[name].Trigger + if trigger != nil && *trigger == plugins.TriggerManual { + errs = append(errs, NewBundleError(name, err)) + } + } + } + if len(errs) > 0 { + return errs } return nil } diff --git a/vendor/github.com/open-policy-agent/opa/plugins/bundle/status.go b/vendor/github.com/open-policy-agent/opa/v1/plugins/bundle/status.go similarity index 93% rename from vendor/github.com/open-policy-agent/opa/plugins/bundle/status.go rename to vendor/github.com/open-policy-agent/opa/v1/plugins/bundle/status.go index b65718cce..4aa592789 100644 --- a/vendor/github.com/open-policy-agent/opa/plugins/bundle/status.go +++ b/vendor/github.com/open-policy-agent/opa/v1/plugins/bundle/status.go @@ -10,10 +10,10 @@ import ( "strconv" "time" - "github.com/open-policy-agent/opa/ast" - "github.com/open-policy-agent/opa/download" - "github.com/open-policy-agent/opa/metrics" - "github.com/open-policy-agent/opa/server/types" + "github.com/open-policy-agent/opa/v1/ast" + "github.com/open-policy-agent/opa/v1/download" + "github.com/open-policy-agent/opa/v1/metrics" + "github.com/open-policy-agent/opa/v1/server/types" ) const ( diff --git a/vendor/github.com/open-policy-agent/opa/plugins/discovery/config.go b/vendor/github.com/open-policy-agent/opa/v1/plugins/discovery/config.go similarity index 95% rename from vendor/github.com/open-policy-agent/opa/plugins/discovery/config.go rename to vendor/github.com/open-policy-agent/opa/v1/plugins/discovery/config.go index 869ae9b4d..1af9a547e 100644 --- a/vendor/github.com/open-policy-agent/opa/plugins/discovery/config.go +++ b/vendor/github.com/open-policy-agent/opa/v1/plugins/discovery/config.go @@ -8,13 +8,13 @@ import ( "fmt" "strings" - "github.com/open-policy-agent/opa/keys" + "github.com/open-policy-agent/opa/v1/keys" - "github.com/open-policy-agent/opa/bundle" + "github.com/open-policy-agent/opa/v1/bundle" - "github.com/open-policy-agent/opa/ast" - "github.com/open-policy-agent/opa/download" - "github.com/open-policy-agent/opa/util" + "github.com/open-policy-agent/opa/v1/ast" + "github.com/open-policy-agent/opa/v1/download" + "github.com/open-policy-agent/opa/v1/util" ) // Config represents the configuration for the discovery feature. diff --git a/vendor/github.com/open-policy-agent/opa/plugins/discovery/discovery.go b/vendor/github.com/open-policy-agent/opa/v1/plugins/discovery/discovery.go similarity index 96% rename from vendor/github.com/open-policy-agent/opa/plugins/discovery/discovery.go rename to vendor/github.com/open-policy-agent/opa/v1/plugins/discovery/discovery.go index 8b04e61b5..e8711d609 100644 --- a/vendor/github.com/open-policy-agent/opa/plugins/discovery/discovery.go +++ b/vendor/github.com/open-policy-agent/opa/v1/plugins/discovery/discovery.go @@ -17,23 +17,23 @@ import ( "strings" "sync" - "github.com/open-policy-agent/opa/ast" - bundleApi "github.com/open-policy-agent/opa/bundle" - "github.com/open-policy-agent/opa/config" - "github.com/open-policy-agent/opa/download" - "github.com/open-policy-agent/opa/hooks" bundleUtils "github.com/open-policy-agent/opa/internal/bundle" cfg "github.com/open-policy-agent/opa/internal/config" - "github.com/open-policy-agent/opa/keys" - "github.com/open-policy-agent/opa/logging" - "github.com/open-policy-agent/opa/metrics" - "github.com/open-policy-agent/opa/plugins" - "github.com/open-policy-agent/opa/plugins/bundle" - "github.com/open-policy-agent/opa/plugins/logs" - "github.com/open-policy-agent/opa/plugins/status" - "github.com/open-policy-agent/opa/rego" - "github.com/open-policy-agent/opa/storage/inmem" - "github.com/open-policy-agent/opa/util" + "github.com/open-policy-agent/opa/v1/ast" + bundleApi "github.com/open-policy-agent/opa/v1/bundle" + "github.com/open-policy-agent/opa/v1/config" + "github.com/open-policy-agent/opa/v1/download" + "github.com/open-policy-agent/opa/v1/hooks" + "github.com/open-policy-agent/opa/v1/keys" + "github.com/open-policy-agent/opa/v1/logging" + "github.com/open-policy-agent/opa/v1/metrics" + "github.com/open-policy-agent/opa/v1/plugins" + "github.com/open-policy-agent/opa/v1/plugins/bundle" + "github.com/open-policy-agent/opa/v1/plugins/logs" + "github.com/open-policy-agent/opa/v1/plugins/status" + "github.com/open-policy-agent/opa/v1/rego" + "github.com/open-policy-agent/opa/v1/storage/inmem" + "github.com/open-policy-agent/opa/v1/util" ) const ( @@ -537,6 +537,10 @@ func evaluateBundle(ctx context.Context, id string, info *ast.Term, b *bundleApi compiler := ast.NewCompiler() + if regoVersion := b.RegoVersion(ast.DefaultRegoVersion); regoVersion != ast.RegoUndefined { + compiler = compiler.WithDefaultRegoVersion(regoVersion) + } + if compiler.Compile(modules); compiler.Failed() { return nil, compiler.Errors } diff --git a/vendor/github.com/open-policy-agent/opa/plugins/logs/buffer.go b/vendor/github.com/open-policy-agent/opa/v1/plugins/logs/buffer.go similarity index 100% rename from vendor/github.com/open-policy-agent/opa/plugins/logs/buffer.go rename to vendor/github.com/open-policy-agent/opa/v1/plugins/logs/buffer.go diff --git a/vendor/github.com/open-policy-agent/opa/plugins/logs/encoder.go b/vendor/github.com/open-policy-agent/opa/v1/plugins/logs/encoder.go similarity index 99% rename from vendor/github.com/open-policy-agent/opa/plugins/logs/encoder.go rename to vendor/github.com/open-policy-agent/opa/v1/plugins/logs/encoder.go index d301664dd..4a24b3540 100644 --- a/vendor/github.com/open-policy-agent/opa/plugins/logs/encoder.go +++ b/vendor/github.com/open-policy-agent/opa/v1/plugins/logs/encoder.go @@ -11,7 +11,7 @@ import ( "fmt" "math" - "github.com/open-policy-agent/opa/metrics" + "github.com/open-policy-agent/opa/v1/metrics" ) const ( diff --git a/vendor/github.com/open-policy-agent/opa/plugins/logs/mask.go b/vendor/github.com/open-policy-agent/opa/v1/plugins/logs/mask.go similarity index 100% rename from vendor/github.com/open-policy-agent/opa/plugins/logs/mask.go rename to vendor/github.com/open-policy-agent/opa/v1/plugins/logs/mask.go diff --git a/vendor/github.com/open-policy-agent/opa/v1/plugins/logs/plugin.go b/vendor/github.com/open-policy-agent/opa/v1/plugins/logs/plugin.go new file mode 100644 index 000000000..19d9018ed --- /dev/null +++ b/vendor/github.com/open-policy-agent/opa/v1/plugins/logs/plugin.go @@ -0,0 +1,1130 @@ +// Copyright 2018 The OPA Authors. All rights reserved. +// Use of this source code is governed by an Apache2 +// license that can be found in the LICENSE file. + +// Package logs implements decision log buffering and uploading. +package logs + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "math" + "math/rand" + "net/url" + "reflect" + "strings" + "sync" + "time" + + "golang.org/x/time/rate" + + "github.com/open-policy-agent/opa/internal/ref" + "github.com/open-policy-agent/opa/v1/ast" + "github.com/open-policy-agent/opa/v1/logging" + "github.com/open-policy-agent/opa/v1/metrics" + "github.com/open-policy-agent/opa/v1/plugins" + lstat "github.com/open-policy-agent/opa/v1/plugins/logs/status" + "github.com/open-policy-agent/opa/v1/plugins/rest" + "github.com/open-policy-agent/opa/v1/plugins/status" + "github.com/open-policy-agent/opa/v1/rego" + "github.com/open-policy-agent/opa/v1/server" + "github.com/open-policy-agent/opa/v1/storage" + "github.com/open-policy-agent/opa/v1/util" +) + +// Logger defines the interface for decision logging plugins. +type Logger interface { + plugins.Plugin + + Log(context.Context, EventV1) error +} + +// EventV1 represents a decision log event. +// WARNING: The AST() function for EventV1 must be kept in sync with +// the struct. Any changes here MUST be reflected in the AST() +// implementation below. +type EventV1 struct { + Labels map[string]string `json:"labels"` + DecisionID string `json:"decision_id"` + TraceID string `json:"trace_id,omitempty"` + SpanID string `json:"span_id,omitempty"` + Revision string `json:"revision,omitempty"` // Deprecated: Use Bundles instead + Bundles map[string]BundleInfoV1 `json:"bundles,omitempty"` + Path string `json:"path,omitempty"` + Query string `json:"query,omitempty"` + Input *interface{} `json:"input,omitempty"` + Result *interface{} `json:"result,omitempty"` + MappedResult *interface{} `json:"mapped_result,omitempty"` + NDBuiltinCache *interface{} `json:"nd_builtin_cache,omitempty"` + Erased []string `json:"erased,omitempty"` + Masked []string `json:"masked,omitempty"` + Error error `json:"error,omitempty"` + RequestedBy string `json:"requested_by,omitempty"` + Timestamp time.Time `json:"timestamp"` + Metrics map[string]interface{} `json:"metrics,omitempty"` + RequestID uint64 `json:"req_id,omitempty"` + RequestContext *RequestContext `json:"request_context,omitempty"` + + inputAST ast.Value +} + +// BundleInfoV1 describes a bundle associated with a decision log event. +type BundleInfoV1 struct { + Revision string `json:"revision,omitempty"` +} + +type RequestContext struct { + HTTPRequest *HTTPRequestContext `json:"http,omitempty"` +} + +type HTTPRequestContext struct { + Headers map[string][]string `json:"headers,omitempty"` +} + +// AST returns the BundleInfoV1 as an AST value +func (b *BundleInfoV1) AST() ast.Value { + result := ast.NewObject() + if len(b.Revision) > 0 { + result.Insert(ast.StringTerm("revision"), ast.StringTerm(b.Revision)) + } + return result +} + +// Key ast.Term values for the Rego AST representation of the EventV1 +var labelsKey = ast.StringTerm("labels") +var decisionIDKey = ast.StringTerm("decision_id") +var revisionKey = ast.StringTerm("revision") +var bundlesKey = ast.StringTerm("bundles") +var pathKey = ast.StringTerm("path") +var queryKey = ast.StringTerm("query") +var inputKey = ast.StringTerm("input") +var resultKey = ast.StringTerm("result") +var mappedResultKey = ast.StringTerm("mapped_result") +var ndBuiltinCacheKey = ast.StringTerm("nd_builtin_cache") +var erasedKey = ast.StringTerm("erased") +var maskedKey = ast.StringTerm("masked") +var errorKey = ast.StringTerm("error") +var requestedByKey = ast.StringTerm("requested_by") +var timestampKey = ast.StringTerm("timestamp") +var metricsKey = ast.StringTerm("metrics") +var requestIDKey = ast.StringTerm("req_id") + +// AST returns the Rego AST representation for a given EventV1 object. +// This avoids having to round trip through JSON while applying a decision log +// mask policy to the event. +func (e *EventV1) AST() (ast.Value, error) { + var err error + event := ast.NewObject() + + if e.Labels != nil { + labelsObj := ast.NewObject() + for k, v := range e.Labels { + labelsObj.Insert(ast.StringTerm(k), ast.StringTerm(v)) + } + event.Insert(labelsKey, ast.NewTerm(labelsObj)) + } else { + event.Insert(labelsKey, ast.NullTerm()) + } + + event.Insert(decisionIDKey, ast.StringTerm(e.DecisionID)) + + if len(e.Revision) > 0 { + event.Insert(revisionKey, ast.StringTerm(e.Revision)) + } + + if len(e.Bundles) > 0 { + bundlesObj := ast.NewObject() + for k, v := range e.Bundles { + bundlesObj.Insert(ast.StringTerm(k), ast.NewTerm(v.AST())) + } + event.Insert(bundlesKey, ast.NewTerm(bundlesObj)) + } + + if len(e.Path) > 0 { + event.Insert(pathKey, ast.StringTerm(e.Path)) + } + + if len(e.Query) > 0 { + event.Insert(queryKey, ast.StringTerm(e.Query)) + } + + if e.Input != nil { + if e.inputAST == nil { + e.inputAST, err = roundtripJSONToAST(e.Input) + if err != nil { + return nil, err + } + } + event.Insert(inputKey, ast.NewTerm(e.inputAST)) + } + + if e.Result != nil { + results, err := roundtripJSONToAST(e.Result) + if err != nil { + return nil, err + } + event.Insert(resultKey, ast.NewTerm(results)) + } + + if e.MappedResult != nil { + mResults, err := roundtripJSONToAST(e.MappedResult) + if err != nil { + return nil, err + } + event.Insert(mappedResultKey, ast.NewTerm(mResults)) + } + + if e.NDBuiltinCache != nil { + ndbCache, err := roundtripJSONToAST(e.NDBuiltinCache) + if err != nil { + return nil, err + } + event.Insert(ndBuiltinCacheKey, ast.NewTerm(ndbCache)) + } + + if len(e.Erased) > 0 { + erased := make([]*ast.Term, len(e.Erased)) + for i, v := range e.Erased { + erased[i] = ast.StringTerm(v) + } + event.Insert(erasedKey, ast.NewTerm(ast.NewArray(erased...))) + } + + if len(e.Masked) > 0 { + masked := make([]*ast.Term, len(e.Masked)) + for i, v := range e.Masked { + masked[i] = ast.StringTerm(v) + } + event.Insert(maskedKey, ast.NewTerm(ast.NewArray(masked...))) + } + + if e.Error != nil { + evalErr, err := roundtripJSONToAST(e.Error) + if err != nil { + return nil, err + } + event.Insert(errorKey, ast.NewTerm(evalErr)) + } + + if len(e.RequestedBy) > 0 { + event.Insert(requestedByKey, ast.StringTerm(e.RequestedBy)) + } + + // Use the timestamp JSON marshaller to ensure the format is the same as + // round tripping through JSON. + timeBytes, err := e.Timestamp.MarshalJSON() + if err != nil { + return nil, err + } + event.Insert(timestampKey, ast.StringTerm(strings.Trim(string(timeBytes), "\""))) + + if e.Metrics != nil { + m, err := ast.InterfaceToValue(e.Metrics) + if err != nil { + return nil, err + } + event.Insert(metricsKey, ast.NewTerm(m)) + } + + if e.RequestID > 0 { + event.Insert(requestIDKey, ast.UIntNumberTerm(e.RequestID)) + } + + return event, nil +} + +func roundtripJSONToAST(x interface{}) (ast.Value, error) { + rawPtr := util.Reference(x) + // roundtrip through json: this turns slices (e.g. []string, []bool) into + // []interface{}, the only array type ast.InterfaceToValue can work with + if err := util.RoundTrip(rawPtr); err != nil { + return nil, err + } + + return ast.InterfaceToValue(*rawPtr) +} + +const ( + // min amount of time to wait following a failure + minRetryDelay = time.Millisecond * 100 + defaultMinDelaySeconds = int64(300) + defaultMaxDelaySeconds = int64(600) + defaultUploadSizeLimitBytes = int64(32768) // 32KB limit + defaultBufferSizeLimitBytes = int64(0) // unlimited + defaultMaskDecisionPath = "/system/log/mask" + defaultDropDecisionPath = "/system/log/drop" + logRateLimitExDropCounterName = "decision_logs_dropped_rate_limit_exceeded" + logNDBDropCounterName = "decision_logs_nd_builtin_cache_dropped" + logBufferSizeLimitExDropCounterName = "decision_logs_dropped_buffer_size_limit_bytes_exceeded" + logEncodingFailureCounterName = "decision_logs_encoding_failure" + defaultResourcePath = "/logs" +) + +// ReportingConfig represents configuration for the plugin's reporting behaviour. +type ReportingConfig struct { + BufferSizeLimitBytes *int64 `json:"buffer_size_limit_bytes,omitempty"` // max size of in-memory buffer + UploadSizeLimitBytes *int64 `json:"upload_size_limit_bytes,omitempty"` // max size of upload payload + MinDelaySeconds *int64 `json:"min_delay_seconds,omitempty"` // min amount of time to wait between successful poll attempts + MaxDelaySeconds *int64 `json:"max_delay_seconds,omitempty"` // max amount of time to wait between poll attempts + MaxDecisionsPerSecond *float64 `json:"max_decisions_per_second,omitempty"` // max number of decision logs to buffer per second + Trigger *plugins.TriggerMode `json:"trigger,omitempty"` // trigger mode +} + +type RequestContextConfig struct { + HTTPRequest *HTTPRequestContextConfig `json:"http,omitempty"` +} + +type HTTPRequestContextConfig struct { + Headers []string `json:"headers,omitempty"` +} + +// Config represents the plugin configuration. +type Config struct { + Plugin *string `json:"plugin"` + Service string `json:"service"` + PartitionName string `json:"partition_name,omitempty"` + Reporting ReportingConfig `json:"reporting"` + RequestContext RequestContextConfig `json:"request_context"` + MaskDecision *string `json:"mask_decision"` + DropDecision *string `json:"drop_decision"` + ConsoleLogs bool `json:"console"` + Resource *string `json:"resource"` + NDBuiltinCache bool `json:"nd_builtin_cache,omitempty"` + maskDecisionRef ast.Ref + dropDecisionRef ast.Ref +} + +func (c *Config) validateAndInjectDefaults(services []string, pluginsList []string, trigger *plugins.TriggerMode) error { + + if c.Plugin != nil { + var found bool + for _, other := range pluginsList { + if other == *c.Plugin { + found = true + break + } + } + if !found { + return fmt.Errorf("invalid plugin name %q in decision_logs", *c.Plugin) + } + } else if c.Service == "" && len(services) != 0 && !c.ConsoleLogs { + // For backwards compatibility allow defaulting to the first + // service listed, but only if console logging is disabled. If enabled + // we can't tell if the deployer wanted to use only console logs or + // both console logs and the default service option. + c.Service = services[0] + } else if c.Service != "" { + found := false + + for _, svc := range services { + if svc == c.Service { + found = true + break + } + } + + if !found { + return fmt.Errorf("invalid service name %q in decision_logs", c.Service) + } + } + + t, err := plugins.ValidateAndInjectDefaultsForTriggerMode(trigger, c.Reporting.Trigger) + if err != nil { + return fmt.Errorf("invalid decision_log config: %w", err) + } + c.Reporting.Trigger = t + + min := defaultMinDelaySeconds + max := defaultMaxDelaySeconds + + // reject bad min/max values + if c.Reporting.MaxDelaySeconds != nil && c.Reporting.MinDelaySeconds != nil { + if *c.Reporting.MaxDelaySeconds < *c.Reporting.MinDelaySeconds { + return fmt.Errorf("max reporting delay must be >= min reporting delay in decision_logs") + } + min = *c.Reporting.MinDelaySeconds + max = *c.Reporting.MaxDelaySeconds + } else if c.Reporting.MaxDelaySeconds == nil && c.Reporting.MinDelaySeconds != nil { + return fmt.Errorf("reporting configuration missing 'max_delay_seconds' in decision_logs") + } else if c.Reporting.MinDelaySeconds == nil && c.Reporting.MaxDelaySeconds != nil { + return fmt.Errorf("reporting configuration missing 'min_delay_seconds' in decision_logs") + } + + // scale to seconds + minSeconds := int64(time.Duration(min) * time.Second) + c.Reporting.MinDelaySeconds = &minSeconds + + maxSeconds := int64(time.Duration(max) * time.Second) + c.Reporting.MaxDelaySeconds = &maxSeconds + + // default the upload size limit + uploadLimit := defaultUploadSizeLimitBytes + if c.Reporting.UploadSizeLimitBytes != nil { + uploadLimit = *c.Reporting.UploadSizeLimitBytes + } + + c.Reporting.UploadSizeLimitBytes = &uploadLimit + + if c.Reporting.BufferSizeLimitBytes != nil && c.Reporting.MaxDecisionsPerSecond != nil { + return fmt.Errorf("invalid decision_log config, specify either 'buffer_size_limit_bytes' or 'max_decisions_per_second'") + } + + // default the buffer size limit + bufferLimit := defaultBufferSizeLimitBytes + if c.Reporting.BufferSizeLimitBytes != nil { + bufferLimit = *c.Reporting.BufferSizeLimitBytes + } + + c.Reporting.BufferSizeLimitBytes = &bufferLimit + + if c.MaskDecision == nil { + maskDecision := defaultMaskDecisionPath + c.MaskDecision = &maskDecision + } + + c.maskDecisionRef, err = ref.ParseDataPath(*c.MaskDecision) + if err != nil { + return fmt.Errorf("invalid mask_decision in decision_logs: %w", err) + } + + if c.DropDecision == nil { + dropDecision := defaultDropDecisionPath + c.DropDecision = &dropDecision + } + + c.dropDecisionRef, err = ref.ParseDataPath(*c.DropDecision) + if err != nil { + return fmt.Errorf("invalid drop_decision in decision_logs: %w", err) + } + + if c.PartitionName != "" { + resourcePath := fmt.Sprintf("/logs/%v", c.PartitionName) + c.Resource = &resourcePath + } else if c.Resource == nil { + resourcePath := defaultResourcePath + c.Resource = &resourcePath + } else { + if _, err := url.Parse(*c.Resource); err != nil { + return fmt.Errorf("invalid resource path %q: %w", *c.Resource, err) + } + } + + return nil +} + +// Plugin implements decision log buffering and uploading. +type Plugin struct { + manager *plugins.Manager + config Config + buffer *logBuffer + enc *chunkEncoder + mtx sync.Mutex + statusMtx sync.Mutex + stop chan chan struct{} + reconfig chan reconfigure + preparedMask prepareOnce + preparedDrop prepareOnce + limiter *rate.Limiter + metrics metrics.Metrics + logger logging.Logger + status *lstat.Status +} + +type prepareOnce struct { + once *sync.Once + preparedQuery *rego.PreparedEvalQuery + err error +} + +func newPrepareOnce() *prepareOnce { + return &prepareOnce{ + once: new(sync.Once), + } +} + +func (po *prepareOnce) drop() { + po.once = new(sync.Once) +} + +func (po *prepareOnce) prepareOnce(f func() (*rego.PreparedEvalQuery, error)) (*rego.PreparedEvalQuery, error) { + po.once.Do(func() { + po.preparedQuery, po.err = f() + }) + return po.preparedQuery, po.err +} + +type reconfigure struct { + config interface{} + done chan struct{} +} + +// ParseConfig validates the config and injects default values. +func ParseConfig(config []byte, services []string, pluginList []string) (*Config, error) { + t := plugins.DefaultTriggerMode + return NewConfigBuilder(). + WithBytes(config). + WithServices(services). + WithPlugins(pluginList). + WithTriggerMode(&t). + Parse() +} + +// ConfigBuilder assists in the construction of the plugin configuration. +type ConfigBuilder struct { + raw []byte + services []string + plugins []string + trigger *plugins.TriggerMode +} + +// NewConfigBuilder returns a new ConfigBuilder to build and parse the plugin config. +func NewConfigBuilder() *ConfigBuilder { + return &ConfigBuilder{} +} + +// WithBytes sets the raw plugin config. +func (b *ConfigBuilder) WithBytes(config []byte) *ConfigBuilder { + b.raw = config + return b +} + +// WithServices sets the services that implement control plane APIs. +func (b *ConfigBuilder) WithServices(services []string) *ConfigBuilder { + b.services = services + return b +} + +// WithPlugins sets the list of named plugins for decision logging. +func (b *ConfigBuilder) WithPlugins(plugins []string) *ConfigBuilder { + b.plugins = plugins + return b +} + +// WithTriggerMode sets the plugin trigger mode. +func (b *ConfigBuilder) WithTriggerMode(trigger *plugins.TriggerMode) *ConfigBuilder { + b.trigger = trigger + return b +} + +// Parse validates the config and injects default values. +func (b *ConfigBuilder) Parse() (*Config, error) { + if b.raw == nil { + return nil, nil + } + + var parsedConfig Config + + if err := util.Unmarshal(b.raw, &parsedConfig); err != nil { + return nil, err + } + + if parsedConfig.Plugin == nil && parsedConfig.Service == "" && len(b.services) == 0 && !parsedConfig.ConsoleLogs { + // Nothing to validate or inject + return nil, nil + } + + if err := parsedConfig.validateAndInjectDefaults(b.services, b.plugins, b.trigger); err != nil { + return nil, err + } + + return &parsedConfig, nil +} + +// New returns a new Plugin with the given config. +func New(parsedConfig *Config, manager *plugins.Manager) *Plugin { + + plugin := &Plugin{ + manager: manager, + config: *parsedConfig, + stop: make(chan chan struct{}), + buffer: newLogBuffer(*parsedConfig.Reporting.BufferSizeLimitBytes), + enc: newChunkEncoder(*parsedConfig.Reporting.UploadSizeLimitBytes), + reconfig: make(chan reconfigure), + logger: manager.Logger().WithFields(map[string]interface{}{"plugin": Name}), + status: &lstat.Status{}, + preparedDrop: *newPrepareOnce(), + preparedMask: *newPrepareOnce(), + } + + if parsedConfig.Reporting.MaxDecisionsPerSecond != nil { + limit := *parsedConfig.Reporting.MaxDecisionsPerSecond + plugin.limiter = rate.NewLimiter(rate.Limit(limit), int(math.Max(1, limit))) + } + + manager.RegisterCompilerTrigger(plugin.compilerUpdated) + + manager.UpdatePluginStatus(Name, &plugins.Status{State: plugins.StateNotReady}) + + return plugin +} + +// WithMetrics sets the global metrics provider to be used by the plugin. +func (p *Plugin) WithMetrics(m metrics.Metrics) *Plugin { + p.metrics = m + p.enc.WithMetrics(m) + return p +} + +// Name identifies the plugin on manager. +const Name = "decision_logs" + +// Lookup returns the decision logs plugin registered with the manager. +func Lookup(manager *plugins.Manager) *Plugin { + if p := manager.Plugin(Name); p != nil { + return p.(*Plugin) + } + return nil +} + +// Start starts the plugin. +func (p *Plugin) Start(_ context.Context) error { + p.logger.Info("Starting decision logger.") + go p.loop() + p.manager.UpdatePluginStatus(Name, &plugins.Status{State: plugins.StateOK}) + return nil +} + +// Stop stops the plugin. +func (p *Plugin) Stop(ctx context.Context) { + p.logger.Info("Stopping decision logger.") + + if *p.config.Reporting.Trigger == plugins.TriggerPeriodic { + if _, ok := ctx.Deadline(); ok && p.config.Service != "" { + p.flushDecisions(ctx) + } + } + + done := make(chan struct{}) + p.stop <- done + <-done + p.manager.UpdatePluginStatus(Name, &plugins.Status{State: plugins.StateNotReady}) +} + +// Config returns the plugin's current configuration +func (p *Plugin) Config() *Config { + return &p.config +} + +func (p *Plugin) flushDecisions(ctx context.Context) { + p.logger.Info("Flushing decision logs.") + + done := make(chan bool) + + go func(ctx context.Context, done chan bool) { + for ctx.Err() == nil { + if _, err := p.oneShot(ctx); err != nil { + p.logger.Error("Error flushing decisions: %s", err) + // Wait some before retrying, but skip incrementing interval since we are shutting down + time.Sleep(1 * time.Second) + } else { + done <- true + return + } + } + }(ctx, done) + + select { + case <-done: + p.logger.Info("All decisions in buffer uploaded.") + case <-ctx.Done(): + switch ctx.Err() { + case context.DeadlineExceeded, context.Canceled: + p.logger.Error("Plugin stopped with decisions possibly still in buffer.") + } + } +} + +// Log appends a decision log event to the buffer for uploading. +func (p *Plugin) Log(ctx context.Context, decision *server.Info) error { + + bundles := map[string]BundleInfoV1{} + for name, info := range decision.Bundles { + bundles[name] = BundleInfoV1{Revision: info.Revision} + } + + event := EventV1{ + Labels: p.manager.Labels(), + DecisionID: decision.DecisionID, + TraceID: decision.TraceID, + SpanID: decision.SpanID, + Revision: decision.Revision, + Bundles: bundles, + Path: decision.Path, + Query: decision.Query, + Input: decision.Input, + Result: decision.Results, + MappedResult: decision.MappedResults, + NDBuiltinCache: decision.NDBuiltinCache, + RequestedBy: decision.RemoteAddr, + Timestamp: decision.Timestamp, + RequestID: decision.RequestID, + inputAST: decision.InputAST, + } + + headers := map[string][]string{} + rctx := p.config.RequestContext + + if rctx.HTTPRequest != nil && len(rctx.HTTPRequest.Headers) > 0 && decision.HTTPRequestContext.Header != nil { + for _, h := range rctx.HTTPRequest.Headers { + values := decision.HTTPRequestContext.Header.Values(h) + if len(values) > 0 { + headers[h] = decision.HTTPRequestContext.Header.Values(h) + } + } + } + + if len(headers) > 0 { + event.RequestContext = &RequestContext{HTTPRequest: &HTTPRequestContext{Headers: headers}} + } + + input, err := event.AST() + if err != nil { + return err + } + + drop, err := p.dropEvent(ctx, decision.Txn, input) + if err != nil { + p.logger.Error("Log drop decision failed: %v.", err) + return nil + } + + if drop { + p.logger.Debug("Decision log event to path %v dropped", event.Path) + return nil + } + + if decision.Metrics != nil { + event.Metrics = decision.Metrics.All() + } + + if decision.Error != nil { + event.Error = decision.Error + } + + if err := p.maskEvent(ctx, decision.Txn, input, &event); err != nil { + // TODO(tsandall): see note below about error handling. + p.logger.Error("Log event masking failed: %v.", err) + return nil + } + + if p.config.ConsoleLogs { + if err := p.logEvent(event); err != nil { + p.logger.Error("Failed to log to console: %v.", err) + } + } + + if p.config.Service != "" { + p.encodeAndBufferEvent(event) + } + + if p.config.Plugin != nil { + proxy, ok := p.manager.Plugin(*p.config.Plugin).(Logger) + if !ok { + return fmt.Errorf("plugin does not implement Logger interface") + } + return proxy.Log(ctx, event) + } + + return nil +} + +// Reconfigure notifies the plugin with a new configuration. +func (p *Plugin) Reconfigure(_ context.Context, config interface{}) { + + done := make(chan struct{}) + p.reconfig <- reconfigure{config: config, done: done} + + p.preparedMask.drop() + p.preparedDrop.drop() + + <-done +} + +// Trigger can be used to control when the plugin attempts to upload +// a new decision log in manual triggering mode. +func (p *Plugin) Trigger(ctx context.Context) error { + done := make(chan error) + + go func() { + if p.config.Service != "" { + err := p.doOneShot(ctx) + if err != nil { + if ctx.Err() == nil { + done <- err + } + } + } + close(done) + }() + + select { + case err := <-done: + return err + case <-ctx.Done(): + return ctx.Err() + } +} + +// compilerUpdated is called when a compiler trigger on the plugin manager +// fires. This indicates a new compiler instance is available. The decision +// logger needs to prepare a new masking query. +func (p *Plugin) compilerUpdated(storage.Transaction) { + p.preparedMask.drop() + p.preparedDrop.drop() +} + +func (p *Plugin) loop() { + + ctx, cancel := context.WithCancel(context.Background()) + + var retry int + + for { + + var waitC chan struct{} + + if *p.config.Reporting.Trigger == plugins.TriggerPeriodic && p.config.Service != "" { + err := p.doOneShot(ctx) + + var delay time.Duration + + if err == nil { + min := float64(*p.config.Reporting.MinDelaySeconds) + max := float64(*p.config.Reporting.MaxDelaySeconds) + delay = time.Duration(((max - min) * rand.Float64()) + min) + } else { + delay = util.DefaultBackoff(float64(minRetryDelay), float64(*p.config.Reporting.MaxDelaySeconds), retry) + } + + p.logger.Debug("Waiting %v before next upload/retry.", delay) + + waitC = make(chan struct{}) + go func() { + timer, timerCancel := util.TimerWithCancel(delay) + select { + case <-timer.C: + if err != nil { + retry++ + } else { + retry = 0 + } + close(waitC) + case <-ctx.Done(): + timerCancel() // explicitly cancel the timer. + } + }() + } + + select { + case <-waitC: + case update := <-p.reconfig: + p.reconfigure(update.config) + update.done <- struct{}{} + case done := <-p.stop: + cancel() + done <- struct{}{} + return + } + } +} + +func (p *Plugin) doOneShot(ctx context.Context) error { + uploaded, err := p.oneShot(ctx) + + // Make a local copy of the plugins's status. + p.statusMtx.Lock() + p.status.SetError(err) + oldStatus := p.status + p.statusMtx.Unlock() + + if s := status.Lookup(p.manager); s != nil { + s.UpdateDecisionLogsStatus(*oldStatus) + } + + if err != nil { + p.logger.Error("%v.", err) + } else if uploaded { + p.logger.Info("Logs uploaded successfully.") + } else { + p.logger.Debug("Log upload queue was empty.") + } + return err +} + +func (p *Plugin) oneShot(ctx context.Context) (ok bool, err error) { + // Make a local copy of the plugins's encoder and buffer and create + // a new encoder and buffer. This is needed as locking the buffer for + // the upload duration will block policy evaluation and result in + // increased latency for OPA clients + p.mtx.Lock() + oldChunkEnc := p.enc + oldBuffer := p.buffer + p.buffer = newLogBuffer(*p.config.Reporting.BufferSizeLimitBytes) + p.enc = newChunkEncoder(*p.config.Reporting.UploadSizeLimitBytes).WithMetrics(p.metrics) + p.mtx.Unlock() + + // Along with uploading the compressed events in the buffer + // to the remote server, flush any pending compressed data to the + // underlying writer and add to the buffer. + chunk, err := oldChunkEnc.Flush() + if err != nil { + return false, err + } + + for _, ch := range chunk { + p.bufferChunk(oldBuffer, ch) + } + + if oldBuffer.Len() == 0 { + return false, nil + } + + for bs := oldBuffer.Pop(); bs != nil; bs = oldBuffer.Pop() { + if err == nil { + err = uploadChunk(ctx, p.manager.Client(p.config.Service), *p.config.Resource, bs) + } + if err != nil { + if p.limiter != nil { + events, decErr := newChunkDecoder(bs).decode() + if decErr != nil { + continue + } + + for _, event := range events { + p.encodeAndBufferEvent(event) + } + } else { + // requeue the chunk + p.mtx.Lock() + p.bufferChunk(p.buffer, bs) + p.mtx.Unlock() + } + } + } + + return err == nil, err +} + +func (p *Plugin) reconfigure(config interface{}) { + + newConfig := config.(*Config) + + if reflect.DeepEqual(p.config, *newConfig) { + p.logger.Debug("Decision log uploader configuration unchanged.") + return + } + + p.logger.Info("Decision log uploader configuration changed.") + p.config = *newConfig +} + +// NOTE(philipc): Because ND builtins caching can cause unbounded growth in +// decision log entry size, we do best-effort event encoding here, and when we +// run out of space, we drop the ND builtins cache, and try encoding again. +func (p *Plugin) encodeAndBufferEvent(event EventV1) { + if p.limiter != nil { + if !p.limiter.Allow() { + if p.metrics != nil { + p.metrics.Counter(logRateLimitExDropCounterName).Incr() + } + + p.logger.Error("Decision log dropped as rate limit exceeded. Reduce reporting interval or increase rate limit.") + return + } + } + + result, err := p.encodeEvent(event) + if err != nil { + // If there's no ND builtins cache in the event, then we don't + // need to retry encoding anything. + if event.NDBuiltinCache == nil { + // TODO(tsandall): revisit this now that we have an API that + // can return an error. Should the default behaviour be to + // fail-closed as we do for plugins? + + if p.metrics != nil { + p.metrics.Counter(logEncodingFailureCounterName).Incr() + } + p.logger.Error("Log encoding failed: %v.", err) + return + } + + // Attempt to encode the event again, dropping the ND builtins cache. + newEvent := event + newEvent.NDBuiltinCache = nil + + result, err = p.encodeEvent(newEvent) + if err != nil { + if p.metrics != nil { + p.metrics.Counter(logEncodingFailureCounterName).Incr() + } + p.logger.Error("Log encoding failed: %v.", err) + return + } + + // Re-encoding was successful, but we still need to alert users. + p.logger.Error("ND builtins cache dropped from this event to fit under maximum upload size limits. Increase upload size limit or change usage of non-deterministic builtins.") + p.metrics.Counter(logNDBDropCounterName).Incr() + } + + p.mtx.Lock() + defer p.mtx.Unlock() + for _, chunk := range result { + p.bufferChunk(p.buffer, chunk) + } +} + +func (p *Plugin) encodeEvent(event EventV1) ([][]byte, error) { + var buf bytes.Buffer + if err := json.NewEncoder(&buf).Encode(event); err != nil { + return nil, err + } + + p.mtx.Lock() + defer p.mtx.Unlock() + return p.enc.WriteBytes(buf.Bytes()) +} + +func (p *Plugin) bufferChunk(buffer *logBuffer, bs []byte) { + dropped := buffer.Push(bs) + if dropped > 0 { + if p.metrics != nil { + p.metrics.Counter(logBufferSizeLimitExDropCounterName).Incr() + } + p.logger.Error("Dropped %v chunks from buffer. Reduce reporting interval or increase buffer size.", dropped) + } +} + +func (p *Plugin) maskEvent(ctx context.Context, txn storage.Transaction, input ast.Value, event *EventV1) error { + pq, err := p.preparedMask.prepareOnce(func() (*rego.PreparedEvalQuery, error) { + var pq rego.PreparedEvalQuery + + query := ast.NewBody(ast.NewExpr(ast.NewTerm(p.config.maskDecisionRef))) + + r := rego.New( + rego.ParsedQuery(query), + rego.Compiler(p.manager.GetCompiler()), + rego.Store(p.manager.Store), + rego.Transaction(txn), + rego.Runtime(p.manager.Info), + rego.EnablePrintStatements(p.manager.EnablePrintStatements()), + rego.PrintHook(p.manager.PrintHook()), + ) + + pq, err := r.PrepareForEval(context.Background()) + if err != nil { + return nil, err + } + return &pq, nil + }) + + if err != nil { + return err + } + + rs, err := pq.Eval( + ctx, + rego.EvalParsedInput(input), + rego.EvalTransaction(txn), + ) + + if err != nil { + return err + } else if len(rs) == 0 { + return nil + } + + mRuleSet, err := newMaskRuleSet( + rs[0].Expressions[0].Value, + func(mRule *maskRule, err error) { + p.logger.Error("mask rule skipped: %s: %s", mRule.String(), err.Error()) + }, + ) + if err != nil { + return err + } + + mRuleSet.Mask(event) + + return nil +} + +func (p *Plugin) dropEvent(ctx context.Context, txn storage.Transaction, input ast.Value) (bool, error) { + var err error + + pq, err := p.preparedDrop.prepareOnce(func() (*rego.PreparedEvalQuery, error) { + var pq rego.PreparedEvalQuery + + query := ast.NewBody(ast.NewExpr(ast.NewTerm(p.config.dropDecisionRef))) + r := rego.New( + rego.ParsedQuery(query), + rego.Compiler(p.manager.GetCompiler()), + rego.Store(p.manager.Store), + rego.Transaction(txn), + rego.Runtime(p.manager.Info), + rego.EnablePrintStatements(p.manager.EnablePrintStatements()), + rego.PrintHook(p.manager.PrintHook()), + ) + + pq, err := r.PrepareForEval(context.Background()) + if err != nil { + return nil, err + } + return &pq, nil + }) + + if err != nil { + return false, err + } + + rs, err := pq.Eval( + ctx, + rego.EvalParsedInput(input), + rego.EvalTransaction(txn), + ) + + if err != nil { + return false, err + } + + return rs.Allowed(), nil +} + +func uploadChunk(ctx context.Context, client rest.Client, uploadPath string, data []byte) error { + + resp, err := client. + WithHeader("Content-Type", "application/json"). + WithHeader("Content-Encoding", "gzip"). + WithBytes(data). + Do(ctx, "POST", uploadPath) + + if err != nil { + return fmt.Errorf("log upload failed: %w", err) + } + + defer util.Close(resp) + + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + return lstat.HTTPError{StatusCode: resp.StatusCode} + } + + return nil +} + +func (p *Plugin) logEvent(event EventV1) error { + eventBuf, err := json.Marshal(&event) + if err != nil { + return err + } + fields := map[string]interface{}{} + err = util.UnmarshalJSON(eventBuf, &fields) + if err != nil { + return err + } + p.manager.ConsoleLogger().WithFields(fields).WithFields(map[string]interface{}{ + "type": "openpolicyagent.org/decision_logs", + }).Info("Decision Log") + return nil +} diff --git a/vendor/github.com/open-policy-agent/opa/plugins/logs/status/status.go b/vendor/github.com/open-policy-agent/opa/v1/plugins/logs/status/status.go similarity index 96% rename from vendor/github.com/open-policy-agent/opa/plugins/logs/status/status.go rename to vendor/github.com/open-policy-agent/opa/v1/plugins/logs/status/status.go index 4a9b24f7b..c47cbe705 100644 --- a/vendor/github.com/open-policy-agent/opa/plugins/logs/status/status.go +++ b/vendor/github.com/open-policy-agent/opa/v1/plugins/logs/status/status.go @@ -11,7 +11,7 @@ import ( "net/http" "strconv" - "github.com/open-policy-agent/opa/metrics" + "github.com/open-policy-agent/opa/v1/metrics" ) const ( diff --git a/vendor/github.com/open-policy-agent/opa/v1/plugins/plugins.go b/vendor/github.com/open-policy-agent/opa/v1/plugins/plugins.go new file mode 100644 index 000000000..82813b5e6 --- /dev/null +++ b/vendor/github.com/open-policy-agent/opa/v1/plugins/plugins.go @@ -0,0 +1,1105 @@ +// Copyright 2018 The OPA Authors. All rights reserved. +// Use of this source code is governed by an Apache2 +// license that can be found in the LICENSE file. + +// Package plugins implements plugin management for the policy engine. +package plugins + +import ( + "context" + "errors" + "fmt" + mr "math/rand" + "sync" + "time" + + "github.com/open-policy-agent/opa/internal/report" + "github.com/prometheus/client_golang/prometheus" + "go.opentelemetry.io/otel/sdk/trace" + + "github.com/gorilla/mux" + + bundleUtils "github.com/open-policy-agent/opa/internal/bundle" + cfg "github.com/open-policy-agent/opa/internal/config" + initload "github.com/open-policy-agent/opa/internal/runtime/init" + "github.com/open-policy-agent/opa/v1/ast" + "github.com/open-policy-agent/opa/v1/bundle" + "github.com/open-policy-agent/opa/v1/config" + "github.com/open-policy-agent/opa/v1/hooks" + "github.com/open-policy-agent/opa/v1/keys" + "github.com/open-policy-agent/opa/v1/loader" + "github.com/open-policy-agent/opa/v1/logging" + "github.com/open-policy-agent/opa/v1/plugins/rest" + "github.com/open-policy-agent/opa/v1/resolver/wasm" + "github.com/open-policy-agent/opa/v1/storage" + "github.com/open-policy-agent/opa/v1/topdown/cache" + "github.com/open-policy-agent/opa/v1/topdown/print" + "github.com/open-policy-agent/opa/v1/tracing" +) + +// Factory defines the interface OPA uses to instantiate your plugin. +// +// When OPA processes it's configuration it looks for factories that +// have been registered by calling runtime.RegisterPlugin. Factories +// are registered to a name which is used to key into the +// configuration blob. If your plugin has not been configured, your +// factory will not be invoked. +// +// plugins: +// my_plugin1: +// some_key: foo +// # my_plugin2: +// # some_key2: bar +// +// If OPA was started with the configuration above and received two +// calls to runtime.RegisterPlugins (one with NAME "my_plugin1" and +// one with NAME "my_plugin2"), it would only invoke the factory for +// for my_plugin1. +// +// OPA instantiates and reconfigures plugins in two steps. First, OPA +// will call Validate to check the configuration. Assuming the +// configuration is valid, your factory should return a configuration +// value that can be used to construct your plugin. Second, OPA will +// call New to instantiate your plugin providing the configuration +// value returned from the Validate call. +// +// Validate receives a slice of bytes representing plugin +// configuration and returns a configuration value that can be used to +// instantiate your plugin. The manager is provided to give access to +// the OPA's compiler, storage layer, and global configuration. Your +// Validate function will typically: +// +// 1. Deserialize the raw config bytes +// 2. Validate the deserialized config for semantic errors +// 3. Inject default values +// 4. Return a deserialized/parsed config +// +// New receives a valid configuration for your plugin and returns a +// plugin object. Your New function will typically: +// +// 1. Cast the config value to it's own type +// 2. Instantiate a plugin object +// 3. Return the plugin object +// 4. Update status via `plugins.Manager#UpdatePluginStatus` +// +// After a plugin has been created subsequent status updates can be +// send anytime the plugin enters a ready or error state. +type Factory interface { + Validate(manager *Manager, config []byte) (interface{}, error) + New(manager *Manager, config interface{}) Plugin +} + +// Plugin defines the interface OPA uses to manage your plugin. +// +// When OPA starts it will start all of the plugins it was configured +// to instantiate. Each time a new plugin is configured (via +// discovery), OPA will start it. You can use the Start call to spawn +// additional goroutines or perform initialization tasks. +// +// Currently OPA will not call Stop on plugins. +// +// When OPA receives new configuration for your plugin via discovery +// it will first Validate the configuration using your factory and +// then call Reconfigure. +type Plugin interface { + Start(ctx context.Context) error + Stop(ctx context.Context) + Reconfigure(ctx context.Context, config interface{}) +} + +// Triggerable defines the interface plugins use for manual plugin triggers. +type Triggerable interface { + Trigger(context.Context) error +} + +// State defines the state that a Plugin instance is currently +// in with pre-defined states. +type State string + +const ( + // StateNotReady indicates that the Plugin is not in an error state, but isn't + // ready for normal operation yet. This should only happen at + // initialization time. + StateNotReady State = "NOT_READY" + + // StateOK signifies that the Plugin is operating normally. + StateOK State = "OK" + + // StateErr indicates that the Plugin is in an error state and should not + // be considered as functional. + StateErr State = "ERROR" + + // StateWarn indicates the Plugin is operating, but in a potentially dangerous or + // degraded state. It may be used to indicate manual remediation is needed, or to + // alert admins of some other noteworthy state. + StateWarn State = "WARN" +) + +// TriggerMode defines the trigger mode utilized by a Plugin for bundle download, +// log upload etc. +type TriggerMode string + +const ( + // TriggerPeriodic represents periodic polling mechanism + TriggerPeriodic TriggerMode = "periodic" + + // TriggerManual represents manual triggering mechanism + TriggerManual TriggerMode = "manual" + + // DefaultTriggerMode represents default trigger mechanism + DefaultTriggerMode TriggerMode = "periodic" +) + +// default interval between OPA report uploads +var defaultUploadIntervalSec = int64(3600) + +// Status has a Plugin's current status plus an optional Message. +type Status struct { + State State `json:"state"` + Message string `json:"message,omitempty"` +} + +func (s *Status) String() string { + return fmt.Sprintf("{%v %q}", s.State, s.Message) +} + +// StatusListener defines a handler to register for status updates. +type StatusListener func(status map[string]*Status) + +// Manager implements lifecycle management of plugins and gives plugins access +// to engine-wide components like storage. +type Manager struct { + Store storage.Store + Config *config.Config + Info *ast.Term + ID string + + compiler *ast.Compiler + compilerMux sync.RWMutex + wasmResolvers []*wasm.Resolver + wasmResolversMtx sync.RWMutex + services map[string]rest.Client + keys map[string]*keys.Config + plugins []namedplugin + registeredTriggers []func(storage.Transaction) + mtx sync.Mutex + pluginStatus map[string]*Status + pluginStatusListeners map[string]StatusListener + initBundles map[string]*bundle.Bundle + initFiles loader.Result + maxErrors int + initialized bool + interQueryBuiltinCacheConfig *cache.Config + gracefulShutdownPeriod int + registeredCacheTriggers []func(*cache.Config) + logger logging.Logger + consoleLogger logging.Logger + serverInitialized chan struct{} + serverInitializedOnce sync.Once + printHook print.Hook + enablePrintStatements bool + router *mux.Router + prometheusRegister prometheus.Registerer + tracerProvider *trace.TracerProvider + distributedTacingOpts tracing.Options + registeredNDCacheTriggers []func(bool) + registeredTelemetryGatherers map[string]report.Gatherer + bootstrapConfigLabels map[string]string + hooks hooks.Hooks + enableTelemetry bool + reporter *report.Reporter + opaReportNotifyCh chan struct{} + stop chan chan struct{} + parserOptions ast.ParserOptions +} + +type managerContextKey string +type managerWasmResolverKey string + +const managerCompilerContextKey = managerContextKey("compiler") +const managerWasmResolverContextKey = managerWasmResolverKey("wasmResolvers") + +// SetCompilerOnContext puts the compiler into the storage context. Calling this +// function before committing updated policies to storage allows the manager to +// skip parsing and compiling of modules. Instead, the manager will use the +// compiler that was stored on the context. +func SetCompilerOnContext(context *storage.Context, compiler *ast.Compiler) { + context.Put(managerCompilerContextKey, compiler) +} + +// GetCompilerOnContext gets the compiler cached on the storage context. +func GetCompilerOnContext(context *storage.Context) *ast.Compiler { + compiler, ok := context.Get(managerCompilerContextKey).(*ast.Compiler) + if !ok { + return nil + } + return compiler +} + +// SetWasmResolversOnContext puts a set of Wasm Resolvers into the storage +// context. Calling this function before committing updated wasm modules to +// storage allows the manager to skip initializing modules before using them. +// Instead, the manager will use the compiler that was stored on the context. +func SetWasmResolversOnContext(context *storage.Context, rs []*wasm.Resolver) { + context.Put(managerWasmResolverContextKey, rs) +} + +// getWasmResolversOnContext gets the resolvers cached on the storage context. +func getWasmResolversOnContext(context *storage.Context) []*wasm.Resolver { + resolvers, ok := context.Get(managerWasmResolverContextKey).([]*wasm.Resolver) + if !ok { + return nil + } + return resolvers +} + +func validateTriggerMode(mode TriggerMode) error { + switch mode { + case TriggerPeriodic, TriggerManual: + return nil + default: + return fmt.Errorf("invalid trigger mode %q (want %q or %q)", mode, TriggerPeriodic, TriggerManual) + } +} + +// ValidateAndInjectDefaultsForTriggerMode validates the trigger mode and injects default values +func ValidateAndInjectDefaultsForTriggerMode(a, b *TriggerMode) (*TriggerMode, error) { + + if a == nil && b != nil { + err := validateTriggerMode(*b) + if err != nil { + return nil, err + } + return b, nil + } else if a != nil && b == nil { + err := validateTriggerMode(*a) + if err != nil { + return nil, err + } + return a, nil + } else if a != nil && b != nil { + if *a != *b { + return nil, fmt.Errorf("trigger mode mismatch: %s and %s (hint: check discovery configuration)", *a, *b) + } + err := validateTriggerMode(*a) + if err != nil { + return nil, err + } + return a, nil + } + + t := DefaultTriggerMode + return &t, nil +} + +type namedplugin struct { + name string + plugin Plugin +} + +// Info sets the runtime information on the manager. The runtime information is +// propagated to opa.runtime() built-in function calls. +func Info(term *ast.Term) func(*Manager) { + return func(m *Manager) { + m.Info = term + } +} + +// InitBundles provides the initial set of bundles to load. +func InitBundles(b map[string]*bundle.Bundle) func(*Manager) { + return func(m *Manager) { + m.initBundles = b + } +} + +// InitFiles provides the initial set of other data/policy files to load. +func InitFiles(f loader.Result) func(*Manager) { + return func(m *Manager) { + m.initFiles = f + } +} + +// MaxErrors sets the error limit for the manager's shared compiler. +func MaxErrors(n int) func(*Manager) { + return func(m *Manager) { + m.maxErrors = n + } +} + +// GracefulShutdownPeriod passes the configured graceful shutdown period to plugins +func GracefulShutdownPeriod(gracefulShutdownPeriod int) func(*Manager) { + return func(m *Manager) { + m.gracefulShutdownPeriod = gracefulShutdownPeriod + } +} + +// Logger configures the passed logger on the plugin manager (useful to +// configure default fields) +func Logger(logger logging.Logger) func(*Manager) { + return func(m *Manager) { + m.logger = logger + } +} + +// ConsoleLogger sets the passed logger to be used by plugins that are +// configured with console logging enabled. +func ConsoleLogger(logger logging.Logger) func(*Manager) { + return func(m *Manager) { + m.consoleLogger = logger + } +} + +func EnablePrintStatements(yes bool) func(*Manager) { + return func(m *Manager) { + m.enablePrintStatements = yes + } +} + +func PrintHook(h print.Hook) func(*Manager) { + return func(m *Manager) { + m.printHook = h + } +} + +func WithRouter(r *mux.Router) func(*Manager) { + return func(m *Manager) { + m.router = r + } +} + +// WithPrometheusRegister sets the passed prometheus.Registerer to be used by plugins +func WithPrometheusRegister(prometheusRegister prometheus.Registerer) func(*Manager) { + return func(m *Manager) { + m.prometheusRegister = prometheusRegister + } +} + +// WithTracerProvider sets the passed *trace.TracerProvider to be used by plugins +func WithTracerProvider(tracerProvider *trace.TracerProvider) func(*Manager) { + return func(m *Manager) { + m.tracerProvider = tracerProvider + } +} + +// WithDistributedTracingOpts sets the options to be used by distributed tracing. +func WithDistributedTracingOpts(tr tracing.Options) func(*Manager) { + return func(m *Manager) { + m.distributedTacingOpts = tr + } +} + +// WithHooks allows passing hooks to the plugin manager. +func WithHooks(hs hooks.Hooks) func(*Manager) { + return func(m *Manager) { + m.hooks = hs + } +} + +// WithParserOptions sets the parser options to be used by the plugin manager. +func WithParserOptions(opts ast.ParserOptions) func(*Manager) { + return func(m *Manager) { + m.parserOptions = opts + } +} + +// WithEnableTelemetry controls whether OPA will send telemetry reports to an external service. +func WithEnableTelemetry(enableTelemetry bool) func(*Manager) { + return func(m *Manager) { + m.enableTelemetry = enableTelemetry + } +} + +// WithTelemetryGatherers allows registration of telemetry gatherers which enable injection of additional data in the +// telemetry report +func WithTelemetryGatherers(gs map[string]report.Gatherer) func(*Manager) { + return func(m *Manager) { + m.registeredTelemetryGatherers = gs + } +} + +// New creates a new Manager using config. +func New(raw []byte, id string, store storage.Store, opts ...func(*Manager)) (*Manager, error) { + + parsedConfig, err := config.ParseConfig(raw, id) + if err != nil { + return nil, err + } + + m := &Manager{ + Store: store, + Config: parsedConfig, + ID: id, + pluginStatus: map[string]*Status{}, + pluginStatusListeners: map[string]StatusListener{}, + maxErrors: -1, + serverInitialized: make(chan struct{}), + bootstrapConfigLabels: parsedConfig.Labels, + } + + for _, f := range opts { + f(m) + } + + if m.logger == nil { + m.logger = logging.Get() + } + + if m.consoleLogger == nil { + m.consoleLogger = logging.New() + } + + m.hooks.Each(func(h hooks.Hook) { + if f, ok := h.(hooks.ConfigHook); ok { + if c, e := f.OnConfig(context.Background(), parsedConfig); e != nil { + err = errors.Join(err, e) + } else { + parsedConfig = c + } + } + }) + if err != nil { + return nil, err + } + + // do after options and overrides + m.keys, err = keys.ParseKeysConfig(parsedConfig.Keys) + if err != nil { + return nil, err + } + + m.interQueryBuiltinCacheConfig, err = cache.ParseCachingConfig(parsedConfig.Caching) + if err != nil { + return nil, err + } + + serviceOpts := cfg.ServiceOptions{ + Raw: parsedConfig.Services, + AuthPlugin: m.AuthPlugin, + Keys: m.keys, + Logger: m.logger, + DistributedTacingOpts: m.distributedTacingOpts, + } + + m.services, err = cfg.ParseServicesConfig(serviceOpts) + if err != nil { + return nil, err + } + + if m.enableTelemetry { + reporter, err := report.New(id, report.Options{Logger: m.logger}) + if err != nil { + return nil, err + } + m.reporter = reporter + + m.reporter.RegisterGatherer("min_compatible_version", func(_ context.Context) (any, error) { + var minimumCompatibleVersion string + if m.compiler != nil && m.compiler.Required != nil { + minimumCompatibleVersion, _ = m.compiler.Required.MinimumCompatibleVersion() + } + return minimumCompatibleVersion, nil + }) + + // register any additional gatherers + for k, g := range m.registeredTelemetryGatherers { + m.reporter.RegisterGatherer(k, g) + } + } + + return m, nil +} + +// Init returns an error if the manager could not initialize itself. Init() should +// be called before Start(). Init() is idempotent. +func (m *Manager) Init(ctx context.Context) error { + + if m.initialized { + return nil + } + + params := storage.TransactionParams{ + Write: true, + Context: storage.NewContext(), + } + + if m.enableTelemetry { + m.opaReportNotifyCh = make(chan struct{}) + m.stop = make(chan chan struct{}) + go m.sendOPAUpdateLoop(ctx) + } + + err := storage.Txn(ctx, m.Store, params, func(txn storage.Transaction) error { + + result, err := initload.InsertAndCompile(ctx, initload.InsertAndCompileOptions{ + Store: m.Store, + Txn: txn, + Files: m.initFiles, + Bundles: m.initBundles, + MaxErrors: m.maxErrors, + EnablePrintStatements: m.enablePrintStatements, + ParserOptions: m.parserOptions, + }) + + if err != nil { + return err + } + + SetCompilerOnContext(params.Context, result.Compiler) + + resolvers, err := bundleUtils.LoadWasmResolversFromStore(ctx, m.Store, txn, nil) + if err != nil { + return err + } + SetWasmResolversOnContext(params.Context, resolvers) + + _, err = m.Store.Register(ctx, txn, storage.TriggerConfig{OnCommit: m.onCommit}) + return err + }) + + if err != nil { + if m.stop != nil { + done := make(chan struct{}) + m.stop <- done + <-done + } + + return err + } + + m.initialized = true + return nil +} + +// Labels returns the set of labels from the configuration. +func (m *Manager) Labels() map[string]string { + m.mtx.Lock() + defer m.mtx.Unlock() + return m.Config.Labels +} + +// InterQueryBuiltinCacheConfig returns the configuration for the inter-query caches. +func (m *Manager) InterQueryBuiltinCacheConfig() *cache.Config { + m.mtx.Lock() + defer m.mtx.Unlock() + return m.interQueryBuiltinCacheConfig +} + +// Register adds a plugin to the manager. When the manager is started, all of +// the plugins will be started. +func (m *Manager) Register(name string, plugin Plugin) { + m.mtx.Lock() + defer m.mtx.Unlock() + m.plugins = append(m.plugins, namedplugin{ + name: name, + plugin: plugin, + }) + if _, ok := m.pluginStatus[name]; !ok { + m.pluginStatus[name] = &Status{State: StateNotReady} + } +} + +// Plugins returns the list of plugins registered with the manager. +func (m *Manager) Plugins() []string { + m.mtx.Lock() + defer m.mtx.Unlock() + result := make([]string, len(m.plugins)) + for i := range m.plugins { + result[i] = m.plugins[i].name + } + return result +} + +// Plugin returns the plugin registered with name or nil if name is not found. +func (m *Manager) Plugin(name string) Plugin { + m.mtx.Lock() + defer m.mtx.Unlock() + for i := range m.plugins { + if m.plugins[i].name == name { + return m.plugins[i].plugin + } + } + return nil +} + +// AuthPlugin returns the HTTPAuthPlugin registered with name or nil if name is not found. +func (m *Manager) AuthPlugin(name string) rest.HTTPAuthPlugin { + m.mtx.Lock() + defer m.mtx.Unlock() + for i := range m.plugins { + if m.plugins[i].name == name { + return m.plugins[i].plugin.(rest.HTTPAuthPlugin) + } + } + return nil +} + +// GetCompiler returns the manager's compiler. +func (m *Manager) GetCompiler() *ast.Compiler { + m.compilerMux.RLock() + defer m.compilerMux.RUnlock() + return m.compiler +} + +func (m *Manager) setCompiler(compiler *ast.Compiler) { + m.compilerMux.Lock() + defer m.compilerMux.Unlock() + m.compiler = compiler +} + +// GetRouter returns the managers router if set +func (m *Manager) GetRouter() *mux.Router { + m.mtx.Lock() + defer m.mtx.Unlock() + return m.router +} + +// RegisterCompilerTrigger registers for change notifications when the compiler +// is changed. +func (m *Manager) RegisterCompilerTrigger(f func(storage.Transaction)) { + m.mtx.Lock() + defer m.mtx.Unlock() + m.registeredTriggers = append(m.registeredTriggers, f) +} + +// GetWasmResolvers returns the manager's set of Wasm Resolvers. +func (m *Manager) GetWasmResolvers() []*wasm.Resolver { + m.wasmResolversMtx.RLock() + defer m.wasmResolversMtx.RUnlock() + return m.wasmResolvers +} + +func (m *Manager) setWasmResolvers(rs []*wasm.Resolver) { + m.wasmResolversMtx.Lock() + defer m.wasmResolversMtx.Unlock() + m.wasmResolvers = rs +} + +// Start starts the manager. Init() should be called once before Start(). +func (m *Manager) Start(ctx context.Context) error { + + if m == nil { + return nil + } + + if !m.initialized { + if err := m.Init(ctx); err != nil { + return err + } + } + + var toStart []Plugin + + func() { + m.mtx.Lock() + defer m.mtx.Unlock() + toStart = make([]Plugin, len(m.plugins)) + for i := range m.plugins { + toStart[i] = m.plugins[i].plugin + } + }() + + for i := range toStart { + if err := toStart[i].Start(ctx); err != nil { + return err + } + } + + return nil +} + +// Stop stops the manager, stopping all the plugins registered with it. +// Any plugin that needs to perform cleanup should do so within the duration +// of the graceful shutdown period passed with the context as a timeout. +// Note that a graceful shutdown period configured with the Manager instance +// will override the timeout of the passed in context (if applicable). +func (m *Manager) Stop(ctx context.Context) { + var toStop []Plugin + + func() { + m.mtx.Lock() + defer m.mtx.Unlock() + toStop = make([]Plugin, len(m.plugins)) + for i := range m.plugins { + toStop[i] = m.plugins[i].plugin + } + }() + + var cancel context.CancelFunc + if m.gracefulShutdownPeriod > 0 { + ctx, cancel = context.WithTimeout(ctx, time.Duration(m.gracefulShutdownPeriod)*time.Second) + } else { + ctx, cancel = context.WithCancel(ctx) + } + defer cancel() + for i := range toStop { + toStop[i].Stop(ctx) + } + if c, ok := m.Store.(interface{ Close(context.Context) error }); ok { + if err := c.Close(ctx); err != nil { + m.logger.Error("Error closing store: %v", err) + } + } + + if m.stop != nil { + done := make(chan struct{}) + m.stop <- done + <-done + } +} + +// Reconfigure updates the configuration on the manager. +func (m *Manager) Reconfigure(config *config.Config) error { + opts := cfg.ServiceOptions{ + Raw: config.Services, + AuthPlugin: m.AuthPlugin, + Logger: m.logger, + DistributedTacingOpts: m.distributedTacingOpts, + } + + keys, err := keys.ParseKeysConfig(config.Keys) + if err != nil { + return err + } + opts.Keys = keys + + services, err := cfg.ParseServicesConfig(opts) + if err != nil { + return err + } + + interQueryBuiltinCacheConfig, err := cache.ParseCachingConfig(config.Caching) + if err != nil { + return err + } + + m.mtx.Lock() + defer m.mtx.Unlock() + + // don't overwrite existing labels, only allow additions - always based on the boostrap config + if config.Labels == nil { + config.Labels = m.bootstrapConfigLabels + } else { + for label, value := range m.bootstrapConfigLabels { + config.Labels[label] = value + } + } + + // don't erase persistence directory + if config.PersistenceDirectory == nil { + config.PersistenceDirectory = m.Config.PersistenceDirectory + } + + m.Config = config + m.interQueryBuiltinCacheConfig = interQueryBuiltinCacheConfig + for name, client := range services { + m.services[name] = client + } + + for name, key := range keys { + m.keys[name] = key + } + + for _, trigger := range m.registeredCacheTriggers { + trigger(interQueryBuiltinCacheConfig) + } + + for _, trigger := range m.registeredNDCacheTriggers { + trigger(config.NDBuiltinCache) + } + + return nil +} + +// PluginStatus returns the current statuses of any plugins registered. +func (m *Manager) PluginStatus() map[string]*Status { + m.mtx.Lock() + defer m.mtx.Unlock() + + return m.copyPluginStatus() +} + +// RegisterPluginStatusListener registers a StatusListener to be +// called when plugin status updates occur. +func (m *Manager) RegisterPluginStatusListener(name string, listener StatusListener) { + m.mtx.Lock() + defer m.mtx.Unlock() + + m.pluginStatusListeners[name] = listener +} + +// UnregisterPluginStatusListener removes a StatusListener registered with the +// same name. +func (m *Manager) UnregisterPluginStatusListener(name string) { + m.mtx.Lock() + defer m.mtx.Unlock() + + delete(m.pluginStatusListeners, name) +} + +// UpdatePluginStatus updates a named plugins status. Any registered +// listeners will be called with a copy of the new state of all +// plugins. +func (m *Manager) UpdatePluginStatus(pluginName string, status *Status) { + + var toNotify map[string]StatusListener + var statuses map[string]*Status + + func() { + m.mtx.Lock() + defer m.mtx.Unlock() + m.pluginStatus[pluginName] = status + toNotify = make(map[string]StatusListener, len(m.pluginStatusListeners)) + for k, v := range m.pluginStatusListeners { + toNotify[k] = v + } + statuses = m.copyPluginStatus() + }() + + for _, l := range toNotify { + l(statuses) + } +} + +func (m *Manager) copyPluginStatus() map[string]*Status { + statusCpy := map[string]*Status{} + for k, v := range m.pluginStatus { + var cpy *Status + if v != nil { + cpy = &Status{ + State: v.State, + Message: v.Message, + } + } + statusCpy[k] = cpy + } + return statusCpy +} + +func (m *Manager) onCommit(ctx context.Context, txn storage.Transaction, event storage.TriggerEvent) { + + compiler := GetCompilerOnContext(event.Context) + + // If the context does not contain the compiler fallback to loading the + // compiler from the store. Currently the bundle plugin sets the + // compiler on the context but the server does not (nor would users + // implementing their own policy loading.) + if compiler == nil && event.PolicyChanged() { + compiler, _ = loadCompilerFromStore(ctx, m.Store, txn, m.enablePrintStatements, m.ParserOptions()) + } + + if compiler != nil { + m.setCompiler(compiler) + + if m.enableTelemetry && event.PolicyChanged() { + m.opaReportNotifyCh <- struct{}{} + } + + for _, f := range m.registeredTriggers { + f(txn) + } + } + + // Similar to the compiler, look for a set of resolvers on the transaction + // context. If they are not set we may need to reload from the store. + resolvers := getWasmResolversOnContext(event.Context) + if resolvers != nil { + m.setWasmResolvers(resolvers) + + } else if event.DataChanged() { + if requiresWasmResolverReload(event) { + resolvers, err := bundleUtils.LoadWasmResolversFromStore(ctx, m.Store, txn, nil) + if err != nil { + panic(err) + } + m.setWasmResolvers(resolvers) + } else { + err := m.updateWasmResolversData(ctx, event) + if err != nil { + panic(err) + } + } + } +} + +func loadCompilerFromStore(ctx context.Context, store storage.Store, txn storage.Transaction, enablePrintStatements bool, popts ast.ParserOptions) (*ast.Compiler, error) { + policies, err := store.ListPolicies(ctx, txn) + if err != nil { + return nil, err + } + modules := map[string]*ast.Module{} + + for _, policy := range policies { + bs, err := store.GetPolicy(ctx, txn, policy) + if err != nil { + return nil, err + } + module, err := ast.ParseModuleWithOpts(policy, string(bs), popts) + if err != nil { + return nil, err + } + modules[policy] = module + } + + compiler := ast.NewCompiler(). + WithEnablePrintStatements(enablePrintStatements) + + if popts.RegoVersion != ast.RegoUndefined { + compiler = compiler.WithDefaultRegoVersion(popts.RegoVersion) + } + + compiler.Compile(modules) + return compiler, nil +} + +func requiresWasmResolverReload(event storage.TriggerEvent) bool { + // If the data changes touched the bundle path (which includes + // the wasm modules) we will reload them. Otherwise update + // data for each module already on the manager. + for _, dataEvent := range event.Data { + if dataEvent.Path.HasPrefix(bundle.BundlesBasePath) { + return true + } + } + return false +} + +func (m *Manager) updateWasmResolversData(ctx context.Context, event storage.TriggerEvent) error { + m.wasmResolversMtx.Lock() + defer m.wasmResolversMtx.Unlock() + + for _, resolver := range m.wasmResolvers { + for _, dataEvent := range event.Data { + var err error + if dataEvent.Removed { + err = resolver.RemoveDataPath(ctx, dataEvent.Path) + } else { + err = resolver.SetDataPath(ctx, dataEvent.Path, dataEvent.Data) + } + if err != nil { + return fmt.Errorf("failed to update wasm runtime data: %s", err) + } + } + } + return nil +} + +// PublicKeys returns a public keys that can be used for verifying signed bundles. +func (m *Manager) PublicKeys() map[string]*keys.Config { + m.mtx.Lock() + defer m.mtx.Unlock() + return m.keys +} + +// Client returns a client for communicating with a remote service. +func (m *Manager) Client(name string) rest.Client { + m.mtx.Lock() + defer m.mtx.Unlock() + return m.services[name] +} + +// Services returns a list of services that m can provide clients for. +func (m *Manager) Services() []string { + m.mtx.Lock() + defer m.mtx.Unlock() + s := make([]string, 0, len(m.services)) + for name := range m.services { + s = append(s, name) + } + return s +} + +// Logger gets the standard logger for this plugin manager. +func (m *Manager) Logger() logging.Logger { + return m.logger +} + +// ConsoleLogger gets the console logger for this plugin manager. +func (m *Manager) ConsoleLogger() logging.Logger { + return m.consoleLogger +} + +func (m *Manager) PrintHook() print.Hook { + return m.printHook +} + +func (m *Manager) EnablePrintStatements() bool { + return m.enablePrintStatements +} + +// ServerInitialized signals a channel indicating that the OPA +// server has finished initialization. +func (m *Manager) ServerInitialized() { + m.serverInitializedOnce.Do(func() { close(m.serverInitialized) }) +} + +// ServerInitializedChannel returns a receive-only channel that +// is closed when the OPA server has finished initialization. +// Be aware that the socket of the server listener may not be +// open by the time this channel is closed. There is a very +// small window where the socket may still be closed, due to +// a race condition. +func (m *Manager) ServerInitializedChannel() <-chan struct{} { + return m.serverInitialized +} + +// RegisterCacheTrigger accepts a func that receives new inter-query cache config generated by +// a reconfigure of the plugin manager, so that it can be propagated to existing inter-query caches. +func (m *Manager) RegisterCacheTrigger(trigger func(*cache.Config)) { + m.mtx.Lock() + defer m.mtx.Unlock() + m.registeredCacheTriggers = append(m.registeredCacheTriggers, trigger) +} + +// PrometheusRegister gets the prometheus.Registerer for this plugin manager. +func (m *Manager) PrometheusRegister() prometheus.Registerer { + return m.prometheusRegister +} + +// TracerProvider gets the *trace.TracerProvider for this plugin manager. +func (m *Manager) TracerProvider() *trace.TracerProvider { + return m.tracerProvider +} + +func (m *Manager) RegisterNDCacheTrigger(trigger func(bool)) { + m.mtx.Lock() + defer m.mtx.Unlock() + m.registeredNDCacheTriggers = append(m.registeredNDCacheTriggers, trigger) +} + +func (m *Manager) sendOPAUpdateLoop(ctx context.Context) { + ticker := time.NewTicker(time.Duration(int64(time.Second) * defaultUploadIntervalSec)) + mr.New(mr.NewSource(time.Now().UnixNano())) + + ctx, cancel := context.WithCancel(ctx) + + var opaReportNotify bool + + for { + select { + case <-m.opaReportNotifyCh: + opaReportNotify = true + case <-ticker.C: + ticker.Stop() + + if opaReportNotify { + opaReportNotify = false + _, err := m.reporter.SendReport(ctx) + if err != nil { + m.logger.WithFields(map[string]interface{}{"err": err}).Debug("Unable to send OPA telemetry report.") + } + } + + newInterval := mr.Int63n(defaultUploadIntervalSec) + defaultUploadIntervalSec + ticker = time.NewTicker(time.Duration(int64(time.Second) * newInterval)) + case done := <-m.stop: + cancel() + ticker.Stop() + done <- struct{}{} + return + } + } +} + +func (m *Manager) ParserOptions() ast.ParserOptions { + return m.parserOptions +} diff --git a/vendor/github.com/open-policy-agent/opa/plugins/rest/auth.go b/vendor/github.com/open-policy-agent/opa/v1/plugins/rest/auth.go similarity index 99% rename from vendor/github.com/open-policy-agent/opa/plugins/rest/auth.go rename to vendor/github.com/open-policy-agent/opa/v1/plugins/rest/auth.go index 11e72001a..964630fa2 100644 --- a/vendor/github.com/open-policy-agent/opa/plugins/rest/auth.go +++ b/vendor/github.com/open-policy-agent/opa/v1/plugins/rest/auth.go @@ -33,8 +33,8 @@ import ( "github.com/open-policy-agent/opa/internal/jwx/jws/sign" "github.com/open-policy-agent/opa/internal/providers/aws" "github.com/open-policy-agent/opa/internal/uuid" - "github.com/open-policy-agent/opa/keys" - "github.com/open-policy-agent/opa/logging" + "github.com/open-policy-agent/opa/v1/keys" + "github.com/open-policy-agent/opa/v1/logging" ) const ( diff --git a/vendor/github.com/open-policy-agent/opa/plugins/rest/aws.go b/vendor/github.com/open-policy-agent/opa/v1/plugins/rest/aws.go similarity index 99% rename from vendor/github.com/open-policy-agent/opa/plugins/rest/aws.go rename to vendor/github.com/open-policy-agent/opa/v1/plugins/rest/aws.go index cc45dfa9c..133df8099 100644 --- a/vendor/github.com/open-policy-agent/opa/plugins/rest/aws.go +++ b/vendor/github.com/open-policy-agent/opa/v1/plugins/rest/aws.go @@ -19,7 +19,7 @@ import ( "github.com/go-ini/ini" "github.com/open-policy-agent/opa/internal/providers/aws" - "github.com/open-policy-agent/opa/logging" + "github.com/open-policy-agent/opa/v1/logging" ) const ( diff --git a/vendor/github.com/open-policy-agent/opa/plugins/rest/azure.go b/vendor/github.com/open-policy-agent/opa/v1/plugins/rest/azure.go similarity index 100% rename from vendor/github.com/open-policy-agent/opa/plugins/rest/azure.go rename to vendor/github.com/open-policy-agent/opa/v1/plugins/rest/azure.go diff --git a/vendor/github.com/open-policy-agent/opa/plugins/rest/gcp.go b/vendor/github.com/open-policy-agent/opa/v1/plugins/rest/gcp.go similarity index 100% rename from vendor/github.com/open-policy-agent/opa/plugins/rest/gcp.go rename to vendor/github.com/open-policy-agent/opa/v1/plugins/rest/gcp.go diff --git a/vendor/github.com/open-policy-agent/opa/plugins/rest/rest.go b/vendor/github.com/open-policy-agent/opa/v1/plugins/rest/rest.go similarity index 98% rename from vendor/github.com/open-policy-agent/opa/plugins/rest/rest.go rename to vendor/github.com/open-policy-agent/opa/v1/plugins/rest/rest.go index fd59058ca..fea351557 100644 --- a/vendor/github.com/open-policy-agent/opa/plugins/rest/rest.go +++ b/vendor/github.com/open-policy-agent/opa/v1/plugins/rest/rest.go @@ -18,10 +18,10 @@ import ( "strings" "github.com/open-policy-agent/opa/internal/version" - "github.com/open-policy-agent/opa/keys" - "github.com/open-policy-agent/opa/logging" - "github.com/open-policy-agent/opa/tracing" - "github.com/open-policy-agent/opa/util" + "github.com/open-policy-agent/opa/v1/keys" + "github.com/open-policy-agent/opa/v1/logging" + "github.com/open-policy-agent/opa/v1/tracing" + "github.com/open-policy-agent/opa/v1/util" ) const ( diff --git a/vendor/github.com/open-policy-agent/opa/plugins/server/decoding/config.go b/vendor/github.com/open-policy-agent/opa/v1/plugins/server/decoding/config.go similarity index 98% rename from vendor/github.com/open-policy-agent/opa/plugins/server/decoding/config.go rename to vendor/github.com/open-policy-agent/opa/v1/plugins/server/decoding/config.go index 5cd009637..9adc5f72d 100644 --- a/vendor/github.com/open-policy-agent/opa/plugins/server/decoding/config.go +++ b/vendor/github.com/open-policy-agent/opa/v1/plugins/server/decoding/config.go @@ -19,7 +19,7 @@ package decoding import ( "fmt" - "github.com/open-policy-agent/opa/util" + "github.com/open-policy-agent/opa/v1/util" ) var ( diff --git a/vendor/github.com/open-policy-agent/opa/plugins/server/encoding/config.go b/vendor/github.com/open-policy-agent/opa/v1/plugins/server/encoding/config.go similarity index 98% rename from vendor/github.com/open-policy-agent/opa/plugins/server/encoding/config.go rename to vendor/github.com/open-policy-agent/opa/v1/plugins/server/encoding/config.go index 4da12d0c0..90b9cbf6c 100644 --- a/vendor/github.com/open-policy-agent/opa/plugins/server/encoding/config.go +++ b/vendor/github.com/open-policy-agent/opa/v1/plugins/server/encoding/config.go @@ -4,7 +4,7 @@ import ( "compress/gzip" "fmt" - "github.com/open-policy-agent/opa/util" + "github.com/open-policy-agent/opa/v1/util" ) var defaultGzipMinLength = 1024 diff --git a/vendor/github.com/open-policy-agent/opa/plugins/server/metrics/config.go b/vendor/github.com/open-policy-agent/opa/v1/plugins/server/metrics/config.go similarity index 98% rename from vendor/github.com/open-policy-agent/opa/plugins/server/metrics/config.go rename to vendor/github.com/open-policy-agent/opa/v1/plugins/server/metrics/config.go index e13bb8798..e3eff7edf 100644 --- a/vendor/github.com/open-policy-agent/opa/plugins/server/metrics/config.go +++ b/vendor/github.com/open-policy-agent/opa/v1/plugins/server/metrics/config.go @@ -1,7 +1,7 @@ package metrics import ( - "github.com/open-policy-agent/opa/util" + "github.com/open-policy-agent/opa/v1/util" ) var defaultHTTPRequestBuckets = []float64{ diff --git a/vendor/github.com/open-policy-agent/opa/v1/plugins/status/metrics.go b/vendor/github.com/open-policy-agent/opa/v1/plugins/status/metrics.go new file mode 100644 index 000000000..9141ed31c --- /dev/null +++ b/vendor/github.com/open-policy-agent/opa/v1/plugins/status/metrics.go @@ -0,0 +1,174 @@ +package status + +import ( + "github.com/open-policy-agent/opa/v1/logging" + "github.com/open-policy-agent/opa/v1/version" + "github.com/prometheus/client_golang/prometheus" +) + +var defaultBundleLoadStageBuckets = prometheus.ExponentialBuckets(1000, 2, 20) + +type PrometheusConfig struct { + Collectors *Collectors `json:"collectors,omitempty"` +} + +type Collectors struct { + BundleLoadDurationNanoseconds *BundleLoadDurationNanoseconds `json:"bundle_loading_duration_ns,omitempty"` +} + +func injectDefaultDurationBuckets(p *PrometheusConfig) *PrometheusConfig { + if p != nil && p.Collectors != nil && p.Collectors.BundleLoadDurationNanoseconds != nil && p.Collectors.BundleLoadDurationNanoseconds.Buckets != nil { + return p + } + + return &PrometheusConfig{ + Collectors: &Collectors{ + BundleLoadDurationNanoseconds: &BundleLoadDurationNanoseconds{ + Buckets: defaultBundleLoadStageBuckets, + }, + }, + } +} + +// collectors is a list of all collectors maintained by the status plugin. +// Note: when adding a new collector, make sure to also add it to this list, +// or it won't survive status plugin reconfigure events. +type collectors struct { + opaInfo prometheus.Gauge + pluginStatus *prometheus.GaugeVec + loaded *prometheus.CounterVec + failLoad *prometheus.CounterVec + lastRequest *prometheus.GaugeVec + lastSuccessfulActivation *prometheus.GaugeVec + lastSuccessfulDownload *prometheus.GaugeVec + lastSuccessfulRequest *prometheus.GaugeVec + bundleLoadDuration *prometheus.HistogramVec +} + +func newCollectors(prometheusConfig *PrometheusConfig) *collectors { + opaInfo := prometheus.NewGauge( + prometheus.GaugeOpts{ + Name: "opa_info", + Help: "Information about the OPA environment.", + ConstLabels: map[string]string{"version": version.Version}, + }, + ) + opaInfo.Set(1) // only publish once + + pluginStatus := prometheus.NewGaugeVec( + prometheus.GaugeOpts{ + Name: "plugin_status_gauge", + Help: "Gauge for the plugin by status.", + }, + []string{"name", "status"}, + ) + loaded := prometheus.NewCounterVec( + prometheus.CounterOpts{ + Name: "bundle_loaded_counter", + Help: "Counter for the bundle loaded.", + }, + []string{"name"}, + ) + failLoad := prometheus.NewCounterVec( + prometheus.CounterOpts{ + Name: "bundle_failed_load_counter", + Help: "Counter for the failed bundle load.", + }, + []string{"name", "code", "message"}, + ) + lastRequest := prometheus.NewGaugeVec( + prometheus.GaugeOpts{ + Name: "last_bundle_request", + Help: "Gauge for the last bundle request.", + }, + []string{"name"}, + ) + lastSuccessfulActivation := prometheus.NewGaugeVec( + prometheus.GaugeOpts{ + Name: "last_success_bundle_activation", + Help: "Gauge for the last success bundle activation.", + }, + []string{"name", "active_revision"}, + ) + lastSuccessfulDownload := prometheus.NewGaugeVec( + prometheus.GaugeOpts{ + Name: "last_success_bundle_download", + Help: "Gauge for the last success bundle download.", + }, + []string{"name"}, + ) + lastSuccessfulRequest := prometheus.NewGaugeVec( + prometheus.GaugeOpts{ + Name: "last_success_bundle_request", + Help: "Gauge for the last success bundle request.", + }, + []string{"name"}, + ) + + bundleLoadDuration := newBundleLoadDurationCollector(prometheusConfig) + + return &collectors{ + opaInfo: opaInfo, + pluginStatus: pluginStatus, + loaded: loaded, + failLoad: failLoad, + lastRequest: lastRequest, + lastSuccessfulActivation: lastSuccessfulActivation, + lastSuccessfulDownload: lastSuccessfulDownload, + lastSuccessfulRequest: lastSuccessfulRequest, + bundleLoadDuration: bundleLoadDuration, + } +} + +func newBundleLoadDurationCollector(prometheusConfig *PrometheusConfig) *prometheus.HistogramVec { + return prometheus.NewHistogramVec(prometheus.HistogramOpts{ + Name: "bundle_loading_duration_ns", + Help: "Histogram for the bundle loading duration by stage.", + Buckets: prometheusConfig.Collectors.BundleLoadDurationNanoseconds.Buckets, + }, []string{"name", "stage"}) +} + +func (c *collectors) RegisterAll(register prometheus.Registerer, logger logging.Logger) { + if register == nil { + return + } + for _, collector := range c.toList() { + if err := register.Register(collector); err != nil { + logger.Error("Status metric failed to register on prometheus :%v.", err) + } + } +} + +func (c *collectors) UnregisterAll(register prometheus.Registerer) { + if register == nil { + return + } + + for _, collector := range c.toList() { + register.Unregister(collector) + } +} + +func (c *collectors) ReregisterBundleLoadDuration(register prometheus.Registerer, config *PrometheusConfig, logger logging.Logger) { + logger.Debug("Re-register bundleLoadDuration collector") + register.Unregister(c.bundleLoadDuration) + c.bundleLoadDuration = newBundleLoadDurationCollector(config) + if err := register.Register(c.bundleLoadDuration); err != nil { + logger.Error("Status metric failed to register bundleLoadDuration collector on prometheus :%v.", err) + } +} + +// helper function +func (c *collectors) toList() []prometheus.Collector { + return []prometheus.Collector{ + c.opaInfo, + c.pluginStatus, + c.loaded, + c.failLoad, + c.lastRequest, + c.lastSuccessfulActivation, + c.lastSuccessfulDownload, + c.lastSuccessfulRequest, + c.bundleLoadDuration, + } +} diff --git a/vendor/github.com/open-policy-agent/opa/plugins/status/plugin.go b/vendor/github.com/open-policy-agent/opa/v1/plugins/status/plugin.go similarity index 83% rename from vendor/github.com/open-policy-agent/opa/plugins/status/plugin.go rename to vendor/github.com/open-policy-agent/opa/v1/plugins/status/plugin.go index 880752403..6a447fe73 100644 --- a/vendor/github.com/open-policy-agent/opa/plugins/status/plugin.go +++ b/vendor/github.com/open-policy-agent/opa/v1/plugins/status/plugin.go @@ -12,15 +12,13 @@ import ( "net/http" "reflect" - prom "github.com/prometheus/client_golang/prometheus" + lstat "github.com/open-policy-agent/opa/v1/plugins/logs/status" - lstat "github.com/open-policy-agent/opa/plugins/logs/status" - - "github.com/open-policy-agent/opa/logging" - "github.com/open-policy-agent/opa/metrics" - "github.com/open-policy-agent/opa/plugins" - "github.com/open-policy-agent/opa/plugins/bundle" - "github.com/open-policy-agent/opa/util" + "github.com/open-policy-agent/opa/v1/logging" + "github.com/open-policy-agent/opa/v1/metrics" + "github.com/open-policy-agent/opa/v1/plugins" + "github.com/open-policy-agent/opa/v1/plugins/bundle" + "github.com/open-policy-agent/opa/v1/util" ) // Logger defines the interface for status plugins. @@ -62,16 +60,23 @@ type Plugin struct { metrics metrics.Metrics logger logging.Logger trigger chan trigger + collectors *collectors } // Config contains configuration for the plugin. type Config struct { - Plugin *string `json:"plugin"` - Service string `json:"service"` - PartitionName string `json:"partition_name,omitempty"` - ConsoleLogs bool `json:"console"` - Prometheus bool `json:"prometheus"` - Trigger *plugins.TriggerMode `json:"trigger,omitempty"` // trigger mode + Plugin *string `json:"plugin"` + Service string `json:"service"` + PartitionName string `json:"partition_name,omitempty"` + ConsoleLogs bool `json:"console"` + Prometheus bool `json:"prometheus"` + PrometheusConfig *PrometheusConfig `json:"prometheus_config,omitempty"` + Trigger *plugins.TriggerMode `json:"trigger,omitempty"` // trigger mode +} + +// BundleLoadDurationNanoseconds represents the configuration for the status.prometheus_config.bundle_loading_duration_ns settings +type BundleLoadDurationNanoseconds struct { + Buckets []float64 `json:"buckets,omitempty"` // the float64 array of buckets representing nanoseconds or multiple of nanoseconds } type reconfigure struct { @@ -85,7 +90,6 @@ type trigger struct { } func (c *Config) validateAndInjectDefaults(services []string, pluginsList []string, trigger *plugins.TriggerMode) error { - if c.Plugin != nil { var found bool for _, other := range pluginsList { @@ -124,6 +128,8 @@ func (c *Config) validateAndInjectDefaults(services []string, pluginsList []stri } c.Trigger = t + c.PrometheusConfig = injectDefaultDurationBuckets(c.PrometheusConfig) + return nil } @@ -211,6 +217,7 @@ func New(parsedConfig *Config, manager *plugins.Manager) *Plugin { queryCh: make(chan chan *UpdateRequestV1), logger: manager.Logger().WithFields(map[string]interface{}{"plugin": Name}), trigger: make(chan trigger), + collectors: newCollectors(parsedConfig.PrometheusConfig), } p.manager.UpdatePluginStatus(Name, &plugins.Status{State: plugins.StateNotReady}) @@ -246,7 +253,7 @@ func (p *Plugin) Start(ctx context.Context) error { p.manager.RegisterPluginStatusListener(Name, p.UpdatePluginStatus) if p.config.Prometheus { - p.registerAll() + p.collectors.RegisterAll(p.manager.PrometheusRegister(), p.logger) } // Set the status plugin's status to OK now that everything is registered and @@ -256,32 +263,6 @@ func (p *Plugin) Start(ctx context.Context) error { return nil } -func (p *Plugin) register(r prom.Registerer, cs ...prom.Collector) { - for _, c := range cs { - if err := r.Register(c); err != nil { - p.logger.Error("Status metric failed to register on prometheus :%v.", err) - } - } -} - -func (p *Plugin) registerAll() { - if p.manager.PrometheusRegister() != nil { - p.register(p.manager.PrometheusRegister(), allCollectors...) - } -} - -func (p *Plugin) unregister(r prom.Registerer, cs ...prom.Collector) { - for _, c := range cs { - r.Unregister(c) - } -} - -func (p *Plugin) unregisterAll() { - if p.manager.PrometheusRegister() != nil { - p.unregister(p.manager.PrometheusRegister(), allCollectors...) - } -} - // Stop stops the plugin. func (p *Plugin) Stop(_ context.Context) { p.logger.Info("Stopping status reporter.") @@ -348,11 +329,9 @@ func (p *Plugin) Trigger(ctx context.Context) error { } func (p *Plugin) loop(ctx context.Context) { - ctx, cancel := context.WithCancel(ctx) for { - select { case statuses := <-p.pluginStatusCh: p.lastPluginStatuses = statuses @@ -429,7 +408,6 @@ func (p *Plugin) loop(ctx context.Context) { } func (p *Plugin) oneShot(ctx context.Context) error { - req := p.snapshot() if p.config.ConsoleLogs { @@ -440,7 +418,7 @@ func (p *Plugin) oneShot(ctx context.Context) error { } if p.config.Prometheus { - updatePrometheusMetrics(req) + p.updatePrometheusMetrics(req) } if p.config.Plugin != nil { @@ -455,7 +433,6 @@ func (p *Plugin) oneShot(ctx context.Context) error { resp, err := p.manager.Client(p.config.Service). WithJSON(req). Do(ctx, "POST", fmt.Sprintf("/status/%v", p.config.PartitionName)) - if err != nil { return fmt.Errorf("Status update failed: %w", err) } @@ -480,16 +457,19 @@ func (p *Plugin) reconfigure(config interface{}) { p.logger.Info("Status reporter configuration changed.") if newConfig.Prometheus && !p.config.Prometheus { - p.registerAll() + p.collectors.RegisterAll(p.manager.PrometheusRegister(), p.logger) } else if !newConfig.Prometheus && p.config.Prometheus { - p.unregisterAll() + p.collectors.UnregisterAll(p.manager.PrometheusRegister()) + } else if newConfig.Prometheus && p.config.Prometheus { + if !reflect.DeepEqual(newConfig.PrometheusConfig, p.config.PrometheusConfig) { + p.collectors.ReregisterBundleLoadDuration(p.manager.PrometheusRegister(), newConfig.PrometheusConfig, p.logger) + } } p.config = *newConfig } func (p *Plugin) snapshot() *UpdateRequestV1 { - s := &UpdateRequestV1{ Labels: p.manager.Labels(), Discovery: p.lastDiscoStatus, @@ -522,27 +502,28 @@ func (p *Plugin) logUpdate(update *UpdateRequestV1) error { return nil } -func updatePrometheusMetrics(u *UpdateRequestV1) { - pluginStatus.Reset() +func (p *Plugin) updatePrometheusMetrics(u *UpdateRequestV1) { + p.collectors.pluginStatus.Reset() for name, plugin := range u.Plugins { - pluginStatus.WithLabelValues(name, string(plugin.State)).Set(1) + p.collectors.pluginStatus.WithLabelValues(name, string(plugin.State)).Set(1) } - lastSuccessfulActivation.Reset() + p.collectors.lastSuccessfulActivation.Reset() for _, bundle := range u.Bundles { if bundle.Code == "" && !bundle.LastSuccessfulActivation.IsZero() { - loaded.WithLabelValues(bundle.Name).Inc() + p.collectors.loaded.WithLabelValues(bundle.Name).Inc() } else { - failLoad.WithLabelValues(bundle.Name, bundle.Code, bundle.Message).Inc() + p.collectors.failLoad.WithLabelValues(bundle.Name, bundle.Code, bundle.Message).Inc() } - lastSuccessfulActivation.WithLabelValues(bundle.Name, bundle.ActiveRevision).Set(float64(bundle.LastSuccessfulActivation.UnixNano())) - lastSuccessfulDownload.WithLabelValues(bundle.Name).Set(float64(bundle.LastSuccessfulDownload.UnixNano())) - lastSuccessfulRequest.WithLabelValues(bundle.Name).Set(float64(bundle.LastSuccessfulRequest.UnixNano())) - lastRequest.WithLabelValues(bundle.Name).Set(float64(bundle.LastRequest.UnixNano())) + p.collectors.lastSuccessfulActivation.WithLabelValues(bundle.Name, bundle.ActiveRevision).Set(float64(bundle.LastSuccessfulActivation.UnixNano())) + p.collectors.lastSuccessfulDownload.WithLabelValues(bundle.Name).Set(float64(bundle.LastSuccessfulDownload.UnixNano())) + p.collectors.lastSuccessfulRequest.WithLabelValues(bundle.Name).Set(float64(bundle.LastSuccessfulRequest.UnixNano())) + p.collectors.lastRequest.WithLabelValues(bundle.Name).Set(float64(bundle.LastRequest.UnixNano())) + if bundle.Metrics != nil { for stage, metric := range bundle.Metrics.All() { switch stage { case "timer_bundle_request_ns", "timer_rego_data_parse_ns", "timer_rego_module_parse_ns", "timer_rego_module_compile_ns", "timer_rego_load_bundles_ns": - bundleLoadDuration.WithLabelValues(bundle.Name, stage).Observe(float64(metric.(int64))) + p.collectors.bundleLoadDuration.WithLabelValues(bundle.Name, stage).Observe(float64(metric.(int64))) } } } diff --git a/vendor/github.com/open-policy-agent/opa/profiler/profiler.go b/vendor/github.com/open-policy-agent/opa/v1/profiler/profiler.go similarity index 98% rename from vendor/github.com/open-policy-agent/opa/profiler/profiler.go rename to vendor/github.com/open-policy-agent/opa/v1/profiler/profiler.go index 62dfd2d38..ca88cadd8 100644 --- a/vendor/github.com/open-policy-agent/opa/profiler/profiler.go +++ b/vendor/github.com/open-policy-agent/opa/v1/profiler/profiler.go @@ -9,9 +9,9 @@ import ( "sort" "time" - "github.com/open-policy-agent/opa/ast" - "github.com/open-policy-agent/opa/metrics" - "github.com/open-policy-agent/opa/topdown" + "github.com/open-policy-agent/opa/v1/ast" + "github.com/open-policy-agent/opa/v1/metrics" + "github.com/open-policy-agent/opa/v1/topdown" ) // Profiler computes and reports on the time spent on expressions. diff --git a/vendor/github.com/open-policy-agent/opa/refactor/refactor.go b/vendor/github.com/open-policy-agent/opa/v1/refactor/refactor.go similarity index 98% rename from vendor/github.com/open-policy-agent/opa/refactor/refactor.go rename to vendor/github.com/open-policy-agent/opa/v1/refactor/refactor.go index 1285218f1..a951d01f8 100644 --- a/vendor/github.com/open-policy-agent/opa/refactor/refactor.go +++ b/vendor/github.com/open-policy-agent/opa/v1/refactor/refactor.go @@ -8,7 +8,7 @@ package refactor import ( "fmt" - "github.com/open-policy-agent/opa/ast" + "github.com/open-policy-agent/opa/v1/ast" ) // Error defines the structure of errors returned by refactor. diff --git a/vendor/github.com/open-policy-agent/opa/v1/rego/errors.go b/vendor/github.com/open-policy-agent/opa/v1/rego/errors.go new file mode 100644 index 000000000..dcc5e2679 --- /dev/null +++ b/vendor/github.com/open-policy-agent/opa/v1/rego/errors.go @@ -0,0 +1,24 @@ +package rego + +// HaltError is an error type to return from a custom function implementation +// that will abort the evaluation process (analogous to topdown.Halt). +type HaltError struct { + err error +} + +// Error delegates to the wrapped error +func (h *HaltError) Error() string { + return h.err.Error() +} + +// NewHaltError wraps an error such that the evaluation process will stop +// when it occurs. +func NewHaltError(err error) error { + return &HaltError{err: err} +} + +// ErrorDetails interface is satisfied by an error that provides further +// details. +type ErrorDetails interface { + Lines() []string +} diff --git a/vendor/github.com/open-policy-agent/opa/v1/rego/plugins.go b/vendor/github.com/open-policy-agent/opa/v1/rego/plugins.go new file mode 100644 index 000000000..88f23480b --- /dev/null +++ b/vendor/github.com/open-policy-agent/opa/v1/rego/plugins.go @@ -0,0 +1,43 @@ +// Copyright 2023 The OPA Authors. All rights reserved. +// Use of this source code is governed by an Apache2 +// license that can be found in the LICENSE file. + +package rego + +import ( + "context" + "sync" + + "github.com/open-policy-agent/opa/v1/ast" + "github.com/open-policy-agent/opa/v1/ir" +) + +var targetPlugins = map[string]TargetPlugin{} +var pluginMtx sync.Mutex + +type TargetPlugin interface { + IsTarget(string) bool + PrepareForEval(context.Context, *ir.Policy, ...PrepareOption) (TargetPluginEval, error) +} + +type TargetPluginEval interface { + Eval(context.Context, *EvalContext, ast.Value) (ast.Value, error) +} + +func (r *Rego) targetPlugin(tgt string) TargetPlugin { + for _, p := range targetPlugins { + if p.IsTarget(tgt) { + return p + } + } + return nil +} + +func RegisterPlugin(name string, p TargetPlugin) { + pluginMtx.Lock() + defer pluginMtx.Unlock() + if _, ok := targetPlugins[name]; ok { + panic("plugin already registered " + name) + } + targetPlugins[name] = p +} diff --git a/vendor/github.com/open-policy-agent/opa/v1/rego/rego.go b/vendor/github.com/open-policy-agent/opa/v1/rego/rego.go new file mode 100644 index 000000000..1b7ea47bd --- /dev/null +++ b/vendor/github.com/open-policy-agent/opa/v1/rego/rego.go @@ -0,0 +1,2912 @@ +// Copyright 2017 The OPA Authors. All rights reserved. +// Use of this source code is governed by an Apache2 +// license that can be found in the LICENSE file. + +// Package rego exposes high level APIs for evaluating Rego policies. +package rego + +import ( + "bytes" + "context" + "errors" + "fmt" + "io" + "strings" + "time" + + bundleUtils "github.com/open-policy-agent/opa/internal/bundle" + "github.com/open-policy-agent/opa/internal/compiler/wasm" + "github.com/open-policy-agent/opa/internal/future" + "github.com/open-policy-agent/opa/internal/planner" + "github.com/open-policy-agent/opa/internal/rego/opa" + "github.com/open-policy-agent/opa/internal/wasm/encoding" + "github.com/open-policy-agent/opa/v1/ast" + "github.com/open-policy-agent/opa/v1/bundle" + "github.com/open-policy-agent/opa/v1/ir" + "github.com/open-policy-agent/opa/v1/loader" + "github.com/open-policy-agent/opa/v1/metrics" + "github.com/open-policy-agent/opa/v1/plugins" + "github.com/open-policy-agent/opa/v1/resolver" + "github.com/open-policy-agent/opa/v1/storage" + "github.com/open-policy-agent/opa/v1/storage/inmem" + "github.com/open-policy-agent/opa/v1/topdown" + "github.com/open-policy-agent/opa/v1/topdown/builtins" + "github.com/open-policy-agent/opa/v1/topdown/cache" + "github.com/open-policy-agent/opa/v1/topdown/print" + "github.com/open-policy-agent/opa/v1/tracing" + "github.com/open-policy-agent/opa/v1/types" + "github.com/open-policy-agent/opa/v1/util" +) + +const ( + defaultPartialNamespace = "partial" + wasmVarPrefix = "^" +) + +// nolint: deadcode,varcheck +const ( + targetWasm = "wasm" + targetRego = "rego" +) + +// CompileResult represents the result of compiling a Rego query, zero or more +// Rego modules, and arbitrary contextual data into an executable. +type CompileResult struct { + Bytes []byte `json:"bytes"` +} + +// PartialQueries contains the queries and support modules produced by partial +// evaluation. +type PartialQueries struct { + Queries []ast.Body `json:"queries,omitempty"` + Support []*ast.Module `json:"modules,omitempty"` +} + +// PartialResult represents the result of partial evaluation. The result can be +// used to generate a new query that can be run when inputs are known. +type PartialResult struct { + compiler *ast.Compiler + store storage.Store + body ast.Body + builtinDecls map[string]*ast.Builtin + builtinFuncs map[string]*topdown.Builtin +} + +// Rego returns an object that can be evaluated to produce a query result. +func (pr PartialResult) Rego(options ...func(*Rego)) *Rego { + options = append(options, Compiler(pr.compiler), Store(pr.store), ParsedQuery(pr.body)) + r := New(options...) + + // Propagate any custom builtins. + for k, v := range pr.builtinDecls { + r.builtinDecls[k] = v + } + for k, v := range pr.builtinFuncs { + r.builtinFuncs[k] = v + } + return r +} + +// preparedQuery is a wrapper around a Rego object which has pre-processed +// state stored on it. Once prepared there are a more limited number of actions +// that can be taken with it. It will, however, be able to evaluate faster since +// it will not have to re-parse or compile as much. +type preparedQuery struct { + r *Rego + cfg *PrepareConfig +} + +// EvalContext defines the set of options allowed to be set at evaluation +// time. Any other options will need to be set on a new Rego object. +type EvalContext struct { + hasInput bool + time time.Time + seed io.Reader + rawInput *interface{} + parsedInput ast.Value + metrics metrics.Metrics + txn storage.Transaction + instrument bool + instrumentation *topdown.Instrumentation + partialNamespace string + queryTracers []topdown.QueryTracer + compiledQuery compiledQuery + unknowns []string + disableInlining []ast.Ref + parsedUnknowns []*ast.Term + indexing bool + earlyExit bool + interQueryBuiltinCache cache.InterQueryCache + interQueryBuiltinValueCache cache.InterQueryValueCache + ndBuiltinCache builtins.NDBCache + resolvers []refResolver + httpRoundTripper topdown.CustomizeRoundTripper + sortSets bool + copyMaps bool + printHook print.Hook + capabilities *ast.Capabilities + strictBuiltinErrors bool + virtualCache topdown.VirtualCache +} + +func (e *EvalContext) RawInput() *interface{} { + return e.rawInput +} + +func (e *EvalContext) ParsedInput() ast.Value { + return e.parsedInput +} + +func (e *EvalContext) Time() time.Time { + return e.time +} + +func (e *EvalContext) Seed() io.Reader { + return e.seed +} + +func (e *EvalContext) InterQueryBuiltinCache() cache.InterQueryCache { + return e.interQueryBuiltinCache +} + +func (e *EvalContext) InterQueryBuiltinValueCache() cache.InterQueryValueCache { + return e.interQueryBuiltinValueCache +} + +func (e *EvalContext) PrintHook() print.Hook { + return e.printHook +} + +func (e *EvalContext) Metrics() metrics.Metrics { + return e.metrics +} + +func (e *EvalContext) StrictBuiltinErrors() bool { + return e.strictBuiltinErrors +} + +func (e *EvalContext) NDBCache() builtins.NDBCache { + return e.ndBuiltinCache +} + +func (e *EvalContext) CompiledQuery() ast.Body { + return e.compiledQuery.query +} + +func (e *EvalContext) Capabilities() *ast.Capabilities { + return e.capabilities +} + +func (e *EvalContext) Transaction() storage.Transaction { + return e.txn +} + +// EvalOption defines a function to set an option on an EvalConfig +type EvalOption func(*EvalContext) + +// EvalInput configures the input for a Prepared Query's evaluation +func EvalInput(input interface{}) EvalOption { + return func(e *EvalContext) { + e.rawInput = &input + e.hasInput = true + } +} + +// EvalParsedInput configures the input for a Prepared Query's evaluation +func EvalParsedInput(input ast.Value) EvalOption { + return func(e *EvalContext) { + e.parsedInput = input + e.hasInput = true + } +} + +// EvalMetrics configures the metrics for a Prepared Query's evaluation +func EvalMetrics(metric metrics.Metrics) EvalOption { + return func(e *EvalContext) { + e.metrics = metric + } +} + +// EvalTransaction configures the Transaction for a Prepared Query's evaluation +func EvalTransaction(txn storage.Transaction) EvalOption { + return func(e *EvalContext) { + e.txn = txn + } +} + +// EvalInstrument enables or disables instrumenting for a Prepared Query's evaluation +func EvalInstrument(instrument bool) EvalOption { + return func(e *EvalContext) { + e.instrument = instrument + } +} + +// EvalTracer configures a tracer for a Prepared Query's evaluation +// Deprecated: Use EvalQueryTracer instead. +func EvalTracer(tracer topdown.Tracer) EvalOption { + return func(e *EvalContext) { + if tracer != nil { + e.queryTracers = append(e.queryTracers, topdown.WrapLegacyTracer(tracer)) + } + } +} + +// EvalQueryTracer configures a tracer for a Prepared Query's evaluation +func EvalQueryTracer(tracer topdown.QueryTracer) EvalOption { + return func(e *EvalContext) { + if tracer != nil { + e.queryTracers = append(e.queryTracers, tracer) + } + } +} + +// EvalPartialNamespace returns an argument that sets the namespace to use for +// partial evaluation results. The namespace must be a valid package path +// component. +func EvalPartialNamespace(ns string) EvalOption { + return func(e *EvalContext) { + e.partialNamespace = ns + } +} + +// EvalUnknowns returns an argument that sets the values to treat as +// unknown during partial evaluation. +func EvalUnknowns(unknowns []string) EvalOption { + return func(e *EvalContext) { + e.unknowns = unknowns + } +} + +// EvalDisableInlining returns an argument that adds a set of paths to exclude from +// partial evaluation inlining. +func EvalDisableInlining(paths []ast.Ref) EvalOption { + return func(e *EvalContext) { + e.disableInlining = paths + } +} + +// EvalParsedUnknowns returns an argument that sets the values to treat +// as unknown during partial evaluation. +func EvalParsedUnknowns(unknowns []*ast.Term) EvalOption { + return func(e *EvalContext) { + e.parsedUnknowns = unknowns + } +} + +// EvalRuleIndexing will disable indexing optimizations for the +// evaluation. This should only be used when tracing in debug mode. +func EvalRuleIndexing(enabled bool) EvalOption { + return func(e *EvalContext) { + e.indexing = enabled + } +} + +// EvalEarlyExit will disable 'early exit' optimizations for the +// evaluation. This should only be used when tracing in debug mode. +func EvalEarlyExit(enabled bool) EvalOption { + return func(e *EvalContext) { + e.earlyExit = enabled + } +} + +// EvalTime sets the wall clock time to use during policy evaluation. +// time.now_ns() calls will return this value. +func EvalTime(x time.Time) EvalOption { + return func(e *EvalContext) { + e.time = x + } +} + +// EvalSeed sets a reader that will seed randomization required by built-in functions. +// If a seed is not provided crypto/rand.Reader is used. +func EvalSeed(r io.Reader) EvalOption { + return func(e *EvalContext) { + e.seed = r + } +} + +// EvalInterQueryBuiltinCache sets the inter-query cache that built-in functions can utilize +// during evaluation. +func EvalInterQueryBuiltinCache(c cache.InterQueryCache) EvalOption { + return func(e *EvalContext) { + e.interQueryBuiltinCache = c + } +} + +// EvalInterQueryBuiltinValueCache sets the inter-query value cache that built-in functions can utilize +// during evaluation. +func EvalInterQueryBuiltinValueCache(c cache.InterQueryValueCache) EvalOption { + return func(e *EvalContext) { + e.interQueryBuiltinValueCache = c + } +} + +// EvalNDBuiltinCache sets the non-deterministic builtin cache that built-in functions can +// use during evaluation. +func EvalNDBuiltinCache(c builtins.NDBCache) EvalOption { + return func(e *EvalContext) { + e.ndBuiltinCache = c + } +} + +// EvalResolver sets a Resolver for a specified ref path for this evaluation. +func EvalResolver(ref ast.Ref, r resolver.Resolver) EvalOption { + return func(e *EvalContext) { + e.resolvers = append(e.resolvers, refResolver{ref, r}) + } +} + +// EvalHTTPRoundTripper allows customizing the http.RoundTripper for this evaluation. +func EvalHTTPRoundTripper(t topdown.CustomizeRoundTripper) EvalOption { + return func(e *EvalContext) { + e.httpRoundTripper = t + } +} + +// EvalSortSets causes the evaluator to sort sets before returning them as JSON arrays. +func EvalSortSets(yes bool) EvalOption { + return func(e *EvalContext) { + e.sortSets = yes + } +} + +// EvalCopyMaps causes the evaluator to copy `map[string]interface{}`s before returning them. +func EvalCopyMaps(yes bool) EvalOption { + return func(e *EvalContext) { + e.copyMaps = yes + } +} + +// EvalPrintHook sets the object to use for handling print statement outputs. +func EvalPrintHook(ph print.Hook) EvalOption { + return func(e *EvalContext) { + e.printHook = ph + } +} + +// EvalVirtualCache sets the topdown.VirtualCache to use for evaluation. This is +// optional, and if not set, the default cache is used. +func EvalVirtualCache(vc topdown.VirtualCache) EvalOption { + return func(e *EvalContext) { + e.virtualCache = vc + } +} + +func (pq preparedQuery) Modules() map[string]*ast.Module { + mods := make(map[string]*ast.Module) + + for name, mod := range pq.r.parsedModules { + mods[name] = mod + } + + for _, b := range pq.r.bundles { + for _, mod := range b.Modules { + mods[mod.Path] = mod.Parsed + } + } + + return mods +} + +// newEvalContext creates a new EvalContext overlaying any EvalOptions over top +// the Rego object on the preparedQuery. The returned function should be called +// once the evaluation is complete to close any transactions that might have +// been opened. +func (pq preparedQuery) newEvalContext(ctx context.Context, options []EvalOption) (*EvalContext, func(context.Context), error) { + ectx := &EvalContext{ + hasInput: false, + rawInput: nil, + parsedInput: nil, + metrics: nil, + txn: nil, + instrument: false, + instrumentation: nil, + partialNamespace: pq.r.partialNamespace, + queryTracers: nil, + unknowns: pq.r.unknowns, + parsedUnknowns: pq.r.parsedUnknowns, + compiledQuery: compiledQuery{}, + indexing: true, + earlyExit: true, + resolvers: pq.r.resolvers, + printHook: pq.r.printHook, + capabilities: pq.r.capabilities, + strictBuiltinErrors: pq.r.strictBuiltinErrors, + } + + for _, o := range options { + o(ectx) + } + + if ectx.metrics == nil { + ectx.metrics = metrics.New() + } + + if ectx.instrument { + ectx.instrumentation = topdown.NewInstrumentation(ectx.metrics) + } + + // Default to an empty "finish" function + finishFunc := func(context.Context) {} + + var err error + ectx.disableInlining, err = parseStringsToRefs(pq.r.disableInlining) + if err != nil { + return nil, finishFunc, err + } + + if ectx.txn == nil { + ectx.txn, err = pq.r.store.NewTransaction(ctx) + if err != nil { + return nil, finishFunc, err + } + finishFunc = func(ctx context.Context) { + pq.r.store.Abort(ctx, ectx.txn) + } + } + + // If we didn't get an input specified in the Eval options + // then fall back to the Rego object's input fields. + if !ectx.hasInput { + ectx.rawInput = pq.r.rawInput + ectx.parsedInput = pq.r.parsedInput + } + + if ectx.parsedInput == nil { + if ectx.rawInput == nil { + // Fall back to the original Rego objects input if none was specified + // Note that it could still be nil + ectx.rawInput = pq.r.rawInput + } + + if pq.r.targetPlugin(pq.r.target) == nil && // no plugin claims this target + pq.r.target != targetWasm { + ectx.parsedInput, err = pq.r.parseRawInput(ectx.rawInput, ectx.metrics) + if err != nil { + return nil, finishFunc, err + } + } + } + + return ectx, finishFunc, nil +} + +// PreparedEvalQuery holds the prepared Rego state that has been pre-processed +// for subsequent evaluations. +type PreparedEvalQuery struct { + preparedQuery +} + +// Eval evaluates this PartialResult's Rego object with additional eval options +// and returns a ResultSet. +// If options are provided they will override the original Rego options respective value. +// The original Rego object transaction will *not* be re-used. A new transaction will be opened +// if one is not provided with an EvalOption. +func (pq PreparedEvalQuery) Eval(ctx context.Context, options ...EvalOption) (ResultSet, error) { + ectx, finish, err := pq.newEvalContext(ctx, options) + if err != nil { + return nil, err + } + defer finish(ctx) + + ectx.compiledQuery = pq.r.compiledQueries[evalQueryType] + + return pq.r.eval(ctx, ectx) +} + +// PreparedPartialQuery holds the prepared Rego state that has been pre-processed +// for partial evaluations. +type PreparedPartialQuery struct { + preparedQuery +} + +// Partial runs partial evaluation on the prepared query and returns the result. +// The original Rego object transaction will *not* be re-used. A new transaction will be opened +// if one is not provided with an EvalOption. +func (pq PreparedPartialQuery) Partial(ctx context.Context, options ...EvalOption) (*PartialQueries, error) { + ectx, finish, err := pq.newEvalContext(ctx, options) + if err != nil { + return nil, err + } + defer finish(ctx) + + ectx.compiledQuery = pq.r.compiledQueries[partialQueryType] + + return pq.r.partial(ctx, ectx) +} + +// Errors represents a collection of errors returned when evaluating Rego. +type Errors []error + +func (errs Errors) Error() string { + if len(errs) == 0 { + return "no error" + } + if len(errs) == 1 { + return fmt.Sprintf("1 error occurred: %v", errs[0].Error()) + } + buf := []string{fmt.Sprintf("%v errors occurred", len(errs))} + for _, err := range errs { + buf = append(buf, err.Error()) + } + return strings.Join(buf, "\n") +} + +var errPartialEvaluationNotEffective = errors.New("partial evaluation not effective") + +// IsPartialEvaluationNotEffectiveErr returns true if err is an error returned by +// this package to indicate that partial evaluation was ineffective. +func IsPartialEvaluationNotEffectiveErr(err error) bool { + errs, ok := err.(Errors) + if !ok { + return false + } + return len(errs) == 1 && errs[0] == errPartialEvaluationNotEffective +} + +type compiledQuery struct { + query ast.Body + compiler ast.QueryCompiler +} + +type queryType int + +// Define a query type for each of the top level Rego +// API's that compile queries differently. +const ( + evalQueryType queryType = iota + partialResultQueryType + partialQueryType + compileQueryType +) + +type loadPaths struct { + paths []string + filter loader.Filter +} + +// Rego constructs a query and can be evaluated to obtain results. +type Rego struct { + query string + parsedQuery ast.Body + compiledQueries map[queryType]compiledQuery + pkg string + parsedPackage *ast.Package + imports []string + parsedImports []*ast.Import + rawInput *interface{} + parsedInput ast.Value + unknowns []string + parsedUnknowns []*ast.Term + disableInlining []string + shallowInlining bool + skipPartialNamespace bool + partialNamespace string + modules []rawModule + parsedModules map[string]*ast.Module + compiler *ast.Compiler + store storage.Store + ownStore bool + ownStoreReadAst bool + txn storage.Transaction + metrics metrics.Metrics + queryTracers []topdown.QueryTracer + tracebuf *topdown.BufferTracer + trace bool + instrumentation *topdown.Instrumentation + instrument bool + capture map[*ast.Expr]ast.Var // map exprs to generated capture vars + termVarID int + dump io.Writer + runtime *ast.Term + time time.Time + seed io.Reader + capabilities *ast.Capabilities + builtinDecls map[string]*ast.Builtin + builtinFuncs map[string]*topdown.Builtin + unsafeBuiltins map[string]struct{} + loadPaths loadPaths + bundlePaths []string + bundles map[string]*bundle.Bundle + skipBundleVerification bool + interQueryBuiltinCache cache.InterQueryCache + interQueryBuiltinValueCache cache.InterQueryValueCache + ndBuiltinCache builtins.NDBCache + strictBuiltinErrors bool + builtinErrorList *[]topdown.Error + resolvers []refResolver + schemaSet *ast.SchemaSet + target string // target type (wasm, rego, etc.) + opa opa.EvalEngine + generateJSON func(*ast.Term, *EvalContext) (interface{}, error) + printHook print.Hook + enablePrintStatements bool + distributedTacingOpts tracing.Options + strict bool + pluginMgr *plugins.Manager + plugins []TargetPlugin + targetPrepState TargetPluginEval + regoVersion ast.RegoVersion +} + +func (r *Rego) RegoVersion() ast.RegoVersion { + return r.regoVersion +} + +// Function represents a built-in function that is callable in Rego. +type Function struct { + Name string + Description string + Decl *types.Function + Memoize bool + Nondeterministic bool +} + +// BuiltinContext contains additional attributes from the evaluator that +// built-in functions can use, e.g., the request context.Context, caches, etc. +type BuiltinContext = topdown.BuiltinContext + +type ( + // Builtin1 defines a built-in function that accepts 1 argument. + Builtin1 func(bctx BuiltinContext, op1 *ast.Term) (*ast.Term, error) + + // Builtin2 defines a built-in function that accepts 2 arguments. + Builtin2 func(bctx BuiltinContext, op1, op2 *ast.Term) (*ast.Term, error) + + // Builtin3 defines a built-in function that accepts 3 argument. + Builtin3 func(bctx BuiltinContext, op1, op2, op3 *ast.Term) (*ast.Term, error) + + // Builtin4 defines a built-in function that accepts 4 argument. + Builtin4 func(bctx BuiltinContext, op1, op2, op3, op4 *ast.Term) (*ast.Term, error) + + // BuiltinDyn defines a built-in function that accepts a list of arguments. + BuiltinDyn func(bctx BuiltinContext, terms []*ast.Term) (*ast.Term, error) +) + +// RegisterBuiltin1 adds a built-in function globally inside the OPA runtime. +func RegisterBuiltin1(decl *Function, impl Builtin1) { + ast.RegisterBuiltin(&ast.Builtin{ + Name: decl.Name, + Description: decl.Description, + Decl: decl.Decl, + Nondeterministic: decl.Nondeterministic, + }) + topdown.RegisterBuiltinFunc(decl.Name, func(bctx BuiltinContext, terms []*ast.Term, iter func(*ast.Term) error) error { + result, err := memoize(decl, bctx, terms, func() (*ast.Term, error) { return impl(bctx, terms[0]) }) + return finishFunction(decl.Name, bctx, result, err, iter) + }) +} + +// RegisterBuiltin2 adds a built-in function globally inside the OPA runtime. +func RegisterBuiltin2(decl *Function, impl Builtin2) { + ast.RegisterBuiltin(&ast.Builtin{ + Name: decl.Name, + Description: decl.Description, + Decl: decl.Decl, + Nondeterministic: decl.Nondeterministic, + }) + topdown.RegisterBuiltinFunc(decl.Name, func(bctx BuiltinContext, terms []*ast.Term, iter func(*ast.Term) error) error { + result, err := memoize(decl, bctx, terms, func() (*ast.Term, error) { return impl(bctx, terms[0], terms[1]) }) + return finishFunction(decl.Name, bctx, result, err, iter) + }) +} + +// RegisterBuiltin3 adds a built-in function globally inside the OPA runtime. +func RegisterBuiltin3(decl *Function, impl Builtin3) { + ast.RegisterBuiltin(&ast.Builtin{ + Name: decl.Name, + Description: decl.Description, + Decl: decl.Decl, + Nondeterministic: decl.Nondeterministic, + }) + topdown.RegisterBuiltinFunc(decl.Name, func(bctx BuiltinContext, terms []*ast.Term, iter func(*ast.Term) error) error { + result, err := memoize(decl, bctx, terms, func() (*ast.Term, error) { return impl(bctx, terms[0], terms[1], terms[2]) }) + return finishFunction(decl.Name, bctx, result, err, iter) + }) +} + +// RegisterBuiltin4 adds a built-in function globally inside the OPA runtime. +func RegisterBuiltin4(decl *Function, impl Builtin4) { + ast.RegisterBuiltin(&ast.Builtin{ + Name: decl.Name, + Description: decl.Description, + Decl: decl.Decl, + Nondeterministic: decl.Nondeterministic, + }) + topdown.RegisterBuiltinFunc(decl.Name, func(bctx BuiltinContext, terms []*ast.Term, iter func(*ast.Term) error) error { + result, err := memoize(decl, bctx, terms, func() (*ast.Term, error) { return impl(bctx, terms[0], terms[1], terms[2], terms[3]) }) + return finishFunction(decl.Name, bctx, result, err, iter) + }) +} + +// RegisterBuiltinDyn adds a built-in function globally inside the OPA runtime. +func RegisterBuiltinDyn(decl *Function, impl BuiltinDyn) { + ast.RegisterBuiltin(&ast.Builtin{ + Name: decl.Name, + Description: decl.Description, + Decl: decl.Decl, + Nondeterministic: decl.Nondeterministic, + }) + topdown.RegisterBuiltinFunc(decl.Name, func(bctx BuiltinContext, terms []*ast.Term, iter func(*ast.Term) error) error { + result, err := memoize(decl, bctx, terms, func() (*ast.Term, error) { return impl(bctx, terms) }) + return finishFunction(decl.Name, bctx, result, err, iter) + }) +} + +// Function1 returns an option that adds a built-in function to the Rego object. +func Function1(decl *Function, f Builtin1) func(*Rego) { + return newFunction(decl, func(bctx BuiltinContext, terms []*ast.Term, iter func(*ast.Term) error) error { + result, err := memoize(decl, bctx, terms, func() (*ast.Term, error) { return f(bctx, terms[0]) }) + return finishFunction(decl.Name, bctx, result, err, iter) + }) +} + +// Function2 returns an option that adds a built-in function to the Rego object. +func Function2(decl *Function, f Builtin2) func(*Rego) { + return newFunction(decl, func(bctx BuiltinContext, terms []*ast.Term, iter func(*ast.Term) error) error { + result, err := memoize(decl, bctx, terms, func() (*ast.Term, error) { return f(bctx, terms[0], terms[1]) }) + return finishFunction(decl.Name, bctx, result, err, iter) + }) +} + +// Function3 returns an option that adds a built-in function to the Rego object. +func Function3(decl *Function, f Builtin3) func(*Rego) { + return newFunction(decl, func(bctx BuiltinContext, terms []*ast.Term, iter func(*ast.Term) error) error { + result, err := memoize(decl, bctx, terms, func() (*ast.Term, error) { return f(bctx, terms[0], terms[1], terms[2]) }) + return finishFunction(decl.Name, bctx, result, err, iter) + }) +} + +// Function4 returns an option that adds a built-in function to the Rego object. +func Function4(decl *Function, f Builtin4) func(*Rego) { + return newFunction(decl, func(bctx BuiltinContext, terms []*ast.Term, iter func(*ast.Term) error) error { + result, err := memoize(decl, bctx, terms, func() (*ast.Term, error) { return f(bctx, terms[0], terms[1], terms[2], terms[3]) }) + return finishFunction(decl.Name, bctx, result, err, iter) + }) +} + +// FunctionDyn returns an option that adds a built-in function to the Rego object. +func FunctionDyn(decl *Function, f BuiltinDyn) func(*Rego) { + return newFunction(decl, func(bctx BuiltinContext, terms []*ast.Term, iter func(*ast.Term) error) error { + result, err := memoize(decl, bctx, terms, func() (*ast.Term, error) { return f(bctx, terms) }) + return finishFunction(decl.Name, bctx, result, err, iter) + }) +} + +// FunctionDecl returns an option that adds a custom-built-in function +// __declaration__. NO implementation is provided. This is used for +// non-interpreter execution envs (e.g., Wasm). +func FunctionDecl(decl *Function) func(*Rego) { + return newDecl(decl) +} + +func newDecl(decl *Function) func(*Rego) { + return func(r *Rego) { + r.builtinDecls[decl.Name] = &ast.Builtin{ + Name: decl.Name, + Decl: decl.Decl, + } + } +} + +type memo struct { + term *ast.Term + err error +} + +type memokey string + +func memoize(decl *Function, bctx BuiltinContext, terms []*ast.Term, ifEmpty func() (*ast.Term, error)) (*ast.Term, error) { + + if !decl.Memoize { + return ifEmpty() + } + + // NOTE(tsandall): we assume memoization is applied to infrequent built-in + // calls that do things like fetch data from remote locations. As such, + // converting the terms to strings is acceptable for now. + var b strings.Builder + if _, err := b.WriteString(decl.Name); err != nil { + return nil, err + } + + // The term slice _may_ include an output term depending on how the caller + // referred to the built-in function. Only use the arguments as the cache + // key. Unification ensures we don't get false positive matches. + for i := 0; i < decl.Decl.Arity(); i++ { + if _, err := b.WriteString(terms[i].String()); err != nil { + return nil, err + } + } + + key := memokey(b.String()) + hit, ok := bctx.Cache.Get(key) + var m memo + if ok { + m = hit.(memo) + } else { + m.term, m.err = ifEmpty() + bctx.Cache.Put(key, m) + } + + return m.term, m.err +} + +// Dump returns an argument that sets the writer to dump debugging information to. +func Dump(w io.Writer) func(r *Rego) { + return func(r *Rego) { + r.dump = w + } +} + +// Query returns an argument that sets the Rego query. +func Query(q string) func(r *Rego) { + return func(r *Rego) { + r.query = q + } +} + +// ParsedQuery returns an argument that sets the Rego query. +func ParsedQuery(q ast.Body) func(r *Rego) { + return func(r *Rego) { + r.parsedQuery = q + } +} + +// Package returns an argument that sets the Rego package on the query's +// context. +func Package(p string) func(r *Rego) { + return func(r *Rego) { + r.pkg = p + } +} + +// ParsedPackage returns an argument that sets the Rego package on the query's +// context. +func ParsedPackage(pkg *ast.Package) func(r *Rego) { + return func(r *Rego) { + r.parsedPackage = pkg + } +} + +// Imports returns an argument that adds a Rego import to the query's context. +func Imports(p []string) func(r *Rego) { + return func(r *Rego) { + r.imports = append(r.imports, p...) + } +} + +// ParsedImports returns an argument that adds Rego imports to the query's +// context. +func ParsedImports(imp []*ast.Import) func(r *Rego) { + return func(r *Rego) { + r.parsedImports = append(r.parsedImports, imp...) + } +} + +// Input returns an argument that sets the Rego input document. Input should be +// a native Go value representing the input document. +func Input(x interface{}) func(r *Rego) { + return func(r *Rego) { + r.rawInput = &x + } +} + +// ParsedInput returns an argument that sets the Rego input document. +func ParsedInput(x ast.Value) func(r *Rego) { + return func(r *Rego) { + r.parsedInput = x + } +} + +// Unknowns returns an argument that sets the values to treat as unknown during +// partial evaluation. +func Unknowns(unknowns []string) func(r *Rego) { + return func(r *Rego) { + r.unknowns = unknowns + } +} + +// ParsedUnknowns returns an argument that sets the values to treat as unknown +// during partial evaluation. +func ParsedUnknowns(unknowns []*ast.Term) func(r *Rego) { + return func(r *Rego) { + r.parsedUnknowns = unknowns + } +} + +// DisableInlining adds a set of paths to exclude from partial evaluation inlining. +func DisableInlining(paths []string) func(r *Rego) { + return func(r *Rego) { + r.disableInlining = paths + } +} + +// ShallowInlining prevents rules that depend on unknown values from being inlined. +// Rules that only depend on known values are inlined. +func ShallowInlining(yes bool) func(r *Rego) { + return func(r *Rego) { + r.shallowInlining = yes + } +} + +// SkipPartialNamespace disables namespacing of partial evalution results for support +// rules generated from policy. Synthetic support rules are still namespaced. +func SkipPartialNamespace(yes bool) func(r *Rego) { + return func(r *Rego) { + r.skipPartialNamespace = yes + } +} + +// PartialNamespace returns an argument that sets the namespace to use for +// partial evaluation results. The namespace must be a valid package path +// component. +func PartialNamespace(ns string) func(r *Rego) { + return func(r *Rego) { + r.partialNamespace = ns + } +} + +// Module returns an argument that adds a Rego module. +func Module(filename, input string) func(r *Rego) { + return func(r *Rego) { + r.modules = append(r.modules, rawModule{ + filename: filename, + module: input, + }) + } +} + +// ParsedModule returns an argument that adds a parsed Rego module. If a string +// module with the same filename name is added, it will override the parsed +// module. +func ParsedModule(module *ast.Module) func(*Rego) { + return func(r *Rego) { + var filename string + if module.Package.Location != nil { + filename = module.Package.Location.File + } else { + filename = fmt.Sprintf("module_%p.rego", module) + } + r.parsedModules[filename] = module + } +} + +// Load returns an argument that adds a filesystem path to load data +// and Rego modules from. Any file with a *.rego, *.yaml, or *.json +// extension will be loaded. The path can be either a directory or file, +// directories are loaded recursively. The optional ignore string patterns +// can be used to filter which files are used. +// The Load option can only be used once. +// Note: Loading files will require a write transaction on the store. +func Load(paths []string, filter loader.Filter) func(r *Rego) { + return func(r *Rego) { + r.loadPaths = loadPaths{paths, filter} + } +} + +// LoadBundle returns an argument that adds a filesystem path to load +// a bundle from. The path can be a compressed bundle file or a directory +// to be loaded as a bundle. +// Note: Loading bundles will require a write transaction on the store. +func LoadBundle(path string) func(r *Rego) { + return func(r *Rego) { + r.bundlePaths = append(r.bundlePaths, path) + } +} + +// ParsedBundle returns an argument that adds a bundle to be loaded. +func ParsedBundle(name string, b *bundle.Bundle) func(r *Rego) { + return func(r *Rego) { + r.bundles[name] = b + } +} + +// Compiler returns an argument that sets the Rego compiler. +func Compiler(c *ast.Compiler) func(r *Rego) { + return func(r *Rego) { + r.compiler = c + } +} + +// Store returns an argument that sets the policy engine's data storage layer. +// +// If using the Load, LoadBundle, or ParsedBundle options then a transaction +// must also be provided via the Transaction() option. After loading files +// or bundles the transaction should be aborted or committed. +func Store(s storage.Store) func(r *Rego) { + return func(r *Rego) { + r.store = s + } +} + +// StoreReadAST returns an argument that sets whether the store should eagerly convert data to AST values. +// +// Only applicable when no store has been set on the Rego object through the Store option. +func StoreReadAST(enabled bool) func(r *Rego) { + return func(r *Rego) { + r.ownStoreReadAst = enabled + } +} + +// Transaction returns an argument that sets the transaction to use for storage +// layer operations. +// +// Requires the store associated with the transaction to be provided via the +// Store() option. If using Load(), LoadBundle(), or ParsedBundle() options +// the transaction will likely require write params. +func Transaction(txn storage.Transaction) func(r *Rego) { + return func(r *Rego) { + r.txn = txn + } +} + +// Metrics returns an argument that sets the metrics collection. +func Metrics(m metrics.Metrics) func(r *Rego) { + return func(r *Rego) { + r.metrics = m + } +} + +// Instrument returns an argument that enables instrumentation for diagnosing +// performance issues. +func Instrument(yes bool) func(r *Rego) { + return func(r *Rego) { + r.instrument = yes + } +} + +// Trace returns an argument that enables tracing on r. +func Trace(yes bool) func(r *Rego) { + return func(r *Rego) { + r.trace = yes + } +} + +// Tracer returns an argument that adds a query tracer to r. +// Deprecated: Use QueryTracer instead. +func Tracer(t topdown.Tracer) func(r *Rego) { + return func(r *Rego) { + if t != nil { + r.queryTracers = append(r.queryTracers, topdown.WrapLegacyTracer(t)) + } + } +} + +// QueryTracer returns an argument that adds a query tracer to r. +func QueryTracer(t topdown.QueryTracer) func(r *Rego) { + return func(r *Rego) { + if t != nil { + r.queryTracers = append(r.queryTracers, t) + } + } +} + +// Runtime returns an argument that sets the runtime data to provide to the +// evaluation engine. +func Runtime(term *ast.Term) func(r *Rego) { + return func(r *Rego) { + r.runtime = term + } +} + +// Time sets the wall clock time to use during policy evaluation. Prepared queries +// do not inherit this parameter. Use EvalTime to set the wall clock time when +// executing a prepared query. +func Time(x time.Time) func(r *Rego) { + return func(r *Rego) { + r.time = x + } +} + +// Seed sets a reader that will seed randomization required by built-in functions. +// If a seed is not provided crypto/rand.Reader is used. +func Seed(r io.Reader) func(*Rego) { + return func(e *Rego) { + e.seed = r + } +} + +// PrintTrace is a helper function to write a human-readable version of the +// trace to the writer w. +func PrintTrace(w io.Writer, r *Rego) { + if r == nil || r.tracebuf == nil { + return + } + topdown.PrettyTrace(w, *r.tracebuf) +} + +// PrintTraceWithLocation is a helper function to write a human-readable version of the +// trace to the writer w. +func PrintTraceWithLocation(w io.Writer, r *Rego) { + if r == nil || r.tracebuf == nil { + return + } + topdown.PrettyTraceWithLocation(w, *r.tracebuf) +} + +// UnsafeBuiltins sets the built-in functions to treat as unsafe and not allow. +// This option is ignored for module compilation if the caller supplies the +// compiler. This option is always honored for query compilation. Provide an +// empty (non-nil) map to disable checks on queries. +func UnsafeBuiltins(unsafeBuiltins map[string]struct{}) func(r *Rego) { + return func(r *Rego) { + r.unsafeBuiltins = unsafeBuiltins + } +} + +// SkipBundleVerification skips verification of a signed bundle. +func SkipBundleVerification(yes bool) func(r *Rego) { + return func(r *Rego) { + r.skipBundleVerification = yes + } +} + +// InterQueryBuiltinCache sets the inter-query cache that built-in functions can utilize +// during evaluation. +func InterQueryBuiltinCache(c cache.InterQueryCache) func(r *Rego) { + return func(r *Rego) { + r.interQueryBuiltinCache = c + } +} + +// InterQueryBuiltinValueCache sets the inter-query value cache that built-in functions can utilize +// during evaluation. +func InterQueryBuiltinValueCache(c cache.InterQueryValueCache) func(r *Rego) { + return func(r *Rego) { + r.interQueryBuiltinValueCache = c + } +} + +// NDBuiltinCache sets the non-deterministic builtins cache. +func NDBuiltinCache(c builtins.NDBCache) func(r *Rego) { + return func(r *Rego) { + r.ndBuiltinCache = c + } +} + +// StrictBuiltinErrors tells the evaluator to treat all built-in function errors as fatal errors. +func StrictBuiltinErrors(yes bool) func(r *Rego) { + return func(r *Rego) { + r.strictBuiltinErrors = yes + } +} + +// BuiltinErrorList supplies an error slice to store built-in function errors. +func BuiltinErrorList(list *[]topdown.Error) func(r *Rego) { + return func(r *Rego) { + r.builtinErrorList = list + } +} + +// Resolver sets a Resolver for a specified ref path. +func Resolver(ref ast.Ref, r resolver.Resolver) func(r *Rego) { + return func(rego *Rego) { + rego.resolvers = append(rego.resolvers, refResolver{ref, r}) + } +} + +// Schemas sets the schemaSet +func Schemas(x *ast.SchemaSet) func(r *Rego) { + return func(r *Rego) { + r.schemaSet = x + } +} + +// Capabilities configures the underlying compiler's capabilities. +// This option is ignored for module compilation if the caller supplies the +// compiler. +func Capabilities(c *ast.Capabilities) func(r *Rego) { + return func(r *Rego) { + r.capabilities = c + } +} + +// Target sets the runtime to exercise. +func Target(t string) func(r *Rego) { + return func(r *Rego) { + r.target = t + } +} + +// GenerateJSON sets the AST to JSON converter for the results. +func GenerateJSON(f func(*ast.Term, *EvalContext) (interface{}, error)) func(r *Rego) { + return func(r *Rego) { + r.generateJSON = f + } +} + +// PrintHook sets the object to use for handling print statement outputs. +func PrintHook(h print.Hook) func(r *Rego) { + return func(r *Rego) { + r.printHook = h + } +} + +// DistributedTracingOpts sets the options to be used by distributed tracing. +func DistributedTracingOpts(tr tracing.Options) func(r *Rego) { + return func(r *Rego) { + r.distributedTacingOpts = tr + } +} + +// EnablePrintStatements enables print() calls. If this option is not provided, +// print() calls will be erased from the policy. This option only applies to +// queries and policies that passed as raw strings, i.e., this function will not +// have any affect if the caller supplies the ast.Compiler instance. +func EnablePrintStatements(yes bool) func(r *Rego) { + return func(r *Rego) { + r.enablePrintStatements = yes + } +} + +// Strict enables or disables strict-mode in the compiler +func Strict(yes bool) func(r *Rego) { + return func(r *Rego) { + r.strict = yes + } +} + +func SetRegoVersion(version ast.RegoVersion) func(r *Rego) { + return func(r *Rego) { + r.regoVersion = version + } +} + +// New returns a new Rego object. +func New(options ...func(r *Rego)) *Rego { + + r := &Rego{ + parsedModules: map[string]*ast.Module{}, + capture: map[*ast.Expr]ast.Var{}, + compiledQueries: map[queryType]compiledQuery{}, + builtinDecls: map[string]*ast.Builtin{}, + builtinFuncs: map[string]*topdown.Builtin{}, + bundles: map[string]*bundle.Bundle{}, + } + + for _, option := range options { + option(r) + } + + if r.compiler == nil { + r.compiler = ast.NewCompiler(). + WithUnsafeBuiltins(r.unsafeBuiltins). + WithBuiltins(r.builtinDecls). + WithDebug(r.dump). + WithSchemas(r.schemaSet). + WithCapabilities(r.capabilities). + WithEnablePrintStatements(r.enablePrintStatements). + WithStrict(r.strict). + WithUseTypeCheckAnnotations(true) + + // topdown could be target "" or "rego", but both could be overridden by + // a target plugin (checked below) + if r.target == targetWasm { + r.compiler = r.compiler.WithEvalMode(ast.EvalModeIR) + } + + if r.regoVersion != ast.RegoUndefined { + r.compiler = r.compiler.WithDefaultRegoVersion(r.regoVersion) + } + } + + if r.store == nil { + r.store = inmem.NewWithOpts(inmem.OptReturnASTValuesOnRead(r.ownStoreReadAst)) + r.ownStore = true + } else { + r.ownStore = false + } + + if r.metrics == nil { + r.metrics = metrics.New() + } + + if r.instrument { + r.instrumentation = topdown.NewInstrumentation(r.metrics) + r.compiler.WithMetrics(r.metrics) + } + + if r.trace { + r.tracebuf = topdown.NewBufferTracer() + r.queryTracers = append(r.queryTracers, r.tracebuf) + } + + if r.partialNamespace == "" { + r.partialNamespace = defaultPartialNamespace + } + + if r.generateJSON == nil { + r.generateJSON = generateJSON + } + + if r.pluginMgr != nil { + for _, name := range r.pluginMgr.Plugins() { + p := r.pluginMgr.Plugin(name) + if p0, ok := p.(TargetPlugin); ok { + r.plugins = append(r.plugins, p0) + } + } + } + + if t := r.targetPlugin(r.target); t != nil { + r.compiler = r.compiler.WithEvalMode(ast.EvalModeIR) + } + + return r +} + +// Eval evaluates this Rego object and returns a ResultSet. +func (r *Rego) Eval(ctx context.Context) (ResultSet, error) { + var err error + var txnClose transactionCloser + r.txn, txnClose, err = r.getTxn(ctx) + if err != nil { + return nil, err + } + + pq, err := r.PrepareForEval(ctx) + if err != nil { + _ = txnClose(ctx, err) // Ignore error + return nil, err + } + + evalArgs := []EvalOption{ + EvalTransaction(r.txn), + EvalMetrics(r.metrics), + EvalInstrument(r.instrument), + EvalTime(r.time), + EvalInterQueryBuiltinCache(r.interQueryBuiltinCache), + EvalInterQueryBuiltinValueCache(r.interQueryBuiltinValueCache), + EvalSeed(r.seed), + } + + if r.ndBuiltinCache != nil { + evalArgs = append(evalArgs, EvalNDBuiltinCache(r.ndBuiltinCache)) + } + + for _, qt := range r.queryTracers { + evalArgs = append(evalArgs, EvalQueryTracer(qt)) + } + + for i := range r.resolvers { + evalArgs = append(evalArgs, EvalResolver(r.resolvers[i].ref, r.resolvers[i].r)) + } + + rs, err := pq.Eval(ctx, evalArgs...) + txnErr := txnClose(ctx, err) // Always call closer + if err == nil { + err = txnErr + } + return rs, err +} + +// PartialEval has been deprecated and renamed to PartialResult. +func (r *Rego) PartialEval(ctx context.Context) (PartialResult, error) { + return r.PartialResult(ctx) +} + +// PartialResult partially evaluates this Rego object and returns a PartialResult. +func (r *Rego) PartialResult(ctx context.Context) (PartialResult, error) { + var err error + var txnClose transactionCloser + r.txn, txnClose, err = r.getTxn(ctx) + if err != nil { + return PartialResult{}, err + } + + pq, err := r.PrepareForEval(ctx, WithPartialEval()) + txnErr := txnClose(ctx, err) // Always call closer + if err != nil { + return PartialResult{}, err + } + if txnErr != nil { + return PartialResult{}, txnErr + } + + pr := PartialResult{ + compiler: pq.r.compiler, + store: pq.r.store, + body: pq.r.parsedQuery, + builtinDecls: pq.r.builtinDecls, + builtinFuncs: pq.r.builtinFuncs, + } + + return pr, nil +} + +// Partial runs partial evaluation on r and returns the result. +func (r *Rego) Partial(ctx context.Context) (*PartialQueries, error) { + var err error + var txnClose transactionCloser + r.txn, txnClose, err = r.getTxn(ctx) + if err != nil { + return nil, err + } + + pq, err := r.PrepareForPartial(ctx) + if err != nil { + _ = txnClose(ctx, err) // Ignore error + return nil, err + } + + evalArgs := []EvalOption{ + EvalTransaction(r.txn), + EvalMetrics(r.metrics), + EvalInstrument(r.instrument), + EvalInterQueryBuiltinCache(r.interQueryBuiltinCache), + EvalInterQueryBuiltinValueCache(r.interQueryBuiltinValueCache), + } + + if r.ndBuiltinCache != nil { + evalArgs = append(evalArgs, EvalNDBuiltinCache(r.ndBuiltinCache)) + } + + for _, t := range r.queryTracers { + evalArgs = append(evalArgs, EvalQueryTracer(t)) + } + + for i := range r.resolvers { + evalArgs = append(evalArgs, EvalResolver(r.resolvers[i].ref, r.resolvers[i].r)) + } + + pqs, err := pq.Partial(ctx, evalArgs...) + txnErr := txnClose(ctx, err) // Always call closer + if err == nil { + err = txnErr + } + return pqs, err +} + +// CompileOption defines a function to set options on Compile calls. +type CompileOption func(*CompileContext) + +// CompileContext contains options for Compile calls. +type CompileContext struct { + partial bool +} + +// CompilePartial defines an option to control whether partial evaluation is run +// before the query is planned and compiled. +func CompilePartial(yes bool) CompileOption { + return func(cfg *CompileContext) { + cfg.partial = yes + } +} + +// Compile returns a compiled policy query. +func (r *Rego) Compile(ctx context.Context, opts ...CompileOption) (*CompileResult, error) { + + var cfg CompileContext + + for _, opt := range opts { + opt(&cfg) + } + + var queries []ast.Body + modules := make([]*ast.Module, 0, len(r.compiler.Modules)) + + if cfg.partial { + + pq, err := r.Partial(ctx) + if err != nil { + return nil, err + } + if r.dump != nil { + if len(pq.Queries) != 0 { + msg := fmt.Sprintf("QUERIES (%d total):", len(pq.Queries)) + fmt.Fprintln(r.dump, msg) + fmt.Fprintln(r.dump, strings.Repeat("-", len(msg))) + for i := range pq.Queries { + fmt.Println(pq.Queries[i]) + } + fmt.Fprintln(r.dump) + } + if len(pq.Support) != 0 { + msg := fmt.Sprintf("SUPPORT (%d total):", len(pq.Support)) + fmt.Fprintln(r.dump, msg) + fmt.Fprintln(r.dump, strings.Repeat("-", len(msg))) + for i := range pq.Support { + fmt.Println(pq.Support[i]) + } + fmt.Fprintln(r.dump) + } + } + + queries = pq.Queries + modules = pq.Support + + for _, module := range r.compiler.Modules { + modules = append(modules, module) + } + } else { + var err error + // If creating a new transaction it should be closed before calling the + // planner to avoid holding open the transaction longer than needed. + // + // TODO(tsandall): in future, planner could make use of store, in which + // case this will need to change. + var txnClose transactionCloser + r.txn, txnClose, err = r.getTxn(ctx) + if err != nil { + return nil, err + } + + err = r.prepare(ctx, compileQueryType, nil) + txnErr := txnClose(ctx, err) // Always call closer + if err != nil { + return nil, err + } + if txnErr != nil { + return nil, err + } + + for _, module := range r.compiler.Modules { + modules = append(modules, module) + } + + queries = []ast.Body{r.compiledQueries[compileQueryType].query} + } + + if tgt := r.targetPlugin(r.target); tgt != nil { + return nil, fmt.Errorf("unsupported for rego target plugins") + } + + return r.compileWasm(modules, queries, compileQueryType) // TODO(sr) control flow is funky here +} + +func (r *Rego) compileWasm(_ []*ast.Module, queries []ast.Body, qType queryType) (*CompileResult, error) { + policy, err := r.planQuery(queries, qType) + if err != nil { + return nil, err + } + + m, err := wasm.New().WithPolicy(policy).Compile() + if err != nil { + return nil, err + } + + var out bytes.Buffer + if err := encoding.WriteModule(&out, m); err != nil { + return nil, err + } + + return &CompileResult{ + Bytes: out.Bytes(), + }, nil +} + +// PrepareOption defines a function to set an option to control +// the behavior of the Prepare call. +type PrepareOption func(*PrepareConfig) + +// PrepareConfig holds settings to control the behavior of the +// Prepare call. +type PrepareConfig struct { + doPartialEval bool + disableInlining *[]string + builtinFuncs map[string]*topdown.Builtin +} + +// WithPartialEval configures an option for PrepareForEval +// which will have it perform partial evaluation while preparing +// the query (similar to rego.Rego#PartialResult) +func WithPartialEval() PrepareOption { + return func(p *PrepareConfig) { + p.doPartialEval = true + } +} + +// WithNoInline adds a set of paths to exclude from partial evaluation inlining. +func WithNoInline(paths []string) PrepareOption { + return func(p *PrepareConfig) { + p.disableInlining = &paths + } +} + +// WithBuiltinFuncs carries the rego.Function{1,2,3} per-query function definitions +// to the target plugins. +func WithBuiltinFuncs(bis map[string]*topdown.Builtin) PrepareOption { + return func(p *PrepareConfig) { + if p.builtinFuncs == nil { + p.builtinFuncs = make(map[string]*topdown.Builtin, len(bis)) + } + for k, v := range bis { + p.builtinFuncs[k] = v + } + } +} + +// BuiltinFuncs allows retrieving the builtin funcs set via PrepareOption +// WithBuiltinFuncs. +func (p *PrepareConfig) BuiltinFuncs() map[string]*topdown.Builtin { + return p.builtinFuncs +} + +// PrepareForEval will parse inputs, modules, and query arguments in preparation +// of evaluating them. +func (r *Rego) PrepareForEval(ctx context.Context, opts ...PrepareOption) (PreparedEvalQuery, error) { + if !r.hasQuery() { + return PreparedEvalQuery{}, fmt.Errorf("cannot evaluate empty query") + } + + pCfg := &PrepareConfig{} + for _, o := range opts { + o(pCfg) + } + + var err error + var txnClose transactionCloser + r.txn, txnClose, err = r.getTxn(ctx) + if err != nil { + return PreparedEvalQuery{}, err + } + + // If the caller wanted to do partial evaluation as part of preparation + // do it now and use the new Rego object. + if pCfg.doPartialEval { + + pr, err := r.partialResult(ctx, pCfg) + if err != nil { + _ = txnClose(ctx, err) // Ignore error + return PreparedEvalQuery{}, err + } + + // Prepare the new query using the result of partial evaluation + pq, err := pr.Rego(Transaction(r.txn)).PrepareForEval(ctx) + txnErr := txnClose(ctx, err) + if err != nil { + return pq, err + } + return pq, txnErr + } + + err = r.prepare(ctx, evalQueryType, []extraStage{ + { + after: "ResolveRefs", + stage: ast.QueryCompilerStageDefinition{ + Name: "RewriteToCaptureValue", + MetricName: "query_compile_stage_rewrite_to_capture_value", + Stage: r.rewriteQueryToCaptureValue, + }, + }, + }) + if err != nil { + _ = txnClose(ctx, err) // Ignore error + return PreparedEvalQuery{}, err + } + + switch r.target { + case targetWasm: // TODO(sr): make wasm a target plugin, too + + if r.hasWasmModule() { + _ = txnClose(ctx, err) // Ignore error + return PreparedEvalQuery{}, fmt.Errorf("wasm target not supported") + } + + var modules []*ast.Module + for _, module := range r.compiler.Modules { + modules = append(modules, module) + } + + queries := []ast.Body{r.compiledQueries[evalQueryType].query} + + e, err := opa.LookupEngine(targetWasm) + if err != nil { + return PreparedEvalQuery{}, err + } + + // nolint: staticcheck // SA4006 false positive + cr, err := r.compileWasm(modules, queries, evalQueryType) + if err != nil { + _ = txnClose(ctx, err) // Ignore error + return PreparedEvalQuery{}, err + } + + // nolint: staticcheck // SA4006 false positive + data, err := r.store.Read(ctx, r.txn, storage.Path{}) + if err != nil { + _ = txnClose(ctx, err) // Ignore error + return PreparedEvalQuery{}, err + } + + o, err := e.New().WithPolicyBytes(cr.Bytes).WithDataJSON(data).Init() + if err != nil { + _ = txnClose(ctx, err) // Ignore error + return PreparedEvalQuery{}, err + } + r.opa = o + + case targetRego: // do nothing, don't lookup default plugin + default: // either a specific plugin target, or one that is default + if tgt := r.targetPlugin(r.target); tgt != nil { + queries := []ast.Body{r.compiledQueries[evalQueryType].query} + pol, err := r.planQuery(queries, evalQueryType) + if err != nil { + return PreparedEvalQuery{}, err + } + // always add the builtins provided via rego.FunctionN options + opts = append(opts, WithBuiltinFuncs(r.builtinFuncs)) + r.targetPrepState, err = tgt.PrepareForEval(ctx, pol, opts...) + if err != nil { + return PreparedEvalQuery{}, err + } + } + } + + txnErr := txnClose(ctx, err) // Always call closer + if txnErr != nil { + return PreparedEvalQuery{}, txnErr + } + + return PreparedEvalQuery{preparedQuery{r, pCfg}}, err +} + +// PrepareForPartial will parse inputs, modules, and query arguments in preparation +// of partially evaluating them. +func (r *Rego) PrepareForPartial(ctx context.Context, opts ...PrepareOption) (PreparedPartialQuery, error) { + if !r.hasQuery() { + return PreparedPartialQuery{}, fmt.Errorf("cannot evaluate empty query") + } + + pCfg := &PrepareConfig{} + for _, o := range opts { + o(pCfg) + } + + var err error + var txnClose transactionCloser + r.txn, txnClose, err = r.getTxn(ctx) + if err != nil { + return PreparedPartialQuery{}, err + } + + err = r.prepare(ctx, partialQueryType, []extraStage{ + { + after: "CheckSafety", + stage: ast.QueryCompilerStageDefinition{ + Name: "RewriteEquals", + MetricName: "query_compile_stage_rewrite_equals", + Stage: r.rewriteEqualsForPartialQueryCompile, + }, + }, + }) + txnErr := txnClose(ctx, err) // Always call closer + if err != nil { + return PreparedPartialQuery{}, err + } + if txnErr != nil { + return PreparedPartialQuery{}, txnErr + } + + return PreparedPartialQuery{preparedQuery{r, pCfg}}, err +} + +func (r *Rego) prepare(ctx context.Context, qType queryType, extras []extraStage) error { + var err error + + r.parsedInput, err = r.parseInput() + if err != nil { + return err + } + + err = r.loadFiles(ctx, r.txn, r.metrics) + if err != nil { + return err + } + + err = r.loadBundles(ctx, r.txn, r.metrics) + if err != nil { + return err + } + + err = r.parseModules(ctx, r.txn, r.metrics) + if err != nil { + return err + } + + // Compile the modules *before* the query, else functions + // defined in the module won't be found... + err = r.compileModules(ctx, r.txn, r.metrics) + if err != nil { + return err + } + + imports, err := r.prepareImports() + if err != nil { + return err + } + + queryImports := []*ast.Import{} + for _, imp := range imports { + path := imp.Path.Value.(ast.Ref) + if path.HasPrefix([]*ast.Term{ast.FutureRootDocument}) || path.HasPrefix([]*ast.Term{ast.RegoRootDocument}) { + queryImports = append(queryImports, imp) + } + } + + r.parsedQuery, err = r.parseQuery(queryImports, r.metrics) + if err != nil { + return err + } + + err = r.compileAndCacheQuery(qType, r.parsedQuery, imports, r.metrics, extras) + if err != nil { + return err + } + + return nil +} + +func (r *Rego) parseModules(ctx context.Context, txn storage.Transaction, m metrics.Metrics) error { + if len(r.modules) == 0 { + return nil + } + + ids, err := r.store.ListPolicies(ctx, txn) + if err != nil { + return err + } + + m.Timer(metrics.RegoModuleParse).Start() + defer m.Timer(metrics.RegoModuleParse).Stop() + var errs Errors + + // Parse any modules that are saved to the store, but only if + // another compile step is going to occur (ie. we have parsed modules + // that need to be compiled). + for _, id := range ids { + // if it is already on the compiler we're using + // then don't bother to re-parse it from source + if _, haveMod := r.compiler.Modules[id]; haveMod { + continue + } + + bs, err := r.store.GetPolicy(ctx, txn, id) + if err != nil { + return err + } + + parsed, err := ast.ParseModuleWithOpts(id, string(bs), ast.ParserOptions{RegoVersion: r.regoVersion}) + if err != nil { + errs = append(errs, err) + } + + r.parsedModules[id] = parsed + } + + // Parse any passed in as arguments to the Rego object + for _, module := range r.modules { + p, err := module.ParseWithOpts(ast.ParserOptions{RegoVersion: r.regoVersion}) + if err != nil { + switch errorWithType := err.(type) { + case ast.Errors: + for _, e := range errorWithType { + errs = append(errs, e) + } + default: + errs = append(errs, errorWithType) + } + } + r.parsedModules[module.filename] = p + } + + if len(errs) > 0 { + return errs + } + + return nil +} + +func (r *Rego) loadFiles(ctx context.Context, txn storage.Transaction, m metrics.Metrics) error { + if len(r.loadPaths.paths) == 0 { + return nil + } + + m.Timer(metrics.RegoLoadFiles).Start() + defer m.Timer(metrics.RegoLoadFiles).Stop() + + result, err := loader.NewFileLoader(). + WithMetrics(m). + WithProcessAnnotation(true). + WithRegoVersion(r.regoVersion). + WithCapabilities(r.capabilities). + Filtered(r.loadPaths.paths, r.loadPaths.filter) + if err != nil { + return err + } + for name, mod := range result.Modules { + r.parsedModules[name] = mod.Parsed + } + + if len(result.Documents) > 0 { + err = r.store.Write(ctx, txn, storage.AddOp, storage.Path{}, result.Documents) + if err != nil { + return err + } + } + return nil +} + +func (r *Rego) loadBundles(_ context.Context, _ storage.Transaction, m metrics.Metrics) error { + if len(r.bundlePaths) == 0 { + return nil + } + + m.Timer(metrics.RegoLoadBundles).Start() + defer m.Timer(metrics.RegoLoadBundles).Stop() + + for _, path := range r.bundlePaths { + bndl, err := loader.NewFileLoader(). + WithMetrics(m). + WithProcessAnnotation(true). + WithSkipBundleVerification(r.skipBundleVerification). + WithRegoVersion(r.regoVersion). + WithCapabilities(r.capabilities). + AsBundle(path) + if err != nil { + return fmt.Errorf("loading error: %s", err) + } + r.bundles[path] = bndl + } + return nil +} + +func (r *Rego) parseInput() (ast.Value, error) { + if r.parsedInput != nil { + return r.parsedInput, nil + } + return r.parseRawInput(r.rawInput, r.metrics) +} + +func (r *Rego) parseRawInput(rawInput *interface{}, m metrics.Metrics) (ast.Value, error) { + var input ast.Value + + if rawInput == nil { + return input, nil + } + + m.Timer(metrics.RegoInputParse).Start() + defer m.Timer(metrics.RegoInputParse).Stop() + + rawPtr := util.Reference(rawInput) + + // roundtrip through json: this turns slices (e.g. []string, []bool) into + // []interface{}, the only array type ast.InterfaceToValue can work with + if err := util.RoundTrip(rawPtr); err != nil { + return nil, err + } + + return ast.InterfaceToValue(*rawPtr) +} + +func (r *Rego) parseQuery(queryImports []*ast.Import, m metrics.Metrics) (ast.Body, error) { + if r.parsedQuery != nil { + return r.parsedQuery, nil + } + + m.Timer(metrics.RegoQueryParse).Start() + defer m.Timer(metrics.RegoQueryParse).Stop() + + popts, err := future.ParserOptionsFromFutureImports(queryImports) + if err != nil { + return nil, err + } + popts.RegoVersion = r.regoVersion + popts, err = parserOptionsFromRegoVersionImport(queryImports, popts) + if err != nil { + return nil, err + } + popts.SkipRules = true + return ast.ParseBodyWithOpts(r.query, popts) +} + +func parserOptionsFromRegoVersionImport(imports []*ast.Import, popts ast.ParserOptions) (ast.ParserOptions, error) { + for _, imp := range imports { + path := imp.Path.Value.(ast.Ref) + if ast.Compare(path, ast.RegoV1CompatibleRef) == 0 { + popts.RegoVersion = ast.RegoV1 + return popts, nil + } + } + return popts, nil +} + +func (r *Rego) compileModules(ctx context.Context, txn storage.Transaction, m metrics.Metrics) error { + + // Only compile again if there are new modules. + if len(r.bundles) > 0 || len(r.parsedModules) > 0 { + + // The bundle.Activate call will activate any bundles passed in + // (ie compile + handle data store changes), and include any of + // the additional modules passed in. If no bundles are provided + // it will only compile the passed in modules. + // Use this as the single-point of compiling everything only a + // single time. + opts := &bundle.ActivateOpts{ + Ctx: ctx, + Store: r.store, + Txn: txn, + Compiler: r.compilerForTxn(ctx, r.store, txn), + Metrics: m, + Bundles: r.bundles, + ExtraModules: r.parsedModules, + ParserOptions: ast.ParserOptions{RegoVersion: r.regoVersion}, + } + err := bundle.Activate(opts) + if err != nil { + return err + } + } + + // Ensure all configured resolvers from the store are loaded. Skip if any were explicitly provided. + if len(r.resolvers) == 0 { + resolvers, err := bundleUtils.LoadWasmResolversFromStore(ctx, r.store, txn, r.bundles) + if err != nil { + return err + } + + for _, rslvr := range resolvers { + for _, ep := range rslvr.Entrypoints() { + r.resolvers = append(r.resolvers, refResolver{ep, rslvr}) + } + } + } + return nil +} + +func (r *Rego) compileAndCacheQuery(qType queryType, query ast.Body, imports []*ast.Import, m metrics.Metrics, extras []extraStage) error { + m.Timer(metrics.RegoQueryCompile).Start() + defer m.Timer(metrics.RegoQueryCompile).Stop() + + cachedQuery, ok := r.compiledQueries[qType] + if ok && cachedQuery.query != nil && cachedQuery.compiler != nil { + return nil + } + + qc, compiled, err := r.compileQuery(query, imports, m, extras) + if err != nil { + return err + } + + // cache the query for future use + r.compiledQueries[qType] = compiledQuery{ + query: compiled, + compiler: qc, + } + return nil +} + +func (r *Rego) prepareImports() ([]*ast.Import, error) { + imports := r.parsedImports + + if len(r.imports) > 0 { + s := make([]string, len(r.imports)) + for i := range r.imports { + s[i] = fmt.Sprintf("import %v", r.imports[i]) + } + parsed, err := ast.ParseImports(strings.Join(s, "\n")) + if err != nil { + return nil, err + } + imports = append(imports, parsed...) + } + return imports, nil +} + +func (r *Rego) compileQuery(query ast.Body, imports []*ast.Import, _ metrics.Metrics, extras []extraStage) (ast.QueryCompiler, ast.Body, error) { + var pkg *ast.Package + + if r.pkg != "" { + var err error + pkg, err = ast.ParsePackage(fmt.Sprintf("package %v", r.pkg)) + if err != nil { + return nil, nil, err + } + } else { + pkg = r.parsedPackage + } + + qctx := ast.NewQueryContext(). + WithPackage(pkg). + WithImports(imports) + + qc := r.compiler.QueryCompiler(). + WithContext(qctx). + WithUnsafeBuiltins(r.unsafeBuiltins). + WithEnablePrintStatements(r.enablePrintStatements). + WithStrict(false) + + for _, extra := range extras { + qc = qc.WithStageAfter(extra.after, extra.stage) + } + + compiled, err := qc.Compile(query) + + return qc, compiled, err + +} + +func (r *Rego) eval(ctx context.Context, ectx *EvalContext) (ResultSet, error) { + switch { + case r.targetPrepState != nil: // target plugin flow + var val ast.Value + if r.runtime != nil { + val = r.runtime.Value + } + s, err := r.targetPrepState.Eval(ctx, ectx, val) + if err != nil { + return nil, err + } + return r.valueToQueryResult(s, ectx) + case r.target == targetWasm: + return r.evalWasm(ctx, ectx) + case r.target == targetRego: // continue + } + + q := topdown.NewQuery(ectx.compiledQuery.query). + WithQueryCompiler(ectx.compiledQuery.compiler). + WithCompiler(r.compiler). + WithStore(r.store). + WithTransaction(ectx.txn). + WithBuiltins(r.builtinFuncs). + WithMetrics(ectx.metrics). + WithInstrumentation(ectx.instrumentation). + WithRuntime(r.runtime). + WithIndexing(ectx.indexing). + WithEarlyExit(ectx.earlyExit). + WithInterQueryBuiltinCache(ectx.interQueryBuiltinCache). + WithInterQueryBuiltinValueCache(ectx.interQueryBuiltinValueCache). + WithStrictBuiltinErrors(r.strictBuiltinErrors). + WithBuiltinErrorList(r.builtinErrorList). + WithSeed(ectx.seed). + WithPrintHook(ectx.printHook). + WithDistributedTracingOpts(r.distributedTacingOpts). + WithVirtualCache(ectx.virtualCache) + + if !ectx.time.IsZero() { + q = q.WithTime(ectx.time) + } + + if ectx.ndBuiltinCache != nil { + q = q.WithNDBuiltinCache(ectx.ndBuiltinCache) + } + + for i := range ectx.queryTracers { + q = q.WithQueryTracer(ectx.queryTracers[i]) + } + + if ectx.parsedInput != nil { + q = q.WithInput(ast.NewTerm(ectx.parsedInput)) + } + + if ectx.httpRoundTripper != nil { + q = q.WithHTTPRoundTripper(ectx.httpRoundTripper) + } + + for i := range ectx.resolvers { + q = q.WithResolver(ectx.resolvers[i].ref, ectx.resolvers[i].r) + } + + // Cancel query if context is cancelled or deadline is reached. + c := topdown.NewCancel() + q = q.WithCancel(c) + exit := make(chan struct{}) + defer close(exit) + go waitForDone(ctx, exit, func() { + c.Cancel() + }) + + var rs ResultSet + err := q.Iter(ctx, func(qr topdown.QueryResult) error { + result, err := r.generateResult(qr, ectx) + if err != nil { + return err + } + rs = append(rs, result) + return nil + }) + + if err != nil { + return nil, err + } + + if len(rs) == 0 { + return nil, nil + } + + return rs, nil +} + +func (r *Rego) evalWasm(ctx context.Context, ectx *EvalContext) (ResultSet, error) { + input := ectx.rawInput + if ectx.parsedInput != nil { + i := interface{}(ectx.parsedInput) + input = &i + } + result, err := r.opa.Eval(ctx, opa.EvalOpts{ + Metrics: r.metrics, + Input: input, + Time: ectx.time, + Seed: ectx.seed, + InterQueryBuiltinCache: ectx.interQueryBuiltinCache, + NDBuiltinCache: ectx.ndBuiltinCache, + PrintHook: ectx.printHook, + Capabilities: ectx.capabilities, + }) + if err != nil { + return nil, err + } + + parsed, err := ast.ParseTerm(string(result.Result)) + if err != nil { + return nil, err + } + + return r.valueToQueryResult(parsed.Value, ectx) +} + +func (r *Rego) valueToQueryResult(res ast.Value, ectx *EvalContext) (ResultSet, error) { + resultSet, ok := res.(ast.Set) + if !ok { + return nil, fmt.Errorf("illegal result type") + } + + if resultSet.Len() == 0 { + return nil, nil + } + + var rs ResultSet + err := resultSet.Iter(func(term *ast.Term) error { + obj, ok := term.Value.(ast.Object) + if !ok { + return fmt.Errorf("illegal result type") + } + qr := topdown.QueryResult{} + obj.Foreach(func(k, v *ast.Term) { + kvt := ast.VarTerm(string(k.Value.(ast.String))) + qr[kvt.Value.(ast.Var)] = v + }) + result, err := r.generateResult(qr, ectx) + if err != nil { + return err + } + rs = append(rs, result) + return nil + }) + + return rs, err +} + +func (r *Rego) generateResult(qr topdown.QueryResult, ectx *EvalContext) (Result, error) { + + rewritten := ectx.compiledQuery.compiler.RewrittenVars() + + result := newResult() + for k, term := range qr { + v, err := r.generateJSON(term, ectx) + if err != nil { + return result, err + } + + if rw, ok := rewritten[k]; ok { + k = rw + } + if isTermVar(k) || isTermWasmVar(k) || k.IsGenerated() || k.IsWildcard() { + continue + } + result.Bindings[string(k)] = v + } + + for _, expr := range ectx.compiledQuery.query { + if expr.Generated { + continue + } + + if k, ok := r.capture[expr]; ok { + v, err := r.generateJSON(qr[k], ectx) + if err != nil { + return result, err + } + result.Expressions = append(result.Expressions, newExpressionValue(expr, v)) + } else { + result.Expressions = append(result.Expressions, newExpressionValue(expr, true)) + } + + } + return result, nil +} + +func (r *Rego) partialResult(ctx context.Context, pCfg *PrepareConfig) (PartialResult, error) { + + err := r.prepare(ctx, partialResultQueryType, []extraStage{ + { + after: "ResolveRefs", + stage: ast.QueryCompilerStageDefinition{ + Name: "RewriteForPartialEval", + MetricName: "query_compile_stage_rewrite_for_partial_eval", + Stage: r.rewriteQueryForPartialEval, + }, + }, + }) + if err != nil { + return PartialResult{}, err + } + + ectx := &EvalContext{ + parsedInput: r.parsedInput, + metrics: r.metrics, + txn: r.txn, + partialNamespace: r.partialNamespace, + queryTracers: r.queryTracers, + compiledQuery: r.compiledQueries[partialResultQueryType], + instrumentation: r.instrumentation, + indexing: true, + resolvers: r.resolvers, + capabilities: r.capabilities, + strictBuiltinErrors: r.strictBuiltinErrors, + } + + disableInlining := r.disableInlining + + if pCfg.disableInlining != nil { + disableInlining = *pCfg.disableInlining + } + + ectx.disableInlining, err = parseStringsToRefs(disableInlining) + if err != nil { + return PartialResult{}, err + } + + pq, err := r.partial(ctx, ectx) + if err != nil { + return PartialResult{}, err + } + + // Construct module for queries. + id := fmt.Sprintf("__partialresult__%s__", ectx.partialNamespace) + + module, err := ast.ParseModuleWithOpts(id, "package "+ectx.partialNamespace, + ast.ParserOptions{RegoVersion: r.regoVersion}) + if err != nil { + return PartialResult{}, fmt.Errorf("bad partial namespace") + } + + module.Rules = make([]*ast.Rule, len(pq.Queries)) + for i, body := range pq.Queries { + rule := &ast.Rule{ + Head: ast.NewHead(ast.Var("__result__"), nil, ast.Wildcard), + Body: body, + Module: module, + } + module.Rules[i] = rule + if checkPartialResultForRecursiveRefs(body, rule.Path()) { + return PartialResult{}, Errors{errPartialEvaluationNotEffective} + } + } + + // Update compiler with partial evaluation output. + r.compiler.Modules[id] = module + for i, module := range pq.Support { + r.compiler.Modules[fmt.Sprintf("__partialsupport__%s__%d__", ectx.partialNamespace, i)] = module + } + + r.metrics.Timer(metrics.RegoModuleCompile).Start() + r.compilerForTxn(ctx, r.store, r.txn).Compile(r.compiler.Modules) + r.metrics.Timer(metrics.RegoModuleCompile).Stop() + + if r.compiler.Failed() { + return PartialResult{}, r.compiler.Errors + } + + result := PartialResult{ + compiler: r.compiler, + store: r.store, + body: ast.MustParseBody(fmt.Sprintf("data.%v.__result__", ectx.partialNamespace)), + builtinDecls: r.builtinDecls, + builtinFuncs: r.builtinFuncs, + } + + return result, nil +} + +func (r *Rego) partial(ctx context.Context, ectx *EvalContext) (*PartialQueries, error) { + + var unknowns []*ast.Term + + switch { + case ectx.parsedUnknowns != nil: + unknowns = ectx.parsedUnknowns + case ectx.unknowns != nil: + unknowns = make([]*ast.Term, len(ectx.unknowns)) + for i := range ectx.unknowns { + var err error + unknowns[i], err = ast.ParseTerm(ectx.unknowns[i]) + if err != nil { + return nil, err + } + } + default: + // Use input document as unknown if caller has not specified any. + unknowns = []*ast.Term{ast.NewTerm(ast.InputRootRef)} + } + + q := topdown.NewQuery(ectx.compiledQuery.query). + WithQueryCompiler(ectx.compiledQuery.compiler). + WithCompiler(r.compiler). + WithStore(r.store). + WithTransaction(ectx.txn). + WithBuiltins(r.builtinFuncs). + WithMetrics(ectx.metrics). + WithInstrumentation(ectx.instrumentation). + WithUnknowns(unknowns). + WithDisableInlining(ectx.disableInlining). + WithRuntime(r.runtime). + WithIndexing(ectx.indexing). + WithEarlyExit(ectx.earlyExit). + WithPartialNamespace(ectx.partialNamespace). + WithSkipPartialNamespace(r.skipPartialNamespace). + WithShallowInlining(r.shallowInlining). + WithInterQueryBuiltinCache(ectx.interQueryBuiltinCache). + WithInterQueryBuiltinValueCache(ectx.interQueryBuiltinValueCache). + WithStrictBuiltinErrors(ectx.strictBuiltinErrors). + WithSeed(ectx.seed). + WithPrintHook(ectx.printHook) + + if !ectx.time.IsZero() { + q = q.WithTime(ectx.time) + } + + if ectx.ndBuiltinCache != nil { + q = q.WithNDBuiltinCache(ectx.ndBuiltinCache) + } + + for i := range ectx.queryTracers { + q = q.WithQueryTracer(ectx.queryTracers[i]) + } + + if ectx.parsedInput != nil { + q = q.WithInput(ast.NewTerm(ectx.parsedInput)) + } + + for i := range ectx.resolvers { + q = q.WithResolver(ectx.resolvers[i].ref, ectx.resolvers[i].r) + } + + // Cancel query if context is cancelled or deadline is reached. + c := topdown.NewCancel() + q = q.WithCancel(c) + exit := make(chan struct{}) + defer close(exit) + go waitForDone(ctx, exit, func() { + c.Cancel() + }) + + queries, support, err := q.PartialRun(ctx) + if err != nil { + return nil, err + } + + // If the target rego-version is v0, and the rego.v1 import is available, then we attempt to apply it to support modules. + if r.regoVersion == ast.RegoV0 && + (r.capabilities == nil || + r.capabilities.ContainsFeature(ast.FeatureRegoV1Import) || + r.capabilities.ContainsFeature(ast.FeatureRegoV1)) { + + for i, mod := range support { + // We can't apply the RegoV0CompatV1 version to the support module if it contains rules or vars that + // conflict with future keywords. + applyRegoVersion := true + + ast.WalkRules(mod, func(r *ast.Rule) bool { + name := r.Head.Name + if name == "" && len(r.Head.Reference) > 0 { + name = r.Head.Reference[0].Value.(ast.Var) + } + if ast.IsFutureKeywordForRegoVersion(name.String(), ast.RegoV0) { + applyRegoVersion = false + return true + } + return false + }) + + if applyRegoVersion { + ast.WalkVars(mod, func(v ast.Var) bool { + if ast.IsFutureKeywordForRegoVersion(v.String(), ast.RegoV0) { + applyRegoVersion = false + return true + } + return false + }) + } + + if applyRegoVersion { + support[i].SetRegoVersion(ast.RegoV0CompatV1) + } else { + support[i].SetRegoVersion(r.regoVersion) + } + } + } else { + // If the target rego-version is not v0, then we apply the target rego-version to the support modules. + for i := range support { + support[i].SetRegoVersion(r.regoVersion) + } + } + + pq := &PartialQueries{ + Queries: queries, + Support: support, + } + + return pq, nil +} + +func (r *Rego) rewriteQueryToCaptureValue(_ ast.QueryCompiler, query ast.Body) (ast.Body, error) { + + checkCapture := iteration(query) || len(query) > 1 + + for _, expr := range query { + + if expr.Negated { + continue + } + + if expr.IsAssignment() || expr.IsEquality() { + continue + } + + var capture *ast.Term + + // If the expression can be evaluated as a function, rewrite it to + // capture the return value. E.g., neq(1,2) becomes neq(1,2,x) but + // plus(1,2,x) does not get rewritten. + switch terms := expr.Terms.(type) { + case *ast.Term: + capture = r.generateTermVar() + expr.Terms = ast.Equality.Expr(terms, capture).Terms + r.capture[expr] = capture.Value.(ast.Var) + case []*ast.Term: + tpe := r.compiler.TypeEnv.Get(terms[0]) + if !types.Void(tpe) && types.Arity(tpe) == len(terms)-1 { + capture = r.generateTermVar() + expr.Terms = append(terms, capture) + r.capture[expr] = capture.Value.(ast.Var) + } + } + + if capture != nil && checkCapture { + cpy := expr.Copy() + cpy.Terms = capture + cpy.Generated = true + cpy.With = nil + query.Append(cpy) + } + } + + return query, nil +} + +func (r *Rego) rewriteQueryForPartialEval(_ ast.QueryCompiler, query ast.Body) (ast.Body, error) { + if len(query) != 1 { + return nil, fmt.Errorf("partial evaluation requires single ref (not multiple expressions)") + } + + term, ok := query[0].Terms.(*ast.Term) + if !ok { + return nil, fmt.Errorf("partial evaluation requires ref (not expression)") + } + + ref, ok := term.Value.(ast.Ref) + if !ok { + return nil, fmt.Errorf("partial evaluation requires ref (not %v)", ast.TypeName(term.Value)) + } + + if !ref.IsGround() { + return nil, fmt.Errorf("partial evaluation requires ground ref") + } + + return ast.NewBody(ast.Equality.Expr(ast.Wildcard, term)), nil +} + +// rewriteEqualsForPartialQueryCompile will rewrite == to = in queries. Normally +// this wouldn't be done, except for handling queries with the `Partial` API +// where rewriting them can substantially simplify the result, and it is unlikely +// that the caller would need expression values. +func (r *Rego) rewriteEqualsForPartialQueryCompile(_ ast.QueryCompiler, query ast.Body) (ast.Body, error) { + doubleEq := ast.Equal.Ref() + unifyOp := ast.Equality.Ref() + ast.WalkExprs(query, func(x *ast.Expr) bool { + if x.IsCall() { + operator := x.Operator() + if operator.Equal(doubleEq) && len(x.Operands()) == 2 { + x.SetOperator(ast.NewTerm(unifyOp)) + } + } + return false + }) + return query, nil +} + +func (r *Rego) generateTermVar() *ast.Term { + r.termVarID++ + prefix := ast.WildcardPrefix + if p := r.targetPlugin(r.target); p != nil { + prefix = wasmVarPrefix + } else if r.target == targetWasm { + prefix = wasmVarPrefix + } + return ast.VarTerm(fmt.Sprintf("%sterm%v", prefix, r.termVarID)) +} + +func (r Rego) hasQuery() bool { + return len(r.query) != 0 || len(r.parsedQuery) != 0 +} + +func (r Rego) hasWasmModule() bool { + for _, b := range r.bundles { + if len(b.WasmModules) > 0 { + return true + } + } + return false +} + +type transactionCloser func(ctx context.Context, err error) error + +// getTxn will conditionally create a read or write transaction suitable for +// the configured Rego object. The returned function should be used to close the txn +// regardless of status. +func (r *Rego) getTxn(ctx context.Context) (storage.Transaction, transactionCloser, error) { + + noopCloser := func(_ context.Context, _ error) error { + return nil // no-op default + } + + if r.txn != nil { + // Externally provided txn + return r.txn, noopCloser, nil + } + + // Create a new transaction.. + params := storage.TransactionParams{} + + // Bundles and data paths may require writing data files or manifests to storage + if len(r.bundles) > 0 || len(r.bundlePaths) > 0 || len(r.loadPaths.paths) > 0 { + + // If we were given a store we will *not* write to it, only do that on one + // which was created automatically on behalf of the user. + if !r.ownStore { + return nil, noopCloser, errors.New("unable to start write transaction when store was provided") + } + + params.Write = true + } + + txn, err := r.store.NewTransaction(ctx, params) + if err != nil { + return nil, noopCloser, err + } + + // Setup a closer function that will abort or commit as needed. + closer := func(ctx context.Context, txnErr error) error { + var err error + + if txnErr == nil && params.Write { + err = r.store.Commit(ctx, txn) + } else { + r.store.Abort(ctx, txn) + } + + // Clear the auto created transaction now that it is closed. + r.txn = nil + + return err + } + + return txn, closer, nil +} + +func (r *Rego) compilerForTxn(ctx context.Context, store storage.Store, txn storage.Transaction) *ast.Compiler { + // Update the compiler to have a valid path conflict check + // for the current context and transaction. + return r.compiler.WithPathConflictsCheck(storage.NonEmpty(ctx, store, txn)) +} + +func checkPartialResultForRecursiveRefs(body ast.Body, path ast.Ref) bool { + var stop bool + ast.WalkRefs(body, func(x ast.Ref) bool { + if !stop { + if path.HasPrefix(x) { + stop = true + } + } + return stop + }) + return stop +} + +func isTermVar(v ast.Var) bool { + return strings.HasPrefix(string(v), ast.WildcardPrefix+"term") +} + +func isTermWasmVar(v ast.Var) bool { + return strings.HasPrefix(string(v), wasmVarPrefix+"term") +} + +func waitForDone(ctx context.Context, exit chan struct{}, f func()) { + select { + case <-exit: + return + case <-ctx.Done(): + f() + return + } +} + +type rawModule struct { + filename string + module string +} + +func (m rawModule) Parse() (*ast.Module, error) { + return ast.ParseModule(m.filename, m.module) +} + +func (m rawModule) ParseWithOpts(opts ast.ParserOptions) (*ast.Module, error) { + return ast.ParseModuleWithOpts(m.filename, m.module, opts) +} + +type extraStage struct { + after string + stage ast.QueryCompilerStageDefinition +} + +type refResolver struct { + ref ast.Ref + r resolver.Resolver +} + +func iteration(x interface{}) bool { + + var stopped bool + + vis := ast.NewGenericVisitor(func(x interface{}) bool { + switch x := x.(type) { + case *ast.Term: + if ast.IsComprehension(x.Value) { + return true + } + case ast.Ref: + if !stopped { + if bi := ast.BuiltinMap[x.String()]; bi != nil { + if bi.Relation { + stopped = true + return stopped + } + } + for i := 1; i < len(x); i++ { + if _, ok := x[i].Value.(ast.Var); ok { + stopped = true + return stopped + } + } + } + return stopped + } + return stopped + }) + + vis.Walk(x) + + return stopped +} + +func parseStringsToRefs(s []string) ([]ast.Ref, error) { + + refs := make([]ast.Ref, len(s)) + for i := range refs { + var err error + refs[i], err = ast.ParseRef(s[i]) + if err != nil { + return nil, err + } + } + + return refs, nil +} + +// helper function to finish a built-in function call. If an error occurred, +// wrap the error and return it. Otherwise, invoke the iterator if the result +// was defined. +func finishFunction(name string, bctx topdown.BuiltinContext, result *ast.Term, err error, iter func(*ast.Term) error) error { + if err != nil { + var e *HaltError + if errors.As(err, &e) { + tdErr := &topdown.Error{ + Code: topdown.BuiltinErr, + Message: fmt.Sprintf("%v: %v", name, e.Error()), + Location: bctx.Location, + } + return topdown.Halt{Err: tdErr.Wrap(e)} + } + tdErr := &topdown.Error{ + Code: topdown.BuiltinErr, + Message: fmt.Sprintf("%v: %v", name, err.Error()), + Location: bctx.Location, + } + return tdErr.Wrap(err) + } + if result == nil { + return nil + } + return iter(result) +} + +// helper function to return an option that sets a custom built-in function. +func newFunction(decl *Function, f topdown.BuiltinFunc) func(*Rego) { + return func(r *Rego) { + r.builtinDecls[decl.Name] = &ast.Builtin{ + Name: decl.Name, + Decl: decl.Decl, + Nondeterministic: decl.Nondeterministic, + } + r.builtinFuncs[decl.Name] = &topdown.Builtin{ + Decl: r.builtinDecls[decl.Name], + Func: f, + } + } +} + +func generateJSON(term *ast.Term, ectx *EvalContext) (interface{}, error) { + return ast.JSONWithOpt(term.Value, + ast.JSONOpt{ + SortSets: ectx.sortSets, + CopyMaps: ectx.copyMaps, + }) +} + +func (r *Rego) planQuery(queries []ast.Body, evalQueryType queryType) (*ir.Policy, error) { + modules := make([]*ast.Module, 0, len(r.compiler.Modules)) + for _, module := range r.compiler.Modules { + modules = append(modules, module) + } + + decls := make(map[string]*ast.Builtin, len(r.builtinDecls)+len(ast.BuiltinMap)) + + for k, v := range ast.BuiltinMap { + decls[k] = v + } + + for k, v := range r.builtinDecls { + decls[k] = v + } + + const queryName = "eval" // NOTE(tsandall): the query name is arbitrary + + p := planner.New(). + WithQueries([]planner.QuerySet{ + { + Name: queryName, + Queries: queries, + RewrittenVars: r.compiledQueries[evalQueryType].compiler.RewrittenVars(), + }, + }). + WithModules(modules). + WithBuiltinDecls(decls). + WithDebug(r.dump) + + policy, err := p.Plan() + if err != nil { + return nil, err + } + if r.dump != nil { + fmt.Fprintln(r.dump, "PLAN:") + fmt.Fprintln(r.dump, "-----") + err = ir.Pretty(r.dump, policy) + if err != nil { + return nil, err + } + fmt.Fprintln(r.dump) + } + return policy, nil +} diff --git a/vendor/github.com/open-policy-agent/opa/v1/rego/resultset.go b/vendor/github.com/open-policy-agent/opa/v1/rego/resultset.go new file mode 100644 index 000000000..cc0710426 --- /dev/null +++ b/vendor/github.com/open-policy-agent/opa/v1/rego/resultset.go @@ -0,0 +1,90 @@ +package rego + +import ( + "fmt" + + "github.com/open-policy-agent/opa/v1/ast" +) + +// ResultSet represents a collection of output from Rego evaluation. An empty +// result set represents an undefined query. +type ResultSet []Result + +// Vars represents a collection of variable bindings. The keys are the variable +// names and the values are the binding values. +type Vars map[string]interface{} + +// WithoutWildcards returns a copy of v with wildcard variables removed. +func (v Vars) WithoutWildcards() Vars { + n := Vars{} + for k, v := range v { + if ast.Var(k).IsWildcard() || ast.Var(k).IsGenerated() { + continue + } + n[k] = v + } + return n +} + +// Result defines the output of Rego evaluation. +type Result struct { + Expressions []*ExpressionValue `json:"expressions"` + Bindings Vars `json:"bindings,omitempty"` +} + +func newResult() Result { + return Result{ + Bindings: Vars{}, + } +} + +// Location defines a position in a Rego query or module. +type Location struct { + Row int `json:"row"` + Col int `json:"col"` +} + +// ExpressionValue defines the value of an expression in a Rego query. +type ExpressionValue struct { + Value interface{} `json:"value"` + Text string `json:"text"` + Location *Location `json:"location"` +} + +func newExpressionValue(expr *ast.Expr, value interface{}) *ExpressionValue { + result := &ExpressionValue{ + Value: value, + } + if expr.Location != nil { + result.Text = string(expr.Location.Text) + result.Location = &Location{ + Row: expr.Location.Row, + Col: expr.Location.Col, + } + } + return result +} + +func (ev *ExpressionValue) String() string { + return fmt.Sprint(ev.Value) +} + +// Allowed is a helper method that'll return true if all of these conditions hold: +// - the result set only has one element +// - there is only one expression in the result set's only element +// - that expression has the value `true` +// - there are no bindings. +// +// If bindings are present, this will yield `false`: it would be a pitfall to +// return `true` for a query like `data.authz.allow = x`, which always has result +// set element with value true, but could also have a binding `x: false`. +func (rs ResultSet) Allowed() bool { + if len(rs) == 1 && len(rs[0].Bindings) == 0 { + if exprs := rs[0].Expressions; len(exprs) == 1 { + if b, ok := exprs[0].Value.(bool); ok { + return b + } + } + } + return false +} diff --git a/vendor/github.com/open-policy-agent/opa/repl/errors.go b/vendor/github.com/open-policy-agent/opa/v1/repl/errors.go similarity index 100% rename from vendor/github.com/open-policy-agent/opa/repl/errors.go rename to vendor/github.com/open-policy-agent/opa/v1/repl/errors.go diff --git a/vendor/github.com/open-policy-agent/opa/repl/repl.go b/vendor/github.com/open-policy-agent/opa/v1/repl/repl.go similarity index 98% rename from vendor/github.com/open-policy-agent/opa/repl/repl.go rename to vendor/github.com/open-policy-agent/opa/v1/repl/repl.go index fc8b73c66..4147a8628 100644 --- a/vendor/github.com/open-policy-agent/opa/repl/repl.go +++ b/vendor/github.com/open-policy-agent/opa/v1/repl/repl.go @@ -18,23 +18,21 @@ import ( "strings" "sync" - "github.com/open-policy-agent/opa/bundle" - "github.com/open-policy-agent/opa/compile" - - "github.com/open-policy-agent/opa/version" - "github.com/peterh/liner" - "github.com/open-policy-agent/opa/ast" - "github.com/open-policy-agent/opa/format" "github.com/open-policy-agent/opa/internal/future" pr "github.com/open-policy-agent/opa/internal/presentation" - "github.com/open-policy-agent/opa/metrics" - "github.com/open-policy-agent/opa/profiler" - "github.com/open-policy-agent/opa/rego" - "github.com/open-policy-agent/opa/storage" - "github.com/open-policy-agent/opa/topdown" - "github.com/open-policy-agent/opa/topdown/lineage" + "github.com/open-policy-agent/opa/v1/ast" + "github.com/open-policy-agent/opa/v1/bundle" + "github.com/open-policy-agent/opa/v1/compile" + "github.com/open-policy-agent/opa/v1/format" + "github.com/open-policy-agent/opa/v1/metrics" + "github.com/open-policy-agent/opa/v1/profiler" + "github.com/open-policy-agent/opa/v1/rego" + "github.com/open-policy-agent/opa/v1/storage" + "github.com/open-policy-agent/opa/v1/topdown" + "github.com/open-policy-agent/opa/v1/topdown/lineage" + "github.com/open-policy-agent/opa/v1/version" ) // REPL represents an instance of the interactive shell. @@ -126,6 +124,7 @@ func New(store storage.Store, historyPath string, output io.Writer, outputFormat errLimit: errLimit, prettyLimit: defaultPrettyLimit, target: compile.TargetRego, + regoVersion: ast.DefaultRegoVersion, } } @@ -930,12 +929,7 @@ func (r *REPL) parserOptions() (ast.ParserOptions, error) { if err == nil { for _, i := range r.modules[r.currentModuleID].Imports { if ast.Compare(i.Path.Value, ast.RegoV1CompatibleRef) == 0 { - opts.RegoVersion = ast.RegoV0CompatV1 - - // ast.RegoV0CompatV1 sets parsing requirements, but doesn't imply allowed future keywords - if r.capabilities != nil { - opts.FutureKeywords = r.capabilities.FutureKeywords - } + opts.RegoVersion = ast.RegoV1 } } } diff --git a/vendor/github.com/open-policy-agent/opa/resolver/interface.go b/vendor/github.com/open-policy-agent/opa/v1/resolver/interface.go similarity index 86% rename from vendor/github.com/open-policy-agent/opa/resolver/interface.go rename to vendor/github.com/open-policy-agent/opa/v1/resolver/interface.go index fc02329f5..1f04d21c0 100644 --- a/vendor/github.com/open-policy-agent/opa/resolver/interface.go +++ b/vendor/github.com/open-policy-agent/opa/v1/resolver/interface.go @@ -7,8 +7,8 @@ package resolver import ( "context" - "github.com/open-policy-agent/opa/ast" - "github.com/open-policy-agent/opa/metrics" + "github.com/open-policy-agent/opa/v1/ast" + "github.com/open-policy-agent/opa/v1/metrics" ) // Resolver defines an external value resolver for OPA evaluations. diff --git a/vendor/github.com/open-policy-agent/opa/v1/resolver/wasm/wasm.go b/vendor/github.com/open-policy-agent/opa/v1/resolver/wasm/wasm.go new file mode 100644 index 000000000..4f57b3ef8 --- /dev/null +++ b/vendor/github.com/open-policy-agent/opa/v1/resolver/wasm/wasm.go @@ -0,0 +1,173 @@ +// Copyright 2020 The OPA Authors. All rights reserved. +// Use of this source code is governed by an Apache2 +// license that can be found in the LICENSE file. + +package wasm + +import ( + "context" + "fmt" + "strconv" + + "github.com/open-policy-agent/opa/internal/rego/opa" + "github.com/open-policy-agent/opa/v1/ast" + "github.com/open-policy-agent/opa/v1/resolver" +) + +// New creates a new Resolver instance which is using the Wasm module +// policy for the given entrypoint ref. +func New(entrypoints []ast.Ref, policy []byte, data interface{}) (*Resolver, error) { + e, err := opa.LookupEngine("wasm") + if err != nil { + return nil, err + } + o, err := e.New(). + WithPolicyBytes(policy). + WithDataJSON(data). + Init() + if err != nil { + return nil, err + } + + // Construct a quick lookup table of ref -> entrypoint ID + // for handling evaluations. Only the entrypoints provided + // by the caller will be constructed, this may be a subset + // of entrypoints available in the Wasm module, however + // only the configured ones will be used when Eval() is + // called. + entrypointRefToID := ast.NewValueMap() + epIDs, err := o.Entrypoints(context.Background()) + if err != nil { + return nil, err + } + for path, id := range epIDs { + for _, ref := range entrypoints { + refPtr, err := ref.Ptr() + if err != nil { + return nil, err + } + if refPtr == path { + entrypointRefToID.Put(ref, ast.Number(strconv.Itoa(int(id)))) + } + } + } + + return &Resolver{ + entrypoints: entrypoints, + entrypointIDs: entrypointRefToID, + o: o, + }, nil +} + +// Resolver implements the resolver.Resolver interface +// using Wasm modules to perform an evaluation. +type Resolver struct { + entrypoints []ast.Ref + entrypointIDs *ast.ValueMap + o opa.EvalEngine +} + +// Entrypoints returns a list of entrypoints this resolver is configured to +// perform evaluations on. +func (r *Resolver) Entrypoints() []ast.Ref { + return r.entrypoints +} + +// Close shuts down the resolver. +func (r *Resolver) Close() { + r.o.Close() +} + +// Eval performs an evaluation using the provided input and the Wasm module +// associated with this Resolver instance. +func (r *Resolver) Eval(ctx context.Context, input resolver.Input) (resolver.Result, error) { + v := r.entrypointIDs.Get(input.Ref) + if v == nil { + return resolver.Result{}, fmt.Errorf("unknown entrypoint %s", input.Ref) + } + + numValue, ok := v.(ast.Number) + if !ok { + return resolver.Result{}, fmt.Errorf("internal error: invalid entrypoint id %s", numValue) + } + + epID, ok := numValue.Int() + if !ok { + return resolver.Result{}, fmt.Errorf("internal error: invalid entrypoint id %s", numValue) + } + + var in *interface{} + if input.Input != nil { + var str interface{} = []byte(input.Input.String()) + in = &str + } + + opts := opa.EvalOpts{ + Input: in, + Entrypoint: int32(epID), + Metrics: input.Metrics, + } + out, err := r.o.Eval(ctx, opts) + if err != nil { + return resolver.Result{}, err + } + + result, err := getResult(out) + if err != nil { + return resolver.Result{}, err + } + + return resolver.Result{Value: result}, nil +} + +// SetData will update the external data for the Wasm instance. +func (r *Resolver) SetData(ctx context.Context, data interface{}) error { + return r.o.SetData(ctx, data) +} + +// SetDataPath will set the provided data on the wasm instance at the specified path. +func (r *Resolver) SetDataPath(ctx context.Context, path []string, data interface{}) error { + return r.o.SetDataPath(ctx, path, data) +} + +// RemoveDataPath will remove any data at the specified path. +func (r *Resolver) RemoveDataPath(ctx context.Context, path []string) error { + return r.o.RemoveDataPath(ctx, path) +} + +func getResult(evalResult *opa.Result) (ast.Value, error) { + + parsed, err := ast.ParseTerm(string(evalResult.Result)) + if err != nil { + return nil, fmt.Errorf("failed to parse wasm result: %s", err) + } + + resultSet, ok := parsed.Value.(ast.Set) + if !ok { + return nil, fmt.Errorf("illegal result type") + } + + if resultSet.Len() == 0 { + return nil, nil + } + + if resultSet.Len() > 1 { + return nil, fmt.Errorf("illegal result type") + } + + var obj ast.Object + err = resultSet.Iter(func(term *ast.Term) error { + obj, ok = term.Value.(ast.Object) + if !ok || obj.Len() != 1 { + return fmt.Errorf("illegal result type") + } + return nil + }) + if err != nil { + return nil, err + } + + result := obj.Get(ast.StringTerm("result")) + + return result.Value, nil +} diff --git a/vendor/github.com/open-policy-agent/opa/runtime/check_user_linux.go b/vendor/github.com/open-policy-agent/opa/v1/runtime/check_user_linux.go similarity index 93% rename from vendor/github.com/open-policy-agent/opa/runtime/check_user_linux.go rename to vendor/github.com/open-policy-agent/opa/v1/runtime/check_user_linux.go index 474203ed9..f3ed8b323 100644 --- a/vendor/github.com/open-policy-agent/opa/runtime/check_user_linux.go +++ b/vendor/github.com/open-policy-agent/opa/v1/runtime/check_user_linux.go @@ -7,7 +7,7 @@ package runtime import ( "os/user" - "github.com/open-policy-agent/opa/logging" + "github.com/open-policy-agent/opa/v1/logging" ) // checkUserPrivileges on Linux could be running in Docker, so we check if diff --git a/vendor/github.com/open-policy-agent/opa/runtime/check_user_unix.go b/vendor/github.com/open-policy-agent/opa/v1/runtime/check_user_unix.go similarity index 89% rename from vendor/github.com/open-policy-agent/opa/runtime/check_user_unix.go rename to vendor/github.com/open-policy-agent/opa/v1/runtime/check_user_unix.go index a47ec5ece..7b398c195 100644 --- a/vendor/github.com/open-policy-agent/opa/runtime/check_user_unix.go +++ b/vendor/github.com/open-policy-agent/opa/v1/runtime/check_user_unix.go @@ -1,3 +1,4 @@ +//go:build !linux && !windows // +build !linux,!windows // Copyright 2022 The OPA Authors. All rights reserved. @@ -9,7 +10,7 @@ package runtime import ( "os/user" - "github.com/open-policy-agent/opa/logging" + "github.com/open-policy-agent/opa/v1/logging" ) // checkUserPrivileges could not be running in Docker, so we only warn diff --git a/vendor/github.com/open-policy-agent/opa/runtime/check_user_windows.go b/vendor/github.com/open-policy-agent/opa/v1/runtime/check_user_windows.go similarity index 89% rename from vendor/github.com/open-policy-agent/opa/runtime/check_user_windows.go rename to vendor/github.com/open-policy-agent/opa/v1/runtime/check_user_windows.go index 68f943c12..012b5f7e7 100644 --- a/vendor/github.com/open-policy-agent/opa/runtime/check_user_windows.go +++ b/vendor/github.com/open-policy-agent/opa/v1/runtime/check_user_windows.go @@ -5,7 +5,7 @@ package runtime import ( - "github.com/open-policy-agent/opa/logging" + "github.com/open-policy-agent/opa/v1/logging" ) // checkUserPrivileges is a no-op in Windows to avoid lookups with diff --git a/vendor/github.com/open-policy-agent/opa/v1/runtime/doc.go b/vendor/github.com/open-policy-agent/opa/v1/runtime/doc.go new file mode 100644 index 000000000..4e047af6a --- /dev/null +++ b/vendor/github.com/open-policy-agent/opa/v1/runtime/doc.go @@ -0,0 +1,6 @@ +// Copyright 2016 The OPA Authors. All rights reserved. +// Use of this source code is governed by an Apache2 +// license that can be found in the LICENSE file. + +// Package runtime contains the entry point to the policy engine. +package runtime diff --git a/vendor/github.com/open-policy-agent/opa/v1/runtime/logging.go b/vendor/github.com/open-policy-agent/opa/v1/runtime/logging.go new file mode 100644 index 000000000..a2d56fd7d --- /dev/null +++ b/vendor/github.com/open-policy-agent/opa/v1/runtime/logging.go @@ -0,0 +1,266 @@ +// Copyright 2016 The OPA Authors. All rights reserved. +// Use of this source code is governed by an Apache2 +// license that can be found in the LICENSE file. + +package runtime + +import ( + "bytes" + "compress/gzip" + "io" + "net/http" + "strings" + "sync/atomic" + "time" + + "github.com/open-policy-agent/opa/v1/logging" + "github.com/open-policy-agent/opa/v1/topdown/print" +) + +type loggingPrintHook struct { + logger logging.Logger +} + +func (h loggingPrintHook) Print(pctx print.Context, msg string) error { + // NOTE(tsandall): if the request context is not present then do not panic, + // just log the print message without the additional context. + var fields map[string]any + rctx, ok := logging.FromContext(pctx.Context) + if ok { + fields = rctx.Fields() + } else { + fields = make(map[string]any, 1) + } + fields["line"] = pctx.Location.String() + h.logger.WithFields(fields).Info(msg) + return nil +} + +// LoggingHandler returns an http.Handler that will print log messages +// containing the request information as well as response status and latency. +type LoggingHandler struct { + logger logging.Logger + inner http.Handler + requestID uint64 +} + +// NewLoggingHandler returns a new http.Handler. +func NewLoggingHandler(logger logging.Logger, inner http.Handler) http.Handler { + return &LoggingHandler{ + logger: logger, + inner: inner, + requestID: uint64(0), + } +} + +func (h *LoggingHandler) loggingEnabled(level logging.Level) bool { + return level <= h.logger.GetLevel() +} + +func (h *LoggingHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + + cloneHeaders := r.Header.Clone() + + // set the HTTP headers via a dedicated key on the parent context irrespective of logging level + r = r.WithContext(logging.WithHTTPRequestContext(r.Context(), &logging.HTTPRequestContext{Header: cloneHeaders})) + + var rctx logging.RequestContext + rctx.ReqID = atomic.AddUint64(&h.requestID, uint64(1)) + + recorder := newRecorder(h.logger, w, r, rctx.ReqID, h.loggingEnabled(logging.Debug)) + t0 := time.Now() + + if h.loggingEnabled(logging.Info) { + + rctx.ClientAddr = r.RemoteAddr + rctx.ReqMethod = r.Method + rctx.ReqPath = r.URL.EscapedPath() + r = r.WithContext(logging.NewContext(r.Context(), &rctx)) + + var err error + fields := rctx.Fields() + + if h.loggingEnabled(logging.Debug) { + var bs []byte + if r.Body != nil { + bs, r.Body, err = readBody(r.Body) + } + if err == nil { + if gzipReceived(r.Header) { + // the request is compressed + var gzReader *gzip.Reader + var plainOutput []byte + reader := bytes.NewReader(bs) + gzReader, err = gzip.NewReader(reader) + if err == nil { + plainOutput, err = io.ReadAll(gzReader) + if err == nil { + defer gzReader.Close() + fields["req_body"] = string(plainOutput) + } + } + } else { + fields["req_body"] = string(bs) + } + } + + // err can be thrown on different statements + if err != nil { + fields["err"] = err + } + + fields["req_params"] = r.URL.Query() + } + + if err == nil { + h.logger.WithFields(fields).Info("Received request.") + } else { + h.logger.WithFields(fields).Error("Failed to read body.") + } + } + + h.inner.ServeHTTP(recorder, r) + + dt := time.Since(t0) + statusCode := 200 + if recorder.statusCode != 0 { + statusCode = recorder.statusCode + } + + if h.loggingEnabled(logging.Info) { + fields := map[string]interface{}{ + "client_addr": rctx.ClientAddr, + "req_id": rctx.ReqID, + "req_method": rctx.ReqMethod, + "req_path": rctx.ReqPath, + "resp_status": statusCode, + "resp_bytes": recorder.bytesWritten, + "resp_duration": float64(dt.Nanoseconds()) / 1e6, + } + + if h.loggingEnabled(logging.Debug) { + switch { + case isPprofEndpoint(r): + // pprof always sends binary data (protobuf) + fields["resp_body"] = "[binary payload]" + + case gzipAccepted(r.Header) && isMetricsEndpoint(r): + // metrics endpoint does so when the client accepts it (e.g. prometheus) + fields["resp_body"] = "[compressed payload]" + + case gzipAccepted(r.Header) && gzipReceived(w.Header()) && (isDataEndpoint(r) || isCompileEndpoint(r)): + // data and compile endpoints might compress the response + gzReader, gzErr := gzip.NewReader(recorder.buf) + if gzErr == nil { + plainOutput, readErr := io.ReadAll(gzReader) + if readErr == nil { + defer gzReader.Close() + fields["resp_body"] = string(plainOutput) + } else { + h.logger.Error("Failed to decompressed the payload: %v", readErr.Error()) + } + } else { + h.logger.Error("Failed to read the compressed payload: %v", gzErr.Error()) + } + + default: + fields["resp_body"] = recorder.buf.String() + } + } + + h.logger.WithFields(fields).Info("Sent response.") + } +} + +func gzipAccepted(header http.Header) bool { + a := header.Get("Accept-Encoding") + parts := strings.Split(a, ",") + for _, part := range parts { + part = strings.TrimSpace(part) + if part == "gzip" || strings.HasPrefix(part, "gzip;") { + return true + } + } + return false +} + +func gzipReceived(header http.Header) bool { + a := header.Get("Content-Encoding") + parts := strings.Split(a, ",") + for _, part := range parts { + part = strings.TrimSpace(part) + if part == "gzip" || strings.HasPrefix(part, "gzip;") { + return true + } + } + return false +} + +func isPprofEndpoint(req *http.Request) bool { + return strings.HasPrefix(req.URL.Path, "/debug/pprof/") +} + +func isMetricsEndpoint(req *http.Request) bool { + return strings.HasPrefix(req.URL.Path, "/metrics") +} + +func isDataEndpoint(req *http.Request) bool { + return strings.HasPrefix(req.URL.Path, "/v1/data") || strings.HasPrefix(req.URL.Path, "/v0/data") +} + +func isCompileEndpoint(req *http.Request) bool { + return strings.HasPrefix(req.URL.Path, "/v1/compile") +} + +type recorder struct { + logger logging.Logger + inner http.ResponseWriter + req *http.Request + id uint64 + + buf *bytes.Buffer + bytesWritten int + statusCode int +} + +func newRecorder(logger logging.Logger, w http.ResponseWriter, r *http.Request, id uint64, buffer bool) *recorder { + var buf *bytes.Buffer + if buffer { + buf = new(bytes.Buffer) + } + return &recorder{ + logger: logger, + buf: buf, + inner: w, + req: r, + id: id, + } +} + +func (r *recorder) Header() http.Header { + return r.inner.Header() +} + +func (r *recorder) Write(bs []byte) (int, error) { + r.bytesWritten += len(bs) + if r.buf != nil { + r.buf.Write(bs) + } + return r.inner.Write(bs) +} + +func (r *recorder) WriteHeader(s int) { + r.statusCode = s + r.inner.WriteHeader(s) +} + +func readBody(r io.ReadCloser) ([]byte, io.ReadCloser, error) { + if r == http.NoBody { + return nil, r, nil + } + var buf bytes.Buffer + if _, err := buf.ReadFrom(r); err != nil { + return nil, r, err + } + return buf.Bytes(), io.NopCloser(bytes.NewReader(buf.Bytes())), nil +} diff --git a/vendor/github.com/open-policy-agent/opa/v1/runtime/runtime.go b/vendor/github.com/open-policy-agent/opa/v1/runtime/runtime.go new file mode 100644 index 000000000..1d54f0b6d --- /dev/null +++ b/vendor/github.com/open-policy-agent/opa/v1/runtime/runtime.go @@ -0,0 +1,1009 @@ +// Copyright 2016 The OPA Authors. All rights reserved. +// Use of this source code is governed by an Apache2 +// license that can be found in the LICENSE file. + +package runtime + +import ( + "bytes" + "context" + "crypto/rand" + "crypto/tls" + "crypto/x509" + "fmt" + "io" + mr "math/rand" + "net/url" + "os" + "os/signal" + "strings" + "sync" + "syscall" + "time" + + "github.com/fsnotify/fsnotify" + "github.com/gorilla/mux" + "go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp" + "go.opentelemetry.io/otel/exporters/otlp/otlptrace" + "go.opentelemetry.io/otel/propagation" + "go.uber.org/automaxprocs/maxprocs" + + "github.com/open-policy-agent/opa/internal/compiler" + "github.com/open-policy-agent/opa/internal/config" + internal_tracing "github.com/open-policy-agent/opa/internal/distributedtracing" + internal_logging "github.com/open-policy-agent/opa/internal/logging" + "github.com/open-policy-agent/opa/internal/pathwatcher" + "github.com/open-policy-agent/opa/internal/prometheus" + "github.com/open-policy-agent/opa/internal/ref" + "github.com/open-policy-agent/opa/internal/report" + "github.com/open-policy-agent/opa/internal/runtime" + initload "github.com/open-policy-agent/opa/internal/runtime/init" + "github.com/open-policy-agent/opa/internal/uuid" + "github.com/open-policy-agent/opa/v1/ast" + "github.com/open-policy-agent/opa/v1/bundle" + opa_config "github.com/open-policy-agent/opa/v1/config" + "github.com/open-policy-agent/opa/v1/loader" + "github.com/open-policy-agent/opa/v1/logging" + "github.com/open-policy-agent/opa/v1/metrics" + "github.com/open-policy-agent/opa/v1/plugins" + "github.com/open-policy-agent/opa/v1/plugins/discovery" + "github.com/open-policy-agent/opa/v1/plugins/logs" + metrics_config "github.com/open-policy-agent/opa/v1/plugins/server/metrics" + "github.com/open-policy-agent/opa/v1/repl" + "github.com/open-policy-agent/opa/v1/server" + "github.com/open-policy-agent/opa/v1/storage" + "github.com/open-policy-agent/opa/v1/storage/disk" + "github.com/open-policy-agent/opa/v1/storage/inmem" + "github.com/open-policy-agent/opa/v1/tracing" + "github.com/open-policy-agent/opa/v1/util" + "github.com/open-policy-agent/opa/v1/version" +) + +var ( + registeredPlugins map[string]plugins.Factory + registeredPluginsMux sync.Mutex +) + +const ( + // default interval between OPA version report uploads after startup (1h) + defaultInitialUploadInterval = time.Hour + // upload interval when OPA has been running for 6+ hrs (6h) + defaultLaterUploadInterval = 6 * time.Hour +) + +// RegisterPlugin registers a plugin factory with the runtime +// package. When the runtime is created, the factories are used to parse +// plugin configuration and instantiate plugins. If no configuration is +// provided, plugins are not instantiated. This function is idempotent. +func RegisterPlugin(name string, factory plugins.Factory) { + registeredPluginsMux.Lock() + defer registeredPluginsMux.Unlock() + registeredPlugins[name] = factory +} + +// Params stores the configuration for an OPA instance. +type Params struct { + // Globally unique identifier for this OPA instance. If an ID is not specified, + // the runtime will generate one. + ID string + + // Addrs are the listening addresses that the OPA server will bind to. + Addrs *[]string + + // DiagnosticAddrs are the listening addresses that the OPA server will bind to + // for read-only diagnostic API's (/health, /metrics, etc) + DiagnosticAddrs *[]string + + // H2CEnabled flag controls whether OPA will allow H2C (HTTP/2 cleartext) on + // HTTP listeners. + H2CEnabled bool + + // Authentication is the type of authentication scheme to use. + Authentication server.AuthenticationScheme + + // Authorization is the type of authorization scheme to use. + Authorization server.AuthorizationScheme + + // Certificate is the certificate to use in server-mode. If the certificate + // is nil, the server will NOT use TLS. + Certificate *tls.Certificate + + // CertificateFile and CertificateKeyFile are the paths to the cert and its + // keyfile. It'll be used to periodically reload the files from disk if they + // have changed. The server will attempt to refresh every 5 minutes, unless + // a different CertificateRefresh time.Duration is provided + CertificateFile string + CertificateKeyFile string + CertificateRefresh time.Duration + + // CertPool holds the CA certs trusted by the OPA server. + CertPool *x509.CertPool + // CertPoolFile, if set permits the reloading of the CA cert pool from disk + CertPoolFile string + + // MinVersion contains the minimum TLS version that is acceptable. + // If zero, TLS 1.2 is currently taken as the minimum. + MinTLSVersion uint16 + + // HistoryPath is the filename to store the interactive shell user + // input history. + HistoryPath string + + // Output format controls how the REPL will print query results. + // Default: "pretty". + OutputFormat string + + // Paths contains filenames of base documents and policy modules to load on + // startup. Data files may be prefixed with ":" to indicate + // where the contained document should be loaded. + Paths []string + + // Optional filter that will be passed to the file loader. + Filter loader.Filter + + // BundleMode will enable treating the Paths provided as bundles rather than + // loading all data & policy files. + BundleMode bool + + // Watch flag controls whether OPA will watch the Paths files for changes. + // If this flag is true, OPA will watch the Paths files for changes and + // reload the storage layer each time they change. This is useful for + // interactive development. + Watch bool + + // ErrorLimit is the number of errors the compiler will allow to occur before + // exiting early. + ErrorLimit int + + // PprofEnabled flag controls whether pprof endpoints are enabled + PprofEnabled bool + + // DecisionIDFactory generates decision IDs to include in API responses + // sent by the server (in response to Data API queries.) + DecisionIDFactory func() string + + // Logging configures the logging behaviour. + Logging LoggingConfig + + // Logger sets the logger implementation to use for debug logs. + Logger logging.Logger + + // ConsoleLogger sets the logger implementation to use for console logs. + ConsoleLogger logging.Logger + + // ConfigFile refers to the OPA configuration to load on startup. + ConfigFile string + + // ConfigOverrides are overrides for the OPA configuration that are applied + // over top the config file They are in a list of key=value syntax that + // conform to the syntax defined in the `strval` package + ConfigOverrides []string + + // ConfigOverrideFiles Similar to `ConfigOverrides` except they are in the + // form of `key=path/to/file`where the file contains the value to be used. + ConfigOverrideFiles []string + + // Output is the output stream used when run as an interactive shell. This + // is mostly for test purposes. + Output io.Writer + + // GracefulShutdownPeriod is the time (in seconds) to wait for the http + // server to shutdown gracefully. + GracefulShutdownPeriod int + + // ShutdownWaitPeriod is the time (in seconds) to wait before initiating shutdown. + ShutdownWaitPeriod int + + // EnableVersionCheck flag controls whether OPA will report its version to an external service. + // If this flag is true, OPA will report its version to the external service + EnableVersionCheck bool + + // BundleVerificationConfig sets the key configuration used to verify a signed bundle + BundleVerificationConfig *bundle.VerificationConfig + + // SkipBundleVerification flag controls whether OPA will verify a signed bundle + SkipBundleVerification bool + + // SkipKnownSchemaCheck flag controls whether OPA will perform type checking on known input schemas + SkipKnownSchemaCheck bool + + // ReadyTimeout flag controls if and for how long OPA server will wait (in seconds) for + // configured bundles and plugins to be activated/ready before listening for traffic. + // A value of 0 or less means no wait is exercised. + ReadyTimeout int + + // Router is the router to which handlers for the REST API are added. + // Router uses a first-matching-route-wins strategy, so no existing routes are overridden + // If it is nil, a new mux.Router will be created + Router *mux.Router + + // DiskStorage, if set, will make the runtime instantiate a disk-backed storage + // implementation (instead of the default, in-memory store). + // It can also be enabled via config, and this runtime field takes precedence. + DiskStorage *disk.Options + + DistributedTracingOpts tracing.Options + + // Check if default Addr is set or the user has changed it. + AddrSetByUser bool + + // UnixSocketPerm specifies the permission for the Unix domain socket if used to listen for connections + UnixSocketPerm *string + + // V0Compatible will enable OPA features and behaviors that were enabled by default in OPA v0.x releases. + // Takes precedence over V1Compatible. + V0Compatible bool + + // V1Compatible will enable OPA features and behaviors that will be enabled by default in a future OPA v1.0 release. + // This flag allows users to opt-in to the new behavior and helps transition to the future release upon which + // the new behavior will be enabled by default. + // If V0Compatible is set, V1Compatible will be ignored. + V1Compatible bool + + // CipherSuites specifies the list of enabled TLS 1.0–1.2 cipher suites + CipherSuites *[]uint16 + + // ReadAstValuesFromStore controls whether the storage layer should return AST values when reading from the store. + // This is an eager conversion, that comes with an upfront performance cost when updating the store (e.g. bundle updates). + // Evaluation performance is affected in that data doesn't need to be converted to AST during evaluation. + // Only applicable when using the default in-memory store, and not when used together with the DiskStorage option. + ReadAstValuesFromStore bool +} + +func (p *Params) regoVersion() ast.RegoVersion { + // v0 takes precedence over v1 + if p.V0Compatible { + return ast.RegoV0 + } + if p.V1Compatible { + return ast.RegoV1 + } + return ast.DefaultRegoVersion +} + +func (p *Params) parserOptions() ast.ParserOptions { + return ast.ParserOptions{RegoVersion: p.regoVersion()} +} + +// LoggingConfig stores the configuration for OPA's logging behaviour. +type LoggingConfig struct { + Level string + Format string + TimestampFormat string +} + +// NewParams returns a new Params object. +func NewParams() Params { + return Params{ + Output: os.Stdout, + BundleMode: false, + EnableVersionCheck: false, + } +} + +// Runtime represents a single OPA instance. +type Runtime struct { + Params Params + Store storage.Store + Manager *plugins.Manager + + logger logging.Logger + server *server.Server + metrics *prometheus.Provider + reporter *report.Reporter + traceExporter *otlptrace.Exporter + loadedPathsResult *initload.LoadPathsResult + + serverInitialized bool + serverInitMtx sync.RWMutex + done chan struct{} + repl *repl.REPL +} + +// NewRuntime returns a new Runtime object initialized with params. Clients must +// call StartServer() or StartREPL() to start the runtime in either mode. +func NewRuntime(ctx context.Context, params Params) (*Runtime, error) { + if params.ID == "" { + var err error + params.ID, err = generateInstanceID() + if err != nil { + return nil, err + } + } + + level, err := internal_logging.GetLevel(params.Logging.Level) + if err != nil { + return nil, err + } + + // NOTE(tsandall): This is a temporary hack to ensure that log formatting + // and leveling is applied correctly. Currently there are a few places where + // the global logger is used as a fallback, however, that fallback _should_ + // never be used. This ensures that _if_ the fallback is used accidentally, + // that the logging configuration is applied. Once we remove all usage of + // the global logger and we remove the API that allows callers to access the + // global logger, we can remove this. + logging.Get().SetFormatter(internal_logging.GetFormatter(params.Logging.Format, params.Logging.TimestampFormat)) + logging.Get().SetLevel(level) + + var logger logging.Logger + + if params.Logger != nil { + logger = params.Logger + } else { + stdLogger := logging.New() + stdLogger.SetLevel(level) + stdLogger.SetFormatter(internal_logging.GetFormatter(params.Logging.Format, params.Logging.TimestampFormat)) + logger = stdLogger + } + + var filePaths []string + urlPathCount := 0 + for _, path := range params.Paths { + if strings.HasPrefix(path, "http://") || strings.HasPrefix(path, "https://") { + urlPathCount++ + override, err := urlPathToConfigOverride(urlPathCount, path) + if err != nil { + return nil, err + } + params.ConfigOverrides = append(params.ConfigOverrides, override...) + } else { + filePaths = append(filePaths, path) + } + } + params.Paths = filePaths + + config, err := config.Load(params.ConfigFile, params.ConfigOverrides, params.ConfigOverrideFiles) + if err != nil { + return nil, fmt.Errorf("config error: %w", err) + } + + var reporter *report.Reporter + if params.EnableVersionCheck { + var err error + reporter, err = report.New(params.ID, report.Options{Logger: logger}) + if err != nil { + return nil, fmt.Errorf("config error: %w", err) + } + } + + regoVersion := params.regoVersion() + + loaded, err := initload.LoadPathsForRegoVersion(regoVersion, params.Paths, params.Filter, params.BundleMode, params.BundleVerificationConfig, params.SkipBundleVerification, false, false, nil, nil) + if err != nil { + return nil, fmt.Errorf("load error: %w", err) + } + + isAuthorizationEnabled := params.Authorization != server.AuthorizationOff + + info, err := runtime.Term(runtime.Params{Config: config, IsAuthorizationEnabled: isAuthorizationEnabled, SkipKnownSchemaCheck: params.SkipKnownSchemaCheck}) + if err != nil { + return nil, err + } + + consoleLogger := params.ConsoleLogger + if consoleLogger == nil { + l := logging.New() + l.SetFormatter(internal_logging.GetFormatter(params.Logging.Format, params.Logging.TimestampFormat)) + consoleLogger = l + } + + if params.Router == nil { + params.Router = mux.NewRouter() + } + + metricsConfig, parseConfigErr := extractMetricsConfig(config, params) + if parseConfigErr != nil { + return nil, parseConfigErr + } + metrics := prometheus.New(metrics.New(), errorLogger(logger), metricsConfig.Prom.HTTPRequestDurationSeconds.Buckets) + + var store storage.Store + if params.DiskStorage == nil { + params.DiskStorage, err = disk.OptionsFromConfig(config, params.ID) + if err != nil { + return nil, fmt.Errorf("parse disk store configuration: %w", err) + } + } + + if params.DiskStorage != nil { + store, err = disk.New(ctx, logger, metrics, *params.DiskStorage) + if err != nil { + return nil, fmt.Errorf("initialize disk store: %w", err) + } + } else { + store = inmem.NewWithOpts(inmem.OptRoundTripOnWrite(false), + inmem.OptReturnASTValuesOnRead(params.ReadAstValuesFromStore)) + } + + traceExporter, tracerProvider, _, err := internal_tracing.Init(ctx, config, params.ID) + if err != nil { + return nil, fmt.Errorf("config error: %w", err) + } + if tracerProvider != nil { + params.DistributedTracingOpts = tracing.NewOptions( + otelhttp.WithTracerProvider(tracerProvider), + otelhttp.WithPropagators(propagation.TraceContext{}), + ) + } + + manager, err := plugins.New(config, + params.ID, + store, + plugins.Info(info), + plugins.InitBundles(loaded.Bundles), + plugins.InitFiles(loaded.Files), + plugins.MaxErrors(params.ErrorLimit), + plugins.GracefulShutdownPeriod(params.GracefulShutdownPeriod), + plugins.ConsoleLogger(consoleLogger), + plugins.Logger(logger), + plugins.EnablePrintStatements(logger.GetLevel() >= logging.Info), + plugins.PrintHook(loggingPrintHook{logger: logger}), + plugins.WithRouter(params.Router), + plugins.WithPrometheusRegister(metrics), + plugins.WithTracerProvider(tracerProvider), + plugins.WithEnableTelemetry(params.EnableVersionCheck), + plugins.WithParserOptions(params.parserOptions())) + if err != nil { + return nil, fmt.Errorf("config error: %w", err) + } + + if err := manager.Init(ctx); err != nil { + return nil, fmt.Errorf("initialization error: %w", err) + } + + if isAuthorizationEnabled && !params.SkipKnownSchemaCheck { + if err := verifyAuthorizationPolicySchema(manager); err != nil { + return nil, fmt.Errorf("initialization error: %w", err) + } + } + + var bootConfig map[string]interface{} + err = util.Unmarshal(config, &bootConfig) + if err != nil { + return nil, fmt.Errorf("config error: %w", err) + } + + disco, err := discovery.New(manager, discovery.Factories(registeredPlugins), discovery.Metrics(metrics), discovery.BootConfig(bootConfig)) + if err != nil { + return nil, fmt.Errorf("config error: %w", err) + } + + manager.Register(discovery.Name, disco) + + rt := &Runtime{ + Store: manager.Store, + Params: params, + Manager: manager, + logger: logger, + metrics: metrics, + reporter: reporter, + serverInitialized: false, + traceExporter: traceExporter, + loadedPathsResult: loaded, + } + + return rt, nil +} + +// extractMetricsConfig returns the configuration for server metrics and parsing errors if any +func extractMetricsConfig(config []byte, params Params) (*metrics_config.Config, error) { + var opaParsedConfig, opaParsedConfigErr = opa_config.ParseConfig(config, params.ID) + if opaParsedConfigErr != nil { + return nil, opaParsedConfigErr + } + + var serverMetricsData []byte + if opaParsedConfig.Server != nil { + serverMetricsData = opaParsedConfig.Server.Metrics + } + + var configBuilder = metrics_config.NewConfigBuilder() + var metricsParsedConfig, metricsParsedConfigErr = configBuilder.WithBytes(serverMetricsData).Parse() + if metricsParsedConfigErr != nil { + return nil, fmt.Errorf("server metrics configuration parse error: %w", metricsParsedConfigErr) + } + + return metricsParsedConfig, nil +} + +// StartServer starts the runtime in server mode. This function will block the +// calling goroutine and will exit the program on error. +func (rt *Runtime) StartServer(ctx context.Context) { + err := rt.Serve(ctx) + if err != nil { + os.Exit(1) + } +} + +// Serve will start a new REST API server and listen for requests. This +// will block until either: an error occurs, the context is canceled, or +// a SIGTERM or SIGKILL signal is sent. +func (rt *Runtime) Serve(ctx context.Context) error { + if rt.Params.Addrs == nil { + return fmt.Errorf("at least one address must be configured in runtime parameters") + } + + serverInitializingMessage := "Initializing server." + if !rt.Params.AddrSetByUser && rt.Params.V0Compatible { + serverInitializingMessage += " OPA is running on a public (0.0.0.0) network interface. Unless you intend to expose OPA outside of the host, binding to the localhost interface (--addr localhost:8181) is recommended. See https://www.openpolicyagent.org/docs/latest/security/#interface-binding" + } + + if rt.Params.DiagnosticAddrs == nil { + rt.Params.DiagnosticAddrs = &[]string{} + } + + rt.logger.WithFields(map[string]interface{}{ + "addrs": *rt.Params.Addrs, + "diagnostic-addrs": *rt.Params.DiagnosticAddrs, + }).Info(serverInitializingMessage) + + if rt.Params.Authorization == server.AuthorizationOff && rt.Params.Authentication == server.AuthenticationToken { + rt.logger.Error("Token authentication enabled without authorization. Authentication will be ineffective. See https://www.openpolicyagent.org/docs/latest/security/#authentication-and-authorization for more information.") + } + + checkUserPrivileges(rt.logger) + + // NOTE(tsandall): at some point, hopefully we can remove this because the + // Go runtime will just do the right thing. Until then, try to set + // GOMAXPROCS based on the CPU quota applied to the process. + undo, err := maxprocs.Set(maxprocs.Logger(func(f string, a ...interface{}) { + rt.logger.Debug(f, a...) + })) + if err != nil { + rt.logger.WithFields(map[string]interface{}{"err": err}).Debug("Failed to set GOMAXPROCS from CPU quota.") + } + + defer undo() + + if err := rt.Manager.Start(ctx); err != nil { + rt.logger.WithFields(map[string]interface{}{"err": err}).Error("Failed to start plugins.") + return err + } + + defer rt.Manager.Stop(ctx) + + if rt.traceExporter != nil { + if err := rt.traceExporter.Start(ctx); err != nil { + rt.logger.WithFields(map[string]interface{}{"err": err}).Error("Failed to start OpenTelemetry trace exporter.") + return err + } + } + + rt.server = server.New(). + WithRouter(rt.Params.Router). + WithStore(rt.Store). + WithManager(rt.Manager). + WithCompilerErrorLimit(rt.Params.ErrorLimit). + WithPprofEnabled(rt.Params.PprofEnabled). + WithAddresses(*rt.Params.Addrs). + WithH2CEnabled(rt.Params.H2CEnabled). + // always use the initial values for the certificate and ca pool, reloading behavior is configured below + WithCertificate(rt.Params.Certificate). + WithCertPool(rt.Params.CertPool). + WithAuthentication(rt.Params.Authentication). + WithAuthorization(rt.Params.Authorization). + WithDecisionIDFactory(rt.decisionIDFactory). + WithDecisionLoggerWithErr(rt.decisionLogger). + WithRuntime(rt.Manager.Info). + WithMetrics(rt.metrics). + WithMinTLSVersion(rt.Params.MinTLSVersion). + WithCipherSuites(rt.Params.CipherSuites). + WithDistributedTracingOpts(rt.Params.DistributedTracingOpts) + + // If decision_logging plugin enabled, check to see if we opted in to the ND builtins cache. + if lp := logs.Lookup(rt.Manager); lp != nil { + rt.server = rt.server.WithNDBCacheEnabled(rt.Manager.Config.NDBuiltinCacheEnabled()) + } + + if rt.Params.DiagnosticAddrs != nil { + rt.server = rt.server.WithDiagnosticAddresses(*rt.Params.DiagnosticAddrs) + } + + if rt.Params.UnixSocketPerm != nil { + rt.server = rt.server.WithUnixSocketPermission(rt.Params.UnixSocketPerm) + } + + // If a refresh period is set, then we will periodically reload the certificate and ca pool. Otherwise, we will only + // reload cert, key and ca pool files when they change on disk. + if rt.Params.CertificateRefresh > 0 { + rt.server = rt.server.WithCertRefresh(rt.Params.CertificateRefresh) + } + + // if either the cert or the ca pool file is set then these fields will be set on the server and reloaded when they + // change on disk. + if rt.Params.CertificateFile != "" || rt.Params.CertPoolFile != "" { + rt.server = rt.server.WithTLSConfig(&server.TLSConfig{ + CertFile: rt.Params.CertificateFile, + KeyFile: rt.Params.CertificateKeyFile, + CertPoolFile: rt.Params.CertPoolFile, + }) + } + + ctx, cancel := context.WithCancel(ctx) + defer cancel() + rt.server, err = rt.server.Init(ctx) + if err != nil { + rt.logger.WithFields(map[string]interface{}{"err": err}).Error("Unable to initialize server.") + return err + } + + if rt.Params.Watch { + if err := rt.startWatcher(ctx, rt.Params.Paths, rt.onReloadLogger); err != nil { + rt.logger.WithFields(map[string]interface{}{"err": err}).Error("Unable to open watch.") + return err + } + } + + if rt.Params.EnableVersionCheck { + rt.done = make(chan struct{}) + go rt.checkOPAUpdateLoop(ctx, rt.done) + } + + defer func() { + if rt.done != nil { + rt.done <- struct{}{} + } + }() + + rt.server.Handler = NewLoggingHandler(rt.logger, rt.server.Handler) + rt.server.DiagnosticHandler = NewLoggingHandler(rt.logger, rt.server.DiagnosticHandler) + + if err := rt.waitPluginsReady( + 100*time.Millisecond, + time.Second*time.Duration(rt.Params.ReadyTimeout)); err != nil { + rt.logger.WithFields(map[string]interface{}{"err": err}).Error("Failed to wait for plugins activation.") + return err + } + + loops, err := rt.server.Listeners() + if err != nil { + rt.logger.WithFields(map[string]interface{}{"err": err}).Error("Unable to create listeners.") + return err + } + + errc := make(chan error) + for _, loop := range loops { + go func(serverLoop func() error) { + errc <- serverLoop() + }(loop) + } + + // Buffer one element as os/signal uses non-blocking channel sends. + // This prevents potentially dropping the first element and failing to shut + // down gracefully. A buffer of 1 is sufficient as we're just looking for a + // one-time shutdown signal. + signalc := make(chan os.Signal, 1) + signal.Notify(signalc, syscall.SIGINT, syscall.SIGTERM) + + // Note that there is a small chance the socket of the server listener is still + // closed by the time this block is executed, due to the serverLoop above + // executing in a goroutine. + rt.serverInitMtx.Lock() + rt.serverInitialized = true + rt.serverInitMtx.Unlock() + rt.Manager.ServerInitialized() + + rt.logger.Debug("Server initialized.") + + for { + select { + case <-ctx.Done(): + return rt.gracefulServerShutdown(rt.server) + case <-signalc: + return rt.gracefulServerShutdown(rt.server) + case err := <-errc: + rt.logger.WithFields(map[string]interface{}{"err": err}).Error("Listener failed.") + os.Exit(1) + } + } +} + +// Addrs returns a list of addresses that the runtime is listening on (when +// in server mode). Returns an empty list if it hasn't started listening. +func (rt *Runtime) Addrs() []string { + rt.serverInitMtx.RLock() + defer rt.serverInitMtx.RUnlock() + + if !rt.serverInitialized { + return nil + } + + return rt.server.Addrs() +} + +// DiagnosticAddrs returns a list of diagnostic addresses that the runtime is +// listening on (when in server mode). Returns an empty list if it hasn't +// started listening. +func (rt *Runtime) DiagnosticAddrs() []string { + if rt.server == nil { + return nil + } + + return rt.server.DiagnosticAddrs() +} + +// StartREPL starts the runtime in REPL mode. This function will block the calling goroutine. +func (rt *Runtime) StartREPL(ctx context.Context) { + if err := rt.Manager.Start(ctx); err != nil { + fmt.Fprintln(rt.Params.Output, "error starting plugins:", err) + os.Exit(1) + } + + defer rt.Manager.Stop(ctx) + + banner := rt.getBanner() + repl := repl.New(rt.Store, rt.Params.HistoryPath, rt.Params.Output, rt.Params.OutputFormat, rt.Params.ErrorLimit, banner). + WithRuntime(rt.Manager.Info). + WithRegoVersion(rt.Params.regoVersion()). + WithInitBundles(rt.loadedPathsResult.Bundles) + + if rt.Params.Watch { + if err := rt.startWatcher(ctx, rt.Params.Paths, onReloadPrinter(rt.Params.Output)); err != nil { + fmt.Fprintln(rt.Params.Output, "error opening watch:", err) + os.Exit(1) + } + } + + if rt.Params.EnableVersionCheck { + go func() { + repl.SetOPAVersionReport(rt.checkOPAUpdate(ctx).Slice()) + }() + } + + rt.repl = repl + repl.Loop(ctx) +} + +// SetDistributedTracingLogging configures the distributed tracing's ErrorHandler, +// and logger instances. +func (rt *Runtime) SetDistributedTracingLogging() { + internal_tracing.SetupLogging(rt.logger) +} + +func (rt *Runtime) checkOPAUpdate(ctx context.Context) *report.DataResponse { + resp, _ := rt.reporter.SendReport(ctx) + return resp +} + +func (rt *Runtime) checkOPAUpdateLoop(ctx context.Context, done chan struct{}) { + rt.checkOPAUpdateLoopDurations(ctx, done, defaultInitialUploadInterval, defaultLaterUploadInterval) +} + +func (rt *Runtime) checkOPAUpdateLoopDurations(ctx context.Context, done chan struct{}, initialDur, laterDur time.Duration) { + ticker := time.NewTicker(initialDur) + i := 0 + mr.New(mr.NewSource(time.Now().UnixNano())) // Seed the PRNG. + + for { + resp, err := rt.reporter.SendReport(ctx) + if err != nil { + rt.logger.WithFields(map[string]interface{}{"err": err}).Debug("Unable to send OPA version report.") + } else { + if resp.Latest.OPAUpToDate { + rt.logger.WithFields(map[string]interface{}{ + "current_version": version.Version, + }).Debug("OPA is up to date.") + } else { + rt.logger.WithFields(map[string]interface{}{ + "download_opa": resp.Latest.Download, + "release_notes": resp.Latest.ReleaseNotes, + "current_version": version.Version, + "latest_version": strings.TrimPrefix(resp.Latest.LatestRelease, "v"), + }).Info("OPA is out of date.") + } + } + select { + case <-ticker.C: + ticker.Stop() + i++ // count the attempts + + newInterval := time.Duration(mr.Int63n(int64(time.Hour / time.Second))) // spray, between 0 and 1 hr + if i < 6 { + newInterval += initialDur + } else { + newInterval += laterDur + } + ticker = time.NewTicker(newInterval) + case <-done: + ticker.Stop() + return + } + } +} + +func (rt *Runtime) decisionIDFactory() string { + if rt.Params.DecisionIDFactory != nil { + return rt.Params.DecisionIDFactory() + } + if logs.Lookup(rt.Manager) != nil { + return generateDecisionID() + } + return "" +} + +func (rt *Runtime) decisionLogger(ctx context.Context, event *server.Info) error { + plugin := logs.Lookup(rt.Manager) + if plugin == nil { + return nil + } + + return plugin.Log(ctx, event) +} + +func (rt *Runtime) startWatcher(ctx context.Context, paths []string, onReload func(time.Duration, error)) error { + watcher, err := rt.getWatcher(paths) + if err != nil { + return err + } + go rt.readWatcher(ctx, watcher, paths, onReload) + return nil +} + +func (rt *Runtime) readWatcher(ctx context.Context, watcher *fsnotify.Watcher, paths []string, onReload func(time.Duration, error)) { + for evt := range watcher.Events { + removalMask := fsnotify.Remove | fsnotify.Rename + mask := fsnotify.Create | fsnotify.Write | removalMask + if (evt.Op & mask) != 0 { + rt.logger.WithFields(map[string]interface{}{ + "event": evt.String(), + }).Debug("Registered file event.") + t0 := time.Now() + removed := "" + if (evt.Op & removalMask) != 0 { + removed = evt.Name + } + err := rt.processWatcherUpdate(ctx, paths, removed) + onReload(time.Since(t0), err) + } + } +} + +func (rt *Runtime) processWatcherUpdate(ctx context.Context, paths []string, removed string) error { + + return pathwatcher.ProcessWatcherUpdateForRegoVersion(ctx, rt.Manager.ParserOptions().RegoVersion, paths, removed, rt.Store, rt.Params.Filter, rt.Params.BundleMode, func(ctx context.Context, txn storage.Transaction, loaded *initload.LoadPathsResult) error { + _, err := initload.InsertAndCompile(ctx, initload.InsertAndCompileOptions{ + Store: rt.Store, + Txn: txn, + Files: loaded.Files, + Bundles: loaded.Bundles, + MaxErrors: -1, + ParserOptions: rt.Manager.ParserOptions(), + }) + + return err + }) +} + +func (rt *Runtime) getBanner() string { + var buf bytes.Buffer + fmt.Fprintf(&buf, "OPA %v (commit %v, built at %v)\n", version.Version, version.Vcs, version.Timestamp) + fmt.Fprintf(&buf, "\n") + fmt.Fprintf(&buf, "Run 'help' to see a list of commands and check for updates.\n") + return buf.String() +} + +func (rt *Runtime) gracefulServerShutdown(s *server.Server) error { + if rt.Params.ShutdownWaitPeriod > 0 { + rt.logger.Info("Waiting %vs before initiating shutdown...", rt.Params.ShutdownWaitPeriod) + time.Sleep(time.Duration(rt.Params.ShutdownWaitPeriod) * time.Second) + } + + rt.logger.Info("Shutting down...") + ctx, cancel := context.WithTimeout(context.Background(), time.Duration(rt.Params.GracefulShutdownPeriod)*time.Second) + defer cancel() + err := s.Shutdown(ctx) + if err != nil { + rt.logger.WithFields(map[string]interface{}{"err": err}).Error("Failed to shutdown server gracefully.") + return err + } + rt.logger.Info("Server shutdown.") + + if rt.traceExporter != nil { + err = rt.traceExporter.Shutdown(ctx) + if err != nil { + rt.logger.WithFields(map[string]interface{}{"err": err}).Error("Failed to shutdown OpenTelemetry trace exporter gracefully.") + } + } + return nil +} + +func (rt *Runtime) waitPluginsReady(checkInterval, timeout time.Duration) error { + if timeout <= 0 { + return nil + } + + // check readiness of all plugins + pluginsReady := func() bool { + for _, status := range rt.Manager.PluginStatus() { + if status != nil && status.State != plugins.StateOK { + return false + } + } + return true + } + + rt.logger.Debug("Waiting for plugins activation (%v).", timeout) + + return util.WaitFunc(pluginsReady, checkInterval, timeout) +} + +func (rt *Runtime) onReloadLogger(d time.Duration, err error) { + rt.logger.WithFields(map[string]interface{}{ + "duration": d, + "err": err, + }).Info("Processed file watch event.") +} + +func (rt *Runtime) getWatcher(rootPaths []string) (*fsnotify.Watcher, error) { + watcher, err := pathwatcher.CreatePathWatcher(rootPaths) + if err != nil { + return nil, err + } + + for _, path := range watcher.WatchList() { + rt.logger.WithFields(map[string]interface{}{"path": path}).Debug("watching path") + } + + return watcher, nil +} + +func urlPathToConfigOverride(pathCount int, path string) ([]string, error) { + uri, err := url.Parse(path) + if err != nil { + return nil, err + } + baseURL := uri.Scheme + "://" + uri.Host + urlPath := uri.Path + if uri.RawQuery != "" { + urlPath += "?" + uri.RawQuery + } + + return []string{ + fmt.Sprintf("services.cli%d.url=%s", pathCount, baseURL), + fmt.Sprintf("bundles.cli%d.service=cli%d", pathCount, pathCount), + fmt.Sprintf("bundles.cli%d.resource=%s", pathCount, urlPath), + fmt.Sprintf("bundles.cli%d.persist=true", pathCount), + }, nil +} + +func errorLogger(logger logging.Logger) func(attrs map[string]interface{}, f string, a ...interface{}) { + return func(attrs map[string]interface{}, f string, a ...interface{}) { + logger.WithFields(attrs).Error(f, a...) + } +} + +func onReloadPrinter(output io.Writer) func(time.Duration, error) { + return func(d time.Duration, err error) { + if err != nil { + fmt.Fprintf(output, "\n# reload error (took %v): %v", d, err) + } else { + fmt.Fprintf(output, "\n# reloaded files (took %v)", d) + } + } +} + +func generateInstanceID() (string, error) { + return uuid.New(rand.Reader) +} + +func generateDecisionID() string { + id, err := uuid.New(rand.Reader) + if err != nil { + return "" + } + return id +} + +func verifyAuthorizationPolicySchema(m *plugins.Manager) error { + authorizationDecisionRef, err := ref.ParseDataPath(*m.Config.DefaultAuthorizationDecision) + if err != nil { + return err + } + + return compiler.VerifyAuthorizationPolicySchema(m.GetCompiler(), authorizationDecisionRef) +} + +func init() { + registeredPlugins = make(map[string]plugins.Factory) +} diff --git a/vendor/github.com/open-policy-agent/opa/schemas/authorizationPolicy.json b/vendor/github.com/open-policy-agent/opa/v1/schemas/authorizationPolicy.json similarity index 100% rename from vendor/github.com/open-policy-agent/opa/schemas/authorizationPolicy.json rename to vendor/github.com/open-policy-agent/opa/v1/schemas/authorizationPolicy.json diff --git a/vendor/github.com/open-policy-agent/opa/schemas/schemas.go b/vendor/github.com/open-policy-agent/opa/v1/schemas/schemas.go similarity index 100% rename from vendor/github.com/open-policy-agent/opa/schemas/schemas.go rename to vendor/github.com/open-policy-agent/opa/v1/schemas/schemas.go diff --git a/vendor/github.com/open-policy-agent/opa/sdk/RawMapper.go b/vendor/github.com/open-policy-agent/opa/v1/sdk/RawMapper.go similarity index 85% rename from vendor/github.com/open-policy-agent/opa/sdk/RawMapper.go rename to vendor/github.com/open-policy-agent/opa/v1/sdk/RawMapper.go index 14221e41c..b0f6f1b0d 100644 --- a/vendor/github.com/open-policy-agent/opa/sdk/RawMapper.go +++ b/vendor/github.com/open-policy-agent/opa/v1/sdk/RawMapper.go @@ -1,7 +1,7 @@ package sdk import ( - "github.com/open-policy-agent/opa/rego" + "github.com/open-policy-agent/opa/v1/rego" ) type RawMapper struct { diff --git a/vendor/github.com/open-policy-agent/opa/sdk/opa.go b/vendor/github.com/open-policy-agent/opa/v1/sdk/opa.go similarity index 96% rename from vendor/github.com/open-policy-agent/opa/sdk/opa.go rename to vendor/github.com/open-policy-agent/opa/v1/sdk/opa.go index 59ebe7466..e309a8365 100644 --- a/vendor/github.com/open-policy-agent/opa/sdk/opa.go +++ b/vendor/github.com/open-policy-agent/opa/v1/sdk/opa.go @@ -13,27 +13,27 @@ import ( "sync" "time" - "github.com/open-policy-agent/opa/ast" - "github.com/open-policy-agent/opa/bundle" - "github.com/open-policy-agent/opa/hooks" "github.com/open-policy-agent/opa/internal/ref" "github.com/open-policy-agent/opa/internal/runtime" "github.com/open-policy-agent/opa/internal/uuid" - "github.com/open-policy-agent/opa/logging" - "github.com/open-policy-agent/opa/metrics" - "github.com/open-policy-agent/opa/plugins" - "github.com/open-policy-agent/opa/plugins/discovery" - "github.com/open-policy-agent/opa/plugins/logs" - "github.com/open-policy-agent/opa/rego" - "github.com/open-policy-agent/opa/server" - "github.com/open-policy-agent/opa/server/types" - "github.com/open-policy-agent/opa/storage" - "github.com/open-policy-agent/opa/topdown" - "github.com/open-policy-agent/opa/topdown/builtins" - "github.com/open-policy-agent/opa/topdown/cache" - "github.com/open-policy-agent/opa/topdown/print" - "github.com/open-policy-agent/opa/util" - "github.com/open-policy-agent/opa/version" + "github.com/open-policy-agent/opa/v1/ast" + "github.com/open-policy-agent/opa/v1/bundle" + "github.com/open-policy-agent/opa/v1/hooks" + "github.com/open-policy-agent/opa/v1/logging" + "github.com/open-policy-agent/opa/v1/metrics" + "github.com/open-policy-agent/opa/v1/plugins" + "github.com/open-policy-agent/opa/v1/plugins/discovery" + "github.com/open-policy-agent/opa/v1/plugins/logs" + "github.com/open-policy-agent/opa/v1/rego" + "github.com/open-policy-agent/opa/v1/server" + "github.com/open-policy-agent/opa/v1/server/types" + "github.com/open-policy-agent/opa/v1/storage" + "github.com/open-policy-agent/opa/v1/topdown" + "github.com/open-policy-agent/opa/v1/topdown/builtins" + "github.com/open-policy-agent/opa/v1/topdown/cache" + "github.com/open-policy-agent/opa/v1/topdown/print" + "github.com/open-policy-agent/opa/v1/util" + "github.com/open-policy-agent/opa/v1/version" ) // OPA represents an instance of the policy engine. OPA can be started with diff --git a/vendor/github.com/open-policy-agent/opa/sdk/options.go b/vendor/github.com/open-policy-agent/opa/v1/sdk/options.go similarity index 93% rename from vendor/github.com/open-policy-agent/opa/sdk/options.go rename to vendor/github.com/open-policy-agent/opa/v1/sdk/options.go index 112093ca8..a702d3e1b 100644 --- a/vendor/github.com/open-policy-agent/opa/sdk/options.go +++ b/vendor/github.com/open-policy-agent/opa/v1/sdk/options.go @@ -8,14 +8,14 @@ import ( "fmt" "io" - "github.com/open-policy-agent/opa/ast" + "github.com/open-policy-agent/opa/v1/ast" "github.com/sirupsen/logrus" - "github.com/open-policy-agent/opa/hooks" - "github.com/open-policy-agent/opa/logging" - "github.com/open-policy-agent/opa/plugins" - "github.com/open-policy-agent/opa/storage" - "github.com/open-policy-agent/opa/storage/inmem" + "github.com/open-policy-agent/opa/v1/hooks" + "github.com/open-policy-agent/opa/v1/logging" + "github.com/open-policy-agent/opa/v1/plugins" + "github.com/open-policy-agent/opa/v1/storage" + "github.com/open-policy-agent/opa/v1/storage/inmem" ) // Options contains parameters to setup and configure OPA. diff --git a/vendor/github.com/open-policy-agent/opa/server/authorizer/authorizer.go b/vendor/github.com/open-policy-agent/opa/v1/server/authorizer/authorizer.go similarity index 93% rename from vendor/github.com/open-policy-agent/opa/server/authorizer/authorizer.go rename to vendor/github.com/open-policy-agent/opa/v1/server/authorizer/authorizer.go index 4240f4037..f3a0ea281 100644 --- a/vendor/github.com/open-policy-agent/opa/server/authorizer/authorizer.go +++ b/vendor/github.com/open-policy-agent/opa/v1/server/authorizer/authorizer.go @@ -11,15 +11,15 @@ import ( "net/url" "strings" - "github.com/open-policy-agent/opa/ast" - "github.com/open-policy-agent/opa/rego" - "github.com/open-policy-agent/opa/server/identifier" - "github.com/open-policy-agent/opa/server/types" - "github.com/open-policy-agent/opa/server/writer" - "github.com/open-policy-agent/opa/storage" - "github.com/open-policy-agent/opa/topdown/cache" - "github.com/open-policy-agent/opa/topdown/print" - "github.com/open-policy-agent/opa/util" + "github.com/open-policy-agent/opa/v1/ast" + "github.com/open-policy-agent/opa/v1/rego" + "github.com/open-policy-agent/opa/v1/server/identifier" + "github.com/open-policy-agent/opa/v1/server/types" + "github.com/open-policy-agent/opa/v1/server/writer" + "github.com/open-policy-agent/opa/v1/storage" + "github.com/open-policy-agent/opa/v1/topdown/cache" + "github.com/open-policy-agent/opa/v1/topdown/print" + "github.com/open-policy-agent/opa/v1/util" ) // Basic provides policy-based authorization over incoming requests. @@ -146,7 +146,7 @@ func (h *Basic) ServeHTTP(w http.ResponseWriter, r *http.Request) { if reason, ok := allowed["reason"]; ok { message, ok := reason.(string) if ok { - writer.Error(w, http.StatusUnauthorized, types.NewErrorV1(types.CodeUnauthorized, message)) + writer.Error(w, http.StatusUnauthorized, types.NewErrorV1(types.CodeUnauthorized, message)) //nolint:govet return } } diff --git a/vendor/github.com/open-policy-agent/opa/v1/server/buffer.go b/vendor/github.com/open-policy-agent/opa/v1/server/buffer.go new file mode 100644 index 000000000..8c805a31c --- /dev/null +++ b/vendor/github.com/open-policy-agent/opa/v1/server/buffer.go @@ -0,0 +1,44 @@ +// Copyright 2017 The OPA Authors. All rights reserved. +// Use of this source code is governed by an Apache2 +// license that can be found in the LICENSE file. + +package server + +import ( + "time" + + "github.com/open-policy-agent/opa/v1/ast" + "github.com/open-policy-agent/opa/v1/logging" + "github.com/open-policy-agent/opa/v1/metrics" + "github.com/open-policy-agent/opa/v1/storage" + "github.com/open-policy-agent/opa/v1/topdown" +) + +// Info contains information describing a policy decision. +type Info struct { + Txn storage.Transaction + Revision string // Deprecated: Use `Bundles` instead + Bundles map[string]BundleInfo + DecisionID string + TraceID string + SpanID string + RemoteAddr string + HTTPRequestContext logging.HTTPRequestContext + Query string + Path string + Timestamp time.Time + Input *interface{} + InputAST ast.Value + Results *interface{} + MappedResults *interface{} + NDBuiltinCache *interface{} + Error error + Metrics metrics.Metrics + Trace []*topdown.Event + RequestID uint64 +} + +// BundleInfo contains information describing a bundle. +type BundleInfo struct { + Revision string +} diff --git a/vendor/github.com/open-policy-agent/opa/server/cache.go b/vendor/github.com/open-policy-agent/opa/v1/server/cache.go similarity index 100% rename from vendor/github.com/open-policy-agent/opa/server/cache.go rename to vendor/github.com/open-policy-agent/opa/v1/server/cache.go diff --git a/vendor/github.com/open-policy-agent/opa/server/certs.go b/vendor/github.com/open-policy-agent/opa/v1/server/certs.go similarity index 99% rename from vendor/github.com/open-policy-agent/opa/server/certs.go rename to vendor/github.com/open-policy-agent/opa/v1/server/certs.go index a02d74729..f0889b21c 100644 --- a/vendor/github.com/open-policy-agent/opa/server/certs.go +++ b/vendor/github.com/open-policy-agent/opa/v1/server/certs.go @@ -18,7 +18,7 @@ import ( "github.com/fsnotify/fsnotify" "github.com/open-policy-agent/opa/internal/pathwatcher" - "github.com/open-policy-agent/opa/logging" + "github.com/open-policy-agent/opa/v1/logging" ) func (s *Server) getCertificate(_ *tls.ClientHelloInfo) (*tls.Certificate, error) { diff --git a/vendor/github.com/open-policy-agent/opa/v1/server/doc.go b/vendor/github.com/open-policy-agent/opa/v1/server/doc.go new file mode 100644 index 000000000..378194cd4 --- /dev/null +++ b/vendor/github.com/open-policy-agent/opa/v1/server/doc.go @@ -0,0 +1,6 @@ +// Copyright 2016 The OPA Authors. All rights reserved. +// Use of this source code is governed by an Apache2 +// license that can be found in the LICENSE file. + +// Package server contains the policy engine's server handlers. +package server diff --git a/vendor/github.com/open-policy-agent/opa/v1/server/features.go b/vendor/github.com/open-policy-agent/opa/v1/server/features.go new file mode 100644 index 000000000..3b0153722 --- /dev/null +++ b/vendor/github.com/open-policy-agent/opa/v1/server/features.go @@ -0,0 +1,10 @@ +// Copyright 2021 The OPA Authors. All rights reserved. +// Use of this source code is governed by an Apache2 +// license that can be found in the LICENSE file. + +//go:build opa_wasm +// +build opa_wasm + +package server + +import _ "github.com/open-policy-agent/opa/v1/features/wasm" diff --git a/vendor/github.com/open-policy-agent/opa/server/handlers/compress.go b/vendor/github.com/open-policy-agent/opa/v1/server/handlers/compress.go similarity index 100% rename from vendor/github.com/open-policy-agent/opa/server/handlers/compress.go rename to vendor/github.com/open-policy-agent/opa/v1/server/handlers/compress.go diff --git a/vendor/github.com/open-policy-agent/opa/server/handlers/decoding.go b/vendor/github.com/open-policy-agent/opa/v1/server/handlers/decoding.go similarity index 92% rename from vendor/github.com/open-policy-agent/opa/server/handlers/decoding.go rename to vendor/github.com/open-policy-agent/opa/v1/server/handlers/decoding.go index 412d5975e..3dae717fd 100644 --- a/vendor/github.com/open-policy-agent/opa/server/handlers/decoding.go +++ b/vendor/github.com/open-policy-agent/opa/v1/server/handlers/decoding.go @@ -4,9 +4,9 @@ import ( "net/http" "strings" - "github.com/open-policy-agent/opa/server/types" - "github.com/open-policy-agent/opa/server/writer" - util_decoding "github.com/open-policy-agent/opa/util/decoding" + "github.com/open-policy-agent/opa/v1/server/types" + "github.com/open-policy-agent/opa/v1/server/writer" + util_decoding "github.com/open-policy-agent/opa/v1/util/decoding" ) // This handler provides hard limits on the size of the request body, for both diff --git a/vendor/github.com/open-policy-agent/opa/server/identifier/certs.go b/vendor/github.com/open-policy-agent/opa/v1/server/identifier/certs.go similarity index 100% rename from vendor/github.com/open-policy-agent/opa/server/identifier/certs.go rename to vendor/github.com/open-policy-agent/opa/v1/server/identifier/certs.go diff --git a/vendor/github.com/open-policy-agent/opa/server/identifier/identifier.go b/vendor/github.com/open-policy-agent/opa/v1/server/identifier/identifier.go similarity index 100% rename from vendor/github.com/open-policy-agent/opa/server/identifier/identifier.go rename to vendor/github.com/open-policy-agent/opa/v1/server/identifier/identifier.go diff --git a/vendor/github.com/open-policy-agent/opa/server/identifier/tls.go b/vendor/github.com/open-policy-agent/opa/v1/server/identifier/tls.go similarity index 100% rename from vendor/github.com/open-policy-agent/opa/server/identifier/tls.go rename to vendor/github.com/open-policy-agent/opa/v1/server/identifier/tls.go diff --git a/vendor/github.com/open-policy-agent/opa/server/identifier/token.go b/vendor/github.com/open-policy-agent/opa/v1/server/identifier/token.go similarity index 100% rename from vendor/github.com/open-policy-agent/opa/server/identifier/token.go rename to vendor/github.com/open-policy-agent/opa/v1/server/identifier/token.go diff --git a/vendor/github.com/open-policy-agent/opa/v1/server/server.go b/vendor/github.com/open-policy-agent/opa/v1/server/server.go new file mode 100644 index 000000000..44bf4852b --- /dev/null +++ b/vendor/github.com/open-policy-agent/opa/v1/server/server.go @@ -0,0 +1,3068 @@ +// Copyright 2016 The OPA Authors. All rights reserved. +// Use of this source code is governed by an Apache2 +// license that can be found in the LICENSE file. + +package server + +import ( + "bytes" + "context" + "crypto/tls" + "crypto/x509" + "encoding/json" + "errors" + "fmt" + "html/template" + "io" + "net" + "net/http" + "net/http/pprof" + "net/url" + "os" + "strconv" + "strings" + "sync" + "time" + + serverDecodingPlugin "github.com/open-policy-agent/opa/v1/plugins/server/decoding" + serverEncodingPlugin "github.com/open-policy-agent/opa/v1/plugins/server/encoding" + + "github.com/gorilla/mux" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/trace" + + "golang.org/x/net/http2" + "golang.org/x/net/http2/h2c" + + "github.com/open-policy-agent/opa/internal/json/patch" + "github.com/open-policy-agent/opa/v1/ast" + "github.com/open-policy-agent/opa/v1/bundle" + "github.com/open-policy-agent/opa/v1/logging" + "github.com/open-policy-agent/opa/v1/metrics" + "github.com/open-policy-agent/opa/v1/plugins" + bundlePlugin "github.com/open-policy-agent/opa/v1/plugins/bundle" + "github.com/open-policy-agent/opa/v1/plugins/status" + "github.com/open-policy-agent/opa/v1/rego" + "github.com/open-policy-agent/opa/v1/server/authorizer" + "github.com/open-policy-agent/opa/v1/server/handlers" + "github.com/open-policy-agent/opa/v1/server/identifier" + "github.com/open-policy-agent/opa/v1/server/types" + "github.com/open-policy-agent/opa/v1/server/writer" + "github.com/open-policy-agent/opa/v1/storage" + "github.com/open-policy-agent/opa/v1/topdown" + "github.com/open-policy-agent/opa/v1/topdown/builtins" + iCache "github.com/open-policy-agent/opa/v1/topdown/cache" + "github.com/open-policy-agent/opa/v1/topdown/lineage" + "github.com/open-policy-agent/opa/v1/tracing" + "github.com/open-policy-agent/opa/v1/util" + "github.com/open-policy-agent/opa/v1/version" +) + +// AuthenticationScheme enumerates the supported authentication schemes. The +// authentication scheme determines how client identities are established. +type AuthenticationScheme int + +// Set of supported authentication schemes. +const ( + AuthenticationOff AuthenticationScheme = iota + AuthenticationToken + AuthenticationTLS +) + +var supportedTLSVersions = []uint16{tls.VersionTLS10, tls.VersionTLS11, tls.VersionTLS12, tls.VersionTLS13} + +// AuthorizationScheme enumerates the supported authorization schemes. The authorization +// scheme determines how access to OPA is controlled. +type AuthorizationScheme int + +// Set of supported authorization schemes. +const ( + AuthorizationOff AuthorizationScheme = iota + AuthorizationBasic +) + +const defaultMinTLSVersion = tls.VersionTLS12 + +// Set of handlers for use in the "handler" dimension of the duration metric. +const ( + PromHandlerV0Data = "v0/data" + PromHandlerV1Data = "v1/data" + PromHandlerV1Query = "v1/query" + PromHandlerV1Policies = "v1/policies" + PromHandlerV1Compile = "v1/compile" + PromHandlerV1Config = "v1/config" + PromHandlerV1Status = "v1/status" + PromHandlerIndex = "index" + PromHandlerCatch = "catchall" + PromHandlerHealth = "health" + PromHandlerAPIAuthz = "authz" +) + +const pqMaxCacheSize = 100 + +// OpenTelemetry attributes +const otelDecisionIDAttr = "opa.decision_id" + +// map of unsafe builtins +var unsafeBuiltinsMap = map[string]struct{}{ast.HTTPSend.Name: {}} + +// Server represents an instance of OPA running in server mode. +type Server struct { + Handler http.Handler + DiagnosticHandler http.Handler + + router *mux.Router + addrs []string + diagAddrs []string + h2cEnabled bool + authentication AuthenticationScheme + authorization AuthorizationScheme + cert *tls.Certificate + tlsConfigMtx sync.RWMutex + certFile string + certFileHash []byte + certKeyFile string + certKeyFileHash []byte + certRefresh time.Duration + certPool *x509.CertPool + certPoolFile string + certPoolFileHash []byte + minTLSVersion uint16 + mtx sync.RWMutex + partials map[string]rego.PartialResult + preparedEvalQueries *cache + store storage.Store + manager *plugins.Manager + decisionIDFactory func() string + logger func(context.Context, *Info) error + errLimit int + pprofEnabled bool + runtime *ast.Term + httpListeners []httpListener + metrics Metrics + defaultDecisionPath string + interQueryBuiltinCache iCache.InterQueryCache + interQueryBuiltinValueCache iCache.InterQueryValueCache + allPluginsOkOnce bool + distributedTracingOpts tracing.Options + ndbCacheEnabled bool + unixSocketPerm *string + cipherSuites *[]uint16 +} + +// Metrics defines the interface that the server requires for recording HTTP +// handler metrics. +type Metrics interface { + RegisterEndpoints(registrar func(path, method string, handler http.Handler)) + InstrumentHandler(handler http.Handler, label string) http.Handler +} + +// TLSConfig represents the TLS configuration for the server. +// This configuration is used to configure file watchers to reload each file as it +// changes on disk. +type TLSConfig struct { + // CertFile is the path to the server's serving certificate file. + CertFile string + + // KeyFile is the path to the server's key file, completing the key pair for the + // CertFile certificate. + KeyFile string + + // CertPoolFile is the path to the CA cert pool file. The contents of this file will be + // reloaded when the file changes on disk and used in as trusted client CAs in the TLS config + // for new connections to the server. + CertPoolFile string +} + +// Loop will contain all the calls from the server that we'll be listening on. +type Loop func() error + +// New returns a new Server. +func New() *Server { + s := Server{} + return &s +} + +// Init initializes the server. This function MUST be called before starting any loops +// from s.Listeners(). +func (s *Server) Init(ctx context.Context) (*Server, error) { + s.initRouters(ctx) + + txn, err := s.store.NewTransaction(ctx, storage.WriteParams) + if err != nil { + return nil, err + } + + // Register triggers so that if runtime reloads the policies, the + // server sees the change. + config := storage.TriggerConfig{ + OnCommit: s.reload, + } + if _, err := s.store.Register(ctx, txn, config); err != nil { + s.store.Abort(ctx, txn) + return nil, err + } + + s.partials = map[string]rego.PartialResult{} + s.preparedEvalQueries = newCache(pqMaxCacheSize) + s.defaultDecisionPath = s.generateDefaultDecisionPath() + s.manager.RegisterNDCacheTrigger(s.updateNDCache) + + s.Handler = s.initHandlerAuthn(s.Handler) + + // compression handler + s.Handler, err = s.initHandlerCompression(s.Handler) + if err != nil { + return nil, err + } + s.DiagnosticHandler = s.initHandlerAuthn(s.DiagnosticHandler) + + s.Handler, err = s.initHandlerDecodingLimits(s.Handler) + if err != nil { + return nil, err + } + + return s, s.store.Commit(ctx, txn) +} + +// Shutdown will attempt to gracefully shutdown each of the http servers +// currently in use by the OPA Server. If any exceed the deadline specified +// by the context an error will be returned. +func (s *Server) Shutdown(ctx context.Context) error { + errChan := make(chan error) + for _, srvr := range s.httpListeners { + go func(s httpListener) { + errChan <- s.Shutdown(ctx) + }(srvr) + } + // wait until each server has finished shutting down + var errorList []error + for i := 0; i < len(s.httpListeners); i++ { + err := <-errChan + if err != nil { + errorList = append(errorList, err) + } + } + + if len(errorList) > 0 { + errMsg := "error while shutting down: " + for i, err := range errorList { + errMsg += fmt.Sprintf("(%d) %s. ", i, err.Error()) + } + return errors.New(errMsg) + } + return nil +} + +// WithAddresses sets the listening addresses that the server will bind to. +func (s *Server) WithAddresses(addrs []string) *Server { + s.addrs = addrs + return s +} + +// WithDiagnosticAddresses sets the listening addresses that the server will +// bind to and *only* serve read-only diagnostic API's. +func (s *Server) WithDiagnosticAddresses(addrs []string) *Server { + s.diagAddrs = addrs + return s +} + +// WithAuthentication sets authentication scheme to use on the server. +func (s *Server) WithAuthentication(scheme AuthenticationScheme) *Server { + s.authentication = scheme + return s +} + +// WithAuthorization sets authorization scheme to use on the server. +func (s *Server) WithAuthorization(scheme AuthorizationScheme) *Server { + s.authorization = scheme + return s +} + +// WithCertificate sets the server-side certificate that the server will use. +func (s *Server) WithCertificate(cert *tls.Certificate) *Server { + s.cert = cert + return s +} + +// WithCertificatePaths sets the server-side certificate and keyfile paths +// that the server will periodically check for changes, and reload if necessary. +func (s *Server) WithCertificatePaths(certFile, keyFile string, refresh time.Duration) *Server { + s.certFile = certFile + s.certKeyFile = keyFile + s.certRefresh = refresh + return s +} + +// WithCertPool sets the server-side cert pool that the server will use. +func (s *Server) WithCertPool(pool *x509.CertPool) *Server { + s.certPool = pool + return s +} + +// WithTLSConfig sets the TLS configuration used by the server. +func (s *Server) WithTLSConfig(tlsConfig *TLSConfig) *Server { + s.certFile = tlsConfig.CertFile + s.certKeyFile = tlsConfig.KeyFile + s.certPoolFile = tlsConfig.CertPoolFile + return s +} + +// WithCertRefresh sets the period on which certs, keys and cert pools are reloaded from disk. +func (s *Server) WithCertRefresh(refresh time.Duration) *Server { + s.certRefresh = refresh + return s +} + +// WithStore sets the storage used by the server. +func (s *Server) WithStore(store storage.Store) *Server { + s.store = store + return s +} + +// WithMetrics sets the metrics provider used by the server. +func (s *Server) WithMetrics(m Metrics) *Server { + s.metrics = m + return s +} + +// WithManager sets the plugins manager used by the server. +func (s *Server) WithManager(manager *plugins.Manager) *Server { + s.manager = manager + return s +} + +// WithCompilerErrorLimit sets the limit on the number of compiler errors the server will +// allow. +func (s *Server) WithCompilerErrorLimit(limit int) *Server { + s.errLimit = limit + return s +} + +// WithPprofEnabled sets whether pprof endpoints are enabled +func (s *Server) WithPprofEnabled(pprofEnabled bool) *Server { + s.pprofEnabled = pprofEnabled + return s +} + +// WithH2CEnabled sets whether h2c ("HTTP/2 cleartext") is enabled for the http listener +func (s *Server) WithH2CEnabled(enabled bool) *Server { + s.h2cEnabled = enabled + return s +} + +// WithDecisionLogger sets the decision logger used by the +// server. DEPRECATED. Use WithDecisionLoggerWithErr instead. +func (s *Server) WithDecisionLogger(logger func(context.Context, *Info)) *Server { + s.logger = func(ctx context.Context, info *Info) error { + logger(ctx, info) + return nil + } + return s +} + +// WithDecisionLoggerWithErr sets the decision logger used by the server. +func (s *Server) WithDecisionLoggerWithErr(logger func(context.Context, *Info) error) *Server { + s.logger = logger + return s +} + +// WithDecisionIDFactory sets a function on the server to generate decision IDs. +func (s *Server) WithDecisionIDFactory(f func() string) *Server { + s.decisionIDFactory = f + return s +} + +// WithRuntime sets the runtime data to provide to the evaluation engine. +func (s *Server) WithRuntime(term *ast.Term) *Server { + s.runtime = term + return s +} + +// WithRouter sets the mux.Router to attach OPA's HTTP API routes onto. If a +// router is not supplied, the server will create it's own. +func (s *Server) WithRouter(router *mux.Router) *Server { + s.router = router + return s +} + +func (s *Server) WithMinTLSVersion(minTLSVersion uint16) *Server { + if isMinTLSVersionSupported(minTLSVersion) { + s.minTLSVersion = minTLSVersion + } else { + s.minTLSVersion = defaultMinTLSVersion + } + return s +} + +// WithDistributedTracingOpts sets the options to be used by distributed tracing. +func (s *Server) WithDistributedTracingOpts(opts tracing.Options) *Server { + s.distributedTracingOpts = opts + return s +} + +// WithNDBCacheEnabled sets whether the ND builtins cache is to be used. +func (s *Server) WithNDBCacheEnabled(ndbCacheEnabled bool) *Server { + s.ndbCacheEnabled = ndbCacheEnabled + return s +} + +// WithCipherSuites sets the list of enabled TLS 1.0–1.2 cipher suites. +func (s *Server) WithCipherSuites(cipherSuites *[]uint16) *Server { + s.cipherSuites = cipherSuites + return s +} + +// WithUnixSocketPermission sets the permission for the Unix domain socket if used to listen for +// incoming connections. Applies to the sockets the server is listening on including diagnostic API's. +func (s *Server) WithUnixSocketPermission(unixSocketPerm *string) *Server { + s.unixSocketPerm = unixSocketPerm + return s +} + +// Listeners returns functions that listen and serve connections. +func (s *Server) Listeners() ([]Loop, error) { + loops := []Loop{} + + handlerBindings := map[httpListenerType]struct { + addrs []string + handler http.Handler + }{ + defaultListenerType: {s.addrs, s.Handler}, + diagnosticListenerType: {s.diagAddrs, s.DiagnosticHandler}, + } + + for t, binding := range handlerBindings { + for _, addr := range binding.addrs { + l, listener, err := s.getListener(addr, binding.handler, t) + if err != nil { + return nil, err + } + s.httpListeners = append(s.httpListeners, listener) + loops = append(loops, l...) + } + } + + return loops, nil +} + +// Addrs returns a list of addresses that the server is listening on. +// If the server hasn't been started it will not return an address. +func (s *Server) Addrs() []string { + return s.addrsForType(defaultListenerType) +} + +// DiagnosticAddrs returns a list of addresses that the server is listening on +// for the read-only diagnostic API's (eg /health, /metrics, etc) +// If the server hasn't been started it will not return an address. +func (s *Server) DiagnosticAddrs() []string { + return s.addrsForType(diagnosticListenerType) +} + +func (s *Server) addrsForType(t httpListenerType) []string { + var addrs []string + for _, l := range s.httpListeners { + a := l.Addr() + if a != "" && l.Type() == t { + addrs = append(addrs, a) + } + } + return addrs +} + +type tcpKeepAliveListener struct { + *net.TCPListener +} + +func (ln tcpKeepAliveListener) Accept() (net.Conn, error) { + tc, err := ln.AcceptTCP() + if err != nil { + return nil, err + } + err = tc.SetKeepAlive(true) + if err != nil { + return nil, err + } + err = tc.SetKeepAlivePeriod(3 * time.Minute) + if err != nil { + return nil, err + } + return tc, nil +} + +type httpListenerType int + +const ( + defaultListenerType httpListenerType = iota + diagnosticListenerType +) + +type httpListener interface { + Addr() string + ListenAndServe() error + ListenAndServeTLS(certFile, keyFile string) error + Shutdown(context.Context) error + Type() httpListenerType +} + +// baseHTTPListener is just a wrapper around http.Server +type baseHTTPListener struct { + s *http.Server + l net.Listener + t httpListenerType + addr string + addrMtx sync.RWMutex +} + +var _ httpListener = (*baseHTTPListener)(nil) + +func newHTTPListener(srvr *http.Server, t httpListenerType) httpListener { + return &baseHTTPListener{s: srvr, t: t} +} + +func newHTTPUnixSocketListener(srvr *http.Server, l net.Listener, t httpListenerType) httpListener { + return &baseHTTPListener{s: srvr, l: l, t: t} +} + +func (b *baseHTTPListener) ListenAndServe() error { + addr := b.s.Addr + if addr == "" { + addr = ":http" + } + var err error + b.l, err = net.Listen("tcp", addr) + if err != nil { + return err + } + + b.initAddr() + + return b.s.Serve(tcpKeepAliveListener{b.l.(*net.TCPListener)}) +} + +func (b *baseHTTPListener) initAddr() { + b.addrMtx.Lock() + if addr := b.l.(*net.TCPListener).Addr(); addr != nil { + b.addr = addr.String() + } + b.addrMtx.Unlock() +} + +func (b *baseHTTPListener) Addr() string { + b.addrMtx.Lock() + defer b.addrMtx.Unlock() + return b.addr +} + +func (b *baseHTTPListener) ListenAndServeTLS(certFile, keyFile string) error { + addr := b.s.Addr + if addr == "" { + addr = ":https" + } + + var err error + b.l, err = net.Listen("tcp", addr) + if err != nil { + return err + } + + b.initAddr() + + defer b.l.Close() + + return b.s.ServeTLS(tcpKeepAliveListener{b.l.(*net.TCPListener)}, certFile, keyFile) +} + +func (b *baseHTTPListener) Shutdown(ctx context.Context) error { + return b.s.Shutdown(ctx) +} + +func (b *baseHTTPListener) Type() httpListenerType { + return b.t +} + +func isMinTLSVersionSupported(TLSVersion uint16) bool { + for _, version := range supportedTLSVersions { + if TLSVersion == version { + return true + } + } + return false +} + +func (s *Server) getListener(addr string, h http.Handler, t httpListenerType) ([]Loop, httpListener, error) { + parsedURL, err := parseURL(addr, s.cert != nil) + if err != nil { + return nil, nil, err + } + + var loops []Loop + var loop Loop + var listener httpListener + switch parsedURL.Scheme { + case "unix": + loop, listener, err = s.getListenerForUNIXSocket(parsedURL, h, t) + loops = []Loop{loop} + case "http": + loop, listener, err = s.getListenerForHTTPServer(parsedURL, h, t) + loops = []Loop{loop} + case "https": + loop, listener, err = s.getListenerForHTTPSServer(parsedURL, h, t) + logger := s.manager.Logger().WithFields(map[string]interface{}{ + "cert-file": s.certFile, + "cert-key-file": s.certKeyFile, + }) + + // if a manual cert refresh period has been set, then use the polling behavior, + // otherwise use the fsnotify default behavior + if s.certRefresh > 0 { + loops = []Loop{loop, s.certLoopPolling(logger)} + } else if s.certFile != "" || s.certPoolFile != "" { + loops = []Loop{loop, s.certLoopNotify(logger)} + } + default: + err = fmt.Errorf("invalid url scheme %q", parsedURL.Scheme) + } + + return loops, listener, err +} + +func (s *Server) getListenerForHTTPServer(u *url.URL, h http.Handler, t httpListenerType) (Loop, httpListener, error) { + if s.h2cEnabled { + h2s := &http2.Server{} + h = h2c.NewHandler(h, h2s) + } + h1s := http.Server{ + Addr: u.Host, + Handler: h, + } + + l := newHTTPListener(&h1s, t) + + return l.ListenAndServe, l, nil +} + +func (s *Server) getListenerForHTTPSServer(u *url.URL, h http.Handler, t httpListenerType) (Loop, httpListener, error) { + + if s.cert == nil { + return nil, nil, fmt.Errorf("TLS certificate required but not supplied") + } + + tlsConfig := tls.Config{ + GetCertificate: s.getCertificate, + // GetConfigForClient is used to ensure that a fresh config is provided containing the latest cert pool. + // This is not required, but appears to be how connect time updates config should be done: + // https://github.com/golang/go/issues/16066#issuecomment-250606132 + GetConfigForClient: func(_ *tls.ClientHelloInfo) (*tls.Config, error) { + s.tlsConfigMtx.Lock() + defer s.tlsConfigMtx.Unlock() + + cfg := &tls.Config{ + GetCertificate: s.getCertificate, + ClientCAs: s.certPool, + } + + if s.authentication == AuthenticationTLS { + cfg.ClientAuth = tls.RequireAndVerifyClientCert + } + + if s.minTLSVersion != 0 { + cfg.MinVersion = s.minTLSVersion + } else { + cfg.MinVersion = defaultMinTLSVersion + } + + if s.cipherSuites != nil { + cfg.CipherSuites = *s.cipherSuites + } + + return cfg, nil + }, + } + + httpsServer := http.Server{ + Addr: u.Host, + Handler: h, + TLSConfig: &tlsConfig, + } + + l := newHTTPListener(&httpsServer, t) + + httpsLoop := func() error { return l.ListenAndServeTLS("", "") } + + return httpsLoop, l, nil +} + +func (s *Server) getListenerForUNIXSocket(u *url.URL, h http.Handler, t httpListenerType) (Loop, httpListener, error) { + socketPath := u.Host + u.Path + + // Recover @ prefix for abstract Unix sockets. + if strings.HasPrefix(u.String(), u.Scheme+"://@") { + socketPath = "@" + socketPath + } else { + // Remove domain socket file in case it already exists. + os.Remove(socketPath) + } + + domainSocketServer := http.Server{Handler: h} + unixListener, err := net.Listen("unix", socketPath) + if err != nil { + return nil, nil, err + } + + if s.unixSocketPerm != nil { + modeVal, err := strconv.ParseUint(*s.unixSocketPerm, 8, 32) + if err != nil { + return nil, nil, err + } + + if err := os.Chmod(socketPath, os.FileMode(modeVal)); err != nil { + return nil, nil, err + } + } + + l := newHTTPUnixSocketListener(&domainSocketServer, unixListener, t) + + domainSocketLoop := func() error { return domainSocketServer.Serve(unixListener) } + return domainSocketLoop, l, nil +} + +func (s *Server) initHandlerAuthn(handler http.Handler) http.Handler { + switch s.authentication { + case AuthenticationToken: + handler = identifier.NewTokenBased(handler) + case AuthenticationTLS: + handler = identifier.NewTLSBased(handler) + } + + return handler +} + +func (s *Server) initHandlerAuthz(handler http.Handler) http.Handler { + switch s.authorization { + case AuthorizationBasic: + handler = authorizer.NewBasic( + handler, + s.getCompiler, + s.store, + authorizer.Runtime(s.runtime), + authorizer.Decision(s.manager.Config.DefaultAuthorizationDecisionRef), + authorizer.PrintHook(s.manager.PrintHook()), + authorizer.EnablePrintStatements(s.manager.EnablePrintStatements()), + authorizer.InterQueryCache(s.interQueryBuiltinCache), + authorizer.InterQueryValueCache(s.interQueryBuiltinValueCache)) + + if s.metrics != nil { + handler = s.instrumentHandler(handler.ServeHTTP, PromHandlerAPIAuthz) + } + } + + return handler +} + +// Enforces request body size limits on incoming requests. For gzipped requests, +// it passes the size limit down the body-reading method via the request +// context. +func (s *Server) initHandlerDecodingLimits(handler http.Handler) (http.Handler, error) { + var decodingRawConfig json.RawMessage + serverConfig := s.manager.Config.Server + if serverConfig != nil { + decodingRawConfig = serverConfig.Decoding + } + decodingConfig, err := serverDecodingPlugin.NewConfigBuilder().WithBytes(decodingRawConfig).Parse() + if err != nil { + return nil, err + } + decodingHandler := handlers.DecodingLimitsHandler(handler, *decodingConfig.MaxLength, *decodingConfig.Gzip.MaxLength) + + return decodingHandler, nil +} + +func (s *Server) initHandlerCompression(handler http.Handler) (http.Handler, error) { + var encodingRawConfig json.RawMessage + serverConfig := s.manager.Config.Server + if serverConfig != nil { + encodingRawConfig = serverConfig.Encoding + } + encodingConfig, err := serverEncodingPlugin.NewConfigBuilder().WithBytes(encodingRawConfig).Parse() + if err != nil { + return nil, err + } + compressHandler := handlers.CompressHandler(handler, *encodingConfig.Gzip.MinLength, *encodingConfig.Gzip.CompressionLevel) + + return compressHandler, nil +} + +func (s *Server) initRouters(ctx context.Context) { + mainRouter := s.router + if mainRouter == nil { + mainRouter = mux.NewRouter() + } + + diagRouter := mux.NewRouter() + + // authorizer, if configured, needs the iCache to be set up already + + cacheConfig := s.manager.InterQueryBuiltinCacheConfig() + + s.interQueryBuiltinCache = iCache.NewInterQueryCacheWithContext(ctx, cacheConfig) + s.interQueryBuiltinValueCache = iCache.NewInterQueryValueCache(ctx, cacheConfig) + + s.manager.RegisterCacheTrigger(s.updateCacheConfig) + + // Add authorization handler. This must come BEFORE authentication handler + // so that the latter can run first. + handlerAuthz := s.initHandlerAuthz(mainRouter) + + handlerAuthzDiag := s.initHandlerAuthz(diagRouter) + + // All routers get the same base configuration *and* diagnostic API's + for _, router := range []*mux.Router{mainRouter, diagRouter} { + router.StrictSlash(true) + router.UseEncodedPath() + router.StrictSlash(true) + + if s.metrics != nil { + s.metrics.RegisterEndpoints(func(path, method string, handler http.Handler) { + router.Handle(path, handler).Methods(method) + }) + } + + router.Handle("/health", s.instrumentHandler(s.unversionedGetHealth, PromHandlerHealth)).Methods(http.MethodGet) + // Use this route to evaluate health policy defined at system.health + // By convention, policy is typically defined at system.health.live and system.health.ready, and is + // evaluated by calling /health/live and /health/ready respectively. + router.Handle("/health/{path:.+}", s.instrumentHandler(s.unversionedGetHealthWithPolicy, PromHandlerHealth)).Methods(http.MethodGet) + } + + if s.pprofEnabled { + mainRouter.HandleFunc("/debug/pprof/", pprof.Index) + mainRouter.Handle("/debug/pprof/allocs", pprof.Handler("allocs")) + mainRouter.Handle("/debug/pprof/block", pprof.Handler("block")) + mainRouter.Handle("/debug/pprof/heap", pprof.Handler("heap")) + mainRouter.Handle("/debug/pprof/mutex", pprof.Handler("mutex")) + mainRouter.HandleFunc("/debug/pprof/cmdline", pprof.Cmdline) + mainRouter.HandleFunc("/debug/pprof/profile", pprof.Profile) + mainRouter.HandleFunc("/debug/pprof/symbol", pprof.Symbol) + mainRouter.HandleFunc("/debug/pprof/trace", pprof.Trace) + } + + // Only the main mainRouter gets the OPA API's (data, policies, query, etc) + mainRouter.Handle("/v0/data/{path:.+}", s.instrumentHandler(s.v0DataPost, PromHandlerV0Data)).Methods(http.MethodPost) + mainRouter.Handle("/v0/data", s.instrumentHandler(s.v0DataPost, PromHandlerV0Data)).Methods(http.MethodPost) + mainRouter.Handle("/v1/data/{path:.+}", s.instrumentHandler(s.v1DataDelete, PromHandlerV1Data)).Methods(http.MethodDelete) + mainRouter.Handle("/v1/data/{path:.+}", s.instrumentHandler(s.v1DataPut, PromHandlerV1Data)).Methods(http.MethodPut) + mainRouter.Handle("/v1/data", s.instrumentHandler(s.v1DataPut, PromHandlerV1Data)).Methods(http.MethodPut) + mainRouter.Handle("/v1/data/{path:.+}", s.instrumentHandler(s.v1DataGet, PromHandlerV1Data)).Methods(http.MethodGet) + mainRouter.Handle("/v1/data", s.instrumentHandler(s.v1DataGet, PromHandlerV1Data)).Methods(http.MethodGet) + mainRouter.Handle("/v1/data/{path:.+}", s.instrumentHandler(s.v1DataPatch, PromHandlerV1Data)).Methods(http.MethodPatch) + mainRouter.Handle("/v1/data", s.instrumentHandler(s.v1DataPatch, PromHandlerV1Data)).Methods(http.MethodPatch) + mainRouter.Handle("/v1/data/{path:.+}", s.instrumentHandler(s.v1DataPost, PromHandlerV1Data)).Methods(http.MethodPost) + mainRouter.Handle("/v1/data", s.instrumentHandler(s.v1DataPost, PromHandlerV1Data)).Methods(http.MethodPost) + mainRouter.Handle("/v1/policies", s.instrumentHandler(s.v1PoliciesList, PromHandlerV1Policies)).Methods(http.MethodGet) + mainRouter.Handle("/v1/policies/{path:.+}", s.instrumentHandler(s.v1PoliciesDelete, PromHandlerV1Policies)).Methods(http.MethodDelete) + mainRouter.Handle("/v1/policies/{path:.+}", s.instrumentHandler(s.v1PoliciesGet, PromHandlerV1Policies)).Methods(http.MethodGet) + mainRouter.Handle("/v1/policies/{path:.+}", s.instrumentHandler(s.v1PoliciesPut, PromHandlerV1Policies)).Methods(http.MethodPut) + mainRouter.Handle("/v1/query", s.instrumentHandler(s.v1QueryGet, PromHandlerV1Query)).Methods(http.MethodGet) + mainRouter.Handle("/v1/query", s.instrumentHandler(s.v1QueryPost, PromHandlerV1Query)).Methods(http.MethodPost) + mainRouter.Handle("/v1/compile", s.instrumentHandler(s.v1CompilePost, PromHandlerV1Compile)).Methods(http.MethodPost) + mainRouter.Handle("/v1/config", s.instrumentHandler(s.v1ConfigGet, PromHandlerV1Config)).Methods(http.MethodGet) + mainRouter.Handle("/v1/status", s.instrumentHandler(s.v1StatusGet, PromHandlerV1Status)).Methods(http.MethodGet) + mainRouter.Handle("/", s.instrumentHandler(s.unversionedPost, PromHandlerIndex)).Methods(http.MethodPost) + mainRouter.Handle("/", s.instrumentHandler(s.indexGet, PromHandlerIndex)).Methods(http.MethodGet) + + // These are catch all handlers that respond http.StatusMethodNotAllowed for resources that exist but the method is not allowed + mainRouter.Handle("/v0/data/{path:.*}", s.instrumentHandler(writer.HTTPStatus(http.StatusMethodNotAllowed), PromHandlerCatch)).Methods(http.MethodGet, http.MethodHead, + http.MethodConnect, http.MethodDelete, http.MethodOptions, http.MethodPatch, http.MethodPut, http.MethodTrace) + mainRouter.Handle("/v0/data", s.instrumentHandler(writer.HTTPStatus(http.StatusMethodNotAllowed), PromHandlerCatch)).Methods(http.MethodGet, http.MethodHead, + http.MethodConnect, http.MethodDelete, http.MethodOptions, http.MethodPatch, http.MethodPut, + http.MethodTrace) + // v1 Data catch all + mainRouter.Handle("/v1/data/{path:.*}", s.instrumentHandler(writer.HTTPStatus(http.StatusMethodNotAllowed), PromHandlerCatch)).Methods(http.MethodHead, + http.MethodConnect, http.MethodOptions, http.MethodTrace) + mainRouter.Handle("/v1/data", s.instrumentHandler(writer.HTTPStatus(http.StatusMethodNotAllowed), PromHandlerCatch)).Methods(http.MethodHead, + http.MethodConnect, http.MethodDelete, http.MethodOptions, http.MethodTrace) + // Policies catch all + mainRouter.Handle("/v1/policies", s.instrumentHandler(writer.HTTPStatus(http.StatusMethodNotAllowed), PromHandlerCatch)).Methods(http.MethodHead, + http.MethodConnect, http.MethodDelete, http.MethodOptions, http.MethodTrace, http.MethodPost, http.MethodPut, + http.MethodPatch) + // Policies (/policies/{path.+} catch all + mainRouter.Handle("/v1/policies/{path:.*}", s.instrumentHandler(writer.HTTPStatus(http.StatusMethodNotAllowed), PromHandlerCatch)).Methods(http.MethodHead, + http.MethodConnect, http.MethodOptions, http.MethodTrace, http.MethodPost) + // Query catch all + mainRouter.Handle("/v1/query/{path:.*}", s.instrumentHandler(writer.HTTPStatus(http.StatusMethodNotAllowed), PromHandlerCatch)).Methods(http.MethodHead, + http.MethodConnect, http.MethodDelete, http.MethodOptions, http.MethodTrace, http.MethodPost, http.MethodPut, http.MethodPatch) + mainRouter.Handle("/v1/query", s.instrumentHandler(writer.HTTPStatus(http.StatusMethodNotAllowed), PromHandlerCatch)).Methods(http.MethodHead, + http.MethodConnect, http.MethodDelete, http.MethodOptions, http.MethodTrace, http.MethodPut, http.MethodPatch) + + s.Handler = mainRouter + s.DiagnosticHandler = diagRouter + + // Add authorization handler in the end so that it can run first + s.Handler = handlerAuthz + s.DiagnosticHandler = handlerAuthzDiag +} + +func (s *Server) instrumentHandler(handler func(http.ResponseWriter, *http.Request), label string) http.Handler { + var httpHandler http.Handler = http.HandlerFunc(handler) + if len(s.distributedTracingOpts) > 0 { + httpHandler = tracing.NewHandler(httpHandler, label, s.distributedTracingOpts) + } + if s.metrics != nil { + return s.metrics.InstrumentHandler(httpHandler, label) + } + return httpHandler +} + +func (s *Server) execQuery(ctx context.Context, br bundleRevisions, txn storage.Transaction, parsedQuery ast.Body, input ast.Value, rawInput *interface{}, m metrics.Metrics, explainMode types.ExplainModeV1, includeMetrics, includeInstrumentation, pretty bool) (*types.QueryResponseV1, error) { + results := types.QueryResponseV1{} + logger := s.getDecisionLogger(br) + + var buf *topdown.BufferTracer + if explainMode != types.ExplainOffV1 { + buf = topdown.NewBufferTracer() + } + + var ndbCache builtins.NDBCache + if s.ndbCacheEnabled { + ndbCache = builtins.NDBCache{} + } + + opts := []func(*rego.Rego){ + rego.Store(s.store), + rego.Transaction(txn), + rego.Compiler(s.getCompiler()), + rego.ParsedQuery(parsedQuery), + rego.ParsedInput(input), + rego.Metrics(m), + rego.Instrument(includeInstrumentation), + rego.QueryTracer(buf), + rego.Runtime(s.runtime), + rego.UnsafeBuiltins(unsafeBuiltinsMap), + rego.InterQueryBuiltinCache(s.interQueryBuiltinCache), + rego.InterQueryBuiltinValueCache(s.interQueryBuiltinValueCache), + rego.PrintHook(s.manager.PrintHook()), + rego.EnablePrintStatements(s.manager.EnablePrintStatements()), + rego.DistributedTracingOpts(s.distributedTracingOpts), + rego.NDBuiltinCache(ndbCache), + } + + for _, r := range s.manager.GetWasmResolvers() { + for _, entrypoint := range r.Entrypoints() { + opts = append(opts, rego.Resolver(entrypoint, r)) + } + } + + rego := rego.New(opts...) + + output, err := rego.Eval(ctx) + if err != nil { + _ = logger.Log(ctx, txn, "", parsedQuery.String(), rawInput, input, nil, ndbCache, err, m) + return nil, err + } + + for _, result := range output { + results.Result = append(results.Result, result.Bindings.WithoutWildcards()) + } + + if includeMetrics || includeInstrumentation { + results.Metrics = m.All() + } + + if explainMode != types.ExplainOffV1 { + results.Explanation = s.getExplainResponse(explainMode, *buf, pretty) + } + + var x interface{} = results.Result + if err := logger.Log(ctx, txn, "", parsedQuery.String(), rawInput, input, &x, ndbCache, nil, m); err != nil { + return nil, err + } + return &results, nil +} + +func (s *Server) indexGet(w http.ResponseWriter, _ *http.Request) { + _ = indexHTML.Execute(w, struct { + Version string + BuildCommit string + BuildTimestamp string + BuildHostname string + }{ + Version: version.Version, + BuildCommit: version.Vcs, + BuildTimestamp: version.Timestamp, + BuildHostname: version.Hostname, + }) +} + +type bundleRevisions struct { + LegacyRevision string + Revisions map[string]string +} + +func getRevisions(ctx context.Context, store storage.Store, txn storage.Transaction) (bundleRevisions, error) { + + var err error + var br bundleRevisions + br.Revisions = map[string]string{} + + // Check if we still have a legacy bundle manifest in the store + br.LegacyRevision, err = bundle.LegacyReadRevisionFromStore(ctx, store, txn) + if err != nil && !storage.IsNotFound(err) { + return br, err + } + + // read all bundle revisions from storage (if any exist) + names, err := bundle.ReadBundleNamesFromStore(ctx, store, txn) + if err != nil && !storage.IsNotFound(err) { + return br, err + } + + for _, name := range names { + r, err := bundle.ReadBundleRevisionFromStore(ctx, store, txn, name) + if err != nil && !storage.IsNotFound(err) { + return br, err + } + br.Revisions[name] = r + } + + return br, nil +} + +func (s *Server) reload(context.Context, storage.Transaction, storage.TriggerEvent) { + + // NOTE(tsandall): We currently rely on the storage txn to provide + // critical sections in the server. + // + // If you modify this function to change any other state on the server, you must + // review the other places in the server where that state is accessed to avoid data + // races--the state must be accessed _after_ a txn has been opened. + + // reset some cached info + s.partials = map[string]rego.PartialResult{} + s.preparedEvalQueries = newCache(pqMaxCacheSize) + s.defaultDecisionPath = s.generateDefaultDecisionPath() +} + +func (s *Server) unversionedPost(w http.ResponseWriter, r *http.Request) { + s.v0QueryPath(w, r, "", true) +} + +func (s *Server) v0DataPost(w http.ResponseWriter, r *http.Request) { + s.v0QueryPath(w, r, mux.Vars(r)["path"], false) +} + +func (s *Server) v0QueryPath(w http.ResponseWriter, r *http.Request, urlPath string, useDefaultDecisionPath bool) { + m := metrics.New() + m.Timer(metrics.ServerHandler).Start() + + decisionID := s.generateDecisionID() + ctx := logging.WithDecisionID(r.Context(), decisionID) + annotateSpan(ctx, decisionID) + + input, goInput, err := readInputV0(r) + if err != nil { + writer.ErrorString(w, http.StatusBadRequest, types.CodeInvalidParameter, fmt.Errorf("unexpected parse error for input: %w", err)) + return + } + + // Prepare for query. + txn, err := s.store.NewTransaction(ctx) + if err != nil { + writer.ErrorAuto(w, err) + return + } + + defer s.store.Abort(ctx, txn) + + br, err := getRevisions(ctx, s.store, txn) + if err != nil { + writer.ErrorAuto(w, err) + return + } + + if useDefaultDecisionPath { + urlPath = s.generateDefaultDecisionPath() + } + + logger := s.getDecisionLogger(br) + + var ndbCache builtins.NDBCache + if s.ndbCacheEnabled { + ndbCache = builtins.NDBCache{} + } + + pqID := "v0QueryPath::" + urlPath + preparedQuery, ok := s.getCachedPreparedEvalQuery(pqID, m) + if !ok { + opts := []func(*rego.Rego){ + rego.Compiler(s.getCompiler()), + rego.Store(s.store), + } + + // Set resolvers on the base Rego object to avoid having them get + // re-initialized, and to propagate them to the prepared query. + for _, r := range s.manager.GetWasmResolvers() { + for _, entrypoint := range r.Entrypoints() { + opts = append(opts, rego.Resolver(entrypoint, r)) + } + } + + rego, err := s.makeRego(ctx, false, txn, input, urlPath, m, false, nil, opts) + if err != nil { + _ = logger.Log(ctx, txn, urlPath, "", goInput, input, nil, ndbCache, err, m) + writer.ErrorAuto(w, err) + return + } + + pq, err := rego.PrepareForEval(ctx) + if err != nil { + _ = logger.Log(ctx, txn, urlPath, "", goInput, input, nil, ndbCache, err, m) + writer.ErrorAuto(w, err) + return + } + preparedQuery = &pq + s.preparedEvalQueries.Insert(pqID, preparedQuery) + } + + evalOpts := []rego.EvalOption{ + rego.EvalTransaction(txn), + rego.EvalParsedInput(input), + rego.EvalMetrics(m), + rego.EvalInterQueryBuiltinCache(s.interQueryBuiltinCache), + rego.EvalInterQueryBuiltinValueCache(s.interQueryBuiltinValueCache), + rego.EvalNDBuiltinCache(ndbCache), + } + + rs, err := preparedQuery.Eval( + ctx, + evalOpts..., + ) + + m.Timer(metrics.ServerHandler).Stop() + + // Handle results. + if err != nil { + _ = logger.Log(ctx, txn, urlPath, "", goInput, input, nil, ndbCache, err, m) + writer.ErrorAuto(w, err) + return + } + + if len(rs) == 0 { + ref := stringPathToDataRef(urlPath) + + var messageType = types.MsgMissingError + if len(s.getCompiler().GetRulesForVirtualDocument(ref)) > 0 { + messageType = types.MsgFoundUndefinedError + } + err := types.NewErrorV1(types.CodeUndefinedDocument, "%v: %v", messageType, ref) + if err := logger.Log(ctx, txn, urlPath, "", goInput, input, nil, ndbCache, err, m); err != nil { + writer.ErrorAuto(w, err) + return + } + + writer.Error(w, http.StatusNotFound, err) + return + } + err = logger.Log(ctx, txn, urlPath, "", goInput, input, &rs[0].Expressions[0].Value, ndbCache, nil, m) + if err != nil { + writer.ErrorAuto(w, err) + return + } + + writer.JSONOK(w, rs[0].Expressions[0].Value, pretty(r)) +} + +func (s *Server) getCachedPreparedEvalQuery(key string, m metrics.Metrics) (*rego.PreparedEvalQuery, bool) { + pq, ok := s.preparedEvalQueries.Get(key) + m.Counter(metrics.ServerQueryCacheHit) // Creates the counter on the metrics if it doesn't exist, starts at 0 + if ok { + m.Counter(metrics.ServerQueryCacheHit).Incr() // Increment counter on hit + return pq.(*rego.PreparedEvalQuery), true + } + return nil, false +} + +func (s *Server) canEval(ctx context.Context) bool { + // Create very simple query that binds a single variable. + opts := []func(*rego.Rego){ + rego.Compiler(s.getCompiler()), + rego.Store(s.store), + rego.Query("x = 1"), + } + + for _, r := range s.manager.GetWasmResolvers() { + for _, ep := range r.Entrypoints() { + opts = append(opts, rego.Resolver(ep, r)) + } + } + + eval := rego.New(opts...) + // Run evaluation. + rs, err := eval.Eval(ctx) + if err != nil { + return false + } + + v, ok := rs[0].Bindings["x"] + if ok { + jsonNumber, ok := v.(json.Number) + if ok && jsonNumber.String() == "1" { + return true + } + } + return false +} + +func (s *Server) bundlesReady(pluginStatuses map[string]*plugins.Status) bool { + + // Look for a discovery plugin first, if it exists and isn't ready + // then don't bother with the others. + // Note: use "discovery" instead of `discovery.Name` to avoid import + // cycle problems.. + dpStatus, ok := pluginStatuses["discovery"] + if ok && dpStatus != nil && (dpStatus.State != plugins.StateOK) { + return false + } + + // The bundle plugin won't return "OK" until the first activation + // of each configured bundle. + bpStatus, ok := pluginStatuses[bundlePlugin.Name] + if ok && bpStatus != nil && (bpStatus.State != plugins.StateOK) { + return false + } + + return true +} + +func (s *Server) unversionedGetHealth(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + includeBundleStatus := getBoolParam(r.URL, types.ParamBundleActivationV1, true) || + getBoolParam(r.URL, types.ParamBundlesActivationV1, true) + includePluginStatus := getBoolParam(r.URL, types.ParamPluginsV1, true) + excludePlugin := getStringSliceParam(r.URL, types.ParamExcludePluginV1) + excludePluginMap := map[string]struct{}{} + for _, name := range excludePlugin { + excludePluginMap[name] = struct{}{} + } + + // Ensure the server can evaluate a simple query + if !s.canEval(ctx) { + writeHealthResponse(w, errors.New("unable to perform evaluation")) + return + } + + pluginStatuses := s.manager.PluginStatus() + + // Ensure that bundles (if configured, and requested to be included in the result) + // have been activated successfully. This will include discovery bundles as well as + // normal bundles that are configured. + if includeBundleStatus && !s.bundlesReady(pluginStatuses) { + // For backwards compatibility we don't return a payload with statuses for the bundle endpoint + writeHealthResponse(w, errors.New("one or more bundles are not activated")) + return + } + + if includePluginStatus { + // Ensure that all plugins (if requested to be included in the result) have an OK status. + hasErr := false + for name, status := range pluginStatuses { + if _, exclude := excludePluginMap[name]; exclude { + continue + } + if status != nil && status.State != plugins.StateOK { + hasErr = true + break + } + } + if hasErr { + writeHealthResponse(w, errors.New("one or more plugins are not up")) + return + } + } + writeHealthResponse(w, nil) +} + +func (s *Server) unversionedGetHealthWithPolicy(w http.ResponseWriter, r *http.Request) { + pluginStatus := s.manager.PluginStatus() + pluginState := map[string]string{} + + // optimistically assume all plugins are ok + allPluginsOk := true + + // build input document for health check query + input := func() map[string]interface{} { + s.mtx.Lock() + defer s.mtx.Unlock() + + // iterate over plugin status to extract state + for name, status := range pluginStatus { + if status != nil { + pluginState[name] = string(status.State) + // if all plugins have not been in OK state yet, then check to see if plugin state is OKx + if !s.allPluginsOkOnce && status.State != plugins.StateOK { + allPluginsOk = false + } + } + } + // once all plugins are OK, set the allPluginsOkOnce flag to true, indicating that all + // plugins have achieved a "ready" state at least once on the server. + if allPluginsOk { + s.allPluginsOkOnce = true + } + + return map[string]interface{}{ + "plugin_state": pluginState, + "plugins_ready": s.allPluginsOkOnce, + } + }() + + vars := mux.Vars(r) + urlPath := vars["path"] + healthDataPath := fmt.Sprintf("/system/health/%s", urlPath) + healthDataPath = stringPathToDataRef(healthDataPath).String() + + rego := rego.New( + rego.Query(healthDataPath), + rego.Compiler(s.getCompiler()), + rego.Store(s.store), + rego.Input(input), + rego.Runtime(s.runtime), + rego.PrintHook(s.manager.PrintHook()), + ) + + rs, err := rego.Eval(r.Context()) + if err != nil { + writeHealthResponse(w, err) + return + } + + if len(rs) == 0 { + writeHealthResponse(w, fmt.Errorf("health check (%v) was undefined", healthDataPath)) + return + } + + result, ok := rs[0].Expressions[0].Value.(bool) + if ok && result { + writeHealthResponse(w, nil) + return + } + + writeHealthResponse(w, fmt.Errorf("health check (%v) returned unexpected value", healthDataPath)) +} + +func writeHealthResponse(w http.ResponseWriter, err error) { + if err != nil { + writer.JSON(w, http.StatusInternalServerError, types.HealthResponseV1{Error: err.Error()}, false) + return + } + + writer.JSONOK(w, types.HealthResponseV1{}, false) +} + +func (s *Server) v1CompilePost(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + explainMode := getExplain(r.URL.Query()[types.ParamExplainV1], types.ExplainOffV1) + includeInstrumentation := getBoolParam(r.URL, types.ParamInstrumentV1, true) + + m := metrics.New() + m.Timer(metrics.ServerHandler).Start() + m.Timer(metrics.RegoQueryParse).Start() + + // decompress the input if sent as zip + body, err := util.ReadMaybeCompressedBody(r) + if err != nil { + writer.Error(w, http.StatusBadRequest, types.NewErrorV1(types.CodeInvalidParameter, "could not decompress the body")) + return + } + + request, reqErr := readInputCompilePostV1(body, s.manager.ParserOptions()) + if reqErr != nil { + writer.Error(w, http.StatusBadRequest, reqErr) + return + } + + m.Timer(metrics.RegoQueryParse).Stop() + + c := storage.NewContext().WithMetrics(m) + txn, err := s.store.NewTransaction(ctx, storage.TransactionParams{Context: c}) + if err != nil { + writer.ErrorAuto(w, err) + return + } + + defer s.store.Abort(ctx, txn) + + var buf *topdown.BufferTracer + if explainMode != types.ExplainOffV1 { + buf = topdown.NewBufferTracer() + } + + eval := rego.New( + rego.Compiler(s.getCompiler()), + rego.Store(s.store), + rego.Transaction(txn), + rego.ParsedQuery(request.Query), + rego.ParsedInput(request.Input), + rego.ParsedUnknowns(request.Unknowns), + rego.DisableInlining(request.Options.DisableInlining), + rego.QueryTracer(buf), + rego.Instrument(includeInstrumentation), + rego.Metrics(m), + rego.Runtime(s.runtime), + rego.UnsafeBuiltins(unsafeBuiltinsMap), + rego.InterQueryBuiltinCache(s.interQueryBuiltinCache), + rego.InterQueryBuiltinValueCache(s.interQueryBuiltinValueCache), + rego.PrintHook(s.manager.PrintHook()), + ) + + pq, err := eval.Partial(ctx) + if err != nil { + switch err := err.(type) { + case ast.Errors: + writer.Error(w, http.StatusBadRequest, types.NewErrorV1(types.CodeInvalidParameter, types.MsgCompileModuleError).WithASTErrors(err)) + default: + writer.ErrorAuto(w, err) + } + return + } + + m.Timer(metrics.ServerHandler).Stop() + + result := types.CompileResponseV1{} + + if includeMetrics(r) || includeInstrumentation { + result.Metrics = m.All() + } + + if explainMode != types.ExplainOffV1 { + result.Explanation = s.getExplainResponse(explainMode, *buf, pretty(r)) + } + + var i interface{} = types.PartialEvaluationResultV1{ + Queries: pq.Queries, + Support: pq.Support, + } + + result.Result = &i + + writer.JSONOK(w, result, pretty(r)) +} + +func (s *Server) v1DataGet(w http.ResponseWriter, r *http.Request) { + m := metrics.New() + + m.Timer(metrics.ServerHandler).Start() + + decisionID := s.generateDecisionID() + ctx := logging.WithDecisionID(r.Context(), decisionID) + annotateSpan(ctx, decisionID) + + vars := mux.Vars(r) + urlPath := vars["path"] + explainMode := getExplain(r.URL.Query()["explain"], types.ExplainOffV1) + includeInstrumentation := getBoolParam(r.URL, types.ParamInstrumentV1, true) + provenance := getBoolParam(r.URL, types.ParamProvenanceV1, true) + strictBuiltinErrors := getBoolParam(r.URL, types.ParamStrictBuiltinErrors, true) + + m.Timer(metrics.RegoInputParse).Start() + + inputs := r.URL.Query()[types.ParamInputV1] + + var input ast.Value + var goInput *interface{} + + if len(inputs) > 0 { + var err error + input, goInput, err = readInputGetV1(inputs[len(inputs)-1]) + if err != nil { + writer.ErrorString(w, http.StatusBadRequest, types.CodeInvalidParameter, err) + return + } + } + + m.Timer(metrics.RegoInputParse).Stop() + + // Prepare for query. + c := storage.NewContext().WithMetrics(m) + txn, err := s.store.NewTransaction(ctx, storage.TransactionParams{Context: c}) + if err != nil { + writer.ErrorAuto(w, err) + return + } + defer s.store.Abort(ctx, txn) + + br, err := getRevisions(ctx, s.store, txn) + if err != nil { + writer.ErrorAuto(w, err) + return + } + + logger := s.getDecisionLogger(br) + + var ndbCache builtins.NDBCache + if s.ndbCacheEnabled { + ndbCache = builtins.NDBCache{} + } + + var buf *topdown.BufferTracer + + if explainMode != types.ExplainOffV1 { + buf = topdown.NewBufferTracer() + } + + pqID := "v1DataGet::" + if strictBuiltinErrors { + pqID += "strict-builtin-errors::" + } + pqID += urlPath + preparedQuery, ok := s.getCachedPreparedEvalQuery(pqID, m) + if !ok { + opts := []func(*rego.Rego){ + rego.Compiler(s.getCompiler()), + rego.Store(s.store), + } + + for _, r := range s.manager.GetWasmResolvers() { + for _, entrypoint := range r.Entrypoints() { + opts = append(opts, rego.Resolver(entrypoint, r)) + } + } + + rego, err := s.makeRego(ctx, strictBuiltinErrors, txn, input, urlPath, m, includeInstrumentation, buf, opts) + if err != nil { + _ = logger.Log(ctx, txn, urlPath, "", goInput, input, nil, ndbCache, err, m) + writer.ErrorAuto(w, err) + return + } + + pq, err := rego.PrepareForEval(ctx) + if err != nil { + _ = logger.Log(ctx, txn, urlPath, "", goInput, input, nil, ndbCache, err, m) + writer.ErrorAuto(w, err) + return + } + preparedQuery = &pq + s.preparedEvalQueries.Insert(pqID, preparedQuery) + } + + evalOpts := []rego.EvalOption{ + rego.EvalTransaction(txn), + rego.EvalParsedInput(input), + rego.EvalMetrics(m), + rego.EvalQueryTracer(buf), + rego.EvalInterQueryBuiltinCache(s.interQueryBuiltinCache), + rego.EvalInterQueryBuiltinValueCache(s.interQueryBuiltinValueCache), + rego.EvalInstrument(includeInstrumentation), + rego.EvalNDBuiltinCache(ndbCache), + } + + rs, err := preparedQuery.Eval( + ctx, + evalOpts..., + ) + + m.Timer(metrics.ServerHandler).Stop() + + // Handle results. + if err != nil { + _ = logger.Log(ctx, txn, urlPath, "", goInput, input, nil, ndbCache, err, m) + writer.ErrorAuto(w, err) + return + } + + result := types.DataResponseV1{ + DecisionID: decisionID, + } + + if includeMetrics(r) || includeInstrumentation { + result.Metrics = m.All() + } + + if provenance { + result.Provenance = s.getProvenance(br) + } + + if len(rs) == 0 { + if explainMode == types.ExplainFullV1 { + result.Explanation, err = types.NewTraceV1(lineage.Full(*buf), pretty(r)) + if err != nil { + writer.ErrorAuto(w, err) + return + } + } + + if err := logger.Log(ctx, txn, urlPath, "", goInput, input, nil, ndbCache, nil, m); err != nil { + writer.ErrorAuto(w, err) + return + } + writer.JSONOK(w, result, pretty(r)) + return + } + + result.Result = &rs[0].Expressions[0].Value + + if explainMode != types.ExplainOffV1 { + result.Explanation = s.getExplainResponse(explainMode, *buf, pretty(r)) + } + + if err := logger.Log(ctx, txn, urlPath, "", goInput, input, result.Result, ndbCache, nil, m); err != nil { + writer.ErrorAuto(w, err) + return + } + writer.JSONOK(w, result, pretty(r)) +} + +func (s *Server) v1DataPatch(w http.ResponseWriter, r *http.Request) { + m := metrics.New() + m.Timer(metrics.ServerHandler).Start() + defer m.Timer(metrics.ServerHandler).Stop() + + ctx := r.Context() + vars := mux.Vars(r) + var ops []types.PatchV1 + + m.Timer(metrics.RegoInputParse).Start() + if err := util.NewJSONDecoder(r.Body).Decode(&ops); err != nil { + writer.ErrorString(w, http.StatusBadRequest, types.CodeInvalidParameter, err) + return + } + m.Timer(metrics.RegoInputParse).Stop() + + patches, err := s.prepareV1PatchSlice(vars["path"], ops) + if err != nil { + writer.ErrorAuto(w, err) + return + } + + params := storage.WriteParams + params.Context = storage.NewContext().WithMetrics(m) + txn, err := s.store.NewTransaction(ctx, params) + if err != nil { + writer.ErrorAuto(w, err) + return + } + + for _, patch := range patches { + if err := s.checkPathScope(ctx, txn, patch.path); err != nil { + s.abortAuto(ctx, txn, w, err) + return + } + + if err := s.store.Write(ctx, txn, patch.op, patch.path, patch.value); err != nil { + s.abortAuto(ctx, txn, w, err) + return + } + } + + if err := ast.CheckPathConflicts(s.getCompiler(), storage.NonEmpty(ctx, s.store, txn)); len(err) > 0 { + s.store.Abort(ctx, txn) + writer.ErrorString(w, http.StatusBadRequest, types.CodeInvalidParameter, err) + return + } + + if err := s.store.Commit(ctx, txn); err != nil { + writer.ErrorAuto(w, err) + return + } + + if includeMetrics(r) { + result := types.DataResponseV1{ + Metrics: m.All(), + } + writer.JSONOK(w, result, false) + return + } + + w.WriteHeader(http.StatusNoContent) +} + +func (s *Server) v1DataPost(w http.ResponseWriter, r *http.Request) { + m := metrics.New() + m.Timer(metrics.ServerHandler).Start() + + decisionID := s.generateDecisionID() + ctx := logging.WithDecisionID(r.Context(), decisionID) + annotateSpan(ctx, decisionID) + + vars := mux.Vars(r) + urlPath := vars["path"] + explainMode := getExplain(r.URL.Query()[types.ParamExplainV1], types.ExplainOffV1) + includeInstrumentation := getBoolParam(r.URL, types.ParamInstrumentV1, true) + provenance := getBoolParam(r.URL, types.ParamProvenanceV1, true) + strictBuiltinErrors := getBoolParam(r.URL, types.ParamStrictBuiltinErrors, true) + + m.Timer(metrics.RegoInputParse).Start() + + input, goInput, err := readInputPostV1(r) + if err != nil { + writer.ErrorString(w, http.StatusBadRequest, types.CodeInvalidParameter, err) + return + } + + m.Timer(metrics.RegoInputParse).Stop() + + txn, err := s.store.NewTransaction(ctx, storage.TransactionParams{Context: storage.NewContext().WithMetrics(m)}) + if err != nil { + writer.ErrorAuto(w, err) + return + } + + defer s.store.Abort(ctx, txn) + + br, err := getRevisions(ctx, s.store, txn) + if err != nil { + writer.ErrorAuto(w, err) + return + } + + logger := s.getDecisionLogger(br) + + var ndbCache builtins.NDBCache + if s.ndbCacheEnabled { + ndbCache = builtins.NDBCache{} + } + + var buf *topdown.BufferTracer + + if explainMode != types.ExplainOffV1 { + buf = topdown.NewBufferTracer() + } + + pqID := "v1DataPost::" + if strictBuiltinErrors { + pqID += "strict-builtin-errors::" + } + pqID += urlPath + preparedQuery, ok := s.getCachedPreparedEvalQuery(pqID, m) + if !ok { + opts := []func(*rego.Rego){ + rego.Compiler(s.getCompiler()), + rego.Store(s.store), + } + + // Set resolvers on the base Rego object to avoid having them get + // re-initialized, and to propagate them to the prepared query. + for _, r := range s.manager.GetWasmResolvers() { + for _, entrypoint := range r.Entrypoints() { + opts = append(opts, rego.Resolver(entrypoint, r)) + } + } + + rego, err := s.makeRego(ctx, strictBuiltinErrors, txn, input, urlPath, m, includeInstrumentation, buf, opts) + if err != nil { + _ = logger.Log(ctx, txn, urlPath, "", goInput, input, nil, ndbCache, err, m) + writer.ErrorAuto(w, err) + return + } + + pq, err := rego.PrepareForEval(ctx) + if err != nil { + _ = logger.Log(ctx, txn, urlPath, "", goInput, input, nil, ndbCache, err, m) + writer.ErrorAuto(w, err) + return + } + preparedQuery = &pq + s.preparedEvalQueries.Insert(pqID, preparedQuery) + } + + evalOpts := []rego.EvalOption{ + rego.EvalTransaction(txn), + rego.EvalParsedInput(input), + rego.EvalMetrics(m), + rego.EvalQueryTracer(buf), + rego.EvalInterQueryBuiltinCache(s.interQueryBuiltinCache), + rego.EvalInterQueryBuiltinValueCache(s.interQueryBuiltinValueCache), + rego.EvalInstrument(includeInstrumentation), + rego.EvalNDBuiltinCache(ndbCache), + } + + rs, err := preparedQuery.Eval( + ctx, + evalOpts..., + ) + + m.Timer(metrics.ServerHandler).Stop() + + // Handle results. + if err != nil { + _ = logger.Log(ctx, txn, urlPath, "", goInput, input, nil, ndbCache, err, m) + writer.ErrorAuto(w, err) + return + } + + result := types.DataResponseV1{ + DecisionID: decisionID, + } + + if input == nil { + result.Warning = types.NewWarning(types.CodeAPIUsageWarn, types.MsgInputKeyMissing) + } + + if includeMetrics(r) || includeInstrumentation { + result.Metrics = m.All() + } + + if provenance { + result.Provenance = s.getProvenance(br) + } + + if len(rs) == 0 { + if explainMode == types.ExplainFullV1 { + result.Explanation, err = types.NewTraceV1(lineage.Full(*buf), pretty(r)) + if err != nil { + writer.ErrorAuto(w, err) + return + } + } + err = logger.Log(ctx, txn, urlPath, "", goInput, input, nil, ndbCache, nil, m) + if err != nil { + writer.ErrorAuto(w, err) + return + } + writer.JSONOK(w, result, pretty(r)) + return + } + + result.Result = &rs[0].Expressions[0].Value + + if explainMode != types.ExplainOffV1 { + result.Explanation = s.getExplainResponse(explainMode, *buf, pretty(r)) + } + + if err := logger.Log(ctx, txn, urlPath, "", goInput, input, result.Result, ndbCache, nil, m); err != nil { + writer.ErrorAuto(w, err) + return + } + writer.JSONOK(w, result, pretty(r)) +} + +func (s *Server) v1DataPut(w http.ResponseWriter, r *http.Request) { + m := metrics.New() + m.Timer(metrics.ServerHandler).Start() + defer m.Timer(metrics.ServerHandler).Stop() + + ctx := r.Context() + vars := mux.Vars(r) + + m.Timer(metrics.RegoInputParse).Start() + var value interface{} + if err := util.NewJSONDecoder(r.Body).Decode(&value); err != nil { + writer.ErrorString(w, http.StatusBadRequest, types.CodeInvalidParameter, err) + return + } + m.Timer(metrics.RegoInputParse).Stop() + + path, ok := storage.ParsePathEscaped("/" + strings.Trim(vars["path"], "/")) + if !ok { + writer.Error(w, http.StatusBadRequest, types.NewErrorV1(types.CodeInvalidParameter, "bad path: %v", vars["path"])) + return + } + + params := storage.WriteParams + params.Context = storage.NewContext().WithMetrics(m) + txn, err := s.store.NewTransaction(ctx, params) + if err != nil { + writer.ErrorAuto(w, err) + return + } + + if err := s.checkPathScope(ctx, txn, path); err != nil { + s.abortAuto(ctx, txn, w, err) + return + } + + _, err = s.store.Read(ctx, txn, path) + if err != nil { + if !storage.IsNotFound(err) { + s.abortAuto(ctx, txn, w, err) + return + } + if len(path) > 0 { + if err := storage.MakeDir(ctx, s.store, txn, path[:len(path)-1]); err != nil { + s.abortAuto(ctx, txn, w, err) + return + } + } + } else if r.Header.Get("If-None-Match") == "*" { + s.store.Abort(ctx, txn) + w.WriteHeader(http.StatusNotModified) + return + } + + if err := s.store.Write(ctx, txn, storage.AddOp, path, value); err != nil { + s.abortAuto(ctx, txn, w, err) + return + } + + if err := ast.CheckPathConflicts(s.getCompiler(), storage.NonEmpty(ctx, s.store, txn)); len(err) > 0 { + s.store.Abort(ctx, txn) + writer.ErrorString(w, http.StatusBadRequest, types.CodeInvalidParameter, err) + return + } + + if err := s.store.Commit(ctx, txn); err != nil { + writer.ErrorAuto(w, err) + return + } + + if includeMetrics(r) { + result := types.DataResponseV1{ + Metrics: m.All(), + } + writer.JSONOK(w, result, false) + return + } + + w.WriteHeader(http.StatusNoContent) +} + +func (s *Server) v1DataDelete(w http.ResponseWriter, r *http.Request) { + m := metrics.New() + m.Timer(metrics.ServerHandler).Start() + defer m.Timer(metrics.ServerHandler).Stop() + + ctx := r.Context() + vars := mux.Vars(r) + + path, ok := storage.ParsePathEscaped("/" + strings.Trim(vars["path"], "/")) + if !ok { + writer.Error(w, http.StatusBadRequest, types.NewErrorV1(types.CodeInvalidParameter, "bad path: %v", vars["path"])) + return + } + + params := storage.WriteParams + params.Context = storage.NewContext().WithMetrics(m) + txn, err := s.store.NewTransaction(ctx, params) + if err != nil { + writer.ErrorAuto(w, err) + return + } + + if err := s.checkPathScope(ctx, txn, path); err != nil { + s.abortAuto(ctx, txn, w, err) + return + } + + _, err = s.store.Read(ctx, txn, path) + if err != nil { + s.abortAuto(ctx, txn, w, err) + return + } + + if err := s.store.Write(ctx, txn, storage.RemoveOp, path, nil); err != nil { + s.abortAuto(ctx, txn, w, err) + return + } + + if err := s.store.Commit(ctx, txn); err != nil { + writer.ErrorAuto(w, err) + return + } + + if includeMetrics(r) { + result := types.DataResponseV1{ + Metrics: m.All(), + } + writer.JSONOK(w, result, false) + return + } + + w.WriteHeader(http.StatusNoContent) +} + +func (s *Server) v1PoliciesDelete(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + vars := mux.Vars(r) + + id, err := url.PathUnescape(vars["path"]) + if err != nil { + writer.ErrorString(w, http.StatusBadRequest, types.CodeInvalidParameter, err) + return + } + + m := metrics.New() + params := storage.WriteParams + params.Context = storage.NewContext().WithMetrics(m) + txn, err := s.store.NewTransaction(ctx, params) + if err != nil { + writer.ErrorAuto(w, err) + return + } + + if err := s.checkPolicyIDScope(ctx, txn, id); err != nil { + s.abortAuto(ctx, txn, w, err) + return + } + + modules, err := s.loadModules(ctx, txn) + if err != nil { + s.abortAuto(ctx, txn, w, err) + return + } + + delete(modules, id) + + c := ast.NewCompiler().SetErrorLimit(s.errLimit) + + m.Timer(metrics.RegoModuleCompile).Start() + + if c.Compile(modules); c.Failed() { + s.abort(ctx, txn, func() { + writer.Error(w, http.StatusBadRequest, types.NewErrorV1(types.CodeInvalidOperation, types.MsgCompileModuleError).WithASTErrors(c.Errors)) + }) + return + } + + m.Timer(metrics.RegoModuleCompile).Stop() + + if err := s.store.DeletePolicy(ctx, txn, id); err != nil { + s.abortAuto(ctx, txn, w, err) + return + } + + if err := s.store.Commit(ctx, txn); err != nil { + writer.ErrorAuto(w, err) + return + } + + resp := types.PolicyDeleteResponseV1{} + if includeMetrics(r) { + resp.Metrics = m.All() + } + + writer.JSONOK(w, resp, pretty(r)) +} + +func (s *Server) v1PoliciesGet(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + vars := mux.Vars(r) + + path, err := url.PathUnescape(vars["path"]) + if err != nil { + writer.ErrorString(w, http.StatusBadRequest, types.CodeInvalidParameter, err) + return + } + + txn, err := s.store.NewTransaction(ctx) + if err != nil { + writer.ErrorAuto(w, err) + return + } + + defer s.store.Abort(ctx, txn) + + bs, err := s.store.GetPolicy(ctx, txn, path) + if err != nil { + writer.ErrorAuto(w, err) + return + } + + c := s.getCompiler() + + resp := types.PolicyGetResponseV1{ + Result: types.PolicyV1{ + ID: path, + Raw: string(bs), + AST: c.Modules[path], + }, + } + + writer.JSONOK(w, resp, pretty(r)) +} + +func (s *Server) v1PoliciesList(w http.ResponseWriter, r *http.Request) { + + ctx := r.Context() + + txn, err := s.store.NewTransaction(ctx) + if err != nil { + writer.ErrorAuto(w, err) + return + } + + defer s.store.Abort(ctx, txn) + + policies := []types.PolicyV1{} + c := s.getCompiler() + + // Only return policies from the store, the compiler + // may contain additional partially compiled modules. + ids, err := s.store.ListPolicies(ctx, txn) + if err != nil { + writer.ErrorAuto(w, err) + return + } + for _, id := range ids { + bs, err := s.store.GetPolicy(ctx, txn, id) + if err != nil { + writer.ErrorAuto(w, err) + return + } + policy := types.PolicyV1{ + ID: id, + Raw: string(bs), + AST: c.Modules[id], + } + policies = append(policies, policy) + } + + writer.JSONOK(w, types.PolicyListResponseV1{Result: policies}, pretty(r)) +} + +func (s *Server) v1PoliciesPut(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + vars := mux.Vars(r) + + id, err := url.PathUnescape(vars["path"]) + if err != nil { + writer.ErrorString(w, http.StatusBadRequest, types.CodeInvalidParameter, err) + return + } + + includeMetrics := includeMetrics(r) + m := metrics.New() + + m.Timer("server_read_bytes").Start() + + buf, err := io.ReadAll(r.Body) + if err != nil { + writer.ErrorString(w, http.StatusBadRequest, types.CodeInvalidParameter, err) + return + } + + m.Timer("server_read_bytes").Stop() + + params := storage.WriteParams + params.Context = storage.NewContext().WithMetrics(m) + txn, err := s.store.NewTransaction(ctx, params) + if err != nil { + writer.ErrorAuto(w, err) + return + } + + if err := s.checkPolicyIDScope(ctx, txn, id); err != nil && !storage.IsNotFound(err) { + s.abortAuto(ctx, txn, w, err) + return + } + + if bs, err := s.store.GetPolicy(ctx, txn, id); err != nil { + if !storage.IsNotFound(err) { + s.abortAuto(ctx, txn, w, err) + return + } + } else if bytes.Equal(buf, bs) { + s.store.Abort(ctx, txn) + resp := types.PolicyPutResponseV1{} + if includeMetrics { + resp.Metrics = m.All() + } + writer.JSONOK(w, resp, pretty(r)) + return + } + + m.Timer(metrics.RegoModuleParse).Start() + parsedMod, err := ast.ParseModuleWithOpts(id, string(buf), s.manager.ParserOptions()) + m.Timer(metrics.RegoModuleParse).Stop() + + if err != nil { + s.store.Abort(ctx, txn) + switch err := err.(type) { + case ast.Errors: + writer.Error(w, http.StatusBadRequest, types.NewErrorV1(types.CodeInvalidParameter, types.MsgCompileModuleError).WithASTErrors(err)) + default: + writer.ErrorString(w, http.StatusBadRequest, types.CodeInvalidParameter, err) + } + return + } + + if parsedMod == nil { + s.store.Abort(ctx, txn) + writer.Error(w, http.StatusBadRequest, types.NewErrorV1(types.CodeInvalidParameter, "empty module")) + return + } + + if err := s.checkPolicyPackageScope(ctx, txn, parsedMod.Package); err != nil { + s.abortAuto(ctx, txn, w, err) + return + } + + modules, err := s.loadModules(ctx, txn) + if err != nil { + s.abortAuto(ctx, txn, w, err) + return + } + + modules[id] = parsedMod + + c := ast.NewCompiler(). + SetErrorLimit(s.errLimit). + WithPathConflictsCheck(storage.NonEmpty(ctx, s.store, txn)). + WithEnablePrintStatements(s.manager.EnablePrintStatements()) + + m.Timer(metrics.RegoModuleCompile).Start() + + if c.Compile(modules); c.Failed() { + s.abort(ctx, txn, func() { + writer.Error(w, http.StatusBadRequest, types.NewErrorV1(types.CodeInvalidParameter, types.MsgCompileModuleError).WithASTErrors(c.Errors)) + }) + return + } + + m.Timer(metrics.RegoModuleCompile).Stop() + + if err := s.store.UpsertPolicy(ctx, txn, id, buf); err != nil { + s.abortAuto(ctx, txn, w, err) + return + } + + if err := s.store.Commit(ctx, txn); err != nil { + writer.ErrorAuto(w, err) + return + } + + resp := types.PolicyPutResponseV1{} + + if includeMetrics { + resp.Metrics = m.All() + } + + writer.JSONOK(w, resp, pretty(r)) +} + +func (s *Server) v1QueryGet(w http.ResponseWriter, r *http.Request) { + m := metrics.New() + + decisionID := s.generateDecisionID() + ctx := logging.WithDecisionID(r.Context(), decisionID) + annotateSpan(ctx, decisionID) + + values := r.URL.Query() + + qStrs := values[types.ParamQueryV1] + if len(qStrs) == 0 { + writer.Error(w, http.StatusBadRequest, types.NewErrorV1(types.CodeInvalidParameter, "missing parameter 'q'")) + return + } + qStr := qStrs[len(qStrs)-1] + + parsedQuery, err := validateQuery(qStr, s.manager.ParserOptions()) + if err != nil { + switch err := err.(type) { + case ast.Errors: + writer.Error(w, http.StatusBadRequest, types.NewErrorV1(types.CodeInvalidParameter, types.MsgParseQueryError).WithASTErrors(err)) + default: + writer.ErrorAuto(w, err) + } + return + } + + explainMode := getExplain(r.URL.Query()["explain"], types.ExplainOffV1) + includeInstrumentation := getBoolParam(r.URL, types.ParamInstrumentV1, true) + + params := storage.TransactionParams{Context: storage.NewContext().WithMetrics(m)} + txn, err := s.store.NewTransaction(ctx, params) + if err != nil { + writer.ErrorAuto(w, err) + return + } + + defer s.store.Abort(ctx, txn) + + br, err := getRevisions(ctx, s.store, txn) + if err != nil { + writer.ErrorAuto(w, err) + return + } + pretty := pretty(r) + results, err := s.execQuery(ctx, br, txn, parsedQuery, nil, nil, m, explainMode, includeMetrics(r), includeInstrumentation, pretty) + if err != nil { + switch err := err.(type) { + case ast.Errors: + writer.Error(w, http.StatusBadRequest, types.NewErrorV1(types.CodeInvalidParameter, types.MsgCompileQueryError).WithASTErrors(err)) + default: + writer.ErrorAuto(w, err) + } + return + } + + writer.JSONOK(w, results, pretty) +} + +func (s *Server) v1QueryPost(w http.ResponseWriter, r *http.Request) { + m := metrics.New() + m.Timer(metrics.ServerHandler).Start() + + decisionID := s.generateDecisionID() + ctx := logging.WithDecisionID(r.Context(), decisionID) + annotateSpan(ctx, decisionID) + + var request types.QueryRequestV1 + err := util.NewJSONDecoder(r.Body).Decode(&request) + if err != nil { + writer.Error(w, http.StatusBadRequest, types.NewErrorV1(types.CodeInvalidParameter, "error(s) occurred while decoding request: %v", err.Error())) + return + } + qStr := request.Query + parsedQuery, err := validateQuery(qStr, s.manager.ParserOptions()) + if err != nil { + switch err := err.(type) { + case ast.Errors: + writer.Error(w, http.StatusBadRequest, types.NewErrorV1(types.CodeInvalidParameter, types.MsgParseQueryError).WithASTErrors(err)) + default: + writer.ErrorAuto(w, err) + } + return + } + + pretty := pretty(r) + explainMode := getExplain(r.URL.Query()["explain"], types.ExplainOffV1) + includeMetrics := includeMetrics(r) + includeInstrumentation := getBoolParam(r.URL, types.ParamInstrumentV1, true) + + var input ast.Value + + if request.Input != nil { + input, err = ast.InterfaceToValue(*request.Input) + if err != nil { + writer.ErrorAuto(w, err) + return + } + } + + params := storage.TransactionParams{Context: storage.NewContext().WithMetrics(m)} + txn, err := s.store.NewTransaction(ctx, params) + if err != nil { + writer.ErrorAuto(w, err) + return + } + + defer s.store.Abort(ctx, txn) + + br, err := getRevisions(ctx, s.store, txn) + if err != nil { + writer.ErrorAuto(w, err) + return + } + + results, err := s.execQuery(ctx, br, txn, parsedQuery, input, request.Input, m, explainMode, includeMetrics, includeInstrumentation, pretty) + if err != nil { + switch err := err.(type) { + case ast.Errors: + writer.Error(w, http.StatusBadRequest, types.NewErrorV1(types.CodeInvalidParameter, types.MsgCompileQueryError).WithASTErrors(err)) + default: + writer.ErrorAuto(w, err) + } + return + } + + m.Timer(metrics.ServerHandler).Stop() + + if includeMetrics || includeInstrumentation { + results.Metrics = m.All() + } + + writer.JSONOK(w, results, pretty) +} + +func (s *Server) v1ConfigGet(w http.ResponseWriter, r *http.Request) { + result, err := s.manager.Config.ActiveConfig() + if err != nil { + writer.ErrorAuto(w, err) + return + } + writer.JSONOK(w, types.ConfigResponseV1{Result: &result}, pretty(r)) +} + +func (s *Server) v1StatusGet(w http.ResponseWriter, r *http.Request) { + p := status.Lookup(s.manager) + if p == nil { + writer.ErrorString(w, http.StatusInternalServerError, types.CodeInternal, errors.New("status plugin not enabled")) + return + } + + var st interface{} = p.Snapshot() + writer.JSONOK(w, types.StatusResponseV1{Result: &st}, pretty(r)) +} + +func (s *Server) checkPolicyIDScope(ctx context.Context, txn storage.Transaction, id string) error { + + bs, err := s.store.GetPolicy(ctx, txn, id) + if err != nil { + return err + } + + module, err := ast.ParseModuleWithOpts(id, string(bs), s.manager.ParserOptions()) + if err != nil { + return err + } + + return s.checkPolicyPackageScope(ctx, txn, module.Package) +} + +func (s *Server) checkPolicyPackageScope(ctx context.Context, txn storage.Transaction, pkg *ast.Package) error { + + path, err := pkg.Path.Ptr() + if err != nil { + return err + } + + spath, ok := storage.ParsePathEscaped("/" + path) + if !ok { + return types.BadRequestErr("invalid package path: cannot determine scope") + } + + return s.checkPathScope(ctx, txn, spath) +} + +func (s *Server) checkPathScope(ctx context.Context, txn storage.Transaction, path storage.Path) error { + + names, err := bundle.ReadBundleNamesFromStore(ctx, s.store, txn) + if err != nil { + if !storage.IsNotFound(err) { + return err + } + return nil + } + + bundleRoots := map[string][]string{} + for _, name := range names { + roots, err := bundle.ReadBundleRootsFromStore(ctx, s.store, txn, name) + if err != nil && !storage.IsNotFound(err) { + return err + } + bundleRoots[name] = roots + } + + spath := strings.Trim(path.String(), "/") + + if spath == "" && len(bundleRoots) > 0 { + return types.BadRequestErr("can't write to document root with bundle roots configured") + } + + spathParts := strings.Split(spath, "/") + + for name, roots := range bundleRoots { + if roots == nil { + return types.BadRequestErr(fmt.Sprintf("all paths owned by bundle %q", name)) + } + for _, root := range roots { + if root == "" { + return types.BadRequestErr(fmt.Sprintf("all paths owned by bundle %q", name)) + } + if isPathOwned(spathParts, strings.Split(root, "/")) { + return types.BadRequestErr(fmt.Sprintf("path %v is owned by bundle %q", spath, name)) + } + } + } + + return nil +} + +func (s *Server) getDecisionLogger(br bundleRevisions) (logger decisionLogger) { + // For backwards compatibility use `revision` as needed. + if s.hasLegacyBundle(br) { + logger.revision = br.LegacyRevision + } else { + logger.revisions = br.Revisions + } + logger.logger = s.logger + return logger +} + +func (s *Server) getExplainResponse(explainMode types.ExplainModeV1, trace []*topdown.Event, pretty bool) (explanation types.TraceV1) { + switch explainMode { + case types.ExplainNotesV1: + var err error + explanation, err = types.NewTraceV1(lineage.Notes(trace), pretty) + if err != nil { + break + } + case types.ExplainFailsV1: + var err error + explanation, err = types.NewTraceV1(lineage.Fails(trace), pretty) + if err != nil { + break + } + case types.ExplainFullV1: + var err error + explanation, err = types.NewTraceV1(lineage.Full(trace), pretty) + if err != nil { + break + } + case types.ExplainDebugV1: + var err error + explanation, err = types.NewTraceV1(lineage.Debug(trace), pretty) + if err != nil { + break + } + } + return explanation +} + +func (s *Server) abort(ctx context.Context, txn storage.Transaction, finish func()) { + s.store.Abort(ctx, txn) + finish() +} + +func (s *Server) abortAuto(ctx context.Context, txn storage.Transaction, w http.ResponseWriter, err error) { + s.abort(ctx, txn, func() { writer.ErrorAuto(w, err) }) +} + +func (s *Server) loadModules(ctx context.Context, txn storage.Transaction) (map[string]*ast.Module, error) { + + ids, err := s.store.ListPolicies(ctx, txn) + if err != nil { + return nil, err + } + + modules := make(map[string]*ast.Module, len(ids)) + + for _, id := range ids { + bs, err := s.store.GetPolicy(ctx, txn, id) + if err != nil { + return nil, err + } + + parsed, err := ast.ParseModuleWithOpts(id, string(bs), s.manager.ParserOptions()) + if err != nil { + return nil, err + } + + modules[id] = parsed + } + + return modules, nil +} + +func (s *Server) getCompiler() *ast.Compiler { + return s.manager.GetCompiler() +} + +func (s *Server) makeRego(_ context.Context, + strictBuiltinErrors bool, + txn storage.Transaction, + input ast.Value, + urlPath string, + m metrics.Metrics, + instrument bool, + tracer topdown.QueryTracer, + opts []func(*rego.Rego), +) (*rego.Rego, error) { + queryPath := stringPathToDataRef(urlPath).String() + + opts = append( + opts, + rego.Transaction(txn), + rego.Query(queryPath), + rego.ParsedInput(input), + rego.Metrics(m), + rego.QueryTracer(tracer), + rego.Instrument(instrument), + rego.Runtime(s.runtime), + rego.UnsafeBuiltins(unsafeBuiltinsMap), + rego.StrictBuiltinErrors(strictBuiltinErrors), + rego.PrintHook(s.manager.PrintHook()), + rego.DistributedTracingOpts(s.distributedTracingOpts), + ) + + return rego.New(opts...), nil +} + +func (s *Server) prepareV1PatchSlice(root string, ops []types.PatchV1) (result []patchImpl, err error) { + + root = "/" + strings.Trim(root, "/") + + for _, op := range ops { + + impl := patchImpl{ + value: op.Value, + } + + // Map patch operation. + switch op.Op { + case "add": + impl.op = storage.AddOp + case "remove": + impl.op = storage.RemoveOp + case "replace": + impl.op = storage.ReplaceOp + default: + return nil, types.BadPatchOperationErr(op.Op) + } + + // Construct patch path. + path := strings.Trim(op.Path, "/") + if len(path) > 0 { + if root == "/" { + path = root + path + } else { + path = root + "/" + path + } + } else { + path = root + } + + var ok bool + impl.path, ok = patch.ParsePatchPathEscaped(path) + if !ok { + return nil, types.BadPatchPathErr(op.Path) + } + + result = append(result, impl) + } + + return result, nil +} + +func (s *Server) generateDecisionID() string { + if s.decisionIDFactory != nil { + return s.decisionIDFactory() + } + return "" +} + +func (s *Server) getProvenance(br bundleRevisions) *types.ProvenanceV1 { + + p := &types.ProvenanceV1{ + Version: version.Version, + Vcs: version.Vcs, + Timestamp: version.Timestamp, + Hostname: version.Hostname, + } + + // For backwards compatibility, if the bundles are using the old + // style config we need to fill in the older `Revision` field. + // Otherwise use the newer `Bundles` keyword. + if s.hasLegacyBundle(br) { + p.Revision = br.LegacyRevision + } else { + p.Bundles = map[string]types.ProvenanceBundleV1{} + for name, revision := range br.Revisions { + p.Bundles[name] = types.ProvenanceBundleV1{Revision: revision} + } + } + + return p +} + +func (s *Server) hasLegacyBundle(br bundleRevisions) bool { + bp := bundlePlugin.Lookup(s.manager) + return br.LegacyRevision != "" || (bp != nil && !bp.Config().IsMultiBundle()) +} + +func (s *Server) generateDefaultDecisionPath() string { + // Assume the path is safe to transition back to a url + p, _ := s.manager.Config.DefaultDecisionRef().Ptr() + return p +} + +func isPathOwned(path, root []string) bool { + for i := 0; i < len(path) && i < len(root); i++ { + if path[i] != root[i] { + return false + } + } + return true +} + +func (s *Server) updateCacheConfig(cacheConfig *iCache.Config) { + s.interQueryBuiltinCache.UpdateConfig(cacheConfig) + s.interQueryBuiltinValueCache.UpdateConfig(cacheConfig) +} + +func (s *Server) updateNDCache(enabled bool) { + s.mtx.Lock() + defer s.mtx.Unlock() + s.ndbCacheEnabled = enabled +} + +func stringPathToDataRef(s string) (r ast.Ref) { + result := ast.Ref{ast.DefaultRootDocument} + return append(result, stringPathToRef(s)...) +} + +func stringPathToRef(s string) (r ast.Ref) { + if len(s) == 0 { + return r + } + p := strings.Split(s, "/") + for _, x := range p { + if x == "" { + continue + } + if y, err := url.PathUnescape(x); err == nil { + x = y + } + i, err := strconv.Atoi(x) + if err != nil { + r = append(r, ast.StringTerm(x)) + } else { + r = append(r, ast.IntNumberTerm(i)) + } + } + return r +} + +func validateQuery(query string, opts ast.ParserOptions) (ast.Body, error) { + return ast.ParseBodyWithOpts(query, opts) +} + +func getBoolParam(url *url.URL, name string, ifEmpty bool) bool { + + p, ok := url.Query()[name] + if !ok { + return false + } + + // Query params w/o values are represented as slice (of len 1) with an + // empty string. + if len(p) == 1 && p[0] == "" { + return ifEmpty + } + + for _, x := range p { + if strings.ToLower(x) == "true" { + return true + } + } + + return false +} + +func getStringSliceParam(url *url.URL, name string) []string { + + p, ok := url.Query()[name] + if !ok { + return nil + } + + // Query params w/o values are represented as slice (of len 1) with an + // empty string. + if len(p) == 1 && p[0] == "" { + return nil + } + + return p +} + +func getExplain(p []string, zero types.ExplainModeV1) types.ExplainModeV1 { + for _, x := range p { + switch x { + case string(types.ExplainNotesV1): + return types.ExplainNotesV1 + case string(types.ExplainFailsV1): + return types.ExplainFailsV1 + case string(types.ExplainFullV1): + return types.ExplainFullV1 + case string(types.ExplainDebugV1): + return types.ExplainDebugV1 + } + } + return zero +} + +func readInputV0(r *http.Request) (ast.Value, *interface{}, error) { + + parsed, ok := authorizer.GetBodyOnContext(r.Context()) + if ok { + v, err := ast.InterfaceToValue(parsed) + return v, &parsed, err + } + + // decompress the input if sent as zip + bodyBytes, err := util.ReadMaybeCompressedBody(r) + if err != nil { + return nil, nil, fmt.Errorf("could not decompress the body: %w", err) + } + + var x interface{} + + if strings.Contains(r.Header.Get("Content-Type"), "yaml") { + if len(bodyBytes) > 0 { + if err = util.Unmarshal(bodyBytes, &x); err != nil { + return nil, nil, fmt.Errorf("body contains malformed input document: %w", err) + } + } + } else { + dec := util.NewJSONDecoder(bytes.NewBuffer(bodyBytes)) + if err := dec.Decode(&x); err != nil && err != io.EOF { + return nil, nil, fmt.Errorf("body contains malformed input document: %w", err) + } + } + + v, err := ast.InterfaceToValue(x) + return v, &x, err +} + +func readInputGetV1(str string) (ast.Value, *interface{}, error) { + var input interface{} + if err := util.UnmarshalJSON([]byte(str), &input); err != nil { + return nil, nil, fmt.Errorf("parameter contains malformed input document: %w", err) + } + v, err := ast.InterfaceToValue(input) + return v, &input, err +} + +func readInputPostV1(r *http.Request) (ast.Value, *interface{}, error) { + + parsed, ok := authorizer.GetBodyOnContext(r.Context()) + if ok { + if obj, ok := parsed.(map[string]interface{}); ok { + if input, ok := obj["input"]; ok { + v, err := ast.InterfaceToValue(input) + return v, &input, err + } + } + return nil, nil, nil + } + + var request types.DataRequestV1 + + // decompress the input if sent as zip + bodyBytes, err := util.ReadMaybeCompressedBody(r) + if err != nil { + return nil, nil, fmt.Errorf("could not decompress the body: %w", err) + } + + ct := r.Header.Get("Content-Type") + // There is no standard for yaml mime-type so we just look for + // anything related + if strings.Contains(ct, "yaml") { + if len(bodyBytes) > 0 { + if err = util.Unmarshal(bodyBytes, &request); err != nil { + return nil, nil, fmt.Errorf("body contains malformed input document: %w", err) + } + } + } else { + dec := util.NewJSONDecoder(bytes.NewBuffer(bodyBytes)) + if err := dec.Decode(&request); err != nil && err != io.EOF { + return nil, nil, fmt.Errorf("body contains malformed input document: %w", err) + } + } + + if request.Input == nil { + return nil, nil, nil + } + + v, err := ast.InterfaceToValue(*request.Input) + return v, request.Input, err +} + +type compileRequest struct { + Query ast.Body + Input ast.Value + Unknowns []*ast.Term + Options compileRequestOptions +} + +type compileRequestOptions struct { + DisableInlining []string +} + +func readInputCompilePostV1(reqBytes []byte, queryParserOptions ast.ParserOptions) (*compileRequest, *types.ErrorV1) { + var request types.CompileRequestV1 + + err := util.NewJSONDecoder(bytes.NewBuffer(reqBytes)).Decode(&request) + if err != nil { + return nil, types.NewErrorV1(types.CodeInvalidParameter, "error(s) occurred while decoding request: %v", err.Error()) + } + + query, err := ast.ParseBodyWithOpts(request.Query, queryParserOptions) + if err != nil { + switch err := err.(type) { + case ast.Errors: + return nil, types.NewErrorV1(types.CodeInvalidParameter, types.MsgParseQueryError).WithASTErrors(err) + default: + return nil, types.NewErrorV1(types.CodeInvalidParameter, "%v: %v", types.MsgParseQueryError, err) + } + } else if len(query) == 0 { + return nil, types.NewErrorV1(types.CodeInvalidParameter, "missing required 'query' value") + } + + var input ast.Value + if request.Input != nil { + input, err = ast.InterfaceToValue(*request.Input) + if err != nil { + return nil, types.NewErrorV1(types.CodeInvalidParameter, "error(s) occurred while converting input: %v", err) + } + } + + var unknowns []*ast.Term + if request.Unknowns != nil { + unknowns = make([]*ast.Term, len(*request.Unknowns)) + for i, s := range *request.Unknowns { + unknowns[i], err = ast.ParseTerm(s) + if err != nil { + return nil, types.NewErrorV1(types.CodeInvalidParameter, "error(s) occurred while parsing unknowns: %v", err) + } + } + } + + result := &compileRequest{ + Query: query, + Input: input, + Unknowns: unknowns, + Options: compileRequestOptions{ + DisableInlining: request.Options.DisableInlining, + }, + } + + return result, nil +} + +var indexHTML, _ = template.New("index").Parse(` + + + + + +
+ ________      ________    ________
+|\   __  \    |\   __  \  |\   __  \
+\ \  \|\  \   \ \  \|\  \ \ \  \|\  \
+ \ \  \\\  \   \ \   ____\ \ \   __  \
+  \ \  \\\  \   \ \  \___|  \ \  \ \  \
+   \ \_______\   \ \__\      \ \__\ \__\
+    \|_______|    \|__|       \|__|\|__|
+
+Open Policy Agent - An open source project to policy-enable your service.
+
+Version: {{ .Version }}
+Build Commit: {{ .BuildCommit }}
+Build Timestamp: {{ .BuildTimestamp }}
+Build Hostname: {{ .BuildHostname }}
+
+Query:
+
+
Input Data (JSON):
+
+
+
+ + +`) + +type decisionLogger struct { + revisions map[string]string + revision string // Deprecated: Use `revisions` instead. + logger func(context.Context, *Info) error +} + +func (l decisionLogger) Log(ctx context.Context, txn storage.Transaction, path string, query string, goInput *interface{}, astInput ast.Value, goResults *interface{}, ndbCache builtins.NDBCache, err error, m metrics.Metrics) error { + + bundles := map[string]BundleInfo{} + for name, rev := range l.revisions { + bundles[name] = BundleInfo{Revision: rev} + } + + rctx := logging.RequestContext{} + if r, ok := logging.FromContext(ctx); ok { + rctx = *r + } + decisionID, _ := logging.DecisionIDFromContext(ctx) + + var httpRctx logging.HTTPRequestContext + + httpRctxVal, _ := logging.HTTPRequestContextFromContext(ctx) + if httpRctxVal != nil { + httpRctx = *httpRctxVal + } + + info := &Info{ + Txn: txn, + Revision: l.revision, + Bundles: bundles, + Timestamp: time.Now().UTC(), + DecisionID: decisionID, + RemoteAddr: rctx.ClientAddr, + HTTPRequestContext: httpRctx, + Path: path, + Query: query, + Input: goInput, + InputAST: astInput, + Results: goResults, + Error: err, + Metrics: m, + RequestID: rctx.ReqID, + } + + if ndbCache != nil { + x, err := ast.JSON(ndbCache.AsValue()) + if err != nil { + return err + } + info.NDBuiltinCache = &x + } + + sctx := trace.SpanFromContext(ctx).SpanContext() + if sctx.IsValid() { + info.TraceID = sctx.TraceID().String() + info.SpanID = sctx.SpanID().String() + } + + if l.logger != nil { + if err := l.logger(ctx, info); err != nil { + return fmt.Errorf("decision_logs: %w", err) + } + } + + return nil +} + +type patchImpl struct { + path storage.Path + op storage.PatchOp + value interface{} +} + +func parseURL(s string, useHTTPSByDefault bool) (*url.URL, error) { + if !strings.Contains(s, "://") { + scheme := "http://" + if useHTTPSByDefault { + scheme = "https://" + } + s = scheme + s + } + return url.Parse(s) +} + +func annotateSpan(ctx context.Context, decisionID string) { + if decisionID == "" { + return + } + trace.SpanFromContext(ctx). + SetAttributes(attribute.String(otelDecisionIDAttr, decisionID)) +} + +func pretty(r *http.Request) bool { + return getBoolParam(r.URL, types.ParamPrettyV1, true) +} + +func includeMetrics(r *http.Request) bool { + return getBoolParam(r.URL, types.ParamMetricsV1, true) +} diff --git a/vendor/github.com/open-policy-agent/opa/server/types/types.go b/vendor/github.com/open-policy-agent/opa/v1/server/types/types.go similarity index 99% rename from vendor/github.com/open-policy-agent/opa/server/types/types.go rename to vendor/github.com/open-policy-agent/opa/v1/server/types/types.go index 37917913d..ca9f21549 100644 --- a/vendor/github.com/open-policy-agent/opa/server/types/types.go +++ b/vendor/github.com/open-policy-agent/opa/v1/server/types/types.go @@ -11,9 +11,9 @@ import ( "fmt" "strings" - "github.com/open-policy-agent/opa/ast" - "github.com/open-policy-agent/opa/topdown" - "github.com/open-policy-agent/opa/util" + "github.com/open-policy-agent/opa/v1/ast" + "github.com/open-policy-agent/opa/v1/topdown" + "github.com/open-policy-agent/opa/v1/util" ) // Error codes returned by OPA's REST API. diff --git a/vendor/github.com/open-policy-agent/opa/server/writer/writer.go b/vendor/github.com/open-policy-agent/opa/v1/server/writer/writer.go similarity index 93% rename from vendor/github.com/open-policy-agent/opa/server/writer/writer.go rename to vendor/github.com/open-policy-agent/opa/v1/server/writer/writer.go index 09cda5fb7..8cf1e3090 100644 --- a/vendor/github.com/open-policy-agent/opa/server/writer/writer.go +++ b/vendor/github.com/open-policy-agent/opa/v1/server/writer/writer.go @@ -9,9 +9,9 @@ import ( "encoding/json" "net/http" - "github.com/open-policy-agent/opa/server/types" - "github.com/open-policy-agent/opa/storage" - "github.com/open-policy-agent/opa/topdown" + "github.com/open-policy-agent/opa/v1/server/types" + "github.com/open-policy-agent/opa/v1/storage" + "github.com/open-policy-agent/opa/v1/topdown" ) // HTTPStatus is used to set a specific status code @@ -44,7 +44,7 @@ func ErrorAuto(w http.ResponseWriter, err error) { // ErrorString writes a response with specified status, code, and message set to // the err's string representation. func ErrorString(w http.ResponseWriter, status int, code string, err error) { - Error(w, status, types.NewErrorV1(code, err.Error())) + Error(w, status, types.NewErrorV1(code, err.Error())) //nolint:govet } // Error writes a response with specified status and error response. diff --git a/vendor/github.com/open-policy-agent/opa/storage/disk/config.go b/vendor/github.com/open-policy-agent/opa/v1/storage/disk/config.go similarity index 93% rename from vendor/github.com/open-policy-agent/opa/storage/disk/config.go rename to vendor/github.com/open-policy-agent/opa/v1/storage/disk/config.go index bab976ceb..5bb6f6046 100644 --- a/vendor/github.com/open-policy-agent/opa/storage/disk/config.go +++ b/vendor/github.com/open-policy-agent/opa/v1/storage/disk/config.go @@ -10,9 +10,9 @@ import ( "os" badger "github.com/dgraph-io/badger/v3" - "github.com/open-policy-agent/opa/config" - "github.com/open-policy-agent/opa/storage" - "github.com/open-policy-agent/opa/util" + "github.com/open-policy-agent/opa/v1/config" + "github.com/open-policy-agent/opa/v1/storage" + "github.com/open-policy-agent/opa/v1/util" ) type cfg struct { diff --git a/vendor/github.com/open-policy-agent/opa/storage/disk/disk.go b/vendor/github.com/open-policy-agent/opa/v1/storage/disk/disk.go similarity index 99% rename from vendor/github.com/open-policy-agent/opa/storage/disk/disk.go rename to vendor/github.com/open-policy-agent/opa/v1/storage/disk/disk.go index 0ebd980ff..8b35e60df 100644 --- a/vendor/github.com/open-policy-agent/opa/storage/disk/disk.go +++ b/vendor/github.com/open-policy-agent/opa/v1/storage/disk/disk.go @@ -74,9 +74,9 @@ import ( badger "github.com/dgraph-io/badger/v3" "github.com/prometheus/client_golang/prometheus" - "github.com/open-policy-agent/opa/logging" - "github.com/open-policy-agent/opa/storage" - "github.com/open-policy-agent/opa/util" + "github.com/open-policy-agent/opa/v1/logging" + "github.com/open-policy-agent/opa/v1/storage" + "github.com/open-policy-agent/opa/v1/util" ) // TODO(tsandall): add support for migrations diff --git a/vendor/github.com/open-policy-agent/opa/storage/disk/errors.go b/vendor/github.com/open-policy-agent/opa/v1/storage/disk/errors.go similarity index 93% rename from vendor/github.com/open-policy-agent/opa/storage/disk/errors.go rename to vendor/github.com/open-policy-agent/opa/v1/storage/disk/errors.go index 0a3c987bc..53102f7f0 100644 --- a/vendor/github.com/open-policy-agent/opa/storage/disk/errors.go +++ b/vendor/github.com/open-policy-agent/opa/v1/storage/disk/errors.go @@ -5,7 +5,7 @@ package disk import ( - "github.com/open-policy-agent/opa/storage" + "github.com/open-policy-agent/opa/v1/storage" ) var errNotFound = &storage.Error{Code: storage.NotFoundErr} diff --git a/vendor/github.com/open-policy-agent/opa/storage/disk/metrics.go b/vendor/github.com/open-policy-agent/opa/v1/storage/disk/metrics.go similarity index 100% rename from vendor/github.com/open-policy-agent/opa/storage/disk/metrics.go rename to vendor/github.com/open-policy-agent/opa/v1/storage/disk/metrics.go diff --git a/vendor/github.com/open-policy-agent/opa/storage/disk/partition.go b/vendor/github.com/open-policy-agent/opa/v1/storage/disk/partition.go similarity index 95% rename from vendor/github.com/open-policy-agent/opa/storage/disk/partition.go rename to vendor/github.com/open-policy-agent/opa/v1/storage/disk/partition.go index a10fc9640..d93bc33ef 100644 --- a/vendor/github.com/open-policy-agent/opa/storage/disk/partition.go +++ b/vendor/github.com/open-policy-agent/opa/v1/storage/disk/partition.go @@ -5,7 +5,7 @@ package disk import ( - "github.com/open-policy-agent/opa/storage" + "github.com/open-policy-agent/opa/v1/storage" ) type partitionTrie struct { diff --git a/vendor/github.com/open-policy-agent/opa/storage/disk/paths.go b/vendor/github.com/open-policy-agent/opa/v1/storage/disk/paths.go similarity index 98% rename from vendor/github.com/open-policy-agent/opa/storage/disk/paths.go rename to vendor/github.com/open-policy-agent/opa/v1/storage/disk/paths.go index f231608c9..53ef9a7ec 100644 --- a/vendor/github.com/open-policy-agent/opa/storage/disk/paths.go +++ b/vendor/github.com/open-policy-agent/opa/v1/storage/disk/paths.go @@ -9,7 +9,7 @@ import ( "sort" "strings" - "github.com/open-policy-agent/opa/storage" + "github.com/open-policy-agent/opa/v1/storage" ) const pathWildcard = "*" diff --git a/vendor/github.com/open-policy-agent/opa/storage/disk/txn.go b/vendor/github.com/open-policy-agent/opa/v1/storage/disk/txn.go similarity index 98% rename from vendor/github.com/open-policy-agent/opa/storage/disk/txn.go rename to vendor/github.com/open-policy-agent/opa/v1/storage/disk/txn.go index 95f962db5..70dc13bc8 100644 --- a/vendor/github.com/open-policy-agent/opa/storage/disk/txn.go +++ b/vendor/github.com/open-policy-agent/opa/v1/storage/disk/txn.go @@ -13,11 +13,11 @@ import ( badger "github.com/dgraph-io/badger/v3" - "github.com/open-policy-agent/opa/metrics" - "github.com/open-policy-agent/opa/storage" - "github.com/open-policy-agent/opa/storage/internal/errors" - "github.com/open-policy-agent/opa/storage/internal/ptr" - "github.com/open-policy-agent/opa/util" + "github.com/open-policy-agent/opa/v1/metrics" + "github.com/open-policy-agent/opa/v1/storage" + "github.com/open-policy-agent/opa/v1/storage/internal/errors" + "github.com/open-policy-agent/opa/v1/storage/internal/ptr" + "github.com/open-policy-agent/opa/v1/util" ) const ( diff --git a/vendor/github.com/open-policy-agent/opa/v1/storage/doc.go b/vendor/github.com/open-policy-agent/opa/v1/storage/doc.go new file mode 100644 index 000000000..6fa2f86d9 --- /dev/null +++ b/vendor/github.com/open-policy-agent/opa/v1/storage/doc.go @@ -0,0 +1,6 @@ +// Copyright 2016 The OPA Authors. All rights reserved. +// Use of this source code is governed by an Apache2 +// license that can be found in the LICENSE file. + +// Package storage exposes the policy engine's storage layer. +package storage diff --git a/vendor/github.com/open-policy-agent/opa/v1/storage/errors.go b/vendor/github.com/open-policy-agent/opa/v1/storage/errors.go new file mode 100644 index 000000000..8c789052e --- /dev/null +++ b/vendor/github.com/open-policy-agent/opa/v1/storage/errors.go @@ -0,0 +1,122 @@ +// Copyright 2016 The OPA Authors. All rights reserved. +// Use of this source code is governed by an Apache2 +// license that can be found in the LICENSE file. + +package storage + +import ( + "fmt" +) + +const ( + // InternalErr indicates an unknown, internal error has occurred. + InternalErr = "storage_internal_error" + + // NotFoundErr indicates the path used in the storage operation does not + // locate a document. + NotFoundErr = "storage_not_found_error" + + // WriteConflictErr indicates a write on the path enocuntered a conflicting + // value inside the transaction. + WriteConflictErr = "storage_write_conflict_error" + + // InvalidPatchErr indicates an invalid patch/write was issued. The patch + // was rejected. + InvalidPatchErr = "storage_invalid_patch_error" + + // InvalidTransactionErr indicates an invalid operation was performed + // inside of the transaction. + InvalidTransactionErr = "storage_invalid_txn_error" + + // TriggersNotSupportedErr indicates the caller attempted to register a + // trigger against a store that does not support them. + TriggersNotSupportedErr = "storage_triggers_not_supported_error" + + // WritesNotSupportedErr indicate the caller attempted to perform a write + // against a store that does not support them. + WritesNotSupportedErr = "storage_writes_not_supported_error" + + // PolicyNotSupportedErr indicate the caller attempted to perform a policy + // management operation against a store that does not support them. + PolicyNotSupportedErr = "storage_policy_not_supported_error" +) + +// Error is the error type returned by the storage layer. +type Error struct { + Code string `json:"code"` + Message string `json:"message"` +} + +func (err *Error) Error() string { + if err.Message != "" { + return fmt.Sprintf("%v: %v", err.Code, err.Message) + } + return err.Code +} + +// IsNotFound returns true if this error is a NotFoundErr. +func IsNotFound(err error) bool { + switch err := err.(type) { + case *Error: + return err.Code == NotFoundErr + } + return false +} + +// IsWriteConflictError returns true if this error a WriteConflictErr. +func IsWriteConflictError(err error) bool { + switch err := err.(type) { + case *Error: + return err.Code == WriteConflictErr + } + return false +} + +// IsInvalidPatch returns true if this error is a InvalidPatchErr. +func IsInvalidPatch(err error) bool { + switch err := err.(type) { + case *Error: + return err.Code == InvalidPatchErr + } + return false +} + +// IsInvalidTransaction returns true if this error is a InvalidTransactionErr. +func IsInvalidTransaction(err error) bool { + switch err := err.(type) { + case *Error: + return err.Code == InvalidTransactionErr + } + return false +} + +// IsIndexingNotSupported is a stub for backwards-compatibility. +// +// Deprecated: We no longer return IndexingNotSupported errors, so it is +// unnecessary to check for them. +func IsIndexingNotSupported(error) bool { return false } + +func writeConflictError(path Path) *Error { + return &Error{ + Code: WriteConflictErr, + Message: path.String(), + } +} + +func triggersNotSupportedError() *Error { + return &Error{ + Code: TriggersNotSupportedErr, + } +} + +func writesNotSupportedError() *Error { + return &Error{ + Code: WritesNotSupportedErr, + } +} + +func policyNotSupportedError() *Error { + return &Error{ + Code: PolicyNotSupportedErr, + } +} diff --git a/vendor/github.com/open-policy-agent/opa/storage/inmem/ast.go b/vendor/github.com/open-policy-agent/opa/v1/storage/inmem/ast.go similarity index 97% rename from vendor/github.com/open-policy-agent/opa/storage/inmem/ast.go rename to vendor/github.com/open-policy-agent/opa/v1/storage/inmem/ast.go index 5a8a6743f..667ca608e 100644 --- a/vendor/github.com/open-policy-agent/opa/storage/inmem/ast.go +++ b/vendor/github.com/open-policy-agent/opa/v1/storage/inmem/ast.go @@ -8,10 +8,10 @@ import ( "fmt" "strconv" - "github.com/open-policy-agent/opa/ast" - "github.com/open-policy-agent/opa/storage" - "github.com/open-policy-agent/opa/storage/internal/errors" - "github.com/open-policy-agent/opa/storage/internal/ptr" + "github.com/open-policy-agent/opa/v1/ast" + "github.com/open-policy-agent/opa/v1/storage" + "github.com/open-policy-agent/opa/v1/storage/internal/errors" + "github.com/open-policy-agent/opa/v1/storage/internal/ptr" ) type updateAST struct { diff --git a/vendor/github.com/open-policy-agent/opa/v1/storage/inmem/inmem.go b/vendor/github.com/open-policy-agent/opa/v1/storage/inmem/inmem.go new file mode 100644 index 000000000..7c5116b52 --- /dev/null +++ b/vendor/github.com/open-policy-agent/opa/v1/storage/inmem/inmem.go @@ -0,0 +1,458 @@ +// Copyright 2016 The OPA Authors. All rights reserved. +// Use of this source code is governed by an Apache2 +// license that can be found in the LICENSE file. + +// Package inmem implements an in-memory version of the policy engine's storage +// layer. +// +// The in-memory store is used as the default storage layer implementation. The +// in-memory store supports multi-reader/single-writer concurrency with +// rollback. +// +// Callers should assume the in-memory store does not make copies of written +// data. Once data is written to the in-memory store, it should not be modified +// (outside of calling Store.Write). Furthermore, data read from the in-memory +// store should be treated as read-only. +package inmem + +import ( + "context" + "fmt" + "io" + "path/filepath" + "strings" + "sync" + "sync/atomic" + + "github.com/open-policy-agent/opa/internal/merge" + "github.com/open-policy-agent/opa/v1/ast" + "github.com/open-policy-agent/opa/v1/storage" + "github.com/open-policy-agent/opa/v1/util" +) + +// New returns an empty in-memory store. +func New() storage.Store { + return NewWithOpts() +} + +// NewWithOpts returns an empty in-memory store, with extra options passed. +func NewWithOpts(opts ...Opt) storage.Store { + s := &store{ + triggers: map[*handle]storage.TriggerConfig{}, + policies: map[string][]byte{}, + roundTripOnWrite: true, + returnASTValuesOnRead: false, + } + + for _, opt := range opts { + opt(s) + } + + if s.returnASTValuesOnRead { + s.data = ast.NewObject() + } else { + s.data = map[string]interface{}{} + } + + return s +} + +// NewFromObject returns a new in-memory store from the supplied data object. +func NewFromObject(data map[string]interface{}) storage.Store { + return NewFromObjectWithOpts(data) +} + +// NewFromObjectWithOpts returns a new in-memory store from the supplied data object, with the +// options passed. +func NewFromObjectWithOpts(data map[string]interface{}, opts ...Opt) storage.Store { + db := NewWithOpts(opts...) + ctx := context.Background() + txn, err := db.NewTransaction(ctx, storage.WriteParams) + if err != nil { + panic(err) + } + if err := db.Write(ctx, txn, storage.AddOp, storage.Path{}, data); err != nil { + panic(err) + } + if err := db.Commit(ctx, txn); err != nil { + panic(err) + } + return db +} + +// NewFromReader returns a new in-memory store from a reader that produces a +// JSON serialized object. This function is for test purposes. +func NewFromReader(r io.Reader) storage.Store { + return NewFromReaderWithOpts(r) +} + +// NewFromReader returns a new in-memory store from a reader that produces a +// JSON serialized object, with extra options. This function is for test purposes. +func NewFromReaderWithOpts(r io.Reader, opts ...Opt) storage.Store { + d := util.NewJSONDecoder(r) + var data map[string]interface{} + if err := d.Decode(&data); err != nil { + panic(err) + } + return NewFromObjectWithOpts(data, opts...) +} + +type store struct { + rmu sync.RWMutex // reader-writer lock + wmu sync.Mutex // writer lock + xid uint64 // last generated transaction id + data interface{} // raw or AST data + policies map[string][]byte // raw policies + triggers map[*handle]storage.TriggerConfig // registered triggers + + // roundTripOnWrite, if true, means that every call to Write round trips the + // data through JSON before adding the data to the store. Defaults to true. + roundTripOnWrite bool + + // returnASTValuesOnRead, if true, means that the store will eagerly convert data to AST values, + // and return them on Read. + // FIXME: naming(?) + returnASTValuesOnRead bool +} + +type handle struct { + db *store +} + +func (db *store) NewTransaction(_ context.Context, params ...storage.TransactionParams) (storage.Transaction, error) { + var write bool + var ctx *storage.Context + if len(params) > 0 { + write = params[0].Write + ctx = params[0].Context + } + xid := atomic.AddUint64(&db.xid, uint64(1)) + if write { + db.wmu.Lock() + } else { + db.rmu.RLock() + } + return newTransaction(xid, write, ctx, db), nil +} + +// Truncate implements the storage.Store interface. This method must be called within a transaction. +func (db *store) Truncate(ctx context.Context, txn storage.Transaction, params storage.TransactionParams, it storage.Iterator) error { + var update *storage.Update + var err error + mergedData := map[string]interface{}{} + + underlying, err := db.underlying(txn) + if err != nil { + return err + } + + for { + update, err = it.Next() + if err != nil { + break + } + + if update.IsPolicy { + err = underlying.UpsertPolicy(strings.TrimLeft(update.Path.String(), "/"), update.Value) + if err != nil { + return err + } + } else { + var value interface{} + err = util.Unmarshal(update.Value, &value) + if err != nil { + return err + } + + var key []string + dirpath := strings.TrimLeft(update.Path.String(), "/") + if len(dirpath) > 0 { + key = strings.Split(dirpath, "/") + } + + if value != nil { + obj, err := mktree(key, value) + if err != nil { + return err + } + + merged, ok := merge.InterfaceMaps(mergedData, obj) + if !ok { + return fmt.Errorf("failed to insert data file from path %s", filepath.Join(key...)) + } + mergedData = merged + } + } + } + + if err != nil && err != io.EOF { + return err + } + + // For backwards compatibility, check if `RootOverwrite` was configured. + if params.RootOverwrite { + newPath, ok := storage.ParsePathEscaped("/") + if !ok { + return fmt.Errorf("storage path invalid: %v", newPath) + } + return underlying.Write(storage.AddOp, newPath, mergedData) + } + + for _, root := range params.BasePaths { + newPath, ok := storage.ParsePathEscaped("/" + root) + if !ok { + return fmt.Errorf("storage path invalid: %v", newPath) + } + + if value, ok := lookup(newPath, mergedData); ok { + if len(newPath) > 0 { + if err := storage.MakeDir(ctx, db, txn, newPath[:len(newPath)-1]); err != nil { + return err + } + } + if err := underlying.Write(storage.AddOp, newPath, value); err != nil { + return err + } + } + } + return nil +} + +func (db *store) Commit(ctx context.Context, txn storage.Transaction) error { + underlying, err := db.underlying(txn) + if err != nil { + return err + } + if underlying.write { + db.rmu.Lock() + event := underlying.Commit() + db.runOnCommitTriggers(ctx, txn, event) + // Mark the transaction stale after executing triggers, so they can + // perform store operations if needed. + underlying.stale = true + db.rmu.Unlock() + db.wmu.Unlock() + } else { + db.rmu.RUnlock() + } + return nil +} + +func (db *store) Abort(_ context.Context, txn storage.Transaction) { + underlying, err := db.underlying(txn) + if err != nil { + panic(err) + } + underlying.stale = true + if underlying.write { + db.wmu.Unlock() + } else { + db.rmu.RUnlock() + } +} + +func (db *store) ListPolicies(_ context.Context, txn storage.Transaction) ([]string, error) { + underlying, err := db.underlying(txn) + if err != nil { + return nil, err + } + return underlying.ListPolicies(), nil +} + +func (db *store) GetPolicy(_ context.Context, txn storage.Transaction, id string) ([]byte, error) { + underlying, err := db.underlying(txn) + if err != nil { + return nil, err + } + return underlying.GetPolicy(id) +} + +func (db *store) UpsertPolicy(_ context.Context, txn storage.Transaction, id string, bs []byte) error { + underlying, err := db.underlying(txn) + if err != nil { + return err + } + return underlying.UpsertPolicy(id, bs) +} + +func (db *store) DeletePolicy(_ context.Context, txn storage.Transaction, id string) error { + underlying, err := db.underlying(txn) + if err != nil { + return err + } + if _, err := underlying.GetPolicy(id); err != nil { + return err + } + return underlying.DeletePolicy(id) +} + +func (db *store) Register(_ context.Context, txn storage.Transaction, config storage.TriggerConfig) (storage.TriggerHandle, error) { + underlying, err := db.underlying(txn) + if err != nil { + return nil, err + } + if !underlying.write { + return nil, &storage.Error{ + Code: storage.InvalidTransactionErr, + Message: "triggers must be registered with a write transaction", + } + } + h := &handle{db} + db.triggers[h] = config + return h, nil +} + +func (db *store) Read(_ context.Context, txn storage.Transaction, path storage.Path) (interface{}, error) { + underlying, err := db.underlying(txn) + if err != nil { + return nil, err + } + + v, err := underlying.Read(path) + if err != nil { + return nil, err + } + + return v, nil +} + +func (db *store) Write(_ context.Context, txn storage.Transaction, op storage.PatchOp, path storage.Path, value interface{}) error { + underlying, err := db.underlying(txn) + if err != nil { + return err + } + val := util.Reference(value) + if db.roundTripOnWrite { + if err := util.RoundTrip(val); err != nil { + return err + } + } + return underlying.Write(op, path, *val) +} + +func (h *handle) Unregister(_ context.Context, txn storage.Transaction) { + underlying, err := h.db.underlying(txn) + if err != nil { + panic(err) + } + if !underlying.write { + panic(&storage.Error{ + Code: storage.InvalidTransactionErr, + Message: "triggers must be unregistered with a write transaction", + }) + } + delete(h.db.triggers, h) +} + +func (db *store) runOnCommitTriggers(ctx context.Context, txn storage.Transaction, event storage.TriggerEvent) { + if db.returnASTValuesOnRead && len(db.triggers) > 0 { + // FIXME: Not very performant for large data. + + dataEvents := make([]storage.DataEvent, 0, len(event.Data)) + + for _, dataEvent := range event.Data { + if astData, ok := dataEvent.Data.(ast.Value); ok { + jsn, err := ast.ValueToInterface(astData, illegalResolver{}) + if err != nil { + panic(err) + } + dataEvents = append(dataEvents, storage.DataEvent{ + Path: dataEvent.Path, + Data: jsn, + Removed: dataEvent.Removed, + }) + } else { + dataEvents = append(dataEvents, dataEvent) + } + } + + event = storage.TriggerEvent{ + Policy: event.Policy, + Data: dataEvents, + Context: event.Context, + } + } + + for _, t := range db.triggers { + t.OnCommit(ctx, txn, event) + } +} + +type illegalResolver struct{} + +func (illegalResolver) Resolve(ref ast.Ref) (interface{}, error) { + return nil, fmt.Errorf("illegal value: %v", ref) +} + +func (db *store) underlying(txn storage.Transaction) (*transaction, error) { + underlying, ok := txn.(*transaction) + if !ok { + return nil, &storage.Error{ + Code: storage.InvalidTransactionErr, + Message: fmt.Sprintf("unexpected transaction type %T", txn), + } + } + if underlying.db != db { + return nil, &storage.Error{ + Code: storage.InvalidTransactionErr, + Message: "unknown transaction", + } + } + if underlying.stale { + return nil, &storage.Error{ + Code: storage.InvalidTransactionErr, + Message: "stale transaction", + } + } + return underlying, nil +} + +const rootMustBeObjectMsg = "root must be object" +const rootCannotBeRemovedMsg = "root cannot be removed" + +func invalidPatchError(f string, a ...interface{}) *storage.Error { + return &storage.Error{ + Code: storage.InvalidPatchErr, + Message: fmt.Sprintf(f, a...), + } +} + +func mktree(path []string, value interface{}) (map[string]interface{}, error) { + if len(path) == 0 { + // For 0 length path the value is the full tree. + obj, ok := value.(map[string]interface{}) + if !ok { + return nil, invalidPatchError(rootMustBeObjectMsg) + } + return obj, nil + } + + dir := map[string]interface{}{} + for i := len(path) - 1; i > 0; i-- { + dir[path[i]] = value + value = dir + dir = map[string]interface{}{} + } + dir[path[0]] = value + + return dir, nil +} + +func lookup(path storage.Path, data map[string]interface{}) (interface{}, bool) { + if len(path) == 0 { + return data, true + } + for i := 0; i < len(path)-1; i++ { + value, ok := data[path[i]] + if !ok { + return nil, false + } + obj, ok := value.(map[string]interface{}) + if !ok { + return nil, false + } + data = obj + } + value, ok := data[path[len(path)-1]] + return value, ok +} diff --git a/vendor/github.com/open-policy-agent/opa/v1/storage/inmem/opts.go b/vendor/github.com/open-policy-agent/opa/v1/storage/inmem/opts.go new file mode 100644 index 000000000..2239fc73a --- /dev/null +++ b/vendor/github.com/open-policy-agent/opa/v1/storage/inmem/opts.go @@ -0,0 +1,37 @@ +package inmem + +// An Opt modifies store at instantiation. +type Opt func(*store) + +// OptRoundTripOnWrite sets whether incoming objects written to store are +// round-tripped through JSON to ensure they are serializable to JSON. +// +// Callers should disable this if they can guarantee all objects passed to +// Write() are serializable to JSON. Failing to do so may result in undefined +// behavior, including panics. +// +// Usually, when only storing objects in the inmem store that have been read +// via encoding/json, this is safe to disable, and comes with an improvement +// in performance and memory use. +// +// If setting to false, callers should deep-copy any objects passed to Write() +// unless they can guarantee the objects will not be mutated after being written, +// and that mutations happening to the objects after they have been passed into +// Write() don't affect their logic. +func OptRoundTripOnWrite(enabled bool) Opt { + return func(s *store) { + s.roundTripOnWrite = enabled + } +} + +// OptReturnASTValuesOnRead sets whether data values added to the store should be +// eagerly converted to AST values, which are then returned on read. +// +// When enabled, this feature does not sanity check data before converting it to AST values, +// which may result in panics if the data is not valid. Callers should ensure that passed data +// can be serialized to AST values; otherwise, it's recommended to also enable OptRoundTripOnWrite. +func OptReturnASTValuesOnRead(enabled bool) Opt { + return func(s *store) { + s.returnASTValuesOnRead = enabled + } +} diff --git a/vendor/github.com/open-policy-agent/opa/storage/inmem/txn.go b/vendor/github.com/open-policy-agent/opa/v1/storage/inmem/txn.go similarity index 98% rename from vendor/github.com/open-policy-agent/opa/storage/inmem/txn.go rename to vendor/github.com/open-policy-agent/opa/v1/storage/inmem/txn.go index d3252e882..f8a730391 100644 --- a/vendor/github.com/open-policy-agent/opa/storage/inmem/txn.go +++ b/vendor/github.com/open-policy-agent/opa/v1/storage/inmem/txn.go @@ -9,11 +9,11 @@ import ( "encoding/json" "strconv" - "github.com/open-policy-agent/opa/ast" "github.com/open-policy-agent/opa/internal/deepcopy" - "github.com/open-policy-agent/opa/storage" - "github.com/open-policy-agent/opa/storage/internal/errors" - "github.com/open-policy-agent/opa/storage/internal/ptr" + "github.com/open-policy-agent/opa/v1/ast" + "github.com/open-policy-agent/opa/v1/storage" + "github.com/open-policy-agent/opa/v1/storage/internal/errors" + "github.com/open-policy-agent/opa/v1/storage/internal/ptr" ) // transaction implements the low-level read/write operations on the in-memory diff --git a/vendor/github.com/open-policy-agent/opa/v1/storage/interface.go b/vendor/github.com/open-policy-agent/opa/v1/storage/interface.go new file mode 100644 index 000000000..94e02a47b --- /dev/null +++ b/vendor/github.com/open-policy-agent/opa/v1/storage/interface.go @@ -0,0 +1,247 @@ +// Copyright 2016 The OPA Authors. All rights reserved. +// Use of this source code is governed by an Apache2 +// license that can be found in the LICENSE file. + +package storage + +import ( + "context" + + "github.com/open-policy-agent/opa/v1/metrics" +) + +// Transaction defines the interface that identifies a consistent snapshot over +// the policy engine's storage layer. +type Transaction interface { + ID() uint64 +} + +// Store defines the interface for the storage layer's backend. +type Store interface { + Trigger + Policy + + // NewTransaction is called create a new transaction in the store. + NewTransaction(context.Context, ...TransactionParams) (Transaction, error) + + // Read is called to fetch a document referred to by path. + Read(context.Context, Transaction, Path) (interface{}, error) + + // Write is called to modify a document referred to by path. + Write(context.Context, Transaction, PatchOp, Path, interface{}) error + + // Commit is called to finish the transaction. If Commit returns an error, the + // transaction must be automatically aborted by the Store implementation. + Commit(context.Context, Transaction) error + + // Truncate is called to make a copy of the underlying store, write documents in the new store + // by creating multiple transactions in the new store as needed and finally swapping + // over to the new storage instance. This method must be called within a transaction on the original store. + Truncate(context.Context, Transaction, TransactionParams, Iterator) error + + // Abort is called to cancel the transaction. + Abort(context.Context, Transaction) +} + +// MakeDirer defines the interface a Store could realize to override the +// generic MakeDir functionality in storage.MakeDir +type MakeDirer interface { + MakeDir(context.Context, Transaction, Path) error +} + +// TransactionParams describes a new transaction. +type TransactionParams struct { + + // BasePaths indicates the top-level paths where write operations will be performed in this transaction. + BasePaths []string + + // RootOverwrite is deprecated. Use BasePaths instead. + RootOverwrite bool + + // Write indicates if this transaction will perform any write operations. + Write bool + + // Context contains key/value pairs passed to triggers. + Context *Context +} + +// Context is a simple container for key/value pairs. +type Context struct { + values map[interface{}]interface{} +} + +// NewContext returns a new context object. +func NewContext() *Context { + return &Context{ + values: map[interface{}]interface{}{}, + } +} + +// Get returns the key value in the context. +func (ctx *Context) Get(key interface{}) interface{} { + if ctx == nil { + return nil + } + return ctx.values[key] +} + +// Put adds a key/value pair to the context. +func (ctx *Context) Put(key, value interface{}) { + ctx.values[key] = value +} + +var metricsKey = struct{}{} + +// WithMetrics allows passing metrics via the Context. +// It puts the metrics object in the ctx, and returns the same +// ctx (not a copy) for convenience. +func (ctx *Context) WithMetrics(m metrics.Metrics) *Context { + ctx.values[metricsKey] = m + return ctx +} + +// Metrics() allows using a Context's metrics. Returns nil if metrics +// were not attached to the Context. +func (ctx *Context) Metrics() metrics.Metrics { + if m, ok := ctx.values[metricsKey]; ok { + if met, ok := m.(metrics.Metrics); ok { + return met + } + } + return nil +} + +// WriteParams specifies the TransactionParams for a write transaction. +var WriteParams = TransactionParams{ + Write: true, +} + +// PatchOp is the enumeration of supposed modifications. +type PatchOp int + +// Patch supports add, remove, and replace operations. +const ( + AddOp PatchOp = iota + RemoveOp = iota + ReplaceOp = iota +) + +// WritesNotSupported provides a default implementation of the write +// interface which may be used if the backend does not support writes. +type WritesNotSupported struct{} + +func (WritesNotSupported) Write(context.Context, Transaction, PatchOp, Path, interface{}) error { + return writesNotSupportedError() +} + +// Policy defines the interface for policy module storage. +type Policy interface { + ListPolicies(context.Context, Transaction) ([]string, error) + GetPolicy(context.Context, Transaction, string) ([]byte, error) + UpsertPolicy(context.Context, Transaction, string, []byte) error + DeletePolicy(context.Context, Transaction, string) error +} + +// PolicyNotSupported provides a default implementation of the policy interface +// which may be used if the backend does not support policy storage. +type PolicyNotSupported struct{} + +// ListPolicies always returns a PolicyNotSupportedErr. +func (PolicyNotSupported) ListPolicies(context.Context, Transaction) ([]string, error) { + return nil, policyNotSupportedError() +} + +// GetPolicy always returns a PolicyNotSupportedErr. +func (PolicyNotSupported) GetPolicy(context.Context, Transaction, string) ([]byte, error) { + return nil, policyNotSupportedError() +} + +// UpsertPolicy always returns a PolicyNotSupportedErr. +func (PolicyNotSupported) UpsertPolicy(context.Context, Transaction, string, []byte) error { + return policyNotSupportedError() +} + +// DeletePolicy always returns a PolicyNotSupportedErr. +func (PolicyNotSupported) DeletePolicy(context.Context, Transaction, string) error { + return policyNotSupportedError() +} + +// PolicyEvent describes a change to a policy. +type PolicyEvent struct { + ID string + Data []byte + Removed bool +} + +// DataEvent describes a change to a base data document. +type DataEvent struct { + Path Path + Data interface{} + Removed bool +} + +// TriggerEvent describes the changes that caused the trigger to be invoked. +type TriggerEvent struct { + Policy []PolicyEvent + Data []DataEvent + Context *Context +} + +// IsZero returns true if the TriggerEvent indicates no changes occurred. This +// function is primarily for test purposes. +func (e TriggerEvent) IsZero() bool { + return !e.PolicyChanged() && !e.DataChanged() +} + +// PolicyChanged returns true if the trigger was caused by a policy change. +func (e TriggerEvent) PolicyChanged() bool { + return len(e.Policy) > 0 +} + +// DataChanged returns true if the trigger was caused by a data change. +func (e TriggerEvent) DataChanged() bool { + return len(e.Data) > 0 +} + +// TriggerConfig contains the trigger registration configuration. +type TriggerConfig struct { + + // OnCommit is invoked when a transaction is successfully committed. The + // callback is invoked with a handle to the write transaction that + // successfully committed before other clients see the changes. + OnCommit func(context.Context, Transaction, TriggerEvent) +} + +// Trigger defines the interface that stores implement to register for change +// notifications when the store is changed. +type Trigger interface { + Register(context.Context, Transaction, TriggerConfig) (TriggerHandle, error) +} + +// TriggersNotSupported provides default implementations of the Trigger +// interface which may be used if the backend does not support triggers. +type TriggersNotSupported struct{} + +// Register always returns an error indicating triggers are not supported. +func (TriggersNotSupported) Register(context.Context, Transaction, TriggerConfig) (TriggerHandle, error) { + return nil, triggersNotSupportedError() +} + +// TriggerHandle defines the interface that can be used to unregister triggers that have +// been registered on a Store. +type TriggerHandle interface { + Unregister(context.Context, Transaction) +} + +// Iterator defines the interface that can be used to read files from a directory starting with +// files at the base of the directory, then sub-directories etc. +type Iterator interface { + Next() (*Update, error) +} + +// Update contains information about a file +type Update struct { + Path Path + Value []byte + IsPolicy bool +} diff --git a/vendor/github.com/open-policy-agent/opa/storage/internal/errors/errors.go b/vendor/github.com/open-policy-agent/opa/v1/storage/internal/errors/errors.go similarity index 95% rename from vendor/github.com/open-policy-agent/opa/storage/internal/errors/errors.go rename to vendor/github.com/open-policy-agent/opa/v1/storage/internal/errors/errors.go index 0bba74b90..06063b4c7 100644 --- a/vendor/github.com/open-policy-agent/opa/storage/internal/errors/errors.go +++ b/vendor/github.com/open-policy-agent/opa/v1/storage/internal/errors/errors.go @@ -8,7 +8,7 @@ package errors import ( "fmt" - "github.com/open-policy-agent/opa/storage" + "github.com/open-policy-agent/opa/v1/storage" ) const ArrayIndexTypeMsg = "array index must be integer" diff --git a/vendor/github.com/open-policy-agent/opa/storage/internal/ptr/ptr.go b/vendor/github.com/open-policy-agent/opa/v1/storage/internal/ptr/ptr.go similarity index 94% rename from vendor/github.com/open-policy-agent/opa/storage/internal/ptr/ptr.go rename to vendor/github.com/open-policy-agent/opa/v1/storage/internal/ptr/ptr.go index 14adbd682..d1c36a15a 100644 --- a/vendor/github.com/open-policy-agent/opa/storage/internal/ptr/ptr.go +++ b/vendor/github.com/open-policy-agent/opa/v1/storage/internal/ptr/ptr.go @@ -8,9 +8,9 @@ package ptr import ( "strconv" - "github.com/open-policy-agent/opa/ast" - "github.com/open-policy-agent/opa/storage" - "github.com/open-policy-agent/opa/storage/internal/errors" + "github.com/open-policy-agent/opa/v1/ast" + "github.com/open-policy-agent/opa/v1/storage" + "github.com/open-policy-agent/opa/v1/storage/internal/errors" ) func Ptr(data interface{}, path storage.Path) (interface{}, error) { diff --git a/vendor/github.com/open-policy-agent/opa/v1/storage/path.go b/vendor/github.com/open-policy-agent/opa/v1/storage/path.go new file mode 100644 index 000000000..7f90c666b --- /dev/null +++ b/vendor/github.com/open-policy-agent/opa/v1/storage/path.go @@ -0,0 +1,150 @@ +// Copyright 2016 The OPA Authors. All rights reserved. +// Use of this source code is governed by an Apache2 +// license that can be found in the LICENSE file. + +package storage + +import ( + "fmt" + "net/url" + "strconv" + "strings" + + "github.com/open-policy-agent/opa/v1/ast" +) + +// Path refers to a document in storage. +type Path []string + +// ParsePath returns a new path for the given str. +func ParsePath(str string) (path Path, ok bool) { + if len(str) == 0 { + return nil, false + } + if str[0] != '/' { + return nil, false + } + if len(str) == 1 { + return Path{}, true + } + parts := strings.Split(str[1:], "/") + return parts, true +} + +// ParsePathEscaped returns a new path for the given escaped str. +func ParsePathEscaped(str string) (path Path, ok bool) { + path, ok = ParsePath(str) + if !ok { + return + } + for i := range path { + segment, err := url.PathUnescape(path[i]) + if err == nil { + path[i] = segment + } + } + return +} + +// NewPathForRef returns a new path for the given ref. +func NewPathForRef(ref ast.Ref) (path Path, err error) { + + if len(ref) == 0 { + return nil, fmt.Errorf("empty reference (indicates error in caller)") + } + + if len(ref) == 1 { + return Path{}, nil + } + + path = make(Path, 0, len(ref)-1) + + for _, term := range ref[1:] { + switch v := term.Value.(type) { + case ast.String: + path = append(path, string(v)) + case ast.Number: + path = append(path, v.String()) + case ast.Boolean, ast.Null: + return nil, &Error{ + Code: NotFoundErr, + Message: fmt.Sprintf("%v: does not exist", ref), + } + case *ast.Array, ast.Object, ast.Set: + return nil, fmt.Errorf("composites cannot be base document keys: %v", ref) + default: + return nil, fmt.Errorf("unresolved reference (indicates error in caller): %v", ref) + } + } + + return path, nil +} + +// Compare performs lexigraphical comparison on p and other and returns -1 if p +// is less than other, 0 if p is equal to other, or 1 if p is greater than +// other. +func (p Path) Compare(other Path) (cmp int) { + for i := 0; i < min(len(p), len(other)); i++ { + if cmp := strings.Compare(p[i], other[i]); cmp != 0 { + return cmp + } + } + if len(p) < len(other) { + return -1 + } + if len(p) == len(other) { + return 0 + } + return 1 +} + +// Equal returns true if p is the same as other. +func (p Path) Equal(other Path) bool { + return p.Compare(other) == 0 +} + +// HasPrefix returns true if p starts with other. +func (p Path) HasPrefix(other Path) bool { + if len(other) > len(p) { + return false + } + for i := range other { + if p[i] != other[i] { + return false + } + } + return true +} + +// Ref returns a ref that represents p rooted at head. +func (p Path) Ref(head *ast.Term) (ref ast.Ref) { + ref = make(ast.Ref, len(p)+1) + ref[0] = head + for i := range p { + idx, err := strconv.ParseInt(p[i], 10, 64) + if err == nil { + ref[i+1] = ast.UIntNumberTerm(uint64(idx)) + } else { + ref[i+1] = ast.StringTerm(p[i]) + } + } + return ref +} + +func (p Path) String() string { + buf := make([]string, len(p)) + for i := range buf { + buf[i] = url.PathEscape(p[i]) + } + return "/" + strings.Join(buf, "/") +} + +// MustParsePath returns a new Path for s. If s cannot be parsed, this function +// will panic. This is mostly for test purposes. +func MustParsePath(s string) Path { + path, ok := ParsePath(s) + if !ok { + panic(s) + } + return path +} diff --git a/vendor/github.com/open-policy-agent/opa/v1/storage/storage.go b/vendor/github.com/open-policy-agent/opa/v1/storage/storage.go new file mode 100644 index 000000000..34305f291 --- /dev/null +++ b/vendor/github.com/open-policy-agent/opa/v1/storage/storage.go @@ -0,0 +1,136 @@ +// Copyright 2016 The OPA Authors. All rights reserved. +// Use of this source code is governed by an Apache2 +// license that can be found in the LICENSE file. + +package storage + +import ( + "context" + + "github.com/open-policy-agent/opa/v1/ast" +) + +// NewTransactionOrDie is a helper function to create a new transaction. If the +// storage layer cannot create a new transaction, this function will panic. This +// function should only be used for tests. +func NewTransactionOrDie(ctx context.Context, store Store, params ...TransactionParams) Transaction { + txn, err := store.NewTransaction(ctx, params...) + if err != nil { + panic(err) + } + return txn +} + +// ReadOne is a convenience function to read a single value from the provided Store. It +// will create a new Transaction to perform the read with, and clean up after itself +// should an error occur. +func ReadOne(ctx context.Context, store Store, path Path) (interface{}, error) { + txn, err := store.NewTransaction(ctx) + if err != nil { + return nil, err + } + defer store.Abort(ctx, txn) + + return store.Read(ctx, txn, path) +} + +// WriteOne is a convenience function to write a single value to the provided Store. It +// will create a new Transaction to perform the write with, and clean up after itself +// should an error occur. +func WriteOne(ctx context.Context, store Store, op PatchOp, path Path, value interface{}) error { + txn, err := store.NewTransaction(ctx, WriteParams) + if err != nil { + return err + } + + if err := store.Write(ctx, txn, op, path, value); err != nil { + store.Abort(ctx, txn) + return err + } + + return store.Commit(ctx, txn) +} + +// MakeDir inserts an empty object at path. If the parent path does not exist, +// MakeDir will create it recursively. +func MakeDir(ctx context.Context, store Store, txn Transaction, path Path) error { + + // Allow the Store implementation to deal with this in its own way. + if md, ok := store.(MakeDirer); ok { + return md.MakeDir(ctx, txn, path) + } + + if len(path) == 0 { + return nil + } + + node, err := store.Read(ctx, txn, path) + if err != nil { + if !IsNotFound(err) { + return err + } + + if err := MakeDir(ctx, store, txn, path[:len(path)-1]); err != nil { + return err + } + + return store.Write(ctx, txn, AddOp, path, map[string]interface{}{}) + } + + if _, ok := node.(map[string]interface{}); ok { + return nil + } + + if _, ok := node.(ast.Object); ok { + return nil + } + + return writeConflictError(path) +} + +// Txn is a convenience function that executes f inside a new transaction +// opened on the store. If the function returns an error, the transaction is +// aborted and the error is returned. Otherwise, the transaction is committed +// and the result of the commit is returned. +func Txn(ctx context.Context, store Store, params TransactionParams, f func(Transaction) error) error { + + txn, err := store.NewTransaction(ctx, params) + if err != nil { + return err + } + + if err := f(txn); err != nil { + store.Abort(ctx, txn) + return err + } + + return store.Commit(ctx, txn) +} + +// NonEmpty returns a function that tests if a path is non-empty. A +// path is non-empty if a Read on the path returns a value or a Read +// on any of the path prefixes returns a non-object value. +func NonEmpty(ctx context.Context, store Store, txn Transaction) func([]string) (bool, error) { + return func(path []string) (bool, error) { + if _, err := store.Read(ctx, txn, Path(path)); err == nil { + return true, nil + } else if !IsNotFound(err) { + return false, err + } + for i := len(path) - 1; i > 0; i-- { + val, err := store.Read(ctx, txn, Path(path[:i])) + if err != nil && !IsNotFound(err) { + return false, err + } else if err == nil { + if _, ok := val.(map[string]interface{}); ok { + return false, nil + } + if _, ok := val.(ast.Object); ok { + return false, nil + } + return true, nil + } + } + return false, nil + } +} diff --git a/vendor/github.com/open-policy-agent/opa/tester/reporter.go b/vendor/github.com/open-policy-agent/opa/v1/tester/reporter.go similarity index 98% rename from vendor/github.com/open-policy-agent/opa/tester/reporter.go rename to vendor/github.com/open-policy-agent/opa/v1/tester/reporter.go index c7ab9cbc9..1e7bd1e02 100644 --- a/vendor/github.com/open-policy-agent/opa/tester/reporter.go +++ b/vendor/github.com/open-policy-agent/opa/v1/tester/reporter.go @@ -11,9 +11,9 @@ import ( "io" "strings" - "github.com/open-policy-agent/opa/ast" - "github.com/open-policy-agent/opa/cover" - "github.com/open-policy-agent/opa/topdown" + "github.com/open-policy-agent/opa/v1/ast" + "github.com/open-policy-agent/opa/v1/cover" + "github.com/open-policy-agent/opa/v1/topdown" ) // Reporter defines the interface for reporting test results. diff --git a/vendor/github.com/open-policy-agent/opa/tester/runner.go b/vendor/github.com/open-policy-agent/opa/v1/tester/runner.go similarity index 85% rename from vendor/github.com/open-policy-agent/opa/tester/runner.go rename to vendor/github.com/open-policy-agent/opa/v1/tester/runner.go index 25c23abf4..4522ffcef 100644 --- a/vendor/github.com/open-policy-agent/opa/tester/runner.go +++ b/vendor/github.com/open-policy-agent/opa/v1/tester/runner.go @@ -15,15 +15,15 @@ import ( "testing" "time" - "github.com/open-policy-agent/opa/ast" - "github.com/open-policy-agent/opa/bundle" wasm_errors "github.com/open-policy-agent/opa/internal/wasm/sdk/opa/errors" - "github.com/open-policy-agent/opa/loader" - "github.com/open-policy-agent/opa/metrics" - "github.com/open-policy-agent/opa/rego" - "github.com/open-policy-agent/opa/storage" - "github.com/open-policy-agent/opa/storage/inmem" - "github.com/open-policy-agent/opa/topdown" + "github.com/open-policy-agent/opa/v1/ast" + "github.com/open-policy-agent/opa/v1/bundle" + "github.com/open-policy-agent/opa/v1/loader" + "github.com/open-policy-agent/opa/v1/metrics" + "github.com/open-policy-agent/opa/v1/rego" + "github.com/open-policy-agent/opa/v1/storage" + "github.com/open-policy-agent/opa/v1/storage/inmem" + "github.com/open-policy-agent/opa/v1/topdown" ) // TestPrefix declares the prefix for all test rules. @@ -126,15 +126,24 @@ type Runner struct { filter string target string // target type (wasm, rego, etc.) customBuiltins []*Builtin + defaultRegoVersion ast.RegoVersion } // NewRunner returns a new runner. func NewRunner() *Runner { return &Runner{ - timeout: 5 * time.Second, + timeout: 5 * time.Second, + defaultRegoVersion: ast.DefaultRegoVersion, } } +// SetDefaultRegoVersion sets the default Rego version to use when compiling modules. +// Not applicable if a custom [ast.Compiler] is set via [SetCompiler]. +func (r *Runner) SetDefaultRegoVersion(v ast.RegoVersion) *Runner { + r.defaultRegoVersion = v + return r +} + // SetCompiler sets the compiler used by the runner. func (r *Runner) SetCompiler(compiler *ast.Compiler) *Runner { r.compiler = compiler @@ -303,7 +312,8 @@ func (r *Runner) runTests(ctx context.Context, txn storage.Transaction, enablePr r.compiler = ast.NewCompiler(). WithCapabilities(capabilities). - WithEnablePrintStatements(enablePrintStatements) + WithEnablePrintStatements(enablePrintStatements). + WithDefaultRegoVersion(r.defaultRegoVersion) } // rewrite duplicate test_* rule names as we compile modules @@ -317,7 +327,7 @@ func (r *Runner) runTests(ctx context.Context, txn storage.Transaction, enablePr r.store = inmem.NewWithOpts(inmem.OptRoundTripOnWrite(false)) } - if r.bundles != nil && len(r.bundles) > 0 { + if len(r.bundles) > 0 { if txn == nil { return nil, fmt.Errorf("unable to activate bundles: storage transaction is nil") } @@ -612,6 +622,10 @@ func Load(args []string, filter loader.Filter) (map[string]*ast.Module, storage. // LoadWithRegoVersion returns modules and an in-memory store for running tests. // Modules are parsed in accordance with the given RegoVersion. func LoadWithRegoVersion(args []string, filter loader.Filter, regoVersion ast.RegoVersion) (map[string]*ast.Module, storage.Store, error) { + if regoVersion == ast.RegoUndefined { + regoVersion = ast.DefaultRegoVersion + } + loaded, err := loader.NewFileLoader(). WithRegoVersion(regoVersion). WithProcessAnnotation(true). @@ -639,6 +653,38 @@ func LoadWithRegoVersion(args []string, filter loader.Filter, regoVersion ast.Re return modules, store, err } +// LoadWithParserOptions returns modules and an in-memory store for running tests. +// Modules are parsed in accordance with the given [ast.ParserOptions]. +func LoadWithParserOptions(args []string, filter loader.Filter, popts ast.ParserOptions) (map[string]*ast.Module, storage.Store, error) { + loaded, err := loader.NewFileLoader(). + WithRegoVersion(popts.RegoVersion). + WithCapabilities(popts.Capabilities). + WithProcessAnnotation(popts.ProcessAnnotation). + WithJSONOptions(popts.JSONOptions). + Filtered(args, filter) + if err != nil { + return nil, nil, err + } + store := inmem.NewFromObject(loaded.Documents) + modules := map[string]*ast.Module{} + ctx := context.Background() + err = storage.Txn(ctx, store, storage.WriteParams, func(txn storage.Transaction) error { + for _, loadedModule := range loaded.Modules { + modules[loadedModule.Name] = loadedModule.Parsed + + // Add the policies to the store to ensure that any future bundle + // activations will preserve them and re-compile the module with + // the bundle modules. + err := store.UpsertPolicy(ctx, txn, loadedModule.Name, loadedModule.Raw) + if err != nil { + return err + } + } + return nil + }) + return modules, store, err +} + // LoadBundles will load the given args as bundles, either tarball or directory is OK. func LoadBundles(args []string, filter loader.Filter) (map[string]*bundle.Bundle, error) { return LoadBundlesWithRegoVersion(args, filter, ast.RegoV0) @@ -647,6 +693,10 @@ func LoadBundles(args []string, filter loader.Filter) (map[string]*bundle.Bundle // LoadBundlesWithRegoVersion will load the given args as bundles, either tarball or directory is OK. // Bundles are parsed in accordance with the given RegoVersion. func LoadBundlesWithRegoVersion(args []string, filter loader.Filter, regoVersion ast.RegoVersion) (map[string]*bundle.Bundle, error) { + if regoVersion == ast.RegoUndefined { + regoVersion = ast.DefaultRegoVersion + } + bundles := map[string]*bundle.Bundle{} for _, bundleDir := range args { b, err := loader.NewFileLoader(). @@ -663,3 +713,29 @@ func LoadBundlesWithRegoVersion(args []string, filter loader.Filter, regoVersion return bundles, nil } + +// LoadBundlesWithParserOptions will load the given args as bundles, either tarball or directory is OK. +// Bundles are parsed in accordance with the given [ast.ParserOptions]. +func LoadBundlesWithParserOptions(args []string, filter loader.Filter, popts ast.ParserOptions) (map[string]*bundle.Bundle, error) { + if popts.RegoVersion == ast.RegoUndefined { + popts.RegoVersion = ast.DefaultRegoVersion + } + + bundles := map[string]*bundle.Bundle{} + for _, bundleDir := range args { + b, err := loader.NewFileLoader(). + WithRegoVersion(popts.RegoVersion). + WithCapabilities(popts.Capabilities). + WithProcessAnnotation(popts.ProcessAnnotation). + WithJSONOptions(popts.JSONOptions). + WithSkipBundleVerification(true). + WithFilter(filter). + AsBundle(bundleDir) + if err != nil { + return nil, fmt.Errorf("unable to load bundle %s: %s", bundleDir, err) + } + bundles[bundleDir] = b + } + + return bundles, nil +} diff --git a/vendor/github.com/open-policy-agent/opa/topdown/aggregates.go b/vendor/github.com/open-policy-agent/opa/v1/topdown/aggregates.go similarity index 84% rename from vendor/github.com/open-policy-agent/opa/topdown/aggregates.go rename to vendor/github.com/open-policy-agent/opa/v1/topdown/aggregates.go index a0f67a7c9..e7d057822 100644 --- a/vendor/github.com/open-policy-agent/opa/topdown/aggregates.go +++ b/vendor/github.com/open-policy-agent/opa/v1/topdown/aggregates.go @@ -7,20 +7,20 @@ package topdown import ( "math/big" - "github.com/open-policy-agent/opa/ast" - "github.com/open-policy-agent/opa/topdown/builtins" + "github.com/open-policy-agent/opa/v1/ast" + "github.com/open-policy-agent/opa/v1/topdown/builtins" ) func builtinCount(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error { switch a := operands[0].Value.(type) { case *ast.Array: - return iter(ast.IntNumberTerm(a.Len())) + return iter(ast.InternedIntNumberTerm(a.Len())) case ast.Object: - return iter(ast.IntNumberTerm(a.Len())) + return iter(ast.InternedIntNumberTerm(a.Len())) case ast.Set: - return iter(ast.IntNumberTerm(a.Len())) + return iter(ast.InternedIntNumberTerm(a.Len())) case ast.String: - return iter(ast.IntNumberTerm(len([]rune(a)))) + return iter(ast.InternedIntNumberTerm(len([]rune(a)))) } return builtins.NewOperandTypeErr(1, operands[0].Value, "array", "object", "set", "string") } @@ -178,7 +178,7 @@ func builtinAll(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) err switch val := operands[0].Value.(type) { case ast.Set: res := true - match := ast.BooleanTerm(true) + match := ast.InternedBooleanTerm(true) val.Until(func(term *ast.Term) bool { if !match.Equal(term) { res = false @@ -186,10 +186,10 @@ func builtinAll(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) err } return false }) - return iter(ast.BooleanTerm(res)) + return iter(ast.InternedBooleanTerm(res)) case *ast.Array: res := true - match := ast.BooleanTerm(true) + match := ast.InternedBooleanTerm(true) val.Until(func(term *ast.Term) bool { if !match.Equal(term) { res = false @@ -197,7 +197,7 @@ func builtinAll(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) err } return false }) - return iter(ast.BooleanTerm(res)) + return iter(ast.InternedBooleanTerm(res)) default: return builtins.NewOperandTypeErr(1, operands[0].Value, "array", "set") } @@ -206,11 +206,11 @@ func builtinAll(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) err func builtinAny(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error { switch val := operands[0].Value.(type) { case ast.Set: - res := val.Len() > 0 && val.Contains(ast.BooleanTerm(true)) - return iter(ast.BooleanTerm(res)) + res := val.Len() > 0 && val.Contains(ast.InternedBooleanTerm(true)) + return iter(ast.InternedBooleanTerm(res)) case *ast.Array: res := false - match := ast.BooleanTerm(true) + match := ast.InternedBooleanTerm(true) val.Until(func(term *ast.Term) bool { if match.Equal(term) { res = true @@ -218,7 +218,7 @@ func builtinAny(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) err } return false }) - return iter(ast.BooleanTerm(res)) + return iter(ast.InternedBooleanTerm(res)) default: return builtins.NewOperandTypeErr(1, operands[0].Value, "array", "set") } @@ -228,27 +228,20 @@ func builtinMember(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) containee := operands[0] switch c := operands[1].Value.(type) { case ast.Set: - return iter(ast.BooleanTerm(c.Contains(containee))) + return iter(ast.InternedBooleanTerm(c.Contains(containee))) case *ast.Array: - ret := false - c.Until(func(v *ast.Term) bool { - if v.Value.Compare(containee.Value) == 0 { - ret = true + for i := 0; i < c.Len(); i++ { + if c.Elem(i).Value.Compare(containee.Value) == 0 { + return iter(ast.InternedBooleanTerm(true)) } - return ret - }) - return iter(ast.BooleanTerm(ret)) + } + return iter(ast.InternedBooleanTerm(false)) case ast.Object: - ret := false - c.Until(func(_, v *ast.Term) bool { - if v.Value.Compare(containee.Value) == 0 { - ret = true - } - return ret - }) - return iter(ast.BooleanTerm(ret)) + return iter(ast.InternedBooleanTerm(c.Until(func(_, v *ast.Term) bool { + return v.Value.Compare(containee.Value) == 0 + }))) } - return iter(ast.BooleanTerm(false)) + return iter(ast.InternedBooleanTerm(false)) } func builtinMemberWithKey(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error { @@ -259,9 +252,9 @@ func builtinMemberWithKey(_ BuiltinContext, operands []*ast.Term, iter func(*ast if act := c.Get(key); act != nil { ret = act.Value.Compare(val.Value) == 0 } - return iter(ast.BooleanTerm(ret)) + return iter(ast.InternedBooleanTerm(ret)) } - return iter(ast.BooleanTerm(false)) + return iter(ast.InternedBooleanTerm(false)) } func init() { diff --git a/vendor/github.com/open-policy-agent/opa/topdown/arithmetic.go b/vendor/github.com/open-policy-agent/opa/v1/topdown/arithmetic.go similarity index 95% rename from vendor/github.com/open-policy-agent/opa/topdown/arithmetic.go rename to vendor/github.com/open-policy-agent/opa/v1/topdown/arithmetic.go index 3ac703efa..68c3b496e 100644 --- a/vendor/github.com/open-policy-agent/opa/topdown/arithmetic.go +++ b/vendor/github.com/open-policy-agent/opa/v1/topdown/arithmetic.go @@ -8,8 +8,8 @@ import ( "fmt" "math/big" - "github.com/open-policy-agent/opa/ast" - "github.com/open-policy-agent/opa/topdown/builtins" + "github.com/open-policy-agent/opa/v1/ast" + "github.com/open-policy-agent/opa/v1/topdown/builtins" ) type arithArity1 func(a *big.Float) (*big.Float, error) @@ -67,7 +67,7 @@ func builtinPlus(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) er y, ok2 := n2.Int() if ok1 && ok2 && inSmallIntRange(x) && inSmallIntRange(y) { - return iter(ast.IntNumberTerm(x + y)) + return iter(ast.InternedIntNumberTerm(x + y)) } f, err := arithPlus(builtins.NumberToFloat(n1), builtins.NumberToFloat(n2)) @@ -91,7 +91,7 @@ func builtinMultiply(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term y, ok2 := n2.Int() if ok1 && ok2 && inSmallIntRange(x) && inSmallIntRange(y) { - return iter(ast.IntNumberTerm(x * y)) + return iter(ast.InternedIntNumberTerm(x * y)) } f, err := arithMultiply(builtins.NumberToFloat(n1), builtins.NumberToFloat(n2)) @@ -171,7 +171,7 @@ func builtinMinus(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) e y, oky := n2.Int() if okx && oky && inSmallIntRange(x) && inSmallIntRange(y) { - return iter(ast.IntNumberTerm(x - y)) + return iter(ast.InternedIntNumberTerm(x - y)) } f, err := arithMinus(builtins.NumberToFloat(n1), builtins.NumberToFloat(n2)) @@ -213,7 +213,7 @@ func builtinRem(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) err return fmt.Errorf("modulo by zero") } - return iter(ast.IntNumberTerm(x % y)) + return iter(ast.InternedIntNumberTerm(x % y)) } op1, err1 := builtins.NumberToInt(n1) diff --git a/vendor/github.com/open-policy-agent/opa/topdown/array.go b/vendor/github.com/open-policy-agent/opa/v1/topdown/array.go similarity index 93% rename from vendor/github.com/open-policy-agent/opa/topdown/array.go rename to vendor/github.com/open-policy-agent/opa/v1/topdown/array.go index e7fe5be64..d37204bef 100644 --- a/vendor/github.com/open-policy-agent/opa/topdown/array.go +++ b/vendor/github.com/open-policy-agent/opa/v1/topdown/array.go @@ -5,8 +5,8 @@ package topdown import ( - "github.com/open-policy-agent/opa/ast" - "github.com/open-policy-agent/opa/topdown/builtins" + "github.com/open-policy-agent/opa/v1/ast" + "github.com/open-policy-agent/opa/v1/topdown/builtins" ) func builtinArrayConcat(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error { @@ -33,7 +33,7 @@ func builtinArrayConcat(_ BuiltinContext, operands []*ast.Term, iter func(*ast.T i++ }) - return iter(ast.NewTerm(ast.NewArray(arrC...))) + return iter(ast.ArrayTerm(arrC...)) } func builtinArraySlice(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error { diff --git a/vendor/github.com/open-policy-agent/opa/topdown/binary.go b/vendor/github.com/open-policy-agent/opa/v1/topdown/binary.go similarity index 90% rename from vendor/github.com/open-policy-agent/opa/topdown/binary.go rename to vendor/github.com/open-policy-agent/opa/v1/topdown/binary.go index b4f9dbd39..6f7ebaf40 100644 --- a/vendor/github.com/open-policy-agent/opa/topdown/binary.go +++ b/vendor/github.com/open-policy-agent/opa/v1/topdown/binary.go @@ -5,8 +5,8 @@ package topdown import ( - "github.com/open-policy-agent/opa/ast" - "github.com/open-policy-agent/opa/topdown/builtins" + "github.com/open-policy-agent/opa/v1/ast" + "github.com/open-policy-agent/opa/v1/topdown/builtins" ) func builtinBinaryAnd(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error { diff --git a/vendor/github.com/open-policy-agent/opa/topdown/bindings.go b/vendor/github.com/open-policy-agent/opa/v1/topdown/bindings.go similarity index 99% rename from vendor/github.com/open-policy-agent/opa/topdown/bindings.go rename to vendor/github.com/open-policy-agent/opa/v1/topdown/bindings.go index 30a8ac5ec..ae6ca15da 100644 --- a/vendor/github.com/open-policy-agent/opa/topdown/bindings.go +++ b/vendor/github.com/open-policy-agent/opa/v1/topdown/bindings.go @@ -8,7 +8,7 @@ import ( "fmt" "strings" - "github.com/open-policy-agent/opa/ast" + "github.com/open-policy-agent/opa/v1/ast" ) type undo struct { @@ -68,7 +68,7 @@ func (u *bindings) Plug(a *ast.Term) *ast.Term { } func (u *bindings) PlugNamespaced(a *ast.Term, caller *bindings) *ast.Term { - if u != nil { + if u != nil && u.instr != nil { u.instr.startTimer(evalOpPlug) t := u.plugNamespaced(a, caller) u.instr.stopTimer(evalOpPlug) diff --git a/vendor/github.com/open-policy-agent/opa/topdown/bits.go b/vendor/github.com/open-policy-agent/opa/v1/topdown/bits.go similarity index 96% rename from vendor/github.com/open-policy-agent/opa/topdown/bits.go rename to vendor/github.com/open-policy-agent/opa/v1/topdown/bits.go index 7a63c0df1..e420ffe61 100644 --- a/vendor/github.com/open-policy-agent/opa/topdown/bits.go +++ b/vendor/github.com/open-policy-agent/opa/v1/topdown/bits.go @@ -7,8 +7,8 @@ package topdown import ( "math/big" - "github.com/open-policy-agent/opa/ast" - "github.com/open-policy-agent/opa/topdown/builtins" + "github.com/open-policy-agent/opa/v1/ast" + "github.com/open-policy-agent/opa/v1/topdown/builtins" ) type bitsArity1 func(a *big.Int) (*big.Int, error) diff --git a/vendor/github.com/open-policy-agent/opa/v1/topdown/builtins.go b/vendor/github.com/open-policy-agent/opa/v1/topdown/builtins.go new file mode 100644 index 000000000..e0b893d47 --- /dev/null +++ b/vendor/github.com/open-policy-agent/opa/v1/topdown/builtins.go @@ -0,0 +1,224 @@ +// Copyright 2016 The OPA Authors. All rights reserved. +// Use of this source code is governed by an Apache2 +// license that can be found in the LICENSE file. + +package topdown + +import ( + "context" + "encoding/binary" + "fmt" + "io" + "math/rand" + + "github.com/open-policy-agent/opa/v1/ast" + "github.com/open-policy-agent/opa/v1/metrics" + "github.com/open-policy-agent/opa/v1/topdown/builtins" + "github.com/open-policy-agent/opa/v1/topdown/cache" + "github.com/open-policy-agent/opa/v1/topdown/print" + "github.com/open-policy-agent/opa/v1/tracing" +) + +type ( + // Deprecated: Functional-style builtins are deprecated. Use BuiltinFunc instead. + FunctionalBuiltin1 func(op1 ast.Value) (output ast.Value, err error) + + // Deprecated: Functional-style builtins are deprecated. Use BuiltinFunc instead. + FunctionalBuiltin2 func(op1, op2 ast.Value) (output ast.Value, err error) + + // Deprecated: Functional-style builtins are deprecated. Use BuiltinFunc instead. + FunctionalBuiltin3 func(op1, op2, op3 ast.Value) (output ast.Value, err error) + + // Deprecated: Functional-style builtins are deprecated. Use BuiltinFunc instead. + FunctionalBuiltin4 func(op1, op2, op3, op4 ast.Value) (output ast.Value, err error) + + // BuiltinContext contains context from the evaluator that may be used by + // built-in functions. + BuiltinContext struct { + Context context.Context // request context that was passed when query started + Metrics metrics.Metrics // metrics registry for recording built-in specific metrics + Seed io.Reader // randomization source + Time *ast.Term // wall clock time + Cancel Cancel // atomic value that signals evaluation to halt + Runtime *ast.Term // runtime information on the OPA instance + Cache builtins.Cache // built-in function state cache + InterQueryBuiltinCache cache.InterQueryCache // cross-query built-in function state cache + InterQueryBuiltinValueCache cache.InterQueryValueCache // cross-query built-in function state value cache. this cache is useful for scenarios where the entry size cannot be calculated + NDBuiltinCache builtins.NDBCache // cache for non-deterministic built-in state + Location *ast.Location // location of built-in call + Tracers []Tracer // Deprecated: Use QueryTracers instead + QueryTracers []QueryTracer // tracer objects for trace() built-in function + TraceEnabled bool // indicates whether tracing is enabled for the evaluation + QueryID uint64 // identifies query being evaluated + ParentID uint64 // identifies parent of query being evaluated + PrintHook print.Hook // provides callback function to use for printing + RoundTripper CustomizeRoundTripper // customize transport to use for HTTP requests + DistributedTracingOpts tracing.Options // options to be used by distributed tracing. + rand *rand.Rand // randomization source for non-security-sensitive operations + Capabilities *ast.Capabilities + } + + // BuiltinFunc defines an interface for implementing built-in functions. + // The built-in function is called with the plugged operands from the call + // (including the output operands.) The implementation should evaluate the + // operands and invoke the iterator for each successful/defined output + // value. + BuiltinFunc func(bctx BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error +) + +// Rand returns a random number generator based on the Seed for this built-in +// context. The random number will be re-used across multiple calls to this +// function. If a random number generator cannot be created, an error is +// returned. +func (bctx *BuiltinContext) Rand() (*rand.Rand, error) { + + if bctx.rand != nil { + return bctx.rand, nil + } + + seed, err := readInt64(bctx.Seed) + if err != nil { + return nil, err + } + + bctx.rand = rand.New(rand.NewSource(seed)) + return bctx.rand, nil +} + +// RegisterBuiltinFunc adds a new built-in function to the evaluation engine. +func RegisterBuiltinFunc(name string, f BuiltinFunc) { + builtinFunctions[name] = builtinErrorWrapper(name, f) +} + +// Deprecated: Functional-style builtins are deprecated. Use RegisterBuiltinFunc instead. +func RegisterFunctionalBuiltin1(name string, fun FunctionalBuiltin1) { + builtinFunctions[name] = functionalWrapper1(name, fun) +} + +// Deprecated: Functional-style builtins are deprecated. Use RegisterBuiltinFunc instead. +func RegisterFunctionalBuiltin2(name string, fun FunctionalBuiltin2) { + builtinFunctions[name] = functionalWrapper2(name, fun) +} + +// Deprecated: Functional-style builtins are deprecated. Use RegisterBuiltinFunc instead. +func RegisterFunctionalBuiltin3(name string, fun FunctionalBuiltin3) { + builtinFunctions[name] = functionalWrapper3(name, fun) +} + +// Deprecated: Functional-style builtins are deprecated. Use RegisterBuiltinFunc instead. +func RegisterFunctionalBuiltin4(name string, fun FunctionalBuiltin4) { + builtinFunctions[name] = functionalWrapper4(name, fun) +} + +// GetBuiltin returns a built-in function implementation, nil if no built-in found. +func GetBuiltin(name string) BuiltinFunc { + return builtinFunctions[name] +} + +// Deprecated: The BuiltinEmpty type is no longer needed. Use nil return values instead. +type BuiltinEmpty struct{} + +func (BuiltinEmpty) Error() string { + return "" +} + +var builtinFunctions = map[string]BuiltinFunc{} + +func builtinErrorWrapper(name string, fn BuiltinFunc) BuiltinFunc { + return func(bctx BuiltinContext, args []*ast.Term, iter func(*ast.Term) error) error { + err := fn(bctx, args, iter) + if err == nil { + return nil + } + return handleBuiltinErr(name, bctx.Location, err) + } +} + +func functionalWrapper1(name string, fn FunctionalBuiltin1) BuiltinFunc { + return func(bctx BuiltinContext, args []*ast.Term, iter func(*ast.Term) error) error { + result, err := fn(args[0].Value) + if err == nil { + return iter(ast.NewTerm(result)) + } + return handleBuiltinErr(name, bctx.Location, err) + } +} + +func functionalWrapper2(name string, fn FunctionalBuiltin2) BuiltinFunc { + return func(bctx BuiltinContext, args []*ast.Term, iter func(*ast.Term) error) error { + result, err := fn(args[0].Value, args[1].Value) + if err == nil { + return iter(ast.NewTerm(result)) + } + return handleBuiltinErr(name, bctx.Location, err) + } +} + +func functionalWrapper3(name string, fn FunctionalBuiltin3) BuiltinFunc { + return func(bctx BuiltinContext, args []*ast.Term, iter func(*ast.Term) error) error { + result, err := fn(args[0].Value, args[1].Value, args[2].Value) + if err == nil { + return iter(ast.NewTerm(result)) + } + return handleBuiltinErr(name, bctx.Location, err) + } +} + +func functionalWrapper4(name string, fn FunctionalBuiltin4) BuiltinFunc { + return func(bctx BuiltinContext, args []*ast.Term, iter func(*ast.Term) error) error { + result, err := fn(args[0].Value, args[1].Value, args[2].Value, args[3].Value) + if err == nil { + return iter(ast.NewTerm(result)) + } + if _, empty := err.(BuiltinEmpty); empty { + return nil + } + return handleBuiltinErr(name, bctx.Location, err) + } +} + +func handleBuiltinErr(name string, loc *ast.Location, err error) error { + switch err := err.(type) { + case BuiltinEmpty: + return nil + case *Error, Halt: + return err + case builtins.ErrOperand: + e := &Error{ + Code: TypeErr, + Message: fmt.Sprintf("%v: %v", name, err.Error()), + Location: loc, + } + return e.Wrap(err) + default: + e := &Error{ + Code: BuiltinErr, + Message: fmt.Sprintf("%v: %v", name, err.Error()), + Location: loc, + } + return e.Wrap(err) + } +} + +func readInt64(r io.Reader) (int64, error) { + bs := make([]byte, 8) + n, err := io.ReadFull(r, bs) + if n != len(bs) || err != nil { + return 0, err + } + return int64(binary.BigEndian.Uint64(bs)), nil +} + +// Used to get older-style (ast.Term, error) tuples out of newer functions. +func getResult(fn BuiltinFunc, operands ...*ast.Term) (*ast.Term, error) { + var result *ast.Term + extractionFn := func(r *ast.Term) error { + result = r + return nil + } + err := fn(BuiltinContext{}, operands, extractionFn) + if err != nil { + return nil, err + } + return result, nil +} diff --git a/vendor/github.com/open-policy-agent/opa/v1/topdown/builtins/builtins.go b/vendor/github.com/open-policy-agent/opa/v1/topdown/builtins/builtins.go new file mode 100644 index 000000000..c788cf253 --- /dev/null +++ b/vendor/github.com/open-policy-agent/opa/v1/topdown/builtins/builtins.go @@ -0,0 +1,328 @@ +// Copyright 2016 The OPA Authors. All rights reserved. +// Use of this source code is governed by an Apache2 +// license that can be found in the LICENSE file. + +// Package builtins contains utilities for implementing built-in functions. +package builtins + +import ( + "encoding/json" + "fmt" + "math/big" + "strings" + + "github.com/open-policy-agent/opa/v1/ast" + "github.com/open-policy-agent/opa/v1/util" +) + +// Cache defines the built-in cache used by the top-down evaluation. The keys +// must be comparable and should not be of type string. +type Cache map[interface{}]interface{} + +// Put updates the cache for the named built-in. +func (c Cache) Put(k, v interface{}) { + c[k] = v +} + +// Get returns the cached value for k. +func (c Cache) Get(k interface{}) (interface{}, bool) { + v, ok := c[k] + return v, ok +} + +// We use an ast.Object for the cached keys/values because a naive +// map[ast.Value]ast.Value will not correctly detect value equality of +// the member keys. +type NDBCache map[string]ast.Object + +func (c NDBCache) AsValue() ast.Value { + out := ast.NewObject() + for bname, obj := range c { + out.Insert(ast.StringTerm(bname), ast.NewTerm(obj)) + } + return out +} + +// Put updates the cache for the named built-in. +// Automatically creates the 2-level hierarchy as needed. +func (c NDBCache) Put(name string, k, v ast.Value) { + if _, ok := c[name]; !ok { + c[name] = ast.NewObject() + } + c[name].Insert(ast.NewTerm(k), ast.NewTerm(v)) +} + +// Get returns the cached value for k for the named builtin. +func (c NDBCache) Get(name string, k ast.Value) (ast.Value, bool) { + if m, ok := c[name]; ok { + v := m.Get(ast.NewTerm(k)) + if v != nil { + return v.Value, true + } + return nil, false + } + return nil, false +} + +// Convenience functions for serializing the data structure. +func (c NDBCache) MarshalJSON() ([]byte, error) { + v, err := ast.JSON(c.AsValue()) + if err != nil { + return nil, err + } + return json.Marshal(v) +} + +func (c *NDBCache) UnmarshalJSON(data []byte) error { + out := map[string]ast.Object{} + var incoming interface{} + + // Note: We use util.Unmarshal instead of json.Unmarshal to get + // correct deserialization of number types. + err := util.Unmarshal(data, &incoming) + if err != nil { + return err + } + + // Convert interface types back into ast.Value types. + nestedObject, err := ast.InterfaceToValue(incoming) + if err != nil { + return err + } + + // Reconstruct NDBCache from nested ast.Object structure. + if source, ok := nestedObject.(ast.Object); ok { + err = source.Iter(func(k, v *ast.Term) error { + if obj, ok := v.Value.(ast.Object); ok { + out[string(k.Value.(ast.String))] = obj + return nil + } + return fmt.Errorf("expected Object, got other Value type in conversion") + }) + if err != nil { + return err + } + } + + *c = out + + return nil +} + +// ErrOperand represents an invalid operand has been passed to a built-in +// function. Built-ins should return ErrOperand to indicate a type error has +// occurred. +type ErrOperand string + +func (err ErrOperand) Error() string { + return string(err) +} + +// NewOperandErr returns a generic operand error. +func NewOperandErr(pos int, f string, a ...interface{}) error { + f = fmt.Sprintf("operand %v ", pos) + f + return ErrOperand(fmt.Sprintf(f, a...)) +} + +// NewOperandTypeErr returns an operand error indicating the operand's type was wrong. +func NewOperandTypeErr(pos int, got ast.Value, expected ...string) error { + + if len(expected) == 1 { + return NewOperandErr(pos, "must be %v but got %v", expected[0], ast.TypeName(got)) + } + + return NewOperandErr(pos, "must be one of {%v} but got %v", strings.Join(expected, ", "), ast.TypeName(got)) +} + +// NewOperandElementErr returns an operand error indicating an element in the +// composite operand was wrong. +func NewOperandElementErr(pos int, composite ast.Value, got ast.Value, expected ...string) error { + + tpe := ast.TypeName(composite) + + if len(expected) == 1 { + return NewOperandErr(pos, "must be %v of %vs but got %v containing %v", tpe, expected[0], tpe, ast.TypeName(got)) + } + + return NewOperandErr(pos, "must be %v of (any of) {%v} but got %v containing %v", tpe, strings.Join(expected, ", "), tpe, ast.TypeName(got)) +} + +// NewOperandEnumErr returns an operand error indicating a value was wrong. +func NewOperandEnumErr(pos int, expected ...string) error { + + if len(expected) == 1 { + return NewOperandErr(pos, "must be %v", expected[0]) + } + + return NewOperandErr(pos, "must be one of {%v}", strings.Join(expected, ", ")) +} + +// IntOperand converts x to an int. If the cast fails, a descriptive error is +// returned. +func IntOperand(x ast.Value, pos int) (int, error) { + n, ok := x.(ast.Number) + if !ok { + return 0, NewOperandTypeErr(pos, x, "number") + } + + i, ok := n.Int() + if !ok { + return 0, NewOperandErr(pos, "must be integer number but got floating-point number") + } + + return i, nil +} + +// BigIntOperand converts x to a big int. If the cast fails, a descriptive error +// is returned. +func BigIntOperand(x ast.Value, pos int) (*big.Int, error) { + n, err := NumberOperand(x, 1) + if err != nil { + return nil, NewOperandTypeErr(pos, x, "integer") + } + bi, err := NumberToInt(n) + if err != nil { + return nil, NewOperandErr(pos, "must be integer number but got floating-point number") + } + + return bi, nil +} + +// NumberOperand converts x to a number. If the cast fails, a descriptive error is +// returned. +func NumberOperand(x ast.Value, pos int) (ast.Number, error) { + n, ok := x.(ast.Number) + if !ok { + return ast.Number(""), NewOperandTypeErr(pos, x, "number") + } + return n, nil +} + +// SetOperand converts x to a set. If the cast fails, a descriptive error is +// returned. +func SetOperand(x ast.Value, pos int) (ast.Set, error) { + s, ok := x.(ast.Set) + if !ok { + return nil, NewOperandTypeErr(pos, x, "set") + } + return s, nil +} + +// StringOperand converts x to a string. If the cast fails, a descriptive error is +// returned. +func StringOperand(x ast.Value, pos int) (ast.String, error) { + s, ok := x.(ast.String) + if !ok { + return ast.String(""), NewOperandTypeErr(pos, x, "string") + } + return s, nil +} + +// ObjectOperand converts x to an object. If the cast fails, a descriptive +// error is returned. +func ObjectOperand(x ast.Value, pos int) (ast.Object, error) { + o, ok := x.(ast.Object) + if !ok { + return nil, NewOperandTypeErr(pos, x, "object") + } + return o, nil +} + +// ArrayOperand converts x to an array. If the cast fails, a descriptive +// error is returned. +func ArrayOperand(x ast.Value, pos int) (*ast.Array, error) { + a, ok := x.(*ast.Array) + if !ok { + return ast.NewArray(), NewOperandTypeErr(pos, x, "array") + } + return a, nil +} + +// NumberToFloat converts n to a big float. +func NumberToFloat(n ast.Number) *big.Float { + r, ok := new(big.Float).SetString(string(n)) + if !ok { + panic("illegal value") + } + return r +} + +// FloatToNumber converts f to a number. +func FloatToNumber(f *big.Float) ast.Number { + var format byte = 'g' + if f.IsInt() { + format = 'f' + } + return ast.Number(f.Text(format, -1)) +} + +// NumberToInt converts n to a big int. +// If n cannot be converted to an big int, an error is returned. +func NumberToInt(n ast.Number) (*big.Int, error) { + f := NumberToFloat(n) + r, accuracy := f.Int(nil) + if accuracy != big.Exact { + return nil, fmt.Errorf("illegal value") + } + return r, nil +} + +// IntToNumber converts i to a number. +func IntToNumber(i *big.Int) ast.Number { + return ast.Number(i.String()) +} + +// StringSliceOperand converts x to a []string. If the cast fails, a descriptive error is +// returned. +func StringSliceOperand(a ast.Value, pos int) ([]string, error) { + type iterable interface { + Iter(func(*ast.Term) error) error + Len() int + } + + strs, ok := a.(iterable) + if !ok { + return nil, NewOperandTypeErr(pos, a, "array", "set") + } + + var outStrs = make([]string, 0, strs.Len()) + if err := strs.Iter(func(x *ast.Term) error { + s, ok := x.Value.(ast.String) + if !ok { + return NewOperandElementErr(pos, a, x.Value, "string") + } + outStrs = append(outStrs, string(s)) + return nil + }); err != nil { + return nil, err + } + + return outStrs, nil +} + +// RuneSliceOperand converts x to a []rune. If the cast fails, a descriptive error is +// returned. +func RuneSliceOperand(x ast.Value, pos int) ([]rune, error) { + a, err := ArrayOperand(x, pos) + if err != nil { + return nil, err + } + + var f = make([]rune, a.Len()) + for k := 0; k < a.Len(); k++ { + b := a.Elem(k) + c, ok := b.Value.(ast.String) + if !ok { + return nil, NewOperandElementErr(pos, x, b.Value, "string") + } + + d := []rune(string(c)) + if len(d) != 1 { + return nil, NewOperandElementErr(pos, x, b.Value, "rune") + } + + f[k] = d[0] + } + + return f, nil +} diff --git a/vendor/github.com/open-policy-agent/opa/v1/topdown/cache.go b/vendor/github.com/open-policy-agent/opa/v1/topdown/cache.go new file mode 100644 index 000000000..607abf46e --- /dev/null +++ b/vendor/github.com/open-policy-agent/opa/v1/topdown/cache.go @@ -0,0 +1,352 @@ +// Copyright 2017 The OPA Authors. All rights reserved. +// Use of this source code is governed by an Apache2 +// license that can be found in the LICENSE file. + +package topdown + +import ( + "github.com/open-policy-agent/opa/v1/ast" + "github.com/open-policy-agent/opa/v1/util" +) + +// VirtualCache defines the interface for a cache that stores the results of +// evaluated virtual documents (rules). +// The cache is a stack of frames, where each frame is a mapping from references +// to values. +type VirtualCache interface { + // Push pushes a new, empty frame of value mappings onto the stack. + Push() + + // Pop pops the top frame of value mappings from the stack, removing all associated entries. + Pop() + + // Get returns the value associated with the given reference. The second return value + // indicates whether the reference has a recorded 'undefined' result. + Get(ref ast.Ref) (*ast.Term, bool) + + // Put associates the given reference with the given value. If the value is nil, the reference + // is marked as having an 'undefined' result. + Put(ref ast.Ref, value *ast.Term) + + // Keys returns the set of keys that have been cached for the active frame. + Keys() []ast.Ref +} + +type virtualCache struct { + stack []*virtualCacheElem +} + +type virtualCacheElem struct { + value *ast.Term + children *util.HashMap + undefined bool +} + +func NewVirtualCache() VirtualCache { + cache := &virtualCache{} + cache.Push() + return cache +} + +func (c *virtualCache) Push() { + c.stack = append(c.stack, newVirtualCacheElem()) +} + +func (c *virtualCache) Pop() { + c.stack = c.stack[:len(c.stack)-1] +} + +// Returns the resolved value of the AST term and a flag indicating if the value +// should be interpretted as undefined: +// +// nil, true indicates the ref is undefined +// ast.Term, false indicates the ref is defined +// nil, false indicates the ref has not been cached +// ast.Term, true is impossible +func (c *virtualCache) Get(ref ast.Ref) (*ast.Term, bool) { + node := c.stack[len(c.stack)-1] + for i := 0; i < len(ref); i++ { + x, ok := node.children.Get(ref[i]) + if !ok { + return nil, false + } + node = x.(*virtualCacheElem) + } + if node.undefined { + return nil, true + } + + return node.value, false +} + +// If value is a nil pointer, set the 'undefined' flag on the cache element to +// indicate that the Ref has resolved to undefined. +func (c *virtualCache) Put(ref ast.Ref, value *ast.Term) { + node := c.stack[len(c.stack)-1] + for i := 0; i < len(ref); i++ { + x, ok := node.children.Get(ref[i]) + if ok { + node = x.(*virtualCacheElem) + } else { + next := newVirtualCacheElem() + node.children.Put(ref[i], next) + node = next + } + } + if value != nil { + node.value = value + } else { + node.undefined = true + } +} + +func (c *virtualCache) Keys() []ast.Ref { + node := c.stack[len(c.stack)-1] + return keysRecursive(nil, node) +} + +func keysRecursive(root ast.Ref, node *virtualCacheElem) []ast.Ref { + var keys []ast.Ref + node.children.Iter(func(k, v util.T) bool { + ref := root.Append(k.(*ast.Term)) + if v.(*virtualCacheElem).value != nil { + keys = append(keys, ref) + } + if v.(*virtualCacheElem).children.Len() > 0 { + keys = append(keys, keysRecursive(ref, v.(*virtualCacheElem))...) + } + return false + }) + return keys +} + +func newVirtualCacheElem() *virtualCacheElem { + return &virtualCacheElem{children: newVirtualCacheHashMap()} +} + +func newVirtualCacheHashMap() *util.HashMap { + return util.NewHashMap(func(a, b util.T) bool { + return a.(*ast.Term).Equal(b.(*ast.Term)) + }, func(x util.T) int { + return x.(*ast.Term).Hash() + }) +} + +// baseCache implements a trie structure to cache base documents read out of +// storage. Values inserted into the cache may contain other values that were +// previously inserted. In this case, the previous values are erased from the +// structure. +type baseCache struct { + root *baseCacheElem +} + +func newBaseCache() *baseCache { + return &baseCache{ + root: newBaseCacheElem(), + } +} + +func (c *baseCache) Get(ref ast.Ref) ast.Value { + node := c.root + for i := 0; i < len(ref); i++ { + node = node.children[ref[i].Value] + if node == nil { + return nil + } else if node.value != nil { + result, err := node.value.Find(ref[i+1:]) + if err != nil { + return nil + } + return result + } + } + return nil +} + +func (c *baseCache) Put(ref ast.Ref, value ast.Value) { + node := c.root + for i := 0; i < len(ref); i++ { + if child, ok := node.children[ref[i].Value]; ok { + node = child + } else { + child := newBaseCacheElem() + node.children[ref[i].Value] = child + node = child + } + } + node.set(value) +} + +type baseCacheElem struct { + value ast.Value + children map[ast.Value]*baseCacheElem +} + +func newBaseCacheElem() *baseCacheElem { + return &baseCacheElem{ + children: map[ast.Value]*baseCacheElem{}, + } +} + +func (e *baseCacheElem) set(value ast.Value) { + e.value = value + e.children = map[ast.Value]*baseCacheElem{} +} + +type refStack struct { + sl []refStackElem +} + +type refStackElem struct { + refs []ast.Ref +} + +func newRefStack() *refStack { + return &refStack{} +} + +func (s *refStack) Push(refs []ast.Ref) { + s.sl = append(s.sl, refStackElem{refs: refs}) +} + +func (s *refStack) Pop() { + s.sl = s.sl[:len(s.sl)-1] +} + +func (s *refStack) Prefixed(ref ast.Ref) bool { + if s != nil { + for i := len(s.sl) - 1; i >= 0; i-- { + for j := range s.sl[i].refs { + if ref.HasPrefix(s.sl[i].refs[j]) { + return true + } + } + } + } + return false +} + +type comprehensionCache struct { + stack []map[*ast.Term]*comprehensionCacheElem +} + +type comprehensionCacheElem struct { + value *ast.Term + children *util.HashMap +} + +func newComprehensionCache() *comprehensionCache { + cache := &comprehensionCache{} + cache.Push() + return cache +} + +func (c *comprehensionCache) Push() { + c.stack = append(c.stack, map[*ast.Term]*comprehensionCacheElem{}) +} + +func (c *comprehensionCache) Pop() { + c.stack = c.stack[:len(c.stack)-1] +} + +func (c *comprehensionCache) Elem(t *ast.Term) (*comprehensionCacheElem, bool) { + elem, ok := c.stack[len(c.stack)-1][t] + return elem, ok +} + +func (c *comprehensionCache) Set(t *ast.Term, elem *comprehensionCacheElem) { + c.stack[len(c.stack)-1][t] = elem +} + +func newComprehensionCacheElem() *comprehensionCacheElem { + return &comprehensionCacheElem{children: newComprehensionCacheHashMap()} +} + +func (c *comprehensionCacheElem) Get(key []*ast.Term) *ast.Term { + node := c + for i := 0; i < len(key); i++ { + x, ok := node.children.Get(key[i]) + if !ok { + return nil + } + node = x.(*comprehensionCacheElem) + } + return node.value +} + +func (c *comprehensionCacheElem) Put(key []*ast.Term, value *ast.Term) { + node := c + for i := 0; i < len(key); i++ { + x, ok := node.children.Get(key[i]) + if ok { + node = x.(*comprehensionCacheElem) + } else { + next := newComprehensionCacheElem() + node.children.Put(key[i], next) + node = next + } + } + node.value = value +} + +func newComprehensionCacheHashMap() *util.HashMap { + return util.NewHashMap(func(a, b util.T) bool { + return a.(*ast.Term).Equal(b.(*ast.Term)) + }, func(x util.T) int { + return x.(*ast.Term).Hash() + }) +} + +type functionMocksStack struct { + stack []*functionMocksElem +} + +type functionMocksElem []frame + +type frame map[string]*ast.Term + +func newFunctionMocksStack() *functionMocksStack { + stack := &functionMocksStack{} + stack.Push() + return stack +} + +func newFunctionMocksElem() *functionMocksElem { + return &functionMocksElem{} +} + +func (s *functionMocksStack) Push() { + s.stack = append(s.stack, newFunctionMocksElem()) +} + +func (s *functionMocksStack) Pop() { + s.stack = s.stack[:len(s.stack)-1] +} + +func (s *functionMocksStack) PopPairs() { + current := s.stack[len(s.stack)-1] + *current = (*current)[:len(*current)-1] +} + +func (s *functionMocksStack) PutPairs(mocks [][2]*ast.Term) { + el := frame{} + for i := range mocks { + el[mocks[i][0].Value.String()] = mocks[i][1] + } + s.Put(el) +} + +func (s *functionMocksStack) Put(el frame) { + current := s.stack[len(s.stack)-1] + *current = append(*current, el) +} + +func (s *functionMocksStack) Get(f ast.Ref) (*ast.Term, bool) { + current := *s.stack[len(s.stack)-1] + for i := len(current) - 1; i >= 0; i-- { + if r, ok := current[i][f.String()]; ok { + return r, true + } + } + return nil, false +} diff --git a/vendor/github.com/open-policy-agent/opa/v1/topdown/cache/cache.go b/vendor/github.com/open-policy-agent/opa/v1/topdown/cache/cache.go new file mode 100644 index 000000000..1edadbb28 --- /dev/null +++ b/vendor/github.com/open-policy-agent/opa/v1/topdown/cache/cache.go @@ -0,0 +1,409 @@ +// Copyright 2020 The OPA Authors. All rights reserved. +// Use of this source code is governed by an Apache2 +// license that can be found in the LICENSE file. + +// Package cache defines the inter-query cache interface that can cache data across queries +package cache + +import ( + "container/list" + "context" + "fmt" + "math" + "sync" + "time" + + "github.com/open-policy-agent/opa/v1/ast" + "github.com/open-policy-agent/opa/v1/util" +) + +const ( + defaultInterQueryBuiltinValueCacheSize = int(0) // unlimited + defaultMaxSizeBytes = int64(0) // unlimited + defaultForcedEvictionThresholdPercentage = int64(100) // trigger at max_size_bytes + defaultStaleEntryEvictionPeriodSeconds = int64(0) // never +) + +// Config represents the configuration for the inter-query builtin cache. +type Config struct { + InterQueryBuiltinCache InterQueryBuiltinCacheConfig `json:"inter_query_builtin_cache"` + InterQueryBuiltinValueCache InterQueryBuiltinValueCacheConfig `json:"inter_query_builtin_value_cache"` +} + +// InterQueryBuiltinValueCacheConfig represents the configuration of the inter-query value cache that built-in functions can utilize. +// MaxNumEntries - max number of cache entries +type InterQueryBuiltinValueCacheConfig struct { + MaxNumEntries *int `json:"max_num_entries,omitempty"` +} + +// InterQueryBuiltinCacheConfig represents the configuration of the inter-query cache that built-in functions can utilize. +// MaxSizeBytes - max capacity of cache in bytes +// ForcedEvictionThresholdPercentage - capacity usage in percentage after which forced FIFO eviction starts +// StaleEntryEvictionPeriodSeconds - time period between end of previous and start of new stale entry eviction routine +type InterQueryBuiltinCacheConfig struct { + MaxSizeBytes *int64 `json:"max_size_bytes,omitempty"` + ForcedEvictionThresholdPercentage *int64 `json:"forced_eviction_threshold_percentage,omitempty"` + StaleEntryEvictionPeriodSeconds *int64 `json:"stale_entry_eviction_period_seconds,omitempty"` +} + +// ParseCachingConfig returns the config for the inter-query cache. +func ParseCachingConfig(raw []byte) (*Config, error) { + if raw == nil { + maxSize := new(int64) + *maxSize = defaultMaxSizeBytes + threshold := new(int64) + *threshold = defaultForcedEvictionThresholdPercentage + period := new(int64) + *period = defaultStaleEntryEvictionPeriodSeconds + + maxInterQueryBuiltinValueCacheSize := new(int) + *maxInterQueryBuiltinValueCacheSize = defaultInterQueryBuiltinValueCacheSize + + return &Config{InterQueryBuiltinCache: InterQueryBuiltinCacheConfig{MaxSizeBytes: maxSize, ForcedEvictionThresholdPercentage: threshold, StaleEntryEvictionPeriodSeconds: period}, + InterQueryBuiltinValueCache: InterQueryBuiltinValueCacheConfig{MaxNumEntries: maxInterQueryBuiltinValueCacheSize}}, nil + } + + var config Config + + if err := util.Unmarshal(raw, &config); err == nil { + if err = config.validateAndInjectDefaults(); err != nil { + return nil, err + } + } else { + return nil, err + } + + return &config, nil +} + +func (c *Config) validateAndInjectDefaults() error { + if c.InterQueryBuiltinCache.MaxSizeBytes == nil { + maxSize := new(int64) + *maxSize = defaultMaxSizeBytes + c.InterQueryBuiltinCache.MaxSizeBytes = maxSize + } + if c.InterQueryBuiltinCache.ForcedEvictionThresholdPercentage == nil { + threshold := new(int64) + *threshold = defaultForcedEvictionThresholdPercentage + c.InterQueryBuiltinCache.ForcedEvictionThresholdPercentage = threshold + } else { + threshold := *c.InterQueryBuiltinCache.ForcedEvictionThresholdPercentage + if threshold < 0 || threshold > 100 { + return fmt.Errorf("invalid forced_eviction_threshold_percentage %v", threshold) + } + } + if c.InterQueryBuiltinCache.StaleEntryEvictionPeriodSeconds == nil { + period := new(int64) + *period = defaultStaleEntryEvictionPeriodSeconds + c.InterQueryBuiltinCache.StaleEntryEvictionPeriodSeconds = period + } else { + period := *c.InterQueryBuiltinCache.StaleEntryEvictionPeriodSeconds + if period < 0 { + return fmt.Errorf("invalid stale_entry_eviction_period_seconds %v", period) + } + } + + if c.InterQueryBuiltinValueCache.MaxNumEntries == nil { + maxSize := new(int) + *maxSize = defaultInterQueryBuiltinValueCacheSize + c.InterQueryBuiltinValueCache.MaxNumEntries = maxSize + } else { + numEntries := *c.InterQueryBuiltinValueCache.MaxNumEntries + if numEntries < 0 { + return fmt.Errorf("invalid max_num_entries %v", numEntries) + } + } + + return nil +} + +// InterQueryCacheValue defines the interface for the data that the inter-query cache holds. +type InterQueryCacheValue interface { + SizeInBytes() int64 + Clone() (InterQueryCacheValue, error) +} + +// InterQueryCache defines the interface for the inter-query cache. +type InterQueryCache interface { + Get(key ast.Value) (value InterQueryCacheValue, found bool) + Insert(key ast.Value, value InterQueryCacheValue) int + InsertWithExpiry(key ast.Value, value InterQueryCacheValue, expiresAt time.Time) int + Delete(key ast.Value) + UpdateConfig(config *Config) + Clone(value InterQueryCacheValue) (InterQueryCacheValue, error) +} + +// NewInterQueryCache returns a new inter-query cache. +// The cache uses a FIFO eviction policy when it reaches the forced eviction threshold. +// Parameters: +// +// config - to configure the InterQueryCache +func NewInterQueryCache(config *Config) InterQueryCache { + return newCache(config) +} + +// NewInterQueryCacheWithContext returns a new inter-query cache with context. +// The cache uses a combination of FIFO eviction policy when it reaches the forced eviction threshold +// and a periodic cleanup routine to remove stale entries that exceed their expiration time, if specified. +// If configured with a zero stale_entry_eviction_period_seconds value, the stale entry cleanup routine is disabled. +// +// Parameters: +// +// ctx - used to control lifecycle of the stale entry cleanup routine +// config - to configure the InterQueryCache +func NewInterQueryCacheWithContext(ctx context.Context, config *Config) InterQueryCache { + iqCache := newCache(config) + if iqCache.staleEntryEvictionTimePeriodSeconds() > 0 { + go func() { + cleanupTicker := time.NewTicker(time.Duration(iqCache.staleEntryEvictionTimePeriodSeconds()) * time.Second) + for { + select { + case <-cleanupTicker.C: + // NOTE: We stop the ticker and create a new one here to ensure that applications + // get _at least_ staleEntryEvictionTimePeriodSeconds with the cache unlocked; + // see https://github.com/open-policy-agent/opa/pull/7188/files#r1855342998 + cleanupTicker.Stop() + iqCache.cleanStaleValues() + cleanupTicker = time.NewTicker(time.Duration(iqCache.staleEntryEvictionTimePeriodSeconds()) * time.Second) + case <-ctx.Done(): + cleanupTicker.Stop() + return + } + } + }() + } + + return iqCache +} + +type cacheItem struct { + value InterQueryCacheValue + expiresAt time.Time + keyElement *list.Element +} + +type cache struct { + items map[string]cacheItem + usage int64 + config *Config + l *list.List + mtx sync.Mutex +} + +func newCache(config *Config) *cache { + return &cache{ + items: map[string]cacheItem{}, + usage: 0, + config: config, + l: list.New(), + } +} + +// InsertWithExpiry inserts a key k into the cache with value v with an expiration time expiresAt. +// A zero time value for expiresAt indicates no expiry +func (c *cache) InsertWithExpiry(k ast.Value, v InterQueryCacheValue, expiresAt time.Time) (dropped int) { + c.mtx.Lock() + defer c.mtx.Unlock() + return c.unsafeInsert(k, v, expiresAt) +} + +// Insert inserts a key k into the cache with value v with no expiration time. +func (c *cache) Insert(k ast.Value, v InterQueryCacheValue) (dropped int) { + return c.InsertWithExpiry(k, v, time.Time{}) +} + +// Get returns the value in the cache for k. +func (c *cache) Get(k ast.Value) (InterQueryCacheValue, bool) { + c.mtx.Lock() + defer c.mtx.Unlock() + cacheItem, ok := c.unsafeGet(k) + + if ok { + return cacheItem.value, true + } + return nil, false +} + +// Delete deletes the value in the cache for k. +func (c *cache) Delete(k ast.Value) { + c.mtx.Lock() + defer c.mtx.Unlock() + c.unsafeDelete(k) +} + +func (c *cache) UpdateConfig(config *Config) { + if config == nil { + return + } + c.mtx.Lock() + defer c.mtx.Unlock() + c.config = config +} + +func (c *cache) Clone(value InterQueryCacheValue) (InterQueryCacheValue, error) { + c.mtx.Lock() + defer c.mtx.Unlock() + return c.unsafeClone(value) +} + +func (c *cache) unsafeInsert(k ast.Value, v InterQueryCacheValue, expiresAt time.Time) (dropped int) { + size := v.SizeInBytes() + limit := int64(math.Ceil(float64(c.forcedEvictionThresholdPercentage())/100.0) * (float64(c.maxSizeBytes()))) + if limit > 0 { + if size > limit { + dropped++ + return dropped + } + + for key := c.l.Front(); key != nil && (c.usage+size > limit); key = c.l.Front() { + dropKey := key.Value.(ast.Value) + c.unsafeDelete(dropKey) + dropped++ + } + } + + // By deleting the old value, if it exists, we ensure the usage variable stays correct + c.unsafeDelete(k) + + c.items[k.String()] = cacheItem{ + value: v, + expiresAt: expiresAt, + keyElement: c.l.PushBack(k), + } + c.usage += size + return dropped +} + +func (c *cache) unsafeGet(k ast.Value) (cacheItem, bool) { + value, ok := c.items[k.String()] + return value, ok +} + +func (c *cache) unsafeDelete(k ast.Value) { + cacheItem, ok := c.unsafeGet(k) + if !ok { + return + } + + c.usage -= cacheItem.value.SizeInBytes() + delete(c.items, k.String()) + c.l.Remove(cacheItem.keyElement) +} + +func (c *cache) unsafeClone(value InterQueryCacheValue) (InterQueryCacheValue, error) { + return value.Clone() +} + +func (c *cache) maxSizeBytes() int64 { + if c.config == nil { + return defaultMaxSizeBytes + } + return *c.config.InterQueryBuiltinCache.MaxSizeBytes +} + +func (c *cache) forcedEvictionThresholdPercentage() int64 { + if c.config == nil { + return defaultForcedEvictionThresholdPercentage + } + return *c.config.InterQueryBuiltinCache.ForcedEvictionThresholdPercentage +} + +func (c *cache) staleEntryEvictionTimePeriodSeconds() int64 { + if c.config == nil { + return defaultStaleEntryEvictionPeriodSeconds + } + return *c.config.InterQueryBuiltinCache.StaleEntryEvictionPeriodSeconds +} + +func (c *cache) cleanStaleValues() (dropped int) { + c.mtx.Lock() + defer c.mtx.Unlock() + for key := c.l.Front(); key != nil; { + nextKey := key.Next() + // if expiresAt is zero, the item doesn't have an expiry + if ea := c.items[(key.Value.(ast.Value)).String()].expiresAt; !ea.IsZero() && ea.Before(time.Now()) { + c.unsafeDelete(key.Value.(ast.Value)) + dropped++ + } + key = nextKey + } + return dropped +} + +type InterQueryValueCache interface { + Get(key ast.Value) (value any, found bool) + Insert(key ast.Value, value any) int + Delete(key ast.Value) + UpdateConfig(config *Config) +} + +type interQueryValueCache struct { + items map[string]any + config *Config + mtx sync.RWMutex +} + +// Get returns the value in the cache for k. +func (c *interQueryValueCache) Get(k ast.Value) (any, bool) { + c.mtx.RLock() + defer c.mtx.RUnlock() + value, ok := c.items[k.String()] + return value, ok +} + +// Insert inserts a key k into the cache with value v. +func (c *interQueryValueCache) Insert(k ast.Value, v any) (dropped int) { + c.mtx.Lock() + defer c.mtx.Unlock() + + maxEntries := c.maxNumEntries() + if maxEntries > 0 { + if len(c.items) >= maxEntries { + itemsToRemove := len(c.items) - maxEntries + 1 + + // Delete a (semi-)random key to make room for the new one. + for k := range c.items { + delete(c.items, k) + dropped++ + + if itemsToRemove == dropped { + break + } + } + } + } + + c.items[k.String()] = v + return dropped +} + +// Delete deletes the value in the cache for k. +func (c *interQueryValueCache) Delete(k ast.Value) { + c.mtx.Lock() + defer c.mtx.Unlock() + delete(c.items, k.String()) +} + +// UpdateConfig updates the cache config. +func (c *interQueryValueCache) UpdateConfig(config *Config) { + if config == nil { + return + } + c.mtx.Lock() + defer c.mtx.Unlock() + c.config = config +} + +func (c *interQueryValueCache) maxNumEntries() int { + if c.config == nil { + return defaultInterQueryBuiltinValueCacheSize + } + return *c.config.InterQueryBuiltinValueCache.MaxNumEntries +} + +func NewInterQueryValueCache(_ context.Context, config *Config) InterQueryValueCache { + return &interQueryValueCache{ + items: map[string]any{}, + config: config, + } +} diff --git a/vendor/github.com/open-policy-agent/opa/v1/topdown/cancel.go b/vendor/github.com/open-policy-agent/opa/v1/topdown/cancel.go new file mode 100644 index 000000000..534e0799a --- /dev/null +++ b/vendor/github.com/open-policy-agent/opa/v1/topdown/cancel.go @@ -0,0 +1,33 @@ +// Copyright 2017 The OPA Authors. All rights reserved. +// Use of this source code is governed by an Apache2 +// license that can be found in the LICENSE file. + +package topdown + +import ( + "sync/atomic" +) + +// Cancel defines the interface for cancelling topdown queries. Cancel +// operations are thread-safe and idempotent. +type Cancel interface { + Cancel() + Cancelled() bool +} + +type cancel struct { + flag int32 +} + +// NewCancel returns a new Cancel object. +func NewCancel() Cancel { + return &cancel{} +} + +func (c *cancel) Cancel() { + atomic.StoreInt32(&c.flag, 1) +} + +func (c *cancel) Cancelled() bool { + return atomic.LoadInt32(&c.flag) != 0 +} diff --git a/vendor/github.com/open-policy-agent/opa/topdown/casts.go b/vendor/github.com/open-policy-agent/opa/v1/topdown/casts.go similarity index 82% rename from vendor/github.com/open-policy-agent/opa/topdown/casts.go rename to vendor/github.com/open-policy-agent/opa/v1/topdown/casts.go index 2eb8f97fc..9be7271c4 100644 --- a/vendor/github.com/open-policy-agent/opa/topdown/casts.go +++ b/vendor/github.com/open-policy-agent/opa/v1/topdown/casts.go @@ -6,24 +6,38 @@ package topdown import ( "strconv" + "strings" - "github.com/open-policy-agent/opa/ast" - "github.com/open-policy-agent/opa/topdown/builtins" + "github.com/open-policy-agent/opa/v1/ast" + "github.com/open-policy-agent/opa/v1/topdown/builtins" ) func builtinToNumber(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error { switch a := operands[0].Value.(type) { case ast.Null: - return iter(ast.NumberTerm("0")) + return iter(ast.InternedIntNumberTerm(0)) case ast.Boolean: if a { - return iter(ast.NumberTerm("1")) + return iter(ast.InternedIntNumberTerm(1)) } - return iter(ast.NumberTerm("0")) + return iter(ast.InternedIntNumberTerm(0)) case ast.Number: return iter(ast.NewTerm(a)) case ast.String: - _, err := strconv.ParseFloat(string(a), 64) + strValue := string(a) + + if it := ast.InternedIntNumberTermFromString(strValue); it != nil { + return iter(it) + } + + trimmedVal := strings.TrimLeft(strValue, "+-") + lowerCaseVal := strings.ToLower(trimmedVal) + + if lowerCaseVal == "inf" || lowerCaseVal == "infinity" || lowerCaseVal == "nan" { + return builtins.NewOperandTypeErr(1, operands[0].Value, "valid number string") + } + + _, err := strconv.ParseFloat(strValue, 64) if err != nil { return err } diff --git a/vendor/github.com/open-policy-agent/opa/topdown/cidr.go b/vendor/github.com/open-policy-agent/opa/v1/topdown/cidr.go similarity index 95% rename from vendor/github.com/open-policy-agent/opa/topdown/cidr.go rename to vendor/github.com/open-policy-agent/opa/v1/topdown/cidr.go index 5b011bd16..113bd2f37 100644 --- a/vendor/github.com/open-policy-agent/opa/topdown/cidr.go +++ b/vendor/github.com/open-policy-agent/opa/v1/topdown/cidr.go @@ -8,9 +8,9 @@ import ( "net" "sort" - "github.com/open-policy-agent/opa/ast" cidrMerge "github.com/open-policy-agent/opa/internal/cidr/merge" - "github.com/open-policy-agent/opa/topdown/builtins" + "github.com/open-policy-agent/opa/v1/ast" + "github.com/open-policy-agent/opa/v1/topdown/builtins" ) func getNetFromOperand(v ast.Value) (*net.IPNet, error) { @@ -75,7 +75,7 @@ func builtinNetCIDRIntersects(_ BuiltinContext, operands []*ast.Term, iter func( // If either net contains the others starting IP they are overlapping cidrsOverlap := cidrnetA.Contains(cidrnetB.IP) || cidrnetB.Contains(cidrnetA.IP) - return iter(ast.BooleanTerm(cidrsOverlap)) + return iter(ast.InternedBooleanTerm(cidrsOverlap)) } func builtinNetCIDRContains(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error { @@ -92,7 +92,7 @@ func builtinNetCIDRContains(_ BuiltinContext, operands []*ast.Term, iter func(*a ip := net.ParseIP(string(bStr)) if ip != nil { - return iter(ast.BooleanTerm(cidrnetA.Contains(ip))) + return iter(ast.InternedBooleanTerm(cidrnetA.Contains(ip))) } // It wasn't an IP, try and parse it as a CIDR @@ -113,7 +113,7 @@ func builtinNetCIDRContains(_ BuiltinContext, operands []*ast.Term, iter func(*a cidrContained = cidrnetA.Contains(lastIP) } - return iter(ast.BooleanTerm(cidrContained)) + return iter(ast.InternedBooleanTerm(cidrContained)) } var errNetCIDRContainsMatchElementType = errors.New("element must be string or non-empty array") @@ -142,7 +142,7 @@ func evalNetCIDRContainsMatchesOperand(operand int, a *ast.Term, iter func(cidr, if err != nil { return fmt.Errorf("operand %v: %v", operand, err) } - if err := iter(cidr, ast.IntNumberTerm(i)); err != nil { + if err := iter(cidr, ast.InternedIntNumberTerm(i)); err != nil { return err } } @@ -219,13 +219,13 @@ func builtinNetCIDRExpand(bctx BuiltinContext, operands []*ast.Term, iter func(* func builtinNetCIDRIsValid(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error { cidr, err := builtins.StringOperand(operands[0].Value, 1) if err != nil { - return iter(ast.BooleanTerm(false)) + return iter(ast.InternedBooleanTerm(false)) } if _, _, err := net.ParseCIDR(string(cidr)); err != nil { - return iter(ast.BooleanTerm(false)) + return iter(ast.InternedBooleanTerm(false)) } - return iter(ast.BooleanTerm(true)) + return iter(ast.InternedBooleanTerm(true)) } type cidrBlockRange struct { diff --git a/vendor/github.com/open-policy-agent/opa/topdown/comparison.go b/vendor/github.com/open-policy-agent/opa/v1/topdown/comparison.go similarity index 90% rename from vendor/github.com/open-policy-agent/opa/topdown/comparison.go rename to vendor/github.com/open-policy-agent/opa/v1/topdown/comparison.go index 0d033d2c3..9e1585a28 100644 --- a/vendor/github.com/open-policy-agent/opa/topdown/comparison.go +++ b/vendor/github.com/open-policy-agent/opa/v1/topdown/comparison.go @@ -4,7 +4,7 @@ package topdown -import "github.com/open-policy-agent/opa/ast" +import "github.com/open-policy-agent/opa/v1/ast" type compareFunc func(a, b ast.Value) bool @@ -34,7 +34,7 @@ func compareEq(a, b ast.Value) bool { func builtinCompare(cmp compareFunc) BuiltinFunc { return func(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error { - return iter(ast.BooleanTerm(cmp(operands[0].Value, operands[1].Value))) + return iter(ast.InternedBooleanTerm(cmp(operands[0].Value, operands[1].Value))) } } diff --git a/vendor/github.com/open-policy-agent/opa/topdown/copypropagation/copypropagation.go b/vendor/github.com/open-policy-agent/opa/v1/topdown/copypropagation/copypropagation.go similarity index 99% rename from vendor/github.com/open-policy-agent/opa/topdown/copypropagation/copypropagation.go rename to vendor/github.com/open-policy-agent/opa/v1/topdown/copypropagation/copypropagation.go index 8824d19bd..233bbcad1 100644 --- a/vendor/github.com/open-policy-agent/opa/topdown/copypropagation/copypropagation.go +++ b/vendor/github.com/open-policy-agent/opa/v1/topdown/copypropagation/copypropagation.go @@ -8,7 +8,7 @@ import ( "fmt" "sort" - "github.com/open-policy-agent/opa/ast" + "github.com/open-policy-agent/opa/v1/ast" ) // CopyPropagator implements a simple copy propagation optimization to remove diff --git a/vendor/github.com/open-policy-agent/opa/topdown/copypropagation/unionfind.go b/vendor/github.com/open-policy-agent/opa/v1/topdown/copypropagation/unionfind.go similarity index 96% rename from vendor/github.com/open-policy-agent/opa/topdown/copypropagation/unionfind.go rename to vendor/github.com/open-policy-agent/opa/v1/topdown/copypropagation/unionfind.go index 38ec56f31..679464250 100644 --- a/vendor/github.com/open-policy-agent/opa/topdown/copypropagation/unionfind.go +++ b/vendor/github.com/open-policy-agent/opa/v1/topdown/copypropagation/unionfind.go @@ -7,8 +7,8 @@ package copypropagation import ( "fmt" - "github.com/open-policy-agent/opa/ast" - "github.com/open-policy-agent/opa/util" + "github.com/open-policy-agent/opa/v1/ast" + "github.com/open-policy-agent/opa/v1/util" ) type rankFunc func(*unionFindRoot, *unionFindRoot) (*unionFindRoot, *unionFindRoot) diff --git a/vendor/github.com/open-policy-agent/opa/topdown/crypto.go b/vendor/github.com/open-policy-agent/opa/v1/topdown/crypto.go similarity index 97% rename from vendor/github.com/open-policy-agent/opa/topdown/crypto.go rename to vendor/github.com/open-policy-agent/opa/v1/topdown/crypto.go index f24432a26..ff5355074 100644 --- a/vendor/github.com/open-policy-agent/opa/topdown/crypto.go +++ b/vendor/github.com/open-policy-agent/opa/v1/topdown/crypto.go @@ -25,9 +25,9 @@ import ( "github.com/open-policy-agent/opa/internal/jwx/jwk" - "github.com/open-policy-agent/opa/ast" - "github.com/open-policy-agent/opa/topdown/builtins" - "github.com/open-policy-agent/opa/util" + "github.com/open-policy-agent/opa/v1/ast" + "github.com/open-policy-agent/opa/v1/topdown/builtins" + "github.com/open-policy-agent/opa/v1/util" ) const ( @@ -96,7 +96,7 @@ func builtinCryptoX509ParseAndVerifyCertificates(_ BuiltinContext, operands []*a } invalid := ast.ArrayTerm( - ast.BooleanTerm(false), + ast.InternedBooleanTerm(false), ast.NewTerm(ast.NewArray()), ) @@ -116,7 +116,7 @@ func builtinCryptoX509ParseAndVerifyCertificates(_ BuiltinContext, operands []*a } valid := ast.ArrayTerm( - ast.BooleanTerm(true), + ast.InternedBooleanTerm(true), ast.NewTerm(value), ) @@ -152,14 +152,12 @@ func builtinCryptoX509ParseAndVerifyCertificatesWithOptions(_ BuiltinContext, op return err } - invalid := ast.ArrayTerm( - ast.BooleanTerm(false), - ast.NewTerm(ast.NewArray()), - ) - certs, err := getX509CertsFromString(string(input)) if err != nil { - return iter(invalid) + return iter(ast.ArrayTerm( + ast.InternedBooleanTerm(false), + ast.NewTerm(ast.NewArray()), + )) } // Collect the cert verification options @@ -170,7 +168,10 @@ func builtinCryptoX509ParseAndVerifyCertificatesWithOptions(_ BuiltinContext, op verified, err := verifyX509CertificateChain(certs, verifyOpt) if err != nil { - return iter(invalid) + return iter(ast.ArrayTerm( + ast.InternedBooleanTerm(false), + ast.NewTerm(ast.NewArray()), + )) } value, err := ast.InterfaceToValue(verified) @@ -178,12 +179,10 @@ func builtinCryptoX509ParseAndVerifyCertificatesWithOptions(_ BuiltinContext, op return err } - valid := ast.ArrayTerm( - ast.BooleanTerm(true), + return iter(ast.ArrayTerm( + ast.InternedBooleanTerm(true), ast.NewTerm(value), - ) - - return iter(valid) + )) } func extractVerifyOpts(options ast.Object) (verifyOpt x509.VerifyOptions, err error) { @@ -522,7 +521,7 @@ func builtinCryptoHmacEqual(_ BuiltinContext, operands []*ast.Term, iter func(*a res := hmac.Equal([]byte(mac1), []byte(mac2)) - return iter(ast.BooleanTerm(res)) + return iter(ast.InternedBooleanTerm(res)) } func init() { diff --git a/vendor/github.com/open-policy-agent/opa/v1/topdown/doc.go b/vendor/github.com/open-policy-agent/opa/v1/topdown/doc.go new file mode 100644 index 000000000..9aa7aa45c --- /dev/null +++ b/vendor/github.com/open-policy-agent/opa/v1/topdown/doc.go @@ -0,0 +1,10 @@ +// Copyright 2016 The OPA Authors. All rights reserved. +// Use of this source code is governed by an Apache2 +// license that can be found in the LICENSE file. + +// Package topdown provides low-level query evaluation support. +// +// The topdown implementation is a modified version of the standard top-down +// evaluation algorithm used in Datalog. References and comprehensions are +// evaluated eagerly while all other terms are evaluated lazily. +package topdown diff --git a/vendor/github.com/open-policy-agent/opa/topdown/encoding.go b/vendor/github.com/open-policy-agent/opa/v1/topdown/encoding.go similarity index 95% rename from vendor/github.com/open-policy-agent/opa/topdown/encoding.go rename to vendor/github.com/open-policy-agent/opa/v1/topdown/encoding.go index f3475a60d..a27a9c245 100644 --- a/vendor/github.com/open-policy-agent/opa/topdown/encoding.go +++ b/vendor/github.com/open-policy-agent/opa/v1/topdown/encoding.go @@ -15,9 +15,9 @@ import ( "sigs.k8s.io/yaml" - "github.com/open-policy-agent/opa/ast" - "github.com/open-policy-agent/opa/topdown/builtins" - "github.com/open-policy-agent/opa/util" + "github.com/open-policy-agent/opa/v1/ast" + "github.com/open-policy-agent/opa/v1/topdown/builtins" + "github.com/open-policy-agent/opa/v1/util" ) func builtinJSONMarshal(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error { @@ -144,10 +144,10 @@ func builtinJSONIsValid(_ BuiltinContext, operands []*ast.Term, iter func(*ast.T str, err := builtins.StringOperand(operands[0].Value, 1) if err != nil { - return iter(ast.BooleanTerm(false)) + return iter(ast.InternedBooleanTerm(false)) } - return iter(ast.BooleanTerm(json.Valid([]byte(str)))) + return iter(ast.InternedBooleanTerm(json.Valid([]byte(str)))) } func builtinBase64Encode(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error { @@ -175,11 +175,11 @@ func builtinBase64Decode(_ BuiltinContext, operands []*ast.Term, iter func(*ast. func builtinBase64IsValid(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error { str, err := builtins.StringOperand(operands[0].Value, 1) if err != nil { - return iter(ast.BooleanTerm(false)) + return iter(ast.InternedBooleanTerm(false)) } _, err = base64.StdEncoding.DecodeString(string(str)) - return iter(ast.BooleanTerm(err == nil)) + return iter(ast.InternedBooleanTerm(err == nil)) } func builtinBase64UrlEncode(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error { @@ -355,12 +355,12 @@ func builtinYAMLUnmarshal(_ BuiltinContext, operands []*ast.Term, iter func(*ast func builtinYAMLIsValid(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error { str, err := builtins.StringOperand(operands[0].Value, 1) if err != nil { - return iter(ast.BooleanTerm(false)) + return iter(ast.InternedBooleanTerm(false)) } var x interface{} err = yaml.Unmarshal([]byte(str), &x) - return iter(ast.BooleanTerm(err == nil)) + return iter(ast.InternedBooleanTerm(err == nil)) } func builtinHexEncode(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error { diff --git a/vendor/github.com/open-policy-agent/opa/v1/topdown/errors.go b/vendor/github.com/open-policy-agent/opa/v1/topdown/errors.go new file mode 100644 index 000000000..cadd16319 --- /dev/null +++ b/vendor/github.com/open-policy-agent/opa/v1/topdown/errors.go @@ -0,0 +1,149 @@ +// Copyright 2017 The OPA Authors. All rights reserved. +// Use of this source code is governed by an Apache2 +// license that can be found in the LICENSE file. + +package topdown + +import ( + "errors" + "fmt" + + "github.com/open-policy-agent/opa/v1/ast" +) + +// Halt is a special error type that built-in function implementations return to indicate +// that policy evaluation should stop immediately. +type Halt struct { + Err error +} + +func (h Halt) Error() string { + return h.Err.Error() +} + +func (h Halt) Unwrap() error { return h.Err } + +// Error is the error type returned by the Eval and Query functions when +// an evaluation error occurs. +type Error struct { + Code string `json:"code"` + Message string `json:"message"` + Location *ast.Location `json:"location,omitempty"` + err error `json:"-"` +} + +const ( + + // InternalErr represents an unknown evaluation error. + InternalErr string = "eval_internal_error" + + // CancelErr indicates the evaluation process was cancelled. + CancelErr string = "eval_cancel_error" + + // ConflictErr indicates a conflict was encountered during evaluation. For + // instance, a conflict occurs if a rule produces multiple, differing values + // for the same key in an object. Conflict errors indicate the policy does + // not account for the data loaded into the policy engine. + ConflictErr string = "eval_conflict_error" + + // TypeErr indicates evaluation stopped because an expression was applied to + // a value of an inappropriate type. + TypeErr string = "eval_type_error" + + // BuiltinErr indicates a built-in function received a semantically invalid + // input or encountered some kind of runtime error, e.g., connection + // timeout, connection refused, etc. + BuiltinErr string = "eval_builtin_error" + + // WithMergeErr indicates that the real and replacement data could not be merged. + WithMergeErr string = "eval_with_merge_error" +) + +// IsError returns true if the err is an Error. +func IsError(err error) bool { + var e *Error + return errors.As(err, &e) +} + +// IsCancel returns true if err was caused by cancellation. +func IsCancel(err error) bool { + return errors.Is(err, &Error{Code: CancelErr}) +} + +// Is allows matching topdown errors using errors.Is (see IsCancel). +func (e *Error) Is(target error) bool { + var t *Error + if errors.As(target, &t) { + return (t.Code == "" || e.Code == t.Code) && + (t.Message == "" || e.Message == t.Message) && + (t.Location == nil || t.Location.Compare(e.Location) == 0) + } + return false +} + +func (e *Error) Error() string { + msg := fmt.Sprintf("%v: %v", e.Code, e.Message) + + if e.Location != nil { + msg = e.Location.String() + ": " + msg + } + + return msg +} + +func (e *Error) Wrap(err error) *Error { + e.err = err + return e +} + +func (e *Error) Unwrap() error { + return e.err +} + +func functionConflictErr(loc *ast.Location) error { + return &Error{ + Code: ConflictErr, + Location: loc, + Message: "functions must not produce multiple outputs for same inputs", + } +} + +func completeDocConflictErr(loc *ast.Location) error { + return &Error{ + Code: ConflictErr, + Location: loc, + Message: "complete rules must not produce multiple outputs", + } +} + +func objectDocKeyConflictErr(loc *ast.Location) error { + return &Error{ + Code: ConflictErr, + Location: loc, + Message: "object keys must be unique", + } +} + +func unsupportedBuiltinErr(loc *ast.Location) error { + return &Error{ + Code: InternalErr, + Location: loc, + Message: "unsupported built-in", + } +} + +func mergeConflictErr(loc *ast.Location) error { + return &Error{ + Code: WithMergeErr, + Location: loc, + Message: "real and replacement data could not be merged", + } +} + +func internalErr(loc *ast.Location, msg string) error { + return &Error{ + Code: InternalErr, + Location: loc, + Message: msg, + } +} diff --git a/vendor/github.com/open-policy-agent/opa/topdown/eval.go b/vendor/github.com/open-policy-agent/opa/v1/topdown/eval.go similarity index 91% rename from vendor/github.com/open-policy-agent/opa/topdown/eval.go rename to vendor/github.com/open-policy-agent/opa/v1/topdown/eval.go index 7884ac01e..9f88a7b15 100644 --- a/vendor/github.com/open-policy-agent/opa/topdown/eval.go +++ b/vendor/github.com/open-policy-agent/opa/v1/topdown/eval.go @@ -8,16 +8,17 @@ import ( "sort" "strconv" "strings" - - "github.com/open-policy-agent/opa/ast" - "github.com/open-policy-agent/opa/metrics" - "github.com/open-policy-agent/opa/storage" - "github.com/open-policy-agent/opa/topdown/builtins" - "github.com/open-policy-agent/opa/topdown/cache" - "github.com/open-policy-agent/opa/topdown/copypropagation" - "github.com/open-policy-agent/opa/topdown/print" - "github.com/open-policy-agent/opa/tracing" - "github.com/open-policy-agent/opa/types" + "sync" + + "github.com/open-policy-agent/opa/v1/ast" + "github.com/open-policy-agent/opa/v1/metrics" + "github.com/open-policy-agent/opa/v1/storage" + "github.com/open-policy-agent/opa/v1/topdown/builtins" + "github.com/open-policy-agent/opa/v1/topdown/cache" + "github.com/open-policy-agent/opa/v1/topdown/copypropagation" + "github.com/open-policy-agent/opa/v1/topdown/print" + "github.com/open-policy-agent/opa/v1/tracing" + "github.com/open-policy-agent/opa/v1/types" ) type evalIterator func(*eval) error @@ -57,60 +58,90 @@ func (ee deferredEarlyExitError) Error() string { return fmt.Sprintf("%v: deferred early exit", ee.e.query) } +// Note(æ): this struct is formatted for optimal alignment as it is big, internal and instantiated +// *very* frequently during evaluation. If you need to add fields here, please consider the alignment +// of the struct, and use something like betteralign (https://github.com/dkorunic/betteralign) if you +// need help with that. type eval struct { ctx context.Context metrics metrics.Metrics seed io.Reader + cancel Cancel + queryCompiler ast.QueryCompiler + store storage.Store + txn storage.Transaction + virtualCache VirtualCache + interQueryBuiltinCache cache.InterQueryCache + interQueryBuiltinValueCache cache.InterQueryValueCache + printHook print.Hook time *ast.Term - queryID uint64 queryIDFact *queryIDFactory parent *eval caller *eval - cancel Cancel - query ast.Body - queryCompiler ast.QueryCompiler - index int - indexing bool - earlyExit bool bindings *bindings - store storage.Store baseCache *baseCache - txn storage.Transaction compiler *ast.Compiler input *ast.Term data *ast.Term external *resolverTrie targetStack *refStack - tracers []QueryTracer - traceEnabled bool traceLastLocation *ast.Location // Last location of a trace event. - plugTraceVars bool instr *Instrumentation builtins map[string]*Builtin builtinCache builtins.Cache ndBuiltinCache builtins.NDBCache functionMocks *functionMocksStack - virtualCache VirtualCache comprehensionCache *comprehensionCache - interQueryBuiltinCache cache.InterQueryCache - interQueryBuiltinValueCache cache.InterQueryValueCache saveSet *saveSet saveStack *saveStack saveSupport *saveSupport saveNamespace *ast.Term - skipSaveNamespace bool inliningControl *inliningControl - genvarprefix string - genvarid int runtime *ast.Term builtinErrors *builtinErrors - printHook print.Hook + roundTripper CustomizeRoundTripper + genvarprefix string + query ast.Body + tracers []QueryTracer tracingOpts tracing.Options + queryID uint64 + index int + genvarid int + indexing bool + earlyExit bool + traceEnabled bool + plugTraceVars bool + skipSaveNamespace bool findOne bool strictObjects bool } +type evp struct { + pool sync.Pool +} + +func (ep *evp) Put(e *eval) { + ep.pool.Put(e) +} + +func (ep *evp) Get() *eval { + return ep.pool.Get().(*eval) +} + +var evalPool = evp{ + pool: sync.Pool{ + New: func() any { + return &eval{} + }, + }, +} + func (e *eval) Run(iter evalIterator) error { + if !e.traceEnabled { + // avoid function literal escaping to heap if we don't need the trace + return e.eval(iter) + } + e.traceEnter(e.query) return e.eval(func(e *eval) error { e.traceExit(e.query) @@ -151,25 +182,23 @@ func (e *eval) builtinFunc(name string) (*ast.Builtin, BuiltinFunc, bool) { return nil, nil, false } -func (e *eval) closure(query ast.Body) *eval { - cpy := *e +func (e *eval) closure(query ast.Body, cpy *eval) { + *cpy = *e cpy.index = 0 cpy.query = query cpy.queryID = cpy.queryIDFact.Next() cpy.parent = e cpy.findOne = false - return &cpy } -func (e *eval) child(query ast.Body) *eval { - cpy := *e +func (e *eval) child(query ast.Body, cpy *eval) { + *cpy = *e cpy.index = 0 cpy.query = query cpy.queryID = cpy.queryIDFact.Next() cpy.bindings = newBindings(cpy.queryID, e.instr) cpy.parent = e cpy.findOne = false - return &cpy } func (e *eval) next(iter evalIterator) error { @@ -335,6 +364,13 @@ func (e *eval) evalExpr(iter evalIterator) error { } if e.cancel != nil && e.cancel.Cancelled() { + if e.ctx != nil && e.ctx.Err() != nil { + return &Error{ + Code: CancelErr, + Message: e.ctx.Err().Error(), + err: e.ctx.Err(), + } + } return &Error{ Code: CancelErr, Message: "caller cancelled query execution", @@ -346,9 +382,7 @@ func (e *eval) evalExpr(iter evalIterator) error { if err != nil { switch err := err.(type) { - case *deferredEarlyExitError: - return wrapErr(err) - case *earlyExitError: + case *deferredEarlyExitError, *earlyExitError: return wrapErr(err) default: return err @@ -374,46 +408,110 @@ func (e *eval) evalExpr(iter evalIterator) error { } func (e *eval) evalStep(iter evalIterator) error { - expr := e.query[e.index] if expr.Negated { return e.evalNot(iter) } - var defined bool var err error + + // NOTE(æ): the reason why there's one branch for the tracing case and one almost + // identical branch below for when tracing is disabled is that the tracing case + // allocates wildly. These allocations are cause by the "defined" boolean variable + // escaping to the heap as its value is set from inside of closures. There may very + // well be more elegant solutions to this problem, but this is one that works, and + // saves several *million* allocations for some workloads. So feel free to refactor + // this, but do make sure that the common non-tracing case doesn't pay in allocations + // for something that is only needed when tracing is enabled. + if e.traceEnabled { + var defined bool + switch terms := expr.Terms.(type) { + case []*ast.Term: + switch { + case expr.IsEquality(): + err = e.unify(terms[1], terms[2], func() error { + defined = true + err := iter(e) + e.traceRedo(expr) + return err + }) + default: + err = e.evalCall(terms, func() error { + defined = true + err := iter(e) + e.traceRedo(expr) + return err + }) + } + case *ast.Term: + // generateVar inlined here to avoid extra allocations in hot path + rterm := ast.VarTerm(fmt.Sprintf("%s_term_%d_%d", e.genvarprefix, e.queryID, e.index)) + err = e.unify(terms, rterm, func() error { + if e.saveSet.Contains(rterm, e.bindings) { + return e.saveExpr(ast.NewExpr(rterm), e.bindings, func() error { + return iter(e) + }) + } + if !e.bindings.Plug(rterm).Equal(ast.InternedBooleanTerm(false)) { + defined = true + err := iter(e) + e.traceRedo(expr) + return err + } + return nil + }) + case *ast.Every: + eval := evalEvery{ + Every: terms, + e: e, + expr: expr, + } + err = eval.eval(func() error { + defined = true + err := iter(e) + e.traceRedo(expr) + return err + }) + + default: // guard-rail for adding extra (Expr).Terms types + return fmt.Errorf("got %T terms: %[1]v", terms) + } + + if err != nil { + return err + } + + if !defined { + e.traceFail(expr) + } + + return nil + } + switch terms := expr.Terms.(type) { case []*ast.Term: switch { case expr.IsEquality(): err = e.unify(terms[1], terms[2], func() error { - defined = true - err := iter(e) - e.traceRedo(expr) - return err + return iter(e) }) default: err = e.evalCall(terms, func() error { - defined = true - err := iter(e) - e.traceRedo(expr) - return err + return iter(e) }) } case *ast.Term: - rterm := e.generateVar(fmt.Sprintf("term_%d_%d", e.queryID, e.index)) + // generateVar inlined here to avoid extra allocations in hot path + rterm := ast.VarTerm(fmt.Sprintf("%s_term_%d_%d", e.genvarprefix, e.queryID, e.index)) err = e.unify(terms, rterm, func() error { if e.saveSet.Contains(rterm, e.bindings) { return e.saveExpr(ast.NewExpr(rterm), e.bindings, func() error { return iter(e) }) } - if !e.bindings.Plug(rterm).Equal(ast.BooleanTerm(false)) { - defined = true - err := iter(e) - e.traceRedo(expr) - return err + if !e.bindings.Plug(rterm).Equal(ast.InternedBooleanTerm(false)) { + return iter(e) } return nil }) @@ -424,25 +522,14 @@ func (e *eval) evalStep(iter evalIterator) error { expr: expr, } err = eval.eval(func() error { - defined = true - err := iter(e) - e.traceRedo(expr) - return err + return iter(e) }) default: // guard-rail for adding extra (Expr).Terms types return fmt.Errorf("got %T terms: %[1]v", terms) } - if err != nil { - return err - } - - if !defined { - e.traceFail(expr) - } - - return nil + return err } func (e *eval) evalNot(iter evalIterator) error { @@ -453,20 +540,27 @@ func (e *eval) evalNot(iter evalIterator) error { return e.evalNotPartial(iter) } - negation := ast.NewBody(expr.Complement().NoWith()) - child := e.closure(negation) + negation := ast.NewBody(expr.ComplementNoWith()) + child := evalPool.Get() + + e.closure(negation, child) + + defer evalPool.Put(child) var defined bool - child.traceEnter(negation) + if e.traceEnabled { + child.traceEnter(negation) + } - err := child.eval(func(*eval) error { - child.traceExit(negation) + if err := child.eval(func(*eval) error { + if e.traceEnabled { + child.traceExit(negation) + child.traceRedo(negation) + } defined = true - child.traceRedo(negation) - return nil - }) - if err != nil { + return nil + }); err != nil { return err } @@ -613,11 +707,14 @@ func (e *eval) evalWithPop(input, data *ast.Term) { } func (e *eval) evalNotPartial(iter evalIterator) error { - // Prepare query normally. expr := e.query[e.index] - negation := expr.Complement().NoWith() - child := e.closure(ast.NewBody(negation)) + negation := expr.ComplementNoWith() + + child := evalPool.Get() + defer evalPool.Put(child) + + e.closure(ast.NewBody(negation), child) // Unknowns is the set of variables that are marked as unknown. The variables // are namespaced with the query ID that they originate in. This ensures that @@ -770,7 +867,7 @@ func (e *eval) evalCall(terms []*ast.Term, iter unifyIterator) error { if ref[0].Equal(ast.DefaultRootDocument) { if mocked { f := e.compiler.TypeEnv.Get(ref).(*types.Function) - return e.evalCallValue(len(f.FuncArgs().Args), terms, mock, iter) + return e.evalCallValue(f.Arity(), terms, mock, iter) } var ir *ast.IndexResult @@ -800,11 +897,11 @@ func (e *eval) evalCall(terms []*ast.Term, iter unifyIterator) error { } if mocked { // value replacement of built-in call - return e.evalCallValue(len(bi.Decl.Args()), terms, mock, iter) + return e.evalCallValue(bi.Decl.Arity(), terms, mock, iter) } if e.unknown(e.query[e.index], e.bindings) { - return e.saveCall(len(bi.Decl.Args()), terms, iter) + return e.saveCall(bi.Decl.Arity(), terms, iter) } var parentID uint64 @@ -836,6 +933,7 @@ func (e *eval) evalCall(terms []*ast.Term, iter unifyIterator) error { PrintHook: e.printHook, DistributedTracingOpts: e.tracingOpts, Capabilities: capabilities, + RoundTripper: e.roundTripper, } eval := evalBuiltin{ @@ -855,7 +953,7 @@ func (e *eval) evalCallValue(arity int, terms []*ast.Term, mock *ast.Term, iter return e.unify(terms[len(terms)-1], mock, iter) case len(terms) == arity+1: - if mock.Value.Compare(ast.Boolean(false)) != 0 { + if !ast.Boolean(false).Equal(mock.Value) { return iter() } return nil @@ -932,6 +1030,22 @@ func (e *eval) biunifyArraysRec(a, b *ast.Array, b1, b2 *bindings, iter unifyIte }) } +func (e *eval) biunifyTerms(a, b []*ast.Term, b1, b2 *bindings, iter unifyIterator) error { + if len(a) != len(b) { + return nil + } + return e.biunifyTermsRec(a, b, b1, b2, iter, 0) +} + +func (e *eval) biunifyTermsRec(a, b []*ast.Term, b1, b2 *bindings, iter unifyIterator, idx int) error { + if idx == len(a) { + return iter() + } + return e.biunify(a[idx], b[idx], b1, b2, func() error { + return e.biunifyTermsRec(a, b, b1, b2, iter, idx+1) + }) +} + func (e *eval) biunifyObjects(a, b ast.Object, b1, b2 *bindings, iter unifyIterator) error { if a.Len() != b.Len() { return nil @@ -1165,7 +1279,10 @@ func (e *eval) buildComprehensionCache(a *ast.Term) (*ast.Term, error) { } func (e *eval) buildComprehensionCacheArray(x *ast.ArrayComprehension, keys []*ast.Term) (*comprehensionCacheElem, error) { - child := e.child(x.Body) + child := evalPool.Get() + defer evalPool.Put(child) + + e.child(x.Body, child) node := newComprehensionCacheElem() return node, child.Run(func(child *eval) error { values := make([]*ast.Term, len(keys)) @@ -1184,7 +1301,10 @@ func (e *eval) buildComprehensionCacheArray(x *ast.ArrayComprehension, keys []*a } func (e *eval) buildComprehensionCacheSet(x *ast.SetComprehension, keys []*ast.Term) (*comprehensionCacheElem, error) { - child := e.child(x.Body) + child := evalPool.Get() + defer evalPool.Put(child) + + e.child(x.Body, child) node := newComprehensionCacheElem() return node, child.Run(func(child *eval) error { values := make([]*ast.Term, len(keys)) @@ -1204,7 +1324,10 @@ func (e *eval) buildComprehensionCacheSet(x *ast.SetComprehension, keys []*ast.T } func (e *eval) buildComprehensionCacheObject(x *ast.ObjectComprehension, keys []*ast.Term) (*comprehensionCacheElem, error) { - child := e.child(x.Body) + child := evalPool.Get() + defer evalPool.Put(child) + + e.child(x.Body, child) node := newComprehensionCacheElem() return node, child.Run(func(child *eval) error { values := make([]*ast.Term, len(keys)) @@ -1285,7 +1408,11 @@ func (e *eval) amendComprehension(a *ast.Term, b1 *bindings) (*ast.Term, error) func (e *eval) biunifyComprehensionArray(x *ast.ArrayComprehension, b *ast.Term, b1, b2 *bindings, iter unifyIterator) error { result := ast.NewArray() - child := e.closure(x.Body) + child := evalPool.Get() + + e.closure(x.Body, child) + defer evalPool.Put(child) + err := child.Run(func(child *eval) error { result = result.Append(child.bindings.Plug(x.Term)) return nil @@ -1298,7 +1425,11 @@ func (e *eval) biunifyComprehensionArray(x *ast.ArrayComprehension, b *ast.Term, func (e *eval) biunifyComprehensionSet(x *ast.SetComprehension, b *ast.Term, b1, b2 *bindings, iter unifyIterator) error { result := ast.NewSet() - child := e.closure(x.Body) + child := evalPool.Get() + + e.closure(x.Body, child) + defer evalPool.Put(child) + err := child.Run(func(child *eval) error { result.Add(child.bindings.Plug(x.Term)) return nil @@ -1310,8 +1441,13 @@ func (e *eval) biunifyComprehensionSet(x *ast.SetComprehension, b *ast.Term, b1, } func (e *eval) biunifyComprehensionObject(x *ast.ObjectComprehension, b *ast.Term, b1, b2 *bindings, iter unifyIterator) error { + child := evalPool.Get() + defer evalPool.Put(child) + + e.closure(x.Body, child) + result := ast.NewObject() - child := e.closure(x.Body) + err := child.Run(func(child *eval) error { key := child.bindings.Plug(x.Key) value := child.bindings.Plug(x.Value) @@ -1448,12 +1584,22 @@ func (e *eval) getRules(ref ast.Ref, args []*ast.Term) (*ast.IndexResult, error) return nil, nil } + resolver := resolverPool.Get().(*evalResolver) + defer func() { + resolver.e = nil + resolver.args = nil + resolverPool.Put(resolver) + }() + var result *ast.IndexResult var err error if e.indexing { - result, err = index.Lookup(&evalResolver{e: e, args: args}) + resolver.e = e + resolver.args = args + result, err = index.Lookup(resolver) } else { - result, err = index.AllRules(&evalResolver{e: e}) + resolver.e = e + result, err = index.AllRules(resolver) } if err != nil { return nil, err @@ -1461,20 +1607,23 @@ func (e *eval) getRules(ref ast.Ref, args []*ast.Term) (*ast.IndexResult, error) result.EarlyExit = result.EarlyExit && e.earlyExit - var msg strings.Builder - if len(result.Rules) == 1 { - msg.WriteString("(matched 1 rule") - } else { - msg.Grow(len("(matched NNNN rules)")) - msg.WriteString("(matched ") - msg.WriteString(strconv.Itoa(len(result.Rules))) - msg.WriteString(" rules") - } - if result.EarlyExit { - msg.WriteString(", early exit") + if e.traceEnabled { + var msg strings.Builder + if len(result.Rules) == 1 { + msg.WriteString("(matched 1 rule") + } else { + msg.Grow(len("(matched NNNN rules)")) + msg.WriteString("(matched ") + msg.WriteString(strconv.Itoa(len(result.Rules))) + msg.WriteString(" rules") + } + if result.EarlyExit { + msg.WriteString(", early exit") + } + msg.WriteRune(')') + e.traceIndex(e.query[e.index], msg.String(), &ref) } - msg.WriteRune(')') - e.traceIndex(e.query[e.index], msg.String(), &ref) + return result, err } @@ -1487,6 +1636,14 @@ type evalResolver struct { args []*ast.Term } +var ( + resolverPool = sync.Pool{ + New: func() any { + return &evalResolver{} + }, + } +) + func (e *evalResolver) Resolve(ref ast.Ref) (ast.Value, error) { e.e.instr.startTimer(evalOpResolve) @@ -1681,7 +1838,7 @@ func (e *eval) getDeclArgsLen(x *ast.Expr) (int, error) { bi, _, ok := e.builtinFunc(operator.String()) if ok { - return len(bi.Decl.Args()), nil + return bi.Decl.Arity(), nil } ir, err := e.getRules(operator, nil) @@ -1724,7 +1881,7 @@ func (e evalBuiltin) eval(iter unifyIterator) error { operands[i] = e.e.bindings.Plug(e.terms[i]) } - numDeclArgs := len(e.bi.Decl.FuncArgs().Args) + numDeclArgs := e.bi.Decl.Arity() e.e.instr.startTimer(evalOpBuiltinCall) var err error @@ -1749,7 +1906,7 @@ func (e evalBuiltin) eval(iter unifyIterator) error { case e.bi.Decl.Result() == nil: return iter() case len(operands) == numDeclArgs: - if v.Compare(ast.Boolean(false)) == 0 { + if ast.Boolean(false).Equal(v) { return nil // nothing to do } return iter() @@ -1773,7 +1930,7 @@ func (e evalBuiltin) eval(iter unifyIterator) error { case e.bi.Decl.Result() == nil: err = iter() case len(operands) == numDeclArgs: - if output.Value.Compare(ast.Boolean(false)) != 0 { + if !ast.Boolean(false).Equal(output.Value) { err = iter() } // else: nothing to do, don't iter() default: @@ -1813,9 +1970,9 @@ func (e evalBuiltin) eval(iter unifyIterator) error { type evalFunc struct { e *eval + ir *ast.IndexResult ref ast.Ref terms []*ast.Term - ir *ast.IndexResult } func (e evalFunc) eval(iter unifyIterator) error { @@ -1854,9 +2011,9 @@ func (e evalFunc) eval(iter unifyIterator) error { func (e evalFunc) evalValue(iter unifyIterator, argCount int, findOne bool) error { var cacheKey ast.Ref - var hit bool - var err error if !e.e.partial() { + var hit bool + var err error cacheKey, hit, err = e.evalCache(argCount, iter) if err != nil { return err @@ -1944,8 +2101,10 @@ func (e evalFunc) evalCache(argCount int, iter unifyIterator) (ast.Ref, bool, er } func (e evalFunc) evalOneRule(iter unifyIterator, rule *ast.Rule, cacheKey ast.Ref, prev *ast.Term, findOne bool) (*ast.Term, error) { + child := evalPool.Get() + defer evalPool.Put(child) - child := e.e.child(rule.Body) + e.e.child(rule.Body, child) child.findOne = findOne args := make([]*ast.Term, len(e.terms)-1) @@ -1959,7 +2118,7 @@ func (e evalFunc) evalOneRule(iter unifyIterator, rule *ast.Rule, cacheKey ast.R child.traceEnter(rule) - err := child.biunifyArrays(ast.NewArray(e.terms[1:]...), ast.NewArray(args...), e.e.bindings, child.bindings, func() error { + err := child.biunifyTerms(e.terms[1:], args, e.e.bindings, child.bindings, func() error { return child.eval(func(child *eval) error { child.traceExit(rule) @@ -1977,8 +2136,8 @@ func (e evalFunc) evalOneRule(iter unifyIterator, rule *ast.Rule, cacheKey ast.R } if len(rule.Head.Args) == len(e.terms)-1 { - if result.Value.Compare(ast.Boolean(false)) == 0 { - if prev != nil && ast.Compare(prev, result) != 0 { + if ast.Boolean(false).Equal(result.Value) { + if prev != nil && !prev.Equal(result) { return functionConflictErr(rule.Location) } prev = result @@ -1992,7 +2151,7 @@ func (e evalFunc) evalOneRule(iter unifyIterator, rule *ast.Rule, cacheKey ast.R // an example. if !e.e.partial() { if prev != nil { - if ast.Compare(prev, result) != 0 { + if !prev.Equal(result) { return functionConflictErr(rule.Location) } child.traceRedo(rule) @@ -2036,8 +2195,10 @@ func (e evalFunc) partialEvalSupport(declArgsLen int, iter unifyIterator) error } func (e evalFunc) partialEvalSupportRule(rule *ast.Rule, path ast.Ref) error { + child := evalPool.Get() + defer evalPool.Put(child) - child := e.e.child(rule.Body) + e.e.child(rule.Body, child) child.traceEnter(rule) e.e.saveStack.PushQuery(nil) @@ -2086,13 +2247,13 @@ func (e evalFunc) partialEvalSupportRule(rule *ast.Rule, path ast.Ref) error { type evalTree struct { e *eval - ref ast.Ref - plugged ast.Ref - pos int bindings *bindings rterm *ast.Term rbindings *bindings node *ast.TreeNode + ref ast.Ref + plugged ast.Ref + pos int } func (e evalTree) eval(iter unifyIterator) error { @@ -2187,7 +2348,7 @@ func (e evalTree) enumerate(iter unifyIterator) error { switch doc := doc.(type) { case *ast.Array: for i := 0; i < doc.Len(); i++ { - k := ast.IntNumberTerm(i) + k := ast.InternedIntNumberTerm(i) err := e.e.biunify(k, e.ref[e.pos], e.bindings, e.bindings, func() error { return e.next(iter, k) }) @@ -2317,12 +2478,12 @@ func (e evalTree) leaves(plugged ast.Ref, node *ast.TreeNode) (ast.Object, error type evalVirtual struct { e *eval - ref ast.Ref - plugged ast.Ref - pos int bindings *bindings rterm *ast.Term rbindings *bindings + ref ast.Ref + plugged ast.Ref + pos int } func (e evalVirtual) eval(iter unifyIterator) error { @@ -2393,14 +2554,14 @@ func (e evalVirtual) eval(iter unifyIterator) error { type evalVirtualPartial struct { e *eval - ref ast.Ref - plugged ast.Ref - pos int ir *ast.IndexResult bindings *bindings rterm *ast.Term rbindings *bindings empty *ast.Term + ref ast.Ref + plugged ast.Ref + pos int } type evalVirtualPartialCacheHint struct { @@ -2536,8 +2697,11 @@ func (e evalVirtualPartial) evalAllRulesNoCache(rules []*ast.Rule) (*ast.Term, e var visitedRefs []ast.Ref + child := evalPool.Get() + defer evalPool.Put(child) + for _, rule := range rules { - child := e.e.child(rule.Body) + e.e.child(rule.Body, child) child.traceEnter(rule) err := child.eval(func(*eval) error { child.traceExit(rule) @@ -2570,8 +2734,10 @@ func wrapInObjects(leaf *ast.Term, ref ast.Ref) *ast.Term { } func (e evalVirtualPartial) evalOneRulePreUnify(iter unifyIterator, rule *ast.Rule, result *ast.Term, unknown bool, visitedRefs *[]ast.Ref) (*ast.Term, error) { + child := evalPool.Get() + defer evalPool.Put(child) - child := e.e.child(rule.Body) + e.e.child(rule.Body, child) child.traceEnter(rule) var defined bool @@ -2663,7 +2829,10 @@ func (e *eval) biunifyDynamicRef(pos int, a, b ast.Ref, b1, b2 *bindings, iter u } func (e evalVirtualPartial) evalOneRulePostUnify(iter unifyIterator, rule *ast.Rule) error { - child := e.e.child(rule.Body) + child := evalPool.Get() + defer evalPool.Put(child) + + e.e.child(rule.Body, child) child.traceEnter(rule) var defined bool @@ -2747,8 +2916,10 @@ func (e evalVirtualPartial) partialEvalSupport(iter unifyIterator) error { } func (e evalVirtualPartial) partialEvalSupportRule(rule *ast.Rule, _ ast.Ref) (bool, error) { + child := evalPool.Get() + defer evalPool.Put(child) - child := e.e.child(rule.Body) + e.e.child(rule.Body, child) child.traceEnter(rule) e.e.saveStack.PushQuery(nil) @@ -3111,13 +3282,13 @@ func (e evalVirtualPartial) reduce(rule *ast.Rule, b *bindings, result *ast.Term type evalVirtualComplete struct { e *eval - ref ast.Ref - plugged ast.Ref - pos int ir *ast.IndexResult bindings *bindings rterm *ast.Term rbindings *bindings + ref ast.Ref + plugged ast.Ref + pos int } func (e evalVirtualComplete) eval(iter unifyIterator) error { @@ -3226,8 +3397,10 @@ func (e evalVirtualComplete) evalValue(iter unifyIterator, findOne bool) error { } func (e evalVirtualComplete) evalValueRule(iter unifyIterator, rule *ast.Rule, prev *ast.Term, findOne bool) (*ast.Term, error) { + child := evalPool.Get() + defer evalPool.Put(child) - child := e.e.child(rule.Body) + e.e.child(rule.Body, child) child.findOne = findOne child.traceEnter(rule) var result *ast.Term @@ -3262,9 +3435,11 @@ func (e evalVirtualComplete) evalValueRule(iter unifyIterator, rule *ast.Rule, p } func (e evalVirtualComplete) partialEval(iter unifyIterator) error { + child := evalPool.Get() + defer evalPool.Put(child) for _, rule := range e.ir.Rules { - child := e.e.child(rule.Body) + e.e.child(rule.Body, child) child.traceEnter(rule) err := child.eval(func(child *eval) error { @@ -3327,8 +3502,10 @@ func (e evalVirtualComplete) partialEvalSupport(iter unifyIterator) error { } func (e evalVirtualComplete) partialEvalSupportRule(rule *ast.Rule, path ast.Ref) (bool, error) { + child := evalPool.Get() + defer evalPool.Put(child) - child := e.e.child(rule.Body) + e.e.child(rule.Body, child) child.traceEnter(rule) e.e.saveStack.PushQuery(nil) @@ -3383,13 +3560,13 @@ func (e evalVirtualComplete) evalTerm(iter unifyIterator, term *ast.Term, termbi type evalTerm struct { e *eval - ref ast.Ref - pos int bindings *bindings term *ast.Term termbindings *bindings rterm *ast.Term rbindings *bindings + ref ast.Ref + pos int } func (e evalTerm) eval(iter unifyIterator) error { @@ -3441,32 +3618,28 @@ func (e evalTerm) enumerate(iter unifyIterator) error { switch v := e.term.Value.(type) { case *ast.Array: for i := 0; i < v.Len(); i++ { - k := ast.IntNumberTerm(i) - err := e.e.biunify(k, e.ref[e.pos], e.bindings, e.bindings, func() error { + k := ast.InternedIntNumberTerm(i) + if err := handleErr(e.e.biunify(k, e.ref[e.pos], e.bindings, e.bindings, func() error { return e.next(iter, k) - }) - - if err := handleErr(err); err != nil { + })); err != nil { return err } } case ast.Object: - if err := v.Iter(func(k, _ *ast.Term) error { - err := e.e.biunify(k, e.ref[e.pos], e.termbindings, e.bindings, func() error { + for _, k := range v.Keys() { + if err := handleErr(e.e.biunify(k, e.ref[e.pos], e.termbindings, e.bindings, func() error { return e.next(iter, e.termbindings.Plug(k)) - }) - return handleErr(err) - }); err != nil { - return err + })); err != nil { + return err + } } case ast.Set: - if err := v.Iter(func(elem *ast.Term) error { - err := e.e.biunify(elem, e.ref[e.pos], e.termbindings, e.bindings, func() error { + for _, elem := range v.Slice() { + if err := handleErr(e.e.biunify(elem, e.ref[e.pos], e.termbindings, e.bindings, func() error { return e.next(iter, e.termbindings.Plug(elem)) - }) - return handleErr(err) - }); err != nil { - return err + })); err != nil { + return err + } } } @@ -3569,7 +3742,11 @@ func (e evalEvery) eval(iter unifyIterator) error { ).SetLocation(e.Domain.Location), ) - domain := e.e.closure(generator) + domain := evalPool.Get() + defer evalPool.Put(domain) + + e.e.closure(generator, domain) + all := true // all generator evaluations yield one successful body evaluation domain.traceEnter(e.expr) @@ -3580,7 +3757,11 @@ func (e evalEvery) eval(iter unifyIterator) error { // This would do extra work, like iterating needlessly if domain was a large array. return nil } - body := child.closure(e.Body) + + body := evalPool.Get() + defer evalPool.Put(body) + + child.closure(e.Body, body) body.findOne = true body.traceEnter(e.Body) done := false @@ -3707,10 +3888,12 @@ func applyCopyPropagation(p *copypropagation.CopyPropagator, instr *Instrumentat return result } +func nonGroundKey(k, _ *ast.Term) bool { + return !k.IsGround() +} + func nonGroundKeys(a ast.Object) bool { - return a.Until(func(k, _ *ast.Term) bool { - return !k.IsGround() - }) + return a.Until(nonGroundKey) } func plugKeys(a ast.Object, b *bindings) ast.Object { diff --git a/vendor/github.com/open-policy-agent/opa/topdown/glob.go b/vendor/github.com/open-policy-agent/opa/v1/topdown/glob.go similarity index 95% rename from vendor/github.com/open-policy-agent/opa/topdown/glob.go rename to vendor/github.com/open-policy-agent/opa/v1/topdown/glob.go index baf092ab6..cda17f382 100644 --- a/vendor/github.com/open-policy-agent/opa/topdown/glob.go +++ b/vendor/github.com/open-policy-agent/opa/v1/topdown/glob.go @@ -6,8 +6,8 @@ import ( "github.com/gobwas/glob" - "github.com/open-policy-agent/opa/ast" - "github.com/open-policy-agent/opa/topdown/builtins" + "github.com/open-policy-agent/opa/v1/ast" + "github.com/open-policy-agent/opa/v1/topdown/builtins" ) const globCacheMaxSize = 100 @@ -55,7 +55,7 @@ func builtinGlobMatch(bctx BuiltinContext, operands []*ast.Term, iter func(*ast. if err != nil { return err } - return iter(ast.BooleanTerm(m)) + return iter(ast.InternedBooleanTerm(m)) } func globCompileAndMatch(bctx BuiltinContext, id, pattern, match string, delimiters []rune) (bool, error) { diff --git a/vendor/github.com/open-policy-agent/opa/v1/topdown/graphql.go b/vendor/github.com/open-policy-agent/opa/v1/topdown/graphql.go new file mode 100644 index 000000000..0ad1cfdb5 --- /dev/null +++ b/vendor/github.com/open-policy-agent/opa/v1/topdown/graphql.go @@ -0,0 +1,485 @@ +// Copyright 2022 The OPA Authors. All rights reserved. +// Use of this source code is governed by an Apache2 +// license that can be found in the LICENSE file. + +package topdown + +import ( + "encoding/json" + "fmt" + "strings" + + gqlast "github.com/open-policy-agent/opa/internal/gqlparser/ast" + gqlparser "github.com/open-policy-agent/opa/internal/gqlparser/parser" + gqlvalidator "github.com/open-policy-agent/opa/internal/gqlparser/validator" + + // Side-effecting import. Triggers GraphQL library's validation rule init() functions. + _ "github.com/open-policy-agent/opa/internal/gqlparser/validator/rules" + + "github.com/open-policy-agent/opa/v1/ast" + "github.com/open-policy-agent/opa/v1/topdown/builtins" +) + +// Parses a GraphQL schema, and returns the GraphQL AST for the schema. +func parseSchema(schema string) (*gqlast.SchemaDocument, error) { + // NOTE(philipc): We don't include the "built-in schema defs" from the + // underlying graphql parsing library here, because those definitions + // generate enormous AST blobs. In the future, if there is demand for + // a "full-spec" version of schema ASTs, we may need to provide a + // version of this function that includes the built-in schema + // definitions. + schemaAST, err := gqlparser.ParseSchema(&gqlast.Source{Input: schema}) + if err != nil { + errorParts := strings.SplitN(err.Error(), ":", 4) + msg := strings.TrimLeft(errorParts[3], " ") + return nil, fmt.Errorf("%s in GraphQL string at location %s:%s", msg, errorParts[1], errorParts[2]) + } + return schemaAST, nil +} + +// Parses a GraphQL query, and returns the GraphQL AST for the query. +func parseQuery(query string) (*gqlast.QueryDocument, error) { + queryAST, err := gqlparser.ParseQuery(&gqlast.Source{Input: query}) + if err != nil { + errorParts := strings.SplitN(err.Error(), ":", 4) + msg := strings.TrimLeft(errorParts[3], " ") + return nil, fmt.Errorf("%s in GraphQL string at location %s:%s", msg, errorParts[1], errorParts[2]) + } + return queryAST, nil +} + +// Validates a GraphQL query against a schema, and returns an error. +// In this case, we get a wrappered error list type, and pluck out +// just the first error message in the list. +func validateQuery(schema *gqlast.Schema, query *gqlast.QueryDocument) error { + // Validate the query against the schema, erroring if there's an issue. + err := gqlvalidator.Validate(schema, query) + if err != nil { + // We use strings.TrimSuffix to remove the '.' characters that the library + // authors include on most of their validation errors. This should be safe, + // since variable names in their error messages are usually quoted, and + // this affects only the last character(s) in the string. + // NOTE(philipc): We know the error location will be in the query string, + // because schema validation always happens before this function is called. + errorParts := strings.SplitN(err.Error(), ":", 4) + msg := strings.TrimSuffix(strings.TrimLeft(errorParts[3], " "), ".\n") + return fmt.Errorf("%s in GraphQL query string at location %s:%s", msg, errorParts[1], errorParts[2]) + } + return nil +} + +func getBuiltinSchema() *gqlast.SchemaDocument { + schema, err := gqlparser.ParseSchema(gqlvalidator.Prelude) + if err != nil { + panic(fmt.Errorf("Error in gqlparser Prelude (should be impossible): %w", err)) + } + return schema +} + +// NOTE(philipc): This function expects *validated* schema documents, and will break +// if it is fed arbitrary structures. +func mergeSchemaDocuments(docA *gqlast.SchemaDocument, docB *gqlast.SchemaDocument) *gqlast.SchemaDocument { + ast := &gqlast.SchemaDocument{} + ast.Merge(docA) + ast.Merge(docB) + return ast +} + +// Converts a SchemaDocument into a gqlast.Schema object that can be used for validation. +// It merges in the builtin schema typedefs exactly as gqltop.LoadSchema did internally. +func convertSchema(schemaDoc *gqlast.SchemaDocument) (*gqlast.Schema, error) { + // Merge builtin schema + schema we were provided. + builtinsSchemaDoc := getBuiltinSchema() + mergedSchemaDoc := mergeSchemaDocuments(builtinsSchemaDoc, schemaDoc) + schema, err := gqlvalidator.ValidateSchemaDocument(mergedSchemaDoc) + if err != nil { + return nil, fmt.Errorf("Error in gqlparser SchemaDocument to Schema conversion: %w", err) + } + return schema, nil +} + +// Converts an ast.Object into a gqlast.QueryDocument object. +func objectToQueryDocument(value ast.Object) (*gqlast.QueryDocument, error) { + // Convert ast.Term to interface{} for JSON encoding below. + asJSON, err := ast.JSON(value) + if err != nil { + return nil, err + } + // Marshal to JSON. + bs, err := json.Marshal(asJSON) + if err != nil { + return nil, err + } + // Unmarshal from JSON -> gqlast.QueryDocument. + var result gqlast.QueryDocument + err = json.Unmarshal(bs, &result) + if err != nil { + return nil, err + } + return &result, nil +} + +// Converts an ast.Object into a gqlast.SchemaDocument object. +func objectToSchemaDocument(value ast.Object) (*gqlast.SchemaDocument, error) { + // Convert ast.Term to interface{} for JSON encoding below. + asJSON, err := ast.JSON(value) + if err != nil { + return nil, err + } + // Marshal to JSON. + bs, err := json.Marshal(asJSON) + if err != nil { + return nil, err + } + // Unmarshal from JSON -> gqlast.SchemaDocument. + var result gqlast.SchemaDocument + err = json.Unmarshal(bs, &result) + if err != nil { + return nil, err + } + return &result, nil +} + +// Recursively traverses an AST that has been run through InterfaceToValue, +// and prunes away the fields with null or empty values, and all `Position` +// structs. +// NOTE(philipc): We currently prune away null values to reduce the level +// of clutter in the returned AST objects. In the future, if there is demand +// for ASTs that have a more regular/fixed structure, we may need to provide +// a "raw" version of the AST, where we still prune away the `Position` +// structs, but leave in the null fields. +func pruneIrrelevantGraphQLASTNodes(value ast.Value) ast.Value { + // We iterate over the Value we've been provided, and recurse down + // in the case of complex types, such as Arrays/Objects. + // We are guaranteed to only have to deal with standard JSON types, + // so this is much less ugly than what we'd need for supporting every + // extant ast type! + switch x := value.(type) { + case *ast.Array: + result := ast.NewArray() + // Iterate over the array's elements, and do the following: + // - Drop any Nulls + // - Drop any any empty object/array value (after running the pruner) + for i := 0; i < x.Len(); i++ { + vTerm := x.Elem(i) + switch v := vTerm.Value.(type) { + case ast.Null: + continue + case *ast.Array: + // Safe, because we knew the type before going to prune it. + va := pruneIrrelevantGraphQLASTNodes(v).(*ast.Array) + if va.Len() > 0 { + result = result.Append(ast.NewTerm(va)) + } + case ast.Object: + // Safe, because we knew the type before going to prune it. + vo := pruneIrrelevantGraphQLASTNodes(v).(ast.Object) + if vo.Len() > 0 { + result = result.Append(ast.NewTerm(vo)) + } + default: + result = result.Append(vTerm) + } + } + return result + case ast.Object: + result := ast.NewObject() + // Iterate over our object's keys, and do the following: + // - Drop "Position". + // - Drop any key with a Null value. + // - Drop any key with an empty object/array value (after running the pruner) + keys := x.Keys() + for _, k := range keys { + // We drop the "Position" objects because we don't need the + // source-backref/location info they provide for policy rules. + // Note that keys are ast.Strings. + if ast.String("Position").Equal(k.Value) { + continue + } + vTerm := x.Get(k) + switch v := vTerm.Value.(type) { + case ast.Null: + continue + case *ast.Array: + // Safe, because we knew the type before going to prune it. + va := pruneIrrelevantGraphQLASTNodes(v).(*ast.Array) + if va.Len() > 0 { + result.Insert(k, ast.NewTerm(va)) + } + case ast.Object: + // Safe, because we knew the type before going to prune it. + vo := pruneIrrelevantGraphQLASTNodes(v).(ast.Object) + if vo.Len() > 0 { + result.Insert(k, ast.NewTerm(vo)) + } + default: + result.Insert(k, vTerm) + } + } + return result + default: + return x + } +} + +// Reports errors from parsing/validation. +func builtinGraphQLParse(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error { + var queryDoc *gqlast.QueryDocument + var schemaDoc *gqlast.SchemaDocument + var err error + + // Parse/translate query if it's a string/object. + switch x := operands[0].Value.(type) { + case ast.String: + queryDoc, err = parseQuery(string(x)) + case ast.Object: + queryDoc, err = objectToQueryDocument(x) + default: + // Error if wrong type. + return builtins.NewOperandTypeErr(0, x, "string", "object") + } + if err != nil { + return err + } + + // Parse/translate schema if it's a string/object. + switch x := operands[1].Value.(type) { + case ast.String: + schemaDoc, err = parseSchema(string(x)) + case ast.Object: + schemaDoc, err = objectToSchemaDocument(x) + default: + // Error if wrong type. + return builtins.NewOperandTypeErr(1, x, "string", "object") + } + if err != nil { + return err + } + + // Transform the ASTs into Objects. + queryASTValue, err := ast.InterfaceToValue(queryDoc) + if err != nil { + return err + } + schemaASTValue, err := ast.InterfaceToValue(schemaDoc) + if err != nil { + return err + } + + // Validate the query against the schema, erroring if there's an issue. + schema, err := convertSchema(schemaDoc) + if err != nil { + return err + } + if err := validateQuery(schema, queryDoc); err != nil { + return err + } + + // Recursively remove irrelevant AST structures. + queryResult := pruneIrrelevantGraphQLASTNodes(queryASTValue.(ast.Object)) + querySchema := pruneIrrelevantGraphQLASTNodes(schemaASTValue.(ast.Object)) + + // Construct return value. + verified := ast.ArrayTerm( + ast.NewTerm(queryResult), + ast.NewTerm(querySchema), + ) + + return iter(verified) +} + +// Returns default value when errors occur. +func builtinGraphQLParseAndVerify(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error { + var queryDoc *gqlast.QueryDocument + var schemaDoc *gqlast.SchemaDocument + var err error + + unverified := ast.ArrayTerm( + ast.InternedBooleanTerm(false), + ast.NewTerm(ast.NewObject()), + ast.NewTerm(ast.NewObject()), + ) + + // Parse/translate query if it's a string/object. + switch x := operands[0].Value.(type) { + case ast.String: + queryDoc, err = parseQuery(string(x)) + case ast.Object: + queryDoc, err = objectToQueryDocument(x) + default: + // Error if wrong type. + return iter(unverified) + } + if err != nil { + return iter(unverified) + } + + // Parse/translate schema if it's a string/object. + switch x := operands[1].Value.(type) { + case ast.String: + schemaDoc, err = parseSchema(string(x)) + case ast.Object: + schemaDoc, err = objectToSchemaDocument(x) + default: + // Error if wrong type. + return iter(unverified) + } + if err != nil { + return iter(unverified) + } + + // Transform the ASTs into Objects. + queryASTValue, err := ast.InterfaceToValue(queryDoc) + if err != nil { + return iter(unverified) + } + schemaASTValue, err := ast.InterfaceToValue(schemaDoc) + if err != nil { + return iter(unverified) + } + + // Validate the query against the schema, erroring if there's an issue. + schema, err := convertSchema(schemaDoc) + if err != nil { + return iter(unverified) + } + if err := validateQuery(schema, queryDoc); err != nil { + return iter(unverified) + } + + // Recursively remove irrelevant AST structures. + queryResult := pruneIrrelevantGraphQLASTNodes(queryASTValue.(ast.Object)) + querySchema := pruneIrrelevantGraphQLASTNodes(schemaASTValue.(ast.Object)) + + // Construct return value. + verified := ast.ArrayTerm( + ast.InternedBooleanTerm(true), + ast.NewTerm(queryResult), + ast.NewTerm(querySchema), + ) + + return iter(verified) +} + +func builtinGraphQLParseQuery(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error { + raw, err := builtins.StringOperand(operands[0].Value, 1) + if err != nil { + return err + } + + // Get the highly-nested AST struct, along with any errors generated. + query, err := parseQuery(string(raw)) + if err != nil { + return err + } + + // Transform the AST into an Object. + value, err := ast.InterfaceToValue(query) + if err != nil { + return err + } + + // Recursively remove irrelevant AST structures. + result := pruneIrrelevantGraphQLASTNodes(value.(ast.Object)) + + return iter(ast.NewTerm(result)) +} + +func builtinGraphQLParseSchema(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error { + raw, err := builtins.StringOperand(operands[0].Value, 1) + if err != nil { + return err + } + + // Get the highly-nested AST struct, along with any errors generated. + schema, err := parseSchema(string(raw)) + if err != nil { + return err + } + + // Transform the AST into an Object. + value, err := ast.InterfaceToValue(schema) + if err != nil { + return err + } + + // Recursively remove irrelevant AST structures. + result := pruneIrrelevantGraphQLASTNodes(value.(ast.Object)) + + return iter(ast.NewTerm(result)) +} + +func builtinGraphQLIsValid(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error { + var queryDoc *gqlast.QueryDocument + var schemaDoc *gqlast.SchemaDocument + var err error + + switch x := operands[0].Value.(type) { + case ast.String: + queryDoc, err = parseQuery(string(x)) + case ast.Object: + queryDoc, err = objectToQueryDocument(x) + default: + // Error if wrong type. + return iter(ast.InternedBooleanTerm(false)) + } + if err != nil { + return iter(ast.InternedBooleanTerm(false)) + } + + switch x := operands[1].Value.(type) { + case ast.String: + schemaDoc, err = parseSchema(string(x)) + case ast.Object: + schemaDoc, err = objectToSchemaDocument(x) + default: + // Error if wrong type. + return iter(ast.InternedBooleanTerm(false)) + } + if err != nil { + return iter(ast.InternedBooleanTerm(false)) + } + + // Validate the query against the schema, erroring if there's an issue. + schema, err := convertSchema(schemaDoc) + if err != nil { + return iter(ast.InternedBooleanTerm(false)) + } + if err := validateQuery(schema, queryDoc); err != nil { + return iter(ast.InternedBooleanTerm(false)) + } + + // If we got this far, the GraphQL query passed validation. + return iter(ast.InternedBooleanTerm(true)) +} + +func builtinGraphQLSchemaIsValid(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error { + var schemaDoc *gqlast.SchemaDocument + var err error + + switch x := operands[0].Value.(type) { + case ast.String: + schemaDoc, err = parseSchema(string(x)) + case ast.Object: + schemaDoc, err = objectToSchemaDocument(x) + default: + // Error if wrong type. + return iter(ast.InternedBooleanTerm(false)) + } + if err != nil { + return iter(ast.InternedBooleanTerm(false)) + } + + // Validate the schema, this determines the result + _, err = convertSchema(schemaDoc) + return iter(ast.InternedBooleanTerm(err == nil)) +} + +func init() { + RegisterBuiltinFunc(ast.GraphQLParse.Name, builtinGraphQLParse) + RegisterBuiltinFunc(ast.GraphQLParseAndVerify.Name, builtinGraphQLParseAndVerify) + RegisterBuiltinFunc(ast.GraphQLParseQuery.Name, builtinGraphQLParseQuery) + RegisterBuiltinFunc(ast.GraphQLParseSchema.Name, builtinGraphQLParseSchema) + RegisterBuiltinFunc(ast.GraphQLIsValid.Name, builtinGraphQLIsValid) + RegisterBuiltinFunc(ast.GraphQLSchemaIsValid.Name, builtinGraphQLSchemaIsValid) +} diff --git a/vendor/github.com/open-policy-agent/opa/v1/topdown/http.go b/vendor/github.com/open-policy-agent/opa/v1/topdown/http.go new file mode 100644 index 000000000..20fea2d7a --- /dev/null +++ b/vendor/github.com/open-policy-agent/opa/v1/topdown/http.go @@ -0,0 +1,1630 @@ +// Copyright 2018 The OPA Authors. All rights reserved. +// Use of this source code is governed by an Apache2 +// license that can be found in the LICENSE file. + +package topdown + +import ( + "bytes" + "context" + "crypto/tls" + "crypto/x509" + "encoding/json" + "fmt" + "io" + "math" + "net" + "net/http" + "net/url" + "os" + "runtime" + "strconv" + "strings" + "time" + + "github.com/open-policy-agent/opa/internal/version" + "github.com/open-policy-agent/opa/v1/ast" + "github.com/open-policy-agent/opa/v1/topdown/builtins" + "github.com/open-policy-agent/opa/v1/topdown/cache" + "github.com/open-policy-agent/opa/v1/tracing" + "github.com/open-policy-agent/opa/v1/util" +) + +type cachingMode string + +const ( + defaultHTTPRequestTimeoutEnv = "HTTP_SEND_TIMEOUT" + defaultCachingMode cachingMode = "serialized" + cachingModeDeserialized cachingMode = "deserialized" +) + +var defaultHTTPRequestTimeout = time.Second * 5 + +var allowedKeyNames = [...]string{ + "method", + "url", + "body", + "enable_redirect", + "force_json_decode", + "force_yaml_decode", + "headers", + "raw_body", + "tls_use_system_certs", + "tls_ca_cert", + "tls_ca_cert_file", + "tls_ca_cert_env_variable", + "tls_client_cert", + "tls_client_cert_file", + "tls_client_cert_env_variable", + "tls_client_key", + "tls_client_key_file", + "tls_client_key_env_variable", + "tls_insecure_skip_verify", + "tls_server_name", + "timeout", + "cache", + "force_cache", + "force_cache_duration_seconds", + "raise_error", + "caching_mode", + "max_retry_attempts", + "cache_ignored_headers", +} + +// ref: https://www.rfc-editor.org/rfc/rfc7231#section-6.1 +var cacheableHTTPStatusCodes = [...]int{ + http.StatusOK, + http.StatusNonAuthoritativeInfo, + http.StatusNoContent, + http.StatusPartialContent, + http.StatusMultipleChoices, + http.StatusMovedPermanently, + http.StatusNotFound, + http.StatusMethodNotAllowed, + http.StatusGone, + http.StatusRequestURITooLong, + http.StatusNotImplemented, +} + +var ( + allowedKeys = ast.NewSet() + cacheableCodes = ast.NewSet() + requiredKeys = ast.NewSet(ast.StringTerm("method"), ast.StringTerm("url")) + httpSendLatencyMetricKey = "rego_builtin_" + strings.ReplaceAll(ast.HTTPSend.Name, ".", "_") + httpSendInterQueryCacheHits = httpSendLatencyMetricKey + "_interquery_cache_hits" +) + +type httpSendKey string + +// CustomizeRoundTripper allows customizing an existing http.Transport, +// to the returned value, which could be the same Transport or a new one. +type CustomizeRoundTripper func(*http.Transport) http.RoundTripper + +const ( + // httpSendBuiltinCacheKey is the key in the builtin context cache that + // points to the http.send() specific cache resides at. + httpSendBuiltinCacheKey httpSendKey = "HTTP_SEND_CACHE_KEY" + + // HTTPSendInternalErr represents a runtime evaluation error. + HTTPSendInternalErr string = "eval_http_send_internal_error" + + // HTTPSendNetworkErr represents a network error. + HTTPSendNetworkErr string = "eval_http_send_network_error" + + // minRetryDelay is amount of time to backoff after the first failure. + minRetryDelay = time.Millisecond * 100 + + // maxRetryDelay is the upper bound of backoff delay. + maxRetryDelay = time.Second * 60 +) + +func builtinHTTPSend(bctx BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error { + + obj, err := builtins.ObjectOperand(operands[0].Value, 1) + if err != nil { + return handleBuiltinErr(ast.HTTPSend.Name, bctx.Location, err) + } + + raiseError, err := getRaiseErrorValue(obj) + if err != nil { + return handleBuiltinErr(ast.HTTPSend.Name, bctx.Location, err) + } + + req, err := validateHTTPRequestOperand(operands[0], 1) + if err != nil { + if raiseError { + return handleHTTPSendErr(bctx, err) + } + + return iter(generateRaiseErrorResult(handleBuiltinErr(ast.HTTPSend.Name, bctx.Location, err))) + } + + result, err := getHTTPResponse(bctx, req) + if err != nil { + if raiseError { + return handleHTTPSendErr(bctx, err) + } + + result = generateRaiseErrorResult(err) + } + return iter(result) +} + +func generateRaiseErrorResult(err error) *ast.Term { + obj := ast.NewObject() + obj.Insert(ast.StringTerm("status_code"), ast.InternedIntNumberTerm(0)) + + errObj := ast.NewObject() + + switch err.(type) { + case *url.Error: + errObj.Insert(ast.StringTerm("code"), ast.StringTerm(HTTPSendNetworkErr)) + default: + errObj.Insert(ast.StringTerm("code"), ast.StringTerm(HTTPSendInternalErr)) + } + + errObj.Insert(ast.StringTerm("message"), ast.StringTerm(err.Error())) + obj.Insert(ast.StringTerm("error"), ast.NewTerm(errObj)) + + return ast.NewTerm(obj) +} + +func getHTTPResponse(bctx BuiltinContext, req ast.Object) (*ast.Term, error) { + + bctx.Metrics.Timer(httpSendLatencyMetricKey).Start() + defer bctx.Metrics.Timer(httpSendLatencyMetricKey).Stop() + + key, err := getKeyFromRequest(req) + if err != nil { + return nil, err + } + + reqExecutor, err := newHTTPRequestExecutor(bctx, req, key) + if err != nil { + return nil, err + } + // Check if cache already has a response for this query + // set headers to exclude cache_ignored_headers + resp, err := reqExecutor.CheckCache() + if err != nil { + return nil, err + } + + if resp == nil { + httpResp, err := reqExecutor.ExecuteHTTPRequest() + if err != nil { + reqExecutor.InsertErrorIntoCache(err) + return nil, err + } + defer util.Close(httpResp) + // Add result to intra/inter-query cache. + resp, err = reqExecutor.InsertIntoCache(httpResp) + if err != nil { + return nil, err + } + } + + return ast.NewTerm(resp), nil +} + +// getKeyFromRequest returns a key to be used for caching HTTP responses +// deletes headers from request object mentioned in cache_ignored_headers +func getKeyFromRequest(req ast.Object) (ast.Object, error) { + // deep copy so changes to key do not reflect in the request object + key := req.Copy() + cacheIgnoredHeadersTerm := req.Get(ast.StringTerm("cache_ignored_headers")) + allHeadersTerm := req.Get(ast.StringTerm("headers")) + // skip because no headers to delete + if cacheIgnoredHeadersTerm == nil || allHeadersTerm == nil { + // need to explicitly set cache_ignored_headers to null + // equivalent requests might have different sets of exclusion lists + key.Insert(ast.StringTerm("cache_ignored_headers"), ast.NullTerm()) + return key, nil + } + var cacheIgnoredHeaders []string + var allHeaders map[string]interface{} + err := ast.As(cacheIgnoredHeadersTerm.Value, &cacheIgnoredHeaders) + if err != nil { + return nil, err + } + err = ast.As(allHeadersTerm.Value, &allHeaders) + if err != nil { + return nil, err + } + for _, header := range cacheIgnoredHeaders { + delete(allHeaders, header) + } + val, err := ast.InterfaceToValue(allHeaders) + if err != nil { + return nil, err + } + key.Insert(ast.StringTerm("headers"), ast.NewTerm(val)) + // remove cache_ignored_headers key + key.Insert(ast.StringTerm("cache_ignored_headers"), ast.NullTerm()) + return key, nil +} + +func init() { + createAllowedKeys() + createCacheableHTTPStatusCodes() + initDefaults() + RegisterBuiltinFunc(ast.HTTPSend.Name, builtinHTTPSend) +} + +func handleHTTPSendErr(bctx BuiltinContext, err error) error { + // Return HTTP client timeout errors in a generic error message to avoid confusion about what happened. + // Do not do this if the builtin context was cancelled and is what caused the request to stop. + if urlErr, ok := err.(*url.Error); ok && urlErr.Timeout() && bctx.Context.Err() == nil { + err = fmt.Errorf("%s %s: request timed out", urlErr.Op, urlErr.URL) + } + if err := bctx.Context.Err(); err != nil { + return Halt{ + Err: &Error{ + Code: CancelErr, + Message: fmt.Sprintf("http.send: timed out (%s)", err.Error()), + }, + } + } + return handleBuiltinErr(ast.HTTPSend.Name, bctx.Location, err) +} + +func initDefaults() { + timeoutDuration := os.Getenv(defaultHTTPRequestTimeoutEnv) + if timeoutDuration != "" { + var err error + defaultHTTPRequestTimeout, err = time.ParseDuration(timeoutDuration) + if err != nil { + // If it is set to something not valid don't let the process continue in a state + // that will almost definitely give unexpected results by having it set at 0 + // which means no timeout.. + // This environment variable isn't considered part of the public API. + // TODO(patrick-east): Remove the environment variable + panic(fmt.Sprintf("invalid value for HTTP_SEND_TIMEOUT: %s", err)) + } + } +} + +func validateHTTPRequestOperand(term *ast.Term, pos int) (ast.Object, error) { + + obj, err := builtins.ObjectOperand(term.Value, pos) + if err != nil { + return nil, err + } + + requestKeys := ast.NewSet(obj.Keys()...) + + invalidKeys := requestKeys.Diff(allowedKeys) + if invalidKeys.Len() != 0 { + return nil, builtins.NewOperandErr(pos, "invalid request parameters(s): %v", invalidKeys) + } + + missingKeys := requiredKeys.Diff(requestKeys) + if missingKeys.Len() != 0 { + return nil, builtins.NewOperandErr(pos, "missing required request parameters(s): %v", missingKeys) + } + + return obj, nil + +} + +// canonicalizeHeaders returns a copy of the headers where the keys are in +// canonical HTTP form. +func canonicalizeHeaders(headers map[string]interface{}) map[string]interface{} { + canonicalized := map[string]interface{}{} + + for k, v := range headers { + canonicalized[http.CanonicalHeaderKey(k)] = v + } + + return canonicalized +} + +// useSocket examines the url for "unix://" and returns a *http.Transport with +// a DialContext that opens a socket (specified in the http call). +// The url is expected to contain socket=/path/to/socket (url encoded) +// Ex. "unix://localhost/end/point?socket=%2Ftmp%2Fhttp.sock" +func useSocket(rawURL string, tlsConfig *tls.Config) (bool, string, *http.Transport) { + u, err := url.Parse(rawURL) + if err != nil { + return false, "", nil + } + + if u.Scheme != "unix" || u.RawQuery == "" { + return false, rawURL, nil + } + + v, err := url.ParseQuery(u.RawQuery) + if err != nil { + return false, rawURL, nil + } + + // Rewrite URL targeting the UNIX domain socket. + u.Scheme = "http" + + // Extract the path to the socket. + // Only retrieve the first value. Subsequent values are ignored and removed + // to prevent HTTP parameter pollution. + socket := v.Get("socket") + v.Del("socket") + u.RawQuery = v.Encode() + + tr := http.DefaultTransport.(*http.Transport).Clone() + tr.DialContext = func(ctx context.Context, _, _ string) (net.Conn, error) { + return http.DefaultTransport.(*http.Transport).DialContext(ctx, "unix", socket) + } + tr.TLSClientConfig = tlsConfig + tr.DisableKeepAlives = true + + return true, u.String(), tr +} + +func verifyHost(bctx BuiltinContext, host string) error { + if bctx.Capabilities == nil || bctx.Capabilities.AllowNet == nil { + return nil + } + + for _, allowed := range bctx.Capabilities.AllowNet { + if allowed == host { + return nil + } + } + + return fmt.Errorf("unallowed host: %s", host) +} + +func verifyURLHost(bctx BuiltinContext, unverifiedURL string) error { + // Eager return to avoid unnecessary URL parsing + if bctx.Capabilities == nil || bctx.Capabilities.AllowNet == nil { + return nil + } + + parsedURL, err := url.Parse(unverifiedURL) + if err != nil { + return err + } + + host := strings.Split(parsedURL.Host, ":")[0] + + return verifyHost(bctx, host) +} + +func createHTTPRequest(bctx BuiltinContext, obj ast.Object) (*http.Request, *http.Client, error) { + var url string + var method string + + // Additional CA certificates loading options. + var tlsCaCert []byte + var tlsCaCertEnvVar string + var tlsCaCertFile string + + // Client TLS certificate and key options. Each input source + // comes in a matched pair. + var tlsClientCert []byte + var tlsClientKey []byte + + var tlsClientCertEnvVar string + var tlsClientKeyEnvVar string + + var tlsClientCertFile string + var tlsClientKeyFile string + + var tlsServerName string + var body *bytes.Buffer + var rawBody *bytes.Buffer + var enableRedirect bool + var tlsUseSystemCerts *bool + var tlsConfig tls.Config + var customHeaders map[string]interface{} + var tlsInsecureSkipVerify bool + timeout := defaultHTTPRequestTimeout + + for _, val := range obj.Keys() { + key, err := ast.JSON(val.Value) + if err != nil { + return nil, nil, err + } + + key = key.(string) + + var strVal string + + if s, ok := obj.Get(val).Value.(ast.String); ok { + strVal = strings.Trim(string(s), "\"") + } else { + // Most parameters are strings, so consolidate the type checking. + switch key { + case "method", + "url", + "raw_body", + "tls_ca_cert", + "tls_ca_cert_file", + "tls_ca_cert_env_variable", + "tls_client_cert", + "tls_client_cert_file", + "tls_client_cert_env_variable", + "tls_client_key", + "tls_client_key_file", + "tls_client_key_env_variable", + "tls_server_name": + return nil, nil, fmt.Errorf("%q must be a string", key) + } + } + + switch key { + case "method": + method = strings.ToUpper(strVal) + case "url": + err := verifyURLHost(bctx, strVal) + if err != nil { + return nil, nil, err + } + url = strVal + case "enable_redirect": + enableRedirect, err = strconv.ParseBool(obj.Get(val).String()) + if err != nil { + return nil, nil, err + } + case "body": + bodyVal := obj.Get(val).Value + bodyValInterface, err := ast.JSON(bodyVal) + if err != nil { + return nil, nil, err + } + + bodyValBytes, err := json.Marshal(bodyValInterface) + if err != nil { + return nil, nil, err + } + body = bytes.NewBuffer(bodyValBytes) + case "raw_body": + rawBody = bytes.NewBufferString(strVal) + case "tls_use_system_certs": + tempTLSUseSystemCerts, err := strconv.ParseBool(obj.Get(val).String()) + if err != nil { + return nil, nil, err + } + tlsUseSystemCerts = &tempTLSUseSystemCerts + case "tls_ca_cert": + tlsCaCert = []byte(strVal) + case "tls_ca_cert_file": + tlsCaCertFile = strVal + case "tls_ca_cert_env_variable": + tlsCaCertEnvVar = strVal + case "tls_client_cert": + tlsClientCert = []byte(strVal) + case "tls_client_cert_file": + tlsClientCertFile = strVal + case "tls_client_cert_env_variable": + tlsClientCertEnvVar = strVal + case "tls_client_key": + tlsClientKey = []byte(strVal) + case "tls_client_key_file": + tlsClientKeyFile = strVal + case "tls_client_key_env_variable": + tlsClientKeyEnvVar = strVal + case "tls_server_name": + tlsServerName = strVal + case "headers": + headersVal := obj.Get(val).Value + headersValInterface, err := ast.JSON(headersVal) + if err != nil { + return nil, nil, err + } + var ok bool + customHeaders, ok = headersValInterface.(map[string]interface{}) + if !ok { + return nil, nil, fmt.Errorf("invalid type for headers key") + } + case "tls_insecure_skip_verify": + tlsInsecureSkipVerify, err = strconv.ParseBool(obj.Get(val).String()) + if err != nil { + return nil, nil, err + } + case "timeout": + timeout, err = parseTimeout(obj.Get(val).Value) + if err != nil { + return nil, nil, err + } + case "cache", "caching_mode", + "force_cache", "force_cache_duration_seconds", + "force_json_decode", "force_yaml_decode", + "raise_error", "max_retry_attempts", "cache_ignored_headers": // no-op + default: + return nil, nil, fmt.Errorf("invalid parameter %q", key) + } + } + + isTLS := false + client := &http.Client{ + Timeout: timeout, + CheckRedirect: func(*http.Request, []*http.Request) error { + return http.ErrUseLastResponse + }, + } + + if tlsInsecureSkipVerify { + isTLS = true + tlsConfig.InsecureSkipVerify = tlsInsecureSkipVerify + } + + if len(tlsClientCert) > 0 && len(tlsClientKey) > 0 { + cert, err := tls.X509KeyPair(tlsClientCert, tlsClientKey) + if err != nil { + return nil, nil, err + } + + isTLS = true + tlsConfig.Certificates = append(tlsConfig.Certificates, cert) + } + + if tlsClientCertFile != "" && tlsClientKeyFile != "" { + cert, err := tls.LoadX509KeyPair(tlsClientCertFile, tlsClientKeyFile) + if err != nil { + return nil, nil, err + } + + isTLS = true + tlsConfig.Certificates = append(tlsConfig.Certificates, cert) + } + + if tlsClientCertEnvVar != "" && tlsClientKeyEnvVar != "" { + cert, err := tls.X509KeyPair( + []byte(os.Getenv(tlsClientCertEnvVar)), + []byte(os.Getenv(tlsClientKeyEnvVar))) + if err != nil { + return nil, nil, fmt.Errorf("cannot extract public/private key pair from envvars %q, %q: %w", + tlsClientCertEnvVar, tlsClientKeyEnvVar, err) + } + + isTLS = true + tlsConfig.Certificates = append(tlsConfig.Certificates, cert) + } + + // Use system certs if no CA cert is provided + // or system certs flag is not set + if len(tlsCaCert) == 0 && tlsCaCertFile == "" && tlsCaCertEnvVar == "" && tlsUseSystemCerts == nil { + trueValue := true + tlsUseSystemCerts = &trueValue + } + + // Check the system certificates config first so that we + // load additional certificated into the correct pool. + if tlsUseSystemCerts != nil && *tlsUseSystemCerts && runtime.GOOS != "windows" { + pool, err := x509.SystemCertPool() + if err != nil { + return nil, nil, err + } + + isTLS = true + tlsConfig.RootCAs = pool + } + + if len(tlsCaCert) != 0 { + tlsCaCert = bytes.Replace(tlsCaCert, []byte("\\n"), []byte("\n"), -1) + pool, err := addCACertsFromBytes(tlsConfig.RootCAs, tlsCaCert) + if err != nil { + return nil, nil, err + } + + isTLS = true + tlsConfig.RootCAs = pool + } + + if tlsCaCertFile != "" { + pool, err := addCACertsFromFile(tlsConfig.RootCAs, tlsCaCertFile) + if err != nil { + return nil, nil, err + } + + isTLS = true + tlsConfig.RootCAs = pool + } + + if tlsCaCertEnvVar != "" { + pool, err := addCACertsFromEnv(tlsConfig.RootCAs, tlsCaCertEnvVar) + if err != nil { + return nil, nil, err + } + + isTLS = true + tlsConfig.RootCAs = pool + } + + var transport *http.Transport + if isTLS { + if ok, parsedURL, tr := useSocket(url, &tlsConfig); ok { + transport = tr + url = parsedURL + } else { + transport = http.DefaultTransport.(*http.Transport).Clone() + transport.TLSClientConfig = &tlsConfig + transport.DisableKeepAlives = true + } + } else { + if ok, parsedURL, tr := useSocket(url, nil); ok { + transport = tr + url = parsedURL + } + } + + if bctx.RoundTripper != nil { + client.Transport = bctx.RoundTripper(transport) + } else if transport != nil { + client.Transport = transport + } + + // check if redirects are enabled + if enableRedirect { + client.CheckRedirect = func(req *http.Request, _ []*http.Request) error { + return verifyURLHost(bctx, req.URL.String()) + } + } + + if rawBody != nil { + body = rawBody + } else if body == nil { + body = bytes.NewBufferString("") + } + + // create the http request, use the builtin context's context to ensure + // the request is cancelled if evaluation is cancelled. + req, err := http.NewRequest(method, url, body) + if err != nil { + return nil, nil, err + } + + req = req.WithContext(bctx.Context) + + // Add custom headers + if len(customHeaders) != 0 { + customHeaders = canonicalizeHeaders(customHeaders) + + for k, v := range customHeaders { + header, ok := v.(string) + if !ok { + return nil, nil, fmt.Errorf("invalid type for headers value %q", v) + } + + req.Header.Add(k, header) + } + + // Don't overwrite or append to one that was set in the custom headers + if _, hasUA := customHeaders["User-Agent"]; !hasUA { + req.Header.Add("User-Agent", version.UserAgent) + } + + // If the caller specifies the Host header, use it for the HTTP + // request host and the TLS server name. + if host, hasHost := customHeaders["Host"]; hasHost { + host := host.(string) // We already checked that it's a string. + req.Host = host + + // Only default the ServerName if the caller has + // specified the host. If we don't specify anything, + // Go will default to the target hostname. This name + // is not the same as the default that Go populates + // `req.Host` with, which is why we don't just set + // this unconditionally. + tlsConfig.ServerName = host + } + } + + if tlsServerName != "" { + tlsConfig.ServerName = tlsServerName + } + + if len(bctx.DistributedTracingOpts) > 0 { + client.Transport = tracing.NewTransport(client.Transport, bctx.DistributedTracingOpts) + } + + return req, client, nil +} + +func executeHTTPRequest(req *http.Request, client *http.Client, inputReqObj ast.Object) (*http.Response, error) { + var err error + var retry int + + retry, err = getNumberValFromReqObj(inputReqObj, ast.StringTerm("max_retry_attempts")) + if err != nil { + return nil, err + } + + for i := 0; true; i++ { + + var resp *http.Response + resp, err = client.Do(req) + if err == nil { + return resp, nil + } + + // final attempt + if i == retry { + break + } + + if err == context.Canceled { + return nil, err + } + + delay := util.DefaultBackoff(float64(minRetryDelay), float64(maxRetryDelay), i) + timer, timerCancel := util.TimerWithCancel(delay) + select { + case <-timer.C: + case <-req.Context().Done(): + timerCancel() // explicitly cancel the timer. + return nil, context.Canceled + } + } + return nil, err +} + +func isContentType(header http.Header, typ ...string) bool { + for _, t := range typ { + if strings.Contains(header.Get("Content-Type"), t) { + return true + } + } + return false +} + +type httpSendCacheEntry struct { + response *ast.Value + error error +} + +// The httpSendCache is used for intra-query caching of http.send results. +type httpSendCache struct { + entries *util.HashMap +} + +func newHTTPSendCache() *httpSendCache { + return &httpSendCache{ + entries: util.NewHashMap(valueEq, valueHash), + } +} + +func valueHash(v util.T) int { + return ast.StringTerm(v.(ast.Value).String()).Hash() +} + +func valueEq(a, b util.T) bool { + av := a.(ast.Value) + bv := b.(ast.Value) + return av.String() == bv.String() +} + +func (cache *httpSendCache) get(k ast.Value) *httpSendCacheEntry { + if v, ok := cache.entries.Get(k); ok { + v := v.(httpSendCacheEntry) + return &v + } + return nil +} + +func (cache *httpSendCache) putResponse(k ast.Value, v *ast.Value) { + cache.entries.Put(k, httpSendCacheEntry{response: v}) +} + +func (cache *httpSendCache) putError(k ast.Value, v error) { + cache.entries.Put(k, httpSendCacheEntry{error: v}) +} + +// In the BuiltinContext cache we only store a single entry that points to +// our ValueMap which is the "real" http.send() cache. +func getHTTPSendCache(bctx BuiltinContext) *httpSendCache { + raw, ok := bctx.Cache.Get(httpSendBuiltinCacheKey) + if !ok { + // Initialize if it isn't there + c := newHTTPSendCache() + bctx.Cache.Put(httpSendBuiltinCacheKey, c) + return c + } + + c, ok := raw.(*httpSendCache) + if !ok { + return nil + } + return c +} + +// checkHTTPSendCache checks for the given key's value in the cache +func checkHTTPSendCache(bctx BuiltinContext, key ast.Object) (ast.Value, error) { + requestCache := getHTTPSendCache(bctx) + if requestCache == nil { + return nil, nil + } + + v := requestCache.get(key) + if v != nil { + if v.error != nil { + return nil, v.error + } + if v.response != nil { + return *v.response, nil + } + // This should never happen + } + + return nil, nil +} + +func insertIntoHTTPSendCache(bctx BuiltinContext, key ast.Object, value ast.Value) { + requestCache := getHTTPSendCache(bctx) + if requestCache == nil { + // Should never happen.. if it does just skip caching the value + // FIXME: return error instead, to prevent inconsistencies? + return + } + requestCache.putResponse(key, &value) +} + +func insertErrorIntoHTTPSendCache(bctx BuiltinContext, key ast.Object, err error) { + requestCache := getHTTPSendCache(bctx) + if requestCache == nil { + // Should never happen.. if it does just skip caching the value + // FIXME: return error instead, to prevent inconsistencies? + return + } + requestCache.putError(key, err) +} + +// checkHTTPSendInterQueryCache checks for the given key's value in the inter-query cache +func (c *interQueryCache) checkHTTPSendInterQueryCache() (ast.Value, error) { + requestCache := c.bctx.InterQueryBuiltinCache + + cachedValue, found := requestCache.Get(c.key) + if !found { + return nil, nil + } + + value, cerr := requestCache.Clone(cachedValue) + if cerr != nil { + return nil, handleHTTPSendErr(c.bctx, cerr) + } + + c.bctx.Metrics.Counter(httpSendInterQueryCacheHits).Incr() + var cachedRespData *interQueryCacheData + + switch v := value.(type) { + case *interQueryCacheValue: + var err error + cachedRespData, err = v.copyCacheData() + if err != nil { + return nil, err + } + case *interQueryCacheData: + cachedRespData = v + default: + return nil, nil + } + + if getCurrentTime(c.bctx).Before(cachedRespData.ExpiresAt) { + return cachedRespData.formatToAST(c.forceJSONDecode, c.forceYAMLDecode) + } + + var err error + c.httpReq, c.httpClient, err = createHTTPRequest(c.bctx, c.key) + if err != nil { + return nil, handleHTTPSendErr(c.bctx, err) + } + + headers := parseResponseHeaders(cachedRespData.Headers) + + // check with the server if the stale response is still up-to-date. + // If server returns a new response (ie. status_code=200), update the cache with the new response + // If server returns an unmodified response (ie. status_code=304), update the headers for the existing response + result, modified, err := revalidateCachedResponse(c.httpReq, c.httpClient, c.key, headers) + requestCache.Delete(c.key) + if err != nil || result == nil { + return nil, err + } + + defer result.Body.Close() + + if !modified { + // update the headers in the cached response with their corresponding values from the 304 (Not Modified) response + for headerName, values := range result.Header { + cachedRespData.Headers.Del(headerName) + for _, v := range values { + cachedRespData.Headers.Add(headerName, v) + } + } + + if forceCaching(c.forceCacheParams) { + createdAt := getCurrentTime(c.bctx) + cachedRespData.ExpiresAt = createdAt.Add(time.Second * time.Duration(c.forceCacheParams.forceCacheDurationSeconds)) + } else { + expiresAt, err := expiryFromHeaders(result.Header) + if err != nil { + return nil, err + } + cachedRespData.ExpiresAt = expiresAt + } + + cachingMode, err := getCachingMode(c.key) + if err != nil { + return nil, err + } + + var pcv cache.InterQueryCacheValue + + if cachingMode == defaultCachingMode { + pcv, err = cachedRespData.toCacheValue() + if err != nil { + return nil, err + } + } else { + pcv = cachedRespData + } + + c.bctx.InterQueryBuiltinCache.InsertWithExpiry(c.key, pcv, cachedRespData.ExpiresAt) + + return cachedRespData.formatToAST(c.forceJSONDecode, c.forceYAMLDecode) + } + + newValue, respBody, err := formatHTTPResponseToAST(result, c.forceJSONDecode, c.forceYAMLDecode) + if err != nil { + return nil, err + } + + if err := insertIntoHTTPSendInterQueryCache(c.bctx, c.key, result, respBody, c.forceCacheParams); err != nil { + return nil, err + } + + return newValue, nil +} + +// insertIntoHTTPSendInterQueryCache inserts given key and value in the inter-query cache +func insertIntoHTTPSendInterQueryCache(bctx BuiltinContext, key ast.Value, resp *http.Response, respBody []byte, cacheParams *forceCacheParams) error { + if resp == nil || (!forceCaching(cacheParams) && !canStore(resp.Header)) || !cacheableCodes.Contains(ast.InternedIntNumberTerm(resp.StatusCode)) { + return nil + } + + requestCache := bctx.InterQueryBuiltinCache + + obj, ok := key.(ast.Object) + if !ok { + return fmt.Errorf("interface conversion error") + } + + cachingMode, err := getCachingMode(obj) + if err != nil { + return err + } + + var pcv cache.InterQueryCacheValue + var pcvData *interQueryCacheData + if cachingMode == defaultCachingMode { + pcv, pcvData, err = newInterQueryCacheValue(bctx, resp, respBody, cacheParams) + } else { + pcvData, err = newInterQueryCacheData(bctx, resp, respBody, cacheParams) + pcv = pcvData + } + + if err != nil { + return err + } + + requestCache.InsertWithExpiry(key, pcv, pcvData.ExpiresAt) + return nil +} + +func createAllowedKeys() { + for _, element := range allowedKeyNames { + allowedKeys.Add(ast.StringTerm(element)) + } +} + +func createCacheableHTTPStatusCodes() { + for _, element := range cacheableHTTPStatusCodes { + cacheableCodes.Add(ast.InternedIntNumberTerm(element)) + } +} + +func parseTimeout(timeoutVal ast.Value) (time.Duration, error) { + var timeout time.Duration + switch t := timeoutVal.(type) { + case ast.Number: + timeoutInt, ok := t.Int64() + if !ok { + return timeout, fmt.Errorf("invalid timeout number value %v, must be int64", timeoutVal) + } + return time.Duration(timeoutInt), nil + case ast.String: + // Support strings without a unit, treat them the same as just a number value (ns) + var err error + timeoutInt, err := strconv.ParseInt(string(t), 10, 64) + if err == nil { + return time.Duration(timeoutInt), nil + } + + // Try parsing it as a duration (requires a supported units suffix) + timeout, err = time.ParseDuration(string(t)) + if err != nil { + return timeout, fmt.Errorf("invalid timeout value %v: %s", timeoutVal, err) + } + return timeout, nil + default: + return timeout, builtins.NewOperandErr(1, "'timeout' must be one of {string, number} but got %s", ast.TypeName(t)) + } +} + +func getBoolValFromReqObj(req ast.Object, key *ast.Term) (bool, error) { + var b ast.Boolean + var ok bool + if v := req.Get(key); v != nil { + if b, ok = v.Value.(ast.Boolean); !ok { + return false, fmt.Errorf("invalid value for %v field", key.String()) + } + } + return bool(b), nil +} + +func getNumberValFromReqObj(req ast.Object, key *ast.Term) (int, error) { + term := req.Get(key) + if term == nil { + return 0, nil + } + + if t, ok := term.Value.(ast.Number); ok { + num, ok := t.Int() + if !ok || num < 0 { + return 0, fmt.Errorf("invalid value %v for field %v", t.String(), key.String()) + } + return num, nil + } + + return 0, fmt.Errorf("invalid value %v for field %v", term.String(), key.String()) +} + +func getCachingMode(req ast.Object) (cachingMode, error) { + key := ast.StringTerm("caching_mode") + var s ast.String + var ok bool + if v := req.Get(key); v != nil { + if s, ok = v.Value.(ast.String); !ok { + return "", fmt.Errorf("invalid value for %v field", key.String()) + } + + switch cachingMode(s) { + case defaultCachingMode, cachingModeDeserialized: + return cachingMode(s), nil + default: + return "", fmt.Errorf("invalid value specified for %v field: %v", key.String(), string(s)) + } + } + return defaultCachingMode, nil +} + +type interQueryCacheValue struct { + Data []byte +} + +func newInterQueryCacheValue(bctx BuiltinContext, resp *http.Response, respBody []byte, cacheParams *forceCacheParams) (*interQueryCacheValue, *interQueryCacheData, error) { + data, err := newInterQueryCacheData(bctx, resp, respBody, cacheParams) + if err != nil { + return nil, nil, err + } + + b, err := json.Marshal(data) + if err != nil { + return nil, nil, err + } + return &interQueryCacheValue{Data: b}, data, nil +} + +func (cb interQueryCacheValue) Clone() (cache.InterQueryCacheValue, error) { + dup := make([]byte, len(cb.Data)) + copy(dup, cb.Data) + return &interQueryCacheValue{Data: dup}, nil +} + +func (cb interQueryCacheValue) SizeInBytes() int64 { + return int64(len(cb.Data)) +} + +func (cb *interQueryCacheValue) copyCacheData() (*interQueryCacheData, error) { + var res interQueryCacheData + err := util.UnmarshalJSON(cb.Data, &res) + if err != nil { + return nil, err + } + return &res, nil +} + +type interQueryCacheData struct { + RespBody []byte + Status string + StatusCode int + Headers http.Header + ExpiresAt time.Time +} + +func forceCaching(cacheParams *forceCacheParams) bool { + return cacheParams != nil && cacheParams.forceCacheDurationSeconds > 0 +} + +func expiryFromHeaders(headers http.Header) (time.Time, error) { + var expiresAt time.Time + maxAge, err := parseMaxAgeCacheDirective(parseCacheControlHeader(headers)) + if err != nil { + return time.Time{}, err + } + if maxAge != -1 { + createdAt, err := getResponseHeaderDate(headers) + if err != nil { + return time.Time{}, err + } + expiresAt = createdAt.Add(time.Second * time.Duration(maxAge)) + } else { + expiresAt = getResponseHeaderExpires(headers) + } + return expiresAt, nil +} + +func newInterQueryCacheData(bctx BuiltinContext, resp *http.Response, respBody []byte, cacheParams *forceCacheParams) (*interQueryCacheData, error) { + var expiresAt time.Time + + if forceCaching(cacheParams) { + createdAt := getCurrentTime(bctx) + expiresAt = createdAt.Add(time.Second * time.Duration(cacheParams.forceCacheDurationSeconds)) + } else { + var err error + expiresAt, err = expiryFromHeaders(resp.Header) + if err != nil { + return nil, err + } + } + + cv := interQueryCacheData{ + ExpiresAt: expiresAt, + RespBody: respBody, + Status: resp.Status, + StatusCode: resp.StatusCode, + Headers: resp.Header} + + return &cv, nil +} + +func (c *interQueryCacheData) formatToAST(forceJSONDecode, forceYAMLDecode bool) (ast.Value, error) { + return prepareASTResult(c.Headers, forceJSONDecode, forceYAMLDecode, c.RespBody, c.Status, c.StatusCode) +} + +func (c *interQueryCacheData) toCacheValue() (*interQueryCacheValue, error) { + b, err := json.Marshal(c) + if err != nil { + return nil, err + } + return &interQueryCacheValue{Data: b}, nil +} + +func (c *interQueryCacheData) SizeInBytes() int64 { + return 0 +} + +func (c *interQueryCacheData) Clone() (cache.InterQueryCacheValue, error) { + dup := make([]byte, len(c.RespBody)) + copy(dup, c.RespBody) + + return &interQueryCacheData{ + ExpiresAt: c.ExpiresAt, + RespBody: dup, + Status: c.Status, + StatusCode: c.StatusCode, + Headers: c.Headers.Clone()}, nil +} + +type responseHeaders struct { + etag string // identifier for a specific version of the response + lastModified string // date and time response was last modified as per origin server +} + +// deltaSeconds specifies a non-negative integer, representing +// time in seconds: http://tools.ietf.org/html/rfc7234#section-1.2.1 +type deltaSeconds int32 + +func parseResponseHeaders(headers http.Header) *responseHeaders { + result := responseHeaders{} + + result.etag = headers.Get("etag") + + result.lastModified = headers.Get("last-modified") + + return &result +} + +func revalidateCachedResponse(req *http.Request, client *http.Client, inputReqObj ast.Object, headers *responseHeaders) (*http.Response, bool, error) { + etag := headers.etag + lastModified := headers.lastModified + + if etag == "" && lastModified == "" { + return nil, false, nil + } + + cloneReq := req.Clone(req.Context()) + + if etag != "" { + cloneReq.Header.Set("if-none-match", etag) + } + + if lastModified != "" { + cloneReq.Header.Set("if-modified-since", lastModified) + } + + response, err := executeHTTPRequest(cloneReq, client, inputReqObj) + if err != nil { + return nil, false, err + } + + switch response.StatusCode { + case http.StatusOK: + return response, true, nil + + case http.StatusNotModified: + return response, false, nil + } + util.Close(response) + return nil, false, nil +} + +func canStore(headers http.Header) bool { + ccHeaders := parseCacheControlHeader(headers) + + // Check "no-store" cache directive + // The "no-store" response directive indicates that a cache MUST NOT + // store any part of either the immediate request or response. + if _, ok := ccHeaders["no-store"]; ok { + return false + } + return true +} + +func getCurrentTime(bctx BuiltinContext) time.Time { + var current time.Time + + value, err := ast.JSON(bctx.Time.Value) + if err != nil { + return current + } + + valueNum, ok := value.(json.Number) + if !ok { + return current + } + + valueNumInt, err := valueNum.Int64() + if err != nil { + return current + } + + current = time.Unix(0, valueNumInt).UTC() + return current +} + +func parseCacheControlHeader(headers http.Header) map[string]string { + ccDirectives := map[string]string{} + ccHeader := headers.Get("cache-control") + + for _, part := range strings.Split(ccHeader, ",") { + part = strings.Trim(part, " ") + if part == "" { + continue + } + if strings.ContainsRune(part, '=') { + items := strings.Split(part, "=") + if len(items) != 2 { + continue + } + ccDirectives[strings.Trim(items[0], " ")] = strings.Trim(items[1], ",") + } else { + ccDirectives[part] = "" + } + } + + return ccDirectives +} + +func getResponseHeaderDate(headers http.Header) (date time.Time, err error) { + dateHeader := headers.Get("date") + if dateHeader == "" { + err = fmt.Errorf("no date header") + return + } + return http.ParseTime(dateHeader) +} + +func getResponseHeaderExpires(headers http.Header) time.Time { + expiresHeader := headers.Get("expires") + if expiresHeader == "" { + return time.Time{} + } + + date, err := http.ParseTime(expiresHeader) + if err != nil { + // servers can set `Expires: 0` which is an invalid date to indicate expired content + return time.Time{} + } + + return date +} + +// parseMaxAgeCacheDirective parses the max-age directive expressed in delta-seconds as per +// https://tools.ietf.org/html/rfc7234#section-1.2.1 +func parseMaxAgeCacheDirective(cc map[string]string) (deltaSeconds, error) { + maxAge, ok := cc["max-age"] + if !ok { + return deltaSeconds(-1), nil + } + + val, err := strconv.ParseUint(maxAge, 10, 32) + if err != nil { + if numError, ok := err.(*strconv.NumError); ok { + if numError.Err == strconv.ErrRange { + return deltaSeconds(math.MaxInt32), nil + } + } + return deltaSeconds(-1), err + } + + if val > math.MaxInt32 { + return deltaSeconds(math.MaxInt32), nil + } + return deltaSeconds(val), nil +} + +func formatHTTPResponseToAST(resp *http.Response, forceJSONDecode, forceYAMLDecode bool) (ast.Value, []byte, error) { + + resultRawBody, err := io.ReadAll(resp.Body) + if err != nil { + return nil, nil, err + } + + resultObj, err := prepareASTResult(resp.Header, forceJSONDecode, forceYAMLDecode, resultRawBody, resp.Status, resp.StatusCode) + if err != nil { + return nil, nil, err + } + + return resultObj, resultRawBody, nil +} + +func prepareASTResult(headers http.Header, forceJSONDecode, forceYAMLDecode bool, body []byte, status string, statusCode int) (ast.Value, error) { + var resultBody interface{} + + // If the response body cannot be JSON/YAML decoded, + // an error will not be returned. Instead, the "body" field + // in the result will be null. + switch { + case forceJSONDecode || isContentType(headers, "application/json"): + _ = util.UnmarshalJSON(body, &resultBody) + case forceYAMLDecode || isContentType(headers, "application/yaml", "application/x-yaml"): + _ = util.Unmarshal(body, &resultBody) + } + + result := make(map[string]interface{}) + result["status"] = status + result["status_code"] = statusCode + result["body"] = resultBody + result["raw_body"] = string(body) + result["headers"] = getResponseHeaders(headers) + + resultObj, err := ast.InterfaceToValue(result) + if err != nil { + return nil, err + } + + return resultObj, nil +} + +func getResponseHeaders(headers http.Header) map[string]interface{} { + respHeaders := map[string]interface{}{} + for headerName, values := range headers { + var respValues []interface{} + for _, v := range values { + respValues = append(respValues, v) + } + respHeaders[strings.ToLower(headerName)] = respValues + } + return respHeaders +} + +// httpRequestExecutor defines an interface for the http send cache +type httpRequestExecutor interface { + CheckCache() (ast.Value, error) + InsertIntoCache(value *http.Response) (ast.Value, error) + InsertErrorIntoCache(err error) + ExecuteHTTPRequest() (*http.Response, error) +} + +// newHTTPRequestExecutor returns a new HTTP request executor that wraps either an inter-query or +// intra-query cache implementation +func newHTTPRequestExecutor(bctx BuiltinContext, req ast.Object, key ast.Object) (httpRequestExecutor, error) { + useInterQueryCache, forceCacheParams, err := useInterQueryCache(req) + if err != nil { + return nil, handleHTTPSendErr(bctx, err) + } + + if useInterQueryCache && bctx.InterQueryBuiltinCache != nil { + return newInterQueryCache(bctx, req, key, forceCacheParams) + } + return newIntraQueryCache(bctx, req, key) +} + +type interQueryCache struct { + bctx BuiltinContext + req ast.Object + key ast.Object + httpReq *http.Request + httpClient *http.Client + forceJSONDecode bool + forceYAMLDecode bool + forceCacheParams *forceCacheParams +} + +func newInterQueryCache(bctx BuiltinContext, req ast.Object, key ast.Object, forceCacheParams *forceCacheParams) (*interQueryCache, error) { + return &interQueryCache{bctx: bctx, req: req, key: key, forceCacheParams: forceCacheParams}, nil +} + +// CheckCache checks the cache for the value of the key set on this object +func (c *interQueryCache) CheckCache() (ast.Value, error) { + var err error + + // Checking the intra-query cache first ensures consistency of errors and HTTP responses within a query. + resp, err := checkHTTPSendCache(c.bctx, c.key) + if err != nil { + return nil, err + } + if resp != nil { + return resp, nil + } + + c.forceJSONDecode, err = getBoolValFromReqObj(c.key, ast.StringTerm("force_json_decode")) + if err != nil { + return nil, handleHTTPSendErr(c.bctx, err) + } + c.forceYAMLDecode, err = getBoolValFromReqObj(c.key, ast.StringTerm("force_yaml_decode")) + if err != nil { + return nil, handleHTTPSendErr(c.bctx, err) + } + + resp, err = c.checkHTTPSendInterQueryCache() + // Always insert the result of the inter-query cache into the intra-query cache, to maintain consistency within the same query. + if err != nil { + insertErrorIntoHTTPSendCache(c.bctx, c.key, err) + } + if resp != nil { + insertIntoHTTPSendCache(c.bctx, c.key, resp) + } + return resp, err +} + +// InsertIntoCache inserts the key set on this object into the cache with the given value +func (c *interQueryCache) InsertIntoCache(value *http.Response) (ast.Value, error) { + result, respBody, err := formatHTTPResponseToAST(value, c.forceJSONDecode, c.forceYAMLDecode) + if err != nil { + return nil, handleHTTPSendErr(c.bctx, err) + } + + // Always insert into the intra-query cache, to maintain consistency within the same query. + insertIntoHTTPSendCache(c.bctx, c.key, result) + + // We ignore errors when populating the inter-query cache, because we've already populated the intra-cache, + // and query consistency is our primary concern. + _ = insertIntoHTTPSendInterQueryCache(c.bctx, c.key, value, respBody, c.forceCacheParams) + return result, nil +} + +func (c *interQueryCache) InsertErrorIntoCache(err error) { + insertErrorIntoHTTPSendCache(c.bctx, c.key, err) +} + +// ExecuteHTTPRequest executes a HTTP request +func (c *interQueryCache) ExecuteHTTPRequest() (*http.Response, error) { + var err error + c.httpReq, c.httpClient, err = createHTTPRequest(c.bctx, c.req) + if err != nil { + return nil, handleHTTPSendErr(c.bctx, err) + } + + return executeHTTPRequest(c.httpReq, c.httpClient, c.req) +} + +type intraQueryCache struct { + bctx BuiltinContext + req ast.Object + key ast.Object +} + +func newIntraQueryCache(bctx BuiltinContext, req ast.Object, key ast.Object) (*intraQueryCache, error) { + return &intraQueryCache{bctx: bctx, req: req, key: key}, nil +} + +// CheckCache checks the cache for the value of the key set on this object +func (c *intraQueryCache) CheckCache() (ast.Value, error) { + return checkHTTPSendCache(c.bctx, c.key) +} + +// InsertIntoCache inserts the key set on this object into the cache with the given value +func (c *intraQueryCache) InsertIntoCache(value *http.Response) (ast.Value, error) { + forceJSONDecode, err := getBoolValFromReqObj(c.key, ast.StringTerm("force_json_decode")) + if err != nil { + return nil, handleHTTPSendErr(c.bctx, err) + } + forceYAMLDecode, err := getBoolValFromReqObj(c.key, ast.StringTerm("force_yaml_decode")) + if err != nil { + return nil, handleHTTPSendErr(c.bctx, err) + } + + result, _, err := formatHTTPResponseToAST(value, forceJSONDecode, forceYAMLDecode) + if err != nil { + return nil, handleHTTPSendErr(c.bctx, err) + } + + if cacheableCodes.Contains(ast.InternedIntNumberTerm(value.StatusCode)) { + insertIntoHTTPSendCache(c.bctx, c.key, result) + } + + return result, nil +} + +func (c *intraQueryCache) InsertErrorIntoCache(err error) { + insertErrorIntoHTTPSendCache(c.bctx, c.key, err) +} + +// ExecuteHTTPRequest executes a HTTP request +func (c *intraQueryCache) ExecuteHTTPRequest() (*http.Response, error) { + httpReq, httpClient, err := createHTTPRequest(c.bctx, c.req) + if err != nil { + return nil, handleHTTPSendErr(c.bctx, err) + } + return executeHTTPRequest(httpReq, httpClient, c.req) +} + +func useInterQueryCache(req ast.Object) (bool, *forceCacheParams, error) { + value, err := getBoolValFromReqObj(req, ast.StringTerm("cache")) + if err != nil { + return false, nil, err + } + + valueForceCache, err := getBoolValFromReqObj(req, ast.StringTerm("force_cache")) + if err != nil { + return false, nil, err + } + + if valueForceCache { + forceCacheParams, err := newForceCacheParams(req) + return true, forceCacheParams, err + } + + return value, nil, nil +} + +type forceCacheParams struct { + forceCacheDurationSeconds int32 +} + +func newForceCacheParams(req ast.Object) (*forceCacheParams, error) { + term := req.Get(ast.StringTerm("force_cache_duration_seconds")) + if term == nil { + return nil, fmt.Errorf("'force_cache' set but 'force_cache_duration_seconds' parameter is missing") + } + + forceCacheDurationSeconds := term.String() + + value, err := strconv.ParseInt(forceCacheDurationSeconds, 10, 32) + if err != nil { + return nil, err + } + + return &forceCacheParams{forceCacheDurationSeconds: int32(value)}, nil +} + +func getRaiseErrorValue(req ast.Object) (bool, error) { + result := ast.Boolean(true) + var ok bool + if v := req.Get(ast.StringTerm("raise_error")); v != nil { + if result, ok = v.Value.(ast.Boolean); !ok { + return false, fmt.Errorf("invalid value for raise_error field") + } + } + return bool(result), nil +} diff --git a/vendor/github.com/open-policy-agent/opa/topdown/http_fixup.go b/vendor/github.com/open-policy-agent/opa/v1/topdown/http_fixup.go similarity index 100% rename from vendor/github.com/open-policy-agent/opa/topdown/http_fixup.go rename to vendor/github.com/open-policy-agent/opa/v1/topdown/http_fixup.go diff --git a/vendor/github.com/open-policy-agent/opa/topdown/http_fixup_darwin.go b/vendor/github.com/open-policy-agent/opa/v1/topdown/http_fixup_darwin.go similarity index 100% rename from vendor/github.com/open-policy-agent/opa/topdown/http_fixup_darwin.go rename to vendor/github.com/open-policy-agent/opa/v1/topdown/http_fixup_darwin.go diff --git a/vendor/github.com/open-policy-agent/opa/topdown/input.go b/vendor/github.com/open-policy-agent/opa/v1/topdown/input.go similarity index 98% rename from vendor/github.com/open-policy-agent/opa/topdown/input.go rename to vendor/github.com/open-policy-agent/opa/v1/topdown/input.go index cb70aeb71..dccf94d89 100644 --- a/vendor/github.com/open-policy-agent/opa/topdown/input.go +++ b/vendor/github.com/open-policy-agent/opa/v1/topdown/input.go @@ -7,7 +7,7 @@ package topdown import ( "fmt" - "github.com/open-policy-agent/opa/ast" + "github.com/open-policy-agent/opa/v1/ast" ) var errBadPath = fmt.Errorf("bad document path") diff --git a/vendor/github.com/open-policy-agent/opa/v1/topdown/instrumentation.go b/vendor/github.com/open-policy-agent/opa/v1/topdown/instrumentation.go new file mode 100644 index 000000000..93da1d002 --- /dev/null +++ b/vendor/github.com/open-policy-agent/opa/v1/topdown/instrumentation.go @@ -0,0 +1,63 @@ +// Copyright 2018 The OPA Authors. All rights reserved. +// Use of this source code is governed by an Apache2 +// license that can be found in the LICENSE file. + +package topdown + +import "github.com/open-policy-agent/opa/v1/metrics" + +const ( + evalOpPlug = "eval_op_plug" + evalOpResolve = "eval_op_resolve" + evalOpRuleIndex = "eval_op_rule_index" + evalOpBuiltinCall = "eval_op_builtin_call" + evalOpVirtualCacheHit = "eval_op_virtual_cache_hit" + evalOpVirtualCacheMiss = "eval_op_virtual_cache_miss" + evalOpBaseCacheHit = "eval_op_base_cache_hit" + evalOpBaseCacheMiss = "eval_op_base_cache_miss" + evalOpComprehensionCacheSkip = "eval_op_comprehension_cache_skip" + evalOpComprehensionCacheBuild = "eval_op_comprehension_cache_build" + evalOpComprehensionCacheHit = "eval_op_comprehension_cache_hit" + evalOpComprehensionCacheMiss = "eval_op_comprehension_cache_miss" + partialOpSaveUnify = "partial_op_save_unify" + partialOpSaveSetContains = "partial_op_save_set_contains" + partialOpSaveSetContainsRec = "partial_op_save_set_contains_rec" + partialOpCopyPropagation = "partial_op_copy_propagation" +) + +// Instrumentation implements helper functions to instrument query evaluation +// to diagnose performance issues. Instrumentation may be expensive in some +// cases, so it is disabled by default. +type Instrumentation struct { + m metrics.Metrics +} + +// NewInstrumentation returns a new Instrumentation object. Performance +// diagnostics recorded on this Instrumentation object will stored in m. +func NewInstrumentation(m metrics.Metrics) *Instrumentation { + return &Instrumentation{ + m: m, + } +} + +func (instr *Instrumentation) startTimer(name string) { + if instr == nil { + return + } + instr.m.Timer(name).Start() +} + +func (instr *Instrumentation) stopTimer(name string) { + if instr == nil { + return + } + delta := instr.m.Timer(name).Stop() + instr.m.Histogram(name).Update(delta) +} + +func (instr *Instrumentation) counterIncr(name string) { + if instr == nil { + return + } + instr.m.Counter(name).Incr() +} diff --git a/vendor/github.com/open-policy-agent/opa/topdown/json.go b/vendor/github.com/open-policy-agent/opa/v1/topdown/json.go similarity index 98% rename from vendor/github.com/open-policy-agent/opa/topdown/json.go rename to vendor/github.com/open-policy-agent/opa/v1/topdown/json.go index 8a5d23283..57e079d2e 100644 --- a/vendor/github.com/open-policy-agent/opa/topdown/json.go +++ b/vendor/github.com/open-policy-agent/opa/v1/topdown/json.go @@ -9,8 +9,8 @@ import ( "strconv" "strings" - "github.com/open-policy-agent/opa/ast" - "github.com/open-policy-agent/opa/topdown/builtins" + "github.com/open-policy-agent/opa/v1/ast" + "github.com/open-policy-agent/opa/v1/topdown/builtins" "github.com/open-policy-agent/opa/internal/edittree" ) diff --git a/vendor/github.com/open-policy-agent/opa/topdown/jsonschema.go b/vendor/github.com/open-policy-agent/opa/v1/topdown/jsonschema.go similarity index 97% rename from vendor/github.com/open-policy-agent/opa/topdown/jsonschema.go rename to vendor/github.com/open-policy-agent/opa/v1/topdown/jsonschema.go index d319bc0b0..588b7ec4c 100644 --- a/vendor/github.com/open-policy-agent/opa/topdown/jsonschema.go +++ b/vendor/github.com/open-policy-agent/opa/v1/topdown/jsonschema.go @@ -8,8 +8,8 @@ import ( "encoding/json" "errors" - "github.com/open-policy-agent/opa/ast" "github.com/open-policy-agent/opa/internal/gojsonschema" + "github.com/open-policy-agent/opa/v1/ast" ) // astValueToJSONSchemaLoader converts a value to JSON Loader. @@ -44,7 +44,7 @@ func astValueToJSONSchemaLoader(value ast.Value) (gojsonschema.JSONLoader, error } func newResultTerm(valid bool, data *ast.Term) *ast.Term { - return ast.ArrayTerm(ast.BooleanTerm(valid), data) + return ast.ArrayTerm(ast.InternedBooleanTerm(valid), data) } // builtinJSONSchemaVerify accepts 1 argument which can be string or object and checks if it is valid JSON schema. diff --git a/vendor/github.com/open-policy-agent/opa/topdown/lineage/lineage.go b/vendor/github.com/open-policy-agent/opa/v1/topdown/lineage/lineage.go similarity index 98% rename from vendor/github.com/open-policy-agent/opa/topdown/lineage/lineage.go rename to vendor/github.com/open-policy-agent/opa/v1/topdown/lineage/lineage.go index 2c0ec7287..26e66eb59 100644 --- a/vendor/github.com/open-policy-agent/opa/topdown/lineage/lineage.go +++ b/vendor/github.com/open-policy-agent/opa/v1/topdown/lineage/lineage.go @@ -5,7 +5,7 @@ package lineage import ( - "github.com/open-policy-agent/opa/topdown" + "github.com/open-policy-agent/opa/v1/topdown" ) // Debug contains everything in the log. diff --git a/vendor/github.com/open-policy-agent/opa/topdown/net.go b/vendor/github.com/open-policy-agent/opa/v1/topdown/net.go similarity index 93% rename from vendor/github.com/open-policy-agent/opa/topdown/net.go rename to vendor/github.com/open-policy-agent/opa/v1/topdown/net.go index 534520529..17ed77984 100644 --- a/vendor/github.com/open-policy-agent/opa/topdown/net.go +++ b/vendor/github.com/open-policy-agent/opa/v1/topdown/net.go @@ -8,8 +8,8 @@ import ( "net" "strings" - "github.com/open-policy-agent/opa/ast" - "github.com/open-policy-agent/opa/topdown/builtins" + "github.com/open-policy-agent/opa/v1/ast" + "github.com/open-policy-agent/opa/v1/topdown/builtins" ) type lookupIPAddrCacheKey string diff --git a/vendor/github.com/open-policy-agent/opa/topdown/numbers.go b/vendor/github.com/open-policy-agent/opa/v1/topdown/numbers.go similarity index 67% rename from vendor/github.com/open-policy-agent/opa/topdown/numbers.go rename to vendor/github.com/open-policy-agent/opa/v1/topdown/numbers.go index 27f3156b8..05a4f96f0 100644 --- a/vendor/github.com/open-policy-agent/opa/topdown/numbers.go +++ b/vendor/github.com/open-policy-agent/opa/v1/topdown/numbers.go @@ -8,15 +8,19 @@ import ( "fmt" "math/big" - "github.com/open-policy-agent/opa/ast" - "github.com/open-policy-agent/opa/topdown/builtins" + "github.com/open-policy-agent/opa/v1/ast" + "github.com/open-policy-agent/opa/v1/topdown/builtins" ) type randIntCachingKey string +var zero = big.NewInt(0) var one = big.NewInt(1) func builtinNumbersRange(bctx BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error { + if canGenerateCheapRange(operands) { + return generateCheapRange(operands, iter) + } x, err := builtins.BigIntOperand(operands[0].Value, 1) if err != nil { @@ -53,7 +57,7 @@ func builtinNumbersRangeStep(bctx BuiltinContext, operands []*ast.Term, iter fun return err } - if step.Cmp(big.NewInt(0)) <= 0 { + if step.Cmp(zero) <= 0 { return fmt.Errorf("numbers.range_step: step must be a positive number above zero") } @@ -65,6 +69,57 @@ func builtinNumbersRangeStep(bctx BuiltinContext, operands []*ast.Term, iter fun return iter(ast) } +func canGenerateCheapRange(operands []*ast.Term) bool { + x, err := builtins.IntOperand(operands[0].Value, 1) + if err != nil || !ast.HasInternedIntNumberTerm(x) { + return false + } + + y, err := builtins.IntOperand(operands[1].Value, 2) + if err != nil || !ast.HasInternedIntNumberTerm(y) { + return false + } + + return true +} + +func generateCheapRange(operands []*ast.Term, iter func(*ast.Term) error) error { + x, err := builtins.IntOperand(operands[0].Value, 1) + if err != nil { + return err + } + + y, err := builtins.IntOperand(operands[1].Value, 2) + if err != nil { + return err + } + + step := 1 + + stepOp, err := builtins.IntOperand(operands[2].Value, 3) + if err == nil { + step = stepOp + } + + if step <= 0 { + return fmt.Errorf("numbers.range_step: step must be a positive number above zero") + } + + terms := make([]*ast.Term, 0, y+1) + + if x <= y { + for i := x; i <= y; i += step { + terms = append(terms, ast.InternedIntNumberTerm(i)) + } + } else { + for i := x; i >= y; i -= step { + terms = append(terms, ast.InternedIntNumberTerm(i)) + } + } + + return iter(ast.ArrayTerm(terms...)) +} + func generateRange(bctx BuiltinContext, x *big.Int, y *big.Int, step *big.Int, funcName string) (*ast.Term, error) { cmp := x.Cmp(y) @@ -109,7 +164,7 @@ func builtinRandIntn(bctx BuiltinContext, operands []*ast.Term, iter func(*ast.T } if n == 0 { - return iter(ast.IntNumberTerm(0)) + return iter(ast.InternedIntNumberTerm(0)) } if n < 0 { @@ -126,7 +181,7 @@ func builtinRandIntn(bctx BuiltinContext, operands []*ast.Term, iter func(*ast.T if err != nil { return err } - result := ast.IntNumberTerm(r.Intn(n)) + result := ast.InternedIntNumberTerm(r.Intn(n)) bctx.Cache.Put(key, result) return iter(result) diff --git a/vendor/github.com/open-policy-agent/opa/topdown/object.go b/vendor/github.com/open-policy-agent/opa/v1/topdown/object.go similarity index 91% rename from vendor/github.com/open-policy-agent/opa/topdown/object.go rename to vendor/github.com/open-policy-agent/opa/v1/topdown/object.go index ba5d77ff3..11671da5f 100644 --- a/vendor/github.com/open-policy-agent/opa/topdown/object.go +++ b/vendor/github.com/open-policy-agent/opa/v1/topdown/object.go @@ -5,9 +5,9 @@ package topdown import ( - "github.com/open-policy-agent/opa/ast" "github.com/open-policy-agent/opa/internal/ref" - "github.com/open-policy-agent/opa/topdown/builtins" + "github.com/open-policy-agent/opa/v1/ast" + "github.com/open-policy-agent/opa/v1/topdown/builtins" ) func builtinObjectUnion(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error { @@ -50,9 +50,6 @@ func builtinObjectUnionN(_ BuiltinContext, operands []*ast.Term, iter func(*ast. return builtins.NewOperandElementErr(1, arr, arr.Elem(i).Value, "object") } mergewithOverwriteInPlace(result, o, frozenKeys) - if err != nil { - return err - } } return iter(ast.NewTerm(result)) @@ -144,37 +141,24 @@ func builtinObjectKeys(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Te return err } - keys := ast.SetTerm(object.Keys()...) - - return iter(keys) + return iter(ast.SetTerm(object.Keys()...)) } // getObjectKeysParam returns a set of key values // from a supplied ast array, object, set value func getObjectKeysParam(arrayOrSet ast.Value) (ast.Set, error) { - keys := ast.NewSet() - switch v := arrayOrSet.(type) { case *ast.Array: - _ = v.Iter(func(f *ast.Term) error { - keys.Add(f) - return nil - }) + keys := ast.NewSet() + v.Foreach(keys.Add) + return keys, nil case ast.Set: - _ = v.Iter(func(f *ast.Term) error { - keys.Add(f) - return nil - }) + return ast.NewSet(v.Slice()...), nil case ast.Object: - _ = v.Iter(func(k *ast.Term, _ *ast.Term) error { - keys.Add(k) - return nil - }) - default: - return nil, builtins.NewOperandTypeErr(2, arrayOrSet, "object", "set", "array") + return ast.NewSet(v.Keys()...), nil } - return keys, nil + return nil, builtins.NewOperandTypeErr(2, arrayOrSet, "object", "set", "array") } func mergeWithOverwrite(objA, objB ast.Object) ast.Object { diff --git a/vendor/github.com/open-policy-agent/opa/topdown/parse.go b/vendor/github.com/open-policy-agent/opa/v1/topdown/parse.go similarity index 91% rename from vendor/github.com/open-policy-agent/opa/topdown/parse.go rename to vendor/github.com/open-policy-agent/opa/v1/topdown/parse.go index c46222b41..464e0141a 100644 --- a/vendor/github.com/open-policy-agent/opa/topdown/parse.go +++ b/vendor/github.com/open-policy-agent/opa/v1/topdown/parse.go @@ -9,8 +9,8 @@ import ( "encoding/json" "fmt" - "github.com/open-policy-agent/opa/ast" - "github.com/open-policy-agent/opa/topdown/builtins" + "github.com/open-policy-agent/opa/v1/ast" + "github.com/open-policy-agent/opa/v1/topdown/builtins" ) func builtinRegoParseModule(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error { @@ -25,6 +25,7 @@ func builtinRegoParseModule(_ BuiltinContext, operands []*ast.Term, iter func(*a return err } + // FIXME: Use configured rego-version? module, err := ast.ParseModule(string(filename), string(input)) if err != nil { return err diff --git a/vendor/github.com/open-policy-agent/opa/topdown/parse_bytes.go b/vendor/github.com/open-policy-agent/opa/v1/topdown/parse_bytes.go similarity index 76% rename from vendor/github.com/open-policy-agent/opa/topdown/parse_bytes.go rename to vendor/github.com/open-policy-agent/opa/v1/topdown/parse_bytes.go index 0cd4bc193..dcc8e2199 100644 --- a/vendor/github.com/open-policy-agent/opa/topdown/parse_bytes.go +++ b/vendor/github.com/open-policy-agent/opa/v1/topdown/parse_bytes.go @@ -10,8 +10,8 @@ import ( "strings" "unicode" - "github.com/open-policy-agent/opa/ast" - "github.com/open-policy-agent/opa/topdown/builtins" + "github.com/open-policy-agent/opa/v1/ast" + "github.com/open-policy-agent/opa/v1/topdown/builtins" ) const ( @@ -121,21 +121,35 @@ func extractNumAndUnit(s string) (string, string) { } firstNonNumIdx := -1 - for idx, r := range s { - if !isNum(r) { + for idx := 0; idx < len(s); idx++ { + r := rune(s[idx]) + // Identify the first non-numeric character, marking the boundary between the number and the unit. + if !isNum(r) && r != 'e' && r != 'E' && r != '+' && r != '-' { firstNonNumIdx = idx break } + if r == 'e' || r == 'E' { + // Check if the next character is a valid digit or +/- for scientific notation + if idx == len(s)-1 || (!unicode.IsDigit(rune(s[idx+1])) && rune(s[idx+1]) != '+' && rune(s[idx+1]) != '-') { + firstNonNumIdx = idx + break + } + // Skip the next character if it is '+' or '-' + if idx+1 < len(s) && (s[idx+1] == '+' || s[idx+1] == '-') { + idx++ + } + } } - if firstNonNumIdx == -1 { // only digits and '.' + if firstNonNumIdx == -1 { // only digits, '.', or valid scientific notation return s, "" } if firstNonNumIdx == 0 { // only units (starts with non-digit) return "", s } - return s[0:firstNonNumIdx], s[firstNonNumIdx:] + // Return the number and the rest as the unit + return s[:firstNonNumIdx], s[firstNonNumIdx:] } func init() { diff --git a/vendor/github.com/open-policy-agent/opa/topdown/parse_units.go b/vendor/github.com/open-policy-agent/opa/v1/topdown/parse_units.go similarity index 96% rename from vendor/github.com/open-policy-agent/opa/topdown/parse_units.go rename to vendor/github.com/open-policy-agent/opa/v1/topdown/parse_units.go index daf240214..47e459510 100644 --- a/vendor/github.com/open-policy-agent/opa/topdown/parse_units.go +++ b/vendor/github.com/open-policy-agent/opa/v1/topdown/parse_units.go @@ -10,8 +10,8 @@ import ( "math/big" "strings" - "github.com/open-policy-agent/opa/ast" - "github.com/open-policy-agent/opa/topdown/builtins" + "github.com/open-policy-agent/opa/v1/ast" + "github.com/open-policy-agent/opa/v1/topdown/builtins" ) // Binary Si unit constants are borrowed from topdown/parse_bytes diff --git a/vendor/github.com/open-policy-agent/opa/v1/topdown/print.go b/vendor/github.com/open-policy-agent/opa/v1/topdown/print.go new file mode 100644 index 000000000..2d16c2baa --- /dev/null +++ b/vendor/github.com/open-policy-agent/opa/v1/topdown/print.go @@ -0,0 +1,86 @@ +// Copyright 2021 The OPA Authors. All rights reserved. +// Use of this source code is governed by an Apache2 +// license that can be found in the LICENSE file. + +package topdown + +import ( + "fmt" + "io" + "strings" + + "github.com/open-policy-agent/opa/v1/ast" + "github.com/open-policy-agent/opa/v1/topdown/builtins" + "github.com/open-policy-agent/opa/v1/topdown/print" +) + +func NewPrintHook(w io.Writer) print.Hook { + return printHook{w: w} +} + +type printHook struct { + w io.Writer +} + +func (h printHook) Print(_ print.Context, msg string) error { + _, err := fmt.Fprintln(h.w, msg) + return err +} + +func builtinPrint(bctx BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error { + + if bctx.PrintHook == nil { + return iter(nil) + } + + arr, err := builtins.ArrayOperand(operands[0].Value, 1) + if err != nil { + return err + } + + buf := make([]string, arr.Len()) + + err = builtinPrintCrossProductOperands(bctx, buf, arr, 0, func(buf []string) error { + pctx := print.Context{ + Context: bctx.Context, + Location: bctx.Location, + } + return bctx.PrintHook.Print(pctx, strings.Join(buf, " ")) + }) + if err != nil { + return err + } + + return iter(nil) +} + +func builtinPrintCrossProductOperands(bctx BuiltinContext, buf []string, operands *ast.Array, i int, f func([]string) error) error { + + if i >= operands.Len() { + return f(buf) + } + + xs, ok := operands.Elem(i).Value.(ast.Set) + if !ok { + return Halt{Err: internalErr(bctx.Location, fmt.Sprintf("illegal argument type: %v", ast.TypeName(operands.Elem(i).Value)))} + } + + if xs.Len() == 0 { + buf[i] = "" + return builtinPrintCrossProductOperands(bctx, buf, operands, i+1, f) + } + + return xs.Iter(func(x *ast.Term) error { + switch v := x.Value.(type) { + case ast.String: + buf[i] = string(v) + default: + buf[i] = v.String() + } + return builtinPrintCrossProductOperands(bctx, buf, operands, i+1, f) + }) +} + +func init() { + RegisterBuiltinFunc(ast.InternalPrint.Name, builtinPrint) +} diff --git a/vendor/github.com/open-policy-agent/opa/v1/topdown/print/print.go b/vendor/github.com/open-policy-agent/opa/v1/topdown/print/print.go new file mode 100644 index 000000000..ce684ae94 --- /dev/null +++ b/vendor/github.com/open-policy-agent/opa/v1/topdown/print/print.go @@ -0,0 +1,21 @@ +package print + +import ( + "context" + + "github.com/open-policy-agent/opa/v1/ast" +) + +// Context provides the Hook implementation context about the print() call. +type Context struct { + Context context.Context // request context passed when query executed + Location *ast.Location // location of print call +} + +// Hook defines the interface that callers can implement to receive print +// statement outputs. If the hook returns an error, it will be surfaced if +// strict builtin error checking is enabled (otherwise, it will not halt +// execution.) +type Hook interface { + Print(Context, string) error +} diff --git a/vendor/github.com/open-policy-agent/opa/topdown/providers.go b/vendor/github.com/open-policy-agent/opa/v1/topdown/providers.go similarity index 97% rename from vendor/github.com/open-policy-agent/opa/topdown/providers.go rename to vendor/github.com/open-policy-agent/opa/v1/topdown/providers.go index 77db91798..dd84026e4 100644 --- a/vendor/github.com/open-policy-agent/opa/topdown/providers.go +++ b/vendor/github.com/open-policy-agent/opa/v1/topdown/providers.go @@ -9,9 +9,9 @@ import ( "net/url" "time" - "github.com/open-policy-agent/opa/ast" "github.com/open-policy-agent/opa/internal/providers/aws" - "github.com/open-policy-agent/opa/topdown/builtins" + "github.com/open-policy-agent/opa/v1/ast" + "github.com/open-policy-agent/opa/v1/topdown/builtins" ) var awsRequiredConfigKeyNames = ast.NewSet( @@ -119,9 +119,6 @@ func builtinAWSSigV4SignReq(_ BuiltinContext, operands []*ast.Term, iter func(*a } signingTimestamp = time.Unix(0, ts) - if err != nil { - return err - } // Make sure our required keys exist! // This check is stricter than required, but better to break here than downstream. diff --git a/vendor/github.com/open-policy-agent/opa/v1/topdown/query.go b/vendor/github.com/open-policy-agent/opa/v1/topdown/query.go new file mode 100644 index 000000000..6029ee721 --- /dev/null +++ b/vendor/github.com/open-policy-agent/opa/v1/topdown/query.go @@ -0,0 +1,611 @@ +package topdown + +import ( + "context" + "crypto/rand" + "io" + "sort" + "time" + + "github.com/open-policy-agent/opa/v1/ast" + "github.com/open-policy-agent/opa/v1/metrics" + "github.com/open-policy-agent/opa/v1/resolver" + "github.com/open-policy-agent/opa/v1/storage" + "github.com/open-policy-agent/opa/v1/topdown/builtins" + "github.com/open-policy-agent/opa/v1/topdown/cache" + "github.com/open-policy-agent/opa/v1/topdown/copypropagation" + "github.com/open-policy-agent/opa/v1/topdown/print" + "github.com/open-policy-agent/opa/v1/tracing" +) + +// QueryResultSet represents a collection of results returned by a query. +type QueryResultSet []QueryResult + +// QueryResult represents a single result returned by a query. The result +// contains bindings for all variables that appear in the query. +type QueryResult map[ast.Var]*ast.Term + +// Query provides a configurable interface for performing query evaluation. +type Query struct { + seed io.Reader + time time.Time + cancel Cancel + query ast.Body + queryCompiler ast.QueryCompiler + compiler *ast.Compiler + store storage.Store + txn storage.Transaction + input *ast.Term + external *resolverTrie + tracers []QueryTracer + plugTraceVars bool + unknowns []*ast.Term + partialNamespace string + skipSaveNamespace bool + metrics metrics.Metrics + instr *Instrumentation + disableInlining []ast.Ref + shallowInlining bool + genvarprefix string + runtime *ast.Term + builtins map[string]*Builtin + indexing bool + earlyExit bool + interQueryBuiltinCache cache.InterQueryCache + interQueryBuiltinValueCache cache.InterQueryValueCache + ndBuiltinCache builtins.NDBCache + strictBuiltinErrors bool + builtinErrorList *[]Error + strictObjects bool + roundTripper CustomizeRoundTripper + printHook print.Hook + tracingOpts tracing.Options + virtualCache VirtualCache +} + +// Builtin represents a built-in function that queries can call. +type Builtin struct { + Decl *ast.Builtin + Func BuiltinFunc +} + +// NewQuery returns a new Query object that can be run. +func NewQuery(query ast.Body) *Query { + return &Query{ + query: query, + genvarprefix: ast.WildcardPrefix, + indexing: true, + earlyExit: true, + external: newResolverTrie(), + } +} + +// WithQueryCompiler sets the queryCompiler used for the query. +func (q *Query) WithQueryCompiler(queryCompiler ast.QueryCompiler) *Query { + q.queryCompiler = queryCompiler + return q +} + +// WithCompiler sets the compiler to use for the query. +func (q *Query) WithCompiler(compiler *ast.Compiler) *Query { + q.compiler = compiler + return q +} + +// WithStore sets the store to use for the query. +func (q *Query) WithStore(store storage.Store) *Query { + q.store = store + return q +} + +// WithTransaction sets the transaction to use for the query. All queries +// should be performed over a consistent snapshot of the storage layer. +func (q *Query) WithTransaction(txn storage.Transaction) *Query { + q.txn = txn + return q +} + +// WithCancel sets the cancellation object to use for the query. Set this if +// you need to abort queries based on a deadline. This is optional. +func (q *Query) WithCancel(cancel Cancel) *Query { + q.cancel = cancel + return q +} + +// WithInput sets the input object to use for the query. References rooted at +// input will be evaluated against this value. This is optional. +func (q *Query) WithInput(input *ast.Term) *Query { + q.input = input + return q +} + +// WithTracer adds a query tracer to use during evaluation. This is optional. +// Deprecated: Use WithQueryTracer instead. +func (q *Query) WithTracer(tracer Tracer) *Query { + qt, ok := tracer.(QueryTracer) + if !ok { + qt = WrapLegacyTracer(tracer) + } + return q.WithQueryTracer(qt) +} + +// WithQueryTracer adds a query tracer to use during evaluation. This is optional. +// Disabled QueryTracers will be ignored. +func (q *Query) WithQueryTracer(tracer QueryTracer) *Query { + if !tracer.Enabled() { + return q + } + + q.tracers = append(q.tracers, tracer) + + // If *any* of the tracers require local variable metadata we need to + // enabled plugging local trace variables. + conf := tracer.Config() + if conf.PlugLocalVars { + q.plugTraceVars = true + } + + return q +} + +// WithMetrics sets the metrics collection to add evaluation metrics to. This +// is optional. +func (q *Query) WithMetrics(m metrics.Metrics) *Query { + q.metrics = m + return q +} + +// WithInstrumentation sets the instrumentation configuration to enable on the +// evaluation process. By default, instrumentation is turned off. +func (q *Query) WithInstrumentation(instr *Instrumentation) *Query { + q.instr = instr + return q +} + +// WithUnknowns sets the initial set of variables or references to treat as +// unknown during query evaluation. This is required for partial evaluation. +func (q *Query) WithUnknowns(terms []*ast.Term) *Query { + q.unknowns = terms + return q +} + +// WithPartialNamespace sets the namespace to use for supporting rules +// generated as part of the partial evaluation process. The ns value must be a +// valid package path component. +func (q *Query) WithPartialNamespace(ns string) *Query { + q.partialNamespace = ns + return q +} + +// WithSkipPartialNamespace disables namespacing of saved support rules that are generated +// from the original policy (rules which are completely synthetic are still namespaced.) +func (q *Query) WithSkipPartialNamespace(yes bool) *Query { + q.skipSaveNamespace = yes + return q +} + +// WithDisableInlining adds a set of paths to the query that should be excluded from +// inlining. Inlining during partial evaluation can be expensive in some cases +// (e.g., when a cross-product is computed.) Disabling inlining avoids expensive +// computation at the cost of generating support rules. +func (q *Query) WithDisableInlining(paths []ast.Ref) *Query { + q.disableInlining = paths + return q +} + +// WithShallowInlining disables aggressive inlining performed during partial evaluation. +// When shallow inlining is enabled rules that depend (transitively) on unknowns are not inlined. +// Only rules/values that are completely known will be inlined. +func (q *Query) WithShallowInlining(yes bool) *Query { + q.shallowInlining = yes + return q +} + +// WithRuntime sets the runtime data to execute the query with. The runtime data +// can be returned by the `opa.runtime` built-in function. +func (q *Query) WithRuntime(runtime *ast.Term) *Query { + q.runtime = runtime + return q +} + +// WithBuiltins adds a set of built-in functions that can be called by the +// query. +func (q *Query) WithBuiltins(builtins map[string]*Builtin) *Query { + q.builtins = builtins + return q +} + +// WithIndexing will enable or disable using rule indexing for the evaluation +// of the query. The default is enabled. +func (q *Query) WithIndexing(enabled bool) *Query { + q.indexing = enabled + return q +} + +// WithEarlyExit will enable or disable using 'early exit' for the evaluation +// of the query. The default is enabled. +func (q *Query) WithEarlyExit(enabled bool) *Query { + q.earlyExit = enabled + return q +} + +// WithSeed sets a reader that will seed randomization required by built-in functions. +// If a seed is not provided crypto/rand.Reader is used. +func (q *Query) WithSeed(r io.Reader) *Query { + q.seed = r + return q +} + +// WithTime sets the time that will be returned by the time.now_ns() built-in function. +func (q *Query) WithTime(x time.Time) *Query { + q.time = x + return q +} + +// WithInterQueryBuiltinCache sets the inter-query cache that built-in functions can utilize. +func (q *Query) WithInterQueryBuiltinCache(c cache.InterQueryCache) *Query { + q.interQueryBuiltinCache = c + return q +} + +// WithInterQueryBuiltinValueCache sets the inter-query value cache that built-in functions can utilize. +func (q *Query) WithInterQueryBuiltinValueCache(c cache.InterQueryValueCache) *Query { + q.interQueryBuiltinValueCache = c + return q +} + +// WithNDBuiltinCache sets the non-deterministic builtin cache. +func (q *Query) WithNDBuiltinCache(c builtins.NDBCache) *Query { + q.ndBuiltinCache = c + return q +} + +// WithStrictBuiltinErrors tells the evaluator to treat all built-in function errors as fatal errors. +func (q *Query) WithStrictBuiltinErrors(yes bool) *Query { + q.strictBuiltinErrors = yes + return q +} + +// WithBuiltinErrorList supplies a pointer to an Error slice to store built-in function errors +// encountered during evaluation. This error slice can be inspected after evaluation to determine +// which built-in function errors occurred. +func (q *Query) WithBuiltinErrorList(list *[]Error) *Query { + q.builtinErrorList = list + return q +} + +// WithResolver configures an external resolver to use for the given ref. +func (q *Query) WithResolver(ref ast.Ref, r resolver.Resolver) *Query { + q.external.Put(ref, r) + return q +} + +// WithHTTPRoundTripper configures a custom HTTP transport for built-in functions that make HTTP requests. +func (q *Query) WithHTTPRoundTripper(t CustomizeRoundTripper) *Query { + q.roundTripper = t + return q +} + +func (q *Query) WithPrintHook(h print.Hook) *Query { + q.printHook = h + return q +} + +// WithDistributedTracingOpts sets the options to be used by distributed tracing. +func (q *Query) WithDistributedTracingOpts(tr tracing.Options) *Query { + q.tracingOpts = tr + return q +} + +// WithStrictObjects tells the evaluator to avoid the "lazy object" optimization +// applied when reading objects from the store. It will result in higher memory +// usage and should only be used temporarily while adjusting code that breaks +// because of the optimization. +func (q *Query) WithStrictObjects(yes bool) *Query { + q.strictObjects = yes + return q +} + +// WithVirtualCache sets the VirtualCache to use during evaluation. This is +// optional, and if not set, the default cache is used. +func (q *Query) WithVirtualCache(vc VirtualCache) *Query { + q.virtualCache = vc + return q +} + +// PartialRun executes partial evaluation on the query with respect to unknown +// values. Partial evaluation attempts to evaluate as much of the query as +// possible without requiring values for the unknowns set on the query. The +// result of partial evaluation is a new set of queries that can be evaluated +// once the unknown value is known. In addition to new queries, partial +// evaluation may produce additional support modules that should be used in +// conjunction with the partially evaluated queries. +func (q *Query) PartialRun(ctx context.Context) (partials []ast.Body, support []*ast.Module, err error) { + if q.partialNamespace == "" { + q.partialNamespace = "partial" // lazily initialize partial namespace + } + if q.seed == nil { + q.seed = rand.Reader + } + if !q.time.IsZero() { + q.time = time.Now() + } + if q.metrics == nil { + q.metrics = metrics.New() + } + + f := &queryIDFactory{} + b := newBindings(0, q.instr) + + var vc VirtualCache + if q.virtualCache != nil { + vc = q.virtualCache + } else { + vc = NewVirtualCache() + } + + e := &eval{ + ctx: ctx, + metrics: q.metrics, + seed: q.seed, + time: ast.NumberTerm(int64ToJSONNumber(q.time.UnixNano())), + cancel: q.cancel, + query: q.query, + queryCompiler: q.queryCompiler, + queryIDFact: f, + queryID: f.Next(), + bindings: b, + compiler: q.compiler, + store: q.store, + baseCache: newBaseCache(), + targetStack: newRefStack(), + txn: q.txn, + input: q.input, + external: q.external, + tracers: q.tracers, + traceEnabled: len(q.tracers) > 0, + plugTraceVars: q.plugTraceVars, + instr: q.instr, + builtins: q.builtins, + builtinCache: builtins.Cache{}, + functionMocks: newFunctionMocksStack(), + interQueryBuiltinCache: q.interQueryBuiltinCache, + interQueryBuiltinValueCache: q.interQueryBuiltinValueCache, + ndBuiltinCache: q.ndBuiltinCache, + virtualCache: vc, + comprehensionCache: newComprehensionCache(), + saveSet: newSaveSet(q.unknowns, b, q.instr), + saveStack: newSaveStack(), + saveSupport: newSaveSupport(), + saveNamespace: ast.StringTerm(q.partialNamespace), + skipSaveNamespace: q.skipSaveNamespace, + inliningControl: &inliningControl{ + shallow: q.shallowInlining, + }, + genvarprefix: q.genvarprefix, + runtime: q.runtime, + indexing: q.indexing, + earlyExit: q.earlyExit, + builtinErrors: &builtinErrors{}, + printHook: q.printHook, + strictObjects: q.strictObjects, + } + + if len(q.disableInlining) > 0 { + e.inliningControl.PushDisable(q.disableInlining, false) + } + + e.caller = e + q.metrics.Timer(metrics.RegoPartialEval).Start() + defer q.metrics.Timer(metrics.RegoPartialEval).Stop() + + livevars := ast.NewVarSet() + for _, t := range q.unknowns { + switch v := t.Value.(type) { + case ast.Var: + livevars.Add(v) + case ast.Ref: + livevars.Add(v[0].Value.(ast.Var)) + } + } + + ast.WalkVars(q.query, func(x ast.Var) bool { + if !x.IsGenerated() { + livevars.Add(x) + } + return false + }) + + p := copypropagation.New(livevars).WithCompiler(q.compiler) + + err = e.Run(func(e *eval) error { + + // Build output from saved expressions. + body := ast.NewBody() + + for _, elem := range e.saveStack.Stack[len(e.saveStack.Stack)-1] { + body.Append(elem.Plug(e.bindings)) + } + + // Include bindings as exprs so that when caller evals the result, they + // can obtain values for the vars in their query. + bindingExprs := []*ast.Expr{} + _ = e.bindings.Iter(e.bindings, func(a, b *ast.Term) error { + bindingExprs = append(bindingExprs, ast.Equality.Expr(a, b)) + return nil + }) // cannot return error + + // Sort binding expressions so that results are deterministic. + sort.Slice(bindingExprs, func(i, j int) bool { + return bindingExprs[i].Compare(bindingExprs[j]) < 0 + }) + + for i := range bindingExprs { + body.Append(bindingExprs[i]) + } + + // Skip this rule body if it fails to type-check. + // Type-checking failure means the rule body will never succeed. + if !e.compiler.PassesTypeCheck(body) { + return nil + } + + if !q.shallowInlining { + body = applyCopyPropagation(p, e.instr, body) + } + + partials = append(partials, body) + return nil + }) + + support = e.saveSupport.List() + + if len(e.builtinErrors.errs) > 0 { + if q.strictBuiltinErrors { + err = e.builtinErrors.errs[0] + } else if q.builtinErrorList != nil { + // If a builtinErrorList has been supplied, we must use pointer indirection + // to append to it. builtinErrorList is a slice pointer so that errors can be + // appended to it without returning a new slice and changing the interface + // of PartialRun. + for _, err := range e.builtinErrors.errs { + if tdError, ok := err.(*Error); ok { + *(q.builtinErrorList) = append(*(q.builtinErrorList), *tdError) + } else { + *(q.builtinErrorList) = append(*(q.builtinErrorList), Error{ + Code: BuiltinErr, + Message: err.Error(), + }) + } + } + } + } + + for i, m := range support { + if regoVersion := q.compiler.DefaultRegoVersion(); regoVersion != ast.RegoUndefined { + ast.SetModuleRegoVersion(m, q.compiler.DefaultRegoVersion()) + } + + sort.Slice(support[i].Rules, func(j, k int) bool { + return support[i].Rules[j].Compare(support[i].Rules[k]) < 0 + }) + } + + return partials, support, err +} + +// Run is a wrapper around Iter that accumulates query results and returns them +// in one shot. +func (q *Query) Run(ctx context.Context) (QueryResultSet, error) { + qrs := QueryResultSet{} + return qrs, q.Iter(ctx, func(qr QueryResult) error { + qrs = append(qrs, qr) + return nil + }) +} + +// Iter executes the query and invokes the iter function with query results +// produced by evaluating the query. +func (q *Query) Iter(ctx context.Context, iter func(QueryResult) error) error { + // Query evaluation must not be allowed if the compiler has errors and is in an undefined, possibly inconsistent state + if q.compiler != nil && len(q.compiler.Errors) > 0 { + return &Error{ + Code: InternalErr, + Message: "compiler has errors", + } + } + + if q.seed == nil { + q.seed = rand.Reader + } + if q.time.IsZero() { + q.time = time.Now() + } + if q.metrics == nil { + q.metrics = metrics.New() + } + + f := &queryIDFactory{} + + var vc VirtualCache + if q.virtualCache != nil { + vc = q.virtualCache + } else { + vc = NewVirtualCache() + } + + e := &eval{ + ctx: ctx, + metrics: q.metrics, + seed: q.seed, + time: ast.NumberTerm(int64ToJSONNumber(q.time.UnixNano())), + cancel: q.cancel, + query: q.query, + queryCompiler: q.queryCompiler, + queryIDFact: f, + queryID: f.Next(), + bindings: newBindings(0, q.instr), + compiler: q.compiler, + store: q.store, + baseCache: newBaseCache(), + targetStack: newRefStack(), + txn: q.txn, + input: q.input, + external: q.external, + tracers: q.tracers, + traceEnabled: len(q.tracers) > 0, + plugTraceVars: q.plugTraceVars, + instr: q.instr, + builtins: q.builtins, + builtinCache: builtins.Cache{}, + functionMocks: newFunctionMocksStack(), + interQueryBuiltinCache: q.interQueryBuiltinCache, + interQueryBuiltinValueCache: q.interQueryBuiltinValueCache, + ndBuiltinCache: q.ndBuiltinCache, + virtualCache: vc, + comprehensionCache: newComprehensionCache(), + genvarprefix: q.genvarprefix, + runtime: q.runtime, + indexing: q.indexing, + earlyExit: q.earlyExit, + builtinErrors: &builtinErrors{}, + printHook: q.printHook, + tracingOpts: q.tracingOpts, + strictObjects: q.strictObjects, + roundTripper: q.roundTripper, + } + e.caller = e + q.metrics.Timer(metrics.RegoQueryEval).Start() + err := e.Run(func(e *eval) error { + qr := QueryResult{} + _ = e.bindings.Iter(nil, func(k, v *ast.Term) error { + qr[k.Value.(ast.Var)] = v + return nil + }) // cannot return error + return iter(qr) + }) + + if len(e.builtinErrors.errs) > 0 { + if q.strictBuiltinErrors { + err = e.builtinErrors.errs[0] + } else if q.builtinErrorList != nil { + // If a builtinErrorList has been supplied, we must use pointer indirection + // to append to it. builtinErrorList is a slice pointer so that errors can be + // appended to it without returning a new slice and changing the interface + // of Iter. + for _, err := range e.builtinErrors.errs { + if tdError, ok := err.(*Error); ok { + *(q.builtinErrorList) = append(*(q.builtinErrorList), *tdError) + } else { + *(q.builtinErrorList) = append(*(q.builtinErrorList), Error{ + Code: BuiltinErr, + Message: err.Error(), + }) + } + } + } + } + + q.metrics.Timer(metrics.RegoQueryEval).Stop() + return err +} diff --git a/vendor/github.com/open-policy-agent/opa/topdown/reachable.go b/vendor/github.com/open-policy-agent/opa/v1/topdown/reachable.go similarity index 97% rename from vendor/github.com/open-policy-agent/opa/topdown/reachable.go rename to vendor/github.com/open-policy-agent/opa/v1/topdown/reachable.go index 8d61018e7..1c31019db 100644 --- a/vendor/github.com/open-policy-agent/opa/topdown/reachable.go +++ b/vendor/github.com/open-policy-agent/opa/v1/topdown/reachable.go @@ -5,8 +5,8 @@ package topdown import ( - "github.com/open-policy-agent/opa/ast" - "github.com/open-policy-agent/opa/topdown/builtins" + "github.com/open-policy-agent/opa/v1/ast" + "github.com/open-policy-agent/opa/v1/topdown/builtins" ) // Helper: sets of vertices can be represented as Arrays or Sets. diff --git a/vendor/github.com/open-policy-agent/opa/topdown/regex.go b/vendor/github.com/open-policy-agent/opa/v1/topdown/regex.go similarity index 94% rename from vendor/github.com/open-policy-agent/opa/topdown/regex.go rename to vendor/github.com/open-policy-agent/opa/v1/topdown/regex.go index 452e7d58b..fb55d9506 100644 --- a/vendor/github.com/open-policy-agent/opa/topdown/regex.go +++ b/vendor/github.com/open-policy-agent/opa/v1/topdown/regex.go @@ -11,8 +11,8 @@ import ( gintersect "github.com/yashtewari/glob-intersection" - "github.com/open-policy-agent/opa/ast" - "github.com/open-policy-agent/opa/topdown/builtins" + "github.com/open-policy-agent/opa/v1/ast" + "github.com/open-policy-agent/opa/v1/topdown/builtins" ) const regexCacheMaxSize = 100 @@ -25,15 +25,15 @@ func builtinRegexIsValid(_ BuiltinContext, operands []*ast.Term, iter func(*ast. s, err := builtins.StringOperand(operands[0].Value, 1) if err != nil { - return iter(ast.BooleanTerm(false)) + return iter(ast.InternedBooleanTerm(false)) } _, err = regexp.Compile(string(s)) if err != nil { - return iter(ast.BooleanTerm(false)) + return iter(ast.InternedBooleanTerm(false)) } - return iter(ast.BooleanTerm(true)) + return iter(ast.InternedBooleanTerm(true)) } func builtinRegexMatch(bctx BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error { @@ -49,7 +49,7 @@ func builtinRegexMatch(bctx BuiltinContext, operands []*ast.Term, iter func(*ast if err != nil { return err } - return iter(ast.BooleanTerm(re.MatchString(string(s2)))) + return iter(ast.InternedBooleanTerm(re.MatchString(string(s2)))) } func builtinRegexMatchTemplate(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error { @@ -79,7 +79,7 @@ func builtinRegexMatchTemplate(_ BuiltinContext, operands []*ast.Term, iter func if err != nil { return err } - return iter(ast.BooleanTerm(re.MatchString(string(match)))) + return iter(ast.InternedBooleanTerm(re.MatchString(string(match)))) } func builtinRegexSplit(bctx BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error { @@ -176,7 +176,7 @@ func builtinGlobsMatch(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Te if err != nil { return err } - return iter(ast.BooleanTerm(ne)) + return iter(ast.InternedBooleanTerm(ne)) } func builtinRegexFind(bctx BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error { diff --git a/vendor/github.com/open-policy-agent/opa/topdown/regex_template.go b/vendor/github.com/open-policy-agent/opa/v1/topdown/regex_template.go similarity index 100% rename from vendor/github.com/open-policy-agent/opa/topdown/regex_template.go rename to vendor/github.com/open-policy-agent/opa/v1/topdown/regex_template.go diff --git a/vendor/github.com/open-policy-agent/opa/topdown/resolver.go b/vendor/github.com/open-policy-agent/opa/v1/topdown/resolver.go similarity index 94% rename from vendor/github.com/open-policy-agent/opa/topdown/resolver.go rename to vendor/github.com/open-policy-agent/opa/v1/topdown/resolver.go index 5ed6c1e44..170e6e640 100644 --- a/vendor/github.com/open-policy-agent/opa/topdown/resolver.go +++ b/vendor/github.com/open-policy-agent/opa/v1/topdown/resolver.go @@ -5,9 +5,9 @@ package topdown import ( - "github.com/open-policy-agent/opa/ast" - "github.com/open-policy-agent/opa/metrics" - "github.com/open-policy-agent/opa/resolver" + "github.com/open-policy-agent/opa/v1/ast" + "github.com/open-policy-agent/opa/v1/metrics" + "github.com/open-policy-agent/opa/v1/resolver" ) type resolverTrie struct { diff --git a/vendor/github.com/open-policy-agent/opa/topdown/runtime.go b/vendor/github.com/open-policy-agent/opa/v1/topdown/runtime.go similarity index 94% rename from vendor/github.com/open-policy-agent/opa/topdown/runtime.go rename to vendor/github.com/open-policy-agent/opa/v1/topdown/runtime.go index 7d512f7c0..f892f1751 100644 --- a/vendor/github.com/open-policy-agent/opa/topdown/runtime.go +++ b/vendor/github.com/open-policy-agent/opa/v1/topdown/runtime.go @@ -7,16 +7,18 @@ package topdown import ( "fmt" - "github.com/open-policy-agent/opa/ast" + "github.com/open-policy-agent/opa/v1/ast" ) +var configStringTerm = ast.StringTerm("config") + func builtinOPARuntime(bctx BuiltinContext, _ []*ast.Term, iter func(*ast.Term) error) error { if bctx.Runtime == nil { return iter(ast.ObjectTerm()) } - if bctx.Runtime.Get(ast.StringTerm("config")) != nil { + if bctx.Runtime.Get(configStringTerm) != nil { iface, err := ast.ValueToInterface(bctx.Runtime.Value, illegalResolver{}) if err != nil { return err diff --git a/vendor/github.com/open-policy-agent/opa/topdown/save.go b/vendor/github.com/open-policy-agent/opa/v1/topdown/save.go similarity index 99% rename from vendor/github.com/open-policy-agent/opa/topdown/save.go rename to vendor/github.com/open-policy-agent/opa/v1/topdown/save.go index 0468692cc..0c853085a 100644 --- a/vendor/github.com/open-policy-agent/opa/topdown/save.go +++ b/vendor/github.com/open-policy-agent/opa/v1/topdown/save.go @@ -5,7 +5,7 @@ import ( "fmt" "strings" - "github.com/open-policy-agent/opa/ast" + "github.com/open-policy-agent/opa/v1/ast" ) // saveSet contains a stack of terms that are considered 'unknown' during diff --git a/vendor/github.com/open-policy-agent/opa/topdown/semver.go b/vendor/github.com/open-policy-agent/opa/v1/topdown/semver.go similarity index 85% rename from vendor/github.com/open-policy-agent/opa/topdown/semver.go rename to vendor/github.com/open-policy-agent/opa/v1/topdown/semver.go index 7bb7b9c18..0e7daaeae 100644 --- a/vendor/github.com/open-policy-agent/opa/topdown/semver.go +++ b/vendor/github.com/open-policy-agent/opa/v1/topdown/semver.go @@ -7,9 +7,9 @@ package topdown import ( "fmt" - "github.com/open-policy-agent/opa/ast" "github.com/open-policy-agent/opa/internal/semver" - "github.com/open-policy-agent/opa/topdown/builtins" + "github.com/open-policy-agent/opa/v1/ast" + "github.com/open-policy-agent/opa/v1/topdown/builtins" ) func builtinSemVerCompare(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error { @@ -34,13 +34,13 @@ func builtinSemVerCompare(_ BuiltinContext, operands []*ast.Term, iter func(*ast result := versionA.Compare(*versionB) - return iter(ast.IntNumberTerm(result)) + return iter(ast.InternedIntNumberTerm(result)) } func builtinSemVerIsValid(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error { versionString, err := builtins.StringOperand(operands[0].Value, 1) if err != nil { - return iter(ast.BooleanTerm(false)) + return iter(ast.InternedBooleanTerm(false)) } result := true @@ -50,7 +50,7 @@ func builtinSemVerIsValid(_ BuiltinContext, operands []*ast.Term, iter func(*ast result = false } - return iter(ast.BooleanTerm(result)) + return iter(ast.InternedBooleanTerm(result)) } func init() { diff --git a/vendor/github.com/open-policy-agent/opa/topdown/sets.go b/vendor/github.com/open-policy-agent/opa/v1/topdown/sets.go similarity index 95% rename from vendor/github.com/open-policy-agent/opa/topdown/sets.go rename to vendor/github.com/open-policy-agent/opa/v1/topdown/sets.go index a973404f3..b7566b8e6 100644 --- a/vendor/github.com/open-policy-agent/opa/topdown/sets.go +++ b/vendor/github.com/open-policy-agent/opa/v1/topdown/sets.go @@ -5,8 +5,8 @@ package topdown import ( - "github.com/open-policy-agent/opa/ast" - "github.com/open-policy-agent/opa/topdown/builtins" + "github.com/open-policy-agent/opa/v1/ast" + "github.com/open-policy-agent/opa/v1/topdown/builtins" ) // Deprecated in v0.4.2 in favour of minus/infix "-" operation. diff --git a/vendor/github.com/open-policy-agent/opa/topdown/strings.go b/vendor/github.com/open-policy-agent/opa/v1/topdown/strings.go similarity index 91% rename from vendor/github.com/open-policy-agent/opa/topdown/strings.go rename to vendor/github.com/open-policy-agent/opa/v1/topdown/strings.go index d9e4a55e5..8d6c753e6 100644 --- a/vendor/github.com/open-policy-agent/opa/topdown/strings.go +++ b/vendor/github.com/open-policy-agent/opa/v1/topdown/strings.go @@ -8,12 +8,14 @@ import ( "fmt" "math/big" "sort" + "strconv" "strings" "github.com/tchap/go-patricia/v2/patricia" - "github.com/open-policy-agent/opa/ast" - "github.com/open-policy-agent/opa/topdown/builtins" + "github.com/open-policy-agent/opa/v1/ast" + "github.com/open-policy-agent/opa/v1/topdown/builtins" + "github.com/open-policy-agent/opa/v1/util" ) func builtinAnyPrefixMatch(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error { @@ -47,7 +49,7 @@ func builtinAnyPrefixMatch(_ BuiltinContext, operands []*ast.Term, iter func(*as return builtins.NewOperandTypeErr(2, b, "string", "set", "array") } - return iter(ast.BooleanTerm(anyStartsWithAny(strs, prefixes))) + return iter(ast.InternedBooleanTerm(anyStartsWithAny(strs, prefixes))) } func builtinAnySuffixMatch(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error { @@ -87,7 +89,7 @@ func builtinAnySuffixMatch(_ BuiltinContext, operands []*ast.Term, iter func(*as return builtins.NewOperandTypeErr(2, b, "string", "set", "array") } - return iter(ast.BooleanTerm(anyStartsWithAny(strsReversed, suffixesReversed))) + return iter(ast.InternedBooleanTerm(anyStartsWithAny(strsReversed, suffixesReversed))) } func anyStartsWithAny(strs []string, prefixes []string) bool { @@ -218,14 +220,14 @@ func builtinIndexOf(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) for i, r := range baseRunes { if len(baseRunes) >= i+searchLen { if r == searchRunes[0] && runesEqual(baseRunes[i:i+searchLen], searchRunes) { - return iter(ast.IntNumberTerm(i)) + return iter(ast.InternedIntNumberTerm(i)) } } else { break } } - return iter(ast.IntNumberTerm(-1)) + return iter(ast.InternedIntNumberTerm(-1)) } func builtinIndexOfN(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error { @@ -250,7 +252,7 @@ func builtinIndexOfN(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term for i, r := range baseRunes { if len(baseRunes) >= i+searchLen { if r == searchRunes[0] && runesEqual(baseRunes[i:i+searchLen], searchRunes) { - arr = append(arr, ast.IntNumberTerm(i)) + arr = append(arr, ast.InternedIntNumberTerm(i)) } } else { break @@ -307,7 +309,7 @@ func builtinContains(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term return err } - return iter(ast.BooleanTerm(strings.Contains(string(s), string(substr)))) + return iter(ast.InternedBooleanTerm(strings.Contains(string(s), string(substr)))) } func builtinStringCount(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error { @@ -326,7 +328,7 @@ func builtinStringCount(_ BuiltinContext, operands []*ast.Term, iter func(*ast.T count := strings.Count(baseTerm, searchTerm) - return iter(ast.IntNumberTerm(count)) + return iter(ast.InternedIntNumberTerm(count)) } func builtinStartsWith(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error { @@ -340,7 +342,7 @@ func builtinStartsWith(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Te return err } - return iter(ast.BooleanTerm(strings.HasPrefix(string(s), string(prefix)))) + return iter(ast.InternedBooleanTerm(strings.HasPrefix(string(s), string(prefix)))) } func builtinEndsWith(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error { @@ -354,7 +356,7 @@ func builtinEndsWith(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term return err } - return iter(ast.BooleanTerm(strings.HasSuffix(string(s), string(suffix)))) + return iter(ast.InternedBooleanTerm(strings.HasSuffix(string(s), string(suffix)))) } func builtinLower(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error { @@ -385,9 +387,9 @@ func builtinSplit(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) e return err } elems := strings.Split(string(s), string(d)) - arr := make([]*ast.Term, len(elems)) + arr := util.NewPtrSlice[ast.Term](len(elems)) for i := range elems { - arr[i] = ast.StringTerm(elems[i]) + arr[i].Value = ast.String(elems[i]) } return iter(ast.ArrayTerm(arr...)) } @@ -437,14 +439,8 @@ func builtinReplaceN(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term } oldnewArr = append(oldnewArr, string(keyVal), string(strVal)) } - if err != nil { - return err - } - - r := strings.NewReplacer(oldnewArr...) - replaced := r.Replace(string(s)) - return iter(ast.StringTerm(replaced)) + return iter(ast.StringTerm(strings.NewReplacer(oldnewArr...).Replace(string(s)))) } func builtinTrim(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error { @@ -537,7 +533,17 @@ func builtinSprintf(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) return builtins.NewOperandTypeErr(2, operands[1].Value, "array") } - args := make([]interface{}, astArr.Len()) + // Optimized path for where sprintf is used as a "to_string" function for + // a single integer, i.e. sprintf("%d", [x]) where x is an integer. + if s == "%d" && astArr.Len() == 1 { + if n, ok := astArr.Elem(0).Value.(ast.Number); ok { + if i, ok := n.Int(); ok { + return iter(ast.StringTerm(strconv.Itoa(i))) + } + } + } + + args := make([]any, astArr.Len()) for i := range args { switch v := astArr.Elem(i).Value.(type) { diff --git a/vendor/github.com/open-policy-agent/opa/topdown/subset.go b/vendor/github.com/open-policy-agent/opa/v1/topdown/subset.go similarity index 94% rename from vendor/github.com/open-policy-agent/opa/topdown/subset.go rename to vendor/github.com/open-policy-agent/opa/v1/topdown/subset.go index 7b152a5ef..08bdc8db4 100644 --- a/vendor/github.com/open-policy-agent/opa/topdown/subset.go +++ b/vendor/github.com/open-policy-agent/opa/v1/topdown/subset.go @@ -5,8 +5,8 @@ package topdown import ( - "github.com/open-policy-agent/opa/ast" - "github.com/open-policy-agent/opa/topdown/builtins" + "github.com/open-policy-agent/opa/v1/ast" + "github.com/open-policy-agent/opa/v1/topdown/builtins" ) func bothObjects(t1, t2 *ast.Term) (bool, ast.Object, ast.Object) { @@ -237,22 +237,22 @@ func builtinObjectSubset(_ BuiltinContext, operands []*ast.Term, iter func(*ast. if ok, superObj, subObj := bothObjects(superTerm, subTerm); ok { // Both operands are objects. - return iter(ast.BooleanTerm(objectSubset(superObj, subObj))) + return iter(ast.InternedBooleanTerm(objectSubset(superObj, subObj))) } if ok, superSet, subSet := bothSets(superTerm, subTerm); ok { // Both operands are sets. - return iter(ast.BooleanTerm(setSubset(superSet, subSet))) + return iter(ast.InternedBooleanTerm(setSubset(superSet, subSet))) } if ok, superArray, subArray := bothArrays(superTerm, subTerm); ok { // Both operands are sets. - return iter(ast.BooleanTerm(arraySubset(superArray, subArray))) + return iter(ast.InternedBooleanTerm(arraySubset(superArray, subArray))) } if ok, superArray, subSet := arraySet(superTerm, subTerm); ok { // Super operand is array and sub operand is set - return iter(ast.BooleanTerm(arraySetSubset(superArray, subSet))) + return iter(ast.InternedBooleanTerm(arraySetSubset(superArray, subSet))) } return builtins.ErrOperand("both arguments object.subset must be of the same type or array and set") diff --git a/vendor/github.com/open-policy-agent/opa/topdown/template.go b/vendor/github.com/open-policy-agent/opa/v1/topdown/template.go similarity index 90% rename from vendor/github.com/open-policy-agent/opa/topdown/template.go rename to vendor/github.com/open-policy-agent/opa/v1/topdown/template.go index cf42477ee..cf4635559 100644 --- a/vendor/github.com/open-policy-agent/opa/topdown/template.go +++ b/vendor/github.com/open-policy-agent/opa/v1/topdown/template.go @@ -4,8 +4,8 @@ import ( "bytes" "text/template" - "github.com/open-policy-agent/opa/ast" - "github.com/open-policy-agent/opa/topdown/builtins" + "github.com/open-policy-agent/opa/v1/ast" + "github.com/open-policy-agent/opa/v1/topdown/builtins" ) func renderTemplate(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error { diff --git a/vendor/github.com/open-policy-agent/opa/topdown/time.go b/vendor/github.com/open-policy-agent/opa/v1/topdown/time.go similarity index 93% rename from vendor/github.com/open-policy-agent/opa/topdown/time.go rename to vendor/github.com/open-policy-agent/opa/v1/topdown/time.go index ba3efc75d..1c5ddaa6f 100644 --- a/vendor/github.com/open-policy-agent/opa/topdown/time.go +++ b/vendor/github.com/open-policy-agent/opa/v1/topdown/time.go @@ -14,8 +14,8 @@ import ( "time" _ "time/tzdata" // this is needed to have LoadLocation when no filesystem tzdata is available - "github.com/open-policy-agent/opa/ast" - "github.com/open-policy-agent/opa/topdown/builtins" + "github.com/open-policy-agent/opa/v1/ast" + "github.com/open-policy-agent/opa/v1/topdown/builtins" ) var tzCache map[string]*time.Location @@ -127,7 +127,7 @@ func builtinDate(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) er return err } year, month, day := t.Date() - result := ast.NewArray(ast.IntNumberTerm(year), ast.IntNumberTerm(int(month)), ast.IntNumberTerm(day)) + result := ast.NewArray(ast.InternedIntNumberTerm(year), ast.InternedIntNumberTerm(int(month)), ast.InternedIntNumberTerm(day)) return iter(ast.NewTerm(result)) } @@ -137,7 +137,7 @@ func builtinClock(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) e return err } hour, minute, second := t.Clock() - result := ast.NewArray(ast.IntNumberTerm(hour), ast.IntNumberTerm(minute), ast.IntNumberTerm(second)) + result := ast.NewArray(ast.InternedIntNumberTerm(hour), ast.InternedIntNumberTerm(minute), ast.InternedIntNumberTerm(second)) return iter(ast.NewTerm(result)) } @@ -238,8 +238,8 @@ func builtinDiff(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) er } // END REDISTRIBUTION FROM APACHE 2.0 LICENSED PROJECT - return iter(ast.ArrayTerm(ast.IntNumberTerm(year), ast.IntNumberTerm(month), ast.IntNumberTerm(day), - ast.IntNumberTerm(hour), ast.IntNumberTerm(min), ast.IntNumberTerm(sec))) + return iter(ast.ArrayTerm(ast.InternedIntNumberTerm(year), ast.InternedIntNumberTerm(month), ast.InternedIntNumberTerm(day), + ast.InternedIntNumberTerm(hour), ast.InternedIntNumberTerm(min), ast.InternedIntNumberTerm(sec))) } func tzTime(a ast.Value) (t time.Time, lay string, err error) { diff --git a/vendor/github.com/open-policy-agent/opa/topdown/tokens.go b/vendor/github.com/open-policy-agent/opa/v1/topdown/tokens.go similarity index 99% rename from vendor/github.com/open-policy-agent/opa/topdown/tokens.go rename to vendor/github.com/open-policy-agent/opa/v1/topdown/tokens.go index 7457f1f15..9e5f9767d 100644 --- a/vendor/github.com/open-policy-agent/opa/topdown/tokens.go +++ b/vendor/github.com/open-policy-agent/opa/v1/topdown/tokens.go @@ -21,11 +21,11 @@ import ( "math/big" "strings" - "github.com/open-policy-agent/opa/ast" "github.com/open-policy-agent/opa/internal/jwx/jwa" "github.com/open-policy-agent/opa/internal/jwx/jwk" "github.com/open-policy-agent/opa/internal/jwx/jws" - "github.com/open-policy-agent/opa/topdown/builtins" + "github.com/open-policy-agent/opa/v1/ast" + "github.com/open-policy-agent/opa/v1/topdown/builtins" ) var ( @@ -1024,7 +1024,7 @@ func builtinJWTDecodeVerify(bctx BuiltinContext, operands []*ast.Term, iter func } unverified := ast.ArrayTerm( - ast.BooleanTerm(false), + ast.InternedBooleanTerm(false), ast.NewTerm(ast.NewObject()), ast.NewTerm(ast.NewObject()), ) @@ -1137,7 +1137,7 @@ func builtinJWTDecodeVerify(bctx BuiltinContext, operands []*ast.Term, iter func } verified := ast.ArrayTerm( - ast.BooleanTerm(true), + ast.InternedBooleanTerm(true), ast.NewTerm(token.decodedHeader), ast.NewTerm(payload), ) diff --git a/vendor/github.com/open-policy-agent/opa/v1/topdown/trace.go b/vendor/github.com/open-policy-agent/opa/v1/topdown/trace.go new file mode 100644 index 000000000..1c45ef23b --- /dev/null +++ b/vendor/github.com/open-policy-agent/opa/v1/topdown/trace.go @@ -0,0 +1,902 @@ +// Copyright 2016 The OPA Authors. All rights reserved. +// Use of this source code is governed by an Apache2 +// license that can be found in the LICENSE file. + +package topdown + +import ( + "bytes" + "fmt" + "io" + "slices" + "strings" + + iStrs "github.com/open-policy-agent/opa/internal/strings" + + "github.com/open-policy-agent/opa/v1/ast" + "github.com/open-policy-agent/opa/v1/topdown/builtins" +) + +const ( + minLocationWidth = 5 // len("query") + maxIdealLocationWidth = 64 + columnPadding = 4 + maxExprVarWidth = 32 + maxPrettyExprVarWidth = 64 +) + +// Op defines the types of tracing events. +type Op string + +const ( + // EnterOp is emitted when a new query is about to be evaluated. + EnterOp Op = "Enter" + + // ExitOp is emitted when a query has evaluated to true. + ExitOp Op = "Exit" + + // EvalOp is emitted when an expression is about to be evaluated. + EvalOp Op = "Eval" + + // RedoOp is emitted when an expression, rule, or query is being re-evaluated. + RedoOp Op = "Redo" + + // SaveOp is emitted when an expression is saved instead of evaluated + // during partial evaluation. + SaveOp Op = "Save" + + // FailOp is emitted when an expression evaluates to false. + FailOp Op = "Fail" + + // DuplicateOp is emitted when a query has produced a duplicate value. The search + // will stop at the point where the duplicate was emitted and backtrack. + DuplicateOp Op = "Duplicate" + + // NoteOp is emitted when an expression invokes a tracing built-in function. + NoteOp Op = "Note" + + // IndexOp is emitted during an expression evaluation to represent lookup + // matches. + IndexOp Op = "Index" + + // WasmOp is emitted when resolving a ref using an external + // Resolver. + WasmOp Op = "Wasm" + + // UnifyOp is emitted when two terms are unified. Node will be set to an + // equality expression with the two terms. This Node will not have location + // info. + UnifyOp Op = "Unify" + FailedAssertionOp Op = "FailedAssertion" +) + +// VarMetadata provides some user facing information about +// a variable in some policy. +type VarMetadata struct { + Name ast.Var `json:"name"` + Location *ast.Location `json:"location"` +} + +// Event contains state associated with a tracing event. +type Event struct { + Op Op // Identifies type of event. + Node ast.Node // Contains AST node relevant to the event. + Location *ast.Location // The location of the Node this event relates to. + QueryID uint64 // Identifies the query this event belongs to. + ParentID uint64 // Identifies the parent query this event belongs to. + Locals *ast.ValueMap // Contains local variable bindings from the query context. Nil if variables were not included in the trace event. + LocalMetadata map[ast.Var]VarMetadata // Contains metadata for the local variable bindings. Nil if variables were not included in the trace event. + Message string // Contains message for Note events. + Ref *ast.Ref // Identifies the subject ref for the event. Only applies to Index and Wasm operations. + + input *ast.Term + bindings *bindings + localVirtualCacheSnapshot *ast.ValueMap +} + +func (evt *Event) WithInput(input *ast.Term) *Event { + evt.input = input + return evt +} + +// HasRule returns true if the Event contains an ast.Rule. +func (evt *Event) HasRule() bool { + _, ok := evt.Node.(*ast.Rule) + return ok +} + +// HasBody returns true if the Event contains an ast.Body. +func (evt *Event) HasBody() bool { + _, ok := evt.Node.(ast.Body) + return ok +} + +// HasExpr returns true if the Event contains an ast.Expr. +func (evt *Event) HasExpr() bool { + _, ok := evt.Node.(*ast.Expr) + return ok +} + +// Equal returns true if this event is equal to the other event. +func (evt *Event) Equal(other *Event) bool { + if evt.Op != other.Op { + return false + } + if evt.QueryID != other.QueryID { + return false + } + if evt.ParentID != other.ParentID { + return false + } + if !evt.equalNodes(other) { + return false + } + return evt.Locals.Equal(other.Locals) +} + +func (evt *Event) String() string { + return fmt.Sprintf("%v %v %v (qid=%v, pqid=%v)", evt.Op, evt.Node, evt.Locals, evt.QueryID, evt.ParentID) +} + +// Input returns the input object as it was at the event. +func (evt *Event) Input() *ast.Term { + return evt.input +} + +// Plug plugs event bindings into the provided ast.Term. Because bindings are mutable, this only makes sense to do when +// the event is emitted rather than on recorded trace events as the bindings are going to be different by then. +func (evt *Event) Plug(term *ast.Term) *ast.Term { + return evt.bindings.Plug(term) +} + +func (evt *Event) equalNodes(other *Event) bool { + switch a := evt.Node.(type) { + case ast.Body: + if b, ok := other.Node.(ast.Body); ok { + return a.Equal(b) + } + case *ast.Rule: + if b, ok := other.Node.(*ast.Rule); ok { + return a.Equal(b) + } + case *ast.Expr: + if b, ok := other.Node.(*ast.Expr); ok { + return a.Equal(b) + } + case nil: + return other.Node == nil + } + return false +} + +// Tracer defines the interface for tracing in the top-down evaluation engine. +// Deprecated: Use QueryTracer instead. +type Tracer interface { + Enabled() bool + Trace(*Event) +} + +// QueryTracer defines the interface for tracing in the top-down evaluation engine. +// The implementation can provide additional configuration to modify the tracing +// behavior for query evaluations. +type QueryTracer interface { + Enabled() bool + TraceEvent(Event) + Config() TraceConfig +} + +// TraceConfig defines some common configuration for Tracer implementations +type TraceConfig struct { + PlugLocalVars bool // Indicate whether to plug local variable bindings before calling into the tracer. +} + +// legacyTracer Implements the QueryTracer interface by wrapping an older Tracer instance. +type legacyTracer struct { + t Tracer +} + +func (l *legacyTracer) Enabled() bool { + return l.t.Enabled() +} + +func (l *legacyTracer) Config() TraceConfig { + return TraceConfig{ + PlugLocalVars: true, // For backwards compatibility old tracers will plug local variables + } +} + +func (l *legacyTracer) TraceEvent(evt Event) { + l.t.Trace(&evt) +} + +// WrapLegacyTracer will create a new QueryTracer which wraps an +// older Tracer instance. +func WrapLegacyTracer(tracer Tracer) QueryTracer { + return &legacyTracer{t: tracer} +} + +// BufferTracer implements the Tracer and QueryTracer interface by +// simply buffering all events received. +type BufferTracer []*Event + +// NewBufferTracer returns a new BufferTracer. +func NewBufferTracer() *BufferTracer { + return &BufferTracer{} +} + +// Enabled always returns true if the BufferTracer is instantiated. +func (b *BufferTracer) Enabled() bool { + return b != nil +} + +// Trace adds the event to the buffer. +// Deprecated: Use TraceEvent instead. +func (b *BufferTracer) Trace(evt *Event) { + *b = append(*b, evt) +} + +// TraceEvent adds the event to the buffer. +func (b *BufferTracer) TraceEvent(evt Event) { + *b = append(*b, &evt) +} + +// Config returns the Tracers standard configuration +func (b *BufferTracer) Config() TraceConfig { + return TraceConfig{PlugLocalVars: true} +} + +// PrettyTrace pretty prints the trace to the writer. +func PrettyTrace(w io.Writer, trace []*Event) { + PrettyTraceWithOpts(w, trace, PrettyTraceOptions{}) +} + +// PrettyTraceWithLocation prints the trace to the writer and includes location information +func PrettyTraceWithLocation(w io.Writer, trace []*Event) { + PrettyTraceWithOpts(w, trace, PrettyTraceOptions{Locations: true}) +} + +type PrettyTraceOptions struct { + Locations bool // Include location information + ExprVariables bool // Include variables found in the expression + LocalVariables bool // Include all local variables +} + +type traceRow []string + +func (r *traceRow) add(s string) { + *r = append(*r, s) +} + +type traceTable struct { + rows []traceRow + maxWidths []int +} + +func (t *traceTable) add(row traceRow) { + t.rows = append(t.rows, row) + for i := range row { + if i >= len(t.maxWidths) { + t.maxWidths = append(t.maxWidths, len(row[i])) + } else if len(row[i]) > t.maxWidths[i] { + t.maxWidths[i] = len(row[i]) + } + } +} + +func (t *traceTable) write(w io.Writer, padding int) { + for _, row := range t.rows { + for i, cell := range row { + width := t.maxWidths[i] + padding + if i < len(row)-1 { + _, _ = fmt.Fprintf(w, "%-*s ", width, cell) + } else { + _, _ = fmt.Fprintf(w, "%s", cell) + } + } + _, _ = fmt.Fprintln(w) + } +} + +func PrettyTraceWithOpts(w io.Writer, trace []*Event, opts PrettyTraceOptions) { + depths := depths{} + + // FIXME: Can we shorten each location as we process each trace event instead of beforehand? + filePathAliases, _ := getShortenedFileNames(trace) + + table := traceTable{} + + for _, event := range trace { + depth := depths.GetOrSet(event.QueryID, event.ParentID) + row := traceRow{} + + if opts.Locations { + location := formatLocation(event, filePathAliases) + row.add(location) + } + + row.add(formatEvent(event, depth)) + + if opts.ExprVariables { + vars := exprLocalVars(event) + keys := sortedKeys(vars) + + buf := new(bytes.Buffer) + buf.WriteString("{") + for i, k := range keys { + if i > 0 { + buf.WriteString(", ") + } + _, _ = fmt.Fprintf(buf, "%v: %s", k, iStrs.Truncate(vars.Get(k).String(), maxExprVarWidth)) + } + buf.WriteString("}") + row.add(buf.String()) + } + + if opts.LocalVariables { + if locals := event.Locals; locals != nil { + keys := sortedKeys(locals) + + buf := new(bytes.Buffer) + buf.WriteString("{") + for i, k := range keys { + if i > 0 { + buf.WriteString(", ") + } + _, _ = fmt.Fprintf(buf, "%v: %s", k, iStrs.Truncate(locals.Get(k).String(), maxExprVarWidth)) + } + buf.WriteString("}") + row.add(buf.String()) + } else { + row.add("{}") + } + } + + table.add(row) + } + + table.write(w, columnPadding) +} + +func sortedKeys(vm *ast.ValueMap) []ast.Value { + keys := make([]ast.Value, 0, vm.Len()) + vm.Iter(func(k, _ ast.Value) bool { + keys = append(keys, k) + return false + }) + slices.SortFunc(keys, func(a, b ast.Value) int { + return strings.Compare(a.String(), b.String()) + }) + return keys +} + +func exprLocalVars(e *Event) *ast.ValueMap { + vars := ast.NewValueMap() + + findVars := func(term *ast.Term) bool { + //if r, ok := term.Value.(ast.Ref); ok { + // fmt.Printf("ref: %v\n", r) + // //return true + //} + if name, ok := term.Value.(ast.Var); ok { + if meta, ok := e.LocalMetadata[name]; ok { + if val := e.Locals.Get(name); val != nil { + vars.Put(meta.Name, val) + } + } + } + return false + } + + if r, ok := e.Node.(*ast.Rule); ok { + // We're only interested in vars in the head, not the body + ast.WalkTerms(r.Head, findVars) + return vars + } + + // The local cache snapshot only contains a snapshot for those refs present in the event node, + // so they can all be added to the vars map. + e.localVirtualCacheSnapshot.Iter(func(k, v ast.Value) bool { + vars.Put(k, v) + return false + }) + + ast.WalkTerms(e.Node, findVars) + + return vars +} + +func formatEvent(event *Event, depth int) string { + padding := formatEventPadding(event, depth) + if event.Op == NoteOp { + return fmt.Sprintf("%v%v %q", padding, event.Op, event.Message) + } + + var details interface{} + if node, ok := event.Node.(*ast.Rule); ok { + details = node.Path() + } else if event.Ref != nil { + details = event.Ref + } else { + details = rewrite(event).Node + } + + template := "%v%v %v" + opts := []interface{}{padding, event.Op, details} + + if event.Message != "" { + template += " %v" + opts = append(opts, event.Message) + } + + return fmt.Sprintf(template, opts...) +} + +func formatEventPadding(event *Event, depth int) string { + spaces := formatEventSpaces(event, depth) + if spaces > 1 { + return strings.Repeat("| ", spaces-1) + } + return "" +} + +func formatEventSpaces(event *Event, depth int) int { + switch event.Op { + case EnterOp: + return depth + case RedoOp: + if _, ok := event.Node.(*ast.Expr); !ok { + return depth + } + } + return depth + 1 +} + +// getShortenedFileNames will return a map of file paths to shortened aliases +// that were found in the trace. It also returns the longest location expected +func getShortenedFileNames(trace []*Event) (map[string]string, int) { + // Get a deduplicated list of all file paths + // and the longest file path size + fpAliases := map[string]string{} + var canShorten []string + longestLocation := 0 + for _, event := range trace { + if event.Location != nil { + if event.Location.File != "" { + // length of ":" + curLen := len(event.Location.File) + numDigits10(event.Location.Row) + 1 + if curLen > longestLocation { + longestLocation = curLen + } + + if _, ok := fpAliases[event.Location.File]; ok { + continue + } + + canShorten = append(canShorten, event.Location.File) + + // Default to just alias their full path + fpAliases[event.Location.File] = event.Location.File + } else { + // length of ":" + curLen := minLocationWidth + numDigits10(event.Location.Row) + 1 + if curLen > longestLocation { + longestLocation = curLen + } + } + } + } + + if len(canShorten) > 0 && longestLocation > maxIdealLocationWidth { + fpAliases, longestLocation = iStrs.TruncateFilePaths(maxIdealLocationWidth, longestLocation, canShorten...) + } + + return fpAliases, longestLocation +} + +func numDigits10(n int) int { + if n < 10 { + return 1 + } + return numDigits10(n/10) + 1 +} + +func formatLocation(event *Event, fileAliases map[string]string) string { + + location := event.Location + if location == nil { + return "" + } + + if location.File == "" { + return fmt.Sprintf("query:%v", location.Row) + } + + return fmt.Sprintf("%v:%v", fileAliases[location.File], location.Row) +} + +// depths is a helper for computing the depth of an event. Events within the +// same query all have the same depth. The depth of query is +// depth(parent(query))+1. +type depths map[uint64]int + +func (ds depths) GetOrSet(qid uint64, pqid uint64) int { + depth := ds[qid] + if depth == 0 { + depth = ds[pqid] + depth++ + ds[qid] = depth + } + return depth +} + +func builtinTrace(bctx BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error { + + str, err := builtins.StringOperand(operands[0].Value, 1) + if err != nil { + return handleBuiltinErr(ast.Trace.Name, bctx.Location, err) + } + + if !bctx.TraceEnabled { + return iter(ast.InternedBooleanTerm(true)) + } + + evt := Event{ + Op: NoteOp, + Location: bctx.Location, + QueryID: bctx.QueryID, + ParentID: bctx.ParentID, + Message: string(str), + } + + for i := range bctx.QueryTracers { + bctx.QueryTracers[i].TraceEvent(evt) + } + + return iter(ast.InternedBooleanTerm(true)) +} + +func rewrite(event *Event) *Event { + + cpy := *event + + var node ast.Node + + switch v := event.Node.(type) { + case *ast.Expr: + expr := v.Copy() + + // Hide generated local vars in 'key' position that have not been + // rewritten. + if ev, ok := v.Terms.(*ast.Every); ok { + if kv, ok := ev.Key.Value.(ast.Var); ok { + if rw, ok := cpy.LocalMetadata[kv]; !ok || rw.Name.IsGenerated() { + expr.Terms.(*ast.Every).Key = nil + } + } + } + node = expr + case ast.Body: + node = v.Copy() + case *ast.Rule: + node = v.Copy() + } + + _, _ = ast.TransformVars(node, func(v ast.Var) (ast.Value, error) { + if meta, ok := cpy.LocalMetadata[v]; ok { + return meta.Name, nil + } + return v, nil + }) + + cpy.Node = node + + return &cpy +} + +type varInfo struct { + VarMetadata + val ast.Value + exprLoc *ast.Location + col int // 0-indexed column +} + +func (v varInfo) Value() string { + if v.val != nil { + return v.val.String() + } + return "undefined" +} + +func (v varInfo) Title() string { + if v.exprLoc != nil && v.exprLoc.Text != nil { + return string(v.exprLoc.Text) + } + return string(v.Name) +} + +func padLocationText(loc *ast.Location) string { + if loc == nil { + return "" + } + + text := string(loc.Text) + + if loc.Col == 0 { + return text + } + + buf := new(bytes.Buffer) + j := 0 + for i := 1; i < loc.Col; i++ { + if len(loc.Tabs) > 0 && j < len(loc.Tabs) && loc.Tabs[j] == i { + buf.WriteString("\t") + j++ + } else { + buf.WriteString(" ") + } + } + + buf.WriteString(text) + return buf.String() +} + +type PrettyEventOpts struct { + PrettyVars bool +} + +func walkTestTerms(x interface{}, f func(*ast.Term) bool) { + var vis *ast.GenericVisitor + vis = ast.NewGenericVisitor(func(x interface{}) bool { + switch x := x.(type) { + case ast.Call: + for _, t := range x[1:] { + vis.Walk(t) + } + return true + case *ast.Expr: + if x.IsCall() { + for _, o := range x.Operands() { + vis.Walk(o) + } + for i := range x.With { + vis.Walk(x.With[i]) + } + return true + } + case *ast.Term: + return f(x) + case *ast.With: + vis.Walk(x.Value) + return true + } + return false + }) + vis.Walk(x) +} + +func PrettyEvent(w io.Writer, e *Event, opts PrettyEventOpts) error { + if !opts.PrettyVars { + _, _ = fmt.Fprintln(w, padLocationText(e.Location)) + return nil + } + + buf := new(bytes.Buffer) + exprVars := map[string]varInfo{} + + findVars := func(unknownAreUndefined bool) func(term *ast.Term) bool { + return func(term *ast.Term) bool { + if term.Location == nil { + return false + } + + switch v := term.Value.(type) { + case *ast.ArrayComprehension, *ast.SetComprehension, *ast.ObjectComprehension: + // we don't report on the internals of a comprehension, as it's already evaluated, and we won't have the local vars. + return true + case ast.Var: + var info *varInfo + if meta, ok := e.LocalMetadata[v]; ok { + info = &varInfo{ + VarMetadata: meta, + val: e.Locals.Get(v), + exprLoc: term.Location, + } + } else if unknownAreUndefined { + info = &varInfo{ + VarMetadata: VarMetadata{Name: v}, + exprLoc: term.Location, + col: term.Location.Col, + } + } + + if info != nil { + if v, exists := exprVars[info.Title()]; !exists || v.val == nil { + if term.Location != nil { + info.col = term.Location.Col + } + exprVars[info.Title()] = *info + } + } + } + return false + } + } + + expr, ok := e.Node.(*ast.Expr) + if !ok || expr == nil { + return nil + } + + base := expr.BaseCogeneratedExpr() + exprText := padLocationText(base.Location) + buf.WriteString(exprText) + + e.localVirtualCacheSnapshot.Iter(func(k, v ast.Value) bool { + var info *varInfo + switch k := k.(type) { + case ast.Ref: + info = &varInfo{ + VarMetadata: VarMetadata{Name: ast.Var(k.String())}, + val: v, + exprLoc: k[0].Location, + col: k[0].Location.Col, + } + case *ast.ArrayComprehension: + info = &varInfo{ + VarMetadata: VarMetadata{Name: ast.Var(k.String())}, + val: v, + exprLoc: k.Term.Location, + col: k.Term.Location.Col, + } + case *ast.SetComprehension: + info = &varInfo{ + VarMetadata: VarMetadata{Name: ast.Var(k.String())}, + val: v, + exprLoc: k.Term.Location, + col: k.Term.Location.Col, + } + case *ast.ObjectComprehension: + info = &varInfo{ + VarMetadata: VarMetadata{Name: ast.Var(k.String())}, + val: v, + exprLoc: k.Key.Location, + col: k.Key.Location.Col, + } + } + + if info != nil { + exprVars[info.Title()] = *info + } + + return false + }) + + // If the expression is negated, we can't confidently assert that vars with unknown values are 'undefined', + // since the compiler might have opted out of the necessary rewrite. + walkTestTerms(expr, findVars(!expr.Negated)) + coExprs := expr.CogeneratedExprs() + for _, coExpr := range coExprs { + // Only the current "co-expr" can have undefined vars, if we don't know the value for a var in any other co-expr, + // it's unknown, not undefined. A var can be unknown if it hasn't been assigned a value yet, because the co-expr + // hasn't been evaluated yet (the fail happened before it). + walkTestTerms(coExpr, findVars(false)) + } + + printPrettyVars(buf, exprVars) + _, _ = fmt.Fprint(w, buf.String()) + return nil +} + +func printPrettyVars(w *bytes.Buffer, exprVars map[string]varInfo) { + containsTabs := false + varRows := make(map[int]interface{}) + for _, info := range exprVars { + if len(info.exprLoc.Tabs) > 0 { + containsTabs = true + } + varRows[info.exprLoc.Row] = nil + } + + if containsTabs && len(varRows) > 1 { + // We can't (currently) reliably point to var locations when they are on different rows that contain tabs. + // So we'll just print them in alphabetical order instead. + byName := make([]varInfo, 0, len(exprVars)) + for _, info := range exprVars { + byName = append(byName, info) + } + slices.SortStableFunc(byName, func(a, b varInfo) int { + return strings.Compare(a.Title(), b.Title()) + }) + + w.WriteString("\n\nWhere:\n") + for _, info := range byName { + w.WriteString(fmt.Sprintf("\n%s: %s", info.Title(), iStrs.Truncate(info.Value(), maxPrettyExprVarWidth))) + } + + return + } + + byCol := make([]varInfo, 0, len(exprVars)) + for _, info := range exprVars { + byCol = append(byCol, info) + } + slices.SortFunc(byCol, func(a, b varInfo) int { + // sort first by column, then by reverse row (to present vars in the same order they appear in the expr) + if a.col == b.col { + if a.exprLoc.Row == b.exprLoc.Row { + return strings.Compare(a.Title(), b.Title()) + } + return b.exprLoc.Row - a.exprLoc.Row + } + return a.col - b.col + }) + + if len(byCol) == 0 { + return + } + + w.WriteString("\n") + printArrows(w, byCol, -1) + for i := len(byCol) - 1; i >= 0; i-- { + w.WriteString("\n") + printArrows(w, byCol, i) + } +} + +func printArrows(w *bytes.Buffer, l []varInfo, printValueAt int) { + prevCol := 0 + var slice []varInfo + if printValueAt >= 0 { + slice = l[:printValueAt+1] + } else { + slice = l + } + isFirst := true + for i, info := range slice { + + isLast := i >= len(slice)-1 + col := info.col + + if !isLast && col == l[i+1].col { + // We're sharing the same column with another, subsequent var + continue + } + + spaces := col - 1 + if i > 0 && !isFirst { + spaces = (col - prevCol) - 1 + } + + for j := 0; j < spaces; j++ { + tab := false + for _, t := range info.exprLoc.Tabs { + if t == j+prevCol+1 { + w.WriteString("\t") + tab = true + break + } + } + if !tab { + w.WriteString(" ") + } + } + + if isLast && printValueAt >= 0 { + valueStr := iStrs.Truncate(info.Value(), maxPrettyExprVarWidth) + if (i > 0 && col == l[i-1].col) || (i < len(l)-1 && col == l[i+1].col) { + // There is another var on this column, so we need to include the name to differentiate them. + w.WriteString(fmt.Sprintf("%s: %s", info.Title(), valueStr)) + } else { + w.WriteString(valueStr) + } + } else { + w.WriteString("|") + } + prevCol = col + isFirst = false + } +} + +func init() { + RegisterBuiltinFunc(ast.Trace.Name, builtinTrace) +} diff --git a/vendor/github.com/open-policy-agent/opa/topdown/type.go b/vendor/github.com/open-policy-agent/opa/v1/topdown/type.go similarity index 71% rename from vendor/github.com/open-policy-agent/opa/topdown/type.go rename to vendor/github.com/open-policy-agent/opa/v1/topdown/type.go index dab5c853c..6103fbe48 100644 --- a/vendor/github.com/open-policy-agent/opa/topdown/type.go +++ b/vendor/github.com/open-policy-agent/opa/v1/topdown/type.go @@ -5,69 +5,69 @@ package topdown import ( - "github.com/open-policy-agent/opa/ast" + "github.com/open-policy-agent/opa/v1/ast" ) func builtinIsNumber(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error { switch operands[0].Value.(type) { case ast.Number: - return iter(ast.BooleanTerm(true)) + return iter(ast.InternedBooleanTerm(true)) default: - return iter(ast.BooleanTerm(false)) + return iter(ast.InternedBooleanTerm(false)) } } func builtinIsString(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error { switch operands[0].Value.(type) { case ast.String: - return iter(ast.BooleanTerm(true)) + return iter(ast.InternedBooleanTerm(true)) default: - return iter(ast.BooleanTerm(false)) + return iter(ast.InternedBooleanTerm(false)) } } func builtinIsBoolean(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error { switch operands[0].Value.(type) { case ast.Boolean: - return iter(ast.BooleanTerm(true)) + return iter(ast.InternedBooleanTerm(true)) default: - return iter(ast.BooleanTerm(false)) + return iter(ast.InternedBooleanTerm(false)) } } func builtinIsArray(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error { switch operands[0].Value.(type) { case *ast.Array: - return iter(ast.BooleanTerm(true)) + return iter(ast.InternedBooleanTerm(true)) default: - return iter(ast.BooleanTerm(false)) + return iter(ast.InternedBooleanTerm(false)) } } func builtinIsSet(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error { switch operands[0].Value.(type) { case ast.Set: - return iter(ast.BooleanTerm(true)) + return iter(ast.InternedBooleanTerm(true)) default: - return iter(ast.BooleanTerm(false)) + return iter(ast.InternedBooleanTerm(false)) } } func builtinIsObject(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error { switch operands[0].Value.(type) { case ast.Object: - return iter(ast.BooleanTerm(true)) + return iter(ast.InternedBooleanTerm(true)) default: - return iter(ast.BooleanTerm(false)) + return iter(ast.InternedBooleanTerm(false)) } } func builtinIsNull(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error { switch operands[0].Value.(type) { case ast.Null: - return iter(ast.BooleanTerm(true)) + return iter(ast.InternedBooleanTerm(true)) default: - return iter(ast.BooleanTerm(false)) + return iter(ast.InternedBooleanTerm(false)) } } diff --git a/vendor/github.com/open-policy-agent/opa/v1/topdown/type_name.go b/vendor/github.com/open-policy-agent/opa/v1/topdown/type_name.go new file mode 100644 index 000000000..fc3de4879 --- /dev/null +++ b/vendor/github.com/open-policy-agent/opa/v1/topdown/type_name.go @@ -0,0 +1,46 @@ +// Copyright 2018 The OPA Authors. All rights reserved. +// Use of this source code is governed by an Apache2 +// license that can be found in the LICENSE file. + +package topdown + +import ( + "fmt" + + "github.com/open-policy-agent/opa/v1/ast" +) + +var ( + nullStringTerm = ast.StringTerm("null") + booleanStringTerm = ast.StringTerm("boolean") + numberStringTerm = ast.StringTerm("number") + stringStringTerm = ast.StringTerm("string") + arrayStringTerm = ast.StringTerm("array") + objectStringTerm = ast.StringTerm("object") + setStringTerm = ast.StringTerm("set") +) + +func builtinTypeName(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error { + switch operands[0].Value.(type) { + case ast.Null: + return iter(nullStringTerm) + case ast.Boolean: + return iter(booleanStringTerm) + case ast.Number: + return iter(numberStringTerm) + case ast.String: + return iter(stringStringTerm) + case *ast.Array: + return iter(arrayStringTerm) + case ast.Object: + return iter(objectStringTerm) + case ast.Set: + return iter(setStringTerm) + } + + return fmt.Errorf("illegal value") +} + +func init() { + RegisterBuiltinFunc(ast.TypeNameBuiltin.Name, builtinTypeName) +} diff --git a/vendor/github.com/open-policy-agent/opa/topdown/uuid.go b/vendor/github.com/open-policy-agent/opa/v1/topdown/uuid.go similarity index 92% rename from vendor/github.com/open-policy-agent/opa/topdown/uuid.go rename to vendor/github.com/open-policy-agent/opa/v1/topdown/uuid.go index d3a7a5f90..d013df9fe 100644 --- a/vendor/github.com/open-policy-agent/opa/topdown/uuid.go +++ b/vendor/github.com/open-policy-agent/opa/v1/topdown/uuid.go @@ -5,9 +5,9 @@ package topdown import ( - "github.com/open-policy-agent/opa/ast" "github.com/open-policy-agent/opa/internal/uuid" - "github.com/open-policy-agent/opa/topdown/builtins" + "github.com/open-policy-agent/opa/v1/ast" + "github.com/open-policy-agent/opa/v1/topdown/builtins" ) type uuidCachingKey string diff --git a/vendor/github.com/open-policy-agent/opa/topdown/walk.go b/vendor/github.com/open-policy-agent/opa/v1/topdown/walk.go similarity index 62% rename from vendor/github.com/open-policy-agent/opa/topdown/walk.go rename to vendor/github.com/open-policy-agent/opa/v1/topdown/walk.go index 0f3b3544b..f5dcf5c9f 100644 --- a/vendor/github.com/open-policy-agent/opa/topdown/walk.go +++ b/vendor/github.com/open-policy-agent/opa/v1/topdown/walk.go @@ -5,18 +5,20 @@ package topdown import ( - "github.com/open-policy-agent/opa/ast" + "github.com/open-policy-agent/opa/v1/ast" ) +var emptyArr = ast.ArrayTerm() + func evalWalk(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error { input := operands[0] if pathIsWildcard(operands) { // When the path assignment is a wildcard: walk(input, [_, value]) // we may skip the path construction entirely, and simply return - // same pointer in each iteration. This is a much more efficient + // same pointer in each iteration. This is a *much* more efficient // path when only the values are needed. - return walkNoPath(input, iter) + return walkNoPath(ast.ArrayTerm(emptyArr, input), iter) } filter := getOutputPath(operands) @@ -24,7 +26,6 @@ func evalWalk(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error } func walk(filter, path *ast.Array, input *ast.Term, iter func(*ast.Term) error) error { - if filter == nil || filter.Len() == 0 { if path == nil { path = ast.NewArray() @@ -50,57 +51,62 @@ func walk(filter, path *ast.Array, input *ast.Term, iter func(*ast.Term) error) switch v := input.Value.(type) { case *ast.Array: for i := 0; i < v.Len(); i++ { - path = pathAppend(path, ast.IntNumberTerm(i)) - if err := walk(filter, path, v.Elem(i), iter); err != nil { + if err := walk(filter, pathAppend(path, ast.InternedIntNumberTerm(i)), v.Elem(i), iter); err != nil { return err } - path = path.Slice(0, path.Len()-1) } case ast.Object: - return v.Iter(func(k, v *ast.Term) error { - path = pathAppend(path, k) - if err := walk(filter, path, v, iter); err != nil { + for _, k := range v.Keys() { + if err := walk(filter, pathAppend(path, k), v.Get(k), iter); err != nil { return err } - path = path.Slice(0, path.Len()-1) - return nil - }) + } case ast.Set: - return v.Iter(func(elem *ast.Term) error { - path = pathAppend(path, elem) - if err := walk(filter, path, elem, iter); err != nil { + for _, elem := range v.Slice() { + if err := walk(filter, pathAppend(path, elem), elem, iter); err != nil { return err } - path = path.Slice(0, path.Len()-1) - return nil - }) + } } return nil } -var emptyArr = ast.ArrayTerm() - func walkNoPath(input *ast.Term, iter func(*ast.Term) error) error { - if err := iter(ast.ArrayTerm(emptyArr, input)); err != nil { + // Note: the path array is embedded in the input from the start here + // in order to avoid an extra allocation per iteration. This leads to + // a little convoluted code below in order to extract and set the value, + // but since walk is commonly used to traverse large data structures, + // the performance gain is worth it. + if err := iter(input); err != nil { return err } - switch v := input.Value.(type) { + inputArray := input.Value.(*ast.Array) + value := inputArray.Get(ast.InternedIntNumberTerm(1)).Value + + switch v := value.(type) { case ast.Object: - return v.Iter(func(_, v *ast.Term) error { - return walkNoPath(v, iter) - }) + for _, k := range v.Keys() { + inputArray.Set(1, v.Get(k)) + if err := walkNoPath(input, iter); err != nil { + return err + } + } case *ast.Array: for i := 0; i < v.Len(); i++ { - if err := walkNoPath(v.Elem(i), iter); err != nil { + inputArray.Set(1, v.Elem(i)) + if err := walkNoPath(input, iter); err != nil { return err } } case ast.Set: - return v.Iter(func(elem *ast.Term) error { - return walkNoPath(elem, iter) - }) + for _, elem := range v.Slice() { + inputArray.Set(1, elem) + if err := walkNoPath(input, iter); err != nil { + return err + } + } } return nil diff --git a/vendor/github.com/open-policy-agent/opa/v1/tracing/tracing.go b/vendor/github.com/open-policy-agent/opa/v1/tracing/tracing.go new file mode 100644 index 000000000..2708b78e2 --- /dev/null +++ b/vendor/github.com/open-policy-agent/opa/v1/tracing/tracing.go @@ -0,0 +1,55 @@ +// Copyright 2021 The OPA Authors. All rights reserved. +// Use of this source code is governed by an Apache2 +// license that can be found in the LICENSE file. + +// Package tracing enables dependency-injection at runtime. When used +// together with an underscore-import of `github.com/open-policy-agent/opa/features/tracing`, +// the server and its runtime will emit OpenTelemetry spans to the +// configured sink. +package tracing + +import "net/http" + +// Options are options for the HTTPTracingService, passed along as-is. +type Options []interface{} + +// NewOptions is a helper method for constructing `tracing.Options` +func NewOptions(opts ...interface{}) Options { + return opts +} + +// HTTPTracingService defines how distributed tracing comes in, server- and client-side +type HTTPTracingService interface { + // NewTransport is used when setting up an HTTP client + NewTransport(http.RoundTripper, Options) http.RoundTripper + + // NewHandler is used to wrap an http.Handler in the server + NewHandler(http.Handler, string, Options) http.Handler +} + +var tracing HTTPTracingService + +// RegisterHTTPTracing enables a HTTPTracingService for further use. +func RegisterHTTPTracing(ht HTTPTracingService) { + tracing = ht +} + +// NewTransport returns another http.RoundTripper, instrumented to emit tracing +// spans according to Options. Provided by the HTTPTracingService registered with +// this package via RegisterHTTPTracing. +func NewTransport(tr http.RoundTripper, opts Options) http.RoundTripper { + if tracing == nil { + return tr + } + return tracing.NewTransport(tr, opts) +} + +// NewHandler returns another http.Handler, instrumented to emit tracing spans +// according to Options. Provided by the HTTPTracingService registered with +// this package via RegisterHTTPTracing. +func NewHandler(f http.Handler, label string, opts Options) http.Handler { + if tracing == nil { + return f + } + return tracing.NewHandler(f, label, opts) +} diff --git a/vendor/github.com/open-policy-agent/opa/types/decode.go b/vendor/github.com/open-policy-agent/opa/v1/types/decode.go similarity index 99% rename from vendor/github.com/open-policy-agent/opa/types/decode.go rename to vendor/github.com/open-policy-agent/opa/v1/types/decode.go index a6bd9ea03..3fcc01664 100644 --- a/vendor/github.com/open-policy-agent/opa/types/decode.go +++ b/vendor/github.com/open-policy-agent/opa/v1/types/decode.go @@ -8,7 +8,7 @@ import ( "encoding/json" "fmt" - "github.com/open-policy-agent/opa/util" + "github.com/open-policy-agent/opa/v1/util" ) const ( diff --git a/vendor/github.com/open-policy-agent/opa/types/types.go b/vendor/github.com/open-policy-agent/opa/v1/types/types.go similarity index 98% rename from vendor/github.com/open-policy-agent/opa/types/types.go rename to vendor/github.com/open-policy-agent/opa/v1/types/types.go index 2a050927d..070521087 100644 --- a/vendor/github.com/open-policy-agent/opa/types/types.go +++ b/vendor/github.com/open-policy-agent/opa/v1/types/types.go @@ -12,7 +12,7 @@ import ( "sort" "strings" - "github.com/open-policy-agent/opa/util" + "github.com/open-policy-agent/opa/v1/util" ) // Sprint returns the string representation of the type. @@ -675,7 +675,7 @@ func Arity(x Type) int { if !ok { return 0 } - return len(f.FuncArgs().Args) + return f.Arity() } // NewFunction returns a new Function object of the given argument and result types. @@ -723,6 +723,11 @@ func (t *Function) Args() []Type { return cpy } +// Arity returns the number of arguments in the function signature. +func (t *Function) Arity() int { + return len(t.args) +} + // Result returns the function's result type. func (t *Function) Result() Type { return unwrap(t.result) @@ -780,14 +785,15 @@ func (t *Function) Union(other *Function) *Function { return other } - a := t.Args() - b := other.Args() - if len(a) != len(b) { + if t.Arity() != other.Arity() { return nil } - aIsVariadic := t.FuncArgs().Variadic != nil - bIsVariadic := other.FuncArgs().Variadic != nil + tfa := t.FuncArgs() + ofa := other.FuncArgs() + + aIsVariadic := tfa.Variadic != nil + bIsVariadic := ofa.Variadic != nil if aIsVariadic && !bIsVariadic { return nil @@ -795,13 +801,16 @@ func (t *Function) Union(other *Function) *Function { return nil } + a := t.Args() + b := other.Args() + args := make([]Type, len(a)) for i := range a { args[i] = Or(a[i], b[i]) } result := NewFunction(args, Or(t.Result(), other.Result())) - result.variadic = Or(t.FuncArgs().Variadic, other.FuncArgs().Variadic) + result.variadic = Or(tfa.Variadic, ofa.Variadic) return result } diff --git a/vendor/github.com/open-policy-agent/opa/v1/util/backoff.go b/vendor/github.com/open-policy-agent/opa/v1/util/backoff.go new file mode 100644 index 000000000..36d57f14e --- /dev/null +++ b/vendor/github.com/open-policy-agent/opa/v1/util/backoff.go @@ -0,0 +1,53 @@ +// Copyright 2018 The OPA Authors. All rights reserved. +// Use of this source code is governed by an Apache2 +// license that can be found in the LICENSE file. + +package util + +import ( + "math/rand" + "time" +) + +func init() { + // NOTE(sr): We don't need good random numbers here; it's used for jittering + // the backup timing a bit. But anyways, let's make it random enough; without + // a call to rand.Seed() we'd get the same stream of numbers for each program + // run. (Or not, if some other packages happens to seed the global randomness + // source.) + // Note(philipc): rand.Seed() was deprecated in Go 1.20, so we've switched to + // using the recommended rand.New(rand.NewSource(seed)) style. + rand.New(rand.NewSource(time.Now().UnixNano())) +} + +// DefaultBackoff returns a delay with an exponential backoff based on the +// number of retries. +func DefaultBackoff(base, maxNS float64, retries int) time.Duration { + return Backoff(base, maxNS, .2, 1.6, retries) +} + +// Backoff returns a delay with an exponential backoff based on the number of +// retries. Same algorithm used in gRPC. +func Backoff(base, maxNS, jitter, factor float64, retries int) time.Duration { + if retries == 0 { + return 0 + } + + backoff, maxNS := base, maxNS + for backoff < maxNS && retries > 0 { + backoff *= factor + retries-- + } + if backoff > maxNS { + backoff = maxNS + } + + // Randomize backoff delays so that if a cluster of requests start at + // the same time, they won't operate in lockstep. + backoff *= 1 + jitter*(rand.Float64()*2-1) + if backoff < 0 { + return 0 + } + + return time.Duration(backoff) +} diff --git a/vendor/github.com/open-policy-agent/opa/v1/util/close.go b/vendor/github.com/open-policy-agent/opa/v1/util/close.go new file mode 100644 index 000000000..c3c177557 --- /dev/null +++ b/vendor/github.com/open-policy-agent/opa/v1/util/close.go @@ -0,0 +1,22 @@ +// Copyright 2018 The OPA Authors. All rights reserved. +// Use of this source code is governed by an Apache2 +// license that can be found in the LICENSE file. + +package util + +import ( + "io" + "net/http" +) + +// Close reads the remaining bytes from the response and then closes it to +// ensure that the connection is freed. If the body is not read and closed, a +// leak can occur. +func Close(resp *http.Response) { + if resp != nil && resp.Body != nil { + if _, err := io.Copy(io.Discard, resp.Body); err != nil { + return + } + resp.Body.Close() + } +} diff --git a/vendor/github.com/open-policy-agent/opa/v1/util/compare.go b/vendor/github.com/open-policy-agent/opa/v1/util/compare.go new file mode 100644 index 000000000..8ae775369 --- /dev/null +++ b/vendor/github.com/open-policy-agent/opa/v1/util/compare.go @@ -0,0 +1,184 @@ +// Copyright 2016 The OPA Authors. All rights reserved. +// Use of this source code is governed by an Apache2 +// license that can be found in the LICENSE file. + +package util + +import ( + "encoding/json" + "fmt" + "math/big" + "sort" +) + +// Compare returns 0 if a equals b, -1 if a is less than b, and 1 if b is than a. +// +// For comparison between values of different types, the following ordering is used: +// nil < bool < int, float64 < string < []interface{} < map[string]interface{}. Slices and maps +// are compared recursively. If one slice or map is a subset of the other slice or map +// it is considered "less than". Nil is always equal to nil. +func Compare(a, b interface{}) int { + aSortOrder := sortOrder(a) + bSortOrder := sortOrder(b) + if aSortOrder < bSortOrder { + return -1 + } else if bSortOrder < aSortOrder { + return 1 + } + switch a := a.(type) { + case nil: + return 0 + case bool: + switch b := b.(type) { + case bool: + if a == b { + return 0 + } + if !a { + return -1 + } + return 1 + } + case json.Number: + switch b := b.(type) { + case json.Number: + return compareJSONNumber(a, b) + } + case int: + switch b := b.(type) { + case int: + if a == b { + return 0 + } else if a < b { + return -1 + } + return 1 + } + case float64: + switch b := b.(type) { + case float64: + if a == b { + return 0 + } else if a < b { + return -1 + } + return 1 + } + case string: + switch b := b.(type) { + case string: + if a == b { + return 0 + } else if a < b { + return -1 + } + return 1 + } + case []interface{}: + switch b := b.(type) { + case []interface{}: + bLen := len(b) + aLen := len(a) + minLen := aLen + if bLen < minLen { + minLen = bLen + } + for i := 0; i < minLen; i++ { + cmp := Compare(a[i], b[i]) + if cmp != 0 { + return cmp + } + } + if aLen == bLen { + return 0 + } else if aLen < bLen { + return -1 + } + return 1 + } + case map[string]interface{}: + switch b := b.(type) { + case map[string]interface{}: + var aKeys []string + for k := range a { + aKeys = append(aKeys, k) + } + var bKeys []string + for k := range b { + bKeys = append(bKeys, k) + } + sort.Strings(aKeys) + sort.Strings(bKeys) + aLen := len(aKeys) + bLen := len(bKeys) + minLen := aLen + if bLen < minLen { + minLen = bLen + } + for i := 0; i < minLen; i++ { + if aKeys[i] < bKeys[i] { + return -1 + } else if bKeys[i] < aKeys[i] { + return 1 + } + aVal := a[aKeys[i]] + bVal := b[bKeys[i]] + cmp := Compare(aVal, bVal) + if cmp != 0 { + return cmp + } + } + if aLen == bLen { + return 0 + } else if aLen < bLen { + return -1 + } + return 1 + } + } + + panic(fmt.Sprintf("illegal arguments of type %T and type %T", a, b)) +} + +const ( + nilSort = iota + boolSort = iota + numberSort = iota + stringSort = iota + arraySort = iota + objectSort = iota +) + +func compareJSONNumber(a, b json.Number) int { + bigA, ok := new(big.Float).SetString(string(a)) + if !ok { + panic("illegal value") + } + bigB, ok := new(big.Float).SetString(string(b)) + if !ok { + panic("illegal value") + } + return bigA.Cmp(bigB) +} + +func sortOrder(v interface{}) int { + switch v.(type) { + case nil: + return nilSort + case bool: + return boolSort + case json.Number: + return numberSort + case int: + return numberSort + case float64: + return numberSort + case string: + return stringSort + case []interface{}: + return arraySort + case map[string]interface{}: + return objectSort + } + panic(fmt.Sprintf("illegal argument of type %T", v)) +} diff --git a/vendor/github.com/open-policy-agent/opa/util/decoding/context.go b/vendor/github.com/open-policy-agent/opa/v1/util/decoding/context.go similarity index 100% rename from vendor/github.com/open-policy-agent/opa/util/decoding/context.go rename to vendor/github.com/open-policy-agent/opa/v1/util/decoding/context.go diff --git a/vendor/github.com/open-policy-agent/opa/v1/util/doc.go b/vendor/github.com/open-policy-agent/opa/v1/util/doc.go new file mode 100644 index 000000000..900dff8c1 --- /dev/null +++ b/vendor/github.com/open-policy-agent/opa/v1/util/doc.go @@ -0,0 +1,6 @@ +// Copyright 2016 The OPA Authors. All rights reserved. +// Use of this source code is governed by an Apache2 +// license that can be found in the LICENSE file. + +// Package util provides generic utilities used throughout the policy engine. +package util diff --git a/vendor/github.com/open-policy-agent/opa/v1/util/enumflag.go b/vendor/github.com/open-policy-agent/opa/v1/util/enumflag.go new file mode 100644 index 000000000..4796f0269 --- /dev/null +++ b/vendor/github.com/open-policy-agent/opa/v1/util/enumflag.go @@ -0,0 +1,59 @@ +// Copyright 2017 The OPA Authors. All rights reserved. +// Use of this source code is governed by an Apache2 +// license that can be found in the LICENSE file. + +package util + +import ( + "fmt" + "strings" +) + +// EnumFlag implements the pflag.Value interface to provide enumerated command +// line parameter values. +type EnumFlag struct { + defaultValue string + vs []string + i int +} + +// NewEnumFlag returns a new EnumFlag that has a defaultValue and vs enumerated +// values. +func NewEnumFlag(defaultValue string, vs []string) *EnumFlag { + f := &EnumFlag{ + i: -1, + vs: vs, + defaultValue: defaultValue, + } + return f +} + +// Type returns the valid enumeration values. +func (f *EnumFlag) Type() string { + return "{" + strings.Join(f.vs, ",") + "}" +} + +// String returns the EnumValue's value as string. +func (f *EnumFlag) String() string { + if f.i == -1 { + return f.defaultValue + } + return f.vs[f.i] +} + +// IsSet will return true if the EnumFlag has been set. +func (f *EnumFlag) IsSet() bool { + return f.i != -1 +} + +// Set sets the enum value. If s is not a valid enum value, an error is +// returned. +func (f *EnumFlag) Set(s string) error { + for i := range f.vs { + if f.vs[i] == s { + f.i = i + return nil + } + } + return fmt.Errorf("must be one of %v", f.Type()) +} diff --git a/vendor/github.com/open-policy-agent/opa/v1/util/graph.go b/vendor/github.com/open-policy-agent/opa/v1/util/graph.go new file mode 100644 index 000000000..f0e824245 --- /dev/null +++ b/vendor/github.com/open-policy-agent/opa/v1/util/graph.go @@ -0,0 +1,90 @@ +// Copyright 2016 The OPA Authors. All rights reserved. +// Use of this source code is governed by an Apache2 +// license that can be found in the LICENSE file. + +package util + +// Traversal defines a basic interface to perform traversals. +type Traversal interface { + + // Edges should return the neighbours of node "u". + Edges(u T) []T + + // Visited should return true if node "u" has already been visited in this + // traversal. If the same traversal is used multiple times, the state that + // tracks visited nodes should be reset. + Visited(u T) bool +} + +// Equals should return true if node "u" equals node "v". +type Equals func(u T, v T) bool + +// Iter should return true to indicate stop. +type Iter func(u T) bool + +// DFS performs a depth first traversal calling f for each node starting from u. +// If f returns true, traversal stops and DFS returns true. +func DFS(t Traversal, f Iter, u T) bool { + lifo := NewLIFO(u) + for lifo.Size() > 0 { + next, _ := lifo.Pop() + if t.Visited(next) { + continue + } + if f(next) { + return true + } + for _, v := range t.Edges(next) { + lifo.Push(v) + } + } + return false +} + +// BFS performs a breadth first traversal calling f for each node starting from +// u. If f returns true, traversal stops and BFS returns true. +func BFS(t Traversal, f Iter, u T) bool { + fifo := NewFIFO(u) + for fifo.Size() > 0 { + next, _ := fifo.Pop() + if t.Visited(next) { + continue + } + if f(next) { + return true + } + for _, v := range t.Edges(next) { + fifo.Push(v) + } + } + return false +} + +// DFSPath returns a path from node a to node z found by performing +// a depth first traversal. If no path is found, an empty slice is returned. +func DFSPath(t Traversal, eq Equals, a, z T) []T { + p := dfsRecursive(t, eq, a, z, []T{}) + for i := len(p)/2 - 1; i >= 0; i-- { + o := len(p) - i - 1 + p[i], p[o] = p[o], p[i] + } + return p +} + +func dfsRecursive(t Traversal, eq Equals, u, z T, path []T) []T { + if t.Visited(u) { + return path + } + for _, v := range t.Edges(u) { + if eq(v, z) { + path = append(path, z) + path = append(path, u) + return path + } + if p := dfsRecursive(t, eq, v, z, path); len(p) > 0 { + path = append(p, u) + return path + } + } + return path +} diff --git a/vendor/github.com/open-policy-agent/opa/v1/util/hashmap.go b/vendor/github.com/open-policy-agent/opa/v1/util/hashmap.go new file mode 100644 index 000000000..8875a6323 --- /dev/null +++ b/vendor/github.com/open-policy-agent/opa/v1/util/hashmap.go @@ -0,0 +1,157 @@ +// Copyright 2016 The OPA Authors. All rights reserved. +// Use of this source code is governed by an Apache2 +// license that can be found in the LICENSE file. + +package util + +import ( + "fmt" + "strings" +) + +// T is a concise way to refer to T. +type T interface{} + +type hashEntry struct { + k T + v T + next *hashEntry +} + +// HashMap represents a key/value map. +type HashMap struct { + eq func(T, T) bool + hash func(T) int + table map[int]*hashEntry + size int +} + +// NewHashMap returns a new empty HashMap. +func NewHashMap(eq func(T, T) bool, hash func(T) int) *HashMap { + return &HashMap{ + eq: eq, + hash: hash, + table: make(map[int]*hashEntry), + size: 0, + } +} + +// Copy returns a shallow copy of this HashMap. +func (h *HashMap) Copy() *HashMap { + cpy := NewHashMap(h.eq, h.hash) + h.Iter(func(k, v T) bool { + cpy.Put(k, v) + return false + }) + return cpy +} + +// Equal returns true if this HashMap equals the other HashMap. +// Two hash maps are equal if they contain the same key/value pairs. +func (h *HashMap) Equal(other *HashMap) bool { + if h.Len() != other.Len() { + return false + } + return !h.Iter(func(k, v T) bool { + ov, ok := other.Get(k) + if !ok { + return true + } + return !h.eq(v, ov) + }) +} + +// Get returns the value for k. +func (h *HashMap) Get(k T) (T, bool) { + hash := h.hash(k) + for entry := h.table[hash]; entry != nil; entry = entry.next { + if h.eq(entry.k, k) { + return entry.v, true + } + } + return nil, false +} + +// Delete removes the key k. +func (h *HashMap) Delete(k T) { + hash := h.hash(k) + var prev *hashEntry + for entry := h.table[hash]; entry != nil; entry = entry.next { + if h.eq(entry.k, k) { + if prev != nil { + prev.next = entry.next + } else { + h.table[hash] = entry.next + } + h.size-- + return + } + prev = entry + } +} + +// Hash returns the hash code for this hash map. +func (h *HashMap) Hash() int { + var hash int + h.Iter(func(k, v T) bool { + hash += h.hash(k) + h.hash(v) + return false + }) + return hash +} + +// Iter invokes the iter function for each element in the HashMap. +// If the iter function returns true, iteration stops and the return value is true. +// If the iter function never returns true, iteration proceeds through all elements +// and the return value is false. +func (h *HashMap) Iter(iter func(T, T) bool) bool { + for _, entry := range h.table { + for ; entry != nil; entry = entry.next { + if iter(entry.k, entry.v) { + return true + } + } + } + return false +} + +// Len returns the current size of this HashMap. +func (h *HashMap) Len() int { + return h.size +} + +// Put inserts a key/value pair into this HashMap. If the key is already present, the existing +// value is overwritten. +func (h *HashMap) Put(k T, v T) { + hash := h.hash(k) + head := h.table[hash] + for entry := head; entry != nil; entry = entry.next { + if h.eq(entry.k, k) { + entry.v = v + return + } + } + h.table[hash] = &hashEntry{k: k, v: v, next: head} + h.size++ +} + +func (h *HashMap) String() string { + var buf []string + h.Iter(func(k T, v T) bool { + buf = append(buf, fmt.Sprintf("%v: %v", k, v)) + return false + }) + return "{" + strings.Join(buf, ", ") + "}" +} + +// Update returns a new HashMap with elements from the other HashMap put into this HashMap. +// If the other HashMap contains elements with the same key as this HashMap, the value +// from the other HashMap overwrites the value from this HashMap. +func (h *HashMap) Update(other *HashMap) *HashMap { + updated := h.Copy() + other.Iter(func(k, v T) bool { + updated.Put(k, v) + return false + }) + return updated +} diff --git a/vendor/github.com/open-policy-agent/opa/v1/util/json.go b/vendor/github.com/open-policy-agent/opa/v1/util/json.go new file mode 100644 index 000000000..5a4e460b6 --- /dev/null +++ b/vendor/github.com/open-policy-agent/opa/v1/util/json.go @@ -0,0 +1,133 @@ +// Copyright 2016 The OPA Authors. All rights reserved. +// Use of this source code is governed by an Apache2 +// license that can be found in the LICENSE file. + +package util + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "reflect" + + "sigs.k8s.io/yaml" + + "github.com/open-policy-agent/opa/v1/loader/extension" +) + +// UnmarshalJSON parses the JSON encoded data and stores the result in the value +// pointed to by x. +// +// This function is intended to be used in place of the standard json.Marshal +// function when json.Number is required. +func UnmarshalJSON(bs []byte, x interface{}) error { + return unmarshalJSON(bs, x, true) +} + +func unmarshalJSON(bs []byte, x interface{}, ext bool) error { + buf := bytes.NewBuffer(bs) + decoder := NewJSONDecoder(buf) + if err := decoder.Decode(x); err != nil { + if handler := extension.FindExtension(".json"); handler != nil && ext { + return handler(bs, x) + } + return err + } + + // Since decoder.Decode validates only the first json structure in bytes, + // check if decoder has more bytes to consume to validate whole input bytes. + tok, err := decoder.Token() + if tok != nil { + return fmt.Errorf("error: invalid character '%s' after top-level value", tok) + } + if err != nil && err != io.EOF { + return err + } + return nil +} + +// NewJSONDecoder returns a new decoder that reads from r. +// +// This function is intended to be used in place of the standard json.NewDecoder +// when json.Number is required. +func NewJSONDecoder(r io.Reader) *json.Decoder { + decoder := json.NewDecoder(r) + decoder.UseNumber() + return decoder +} + +// MustUnmarshalJSON parse the JSON encoded data and returns the result. +// +// If the data cannot be decoded, this function will panic. This function is for +// test purposes. +func MustUnmarshalJSON(bs []byte) interface{} { + var x interface{} + if err := UnmarshalJSON(bs, &x); err != nil { + panic(err) + } + return x +} + +// MustMarshalJSON returns the JSON encoding of x +// +// If the data cannot be encoded, this function will panic. This function is for +// test purposes. +func MustMarshalJSON(x interface{}) []byte { + bs, err := json.Marshal(x) + if err != nil { + panic(err) + } + return bs +} + +// RoundTrip encodes to JSON, and decodes the result again. +// +// Thereby, it is converting its argument to the representation expected by +// rego.Input and inmem's Write operations. Works with both references and +// values. +func RoundTrip(x *interface{}) error { + bs, err := json.Marshal(x) + if err != nil { + return err + } + return UnmarshalJSON(bs, x) +} + +// Reference returns a pointer to its argument unless the argument already is +// a pointer. If the argument is **t, or ***t, etc, it will return *t. +// +// Used for preparing Go types (including pointers to structs) into values to be +// put through util.RoundTrip(). +func Reference(x interface{}) *interface{} { + var y interface{} + rv := reflect.ValueOf(x) + if rv.Kind() == reflect.Ptr { + return Reference(rv.Elem().Interface()) + } + if rv.Kind() != reflect.Invalid { + y = rv.Interface() + return &y + } + return &x +} + +// Unmarshal decodes a YAML, JSON or JSON extension value into the specified type. +func Unmarshal(bs []byte, v interface{}) error { + if len(bs) > 2 && bs[0] == 0xef && bs[1] == 0xbb && bs[2] == 0xbf { + bs = bs[3:] // Strip UTF-8 BOM, see https://www.rfc-editor.org/rfc/rfc8259#section-8.1 + } + + if json.Valid(bs) { + return unmarshalJSON(bs, v, false) + } + nbs, err := yaml.YAMLToJSON(bs) + if err == nil { + return unmarshalJSON(nbs, v, false) + } + // not json or yaml: try extensions + if handler := extension.FindExtension(".json"); handler != nil { + return handler(bs, v) + } + return err +} diff --git a/vendor/github.com/open-policy-agent/opa/v1/util/maps.go b/vendor/github.com/open-policy-agent/opa/v1/util/maps.go new file mode 100644 index 000000000..d943b4d0a --- /dev/null +++ b/vendor/github.com/open-policy-agent/opa/v1/util/maps.go @@ -0,0 +1,10 @@ +package util + +// Values returns a slice of values from any map. Copied from golang.org/x/exp/maps. +func Values[M ~map[K]V, K comparable, V any](m M) []V { + r := make([]V, 0, len(m)) + for _, v := range m { + r = append(r, v) + } + return r +} diff --git a/vendor/github.com/open-policy-agent/opa/v1/util/performance.go b/vendor/github.com/open-policy-agent/opa/v1/util/performance.go new file mode 100644 index 000000000..03dc7d060 --- /dev/null +++ b/vendor/github.com/open-policy-agent/opa/v1/util/performance.go @@ -0,0 +1,24 @@ +package util + +import "slices" + +// NewPtrSlice returns a slice of pointers to T with length n, +// with only 2 allocations performed no matter the size of n. +// See: +// https://gist.github.com/CAFxX/e96e8a5c3841d152f16d266a1fe7f8bd#slices-of-pointers +func NewPtrSlice[T any](n int) []*T { + return GrowPtrSlice[T](nil, n) +} + +// GrowPtrSlice appends n elements to the slice, each pointing to +// a newly-allocated T. The resulting slice has length equal to len(s)+n. +// +// It performs at most 2 allocations, regardless of n. +func GrowPtrSlice[T any](s []*T, n int) []*T { + s = slices.Grow(s, n) + p := make([]T, n) + for i := 0; i < n; i++ { + s = append(s, &p[i]) + } + return s +} diff --git a/vendor/github.com/open-policy-agent/opa/v1/util/queue.go b/vendor/github.com/open-policy-agent/opa/v1/util/queue.go new file mode 100644 index 000000000..63a2ffc16 --- /dev/null +++ b/vendor/github.com/open-policy-agent/opa/v1/util/queue.go @@ -0,0 +1,113 @@ +// Copyright 2017 The OPA Authors. All rights reserved. +// Use of this source code is governed by an Apache2 +// license that can be found in the LICENSE file. + +package util + +// LIFO represents a simple LIFO queue. +type LIFO struct { + top *queueNode + size int +} + +type queueNode struct { + v T + next *queueNode +} + +// NewLIFO returns a new LIFO queue containing elements ts starting with the +// left-most argument at the bottom. +func NewLIFO(ts ...T) *LIFO { + s := &LIFO{} + for i := range ts { + s.Push(ts[i]) + } + return s +} + +// Push adds a new element onto the LIFO. +func (s *LIFO) Push(t T) { + node := &queueNode{v: t, next: s.top} + s.top = node + s.size++ +} + +// Peek returns the top of the LIFO. If LIFO is empty, returns nil, false. +func (s *LIFO) Peek() (T, bool) { + if s.top == nil { + return nil, false + } + return s.top.v, true +} + +// Pop returns the top of the LIFO and removes it. If LIFO is empty returns +// nil, false. +func (s *LIFO) Pop() (T, bool) { + if s.top == nil { + return nil, false + } + node := s.top + s.top = node.next + s.size-- + return node.v, true +} + +// Size returns the size of the LIFO. +func (s *LIFO) Size() int { + return s.size +} + +// FIFO represents a simple FIFO queue. +type FIFO struct { + front *queueNode + back *queueNode + size int +} + +// NewFIFO returns a new FIFO queue containing elements ts starting with the +// left-most argument at the front. +func NewFIFO(ts ...T) *FIFO { + s := &FIFO{} + for i := range ts { + s.Push(ts[i]) + } + return s +} + +// Push adds a new element onto the LIFO. +func (s *FIFO) Push(t T) { + node := &queueNode{v: t, next: nil} + if s.front == nil { + s.front = node + s.back = node + } else { + s.back.next = node + s.back = node + } + s.size++ +} + +// Peek returns the top of the LIFO. If LIFO is empty, returns nil, false. +func (s *FIFO) Peek() (T, bool) { + if s.front == nil { + return nil, false + } + return s.front.v, true +} + +// Pop returns the top of the LIFO and removes it. If LIFO is empty returns +// nil, false. +func (s *FIFO) Pop() (T, bool) { + if s.front == nil { + return nil, false + } + node := s.front + s.front = node.next + s.size-- + return node.v, true +} + +// Size returns the size of the LIFO. +func (s *FIFO) Size() int { + return s.size +} diff --git a/vendor/github.com/open-policy-agent/opa/v1/util/read_gzip_body.go b/vendor/github.com/open-policy-agent/opa/v1/util/read_gzip_body.go new file mode 100644 index 000000000..74bca7263 --- /dev/null +++ b/vendor/github.com/open-policy-agent/opa/v1/util/read_gzip_body.go @@ -0,0 +1,81 @@ +package util + +import ( + "bytes" + "compress/gzip" + "encoding/binary" + "fmt" + "io" + "net/http" + "strings" + "sync" + + "github.com/open-policy-agent/opa/v1/util/decoding" +) + +var gzipReaderPool = sync.Pool{ + New: func() interface{} { + reader := new(gzip.Reader) + return reader + }, +} + +// Note(philipc): Originally taken from server/server.go +// The DecodingLimitHandler handles validating that the gzip payload is within the +// allowed max size limit. Thus, in the event of a forged payload size trailer, +// the worst that can happen is that we waste memory up to the allowed max gzip +// payload size, but not an unbounded amount of memory, as was potentially +// possible before. +func ReadMaybeCompressedBody(r *http.Request) ([]byte, error) { + var content *bytes.Buffer + // Note(philipc): If the request body is of unknown length (such as what + // happens when 'Transfer-Encoding: chunked' is set), we have to do an + // incremental read of the body. In this case, we can't be too clever, we + // just do the best we can with whatever is streamed over to us. + // Fetch gzip payload size limit from request context. + if maxLength, ok := decoding.GetServerDecodingMaxLen(r.Context()); ok { + bs, err := io.ReadAll(io.LimitReader(r.Body, maxLength)) + if err != nil { + return bs, err + } + content = bytes.NewBuffer(bs) + } else { + // Read content from the request body into a buffer of known size. + content = bytes.NewBuffer(make([]byte, 0, r.ContentLength)) + if _, err := io.CopyN(content, r.Body, r.ContentLength); err != nil { + return content.Bytes(), err + } + } + + // Decompress gzip content by reading from the buffer. + if strings.Contains(r.Header.Get("Content-Encoding"), "gzip") { + // Fetch gzip payload size limit from request context. + gzipMaxLength, _ := decoding.GetServerDecodingGzipMaxLen(r.Context()) + + // Note(philipc): The last 4 bytes of a well-formed gzip blob will + // always be a little-endian uint32, representing the decompressed + // content size, modulo 2^32. We validate that the size is safe, + // earlier in DecodingLimitHandler. + sizeTrailerField := binary.LittleEndian.Uint32(content.Bytes()[content.Len()-4:]) + if sizeTrailerField > uint32(gzipMaxLength) { + return content.Bytes(), fmt.Errorf("gzip payload too large") + } + // Pull a gzip decompressor from the pool, and assign it to the current + // buffer, using Reset(). Later, return it back to the pool for another + // request to use. + gzReader := gzipReaderPool.Get().(*gzip.Reader) + if err := gzReader.Reset(content); err != nil { + return nil, err + } + defer gzReader.Close() + defer gzipReaderPool.Put(gzReader) + decompressedContent := bytes.NewBuffer(make([]byte, 0, sizeTrailerField)) + if _, err := io.CopyN(decompressedContent, gzReader, int64(sizeTrailerField)); err != nil { + return decompressedContent.Bytes(), err + } + return decompressedContent.Bytes(), nil + } + + // Request was not compressed; return the content bytes. + return content.Bytes(), nil +} diff --git a/vendor/github.com/open-policy-agent/opa/v1/util/time.go b/vendor/github.com/open-policy-agent/opa/v1/util/time.go new file mode 100644 index 000000000..93ef03939 --- /dev/null +++ b/vendor/github.com/open-policy-agent/opa/v1/util/time.go @@ -0,0 +1,48 @@ +package util + +import "time" + +// TimerWithCancel exists because of memory leaks when using +// time.After in select statements. Instead, we now manually create timers, +// wait on them, and manually free them. +// +// See this for more details: +// https://www.arangodb.com/2020/09/a-story-of-a-memory-leak-in-go-how-to-properly-use-time-after/ +// +// Note: This issue is fixed in Go 1.23, but this fix helps us until then. +// +// Warning: the cancel cannot be done concurrent to reading, everything should +// work in the same goroutine. +// +// Example: +// +// for retries := 0; true; retries++ { +// +// ...main logic... +// +// timer, cancel := utils.TimerWithCancel(utils.Backoff(retries)) +// select { +// case <-ctx.Done(): +// cancel() +// return ctx.Err() +// case <-timer.C: +// continue +// } +// } +func TimerWithCancel(delay time.Duration) (*time.Timer, func()) { + timer := time.NewTimer(delay) + + return timer, func() { + // Note: The Stop function returns: + // - true: if the timer is active. (no draining required) + // - false: if the timer was already stopped or fired/expired. + // In this case the channel should be drained to prevent memory + // leaks only if it is not empty. + // This operation is safe only if the cancel function is + // used in same goroutine. Concurrent reading or canceling may + // cause deadlock. + if !timer.Stop() && len(timer.C) > 0 { + <-timer.C + } + } +} diff --git a/vendor/github.com/open-policy-agent/opa/v1/util/wait.go b/vendor/github.com/open-policy-agent/opa/v1/util/wait.go new file mode 100644 index 000000000..b70ab6fcf --- /dev/null +++ b/vendor/github.com/open-policy-agent/opa/v1/util/wait.go @@ -0,0 +1,34 @@ +// Copyright 2020 The OPA Authors. All rights reserved. +// Use of this source code is governed by an Apache2 +// license that can be found in the LICENSE file. + +package util + +import ( + "fmt" + "time" +) + +// WaitFunc will call passed function at an interval and return nil +// as soon this function returns true. +// If timeout is reached before the passed in function returns true +// an error is returned. +func WaitFunc(fun func() bool, interval, timeout time.Duration) error { + if fun() { + return nil + } + ticker := time.NewTicker(interval) + timer := time.NewTimer(timeout) + defer ticker.Stop() + defer timer.Stop() + for { + select { + case <-timer.C: + return fmt.Errorf("timeout") + case <-ticker.C: + if fun() { + return nil + } + } + } +} diff --git a/vendor/github.com/open-policy-agent/opa/version/version.go b/vendor/github.com/open-policy-agent/opa/v1/version/version.go similarity index 97% rename from vendor/github.com/open-policy-agent/opa/version/version.go rename to vendor/github.com/open-policy-agent/opa/v1/version/version.go index 862556bce..3b550b81f 100644 --- a/vendor/github.com/open-policy-agent/opa/version/version.go +++ b/vendor/github.com/open-policy-agent/opa/v1/version/version.go @@ -11,7 +11,7 @@ import ( ) // Version is the canonical version of OPA. -var Version = "0.70.0" +var Version = "1.0.0" // GoVersion is the version of Go this was built with var GoVersion = runtime.Version() diff --git a/vendor/github.com/open-policy-agent/opa/version/wasm.go b/vendor/github.com/open-policy-agent/opa/v1/version/wasm.go similarity index 100% rename from vendor/github.com/open-policy-agent/opa/version/wasm.go rename to vendor/github.com/open-policy-agent/opa/v1/version/wasm.go diff --git a/vendor/go.opentelemetry.io/auto/sdk/CONTRIBUTING.md b/vendor/go.opentelemetry.io/auto/sdk/CONTRIBUTING.md new file mode 100644 index 000000000..773c9b643 --- /dev/null +++ b/vendor/go.opentelemetry.io/auto/sdk/CONTRIBUTING.md @@ -0,0 +1,27 @@ +# Contributing to go.opentelemetry.io/auto/sdk + +The `go.opentelemetry.io/auto/sdk` module is a purpose built OpenTelemetry SDK. +It is designed to be: + +0. An OpenTelemetry compliant SDK +1. Instrumented by auto-instrumentation (serializable into OTLP JSON) +2. Lightweight +3. User-friendly + +These design choices are listed in the order of their importance. + +The primary design goal of this module is to be an OpenTelemetry SDK. +This means that it needs to implement the Go APIs found in `go.opentelemetry.io/otel`. + +Having met the requirement of SDK compliance, this module needs to provide code that the `go.opentelemetry.io/auto` module can instrument. +The chosen approach to meet this goal is to ensure the telemetry from the SDK is serializable into JSON encoded OTLP. +This ensures then that the serialized form is compatible with other OpenTelemetry systems, and the auto-instrumentation can use these systems to deserialize any telemetry it is sent. + +Outside of these first two goals, the intended use becomes relevant. +This package is intended to be used in the `go.opentelemetry.io/otel` global API as a default when the auto-instrumentation is running. +Because of this, this package needs to not add unnecessary dependencies to that API. +Ideally, it adds none. +It also needs to operate efficiently. + +Finally, this module is designed to be user-friendly to Go development. +It hides complexity in order to provide simpler APIs when the previous goals can all still be met. diff --git a/vendor/go.opentelemetry.io/auto/sdk/LICENSE b/vendor/go.opentelemetry.io/auto/sdk/LICENSE new file mode 100644 index 000000000..261eeb9e9 --- /dev/null +++ b/vendor/go.opentelemetry.io/auto/sdk/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/vendor/go.opentelemetry.io/auto/sdk/VERSIONING.md b/vendor/go.opentelemetry.io/auto/sdk/VERSIONING.md new file mode 100644 index 000000000..088d19a6c --- /dev/null +++ b/vendor/go.opentelemetry.io/auto/sdk/VERSIONING.md @@ -0,0 +1,15 @@ +# Versioning + +This document describes the versioning policy for this module. +This policy is designed so the following goals can be achieved. + +**Users are provided a codebase of value that is stable and secure.** + +## Policy + +* Versioning of this module will be idiomatic of a Go project using [Go modules](https://github.com/golang/go/wiki/Modules). + * [Semantic import versioning](https://github.com/golang/go/wiki/Modules#semantic-import-versioning) will be used. + * Versions will comply with [semver 2.0](https://semver.org/spec/v2.0.0.html). + * Any `v2` or higher version of this module will be included as a `/vN` at the end of the module path used in `go.mod` files and in the package import path. + +* GitHub releases will be made for all releases. diff --git a/vendor/go.opentelemetry.io/auto/sdk/doc.go b/vendor/go.opentelemetry.io/auto/sdk/doc.go new file mode 100644 index 000000000..ad73d8cb9 --- /dev/null +++ b/vendor/go.opentelemetry.io/auto/sdk/doc.go @@ -0,0 +1,14 @@ +// Copyright The OpenTelemetry Authors +// SPDX-License-Identifier: Apache-2.0 + +/* +Package sdk provides an auto-instrumentable OpenTelemetry SDK. + +An [go.opentelemetry.io/auto.Instrumentation] can be configured to target the +process running this SDK. In that case, all telemetry the SDK produces will be +processed and handled by that [go.opentelemetry.io/auto.Instrumentation]. + +By default, if there is no [go.opentelemetry.io/auto.Instrumentation] set to +auto-instrument the SDK, the SDK will not generate any telemetry. +*/ +package sdk diff --git a/vendor/go.opentelemetry.io/auto/sdk/internal/telemetry/attr.go b/vendor/go.opentelemetry.io/auto/sdk/internal/telemetry/attr.go new file mode 100644 index 000000000..af6ef171f --- /dev/null +++ b/vendor/go.opentelemetry.io/auto/sdk/internal/telemetry/attr.go @@ -0,0 +1,58 @@ +// Copyright The OpenTelemetry Authors +// SPDX-License-Identifier: Apache-2.0 + +package telemetry + +// Attr is a key-value pair. +type Attr struct { + Key string `json:"key,omitempty"` + Value Value `json:"value,omitempty"` +} + +// String returns an Attr for a string value. +func String(key, value string) Attr { + return Attr{key, StringValue(value)} +} + +// Int64 returns an Attr for an int64 value. +func Int64(key string, value int64) Attr { + return Attr{key, Int64Value(value)} +} + +// Int returns an Attr for an int value. +func Int(key string, value int) Attr { + return Int64(key, int64(value)) +} + +// Float64 returns an Attr for a float64 value. +func Float64(key string, value float64) Attr { + return Attr{key, Float64Value(value)} +} + +// Bool returns an Attr for a bool value. +func Bool(key string, value bool) Attr { + return Attr{key, BoolValue(value)} +} + +// Bytes returns an Attr for a []byte value. +// The passed slice must not be changed after it is passed. +func Bytes(key string, value []byte) Attr { + return Attr{key, BytesValue(value)} +} + +// Slice returns an Attr for a []Value value. +// The passed slice must not be changed after it is passed. +func Slice(key string, value ...Value) Attr { + return Attr{key, SliceValue(value...)} +} + +// Map returns an Attr for a map value. +// The passed slice must not be changed after it is passed. +func Map(key string, value ...Attr) Attr { + return Attr{key, MapValue(value...)} +} + +// Equal returns if a is equal to b. +func (a Attr) Equal(b Attr) bool { + return a.Key == b.Key && a.Value.Equal(b.Value) +} diff --git a/vendor/go.opentelemetry.io/auto/sdk/internal/telemetry/doc.go b/vendor/go.opentelemetry.io/auto/sdk/internal/telemetry/doc.go new file mode 100644 index 000000000..949e2165c --- /dev/null +++ b/vendor/go.opentelemetry.io/auto/sdk/internal/telemetry/doc.go @@ -0,0 +1,8 @@ +// Copyright The OpenTelemetry Authors +// SPDX-License-Identifier: Apache-2.0 + +/* +Package telemetry provides a lightweight representations of OpenTelemetry +telemetry that is compatible with the OTLP JSON protobuf encoding. +*/ +package telemetry diff --git a/vendor/go.opentelemetry.io/auto/sdk/internal/telemetry/id.go b/vendor/go.opentelemetry.io/auto/sdk/internal/telemetry/id.go new file mode 100644 index 000000000..e854d7e84 --- /dev/null +++ b/vendor/go.opentelemetry.io/auto/sdk/internal/telemetry/id.go @@ -0,0 +1,103 @@ +// Copyright The OpenTelemetry Authors +// SPDX-License-Identifier: Apache-2.0 + +package telemetry + +import ( + "encoding/hex" + "errors" + "fmt" +) + +const ( + traceIDSize = 16 + spanIDSize = 8 +) + +// TraceID is a custom data type that is used for all trace IDs. +type TraceID [traceIDSize]byte + +// String returns the hex string representation form of a TraceID. +func (tid TraceID) String() string { + return hex.EncodeToString(tid[:]) +} + +// IsEmpty returns false if id contains at least one non-zero byte. +func (tid TraceID) IsEmpty() bool { + return tid == [traceIDSize]byte{} +} + +// MarshalJSON converts the trace ID into a hex string enclosed in quotes. +func (tid TraceID) MarshalJSON() ([]byte, error) { + if tid.IsEmpty() { + return []byte(`""`), nil + } + return marshalJSON(tid[:]) +} + +// UnmarshalJSON inflates the trace ID from hex string, possibly enclosed in +// quotes. +func (tid *TraceID) UnmarshalJSON(data []byte) error { + *tid = [traceIDSize]byte{} + return unmarshalJSON(tid[:], data) +} + +// SpanID is a custom data type that is used for all span IDs. +type SpanID [spanIDSize]byte + +// String returns the hex string representation form of a SpanID. +func (sid SpanID) String() string { + return hex.EncodeToString(sid[:]) +} + +// IsEmpty returns true if the span ID contains at least one non-zero byte. +func (sid SpanID) IsEmpty() bool { + return sid == [spanIDSize]byte{} +} + +// MarshalJSON converts span ID into a hex string enclosed in quotes. +func (sid SpanID) MarshalJSON() ([]byte, error) { + if sid.IsEmpty() { + return []byte(`""`), nil + } + return marshalJSON(sid[:]) +} + +// UnmarshalJSON decodes span ID from hex string, possibly enclosed in quotes. +func (sid *SpanID) UnmarshalJSON(data []byte) error { + *sid = [spanIDSize]byte{} + return unmarshalJSON(sid[:], data) +} + +// marshalJSON converts id into a hex string enclosed in quotes. +func marshalJSON(id []byte) ([]byte, error) { + // Plus 2 quote chars at the start and end. + hexLen := hex.EncodedLen(len(id)) + 2 + + b := make([]byte, hexLen) + hex.Encode(b[1:hexLen-1], id) + b[0], b[hexLen-1] = '"', '"' + + return b, nil +} + +// unmarshalJSON inflates trace id from hex string, possibly enclosed in quotes. +func unmarshalJSON(dst []byte, src []byte) error { + if l := len(src); l >= 2 && src[0] == '"' && src[l-1] == '"' { + src = src[1 : l-1] + } + nLen := len(src) + if nLen == 0 { + return nil + } + + if len(dst) != hex.DecodedLen(nLen) { + return errors.New("invalid length for ID") + } + + _, err := hex.Decode(dst, src) + if err != nil { + return fmt.Errorf("cannot unmarshal ID from string '%s': %w", string(src), err) + } + return nil +} diff --git a/vendor/go.opentelemetry.io/auto/sdk/internal/telemetry/number.go b/vendor/go.opentelemetry.io/auto/sdk/internal/telemetry/number.go new file mode 100644 index 000000000..29e629d66 --- /dev/null +++ b/vendor/go.opentelemetry.io/auto/sdk/internal/telemetry/number.go @@ -0,0 +1,67 @@ +// Copyright The OpenTelemetry Authors +// SPDX-License-Identifier: Apache-2.0 + +package telemetry + +import ( + "encoding/json" + "strconv" +) + +// protoInt64 represents the protobuf encoding of integers which can be either +// strings or integers. +type protoInt64 int64 + +// Int64 returns the protoInt64 as an int64. +func (i *protoInt64) Int64() int64 { return int64(*i) } + +// UnmarshalJSON decodes both strings and integers. +func (i *protoInt64) UnmarshalJSON(data []byte) error { + if data[0] == '"' { + var str string + if err := json.Unmarshal(data, &str); err != nil { + return err + } + parsedInt, err := strconv.ParseInt(str, 10, 64) + if err != nil { + return err + } + *i = protoInt64(parsedInt) + } else { + var parsedInt int64 + if err := json.Unmarshal(data, &parsedInt); err != nil { + return err + } + *i = protoInt64(parsedInt) + } + return nil +} + +// protoUint64 represents the protobuf encoding of integers which can be either +// strings or integers. +type protoUint64 uint64 + +// Int64 returns the protoUint64 as a uint64. +func (i *protoUint64) Uint64() uint64 { return uint64(*i) } + +// UnmarshalJSON decodes both strings and integers. +func (i *protoUint64) UnmarshalJSON(data []byte) error { + if data[0] == '"' { + var str string + if err := json.Unmarshal(data, &str); err != nil { + return err + } + parsedUint, err := strconv.ParseUint(str, 10, 64) + if err != nil { + return err + } + *i = protoUint64(parsedUint) + } else { + var parsedUint uint64 + if err := json.Unmarshal(data, &parsedUint); err != nil { + return err + } + *i = protoUint64(parsedUint) + } + return nil +} diff --git a/vendor/go.opentelemetry.io/auto/sdk/internal/telemetry/resource.go b/vendor/go.opentelemetry.io/auto/sdk/internal/telemetry/resource.go new file mode 100644 index 000000000..cecad8bae --- /dev/null +++ b/vendor/go.opentelemetry.io/auto/sdk/internal/telemetry/resource.go @@ -0,0 +1,66 @@ +// Copyright The OpenTelemetry Authors +// SPDX-License-Identifier: Apache-2.0 + +package telemetry + +import ( + "bytes" + "encoding/json" + "errors" + "fmt" + "io" +) + +// Resource information. +type Resource struct { + // Attrs are the set of attributes that describe the resource. Attribute + // keys MUST be unique (it is not allowed to have more than one attribute + // with the same key). + Attrs []Attr `json:"attributes,omitempty"` + // DroppedAttrs is the number of dropped attributes. If the value + // is 0, then no attributes were dropped. + DroppedAttrs uint32 `json:"droppedAttributesCount,omitempty"` +} + +// UnmarshalJSON decodes the OTLP formatted JSON contained in data into r. +func (r *Resource) UnmarshalJSON(data []byte) error { + decoder := json.NewDecoder(bytes.NewReader(data)) + + t, err := decoder.Token() + if err != nil { + return err + } + if t != json.Delim('{') { + return errors.New("invalid Resource type") + } + + for decoder.More() { + keyIface, err := decoder.Token() + if err != nil { + if errors.Is(err, io.EOF) { + // Empty. + return nil + } + return err + } + + key, ok := keyIface.(string) + if !ok { + return fmt.Errorf("invalid Resource field: %#v", keyIface) + } + + switch key { + case "attributes": + err = decoder.Decode(&r.Attrs) + case "droppedAttributesCount", "dropped_attributes_count": + err = decoder.Decode(&r.DroppedAttrs) + default: + // Skip unknown. + } + + if err != nil { + return err + } + } + return nil +} diff --git a/vendor/go.opentelemetry.io/auto/sdk/internal/telemetry/scope.go b/vendor/go.opentelemetry.io/auto/sdk/internal/telemetry/scope.go new file mode 100644 index 000000000..b6f2e28d4 --- /dev/null +++ b/vendor/go.opentelemetry.io/auto/sdk/internal/telemetry/scope.go @@ -0,0 +1,67 @@ +// Copyright The OpenTelemetry Authors +// SPDX-License-Identifier: Apache-2.0 + +package telemetry + +import ( + "bytes" + "encoding/json" + "errors" + "fmt" + "io" +) + +// Scope is the identifying values of the instrumentation scope. +type Scope struct { + Name string `json:"name,omitempty"` + Version string `json:"version,omitempty"` + Attrs []Attr `json:"attributes,omitempty"` + DroppedAttrs uint32 `json:"droppedAttributesCount,omitempty"` +} + +// UnmarshalJSON decodes the OTLP formatted JSON contained in data into r. +func (s *Scope) UnmarshalJSON(data []byte) error { + decoder := json.NewDecoder(bytes.NewReader(data)) + + t, err := decoder.Token() + if err != nil { + return err + } + if t != json.Delim('{') { + return errors.New("invalid Scope type") + } + + for decoder.More() { + keyIface, err := decoder.Token() + if err != nil { + if errors.Is(err, io.EOF) { + // Empty. + return nil + } + return err + } + + key, ok := keyIface.(string) + if !ok { + return fmt.Errorf("invalid Scope field: %#v", keyIface) + } + + switch key { + case "name": + err = decoder.Decode(&s.Name) + case "version": + err = decoder.Decode(&s.Version) + case "attributes": + err = decoder.Decode(&s.Attrs) + case "droppedAttributesCount", "dropped_attributes_count": + err = decoder.Decode(&s.DroppedAttrs) + default: + // Skip unknown. + } + + if err != nil { + return err + } + } + return nil +} diff --git a/vendor/go.opentelemetry.io/auto/sdk/internal/telemetry/span.go b/vendor/go.opentelemetry.io/auto/sdk/internal/telemetry/span.go new file mode 100644 index 000000000..a13a6b733 --- /dev/null +++ b/vendor/go.opentelemetry.io/auto/sdk/internal/telemetry/span.go @@ -0,0 +1,456 @@ +// Copyright The OpenTelemetry Authors +// SPDX-License-Identifier: Apache-2.0 + +package telemetry + +import ( + "bytes" + "encoding/hex" + "encoding/json" + "errors" + "fmt" + "io" + "time" +) + +// A Span represents a single operation performed by a single component of the +// system. +type Span struct { + // A unique identifier for a trace. All spans from the same trace share + // the same `trace_id`. The ID is a 16-byte array. An ID with all zeroes OR + // of length other than 16 bytes is considered invalid (empty string in OTLP/JSON + // is zero-length and thus is also invalid). + // + // This field is required. + TraceID TraceID `json:"traceId,omitempty"` + // A unique identifier for a span within a trace, assigned when the span + // is created. The ID is an 8-byte array. An ID with all zeroes OR of length + // other than 8 bytes is considered invalid (empty string in OTLP/JSON + // is zero-length and thus is also invalid). + // + // This field is required. + SpanID SpanID `json:"spanId,omitempty"` + // trace_state conveys information about request position in multiple distributed tracing graphs. + // It is a trace_state in w3c-trace-context format: https://www.w3.org/TR/trace-context/#tracestate-header + // See also https://github.com/w3c/distributed-tracing for more details about this field. + TraceState string `json:"traceState,omitempty"` + // The `span_id` of this span's parent span. If this is a root span, then this + // field must be empty. The ID is an 8-byte array. + ParentSpanID SpanID `json:"parentSpanId,omitempty"` + // Flags, a bit field. + // + // Bits 0-7 (8 least significant bits) are the trace flags as defined in W3C Trace + // Context specification. To read the 8-bit W3C trace flag, use + // `flags & SPAN_FLAGS_TRACE_FLAGS_MASK`. + // + // See https://www.w3.org/TR/trace-context-2/#trace-flags for the flag definitions. + // + // Bits 8 and 9 represent the 3 states of whether a span's parent + // is remote. The states are (unknown, is not remote, is remote). + // To read whether the value is known, use `(flags & SPAN_FLAGS_CONTEXT_HAS_IS_REMOTE_MASK) != 0`. + // To read whether the span is remote, use `(flags & SPAN_FLAGS_CONTEXT_IS_REMOTE_MASK) != 0`. + // + // When creating span messages, if the message is logically forwarded from another source + // with an equivalent flags fields (i.e., usually another OTLP span message), the field SHOULD + // be copied as-is. If creating from a source that does not have an equivalent flags field + // (such as a runtime representation of an OpenTelemetry span), the high 22 bits MUST + // be set to zero. + // Readers MUST NOT assume that bits 10-31 (22 most significant bits) will be zero. + // + // [Optional]. + Flags uint32 `json:"flags,omitempty"` + // A description of the span's operation. + // + // For example, the name can be a qualified method name or a file name + // and a line number where the operation is called. A best practice is to use + // the same display name at the same call point in an application. + // This makes it easier to correlate spans in different traces. + // + // This field is semantically required to be set to non-empty string. + // Empty value is equivalent to an unknown span name. + // + // This field is required. + Name string `json:"name"` + // Distinguishes between spans generated in a particular context. For example, + // two spans with the same name may be distinguished using `CLIENT` (caller) + // and `SERVER` (callee) to identify queueing latency associated with the span. + Kind SpanKind `json:"kind,omitempty"` + // start_time_unix_nano is the start time of the span. On the client side, this is the time + // kept by the local machine where the span execution starts. On the server side, this + // is the time when the server's application handler starts running. + // Value is UNIX Epoch time in nanoseconds since 00:00:00 UTC on 1 January 1970. + // + // This field is semantically required and it is expected that end_time >= start_time. + StartTime time.Time `json:"startTimeUnixNano,omitempty"` + // end_time_unix_nano is the end time of the span. On the client side, this is the time + // kept by the local machine where the span execution ends. On the server side, this + // is the time when the server application handler stops running. + // Value is UNIX Epoch time in nanoseconds since 00:00:00 UTC on 1 January 1970. + // + // This field is semantically required and it is expected that end_time >= start_time. + EndTime time.Time `json:"endTimeUnixNano,omitempty"` + // attributes is a collection of key/value pairs. Note, global attributes + // like server name can be set using the resource API. Examples of attributes: + // + // "/http/user_agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_14_2) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/71.0.3578.98 Safari/537.36" + // "/http/server_latency": 300 + // "example.com/myattribute": true + // "example.com/score": 10.239 + // + // The OpenTelemetry API specification further restricts the allowed value types: + // https://github.com/open-telemetry/opentelemetry-specification/blob/main/specification/common/README.md#attribute + // Attribute keys MUST be unique (it is not allowed to have more than one + // attribute with the same key). + Attrs []Attr `json:"attributes,omitempty"` + // dropped_attributes_count is the number of attributes that were discarded. Attributes + // can be discarded because their keys are too long or because there are too many + // attributes. If this value is 0, then no attributes were dropped. + DroppedAttrs uint32 `json:"droppedAttributesCount,omitempty"` + // events is a collection of Event items. + Events []*SpanEvent `json:"events,omitempty"` + // dropped_events_count is the number of dropped events. If the value is 0, then no + // events were dropped. + DroppedEvents uint32 `json:"droppedEventsCount,omitempty"` + // links is a collection of Links, which are references from this span to a span + // in the same or different trace. + Links []*SpanLink `json:"links,omitempty"` + // dropped_links_count is the number of dropped links after the maximum size was + // enforced. If this value is 0, then no links were dropped. + DroppedLinks uint32 `json:"droppedLinksCount,omitempty"` + // An optional final status for this span. Semantically when Status isn't set, it means + // span's status code is unset, i.e. assume STATUS_CODE_UNSET (code = 0). + Status *Status `json:"status,omitempty"` +} + +// MarshalJSON encodes s into OTLP formatted JSON. +func (s Span) MarshalJSON() ([]byte, error) { + startT := s.StartTime.UnixNano() + if s.StartTime.IsZero() || startT < 0 { + startT = 0 + } + + endT := s.EndTime.UnixNano() + if s.EndTime.IsZero() || endT < 0 { + endT = 0 + } + + // Override non-empty default SpanID marshal and omitempty. + var parentSpanId string + if !s.ParentSpanID.IsEmpty() { + b := make([]byte, hex.EncodedLen(spanIDSize)) + hex.Encode(b, s.ParentSpanID[:]) + parentSpanId = string(b) + } + + type Alias Span + return json.Marshal(struct { + Alias + ParentSpanID string `json:"parentSpanId,omitempty"` + StartTime uint64 `json:"startTimeUnixNano,omitempty"` + EndTime uint64 `json:"endTimeUnixNano,omitempty"` + }{ + Alias: Alias(s), + ParentSpanID: parentSpanId, + StartTime: uint64(startT), + EndTime: uint64(endT), + }) +} + +// UnmarshalJSON decodes the OTLP formatted JSON contained in data into s. +func (s *Span) UnmarshalJSON(data []byte) error { + decoder := json.NewDecoder(bytes.NewReader(data)) + + t, err := decoder.Token() + if err != nil { + return err + } + if t != json.Delim('{') { + return errors.New("invalid Span type") + } + + for decoder.More() { + keyIface, err := decoder.Token() + if err != nil { + if errors.Is(err, io.EOF) { + // Empty. + return nil + } + return err + } + + key, ok := keyIface.(string) + if !ok { + return fmt.Errorf("invalid Span field: %#v", keyIface) + } + + switch key { + case "traceId", "trace_id": + err = decoder.Decode(&s.TraceID) + case "spanId", "span_id": + err = decoder.Decode(&s.SpanID) + case "traceState", "trace_state": + err = decoder.Decode(&s.TraceState) + case "parentSpanId", "parent_span_id": + err = decoder.Decode(&s.ParentSpanID) + case "flags": + err = decoder.Decode(&s.Flags) + case "name": + err = decoder.Decode(&s.Name) + case "kind": + err = decoder.Decode(&s.Kind) + case "startTimeUnixNano", "start_time_unix_nano": + var val protoUint64 + err = decoder.Decode(&val) + s.StartTime = time.Unix(0, int64(val.Uint64())) + case "endTimeUnixNano", "end_time_unix_nano": + var val protoUint64 + err = decoder.Decode(&val) + s.EndTime = time.Unix(0, int64(val.Uint64())) + case "attributes": + err = decoder.Decode(&s.Attrs) + case "droppedAttributesCount", "dropped_attributes_count": + err = decoder.Decode(&s.DroppedAttrs) + case "events": + err = decoder.Decode(&s.Events) + case "droppedEventsCount", "dropped_events_count": + err = decoder.Decode(&s.DroppedEvents) + case "links": + err = decoder.Decode(&s.Links) + case "droppedLinksCount", "dropped_links_count": + err = decoder.Decode(&s.DroppedLinks) + case "status": + err = decoder.Decode(&s.Status) + default: + // Skip unknown. + } + + if err != nil { + return err + } + } + return nil +} + +// SpanFlags represents constants used to interpret the +// Span.flags field, which is protobuf 'fixed32' type and is to +// be used as bit-fields. Each non-zero value defined in this enum is +// a bit-mask. To extract the bit-field, for example, use an +// expression like: +// +// (span.flags & SPAN_FLAGS_TRACE_FLAGS_MASK) +// +// See https://www.w3.org/TR/trace-context-2/#trace-flags for the flag definitions. +// +// Note that Span flags were introduced in version 1.1 of the +// OpenTelemetry protocol. Older Span producers do not set this +// field, consequently consumers should not rely on the absence of a +// particular flag bit to indicate the presence of a particular feature. +type SpanFlags int32 + +const ( + // Bits 0-7 are used for trace flags. + SpanFlagsTraceFlagsMask SpanFlags = 255 + // Bits 8 and 9 are used to indicate that the parent span or link span is remote. + // Bit 8 (`HAS_IS_REMOTE`) indicates whether the value is known. + // Bit 9 (`IS_REMOTE`) indicates whether the span or link is remote. + SpanFlagsContextHasIsRemoteMask SpanFlags = 256 + // SpanFlagsContextHasIsRemoteMask indicates the Span is remote. + SpanFlagsContextIsRemoteMask SpanFlags = 512 +) + +// SpanKind is the type of span. Can be used to specify additional relationships between spans +// in addition to a parent/child relationship. +type SpanKind int32 + +const ( + // Indicates that the span represents an internal operation within an application, + // as opposed to an operation happening at the boundaries. Default value. + SpanKindInternal SpanKind = 1 + // Indicates that the span covers server-side handling of an RPC or other + // remote network request. + SpanKindServer SpanKind = 2 + // Indicates that the span describes a request to some remote service. + SpanKindClient SpanKind = 3 + // Indicates that the span describes a producer sending a message to a broker. + // Unlike CLIENT and SERVER, there is often no direct critical path latency relationship + // between producer and consumer spans. A PRODUCER span ends when the message was accepted + // by the broker while the logical processing of the message might span a much longer time. + SpanKindProducer SpanKind = 4 + // Indicates that the span describes consumer receiving a message from a broker. + // Like the PRODUCER kind, there is often no direct critical path latency relationship + // between producer and consumer spans. + SpanKindConsumer SpanKind = 5 +) + +// Event is a time-stamped annotation of the span, consisting of user-supplied +// text description and key-value pairs. +type SpanEvent struct { + // time_unix_nano is the time the event occurred. + Time time.Time `json:"timeUnixNano,omitempty"` + // name of the event. + // This field is semantically required to be set to non-empty string. + Name string `json:"name,omitempty"` + // attributes is a collection of attribute key/value pairs on the event. + // Attribute keys MUST be unique (it is not allowed to have more than one + // attribute with the same key). + Attrs []Attr `json:"attributes,omitempty"` + // dropped_attributes_count is the number of dropped attributes. If the value is 0, + // then no attributes were dropped. + DroppedAttrs uint32 `json:"droppedAttributesCount,omitempty"` +} + +// MarshalJSON encodes e into OTLP formatted JSON. +func (e SpanEvent) MarshalJSON() ([]byte, error) { + t := e.Time.UnixNano() + if e.Time.IsZero() || t < 0 { + t = 0 + } + + type Alias SpanEvent + return json.Marshal(struct { + Alias + Time uint64 `json:"timeUnixNano,omitempty"` + }{ + Alias: Alias(e), + Time: uint64(t), + }) +} + +// UnmarshalJSON decodes the OTLP formatted JSON contained in data into se. +func (se *SpanEvent) UnmarshalJSON(data []byte) error { + decoder := json.NewDecoder(bytes.NewReader(data)) + + t, err := decoder.Token() + if err != nil { + return err + } + if t != json.Delim('{') { + return errors.New("invalid SpanEvent type") + } + + for decoder.More() { + keyIface, err := decoder.Token() + if err != nil { + if errors.Is(err, io.EOF) { + // Empty. + return nil + } + return err + } + + key, ok := keyIface.(string) + if !ok { + return fmt.Errorf("invalid SpanEvent field: %#v", keyIface) + } + + switch key { + case "timeUnixNano", "time_unix_nano": + var val protoUint64 + err = decoder.Decode(&val) + se.Time = time.Unix(0, int64(val.Uint64())) + case "name": + err = decoder.Decode(&se.Name) + case "attributes": + err = decoder.Decode(&se.Attrs) + case "droppedAttributesCount", "dropped_attributes_count": + err = decoder.Decode(&se.DroppedAttrs) + default: + // Skip unknown. + } + + if err != nil { + return err + } + } + return nil +} + +// A pointer from the current span to another span in the same trace or in a +// different trace. For example, this can be used in batching operations, +// where a single batch handler processes multiple requests from different +// traces or when the handler receives a request from a different project. +type SpanLink struct { + // A unique identifier of a trace that this linked span is part of. The ID is a + // 16-byte array. + TraceID TraceID `json:"traceId,omitempty"` + // A unique identifier for the linked span. The ID is an 8-byte array. + SpanID SpanID `json:"spanId,omitempty"` + // The trace_state associated with the link. + TraceState string `json:"traceState,omitempty"` + // attributes is a collection of attribute key/value pairs on the link. + // Attribute keys MUST be unique (it is not allowed to have more than one + // attribute with the same key). + Attrs []Attr `json:"attributes,omitempty"` + // dropped_attributes_count is the number of dropped attributes. If the value is 0, + // then no attributes were dropped. + DroppedAttrs uint32 `json:"droppedAttributesCount,omitempty"` + // Flags, a bit field. + // + // Bits 0-7 (8 least significant bits) are the trace flags as defined in W3C Trace + // Context specification. To read the 8-bit W3C trace flag, use + // `flags & SPAN_FLAGS_TRACE_FLAGS_MASK`. + // + // See https://www.w3.org/TR/trace-context-2/#trace-flags for the flag definitions. + // + // Bits 8 and 9 represent the 3 states of whether the link is remote. + // The states are (unknown, is not remote, is remote). + // To read whether the value is known, use `(flags & SPAN_FLAGS_CONTEXT_HAS_IS_REMOTE_MASK) != 0`. + // To read whether the link is remote, use `(flags & SPAN_FLAGS_CONTEXT_IS_REMOTE_MASK) != 0`. + // + // Readers MUST NOT assume that bits 10-31 (22 most significant bits) will be zero. + // When creating new spans, bits 10-31 (most-significant 22-bits) MUST be zero. + // + // [Optional]. + Flags uint32 `json:"flags,omitempty"` +} + +// UnmarshalJSON decodes the OTLP formatted JSON contained in data into sl. +func (sl *SpanLink) UnmarshalJSON(data []byte) error { + decoder := json.NewDecoder(bytes.NewReader(data)) + + t, err := decoder.Token() + if err != nil { + return err + } + if t != json.Delim('{') { + return errors.New("invalid SpanLink type") + } + + for decoder.More() { + keyIface, err := decoder.Token() + if err != nil { + if errors.Is(err, io.EOF) { + // Empty. + return nil + } + return err + } + + key, ok := keyIface.(string) + if !ok { + return fmt.Errorf("invalid SpanLink field: %#v", keyIface) + } + + switch key { + case "traceId", "trace_id": + err = decoder.Decode(&sl.TraceID) + case "spanId", "span_id": + err = decoder.Decode(&sl.SpanID) + case "traceState", "trace_state": + err = decoder.Decode(&sl.TraceState) + case "attributes": + err = decoder.Decode(&sl.Attrs) + case "droppedAttributesCount", "dropped_attributes_count": + err = decoder.Decode(&sl.DroppedAttrs) + case "flags": + err = decoder.Decode(&sl.Flags) + default: + // Skip unknown. + } + + if err != nil { + return err + } + } + return nil +} diff --git a/vendor/go.opentelemetry.io/auto/sdk/internal/telemetry/status.go b/vendor/go.opentelemetry.io/auto/sdk/internal/telemetry/status.go new file mode 100644 index 000000000..1217776ea --- /dev/null +++ b/vendor/go.opentelemetry.io/auto/sdk/internal/telemetry/status.go @@ -0,0 +1,40 @@ +// Copyright The OpenTelemetry Authors +// SPDX-License-Identifier: Apache-2.0 + +package telemetry + +// For the semantics of status codes see +// https://github.com/open-telemetry/opentelemetry-specification/blob/main/specification/trace/api.md#set-status +type StatusCode int32 + +const ( + // The default status. + StatusCodeUnset StatusCode = 0 + // The Span has been validated by an Application developer or Operator to + // have completed successfully. + StatusCodeOK StatusCode = 1 + // The Span contains an error. + StatusCodeError StatusCode = 2 +) + +var statusCodeStrings = []string{ + "Unset", + "OK", + "Error", +} + +func (s StatusCode) String() string { + if s >= 0 && int(s) < len(statusCodeStrings) { + return statusCodeStrings[s] + } + return "" +} + +// The Status type defines a logical error model that is suitable for different +// programming environments, including REST APIs and RPC APIs. +type Status struct { + // A developer-facing human readable error message. + Message string `json:"message,omitempty"` + // The status code. + Code StatusCode `json:"code,omitempty"` +} diff --git a/vendor/go.opentelemetry.io/auto/sdk/internal/telemetry/traces.go b/vendor/go.opentelemetry.io/auto/sdk/internal/telemetry/traces.go new file mode 100644 index 000000000..69a348f0f --- /dev/null +++ b/vendor/go.opentelemetry.io/auto/sdk/internal/telemetry/traces.go @@ -0,0 +1,189 @@ +// Copyright The OpenTelemetry Authors +// SPDX-License-Identifier: Apache-2.0 + +package telemetry + +import ( + "bytes" + "encoding/json" + "errors" + "fmt" + "io" +) + +// Traces represents the traces data that can be stored in a persistent storage, +// OR can be embedded by other protocols that transfer OTLP traces data but do +// not implement the OTLP protocol. +// +// The main difference between this message and collector protocol is that +// in this message there will not be any "control" or "metadata" specific to +// OTLP protocol. +// +// When new fields are added into this message, the OTLP request MUST be updated +// as well. +type Traces struct { + // An array of ResourceSpans. + // For data coming from a single resource this array will typically contain + // one element. Intermediary nodes that receive data from multiple origins + // typically batch the data before forwarding further and in that case this + // array will contain multiple elements. + ResourceSpans []*ResourceSpans `json:"resourceSpans,omitempty"` +} + +// UnmarshalJSON decodes the OTLP formatted JSON contained in data into td. +func (td *Traces) UnmarshalJSON(data []byte) error { + decoder := json.NewDecoder(bytes.NewReader(data)) + + t, err := decoder.Token() + if err != nil { + return err + } + if t != json.Delim('{') { + return errors.New("invalid TracesData type") + } + + for decoder.More() { + keyIface, err := decoder.Token() + if err != nil { + if errors.Is(err, io.EOF) { + // Empty. + return nil + } + return err + } + + key, ok := keyIface.(string) + if !ok { + return fmt.Errorf("invalid TracesData field: %#v", keyIface) + } + + switch key { + case "resourceSpans", "resource_spans": + err = decoder.Decode(&td.ResourceSpans) + default: + // Skip unknown. + } + + if err != nil { + return err + } + } + return nil +} + +// A collection of ScopeSpans from a Resource. +type ResourceSpans struct { + // The resource for the spans in this message. + // If this field is not set then no resource info is known. + Resource Resource `json:"resource"` + // A list of ScopeSpans that originate from a resource. + ScopeSpans []*ScopeSpans `json:"scopeSpans,omitempty"` + // This schema_url applies to the data in the "resource" field. It does not apply + // to the data in the "scope_spans" field which have their own schema_url field. + SchemaURL string `json:"schemaUrl,omitempty"` +} + +// UnmarshalJSON decodes the OTLP formatted JSON contained in data into rs. +func (rs *ResourceSpans) UnmarshalJSON(data []byte) error { + decoder := json.NewDecoder(bytes.NewReader(data)) + + t, err := decoder.Token() + if err != nil { + return err + } + if t != json.Delim('{') { + return errors.New("invalid ResourceSpans type") + } + + for decoder.More() { + keyIface, err := decoder.Token() + if err != nil { + if errors.Is(err, io.EOF) { + // Empty. + return nil + } + return err + } + + key, ok := keyIface.(string) + if !ok { + return fmt.Errorf("invalid ResourceSpans field: %#v", keyIface) + } + + switch key { + case "resource": + err = decoder.Decode(&rs.Resource) + case "scopeSpans", "scope_spans": + err = decoder.Decode(&rs.ScopeSpans) + case "schemaUrl", "schema_url": + err = decoder.Decode(&rs.SchemaURL) + default: + // Skip unknown. + } + + if err != nil { + return err + } + } + return nil +} + +// A collection of Spans produced by an InstrumentationScope. +type ScopeSpans struct { + // The instrumentation scope information for the spans in this message. + // Semantically when InstrumentationScope isn't set, it is equivalent with + // an empty instrumentation scope name (unknown). + Scope *Scope `json:"scope"` + // A list of Spans that originate from an instrumentation scope. + Spans []*Span `json:"spans,omitempty"` + // The Schema URL, if known. This is the identifier of the Schema that the span data + // is recorded in. To learn more about Schema URL see + // https://opentelemetry.io/docs/specs/otel/schemas/#schema-url + // This schema_url applies to all spans and span events in the "spans" field. + SchemaURL string `json:"schemaUrl,omitempty"` +} + +// UnmarshalJSON decodes the OTLP formatted JSON contained in data into ss. +func (ss *ScopeSpans) UnmarshalJSON(data []byte) error { + decoder := json.NewDecoder(bytes.NewReader(data)) + + t, err := decoder.Token() + if err != nil { + return err + } + if t != json.Delim('{') { + return errors.New("invalid ScopeSpans type") + } + + for decoder.More() { + keyIface, err := decoder.Token() + if err != nil { + if errors.Is(err, io.EOF) { + // Empty. + return nil + } + return err + } + + key, ok := keyIface.(string) + if !ok { + return fmt.Errorf("invalid ScopeSpans field: %#v", keyIface) + } + + switch key { + case "scope": + err = decoder.Decode(&ss.Scope) + case "spans": + err = decoder.Decode(&ss.Spans) + case "schemaUrl", "schema_url": + err = decoder.Decode(&ss.SchemaURL) + default: + // Skip unknown. + } + + if err != nil { + return err + } + } + return nil +} diff --git a/vendor/go.opentelemetry.io/auto/sdk/internal/telemetry/value.go b/vendor/go.opentelemetry.io/auto/sdk/internal/telemetry/value.go new file mode 100644 index 000000000..0dd01b063 --- /dev/null +++ b/vendor/go.opentelemetry.io/auto/sdk/internal/telemetry/value.go @@ -0,0 +1,452 @@ +// Copyright The OpenTelemetry Authors +// SPDX-License-Identifier: Apache-2.0 + +//go:generate stringer -type=ValueKind -trimprefix=ValueKind + +package telemetry + +import ( + "bytes" + "cmp" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "io" + "math" + "slices" + "strconv" + "unsafe" +) + +// A Value represents a structured value. +// A zero value is valid and represents an empty value. +type Value struct { + // Ensure forward compatibility by explicitly making this not comparable. + noCmp [0]func() //nolint: unused // This is indeed used. + + // num holds the value for Int64, Float64, and Bool. It holds the length + // for String, Bytes, Slice, Map. + num uint64 + // any holds either the KindBool, KindInt64, KindFloat64, stringptr, + // bytesptr, sliceptr, or mapptr. If KindBool, KindInt64, or KindFloat64 + // then the value of Value is in num as described above. Otherwise, it + // contains the value wrapped in the appropriate type. + any any +} + +type ( + // sliceptr represents a value in Value.any for KindString Values. + stringptr *byte + // bytesptr represents a value in Value.any for KindBytes Values. + bytesptr *byte + // sliceptr represents a value in Value.any for KindSlice Values. + sliceptr *Value + // mapptr represents a value in Value.any for KindMap Values. + mapptr *Attr +) + +// ValueKind is the kind of a [Value]. +type ValueKind int + +// ValueKind values. +const ( + ValueKindEmpty ValueKind = iota + ValueKindBool + ValueKindFloat64 + ValueKindInt64 + ValueKindString + ValueKindBytes + ValueKindSlice + ValueKindMap +) + +var valueKindStrings = []string{ + "Empty", + "Bool", + "Float64", + "Int64", + "String", + "Bytes", + "Slice", + "Map", +} + +func (k ValueKind) String() string { + if k >= 0 && int(k) < len(valueKindStrings) { + return valueKindStrings[k] + } + return "" +} + +// StringValue returns a new [Value] for a string. +func StringValue(v string) Value { + return Value{ + num: uint64(len(v)), + any: stringptr(unsafe.StringData(v)), + } +} + +// IntValue returns a [Value] for an int. +func IntValue(v int) Value { return Int64Value(int64(v)) } + +// Int64Value returns a [Value] for an int64. +func Int64Value(v int64) Value { + return Value{num: uint64(v), any: ValueKindInt64} +} + +// Float64Value returns a [Value] for a float64. +func Float64Value(v float64) Value { + return Value{num: math.Float64bits(v), any: ValueKindFloat64} +} + +// BoolValue returns a [Value] for a bool. +func BoolValue(v bool) Value { //nolint:revive // Not a control flag. + var n uint64 + if v { + n = 1 + } + return Value{num: n, any: ValueKindBool} +} + +// BytesValue returns a [Value] for a byte slice. The passed slice must not be +// changed after it is passed. +func BytesValue(v []byte) Value { + return Value{ + num: uint64(len(v)), + any: bytesptr(unsafe.SliceData(v)), + } +} + +// SliceValue returns a [Value] for a slice of [Value]. The passed slice must +// not be changed after it is passed. +func SliceValue(vs ...Value) Value { + return Value{ + num: uint64(len(vs)), + any: sliceptr(unsafe.SliceData(vs)), + } +} + +// MapValue returns a new [Value] for a slice of key-value pairs. The passed +// slice must not be changed after it is passed. +func MapValue(kvs ...Attr) Value { + return Value{ + num: uint64(len(kvs)), + any: mapptr(unsafe.SliceData(kvs)), + } +} + +// AsString returns the value held by v as a string. +func (v Value) AsString() string { + if sp, ok := v.any.(stringptr); ok { + return unsafe.String(sp, v.num) + } + // TODO: error handle + return "" +} + +// asString returns the value held by v as a string. It will panic if the Value +// is not KindString. +func (v Value) asString() string { + return unsafe.String(v.any.(stringptr), v.num) +} + +// AsInt64 returns the value held by v as an int64. +func (v Value) AsInt64() int64 { + if v.Kind() != ValueKindInt64 { + // TODO: error handle + return 0 + } + return v.asInt64() +} + +// asInt64 returns the value held by v as an int64. If v is not of KindInt64, +// this will return garbage. +func (v Value) asInt64() int64 { + // Assumes v.num was a valid int64 (overflow not checked). + return int64(v.num) // nolint: gosec +} + +// AsBool returns the value held by v as a bool. +func (v Value) AsBool() bool { + if v.Kind() != ValueKindBool { + // TODO: error handle + return false + } + return v.asBool() +} + +// asBool returns the value held by v as a bool. If v is not of KindBool, this +// will return garbage. +func (v Value) asBool() bool { return v.num == 1 } + +// AsFloat64 returns the value held by v as a float64. +func (v Value) AsFloat64() float64 { + if v.Kind() != ValueKindFloat64 { + // TODO: error handle + return 0 + } + return v.asFloat64() +} + +// asFloat64 returns the value held by v as a float64. If v is not of +// KindFloat64, this will return garbage. +func (v Value) asFloat64() float64 { return math.Float64frombits(v.num) } + +// AsBytes returns the value held by v as a []byte. +func (v Value) AsBytes() []byte { + if sp, ok := v.any.(bytesptr); ok { + return unsafe.Slice((*byte)(sp), v.num) + } + // TODO: error handle + return nil +} + +// asBytes returns the value held by v as a []byte. It will panic if the Value +// is not KindBytes. +func (v Value) asBytes() []byte { + return unsafe.Slice((*byte)(v.any.(bytesptr)), v.num) +} + +// AsSlice returns the value held by v as a []Value. +func (v Value) AsSlice() []Value { + if sp, ok := v.any.(sliceptr); ok { + return unsafe.Slice((*Value)(sp), v.num) + } + // TODO: error handle + return nil +} + +// asSlice returns the value held by v as a []Value. It will panic if the Value +// is not KindSlice. +func (v Value) asSlice() []Value { + return unsafe.Slice((*Value)(v.any.(sliceptr)), v.num) +} + +// AsMap returns the value held by v as a []Attr. +func (v Value) AsMap() []Attr { + if sp, ok := v.any.(mapptr); ok { + return unsafe.Slice((*Attr)(sp), v.num) + } + // TODO: error handle + return nil +} + +// asMap returns the value held by v as a []Attr. It will panic if the +// Value is not KindMap. +func (v Value) asMap() []Attr { + return unsafe.Slice((*Attr)(v.any.(mapptr)), v.num) +} + +// Kind returns the Kind of v. +func (v Value) Kind() ValueKind { + switch x := v.any.(type) { + case ValueKind: + return x + case stringptr: + return ValueKindString + case bytesptr: + return ValueKindBytes + case sliceptr: + return ValueKindSlice + case mapptr: + return ValueKindMap + default: + return ValueKindEmpty + } +} + +// Empty returns if v does not hold any value. +func (v Value) Empty() bool { return v.Kind() == ValueKindEmpty } + +// Equal returns if v is equal to w. +func (v Value) Equal(w Value) bool { + k1 := v.Kind() + k2 := w.Kind() + if k1 != k2 { + return false + } + switch k1 { + case ValueKindInt64, ValueKindBool: + return v.num == w.num + case ValueKindString: + return v.asString() == w.asString() + case ValueKindFloat64: + return v.asFloat64() == w.asFloat64() + case ValueKindSlice: + return slices.EqualFunc(v.asSlice(), w.asSlice(), Value.Equal) + case ValueKindMap: + sv := sortMap(v.asMap()) + sw := sortMap(w.asMap()) + return slices.EqualFunc(sv, sw, Attr.Equal) + case ValueKindBytes: + return bytes.Equal(v.asBytes(), w.asBytes()) + case ValueKindEmpty: + return true + default: + // TODO: error handle + return false + } +} + +func sortMap(m []Attr) []Attr { + sm := make([]Attr, len(m)) + copy(sm, m) + slices.SortFunc(sm, func(a, b Attr) int { + return cmp.Compare(a.Key, b.Key) + }) + + return sm +} + +// String returns Value's value as a string, formatted like [fmt.Sprint]. +// +// The returned string is meant for debugging; +// the string representation is not stable. +func (v Value) String() string { + switch v.Kind() { + case ValueKindString: + return v.asString() + case ValueKindInt64: + // Assumes v.num was a valid int64 (overflow not checked). + return strconv.FormatInt(int64(v.num), 10) // nolint: gosec + case ValueKindFloat64: + return strconv.FormatFloat(v.asFloat64(), 'g', -1, 64) + case ValueKindBool: + return strconv.FormatBool(v.asBool()) + case ValueKindBytes: + return fmt.Sprint(v.asBytes()) + case ValueKindMap: + return fmt.Sprint(v.asMap()) + case ValueKindSlice: + return fmt.Sprint(v.asSlice()) + case ValueKindEmpty: + return "" + default: + // Try to handle this as gracefully as possible. + // + // Don't panic here. The goal here is to have developers find this + // first if a slog.Kind is is not handled. It is + // preferable to have user's open issue asking why their attributes + // have a "unhandled: " prefix than say that their code is panicking. + return fmt.Sprintf("", v.Kind()) + } +} + +// MarshalJSON encodes v into OTLP formatted JSON. +func (v *Value) MarshalJSON() ([]byte, error) { + switch v.Kind() { + case ValueKindString: + return json.Marshal(struct { + Value string `json:"stringValue"` + }{v.asString()}) + case ValueKindInt64: + return json.Marshal(struct { + Value string `json:"intValue"` + }{strconv.FormatInt(int64(v.num), 10)}) + case ValueKindFloat64: + return json.Marshal(struct { + Value float64 `json:"doubleValue"` + }{v.asFloat64()}) + case ValueKindBool: + return json.Marshal(struct { + Value bool `json:"boolValue"` + }{v.asBool()}) + case ValueKindBytes: + return json.Marshal(struct { + Value []byte `json:"bytesValue"` + }{v.asBytes()}) + case ValueKindMap: + return json.Marshal(struct { + Value struct { + Values []Attr `json:"values"` + } `json:"kvlistValue"` + }{struct { + Values []Attr `json:"values"` + }{v.asMap()}}) + case ValueKindSlice: + return json.Marshal(struct { + Value struct { + Values []Value `json:"values"` + } `json:"arrayValue"` + }{struct { + Values []Value `json:"values"` + }{v.asSlice()}}) + case ValueKindEmpty: + return nil, nil + default: + return nil, fmt.Errorf("unknown Value kind: %s", v.Kind().String()) + } +} + +// UnmarshalJSON decodes the OTLP formatted JSON contained in data into v. +func (v *Value) UnmarshalJSON(data []byte) error { + decoder := json.NewDecoder(bytes.NewReader(data)) + + t, err := decoder.Token() + if err != nil { + return err + } + if t != json.Delim('{') { + return errors.New("invalid Value type") + } + + for decoder.More() { + keyIface, err := decoder.Token() + if err != nil { + if errors.Is(err, io.EOF) { + // Empty. + return nil + } + return err + } + + key, ok := keyIface.(string) + if !ok { + return fmt.Errorf("invalid Value key: %#v", keyIface) + } + + switch key { + case "stringValue", "string_value": + var val string + err = decoder.Decode(&val) + *v = StringValue(val) + case "boolValue", "bool_value": + var val bool + err = decoder.Decode(&val) + *v = BoolValue(val) + case "intValue", "int_value": + var val protoInt64 + err = decoder.Decode(&val) + *v = Int64Value(val.Int64()) + case "doubleValue", "double_value": + var val float64 + err = decoder.Decode(&val) + *v = Float64Value(val) + case "bytesValue", "bytes_value": + var val64 string + if err := decoder.Decode(&val64); err != nil { + return err + } + var val []byte + val, err = base64.StdEncoding.DecodeString(val64) + *v = BytesValue(val) + case "arrayValue", "array_value": + var val struct{ Values []Value } + err = decoder.Decode(&val) + *v = SliceValue(val.Values...) + case "kvlistValue", "kvlist_value": + var val struct{ Values []Attr } + err = decoder.Decode(&val) + *v = MapValue(val.Values...) + default: + // Skip unknown. + continue + } + // Use first valid. Ignore the rest. + return err + } + + // Only unknown fields. Return nil without unmarshaling any value. + return nil +} diff --git a/vendor/go.opentelemetry.io/auto/sdk/limit.go b/vendor/go.opentelemetry.io/auto/sdk/limit.go new file mode 100644 index 000000000..86babf1a8 --- /dev/null +++ b/vendor/go.opentelemetry.io/auto/sdk/limit.go @@ -0,0 +1,94 @@ +// Copyright The OpenTelemetry Authors +// SPDX-License-Identifier: Apache-2.0 + +package sdk + +import ( + "log/slog" + "os" + "strconv" +) + +// maxSpan are the span limits resolved during startup. +var maxSpan = newSpanLimits() + +type spanLimits struct { + // Attrs is the number of allowed attributes for a span. + // + // This is resolved from the environment variable value for the + // OTEL_SPAN_ATTRIBUTE_COUNT_LIMIT key if it exists. Otherwise, the + // environment variable value for OTEL_ATTRIBUTE_COUNT_LIMIT, or 128 if + // that is not set, is used. + Attrs int + // AttrValueLen is the maximum attribute value length allowed for a span. + // + // This is resolved from the environment variable value for the + // OTEL_SPAN_ATTRIBUTE_VALUE_LENGTH_LIMIT key if it exists. Otherwise, the + // environment variable value for OTEL_ATTRIBUTE_VALUE_LENGTH_LIMIT, or -1 + // if that is not set, is used. + AttrValueLen int + // Events is the number of allowed events for a span. + // + // This is resolved from the environment variable value for the + // OTEL_SPAN_EVENT_COUNT_LIMIT key, or 128 is used if that is not set. + Events int + // EventAttrs is the number of allowed attributes for a span event. + // + // The is resolved from the environment variable value for the + // OTEL_EVENT_ATTRIBUTE_COUNT_LIMIT key, or 128 is used if that is not set. + EventAttrs int + // Links is the number of allowed Links for a span. + // + // This is resolved from the environment variable value for the + // OTEL_SPAN_LINK_COUNT_LIMIT, or 128 is used if that is not set. + Links int + // LinkAttrs is the number of allowed attributes for a span link. + // + // This is resolved from the environment variable value for the + // OTEL_LINK_ATTRIBUTE_COUNT_LIMIT, or 128 is used if that is not set. + LinkAttrs int +} + +func newSpanLimits() spanLimits { + return spanLimits{ + Attrs: firstEnv( + 128, + "OTEL_SPAN_ATTRIBUTE_COUNT_LIMIT", + "OTEL_ATTRIBUTE_COUNT_LIMIT", + ), + AttrValueLen: firstEnv( + -1, // Unlimited. + "OTEL_SPAN_ATTRIBUTE_VALUE_LENGTH_LIMIT", + "OTEL_ATTRIBUTE_VALUE_LENGTH_LIMIT", + ), + Events: firstEnv(128, "OTEL_SPAN_EVENT_COUNT_LIMIT"), + EventAttrs: firstEnv(128, "OTEL_EVENT_ATTRIBUTE_COUNT_LIMIT"), + Links: firstEnv(128, "OTEL_SPAN_LINK_COUNT_LIMIT"), + LinkAttrs: firstEnv(128, "OTEL_LINK_ATTRIBUTE_COUNT_LIMIT"), + } +} + +// firstEnv returns the parsed integer value of the first matching environment +// variable from keys. The defaultVal is returned if the value is not an +// integer or no match is found. +func firstEnv(defaultVal int, keys ...string) int { + for _, key := range keys { + strV := os.Getenv(key) + if strV == "" { + continue + } + + v, err := strconv.Atoi(strV) + if err == nil { + return v + } + slog.Warn( + "invalid limit environment variable", + "error", err, + "key", key, + "value", strV, + ) + } + + return defaultVal +} diff --git a/vendor/go.opentelemetry.io/auto/sdk/span.go b/vendor/go.opentelemetry.io/auto/sdk/span.go new file mode 100644 index 000000000..6ebea12a9 --- /dev/null +++ b/vendor/go.opentelemetry.io/auto/sdk/span.go @@ -0,0 +1,432 @@ +// Copyright The OpenTelemetry Authors +// SPDX-License-Identifier: Apache-2.0 + +package sdk + +import ( + "encoding/json" + "fmt" + "reflect" + "runtime" + "strings" + "sync" + "sync/atomic" + "time" + "unicode/utf8" + + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/codes" + semconv "go.opentelemetry.io/otel/semconv/v1.26.0" + "go.opentelemetry.io/otel/trace" + "go.opentelemetry.io/otel/trace/noop" + + "go.opentelemetry.io/auto/sdk/internal/telemetry" +) + +type span struct { + noop.Span + + spanContext trace.SpanContext + sampled atomic.Bool + + mu sync.Mutex + traces *telemetry.Traces + span *telemetry.Span +} + +func (s *span) SpanContext() trace.SpanContext { + if s == nil { + return trace.SpanContext{} + } + // s.spanContext is immutable, do not acquire lock s.mu. + return s.spanContext +} + +func (s *span) IsRecording() bool { + if s == nil { + return false + } + + return s.sampled.Load() +} + +func (s *span) SetStatus(c codes.Code, msg string) { + if s == nil || !s.sampled.Load() { + return + } + + s.mu.Lock() + defer s.mu.Unlock() + + if s.span.Status == nil { + s.span.Status = new(telemetry.Status) + } + + s.span.Status.Message = msg + + switch c { + case codes.Unset: + s.span.Status.Code = telemetry.StatusCodeUnset + case codes.Error: + s.span.Status.Code = telemetry.StatusCodeError + case codes.Ok: + s.span.Status.Code = telemetry.StatusCodeOK + } +} + +func (s *span) SetAttributes(attrs ...attribute.KeyValue) { + if s == nil || !s.sampled.Load() { + return + } + + s.mu.Lock() + defer s.mu.Unlock() + + limit := maxSpan.Attrs + if limit == 0 { + // No attributes allowed. + s.span.DroppedAttrs += uint32(len(attrs)) + return + } + + m := make(map[string]int) + for i, a := range s.span.Attrs { + m[a.Key] = i + } + + for _, a := range attrs { + val := convAttrValue(a.Value) + if val.Empty() { + s.span.DroppedAttrs++ + continue + } + + if idx, ok := m[string(a.Key)]; ok { + s.span.Attrs[idx] = telemetry.Attr{ + Key: string(a.Key), + Value: val, + } + } else if limit < 0 || len(s.span.Attrs) < limit { + s.span.Attrs = append(s.span.Attrs, telemetry.Attr{ + Key: string(a.Key), + Value: val, + }) + m[string(a.Key)] = len(s.span.Attrs) - 1 + } else { + s.span.DroppedAttrs++ + } + } +} + +// convCappedAttrs converts up to limit attrs into a []telemetry.Attr. The +// number of dropped attributes is also returned. +func convCappedAttrs(limit int, attrs []attribute.KeyValue) ([]telemetry.Attr, uint32) { + if limit == 0 { + return nil, uint32(len(attrs)) + } + + if limit < 0 { + // Unlimited. + return convAttrs(attrs), 0 + } + + limit = min(len(attrs), limit) + return convAttrs(attrs[:limit]), uint32(len(attrs) - limit) +} + +func convAttrs(attrs []attribute.KeyValue) []telemetry.Attr { + if len(attrs) == 0 { + // Avoid allocations if not necessary. + return nil + } + + out := make([]telemetry.Attr, 0, len(attrs)) + for _, attr := range attrs { + key := string(attr.Key) + val := convAttrValue(attr.Value) + if val.Empty() { + continue + } + out = append(out, telemetry.Attr{Key: key, Value: val}) + } + return out +} + +func convAttrValue(value attribute.Value) telemetry.Value { + switch value.Type() { + case attribute.BOOL: + return telemetry.BoolValue(value.AsBool()) + case attribute.INT64: + return telemetry.Int64Value(value.AsInt64()) + case attribute.FLOAT64: + return telemetry.Float64Value(value.AsFloat64()) + case attribute.STRING: + v := truncate(maxSpan.AttrValueLen, value.AsString()) + return telemetry.StringValue(v) + case attribute.BOOLSLICE: + slice := value.AsBoolSlice() + out := make([]telemetry.Value, 0, len(slice)) + for _, v := range slice { + out = append(out, telemetry.BoolValue(v)) + } + return telemetry.SliceValue(out...) + case attribute.INT64SLICE: + slice := value.AsInt64Slice() + out := make([]telemetry.Value, 0, len(slice)) + for _, v := range slice { + out = append(out, telemetry.Int64Value(v)) + } + return telemetry.SliceValue(out...) + case attribute.FLOAT64SLICE: + slice := value.AsFloat64Slice() + out := make([]telemetry.Value, 0, len(slice)) + for _, v := range slice { + out = append(out, telemetry.Float64Value(v)) + } + return telemetry.SliceValue(out...) + case attribute.STRINGSLICE: + slice := value.AsStringSlice() + out := make([]telemetry.Value, 0, len(slice)) + for _, v := range slice { + v = truncate(maxSpan.AttrValueLen, v) + out = append(out, telemetry.StringValue(v)) + } + return telemetry.SliceValue(out...) + } + return telemetry.Value{} +} + +// truncate returns a truncated version of s such that it contains less than +// the limit number of characters. Truncation is applied by returning the limit +// number of valid characters contained in s. +// +// If limit is negative, it returns the original string. +// +// UTF-8 is supported. When truncating, all invalid characters are dropped +// before applying truncation. +// +// If s already contains less than the limit number of bytes, it is returned +// unchanged. No invalid characters are removed. +func truncate(limit int, s string) string { + // This prioritize performance in the following order based on the most + // common expected use-cases. + // + // - Short values less than the default limit (128). + // - Strings with valid encodings that exceed the limit. + // - No limit. + // - Strings with invalid encodings that exceed the limit. + if limit < 0 || len(s) <= limit { + return s + } + + // Optimistically, assume all valid UTF-8. + var b strings.Builder + count := 0 + for i, c := range s { + if c != utf8.RuneError { + count++ + if count > limit { + return s[:i] + } + continue + } + + _, size := utf8.DecodeRuneInString(s[i:]) + if size == 1 { + // Invalid encoding. + b.Grow(len(s) - 1) + _, _ = b.WriteString(s[:i]) + s = s[i:] + break + } + } + + // Fast-path, no invalid input. + if b.Cap() == 0 { + return s + } + + // Truncate while validating UTF-8. + for i := 0; i < len(s) && count < limit; { + c := s[i] + if c < utf8.RuneSelf { + // Optimization for single byte runes (common case). + _ = b.WriteByte(c) + i++ + count++ + continue + } + + _, size := utf8.DecodeRuneInString(s[i:]) + if size == 1 { + // We checked for all 1-byte runes above, this is a RuneError. + i++ + continue + } + + _, _ = b.WriteString(s[i : i+size]) + i += size + count++ + } + + return b.String() +} + +func (s *span) End(opts ...trace.SpanEndOption) { + if s == nil || !s.sampled.Swap(false) { + return + } + + // s.end exists so the lock (s.mu) is not held while s.ended is called. + s.ended(s.end(opts)) +} + +func (s *span) end(opts []trace.SpanEndOption) []byte { + s.mu.Lock() + defer s.mu.Unlock() + + cfg := trace.NewSpanEndConfig(opts...) + if t := cfg.Timestamp(); !t.IsZero() { + s.span.EndTime = cfg.Timestamp() + } else { + s.span.EndTime = time.Now() + } + + b, _ := json.Marshal(s.traces) // TODO: do not ignore this error. + return b +} + +// Expected to be implemented in eBPF. +// +//go:noinline +func (*span) ended(buf []byte) { ended(buf) } + +// ended is used for testing. +var ended = func([]byte) {} + +func (s *span) RecordError(err error, opts ...trace.EventOption) { + if s == nil || err == nil || !s.sampled.Load() { + return + } + + cfg := trace.NewEventConfig(opts...) + + attrs := cfg.Attributes() + attrs = append(attrs, + semconv.ExceptionType(typeStr(err)), + semconv.ExceptionMessage(err.Error()), + ) + if cfg.StackTrace() { + buf := make([]byte, 2048) + n := runtime.Stack(buf, false) + attrs = append(attrs, semconv.ExceptionStacktrace(string(buf[0:n]))) + } + + s.mu.Lock() + defer s.mu.Unlock() + + s.addEvent(semconv.ExceptionEventName, cfg.Timestamp(), attrs) +} + +func typeStr(i any) string { + t := reflect.TypeOf(i) + if t.PkgPath() == "" && t.Name() == "" { + // Likely a builtin type. + return t.String() + } + return fmt.Sprintf("%s.%s", t.PkgPath(), t.Name()) +} + +func (s *span) AddEvent(name string, opts ...trace.EventOption) { + if s == nil || !s.sampled.Load() { + return + } + + cfg := trace.NewEventConfig(opts...) + + s.mu.Lock() + defer s.mu.Unlock() + + s.addEvent(name, cfg.Timestamp(), cfg.Attributes()) +} + +// addEvent adds an event with name and attrs at tStamp to the span. The span +// lock (s.mu) needs to be held by the caller. +func (s *span) addEvent(name string, tStamp time.Time, attrs []attribute.KeyValue) { + limit := maxSpan.Events + + if limit == 0 { + s.span.DroppedEvents++ + return + } + + if limit > 0 && len(s.span.Events) == limit { + // Drop head while avoiding allocation of more capacity. + copy(s.span.Events[:limit-1], s.span.Events[1:]) + s.span.Events = s.span.Events[:limit-1] + s.span.DroppedEvents++ + } + + e := &telemetry.SpanEvent{Time: tStamp, Name: name} + e.Attrs, e.DroppedAttrs = convCappedAttrs(maxSpan.EventAttrs, attrs) + + s.span.Events = append(s.span.Events, e) +} + +func (s *span) AddLink(link trace.Link) { + if s == nil || !s.sampled.Load() { + return + } + + l := maxSpan.Links + + s.mu.Lock() + defer s.mu.Unlock() + + if l == 0 { + s.span.DroppedLinks++ + return + } + + if l > 0 && len(s.span.Links) == l { + // Drop head while avoiding allocation of more capacity. + copy(s.span.Links[:l-1], s.span.Links[1:]) + s.span.Links = s.span.Links[:l-1] + s.span.DroppedLinks++ + } + + s.span.Links = append(s.span.Links, convLink(link)) +} + +func convLinks(links []trace.Link) []*telemetry.SpanLink { + out := make([]*telemetry.SpanLink, 0, len(links)) + for _, link := range links { + out = append(out, convLink(link)) + } + return out +} + +func convLink(link trace.Link) *telemetry.SpanLink { + l := &telemetry.SpanLink{ + TraceID: telemetry.TraceID(link.SpanContext.TraceID()), + SpanID: telemetry.SpanID(link.SpanContext.SpanID()), + TraceState: link.SpanContext.TraceState().String(), + Flags: uint32(link.SpanContext.TraceFlags()), + } + l.Attrs, l.DroppedAttrs = convCappedAttrs(maxSpan.LinkAttrs, link.Attributes) + + return l +} + +func (s *span) SetName(name string) { + if s == nil || !s.sampled.Load() { + return + } + + s.mu.Lock() + defer s.mu.Unlock() + + s.span.Name = name +} + +func (*span) TracerProvider() trace.TracerProvider { return TracerProvider() } diff --git a/vendor/go.opentelemetry.io/auto/sdk/tracer.go b/vendor/go.opentelemetry.io/auto/sdk/tracer.go new file mode 100644 index 000000000..cbcfabde3 --- /dev/null +++ b/vendor/go.opentelemetry.io/auto/sdk/tracer.go @@ -0,0 +1,124 @@ +// Copyright The OpenTelemetry Authors +// SPDX-License-Identifier: Apache-2.0 + +package sdk + +import ( + "context" + "time" + + "go.opentelemetry.io/otel/trace" + "go.opentelemetry.io/otel/trace/noop" + + "go.opentelemetry.io/auto/sdk/internal/telemetry" +) + +type tracer struct { + noop.Tracer + + name, schemaURL, version string +} + +var _ trace.Tracer = tracer{} + +func (t tracer) Start(ctx context.Context, name string, opts ...trace.SpanStartOption) (context.Context, trace.Span) { + var psc trace.SpanContext + sampled := true + span := new(span) + + // Ask eBPF for sampling decision and span context info. + t.start(ctx, span, &psc, &sampled, &span.spanContext) + + span.sampled.Store(sampled) + + ctx = trace.ContextWithSpan(ctx, span) + + if sampled { + // Only build traces if sampled. + cfg := trace.NewSpanStartConfig(opts...) + span.traces, span.span = t.traces(name, cfg, span.spanContext, psc) + } + + return ctx, span +} + +// Expected to be implemented in eBPF. +// +//go:noinline +func (t *tracer) start( + ctx context.Context, + spanPtr *span, + psc *trace.SpanContext, + sampled *bool, + sc *trace.SpanContext, +) { + start(ctx, spanPtr, psc, sampled, sc) +} + +// start is used for testing. +var start = func(context.Context, *span, *trace.SpanContext, *bool, *trace.SpanContext) {} + +func (t tracer) traces(name string, cfg trace.SpanConfig, sc, psc trace.SpanContext) (*telemetry.Traces, *telemetry.Span) { + span := &telemetry.Span{ + TraceID: telemetry.TraceID(sc.TraceID()), + SpanID: telemetry.SpanID(sc.SpanID()), + Flags: uint32(sc.TraceFlags()), + TraceState: sc.TraceState().String(), + ParentSpanID: telemetry.SpanID(psc.SpanID()), + Name: name, + Kind: spanKind(cfg.SpanKind()), + } + + span.Attrs, span.DroppedAttrs = convCappedAttrs(maxSpan.Attrs, cfg.Attributes()) + + links := cfg.Links() + if limit := maxSpan.Links; limit == 0 { + span.DroppedLinks = uint32(len(links)) + } else { + if limit > 0 { + n := max(len(links)-limit, 0) + span.DroppedLinks = uint32(n) + links = links[n:] + } + span.Links = convLinks(links) + } + + if t := cfg.Timestamp(); !t.IsZero() { + span.StartTime = cfg.Timestamp() + } else { + span.StartTime = time.Now() + } + + return &telemetry.Traces{ + ResourceSpans: []*telemetry.ResourceSpans{ + { + ScopeSpans: []*telemetry.ScopeSpans{ + { + Scope: &telemetry.Scope{ + Name: t.name, + Version: t.version, + }, + Spans: []*telemetry.Span{span}, + SchemaURL: t.schemaURL, + }, + }, + }, + }, + }, span +} + +func spanKind(kind trace.SpanKind) telemetry.SpanKind { + switch kind { + case trace.SpanKindInternal: + return telemetry.SpanKindInternal + case trace.SpanKindServer: + return telemetry.SpanKindServer + case trace.SpanKindClient: + return telemetry.SpanKindClient + case trace.SpanKindProducer: + return telemetry.SpanKindProducer + case trace.SpanKindConsumer: + return telemetry.SpanKindConsumer + } + return telemetry.SpanKind(0) // undefined. +} diff --git a/vendor/go.opentelemetry.io/auto/sdk/tracer_provider.go b/vendor/go.opentelemetry.io/auto/sdk/tracer_provider.go new file mode 100644 index 000000000..dbc477a59 --- /dev/null +++ b/vendor/go.opentelemetry.io/auto/sdk/tracer_provider.go @@ -0,0 +1,33 @@ +// Copyright The OpenTelemetry Authors +// SPDX-License-Identifier: Apache-2.0 + +package sdk + +import ( + "go.opentelemetry.io/otel/trace" + "go.opentelemetry.io/otel/trace/noop" +) + +// TracerProvider returns an auto-instrumentable [trace.TracerProvider]. +// +// If an [go.opentelemetry.io/auto.Instrumentation] is configured to instrument +// the process using the returned TracerProvider, all of the telemetry it +// produces will be processed and handled by that Instrumentation. By default, +// if no Instrumentation instruments the TracerProvider it will not generate +// any trace telemetry. +func TracerProvider() trace.TracerProvider { return tracerProviderInstance } + +var tracerProviderInstance = new(tracerProvider) + +type tracerProvider struct{ noop.TracerProvider } + +var _ trace.TracerProvider = tracerProvider{} + +func (p tracerProvider) Tracer(name string, opts ...trace.TracerOption) trace.Tracer { + cfg := trace.NewTracerConfig(opts...) + return tracer{ + name: name, + version: cfg.InstrumentationVersion(), + schemaURL: cfg.SchemaURL(), + } +} diff --git a/vendor/go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp/client.go b/vendor/go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp/client.go index 6aae83bfd..b25641c55 100644 --- a/vendor/go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp/client.go +++ b/vendor/go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp/client.go @@ -18,7 +18,7 @@ var DefaultClient = &http.Client{Transport: NewTransport(http.DefaultTransport)} // Get is a convenient replacement for http.Get that adds a span around the request. func Get(ctx context.Context, targetURL string) (resp *http.Response, err error) { - req, err := http.NewRequestWithContext(ctx, "GET", targetURL, nil) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, targetURL, nil) if err != nil { return nil, err } @@ -27,7 +27,7 @@ func Get(ctx context.Context, targetURL string) (resp *http.Response, err error) // Head is a convenient replacement for http.Head that adds a span around the request. func Head(ctx context.Context, targetURL string) (resp *http.Response, err error) { - req, err := http.NewRequestWithContext(ctx, "HEAD", targetURL, nil) + req, err := http.NewRequestWithContext(ctx, http.MethodHead, targetURL, nil) if err != nil { return nil, err } @@ -36,7 +36,7 @@ func Head(ctx context.Context, targetURL string) (resp *http.Response, err error // Post is a convenient replacement for http.Post that adds a span around the request. func Post(ctx context.Context, targetURL, contentType string, body io.Reader) (resp *http.Response, err error) { - req, err := http.NewRequestWithContext(ctx, "POST", targetURL, body) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, targetURL, body) if err != nil { return nil, err } diff --git a/vendor/go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp/common.go b/vendor/go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp/common.go index 214acaf58..a83a02627 100644 --- a/vendor/go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp/common.go +++ b/vendor/go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp/common.go @@ -18,20 +18,6 @@ const ( WriteErrorKey = attribute.Key("http.write_error") // if an error occurred while writing a reply, the string of the error (io.EOF is not recorded) ) -// Server HTTP metrics. -const ( - serverRequestSize = "http.server.request.size" // Incoming request bytes total - serverResponseSize = "http.server.response.size" // Incoming response bytes total - serverDuration = "http.server.duration" // Incoming end to end duration, milliseconds -) - -// Client HTTP metrics. -const ( - clientRequestSize = "http.client.request.size" // Outgoing request bytes total - clientResponseSize = "http.client.response.size" // Outgoing response bytes total - clientDuration = "http.client.duration" // Outgoing end to end duration, milliseconds -) - // Filter is a predicate used to determine whether a given http.request should // be traced. A Filter must return true if the request should be traced. type Filter func(*http.Request) bool diff --git a/vendor/go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp/config.go b/vendor/go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp/config.go index f0a9bb9ef..a01bfafbe 100644 --- a/vendor/go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp/config.go +++ b/vendor/go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp/config.go @@ -8,6 +8,8 @@ import ( "net/http" "net/http/httptrace" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel" "go.opentelemetry.io/otel/metric" "go.opentelemetry.io/otel/propagation" @@ -33,8 +35,9 @@ type config struct { SpanNameFormatter func(string, *http.Request) string ClientTrace func(context.Context) *httptrace.ClientTrace - TracerProvider trace.TracerProvider - MeterProvider metric.MeterProvider + TracerProvider trace.TracerProvider + MeterProvider metric.MeterProvider + MetricAttributesFn func(*http.Request) []attribute.KeyValue } // Option interface used for setting optional config properties. @@ -194,3 +197,11 @@ func WithServerName(server string) Option { c.ServerName = server }) } + +// WithMetricAttributesFn returns an Option to set a function that maps an HTTP request to a slice of attribute.KeyValue. +// These attributes will be included in metrics for every request. +func WithMetricAttributesFn(metricAttributesFn func(r *http.Request) []attribute.KeyValue) Option { + return optionFunc(func(c *config) { + c.MetricAttributesFn = metricAttributesFn + }) +} diff --git a/vendor/go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp/handler.go b/vendor/go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp/handler.go index d01bdccf4..e555a475f 100644 --- a/vendor/go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp/handler.go +++ b/vendor/go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp/handler.go @@ -9,11 +9,9 @@ import ( "github.com/felixge/httpsnoop" + "go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp/internal/request" "go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp/internal/semconv" - "go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp/internal/semconvutil" "go.opentelemetry.io/otel" - "go.opentelemetry.io/otel/attribute" - "go.opentelemetry.io/otel/metric" "go.opentelemetry.io/otel/propagation" "go.opentelemetry.io/otel/trace" ) @@ -24,7 +22,6 @@ type middleware struct { server string tracer trace.Tracer - meter metric.Meter propagators propagation.TextMapPropagator spanStartOptions []trace.SpanStartOption readEvent bool @@ -34,10 +31,7 @@ type middleware struct { publicEndpoint bool publicEndpointFn func(*http.Request) bool - traceSemconv semconv.HTTPServer - requestBytesCounter metric.Int64Counter - responseBytesCounter metric.Int64Counter - serverLatencyMeasure metric.Float64Histogram + semconv semconv.HTTPServer } func defaultHandlerFormatter(operation string, _ *http.Request) string { @@ -56,8 +50,6 @@ func NewHandler(handler http.Handler, operation string, opts ...Option) http.Han func NewMiddleware(operation string, opts ...Option) func(http.Handler) http.Handler { h := middleware{ operation: operation, - - traceSemconv: semconv.NewHTTPServer(), } defaultOpts := []Option{ @@ -67,7 +59,6 @@ func NewMiddleware(operation string, opts ...Option) func(http.Handler) http.Han c := newConfig(append(defaultOpts, opts...)...) h.configure(c) - h.createMeasures() return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -78,7 +69,6 @@ func NewMiddleware(operation string, opts ...Option) func(http.Handler) http.Han func (h *middleware) configure(c *config) { h.tracer = c.Tracer - h.meter = c.Meter h.propagators = c.Propagators h.spanStartOptions = c.SpanStartOptions h.readEvent = c.ReadEvent @@ -88,36 +78,7 @@ func (h *middleware) configure(c *config) { h.publicEndpoint = c.PublicEndpoint h.publicEndpointFn = c.PublicEndpointFn h.server = c.ServerName -} - -func handleErr(err error) { - if err != nil { - otel.Handle(err) - } -} - -func (h *middleware) createMeasures() { - var err error - h.requestBytesCounter, err = h.meter.Int64Counter( - serverRequestSize, - metric.WithUnit("By"), - metric.WithDescription("Measures the size of HTTP request messages."), - ) - handleErr(err) - - h.responseBytesCounter, err = h.meter.Int64Counter( - serverResponseSize, - metric.WithUnit("By"), - metric.WithDescription("Measures the size of HTTP response messages."), - ) - handleErr(err) - - h.serverLatencyMeasure, err = h.meter.Float64Histogram( - serverDuration, - metric.WithUnit("ms"), - metric.WithDescription("Measures the duration of inbound HTTP requests."), - ) - handleErr(err) + h.semconv = semconv.NewHTTPServer(c.Meter) } // serveHTTP sets up tracing and calls the given next http.Handler with the span @@ -134,7 +95,7 @@ func (h *middleware) serveHTTP(w http.ResponseWriter, r *http.Request, next http ctx := h.propagators.Extract(r.Context(), propagation.HeaderCarrier(r.Header)) opts := []trace.SpanStartOption{ - trace.WithAttributes(h.traceSemconv.RequestTraceAttrs(h.server, r)...), + trace.WithAttributes(h.semconv.RequestTraceAttrs(h.server, r)...), } opts = append(opts, h.spanStartOptions...) @@ -156,6 +117,11 @@ func (h *middleware) serveHTTP(w http.ResponseWriter, r *http.Request, next http } } + if startTime := StartTimeFromContext(ctx); !startTime.IsZero() { + opts = append(opts, trace.WithTimestamp(startTime)) + requestStartTime = startTime + } + ctx, span := tracer.Start(ctx, h.spanNameFormatter(h.operation, r), opts...) defer span.End() @@ -166,14 +132,12 @@ func (h *middleware) serveHTTP(w http.ResponseWriter, r *http.Request, next http } } - var bw bodyWrapper // if request body is nil or NoBody, we don't want to mutate the body as it // will affect the identity of it in an unforeseeable way because we assert // ReadCloser fulfills a certain interface and it is indeed nil or NoBody. + bw := request.NewBodyWrapper(r.Body, readRecordFunc) if r.Body != nil && r.Body != http.NoBody { - bw.ReadCloser = r.Body - bw.record = readRecordFunc - r.Body = &bw + r.Body = bw } writeRecordFunc := func(int64) {} @@ -183,13 +147,7 @@ func (h *middleware) serveHTTP(w http.ResponseWriter, r *http.Request, next http } } - rww := &respWriterWrapper{ - ResponseWriter: w, - record: writeRecordFunc, - ctx: ctx, - props: h.propagators, - statusCode: http.StatusOK, // default status code in case the Handler doesn't write anything - } + rww := request.NewRespWriterWrapper(w, writeRecordFunc) // Wrap w to use our ResponseWriter methods while also exposing // other interfaces that w may implement (http.CloseNotifier, @@ -217,35 +175,39 @@ func (h *middleware) serveHTTP(w http.ResponseWriter, r *http.Request, next http next.ServeHTTP(w, r.WithContext(ctx)) - span.SetStatus(semconv.ServerStatus(rww.statusCode)) - span.SetAttributes(h.traceSemconv.ResponseTraceAttrs(semconv.ResponseTelemetry{ - StatusCode: rww.statusCode, - ReadBytes: bw.read.Load(), - ReadError: bw.err, - WriteBytes: rww.written, - WriteError: rww.err, + statusCode := rww.StatusCode() + bytesWritten := rww.BytesWritten() + span.SetStatus(h.semconv.Status(statusCode)) + span.SetAttributes(h.semconv.ResponseTraceAttrs(semconv.ResponseTelemetry{ + StatusCode: statusCode, + ReadBytes: bw.BytesRead(), + ReadError: bw.Error(), + WriteBytes: bytesWritten, + WriteError: rww.Error(), })...) - // Add metrics - attributes := append(labeler.Get(), semconvutil.HTTPServerRequestMetrics(h.server, r)...) - if rww.statusCode > 0 { - attributes = append(attributes, semconv.HTTPStatusCode(rww.statusCode)) - } - o := metric.WithAttributeSet(attribute.NewSet(attributes...)) - addOpts := []metric.AddOption{o} // Allocate vararg slice once. - h.requestBytesCounter.Add(ctx, bw.read.Load(), addOpts...) - h.responseBytesCounter.Add(ctx, rww.written, addOpts...) - // Use floating point division here for higher precision (instead of Millisecond method). elapsedTime := float64(time.Since(requestStartTime)) / float64(time.Millisecond) - h.serverLatencyMeasure.Record(ctx, elapsedTime, o) + h.semconv.RecordMetrics(ctx, semconv.ServerMetricData{ + ServerName: h.server, + ResponseSize: bytesWritten, + MetricAttributes: semconv.MetricAttributes{ + Req: r, + StatusCode: statusCode, + AdditionalAttributes: labeler.Get(), + }, + MetricData: semconv.MetricData{ + RequestSize: bw.BytesRead(), + ElapsedTime: elapsedTime, + }, + }) } // WithRouteTag annotates spans and metrics with the provided route name // with HTTP route attribute. func WithRouteTag(route string, h http.Handler) http.Handler { - attr := semconv.NewHTTPServer().Route(route) + attr := semconv.NewHTTPServer(nil).Route(route) return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { span := trace.SpanFromContext(r.Context()) span.SetAttributes(attr) diff --git a/vendor/go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp/internal/request/body_wrapper.go b/vendor/go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp/internal/request/body_wrapper.go new file mode 100644 index 000000000..a945f5566 --- /dev/null +++ b/vendor/go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp/internal/request/body_wrapper.go @@ -0,0 +1,75 @@ +// Copyright The OpenTelemetry Authors +// SPDX-License-Identifier: Apache-2.0 + +package request // import "go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp/internal/request" + +import ( + "io" + "sync" +) + +var _ io.ReadCloser = &BodyWrapper{} + +// BodyWrapper wraps a http.Request.Body (an io.ReadCloser) to track the number +// of bytes read and the last error. +type BodyWrapper struct { + io.ReadCloser + OnRead func(n int64) // must not be nil + + mu sync.Mutex + read int64 + err error +} + +// NewBodyWrapper creates a new BodyWrapper. +// +// The onRead attribute is a callback that will be called every time the data +// is read, with the number of bytes being read. +func NewBodyWrapper(body io.ReadCloser, onRead func(int64)) *BodyWrapper { + return &BodyWrapper{ + ReadCloser: body, + OnRead: onRead, + } +} + +// Read reads the data from the io.ReadCloser, and stores the number of bytes +// read and the error. +func (w *BodyWrapper) Read(b []byte) (int, error) { + n, err := w.ReadCloser.Read(b) + n1 := int64(n) + + w.updateReadData(n1, err) + w.OnRead(n1) + return n, err +} + +func (w *BodyWrapper) updateReadData(n int64, err error) { + w.mu.Lock() + defer w.mu.Unlock() + + w.read += n + if err != nil { + w.err = err + } +} + +// Closes closes the io.ReadCloser. +func (w *BodyWrapper) Close() error { + return w.ReadCloser.Close() +} + +// BytesRead returns the number of bytes read up to this point. +func (w *BodyWrapper) BytesRead() int64 { + w.mu.Lock() + defer w.mu.Unlock() + + return w.read +} + +// Error returns the last error. +func (w *BodyWrapper) Error() error { + w.mu.Lock() + defer w.mu.Unlock() + + return w.err +} diff --git a/vendor/go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp/internal/request/resp_writer_wrapper.go b/vendor/go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp/internal/request/resp_writer_wrapper.go new file mode 100644 index 000000000..fbc344cbd --- /dev/null +++ b/vendor/go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp/internal/request/resp_writer_wrapper.go @@ -0,0 +1,119 @@ +// Copyright The OpenTelemetry Authors +// SPDX-License-Identifier: Apache-2.0 + +package request // import "go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp/internal/request" + +import ( + "net/http" + "sync" +) + +var _ http.ResponseWriter = &RespWriterWrapper{} + +// RespWriterWrapper wraps a http.ResponseWriter in order to track the number of +// bytes written, the last error, and to catch the first written statusCode. +// TODO: The wrapped http.ResponseWriter doesn't implement any of the optional +// types (http.Hijacker, http.Pusher, http.CloseNotifier, etc) +// that may be useful when using it in real life situations. +type RespWriterWrapper struct { + http.ResponseWriter + OnWrite func(n int64) // must not be nil + + mu sync.RWMutex + written int64 + statusCode int + err error + wroteHeader bool +} + +// NewRespWriterWrapper creates a new RespWriterWrapper. +// +// The onWrite attribute is a callback that will be called every time the data +// is written, with the number of bytes that were written. +func NewRespWriterWrapper(w http.ResponseWriter, onWrite func(int64)) *RespWriterWrapper { + return &RespWriterWrapper{ + ResponseWriter: w, + OnWrite: onWrite, + statusCode: http.StatusOK, // default status code in case the Handler doesn't write anything + } +} + +// Write writes the bytes array into the [ResponseWriter], and tracks the +// number of bytes written and last error. +func (w *RespWriterWrapper) Write(p []byte) (int, error) { + w.mu.Lock() + defer w.mu.Unlock() + + if !w.wroteHeader { + w.writeHeader(http.StatusOK) + } + + n, err := w.ResponseWriter.Write(p) + n1 := int64(n) + w.OnWrite(n1) + w.written += n1 + w.err = err + return n, err +} + +// WriteHeader persists initial statusCode for span attribution. +// All calls to WriteHeader will be propagated to the underlying ResponseWriter +// and will persist the statusCode from the first call. +// Blocking consecutive calls to WriteHeader alters expected behavior and will +// remove warning logs from net/http where developers will notice incorrect handler implementations. +func (w *RespWriterWrapper) WriteHeader(statusCode int) { + w.mu.Lock() + defer w.mu.Unlock() + + w.writeHeader(statusCode) +} + +// writeHeader persists the status code for span attribution, and propagates +// the call to the underlying ResponseWriter. +// It does not acquire a lock, and therefore assumes that is being handled by a +// parent method. +func (w *RespWriterWrapper) writeHeader(statusCode int) { + if !w.wroteHeader { + w.wroteHeader = true + w.statusCode = statusCode + } + w.ResponseWriter.WriteHeader(statusCode) +} + +// Flush implements [http.Flusher]. +func (w *RespWriterWrapper) Flush() { + w.mu.Lock() + defer w.mu.Unlock() + + if !w.wroteHeader { + w.writeHeader(http.StatusOK) + } + + if f, ok := w.ResponseWriter.(http.Flusher); ok { + f.Flush() + } +} + +// BytesWritten returns the number of bytes written. +func (w *RespWriterWrapper) BytesWritten() int64 { + w.mu.RLock() + defer w.mu.RUnlock() + + return w.written +} + +// BytesWritten returns the HTTP status code that was sent. +func (w *RespWriterWrapper) StatusCode() int { + w.mu.RLock() + defer w.mu.RUnlock() + + return w.statusCode +} + +// Error returns the last error. +func (w *RespWriterWrapper) Error() error { + w.mu.RLock() + defer w.mu.RUnlock() + + return w.err +} diff --git a/vendor/go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp/internal/semconv/env.go b/vendor/go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp/internal/semconv/env.go index 3ec0ad00c..3b036f8a3 100644 --- a/vendor/go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp/internal/semconv/env.go +++ b/vendor/go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp/internal/semconv/env.go @@ -4,13 +4,16 @@ package semconv // import "go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp/internal/semconv" import ( + "context" "fmt" "net/http" "os" "strings" + "sync" "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/codes" + "go.opentelemetry.io/otel/metric" ) type ResponseTelemetry struct { @@ -23,6 +26,11 @@ type ResponseTelemetry struct { type HTTPServer struct { duplicate bool + + // Old metrics + requestBytesCounter metric.Int64Counter + responseBytesCounter metric.Int64Counter + serverLatencyMeasure metric.Float64Histogram } // RequestTraceAttrs returns trace attributes for an HTTP request received by a @@ -43,9 +51,9 @@ type HTTPServer struct { // The req Host will be used to determine the server instead. func (s HTTPServer) RequestTraceAttrs(server string, req *http.Request) []attribute.KeyValue { if s.duplicate { - return append(oldHTTPServer{}.RequestTraceAttrs(server, req), newHTTPServer{}.RequestTraceAttrs(server, req)...) + return append(OldHTTPServer{}.RequestTraceAttrs(server, req), CurrentHTTPServer{}.RequestTraceAttrs(server, req)...) } - return oldHTTPServer{}.RequestTraceAttrs(server, req) + return OldHTTPServer{}.RequestTraceAttrs(server, req) } // ResponseTraceAttrs returns trace attributes for telemetry from an HTTP response. @@ -53,25 +61,20 @@ func (s HTTPServer) RequestTraceAttrs(server string, req *http.Request) []attrib // If any of the fields in the ResponseTelemetry are not set the attribute will be omitted. func (s HTTPServer) ResponseTraceAttrs(resp ResponseTelemetry) []attribute.KeyValue { if s.duplicate { - return append(oldHTTPServer{}.ResponseTraceAttrs(resp), newHTTPServer{}.ResponseTraceAttrs(resp)...) + return append(OldHTTPServer{}.ResponseTraceAttrs(resp), CurrentHTTPServer{}.ResponseTraceAttrs(resp)...) } - return oldHTTPServer{}.ResponseTraceAttrs(resp) + return OldHTTPServer{}.ResponseTraceAttrs(resp) } // Route returns the attribute for the route. func (s HTTPServer) Route(route string) attribute.KeyValue { - return oldHTTPServer{}.Route(route) -} - -func NewHTTPServer() HTTPServer { - env := strings.ToLower(os.Getenv("OTEL_HTTP_CLIENT_COMPATIBILITY_MODE")) - return HTTPServer{duplicate: env == "http/dup"} + return OldHTTPServer{}.Route(route) } -// ServerStatus returns a span status code and message for an HTTP status code +// Status returns a span status code and message for an HTTP status code // value returned by a server. Status codes in the 400-499 range are not // returned as errors. -func ServerStatus(code int) (codes.Code, string) { +func (s HTTPServer) Status(code int) (codes.Code, string) { if code < 100 || code >= 600 { return codes.Error, fmt.Sprintf("Invalid HTTP status code %d", code) } @@ -80,3 +83,155 @@ func ServerStatus(code int) (codes.Code, string) { } return codes.Unset, "" } + +type ServerMetricData struct { + ServerName string + ResponseSize int64 + + MetricData + MetricAttributes +} + +type MetricAttributes struct { + Req *http.Request + StatusCode int + AdditionalAttributes []attribute.KeyValue +} + +type MetricData struct { + RequestSize int64 + ElapsedTime float64 +} + +var metricAddOptionPool = &sync.Pool{ + New: func() interface{} { + return &[]metric.AddOption{} + }, +} + +func (s HTTPServer) RecordMetrics(ctx context.Context, md ServerMetricData) { + if s.requestBytesCounter == nil || s.responseBytesCounter == nil || s.serverLatencyMeasure == nil { + // This will happen if an HTTPServer{} is used instead of NewHTTPServer. + return + } + + attributes := OldHTTPServer{}.MetricAttributes(md.ServerName, md.Req, md.StatusCode, md.AdditionalAttributes) + o := metric.WithAttributeSet(attribute.NewSet(attributes...)) + addOpts := metricAddOptionPool.Get().(*[]metric.AddOption) + *addOpts = append(*addOpts, o) + s.requestBytesCounter.Add(ctx, md.RequestSize, *addOpts...) + s.responseBytesCounter.Add(ctx, md.ResponseSize, *addOpts...) + s.serverLatencyMeasure.Record(ctx, md.ElapsedTime, o) + *addOpts = (*addOpts)[:0] + metricAddOptionPool.Put(addOpts) + + // TODO: Duplicate Metrics +} + +func NewHTTPServer(meter metric.Meter) HTTPServer { + env := strings.ToLower(os.Getenv("OTEL_SEMCONV_STABILITY_OPT_IN")) + duplicate := env == "http/dup" + server := HTTPServer{ + duplicate: duplicate, + } + server.requestBytesCounter, server.responseBytesCounter, server.serverLatencyMeasure = OldHTTPServer{}.createMeasures(meter) + return server +} + +type HTTPClient struct { + duplicate bool + + // old metrics + requestBytesCounter metric.Int64Counter + responseBytesCounter metric.Int64Counter + latencyMeasure metric.Float64Histogram +} + +func NewHTTPClient(meter metric.Meter) HTTPClient { + env := strings.ToLower(os.Getenv("OTEL_SEMCONV_STABILITY_OPT_IN")) + client := HTTPClient{ + duplicate: env == "http/dup", + } + client.requestBytesCounter, client.responseBytesCounter, client.latencyMeasure = OldHTTPClient{}.createMeasures(meter) + return client +} + +// RequestTraceAttrs returns attributes for an HTTP request made by a client. +func (c HTTPClient) RequestTraceAttrs(req *http.Request) []attribute.KeyValue { + if c.duplicate { + return append(OldHTTPClient{}.RequestTraceAttrs(req), CurrentHTTPClient{}.RequestTraceAttrs(req)...) + } + return OldHTTPClient{}.RequestTraceAttrs(req) +} + +// ResponseTraceAttrs returns metric attributes for an HTTP request made by a client. +func (c HTTPClient) ResponseTraceAttrs(resp *http.Response) []attribute.KeyValue { + if c.duplicate { + return append(OldHTTPClient{}.ResponseTraceAttrs(resp), CurrentHTTPClient{}.ResponseTraceAttrs(resp)...) + } + + return OldHTTPClient{}.ResponseTraceAttrs(resp) +} + +func (c HTTPClient) Status(code int) (codes.Code, string) { + if code < 100 || code >= 600 { + return codes.Error, fmt.Sprintf("Invalid HTTP status code %d", code) + } + if code >= 400 { + return codes.Error, "" + } + return codes.Unset, "" +} + +func (c HTTPClient) ErrorType(err error) attribute.KeyValue { + if c.duplicate { + return CurrentHTTPClient{}.ErrorType(err) + } + + return attribute.KeyValue{} +} + +type MetricOpts struct { + measurement metric.MeasurementOption + addOptions metric.AddOption +} + +func (o MetricOpts) MeasurementOption() metric.MeasurementOption { + return o.measurement +} + +func (o MetricOpts) AddOptions() metric.AddOption { + return o.addOptions +} + +func (c HTTPClient) MetricOptions(ma MetricAttributes) MetricOpts { + attributes := OldHTTPClient{}.MetricAttributes(ma.Req, ma.StatusCode, ma.AdditionalAttributes) + // TODO: Duplicate Metrics + set := metric.WithAttributeSet(attribute.NewSet(attributes...)) + return MetricOpts{ + measurement: set, + addOptions: set, + } +} + +func (s HTTPClient) RecordMetrics(ctx context.Context, md MetricData, opts MetricOpts) { + if s.requestBytesCounter == nil || s.latencyMeasure == nil { + // This will happen if an HTTPClient{} is used instead of NewHTTPClient(). + return + } + + s.requestBytesCounter.Add(ctx, md.RequestSize, opts.AddOptions()) + s.latencyMeasure.Record(ctx, md.ElapsedTime, opts.MeasurementOption()) + + // TODO: Duplicate Metrics +} + +func (s HTTPClient) RecordResponseSize(ctx context.Context, responseData int64, opts metric.AddOption) { + if s.responseBytesCounter == nil { + // This will happen if an HTTPClient{} is used instead of NewHTTPClient(). + return + } + + s.responseBytesCounter.Add(ctx, responseData, opts) + // TODO: Duplicate Metrics +} diff --git a/vendor/go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp/internal/semconv/httpconv.go b/vendor/go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp/internal/semconv/httpconv.go new file mode 100644 index 000000000..dc9ec7bc3 --- /dev/null +++ b/vendor/go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp/internal/semconv/httpconv.go @@ -0,0 +1,348 @@ +// Copyright The OpenTelemetry Authors +// SPDX-License-Identifier: Apache-2.0 + +package semconv // import "go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp/internal/semconv" + +import ( + "fmt" + "net/http" + "reflect" + "strconv" + "strings" + + "go.opentelemetry.io/otel/attribute" + semconvNew "go.opentelemetry.io/otel/semconv/v1.26.0" +) + +type CurrentHTTPServer struct{} + +// TraceRequest returns trace attributes for an HTTP request received by a +// server. +// +// The server must be the primary server name if it is known. For example this +// would be the ServerName directive +// (https://httpd.apache.org/docs/2.4/mod/core.html#servername) for an Apache +// server, and the server_name directive +// (http://nginx.org/en/docs/http/ngx_http_core_module.html#server_name) for an +// nginx server. More generically, the primary server name would be the host +// header value that matches the default virtual host of an HTTP server. It +// should include the host identifier and if a port is used to route to the +// server that port identifier should be included as an appropriate port +// suffix. +// +// If the primary server name is not known, server should be an empty string. +// The req Host will be used to determine the server instead. +func (n CurrentHTTPServer) RequestTraceAttrs(server string, req *http.Request) []attribute.KeyValue { + count := 3 // ServerAddress, Method, Scheme + + var host string + var p int + if server == "" { + host, p = SplitHostPort(req.Host) + } else { + // Prioritize the primary server name. + host, p = SplitHostPort(server) + if p < 0 { + _, p = SplitHostPort(req.Host) + } + } + + hostPort := requiredHTTPPort(req.TLS != nil, p) + if hostPort > 0 { + count++ + } + + method, methodOriginal := n.method(req.Method) + if methodOriginal != (attribute.KeyValue{}) { + count++ + } + + scheme := n.scheme(req.TLS != nil) + + if peer, peerPort := SplitHostPort(req.RemoteAddr); peer != "" { + // The Go HTTP server sets RemoteAddr to "IP:port", this will not be a + // file-path that would be interpreted with a sock family. + count++ + if peerPort > 0 { + count++ + } + } + + useragent := req.UserAgent() + if useragent != "" { + count++ + } + + clientIP := serverClientIP(req.Header.Get("X-Forwarded-For")) + if clientIP != "" { + count++ + } + + if req.URL != nil && req.URL.Path != "" { + count++ + } + + protoName, protoVersion := netProtocol(req.Proto) + if protoName != "" && protoName != "http" { + count++ + } + if protoVersion != "" { + count++ + } + + attrs := make([]attribute.KeyValue, 0, count) + attrs = append(attrs, + semconvNew.ServerAddress(host), + method, + scheme, + ) + + if hostPort > 0 { + attrs = append(attrs, semconvNew.ServerPort(hostPort)) + } + if methodOriginal != (attribute.KeyValue{}) { + attrs = append(attrs, methodOriginal) + } + + if peer, peerPort := SplitHostPort(req.RemoteAddr); peer != "" { + // The Go HTTP server sets RemoteAddr to "IP:port", this will not be a + // file-path that would be interpreted with a sock family. + attrs = append(attrs, semconvNew.NetworkPeerAddress(peer)) + if peerPort > 0 { + attrs = append(attrs, semconvNew.NetworkPeerPort(peerPort)) + } + } + + if useragent := req.UserAgent(); useragent != "" { + attrs = append(attrs, semconvNew.UserAgentOriginal(useragent)) + } + + if clientIP != "" { + attrs = append(attrs, semconvNew.ClientAddress(clientIP)) + } + + if req.URL != nil && req.URL.Path != "" { + attrs = append(attrs, semconvNew.URLPath(req.URL.Path)) + } + + if protoName != "" && protoName != "http" { + attrs = append(attrs, semconvNew.NetworkProtocolName(protoName)) + } + if protoVersion != "" { + attrs = append(attrs, semconvNew.NetworkProtocolVersion(protoVersion)) + } + + return attrs +} + +func (n CurrentHTTPServer) method(method string) (attribute.KeyValue, attribute.KeyValue) { + if method == "" { + return semconvNew.HTTPRequestMethodGet, attribute.KeyValue{} + } + if attr, ok := methodLookup[method]; ok { + return attr, attribute.KeyValue{} + } + + orig := semconvNew.HTTPRequestMethodOriginal(method) + if attr, ok := methodLookup[strings.ToUpper(method)]; ok { + return attr, orig + } + return semconvNew.HTTPRequestMethodGet, orig +} + +func (n CurrentHTTPServer) scheme(https bool) attribute.KeyValue { // nolint:revive + if https { + return semconvNew.URLScheme("https") + } + return semconvNew.URLScheme("http") +} + +// TraceResponse returns trace attributes for telemetry from an HTTP response. +// +// If any of the fields in the ResponseTelemetry are not set the attribute will be omitted. +func (n CurrentHTTPServer) ResponseTraceAttrs(resp ResponseTelemetry) []attribute.KeyValue { + var count int + + if resp.ReadBytes > 0 { + count++ + } + if resp.WriteBytes > 0 { + count++ + } + if resp.StatusCode > 0 { + count++ + } + + attributes := make([]attribute.KeyValue, 0, count) + + if resp.ReadBytes > 0 { + attributes = append(attributes, + semconvNew.HTTPRequestBodySize(int(resp.ReadBytes)), + ) + } + if resp.WriteBytes > 0 { + attributes = append(attributes, + semconvNew.HTTPResponseBodySize(int(resp.WriteBytes)), + ) + } + if resp.StatusCode > 0 { + attributes = append(attributes, + semconvNew.HTTPResponseStatusCode(resp.StatusCode), + ) + } + + return attributes +} + +// Route returns the attribute for the route. +func (n CurrentHTTPServer) Route(route string) attribute.KeyValue { + return semconvNew.HTTPRoute(route) +} + +type CurrentHTTPClient struct{} + +// RequestTraceAttrs returns trace attributes for an HTTP request made by a client. +func (n CurrentHTTPClient) RequestTraceAttrs(req *http.Request) []attribute.KeyValue { + /* + below attributes are returned: + - http.request.method + - http.request.method.original + - url.full + - server.address + - server.port + - network.protocol.name + - network.protocol.version + */ + numOfAttributes := 3 // URL, server address, proto, and method. + + var urlHost string + if req.URL != nil { + urlHost = req.URL.Host + } + var requestHost string + var requestPort int + for _, hostport := range []string{urlHost, req.Header.Get("Host")} { + requestHost, requestPort = SplitHostPort(hostport) + if requestHost != "" || requestPort > 0 { + break + } + } + + eligiblePort := requiredHTTPPort(req.URL != nil && req.URL.Scheme == "https", requestPort) + if eligiblePort > 0 { + numOfAttributes++ + } + useragent := req.UserAgent() + if useragent != "" { + numOfAttributes++ + } + + protoName, protoVersion := netProtocol(req.Proto) + if protoName != "" && protoName != "http" { + numOfAttributes++ + } + if protoVersion != "" { + numOfAttributes++ + } + + method, originalMethod := n.method(req.Method) + if originalMethod != (attribute.KeyValue{}) { + numOfAttributes++ + } + + attrs := make([]attribute.KeyValue, 0, numOfAttributes) + + attrs = append(attrs, method) + if originalMethod != (attribute.KeyValue{}) { + attrs = append(attrs, originalMethod) + } + + var u string + if req.URL != nil { + // Remove any username/password info that may be in the URL. + userinfo := req.URL.User + req.URL.User = nil + u = req.URL.String() + // Restore any username/password info that was removed. + req.URL.User = userinfo + } + attrs = append(attrs, semconvNew.URLFull(u)) + + attrs = append(attrs, semconvNew.ServerAddress(requestHost)) + if eligiblePort > 0 { + attrs = append(attrs, semconvNew.ServerPort(eligiblePort)) + } + + if protoName != "" && protoName != "http" { + attrs = append(attrs, semconvNew.NetworkProtocolName(protoName)) + } + if protoVersion != "" { + attrs = append(attrs, semconvNew.NetworkProtocolVersion(protoVersion)) + } + + return attrs +} + +// ResponseTraceAttrs returns trace attributes for an HTTP response made by a client. +func (n CurrentHTTPClient) ResponseTraceAttrs(resp *http.Response) []attribute.KeyValue { + /* + below attributes are returned: + - http.response.status_code + - error.type + */ + var count int + if resp.StatusCode > 0 { + count++ + } + + if isErrorStatusCode(resp.StatusCode) { + count++ + } + + attrs := make([]attribute.KeyValue, 0, count) + if resp.StatusCode > 0 { + attrs = append(attrs, semconvNew.HTTPResponseStatusCode(resp.StatusCode)) + } + + if isErrorStatusCode(resp.StatusCode) { + errorType := strconv.Itoa(resp.StatusCode) + attrs = append(attrs, semconvNew.ErrorTypeKey.String(errorType)) + } + return attrs +} + +func (n CurrentHTTPClient) ErrorType(err error) attribute.KeyValue { + t := reflect.TypeOf(err) + var value string + if t.PkgPath() == "" && t.Name() == "" { + // Likely a builtin type. + value = t.String() + } else { + value = fmt.Sprintf("%s.%s", t.PkgPath(), t.Name()) + } + + if value == "" { + return semconvNew.ErrorTypeOther + } + + return semconvNew.ErrorTypeKey.String(value) +} + +func (n CurrentHTTPClient) method(method string) (attribute.KeyValue, attribute.KeyValue) { + if method == "" { + return semconvNew.HTTPRequestMethodGet, attribute.KeyValue{} + } + if attr, ok := methodLookup[method]; ok { + return attr, attribute.KeyValue{} + } + + orig := semconvNew.HTTPRequestMethodOriginal(method) + if attr, ok := methodLookup[strings.ToUpper(method)]; ok { + return attr, orig + } + return semconvNew.HTTPRequestMethodGet, orig +} + +func isErrorStatusCode(code int) bool { + return code >= 400 || code < 100 +} diff --git a/vendor/go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp/internal/semconv/util.go b/vendor/go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp/internal/semconv/util.go index e7f293761..93e8d0f94 100644 --- a/vendor/go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp/internal/semconv/util.go +++ b/vendor/go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp/internal/semconv/util.go @@ -9,18 +9,19 @@ import ( "strconv" "strings" + "go.opentelemetry.io/otel" "go.opentelemetry.io/otel/attribute" - semconvNew "go.opentelemetry.io/otel/semconv/v1.24.0" + semconvNew "go.opentelemetry.io/otel/semconv/v1.26.0" ) -// splitHostPort splits a network address hostport of the form "host", +// SplitHostPort splits a network address hostport of the form "host", // "host%zone", "[host]", "[host%zone], "host:port", "host%zone:port", // "[host]:port", "[host%zone]:port", or ":port" into host or host%zone and // port. // // An empty host is returned if it is not provided or unparsable. A negative // port is returned if it is not provided or unparsable. -func splitHostPort(hostport string) (host string, port int) { +func SplitHostPort(hostport string) (host string, port int) { port = -1 if strings.HasPrefix(hostport, "[") { @@ -49,7 +50,7 @@ func splitHostPort(hostport string) (host string, port int) { if err != nil { return } - return host, int(p) + return host, int(p) // nolint: gosec // Byte size checked 16 above. } func requiredHTTPPort(https bool, port int) int { // nolint:revive @@ -89,3 +90,9 @@ var methodLookup = map[string]attribute.KeyValue{ http.MethodPut: semconvNew.HTTPRequestMethodPut, http.MethodTrace: semconvNew.HTTPRequestMethodTrace, } + +func handleErr(err error) { + if err != nil { + otel.Handle(err) + } +} diff --git a/vendor/go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp/internal/semconv/v1.20.0.go b/vendor/go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp/internal/semconv/v1.20.0.go index c3e838aaa..c042249dd 100644 --- a/vendor/go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp/internal/semconv/v1.20.0.go +++ b/vendor/go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp/internal/semconv/v1.20.0.go @@ -7,13 +7,17 @@ import ( "errors" "io" "net/http" + "slices" + "strings" "go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp/internal/semconvutil" "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/metric" + "go.opentelemetry.io/otel/metric/noop" semconv "go.opentelemetry.io/otel/semconv/v1.20.0" ) -type oldHTTPServer struct{} +type OldHTTPServer struct{} // RequestTraceAttrs returns trace attributes for an HTTP request received by a // server. @@ -31,14 +35,14 @@ type oldHTTPServer struct{} // // If the primary server name is not known, server should be an empty string. // The req Host will be used to determine the server instead. -func (o oldHTTPServer) RequestTraceAttrs(server string, req *http.Request) []attribute.KeyValue { +func (o OldHTTPServer) RequestTraceAttrs(server string, req *http.Request) []attribute.KeyValue { return semconvutil.HTTPServerRequest(server, req) } // ResponseTraceAttrs returns trace attributes for telemetry from an HTTP response. // // If any of the fields in the ResponseTelemetry are not set the attribute will be omitted. -func (o oldHTTPServer) ResponseTraceAttrs(resp ResponseTelemetry) []attribute.KeyValue { +func (o OldHTTPServer) ResponseTraceAttrs(resp ResponseTelemetry) []attribute.KeyValue { attributes := []attribute.KeyValue{} if resp.ReadBytes > 0 { @@ -63,7 +67,7 @@ func (o oldHTTPServer) ResponseTraceAttrs(resp ResponseTelemetry) []attribute.Ke } // Route returns the attribute for the route. -func (o oldHTTPServer) Route(route string) attribute.KeyValue { +func (o OldHTTPServer) Route(route string) attribute.KeyValue { return semconv.HTTPRoute(route) } @@ -72,3 +76,199 @@ func (o oldHTTPServer) Route(route string) attribute.KeyValue { func HTTPStatusCode(status int) attribute.KeyValue { return semconv.HTTPStatusCode(status) } + +// Server HTTP metrics. +const ( + serverRequestSize = "http.server.request.size" // Incoming request bytes total + serverResponseSize = "http.server.response.size" // Incoming response bytes total + serverDuration = "http.server.duration" // Incoming end to end duration, milliseconds +) + +func (h OldHTTPServer) createMeasures(meter metric.Meter) (metric.Int64Counter, metric.Int64Counter, metric.Float64Histogram) { + if meter == nil { + return noop.Int64Counter{}, noop.Int64Counter{}, noop.Float64Histogram{} + } + var err error + requestBytesCounter, err := meter.Int64Counter( + serverRequestSize, + metric.WithUnit("By"), + metric.WithDescription("Measures the size of HTTP request messages."), + ) + handleErr(err) + + responseBytesCounter, err := meter.Int64Counter( + serverResponseSize, + metric.WithUnit("By"), + metric.WithDescription("Measures the size of HTTP response messages."), + ) + handleErr(err) + + serverLatencyMeasure, err := meter.Float64Histogram( + serverDuration, + metric.WithUnit("ms"), + metric.WithDescription("Measures the duration of inbound HTTP requests."), + ) + handleErr(err) + + return requestBytesCounter, responseBytesCounter, serverLatencyMeasure +} + +func (o OldHTTPServer) MetricAttributes(server string, req *http.Request, statusCode int, additionalAttributes []attribute.KeyValue) []attribute.KeyValue { + n := len(additionalAttributes) + 3 + var host string + var p int + if server == "" { + host, p = SplitHostPort(req.Host) + } else { + // Prioritize the primary server name. + host, p = SplitHostPort(server) + if p < 0 { + _, p = SplitHostPort(req.Host) + } + } + hostPort := requiredHTTPPort(req.TLS != nil, p) + if hostPort > 0 { + n++ + } + protoName, protoVersion := netProtocol(req.Proto) + if protoName != "" { + n++ + } + if protoVersion != "" { + n++ + } + + if statusCode > 0 { + n++ + } + + attributes := slices.Grow(additionalAttributes, n) + attributes = append(attributes, + standardizeHTTPMethodMetric(req.Method), + o.scheme(req.TLS != nil), + semconv.NetHostName(host)) + + if hostPort > 0 { + attributes = append(attributes, semconv.NetHostPort(hostPort)) + } + if protoName != "" { + attributes = append(attributes, semconv.NetProtocolName(protoName)) + } + if protoVersion != "" { + attributes = append(attributes, semconv.NetProtocolVersion(protoVersion)) + } + + if statusCode > 0 { + attributes = append(attributes, semconv.HTTPStatusCode(statusCode)) + } + return attributes +} + +func (o OldHTTPServer) scheme(https bool) attribute.KeyValue { // nolint:revive + if https { + return semconv.HTTPSchemeHTTPS + } + return semconv.HTTPSchemeHTTP +} + +type OldHTTPClient struct{} + +func (o OldHTTPClient) RequestTraceAttrs(req *http.Request) []attribute.KeyValue { + return semconvutil.HTTPClientRequest(req) +} + +func (o OldHTTPClient) ResponseTraceAttrs(resp *http.Response) []attribute.KeyValue { + return semconvutil.HTTPClientResponse(resp) +} + +func (o OldHTTPClient) MetricAttributes(req *http.Request, statusCode int, additionalAttributes []attribute.KeyValue) []attribute.KeyValue { + /* The following semantic conventions are returned if present: + http.method string + http.status_code int + net.peer.name string + net.peer.port int + */ + + n := 2 // method, peer name. + var h string + if req.URL != nil { + h = req.URL.Host + } + var requestHost string + var requestPort int + for _, hostport := range []string{h, req.Header.Get("Host")} { + requestHost, requestPort = SplitHostPort(hostport) + if requestHost != "" || requestPort > 0 { + break + } + } + + port := requiredHTTPPort(req.URL != nil && req.URL.Scheme == "https", requestPort) + if port > 0 { + n++ + } + + if statusCode > 0 { + n++ + } + + attributes := slices.Grow(additionalAttributes, n) + attributes = append(attributes, + standardizeHTTPMethodMetric(req.Method), + semconv.NetPeerName(requestHost), + ) + + if port > 0 { + attributes = append(attributes, semconv.NetPeerPort(port)) + } + + if statusCode > 0 { + attributes = append(attributes, semconv.HTTPStatusCode(statusCode)) + } + return attributes +} + +// Client HTTP metrics. +const ( + clientRequestSize = "http.client.request.size" // Incoming request bytes total + clientResponseSize = "http.client.response.size" // Incoming response bytes total + clientDuration = "http.client.duration" // Incoming end to end duration, milliseconds +) + +func (o OldHTTPClient) createMeasures(meter metric.Meter) (metric.Int64Counter, metric.Int64Counter, metric.Float64Histogram) { + if meter == nil { + return noop.Int64Counter{}, noop.Int64Counter{}, noop.Float64Histogram{} + } + requestBytesCounter, err := meter.Int64Counter( + clientRequestSize, + metric.WithUnit("By"), + metric.WithDescription("Measures the size of HTTP request messages."), + ) + handleErr(err) + + responseBytesCounter, err := meter.Int64Counter( + clientResponseSize, + metric.WithUnit("By"), + metric.WithDescription("Measures the size of HTTP response messages."), + ) + handleErr(err) + + latencyMeasure, err := meter.Float64Histogram( + clientDuration, + metric.WithUnit("ms"), + metric.WithDescription("Measures the duration of outbound HTTP requests."), + ) + handleErr(err) + + return requestBytesCounter, responseBytesCounter, latencyMeasure +} + +func standardizeHTTPMethodMetric(method string) attribute.KeyValue { + method = strings.ToUpper(method) + switch method { + case http.MethodConnect, http.MethodDelete, http.MethodGet, http.MethodHead, http.MethodOptions, http.MethodPatch, http.MethodPost, http.MethodPut, http.MethodTrace: + default: + method = "_OTHER" + } + return semconv.HTTPMethod(method) +} diff --git a/vendor/go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp/internal/semconv/v1.24.0.go b/vendor/go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp/internal/semconv/v1.24.0.go deleted file mode 100644 index 0c5d4c460..000000000 --- a/vendor/go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp/internal/semconv/v1.24.0.go +++ /dev/null @@ -1,197 +0,0 @@ -// Copyright The OpenTelemetry Authors -// SPDX-License-Identifier: Apache-2.0 - -package semconv // import "go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp/internal/semconv" - -import ( - "net/http" - "strings" - - "go.opentelemetry.io/otel/attribute" - semconvNew "go.opentelemetry.io/otel/semconv/v1.24.0" -) - -type newHTTPServer struct{} - -// TraceRequest returns trace attributes for an HTTP request received by a -// server. -// -// The server must be the primary server name if it is known. For example this -// would be the ServerName directive -// (https://httpd.apache.org/docs/2.4/mod/core.html#servername) for an Apache -// server, and the server_name directive -// (http://nginx.org/en/docs/http/ngx_http_core_module.html#server_name) for an -// nginx server. More generically, the primary server name would be the host -// header value that matches the default virtual host of an HTTP server. It -// should include the host identifier and if a port is used to route to the -// server that port identifier should be included as an appropriate port -// suffix. -// -// If the primary server name is not known, server should be an empty string. -// The req Host will be used to determine the server instead. -func (n newHTTPServer) RequestTraceAttrs(server string, req *http.Request) []attribute.KeyValue { - count := 3 // ServerAddress, Method, Scheme - - var host string - var p int - if server == "" { - host, p = splitHostPort(req.Host) - } else { - // Prioritize the primary server name. - host, p = splitHostPort(server) - if p < 0 { - _, p = splitHostPort(req.Host) - } - } - - hostPort := requiredHTTPPort(req.TLS != nil, p) - if hostPort > 0 { - count++ - } - - method, methodOriginal := n.method(req.Method) - if methodOriginal != (attribute.KeyValue{}) { - count++ - } - - scheme := n.scheme(req.TLS != nil) - - if peer, peerPort := splitHostPort(req.RemoteAddr); peer != "" { - // The Go HTTP server sets RemoteAddr to "IP:port", this will not be a - // file-path that would be interpreted with a sock family. - count++ - if peerPort > 0 { - count++ - } - } - - useragent := req.UserAgent() - if useragent != "" { - count++ - } - - clientIP := serverClientIP(req.Header.Get("X-Forwarded-For")) - if clientIP != "" { - count++ - } - - if req.URL != nil && req.URL.Path != "" { - count++ - } - - protoName, protoVersion := netProtocol(req.Proto) - if protoName != "" && protoName != "http" { - count++ - } - if protoVersion != "" { - count++ - } - - attrs := make([]attribute.KeyValue, 0, count) - attrs = append(attrs, - semconvNew.ServerAddress(host), - method, - scheme, - ) - - if hostPort > 0 { - attrs = append(attrs, semconvNew.ServerPort(hostPort)) - } - if methodOriginal != (attribute.KeyValue{}) { - attrs = append(attrs, methodOriginal) - } - - if peer, peerPort := splitHostPort(req.RemoteAddr); peer != "" { - // The Go HTTP server sets RemoteAddr to "IP:port", this will not be a - // file-path that would be interpreted with a sock family. - attrs = append(attrs, semconvNew.NetworkPeerAddress(peer)) - if peerPort > 0 { - attrs = append(attrs, semconvNew.NetworkPeerPort(peerPort)) - } - } - - if useragent := req.UserAgent(); useragent != "" { - attrs = append(attrs, semconvNew.UserAgentOriginal(useragent)) - } - - if clientIP != "" { - attrs = append(attrs, semconvNew.ClientAddress(clientIP)) - } - - if req.URL != nil && req.URL.Path != "" { - attrs = append(attrs, semconvNew.URLPath(req.URL.Path)) - } - - if protoName != "" && protoName != "http" { - attrs = append(attrs, semconvNew.NetworkProtocolName(protoName)) - } - if protoVersion != "" { - attrs = append(attrs, semconvNew.NetworkProtocolVersion(protoVersion)) - } - - return attrs -} - -func (n newHTTPServer) method(method string) (attribute.KeyValue, attribute.KeyValue) { - if method == "" { - return semconvNew.HTTPRequestMethodGet, attribute.KeyValue{} - } - if attr, ok := methodLookup[method]; ok { - return attr, attribute.KeyValue{} - } - - orig := semconvNew.HTTPRequestMethodOriginal(method) - if attr, ok := methodLookup[strings.ToUpper(method)]; ok { - return attr, orig - } - return semconvNew.HTTPRequestMethodGet, orig -} - -func (n newHTTPServer) scheme(https bool) attribute.KeyValue { // nolint:revive - if https { - return semconvNew.URLScheme("https") - } - return semconvNew.URLScheme("http") -} - -// TraceResponse returns trace attributes for telemetry from an HTTP response. -// -// If any of the fields in the ResponseTelemetry are not set the attribute will be omitted. -func (n newHTTPServer) ResponseTraceAttrs(resp ResponseTelemetry) []attribute.KeyValue { - var count int - - if resp.ReadBytes > 0 { - count++ - } - if resp.WriteBytes > 0 { - count++ - } - if resp.StatusCode > 0 { - count++ - } - - attributes := make([]attribute.KeyValue, 0, count) - - if resp.ReadBytes > 0 { - attributes = append(attributes, - semconvNew.HTTPRequestBodySize(int(resp.ReadBytes)), - ) - } - if resp.WriteBytes > 0 { - attributes = append(attributes, - semconvNew.HTTPResponseBodySize(int(resp.WriteBytes)), - ) - } - if resp.StatusCode > 0 { - attributes = append(attributes, - semconvNew.HTTPResponseStatusCode(resp.StatusCode), - ) - } - - return attributes -} - -// Route returns the attribute for the route. -func (n newHTTPServer) Route(route string) attribute.KeyValue { - return semconvNew.HTTPRoute(route) -} diff --git a/vendor/go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp/internal/semconvutil/netconv.go b/vendor/go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp/internal/semconvutil/netconv.go index a9a9226b3..b80a1db61 100644 --- a/vendor/go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp/internal/semconvutil/netconv.go +++ b/vendor/go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp/internal/semconvutil/netconv.go @@ -195,7 +195,7 @@ func splitHostPort(hostport string) (host string, port int) { if err != nil { return } - return host, int(p) + return host, int(p) // nolint: gosec // Bitsize checked to be 16 above. } func netProtocol(proto string) (name string, version string) { diff --git a/vendor/go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp/start_time_context.go b/vendor/go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp/start_time_context.go new file mode 100644 index 000000000..9476ef01b --- /dev/null +++ b/vendor/go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp/start_time_context.go @@ -0,0 +1,29 @@ +// Copyright The OpenTelemetry Authors +// SPDX-License-Identifier: Apache-2.0 + +package otelhttp // import "go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp" + +import ( + "context" + "time" +) + +type startTimeContextKeyType int + +const startTimeContextKey startTimeContextKeyType = 0 + +// ContextWithStartTime returns a new context with the provided start time. The +// start time will be used for metrics and traces emitted by the +// instrumentation. Only one labeller can be injected into the context. +// Injecting it multiple times will override the previous calls. +func ContextWithStartTime(parent context.Context, start time.Time) context.Context { + return context.WithValue(parent, startTimeContextKey, start) +} + +// StartTimeFromContext retrieves a time.Time from the provided context if one +// is available. If no start time was found in the provided context, a new, +// zero start time is returned and the second return value is false. +func StartTimeFromContext(ctx context.Context) time.Time { + t, _ := ctx.Value(startTimeContextKey).(time.Time) + return t +} diff --git a/vendor/go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp/transport.go b/vendor/go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp/transport.go index 0d3cb2e4a..39681ad4b 100644 --- a/vendor/go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp/transport.go +++ b/vendor/go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp/transport.go @@ -11,13 +11,13 @@ import ( "sync/atomic" "time" - "go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp/internal/semconvutil" + "go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp/internal/request" + "go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp/internal/semconv" "go.opentelemetry.io/otel" "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/codes" - "go.opentelemetry.io/otel/metric" "go.opentelemetry.io/otel/propagation" - semconv "go.opentelemetry.io/otel/semconv/v1.20.0" + "go.opentelemetry.io/otel/trace" ) @@ -26,17 +26,15 @@ import ( type Transport struct { rt http.RoundTripper - tracer trace.Tracer - meter metric.Meter - propagators propagation.TextMapPropagator - spanStartOptions []trace.SpanStartOption - filters []Filter - spanNameFormatter func(string, *http.Request) string - clientTrace func(context.Context) *httptrace.ClientTrace - - requestBytesCounter metric.Int64Counter - responseBytesCounter metric.Int64Counter - latencyMeasure metric.Float64Histogram + tracer trace.Tracer + propagators propagation.TextMapPropagator + spanStartOptions []trace.SpanStartOption + filters []Filter + spanNameFormatter func(string, *http.Request) string + clientTrace func(context.Context) *httptrace.ClientTrace + metricAttributesFn func(*http.Request) []attribute.KeyValue + + semconv semconv.HTTPClient } var _ http.RoundTripper = &Transport{} @@ -63,43 +61,19 @@ func NewTransport(base http.RoundTripper, opts ...Option) *Transport { c := newConfig(append(defaultOpts, opts...)...) t.applyConfig(c) - t.createMeasures() return &t } func (t *Transport) applyConfig(c *config) { t.tracer = c.Tracer - t.meter = c.Meter t.propagators = c.Propagators t.spanStartOptions = c.SpanStartOptions t.filters = c.Filters t.spanNameFormatter = c.SpanNameFormatter t.clientTrace = c.ClientTrace -} - -func (t *Transport) createMeasures() { - var err error - t.requestBytesCounter, err = t.meter.Int64Counter( - clientRequestSize, - metric.WithUnit("By"), - metric.WithDescription("Measures the size of HTTP request messages."), - ) - handleErr(err) - - t.responseBytesCounter, err = t.meter.Int64Counter( - clientResponseSize, - metric.WithUnit("By"), - metric.WithDescription("Measures the size of HTTP response messages."), - ) - handleErr(err) - - t.latencyMeasure, err = t.meter.Float64Histogram( - clientDuration, - metric.WithUnit("ms"), - metric.WithDescription("Measures the duration of outbound HTTP requests."), - ) - handleErr(err) + t.semconv = semconv.NewHTTPClient(c.Meter) + t.metricAttributesFn = c.MetricAttributesFn } func defaultTransportFormatter(_ string, r *http.Request) string { @@ -143,54 +117,68 @@ func (t *Transport) RoundTrip(r *http.Request) (*http.Response, error) { r = r.Clone(ctx) // According to RoundTripper spec, we shouldn't modify the origin request. - // use a body wrapper to determine the request size - var bw bodyWrapper // if request body is nil or NoBody, we don't want to mutate the body as it // will affect the identity of it in an unforeseeable way because we assert // ReadCloser fulfills a certain interface and it is indeed nil or NoBody. + bw := request.NewBodyWrapper(r.Body, func(int64) {}) if r.Body != nil && r.Body != http.NoBody { - bw.ReadCloser = r.Body - // noop to prevent nil panic. not using this record fun yet. - bw.record = func(int64) {} - r.Body = &bw + r.Body = bw } - span.SetAttributes(semconvutil.HTTPClientRequest(r)...) + span.SetAttributes(t.semconv.RequestTraceAttrs(r)...) t.propagators.Inject(ctx, propagation.HeaderCarrier(r.Header)) res, err := t.rt.RoundTrip(r) if err != nil { - span.RecordError(err) + // set error type attribute if the error is part of the predefined + // error types. + // otherwise, record it as an exception + if errType := t.semconv.ErrorType(err); errType.Valid() { + span.SetAttributes(errType) + } else { + span.RecordError(err) + } + span.SetStatus(codes.Error, err.Error()) span.End() return res, err } // metrics - metricAttrs := append(labeler.Get(), semconvutil.HTTPClientRequestMetrics(r)...) - if res.StatusCode > 0 { - metricAttrs = append(metricAttrs, semconv.HTTPStatusCode(res.StatusCode)) - } - o := metric.WithAttributeSet(attribute.NewSet(metricAttrs...)) - addOpts := []metric.AddOption{o} // Allocate vararg slice once. - t.requestBytesCounter.Add(ctx, bw.read.Load(), addOpts...) + metricOpts := t.semconv.MetricOptions(semconv.MetricAttributes{ + Req: r, + StatusCode: res.StatusCode, + AdditionalAttributes: append(labeler.Get(), t.metricAttributesFromRequest(r)...), + }) + // For handling response bytes we leverage a callback when the client reads the http response readRecordFunc := func(n int64) { - t.responseBytesCounter.Add(ctx, n, addOpts...) + t.semconv.RecordResponseSize(ctx, n, metricOpts.AddOptions()) } // traces - span.SetAttributes(semconvutil.HTTPClientResponse(res)...) - span.SetStatus(semconvutil.HTTPClientStatus(res.StatusCode)) + span.SetAttributes(t.semconv.ResponseTraceAttrs(res)...) + span.SetStatus(t.semconv.Status(res.StatusCode)) res.Body = newWrappedBody(span, readRecordFunc, res.Body) // Use floating point division here for higher precision (instead of Millisecond method). elapsedTime := float64(time.Since(requestStartTime)) / float64(time.Millisecond) - t.latencyMeasure.Record(ctx, elapsedTime, o) + t.semconv.RecordMetrics(ctx, semconv.MetricData{ + RequestSize: bw.BytesRead(), + ElapsedTime: elapsedTime, + }, metricOpts) - return res, err + return res, nil +} + +func (t *Transport) metricAttributesFromRequest(r *http.Request) []attribute.KeyValue { + var attributeForRequest []attribute.KeyValue + if t.metricAttributesFn != nil { + attributeForRequest = t.metricAttributesFn(r) + } + return attributeForRequest } // newWrappedBody returns a new and appropriately scoped *wrappedBody as an diff --git a/vendor/go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp/version.go b/vendor/go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp/version.go index b0957f28c..353e43b91 100644 --- a/vendor/go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp/version.go +++ b/vendor/go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp/version.go @@ -5,7 +5,7 @@ package otelhttp // import "go.opentelemetry.io/contrib/instrumentation/net/http // Version is the current release version of the otelhttp instrumentation. func Version() string { - return "0.53.0" + return "0.58.0" // This string is updated by the pre_release.sh script during release } diff --git a/vendor/go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp/wrap.go b/vendor/go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp/wrap.go deleted file mode 100644 index 948f8406c..000000000 --- a/vendor/go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp/wrap.go +++ /dev/null @@ -1,99 +0,0 @@ -// Copyright The OpenTelemetry Authors -// SPDX-License-Identifier: Apache-2.0 - -package otelhttp // import "go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp" - -import ( - "context" - "io" - "net/http" - "sync/atomic" - - "go.opentelemetry.io/otel/propagation" -) - -var _ io.ReadCloser = &bodyWrapper{} - -// bodyWrapper wraps a http.Request.Body (an io.ReadCloser) to track the number -// of bytes read and the last error. -type bodyWrapper struct { - io.ReadCloser - record func(n int64) // must not be nil - - read atomic.Int64 - err error -} - -func (w *bodyWrapper) Read(b []byte) (int, error) { - n, err := w.ReadCloser.Read(b) - n1 := int64(n) - w.read.Add(n1) - w.err = err - w.record(n1) - return n, err -} - -func (w *bodyWrapper) Close() error { - return w.ReadCloser.Close() -} - -var _ http.ResponseWriter = &respWriterWrapper{} - -// respWriterWrapper wraps a http.ResponseWriter in order to track the number of -// bytes written, the last error, and to catch the first written statusCode. -// TODO: The wrapped http.ResponseWriter doesn't implement any of the optional -// types (http.Hijacker, http.Pusher, http.CloseNotifier, http.Flusher, etc) -// that may be useful when using it in real life situations. -type respWriterWrapper struct { - http.ResponseWriter - record func(n int64) // must not be nil - - // used to inject the header - ctx context.Context - - props propagation.TextMapPropagator - - written int64 - statusCode int - err error - wroteHeader bool -} - -func (w *respWriterWrapper) Header() http.Header { - return w.ResponseWriter.Header() -} - -func (w *respWriterWrapper) Write(p []byte) (int, error) { - if !w.wroteHeader { - w.WriteHeader(http.StatusOK) - } - n, err := w.ResponseWriter.Write(p) - n1 := int64(n) - w.record(n1) - w.written += n1 - w.err = err - return n, err -} - -// WriteHeader persists initial statusCode for span attribution. -// All calls to WriteHeader will be propagated to the underlying ResponseWriter -// and will persist the statusCode from the first call. -// Blocking consecutive calls to WriteHeader alters expected behavior and will -// remove warning logs from net/http where developers will notice incorrect handler implementations. -func (w *respWriterWrapper) WriteHeader(statusCode int) { - if !w.wroteHeader { - w.wroteHeader = true - w.statusCode = statusCode - } - w.ResponseWriter.WriteHeader(statusCode) -} - -func (w *respWriterWrapper) Flush() { - if !w.wroteHeader { - w.WriteHeader(http.StatusOK) - } - - if f, ok := w.ResponseWriter.(http.Flusher); ok { - f.Flush() - } -} diff --git a/vendor/go.opentelemetry.io/otel/.gitignore b/vendor/go.opentelemetry.io/otel/.gitignore index 895c7664b..ae8577ef3 100644 --- a/vendor/go.opentelemetry.io/otel/.gitignore +++ b/vendor/go.opentelemetry.io/otel/.gitignore @@ -12,11 +12,3 @@ go.work go.work.sum gen/ - -/example/dice/dice -/example/namedtracer/namedtracer -/example/otel-collector/otel-collector -/example/opencensus/opencensus -/example/passthrough/passthrough -/example/prometheus/prometheus -/example/zipkin/zipkin diff --git a/vendor/go.opentelemetry.io/otel/.golangci.yml b/vendor/go.opentelemetry.io/otel/.golangci.yml index d09555506..ce3f40b60 100644 --- a/vendor/go.opentelemetry.io/otel/.golangci.yml +++ b/vendor/go.opentelemetry.io/otel/.golangci.yml @@ -22,6 +22,7 @@ linters: - govet - ineffassign - misspell + - perfsprint - revive - staticcheck - tenv @@ -30,6 +31,7 @@ linters: - unconvert - unused - unparam + - usestdlibvars issues: # Maximum issues count per one linter. @@ -61,10 +63,11 @@ issues: text: "calls to (.+) only in main[(][)] or init[(][)] functions" linters: - revive - # It's okay to not run gosec in a test. + # It's okay to not run gosec and perfsprint in a test. - path: _test\.go linters: - gosec + - perfsprint # Ignoring gosec G404: Use of weak random number generator (math/rand instead of crypto/rand) # as we commonly use it in tests and examples. - text: "G404:" @@ -95,6 +98,13 @@ linters-settings: - pkg: "crypto/md5" - pkg: "crypto/sha1" - pkg: "crypto/**/pkix" + auto/sdk: + files: + - "!internal/global/trace.go" + - "~internal/global/trace_test.go" + deny: + - pkg: "go.opentelemetry.io/auto/sdk" + desc: Do not use SDK from automatic instrumentation. otlp-internal: files: - "!**/exporters/otlp/internal/**/*.go" @@ -127,8 +137,6 @@ linters-settings: - "**/metric/**/*.go" - "**/bridge/*.go" - "**/bridge/**/*.go" - - "**/example/*.go" - - "**/example/**/*.go" - "**/trace/*.go" - "**/trace/**/*.go" - "**/log/*.go" @@ -156,6 +164,12 @@ linters-settings: locale: US ignore-words: - cancelled + perfsprint: + err-error: true + errorf: true + int-conversion: true + sprintf1: true + strconcat: true revive: # Sets the default failure confidence. # This means that linting errors with less than 0.8 confidence will be ignored. diff --git a/vendor/go.opentelemetry.io/otel/CHANGELOG.md b/vendor/go.opentelemetry.io/otel/CHANGELOG.md index 4b361d026..a30988f25 100644 --- a/vendor/go.opentelemetry.io/otel/CHANGELOG.md +++ b/vendor/go.opentelemetry.io/otel/CHANGELOG.md @@ -8,9 +8,84 @@ This project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.htm ## [Unreleased] +## [1.33.0/0.55.0/0.9.0/0.0.12] 2024-12-12 + +### Added + +- Add `Reset` method to `SpanRecorder` in `go.opentelemetry.io/otel/sdk/trace/tracetest`. (#5994) +- Add `EnabledInstrument` interface in `go.opentelemetry.io/otel/sdk/metric/internal/x`. + This is an experimental interface that is implemented by synchronous instruments provided by `go.opentelemetry.io/otel/sdk/metric`. + Users can use it to avoid performing computationally expensive operations when recording measurements. + It does not fall within the scope of the OpenTelemetry Go versioning and stability [policy](./VERSIONING.md) and it may be changed in backwards incompatible ways or removed in feature releases. (#6016) + +### Changed + +- The default global API now supports full auto-instrumentation from the `go.opentelemetry.io/auto` package. + See that package for more information. (#5920) +- Propagate non-retryable error messages to client in `go.opentelemetry.io/otel/exporters/otlp/otlplog/otlploghttp`. (#5929) +- Propagate non-retryable error messages to client in `go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp`. (#5929) +- Propagate non-retryable error messages to client in `go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetrichttp`. (#5929) +- Performance improvements for attribute value `AsStringSlice`, `AsFloat64Slice`, `AsInt64Slice`, `AsBoolSlice`. (#6011) +- Change `EnabledParameters` to have a `Severity` field instead of a getter and setter in `go.opentelemetry.io/otel/log`. (#6009) + +### Fixed + +- Fix inconsistent request body closing in `go.opentelemetry.io/otel/exporters/otlp/otlplog/otlploghttp`. (#5954) +- Fix inconsistent request body closing in `go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp`. (#5954) +- Fix inconsistent request body closing in `go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetrichttp`. (#5954) +- Fix invalid exemplar keys in `go.opentelemetry.io/otel/exporters/prometheus`. (#5995) +- Fix attribute value truncation in `go.opentelemetry.io/otel/sdk/trace`. (#5997) +- Fix attribute value truncation in `go.opentelemetry.io/otel/sdk/log`. (#6032) + +## [1.32.0/0.54.0/0.8.0/0.0.11] 2024-11-08 + +### Added + +- Add `go.opentelemetry.io/otel/sdk/metric/exemplar.AlwaysOffFilter`, which can be used to disable exemplar recording. (#5850) +- Add `go.opentelemetry.io/otel/sdk/metric.WithExemplarFilter`, which can be used to configure the exemplar filter used by the metrics SDK. (#5850) +- Add `ExemplarReservoirProviderSelector` and `DefaultExemplarReservoirProviderSelector` to `go.opentelemetry.io/otel/sdk/metric`, which defines the exemplar reservoir to use based on the aggregation of the metric. (#5861) +- Add `ExemplarReservoirProviderSelector` to `go.opentelemetry.io/otel/sdk/metric.Stream` to allow using views to configure the exemplar reservoir to use for a metric. (#5861) +- Add `ReservoirProvider`, `HistogramReservoirProvider` and `FixedSizeReservoirProvider` to `go.opentelemetry.io/otel/sdk/metric/exemplar` to make it convenient to use providers of Reservoirs. (#5861) +- The `go.opentelemetry.io/otel/semconv/v1.27.0` package. + The package contains semantic conventions from the `v1.27.0` version of the OpenTelemetry Semantic Conventions. (#5894) +- Add `Attributes attribute.Set` field to `Scope` in `go.opentelemetry.io/otel/sdk/instrumentation`. (#5903) +- Add `Attributes attribute.Set` field to `ScopeRecords` in `go.opentelemetry.io/otel/log/logtest`. (#5927) +- `go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc` adds instrumentation scope attributes. (#5934) +- `go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp` adds instrumentation scope attributes. (#5934) +- `go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetricgrpc` adds instrumentation scope attributes. (#5935) +- `go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetrichttp` adds instrumentation scope attributes. (#5935) +- `go.opentelemetry.io/otel/exporters/otlp/otlplog/otlploggrpc` adds instrumentation scope attributes. (#5933) +- `go.opentelemetry.io/otel/exporters/otlp/otlplog/otlploghttp` adds instrumentation scope attributes. (#5933) +- `go.opentelemetry.io/otel/exporters/prometheus` adds instrumentation scope attributes in `otel_scope_info` metric as labels. (#5932) + +### Changed + +- Support scope attributes and make them as identifying for `Tracer` in `go.opentelemetry.io/otel` and `go.opentelemetry.io/otel/sdk/trace`. (#5924) +- Support scope attributes and make them as identifying for `Meter` in `go.opentelemetry.io/otel` and `go.opentelemetry.io/otel/sdk/metric`. (#5926) +- Support scope attributes and make them as identifying for `Logger` in `go.opentelemetry.io/otel` and `go.opentelemetry.io/otel/sdk/log`. (#5925) +- Make schema URL and scope attributes as identifying for `Tracer` in `go.opentelemetry.io/otel/bridge/opentracing`. (#5931) +- Clear unneeded slice elements to allow GC to collect the objects in `go.opentelemetry.io/otel/sdk/metric` and `go.opentelemetry.io/otel/sdk/trace`. (#5804) + +### Fixed + +- Global MeterProvider registration unwraps global instrument Observers, the undocumented Unwrap() methods are now private. (#5881) +- `go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetricgrpc` now keeps the metadata already present in the context when `WithHeaders` is used. (#5892) +- `go.opentelemetry.io/otel/exporters/otlp/otlplog/otlploggrpc` now keeps the metadata already present in the context when `WithHeaders` is used. (#5911) +- `go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc` now keeps the metadata already present in the context when `WithHeaders` is used. (#5915) +- Fix `go.opentelemetry.io/otel/exporters/prometheus` trying to add exemplars to Gauge metrics, which is unsupported. (#5912) +- Fix `WithEndpointURL` to always use a secure connection when an https URL is passed in `go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetricgrpc`. (#5944) +- Fix `WithEndpointURL` to always use a secure connection when an https URL is passed in `go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetrichttp`. (#5944) +- Fix `WithEndpointURL` to always use a secure connection when an https URL is passed in `go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc`. (#5944) +- Fix `WithEndpointURL` to always use a secure connection when an https URL is passed in `go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp`. (#5944) +- Fix incorrect metrics generated from callbacks when multiple readers are used in `go.opentelemetry.io/otel/sdk/metric`. (#5900) + +### Removed + +- Remove all examples under `go.opentelemetry.io/otel/example` as they are moved to [Contrib repository](https://github.com/open-telemetry/opentelemetry-go-contrib/tree/main/examples). (#5930) + ## [1.31.0/0.53.0/0.7.0/0.0.10] 2024-10-11 ### Added @@ -3110,7 +3185,9 @@ It contains api and sdk for trace and meter. - CircleCI build CI manifest files. - CODEOWNERS file to track owners of this project. -[Unreleased]: https://github.com/open-telemetry/opentelemetry-go/compare/v1.31.0...HEAD +[Unreleased]: https://github.com/open-telemetry/opentelemetry-go/compare/v1.33.0...HEAD +[1.33.0/0.55.0/0.9.0/0.0.12]: https://github.com/open-telemetry/opentelemetry-go/releases/tag/v1.33.0 +[1.32.0/0.54.0/0.8.0/0.0.11]: https://github.com/open-telemetry/opentelemetry-go/releases/tag/v1.32.0 [1.31.0/0.53.0/0.7.0/0.0.10]: https://github.com/open-telemetry/opentelemetry-go/releases/tag/v1.31.0 [1.30.0/0.52.0/0.6.0/0.0.9]: https://github.com/open-telemetry/opentelemetry-go/releases/tag/v1.30.0 [1.29.0/0.51.0/0.5.0]: https://github.com/open-telemetry/opentelemetry-go/releases/tag/v1.29.0 diff --git a/vendor/go.opentelemetry.io/otel/CONTRIBUTING.md b/vendor/go.opentelemetry.io/otel/CONTRIBUTING.md index bb3396557..22a2e9dbd 100644 --- a/vendor/go.opentelemetry.io/otel/CONTRIBUTING.md +++ b/vendor/go.opentelemetry.io/otel/CONTRIBUTING.md @@ -629,6 +629,10 @@ should be canceled. ## Approvers and Maintainers +### Triagers + +- [Cheng-Zhen Yang](https://github.com/scorpionknifes), Independent + ### Approvers ### Maintainers @@ -641,13 +645,13 @@ should be canceled. ### Emeritus -- [Aaron Clawson](https://github.com/MadVikingGod), LightStep -- [Anthony Mirabella](https://github.com/Aneurysm9), AWS -- [Chester Cheung](https://github.com/hanyuancheung), Tencent -- [Evan Torrie](https://github.com/evantorrie), Yahoo -- [Gustavo Silva Paiva](https://github.com/paivagustavo), LightStep -- [Josh MacDonald](https://github.com/jmacd), LightStep -- [Liz Fong-Jones](https://github.com/lizthegrey), Honeycomb +- [Aaron Clawson](https://github.com/MadVikingGod) +- [Anthony Mirabella](https://github.com/Aneurysm9) +- [Chester Cheung](https://github.com/hanyuancheung) +- [Evan Torrie](https://github.com/evantorrie) +- [Gustavo Silva Paiva](https://github.com/paivagustavo) +- [Josh MacDonald](https://github.com/jmacd) +- [Liz Fong-Jones](https://github.com/lizthegrey) ### Become an Approver or a Maintainer diff --git a/vendor/go.opentelemetry.io/otel/Makefile b/vendor/go.opentelemetry.io/otel/Makefile index a1228a212..a7f6d8cc6 100644 --- a/vendor/go.opentelemetry.io/otel/Makefile +++ b/vendor/go.opentelemetry.io/otel/Makefile @@ -14,8 +14,8 @@ TIMEOUT = 60 .DEFAULT_GOAL := precommit .PHONY: precommit ci -precommit: generate license-check misspell go-mod-tidy golangci-lint-fix verify-readmes verify-mods test-default -ci: generate license-check lint vanity-import-check verify-readmes verify-mods build test-default check-clean-work-tree test-coverage +precommit: generate toolchain-check license-check misspell go-mod-tidy golangci-lint-fix verify-readmes verify-mods test-default +ci: generate toolchain-check license-check lint vanity-import-check verify-readmes verify-mods build test-default check-clean-work-tree test-coverage # Tools @@ -235,6 +235,16 @@ govulncheck/%: $(GOVULNCHECK) codespell: $(CODESPELL) @$(DOCKERPY) $(CODESPELL) +.PHONY: toolchain-check +toolchain-check: + @toolchainRes=$$(for f in $(ALL_GO_MOD_DIRS); do \ + awk '/^toolchain/ { found=1; next } END { if (found) print FILENAME }' $$f/go.mod; \ + done); \ + if [ -n "$${toolchainRes}" ]; then \ + echo "toolchain checking failed:"; echo "$${toolchainRes}"; \ + exit 1; \ + fi + .PHONY: license-check license-check: @licRes=$$(for f in $$(find . -type f \( -iname '*.go' -o -iname '*.sh' \) ! -path '**/third_party/*' ! -path './.git/*' ) ; do \ @@ -260,7 +270,7 @@ SEMCONVPKG ?= "semconv/" semconv-generate: $(SEMCONVGEN) $(SEMCONVKIT) [ "$(TAG)" ] || ( echo "TAG unset: missing opentelemetry semantic-conventions tag"; exit 1 ) [ "$(OTEL_SEMCONV_REPO)" ] || ( echo "OTEL_SEMCONV_REPO unset: missing path to opentelemetry semantic-conventions repo"; exit 1 ) - $(SEMCONVGEN) -i "$(OTEL_SEMCONV_REPO)/model/." --only=attribute_group -p conventionType=trace -f attribute_group.go -t "$(SEMCONVPKG)/template.j2" -s "$(TAG)" + $(SEMCONVGEN) -i "$(OTEL_SEMCONV_REPO)/model/." --only=attribute_group -p conventionType=trace -f attribute_group.go -z "$(SEMCONVPKG)/capitalizations.txt" -t "$(SEMCONVPKG)/template.j2" -s "$(TAG)" $(SEMCONVGEN) -i "$(OTEL_SEMCONV_REPO)/model/." --only=metric -f metric.go -t "$(SEMCONVPKG)/metric_template.j2" -s "$(TAG)" $(SEMCONVKIT) -output "$(SEMCONVPKG)/$(TAG)" -tag "$(TAG)" diff --git a/vendor/go.opentelemetry.io/otel/VERSIONING.md b/vendor/go.opentelemetry.io/otel/VERSIONING.md index 412f1e362..b8cb605c1 100644 --- a/vendor/go.opentelemetry.io/otel/VERSIONING.md +++ b/vendor/go.opentelemetry.io/otel/VERSIONING.md @@ -26,7 +26,7 @@ is designed so the following goals can be achieved. go.opentelemetry.io/otel/v2 v2.0.1`) and in the package import path (e.g., `import "go.opentelemetry.io/otel/v2/trace"`). This includes the paths used in `go get` commands (e.g., `go get - go.opentelemetry.io/otel/v2@v2.0.1`. Note there is both a `/v2` and a + go.opentelemetry.io/otel/v2@v2.0.1`). Note there is both a `/v2` and a `@v2.0.1` in that example. One way to think about it is that the module name now includes the `/v2`, so include `/v2` whenever you are using the module name). diff --git a/vendor/go.opentelemetry.io/otel/baggage/baggage.go b/vendor/go.opentelemetry.io/otel/baggage/baggage.go index 36f536703..0e1fe2422 100644 --- a/vendor/go.opentelemetry.io/otel/baggage/baggage.go +++ b/vendor/go.opentelemetry.io/otel/baggage/baggage.go @@ -355,7 +355,7 @@ func parseMember(member string) (Member, error) { } // replaceInvalidUTF8Sequences replaces invalid UTF-8 sequences with '�'. -func replaceInvalidUTF8Sequences(cap int, unescapeVal string) string { +func replaceInvalidUTF8Sequences(c int, unescapeVal string) string { if utf8.ValidString(unescapeVal) { return unescapeVal } @@ -363,7 +363,7 @@ func replaceInvalidUTF8Sequences(cap int, unescapeVal string) string { // https://github.com/w3c/baggage/blob/8c215efbeebd3fa4b1aceb937a747e56444f22f3/baggage/HTTP_HEADER_FORMAT.md?plain=1#L69 var b strings.Builder - b.Grow(cap) + b.Grow(c) for i := 0; i < len(unescapeVal); { r, size := utf8.DecodeRuneInString(unescapeVal[i:]) if r == utf8.RuneError && size == 1 { diff --git a/vendor/go.opentelemetry.io/otel/codes/codes.go b/vendor/go.opentelemetry.io/otel/codes/codes.go index 2acbac354..49a35b122 100644 --- a/vendor/go.opentelemetry.io/otel/codes/codes.go +++ b/vendor/go.opentelemetry.io/otel/codes/codes.go @@ -5,6 +5,7 @@ package codes // import "go.opentelemetry.io/otel/codes" import ( "encoding/json" + "errors" "fmt" "strconv" ) @@ -63,7 +64,7 @@ func (c *Code) UnmarshalJSON(b []byte) error { return nil } if c == nil { - return fmt.Errorf("nil receiver passed to UnmarshalJSON") + return errors.New("nil receiver passed to UnmarshalJSON") } var x interface{} diff --git a/vendor/go.opentelemetry.io/otel/exporters/otlp/otlptrace/internal/tracetransform/instrumentation.go b/vendor/go.opentelemetry.io/otel/exporters/otlp/otlptrace/internal/tracetransform/instrumentation.go index f6dd3decc..2e7690e43 100644 --- a/vendor/go.opentelemetry.io/otel/exporters/otlp/otlptrace/internal/tracetransform/instrumentation.go +++ b/vendor/go.opentelemetry.io/otel/exporters/otlp/otlptrace/internal/tracetransform/instrumentation.go @@ -13,7 +13,8 @@ func InstrumentationScope(il instrumentation.Scope) *commonpb.InstrumentationSco return nil } return &commonpb.InstrumentationScope{ - Name: il.Name, - Version: il.Version, + Name: il.Name, + Version: il.Version, + Attributes: Iterator(il.Attributes.Iter()), } } diff --git a/vendor/go.opentelemetry.io/otel/exporters/otlp/otlptrace/internal/tracetransform/span.go b/vendor/go.opentelemetry.io/otel/exporters/otlp/otlptrace/internal/tracetransform/span.go index c3c69c5a0..bf27ef022 100644 --- a/vendor/go.opentelemetry.io/otel/exporters/otlp/otlptrace/internal/tracetransform/span.go +++ b/vendor/go.opentelemetry.io/otel/exporters/otlp/otlptrace/internal/tracetransform/span.go @@ -4,6 +4,8 @@ package tracetransform // import "go.opentelemetry.io/otel/exporters/otlp/otlptrace/internal/tracetransform" import ( + "math" + "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/codes" "go.opentelemetry.io/otel/sdk/instrumentation" @@ -95,16 +97,16 @@ func span(sd tracesdk.ReadOnlySpan) *tracepb.Span { SpanId: sid[:], TraceState: sd.SpanContext().TraceState().String(), Status: status(sd.Status().Code, sd.Status().Description), - StartTimeUnixNano: uint64(sd.StartTime().UnixNano()), - EndTimeUnixNano: uint64(sd.EndTime().UnixNano()), + StartTimeUnixNano: uint64(max(0, sd.StartTime().UnixNano())), // nolint:gosec // Overflow checked. + EndTimeUnixNano: uint64(max(0, sd.EndTime().UnixNano())), // nolint:gosec // Overflow checked. Links: links(sd.Links()), Kind: spanKind(sd.SpanKind()), Name: sd.Name(), Attributes: KeyValues(sd.Attributes()), Events: spanEvents(sd.Events()), - DroppedAttributesCount: uint32(sd.DroppedAttributes()), - DroppedEventsCount: uint32(sd.DroppedEvents()), - DroppedLinksCount: uint32(sd.DroppedLinks()), + DroppedAttributesCount: clampUint32(sd.DroppedAttributes()), + DroppedEventsCount: clampUint32(sd.DroppedEvents()), + DroppedLinksCount: clampUint32(sd.DroppedLinks()), } if psid := sd.Parent().SpanID(); psid.IsValid() { @@ -115,6 +117,16 @@ func span(sd tracesdk.ReadOnlySpan) *tracepb.Span { return s } +func clampUint32(v int) uint32 { + if v < 0 { + return 0 + } + if int64(v) > math.MaxUint32 { + return math.MaxUint32 + } + return uint32(v) // nolint: gosec // Overflow/Underflow checked. +} + // status transform a span code and message into an OTLP span status. func status(status codes.Code, message string) *tracepb.Status { var c tracepb.Status_StatusCode @@ -153,7 +165,7 @@ func links(links []tracesdk.Link) []*tracepb.Span_Link { TraceId: tid[:], SpanId: sid[:], Attributes: KeyValues(otLink.Attributes), - DroppedAttributesCount: uint32(otLink.DroppedAttributeCount), + DroppedAttributesCount: clampUint32(otLink.DroppedAttributeCount), Flags: flags, }) } @@ -166,7 +178,7 @@ func buildSpanFlags(sc trace.SpanContext) uint32 { flags |= tracepb.SpanFlags_SPAN_FLAGS_CONTEXT_IS_REMOTE_MASK } - return uint32(flags) + return uint32(flags) // nolint:gosec // Flags is a bitmask and can't be negative } // spanEvents transforms span Events to an OTLP span events. @@ -180,9 +192,9 @@ func spanEvents(es []tracesdk.Event) []*tracepb.Span_Event { for i := 0; i < len(es); i++ { events[i] = &tracepb.Span_Event{ Name: es[i].Name, - TimeUnixNano: uint64(es[i].Time.UnixNano()), + TimeUnixNano: uint64(max(0, es[i].Time.UnixNano())), // nolint:gosec // Overflow checked. Attributes: KeyValues(es[i].Attributes), - DroppedAttributesCount: uint32(es[i].DroppedAttributeCount), + DroppedAttributesCount: clampUint32(es[i].DroppedAttributeCount), } } return events diff --git a/vendor/go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc/client.go b/vendor/go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc/client.go index 3993df927..2171bee3c 100644 --- a/vendor/go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc/client.go +++ b/vendor/go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc/client.go @@ -229,7 +229,12 @@ func (c *client) exportContext(parent context.Context) (context.Context, context } if c.metadata.Len() > 0 { - ctx = metadata.NewOutgoingContext(ctx, c.metadata) + md := c.metadata + if outMD, ok := metadata.FromOutgoingContext(ctx); ok { + md = metadata.Join(md, outMD) + } + + ctx = metadata.NewOutgoingContext(ctx, md) } // Unify the client stopCtx with the parent. diff --git a/vendor/go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc/doc.go b/vendor/go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc/doc.go index e783b57ac..b7bd429ff 100644 --- a/vendor/go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc/doc.go +++ b/vendor/go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc/doc.go @@ -12,9 +12,8 @@ The environment variables described below can be used for configuration. OTEL_EXPORTER_OTLP_ENDPOINT, OTEL_EXPORTER_OTLP_TRACES_ENDPOINT (default: "https://localhost:4317") - target to which the exporter sends telemetry. The target syntax is defined in https://github.com/grpc/grpc/blob/master/doc/naming.md. -The value must contain a host. -The value may additionally a port, a scheme, and a path. -The value accepts "http" and "https" scheme. +The value must contain a scheme ("http" or "https") and host. +The value may additionally contain a port, and a path. The value should not contain a query string or fragment. OTEL_EXPORTER_OTLP_TRACES_ENDPOINT takes precedence over OTEL_EXPORTER_OTLP_ENDPOINT. The configuration can be overridden by [WithEndpoint], [WithEndpointURL], [WithInsecure], and [WithGRPCConn] options. diff --git a/vendor/go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc/internal/envconfig/envconfig.go b/vendor/go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc/internal/envconfig/envconfig.go index 9513c0a57..4abf48d1f 100644 --- a/vendor/go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc/internal/envconfig/envconfig.go +++ b/vendor/go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc/internal/envconfig/envconfig.go @@ -15,6 +15,7 @@ import ( "strconv" "strings" "time" + "unicode" "go.opentelemetry.io/otel/internal/global" ) @@ -163,12 +164,16 @@ func stringToHeader(value string) map[string]string { global.Error(errors.New("missing '="), "parse headers", "input", header) continue } - name, err := url.PathUnescape(n) - if err != nil { - global.Error(err, "escape header key", "key", n) + + trimmedName := strings.TrimSpace(n) + + // Validate the key. + if !isValidHeaderKey(trimmedName) { + global.Error(errors.New("invalid header key"), "parse headers", "key", trimmedName) continue } - trimmedName := strings.TrimSpace(name) + + // Only decode the value. value, err := url.PathUnescape(v) if err != nil { global.Error(err, "escape header value", "value", v) @@ -189,3 +194,22 @@ func createCertPool(certBytes []byte) (*x509.CertPool, error) { } return cp, nil } + +func isValidHeaderKey(key string) bool { + if key == "" { + return false + } + for _, c := range key { + if !isTokenChar(c) { + return false + } + } + return true +} + +func isTokenChar(c rune) bool { + return c <= unicode.MaxASCII && (unicode.IsLetter(c) || + unicode.IsDigit(c) || + c == '!' || c == '#' || c == '$' || c == '%' || c == '&' || c == '\'' || c == '*' || + c == '+' || c == '-' || c == '.' || c == '^' || c == '_' || c == '`' || c == '|' || c == '~') +} diff --git a/vendor/go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc/internal/otlpconfig/options.go b/vendor/go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc/internal/otlpconfig/options.go index 8f84a7996..0a317d926 100644 --- a/vendor/go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc/internal/otlpconfig/options.go +++ b/vendor/go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc/internal/otlpconfig/options.go @@ -98,7 +98,7 @@ func cleanPath(urlPath string, defaultPath string) string { return defaultPath } if !path.IsAbs(tmp) { - tmp = fmt.Sprintf("/%s", tmp) + tmp = "/" + tmp } return tmp } @@ -125,7 +125,7 @@ func NewGRPCConfig(opts ...GRPCOption) Config { if cfg.ServiceConfig != "" { cfg.DialOptions = append(cfg.DialOptions, grpc.WithDefaultServiceConfig(cfg.ServiceConfig)) } - // Priroritize GRPCCredentials over Insecure (passing both is an error). + // Prioritize GRPCCredentials over Insecure (passing both is an error). if cfg.Traces.GRPCCredentials != nil { cfg.DialOptions = append(cfg.DialOptions, grpc.WithTransportCredentials(cfg.Traces.GRPCCredentials)) } else if cfg.Traces.Insecure { @@ -278,9 +278,7 @@ func WithEndpointURL(v string) GenericOption { cfg.Traces.Endpoint = u.Host cfg.Traces.URLPath = u.Path - if u.Scheme != "https" { - cfg.Traces.Insecure = true - } + cfg.Traces.Insecure = u.Scheme != "https" return cfg }) diff --git a/vendor/go.opentelemetry.io/otel/exporters/otlp/otlptrace/version.go b/vendor/go.opentelemetry.io/otel/exporters/otlp/otlptrace/version.go index 14ad8c33b..8ea156a09 100644 --- a/vendor/go.opentelemetry.io/otel/exporters/otlp/otlptrace/version.go +++ b/vendor/go.opentelemetry.io/otel/exporters/otlp/otlptrace/version.go @@ -5,5 +5,5 @@ package otlptrace // import "go.opentelemetry.io/otel/exporters/otlp/otlptrace" // Version is the current release version of the OpenTelemetry OTLP trace exporter in use. func Version() string { - return "1.28.0" + return "1.33.0" } diff --git a/vendor/go.opentelemetry.io/otel/internal/attribute/attribute.go b/vendor/go.opentelemetry.io/otel/internal/attribute/attribute.go index 822d84794..691d96c75 100644 --- a/vendor/go.opentelemetry.io/otel/internal/attribute/attribute.go +++ b/vendor/go.opentelemetry.io/otel/internal/attribute/attribute.go @@ -49,12 +49,11 @@ func AsBoolSlice(v interface{}) []bool { if rv.Type().Kind() != reflect.Array { return nil } - var zero bool - correctLen := rv.Len() - correctType := reflect.ArrayOf(correctLen, reflect.TypeOf(zero)) - cpy := reflect.New(correctType) - _ = reflect.Copy(cpy.Elem(), rv) - return cpy.Elem().Slice(0, correctLen).Interface().([]bool) + cpy := make([]bool, rv.Len()) + if len(cpy) > 0 { + _ = reflect.Copy(reflect.ValueOf(cpy), rv) + } + return cpy } // AsInt64Slice converts an int64 array into a slice into with same elements as array. @@ -63,12 +62,11 @@ func AsInt64Slice(v interface{}) []int64 { if rv.Type().Kind() != reflect.Array { return nil } - var zero int64 - correctLen := rv.Len() - correctType := reflect.ArrayOf(correctLen, reflect.TypeOf(zero)) - cpy := reflect.New(correctType) - _ = reflect.Copy(cpy.Elem(), rv) - return cpy.Elem().Slice(0, correctLen).Interface().([]int64) + cpy := make([]int64, rv.Len()) + if len(cpy) > 0 { + _ = reflect.Copy(reflect.ValueOf(cpy), rv) + } + return cpy } // AsFloat64Slice converts a float64 array into a slice into with same elements as array. @@ -77,12 +75,11 @@ func AsFloat64Slice(v interface{}) []float64 { if rv.Type().Kind() != reflect.Array { return nil } - var zero float64 - correctLen := rv.Len() - correctType := reflect.ArrayOf(correctLen, reflect.TypeOf(zero)) - cpy := reflect.New(correctType) - _ = reflect.Copy(cpy.Elem(), rv) - return cpy.Elem().Slice(0, correctLen).Interface().([]float64) + cpy := make([]float64, rv.Len()) + if len(cpy) > 0 { + _ = reflect.Copy(reflect.ValueOf(cpy), rv) + } + return cpy } // AsStringSlice converts a string array into a slice into with same elements as array. @@ -91,10 +88,9 @@ func AsStringSlice(v interface{}) []string { if rv.Type().Kind() != reflect.Array { return nil } - var zero string - correctLen := rv.Len() - correctType := reflect.ArrayOf(correctLen, reflect.TypeOf(zero)) - cpy := reflect.New(correctType) - _ = reflect.Copy(cpy.Elem(), rv) - return cpy.Elem().Slice(0, correctLen).Interface().([]string) + cpy := make([]string, rv.Len()) + if len(cpy) > 0 { + _ = reflect.Copy(reflect.ValueOf(cpy), rv) + } + return cpy } diff --git a/vendor/go.opentelemetry.io/otel/internal/global/instruments.go b/vendor/go.opentelemetry.io/otel/internal/global/instruments.go index 3a0cc42f6..ae92a4251 100644 --- a/vendor/go.opentelemetry.io/otel/internal/global/instruments.go +++ b/vendor/go.opentelemetry.io/otel/internal/global/instruments.go @@ -13,7 +13,7 @@ import ( // unwrapper unwraps to return the underlying instrument implementation. type unwrapper interface { - Unwrap() metric.Observable + unwrap() metric.Observable } type afCounter struct { @@ -40,7 +40,7 @@ func (i *afCounter) setDelegate(m metric.Meter) { i.delegate.Store(ctr) } -func (i *afCounter) Unwrap() metric.Observable { +func (i *afCounter) unwrap() metric.Observable { if ctr := i.delegate.Load(); ctr != nil { return ctr.(metric.Float64ObservableCounter) } @@ -71,7 +71,7 @@ func (i *afUpDownCounter) setDelegate(m metric.Meter) { i.delegate.Store(ctr) } -func (i *afUpDownCounter) Unwrap() metric.Observable { +func (i *afUpDownCounter) unwrap() metric.Observable { if ctr := i.delegate.Load(); ctr != nil { return ctr.(metric.Float64ObservableUpDownCounter) } @@ -102,7 +102,7 @@ func (i *afGauge) setDelegate(m metric.Meter) { i.delegate.Store(ctr) } -func (i *afGauge) Unwrap() metric.Observable { +func (i *afGauge) unwrap() metric.Observable { if ctr := i.delegate.Load(); ctr != nil { return ctr.(metric.Float64ObservableGauge) } @@ -133,7 +133,7 @@ func (i *aiCounter) setDelegate(m metric.Meter) { i.delegate.Store(ctr) } -func (i *aiCounter) Unwrap() metric.Observable { +func (i *aiCounter) unwrap() metric.Observable { if ctr := i.delegate.Load(); ctr != nil { return ctr.(metric.Int64ObservableCounter) } @@ -164,7 +164,7 @@ func (i *aiUpDownCounter) setDelegate(m metric.Meter) { i.delegate.Store(ctr) } -func (i *aiUpDownCounter) Unwrap() metric.Observable { +func (i *aiUpDownCounter) unwrap() metric.Observable { if ctr := i.delegate.Load(); ctr != nil { return ctr.(metric.Int64ObservableUpDownCounter) } @@ -195,7 +195,7 @@ func (i *aiGauge) setDelegate(m metric.Meter) { i.delegate.Store(ctr) } -func (i *aiGauge) Unwrap() metric.Observable { +func (i *aiGauge) unwrap() metric.Observable { if ctr := i.delegate.Load(); ctr != nil { return ctr.(metric.Int64ObservableGauge) } diff --git a/vendor/go.opentelemetry.io/otel/internal/global/meter.go b/vendor/go.opentelemetry.io/otel/internal/global/meter.go index e3db438a0..a6acd8dca 100644 --- a/vendor/go.opentelemetry.io/otel/internal/global/meter.go +++ b/vendor/go.opentelemetry.io/otel/internal/global/meter.go @@ -5,6 +5,7 @@ package global // import "go.opentelemetry.io/otel/internal/global" import ( "container/list" + "context" "reflect" "sync" @@ -66,6 +67,7 @@ func (p *meterProvider) Meter(name string, opts ...metric.MeterOption) metric.Me name: name, version: c.InstrumentationVersion(), schema: c.SchemaURL(), + attrs: c.InstrumentationAttributes(), } if p.meters == nil { @@ -472,8 +474,7 @@ func (m *meter) RegisterCallback(f metric.Callback, insts ...metric.Observable) defer m.mtx.Unlock() if m.delegate != nil { - insts = unwrapInstruments(insts) - return m.delegate.RegisterCallback(f, insts...) + return m.delegate.RegisterCallback(unwrapCallback(f), unwrapInstruments(insts)...) } reg := ®istration{instruments: insts, function: f} @@ -487,15 +488,11 @@ func (m *meter) RegisterCallback(f metric.Callback, insts ...metric.Observable) return reg, nil } -type wrapped interface { - unwrap() metric.Observable -} - func unwrapInstruments(instruments []metric.Observable) []metric.Observable { out := make([]metric.Observable, 0, len(instruments)) for _, inst := range instruments { - if in, ok := inst.(wrapped); ok { + if in, ok := inst.(unwrapper); ok { out = append(out, in.unwrap()) } else { out = append(out, inst) @@ -515,9 +512,61 @@ type registration struct { unregMu sync.Mutex } -func (c *registration) setDelegate(m metric.Meter) { - insts := unwrapInstruments(c.instruments) +type unwrapObs struct { + embedded.Observer + obs metric.Observer +} + +// unwrapFloat64Observable returns an expected metric.Float64Observable after +// unwrapping the global object. +func unwrapFloat64Observable(inst metric.Float64Observable) metric.Float64Observable { + if unwrapped, ok := inst.(unwrapper); ok { + if floatObs, ok := unwrapped.unwrap().(metric.Float64Observable); ok { + // Note: if the unwrapped object does not + // unwrap as an observable for either of the + // predicates here, it means an internal bug in + // this package. We avoid logging an error in + // this case, because the SDK has to try its + // own type conversion on the object. The SDK + // will see this and be forced to respond with + // its own error. + // + // This code uses a double-nested if statement + // to avoid creating a branch that is + // impossible to cover. + inst = floatObs + } + } + return inst +} + +// unwrapInt64Observable returns an expected metric.Int64Observable after +// unwrapping the global object. +func unwrapInt64Observable(inst metric.Int64Observable) metric.Int64Observable { + if unwrapped, ok := inst.(unwrapper); ok { + if unint, ok := unwrapped.unwrap().(metric.Int64Observable); ok { + // See the comment in unwrapFloat64Observable(). + inst = unint + } + } + return inst +} + +func (uo *unwrapObs) ObserveFloat64(inst metric.Float64Observable, value float64, opts ...metric.ObserveOption) { + uo.obs.ObserveFloat64(unwrapFloat64Observable(inst), value, opts...) +} + +func (uo *unwrapObs) ObserveInt64(inst metric.Int64Observable, value int64, opts ...metric.ObserveOption) { + uo.obs.ObserveInt64(unwrapInt64Observable(inst), value, opts...) +} +func unwrapCallback(f metric.Callback) metric.Callback { + return func(ctx context.Context, obs metric.Observer) error { + return f(ctx, &unwrapObs{obs: obs}) + } +} + +func (c *registration) setDelegate(m metric.Meter) { c.unregMu.Lock() defer c.unregMu.Unlock() @@ -526,7 +575,7 @@ func (c *registration) setDelegate(m metric.Meter) { return } - reg, err := m.RegisterCallback(c.function, insts...) + reg, err := m.RegisterCallback(unwrapCallback(c.function), unwrapInstruments(c.instruments)...) if err != nil { GetErrorHandler().Handle(err) return diff --git a/vendor/go.opentelemetry.io/otel/internal/global/trace.go b/vendor/go.opentelemetry.io/otel/internal/global/trace.go index e31f442b4..8982aa0dc 100644 --- a/vendor/go.opentelemetry.io/otel/internal/global/trace.go +++ b/vendor/go.opentelemetry.io/otel/internal/global/trace.go @@ -25,6 +25,7 @@ import ( "sync" "sync/atomic" + "go.opentelemetry.io/auto/sdk" "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/codes" "go.opentelemetry.io/otel/trace" @@ -87,6 +88,7 @@ func (p *tracerProvider) Tracer(name string, opts ...trace.TracerOption) trace.T name: name, version: c.InstrumentationVersion(), schema: c.SchemaURL(), + attrs: c.InstrumentationAttributes(), } if p.tracers == nil { @@ -102,7 +104,12 @@ func (p *tracerProvider) Tracer(name string, opts ...trace.TracerOption) trace.T return t } -type il struct{ name, version, schema string } +type il struct { + name string + version string + schema string + attrs attribute.Set +} // tracer is a placeholder for a trace.Tracer. // @@ -139,6 +146,30 @@ func (t *tracer) Start(ctx context.Context, name string, opts ...trace.SpanStart return delegate.(trace.Tracer).Start(ctx, name, opts...) } + return t.newSpan(ctx, autoInstEnabled, name, opts) +} + +// autoInstEnabled determines if the auto-instrumentation SDK span is returned +// from the tracer when not backed by a delegate and auto-instrumentation has +// attached to this process. +// +// The auto-instrumentation is expected to overwrite this value to true when it +// attaches. By default, this will point to false and mean a tracer will return +// a nonRecordingSpan by default. +var autoInstEnabled = new(bool) + +func (t *tracer) newSpan(ctx context.Context, autoSpan *bool, name string, opts []trace.SpanStartOption) (context.Context, trace.Span) { + // autoInstEnabled is passed to newSpan via the autoSpan parameter. This is + // so the auto-instrumentation can define a uprobe for (*t).newSpan and be + // provided with the address of the bool autoInstEnabled points to. It + // needs to be a parameter so that pointer can be reliably determined, it + // should not be read from the global. + + if *autoSpan { + tracer := sdk.TracerProvider().Tracer(t.name, t.opts...) + return tracer.Start(ctx, name, opts...) + } + s := nonRecordingSpan{sc: trace.SpanContextFromContext(ctx), tracer: t} ctx = trace.ContextWithSpan(ctx, s) return ctx, s diff --git a/vendor/go.opentelemetry.io/otel/sdk/instrumentation/scope.go b/vendor/go.opentelemetry.io/otel/sdk/instrumentation/scope.go index 728115045..34852a47b 100644 --- a/vendor/go.opentelemetry.io/otel/sdk/instrumentation/scope.go +++ b/vendor/go.opentelemetry.io/otel/sdk/instrumentation/scope.go @@ -3,6 +3,8 @@ package instrumentation // import "go.opentelemetry.io/otel/sdk/instrumentation" +import "go.opentelemetry.io/otel/attribute" + // Scope represents the instrumentation scope. type Scope struct { // Name is the name of the instrumentation scope. This should be the @@ -12,4 +14,6 @@ type Scope struct { Version string // SchemaURL of the telemetry emitted by the scope. SchemaURL string + // Attributes of the telemetry emitted by the scope. + Attributes attribute.Set } diff --git a/vendor/go.opentelemetry.io/otel/sdk/resource/auto.go b/vendor/go.opentelemetry.io/otel/sdk/resource/auto.go index 95a61d61d..c02aeefdd 100644 --- a/vendor/go.opentelemetry.io/otel/sdk/resource/auto.go +++ b/vendor/go.opentelemetry.io/otel/sdk/resource/auto.go @@ -7,7 +7,6 @@ import ( "context" "errors" "fmt" - "strings" ) // ErrPartialResource is returned by a detector when complete source @@ -57,62 +56,37 @@ func Detect(ctx context.Context, detectors ...Detector) (*Resource, error) { // these errors will be returned. Otherwise, nil is returned. func detect(ctx context.Context, res *Resource, detectors []Detector) error { var ( - r *Resource - errs detectErrs - err error + r *Resource + err error + e error ) for _, detector := range detectors { if detector == nil { continue } - r, err = detector.Detect(ctx) - if err != nil { - errs = append(errs, err) - if !errors.Is(err, ErrPartialResource) { + r, e = detector.Detect(ctx) + if e != nil { + err = errors.Join(err, e) + if !errors.Is(e, ErrPartialResource) { continue } } - r, err = Merge(res, r) - if err != nil { - errs = append(errs, err) + r, e = Merge(res, r) + if e != nil { + err = errors.Join(err, e) } *res = *r } - if len(errs) == 0 { - return nil - } - if errors.Is(errs, ErrSchemaURLConflict) { - // If there has been a merge conflict, ensure the resource has no - // schema URL. - res.schemaURL = "" - } - return errs -} - -type detectErrs []error - -func (e detectErrs) Error() string { - errStr := make([]string, len(e)) - for i, err := range e { - errStr[i] = fmt.Sprintf("* %s", err) - } - - format := "%d errors occurred detecting resource:\n\t%s" - return fmt.Sprintf(format, len(e), strings.Join(errStr, "\n\t")) -} + if err != nil { + if errors.Is(err, ErrSchemaURLConflict) { + // If there has been a merge conflict, ensure the resource has no + // schema URL. + res.schemaURL = "" + } -func (e detectErrs) Unwrap() error { - switch len(e) { - case 0: - return nil - case 1: - return e[0] + err = fmt.Errorf("error detecting resource: %w", err) } - return e[1:] -} - -func (e detectErrs) Is(target error) bool { - return len(e) != 0 && errors.Is(e[0], target) + return err } diff --git a/vendor/go.opentelemetry.io/otel/sdk/resource/builtin.go b/vendor/go.opentelemetry.io/otel/sdk/resource/builtin.go index 6ac1cdbf7..cf3c88e15 100644 --- a/vendor/go.opentelemetry.io/otel/sdk/resource/builtin.go +++ b/vendor/go.opentelemetry.io/otel/sdk/resource/builtin.go @@ -20,15 +20,13 @@ type ( // telemetrySDK is a Detector that provides information about // the OpenTelemetry SDK used. This Detector is included as a // builtin. If these resource attributes are not wanted, use - // the WithTelemetrySDK(nil) or WithoutBuiltin() options to - // explicitly disable them. + // resource.New() to explicitly disable them. telemetrySDK struct{} // host is a Detector that provides information about the host // being run on. This Detector is included as a builtin. If // these resource attributes are not wanted, use the - // WithHost(nil) or WithoutBuiltin() options to explicitly - // disable them. + // resource.New() to explicitly disable them. host struct{} stringDetector struct { diff --git a/vendor/go.opentelemetry.io/otel/sdk/trace/batch_span_processor.go b/vendor/go.opentelemetry.io/otel/sdk/trace/batch_span_processor.go index 4ce757dfd..ccc97e1b6 100644 --- a/vendor/go.opentelemetry.io/otel/sdk/trace/batch_span_processor.go +++ b/vendor/go.opentelemetry.io/otel/sdk/trace/batch_span_processor.go @@ -280,6 +280,7 @@ func (bsp *batchSpanProcessor) exportSpans(ctx context.Context) error { // // It is up to the exporter to implement any type of retry logic if a batch is failing // to be exported, since it is specific to the protocol and backend being sent to. + clear(bsp.batch) // Erase elements to let GC collect objects bsp.batch = bsp.batch[:0] if err != nil { diff --git a/vendor/go.opentelemetry.io/otel/sdk/trace/provider.go b/vendor/go.opentelemetry.io/otel/sdk/trace/provider.go index 14c2e5beb..185aa7c08 100644 --- a/vendor/go.opentelemetry.io/otel/sdk/trace/provider.go +++ b/vendor/go.opentelemetry.io/otel/sdk/trace/provider.go @@ -139,9 +139,10 @@ func (p *TracerProvider) Tracer(name string, opts ...trace.TracerOption) trace.T name = defaultTracerName } is := instrumentation.Scope{ - Name: name, - Version: c.InstrumentationVersion(), - SchemaURL: c.SchemaURL(), + Name: name, + Version: c.InstrumentationVersion(), + SchemaURL: c.SchemaURL(), + Attributes: c.InstrumentationAttributes(), } t, ok := func() (trace.Tracer, bool) { @@ -168,7 +169,7 @@ func (p *TracerProvider) Tracer(name string, opts ...trace.TracerOption) trace.T // slowing down all tracing consumers. // - Logging code may be instrumented with tracing and deadlock because it could try // acquiring the same non-reentrant mutex. - global.Info("Tracer created", "name", name, "version", is.Version, "schemaURL", is.SchemaURL) + global.Info("Tracer created", "name", name, "version", is.Version, "schemaURL", is.SchemaURL, "attributes", is.Attributes) } return t } diff --git a/vendor/go.opentelemetry.io/otel/sdk/trace/sampler_env.go b/vendor/go.opentelemetry.io/otel/sdk/trace/sampler_env.go index d2d1f7246..9b672a1d7 100644 --- a/vendor/go.opentelemetry.io/otel/sdk/trace/sampler_env.go +++ b/vendor/go.opentelemetry.io/otel/sdk/trace/sampler_env.go @@ -5,7 +5,6 @@ package trace // import "go.opentelemetry.io/otel/sdk/trace" import ( "errors" - "fmt" "os" "strconv" "strings" @@ -26,7 +25,7 @@ const ( type errUnsupportedSampler string func (e errUnsupportedSampler) Error() string { - return fmt.Sprintf("unsupported sampler: %s", string(e)) + return "unsupported sampler: " + string(e) } var ( @@ -39,7 +38,7 @@ type samplerArgParseError struct { } func (e samplerArgParseError) Error() string { - return fmt.Sprintf("parsing sampler argument: %s", e.parseErr.Error()) + return "parsing sampler argument: " + e.parseErr.Error() } func (e samplerArgParseError) Unwrap() error { diff --git a/vendor/go.opentelemetry.io/otel/sdk/trace/span.go b/vendor/go.opentelemetry.io/otel/sdk/trace/span.go index 730fb85c3..8f4fc3850 100644 --- a/vendor/go.opentelemetry.io/otel/sdk/trace/span.go +++ b/vendor/go.opentelemetry.io/otel/sdk/trace/span.go @@ -347,54 +347,99 @@ func truncateAttr(limit int, attr attribute.KeyValue) attribute.KeyValue { } switch attr.Value.Type() { case attribute.STRING: - if v := attr.Value.AsString(); len(v) > limit { - return attr.Key.String(safeTruncate(v, limit)) - } + v := attr.Value.AsString() + return attr.Key.String(truncate(limit, v)) case attribute.STRINGSLICE: v := attr.Value.AsStringSlice() for i := range v { - if len(v[i]) > limit { - v[i] = safeTruncate(v[i], limit) - } + v[i] = truncate(limit, v[i]) } return attr.Key.StringSlice(v) } return attr } -// safeTruncate truncates the string and guarantees valid UTF-8 is returned. -func safeTruncate(input string, limit int) string { - if trunc, ok := safeTruncateValidUTF8(input, limit); ok { - return trunc +// truncate returns a truncated version of s such that it contains less than +// the limit number of characters. Truncation is applied by returning the limit +// number of valid characters contained in s. +// +// If limit is negative, it returns the original string. +// +// UTF-8 is supported. When truncating, all invalid characters are dropped +// before applying truncation. +// +// If s already contains less than the limit number of bytes, it is returned +// unchanged. No invalid characters are removed. +func truncate(limit int, s string) string { + // This prioritize performance in the following order based on the most + // common expected use-cases. + // + // - Short values less than the default limit (128). + // - Strings with valid encodings that exceed the limit. + // - No limit. + // - Strings with invalid encodings that exceed the limit. + if limit < 0 || len(s) <= limit { + return s + } + + // Optimistically, assume all valid UTF-8. + var b strings.Builder + count := 0 + for i, c := range s { + if c != utf8.RuneError { + count++ + if count > limit { + return s[:i] + } + continue + } + + _, size := utf8.DecodeRuneInString(s[i:]) + if size == 1 { + // Invalid encoding. + b.Grow(len(s) - 1) + _, _ = b.WriteString(s[:i]) + s = s[i:] + break + } + } + + // Fast-path, no invalid input. + if b.Cap() == 0 { + return s } - trunc, _ := safeTruncateValidUTF8(strings.ToValidUTF8(input, ""), limit) - return trunc -} -// safeTruncateValidUTF8 returns a copy of the input string safely truncated to -// limit. The truncation is ensured to occur at the bounds of complete UTF-8 -// characters. If invalid encoding of UTF-8 is encountered, input is returned -// with false, otherwise, the truncated input will be returned with true. -func safeTruncateValidUTF8(input string, limit int) (string, bool) { - for cnt := 0; cnt <= limit; { - r, size := utf8.DecodeRuneInString(input[cnt:]) - if r == utf8.RuneError { - return input, false + // Truncate while validating UTF-8. + for i := 0; i < len(s) && count < limit; { + c := s[i] + if c < utf8.RuneSelf { + // Optimization for single byte runes (common case). + _ = b.WriteByte(c) + i++ + count++ + continue } - if cnt+size > limit { - return input[:cnt], true + _, size := utf8.DecodeRuneInString(s[i:]) + if size == 1 { + // We checked for all 1-byte runes above, this is a RuneError. + i++ + continue } - cnt += size + + _, _ = b.WriteString(s[i : i+size]) + i += size + count++ } - return input, true + + return b.String() } // End ends the span. This method does nothing if the span is already ended or // is not being recorded. // -// The only SpanOption currently supported is WithTimestamp which will set the -// end time for a Span's life-cycle. +// The only SpanEndOption currently supported are [trace.WithTimestamp], and +// [trace.WithStackTrace]. // // If this method is called while panicking an error event is added to the // Span before ending it and the panic is continued. @@ -639,10 +684,7 @@ func (s *recordingSpan) dedupeAttrsFromRecord(record map[attribute.Key]int) { record[a.Key] = len(unique) - 1 } } - // s.attributes have element types of attribute.KeyValue. These types are - // not pointers and they themselves do not contain pointer fields, - // therefore the duplicate values do not need to be zeroed for them to be - // garbage collected. + clear(s.attributes[len(unique):]) // Erase unneeded elements to let GC collect objects. s.attributes = unique } diff --git a/vendor/go.opentelemetry.io/otel/sdk/trace/tracetest/recorder.go b/vendor/go.opentelemetry.io/otel/sdk/trace/tracetest/recorder.go index 7aababbbf..732669a17 100644 --- a/vendor/go.opentelemetry.io/otel/sdk/trace/tracetest/recorder.go +++ b/vendor/go.opentelemetry.io/otel/sdk/trace/tracetest/recorder.go @@ -69,6 +69,19 @@ func (sr *SpanRecorder) Started() []sdktrace.ReadWriteSpan { return dst } +// Reset clears the recorded spans. +// +// This method is safe to be called concurrently. +func (sr *SpanRecorder) Reset() { + sr.startedMu.Lock() + sr.endedMu.Lock() + defer sr.startedMu.Unlock() + defer sr.endedMu.Unlock() + + sr.started = nil + sr.ended = nil +} + // Ended returns a copy of all ended spans that have been recorded. // // This method is safe to be called concurrently. diff --git a/vendor/go.opentelemetry.io/otel/sdk/version.go b/vendor/go.opentelemetry.io/otel/sdk/version.go index dc1eaa8e9..ba7db4889 100644 --- a/vendor/go.opentelemetry.io/otel/sdk/version.go +++ b/vendor/go.opentelemetry.io/otel/sdk/version.go @@ -5,5 +5,5 @@ package sdk // import "go.opentelemetry.io/otel/sdk" // Version is the current release version of the OpenTelemetry SDK in use. func Version() string { - return "1.31.0" + return "1.33.0" } diff --git a/vendor/go.opentelemetry.io/otel/semconv/v1.24.0/README.md b/vendor/go.opentelemetry.io/otel/semconv/v1.24.0/README.md deleted file mode 100644 index 0b6cbe960..000000000 --- a/vendor/go.opentelemetry.io/otel/semconv/v1.24.0/README.md +++ /dev/null @@ -1,3 +0,0 @@ -# Semconv v1.24.0 - -[![PkgGoDev](https://pkg.go.dev/badge/go.opentelemetry.io/otel/semconv/v1.24.0)](https://pkg.go.dev/go.opentelemetry.io/otel/semconv/v1.24.0) diff --git a/vendor/go.opentelemetry.io/otel/semconv/v1.24.0/attribute_group.go b/vendor/go.opentelemetry.io/otel/semconv/v1.24.0/attribute_group.go deleted file mode 100644 index 6e688345c..000000000 --- a/vendor/go.opentelemetry.io/otel/semconv/v1.24.0/attribute_group.go +++ /dev/null @@ -1,4387 +0,0 @@ -// Copyright The OpenTelemetry Authors -// SPDX-License-Identifier: Apache-2.0 - -// Code generated from semantic convention specification. DO NOT EDIT. - -package semconv // import "go.opentelemetry.io/otel/semconv/v1.24.0" - -import "go.opentelemetry.io/otel/attribute" - -// Describes FaaS attributes. -const ( - // FaaSInvokedNameKey is the attribute Key conforming to the - // "faas.invoked_name" semantic conventions. It represents the name of the - // invoked function. - // - // Type: string - // RequirementLevel: Required - // Stability: experimental - // Examples: 'my-function' - // Note: SHOULD be equal to the `faas.name` resource attribute of the - // invoked function. - FaaSInvokedNameKey = attribute.Key("faas.invoked_name") - - // FaaSInvokedProviderKey is the attribute Key conforming to the - // "faas.invoked_provider" semantic conventions. It represents the cloud - // provider of the invoked function. - // - // Type: Enum - // RequirementLevel: Required - // Stability: experimental - // Note: SHOULD be equal to the `cloud.provider` resource attribute of the - // invoked function. - FaaSInvokedProviderKey = attribute.Key("faas.invoked_provider") - - // FaaSInvokedRegionKey is the attribute Key conforming to the - // "faas.invoked_region" semantic conventions. It represents the cloud - // region of the invoked function. - // - // Type: string - // RequirementLevel: ConditionallyRequired (For some cloud providers, like - // AWS or GCP, the region in which a function is hosted is essential to - // uniquely identify the function and also part of its endpoint. Since it's - // part of the endpoint being called, the region is always known to - // clients. In these cases, `faas.invoked_region` MUST be set accordingly. - // If the region is unknown to the client or not required for identifying - // the invoked function, setting `faas.invoked_region` is optional.) - // Stability: experimental - // Examples: 'eu-central-1' - // Note: SHOULD be equal to the `cloud.region` resource attribute of the - // invoked function. - FaaSInvokedRegionKey = attribute.Key("faas.invoked_region") - - // FaaSTriggerKey is the attribute Key conforming to the "faas.trigger" - // semantic conventions. It represents the type of the trigger which caused - // this function invocation. - // - // Type: Enum - // RequirementLevel: Optional - // Stability: experimental - FaaSTriggerKey = attribute.Key("faas.trigger") -) - -var ( - // Alibaba Cloud - FaaSInvokedProviderAlibabaCloud = FaaSInvokedProviderKey.String("alibaba_cloud") - // Amazon Web Services - FaaSInvokedProviderAWS = FaaSInvokedProviderKey.String("aws") - // Microsoft Azure - FaaSInvokedProviderAzure = FaaSInvokedProviderKey.String("azure") - // Google Cloud Platform - FaaSInvokedProviderGCP = FaaSInvokedProviderKey.String("gcp") - // Tencent Cloud - FaaSInvokedProviderTencentCloud = FaaSInvokedProviderKey.String("tencent_cloud") -) - -var ( - // A response to some data source operation such as a database or filesystem read/write - FaaSTriggerDatasource = FaaSTriggerKey.String("datasource") - // To provide an answer to an inbound HTTP request - FaaSTriggerHTTP = FaaSTriggerKey.String("http") - // A function is set to be executed when messages are sent to a messaging system - FaaSTriggerPubsub = FaaSTriggerKey.String("pubsub") - // A function is scheduled to be executed regularly - FaaSTriggerTimer = FaaSTriggerKey.String("timer") - // If none of the others apply - FaaSTriggerOther = FaaSTriggerKey.String("other") -) - -// FaaSInvokedName returns an attribute KeyValue conforming to the -// "faas.invoked_name" semantic conventions. It represents the name of the -// invoked function. -func FaaSInvokedName(val string) attribute.KeyValue { - return FaaSInvokedNameKey.String(val) -} - -// FaaSInvokedRegion returns an attribute KeyValue conforming to the -// "faas.invoked_region" semantic conventions. It represents the cloud region -// of the invoked function. -func FaaSInvokedRegion(val string) attribute.KeyValue { - return FaaSInvokedRegionKey.String(val) -} - -// Attributes for Events represented using Log Records. -const ( - // EventNameKey is the attribute Key conforming to the "event.name" - // semantic conventions. It represents the identifies the class / type of - // event. - // - // Type: string - // RequirementLevel: Required - // Stability: experimental - // Examples: 'browser.mouse.click', 'device.app.lifecycle' - // Note: Event names are subject to the same rules as [attribute - // names](https://github.com/open-telemetry/opentelemetry-specification/tree/v1.26.0/specification/common/attribute-naming.md). - // Notably, event names are namespaced to avoid collisions and provide a - // clean separation of semantics for events in separate domains like - // browser, mobile, and kubernetes. - EventNameKey = attribute.Key("event.name") -) - -// EventName returns an attribute KeyValue conforming to the "event.name" -// semantic conventions. It represents the identifies the class / type of -// event. -func EventName(val string) attribute.KeyValue { - return EventNameKey.String(val) -} - -// The attributes described in this section are rather generic. They may be -// used in any Log Record they apply to. -const ( - // LogRecordUIDKey is the attribute Key conforming to the "log.record.uid" - // semantic conventions. It represents a unique identifier for the Log - // Record. - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: '01ARZ3NDEKTSV4RRFFQ69G5FAV' - // Note: If an id is provided, other log records with the same id will be - // considered duplicates and can be removed safely. This means, that two - // distinguishable log records MUST have different values. - // The id MAY be an [Universally Unique Lexicographically Sortable - // Identifier (ULID)](https://github.com/ulid/spec), but other identifiers - // (e.g. UUID) may be used as needed. - LogRecordUIDKey = attribute.Key("log.record.uid") -) - -// LogRecordUID returns an attribute KeyValue conforming to the -// "log.record.uid" semantic conventions. It represents a unique identifier for -// the Log Record. -func LogRecordUID(val string) attribute.KeyValue { - return LogRecordUIDKey.String(val) -} - -// Describes Log attributes -const ( - // LogIostreamKey is the attribute Key conforming to the "log.iostream" - // semantic conventions. It represents the stream associated with the log. - // See below for a list of well-known values. - // - // Type: Enum - // RequirementLevel: Optional - // Stability: experimental - LogIostreamKey = attribute.Key("log.iostream") -) - -var ( - // Logs from stdout stream - LogIostreamStdout = LogIostreamKey.String("stdout") - // Events from stderr stream - LogIostreamStderr = LogIostreamKey.String("stderr") -) - -// A file to which log was emitted. -const ( - // LogFileNameKey is the attribute Key conforming to the "log.file.name" - // semantic conventions. It represents the basename of the file. - // - // Type: string - // RequirementLevel: Recommended - // Stability: experimental - // Examples: 'audit.log' - LogFileNameKey = attribute.Key("log.file.name") - - // LogFileNameResolvedKey is the attribute Key conforming to the - // "log.file.name_resolved" semantic conventions. It represents the - // basename of the file, with symlinks resolved. - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: 'uuid.log' - LogFileNameResolvedKey = attribute.Key("log.file.name_resolved") - - // LogFilePathKey is the attribute Key conforming to the "log.file.path" - // semantic conventions. It represents the full path to the file. - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: '/var/log/mysql/audit.log' - LogFilePathKey = attribute.Key("log.file.path") - - // LogFilePathResolvedKey is the attribute Key conforming to the - // "log.file.path_resolved" semantic conventions. It represents the full - // path to the file, with symlinks resolved. - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: '/var/lib/docker/uuid.log' - LogFilePathResolvedKey = attribute.Key("log.file.path_resolved") -) - -// LogFileName returns an attribute KeyValue conforming to the -// "log.file.name" semantic conventions. It represents the basename of the -// file. -func LogFileName(val string) attribute.KeyValue { - return LogFileNameKey.String(val) -} - -// LogFileNameResolved returns an attribute KeyValue conforming to the -// "log.file.name_resolved" semantic conventions. It represents the basename of -// the file, with symlinks resolved. -func LogFileNameResolved(val string) attribute.KeyValue { - return LogFileNameResolvedKey.String(val) -} - -// LogFilePath returns an attribute KeyValue conforming to the -// "log.file.path" semantic conventions. It represents the full path to the -// file. -func LogFilePath(val string) attribute.KeyValue { - return LogFilePathKey.String(val) -} - -// LogFilePathResolved returns an attribute KeyValue conforming to the -// "log.file.path_resolved" semantic conventions. It represents the full path -// to the file, with symlinks resolved. -func LogFilePathResolved(val string) attribute.KeyValue { - return LogFilePathResolvedKey.String(val) -} - -// Describes Database attributes -const ( - // PoolNameKey is the attribute Key conforming to the "pool.name" semantic - // conventions. It represents the name of the connection pool; unique - // within the instrumented application. In case the connection pool - // implementation doesn't provide a name, then the - // [db.connection_string](/docs/database/database-spans.md#connection-level-attributes) - // should be used - // - // Type: string - // RequirementLevel: Required - // Stability: experimental - // Examples: 'myDataSource' - PoolNameKey = attribute.Key("pool.name") - - // StateKey is the attribute Key conforming to the "state" semantic - // conventions. It represents the state of a connection in the pool - // - // Type: Enum - // RequirementLevel: Required - // Stability: experimental - // Examples: 'idle' - StateKey = attribute.Key("state") -) - -var ( - // idle - StateIdle = StateKey.String("idle") - // used - StateUsed = StateKey.String("used") -) - -// PoolName returns an attribute KeyValue conforming to the "pool.name" -// semantic conventions. It represents the name of the connection pool; unique -// within the instrumented application. In case the connection pool -// implementation doesn't provide a name, then the -// [db.connection_string](/docs/database/database-spans.md#connection-level-attributes) -// should be used -func PoolName(val string) attribute.KeyValue { - return PoolNameKey.String(val) -} - -// ASP.NET Core attributes -const ( - // AspnetcoreDiagnosticsHandlerTypeKey is the attribute Key conforming to - // the "aspnetcore.diagnostics.handler.type" semantic conventions. It - // represents the full type name of the - // [`IExceptionHandler`](https://learn.microsoft.com/dotnet/api/microsoft.aspnetcore.diagnostics.iexceptionhandler) - // implementation that handled the exception. - // - // Type: string - // RequirementLevel: ConditionallyRequired (if and only if the exception - // was handled by this handler.) - // Stability: experimental - // Examples: 'Contoso.MyHandler' - AspnetcoreDiagnosticsHandlerTypeKey = attribute.Key("aspnetcore.diagnostics.handler.type") - - // AspnetcoreRateLimitingPolicyKey is the attribute Key conforming to the - // "aspnetcore.rate_limiting.policy" semantic conventions. It represents - // the rate limiting policy name. - // - // Type: string - // RequirementLevel: ConditionallyRequired (if the matched endpoint for the - // request had a rate-limiting policy.) - // Stability: experimental - // Examples: 'fixed', 'sliding', 'token' - AspnetcoreRateLimitingPolicyKey = attribute.Key("aspnetcore.rate_limiting.policy") - - // AspnetcoreRateLimitingResultKey is the attribute Key conforming to the - // "aspnetcore.rate_limiting.result" semantic conventions. It represents - // the rate-limiting result, shows whether the lease was acquired or - // contains a rejection reason - // - // Type: Enum - // RequirementLevel: Required - // Stability: experimental - // Examples: 'acquired', 'request_canceled' - AspnetcoreRateLimitingResultKey = attribute.Key("aspnetcore.rate_limiting.result") - - // AspnetcoreRequestIsUnhandledKey is the attribute Key conforming to the - // "aspnetcore.request.is_unhandled" semantic conventions. It represents - // the flag indicating if request was handled by the application pipeline. - // - // Type: boolean - // RequirementLevel: ConditionallyRequired (if and only if the request was - // not handled.) - // Stability: experimental - // Examples: True - AspnetcoreRequestIsUnhandledKey = attribute.Key("aspnetcore.request.is_unhandled") - - // AspnetcoreRoutingIsFallbackKey is the attribute Key conforming to the - // "aspnetcore.routing.is_fallback" semantic conventions. It represents a - // value that indicates whether the matched route is a fallback route. - // - // Type: boolean - // RequirementLevel: ConditionallyRequired (If and only if a route was - // successfully matched.) - // Stability: experimental - // Examples: True - AspnetcoreRoutingIsFallbackKey = attribute.Key("aspnetcore.routing.is_fallback") -) - -var ( - // Lease was acquired - AspnetcoreRateLimitingResultAcquired = AspnetcoreRateLimitingResultKey.String("acquired") - // Lease request was rejected by the endpoint limiter - AspnetcoreRateLimitingResultEndpointLimiter = AspnetcoreRateLimitingResultKey.String("endpoint_limiter") - // Lease request was rejected by the global limiter - AspnetcoreRateLimitingResultGlobalLimiter = AspnetcoreRateLimitingResultKey.String("global_limiter") - // Lease request was canceled - AspnetcoreRateLimitingResultRequestCanceled = AspnetcoreRateLimitingResultKey.String("request_canceled") -) - -// AspnetcoreDiagnosticsHandlerType returns an attribute KeyValue conforming -// to the "aspnetcore.diagnostics.handler.type" semantic conventions. It -// represents the full type name of the -// [`IExceptionHandler`](https://learn.microsoft.com/dotnet/api/microsoft.aspnetcore.diagnostics.iexceptionhandler) -// implementation that handled the exception. -func AspnetcoreDiagnosticsHandlerType(val string) attribute.KeyValue { - return AspnetcoreDiagnosticsHandlerTypeKey.String(val) -} - -// AspnetcoreRateLimitingPolicy returns an attribute KeyValue conforming to -// the "aspnetcore.rate_limiting.policy" semantic conventions. It represents -// the rate limiting policy name. -func AspnetcoreRateLimitingPolicy(val string) attribute.KeyValue { - return AspnetcoreRateLimitingPolicyKey.String(val) -} - -// AspnetcoreRequestIsUnhandled returns an attribute KeyValue conforming to -// the "aspnetcore.request.is_unhandled" semantic conventions. It represents -// the flag indicating if request was handled by the application pipeline. -func AspnetcoreRequestIsUnhandled(val bool) attribute.KeyValue { - return AspnetcoreRequestIsUnhandledKey.Bool(val) -} - -// AspnetcoreRoutingIsFallback returns an attribute KeyValue conforming to -// the "aspnetcore.routing.is_fallback" semantic conventions. It represents a -// value that indicates whether the matched route is a fallback route. -func AspnetcoreRoutingIsFallback(val bool) attribute.KeyValue { - return AspnetcoreRoutingIsFallbackKey.Bool(val) -} - -// SignalR attributes -const ( - // SignalrConnectionStatusKey is the attribute Key conforming to the - // "signalr.connection.status" semantic conventions. It represents the - // signalR HTTP connection closure status. - // - // Type: Enum - // RequirementLevel: Optional - // Stability: experimental - // Examples: 'app_shutdown', 'timeout' - SignalrConnectionStatusKey = attribute.Key("signalr.connection.status") - - // SignalrTransportKey is the attribute Key conforming to the - // "signalr.transport" semantic conventions. It represents the [SignalR - // transport - // type](https://github.com/dotnet/aspnetcore/blob/main/src/SignalR/docs/specs/TransportProtocols.md) - // - // Type: Enum - // RequirementLevel: Optional - // Stability: experimental - // Examples: 'web_sockets', 'long_polling' - SignalrTransportKey = attribute.Key("signalr.transport") -) - -var ( - // The connection was closed normally - SignalrConnectionStatusNormalClosure = SignalrConnectionStatusKey.String("normal_closure") - // The connection was closed due to a timeout - SignalrConnectionStatusTimeout = SignalrConnectionStatusKey.String("timeout") - // The connection was closed because the app is shutting down - SignalrConnectionStatusAppShutdown = SignalrConnectionStatusKey.String("app_shutdown") -) - -var ( - // ServerSentEvents protocol - SignalrTransportServerSentEvents = SignalrTransportKey.String("server_sent_events") - // LongPolling protocol - SignalrTransportLongPolling = SignalrTransportKey.String("long_polling") - // WebSockets protocol - SignalrTransportWebSockets = SignalrTransportKey.String("web_sockets") -) - -// Describes JVM buffer metric attributes. -const ( - // JvmBufferPoolNameKey is the attribute Key conforming to the - // "jvm.buffer.pool.name" semantic conventions. It represents the name of - // the buffer pool. - // - // Type: string - // RequirementLevel: Recommended - // Stability: experimental - // Examples: 'mapped', 'direct' - // Note: Pool names are generally obtained via - // [BufferPoolMXBean#getName()](https://docs.oracle.com/en/java/javase/11/docs/api/java.management/java/lang/management/BufferPoolMXBean.html#getName()). - JvmBufferPoolNameKey = attribute.Key("jvm.buffer.pool.name") -) - -// JvmBufferPoolName returns an attribute KeyValue conforming to the -// "jvm.buffer.pool.name" semantic conventions. It represents the name of the -// buffer pool. -func JvmBufferPoolName(val string) attribute.KeyValue { - return JvmBufferPoolNameKey.String(val) -} - -// Describes JVM memory metric attributes. -const ( - // JvmMemoryPoolNameKey is the attribute Key conforming to the - // "jvm.memory.pool.name" semantic conventions. It represents the name of - // the memory pool. - // - // Type: string - // RequirementLevel: Recommended - // Stability: stable - // Examples: 'G1 Old Gen', 'G1 Eden space', 'G1 Survivor Space' - // Note: Pool names are generally obtained via - // [MemoryPoolMXBean#getName()](https://docs.oracle.com/en/java/javase/11/docs/api/java.management/java/lang/management/MemoryPoolMXBean.html#getName()). - JvmMemoryPoolNameKey = attribute.Key("jvm.memory.pool.name") - - // JvmMemoryTypeKey is the attribute Key conforming to the - // "jvm.memory.type" semantic conventions. It represents the type of - // memory. - // - // Type: Enum - // RequirementLevel: Recommended - // Stability: stable - // Examples: 'heap', 'non_heap' - JvmMemoryTypeKey = attribute.Key("jvm.memory.type") -) - -var ( - // Heap memory - JvmMemoryTypeHeap = JvmMemoryTypeKey.String("heap") - // Non-heap memory - JvmMemoryTypeNonHeap = JvmMemoryTypeKey.String("non_heap") -) - -// JvmMemoryPoolName returns an attribute KeyValue conforming to the -// "jvm.memory.pool.name" semantic conventions. It represents the name of the -// memory pool. -func JvmMemoryPoolName(val string) attribute.KeyValue { - return JvmMemoryPoolNameKey.String(val) -} - -// Describes System metric attributes -const ( - // SystemDeviceKey is the attribute Key conforming to the "system.device" - // semantic conventions. It represents the device identifier - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: '(identifier)' - SystemDeviceKey = attribute.Key("system.device") -) - -// SystemDevice returns an attribute KeyValue conforming to the -// "system.device" semantic conventions. It represents the device identifier -func SystemDevice(val string) attribute.KeyValue { - return SystemDeviceKey.String(val) -} - -// Describes System CPU metric attributes -const ( - // SystemCPULogicalNumberKey is the attribute Key conforming to the - // "system.cpu.logical_number" semantic conventions. It represents the - // logical CPU number [0..n-1] - // - // Type: int - // RequirementLevel: Optional - // Stability: experimental - // Examples: 1 - SystemCPULogicalNumberKey = attribute.Key("system.cpu.logical_number") - - // SystemCPUStateKey is the attribute Key conforming to the - // "system.cpu.state" semantic conventions. It represents the state of the - // CPU - // - // Type: Enum - // RequirementLevel: Optional - // Stability: experimental - // Examples: 'idle', 'interrupt' - SystemCPUStateKey = attribute.Key("system.cpu.state") -) - -var ( - // user - SystemCPUStateUser = SystemCPUStateKey.String("user") - // system - SystemCPUStateSystem = SystemCPUStateKey.String("system") - // nice - SystemCPUStateNice = SystemCPUStateKey.String("nice") - // idle - SystemCPUStateIdle = SystemCPUStateKey.String("idle") - // iowait - SystemCPUStateIowait = SystemCPUStateKey.String("iowait") - // interrupt - SystemCPUStateInterrupt = SystemCPUStateKey.String("interrupt") - // steal - SystemCPUStateSteal = SystemCPUStateKey.String("steal") -) - -// SystemCPULogicalNumber returns an attribute KeyValue conforming to the -// "system.cpu.logical_number" semantic conventions. It represents the logical -// CPU number [0..n-1] -func SystemCPULogicalNumber(val int) attribute.KeyValue { - return SystemCPULogicalNumberKey.Int(val) -} - -// Describes System Memory metric attributes -const ( - // SystemMemoryStateKey is the attribute Key conforming to the - // "system.memory.state" semantic conventions. It represents the memory - // state - // - // Type: Enum - // RequirementLevel: Optional - // Stability: experimental - // Examples: 'free', 'cached' - SystemMemoryStateKey = attribute.Key("system.memory.state") -) - -var ( - // used - SystemMemoryStateUsed = SystemMemoryStateKey.String("used") - // free - SystemMemoryStateFree = SystemMemoryStateKey.String("free") - // shared - SystemMemoryStateShared = SystemMemoryStateKey.String("shared") - // buffers - SystemMemoryStateBuffers = SystemMemoryStateKey.String("buffers") - // cached - SystemMemoryStateCached = SystemMemoryStateKey.String("cached") -) - -// Describes System Memory Paging metric attributes -const ( - // SystemPagingDirectionKey is the attribute Key conforming to the - // "system.paging.direction" semantic conventions. It represents the paging - // access direction - // - // Type: Enum - // RequirementLevel: Optional - // Stability: experimental - // Examples: 'in' - SystemPagingDirectionKey = attribute.Key("system.paging.direction") - - // SystemPagingStateKey is the attribute Key conforming to the - // "system.paging.state" semantic conventions. It represents the memory - // paging state - // - // Type: Enum - // RequirementLevel: Optional - // Stability: experimental - // Examples: 'free' - SystemPagingStateKey = attribute.Key("system.paging.state") - - // SystemPagingTypeKey is the attribute Key conforming to the - // "system.paging.type" semantic conventions. It represents the memory - // paging type - // - // Type: Enum - // RequirementLevel: Optional - // Stability: experimental - // Examples: 'minor' - SystemPagingTypeKey = attribute.Key("system.paging.type") -) - -var ( - // in - SystemPagingDirectionIn = SystemPagingDirectionKey.String("in") - // out - SystemPagingDirectionOut = SystemPagingDirectionKey.String("out") -) - -var ( - // used - SystemPagingStateUsed = SystemPagingStateKey.String("used") - // free - SystemPagingStateFree = SystemPagingStateKey.String("free") -) - -var ( - // major - SystemPagingTypeMajor = SystemPagingTypeKey.String("major") - // minor - SystemPagingTypeMinor = SystemPagingTypeKey.String("minor") -) - -// Describes Filesystem metric attributes -const ( - // SystemFilesystemModeKey is the attribute Key conforming to the - // "system.filesystem.mode" semantic conventions. It represents the - // filesystem mode - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: 'rw, ro' - SystemFilesystemModeKey = attribute.Key("system.filesystem.mode") - - // SystemFilesystemMountpointKey is the attribute Key conforming to the - // "system.filesystem.mountpoint" semantic conventions. It represents the - // filesystem mount path - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: '/mnt/data' - SystemFilesystemMountpointKey = attribute.Key("system.filesystem.mountpoint") - - // SystemFilesystemStateKey is the attribute Key conforming to the - // "system.filesystem.state" semantic conventions. It represents the - // filesystem state - // - // Type: Enum - // RequirementLevel: Optional - // Stability: experimental - // Examples: 'used' - SystemFilesystemStateKey = attribute.Key("system.filesystem.state") - - // SystemFilesystemTypeKey is the attribute Key conforming to the - // "system.filesystem.type" semantic conventions. It represents the - // filesystem type - // - // Type: Enum - // RequirementLevel: Optional - // Stability: experimental - // Examples: 'ext4' - SystemFilesystemTypeKey = attribute.Key("system.filesystem.type") -) - -var ( - // used - SystemFilesystemStateUsed = SystemFilesystemStateKey.String("used") - // free - SystemFilesystemStateFree = SystemFilesystemStateKey.String("free") - // reserved - SystemFilesystemStateReserved = SystemFilesystemStateKey.String("reserved") -) - -var ( - // fat32 - SystemFilesystemTypeFat32 = SystemFilesystemTypeKey.String("fat32") - // exfat - SystemFilesystemTypeExfat = SystemFilesystemTypeKey.String("exfat") - // ntfs - SystemFilesystemTypeNtfs = SystemFilesystemTypeKey.String("ntfs") - // refs - SystemFilesystemTypeRefs = SystemFilesystemTypeKey.String("refs") - // hfsplus - SystemFilesystemTypeHfsplus = SystemFilesystemTypeKey.String("hfsplus") - // ext4 - SystemFilesystemTypeExt4 = SystemFilesystemTypeKey.String("ext4") -) - -// SystemFilesystemMode returns an attribute KeyValue conforming to the -// "system.filesystem.mode" semantic conventions. It represents the filesystem -// mode -func SystemFilesystemMode(val string) attribute.KeyValue { - return SystemFilesystemModeKey.String(val) -} - -// SystemFilesystemMountpoint returns an attribute KeyValue conforming to -// the "system.filesystem.mountpoint" semantic conventions. It represents the -// filesystem mount path -func SystemFilesystemMountpoint(val string) attribute.KeyValue { - return SystemFilesystemMountpointKey.String(val) -} - -// Describes Network metric attributes -const ( - // SystemNetworkStateKey is the attribute Key conforming to the - // "system.network.state" semantic conventions. It represents a stateless - // protocol MUST NOT set this attribute - // - // Type: Enum - // RequirementLevel: Optional - // Stability: experimental - // Examples: 'close_wait' - SystemNetworkStateKey = attribute.Key("system.network.state") -) - -var ( - // close - SystemNetworkStateClose = SystemNetworkStateKey.String("close") - // close_wait - SystemNetworkStateCloseWait = SystemNetworkStateKey.String("close_wait") - // closing - SystemNetworkStateClosing = SystemNetworkStateKey.String("closing") - // delete - SystemNetworkStateDelete = SystemNetworkStateKey.String("delete") - // established - SystemNetworkStateEstablished = SystemNetworkStateKey.String("established") - // fin_wait_1 - SystemNetworkStateFinWait1 = SystemNetworkStateKey.String("fin_wait_1") - // fin_wait_2 - SystemNetworkStateFinWait2 = SystemNetworkStateKey.String("fin_wait_2") - // last_ack - SystemNetworkStateLastAck = SystemNetworkStateKey.String("last_ack") - // listen - SystemNetworkStateListen = SystemNetworkStateKey.String("listen") - // syn_recv - SystemNetworkStateSynRecv = SystemNetworkStateKey.String("syn_recv") - // syn_sent - SystemNetworkStateSynSent = SystemNetworkStateKey.String("syn_sent") - // time_wait - SystemNetworkStateTimeWait = SystemNetworkStateKey.String("time_wait") -) - -// Describes System Process metric attributes -const ( - // SystemProcessesStatusKey is the attribute Key conforming to the - // "system.processes.status" semantic conventions. It represents the - // process state, e.g., [Linux Process State - // Codes](https://man7.org/linux/man-pages/man1/ps.1.html#PROCESS_STATE_CODES) - // - // Type: Enum - // RequirementLevel: Optional - // Stability: experimental - // Examples: 'running' - SystemProcessesStatusKey = attribute.Key("system.processes.status") -) - -var ( - // running - SystemProcessesStatusRunning = SystemProcessesStatusKey.String("running") - // sleeping - SystemProcessesStatusSleeping = SystemProcessesStatusKey.String("sleeping") - // stopped - SystemProcessesStatusStopped = SystemProcessesStatusKey.String("stopped") - // defunct - SystemProcessesStatusDefunct = SystemProcessesStatusKey.String("defunct") -) - -// These attributes may be used to describe the client in a connection-based -// network interaction where there is one side that initiates the connection -// (the client is the side that initiates the connection). This covers all TCP -// network interactions since TCP is connection-based and one side initiates -// the connection (an exception is made for peer-to-peer communication over TCP -// where the "user-facing" surface of the protocol / API doesn't expose a clear -// notion of client and server). This also covers UDP network interactions -// where one side initiates the interaction, e.g. QUIC (HTTP/3) and DNS. -const ( - // ClientAddressKey is the attribute Key conforming to the "client.address" - // semantic conventions. It represents the client address - domain name if - // available without reverse DNS lookup; otherwise, IP address or Unix - // domain socket name. - // - // Type: string - // RequirementLevel: Optional - // Stability: stable - // Examples: 'client.example.com', '10.1.2.80', '/tmp/my.sock' - // Note: When observed from the server side, and when communicating through - // an intermediary, `client.address` SHOULD represent the client address - // behind any intermediaries, for example proxies, if it's available. - ClientAddressKey = attribute.Key("client.address") - - // ClientPortKey is the attribute Key conforming to the "client.port" - // semantic conventions. It represents the client port number. - // - // Type: int - // RequirementLevel: Optional - // Stability: stable - // Examples: 65123 - // Note: When observed from the server side, and when communicating through - // an intermediary, `client.port` SHOULD represent the client port behind - // any intermediaries, for example proxies, if it's available. - ClientPortKey = attribute.Key("client.port") -) - -// ClientAddress returns an attribute KeyValue conforming to the -// "client.address" semantic conventions. It represents the client address - -// domain name if available without reverse DNS lookup; otherwise, IP address -// or Unix domain socket name. -func ClientAddress(val string) attribute.KeyValue { - return ClientAddressKey.String(val) -} - -// ClientPort returns an attribute KeyValue conforming to the "client.port" -// semantic conventions. It represents the client port number. -func ClientPort(val int) attribute.KeyValue { - return ClientPortKey.Int(val) -} - -// The attributes used to describe telemetry in the context of databases. -const ( - // DBCassandraConsistencyLevelKey is the attribute Key conforming to the - // "db.cassandra.consistency_level" semantic conventions. It represents the - // consistency level of the query. Based on consistency values from - // [CQL](https://docs.datastax.com/en/cassandra-oss/3.0/cassandra/dml/dmlConfigConsistency.html). - // - // Type: Enum - // RequirementLevel: Optional - // Stability: experimental - DBCassandraConsistencyLevelKey = attribute.Key("db.cassandra.consistency_level") - - // DBCassandraCoordinatorDCKey is the attribute Key conforming to the - // "db.cassandra.coordinator.dc" semantic conventions. It represents the - // data center of the coordinating node for a query. - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: 'us-west-2' - DBCassandraCoordinatorDCKey = attribute.Key("db.cassandra.coordinator.dc") - - // DBCassandraCoordinatorIDKey is the attribute Key conforming to the - // "db.cassandra.coordinator.id" semantic conventions. It represents the ID - // of the coordinating node for a query. - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: 'be13faa2-8574-4d71-926d-27f16cf8a7af' - DBCassandraCoordinatorIDKey = attribute.Key("db.cassandra.coordinator.id") - - // DBCassandraIdempotenceKey is the attribute Key conforming to the - // "db.cassandra.idempotence" semantic conventions. It represents the - // whether or not the query is idempotent. - // - // Type: boolean - // RequirementLevel: Optional - // Stability: experimental - DBCassandraIdempotenceKey = attribute.Key("db.cassandra.idempotence") - - // DBCassandraPageSizeKey is the attribute Key conforming to the - // "db.cassandra.page_size" semantic conventions. It represents the fetch - // size used for paging, i.e. how many rows will be returned at once. - // - // Type: int - // RequirementLevel: Optional - // Stability: experimental - // Examples: 5000 - DBCassandraPageSizeKey = attribute.Key("db.cassandra.page_size") - - // DBCassandraSpeculativeExecutionCountKey is the attribute Key conforming - // to the "db.cassandra.speculative_execution_count" semantic conventions. - // It represents the number of times a query was speculatively executed. - // Not set or `0` if the query was not executed speculatively. - // - // Type: int - // RequirementLevel: Optional - // Stability: experimental - // Examples: 0, 2 - DBCassandraSpeculativeExecutionCountKey = attribute.Key("db.cassandra.speculative_execution_count") - - // DBCassandraTableKey is the attribute Key conforming to the - // "db.cassandra.table" semantic conventions. It represents the name of the - // primary Cassandra table that the operation is acting upon, including the - // keyspace name (if applicable). - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: 'mytable' - // Note: This mirrors the db.sql.table attribute but references cassandra - // rather than sql. It is not recommended to attempt any client-side - // parsing of `db.statement` just to get this property, but it should be - // set if it is provided by the library being instrumented. If the - // operation is acting upon an anonymous table, or more than one table, - // this value MUST NOT be set. - DBCassandraTableKey = attribute.Key("db.cassandra.table") - - // DBConnectionStringKey is the attribute Key conforming to the - // "db.connection_string" semantic conventions. It represents the - // connection string used to connect to the database. It is recommended to - // remove embedded credentials. - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: 'Server=(localdb)\\v11.0;Integrated Security=true;' - DBConnectionStringKey = attribute.Key("db.connection_string") - - // DBCosmosDBClientIDKey is the attribute Key conforming to the - // "db.cosmosdb.client_id" semantic conventions. It represents the unique - // Cosmos client instance id. - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: '3ba4827d-4422-483f-b59f-85b74211c11d' - DBCosmosDBClientIDKey = attribute.Key("db.cosmosdb.client_id") - - // DBCosmosDBConnectionModeKey is the attribute Key conforming to the - // "db.cosmosdb.connection_mode" semantic conventions. It represents the - // cosmos client connection mode. - // - // Type: Enum - // RequirementLevel: Optional - // Stability: experimental - DBCosmosDBConnectionModeKey = attribute.Key("db.cosmosdb.connection_mode") - - // DBCosmosDBContainerKey is the attribute Key conforming to the - // "db.cosmosdb.container" semantic conventions. It represents the cosmos - // DB container name. - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: 'anystring' - DBCosmosDBContainerKey = attribute.Key("db.cosmosdb.container") - - // DBCosmosDBOperationTypeKey is the attribute Key conforming to the - // "db.cosmosdb.operation_type" semantic conventions. It represents the - // cosmosDB Operation Type. - // - // Type: Enum - // RequirementLevel: Optional - // Stability: experimental - DBCosmosDBOperationTypeKey = attribute.Key("db.cosmosdb.operation_type") - - // DBCosmosDBRequestChargeKey is the attribute Key conforming to the - // "db.cosmosdb.request_charge" semantic conventions. It represents the rU - // consumed for that operation - // - // Type: double - // RequirementLevel: Optional - // Stability: experimental - // Examples: 46.18, 1.0 - DBCosmosDBRequestChargeKey = attribute.Key("db.cosmosdb.request_charge") - - // DBCosmosDBRequestContentLengthKey is the attribute Key conforming to the - // "db.cosmosdb.request_content_length" semantic conventions. It represents - // the request payload size in bytes - // - // Type: int - // RequirementLevel: Optional - // Stability: experimental - DBCosmosDBRequestContentLengthKey = attribute.Key("db.cosmosdb.request_content_length") - - // DBCosmosDBStatusCodeKey is the attribute Key conforming to the - // "db.cosmosdb.status_code" semantic conventions. It represents the cosmos - // DB status code. - // - // Type: int - // RequirementLevel: Optional - // Stability: experimental - // Examples: 200, 201 - DBCosmosDBStatusCodeKey = attribute.Key("db.cosmosdb.status_code") - - // DBCosmosDBSubStatusCodeKey is the attribute Key conforming to the - // "db.cosmosdb.sub_status_code" semantic conventions. It represents the - // cosmos DB sub status code. - // - // Type: int - // RequirementLevel: Optional - // Stability: experimental - // Examples: 1000, 1002 - DBCosmosDBSubStatusCodeKey = attribute.Key("db.cosmosdb.sub_status_code") - - // DBElasticsearchClusterNameKey is the attribute Key conforming to the - // "db.elasticsearch.cluster.name" semantic conventions. It represents the - // represents the identifier of an Elasticsearch cluster. - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: 'e9106fc68e3044f0b1475b04bf4ffd5f' - DBElasticsearchClusterNameKey = attribute.Key("db.elasticsearch.cluster.name") - - // DBElasticsearchNodeNameKey is the attribute Key conforming to the - // "db.elasticsearch.node.name" semantic conventions. It represents the - // represents the human-readable identifier of the node/instance to which a - // request was routed. - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: 'instance-0000000001' - DBElasticsearchNodeNameKey = attribute.Key("db.elasticsearch.node.name") - - // DBInstanceIDKey is the attribute Key conforming to the "db.instance.id" - // semantic conventions. It represents an identifier (address, unique name, - // or any other identifier) of the database instance that is executing - // queries or mutations on the current connection. This is useful in cases - // where the database is running in a clustered environment and the - // instrumentation is able to record the node executing the query. The - // client may obtain this value in databases like MySQL using queries like - // `select @@hostname`. - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: 'mysql-e26b99z.example.com' - DBInstanceIDKey = attribute.Key("db.instance.id") - - // DBJDBCDriverClassnameKey is the attribute Key conforming to the - // "db.jdbc.driver_classname" semantic conventions. It represents the - // fully-qualified class name of the [Java Database Connectivity - // (JDBC)](https://docs.oracle.com/javase/8/docs/technotes/guides/jdbc/) - // driver used to connect. - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: 'org.postgresql.Driver', - // 'com.microsoft.sqlserver.jdbc.SQLServerDriver' - DBJDBCDriverClassnameKey = attribute.Key("db.jdbc.driver_classname") - - // DBMongoDBCollectionKey is the attribute Key conforming to the - // "db.mongodb.collection" semantic conventions. It represents the MongoDB - // collection being accessed within the database stated in `db.name`. - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: 'customers', 'products' - DBMongoDBCollectionKey = attribute.Key("db.mongodb.collection") - - // DBMSSQLInstanceNameKey is the attribute Key conforming to the - // "db.mssql.instance_name" semantic conventions. It represents the - // Microsoft SQL Server [instance - // name](https://docs.microsoft.com/sql/connect/jdbc/building-the-connection-url?view=sql-server-ver15) - // connecting to. This name is used to determine the port of a named - // instance. - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: 'MSSQLSERVER' - // Note: If setting a `db.mssql.instance_name`, `server.port` is no longer - // required (but still recommended if non-standard). - DBMSSQLInstanceNameKey = attribute.Key("db.mssql.instance_name") - - // DBNameKey is the attribute Key conforming to the "db.name" semantic - // conventions. It represents the this attribute is used to report the name - // of the database being accessed. For commands that switch the database, - // this should be set to the target database (even if the command fails). - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: 'customers', 'main' - // Note: In some SQL databases, the database name to be used is called - // "schema name". In case there are multiple layers that could be - // considered for database name (e.g. Oracle instance name and schema - // name), the database name to be used is the more specific layer (e.g. - // Oracle schema name). - DBNameKey = attribute.Key("db.name") - - // DBOperationKey is the attribute Key conforming to the "db.operation" - // semantic conventions. It represents the name of the operation being - // executed, e.g. the [MongoDB command - // name](https://docs.mongodb.com/manual/reference/command/#database-operations) - // such as `findAndModify`, or the SQL keyword. - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: 'findAndModify', 'HMSET', 'SELECT' - // Note: When setting this to an SQL keyword, it is not recommended to - // attempt any client-side parsing of `db.statement` just to get this - // property, but it should be set if the operation name is provided by the - // library being instrumented. If the SQL statement has an ambiguous - // operation, or performs more than one operation, this value may be - // omitted. - DBOperationKey = attribute.Key("db.operation") - - // DBRedisDBIndexKey is the attribute Key conforming to the - // "db.redis.database_index" semantic conventions. It represents the index - // of the database being accessed as used in the [`SELECT` - // command](https://redis.io/commands/select), provided as an integer. To - // be used instead of the generic `db.name` attribute. - // - // Type: int - // RequirementLevel: Optional - // Stability: experimental - // Examples: 0, 1, 15 - DBRedisDBIndexKey = attribute.Key("db.redis.database_index") - - // DBSQLTableKey is the attribute Key conforming to the "db.sql.table" - // semantic conventions. It represents the name of the primary table that - // the operation is acting upon, including the database name (if - // applicable). - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: 'public.users', 'customers' - // Note: It is not recommended to attempt any client-side parsing of - // `db.statement` just to get this property, but it should be set if it is - // provided by the library being instrumented. If the operation is acting - // upon an anonymous table, or more than one table, this value MUST NOT be - // set. - DBSQLTableKey = attribute.Key("db.sql.table") - - // DBStatementKey is the attribute Key conforming to the "db.statement" - // semantic conventions. It represents the database statement being - // executed. - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: 'SELECT * FROM wuser_table', 'SET mykey "WuValue"' - DBStatementKey = attribute.Key("db.statement") - - // DBSystemKey is the attribute Key conforming to the "db.system" semantic - // conventions. It represents an identifier for the database management - // system (DBMS) product being used. See below for a list of well-known - // identifiers. - // - // Type: Enum - // RequirementLevel: Optional - // Stability: experimental - DBSystemKey = attribute.Key("db.system") - - // DBUserKey is the attribute Key conforming to the "db.user" semantic - // conventions. It represents the username for accessing the database. - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: 'readonly_user', 'reporting_user' - DBUserKey = attribute.Key("db.user") -) - -var ( - // all - DBCassandraConsistencyLevelAll = DBCassandraConsistencyLevelKey.String("all") - // each_quorum - DBCassandraConsistencyLevelEachQuorum = DBCassandraConsistencyLevelKey.String("each_quorum") - // quorum - DBCassandraConsistencyLevelQuorum = DBCassandraConsistencyLevelKey.String("quorum") - // local_quorum - DBCassandraConsistencyLevelLocalQuorum = DBCassandraConsistencyLevelKey.String("local_quorum") - // one - DBCassandraConsistencyLevelOne = DBCassandraConsistencyLevelKey.String("one") - // two - DBCassandraConsistencyLevelTwo = DBCassandraConsistencyLevelKey.String("two") - // three - DBCassandraConsistencyLevelThree = DBCassandraConsistencyLevelKey.String("three") - // local_one - DBCassandraConsistencyLevelLocalOne = DBCassandraConsistencyLevelKey.String("local_one") - // any - DBCassandraConsistencyLevelAny = DBCassandraConsistencyLevelKey.String("any") - // serial - DBCassandraConsistencyLevelSerial = DBCassandraConsistencyLevelKey.String("serial") - // local_serial - DBCassandraConsistencyLevelLocalSerial = DBCassandraConsistencyLevelKey.String("local_serial") -) - -var ( - // Gateway (HTTP) connections mode - DBCosmosDBConnectionModeGateway = DBCosmosDBConnectionModeKey.String("gateway") - // Direct connection - DBCosmosDBConnectionModeDirect = DBCosmosDBConnectionModeKey.String("direct") -) - -var ( - // invalid - DBCosmosDBOperationTypeInvalid = DBCosmosDBOperationTypeKey.String("Invalid") - // create - DBCosmosDBOperationTypeCreate = DBCosmosDBOperationTypeKey.String("Create") - // patch - DBCosmosDBOperationTypePatch = DBCosmosDBOperationTypeKey.String("Patch") - // read - DBCosmosDBOperationTypeRead = DBCosmosDBOperationTypeKey.String("Read") - // read_feed - DBCosmosDBOperationTypeReadFeed = DBCosmosDBOperationTypeKey.String("ReadFeed") - // delete - DBCosmosDBOperationTypeDelete = DBCosmosDBOperationTypeKey.String("Delete") - // replace - DBCosmosDBOperationTypeReplace = DBCosmosDBOperationTypeKey.String("Replace") - // execute - DBCosmosDBOperationTypeExecute = DBCosmosDBOperationTypeKey.String("Execute") - // query - DBCosmosDBOperationTypeQuery = DBCosmosDBOperationTypeKey.String("Query") - // head - DBCosmosDBOperationTypeHead = DBCosmosDBOperationTypeKey.String("Head") - // head_feed - DBCosmosDBOperationTypeHeadFeed = DBCosmosDBOperationTypeKey.String("HeadFeed") - // upsert - DBCosmosDBOperationTypeUpsert = DBCosmosDBOperationTypeKey.String("Upsert") - // batch - DBCosmosDBOperationTypeBatch = DBCosmosDBOperationTypeKey.String("Batch") - // query_plan - DBCosmosDBOperationTypeQueryPlan = DBCosmosDBOperationTypeKey.String("QueryPlan") - // execute_javascript - DBCosmosDBOperationTypeExecuteJavascript = DBCosmosDBOperationTypeKey.String("ExecuteJavaScript") -) - -var ( - // Some other SQL database. Fallback only. See notes - DBSystemOtherSQL = DBSystemKey.String("other_sql") - // Microsoft SQL Server - DBSystemMSSQL = DBSystemKey.String("mssql") - // Microsoft SQL Server Compact - DBSystemMssqlcompact = DBSystemKey.String("mssqlcompact") - // MySQL - DBSystemMySQL = DBSystemKey.String("mysql") - // Oracle Database - DBSystemOracle = DBSystemKey.String("oracle") - // IBM DB2 - DBSystemDB2 = DBSystemKey.String("db2") - // PostgreSQL - DBSystemPostgreSQL = DBSystemKey.String("postgresql") - // Amazon Redshift - DBSystemRedshift = DBSystemKey.String("redshift") - // Apache Hive - DBSystemHive = DBSystemKey.String("hive") - // Cloudscape - DBSystemCloudscape = DBSystemKey.String("cloudscape") - // HyperSQL DataBase - DBSystemHSQLDB = DBSystemKey.String("hsqldb") - // Progress Database - DBSystemProgress = DBSystemKey.String("progress") - // SAP MaxDB - DBSystemMaxDB = DBSystemKey.String("maxdb") - // SAP HANA - DBSystemHanaDB = DBSystemKey.String("hanadb") - // Ingres - DBSystemIngres = DBSystemKey.String("ingres") - // FirstSQL - DBSystemFirstSQL = DBSystemKey.String("firstsql") - // EnterpriseDB - DBSystemEDB = DBSystemKey.String("edb") - // InterSystems Caché - DBSystemCache = DBSystemKey.String("cache") - // Adabas (Adaptable Database System) - DBSystemAdabas = DBSystemKey.String("adabas") - // Firebird - DBSystemFirebird = DBSystemKey.String("firebird") - // Apache Derby - DBSystemDerby = DBSystemKey.String("derby") - // FileMaker - DBSystemFilemaker = DBSystemKey.String("filemaker") - // Informix - DBSystemInformix = DBSystemKey.String("informix") - // InstantDB - DBSystemInstantDB = DBSystemKey.String("instantdb") - // InterBase - DBSystemInterbase = DBSystemKey.String("interbase") - // MariaDB - DBSystemMariaDB = DBSystemKey.String("mariadb") - // Netezza - DBSystemNetezza = DBSystemKey.String("netezza") - // Pervasive PSQL - DBSystemPervasive = DBSystemKey.String("pervasive") - // PointBase - DBSystemPointbase = DBSystemKey.String("pointbase") - // SQLite - DBSystemSqlite = DBSystemKey.String("sqlite") - // Sybase - DBSystemSybase = DBSystemKey.String("sybase") - // Teradata - DBSystemTeradata = DBSystemKey.String("teradata") - // Vertica - DBSystemVertica = DBSystemKey.String("vertica") - // H2 - DBSystemH2 = DBSystemKey.String("h2") - // ColdFusion IMQ - DBSystemColdfusion = DBSystemKey.String("coldfusion") - // Apache Cassandra - DBSystemCassandra = DBSystemKey.String("cassandra") - // Apache HBase - DBSystemHBase = DBSystemKey.String("hbase") - // MongoDB - DBSystemMongoDB = DBSystemKey.String("mongodb") - // Redis - DBSystemRedis = DBSystemKey.String("redis") - // Couchbase - DBSystemCouchbase = DBSystemKey.String("couchbase") - // CouchDB - DBSystemCouchDB = DBSystemKey.String("couchdb") - // Microsoft Azure Cosmos DB - DBSystemCosmosDB = DBSystemKey.String("cosmosdb") - // Amazon DynamoDB - DBSystemDynamoDB = DBSystemKey.String("dynamodb") - // Neo4j - DBSystemNeo4j = DBSystemKey.String("neo4j") - // Apache Geode - DBSystemGeode = DBSystemKey.String("geode") - // Elasticsearch - DBSystemElasticsearch = DBSystemKey.String("elasticsearch") - // Memcached - DBSystemMemcached = DBSystemKey.String("memcached") - // CockroachDB - DBSystemCockroachdb = DBSystemKey.String("cockroachdb") - // OpenSearch - DBSystemOpensearch = DBSystemKey.String("opensearch") - // ClickHouse - DBSystemClickhouse = DBSystemKey.String("clickhouse") - // Cloud Spanner - DBSystemSpanner = DBSystemKey.String("spanner") - // Trino - DBSystemTrino = DBSystemKey.String("trino") -) - -// DBCassandraCoordinatorDC returns an attribute KeyValue conforming to the -// "db.cassandra.coordinator.dc" semantic conventions. It represents the data -// center of the coordinating node for a query. -func DBCassandraCoordinatorDC(val string) attribute.KeyValue { - return DBCassandraCoordinatorDCKey.String(val) -} - -// DBCassandraCoordinatorID returns an attribute KeyValue conforming to the -// "db.cassandra.coordinator.id" semantic conventions. It represents the ID of -// the coordinating node for a query. -func DBCassandraCoordinatorID(val string) attribute.KeyValue { - return DBCassandraCoordinatorIDKey.String(val) -} - -// DBCassandraIdempotence returns an attribute KeyValue conforming to the -// "db.cassandra.idempotence" semantic conventions. It represents the whether -// or not the query is idempotent. -func DBCassandraIdempotence(val bool) attribute.KeyValue { - return DBCassandraIdempotenceKey.Bool(val) -} - -// DBCassandraPageSize returns an attribute KeyValue conforming to the -// "db.cassandra.page_size" semantic conventions. It represents the fetch size -// used for paging, i.e. how many rows will be returned at once. -func DBCassandraPageSize(val int) attribute.KeyValue { - return DBCassandraPageSizeKey.Int(val) -} - -// DBCassandraSpeculativeExecutionCount returns an attribute KeyValue -// conforming to the "db.cassandra.speculative_execution_count" semantic -// conventions. It represents the number of times a query was speculatively -// executed. Not set or `0` if the query was not executed speculatively. -func DBCassandraSpeculativeExecutionCount(val int) attribute.KeyValue { - return DBCassandraSpeculativeExecutionCountKey.Int(val) -} - -// DBCassandraTable returns an attribute KeyValue conforming to the -// "db.cassandra.table" semantic conventions. It represents the name of the -// primary Cassandra table that the operation is acting upon, including the -// keyspace name (if applicable). -func DBCassandraTable(val string) attribute.KeyValue { - return DBCassandraTableKey.String(val) -} - -// DBConnectionString returns an attribute KeyValue conforming to the -// "db.connection_string" semantic conventions. It represents the connection -// string used to connect to the database. It is recommended to remove embedded -// credentials. -func DBConnectionString(val string) attribute.KeyValue { - return DBConnectionStringKey.String(val) -} - -// DBCosmosDBClientID returns an attribute KeyValue conforming to the -// "db.cosmosdb.client_id" semantic conventions. It represents the unique -// Cosmos client instance id. -func DBCosmosDBClientID(val string) attribute.KeyValue { - return DBCosmosDBClientIDKey.String(val) -} - -// DBCosmosDBContainer returns an attribute KeyValue conforming to the -// "db.cosmosdb.container" semantic conventions. It represents the cosmos DB -// container name. -func DBCosmosDBContainer(val string) attribute.KeyValue { - return DBCosmosDBContainerKey.String(val) -} - -// DBCosmosDBRequestCharge returns an attribute KeyValue conforming to the -// "db.cosmosdb.request_charge" semantic conventions. It represents the rU -// consumed for that operation -func DBCosmosDBRequestCharge(val float64) attribute.KeyValue { - return DBCosmosDBRequestChargeKey.Float64(val) -} - -// DBCosmosDBRequestContentLength returns an attribute KeyValue conforming -// to the "db.cosmosdb.request_content_length" semantic conventions. It -// represents the request payload size in bytes -func DBCosmosDBRequestContentLength(val int) attribute.KeyValue { - return DBCosmosDBRequestContentLengthKey.Int(val) -} - -// DBCosmosDBStatusCode returns an attribute KeyValue conforming to the -// "db.cosmosdb.status_code" semantic conventions. It represents the cosmos DB -// status code. -func DBCosmosDBStatusCode(val int) attribute.KeyValue { - return DBCosmosDBStatusCodeKey.Int(val) -} - -// DBCosmosDBSubStatusCode returns an attribute KeyValue conforming to the -// "db.cosmosdb.sub_status_code" semantic conventions. It represents the cosmos -// DB sub status code. -func DBCosmosDBSubStatusCode(val int) attribute.KeyValue { - return DBCosmosDBSubStatusCodeKey.Int(val) -} - -// DBElasticsearchClusterName returns an attribute KeyValue conforming to -// the "db.elasticsearch.cluster.name" semantic conventions. It represents the -// represents the identifier of an Elasticsearch cluster. -func DBElasticsearchClusterName(val string) attribute.KeyValue { - return DBElasticsearchClusterNameKey.String(val) -} - -// DBElasticsearchNodeName returns an attribute KeyValue conforming to the -// "db.elasticsearch.node.name" semantic conventions. It represents the -// represents the human-readable identifier of the node/instance to which a -// request was routed. -func DBElasticsearchNodeName(val string) attribute.KeyValue { - return DBElasticsearchNodeNameKey.String(val) -} - -// DBInstanceID returns an attribute KeyValue conforming to the -// "db.instance.id" semantic conventions. It represents an identifier (address, -// unique name, or any other identifier) of the database instance that is -// executing queries or mutations on the current connection. This is useful in -// cases where the database is running in a clustered environment and the -// instrumentation is able to record the node executing the query. The client -// may obtain this value in databases like MySQL using queries like `select -// @@hostname`. -func DBInstanceID(val string) attribute.KeyValue { - return DBInstanceIDKey.String(val) -} - -// DBJDBCDriverClassname returns an attribute KeyValue conforming to the -// "db.jdbc.driver_classname" semantic conventions. It represents the -// fully-qualified class name of the [Java Database Connectivity -// (JDBC)](https://docs.oracle.com/javase/8/docs/technotes/guides/jdbc/) driver -// used to connect. -func DBJDBCDriverClassname(val string) attribute.KeyValue { - return DBJDBCDriverClassnameKey.String(val) -} - -// DBMongoDBCollection returns an attribute KeyValue conforming to the -// "db.mongodb.collection" semantic conventions. It represents the MongoDB -// collection being accessed within the database stated in `db.name`. -func DBMongoDBCollection(val string) attribute.KeyValue { - return DBMongoDBCollectionKey.String(val) -} - -// DBMSSQLInstanceName returns an attribute KeyValue conforming to the -// "db.mssql.instance_name" semantic conventions. It represents the Microsoft -// SQL Server [instance -// name](https://docs.microsoft.com/sql/connect/jdbc/building-the-connection-url?view=sql-server-ver15) -// connecting to. This name is used to determine the port of a named instance. -func DBMSSQLInstanceName(val string) attribute.KeyValue { - return DBMSSQLInstanceNameKey.String(val) -} - -// DBName returns an attribute KeyValue conforming to the "db.name" semantic -// conventions. It represents the this attribute is used to report the name of -// the database being accessed. For commands that switch the database, this -// should be set to the target database (even if the command fails). -func DBName(val string) attribute.KeyValue { - return DBNameKey.String(val) -} - -// DBOperation returns an attribute KeyValue conforming to the -// "db.operation" semantic conventions. It represents the name of the operation -// being executed, e.g. the [MongoDB command -// name](https://docs.mongodb.com/manual/reference/command/#database-operations) -// such as `findAndModify`, or the SQL keyword. -func DBOperation(val string) attribute.KeyValue { - return DBOperationKey.String(val) -} - -// DBRedisDBIndex returns an attribute KeyValue conforming to the -// "db.redis.database_index" semantic conventions. It represents the index of -// the database being accessed as used in the [`SELECT` -// command](https://redis.io/commands/select), provided as an integer. To be -// used instead of the generic `db.name` attribute. -func DBRedisDBIndex(val int) attribute.KeyValue { - return DBRedisDBIndexKey.Int(val) -} - -// DBSQLTable returns an attribute KeyValue conforming to the "db.sql.table" -// semantic conventions. It represents the name of the primary table that the -// operation is acting upon, including the database name (if applicable). -func DBSQLTable(val string) attribute.KeyValue { - return DBSQLTableKey.String(val) -} - -// DBStatement returns an attribute KeyValue conforming to the -// "db.statement" semantic conventions. It represents the database statement -// being executed. -func DBStatement(val string) attribute.KeyValue { - return DBStatementKey.String(val) -} - -// DBUser returns an attribute KeyValue conforming to the "db.user" semantic -// conventions. It represents the username for accessing the database. -func DBUser(val string) attribute.KeyValue { - return DBUserKey.String(val) -} - -// Describes deprecated HTTP attributes. -const ( - // HTTPFlavorKey is the attribute Key conforming to the "http.flavor" - // semantic conventions. - // - // Type: Enum - // RequirementLevel: Optional - // Stability: deprecated - // Deprecated: use `network.protocol.name` instead. - HTTPFlavorKey = attribute.Key("http.flavor") - - // HTTPMethodKey is the attribute Key conforming to the "http.method" - // semantic conventions. - // - // Type: string - // RequirementLevel: Optional - // Stability: deprecated - // Examples: 'GET', 'POST', 'HEAD' - // Deprecated: use `http.request.method` instead. - HTTPMethodKey = attribute.Key("http.method") - - // HTTPRequestContentLengthKey is the attribute Key conforming to the - // "http.request_content_length" semantic conventions. - // - // Type: int - // RequirementLevel: Optional - // Stability: deprecated - // Examples: 3495 - // Deprecated: use `http.request.header.content-length` instead. - HTTPRequestContentLengthKey = attribute.Key("http.request_content_length") - - // HTTPResponseContentLengthKey is the attribute Key conforming to the - // "http.response_content_length" semantic conventions. - // - // Type: int - // RequirementLevel: Optional - // Stability: deprecated - // Examples: 3495 - // Deprecated: use `http.response.header.content-length` instead. - HTTPResponseContentLengthKey = attribute.Key("http.response_content_length") - - // HTTPSchemeKey is the attribute Key conforming to the "http.scheme" - // semantic conventions. - // - // Type: string - // RequirementLevel: Optional - // Stability: deprecated - // Examples: 'http', 'https' - // Deprecated: use `url.scheme` instead. - HTTPSchemeKey = attribute.Key("http.scheme") - - // HTTPStatusCodeKey is the attribute Key conforming to the - // "http.status_code" semantic conventions. - // - // Type: int - // RequirementLevel: Optional - // Stability: deprecated - // Examples: 200 - // Deprecated: use `http.response.status_code` instead. - HTTPStatusCodeKey = attribute.Key("http.status_code") - - // HTTPTargetKey is the attribute Key conforming to the "http.target" - // semantic conventions. - // - // Type: string - // RequirementLevel: Optional - // Stability: deprecated - // Examples: '/search?q=OpenTelemetry#SemConv' - // Deprecated: use `url.path` and `url.query` instead. - HTTPTargetKey = attribute.Key("http.target") - - // HTTPURLKey is the attribute Key conforming to the "http.url" semantic - // conventions. - // - // Type: string - // RequirementLevel: Optional - // Stability: deprecated - // Examples: 'https://www.foo.bar/search?q=OpenTelemetry#SemConv' - // Deprecated: use `url.full` instead. - HTTPURLKey = attribute.Key("http.url") - - // HTTPUserAgentKey is the attribute Key conforming to the - // "http.user_agent" semantic conventions. - // - // Type: string - // RequirementLevel: Optional - // Stability: deprecated - // Examples: 'CERN-LineMode/2.15 libwww/2.17b3', 'Mozilla/5.0 (iPhone; CPU - // iPhone OS 14_7_1 like Mac OS X) AppleWebKit/605.1.15 (KHTML, like Gecko) - // Version/14.1.2 Mobile/15E148 Safari/604.1' - // Deprecated: use `user_agent.original` instead. - HTTPUserAgentKey = attribute.Key("http.user_agent") -) - -var ( - // HTTP/1.0 - // - // Deprecated: use `network.protocol.name` instead. - HTTPFlavorHTTP10 = HTTPFlavorKey.String("1.0") - // HTTP/1.1 - // - // Deprecated: use `network.protocol.name` instead. - HTTPFlavorHTTP11 = HTTPFlavorKey.String("1.1") - // HTTP/2 - // - // Deprecated: use `network.protocol.name` instead. - HTTPFlavorHTTP20 = HTTPFlavorKey.String("2.0") - // HTTP/3 - // - // Deprecated: use `network.protocol.name` instead. - HTTPFlavorHTTP30 = HTTPFlavorKey.String("3.0") - // SPDY protocol - // - // Deprecated: use `network.protocol.name` instead. - HTTPFlavorSPDY = HTTPFlavorKey.String("SPDY") - // QUIC protocol - // - // Deprecated: use `network.protocol.name` instead. - HTTPFlavorQUIC = HTTPFlavorKey.String("QUIC") -) - -// HTTPMethod returns an attribute KeyValue conforming to the "http.method" -// semantic conventions. -// -// Deprecated: use `http.request.method` instead. -func HTTPMethod(val string) attribute.KeyValue { - return HTTPMethodKey.String(val) -} - -// HTTPRequestContentLength returns an attribute KeyValue conforming to the -// "http.request_content_length" semantic conventions. -// -// Deprecated: use `http.request.header.content-length` instead. -func HTTPRequestContentLength(val int) attribute.KeyValue { - return HTTPRequestContentLengthKey.Int(val) -} - -// HTTPResponseContentLength returns an attribute KeyValue conforming to the -// "http.response_content_length" semantic conventions. -// -// Deprecated: use `http.response.header.content-length` instead. -func HTTPResponseContentLength(val int) attribute.KeyValue { - return HTTPResponseContentLengthKey.Int(val) -} - -// HTTPScheme returns an attribute KeyValue conforming to the "http.scheme" -// semantic conventions. -// -// Deprecated: use `url.scheme` instead. -func HTTPScheme(val string) attribute.KeyValue { - return HTTPSchemeKey.String(val) -} - -// HTTPStatusCode returns an attribute KeyValue conforming to the -// "http.status_code" semantic conventions. -// -// Deprecated: use `http.response.status_code` instead. -func HTTPStatusCode(val int) attribute.KeyValue { - return HTTPStatusCodeKey.Int(val) -} - -// HTTPTarget returns an attribute KeyValue conforming to the "http.target" -// semantic conventions. -// -// Deprecated: use `url.path` and `url.query` instead. -func HTTPTarget(val string) attribute.KeyValue { - return HTTPTargetKey.String(val) -} - -// HTTPURL returns an attribute KeyValue conforming to the "http.url" -// semantic conventions. -// -// Deprecated: use `url.full` instead. -func HTTPURL(val string) attribute.KeyValue { - return HTTPURLKey.String(val) -} - -// HTTPUserAgent returns an attribute KeyValue conforming to the -// "http.user_agent" semantic conventions. -// -// Deprecated: use `user_agent.original` instead. -func HTTPUserAgent(val string) attribute.KeyValue { - return HTTPUserAgentKey.String(val) -} - -// These attributes may be used for any network related operation. -const ( - // NetHostNameKey is the attribute Key conforming to the "net.host.name" - // semantic conventions. - // - // Type: string - // RequirementLevel: Optional - // Stability: deprecated - // Examples: 'example.com' - // Deprecated: use `server.address`. - NetHostNameKey = attribute.Key("net.host.name") - - // NetHostPortKey is the attribute Key conforming to the "net.host.port" - // semantic conventions. - // - // Type: int - // RequirementLevel: Optional - // Stability: deprecated - // Examples: 8080 - // Deprecated: use `server.port`. - NetHostPortKey = attribute.Key("net.host.port") - - // NetPeerNameKey is the attribute Key conforming to the "net.peer.name" - // semantic conventions. - // - // Type: string - // RequirementLevel: Optional - // Stability: deprecated - // Examples: 'example.com' - // Deprecated: use `server.address` on client spans and `client.address` on - // server spans. - NetPeerNameKey = attribute.Key("net.peer.name") - - // NetPeerPortKey is the attribute Key conforming to the "net.peer.port" - // semantic conventions. - // - // Type: int - // RequirementLevel: Optional - // Stability: deprecated - // Examples: 8080 - // Deprecated: use `server.port` on client spans and `client.port` on - // server spans. - NetPeerPortKey = attribute.Key("net.peer.port") - - // NetProtocolNameKey is the attribute Key conforming to the - // "net.protocol.name" semantic conventions. - // - // Type: string - // RequirementLevel: Optional - // Stability: deprecated - // Examples: 'amqp', 'http', 'mqtt' - // Deprecated: use `network.protocol.name`. - NetProtocolNameKey = attribute.Key("net.protocol.name") - - // NetProtocolVersionKey is the attribute Key conforming to the - // "net.protocol.version" semantic conventions. - // - // Type: string - // RequirementLevel: Optional - // Stability: deprecated - // Examples: '3.1.1' - // Deprecated: use `network.protocol.version`. - NetProtocolVersionKey = attribute.Key("net.protocol.version") - - // NetSockFamilyKey is the attribute Key conforming to the - // "net.sock.family" semantic conventions. - // - // Type: Enum - // RequirementLevel: Optional - // Stability: deprecated - // Deprecated: use `network.transport` and `network.type`. - NetSockFamilyKey = attribute.Key("net.sock.family") - - // NetSockHostAddrKey is the attribute Key conforming to the - // "net.sock.host.addr" semantic conventions. - // - // Type: string - // RequirementLevel: Optional - // Stability: deprecated - // Examples: '/var/my.sock' - // Deprecated: use `network.local.address`. - NetSockHostAddrKey = attribute.Key("net.sock.host.addr") - - // NetSockHostPortKey is the attribute Key conforming to the - // "net.sock.host.port" semantic conventions. - // - // Type: int - // RequirementLevel: Optional - // Stability: deprecated - // Examples: 8080 - // Deprecated: use `network.local.port`. - NetSockHostPortKey = attribute.Key("net.sock.host.port") - - // NetSockPeerAddrKey is the attribute Key conforming to the - // "net.sock.peer.addr" semantic conventions. - // - // Type: string - // RequirementLevel: Optional - // Stability: deprecated - // Examples: '192.168.0.1' - // Deprecated: use `network.peer.address`. - NetSockPeerAddrKey = attribute.Key("net.sock.peer.addr") - - // NetSockPeerNameKey is the attribute Key conforming to the - // "net.sock.peer.name" semantic conventions. - // - // Type: string - // RequirementLevel: Optional - // Stability: deprecated - // Examples: '/var/my.sock' - // Deprecated: no replacement at this time. - NetSockPeerNameKey = attribute.Key("net.sock.peer.name") - - // NetSockPeerPortKey is the attribute Key conforming to the - // "net.sock.peer.port" semantic conventions. - // - // Type: int - // RequirementLevel: Optional - // Stability: deprecated - // Examples: 65531 - // Deprecated: use `network.peer.port`. - NetSockPeerPortKey = attribute.Key("net.sock.peer.port") - - // NetTransportKey is the attribute Key conforming to the "net.transport" - // semantic conventions. - // - // Type: Enum - // RequirementLevel: Optional - // Stability: deprecated - // Deprecated: use `network.transport`. - NetTransportKey = attribute.Key("net.transport") -) - -var ( - // IPv4 address - // - // Deprecated: use `network.transport` and `network.type`. - NetSockFamilyInet = NetSockFamilyKey.String("inet") - // IPv6 address - // - // Deprecated: use `network.transport` and `network.type`. - NetSockFamilyInet6 = NetSockFamilyKey.String("inet6") - // Unix domain socket path - // - // Deprecated: use `network.transport` and `network.type`. - NetSockFamilyUnix = NetSockFamilyKey.String("unix") -) - -var ( - // ip_tcp - // - // Deprecated: use `network.transport`. - NetTransportTCP = NetTransportKey.String("ip_tcp") - // ip_udp - // - // Deprecated: use `network.transport`. - NetTransportUDP = NetTransportKey.String("ip_udp") - // Named or anonymous pipe - // - // Deprecated: use `network.transport`. - NetTransportPipe = NetTransportKey.String("pipe") - // In-process communication - // - // Deprecated: use `network.transport`. - NetTransportInProc = NetTransportKey.String("inproc") - // Something else (non IP-based) - // - // Deprecated: use `network.transport`. - NetTransportOther = NetTransportKey.String("other") -) - -// NetHostName returns an attribute KeyValue conforming to the -// "net.host.name" semantic conventions. -// -// Deprecated: use `server.address`. -func NetHostName(val string) attribute.KeyValue { - return NetHostNameKey.String(val) -} - -// NetHostPort returns an attribute KeyValue conforming to the -// "net.host.port" semantic conventions. -// -// Deprecated: use `server.port`. -func NetHostPort(val int) attribute.KeyValue { - return NetHostPortKey.Int(val) -} - -// NetPeerName returns an attribute KeyValue conforming to the -// "net.peer.name" semantic conventions. -// -// Deprecated: use `server.address` on client spans and `client.address` on -// server spans. -func NetPeerName(val string) attribute.KeyValue { - return NetPeerNameKey.String(val) -} - -// NetPeerPort returns an attribute KeyValue conforming to the -// "net.peer.port" semantic conventions. -// -// Deprecated: use `server.port` on client spans and `client.port` on server -// spans. -func NetPeerPort(val int) attribute.KeyValue { - return NetPeerPortKey.Int(val) -} - -// NetProtocolName returns an attribute KeyValue conforming to the -// "net.protocol.name" semantic conventions. -// -// Deprecated: use `network.protocol.name`. -func NetProtocolName(val string) attribute.KeyValue { - return NetProtocolNameKey.String(val) -} - -// NetProtocolVersion returns an attribute KeyValue conforming to the -// "net.protocol.version" semantic conventions. -// -// Deprecated: use `network.protocol.version`. -func NetProtocolVersion(val string) attribute.KeyValue { - return NetProtocolVersionKey.String(val) -} - -// NetSockHostAddr returns an attribute KeyValue conforming to the -// "net.sock.host.addr" semantic conventions. -// -// Deprecated: use `network.local.address`. -func NetSockHostAddr(val string) attribute.KeyValue { - return NetSockHostAddrKey.String(val) -} - -// NetSockHostPort returns an attribute KeyValue conforming to the -// "net.sock.host.port" semantic conventions. -// -// Deprecated: use `network.local.port`. -func NetSockHostPort(val int) attribute.KeyValue { - return NetSockHostPortKey.Int(val) -} - -// NetSockPeerAddr returns an attribute KeyValue conforming to the -// "net.sock.peer.addr" semantic conventions. -// -// Deprecated: use `network.peer.address`. -func NetSockPeerAddr(val string) attribute.KeyValue { - return NetSockPeerAddrKey.String(val) -} - -// NetSockPeerName returns an attribute KeyValue conforming to the -// "net.sock.peer.name" semantic conventions. -// -// Deprecated: no replacement at this time. -func NetSockPeerName(val string) attribute.KeyValue { - return NetSockPeerNameKey.String(val) -} - -// NetSockPeerPort returns an attribute KeyValue conforming to the -// "net.sock.peer.port" semantic conventions. -// -// Deprecated: use `network.peer.port`. -func NetSockPeerPort(val int) attribute.KeyValue { - return NetSockPeerPortKey.Int(val) -} - -// These attributes may be used to describe the receiver of a network -// exchange/packet. These should be used when there is no client/server -// relationship between the two sides, or when that relationship is unknown. -// This covers low-level network interactions (e.g. packet tracing) where you -// don't know if there was a connection or which side initiated it. This also -// covers unidirectional UDP flows and peer-to-peer communication where the -// "user-facing" surface of the protocol / API doesn't expose a clear notion of -// client and server. -const ( - // DestinationAddressKey is the attribute Key conforming to the - // "destination.address" semantic conventions. It represents the - // destination address - domain name if available without reverse DNS - // lookup; otherwise, IP address or Unix domain socket name. - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: 'destination.example.com', '10.1.2.80', '/tmp/my.sock' - // Note: When observed from the source side, and when communicating through - // an intermediary, `destination.address` SHOULD represent the destination - // address behind any intermediaries, for example proxies, if it's - // available. - DestinationAddressKey = attribute.Key("destination.address") - - // DestinationPortKey is the attribute Key conforming to the - // "destination.port" semantic conventions. It represents the destination - // port number - // - // Type: int - // RequirementLevel: Optional - // Stability: experimental - // Examples: 3389, 2888 - DestinationPortKey = attribute.Key("destination.port") -) - -// DestinationAddress returns an attribute KeyValue conforming to the -// "destination.address" semantic conventions. It represents the destination -// address - domain name if available without reverse DNS lookup; otherwise, IP -// address or Unix domain socket name. -func DestinationAddress(val string) attribute.KeyValue { - return DestinationAddressKey.String(val) -} - -// DestinationPort returns an attribute KeyValue conforming to the -// "destination.port" semantic conventions. It represents the destination port -// number -func DestinationPort(val int) attribute.KeyValue { - return DestinationPortKey.Int(val) -} - -// These attributes may be used for any disk related operation. -const ( - // DiskIoDirectionKey is the attribute Key conforming to the - // "disk.io.direction" semantic conventions. It represents the disk IO - // operation direction. - // - // Type: Enum - // RequirementLevel: Optional - // Stability: experimental - // Examples: 'read' - DiskIoDirectionKey = attribute.Key("disk.io.direction") -) - -var ( - // read - DiskIoDirectionRead = DiskIoDirectionKey.String("read") - // write - DiskIoDirectionWrite = DiskIoDirectionKey.String("write") -) - -// The shared attributes used to report an error. -const ( - // ErrorTypeKey is the attribute Key conforming to the "error.type" - // semantic conventions. It represents the describes a class of error the - // operation ended with. - // - // Type: Enum - // RequirementLevel: Optional - // Stability: stable - // Examples: 'timeout', 'java.net.UnknownHostException', - // 'server_certificate_invalid', '500' - // Note: The `error.type` SHOULD be predictable and SHOULD have low - // cardinality. - // Instrumentations SHOULD document the list of errors they report. - // - // The cardinality of `error.type` within one instrumentation library - // SHOULD be low. - // Telemetry consumers that aggregate data from multiple instrumentation - // libraries and applications - // should be prepared for `error.type` to have high cardinality at query - // time when no - // additional filters are applied. - // - // If the operation has completed successfully, instrumentations SHOULD NOT - // set `error.type`. - // - // If a specific domain defines its own set of error identifiers (such as - // HTTP or gRPC status codes), - // it's RECOMMENDED to: - // - // * Use a domain-specific attribute - // * Set `error.type` to capture all errors, regardless of whether they are - // defined within the domain-specific set or not. - ErrorTypeKey = attribute.Key("error.type") -) - -var ( - // A fallback error value to be used when the instrumentation doesn't define a custom value - ErrorTypeOther = ErrorTypeKey.String("_OTHER") -) - -// The shared attributes used to report a single exception associated with a -// span or log. -const ( - // ExceptionEscapedKey is the attribute Key conforming to the - // "exception.escaped" semantic conventions. It represents the sHOULD be - // set to true if the exception event is recorded at a point where it is - // known that the exception is escaping the scope of the span. - // - // Type: boolean - // RequirementLevel: Optional - // Stability: experimental - // Note: An exception is considered to have escaped (or left) the scope of - // a span, - // if that span is ended while the exception is still logically "in - // flight". - // This may be actually "in flight" in some languages (e.g. if the - // exception - // is passed to a Context manager's `__exit__` method in Python) but will - // usually be caught at the point of recording the exception in most - // languages. - // - // It is usually not possible to determine at the point where an exception - // is thrown - // whether it will escape the scope of a span. - // However, it is trivial to know that an exception - // will escape, if one checks for an active exception just before ending - // the span, - // as done in the [example for recording span - // exceptions](#recording-an-exception). - // - // It follows that an exception may still escape the scope of the span - // even if the `exception.escaped` attribute was not set or set to false, - // since the event might have been recorded at a time where it was not - // clear whether the exception will escape. - ExceptionEscapedKey = attribute.Key("exception.escaped") - - // ExceptionMessageKey is the attribute Key conforming to the - // "exception.message" semantic conventions. It represents the exception - // message. - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: 'Division by zero', "Can't convert 'int' object to str - // implicitly" - ExceptionMessageKey = attribute.Key("exception.message") - - // ExceptionStacktraceKey is the attribute Key conforming to the - // "exception.stacktrace" semantic conventions. It represents a stacktrace - // as a string in the natural representation for the language runtime. The - // representation is to be determined and documented by each language SIG. - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: 'Exception in thread "main" java.lang.RuntimeException: Test - // exception\\n at ' - // 'com.example.GenerateTrace.methodB(GenerateTrace.java:13)\\n at ' - // 'com.example.GenerateTrace.methodA(GenerateTrace.java:9)\\n at ' - // 'com.example.GenerateTrace.main(GenerateTrace.java:5)' - ExceptionStacktraceKey = attribute.Key("exception.stacktrace") - - // ExceptionTypeKey is the attribute Key conforming to the "exception.type" - // semantic conventions. It represents the type of the exception (its - // fully-qualified class name, if applicable). The dynamic type of the - // exception should be preferred over the static type in languages that - // support it. - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: 'java.net.ConnectException', 'OSError' - ExceptionTypeKey = attribute.Key("exception.type") -) - -// ExceptionEscaped returns an attribute KeyValue conforming to the -// "exception.escaped" semantic conventions. It represents the sHOULD be set to -// true if the exception event is recorded at a point where it is known that -// the exception is escaping the scope of the span. -func ExceptionEscaped(val bool) attribute.KeyValue { - return ExceptionEscapedKey.Bool(val) -} - -// ExceptionMessage returns an attribute KeyValue conforming to the -// "exception.message" semantic conventions. It represents the exception -// message. -func ExceptionMessage(val string) attribute.KeyValue { - return ExceptionMessageKey.String(val) -} - -// ExceptionStacktrace returns an attribute KeyValue conforming to the -// "exception.stacktrace" semantic conventions. It represents a stacktrace as a -// string in the natural representation for the language runtime. The -// representation is to be determined and documented by each language SIG. -func ExceptionStacktrace(val string) attribute.KeyValue { - return ExceptionStacktraceKey.String(val) -} - -// ExceptionType returns an attribute KeyValue conforming to the -// "exception.type" semantic conventions. It represents the type of the -// exception (its fully-qualified class name, if applicable). The dynamic type -// of the exception should be preferred over the static type in languages that -// support it. -func ExceptionType(val string) attribute.KeyValue { - return ExceptionTypeKey.String(val) -} - -// Semantic convention attributes in the HTTP namespace. -const ( - // HTTPRequestBodySizeKey is the attribute Key conforming to the - // "http.request.body.size" semantic conventions. It represents the size of - // the request payload body in bytes. This is the number of bytes - // transferred excluding headers and is often, but not always, present as - // the - // [Content-Length](https://www.rfc-editor.org/rfc/rfc9110.html#field.content-length) - // header. For requests using transport encoding, this should be the - // compressed size. - // - // Type: int - // RequirementLevel: Optional - // Stability: experimental - // Examples: 3495 - HTTPRequestBodySizeKey = attribute.Key("http.request.body.size") - - // HTTPRequestMethodKey is the attribute Key conforming to the - // "http.request.method" semantic conventions. It represents the hTTP - // request method. - // - // Type: Enum - // RequirementLevel: Optional - // Stability: stable - // Examples: 'GET', 'POST', 'HEAD' - // Note: HTTP request method value SHOULD be "known" to the - // instrumentation. - // By default, this convention defines "known" methods as the ones listed - // in [RFC9110](https://www.rfc-editor.org/rfc/rfc9110.html#name-methods) - // and the PATCH method defined in - // [RFC5789](https://www.rfc-editor.org/rfc/rfc5789.html). - // - // If the HTTP request method is not known to instrumentation, it MUST set - // the `http.request.method` attribute to `_OTHER`. - // - // If the HTTP instrumentation could end up converting valid HTTP request - // methods to `_OTHER`, then it MUST provide a way to override - // the list of known HTTP methods. If this override is done via environment - // variable, then the environment variable MUST be named - // OTEL_INSTRUMENTATION_HTTP_KNOWN_METHODS and support a comma-separated - // list of case-sensitive known HTTP methods - // (this list MUST be a full override of the default known method, it is - // not a list of known methods in addition to the defaults). - // - // HTTP method names are case-sensitive and `http.request.method` attribute - // value MUST match a known HTTP method name exactly. - // Instrumentations for specific web frameworks that consider HTTP methods - // to be case insensitive, SHOULD populate a canonical equivalent. - // Tracing instrumentations that do so, MUST also set - // `http.request.method_original` to the original value. - HTTPRequestMethodKey = attribute.Key("http.request.method") - - // HTTPRequestMethodOriginalKey is the attribute Key conforming to the - // "http.request.method_original" semantic conventions. It represents the - // original HTTP method sent by the client in the request line. - // - // Type: string - // RequirementLevel: Optional - // Stability: stable - // Examples: 'GeT', 'ACL', 'foo' - HTTPRequestMethodOriginalKey = attribute.Key("http.request.method_original") - - // HTTPRequestResendCountKey is the attribute Key conforming to the - // "http.request.resend_count" semantic conventions. It represents the - // ordinal number of request resending attempt (for any reason, including - // redirects). - // - // Type: int - // RequirementLevel: Optional - // Stability: stable - // Examples: 3 - // Note: The resend count SHOULD be updated each time an HTTP request gets - // resent by the client, regardless of what was the cause of the resending - // (e.g. redirection, authorization failure, 503 Server Unavailable, - // network issues, or any other). - HTTPRequestResendCountKey = attribute.Key("http.request.resend_count") - - // HTTPResponseBodySizeKey is the attribute Key conforming to the - // "http.response.body.size" semantic conventions. It represents the size - // of the response payload body in bytes. This is the number of bytes - // transferred excluding headers and is often, but not always, present as - // the - // [Content-Length](https://www.rfc-editor.org/rfc/rfc9110.html#field.content-length) - // header. For requests using transport encoding, this should be the - // compressed size. - // - // Type: int - // RequirementLevel: Optional - // Stability: experimental - // Examples: 3495 - HTTPResponseBodySizeKey = attribute.Key("http.response.body.size") - - // HTTPResponseStatusCodeKey is the attribute Key conforming to the - // "http.response.status_code" semantic conventions. It represents the - // [HTTP response status - // code](https://tools.ietf.org/html/rfc7231#section-6). - // - // Type: int - // RequirementLevel: Optional - // Stability: stable - // Examples: 200 - HTTPResponseStatusCodeKey = attribute.Key("http.response.status_code") - - // HTTPRouteKey is the attribute Key conforming to the "http.route" - // semantic conventions. It represents the matched route, that is, the path - // template in the format used by the respective server framework. - // - // Type: string - // RequirementLevel: Optional - // Stability: stable - // Examples: '/users/:userID?', '{controller}/{action}/{id?}' - // Note: MUST NOT be populated when this is not supported by the HTTP - // server framework as the route attribute should have low-cardinality and - // the URI path can NOT substitute it. - // SHOULD include the [application - // root](/docs/http/http-spans.md#http-server-definitions) if there is one. - HTTPRouteKey = attribute.Key("http.route") -) - -var ( - // CONNECT method - HTTPRequestMethodConnect = HTTPRequestMethodKey.String("CONNECT") - // DELETE method - HTTPRequestMethodDelete = HTTPRequestMethodKey.String("DELETE") - // GET method - HTTPRequestMethodGet = HTTPRequestMethodKey.String("GET") - // HEAD method - HTTPRequestMethodHead = HTTPRequestMethodKey.String("HEAD") - // OPTIONS method - HTTPRequestMethodOptions = HTTPRequestMethodKey.String("OPTIONS") - // PATCH method - HTTPRequestMethodPatch = HTTPRequestMethodKey.String("PATCH") - // POST method - HTTPRequestMethodPost = HTTPRequestMethodKey.String("POST") - // PUT method - HTTPRequestMethodPut = HTTPRequestMethodKey.String("PUT") - // TRACE method - HTTPRequestMethodTrace = HTTPRequestMethodKey.String("TRACE") - // Any HTTP method that the instrumentation has no prior knowledge of - HTTPRequestMethodOther = HTTPRequestMethodKey.String("_OTHER") -) - -// HTTPRequestBodySize returns an attribute KeyValue conforming to the -// "http.request.body.size" semantic conventions. It represents the size of the -// request payload body in bytes. This is the number of bytes transferred -// excluding headers and is often, but not always, present as the -// [Content-Length](https://www.rfc-editor.org/rfc/rfc9110.html#field.content-length) -// header. For requests using transport encoding, this should be the compressed -// size. -func HTTPRequestBodySize(val int) attribute.KeyValue { - return HTTPRequestBodySizeKey.Int(val) -} - -// HTTPRequestMethodOriginal returns an attribute KeyValue conforming to the -// "http.request.method_original" semantic conventions. It represents the -// original HTTP method sent by the client in the request line. -func HTTPRequestMethodOriginal(val string) attribute.KeyValue { - return HTTPRequestMethodOriginalKey.String(val) -} - -// HTTPRequestResendCount returns an attribute KeyValue conforming to the -// "http.request.resend_count" semantic conventions. It represents the ordinal -// number of request resending attempt (for any reason, including redirects). -func HTTPRequestResendCount(val int) attribute.KeyValue { - return HTTPRequestResendCountKey.Int(val) -} - -// HTTPResponseBodySize returns an attribute KeyValue conforming to the -// "http.response.body.size" semantic conventions. It represents the size of -// the response payload body in bytes. This is the number of bytes transferred -// excluding headers and is often, but not always, present as the -// [Content-Length](https://www.rfc-editor.org/rfc/rfc9110.html#field.content-length) -// header. For requests using transport encoding, this should be the compressed -// size. -func HTTPResponseBodySize(val int) attribute.KeyValue { - return HTTPResponseBodySizeKey.Int(val) -} - -// HTTPResponseStatusCode returns an attribute KeyValue conforming to the -// "http.response.status_code" semantic conventions. It represents the [HTTP -// response status code](https://tools.ietf.org/html/rfc7231#section-6). -func HTTPResponseStatusCode(val int) attribute.KeyValue { - return HTTPResponseStatusCodeKey.Int(val) -} - -// HTTPRoute returns an attribute KeyValue conforming to the "http.route" -// semantic conventions. It represents the matched route, that is, the path -// template in the format used by the respective server framework. -func HTTPRoute(val string) attribute.KeyValue { - return HTTPRouteKey.String(val) -} - -// Attributes describing telemetry around messaging systems and messaging -// activities. -const ( - // MessagingBatchMessageCountKey is the attribute Key conforming to the - // "messaging.batch.message_count" semantic conventions. It represents the - // number of messages sent, received, or processed in the scope of the - // batching operation. - // - // Type: int - // RequirementLevel: Optional - // Stability: experimental - // Examples: 0, 1, 2 - // Note: Instrumentations SHOULD NOT set `messaging.batch.message_count` on - // spans that operate with a single message. When a messaging client - // library supports both batch and single-message API for the same - // operation, instrumentations SHOULD use `messaging.batch.message_count` - // for batching APIs and SHOULD NOT use it for single-message APIs. - MessagingBatchMessageCountKey = attribute.Key("messaging.batch.message_count") - - // MessagingClientIDKey is the attribute Key conforming to the - // "messaging.client_id" semantic conventions. It represents a unique - // identifier for the client that consumes or produces a message. - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: 'client-5', 'myhost@8742@s8083jm' - MessagingClientIDKey = attribute.Key("messaging.client_id") - - // MessagingDestinationAnonymousKey is the attribute Key conforming to the - // "messaging.destination.anonymous" semantic conventions. It represents a - // boolean that is true if the message destination is anonymous (could be - // unnamed or have auto-generated name). - // - // Type: boolean - // RequirementLevel: Optional - // Stability: experimental - MessagingDestinationAnonymousKey = attribute.Key("messaging.destination.anonymous") - - // MessagingDestinationNameKey is the attribute Key conforming to the - // "messaging.destination.name" semantic conventions. It represents the - // message destination name - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: 'MyQueue', 'MyTopic' - // Note: Destination name SHOULD uniquely identify a specific queue, topic - // or other entity within the broker. If - // the broker doesn't have such notion, the destination name SHOULD - // uniquely identify the broker. - MessagingDestinationNameKey = attribute.Key("messaging.destination.name") - - // MessagingDestinationTemplateKey is the attribute Key conforming to the - // "messaging.destination.template" semantic conventions. It represents the - // low cardinality representation of the messaging destination name - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: '/customers/{customerID}' - // Note: Destination names could be constructed from templates. An example - // would be a destination name involving a user name or product id. - // Although the destination name in this case is of high cardinality, the - // underlying template is of low cardinality and can be effectively used - // for grouping and aggregation. - MessagingDestinationTemplateKey = attribute.Key("messaging.destination.template") - - // MessagingDestinationTemporaryKey is the attribute Key conforming to the - // "messaging.destination.temporary" semantic conventions. It represents a - // boolean that is true if the message destination is temporary and might - // not exist anymore after messages are processed. - // - // Type: boolean - // RequirementLevel: Optional - // Stability: experimental - MessagingDestinationTemporaryKey = attribute.Key("messaging.destination.temporary") - - // MessagingDestinationPublishAnonymousKey is the attribute Key conforming - // to the "messaging.destination_publish.anonymous" semantic conventions. - // It represents a boolean that is true if the publish message destination - // is anonymous (could be unnamed or have auto-generated name). - // - // Type: boolean - // RequirementLevel: Optional - // Stability: experimental - MessagingDestinationPublishAnonymousKey = attribute.Key("messaging.destination_publish.anonymous") - - // MessagingDestinationPublishNameKey is the attribute Key conforming to - // the "messaging.destination_publish.name" semantic conventions. It - // represents the name of the original destination the message was - // published to - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: 'MyQueue', 'MyTopic' - // Note: The name SHOULD uniquely identify a specific queue, topic, or - // other entity within the broker. If - // the broker doesn't have such notion, the original destination name - // SHOULD uniquely identify the broker. - MessagingDestinationPublishNameKey = attribute.Key("messaging.destination_publish.name") - - // MessagingGCPPubsubMessageOrderingKeyKey is the attribute Key conforming - // to the "messaging.gcp_pubsub.message.ordering_key" semantic conventions. - // It represents the ordering key for a given message. If the attribute is - // not present, the message does not have an ordering key. - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: 'ordering_key' - MessagingGCPPubsubMessageOrderingKeyKey = attribute.Key("messaging.gcp_pubsub.message.ordering_key") - - // MessagingKafkaConsumerGroupKey is the attribute Key conforming to the - // "messaging.kafka.consumer.group" semantic conventions. It represents the - // name of the Kafka Consumer Group that is handling the message. Only - // applies to consumers, not producers. - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: 'my-group' - MessagingKafkaConsumerGroupKey = attribute.Key("messaging.kafka.consumer.group") - - // MessagingKafkaDestinationPartitionKey is the attribute Key conforming to - // the "messaging.kafka.destination.partition" semantic conventions. It - // represents the partition the message is sent to. - // - // Type: int - // RequirementLevel: Optional - // Stability: experimental - // Examples: 2 - MessagingKafkaDestinationPartitionKey = attribute.Key("messaging.kafka.destination.partition") - - // MessagingKafkaMessageKeyKey is the attribute Key conforming to the - // "messaging.kafka.message.key" semantic conventions. It represents the - // message keys in Kafka are used for grouping alike messages to ensure - // they're processed on the same partition. They differ from - // `messaging.message.id` in that they're not unique. If the key is `null`, - // the attribute MUST NOT be set. - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: 'myKey' - // Note: If the key type is not string, it's string representation has to - // be supplied for the attribute. If the key has no unambiguous, canonical - // string form, don't include its value. - MessagingKafkaMessageKeyKey = attribute.Key("messaging.kafka.message.key") - - // MessagingKafkaMessageOffsetKey is the attribute Key conforming to the - // "messaging.kafka.message.offset" semantic conventions. It represents the - // offset of a record in the corresponding Kafka partition. - // - // Type: int - // RequirementLevel: Optional - // Stability: experimental - // Examples: 42 - MessagingKafkaMessageOffsetKey = attribute.Key("messaging.kafka.message.offset") - - // MessagingKafkaMessageTombstoneKey is the attribute Key conforming to the - // "messaging.kafka.message.tombstone" semantic conventions. It represents - // a boolean that is true if the message is a tombstone. - // - // Type: boolean - // RequirementLevel: Optional - // Stability: experimental - MessagingKafkaMessageTombstoneKey = attribute.Key("messaging.kafka.message.tombstone") - - // MessagingMessageBodySizeKey is the attribute Key conforming to the - // "messaging.message.body.size" semantic conventions. It represents the - // size of the message body in bytes. - // - // Type: int - // RequirementLevel: Optional - // Stability: experimental - // Examples: 1439 - // Note: This can refer to both the compressed or uncompressed body size. - // If both sizes are known, the uncompressed - // body size should be used. - MessagingMessageBodySizeKey = attribute.Key("messaging.message.body.size") - - // MessagingMessageConversationIDKey is the attribute Key conforming to the - // "messaging.message.conversation_id" semantic conventions. It represents - // the conversation ID identifying the conversation to which the message - // belongs, represented as a string. Sometimes called "Correlation ID". - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: 'MyConversationID' - MessagingMessageConversationIDKey = attribute.Key("messaging.message.conversation_id") - - // MessagingMessageEnvelopeSizeKey is the attribute Key conforming to the - // "messaging.message.envelope.size" semantic conventions. It represents - // the size of the message body and metadata in bytes. - // - // Type: int - // RequirementLevel: Optional - // Stability: experimental - // Examples: 2738 - // Note: This can refer to both the compressed or uncompressed size. If - // both sizes are known, the uncompressed - // size should be used. - MessagingMessageEnvelopeSizeKey = attribute.Key("messaging.message.envelope.size") - - // MessagingMessageIDKey is the attribute Key conforming to the - // "messaging.message.id" semantic conventions. It represents a value used - // by the messaging system as an identifier for the message, represented as - // a string. - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: '452a7c7c7c7048c2f887f61572b18fc2' - MessagingMessageIDKey = attribute.Key("messaging.message.id") - - // MessagingOperationKey is the attribute Key conforming to the - // "messaging.operation" semantic conventions. It represents a string - // identifying the kind of messaging operation. - // - // Type: Enum - // RequirementLevel: Optional - // Stability: experimental - // Note: If a custom value is used, it MUST be of low cardinality. - MessagingOperationKey = attribute.Key("messaging.operation") - - // MessagingRabbitmqDestinationRoutingKeyKey is the attribute Key - // conforming to the "messaging.rabbitmq.destination.routing_key" semantic - // conventions. It represents the rabbitMQ message routing key. - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: 'myKey' - MessagingRabbitmqDestinationRoutingKeyKey = attribute.Key("messaging.rabbitmq.destination.routing_key") - - // MessagingRocketmqClientGroupKey is the attribute Key conforming to the - // "messaging.rocketmq.client_group" semantic conventions. It represents - // the name of the RocketMQ producer/consumer group that is handling the - // message. The client type is identified by the SpanKind. - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: 'myConsumerGroup' - MessagingRocketmqClientGroupKey = attribute.Key("messaging.rocketmq.client_group") - - // MessagingRocketmqConsumptionModelKey is the attribute Key conforming to - // the "messaging.rocketmq.consumption_model" semantic conventions. It - // represents the model of message consumption. This only applies to - // consumer spans. - // - // Type: Enum - // RequirementLevel: Optional - // Stability: experimental - MessagingRocketmqConsumptionModelKey = attribute.Key("messaging.rocketmq.consumption_model") - - // MessagingRocketmqMessageDelayTimeLevelKey is the attribute Key - // conforming to the "messaging.rocketmq.message.delay_time_level" semantic - // conventions. It represents the delay time level for delay message, which - // determines the message delay time. - // - // Type: int - // RequirementLevel: Optional - // Stability: experimental - // Examples: 3 - MessagingRocketmqMessageDelayTimeLevelKey = attribute.Key("messaging.rocketmq.message.delay_time_level") - - // MessagingRocketmqMessageDeliveryTimestampKey is the attribute Key - // conforming to the "messaging.rocketmq.message.delivery_timestamp" - // semantic conventions. It represents the timestamp in milliseconds that - // the delay message is expected to be delivered to consumer. - // - // Type: int - // RequirementLevel: Optional - // Stability: experimental - // Examples: 1665987217045 - MessagingRocketmqMessageDeliveryTimestampKey = attribute.Key("messaging.rocketmq.message.delivery_timestamp") - - // MessagingRocketmqMessageGroupKey is the attribute Key conforming to the - // "messaging.rocketmq.message.group" semantic conventions. It represents - // the it is essential for FIFO message. Messages that belong to the same - // message group are always processed one by one within the same consumer - // group. - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: 'myMessageGroup' - MessagingRocketmqMessageGroupKey = attribute.Key("messaging.rocketmq.message.group") - - // MessagingRocketmqMessageKeysKey is the attribute Key conforming to the - // "messaging.rocketmq.message.keys" semantic conventions. It represents - // the key(s) of message, another way to mark message besides message id. - // - // Type: string[] - // RequirementLevel: Optional - // Stability: experimental - // Examples: 'keyA', 'keyB' - MessagingRocketmqMessageKeysKey = attribute.Key("messaging.rocketmq.message.keys") - - // MessagingRocketmqMessageTagKey is the attribute Key conforming to the - // "messaging.rocketmq.message.tag" semantic conventions. It represents the - // secondary classifier of message besides topic. - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: 'tagA' - MessagingRocketmqMessageTagKey = attribute.Key("messaging.rocketmq.message.tag") - - // MessagingRocketmqMessageTypeKey is the attribute Key conforming to the - // "messaging.rocketmq.message.type" semantic conventions. It represents - // the type of message. - // - // Type: Enum - // RequirementLevel: Optional - // Stability: experimental - MessagingRocketmqMessageTypeKey = attribute.Key("messaging.rocketmq.message.type") - - // MessagingRocketmqNamespaceKey is the attribute Key conforming to the - // "messaging.rocketmq.namespace" semantic conventions. It represents the - // namespace of RocketMQ resources, resources in different namespaces are - // individual. - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: 'myNamespace' - MessagingRocketmqNamespaceKey = attribute.Key("messaging.rocketmq.namespace") - - // MessagingSystemKey is the attribute Key conforming to the - // "messaging.system" semantic conventions. It represents an identifier for - // the messaging system being used. See below for a list of well-known - // identifiers. - // - // Type: Enum - // RequirementLevel: Optional - // Stability: experimental - MessagingSystemKey = attribute.Key("messaging.system") -) - -var ( - // One or more messages are provided for publishing to an intermediary. If a single message is published, the context of the "Publish" span can be used as the creation context and no "Create" span needs to be created - MessagingOperationPublish = MessagingOperationKey.String("publish") - // A message is created. "Create" spans always refer to a single message and are used to provide a unique creation context for messages in batch publishing scenarios - MessagingOperationCreate = MessagingOperationKey.String("create") - // One or more messages are requested by a consumer. This operation refers to pull-based scenarios, where consumers explicitly call methods of messaging SDKs to receive messages - MessagingOperationReceive = MessagingOperationKey.String("receive") - // One or more messages are passed to a consumer. This operation refers to push-based scenarios, where consumer register callbacks which get called by messaging SDKs - MessagingOperationDeliver = MessagingOperationKey.String("deliver") -) - -var ( - // Clustering consumption model - MessagingRocketmqConsumptionModelClustering = MessagingRocketmqConsumptionModelKey.String("clustering") - // Broadcasting consumption model - MessagingRocketmqConsumptionModelBroadcasting = MessagingRocketmqConsumptionModelKey.String("broadcasting") -) - -var ( - // Normal message - MessagingRocketmqMessageTypeNormal = MessagingRocketmqMessageTypeKey.String("normal") - // FIFO message - MessagingRocketmqMessageTypeFifo = MessagingRocketmqMessageTypeKey.String("fifo") - // Delay message - MessagingRocketmqMessageTypeDelay = MessagingRocketmqMessageTypeKey.String("delay") - // Transaction message - MessagingRocketmqMessageTypeTransaction = MessagingRocketmqMessageTypeKey.String("transaction") -) - -var ( - // Apache ActiveMQ - MessagingSystemActivemq = MessagingSystemKey.String("activemq") - // Amazon Simple Queue Service (SQS) - MessagingSystemAWSSqs = MessagingSystemKey.String("aws_sqs") - // Azure Event Grid - MessagingSystemAzureEventgrid = MessagingSystemKey.String("azure_eventgrid") - // Azure Event Hubs - MessagingSystemAzureEventhubs = MessagingSystemKey.String("azure_eventhubs") - // Azure Service Bus - MessagingSystemAzureServicebus = MessagingSystemKey.String("azure_servicebus") - // Google Cloud Pub/Sub - MessagingSystemGCPPubsub = MessagingSystemKey.String("gcp_pubsub") - // Java Message Service - MessagingSystemJms = MessagingSystemKey.String("jms") - // Apache Kafka - MessagingSystemKafka = MessagingSystemKey.String("kafka") - // RabbitMQ - MessagingSystemRabbitmq = MessagingSystemKey.String("rabbitmq") - // Apache RocketMQ - MessagingSystemRocketmq = MessagingSystemKey.String("rocketmq") -) - -// MessagingBatchMessageCount returns an attribute KeyValue conforming to -// the "messaging.batch.message_count" semantic conventions. It represents the -// number of messages sent, received, or processed in the scope of the batching -// operation. -func MessagingBatchMessageCount(val int) attribute.KeyValue { - return MessagingBatchMessageCountKey.Int(val) -} - -// MessagingClientID returns an attribute KeyValue conforming to the -// "messaging.client_id" semantic conventions. It represents a unique -// identifier for the client that consumes or produces a message. -func MessagingClientID(val string) attribute.KeyValue { - return MessagingClientIDKey.String(val) -} - -// MessagingDestinationAnonymous returns an attribute KeyValue conforming to -// the "messaging.destination.anonymous" semantic conventions. It represents a -// boolean that is true if the message destination is anonymous (could be -// unnamed or have auto-generated name). -func MessagingDestinationAnonymous(val bool) attribute.KeyValue { - return MessagingDestinationAnonymousKey.Bool(val) -} - -// MessagingDestinationName returns an attribute KeyValue conforming to the -// "messaging.destination.name" semantic conventions. It represents the message -// destination name -func MessagingDestinationName(val string) attribute.KeyValue { - return MessagingDestinationNameKey.String(val) -} - -// MessagingDestinationTemplate returns an attribute KeyValue conforming to -// the "messaging.destination.template" semantic conventions. It represents the -// low cardinality representation of the messaging destination name -func MessagingDestinationTemplate(val string) attribute.KeyValue { - return MessagingDestinationTemplateKey.String(val) -} - -// MessagingDestinationTemporary returns an attribute KeyValue conforming to -// the "messaging.destination.temporary" semantic conventions. It represents a -// boolean that is true if the message destination is temporary and might not -// exist anymore after messages are processed. -func MessagingDestinationTemporary(val bool) attribute.KeyValue { - return MessagingDestinationTemporaryKey.Bool(val) -} - -// MessagingDestinationPublishAnonymous returns an attribute KeyValue -// conforming to the "messaging.destination_publish.anonymous" semantic -// conventions. It represents a boolean that is true if the publish message -// destination is anonymous (could be unnamed or have auto-generated name). -func MessagingDestinationPublishAnonymous(val bool) attribute.KeyValue { - return MessagingDestinationPublishAnonymousKey.Bool(val) -} - -// MessagingDestinationPublishName returns an attribute KeyValue conforming -// to the "messaging.destination_publish.name" semantic conventions. It -// represents the name of the original destination the message was published to -func MessagingDestinationPublishName(val string) attribute.KeyValue { - return MessagingDestinationPublishNameKey.String(val) -} - -// MessagingGCPPubsubMessageOrderingKey returns an attribute KeyValue -// conforming to the "messaging.gcp_pubsub.message.ordering_key" semantic -// conventions. It represents the ordering key for a given message. If the -// attribute is not present, the message does not have an ordering key. -func MessagingGCPPubsubMessageOrderingKey(val string) attribute.KeyValue { - return MessagingGCPPubsubMessageOrderingKeyKey.String(val) -} - -// MessagingKafkaConsumerGroup returns an attribute KeyValue conforming to -// the "messaging.kafka.consumer.group" semantic conventions. It represents the -// name of the Kafka Consumer Group that is handling the message. Only applies -// to consumers, not producers. -func MessagingKafkaConsumerGroup(val string) attribute.KeyValue { - return MessagingKafkaConsumerGroupKey.String(val) -} - -// MessagingKafkaDestinationPartition returns an attribute KeyValue -// conforming to the "messaging.kafka.destination.partition" semantic -// conventions. It represents the partition the message is sent to. -func MessagingKafkaDestinationPartition(val int) attribute.KeyValue { - return MessagingKafkaDestinationPartitionKey.Int(val) -} - -// MessagingKafkaMessageKey returns an attribute KeyValue conforming to the -// "messaging.kafka.message.key" semantic conventions. It represents the -// message keys in Kafka are used for grouping alike messages to ensure they're -// processed on the same partition. They differ from `messaging.message.id` in -// that they're not unique. If the key is `null`, the attribute MUST NOT be -// set. -func MessagingKafkaMessageKey(val string) attribute.KeyValue { - return MessagingKafkaMessageKeyKey.String(val) -} - -// MessagingKafkaMessageOffset returns an attribute KeyValue conforming to -// the "messaging.kafka.message.offset" semantic conventions. It represents the -// offset of a record in the corresponding Kafka partition. -func MessagingKafkaMessageOffset(val int) attribute.KeyValue { - return MessagingKafkaMessageOffsetKey.Int(val) -} - -// MessagingKafkaMessageTombstone returns an attribute KeyValue conforming -// to the "messaging.kafka.message.tombstone" semantic conventions. It -// represents a boolean that is true if the message is a tombstone. -func MessagingKafkaMessageTombstone(val bool) attribute.KeyValue { - return MessagingKafkaMessageTombstoneKey.Bool(val) -} - -// MessagingMessageBodySize returns an attribute KeyValue conforming to the -// "messaging.message.body.size" semantic conventions. It represents the size -// of the message body in bytes. -func MessagingMessageBodySize(val int) attribute.KeyValue { - return MessagingMessageBodySizeKey.Int(val) -} - -// MessagingMessageConversationID returns an attribute KeyValue conforming -// to the "messaging.message.conversation_id" semantic conventions. It -// represents the conversation ID identifying the conversation to which the -// message belongs, represented as a string. Sometimes called "Correlation ID". -func MessagingMessageConversationID(val string) attribute.KeyValue { - return MessagingMessageConversationIDKey.String(val) -} - -// MessagingMessageEnvelopeSize returns an attribute KeyValue conforming to -// the "messaging.message.envelope.size" semantic conventions. It represents -// the size of the message body and metadata in bytes. -func MessagingMessageEnvelopeSize(val int) attribute.KeyValue { - return MessagingMessageEnvelopeSizeKey.Int(val) -} - -// MessagingMessageID returns an attribute KeyValue conforming to the -// "messaging.message.id" semantic conventions. It represents a value used by -// the messaging system as an identifier for the message, represented as a -// string. -func MessagingMessageID(val string) attribute.KeyValue { - return MessagingMessageIDKey.String(val) -} - -// MessagingRabbitmqDestinationRoutingKey returns an attribute KeyValue -// conforming to the "messaging.rabbitmq.destination.routing_key" semantic -// conventions. It represents the rabbitMQ message routing key. -func MessagingRabbitmqDestinationRoutingKey(val string) attribute.KeyValue { - return MessagingRabbitmqDestinationRoutingKeyKey.String(val) -} - -// MessagingRocketmqClientGroup returns an attribute KeyValue conforming to -// the "messaging.rocketmq.client_group" semantic conventions. It represents -// the name of the RocketMQ producer/consumer group that is handling the -// message. The client type is identified by the SpanKind. -func MessagingRocketmqClientGroup(val string) attribute.KeyValue { - return MessagingRocketmqClientGroupKey.String(val) -} - -// MessagingRocketmqMessageDelayTimeLevel returns an attribute KeyValue -// conforming to the "messaging.rocketmq.message.delay_time_level" semantic -// conventions. It represents the delay time level for delay message, which -// determines the message delay time. -func MessagingRocketmqMessageDelayTimeLevel(val int) attribute.KeyValue { - return MessagingRocketmqMessageDelayTimeLevelKey.Int(val) -} - -// MessagingRocketmqMessageDeliveryTimestamp returns an attribute KeyValue -// conforming to the "messaging.rocketmq.message.delivery_timestamp" semantic -// conventions. It represents the timestamp in milliseconds that the delay -// message is expected to be delivered to consumer. -func MessagingRocketmqMessageDeliveryTimestamp(val int) attribute.KeyValue { - return MessagingRocketmqMessageDeliveryTimestampKey.Int(val) -} - -// MessagingRocketmqMessageGroup returns an attribute KeyValue conforming to -// the "messaging.rocketmq.message.group" semantic conventions. It represents -// the it is essential for FIFO message. Messages that belong to the same -// message group are always processed one by one within the same consumer -// group. -func MessagingRocketmqMessageGroup(val string) attribute.KeyValue { - return MessagingRocketmqMessageGroupKey.String(val) -} - -// MessagingRocketmqMessageKeys returns an attribute KeyValue conforming to -// the "messaging.rocketmq.message.keys" semantic conventions. It represents -// the key(s) of message, another way to mark message besides message id. -func MessagingRocketmqMessageKeys(val ...string) attribute.KeyValue { - return MessagingRocketmqMessageKeysKey.StringSlice(val) -} - -// MessagingRocketmqMessageTag returns an attribute KeyValue conforming to -// the "messaging.rocketmq.message.tag" semantic conventions. It represents the -// secondary classifier of message besides topic. -func MessagingRocketmqMessageTag(val string) attribute.KeyValue { - return MessagingRocketmqMessageTagKey.String(val) -} - -// MessagingRocketmqNamespace returns an attribute KeyValue conforming to -// the "messaging.rocketmq.namespace" semantic conventions. It represents the -// namespace of RocketMQ resources, resources in different namespaces are -// individual. -func MessagingRocketmqNamespace(val string) attribute.KeyValue { - return MessagingRocketmqNamespaceKey.String(val) -} - -// These attributes may be used for any network related operation. -const ( - // NetworkCarrierIccKey is the attribute Key conforming to the - // "network.carrier.icc" semantic conventions. It represents the ISO 3166-1 - // alpha-2 2-character country code associated with the mobile carrier - // network. - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: 'DE' - NetworkCarrierIccKey = attribute.Key("network.carrier.icc") - - // NetworkCarrierMccKey is the attribute Key conforming to the - // "network.carrier.mcc" semantic conventions. It represents the mobile - // carrier country code. - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: '310' - NetworkCarrierMccKey = attribute.Key("network.carrier.mcc") - - // NetworkCarrierMncKey is the attribute Key conforming to the - // "network.carrier.mnc" semantic conventions. It represents the mobile - // carrier network code. - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: '001' - NetworkCarrierMncKey = attribute.Key("network.carrier.mnc") - - // NetworkCarrierNameKey is the attribute Key conforming to the - // "network.carrier.name" semantic conventions. It represents the name of - // the mobile carrier. - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: 'sprint' - NetworkCarrierNameKey = attribute.Key("network.carrier.name") - - // NetworkConnectionSubtypeKey is the attribute Key conforming to the - // "network.connection.subtype" semantic conventions. It represents the - // this describes more details regarding the connection.type. It may be the - // type of cell technology connection, but it could be used for describing - // details about a wifi connection. - // - // Type: Enum - // RequirementLevel: Optional - // Stability: experimental - // Examples: 'LTE' - NetworkConnectionSubtypeKey = attribute.Key("network.connection.subtype") - - // NetworkConnectionTypeKey is the attribute Key conforming to the - // "network.connection.type" semantic conventions. It represents the - // internet connection type. - // - // Type: Enum - // RequirementLevel: Optional - // Stability: experimental - // Examples: 'wifi' - NetworkConnectionTypeKey = attribute.Key("network.connection.type") - - // NetworkIoDirectionKey is the attribute Key conforming to the - // "network.io.direction" semantic conventions. It represents the network - // IO operation direction. - // - // Type: Enum - // RequirementLevel: Optional - // Stability: experimental - // Examples: 'transmit' - NetworkIoDirectionKey = attribute.Key("network.io.direction") - - // NetworkLocalAddressKey is the attribute Key conforming to the - // "network.local.address" semantic conventions. It represents the local - // address of the network connection - IP address or Unix domain socket - // name. - // - // Type: string - // RequirementLevel: Optional - // Stability: stable - // Examples: '10.1.2.80', '/tmp/my.sock' - NetworkLocalAddressKey = attribute.Key("network.local.address") - - // NetworkLocalPortKey is the attribute Key conforming to the - // "network.local.port" semantic conventions. It represents the local port - // number of the network connection. - // - // Type: int - // RequirementLevel: Optional - // Stability: stable - // Examples: 65123 - NetworkLocalPortKey = attribute.Key("network.local.port") - - // NetworkPeerAddressKey is the attribute Key conforming to the - // "network.peer.address" semantic conventions. It represents the peer - // address of the network connection - IP address or Unix domain socket - // name. - // - // Type: string - // RequirementLevel: Optional - // Stability: stable - // Examples: '10.1.2.80', '/tmp/my.sock' - NetworkPeerAddressKey = attribute.Key("network.peer.address") - - // NetworkPeerPortKey is the attribute Key conforming to the - // "network.peer.port" semantic conventions. It represents the peer port - // number of the network connection. - // - // Type: int - // RequirementLevel: Optional - // Stability: stable - // Examples: 65123 - NetworkPeerPortKey = attribute.Key("network.peer.port") - - // NetworkProtocolNameKey is the attribute Key conforming to the - // "network.protocol.name" semantic conventions. It represents the [OSI - // application layer](https://osi-model.com/application-layer/) or non-OSI - // equivalent. - // - // Type: string - // RequirementLevel: Optional - // Stability: stable - // Examples: 'amqp', 'http', 'mqtt' - // Note: The value SHOULD be normalized to lowercase. - NetworkProtocolNameKey = attribute.Key("network.protocol.name") - - // NetworkProtocolVersionKey is the attribute Key conforming to the - // "network.protocol.version" semantic conventions. It represents the - // version of the protocol specified in `network.protocol.name`. - // - // Type: string - // RequirementLevel: Optional - // Stability: stable - // Examples: '3.1.1' - // Note: `network.protocol.version` refers to the version of the protocol - // used and might be different from the protocol client's version. If the - // HTTP client has a version of `0.27.2`, but sends HTTP version `1.1`, - // this attribute should be set to `1.1`. - NetworkProtocolVersionKey = attribute.Key("network.protocol.version") - - // NetworkTransportKey is the attribute Key conforming to the - // "network.transport" semantic conventions. It represents the [OSI - // transport layer](https://osi-model.com/transport-layer/) or - // [inter-process communication - // method](https://wikipedia.org/wiki/Inter-process_communication). - // - // Type: Enum - // RequirementLevel: Optional - // Stability: stable - // Examples: 'tcp', 'udp' - // Note: The value SHOULD be normalized to lowercase. - // - // Consider always setting the transport when setting a port number, since - // a port number is ambiguous without knowing the transport. For example - // different processes could be listening on TCP port 12345 and UDP port - // 12345. - NetworkTransportKey = attribute.Key("network.transport") - - // NetworkTypeKey is the attribute Key conforming to the "network.type" - // semantic conventions. It represents the [OSI network - // layer](https://osi-model.com/network-layer/) or non-OSI equivalent. - // - // Type: Enum - // RequirementLevel: Optional - // Stability: stable - // Examples: 'ipv4', 'ipv6' - // Note: The value SHOULD be normalized to lowercase. - NetworkTypeKey = attribute.Key("network.type") -) - -var ( - // GPRS - NetworkConnectionSubtypeGprs = NetworkConnectionSubtypeKey.String("gprs") - // EDGE - NetworkConnectionSubtypeEdge = NetworkConnectionSubtypeKey.String("edge") - // UMTS - NetworkConnectionSubtypeUmts = NetworkConnectionSubtypeKey.String("umts") - // CDMA - NetworkConnectionSubtypeCdma = NetworkConnectionSubtypeKey.String("cdma") - // EVDO Rel. 0 - NetworkConnectionSubtypeEvdo0 = NetworkConnectionSubtypeKey.String("evdo_0") - // EVDO Rev. A - NetworkConnectionSubtypeEvdoA = NetworkConnectionSubtypeKey.String("evdo_a") - // CDMA2000 1XRTT - NetworkConnectionSubtypeCdma20001xrtt = NetworkConnectionSubtypeKey.String("cdma2000_1xrtt") - // HSDPA - NetworkConnectionSubtypeHsdpa = NetworkConnectionSubtypeKey.String("hsdpa") - // HSUPA - NetworkConnectionSubtypeHsupa = NetworkConnectionSubtypeKey.String("hsupa") - // HSPA - NetworkConnectionSubtypeHspa = NetworkConnectionSubtypeKey.String("hspa") - // IDEN - NetworkConnectionSubtypeIden = NetworkConnectionSubtypeKey.String("iden") - // EVDO Rev. B - NetworkConnectionSubtypeEvdoB = NetworkConnectionSubtypeKey.String("evdo_b") - // LTE - NetworkConnectionSubtypeLte = NetworkConnectionSubtypeKey.String("lte") - // EHRPD - NetworkConnectionSubtypeEhrpd = NetworkConnectionSubtypeKey.String("ehrpd") - // HSPAP - NetworkConnectionSubtypeHspap = NetworkConnectionSubtypeKey.String("hspap") - // GSM - NetworkConnectionSubtypeGsm = NetworkConnectionSubtypeKey.String("gsm") - // TD-SCDMA - NetworkConnectionSubtypeTdScdma = NetworkConnectionSubtypeKey.String("td_scdma") - // IWLAN - NetworkConnectionSubtypeIwlan = NetworkConnectionSubtypeKey.String("iwlan") - // 5G NR (New Radio) - NetworkConnectionSubtypeNr = NetworkConnectionSubtypeKey.String("nr") - // 5G NRNSA (New Radio Non-Standalone) - NetworkConnectionSubtypeNrnsa = NetworkConnectionSubtypeKey.String("nrnsa") - // LTE CA - NetworkConnectionSubtypeLteCa = NetworkConnectionSubtypeKey.String("lte_ca") -) - -var ( - // wifi - NetworkConnectionTypeWifi = NetworkConnectionTypeKey.String("wifi") - // wired - NetworkConnectionTypeWired = NetworkConnectionTypeKey.String("wired") - // cell - NetworkConnectionTypeCell = NetworkConnectionTypeKey.String("cell") - // unavailable - NetworkConnectionTypeUnavailable = NetworkConnectionTypeKey.String("unavailable") - // unknown - NetworkConnectionTypeUnknown = NetworkConnectionTypeKey.String("unknown") -) - -var ( - // transmit - NetworkIoDirectionTransmit = NetworkIoDirectionKey.String("transmit") - // receive - NetworkIoDirectionReceive = NetworkIoDirectionKey.String("receive") -) - -var ( - // TCP - NetworkTransportTCP = NetworkTransportKey.String("tcp") - // UDP - NetworkTransportUDP = NetworkTransportKey.String("udp") - // Named or anonymous pipe - NetworkTransportPipe = NetworkTransportKey.String("pipe") - // Unix domain socket - NetworkTransportUnix = NetworkTransportKey.String("unix") -) - -var ( - // IPv4 - NetworkTypeIpv4 = NetworkTypeKey.String("ipv4") - // IPv6 - NetworkTypeIpv6 = NetworkTypeKey.String("ipv6") -) - -// NetworkCarrierIcc returns an attribute KeyValue conforming to the -// "network.carrier.icc" semantic conventions. It represents the ISO 3166-1 -// alpha-2 2-character country code associated with the mobile carrier network. -func NetworkCarrierIcc(val string) attribute.KeyValue { - return NetworkCarrierIccKey.String(val) -} - -// NetworkCarrierMcc returns an attribute KeyValue conforming to the -// "network.carrier.mcc" semantic conventions. It represents the mobile carrier -// country code. -func NetworkCarrierMcc(val string) attribute.KeyValue { - return NetworkCarrierMccKey.String(val) -} - -// NetworkCarrierMnc returns an attribute KeyValue conforming to the -// "network.carrier.mnc" semantic conventions. It represents the mobile carrier -// network code. -func NetworkCarrierMnc(val string) attribute.KeyValue { - return NetworkCarrierMncKey.String(val) -} - -// NetworkCarrierName returns an attribute KeyValue conforming to the -// "network.carrier.name" semantic conventions. It represents the name of the -// mobile carrier. -func NetworkCarrierName(val string) attribute.KeyValue { - return NetworkCarrierNameKey.String(val) -} - -// NetworkLocalAddress returns an attribute KeyValue conforming to the -// "network.local.address" semantic conventions. It represents the local -// address of the network connection - IP address or Unix domain socket name. -func NetworkLocalAddress(val string) attribute.KeyValue { - return NetworkLocalAddressKey.String(val) -} - -// NetworkLocalPort returns an attribute KeyValue conforming to the -// "network.local.port" semantic conventions. It represents the local port -// number of the network connection. -func NetworkLocalPort(val int) attribute.KeyValue { - return NetworkLocalPortKey.Int(val) -} - -// NetworkPeerAddress returns an attribute KeyValue conforming to the -// "network.peer.address" semantic conventions. It represents the peer address -// of the network connection - IP address or Unix domain socket name. -func NetworkPeerAddress(val string) attribute.KeyValue { - return NetworkPeerAddressKey.String(val) -} - -// NetworkPeerPort returns an attribute KeyValue conforming to the -// "network.peer.port" semantic conventions. It represents the peer port number -// of the network connection. -func NetworkPeerPort(val int) attribute.KeyValue { - return NetworkPeerPortKey.Int(val) -} - -// NetworkProtocolName returns an attribute KeyValue conforming to the -// "network.protocol.name" semantic conventions. It represents the [OSI -// application layer](https://osi-model.com/application-layer/) or non-OSI -// equivalent. -func NetworkProtocolName(val string) attribute.KeyValue { - return NetworkProtocolNameKey.String(val) -} - -// NetworkProtocolVersion returns an attribute KeyValue conforming to the -// "network.protocol.version" semantic conventions. It represents the version -// of the protocol specified in `network.protocol.name`. -func NetworkProtocolVersion(val string) attribute.KeyValue { - return NetworkProtocolVersionKey.String(val) -} - -// Attributes for remote procedure calls. -const ( - // RPCConnectRPCErrorCodeKey is the attribute Key conforming to the - // "rpc.connect_rpc.error_code" semantic conventions. It represents the - // [error codes](https://connect.build/docs/protocol/#error-codes) of the - // Connect request. Error codes are always string values. - // - // Type: Enum - // RequirementLevel: Optional - // Stability: experimental - RPCConnectRPCErrorCodeKey = attribute.Key("rpc.connect_rpc.error_code") - - // RPCGRPCStatusCodeKey is the attribute Key conforming to the - // "rpc.grpc.status_code" semantic conventions. It represents the [numeric - // status - // code](https://github.com/grpc/grpc/blob/v1.33.2/doc/statuscodes.md) of - // the gRPC request. - // - // Type: Enum - // RequirementLevel: Optional - // Stability: experimental - RPCGRPCStatusCodeKey = attribute.Key("rpc.grpc.status_code") - - // RPCJsonrpcErrorCodeKey is the attribute Key conforming to the - // "rpc.jsonrpc.error_code" semantic conventions. It represents the - // `error.code` property of response if it is an error response. - // - // Type: int - // RequirementLevel: Optional - // Stability: experimental - // Examples: -32700, 100 - RPCJsonrpcErrorCodeKey = attribute.Key("rpc.jsonrpc.error_code") - - // RPCJsonrpcErrorMessageKey is the attribute Key conforming to the - // "rpc.jsonrpc.error_message" semantic conventions. It represents the - // `error.message` property of response if it is an error response. - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: 'Parse error', 'User already exists' - RPCJsonrpcErrorMessageKey = attribute.Key("rpc.jsonrpc.error_message") - - // RPCJsonrpcRequestIDKey is the attribute Key conforming to the - // "rpc.jsonrpc.request_id" semantic conventions. It represents the `id` - // property of request or response. Since protocol allows id to be int, - // string, `null` or missing (for notifications), value is expected to be - // cast to string for simplicity. Use empty string in case of `null` value. - // Omit entirely if this is a notification. - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: '10', 'request-7', '' - RPCJsonrpcRequestIDKey = attribute.Key("rpc.jsonrpc.request_id") - - // RPCJsonrpcVersionKey is the attribute Key conforming to the - // "rpc.jsonrpc.version" semantic conventions. It represents the protocol - // version as in `jsonrpc` property of request/response. Since JSON-RPC 1.0 - // doesn't specify this, the value can be omitted. - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: '2.0', '1.0' - RPCJsonrpcVersionKey = attribute.Key("rpc.jsonrpc.version") - - // RPCMethodKey is the attribute Key conforming to the "rpc.method" - // semantic conventions. It represents the name of the (logical) method - // being called, must be equal to the $method part in the span name. - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: 'exampleMethod' - // Note: This is the logical name of the method from the RPC interface - // perspective, which can be different from the name of any implementing - // method/function. The `code.function` attribute may be used to store the - // latter (e.g., method actually executing the call on the server side, RPC - // client stub method on the client side). - RPCMethodKey = attribute.Key("rpc.method") - - // RPCServiceKey is the attribute Key conforming to the "rpc.service" - // semantic conventions. It represents the full (logical) name of the - // service being called, including its package name, if applicable. - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: 'myservice.EchoService' - // Note: This is the logical name of the service from the RPC interface - // perspective, which can be different from the name of any implementing - // class. The `code.namespace` attribute may be used to store the latter - // (despite the attribute name, it may include a class name; e.g., class - // with method actually executing the call on the server side, RPC client - // stub class on the client side). - RPCServiceKey = attribute.Key("rpc.service") - - // RPCSystemKey is the attribute Key conforming to the "rpc.system" - // semantic conventions. It represents a string identifying the remoting - // system. See below for a list of well-known identifiers. - // - // Type: Enum - // RequirementLevel: Optional - // Stability: experimental - RPCSystemKey = attribute.Key("rpc.system") -) - -var ( - // cancelled - RPCConnectRPCErrorCodeCancelled = RPCConnectRPCErrorCodeKey.String("cancelled") - // unknown - RPCConnectRPCErrorCodeUnknown = RPCConnectRPCErrorCodeKey.String("unknown") - // invalid_argument - RPCConnectRPCErrorCodeInvalidArgument = RPCConnectRPCErrorCodeKey.String("invalid_argument") - // deadline_exceeded - RPCConnectRPCErrorCodeDeadlineExceeded = RPCConnectRPCErrorCodeKey.String("deadline_exceeded") - // not_found - RPCConnectRPCErrorCodeNotFound = RPCConnectRPCErrorCodeKey.String("not_found") - // already_exists - RPCConnectRPCErrorCodeAlreadyExists = RPCConnectRPCErrorCodeKey.String("already_exists") - // permission_denied - RPCConnectRPCErrorCodePermissionDenied = RPCConnectRPCErrorCodeKey.String("permission_denied") - // resource_exhausted - RPCConnectRPCErrorCodeResourceExhausted = RPCConnectRPCErrorCodeKey.String("resource_exhausted") - // failed_precondition - RPCConnectRPCErrorCodeFailedPrecondition = RPCConnectRPCErrorCodeKey.String("failed_precondition") - // aborted - RPCConnectRPCErrorCodeAborted = RPCConnectRPCErrorCodeKey.String("aborted") - // out_of_range - RPCConnectRPCErrorCodeOutOfRange = RPCConnectRPCErrorCodeKey.String("out_of_range") - // unimplemented - RPCConnectRPCErrorCodeUnimplemented = RPCConnectRPCErrorCodeKey.String("unimplemented") - // internal - RPCConnectRPCErrorCodeInternal = RPCConnectRPCErrorCodeKey.String("internal") - // unavailable - RPCConnectRPCErrorCodeUnavailable = RPCConnectRPCErrorCodeKey.String("unavailable") - // data_loss - RPCConnectRPCErrorCodeDataLoss = RPCConnectRPCErrorCodeKey.String("data_loss") - // unauthenticated - RPCConnectRPCErrorCodeUnauthenticated = RPCConnectRPCErrorCodeKey.String("unauthenticated") -) - -var ( - // OK - RPCGRPCStatusCodeOk = RPCGRPCStatusCodeKey.Int(0) - // CANCELLED - RPCGRPCStatusCodeCancelled = RPCGRPCStatusCodeKey.Int(1) - // UNKNOWN - RPCGRPCStatusCodeUnknown = RPCGRPCStatusCodeKey.Int(2) - // INVALID_ARGUMENT - RPCGRPCStatusCodeInvalidArgument = RPCGRPCStatusCodeKey.Int(3) - // DEADLINE_EXCEEDED - RPCGRPCStatusCodeDeadlineExceeded = RPCGRPCStatusCodeKey.Int(4) - // NOT_FOUND - RPCGRPCStatusCodeNotFound = RPCGRPCStatusCodeKey.Int(5) - // ALREADY_EXISTS - RPCGRPCStatusCodeAlreadyExists = RPCGRPCStatusCodeKey.Int(6) - // PERMISSION_DENIED - RPCGRPCStatusCodePermissionDenied = RPCGRPCStatusCodeKey.Int(7) - // RESOURCE_EXHAUSTED - RPCGRPCStatusCodeResourceExhausted = RPCGRPCStatusCodeKey.Int(8) - // FAILED_PRECONDITION - RPCGRPCStatusCodeFailedPrecondition = RPCGRPCStatusCodeKey.Int(9) - // ABORTED - RPCGRPCStatusCodeAborted = RPCGRPCStatusCodeKey.Int(10) - // OUT_OF_RANGE - RPCGRPCStatusCodeOutOfRange = RPCGRPCStatusCodeKey.Int(11) - // UNIMPLEMENTED - RPCGRPCStatusCodeUnimplemented = RPCGRPCStatusCodeKey.Int(12) - // INTERNAL - RPCGRPCStatusCodeInternal = RPCGRPCStatusCodeKey.Int(13) - // UNAVAILABLE - RPCGRPCStatusCodeUnavailable = RPCGRPCStatusCodeKey.Int(14) - // DATA_LOSS - RPCGRPCStatusCodeDataLoss = RPCGRPCStatusCodeKey.Int(15) - // UNAUTHENTICATED - RPCGRPCStatusCodeUnauthenticated = RPCGRPCStatusCodeKey.Int(16) -) - -var ( - // gRPC - RPCSystemGRPC = RPCSystemKey.String("grpc") - // Java RMI - RPCSystemJavaRmi = RPCSystemKey.String("java_rmi") - // .NET WCF - RPCSystemDotnetWcf = RPCSystemKey.String("dotnet_wcf") - // Apache Dubbo - RPCSystemApacheDubbo = RPCSystemKey.String("apache_dubbo") - // Connect RPC - RPCSystemConnectRPC = RPCSystemKey.String("connect_rpc") -) - -// RPCJsonrpcErrorCode returns an attribute KeyValue conforming to the -// "rpc.jsonrpc.error_code" semantic conventions. It represents the -// `error.code` property of response if it is an error response. -func RPCJsonrpcErrorCode(val int) attribute.KeyValue { - return RPCJsonrpcErrorCodeKey.Int(val) -} - -// RPCJsonrpcErrorMessage returns an attribute KeyValue conforming to the -// "rpc.jsonrpc.error_message" semantic conventions. It represents the -// `error.message` property of response if it is an error response. -func RPCJsonrpcErrorMessage(val string) attribute.KeyValue { - return RPCJsonrpcErrorMessageKey.String(val) -} - -// RPCJsonrpcRequestID returns an attribute KeyValue conforming to the -// "rpc.jsonrpc.request_id" semantic conventions. It represents the `id` -// property of request or response. Since protocol allows id to be int, string, -// `null` or missing (for notifications), value is expected to be cast to -// string for simplicity. Use empty string in case of `null` value. Omit -// entirely if this is a notification. -func RPCJsonrpcRequestID(val string) attribute.KeyValue { - return RPCJsonrpcRequestIDKey.String(val) -} - -// RPCJsonrpcVersion returns an attribute KeyValue conforming to the -// "rpc.jsonrpc.version" semantic conventions. It represents the protocol -// version as in `jsonrpc` property of request/response. Since JSON-RPC 1.0 -// doesn't specify this, the value can be omitted. -func RPCJsonrpcVersion(val string) attribute.KeyValue { - return RPCJsonrpcVersionKey.String(val) -} - -// RPCMethod returns an attribute KeyValue conforming to the "rpc.method" -// semantic conventions. It represents the name of the (logical) method being -// called, must be equal to the $method part in the span name. -func RPCMethod(val string) attribute.KeyValue { - return RPCMethodKey.String(val) -} - -// RPCService returns an attribute KeyValue conforming to the "rpc.service" -// semantic conventions. It represents the full (logical) name of the service -// being called, including its package name, if applicable. -func RPCService(val string) attribute.KeyValue { - return RPCServiceKey.String(val) -} - -// These attributes may be used to describe the server in a connection-based -// network interaction where there is one side that initiates the connection -// (the client is the side that initiates the connection). This covers all TCP -// network interactions since TCP is connection-based and one side initiates -// the connection (an exception is made for peer-to-peer communication over TCP -// where the "user-facing" surface of the protocol / API doesn't expose a clear -// notion of client and server). This also covers UDP network interactions -// where one side initiates the interaction, e.g. QUIC (HTTP/3) and DNS. -const ( - // ServerAddressKey is the attribute Key conforming to the "server.address" - // semantic conventions. It represents the server domain name if available - // without reverse DNS lookup; otherwise, IP address or Unix domain socket - // name. - // - // Type: string - // RequirementLevel: Optional - // Stability: stable - // Examples: 'example.com', '10.1.2.80', '/tmp/my.sock' - // Note: When observed from the client side, and when communicating through - // an intermediary, `server.address` SHOULD represent the server address - // behind any intermediaries, for example proxies, if it's available. - ServerAddressKey = attribute.Key("server.address") - - // ServerPortKey is the attribute Key conforming to the "server.port" - // semantic conventions. It represents the server port number. - // - // Type: int - // RequirementLevel: Optional - // Stability: stable - // Examples: 80, 8080, 443 - // Note: When observed from the client side, and when communicating through - // an intermediary, `server.port` SHOULD represent the server port behind - // any intermediaries, for example proxies, if it's available. - ServerPortKey = attribute.Key("server.port") -) - -// ServerAddress returns an attribute KeyValue conforming to the -// "server.address" semantic conventions. It represents the server domain name -// if available without reverse DNS lookup; otherwise, IP address or Unix -// domain socket name. -func ServerAddress(val string) attribute.KeyValue { - return ServerAddressKey.String(val) -} - -// ServerPort returns an attribute KeyValue conforming to the "server.port" -// semantic conventions. It represents the server port number. -func ServerPort(val int) attribute.KeyValue { - return ServerPortKey.Int(val) -} - -// These attributes may be used to describe the sender of a network -// exchange/packet. These should be used when there is no client/server -// relationship between the two sides, or when that relationship is unknown. -// This covers low-level network interactions (e.g. packet tracing) where you -// don't know if there was a connection or which side initiated it. This also -// covers unidirectional UDP flows and peer-to-peer communication where the -// "user-facing" surface of the protocol / API doesn't expose a clear notion of -// client and server. -const ( - // SourceAddressKey is the attribute Key conforming to the "source.address" - // semantic conventions. It represents the source address - domain name if - // available without reverse DNS lookup; otherwise, IP address or Unix - // domain socket name. - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: 'source.example.com', '10.1.2.80', '/tmp/my.sock' - // Note: When observed from the destination side, and when communicating - // through an intermediary, `source.address` SHOULD represent the source - // address behind any intermediaries, for example proxies, if it's - // available. - SourceAddressKey = attribute.Key("source.address") - - // SourcePortKey is the attribute Key conforming to the "source.port" - // semantic conventions. It represents the source port number - // - // Type: int - // RequirementLevel: Optional - // Stability: experimental - // Examples: 3389, 2888 - SourcePortKey = attribute.Key("source.port") -) - -// SourceAddress returns an attribute KeyValue conforming to the -// "source.address" semantic conventions. It represents the source address - -// domain name if available without reverse DNS lookup; otherwise, IP address -// or Unix domain socket name. -func SourceAddress(val string) attribute.KeyValue { - return SourceAddressKey.String(val) -} - -// SourcePort returns an attribute KeyValue conforming to the "source.port" -// semantic conventions. It represents the source port number -func SourcePort(val int) attribute.KeyValue { - return SourcePortKey.Int(val) -} - -// Semantic convention attributes in the TLS namespace. -const ( - // TLSCipherKey is the attribute Key conforming to the "tls.cipher" - // semantic conventions. It represents the string indicating the - // [cipher](https://datatracker.ietf.org/doc/html/rfc5246#appendix-A.5) - // used during the current connection. - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: 'TLS_RSA_WITH_3DES_EDE_CBC_SHA', - // 'TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256' - // Note: The values allowed for `tls.cipher` MUST be one of the - // `Descriptions` of the [registered TLS Cipher - // Suits](https://www.iana.org/assignments/tls-parameters/tls-parameters.xhtml#table-tls-parameters-4). - TLSCipherKey = attribute.Key("tls.cipher") - - // TLSClientCertificateKey is the attribute Key conforming to the - // "tls.client.certificate" semantic conventions. It represents the - // pEM-encoded stand-alone certificate offered by the client. This is - // usually mutually-exclusive of `client.certificate_chain` since this - // value also exists in that list. - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: 'MII...' - TLSClientCertificateKey = attribute.Key("tls.client.certificate") - - // TLSClientCertificateChainKey is the attribute Key conforming to the - // "tls.client.certificate_chain" semantic conventions. It represents the - // array of PEM-encoded certificates that make up the certificate chain - // offered by the client. This is usually mutually-exclusive of - // `client.certificate` since that value should be the first certificate in - // the chain. - // - // Type: string[] - // RequirementLevel: Optional - // Stability: experimental - // Examples: 'MII...', 'MI...' - TLSClientCertificateChainKey = attribute.Key("tls.client.certificate_chain") - - // TLSClientHashMd5Key is the attribute Key conforming to the - // "tls.client.hash.md5" semantic conventions. It represents the - // certificate fingerprint using the MD5 digest of DER-encoded version of - // certificate offered by the client. For consistency with other hash - // values, this value should be formatted as an uppercase hash. - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: '0F76C7F2C55BFD7D8E8B8F4BFBF0C9EC' - TLSClientHashMd5Key = attribute.Key("tls.client.hash.md5") - - // TLSClientHashSha1Key is the attribute Key conforming to the - // "tls.client.hash.sha1" semantic conventions. It represents the - // certificate fingerprint using the SHA1 digest of DER-encoded version of - // certificate offered by the client. For consistency with other hash - // values, this value should be formatted as an uppercase hash. - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: '9E393D93138888D288266C2D915214D1D1CCEB2A' - TLSClientHashSha1Key = attribute.Key("tls.client.hash.sha1") - - // TLSClientHashSha256Key is the attribute Key conforming to the - // "tls.client.hash.sha256" semantic conventions. It represents the - // certificate fingerprint using the SHA256 digest of DER-encoded version - // of certificate offered by the client. For consistency with other hash - // values, this value should be formatted as an uppercase hash. - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: - // '0687F666A054EF17A08E2F2162EAB4CBC0D265E1D7875BE74BF3C712CA92DAF0' - TLSClientHashSha256Key = attribute.Key("tls.client.hash.sha256") - - // TLSClientIssuerKey is the attribute Key conforming to the - // "tls.client.issuer" semantic conventions. It represents the - // distinguished name of - // [subject](https://datatracker.ietf.org/doc/html/rfc5280#section-4.1.2.6) - // of the issuer of the x.509 certificate presented by the client. - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: 'CN=Example Root CA, OU=Infrastructure Team, DC=example, - // DC=com' - TLSClientIssuerKey = attribute.Key("tls.client.issuer") - - // TLSClientJa3Key is the attribute Key conforming to the "tls.client.ja3" - // semantic conventions. It represents a hash that identifies clients based - // on how they perform an SSL/TLS handshake. - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: 'd4e5b18d6b55c71272893221c96ba240' - TLSClientJa3Key = attribute.Key("tls.client.ja3") - - // TLSClientNotAfterKey is the attribute Key conforming to the - // "tls.client.not_after" semantic conventions. It represents the date/Time - // indicating when client certificate is no longer considered valid. - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: '2021-01-01T00:00:00.000Z' - TLSClientNotAfterKey = attribute.Key("tls.client.not_after") - - // TLSClientNotBeforeKey is the attribute Key conforming to the - // "tls.client.not_before" semantic conventions. It represents the - // date/Time indicating when client certificate is first considered valid. - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: '1970-01-01T00:00:00.000Z' - TLSClientNotBeforeKey = attribute.Key("tls.client.not_before") - - // TLSClientServerNameKey is the attribute Key conforming to the - // "tls.client.server_name" semantic conventions. It represents the also - // called an SNI, this tells the server which hostname to which the client - // is attempting to connect to. - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: 'opentelemetry.io' - TLSClientServerNameKey = attribute.Key("tls.client.server_name") - - // TLSClientSubjectKey is the attribute Key conforming to the - // "tls.client.subject" semantic conventions. It represents the - // distinguished name of subject of the x.509 certificate presented by the - // client. - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: 'CN=myclient, OU=Documentation Team, DC=example, DC=com' - TLSClientSubjectKey = attribute.Key("tls.client.subject") - - // TLSClientSupportedCiphersKey is the attribute Key conforming to the - // "tls.client.supported_ciphers" semantic conventions. It represents the - // array of ciphers offered by the client during the client hello. - // - // Type: string[] - // RequirementLevel: Optional - // Stability: experimental - // Examples: '"TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384", - // "TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384", "..."' - TLSClientSupportedCiphersKey = attribute.Key("tls.client.supported_ciphers") - - // TLSCurveKey is the attribute Key conforming to the "tls.curve" semantic - // conventions. It represents the string indicating the curve used for the - // given cipher, when applicable - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: 'secp256r1' - TLSCurveKey = attribute.Key("tls.curve") - - // TLSEstablishedKey is the attribute Key conforming to the - // "tls.established" semantic conventions. It represents the boolean flag - // indicating if the TLS negotiation was successful and transitioned to an - // encrypted tunnel. - // - // Type: boolean - // RequirementLevel: Optional - // Stability: experimental - // Examples: True - TLSEstablishedKey = attribute.Key("tls.established") - - // TLSNextProtocolKey is the attribute Key conforming to the - // "tls.next_protocol" semantic conventions. It represents the string - // indicating the protocol being tunneled. Per the values in the [IANA - // registry](https://www.iana.org/assignments/tls-extensiontype-values/tls-extensiontype-values.xhtml#alpn-protocol-ids), - // this string should be lower case. - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: 'http/1.1' - TLSNextProtocolKey = attribute.Key("tls.next_protocol") - - // TLSProtocolNameKey is the attribute Key conforming to the - // "tls.protocol.name" semantic conventions. It represents the normalized - // lowercase protocol name parsed from original string of the negotiated - // [SSL/TLS protocol - // version](https://www.openssl.org/docs/man1.1.1/man3/SSL_get_version.html#RETURN-VALUES) - // - // Type: Enum - // RequirementLevel: Optional - // Stability: experimental - TLSProtocolNameKey = attribute.Key("tls.protocol.name") - - // TLSProtocolVersionKey is the attribute Key conforming to the - // "tls.protocol.version" semantic conventions. It represents the numeric - // part of the version parsed from the original string of the negotiated - // [SSL/TLS protocol - // version](https://www.openssl.org/docs/man1.1.1/man3/SSL_get_version.html#RETURN-VALUES) - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: '1.2', '3' - TLSProtocolVersionKey = attribute.Key("tls.protocol.version") - - // TLSResumedKey is the attribute Key conforming to the "tls.resumed" - // semantic conventions. It represents the boolean flag indicating if this - // TLS connection was resumed from an existing TLS negotiation. - // - // Type: boolean - // RequirementLevel: Optional - // Stability: experimental - // Examples: True - TLSResumedKey = attribute.Key("tls.resumed") - - // TLSServerCertificateKey is the attribute Key conforming to the - // "tls.server.certificate" semantic conventions. It represents the - // pEM-encoded stand-alone certificate offered by the server. This is - // usually mutually-exclusive of `server.certificate_chain` since this - // value also exists in that list. - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: 'MII...' - TLSServerCertificateKey = attribute.Key("tls.server.certificate") - - // TLSServerCertificateChainKey is the attribute Key conforming to the - // "tls.server.certificate_chain" semantic conventions. It represents the - // array of PEM-encoded certificates that make up the certificate chain - // offered by the server. This is usually mutually-exclusive of - // `server.certificate` since that value should be the first certificate in - // the chain. - // - // Type: string[] - // RequirementLevel: Optional - // Stability: experimental - // Examples: 'MII...', 'MI...' - TLSServerCertificateChainKey = attribute.Key("tls.server.certificate_chain") - - // TLSServerHashMd5Key is the attribute Key conforming to the - // "tls.server.hash.md5" semantic conventions. It represents the - // certificate fingerprint using the MD5 digest of DER-encoded version of - // certificate offered by the server. For consistency with other hash - // values, this value should be formatted as an uppercase hash. - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: '0F76C7F2C55BFD7D8E8B8F4BFBF0C9EC' - TLSServerHashMd5Key = attribute.Key("tls.server.hash.md5") - - // TLSServerHashSha1Key is the attribute Key conforming to the - // "tls.server.hash.sha1" semantic conventions. It represents the - // certificate fingerprint using the SHA1 digest of DER-encoded version of - // certificate offered by the server. For consistency with other hash - // values, this value should be formatted as an uppercase hash. - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: '9E393D93138888D288266C2D915214D1D1CCEB2A' - TLSServerHashSha1Key = attribute.Key("tls.server.hash.sha1") - - // TLSServerHashSha256Key is the attribute Key conforming to the - // "tls.server.hash.sha256" semantic conventions. It represents the - // certificate fingerprint using the SHA256 digest of DER-encoded version - // of certificate offered by the server. For consistency with other hash - // values, this value should be formatted as an uppercase hash. - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: - // '0687F666A054EF17A08E2F2162EAB4CBC0D265E1D7875BE74BF3C712CA92DAF0' - TLSServerHashSha256Key = attribute.Key("tls.server.hash.sha256") - - // TLSServerIssuerKey is the attribute Key conforming to the - // "tls.server.issuer" semantic conventions. It represents the - // distinguished name of - // [subject](https://datatracker.ietf.org/doc/html/rfc5280#section-4.1.2.6) - // of the issuer of the x.509 certificate presented by the client. - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: 'CN=Example Root CA, OU=Infrastructure Team, DC=example, - // DC=com' - TLSServerIssuerKey = attribute.Key("tls.server.issuer") - - // TLSServerJa3sKey is the attribute Key conforming to the - // "tls.server.ja3s" semantic conventions. It represents a hash that - // identifies servers based on how they perform an SSL/TLS handshake. - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: 'd4e5b18d6b55c71272893221c96ba240' - TLSServerJa3sKey = attribute.Key("tls.server.ja3s") - - // TLSServerNotAfterKey is the attribute Key conforming to the - // "tls.server.not_after" semantic conventions. It represents the date/Time - // indicating when server certificate is no longer considered valid. - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: '2021-01-01T00:00:00.000Z' - TLSServerNotAfterKey = attribute.Key("tls.server.not_after") - - // TLSServerNotBeforeKey is the attribute Key conforming to the - // "tls.server.not_before" semantic conventions. It represents the - // date/Time indicating when server certificate is first considered valid. - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: '1970-01-01T00:00:00.000Z' - TLSServerNotBeforeKey = attribute.Key("tls.server.not_before") - - // TLSServerSubjectKey is the attribute Key conforming to the - // "tls.server.subject" semantic conventions. It represents the - // distinguished name of subject of the x.509 certificate presented by the - // server. - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: 'CN=myserver, OU=Documentation Team, DC=example, DC=com' - TLSServerSubjectKey = attribute.Key("tls.server.subject") -) - -var ( - // ssl - TLSProtocolNameSsl = TLSProtocolNameKey.String("ssl") - // tls - TLSProtocolNameTLS = TLSProtocolNameKey.String("tls") -) - -// TLSCipher returns an attribute KeyValue conforming to the "tls.cipher" -// semantic conventions. It represents the string indicating the -// [cipher](https://datatracker.ietf.org/doc/html/rfc5246#appendix-A.5) used -// during the current connection. -func TLSCipher(val string) attribute.KeyValue { - return TLSCipherKey.String(val) -} - -// TLSClientCertificate returns an attribute KeyValue conforming to the -// "tls.client.certificate" semantic conventions. It represents the pEM-encoded -// stand-alone certificate offered by the client. This is usually -// mutually-exclusive of `client.certificate_chain` since this value also -// exists in that list. -func TLSClientCertificate(val string) attribute.KeyValue { - return TLSClientCertificateKey.String(val) -} - -// TLSClientCertificateChain returns an attribute KeyValue conforming to the -// "tls.client.certificate_chain" semantic conventions. It represents the array -// of PEM-encoded certificates that make up the certificate chain offered by -// the client. This is usually mutually-exclusive of `client.certificate` since -// that value should be the first certificate in the chain. -func TLSClientCertificateChain(val ...string) attribute.KeyValue { - return TLSClientCertificateChainKey.StringSlice(val) -} - -// TLSClientHashMd5 returns an attribute KeyValue conforming to the -// "tls.client.hash.md5" semantic conventions. It represents the certificate -// fingerprint using the MD5 digest of DER-encoded version of certificate -// offered by the client. For consistency with other hash values, this value -// should be formatted as an uppercase hash. -func TLSClientHashMd5(val string) attribute.KeyValue { - return TLSClientHashMd5Key.String(val) -} - -// TLSClientHashSha1 returns an attribute KeyValue conforming to the -// "tls.client.hash.sha1" semantic conventions. It represents the certificate -// fingerprint using the SHA1 digest of DER-encoded version of certificate -// offered by the client. For consistency with other hash values, this value -// should be formatted as an uppercase hash. -func TLSClientHashSha1(val string) attribute.KeyValue { - return TLSClientHashSha1Key.String(val) -} - -// TLSClientHashSha256 returns an attribute KeyValue conforming to the -// "tls.client.hash.sha256" semantic conventions. It represents the certificate -// fingerprint using the SHA256 digest of DER-encoded version of certificate -// offered by the client. For consistency with other hash values, this value -// should be formatted as an uppercase hash. -func TLSClientHashSha256(val string) attribute.KeyValue { - return TLSClientHashSha256Key.String(val) -} - -// TLSClientIssuer returns an attribute KeyValue conforming to the -// "tls.client.issuer" semantic conventions. It represents the distinguished -// name of -// [subject](https://datatracker.ietf.org/doc/html/rfc5280#section-4.1.2.6) of -// the issuer of the x.509 certificate presented by the client. -func TLSClientIssuer(val string) attribute.KeyValue { - return TLSClientIssuerKey.String(val) -} - -// TLSClientJa3 returns an attribute KeyValue conforming to the -// "tls.client.ja3" semantic conventions. It represents a hash that identifies -// clients based on how they perform an SSL/TLS handshake. -func TLSClientJa3(val string) attribute.KeyValue { - return TLSClientJa3Key.String(val) -} - -// TLSClientNotAfter returns an attribute KeyValue conforming to the -// "tls.client.not_after" semantic conventions. It represents the date/Time -// indicating when client certificate is no longer considered valid. -func TLSClientNotAfter(val string) attribute.KeyValue { - return TLSClientNotAfterKey.String(val) -} - -// TLSClientNotBefore returns an attribute KeyValue conforming to the -// "tls.client.not_before" semantic conventions. It represents the date/Time -// indicating when client certificate is first considered valid. -func TLSClientNotBefore(val string) attribute.KeyValue { - return TLSClientNotBeforeKey.String(val) -} - -// TLSClientServerName returns an attribute KeyValue conforming to the -// "tls.client.server_name" semantic conventions. It represents the also called -// an SNI, this tells the server which hostname to which the client is -// attempting to connect to. -func TLSClientServerName(val string) attribute.KeyValue { - return TLSClientServerNameKey.String(val) -} - -// TLSClientSubject returns an attribute KeyValue conforming to the -// "tls.client.subject" semantic conventions. It represents the distinguished -// name of subject of the x.509 certificate presented by the client. -func TLSClientSubject(val string) attribute.KeyValue { - return TLSClientSubjectKey.String(val) -} - -// TLSClientSupportedCiphers returns an attribute KeyValue conforming to the -// "tls.client.supported_ciphers" semantic conventions. It represents the array -// of ciphers offered by the client during the client hello. -func TLSClientSupportedCiphers(val ...string) attribute.KeyValue { - return TLSClientSupportedCiphersKey.StringSlice(val) -} - -// TLSCurve returns an attribute KeyValue conforming to the "tls.curve" -// semantic conventions. It represents the string indicating the curve used for -// the given cipher, when applicable -func TLSCurve(val string) attribute.KeyValue { - return TLSCurveKey.String(val) -} - -// TLSEstablished returns an attribute KeyValue conforming to the -// "tls.established" semantic conventions. It represents the boolean flag -// indicating if the TLS negotiation was successful and transitioned to an -// encrypted tunnel. -func TLSEstablished(val bool) attribute.KeyValue { - return TLSEstablishedKey.Bool(val) -} - -// TLSNextProtocol returns an attribute KeyValue conforming to the -// "tls.next_protocol" semantic conventions. It represents the string -// indicating the protocol being tunneled. Per the values in the [IANA -// registry](https://www.iana.org/assignments/tls-extensiontype-values/tls-extensiontype-values.xhtml#alpn-protocol-ids), -// this string should be lower case. -func TLSNextProtocol(val string) attribute.KeyValue { - return TLSNextProtocolKey.String(val) -} - -// TLSProtocolVersion returns an attribute KeyValue conforming to the -// "tls.protocol.version" semantic conventions. It represents the numeric part -// of the version parsed from the original string of the negotiated [SSL/TLS -// protocol -// version](https://www.openssl.org/docs/man1.1.1/man3/SSL_get_version.html#RETURN-VALUES) -func TLSProtocolVersion(val string) attribute.KeyValue { - return TLSProtocolVersionKey.String(val) -} - -// TLSResumed returns an attribute KeyValue conforming to the "tls.resumed" -// semantic conventions. It represents the boolean flag indicating if this TLS -// connection was resumed from an existing TLS negotiation. -func TLSResumed(val bool) attribute.KeyValue { - return TLSResumedKey.Bool(val) -} - -// TLSServerCertificate returns an attribute KeyValue conforming to the -// "tls.server.certificate" semantic conventions. It represents the pEM-encoded -// stand-alone certificate offered by the server. This is usually -// mutually-exclusive of `server.certificate_chain` since this value also -// exists in that list. -func TLSServerCertificate(val string) attribute.KeyValue { - return TLSServerCertificateKey.String(val) -} - -// TLSServerCertificateChain returns an attribute KeyValue conforming to the -// "tls.server.certificate_chain" semantic conventions. It represents the array -// of PEM-encoded certificates that make up the certificate chain offered by -// the server. This is usually mutually-exclusive of `server.certificate` since -// that value should be the first certificate in the chain. -func TLSServerCertificateChain(val ...string) attribute.KeyValue { - return TLSServerCertificateChainKey.StringSlice(val) -} - -// TLSServerHashMd5 returns an attribute KeyValue conforming to the -// "tls.server.hash.md5" semantic conventions. It represents the certificate -// fingerprint using the MD5 digest of DER-encoded version of certificate -// offered by the server. For consistency with other hash values, this value -// should be formatted as an uppercase hash. -func TLSServerHashMd5(val string) attribute.KeyValue { - return TLSServerHashMd5Key.String(val) -} - -// TLSServerHashSha1 returns an attribute KeyValue conforming to the -// "tls.server.hash.sha1" semantic conventions. It represents the certificate -// fingerprint using the SHA1 digest of DER-encoded version of certificate -// offered by the server. For consistency with other hash values, this value -// should be formatted as an uppercase hash. -func TLSServerHashSha1(val string) attribute.KeyValue { - return TLSServerHashSha1Key.String(val) -} - -// TLSServerHashSha256 returns an attribute KeyValue conforming to the -// "tls.server.hash.sha256" semantic conventions. It represents the certificate -// fingerprint using the SHA256 digest of DER-encoded version of certificate -// offered by the server. For consistency with other hash values, this value -// should be formatted as an uppercase hash. -func TLSServerHashSha256(val string) attribute.KeyValue { - return TLSServerHashSha256Key.String(val) -} - -// TLSServerIssuer returns an attribute KeyValue conforming to the -// "tls.server.issuer" semantic conventions. It represents the distinguished -// name of -// [subject](https://datatracker.ietf.org/doc/html/rfc5280#section-4.1.2.6) of -// the issuer of the x.509 certificate presented by the client. -func TLSServerIssuer(val string) attribute.KeyValue { - return TLSServerIssuerKey.String(val) -} - -// TLSServerJa3s returns an attribute KeyValue conforming to the -// "tls.server.ja3s" semantic conventions. It represents a hash that identifies -// servers based on how they perform an SSL/TLS handshake. -func TLSServerJa3s(val string) attribute.KeyValue { - return TLSServerJa3sKey.String(val) -} - -// TLSServerNotAfter returns an attribute KeyValue conforming to the -// "tls.server.not_after" semantic conventions. It represents the date/Time -// indicating when server certificate is no longer considered valid. -func TLSServerNotAfter(val string) attribute.KeyValue { - return TLSServerNotAfterKey.String(val) -} - -// TLSServerNotBefore returns an attribute KeyValue conforming to the -// "tls.server.not_before" semantic conventions. It represents the date/Time -// indicating when server certificate is first considered valid. -func TLSServerNotBefore(val string) attribute.KeyValue { - return TLSServerNotBeforeKey.String(val) -} - -// TLSServerSubject returns an attribute KeyValue conforming to the -// "tls.server.subject" semantic conventions. It represents the distinguished -// name of subject of the x.509 certificate presented by the server. -func TLSServerSubject(val string) attribute.KeyValue { - return TLSServerSubjectKey.String(val) -} - -// Attributes describing URL. -const ( - // URLFragmentKey is the attribute Key conforming to the "url.fragment" - // semantic conventions. It represents the [URI - // fragment](https://www.rfc-editor.org/rfc/rfc3986#section-3.5) component - // - // Type: string - // RequirementLevel: Optional - // Stability: stable - // Examples: 'SemConv' - URLFragmentKey = attribute.Key("url.fragment") - - // URLFullKey is the attribute Key conforming to the "url.full" semantic - // conventions. It represents the absolute URL describing a network - // resource according to [RFC3986](https://www.rfc-editor.org/rfc/rfc3986) - // - // Type: string - // RequirementLevel: Optional - // Stability: stable - // Examples: 'https://www.foo.bar/search?q=OpenTelemetry#SemConv', - // '//localhost' - // Note: For network calls, URL usually has - // `scheme://host[:port][path][?query][#fragment]` format, where the - // fragment is not transmitted over HTTP, but if it is known, it SHOULD be - // included nevertheless. - // `url.full` MUST NOT contain credentials passed via URL in form of - // `https://username:password@www.example.com/`. In such case username and - // password SHOULD be redacted and attribute's value SHOULD be - // `https://REDACTED:REDACTED@www.example.com/`. - // `url.full` SHOULD capture the absolute URL when it is available (or can - // be reconstructed) and SHOULD NOT be validated or modified except for - // sanitizing purposes. - URLFullKey = attribute.Key("url.full") - - // URLPathKey is the attribute Key conforming to the "url.path" semantic - // conventions. It represents the [URI - // path](https://www.rfc-editor.org/rfc/rfc3986#section-3.3) component - // - // Type: string - // RequirementLevel: Optional - // Stability: stable - // Examples: '/search' - URLPathKey = attribute.Key("url.path") - - // URLQueryKey is the attribute Key conforming to the "url.query" semantic - // conventions. It represents the [URI - // query](https://www.rfc-editor.org/rfc/rfc3986#section-3.4) component - // - // Type: string - // RequirementLevel: Optional - // Stability: stable - // Examples: 'q=OpenTelemetry' - // Note: Sensitive content provided in query string SHOULD be scrubbed when - // instrumentations can identify it. - URLQueryKey = attribute.Key("url.query") - - // URLSchemeKey is the attribute Key conforming to the "url.scheme" - // semantic conventions. It represents the [URI - // scheme](https://www.rfc-editor.org/rfc/rfc3986#section-3.1) component - // identifying the used protocol. - // - // Type: string - // RequirementLevel: Optional - // Stability: stable - // Examples: 'https', 'ftp', 'telnet' - URLSchemeKey = attribute.Key("url.scheme") -) - -// URLFragment returns an attribute KeyValue conforming to the -// "url.fragment" semantic conventions. It represents the [URI -// fragment](https://www.rfc-editor.org/rfc/rfc3986#section-3.5) component -func URLFragment(val string) attribute.KeyValue { - return URLFragmentKey.String(val) -} - -// URLFull returns an attribute KeyValue conforming to the "url.full" -// semantic conventions. It represents the absolute URL describing a network -// resource according to [RFC3986](https://www.rfc-editor.org/rfc/rfc3986) -func URLFull(val string) attribute.KeyValue { - return URLFullKey.String(val) -} - -// URLPath returns an attribute KeyValue conforming to the "url.path" -// semantic conventions. It represents the [URI -// path](https://www.rfc-editor.org/rfc/rfc3986#section-3.3) component -func URLPath(val string) attribute.KeyValue { - return URLPathKey.String(val) -} - -// URLQuery returns an attribute KeyValue conforming to the "url.query" -// semantic conventions. It represents the [URI -// query](https://www.rfc-editor.org/rfc/rfc3986#section-3.4) component -func URLQuery(val string) attribute.KeyValue { - return URLQueryKey.String(val) -} - -// URLScheme returns an attribute KeyValue conforming to the "url.scheme" -// semantic conventions. It represents the [URI -// scheme](https://www.rfc-editor.org/rfc/rfc3986#section-3.1) component -// identifying the used protocol. -func URLScheme(val string) attribute.KeyValue { - return URLSchemeKey.String(val) -} - -// Describes user-agent attributes. -const ( - // UserAgentOriginalKey is the attribute Key conforming to the - // "user_agent.original" semantic conventions. It represents the value of - // the [HTTP - // User-Agent](https://www.rfc-editor.org/rfc/rfc9110.html#field.user-agent) - // header sent by the client. - // - // Type: string - // RequirementLevel: Optional - // Stability: stable - // Examples: 'CERN-LineMode/2.15 libwww/2.17b3', 'Mozilla/5.0 (iPhone; CPU - // iPhone OS 14_7_1 like Mac OS X) AppleWebKit/605.1.15 (KHTML, like Gecko) - // Version/14.1.2 Mobile/15E148 Safari/604.1' - UserAgentOriginalKey = attribute.Key("user_agent.original") -) - -// UserAgentOriginal returns an attribute KeyValue conforming to the -// "user_agent.original" semantic conventions. It represents the value of the -// [HTTP -// User-Agent](https://www.rfc-editor.org/rfc/rfc9110.html#field.user-agent) -// header sent by the client. -func UserAgentOriginal(val string) attribute.KeyValue { - return UserAgentOriginalKey.String(val) -} - -// Session is defined as the period of time encompassing all activities -// performed by the application and the actions executed by the end user. -// Consequently, a Session is represented as a collection of Logs, Events, and -// Spans emitted by the Client Application throughout the Session's duration. -// Each Session is assigned a unique identifier, which is included as an -// attribute in the Logs, Events, and Spans generated during the Session's -// lifecycle. -// When a session reaches end of life, typically due to user inactivity or -// session timeout, a new session identifier will be assigned. The previous -// session identifier may be provided by the instrumentation so that telemetry -// backends can link the two sessions. -const ( - // SessionIDKey is the attribute Key conforming to the "session.id" - // semantic conventions. It represents a unique id to identify a session. - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: '00112233-4455-6677-8899-aabbccddeeff' - SessionIDKey = attribute.Key("session.id") - - // SessionPreviousIDKey is the attribute Key conforming to the - // "session.previous_id" semantic conventions. It represents the previous - // `session.id` for this user, when known. - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: '00112233-4455-6677-8899-aabbccddeeff' - SessionPreviousIDKey = attribute.Key("session.previous_id") -) - -// SessionID returns an attribute KeyValue conforming to the "session.id" -// semantic conventions. It represents a unique id to identify a session. -func SessionID(val string) attribute.KeyValue { - return SessionIDKey.String(val) -} - -// SessionPreviousID returns an attribute KeyValue conforming to the -// "session.previous_id" semantic conventions. It represents the previous -// `session.id` for this user, when known. -func SessionPreviousID(val string) attribute.KeyValue { - return SessionPreviousIDKey.String(val) -} diff --git a/vendor/go.opentelemetry.io/otel/semconv/v1.24.0/doc.go b/vendor/go.opentelemetry.io/otel/semconv/v1.24.0/doc.go deleted file mode 100644 index d27e8a8f8..000000000 --- a/vendor/go.opentelemetry.io/otel/semconv/v1.24.0/doc.go +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright The OpenTelemetry Authors -// SPDX-License-Identifier: Apache-2.0 - -// Package semconv implements OpenTelemetry semantic conventions. -// -// OpenTelemetry semantic conventions are agreed standardized naming -// patterns for OpenTelemetry things. This package represents the v1.24.0 -// version of the OpenTelemetry semantic conventions. -package semconv // import "go.opentelemetry.io/otel/semconv/v1.24.0" diff --git a/vendor/go.opentelemetry.io/otel/semconv/v1.24.0/event.go b/vendor/go.opentelemetry.io/otel/semconv/v1.24.0/event.go deleted file mode 100644 index 6c019aafc..000000000 --- a/vendor/go.opentelemetry.io/otel/semconv/v1.24.0/event.go +++ /dev/null @@ -1,200 +0,0 @@ -// Copyright The OpenTelemetry Authors -// SPDX-License-Identifier: Apache-2.0 - -// Code generated from semantic convention specification. DO NOT EDIT. - -package semconv // import "go.opentelemetry.io/otel/semconv/v1.24.0" - -import "go.opentelemetry.io/otel/attribute" - -// This event represents an occurrence of a lifecycle transition on the iOS -// platform. -const ( - // IosStateKey is the attribute Key conforming to the "ios.state" semantic - // conventions. It represents the this attribute represents the state the - // application has transitioned into at the occurrence of the event. - // - // Type: Enum - // RequirementLevel: Required - // Stability: experimental - // Note: The iOS lifecycle states are defined in the [UIApplicationDelegate - // documentation](https://developer.apple.com/documentation/uikit/uiapplicationdelegate#1656902), - // and from which the `OS terminology` column values are derived. - IosStateKey = attribute.Key("ios.state") -) - -var ( - // The app has become `active`. Associated with UIKit notification `applicationDidBecomeActive` - IosStateActive = IosStateKey.String("active") - // The app is now `inactive`. Associated with UIKit notification `applicationWillResignActive` - IosStateInactive = IosStateKey.String("inactive") - // The app is now in the background. This value is associated with UIKit notification `applicationDidEnterBackground` - IosStateBackground = IosStateKey.String("background") - // The app is now in the foreground. This value is associated with UIKit notification `applicationWillEnterForeground` - IosStateForeground = IosStateKey.String("foreground") - // The app is about to terminate. Associated with UIKit notification `applicationWillTerminate` - IosStateTerminate = IosStateKey.String("terminate") -) - -// This event represents an occurrence of a lifecycle transition on the Android -// platform. -const ( - // AndroidStateKey is the attribute Key conforming to the "android.state" - // semantic conventions. It represents the this attribute represents the - // state the application has transitioned into at the occurrence of the - // event. - // - // Type: Enum - // RequirementLevel: Required - // Stability: experimental - // Note: The Android lifecycle states are defined in [Activity lifecycle - // callbacks](https://developer.android.com/guide/components/activities/activity-lifecycle#lc), - // and from which the `OS identifiers` are derived. - AndroidStateKey = attribute.Key("android.state") -) - -var ( - // Any time before Activity.onResume() or, if the app has no Activity, Context.startService() has been called in the app for the first time - AndroidStateCreated = AndroidStateKey.String("created") - // Any time after Activity.onPause() or, if the app has no Activity, Context.stopService() has been called when the app was in the foreground state - AndroidStateBackground = AndroidStateKey.String("background") - // Any time after Activity.onResume() or, if the app has no Activity, Context.startService() has been called when the app was in either the created or background states - AndroidStateForeground = AndroidStateKey.String("foreground") -) - -// This semantic convention defines the attributes used to represent a feature -// flag evaluation as an event. -const ( - // FeatureFlagKeyKey is the attribute Key conforming to the - // "feature_flag.key" semantic conventions. It represents the unique - // identifier of the feature flag. - // - // Type: string - // RequirementLevel: Required - // Stability: experimental - // Examples: 'logo-color' - FeatureFlagKeyKey = attribute.Key("feature_flag.key") - - // FeatureFlagProviderNameKey is the attribute Key conforming to the - // "feature_flag.provider_name" semantic conventions. It represents the - // name of the service provider that performs the flag evaluation. - // - // Type: string - // RequirementLevel: Recommended - // Stability: experimental - // Examples: 'Flag Manager' - FeatureFlagProviderNameKey = attribute.Key("feature_flag.provider_name") - - // FeatureFlagVariantKey is the attribute Key conforming to the - // "feature_flag.variant" semantic conventions. It represents the sHOULD be - // a semantic identifier for a value. If one is unavailable, a stringified - // version of the value can be used. - // - // Type: string - // RequirementLevel: Recommended - // Stability: experimental - // Examples: 'red', 'true', 'on' - // Note: A semantic identifier, commonly referred to as a variant, provides - // a means - // for referring to a value without including the value itself. This can - // provide additional context for understanding the meaning behind a value. - // For example, the variant `red` maybe be used for the value `#c05543`. - // - // A stringified version of the value can be used in situations where a - // semantic identifier is unavailable. String representation of the value - // should be determined by the implementer. - FeatureFlagVariantKey = attribute.Key("feature_flag.variant") -) - -// FeatureFlagKey returns an attribute KeyValue conforming to the -// "feature_flag.key" semantic conventions. It represents the unique identifier -// of the feature flag. -func FeatureFlagKey(val string) attribute.KeyValue { - return FeatureFlagKeyKey.String(val) -} - -// FeatureFlagProviderName returns an attribute KeyValue conforming to the -// "feature_flag.provider_name" semantic conventions. It represents the name of -// the service provider that performs the flag evaluation. -func FeatureFlagProviderName(val string) attribute.KeyValue { - return FeatureFlagProviderNameKey.String(val) -} - -// FeatureFlagVariant returns an attribute KeyValue conforming to the -// "feature_flag.variant" semantic conventions. It represents the sHOULD be a -// semantic identifier for a value. If one is unavailable, a stringified -// version of the value can be used. -func FeatureFlagVariant(val string) attribute.KeyValue { - return FeatureFlagVariantKey.String(val) -} - -// RPC received/sent message. -const ( - // MessageCompressedSizeKey is the attribute Key conforming to the - // "message.compressed_size" semantic conventions. It represents the - // compressed size of the message in bytes. - // - // Type: int - // RequirementLevel: Optional - // Stability: experimental - MessageCompressedSizeKey = attribute.Key("message.compressed_size") - - // MessageIDKey is the attribute Key conforming to the "message.id" - // semantic conventions. It represents the mUST be calculated as two - // different counters starting from `1` one for sent messages and one for - // received message. - // - // Type: int - // RequirementLevel: Optional - // Stability: experimental - // Note: This way we guarantee that the values will be consistent between - // different implementations. - MessageIDKey = attribute.Key("message.id") - - // MessageTypeKey is the attribute Key conforming to the "message.type" - // semantic conventions. It represents the whether this is a received or - // sent message. - // - // Type: Enum - // RequirementLevel: Optional - // Stability: experimental - MessageTypeKey = attribute.Key("message.type") - - // MessageUncompressedSizeKey is the attribute Key conforming to the - // "message.uncompressed_size" semantic conventions. It represents the - // uncompressed size of the message in bytes. - // - // Type: int - // RequirementLevel: Optional - // Stability: experimental - MessageUncompressedSizeKey = attribute.Key("message.uncompressed_size") -) - -var ( - // sent - MessageTypeSent = MessageTypeKey.String("SENT") - // received - MessageTypeReceived = MessageTypeKey.String("RECEIVED") -) - -// MessageCompressedSize returns an attribute KeyValue conforming to the -// "message.compressed_size" semantic conventions. It represents the compressed -// size of the message in bytes. -func MessageCompressedSize(val int) attribute.KeyValue { - return MessageCompressedSizeKey.Int(val) -} - -// MessageID returns an attribute KeyValue conforming to the "message.id" -// semantic conventions. It represents the mUST be calculated as two different -// counters starting from `1` one for sent messages and one for received -// message. -func MessageID(val int) attribute.KeyValue { - return MessageIDKey.Int(val) -} - -// MessageUncompressedSize returns an attribute KeyValue conforming to the -// "message.uncompressed_size" semantic conventions. It represents the -// uncompressed size of the message in bytes. -func MessageUncompressedSize(val int) attribute.KeyValue { - return MessageUncompressedSizeKey.Int(val) -} diff --git a/vendor/go.opentelemetry.io/otel/semconv/v1.24.0/exception.go b/vendor/go.opentelemetry.io/otel/semconv/v1.24.0/exception.go deleted file mode 100644 index 7235bb51d..000000000 --- a/vendor/go.opentelemetry.io/otel/semconv/v1.24.0/exception.go +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright The OpenTelemetry Authors -// SPDX-License-Identifier: Apache-2.0 - -package semconv // import "go.opentelemetry.io/otel/semconv/v1.24.0" - -const ( - // ExceptionEventName is the name of the Span event representing an exception. - ExceptionEventName = "exception" -) diff --git a/vendor/go.opentelemetry.io/otel/semconv/v1.24.0/metric.go b/vendor/go.opentelemetry.io/otel/semconv/v1.24.0/metric.go deleted file mode 100644 index a6b953f62..000000000 --- a/vendor/go.opentelemetry.io/otel/semconv/v1.24.0/metric.go +++ /dev/null @@ -1,1071 +0,0 @@ -// Copyright The OpenTelemetry Authors -// SPDX-License-Identifier: Apache-2.0 - -// Code generated from semantic convention specification. DO NOT EDIT. - -package semconv // import "go.opentelemetry.io/otel/semconv/v1.24.0" - -const ( - - // DBClientConnectionsUsage is the metric conforming to the - // "db.client.connections.usage" semantic conventions. It represents the number - // of connections that are currently in state described by the `state` - // attribute. - // Instrument: updowncounter - // Unit: {connection} - // Stability: Experimental - DBClientConnectionsUsageName = "db.client.connections.usage" - DBClientConnectionsUsageUnit = "{connection}" - DBClientConnectionsUsageDescription = "The number of connections that are currently in state described by the `state` attribute" - - // DBClientConnectionsIdleMax is the metric conforming to the - // "db.client.connections.idle.max" semantic conventions. It represents the - // maximum number of idle open connections allowed. - // Instrument: updowncounter - // Unit: {connection} - // Stability: Experimental - DBClientConnectionsIdleMaxName = "db.client.connections.idle.max" - DBClientConnectionsIdleMaxUnit = "{connection}" - DBClientConnectionsIdleMaxDescription = "The maximum number of idle open connections allowed" - - // DBClientConnectionsIdleMin is the metric conforming to the - // "db.client.connections.idle.min" semantic conventions. It represents the - // minimum number of idle open connections allowed. - // Instrument: updowncounter - // Unit: {connection} - // Stability: Experimental - DBClientConnectionsIdleMinName = "db.client.connections.idle.min" - DBClientConnectionsIdleMinUnit = "{connection}" - DBClientConnectionsIdleMinDescription = "The minimum number of idle open connections allowed" - - // DBClientConnectionsMax is the metric conforming to the - // "db.client.connections.max" semantic conventions. It represents the maximum - // number of open connections allowed. - // Instrument: updowncounter - // Unit: {connection} - // Stability: Experimental - DBClientConnectionsMaxName = "db.client.connections.max" - DBClientConnectionsMaxUnit = "{connection}" - DBClientConnectionsMaxDescription = "The maximum number of open connections allowed" - - // DBClientConnectionsPendingRequests is the metric conforming to the - // "db.client.connections.pending_requests" semantic conventions. It represents - // the number of pending requests for an open connection, cumulative for the - // entire pool. - // Instrument: updowncounter - // Unit: {request} - // Stability: Experimental - DBClientConnectionsPendingRequestsName = "db.client.connections.pending_requests" - DBClientConnectionsPendingRequestsUnit = "{request}" - DBClientConnectionsPendingRequestsDescription = "The number of pending requests for an open connection, cumulative for the entire pool" - - // DBClientConnectionsTimeouts is the metric conforming to the - // "db.client.connections.timeouts" semantic conventions. It represents the - // number of connection timeouts that have occurred trying to obtain a - // connection from the pool. - // Instrument: counter - // Unit: {timeout} - // Stability: Experimental - DBClientConnectionsTimeoutsName = "db.client.connections.timeouts" - DBClientConnectionsTimeoutsUnit = "{timeout}" - DBClientConnectionsTimeoutsDescription = "The number of connection timeouts that have occurred trying to obtain a connection from the pool" - - // DBClientConnectionsCreateTime is the metric conforming to the - // "db.client.connections.create_time" semantic conventions. It represents the - // time it took to create a new connection. - // Instrument: histogram - // Unit: ms - // Stability: Experimental - DBClientConnectionsCreateTimeName = "db.client.connections.create_time" - DBClientConnectionsCreateTimeUnit = "ms" - DBClientConnectionsCreateTimeDescription = "The time it took to create a new connection" - - // DBClientConnectionsWaitTime is the metric conforming to the - // "db.client.connections.wait_time" semantic conventions. It represents the - // time it took to obtain an open connection from the pool. - // Instrument: histogram - // Unit: ms - // Stability: Experimental - DBClientConnectionsWaitTimeName = "db.client.connections.wait_time" - DBClientConnectionsWaitTimeUnit = "ms" - DBClientConnectionsWaitTimeDescription = "The time it took to obtain an open connection from the pool" - - // DBClientConnectionsUseTime is the metric conforming to the - // "db.client.connections.use_time" semantic conventions. It represents the - // time between borrowing a connection and returning it to the pool. - // Instrument: histogram - // Unit: ms - // Stability: Experimental - DBClientConnectionsUseTimeName = "db.client.connections.use_time" - DBClientConnectionsUseTimeUnit = "ms" - DBClientConnectionsUseTimeDescription = "The time between borrowing a connection and returning it to the pool" - - // AspnetcoreRoutingMatchAttempts is the metric conforming to the - // "aspnetcore.routing.match_attempts" semantic conventions. It represents the - // number of requests that were attempted to be matched to an endpoint. - // Instrument: counter - // Unit: {match_attempt} - // Stability: Experimental - AspnetcoreRoutingMatchAttemptsName = "aspnetcore.routing.match_attempts" - AspnetcoreRoutingMatchAttemptsUnit = "{match_attempt}" - AspnetcoreRoutingMatchAttemptsDescription = "Number of requests that were attempted to be matched to an endpoint." - - // AspnetcoreDiagnosticsExceptions is the metric conforming to the - // "aspnetcore.diagnostics.exceptions" semantic conventions. It represents the - // number of exceptions caught by exception handling middleware. - // Instrument: counter - // Unit: {exception} - // Stability: Experimental - AspnetcoreDiagnosticsExceptionsName = "aspnetcore.diagnostics.exceptions" - AspnetcoreDiagnosticsExceptionsUnit = "{exception}" - AspnetcoreDiagnosticsExceptionsDescription = "Number of exceptions caught by exception handling middleware." - - // AspnetcoreRateLimitingActiveRequestLeases is the metric conforming to the - // "aspnetcore.rate_limiting.active_request_leases" semantic conventions. It - // represents the number of requests that are currently active on the server - // that hold a rate limiting lease. - // Instrument: updowncounter - // Unit: {request} - // Stability: Experimental - AspnetcoreRateLimitingActiveRequestLeasesName = "aspnetcore.rate_limiting.active_request_leases" - AspnetcoreRateLimitingActiveRequestLeasesUnit = "{request}" - AspnetcoreRateLimitingActiveRequestLeasesDescription = "Number of requests that are currently active on the server that hold a rate limiting lease." - - // AspnetcoreRateLimitingRequestLeaseDuration is the metric conforming to the - // "aspnetcore.rate_limiting.request_lease.duration" semantic conventions. It - // represents the duration of rate limiting lease held by requests on the - // server. - // Instrument: histogram - // Unit: s - // Stability: Experimental - AspnetcoreRateLimitingRequestLeaseDurationName = "aspnetcore.rate_limiting.request_lease.duration" - AspnetcoreRateLimitingRequestLeaseDurationUnit = "s" - AspnetcoreRateLimitingRequestLeaseDurationDescription = "The duration of rate limiting lease held by requests on the server." - - // AspnetcoreRateLimitingRequestTimeInQueue is the metric conforming to the - // "aspnetcore.rate_limiting.request.time_in_queue" semantic conventions. It - // represents the time the request spent in a queue waiting to acquire a rate - // limiting lease. - // Instrument: histogram - // Unit: s - // Stability: Experimental - AspnetcoreRateLimitingRequestTimeInQueueName = "aspnetcore.rate_limiting.request.time_in_queue" - AspnetcoreRateLimitingRequestTimeInQueueUnit = "s" - AspnetcoreRateLimitingRequestTimeInQueueDescription = "The time the request spent in a queue waiting to acquire a rate limiting lease." - - // AspnetcoreRateLimitingQueuedRequests is the metric conforming to the - // "aspnetcore.rate_limiting.queued_requests" semantic conventions. It - // represents the number of requests that are currently queued, waiting to - // acquire a rate limiting lease. - // Instrument: updowncounter - // Unit: {request} - // Stability: Experimental - AspnetcoreRateLimitingQueuedRequestsName = "aspnetcore.rate_limiting.queued_requests" - AspnetcoreRateLimitingQueuedRequestsUnit = "{request}" - AspnetcoreRateLimitingQueuedRequestsDescription = "Number of requests that are currently queued, waiting to acquire a rate limiting lease." - - // AspnetcoreRateLimitingRequests is the metric conforming to the - // "aspnetcore.rate_limiting.requests" semantic conventions. It represents the - // number of requests that tried to acquire a rate limiting lease. - // Instrument: counter - // Unit: {request} - // Stability: Experimental - AspnetcoreRateLimitingRequestsName = "aspnetcore.rate_limiting.requests" - AspnetcoreRateLimitingRequestsUnit = "{request}" - AspnetcoreRateLimitingRequestsDescription = "Number of requests that tried to acquire a rate limiting lease." - - // DNSLookupDuration is the metric conforming to the "dns.lookup.duration" - // semantic conventions. It represents the measures the time taken to perform a - // DNS lookup. - // Instrument: histogram - // Unit: s - // Stability: Experimental - DNSLookupDurationName = "dns.lookup.duration" - DNSLookupDurationUnit = "s" - DNSLookupDurationDescription = "Measures the time taken to perform a DNS lookup." - - // HTTPClientOpenConnections is the metric conforming to the - // "http.client.open_connections" semantic conventions. It represents the - // number of outbound HTTP connections that are currently active or idle on the - // client. - // Instrument: updowncounter - // Unit: {connection} - // Stability: Experimental - HTTPClientOpenConnectionsName = "http.client.open_connections" - HTTPClientOpenConnectionsUnit = "{connection}" - HTTPClientOpenConnectionsDescription = "Number of outbound HTTP connections that are currently active or idle on the client." - - // HTTPClientConnectionDuration is the metric conforming to the - // "http.client.connection.duration" semantic conventions. It represents the - // duration of the successfully established outbound HTTP connections. - // Instrument: histogram - // Unit: s - // Stability: Experimental - HTTPClientConnectionDurationName = "http.client.connection.duration" - HTTPClientConnectionDurationUnit = "s" - HTTPClientConnectionDurationDescription = "The duration of the successfully established outbound HTTP connections." - - // HTTPClientActiveRequests is the metric conforming to the - // "http.client.active_requests" semantic conventions. It represents the number - // of active HTTP requests. - // Instrument: updowncounter - // Unit: {request} - // Stability: Experimental - HTTPClientActiveRequestsName = "http.client.active_requests" - HTTPClientActiveRequestsUnit = "{request}" - HTTPClientActiveRequestsDescription = "Number of active HTTP requests." - - // HTTPClientRequestTimeInQueue is the metric conforming to the - // "http.client.request.time_in_queue" semantic conventions. It represents the - // amount of time requests spent on a queue waiting for an available - // connection. - // Instrument: histogram - // Unit: s - // Stability: Experimental - HTTPClientRequestTimeInQueueName = "http.client.request.time_in_queue" - HTTPClientRequestTimeInQueueUnit = "s" - HTTPClientRequestTimeInQueueDescription = "The amount of time requests spent on a queue waiting for an available connection." - - // KestrelActiveConnections is the metric conforming to the - // "kestrel.active_connections" semantic conventions. It represents the number - // of connections that are currently active on the server. - // Instrument: updowncounter - // Unit: {connection} - // Stability: Experimental - KestrelActiveConnectionsName = "kestrel.active_connections" - KestrelActiveConnectionsUnit = "{connection}" - KestrelActiveConnectionsDescription = "Number of connections that are currently active on the server." - - // KestrelConnectionDuration is the metric conforming to the - // "kestrel.connection.duration" semantic conventions. It represents the - // duration of connections on the server. - // Instrument: histogram - // Unit: s - // Stability: Experimental - KestrelConnectionDurationName = "kestrel.connection.duration" - KestrelConnectionDurationUnit = "s" - KestrelConnectionDurationDescription = "The duration of connections on the server." - - // KestrelRejectedConnections is the metric conforming to the - // "kestrel.rejected_connections" semantic conventions. It represents the - // number of connections rejected by the server. - // Instrument: counter - // Unit: {connection} - // Stability: Experimental - KestrelRejectedConnectionsName = "kestrel.rejected_connections" - KestrelRejectedConnectionsUnit = "{connection}" - KestrelRejectedConnectionsDescription = "Number of connections rejected by the server." - - // KestrelQueuedConnections is the metric conforming to the - // "kestrel.queued_connections" semantic conventions. It represents the number - // of connections that are currently queued and are waiting to start. - // Instrument: updowncounter - // Unit: {connection} - // Stability: Experimental - KestrelQueuedConnectionsName = "kestrel.queued_connections" - KestrelQueuedConnectionsUnit = "{connection}" - KestrelQueuedConnectionsDescription = "Number of connections that are currently queued and are waiting to start." - - // KestrelQueuedRequests is the metric conforming to the - // "kestrel.queued_requests" semantic conventions. It represents the number of - // HTTP requests on multiplexed connections (HTTP/2 and HTTP/3) that are - // currently queued and are waiting to start. - // Instrument: updowncounter - // Unit: {request} - // Stability: Experimental - KestrelQueuedRequestsName = "kestrel.queued_requests" - KestrelQueuedRequestsUnit = "{request}" - KestrelQueuedRequestsDescription = "Number of HTTP requests on multiplexed connections (HTTP/2 and HTTP/3) that are currently queued and are waiting to start." - - // KestrelUpgradedConnections is the metric conforming to the - // "kestrel.upgraded_connections" semantic conventions. It represents the - // number of connections that are currently upgraded (WebSockets). . - // Instrument: updowncounter - // Unit: {connection} - // Stability: Experimental - KestrelUpgradedConnectionsName = "kestrel.upgraded_connections" - KestrelUpgradedConnectionsUnit = "{connection}" - KestrelUpgradedConnectionsDescription = "Number of connections that are currently upgraded (WebSockets). ." - - // KestrelTLSHandshakeDuration is the metric conforming to the - // "kestrel.tls_handshake.duration" semantic conventions. It represents the - // duration of TLS handshakes on the server. - // Instrument: histogram - // Unit: s - // Stability: Experimental - KestrelTLSHandshakeDurationName = "kestrel.tls_handshake.duration" - KestrelTLSHandshakeDurationUnit = "s" - KestrelTLSHandshakeDurationDescription = "The duration of TLS handshakes on the server." - - // KestrelActiveTLSHandshakes is the metric conforming to the - // "kestrel.active_tls_handshakes" semantic conventions. It represents the - // number of TLS handshakes that are currently in progress on the server. - // Instrument: updowncounter - // Unit: {handshake} - // Stability: Experimental - KestrelActiveTLSHandshakesName = "kestrel.active_tls_handshakes" - KestrelActiveTLSHandshakesUnit = "{handshake}" - KestrelActiveTLSHandshakesDescription = "Number of TLS handshakes that are currently in progress on the server." - - // SignalrServerConnectionDuration is the metric conforming to the - // "signalr.server.connection.duration" semantic conventions. It represents the - // duration of connections on the server. - // Instrument: histogram - // Unit: s - // Stability: Experimental - SignalrServerConnectionDurationName = "signalr.server.connection.duration" - SignalrServerConnectionDurationUnit = "s" - SignalrServerConnectionDurationDescription = "The duration of connections on the server." - - // SignalrServerActiveConnections is the metric conforming to the - // "signalr.server.active_connections" semantic conventions. It represents the - // number of connections that are currently active on the server. - // Instrument: updowncounter - // Unit: {connection} - // Stability: Experimental - SignalrServerActiveConnectionsName = "signalr.server.active_connections" - SignalrServerActiveConnectionsUnit = "{connection}" - SignalrServerActiveConnectionsDescription = "Number of connections that are currently active on the server." - - // FaaSInvokeDuration is the metric conforming to the "faas.invoke_duration" - // semantic conventions. It represents the measures the duration of the - // function's logic execution. - // Instrument: histogram - // Unit: s - // Stability: Experimental - FaaSInvokeDurationName = "faas.invoke_duration" - FaaSInvokeDurationUnit = "s" - FaaSInvokeDurationDescription = "Measures the duration of the function's logic execution" - - // FaaSInitDuration is the metric conforming to the "faas.init_duration" - // semantic conventions. It represents the measures the duration of the - // function's initialization, such as a cold start. - // Instrument: histogram - // Unit: s - // Stability: Experimental - FaaSInitDurationName = "faas.init_duration" - FaaSInitDurationUnit = "s" - FaaSInitDurationDescription = "Measures the duration of the function's initialization, such as a cold start" - - // FaaSColdstarts is the metric conforming to the "faas.coldstarts" semantic - // conventions. It represents the number of invocation cold starts. - // Instrument: counter - // Unit: {coldstart} - // Stability: Experimental - FaaSColdstartsName = "faas.coldstarts" - FaaSColdstartsUnit = "{coldstart}" - FaaSColdstartsDescription = "Number of invocation cold starts" - - // FaaSErrors is the metric conforming to the "faas.errors" semantic - // conventions. It represents the number of invocation errors. - // Instrument: counter - // Unit: {error} - // Stability: Experimental - FaaSErrorsName = "faas.errors" - FaaSErrorsUnit = "{error}" - FaaSErrorsDescription = "Number of invocation errors" - - // FaaSInvocations is the metric conforming to the "faas.invocations" semantic - // conventions. It represents the number of successful invocations. - // Instrument: counter - // Unit: {invocation} - // Stability: Experimental - FaaSInvocationsName = "faas.invocations" - FaaSInvocationsUnit = "{invocation}" - FaaSInvocationsDescription = "Number of successful invocations" - - // FaaSTimeouts is the metric conforming to the "faas.timeouts" semantic - // conventions. It represents the number of invocation timeouts. - // Instrument: counter - // Unit: {timeout} - // Stability: Experimental - FaaSTimeoutsName = "faas.timeouts" - FaaSTimeoutsUnit = "{timeout}" - FaaSTimeoutsDescription = "Number of invocation timeouts" - - // FaaSMemUsage is the metric conforming to the "faas.mem_usage" semantic - // conventions. It represents the distribution of max memory usage per - // invocation. - // Instrument: histogram - // Unit: By - // Stability: Experimental - FaaSMemUsageName = "faas.mem_usage" - FaaSMemUsageUnit = "By" - FaaSMemUsageDescription = "Distribution of max memory usage per invocation" - - // FaaSCPUUsage is the metric conforming to the "faas.cpu_usage" semantic - // conventions. It represents the distribution of CPU usage per invocation. - // Instrument: histogram - // Unit: s - // Stability: Experimental - FaaSCPUUsageName = "faas.cpu_usage" - FaaSCPUUsageUnit = "s" - FaaSCPUUsageDescription = "Distribution of CPU usage per invocation" - - // FaaSNetIo is the metric conforming to the "faas.net_io" semantic - // conventions. It represents the distribution of net I/O usage per invocation. - // Instrument: histogram - // Unit: By - // Stability: Experimental - FaaSNetIoName = "faas.net_io" - FaaSNetIoUnit = "By" - FaaSNetIoDescription = "Distribution of net I/O usage per invocation" - - // HTTPServerRequestDuration is the metric conforming to the - // "http.server.request.duration" semantic conventions. It represents the - // duration of HTTP server requests. - // Instrument: histogram - // Unit: s - // Stability: Stable - HTTPServerRequestDurationName = "http.server.request.duration" - HTTPServerRequestDurationUnit = "s" - HTTPServerRequestDurationDescription = "Duration of HTTP server requests." - - // HTTPServerActiveRequests is the metric conforming to the - // "http.server.active_requests" semantic conventions. It represents the number - // of active HTTP server requests. - // Instrument: updowncounter - // Unit: {request} - // Stability: Experimental - HTTPServerActiveRequestsName = "http.server.active_requests" - HTTPServerActiveRequestsUnit = "{request}" - HTTPServerActiveRequestsDescription = "Number of active HTTP server requests." - - // HTTPServerRequestBodySize is the metric conforming to the - // "http.server.request.body.size" semantic conventions. It represents the size - // of HTTP server request bodies. - // Instrument: histogram - // Unit: By - // Stability: Experimental - HTTPServerRequestBodySizeName = "http.server.request.body.size" - HTTPServerRequestBodySizeUnit = "By" - HTTPServerRequestBodySizeDescription = "Size of HTTP server request bodies." - - // HTTPServerResponseBodySize is the metric conforming to the - // "http.server.response.body.size" semantic conventions. It represents the - // size of HTTP server response bodies. - // Instrument: histogram - // Unit: By - // Stability: Experimental - HTTPServerResponseBodySizeName = "http.server.response.body.size" - HTTPServerResponseBodySizeUnit = "By" - HTTPServerResponseBodySizeDescription = "Size of HTTP server response bodies." - - // HTTPClientRequestDuration is the metric conforming to the - // "http.client.request.duration" semantic conventions. It represents the - // duration of HTTP client requests. - // Instrument: histogram - // Unit: s - // Stability: Stable - HTTPClientRequestDurationName = "http.client.request.duration" - HTTPClientRequestDurationUnit = "s" - HTTPClientRequestDurationDescription = "Duration of HTTP client requests." - - // HTTPClientRequestBodySize is the metric conforming to the - // "http.client.request.body.size" semantic conventions. It represents the size - // of HTTP client request bodies. - // Instrument: histogram - // Unit: By - // Stability: Experimental - HTTPClientRequestBodySizeName = "http.client.request.body.size" - HTTPClientRequestBodySizeUnit = "By" - HTTPClientRequestBodySizeDescription = "Size of HTTP client request bodies." - - // HTTPClientResponseBodySize is the metric conforming to the - // "http.client.response.body.size" semantic conventions. It represents the - // size of HTTP client response bodies. - // Instrument: histogram - // Unit: By - // Stability: Experimental - HTTPClientResponseBodySizeName = "http.client.response.body.size" - HTTPClientResponseBodySizeUnit = "By" - HTTPClientResponseBodySizeDescription = "Size of HTTP client response bodies." - - // JvmMemoryInit is the metric conforming to the "jvm.memory.init" semantic - // conventions. It represents the measure of initial memory requested. - // Instrument: updowncounter - // Unit: By - // Stability: Experimental - JvmMemoryInitName = "jvm.memory.init" - JvmMemoryInitUnit = "By" - JvmMemoryInitDescription = "Measure of initial memory requested." - - // JvmSystemCPUUtilization is the metric conforming to the - // "jvm.system.cpu.utilization" semantic conventions. It represents the recent - // CPU utilization for the whole system as reported by the JVM. - // Instrument: gauge - // Unit: 1 - // Stability: Experimental - JvmSystemCPUUtilizationName = "jvm.system.cpu.utilization" - JvmSystemCPUUtilizationUnit = "1" - JvmSystemCPUUtilizationDescription = "Recent CPU utilization for the whole system as reported by the JVM." - - // JvmSystemCPULoad1m is the metric conforming to the "jvm.system.cpu.load_1m" - // semantic conventions. It represents the average CPU load of the whole system - // for the last minute as reported by the JVM. - // Instrument: gauge - // Unit: {run_queue_item} - // Stability: Experimental - JvmSystemCPULoad1mName = "jvm.system.cpu.load_1m" - JvmSystemCPULoad1mUnit = "{run_queue_item}" - JvmSystemCPULoad1mDescription = "Average CPU load of the whole system for the last minute as reported by the JVM." - - // JvmBufferMemoryUsage is the metric conforming to the - // "jvm.buffer.memory.usage" semantic conventions. It represents the measure of - // memory used by buffers. - // Instrument: updowncounter - // Unit: By - // Stability: Experimental - JvmBufferMemoryUsageName = "jvm.buffer.memory.usage" - JvmBufferMemoryUsageUnit = "By" - JvmBufferMemoryUsageDescription = "Measure of memory used by buffers." - - // JvmBufferMemoryLimit is the metric conforming to the - // "jvm.buffer.memory.limit" semantic conventions. It represents the measure of - // total memory capacity of buffers. - // Instrument: updowncounter - // Unit: By - // Stability: Experimental - JvmBufferMemoryLimitName = "jvm.buffer.memory.limit" - JvmBufferMemoryLimitUnit = "By" - JvmBufferMemoryLimitDescription = "Measure of total memory capacity of buffers." - - // JvmBufferCount is the metric conforming to the "jvm.buffer.count" semantic - // conventions. It represents the number of buffers in the pool. - // Instrument: updowncounter - // Unit: {buffer} - // Stability: Experimental - JvmBufferCountName = "jvm.buffer.count" - JvmBufferCountUnit = "{buffer}" - JvmBufferCountDescription = "Number of buffers in the pool." - - // JvmMemoryUsed is the metric conforming to the "jvm.memory.used" semantic - // conventions. It represents the measure of memory used. - // Instrument: updowncounter - // Unit: By - // Stability: Stable - JvmMemoryUsedName = "jvm.memory.used" - JvmMemoryUsedUnit = "By" - JvmMemoryUsedDescription = "Measure of memory used." - - // JvmMemoryCommitted is the metric conforming to the "jvm.memory.committed" - // semantic conventions. It represents the measure of memory committed. - // Instrument: updowncounter - // Unit: By - // Stability: Stable - JvmMemoryCommittedName = "jvm.memory.committed" - JvmMemoryCommittedUnit = "By" - JvmMemoryCommittedDescription = "Measure of memory committed." - - // JvmMemoryLimit is the metric conforming to the "jvm.memory.limit" semantic - // conventions. It represents the measure of max obtainable memory. - // Instrument: updowncounter - // Unit: By - // Stability: Stable - JvmMemoryLimitName = "jvm.memory.limit" - JvmMemoryLimitUnit = "By" - JvmMemoryLimitDescription = "Measure of max obtainable memory." - - // JvmMemoryUsedAfterLastGc is the metric conforming to the - // "jvm.memory.used_after_last_gc" semantic conventions. It represents the - // measure of memory used, as measured after the most recent garbage collection - // event on this pool. - // Instrument: updowncounter - // Unit: By - // Stability: Stable - JvmMemoryUsedAfterLastGcName = "jvm.memory.used_after_last_gc" - JvmMemoryUsedAfterLastGcUnit = "By" - JvmMemoryUsedAfterLastGcDescription = "Measure of memory used, as measured after the most recent garbage collection event on this pool." - - // JvmGcDuration is the metric conforming to the "jvm.gc.duration" semantic - // conventions. It represents the duration of JVM garbage collection actions. - // Instrument: histogram - // Unit: s - // Stability: Stable - JvmGcDurationName = "jvm.gc.duration" - JvmGcDurationUnit = "s" - JvmGcDurationDescription = "Duration of JVM garbage collection actions." - - // JvmThreadCount is the metric conforming to the "jvm.thread.count" semantic - // conventions. It represents the number of executing platform threads. - // Instrument: updowncounter - // Unit: {thread} - // Stability: Stable - JvmThreadCountName = "jvm.thread.count" - JvmThreadCountUnit = "{thread}" - JvmThreadCountDescription = "Number of executing platform threads." - - // JvmClassLoaded is the metric conforming to the "jvm.class.loaded" semantic - // conventions. It represents the number of classes loaded since JVM start. - // Instrument: counter - // Unit: {class} - // Stability: Stable - JvmClassLoadedName = "jvm.class.loaded" - JvmClassLoadedUnit = "{class}" - JvmClassLoadedDescription = "Number of classes loaded since JVM start." - - // JvmClassUnloaded is the metric conforming to the "jvm.class.unloaded" - // semantic conventions. It represents the number of classes unloaded since JVM - // start. - // Instrument: counter - // Unit: {class} - // Stability: Stable - JvmClassUnloadedName = "jvm.class.unloaded" - JvmClassUnloadedUnit = "{class}" - JvmClassUnloadedDescription = "Number of classes unloaded since JVM start." - - // JvmClassCount is the metric conforming to the "jvm.class.count" semantic - // conventions. It represents the number of classes currently loaded. - // Instrument: updowncounter - // Unit: {class} - // Stability: Stable - JvmClassCountName = "jvm.class.count" - JvmClassCountUnit = "{class}" - JvmClassCountDescription = "Number of classes currently loaded." - - // JvmCPUCount is the metric conforming to the "jvm.cpu.count" semantic - // conventions. It represents the number of processors available to the Java - // virtual machine. - // Instrument: updowncounter - // Unit: {cpu} - // Stability: Stable - JvmCPUCountName = "jvm.cpu.count" - JvmCPUCountUnit = "{cpu}" - JvmCPUCountDescription = "Number of processors available to the Java virtual machine." - - // JvmCPUTime is the metric conforming to the "jvm.cpu.time" semantic - // conventions. It represents the cPU time used by the process as reported by - // the JVM. - // Instrument: counter - // Unit: s - // Stability: Stable - JvmCPUTimeName = "jvm.cpu.time" - JvmCPUTimeUnit = "s" - JvmCPUTimeDescription = "CPU time used by the process as reported by the JVM." - - // JvmCPURecentUtilization is the metric conforming to the - // "jvm.cpu.recent_utilization" semantic conventions. It represents the recent - // CPU utilization for the process as reported by the JVM. - // Instrument: gauge - // Unit: 1 - // Stability: Stable - JvmCPURecentUtilizationName = "jvm.cpu.recent_utilization" - JvmCPURecentUtilizationUnit = "1" - JvmCPURecentUtilizationDescription = "Recent CPU utilization for the process as reported by the JVM." - - // MessagingPublishDuration is the metric conforming to the - // "messaging.publish.duration" semantic conventions. It represents the - // measures the duration of publish operation. - // Instrument: histogram - // Unit: s - // Stability: Experimental - MessagingPublishDurationName = "messaging.publish.duration" - MessagingPublishDurationUnit = "s" - MessagingPublishDurationDescription = "Measures the duration of publish operation." - - // MessagingReceiveDuration is the metric conforming to the - // "messaging.receive.duration" semantic conventions. It represents the - // measures the duration of receive operation. - // Instrument: histogram - // Unit: s - // Stability: Experimental - MessagingReceiveDurationName = "messaging.receive.duration" - MessagingReceiveDurationUnit = "s" - MessagingReceiveDurationDescription = "Measures the duration of receive operation." - - // MessagingDeliverDuration is the metric conforming to the - // "messaging.deliver.duration" semantic conventions. It represents the - // measures the duration of deliver operation. - // Instrument: histogram - // Unit: s - // Stability: Experimental - MessagingDeliverDurationName = "messaging.deliver.duration" - MessagingDeliverDurationUnit = "s" - MessagingDeliverDurationDescription = "Measures the duration of deliver operation." - - // MessagingPublishMessages is the metric conforming to the - // "messaging.publish.messages" semantic conventions. It represents the - // measures the number of published messages. - // Instrument: counter - // Unit: {message} - // Stability: Experimental - MessagingPublishMessagesName = "messaging.publish.messages" - MessagingPublishMessagesUnit = "{message}" - MessagingPublishMessagesDescription = "Measures the number of published messages." - - // MessagingReceiveMessages is the metric conforming to the - // "messaging.receive.messages" semantic conventions. It represents the - // measures the number of received messages. - // Instrument: counter - // Unit: {message} - // Stability: Experimental - MessagingReceiveMessagesName = "messaging.receive.messages" - MessagingReceiveMessagesUnit = "{message}" - MessagingReceiveMessagesDescription = "Measures the number of received messages." - - // MessagingDeliverMessages is the metric conforming to the - // "messaging.deliver.messages" semantic conventions. It represents the - // measures the number of delivered messages. - // Instrument: counter - // Unit: {message} - // Stability: Experimental - MessagingDeliverMessagesName = "messaging.deliver.messages" - MessagingDeliverMessagesUnit = "{message}" - MessagingDeliverMessagesDescription = "Measures the number of delivered messages." - - // RPCServerDuration is the metric conforming to the "rpc.server.duration" - // semantic conventions. It represents the measures the duration of inbound - // RPC. - // Instrument: histogram - // Unit: ms - // Stability: Experimental - RPCServerDurationName = "rpc.server.duration" - RPCServerDurationUnit = "ms" - RPCServerDurationDescription = "Measures the duration of inbound RPC." - - // RPCServerRequestSize is the metric conforming to the - // "rpc.server.request.size" semantic conventions. It represents the measures - // the size of RPC request messages (uncompressed). - // Instrument: histogram - // Unit: By - // Stability: Experimental - RPCServerRequestSizeName = "rpc.server.request.size" - RPCServerRequestSizeUnit = "By" - RPCServerRequestSizeDescription = "Measures the size of RPC request messages (uncompressed)." - - // RPCServerResponseSize is the metric conforming to the - // "rpc.server.response.size" semantic conventions. It represents the measures - // the size of RPC response messages (uncompressed). - // Instrument: histogram - // Unit: By - // Stability: Experimental - RPCServerResponseSizeName = "rpc.server.response.size" - RPCServerResponseSizeUnit = "By" - RPCServerResponseSizeDescription = "Measures the size of RPC response messages (uncompressed)." - - // RPCServerRequestsPerRPC is the metric conforming to the - // "rpc.server.requests_per_rpc" semantic conventions. It represents the - // measures the number of messages received per RPC. - // Instrument: histogram - // Unit: {count} - // Stability: Experimental - RPCServerRequestsPerRPCName = "rpc.server.requests_per_rpc" - RPCServerRequestsPerRPCUnit = "{count}" - RPCServerRequestsPerRPCDescription = "Measures the number of messages received per RPC." - - // RPCServerResponsesPerRPC is the metric conforming to the - // "rpc.server.responses_per_rpc" semantic conventions. It represents the - // measures the number of messages sent per RPC. - // Instrument: histogram - // Unit: {count} - // Stability: Experimental - RPCServerResponsesPerRPCName = "rpc.server.responses_per_rpc" - RPCServerResponsesPerRPCUnit = "{count}" - RPCServerResponsesPerRPCDescription = "Measures the number of messages sent per RPC." - - // RPCClientDuration is the metric conforming to the "rpc.client.duration" - // semantic conventions. It represents the measures the duration of outbound - // RPC. - // Instrument: histogram - // Unit: ms - // Stability: Experimental - RPCClientDurationName = "rpc.client.duration" - RPCClientDurationUnit = "ms" - RPCClientDurationDescription = "Measures the duration of outbound RPC." - - // RPCClientRequestSize is the metric conforming to the - // "rpc.client.request.size" semantic conventions. It represents the measures - // the size of RPC request messages (uncompressed). - // Instrument: histogram - // Unit: By - // Stability: Experimental - RPCClientRequestSizeName = "rpc.client.request.size" - RPCClientRequestSizeUnit = "By" - RPCClientRequestSizeDescription = "Measures the size of RPC request messages (uncompressed)." - - // RPCClientResponseSize is the metric conforming to the - // "rpc.client.response.size" semantic conventions. It represents the measures - // the size of RPC response messages (uncompressed). - // Instrument: histogram - // Unit: By - // Stability: Experimental - RPCClientResponseSizeName = "rpc.client.response.size" - RPCClientResponseSizeUnit = "By" - RPCClientResponseSizeDescription = "Measures the size of RPC response messages (uncompressed)." - - // RPCClientRequestsPerRPC is the metric conforming to the - // "rpc.client.requests_per_rpc" semantic conventions. It represents the - // measures the number of messages received per RPC. - // Instrument: histogram - // Unit: {count} - // Stability: Experimental - RPCClientRequestsPerRPCName = "rpc.client.requests_per_rpc" - RPCClientRequestsPerRPCUnit = "{count}" - RPCClientRequestsPerRPCDescription = "Measures the number of messages received per RPC." - - // RPCClientResponsesPerRPC is the metric conforming to the - // "rpc.client.responses_per_rpc" semantic conventions. It represents the - // measures the number of messages sent per RPC. - // Instrument: histogram - // Unit: {count} - // Stability: Experimental - RPCClientResponsesPerRPCName = "rpc.client.responses_per_rpc" - RPCClientResponsesPerRPCUnit = "{count}" - RPCClientResponsesPerRPCDescription = "Measures the number of messages sent per RPC." - - // SystemCPUTime is the metric conforming to the "system.cpu.time" semantic - // conventions. It represents the seconds each logical CPU spent on each mode. - // Instrument: counter - // Unit: s - // Stability: Experimental - SystemCPUTimeName = "system.cpu.time" - SystemCPUTimeUnit = "s" - SystemCPUTimeDescription = "Seconds each logical CPU spent on each mode" - - // SystemCPUUtilization is the metric conforming to the - // "system.cpu.utilization" semantic conventions. It represents the difference - // in system.cpu.time since the last measurement, divided by the elapsed time - // and number of logical CPUs. - // Instrument: gauge - // Unit: 1 - // Stability: Experimental - SystemCPUUtilizationName = "system.cpu.utilization" - SystemCPUUtilizationUnit = "1" - SystemCPUUtilizationDescription = "Difference in system.cpu.time since the last measurement, divided by the elapsed time and number of logical CPUs" - - // SystemCPUFrequency is the metric conforming to the "system.cpu.frequency" - // semantic conventions. It represents the reports the current frequency of the - // CPU in Hz. - // Instrument: gauge - // Unit: {Hz} - // Stability: Experimental - SystemCPUFrequencyName = "system.cpu.frequency" - SystemCPUFrequencyUnit = "{Hz}" - SystemCPUFrequencyDescription = "Reports the current frequency of the CPU in Hz" - - // SystemCPUPhysicalCount is the metric conforming to the - // "system.cpu.physical.count" semantic conventions. It represents the reports - // the number of actual physical processor cores on the hardware. - // Instrument: updowncounter - // Unit: {cpu} - // Stability: Experimental - SystemCPUPhysicalCountName = "system.cpu.physical.count" - SystemCPUPhysicalCountUnit = "{cpu}" - SystemCPUPhysicalCountDescription = "Reports the number of actual physical processor cores on the hardware" - - // SystemCPULogicalCount is the metric conforming to the - // "system.cpu.logical.count" semantic conventions. It represents the reports - // the number of logical (virtual) processor cores created by the operating - // system to manage multitasking. - // Instrument: updowncounter - // Unit: {cpu} - // Stability: Experimental - SystemCPULogicalCountName = "system.cpu.logical.count" - SystemCPULogicalCountUnit = "{cpu}" - SystemCPULogicalCountDescription = "Reports the number of logical (virtual) processor cores created by the operating system to manage multitasking" - - // SystemMemoryUsage is the metric conforming to the "system.memory.usage" - // semantic conventions. It represents the reports memory in use by state. - // Instrument: updowncounter - // Unit: By - // Stability: Experimental - SystemMemoryUsageName = "system.memory.usage" - SystemMemoryUsageUnit = "By" - SystemMemoryUsageDescription = "Reports memory in use by state." - - // SystemMemoryLimit is the metric conforming to the "system.memory.limit" - // semantic conventions. It represents the total memory available in the - // system. - // Instrument: updowncounter - // Unit: By - // Stability: Experimental - SystemMemoryLimitName = "system.memory.limit" - SystemMemoryLimitUnit = "By" - SystemMemoryLimitDescription = "Total memory available in the system." - - // SystemMemoryUtilization is the metric conforming to the - // "system.memory.utilization" semantic conventions. - // Instrument: gauge - // Unit: 1 - // Stability: Experimental - // NOTE: The description (brief) for this metric is not defined in the semantic-conventions repository. - SystemMemoryUtilizationName = "system.memory.utilization" - SystemMemoryUtilizationUnit = "1" - - // SystemPagingUsage is the metric conforming to the "system.paging.usage" - // semantic conventions. It represents the unix swap or windows pagefile usage. - // Instrument: updowncounter - // Unit: By - // Stability: Experimental - SystemPagingUsageName = "system.paging.usage" - SystemPagingUsageUnit = "By" - SystemPagingUsageDescription = "Unix swap or windows pagefile usage" - - // SystemPagingUtilization is the metric conforming to the - // "system.paging.utilization" semantic conventions. - // Instrument: gauge - // Unit: 1 - // Stability: Experimental - // NOTE: The description (brief) for this metric is not defined in the semantic-conventions repository. - SystemPagingUtilizationName = "system.paging.utilization" - SystemPagingUtilizationUnit = "1" - - // SystemPagingFaults is the metric conforming to the "system.paging.faults" - // semantic conventions. - // Instrument: counter - // Unit: {fault} - // Stability: Experimental - // NOTE: The description (brief) for this metric is not defined in the semantic-conventions repository. - SystemPagingFaultsName = "system.paging.faults" - SystemPagingFaultsUnit = "{fault}" - - // SystemPagingOperations is the metric conforming to the - // "system.paging.operations" semantic conventions. - // Instrument: counter - // Unit: {operation} - // Stability: Experimental - // NOTE: The description (brief) for this metric is not defined in the semantic-conventions repository. - SystemPagingOperationsName = "system.paging.operations" - SystemPagingOperationsUnit = "{operation}" - - // SystemDiskIo is the metric conforming to the "system.disk.io" semantic - // conventions. - // Instrument: counter - // Unit: By - // Stability: Experimental - // NOTE: The description (brief) for this metric is not defined in the semantic-conventions repository. - SystemDiskIoName = "system.disk.io" - SystemDiskIoUnit = "By" - - // SystemDiskOperations is the metric conforming to the - // "system.disk.operations" semantic conventions. - // Instrument: counter - // Unit: {operation} - // Stability: Experimental - // NOTE: The description (brief) for this metric is not defined in the semantic-conventions repository. - SystemDiskOperationsName = "system.disk.operations" - SystemDiskOperationsUnit = "{operation}" - - // SystemDiskIoTime is the metric conforming to the "system.disk.io_time" - // semantic conventions. It represents the time disk spent activated. - // Instrument: counter - // Unit: s - // Stability: Experimental - SystemDiskIoTimeName = "system.disk.io_time" - SystemDiskIoTimeUnit = "s" - SystemDiskIoTimeDescription = "Time disk spent activated" - - // SystemDiskOperationTime is the metric conforming to the - // "system.disk.operation_time" semantic conventions. It represents the sum of - // the time each operation took to complete. - // Instrument: counter - // Unit: s - // Stability: Experimental - SystemDiskOperationTimeName = "system.disk.operation_time" - SystemDiskOperationTimeUnit = "s" - SystemDiskOperationTimeDescription = "Sum of the time each operation took to complete" - - // SystemDiskMerged is the metric conforming to the "system.disk.merged" - // semantic conventions. - // Instrument: counter - // Unit: {operation} - // Stability: Experimental - // NOTE: The description (brief) for this metric is not defined in the semantic-conventions repository. - SystemDiskMergedName = "system.disk.merged" - SystemDiskMergedUnit = "{operation}" - - // SystemFilesystemUsage is the metric conforming to the - // "system.filesystem.usage" semantic conventions. - // Instrument: updowncounter - // Unit: By - // Stability: Experimental - // NOTE: The description (brief) for this metric is not defined in the semantic-conventions repository. - SystemFilesystemUsageName = "system.filesystem.usage" - SystemFilesystemUsageUnit = "By" - - // SystemFilesystemUtilization is the metric conforming to the - // "system.filesystem.utilization" semantic conventions. - // Instrument: gauge - // Unit: 1 - // Stability: Experimental - // NOTE: The description (brief) for this metric is not defined in the semantic-conventions repository. - SystemFilesystemUtilizationName = "system.filesystem.utilization" - SystemFilesystemUtilizationUnit = "1" - - // SystemNetworkDropped is the metric conforming to the - // "system.network.dropped" semantic conventions. It represents the count of - // packets that are dropped or discarded even though there was no error. - // Instrument: counter - // Unit: {packet} - // Stability: Experimental - SystemNetworkDroppedName = "system.network.dropped" - SystemNetworkDroppedUnit = "{packet}" - SystemNetworkDroppedDescription = "Count of packets that are dropped or discarded even though there was no error" - - // SystemNetworkPackets is the metric conforming to the - // "system.network.packets" semantic conventions. - // Instrument: counter - // Unit: {packet} - // Stability: Experimental - // NOTE: The description (brief) for this metric is not defined in the semantic-conventions repository. - SystemNetworkPacketsName = "system.network.packets" - SystemNetworkPacketsUnit = "{packet}" - - // SystemNetworkErrors is the metric conforming to the "system.network.errors" - // semantic conventions. It represents the count of network errors detected. - // Instrument: counter - // Unit: {error} - // Stability: Experimental - SystemNetworkErrorsName = "system.network.errors" - SystemNetworkErrorsUnit = "{error}" - SystemNetworkErrorsDescription = "Count of network errors detected" - - // SystemNetworkIo is the metric conforming to the "system.network.io" semantic - // conventions. - // Instrument: counter - // Unit: By - // Stability: Experimental - // NOTE: The description (brief) for this metric is not defined in the semantic-conventions repository. - SystemNetworkIoName = "system.network.io" - SystemNetworkIoUnit = "By" - - // SystemNetworkConnections is the metric conforming to the - // "system.network.connections" semantic conventions. - // Instrument: updowncounter - // Unit: {connection} - // Stability: Experimental - // NOTE: The description (brief) for this metric is not defined in the semantic-conventions repository. - SystemNetworkConnectionsName = "system.network.connections" - SystemNetworkConnectionsUnit = "{connection}" - - // SystemProcessesCount is the metric conforming to the - // "system.processes.count" semantic conventions. It represents the total - // number of processes in each state. - // Instrument: updowncounter - // Unit: {process} - // Stability: Experimental - SystemProcessesCountName = "system.processes.count" - SystemProcessesCountUnit = "{process}" - SystemProcessesCountDescription = "Total number of processes in each state" - - // SystemProcessesCreated is the metric conforming to the - // "system.processes.created" semantic conventions. It represents the total - // number of processes created over uptime of the host. - // Instrument: counter - // Unit: {process} - // Stability: Experimental - SystemProcessesCreatedName = "system.processes.created" - SystemProcessesCreatedUnit = "{process}" - SystemProcessesCreatedDescription = "Total number of processes created over uptime of the host" - - // SystemLinuxMemoryAvailable is the metric conforming to the - // "system.linux.memory.available" semantic conventions. It represents an - // estimate of how much memory is available for starting new applications, - // without causing swapping. - // Instrument: updowncounter - // Unit: By - // Stability: Experimental - SystemLinuxMemoryAvailableName = "system.linux.memory.available" - SystemLinuxMemoryAvailableUnit = "By" - SystemLinuxMemoryAvailableDescription = "An estimate of how much memory is available for starting new applications, without causing swapping" -) diff --git a/vendor/go.opentelemetry.io/otel/semconv/v1.24.0/resource.go b/vendor/go.opentelemetry.io/otel/semconv/v1.24.0/resource.go deleted file mode 100644 index d66bbe9c2..000000000 --- a/vendor/go.opentelemetry.io/otel/semconv/v1.24.0/resource.go +++ /dev/null @@ -1,2545 +0,0 @@ -// Copyright The OpenTelemetry Authors -// SPDX-License-Identifier: Apache-2.0 - -// Code generated from semantic convention specification. DO NOT EDIT. - -package semconv // import "go.opentelemetry.io/otel/semconv/v1.24.0" - -import "go.opentelemetry.io/otel/attribute" - -// A cloud environment (e.g. GCP, Azure, AWS). -const ( - // CloudAccountIDKey is the attribute Key conforming to the - // "cloud.account.id" semantic conventions. It represents the cloud account - // ID the resource is assigned to. - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: '111111111111', 'opentelemetry' - CloudAccountIDKey = attribute.Key("cloud.account.id") - - // CloudAvailabilityZoneKey is the attribute Key conforming to the - // "cloud.availability_zone" semantic conventions. It represents the cloud - // regions often have multiple, isolated locations known as zones to - // increase availability. Availability zone represents the zone where the - // resource is running. - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: 'us-east-1c' - // Note: Availability zones are called "zones" on Alibaba Cloud and Google - // Cloud. - CloudAvailabilityZoneKey = attribute.Key("cloud.availability_zone") - - // CloudPlatformKey is the attribute Key conforming to the "cloud.platform" - // semantic conventions. It represents the cloud platform in use. - // - // Type: Enum - // RequirementLevel: Optional - // Stability: experimental - // Note: The prefix of the service SHOULD match the one specified in - // `cloud.provider`. - CloudPlatformKey = attribute.Key("cloud.platform") - - // CloudProviderKey is the attribute Key conforming to the "cloud.provider" - // semantic conventions. It represents the name of the cloud provider. - // - // Type: Enum - // RequirementLevel: Optional - // Stability: experimental - CloudProviderKey = attribute.Key("cloud.provider") - - // CloudRegionKey is the attribute Key conforming to the "cloud.region" - // semantic conventions. It represents the geographical region the resource - // is running. - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: 'us-central1', 'us-east-1' - // Note: Refer to your provider's docs to see the available regions, for - // example [Alibaba Cloud - // regions](https://www.alibabacloud.com/help/doc-detail/40654.htm), [AWS - // regions](https://aws.amazon.com/about-aws/global-infrastructure/regions_az/), - // [Azure - // regions](https://azure.microsoft.com/global-infrastructure/geographies/), - // [Google Cloud regions](https://cloud.google.com/about/locations), or - // [Tencent Cloud - // regions](https://www.tencentcloud.com/document/product/213/6091). - CloudRegionKey = attribute.Key("cloud.region") - - // CloudResourceIDKey is the attribute Key conforming to the - // "cloud.resource_id" semantic conventions. It represents the cloud - // provider-specific native identifier of the monitored cloud resource - // (e.g. an - // [ARN](https://docs.aws.amazon.com/general/latest/gr/aws-arns-and-namespaces.html) - // on AWS, a [fully qualified resource - // ID](https://learn.microsoft.com/rest/api/resources/resources/get-by-id) - // on Azure, a [full resource - // name](https://cloud.google.com/apis/design/resource_names#full_resource_name) - // on GCP) - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: 'arn:aws:lambda:REGION:ACCOUNT_ID:function:my-function', - // '//run.googleapis.com/projects/PROJECT_ID/locations/LOCATION_ID/services/SERVICE_ID', - // '/subscriptions//resourceGroups//providers/Microsoft.Web/sites//functions/' - // Note: On some cloud providers, it may not be possible to determine the - // full ID at startup, - // so it may be necessary to set `cloud.resource_id` as a span attribute - // instead. - // - // The exact value to use for `cloud.resource_id` depends on the cloud - // provider. - // The following well-known definitions MUST be used if you set this - // attribute and they apply: - // - // * **AWS Lambda:** The function - // [ARN](https://docs.aws.amazon.com/general/latest/gr/aws-arns-and-namespaces.html). - // Take care not to use the "invoked ARN" directly but replace any - // [alias - // suffix](https://docs.aws.amazon.com/lambda/latest/dg/configuration-aliases.html) - // with the resolved function version, as the same runtime instance may - // be invokable with - // multiple different aliases. - // * **GCP:** The [URI of the - // resource](https://cloud.google.com/iam/docs/full-resource-names) - // * **Azure:** The [Fully Qualified Resource - // ID](https://docs.microsoft.com/rest/api/resources/resources/get-by-id) - // of the invoked function, - // *not* the function app, having the form - // `/subscriptions//resourceGroups//providers/Microsoft.Web/sites//functions/`. - // This means that a span attribute MUST be used, as an Azure function - // app can host multiple functions that would usually share - // a TracerProvider. - CloudResourceIDKey = attribute.Key("cloud.resource_id") -) - -var ( - // Alibaba Cloud Elastic Compute Service - CloudPlatformAlibabaCloudECS = CloudPlatformKey.String("alibaba_cloud_ecs") - // Alibaba Cloud Function Compute - CloudPlatformAlibabaCloudFc = CloudPlatformKey.String("alibaba_cloud_fc") - // Red Hat OpenShift on Alibaba Cloud - CloudPlatformAlibabaCloudOpenshift = CloudPlatformKey.String("alibaba_cloud_openshift") - // AWS Elastic Compute Cloud - CloudPlatformAWSEC2 = CloudPlatformKey.String("aws_ec2") - // AWS Elastic Container Service - CloudPlatformAWSECS = CloudPlatformKey.String("aws_ecs") - // AWS Elastic Kubernetes Service - CloudPlatformAWSEKS = CloudPlatformKey.String("aws_eks") - // AWS Lambda - CloudPlatformAWSLambda = CloudPlatformKey.String("aws_lambda") - // AWS Elastic Beanstalk - CloudPlatformAWSElasticBeanstalk = CloudPlatformKey.String("aws_elastic_beanstalk") - // AWS App Runner - CloudPlatformAWSAppRunner = CloudPlatformKey.String("aws_app_runner") - // Red Hat OpenShift on AWS (ROSA) - CloudPlatformAWSOpenshift = CloudPlatformKey.String("aws_openshift") - // Azure Virtual Machines - CloudPlatformAzureVM = CloudPlatformKey.String("azure_vm") - // Azure Container Instances - CloudPlatformAzureContainerInstances = CloudPlatformKey.String("azure_container_instances") - // Azure Kubernetes Service - CloudPlatformAzureAKS = CloudPlatformKey.String("azure_aks") - // Azure Functions - CloudPlatformAzureFunctions = CloudPlatformKey.String("azure_functions") - // Azure App Service - CloudPlatformAzureAppService = CloudPlatformKey.String("azure_app_service") - // Azure Red Hat OpenShift - CloudPlatformAzureOpenshift = CloudPlatformKey.String("azure_openshift") - // Google Bare Metal Solution (BMS) - CloudPlatformGCPBareMetalSolution = CloudPlatformKey.String("gcp_bare_metal_solution") - // Google Cloud Compute Engine (GCE) - CloudPlatformGCPComputeEngine = CloudPlatformKey.String("gcp_compute_engine") - // Google Cloud Run - CloudPlatformGCPCloudRun = CloudPlatformKey.String("gcp_cloud_run") - // Google Cloud Kubernetes Engine (GKE) - CloudPlatformGCPKubernetesEngine = CloudPlatformKey.String("gcp_kubernetes_engine") - // Google Cloud Functions (GCF) - CloudPlatformGCPCloudFunctions = CloudPlatformKey.String("gcp_cloud_functions") - // Google Cloud App Engine (GAE) - CloudPlatformGCPAppEngine = CloudPlatformKey.String("gcp_app_engine") - // Red Hat OpenShift on Google Cloud - CloudPlatformGCPOpenshift = CloudPlatformKey.String("gcp_openshift") - // Red Hat OpenShift on IBM Cloud - CloudPlatformIbmCloudOpenshift = CloudPlatformKey.String("ibm_cloud_openshift") - // Tencent Cloud Cloud Virtual Machine (CVM) - CloudPlatformTencentCloudCvm = CloudPlatformKey.String("tencent_cloud_cvm") - // Tencent Cloud Elastic Kubernetes Service (EKS) - CloudPlatformTencentCloudEKS = CloudPlatformKey.String("tencent_cloud_eks") - // Tencent Cloud Serverless Cloud Function (SCF) - CloudPlatformTencentCloudScf = CloudPlatformKey.String("tencent_cloud_scf") -) - -var ( - // Alibaba Cloud - CloudProviderAlibabaCloud = CloudProviderKey.String("alibaba_cloud") - // Amazon Web Services - CloudProviderAWS = CloudProviderKey.String("aws") - // Microsoft Azure - CloudProviderAzure = CloudProviderKey.String("azure") - // Google Cloud Platform - CloudProviderGCP = CloudProviderKey.String("gcp") - // Heroku Platform as a Service - CloudProviderHeroku = CloudProviderKey.String("heroku") - // IBM Cloud - CloudProviderIbmCloud = CloudProviderKey.String("ibm_cloud") - // Tencent Cloud - CloudProviderTencentCloud = CloudProviderKey.String("tencent_cloud") -) - -// CloudAccountID returns an attribute KeyValue conforming to the -// "cloud.account.id" semantic conventions. It represents the cloud account ID -// the resource is assigned to. -func CloudAccountID(val string) attribute.KeyValue { - return CloudAccountIDKey.String(val) -} - -// CloudAvailabilityZone returns an attribute KeyValue conforming to the -// "cloud.availability_zone" semantic conventions. It represents the cloud -// regions often have multiple, isolated locations known as zones to increase -// availability. Availability zone represents the zone where the resource is -// running. -func CloudAvailabilityZone(val string) attribute.KeyValue { - return CloudAvailabilityZoneKey.String(val) -} - -// CloudRegion returns an attribute KeyValue conforming to the -// "cloud.region" semantic conventions. It represents the geographical region -// the resource is running. -func CloudRegion(val string) attribute.KeyValue { - return CloudRegionKey.String(val) -} - -// CloudResourceID returns an attribute KeyValue conforming to the -// "cloud.resource_id" semantic conventions. It represents the cloud -// provider-specific native identifier of the monitored cloud resource (e.g. an -// [ARN](https://docs.aws.amazon.com/general/latest/gr/aws-arns-and-namespaces.html) -// on AWS, a [fully qualified resource -// ID](https://learn.microsoft.com/rest/api/resources/resources/get-by-id) on -// Azure, a [full resource -// name](https://cloud.google.com/apis/design/resource_names#full_resource_name) -// on GCP) -func CloudResourceID(val string) attribute.KeyValue { - return CloudResourceIDKey.String(val) -} - -// A container instance. -const ( - // ContainerCommandKey is the attribute Key conforming to the - // "container.command" semantic conventions. It represents the command used - // to run the container (i.e. the command name). - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: 'otelcontribcol' - // Note: If using embedded credentials or sensitive data, it is recommended - // to remove them to prevent potential leakage. - ContainerCommandKey = attribute.Key("container.command") - - // ContainerCommandArgsKey is the attribute Key conforming to the - // "container.command_args" semantic conventions. It represents the all the - // command arguments (including the command/executable itself) run by the - // container. [2] - // - // Type: string[] - // RequirementLevel: Optional - // Stability: experimental - // Examples: 'otelcontribcol, --config, config.yaml' - ContainerCommandArgsKey = attribute.Key("container.command_args") - - // ContainerCommandLineKey is the attribute Key conforming to the - // "container.command_line" semantic conventions. It represents the full - // command run by the container as a single string representing the full - // command. [2] - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: 'otelcontribcol --config config.yaml' - ContainerCommandLineKey = attribute.Key("container.command_line") - - // ContainerIDKey is the attribute Key conforming to the "container.id" - // semantic conventions. It represents the container ID. Usually a UUID, as - // for example used to [identify Docker - // containers](https://docs.docker.com/engine/reference/run/#container-identification). - // The UUID might be abbreviated. - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: 'a3bf90e006b2' - ContainerIDKey = attribute.Key("container.id") - - // ContainerImageIDKey is the attribute Key conforming to the - // "container.image.id" semantic conventions. It represents the runtime - // specific image identifier. Usually a hash algorithm followed by a UUID. - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: - // 'sha256:19c92d0a00d1b66d897bceaa7319bee0dd38a10a851c60bcec9474aa3f01e50f' - // Note: Docker defines a sha256 of the image id; `container.image.id` - // corresponds to the `Image` field from the Docker container inspect - // [API](https://docs.docker.com/engine/api/v1.43/#tag/Container/operation/ContainerInspect) - // endpoint. - // K8S defines a link to the container registry repository with digest - // `"imageID": "registry.azurecr.io - // /namespace/service/dockerfile@sha256:bdeabd40c3a8a492eaf9e8e44d0ebbb84bac7ee25ac0cf8a7159d25f62555625"`. - // The ID is assinged by the container runtime and can vary in different - // environments. Consider using `oci.manifest.digest` if it is important to - // identify the same image in different environments/runtimes. - ContainerImageIDKey = attribute.Key("container.image.id") - - // ContainerImageNameKey is the attribute Key conforming to the - // "container.image.name" semantic conventions. It represents the name of - // the image the container was built on. - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: 'gcr.io/opentelemetry/operator' - ContainerImageNameKey = attribute.Key("container.image.name") - - // ContainerImageRepoDigestsKey is the attribute Key conforming to the - // "container.image.repo_digests" semantic conventions. It represents the - // repo digests of the container image as provided by the container - // runtime. - // - // Type: string[] - // RequirementLevel: Optional - // Stability: experimental - // Examples: - // 'example@sha256:afcc7f1ac1b49db317a7196c902e61c6c3c4607d63599ee1a82d702d249a0ccb', - // 'internal.registry.example.com:5000/example@sha256:b69959407d21e8a062e0416bf13405bb2b71ed7a84dde4158ebafacfa06f5578' - // Note: - // [Docker](https://docs.docker.com/engine/api/v1.43/#tag/Image/operation/ImageInspect) - // and - // [CRI](https://github.com/kubernetes/cri-api/blob/c75ef5b473bbe2d0a4fc92f82235efd665ea8e9f/pkg/apis/runtime/v1/api.proto#L1237-L1238) - // report those under the `RepoDigests` field. - ContainerImageRepoDigestsKey = attribute.Key("container.image.repo_digests") - - // ContainerImageTagsKey is the attribute Key conforming to the - // "container.image.tags" semantic conventions. It represents the container - // image tags. An example can be found in [Docker Image - // Inspect](https://docs.docker.com/engine/api/v1.43/#tag/Image/operation/ImageInspect). - // Should be only the `` section of the full name for example from - // `registry.example.com/my-org/my-image:`. - // - // Type: string[] - // RequirementLevel: Optional - // Stability: experimental - // Examples: 'v1.27.1', '3.5.7-0' - ContainerImageTagsKey = attribute.Key("container.image.tags") - - // ContainerNameKey is the attribute Key conforming to the "container.name" - // semantic conventions. It represents the container name used by container - // runtime. - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: 'opentelemetry-autoconf' - ContainerNameKey = attribute.Key("container.name") - - // ContainerRuntimeKey is the attribute Key conforming to the - // "container.runtime" semantic conventions. It represents the container - // runtime managing this container. - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: 'docker', 'containerd', 'rkt' - ContainerRuntimeKey = attribute.Key("container.runtime") -) - -// ContainerCommand returns an attribute KeyValue conforming to the -// "container.command" semantic conventions. It represents the command used to -// run the container (i.e. the command name). -func ContainerCommand(val string) attribute.KeyValue { - return ContainerCommandKey.String(val) -} - -// ContainerCommandArgs returns an attribute KeyValue conforming to the -// "container.command_args" semantic conventions. It represents the all the -// command arguments (including the command/executable itself) run by the -// container. [2] -func ContainerCommandArgs(val ...string) attribute.KeyValue { - return ContainerCommandArgsKey.StringSlice(val) -} - -// ContainerCommandLine returns an attribute KeyValue conforming to the -// "container.command_line" semantic conventions. It represents the full -// command run by the container as a single string representing the full -// command. [2] -func ContainerCommandLine(val string) attribute.KeyValue { - return ContainerCommandLineKey.String(val) -} - -// ContainerID returns an attribute KeyValue conforming to the -// "container.id" semantic conventions. It represents the container ID. Usually -// a UUID, as for example used to [identify Docker -// containers](https://docs.docker.com/engine/reference/run/#container-identification). -// The UUID might be abbreviated. -func ContainerID(val string) attribute.KeyValue { - return ContainerIDKey.String(val) -} - -// ContainerImageID returns an attribute KeyValue conforming to the -// "container.image.id" semantic conventions. It represents the runtime -// specific image identifier. Usually a hash algorithm followed by a UUID. -func ContainerImageID(val string) attribute.KeyValue { - return ContainerImageIDKey.String(val) -} - -// ContainerImageName returns an attribute KeyValue conforming to the -// "container.image.name" semantic conventions. It represents the name of the -// image the container was built on. -func ContainerImageName(val string) attribute.KeyValue { - return ContainerImageNameKey.String(val) -} - -// ContainerImageRepoDigests returns an attribute KeyValue conforming to the -// "container.image.repo_digests" semantic conventions. It represents the repo -// digests of the container image as provided by the container runtime. -func ContainerImageRepoDigests(val ...string) attribute.KeyValue { - return ContainerImageRepoDigestsKey.StringSlice(val) -} - -// ContainerImageTags returns an attribute KeyValue conforming to the -// "container.image.tags" semantic conventions. It represents the container -// image tags. An example can be found in [Docker Image -// Inspect](https://docs.docker.com/engine/api/v1.43/#tag/Image/operation/ImageInspect). -// Should be only the `` section of the full name for example from -// `registry.example.com/my-org/my-image:`. -func ContainerImageTags(val ...string) attribute.KeyValue { - return ContainerImageTagsKey.StringSlice(val) -} - -// ContainerName returns an attribute KeyValue conforming to the -// "container.name" semantic conventions. It represents the container name used -// by container runtime. -func ContainerName(val string) attribute.KeyValue { - return ContainerNameKey.String(val) -} - -// ContainerRuntime returns an attribute KeyValue conforming to the -// "container.runtime" semantic conventions. It represents the container -// runtime managing this container. -func ContainerRuntime(val string) attribute.KeyValue { - return ContainerRuntimeKey.String(val) -} - -// Describes device attributes. -const ( - // DeviceIDKey is the attribute Key conforming to the "device.id" semantic - // conventions. It represents a unique identifier representing the device - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: '2ab2916d-a51f-4ac8-80ee-45ac31a28092' - // Note: The device identifier MUST only be defined using the values - // outlined below. This value is not an advertising identifier and MUST NOT - // be used as such. On iOS (Swift or Objective-C), this value MUST be equal - // to the [vendor - // identifier](https://developer.apple.com/documentation/uikit/uidevice/1620059-identifierforvendor). - // On Android (Java or Kotlin), this value MUST be equal to the Firebase - // Installation ID or a globally unique UUID which is persisted across - // sessions in your application. More information can be found - // [here](https://developer.android.com/training/articles/user-data-ids) on - // best practices and exact implementation details. Caution should be taken - // when storing personal data or anything which can identify a user. GDPR - // and data protection laws may apply, ensure you do your own due - // diligence. - DeviceIDKey = attribute.Key("device.id") - - // DeviceManufacturerKey is the attribute Key conforming to the - // "device.manufacturer" semantic conventions. It represents the name of - // the device manufacturer - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: 'Apple', 'Samsung' - // Note: The Android OS provides this field via - // [Build](https://developer.android.com/reference/android/os/Build#MANUFACTURER). - // iOS apps SHOULD hardcode the value `Apple`. - DeviceManufacturerKey = attribute.Key("device.manufacturer") - - // DeviceModelIdentifierKey is the attribute Key conforming to the - // "device.model.identifier" semantic conventions. It represents the model - // identifier for the device - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: 'iPhone3,4', 'SM-G920F' - // Note: It's recommended this value represents a machine-readable version - // of the model identifier rather than the market or consumer-friendly name - // of the device. - DeviceModelIdentifierKey = attribute.Key("device.model.identifier") - - // DeviceModelNameKey is the attribute Key conforming to the - // "device.model.name" semantic conventions. It represents the marketing - // name for the device model - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: 'iPhone 6s Plus', 'Samsung Galaxy S6' - // Note: It's recommended this value represents a human-readable version of - // the device model rather than a machine-readable alternative. - DeviceModelNameKey = attribute.Key("device.model.name") -) - -// DeviceID returns an attribute KeyValue conforming to the "device.id" -// semantic conventions. It represents a unique identifier representing the -// device -func DeviceID(val string) attribute.KeyValue { - return DeviceIDKey.String(val) -} - -// DeviceManufacturer returns an attribute KeyValue conforming to the -// "device.manufacturer" semantic conventions. It represents the name of the -// device manufacturer -func DeviceManufacturer(val string) attribute.KeyValue { - return DeviceManufacturerKey.String(val) -} - -// DeviceModelIdentifier returns an attribute KeyValue conforming to the -// "device.model.identifier" semantic conventions. It represents the model -// identifier for the device -func DeviceModelIdentifier(val string) attribute.KeyValue { - return DeviceModelIdentifierKey.String(val) -} - -// DeviceModelName returns an attribute KeyValue conforming to the -// "device.model.name" semantic conventions. It represents the marketing name -// for the device model -func DeviceModelName(val string) attribute.KeyValue { - return DeviceModelNameKey.String(val) -} - -// A host is defined as a computing instance. For example, physical servers, -// virtual machines, switches or disk array. -const ( - // HostArchKey is the attribute Key conforming to the "host.arch" semantic - // conventions. It represents the CPU architecture the host system is - // running on. - // - // Type: Enum - // RequirementLevel: Optional - // Stability: experimental - HostArchKey = attribute.Key("host.arch") - - // HostCPUCacheL2SizeKey is the attribute Key conforming to the - // "host.cpu.cache.l2.size" semantic conventions. It represents the amount - // of level 2 memory cache available to the processor (in Bytes). - // - // Type: int - // RequirementLevel: Optional - // Stability: experimental - // Examples: 12288000 - HostCPUCacheL2SizeKey = attribute.Key("host.cpu.cache.l2.size") - - // HostCPUFamilyKey is the attribute Key conforming to the - // "host.cpu.family" semantic conventions. It represents the family or - // generation of the CPU. - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: '6', 'PA-RISC 1.1e' - HostCPUFamilyKey = attribute.Key("host.cpu.family") - - // HostCPUModelIDKey is the attribute Key conforming to the - // "host.cpu.model.id" semantic conventions. It represents the model - // identifier. It provides more granular information about the CPU, - // distinguishing it from other CPUs within the same family. - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: '6', '9000/778/B180L' - HostCPUModelIDKey = attribute.Key("host.cpu.model.id") - - // HostCPUModelNameKey is the attribute Key conforming to the - // "host.cpu.model.name" semantic conventions. It represents the model - // designation of the processor. - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: '11th Gen Intel(R) Core(TM) i7-1185G7 @ 3.00GHz' - HostCPUModelNameKey = attribute.Key("host.cpu.model.name") - - // HostCPUSteppingKey is the attribute Key conforming to the - // "host.cpu.stepping" semantic conventions. It represents the stepping or - // core revisions. - // - // Type: int - // RequirementLevel: Optional - // Stability: experimental - // Examples: 1 - HostCPUSteppingKey = attribute.Key("host.cpu.stepping") - - // HostCPUVendorIDKey is the attribute Key conforming to the - // "host.cpu.vendor.id" semantic conventions. It represents the processor - // manufacturer identifier. A maximum 12-character string. - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: 'GenuineIntel' - // Note: [CPUID](https://wiki.osdev.org/CPUID) command returns the vendor - // ID string in EBX, EDX and ECX registers. Writing these to memory in this - // order results in a 12-character string. - HostCPUVendorIDKey = attribute.Key("host.cpu.vendor.id") - - // HostIDKey is the attribute Key conforming to the "host.id" semantic - // conventions. It represents the unique host ID. For Cloud, this must be - // the instance_id assigned by the cloud provider. For non-containerized - // systems, this should be the `machine-id`. See the table below for the - // sources to use to determine the `machine-id` based on operating system. - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: 'fdbf79e8af94cb7f9e8df36789187052' - HostIDKey = attribute.Key("host.id") - - // HostImageIDKey is the attribute Key conforming to the "host.image.id" - // semantic conventions. It represents the vM image ID or host OS image ID. - // For Cloud, this value is from the provider. - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: 'ami-07b06b442921831e5' - HostImageIDKey = attribute.Key("host.image.id") - - // HostImageNameKey is the attribute Key conforming to the - // "host.image.name" semantic conventions. It represents the name of the VM - // image or OS install the host was instantiated from. - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: 'infra-ami-eks-worker-node-7d4ec78312', 'CentOS-8-x86_64-1905' - HostImageNameKey = attribute.Key("host.image.name") - - // HostImageVersionKey is the attribute Key conforming to the - // "host.image.version" semantic conventions. It represents the version - // string of the VM image or host OS as defined in [Version - // Attributes](/docs/resource/README.md#version-attributes). - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: '0.1' - HostImageVersionKey = attribute.Key("host.image.version") - - // HostIPKey is the attribute Key conforming to the "host.ip" semantic - // conventions. It represents the available IP addresses of the host, - // excluding loopback interfaces. - // - // Type: string[] - // RequirementLevel: Optional - // Stability: experimental - // Examples: '192.168.1.140', 'fe80::abc2:4a28:737a:609e' - // Note: IPv4 Addresses MUST be specified in dotted-quad notation. IPv6 - // addresses MUST be specified in the [RFC - // 5952](https://www.rfc-editor.org/rfc/rfc5952.html) format. - HostIPKey = attribute.Key("host.ip") - - // HostMacKey is the attribute Key conforming to the "host.mac" semantic - // conventions. It represents the available MAC addresses of the host, - // excluding loopback interfaces. - // - // Type: string[] - // RequirementLevel: Optional - // Stability: experimental - // Examples: 'AC-DE-48-23-45-67', 'AC-DE-48-23-45-67-01-9F' - // Note: MAC Addresses MUST be represented in [IEEE RA hexadecimal - // form](https://standards.ieee.org/wp-content/uploads/import/documents/tutorials/eui.pdf): - // as hyphen-separated octets in uppercase hexadecimal form from most to - // least significant. - HostMacKey = attribute.Key("host.mac") - - // HostNameKey is the attribute Key conforming to the "host.name" semantic - // conventions. It represents the name of the host. On Unix systems, it may - // contain what the hostname command returns, or the fully qualified - // hostname, or another name specified by the user. - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: 'opentelemetry-test' - HostNameKey = attribute.Key("host.name") - - // HostTypeKey is the attribute Key conforming to the "host.type" semantic - // conventions. It represents the type of host. For Cloud, this must be the - // machine type. - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: 'n1-standard-1' - HostTypeKey = attribute.Key("host.type") -) - -var ( - // AMD64 - HostArchAMD64 = HostArchKey.String("amd64") - // ARM32 - HostArchARM32 = HostArchKey.String("arm32") - // ARM64 - HostArchARM64 = HostArchKey.String("arm64") - // Itanium - HostArchIA64 = HostArchKey.String("ia64") - // 32-bit PowerPC - HostArchPPC32 = HostArchKey.String("ppc32") - // 64-bit PowerPC - HostArchPPC64 = HostArchKey.String("ppc64") - // IBM z/Architecture - HostArchS390x = HostArchKey.String("s390x") - // 32-bit x86 - HostArchX86 = HostArchKey.String("x86") -) - -// HostCPUCacheL2Size returns an attribute KeyValue conforming to the -// "host.cpu.cache.l2.size" semantic conventions. It represents the amount of -// level 2 memory cache available to the processor (in Bytes). -func HostCPUCacheL2Size(val int) attribute.KeyValue { - return HostCPUCacheL2SizeKey.Int(val) -} - -// HostCPUFamily returns an attribute KeyValue conforming to the -// "host.cpu.family" semantic conventions. It represents the family or -// generation of the CPU. -func HostCPUFamily(val string) attribute.KeyValue { - return HostCPUFamilyKey.String(val) -} - -// HostCPUModelID returns an attribute KeyValue conforming to the -// "host.cpu.model.id" semantic conventions. It represents the model -// identifier. It provides more granular information about the CPU, -// distinguishing it from other CPUs within the same family. -func HostCPUModelID(val string) attribute.KeyValue { - return HostCPUModelIDKey.String(val) -} - -// HostCPUModelName returns an attribute KeyValue conforming to the -// "host.cpu.model.name" semantic conventions. It represents the model -// designation of the processor. -func HostCPUModelName(val string) attribute.KeyValue { - return HostCPUModelNameKey.String(val) -} - -// HostCPUStepping returns an attribute KeyValue conforming to the -// "host.cpu.stepping" semantic conventions. It represents the stepping or core -// revisions. -func HostCPUStepping(val int) attribute.KeyValue { - return HostCPUSteppingKey.Int(val) -} - -// HostCPUVendorID returns an attribute KeyValue conforming to the -// "host.cpu.vendor.id" semantic conventions. It represents the processor -// manufacturer identifier. A maximum 12-character string. -func HostCPUVendorID(val string) attribute.KeyValue { - return HostCPUVendorIDKey.String(val) -} - -// HostID returns an attribute KeyValue conforming to the "host.id" semantic -// conventions. It represents the unique host ID. For Cloud, this must be the -// instance_id assigned by the cloud provider. For non-containerized systems, -// this should be the `machine-id`. See the table below for the sources to use -// to determine the `machine-id` based on operating system. -func HostID(val string) attribute.KeyValue { - return HostIDKey.String(val) -} - -// HostImageID returns an attribute KeyValue conforming to the -// "host.image.id" semantic conventions. It represents the vM image ID or host -// OS image ID. For Cloud, this value is from the provider. -func HostImageID(val string) attribute.KeyValue { - return HostImageIDKey.String(val) -} - -// HostImageName returns an attribute KeyValue conforming to the -// "host.image.name" semantic conventions. It represents the name of the VM -// image or OS install the host was instantiated from. -func HostImageName(val string) attribute.KeyValue { - return HostImageNameKey.String(val) -} - -// HostImageVersion returns an attribute KeyValue conforming to the -// "host.image.version" semantic conventions. It represents the version string -// of the VM image or host OS as defined in [Version -// Attributes](/docs/resource/README.md#version-attributes). -func HostImageVersion(val string) attribute.KeyValue { - return HostImageVersionKey.String(val) -} - -// HostIP returns an attribute KeyValue conforming to the "host.ip" semantic -// conventions. It represents the available IP addresses of the host, excluding -// loopback interfaces. -func HostIP(val ...string) attribute.KeyValue { - return HostIPKey.StringSlice(val) -} - -// HostMac returns an attribute KeyValue conforming to the "host.mac" -// semantic conventions. It represents the available MAC addresses of the host, -// excluding loopback interfaces. -func HostMac(val ...string) attribute.KeyValue { - return HostMacKey.StringSlice(val) -} - -// HostName returns an attribute KeyValue conforming to the "host.name" -// semantic conventions. It represents the name of the host. On Unix systems, -// it may contain what the hostname command returns, or the fully qualified -// hostname, or another name specified by the user. -func HostName(val string) attribute.KeyValue { - return HostNameKey.String(val) -} - -// HostType returns an attribute KeyValue conforming to the "host.type" -// semantic conventions. It represents the type of host. For Cloud, this must -// be the machine type. -func HostType(val string) attribute.KeyValue { - return HostTypeKey.String(val) -} - -// Kubernetes resource attributes. -const ( - // K8SClusterNameKey is the attribute Key conforming to the - // "k8s.cluster.name" semantic conventions. It represents the name of the - // cluster. - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: 'opentelemetry-cluster' - K8SClusterNameKey = attribute.Key("k8s.cluster.name") - - // K8SClusterUIDKey is the attribute Key conforming to the - // "k8s.cluster.uid" semantic conventions. It represents a pseudo-ID for - // the cluster, set to the UID of the `kube-system` namespace. - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: '218fc5a9-a5f1-4b54-aa05-46717d0ab26d' - // Note: K8S doesn't have support for obtaining a cluster ID. If this is - // ever - // added, we will recommend collecting the `k8s.cluster.uid` through the - // official APIs. In the meantime, we are able to use the `uid` of the - // `kube-system` namespace as a proxy for cluster ID. Read on for the - // rationale. - // - // Every object created in a K8S cluster is assigned a distinct UID. The - // `kube-system` namespace is used by Kubernetes itself and will exist - // for the lifetime of the cluster. Using the `uid` of the `kube-system` - // namespace is a reasonable proxy for the K8S ClusterID as it will only - // change if the cluster is rebuilt. Furthermore, Kubernetes UIDs are - // UUIDs as standardized by - // [ISO/IEC 9834-8 and ITU-T - // X.667](https://www.itu.int/ITU-T/studygroups/com17/oid.html). - // Which states: - // - // > If generated according to one of the mechanisms defined in Rec. - // ITU-T X.667 | ISO/IEC 9834-8, a UUID is either guaranteed to be - // different from all other UUIDs generated before 3603 A.D., or is - // extremely likely to be different (depending on the mechanism chosen). - // - // Therefore, UIDs between clusters should be extremely unlikely to - // conflict. - K8SClusterUIDKey = attribute.Key("k8s.cluster.uid") - - // K8SContainerNameKey is the attribute Key conforming to the - // "k8s.container.name" semantic conventions. It represents the name of the - // Container from Pod specification, must be unique within a Pod. Container - // runtime usually uses different globally unique name (`container.name`). - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: 'redis' - K8SContainerNameKey = attribute.Key("k8s.container.name") - - // K8SContainerRestartCountKey is the attribute Key conforming to the - // "k8s.container.restart_count" semantic conventions. It represents the - // number of times the container was restarted. This attribute can be used - // to identify a particular container (running or stopped) within a - // container spec. - // - // Type: int - // RequirementLevel: Optional - // Stability: experimental - // Examples: 0, 2 - K8SContainerRestartCountKey = attribute.Key("k8s.container.restart_count") - - // K8SCronJobNameKey is the attribute Key conforming to the - // "k8s.cronjob.name" semantic conventions. It represents the name of the - // CronJob. - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: 'opentelemetry' - K8SCronJobNameKey = attribute.Key("k8s.cronjob.name") - - // K8SCronJobUIDKey is the attribute Key conforming to the - // "k8s.cronjob.uid" semantic conventions. It represents the UID of the - // CronJob. - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: '275ecb36-5aa8-4c2a-9c47-d8bb681b9aff' - K8SCronJobUIDKey = attribute.Key("k8s.cronjob.uid") - - // K8SDaemonSetNameKey is the attribute Key conforming to the - // "k8s.daemonset.name" semantic conventions. It represents the name of the - // DaemonSet. - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: 'opentelemetry' - K8SDaemonSetNameKey = attribute.Key("k8s.daemonset.name") - - // K8SDaemonSetUIDKey is the attribute Key conforming to the - // "k8s.daemonset.uid" semantic conventions. It represents the UID of the - // DaemonSet. - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: '275ecb36-5aa8-4c2a-9c47-d8bb681b9aff' - K8SDaemonSetUIDKey = attribute.Key("k8s.daemonset.uid") - - // K8SDeploymentNameKey is the attribute Key conforming to the - // "k8s.deployment.name" semantic conventions. It represents the name of - // the Deployment. - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: 'opentelemetry' - K8SDeploymentNameKey = attribute.Key("k8s.deployment.name") - - // K8SDeploymentUIDKey is the attribute Key conforming to the - // "k8s.deployment.uid" semantic conventions. It represents the UID of the - // Deployment. - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: '275ecb36-5aa8-4c2a-9c47-d8bb681b9aff' - K8SDeploymentUIDKey = attribute.Key("k8s.deployment.uid") - - // K8SJobNameKey is the attribute Key conforming to the "k8s.job.name" - // semantic conventions. It represents the name of the Job. - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: 'opentelemetry' - K8SJobNameKey = attribute.Key("k8s.job.name") - - // K8SJobUIDKey is the attribute Key conforming to the "k8s.job.uid" - // semantic conventions. It represents the UID of the Job. - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: '275ecb36-5aa8-4c2a-9c47-d8bb681b9aff' - K8SJobUIDKey = attribute.Key("k8s.job.uid") - - // K8SNamespaceNameKey is the attribute Key conforming to the - // "k8s.namespace.name" semantic conventions. It represents the name of the - // namespace that the pod is running in. - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: 'default' - K8SNamespaceNameKey = attribute.Key("k8s.namespace.name") - - // K8SNodeNameKey is the attribute Key conforming to the "k8s.node.name" - // semantic conventions. It represents the name of the Node. - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: 'node-1' - K8SNodeNameKey = attribute.Key("k8s.node.name") - - // K8SNodeUIDKey is the attribute Key conforming to the "k8s.node.uid" - // semantic conventions. It represents the UID of the Node. - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: '1eb3a0c6-0477-4080-a9cb-0cb7db65c6a2' - K8SNodeUIDKey = attribute.Key("k8s.node.uid") - - // K8SPodNameKey is the attribute Key conforming to the "k8s.pod.name" - // semantic conventions. It represents the name of the Pod. - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: 'opentelemetry-pod-autoconf' - K8SPodNameKey = attribute.Key("k8s.pod.name") - - // K8SPodUIDKey is the attribute Key conforming to the "k8s.pod.uid" - // semantic conventions. It represents the UID of the Pod. - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: '275ecb36-5aa8-4c2a-9c47-d8bb681b9aff' - K8SPodUIDKey = attribute.Key("k8s.pod.uid") - - // K8SReplicaSetNameKey is the attribute Key conforming to the - // "k8s.replicaset.name" semantic conventions. It represents the name of - // the ReplicaSet. - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: 'opentelemetry' - K8SReplicaSetNameKey = attribute.Key("k8s.replicaset.name") - - // K8SReplicaSetUIDKey is the attribute Key conforming to the - // "k8s.replicaset.uid" semantic conventions. It represents the UID of the - // ReplicaSet. - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: '275ecb36-5aa8-4c2a-9c47-d8bb681b9aff' - K8SReplicaSetUIDKey = attribute.Key("k8s.replicaset.uid") - - // K8SStatefulSetNameKey is the attribute Key conforming to the - // "k8s.statefulset.name" semantic conventions. It represents the name of - // the StatefulSet. - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: 'opentelemetry' - K8SStatefulSetNameKey = attribute.Key("k8s.statefulset.name") - - // K8SStatefulSetUIDKey is the attribute Key conforming to the - // "k8s.statefulset.uid" semantic conventions. It represents the UID of the - // StatefulSet. - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: '275ecb36-5aa8-4c2a-9c47-d8bb681b9aff' - K8SStatefulSetUIDKey = attribute.Key("k8s.statefulset.uid") -) - -// K8SClusterName returns an attribute KeyValue conforming to the -// "k8s.cluster.name" semantic conventions. It represents the name of the -// cluster. -func K8SClusterName(val string) attribute.KeyValue { - return K8SClusterNameKey.String(val) -} - -// K8SClusterUID returns an attribute KeyValue conforming to the -// "k8s.cluster.uid" semantic conventions. It represents a pseudo-ID for the -// cluster, set to the UID of the `kube-system` namespace. -func K8SClusterUID(val string) attribute.KeyValue { - return K8SClusterUIDKey.String(val) -} - -// K8SContainerName returns an attribute KeyValue conforming to the -// "k8s.container.name" semantic conventions. It represents the name of the -// Container from Pod specification, must be unique within a Pod. Container -// runtime usually uses different globally unique name (`container.name`). -func K8SContainerName(val string) attribute.KeyValue { - return K8SContainerNameKey.String(val) -} - -// K8SContainerRestartCount returns an attribute KeyValue conforming to the -// "k8s.container.restart_count" semantic conventions. It represents the number -// of times the container was restarted. This attribute can be used to identify -// a particular container (running or stopped) within a container spec. -func K8SContainerRestartCount(val int) attribute.KeyValue { - return K8SContainerRestartCountKey.Int(val) -} - -// K8SCronJobName returns an attribute KeyValue conforming to the -// "k8s.cronjob.name" semantic conventions. It represents the name of the -// CronJob. -func K8SCronJobName(val string) attribute.KeyValue { - return K8SCronJobNameKey.String(val) -} - -// K8SCronJobUID returns an attribute KeyValue conforming to the -// "k8s.cronjob.uid" semantic conventions. It represents the UID of the -// CronJob. -func K8SCronJobUID(val string) attribute.KeyValue { - return K8SCronJobUIDKey.String(val) -} - -// K8SDaemonSetName returns an attribute KeyValue conforming to the -// "k8s.daemonset.name" semantic conventions. It represents the name of the -// DaemonSet. -func K8SDaemonSetName(val string) attribute.KeyValue { - return K8SDaemonSetNameKey.String(val) -} - -// K8SDaemonSetUID returns an attribute KeyValue conforming to the -// "k8s.daemonset.uid" semantic conventions. It represents the UID of the -// DaemonSet. -func K8SDaemonSetUID(val string) attribute.KeyValue { - return K8SDaemonSetUIDKey.String(val) -} - -// K8SDeploymentName returns an attribute KeyValue conforming to the -// "k8s.deployment.name" semantic conventions. It represents the name of the -// Deployment. -func K8SDeploymentName(val string) attribute.KeyValue { - return K8SDeploymentNameKey.String(val) -} - -// K8SDeploymentUID returns an attribute KeyValue conforming to the -// "k8s.deployment.uid" semantic conventions. It represents the UID of the -// Deployment. -func K8SDeploymentUID(val string) attribute.KeyValue { - return K8SDeploymentUIDKey.String(val) -} - -// K8SJobName returns an attribute KeyValue conforming to the "k8s.job.name" -// semantic conventions. It represents the name of the Job. -func K8SJobName(val string) attribute.KeyValue { - return K8SJobNameKey.String(val) -} - -// K8SJobUID returns an attribute KeyValue conforming to the "k8s.job.uid" -// semantic conventions. It represents the UID of the Job. -func K8SJobUID(val string) attribute.KeyValue { - return K8SJobUIDKey.String(val) -} - -// K8SNamespaceName returns an attribute KeyValue conforming to the -// "k8s.namespace.name" semantic conventions. It represents the name of the -// namespace that the pod is running in. -func K8SNamespaceName(val string) attribute.KeyValue { - return K8SNamespaceNameKey.String(val) -} - -// K8SNodeName returns an attribute KeyValue conforming to the -// "k8s.node.name" semantic conventions. It represents the name of the Node. -func K8SNodeName(val string) attribute.KeyValue { - return K8SNodeNameKey.String(val) -} - -// K8SNodeUID returns an attribute KeyValue conforming to the "k8s.node.uid" -// semantic conventions. It represents the UID of the Node. -func K8SNodeUID(val string) attribute.KeyValue { - return K8SNodeUIDKey.String(val) -} - -// K8SPodName returns an attribute KeyValue conforming to the "k8s.pod.name" -// semantic conventions. It represents the name of the Pod. -func K8SPodName(val string) attribute.KeyValue { - return K8SPodNameKey.String(val) -} - -// K8SPodUID returns an attribute KeyValue conforming to the "k8s.pod.uid" -// semantic conventions. It represents the UID of the Pod. -func K8SPodUID(val string) attribute.KeyValue { - return K8SPodUIDKey.String(val) -} - -// K8SReplicaSetName returns an attribute KeyValue conforming to the -// "k8s.replicaset.name" semantic conventions. It represents the name of the -// ReplicaSet. -func K8SReplicaSetName(val string) attribute.KeyValue { - return K8SReplicaSetNameKey.String(val) -} - -// K8SReplicaSetUID returns an attribute KeyValue conforming to the -// "k8s.replicaset.uid" semantic conventions. It represents the UID of the -// ReplicaSet. -func K8SReplicaSetUID(val string) attribute.KeyValue { - return K8SReplicaSetUIDKey.String(val) -} - -// K8SStatefulSetName returns an attribute KeyValue conforming to the -// "k8s.statefulset.name" semantic conventions. It represents the name of the -// StatefulSet. -func K8SStatefulSetName(val string) attribute.KeyValue { - return K8SStatefulSetNameKey.String(val) -} - -// K8SStatefulSetUID returns an attribute KeyValue conforming to the -// "k8s.statefulset.uid" semantic conventions. It represents the UID of the -// StatefulSet. -func K8SStatefulSetUID(val string) attribute.KeyValue { - return K8SStatefulSetUIDKey.String(val) -} - -// An OCI image manifest. -const ( - // OciManifestDigestKey is the attribute Key conforming to the - // "oci.manifest.digest" semantic conventions. It represents the digest of - // the OCI image manifest. For container images specifically is the digest - // by which the container image is known. - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: - // 'sha256:e4ca62c0d62f3e886e684806dfe9d4e0cda60d54986898173c1083856cfda0f4' - // Note: Follows [OCI Image Manifest - // Specification](https://github.com/opencontainers/image-spec/blob/main/manifest.md), - // and specifically the [Digest - // property](https://github.com/opencontainers/image-spec/blob/main/descriptor.md#digests). - // An example can be found in [Example Image - // Manifest](https://docs.docker.com/registry/spec/manifest-v2-2/#example-image-manifest). - OciManifestDigestKey = attribute.Key("oci.manifest.digest") -) - -// OciManifestDigest returns an attribute KeyValue conforming to the -// "oci.manifest.digest" semantic conventions. It represents the digest of the -// OCI image manifest. For container images specifically is the digest by which -// the container image is known. -func OciManifestDigest(val string) attribute.KeyValue { - return OciManifestDigestKey.String(val) -} - -// The operating system (OS) on which the process represented by this resource -// is running. -const ( - // OSBuildIDKey is the attribute Key conforming to the "os.build_id" - // semantic conventions. It represents the unique identifier for a - // particular build or compilation of the operating system. - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: 'TQ3C.230805.001.B2', '20E247', '22621' - OSBuildIDKey = attribute.Key("os.build_id") - - // OSDescriptionKey is the attribute Key conforming to the "os.description" - // semantic conventions. It represents the human readable (not intended to - // be parsed) OS version information, like e.g. reported by `ver` or - // `lsb_release -a` commands. - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: 'Microsoft Windows [Version 10.0.18363.778]', 'Ubuntu 18.04.1 - // LTS' - OSDescriptionKey = attribute.Key("os.description") - - // OSNameKey is the attribute Key conforming to the "os.name" semantic - // conventions. It represents the human readable operating system name. - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: 'iOS', 'Android', 'Ubuntu' - OSNameKey = attribute.Key("os.name") - - // OSTypeKey is the attribute Key conforming to the "os.type" semantic - // conventions. It represents the operating system type. - // - // Type: Enum - // RequirementLevel: Optional - // Stability: experimental - OSTypeKey = attribute.Key("os.type") - - // OSVersionKey is the attribute Key conforming to the "os.version" - // semantic conventions. It represents the version string of the operating - // system as defined in [Version - // Attributes](/docs/resource/README.md#version-attributes). - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: '14.2.1', '18.04.1' - OSVersionKey = attribute.Key("os.version") -) - -var ( - // Microsoft Windows - OSTypeWindows = OSTypeKey.String("windows") - // Linux - OSTypeLinux = OSTypeKey.String("linux") - // Apple Darwin - OSTypeDarwin = OSTypeKey.String("darwin") - // FreeBSD - OSTypeFreeBSD = OSTypeKey.String("freebsd") - // NetBSD - OSTypeNetBSD = OSTypeKey.String("netbsd") - // OpenBSD - OSTypeOpenBSD = OSTypeKey.String("openbsd") - // DragonFly BSD - OSTypeDragonflyBSD = OSTypeKey.String("dragonflybsd") - // HP-UX (Hewlett Packard Unix) - OSTypeHPUX = OSTypeKey.String("hpux") - // AIX (Advanced Interactive eXecutive) - OSTypeAIX = OSTypeKey.String("aix") - // SunOS, Oracle Solaris - OSTypeSolaris = OSTypeKey.String("solaris") - // IBM z/OS - OSTypeZOS = OSTypeKey.String("z_os") -) - -// OSBuildID returns an attribute KeyValue conforming to the "os.build_id" -// semantic conventions. It represents the unique identifier for a particular -// build or compilation of the operating system. -func OSBuildID(val string) attribute.KeyValue { - return OSBuildIDKey.String(val) -} - -// OSDescription returns an attribute KeyValue conforming to the -// "os.description" semantic conventions. It represents the human readable (not -// intended to be parsed) OS version information, like e.g. reported by `ver` -// or `lsb_release -a` commands. -func OSDescription(val string) attribute.KeyValue { - return OSDescriptionKey.String(val) -} - -// OSName returns an attribute KeyValue conforming to the "os.name" semantic -// conventions. It represents the human readable operating system name. -func OSName(val string) attribute.KeyValue { - return OSNameKey.String(val) -} - -// OSVersion returns an attribute KeyValue conforming to the "os.version" -// semantic conventions. It represents the version string of the operating -// system as defined in [Version -// Attributes](/docs/resource/README.md#version-attributes). -func OSVersion(val string) attribute.KeyValue { - return OSVersionKey.String(val) -} - -// An operating system process. -const ( - // ProcessCommandKey is the attribute Key conforming to the - // "process.command" semantic conventions. It represents the command used - // to launch the process (i.e. the command name). On Linux based systems, - // can be set to the zeroth string in `proc/[pid]/cmdline`. On Windows, can - // be set to the first parameter extracted from `GetCommandLineW`. - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: 'cmd/otelcol' - ProcessCommandKey = attribute.Key("process.command") - - // ProcessCommandArgsKey is the attribute Key conforming to the - // "process.command_args" semantic conventions. It represents the all the - // command arguments (including the command/executable itself) as received - // by the process. On Linux-based systems (and some other Unixoid systems - // supporting procfs), can be set according to the list of null-delimited - // strings extracted from `proc/[pid]/cmdline`. For libc-based executables, - // this would be the full argv vector passed to `main`. - // - // Type: string[] - // RequirementLevel: Optional - // Stability: experimental - // Examples: 'cmd/otecol', '--config=config.yaml' - ProcessCommandArgsKey = attribute.Key("process.command_args") - - // ProcessCommandLineKey is the attribute Key conforming to the - // "process.command_line" semantic conventions. It represents the full - // command used to launch the process as a single string representing the - // full command. On Windows, can be set to the result of `GetCommandLineW`. - // Do not set this if you have to assemble it just for monitoring; use - // `process.command_args` instead. - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: 'C:\\cmd\\otecol --config="my directory\\config.yaml"' - ProcessCommandLineKey = attribute.Key("process.command_line") - - // ProcessExecutableNameKey is the attribute Key conforming to the - // "process.executable.name" semantic conventions. It represents the name - // of the process executable. On Linux based systems, can be set to the - // `Name` in `proc/[pid]/status`. On Windows, can be set to the base name - // of `GetProcessImageFileNameW`. - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: 'otelcol' - ProcessExecutableNameKey = attribute.Key("process.executable.name") - - // ProcessExecutablePathKey is the attribute Key conforming to the - // "process.executable.path" semantic conventions. It represents the full - // path to the process executable. On Linux based systems, can be set to - // the target of `proc/[pid]/exe`. On Windows, can be set to the result of - // `GetProcessImageFileNameW`. - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: '/usr/bin/cmd/otelcol' - ProcessExecutablePathKey = attribute.Key("process.executable.path") - - // ProcessOwnerKey is the attribute Key conforming to the "process.owner" - // semantic conventions. It represents the username of the user that owns - // the process. - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: 'root' - ProcessOwnerKey = attribute.Key("process.owner") - - // ProcessParentPIDKey is the attribute Key conforming to the - // "process.parent_pid" semantic conventions. It represents the parent - // Process identifier (PPID). - // - // Type: int - // RequirementLevel: Optional - // Stability: experimental - // Examples: 111 - ProcessParentPIDKey = attribute.Key("process.parent_pid") - - // ProcessPIDKey is the attribute Key conforming to the "process.pid" - // semantic conventions. It represents the process identifier (PID). - // - // Type: int - // RequirementLevel: Optional - // Stability: experimental - // Examples: 1234 - ProcessPIDKey = attribute.Key("process.pid") - - // ProcessRuntimeDescriptionKey is the attribute Key conforming to the - // "process.runtime.description" semantic conventions. It represents an - // additional description about the runtime of the process, for example a - // specific vendor customization of the runtime environment. - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: 'Eclipse OpenJ9 Eclipse OpenJ9 VM openj9-0.21.0' - ProcessRuntimeDescriptionKey = attribute.Key("process.runtime.description") - - // ProcessRuntimeNameKey is the attribute Key conforming to the - // "process.runtime.name" semantic conventions. It represents the name of - // the runtime of this process. For compiled native binaries, this SHOULD - // be the name of the compiler. - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: 'OpenJDK Runtime Environment' - ProcessRuntimeNameKey = attribute.Key("process.runtime.name") - - // ProcessRuntimeVersionKey is the attribute Key conforming to the - // "process.runtime.version" semantic conventions. It represents the - // version of the runtime of this process, as returned by the runtime - // without modification. - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: '14.0.2' - ProcessRuntimeVersionKey = attribute.Key("process.runtime.version") -) - -// ProcessCommand returns an attribute KeyValue conforming to the -// "process.command" semantic conventions. It represents the command used to -// launch the process (i.e. the command name). On Linux based systems, can be -// set to the zeroth string in `proc/[pid]/cmdline`. On Windows, can be set to -// the first parameter extracted from `GetCommandLineW`. -func ProcessCommand(val string) attribute.KeyValue { - return ProcessCommandKey.String(val) -} - -// ProcessCommandArgs returns an attribute KeyValue conforming to the -// "process.command_args" semantic conventions. It represents the all the -// command arguments (including the command/executable itself) as received by -// the process. On Linux-based systems (and some other Unixoid systems -// supporting procfs), can be set according to the list of null-delimited -// strings extracted from `proc/[pid]/cmdline`. For libc-based executables, -// this would be the full argv vector passed to `main`. -func ProcessCommandArgs(val ...string) attribute.KeyValue { - return ProcessCommandArgsKey.StringSlice(val) -} - -// ProcessCommandLine returns an attribute KeyValue conforming to the -// "process.command_line" semantic conventions. It represents the full command -// used to launch the process as a single string representing the full command. -// On Windows, can be set to the result of `GetCommandLineW`. Do not set this -// if you have to assemble it just for monitoring; use `process.command_args` -// instead. -func ProcessCommandLine(val string) attribute.KeyValue { - return ProcessCommandLineKey.String(val) -} - -// ProcessExecutableName returns an attribute KeyValue conforming to the -// "process.executable.name" semantic conventions. It represents the name of -// the process executable. On Linux based systems, can be set to the `Name` in -// `proc/[pid]/status`. On Windows, can be set to the base name of -// `GetProcessImageFileNameW`. -func ProcessExecutableName(val string) attribute.KeyValue { - return ProcessExecutableNameKey.String(val) -} - -// ProcessExecutablePath returns an attribute KeyValue conforming to the -// "process.executable.path" semantic conventions. It represents the full path -// to the process executable. On Linux based systems, can be set to the target -// of `proc/[pid]/exe`. On Windows, can be set to the result of -// `GetProcessImageFileNameW`. -func ProcessExecutablePath(val string) attribute.KeyValue { - return ProcessExecutablePathKey.String(val) -} - -// ProcessOwner returns an attribute KeyValue conforming to the -// "process.owner" semantic conventions. It represents the username of the user -// that owns the process. -func ProcessOwner(val string) attribute.KeyValue { - return ProcessOwnerKey.String(val) -} - -// ProcessParentPID returns an attribute KeyValue conforming to the -// "process.parent_pid" semantic conventions. It represents the parent Process -// identifier (PPID). -func ProcessParentPID(val int) attribute.KeyValue { - return ProcessParentPIDKey.Int(val) -} - -// ProcessPID returns an attribute KeyValue conforming to the "process.pid" -// semantic conventions. It represents the process identifier (PID). -func ProcessPID(val int) attribute.KeyValue { - return ProcessPIDKey.Int(val) -} - -// ProcessRuntimeDescription returns an attribute KeyValue conforming to the -// "process.runtime.description" semantic conventions. It represents an -// additional description about the runtime of the process, for example a -// specific vendor customization of the runtime environment. -func ProcessRuntimeDescription(val string) attribute.KeyValue { - return ProcessRuntimeDescriptionKey.String(val) -} - -// ProcessRuntimeName returns an attribute KeyValue conforming to the -// "process.runtime.name" semantic conventions. It represents the name of the -// runtime of this process. For compiled native binaries, this SHOULD be the -// name of the compiler. -func ProcessRuntimeName(val string) attribute.KeyValue { - return ProcessRuntimeNameKey.String(val) -} - -// ProcessRuntimeVersion returns an attribute KeyValue conforming to the -// "process.runtime.version" semantic conventions. It represents the version of -// the runtime of this process, as returned by the runtime without -// modification. -func ProcessRuntimeVersion(val string) attribute.KeyValue { - return ProcessRuntimeVersionKey.String(val) -} - -// The Android platform on which the Android application is running. -const ( - // AndroidOSAPILevelKey is the attribute Key conforming to the - // "android.os.api_level" semantic conventions. It represents the uniquely - // identifies the framework API revision offered by a version - // (`os.version`) of the android operating system. More information can be - // found - // [here](https://developer.android.com/guide/topics/manifest/uses-sdk-element#APILevels). - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: '33', '32' - AndroidOSAPILevelKey = attribute.Key("android.os.api_level") -) - -// AndroidOSAPILevel returns an attribute KeyValue conforming to the -// "android.os.api_level" semantic conventions. It represents the uniquely -// identifies the framework API revision offered by a version (`os.version`) of -// the android operating system. More information can be found -// [here](https://developer.android.com/guide/topics/manifest/uses-sdk-element#APILevels). -func AndroidOSAPILevel(val string) attribute.KeyValue { - return AndroidOSAPILevelKey.String(val) -} - -// The web browser in which the application represented by the resource is -// running. The `browser.*` attributes MUST be used only for resources that -// represent applications running in a web browser (regardless of whether -// running on a mobile or desktop device). -const ( - // BrowserBrandsKey is the attribute Key conforming to the "browser.brands" - // semantic conventions. It represents the array of brand name and version - // separated by a space - // - // Type: string[] - // RequirementLevel: Optional - // Stability: experimental - // Examples: ' Not A;Brand 99', 'Chromium 99', 'Chrome 99' - // Note: This value is intended to be taken from the [UA client hints - // API](https://wicg.github.io/ua-client-hints/#interface) - // (`navigator.userAgentData.brands`). - BrowserBrandsKey = attribute.Key("browser.brands") - - // BrowserLanguageKey is the attribute Key conforming to the - // "browser.language" semantic conventions. It represents the preferred - // language of the user using the browser - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: 'en', 'en-US', 'fr', 'fr-FR' - // Note: This value is intended to be taken from the Navigator API - // `navigator.language`. - BrowserLanguageKey = attribute.Key("browser.language") - - // BrowserMobileKey is the attribute Key conforming to the "browser.mobile" - // semantic conventions. It represents a boolean that is true if the - // browser is running on a mobile device - // - // Type: boolean - // RequirementLevel: Optional - // Stability: experimental - // Note: This value is intended to be taken from the [UA client hints - // API](https://wicg.github.io/ua-client-hints/#interface) - // (`navigator.userAgentData.mobile`). If unavailable, this attribute - // SHOULD be left unset. - BrowserMobileKey = attribute.Key("browser.mobile") - - // BrowserPlatformKey is the attribute Key conforming to the - // "browser.platform" semantic conventions. It represents the platform on - // which the browser is running - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: 'Windows', 'macOS', 'Android' - // Note: This value is intended to be taken from the [UA client hints - // API](https://wicg.github.io/ua-client-hints/#interface) - // (`navigator.userAgentData.platform`). If unavailable, the legacy - // `navigator.platform` API SHOULD NOT be used instead and this attribute - // SHOULD be left unset in order for the values to be consistent. - // The list of possible values is defined in the [W3C User-Agent Client - // Hints - // specification](https://wicg.github.io/ua-client-hints/#sec-ch-ua-platform). - // Note that some (but not all) of these values can overlap with values in - // the [`os.type` and `os.name` attributes](./os.md). However, for - // consistency, the values in the `browser.platform` attribute should - // capture the exact value that the user agent provides. - BrowserPlatformKey = attribute.Key("browser.platform") -) - -// BrowserBrands returns an attribute KeyValue conforming to the -// "browser.brands" semantic conventions. It represents the array of brand name -// and version separated by a space -func BrowserBrands(val ...string) attribute.KeyValue { - return BrowserBrandsKey.StringSlice(val) -} - -// BrowserLanguage returns an attribute KeyValue conforming to the -// "browser.language" semantic conventions. It represents the preferred -// language of the user using the browser -func BrowserLanguage(val string) attribute.KeyValue { - return BrowserLanguageKey.String(val) -} - -// BrowserMobile returns an attribute KeyValue conforming to the -// "browser.mobile" semantic conventions. It represents a boolean that is true -// if the browser is running on a mobile device -func BrowserMobile(val bool) attribute.KeyValue { - return BrowserMobileKey.Bool(val) -} - -// BrowserPlatform returns an attribute KeyValue conforming to the -// "browser.platform" semantic conventions. It represents the platform on which -// the browser is running -func BrowserPlatform(val string) attribute.KeyValue { - return BrowserPlatformKey.String(val) -} - -// Resources used by AWS Elastic Container Service (ECS). -const ( - // AWSECSClusterARNKey is the attribute Key conforming to the - // "aws.ecs.cluster.arn" semantic conventions. It represents the ARN of an - // [ECS - // cluster](https://docs.aws.amazon.com/AmazonECS/latest/developerguide/clusters.html). - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: 'arn:aws:ecs:us-west-2:123456789123:cluster/my-cluster' - AWSECSClusterARNKey = attribute.Key("aws.ecs.cluster.arn") - - // AWSECSContainerARNKey is the attribute Key conforming to the - // "aws.ecs.container.arn" semantic conventions. It represents the Amazon - // Resource Name (ARN) of an [ECS container - // instance](https://docs.aws.amazon.com/AmazonECS/latest/developerguide/ECS_instances.html). - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: - // 'arn:aws:ecs:us-west-1:123456789123:container/32624152-9086-4f0e-acae-1a75b14fe4d9' - AWSECSContainerARNKey = attribute.Key("aws.ecs.container.arn") - - // AWSECSLaunchtypeKey is the attribute Key conforming to the - // "aws.ecs.launchtype" semantic conventions. It represents the [launch - // type](https://docs.aws.amazon.com/AmazonECS/latest/developerguide/launch_types.html) - // for an ECS task. - // - // Type: Enum - // RequirementLevel: Optional - // Stability: experimental - AWSECSLaunchtypeKey = attribute.Key("aws.ecs.launchtype") - - // AWSECSTaskARNKey is the attribute Key conforming to the - // "aws.ecs.task.arn" semantic conventions. It represents the ARN of an - // [ECS task - // definition](https://docs.aws.amazon.com/AmazonECS/latest/developerguide/task_definitions.html). - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: - // 'arn:aws:ecs:us-west-1:123456789123:task/10838bed-421f-43ef-870a-f43feacbbb5b' - AWSECSTaskARNKey = attribute.Key("aws.ecs.task.arn") - - // AWSECSTaskFamilyKey is the attribute Key conforming to the - // "aws.ecs.task.family" semantic conventions. It represents the task - // definition family this task definition is a member of. - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: 'opentelemetry-family' - AWSECSTaskFamilyKey = attribute.Key("aws.ecs.task.family") - - // AWSECSTaskRevisionKey is the attribute Key conforming to the - // "aws.ecs.task.revision" semantic conventions. It represents the revision - // for this task definition. - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: '8', '26' - AWSECSTaskRevisionKey = attribute.Key("aws.ecs.task.revision") -) - -var ( - // ec2 - AWSECSLaunchtypeEC2 = AWSECSLaunchtypeKey.String("ec2") - // fargate - AWSECSLaunchtypeFargate = AWSECSLaunchtypeKey.String("fargate") -) - -// AWSECSClusterARN returns an attribute KeyValue conforming to the -// "aws.ecs.cluster.arn" semantic conventions. It represents the ARN of an [ECS -// cluster](https://docs.aws.amazon.com/AmazonECS/latest/developerguide/clusters.html). -func AWSECSClusterARN(val string) attribute.KeyValue { - return AWSECSClusterARNKey.String(val) -} - -// AWSECSContainerARN returns an attribute KeyValue conforming to the -// "aws.ecs.container.arn" semantic conventions. It represents the Amazon -// Resource Name (ARN) of an [ECS container -// instance](https://docs.aws.amazon.com/AmazonECS/latest/developerguide/ECS_instances.html). -func AWSECSContainerARN(val string) attribute.KeyValue { - return AWSECSContainerARNKey.String(val) -} - -// AWSECSTaskARN returns an attribute KeyValue conforming to the -// "aws.ecs.task.arn" semantic conventions. It represents the ARN of an [ECS -// task -// definition](https://docs.aws.amazon.com/AmazonECS/latest/developerguide/task_definitions.html). -func AWSECSTaskARN(val string) attribute.KeyValue { - return AWSECSTaskARNKey.String(val) -} - -// AWSECSTaskFamily returns an attribute KeyValue conforming to the -// "aws.ecs.task.family" semantic conventions. It represents the task -// definition family this task definition is a member of. -func AWSECSTaskFamily(val string) attribute.KeyValue { - return AWSECSTaskFamilyKey.String(val) -} - -// AWSECSTaskRevision returns an attribute KeyValue conforming to the -// "aws.ecs.task.revision" semantic conventions. It represents the revision for -// this task definition. -func AWSECSTaskRevision(val string) attribute.KeyValue { - return AWSECSTaskRevisionKey.String(val) -} - -// Resources used by AWS Elastic Kubernetes Service (EKS). -const ( - // AWSEKSClusterARNKey is the attribute Key conforming to the - // "aws.eks.cluster.arn" semantic conventions. It represents the ARN of an - // EKS cluster. - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: 'arn:aws:ecs:us-west-2:123456789123:cluster/my-cluster' - AWSEKSClusterARNKey = attribute.Key("aws.eks.cluster.arn") -) - -// AWSEKSClusterARN returns an attribute KeyValue conforming to the -// "aws.eks.cluster.arn" semantic conventions. It represents the ARN of an EKS -// cluster. -func AWSEKSClusterARN(val string) attribute.KeyValue { - return AWSEKSClusterARNKey.String(val) -} - -// Resources specific to Amazon Web Services. -const ( - // AWSLogGroupARNsKey is the attribute Key conforming to the - // "aws.log.group.arns" semantic conventions. It represents the Amazon - // Resource Name(s) (ARN) of the AWS log group(s). - // - // Type: string[] - // RequirementLevel: Optional - // Stability: experimental - // Examples: - // 'arn:aws:logs:us-west-1:123456789012:log-group:/aws/my/group:*' - // Note: See the [log group ARN format - // documentation](https://docs.aws.amazon.com/AmazonCloudWatch/latest/logs/iam-access-control-overview-cwl.html#CWL_ARN_Format). - AWSLogGroupARNsKey = attribute.Key("aws.log.group.arns") - - // AWSLogGroupNamesKey is the attribute Key conforming to the - // "aws.log.group.names" semantic conventions. It represents the name(s) of - // the AWS log group(s) an application is writing to. - // - // Type: string[] - // RequirementLevel: Optional - // Stability: experimental - // Examples: '/aws/lambda/my-function', 'opentelemetry-service' - // Note: Multiple log groups must be supported for cases like - // multi-container applications, where a single application has sidecar - // containers, and each write to their own log group. - AWSLogGroupNamesKey = attribute.Key("aws.log.group.names") - - // AWSLogStreamARNsKey is the attribute Key conforming to the - // "aws.log.stream.arns" semantic conventions. It represents the ARN(s) of - // the AWS log stream(s). - // - // Type: string[] - // RequirementLevel: Optional - // Stability: experimental - // Examples: - // 'arn:aws:logs:us-west-1:123456789012:log-group:/aws/my/group:log-stream:logs/main/10838bed-421f-43ef-870a-f43feacbbb5b' - // Note: See the [log stream ARN format - // documentation](https://docs.aws.amazon.com/AmazonCloudWatch/latest/logs/iam-access-control-overview-cwl.html#CWL_ARN_Format). - // One log group can contain several log streams, so these ARNs necessarily - // identify both a log group and a log stream. - AWSLogStreamARNsKey = attribute.Key("aws.log.stream.arns") - - // AWSLogStreamNamesKey is the attribute Key conforming to the - // "aws.log.stream.names" semantic conventions. It represents the name(s) - // of the AWS log stream(s) an application is writing to. - // - // Type: string[] - // RequirementLevel: Optional - // Stability: experimental - // Examples: 'logs/main/10838bed-421f-43ef-870a-f43feacbbb5b' - AWSLogStreamNamesKey = attribute.Key("aws.log.stream.names") -) - -// AWSLogGroupARNs returns an attribute KeyValue conforming to the -// "aws.log.group.arns" semantic conventions. It represents the Amazon Resource -// Name(s) (ARN) of the AWS log group(s). -func AWSLogGroupARNs(val ...string) attribute.KeyValue { - return AWSLogGroupARNsKey.StringSlice(val) -} - -// AWSLogGroupNames returns an attribute KeyValue conforming to the -// "aws.log.group.names" semantic conventions. It represents the name(s) of the -// AWS log group(s) an application is writing to. -func AWSLogGroupNames(val ...string) attribute.KeyValue { - return AWSLogGroupNamesKey.StringSlice(val) -} - -// AWSLogStreamARNs returns an attribute KeyValue conforming to the -// "aws.log.stream.arns" semantic conventions. It represents the ARN(s) of the -// AWS log stream(s). -func AWSLogStreamARNs(val ...string) attribute.KeyValue { - return AWSLogStreamARNsKey.StringSlice(val) -} - -// AWSLogStreamNames returns an attribute KeyValue conforming to the -// "aws.log.stream.names" semantic conventions. It represents the name(s) of -// the AWS log stream(s) an application is writing to. -func AWSLogStreamNames(val ...string) attribute.KeyValue { - return AWSLogStreamNamesKey.StringSlice(val) -} - -// Resource used by Google Cloud Run. -const ( - // GCPCloudRunJobExecutionKey is the attribute Key conforming to the - // "gcp.cloud_run.job.execution" semantic conventions. It represents the - // name of the Cloud Run - // [execution](https://cloud.google.com/run/docs/managing/job-executions) - // being run for the Job, as set by the - // [`CLOUD_RUN_EXECUTION`](https://cloud.google.com/run/docs/container-contract#jobs-env-vars) - // environment variable. - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: 'job-name-xxxx', 'sample-job-mdw84' - GCPCloudRunJobExecutionKey = attribute.Key("gcp.cloud_run.job.execution") - - // GCPCloudRunJobTaskIndexKey is the attribute Key conforming to the - // "gcp.cloud_run.job.task_index" semantic conventions. It represents the - // index for a task within an execution as provided by the - // [`CLOUD_RUN_TASK_INDEX`](https://cloud.google.com/run/docs/container-contract#jobs-env-vars) - // environment variable. - // - // Type: int - // RequirementLevel: Optional - // Stability: experimental - // Examples: 0, 1 - GCPCloudRunJobTaskIndexKey = attribute.Key("gcp.cloud_run.job.task_index") -) - -// GCPCloudRunJobExecution returns an attribute KeyValue conforming to the -// "gcp.cloud_run.job.execution" semantic conventions. It represents the name -// of the Cloud Run -// [execution](https://cloud.google.com/run/docs/managing/job-executions) being -// run for the Job, as set by the -// [`CLOUD_RUN_EXECUTION`](https://cloud.google.com/run/docs/container-contract#jobs-env-vars) -// environment variable. -func GCPCloudRunJobExecution(val string) attribute.KeyValue { - return GCPCloudRunJobExecutionKey.String(val) -} - -// GCPCloudRunJobTaskIndex returns an attribute KeyValue conforming to the -// "gcp.cloud_run.job.task_index" semantic conventions. It represents the index -// for a task within an execution as provided by the -// [`CLOUD_RUN_TASK_INDEX`](https://cloud.google.com/run/docs/container-contract#jobs-env-vars) -// environment variable. -func GCPCloudRunJobTaskIndex(val int) attribute.KeyValue { - return GCPCloudRunJobTaskIndexKey.Int(val) -} - -// Resources used by Google Compute Engine (GCE). -const ( - // GCPGceInstanceHostnameKey is the attribute Key conforming to the - // "gcp.gce.instance.hostname" semantic conventions. It represents the - // hostname of a GCE instance. This is the full value of the default or - // [custom - // hostname](https://cloud.google.com/compute/docs/instances/custom-hostname-vm). - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: 'my-host1234.example.com', - // 'sample-vm.us-west1-b.c.my-project.internal' - GCPGceInstanceHostnameKey = attribute.Key("gcp.gce.instance.hostname") - - // GCPGceInstanceNameKey is the attribute Key conforming to the - // "gcp.gce.instance.name" semantic conventions. It represents the instance - // name of a GCE instance. This is the value provided by `host.name`, the - // visible name of the instance in the Cloud Console UI, and the prefix for - // the default hostname of the instance as defined by the [default internal - // DNS - // name](https://cloud.google.com/compute/docs/internal-dns#instance-fully-qualified-domain-names). - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: 'instance-1', 'my-vm-name' - GCPGceInstanceNameKey = attribute.Key("gcp.gce.instance.name") -) - -// GCPGceInstanceHostname returns an attribute KeyValue conforming to the -// "gcp.gce.instance.hostname" semantic conventions. It represents the hostname -// of a GCE instance. This is the full value of the default or [custom -// hostname](https://cloud.google.com/compute/docs/instances/custom-hostname-vm). -func GCPGceInstanceHostname(val string) attribute.KeyValue { - return GCPGceInstanceHostnameKey.String(val) -} - -// GCPGceInstanceName returns an attribute KeyValue conforming to the -// "gcp.gce.instance.name" semantic conventions. It represents the instance -// name of a GCE instance. This is the value provided by `host.name`, the -// visible name of the instance in the Cloud Console UI, and the prefix for the -// default hostname of the instance as defined by the [default internal DNS -// name](https://cloud.google.com/compute/docs/internal-dns#instance-fully-qualified-domain-names). -func GCPGceInstanceName(val string) attribute.KeyValue { - return GCPGceInstanceNameKey.String(val) -} - -// Heroku dyno metadata -const ( - // HerokuAppIDKey is the attribute Key conforming to the "heroku.app.id" - // semantic conventions. It represents the unique identifier for the - // application - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: '2daa2797-e42b-4624-9322-ec3f968df4da' - HerokuAppIDKey = attribute.Key("heroku.app.id") - - // HerokuReleaseCommitKey is the attribute Key conforming to the - // "heroku.release.commit" semantic conventions. It represents the commit - // hash for the current release - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: 'e6134959463efd8966b20e75b913cafe3f5ec' - HerokuReleaseCommitKey = attribute.Key("heroku.release.commit") - - // HerokuReleaseCreationTimestampKey is the attribute Key conforming to the - // "heroku.release.creation_timestamp" semantic conventions. It represents - // the time and date the release was created - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: '2022-10-23T18:00:42Z' - HerokuReleaseCreationTimestampKey = attribute.Key("heroku.release.creation_timestamp") -) - -// HerokuAppID returns an attribute KeyValue conforming to the -// "heroku.app.id" semantic conventions. It represents the unique identifier -// for the application -func HerokuAppID(val string) attribute.KeyValue { - return HerokuAppIDKey.String(val) -} - -// HerokuReleaseCommit returns an attribute KeyValue conforming to the -// "heroku.release.commit" semantic conventions. It represents the commit hash -// for the current release -func HerokuReleaseCommit(val string) attribute.KeyValue { - return HerokuReleaseCommitKey.String(val) -} - -// HerokuReleaseCreationTimestamp returns an attribute KeyValue conforming -// to the "heroku.release.creation_timestamp" semantic conventions. It -// represents the time and date the release was created -func HerokuReleaseCreationTimestamp(val string) attribute.KeyValue { - return HerokuReleaseCreationTimestampKey.String(val) -} - -// The software deployment. -const ( - // DeploymentEnvironmentKey is the attribute Key conforming to the - // "deployment.environment" semantic conventions. It represents the name of - // the [deployment - // environment](https://wikipedia.org/wiki/Deployment_environment) (aka - // deployment tier). - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: 'staging', 'production' - // Note: `deployment.environment` does not affect the uniqueness - // constraints defined through - // the `service.namespace`, `service.name` and `service.instance.id` - // resource attributes. - // This implies that resources carrying the following attribute - // combinations MUST be - // considered to be identifying the same service: - // - // * `service.name=frontend`, `deployment.environment=production` - // * `service.name=frontend`, `deployment.environment=staging`. - DeploymentEnvironmentKey = attribute.Key("deployment.environment") -) - -// DeploymentEnvironment returns an attribute KeyValue conforming to the -// "deployment.environment" semantic conventions. It represents the name of the -// [deployment environment](https://wikipedia.org/wiki/Deployment_environment) -// (aka deployment tier). -func DeploymentEnvironment(val string) attribute.KeyValue { - return DeploymentEnvironmentKey.String(val) -} - -// A serverless instance. -const ( - // FaaSInstanceKey is the attribute Key conforming to the "faas.instance" - // semantic conventions. It represents the execution environment ID as a - // string, that will be potentially reused for other invocations to the - // same function/function version. - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: '2021/06/28/[$LATEST]2f399eb14537447da05ab2a2e39309de' - // Note: * **AWS Lambda:** Use the (full) log stream name. - FaaSInstanceKey = attribute.Key("faas.instance") - - // FaaSMaxMemoryKey is the attribute Key conforming to the - // "faas.max_memory" semantic conventions. It represents the amount of - // memory available to the serverless function converted to Bytes. - // - // Type: int - // RequirementLevel: Optional - // Stability: experimental - // Examples: 134217728 - // Note: It's recommended to set this attribute since e.g. too little - // memory can easily stop a Java AWS Lambda function from working - // correctly. On AWS Lambda, the environment variable - // `AWS_LAMBDA_FUNCTION_MEMORY_SIZE` provides this information (which must - // be multiplied by 1,048,576). - FaaSMaxMemoryKey = attribute.Key("faas.max_memory") - - // FaaSNameKey is the attribute Key conforming to the "faas.name" semantic - // conventions. It represents the name of the single function that this - // runtime instance executes. - // - // Type: string - // RequirementLevel: Required - // Stability: experimental - // Examples: 'my-function', 'myazurefunctionapp/some-function-name' - // Note: This is the name of the function as configured/deployed on the - // FaaS - // platform and is usually different from the name of the callback - // function (which may be stored in the - // [`code.namespace`/`code.function`](/docs/general/attributes.md#source-code-attributes) - // span attributes). - // - // For some cloud providers, the above definition is ambiguous. The - // following - // definition of function name MUST be used for this attribute - // (and consequently the span name) for the listed cloud - // providers/products: - // - // * **Azure:** The full name `/`, i.e., function app name - // followed by a forward slash followed by the function name (this form - // can also be seen in the resource JSON for the function). - // This means that a span attribute MUST be used, as an Azure function - // app can host multiple functions that would usually share - // a TracerProvider (see also the `cloud.resource_id` attribute). - FaaSNameKey = attribute.Key("faas.name") - - // FaaSVersionKey is the attribute Key conforming to the "faas.version" - // semantic conventions. It represents the immutable version of the - // function being executed. - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: '26', 'pinkfroid-00002' - // Note: Depending on the cloud provider and platform, use: - // - // * **AWS Lambda:** The [function - // version](https://docs.aws.amazon.com/lambda/latest/dg/configuration-versions.html) - // (an integer represented as a decimal string). - // * **Google Cloud Run (Services):** The - // [revision](https://cloud.google.com/run/docs/managing/revisions) - // (i.e., the function name plus the revision suffix). - // * **Google Cloud Functions:** The value of the - // [`K_REVISION` environment - // variable](https://cloud.google.com/functions/docs/env-var#runtime_environment_variables_set_automatically). - // * **Azure Functions:** Not applicable. Do not set this attribute. - FaaSVersionKey = attribute.Key("faas.version") -) - -// FaaSInstance returns an attribute KeyValue conforming to the -// "faas.instance" semantic conventions. It represents the execution -// environment ID as a string, that will be potentially reused for other -// invocations to the same function/function version. -func FaaSInstance(val string) attribute.KeyValue { - return FaaSInstanceKey.String(val) -} - -// FaaSMaxMemory returns an attribute KeyValue conforming to the -// "faas.max_memory" semantic conventions. It represents the amount of memory -// available to the serverless function converted to Bytes. -func FaaSMaxMemory(val int) attribute.KeyValue { - return FaaSMaxMemoryKey.Int(val) -} - -// FaaSName returns an attribute KeyValue conforming to the "faas.name" -// semantic conventions. It represents the name of the single function that -// this runtime instance executes. -func FaaSName(val string) attribute.KeyValue { - return FaaSNameKey.String(val) -} - -// FaaSVersion returns an attribute KeyValue conforming to the -// "faas.version" semantic conventions. It represents the immutable version of -// the function being executed. -func FaaSVersion(val string) attribute.KeyValue { - return FaaSVersionKey.String(val) -} - -// A service instance. -const ( - // ServiceNameKey is the attribute Key conforming to the "service.name" - // semantic conventions. It represents the logical name of the service. - // - // Type: string - // RequirementLevel: Required - // Stability: experimental - // Examples: 'shoppingcart' - // Note: MUST be the same for all instances of horizontally scaled - // services. If the value was not specified, SDKs MUST fallback to - // `unknown_service:` concatenated with - // [`process.executable.name`](process.md#process), e.g. - // `unknown_service:bash`. If `process.executable.name` is not available, - // the value MUST be set to `unknown_service`. - ServiceNameKey = attribute.Key("service.name") - - // ServiceVersionKey is the attribute Key conforming to the - // "service.version" semantic conventions. It represents the version string - // of the service API or implementation. The format is not defined by these - // conventions. - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: '2.0.0', 'a01dbef8a' - ServiceVersionKey = attribute.Key("service.version") -) - -// ServiceName returns an attribute KeyValue conforming to the -// "service.name" semantic conventions. It represents the logical name of the -// service. -func ServiceName(val string) attribute.KeyValue { - return ServiceNameKey.String(val) -} - -// ServiceVersion returns an attribute KeyValue conforming to the -// "service.version" semantic conventions. It represents the version string of -// the service API or implementation. The format is not defined by these -// conventions. -func ServiceVersion(val string) attribute.KeyValue { - return ServiceVersionKey.String(val) -} - -// A service instance. -const ( - // ServiceInstanceIDKey is the attribute Key conforming to the - // "service.instance.id" semantic conventions. It represents the string ID - // of the service instance. - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: 'my-k8s-pod-deployment-1', - // '627cc493-f310-47de-96bd-71410b7dec09' - // Note: MUST be unique for each instance of the same - // `service.namespace,service.name` pair (in other words - // `service.namespace,service.name,service.instance.id` triplet MUST be - // globally unique). The ID helps to distinguish instances of the same - // service that exist at the same time (e.g. instances of a horizontally - // scaled service). It is preferable for the ID to be persistent and stay - // the same for the lifetime of the service instance, however it is - // acceptable that the ID is ephemeral and changes during important - // lifetime events for the service (e.g. service restarts). If the service - // has no inherent unique ID that can be used as the value of this - // attribute it is recommended to generate a random Version 1 or Version 4 - // RFC 4122 UUID (services aiming for reproducible UUIDs may also use - // Version 5, see RFC 4122 for more recommendations). - ServiceInstanceIDKey = attribute.Key("service.instance.id") - - // ServiceNamespaceKey is the attribute Key conforming to the - // "service.namespace" semantic conventions. It represents a namespace for - // `service.name`. - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: 'Shop' - // Note: A string value having a meaning that helps to distinguish a group - // of services, for example the team name that owns a group of services. - // `service.name` is expected to be unique within the same namespace. If - // `service.namespace` is not specified in the Resource then `service.name` - // is expected to be unique for all services that have no explicit - // namespace defined (so the empty/unspecified namespace is simply one more - // valid namespace). Zero-length namespace string is assumed equal to - // unspecified namespace. - ServiceNamespaceKey = attribute.Key("service.namespace") -) - -// ServiceInstanceID returns an attribute KeyValue conforming to the -// "service.instance.id" semantic conventions. It represents the string ID of -// the service instance. -func ServiceInstanceID(val string) attribute.KeyValue { - return ServiceInstanceIDKey.String(val) -} - -// ServiceNamespace returns an attribute KeyValue conforming to the -// "service.namespace" semantic conventions. It represents a namespace for -// `service.name`. -func ServiceNamespace(val string) attribute.KeyValue { - return ServiceNamespaceKey.String(val) -} - -// The telemetry SDK used to capture data recorded by the instrumentation -// libraries. -const ( - // TelemetrySDKLanguageKey is the attribute Key conforming to the - // "telemetry.sdk.language" semantic conventions. It represents the - // language of the telemetry SDK. - // - // Type: Enum - // RequirementLevel: Required - // Stability: experimental - TelemetrySDKLanguageKey = attribute.Key("telemetry.sdk.language") - - // TelemetrySDKNameKey is the attribute Key conforming to the - // "telemetry.sdk.name" semantic conventions. It represents the name of the - // telemetry SDK as defined above. - // - // Type: string - // RequirementLevel: Required - // Stability: experimental - // Examples: 'opentelemetry' - // Note: The OpenTelemetry SDK MUST set the `telemetry.sdk.name` attribute - // to `opentelemetry`. - // If another SDK, like a fork or a vendor-provided implementation, is - // used, this SDK MUST set the - // `telemetry.sdk.name` attribute to the fully-qualified class or module - // name of this SDK's main entry point - // or another suitable identifier depending on the language. - // The identifier `opentelemetry` is reserved and MUST NOT be used in this - // case. - // All custom identifiers SHOULD be stable across different versions of an - // implementation. - TelemetrySDKNameKey = attribute.Key("telemetry.sdk.name") - - // TelemetrySDKVersionKey is the attribute Key conforming to the - // "telemetry.sdk.version" semantic conventions. It represents the version - // string of the telemetry SDK. - // - // Type: string - // RequirementLevel: Required - // Stability: experimental - // Examples: '1.2.3' - TelemetrySDKVersionKey = attribute.Key("telemetry.sdk.version") -) - -var ( - // cpp - TelemetrySDKLanguageCPP = TelemetrySDKLanguageKey.String("cpp") - // dotnet - TelemetrySDKLanguageDotnet = TelemetrySDKLanguageKey.String("dotnet") - // erlang - TelemetrySDKLanguageErlang = TelemetrySDKLanguageKey.String("erlang") - // go - TelemetrySDKLanguageGo = TelemetrySDKLanguageKey.String("go") - // java - TelemetrySDKLanguageJava = TelemetrySDKLanguageKey.String("java") - // nodejs - TelemetrySDKLanguageNodejs = TelemetrySDKLanguageKey.String("nodejs") - // php - TelemetrySDKLanguagePHP = TelemetrySDKLanguageKey.String("php") - // python - TelemetrySDKLanguagePython = TelemetrySDKLanguageKey.String("python") - // ruby - TelemetrySDKLanguageRuby = TelemetrySDKLanguageKey.String("ruby") - // rust - TelemetrySDKLanguageRust = TelemetrySDKLanguageKey.String("rust") - // swift - TelemetrySDKLanguageSwift = TelemetrySDKLanguageKey.String("swift") - // webjs - TelemetrySDKLanguageWebjs = TelemetrySDKLanguageKey.String("webjs") -) - -// TelemetrySDKName returns an attribute KeyValue conforming to the -// "telemetry.sdk.name" semantic conventions. It represents the name of the -// telemetry SDK as defined above. -func TelemetrySDKName(val string) attribute.KeyValue { - return TelemetrySDKNameKey.String(val) -} - -// TelemetrySDKVersion returns an attribute KeyValue conforming to the -// "telemetry.sdk.version" semantic conventions. It represents the version -// string of the telemetry SDK. -func TelemetrySDKVersion(val string) attribute.KeyValue { - return TelemetrySDKVersionKey.String(val) -} - -// The telemetry SDK used to capture data recorded by the instrumentation -// libraries. -const ( - // TelemetryDistroNameKey is the attribute Key conforming to the - // "telemetry.distro.name" semantic conventions. It represents the name of - // the auto instrumentation agent or distribution, if used. - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: 'parts-unlimited-java' - // Note: Official auto instrumentation agents and distributions SHOULD set - // the `telemetry.distro.name` attribute to - // a string starting with `opentelemetry-`, e.g. - // `opentelemetry-java-instrumentation`. - TelemetryDistroNameKey = attribute.Key("telemetry.distro.name") - - // TelemetryDistroVersionKey is the attribute Key conforming to the - // "telemetry.distro.version" semantic conventions. It represents the - // version string of the auto instrumentation agent or distribution, if - // used. - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: '1.2.3' - TelemetryDistroVersionKey = attribute.Key("telemetry.distro.version") -) - -// TelemetryDistroName returns an attribute KeyValue conforming to the -// "telemetry.distro.name" semantic conventions. It represents the name of the -// auto instrumentation agent or distribution, if used. -func TelemetryDistroName(val string) attribute.KeyValue { - return TelemetryDistroNameKey.String(val) -} - -// TelemetryDistroVersion returns an attribute KeyValue conforming to the -// "telemetry.distro.version" semantic conventions. It represents the version -// string of the auto instrumentation agent or distribution, if used. -func TelemetryDistroVersion(val string) attribute.KeyValue { - return TelemetryDistroVersionKey.String(val) -} - -// Resource describing the packaged software running the application code. Web -// engines are typically executed using process.runtime. -const ( - // WebEngineDescriptionKey is the attribute Key conforming to the - // "webengine.description" semantic conventions. It represents the - // additional description of the web engine (e.g. detailed version and - // edition information). - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: 'WildFly Full 21.0.0.Final (WildFly Core 13.0.1.Final) - - // 2.2.2.Final' - WebEngineDescriptionKey = attribute.Key("webengine.description") - - // WebEngineNameKey is the attribute Key conforming to the "webengine.name" - // semantic conventions. It represents the name of the web engine. - // - // Type: string - // RequirementLevel: Required - // Stability: experimental - // Examples: 'WildFly' - WebEngineNameKey = attribute.Key("webengine.name") - - // WebEngineVersionKey is the attribute Key conforming to the - // "webengine.version" semantic conventions. It represents the version of - // the web engine. - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: '21.0.0' - WebEngineVersionKey = attribute.Key("webengine.version") -) - -// WebEngineDescription returns an attribute KeyValue conforming to the -// "webengine.description" semantic conventions. It represents the additional -// description of the web engine (e.g. detailed version and edition -// information). -func WebEngineDescription(val string) attribute.KeyValue { - return WebEngineDescriptionKey.String(val) -} - -// WebEngineName returns an attribute KeyValue conforming to the -// "webengine.name" semantic conventions. It represents the name of the web -// engine. -func WebEngineName(val string) attribute.KeyValue { - return WebEngineNameKey.String(val) -} - -// WebEngineVersion returns an attribute KeyValue conforming to the -// "webengine.version" semantic conventions. It represents the version of the -// web engine. -func WebEngineVersion(val string) attribute.KeyValue { - return WebEngineVersionKey.String(val) -} - -// Attributes used by non-OTLP exporters to represent OpenTelemetry Scope's -// concepts. -const ( - // OTelScopeNameKey is the attribute Key conforming to the - // "otel.scope.name" semantic conventions. It represents the name of the - // instrumentation scope - (`InstrumentationScope.Name` in OTLP). - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: 'io.opentelemetry.contrib.mongodb' - OTelScopeNameKey = attribute.Key("otel.scope.name") - - // OTelScopeVersionKey is the attribute Key conforming to the - // "otel.scope.version" semantic conventions. It represents the version of - // the instrumentation scope - (`InstrumentationScope.Version` in OTLP). - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: '1.0.0' - OTelScopeVersionKey = attribute.Key("otel.scope.version") -) - -// OTelScopeName returns an attribute KeyValue conforming to the -// "otel.scope.name" semantic conventions. It represents the name of the -// instrumentation scope - (`InstrumentationScope.Name` in OTLP). -func OTelScopeName(val string) attribute.KeyValue { - return OTelScopeNameKey.String(val) -} - -// OTelScopeVersion returns an attribute KeyValue conforming to the -// "otel.scope.version" semantic conventions. It represents the version of the -// instrumentation scope - (`InstrumentationScope.Version` in OTLP). -func OTelScopeVersion(val string) attribute.KeyValue { - return OTelScopeVersionKey.String(val) -} - -// Span attributes used by non-OTLP exporters to represent OpenTelemetry -// Scope's concepts. -const ( - // OTelLibraryNameKey is the attribute Key conforming to the - // "otel.library.name" semantic conventions. - // - // Type: string - // RequirementLevel: Optional - // Stability: deprecated - // Examples: 'io.opentelemetry.contrib.mongodb' - // Deprecated: use the `otel.scope.name` attribute. - OTelLibraryNameKey = attribute.Key("otel.library.name") - - // OTelLibraryVersionKey is the attribute Key conforming to the - // "otel.library.version" semantic conventions. - // - // Type: string - // RequirementLevel: Optional - // Stability: deprecated - // Examples: '1.0.0' - // Deprecated: use the `otel.scope.version` attribute. - OTelLibraryVersionKey = attribute.Key("otel.library.version") -) - -// OTelLibraryName returns an attribute KeyValue conforming to the -// "otel.library.name" semantic conventions. -// -// Deprecated: use the `otel.scope.name` attribute. -func OTelLibraryName(val string) attribute.KeyValue { - return OTelLibraryNameKey.String(val) -} - -// OTelLibraryVersion returns an attribute KeyValue conforming to the -// "otel.library.version" semantic conventions. -// -// Deprecated: use the `otel.scope.version` attribute. -func OTelLibraryVersion(val string) attribute.KeyValue { - return OTelLibraryVersionKey.String(val) -} diff --git a/vendor/go.opentelemetry.io/otel/semconv/v1.24.0/schema.go b/vendor/go.opentelemetry.io/otel/semconv/v1.24.0/schema.go deleted file mode 100644 index fe80b1731..000000000 --- a/vendor/go.opentelemetry.io/otel/semconv/v1.24.0/schema.go +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright The OpenTelemetry Authors -// SPDX-License-Identifier: Apache-2.0 - -package semconv // import "go.opentelemetry.io/otel/semconv/v1.24.0" - -// SchemaURL is the schema URL that matches the version of the semantic conventions -// that this package defines. Semconv packages starting from v1.4.0 must declare -// non-empty schema URL in the form https://opentelemetry.io/schemas/ -const SchemaURL = "https://opentelemetry.io/schemas/1.24.0" diff --git a/vendor/go.opentelemetry.io/otel/semconv/v1.24.0/trace.go b/vendor/go.opentelemetry.io/otel/semconv/v1.24.0/trace.go deleted file mode 100644 index c1718234e..000000000 --- a/vendor/go.opentelemetry.io/otel/semconv/v1.24.0/trace.go +++ /dev/null @@ -1,1323 +0,0 @@ -// Copyright The OpenTelemetry Authors -// SPDX-License-Identifier: Apache-2.0 - -// Code generated from semantic convention specification. DO NOT EDIT. - -package semconv // import "go.opentelemetry.io/otel/semconv/v1.24.0" - -import "go.opentelemetry.io/otel/attribute" - -// Operations that access some remote service. -const ( - // PeerServiceKey is the attribute Key conforming to the "peer.service" - // semantic conventions. It represents the - // [`service.name`](/docs/resource/README.md#service) of the remote - // service. SHOULD be equal to the actual `service.name` resource attribute - // of the remote service if any. - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: 'AuthTokenCache' - PeerServiceKey = attribute.Key("peer.service") -) - -// PeerService returns an attribute KeyValue conforming to the -// "peer.service" semantic conventions. It represents the -// [`service.name`](/docs/resource/README.md#service) of the remote service. -// SHOULD be equal to the actual `service.name` resource attribute of the -// remote service if any. -func PeerService(val string) attribute.KeyValue { - return PeerServiceKey.String(val) -} - -// These attributes may be used for any operation with an authenticated and/or -// authorized enduser. -const ( - // EnduserIDKey is the attribute Key conforming to the "enduser.id" - // semantic conventions. It represents the username or client_id extracted - // from the access token or - // [Authorization](https://tools.ietf.org/html/rfc7235#section-4.2) header - // in the inbound request from outside the system. - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: 'username' - EnduserIDKey = attribute.Key("enduser.id") - - // EnduserRoleKey is the attribute Key conforming to the "enduser.role" - // semantic conventions. It represents the actual/assumed role the client - // is making the request under extracted from token or application security - // context. - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: 'admin' - EnduserRoleKey = attribute.Key("enduser.role") - - // EnduserScopeKey is the attribute Key conforming to the "enduser.scope" - // semantic conventions. It represents the scopes or granted authorities - // the client currently possesses extracted from token or application - // security context. The value would come from the scope associated with an - // [OAuth 2.0 Access - // Token](https://tools.ietf.org/html/rfc6749#section-3.3) or an attribute - // value in a [SAML 2.0 - // Assertion](http://docs.oasis-open.org/security/saml/Post2.0/sstc-saml-tech-overview-2.0.html). - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: 'read:message, write:files' - EnduserScopeKey = attribute.Key("enduser.scope") -) - -// EnduserID returns an attribute KeyValue conforming to the "enduser.id" -// semantic conventions. It represents the username or client_id extracted from -// the access token or -// [Authorization](https://tools.ietf.org/html/rfc7235#section-4.2) header in -// the inbound request from outside the system. -func EnduserID(val string) attribute.KeyValue { - return EnduserIDKey.String(val) -} - -// EnduserRole returns an attribute KeyValue conforming to the -// "enduser.role" semantic conventions. It represents the actual/assumed role -// the client is making the request under extracted from token or application -// security context. -func EnduserRole(val string) attribute.KeyValue { - return EnduserRoleKey.String(val) -} - -// EnduserScope returns an attribute KeyValue conforming to the -// "enduser.scope" semantic conventions. It represents the scopes or granted -// authorities the client currently possesses extracted from token or -// application security context. The value would come from the scope associated -// with an [OAuth 2.0 Access -// Token](https://tools.ietf.org/html/rfc6749#section-3.3) or an attribute -// value in a [SAML 2.0 -// Assertion](http://docs.oasis-open.org/security/saml/Post2.0/sstc-saml-tech-overview-2.0.html). -func EnduserScope(val string) attribute.KeyValue { - return EnduserScopeKey.String(val) -} - -// These attributes allow to report this unit of code and therefore to provide -// more context about the span. -const ( - // CodeColumnKey is the attribute Key conforming to the "code.column" - // semantic conventions. It represents the column number in `code.filepath` - // best representing the operation. It SHOULD point within the code unit - // named in `code.function`. - // - // Type: int - // RequirementLevel: Optional - // Stability: experimental - // Examples: 16 - CodeColumnKey = attribute.Key("code.column") - - // CodeFilepathKey is the attribute Key conforming to the "code.filepath" - // semantic conventions. It represents the source code file name that - // identifies the code unit as uniquely as possible (preferably an absolute - // file path). - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: '/usr/local/MyApplication/content_root/app/index.php' - CodeFilepathKey = attribute.Key("code.filepath") - - // CodeFunctionKey is the attribute Key conforming to the "code.function" - // semantic conventions. It represents the method or function name, or - // equivalent (usually rightmost part of the code unit's name). - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: 'serveRequest' - CodeFunctionKey = attribute.Key("code.function") - - // CodeLineNumberKey is the attribute Key conforming to the "code.lineno" - // semantic conventions. It represents the line number in `code.filepath` - // best representing the operation. It SHOULD point within the code unit - // named in `code.function`. - // - // Type: int - // RequirementLevel: Optional - // Stability: experimental - // Examples: 42 - CodeLineNumberKey = attribute.Key("code.lineno") - - // CodeNamespaceKey is the attribute Key conforming to the "code.namespace" - // semantic conventions. It represents the "namespace" within which - // `code.function` is defined. Usually the qualified class or module name, - // such that `code.namespace` + some separator + `code.function` form a - // unique identifier for the code unit. - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: 'com.example.MyHTTPService' - CodeNamespaceKey = attribute.Key("code.namespace") - - // CodeStacktraceKey is the attribute Key conforming to the - // "code.stacktrace" semantic conventions. It represents a stacktrace as a - // string in the natural representation for the language runtime. The - // representation is to be determined and documented by each language SIG. - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: 'at - // com.example.GenerateTrace.methodB(GenerateTrace.java:13)\\n at ' - // 'com.example.GenerateTrace.methodA(GenerateTrace.java:9)\\n at ' - // 'com.example.GenerateTrace.main(GenerateTrace.java:5)' - CodeStacktraceKey = attribute.Key("code.stacktrace") -) - -// CodeColumn returns an attribute KeyValue conforming to the "code.column" -// semantic conventions. It represents the column number in `code.filepath` -// best representing the operation. It SHOULD point within the code unit named -// in `code.function`. -func CodeColumn(val int) attribute.KeyValue { - return CodeColumnKey.Int(val) -} - -// CodeFilepath returns an attribute KeyValue conforming to the -// "code.filepath" semantic conventions. It represents the source code file -// name that identifies the code unit as uniquely as possible (preferably an -// absolute file path). -func CodeFilepath(val string) attribute.KeyValue { - return CodeFilepathKey.String(val) -} - -// CodeFunction returns an attribute KeyValue conforming to the -// "code.function" semantic conventions. It represents the method or function -// name, or equivalent (usually rightmost part of the code unit's name). -func CodeFunction(val string) attribute.KeyValue { - return CodeFunctionKey.String(val) -} - -// CodeLineNumber returns an attribute KeyValue conforming to the "code.lineno" -// semantic conventions. It represents the line number in `code.filepath` best -// representing the operation. It SHOULD point within the code unit named in -// `code.function`. -func CodeLineNumber(val int) attribute.KeyValue { - return CodeLineNumberKey.Int(val) -} - -// CodeNamespace returns an attribute KeyValue conforming to the -// "code.namespace" semantic conventions. It represents the "namespace" within -// which `code.function` is defined. Usually the qualified class or module -// name, such that `code.namespace` + some separator + `code.function` form a -// unique identifier for the code unit. -func CodeNamespace(val string) attribute.KeyValue { - return CodeNamespaceKey.String(val) -} - -// CodeStacktrace returns an attribute KeyValue conforming to the -// "code.stacktrace" semantic conventions. It represents a stacktrace as a -// string in the natural representation for the language runtime. The -// representation is to be determined and documented by each language SIG. -func CodeStacktrace(val string) attribute.KeyValue { - return CodeStacktraceKey.String(val) -} - -// These attributes may be used for any operation to store information about a -// thread that started a span. -const ( - // ThreadIDKey is the attribute Key conforming to the "thread.id" semantic - // conventions. It represents the current "managed" thread ID (as opposed - // to OS thread ID). - // - // Type: int - // RequirementLevel: Optional - // Stability: experimental - // Examples: 42 - ThreadIDKey = attribute.Key("thread.id") - - // ThreadNameKey is the attribute Key conforming to the "thread.name" - // semantic conventions. It represents the current thread name. - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: 'main' - ThreadNameKey = attribute.Key("thread.name") -) - -// ThreadID returns an attribute KeyValue conforming to the "thread.id" -// semantic conventions. It represents the current "managed" thread ID (as -// opposed to OS thread ID). -func ThreadID(val int) attribute.KeyValue { - return ThreadIDKey.Int(val) -} - -// ThreadName returns an attribute KeyValue conforming to the "thread.name" -// semantic conventions. It represents the current thread name. -func ThreadName(val string) attribute.KeyValue { - return ThreadNameKey.String(val) -} - -// Span attributes used by AWS Lambda (in addition to general `faas` -// attributes). -const ( - // AWSLambdaInvokedARNKey is the attribute Key conforming to the - // "aws.lambda.invoked_arn" semantic conventions. It represents the full - // invoked ARN as provided on the `Context` passed to the function - // (`Lambda-Runtime-Invoked-Function-ARN` header on the - // `/runtime/invocation/next` applicable). - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: 'arn:aws:lambda:us-east-1:123456:function:myfunction:myalias' - // Note: This may be different from `cloud.resource_id` if an alias is - // involved. - AWSLambdaInvokedARNKey = attribute.Key("aws.lambda.invoked_arn") -) - -// AWSLambdaInvokedARN returns an attribute KeyValue conforming to the -// "aws.lambda.invoked_arn" semantic conventions. It represents the full -// invoked ARN as provided on the `Context` passed to the function -// (`Lambda-Runtime-Invoked-Function-ARN` header on the -// `/runtime/invocation/next` applicable). -func AWSLambdaInvokedARN(val string) attribute.KeyValue { - return AWSLambdaInvokedARNKey.String(val) -} - -// Attributes for CloudEvents. CloudEvents is a specification on how to define -// event data in a standard way. These attributes can be attached to spans when -// performing operations with CloudEvents, regardless of the protocol being -// used. -const ( - // CloudeventsEventIDKey is the attribute Key conforming to the - // "cloudevents.event_id" semantic conventions. It represents the - // [event_id](https://github.com/cloudevents/spec/blob/v1.0.2/cloudevents/spec.md#id) - // uniquely identifies the event. - // - // Type: string - // RequirementLevel: Required - // Stability: experimental - // Examples: '123e4567-e89b-12d3-a456-426614174000', '0001' - CloudeventsEventIDKey = attribute.Key("cloudevents.event_id") - - // CloudeventsEventSourceKey is the attribute Key conforming to the - // "cloudevents.event_source" semantic conventions. It represents the - // [source](https://github.com/cloudevents/spec/blob/v1.0.2/cloudevents/spec.md#source-1) - // identifies the context in which an event happened. - // - // Type: string - // RequirementLevel: Required - // Stability: experimental - // Examples: 'https://github.com/cloudevents', - // '/cloudevents/spec/pull/123', 'my-service' - CloudeventsEventSourceKey = attribute.Key("cloudevents.event_source") - - // CloudeventsEventSpecVersionKey is the attribute Key conforming to the - // "cloudevents.event_spec_version" semantic conventions. It represents the - // [version of the CloudEvents - // specification](https://github.com/cloudevents/spec/blob/v1.0.2/cloudevents/spec.md#specversion) - // which the event uses. - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: '1.0' - CloudeventsEventSpecVersionKey = attribute.Key("cloudevents.event_spec_version") - - // CloudeventsEventSubjectKey is the attribute Key conforming to the - // "cloudevents.event_subject" semantic conventions. It represents the - // [subject](https://github.com/cloudevents/spec/blob/v1.0.2/cloudevents/spec.md#subject) - // of the event in the context of the event producer (identified by - // source). - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: 'mynewfile.jpg' - CloudeventsEventSubjectKey = attribute.Key("cloudevents.event_subject") - - // CloudeventsEventTypeKey is the attribute Key conforming to the - // "cloudevents.event_type" semantic conventions. It represents the - // [event_type](https://github.com/cloudevents/spec/blob/v1.0.2/cloudevents/spec.md#type) - // contains a value describing the type of event related to the originating - // occurrence. - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: 'com.github.pull_request.opened', - // 'com.example.object.deleted.v2' - CloudeventsEventTypeKey = attribute.Key("cloudevents.event_type") -) - -// CloudeventsEventID returns an attribute KeyValue conforming to the -// "cloudevents.event_id" semantic conventions. It represents the -// [event_id](https://github.com/cloudevents/spec/blob/v1.0.2/cloudevents/spec.md#id) -// uniquely identifies the event. -func CloudeventsEventID(val string) attribute.KeyValue { - return CloudeventsEventIDKey.String(val) -} - -// CloudeventsEventSource returns an attribute KeyValue conforming to the -// "cloudevents.event_source" semantic conventions. It represents the -// [source](https://github.com/cloudevents/spec/blob/v1.0.2/cloudevents/spec.md#source-1) -// identifies the context in which an event happened. -func CloudeventsEventSource(val string) attribute.KeyValue { - return CloudeventsEventSourceKey.String(val) -} - -// CloudeventsEventSpecVersion returns an attribute KeyValue conforming to -// the "cloudevents.event_spec_version" semantic conventions. It represents the -// [version of the CloudEvents -// specification](https://github.com/cloudevents/spec/blob/v1.0.2/cloudevents/spec.md#specversion) -// which the event uses. -func CloudeventsEventSpecVersion(val string) attribute.KeyValue { - return CloudeventsEventSpecVersionKey.String(val) -} - -// CloudeventsEventSubject returns an attribute KeyValue conforming to the -// "cloudevents.event_subject" semantic conventions. It represents the -// [subject](https://github.com/cloudevents/spec/blob/v1.0.2/cloudevents/spec.md#subject) -// of the event in the context of the event producer (identified by source). -func CloudeventsEventSubject(val string) attribute.KeyValue { - return CloudeventsEventSubjectKey.String(val) -} - -// CloudeventsEventType returns an attribute KeyValue conforming to the -// "cloudevents.event_type" semantic conventions. It represents the -// [event_type](https://github.com/cloudevents/spec/blob/v1.0.2/cloudevents/spec.md#type) -// contains a value describing the type of event related to the originating -// occurrence. -func CloudeventsEventType(val string) attribute.KeyValue { - return CloudeventsEventTypeKey.String(val) -} - -// Semantic conventions for the OpenTracing Shim -const ( - // OpentracingRefTypeKey is the attribute Key conforming to the - // "opentracing.ref_type" semantic conventions. It represents the - // parent-child Reference type - // - // Type: Enum - // RequirementLevel: Optional - // Stability: experimental - // Note: The causal relationship between a child Span and a parent Span. - OpentracingRefTypeKey = attribute.Key("opentracing.ref_type") -) - -var ( - // The parent Span depends on the child Span in some capacity - OpentracingRefTypeChildOf = OpentracingRefTypeKey.String("child_of") - // The parent Span doesn't depend in any way on the result of the child Span - OpentracingRefTypeFollowsFrom = OpentracingRefTypeKey.String("follows_from") -) - -// Span attributes used by non-OTLP exporters to represent OpenTelemetry Span's -// concepts. -const ( - // OTelStatusCodeKey is the attribute Key conforming to the - // "otel.status_code" semantic conventions. It represents the name of the - // code, either "OK" or "ERROR". MUST NOT be set if the status code is - // UNSET. - // - // Type: Enum - // RequirementLevel: Optional - // Stability: experimental - OTelStatusCodeKey = attribute.Key("otel.status_code") - - // OTelStatusDescriptionKey is the attribute Key conforming to the - // "otel.status_description" semantic conventions. It represents the - // description of the Status if it has a value, otherwise not set. - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: 'resource not found' - OTelStatusDescriptionKey = attribute.Key("otel.status_description") -) - -var ( - // The operation has been validated by an Application developer or Operator to have completed successfully - OTelStatusCodeOk = OTelStatusCodeKey.String("OK") - // The operation contains an error - OTelStatusCodeError = OTelStatusCodeKey.String("ERROR") -) - -// OTelStatusDescription returns an attribute KeyValue conforming to the -// "otel.status_description" semantic conventions. It represents the -// description of the Status if it has a value, otherwise not set. -func OTelStatusDescription(val string) attribute.KeyValue { - return OTelStatusDescriptionKey.String(val) -} - -// This semantic convention describes an instance of a function that runs -// without provisioning or managing of servers (also known as serverless -// functions or Function as a Service (FaaS)) with spans. -const ( - // FaaSInvocationIDKey is the attribute Key conforming to the - // "faas.invocation_id" semantic conventions. It represents the invocation - // ID of the current function invocation. - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: 'af9d5aa4-a685-4c5f-a22b-444f80b3cc28' - FaaSInvocationIDKey = attribute.Key("faas.invocation_id") -) - -// FaaSInvocationID returns an attribute KeyValue conforming to the -// "faas.invocation_id" semantic conventions. It represents the invocation ID -// of the current function invocation. -func FaaSInvocationID(val string) attribute.KeyValue { - return FaaSInvocationIDKey.String(val) -} - -// Semantic Convention for FaaS triggered as a response to some data source -// operation such as a database or filesystem read/write. -const ( - // FaaSDocumentCollectionKey is the attribute Key conforming to the - // "faas.document.collection" semantic conventions. It represents the name - // of the source on which the triggering operation was performed. For - // example, in Cloud Storage or S3 corresponds to the bucket name, and in - // Cosmos DB to the database name. - // - // Type: string - // RequirementLevel: Required - // Stability: experimental - // Examples: 'myBucketName', 'myDBName' - FaaSDocumentCollectionKey = attribute.Key("faas.document.collection") - - // FaaSDocumentNameKey is the attribute Key conforming to the - // "faas.document.name" semantic conventions. It represents the document - // name/table subjected to the operation. For example, in Cloud Storage or - // S3 is the name of the file, and in Cosmos DB the table name. - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: 'myFile.txt', 'myTableName' - FaaSDocumentNameKey = attribute.Key("faas.document.name") - - // FaaSDocumentOperationKey is the attribute Key conforming to the - // "faas.document.operation" semantic conventions. It represents the - // describes the type of the operation that was performed on the data. - // - // Type: Enum - // RequirementLevel: Required - // Stability: experimental - FaaSDocumentOperationKey = attribute.Key("faas.document.operation") - - // FaaSDocumentTimeKey is the attribute Key conforming to the - // "faas.document.time" semantic conventions. It represents a string - // containing the time when the data was accessed in the [ISO - // 8601](https://www.iso.org/iso-8601-date-and-time-format.html) format - // expressed in [UTC](https://www.w3.org/TR/NOTE-datetime). - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: '2020-01-23T13:47:06Z' - FaaSDocumentTimeKey = attribute.Key("faas.document.time") -) - -var ( - // When a new object is created - FaaSDocumentOperationInsert = FaaSDocumentOperationKey.String("insert") - // When an object is modified - FaaSDocumentOperationEdit = FaaSDocumentOperationKey.String("edit") - // When an object is deleted - FaaSDocumentOperationDelete = FaaSDocumentOperationKey.String("delete") -) - -// FaaSDocumentCollection returns an attribute KeyValue conforming to the -// "faas.document.collection" semantic conventions. It represents the name of -// the source on which the triggering operation was performed. For example, in -// Cloud Storage or S3 corresponds to the bucket name, and in Cosmos DB to the -// database name. -func FaaSDocumentCollection(val string) attribute.KeyValue { - return FaaSDocumentCollectionKey.String(val) -} - -// FaaSDocumentName returns an attribute KeyValue conforming to the -// "faas.document.name" semantic conventions. It represents the document -// name/table subjected to the operation. For example, in Cloud Storage or S3 -// is the name of the file, and in Cosmos DB the table name. -func FaaSDocumentName(val string) attribute.KeyValue { - return FaaSDocumentNameKey.String(val) -} - -// FaaSDocumentTime returns an attribute KeyValue conforming to the -// "faas.document.time" semantic conventions. It represents a string containing -// the time when the data was accessed in the [ISO -// 8601](https://www.iso.org/iso-8601-date-and-time-format.html) format -// expressed in [UTC](https://www.w3.org/TR/NOTE-datetime). -func FaaSDocumentTime(val string) attribute.KeyValue { - return FaaSDocumentTimeKey.String(val) -} - -// Semantic Convention for FaaS scheduled to be executed regularly. -const ( - // FaaSCronKey is the attribute Key conforming to the "faas.cron" semantic - // conventions. It represents a string containing the schedule period as - // [Cron - // Expression](https://docs.oracle.com/cd/E12058_01/doc/doc.1014/e12030/cron_expressions.htm). - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: '0/5 * * * ? *' - FaaSCronKey = attribute.Key("faas.cron") - - // FaaSTimeKey is the attribute Key conforming to the "faas.time" semantic - // conventions. It represents a string containing the function invocation - // time in the [ISO - // 8601](https://www.iso.org/iso-8601-date-and-time-format.html) format - // expressed in [UTC](https://www.w3.org/TR/NOTE-datetime). - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: '2020-01-23T13:47:06Z' - FaaSTimeKey = attribute.Key("faas.time") -) - -// FaaSCron returns an attribute KeyValue conforming to the "faas.cron" -// semantic conventions. It represents a string containing the schedule period -// as [Cron -// Expression](https://docs.oracle.com/cd/E12058_01/doc/doc.1014/e12030/cron_expressions.htm). -func FaaSCron(val string) attribute.KeyValue { - return FaaSCronKey.String(val) -} - -// FaaSTime returns an attribute KeyValue conforming to the "faas.time" -// semantic conventions. It represents a string containing the function -// invocation time in the [ISO -// 8601](https://www.iso.org/iso-8601-date-and-time-format.html) format -// expressed in [UTC](https://www.w3.org/TR/NOTE-datetime). -func FaaSTime(val string) attribute.KeyValue { - return FaaSTimeKey.String(val) -} - -// Contains additional attributes for incoming FaaS spans. -const ( - // FaaSColdstartKey is the attribute Key conforming to the "faas.coldstart" - // semantic conventions. It represents a boolean that is true if the - // serverless function is executed for the first time (aka cold-start). - // - // Type: boolean - // RequirementLevel: Optional - // Stability: experimental - FaaSColdstartKey = attribute.Key("faas.coldstart") -) - -// FaaSColdstart returns an attribute KeyValue conforming to the -// "faas.coldstart" semantic conventions. It represents a boolean that is true -// if the serverless function is executed for the first time (aka cold-start). -func FaaSColdstart(val bool) attribute.KeyValue { - return FaaSColdstartKey.Bool(val) -} - -// The `aws` conventions apply to operations using the AWS SDK. They map -// request or response parameters in AWS SDK API calls to attributes on a Span. -// The conventions have been collected over time based on feedback from AWS -// users of tracing and will continue to evolve as new interesting conventions -// are found. -// Some descriptions are also provided for populating general OpenTelemetry -// semantic conventions based on these APIs. -const ( - // AWSRequestIDKey is the attribute Key conforming to the "aws.request_id" - // semantic conventions. It represents the AWS request ID as returned in - // the response headers `x-amz-request-id` or `x-amz-requestid`. - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: '79b9da39-b7ae-508a-a6bc-864b2829c622', 'C9ER4AJX75574TDJ' - AWSRequestIDKey = attribute.Key("aws.request_id") -) - -// AWSRequestID returns an attribute KeyValue conforming to the -// "aws.request_id" semantic conventions. It represents the AWS request ID as -// returned in the response headers `x-amz-request-id` or `x-amz-requestid`. -func AWSRequestID(val string) attribute.KeyValue { - return AWSRequestIDKey.String(val) -} - -// Attributes that exist for multiple DynamoDB request types. -const ( - // AWSDynamoDBAttributesToGetKey is the attribute Key conforming to the - // "aws.dynamodb.attributes_to_get" semantic conventions. It represents the - // value of the `AttributesToGet` request parameter. - // - // Type: string[] - // RequirementLevel: Optional - // Stability: experimental - // Examples: 'lives', 'id' - AWSDynamoDBAttributesToGetKey = attribute.Key("aws.dynamodb.attributes_to_get") - - // AWSDynamoDBConsistentReadKey is the attribute Key conforming to the - // "aws.dynamodb.consistent_read" semantic conventions. It represents the - // value of the `ConsistentRead` request parameter. - // - // Type: boolean - // RequirementLevel: Optional - // Stability: experimental - AWSDynamoDBConsistentReadKey = attribute.Key("aws.dynamodb.consistent_read") - - // AWSDynamoDBConsumedCapacityKey is the attribute Key conforming to the - // "aws.dynamodb.consumed_capacity" semantic conventions. It represents the - // JSON-serialized value of each item in the `ConsumedCapacity` response - // field. - // - // Type: string[] - // RequirementLevel: Optional - // Stability: experimental - // Examples: '{ "CapacityUnits": number, "GlobalSecondaryIndexes": { - // "string" : { "CapacityUnits": number, "ReadCapacityUnits": number, - // "WriteCapacityUnits": number } }, "LocalSecondaryIndexes": { "string" : - // { "CapacityUnits": number, "ReadCapacityUnits": number, - // "WriteCapacityUnits": number } }, "ReadCapacityUnits": number, "Table": - // { "CapacityUnits": number, "ReadCapacityUnits": number, - // "WriteCapacityUnits": number }, "TableName": "string", - // "WriteCapacityUnits": number }' - AWSDynamoDBConsumedCapacityKey = attribute.Key("aws.dynamodb.consumed_capacity") - - // AWSDynamoDBIndexNameKey is the attribute Key conforming to the - // "aws.dynamodb.index_name" semantic conventions. It represents the value - // of the `IndexName` request parameter. - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: 'name_to_group' - AWSDynamoDBIndexNameKey = attribute.Key("aws.dynamodb.index_name") - - // AWSDynamoDBItemCollectionMetricsKey is the attribute Key conforming to - // the "aws.dynamodb.item_collection_metrics" semantic conventions. It - // represents the JSON-serialized value of the `ItemCollectionMetrics` - // response field. - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: '{ "string" : [ { "ItemCollectionKey": { "string" : { "B": - // blob, "BOOL": boolean, "BS": [ blob ], "L": [ "AttributeValue" ], "M": { - // "string" : "AttributeValue" }, "N": "string", "NS": [ "string" ], - // "NULL": boolean, "S": "string", "SS": [ "string" ] } }, - // "SizeEstimateRangeGB": [ number ] } ] }' - AWSDynamoDBItemCollectionMetricsKey = attribute.Key("aws.dynamodb.item_collection_metrics") - - // AWSDynamoDBLimitKey is the attribute Key conforming to the - // "aws.dynamodb.limit" semantic conventions. It represents the value of - // the `Limit` request parameter. - // - // Type: int - // RequirementLevel: Optional - // Stability: experimental - // Examples: 10 - AWSDynamoDBLimitKey = attribute.Key("aws.dynamodb.limit") - - // AWSDynamoDBProjectionKey is the attribute Key conforming to the - // "aws.dynamodb.projection" semantic conventions. It represents the value - // of the `ProjectionExpression` request parameter. - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: 'Title', 'Title, Price, Color', 'Title, Description, - // RelatedItems, ProductReviews' - AWSDynamoDBProjectionKey = attribute.Key("aws.dynamodb.projection") - - // AWSDynamoDBProvisionedReadCapacityKey is the attribute Key conforming to - // the "aws.dynamodb.provisioned_read_capacity" semantic conventions. It - // represents the value of the `ProvisionedThroughput.ReadCapacityUnits` - // request parameter. - // - // Type: double - // RequirementLevel: Optional - // Stability: experimental - // Examples: 1.0, 2.0 - AWSDynamoDBProvisionedReadCapacityKey = attribute.Key("aws.dynamodb.provisioned_read_capacity") - - // AWSDynamoDBProvisionedWriteCapacityKey is the attribute Key conforming - // to the "aws.dynamodb.provisioned_write_capacity" semantic conventions. - // It represents the value of the - // `ProvisionedThroughput.WriteCapacityUnits` request parameter. - // - // Type: double - // RequirementLevel: Optional - // Stability: experimental - // Examples: 1.0, 2.0 - AWSDynamoDBProvisionedWriteCapacityKey = attribute.Key("aws.dynamodb.provisioned_write_capacity") - - // AWSDynamoDBSelectKey is the attribute Key conforming to the - // "aws.dynamodb.select" semantic conventions. It represents the value of - // the `Select` request parameter. - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: 'ALL_ATTRIBUTES', 'COUNT' - AWSDynamoDBSelectKey = attribute.Key("aws.dynamodb.select") - - // AWSDynamoDBTableNamesKey is the attribute Key conforming to the - // "aws.dynamodb.table_names" semantic conventions. It represents the keys - // in the `RequestItems` object field. - // - // Type: string[] - // RequirementLevel: Optional - // Stability: experimental - // Examples: 'Users', 'Cats' - AWSDynamoDBTableNamesKey = attribute.Key("aws.dynamodb.table_names") -) - -// AWSDynamoDBAttributesToGet returns an attribute KeyValue conforming to -// the "aws.dynamodb.attributes_to_get" semantic conventions. It represents the -// value of the `AttributesToGet` request parameter. -func AWSDynamoDBAttributesToGet(val ...string) attribute.KeyValue { - return AWSDynamoDBAttributesToGetKey.StringSlice(val) -} - -// AWSDynamoDBConsistentRead returns an attribute KeyValue conforming to the -// "aws.dynamodb.consistent_read" semantic conventions. It represents the value -// of the `ConsistentRead` request parameter. -func AWSDynamoDBConsistentRead(val bool) attribute.KeyValue { - return AWSDynamoDBConsistentReadKey.Bool(val) -} - -// AWSDynamoDBConsumedCapacity returns an attribute KeyValue conforming to -// the "aws.dynamodb.consumed_capacity" semantic conventions. It represents the -// JSON-serialized value of each item in the `ConsumedCapacity` response field. -func AWSDynamoDBConsumedCapacity(val ...string) attribute.KeyValue { - return AWSDynamoDBConsumedCapacityKey.StringSlice(val) -} - -// AWSDynamoDBIndexName returns an attribute KeyValue conforming to the -// "aws.dynamodb.index_name" semantic conventions. It represents the value of -// the `IndexName` request parameter. -func AWSDynamoDBIndexName(val string) attribute.KeyValue { - return AWSDynamoDBIndexNameKey.String(val) -} - -// AWSDynamoDBItemCollectionMetrics returns an attribute KeyValue conforming -// to the "aws.dynamodb.item_collection_metrics" semantic conventions. It -// represents the JSON-serialized value of the `ItemCollectionMetrics` response -// field. -func AWSDynamoDBItemCollectionMetrics(val string) attribute.KeyValue { - return AWSDynamoDBItemCollectionMetricsKey.String(val) -} - -// AWSDynamoDBLimit returns an attribute KeyValue conforming to the -// "aws.dynamodb.limit" semantic conventions. It represents the value of the -// `Limit` request parameter. -func AWSDynamoDBLimit(val int) attribute.KeyValue { - return AWSDynamoDBLimitKey.Int(val) -} - -// AWSDynamoDBProjection returns an attribute KeyValue conforming to the -// "aws.dynamodb.projection" semantic conventions. It represents the value of -// the `ProjectionExpression` request parameter. -func AWSDynamoDBProjection(val string) attribute.KeyValue { - return AWSDynamoDBProjectionKey.String(val) -} - -// AWSDynamoDBProvisionedReadCapacity returns an attribute KeyValue -// conforming to the "aws.dynamodb.provisioned_read_capacity" semantic -// conventions. It represents the value of the -// `ProvisionedThroughput.ReadCapacityUnits` request parameter. -func AWSDynamoDBProvisionedReadCapacity(val float64) attribute.KeyValue { - return AWSDynamoDBProvisionedReadCapacityKey.Float64(val) -} - -// AWSDynamoDBProvisionedWriteCapacity returns an attribute KeyValue -// conforming to the "aws.dynamodb.provisioned_write_capacity" semantic -// conventions. It represents the value of the -// `ProvisionedThroughput.WriteCapacityUnits` request parameter. -func AWSDynamoDBProvisionedWriteCapacity(val float64) attribute.KeyValue { - return AWSDynamoDBProvisionedWriteCapacityKey.Float64(val) -} - -// AWSDynamoDBSelect returns an attribute KeyValue conforming to the -// "aws.dynamodb.select" semantic conventions. It represents the value of the -// `Select` request parameter. -func AWSDynamoDBSelect(val string) attribute.KeyValue { - return AWSDynamoDBSelectKey.String(val) -} - -// AWSDynamoDBTableNames returns an attribute KeyValue conforming to the -// "aws.dynamodb.table_names" semantic conventions. It represents the keys in -// the `RequestItems` object field. -func AWSDynamoDBTableNames(val ...string) attribute.KeyValue { - return AWSDynamoDBTableNamesKey.StringSlice(val) -} - -// DynamoDB.CreateTable -const ( - // AWSDynamoDBGlobalSecondaryIndexesKey is the attribute Key conforming to - // the "aws.dynamodb.global_secondary_indexes" semantic conventions. It - // represents the JSON-serialized value of each item of the - // `GlobalSecondaryIndexes` request field - // - // Type: string[] - // RequirementLevel: Optional - // Stability: experimental - // Examples: '{ "IndexName": "string", "KeySchema": [ { "AttributeName": - // "string", "KeyType": "string" } ], "Projection": { "NonKeyAttributes": [ - // "string" ], "ProjectionType": "string" }, "ProvisionedThroughput": { - // "ReadCapacityUnits": number, "WriteCapacityUnits": number } }' - AWSDynamoDBGlobalSecondaryIndexesKey = attribute.Key("aws.dynamodb.global_secondary_indexes") - - // AWSDynamoDBLocalSecondaryIndexesKey is the attribute Key conforming to - // the "aws.dynamodb.local_secondary_indexes" semantic conventions. It - // represents the JSON-serialized value of each item of the - // `LocalSecondaryIndexes` request field. - // - // Type: string[] - // RequirementLevel: Optional - // Stability: experimental - // Examples: '{ "IndexARN": "string", "IndexName": "string", - // "IndexSizeBytes": number, "ItemCount": number, "KeySchema": [ { - // "AttributeName": "string", "KeyType": "string" } ], "Projection": { - // "NonKeyAttributes": [ "string" ], "ProjectionType": "string" } }' - AWSDynamoDBLocalSecondaryIndexesKey = attribute.Key("aws.dynamodb.local_secondary_indexes") -) - -// AWSDynamoDBGlobalSecondaryIndexes returns an attribute KeyValue -// conforming to the "aws.dynamodb.global_secondary_indexes" semantic -// conventions. It represents the JSON-serialized value of each item of the -// `GlobalSecondaryIndexes` request field -func AWSDynamoDBGlobalSecondaryIndexes(val ...string) attribute.KeyValue { - return AWSDynamoDBGlobalSecondaryIndexesKey.StringSlice(val) -} - -// AWSDynamoDBLocalSecondaryIndexes returns an attribute KeyValue conforming -// to the "aws.dynamodb.local_secondary_indexes" semantic conventions. It -// represents the JSON-serialized value of each item of the -// `LocalSecondaryIndexes` request field. -func AWSDynamoDBLocalSecondaryIndexes(val ...string) attribute.KeyValue { - return AWSDynamoDBLocalSecondaryIndexesKey.StringSlice(val) -} - -// DynamoDB.ListTables -const ( - // AWSDynamoDBExclusiveStartTableKey is the attribute Key conforming to the - // "aws.dynamodb.exclusive_start_table" semantic conventions. It represents - // the value of the `ExclusiveStartTableName` request parameter. - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: 'Users', 'CatsTable' - AWSDynamoDBExclusiveStartTableKey = attribute.Key("aws.dynamodb.exclusive_start_table") - - // AWSDynamoDBTableCountKey is the attribute Key conforming to the - // "aws.dynamodb.table_count" semantic conventions. It represents the the - // number of items in the `TableNames` response parameter. - // - // Type: int - // RequirementLevel: Optional - // Stability: experimental - // Examples: 20 - AWSDynamoDBTableCountKey = attribute.Key("aws.dynamodb.table_count") -) - -// AWSDynamoDBExclusiveStartTable returns an attribute KeyValue conforming -// to the "aws.dynamodb.exclusive_start_table" semantic conventions. It -// represents the value of the `ExclusiveStartTableName` request parameter. -func AWSDynamoDBExclusiveStartTable(val string) attribute.KeyValue { - return AWSDynamoDBExclusiveStartTableKey.String(val) -} - -// AWSDynamoDBTableCount returns an attribute KeyValue conforming to the -// "aws.dynamodb.table_count" semantic conventions. It represents the the -// number of items in the `TableNames` response parameter. -func AWSDynamoDBTableCount(val int) attribute.KeyValue { - return AWSDynamoDBTableCountKey.Int(val) -} - -// DynamoDB.Query -const ( - // AWSDynamoDBScanForwardKey is the attribute Key conforming to the - // "aws.dynamodb.scan_forward" semantic conventions. It represents the - // value of the `ScanIndexForward` request parameter. - // - // Type: boolean - // RequirementLevel: Optional - // Stability: experimental - AWSDynamoDBScanForwardKey = attribute.Key("aws.dynamodb.scan_forward") -) - -// AWSDynamoDBScanForward returns an attribute KeyValue conforming to the -// "aws.dynamodb.scan_forward" semantic conventions. It represents the value of -// the `ScanIndexForward` request parameter. -func AWSDynamoDBScanForward(val bool) attribute.KeyValue { - return AWSDynamoDBScanForwardKey.Bool(val) -} - -// DynamoDB.Scan -const ( - // AWSDynamoDBCountKey is the attribute Key conforming to the - // "aws.dynamodb.count" semantic conventions. It represents the value of - // the `Count` response parameter. - // - // Type: int - // RequirementLevel: Optional - // Stability: experimental - // Examples: 10 - AWSDynamoDBCountKey = attribute.Key("aws.dynamodb.count") - - // AWSDynamoDBScannedCountKey is the attribute Key conforming to the - // "aws.dynamodb.scanned_count" semantic conventions. It represents the - // value of the `ScannedCount` response parameter. - // - // Type: int - // RequirementLevel: Optional - // Stability: experimental - // Examples: 50 - AWSDynamoDBScannedCountKey = attribute.Key("aws.dynamodb.scanned_count") - - // AWSDynamoDBSegmentKey is the attribute Key conforming to the - // "aws.dynamodb.segment" semantic conventions. It represents the value of - // the `Segment` request parameter. - // - // Type: int - // RequirementLevel: Optional - // Stability: experimental - // Examples: 10 - AWSDynamoDBSegmentKey = attribute.Key("aws.dynamodb.segment") - - // AWSDynamoDBTotalSegmentsKey is the attribute Key conforming to the - // "aws.dynamodb.total_segments" semantic conventions. It represents the - // value of the `TotalSegments` request parameter. - // - // Type: int - // RequirementLevel: Optional - // Stability: experimental - // Examples: 100 - AWSDynamoDBTotalSegmentsKey = attribute.Key("aws.dynamodb.total_segments") -) - -// AWSDynamoDBCount returns an attribute KeyValue conforming to the -// "aws.dynamodb.count" semantic conventions. It represents the value of the -// `Count` response parameter. -func AWSDynamoDBCount(val int) attribute.KeyValue { - return AWSDynamoDBCountKey.Int(val) -} - -// AWSDynamoDBScannedCount returns an attribute KeyValue conforming to the -// "aws.dynamodb.scanned_count" semantic conventions. It represents the value -// of the `ScannedCount` response parameter. -func AWSDynamoDBScannedCount(val int) attribute.KeyValue { - return AWSDynamoDBScannedCountKey.Int(val) -} - -// AWSDynamoDBSegment returns an attribute KeyValue conforming to the -// "aws.dynamodb.segment" semantic conventions. It represents the value of the -// `Segment` request parameter. -func AWSDynamoDBSegment(val int) attribute.KeyValue { - return AWSDynamoDBSegmentKey.Int(val) -} - -// AWSDynamoDBTotalSegments returns an attribute KeyValue conforming to the -// "aws.dynamodb.total_segments" semantic conventions. It represents the value -// of the `TotalSegments` request parameter. -func AWSDynamoDBTotalSegments(val int) attribute.KeyValue { - return AWSDynamoDBTotalSegmentsKey.Int(val) -} - -// DynamoDB.UpdateTable -const ( - // AWSDynamoDBAttributeDefinitionsKey is the attribute Key conforming to - // the "aws.dynamodb.attribute_definitions" semantic conventions. It - // represents the JSON-serialized value of each item in the - // `AttributeDefinitions` request field. - // - // Type: string[] - // RequirementLevel: Optional - // Stability: experimental - // Examples: '{ "AttributeName": "string", "AttributeType": "string" }' - AWSDynamoDBAttributeDefinitionsKey = attribute.Key("aws.dynamodb.attribute_definitions") - - // AWSDynamoDBGlobalSecondaryIndexUpdatesKey is the attribute Key - // conforming to the "aws.dynamodb.global_secondary_index_updates" semantic - // conventions. It represents the JSON-serialized value of each item in the - // the `GlobalSecondaryIndexUpdates` request field. - // - // Type: string[] - // RequirementLevel: Optional - // Stability: experimental - // Examples: '{ "Create": { "IndexName": "string", "KeySchema": [ { - // "AttributeName": "string", "KeyType": "string" } ], "Projection": { - // "NonKeyAttributes": [ "string" ], "ProjectionType": "string" }, - // "ProvisionedThroughput": { "ReadCapacityUnits": number, - // "WriteCapacityUnits": number } }' - AWSDynamoDBGlobalSecondaryIndexUpdatesKey = attribute.Key("aws.dynamodb.global_secondary_index_updates") -) - -// AWSDynamoDBAttributeDefinitions returns an attribute KeyValue conforming -// to the "aws.dynamodb.attribute_definitions" semantic conventions. It -// represents the JSON-serialized value of each item in the -// `AttributeDefinitions` request field. -func AWSDynamoDBAttributeDefinitions(val ...string) attribute.KeyValue { - return AWSDynamoDBAttributeDefinitionsKey.StringSlice(val) -} - -// AWSDynamoDBGlobalSecondaryIndexUpdates returns an attribute KeyValue -// conforming to the "aws.dynamodb.global_secondary_index_updates" semantic -// conventions. It represents the JSON-serialized value of each item in the the -// `GlobalSecondaryIndexUpdates` request field. -func AWSDynamoDBGlobalSecondaryIndexUpdates(val ...string) attribute.KeyValue { - return AWSDynamoDBGlobalSecondaryIndexUpdatesKey.StringSlice(val) -} - -// Attributes that exist for S3 request types. -const ( - // AWSS3BucketKey is the attribute Key conforming to the "aws.s3.bucket" - // semantic conventions. It represents the S3 bucket name the request - // refers to. Corresponds to the `--bucket` parameter of the [S3 - // API](https://docs.aws.amazon.com/cli/latest/reference/s3api/index.html) - // operations. - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: 'some-bucket-name' - // Note: The `bucket` attribute is applicable to all S3 operations that - // reference a bucket, i.e. that require the bucket name as a mandatory - // parameter. - // This applies to almost all S3 operations except `list-buckets`. - AWSS3BucketKey = attribute.Key("aws.s3.bucket") - - // AWSS3CopySourceKey is the attribute Key conforming to the - // "aws.s3.copy_source" semantic conventions. It represents the source - // object (in the form `bucket`/`key`) for the copy operation. - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: 'someFile.yml' - // Note: The `copy_source` attribute applies to S3 copy operations and - // corresponds to the `--copy-source` parameter - // of the [copy-object operation within the S3 - // API](https://docs.aws.amazon.com/cli/latest/reference/s3api/copy-object.html). - // This applies in particular to the following operations: - // - // - - // [copy-object](https://docs.aws.amazon.com/cli/latest/reference/s3api/copy-object.html) - // - - // [upload-part-copy](https://docs.aws.amazon.com/cli/latest/reference/s3api/upload-part-copy.html) - AWSS3CopySourceKey = attribute.Key("aws.s3.copy_source") - - // AWSS3DeleteKey is the attribute Key conforming to the "aws.s3.delete" - // semantic conventions. It represents the delete request container that - // specifies the objects to be deleted. - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: - // 'Objects=[{Key=string,VersionID=string},{Key=string,VersionID=string}],Quiet=boolean' - // Note: The `delete` attribute is only applicable to the - // [delete-object](https://docs.aws.amazon.com/cli/latest/reference/s3api/delete-object.html) - // operation. - // The `delete` attribute corresponds to the `--delete` parameter of the - // [delete-objects operation within the S3 - // API](https://docs.aws.amazon.com/cli/latest/reference/s3api/delete-objects.html). - AWSS3DeleteKey = attribute.Key("aws.s3.delete") - - // AWSS3KeyKey is the attribute Key conforming to the "aws.s3.key" semantic - // conventions. It represents the S3 object key the request refers to. - // Corresponds to the `--key` parameter of the [S3 - // API](https://docs.aws.amazon.com/cli/latest/reference/s3api/index.html) - // operations. - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: 'someFile.yml' - // Note: The `key` attribute is applicable to all object-related S3 - // operations, i.e. that require the object key as a mandatory parameter. - // This applies in particular to the following operations: - // - // - - // [copy-object](https://docs.aws.amazon.com/cli/latest/reference/s3api/copy-object.html) - // - - // [delete-object](https://docs.aws.amazon.com/cli/latest/reference/s3api/delete-object.html) - // - - // [get-object](https://docs.aws.amazon.com/cli/latest/reference/s3api/get-object.html) - // - - // [head-object](https://docs.aws.amazon.com/cli/latest/reference/s3api/head-object.html) - // - - // [put-object](https://docs.aws.amazon.com/cli/latest/reference/s3api/put-object.html) - // - - // [restore-object](https://docs.aws.amazon.com/cli/latest/reference/s3api/restore-object.html) - // - - // [select-object-content](https://docs.aws.amazon.com/cli/latest/reference/s3api/select-object-content.html) - // - - // [abort-multipart-upload](https://docs.aws.amazon.com/cli/latest/reference/s3api/abort-multipart-upload.html) - // - - // [complete-multipart-upload](https://docs.aws.amazon.com/cli/latest/reference/s3api/complete-multipart-upload.html) - // - - // [create-multipart-upload](https://docs.aws.amazon.com/cli/latest/reference/s3api/create-multipart-upload.html) - // - - // [list-parts](https://docs.aws.amazon.com/cli/latest/reference/s3api/list-parts.html) - // - - // [upload-part](https://docs.aws.amazon.com/cli/latest/reference/s3api/upload-part.html) - // - - // [upload-part-copy](https://docs.aws.amazon.com/cli/latest/reference/s3api/upload-part-copy.html) - AWSS3KeyKey = attribute.Key("aws.s3.key") - - // AWSS3PartNumberKey is the attribute Key conforming to the - // "aws.s3.part_number" semantic conventions. It represents the part number - // of the part being uploaded in a multipart-upload operation. This is a - // positive integer between 1 and 10,000. - // - // Type: int - // RequirementLevel: Optional - // Stability: experimental - // Examples: 3456 - // Note: The `part_number` attribute is only applicable to the - // [upload-part](https://docs.aws.amazon.com/cli/latest/reference/s3api/upload-part.html) - // and - // [upload-part-copy](https://docs.aws.amazon.com/cli/latest/reference/s3api/upload-part-copy.html) - // operations. - // The `part_number` attribute corresponds to the `--part-number` parameter - // of the - // [upload-part operation within the S3 - // API](https://docs.aws.amazon.com/cli/latest/reference/s3api/upload-part.html). - AWSS3PartNumberKey = attribute.Key("aws.s3.part_number") - - // AWSS3UploadIDKey is the attribute Key conforming to the - // "aws.s3.upload_id" semantic conventions. It represents the upload ID - // that identifies the multipart upload. - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: 'dfRtDYWFbkRONycy.Yxwh66Yjlx.cph0gtNBtJ' - // Note: The `upload_id` attribute applies to S3 multipart-upload - // operations and corresponds to the `--upload-id` parameter - // of the [S3 - // API](https://docs.aws.amazon.com/cli/latest/reference/s3api/index.html) - // multipart operations. - // This applies in particular to the following operations: - // - // - - // [abort-multipart-upload](https://docs.aws.amazon.com/cli/latest/reference/s3api/abort-multipart-upload.html) - // - - // [complete-multipart-upload](https://docs.aws.amazon.com/cli/latest/reference/s3api/complete-multipart-upload.html) - // - - // [list-parts](https://docs.aws.amazon.com/cli/latest/reference/s3api/list-parts.html) - // - - // [upload-part](https://docs.aws.amazon.com/cli/latest/reference/s3api/upload-part.html) - // - - // [upload-part-copy](https://docs.aws.amazon.com/cli/latest/reference/s3api/upload-part-copy.html) - AWSS3UploadIDKey = attribute.Key("aws.s3.upload_id") -) - -// AWSS3Bucket returns an attribute KeyValue conforming to the -// "aws.s3.bucket" semantic conventions. It represents the S3 bucket name the -// request refers to. Corresponds to the `--bucket` parameter of the [S3 -// API](https://docs.aws.amazon.com/cli/latest/reference/s3api/index.html) -// operations. -func AWSS3Bucket(val string) attribute.KeyValue { - return AWSS3BucketKey.String(val) -} - -// AWSS3CopySource returns an attribute KeyValue conforming to the -// "aws.s3.copy_source" semantic conventions. It represents the source object -// (in the form `bucket`/`key`) for the copy operation. -func AWSS3CopySource(val string) attribute.KeyValue { - return AWSS3CopySourceKey.String(val) -} - -// AWSS3Delete returns an attribute KeyValue conforming to the -// "aws.s3.delete" semantic conventions. It represents the delete request -// container that specifies the objects to be deleted. -func AWSS3Delete(val string) attribute.KeyValue { - return AWSS3DeleteKey.String(val) -} - -// AWSS3Key returns an attribute KeyValue conforming to the "aws.s3.key" -// semantic conventions. It represents the S3 object key the request refers to. -// Corresponds to the `--key` parameter of the [S3 -// API](https://docs.aws.amazon.com/cli/latest/reference/s3api/index.html) -// operations. -func AWSS3Key(val string) attribute.KeyValue { - return AWSS3KeyKey.String(val) -} - -// AWSS3PartNumber returns an attribute KeyValue conforming to the -// "aws.s3.part_number" semantic conventions. It represents the part number of -// the part being uploaded in a multipart-upload operation. This is a positive -// integer between 1 and 10,000. -func AWSS3PartNumber(val int) attribute.KeyValue { - return AWSS3PartNumberKey.Int(val) -} - -// AWSS3UploadID returns an attribute KeyValue conforming to the -// "aws.s3.upload_id" semantic conventions. It represents the upload ID that -// identifies the multipart upload. -func AWSS3UploadID(val string) attribute.KeyValue { - return AWSS3UploadIDKey.String(val) -} - -// Semantic conventions to apply when instrumenting the GraphQL implementation. -// They map GraphQL operations to attributes on a Span. -const ( - // GraphqlDocumentKey is the attribute Key conforming to the - // "graphql.document" semantic conventions. It represents the GraphQL - // document being executed. - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: 'query findBookByID { bookByID(id: ?) { name } }' - // Note: The value may be sanitized to exclude sensitive information. - GraphqlDocumentKey = attribute.Key("graphql.document") - - // GraphqlOperationNameKey is the attribute Key conforming to the - // "graphql.operation.name" semantic conventions. It represents the name of - // the operation being executed. - // - // Type: string - // RequirementLevel: Optional - // Stability: experimental - // Examples: 'findBookByID' - GraphqlOperationNameKey = attribute.Key("graphql.operation.name") - - // GraphqlOperationTypeKey is the attribute Key conforming to the - // "graphql.operation.type" semantic conventions. It represents the type of - // the operation being executed. - // - // Type: Enum - // RequirementLevel: Optional - // Stability: experimental - // Examples: 'query', 'mutation', 'subscription' - GraphqlOperationTypeKey = attribute.Key("graphql.operation.type") -) - -var ( - // GraphQL query - GraphqlOperationTypeQuery = GraphqlOperationTypeKey.String("query") - // GraphQL mutation - GraphqlOperationTypeMutation = GraphqlOperationTypeKey.String("mutation") - // GraphQL subscription - GraphqlOperationTypeSubscription = GraphqlOperationTypeKey.String("subscription") -) - -// GraphqlDocument returns an attribute KeyValue conforming to the -// "graphql.document" semantic conventions. It represents the GraphQL document -// being executed. -func GraphqlDocument(val string) attribute.KeyValue { - return GraphqlDocumentKey.String(val) -} - -// GraphqlOperationName returns an attribute KeyValue conforming to the -// "graphql.operation.name" semantic conventions. It represents the name of the -// operation being executed. -func GraphqlOperationName(val string) attribute.KeyValue { - return GraphqlOperationNameKey.String(val) -} diff --git a/vendor/go.opentelemetry.io/otel/trace/config.go b/vendor/go.opentelemetry.io/otel/trace/config.go index 273d58e00..9c0b720a4 100644 --- a/vendor/go.opentelemetry.io/otel/trace/config.go +++ b/vendor/go.opentelemetry.io/otel/trace/config.go @@ -213,7 +213,7 @@ var _ SpanStartEventOption = attributeOption{} // WithAttributes adds the attributes related to a span life-cycle event. // These attributes are used to describe the work a Span represents when this -// option is provided to a Span's start or end events. Otherwise, these +// option is provided to a Span's start event. Otherwise, these // attributes provide additional information about the event being recorded // (e.g. error, state change, processing progress, system event). // diff --git a/vendor/go.opentelemetry.io/otel/version.go b/vendor/go.opentelemetry.io/otel/version.go index 6d3c7b1f4..fb7d12673 100644 --- a/vendor/go.opentelemetry.io/otel/version.go +++ b/vendor/go.opentelemetry.io/otel/version.go @@ -5,5 +5,5 @@ package otel // import "go.opentelemetry.io/otel" // Version is the current release version of OpenTelemetry in use. func Version() string { - return "1.31.0" + return "1.33.0" } diff --git a/vendor/go.opentelemetry.io/otel/versions.yaml b/vendor/go.opentelemetry.io/otel/versions.yaml index cdebdb5eb..9f878cd1f 100644 --- a/vendor/go.opentelemetry.io/otel/versions.yaml +++ b/vendor/go.opentelemetry.io/otel/versions.yaml @@ -3,19 +3,13 @@ module-sets: stable-v1: - version: v1.31.0 + version: v1.33.0 modules: - go.opentelemetry.io/otel - go.opentelemetry.io/otel/bridge/opencensus - go.opentelemetry.io/otel/bridge/opencensus/test - go.opentelemetry.io/otel/bridge/opentracing - go.opentelemetry.io/otel/bridge/opentracing/test - - go.opentelemetry.io/otel/example/dice - - go.opentelemetry.io/otel/example/namedtracer - - go.opentelemetry.io/otel/example/opencensus - - go.opentelemetry.io/otel/example/otel-collector - - go.opentelemetry.io/otel/example/passthrough - - go.opentelemetry.io/otel/example/zipkin - go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetricgrpc - go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetrichttp - go.opentelemetry.io/otel/exporters/otlp/otlptrace @@ -29,12 +23,11 @@ module-sets: - go.opentelemetry.io/otel/sdk/metric - go.opentelemetry.io/otel/trace experimental-metrics: - version: v0.53.0 + version: v0.55.0 modules: - - go.opentelemetry.io/otel/example/prometheus - go.opentelemetry.io/otel/exporters/prometheus experimental-logs: - version: v0.7.0 + version: v0.9.0 modules: - go.opentelemetry.io/otel/log - go.opentelemetry.io/otel/sdk/log @@ -42,7 +35,7 @@ module-sets: - go.opentelemetry.io/otel/exporters/otlp/otlplog/otlploghttp - go.opentelemetry.io/otel/exporters/stdout/stdoutlog experimental-schema: - version: v0.0.10 + version: v0.0.12 modules: - go.opentelemetry.io/otel/schema excluded-modules: diff --git a/vendor/google.golang.org/genproto/googleapis/api/httpbody/httpbody.pb.go b/vendor/google.golang.org/genproto/googleapis/api/httpbody/httpbody.pb.go index e7d3805e3..f388426b0 100644 --- a/vendor/google.golang.org/genproto/googleapis/api/httpbody/httpbody.pb.go +++ b/vendor/google.golang.org/genproto/googleapis/api/httpbody/httpbody.pb.go @@ -159,14 +159,14 @@ var file_google_api_httpbody_proto_rawDesc = []byte{ 0x04, 0x64, 0x61, 0x74, 0x61, 0x12, 0x34, 0x0a, 0x0a, 0x65, 0x78, 0x74, 0x65, 0x6e, 0x73, 0x69, 0x6f, 0x6e, 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x41, 0x6e, 0x79, 0x52, - 0x0a, 0x65, 0x78, 0x74, 0x65, 0x6e, 0x73, 0x69, 0x6f, 0x6e, 0x73, 0x42, 0x68, 0x0a, 0x0e, 0x63, + 0x0a, 0x65, 0x78, 0x74, 0x65, 0x6e, 0x73, 0x69, 0x6f, 0x6e, 0x73, 0x42, 0x65, 0x0a, 0x0e, 0x63, 0x6f, 0x6d, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x61, 0x70, 0x69, 0x42, 0x0d, 0x48, 0x74, 0x74, 0x70, 0x42, 0x6f, 0x64, 0x79, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x50, 0x01, 0x5a, 0x3b, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x67, 0x6f, 0x6c, 0x61, 0x6e, 0x67, 0x2e, 0x6f, 0x72, 0x67, 0x2f, 0x67, 0x65, 0x6e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2f, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x61, 0x70, 0x69, 0x73, 0x2f, 0x61, 0x70, 0x69, 0x2f, 0x68, 0x74, 0x74, 0x70, 0x62, 0x6f, - 0x64, 0x79, 0x3b, 0x68, 0x74, 0x74, 0x70, 0x62, 0x6f, 0x64, 0x79, 0xf8, 0x01, 0x01, 0xa2, 0x02, - 0x04, 0x47, 0x41, 0x50, 0x49, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, + 0x64, 0x79, 0x3b, 0x68, 0x74, 0x74, 0x70, 0x62, 0x6f, 0x64, 0x79, 0xa2, 0x02, 0x04, 0x47, 0x41, + 0x50, 0x49, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, } var ( diff --git a/vendor/modules.txt b/vendor/modules.txt index bd161994d..3fa1023a1 100644 --- a/vendor/modules.txt +++ b/vendor/modules.txt @@ -34,7 +34,7 @@ github.com/cespare/xxhash/v2 github.com/cncf/xds/go/udpa/annotations github.com/cncf/xds/go/xds/annotations/v3 github.com/cncf/xds/go/xds/core/v3 -# github.com/containerd/containerd v1.7.23 +# github.com/containerd/containerd v1.7.24 ## explicit; go 1.21 github.com/containerd/containerd/archive/compression github.com/containerd/containerd/content @@ -52,7 +52,7 @@ github.com/containerd/containerd/remotes/docker/schema1 github.com/containerd/containerd/remotes/errors github.com/containerd/containerd/tracing github.com/containerd/containerd/version -# github.com/containerd/errdefs v0.3.0 +# github.com/containerd/errdefs v1.0.0 ## explicit; go 1.20 github.com/containerd/errdefs # github.com/containerd/log v0.1.0 @@ -96,9 +96,10 @@ github.com/envoyproxy/protoc-gen-validate/validate # github.com/felixge/httpsnoop v1.0.4 ## explicit; go 1.13 github.com/felixge/httpsnoop -# github.com/fsnotify/fsnotify v1.7.0 +# github.com/fsnotify/fsnotify v1.8.0 ## explicit; go 1.17 github.com/fsnotify/fsnotify +github.com/fsnotify/fsnotify/internal # github.com/go-ini/ini v1.67.0 ## explicit github.com/go-ini/ini @@ -147,8 +148,8 @@ github.com/google/uuid # github.com/gorilla/mux v1.8.1 ## explicit; go 1.20 github.com/gorilla/mux -# github.com/grpc-ecosystem/grpc-gateway/v2 v2.20.0 -## explicit; go 1.20 +# github.com/grpc-ecosystem/grpc-gateway/v2 v2.24.0 +## explicit; go 1.22.7 github.com/grpc-ecosystem/grpc-gateway/v2/internal/httprule github.com/grpc-ecosystem/grpc-gateway/v2/runtime github.com/grpc-ecosystem/grpc-gateway/v2/utilities @@ -194,26 +195,16 @@ github.com/munnerz/goautoneg # github.com/olekukonko/tablewriter v0.0.5 ## explicit; go 1.12 github.com/olekukonko/tablewriter -# github.com/open-policy-agent/opa v0.70.0 -## explicit; go 1.21 +# github.com/open-policy-agent/opa v1.0.0 +## explicit; go 1.22.7 github.com/open-policy-agent/opa/ast -github.com/open-policy-agent/opa/ast/internal/scanner -github.com/open-policy-agent/opa/ast/internal/tokens github.com/open-policy-agent/opa/ast/json -github.com/open-policy-agent/opa/ast/location github.com/open-policy-agent/opa/bundle github.com/open-policy-agent/opa/capabilities github.com/open-policy-agent/opa/cmd github.com/open-policy-agent/opa/cmd/internal/env github.com/open-policy-agent/opa/cmd/internal/exec -github.com/open-policy-agent/opa/compile github.com/open-policy-agent/opa/config -github.com/open-policy-agent/opa/cover -github.com/open-policy-agent/opa/dependencies -github.com/open-policy-agent/opa/download -github.com/open-policy-agent/opa/features/tracing -github.com/open-policy-agent/opa/features/wasm -github.com/open-policy-agent/opa/format github.com/open-policy-agent/opa/hooks github.com/open-policy-agent/opa/internal/bundle github.com/open-policy-agent/opa/internal/bundle/inspect @@ -277,56 +268,90 @@ github.com/open-policy-agent/opa/internal/wasm/sdk/opa/capabilities github.com/open-policy-agent/opa/internal/wasm/sdk/opa/errors github.com/open-policy-agent/opa/internal/wasm/types github.com/open-policy-agent/opa/internal/wasm/util -github.com/open-policy-agent/opa/ir -github.com/open-policy-agent/opa/keys github.com/open-policy-agent/opa/loader -github.com/open-policy-agent/opa/loader/extension -github.com/open-policy-agent/opa/loader/filter github.com/open-policy-agent/opa/logging github.com/open-policy-agent/opa/logging/test github.com/open-policy-agent/opa/metrics github.com/open-policy-agent/opa/plugins -github.com/open-policy-agent/opa/plugins/bundle -github.com/open-policy-agent/opa/plugins/discovery github.com/open-policy-agent/opa/plugins/logs -github.com/open-policy-agent/opa/plugins/logs/status -github.com/open-policy-agent/opa/plugins/rest -github.com/open-policy-agent/opa/plugins/server/decoding -github.com/open-policy-agent/opa/plugins/server/encoding -github.com/open-policy-agent/opa/plugins/server/metrics -github.com/open-policy-agent/opa/plugins/status -github.com/open-policy-agent/opa/profiler -github.com/open-policy-agent/opa/refactor github.com/open-policy-agent/opa/rego -github.com/open-policy-agent/opa/repl -github.com/open-policy-agent/opa/resolver github.com/open-policy-agent/opa/resolver/wasm github.com/open-policy-agent/opa/runtime -github.com/open-policy-agent/opa/schemas -github.com/open-policy-agent/opa/sdk github.com/open-policy-agent/opa/server -github.com/open-policy-agent/opa/server/authorizer -github.com/open-policy-agent/opa/server/handlers -github.com/open-policy-agent/opa/server/identifier -github.com/open-policy-agent/opa/server/types -github.com/open-policy-agent/opa/server/writer github.com/open-policy-agent/opa/storage -github.com/open-policy-agent/opa/storage/disk github.com/open-policy-agent/opa/storage/inmem -github.com/open-policy-agent/opa/storage/internal/errors -github.com/open-policy-agent/opa/storage/internal/ptr -github.com/open-policy-agent/opa/tester github.com/open-policy-agent/opa/topdown github.com/open-policy-agent/opa/topdown/builtins github.com/open-policy-agent/opa/topdown/cache -github.com/open-policy-agent/opa/topdown/copypropagation -github.com/open-policy-agent/opa/topdown/lineage github.com/open-policy-agent/opa/topdown/print github.com/open-policy-agent/opa/tracing -github.com/open-policy-agent/opa/types github.com/open-policy-agent/opa/util -github.com/open-policy-agent/opa/util/decoding -github.com/open-policy-agent/opa/version +github.com/open-policy-agent/opa/v1/ast +github.com/open-policy-agent/opa/v1/ast/internal/scanner +github.com/open-policy-agent/opa/v1/ast/internal/tokens +github.com/open-policy-agent/opa/v1/ast/json +github.com/open-policy-agent/opa/v1/ast/location +github.com/open-policy-agent/opa/v1/bundle +github.com/open-policy-agent/opa/v1/capabilities +github.com/open-policy-agent/opa/v1/compile +github.com/open-policy-agent/opa/v1/config +github.com/open-policy-agent/opa/v1/cover +github.com/open-policy-agent/opa/v1/dependencies +github.com/open-policy-agent/opa/v1/download +github.com/open-policy-agent/opa/v1/features/tracing +github.com/open-policy-agent/opa/v1/features/wasm +github.com/open-policy-agent/opa/v1/format +github.com/open-policy-agent/opa/v1/hooks +github.com/open-policy-agent/opa/v1/ir +github.com/open-policy-agent/opa/v1/keys +github.com/open-policy-agent/opa/v1/loader +github.com/open-policy-agent/opa/v1/loader/extension +github.com/open-policy-agent/opa/v1/loader/filter +github.com/open-policy-agent/opa/v1/logging +github.com/open-policy-agent/opa/v1/logging/test +github.com/open-policy-agent/opa/v1/metrics +github.com/open-policy-agent/opa/v1/plugins +github.com/open-policy-agent/opa/v1/plugins/bundle +github.com/open-policy-agent/opa/v1/plugins/discovery +github.com/open-policy-agent/opa/v1/plugins/logs +github.com/open-policy-agent/opa/v1/plugins/logs/status +github.com/open-policy-agent/opa/v1/plugins/rest +github.com/open-policy-agent/opa/v1/plugins/server/decoding +github.com/open-policy-agent/opa/v1/plugins/server/encoding +github.com/open-policy-agent/opa/v1/plugins/server/metrics +github.com/open-policy-agent/opa/v1/plugins/status +github.com/open-policy-agent/opa/v1/profiler +github.com/open-policy-agent/opa/v1/refactor +github.com/open-policy-agent/opa/v1/rego +github.com/open-policy-agent/opa/v1/repl +github.com/open-policy-agent/opa/v1/resolver +github.com/open-policy-agent/opa/v1/resolver/wasm +github.com/open-policy-agent/opa/v1/runtime +github.com/open-policy-agent/opa/v1/schemas +github.com/open-policy-agent/opa/v1/sdk +github.com/open-policy-agent/opa/v1/server +github.com/open-policy-agent/opa/v1/server/authorizer +github.com/open-policy-agent/opa/v1/server/handlers +github.com/open-policy-agent/opa/v1/server/identifier +github.com/open-policy-agent/opa/v1/server/types +github.com/open-policy-agent/opa/v1/server/writer +github.com/open-policy-agent/opa/v1/storage +github.com/open-policy-agent/opa/v1/storage/disk +github.com/open-policy-agent/opa/v1/storage/inmem +github.com/open-policy-agent/opa/v1/storage/internal/errors +github.com/open-policy-agent/opa/v1/storage/internal/ptr +github.com/open-policy-agent/opa/v1/tester +github.com/open-policy-agent/opa/v1/topdown +github.com/open-policy-agent/opa/v1/topdown/builtins +github.com/open-policy-agent/opa/v1/topdown/cache +github.com/open-policy-agent/opa/v1/topdown/copypropagation +github.com/open-policy-agent/opa/v1/topdown/lineage +github.com/open-policy-agent/opa/v1/topdown/print +github.com/open-policy-agent/opa/v1/tracing +github.com/open-policy-agent/opa/v1/types +github.com/open-policy-agent/opa/v1/util +github.com/open-policy-agent/opa/v1/util/decoding +github.com/open-policy-agent/opa/v1/version # github.com/opencontainers/go-digest v1.0.0 ## explicit; go 1.13 github.com/opencontainers/go-digest @@ -426,8 +451,6 @@ github.com/spf13/viper/internal/encoding/json github.com/spf13/viper/internal/encoding/toml github.com/spf13/viper/internal/encoding/yaml github.com/spf13/viper/internal/features -# github.com/stretchr/testify v1.10.0 -## explicit; go 1.17 # github.com/subosito/gotenv v1.6.0 ## explicit; go 1.18 github.com/subosito/gotenv @@ -450,20 +473,25 @@ go.opencensus.io/internal go.opencensus.io/trace go.opencensus.io/trace/internal go.opencensus.io/trace/tracestate +# go.opentelemetry.io/auto/sdk v1.1.0 +## explicit; go 1.22.0 +go.opentelemetry.io/auto/sdk +go.opentelemetry.io/auto/sdk/internal/telemetry # go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.53.0 ## explicit; go 1.21 go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc/internal -# go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.53.0 -## explicit; go 1.21 +# go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.58.0 +## explicit; go 1.22.0 go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp +go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp/internal/request go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp/internal/semconv go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp/internal/semconvutil # go.opentelemetry.io/contrib/propagators/b3 v1.28.0 ## explicit; go 1.21 go.opentelemetry.io/contrib/propagators/b3 -# go.opentelemetry.io/otel v1.31.0 -## explicit; go 1.22 +# go.opentelemetry.io/otel v1.33.0 +## explicit; go 1.22.0 go.opentelemetry.io/otel go.opentelemetry.io/otel/attribute go.opentelemetry.io/otel/baggage @@ -477,27 +505,26 @@ go.opentelemetry.io/otel/semconv/internal go.opentelemetry.io/otel/semconv/v1.17.0 go.opentelemetry.io/otel/semconv/v1.20.0 go.opentelemetry.io/otel/semconv/v1.21.0 -go.opentelemetry.io/otel/semconv/v1.24.0 go.opentelemetry.io/otel/semconv/v1.26.0 go.opentelemetry.io/otel/semconv/v1.7.0 -# go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.28.0 -## explicit; go 1.21 +# go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.33.0 +## explicit; go 1.22.7 go.opentelemetry.io/otel/exporters/otlp/otlptrace go.opentelemetry.io/otel/exporters/otlp/otlptrace/internal/tracetransform -# go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.28.0 -## explicit; go 1.21 +# go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.33.0 +## explicit; go 1.22.7 go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc/internal go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc/internal/envconfig go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc/internal/otlpconfig go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc/internal/retry -# go.opentelemetry.io/otel/metric v1.31.0 -## explicit; go 1.22 +# go.opentelemetry.io/otel/metric v1.33.0 +## explicit; go 1.22.0 go.opentelemetry.io/otel/metric go.opentelemetry.io/otel/metric/embedded go.opentelemetry.io/otel/metric/noop -# go.opentelemetry.io/otel/sdk v1.31.0 -## explicit; go 1.22 +# go.opentelemetry.io/otel/sdk v1.33.0 +## explicit; go 1.22.0 go.opentelemetry.io/otel/sdk go.opentelemetry.io/otel/sdk/instrumentation go.opentelemetry.io/otel/sdk/internal/env @@ -505,13 +532,13 @@ go.opentelemetry.io/otel/sdk/internal/x go.opentelemetry.io/otel/sdk/resource go.opentelemetry.io/otel/sdk/trace go.opentelemetry.io/otel/sdk/trace/tracetest -# go.opentelemetry.io/otel/trace v1.31.0 -## explicit; go 1.22 +# go.opentelemetry.io/otel/trace v1.33.0 +## explicit; go 1.22.0 go.opentelemetry.io/otel/trace go.opentelemetry.io/otel/trace/embedded go.opentelemetry.io/otel/trace/noop -# go.opentelemetry.io/proto/otlp v1.3.1 -## explicit; go 1.17 +# go.opentelemetry.io/proto/otlp v1.4.0 +## explicit; go 1.22.7 go.opentelemetry.io/proto/otlp/collector/trace/v1 go.opentelemetry.io/proto/otlp/common/v1 go.opentelemetry.io/proto/otlp/resource/v1 @@ -543,7 +570,7 @@ golang.org/x/lint/golint golang.org/x/mod/internal/lazyregexp golang.org/x/mod/module golang.org/x/mod/semver -# golang.org/x/net v0.32.0 +# golang.org/x/net v0.33.0 ## explicit; go 1.18 golang.org/x/net/http/httpguts golang.org/x/net/http2 @@ -573,7 +600,7 @@ golang.org/x/text/secure/bidirule golang.org/x/text/transform golang.org/x/text/unicode/bidi golang.org/x/text/unicode/norm -# golang.org/x/time v0.7.0 +# golang.org/x/time v0.8.0 ## explicit; go 1.18 golang.org/x/time/rate # golang.org/x/tools v0.28.0 @@ -598,10 +625,10 @@ golang.org/x/tools/internal/stdlib golang.org/x/tools/internal/typeparams golang.org/x/tools/internal/typesinternal golang.org/x/tools/internal/versions -# google.golang.org/genproto/googleapis/api v0.0.0-20241015192408-796eee8c2d53 +# google.golang.org/genproto/googleapis/api v0.0.0-20241209162323-e6fa225c2576 ## explicit; go 1.21 google.golang.org/genproto/googleapis/api/httpbody -# google.golang.org/genproto/googleapis/rpc v0.0.0-20241015192408-796eee8c2d53 +# google.golang.org/genproto/googleapis/rpc v0.0.0-20241209162323-e6fa225c2576 ## explicit; go 1.21 google.golang.org/genproto/googleapis/rpc/code google.golang.org/genproto/googleapis/rpc/errdetails