Skip to content

Commit

Permalink
Create and apply a RecommendationPipeline class (#56)
Browse files Browse the repository at this point in the history
This introduces a simple-ish `RecommendationPipeline` that works with
the components created in #43 and #53. Conceptually, the components run
in the order they're added to the pipeline and modify shared state that
each component can select inputs from and write outputs to. It could be
generalized (e.g. to support types other than `ArticleSet` and
`InterestProfile`), but this is enough to support the recommenders we're
currently building/using.
  • Loading branch information
karlhigley authored Jul 8, 2024
1 parent 1eaa423 commit 121a1c9
Show file tree
Hide file tree
Showing 15 changed files with 227 additions and 114 deletions.
92 changes: 28 additions & 64 deletions src/poprox_recommender/default.py
Original file line number Diff line number Diff line change
@@ -1,91 +1,55 @@
import random
from typing import Any
from uuid import UUID

from poprox_concepts import Article, ArticleSet, InterestProfile
from poprox_concepts import ArticleSet, InterestProfile
from poprox_recommender.diversifiers import MMRDiversifier, PFARDiversifier
from poprox_recommender.embedders import ArticleEmbedder, UserEmbedder
from poprox_recommender.filters import TopicFilter
from poprox_recommender.model import get_model
from poprox_recommender.pipeline import RecommendationPipeline
from poprox_recommender.samplers import UniformSampler
from poprox_recommender.scorers import ArticleScorer
from poprox_recommender.topics import user_topic_preference


def select_articles(
candidate_articles: list[Article],
past_articles: list[Article],
candidate_articles: ArticleSet,
clicked_articles: ArticleSet,
interest_profile: InterestProfile,
num_slots: int,
algo_params: dict[str, Any] | None = None,
) -> dict[UUID, list[Article]]:
candidate_articles = ArticleSet(articles=candidate_articles)
past_articles = ArticleSet(articles=past_articles)

click_history = interest_profile.click_history
clicked_articles = list(filter(lambda a: a.article_id in set(click_history.article_ids), past_articles.articles))
clicked_articles = ArticleSet(articles=clicked_articles)

# This could be a component but should likely be moved upstream to the platform
interest_profile.click_topic_counts = user_topic_preference(past_articles.articles, interest_profile.click_history)
profile_id = interest_profile.profile_id

) -> ArticleSet:
algo_params = algo_params or {}
diversify = str(algo_params.get("diversity_algo", "pfar"))

model = get_model()
recommendations = {}

# The following code should ONLY access the InterestProfile and ArticleSets defined above
if model and click_history.article_ids:
if model and interest_profile.click_history.article_ids:
article_embedder = ArticleEmbedder(model.model, model.tokenizer, model.device)
user_embedder = UserEmbedder(model.model, model.device)
article_scorer = ArticleScorer(model.model)

candidate_articles = article_embedder(candidate_articles)
clicked_articles = article_embedder(clicked_articles)

interest_profile = user_embedder(clicked_articles, interest_profile)
candidate_articles = article_scorer(candidate_articles, interest_profile)

if diversify == "mmr":
diversifier = MMRDiversifier(algo_params, num_slots)
recs = diversifier(candidate_articles)

elif diversify == "pfar":
diversifier = PFARDiversifier(algo_params, num_slots)
recs = diversifier(candidate_articles, interest_profile)

recommendations[profile_id] = recs.articles
else:
recommendations[profile_id] = select_by_topic(
candidate_articles.articles,
interest_profile,
num_slots,
)

return recommendations


def select_by_topic(todays_articles: list[Article], interest_profile: InterestProfile, num_slots: int):
# Preference values from onboarding are 1-indexed, where 1 means "absolutely no interest."
# We might want to normalize them to 0-indexed somewhere upstream, but in the mean time
# this is one of the simpler ways to filter out topics people aren't interested in from
# their early newsletters
profile_topics = {
interest.entity_name for interest in interest_profile.onboarding_topics if interest.preference > 1
}

other_articles = []
topical_articles = []
for article in todays_articles:
article_topics = {mention.entity.name for mention in article.mentions}
if len(profile_topics.intersection(article_topics)) > 0:
topical_articles.append(article)
else:
other_articles.append(article)

