Skip to content

Commit

Permalink
chore: more test cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
mistahj67 committed Dec 30, 2024
1 parent 0178bd8 commit 4e44d60
Showing 1 changed file with 42 additions and 121 deletions.
163 changes: 42 additions & 121 deletions cmd/api/src/api/v2/auth/sso_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,10 @@
package auth_test

import (
"context"
"net/http"
"net/http/httptest"
"net/url"
"testing"

"github.com/gorilla/mux"
"github.com/pkg/errors"
"github.com/specterops/bloodhound/src/api"
"github.com/specterops/bloodhound/src/api/v2/apitest"
Expand All @@ -34,14 +31,16 @@ import (
"github.com/specterops/bloodhound/src/database/types/null"
"github.com/specterops/bloodhound/src/model"
"github.com/specterops/bloodhound/src/utils/test"
"github.com/stretchr/testify/require"
"go.uber.org/mock/gomock"
)

func TestManagementResource_ListAuthProviders(t *testing.T) {
const endpoint = "/api/v2/sso-providers"

var (
mockCtrl = gomock.NewController(t)
resources, mockDB = apitest.NewAuthManagementResource(mockCtrl)
reqCtx = &ctx.Context{Host: &url.URL{}}

oidcProvider = model.OIDCProvider{
SSOProviderID: 1,
Expand Down Expand Up @@ -94,144 +93,66 @@ func TestManagementResource_ListAuthProviders(t *testing.T) {

test.Request(t).
WithMethod(http.MethodGet).
WithContext(&ctx.Context{Host: &url.URL{}}).
WithURL("/api/v2/sso-providers").
WithContext(reqCtx).
WithURL(endpoint).
OnHandlerFunc(resources.ListAuthProviders).
Require().
ResponseStatusCode(http.StatusOK)
})

t.Run("successfully list auth providers with sorting", func(t *testing.T) {

// sorting by name descending
mockDB.EXPECT().GetAllSSOProviders(
gomock.Any(),
"name desc",
model.SQLFilter{SQLString: "", Params: nil},
).Return(ssoProviders, nil)

endpoint := "/api/v2/sso-providers?sort_by=-name"

bhCtx := &ctx.Context{
Host: &url.URL{
Scheme: "http",
Host: "example.com",
},
}
requestContext := context.WithValue(context.Background(), ctx.ValueKey, bhCtx)

req, err := http.NewRequestWithContext(requestContext, "GET", endpoint, nil)
require.NoError(t, err)
req.Header.Set("Content-Type", "application/json")
req.Host = "example.com"

router := mux.NewRouter()
router.HandleFunc("/api/v2/sso-providers", resources.ListAuthProviders).Methods("GET")
mockDB.EXPECT().GetAllSSOProviders(gomock.Any(), "name desc", model.SQLFilter{SQLString: "", Params: nil}).Return(ssoProviders, nil)
const reqUrl = endpoint + "?sort_by=-name"

rr := httptest.NewRecorder()
router.ServeHTTP(rr, req)

require.Equal(t, http.StatusOK, rr.Code)
test.Request(t).
WithMethod(http.MethodGet).
WithContext(reqCtx).
WithURL(reqUrl).
OnHandlerFunc(resources.ListAuthProviders).
Require().
ResponseStatusCode(http.StatusOK)
})

t.Run("successfully list auth providers with filtering", func(t *testing.T) {
oidcProvider := model.OIDCProvider{
SSOProviderID: 1,
ClientID: "client-id-1",
Issuer: "https://issuer1.com",
}
ssoProviders := []model.SSOProvider{
{
Serial: model.Serial{ID: 1},
Name: "OIDC Provider 1",
Slug: "oidc-provider-1",
Type: model.SessionAuthProviderOIDC,
OIDCProvider: &oidcProvider,
},
}

// filtering by name
mockDB.EXPECT().GetAllSSOProviders(
gomock.Any(),
"created_at",
model.SQLFilter{
SQLString: "name = ?",
Params: []interface{}{"OIDC Provider 1"},
},
).Return(ssoProviders, nil)

endpoint := "/api/v2/sso-providers?name=eq:OIDC Provider 1"

bhCtx := &ctx.Context{
Host: &url.URL{
Scheme: "http",
Host: "example.com",
},
}
requestContext := context.WithValue(context.Background(), ctx.ValueKey, bhCtx)

req, err := http.NewRequestWithContext(requestContext, "GET", endpoint, nil)
require.NoError(t, err)
req.Header.Set("Content-Type", "application/json")
req.Host = "example.com"

router := mux.NewRouter()
router.HandleFunc("/api/v2/sso-providers", resources.ListAuthProviders).Methods("GET")

rr := httptest.NewRecorder()
router.ServeHTTP(rr, req)
mockDB.EXPECT().GetAllSSOProviders(gomock.Any(), "created_at", model.SQLFilter{
SQLString: "name = ?",
Params: []interface{}{"OIDC Provider 1"},
}).Return([]model.SSOProvider{ssoProviders[0]}, nil)
const reqUrl = endpoint + "?name=eq:OIDC Provider 1"

require.Equal(t, http.StatusOK, rr.Code)
test.Request(t).
WithMethod(http.MethodGet).
WithContext(reqCtx).
WithURL(reqUrl).
OnHandlerFunc(resources.ListAuthProviders).
Require().
ResponseStatusCode(http.StatusOK)
})

t.Run("fail to list auth providers with invalid sort field", func(t *testing.T) {
endpoint := "/api/v2/sso-providers?sort_by=invalid_field"

bhCtx := &ctx.Context{
Host: &url.URL{
Scheme: "http",
Host: "example.com",
},
}
requestContext := context.WithValue(context.Background(), ctx.ValueKey, bhCtx)
const reqUrl = endpoint + "?sort_by=invalid_field"

req, err := http.NewRequestWithContext(requestContext, "GET", endpoint, nil)
require.NoError(t, err)
req.Header.Set("Content-Type", "application/json")
req.Host = "example.com"

router := mux.NewRouter()
router.HandleFunc("/api/v2/sso-providers", resources.ListAuthProviders).Methods("GET")

rr := httptest.NewRecorder()
router.ServeHTTP(rr, req)

require.Equal(t, http.StatusBadRequest, rr.Code)
test.Request(t).
WithMethod(http.MethodGet).
WithContext(reqCtx).
WithURL(reqUrl).
OnHandlerFunc(resources.ListAuthProviders).
Require().
ResponseStatusCode(http.StatusBadRequest)
})

t.Run("fail to list auth providers with invalid filter predicate", func(t *testing.T) {
endpoint := "/api/v2/sso-providers?name=invalid_predicate:Provider"

bhCtx := &ctx.Context{
Host: &url.URL{
Scheme: "http",
Host: "example.com",
},
}
requestContext := context.WithValue(context.Background(), ctx.ValueKey, bhCtx)

req, err := http.NewRequestWithContext(requestContext, "GET", endpoint, nil)
require.NoError(t, err)
req.Header.Set("Content-Type", "application/json")
req.Host = "example.com"
const reqUrl = endpoint + "?name=invalid_predicate:Provider"

router := mux.NewRouter()
router.HandleFunc("/api/v2/sso-providers", resources.ListAuthProviders).Methods("GET")

rr := httptest.NewRecorder()
router.ServeHTTP(rr, req)

require.Equal(t, http.StatusBadRequest, rr.Code)
test.Request(t).
WithMethod(http.MethodGet).
WithContext(reqCtx).
WithURL(reqUrl).
OnHandlerFunc(resources.ListAuthProviders).
Require().
ResponseStatusCode(http.StatusBadRequest)
})
}

Expand Down

0 comments on commit 4e44d60

Please sign in to comment.