Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deepspeech workload variants #627

Closed
wants to merge 74 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
74 commits
Select commit Hold shift + click to select a range
cae33e2
modify jax model to add variants
priyakasimbeg Dec 20, 2023
b17fc26
add jax deepspeech workload variants
priyakasimbeg Dec 20, 2023
b042167
add tanh variant
priyakasimbeg Dec 20, 2023
83f079b
add deepspeech pytorch variants
priyakasimbeg Dec 20, 2023
7e2ca31
fix deepspeech batchnorm layer
priyakasimbeg Dec 20, 2023
50ff71f
add deepspeech workoad variant names to docker script
priyakasimbeg Dec 20, 2023
a8187b8
Add pass/fail thresholds to traindiffs test
runame Jan 13, 2024
2373e15
Add traindiffs_test option to docker startup script
runame Jan 13, 2024
d1da9c7
Rename PytWorkload to PyTorchWorkload
runame Jan 13, 2024
6a5d63a
Add traindiffs tests to workflows (self-hosted)
runame Jan 13, 2024
047475c
Merge branch 'dev' into traindiffs
runame Jan 17, 2024
1683ba3
add variant scoring conditions
priyakasimbeg Jan 18, 2024
370687d
add flag for self-tuning rulset
priyakasimbeg Jan 18, 2024
2128ce8
score group of submissions
priyakasimbeg Jan 18, 2024
c65794d
Merge branch 'dev' into scoring_fixes
priyakasimbeg Jan 19, 2024
d43ccf4
correct max number of steps
priyakasimbeg Jan 19, 2024
fb81436
add heldout workloads"
priyakasimbeg Jan 19, 2024
1ea2282
add trial args to docker startup.sh"
priyakasimbeg Jan 23, 2024
0bcb969
add script for sampling held out workloads
priyakasimbeg Jan 24, 2024
ce5f202
add code for run workloads
priyakasimbeg Jan 25, 2024
f431eef
add workload sampling
priyakasimbeg Jan 25, 2024
f260497
formatting
priyakasimbeg Jan 25, 2024
1a41f8b
imports
priyakasimbeg Jan 25, 2024
87df162
make seed splitting parallelizable
priyakasimbeg Jan 25, 2024
9d9cdb9
fix
priyakasimbeg Jan 25, 2024
1775307
formatting
priyakasimbeg Jan 25, 2024
8108c00
Merge pull request #613 from runame/traindiffs
priyakasimbeg Jan 25, 2024
2a11708
held out workloads example
priyakasimbeg Jan 25, 2024
a8385a2
add docker for run_workloads.py
priyakasimbeg Jan 25, 2024
ffddbdc
fix run_workloads.py
priyakasimbeg Jan 25, 2024
91cdf34
fix
priyakasimbeg Jan 25, 2024
95572ad
add rng seed to startup.sh docker script
priyakasimbeg Jan 25, 2024
d577d5c
fix
priyakasimbeg Jan 25, 2024
91ff705
fix
priyakasimbeg Jan 25, 2024
296dc1e
fix
priyakasimbeg Jan 26, 2024
a5b1154
fix
priyakasimbeg Jan 26, 2024
226544d
fix
priyakasimbeg Jan 26, 2024
6faad04
fix
priyakasimbeg Jan 26, 2024
9e7def9
fix log message
priyakasimbeg Jan 26, 2024
9b410b7
fix
priyakasimbeg Jan 26, 2024
7634a0b
debug
priyakasimbeg Jan 26, 2024
235bc69
debugging
priyakasimbeg Jan 26, 2024
a8d04cc
debugging
priyakasimbeg Jan 26, 2024
b2571b2
fix
priyakasimbeg Jan 26, 2024
18bc347
remove debugging statemetns
priyakasimbeg Jan 26, 2024
4a98698
fix
priyakasimbeg Jan 26, 2024
4d38e55
formatting
priyakasimbeg Jan 26, 2024
4d413f4
take into account median of studies for scoring
priyakasimbeg Jan 27, 2024
84c87b9
remove debugging
priyakasimbeg Jan 27, 2024
d6e2a36
formatting
priyakasimbeg Jan 27, 2024
f34838a
documentation
priyakasimbeg Jan 27, 2024
be263a2
Merge branch 'dev' into scoring_fixes
priyakasimbeg Jan 29, 2024
84dbb07
fix
priyakasimbeg Jan 30, 2024
af0b608
Merge branch 'scoring_fixes' of github.com:mlcommons/algorithmic-effi…
priyakasimbeg Jan 30, 2024
6d3b0ae
remove indexing for rng_subkeys
priyakasimbeg Jan 31, 2024
7b23443
add documentation
priyakasimbeg Jan 31, 2024
4b77ddd
fix documentation
priyakasimbeg Jan 31, 2024
2f48009
add warning
priyakasimbeg Jan 31, 2024
d39eb24
typo
priyakasimbeg Jan 31, 2024
aecb37f
fix documentation
priyakasimbeg Jan 31, 2024
5135cc8
remove prng import from generate_held_out_workloads.py
priyakasimbeg Jan 31, 2024
6d4f82e
fix technical documentation
priyakasimbeg Feb 1, 2024
aaa1014
formatting
priyakasimbeg Feb 1, 2024
761a877
add default for workload metadata config file
priyakasimbeg Feb 1, 2024
6b3827a
yapf fix
priyakasimbeg Feb 1, 2024
c0e1aad
import order
priyakasimbeg Feb 1, 2024
ff3c9b0
Merge pull request #618 from mlcommons/scoring_fixes
priyakasimbeg Feb 1, 2024
34996e9
minor
pomonam Feb 6, 2024
72be296
Revert batch size
pomonam Feb 6, 2024
daef103
minor
pomonam Feb 6, 2024
3a78765
minor
pomonam Feb 6, 2024
d718663
minor
pomonam Feb 6, 2024
4bc8fec
minor
pomonam Feb 6, 2024
0dced90
Merge conflicts
pomonam Feb 6, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions .github/workflows/traindiffs_tests.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
name: Containerized training differences tests between Jax and PyTorch

