Skip to content

Commit

Permalink
changes per review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
freedomtan committed Aug 1, 2023
1 parent 21b28dc commit c7ea016
Showing 1 changed file with 8 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
Author: [fchollet](https://twitter.com/fchollet)
Date created: 2021/05/26
Last modified: 2023/02/25
Description: Implementing a sequence-to-sequene Transformer and training it on a machine translation task.
Description: Implementing a sequence-to-sequence Transformer and training it on a machine translation task.
Accelerator: GPU
"""
"""
Expand Down Expand Up @@ -38,8 +38,12 @@
import re
import numpy as np

import tensorflow.data as tf_data
import tensorflow.strings as tf_strings

import keras_core as keras
from keras_core import layers
from keras_core import ops
from keras_core.layers import TextVectorization

"""
Expand Down Expand Up @@ -115,8 +119,6 @@
which you could achieve by providing a custom `split` function to the `TextVectorization` layer.
"""

import tensorflow.strings as tf_strings

strip_chars = string.punctuation + "¿"
strip_chars = strip_chars.replace("[", "")
strip_chars = strip_chars.replace("]", "")
Expand Down Expand Up @@ -162,8 +164,6 @@ def custom_standardization(input_string):
it provides the next words in the target sentence -- what the model will try to predict.
"""

import tensorflow.data as tf_data

def format_dataset(eng, spa):
eng = eng_vectorization(eng)
spa = spa_vectorization(spa)
Expand Down Expand Up @@ -242,7 +242,7 @@ def __init__(self, embed_dim, dense_dim, num_heads, **kwargs):

def call(self, inputs, mask=None):
if mask is not None:
padding_mask = ops.cast(mask[:, np.newaxis, :], dtype="int32")
padding_mask = ops.cast(mask[:, None, :], dtype="int32")
else:
padding_mask = None

Expand Down Expand Up @@ -329,7 +329,7 @@ def __init__(self, embed_dim, latent_dim, num_heads, **kwargs):
def call(self, inputs, encoder_outputs, mask=None):
causal_mask = self.get_causal_attention_mask(inputs)
if mask is not None:
padding_mask = ops.cast(mask[:, np.newaxis, :], dtype="int32")
padding_mask = ops.cast(mask[:, None, :], dtype="int32")
padding_mask = ops.minimum(padding_mask, causal_mask)
else:
padding_mask = None
Expand All @@ -353,7 +353,7 @@ def call(self, inputs, encoder_outputs, mask=None):
def get_causal_attention_mask(self, inputs):
input_shape = ops.shape(inputs)
batch_size, sequence_length = input_shape[0], input_shape[1]
i = ops.arange(sequence_length)[:, np.newaxis]
i = ops.arange(sequence_length)[:, None]
j = ops.arange(sequence_length)
mask = ops.cast(i >= j, dtype="int32")
mask = ops.reshape(mask, (1, input_shape[1], input_shape[1]))
Expand Down

0 comments on commit c7ea016

Please sign in to comment.