Skip to content

Commit

Permalink
Format
Browse files Browse the repository at this point in the history
  • Loading branch information
Fabian Degen committed Jan 7, 2025
1 parent ffc194b commit 7a6dec9
Show file tree
Hide file tree
Showing 5 changed files with 19 additions and 8 deletions.
16 changes: 13 additions & 3 deletions transformer_lens/HookedEncoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,11 @@ def forward(
task: str = None,
token_type_ids: Optional[Int[torch.Tensor, "batch pos"]] = None,
one_zero_attention_mask: Optional[Int[torch.Tensor, "batch pos"]] = None,
) -> Optional[Union[Float[torch.Tensor, "batch pos d_vocab"], Float[torch.Tensor, "batch 2"], str, List[str]]]:
) -> Optional[
Union[
Float[torch.Tensor, "batch pos d_vocab"], Float[torch.Tensor, "batch 2"], str, List[str]
]
]:
"""Forward pass through the HookedEncoder.
Args:
Expand All @@ -152,13 +156,15 @@ def forward(
"[CLS] Sentence A [SEP] Sentence B [SEP]", token_type_ids would be
[0, 0, ..., 0, 1, ..., 1, 1]. `0` represents tokens from Sentence A,
`1` from Sentence B. If not provided, BERT assumes a single sequence input.
This parameter gets inferred from the the tokenizer if input is a string or list of strings.
Shape is (batch_size, sequence_length).
one_zero_attention_mask: Optional[torch.Tensor]: A binary mask which indicates
which tokens should be attended to (1) and which should be ignored (0).
Primarily used for padding variable-length sentences in a batch.
For instance, in a batch with sentences of differing lengths, shorter
sentences are padded with 0s on the right. If not provided, the model
assumes all tokens should be attended to.
This parameter gets inferred from the tokenizer if input is a string or list of strings.
Shape is (batch_size, sequence_length).
Returns:
Expand Down Expand Up @@ -191,7 +197,7 @@ def forward(
raise ValueError(
"Next sentence prediction task requires exactly two sentences, please provide a list of strings with each sentence as an element."
)

# We need to input the two sentences separately for NSP
encodings = self.tokenizer(
input[0],
Expand All @@ -214,7 +220,11 @@ def forward(

# If token_type_ids or attention mask are not provided, use the ones from the tokenizer
token_type_ids = encodings.token_type_ids if token_type_ids is None else token_type_ids
one_zero_attention_mask = encodings.attention_mask if one_zero_attention_mask is None else one_zero_attention_mask
one_zero_attention_mask = (
encodings.attention_mask
if one_zero_attention_mask is None
else one_zero_attention_mask
)
else:
if task == "NSP" and token_type_ids is None:
raise ValueError(
Expand Down
2 changes: 1 addition & 1 deletion transformer_lens/HookedTransformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -744,7 +744,7 @@ def to_tokens(
self.tokenizer.padding_side. Specifies which side to pad when tokenizing
multiple strings of different lengths.
move_to_device (bool): Whether to move the output tensor of tokens to the device the
model lives on. Defaults to True
model lives on. Defaults to True
truncate (bool): If the output tokens are too long,
whether to truncate the output tokens to the model's max context window. Does nothing
for shorter inputs. Defaults to True.
Expand Down
2 changes: 1 addition & 1 deletion transformer_lens/components/bert_embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import einops
import torch
import torch.nn as nn
from jaxtyping import Int, Float
from jaxtyping import Float, Int

from transformer_lens.components import Embed, LayerNorm, PosEmbed, TokenTypeEmbed
from transformer_lens.hook_points import HookPoint
Expand Down
5 changes: 3 additions & 2 deletions transformer_lens/components/bert_mlm_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
"""
from typing import Dict, Union

import einops
import torch
import torch.nn as nn
from jaxtyping import Float
Expand All @@ -26,7 +25,9 @@ def __init__(self, cfg: Union[Dict, HookedTransformerConfig]):
self.act_fn = nn.GELU()
self.ln = LayerNorm(self.cfg)

def forward(self, resid: Float[torch.Tensor, "batch pos d_model"]) -> Float[torch.Tensor, "batch pos d_model"]:
def forward(
self, resid: Float[torch.Tensor, "batch pos d_model"]
) -> Float[torch.Tensor, "batch pos d_model"]:
resid = torch.matmul(resid, self.W) + self.b
resid = self.act_fn(resid)
resid = self.ln(resid)
Expand Down
2 changes: 1 addition & 1 deletion transformer_lens/components/bert_pooler.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

class BertPooler(nn.Module):
"""
Transforms the [CLS] token representation into a fixed-size sequence embedding.
Transforms the [CLS] token representation into a fixed-size sequence embedding.
The purpose of this module is to convert variable-length sequence inputs into a single vector representation suitable for downstream tasks.
(e.g. Next Sentence Prediction)
"""
Expand Down

0 comments on commit 7a6dec9

Please sign in to comment.