Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor ArtifactManager and restart TaskManager #529

Merged
merged 4 commits into from
Apr 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 27 additions & 3 deletions azimuth/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,13 @@

from azimuth.config import AzimuthConfig, load_azimuth_config
from azimuth.dataset_split_manager import DatasetSplitManager
from azimuth.modules.base_classes import DaskModule
from azimuth.modules.base_classes import ArtifactManager, DaskModule
from azimuth.startup import startup_tasks
from azimuth.task_manager import TaskManager
from azimuth.types import DatasetSplitName, ModuleOptions, SupportedModule
from azimuth.utils.cluster import default_cluster
from azimuth.utils.conversion import JSONResponseIgnoreNan
from azimuth.utils.logs import set_logger_config
from azimuth.utils.project import load_dataset_split_managers_from_config
from azimuth.utils.validation import assert_not_none

_dataset_split_managers: Dict[DatasetSplitName, Optional[DatasetSplitManager]] = {}
Expand Down Expand Up @@ -296,6 +295,32 @@ def create_app() -> FastAPI:
return app


def load_dataset_split_managers_from_config(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I needed to move this function because of circular import issues.
One big change that I made here is to use an ArtifactManager in the main thread. I was afraid that multiple dms were created otherwise.

azimuth_config: AzimuthConfig,
) -> Dict[DatasetSplitName, Optional[DatasetSplitManager]]:
"""
Load all dataset splits for the application.

Args:
azimuth_config: Azimuth Configuration.

Returns:
For all DatasetSplitName, None or a dataset_split manager.

"""
artifact_manager = ArtifactManager.instance()
dataset = artifact_manager.get_dataset_dict(azimuth_config)

return {
dataset_split_name: None
if dataset_split_name not in dataset
else artifact_manager.get_dataset_split_manager(
azimuth_config, DatasetSplitName[dataset_split_name]
)
for dataset_split_name in [DatasetSplitName.eval, DatasetSplitName.train]
}


def initialize_managers(azimuth_config: AzimuthConfig, cluster: SpecCluster):
"""Initialize DatasetSplitManagers and TaskManagers.

Expand Down Expand Up @@ -346,7 +371,6 @@ def run_validation_module(pipeline_index=None):
else:
for pipeline_index in range(len(config.pipelines)):
run_validation_module(pipeline_index)
task_manager.clear_worker_cache()


def run_startup_tasks(azimuth_config: AzimuthConfig, cluster: SpecCluster):
Expand Down
10 changes: 10 additions & 0 deletions azimuth/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,16 @@ def check_pipeline_names(cls, pipeline_definitions):
raise ValueError(f"Duplicated pipeline names {pipeline_names}.")
return pipeline_definitions

def get_model_contract_hash(self):
"""Hash for fields related to model contract only (excluding fields from the parents)."""
return md5_hash(
self.dict(
include=ModelContractConfig.__fields__.keys()
- CommonFieldsConfig.__fields__.keys(),
JosephMarinier marked this conversation as resolved.
Show resolved Hide resolved
by_alias=True,
)
)


class MetricsConfig(ModelContractConfig):
# Custom HuggingFace metrics
Expand Down
96 changes: 65 additions & 31 deletions azimuth/modules/base_classes/artifact_manager.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
# Copyright ServiceNow, Inc. 2021 – 2022
# This source code is licensed under the Apache 2.0 license found in the LICENSE file
# in the root directory of this source tree.
from multiprocessing import Lock
from typing import Callable, Dict, Optional
from collections import defaultdict
from typing import Callable, Dict

import structlog
from datasets import DatasetDict

from azimuth.config import AzimuthConfig
Expand All @@ -18,31 +19,67 @@
Hash = int


log = structlog.get_logger()


class Singleton:
JosephMarinier marked this conversation as resolved.
Show resolved Hide resolved
"""
A non-thread-safe helper class to ease implementing singletons.
JosephMarinier marked this conversation as resolved.
Show resolved Hide resolved
This should be used as a decorator -- not a metaclass -- to the
class that should be a singleton.

To get the singleton instance, use the `instance` method. Trying
to use `__call__` will result in a `TypeError` being raised.

Args:
decorated: Decorated class
"""

def __init__(self, decorated):
self._decorated = decorated

def instance(self):
"""
Returns the singleton instance. Upon its first call, it creates a
new instance of the decorated class and calls its `__init__` method.
On all subsequent calls, the already created instance is returned.

Returns:
Instance of the decorated class
"""
try:
return self._instance
except AttributeError:
self._instance = self._decorated()
return self._instance

def __call__(self):
raise TypeError("Singletons must be accessed through `instance()`.")

def clear_instance(self):
"""For test purposes only"""
if hasattr(self, "_instance"):
delattr(self, "_instance")


@Singleton
class ArtifactManager:
"""This class is a singleton which holds different artifacts.

