diff --git a/keras_nlp/utils/__init__.py b/keras_nlp/utils/__init__.py index 905bfd23f1..cf1975a26a 100644 --- a/keras_nlp/utils/__init__.py +++ b/keras_nlp/utils/__init__.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from keras_nlp.utils.text_generation import beam_search from keras_nlp.utils.text_generation import greedy_search from keras_nlp.utils.text_generation import random_search from keras_nlp.utils.text_generation import top_k_search diff --git a/keras_nlp/utils/text_generation.py b/keras_nlp/utils/text_generation.py index 231d4162dd..fbae45ce9b 100644 --- a/keras_nlp/utils/text_generation.py +++ b/keras_nlp/utils/text_generation.py @@ -160,6 +160,156 @@ def token_probability_fn(inputs): return prompt +def beam_search( + token_probability_fn, + prompt, + max_length, + num_beams, + from_logits=False, + end_token_id=None, + pad_token_id=0, +): + """ + Text generation utility based on beam search algorithm. + + At each time-step, beam search keeps the beams (sequences) of the top + `num_beams` highest accumulated probabilities, and uses each one of the + beams to predict candidate next tokens. + + Args: + token_probability_fn: a callable, which takes in input_sequence + and output the probability distribution of the next token. If + `from_logits` set to True, it should output the logits of the next + token. The input shape would be `[batch_size, length]` and the + output should be `[batch_size, vocab_size]`, where batch_size is + variable. + prompt: a list or a Tensor, can be 1D or 2D, the initial tokens to + append generated tokens. The initial beam for beam search. + max_length: int. The max length of generated text. + num_beams: int. The number of beams that should be kept at each + time-step. `num_beams` should be strictly positive. + from_logits: bool. Indicates whether `token_probability_fn` outputs + logits or probabilities. + end_token_id: int, defaults to None. The token marking the end of the + sequence, once encountered the generation is finished for the exact + sequence. If None, every sequence is generated up to `max_length`. + If set, all tokens after encountering `end_token_id` will be + replaced with `pad_token_id`. + pad_token_id: int, defaults to 0. The pad token after `end_token_id` + is received. + + Returns: + A 1D int Tensor, or 2D int Tensor representing the generated + sequences. + + Examples: + ```python + BATCH_SIZE = 8 + VOCAB_SIZE = 10 + FEATURE_SIZE = 16 + START_ID = 1 + END_ID = 2 + + # Create a dummy model to predict the next token. + model = tf.keras.Sequential( + [ + tf.keras.Input(shape=[None]), + tf.keras.layers.Embedding( + input_dim=VOCAB_SIZE, + output_dim=FEATURE_SIZE, + ), + tf.keras.layers.Dense(VOCAB_SIZE, activation="softmax"), + ] + ) + + # Define a function that outputs the next token's probability given the + # input sequence. + def token_probability_fn(inputs): + return model(inputs)[:, -1, :] + + prompt = tf.fill((BATCH_SIZE, 1), START_ID) + + # Print the generated sequence (token ids). + keras_nlp.utils.beam_search( + token_probability_fn, + prompt, + max_length=10, + num_beams=5, + end_token_id=END_ID, + ) + ``` + + """ + if not tf.executing_eagerly(): + raise RuntimeError( + "`keras_nlp.utils.beam_search` currently requires an eager " + "execution context. Please call `beam_search` outside " + "tf.function or run `tf.config.run_functions_eagerly(True)` to run " + "tf.function in eager mode." + ) + if num_beams <= 0: + raise ValueError( + f"`num_beams` should be strictly positive. Received: `num_beams={num_beams}`." + ) + + prompt = validate_prompt(prompt) + + input_is_1d = prompt.shape.rank == 1 + if input_is_1d: + prompt = prompt[tf.newaxis, :] + validate_token_probability_fn(token_probability_fn, prompt) + + batch_size, length = prompt.shape + if length < max_length: + # Initialize beam. + beams = tf.expand_dims(prompt, 1) + beams_prob = tf.zeros([batch_size, 1]) + i = length + while i < max_length: + beam_size = beams.shape[1] + beam_preds = [] + for j in range(beam_size): + preds = token_probability_fn(beams[:, j, :]) + if from_logits: + preds = tf.keras.activations.softmax(preds, axis=-1) + beam_preds.append(preds) + stacked_preds = tf.stack(beam_preds, axis=1) + vocab_size = stacked_preds.shape[2] + logits = tf.reshape( + stacked_preds, [batch_size, beam_size * vocab_size] + ) + probs = tf.math.log(logits) + tf.repeat( + beams_prob, repeats=vocab_size, axis=1 + ) + num_beams = min(beam_size * vocab_size, num_beams) + candidate_prob, candidate_indexes = tf.math.top_k( + probs, k=num_beams, sorted=False + ) + candidate_beam_indexes = candidate_indexes // vocab_size + next_token = candidate_indexes % vocab_size + + beams = tf.gather( + beams, candidate_beam_indexes, axis=1, batch_dims=1 + ) + beams = tf.concat([beams, next_token[..., tf.newaxis]], axis=-1) + beams_prob = candidate_prob + i += 1 + # Get the beam with the maximum probability. + max_indexes = tf.math.argmax(beams_prob, axis=-1) + max_beams = tf.gather( + beams, max_indexes[:, tf.newaxis], axis=1, batch_dims=1 + ) + prompt = tf.squeeze(max_beams) + + if end_token_id is not None: + prompt = mask_tokens_after_end_token( + prompt, max_length, end_token_id, pad_token_id + ) + if input_is_1d: + return tf.squeeze(prompt) + return prompt + + def random_search( token_probability_fn, prompt, @@ -361,7 +511,7 @@ def token_probability_fn(inputs): "tf.function in eager mode." ) if k <= 0: - raise ValueError(f"`k` should strictly positive. Received: `k={k}`.") + raise ValueError(f"`k` should be strictly positive. Received: `k={k}`.") prompt = validate_prompt(prompt) input_is_1d = prompt.shape.rank == 1 @@ -378,7 +528,7 @@ def token_probability_fn(inputs): # If k is greater than the vocabulary size, use the entire vocabulary. k = min(k, pred.shape[1]) # Filter out top-k tokens. - top_k_pred, top_k_indices = tf.math.top_k(pred, k=k) + top_k_pred, top_k_indices = tf.math.top_k(pred, k=k, sorted=False) # Sample the next token from the probability distribution. next_token = tf.random.categorical( tf.math.log(top_k_pred), 1, seed=seed diff --git a/keras_nlp/utils/text_generation_test.py b/keras_nlp/utils/text_generation_test.py index 5b1da81965..5e6f6a2ca0 100644 --- a/keras_nlp/utils/text_generation_test.py +++ b/keras_nlp/utils/text_generation_test.py @@ -13,10 +13,12 @@ # limitations under the License. """Tests for Text Generation Utils.""" +import random import numpy as np import tensorflow as tf +from keras_nlp.utils.text_generation import beam_search from keras_nlp.utils.text_generation import greedy_search from keras_nlp.utils.text_generation import random_search from keras_nlp.utils.text_generation import top_k_search @@ -111,6 +113,160 @@ def token_probability_fn(inputs): self.assertAllEqual(outputs, expected_outputs) +class BeamSearchTextGenerationTest(tf.test.TestCase): + def setUp(self): + super().setUp() + vocab_size = 10 + feature_size = 16 + + # Create a dummy model to predict the next token. + model = tf.keras.Sequential( + [ + tf.keras.Input(shape=[None]), + tf.keras.layers.Embedding( + input_dim=vocab_size, + output_dim=feature_size, + ), + tf.keras.layers.Dense(vocab_size), + tf.keras.layers.Softmax(), + ] + ) + + def token_probability_fn(inputs): + return model(inputs)[:, -1, :] + + self.token_probability_fn = token_probability_fn + + def test_generate_with_empty_prompt(self): + inputs = tf.constant([]) + with self.assertRaises(ValueError): + beam_search( + self.token_probability_fn, inputs, max_length=5, num_beams=5 + ) + inputs = tf.constant([[]]) + with self.assertRaises(ValueError): + beam_search( + self.token_probability_fn, inputs, max_length=5, num_beams=5 + ) + + def test_generate_with_1d_prompt(self): + inputs = tf.constant([1]) + outputs = beam_search( + self.token_probability_fn, + inputs, + max_length=5, + num_beams=5, + ) + self.assertEquals(outputs.shape, [5]) + + def test_generate_with_2d_prompt(self): + inputs = tf.constant([[1], [1]]) + outputs = beam_search( + self.token_probability_fn, + inputs, + max_length=5, + num_beams=5, + ) + self.assertEquals(outputs.shape, [2, 5]) + + def test_generate_with_list_prompt(self): + inputs = [[1], [1]] + outputs = beam_search( + self.token_probability_fn, + inputs, + max_length=5, + num_beams=5, + ) + self.assertEquals(outputs.shape, [2, 5]) + + def test_generate_with_ragged_prompt(self): + inputs = tf.ragged.constant([[1], [2, 3]]) + with self.assertRaises(ValueError): + beam_search( + self.token_probability_fn, + inputs, + max_length=5, + num_beams=5, + ) + + def test_one_beam_generation(self): + for i in range(50): + inputs = tf.constant([random.randint(0, 9)]) + beam_output = beam_search( + self.token_probability_fn, + inputs, + max_length=5, + num_beams=1, + ) + greedy_output = greedy_search( + self.token_probability_fn, + inputs, + max_length=5, + ) + self.assertAllEqual(beam_output, greedy_output) + + def test_multiple_beam_generation(self): + def token_probability_fn(inputs): + if inputs.shape[1] == 1: + prob = tf.constant([[0.1, 0.2, 0.3, 0.4]]) + elif inputs[0][1] == 2: + prob = tf.constant([[0.9, 0.08, 0.01, 0.01]]) + elif inputs[0][1] == 3: + prob = tf.constant([[0.25, 0.25, 0.25, 0.25]]) + return prob + + inputs = tf.constant([[1]]) + beam_output = beam_search( + token_probability_fn, + inputs, + max_length=3, + num_beams=2, + ) + self.assertAllEqual( + beam_output, tf.constant([1, 2, 0], dtype=beam_output.dtype) + ) + + def test_assert_generation_is_correct(self): + def token_probability_fn(inputs): + batch_size = inputs.shape[0] + prob = tf.constant([[0.01, 0.01, 0.08, 0.9]]) + return tf.repeat(prob, batch_size, axis=0) + + batch_size = 10 + inputs = 3 * tf.ones([batch_size, 1], dtype=tf.int32) + max_length = 3 + for i in range(1, 10): + outputs = beam_search( + token_probability_fn, + inputs, + max_length=max_length, + num_beams=i, + ) + self.assertAllEqual( + outputs, 3 * tf.ones(shape=[batch_size, max_length]) + ) + + def test_end_token_id(self): + def token_probability_fn(inputs): + batch_size = inputs.shape[0] + prob = tf.constant([[0.01, 0.01, 0.08, 0.9]]) + return tf.repeat(prob, batch_size, axis=0) + + max_length = 5 + inputs = tf.constant([[0, 1], [1, 2]]) + outputs = beam_search( + token_probability_fn, + inputs, + max_length=max_length, + num_beams=2, + end_token_id=2, + pad_token_id=0, + ) + expected_outputs = tf.tile([[3], [0]], [1, max_length - 2]) + expected_outputs = tf.concat([inputs, expected_outputs], axis=1) + self.assertAllEqual(outputs, expected_outputs) + + class RandomSearchTextGenerationTest(tf.test.TestCase): def setUp(self): super().setUp() @@ -334,8 +490,6 @@ def token_probability_fn(inputs): ) # Top-k sampling result with seed 42. seeded_result = 3 * np.ones(shape=[batch_size, max_length]) - seeded_result[3][1] = 2 - seeded_result[7][1] = 2 self.assertAllEqual(outputs, seeded_result) def test_assert_probability_distribution_generation_is_correct(self):