Skip to content

Commit

Permalink
feat: allow custom template in avrogen (#392)
Browse files Browse the repository at this point in the history
  • Loading branch information
adrianiacobghiula authored May 6, 2024
1 parent 5dde47b commit f17a001
Show file tree
Hide file tree
Showing 5 changed files with 141 additions and 110 deletions.
48 changes: 36 additions & 12 deletions cmd/avrogen/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,19 @@ import (
"errors"
"flag"
"fmt"
"go/format"
"io"
"os"
"path/filepath"
"strings"

"github.com/hamba/avro/v2"
"github.com/hamba/avro/v2/gen"
"golang.org/x/tools/imports"
)

type config struct {
TemplateFileName string

Pkg string
Out string
Tags string
Expand All @@ -38,6 +40,7 @@ func realMain(args []string, stdout, stderr io.Writer) int {
flgs.BoolVar(&cfg.FullName, "fullname", false, "Use the full name of the Record schema to create the struct name.")
flgs.BoolVar(&cfg.Encoders, "encoders", false, "Generate encoders for the structs.")
flgs.StringVar(&cfg.Initialisms, "initialisms", "", "Custom initialisms <VAL>[,...] for struct and field names.")
flgs.StringVar(&cfg.TemplateFileName, "templateFileName", "", "Override output template with one loaded from file.")
flgs.Usage = func() {
_, _ = fmt.Fprintln(stderr, "Usage: avrogen [options] schemas")
_, _ = fmt.Fprintln(stderr, "Options:")
Expand All @@ -64,10 +67,17 @@ func realMain(args []string, stdout, stderr io.Writer) int {
return 1
}

template, err := loadTemplate(cfg.TemplateFileName)
if err != nil {
_, _ = fmt.Fprintln(stderr, "Error: "+err.Error())
return 1
}

opts := []gen.OptsFunc{
gen.WithFullName(cfg.FullName),
gen.WithEncoders(cfg.Encoders),
gen.WithInitialisms(initialisms),
gen.WithTemplate(string(template)),
}
g := gen.NewGenerator(cfg.Pkg, tags, opts...)
for _, file := range flgs.Args() {
Expand All @@ -84,30 +94,37 @@ func realMain(args []string, stdout, stderr io.Writer) int {
_, _ = fmt.Fprintf(stderr, "Error: could not generate code: %v\n", err)
return 3
}
formatted, err := format.Source(buf.Bytes())
formatted, err := imports.Process("", buf.Bytes(), nil)
if err != nil {
_, _ = fmt.Fprintf(stderr, "Error: could not format code: %v\n", err)
_ = writeOut(cfg.Out, stdout, buf.Bytes())
_, _ = fmt.Fprintf(stderr, "Error: generated code could not be formatted: %v\n", err)
return 3
}

err = writeOut(cfg.Out, stdout, formatted)
if err != nil {
_, _ = fmt.Fprintf(stderr, "Error: %v\n", err)
return 4
}
return 0
}

func writeOut(filename string, stdout io.Writer, bytes []byte) error {
writer := stdout
if cfg.Out != "" {
file, err := os.Create(cfg.Out)
if filename != "" {
file, err := os.Create(filepath.Clean(filename))
if err != nil {
_, _ = fmt.Fprintf(stderr, "Error: could not create output file: %v\n", err)
return 4
return fmt.Errorf("could not create output file: %w", err)
}
defer func() { _ = file.Close() }()

writer = file
}

if _, err := writer.Write(formatted); err != nil {
_, _ = fmt.Fprintf(stderr, "Error: could not write code: %v\n", err)
return 4
if _, err := writer.Write(bytes); err != nil {
return fmt.Errorf("could not write code: %w", err)
}

return 0
return nil
}

func validateOpts(nargs int, cfg config) error {
Expand Down Expand Up @@ -172,3 +189,10 @@ func parseInitialisms(raw string) ([]string, error) {

return result, nil
}

func loadTemplate(templateFileName string) ([]byte, error) {
if templateFileName == "" {
return nil, nil
}
return os.ReadFile(filepath.Clean(templateFileName))
}
147 changes: 49 additions & 98 deletions gen/gen.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,17 @@ package gen

import (
"bytes"
_ "embed"
"errors"
"fmt"
"go/format"
"io"
"maps"
"strings"
"text/template"

"github.com/ettle/strcase"
"github.com/hamba/avro/v2"
"golang.org/x/tools/imports"
)

// Config configures the code generation.
Expand Down Expand Up @@ -39,60 +41,8 @@ const (
UpperCamel TagStyle = "upper-camel"
)

const outputTemplate = `package {{ .PackageName }}
// Code generated by avro/gen. DO NOT EDIT.
{{- $encoders := .WithEncoders }}
{{ if len .Imports }}
import (
{{- range .Imports }}
"{{ . }}"
{{- end }}
{{ if len .ThirdPartyImports }}
{{- range .ThirdPartyImports }}
"{{ . }}"
{{- end }}
{{ end }}
)
{{ else if len .ThirdPartyImports }}
import (
{{- range .ThirdPartyImports }}
"{{ . }}"
{{- end }}
)
{{ end }}
{{- range .Typedefs }}
// {{ .Name }} is a generated struct.
type {{ .Name }} struct {
{{- range .Fields }}
{{ .Name }} {{ .Type }} {{ .Tag }}
{{- end }}
}
{{- if $encoders }}
var schema{{ .Name }} = avro.MustParse(` + "`{{ .Schema }}`" + `)
// Schema returns the schema for {{ .Name }}.
func (o *{{ .Name }}) Schema() avro.Schema {
return schema{{ .Name }}
}
// Unmarshal decodes b into the receiver.
func (o *{{ .Name }}) Unmarshal(b []byte) error {
return avro.Unmarshal(o.Schema(), b, o)
}
// Marshal encodes the receiver.
func (o *{{ .Name }}) Marshal() ([]byte, error) {
return avro.Marshal(o.Schema(), o)
}
{{- end }}
{{ end }}`
//go:embed output_template.tmpl
var outputTemplate string

var primitiveMappings = map[avro.Type]string{
"string": "string",
Expand Down Expand Up @@ -133,9 +83,10 @@ func StructFromSchema(schema avro.Schema, w io.Writer, cfg Config) error {
return err
}

formatted, err := format.Source(buf.Bytes())
formatted, err := imports.Process("", buf.Bytes(), nil)
if err != nil {
return fmt.Errorf("could not format code: %w", err)
_, _ = w.Write(buf.Bytes())
return fmt.Errorf("generated code could not be formatted: %w", err)
}

_, err = w.Write(formatted)
Expand Down Expand Up @@ -172,8 +123,19 @@ func WithInitialisms(ss []string) OptsFunc {
}
}

// WithTemplate configures the generator to use a custom template provided by the user.
func WithTemplate(template string) OptsFunc {
return func(g *Generator) {
if template == "" {
return
}
g.template = template
}
}

// Generator generates Go structs from schemas.
type Generator struct {
template string
pkg string
tags map[string]TagStyle
fullName bool
Expand All @@ -189,9 +151,13 @@ type Generator struct {

// NewGenerator returns a generator.
func NewGenerator(pkg string, tags map[string]TagStyle, opts ...OptsFunc) *Generator {
clonedTags := maps.Clone(tags)
delete(clonedTags, "avro")

g := &Generator{
pkg: pkg,
tags: tags,
template: outputTemplate,
pkg: pkg,
tags: clonedTags,
}

for _, opt := range opts {
Expand Down Expand Up @@ -266,8 +232,7 @@ func (g *Generator) resolveRecordSchema(schema *avro.RecordSchema) string {
fields := make([]field, len(schema.Fields()))
for i, f := range schema.Fields() {
typ := g.generate(f.Type())
tag := f.Name()
fields[i] = g.newField(g.nameCaser.ToPascal(f.Name()), typ, tag)
fields[i] = g.newField(g.nameCaser.ToPascal(f.Name()), typ, f.Doc(), f.Name())
}

typeName := g.resolveTypeName(schema)
Expand Down Expand Up @@ -334,35 +299,13 @@ func (g *Generator) resolveLogicalSchema(logicalType avro.LogicalType) string {
return typ
}

func (g *Generator) newField(name, typ, tag string) field {
tagLine := fmt.Sprintf(`avro:"%s"`, tag)
for tagName, style := range g.tags {
if tagName == "avro" {
continue
}
tagLine += fmt.Sprintf(` %s:"%s"`, tagName, formatTag(tag, style))
}
func (g *Generator) newField(name, typ, avroFieldDoc, avroFieldName string) field {
return field{
Name: name,
Type: typ,
Tag: fmt.Sprintf("`%s`", tagLine),
}
}

func formatTag(tag string, style TagStyle) string {
switch style {
case Kebab:
return strcase.ToKebab(tag)
case UpperCamel:
return strcase.ToPascal(tag)
case Camel:
return strcase.ToCamel(tag)
case Snake:
return strcase.ToSnake(tag)
case Original:
fallthrough
default:
return tag
Name: name,
Type: typ,
AvroFieldName: avroFieldName,
AvroFieldDoc: avroFieldDoc,
Tags: g.tags,
}
}

Expand All @@ -386,7 +329,14 @@ func (g *Generator) addThirdPartyImport(pkg string) {

// Write writes Go code from the parsed schemas.
func (g *Generator) Write(w io.Writer) error {
parsed, err := template.New("out").Parse(outputTemplate)
parsed, err := template.New("out").
Funcs(template.FuncMap{
"kebab": strcase.ToKebab,
"upperCamel": strcase.ToPascal,
"camel": strcase.ToCamel,
"snake": strcase.ToSnake,
}).
Parse(g.template)
if err != nil {
return err
}
Expand All @@ -398,11 +348,10 @@ func (g *Generator) Write(w io.Writer) error {
ThirdPartyImports []string
Typedefs []typedef
}{
WithEncoders: g.encoders,
PackageName: g.pkg,
Imports: g.imports,
ThirdPartyImports: g.thirdPartyImports,
Typedefs: g.typedefs,
WithEncoders: g.encoders,
PackageName: g.pkg,
Imports: append(g.imports, g.thirdPartyImports...),
Typedefs: g.typedefs,
}
return parsed.Execute(w, data)
}
Expand All @@ -422,7 +371,9 @@ func newType(name string, fields []field, schema string) typedef {
}

type field struct {
Name string
Type string
Tag string
Name string
Type string
AvroFieldName string
AvroFieldDoc string
Tags map[string]TagStyle
}
50 changes: 50 additions & 0 deletions gen/output_template.tmpl
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
package {{ .PackageName }}

// Code generated by avro/gen. DO NOT EDIT.

{{- $encoders := .WithEncoders }}
{{ if len .Imports }}
import (
{{- range .Imports }}
"{{ . }}"
{{- end }}
)
{{ end }}

{{- range .Typedefs }}
// {{ .Name }} is a generated struct.
type {{ .Name }} struct {
{{- range .Fields }}
{{- $f := . }}
{{ .Name }} {{ .Type }} `avro:"{{ $f.AvroFieldName }}"
{{- range $tag, $style := .Tags }}
{{- " "}}{{ $tag }}:"
{{- if eq $style "kebab" }}{{ kebab $f.AvroFieldName }}
{{- else if eq $style "upper-camel"}}{{ upperCamel $f.AvroFieldName }}
{{- else if eq $style "camel"}}{{ camel $f.AvroFieldName }}
{{- else if eq $style "snake"}}{{ snake $f.AvroFieldName }}
{{- else}}{{ $f.AvroFieldName }}
{{- end}}"
{{- end }}`
{{- end }}
}

{{- if $encoders }}
var schema{{ .Name }} = avro.MustParse(`{{ .Schema }}`)

// Schema returns the schema for {{ .Name }}.
func (o *{{ .Name }}) Schema() avro.Schema {
return schema{{ .Name }}
}

// Unmarshal decodes b into the receiver.
func (o *{{ .Name }}) Unmarshal(b []byte) error {
return avro.Unmarshal(o.Schema(), b, o)
}

// Marshal encodes the receiver.
func (o *{{ .Name }}) Marshal() ([]byte, error) {
return avro.Marshal(o.Schema(), o)
}
{{- end }}
{{ end }}
2 changes: 2 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,7 @@ require (
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
golang.org/x/mod v0.17.0 // indirect
golang.org/x/tools v0.20.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)
Loading

0 comments on commit f17a001

Please sign in to comment.