Skip to content

Commit

Permalink
Merge branch 'main' into dependabot/pip/aiohttp-3.10.2
Browse files Browse the repository at this point in the history
  • Loading branch information
amrit110 authored Aug 20, 2024
2 parents a2aa8dd + 3e7400e commit 577b4a9
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 7 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/code_checks.yml
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,6 @@ jobs:
poetry install --with test --all-extras
pre-commit run --all-files
- name: pip-audit (gh-action-pip-audit)
uses: pypa/gh-action-pip-audit@v1.0.8
uses: pypa/gh-action-pip-audit@v1.1.0
with:
virtual-environment: .venv/
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ repos:
- id: check-toml

- repo: https://github.com/astral-sh/ruff-pre-commit
rev: 'v0.5.6'
rev: 'v0.6.1'
hooks:
- id: ruff
args: [--fix, --exit-non-zero-on-fix]
Expand Down
8 changes: 5 additions & 3 deletions odyssey/data/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -528,14 +528,16 @@ def get_pretrain_test_split(
target=stratify_target,
test_size=test_size,
)

else:
test_patients = dataset.sample(n=test_size, random_state=seed)
test_patients = dataset.sample(
n=int(test_size * len(dataset)), random_state=seed
)
test_ids = test_patients["patient_id"].tolist()
pretrain_ids = dataset[~dataset["patient_id"].isin(test_patients)][
pretrain_ids = dataset[~dataset["patient_id"].isin(test_ids)][
"patient_id"
].tolist()

random.seed(seed)
random.shuffle(pretrain_ids)

return pretrain_ids, test_ids
Expand Down
6 changes: 4 additions & 2 deletions odyssey/models/ehr_mamba/mamba_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,8 @@ def forward(
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
Returns: # noqa: D407
Returns
-------
"""
if inputs_embeds is not None:
sequence_outputs = self.backbone(
Expand Down Expand Up @@ -222,7 +223,8 @@ def forward(
task_indices (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Task indices to specify which classification head to use for each example in the batch.
Returns: # noqa: D407
Returns
-------
"""
if inputs_embeds is not None:
sequence_outputs = self.backbone(
Expand Down

0 comments on commit 577b4a9

Please sign in to comment.