Skip to content

Commit

Permalink
Updated data.py
Browse files Browse the repository at this point in the history
  • Loading branch information
afallah committed Mar 11, 2024
1 parent 69430fe commit 43a01ee
Showing 1 changed file with 0 additions and 25 deletions.
25 changes: 0 additions & 25 deletions lib/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,19 +36,11 @@ def __len__(self) -> int:
return len(self.data)

def tokenize_data(self, sequence: Union[str, List[str]]) -> Any:
<<<<<<< HEAD:models/big_bird_cehr/data.py
""" Tokenize the sequence and return input_ids and attention mask. """
return self.tokenizer(sequence)

def mask_tokens(self, sequence: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
""" Mask the tokens in the sequence using vectorized operations."""
=======
"""Tokenize the sequence and return input_ids and attention mask."""
return self.tokenizer(sequence, max_length=self.max_len)

def mask_tokens(self, sequence: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""Mask the tokens in the sequence using vectorized operations."""
>>>>>>> main:lib/data.py
mask_token_id = self.tokenizer.get_mask_token_id()

masked_sequence = sequence.clone()
Expand All @@ -63,12 +55,6 @@ def mask_tokens(self, sequence: torch.Tensor) -> Tuple[torch.Tensor, torch.Tenso
masked_sequence[replaced] = mask_token_id

# 10% of the time, we replace masked input tokens with random vector.
<<<<<<< HEAD:models/big_bird_cehr/data.py
randomized = torch.bernoulli(torch.full(selected.shape, 0.1)).bool() & selected & ~replaced
random_idx = torch.randint(low=self.tokenizer.get_first_token_index(),
high=self.tokenizer.get_last_token_index(),
size=prob_matrix.shape, dtype=torch.long)
=======
randomized = (
torch.bernoulli(torch.full(selected.shape, 0.1)).bool()
& selected
Expand All @@ -80,7 +66,6 @@ def mask_tokens(self, sequence: torch.Tensor) -> Tuple[torch.Tensor, torch.Tenso
size=prob_matrix.shape,
dtype=torch.long,
)
>>>>>>> main:lib/data.py
masked_sequence[randomized] = random_idx[randomized]

labels = torch.where(selected, sequence, -100)
Expand All @@ -95,13 +80,8 @@ def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
"""
data = self.data.iloc[idx]
tokenized_input = self.tokenize_data(data[f"event_tokens_{self.max_len}"])
<<<<<<< HEAD:models/big_bird_cehr/data.py
concept_tokens = tokenized_input['input_ids'].squeeze()
attention_mask = tokenized_input['attention_mask'].squeeze()
=======
concept_tokens = tokenized_input["input_ids"].squeeze()
attention_mask = tokenized_input["attention_mask"].squeeze()
>>>>>>> main:lib/data.py

type_tokens = data[f"type_tokens_{self.max_len}"]
age_tokens = data[f"age_tokens_{self.max_len}"]
Expand Down Expand Up @@ -161,13 +141,8 @@ def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
"""
data = self.data.iloc[idx]
tokenized_input = self.tokenize_data(data[f"event_tokens_{self.max_len}"])
<<<<<<< HEAD:models/big_bird_cehr/data.py
concept_tokens = tokenized_input['input_ids'].squeeze()
attention_mask = tokenized_input['attention_mask'].squeeze()
=======
concept_tokens = tokenized_input["input_ids"].squeeze()
attention_mask = tokenized_input["attention_mask"].squeeze()
>>>>>>> main:lib/data.py

type_tokens = data[f"type_tokens_{self.max_len}"]
age_tokens = data[f"age_tokens_{self.max_len}"]
Expand Down

0 comments on commit 43a01ee

Please sign in to comment.