on:
pull_request:
branches:
- 'main'

jobs:
build_and_push_docker_image:
runs-on: self-hosted
steps:
- uses: actions/checkout@v2
- name: Build and push docker image
run: |
GIT_BRANCH=${{ github.head_ref || github.ref_name }}
FRAMEWORK=both
IMAGE_NAME="algoperf_${GIT_BRANCH}"
cd $HOME/algorithmic-efficiency/docker
docker build --no-cache -t $IMAGE_NAME . --build-arg framework=$FRAMEWORK --build-arg branch=$GIT_BRANCH
BUILD_RETURN=$?
if [[ ${BUILD_RETURN} != 0 ]]; then exit ${BUILD_RETURN}; fi
docker tag $IMAGE_NAME us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/$IMAGE_NAME
docker push us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/$IMAGE_NAME
traindiffs_tests:
runs-on: self-hosted
needs: build_and_push_docker_image
steps:
- uses: actions/checkout@v2
- name: Run containerized traindiffs test
run: |
docker pull us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_${{ github.head_ref || github.ref_name }}
docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_${{ github.head_ref || github.ref_name }} algoperf_${{ github.head_ref || github.ref_name }} --traindiffs_test true
2 changes: 1 addition & 1 deletion DOCUMENTATION.md
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,7 @@ The held-out workloads function similarly to a holdout test set discouraging sub

Modifications could, for example, include changing the number of layers or units (drawn from an interval), swapping the activation function (drawn from a set of applicable functions), or using different data augmentations (drawn from a list of possible pre-processing steps). The sample space should be wide enough to discourage submitters from simply trying them all out, but at the same time should be restricted enough to produce realistic workloads with acceptable achievable performances.

In the first iteration of this benchmark, we manually designed three different workloads variants for each fixed workload. The variants are designed such that they achieve a comparable performance to the fixed workload and that they might require different hyperparameters to achieve this performance. After the submission deadline, one held-out workload will be sampled for each fixed workload.
In the first iteration of this benchmark, we manually designed three different workloads variants for each fixed workload. The variants are designed such that they achieve a comparable performance to the fixed workload and that they might require different hyperparameters to achieve this performance. After the submission deadline, one held-out workload will be sampled for each dataset.

Our scoring procedure uses the held-out workloads only to penalize submissions that can't handle the introduced modifications (see the [Scoring](#scoring) section for further details).

