diff --git a/neuralnetlib/layers.py b/neuralnetlib/layers.py index dea8059..fc9569a 100644 --- a/neuralnetlib/layers.py +++ b/neuralnetlib/layers.py @@ -2684,6 +2684,9 @@ def __init__( self, max_sequence_length: int, embedding_dim: int, + warmup_steps: int = 1000, + initial_scale: float = 0.1, + final_scale: float = 1.0, scale_embeddings: bool = True, trainable: bool = False, random_state: int = None, @@ -2692,27 +2695,32 @@ def __init__( super().__init__() self.max_sequence_length = max_sequence_length self.embedding_dim = embedding_dim + self.warmup_steps = warmup_steps + self.initial_scale = initial_scale + self.final_scale = final_scale self.scale_embeddings = scale_embeddings self.trainable = trainable self.random_state = random_state + + self.current_step = 0 + self.current_scale = initial_scale + + self.base_scale_factor = 1.0 if not scale_embeddings else 1.0 / np.sqrt(embedding_dim) + self.weights = None self.d_weights = None self.seq_length = None - self.scale_factor = 1.0 if not scale_embeddings else 1.0 / np.sqrt(embedding_dim) - + if trainable: - self.rng = np.random.default_rng(random_state if random_state is not None else int(time.time_ns())) + self.rng = np.random.default_rng( + random_state if random_state is not None else int(time.time_ns()) + ) self.initialize_weights() else: self._build_sinusoidal_encoding() - + for key, value in kwargs.items(): setattr(self, key, value) - - def __str__(self) -> str: - return (f'PositionalEncoding(seq_length={self.seq_length}, ' - f'embedding_dim={self.embedding_dim}, trainable={self.trainable}, ' - f'scale_embeddings={self.scale_embeddings})') def _build_sinusoidal_encoding(self) -> None: position = np.arange(self.max_sequence_length)[:, np.newaxis] @@ -2723,74 +2731,99 @@ def _build_sinusoidal_encoding(self) -> None: ) pe = np.zeros((self.max_sequence_length, self.embedding_dim), dtype=np.float32) - pe[:, 0::2] = np.sin(position * div_term) pe[:, 1::2] = np.cos(position * div_term) self.weights = pe[np.newaxis, :, :] - self.d_weights = np.zeros_like(self.weights) def initialize_weights(self) -> None: if self.trainable: - limit = 0.02 - self.weights = self.rng.uniform( - -limit, limit, - (1, self.max_sequence_length, self.embedding_dim) + self._build_sinusoidal_encoding() + + noise = self.rng.normal( + 0, + 0.02, + self.weights.shape ) + self.weights = self.weights + noise self.d_weights = np.zeros_like(self.weights) - def forward_pass(self, input_data: np.ndarray) -> np.ndarray: - batch_size, seq_len, _ = input_data.shape + def get_warmup_scale(self) -> float: + if self.current_step >= self.warmup_steps: + return self.final_scale - if seq_len > self.max_sequence_length: - raise ValueError( - f"Input sequence length {seq_len} exceeds maximum length {self.max_sequence_length}" - ) - - pos_encoding = self.weights[:, :seq_len, :] - - if batch_size > 1: - pos_encoding = np.repeat(pos_encoding, batch_size, axis=0) + progress = self.current_step / self.warmup_steps + return self.initial_scale + (self.final_scale - self.initial_scale) * progress + + def forward_pass(self, input_data: np.ndarray) -> np.ndarray: + batch_size, seq_len, _ = input_data.shape - if self.scale_embeddings: - return pos_encoding + input_data - else: - return input_data + (pos_encoding * 0.1) + if seq_len > self.max_sequence_length: + raise ValueError( + f"Input sequence length {seq_len} exceeds maximum length {self.max_sequence_length}" + ) + + pos_encoding = self.weights[:, :seq_len, :] + + if batch_size > 1: + pos_encoding = np.repeat(pos_encoding, batch_size, axis=0) + + if self.trainable: + self.current_scale = self.get_warmup_scale() + else: + self.current_scale = self.final_scale + + effective_scale = self.current_scale * self.base_scale_factor if self.scale_embeddings else self.current_scale + + output = input_data + (pos_encoding * effective_scale) + + self.metadata = { + 'step': self.current_step, + 'scale': self.current_scale, + 'effective_scale': effective_scale, + 'embedding_contribution': np.mean(np.abs(pos_encoding * effective_scale)) / np.mean(np.abs(input_data)) + } + + return output def backward_pass(self, output_error: np.ndarray) -> np.ndarray: if self.trainable: _, seq_len, _ = output_error.shape - if self.scale_embeddings: - output_error = output_error * self.scale_factor - - self.d_weights[:, :seq_len, :] += np.sum(output_error, axis=0, keepdims=True) + effective_scale = self.current_scale * self.base_scale_factor if self.scale_embeddings else self.current_scale + scaled_error = output_error * effective_scale + self.d_weights[:, :seq_len, :] += np.sum(scaled_error, axis=0, keepdims=True) + + self.current_step += 1 + return output_error def get_config(self) -> dict: return { - 'name': self.__class__.__name__, - 'seq_length': self.seq_length, + 'max_sequence_length': self.max_sequence_length, 'embedding_dim': self.embedding_dim, + 'warmup_steps': self.warmup_steps, + 'initial_scale': self.initial_scale, + 'final_scale': self.final_scale, + 'scale_embeddings': self.scale_embeddings, 'trainable': self.trainable, 'random_state': self.random_state, - 'weights': self.weights.tolist() if self.weights is not None else None + 'current_step': self.current_step, + 'current_scale': self.current_scale } - @staticmethod - def from_config(config: dict) -> "PositionalEncoding": - layer = PositionalEncoding( - seq_length=config['seq_length'], - embedding_dim=config['embedding_dim'], - trainable=config['trainable'], - random_state=config['random_state'] + def __str__(self) -> str: + return ( + f'PositionalEncodingWithWarmup(' + f'seq_length={self.seq_length}, ' + f'embedding_dim={self.embedding_dim}, ' + f'trainable={self.trainable}, ' + f'warmup_steps={self.warmup_steps}, ' + f'current_step={self.current_step}, ' + f'current_scale={self.current_scale:.4f})' ) - if config['weights'] is not None: - layer.weights = np.array(config['weights']) - layer.d_weights = np.zeros_like(layer.weights) - return layer class FeedForward(Layer): @@ -2799,7 +2832,7 @@ def __init__( d_ff: int, d_model: int, dropout_rate: float = 0.1, - activation: str = 'relu', + activation: str = 'gelu', kernel_initializer: str = "glorot_uniform", bias_initializer: str = "zeros", random_state: int = None, @@ -3152,7 +3185,7 @@ def __init__( d_ff: int, dropout_rate: float = 0.1, attention_dropout: float = 0.0, - activation: str = 'relu', + activation: str = 'gelu', kernel_initializer: str = "glorot_uniform", bias_initializer: str = "zeros", random_state: int | None = None,