Skip to content

Commit

Permalink
Catalyst integration (#139)
Browse files Browse the repository at this point in the history
* LGBM integration added

* Tests for lgbm integration

* test_lgbm: added tmp_dir fixture

* Update dvclive/lgbm.py

Save model at the end of each iteration

Co-authored-by: David de la Iglesia Castro <daviddelaiglesiacastro@gmail.com>

* Huggingface integration

* Add on_log event

* Update dvclive/huggingface.py

Co-authored-by: David de la Iglesia Castro <daviddelaiglesiacastro@gmail.com>

* fix: huggingface test after callback changes

* revert last commit

* fix: huggingface test after calback changes

* Updated test_huggingface

* Updated test_huggingface

* Catalyst integration

* Callback rename

* Rename callback in test

* Update dvclive/catalyst.py

Co-authored-by: David de la Iglesia Castro <daviddelaiglesiacastro@gmail.com>

* upd: catalyst tests

* Fix: cross platform tests

Co-authored-by: David de la Iglesia Castro <daviddelaiglesiacastro@gmail.com>
  • Loading branch information
pacifikus and daavoo authored Aug 24, 2021
1 parent b00ffda commit 138debe
Show file tree
Hide file tree
Showing 3 changed files with 132 additions and 1 deletion.
27 changes: 27 additions & 0 deletions dvclive/catalyst.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from catalyst.core.callback import Callback, CallbackOrder

import dvclive


class DvcLiveCallback(Callback):
def __init__(self, model_file=None):
super().__init__(order=CallbackOrder.external)
self.model_file = model_file

def on_epoch_end(self, runner) -> None:
step = runner.stage_epoch_step

for loader_key, per_loader_metrics in runner.epoch_metrics.items():
for key, value in per_loader_metrics.items():
key = key.replace("/", "_")
dvclive.log(f"{loader_key}/{key}", float(value), step)

if self.model_file:
checkpoint = runner.engine.pack_checkpoint(
model=runner.model,
criterion=runner.criterion,
optimizer=runner.optimizer,
scheduler=runner.scheduler,
)
runner.engine.save_checkpoint(checkpoint, self.model_file)
dvclive.next_step()
4 changes: 3 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,9 @@ def run(self):
xgb = ["xgboost"]
lgbm = ["lightgbm"]
hugginface = ["transformers", "datasets"]
catalyst = ["catalyst"]

all_libs = mmcv + tf + xgb + lgbm + hugginface
all_libs = mmcv + tf + xgb + lgbm + hugginface + catalyst

tests_requires = [
"pylint==2.5.3",
Expand Down Expand Up @@ -75,6 +76,7 @@ def run(self):
"xgb": xgb,
"lgbm": lgbm,
"huggingface": hugginface,
"catalyst": catalyst,
},
keywords="data-science metrics machine-learning developer-tools ai",
python_requires=">=3.6",
Expand Down
102 changes: 102 additions & 0 deletions tests/test_catalyst.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
import os

import pytest
from catalyst import dl
from catalyst.contrib.datasets import MNIST
from catalyst.data import ToTensor
from catalyst.utils.torch import get_available_engine
from torch import nn, optim
from torch.utils.data import DataLoader

import dvclive
from dvclive.catalyst import DvcLiveCallback

# pylint: disable=redefined-outer-name, unused-argument


@pytest.fixture
def loaders():
train_data = MNIST(
os.getcwd(), train=True, download=True, transform=ToTensor()
)
valid_data = MNIST(
os.getcwd(), train=False, download=True, transform=ToTensor()
)
return {
"train": DataLoader(train_data, batch_size=32),
"valid": DataLoader(valid_data, batch_size=32),
}


@pytest.fixture
def runner():
return dl.SupervisedRunner(
engine=get_available_engine(),
input_key="features",
output_key="logits",
target_key="targets",
loss_key="loss",
)


def test_catalyst_callback(tmp_dir, runner, loaders):
dvclive.init("dvc_logs")

model = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10))
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.02)

runner.train(
model=model,
criterion=criterion,
optimizer=optimizer,
loaders=loaders,
num_epochs=2,
callbacks=[
dl.AccuracyCallback(input_key="logits", target_key="targets"),
DvcLiveCallback(),
],
logdir="./logs",
valid_loader="valid",
valid_metric="loss",
minimize_valid_metric=True,
verbose=True,
load_best_on_end=True,
)

assert os.path.exists("dvc_logs")

train_path = tmp_dir / "dvc_logs/train"
valid_path = tmp_dir / "dvc_logs/valid"

assert train_path.is_dir()
assert valid_path.is_dir()
assert (train_path / "accuracy.tsv").exists()


def test_catalyst_model_file(tmp_dir, runner, loaders):
dvclive.init("dvc_logs")

model = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10))
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.02)

runner.train(
model=model,
engine=runner.engine,
criterion=criterion,
optimizer=optimizer,
loaders=loaders,
num_epochs=2,
callbacks=[
dl.AccuracyCallback(input_key="logits", target_key="targets"),
DvcLiveCallback("model.pth"),
],
logdir="./logs",
valid_loader="valid",
valid_metric="loss",
minimize_valid_metric=True,
verbose=True,
load_best_on_end=True,
)
assert (tmp_dir / "model.pth").is_file()

0 comments on commit 138debe

Please sign in to comment.