Artifacts include dataset_split_managers, datasets and models for each config, so they don't
need to be reloaded many times for a same module.
need to be reloaded many times for a same module. Inspired from
https://stackoverflow.com/questions/31875/is-there-a-simple-elegant-way-to-define-singletons.
"""

instance: Optional["ArtifactManager"] = None

def __init__(self):
# The keys of the dict are a hash of the config.
self.dataset_dict_mapping: Dict[Hash, DatasetDict] = {}
self.dataset_split_managers_mapping: Dict[
Hash, Dict[DatasetSplitName, DatasetSplitManager]
] = {}
self.models_mapping: Dict[Hash, Dict[int, Callable]] = {}
self.tokenizer = None
] = defaultdict(dict)
self.models_mapping: Dict[Hash, Dict[int, Callable]] = defaultdict(dict)
self.metrics = {}

@classmethod
def get_instance(cls):
with Lock():
if cls.instance is None:
cls.instance = cls()
return cls.instance
log.debug(f"Creating new Artifact Manager {id(self)}.")

def get_dataset_split_manager(
self, config: AzimuthConfig, name: DatasetSplitName
Expand All @@ -68,8 +105,6 @@ def get_dataset_split_manager(
f"Found {tuple(dataset_dict.keys())}."
)
project_hash: Hash = config.get_project_hash()
if project_hash not in self.dataset_split_managers_mapping:
self.dataset_split_managers_mapping[project_hash] = {}
if name not in self.dataset_split_managers_mapping[project_hash]:
self.dataset_split_managers_mapping[project_hash][name] = DatasetSplitManager(
name=name,
Expand All @@ -78,6 +113,7 @@ def get_dataset_split_manager(
initial_prediction_tags=ALL_PREDICTION_TAGS,
dataset_split=dataset_dict[name],
)
log.debug(f"New {name} DM in Artifact Manager {id(self)}")
return self.dataset_split_managers_mapping[project_hash][name]

def get_dataset_dict(self, config) -> DatasetDict:
Expand Down Expand Up @@ -106,25 +142,23 @@ def get_model(self, config: AzimuthConfig, pipeline_idx: int):
Returns:
Loaded model.
"""

