Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Interp #25

Merged
merged 9 commits into from
Apr 18, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion odyssey/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,11 @@ def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
"attention_mask": attention_mask,
}

def __iter__(self) -> Any:
"""Return an iterator over the dataset."""
for i in range(len(self)):
yield self[i]


class FinetuneMultiDataset(Dataset):
"""Dataset for finetuning the model on multi dataset.
Expand Down Expand Up @@ -408,7 +413,7 @@ def __getitem__(self, idx: int) -> Dict[str, Any]:
data = self.data.iloc[index]

# Swap the first token with the task token.
data["event_tokens_2048"][0] = self.tokenizer.task_to_token(task)
data[f"event_tokens_{self.max_len}"][0] = self.tokenizer.task_to_token(task)

# Truncate and pad the data to the specified cutoff.
data = truncate_and_pad(data, cutoff, self.max_len)
Expand Down
125 changes: 96 additions & 29 deletions odyssey/data/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,20 +20,20 @@ def truncate_and_pad(
"""Truncate and pad the input row to the maximum length.

This function assumes the presence of the following columns in row:
- 'event_tokens_2048'
- 'type_tokens_2048'
- 'age_tokens_2048'
- 'time_tokens_2048'
- 'visit_tokens_2048'
- 'position_tokens_2048'
- 'elapsed_tokens_2048'
- 'event_tokens_{max_len}'
- 'type_tokens_{max_len}'
- 'age_tokens_{max_len}'
- 'time_tokens_{max_len}'
- 'visit_tokens_{max_len}'
- 'position_tokens_{max_len}'
- 'elapsed_tokens_{max_len}'

Parameters
----------
row: pd.Series
The input row.
cutoff: Optional[int]
The cutoff length. Will be set to length of 'event_tokens_2048' if None.
The cutoff length. Will be set to length of 'event_tokens_{max_len}' if None.
max_len: int
The maximum length to pad to.

