From d0bf6eba7402a13f1bad512efc173a6dcf3e2394 Mon Sep 17 00:00:00 2001 From: Himanshu Rai <36773027+hi-rai@users.noreply.github.com> Date: Fri, 2 Aug 2024 11:16:05 +0530 Subject: [PATCH] Fix handling for nested generic types (#12) * Fix handling for nested generic types. For non-defined generic types, the type arguments are appended to the generated schema names * Only keep necessary types/methods/functions as exported * Update README * Remove name and generic types arguments from CustomFn type. Their use was unclear and if required, can be derived from the type argument * Update golangci-lint-action and golangci-lint --- .github/workflows/ci.yml | 5 +- README.md | 89 +++++++++++++---- custom/decimal/decimal.go | 2 +- custom/optional/optional.go | 4 +- zod.go | 191 ++++++++++++++++++++++-------------- zod_test.go | 38 ++++++- 6 files changed, 232 insertions(+), 97 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 48f7a13..dfe1186 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -36,9 +36,10 @@ jobs: run: diff <(echo -n) <(gofumpt -d .) - name: golangci-lint - uses: golangci/golangci-lint-action@v3 + uses: golangci/golangci-lint-action@v6 with: - version: v1.52.2 + version: v1.59 + args: --verbose --timeout=3m - name: Test run: make test diff --git a/README.md b/README.md index ef0b947..7fcf48e 100644 --- a/README.md +++ b/README.md @@ -4,33 +4,65 @@ Zod + Generate = Zen Converts Go structs with go-validator validations to Zod schemas. -Zen supports self-referential types. +Zen supports self-referential types and generic types. Other cyclic types (apart from self referential types) are not supported +as they are not supported by zod itself. -## Usage: +## Usage ```go type Post struct { Title string `validate:"required"` } type User struct { - Name string `validate:"required"` - Nickname *string // pointers become optional - Age int `validate:"min=18"` - Height float64 `validate:"min=0,max=3"` - Tags []string `validate:"min=1"` + Name string `validate:"required"` + Nickname *string // pointers become optional + Age int `validate:"min=18"` + Height float64 `validate:"min=0,max=3"` + Tags []string `validate:"min=1"` Favourites []struct { // nested structs are kept inline Name string `validate:"required"` } Posts []Post // external structs are emitted as separate exports } -StructToZodSchema(User{}) +fmt.Print(zen.StructToZodSchema(User{})) + +// Self referential types are supported +type Tree struct { + Value int + Children []Tree +} +fmt.Print(zen.StructToZodSchema(Tree{})) + +// We can also use create a converter and convert multiple types together +c := zen.NewConverter(nil) + +// Generic types are also supported +type GenericPair[T any, U any] struct { + First T + Second U +} +type StringIntPair GenericPair[string, int] +c.AddType(StringIntPair{}) + +// For non-defined types, the type arguments are appended to the generic type +// name to get the type name +c.AddType(GenericPair[int, bool]{}) + +// Even nested generic types are supported +type PairMap[K comparable, T any, U any] struct { + Items map[K]GenericPair[T, U] `json:"items"` +} +c.AddType(PairMap[string, int, bool]{}) + +// Now export the generated schemas. Duplicate schemas are skipped +fmt.Print(c.Export()) ``` Outputs: ```typescript export const PostSchema = z.object({ - Title: z.string().min(1), + Title: z.string().min(1), }) export type Post = z.infer @@ -46,9 +78,33 @@ export const UserSchema = z.object({ Posts: PostSchema.array().nullable(), }) export type User = z.infer -``` -It also works without any validations. +export type Tree = { + Value: number, + Children: Tree[] | null, +} +export const TreeSchema: z.ZodType = z.object({ + Value: z.number(), + Children: z.lazy(() => TreeSchema).array().nullable(), +}) + +export const StringIntPairSchema = z.object({ + First: z.string(), + Second: z.number(), +}) +export type StringIntPair = z.infer + +export const GenericPairIntBoolSchema = z.object({ + First: z.number(), + Second: z.boolean(), +}) +export type GenericPairIntBool = z.infer + +export const PairMapStringIntBoolSchema = z.object({ + items: z.record(z.string(), GenericPairIntBoolSchema).nullable(), +}) +export type PairMapStringIntBool = z.infer +``` ### How we use it at Hypersequent @@ -61,7 +117,7 @@ It also works without any validations. converter := zen.NewConverter(make(map[string]zen.CustomFn)) {{range .TypesToGenerate}} - converter.AddType(types.{{.}}{}) + converter.AddType(types.{{.}}{}) {{end}} schema := converter.Export() @@ -69,11 +125,11 @@ It also works without any validations. ## Custom Types -You can pass type name mappings to custom conversion functions: +We can pass type name mappings to custom conversion functions: ```go c := zen.NewConverter(map[string]zen.CustomFn{ - "github.com/shopspring/decimal.Decimal": func (c *zen.Converter, t reflect.Type, s, g string, i int) string { + "github.com/shopspring/decimal.Decimal": func (c *zen.Converter, t reflect.Type, v string, i int) string { // Shopspring's decimal type serialises to a string. return "z.string()" }, @@ -98,11 +154,10 @@ There are some custom types with tests in the "custom" directory. The function signature for custom type handlers is: ```go -func(c *zen.Converter, t reflect.Type, typeName, genericTypeName string, indentLevel int) string +func(c *Converter, t reflect.Type, validate string, indent int) string ``` -You can use the Converter to process nested types. The `genericTypeName` is the name of the `T` in `Generic[T]` and the -indent level is for passing to other converter APIs. +We can use `c` to process nested types. Indent level is for passing to other converter APIs. ## Supported validations diff --git a/custom/decimal/decimal.go b/custom/decimal/decimal.go index 9289ed3..26952af 100644 --- a/custom/decimal/decimal.go +++ b/custom/decimal/decimal.go @@ -8,7 +8,7 @@ import ( var ( DecimalType = "github.com/shopspring/decimal.Decimal" - DecimalFunc = func(c *zen.Converter, t reflect.Type, s, g string, validate string, i int) string { + DecimalFunc = func(c *zen.Converter, t reflect.Type, validate string, i int) string { // Shopspring's decimal type serialises to a string. return "z.string()" } diff --git a/custom/optional/optional.go b/custom/optional/optional.go index 68b0d7d..52974a0 100644 --- a/custom/optional/optional.go +++ b/custom/optional/optional.go @@ -9,7 +9,7 @@ import ( var ( OptionalType = "4d63.com/optional.Optional" - OptionalFunc = func(c *zen.Converter, t reflect.Type, s string, g string, validate string, i int) string { - return fmt.Sprintf("%s.optional().nullish()", c.ConvertType(t.Elem(), s, validate, i)) + OptionalFunc = func(c *zen.Converter, t reflect.Type, validate string, i int) string { + return fmt.Sprintf("%s.optional().nullish()", c.ConvertType(t.Elem(), validate, i)) } ) diff --git a/zod.go b/zod.go index 6e2221e..2ddf9d4 100644 --- a/zod.go +++ b/zod.go @@ -9,6 +9,9 @@ import ( "strings" ) +// NewConverter initializes and returns a new converter instance. The custom handler +// function map should be keyed on the fully qualified type name (excluding generic +// type arguments), ie. package.typename. func NewConverter(custom map[string]CustomFn) Converter { c := Converter{ prefix: "", @@ -19,10 +22,16 @@ func NewConverter(custom map[string]CustomFn) Converter { return c } +// AddType converts a struct type to corresponding zod schema. AddType can be called +// multiple times, followed by Export to get the corresonding zod schemas. func (c *Converter) AddType(input interface{}) { t := reflect.TypeOf(input) - name := t.Name() + if t.Kind() != reflect.Struct { + panic("input must be a struct") + } + + name := typeName(t) if _, ok := c.outputs[name]; ok { return } @@ -33,12 +42,20 @@ func (c *Converter) AddType(input interface{}) { c.structs = order + 1 } +// Convert returns zod schema corresponding to a struct type. Its a shorthand for +// call to AddType followed by Export. So calling Convert after other calls to +// AddType/Convert/ConvertSlice, returns schemas from those previous calls as well. +// Calling AddType followed by Export might be more appropriate in such scenarios. func (c *Converter) Convert(input interface{}) string { c.AddType(input) return c.Export() } +// ConvertSlice returns zod schemas corresponding to multiple struct types passed +// in the argument. Similar to Convert, calling ConvertSlice after other calls to +// AddType/Convert/ConvertSlice, returns schemas from those previous calls as well. +// Calling AddType followed by Export might be more appropriate in such scenarios. func (c *Converter) ConvertSlice(inputs []interface{}) string { for _, input := range inputs { c.AddType(input) @@ -47,26 +64,25 @@ func (c *Converter) ConvertSlice(inputs []interface{}) string { return c.Export() } +// StructToZodSchema returns zod schema corresponding to a struct type. func StructToZodSchema(input interface{}) string { c := Converter{ prefix: "", outputs: make(map[string]entry), } - c.AddType(input) - - return c.Export() + return c.Convert(input) } +// StructToZodSchemaWithPrefix returns zod schema corresponding to a struct type. +// The prefix is added to the generated schema and type names. func StructToZodSchemaWithPrefix(prefix string, input interface{}) string { c := Converter{ prefix: prefix, outputs: make(map[string]entry), } - c.AddType(input) - - return c.Export() + return c.Convert(input) } var typeMapping = map[reflect.Kind]string{ @@ -95,17 +111,17 @@ type entry struct { data string } -type ByOrder []entry +type byOrder []entry -func (a ByOrder) Len() int { return len(a) } -func (a ByOrder) Swap(i, j int) { a[i], a[j] = a[j], a[i] } -func (a ByOrder) Less(i, j int) bool { return a[i].order < a[j].order } +func (a byOrder) Len() int { return len(a) } +func (a byOrder) Swap(i, j int) { a[i], a[j] = a[j], a[i] } +func (a byOrder) Less(i, j int) bool { return a[i].order < a[j].order } -type CustomFn func(*Converter, reflect.Type, string, string, string, int) string +type CustomFn func(c *Converter, t reflect.Type, validate string, indent int) string -type Meta struct { - Name string - SelfRef bool +type meta struct { + name string + selfRef bool } type Converter struct { @@ -113,7 +129,7 @@ type Converter struct { structs int outputs map[string]entry custom map[string]CustomFn - stack []Meta + stack []meta } func (c *Converter) addSchema(name string, data string) { @@ -126,6 +142,8 @@ func (c *Converter) addSchema(name string, data string) { } } +// Export returns the zod schemas corresponding to all types that have been +// converted so far. func (c *Converter) Export() string { output := strings.Builder{} var sorted []entry @@ -133,12 +151,13 @@ func (c *Converter) Export() string { sorted = append(sorted, ent) } - sort.Sort(ByOrder(sorted)) + sort.Sort(byOrder(sorted)) for _, ent := range sorted { output.WriteString(ent.data) output.WriteString("\n\n") } + return output.String() } @@ -165,7 +184,7 @@ func fieldName(input reflect.StructField) string { func typeName(t reflect.Type) string { if t.Kind() == reflect.Struct { - return t.Name() + return getTypeNameWithGenerics(t.Name()) } if t.Kind() == reflect.Ptr { return typeName(t.Elem()) @@ -183,14 +202,14 @@ func typeName(t reflect.Type) string { func (c *Converter) convertStructTopLevel(t reflect.Type) string { output := strings.Builder{} - name := t.Name() - c.stack = append(c.stack, Meta{name, false}) + name := typeName(t) + c.stack = append(c.stack, meta{name, false}) data := c.convertStruct(t, 0) fullName := c.prefix + name top := c.stack[len(c.stack)-1] - if top.SelfRef { + if top.selfRef { output.WriteString(fmt.Sprintf(`export type %s = %s `, fullName, c.getTypeStruct(t, 0))) @@ -270,61 +289,67 @@ func (c *Converter) getTypeStruct(input reflect.Type, indent int) string { var matchGenericTypeName = regexp.MustCompile(`(.+)\[(.+)]`) -// checking it a reflected type is a generic isn't supported as far as I can see -// so this simple check looks for a `[` character in the type name: `T1[T2]`. +// Checking if a reflected type is a generic isn't supported as far as I can see. +// So this simple check looks for a `[` character in the type name: `T1[T2]`. func isGeneric(t reflect.Type) bool { return strings.Contains(t.Name(), "[") } -// gets the full name and if it's a generic type, strips out the [T] part. -func getFullName(t reflect.Type) (string, string) { +// Gets the full type name (package+type), stripping out generic type arguments. +func getFullName(t reflect.Type) string { var typename string - var generic string if isGeneric(t) { m := matchGenericTypeName.FindAllStringSubmatch(t.Name(), 1)[0] - typename = m[1] - generic = m[2] } else { typename = t.Name() } - return fmt.Sprintf("%s.%s", t.PkgPath(), typename), generic + return fmt.Sprintf("%s.%s", t.PkgPath(), typename) } -func (c *Converter) handleCustomType(t reflect.Type, name, validate string, indent int) (string, bool) { - fullName, generic := getFullName(t) +func (c *Converter) handleCustomType(t reflect.Type, validate string, indent int) (string, bool) { + fullName := getFullName(t) custom, ok := c.custom[fullName] if ok { - return custom(c, t, name, generic, validate, indent), true + return custom(c, t, validate, indent), true } return "", false } -func (c *Converter) ConvertType(t reflect.Type, name string, validate string, indent int) string { +// ConvertType should be called from custom converter functions. +func (c *Converter) ConvertType(t reflect.Type, validate string, indent int) string { if t.Kind() == reflect.Ptr { inner := t.Elem() validate = strings.TrimPrefix(validate, "omitempty") validate = strings.TrimPrefix(validate, ",") - return c.ConvertType(inner, name, validate, indent) + return c.ConvertType(inner, validate, indent) } - if custom, ok := c.handleCustomType(t, name, validate, indent); ok { + // Custom types should be handled before maps/slices, as we might have + // custom types that are maps/slices. + if custom, ok := c.handleCustomType(t, validate, indent); ok { return custom } if t.Kind() == reflect.Slice || t.Kind() == reflect.Array { - return c.convertSliceAndArray(t, name, validate, indent) + return c.convertSliceAndArray(t, validate, indent) + } + + if t.Kind() == reflect.Map { + return c.convertMap(t, validate, indent) } if t.Kind() == reflect.Struct { - // Handle nested un-named structs - these are inline. - if t.Name() == "" { + name := typeName(t) + + if name == "" { + // Handle fields with non-defined types - these are inline. return c.convertStruct(t, indent) - } else if t.Name() == "Time" { + } else if name == "Time" { var validateStr string if validate != "" { // We compare with both the zero value from go and the zero value that zod coerces to @@ -335,8 +360,8 @@ func (c *Converter) ConvertType(t reflect.Type, name string, validate string, in // timestamps are to be coerced to date by zod. JSON.parse only serializes to string return "z.coerce.date()" + validateStr } else { - if c.stack[len(c.stack)-1].Name == name { - c.stack[len(c.stack)-1].SelfRef = true + if c.stack[len(c.stack)-1].name == name { + c.stack[len(c.stack)-1].selfRef = true return fmt.Sprintf("z.lazy(() => %s)", schemaName(c.prefix, name)) } // throws panic if there is a cycle @@ -346,10 +371,6 @@ func (c *Converter) ConvertType(t reflect.Type, name string, validate string, in } } - if t.Kind() == reflect.Map { - return c.convertMap(t, name, validate, indent) - } - // boolean, number, string, any zodType, ok := typeMapping[t.Kind()] if !ok { @@ -372,21 +393,27 @@ func (c *Converter) ConvertType(t reflect.Type, name string, validate string, in return fmt.Sprintf("z.%s()%s", zodType, validateStr) } -func (c *Converter) getType(t reflect.Type, name string, indent int) string { +func (c *Converter) getType(t reflect.Type, indent int) string { if t.Kind() == reflect.Ptr { inner := t.Elem() - return c.getType(inner, name, indent) + return c.getType(inner, indent) } // TODO: handle types for custom types if t.Kind() == reflect.Slice || t.Kind() == reflect.Array { - return c.getTypeSliceAndArray(t, name, indent) + return c.getTypeSliceAndArray(t, indent) + } + + if t.Kind() == reflect.Map { + return c.getTypeMap(t, indent) } if t.Kind() == reflect.Struct { - // Handle nested un-named structs - these are inline. + name := typeName(t) + if t.Name() == "" { + // Handle fields with non-defined types - these are inline. return c.getTypeStruct(t, indent) } else if t.Name() == "Time" { return "date" @@ -395,10 +422,6 @@ func (c *Converter) getType(t reflect.Type, name string, indent int) string { } } - if t.Kind() == reflect.Map { - return c.getTypeMap(t, name, indent) - } - zodType, ok := typeMapping[t.Kind()] if !ok { panic(fmt.Sprint("cannot handle: ", t.Kind())) @@ -416,7 +439,7 @@ func (c *Converter) convertField(f reflect.StructField, indent int, optional, nu // because nullability is processed before custom types, this makes sure // the custom type has control over nullability. - fullName, _ := getFullName(f.Type) + fullName := getFullName(f.Type) _, isCustom := c.custom[fullName] optionalCall := "" @@ -428,7 +451,7 @@ func (c *Converter) convertField(f reflect.StructField, indent int, optional, nu nullableCall = ".nullable()" } - t := c.ConvertType(f.Type, typeName(f.Type), f.Tag.Get("validate"), indent) + t := c.ConvertType(f.Type, f.Tag.Get("validate"), indent) if !anonymous { return fmt.Sprintf( "%s%s: %s%s%s,\n", @@ -452,7 +475,7 @@ func (c *Converter) getTypeField(f reflect.StructField, indent int, optional, nu // because nullability is processed before custom types, this makes sure // the custom type has control over nullability. - fullName, _ := getFullName(f.Type) + fullName := getFullName(f.Type) _, isCustom := c.custom[fullName] optionalCallPre := "" @@ -471,16 +494,16 @@ func (c *Converter) getTypeField(f reflect.StructField, indent int, optional, nu indentation(indent), name, optionalCallPre, - c.getType(f.Type, typeName(f.Type), indent), + c.getType(f.Type, indent), nullableCall, optionalCallUndef) } -func (c *Converter) convertSliceAndArray(t reflect.Type, name, validate string, indent int) string { +func (c *Converter) convertSliceAndArray(t reflect.Type, validate string, indent int) string { if t.Kind() == reflect.Array { return fmt.Sprintf( "%s.array()%s", - c.ConvertType(t.Elem(), name, getValidateAfterDive(validate), indent), fmt.Sprintf(".length(%d)", t.Len())) + c.ConvertType(t.Elem(), getValidateAfterDive(validate), indent), fmt.Sprintf(".length(%d)", t.Len())) } var validateStr strings.Builder @@ -540,16 +563,16 @@ func (c *Converter) convertSliceAndArray(t reflect.Type, name, validate string, return fmt.Sprintf( "%s.array()%s", - c.ConvertType(t.Elem(), name, getValidateAfterDive(validate), indent), validateStr.String()) + c.ConvertType(t.Elem(), getValidateAfterDive(validate), indent), validateStr.String()) } -func (c *Converter) getTypeSliceAndArray(t reflect.Type, name string, indent int) string { +func (c *Converter) getTypeSliceAndArray(t reflect.Type, indent int) string { return fmt.Sprintf( "%s[]", - c.getType(t.Elem(), name, indent)) + c.getType(t.Elem(), indent)) } -func (c *Converter) convertKeyType(t reflect.Type, name, validate string, indent int) string { +func (c *Converter) convertKeyType(t reflect.Type, validate string) string { if t.Name() == "Time" { return "z.coerce.date()" } @@ -582,7 +605,7 @@ func (c *Converter) convertKeyType(t reflect.Type, name, validate string, indent return fmt.Sprintf("z.coerce.%s()%s", zodType, validateStr) } -func (c *Converter) convertMap(t reflect.Type, name, validate string, indent int) string { +func (c *Converter) convertMap(t reflect.Type, validate string, indent int) string { var validateStr strings.Builder if validate != "" { parts := strings.Split(validate, ",") @@ -619,19 +642,19 @@ func (c *Converter) convertMap(t reflect.Type, name, validate string, indent int } return fmt.Sprintf(`z.record(%s, %s)%s`, - c.convertKeyType(t.Key(), name, getValidateKeys(validate), indent), - c.ConvertType(t.Elem(), name, getValidateValues(validate), indent), + c.convertKeyType(t.Key(), getValidateKeys(validate)), + c.ConvertType(t.Elem(), getValidateValues(validate), indent), validateStr.String()) } -func (c *Converter) getTypeMap(t reflect.Type, name string, indent int) string { +func (c *Converter) getTypeMap(t reflect.Type, indent int) string { return fmt.Sprintf(`Record<%s, %s>`, - c.getType(t.Key(), name, indent), - c.getType(t.Elem(), name, indent)) + c.getType(t.Key(), indent), + c.getType(t.Elem(), indent)) } +// Select part of validate string after dive, if it exists. func getValidateAfterDive(validate string) string { - // select part of validate string after dive, if it exists var validateNext string if validate != "" { parts := strings.Split(validate, ",") @@ -1013,10 +1036,12 @@ func isOptional(field reflect.StructField) bool { // structs do not have an empty value. // Interfaces are currently exported with "any" type, which already includes "undefined" if field.Type.Kind() == reflect.Struct || isInterface(field) || - strings.Contains(getValidateCurrent(field.Tag.Get("validate")), "required") { + strings.Contains(validateCurrent, "required") { return false } + // If some comparison is present min=1 or max=2 or len=4 etc. then go-validator requires the value + // to be non-nil unless omitempty is also present if strings.Contains(validateCurrent, "=") && !strings.Contains(validateCurrent, "omitempty") { return false } @@ -1029,12 +1054,12 @@ func indentation(level int) string { return strings.Repeat(" ", level*2) } -func detectCycle(name string, stack []Meta) { +func detectCycle(name string, stack []meta) { var found bool var cycle strings.Builder for _, m := range stack { - cycle.WriteString(m.Name) - if m.Name == name { + cycle.WriteString(m.name) + if m.name == name { found = true break } @@ -1045,3 +1070,21 @@ func detectCycle(name string, stack []Meta) { panic(fmt.Sprintf("circular dependency detected: %s", cycle.String())) } } + +func getTypeNameWithGenerics(name string) string { + typeArgsIdx := strings.Index(name, "[") + if typeArgsIdx == -1 { + return name + } + + var sb strings.Builder + sb.WriteString(name[:typeArgsIdx]) + + typeArgs := strings.Split(name[typeArgsIdx+1:len(name)-1], ",") + for _, arg := range typeArgs { + sb.WriteString(strings.ToTitle(arg[:1])) // Capitalize first letter + sb.WriteString(arg[1:]) + } + + return sb.String() +} diff --git a/zod_test.go b/zod_test.go index 3693e4a..d717c09 100644 --- a/zod_test.go +++ b/zod_test.go @@ -1976,7 +1976,7 @@ export type User = z.infer func TestCustom(t *testing.T) { c := NewConverter(map[string]CustomFn{ - "github.com/hypersequent/zen.Decimal": func(c *Converter, t reflect.Type, s, g, validate string, i int) string { + "github.com/hypersequent/zen.Decimal": func(c *Converter, t reflect.Type, validate string, i int) string { return "z.string()" }, }) @@ -2071,3 +2071,39 @@ func TestCyclic(t *testing.T) { StructToZodSchema(TestCyclicA{}) }) } + +type GenericPair[T any, U any] struct { + First T + Second U +} + +type StringIntPair GenericPair[string, int] + +type PairMap[K comparable, T any, U any] struct { + Items map[K]GenericPair[T, U] `json:"items"` +} + +func TestGenerics(t *testing.T) { + c := NewConverter(nil) + c.AddType(StringIntPair{}) + c.AddType(GenericPair[int, bool]{}) + c.AddType(PairMap[string, int, bool]{}) + assert.Equal(t, `export const StringIntPairSchema = z.object({ + First: z.string(), + Second: z.number(), +}) +export type StringIntPair = z.infer + +export const GenericPairIntBoolSchema = z.object({ + First: z.number(), + Second: z.boolean(), +}) +export type GenericPairIntBool = z.infer + +export const PairMapStringIntBoolSchema = z.object({ + items: z.record(z.string(), GenericPairIntBoolSchema).nullable(), +}) +export type PairMapStringIntBool = z.infer + +`, c.Export()) +}