diff --git a/filters/ratelimit/fail_closed_test.go b/filters/ratelimit/fail_closed_test.go index 3c64684184..c2a7d073b8 100644 --- a/filters/ratelimit/fail_closed_test.go +++ b/filters/ratelimit/fail_closed_test.go @@ -3,11 +3,9 @@ package ratelimit_test import ( "net/http" "net/http/httptest" - "net/url" "testing" "github.com/zalando/skipper/eskip" - "github.com/zalando/skipper/filters" "github.com/zalando/skipper/filters/builtin" fratelimit "github.com/zalando/skipper/filters/ratelimit" snet "github.com/zalando/skipper/net" @@ -19,68 +17,64 @@ import ( func TestFailureMode(t *testing.T) { for _, tt := range []struct { - name string - ratelimitFilterName string - failClosed bool - wantLimit bool - limitStatusCode int + name string + filters string + wantLimit bool + limitStatusCode int }{ { - name: "test clusterRatelimit fail open", - ratelimitFilterName: "clusterRatelimit", - wantLimit: false, - limitStatusCode: http.StatusTooManyRequests, + name: "test clusterRatelimit fail open", + filters: `clusterRatelimit("t", 1, "1s")`, + wantLimit: false, + limitStatusCode: http.StatusTooManyRequests, }, { - name: "test clusterRatelimit fail closed", - ratelimitFilterName: "clusterRatelimit", - failClosed: true, - wantLimit: true, - limitStatusCode: http.StatusTooManyRequests, + name: "test clusterRatelimit fail closed", + filters: `ratelimitFailClosed() -> clusterRatelimit("t", 1, "1s")`, + wantLimit: true, + limitStatusCode: http.StatusTooManyRequests, }, { - name: "test clusterClientRatelimit fail open", - ratelimitFilterName: "clusterClientRatelimit", - wantLimit: false, - limitStatusCode: http.StatusTooManyRequests, + name: "test clusterClientRatelimit fail open", + filters: `clusterClientRatelimit("t", 1, "1s", "X-Test")`, + wantLimit: false, + limitStatusCode: http.StatusTooManyRequests, }, { - name: "test clusterClientRatelimit fail closed", - ratelimitFilterName: "clusterClientRatelimit", - failClosed: true, - wantLimit: true, - limitStatusCode: http.StatusTooManyRequests, + name: "test clusterClientRatelimit fail closed", + filters: `ratelimitFailClosed() -> clusterClientRatelimit("t", 1, "1s", "X-Test")`, + wantLimit: true, + limitStatusCode: http.StatusTooManyRequests, }, { - name: "test backendRatelimit fail open", - ratelimitFilterName: "backendRatelimit", - wantLimit: false, - limitStatusCode: http.StatusServiceUnavailable, + name: "test backendRatelimit fail open", + filters: `backendRatelimit("t", 1, "1s")`, + wantLimit: false, + limitStatusCode: http.StatusServiceUnavailable, }, { - name: "test backendRatelimit fail closed", - ratelimitFilterName: "backendRatelimit", - failClosed: true, - wantLimit: true, - limitStatusCode: http.StatusServiceUnavailable, + name: "test backendRatelimit fail closed", + filters: `ratelimitFailClosed() -> backendRatelimit("t", 1, "1s")`, + wantLimit: true, + limitStatusCode: http.StatusServiceUnavailable, }, { - name: "test clusterLeakyBucketRatelimit fail open", - ratelimitFilterName: "clusterLeakyBucketRatelimit", - wantLimit: false, - limitStatusCode: http.StatusTooManyRequests, + name: "test clusterLeakyBucketRatelimit fail open", + filters: `clusterLeakyBucketRatelimit("t", 1, "1s")`, + wantLimit: false, + limitStatusCode: http.StatusTooManyRequests, }, { - name: "test clusterLeakyBucketRatelimit fail closed", - ratelimitFilterName: "clusterLeakyBucketRatelimit", - failClosed: true, - wantLimit: true, - limitStatusCode: http.StatusTooManyRequests, - }} { + name: "test clusterLeakyBucketRatelimit fail closed", + filters: `ratelimitFailClosed() -> clusterLeakyBucketRatelimit("t", 1, "1s", 10, 1)`, + wantLimit: true, + limitStatusCode: http.StatusTooManyRequests, + }, + } { t.Run(tt.name, func(t *testing.T) { fr := builtin.MakeRegistry() - reg := ratelimit.NewSwarmRegistry(nil, &snet.RedisOptions{Addrs: []string{"127.0.0.2:6379"}}, ratelimit.Settings{}) + reg := ratelimit.NewSwarmRegistry(nil, &snet.RedisOptions{Addrs: []string{"fails.test:6379"}}, ratelimit.Settings{}) defer reg.Close() provider := fratelimit.NewRatelimitProvider(reg) @@ -96,19 +90,9 @@ func TestFailureMode(t *testing.T) { })) defer backend.Close() - args := []interface{}{"t", 1, "1s"} - switch tt.ratelimitFilterName { - case filters.ClusterLeakyBucketRatelimitName: - args = append(args, 10, 1) - case filters.ClusterClientRatelimitName: - args = append(args, "X-Test") - } - - r := &eskip.Route{Filters: []*eskip.Filter{ - {Name: tt.ratelimitFilterName, Args: args}}, Backend: backend.URL} - if tt.failClosed { - r.Filters = append([]*eskip.Filter{{Name: fratelimit.NewFailClosed().Name()}}, - r.Filters...) + r := &eskip.Route{ + Filters: eskip.MustParseFilters(tt.filters), + Backend: backend.URL, } proxy := proxytest.WithParamsAndRoutingOptions( @@ -123,22 +107,16 @@ func TestFailureMode(t *testing.T) { }, r) defer proxy.Close() - reqURL, err := url.Parse(proxy.URL) - if err != nil { - t.Fatalf("Failed to parse url %s: %v", proxy.URL, err) - } - req, err := http.NewRequest("GET", reqURL.String(), nil) + req, err := http.NewRequest("GET", proxy.URL, nil) if err != nil { t.Fatal(err) - return } req.Header.Set("X-Test", "foo") - rsp, err := http.DefaultClient.Do(req) + rsp, err := proxy.Client().Do(req) if err != nil { t.Fatal(err) } - defer rsp.Body.Close() limited := rsp.StatusCode == tt.limitStatusCode @@ -149,5 +127,4 @@ func TestFailureMode(t *testing.T) { } }) } - }