Skip to content

Commit

Permalink
Merge pull request #7 from otavioon/experiments
Browse files Browse the repository at this point in the history
Updated experiments API
  • Loading branch information
otavioon authored Feb 1, 2024
2 parents 06461f0 + 2c56799 commit 5feb802
Show file tree
Hide file tree
Showing 17 changed files with 693 additions and 772 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ requires-python = ">=3.8"
version = "0.0.1-dev"

dependencies = [
"ipywidgets",
"jsonargparse[all]",
"librep@git+https://github.com/discovery-unicamp/hiaac-librep.git@0.0.4-dev",
"lightly",
Expand Down
9 changes: 7 additions & 2 deletions ssl_tools/experiments/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,7 @@
from .lightning_cli import LightningTrain, LightningTest
from .ssl_experiment import SSLTrain, SSLTest
from .experiment import Experiment, auto_main
from .lightning_experiment import (
LightningExperiment,
LightningTrain,
LightningTest,
LightningSSLTrain,
)
79 changes: 79 additions & 0 deletions ssl_tools/experiments/experiment.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
from pathlib import Path
from typing import Any, Dict, Union
from abc import ABC, abstractmethod
from datetime import datetime
from jsonargparse import ArgumentParser

EXPERIMENT_VERSION_FORMAT = "%Y-%m-%d_%H-%M-%S"


class Experiment(ABC):
def __init__(
self,
name: str = "experiment",
run_id: Union[str, int] = None,
log_dir: str = "logs",
seed: int = None,
):
self.name = name
self.run_id = run_id or datetime.now().strftime(
EXPERIMENT_VERSION_FORMAT
)
self.log_dir = log_dir
self.seed = seed

@property
def experiment_dir(self) -> Path:
return Path(self.log_dir) / self.name / str(self.run_id)

def setup(self):
pass

@abstractmethod
def run(self) -> Any:
raise NotImplementedError

def teardown(self):
pass

def execute(self):
print(f"Setting up experiment: {self.name}...")
self.setup()
print(f"Running experiment: {self.name}...")
result = self.run()
print(f"Teardown experiment: {self.name}...")
self.teardown()
return result

def __call__(self):
return self.execute()

def __str__(self):
return f"Experiment(name={self.name}, run_id={self.run_id}, cwd={self.experiment_dir})"

def __repr__(self) -> str:
return str(self)


def get_parser(commands: Dict[str, Experiment]):
parser = ArgumentParser()
subcommands = parser.add_subcommands()

for name, command in commands.items():
subparser = ArgumentParser()
subparser.add_class_arguments(command)
subcommands.add_subcommand(name, subparser)

return parser


def auto_main(commands: Dict[str, Experiment]):
parser = get_parser(commands)
args = parser.parse_args()
# print(args)

experiment = commands[args.subcommand](**args[args.subcommand])
experiment.execute()

# command = args.subcommand
# command(**args).execute()
49 changes: 22 additions & 27 deletions ssl_tools/experiments/har_classification/cpc.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,9 @@
#!/usr/bin/env python

# TODO: A way of removing the need to add the path to the root of
# the project
import sys
from jsonargparse import CLI
import lightning as L
import torch

sys.path.append("../../../")


from ssl_tools.experiments import SSLTrain, SSLTest
from ssl_tools.experiments import LightningSSLTrain, LightningTest, auto_main
from ssl_tools.models.ssl.cpc import build_cpc
from ssl_tools.data.data_modules import (
MultiModalHARSeriesDataModule,
Expand All @@ -21,7 +14,7 @@
from ssl_tools.models.ssl.modules.heads import CPCPredictionHead


class CPCTrain(SSLTrain):
class CPCTrain(LightningSSLTrain):
_MODEL_NAME = "CPC"

def __init__(
Expand Down Expand Up @@ -64,7 +57,7 @@ def __init__(
self.num_classes = num_classes
self.update_backbone = update_backbone

def _get_pretrain_model(self) -> L.LightningModule:
def get_pretrain_model(self) -> L.LightningModule:
model = build_cpc(
encoding_size=self.encoding_size,
in_channel=self.in_channel,
Expand All @@ -74,7 +67,7 @@ def _get_pretrain_model(self) -> L.LightningModule:
)
return model

def _get_pretrain_data_module(self) -> L.LightningDataModule:
def get_pretrain_data_module(self) -> L.LightningDataModule:
data_module = UserActivityFolderDataModule(
data_path=self.data,
batch_size=self.batch_size,
Expand All @@ -83,17 +76,17 @@ def _get_pretrain_data_module(self) -> L.LightningDataModule:
)
return data_module

def _get_finetune_model(
def get_finetune_model(
self, load_backbone: str = None
) -> L.LightningModule:
model = self._get_pretrain_model()
model = self.get_pretrain_model()

if load_backbone is not None:
self._load_model(model, load_backbone)
self.load_checkpoint(model, load_backbone)

classifier = CPCPredictionHead(
input_dim=self.encoding_size,
hidden_dim2=self.num_classes,
output_dim=self.num_classes,
)

task = "multiclass" if self.num_classes > 2 else "binary"
Expand All @@ -107,18 +100,19 @@ def _get_finetune_model(
)
return model

def _get_finetune_data_module(self) -> L.LightningDataModule:
def get_finetune_data_module(self) -> L.LightningDataModule:
data_module = MultiModalHARSeriesDataModule(
data_path=self.data,
batch_size=self.batch_size,
label="standard activity code",
features_as_channels=True,
num_workers=self.num_workers,
)

return data_module


class CPCTest(SSLTest):
class CPCTest(LightningTest):
_MODEL_NAME = "CPC"

def __init__(
Expand Down Expand Up @@ -154,17 +148,20 @@ def __init__(
self.window_size = window_size
self.num_classes = num_classes

def _get_test_model(self, load_backbone: str = None) -> L.LightningModule:
def get_model(self, load_backbone: str = None) -> L.LightningModule:
model = build_cpc(
encoding_size=self.encoding_size,
in_channel=self.in_channel,
window_size=self.window_size,
n_size=5,
)

if load_backbone is not None:
self.load_checkpoint(model, load_backbone)

classifier = CPCPredictionHead(
input_dim=self.encoding_size,
hidden_dim2=self.num_classes,
output_dim=self.num_classes,
)

task = "multiclass" if self.num_classes > 2 else "binary"
Expand All @@ -176,23 +173,21 @@ def _get_test_model(self, load_backbone: str = None) -> L.LightningModule:
)
return model

def _get_test_data_module(self) -> L.LightningDataModule:
def get_data_module(self) -> L.LightningDataModule:
data_module = MultiModalHARSeriesDataModule(
data_path=self.data,
batch_size=self.batch_size,
label="standard activity code",
features_as_channels=True,
num_workers=self.num_workers,
)

return data_module


def main():
components = {
if __name__ == "__main__":
options = {
"fit": CPCTrain,
"test": CPCTest,
}
CLI(components=components, as_positional=False)()


if __name__ == "__main__":
main()
auto_main(options)
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ cd ..
--batch_size 128 \
--accelerator gpu \
--devices 1 \
--load_backbone logs/pretrain/CPC/2024-01-29_19-54-51/checkpoints/last.ckpt \
--load_backbone logs/pretrain/CPC/2024-01-31_21-14-31/checkpoints/last.ckpt \
--training_mode finetune \
--window_size 60 \
--num_classes 6 \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ cd ..
--batch_size 128 \
--accelerator gpu \
--devices 1 \
--load_backbone logs/pretrain/TFC/2024-01-28_22-20-23/checkpoints/last.ckpt \
--load_backbone logs/pretrain/TFC/2024-01-31_21-07-06/checkpoints/last.ckpt \
--training_mode finetune \
--encoding_size 128 \
--features_as_channels True \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ cd ..
--batch_size 128 \
--accelerator gpu \
--devices 1 \
--load_backbone logs/pretrain/TNC/2024-01-29_19-28-11/checkpoints/last.ckpt \
--load_backbone logs/pretrain/TNC/2024-01-31_21-13-05/checkpoints/last.ckpt \
--training_mode finetune \
--repeat 5 \
--mc_sample_size 20 \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ for dset in "KuHar" "MotionSense" "RealWorld_thigh" "RealWorld_waist" "UCI";
do
./cpc.py test \
--data /workspaces/hiaac-m4/ssl_tools/data/standartized_balanced/${dset} \
--load logs/finetune/CPC/2024-01-29_19-58-58/checkpoints/last.ckpt \
--load logs/finetune/CPC/2024-01-31_21-19-23/checkpoints/last.ckpt \
--batch_size 128 \
--accelerator gpu \
--devices 1 \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ for dset in "KuHar" "MotionSense" "RealWorld_thigh" "RealWorld_waist" "UCI";
do
./tfc.py test \
--data /workspaces/hiaac-m4/ssl_tools/data/standartized_balanced/${dset} \
--load logs/finetune/TFC/2023-12-06_19-35-44/checkpoints/last.ckpt \
--load logs/finetune/TFC/2024-01-31_21-17-05/checkpoints/last.ckpt \
--batch_size 128 \
--accelerator gpu \
--devices 1 \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ for dset in "KuHar" "MotionSense" "RealWorld_thigh" "RealWorld_waist" "UCI";
do
./tnc.py test \
--data /workspaces/hiaac-m4/ssl_tools/data/standartized_balanced/${dset} \
--load logs/finetune/TNC/2024-01-29_19-36-22/checkpoints/last.ckpt \
--load logs/finetune/TNC/2024-01-31_21-22-05/checkpoints/last.ckpt \
--batch_size 128 \
--accelerator gpu \
--devices 1 \
Expand Down
Loading

0 comments on commit 5feb802

Please sign in to comment.