Skip to content

Commit

Permalink
support for recursive structs
Browse files Browse the repository at this point in the history
  • Loading branch information
adrianiacobghiula committed Jun 23, 2024
1 parent e2e849d commit cd10c42
Show file tree
Hide file tree
Showing 22 changed files with 475 additions and 214 deletions.
4 changes: 0 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -142,10 +142,6 @@ For security reasons, the configuration `Config.MaxByteSliceSize` restricts the
by the `Reader`. The default maximum size is `1MiB` and is configurable. This is required to stop untrusted input from consuming all memory and
crashing the application. Should this not be need, setting a negative number will disable the behaviour.

### Recursive Structs

At this moment recursive structs are not supported. It is planned for the future.

## Benchmark

Benchmark source code can be found at: [https://github.com/nrwiersma/avro-benchmarks](https://github.com/nrwiersma/avro-benchmarks)
Expand Down
120 changes: 80 additions & 40 deletions codec.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,48 +71,88 @@ func (c *frozenConfig) DecoderOf(schema Schema, typ reflect2.Type) ValDecoder {
}

ptrType := typ.(*reflect2.UnsafePtrType)
decoder = decoderOfType(c, schema, ptrType.Elem())
decoder = newDecoderCreator(c).ofType(schema, ptrType.Elem())
c.addDecoderToCache(schema.CacheFingerprint(), rtype, decoder)
return decoder
}

