Skip to content

Commit

Permalink
Fixing the dependency situation (#613)
Browse files Browse the repository at this point in the history
  • Loading branch information
dirkgr authored Apr 5, 2024
1 parent 3400aa8 commit f585693
Show file tree
Hide file tree
Showing 11 changed files with 35 additions and 12 deletions.
3 changes: 2 additions & 1 deletion .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ jobs:
if: steps.virtualenv-cache.outputs.cache-hit != 'true' && (contains(matrix.task.extras, 'flax') || contains(matrix.task.extras, 'all'))
run: |
. .venv/bin/activate
pip install flax==0.6.1 jax==0.4.1 jaxlib==0.4.1 tensorflow-cpu==2.9.1 optax==0.1.3
pip install flax jax jaxlib "tensorflow-cpu>=2.9.1" optax
- name: Install editable (no cache hit)
if: steps.virtualenv-cache.outputs.cache-hit != 'true'
Expand Down Expand Up @@ -282,6 +282,7 @@ jobs:
spec: |
version: v2
description: GPU Tests
budget: ai2/oe-training
tasks:
- name: tests
image:
Expand Down
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## Unreleased

### Fixed

- Fixed a bunch of dependencies

## [v1.3.2](https://github.com/allenai/tango/releases/tag/v1.3.2) - 2023-10-27

### Fixed
Expand Down
8 changes: 4 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -89,10 +89,10 @@ fairscale = [
]
flax = [
"datasets>=1.12,<3.0",
"jax>=0.4.1,<=0.4.13",
"jaxlib>=0.4.1,<=0.4.13",
"flax>=0.6.1,<=0.7.0",
"optax>=0.1.2",
"jax",
"jaxlib",
"flax",
"optax",
"tensorflow-cpu>=2.9.1"
]
wandb = [
Expand Down
2 changes: 1 addition & 1 deletion tango/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ class SettingsObject(NamedTuple):
called_by_executor: bool


@click.group(**_CLICK_GROUP_DEFAULTS)
@click.group(name=None, **_CLICK_GROUP_DEFAULTS)
@click.version_option(version=VERSION)
@click.option(
"--settings",
Expand Down
10 changes: 9 additions & 1 deletion tango/integrations/beaker/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,7 @@ def __init__(
priority: Optional[Union[str, Priority]] = None,
allow_dirty: bool = False,
scheduler: Optional[BeakerScheduler] = None,
budget: Optional[str] = None,
**kwargs,
):
# Pre-validate arguments.
Expand All @@ -365,6 +366,11 @@ def __init__(
"Either 'beaker_image' or 'docker_image' must be specified for BeakerExecutor, but not both."
)

if budget is None:
raise ConfigurationError("You must specify a budget to use the beaker executor.")
else:
self._budget = budget

from tango.workspaces import LocalWorkspace, MemoryWorkspace

if isinstance(workspace, MemoryWorkspace):
Expand Down Expand Up @@ -1029,7 +1035,9 @@ def _build_experiment_spec(
return (
experiment_name,
ExperimentSpec(
tasks=[task_spec], description=f'Tango step "{step_name}" ({step.unique_id})'
tasks=[task_spec],
description=f'Tango step "{step_name}" ({step.unique_id})',
budget=self._budget,
),
[step_graph_dataset],
)
2 changes: 1 addition & 1 deletion tango/integrations/flax/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def __init__(

self.logger = logging.getLogger(FlaxDataLoader.__name__)

def __call__(self, rng: jax.random.PRNGKeyArray, do_distributed: bool):
def __call__(self, rng: jax._src.random.KeyArrayLike, do_distributed: bool):
steps_per_epoch = self.dataset_size // self.batch_size

if self.shuffle:
Expand Down
3 changes: 2 additions & 1 deletion tango/integrations/flax/optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ class Optimizer(Registrable):
:options: +ELLIPSIS
optax::adabelief
optax::adadelta
optax::adafactor
optax::adagrad
optax::adam
Expand Down Expand Up @@ -100,7 +101,7 @@ def factory_func():
Optimizer.register("optax::" + name)(factory_func)

# Register all learning rate schedulers.
for name, cls in optax._src.schedule.__dict__.items():
for name, cls in optax.schedules.__dict__.items():
if isfunction(cls) and not name.startswith("_") and cls.__annotations__:
factory_func = scheduler_factory(cls)
LRScheduler.register("optax::" + name)(factory_func)
Expand Down
4 changes: 2 additions & 2 deletions tango/integrations/flax/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,15 @@
import jax


def get_PRNGkey(seed: int = 42) -> Union[Any, jax.random.PRNGKeyArray]:
def get_PRNGkey(seed: int = 42) -> Union[Any, jax._src.random.KeyArray]:
"""
Utility function to create a pseudo-random number generator key
given a seed.
"""
return jax.random.PRNGKey(seed)


def get_multiple_keys(key, multiple: int = 1) -> Union[Any, jax.random.PRNGKeyArray]:
def get_multiple_keys(key, multiple: int = 1) -> Union[Any, jax._src.random.KeyArray]:
"""
Utility function to split a PRNG key into multiple new keys.
Used in distributed training.
Expand Down
1 change: 1 addition & 0 deletions tango/integrations/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@
transformers::Adafactor
transformers::AdamW
transformers::LayerWiseDummyOptimizer
- :class:`~tango.integrations.torch.LRScheduler`: All learning rate scheduler function from transformers
are registered according to their type name (e.g. "transformers::linear").
Expand Down
3 changes: 3 additions & 0 deletions tests/integrations/beaker/executor_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def test_from_params(beaker_workspace_name: str):
beaker_image="ai2/conda",
github_token="FAKE_TOKEN",
datasets=[{"source": {"beaker": "some-dataset"}, "mount_path": "/input"}],
budget="ai2/allennlp",
),
workspace=BeakerWorkspace(workspace=beaker_workspace_name),
clusters=["fake-cluster"],
Expand All @@ -38,6 +39,7 @@ def test_init_with_mem_workspace(beaker_workspace_name: str):
beaker_image="ai2/conda",
github_token="FAKE_TOKEN",
clusters=["fake-cluster"],
budget="ai2/allennlp",
)


Expand All @@ -50,6 +52,7 @@ def settings(beaker_workspace_name: str) -> TangoGlobalSettings:
"beaker_workspace": beaker_workspace_name,
"install_cmd": "pip install .[beaker]",
"clusters": ["ai2/allennlp-cirrascale", "ai2/general-cirrascale"],
"budget": "ai2/allennlp",
},
)

Expand Down
7 changes: 6 additions & 1 deletion tests/integrations/flax/train_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,5 +20,10 @@ def test_trainer(self):
],
)
assert (
result_dir / "train" / "work" / "checkpoint_state_latest" / "checkpoint_0"
result_dir
/ "train"
/ "work"
/ "checkpoint_state_latest"
/ "checkpoint_0"
/ "checkpoint"
).is_file()

0 comments on commit f585693

Please sign in to comment.