project_hash: Hash = config.get_project_hash()
if project_hash not in self.models_mapping:
self.models_mapping[project_hash] = {}
if pipeline_idx not in self.models_mapping[project_hash]:
model_contract_hash: Hash = config.get_model_contract_hash()
if pipeline_idx not in self.models_mapping[model_contract_hash]:
log.debug(f"Loading pipeline {pipeline_idx}.")
pipelines = assert_not_none(config.pipelines)
self.models_mapping[project_hash][pipeline_idx] = load_custom_object(
self.models_mapping[model_contract_hash][pipeline_idx] = load_custom_object(
assert_not_none(pipelines[pipeline_idx].model), azimuth_config=config
)

return self.models_mapping[project_hash][pipeline_idx]
return self.models_mapping[model_contract_hash][pipeline_idx]

def get_metric(self, config, name: str, **kwargs):
hash: Hash = md5_hash({"name": name, **kwargs})
if hash not in self.metrics:
self.metrics[hash] = load_custom_object(config.metrics[name], **kwargs)
return self.metrics[hash]
metric_hash: Hash = md5_hash({"name": name, **kwargs})
if metric_hash not in self.metrics:
self.metrics[metric_hash] = load_custom_object(config.metrics[name], **kwargs)
return self.metrics[metric_hash]

@classmethod
def clear_cache(cls) -> None:
with Lock():
cls.instance = None
def instance(cls):
JosephMarinier marked this conversation as resolved.
Show resolved Hide resolved
# Implemented in decorator
raise NotImplementedError
5 changes: 1 addition & 4 deletions azimuth/modules/base_classes/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def get_indices(self, name: Optional[DatasetSplitName] = None) -> List[int]:
def artifact_manager(self):
"""This is set as a property so the Module always have access to the current version of
the ArtifactManager on the worker."""
return ArtifactManager.get_instance()
return ArtifactManager.instance()

@property
def available_dataset_splits(self) -> Set[DatasetSplitName]:
Expand Down Expand Up @@ -215,6 +215,3 @@ def get_pipeline_definition(self) -> PipelineDefinition:
pipeline_index = assert_not_none(self.mod_options.pipeline_index)
current_pipeline = pipelines[pipeline_index]
return current_pipeline

def clear_cache(self):
self.artifact_manager.clear_cache()
5 changes: 3 additions & 2 deletions azimuth/routers/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,25 +78,26 @@ def patch_config(
)

try:
log.info(f"Validating config change with {partial_config}.")
new_config = update_config(old_config=config, partial_config=partial_config)
if attribute_changed_in_config("large_dask_cluster", partial_config, config):
cluster = default_cluster(partial_config["large_dask_cluster"])
else:
cluster = task_manager.cluster
run_startup_tasks(new_config, cluster)
log.info(f"Config successfully updated with {partial_config}.")
except Exception as e:
log.error("Rollback config update due to error", exc_info=e)
new_config = config
initialize_managers(new_config, task_manager.cluster)
log.info("Config update cancelled.")
if isinstance(e, (AzimuthValidationError, ValidationError)):
raise HTTPException(HTTP_400_BAD_REQUEST, detail=str(e))
else:
raise HTTPException(
HTTP_500_INTERNAL_SERVER_ERROR, detail="Error when loading the new config."
)

# Clear workers so that they load the correct config.
task_manager.clear_worker_cache()
return new_config


Expand Down
14 changes: 10 additions & 4 deletions azimuth/routers/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import os
import time
from os.path import join as pjoin
from typing import Dict, Generator, List, Optional, cast
from typing import Dict, Generator, List, Optional

import pandas as pd
from fastapi import APIRouter, Depends, HTTPException
Expand Down Expand Up @@ -155,7 +155,10 @@ def get_export_perturbed_set(

output = list(
make_utterance_level_result(
dataset_split_manager, task_result, pipeline_index=pipeline_index_not_null
dataset_split_manager,
task_result,
pipeline_index=pipeline_index_not_null,
config=config,
)
)
with open(path, "w") as f:
Expand All @@ -164,20 +167,23 @@ def get_export_perturbed_set(


def make_utterance_level_result(
dm: DatasetSplitManager, results: List[List[PerturbedUtteranceResult]], pipeline_index: int
dm: DatasetSplitManager,
results: List[List[PerturbedUtteranceResult]],
pipeline_index: int,
config: AzimuthConfig,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was causing an error in a test. We shouldn't rely on the config of the dataset split manager, because it only depends on the project_hash.

) -> Generator[Dict, None, None]:
"""Massage perturbation testing results for the frontend.

Args:
dm: Current DatasetSplitManager.
results: Output of Perturbation Testing.
pipeline_index: Index of the pipeline that made the results.
config: Azimuth config

Returns:
Generator that yield json-able object for the frontend.

"""
config = cast(AzimuthConfig, dm.config)
for idx, (utterance, test_results) in enumerate(
zip(
dm.get_dataset_split(
Expand Down
2 changes: 0 additions & 2 deletions azimuth/routers/utterances.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,6 @@ def patch_utterances(
utterances: List[UtterancePatch] = Body(...),
config: AzimuthConfig = Depends(get_config),
dataset_split_manager: DatasetSplitManager = Depends(get_dataset_split_manager),
task_manager: TaskManager = Depends(get_task_manager),
ignore_not_found: bool = Query(False),
) -> List[UtterancePatch]:
if ignore_not_found:
Expand All @@ -250,7 +249,6 @@ def patch_utterances(

dataset_split_manager.add_tags(data_actions)

task_manager.clear_worker_cache()
updated_tags = dataset_split_manager.get_tags(row_indices)

return [
Expand Down
19 changes: 10 additions & 9 deletions azimuth/startup.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,8 +153,6 @@ def on_end(fut: Future, module: DaskModule, dm: DatasetSplitManager, task_manage
# Task is done, save the result.
if isinstance(module, DatasetResultModule):
module.save_result(module.result(), dm)
# We only need to clear cache when the dataset is modified.
task_manager.clear_worker_cache()
else:
log.exception("Error in", module=module, fut=fut, exc_info=fut.exception())

Expand Down Expand Up @@ -257,12 +255,16 @@ def startup_tasks(

mods = start_tasks_for_dms(config, dataset_split_managers, task_manager, start_up_tasks)

# Start a thread to monitor the status.
th = threading.Thread(
target=wait_for_startup, args=(mods, task_manager), name=START_UP_THREAD_NAME
)
th.setDaemon(True)
th.start()
startup_ready = all(m.done() for m in mods.values())
if startup_ready:
JosephMarinier marked this conversation as resolved.
Show resolved Hide resolved
log.info("Loading the application from cache. It should be accessible now.")
else:
# Start a thread to monitor the status.
th = threading.Thread(
target=wait_for_startup, args=(mods, task_manager), name=START_UP_THREAD_NAME
)
th.setDaemon(True)
th.start()

return mods

Expand Down Expand Up @@ -376,4 +378,3 @@ def log_progress():
task_manager.restart()
# After restarting, it is safe to unlock the task manager.
task_manager.unlock()
log.info("Cluster restarted to free memory.")
9 changes: 3 additions & 6 deletions azimuth/task_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from distributed import Client, SpecCluster

from azimuth.config import AzimuthConfig
from azimuth.modules.base_classes import ArtifactManager, DaskModule, ExpirableMixin
from azimuth.modules.base_classes import DaskModule, ExpirableMixin
from azimuth.modules.task_mapping import model_contract_methods, modules
from azimuth.types import (
DatasetSplitName,
Expand Down Expand Up @@ -66,7 +66,6 @@ def close(self):
mod.future.cancel()
except Exception:
pass
self.clear_worker_cache()
self.client.close()

def register_task(self, name, cls):
Expand Down Expand Up @@ -213,10 +212,8 @@ def status(self):
**self.get_all_tasks_status(task=None),
}

def clear_worker_cache(self):
self.client.run(ArtifactManager.clear_cache)

def restart(self):
# Clear futures to free memory.
for task_name, module in self.current_tasks.items():
module.future = None
self.client.restart()
log.info("Cluster restarted to free memory.")
Loading