Skip to content

Commit

Permalink
feat: add ability to omit root when writing avro
Browse files Browse the repository at this point in the history
  • Loading branch information
jepp2078 authored and ericwenn committed Aug 30, 2022
1 parent c402786 commit f5eaf84
Show file tree
Hide file tree
Showing 13 changed files with 1,279 additions and 103 deletions.
22 changes: 11 additions & 11 deletions encoding/protoavro/decode.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@ import (

// decodeJSON decodes the JSON encoded avro data and places the
// result in msg.
func decodeJSON(data interface{}, msg proto.Message) error {
return decodeMessage(data, msg.ProtoReflect())
func (o *SchemaOptions) decodeJSON(data interface{}, msg proto.Message) error {
return o.decodeMessage(data, msg.ProtoReflect())
}

func decodeMessage(data interface{}, msg protoreflect.Message) error {
func (o *SchemaOptions) decodeMessage(data interface{}, msg protoreflect.Message) error {
if data == nil {
return nil
}
Expand All @@ -28,28 +28,28 @@ func decodeMessage(data interface{}, msg protoreflect.Message) error {
// unwrap union
desc := msg.Descriptor()
if msgData, ok := d[string(desc.FullName())]; len(d) == 1 && ok {
return decodeMessage(msgData, msg)
return o.decodeMessage(msgData, msg)
}
for fieldName, fieldValue := range d {
fd, ok := findField(desc, fieldName)
if !ok {
return fmt.Errorf("unexpected field %s", fieldName)
}
if err := decodeField(fieldValue, msg, fd); err != nil {
if err := o.decodeField(fieldValue, msg, fd); err != nil {
return err
}
}
return nil
}

func decodeField(data interface{}, val protoreflect.Message, f protoreflect.FieldDescriptor) error {
func (o *SchemaOptions) decodeField(data interface{}, val protoreflect.Message, f protoreflect.FieldDescriptor) error {
if data == nil {
return nil
}
switch {
case f.IsMap():
mp := val.NewField(f).Map()
if err := decodeMap(data, f, mp); err != nil {
if err := o.decodeMap(data, f, mp); err != nil {
return err
}
val.Set(f, protoreflect.ValueOfMap(mp))
Expand All @@ -65,7 +65,7 @@ func decodeField(data interface{}, val protoreflect.Message, f protoreflect.Fiel
list.Append(list.NewElement())
continue
}
fieldValue, err := decodeFieldKind(el, list.NewElement(), f)
fieldValue, err := o.decodeFieldKind(el, list.NewElement(), f)
if err != nil {
return err
}
Expand All @@ -74,7 +74,7 @@ func decodeField(data interface{}, val protoreflect.Message, f protoreflect.Fiel
val.Set(f, protoreflect.ValueOfList(list))
return nil
default:
fieldValue, err := decodeFieldKind(data, val.NewField(f), f)
fieldValue, err := o.decodeFieldKind(data, val.NewField(f), f)
if err != nil {
return err
}
Expand All @@ -83,14 +83,14 @@ func decodeField(data interface{}, val protoreflect.Message, f protoreflect.Fiel
return nil
}

func decodeFieldKind(
func (o *SchemaOptions) decodeFieldKind(
data interface{},
mutable protoreflect.Value,
f protoreflect.FieldDescriptor,
) (protoreflect.Value, error) {
switch f.Kind() {
case protoreflect.MessageKind, protoreflect.GroupKind:
if err := decodeMessage(data, mutable.Message()); err != nil {
if err := o.decodeMessage(data, mutable.Message()); err != nil {
return protoreflect.Value{}, err
}
return mutable, nil
Expand Down
59 changes: 35 additions & 24 deletions encoding/protoavro/encode.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,22 +6,22 @@ import (
)

// encodeJSON returns the Avro JSON encoding of message.
func encodeJSON(message proto.Message) (interface{}, error) {
return messageJSON(message.ProtoReflect())
func (o SchemaOptions) encodeJSON(message proto.Message) (interface{}, error) {
return o.messageJSON(message.ProtoReflect(), 0)
}

func unionValue(key string, value interface{}) map[string]interface{} {
func (o SchemaOptions) unionValue(key string, value interface{}) map[string]interface{} {
return map[string]interface{}{
key: value,
}
}

func messageJSON(message protoreflect.Message) (interface{}, error) {
func (o SchemaOptions) messageJSON(message protoreflect.Message, recursiveIndex int) (interface{}, error) {
if !message.IsValid() {
return nil, nil
}
if isWKT(message.Descriptor().FullName()) {
value, err := encodeWKT(message)
value, err := o.encodeWKT(message)
if err != nil {
return nil, err
}
Expand All @@ -38,7 +38,7 @@ func messageJSON(message protoreflect.Message) (interface{}, error) {
record[string(field.Name())] = nil
} else {
value := message.Get(field)
jsonValue, err := fieldJSON(field, value)
jsonValue, err := o.fieldJSON(field, value, recursiveIndex+1)
if err != nil {
return nil, err
}
Expand All @@ -47,69 +47,80 @@ func messageJSON(message protoreflect.Message) (interface{}, error) {
continue
}
value := message.Get(field)
jsonValue, err := fieldJSON(field, value)
jsonValue, err := o.fieldJSON(field, value, recursiveIndex+1)
if err != nil {
return nil, err
}
record[string(field.Name())] = jsonValue
}
if o.OmitRootElement && recursiveIndex == 0 {
return record, nil
}
return map[string]interface{}{
string(desc.FullName()): record,
}, nil
}

func fieldJSON(field protoreflect.FieldDescriptor, value protoreflect.Value) (interface{}, error) {
func (o SchemaOptions) fieldJSON(
field protoreflect.FieldDescriptor,
value protoreflect.Value,
recursiveIndex int,
) (interface{}, error) {
if field.IsList() {
list := make([]interface{}, 0, value.List().Len())
for i := 0; i < value.List().Len(); i++ {
v := value.List().Get(i)
fieldValue, err := fieldKindJSON(field, v)
fieldValue, err := o.fieldKindJSON(field, v, recursiveIndex)
if err != nil {
return nil, err
}
list = append(list, fieldValue)
}
return unionValue("array", list), nil
return o.unionValue("array", list), nil
}
if field.IsMap() {
return encodeMap(field, value.Map())
return o.encodeMap(field, value.Map(), recursiveIndex)
}
return fieldKindJSON(field, value)
return o.fieldKindJSON(field, value, recursiveIndex)
}

func fieldKindJSON(field protoreflect.FieldDescriptor, value protoreflect.Value) (interface{}, error) {
func (o SchemaOptions) fieldKindJSON(
field protoreflect.FieldDescriptor,
value protoreflect.Value,
recursiveIndex int,
) (interface{}, error) {
switch field.Kind() {
case protoreflect.MessageKind, protoreflect.GroupKind:
return messageJSON(value.Message())
return o.messageJSON(value.Message(), recursiveIndex)
case protoreflect.EnumKind:
return unionValue(
return o.unionValue(
string(field.Enum().FullName()),
string(field.Enum().Values().Get(int(value.Enum())).Name()),
), nil
case protoreflect.StringKind:
return unionValue("string", value.String()), nil
return o.unionValue("string", value.String()), nil
case protoreflect.Int32Kind,
protoreflect.Fixed32Kind,
protoreflect.Sfixed32Kind,
protoreflect.Sint32Kind:
return unionValue("int", int32(value.Int())), nil
return o.unionValue("int", int32(value.Int())), nil
case protoreflect.Uint32Kind:
return unionValue("int", int32(value.Uint())), nil
return o.unionValue("int", int32(value.Uint())), nil
case protoreflect.Int64Kind,
protoreflect.Fixed64Kind,
protoreflect.Sfixed64Kind,
protoreflect.Sint64Kind:
return unionValue("long", value.Int()), nil
return o.unionValue("long", value.Int()), nil
case protoreflect.Uint64Kind:
return unionValue("long", int64(value.Uint())), nil
return o.unionValue("long", int64(value.Uint())), nil
case protoreflect.BoolKind:
return unionValue("boolean", value.Bool()), nil
return o.unionValue("boolean", value.Bool()), nil
case protoreflect.BytesKind:
return unionValue("bytes", value.Bytes()), nil
return o.unionValue("bytes", value.Bytes()), nil
case protoreflect.DoubleKind:
return unionValue("double", value.Float()), nil
return o.unionValue("double", value.Float()), nil
case protoreflect.FloatKind:
return unionValue("float", float32(value.Float())), nil
return o.unionValue("float", float32(value.Float())), nil
}
return value.Interface(), nil
}
Loading

0 comments on commit f5eaf84

Please sign in to comment.