diff --git a/check.go b/check.go index 445f6de4..64df9e24 100644 --- a/check.go +++ b/check.go @@ -14,7 +14,13 @@ import ( func Check(request CheckRequest, manager Github) (CheckResponse, error) { var response CheckResponse - pulls, err := manager.ListOpenPullRequests() + // Filter out pull request if it does not have a filtered state + filterStates := []githubv4.PullRequestState{githubv4.PullRequestStateOpen} + if len(request.Source.States) > 0 { + filterStates = request.Source.States + } + + pulls, err := manager.ListPullRequests(filterStates) if err != nil { return nil, fmt.Errorf("failed to get last commits: %s", err) } @@ -38,23 +44,6 @@ Loop: continue } - // Filter out pull request if it does not have a filtered state - filterStates := []githubv4.PullRequestState{githubv4.PullRequestStateOpen} - if len(request.Source.States) > 0 { - filterStates = request.Source.States - } - - stateFound := false - for _, state := range filterStates { - if p.State == state { - stateFound = true - break - } - } - if !stateFound { - continue - } - // Filter out commits that are too old. if !p.UpdatedDate().Time.After(request.Version.CommittedDate) { continue diff --git a/check_test.go b/check_test.go index 3bf35c46..8c422914 100644 --- a/check_test.go +++ b/check_test.go @@ -267,7 +267,20 @@ func TestCheck(t *testing.T) { for _, tc := range tests { t.Run(tc.description, func(t *testing.T) { github := new(fakes.FakeGithub) - github.ListOpenPullRequestsReturns(tc.pullRequests, nil) + pullRequests := []*resource.PullRequest{} + filterStates := []githubv4.PullRequestState{githubv4.PullRequestStateOpen} + if len(tc.source.States) > 0 { + filterStates = tc.source.States + } + for i := range tc.pullRequests { + for j := range filterStates { + if filterStates[j] == tc.pullRequests[i].PullRequestObject.State { + pullRequests = append(pullRequests, tc.pullRequests[i]) + break + } + } + } + github.ListPullRequestsReturns(pullRequests, nil) for i, file := range tc.files { github.ListModifiedFilesReturnsOnCall(i, file, nil) @@ -279,7 +292,7 @@ func TestCheck(t *testing.T) { if assert.NoError(t, err) { assert.Equal(t, tc.expected, output) } - assert.Equal(t, 1, github.ListOpenPullRequestsCallCount()) + assert.Equal(t, 1, github.ListPullRequestsCallCount()) }) } } diff --git a/fakes/fake_github.go b/fakes/fake_github.go index ac8242a6..1847478f 100644 --- a/fakes/fake_github.go +++ b/fakes/fake_github.go @@ -4,6 +4,7 @@ package fakes import ( "sync" + "github.com/shurcooL/githubv4" resource "github.com/telia-oss/github-pr-resource" ) @@ -60,15 +61,16 @@ type FakeGithub struct { result1 []string result2 error } - ListOpenPullRequestsStub func() ([]*resource.PullRequest, error) - listOpenPullRequestsMutex sync.RWMutex - listOpenPullRequestsArgsForCall []struct { + ListPullRequestsStub func([]githubv4.PullRequestState) ([]*resource.PullRequest, error) + listPullRequestsMutex sync.RWMutex + listPullRequestsArgsForCall []struct { + arg1 []githubv4.PullRequestState } - listOpenPullRequestsReturns struct { + listPullRequestsReturns struct { result1 []*resource.PullRequest result2 error } - listOpenPullRequestsReturnsOnCall map[int]struct { + listPullRequestsReturnsOnCall map[int]struct { result1 []*resource.PullRequest result2 error } @@ -355,56 +357,69 @@ func (fake *FakeGithub) ListModifiedFilesReturnsOnCall(i int, result1 []string, }{result1, result2} } -func (fake *FakeGithub) ListOpenPullRequests() ([]*resource.PullRequest, error) { - fake.listOpenPullRequestsMutex.Lock() - ret, specificReturn := fake.listOpenPullRequestsReturnsOnCall[len(fake.listOpenPullRequestsArgsForCall)] - fake.listOpenPullRequestsArgsForCall = append(fake.listOpenPullRequestsArgsForCall, struct { - }{}) - fake.recordInvocation("ListOpenPullRequests", []interface{}{}) - fake.listOpenPullRequestsMutex.Unlock() - if fake.ListOpenPullRequestsStub != nil { - return fake.ListOpenPullRequestsStub() +func (fake *FakeGithub) ListPullRequests(arg1 []githubv4.PullRequestState) ([]*resource.PullRequest, error) { + var arg1Copy []githubv4.PullRequestState + if arg1 != nil { + arg1Copy = make([]githubv4.PullRequestState, len(arg1)) + copy(arg1Copy, arg1) + } + fake.listPullRequestsMutex.Lock() + ret, specificReturn := fake.listPullRequestsReturnsOnCall[len(fake.listPullRequestsArgsForCall)] + fake.listPullRequestsArgsForCall = append(fake.listPullRequestsArgsForCall, struct { + arg1 []githubv4.PullRequestState + }{arg1Copy}) + fake.recordInvocation("ListPullRequests", []interface{}{arg1Copy}) + fake.listPullRequestsMutex.Unlock() + if fake.ListPullRequestsStub != nil { + return fake.ListPullRequestsStub(arg1) } if specificReturn { return ret.result1, ret.result2 } - fakeReturns := fake.listOpenPullRequestsReturns + fakeReturns := fake.listPullRequestsReturns return fakeReturns.result1, fakeReturns.result2 } -func (fake *FakeGithub) ListOpenPullRequestsCallCount() int { - fake.listOpenPullRequestsMutex.RLock() - defer fake.listOpenPullRequestsMutex.RUnlock() - return len(fake.listOpenPullRequestsArgsForCall) +func (fake *FakeGithub) ListPullRequestsCallCount() int { + fake.listPullRequestsMutex.RLock() + defer fake.listPullRequestsMutex.RUnlock() + return len(fake.listPullRequestsArgsForCall) +} + +func (fake *FakeGithub) ListPullRequestsCalls(stub func([]githubv4.PullRequestState) ([]*resource.PullRequest, error)) { + fake.listPullRequestsMutex.Lock() + defer fake.listPullRequestsMutex.Unlock() + fake.ListPullRequestsStub = stub } -func (fake *FakeGithub) ListOpenPullRequestsCalls(stub func() ([]*resource.PullRequest, error)) { - fake.listOpenPullRequestsMutex.Lock() - defer fake.listOpenPullRequestsMutex.Unlock() - fake.ListOpenPullRequestsStub = stub +func (fake *FakeGithub) ListPullRequestsArgsForCall(i int) []githubv4.PullRequestState { + fake.listPullRequestsMutex.RLock() + defer fake.listPullRequestsMutex.RUnlock() + argsForCall := fake.listPullRequestsArgsForCall[i] + return argsForCall.arg1 } -func (fake *FakeGithub) ListOpenPullRequestsReturns(result1 []*resource.PullRequest, result2 error) { - fake.listOpenPullRequestsMutex.Lock() - defer fake.listOpenPullRequestsMutex.Unlock() - fake.ListOpenPullRequestsStub = nil - fake.listOpenPullRequestsReturns = struct { +func (fake *FakeGithub) ListPullRequestsReturns(result1 []*resource.PullRequest, result2 error) { + fake.listPullRequestsMutex.Lock() + defer fake.listPullRequestsMutex.Unlock() + fake.ListPullRequestsStub = nil + fake.listPullRequestsReturns = struct { result1 []*resource.PullRequest result2 error }{result1, result2} } -func (fake *FakeGithub) ListOpenPullRequestsReturnsOnCall(i int, result1 []*resource.PullRequest, result2 error) { - fake.listOpenPullRequestsMutex.Lock() - defer fake.listOpenPullRequestsMutex.Unlock() - fake.ListOpenPullRequestsStub = nil - if fake.listOpenPullRequestsReturnsOnCall == nil { - fake.listOpenPullRequestsReturnsOnCall = make(map[int]struct { +func (fake *FakeGithub) ListPullRequestsReturnsOnCall(i int, result1 []*resource.PullRequest, result2 error) { + fake.listPullRequestsMutex.Lock() + defer fake.listPullRequestsMutex.Unlock() + fake.ListPullRequestsStub = nil + if fake.listPullRequestsReturnsOnCall == nil { + fake.listPullRequestsReturnsOnCall = make(map[int]struct { result1 []*resource.PullRequest result2 error }) } - fake.listOpenPullRequestsReturnsOnCall[i] = struct { + fake.listPullRequestsReturnsOnCall[i] = struct { result1 []*resource.PullRequest result2 error }{result1, result2} @@ -547,8 +562,8 @@ func (fake *FakeGithub) Invocations() map[string][][]interface{} { defer fake.getPullRequestMutex.RUnlock() fake.listModifiedFilesMutex.RLock() defer fake.listModifiedFilesMutex.RUnlock() - fake.listOpenPullRequestsMutex.RLock() - defer fake.listOpenPullRequestsMutex.RUnlock() + fake.listPullRequestsMutex.RLock() + defer fake.listPullRequestsMutex.RUnlock() fake.postCommentMutex.RLock() defer fake.postCommentMutex.RUnlock() fake.updateCommitStatusMutex.RLock() diff --git a/github.go b/github.go index 33372ccc..ab10cbdc 100644 --- a/github.go +++ b/github.go @@ -20,7 +20,7 @@ import ( // Github for testing purposes. //go:generate go run github.com/maxbrunsfeld/counterfeiter/v6 -o fakes/fake_github.go . Github type Github interface { - ListOpenPullRequests() ([]*PullRequest, error) + ListPullRequests([]githubv4.PullRequestState) ([]*PullRequest, error) ListModifiedFiles(int) ([]string, error) PostComment(string, string) error GetPullRequest(string, string) (*PullRequest, error) @@ -97,8 +97,8 @@ func NewGithubClient(s *Source) (*GithubClient, error) { }, nil } -// ListOpenPullRequests gets the last commit on all open pull requests. -func (m *GithubClient) ListOpenPullRequests() ([]*PullRequest, error) { +// ListPullRequests gets the last commit on all pull requests with the matching state. +func (m *GithubClient) ListPullRequests(prStates []githubv4.PullRequestState) ([]*PullRequest, error) { var query struct { Repository struct { PullRequests struct { @@ -136,7 +136,7 @@ func (m *GithubClient) ListOpenPullRequests() ([]*PullRequest, error) { "repositoryOwner": githubv4.String(m.Owner), "repositoryName": githubv4.String(m.Repository), "prFirst": githubv4.Int(100), - "prStates": []githubv4.PullRequestState{githubv4.PullRequestStateOpen, githubv4.PullRequestStateClosed, githubv4.PullRequestStateMerged}, + "prStates": prStates, "prCursor": (*githubv4.String)(nil), "commitsLast": githubv4.Int(1), "prReviewStates": []githubv4.PullRequestReviewState{githubv4.PullRequestReviewStateApproved},