Skip to content

Commit

Permalink
Add batching support to t5 denoise preprocessor.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 645488181
  • Loading branch information
gauravmishra authored and t5-copybara committed Jun 21, 2024
1 parent 69cf3b5 commit dd26419
Showing 1 changed file with 118 additions and 49 deletions.
167 changes: 118 additions & 49 deletions t5/data/preprocessors.py
Original file line number Diff line number Diff line change
Expand Up @@ -2744,17 +2744,19 @@ def __call__(self, tokens: tf.Tensor, noise_mask: tf.Tensor, vocabulary,
"""Computes the target tokens. Seeds should have shape [2, 2]."""


def single_example_denoise(features: FeatureType,
seed: tf.Tensor,
*,
output_features: Mapping[str, Any],
noise_density: float,
noise_mask_fn: DenoiseNoiseMaskFn,
inputs_fn: DenoiseInputsFn,
targets_fn: Optional[DenoiseTargetsFn] = None,
passthrough_feature_keys: Optional[
Sequence[str]] = None,
input_feature_key: str = 'inputs') -> FeatureType:
def single_example_denoise(
features: FeatureType,
seed: tf.Tensor,
*,
output_features: Mapping[str, Any],
noise_density: float,
noise_mask_fn: DenoiseNoiseMaskFn,
inputs_fn: DenoiseInputsFn,
targets_fn: Optional[DenoiseTargetsFn] = None,
passthrough_feature_keys: Optional[Sequence[str]] = None,
input_feature_key: str = 'inputs',
batch_size: int | None = None,
) -> FeatureType:
"""Preprocessing function for self-supervised denoising tasks.
This function takes a dataset containing "targets" sequences,
Expand Down Expand Up @@ -2796,6 +2798,8 @@ def single_example_denoise(features: FeatureType,
targets_fn: a function from (tokens, noise_mask, vocabulary) -> tokens
passthrough_feature_keys: names of additional features to include in output
input_feature_key: name of feature to use as inputs
batch_size: an optional int indicating batch size if `features` is a dict of
batched features.
Returns:
A preprocessed features.
Expand All @@ -2814,7 +2818,14 @@ def single_example_denoise(features: FeatureType,
raise ValueError(
'denoise creates inputs based on tokenized targets but was applied '
'to a task that uses different vocabularies for inputs and targets.')
noise_mask = noise_mask_fn(tf.size(tokens), noise_density, seeds=seeds[:2])
if batch_size:
# This step will fail if the noise_mask_fn, inputs_fn or targets_fn don't
# support a batch_size arg.
noise_mask_fn = functools.partial(noise_mask_fn, batch_size=batch_size)
inputs_fn = functools.partial(inputs_fn, batch_size=batch_size)
targets_fn = functools.partial(targets_fn, batch_size=batch_size)
length = tf.size(tokens) // (batch_size or 1)
noise_mask = noise_mask_fn(length, noise_density, seeds=seeds[:2])
inputs = inputs_fn(tokens, noise_mask, vocabulary, seeds=seeds[2:4])
if targets_fn:
targets = targets_fn(tokens, noise_mask, vocabulary, seeds=seeds[4:6])
Expand Down Expand Up @@ -2916,11 +2927,14 @@ def regular_noise_mask(length,


@gin.configurable()
def random_spans_noise_mask(length,
noise_density,
seeds,
mean_noise_span_length=3.0,
random_roll=False):
def random_spans_noise_mask(
length,
noise_density,
seeds,
mean_noise_span_length=3.0,
random_roll=False,
batch_size=None,
):
"""Noise mask consisting of random spans of noise tokens.
The number of noise tokens and the number of noise spans and non-noise spans
Expand All @@ -2943,9 +2957,12 @@ def random_spans_noise_mask(length,
of masked positions. Specifically, when random_roll is False (default) and
a single span is enough to satisfy the noise density requirement, this
fuction masks only the last few positions.
batch_size: an int32; if set, a batch of masks of shape [batch_size, length]
is returned.
Returns:
a boolean tensor with shape [length]
a boolean tensor with shape [length] or [batch_size, length] if batch_size
is set.
"""

if noise_density == 0.0:
Expand All @@ -2966,6 +2983,11 @@ def to_float(x):
# avoid degeneracy by ensuring positive number of noise spans
num_noise_spans = tf.maximum(num_noise_spans, 1)
num_nonnoise_tokens = length - num_noise_tokens
if batch_size:
# Create seeds to generate masks for each row.
seeds = tf.unstack(
tf.random.experimental.stateless_split(seeds[0], batch_size * 2)
)
# pick the lengths of the noise spans and the non-noise spans
def _random_segmentation(num_items, num_segments, seed):
"""Partition a sequence of items randomly into non-empty segments.
Expand All @@ -2986,29 +3008,35 @@ def _random_segmentation(num_items, num_segments, seed):
segment_id = tf.cumsum(first_in_segment)
segment_length = tf.math.segment_sum(tf.ones_like(segment_id), segment_id)
return segment_length
noise_span_lengths = _random_segmentation(
num_noise_tokens, num_noise_spans, seeds[0])
nonnoise_span_lengths = _random_segmentation(
num_nonnoise_tokens, num_noise_spans, seeds[1])
interleaved_span_lengths = tf.reshape(
tf.stack([nonnoise_span_lengths, noise_span_lengths], axis=1),
[num_noise_spans * 2])
span_starts = tf.cumsum(interleaved_span_lengths)[:-1]
span_start_indicator = tf.math.unsorted_segment_sum(
tf.ones_like(span_starts), span_starts, length)
span_num = tf.cumsum(span_start_indicator)
is_noise = tf.equal(span_num % 2, 1)

mask = is_noise[:orig_length]

if random_roll:
roll_seed = (seeds[0][0]+seeds[1][1], seeds[0][1]-seeds[1][0]) # new seed.
# Roll the mask by a random offset e.g. for offset=2: [1,2,3,4] => [3,4,1,2]
offset = tf.random.stateless_uniform(
[1], seed=roll_seed, dtype=tf.int32, minval=0, maxval=length)[0]
mask = tf.roll(mask, shift=offset, axis=0)

return mask
masks = []
for i in range(batch_size or 1):
noise_span_lengths = _random_segmentation(
num_noise_tokens, num_noise_spans, seeds[2 * i])
nonnoise_span_lengths = _random_segmentation(
num_nonnoise_tokens, num_noise_spans, seeds[2 * i + 1])
interleaved_span_lengths = tf.reshape(
tf.stack([nonnoise_span_lengths, noise_span_lengths], axis=1),
[num_noise_spans * 2])
span_starts = tf.cumsum(interleaved_span_lengths)[:-1]
span_start_indicator = tf.math.unsorted_segment_sum(
tf.ones_like(span_starts), span_starts, length)
span_num = tf.cumsum(span_start_indicator)
is_noise = tf.equal(span_num % 2, 1)

mask = is_noise[:orig_length]

if random_roll:
roll_seed = (seeds[0][0]+seeds[1][1], seeds[0][1]-seeds[1][0]) # new seed
# Roll the mask by a random offset e.g. for offset=2: [1,2,3,4] =>
# [3,4,1,2]
offset = tf.random.stateless_uniform(
[1], seed=roll_seed, dtype=tf.int32, minval=0, maxval=length)[0]
mask = tf.roll(mask, shift=offset, axis=0)
masks.append(mask)

if not batch_size:
return masks[0]
return tf.stack(masks, axis=0)


@gin.configurable()
Expand Down Expand Up @@ -3110,7 +3138,9 @@ def nonnoise_span_to_sentinel(tokens, noise_mask, vocabulary, seeds):


@gin.configurable()
def noise_span_to_unique_sentinel(tokens, noise_mask, vocabulary, seeds):
def noise_span_to_unique_sentinel(
tokens, noise_mask, vocabulary, seeds, batch_size=None
):
"""Replace each run of consecutive noise tokens with a different sentinel.
The idea here is to be able to align the dropped spans in the inputs
Expand All @@ -3132,28 +3162,67 @@ def noise_span_to_unique_sentinel(tokens, noise_mask, vocabulary, seeds):
noise_mask: a boolean Tensor with the same shape as tokens
vocabulary: a vocabulary.Vocabulary
seeds: an unused int32 Tensor
batch_size: an optional int32; if tokens are batched.
Returns:
a Tensor with the same shape and dtype as tokens
"""
del seeds

prev_token_is_noise = tf.pad(noise_mask[:-1], [[1, 0]])
if batch_size:
def shift_batched_right_by_one(arr, fill_value):
if not (arr.dtype.is_integer or arr.dtype.is_floating):
raise ValueError(f'Only numeric types are supported. Got: {arr.dtype}')
# tf.roll wraps around the axis.
rolled = tf.roll(arr, shift=1, axis=1)

# Zero out the first position by multiplying with [0, 1, 1, ..., 1].
depth = tf.shape(arr)[1]
mask = tf.one_hot(
0, depth=depth, on_value=0, off_value=1, dtype=arr.dtype
)
# Tile the mask to match batch size
shape = tf.shape(arr)
mask = tf.tile(tf.expand_dims(mask, axis=0), [batch_size, 1])
# Broadcast mask to match shape of rolled
for _ in range(len(shape) - 2):
mask = tf.expand_dims(mask, axis=-1)
return rolled * mask + (1 - mask) * fill_value
int_mask = tf.cast(noise_mask, tf.int32)
shifted_mask = shift_batched_right_by_one(int_mask, fill_value=0)
prev_token_is_noise = tf.cast(shifted_mask, tf.bool)
else:
prev_token_is_noise = tf.pad(noise_mask[:-1], [[1, 0]])

first_noise_tokens = tf.logical_and(
noise_mask, tf.logical_not(prev_token_is_noise))
subsequent_noise_tokens = tf.logical_and(noise_mask, prev_token_is_noise)

sentinel = sentinel_id(vocabulary) + 1 - tf.cumsum(
tf.cast(first_noise_tokens, tokens.dtype))
sentinel = (
sentinel_id(vocabulary)
+ 1
- tf.cumsum(tf.cast(first_noise_tokens, tokens.dtype), axis=-1)
)

tokens = tf.where(first_noise_tokens, sentinel, tokens)
return tf.boolean_mask(tokens, tf.logical_not(subsequent_noise_tokens))
denoised = tf.ragged.boolean_mask(
tokens, tf.logical_not(subsequent_noise_tokens)
)
if isinstance(denoised, tf.RaggedTensor):
denoised = denoised.to_tensor()
return denoised


@gin.configurable()
def nonnoise_span_to_unique_sentinel(tokens, noise_mask, vocabulary, seeds):
def nonnoise_span_to_unique_sentinel(
tokens, noise_mask, vocabulary, seeds, batch_size=None
):
return noise_span_to_unique_sentinel(
tokens, tf.logical_not(noise_mask), vocabulary, seeds)
tokens,
tf.logical_not(noise_mask),
vocabulary,
seeds,
batch_size=batch_size,
)


@gin.configurable()
Expand Down

0 comments on commit dd26419

Please sign in to comment.