Skip to content

Commit

Permalink
fix(PositionalEncoding): better positional scaling
Browse files Browse the repository at this point in the history
  • Loading branch information
marcpinet committed Nov 22, 2024
1 parent acdea91 commit 760e603
Showing 1 changed file with 83 additions and 50 deletions.
133 changes: 83 additions & 50 deletions neuralnetlib/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2684,6 +2684,9 @@ def __init__(
self,
max_sequence_length: int,
embedding_dim: int,
warmup_steps: int = 1000,
initial_scale: float = 0.1,
final_scale: float = 1.0,
scale_embeddings: bool = True,
trainable: bool = False,
random_state: int = None,
Expand All @@ -2692,27 +2695,32 @@ def __init__(
super().__init__()
self.max_sequence_length = max_sequence_length
self.embedding_dim = embedding_dim
self.warmup_steps = warmup_steps
self.initial_scale = initial_scale
self.final_scale = final_scale
self.scale_embeddings = scale_embeddings
self.trainable = trainable
self.random_state = random_state

self.current_step = 0
self.current_scale = initial_scale

self.base_scale_factor = 1.0 if not scale_embeddings else 1.0 / np.sqrt(embedding_dim)

self.weights = None
self.d_weights = None
self.seq_length = None
self.scale_factor = 1.0 if not scale_embeddings else 1.0 / np.sqrt(embedding_dim)


if trainable:
self.rng = np.random.default_rng(random_state if random_state is not None else int(time.time_ns()))
self.rng = np.random.default_rng(
random_state if random_state is not None else int(time.time_ns())
)
self.initialize_weights()
else:
self._build_sinusoidal_encoding()

for key, value in kwargs.items():
setattr(self, key, value)

def __str__(self) -> str:
return (f'PositionalEncoding(seq_length={self.seq_length}, '
f'embedding_dim={self.embedding_dim}, trainable={self.trainable}, '
f'scale_embeddings={self.scale_embeddings})')

def _build_sinusoidal_encoding(self) -> None:
position = np.arange(self.max_sequence_length)[:, np.newaxis]
Expand All @@ -2723,74 +2731,99 @@ def _build_sinusoidal_encoding(self) -> None:
)

pe = np.zeros((self.max_sequence_length, self.embedding_dim), dtype=np.float32)

pe[:, 0::2] = np.sin(position * div_term)
pe[:, 1::2] = np.cos(position * div_term)

self.weights = pe[np.newaxis, :, :]

self.d_weights = np.zeros_like(self.weights)

def initialize_weights(self) -> None:
if self.trainable:
limit = 0.02
self.weights = self.rng.uniform(
-limit, limit,
(1, self.max_sequence_length, self.embedding_dim)
self._build_sinusoidal_encoding()

noise = self.rng.normal(
0,
0.02,
self.weights.shape
)
self.weights = self.weights + noise
self.d_weights = np.zeros_like(self.weights)

def forward_pass(self, input_data: np.ndarray) -> np.ndarray:
batch_size, seq_len, _ = input_data.shape
def get_warmup_scale(self) -> float:
if self.current_step >= self.warmup_steps:
return self.final_scale

if seq_len > self.max_sequence_length:
raise ValueError(
f"Input sequence length {seq_len} exceeds maximum length {self.max_sequence_length}"
)

pos_encoding = self.weights[:, :seq_len, :]

if batch_size > 1:
pos_encoding = np.repeat(pos_encoding, batch_size, axis=0)
progress = self.current_step / self.warmup_steps
return self.initial_scale + (self.final_scale - self.initial_scale) * progress

def forward_pass(self, input_data: np.ndarray) -> np.ndarray:
batch_size, seq_len, _ = input_data.shape

if self.scale_embeddings:
return pos_encoding + input_data
else:
return input_data + (pos_encoding * 0.1)
if seq_len > self.max_sequence_length:
raise ValueError(
f"Input sequence length {seq_len} exceeds maximum length {self.max_sequence_length}"
)

pos_encoding = self.weights[:, :seq_len, :]

if batch_size > 1:
pos_encoding = np.repeat(pos_encoding, batch_size, axis=0)

if self.trainable:
self.current_scale = self.get_warmup_scale()
else:
self.current_scale = self.final_scale

effective_scale = self.current_scale * self.base_scale_factor if self.scale_embeddings else self.current_scale

output = input_data + (pos_encoding * effective_scale)

self.metadata = {
'step': self.current_step,
'scale': self.current_scale,
'effective_scale': effective_scale,
'embedding_contribution': np.mean(np.abs(pos_encoding * effective_scale)) / np.mean(np.abs(input_data))
}

return output

def backward_pass(self, output_error: np.ndarray) -> np.ndarray:
if self.trainable:
_, seq_len, _ = output_error.shape

if self.scale_embeddings:
output_error = output_error * self.scale_factor

self.d_weights[:, :seq_len, :] += np.sum(output_error, axis=0, keepdims=True)
effective_scale = self.current_scale * self.base_scale_factor if self.scale_embeddings else self.current_scale
scaled_error = output_error * effective_scale

self.d_weights[:, :seq_len, :] += np.sum(scaled_error, axis=0, keepdims=True)

self.current_step += 1

return output_error

def get_config(self) -> dict:
return {
'name': self.__class__.__name__,
'seq_length': self.seq_length,
'max_sequence_length': self.max_sequence_length,
'embedding_dim': self.embedding_dim,
'warmup_steps': self.warmup_steps,
'initial_scale': self.initial_scale,
'final_scale': self.final_scale,
'scale_embeddings': self.scale_embeddings,
'trainable': self.trainable,
'random_state': self.random_state,
'weights': self.weights.tolist() if self.weights is not None else None
'current_step': self.current_step,
'current_scale': self.current_scale
}

@staticmethod
def from_config(config: dict) -> "PositionalEncoding":
layer = PositionalEncoding(
seq_length=config['seq_length'],
embedding_dim=config['embedding_dim'],
trainable=config['trainable'],
random_state=config['random_state']
def __str__(self) -> str:
return (
f'PositionalEncodingWithWarmup('
f'seq_length={self.seq_length}, '
f'embedding_dim={self.embedding_dim}, '
f'trainable={self.trainable}, '
f'warmup_steps={self.warmup_steps}, '
f'current_step={self.current_step}, '
f'current_scale={self.current_scale:.4f})'
)
if config['weights'] is not None:
layer.weights = np.array(config['weights'])
layer.d_weights = np.zeros_like(layer.weights)
return layer


class FeedForward(Layer):
Expand All @@ -2799,7 +2832,7 @@ def __init__(
d_ff: int,
d_model: int,
dropout_rate: float = 0.1,
activation: str = 'relu',
activation: str = 'gelu',
kernel_initializer: str = "glorot_uniform",
bias_initializer: str = "zeros",
random_state: int = None,
Expand Down Expand Up @@ -3152,7 +3185,7 @@ def __init__(
d_ff: int,
dropout_rate: float = 0.1,
attention_dropout: float = 0.0,
activation: str = 'relu',
activation: str = 'gelu',
kernel_initializer: str = "glorot_uniform",
bias_initializer: str = "zeros",
random_state: int | None = None,
Expand Down

0 comments on commit 760e603

Please sign in to comment.