Skip to content

Commit

Permalink
'Remove step' feature for local workspaces (#588)
Browse files Browse the repository at this point in the history
Co-authored-by: Pranjali Basmatkar <pranjalib@Pranjali-Basmatkars-MacBook-Pro.local>
Co-authored-by: Akshita Bhagia <akshita23bhagia@gmail.com>
  • Loading branch information
3 people authored Oct 5, 2023
1 parent 3857415 commit 11b5229
Show file tree
Hide file tree
Showing 17 changed files with 208 additions and 2 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ on:
- "v*.*.*"

env:
CACHE_PREFIX: v3 # Change this to invalidate existing cache.
CACHE_PREFIX: v5 # Change this to invalidate existing cache.
PYTHON_PATH: ./
DEFAULT_PYTHON: 3.9
WANDB_API_KEY: ${{ secrets.WANDB_API_KEY }}
Expand Down
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## Unreleased

### Added

- Added the `Workspace.remove_step()` method to safely remove steps.
- The `GSWorkspace()` can now be initialized with google cloud bucket subfolders.

### Fixed
Expand Down
7 changes: 7 additions & 0 deletions tango/integrations/beaker/workspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,3 +420,10 @@ def _update_step_info(self, step_info: StepInfo):
self.Constants.STEP_INFO_FNAME, # step info filename
quiet=True,
)

def _remove_step_info(self, step_info: StepInfo) -> None:
# remove dir from beaker workspace
dataset_name = self.Constants.step_artifact_name(step_info)
step_dataset = self.beaker.dataset.get(dataset_name)
if step_dataset is not None:
self.beaker.dataset.delete(step_dataset)
9 changes: 9 additions & 0 deletions tango/integrations/gs/workspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,15 @@ def _update_step_info(self, step_info: StepInfo):

self._ds.put(step_info_entity)

def _remove_step_info(self, step_info: StepInfo) -> None:
# remove dir from bucket
step_artifact = self.client.get(self.Constants.step_artifact_name(step_info))
if step_artifact is not None:
self.client.delete(step_artifact)

# remove datastore entities
self._ds.delete(key=self._ds.key("stepinfo", step_info.unique_id))

def _save_run_log(self, name: str, log_file: Path):
"""
The logs are stored in the bucket. The Run object details are stored in
Expand Down
1 change: 1 addition & 0 deletions tango/integrations/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
from tango.integrations.transformers import *
available_models = []
for name in sorted(Model.list_available()):
if name.startswith("transformers::AutoModel"):
available_models.append(name)
Expand Down
7 changes: 7 additions & 0 deletions tango/integrations/wandb/workspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,13 @@ def step_failed(self, step: Step, e: BaseException) -> None:
if step.unique_id in self._running_step_info:
del self._running_step_info[step.unique_id]

def remove_step(self, step_unique_id: str):
"""
Removes cached step using the given unique step id
:raises KeyError: If there is no step with the given name.
"""
raise NotImplementedError()

def register_run(self, targets: Iterable[Step], name: Optional[str] = None) -> Run:
all_steps = set(targets)
for step in targets:
Expand Down
5 changes: 5 additions & 0 deletions tango/step_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,11 @@ def __setitem__(self, step: Step, value: Any) -> None:
"""Writes the results for the given step. Throws an exception if the step is already cached."""
raise NotImplementedError()

@abstractmethod
def __delitem__(self, step_unique_id: Union[Step, StepInfo]) -> None:
"""Removes a step from step cache"""
raise NotImplementedError()

@abstractmethod
def __len__(self) -> int:
"""Returns the number of results saved in this cache."""
Expand Down
21 changes: 21 additions & 0 deletions tango/step_caches/local_step_cache.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import collections
import logging
import os
import shutil
import warnings
import weakref
from pathlib import Path
Expand Down Expand Up @@ -89,6 +91,17 @@ def _get_from_cache(self, key: str) -> Optional[Any]:
except KeyError:
return None

def _remove_from_cache(self, key: str) -> None:
# check and remove from strong cache
if key in self.strong_cache:
del self.strong_cache[key]
assert key not in self.strong_cache

# check and remove from weak cache
if key in self.weak_cache:
del self.weak_cache[key]
assert key not in self.weak_cache

def _metadata_path(self, step_or_unique_id: Union[Step, StepInfo, str]) -> Path:
return self.step_dir(step_or_unique_id) / self.METADATA_FILE_NAME

Expand Down Expand Up @@ -147,6 +160,14 @@ def __setitem__(self, step: Step, value: Any) -> None:
pass
raise

def __delitem__(self, step: Union[Step, StepInfo]) -> None:
location = str(self.dir) + "/" + str(step.unique_id)
try:
shutil.rmtree(location)
self._remove_from_cache(step.unique_id)
except OSError:
raise OSError(f"Step cache folder for '{step.unique_id}' not found. Cannot be deleted.")

def __len__(self) -> int:
return sum(1 for _ in self.dir.glob(f"*/{self.METADATA_FILE_NAME}"))

Expand Down
6 changes: 6 additions & 0 deletions tango/step_caches/memory_step_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,12 @@ def __setitem__(self, step: Step, value: Any) -> None:
UserWarning,
)

def __delitem__(self, step: Union[Step, StepInfo]) -> None:
if step.unique_id in self.cache:
del self.cache[step.unique_id]
else:
raise KeyError(f"{step.unique_id} not present in the memory cache. Cannot be deleted.")

def __contains__(self, step: object) -> bool:
if isinstance(step, (Step, StepInfo)):
return step.unique_id in self.cache
Expand Down
8 changes: 8 additions & 0 deletions tango/workspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,6 +419,14 @@ def step_result(self, step_name: str) -> Any:
return self.step_cache[run.steps[step_name]]
raise KeyError(f"No step named '{step_name}' found in previous runs")

@abstractmethod
def remove_step(self, step_unique_id: str):
"""
Removes cached step using the given unique step id
:raises KeyError: If there is no step with the given name.
"""
raise NotImplementedError()

def capture_logs_for_run(self, name: str) -> ContextManager[None]:
"""
Should return a context manager that can be used to capture the logs for a run.
Expand Down
14 changes: 14 additions & 0 deletions tango/workspaces/local_workspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,20 @@ def step_failed(self, step: Step, e: BaseException) -> None:
lock.release()
del self.locks[step]

