diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index ec656db..c24db4e 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -16,7 +16,7 @@ repos: - id: check-toml - repo: https://github.com/astral-sh/ruff-pre-commit - rev: 'v0.5.6' + rev: 'v0.6.1' hooks: - id: ruff args: [--fix, --exit-non-zero-on-fix] diff --git a/odyssey/data/processor.py b/odyssey/data/processor.py index 9564b1e..e4c7d50 100644 --- a/odyssey/data/processor.py +++ b/odyssey/data/processor.py @@ -529,13 +529,17 @@ def get_pretrain_test_split( test_size=test_size, ) else: - test_patients = dataset.sample(n=int(test_size * len(dataset)), random_state=seed) + test_patients = dataset.sample( + n=int(test_size * len(dataset)), random_state=seed + ) test_ids = test_patients["patient_id"].tolist() - pretrain_ids = dataset[~dataset["patient_id"].isin(test_ids)]["patient_id"].tolist() - + pretrain_ids = dataset[~dataset["patient_id"].isin(test_ids)][ + "patient_id" + ].tolist() + random.seed(seed) random.shuffle(pretrain_ids) - + return pretrain_ids, test_ids diff --git a/odyssey/models/ehr_mamba/mamba_utils.py b/odyssey/models/ehr_mamba/mamba_utils.py index de227f6..bcd3645 100644 --- a/odyssey/models/ehr_mamba/mamba_utils.py +++ b/odyssey/models/ehr_mamba/mamba_utils.py @@ -112,7 +112,8 @@ def forward( config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). - Returns: + Returns + ------- """ if inputs_embeds is not None: sequence_outputs = self.backbone( @@ -222,7 +223,8 @@ def forward( task_indices (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Task indices to specify which classification head to use for each example in the batch. - Returns: + Returns + ------- """ if inputs_embeds is not None: sequence_outputs = self.backbone(