Expand All @@ -47,41 +47,42 @@ def truncate_and_pad(
row = row.copy()

if not cutoff:
cutoff = len(row["event_tokens_2048"])
cutoff = min(max_len, len(row[f"event_tokens_{max_len}"]))

row["event_tokens_2048"] = row["event_tokens_2048"][:cutoff]
row["type_tokens_2048"] = np.pad(
row["type_tokens_2048"][:cutoff],
row[f"event_tokens_{max_len}"] = row[f"event_tokens_{max_len}"][:cutoff]

row[f"type_tokens_{max_len}"] = np.pad(
row[f"type_tokens_{max_len}"][:cutoff],
(0, max_len - cutoff),
mode="constant",
)
row["age_tokens_2048"] = np.pad(
row["age_tokens_2048"][:cutoff],
row[f"age_tokens_{max_len}"] = np.pad(
row[f"age_tokens_{max_len}"][:cutoff],
(0, max_len - cutoff),
mode="constant",
)
row["time_tokens_2048"] = np.pad(
row["time_tokens_2048"][:cutoff],
row[f"time_tokens_{max_len}"] = np.pad(
row[f"time_tokens_{max_len}"][:cutoff],
(0, max_len - cutoff),
mode="constant",
)
row["visit_tokens_2048"] = np.pad(
row["visit_tokens_2048"][:cutoff],
row[f"visit_tokens_{max_len}"] = np.pad(
row[f"visit_tokens_{max_len}"][:cutoff],
(0, max_len - cutoff),
mode="constant",
)
row["position_tokens_2048"] = np.pad(
row["position_tokens_2048"][:cutoff],
row[f"position_tokens_{max_len}"] = np.pad(
row[f"position_tokens_{max_len}"][:cutoff],
(0, max_len - cutoff),
mode="constant",
)
row["elapsed_tokens_2048"] = np.pad(
row["elapsed_tokens_2048"][:cutoff],
row[f"elapsed_tokens_{max_len}"] = np.pad(
row[f"elapsed_tokens_{max_len}"][:cutoff],
(0, max_len - cutoff),
mode="constant",
)

row["event_tokens_2048"] = " ".join(row["event_tokens_2048"])
row[f"event_tokens_{max_len}"] = " ".join(row[f"event_tokens_{max_len}"])

return row

Expand Down Expand Up @@ -142,8 +143,12 @@ def __init__(
tokenizer_object: Optional[Tokenizer] = None,
tokenizer: Optional[PreTrainedTokenizerFast] = None,
) -> None:
self.mask_token = mask_token
self.pad_token = pad_token
self.mask_token = mask_token
self.start_token = start_token
self.end_token = end_token
self.class_token = class_token
self.reg_token = reg_token
self.unknown_token = unknown_token
self.task_tokens = ["[MOR_1M]", "[LOS_1W]", "[REA_1M]"] + [
f"[C{i}]" for i in range(0, 5)
Expand Down Expand Up @@ -178,7 +183,7 @@ def __init__(
self.first_token_index: Optional[int] = None
self.last_token_index: Optional[int] = None

def fit_on_vocab(self) -> None:
def fit_on_vocab(self, with_tasks: bool = True) -> None:
"""Fit the tokenizer on the vocabulary."""
# Create dictionary of all possible medical concepts
self.token_type_vocab["special_tokens"] = self.special_tokens
Expand All @@ -190,12 +195,13 @@ def fit_on_vocab(self) -> None:
vocab_type = file.split("/")[-1].split(".")[0]
self.token_type_vocab[vocab_type] = vocab

self.token_type_vocab["task_tokens"] = self.task_tokens
if with_tasks:
self.special_tokens += self.task_tokens
self.token_type_vocab["task_tokens"] = self.task_tokens

# Create the tokenizer dictionary
tokens = list(chain.from_iterable(list(self.token_type_vocab.values())))
self.tokenizer_vocab = {token: i for i, token in enumerate(tokens)}
self.special_tokens += self.task_tokens

# Create the tokenizer object
self.tokenizer_object = Tokenizer(
Expand Down Expand Up @@ -445,7 +451,7 @@ def get_special_token_ids(self) -> List[int]:

return self.special_token_ids

def save_tokenizer_to_disk(self, save_dir: str) -> None:
def save(self, save_dir: str) -> None:
"""Save the tokenizer object to disk as a JSON file.

Parameters
Expand All @@ -454,7 +460,68 @@ def save_tokenizer_to_disk(self, save_dir: str) -> None:
Directory to save the tokenizer.

"""
self.tokenizer.save(path=save_dir)
os.makedirs(save_dir, exist_ok=True)
tokenizer_config = {
"pad_token": self.pad_token,
"mask_token": self.mask_token,
"unknown_token": self.unknown_token,
"start_token": self.start_token,
"end_token": self.end_token,
"class_token": self.class_token,
"reg_token": self.reg_token,
"special_tokens": self.special_tokens,
"tokenizer_vocab": self.tokenizer_vocab,
"token_type_vocab": self.token_type_vocab,
"data_dir": self.data_dir,
}
save_path = os.path.join(save_dir, "tokenizer.json")
with open(save_path, "w") as file:
json.dump(tokenizer_config, file, indent=4)

@classmethod
def load(cls, load_path: str) -> "ConceptTokenizer":
"""
Load the tokenizer configuration from a file.

Parameters
----------
load_path : str
The path from where the tokenizer configuration will be loaded.

Returns
-------
ConceptTokenizer
An instance of ConceptTokenizer initialized with the loaded configuration.
"""
with open(load_path, "r") as file:
tokenizer_config = json.load(file)

tokenizer = cls(
pad_token=tokenizer_config["pad_token"],
mask_token=tokenizer_config["mask_token"],
unknown_token=tokenizer_config["unknown_token"],
start_token=tokenizer_config.get("start_token", "[VS]"),
end_token=tokenizer_config.get("end_token", "[VE]"),
class_token=tokenizer_config.get("class_token", "[CLS]"),
reg_token=tokenizer_config.get("reg_token", "[REG]"),
data_dir=tokenizer_config.get("data_dir", "data_files"),
)

tokenizer.special_tokens = tokenizer_config["special_tokens"]
tokenizer.tokenizer_vocab = tokenizer_config["tokenizer_vocab"]
tokenizer.token_type_vocab = tokenizer_config["token_type_vocab"]

tokenizer.tokenizer_object = Tokenizer(
models.WordPiece(
vocab=tokenizer.tokenizer_vocab,
unk_token=tokenizer.unknown_token,
max_input_chars_per_word=1000,
),
)
tokenizer.tokenizer_object.pre_tokenizer = pre_tokenizers.WhitespaceSplit()
tokenizer.tokenizer = tokenizer.create_tokenizer(tokenizer.tokenizer_object)

return tokenizer

def create_task_to_token_dict(self) -> Dict[str, str]:
"""Create a dictionary mapping each task to its respective special token.
Expand Down
1 change: 1 addition & 0 deletions odyssey/interp/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Interpretability sub-package."""
Loading
Loading