diff --git a/document.go b/document.go index 199feed8..654d003c 100644 --- a/document.go +++ b/document.go @@ -92,7 +92,7 @@ type document struct { // skip optional errors like circular reference errors. // if configured - errorFilter func(error) bool + errorFilter []func(error) bool } // DocumentModel represents either a Swagger document (version 2) or an OpenAPI document (version 3) that is @@ -118,12 +118,12 @@ type DocumentModel[T v2high.Swagger | v3high.Document] struct { func NewDocument(specByteArray []byte, options ...ConfigurationOption) (Document, error) { // sane defaults directly visible to the user. config := &Configuration{ - RemoteURLHandler: http.Get, - AllowFileReferences: false, - AllowRemoteReferences: false, - AvoidIndexBuild: false, - BypassDocumentCheck: false, - AllowCircularReferenceResolving: true, + RemoteURLHandler: http.Get, + AllowFileReferences: false, + AllowRemoteReferences: false, + AvoidIndexBuild: false, + BypassDocumentCheck: false, + ForbidCircularReferenceResolving: false, } var err error @@ -150,7 +150,7 @@ func NewDocumentWithConfiguration(specByteArray []byte, config *Configuration) ( info: info, highOpenAPI3Model: nil, highSwaggerModel: nil, - errorFilter: defaultErrorFilter, + errorFilter: []func(error) bool{defaultErrorFilter}, } d.SetConfiguration(config) @@ -162,14 +162,19 @@ func (d *document) SetConfiguration(config *Configuration) { d.config = config if config == nil { - d.errorFilter = defaultErrorFilter + d.errorFilter = []func(error) bool{defaultErrorFilter} return } - d.errorFilter = errorutils.AndFilter( + if config.RemoteURLHandler == nil { + // set default handler if + config.RemoteURLHandler = http.Get + } + + d.errorFilter = []func(error) bool{ // more filters can be added here if needed - circularReferenceErrorFilter(config.AllowCircularReferenceResolving), - ) + circularReferenceErrorFilter(config.ForbidCircularReferenceResolving), + } } func NewDocumentWithTypeCheck(specByteArray []byte, bypassCheck bool) (Document, error) { @@ -250,7 +255,7 @@ func (d *document) BuildV2Model() (*DocumentModel[v2high.Swagger], error) { } lowDoc, err := v2low.CreateDocumentFromConfig(d.info, d.config.toModelConfig()) - err = errorutils.Filtered(err, d.errorFilter) + err = errorutils.Filtered(err, d.errorFilter...) if err != nil { return nil, err } @@ -282,7 +287,7 @@ func (d *document) BuildV3Model() (*DocumentModel[v3high.Document], error) { } lowDoc, err = v3low.CreateDocumentFromConfig(d.info, d.config.toModelConfig()) - err = errorutils.Filtered(err, d.errorFilter) + err = errorutils.Filtered(err, d.errorFilter...) if err != nil { return nil, err } diff --git a/document_config.go b/document_config.go index d27dbf6d..101f4b4d 100644 --- a/document_config.go +++ b/document_config.go @@ -40,9 +40,9 @@ type Configuration struct { // passed in and used. Only enable this when parsing non openapi documents. BypassDocumentCheck bool - // AllowCircularReferences will allow circular references to be resolved. This is disabled by default. - // Will return an error in case of circular references. - AllowCircularReferenceResolving bool + // ForbidCircularReferenceResolving will forbid circular references to be resolved. This is disabled by default. + // Will return an error in case of circular references if set to true. + ForbidCircularReferenceResolving bool } func (c *Configuration) toModelConfig() *datamodel.DocumentConfiguration { @@ -109,11 +109,11 @@ func WithBypassDocumentCheck(bypass bool) ConfigurationOption { } } -// WithAllowCircularReferenceResolving returns an error for every detected circular reference if set to false. -// If set to true, circular references will be resolved (default behavior). -func WithAllowCircularReferenceResolving(allow bool) ConfigurationOption { +// WithForbidCircularReferenceResolving returns an error for every detected circular reference if set to true. +// If set to false, circular references will be resolved (default behavior). +func WithForbidCircularReferenceResolving(forbidden bool) ConfigurationOption { return func(o *Configuration) error { - o.AllowCircularReferenceResolving = allow + o.ForbidCircularReferenceResolving = forbidden return nil } } diff --git a/document_examples_test.go b/document_examples_test.go index 31e5ce22..d02dfa8d 100644 --- a/document_examples_test.go +++ b/document_examples_test.go @@ -401,7 +401,7 @@ components: - testThing ` // create a new document from specification bytes - doc, err := NewDocument([]byte(spec), WithAllowCircularReferenceResolving(false)) + doc, err := NewDocument([]byte(spec), WithForbidCircularReferenceResolving(true)) // if anything went wrong, an error is thrown if err != nil { diff --git a/document_test.go b/document_test.go index 8580f15a..8221a4fe 100644 --- a/document_test.go +++ b/document_test.go @@ -174,7 +174,7 @@ func TestDocument_RenderAndReload_ChangeCheck_Burgershop(t *testing.T) { func TestDocument_RenderAndReload_ChangeCheck_Stripe(t *testing.T) { bs, _ := os.ReadFile("test_specs/stripe.yaml") - doc, err := NewDocument(bs, WithAllowCircularReferenceResolving(true)) + doc, err := NewDocument(bs) require.NoError(t, err) _, err = doc.BuildV3Model() require.NoError(t, err) @@ -321,9 +321,10 @@ func TestDocument_AnyDocWithConfig(t *testing.T) { func TestDocument_BuildModelCircular(t *testing.T) { petstore, _ := os.ReadFile("test_specs/circular-tests.yaml") - doc, err := NewDocument(petstore, WithAllowCircularReferenceResolving(false)) + doc, err := NewDocument(petstore, WithForbidCircularReferenceResolving(true)) require.NoError(t, err) m, err := doc.BuildV3Model() + require.Error(t, err) // top level library does not return broken objects // with an error, only one or the other diff --git a/errors.go b/errors.go index 661e2abc..3a98f768 100644 --- a/errors.go +++ b/errors.go @@ -32,18 +32,15 @@ func isCircularErr(err error) bool { // returns a filter function that checks if a given error is a circular reference error // and in case that circular references are allowed or not, it returns false // in order to skip the error or true in order to keep the error in the wrapped error list. -func circularReferenceErrorFilter(refAllowed bool) func(error) (keep bool) { +func circularReferenceErrorFilter(forbidden bool) func(error) (keep bool) { return func(err error) bool { if err == nil { return false } if isCircularErr(err) { - if refAllowed { - return false - } else { - return true - } + // if forbidded -> keep the error and pass it to the user + return forbidden } // keep unknown error diff --git a/index/find_component_test.go b/index/find_component_test.go index f92d12c4..a2191850 100644 --- a/index/find_component_test.go +++ b/index/find_component_test.go @@ -177,8 +177,10 @@ paths: // extract crs param from index crsParam := index.GetMappedReferences()["https://schemas.opengis.net/ogcapi/features/part2/1.0/openapi/ogcapi-features-2.yaml#/components/parameters/crs"] - assert.NotNil(t, crsParam) + require.NotNil(t, crsParam) assert.True(t, crsParam.IsRemote) + require.NotNil(t, crsParam.Node) + require.GreaterOrEqual(t, len(crsParam.Node.Content), 10) assert.Equal(t, "crs", crsParam.Node.Content[1].Value) assert.Equal(t, "query", crsParam.Node.Content[3].Value) assert.Equal(t, "form", crsParam.Node.Content[9].Value) diff --git a/internal/errorutils/filter.go b/internal/errorutils/filter.go index f19f5ff8..71c64dec 100644 --- a/internal/errorutils/filter.go +++ b/internal/errorutils/filter.go @@ -5,7 +5,7 @@ func Filtered(err error, filters ...func(error) (keep bool)) error { return nil } errs := ShallowUnwrap(err) - filtered := Filter(errs, AndFilter(filters...)) + filtered := Filter(errs, and(filters...)) if len(filtered) == 0 { return nil } @@ -14,21 +14,26 @@ func Filtered(err error, filters ...func(error) (keep bool)) error { func Filter(errs []error, filter func(error) (keep bool)) []error { var result []error + var keep bool for _, err := range errs { - if filter(err) { + keep = filter(err) + if keep { result = append(result, err) } } return result } -func AndFilter(filters ...func(error) (keep bool)) func(error) (keep bool) { +func and(filters ...func(error) (keep bool)) func(error) (keep bool) { return func(err error) bool { + var keep bool for _, filter := range filters { - if !filter(err) { + keep = filter(err) + if !keep { return false } } + // all true -> true return true } } diff --git a/resolver/resolver_test.go b/resolver/resolver_test.go index 17d0d329..5729ac9b 100644 --- a/resolver/resolver_test.go +++ b/resolver/resolver_test.go @@ -81,12 +81,12 @@ func TestResolver_CheckForCircularReferences_DigitalOcean(t *testing.T) { require.NotNil(t, resolver) circ := resolver.CheckForCircularReferences() - assert.Len(t, circ, 0) - assert.Len(t, resolver.GetResolvingErrors(), 0) - assert.Len(t, resolver.GetCircularErrors(), 0) + require.Len(t, circ, 0) + require.Len(t, resolver.GetResolvingErrors(), 0) + require.Len(t, resolver.GetCircularErrors(), 0) _, err := yaml.Marshal(resolver.resolvedRoot) - assert.NoError(t, err) + require.NoError(t, err) } func TestResolver_CircularReferencesRequiredValid(t *testing.T) {