func decoderOfType(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValDecoder {
if dec := createDecoderOfMarshaler(cfg, schema, typ); dec != nil {
type deferDecoder struct {
decoder ValDecoder
}

func (d *deferDecoder) Decode(ptr unsafe.Pointer, r *Reader) {
d.decoder.Decode(ptr, r)
}

type deferEncoder struct {
encoder ValEncoder
}

func (d *deferEncoder) Encode(ptr unsafe.Pointer, w *Writer) {
d.encoder.Encode(ptr, w)
}

type decoderCreator struct {
cfg *frozenConfig
decoders map[cacheKey]ValDecoder
}

func newDecoderCreator(cfg *frozenConfig) *decoderCreator {
return &decoderCreator{
cfg: cfg,
decoders: make(map[cacheKey]ValDecoder),
}
}

type encoderCreator struct {
cfg *frozenConfig
encoders map[cacheKey]ValEncoder
}

func newEncoderCreator(cfg *frozenConfig) *encoderCreator {
return &encoderCreator{
cfg: cfg,
encoders: make(map[cacheKey]ValEncoder),
}
}

func (d decoderCreator) ofType(schema Schema, typ reflect2.Type) ValDecoder {
if dec := d.ofMarshaler(schema, typ); dec != nil {
return dec
}

// Handle eface case when it isnt a union
// Handle eface case when it isn't a union
if typ.Kind() == reflect.Interface && schema.Type() != Union {
if _, ok := typ.(*reflect2.UnsafeIFaceType); !ok {
return newEfaceDecoder(cfg, schema)
return d.newEfaceDecoder(schema)
}
}

switch schema.Type() {
case String, Bytes, Int, Long, Float, Double, Boolean:
return createDecoderOfNative(schema.(*PrimitiveSchema), typ)

return d.ofNative(schema.(*PrimitiveSchema), typ)
case Record:
return createDecoderOfRecord(cfg, schema, typ)

key := cacheKey{fingerprint: schema.Fingerprint(), rtype: typ.RType()}
defDec := &deferDecoder{}
d.decoders[key] = defDec
defDec.decoder = d.ofRecord(schema.(*RecordSchema), typ)
return defDec.decoder
case Ref:
return decoderOfType(cfg, schema.(*RefSchema).Schema(), typ)

key := cacheKey{fingerprint: schema.(*RefSchema).Schema().Fingerprint(), rtype: typ.RType()}
if dec, f := d.decoders[key]; f {
return dec
}
return d.ofType(schema.(*RefSchema).Schema(), typ)
case Enum:
return createDecoderOfEnum(schema, typ)

return d.ofEnum(schema.(*EnumSchema), typ)
case Array:
return createDecoderOfArray(cfg, schema, typ)

return d.ofArray(schema.(*ArraySchema), typ)
case Map:
return createDecoderOfMap(cfg, schema, typ)

return d.ofMap(schema.(*MapSchema), typ)
case Union:
return createDecoderOfUnion(cfg, schema, typ)

return d.ofUnion(schema.(*UnionSchema), typ)
case Fixed:
return createDecoderOfFixed(schema, typ)

return d.ofFixed(schema.(*FixedSchema), typ)
default:
// It is impossible to get here with a valid schema
return &errorDecoder{err: fmt.Errorf("avro: schema type %s is unsupported", schema.Type())}
Expand All @@ -130,7 +170,7 @@ func (c *frozenConfig) EncoderOf(schema Schema, typ reflect2.Type) ValEncoder {
return encoder
}

encoder = encoderOfType(c, schema, typ)
encoder = newEncoderCreator(c).ofType(schema, typ)
if typ.LikePtr() {
encoder = &onePtrEncoder{encoder}
}
Expand All @@ -146,8 +186,8 @@ func (e *onePtrEncoder) Encode(ptr unsafe.Pointer, w *Writer) {
e.enc.Encode(noescape(unsafe.Pointer(&ptr)), w)
}

func encoderOfType(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValEncoder {
if enc := createEncoderOfMarshaler(cfg, schema, typ); enc != nil {
func (e encoderCreator) ofType(schema Schema, typ reflect2.Type) ValEncoder {
if enc := e.ofMarshaler(schema, typ); enc != nil {
return enc
}

Expand All @@ -157,29 +197,29 @@ func encoderOfType(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValEncod

switch schema.Type() {
case String, Bytes, Int, Long, Float, Double, Boolean, Null:
return createEncoderOfNative(schema, typ)

return e.ofNative(schema, typ)
case Record:
return createEncoderOfRecord(cfg, schema, typ)

key := cacheKey{fingerprint: schema.Fingerprint(), rtype: typ.RType()}
defEnc := &deferEncoder{}
e.encoders[key] = defEnc
defEnc.encoder = e.ofRecord(schema.(*RecordSchema), typ)
return defEnc.encoder
case Ref:
return encoderOfType(cfg, schema.(*RefSchema).Schema(), typ)

key := cacheKey{fingerprint: schema.(*RefSchema).Schema().Fingerprint(), rtype: typ.RType()}
if enc, f := e.encoders[key]; f {
return enc
}
return e.ofType(schema.(*RefSchema).Schema(), typ)
case Enum:
return createEncoderOfEnum(schema, typ)

return e.ofEnum(schema.(*EnumSchema), typ)
case Array:
return createEncoderOfArray(cfg, schema, typ)

return e.ofArray(schema.(*ArraySchema), typ)
case Map:
return createEncoderOfMap(cfg, schema, typ)

return e.ofMap(schema.(*MapSchema), typ)
case Union:
return createEncoderOfUnion(cfg, schema, typ)

return e.ofUnion(schema.(*UnionSchema), typ)
case Fixed:
return createEncoderOfFixed(schema, typ)

return e.ofFixed(schema.(*FixedSchema), typ)
default:
// It is impossible to get here with a valid schema
return &errorEncoder{err: fmt.Errorf("avro: schema type %s is unsupported", schema.Type())}
Expand Down
23 changes: 10 additions & 13 deletions codec_array.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,26 +10,25 @@ import (
"github.com/modern-go/reflect2"
)

func createDecoderOfArray(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValDecoder {
func (d decoderCreator) ofArray(schema *ArraySchema, typ reflect2.Type) ValDecoder {
if typ.Kind() == reflect.Slice {
return decoderOfArray(cfg, schema, typ)
return d.decoderOfArray(schema, typ)
}

return &errorDecoder{err: fmt.Errorf("avro: %s is unsupported for Avro %s", typ.String(), schema.Type())}
}

func createEncoderOfArray(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValEncoder {
func (e encoderCreator) ofArray(schema *ArraySchema, typ reflect2.Type) ValEncoder {
if typ.Kind() == reflect.Slice {
return encoderOfArray(cfg, schema, typ)
return e.encoderOfArray(schema, typ.(*reflect2.UnsafeSliceType))
}

return &errorEncoder{err: fmt.Errorf("avro: %s is unsupported for Avro %s", typ.String(), schema.Type())}
}

func decoderOfArray(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValDecoder {
arr := schema.(*ArraySchema)
func (d decoderCreator) decoderOfArray(schema *ArraySchema, typ reflect2.Type) ValDecoder {
sliceType := typ.(*reflect2.UnsafeSliceType)
decoder := decoderOfType(cfg, arr.Items(), sliceType.Elem())
decoder := d.ofType(schema.Items(), sliceType.Elem())

return &arrayDecoder{typ: sliceType, decoder: decoder}
}
Expand Down Expand Up @@ -74,14 +73,12 @@ func (d *arrayDecoder) Decode(ptr unsafe.Pointer, r *Reader) {
}
}

func encoderOfArray(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValEncoder {
arr := schema.(*ArraySchema)
sliceType := typ.(*reflect2.UnsafeSliceType)
encoder := encoderOfType(cfg, arr.Items(), sliceType.Elem())
func (e encoderCreator) encoderOfArray(schema *ArraySchema, typ *reflect2.UnsafeSliceType) ValEncoder {
encoder := e.ofType(schema.Items(), typ.Elem())

return &arrayEncoder{
blockLength: cfg.getBlockLength(),
typ: sliceType,
blockLength: e.cfg.getBlockLength(),
typ: typ,
encoder: encoder,
}
}
Expand Down
4 changes: 2 additions & 2 deletions codec_default.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ func createDefaultDecoder(cfg *frozenConfig, field *Field, typ reflect2.Type) Va
if defaultType == nil {
defaultType = reflect2.TypeOf((*null)(nil))
}
defaultEncoder := encoderOfType(cfg, field.Type(), defaultType)
defaultEncoder := newEncoderCreator(cfg).ofType(field.Type(), defaultType)
if defaultType.LikePtr() {
defaultEncoder = &onePtrEncoder{defaultEncoder}
}
Expand All @@ -37,7 +37,7 @@ func createDefaultDecoder(cfg *frozenConfig, field *Field, typ reflect2.Type) Va
}
return &defaultDecoder{
data: b,
decoder: decoderOfType(cfg, field.Type(), typ),
decoder: newDecoderCreator(cfg).ofType(field.Type(), typ),
}
}

Expand Down
4 changes: 2 additions & 2 deletions codec_dynamic.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@ type efaceDecoder struct {
dec ValDecoder
}

func newEfaceDecoder(cfg *frozenConfig, schema Schema) *efaceDecoder {
func (d decoderCreator) newEfaceDecoder(schema Schema) *efaceDecoder {
typ, _ := genericReceiver(schema)
dec := decoderOfType(cfg, schema, typ)
dec := d.ofType(schema, typ)

return &efaceDecoder{
schema: schema,
Expand Down
16 changes: 8 additions & 8 deletions codec_enum.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,27 +10,27 @@ import (
"github.com/modern-go/reflect2"
)

func createDecoderOfEnum(schema Schema, typ reflect2.Type) ValDecoder {
func (d decoderCreator) ofEnum(schema *EnumSchema, typ reflect2.Type) ValDecoder {
switch {
case typ.Kind() == reflect.String:
return &enumCodec{enum: schema.(*EnumSchema)}
return &enumCodec{enum: schema}
case typ.Implements(textUnmarshalerType):
return &enumTextMarshalerCodec{typ: typ, enum: schema.(*EnumSchema)}
return &enumTextMarshalerCodec{typ: typ, enum: schema}
case reflect2.PtrTo(typ).Implements(textUnmarshalerType):
return &enumTextMarshalerCodec{typ: typ, enum: schema.(*EnumSchema), ptr: true}
return &enumTextMarshalerCodec{typ: typ, enum: schema, ptr: true}
}

return &errorDecoder{err: fmt.Errorf("avro: %s is unsupported for Avro %s", typ.String(), schema.Type())}
}

func createEncoderOfEnum(schema Schema, typ reflect2.Type) ValEncoder {
func (e encoderCreator) ofEnum(schema *EnumSchema, typ reflect2.Type) ValEncoder {
switch {
case typ.Kind() == reflect.String:
return &enumCodec{enum: schema.(*EnumSchema)}
return &enumCodec{enum: schema}
case typ.Implements(textMarshalerType):
return &enumTextMarshalerCodec{typ: typ, enum: schema.(*EnumSchema)}
return &enumTextMarshalerCodec{typ: typ, enum: schema}
case reflect2.PtrTo(typ).Implements(textMarshalerType):
return &enumTextMarshalerCodec{typ: typ, enum: schema.(*EnumSchema), ptr: true}
return &enumTextMarshalerCodec{typ: typ, enum: schema, ptr: true}
}

return &errorEncoder{err: fmt.Errorf("avro: %s is unsupported for Avro %s", typ.String(), schema.Type())}
Expand Down
Loading

0 comments on commit cd10c42

Please sign in to comment.