diff --git a/index/circular_reference_result.go b/index/circular_reference_result.go index a710d6a9..feec1763 100644 --- a/index/circular_reference_result.go +++ b/index/circular_reference_result.go @@ -8,8 +8,10 @@ type CircularReferenceResult struct { Start *Reference LoopIndex int LoopPoint *Reference - IsPolymorphicResult bool // if this result comes from a polymorphic loop. - IsInfiniteLoop bool // if all the definitions in the reference loop are marked as required, this is an infinite circular reference, thus is not allowed. + IsArrayResult bool // if this result comes from an array loop. + PolymorphicType string // which type of polymorphic loop is this? (oneOf, anyOf, allOf) + IsPolymorphicResult bool // if this result comes from a polymorphic loop. + IsInfiniteLoop bool // if all the definitions in the reference loop are marked as required, this is an infinite circular reference, thus is not allowed. } func (c *CircularReferenceResult) GenerateJourneyPath() string { diff --git a/index/index_model.go b/index/index_model.go index 18031f02..08f629b5 100644 --- a/index/index_model.go +++ b/index/index_model.go @@ -28,6 +28,7 @@ type Reference struct { Name string Node *yaml.Node ParentNode *yaml.Node + ParentNodeSchemaType string // used to determine if the parent node is an array or not. Resolved bool Circular bool Seen bool diff --git a/resolver/resolver.go b/resolver/resolver.go index 8f06462c..ad191e7b 100644 --- a/resolver/resolver.go +++ b/resolver/resolver.go @@ -42,6 +42,8 @@ type Resolver struct { indexesVisited int journeysTaken int relativesSeen int + ignorePoly bool + ignoreArray bool } // NewResolver will create a new resolver from a *index.SpecIndex @@ -92,10 +94,21 @@ func (resolver *Resolver) GetNonPolymorphicCircularErrors() []*index.CircularRef res = append(res, resolver.circularReferences[i]) } } - return res } +// IgnorePolymorphicCircularReferences will ignore any circular references that are polymorphic (oneOf, anyOf, allOf) +// This must be set before any resolving is done. +func (resolver *Resolver) IgnorePolymorphicCircularReferences() { + resolver.ignorePoly = true +} + +// IgnoreArrayCircularReferences will ignore any circular references that stem from arrays. This must be set before +// any resolving is done. +func (resolver *Resolver) IgnoreArrayCircularReferences() { + resolver.ignoreArray = true +} + // GetJourneysTaken returns the number of journeys taken by the resolver func (resolver *Resolver) GetJourneysTaken() int { return resolver.journeysTaken @@ -231,7 +244,7 @@ func (resolver *Resolver) VisitReference(ref *index.Reference, seen map[string]b } journey = append(journey, ref) - relatives := resolver.extractRelatives(ref.Node, seen, journey, resolve) + relatives := resolver.extractRelatives(ref.Node, nil, seen, journey, resolve) seen = make(map[string]bool) @@ -254,11 +267,17 @@ func (resolver *Resolver) VisitReference(ref *index.Reference, seen map[string]b visitedDefinitions := map[string]bool{} isInfiniteLoop, _ := resolver.isInfiniteCircularDependency(foundDup, visitedDefinitions, nil) + + isArray := false + if r.ParentNodeSchemaType == "array" { + isArray = true + } circRef = &index.CircularReferenceResult{ Journey: loop, Start: foundDup, LoopIndex: i, LoopPoint: foundDup, + IsArrayResult: isArray, IsInfiniteLoop: isInfiniteLoop, } resolver.circularReferences = append(resolver.circularReferences, circRef) @@ -321,7 +340,7 @@ func (resolver *Resolver) isInfiniteCircularDependency(ref *index.Reference, vis return false, visitedDefinitions } -func (resolver *Resolver) extractRelatives(node *yaml.Node, +func (resolver *Resolver) extractRelatives(node, parent *yaml.Node, foundRelatives map[string]bool, journey []*index.Reference, resolve bool) []*index.Reference { @@ -333,7 +352,30 @@ func (resolver *Resolver) extractRelatives(node *yaml.Node, if len(node.Content) > 0 { for i, n := range node.Content { if utils.IsNodeMap(n) || utils.IsNodeArray(n) { - found = append(found, resolver.extractRelatives(n, foundRelatives, journey, resolve)...) + + var anyvn, allvn, onevn, arrayTypevn *yaml.Node + + // extract polymorphic references + if len(n.Content) > 1 { + _, anyvn = utils.FindKeyNodeTop("anyOf", n.Content) + _, allvn = utils.FindKeyNodeTop("allOf", n.Content) + _, onevn = utils.FindKeyNodeTop("oneOf", n.Content) + _, arrayTypevn = utils.FindKeyNodeTop("type", n.Content) + } + if anyvn != nil || allvn != nil || onevn != nil { + if resolver.ignorePoly { + continue + } + } + if arrayTypevn != nil { + if arrayTypevn.Value == "array" { + if resolver.ignoreArray { + continue + } + } + } + + found = append(found, resolver.extractRelatives(n, node, foundRelatives, journey, resolve)...) } if i%2 == 0 && n.Value == "$ref" { @@ -357,10 +399,22 @@ func (resolver *Resolver) extractRelatives(node *yaml.Node, continue } + schemaType := "" + if parent != nil { + _, arrayTypevn := utils.FindKeyNodeTop("type", parent.Content) + if arrayTypevn != nil { + if arrayTypevn.Value == "array" { + schemaType = "array" + } + } + } + r := &index.Reference{ - Definition: value, - Name: value, - Node: node, + Definition: value, + Name: value, + Node: node, + ParentNode: parent, + ParentNodeSchemaType: schemaType, } found = append(found, r) @@ -398,6 +452,7 @@ func (resolver *Resolver) extractRelatives(node *yaml.Node, Start: ref, LoopIndex: i, LoopPoint: ref, + PolymorphicType: n.Value, IsPolymorphicResult: true, } @@ -435,6 +490,7 @@ func (resolver *Resolver) extractRelatives(node *yaml.Node, Start: ref, LoopIndex: i, LoopPoint: ref, + PolymorphicType: n.Value, IsPolymorphicResult: true, } @@ -449,6 +505,7 @@ func (resolver *Resolver) extractRelatives(node *yaml.Node, } break } + } } } diff --git a/resolver/resolver_test.go b/resolver/resolver_test.go index 1db450ec..f3c33cbf 100644 --- a/resolver/resolver_test.go +++ b/resolver/resolver_test.go @@ -4,7 +4,6 @@ import ( "errors" "fmt" "io/ioutil" - "net/url" "testing" "github.com/pb33f/libopenapi/index" @@ -63,22 +62,104 @@ func TestResolver_CheckForCircularReferences(t *testing.T) { assert.NoError(t, err) } -func TestResolver_CheckForCircularReferences_DigitalOcean(t *testing.T) { - circular, _ := ioutil.ReadFile("../test_specs/digitalocean.yaml") +func TestResolver_CheckForCircularReferences_CatchArray(t *testing.T) { + circular := []byte(`openapi: 3.0.0 +components: + schemas: + ProductCategory: + type: "object" + properties: + name: + type: "string" + children: + type: "array" + items: + $ref: "#/components/schemas/ProductCategory" + description: "Array of sub-categories in the same format." + required: + - "name" + - "children"`) + var rootNode yaml.Node + yaml.Unmarshal(circular, &rootNode) + + index := index.NewSpecIndex(&rootNode) + + resolver := NewResolver(index) + assert.NotNil(t, resolver) + + circ := resolver.CheckForCircularReferences() + assert.Len(t, circ, 1) + assert.Len(t, resolver.GetResolvingErrors(), 1) // infinite loop is a resolving error. + assert.Len(t, resolver.GetCircularErrors(), 1) + assert.True(t, resolver.GetCircularErrors()[0].IsArrayResult) + + _, err := yaml.Marshal(resolver.resolvedRoot) + assert.NoError(t, err) +} + +func TestResolver_CheckForCircularReferences_IgnoreArray(t *testing.T) { + circular := []byte(`openapi: 3.0.0 +components: + schemas: + ProductCategory: + type: "object" + properties: + name: + type: "string" + children: + type: "array" + items: + $ref: "#/components/schemas/ProductCategory" + description: "Array of sub-categories in the same format." + required: + - "name" + - "children"`) var rootNode yaml.Node yaml.Unmarshal(circular, &rootNode) - baseURL, _ := url.Parse("https://raw.githubusercontent.com/digitalocean/openapi/main/specification") + index := index.NewSpecIndex(&rootNode) + + resolver := NewResolver(index) + assert.NotNil(t, resolver) + + resolver.IgnoreArrayCircularReferences() + + circ := resolver.CheckForCircularReferences() + assert.Len(t, circ, 0) + assert.Len(t, resolver.GetResolvingErrors(), 0) + assert.Len(t, resolver.GetCircularErrors(), 0) + + _, err := yaml.Marshal(resolver.resolvedRoot) + assert.NoError(t, err) +} - index := index.NewSpecIndexWithConfig(&rootNode, &index.SpecIndexConfig{ - AllowRemoteLookup: true, - AllowFileLookup: true, - BaseURL: baseURL, - }) +func TestResolver_CheckForCircularReferences_IgnorePoly_Any(t *testing.T) { + circular := []byte(`openapi: 3.0.0 +components: + schemas: + ProductCategory: + type: "object" + properties: + name: + type: "string" + children: + type: "object" + anyOf: + - $ref: "#/components/schemas/ProductCategory" + description: "Array of sub-categories in the same format." + required: + - "name" + - "children"`) + var rootNode yaml.Node + yaml.Unmarshal(circular, &rootNode) + + index := index.NewSpecIndex(&rootNode) resolver := NewResolver(index) assert.NotNil(t, resolver) + resolver.IgnorePolymorphicCircularReferences() + circ := resolver.CheckForCircularReferences() assert.Len(t, circ, 0) assert.Len(t, resolver.GetResolvingErrors(), 0) @@ -88,6 +169,172 @@ func TestResolver_CheckForCircularReferences_DigitalOcean(t *testing.T) { assert.NoError(t, err) } +func TestResolver_CheckForCircularReferences_IgnorePoly_All(t *testing.T) { + circular := []byte(`openapi: 3.0.0 +components: + schemas: + ProductCategory: + type: "object" + properties: + name: + type: "string" + children: + type: "object" + allOf: + - $ref: "#/components/schemas/ProductCategory" + description: "Array of sub-categories in the same format." + required: + - "name" + - "children"`) + var rootNode yaml.Node + yaml.Unmarshal(circular, &rootNode) + + index := index.NewSpecIndex(&rootNode) + + resolver := NewResolver(index) + assert.NotNil(t, resolver) + + resolver.IgnorePolymorphicCircularReferences() + + circ := resolver.CheckForCircularReferences() + assert.Len(t, circ, 0) + assert.Len(t, resolver.GetResolvingErrors(), 0) + assert.Len(t, resolver.GetCircularErrors(), 0) + + _, err := yaml.Marshal(resolver.resolvedRoot) + assert.NoError(t, err) +} + +func TestResolver_CheckForCircularReferences_IgnorePoly_One(t *testing.T) { + circular := []byte(`openapi: 3.0.0 +components: + schemas: + ProductCategory: + type: "object" + properties: + name: + type: "string" + children: + type: "object" + oneOf: + - $ref: "#/components/schemas/ProductCategory" + description: "Array of sub-categories in the same format." + required: + - "name" + - "children"`) + var rootNode yaml.Node + yaml.Unmarshal(circular, &rootNode) + + index := index.NewSpecIndex(&rootNode) + + resolver := NewResolver(index) + assert.NotNil(t, resolver) + + resolver.IgnorePolymorphicCircularReferences() + + circ := resolver.CheckForCircularReferences() + assert.Len(t, circ, 0) + assert.Len(t, resolver.GetResolvingErrors(), 0) + assert.Len(t, resolver.GetCircularErrors(), 0) + + _, err := yaml.Marshal(resolver.resolvedRoot) + assert.NoError(t, err) +} + +func TestResolver_CheckForCircularReferences_CatchPoly_Any(t *testing.T) { + circular := []byte(`openapi: 3.0.0 +components: + schemas: + ProductCategory: + type: "object" + properties: + name: + type: "string" + children: + type: "object" + anyOf: + - $ref: "#/components/schemas/ProductCategory" + description: "Array of sub-categories in the same format." + required: + - "name" + - "children"`) + var rootNode yaml.Node + yaml.Unmarshal(circular, &rootNode) + + index := index.NewSpecIndex(&rootNode) + + resolver := NewResolver(index) + assert.NotNil(t, resolver) + + circ := resolver.CheckForCircularReferences() + assert.Len(t, circ, 0) + assert.Len(t, resolver.GetResolvingErrors(), 0) // not an infinite loop if poly. + assert.Len(t, resolver.GetCircularErrors(), 1) + assert.Equal(t, "anyOf", resolver.GetCircularErrors()[0].PolymorphicType) + _, err := yaml.Marshal(resolver.resolvedRoot) + assert.NoError(t, err) +} + +func TestResolver_CheckForCircularReferences_CatchPoly_All(t *testing.T) { + circular := []byte(`openapi: 3.0.0 +components: + schemas: + ProductCategory: + type: "object" + properties: + name: + type: "string" + children: + type: "object" + allOf: + - $ref: "#/components/schemas/ProductCategory" + description: "Array of sub-categories in the same format." + required: + - "name" + - "children"`) + var rootNode yaml.Node + yaml.Unmarshal(circular, &rootNode) + + index := index.NewSpecIndex(&rootNode) + + resolver := NewResolver(index) + assert.NotNil(t, resolver) + + circ := resolver.CheckForCircularReferences() + assert.Len(t, circ, 0) + assert.Len(t, resolver.GetResolvingErrors(), 0) // not an infinite loop if poly. + assert.Len(t, resolver.GetCircularErrors(), 1) + assert.Equal(t, "allOf", resolver.GetCircularErrors()[0].PolymorphicType) + assert.True(t, resolver.GetCircularErrors()[0].IsPolymorphicResult) + _, err := yaml.Marshal(resolver.resolvedRoot) + assert.NoError(t, err) +} + +//func TestResolver_CheckForCircularReferences_DigitalOcean(t *testing.T) { +// circular, _ := ioutil.ReadFile("../test_specs/digitalocean.yaml") +// var rootNode yaml.Node +// yaml.Unmarshal(circular, &rootNode) +// +// baseURL, _ := url.Parse("https://raw.githubusercontent.com/digitalocean/openapi/main/specification") +// +// index := index.NewSpecIndexWithConfig(&rootNode, &index.SpecIndexConfig{ +// AllowRemoteLookup: true, +// AllowFileLookup: true, +// BaseURL: baseURL, +// }) +// +// resolver := NewResolver(index) +// assert.NotNil(t, resolver) +// +// circ := resolver.CheckForCircularReferences() +// assert.Len(t, circ, 0) +// assert.Len(t, resolver.GetResolvingErrors(), 0) +// assert.Len(t, resolver.GetCircularErrors(), 0) +// +// _, err := yaml.Marshal(resolver.resolvedRoot) +// assert.NoError(t, err) +//} + func TestResolver_CircularReferencesRequiredValid(t *testing.T) { circular, _ := ioutil.ReadFile("../test_specs/swagger-valid-recursive-model.yaml") var rootNode yaml.Node @@ -129,7 +376,7 @@ func TestResolver_DeepJourney(t *testing.T) { } index := index.NewSpecIndex(nil) resolver := NewResolver(index) - assert.Nil(t, resolver.extractRelatives(nil, nil, journey, false)) + assert.Nil(t, resolver.extractRelatives(nil, nil, nil, journey, false)) } func TestResolver_ResolveComponents_Stripe(t *testing.T) { diff --git a/utils/utils.go b/utils/utils.go index e384906e..b99ea843 100644 --- a/utils/utils.go +++ b/utils/utils.go @@ -236,6 +236,9 @@ func FindKeyNodeTop(key string, nodes []*yaml.Node) (keyNode *yaml.Node, valueNo } if strings.EqualFold(key, v.Value) { + if i+1 >= len(nodes) { + return v, nodes[i] + } return v, nodes[i+1] // next node is what we need. } }