Skip to content

Commit

Permalink
resolved review feedback. Added a based_model docstring. Added versio…
Browse files Browse the repository at this point in the history
…n for polars. Added github workflow test matrix over python versions, removed redundant run_command definition from test_configs
  • Loading branch information
Oufattole committed Sep 9, 2024
1 parent 2601fca commit a4ad03c
Show file tree
Hide file tree
Showing 7 changed files with 19 additions and 46 deletions.
6 changes: 4 additions & 2 deletions .github/workflows/tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,19 @@ jobs:

strategy:
fail-fast: false
matrix:
python-version: ["3.11", "3.12"]

timeout-minutes: 30

steps:
- name: Checkout
uses: actions/checkout@v3

- name: Set up Python 3.12
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v3
with:
python-version: "3.12"
python-version: ${{ matrix.python-version }}

- name: Install packages
run: |
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ classifiers = [
"Operating System :: OS Independent",
]
dependencies = [
"polars", "pyarrow", "loguru", "hydra-core==1.3.2", "numpy", "scipy<1.14.0", "pandas", "tqdm", "xgboost",
"polars==1.6.0", "pyarrow", "loguru", "hydra-core==1.3.2", "numpy", "scipy<1.14.0", "pandas", "tqdm", "xgboost",
"scikit-learn", "hydra-optuna-sweeper", "hydra-joblib-launcher", "ml-mixins", "meds==0.3.3", "meds-transforms==0.0.7",
]

Expand Down
2 changes: 2 additions & 0 deletions src/MEDS_tabular_automl/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@


class BaseModel(ABC, TimeableMixin):
"""Defines the interface for a model that can be trained and evaluated via the launch_model script."""

@abstractmethod
def __init__(self):
pass
Expand Down
Empty file added tests/__init__.py
Empty file.
32 changes: 2 additions & 30 deletions tests/test_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,44 +2,16 @@

root = rootutils.setup_root(__file__, dotenv=True, pythonpath=True, cwd=True)

import subprocess

import hydra
import polars as pl
import pytest
from hydra import compose, initialize
from omegaconf import DictConfig, OmegaConf
from omegaconf import DictConfig

from MEDS_tabular_automl.sklearn_model import SklearnModel
from MEDS_tabular_automl.xgboost_model import XGBoostModel


def run_command(script: str, args: list[str], hydra_kwargs: dict[str, str], test_name: str):
command_parts = [script] + args + [f"{k}={v}" for k, v in hydra_kwargs.items()]
command_out = subprocess.run(" ".join(command_parts), shell=True, capture_output=True)
stderr = command_out.stderr.decode()
stdout = command_out.stdout.decode()
if command_out.returncode != 0:
raise AssertionError(f"{test_name} failed!\nstdout:\n{stdout}\nstderr:\n{stderr}")
return stderr, stdout


def make_config_mutable(cfg):
if OmegaConf.is_config(cfg):
OmegaConf.set_readonly(cfg, False)
for key in cfg.keys():
print(key)
# try:
cfg[key] = make_config_mutable(cfg[key])
# except:
# import pdb; pdb.set_trace()
return cfg
# elif isinstance(cfg, list):
# return [make_config_mutable(item) for item in cfg]
# elif isinstance(cfg, dict):
# return {key: make_config_mutable(value) for key, value in cfg.items()}
else:
return cfg
from tests.test_integration import run_command


@pytest.mark.parametrize(
Expand Down
20 changes: 10 additions & 10 deletions tests/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,6 @@

import polars as pl
from hydra import compose, initialize
from test_tabularize import (
CODE_COLS,
EXPECTED_STATIC_FILES,
MEDS_OUTPUTS,
NUM_SHARDS,
SPLITS_JSON,
STATIC_FIRST_COLS,
STATIC_PRESENT_COLS,
VALUE_COLS,
)

from MEDS_tabular_automl.describe_codes import get_feature_columns
from MEDS_tabular_automl.file_name import list_subdir_files
Expand All @@ -30,6 +20,16 @@
get_unique_time_events_df,
load_matrix,
)
from tests.test_tabularize import (
CODE_COLS,
EXPECTED_STATIC_FILES,
MEDS_OUTPUTS,
NUM_SHARDS,
SPLITS_JSON,
STATIC_FIRST_COLS,
STATIC_PRESENT_COLS,
VALUE_COLS,
)


def run_command(script: str, args: list[str], hydra_kwargs: dict[str, str], test_name: str):
Expand Down
3 changes: 0 additions & 3 deletions tests/test_tabularize.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@

import polars as pl
from hydra import compose, initialize
from loguru import logger

from MEDS_tabular_automl.describe_codes import get_feature_columns
from MEDS_tabular_automl.file_name import list_subdir_files
Expand All @@ -30,8 +29,6 @@
load_matrix,
)

logger.disable("MEDS_tabular_automl")

SPLITS_JSON = """{"train/0": [239684, 1195293], "train/1": [68729, 814703], "tuning/0": [754281], "held_out/0": [1500733]}""" # noqa: E501
NUM_SHARDS = 4

Expand Down

0 comments on commit a4ad03c

Please sign in to comment.