def remove_step(self, step_unique_id: str) -> None:
"""
Get Step unique id from the user and remove the step information from cache
:raises KeyError: If no step with the unique name found in the cache dir
"""
with SqliteDict(self.step_info_file) as d:
try:
step_info = self.step_info(step_unique_id)
del d[step_unique_id]
d.commit()
del self.cache[step_info]
except KeyError:
raise KeyError(f"No step named '{step_unique_id}' found")

def register_run(self, targets: Iterable[Step], name: Optional[str] = None) -> Run:
# sanity check targets
targets = list(targets)
Expand Down
12 changes: 12 additions & 0 deletions tango/workspaces/memory_workspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,18 @@ def step_failed(self, step: Step, e: BaseException) -> None:
existing_step_info.end_time = utc_now_datetime()
existing_step_info.error = exception_to_string(e)

def remove_step(self, step_unique_id: str) -> None:
"""
Get Step unique id from the user and remove the step information from memory cache
:raises KeyError: If no step with the unique name found in the cache dir
"""
try:
step_info = self.step_info(step_unique_id)
del self.unique_id_to_info[step_unique_id]
del self.step_cache[step_info]
except KeyError:
raise KeyError(f"{step_unique_id} step info not found, step cache cannot be deleted")

def register_run(self, targets: Iterable[Step], name: Optional[str] = None) -> Run:
if name is None:
name = petname.generate()
Expand Down
20 changes: 20 additions & 0 deletions tango/workspaces/remote_workspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,22 @@ def step_failed(self, step: Step, e: BaseException) -> None:
finally:
self.locks.pop(step).release()

def remove_step(self, step_unique_id: str) -> None:
"""
Get Step unique id from the user and remove the step information from cache
:raises KeyError: If no step with the unique name found in the cache dir
"""
try:
step_info = self.step_info(step_unique_id)
# remove remote objects
self._remove_step_info(step_info)

# remove cache info
del self.cache[step_info]
except KeyError:
raise KeyError(f"No step named '{step_unique_id}' found.")
return None

def _get_run_step_info(self, targets: Iterable[Step]) -> Tuple[Dict, Dict]:
import concurrent.futures

Expand Down Expand Up @@ -229,3 +245,7 @@ def capture_logs_for_run(self, name: str) -> Generator[None, None, None]:
@abstractmethod
def _update_step_info(self, step_info: StepInfo):
raise NotImplementedError()

@abstractmethod
def _remove_step_info(self, step_info: StepInfo):
raise NotImplementedError()
28 changes: 28 additions & 0 deletions tests/integrations/beaker/workspace_test.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
import pytest
from beaker import DatasetNotFound

