Skip to content

Commit

Permalink
Python SDK v1 (determined-ai#8005)
Browse files Browse the repository at this point in the history
Python SDK V1: various refactors, bugfixes, feature additions, enhancements, and deprecations for the Python SDK.

- Deprecates `get_checkpoints`, `select_checkpoint`, `top_checkpoint`, `top_n_checkpoints` on `Trial` and `Experiment`
- Introduces `reload()` on class objects and a caching strategy
- Renames training `Trial` to `LegacyTrial` and the SDK's `TrialReference` to `Trial`
  • Loading branch information
azhou-determined authored Oct 13, 2023
1 parent 3a24611 commit 67f6dc1
Show file tree
Hide file tree
Showing 76 changed files with 10,208 additions and 2,349 deletions.
107 changes: 53 additions & 54 deletions docs/model-dev-guide/model-management/checkpoints.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,83 +25,82 @@ The Checkpoint Export API is a subset of the features found in the
Querying Checkpoints
**********************

The :class:`~determined.experimental.client.ExperimentReference` class is a reference to an
experiment. The reference contains the
:meth:`~determined.experimental.client.ExperimentReference.top_checkpoint` method. Without
arguments, the method will check the experiment configuration searcher field for the ``metric`` and
``smaller_is_better`` values. These values are used to sort the experiment's checkpoints by
validation performance. The searcher settings in the following snippet from an experiment
configuration file will result in checkpoints being sorted by the loss metric in ascending order.
Use the :class:`~determined.experimental.client.Experiment` class to reference an experiment. The
:meth:`~determined.experimental.client.Experiment.list_checkpoints` method, when called without
arguments, returns checkpoints sorted based on the ``metric`` and ``smaller_is_better`` values from
the experiment configuration's searcher field.

For example, in the following experiment configuration file snippet, Determined will sort
checkpoints by the loss metric and in ascending order.

.. code:: yaml
searcher:
metric: "loss"
smaller_is_better: true
The following snippet of Python code can be run after the specified experiment has generated a
checkpoint. It returns an instance of :class:`~determined.experimental.client.Checkpoint`
representing the checkpoint that has the best validation metric.
After generating a checkpoint for the specified experiment, you can run the Python code below. This
code retrieves a list of sorted :class:`~determined.experimental.client.Checkpoint` instances
associated with the experiment and selects the checkpoint with the best validation metric.

.. code:: python
from determined.experimental import client
checkpoint = client.get_experiment(id).top_checkpoint()
checkpoint = client.get_experiment(id).list_checkpoints()[0]
Checkpoints can be sorted by any metric using the ``sort_by`` keyword argument, which defines which
metric to use, and ``smaller_is_better``, which defines whether to sort the checkpoints in ascending
or descending order with respect to the specified metric.
To sort checkpoints by any metric, use the ``sort_by`` argument to specify the metric and
``order_by`` to define the sorting order (ascending or descending).

.. code:: python
from determined.experimental import client
from determined.experimental import checkpoint, client
checkpoint = (
client.get_experiment(id).top_checkpoint(sort_by="accuracy", smaller_is_better=False)
checkpoints = (
client.get_experiment(id).list_checkpoints(
sort_by="accuracy",
order_by=checkpoint.CheckpointOrderBy.DESC
)
)
You may also query multiple checkpoints at the same time using the
:meth:`~determined.experimental.client.ExperimentReference.top_n_checkpoints` method. Only the
single best checkpoint from each trial is considered; out of those, the checkpoints with the best
validation metric values are returned in sorted order, with the best one first. For example, the
following snippet returns the top five checkpoints from distinct trials of a specified experiment.
To sort checkpoints using preset checkpoint parameters, use the
:class:`~determined.experimental.checkpoint.CheckpointSortBy` class. The example below fetches all
checkpoints for an experiment, sorting them by trial ID in descending order.

.. code:: python
from determined.experimental import client
checkpoints = client.get_experiment(id).top_n_checkpoints(5)
This method also accepts ``sort_by`` and ``smaller_is_better`` arguments.
from determined.experimental import checkpoint, client
:class:`~determined.experimental.client.TrialReference` is used for fine-grained control over
checkpoint selection within a trial. It contains a
:meth:`~determined.experimental.client.TrialReference.top_checkpoint` method, which mirrors
:meth:`~determined.experimental.client.ExperimentReference.top_checkpoint` for an experiment. It
also contains :meth:`~determined.experimental.client.TrialReference.select_checkpoint`, which offers
three ways to query checkpoints:
checkpoints = client.get_experiment(id).list_checkpoints(
sort_by=checkpoint.CheckpointSortBy.TRIAL_ID,
order_by=checkpoint.CheckpointOrderBy.DESC
)
#. ``best``: Returns the best checkpoint based on validation metrics as discussed above. When using
``best``, ``smaller_is_better`` and ``sort_by`` are also accepted.
#. ``latest``: Returns the most recent checkpoint for the trial.
#. ``uuid``: Returns the checkpoint with the specified UUID.
:class:`~determined.experimental.client.Trial` is used for fine-grained control over checkpoint
selection within a trial. It contains a
:meth:`~determined.experimental.client.Trial.list_checkpoints` method, which mirrors
:meth:`~determined.experimental.client.Experiment.list_checkpoints` for an experiment.

The following snippet showcases how to use the different modes for selecting checkpoints.
The following code illustrates methods to select specific checkpoints from a trial:

.. code:: python
from determined.experimental import client
from determined.experimental import checkpoint, client
trial = client.get_trial(id)
best_checkpoint = trial.top_checkpoint()
most_accurate_checkpoint = trial.select_checkpoint(
best=True, sort_by="accuracy", smaller_is_better=False
)
most_recent_checkpoint = trial.list_checkpoints(
sort_by=checkpoint.CheckpointSortBy.END_TIME,
order_by=checkpoint.CheckpointOrderBy.DESC,
max_results=1
)[0]
most_recent_checkpoint = trial.select_checkpoint(latest=True)
# Sort checkpoints by "accuracy" metric, if your training code reports it.
most_accurate_checkpoint = trial.list_checkpoints(
sort_by="accuracy",
order_by=checkpoint.CheckpointOrderBy.DESC,
max_results=1
)[0]
specific_checkpoint = client.get_checkpoint(uuid="uuid-for-checkpoint")
Expand All @@ -122,7 +121,7 @@ parameter, which changes the checkpoint download location.
from determined.experimental import client
checkpoint = client.get_experiment(id).top_checkpoint()
checkpoint = client.get_experiment(id).list_checkpoints()[0]
checkpoint_path = checkpoint.download()
specific_path = checkpoint.download(path="specific-checkpoint-path")
Expand All @@ -144,7 +143,7 @@ the ``model`` attribute of the ``Trial`` object, as shown in the following snipp
from determined.experimental import client
from determined import pytorch
checkpoint = client.get_experiment(id).top_checkpoint()
checkpoint = client.get_experiment(id).list_checkpoints()[0]
path = checkpoint.download()
trial = pytorch.load_trial_from_checkpoint_path(path)
model = trial.model
Expand All @@ -169,7 +168,7 @@ predictions as shown in the following snippet.
from determined.experimental import client
from determined import keras
checkpoint = client.get_experiment(id).top_checkpoint()
checkpoint = client.get_experiment(id).list_checkpoints()[0]
path = checkpoint.download()
model = keras.load_model_from_checkpoint_path(path)
Expand All @@ -192,7 +191,7 @@ useful for storing post-training metrics, labels, information related to deploym
from determined.experimental import client
checkpoint = client.get_experiment(id).top_checkpoint()
checkpoint = client.get_experiment(id).list_checkpoints()[0]
checkpoint.add_metadata({"environment": "production"})
# Metadata will be stored in Determined and accessible on the checkpoint object.
Expand All @@ -206,7 +205,7 @@ exists the entire tree beneath it will be overwritten.
from determined.experimental import client
checkpoint = client.get_experiment(id).top_checkpoint()
checkpoint = client.get_experiment(id).list_checkpoints()[0]
checkpoint.add_metadata({"metrics": {"loss": 0.12}})
checkpoint.add_metadata({"metrics": {"acc": 0.92}})
Expand All @@ -220,7 +219,7 @@ deleted.
from determined.experimental import client
checkpoint = client.get_experiment(id).top_checkpoint()
checkpoint = client.get_experiment(id).list_checkpoints()[0]
checkpoint.remove_metadata(["metrics"])
***************************************
Expand Down Expand Up @@ -287,8 +286,8 @@ checkpoint.
| | }
The ``det trial download`` command downloads checkpoints for a specified trial. Similar to the
:class:`~determined.experimental.client.TrialReference` API, the ``det trial download`` command
accepts ``--best``, ``--latest``, and ``--uuid`` options.
:class:`~determined.experimental.client.Trial` API, the ``det trial download`` command accepts
``--best``, ``--latest``, and ``--uuid`` options.

.. code::
Expand Down Expand Up @@ -328,7 +327,7 @@ The ``--latest`` and ``--uuid`` options are used as follows:
det trial download <trial_id> --uuid <uuid-for-checkpoint>
Finally, the ``det experiment download`` command provides a similar experience to using the
:class:`Python SDK <determined.experimental.client.ExperimentReference>`.
:class:`Python SDK <determined.experimental.client.Experiment>`.

.. code:: bash
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ The following snippet registers a new version of a model.
d = Determined()
checkpoint = d.get_experiment(exp_id).top_checkpoint()
checkpoint = d.get_experiment(exp_id).list_checkpoints()[0]
model = d.get_model("model_name")
Expand Down
19 changes: 9 additions & 10 deletions docs/reference/python-sdk.rst
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,8 @@ The next step is to call create_experiment():
exp = client.create_experiment(config="my_config.yaml", model_dir=".")
print(f"started experiment {exp.id}")
The returned object will be an ``ExperimentReference`` object, which has methods for controlling the
lifetime of the experiment running on the cluster. In this example, we will just wait for the
experiment to complete.
The returned object is an ``Experiment`` object, which offers methods to manage the experiment's
lifecycle. In the following example, we simply await the experiment's completion.

.. code:: python
Expand All @@ -46,7 +45,7 @@ Now that the experiment has completed, you can grab the top-performing checkpoin

.. code:: python
best_checkpoint = exp.top_checkpoint()
best_checkpoint = exp.list_checkpoints()[0]
print(f"best checkpoint was {best_checkpoint.uuid}")
.. _python-sdk-reference:
Expand Down Expand Up @@ -76,10 +75,10 @@ Now that the experiment has completed, you can grab the top-performing checkpoin
:members:
:member-order: bysource

``ExperimentReference``
=======================
``Experiment``
==============

.. autoclass:: determined.experimental.client.ExperimentReference
.. autoclass:: determined.experimental.client.Experiment
:members:
:member-order: bysource

Expand Down Expand Up @@ -118,10 +117,10 @@ Now that the experiment has completed, you can grab the top-performing checkpoin
:members:
:member-order: bysource

``TrialReference``
==================
``Trial``
=========

.. autoclass:: determined.experimental.client.TrialReference
.. autoclass:: determined.experimental.client.Trial
:members:
:exclude-members: stream_training_metrics, stream_validation_metrics
:member-order: bysource
Expand Down
23 changes: 11 additions & 12 deletions docs/release-notes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -523,8 +523,8 @@ Version 0.21.1

- :meth:`~determined.experimental.client.stream_trials_training_metrics`
- :meth:`~determined.experimental.client.stream_trials_validation_metrics`
- :meth:`~determined.experimental.client.TrialReference.stream_training_metrics`
- :meth:`~determined.experimental.client.TrialReference.stream_validation_metrics`
- :meth:`~determined.experimental.client.Trial.stream_training_metrics`
- :meth:`~determined.experimental.client.Trial.stream_validation_metrics`

**Removed Features**

Expand Down Expand Up @@ -884,12 +884,12 @@ Version 0.19.7
Hugging Face's Diffusers.

- Python SDK now supports reading logs from trials, via the new
:meth:`~determined.experimental.client.TrialReference.logs` method. Additionally, the Python SDK
also supports a new blocking call on an experiment to get the first trial created for an
experiment via the
:meth:`~determined.experimental.client.ExperimentReference.await_first_trial()` method. Users who
have been writing automation around the ``det e create --follow-first-trial`` CLI command may now
use the Python SDK instead, by combining ``.await_first_trial()`` and ``.logs()``.
:meth:`~determined.experimental.client.Trial.logs` method. Additionally, the Python SDK also
supports a new blocking call on an experiment to get the first trial created for an experiment
via the :meth:`~determined.experimental.client.ExperimentReference.await_first_trial()` method.
Users who have been writing automation around the ``det e create --follow-first-trial`` CLI
command may now use the Python SDK instead, by combining ``.await_first_trial()`` and
``.logs()``.

- RBAC: the enterprise edition of Determined (`HPE Machine Learning Development Environment
<https://www.hpe.com/us/en/solutions/artificial-intelligence/machine-learning-development-environment.html>`_)
Expand Down Expand Up @@ -3789,9 +3789,8 @@ Version 0.12.3
- Add support for locally testing experiments via ``det e create --local``.

- Add :class:`determined.experimental.Determined` class for accessing
:class:`~determined.experimental.ExperimentReference`,
:class:`~determined.experimental.TrialReference`, and
:class:`~determined.experimental.Checkpoint` objects.
:class:`~determined.experimental.ExperimentReference`, :class:`~determined.experimental.Trial`,
and :class:`~determined.experimental.Checkpoint` objects.

- TensorBoard logs now appear under the ``storage_path`` for ``shared_fs`` checkpoint
configurations.
Expand Down Expand Up @@ -3979,7 +3978,7 @@ Version 0.12.2

- Add support for gradient aggregation in Keras trials for TensorFlow 2.1.

- Add TrialReference and Checkpoint experimental APIs for exporting and loading checkpoints.
- Add Trial and Checkpoint experimental APIs for exporting and loading checkpoints.

- Improve performance when starting many tasks simultaneously.

Expand Down
3 changes: 3 additions & 0 deletions e2e_tests/tests/cluster/test_model_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def test_model_registry() -> None:
assert mnist.metadata == {"testing": "metadata"}

# Confirm we can look up a model by its ID
assert mnist.model_id is not None, "mnist.model_id set by create_model"
db_model = d.get_model_by_id(mnist.model_id)
assert db_model.name == "mnist"
db_model = d.get_model(mnist.model_id)
Expand Down Expand Up @@ -106,6 +107,7 @@ def test_model_registry() -> None:

latest_version = mnist.get_version()
assert latest_version is not None
assert latest_version.checkpoint
assert latest_version.checkpoint.uuid == checkpoint.uuid

latest_version.set_name("Test 2021")
Expand All @@ -131,6 +133,7 @@ def test_model_registry() -> None:

latest_version = mnist.get_version()
assert latest_version is not None
assert latest_version.checkpoint
assert latest_version.checkpoint.uuid == checkpoint.uuid

# Ensure the correct number of versions are present.
Expand Down
6 changes: 4 additions & 2 deletions e2e_tests/tests/cluster/test_resource_pools.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ def test_default_pool_task_container_defaults() -> None:
model_dir=conf.fixtures_path("no_op"),
)

