diff --git a/hack/ccp/go.mod b/hack/ccp/go.mod index 23c448948..9dca4d93e 100644 --- a/hack/ccp/go.mod +++ b/hack/ccp/go.mod @@ -4,6 +4,7 @@ go 1.22.5 require ( buf.build/gen/go/bufbuild/protovalidate/protocolbuffers/go v1.34.1-20240508200655-46a4cf4ba109.1 + connectrpc.com/authn v0.1.0 connectrpc.com/connect v1.16.2 connectrpc.com/grpchealth v1.3.0 connectrpc.com/grpcreflect v1.2.0 diff --git a/hack/ccp/go.sum b/hack/ccp/go.sum index 1af2917da..841238d4e 100644 --- a/hack/ccp/go.sum +++ b/hack/ccp/go.sum @@ -1,6 +1,8 @@ buf.build/gen/go/bufbuild/protovalidate/protocolbuffers/go v1.34.1-20240508200655-46a4cf4ba109.1 h1:LEXWFH/xZ5oOWrC3oOtHbUyBdzRWMCPpAQmKC9v05mA= buf.build/gen/go/bufbuild/protovalidate/protocolbuffers/go v1.34.1-20240508200655-46a4cf4ba109.1/go.mod h1:XF+P8+RmfdufmIYpGUC+6bF7S+IlmHDEnCrO3OXaUAQ= cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= +connectrpc.com/authn v0.1.0 h1:m5weACjLWwgwcjttvUDyTPICJKw74+p2obBVrf8hT9E= +connectrpc.com/authn v0.1.0/go.mod h1:AwNZK/KYbqaJzRYadTuAaoz6sYQSPdORPqh1TOPIkgY= connectrpc.com/connect v1.16.2 h1:ybd6y+ls7GOlb7Bh5C8+ghA6SvCBajHwxssO2CGFjqE= connectrpc.com/connect v1.16.2/go.mod h1:n2kgwskMHXC+lVqb18wngEpF95ldBHXjZYJussz5FRc= connectrpc.com/grpchealth v1.3.0 h1:FA3OIwAvuMokQIXQrY5LbIy8IenftksTP/lG4PbYN+E= diff --git a/hack/ccp/internal/api/admin/admin.go b/hack/ccp/internal/api/admin/admin.go index 0dd1919ad..675ab95f8 100644 --- a/hack/ccp/internal/api/admin/admin.go +++ b/hack/ccp/internal/api/admin/admin.go @@ -11,6 +11,7 @@ import ( "sync" "time" + "connectrpc.com/authn" "connectrpc.com/connect" "connectrpc.com/grpchealth" "connectrpc.com/grpcreflect" @@ -63,7 +64,7 @@ func New(logger logr.Logger, config Config, ctrl *controller.Controller, store s srv.v = v } - srv.cache = ttlcache.New[adminv1.PackageType, *adminv1.ListPackagesResponse]( + srv.cache = ttlcache.New( ttlcache.WithTTL[adminv1.PackageType, *adminv1.ListPackagesResponse](1 * time.Second), ) srv.wg.Add(1) @@ -98,10 +99,13 @@ func (s *Server) Run() error { compress1KB, )) + auth := authenticate(s.logger, s.store) + handler := authn.NewMiddleware(auth).Wrap(mux) + s.server = &http.Server{ Addr: s.config.Addr, Handler: h2c.NewHandler( - corsutil.New(nil).Handler(mux), + corsutil.New(nil).Handler(handler), &http2.Server{}, ), ReadHeaderTimeout: time.Second, diff --git a/hack/ccp/internal/api/admin/auth.go b/hack/ccp/internal/api/admin/auth.go new file mode 100644 index 000000000..8963c6113 --- /dev/null +++ b/hack/ccp/internal/api/admin/auth.go @@ -0,0 +1,95 @@ +package admin + +import ( + "context" + "strings" + + "connectrpc.com/authn" + "github.com/go-logr/logr" + + "github.com/artefactual/archivematica/hack/ccp/internal/store" +) + +var errInvalidAuth = authn.Errorf("invalid authorization") + +func authenticate(logger logr.Logger, store store.Store) authn.AuthFunc { + return multiAuthenticate( + authApiKey(logger, store), + ) +} + +func multiAuthenticate(methods ...authn.AuthFunc) authn.AuthFunc { + return func(ctx context.Context, req authn.Request) (any, error) { + var lastErr error + for _, method := range methods { + result, err := method(ctx, req) + if err == nil { + return result, nil + } + lastErr = err + } + return nil, lastErr + } +} + +func authApiKey(logger logr.Logger, store store.Store) authn.AuthFunc { + return func(ctx context.Context, req authn.Request) (any, error) { + auth := req.Header().Get("Authorization") + if auth == "" { + return nil, errInvalidAuth + } + + username, key, ok := parseApiKey(auth) + if !ok { + return nil, errInvalidAuth + } + + ok, err := store.ValidateUserAPIKey(ctx, username, key) + if err != nil { + logger.Error(err, "Cannot look up user details.") + return nil, errInvalidAuth + } + if !ok { + return nil, errInvalidAuth + } + + return username, nil + } +} + +// parseApiKey parses the ApiKey string. +// "ApiKey test:test" returns ("test", "test", true). +func parseApiKey(auth string) (username, key string, ok bool) { + const prefix = "ApiKey " + // Case insensitive prefix match. + if len(auth) < len(prefix) || !equalFold(auth[:len(prefix)], prefix) { + return "", "", false + } + username, key, ok = strings.Cut(auth[len(prefix):], ":") + if !ok { + return "", "", false + } + return username, key, true +} + +// equalFold is [strings.EqualFold], ASCII only. It reports whether s and t +// are equal, ASCII-case-insensitively. +func equalFold(s, t string) bool { + if len(s) != len(t) { + return false + } + for i := range len(s) { + if lower(s[i]) != lower(t[i]) { + return false + } + } + return true +} + +// lower returns the ASCII lowercase version of b. +func lower(b byte) byte { + if 'A' <= b && b <= 'Z' { + return b + ('a' - 'A') + } + return b +} diff --git a/hack/ccp/internal/api/admin/auth_test.go b/hack/ccp/internal/api/admin/auth_test.go new file mode 100644 index 000000000..282cc2ce6 --- /dev/null +++ b/hack/ccp/internal/api/admin/auth_test.go @@ -0,0 +1,56 @@ +package admin + +import ( + "net/http" + "net/http/httptest" + "testing" + + "connectrpc.com/authn" + "github.com/artefactual/archivematica/hack/ccp/internal/store/storemock" + "github.com/go-logr/logr" + "go.artefactual.dev/tools/mockutil" + "go.uber.org/mock/gomock" + "gotest.tools/v3/assert" +) + +func TestAuthentication(t *testing.T) { + t.Parallel() + + t.Run("Accepts API key", func(t *testing.T) { + t.Parallel() + + store := storemock.NewMockStore(gomock.NewController(t)) + store.EXPECT().ValidateUserAPIKey(mockutil.Context(), "test", "test").Return(true, nil) + + auth := multiAuthenticate(authApiKey(logr.Discard(), store)) + handler := authn.NewMiddleware(auth).Wrap(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) + + req := httptest.NewRequest("GET", "http://example.com/foo", nil) + req.Header.Set("Authorization", "ApiKey test:test") + w := httptest.NewRecorder() + handler.ServeHTTP(w, req) + + resp := w.Result() + + assert.Equal(t, resp.StatusCode, http.StatusOK) + }) + + t.Run("Rejects invalid API key", func(t *testing.T) { + t.Parallel() + + store := storemock.NewMockStore(gomock.NewController(t)) + store.EXPECT().ValidateUserAPIKey(mockutil.Context(), "test", "12345").Return(false, nil) + + auth := multiAuthenticate(authApiKey(logr.Discard(), store)) + handler := authn.NewMiddleware(auth).Wrap(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) + + req := httptest.NewRequest("GET", "http://example.com/foo", nil) + req.Header.Set("Authorization", "ApiKey test:12345") + w := httptest.NewRecorder() + handler.ServeHTTP(w, req) + + resp := w.Result() + + assert.Equal(t, resp.StatusCode, http.StatusUnauthorized) + }) +} diff --git a/hack/ccp/internal/store/mysql.go b/hack/ccp/internal/store/mysql.go index 420423db0..996c0fa76 100644 --- a/hack/ccp/internal/store/mysql.go +++ b/hack/ccp/internal/store/mysql.go @@ -975,6 +975,20 @@ func (s *mysqlStoreImpl) ReadStorageServiceConfig(ctx context.Context) (ret Stor return ret, nil } +func (s *mysqlStoreImpl) ValidateUserAPIKey(ctx context.Context, username, key string) (_ bool, err error) { + defer wrap(&err, "ValidateUserAPIKey(%q, %q)", username, key) + + _, err = s.queries.ReadUserWithKey(ctx, &sqlc.ReadUserWithKeyParams{ + Username: username, + Key: key, + }) + if err == sql.ErrNoRows { + return false, nil + } + + return true, nil +} + func (s *mysqlStoreImpl) Running() bool { return s != nil } diff --git a/hack/ccp/internal/store/sqlcmysql/db.go b/hack/ccp/internal/store/sqlcmysql/db.go index 60e38688d..24bf03da1 100644 --- a/hack/ccp/internal/store/sqlcmysql/db.go +++ b/hack/ccp/internal/store/sqlcmysql/db.go @@ -96,6 +96,9 @@ func Prepare(ctx context.Context, db DBTX) (*Queries, error) { if q.readUnitVarsStmt, err = db.PrepareContext(ctx, readUnitVars); err != nil { return nil, fmt.Errorf("error preparing query ReadUnitVars: %w", err) } + if q.readUserWithKeyStmt, err = db.PrepareContext(ctx, readUserWithKey); err != nil { + return nil, fmt.Errorf("error preparing query ReadUserWithKey: %w", err) + } if q.updateJobStatusStmt, err = db.PrepareContext(ctx, updateJobStatus); err != nil { return nil, fmt.Errorf("error preparing query UpdateJobStatus: %w", err) } @@ -239,6 +242,11 @@ func (q *Queries) Close() error { err = fmt.Errorf("error closing readUnitVarsStmt: %w", cerr) } } + if q.readUserWithKeyStmt != nil { + if cerr := q.readUserWithKeyStmt.Close(); cerr != nil { + err = fmt.Errorf("error closing readUserWithKeyStmt: %w", cerr) + } + } if q.updateJobStatusStmt != nil { if cerr := q.updateJobStatusStmt.Close(); cerr != nil { err = fmt.Errorf("error closing updateJobStatusStmt: %w", cerr) @@ -332,6 +340,7 @@ type Queries struct { readTransferWithLocationStmt *sql.Stmt readUnitVarStmt *sql.Stmt readUnitVarsStmt *sql.Stmt + readUserWithKeyStmt *sql.Stmt updateJobStatusStmt *sql.Stmt updateSIPLocationStmt *sql.Stmt updateSIPStatusStmt *sql.Stmt @@ -368,6 +377,7 @@ func (q *Queries) WithTx(tx *sql.Tx) *Queries { readTransferWithLocationStmt: q.readTransferWithLocationStmt, readUnitVarStmt: q.readUnitVarStmt, readUnitVarsStmt: q.readUnitVarsStmt, + readUserWithKeyStmt: q.readUserWithKeyStmt, updateJobStatusStmt: q.updateJobStatusStmt, updateSIPLocationStmt: q.updateSIPLocationStmt, updateSIPStatusStmt: q.updateSIPStatusStmt, diff --git a/hack/ccp/internal/store/sqlcmysql/query.sql.go b/hack/ccp/internal/store/sqlcmysql/query.sql.go index 52acba2cd..aa25309de 100644 --- a/hack/ccp/internal/store/sqlcmysql/query.sql.go +++ b/hack/ccp/internal/store/sqlcmysql/query.sql.go @@ -617,6 +617,34 @@ func (q *Queries) ReadUnitVars(ctx context.Context, arg *ReadUnitVarsParams) ([] return items, nil } +const readUserWithKey = `-- name: ReadUserWithKey :one + +SELECT auth_user.id, auth_user.username, auth_user.is_active +FROM auth_user +JOIN tastypie_apikey ON auth_user.id = tastypie_apikey.user_id +WHERE auth_user.username = ? AND tastypie_apikey.key = ? AND auth_user.is_active = 1 +LIMIT 1 +` + +type ReadUserWithKeyParams struct { + Username string + Key string +} + +type ReadUserWithKeyRow struct { + ID int32 + Username string + IsActive bool +} + +// Authorization +func (q *Queries) ReadUserWithKey(ctx context.Context, arg *ReadUserWithKeyParams) (*ReadUserWithKeyRow, error) { + row := q.queryRow(ctx, q.readUserWithKeyStmt, readUserWithKey, arg.Username, arg.Key) + var i ReadUserWithKeyRow + err := row.Scan(&i.ID, &i.Username, &i.IsActive) + return &i, err +} + const updateJobStatus = `-- name: UpdateJobStatus :exec UPDATE Jobs SET currentStep = ? WHERE jobUUID = ? ` diff --git a/hack/ccp/internal/store/store.go b/hack/ccp/internal/store/store.go index 0f647fab7..42c2c148a 100644 --- a/hack/ccp/internal/store/store.go +++ b/hack/ccp/internal/store/store.go @@ -115,6 +115,10 @@ type Store interface { // Archivematica Storage Service associated to this pipeline. ReadStorageServiceConfig(ctx context.Context) (StorageServiceConfig, error) + // ValidateUserAPIKey confirms that the username with the given API key + // exists in the database and is active. + ValidateUserAPIKey(ctx context.Context, username, key string) (bool, error) + Running() bool Close() error } diff --git a/hack/ccp/internal/store/storemock/mock_store.go b/hack/ccp/internal/store/storemock/mock_store.go index bd839f190..d1c5830ac 100644 --- a/hack/ccp/internal/store/storemock/mock_store.go +++ b/hack/ccp/internal/store/storemock/mock_store.go @@ -1206,3 +1206,42 @@ func (c *MockStoreUpsertTransferCall) DoAndReturn(f func(context.Context, uuid.U c.Call = c.Call.DoAndReturn(f) return c } + +// ValidateUserAPIKey mocks base method. +func (m *MockStore) ValidateUserAPIKey(ctx context.Context, username, key string) (bool, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ValidateUserAPIKey", ctx, username, key) + ret0, _ := ret[0].(bool) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ValidateUserAPIKey indicates an expected call of ValidateUserAPIKey. +func (mr *MockStoreMockRecorder) ValidateUserAPIKey(ctx, username, key any) *MockStoreValidateUserAPIKeyCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ValidateUserAPIKey", reflect.TypeOf((*MockStore)(nil).ValidateUserAPIKey), ctx, username, key) + return &MockStoreValidateUserAPIKeyCall{Call: call} +} + +// MockStoreValidateUserAPIKeyCall wrap *gomock.Call +type MockStoreValidateUserAPIKeyCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockStoreValidateUserAPIKeyCall) Return(arg0 bool, arg1 error) *MockStoreValidateUserAPIKeyCall { + c.Call = c.Call.Return(arg0, arg1) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockStoreValidateUserAPIKeyCall) Do(f func(context.Context, string, string) (bool, error)) *MockStoreValidateUserAPIKeyCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockStoreValidateUserAPIKeyCall) DoAndReturn(f func(context.Context, string, string) (bool, error)) *MockStoreValidateUserAPIKeyCall { + c.Call = c.Call.DoAndReturn(f) + return c +} diff --git a/hack/ccp/sqlc/mysql/query.sql b/hack/ccp/sqlc/mysql/query.sql index be25aabc3..674da13b3 100644 --- a/hack/ccp/sqlc/mysql/query.sql +++ b/hack/ccp/sqlc/mysql/query.sql @@ -161,3 +161,14 @@ SELECT name, value, scope FROM DashboardSettings WHERE name LIKE ?; -- name: ReadDashboardSetting :one SELECT name, value, scope FROM DashboardSettings WHERE name = ?; + +-- +-- Authorization +-- + +-- name: ReadUserWithKey :one +SELECT auth_user.id, auth_user.username, auth_user.is_active +FROM auth_user +JOIN tastypie_apikey ON auth_user.id = tastypie_apikey.user_id +WHERE auth_user.username = ? AND tastypie_apikey.key = ? AND auth_user.is_active = 1 +LIMIT 1;