diff --git a/src/poprox_recommender/default.py b/src/poprox_recommender/default.py index 120afea9..7141bd07 100644 --- a/src/poprox_recommender/default.py +++ b/src/poprox_recommender/default.py @@ -1,91 +1,55 @@ -import random from typing import Any -from uuid import UUID -from poprox_concepts import Article, ArticleSet, InterestProfile +from poprox_concepts import ArticleSet, InterestProfile from poprox_recommender.diversifiers import MMRDiversifier, PFARDiversifier from poprox_recommender.embedders import ArticleEmbedder, UserEmbedder +from poprox_recommender.filters import TopicFilter from poprox_recommender.model import get_model +from poprox_recommender.pipeline import RecommendationPipeline +from poprox_recommender.samplers import UniformSampler from poprox_recommender.scorers import ArticleScorer -from poprox_recommender.topics import user_topic_preference def select_articles( - candidate_articles: list[Article], - past_articles: list[Article], + candidate_articles: ArticleSet, + clicked_articles: ArticleSet, interest_profile: InterestProfile, num_slots: int, algo_params: dict[str, Any] | None = None, -) -> dict[UUID, list[Article]]: - candidate_articles = ArticleSet(articles=candidate_articles) - past_articles = ArticleSet(articles=past_articles) - - click_history = interest_profile.click_history - clicked_articles = list(filter(lambda a: a.article_id in set(click_history.article_ids), past_articles.articles)) - clicked_articles = ArticleSet(articles=clicked_articles) - - # This could be a component but should likely be moved upstream to the platform - interest_profile.click_topic_counts = user_topic_preference(past_articles.articles, interest_profile.click_history) - profile_id = interest_profile.profile_id - +) -> ArticleSet: algo_params = algo_params or {} diversify = str(algo_params.get("diversity_algo", "pfar")) model = get_model() - recommendations = {} - # The following code should ONLY access the InterestProfile and ArticleSets defined above - if model and click_history.article_ids: + if model and interest_profile.click_history.article_ids: article_embedder = ArticleEmbedder(model.model, model.tokenizer, model.device) user_embedder = UserEmbedder(model.model, model.device) article_scorer = ArticleScorer(model.model) - candidate_articles = article_embedder(candidate_articles) - clicked_articles = article_embedder(clicked_articles) - - interest_profile = user_embedder(clicked_articles, interest_profile) - candidate_articles = article_scorer(candidate_articles, interest_profile) - if diversify == "mmr": diversifier = MMRDiversifier(algo_params, num_slots) - recs = diversifier(candidate_articles) - elif diversify == "pfar": diversifier = PFARDiversifier(algo_params, num_slots) - recs = diversifier(candidate_articles, interest_profile) - - recommendations[profile_id] = recs.articles - else: - recommendations[profile_id] = select_by_topic( - candidate_articles.articles, - interest_profile, - num_slots, - ) - - return recommendations - - -def select_by_topic(todays_articles: list[Article], interest_profile: InterestProfile, num_slots: int): - # Preference values from onboarding are 1-indexed, where 1 means "absolutely no interest." - # We might want to normalize them to 0-indexed somewhere upstream, but in the mean time - # this is one of the simpler ways to filter out topics people aren't interested in from - # their early newsletters - profile_topics = { - interest.entity_name for interest in interest_profile.onboarding_topics if interest.preference > 1 - } - - other_articles = [] - topical_articles = [] - for article in todays_articles: - article_topics = {mention.entity.name for mention in article.mentions} - if len(profile_topics.intersection(article_topics)) > 0: - topical_articles.append(article) - else: - other_articles.append(article) - if len(topical_articles) >= num_slots: - return random.sample(topical_articles, num_slots) + pipeline = RecommendationPipeline(name=diversify) + pipeline.add(article_embedder, inputs=["candidate"], output="candidate") + pipeline.add(article_embedder, inputs=["clicked"], output="clicked") + pipeline.add(user_embedder, inputs=["clicked", "profile"], output="profile") + pipeline.add(article_scorer, inputs=["candidate", "profile"], output="candidate") + pipeline.add(diversifier, inputs=["candidate", "profile"], output="recs") else: - return random.sample(topical_articles, len(topical_articles)) + random.sample( - other_articles, num_slots - len(topical_articles) - ) + topic_filter = TopicFilter() + sampler = UniformSampler(num_slots=num_slots) + + pipeline = RecommendationPipeline(name="random_topical") + pipeline.add(topic_filter, inputs=["candidate", "profile"], output="topical") + pipeline.add(sampler, inputs=["topical", "candidate"], output="recs") + + return pipeline( + { + "candidate": candidate_articles, + "clicked": clicked_articles, + "profile": interest_profile, + } + ) diff --git a/src/poprox_recommender/diversifiers/pfar.py b/src/poprox_recommender/diversifiers/pfar.py index 668ad988..9d7de7b9 100644 --- a/src/poprox_recommender/diversifiers/pfar.py +++ b/src/poprox_recommender/diversifiers/pfar.py @@ -20,8 +20,9 @@ def __call__(self, candidate_articles: ArticleSet, interest_profile: InterestPro for interest in interest_profile.onboarding_topics: topic_preferences[interest.entity_name] = max(interest.preference - 1, 0) - for topic, click_count in interest_profile.click_topic_counts.items(): - topic_preferences[topic] = click_count + if interest_profile.click_topic_counts: + for topic, click_count in interest_profile.click_topic_counts.items(): + topic_preferences[topic] = click_count normalized_topic_prefs = normalized_topic_count(topic_preferences) diff --git a/src/poprox_recommender/embedders/article.py b/src/poprox_recommender/embedders/article.py index 3875e4e8..e547a196 100644 --- a/src/poprox_recommender/embedders/article.py +++ b/src/poprox_recommender/embedders/article.py @@ -20,6 +20,6 @@ def __call__(self, article_set: ArticleSet) -> ArticleSet: if len(title_tensor.shape) == 1: title_tensor = title_tensor.unsqueeze(dim=0) - article_embeddings = self.model.get_news_vector(title_tensor) + article_set.embeddings = self.model.get_news_vector(title_tensor) - return article_set.model_copy(update={"embeddings": article_embeddings}) + return article_set diff --git a/src/poprox_recommender/embedders/user.py b/src/poprox_recommender/embedders/user.py index 426222aa..0b9317fa 100644 --- a/src/poprox_recommender/embedders/user.py +++ b/src/poprox_recommender/embedders/user.py @@ -17,7 +17,7 @@ def __call__(self, clicked_articles: ArticleSet, interest_profile: InterestProfi embedding_lookup["PADDED_NEWS"] = th.zeros(list(embedding_lookup.values())[0].size(), device=self.device) - user_embedding = build_user_embedding( + interest_profile.embedding = build_user_embedding( interest_profile.click_history, embedding_lookup, self.model, @@ -25,7 +25,7 @@ def __call__(self, clicked_articles: ArticleSet, interest_profile: InterestProfi self.max_clicks, ) - return interest_profile.model_copy(update={"embedding": user_embedding}) + return interest_profile # Compute a vector for each user diff --git a/src/poprox_recommender/filters/__init__.py b/src/poprox_recommender/filters/__init__.py new file mode 100644 index 00000000..b7b9e254 --- /dev/null +++ b/src/poprox_recommender/filters/__init__.py @@ -0,0 +1,3 @@ +from poprox_recommender.filters.topic import TopicFilter + +__all__ = ["TopicFilter"] diff --git a/src/poprox_recommender/filters/topic.py b/src/poprox_recommender/filters/topic.py new file mode 100644 index 00000000..1cab0100 --- /dev/null +++ b/src/poprox_recommender/filters/topic.py @@ -0,0 +1,20 @@ +from poprox_concepts import ArticleSet, InterestProfile + + +class TopicFilter: + def __call__(self, candidate: ArticleSet, interest_profile: InterestProfile) -> ArticleSet: + # Preference values from onboarding are 1-indexed, where 1 means "absolutely no interest." + # We might want to normalize them to 0-indexed somewhere upstream, but in the mean time + # this is one of the simpler ways to filter out topics people aren't interested in from + # their early newsletters + profile_topics = { + interest.entity_name for interest in interest_profile.onboarding_topics if interest.preference > 1 + } + + topical_articles = [] + for article in candidate.articles: + article_topics = {mention.entity.name for mention in article.mentions} + if len(profile_topics.intersection(article_topics)) > 0: + topical_articles.append(article) + + return ArticleSet(articles=topical_articles) diff --git a/src/poprox_recommender/handler.py b/src/poprox_recommender/handler.py index d15b6ee3..7234f01f 100644 --- a/src/poprox_recommender/handler.py +++ b/src/poprox_recommender/handler.py @@ -1,11 +1,13 @@ import base64 import logging +from poprox_concepts import ArticleSet from poprox_concepts.api.recommendations import ( RecommendationRequest, RecommendationResponse, ) from poprox_recommender.default import select_articles +from poprox_recommender.topics import user_topic_preference logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) @@ -29,16 +31,32 @@ def generate_recs(event, context): logger.info("Using default parameters") logger.info("Selecting articles...") + + # The platform should send an ArticleSet but we'll do it here for now + candidate_articles = ArticleSet(articles=req.todays_articles) + + # Similarly, the platform should provided pre-filtered clicked articles + # and compute the topic counts but this shim lets us ignore that issue + # in the actual article selection + profile = req.interest_profile + click_history = profile.click_history + clicked_articles = list(filter(lambda a: a.article_id in set(click_history.article_ids), req.past_articles)) + clicked_articles = ArticleSet(articles=clicked_articles) + + profile.click_topic_counts = user_topic_preference(req.past_articles, profile.click_history) + recommendations = select_articles( - req.todays_articles, - req.past_articles, - req.interest_profile, + candidate_articles, + clicked_articles, + profile, req.num_recs, algo_params, ) logger.info("Constructing response...") - resp_body = RecommendationResponse.model_validate({"recommendations": recommendations}) + resp_body = RecommendationResponse.model_validate( + {"recommendations": {profile.profile_id: recommendations.articles}} + ) logger.info("Serializing response...") response = {"statusCode": 200, "body": resp_body.model_dump_json()} diff --git a/src/poprox_recommender/pipeline.py b/src/poprox_recommender/pipeline.py new file mode 100644 index 00000000..f6acce6a --- /dev/null +++ b/src/poprox_recommender/pipeline.py @@ -0,0 +1,56 @@ +from copy import deepcopy +from dataclasses import dataclass +from typing import Callable + +from poprox_concepts import ArticleSet, InterestProfile + + +@dataclass +class ComponentSpec: + component: Callable + inputs: list[str] + output: str + + +class RecommendationPipeline: + def __init__(self, name): + self.name = name + self.components = [] + + def add(self, component: Callable, inputs: list[str], output: str): + self.components.append(ComponentSpec(component, inputs, output)) + + def __call__(self, inputs: dict[str, ArticleSet | InterestProfile]) -> ArticleSet: + # Avoid modifying the inputs + state = deepcopy(inputs) + + # Run each component in the order it was added + for component_spec in self.components: + state = self.run_component(component_spec, state) + + recs = state[self.components[-1].output] + + # Double check that we're returning the right type for recs + if not isinstance(recs, ArticleSet): + msg = f"The final pipeline component must return ArticleSet, but received {type(recs)}" + raise TypeError(msg) + + return recs + + def run_component(self, component_spec: ComponentSpec, state: dict[str, ArticleSet | InterestProfile]): + arguments = [] + for input_name in component_spec.inputs: + arguments.append(state[input_name]) + + output = component_spec.component(*arguments) + + if not isinstance(output, (ArticleSet, InterestProfile)): + msg = ( + f"Pipeline components must return ArticleSet or InterestProfile, " + f"but received {type(output)} from {type(component_spec.component)}" + ) + raise TypeError(msg) + + state[component_spec.output] = output + + return state diff --git a/src/poprox_recommender/samplers/__init__.py b/src/poprox_recommender/samplers/__init__.py new file mode 100644 index 00000000..1c8c71aa --- /dev/null +++ b/src/poprox_recommender/samplers/__init__.py @@ -0,0 +1,3 @@ +from poprox_recommender.samplers.uniform import UniformSampler + +__all__ = ["UniformSampler"] diff --git a/src/poprox_recommender/samplers/uniform.py b/src/poprox_recommender/samplers/uniform.py new file mode 100644 index 00000000..ea0fc8fe --- /dev/null +++ b/src/poprox_recommender/samplers/uniform.py @@ -0,0 +1,26 @@ +import random + +from poprox_concepts import ArticleSet + + +class UniformSampler: + def __init__(self, num_slots): + self.num_slots = num_slots + + def __call__(self, candidate: ArticleSet, backup: ArticleSet | None = None): + backup_articles = list(filter(lambda article: article not in candidate.articles, backup.articles)) + + num_backups = ( + self.num_slots - len(candidate.articles) + if len(candidate.articles) + len(backup_articles) >= self.num_slots + else len(backup_articles) + ) + + if len(candidate.articles) < self.num_slots and backup_articles: + sampled = random.sample(candidate.articles, len(candidate.articles)) + random.sample( + backup_articles, num_backups + ) + else: + sampled = random.sample(candidate.articles, self.num_slots) + + return ArticleSet(articles=sampled) diff --git a/src/poprox_recommender/scorers/article.py b/src/poprox_recommender/scorers/article.py index 2c7e818f..46bef6b6 100644 --- a/src/poprox_recommender/scorers/article.py +++ b/src/poprox_recommender/scorers/article.py @@ -10,4 +10,6 @@ def __call__(self, candidate_articles: ArticleSet, interest_profile: InterestPro user_embedding = interest_profile.embedding pred = self.model.get_prediction(candidate_embeddings, user_embedding.squeeze()) - return candidate_articles.model_copy(update={"scores": pred.cpu().detach().numpy()}) + candidate_articles.scores = pred.cpu().detach().numpy() + + return candidate_articles diff --git a/tests/basic-request.json b/tests/basic-request.json index fb141ec7..0ca86f5a 100644 --- a/tests/basic-request.json +++ b/tests/basic-request.json @@ -18,7 +18,6 @@ "interest_profile": { "profile_id": "28838f05-23f5-4f23-bea2-30b51f67c538", "click_history": { - "account_id": "977a3c88-937a-46fb-bbfe-94dc5dcb68c8", "article_ids": [ "e7605f12-a37a-4326-bf3c-3f9b72d0738d" ] diff --git a/tests/test_pfar.py b/tests/test_pfar.py index 310409b2..6100dcd4 100644 --- a/tests/test_pfar.py +++ b/tests/test_pfar.py @@ -3,9 +3,12 @@ import random from pathlib import Path -from poprox_concepts import Article, ClickHistory, InterestProfile -from poprox_recommender.default import select_articles, user_topic_preference -from poprox_recommender.topics import GENERAL_TOPICS, extract_general_topics, match_news_topics_to_general +from poprox_concepts import Article, ArticleSet, ClickHistory +from poprox_recommender.topics import ( + GENERAL_TOPICS, + extract_general_topics, + match_news_topics_to_general, +) logger = logging.getLogger(__name__) @@ -15,20 +18,17 @@ def load_test_articles(): with open(event_path, "r") as j: req_body = json.loads(j.read()) - todays_articles = [Article.model_validate(attrs) for attrs in req_body["todays_articles"]] - - past_articles = [Article.model_validate(attrs) for attrs in req_body["past_articles"]] - + candidate = ArticleSet(articles=[Article.model_validate(attrs) for attrs in req_body["todays_articles"]]) + past = ArticleSet(articles=[Article.model_validate(attrs) for attrs in req_body["past_articles"]]) click_history = [ClickHistory.model_validate(attrs) for attrs in req_body["click_data"]] - num_recs = req_body["num_recs"] - return todays_articles, past_articles, click_history, num_recs + return candidate, past, click_history, num_recs def test_topic_classification(): - todays_articles, _, _, _ = load_test_articles() - topic_matched_dict, todays_article_matched_topics = match_news_topics_to_general(todays_articles) + candidate, _, _, _ = load_test_articles() + topic_matched_dict, todays_article_matched_topics = match_news_topics_to_general(candidate.articles) assert len(todays_article_matched_topics.keys()) > 0 random_10_topic = random.sample(list(topic_matched_dict.keys()), 10) @@ -37,23 +37,8 @@ def test_topic_classification(): def test_extract_generalized_topic(): - todays_articles, _, _, _ = load_test_articles() - for article in todays_articles: + candidate, _, _, _ = load_test_articles() + for article in candidate.articles: generalized_topics = extract_general_topics(article) for topic in generalized_topics: assert topic in GENERAL_TOPICS - - -def test_user_topic_pref(): - todays_articles, past_articles, click_data, num_recs = load_test_articles() - interest_profile = InterestProfile.model_validate( - { - "click_history": click_data[0], - "onboarding_topics": [], - } - ) - - user_preference_dict = user_topic_preference(past_articles, click_data[0]) - algo_params = {"user_topic_preference": user_preference_dict} - recommendations = select_articles(todays_articles, past_articles, interest_profile, num_recs, algo_params) - assert len(recommendations) > 0 diff --git a/tests/test_select.py b/tests/test_select.py index 5222e814..a143e6d9 100644 --- a/tests/test_select.py +++ b/tests/test_select.py @@ -1,8 +1,10 @@ from uuid import uuid4 -from poprox_concepts import Article, ClickHistory, Entity, Mention +from poprox_concepts import Article, ArticleSet, ClickHistory, Entity, Mention from poprox_concepts.domain.profile import AccountInterest, InterestProfile -from poprox_recommender.default import select_by_topic +from poprox_recommender.filters import TopicFilter +from poprox_recommender.pipeline import RecommendationPipeline +from poprox_recommender.samplers import UniformSampler def test_select_by_topic_filters_articles(): @@ -37,17 +39,31 @@ def test_select_by_topic_filters_articles(): ), ] + topic_filter = TopicFilter() + sampler = UniformSampler(num_slots=2) + + pipeline = RecommendationPipeline(name="random_topical") + pipeline.add(topic_filter, inputs=["candidate", "profile"], output="topical") + pipeline.add(sampler, inputs=["topical", "candidate"], output="recs") + # If we can, only select articles matching interests - recs = select_by_topic(articles, profile, num_slots=2) + inputs = { + "candidate": ArticleSet(articles=articles), + "clicked": ArticleSet(articles=[]), + "profile": profile, + } + recs = pipeline(inputs) - for article in recs: + for article in recs.articles: topics = [mention.entity.name for mention in article.mentions] assert "U.S. News" in topics or "Politics" in topics # If we need to, fill out the end of the list with other random articles - recs = select_by_topic(articles, profile, num_slots=3) - assert len(recs) == 3 + sampler.num_slots = 3 + recs = pipeline(inputs) + + assert len(recs.articles) == 3 - for article in recs[:2]: + for article in recs.articles[:2]: topics = [mention.entity.name for mention in article.mentions] assert "U.S. News" in topics or "Politics" in topics diff --git a/tests/test_smoke.py b/tests/test_smoke.py index 12fcbcbd..b806d887 100644 --- a/tests/test_smoke.py +++ b/tests/test_smoke.py @@ -5,6 +5,7 @@ import logging from pathlib import Path +from poprox_concepts import ArticleSet, ClickHistory from poprox_concepts.api.recommendations import RecommendationRequest from poprox_recommender.default import select_articles @@ -18,10 +19,29 @@ def test_direct_basic_request(): logger.info("generating recommendations") recs = select_articles( - req.todays_articles, - req.past_articles, + ArticleSet(articles=req.todays_articles), + ArticleSet(articles=req.past_articles), req.interest_profile, req.num_recs, ) # do we get recommendations? - assert len(recs) > 0 + assert len(recs.articles) > 0 + + +def test_direct_basic_request_without_clicks(): + test_dir = Path(__file__) + req_f = test_dir.parent / "basic-request.json" + req = RecommendationRequest.model_validate_json(req_f.read_text()) + + logger.info("generating recommendations") + + profile = req.interest_profile + profile.click_history = ClickHistory(article_ids=[]) + recs = select_articles( + ArticleSet(articles=req.todays_articles), + ArticleSet(articles=req.past_articles), + req.interest_profile, + req.num_recs, + ) + # do we get recommendations? + assert len(recs.articles) > 0