Skip to content

Commit

Permalink
Merge pull request #74 from VectorInstitute/pre-commit-ci-update-config
Browse files Browse the repository at this point in the history
[pre-commit.ci] pre-commit autoupdate
  • Loading branch information
amrit110 authored Aug 20, 2024
2 parents 6a2c470 + a45fc62 commit 3bb2f1c
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 7 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ repos:
- id: check-toml

- repo: https://github.com/astral-sh/ruff-pre-commit
rev: 'v0.5.6'
rev: 'v0.6.1'
hooks:
- id: ruff
args: [--fix, --exit-non-zero-on-fix]
Expand Down
12 changes: 8 additions & 4 deletions odyssey/data/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,13 +529,17 @@ def get_pretrain_test_split(
test_size=test_size,
)
else:
test_patients = dataset.sample(n=int(test_size * len(dataset)), random_state=seed)
test_patients = dataset.sample(
n=int(test_size * len(dataset)), random_state=seed
)
test_ids = test_patients["patient_id"].tolist()
pretrain_ids = dataset[~dataset["patient_id"].isin(test_ids)]["patient_id"].tolist()

pretrain_ids = dataset[~dataset["patient_id"].isin(test_ids)][
"patient_id"
].tolist()

random.seed(seed)
random.shuffle(pretrain_ids)

return pretrain_ids, test_ids


Expand Down
6 changes: 4 additions & 2 deletions odyssey/models/ehr_mamba/mamba_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,8 @@ def forward(
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
Returns:
Returns
-------
"""
if inputs_embeds is not None:
sequence_outputs = self.backbone(
Expand Down Expand Up @@ -222,7 +223,8 @@ def forward(
task_indices (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Task indices to specify which classification head to use for each example in the batch.
Returns:
Returns
-------
"""
if inputs_embeds is not None:
sequence_outputs = self.backbone(
Expand Down

0 comments on commit 3bb2f1c

Please sign in to comment.