From 100f2b4380ee72985364b8dc185a339fb47a20eb Mon Sep 17 00:00:00 2001 From: Juan Antonio Osorio Date: Mon, 26 Aug 2024 16:25:00 +0300 Subject: [PATCH] Add per-entity tables as views to migrate Signed-off-by: Juan Antonio Osorio --- cmd/server/app/history_purge_test.go | 58 ++- .../000101_remove_entities_tables.up.sql | 23 - ...=> 000102_remove_entities_tables.down.sql} | 0 .../000102_remove_entities_tables.up.sql | 91 ++++ database/mock/store.go | 255 +++++++++++ database/query/artifacts.sql | 13 + database/query/eval_history.sql | 4 +- database/query/pull_requests.sql | 7 + database/query/repositories.sql | 61 +++ internal/db/artifacts.sql.go | 107 +++++ internal/db/eval_history.sql.go | 6 +- internal/db/eval_history_test.go | 28 -- internal/db/models.go | 38 ++ internal/db/profiles_test.go | 5 - internal/db/pull_requests.sql.go | 51 +++ internal/db/querier.go | 20 + internal/db/repositories.sql.go | 409 ++++++++++++++++++ 17 files changed, 1098 insertions(+), 78 deletions(-) delete mode 100644 database/migrations/000101_remove_entities_tables.up.sql rename database/migrations/{000101_remove_entities_tables.down.sql => 000102_remove_entities_tables.down.sql} (100%) create mode 100644 database/migrations/000102_remove_entities_tables.up.sql create mode 100644 database/query/artifacts.sql create mode 100644 database/query/pull_requests.sql create mode 100644 database/query/repositories.sql create mode 100644 internal/db/artifacts.sql.go create mode 100644 internal/db/pull_requests.sql.go create mode 100644 internal/db/repositories.sql.go diff --git a/cmd/server/app/history_purge_test.go b/cmd/server/app/history_purge_test.go index d0990adde3..11b40a2fc6 100644 --- a/cmd/server/app/history_purge_test.go +++ b/cmd/server/app/history_purge_test.go @@ -41,9 +41,12 @@ func TestRecordSize(t *testing.T) { db.ListEvaluationHistoryStaleRecordsRow{ ID: uuid.Nil, EvaluationTime: time.Now(), - EntityType: int32(1), - EntityID: uuid.Nil, - RuleID: uuid.Nil, + EntityType: db.NullEntities{ + Entities: db.EntitiesRepository, + Valid: true, + }, + EntityID: uuid.Nil, + RuleID: uuid.Nil, }, ) @@ -76,8 +79,11 @@ func TestPurgeLoop(t *testing.T) { EvaluationTime: time.Now(), ID: uuid1, RuleID: ruleID1, - EntityType: int32(1), - EntityID: entityID1, + EntityType: db.NullEntities{ + Entities: db.EntitiesRepository, + Valid: true, + }, + EntityID: entityID1, }, ), withTransactionStuff(), @@ -104,8 +110,11 @@ func TestPurgeLoop(t *testing.T) { EvaluationTime: time.Now(), ID: uuid1, RuleID: ruleID1, - EntityType: int32(1), - EntityID: entityID1, + EntityType: db.NullEntities{ + Entities: db.EntitiesRepository, + Valid: true, + }, + EntityID: entityID1, }, ), ), @@ -126,22 +135,31 @@ func TestPurgeLoop(t *testing.T) { EvaluationTime: time.Now(), ID: uuid1, RuleID: ruleID1, - EntityType: int32(1), - EntityID: entityID1, + EntityType: db.NullEntities{ + Entities: db.EntitiesRepository, + Valid: true, + }, + EntityID: entityID1, }, db.ListEvaluationHistoryStaleRecordsRow{ EvaluationTime: time.Now(), ID: uuid2, RuleID: ruleID2, - EntityType: int32(1), - EntityID: entityID2, + EntityType: db.NullEntities{ + Entities: db.EntitiesRepository, + Valid: true, + }, + EntityID: entityID2, }, db.ListEvaluationHistoryStaleRecordsRow{ EvaluationTime: time.Now(), ID: uuid3, RuleID: ruleID3, - EntityType: int32(1), - EntityID: entityID3, + EntityType: db.NullEntities{ + Entities: db.EntitiesRepository, + Valid: true, + }, + EntityID: entityID3, }, ), withTransactionStuff(), @@ -201,8 +219,11 @@ func TestPurgeLoop(t *testing.T) { EvaluationTime: time.Now(), ID: uuid1, RuleID: ruleID1, - EntityType: int32(1), - EntityID: entityID1, + EntityType: db.NullEntities{ + Entities: db.EntitiesRepository, + Valid: true, + }, + EntityID: entityID1, }, ), withTransactionStuff(), @@ -434,14 +455,17 @@ var ( ruleID3 = uuid.MustParse("00000000-0000-0000-0000-000000000333") evaluatedAt1 = time.Now() evaluatedAt2 = evaluatedAt1.Add(-1 * time.Hour) - entityType = int32(1) + entityType = db.NullEntities{ + Entities: db.EntitiesRepository, + Valid: true, + } ) //nolint:unparam func makeHistoryRow( id uuid.UUID, evaluatedAt time.Time, - entityType int32, + entityType db.NullEntities, entityID uuid.UUID, ruleID uuid.UUID, ) db.ListEvaluationHistoryStaleRecordsRow { diff --git a/database/migrations/000101_remove_entities_tables.up.sql b/database/migrations/000101_remove_entities_tables.up.sql deleted file mode 100644 index df7d94f6a6..0000000000 --- a/database/migrations/000101_remove_entities_tables.up.sql +++ /dev/null @@ -1,23 +0,0 @@ --- Copyright 2024 Stacklok, Inc --- --- Licensed under the Apache License, Version 2.0 (the "License"); --- you may not use this file except in compliance with the License. --- You may obtain a copy of the License at --- --- http://www.apache.org/licenses/LICENSE-2.0 --- --- Unless required by applicable law or agreed to in writing, software --- distributed under the License is distributed on an "AS IS" BASIS, --- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. --- See the License for the specific language governing permissions and --- limitations under the License. - -BEGIN; - --- remove the repositories, artifacts and pull_request tables - -DROP TABLE IF EXISTS repositories; -DROP TABLE IF EXISTS artifacts; -DROP TABLE IF EXISTS pull_requests; - -COMMIT; \ No newline at end of file diff --git a/database/migrations/000101_remove_entities_tables.down.sql b/database/migrations/000102_remove_entities_tables.down.sql similarity index 100% rename from database/migrations/000101_remove_entities_tables.down.sql rename to database/migrations/000102_remove_entities_tables.down.sql diff --git a/database/migrations/000102_remove_entities_tables.up.sql b/database/migrations/000102_remove_entities_tables.up.sql new file mode 100644 index 0000000000..0aafe5df04 --- /dev/null +++ b/database/migrations/000102_remove_entities_tables.up.sql @@ -0,0 +1,91 @@ +-- Copyright 2024 Stacklok, Inc +-- +-- Licensed under the Apache License, Version 2.0 (the "License"); +-- you may not use this file except in compliance with the License. +-- You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. + +BEGIN; + +-- remove the repositories, artifacts and pull_request tables + +DROP TABLE IF EXISTS repositories; +DROP TABLE IF EXISTS artifacts; +DROP TABLE IF EXISTS pull_requests; + +CREATE VIEW repositories AS +SELECT + ei.id, + ei.project_id, + pr.name AS provider, + ei.provider_id, + (prop_owner.value->>'text')::TEXT AS repo_owner, + (prop_name.value->>'text')::TEXT AS repo_name, + (prop_repo_id.value->>'number')::BIGINT AS repo_id, + (prop_is_private.value->>'boolean')::BOOLEAN AS is_private, + (prop_is_fork.value->>'boolean')::BOOLEAN AS is_fork, + (prop_webhook_id.value->>'number')::BIGINT AS webhook_id, + (prop_webhook_url.value->>'text')::TEXT AS webhook_url, + (prop_deploy_url.value->>'text')::TEXT AS deploy_url, + (prop_clone_url.value->>'text')::TEXT AS clone_url, + (prop_default_branch.value->>'text')::TEXT AS default_branch, + (prop_license.value->>'text')::TEXT AS license, + ei.created_at +FROM + entity_instances ei + JOIN providers pr ON ei.provider_id = pr.id + LEFT JOIN properties prop_owner ON ei.id = prop_owner.entity_id AND prop_owner.key = 'repo_owner' + LEFT JOIN properties prop_name ON ei.id = prop_name.entity_id AND prop_name.key = 'repo_name' + LEFT JOIN properties prop_repo_id ON ei.id = prop_repo_id.entity_id AND prop_repo_id.key = 'repo_id' + LEFT JOIN properties prop_is_private ON ei.id = prop_is_private.entity_id AND prop_is_private.key = 'is_private' + LEFT JOIN properties prop_is_fork ON ei.id = prop_is_fork.entity_id AND prop_is_fork.key = 'is_fork' + LEFT JOIN properties prop_webhook_id ON ei.id = prop_webhook_id.entity_id AND prop_webhook_id.key = 'webhook_id' + LEFT JOIN properties prop_webhook_url ON ei.id = prop_webhook_url.entity_id AND prop_webhook_url.key = 'webhook_url' + LEFT JOIN properties prop_deploy_url ON ei.id = prop_deploy_url.entity_id AND prop_deploy_url.key = 'deploy_url' + LEFT JOIN properties prop_clone_url ON ei.id = prop_clone_url.entity_id AND prop_clone_url.key = 'clone_url' + LEFT JOIN properties prop_default_branch ON ei.id = prop_default_branch.entity_id AND prop_default_branch.key = 'default_branch' + LEFT JOIN properties prop_license ON ei.id = prop_license.entity_id AND prop_license.key = 'license' +WHERE + ei.entity_type = 'repository'; + +CREATE VIEW artifacts AS +SELECT + ei.id, + ei.project_id, + pr.name AS provider_name, + ei.provider_id, + ei.originated_from AS repository_id, + (prop_artifact_name.value->>'text')::TEXT AS artifact_name, + (prop_artifact_type.value->>'text')::TEXT AS artifact_type, + (prop_artifact_visibility.value->>'text')::TEXT AS artifact_visibility, + ei.created_at +FROM + entity_instances ei + JOIN providers pr ON ei.provider_id = pr.id + LEFT JOIN properties prop_artifact_name ON ei.id = prop_artifact_name.entity_id AND prop_artifact_name.key = 'artifact_name' + LEFT JOIN properties prop_artifact_type ON ei.id = prop_artifact_type.entity_id AND prop_artifact_type.key = 'artifact_type' + LEFT JOIN properties prop_artifact_visibility ON ei.id = prop_artifact_visibility.entity_id AND prop_artifact_visibility.key = 'artifact_visibility' +WHERE + ei.entity_type = 'artifact'; + +CREATE VIEW pull_requests AS +SELECT + ei.id, + ei.originated_from AS repository_id, + (prop_pr_number.value->>'number')::BIGINT AS pr_number, + ei.created_at +FROM + entity_instances ei + LEFT JOIN properties prop_pr_number ON ei.id = prop_pr_number.entity_id AND prop_pr_number.key = 'pr_number' +WHERE + ei.entity_type = 'pull_request'; + + +COMMIT; \ No newline at end of file diff --git a/database/mock/store.go b/database/mock/store.go index 9d314956c2..50dac44e3d 100644 --- a/database/mock/store.go +++ b/database/mock/store.go @@ -131,6 +131,21 @@ func (mr *MockStoreMockRecorder) CountProfilesByName(arg0, arg1 any) *gomock.Cal return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CountProfilesByName", reflect.TypeOf((*MockStore)(nil).CountProfilesByName), arg0, arg1) } +// CountRepositories mocks base method. +func (m *MockStore) CountRepositories(arg0 context.Context) (int64, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CountRepositories", arg0) + ret0, _ := ret[0].(int64) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// CountRepositories indicates an expected call of CountRepositories. +func (mr *MockStoreMockRecorder) CountRepositories(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CountRepositories", reflect.TypeOf((*MockStore)(nil).CountRepositories), arg0) +} + // CountUsers mocks base method. func (m *MockStore) CountUsers(arg0 context.Context) (int64, error) { m.ctrl.T.Helper() @@ -747,6 +762,36 @@ func (mr *MockStoreMockRecorder) GetAllPropertyValuesV1(arg0, arg1 any) *gomock. return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAllPropertyValuesV1", reflect.TypeOf((*MockStore)(nil).GetAllPropertyValuesV1), arg0, arg1) } +// GetArtifactByID mocks base method. +func (m *MockStore) GetArtifactByID(arg0 context.Context, arg1 db.GetArtifactByIDParams) (db.Artifact, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetArtifactByID", arg0, arg1) + ret0, _ := ret[0].(db.Artifact) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetArtifactByID indicates an expected call of GetArtifactByID. +func (mr *MockStoreMockRecorder) GetArtifactByID(arg0, arg1 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetArtifactByID", reflect.TypeOf((*MockStore)(nil).GetArtifactByID), arg0, arg1) +} + +// GetArtifactByName mocks base method. +func (m *MockStore) GetArtifactByName(arg0 context.Context, arg1 db.GetArtifactByNameParams) (db.Artifact, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetArtifactByName", arg0, arg1) + ret0, _ := ret[0].(db.Artifact) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetArtifactByName indicates an expected call of GetArtifactByName. +func (mr *MockStoreMockRecorder) GetArtifactByName(arg0, arg1 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetArtifactByName", reflect.TypeOf((*MockStore)(nil).GetArtifactByName), arg0, arg1) +} + // GetBundle mocks base method. func (m *MockStore) GetBundle(arg0 context.Context, arg1 db.GetBundleParams) (db.Bundle, error) { m.ctrl.T.Helper() @@ -1242,6 +1287,51 @@ func (mr *MockStoreMockRecorder) GetProviderByName(arg0, arg1 any) *gomock.Call return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetProviderByName", reflect.TypeOf((*MockStore)(nil).GetProviderByName), arg0, arg1) } +// GetProviderWebhooks mocks base method. +func (m *MockStore) GetProviderWebhooks(arg0 context.Context, arg1 uuid.UUID) ([]db.GetProviderWebhooksRow, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetProviderWebhooks", arg0, arg1) + ret0, _ := ret[0].([]db.GetProviderWebhooksRow) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetProviderWebhooks indicates an expected call of GetProviderWebhooks. +func (mr *MockStoreMockRecorder) GetProviderWebhooks(arg0, arg1 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetProviderWebhooks", reflect.TypeOf((*MockStore)(nil).GetProviderWebhooks), arg0, arg1) +} + +// GetPullRequest mocks base method. +func (m *MockStore) GetPullRequest(arg0 context.Context, arg1 db.GetPullRequestParams) (db.PullRequest, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetPullRequest", arg0, arg1) + ret0, _ := ret[0].(db.PullRequest) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetPullRequest indicates an expected call of GetPullRequest. +func (mr *MockStoreMockRecorder) GetPullRequest(arg0, arg1 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPullRequest", reflect.TypeOf((*MockStore)(nil).GetPullRequest), arg0, arg1) +} + +// GetPullRequestByID mocks base method. +func (m *MockStore) GetPullRequestByID(arg0 context.Context, arg1 uuid.UUID) (db.PullRequest, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetPullRequestByID", arg0, arg1) + ret0, _ := ret[0].(db.PullRequest) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetPullRequestByID indicates an expected call of GetPullRequestByID. +func (mr *MockStoreMockRecorder) GetPullRequestByID(arg0, arg1 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPullRequestByID", reflect.TypeOf((*MockStore)(nil).GetPullRequestByID), arg0, arg1) +} + // GetQuerierWithTransaction mocks base method. func (m *MockStore) GetQuerierWithTransaction(arg0 *sql.Tx) db.ExtendQuerier { m.ctrl.T.Helper() @@ -1256,6 +1346,96 @@ func (mr *MockStoreMockRecorder) GetQuerierWithTransaction(arg0 any) *gomock.Cal return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetQuerierWithTransaction", reflect.TypeOf((*MockStore)(nil).GetQuerierWithTransaction), arg0) } +// GetRepoPathFromArtifactID mocks base method. +func (m *MockStore) GetRepoPathFromArtifactID(arg0 context.Context, arg1 uuid.UUID) (db.GetRepoPathFromArtifactIDRow, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetRepoPathFromArtifactID", arg0, arg1) + ret0, _ := ret[0].(db.GetRepoPathFromArtifactIDRow) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetRepoPathFromArtifactID indicates an expected call of GetRepoPathFromArtifactID. +func (mr *MockStoreMockRecorder) GetRepoPathFromArtifactID(arg0, arg1 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetRepoPathFromArtifactID", reflect.TypeOf((*MockStore)(nil).GetRepoPathFromArtifactID), arg0, arg1) +} + +// GetRepoPathFromPullRequestID mocks base method. +func (m *MockStore) GetRepoPathFromPullRequestID(arg0 context.Context, arg1 uuid.UUID) (db.GetRepoPathFromPullRequestIDRow, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetRepoPathFromPullRequestID", arg0, arg1) + ret0, _ := ret[0].(db.GetRepoPathFromPullRequestIDRow) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetRepoPathFromPullRequestID indicates an expected call of GetRepoPathFromPullRequestID. +func (mr *MockStoreMockRecorder) GetRepoPathFromPullRequestID(arg0, arg1 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetRepoPathFromPullRequestID", reflect.TypeOf((*MockStore)(nil).GetRepoPathFromPullRequestID), arg0, arg1) +} + +// GetRepositoryByID mocks base method. +func (m *MockStore) GetRepositoryByID(arg0 context.Context, arg1 uuid.UUID) (db.Repository, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetRepositoryByID", arg0, arg1) + ret0, _ := ret[0].(db.Repository) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetRepositoryByID indicates an expected call of GetRepositoryByID. +func (mr *MockStoreMockRecorder) GetRepositoryByID(arg0, arg1 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetRepositoryByID", reflect.TypeOf((*MockStore)(nil).GetRepositoryByID), arg0, arg1) +} + +// GetRepositoryByIDAndProject mocks base method. +func (m *MockStore) GetRepositoryByIDAndProject(arg0 context.Context, arg1 db.GetRepositoryByIDAndProjectParams) (db.Repository, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetRepositoryByIDAndProject", arg0, arg1) + ret0, _ := ret[0].(db.Repository) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetRepositoryByIDAndProject indicates an expected call of GetRepositoryByIDAndProject. +func (mr *MockStoreMockRecorder) GetRepositoryByIDAndProject(arg0, arg1 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetRepositoryByIDAndProject", reflect.TypeOf((*MockStore)(nil).GetRepositoryByIDAndProject), arg0, arg1) +} + +// GetRepositoryByRepoID mocks base method. +func (m *MockStore) GetRepositoryByRepoID(arg0 context.Context, arg1 int64) (db.Repository, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetRepositoryByRepoID", arg0, arg1) + ret0, _ := ret[0].(db.Repository) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetRepositoryByRepoID indicates an expected call of GetRepositoryByRepoID. +func (mr *MockStoreMockRecorder) GetRepositoryByRepoID(arg0, arg1 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetRepositoryByRepoID", reflect.TypeOf((*MockStore)(nil).GetRepositoryByRepoID), arg0, arg1) +} + +// GetRepositoryByRepoName mocks base method. +func (m *MockStore) GetRepositoryByRepoName(arg0 context.Context, arg1 db.GetRepositoryByRepoNameParams) (db.Repository, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetRepositoryByRepoName", arg0, arg1) + ret0, _ := ret[0].(db.Repository) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetRepositoryByRepoName indicates an expected call of GetRepositoryByRepoName. +func (mr *MockStoreMockRecorder) GetRepositoryByRepoName(arg0, arg1 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetRepositoryByRepoName", reflect.TypeOf((*MockStore)(nil).GetRepositoryByRepoName), arg0, arg1) +} + // GetRuleEvaluationByProfileIdAndRuleType mocks base method. func (m *MockStore) GetRuleEvaluationByProfileIdAndRuleType(arg0 context.Context, arg1 uuid.UUID, arg2 db.NullEntities, arg3 sql.NullString, arg4 uuid.NullUUID, arg5 sql.NullString) (*db.ListRuleEvaluationsByProfileIdRow, error) { m.ctrl.T.Helper() @@ -1554,6 +1734,21 @@ func (mr *MockStoreMockRecorder) InsertRemediationEvent(arg0, arg1 any) *gomock. return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InsertRemediationEvent", reflect.TypeOf((*MockStore)(nil).InsertRemediationEvent), arg0, arg1) } +// ListArtifactsByRepoID mocks base method. +func (m *MockStore) ListArtifactsByRepoID(arg0 context.Context, arg1 uuid.NullUUID) ([]db.Artifact, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ListArtifactsByRepoID", arg0, arg1) + ret0, _ := ret[0].([]db.Artifact) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ListArtifactsByRepoID indicates an expected call of ListArtifactsByRepoID. +func (mr *MockStoreMockRecorder) ListArtifactsByRepoID(arg0, arg1 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListArtifactsByRepoID", reflect.TypeOf((*MockStore)(nil).ListArtifactsByRepoID), arg0, arg1) +} + // ListEvaluationHistory mocks base method. func (m *MockStore) ListEvaluationHistory(arg0 context.Context, arg1 db.ListEvaluationHistoryParams) ([]db.ListEvaluationHistoryRow, error) { m.ctrl.T.Helper() @@ -1689,6 +1884,51 @@ func (mr *MockStoreMockRecorder) ListProvidersByProjectIDPaginated(arg0, arg1 an return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListProvidersByProjectIDPaginated", reflect.TypeOf((*MockStore)(nil).ListProvidersByProjectIDPaginated), arg0, arg1) } +// ListRegisteredRepositoriesByProjectIDAndProvider mocks base method. +func (m *MockStore) ListRegisteredRepositoriesByProjectIDAndProvider(arg0 context.Context, arg1 db.ListRegisteredRepositoriesByProjectIDAndProviderParams) ([]db.Repository, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ListRegisteredRepositoriesByProjectIDAndProvider", arg0, arg1) + ret0, _ := ret[0].([]db.Repository) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ListRegisteredRepositoriesByProjectIDAndProvider indicates an expected call of ListRegisteredRepositoriesByProjectIDAndProvider. +func (mr *MockStoreMockRecorder) ListRegisteredRepositoriesByProjectIDAndProvider(arg0, arg1 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListRegisteredRepositoriesByProjectIDAndProvider", reflect.TypeOf((*MockStore)(nil).ListRegisteredRepositoriesByProjectIDAndProvider), arg0, arg1) +} + +// ListRepositoriesAfterID mocks base method. +func (m *MockStore) ListRepositoriesAfterID(arg0 context.Context, arg1 db.ListRepositoriesAfterIDParams) ([]db.Repository, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ListRepositoriesAfterID", arg0, arg1) + ret0, _ := ret[0].([]db.Repository) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ListRepositoriesAfterID indicates an expected call of ListRepositoriesAfterID. +func (mr *MockStoreMockRecorder) ListRepositoriesAfterID(arg0, arg1 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListRepositoriesAfterID", reflect.TypeOf((*MockStore)(nil).ListRepositoriesAfterID), arg0, arg1) +} + +// ListRepositoriesByProjectID mocks base method. +func (m *MockStore) ListRepositoriesByProjectID(arg0 context.Context, arg1 db.ListRepositoriesByProjectIDParams) ([]db.Repository, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ListRepositoriesByProjectID", arg0, arg1) + ret0, _ := ret[0].([]db.Repository) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ListRepositoriesByProjectID indicates an expected call of ListRepositoriesByProjectID. +func (mr *MockStoreMockRecorder) ListRepositoriesByProjectID(arg0, arg1 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListRepositoriesByProjectID", reflect.TypeOf((*MockStore)(nil).ListRepositoriesByProjectID), arg0, arg1) +} + // ListRuleEvaluationsByProfileId mocks base method. func (m *MockStore) ListRuleEvaluationsByProfileId(arg0 context.Context, arg1 db.ListRuleEvaluationsByProfileIdParams) ([]db.ListRuleEvaluationsByProfileIdRow, error) { m.ctrl.T.Helper() @@ -1793,6 +2033,21 @@ func (mr *MockStoreMockRecorder) ReleaseLock(arg0, arg1 any) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReleaseLock", reflect.TypeOf((*MockStore)(nil).ReleaseLock), arg0, arg1) } +// RepositoryExistsAfterID mocks base method. +func (m *MockStore) RepositoryExistsAfterID(arg0 context.Context, arg1 uuid.UUID) (bool, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "RepositoryExistsAfterID", arg0, arg1) + ret0, _ := ret[0].(bool) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// RepositoryExistsAfterID indicates an expected call of RepositoryExistsAfterID. +func (mr *MockStoreMockRecorder) RepositoryExistsAfterID(arg0, arg1 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RepositoryExistsAfterID", reflect.TypeOf((*MockStore)(nil).RepositoryExistsAfterID), arg0, arg1) +} + // Rollback mocks base method. func (m *MockStore) Rollback(arg0 *sql.Tx) error { m.ctrl.T.Helper() diff --git a/database/query/artifacts.sql b/database/query/artifacts.sql new file mode 100644 index 0000000000..3ec075f37b --- /dev/null +++ b/database/query/artifacts.sql @@ -0,0 +1,13 @@ +-- name: GetArtifactByID :one +SELECT * FROM artifacts +WHERE artifacts.id = $1 AND artifacts.project_id = $2; + +-- name: GetArtifactByName :one +SELECT * FROM artifacts +WHERE lower(artifacts.artifact_name) = lower(sqlc.arg(artifact_name)) +AND artifacts.repository_id = $1 AND artifacts.project_id = $2; + +-- name: ListArtifactsByRepoID :many +SELECT * FROM artifacts +WHERE repository_id = $1 +ORDER BY id; \ No newline at end of file diff --git a/database/query/eval_history.sql b/database/query/eval_history.sql index 3c33b6da7a..8c9275f40f 100644 --- a/database/query/eval_history.sql +++ b/database/query/eval_history.sql @@ -130,7 +130,7 @@ WHERE s.id = sqlc.arg(evaluation_id) AND j.id = sqlc.arg(project_id); -- name: ListEvaluationHistory :many SELECT s.id::uuid AS evaluation_id, s.evaluation_time as evaluated_at, - ere.entity_type, + ei.entity_type, ere.entity_instance_id as entity_id, -- raw fields for entity names ei.name as entity_name, @@ -153,7 +153,7 @@ SELECT s.id::uuid AS evaluation_id, JOIN rule_instances ri ON ere.rule_id = ri.id JOIN rule_type rt ON ri.rule_type_id = rt.id JOIN profiles p ON ri.profile_id = p.id - LEFT JOIN entity_instances ei ON ei.id = ere.entity_instance_id + JOIN entity_instances ei ON ei.id = ere.entity_instance_id LEFT JOIN remediation_events re ON re.evaluation_id = s.id LEFT JOIN alert_events ae ON ae.evaluation_id = s.id LEFT JOIN projects j ON r.project_id = j.id diff --git a/database/query/pull_requests.sql b/database/query/pull_requests.sql new file mode 100644 index 0000000000..ba9d7554cd --- /dev/null +++ b/database/query/pull_requests.sql @@ -0,0 +1,7 @@ +-- name: GetPullRequest :one +SELECT * FROM pull_requests +WHERE repository_id = $1 AND pr_number = $2; + +-- name: GetPullRequestByID :one +SELECT * FROM pull_requests +WHERE id = $1; \ No newline at end of file diff --git a/database/query/repositories.sql b/database/query/repositories.sql new file mode 100644 index 0000000000..fbe89a68bb --- /dev/null +++ b/database/query/repositories.sql @@ -0,0 +1,61 @@ +-- name: GetRepositoryByRepoID :one +SELECT * FROM repositories WHERE repo_id = $1; + +-- name: GetRepositoryByRepoName :one +SELECT * FROM repositories + WHERE repo_owner = $1 AND repo_name = $2 AND project_id = $3 + AND (lower(provider) = lower(sqlc.narg('provider')::text) OR sqlc.narg('provider')::text IS NULL); + +-- avoid using this, where possible use GetRepositoryByIDAndProject instead +-- name: GetRepositoryByID :one +SELECT * FROM repositories WHERE id = $1; + +-- name: GetRepositoryByIDAndProject :one +SELECT * FROM repositories WHERE id = $1 AND project_id = $2; + +-- name: ListRepositoriesByProjectID :many +SELECT * FROM repositories +WHERE project_id = $1 + AND (repo_id >= sqlc.narg('repo_id') OR sqlc.narg('repo_id') IS NULL) + AND lower(provider) = lower(COALESCE(sqlc.narg('provider'), provider)::text) +ORDER BY project_id, provider, repo_id +LIMIT sqlc.narg('limit')::bigint; + +-- name: ListRegisteredRepositoriesByProjectIDAndProvider :many +SELECT * FROM repositories +WHERE project_id = $1 AND webhook_id IS NOT NULL + AND (lower(provider) = lower(sqlc.narg('provider')::text) OR sqlc.narg('provider')::text IS NULL) +ORDER BY repo_name; + +-- name: ListRepositoriesAfterID :many +SELECT * +FROM repositories +WHERE id > $1 +ORDER BY id +LIMIT sqlc.arg('limit')::bigint; + +-- name: RepositoryExistsAfterID :one +SELECT EXISTS ( + SELECT 1 + FROM repositories + WHERE id > $1) +AS exists; + +-- name: CountRepositories :one +SELECT COUNT(*) FROM repositories; + +-- get a list of repos with webhooks belonging to a provider +-- is used for webhook cleanup during provider deletion +-- name: GetProviderWebhooks :many +SELECT repo_owner, repo_name, webhook_id FROM repositories +WHERE webhook_id IS NOT NULL AND provider_id = $1; + +-- name: GetRepoPathFromArtifactID :one +SELECT r.repo_owner AS owner , r.repo_name AS name FROM repositories AS r +JOIN artifacts AS a ON a.repository_id = r.id +WHERE a.id = $1; + +-- name: GetRepoPathFromPullRequestID :one +SELECT r.repo_owner AS owner , r.repo_name AS name FROM repositories AS r +JOIN pull_requests AS p ON p.repository_id = r.id +WHERE p.id = $1; \ No newline at end of file diff --git a/internal/db/artifacts.sql.go b/internal/db/artifacts.sql.go new file mode 100644 index 0000000000..fbb542b217 --- /dev/null +++ b/internal/db/artifacts.sql.go @@ -0,0 +1,107 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.27.0 +// source: artifacts.sql + +package db + +import ( + "context" + + "github.com/google/uuid" +) + +const getArtifactByID = `-- name: GetArtifactByID :one +SELECT id, project_id, provider_name, provider_id, repository_id, artifact_name, artifact_type, artifact_visibility, created_at FROM artifacts +WHERE artifacts.id = $1 AND artifacts.project_id = $2 +` + +type GetArtifactByIDParams struct { + ID uuid.UUID `json:"id"` + ProjectID uuid.UUID `json:"project_id"` +} + +func (q *Queries) GetArtifactByID(ctx context.Context, arg GetArtifactByIDParams) (Artifact, error) { + row := q.db.QueryRowContext(ctx, getArtifactByID, arg.ID, arg.ProjectID) + var i Artifact + err := row.Scan( + &i.ID, + &i.ProjectID, + &i.ProviderName, + &i.ProviderID, + &i.RepositoryID, + &i.ArtifactName, + &i.ArtifactType, + &i.ArtifactVisibility, + &i.CreatedAt, + ) + return i, err +} + +const getArtifactByName = `-- name: GetArtifactByName :one +SELECT id, project_id, provider_name, provider_id, repository_id, artifact_name, artifact_type, artifact_visibility, created_at FROM artifacts +WHERE lower(artifacts.artifact_name) = lower($3) +AND artifacts.repository_id = $1 AND artifacts.project_id = $2 +` + +type GetArtifactByNameParams struct { + RepositoryID uuid.NullUUID `json:"repository_id"` + ProjectID uuid.UUID `json:"project_id"` + ArtifactName string `json:"artifact_name"` +} + +func (q *Queries) GetArtifactByName(ctx context.Context, arg GetArtifactByNameParams) (Artifact, error) { + row := q.db.QueryRowContext(ctx, getArtifactByName, arg.RepositoryID, arg.ProjectID, arg.ArtifactName) + var i Artifact + err := row.Scan( + &i.ID, + &i.ProjectID, + &i.ProviderName, + &i.ProviderID, + &i.RepositoryID, + &i.ArtifactName, + &i.ArtifactType, + &i.ArtifactVisibility, + &i.CreatedAt, + ) + return i, err +} + +const listArtifactsByRepoID = `-- name: ListArtifactsByRepoID :many +SELECT id, project_id, provider_name, provider_id, repository_id, artifact_name, artifact_type, artifact_visibility, created_at FROM artifacts +WHERE repository_id = $1 +ORDER BY id +` + +func (q *Queries) ListArtifactsByRepoID(ctx context.Context, repositoryID uuid.NullUUID) ([]Artifact, error) { + rows, err := q.db.QueryContext(ctx, listArtifactsByRepoID, repositoryID) + if err != nil { + return nil, err + } + defer rows.Close() + items := []Artifact{} + for rows.Next() { + var i Artifact + if err := rows.Scan( + &i.ID, + &i.ProjectID, + &i.ProviderName, + &i.ProviderID, + &i.RepositoryID, + &i.ArtifactName, + &i.ArtifactType, + &i.ArtifactVisibility, + &i.CreatedAt, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} diff --git a/internal/db/eval_history.sql.go b/internal/db/eval_history.sql.go index 9d94b0b9c6..87b2e8ab3f 100644 --- a/internal/db/eval_history.sql.go +++ b/internal/db/eval_history.sql.go @@ -312,7 +312,7 @@ func (q *Queries) InsertRemediationEvent(ctx context.Context, arg InsertRemediat const listEvaluationHistory = `-- name: ListEvaluationHistory :many SELECT s.id::uuid AS evaluation_id, s.evaluation_time as evaluated_at, - ere.entity_type, + ei.entity_type, ere.entity_instance_id as entity_id, -- raw fields for entity names ei.name as entity_name, @@ -335,7 +335,7 @@ SELECT s.id::uuid AS evaluation_id, JOIN rule_instances ri ON ere.rule_id = ri.id JOIN rule_type rt ON ri.rule_type_id = rt.id JOIN profiles p ON ri.profile_id = p.id - LEFT JOIN entity_instances ei ON ei.id = ere.entity_instance_id + JOIN entity_instances ei ON ei.id = ere.entity_instance_id LEFT JOIN remediation_events re ON re.evaluation_id = s.id LEFT JOIN alert_events ae ON ae.evaluation_id = s.id LEFT JOIN projects j ON r.project_id = j.id @@ -396,7 +396,7 @@ type ListEvaluationHistoryRow struct { EvaluatedAt time.Time `json:"evaluated_at"` EntityType Entities `json:"entity_type"` EntityID uuid.UUID `json:"entity_id"` - EntityName sql.NullString `json:"entity_name"` + EntityName string `json:"entity_name"` RuleType string `json:"rule_type"` RuleName string `json:"rule_name"` RuleSeverity Severity `json:"rule_severity"` diff --git a/internal/db/eval_history_test.go b/internal/db/eval_history_test.go index 42b1d5dfd7..f0b8a83c7e 100644 --- a/internal/db/eval_history_test.go +++ b/internal/db/eval_history_test.go @@ -70,8 +70,6 @@ func TestListEvaluationHistoryFilters(t *testing.T) { require.Equal(t, es1, row.EvaluationID) require.Equal(t, EntitiesRepository, row.EntityType) require.Equal(t, repo1.ID, row.EntityID) - require.Equal(t, repo1.RepoOwner, row.RepoOwner.String) - require.Equal(t, repo1.RepoName, row.RepoName.String) }, }, @@ -94,8 +92,6 @@ func TestListEvaluationHistoryFilters(t *testing.T) { require.Equal(t, es1, row.EvaluationID) require.Equal(t, EntitiesRepository, row.EntityType) require.Equal(t, repo1.ID, row.EntityID) - require.Equal(t, repo1.RepoOwner, row.RepoOwner.String) - require.Equal(t, repo1.RepoName, row.RepoName.String) }, }, { @@ -148,8 +144,6 @@ func TestListEvaluationHistoryFilters(t *testing.T) { require.Equal(t, es1, row.EvaluationID) require.Equal(t, EntitiesRepository, row.EntityType) require.Equal(t, repo1.ID, row.EntityID) - require.Equal(t, repo1.RepoOwner, row.RepoOwner.String) - require.Equal(t, repo1.RepoName, row.RepoName.String) }, }, { @@ -170,8 +164,6 @@ func TestListEvaluationHistoryFilters(t *testing.T) { require.Equal(t, es1, row.EvaluationID) require.Equal(t, EntitiesRepository, row.EntityType) require.Equal(t, repo1.ID, row.EntityID) - require.Equal(t, repo1.RepoOwner, row.RepoOwner.String) - require.Equal(t, repo1.RepoName, row.RepoName.String) }, }, { @@ -210,8 +202,6 @@ func TestListEvaluationHistoryFilters(t *testing.T) { require.Equal(t, es1, row.EvaluationID) require.Equal(t, EntitiesRepository, row.EntityType) require.Equal(t, repo1.ID, row.EntityID) - require.Equal(t, repo1.RepoOwner, row.RepoOwner.String) - require.Equal(t, repo1.RepoName, row.RepoName.String) }, }, { @@ -264,8 +254,6 @@ func TestListEvaluationHistoryFilters(t *testing.T) { require.Equal(t, es1, row.EvaluationID) require.Equal(t, EntitiesRepository, row.EntityType) require.Equal(t, repo1.ID, row.EntityID) - require.Equal(t, repo1.RepoOwner, row.RepoOwner.String) - require.Equal(t, repo1.RepoName, row.RepoName.String) }, }, { @@ -286,8 +274,6 @@ func TestListEvaluationHistoryFilters(t *testing.T) { require.Equal(t, es1, row.EvaluationID) require.Equal(t, EntitiesRepository, row.EntityType) require.Equal(t, repo1.ID, row.EntityID) - require.Equal(t, repo1.RepoOwner, row.RepoOwner.String) - require.Equal(t, repo1.RepoName, row.RepoName.String) }, }, { @@ -326,8 +312,6 @@ func TestListEvaluationHistoryFilters(t *testing.T) { require.Equal(t, es1, row.EvaluationID) require.Equal(t, EntitiesRepository, row.EntityType) require.Equal(t, repo1.ID, row.EntityID) - require.Equal(t, repo1.RepoOwner, row.RepoOwner.String) - require.Equal(t, repo1.RepoName, row.RepoName.String) }, }, { @@ -380,8 +364,6 @@ func TestListEvaluationHistoryFilters(t *testing.T) { require.Equal(t, es1, row.EvaluationID) require.Equal(t, EntitiesRepository, row.EntityType) require.Equal(t, repo1.ID, row.EntityID) - require.Equal(t, repo1.RepoOwner, row.RepoOwner.String) - require.Equal(t, repo1.RepoName, row.RepoName.String) }, }, { @@ -402,8 +384,6 @@ func TestListEvaluationHistoryFilters(t *testing.T) { require.Equal(t, es1, row.EvaluationID) require.Equal(t, EntitiesRepository, row.EntityType) require.Equal(t, repo1.ID, row.EntityID) - require.Equal(t, repo1.RepoOwner, row.RepoOwner.String) - require.Equal(t, repo1.RepoName, row.RepoName.String) }, }, { @@ -464,8 +444,6 @@ func TestListEvaluationHistoryFilters(t *testing.T) { require.Equal(t, es1, row.EvaluationID) require.Equal(t, EntitiesRepository, row.EntityType) require.Equal(t, repo1.ID, row.EntityID) - require.Equal(t, repo1.RepoOwner, row.RepoOwner.String) - require.Equal(t, repo1.RepoName, row.RepoName.String) }, }, { @@ -489,8 +467,6 @@ func TestListEvaluationHistoryFilters(t *testing.T) { require.Equal(t, es1, row.EvaluationID) require.Equal(t, EntitiesRepository, row.EntityType) require.Equal(t, repo1.ID, row.EntityID) - require.Equal(t, repo1.RepoOwner, row.RepoOwner.String) - require.Equal(t, repo1.RepoName, row.RepoName.String) }, }, { @@ -601,8 +577,6 @@ func TestListEvaluationHistoryPagination(t *testing.T) { require.Equal(t, ess[9], row.EvaluationID) require.Equal(t, EntitiesRepository, row.EntityType) require.Equal(t, repos[9].ID, row.EntityID) - require.Equal(t, repos[9].RepoOwner, row.RepoOwner.String) - require.Equal(t, repos[9].RepoName, row.RepoName.String) }, }, { @@ -646,8 +620,6 @@ func TestListEvaluationHistoryPagination(t *testing.T) { require.Equal(t, ess[0], row.EvaluationID) require.Equal(t, EntitiesRepository, row.EntityType) require.Equal(t, repos[0].ID, row.EntityID) - require.Equal(t, repos[0].RepoOwner, row.RepoOwner.String) - require.Equal(t, repos[0].RepoName, row.RepoName.String) }, }, { diff --git a/internal/db/models.go b/internal/db/models.go index 3a68950b5d..506ad9adc1 100644 --- a/internal/db/models.go +++ b/internal/db/models.go @@ -476,6 +476,18 @@ type AlertEvent struct { CreatedAt time.Time `json:"created_at"` } +type Artifact struct { + ID uuid.UUID `json:"id"` + ProjectID uuid.UUID `json:"project_id"` + ProviderName string `json:"provider_name"` + ProviderID uuid.UUID `json:"provider_id"` + RepositoryID uuid.NullUUID `json:"repository_id"` + ArtifactName string `json:"artifact_name"` + ArtifactType string `json:"artifact_type"` + ArtifactVisibility string `json:"artifact_visibility"` + CreatedAt time.Time `json:"created_at"` +} + type Bundle struct { ID uuid.UUID `json:"id"` Namespace string `json:"namespace"` @@ -660,6 +672,13 @@ type ProviderGithubAppInstallation struct { IsOrg bool `json:"is_org"` } +type PullRequest struct { + ID uuid.UUID `json:"id"` + RepositoryID uuid.NullUUID `json:"repository_id"` + PrNumber int64 `json:"pr_number"` + CreatedAt time.Time `json:"created_at"` +} + type RemediationEvent struct { ID uuid.UUID `json:"id"` EvaluationID uuid.UUID `json:"evaluation_id"` @@ -669,6 +688,25 @@ type RemediationEvent struct { CreatedAt time.Time `json:"created_at"` } +type Repository struct { + ID uuid.UUID `json:"id"` + ProjectID uuid.UUID `json:"project_id"` + Provider string `json:"provider"` + ProviderID uuid.UUID `json:"provider_id"` + RepoOwner string `json:"repo_owner"` + RepoName string `json:"repo_name"` + RepoID int64 `json:"repo_id"` + IsPrivate bool `json:"is_private"` + IsFork bool `json:"is_fork"` + WebhookID int64 `json:"webhook_id"` + WebhookUrl string `json:"webhook_url"` + DeployUrl string `json:"deploy_url"` + CloneUrl string `json:"clone_url"` + DefaultBranch string `json:"default_branch"` + License string `json:"license"` + CreatedAt time.Time `json:"created_at"` +} + type RuleInstance struct { ID uuid.UUID `json:"id"` ProfileID uuid.UUID `json:"profile_id"` diff --git a/internal/db/profiles_test.go b/internal/db/profiles_test.go index fbe8dcd5f2..ecda022ab1 100644 --- a/internal/db/profiles_test.go +++ b/internal/db/profiles_test.go @@ -3652,11 +3652,6 @@ func verifyRow( require.Equal(t, rt.ID, row.RuleTypeID) require.Equal(t, rt.Name, row.RuleTypeName) - - require.Equal(t, randomEntities.repo.RepoName, row.RepoName.String) - require.Equal(t, randomEntities.repo.RepoOwner, row.RepoOwner.String) - - require.Equal(t, randomEntities.prov.Name, row.Provider.String) } func TestListRuleEvaluations(t *testing.T) { diff --git a/internal/db/pull_requests.sql.go b/internal/db/pull_requests.sql.go new file mode 100644 index 0000000000..30da54c2b8 --- /dev/null +++ b/internal/db/pull_requests.sql.go @@ -0,0 +1,51 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.27.0 +// source: pull_requests.sql + +package db + +import ( + "context" + + "github.com/google/uuid" +) + +const getPullRequest = `-- name: GetPullRequest :one +SELECT id, repository_id, pr_number, created_at FROM pull_requests +WHERE repository_id = $1 AND pr_number = $2 +` + +type GetPullRequestParams struct { + RepositoryID uuid.NullUUID `json:"repository_id"` + PrNumber int64 `json:"pr_number"` +} + +func (q *Queries) GetPullRequest(ctx context.Context, arg GetPullRequestParams) (PullRequest, error) { + row := q.db.QueryRowContext(ctx, getPullRequest, arg.RepositoryID, arg.PrNumber) + var i PullRequest + err := row.Scan( + &i.ID, + &i.RepositoryID, + &i.PrNumber, + &i.CreatedAt, + ) + return i, err +} + +const getPullRequestByID = `-- name: GetPullRequestByID :one +SELECT id, repository_id, pr_number, created_at FROM pull_requests +WHERE id = $1 +` + +func (q *Queries) GetPullRequestByID(ctx context.Context, id uuid.UUID) (PullRequest, error) { + row := q.db.QueryRowContext(ctx, getPullRequestByID, id) + var i PullRequest + err := row.Scan( + &i.ID, + &i.RepositoryID, + &i.PrNumber, + &i.CreatedAt, + ) + return i, err +} diff --git a/internal/db/querier.go b/internal/db/querier.go index ae92770305..3d88746886 100644 --- a/internal/db/querier.go +++ b/internal/db/querier.go @@ -16,6 +16,7 @@ type Querier interface { BulkGetProfilesByID(ctx context.Context, profileIds []uuid.UUID) ([]BulkGetProfilesByIDRow, error) CountProfilesByEntityType(ctx context.Context) ([]CountProfilesByEntityTypeRow, error) CountProfilesByName(ctx context.Context, name string) (int64, error) + CountRepositories(ctx context.Context) (int64, error) CountUsers(ctx context.Context) (int64, error) // CreateEntity adds an entry to the entity_instances table so it can be tracked by Minder. CreateEntity(ctx context.Context, arg CreateEntityParams) (EntityInstance, error) @@ -74,6 +75,8 @@ type Querier interface { GetAccessTokenByProvider(ctx context.Context, provider string) ([]ProviderAccessToken, error) GetAccessTokenSinceDate(ctx context.Context, arg GetAccessTokenSinceDateParams) (ProviderAccessToken, error) GetAllPropertiesForEntity(ctx context.Context, entityID uuid.UUID) ([]Property, error) + GetArtifactByID(ctx context.Context, arg GetArtifactByIDParams) (Artifact, error) + GetArtifactByName(ctx context.Context, arg GetArtifactByNameParams) (Artifact, error) GetBundle(ctx context.Context, arg GetBundleParams) (Bundle, error) GetChildrenProjects(ctx context.Context, id uuid.UUID) ([]GetChildrenProjectsRow, error) // GetEntitiesByType retrieves all entities of a given type for a project or hierarchy of projects. @@ -142,6 +145,18 @@ type Querier interface { // if it exists in the project or any of its ancestors. It'll return the first // provider that matches the name. GetProviderByName(ctx context.Context, arg GetProviderByNameParams) (Provider, error) + // get a list of repos with webhooks belonging to a provider + // is used for webhook cleanup during provider deletion + GetProviderWebhooks(ctx context.Context, providerID uuid.UUID) ([]GetProviderWebhooksRow, error) + GetPullRequest(ctx context.Context, arg GetPullRequestParams) (PullRequest, error) + GetPullRequestByID(ctx context.Context, id uuid.UUID) (PullRequest, error) + GetRepoPathFromArtifactID(ctx context.Context, id uuid.UUID) (GetRepoPathFromArtifactIDRow, error) + GetRepoPathFromPullRequestID(ctx context.Context, id uuid.UUID) (GetRepoPathFromPullRequestIDRow, error) + // avoid using this, where possible use GetRepositoryByIDAndProject instead + GetRepositoryByID(ctx context.Context, id uuid.UUID) (Repository, error) + GetRepositoryByIDAndProject(ctx context.Context, arg GetRepositoryByIDAndProjectParams) (Repository, error) + GetRepositoryByRepoID(ctx context.Context, repoID int64) (Repository, error) + GetRepositoryByRepoName(ctx context.Context, arg GetRepositoryByRepoNameParams) (Repository, error) GetRuleInstancesEntityInProjects(ctx context.Context, arg GetRuleInstancesEntityInProjectsParams) ([]RuleInstance, error) GetRuleInstancesForProfile(ctx context.Context, profileID uuid.UUID) ([]RuleInstance, error) GetRuleTypeByID(ctx context.Context, id uuid.UUID) (RuleType, error) @@ -165,6 +180,7 @@ type Querier interface { InsertEvaluationRuleEntity(ctx context.Context, arg InsertEvaluationRuleEntityParams) (uuid.UUID, error) InsertEvaluationStatus(ctx context.Context, arg InsertEvaluationStatusParams) (uuid.UUID, error) InsertRemediationEvent(ctx context.Context, arg InsertRemediationEventParams) error + ListArtifactsByRepoID(ctx context.Context, repositoryID uuid.NullUUID) ([]Artifact, error) ListEvaluationHistory(ctx context.Context, arg ListEvaluationHistoryParams) ([]ListEvaluationHistoryRow, error) ListEvaluationHistoryStaleRecords(ctx context.Context, arg ListEvaluationHistoryStaleRecordsParams) ([]ListEvaluationHistoryStaleRecordsRow, error) ListFlushCache(ctx context.Context) ([]FlushCache, error) @@ -184,6 +200,9 @@ type Querier interface { // ListProvidersByProjectIDPaginated allows us to lits all providers for a given project // with pagination taken into account. In this case, the cursor is the creation date. ListProvidersByProjectIDPaginated(ctx context.Context, arg ListProvidersByProjectIDPaginatedParams) ([]Provider, error) + ListRegisteredRepositoriesByProjectIDAndProvider(ctx context.Context, arg ListRegisteredRepositoriesByProjectIDAndProviderParams) ([]Repository, error) + ListRepositoriesAfterID(ctx context.Context, arg ListRepositoriesAfterIDParams) ([]Repository, error) + ListRepositoriesByProjectID(ctx context.Context, arg ListRepositoriesByProjectIDParams) ([]Repository, error) ListRuleEvaluationsByProfileId(ctx context.Context, arg ListRuleEvaluationsByProfileIdParams) ([]ListRuleEvaluationsByProfileIdRow, error) ListRuleTypesByProject(ctx context.Context, projectID uuid.UUID) ([]RuleType, error) // When doing a key/algorithm rotation, identify the secrets which need to be @@ -208,6 +227,7 @@ type Querier interface { // entity_execution_lock record if the lock is held by the given locked_by // value. ReleaseLock(ctx context.Context, arg ReleaseLockParams) error + RepositoryExistsAfterID(ctx context.Context, id uuid.UUID) (bool, error) SetCurrentVersion(ctx context.Context, arg SetCurrentVersionParams) error UpdateEncryptedSecret(ctx context.Context, arg UpdateEncryptedSecretParams) error // UpdateInvitationRole updates an invitation by its code. This is intended to be diff --git a/internal/db/repositories.sql.go b/internal/db/repositories.sql.go new file mode 100644 index 0000000000..78a33aa0ed --- /dev/null +++ b/internal/db/repositories.sql.go @@ -0,0 +1,409 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.27.0 +// source: repositories.sql + +package db + +import ( + "context" + "database/sql" + + "github.com/google/uuid" +) + +const countRepositories = `-- name: CountRepositories :one +SELECT COUNT(*) FROM repositories +` + +func (q *Queries) CountRepositories(ctx context.Context) (int64, error) { + row := q.db.QueryRowContext(ctx, countRepositories) + var count int64 + err := row.Scan(&count) + return count, err +} + +const getProviderWebhooks = `-- name: GetProviderWebhooks :many +SELECT repo_owner, repo_name, webhook_id FROM repositories +WHERE webhook_id IS NOT NULL AND provider_id = $1 +` + +type GetProviderWebhooksRow struct { + RepoOwner string `json:"repo_owner"` + RepoName string `json:"repo_name"` + WebhookID int64 `json:"webhook_id"` +} + +// get a list of repos with webhooks belonging to a provider +// is used for webhook cleanup during provider deletion +func (q *Queries) GetProviderWebhooks(ctx context.Context, providerID uuid.UUID) ([]GetProviderWebhooksRow, error) { + rows, err := q.db.QueryContext(ctx, getProviderWebhooks, providerID) + if err != nil { + return nil, err + } + defer rows.Close() + items := []GetProviderWebhooksRow{} + for rows.Next() { + var i GetProviderWebhooksRow + if err := rows.Scan(&i.RepoOwner, &i.RepoName, &i.WebhookID); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const getRepoPathFromArtifactID = `-- name: GetRepoPathFromArtifactID :one +SELECT r.repo_owner AS owner , r.repo_name AS name FROM repositories AS r +JOIN artifacts AS a ON a.repository_id = r.id +WHERE a.id = $1 +` + +type GetRepoPathFromArtifactIDRow struct { + Owner string `json:"owner"` + Name string `json:"name"` +} + +func (q *Queries) GetRepoPathFromArtifactID(ctx context.Context, id uuid.UUID) (GetRepoPathFromArtifactIDRow, error) { + row := q.db.QueryRowContext(ctx, getRepoPathFromArtifactID, id) + var i GetRepoPathFromArtifactIDRow + err := row.Scan(&i.Owner, &i.Name) + return i, err +} + +const getRepoPathFromPullRequestID = `-- name: GetRepoPathFromPullRequestID :one +SELECT r.repo_owner AS owner , r.repo_name AS name FROM repositories AS r +JOIN pull_requests AS p ON p.repository_id = r.id +WHERE p.id = $1 +` + +type GetRepoPathFromPullRequestIDRow struct { + Owner string `json:"owner"` + Name string `json:"name"` +} + +func (q *Queries) GetRepoPathFromPullRequestID(ctx context.Context, id uuid.UUID) (GetRepoPathFromPullRequestIDRow, error) { + row := q.db.QueryRowContext(ctx, getRepoPathFromPullRequestID, id) + var i GetRepoPathFromPullRequestIDRow + err := row.Scan(&i.Owner, &i.Name) + return i, err +} + +const getRepositoryByID = `-- name: GetRepositoryByID :one +SELECT id, project_id, provider, provider_id, repo_owner, repo_name, repo_id, is_private, is_fork, webhook_id, webhook_url, deploy_url, clone_url, default_branch, license, created_at FROM repositories WHERE id = $1 +` + +// avoid using this, where possible use GetRepositoryByIDAndProject instead +func (q *Queries) GetRepositoryByID(ctx context.Context, id uuid.UUID) (Repository, error) { + row := q.db.QueryRowContext(ctx, getRepositoryByID, id) + var i Repository + err := row.Scan( + &i.ID, + &i.ProjectID, + &i.Provider, + &i.ProviderID, + &i.RepoOwner, + &i.RepoName, + &i.RepoID, + &i.IsPrivate, + &i.IsFork, + &i.WebhookID, + &i.WebhookUrl, + &i.DeployUrl, + &i.CloneUrl, + &i.DefaultBranch, + &i.License, + &i.CreatedAt, + ) + return i, err +} + +const getRepositoryByIDAndProject = `-- name: GetRepositoryByIDAndProject :one +SELECT id, project_id, provider, provider_id, repo_owner, repo_name, repo_id, is_private, is_fork, webhook_id, webhook_url, deploy_url, clone_url, default_branch, license, created_at FROM repositories WHERE id = $1 AND project_id = $2 +` + +type GetRepositoryByIDAndProjectParams struct { + ID uuid.UUID `json:"id"` + ProjectID uuid.UUID `json:"project_id"` +} + +func (q *Queries) GetRepositoryByIDAndProject(ctx context.Context, arg GetRepositoryByIDAndProjectParams) (Repository, error) { + row := q.db.QueryRowContext(ctx, getRepositoryByIDAndProject, arg.ID, arg.ProjectID) + var i Repository + err := row.Scan( + &i.ID, + &i.ProjectID, + &i.Provider, + &i.ProviderID, + &i.RepoOwner, + &i.RepoName, + &i.RepoID, + &i.IsPrivate, + &i.IsFork, + &i.WebhookID, + &i.WebhookUrl, + &i.DeployUrl, + &i.CloneUrl, + &i.DefaultBranch, + &i.License, + &i.CreatedAt, + ) + return i, err +} + +const getRepositoryByRepoID = `-- name: GetRepositoryByRepoID :one +SELECT id, project_id, provider, provider_id, repo_owner, repo_name, repo_id, is_private, is_fork, webhook_id, webhook_url, deploy_url, clone_url, default_branch, license, created_at FROM repositories WHERE repo_id = $1 +` + +func (q *Queries) GetRepositoryByRepoID(ctx context.Context, repoID int64) (Repository, error) { + row := q.db.QueryRowContext(ctx, getRepositoryByRepoID, repoID) + var i Repository + err := row.Scan( + &i.ID, + &i.ProjectID, + &i.Provider, + &i.ProviderID, + &i.RepoOwner, + &i.RepoName, + &i.RepoID, + &i.IsPrivate, + &i.IsFork, + &i.WebhookID, + &i.WebhookUrl, + &i.DeployUrl, + &i.CloneUrl, + &i.DefaultBranch, + &i.License, + &i.CreatedAt, + ) + return i, err +} + +const getRepositoryByRepoName = `-- name: GetRepositoryByRepoName :one +SELECT id, project_id, provider, provider_id, repo_owner, repo_name, repo_id, is_private, is_fork, webhook_id, webhook_url, deploy_url, clone_url, default_branch, license, created_at FROM repositories + WHERE repo_owner = $1 AND repo_name = $2 AND project_id = $3 + AND (lower(provider) = lower($4::text) OR $4::text IS NULL) +` + +type GetRepositoryByRepoNameParams struct { + RepoOwner string `json:"repo_owner"` + RepoName string `json:"repo_name"` + ProjectID uuid.UUID `json:"project_id"` + Provider sql.NullString `json:"provider"` +} + +func (q *Queries) GetRepositoryByRepoName(ctx context.Context, arg GetRepositoryByRepoNameParams) (Repository, error) { + row := q.db.QueryRowContext(ctx, getRepositoryByRepoName, + arg.RepoOwner, + arg.RepoName, + arg.ProjectID, + arg.Provider, + ) + var i Repository + err := row.Scan( + &i.ID, + &i.ProjectID, + &i.Provider, + &i.ProviderID, + &i.RepoOwner, + &i.RepoName, + &i.RepoID, + &i.IsPrivate, + &i.IsFork, + &i.WebhookID, + &i.WebhookUrl, + &i.DeployUrl, + &i.CloneUrl, + &i.DefaultBranch, + &i.License, + &i.CreatedAt, + ) + return i, err +} + +const listRegisteredRepositoriesByProjectIDAndProvider = `-- name: ListRegisteredRepositoriesByProjectIDAndProvider :many +SELECT id, project_id, provider, provider_id, repo_owner, repo_name, repo_id, is_private, is_fork, webhook_id, webhook_url, deploy_url, clone_url, default_branch, license, created_at FROM repositories +WHERE project_id = $1 AND webhook_id IS NOT NULL + AND (lower(provider) = lower($2::text) OR $2::text IS NULL) +ORDER BY repo_name +` + +type ListRegisteredRepositoriesByProjectIDAndProviderParams struct { + ProjectID uuid.UUID `json:"project_id"` + Provider sql.NullString `json:"provider"` +} + +func (q *Queries) ListRegisteredRepositoriesByProjectIDAndProvider(ctx context.Context, arg ListRegisteredRepositoriesByProjectIDAndProviderParams) ([]Repository, error) { + rows, err := q.db.QueryContext(ctx, listRegisteredRepositoriesByProjectIDAndProvider, arg.ProjectID, arg.Provider) + if err != nil { + return nil, err + } + defer rows.Close() + items := []Repository{} + for rows.Next() { + var i Repository + if err := rows.Scan( + &i.ID, + &i.ProjectID, + &i.Provider, + &i.ProviderID, + &i.RepoOwner, + &i.RepoName, + &i.RepoID, + &i.IsPrivate, + &i.IsFork, + &i.WebhookID, + &i.WebhookUrl, + &i.DeployUrl, + &i.CloneUrl, + &i.DefaultBranch, + &i.License, + &i.CreatedAt, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const listRepositoriesAfterID = `-- name: ListRepositoriesAfterID :many +SELECT id, project_id, provider, provider_id, repo_owner, repo_name, repo_id, is_private, is_fork, webhook_id, webhook_url, deploy_url, clone_url, default_branch, license, created_at +FROM repositories +WHERE id > $1 +ORDER BY id +LIMIT $2::bigint +` + +type ListRepositoriesAfterIDParams struct { + ID uuid.UUID `json:"id"` + Limit int64 `json:"limit"` +} + +func (q *Queries) ListRepositoriesAfterID(ctx context.Context, arg ListRepositoriesAfterIDParams) ([]Repository, error) { + rows, err := q.db.QueryContext(ctx, listRepositoriesAfterID, arg.ID, arg.Limit) + if err != nil { + return nil, err + } + defer rows.Close() + items := []Repository{} + for rows.Next() { + var i Repository + if err := rows.Scan( + &i.ID, + &i.ProjectID, + &i.Provider, + &i.ProviderID, + &i.RepoOwner, + &i.RepoName, + &i.RepoID, + &i.IsPrivate, + &i.IsFork, + &i.WebhookID, + &i.WebhookUrl, + &i.DeployUrl, + &i.CloneUrl, + &i.DefaultBranch, + &i.License, + &i.CreatedAt, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const listRepositoriesByProjectID = `-- name: ListRepositoriesByProjectID :many +SELECT id, project_id, provider, provider_id, repo_owner, repo_name, repo_id, is_private, is_fork, webhook_id, webhook_url, deploy_url, clone_url, default_branch, license, created_at FROM repositories +WHERE project_id = $1 + AND (repo_id >= $2 OR $2 IS NULL) + AND lower(provider) = lower(COALESCE($3, provider)::text) +ORDER BY project_id, provider, repo_id +LIMIT $4::bigint +` + +type ListRepositoriesByProjectIDParams struct { + ProjectID uuid.UUID `json:"project_id"` + RepoID sql.NullInt64 `json:"repo_id"` + Provider sql.NullString `json:"provider"` + Limit sql.NullInt64 `json:"limit"` +} + +func (q *Queries) ListRepositoriesByProjectID(ctx context.Context, arg ListRepositoriesByProjectIDParams) ([]Repository, error) { + rows, err := q.db.QueryContext(ctx, listRepositoriesByProjectID, + arg.ProjectID, + arg.RepoID, + arg.Provider, + arg.Limit, + ) + if err != nil { + return nil, err + } + defer rows.Close() + items := []Repository{} + for rows.Next() { + var i Repository + if err := rows.Scan( + &i.ID, + &i.ProjectID, + &i.Provider, + &i.ProviderID, + &i.RepoOwner, + &i.RepoName, + &i.RepoID, + &i.IsPrivate, + &i.IsFork, + &i.WebhookID, + &i.WebhookUrl, + &i.DeployUrl, + &i.CloneUrl, + &i.DefaultBranch, + &i.License, + &i.CreatedAt, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const repositoryExistsAfterID = `-- name: RepositoryExistsAfterID :one +SELECT EXISTS ( + SELECT 1 + FROM repositories + WHERE id > $1) +AS exists +` + +func (q *Queries) RepositoryExistsAfterID(ctx context.Context, id uuid.UUID) (bool, error) { + row := q.db.QueryRowContext(ctx, repositoryExistsAfterID, id) + var exists bool + err := row.Scan(&exists) + return exists, err +}