diff --git a/schema.go b/schema.go index d2b93df..ecd11c9 100644 --- a/schema.go +++ b/schema.go @@ -76,6 +76,69 @@ const ( Duration LogicalType = "duration" ) +// customLogicalSchema is a custom logical type schema that is not part of the Avro specification. +// It wraps a primitive type schema and thus supports no additional properties. +type customLogicalSchema struct { + PrimitiveLogicalSchema +} + +type customSchemaKey = struct { + typ Type + ltyp LogicalType +} + +var customLogicalSchemas sync.Map // map[customSchemaKey]*CustomLogicalSchema + +func addCustomLogicalSchema(typ Type, ltyp LogicalType) { + key := customSchemaKey{typ, ltyp} + customLogicalSchemas.Store(key, &customLogicalSchema{ + PrimitiveLogicalSchema: PrimitiveLogicalSchema{typ: ltyp}, + }) +} + +func getCustomLogicalSchema(typ Type, ltyp LogicalType) LogicalSchema { + key := customSchemaKey{typ, ltyp} + if ls, ok := customLogicalSchemas.Load(key); ok { + return ls.(*customLogicalSchema) + } + return nil +} + +// RegisterCustomLogicalType registers a custom logical type that is not part of the +// Avro specification for the given types. +// It returns an error if the logical type conflicts with a predefined logical type. +func RegisterCustomLogicalType(ltyp LogicalType, types ...Type) error { + // Ensure that the custom logical type does not overwrite a primitive type + switch ltyp { + case Decimal, + UUID, + Date, + TimeMillis, + TimeMicros, + TimestampMillis, + TimestampMicros, + LocalTimestampMillis, + LocalTimestampMicros, + Duration: + return errors.New("logical type conflicts with a predefined logical type") + } + + // Check that all of the given type supports logical types + for _, typ := range types { + switch typ { + case Ref, Union, Null: + return fmt.Errorf("type %q does not support logical types", typ) + } + } + + // Register the custom logical type + for _, typ := range types { + addCustomLogicalSchema(typ, ltyp) + } + + return nil +} + // Action is a field action used during decoding process. type Action string @@ -396,12 +459,13 @@ func (p properties) marshalPropertiesToJSON(buf *bytes.Buffer) error { } type schemaConfig struct { - aliases []string - doc string - def any - order Order - props map[string]any - wfp *[32]byte + aliases []string + doc string + def any + order Order + props map[string]any + wfp *[32]byte + customLogicalType LogicalType } // SchemaOption is a function that sets a schema option. @@ -414,6 +478,16 @@ func WithAliases(aliases []string) SchemaOption { } } +// WithCustomLogicalType sets a custom logical type on a schema. +// Make sure to register the custom logical type before using it, +// otherwise it will be ignored. +// See RegisterCustomLogicalType. +func WithCustomLogicalType(ltyp LogicalType) SchemaOption { + return func(opts *schemaConfig) { + opts.customLogicalType = ltyp + } +} + // WithDoc sets the doc on a schema. func WithDoc(doc string) SchemaOption { return func(opts *schemaConfig) { @@ -477,6 +551,11 @@ func NewPrimitiveSchema(t Type, l LogicalSchema, opts ...SchemaOption) *Primitiv opt(&cfg) } + // If the logical schema is nil, use the custom logical schema. + if l == nil { + l = getCustomLogicalSchema(t, cfg.customLogicalType) + } + return &PrimitiveSchema{ properties: newProperties(cfg.props, schemaReserved), cacheFingerprinter: cacheFingerprinter{writerFingerprint: cfg.wfp}, @@ -552,6 +631,7 @@ type RecordSchema struct { isError bool fields []*Field doc string + logical LogicalSchema } // NewRecordSchema creates a new record schema instance. @@ -572,6 +652,7 @@ func NewRecordSchema(name, namespace string, fields []*Field, opts ...SchemaOpti cacheFingerprinter: cacheFingerprinter{writerFingerprint: cfg.wfp}, fields: fields, doc: cfg.doc, + logical: getCustomLogicalSchema(Record, cfg.customLogicalType), }, nil } @@ -592,6 +673,11 @@ func (s *RecordSchema) Type() Type { return Record } +// Logical returns the logical schema or nil. +func (s *RecordSchema) Logical() LogicalSchema { + return s.logical +} + // Doc returns the documentation of a record. func (s *RecordSchema) Doc() string { return s.doc @@ -622,6 +708,12 @@ func (s *RecordSchema) String() string { fields = fields[:len(fields)-1] } + if s.logical != nil { + return fmt.Sprintf("{\"name\":\"%s\", \"type\":\"%s\", \"fields\":[%s]\", %s}", + s.FullName(), typ, fields, s.logical.String(), + ) + } + return `{"name":"` + s.FullName() + `","type":"` + typ + `","fields":[` + fields + `]}` } @@ -659,6 +751,9 @@ func (s *RecordSchema) MarshalJSON() ([]byte, error) { if err := s.marshalPropertiesToJSON(buf); err != nil { return nil, err } + if s.logical != nil { + buf.WriteString(`,"logicalType":"` + string(s.logical.Type()) + `"`) + } buf.WriteString("}") return buf.Bytes(), nil } @@ -876,6 +971,7 @@ type EnumSchema struct { symbols []string def string doc string + logical LogicalSchema // encodedSymbols is the symbols of the encoded value. // It's only used in the context of write-read schema resolution. @@ -918,6 +1014,7 @@ func NewEnumSchema(name, namespace string, symbols []string, opts ...SchemaOptio symbols: symbols, def: def, doc: cfg.doc, + logical: getCustomLogicalSchema(Enum, cfg.customLogicalType), }, nil } @@ -979,6 +1076,11 @@ func (s *EnumSchema) HasDefault() bool { return s.def != "" } +// Logical returns the logical schema or nil. +func (s *EnumSchema) Logical() LogicalSchema { + return s.logical +} + // String returns the canonical form of the schema. func (s *EnumSchema) String() string { symbols := "" @@ -989,6 +1091,11 @@ func (s *EnumSchema) String() string { symbols = symbols[:len(symbols)-1] } + if s.logical != nil { + return fmt.Sprintf("{\"name\":\"%s\", \"type\":\"enum\", \"symbols\":[%s]\", %s}", + s.FullName(), symbols, s.logical.String()) + } + return `{"name":"` + s.FullName() + `","type":"enum","symbols":[` + symbols + `]}` } @@ -1025,6 +1132,9 @@ func (s *EnumSchema) MarshalJSON() ([]byte, error) { if err := s.marshalPropertiesToJSON(buf); err != nil { return nil, err } + if s.logical != nil { + buf.WriteString(`,"logicalType":"` + string(s.logical.Type()) + `"`) + } buf.WriteString("}") return buf.Bytes(), nil } @@ -1055,7 +1165,8 @@ type ArraySchema struct { fingerprinter cacheFingerprinter - items Schema + items Schema + logical LogicalSchema } // NewArraySchema creates an array schema instance. @@ -1069,6 +1180,7 @@ func NewArraySchema(items Schema, opts ...SchemaOption) *ArraySchema { properties: newProperties(cfg.props, schemaReserved), cacheFingerprinter: cacheFingerprinter{writerFingerprint: cfg.wfp}, items: items, + logical: getCustomLogicalSchema(Array, cfg.customLogicalType), } } @@ -1082,8 +1194,16 @@ func (s *ArraySchema) Items() Schema { return s.items } +// Logical returns the logical schema or nil. +func (s *ArraySchema) Logical() LogicalSchema { + return s.logical +} + // String returns the canonical form of the schema. func (s *ArraySchema) String() string { + if s.logical != nil { + return `{"type":"array","items":` + s.items.String() + `,"` + s.logical.String() + `"}` + } return `{"type":"array","items":` + s.items.String() + `}` } @@ -1100,6 +1220,9 @@ func (s *ArraySchema) MarshalJSON() ([]byte, error) { if err = s.marshalPropertiesToJSON(buf); err != nil { return nil, err } + if s.logical != nil { + buf.WriteString(`,"logicalType":"` + string(s.logical.Type()) + `"`) + } buf.WriteString("}") return buf.Bytes(), nil } @@ -1125,7 +1248,8 @@ type MapSchema struct { fingerprinter cacheFingerprinter - values Schema + values Schema + logical LogicalSchema } // NewMapSchema creates a map schema instance. @@ -1139,6 +1263,7 @@ func NewMapSchema(values Schema, opts ...SchemaOption) *MapSchema { properties: newProperties(cfg.props, schemaReserved), cacheFingerprinter: cacheFingerprinter{writerFingerprint: cfg.wfp}, values: values, + logical: getCustomLogicalSchema(Map, cfg.customLogicalType), } } @@ -1152,8 +1277,16 @@ func (s *MapSchema) Values() Schema { return s.values } +// Logical returns the logical schema or nil. +func (s *MapSchema) Logical() LogicalSchema { + return s.logical +} + // String returns the canonical form of the schema. func (s *MapSchema) String() string { + if s.logical != nil { + return `{"type":"map","values":` + s.values.String() + `,"` + s.logical.String() + `"}` + } return `{"type":"map","values":` + s.values.String() + `}` } @@ -1170,6 +1303,9 @@ func (s *MapSchema) MarshalJSON() ([]byte, error) { if err := s.marshalPropertiesToJSON(buf); err != nil { return nil, err } + if s.logical != nil { + buf.WriteString(`,"logicalType":"` + string(s.logical.Type()) + `"`) + } buf.WriteString("}") return buf.Bytes(), nil } diff --git a/schema_parse.go b/schema_parse.go index 3022116..0c39f14 100644 --- a/schema_parse.go +++ b/schema_parse.go @@ -202,17 +202,18 @@ func parsePrimitiveLogicalType(typ Type, lt string, prec, scale int) LogicalSche return parseDecimalLogicalType(-1, prec, scale) } - return nil + return getCustomLogicalSchema(typ, ltyp) } type recordSchema struct { - Type string `mapstructure:"type"` - Name string `mapstructure:"name"` - Namespace string `mapstructure:"namespace"` - Aliases []string `mapstructure:"aliases"` - Doc string `mapstructure:"doc"` - Fields []map[string]any `mapstructure:"fields"` - Props map[string]any `mapstructure:",remain"` + Type string `mapstructure:"type"` + Name string `mapstructure:"name"` + Namespace string `mapstructure:"namespace"` + Aliases []string `mapstructure:"aliases"` + Doc string `mapstructure:"doc"` + Fields []map[string]any `mapstructure:"fields"` + LogicalType string `mapstructure:"logicalType"` + Props map[string]any `mapstructure:",remain"` } func parseRecord(typ Type, namespace string, m map[string]any, seen seenCache, cache *SchemaCache) (Schema, error) { @@ -243,11 +244,11 @@ func parseRecord(typ Type, namespace string, m map[string]any, seen seenCache, c switch typ { case Record: rec, err = NewRecordSchema(r.Name, r.Namespace, fields, - WithAliases(r.Aliases), WithDoc(r.Doc), WithProps(r.Props), + WithAliases(r.Aliases), WithDoc(r.Doc), WithProps(r.Props), WithCustomLogicalType(LogicalType(r.LogicalType)), ) case Error: rec, err = NewErrorRecordSchema(r.Name, r.Namespace, fields, - WithAliases(r.Aliases), WithDoc(r.Doc), WithProps(r.Props), + WithAliases(r.Aliases), WithDoc(r.Doc), WithProps(r.Props), WithCustomLogicalType(LogicalType(r.LogicalType)), ) } if err != nil { @@ -276,13 +277,14 @@ func parseRecord(typ Type, namespace string, m map[string]any, seen seenCache, c } type fieldSchema struct { - Name string `mapstructure:"name"` - Aliases []string `mapstructure:"aliases"` - Type any `mapstructure:"type"` - Doc string `mapstructure:"doc"` - Default any `mapstructure:"default"` - Order Order `mapstructure:"order"` - Props map[string]any `mapstructure:",remain"` + Name string `mapstructure:"name"` + Aliases []string `mapstructure:"aliases"` + Type any `mapstructure:"type"` + Doc string `mapstructure:"doc"` + Default any `mapstructure:"default"` + Order Order `mapstructure:"order"` + LogicalType string `mapstructure:"logicalType"` + Props map[string]any `mapstructure:",remain"` } func parseField(namespace string, m map[string]any, seen seenCache, cache *SchemaCache) (*Field, error) { @@ -312,6 +314,7 @@ func parseField(namespace string, m map[string]any, seen seenCache, cache *Schem field, err := NewField(f.Name, typ, WithDefault(f.Default), WithAliases(f.Aliases), WithDoc(f.Doc), WithOrder(f.Order), WithProps(f.Props), + WithCustomLogicalType(LogicalType(f.LogicalType)), ) if err != nil { return nil, err @@ -321,14 +324,15 @@ func parseField(namespace string, m map[string]any, seen seenCache, cache *Schem } type enumSchema struct { - Name string `mapstructure:"name"` - Namespace string `mapstructure:"namespace"` - Aliases []string `mapstructure:"aliases"` - Type string `mapstructure:"type"` - Doc string `mapstructure:"doc"` - Symbols []string `mapstructure:"symbols"` - Default string `mapstructure:"default"` - Props map[string]any `mapstructure:",remain"` + Name string `mapstructure:"name"` + Namespace string `mapstructure:"namespace"` + Aliases []string `mapstructure:"aliases"` + Type string `mapstructure:"type"` + Doc string `mapstructure:"doc"` + Symbols []string `mapstructure:"symbols"` + Default string `mapstructure:"default"` + LogicalType string `mapstructure:"logicalType"` + Props map[string]any `mapstructure:",remain"` } func parseEnum(namespace string, m map[string]any, seen seenCache, cache *SchemaCache) (Schema, error) { @@ -349,6 +353,7 @@ func parseEnum(namespace string, m map[string]any, seen seenCache, cache *Schema enum, err := NewEnumSchema(e.Name, e.Namespace, e.Symbols, WithDefault(e.Default), WithAliases(e.Aliases), WithDoc(e.Doc), WithProps(e.Props), + WithCustomLogicalType(LogicalType(e.LogicalType)), ) if err != nil { return nil, err @@ -368,8 +373,9 @@ func parseEnum(namespace string, m map[string]any, seen seenCache, cache *Schema } type arraySchema struct { - Items any `mapstructure:"items"` - Props map[string]any `mapstructure:",remain"` + Items any `mapstructure:"items"` + LogicalType string `mapstructure:"logicalType"` + Props map[string]any `mapstructure:",remain"` } func parseArray(namespace string, m map[string]any, seen seenCache, cache *SchemaCache) (Schema, error) { @@ -389,12 +395,13 @@ func parseArray(namespace string, m map[string]any, seen seenCache, cache *Schem return nil, err } - return NewArraySchema(schema, WithProps(a.Props)), nil + return NewArraySchema(schema, WithProps(a.Props), WithCustomLogicalType(LogicalType(a.LogicalType))), nil } type mapSchema struct { - Values any `mapstructure:"values"` - Props map[string]any `mapstructure:",remain"` + Values any `mapstructure:"values"` + LogicalType string `mapstructure:"logicalType"` + Props map[string]any `mapstructure:",remain"` } func parseMap(namespace string, m map[string]any, seen seenCache, cache *SchemaCache) (Schema, error) { @@ -414,7 +421,7 @@ func parseMap(namespace string, m map[string]any, seen seenCache, cache *SchemaC return nil, err } - return NewMapSchema(schema, WithProps(ms.Props)), nil + return NewMapSchema(schema, WithProps(ms.Props), WithCustomLogicalType(LogicalType(ms.LogicalType))), nil } func parseUnion(namespace string, v []any, seen seenCache, cache *SchemaCache) (Schema, error) { @@ -494,7 +501,7 @@ func parseFixedLogicalType(size int, lt string, prec, scale int) LogicalSchema { return parseDecimalLogicalType(size, prec, scale) } - return nil + return getCustomLogicalSchema(Fixed, ltyp) } func parseDecimalLogicalType(size, prec, scale int) LogicalSchema { diff --git a/schema_test.go b/schema_test.go index 1fa1461..b5c2481 100644 --- a/schema_test.go +++ b/schema_test.go @@ -983,6 +983,14 @@ func TestFixedSchema_HandlesProps(t *testing.T) { } func TestSchema_LogicalTypes(t *testing.T) { + customType := avro.LogicalType("customType") + err := avro.RegisterCustomLogicalType(customType, avro.Int, avro.Enum, avro.Array, avro.Map, avro.Record) + require.NoError(t, err) + + // should not be able to register a type with the same name as a built-in type + err = avro.RegisterCustomLogicalType(avro.Date, avro.Double) + require.Error(t, err) + tests := []struct { name string schema string @@ -997,6 +1005,47 @@ func TestSchema_LogicalTypes(t *testing.T) { wantType: avro.Int, wantLogical: false, }, + { + name: "Invalid", + schema: `{"type": "long", "logicalType": "customType"}`, + wantType: avro.Long, + wantLogical: false, + }, + { + name: "Primitive Custom Type", + schema: `{"type": "int", "logicalType": "customType"}`, + wantType: avro.Int, + wantLogical: true, + wantLogicalType: customType, + }, + { + name: "Enum Custom Type", + schema: `{"type":"enum", "name":"test", "namespace": "org.hamba.avro", "symbols":["TEST"], "logicalType": "customType"}`, + wantType: avro.Enum, + wantLogical: true, + wantLogicalType: customType, + }, + { + name: "Array Custom Type", + schema: `{"type":"array", "items": "int", "logicalType": "customType"}`, + wantType: avro.Array, + wantLogical: true, + wantLogicalType: customType, + }, + { + name: "Map Custom Type", + schema: `{"type":"map", "values": "int", "logicalType": "customType"}`, + wantType: avro.Map, + wantLogical: true, + wantLogicalType: customType, + }, + { + name: "Record Custom Type", + schema: `{"type": "record", "name": "Foo", "fields": [{"name": "baz", "type": "string"}], "logicalType": "customType"}`, + wantType: avro.Record, + wantLogical: true, + wantLogicalType: customType, + }, { name: "Date", schema: `{"type": "int", "logicalType": "date"}`,