From facbc9ed02ae8c2163ec0eacde7b49d8bd909d7b Mon Sep 17 00:00:00 2001 From: Jason Hall Date: Sat, 30 Sep 2023 14:53:02 -0400 Subject: [PATCH] review feedback Signed-off-by: Jason Hall --- reflect/reflect.go | 56 +++++++++++++++++++++------------------------- 1 file changed, 26 insertions(+), 30 deletions(-) diff --git a/reflect/reflect.go b/reflect/reflect.go index 0474707..72638c9 100644 --- a/reflect/reflect.go +++ b/reflect/reflect.go @@ -13,20 +13,10 @@ import ( ) func GenerateType(v any) (attr.Type, error) { - return generateTypeReflect("", reflect.TypeOf(v)) + return generateTypeReflect("", sets.New[string](), reflect.TypeOf(v)) } -var inProgress = sets.NewString() - -func generateTypeReflect(path string, t reflect.Type) (attr.Type, error) { - if inProgress.Has(t.String()) { - log.Println("detected recursive type:", path) - // If we're already trying to figure out this type, then we're in a recursive loop. Avoid this by just returning an empty object. - return basetypes.ObjectType{}, nil - } - inProgress.Insert(t.String()) - defer func() { inProgress.Delete(t.String()) }() - +func generateTypeReflect(path string, inProgress sets.Set[string], t reflect.Type) (attr.Type, error) { switch t.Kind() { case reflect.String: return basetypes.StringType{}, nil @@ -38,10 +28,10 @@ func generateTypeReflect(path string, t reflect.Type) (attr.Type, error) { case reflect.Float32, reflect.Float64: return basetypes.Float64Type{}, nil case reflect.Ptr: - return generateTypeReflect("*"+path, t.Elem()) + return generateTypeReflect("*"+path, inProgress, t.Elem()) case reflect.Array, reflect.Slice: - st, err := generateTypeReflect(path+"[]", t.Elem()) + st, err := generateTypeReflect(path+"[]", inProgress, t.Elem()) if err != nil { return nil, fmt.Errorf("[]%v: %w", t.Elem(), err) } @@ -54,7 +44,7 @@ func generateTypeReflect(path string, t reflect.Type) (attr.Type, error) { return nil, fmt.Errorf("%v only string map keys are supported", t.Key()) } - et, err := generateTypeReflect(path+"{}", t.Elem()) + et, err := generateTypeReflect(path+"{}", inProgress, t.Elem()) if err != nil { return nil, fmt.Errorf("map[string]%v: %w", t.Elem(), err) } @@ -63,6 +53,14 @@ func generateTypeReflect(path string, t reflect.Type) (attr.Type, error) { }, nil case reflect.Struct: + if inProgress.Has(t.String()) { + log.Println("detected recursive type:", path) + // If we're already trying to figure out this type, then we're in a recursive loop. Avoid this by just returning an empty object. + return basetypes.ObjectType{}, nil + } + inProgress.Insert(t.String()) + defer func() { inProgress.Delete(t.String()) }() + ot := basetypes.ObjectType{ AttrTypes: make(map[string]attr.Type, t.NumField()), } @@ -76,7 +74,7 @@ func generateTypeReflect(path string, t reflect.Type) (attr.Type, error) { continue } - ft, err := generateTypeReflect(path+"."+*tag, sf.Type) + ft, err := generateTypeReflect(path+"."+*tag, inProgress, sf.Type) if err != nil { return nil, fmt.Errorf("struct %w", err) } @@ -90,24 +88,22 @@ func generateTypeReflect(path string, t reflect.Type) (attr.Type, error) { } func GenerateValue(v any) (attr.Value, diag.Diagnostics) { - return generateValueReflect("", reflect.ValueOf(v)) + return generateValueReflect("", sets.New[string](), reflect.ValueOf(v)) } -var valueInProgress = sets.NewString() - -func generateValueReflect(path string, v reflect.Value) (out attr.Value, diagout diag.Diagnostics) { - if valueInProgress.Has(v.String()) { +func generateValueReflect(path string, inProgress sets.Set[string], v reflect.Value) (out attr.Value, diagout diag.Diagnostics) { + if inProgress.Has(v.String()) { log.Println("detected recursive type:", path) // If we're already trying to figure out this value, then we're in a recursive loop. Avoid this by just returning an empty object. return basetypes.NewObjectValue(nil, nil) } - valueInProgress.Insert(v.String()) - defer func() { valueInProgress.Delete(v.String()) }() + inProgress.Insert(v.String()) + defer func() { inProgress.Delete(v.String()) }() t := v.Type() switch t.Kind() { case reflect.Pointer: - return generateValueReflect(path, v.Elem()) + return generateValueReflect(path, inProgress, v.Elem()) case reflect.String: return basetypes.NewStringValue(v.String()), nil case reflect.Bool: @@ -120,13 +116,13 @@ func generateValueReflect(path string, v reflect.Value) (out attr.Value, diagout return basetypes.NewFloat64Value(v.Float()), nil case reflect.Array, reflect.Slice: - st, err := generateTypeReflect(path+"[]", t.Elem()) + st, err := generateTypeReflect(path+"[]", inProgress, t.Elem()) if err != nil { return nil, []diag.Diagnostic{diag.NewErrorDiagnostic(err.Error(), "")} } ets := make([]attr.Value, 0, v.Len()) for i := 0; i < v.Len(); i++ { - et, diags := generateValueReflect(path+"[]", v.Index(i)) + et, diags := generateValueReflect(fmt.Sprintf("%s[%d]", path, i), inProgress, v.Index(i)) if diags.HasError() { return nil, diags } @@ -135,14 +131,14 @@ func generateValueReflect(path string, v reflect.Value) (out attr.Value, diagout return basetypes.NewListValue(st, ets) case reflect.Map: - et, err := generateTypeReflect(path+"{}", t.Elem()) + et, err := generateTypeReflect(path+"{}", inProgress, t.Elem()) if err != nil { return nil, []diag.Diagnostic{diag.NewErrorDiagnostic(err.Error(), "")} } em := make(map[string]attr.Value, v.Len()) for _, key := range v.MapKeys() { - et, diags := generateValueReflect(path+"{"+key.String()+"}", v.MapIndex(key)) + et, diags := generateValueReflect(fmt.Sprintf("%s[%q]", path, key.String()), inProgress, v.MapIndex(key)) if diags.HasError() { return nil, diags } @@ -151,7 +147,7 @@ func generateValueReflect(path string, v reflect.Value) (out attr.Value, diagout return basetypes.NewMapValue(et, em) case reflect.Struct: - ot, err := generateTypeReflect(path+"{}", t) + ot, err := generateTypeReflect(path+"{}", inProgress, t) if err != nil { return nil, []diag.Diagnostic{diag.NewErrorDiagnostic(err.Error(), "")} } @@ -163,7 +159,7 @@ func generateValueReflect(path string, v reflect.Value) (out attr.Value, diagout if tag == nil { continue } - ft, diags := generateValueReflect(path+"{}", v.Field(i)) + ft, diags := generateValueReflect(path+"{}", inProgress, v.Field(i)) if diags.HasError() { return nil, diags }