Skip to content

Commit

Permalink
Implemented a decoder dataset for training Mamba.
Browse files Browse the repository at this point in the history
  • Loading branch information
Adibvafa committed May 1, 2024
1 parent cef4afe commit 17ebdb7
Showing 1 changed file with 189 additions and 0 deletions.
189 changes: 189 additions & 0 deletions odyssey/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -583,3 +583,192 @@ def balance_labels(self, task: str, positive_ratio: float) -> None:

# Combine the kept negatives with all positives
self.task_to_index[task] = positives + negatives_kept


class FinetuneDatasetDecoder(Dataset):
"""Dataset for finetuning a decoder-based model.
Parameters
----------
data : pd.DataFrame
The input data containing sequences to be tokenized and masked.
tokenizer : ConceptTokenizer
An instance of the ConceptTokenizer class used for tokenizing sequences.
tasks : List[str]
A list of tasks (labels) that need to be predicted.
balance_guide : Optional[Dict[str, float]], optional
A dictionary containing the desired positive ratios for each task,
by default None.
max_len : int, optional
The maximum length of the tokenized sequences, by default 2048.
nan_indicator : int, optional
Value used to represent missing labels in the dataset, by default -1.
Attributes
----------
data : pd.DataFrame
Stores the input data.
tokenizer : ConceptTokenizer
Tokenizer used for tokenizing sequences.
tasks : List[str]
A list of tasks (labels) that need to be predicted.
balance_guide : Optional[Dict[str, float]]
A dictionary containing the desired positive ratios for each task.
max_len : int
Maximum length of the tokenized sequences.
nan_indicator : int
Value used to represent missing labels in the dataset.
task_to_index : Dict[str, List[Tuple[int, str, int, Optional[int]]]]
A dictionary mapping each task to a list of tuples containing the
index, task, label, and cutoff.
index_mapper : List[Tuple[int, str, int, Optional[int]]]
A list of all datapoints to be used by __getitem__.
"""

def __init__(
self,
data: pd.DataFrame,
tokenizer: ConceptTokenizer,
tasks: List[str],
balance_guide: Optional[Dict[str, float]] = None,
max_len: int = 2048,
nan_indicator: int = -1,
):
"""Initiate the class."""
super().__init__()

self.data = data
self.tokenizer = tokenizer
self.tasks = tasks # List of tasks for which the model is being finetuned.
self.balance_guide = balance_guide
self.max_len = max_len
self.nan_indicator = (
nan_indicator # Value used to indicate missing data in labels.
)

# Precompute indices for quick mapping in __getitem__ that
# exclude missing labels.
# This helps in filtering out entries where the label is missing
# for the specified tasks.
self.task_to_index = {task: [] for task in self.tasks}
self.data.reset_index(drop=True, inplace=True)

for patient in self.data.itertuples():
index = patient.Index

for task in self.tasks:
label_col = f"label_{task}"
# Skip this task for the current patient if the label is missing.
if getattr(patient, label_col) == self.nan_indicator:
continue

label = getattr(patient, label_col)
# Check for the existence of a task-specific cutoff in the data,
# else use None.
if f"cutoff_{task}" in self.data.columns:
cutoff = getattr(patient, f"cutoff_{task}")
else:
cutoff = None
# Append a tuple containing the necessary information
# for training to index_mapper.
datapoint = (index, task, label, cutoff)
self.task_to_index[task].append(datapoint)

# Balance labels for specified tasks
if self.balance_guide:
for task in self.balance_guide:
self.balance_labels(task=task, positive_ratio=self.balance_guide[task])

# Create a list of all datapoints to be used by __getitem__
self.index_mapper = [
datapoints
for task_data in self.task_to_index.values()
for datapoints in task_data
]
del self.task_to_index

def __len__(self) -> int:
"""Return the length of dataset."""
return len(self.index_mapper)

def __getitem__(self, idx: int) -> Dict[str, Any]:
"""Get data at corresponding index.
Parameters
----------
idx : int
The index of the data to be retrieved.
Returns
-------
Dict[str, Any]
A dictionary containing all different token sequences along with labels.
"""
index, task, labels, cutoff = self.index_mapper[idx]
data = self.data.iloc[index]

# Swap the first and last token with the task token.
data[f"event_tokens_{self.max_len}"][0] = self.tokenizer.task_to_token(task)
data[f"event_tokens_{self.max_len}"][-1] = self.tokenizer.task_to_token(task)

# Truncate and pad the data to the specified cutoff.
data = truncate_and_pad(data, cutoff, self.max_len)

# Prepare model input
tokenized_input = self.tokenize_data(data[f"event_tokens_{self.max_len}"])
concept_ids = tokenized_input["input_ids"].squeeze()
labels = torch.tensor(labels)

return {
"concept_ids": concept_ids,
"labels": labels,
"task": task
}

def tokenize_data(self, sequence: Union[str, List[str]]) -> Any:
"""Tokenize the sequence and return input_ids and attention mask.
Parameters
----------
sequence : Union[str, List[str]]
The sequence to be tokenized.
Returns
-------
Any
A dictionary containing input_ids and attention_mask.
"""
return self.tokenizer(sequence, max_length=self.max_len)

def balance_labels(self, task: str, positive_ratio: float) -> None:
"""Balance the labels for the specified task in the dataset.
This function modifies the dataset to ensure that the ratio of positive samples
to the total number of samples matches the specified positive_ratio,
while keeping all positive data points.
Parameters
----------
task : str
The task for which the labels need to be balanced.
positive_ratio : float
The desired positive ratio for the task.
"""
# Separate positive and negative datapoints
datapoints = self.task_to_index[task]
positives = [data for data in datapoints if data[LABEL_INDEX] == 1]
negatives = [data for data in datapoints if data[LABEL_INDEX] == 0]

# Calculate the total number of samples needed to achieve the
# desired positive ratio
num_positives = len(positives)
total_needed = int(num_positives / positive_ratio) - num_positives
num_negatives_to_keep = min(len(negatives), total_needed)

# Randomly select the negatives to keep
negatives_kept = random.sample(negatives, num_negatives_to_keep)

# Combine the kept negatives with all positives
self.task_to_index[task] = positives + negatives_kept

0 comments on commit 17ebdb7

Please sign in to comment.