From 48cc9983bc98ca5bd3146cea7f7aa02cf6770134 Mon Sep 17 00:00:00 2001 From: Anthony Regeda Date: Sat, 4 Jan 2025 04:39:39 +0100 Subject: [PATCH] refact: use protoreflect to convert check request Signed-off-by: Anthony Regeda --- Makefile | 2 +- envoyauth/protomap.go | 60 +++++++++++++++++ envoyauth/protomap_test.go | 131 +++++++++++++++++++++++++++++++++++++ envoyauth/request.go | 24 ++----- go.mod | 1 + internal/internal_test.go | 4 +- 6 files changed, 202 insertions(+), 20 deletions(-) create mode 100644 envoyauth/protomap.go create mode 100644 envoyauth/protomap_test.go diff --git a/Makefile b/Makefile index 348e99ecc..9926a11a2 100644 --- a/Makefile +++ b/Makefile @@ -175,7 +175,7 @@ deploy-ci: docker-login ensure-release-dir start-builder ci-build-linux ci-build .PHONY: test test: generate - $(DISABLE_CGO) $(GO) test -v -bench=. $(PACKAGES) + $(DISABLE_CGO) $(GO) test -v -bench=. -benchmem $(PACKAGES) .PHONY: test-e2e test-e2e: diff --git a/envoyauth/protomap.go b/envoyauth/protomap.go new file mode 100644 index 000000000..ae440ec83 --- /dev/null +++ b/envoyauth/protomap.go @@ -0,0 +1,60 @@ +package envoyauth + +import ( + "google.golang.org/protobuf/reflect/protoreflect" +) + +// protomap converts protobuf message into map[string]any type using json names. +func protomap(msg protoreflect.Message) map[string]any { + v := msg.Interface() + // handle structpb.Struct + if mapper, ok := v.(interface{ AsMap() map[string]any }); ok { + return mapper.AsMap() + } + + result := make(map[string]any, msg.Descriptor().Fields().Len()) + + msg.Range(func(fd protoreflect.FieldDescriptor, value protoreflect.Value) bool { + name := fd.JSONName() + + switch { + case fd.IsMap(): + mapValue := value.Map() + mapResult := make(map[string]any, mapValue.Len()) + if fd.MapValue().Kind() == protoreflect.MessageKind { + mapValue.Range(func(key protoreflect.MapKey, val protoreflect.Value) bool { + mapResult[key.String()] = protomap(val.Message()) + return true + }) + } else { + mapValue.Range(func(key protoreflect.MapKey, val protoreflect.Value) bool { + mapResult[key.String()] = val.Interface() + return true + }) + } + result[name] = mapResult + + case fd.IsList(): + list := value.List() + listResult := make([]any, list.Len()) + for i := 0; i < list.Len(); i++ { + elem := list.Get(i) + if fd.Kind() == protoreflect.MessageKind { + listResult[i] = protomap(elem.Message()) + } else { + listResult[i] = elem.Interface() + } + } + result[name] = listResult + + case fd.Kind() == protoreflect.MessageKind: + result[name] = protomap(value.Message()) + default: + result[name] = value.Interface() + } + + return true + }) + + return result +} diff --git a/envoyauth/protomap_test.go b/envoyauth/protomap_test.go new file mode 100644 index 000000000..e94631645 --- /dev/null +++ b/envoyauth/protomap_test.go @@ -0,0 +1,131 @@ +package envoyauth + +import ( + "testing" + + "google.golang.org/protobuf/encoding/protojson" + + ext_authz "github.com/envoyproxy/go-control-plane/envoy/service/auth/v3" +) + +const extAuthzRequest = `{ + "attributes": { + "source": { + "address": { + "socketAddress": { + "address": "127.0.0.1" + } + }, + "service": "dummy", + "labels": { + "foo": "bar" + } + }, + "metadataContext": { + "filterMetadata": { + "dummy": { + "hello": "world", + "count": 1 + } + } + }, + "contextExtensions": { + "hello": "world" + }, + "request": { + "http": { + "id": "13359530607844510314", + "method": "GET", + "headers": { + ":authority": "192.168.99.100:31380", + ":method": "GET", + ":path": "/api/v1/products", + "accept": "*/*" + }, + "path": "/api/v1/products", + "host": "192.168.99.100:31380", + "protocol": "HTTP/1.1", + "body": "{\"firstname\": \"foo\", \"lastname\": \"bar\"}" + } + } + } +}` + +func Test_protomap(t *testing.T) { + var req ext_authz.CheckRequest + + if err := protojson.Unmarshal([]byte(extAuthzRequest), &req); err != nil { + t.Fatal(err) + } + + result := protomap(req.ProtoReflect()) + + if result == nil { + t.Fatal("not nil expected") + } + + assertMap(t, result, map[string]any{ + "attributes": map[string]any{ + "source": map[string]any{ + "service": "dummy", + "labels": map[string]any{ + "foo": "bar", + }, + "address": map[string]any{ + "socketAddress": map[string]any{ + "address": "127.0.0.1", + }, + }, + }, + "metadataContext": map[string]any{ + "filterMetadata": map[string]any{ + "dummy": map[string]any{ + "hello": "world", + "count": float64(1), + }, + }, + }, + "contextExtensions": map[string]any{ + "hello": "world", + }, + "request": map[string]any{ + "http": map[string]any{ + "id": "13359530607844510314", + "method": "GET", + "path": "/api/v1/products", + "host": "192.168.99.100:31380", + "protocol": "HTTP/1.1", + "body": "{\"firstname\": \"foo\", \"lastname\": \"bar\"}", + "headers": map[string]any{ + ":authority": "192.168.99.100:31380", + ":method": "GET", + ":path": "/api/v1/products", + "accept": "*/*", + }, + }, + }, + }, + }) +} + +func assertMap(t *testing.T, actual map[string]any, expected map[string]any) { + t.Helper() + if len(actual) != len(expected) { + t.Fatalf("different len of maps, actual %v, expected %v", actual, expected) + } + for k, ev := range expected { + av, ok := actual[k] + if !ok { + t.Fatalf("expected key %s not found", k) + } + if em, ok := ev.(map[string]any); ok { + am, ok := av.(map[string]any) + if !ok { + t.Fatalf("both values must be map[string]any, actual %T", av) + } + assertMap(t, em, am) + } else if ev != av { + t.Fatalf("values of key %s are different, actual %v (%[2]T), expected %v (%[3]T)", k, av, ev) + } + } +} diff --git a/envoyauth/request.go b/envoyauth/request.go index f422d13c1..a0e93bc96 100644 --- a/envoyauth/request.go +++ b/envoyauth/request.go @@ -2,7 +2,6 @@ package envoyauth import ( "encoding/binary" - "encoding/json" "fmt" "io" "mime" @@ -23,17 +22,17 @@ import ( "github.com/open-policy-agent/opa/util" ) -var v2Info = map[string]string{"ext_authz": "v2", "encoding": "encoding/json"} -var v3Info = map[string]string{"ext_authz": "v3", "encoding": "protojson"} +var v2Info = map[string]string{"ext_authz": "v2", "encoding": "protoreflect"} +var v3Info = map[string]string{"ext_authz": "v3", "encoding": "protoreflect"} // RequestToInput - Converts a CheckRequest in either protobuf 2 or 3 to an input map func RequestToInput(req interface{}, logger logging.Logger, protoSet *protoregistry.Files, skipRequestBodyParse bool) (map[string]interface{}, error) { var err error - var input map[string]interface{} - var bs, rawBody []byte + var rawBody []byte var path, body string var headers, version map[string]string + var protoreq protoreflect.Message // NOTE: The path/body/headers blocks look silly, but they allow us to retrieve // the parts of the incoming request we care about, without having to convert @@ -41,30 +40,21 @@ func RequestToInput(req interface{}, logger logging.Logger, protoSet *protoregis // etc -- we only care for its JSON representation as fed into evaluation later. switch req := req.(type) { case *ext_authz_v3.CheckRequest: - bs, err = protojson.Marshal(req) - if err != nil { - return nil, err - } + protoreq = req.ProtoReflect() path = req.GetAttributes().GetRequest().GetHttp().GetPath() body = req.GetAttributes().GetRequest().GetHttp().GetBody() headers = req.GetAttributes().GetRequest().GetHttp().GetHeaders() rawBody = req.GetAttributes().GetRequest().GetHttp().GetRawBody() version = v3Info case *ext_authz_v2.CheckRequest: - bs, err = json.Marshal(req) - if err != nil { - return nil, err - } + protoreq = req.ProtoReflect() path = req.GetAttributes().GetRequest().GetHttp().GetPath() body = req.GetAttributes().GetRequest().GetHttp().GetBody() headers = req.GetAttributes().GetRequest().GetHttp().GetHeaders() version = v2Info } - err = util.UnmarshalJSON(bs, &input) - if err != nil { - return nil, err - } + input := protomap(protoreq) input["version"] = version parsedPath, parsedQuery, err := getParsedPathAndQuery(path) diff --git a/go.mod b/go.mod index 7ec25914a..fd66d34cb 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,7 @@ module github.com/open-policy-agent/opa-envoy-plugin go 1.22.0 + toolchain go1.23.1 require ( diff --git a/internal/internal_test.go b/internal/internal_test.go index 852a554df..9400262b2 100644 --- a/internal/internal_test.go +++ b/internal/internal_test.go @@ -2185,7 +2185,7 @@ func TestVersionInfoInputV3(t *testing.T) { allow { input.version.ext_authz == "v3" - input.version.encoding == "protojson" + input.version.encoding == "protoreflect" } ` server := testAuthzServerWithModule(module, "envoy/authz/allow", nil, withCustomLogger(customLogger)) @@ -2212,7 +2212,7 @@ func TestVersionInfoInputV2(t *testing.T) { allow { input.version.ext_authz == "v2" - input.version.encoding == "encoding/json" + input.version.encoding == "protoreflect" } ` serverV3 := testAuthzServerWithModule(module, "envoy/authz/allow", nil, withCustomLogger(customLogger))