e1_config = e1.get_config()
assert e1.config
e1_config = e1.config

assert len(e1_config["environment"]["environment_variables"]["cpu"]) > 0

Expand All @@ -35,7 +36,8 @@ def test_default_pool_task_container_defaults() -> None:
config=parsed_config,
model_dir=conf.fixtures_path("no_op"),
)
e2_config = e2.get_config()
assert e2.config
e2_config = e2.config

assert (
e1_config["environment"]["environment_variables"]["cpu"]
Expand Down
3 changes: 2 additions & 1 deletion e2e_tests/tests/experiment/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ def test_end_to_end_adaptive() -> None:
assert best is not None
assert best > 0.93

# Check that ExperimentReference returns a sorted order of top checkpoints
# Check that the Experiment returns a sorted order of top checkpoints
# without gaps. The top 2 checkpoints should be the first 2 of the top k
# checkpoints if sorting is stable.
d = Determined(conf.make_master_url())
Expand Down Expand Up @@ -212,6 +212,7 @@ def test_end_to_end_adaptive() -> None:
db_check = d.get_checkpoint(checkpoint.uuid)
# Make sure the checkpoint metadata is correct and correctly saved to the db.
# Beginning with 0.18 the TrialController contributes a few items to the dict.
assert checkpoint.metadata
assert checkpoint.metadata.get("testing") == "metadata"
assert checkpoint.metadata.keys() == {
"determined_version",
Expand Down
4 changes: 2 additions & 2 deletions e2e_tests/tests/fixtures/no_op/model_def.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
self.trained_steps = collections.Counter()

@staticmethod
def from_trial(trial_inst: det.Trial, *args: Any, **kwargs: Any) -> det.TrialController:
def from_trial(trial_inst: det.LegacyTrial, *args: Any, **kwargs: Any) -> det.TrialController:
return NoOpTrialController(*args, **kwargs)

@staticmethod
Expand Down Expand Up @@ -251,7 +251,7 @@ def chaos_failure(self, probability: Optional[float]) -> None:
raise Exception("CHAOS! Executing random failure.")


class NoOpTrial(det.Trial):
class NoOpTrial(det.LegacyTrial):
trial_context_class = NoOpTrialContext
trial_controller_class = NoOpTrialController

Expand Down
Loading

0 comments on commit 67f6dc1

Please sign in to comment.