Skip to content

Commit

Permalink
keras serializable
Browse files Browse the repository at this point in the history
  • Loading branch information
brandnewchoppa authored Sep 21, 2023
1 parent a654bf6 commit eb85768
Showing 1 changed file with 46 additions and 37 deletions.
83 changes: 46 additions & 37 deletions gau_tensorflow/gau_tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,7 @@

from keras import Model, Sequential
from keras.layers import Layer
from keras.saving import (
serialize_keras_object,
deserialize_keras_object,
register_keras_serializable
)
from keras.saving import register_keras_serializable

from keras.layers import (
Dense,
Expand All @@ -19,6 +15,7 @@
Embedding
)

@register_keras_serializable(package = 'GAUTensorFlow')
class ScaleNorm(Layer):
"""
Scale Normalization (ScaleNorm)
Expand Down Expand Up @@ -64,6 +61,7 @@ def get_config(self):
config.update({'eps': self.eps})
return config

@register_keras_serializable(package = 'GAUTensorFlow')
class RMSNorm(Layer):
"""
Root Mean Square Layer Normalization (RMSNorm)
Expand Down Expand Up @@ -115,6 +113,7 @@ def get_config(self):
})
return config

@register_keras_serializable(package = 'GAUTensorFlow')
class OffsetScale(Layer):
"""
Offset Scale (OffsetScale)
Expand Down Expand Up @@ -161,7 +160,8 @@ def get_config(self):
config = super().get_config()
config.update({'splits': self.splits})
return config


@register_keras_serializable(package = 'GAUTensorFlow')
class RelativePositionBias(Layer):
"""
Relative Position Bias (RelativePositionBias)
Expand Down Expand Up @@ -221,11 +221,11 @@ def get_config(self):
config.update({
'scale': self.scale,
'n_buckets': self.n_buckets,
'max_distance': self.max_distance,
'relative_attention_bias': self.relative_attention_bias
'max_distance': self.max_distance
})
return config

@register_keras_serializable(package = 'GAUTensorFlow')
class ReLUSquared(Layer):
"""
ReLU Squared (ReLUSquared)
Expand All @@ -242,6 +242,7 @@ def __init__(self, **kwargs):
def call(self, x):
return math.square(tf.nn.relu(x))

@register_keras_serializable(package = 'GAUTensorFlow')
class LaplacianAttnFn(Layer):
"""
Laplacian Attention Function (LaplacianAttnFn)
Expand Down Expand Up @@ -328,13 +329,13 @@ def build(self, x_shape):
elif self.norm_type == 'rms_norm':
self.norm = RMSNorm()

self.to_uv = tf.recompute_grad(Dense(
self.to_uv = Dense(
(d * self.expansion_factor) * 2,
activation = 'silu'))
activation = 'silu')

self.to_qk = tf.recompute_grad(Dense(
self.to_qk = Dense(
self.qk_dim,
activation = 'silu'))
activation = 'silu')

self.scale_offset = OffsetScale(
splits = 2)
Expand All @@ -360,6 +361,7 @@ def build(self, x_shape):

self.built = True

@tf.recompute_grad
def _attn(self, x, v):
n = cast(x.shape[-2], x.dtype)
z = self.to_qk(x)
Expand Down Expand Up @@ -402,31 +404,11 @@ def get_config(self):
'norm_type': self.norm_type,
'shift_tokens': self.shift_tokens,
'use_rope': self.use_rope,
'laplace_attn_fn': self.laplace_attn_fn,

'to_uv': serialize_keras_object(self.to_uv),
'to_qk': serialize_keras_object(self.to_qk),
'scale_offset': serialize_keras_object(self.scale_offset),
'rotary_pos_embs': serialize_keras_object(self.rotary_pos_embs),
'rel_pos_bias': serialize_keras_object(self.rel_pos_bias),
'dropout': serialize_keras_object(self.dropout),
'to_out': serialize_keras_object(self.to_out),
'attn_fn': serialize_keras_object(self.attn_fn)
'laplace_attn_fn': self.laplace_attn_fn
})
return config

@classmethod
def from_config(cls, config):
config['to_uv'] = deserialize_keras_object(config['to_uv'])
config['to_qk'] = deserialize_keras_object(config['to_qk'])
config['scale_offset'] = deserialize_keras_object(config['scale_offset'])
config['rotary_pos_embs'] = deserialize_keras_object(config['rotary_pos_embs'])
config['rel_pos_bias'] = deserialize_keras_object(config['rel_pos_bias'])
config['dropout'] = deserialize_keras_object(config['dropuot'])
config['to_out'] = deserialize_keras_object(config['to_out'])
config['attn_fn'] = deserialize_keras_object(config['attn_fn'])
return cls(**config)

@register_keras_serializable(package = 'GAUTensorFlow')
class ScaledSin(Layer):
"""
Sinusoidal Position Embedding with scaling factor. (ScaledSin)
Expand Down Expand Up @@ -457,7 +439,8 @@ def call(self, x):
pos = einsum('s, d -> sd', pos, self.inv_freq)
scaled_emb = tf.concat([ math.sin(pos), math.cos(pos) ], axis = -1)
return tf.cast(scaled_emb, x.dtype) * self.scale


@register_keras_serializable(package = 'GAUTensorFlow')
class GAUTransformer(Model):
def __init__(self,
*,
Expand All @@ -474,7 +457,17 @@ def __init__(self,
laplace_attn_fn : bool = False,
**kwargs):
super().__init__(**kwargs)
self.emb_dim = emb_dim
self.n_tokens = n_tokens
self.depth = depth
self.qk_dim = qk_dim
self.expansion_factor = expansion_factor
self.causal = causal
self.dropout_rate = dropout_rate
self.norm_type = norm_type
self.shift_tokens = shift_tokens
self.use_rope = use_rope
self.laplace_attn_fn = laplace_attn_fn

self.token_emb = Embedding(n_tokens, emb_dim, name = 'embeddings')
self.abs_pos_emb = ScaledSin(name = 'scaled_sin')
Expand All @@ -492,13 +485,29 @@ def __init__(self,
name = f'gau{i}'
) for i in range(depth)], name = 'blocks')

self.to_logits = tf.recompute_grad(Sequential([
self.to_logits = Sequential([
LayerNormalization(),
Dense(n_tokens)
], name = 'logits'))
], name = 'logits')

def call(self, x):
x = self.token_emb(x)
x = self.abs_pos_emb(x) + x
x = self.blocks(x)
return self.to_logits(x)

def get_config(self):
config = super().get_config()
config.update({
'emb_dim': self.emb_dim,
'n_tokens': self.n_tokens,
'depth': self.depth,
'qk_dim': self.qk_dim,
'expansion_factor': self.expansion_factor,
'causal': self.causal,
'dropout_rate': self.dropout_rate,
'norm_type': self.norm_type,
'shift_tokens': self.shift_tokens,
'use_rope': self.use_rope,
'laplace_attn_fn': self.laplace_attn_fn
})

0 comments on commit eb85768

Please sign in to comment.