diff --git a/pkg/sources/github/github.go b/pkg/sources/github/github.go index 8d0f0f046a78..10f8f7fade13 100644 --- a/pkg/sources/github/github.go +++ b/pkg/sources/github/github.go @@ -317,11 +317,12 @@ func (s *Source) visibilityOf(ctx context.Context, repoURL string) source_metada // Chunks emits chunks of bytes over a channel. func (s *Source) Chunks(ctx context.Context, chunksChan chan *sources.Chunk, targets ...sources.ChunkingTarget) error { + chunksReporter := sources.ChanReporter{Ch: chunksChan} // If targets are provided, we're only scanning the data in those targets. // Otherwise, we're scanning all data. // This allows us to only scan the commit where a vulnerability was found. if len(targets) > 0 { - errs := s.scanTargets(ctx, targets, chunksChan) + errs := s.scanTargets(ctx, targets, chunksReporter) return errors.Join(errs...) } @@ -335,7 +336,7 @@ func (s *Source) Chunks(ctx context.Context, chunksChan chan *sources.Chunk, tar return fmt.Errorf("error enumerating: %w", err) } - return s.scan(ctx, chunksChan) + return s.scan(ctx, chunksReporter) } func (s *Source) enumerate(ctx context.Context) error { @@ -564,7 +565,7 @@ func createGitHubClient(httpClient *http.Client, apiEndpoint string) (*github.Cl return github.NewClient(httpClient).WithEnterpriseURLs(apiEndpoint, apiEndpoint) } -func (s *Source) scan(ctx context.Context, chunksChan chan *sources.Chunk) error { +func (s *Source) scan(ctx context.Context, reporter sources.ChunkReporter) error { var scannedCount uint64 = 1 ctx.Logger().V(2).Info("Found repos to scan", "count", len(s.repos)) @@ -609,7 +610,7 @@ func (s *Source) scan(ctx context.Context, chunksChan chan *sources.Chunk) error return nil } repoCtx := context.WithValues(ctx, "repo", repoURL) - duration, err := s.cloneAndScanRepo(repoCtx, repoURL, repoInfo, chunksChan) + duration, err := s.cloneAndScanRepo(repoCtx, repoURL, repoInfo, reporter) if err != nil { scanErrs.Add(err) return nil @@ -620,7 +621,7 @@ func (s *Source) scan(ctx context.Context, chunksChan chan *sources.Chunk) error wikiURL := strings.TrimSuffix(repoURL, ".git") + ".wiki.git" wikiCtx := context.WithValue(ctx, "repo", wikiURL) - _, err := s.cloneAndScanRepo(wikiCtx, wikiURL, repoInfo, chunksChan) + _, err := s.cloneAndScanRepo(wikiCtx, wikiURL, repoInfo, reporter) if err != nil { // Ignore "Repository not found" errors. // It's common for GitHub's API to say a repo has a wiki when it doesn't. @@ -634,7 +635,7 @@ func (s *Source) scan(ctx context.Context, chunksChan chan *sources.Chunk) error // Scan comments, if enabled. if s.includeGistComments || s.includeIssueComments || s.includePRComments { - if err = s.scanComments(repoCtx, repoURL, repoInfo, chunksChan); err != nil { + if err = s.scanComments(repoCtx, repoURL, repoInfo, reporter); err != nil { scanErrs.Add(fmt.Errorf("error scanning comments in repo %s: %w", repoURL, err)) return nil } @@ -656,7 +657,7 @@ func (s *Source) scan(ctx context.Context, chunksChan chan *sources.Chunk) error return nil } -func (s *Source) cloneAndScanRepo(ctx context.Context, repoURL string, repoInfo repoInfo, chunksChan chan *sources.Chunk) (time.Duration, error) { +func (s *Source) cloneAndScanRepo(ctx context.Context, repoURL string, repoInfo repoInfo, reporter sources.ChunkReporter) (time.Duration, error) { var duration time.Duration ctx.Logger().V(2).Info("attempting to clone repo") @@ -679,7 +680,7 @@ func (s *Source) cloneAndScanRepo(ctx context.Context, repoURL string, repoInfo logger.V(2).Info("scanning repo") start := time.Now() - if err = s.git.ScanRepo(ctx, repo, path, s.scanOptions, sources.ChanReporter{Ch: chunksChan}); err != nil { + if err = s.git.ScanRepo(ctx, repo, path, s.scanOptions, reporter); err != nil { return duration, fmt.Errorf("error scanning repo %s: %w", repoURL, err) } duration = time.Since(start) @@ -948,16 +949,16 @@ func (s *Source) setProgressCompleteWithRepo(index int, offset int, repoURL stri s.SetProgressComplete(index+offset, len(s.repos)+offset, fmt.Sprintf("Repo: %s", repoURL), encodedResumeInfo) } -func (s *Source) scanComments(ctx context.Context, repoPath string, repoInfo repoInfo, chunksChan chan *sources.Chunk) error { +func (s *Source) scanComments(ctx context.Context, repoPath string, repoInfo repoInfo, reporter sources.ChunkReporter) error { urlString, urlParts, err := getRepoURLParts(repoPath) if err != nil { return err } if s.includeGistComments && isGistUrl(urlParts) { - return s.processGistComments(ctx, urlString, urlParts, repoInfo, chunksChan) + return s.processGistComments(ctx, urlString, urlParts, repoInfo, reporter) } else if s.includeIssueComments || s.includePRComments { - return s.processRepoComments(ctx, repoInfo, chunksChan) + return s.processRepoComments(ctx, repoInfo, reporter) } return nil } @@ -1017,7 +1018,7 @@ func getRepoURLParts(repoURLString string) (string, []string, error) { const initialPage = 1 // page to start listing from -func (s *Source) processGistComments(ctx context.Context, gistURL string, urlParts []string, repoInfo repoInfo, chunksChan chan *sources.Chunk) error { +func (s *Source) processGistComments(ctx context.Context, gistURL string, urlParts []string, repoInfo repoInfo, reporter sources.ChunkReporter) error { ctx.Logger().V(2).Info("Scanning GitHub Gist comments") // GitHub Gist URL. @@ -1036,7 +1037,7 @@ func (s *Source) processGistComments(ctx context.Context, gistURL string, urlPar return err } - if err = s.chunkGistComments(ctx, gistURL, repoInfo, comments, chunksChan); err != nil { + if err = s.chunkGistComments(ctx, gistURL, repoInfo, comments, reporter); err != nil { return err } @@ -1056,10 +1057,10 @@ func isGistUrl(urlParts []string) bool { return strings.EqualFold(urlParts[0], "gist.github.com") || (len(urlParts) == 4 && strings.EqualFold(urlParts[1], "gist")) } -func (s *Source) chunkGistComments(ctx context.Context, gistURL string, gistInfo repoInfo, comments []*github.GistComment, chunksChan chan *sources.Chunk) error { +func (s *Source) chunkGistComments(ctx context.Context, gistURL string, gistInfo repoInfo, comments []*github.GistComment, reporter sources.ChunkReporter) error { for _, comment := range comments { // Create chunk and send it to the channel. - chunk := &sources.Chunk{ + chunk := sources.Chunk{ SourceName: s.name, SourceID: s.SourceID(), SourceType: s.Type(), @@ -1080,10 +1081,8 @@ func (s *Source) chunkGistComments(ctx context.Context, gistURL string, gistInfo Verify: s.verify, } - select { - case <-ctx.Done(): - return ctx.Err() - case chunksChan <- chunk: + if err := reporter.ChunkOk(ctx, chunk); err != nil { + return err } } return nil @@ -1104,23 +1103,23 @@ var ( state = "all" ) -func (s *Source) processRepoComments(ctx context.Context, repoInfo repoInfo, chunksChan chan *sources.Chunk) error { +func (s *Source) processRepoComments(ctx context.Context, repoInfo repoInfo, reporter sources.ChunkReporter) error { if s.includeIssueComments { ctx.Logger().V(2).Info("Scanning issues") - if err := s.processIssues(ctx, repoInfo, chunksChan); err != nil { + if err := s.processIssues(ctx, repoInfo, reporter); err != nil { return err } - if err := s.processIssueComments(ctx, repoInfo, chunksChan); err != nil { + if err := s.processIssueComments(ctx, repoInfo, reporter); err != nil { return err } } if s.includePRComments { ctx.Logger().V(2).Info("Scanning pull requests") - if err := s.processPRs(ctx, repoInfo, chunksChan); err != nil { + if err := s.processPRs(ctx, repoInfo, reporter); err != nil { return err } - if err := s.processPRComments(ctx, repoInfo, chunksChan); err != nil { + if err := s.processPRComments(ctx, repoInfo, reporter); err != nil { return err } } @@ -1129,7 +1128,7 @@ func (s *Source) processRepoComments(ctx context.Context, repoInfo repoInfo, chu } -func (s *Source) processIssues(ctx context.Context, repoInfo repoInfo, chunksChan chan *sources.Chunk) error { +func (s *Source) processIssues(ctx context.Context, repoInfo repoInfo, reporter sources.ChunkReporter) error { bodyTextsOpts := &github.IssueListByRepoOptions{ Sort: sortType, Direction: directionType, @@ -1150,7 +1149,7 @@ func (s *Source) processIssues(ctx context.Context, repoInfo repoInfo, chunksCha return err } - if err = s.chunkIssues(ctx, repoInfo, issues, chunksChan); err != nil { + if err = s.chunkIssues(ctx, repoInfo, issues, reporter); err != nil { return err } @@ -1163,7 +1162,7 @@ func (s *Source) processIssues(ctx context.Context, repoInfo repoInfo, chunksCha return nil } -func (s *Source) chunkIssues(ctx context.Context, repoInfo repoInfo, issues []*github.Issue, chunksChan chan *sources.Chunk) error { +func (s *Source) chunkIssues(ctx context.Context, repoInfo repoInfo, issues []*github.Issue, reporter sources.ChunkReporter) error { for _, issue := range issues { // Skip pull requests since covered by processPRs. @@ -1172,7 +1171,7 @@ func (s *Source) chunkIssues(ctx context.Context, repoInfo repoInfo, issues []*g } // Create chunk and send it to the channel. - chunk := &sources.Chunk{ + chunk := sources.Chunk{ SourceName: s.name, SourceID: s.SourceID(), JobID: s.JobID(), @@ -1193,16 +1192,14 @@ func (s *Source) chunkIssues(ctx context.Context, repoInfo repoInfo, issues []*g Verify: s.verify, } - select { - case <-ctx.Done(): - return ctx.Err() - case chunksChan <- chunk: + if err := reporter.ChunkOk(ctx, chunk); err != nil { + return err } } return nil } -func (s *Source) processIssueComments(ctx context.Context, repoInfo repoInfo, chunksChan chan *sources.Chunk) error { +func (s *Source) processIssueComments(ctx context.Context, repoInfo repoInfo, reporter sources.ChunkReporter) error { issueOpts := &github.IssueListCommentsOptions{ Sort: &sortType, Direction: &directionType, @@ -1221,7 +1218,7 @@ func (s *Source) processIssueComments(ctx context.Context, repoInfo repoInfo, ch return err } - if err = s.chunkIssueComments(ctx, repoInfo, issueComments, chunksChan); err != nil { + if err = s.chunkIssueComments(ctx, repoInfo, issueComments, reporter); err != nil { return err } @@ -1233,10 +1230,10 @@ func (s *Source) processIssueComments(ctx context.Context, repoInfo repoInfo, ch return nil } -func (s *Source) chunkIssueComments(ctx context.Context, repoInfo repoInfo, comments []*github.IssueComment, chunksChan chan *sources.Chunk) error { +func (s *Source) chunkIssueComments(ctx context.Context, repoInfo repoInfo, comments []*github.IssueComment, reporter sources.ChunkReporter) error { for _, comment := range comments { // Create chunk and send it to the channel. - chunk := &sources.Chunk{ + chunk := sources.Chunk{ SourceName: s.name, SourceID: s.SourceID(), JobID: s.JobID(), @@ -1257,16 +1254,14 @@ func (s *Source) chunkIssueComments(ctx context.Context, repoInfo repoInfo, comm Verify: s.verify, } - select { - case <-ctx.Done(): - return ctx.Err() - case chunksChan <- chunk: + if err := reporter.ChunkOk(ctx, chunk); err != nil { + return err } } return nil } -func (s *Source) processPRs(ctx context.Context, repoInfo repoInfo, chunksChan chan *sources.Chunk) error { +func (s *Source) processPRs(ctx context.Context, repoInfo repoInfo, reporter sources.ChunkReporter) error { prOpts := &github.PullRequestListOptions{ Sort: sortType, Direction: directionType, @@ -1286,7 +1281,7 @@ func (s *Source) processPRs(ctx context.Context, repoInfo repoInfo, chunksChan c return err } - if err = s.chunkPullRequests(ctx, repoInfo, prs, chunksChan); err != nil { + if err = s.chunkPullRequests(ctx, repoInfo, prs, reporter); err != nil { return err } @@ -1299,7 +1294,7 @@ func (s *Source) processPRs(ctx context.Context, repoInfo repoInfo, chunksChan c return nil } -func (s *Source) processPRComments(ctx context.Context, repoInfo repoInfo, chunksChan chan *sources.Chunk) error { +func (s *Source) processPRComments(ctx context.Context, repoInfo repoInfo, reporter sources.ChunkReporter) error { prOpts := &github.PullRequestListCommentsOptions{ Sort: sortType, Direction: directionType, @@ -1318,7 +1313,7 @@ func (s *Source) processPRComments(ctx context.Context, repoInfo repoInfo, chunk return err } - if err = s.chunkPullRequestComments(ctx, repoInfo, prComments, chunksChan); err != nil { + if err = s.chunkPullRequestComments(ctx, repoInfo, prComments, reporter); err != nil { return err } @@ -1331,10 +1326,10 @@ func (s *Source) processPRComments(ctx context.Context, repoInfo repoInfo, chunk return nil } -func (s *Source) chunkPullRequests(ctx context.Context, repoInfo repoInfo, prs []*github.PullRequest, chunksChan chan *sources.Chunk) error { +func (s *Source) chunkPullRequests(ctx context.Context, repoInfo repoInfo, prs []*github.PullRequest, reporter sources.ChunkReporter) error { for _, pr := range prs { // Create chunk and send it to the channel. - chunk := &sources.Chunk{ + chunk := sources.Chunk{ SourceName: s.name, SourceID: s.SourceID(), SourceType: s.Type(), @@ -1355,19 +1350,17 @@ func (s *Source) chunkPullRequests(ctx context.Context, repoInfo repoInfo, prs [ Verify: s.verify, } - select { - case <-ctx.Done(): - return ctx.Err() - case chunksChan <- chunk: + if err := reporter.ChunkOk(ctx, chunk); err != nil { + return err } } return nil } -func (s *Source) chunkPullRequestComments(ctx context.Context, repoInfo repoInfo, comments []*github.PullRequestComment, chunksChan chan *sources.Chunk) error { +func (s *Source) chunkPullRequestComments(ctx context.Context, repoInfo repoInfo, comments []*github.PullRequestComment, reporter sources.ChunkReporter) error { for _, comment := range comments { // Create chunk and send it to the channel. - chunk := &sources.Chunk{ + chunk := sources.Chunk{ SourceName: s.name, SourceID: s.SourceID(), SourceType: s.Type(), @@ -1388,19 +1381,17 @@ func (s *Source) chunkPullRequestComments(ctx context.Context, repoInfo repoInfo Verify: s.verify, } - select { - case <-ctx.Done(): - return ctx.Err() - case chunksChan <- chunk: + if err := reporter.ChunkOk(ctx, chunk); err != nil { + return err } } return nil } -func (s *Source) scanTargets(ctx context.Context, targets []sources.ChunkingTarget, chunksChan chan *sources.Chunk) []error { +func (s *Source) scanTargets(ctx context.Context, targets []sources.ChunkingTarget, reporter sources.ChunkReporter) []error { var errs []error for _, tgt := range targets { - if err := s.scanTarget(ctx, tgt, chunksChan); err != nil { + if err := s.scanTarget(ctx, tgt, reporter); err != nil { ctx.Logger().Error(err, "error scanning target") errs = append(errs, &sources.TargetedScanError{Err: err, SecretID: tgt.SecretID}) } @@ -1409,7 +1400,7 @@ func (s *Source) scanTargets(ctx context.Context, targets []sources.ChunkingTarg return errs } -func (s *Source) scanTarget(ctx context.Context, target sources.ChunkingTarget, chunksChan chan *sources.Chunk) error { +func (s *Source) scanTarget(ctx context.Context, target sources.ChunkingTarget, reporter sources.ChunkReporter) error { metaType, ok := target.QueryCriteria.GetData().(*source_metadatapb.MetaData_Github) if !ok { return fmt.Errorf("unable to cast metadata type for targeted scan") @@ -1446,7 +1437,6 @@ func (s *Source) scanTarget(ctx context.Context, target sources.ChunkingTarget, return fmt.Errorf("unexpected HTTP response status when trying to download file for scan: %v", resp.Status) } - reporter := sources.ChanReporter{Ch: chunksChan} chunkSkel := sources.Chunk{ SourceType: s.Type(), SourceName: s.name,