Skip to content

Commit

Permalink
Instrument GitHub source with a ChunkReporter (#3296)
Browse files Browse the repository at this point in the history
  • Loading branch information
mcastorina authored Sep 16, 2024
1 parent 661984c commit 401bc46
Showing 1 changed file with 50 additions and 60 deletions.
110 changes: 50 additions & 60 deletions pkg/sources/github/github.go
Original file line number Diff line number Diff line change
Expand Up @@ -317,11 +317,12 @@ func (s *Source) visibilityOf(ctx context.Context, repoURL string) source_metada

// Chunks emits chunks of bytes over a channel.
func (s *Source) Chunks(ctx context.Context, chunksChan chan *sources.Chunk, targets ...sources.ChunkingTarget) error {
chunksReporter := sources.ChanReporter{Ch: chunksChan}
// If targets are provided, we're only scanning the data in those targets.
// Otherwise, we're scanning all data.
// This allows us to only scan the commit where a vulnerability was found.
if len(targets) > 0 {
errs := s.scanTargets(ctx, targets, chunksChan)
errs := s.scanTargets(ctx, targets, chunksReporter)
return errors.Join(errs...)
}

Expand All @@ -335,7 +336,7 @@ func (s *Source) Chunks(ctx context.Context, chunksChan chan *sources.Chunk, tar
return fmt.Errorf("error enumerating: %w", err)
}

return s.scan(ctx, chunksChan)
return s.scan(ctx, chunksReporter)
}

func (s *Source) enumerate(ctx context.Context) error {
Expand Down Expand Up @@ -564,7 +565,7 @@ func createGitHubClient(httpClient *http.Client, apiEndpoint string) (*github.Cl
return github.NewClient(httpClient).WithEnterpriseURLs(apiEndpoint, apiEndpoint)
}

func (s *Source) scan(ctx context.Context, chunksChan chan *sources.Chunk) error {
func (s *Source) scan(ctx context.Context, reporter sources.ChunkReporter) error {
var scannedCount uint64 = 1

ctx.Logger().V(2).Info("Found repos to scan", "count", len(s.repos))
Expand Down Expand Up @@ -609,7 +610,7 @@ func (s *Source) scan(ctx context.Context, chunksChan chan *sources.Chunk) error
return nil
}
repoCtx := context.WithValues(ctx, "repo", repoURL)
duration, err := s.cloneAndScanRepo(repoCtx, repoURL, repoInfo, chunksChan)
duration, err := s.cloneAndScanRepo(repoCtx, repoURL, repoInfo, reporter)
if err != nil {
scanErrs.Add(err)
return nil
Expand All @@ -620,7 +621,7 @@ func (s *Source) scan(ctx context.Context, chunksChan chan *sources.Chunk) error
wikiURL := strings.TrimSuffix(repoURL, ".git") + ".wiki.git"
wikiCtx := context.WithValue(ctx, "repo", wikiURL)

_, err := s.cloneAndScanRepo(wikiCtx, wikiURL, repoInfo, chunksChan)
_, err := s.cloneAndScanRepo(wikiCtx, wikiURL, repoInfo, reporter)
if err != nil {
// Ignore "Repository not found" errors.
// It's common for GitHub's API to say a repo has a wiki when it doesn't.
Expand All @@ -634,7 +635,7 @@ func (s *Source) scan(ctx context.Context, chunksChan chan *sources.Chunk) error

// Scan comments, if enabled.
if s.includeGistComments || s.includeIssueComments || s.includePRComments {
if err = s.scanComments(repoCtx, repoURL, repoInfo, chunksChan); err != nil {
if err = s.scanComments(repoCtx, repoURL, repoInfo, reporter); err != nil {
scanErrs.Add(fmt.Errorf("error scanning comments in repo %s: %w", repoURL, err))
return nil
}
Expand All @@ -656,7 +657,7 @@ func (s *Source) scan(ctx context.Context, chunksChan chan *sources.Chunk) error
return nil
}

func (s *Source) cloneAndScanRepo(ctx context.Context, repoURL string, repoInfo repoInfo, chunksChan chan *sources.Chunk) (time.Duration, error) {
func (s *Source) cloneAndScanRepo(ctx context.Context, repoURL string, repoInfo repoInfo, reporter sources.ChunkReporter) (time.Duration, error) {
var duration time.Duration

ctx.Logger().V(2).Info("attempting to clone repo")
Expand All @@ -679,7 +680,7 @@ func (s *Source) cloneAndScanRepo(ctx context.Context, repoURL string, repoInfo
logger.V(2).Info("scanning repo")

start := time.Now()
if err = s.git.ScanRepo(ctx, repo, path, s.scanOptions, sources.ChanReporter{Ch: chunksChan}); err != nil {
if err = s.git.ScanRepo(ctx, repo, path, s.scanOptions, reporter); err != nil {
return duration, fmt.Errorf("error scanning repo %s: %w", repoURL, err)
}
duration = time.Since(start)
Expand Down Expand Up @@ -948,16 +949,16 @@ func (s *Source) setProgressCompleteWithRepo(index int, offset int, repoURL stri
s.SetProgressComplete(index+offset, len(s.repos)+offset, fmt.Sprintf("Repo: %s", repoURL), encodedResumeInfo)
}

func (s *Source) scanComments(ctx context.Context, repoPath string, repoInfo repoInfo, chunksChan chan *sources.Chunk) error {
func (s *Source) scanComments(ctx context.Context, repoPath string, repoInfo repoInfo, reporter sources.ChunkReporter) error {
urlString, urlParts, err := getRepoURLParts(repoPath)
if err != nil {
return err
}

if s.includeGistComments && isGistUrl(urlParts) {
return s.processGistComments(ctx, urlString, urlParts, repoInfo, chunksChan)
return s.processGistComments(ctx, urlString, urlParts, repoInfo, reporter)
} else if s.includeIssueComments || s.includePRComments {
return s.processRepoComments(ctx, repoInfo, chunksChan)
return s.processRepoComments(ctx, repoInfo, reporter)
}
return nil
}
Expand Down Expand Up @@ -1017,7 +1018,7 @@ func getRepoURLParts(repoURLString string) (string, []string, error) {

const initialPage = 1 // page to start listing from

func (s *Source) processGistComments(ctx context.Context, gistURL string, urlParts []string, repoInfo repoInfo, chunksChan chan *sources.Chunk) error {
func (s *Source) processGistComments(ctx context.Context, gistURL string, urlParts []string, repoInfo repoInfo, reporter sources.ChunkReporter) error {
ctx.Logger().V(2).Info("Scanning GitHub Gist comments")

// GitHub Gist URL.
Expand All @@ -1036,7 +1037,7 @@ func (s *Source) processGistComments(ctx context.Context, gistURL string, urlPar
return err
}

if err = s.chunkGistComments(ctx, gistURL, repoInfo, comments, chunksChan); err != nil {
if err = s.chunkGistComments(ctx, gistURL, repoInfo, comments, reporter); err != nil {
return err
}

Expand All @@ -1056,10 +1057,10 @@ func isGistUrl(urlParts []string) bool {
return strings.EqualFold(urlParts[0], "gist.github.com") || (len(urlParts) == 4 && strings.EqualFold(urlParts[1], "gist"))
}

func (s *Source) chunkGistComments(ctx context.Context, gistURL string, gistInfo repoInfo, comments []*github.GistComment, chunksChan chan *sources.Chunk) error {
func (s *Source) chunkGistComments(ctx context.Context, gistURL string, gistInfo repoInfo, comments []*github.GistComment, reporter sources.ChunkReporter) error {
for _, comment := range comments {
// Create chunk and send it to the channel.
chunk := &sources.Chunk{
chunk := sources.Chunk{
SourceName: s.name,
SourceID: s.SourceID(),
SourceType: s.Type(),
Expand All @@ -1080,10 +1081,8 @@ func (s *Source) chunkGistComments(ctx context.Context, gistURL string, gistInfo
Verify: s.verify,
}

select {
case <-ctx.Done():
return ctx.Err()
case chunksChan <- chunk:
if err := reporter.ChunkOk(ctx, chunk); err != nil {
return err
}
}
return nil
Expand All @@ -1104,23 +1103,23 @@ var (
state = "all"
)

func (s *Source) processRepoComments(ctx context.Context, repoInfo repoInfo, chunksChan chan *sources.Chunk) error {
func (s *Source) processRepoComments(ctx context.Context, repoInfo repoInfo, reporter sources.ChunkReporter) error {
if s.includeIssueComments {
ctx.Logger().V(2).Info("Scanning issues")
if err := s.processIssues(ctx, repoInfo, chunksChan); err != nil {
if err := s.processIssues(ctx, repoInfo, reporter); err != nil {
return err
}
if err := s.processIssueComments(ctx, repoInfo, chunksChan); err != nil {
if err := s.processIssueComments(ctx, repoInfo, reporter); err != nil {
return err
}
}

if s.includePRComments {
ctx.Logger().V(2).Info("Scanning pull requests")
if err := s.processPRs(ctx, repoInfo, chunksChan); err != nil {
if err := s.processPRs(ctx, repoInfo, reporter); err != nil {
return err
}
if err := s.processPRComments(ctx, repoInfo, chunksChan); err != nil {
if err := s.processPRComments(ctx, repoInfo, reporter); err != nil {
return err
}
}
Expand All @@ -1129,7 +1128,7 @@ func (s *Source) processRepoComments(ctx context.Context, repoInfo repoInfo, chu

}

func (s *Source) processIssues(ctx context.Context, repoInfo repoInfo, chunksChan chan *sources.Chunk) error {
func (s *Source) processIssues(ctx context.Context, repoInfo repoInfo, reporter sources.ChunkReporter) error {
bodyTextsOpts := &github.IssueListByRepoOptions{
Sort: sortType,
Direction: directionType,
Expand All @@ -1150,7 +1149,7 @@ func (s *Source) processIssues(ctx context.Context, repoInfo repoInfo, chunksCha
return err
}

if err = s.chunkIssues(ctx, repoInfo, issues, chunksChan); err != nil {
if err = s.chunkIssues(ctx, repoInfo, issues, reporter); err != nil {
return err
}

Expand All @@ -1163,7 +1162,7 @@ func (s *Source) processIssues(ctx context.Context, repoInfo repoInfo, chunksCha
return nil
}

func (s *Source) chunkIssues(ctx context.Context, repoInfo repoInfo, issues []*github.Issue, chunksChan chan *sources.Chunk) error {
func (s *Source) chunkIssues(ctx context.Context, repoInfo repoInfo, issues []*github.Issue, reporter sources.ChunkReporter) error {
for _, issue := range issues {

// Skip pull requests since covered by processPRs.
Expand All @@ -1172,7 +1171,7 @@ func (s *Source) chunkIssues(ctx context.Context, repoInfo repoInfo, issues []*g
}

// Create chunk and send it to the channel.
chunk := &sources.Chunk{
chunk := sources.Chunk{
SourceName: s.name,
SourceID: s.SourceID(),
JobID: s.JobID(),
Expand All @@ -1193,16 +1192,14 @@ func (s *Source) chunkIssues(ctx context.Context, repoInfo repoInfo, issues []*g
Verify: s.verify,
}

select {
case <-ctx.Done():
return ctx.Err()
case chunksChan <- chunk:
if err := reporter.ChunkOk(ctx, chunk); err != nil {
return err
}
}
return nil
}

func (s *Source) processIssueComments(ctx context.Context, repoInfo repoInfo, chunksChan chan *sources.Chunk) error {
func (s *Source) processIssueComments(ctx context.Context, repoInfo repoInfo, reporter sources.ChunkReporter) error {
issueOpts := &github.IssueListCommentsOptions{
Sort: &sortType,
Direction: &directionType,
Expand All @@ -1221,7 +1218,7 @@ func (s *Source) processIssueComments(ctx context.Context, repoInfo repoInfo, ch
return err
}

if err = s.chunkIssueComments(ctx, repoInfo, issueComments, chunksChan); err != nil {
if err = s.chunkIssueComments(ctx, repoInfo, issueComments, reporter); err != nil {
return err
}

Expand All @@ -1233,10 +1230,10 @@ func (s *Source) processIssueComments(ctx context.Context, repoInfo repoInfo, ch
return nil
}

func (s *Source) chunkIssueComments(ctx context.Context, repoInfo repoInfo, comments []*github.IssueComment, chunksChan chan *sources.Chunk) error {
func (s *Source) chunkIssueComments(ctx context.Context, repoInfo repoInfo, comments []*github.IssueComment, reporter sources.ChunkReporter) error {
for _, comment := range comments {
// Create chunk and send it to the channel.
chunk := &sources.Chunk{
chunk := sources.Chunk{
SourceName: s.name,
SourceID: s.SourceID(),
JobID: s.JobID(),
Expand All @@ -1257,16 +1254,14 @@ func (s *Source) chunkIssueComments(ctx context.Context, repoInfo repoInfo, comm
Verify: s.verify,
}

select {
case <-ctx.Done():
return ctx.Err()
case chunksChan <- chunk:
if err := reporter.ChunkOk(ctx, chunk); err != nil {
return err
}
}
return nil
}

func (s *Source) processPRs(ctx context.Context, repoInfo repoInfo, chunksChan chan *sources.Chunk) error {
func (s *Source) processPRs(ctx context.Context, repoInfo repoInfo, reporter sources.ChunkReporter) error {
prOpts := &github.PullRequestListOptions{
Sort: sortType,
Direction: directionType,
Expand All @@ -1286,7 +1281,7 @@ func (s *Source) processPRs(ctx context.Context, repoInfo repoInfo, chunksChan c
return err
}

if err = s.chunkPullRequests(ctx, repoInfo, prs, chunksChan); err != nil {
if err = s.chunkPullRequests(ctx, repoInfo, prs, reporter); err != nil {
return err
}

Expand All @@ -1299,7 +1294,7 @@ func (s *Source) processPRs(ctx context.Context, repoInfo repoInfo, chunksChan c
return nil
}

func (s *Source) processPRComments(ctx context.Context, repoInfo repoInfo, chunksChan chan *sources.Chunk) error {
func (s *Source) processPRComments(ctx context.Context, repoInfo repoInfo, reporter sources.ChunkReporter) error {
prOpts := &github.PullRequestListCommentsOptions{
Sort: sortType,
Direction: directionType,
Expand All @@ -1318,7 +1313,7 @@ func (s *Source) processPRComments(ctx context.Context, repoInfo repoInfo, chunk
return err
}

if err = s.chunkPullRequestComments(ctx, repoInfo, prComments, chunksChan); err != nil {
if err = s.chunkPullRequestComments(ctx, repoInfo, prComments, reporter); err != nil {
return err
}

Expand All @@ -1331,10 +1326,10 @@ func (s *Source) processPRComments(ctx context.Context, repoInfo repoInfo, chunk
return nil
}

func (s *Source) chunkPullRequests(ctx context.Context, repoInfo repoInfo, prs []*github.PullRequest, chunksChan chan *sources.Chunk) error {
func (s *Source) chunkPullRequests(ctx context.Context, repoInfo repoInfo, prs []*github.PullRequest, reporter sources.ChunkReporter) error {
for _, pr := range prs {
// Create chunk and send it to the channel.
chunk := &sources.Chunk{
chunk := sources.Chunk{
SourceName: s.name,
SourceID: s.SourceID(),
SourceType: s.Type(),
Expand All @@ -1355,19 +1350,17 @@ func (s *Source) chunkPullRequests(ctx context.Context, repoInfo repoInfo, prs [
Verify: s.verify,
}

select {
case <-ctx.Done():
return ctx.Err()
case chunksChan <- chunk:
if err := reporter.ChunkOk(ctx, chunk); err != nil {
return err
}
}
return nil
}

func (s *Source) chunkPullRequestComments(ctx context.Context, repoInfo repoInfo, comments []*github.PullRequestComment, chunksChan chan *sources.Chunk) error {
func (s *Source) chunkPullRequestComments(ctx context.Context, repoInfo repoInfo, comments []*github.PullRequestComment, reporter sources.ChunkReporter) error {
for _, comment := range comments {
// Create chunk and send it to the channel.
chunk := &sources.Chunk{
chunk := sources.Chunk{
SourceName: s.name,
SourceID: s.SourceID(),
SourceType: s.Type(),
Expand All @@ -1388,19 +1381,17 @@ func (s *Source) chunkPullRequestComments(ctx context.Context, repoInfo repoInfo
Verify: s.verify,
}

select {
case <-ctx.Done():
return ctx.Err()
case chunksChan <- chunk:
if err := reporter.ChunkOk(ctx, chunk); err != nil {
return err
}
}
return nil
}

func (s *Source) scanTargets(ctx context.Context, targets []sources.ChunkingTarget, chunksChan chan *sources.Chunk) []error {
func (s *Source) scanTargets(ctx context.Context, targets []sources.ChunkingTarget, reporter sources.ChunkReporter) []error {
var errs []error
for _, tgt := range targets {
if err := s.scanTarget(ctx, tgt, chunksChan); err != nil {
if err := s.scanTarget(ctx, tgt, reporter); err != nil {
ctx.Logger().Error(err, "error scanning target")
errs = append(errs, &sources.TargetedScanError{Err: err, SecretID: tgt.SecretID})
}
Expand All @@ -1409,7 +1400,7 @@ func (s *Source) scanTargets(ctx context.Context, targets []sources.ChunkingTarg
return errs
}

func (s *Source) scanTarget(ctx context.Context, target sources.ChunkingTarget, chunksChan chan *sources.Chunk) error {
func (s *Source) scanTarget(ctx context.Context, target sources.ChunkingTarget, reporter sources.ChunkReporter) error {
metaType, ok := target.QueryCriteria.GetData().(*source_metadatapb.MetaData_Github)
if !ok {
return fmt.Errorf("unable to cast metadata type for targeted scan")
Expand Down Expand Up @@ -1446,7 +1437,6 @@ func (s *Source) scanTarget(ctx context.Context, target sources.ChunkingTarget,
return fmt.Errorf("unexpected HTTP response status when trying to download file for scan: %v", resp.Status)
}

reporter := sources.ChanReporter{Ch: chunksChan}
chunkSkel := sources.Chunk{
SourceType: s.Type(),
SourceName: s.name,
Expand Down

0 comments on commit 401bc46

Please sign in to comment.