From 07bc9c10ca066f4f68bc51eca298e6ab29291ab2 Mon Sep 17 00:00:00 2001 From: Prajwal S N Date: Wed, 1 Nov 2023 21:45:51 +0530 Subject: [PATCH] fix: handle dereference in proto2 bindings Signed-off-by: Prajwal S N --- Makefile | 2 +- processor.go | 17 +++ protogetter.go | 1 + testdata/Makefile | 2 +- testdata/proto/test_proto2.pb.go | 216 +++++++++++++++++++++++++++++++ testdata/proto/test_proto2.proto | 15 +++ testdata/test_proto2.go | 17 +++ testdata/test_proto2.go.golden | 17 +++ 8 files changed, 285 insertions(+), 2 deletions(-) create mode 100644 testdata/proto/test_proto2.pb.go create mode 100644 testdata/proto/test_proto2.proto create mode 100644 testdata/test_proto2.go create mode 100644 testdata/test_proto2.go.golden diff --git a/Makefile b/Makefile index af4b62b..4c2a62a 100644 --- a/Makefile +++ b/Makefile @@ -1,6 +1,6 @@ .PHONY: test test: - cd testdata && make vendor + $(MAKE) -C testdata vendor go test -v ./... .PHONY: install diff --git a/processor.go b/processor.go index 4026c41..167248d 100644 --- a/processor.go +++ b/processor.go @@ -70,6 +70,23 @@ func (c *processor) process(n ast.Node) (*Result, error) { c.processInner(x) + case *ast.StarExpr: + f, ok := x.X.(*ast.SelectorExpr) + if !ok { + return &Result{}, nil + } + + if !isProtoMessage(c.info, f.X) { + return &Result{}, nil + } + + // proto2 generates fields as pointers. Hence, the indirection + // must be removed when generating the fix for the case. + // The `*` is retained in `c.from`, but excluded from the fix + // present in the `c.to`. + c.writeFrom("*") + c.processInner(x.X) + default: return nil, fmt.Errorf("not implemented for type: %s (%s)", reflect.TypeOf(x), formatNode(n)) } diff --git a/protogetter.go b/protogetter.go index 4328492..0061994 100644 --- a/protogetter.go +++ b/protogetter.go @@ -99,6 +99,7 @@ func Run(pass *analysis.Pass, cfg *Config) ([]Issue, error) { (*ast.AssignStmt)(nil), (*ast.CallExpr)(nil), (*ast.SelectorExpr)(nil), + (*ast.StarExpr)(nil), (*ast.IncDecStmt)(nil), (*ast.UnaryExpr)(nil), } diff --git a/testdata/Makefile b/testdata/Makefile index b3e7b34..79570b4 100644 --- a/testdata/Makefile +++ b/testdata/Makefile @@ -17,4 +17,4 @@ protoc: --go_opt paths=source_relative \ --go-grpc_out proto \ --go-grpc_opt paths=source_relative \ - proto/test.proto \ No newline at end of file + proto/test.proto proto/test_proto2.proto diff --git a/testdata/proto/test_proto2.pb.go b/testdata/proto/test_proto2.pb.go new file mode 100644 index 0000000..fddec6a --- /dev/null +++ b/testdata/proto/test_proto2.pb.go @@ -0,0 +1,216 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// versions: +// protoc-gen-go v1.31.0 +// protoc v4.24.4 +// source: testdata/proto/test_proto2.proto + +package proto + +import ( + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" + reflect "reflect" + sync "sync" +) + +const ( + // Verify that this generated code is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) + // Verify that runtime/protoimpl is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) +) + +type TestProto2 struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + D *float64 `protobuf:"fixed64,1,req,name=d" json:"d,omitempty"` + F *float32 `protobuf:"fixed32,2,req,name=f" json:"f,omitempty"` + I32 *int32 `protobuf:"varint,3,req,name=i32" json:"i32,omitempty"` + I64 *int64 `protobuf:"varint,4,req,name=i64" json:"i64,omitempty"` + U32 *uint32 `protobuf:"varint,5,opt,name=u32" json:"u32,omitempty"` + U64 *uint64 `protobuf:"varint,6,opt,name=u64" json:"u64,omitempty"` + T *bool `protobuf:"varint,7,opt,name=t" json:"t,omitempty"` + B []byte `protobuf:"bytes,8,opt,name=b" json:"b,omitempty"` + S *string `protobuf:"bytes,9,opt,name=s" json:"s,omitempty"` +} + +func (x *TestProto2) Reset() { + *x = TestProto2{} + if protoimpl.UnsafeEnabled { + mi := &file_testdata_proto_test_proto2_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *TestProto2) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*TestProto2) ProtoMessage() {} + +func (x *TestProto2) ProtoReflect() protoreflect.Message { + mi := &file_testdata_proto_test_proto2_proto_msgTypes[0] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use TestProto2.ProtoReflect.Descriptor instead. +func (*TestProto2) Descriptor() ([]byte, []int) { + return file_testdata_proto_test_proto2_proto_rawDescGZIP(), []int{0} +} + +func (x *TestProto2) GetD() float64 { + if x != nil && x.D != nil { + return *x.D + } + return 0 +} + +func (x *TestProto2) GetF() float32 { + if x != nil && x.F != nil { + return *x.F + } + return 0 +} + +func (x *TestProto2) GetI32() int32 { + if x != nil && x.I32 != nil { + return *x.I32 + } + return 0 +} + +func (x *TestProto2) GetI64() int64 { + if x != nil && x.I64 != nil { + return *x.I64 + } + return 0 +} + +func (x *TestProto2) GetU32() uint32 { + if x != nil && x.U32 != nil { + return *x.U32 + } + return 0 +} + +func (x *TestProto2) GetU64() uint64 { + if x != nil && x.U64 != nil { + return *x.U64 + } + return 0 +} + +func (x *TestProto2) GetT() bool { + if x != nil && x.T != nil { + return *x.T + } + return false +} + +func (x *TestProto2) GetB() []byte { + if x != nil { + return x.B + } + return nil +} + +func (x *TestProto2) GetS() string { + if x != nil && x.S != nil { + return *x.S + } + return "" +} + +var File_testdata_proto_test_proto2_proto protoreflect.FileDescriptor + +var file_testdata_proto_test_proto2_proto_rawDesc = []byte{ + 0x0a, 0x20, 0x74, 0x65, 0x73, 0x74, 0x64, 0x61, 0x74, 0x61, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, + 0x2f, 0x74, 0x65, 0x73, 0x74, 0x5f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x32, 0x2e, 0x70, 0x72, 0x6f, + 0x74, 0x6f, 0x22, 0x9a, 0x01, 0x0a, 0x0a, 0x54, 0x65, 0x73, 0x74, 0x50, 0x72, 0x6f, 0x74, 0x6f, + 0x32, 0x12, 0x0c, 0x0a, 0x01, 0x64, 0x18, 0x01, 0x20, 0x02, 0x28, 0x01, 0x52, 0x01, 0x64, 0x12, + 0x0c, 0x0a, 0x01, 0x66, 0x18, 0x02, 0x20, 0x02, 0x28, 0x02, 0x52, 0x01, 0x66, 0x12, 0x10, 0x0a, + 0x03, 0x69, 0x33, 0x32, 0x18, 0x03, 0x20, 0x02, 0x28, 0x05, 0x52, 0x03, 0x69, 0x33, 0x32, 0x12, + 0x10, 0x0a, 0x03, 0x69, 0x36, 0x34, 0x18, 0x04, 0x20, 0x02, 0x28, 0x03, 0x52, 0x03, 0x69, 0x36, + 0x34, 0x12, 0x10, 0x0a, 0x03, 0x75, 0x33, 0x32, 0x18, 0x05, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x03, + 0x75, 0x33, 0x32, 0x12, 0x10, 0x0a, 0x03, 0x75, 0x36, 0x34, 0x18, 0x06, 0x20, 0x01, 0x28, 0x04, + 0x52, 0x03, 0x75, 0x36, 0x34, 0x12, 0x0c, 0x0a, 0x01, 0x74, 0x18, 0x07, 0x20, 0x01, 0x28, 0x08, + 0x52, 0x01, 0x74, 0x12, 0x0c, 0x0a, 0x01, 0x62, 0x18, 0x08, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x01, + 0x62, 0x12, 0x0c, 0x0a, 0x01, 0x73, 0x18, 0x09, 0x20, 0x01, 0x28, 0x09, 0x52, 0x01, 0x73, 0x42, + 0x30, 0x5a, 0x2e, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x67, 0x68, + 0x6f, 0x73, 0x74, 0x69, 0x61, 0x6d, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x67, 0x65, 0x74, 0x74, + 0x65, 0x72, 0x2f, 0x74, 0x65, 0x73, 0x74, 0x64, 0x61, 0x74, 0x61, 0x2f, 0x70, 0x72, 0x6f, 0x74, + 0x6f, +} + +var ( + file_testdata_proto_test_proto2_proto_rawDescOnce sync.Once + file_testdata_proto_test_proto2_proto_rawDescData = file_testdata_proto_test_proto2_proto_rawDesc +) + +func file_testdata_proto_test_proto2_proto_rawDescGZIP() []byte { + file_testdata_proto_test_proto2_proto_rawDescOnce.Do(func() { + file_testdata_proto_test_proto2_proto_rawDescData = protoimpl.X.CompressGZIP(file_testdata_proto_test_proto2_proto_rawDescData) + }) + return file_testdata_proto_test_proto2_proto_rawDescData +} + +var file_testdata_proto_test_proto2_proto_msgTypes = make([]protoimpl.MessageInfo, 1) +var file_testdata_proto_test_proto2_proto_goTypes = []interface{}{ + (*TestProto2)(nil), // 0: TestProto2 +} +var file_testdata_proto_test_proto2_proto_depIdxs = []int32{ + 0, // [0:0] is the sub-list for method output_type + 0, // [0:0] is the sub-list for method input_type + 0, // [0:0] is the sub-list for extension type_name + 0, // [0:0] is the sub-list for extension extendee + 0, // [0:0] is the sub-list for field type_name +} + +func init() { file_testdata_proto_test_proto2_proto_init() } +func file_testdata_proto_test_proto2_proto_init() { + if File_testdata_proto_test_proto2_proto != nil { + return + } + if !protoimpl.UnsafeEnabled { + file_testdata_proto_test_proto2_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*TestProto2); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + } + type x struct{} + out := protoimpl.TypeBuilder{ + File: protoimpl.DescBuilder{ + GoPackagePath: reflect.TypeOf(x{}).PkgPath(), + RawDescriptor: file_testdata_proto_test_proto2_proto_rawDesc, + NumEnums: 0, + NumMessages: 1, + NumExtensions: 0, + NumServices: 0, + }, + GoTypes: file_testdata_proto_test_proto2_proto_goTypes, + DependencyIndexes: file_testdata_proto_test_proto2_proto_depIdxs, + MessageInfos: file_testdata_proto_test_proto2_proto_msgTypes, + }.Build() + File_testdata_proto_test_proto2_proto = out.File + file_testdata_proto_test_proto2_proto_rawDesc = nil + file_testdata_proto_test_proto2_proto_goTypes = nil + file_testdata_proto_test_proto2_proto_depIdxs = nil +} diff --git a/testdata/proto/test_proto2.proto b/testdata/proto/test_proto2.proto new file mode 100644 index 0000000..2067f26 --- /dev/null +++ b/testdata/proto/test_proto2.proto @@ -0,0 +1,15 @@ +syntax = "proto2"; + +option go_package = "github.com/ghostiam/protogetter/testdata/proto"; + +message TestProto2 { + required double d = 1; + required float f = 2; + required int32 i32 = 3; + required int64 i64 = 4; + optional uint32 u32 = 5; + optional uint64 u64 = 6; + optional bool t = 7; + optional bytes b = 8; + optional string s = 9; +} diff --git a/testdata/test_proto2.go b/testdata/test_proto2.go new file mode 100644 index 0000000..5f6f204 --- /dev/null +++ b/testdata/test_proto2.go @@ -0,0 +1,17 @@ +package testdata + +import ( + "github.com/ghostiam/protogetter/testdata/proto" +) + +func testInvalidProto2(t *proto.TestProto2) { + _ = *t.D // want `avoid direct access to proto field \*t\.D, use t\.GetD\(\) instead` + _ = *t.F // want `avoid direct access to proto field \*t\.F, use t\.GetF\(\) instead` + _ = *t.I32 // want `avoid direct access to proto field \*t\.I32, use t\.GetI32\(\) instead` + _ = *t.I64 // want `avoid direct access to proto field \*t\.I64, use t\.GetI64\(\) instead` + _ = *t.U32 // want `avoid direct access to proto field \*t\.U32, use t\.GetU32\(\) instead` + _ = *t.U64 // want `avoid direct access to proto field \*t\.U64, use t\.GetU64\(\) instead` + _ = *t.T // want `avoid direct access to proto field \*t\.T, use t\.GetT\(\) instead` + _ = t.B // want `avoid direct access to proto field t\.B, use t\.GetB\(\) instead` + _ = *t.S // want `avoid direct access to proto field \*t\.S, use t\.GetS\(\) instead` +} diff --git a/testdata/test_proto2.go.golden b/testdata/test_proto2.go.golden new file mode 100644 index 0000000..627687a --- /dev/null +++ b/testdata/test_proto2.go.golden @@ -0,0 +1,17 @@ +package testdata + +import ( + "github.com/ghostiam/protogetter/testdata/proto" +) + +func testInvalidProto2(t *proto.TestProto2) { + _ = t.GetD() // want `avoid direct access to proto field \*t\.D, use t\.GetD\(\) instead` + _ = t.GetF() // want `avoid direct access to proto field \*t\.F, use t\.GetF\(\) instead` + _ = t.GetI32() // want `avoid direct access to proto field \*t\.I32, use t\.GetI32\(\) instead` + _ = t.GetI64() // want `avoid direct access to proto field \*t\.I64, use t\.GetI64\(\) instead` + _ = t.GetU32() // want `avoid direct access to proto field \*t\.U32, use t\.GetU32\(\) instead` + _ = t.GetU64() // want `avoid direct access to proto field \*t\.U64, use t\.GetU64\(\) instead` + _ = t.GetT() // want `avoid direct access to proto field \*t\.T, use t\.GetT\(\) instead` + _ = t.GetB() // want `avoid direct access to proto field t\.B, use t\.GetB\(\) instead` + _ = t.GetS() // want `avoid direct access to proto field \*t\.S, use t\.GetS\(\) instead` +}