Expand Down
42 changes: 40 additions & 2 deletions GETTING_STARTED.md
Original file line number Diff line number Diff line change
Expand Up @@ -336,11 +336,49 @@ docker exec -it <container_id> /bin/bash
```

## Score your Submission
To score your submission we will score over all workloads, held-out workloads and studies as described in the rules.
We will sample 1 held-out workload per dataset for a total of 6 held-out workloads and will use the sampled
held-out workloads in the scoring criteria for the matching base workloads.
In other words, the total number of runs expected for official scoring is:
- for external ruleset (8 (workloads) + 6 (held-out workloads)) x 5 (studies) x 5 (trials)
- for internal ruleset (8 (workloads) + 6 (held-out workloads)) x 5 (studies)

To produce performance profile and performance table:


### Running workloads
To run workloads for scoring you may specify a "virtual" list of held-out workloads. It is important
to note that the official set of held-out workloads will be sampled by the competition organizers during scoring time.

An example config for held-out workloads is stored in `scoring/held_workloads_example.json`.
To generate a new sample of held out workloads run:

```bash
python3 generate_held_out_workloads.py --seed <optional_rng_seed> --output_filename <output_filename>
```

To run a number of studies and trials over all workload using Docker containers for each run:

```bash
python scoring/run_workloads.py \
--framework <framework> \
--experiment_name <experiment_name> \
--docker_image_url <docker_image_url> \
--submission_path <sumbission_path> \
--tuning_search_space <submission_path> \
--held_out_workloads_config_path held_out_workloads_example.json \
--num_studies <num_studies>
--seed <rng_seed>
```

Note that to run the above script you will need the minimum jax_cpu and pytorch_cpu installations of the algorithmic-efficiency package.

During submission development, it might be useful to do faster, approximate scoring (e.g. without 5 different s
tudies or when some trials are missing) so the scoring scripts allow some flexibility. To simulate official scoring,
pass the `--strict=True` flag in score_submission.py. To get the raw scores and performance profiles of group of
submissions or single submission:

```bash
python3 scoring/score_submission.py --experiment_path=<path_to_experiment_dir> --output_dir=<output_dir>
python score_submissions.py --submission_directory <directory_with_submissions> --output_dir <output_dir> --compute_performance_profiles
```

We provide the scores and performance profiles for the [paper baseline algorithms](/reference_algorithms/paper_baselines/) in the "Baseline Results" section in [Benchmarking Neural Network Training Algorithms](https://arxiv.org/abs/2306.07179).
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,8 @@ class DeepspeechConfig:
enable_residual_connections: bool = True
enable_decoder_layer_norm: bool = True
bidirectional: bool = True

use_tanh: bool = False
layernorm_everywhere: bool = False

class Subsample(nn.Module):
"""Module to perform strided convolution in order to subsample inputs.
Expand All @@ -80,15 +81,18 @@ def __call__(self, inputs, output_paddings, train):
batch_norm_momentum=config.batch_norm_momentum,
batch_norm_epsilon=config.batch_norm_epsilon,
input_channels=1,
output_channels=config.encoder_dim)(outputs, output_paddings, train)
output_channels=config.encoder_dim,
use_tanh=config.use_tanh
)(outputs, output_paddings, train)

outputs, output_paddings = Conv2dSubsampling(
encoder_dim=config.encoder_dim,
dtype=config.dtype,
batch_norm_momentum=config.batch_norm_momentum,
batch_norm_epsilon=config.batch_norm_epsilon,
input_channels=config.encoder_dim,
output_channels=config.encoder_dim)(outputs, output_paddings, train)
output_channels=config.encoder_dim,
use_tanh=config.use_tanh)(outputs, output_paddings, train)

batch_size, subsampled_lengths, subsampled_dims, channels = outputs.shape

Expand Down Expand Up @@ -127,6 +131,7 @@ class Conv2dSubsampling(nn.Module):
dtype: Any = jnp.float32
batch_norm_momentum: float = 0.999
batch_norm_epsilon: float = 0.001
use_tanh: bool = False

def setup(self):
self.filter_shape = (3, 3, self.input_channels, self.output_channels)
Expand All @@ -150,7 +155,12 @@ def __call__(self, inputs, paddings, train):
feature_group_count=feature_group_count)

outputs += jnp.reshape(self.bias, (1,) * (outputs.ndim - 1) + (-1,))
outputs = nn.relu(outputs)

if self.use_tanh:
outputs = nn.tanh(outputs)
else:
outputs = nn.relu(outputs)


# Computing correct paddings post input convolution.
input_length = paddings.shape[1]
Expand Down Expand Up @@ -182,16 +192,22 @@ def __call__(self, inputs, input_paddings=None, train=False):
padding_mask = jnp.expand_dims(1 - input_paddings, -1)
config = self.config

inputs = BatchNorm(config.encoder_dim,
config.dtype,
config.batch_norm_momentum,
config.batch_norm_epsilon)(inputs, input_paddings, train)
if config.layernorm_everywhere:
inputs = LayerNorm(config.encoder_dim)(inputs)
else:
inputs = BatchNorm(config.encoder_dim,
config.dtype,
config.batch_norm_momentum,
config.batch_norm_epsilon)(inputs, input_paddings, train)
inputs = nn.Dense(
config.encoder_dim,
use_bias=True,
kernel_init=nn.initializers.xavier_uniform())(
inputs)
inputs = nn.relu(inputs)
if config.use_tanh:
inputs = nn.tanh(inputs)
else:
inputs = nn.relu(inputs)
inputs *= padding_mask

if config.feed_forward_dropout_rate is None:
Expand Down Expand Up @@ -416,10 +432,13 @@ class BatchRNN(nn.Module):
def __call__(self, inputs, input_paddings, train):
config = self.config

inputs = BatchNorm(config.encoder_dim,
config.dtype,
config.batch_norm_momentum,
config.batch_norm_epsilon)(inputs, input_paddings, train)
if config.layernorm_everywhere:
inputs = LayerNorm(config.encoder_dim)(inputs)
else:
inputs = BatchNorm(config.encoder_dim,
config.dtype,
config.batch_norm_momentum,
config.batch_norm_epsilon)(inputs, input_paddings, train)
output = CudnnLSTM(
features=config.encoder_dim // 2,
bidirectional=config.bidirectional,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,14 @@ def init_model_fn(
model_config = models.DeepspeechConfig(
feed_forward_dropout_rate=dropout_rate,
use_specaug=self.use_specaug,
input_dropout_rate=aux_dropout_rate)
input_dropout_rate=aux_dropout_rate,
use_tanh=self.use_tanh,
enable_residual_connections=self.enable_residual_connections,
enable_decoder_layer_norm=self.enable_decoder_layer_norm,
layernorm_everywhere=self.layernorm_everywhere,
freq_mask_count=self.freq_mask_count,
time_mask_count=self.time_mask_count,
)
self._model = models.Deepspeech(model_config)
input_shape = [(320000,), (320000,)]
fake_input_batch = [np.zeros((2, *x), jnp.float32) for x in input_shape]
Expand Down Expand Up @@ -67,3 +74,64 @@ def step_hint(self) -> int:
@property
def max_allowed_runtime_sec(self) -> int:
return 55_506 # ~15.4 hours

@property
def use_tanh(self) -> bool:
return False

@property
def enable_residual_connections(self) -> bool:
return True

@property
def enable_decoder_layer_norm(self) -> bool:
return True

@property
def layernorm_everywhere(self) -> bool:
return False

@property
def freq_mask_count(self) -> int:
return 2

@property
def time_mask_count(self) -> int:
return 10


class LibriSpeechDeepSpeechTanhWorkload(LibriSpeechDeepSpeechWorkload):

@property
def use_tanh(self) -> bool:
return True


class LibriSpeechDeepSpeechNoResNetWorkload(LibriSpeechDeepSpeechWorkload):

@property
def enable_residual_connections(self) -> bool:
return False


class LibriSpeechDeepSpeechNormAndSpecAugWorkload(LibriSpeechDeepSpeechWorkload):

@property
def eval_batch_size(self) -> int:
return 128

@property
def enable_decoder_layer_norm(self) -> bool:
return False

@property
def layernorm_everywhere(self) -> bool:
return True

@property
def freq_mask_count(self) -> int:
return 4

@property
def time_mask_count(self) -> int:
return 15
Loading
Loading