if len(topical_articles) >= num_slots:
return random.sample(topical_articles, num_slots)
pipeline = RecommendationPipeline(name=diversify)
pipeline.add(article_embedder, inputs=["candidate"], output="candidate")
pipeline.add(article_embedder, inputs=["clicked"], output="clicked")
pipeline.add(user_embedder, inputs=["clicked", "profile"], output="profile")
pipeline.add(article_scorer, inputs=["candidate", "profile"], output="candidate")
pipeline.add(diversifier, inputs=["candidate", "profile"], output="recs")
else:
return random.sample(topical_articles, len(topical_articles)) + random.sample(
other_articles, num_slots - len(topical_articles)
)
topic_filter = TopicFilter()
sampler = UniformSampler(num_slots=num_slots)

pipeline = RecommendationPipeline(name="random_topical")
pipeline.add(topic_filter, inputs=["candidate", "profile"], output="topical")
pipeline.add(sampler, inputs=["topical", "candidate"], output="recs")

return pipeline(
{
"candidate": candidate_articles,
"clicked": clicked_articles,
"profile": interest_profile,
}
)
5 changes: 3 additions & 2 deletions src/poprox_recommender/diversifiers/pfar.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,9 @@ def __call__(self, candidate_articles: ArticleSet, interest_profile: InterestPro
for interest in interest_profile.onboarding_topics:
topic_preferences[interest.entity_name] = max(interest.preference - 1, 0)

for topic, click_count in interest_profile.click_topic_counts.items():
topic_preferences[topic] = click_count
if interest_profile.click_topic_counts:
for topic, click_count in interest_profile.click_topic_counts.items():
topic_preferences[topic] = click_count

normalized_topic_prefs = normalized_topic_count(topic_preferences)

Expand Down
4 changes: 2 additions & 2 deletions src/poprox_recommender/embedders/article.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,6 @@ def __call__(self, article_set: ArticleSet) -> ArticleSet:
if len(title_tensor.shape) == 1:
title_tensor = title_tensor.unsqueeze(dim=0)

article_embeddings = self.model.get_news_vector(title_tensor)
article_set.embeddings = self.model.get_news_vector(title_tensor)

return article_set.model_copy(update={"embeddings": article_embeddings})
return article_set
4 changes: 2 additions & 2 deletions src/poprox_recommender/embedders/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,15 @@ def __call__(self, clicked_articles: ArticleSet, interest_profile: InterestProfi

embedding_lookup["PADDED_NEWS"] = th.zeros(list(embedding_lookup.values())[0].size(), device=self.device)

user_embedding = build_user_embedding(
interest_profile.embedding = build_user_embedding(
interest_profile.click_history,
embedding_lookup,
self.model,
self.device,
self.max_clicks,
)

return interest_profile.model_copy(update={"embedding": user_embedding})
return interest_profile


# Compute a vector for each user
Expand Down
3 changes: 3 additions & 0 deletions src/poprox_recommender/filters/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from poprox_recommender.filters.topic import TopicFilter

__all__ = ["TopicFilter"]
20 changes: 20 additions & 0 deletions src/poprox_recommender/filters/topic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from poprox_concepts import ArticleSet, InterestProfile


class TopicFilter:
def __call__(self, candidate: ArticleSet, interest_profile: InterestProfile) -> ArticleSet:
# Preference values from onboarding are 1-indexed, where 1 means "absolutely no interest."
# We might want to normalize them to 0-indexed somewhere upstream, but in the mean time
# this is one of the simpler ways to filter out topics people aren't interested in from
# their early newsletters
profile_topics = {
interest.entity_name for interest in interest_profile.onboarding_topics if interest.preference > 1
}

topical_articles = []
for article in candidate.articles:
article_topics = {mention.entity.name for mention in article.mentions}
if len(profile_topics.intersection(article_topics)) > 0:
topical_articles.append(article)

return ArticleSet(articles=topical_articles)
26 changes: 22 additions & 4 deletions src/poprox_recommender/handler.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import base64
import logging

from poprox_concepts import ArticleSet
from poprox_concepts.api.recommendations import (
RecommendationRequest,
RecommendationResponse,
)
from poprox_recommender.default import select_articles
from poprox_recommender.topics import user_topic_preference

logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
Expand All @@ -29,16 +31,32 @@ def generate_recs(event, context):
logger.info("Using default parameters")

logger.info("Selecting articles...")

# The platform should send an ArticleSet but we'll do it here for now
candidate_articles = ArticleSet(articles=req.todays_articles)

# Similarly, the platform should provided pre-filtered clicked articles
# and compute the topic counts but this shim lets us ignore that issue
# in the actual article selection
profile = req.interest_profile
click_history = profile.click_history
clicked_articles = list(filter(lambda a: a.article_id in set(click_history.article_ids), req.past_articles))
clicked_articles = ArticleSet(articles=clicked_articles)

profile.click_topic_counts = user_topic_preference(req.past_articles, profile.click_history)

recommendations = select_articles(
req.todays_articles,
req.past_articles,
req.interest_profile,
candidate_articles,
clicked_articles,
profile,
req.num_recs,
algo_params,
)

logger.info("Constructing response...")
resp_body = RecommendationResponse.model_validate({"recommendations": recommendations})
resp_body = RecommendationResponse.model_validate(
{"recommendations": {profile.profile_id: recommendations.articles}}
)

logger.info("Serializing response...")
response = {"statusCode": 200, "body": resp_body.model_dump_json()}
Expand Down
56 changes: 56 additions & 0 deletions src/poprox_recommender/pipeline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
from copy import deepcopy
from dataclasses import dataclass
from typing import Callable

from poprox_concepts import ArticleSet, InterestProfile


@dataclass
class ComponentSpec:
component: Callable
inputs: list[str]
output: str


class RecommendationPipeline:
def __init__(self, name):
self.name = name
self.components = []

def add(self, component: Callable, inputs: list[str], output: str):
self.components.append(ComponentSpec(component, inputs, output))

def __call__(self, inputs: dict[str, ArticleSet | InterestProfile]) -> ArticleSet:
# Avoid modifying the inputs
state = deepcopy(inputs)

# Run each component in the order it was added
for component_spec in self.components:
state = self.run_component(component_spec, state)

recs = state[self.components[-1].output]

# Double check that we're returning the right type for recs
if not isinstance(recs, ArticleSet):
msg = f"The final pipeline component must return ArticleSet, but received {type(recs)}"
raise TypeError(msg)

return recs

def run_component(self, component_spec: ComponentSpec, state: dict[str, ArticleSet | InterestProfile]):
arguments = []
for input_name in component_spec.inputs:
arguments.append(state[input_name])

output = component_spec.component(*arguments)

if not isinstance(output, (ArticleSet, InterestProfile)):
msg = (
f"Pipeline components must return ArticleSet or InterestProfile, "
f"but received {type(output)} from {type(component_spec.component)}"
)
raise TypeError(msg)

state[component_spec.output] = output

return state
3 changes: 3 additions & 0 deletions src/poprox_recommender/samplers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from poprox_recommender.samplers.uniform import UniformSampler

__all__ = ["UniformSampler"]
26 changes: 26 additions & 0 deletions src/poprox_recommender/samplers/uniform.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import random

from poprox_concepts import ArticleSet


class UniformSampler:
def __init__(self, num_slots):
self.num_slots = num_slots

def __call__(self, candidate: ArticleSet, backup: ArticleSet | None = None):
backup_articles = list(filter(lambda article: article not in candidate.articles, backup.articles))

num_backups = (
self.num_slots - len(candidate.articles)
if len(candidate.articles) + len(backup_articles) >= self.num_slots
else len(backup_articles)
)

if len(candidate.articles) < self.num_slots and backup_articles:
sampled = random.sample(candidate.articles, len(candidate.articles)) + random.sample(
backup_articles, num_backups
)
else:
sampled = random.sample(candidate.articles, self.num_slots)

return ArticleSet(articles=sampled)
4 changes: 3 additions & 1 deletion src/poprox_recommender/scorers/article.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,6 @@ def __call__(self, candidate_articles: ArticleSet, interest_profile: InterestPro
user_embedding = interest_profile.embedding

pred = self.model.get_prediction(candidate_embeddings, user_embedding.squeeze())
return candidate_articles.model_copy(update={"scores": pred.cpu().detach().numpy()})
candidate_articles.scores = pred.cpu().detach().numpy()

return candidate_articles
1 change: 0 additions & 1 deletion tests/basic-request.json
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
"interest_profile": {
"profile_id": "28838f05-23f5-4f23-bea2-30b51f67c538",
"click_history": {
"account_id": "977a3c88-937a-46fb-bbfe-94dc5dcb68c8",
"article_ids": [
"e7605f12-a37a-4326-bf3c-3f9b72d0738d"
]
Expand Down
Loading

0 comments on commit 121a1c9

Please sign in to comment.