Skip to content

Commit

Permalink
Add beam search decoding util (#237)
Browse files Browse the repository at this point in the history
* beam search

* minor fixes

* minor changes

* minor style changes

* Fixed bug

* naming and docstring updates

* temporary change in setup

* undo setup change

* updated init to export utils

* init change

* style changes

* converted to loop based beam search

* docstring description changes
  • Loading branch information
jessechancy authored Jun 30, 2022
1 parent 2b69891 commit f9abc8f
Show file tree
Hide file tree
Showing 3 changed files with 309 additions and 4 deletions.
1 change: 1 addition & 0 deletions keras_nlp/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from keras_nlp.utils.text_generation import beam_search
from keras_nlp.utils.text_generation import greedy_search
from keras_nlp.utils.text_generation import random_search
from keras_nlp.utils.text_generation import top_k_search
Expand Down
154 changes: 152 additions & 2 deletions keras_nlp/utils/text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,156 @@ def token_probability_fn(inputs):
return prompt


def beam_search(
token_probability_fn,
prompt,
max_length,
num_beams,
from_logits=False,
end_token_id=None,
pad_token_id=0,
):
"""
Text generation utility based on beam search algorithm.
At each time-step, beam search keeps the beams (sequences) of the top
`num_beams` highest accumulated probabilities, and uses each one of the
beams to predict candidate next tokens.
Args:
token_probability_fn: a callable, which takes in input_sequence
and output the probability distribution of the next token. If
`from_logits` set to True, it should output the logits of the next
token. The input shape would be `[batch_size, length]` and the
output should be `[batch_size, vocab_size]`, where batch_size is
variable.
prompt: a list or a Tensor, can be 1D or 2D, the initial tokens to
append generated tokens. The initial beam for beam search.
max_length: int. The max length of generated text.
num_beams: int. The number of beams that should be kept at each
time-step. `num_beams` should be strictly positive.
from_logits: bool. Indicates whether `token_probability_fn` outputs
logits or probabilities.
end_token_id: int, defaults to None. The token marking the end of the
sequence, once encountered the generation is finished for the exact
sequence. If None, every sequence is generated up to `max_length`.
If set, all tokens after encountering `end_token_id` will be
replaced with `pad_token_id`.
pad_token_id: int, defaults to 0. The pad token after `end_token_id`
is received.
Returns:
A 1D int Tensor, or 2D int Tensor representing the generated
sequences.
Examples:
```python
BATCH_SIZE = 8
VOCAB_SIZE = 10
FEATURE_SIZE = 16
START_ID = 1
END_ID = 2
# Create a dummy model to predict the next token.
model = tf.keras.Sequential(
[
tf.keras.Input(shape=[None]),
tf.keras.layers.Embedding(
input_dim=VOCAB_SIZE,
output_dim=FEATURE_SIZE,
),
tf.keras.layers.Dense(VOCAB_SIZE, activation="softmax"),
]
)
# Define a function that outputs the next token's probability given the
# input sequence.
def token_probability_fn(inputs):
return model(inputs)[:, -1, :]
prompt = tf.fill((BATCH_SIZE, 1), START_ID)
# Print the generated sequence (token ids).
keras_nlp.utils.beam_search(
token_probability_fn,
prompt,
max_length=10,
num_beams=5,
end_token_id=END_ID,
)
```
"""
if not tf.executing_eagerly():
raise RuntimeError(
"`keras_nlp.utils.beam_search` currently requires an eager "
"execution context. Please call `beam_search` outside "
"tf.function or run `tf.config.run_functions_eagerly(True)` to run "
"tf.function in eager mode."
)
if num_beams <= 0:
raise ValueError(
f"`num_beams` should be strictly positive. Received: `num_beams={num_beams}`."
)

prompt = validate_prompt(prompt)

input_is_1d = prompt.shape.rank == 1
if input_is_1d:
prompt = prompt[tf.newaxis, :]
validate_token_probability_fn(token_probability_fn, prompt)

batch_size, length = prompt.shape
if length < max_length:
# Initialize beam.
beams = tf.expand_dims(prompt, 1)
beams_prob = tf.zeros([batch_size, 1])
i = length
while i < max_length:
beam_size = beams.shape[1]
beam_preds = []
for j in range(beam_size):
preds = token_probability_fn(beams[:, j, :])
if from_logits:
preds = tf.keras.activations.softmax(preds, axis=-1)
beam_preds.append(preds)
stacked_preds = tf.stack(beam_preds, axis=1)
vocab_size = stacked_preds.shape[2]
logits = tf.reshape(
stacked_preds, [batch_size, beam_size * vocab_size]
)
probs = tf.math.log(logits) + tf.repeat(
beams_prob, repeats=vocab_size, axis=1
)
num_beams = min(beam_size * vocab_size, num_beams)
candidate_prob, candidate_indexes = tf.math.top_k(
probs, k=num_beams, sorted=False
)
candidate_beam_indexes = candidate_indexes // vocab_size
next_token = candidate_indexes % vocab_size

beams = tf.gather(
beams, candidate_beam_indexes, axis=1, batch_dims=1
)
beams = tf.concat([beams, next_token[..., tf.newaxis]], axis=-1)
beams_prob = candidate_prob
i += 1
# Get the beam with the maximum probability.
max_indexes = tf.math.argmax(beams_prob, axis=-1)
max_beams = tf.gather(
beams, max_indexes[:, tf.newaxis], axis=1, batch_dims=1
)
prompt = tf.squeeze(max_beams)

if end_token_id is not None:
prompt = mask_tokens_after_end_token(
prompt, max_length, end_token_id, pad_token_id
)
if input_is_1d:
return tf.squeeze(prompt)
return prompt


def random_search(
token_probability_fn,
prompt,
Expand Down Expand Up @@ -361,7 +511,7 @@ def token_probability_fn(inputs):
"tf.function in eager mode."
)
if k <= 0:
raise ValueError(f"`k` should strictly positive. Received: `k={k}`.")
raise ValueError(f"`k` should be strictly positive. Received: `k={k}`.")

prompt = validate_prompt(prompt)
input_is_1d = prompt.shape.rank == 1
Expand All @@ -378,7 +528,7 @@ def token_probability_fn(inputs):
# If k is greater than the vocabulary size, use the entire vocabulary.
k = min(k, pred.shape[1])
# Filter out top-k tokens.
top_k_pred, top_k_indices = tf.math.top_k(pred, k=k)
top_k_pred, top_k_indices = tf.math.top_k(pred, k=k, sorted=False)
# Sample the next token from the probability distribution.
next_token = tf.random.categorical(
tf.math.log(top_k_pred), 1, seed=seed
Expand Down
158 changes: 156 additions & 2 deletions keras_nlp/utils/text_generation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,12 @@
# limitations under the License.
"""Tests for Text Generation Utils."""

import random

import numpy as np
import tensorflow as tf

from keras_nlp.utils.text_generation import beam_search
from keras_nlp.utils.text_generation import greedy_search
from keras_nlp.utils.text_generation import random_search
from keras_nlp.utils.text_generation import top_k_search
Expand Down Expand Up @@ -111,6 +113,160 @@ def token_probability_fn(inputs):
self.assertAllEqual(outputs, expected_outputs)


class BeamSearchTextGenerationTest(tf.test.TestCase):
def setUp(self):
super().setUp()
vocab_size = 10
feature_size = 16

# Create a dummy model to predict the next token.
model = tf.keras.Sequential(
[
tf.keras.Input(shape=[None]),
tf.keras.layers.Embedding(
input_dim=vocab_size,
output_dim=feature_size,
),
tf.keras.layers.Dense(vocab_size),
tf.keras.layers.Softmax(),
]
)

def token_probability_fn(inputs):
return model(inputs)[:, -1, :]

self.token_probability_fn = token_probability_fn

def test_generate_with_empty_prompt(self):
inputs = tf.constant([])
with self.assertRaises(ValueError):
beam_search(
self.token_probability_fn, inputs, max_length=5, num_beams=5
)
inputs = tf.constant([[]])
with self.assertRaises(ValueError):
beam_search(
self.token_probability_fn, inputs, max_length=5, num_beams=5
)

def test_generate_with_1d_prompt(self):
inputs = tf.constant([1])
outputs = beam_search(
self.token_probability_fn,
inputs,
max_length=5,
num_beams=5,
)
self.assertEquals(outputs.shape, [5])

def test_generate_with_2d_prompt(self):
inputs = tf.constant([[1], [1]])
outputs = beam_search(
self.token_probability_fn,
inputs,
max_length=5,
num_beams=5,
)
self.assertEquals(outputs.shape, [2, 5])

def test_generate_with_list_prompt(self):
inputs = [[1], [1]]
outputs = beam_search(
self.token_probability_fn,
inputs,
max_length=5,
num_beams=5,
)
self.assertEquals(outputs.shape, [2, 5])

def test_generate_with_ragged_prompt(self):
inputs = tf.ragged.constant([[1], [2, 3]])
with self.assertRaises(ValueError):
beam_search(
self.token_probability_fn,
inputs,
max_length=5,
num_beams=5,
)

def test_one_beam_generation(self):
for i in range(50):
inputs = tf.constant([random.randint(0, 9)])
beam_output = beam_search(
self.token_probability_fn,
inputs,
max_length=5,
num_beams=1,
)
greedy_output = greedy_search(
self.token_probability_fn,
inputs,
max_length=5,
)
self.assertAllEqual(beam_output, greedy_output)

def test_multiple_beam_generation(self):
def token_probability_fn(inputs):
if inputs.shape[1] == 1:
prob = tf.constant([[0.1, 0.2, 0.3, 0.4]])
elif inputs[0][1] == 2:
prob = tf.constant([[0.9, 0.08, 0.01, 0.01]])
elif inputs[0][1] == 3:
prob = tf.constant([[0.25, 0.25, 0.25, 0.25]])
return prob

inputs = tf.constant([[1]])
beam_output = beam_search(
token_probability_fn,
inputs,
max_length=3,
num_beams=2,
)
self.assertAllEqual(
beam_output, tf.constant([1, 2, 0], dtype=beam_output.dtype)
)

def test_assert_generation_is_correct(self):
def token_probability_fn(inputs):
batch_size = inputs.shape[0]
prob = tf.constant([[0.01, 0.01, 0.08, 0.9]])
return tf.repeat(prob, batch_size, axis=0)

batch_size = 10
inputs = 3 * tf.ones([batch_size, 1], dtype=tf.int32)
max_length = 3
for i in range(1, 10):
outputs = beam_search(
token_probability_fn,
inputs,
max_length=max_length,
num_beams=i,
)
self.assertAllEqual(
outputs, 3 * tf.ones(shape=[batch_size, max_length])
)

def test_end_token_id(self):
def token_probability_fn(inputs):
batch_size = inputs.shape[0]
prob = tf.constant([[0.01, 0.01, 0.08, 0.9]])
return tf.repeat(prob, batch_size, axis=0)

max_length = 5
inputs = tf.constant([[0, 1], [1, 2]])
outputs = beam_search(
token_probability_fn,
inputs,
max_length=max_length,
num_beams=2,
end_token_id=2,
pad_token_id=0,
)
expected_outputs = tf.tile([[3], [0]], [1, max_length - 2])
expected_outputs = tf.concat([inputs, expected_outputs], axis=1)
self.assertAllEqual(outputs, expected_outputs)


class RandomSearchTextGenerationTest(tf.test.TestCase):
def setUp(self):
super().setUp()
Expand Down Expand Up @@ -334,8 +490,6 @@ def token_probability_fn(inputs):
)
# Top-k sampling result with seed 42.
seeded_result = 3 * np.ones(shape=[batch_size, max_length])
seeded_result[3][1] = 2
seeded_result[7][1] = 2
self.assertAllEqual(outputs, seeded_result)

def test_assert_probability_distribution_generation_is_correct(self):
Expand Down

0 comments on commit f9abc8f

Please sign in to comment.