-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Create and apply a
RecommendationPipeline
class (#56)
This introduces a simple-ish `RecommendationPipeline` that works with the components created in #43 and #53. Conceptually, the components run in the order they're added to the pipeline and modify shared state that each component can select inputs from and write outputs to. It could be generalized (e.g. to support types other than `ArticleSet` and `InterestProfile`), but this is enough to support the recommenders we're currently building/using.
- Loading branch information
1 parent
1eaa423
commit 121a1c9
Showing
15 changed files
with
227 additions
and
114 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,91 +1,55 @@ | ||
import random | ||
from typing import Any | ||
from uuid import UUID | ||
|
||
from poprox_concepts import Article, ArticleSet, InterestProfile | ||
from poprox_concepts import ArticleSet, InterestProfile | ||
from poprox_recommender.diversifiers import MMRDiversifier, PFARDiversifier | ||
from poprox_recommender.embedders import ArticleEmbedder, UserEmbedder | ||
from poprox_recommender.filters import TopicFilter | ||
from poprox_recommender.model import get_model | ||
from poprox_recommender.pipeline import RecommendationPipeline | ||
from poprox_recommender.samplers import UniformSampler | ||
from poprox_recommender.scorers import ArticleScorer | ||
from poprox_recommender.topics import user_topic_preference | ||
|
||
|
||
def select_articles( | ||
candidate_articles: list[Article], | ||
past_articles: list[Article], | ||
candidate_articles: ArticleSet, | ||
clicked_articles: ArticleSet, | ||
interest_profile: InterestProfile, | ||
num_slots: int, | ||
algo_params: dict[str, Any] | None = None, | ||
) -> dict[UUID, list[Article]]: | ||
candidate_articles = ArticleSet(articles=candidate_articles) | ||
past_articles = ArticleSet(articles=past_articles) | ||
|
||
click_history = interest_profile.click_history | ||
clicked_articles = list(filter(lambda a: a.article_id in set(click_history.article_ids), past_articles.articles)) | ||
clicked_articles = ArticleSet(articles=clicked_articles) | ||
|
||
# This could be a component but should likely be moved upstream to the platform | ||
interest_profile.click_topic_counts = user_topic_preference(past_articles.articles, interest_profile.click_history) | ||
profile_id = interest_profile.profile_id | ||
|
||
) -> ArticleSet: | ||
algo_params = algo_params or {} | ||
diversify = str(algo_params.get("diversity_algo", "pfar")) | ||
|
||
model = get_model() | ||
recommendations = {} | ||
|
||
# The following code should ONLY access the InterestProfile and ArticleSets defined above | ||
if model and click_history.article_ids: | ||
if model and interest_profile.click_history.article_ids: | ||
article_embedder = ArticleEmbedder(model.model, model.tokenizer, model.device) | ||
user_embedder = UserEmbedder(model.model, model.device) | ||
article_scorer = ArticleScorer(model.model) | ||
|
||
candidate_articles = article_embedder(candidate_articles) | ||
clicked_articles = article_embedder(clicked_articles) | ||
|
||
interest_profile = user_embedder(clicked_articles, interest_profile) | ||
candidate_articles = article_scorer(candidate_articles, interest_profile) | ||
|
||
if diversify == "mmr": | ||
diversifier = MMRDiversifier(algo_params, num_slots) | ||
recs = diversifier(candidate_articles) | ||
|
||
elif diversify == "pfar": | ||
diversifier = PFARDiversifier(algo_params, num_slots) | ||
recs = diversifier(candidate_articles, interest_profile) | ||
|
||
recommendations[profile_id] = recs.articles | ||
else: | ||
recommendations[profile_id] = select_by_topic( | ||
candidate_articles.articles, | ||
interest_profile, | ||
num_slots, | ||
) | ||
|
||
return recommendations | ||
|
||
|
||
def select_by_topic(todays_articles: list[Article], interest_profile: InterestProfile, num_slots: int): | ||
# Preference values from onboarding are 1-indexed, where 1 means "absolutely no interest." | ||
# We might want to normalize them to 0-indexed somewhere upstream, but in the mean time | ||
# this is one of the simpler ways to filter out topics people aren't interested in from | ||
# their early newsletters | ||
profile_topics = { | ||
interest.entity_name for interest in interest_profile.onboarding_topics if interest.preference > 1 | ||
} | ||
|
||
other_articles = [] | ||
topical_articles = [] | ||
for article in todays_articles: | ||
article_topics = {mention.entity.name for mention in article.mentions} | ||
if len(profile_topics.intersection(article_topics)) > 0: | ||
topical_articles.append(article) | ||
else: | ||
other_articles.append(article) | ||
|
||
if len(topical_articles) >= num_slots: | ||
return random.sample(topical_articles, num_slots) | ||
pipeline = RecommendationPipeline(name=diversify) | ||
pipeline.add(article_embedder, inputs=["candidate"], output="candidate") | ||
pipeline.add(article_embedder, inputs=["clicked"], output="clicked") | ||
pipeline.add(user_embedder, inputs=["clicked", "profile"], output="profile") | ||
pipeline.add(article_scorer, inputs=["candidate", "profile"], output="candidate") | ||
pipeline.add(diversifier, inputs=["candidate", "profile"], output="recs") | ||
else: | ||
return random.sample(topical_articles, len(topical_articles)) + random.sample( | ||
other_articles, num_slots - len(topical_articles) | ||
) | ||
topic_filter = TopicFilter() | ||
sampler = UniformSampler(num_slots=num_slots) | ||
|
||
pipeline = RecommendationPipeline(name="random_topical") | ||
pipeline.add(topic_filter, inputs=["candidate", "profile"], output="topical") | ||
pipeline.add(sampler, inputs=["topical", "candidate"], output="recs") | ||
|
||
return pipeline( | ||
{ | ||
"candidate": candidate_articles, | ||
"clicked": clicked_articles, | ||
"profile": interest_profile, | ||
} | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from poprox_recommender.filters.topic import TopicFilter | ||
|
||
__all__ = ["TopicFilter"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
from poprox_concepts import ArticleSet, InterestProfile | ||
|
||
|
||
class TopicFilter: | ||
def __call__(self, candidate: ArticleSet, interest_profile: InterestProfile) -> ArticleSet: | ||
# Preference values from onboarding are 1-indexed, where 1 means "absolutely no interest." | ||
# We might want to normalize them to 0-indexed somewhere upstream, but in the mean time | ||
# this is one of the simpler ways to filter out topics people aren't interested in from | ||
# their early newsletters | ||
profile_topics = { | ||
interest.entity_name for interest in interest_profile.onboarding_topics if interest.preference > 1 | ||
} | ||
|
||
topical_articles = [] | ||
for article in candidate.articles: | ||
article_topics = {mention.entity.name for mention in article.mentions} | ||
if len(profile_topics.intersection(article_topics)) > 0: | ||
topical_articles.append(article) | ||
|
||
return ArticleSet(articles=topical_articles) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
from copy import deepcopy | ||
from dataclasses import dataclass | ||
from typing import Callable | ||
|
||
from poprox_concepts import ArticleSet, InterestProfile | ||
|
||
|
||
@dataclass | ||
class ComponentSpec: | ||
component: Callable | ||
inputs: list[str] | ||
output: str | ||
|
||
|
||
class RecommendationPipeline: | ||
def __init__(self, name): | ||
self.name = name | ||
self.components = [] | ||
|
||
def add(self, component: Callable, inputs: list[str], output: str): | ||
self.components.append(ComponentSpec(component, inputs, output)) | ||
|
||
def __call__(self, inputs: dict[str, ArticleSet | InterestProfile]) -> ArticleSet: | ||
# Avoid modifying the inputs | ||
state = deepcopy(inputs) | ||
|
||
# Run each component in the order it was added | ||
for component_spec in self.components: | ||
state = self.run_component(component_spec, state) | ||
|
||
recs = state[self.components[-1].output] | ||
|
||
# Double check that we're returning the right type for recs | ||
if not isinstance(recs, ArticleSet): | ||
msg = f"The final pipeline component must return ArticleSet, but received {type(recs)}" | ||
raise TypeError(msg) | ||
|
||
return recs | ||
|
||
def run_component(self, component_spec: ComponentSpec, state: dict[str, ArticleSet | InterestProfile]): | ||
arguments = [] | ||
for input_name in component_spec.inputs: | ||
arguments.append(state[input_name]) | ||
|
||
output = component_spec.component(*arguments) | ||
|
||
if not isinstance(output, (ArticleSet, InterestProfile)): | ||
msg = ( | ||
f"Pipeline components must return ArticleSet or InterestProfile, " | ||
f"but received {type(output)} from {type(component_spec.component)}" | ||
) | ||
raise TypeError(msg) | ||
|
||
state[component_spec.output] = output | ||
|
||
return state |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from poprox_recommender.samplers.uniform import UniformSampler | ||
|
||
__all__ = ["UniformSampler"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
import random | ||
|
||
from poprox_concepts import ArticleSet | ||
|
||
|
||
class UniformSampler: | ||
def __init__(self, num_slots): | ||
self.num_slots = num_slots | ||
|
||
def __call__(self, candidate: ArticleSet, backup: ArticleSet | None = None): | ||
backup_articles = list(filter(lambda article: article not in candidate.articles, backup.articles)) | ||
|
||
num_backups = ( | ||
self.num_slots - len(candidate.articles) | ||
if len(candidate.articles) + len(backup_articles) >= self.num_slots | ||
else len(backup_articles) | ||
) | ||
|
||
if len(candidate.articles) < self.num_slots and backup_articles: | ||
sampled = random.sample(candidate.articles, len(candidate.articles)) + random.sample( | ||
backup_articles, num_backups | ||
) | ||
else: | ||
sampled = random.sample(candidate.articles, self.num_slots) | ||
|
||
return ArticleSet(articles=sampled) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.