diff --git a/pkg/model/article.go b/pkg/model/article.go index 0b1aa2d..b2ed5e4 100644 --- a/pkg/model/article.go +++ b/pkg/model/article.go @@ -12,6 +12,8 @@ type Article struct { Link string } +type ArticlesStream <-chan Article + func NewArticleWithContent(title, link, content string) Article { return Article{ Title: title, @@ -43,10 +45,11 @@ func (a *Article) IsValid() bool { return isUrl(a.Link) && contentIsValid && titleIsValid } -func TakeRandomArticles(articles []Article, take int) []Article { +func TakeRandomArticles(stream ArticlesStream, take int) []Article { if take == 0 { return make([]Article, 0) } + articles := ToArticlesArray(stream) if take >= len(articles) { return articles } @@ -58,3 +61,11 @@ func TakeRandomArticles(articles []Article, take int) []Article { return randomArticles } + +func ToArticlesArray(s ArticlesStream) []Article { + res := make([]Article, 0) + for v := range s { + res = append(res, v) + } + return res +} diff --git a/pkg/model/article_test.go b/pkg/model/article_test.go index 36d5f7a..ea6f6ba 100644 --- a/pkg/model/article_test.go +++ b/pkg/model/article_test.go @@ -43,19 +43,31 @@ func TestArticleValidationWhenUrlIsCorrect(t *testing.T) { } func TestGetRandomArticlesWhenTakeIsZero(t *testing.T) { - articles := []Article{NewArticle("x", "2"), NewArticle("d", "1"), NewArticle("xd", "37")} + articles := make(chan Article, 10) + for _, a := range []Article{NewArticle("x", "2"), NewArticle("d", "1"), NewArticle("xd", "37")} { + articles <- a + } + close(articles) randomArticles := TakeRandomArticles(articles, 0) assert.Len(t, randomArticles, 0) } func TestGetRandomArticlesWhenTakeIsGreaterThanLenOfArticlesArray(t *testing.T) { - articles := []Article{NewArticle("x", "2"), NewArticle("d", "1"), NewArticle("xd", "37")} + articles := make(chan Article, 10) + for _, a := range []Article{NewArticle("x", "2"), NewArticle("d", "1"), NewArticle("xd", "37")} { + articles <- a + } + close(articles) randomArticles := TakeRandomArticles(articles, 5) assert.Len(t, randomArticles, len(articles)) } func TestGetRandomArticlesWhenTakeIsSmallerThanLenOfArticlesArray(t *testing.T) { - articles := []Article{NewArticle("x", "2"), NewArticle("d", "1"), NewArticle("xd", "37")} + articles := make(chan Article, 10) + for _, a := range []Article{NewArticle("x", "2"), NewArticle("d", "1"), NewArticle("xd", "37")} { + articles <- a + } + close(articles) randomArticles := TakeRandomArticles(articles, 2) assert.Len(t, randomArticles, 2) } diff --git a/pkg/providers/articles.go b/pkg/providers/articles.go index 019c46d..e5211f1 100644 --- a/pkg/providers/articles.go +++ b/pkg/providers/articles.go @@ -10,7 +10,7 @@ import ( ) type ArticlesProvider interface { - Provide(ctx context.Context) ([]model.Article, error) + Provide(ctx context.Context) model.ArticlesStream } type articlesProvider struct { @@ -26,7 +26,11 @@ func fanIn(ctx context.Context, stream ...chan model.Article) chan model.Article out := make(chan model.Article) output := func(c <-chan model.Article) { for v := range c { - out <- v + select { + case <-ctx.Done(): + return + case out <- v: + } } wg.Done() } @@ -49,7 +53,11 @@ func (f *articlesProvider) parse(ctx context.Context, parser parsers.ArticlesPar log.WithError(err).WithContext(ctx).Error("Error while parsing articles") } else { for _, v := range res { - stream <- v + select { + case <-ctx.Done(): + return + case stream <- v: + } } } close(stream) @@ -57,23 +65,10 @@ func (f *articlesProvider) parse(ctx context.Context, parser parsers.ArticlesPar return stream } -func (f *articlesProvider) Provide(ctx context.Context) ([]model.Article, error) { +func (f *articlesProvider) Provide(ctx context.Context) model.ArticlesStream { streams := make([]chan model.Article, 0, len(f.parsers)) - result := make([]model.Article, 0) for _, parser := range f.parsers { streams = append(streams, f.parse(ctx, parser)) } - finalStream := fanIn(ctx, streams...) - for { - select { - case v, ok := <-finalStream: - if ok { - result = append(result, v) - } else { - return result, nil - } - case <-ctx.Done(): - return result, nil - } - } + return fanIn(ctx, streams...) } diff --git a/pkg/providers/articles_test.go b/pkg/providers/articles_test.go index 446c5d8..52786dd 100644 --- a/pkg/providers/articles_test.go +++ b/pkg/providers/articles_test.go @@ -33,8 +33,7 @@ func (f *fakeErrorParser) Parse(ctx context.Context) ([]model.Article, error) { func TestArticlesProvider(t *testing.T) { articlesProvider := NewArticlesProvider([]parsers.ArticlesParser{&fakeParser{}, &fakeParser2{}}) - subject, err := articlesProvider.Provide(context.Background()) - assert.Nil(t, err) + subject := model.ToArticlesArray(articlesProvider.Provide(context.Background())) assert.Len(t, subject, 2) assert.Equal(t, "test", subject[0].Title) assert.Equal(t, "test", subject[1].Title) @@ -42,8 +41,7 @@ func TestArticlesProvider(t *testing.T) { func TestArticlesProviderWhenError(t *testing.T) { articlesProvider := NewArticlesProvider([]parsers.ArticlesParser{&fakeParser{}, &fakeErrorParser{}}) - subject, err := articlesProvider.Provide(context.Background()) - assert.Nil(t, err) + subject := model.ToArticlesArray(articlesProvider.Provide(context.Background())) assert.Len(t, subject, 1) assert.Equal(t, "test", subject[0].Title) } diff --git a/pkg/usecase/parser.go b/pkg/usecase/parser.go index 7d7ad75..d60bc95 100644 --- a/pkg/usecase/parser.go +++ b/pkg/usecase/parser.go @@ -2,6 +2,7 @@ package usecase import ( "context" + "sync" "github.com/dominikus1993/dev-news-bot/pkg/model" "github.com/dominikus1993/dev-news-bot/pkg/notifications" @@ -20,42 +21,49 @@ func NewParseArticlesAndSendItUseCase(articlesProvider providers.ArticlesProvide return &ParseArticlesAndSendItUseCase{articlesProvider: articlesProvider, repository: repository, broadcaster: broadcaster} } -func (u *ParseArticlesAndSendItUseCase) filterNewArticles(ctx context.Context, articles []model.Article) []model.Article { - filteredArticles := make([]model.Article, 0, len(articles)) - for _, article := range articles { +func pipe(ctx context.Context, articles model.ArticlesStream, f func(ctx context.Context, article model.Article, articles chan<- model.Article)) model.ArticlesStream { + filteredArticles := make(chan model.Article, 10) + go func() { + var wg sync.WaitGroup + for article := range articles { + wg.Add(1) + go func(a model.Article) { + f(ctx, a, filteredArticles) + wg.Done() + }(article) + } + wg.Wait() + close(filteredArticles) + }() + return filteredArticles +} + +func (u *ParseArticlesAndSendItUseCase) filterNewArticles(ctx context.Context, articles model.ArticlesStream) model.ArticlesStream { + return pipe(ctx, articles, func(ctx context.Context, article model.Article, articles chan<- model.Article) { isNew, err := u.repository.IsNew(ctx, article) if err != nil { log.WithField("ArticleLink", article.Link).WithError(err).WithContext(ctx).Error("error while checking if article exists") } if isNew { - filteredArticles = append(filteredArticles, article) + articles <- article } - } - return filteredArticles + }) } -func (u *ParseArticlesAndSendItUseCase) filterValid(ctx context.Context, articles []model.Article) []model.Article { - filteredArticles := make([]model.Article, 0, len(articles)) - for _, article := range articles { +func (u *ParseArticlesAndSendItUseCase) filterValid(ctx context.Context, articles model.ArticlesStream) model.ArticlesStream { + return pipe(ctx, articles, func(ctx context.Context, article model.Article, articles chan<- model.Article) { if article.IsValid() { - filteredArticles = append(filteredArticles, article) + articles <- article } - } - return filteredArticles + }) } func (u *ParseArticlesAndSendItUseCase) Execute(ctx context.Context, articlesQuantity int) error { - articles, err := u.articlesProvider.Provide(ctx) - if err != nil { - return err - } - log.Infoln("Found articles:", len(articles)) + articles := u.articlesProvider.Provide(ctx) validArticles := u.filterValid(ctx, articles) - log.Infoln("Found valid articles:", len(validArticles)) newArticles := u.filterNewArticles(ctx, validArticles) - log.Infoln("Found new articles:", len(newArticles)) randomArticles := model.TakeRandomArticles(newArticles, articlesQuantity) - err = u.broadcaster.Broadcast(ctx, randomArticles) + err := u.broadcaster.Broadcast(ctx, randomArticles) if err != nil { return err }