-
Notifications
You must be signed in to change notification settings - Fork 36
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* LGBM integration added * Tests for lgbm integration * test_lgbm: added tmp_dir fixture * Update dvclive/lgbm.py Save model at the end of each iteration Co-authored-by: David de la Iglesia Castro <daviddelaiglesiacastro@gmail.com> * Huggingface integration * Add on_log event * Update dvclive/huggingface.py Co-authored-by: David de la Iglesia Castro <daviddelaiglesiacastro@gmail.com> * fix: huggingface test after callback changes * revert last commit * fix: huggingface test after calback changes * Updated test_huggingface * Updated test_huggingface * Catalyst integration * Callback rename * Rename callback in test * Update dvclive/catalyst.py Co-authored-by: David de la Iglesia Castro <daviddelaiglesiacastro@gmail.com> * upd: catalyst tests * Fix: cross platform tests Co-authored-by: David de la Iglesia Castro <daviddelaiglesiacastro@gmail.com>
- Loading branch information
Showing
3 changed files
with
132 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
from catalyst.core.callback import Callback, CallbackOrder | ||
|
||
import dvclive | ||
|
||
|
||
class DvcLiveCallback(Callback): | ||
def __init__(self, model_file=None): | ||
super().__init__(order=CallbackOrder.external) | ||
self.model_file = model_file | ||
|
||
def on_epoch_end(self, runner) -> None: | ||
step = runner.stage_epoch_step | ||
|
||
for loader_key, per_loader_metrics in runner.epoch_metrics.items(): | ||
for key, value in per_loader_metrics.items(): | ||
key = key.replace("/", "_") | ||
dvclive.log(f"{loader_key}/{key}", float(value), step) | ||
|
||
if self.model_file: | ||
checkpoint = runner.engine.pack_checkpoint( | ||
model=runner.model, | ||
criterion=runner.criterion, | ||
optimizer=runner.optimizer, | ||
scheduler=runner.scheduler, | ||
) | ||
runner.engine.save_checkpoint(checkpoint, self.model_file) | ||
dvclive.next_step() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
import os | ||
|
||
import pytest | ||
from catalyst import dl | ||
from catalyst.contrib.datasets import MNIST | ||
from catalyst.data import ToTensor | ||
from catalyst.utils.torch import get_available_engine | ||
from torch import nn, optim | ||
from torch.utils.data import DataLoader | ||
|
||
import dvclive | ||
from dvclive.catalyst import DvcLiveCallback | ||
|
||
# pylint: disable=redefined-outer-name, unused-argument | ||
|
||
|
||
@pytest.fixture | ||
def loaders(): | ||
train_data = MNIST( | ||
os.getcwd(), train=True, download=True, transform=ToTensor() | ||
) | ||
valid_data = MNIST( | ||
os.getcwd(), train=False, download=True, transform=ToTensor() | ||
) | ||
return { | ||
"train": DataLoader(train_data, batch_size=32), | ||
"valid": DataLoader(valid_data, batch_size=32), | ||
} | ||
|
||
|
||
@pytest.fixture | ||
def runner(): | ||
return dl.SupervisedRunner( | ||
engine=get_available_engine(), | ||
input_key="features", | ||
output_key="logits", | ||
target_key="targets", | ||
loss_key="loss", | ||
) | ||
|
||
|
||
def test_catalyst_callback(tmp_dir, runner, loaders): | ||
dvclive.init("dvc_logs") | ||
|
||
model = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10)) | ||
criterion = nn.CrossEntropyLoss() | ||
optimizer = optim.Adam(model.parameters(), lr=0.02) | ||
|
||
runner.train( | ||
model=model, | ||
criterion=criterion, | ||
optimizer=optimizer, | ||
loaders=loaders, | ||
num_epochs=2, | ||
callbacks=[ | ||
dl.AccuracyCallback(input_key="logits", target_key="targets"), | ||
DvcLiveCallback(), | ||
], | ||
logdir="./logs", | ||
valid_loader="valid", | ||
valid_metric="loss", | ||
minimize_valid_metric=True, | ||
verbose=True, | ||
load_best_on_end=True, | ||
) | ||
|
||
assert os.path.exists("dvc_logs") | ||
|
||
train_path = tmp_dir / "dvc_logs/train" | ||
valid_path = tmp_dir / "dvc_logs/valid" | ||
|
||
assert train_path.is_dir() | ||
assert valid_path.is_dir() | ||
assert (train_path / "accuracy.tsv").exists() | ||
|
||
|
||
def test_catalyst_model_file(tmp_dir, runner, loaders): | ||
dvclive.init("dvc_logs") | ||
|
||
model = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10)) | ||
criterion = nn.CrossEntropyLoss() | ||
optimizer = optim.Adam(model.parameters(), lr=0.02) | ||
|
||
runner.train( | ||
model=model, | ||
engine=runner.engine, | ||
criterion=criterion, | ||
optimizer=optimizer, | ||
loaders=loaders, | ||
num_epochs=2, | ||
callbacks=[ | ||
dl.AccuracyCallback(input_key="logits", target_key="targets"), | ||
DvcLiveCallback("model.pth"), | ||
], | ||
logdir="./logs", | ||
valid_loader="valid", | ||
valid_metric="loss", | ||
minimize_valid_metric=True, | ||
verbose=True, | ||
load_best_on_end=True, | ||
) | ||
assert (tmp_dir / "model.pth").is_file() |