From 121a1c939315b9ec4ea2db0c939383676060dc48 Mon Sep 17 00:00:00 2001 From: Karl Higley Date: Mon, 8 Jul 2024 12:12:01 -0400 Subject: [PATCH] Create and apply a `RecommendationPipeline` class (#56) This introduces a simple-ish `RecommendationPipeline` that works with the components created in #43 and #53. Conceptually, the components run in the order they're added to the pipeline and modify shared state that each component can select inputs from and write outputs to. It could be generalized (e.g. to support types other than `ArticleSet` and `InterestProfile`), but this is enough to support the recommenders we're currently building/using. --- src/poprox_recommender/default.py | 92 +++++++-------------- src/poprox_recommender/diversifiers/pfar.py | 5 +- src/poprox_recommender/embedders/article.py | 4 +- src/poprox_recommender/embedders/user.py | 4 +- src/poprox_recommender/filters/__init__.py | 3 + src/poprox_recommender/filters/topic.py | 20 +++++ src/poprox_recommender/handler.py | 26 +++++- src/poprox_recommender/pipeline.py | 56 +++++++++++++ src/poprox_recommender/samplers/__init__.py | 3 + src/poprox_recommender/samplers/uniform.py | 26 ++++++ src/poprox_recommender/scorers/article.py | 4 +- tests/basic-request.json | 1 - tests/test_pfar.py | 41 +++------ tests/test_select.py | 30 +++++-- tests/test_smoke.py | 26 +++++- 15 files changed, 227 insertions(+), 114 deletions(-) create mode 100644 src/poprox_recommender/filters/__init__.py create mode 100644 src/poprox_recommender/filters/topic.py create mode 100644 src/poprox_recommender/pipeline.py create mode 100644 src/poprox_recommender/samplers/__init__.py create mode 100644 src/poprox_recommender/samplers/uniform.py diff --git a/src/poprox_recommender/default.py b/src/poprox_recommender/default.py index 120afea9..7141bd07 100644 --- a/src/poprox_recommender/default.py +++ b/src/poprox_recommender/default.py @@ -1,91 +1,55 @@ -import random from typing import Any -from uuid import UUID -from poprox_concepts import Article, ArticleSet, InterestProfile +from poprox_concepts import ArticleSet, InterestProfile from poprox_recommender.diversifiers import MMRDiversifier, PFARDiversifier from poprox_recommender.embedders import ArticleEmbedder, UserEmbedder +from poprox_recommender.filters import TopicFilter from poprox_recommender.model import get_model +from poprox_recommender.pipeline import RecommendationPipeline +from poprox_recommender.samplers import UniformSampler from poprox_recommender.scorers import ArticleScorer -from poprox_recommender.topics import user_topic_preference def select_articles( - candidate_articles: list[Article], - past_articles: list[Article], + candidate_articles: ArticleSet, + clicked_articles: ArticleSet, interest_profile: InterestProfile, num_slots: int, algo_params: dict[str, Any] | None = None, -) -> dict[UUID, list[Article]]: - candidate_articles = ArticleSet(articles=candidate_articles) - past_articles = ArticleSet(articles=past_articles) - - click_history = interest_profile.click_history - clicked_articles = list(filter(lambda a: a.article_id in set(click_history.article_ids), past_articles.articles)) - clicked_articles = ArticleSet(articles=clicked_articles) - - # This could be a component but should likely be moved upstream to the platform - interest_profile.click_topic_counts = user_topic_preference(past_articles.articles, interest_profile.click_history) - profile_id = interest_profile.profile_id - +) -> ArticleSet: algo_params = algo_params or {} diversify = str(algo_params.get("diversity_algo", "pfar")) model = get_model() - recommendations = {} - # The following code should ONLY access the InterestProfile and ArticleSets defined above - if model and click_history.article_ids: + if model and interest_profile.click_history.article_ids: article_embedder = ArticleEmbedder(model.model, model.tokenizer, model.device) user_embedder = UserEmbedder(model.model, model.device) article_scorer = ArticleScorer(model.model) - candidate_articles = article_embedder(candidate_articles) - clicked_articles = article_embedder(clicked_articles) - - interest_profile = user_embedder(clicked_articles, interest_profile) - candidate_articles = article_scorer(candidate_articles, interest_profile) - if diversify == "mmr": diversifier = MMRDiversifier(algo_params, num_slots) - recs = diversifier(candidate_articles) - elif diversify == "pfar": diversifier = PFARDiversifier(algo_params, num_slots) - recs = diversifier(candidate_articles, interest_profile) - - recommendations[profile_id] = recs.articles - else: - recommendations[profile_id] = select_by_topic( - candidate_articles.articles, - interest_profile, - num_slots, - ) - - return recommendations - - -def select_by_topic(todays_articles: list[Article], interest_profile: InterestProfile, num_slots: int): - # Preference values from onboarding are 1-indexed, where 1 means "absolutely no interest." - # We might want to normalize them to 0-indexed somewhere upstream, but in the mean time - # this is one of the simpler ways to filter out topics people aren't interested in from - # their early newsletters - profile_topics = { - interest.entity_name for interest in interest_profile.onboarding_topics if interest.preference > 1 - } - - other_articles = [] - topical_articles = [] - for article in todays_articles: - article_topics = {mention.entity.name for mention in article.mentions} - if len(profile_topics.intersection(article_topics)) > 0: - topical_articles.append(article) - else: - other_articles.append(article) - if len(topical_articles) >= num_slots: - return random.sample(topical_articles, num_slots) + pipeline = RecommendationPipeline(name=diversify) + pipeline.add(article_embedder, inputs=["candidate"], output="candidate") + pipeline.add(article_embedder, inputs=["clicked"], output="clicked") + pipeline.add(user_embedder, inputs=["clicked", "profile"], output="profile") + pipeline.add(article_scorer, inputs=["candidate", "profile"], output="candidate") + pipeline.add(diversifier, inputs=["candidate", "profile"], output="recs") else: - return random.sample(topical_articles, len(topical_articles)) + random.sample( - other_articles, num_slots - len(topical_articles) - ) + topic_filter = TopicFilter() + sampler = UniformSampler(num_slots=num_slots) + + pipeline = RecommendationPipeline(name="random_topical") + pipeline.add(topic_filter, inputs=["candidate", "profile"], output="topical") + pipeline.add(sampler, inputs=["topical", "candidate"], output="recs") + + return pipeline( + { + "candidate": candidate_articles, + "clicked": clicked_articles, + "profile": interest_profile, + } + ) diff --git a/src/poprox_recommender/diversifiers/pfar.py b/src/poprox_recommender/diversifiers/pfar.py index 668ad988..9d7de7b9 100644 --- a/src/poprox_recommender/diversifiers/pfar.py +++ b/src/poprox_recommender/diversifiers/pfar.py @@ -20,8 +20,9 @@ def __call__(self, candidate_articles: ArticleSet, interest_profile: InterestPro for interest in interest_profile.onboarding_topics: topic_preferences[interest.entity_name] = max(interest.preference - 1, 0) - for topic, click_count in interest_profile.click_topic_counts.items(): - topic_preferences[topic] = click_count + if interest_profile.click_topic_counts: + for topic, click_count in interest_profile.click_topic_counts.items(): + topic_preferences[topic] = click_count normalized_topic_prefs = normalized_topic_count(topic_preferences) diff --git a/src/poprox_recommender/embedders/article.py b/src/poprox_recommender/embedders/article.py index 3875e4e8..e547a196 100644 --- a/src/poprox_recommender/embedders/article.py +++ b/src/poprox_recommender/embedders/article.py @@ -20,6 +20,6 @@ def __call__(self, article_set: ArticleSet) -> ArticleSet: if len(title_tensor.shape) == 1: title_tensor = title_tensor.unsqueeze(dim=0) - article_embeddings = self.model.get_news_vector(title_tensor) + article_set.embeddings = self.model.get_news_vector(title_tensor) - return article_set.model_copy(update={"embeddings": article_embeddings}) + return article_set diff --git a/src/poprox_recommender/embedders/user.py b/src/poprox_recommender/embedders/user.py index 426222aa..0b9317fa 100644 --- a/src/poprox_recommender/embedders/user.py +++ b/src/poprox_recommender/embedders/user.py @@ -17,7 +17,7 @@ def __call__(self, clicked_articles: ArticleSet, interest_profile: InterestProfi embedding_lookup["PADDED_NEWS"] = th.zeros(list(embedding_lookup.values())[0].size(), device=self.device) - user_embedding = build_user_embedding( + interest_profile.embedding = build_user_embedding( interest_profile.click_history, embedding_lookup, self.model, @@ -25,7 +25,7 @@ def __call__(self, clicked_articles: ArticleSet, interest_profile: InterestProfi self.max_clicks, ) - return interest_profile.model_copy(update={"embedding": user_embedding}) + return interest_profile # Compute a vector for each user diff --git a/src/poprox_recommender/filters/__init__.py b/src/poprox_recommender/filters/__init__.py new file mode 100644 index 00000000..b7b9e254 --- /dev/null +++ b/src/poprox_recommender/filters/__init__.py @@ -0,0 +1,3 @@ +from poprox_recommender.filters.topic import TopicFilter + +__all__ = ["TopicFilter"] diff --git a/src/poprox_recommender/filters/topic.py b/src/poprox_recommender/filters/topic.py new file mode 100644 index 00000000..1cab0100 --- /dev/null +++ b/src/poprox_recommender/filters/topic.py @@ -0,0 +1,20 @@ +from poprox_concepts import ArticleSet, InterestProfile + + +class TopicFilter: + def __call__(self, candidate: ArticleSet, interest_profile: InterestProfile) -> ArticleSet: + # Preference values from onboarding are 1-indexed, where 1 means "absolutely no interest." + # We might want to normalize them to 0-indexed somewhere upstream, but in the mean time + # this is one of the simpler ways to filter out topics people aren't interested in from + # their early newsletters + profile_topics = { + interest.entity_name for interest in interest_profile.onboarding_topics if interest.preference > 1 + } + + topical_articles = [] + for article in candidate.articles: + article_topics = {mention.entity.name for mention in article.mentions} + if len(profile_topics.intersection(article_topics)) > 0: + topical_articles.append(article) + + return ArticleSet(articles=topical_articles) diff --git a/src/poprox_recommender/handler.py b/src/poprox_recommender/handler.py index d15b6ee3..7234f01f 100644 --- a/src/poprox_recommender/handler.py +++ b/src/poprox_recommender/handler.py @@ -1,11 +1,13 @@ import base64 import logging +from poprox_concepts import ArticleSet from poprox_concepts.api.recommendations import ( RecommendationRequest, RecommendationResponse, ) from poprox_recommender.default import select_articles +from poprox_recommender.topics import user_topic_preference logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) @@ -29,16 +31,32 @@ def generate_recs(event, context): logger.info("Using default parameters") logger.info("Selecting articles...") + + # The platform should send an ArticleSet but we'll do it here for now + candidate_articles = ArticleSet(articles=req.todays_articles) + + # Similarly, the platform should provided pre-filtered clicked articles + # and compute the topic counts but this shim lets us ignore that issue + # in the actual article selection + profile = req.interest_profile + click_history = profile.click_history + clicked_articles = list(filter(lambda a: a.article_id in set(click_history.article_ids), req.past_articles)) + clicked_articles = ArticleSet(articles=clicked_articles) + + profile.click_topic_counts = user_topic_preference(req.past_articles, profile.click_history) + recommendations = select_articles( - req.todays_articles, - req.past_articles, - req.interest_profile, + candidate_articles, + clicked_articles, + profile, req.num_recs, algo_params, ) logger.info("Constructing response...") - resp_body = RecommendationResponse.model_validate({"recommendations": recommendations}) + resp_body = RecommendationResponse.model_validate( + {"recommendations": {profile.profile_id: recommendations.articles}} + ) logger.info("Serializing response...") response = {"statusCode": 200, "body": resp_body.model_dump_json()} diff --git a/src/poprox_recommender/pipeline.py b/src/poprox_recommender/pipeline.py new file mode 100644 index 00000000..f6acce6a --- /dev/null +++ b/src/poprox_recommender/pipeline.py @@ -0,0 +1,56 @@ +from copy import deepcopy +from dataclasses import dataclass +from typing import Callable + +from poprox_concepts import ArticleSet, InterestProfile + + +@dataclass +class ComponentSpec: + component: Callable + inputs: list[str] + output: str + + +class RecommendationPipeline: + def __init__(self, name): + self.name = name + self.components = [] + + def add(self, component: Callable, inputs: list[str], output: str): + self.components.append(ComponentSpec(component, inputs, output)) + + def __call__(self, inputs: dict[str, ArticleSet | InterestProfile]) -> ArticleSet: + # Avoid modifying the inputs + state = deepcopy(inputs) + + # Run each component in the order it was added + for component_spec in self.components: + state = self.run_component(component_spec, state) + + recs = state[self.components[-1].output] + + # Double check that we're returning the right type for recs + if not isinstance(recs, ArticleSet): + msg = f"The final pipeline component must return ArticleSet, but received {type(recs)}" + raise TypeError(msg) + + return recs + + def run_component(self, component_spec: ComponentSpec, state: dict[str, ArticleSet | InterestProfile]): + arguments = [] + for input_name in component_spec.inputs: + arguments.append(state[input_name]) + + output = component_spec.component(*arguments) + + if not isinstance(output, (ArticleSet, InterestProfile)): + msg = ( + f"Pipeline components must return ArticleSet or InterestProfile, " + f"but received {type(output)} from {type(component_spec.component)}" + ) + raise TypeError(msg) + + state[component_spec.output] = output + + return state diff --git a/src/poprox_recommender/samplers/__init__.py b/src/poprox_recommender/samplers/__init__.py new file mode 100644 index 00000000..1c8c71aa --- /dev/null +++ b/src/poprox_recommender/samplers/__init__.py @@ -0,0 +1,3 @@ +from poprox_recommender.samplers.uniform import UniformSampler + +__all__ = ["UniformSampler"] diff --git a/src/poprox_recommender/samplers/uniform.py b/src/poprox_recommender/samplers/uniform.py new file mode 100644 index 00000000..ea0fc8fe --- /dev/null +++ b/src/poprox_recommender/samplers/uniform.py @@ -0,0 +1,26 @@ +import random + +from poprox_concepts import ArticleSet + + +class UniformSampler: + def __init__(self, num_slots): + self.num_slots = num_slots + + def __call__(self, candidate: ArticleSet, backup: ArticleSet | None = None): + backup_articles = list(filter(lambda article: article not in candidate.articles, backup.articles)) + + num_backups = ( + self.num_slots - len(candidate.articles) + if len(candidate.articles) + len(backup_articles) >= self.num_slots + else len(backup_articles) + ) + + if len(candidate.articles) < self.num_slots and backup_articles: + sampled = random.sample(candidate.articles, len(candidate.articles)) + random.sample( + backup_articles, num_backups + ) + else: + sampled = random.sample(candidate.articles, self.num_slots) + + return ArticleSet(articles=sampled) diff --git a/src/poprox_recommender/scorers/article.py b/src/poprox_recommender/scorers/article.py index 2c7e818f..46bef6b6 100644 --- a/src/poprox_recommender/scorers/article.py +++ b/src/poprox_recommender/scorers/article.py @@ -10,4 +10,6 @@ def __call__(self, candidate_articles: ArticleSet, interest_profile: InterestPro user_embedding = interest_profile.embedding pred = self.model.get_prediction(candidate_embeddings, user_embedding.squeeze()) - return candidate_articles.model_copy(update={"scores": pred.cpu().detach().numpy()}) + candidate_articles.scores = pred.cpu().detach().numpy() + + return candidate_articles diff --git a/tests/basic-request.json b/tests/basic-request.json index fb141ec7..0ca86f5a 100644 --- a/tests/basic-request.json +++ b/tests/basic-request.json @@ -18,7 +18,6 @@ "interest_profile": { "profile_id": "28838f05-23f5-4f23-bea2-30b51f67c538", "click_history": { - "account_id": "977a3c88-937a-46fb-bbfe-94dc5dcb68c8", "article_ids": [ "e7605f12-a37a-4326-bf3c-3f9b72d0738d" ] diff --git a/tests/test_pfar.py b/tests/test_pfar.py index 310409b2..6100dcd4 100644 --- a/tests/test_pfar.py +++ b/tests/test_pfar.py @@ -3,9 +3,12 @@ import random from pathlib import Path -from poprox_concepts import Article, ClickHistory, InterestProfile -from poprox_recommender.default import select_articles, user_topic_preference -from poprox_recommender.topics import GENERAL_TOPICS, extract_general_topics, match_news_topics_to_general +from poprox_concepts import Article, ArticleSet, ClickHistory +from poprox_recommender.topics import ( + GENERAL_TOPICS, + extract_general_topics, + match_news_topics_to_general, +) logger = logging.getLogger(__name__) @@ -15,20 +18,17 @@ def load_test_articles(): with open(event_path, "r") as j: req_body = json.loads(j.read()) - todays_articles = [Article.model_validate(attrs) for attrs in req_body["todays_articles"]] - - past_articles = [Article.model_validate(attrs) for attrs in req_body["past_articles"]] - + candidate = ArticleSet(articles=[Article.model_validate(attrs) for attrs in req_body["todays_articles"]]) + past = ArticleSet(articles=[Article.model_validate(attrs) for attrs in req_body["past_articles"]]) click_history = [ClickHistory.model_validate(attrs) for attrs in req_body["click_data"]] - num_recs = req_body["num_recs"] - return todays_articles, past_articles, click_history, num_recs + return candidate, past, click_history, num_recs def test_topic_classification(): - todays_articles, _, _, _ = load_test_articles() - topic_matched_dict, todays_article_matched_topics = match_news_topics_to_general(todays_articles) + candidate, _, _, _ = load_test_articles() + topic_matched_dict, todays_article_matched_topics = match_news_topics_to_general(candidate.articles) assert len(todays_article_matched_topics.keys()) > 0 random_10_topic = random.sample(list(topic_matched_dict.keys()), 10) @@ -37,23 +37,8 @@ def test_topic_classification(): def test_extract_generalized_topic(): - todays_articles, _, _, _ = load_test_articles() - for article in todays_articles: + candidate, _, _, _ = load_test_articles() + for article in candidate.articles: generalized_topics = extract_general_topics(article) for topic in generalized_topics: assert topic in GENERAL_TOPICS - - -def test_user_topic_pref(): - todays_articles, past_articles, click_data, num_recs = load_test_articles() - interest_profile = InterestProfile.model_validate( - { - "click_history": click_data[0], - "onboarding_topics": [], - } - ) - - user_preference_dict = user_topic_preference(past_articles, click_data[0]) - algo_params = {"user_topic_preference": user_preference_dict} - recommendations = select_articles(todays_articles, past_articles, interest_profile, num_recs, algo_params) - assert len(recommendations) > 0 diff --git a/tests/test_select.py b/tests/test_select.py index 5222e814..a143e6d9 100644 --- a/tests/test_select.py +++ b/tests/test_select.py @@ -1,8 +1,10 @@ from uuid import uuid4 -from poprox_concepts import Article, ClickHistory, Entity, Mention +from poprox_concepts import Article, ArticleSet, ClickHistory, Entity, Mention from poprox_concepts.domain.profile import AccountInterest, InterestProfile -from poprox_recommender.default import select_by_topic +from poprox_recommender.filters import TopicFilter +from poprox_recommender.pipeline import RecommendationPipeline +from poprox_recommender.samplers import UniformSampler def test_select_by_topic_filters_articles(): @@ -37,17 +39,31 @@ def test_select_by_topic_filters_articles(): ), ] + topic_filter = TopicFilter() + sampler = UniformSampler(num_slots=2) + + pipeline = RecommendationPipeline(name="random_topical") + pipeline.add(topic_filter, inputs=["candidate", "profile"], output="topical") + pipeline.add(sampler, inputs=["topical", "candidate"], output="recs") + # If we can, only select articles matching interests - recs = select_by_topic(articles, profile, num_slots=2) + inputs = { + "candidate": ArticleSet(articles=articles), + "clicked": ArticleSet(articles=[]), + "profile": profile, + } + recs = pipeline(inputs) - for article in recs: + for article in recs.articles: topics = [mention.entity.name for mention in article.mentions] assert "U.S. News" in topics or "Politics" in topics # If we need to, fill out the end of the list with other random articles - recs = select_by_topic(articles, profile, num_slots=3) - assert len(recs) == 3 + sampler.num_slots = 3 + recs = pipeline(inputs) + + assert len(recs.articles) == 3 - for article in recs[:2]: + for article in recs.articles[:2]: topics = [mention.entity.name for mention in article.mentions] assert "U.S. News" in topics or "Politics" in topics diff --git a/tests/test_smoke.py b/tests/test_smoke.py index 12fcbcbd..b806d887 100644 --- a/tests/test_smoke.py +++ b/tests/test_smoke.py @@ -5,6 +5,7 @@ import logging from pathlib import Path +from poprox_concepts import ArticleSet, ClickHistory from poprox_concepts.api.recommendations import RecommendationRequest from poprox_recommender.default import select_articles @@ -18,10 +19,29 @@ def test_direct_basic_request(): logger.info("generating recommendations") recs = select_articles( - req.todays_articles, - req.past_articles, + ArticleSet(articles=req.todays_articles), + ArticleSet(articles=req.past_articles), req.interest_profile, req.num_recs, ) # do we get recommendations? - assert len(recs) > 0 + assert len(recs.articles) > 0 + + +def test_direct_basic_request_without_clicks(): + test_dir = Path(__file__) + req_f = test_dir.parent / "basic-request.json" + req = RecommendationRequest.model_validate_json(req_f.read_text()) + + logger.info("generating recommendations") + + profile = req.interest_profile + profile.click_history = ClickHistory(article_ids=[]) + recs = select_articles( + ArticleSet(articles=req.todays_articles), + ArticleSet(articles=req.past_articles), + req.interest_profile, + req.num_recs, + ) + # do we get recommendations? + assert len(recs.articles) > 0