Skip to content

Commit

Permalink
Use tf.while_loop to sample slots in PenalizedPlackettLuce.
Browse files Browse the repository at this point in the history
This is more efficient than using a for loop, especially when the number of slots is large.

PiperOrigin-RevId: 650275537
Change-Id: I6c8f4d5975020da9aff9c98d6405da84920eecd5
  • Loading branch information
TF-Agents Team authored and copybara-github committed Jul 8, 2024
1 parent c846013 commit c0ef082
Show file tree
Hide file tree
Showing 2 changed files with 117 additions and 33 deletions.
143 changes: 110 additions & 33 deletions tf_agents/bandits/policies/ranking_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# limitations under the License.

"""Ranking policy."""
from typing import Optional, Sequence, Text
from typing import Optional, Text

import numpy as np
import tensorflow as tf # pylint: disable=g-explicit-tensorflow-version-import
Expand Down Expand Up @@ -62,32 +62,85 @@ def __init__(
def _penalizer_fn(
self,
logits: types.Float,
features: types.Float,
slots: Sequence[types.Int],
slots: tf.Tensor,
num_slotted: tf.Tensor,
):
"""Downscores items by their similarity to already selected items.
Args:
logits: The current logits of all items.
features: the feature vectors of the items.
slots: list of indices of already selected items.
logits: The current logits of all items, shaped as [batch_size,
num_items].
slots: A tensor of indices of the selected items, shaped as [batch_size,
num_slots]. Only the first `num_slotted` columns correspond to valid
indices.
num_slotted: The number of slots filled so far.
Returns:
New logits.
"""
raise NotImplementedError()

def _sample_n(self, n, seed=None):
# TODO(b/251139151): Support n > 1.
del n
# The scores (logits) of all items, shaped as [batch_size, num_items].
logits = tf.convert_to_tensor(self.scores)
sample_shape = tf.concat([[n], tf.shape(logits)], axis=0)
slots = []
for _ in range(self._num_slots):
items = tfd.Categorical(logits=logits).sample()
slots.append(items)
logits -= tf.one_hot(items, sample_shape[-1], on_value=np.inf)
logits = self._penalizer_fn(logits, self._features, slots)
sample = tf.expand_dims(tf.stack(slots, axis=-1), axis=0)
return sample
# The index of the next slot to sample.
slot_idx = tf.constant(0, dtype=tf.int32)
# The indices of the items that have been sampled, shaped as
# [batch_size, num_slots].
slots = tf.zeros(
shape=(tf.shape(logits)[0], self._num_slots), dtype=tf.int32
)

def _sample_next_slot(slot_idx, slots, logits):
# Samples the batch of item indices for the next slot.
items = tf.ensure_shape(
tfd.Categorical(logits=logits).sample(),
(self._features.shape[0],),
name='ensure_shape_items',
)
slots = tf.ensure_shape(
slots,
(self._features.shape[0], self._num_slots),
name='ensure_shape_slots',
)
# Updates the indices of sampled items by incorporating the sampled item
# indices for the next slot.
slots = tf.ensure_shape(
slots
+ tf.expand_dims(items, axis=-1)
* tf.expand_dims(
tf.one_hot(
slot_idx,
self._num_slots,
dtype=tf.int32,
name='one_hot_for_slot_idx',
),
0,
),
(self._features.shape[0], self._num_slots),
name='ensure_shape_slots_after_update',
)
# Discounts the scores (logits) of the items that have been sampled, so
# they will not be selected again.
logits -= tf.one_hot(
items, logits.shape[-1], on_value=np.inf, name='one_hot_for_items'
)
# Applies the penalty function to the logits.
logits = tf.ensure_shape(
self._penalizer_fn(logits, slots, num_slotted=slot_idx + 1),
self.scores.shape,
)
return slot_idx + 1, slots, logits

_, slots, _ = tf.while_loop(
cond=lambda slot_idx, slots, logits: True,
body=_sample_next_slot,
loop_vars=(slot_idx, slots, logits),
maximum_iterations=self._num_slots,
)
return tf.expand_dims(slots, axis=0)

def _event_shape(self, scores=None):
return self._num_slots
Expand All @@ -96,31 +149,55 @@ def _event_shape(self, scores=None):
class CosinePenalizedPlackettLuce(PenalizedPlackettLuce):
"""A distribution that samples items based on scores and cosine similarity."""

def _penalizer_fn(self, logits, features, slots):
def __init__(
self,
features: types.Tensor,
num_slots: int,
logits: types.Tensor,
penalty_mixture_coefficient: float = 1.0,
):
"""Initializes an instance of CosinePenalizedPlackettLuce.
Args:
features: Item features based on which similarity is calculated.
num_slots: The number of slots to fill: this many items will be sampled.
logits: Unnormalized log probabilities for the PlackettLuce distribution.
Shape is `[num_items]`.
penalty_mixture_coefficient: A parameter responsible for the balance
between selecting high scoring items and enforcing diverisity.
"""
super().__init__(features, num_slots, logits, penalty_mixture_coefficient)
num_items = features.shape[1]
# Computes the cosine similarity matrix between all items, shaped as
# [batch_size, num_items, num_items].
self._sim_matrix = tf.reshape(
tf.keras.losses.cosine_similarity(
tf.repeat(features, num_items, axis=1, name='repeat_features'),
tf.tile(
features,
[1, num_items, 1],
name='tile_features',
),
)
- 1,
shape=[-1, num_items, num_items],
)

def _penalizer_fn(self, logits, slots, num_slotted):
num_items = logits.shape[-1]
num_slotted = len(slots)
slot_tensor = tf.stack(slots, axis=-1)
# Gathers the pairwise similarity matrix between all items and the items
# that have been selected, shaped as [batch_size, num_slotted, num_items].
# The tfd.Categorical distribution will give the sample `num_items` if all
# the logits are `-inf`. Hence, we need to apply minimum. This happens when
# `num_actions` is less than `num_slots`. To this end, the action taken by
# the policy always has to be taken together with the `num_actions`
# observation, to know how many slots are filled with valid items.
slotted_features = tf.gather(
features, tf.minimum(slot_tensor, num_items - 1), batch_dims=1
)

# Calculate the similarity between all pairs from
# `slotted_features x all_features`.
all_sims = (
tf.keras.losses.cosine_similarity(
tf.repeat(features, num_slotted, axis=1),
tf.tile(slotted_features, [1, num_items, 1]),
)
- 1
sim_matrix_against_slotted = tf.gather(
self._sim_matrix,
tf.minimum(slots[..., :num_slotted], num_items - 1),
batch_dims=1,
)

sim_matrix = tf.reshape(all_sims, shape=[-1, num_items, num_slotted])
similarity_boosts = tf.reduce_min(sim_matrix, axis=-1)
similarity_boosts = tf.reduce_min(sim_matrix_against_slotted, axis=1)
adjusted_logits = logits + (
self._penalty_mixture_coefficient * similarity_boosts
)
Expand Down
7 changes: 7 additions & 0 deletions tf_agents/bandits/policies/ranking_policy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@
# limitations under the License.

"""Tests for ranking_policy."""

from absl.testing import parameterized
import numpy as np
import tensorflow as tf # pylint: disable=g-explicit-tensorflow-version-import
from tf_agents.bandits.networks import global_and_arm_feature_network as arm_net
from tf_agents.bandits.policies import ranking_policy
Expand Down Expand Up @@ -84,8 +86,13 @@ def testPolicy(self, policy_class, batch_size, num_items, num_slots):
)
time_spec = ts.restart(observation, batch_size=batch_size)
action_step = policy.action(time_spec)
unique_item_counts = tf.map_fn(
lambda action: tf.unique_with_counts(action)[2], action_step.action
)
self.evaluate(tf.compat.v1.global_variables_initializer())
self.assertAllEqual(action_step.action.shape, [batch_size, num_slots])
# All ranked items should appear exactly once in the ranked list.
self.assertAllEqual(unique_item_counts, np.ones((batch_size, num_slots)))

def testTemperature(self):
if not tf.executing_eagerly():
Expand Down

0 comments on commit c0ef082

Please sign in to comment.