Skip to content

Commit

Permalink
[pre-commit.ci] Add auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Aug 12, 2024
1 parent 46956d7 commit 38df916
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 6 deletions.
12 changes: 8 additions & 4 deletions odyssey/data/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,13 +529,17 @@ def get_pretrain_test_split(
test_size=test_size,
)
else:
test_patients = dataset.sample(n=int(test_size * len(dataset)), random_state=seed)
test_patients = dataset.sample(
n=int(test_size * len(dataset)), random_state=seed
)
test_ids = test_patients["patient_id"].tolist()
pretrain_ids = dataset[~dataset["patient_id"].isin(test_ids)]["patient_id"].tolist()

pretrain_ids = dataset[~dataset["patient_id"].isin(test_ids)][
"patient_id"
].tolist()

random.seed(seed)
random.shuffle(pretrain_ids)

return pretrain_ids, test_ids


Expand Down
6 changes: 4 additions & 2 deletions odyssey/models/ehr_mamba/mamba_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,8 @@ def forward(
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
Returns:
Returns
-------
"""
if inputs_embeds is not None:
sequence_outputs = self.backbone(
Expand Down Expand Up @@ -222,7 +223,8 @@ def forward(
task_indices (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Task indices to specify which classification head to use for each example in the batch.
Returns:
Returns
-------
"""
if inputs_embeds is not None:
sequence_outputs = self.backbone(
Expand Down

0 comments on commit 38df916

Please sign in to comment.