from tango.common.testing.steps import FloatStep
from tango.integrations.beaker.workspace import BeakerWorkspace
from tango.step_info import StepState
from tango.workspace import Workspace


def test_from_url(beaker_workspace: str):
print(beaker_workspace)
workspace = Workspace.from_url(f"beaker://{beaker_workspace}")
assert isinstance(workspace, BeakerWorkspace)

Expand All @@ -22,3 +26,27 @@ def test_direct_usage(beaker_workspace: str):
workspace.step_finished(step, 1.0)
assert workspace.step_info(step).state == StepState.COMPLETED
assert workspace.step_result_for_run(run.name, "float") == 1.0


def test_remove_step(beaker_workspace: str):
beaker_workspace = "ai2/tango_remove_cache_test"
workspace = BeakerWorkspace(beaker_workspace)
step = FloatStep(step_name="float", result=1.0)

workspace.step_starting(step)
workspace.step_finished(step, 1.0)

step_info = workspace.step_info(step)
dataset_name = workspace.Constants.step_artifact_name(step_info)
cache = workspace.step_cache

assert workspace.beaker.dataset.get(dataset_name) is not None
assert step in cache

workspace.remove_step(step.unique_id)
cache = workspace.step_cache
dataset_name = workspace.Constants.step_artifact_name(step_info)

with pytest.raises(DatasetNotFound):
workspace.beaker.dataset.get(dataset_name)
assert step not in cache
27 changes: 27 additions & 0 deletions tests/integrations/gs/workspace_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,30 @@ def test_direct_usage(self, gs_path: str):
workspace.step_finished(step, 1.0)
assert workspace.step_info(step).state == StepState.COMPLETED
assert workspace.step_result_for_run(run.name, "float") == 1.0

def test_remove_step(self):
workspace = GSWorkspace(GS_BUCKET_NAME)
step = FloatStep(step_name="float", result=1.0)
step_info = workspace.step_info(step)

workspace.step_starting(step)
workspace.step_finished(step, 1.0)
bucket_artifact = workspace.Constants.step_artifact_name(step_info)
ds_entity = workspace._ds.get(key=workspace._ds.key("stepinfo", step_info.unique_id))
cache = workspace.step_cache

assert workspace.client.artifacts(prefix=bucket_artifact) is not None
assert ds_entity is not None
assert step in cache

workspace.remove_step(step.unique_id)
cache = workspace.step_cache

ds_entity = workspace._ds.get(key=workspace._ds.key("stepinfo", step_info.unique_id))

with pytest.raises(Exception) as excinfo:
workspace.client.artifacts(prefix=bucket_artifact)

assert "KeyError" in str(excinfo)
assert ds_entity is None
assert step not in cache
21 changes: 21 additions & 0 deletions tests/workspaces/local_workspace_test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from shutil import copytree

import pytest
from sqlitedict import SqliteDict

from tango import Step
from tango.common.testing import TangoTestCase
Expand Down Expand Up @@ -73,3 +74,23 @@ def test_local_workspace_upgrade_v1_to_v2(self):
while len(dependencies) > 0:
step_info = workspace.step_info(dependencies.pop())
dependencies.extend(step_info.dependencies)

def test_remove_step(self):
workspace = LocalWorkspace(self.TEST_DIR)
step = AdditionStep(a=1, b=2)
workspace.step_starting(step)
workspace.step_finished(step, 1.0)

with SqliteDict(workspace.step_info_file) as d:
assert step.unique_id in d

cache = workspace.step_cache
assert step in cache

workspace.remove_step(step.unique_id)

with SqliteDict(workspace.step_info_file) as d:
assert step.unique_id not in d

cache = workspace.step_cache
assert step not in cache
20 changes: 20 additions & 0 deletions tests/workspaces/memory_workspace_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from tango.common.testing.steps import FloatStep
from tango.workspaces import MemoryWorkspace


def test_remove_step():
workspace = MemoryWorkspace()
step = FloatStep(step_name="float", result=1.0)

workspace.step_starting(step)
workspace.step_finished(step, 1.0)
cache = workspace.step_cache

assert step.unique_id in workspace.unique_id_to_info
assert step in cache

workspace.remove_step(step.unique_id)
cache = workspace.step_cache

assert step.unique_id not in workspace.unique_id_to_info
assert step not in cache

0 comments on commit 11b5229

Please sign in to comment.