Skip to content

Commit

Permalink
review feedback
Browse files Browse the repository at this point in the history
Signed-off-by: Jason Hall <jason@chainguard.dev>
  • Loading branch information
imjasonh committed Sep 30, 2023
1 parent e6dacc7 commit facbc9e
Showing 1 changed file with 26 additions and 30 deletions.
56 changes: 26 additions & 30 deletions reflect/reflect.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,20 +13,10 @@ import (
)

func GenerateType(v any) (attr.Type, error) {
return generateTypeReflect("", reflect.TypeOf(v))
return generateTypeReflect("", sets.New[string](), reflect.TypeOf(v))
}

var inProgress = sets.NewString()

func generateTypeReflect(path string, t reflect.Type) (attr.Type, error) {
if inProgress.Has(t.String()) {
log.Println("detected recursive type:", path)
// If we're already trying to figure out this type, then we're in a recursive loop. Avoid this by just returning an empty object.
return basetypes.ObjectType{}, nil
}
inProgress.Insert(t.String())
defer func() { inProgress.Delete(t.String()) }()

func generateTypeReflect(path string, inProgress sets.Set[string], t reflect.Type) (attr.Type, error) {
switch t.Kind() {
case reflect.String:
return basetypes.StringType{}, nil
Expand All @@ -38,10 +28,10 @@ func generateTypeReflect(path string, t reflect.Type) (attr.Type, error) {
case reflect.Float32, reflect.Float64:
return basetypes.Float64Type{}, nil
case reflect.Ptr:
return generateTypeReflect("*"+path, t.Elem())
return generateTypeReflect("*"+path, inProgress, t.Elem())

case reflect.Array, reflect.Slice:
st, err := generateTypeReflect(path+"[]", t.Elem())
st, err := generateTypeReflect(path+"[]", inProgress, t.Elem())
if err != nil {
return nil, fmt.Errorf("[]%v: %w", t.Elem(), err)
}
Expand All @@ -54,7 +44,7 @@ func generateTypeReflect(path string, t reflect.Type) (attr.Type, error) {
return nil, fmt.Errorf("%v only string map keys are supported", t.Key())
}

et, err := generateTypeReflect(path+"{}", t.Elem())
et, err := generateTypeReflect(path+"{}", inProgress, t.Elem())
if err != nil {
return nil, fmt.Errorf("map[string]%v: %w", t.Elem(), err)
}
Expand All @@ -63,6 +53,14 @@ func generateTypeReflect(path string, t reflect.Type) (attr.Type, error) {
}, nil

case reflect.Struct:
if inProgress.Has(t.String()) {
log.Println("detected recursive type:", path)
// If we're already trying to figure out this type, then we're in a recursive loop. Avoid this by just returning an empty object.
return basetypes.ObjectType{}, nil
}
inProgress.Insert(t.String())
defer func() { inProgress.Delete(t.String()) }()

ot := basetypes.ObjectType{
AttrTypes: make(map[string]attr.Type, t.NumField()),
}
Expand All @@ -76,7 +74,7 @@ func generateTypeReflect(path string, t reflect.Type) (attr.Type, error) {
continue
}

ft, err := generateTypeReflect(path+"."+*tag, sf.Type)
ft, err := generateTypeReflect(path+"."+*tag, inProgress, sf.Type)
if err != nil {
return nil, fmt.Errorf("struct %w", err)
}
Expand All @@ -90,24 +88,22 @@ func generateTypeReflect(path string, t reflect.Type) (attr.Type, error) {
}

func GenerateValue(v any) (attr.Value, diag.Diagnostics) {
return generateValueReflect("", reflect.ValueOf(v))
return generateValueReflect("", sets.New[string](), reflect.ValueOf(v))
}

var valueInProgress = sets.NewString()

func generateValueReflect(path string, v reflect.Value) (out attr.Value, diagout diag.Diagnostics) {
if valueInProgress.Has(v.String()) {
func generateValueReflect(path string, inProgress sets.Set[string], v reflect.Value) (out attr.Value, diagout diag.Diagnostics) {
if inProgress.Has(v.String()) {
log.Println("detected recursive type:", path)
// If we're already trying to figure out this value, then we're in a recursive loop. Avoid this by just returning an empty object.
return basetypes.NewObjectValue(nil, nil)
}
valueInProgress.Insert(v.String())
defer func() { valueInProgress.Delete(v.String()) }()
inProgress.Insert(v.String())
defer func() { inProgress.Delete(v.String()) }()

t := v.Type()
switch t.Kind() {
case reflect.Pointer:
return generateValueReflect(path, v.Elem())
return generateValueReflect(path, inProgress, v.Elem())
case reflect.String:
return basetypes.NewStringValue(v.String()), nil
case reflect.Bool:
Expand All @@ -120,13 +116,13 @@ func generateValueReflect(path string, v reflect.Value) (out attr.Value, diagout
return basetypes.NewFloat64Value(v.Float()), nil

case reflect.Array, reflect.Slice:
st, err := generateTypeReflect(path+"[]", t.Elem())
st, err := generateTypeReflect(path+"[]", inProgress, t.Elem())
if err != nil {
return nil, []diag.Diagnostic{diag.NewErrorDiagnostic(err.Error(), "")}
}
ets := make([]attr.Value, 0, v.Len())
for i := 0; i < v.Len(); i++ {
et, diags := generateValueReflect(path+"[]", v.Index(i))
et, diags := generateValueReflect(fmt.Sprintf("%s[%d]", path, i), inProgress, v.Index(i))
if diags.HasError() {
return nil, diags
}
Expand All @@ -135,14 +131,14 @@ func generateValueReflect(path string, v reflect.Value) (out attr.Value, diagout
return basetypes.NewListValue(st, ets)

case reflect.Map:
et, err := generateTypeReflect(path+"{}", t.Elem())
et, err := generateTypeReflect(path+"{}", inProgress, t.Elem())
if err != nil {
return nil, []diag.Diagnostic{diag.NewErrorDiagnostic(err.Error(), "")}
}

em := make(map[string]attr.Value, v.Len())
for _, key := range v.MapKeys() {
et, diags := generateValueReflect(path+"{"+key.String()+"}", v.MapIndex(key))
et, diags := generateValueReflect(fmt.Sprintf("%s[%q]", path, key.String()), inProgress, v.MapIndex(key))
if diags.HasError() {
return nil, diags
}
Expand All @@ -151,7 +147,7 @@ func generateValueReflect(path string, v reflect.Value) (out attr.Value, diagout
return basetypes.NewMapValue(et, em)

case reflect.Struct:
ot, err := generateTypeReflect(path+"{}", t)
ot, err := generateTypeReflect(path+"{}", inProgress, t)
if err != nil {
return nil, []diag.Diagnostic{diag.NewErrorDiagnostic(err.Error(), "")}
}
Expand All @@ -163,7 +159,7 @@ func generateValueReflect(path string, v reflect.Value) (out attr.Value, diagout
if tag == nil {
continue
}
ft, diags := generateValueReflect(path+"{}", v.Field(i))
ft, diags := generateValueReflect(path+"{}", inProgress, v.Field(i))
if diags.HasError() {
return nil, diags
}
Expand Down

0 comments on commit facbc9e

Please sign in to comment.