Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update bench cluster #229

Merged
merged 59 commits into from
Sep 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
59 commits
Select commit Hold shift + click to select a range
33fd672
readme
zzhhjjj Apr 22, 2024
b1872e1
Adding checkpoint after traning ends
angegonzalez May 7, 2024
7dd5beb
fix row parallel
C-TC May 16, 2024
ad028e6
datatrove is all you need
TJ-Solergibert May 31, 2024
df1632b
Refractored preprocessing script to work with datatrove
TJ-Solergibert May 31, 2024
395a4db
Updated dataset_weights default value
TJ-Solergibert Jun 1, 2024
5f8a52b
Support for jsonl files for preprocess_data.py and updated preprocess…
TJ-Solergibert Jun 1, 2024
9fcd071
Updated nanoset docs
TJ-Solergibert Jun 1, 2024
d7cfc3f
Install datatrove from source
TJ-Solergibert Jun 2, 2024
bcf405d
Implemented global memory buffer to reduce activation memory of diffe…
AleHD Jun 27, 2024
ed1ca7d
GLU fusion
AleHD Jun 27, 2024
9b0de5b
precommit
AleHD Jun 27, 2024
bbc259f
Merge branch 'main' into fix_tp_mem_cache
AleHD Jul 11, 2024
803b6da
Wrong backward fixed
AleHD Jul 16, 2024
59bfb6b
Removed useless prints
AleHD Jul 16, 2024
2c69e9a
Minor fixes
AleHD Jul 17, 2024
30439fd
precommit
AleHD Jul 17, 2024
1e02a9c
Added tp_recompute_allgather option
AleHD Jul 17, 2024
9cc81bb
Changed recompute default
AleHD Jul 17, 2024
956fbfd
Changed recompute default
AleHD Jul 17, 2024
9992f1c
Little fixes
TJ-Solergibert Jul 17, 2024
b9e9201
Moved ColumnLinearNoAsync module for consistency
AleHD Jul 17, 2024
cd2ff64
Merge branch 'fix-row-parallel' of github.com:C-TC/nanotron into asyn…
AleHD Jul 17, 2024
25acc0e
Merge branch 'async_fix' into mem_fix_async
AleHD Jul 17, 2024
7cc6653
memory efficient async linear
AleHD Jul 18, 2024
cb0f260
precommit
AleHD Jul 18, 2024
6d85d03
Added no_recompute_allgather mode to async
AleHD Jul 18, 2024
5f82f7a
Merge pull request #145 from zzhhjjj/readme
zzhhjjj Jul 21, 2024
49633df
Merge branch 'main' into fix_tp_mem_cache
AleHD Jul 23, 2024
2afd007
Fixed List not found
AleHD Jul 23, 2024
81e7a54
Merge branch 'main' into mem_fix_async
AleHD Jul 23, 2024
7e758db
Fixed tp=1 case
AleHD Jul 23, 2024
41f11f0
fix row parallel
C-TC May 16, 2024
b713f7b
Merge branch 'fix-row-parallel' of https://github.com/c-tc/nanotron i…
3outeille Jul 29, 2024
2e48d66
Merge pull request #172 from C-TC/fix-row-parallel
3outeille Jul 30, 2024
ce2a96b
Merge branch 'main' into fix_tp_mem_cache
AleHD Jul 30, 2024
cd84d4f
Fixed column parallel
AleHD Jul 30, 2024
d3db06a
Added tp_recompute_allgather test
AleHD Jul 30, 2024
6f82050
Merge branch 'fix_tp_mem_cache' into mem_fix_async
AleHD Jul 30, 2024
4c94b99
Added tp_recompute_allgather test
AleHD Jul 30, 2024
3c35611
change to correct config_nanoset.yaml path
xrsrke Jul 31, 2024
6ad5994
Merge pull request #212 from huggingface/pr/TJ-Solergibert/189
xrsrke Jul 31, 2024
2793c92
Merge pull request #165 from angegonzalez/feature/save-train-ending-c…
xrsrke Aug 2, 2024
7daa186
Minor restyling
AleHD Aug 2, 2024
31c3c5a
Fixed names
AleHD Aug 2, 2024
0adb368
Merge pull request #1 from AleHD/fix_tp_mem_cache
AleHD Aug 2, 2024
4eb520f
Merge pull request #203 from AleHD/fix_tp_mem_cache
3outeille Aug 2, 2024
03d67f2
Merge pull request #208 from AleHD/mem_fix_async
3outeille Aug 5, 2024
f40fa05
exact output logits match for LLaMA-3
zzhhjjj Aug 14, 2024
3a45a34
remove torch compile
eliebak Aug 27, 2024
3340a7b
only log ckpt on rank 0
eliebak Aug 27, 2024
6cf2d63
Merge pull request #224 from eliebak/fix-logging-checkpoint
zzhhjjj Aug 29, 2024
4a2ddca
Merge pull request #223 from eliebak/fix-torch-compile
zzhhjjj Aug 29, 2024
b15fbc7
clean code
zzhhjjj Sep 2, 2024
7323ce1
make nanoset compatible with python
eliebak Sep 5, 2024
3be44ef
Merge pull request #228 from eliebak/nanoset-python-compatible
3outeille Sep 5, 2024
5ded28c
inference part
zzhhjjj Sep 5, 2024
2677380
refactor
zzhhjjj Sep 5, 2024
1456446
Merge pull request #214 from zzhhjjj/match-transformers
3outeille Sep 5, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 33 additions & 29 deletions docs/nanoset.md
Original file line number Diff line number Diff line change
@@ -1,49 +1,50 @@
# Nanosets
Nanotron incorporates [`Nanosets`](../src/nanotron/data/nanoset.py), a kind of datasets based on [numpy memory-mapped arrays](https://numpy.org/doc/stable/reference/generated/numpy.memmap.html). `Nanosets` are capable of serving batches from files containing pre-tokenized datasets. They allow reading tokens from one or multiple datasets and even specifying the weight of each dataset when building batches.
Nanotron incorporates [`Nanosets`](../src/nanotron/data/nanoset.py), a dataset for processing tokenized documents with [`datatrove`](https://github.com/huggingface/datatrove). They allow reading tokens from one or multiple datasets and even specifying the weight of each dataset when building batches.
## Install
To use `Nanosets`, it's necessary to install Nanotron with the `nanosets` flavor.
```
pip install -e '.[nanosets]'
pip install nanotron[nanosets]
```
This will install the following dependencies:
- `transformers`: To tokenize the datasets
- `datasets`: To preprocess the datasets
- `datatrove`: To preprocess the datasets
- `numba`: To compile helper functions in order to speed up the creation of `Nanosets`
- `transformers`: For the tokenizers
## Data pre-processing
To use these datasets, first, we need to preprocess the data. The input format can either be a column of a Hugging Face Dataset or a .json file containing a text sample per line. For example:
To use this dataset, first, we need to preprocess the data using `datatrove`'s `DocumentTokenizer` pipeline. We invite you to take a look at `datatrove`, since it contains multiple features that allow, for example, filter out documents based on specific rules/criteria, extract text content from raw formats or scheduling the preprocessing in a Slurm cluster. We have also added a simple script capable of tokenizing datasets.

<pre>
{"src": "www.nvidia.com", "text": "The quick brown fox", "type": "Eng", "id": "0", "title": "First Part"}
{"src": "The Internet", "text": "jumps over the lazy dog", "type": "Eng", "id": "42", "title": "Second Part"}
</pre>

The preprocessing is done using the [`tools/preprocess_data.py`](../tools/preprocess_data.py) script. Below we show an example for processing a corpus with the Llama2 tokenizer.
The preprocessing is done using the [`tools/preprocess_data.py`](../tools/preprocess_data.py) script. The input format can either be a Hugging Face Dataset, a path to a `.jsonl` or a path to a folder containing multiple `.jsonl` files. Below we show an example for processing a Hugging Face Dataset from the Hub with the Llama3 tokenizer.

<pre>
torchrun --nproc-per-node 16 tools/preprocess_data.py \
--input HuggingFaceH4/testing_alpaca_small \
--split train \
--column completion \
--output-prefix datasets/testing_alpaca_small \
--tokenizer-name-or-path openai-community/gpt2
python3 tools/preprocess_data.py \
--tokenizer-name-or-path meta-llama/Meta-Llama-3-8B \
--output-folder datasets/emotion \
--n-tasks 16 \
hf \
--dataset dair-ai/emotion \
</pre>

The preprocessing script has to be launched with `torchrun` in order to spawn `--nproc-per-node` workers that will preprocess the dataset concurrently. The `--input` dataset can be either a Hugging Face Dataset from the Hub or a `.json` file. The processed dataset will be stored in *`--output-prefix`_input_ids.npy*. In `--tokenizer-name-or-path`, we will have to specify a tokenizer in the same way as we do when using `AutoTokenizers.from_pretrained(...)`.
First with `--tokenizer-name-or-path` we will specify a tokenizer in the same way as we do when using `AutoTokenizers.from_pretrained(...)`. Then we specify the `--output-folder` where we will store the tokenized documents and the number of workers with `--n-tasks`. Finally we will indicate the type of dataset (whether if it's a Hugging Face Dataset ["**hf**"] or in jsonl ["**jsonl**"] format) and the dataset that we want to preprocess. Check the different settings with `python3 tools/preprocess_data.py --help`, `python3 tools/preprocess_data.py hf --help` & `python3 tools/preprocess_data.py jsonl --help`.

The output will be one file named, in this case, `datasets/testing_alpaca_small_input_ids.npy`. We will then have to specify this file in the `dataset_path` field in the config file.
Every worker will store in `--output-folder` 3 different kind of files:
- `*.ds` Containing the tokenized documents
- `*.ds.index` Containing the bounds of each tokenized document
- `*.ds.metadata` Containing the number of tokens and tokenizer used

> [!IMPORTANT]
Remember to introduce the type of dataset to process. e.g. python3 tools/preprocess_data.py --tokenizer-name-or-path gpt2 --n-tasks 16 **jsonl** --dataset raw_datasets/c4-es-json-files

## Working with Nanosets

To work with `Nanosets`, we just need to configure 1 argument:
1. `dataset_path`: This argument specifies the file or files that will compose the `Nanoset`. There are 3 ways to specify it:
1. `dataset_folder`: This argument specifies the file or files that will compose the `Nanoset`. There are 3 ways to specify it:
1. If we specify a single path, we will create a `Nanoset` from a single dataset file.
```yaml
data_stages:
- name: General purpose training (Single dataset)
start_training_step: 1
data:
dataset:
dataset_path: datasets/SlimPajama-6B_input_ids.npy
dataset_folder: datasets/SlimPajama-6B
num_loading_workers: 0
seed: 1234
```
Expand All @@ -54,9 +55,9 @@ To work with `Nanosets`, we just need to configure 1 argument:
start_training_step: 15
data:
dataset:
dataset_path:
- datasets/SlimPajama-6B_input_ids.npy
- datasets/testing_alpaca_small_input_ids.npy
dataset_folder:
- datasets/SlimPajama-6B
- datasets/testing_alpaca_small
num_loading_workers: 0
seed: 1234
```
Expand All @@ -67,9 +68,9 @@ To work with `Nanosets`, we just need to configure 1 argument:
start_training_step: 25
data:
dataset:
dataset_path:
datasets/SlimPajama-6B_input_ids.npy: 0.8
datasets/testing_alpaca_small_input_ids.npy: 0.2
dataset_folder:
datasets/SlimPajama-6B: 0.8
datasets/testing_alpaca_small: 0.2
num_loading_workers: 0
seed: 1234
```
Expand All @@ -78,11 +79,14 @@ To work with `Nanosets`, we just need to configure 1 argument:

Finally, to use the `Nanosets`, launch the training with [`run_train.py`](../run_train.py).
```shell
torchrun --nproc-per-node 8 run_train.py --config configs/config_nanoset.yaml
torchrun --nproc-per-node 1 run_train.py --config examples/config_nanoset.yaml
```

## Under the hood
`Nanosets` are responsible of building samples of `sequence length + 1` tokens from the preprocessed dataset files. The `dataset lengths` of each dataset will be determined by the `(dataset_number_of_tokens - 1) / sequence length`, discarding the last sample if its length < `sequence length`.
`Nanosets` are responsible of building samples of `sequence length + 1` tokens from the preprocessed dataset files. Despite most of the extracting logic lies in `DatatroveFolderDataset`, `Nanosets` will take care of the following:
1. Creating dataset mixtures from different dataset folder paths
2. Ensure that in each epoch, we consume each sample only once
3. Ensure that we never exhaust the `DataLoader`

Based on the `dataset lengths`, the `dataset weights` and the `number of samples per epoch` (defined as the `sum(dataset lengths)`), we build the two indexes we need in order to extract samples from the `Nanoset` ([build_nanoset_index_helper](../src/nanotron/data/nanoset.py)):
- `dataset index`: Contains the index of the dataset from the list of `dataset paths` from which to extract the sample, respecting the established dataset weight.
Expand Down
24 changes: 12 additions & 12 deletions examples/config_nanoset.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,25 +7,25 @@ checkpoints:
data_stages:
- data:
dataset:
dataset_path: datasets/testing_alpaca_small_input_ids.npy
dataset_folder: datasets/c4-es/tokenized
num_loading_workers: 1
seed: 42
name: General purpose training (Single dataset)
start_training_step: 1
- data:
dataset:
dataset_path:
- datasets/yelp_review_full_input_ids.npy
- datasets/testing_alpaca_small_input_ids.npy
dataset_folder:
- datasets/SlimPajama-6B/tokenized
- datasets/c4-es/tokenized
num_loading_workers: 1
seed: 42
name: Second purpose training (> 1 dataset)
start_training_step: 15
- data:
dataset:
dataset_path:
datasets/testing_alpaca_small_input_ids.npy: 0.8
datasets/yelp_review_full_input_ids.npy: 0.2
dataset_folder:
datasets/SlimPajama-6B/tokenized: 0.8
datasets/c4-es/tokenized: 0.2
num_loading_workers: 1
seed: 42
name: Third purpose training (Blended dataset)
Expand Down Expand Up @@ -57,7 +57,7 @@ model:
initializer_range: 0.02
intermediate_size: 64
is_llama_config: true
max_position_embeddings: 256
max_position_embeddings: 1024
num_attention_heads: 4
num_hidden_layers: 2
num_key_value_heads: 4
Expand All @@ -67,7 +67,7 @@ model:
rope_scaling: null
tie_word_embeddings: true
use_cache: true
vocab_size: 32000
vocab_size: 50257
optimizer:
accumulate_grad_in_fp32: true
clip_grad: 1.0
Expand All @@ -88,11 +88,11 @@ optimizer:
weight_decay: 0.01
zero_stage: 0
parallelism:
dp: 2
dp: 1
expert_parallel_size: 1
pp: 1
pp_engine: 1f1b
tp: 2
tp: 1
tp_linear_async_communication: true
tp_mode: REDUCE_SCATTER
profiler: null
Expand All @@ -105,6 +105,6 @@ tokens:
limit_test_batches: 0
limit_val_batches: 0
micro_batch_size: 2
sequence_length: 128
sequence_length: 1024
train_steps: 200
val_check_interval: -1
12 changes: 12 additions & 0 deletions examples/mamba/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,18 @@ pip install -r requirements.txt

> https://wandb.ai/bouteille/test/reports/Mamba-loss--Vmlldzo2OTgwNDM5

## Bug related to nanotron
Encountered the following issue when ran train_mamba.sh:
```
causal_conv1d_cuda.cpython-310-x86_64-linux-gnu.so: undefined symbol: _ZNK3c1017SymbolicShapeMeta18init_is_contiguousEv
```
Solved this by doing:
pip uninstall mamba-ssm
pip install causal_conv1d==1.1.1
pip install mamba-ssm --no-cache-dir
https://github.com/state-spaces/mamba/issues/169


## Credits
Credits to the following repositories from which the code was adapted:
- https://github.com/state-spaces/mamba
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ fast-modeling = [

nanosets = [
"transformers",
"datasets",
"datatrove[io,processing]@git+https://github.com/huggingface/datatrove",
"numba",
]

Expand Down
6 changes: 3 additions & 3 deletions run_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,17 +143,17 @@ def get_dataloader_from_data_stage(
elif isinstance(data.dataset, NanosetDatasetsArgs):
# Get tokenizer cardinality
tokenizer = AutoTokenizer.from_pretrained(trainer.config.tokenizer.tokenizer_name_or_path)
token_dtype = np.int32 if len(tokenizer) > np.iinfo(np.uint16).max + 1 else np.uint16
token_size = 4 if len(tokenizer) > np.iinfo(np.uint16).max + 1 else 2
del tokenizer
# Create Nanoset
from nanotron.data.nanoset import Nanoset

with main_rank_first(trainer.parallel_context.world_pg):
train_dataset = Nanoset(
dataset_paths=data.dataset.dataset_path,
dataset_folders=data.dataset.dataset_folder,
dataset_weights=data.dataset.dataset_weights,
sequence_length=trainer.sequence_length,
token_dtype=token_dtype,
token_size=token_size,
train_split_num_samples=trainer.config.tokens.train_steps * trainer.global_batch_size,
random_seed=data.seed,
)
Expand Down
16 changes: 6 additions & 10 deletions src/nanotron/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,25 +93,20 @@ def __post_init__(self):

@dataclass
class NanosetDatasetsArgs:
dataset_path: Union[str, dict, List[str]]
dataset_folder: Union[str, List[str]]
dataset_weights: Optional[List[float]] = None

def __post_init__(self):
if isinstance(self.dataset_path, str): # Case 1: 1 Dataset file
self.dataset_path = [self.dataset_path]
if isinstance(self.dataset_folder, str): # Case 1: 1 Dataset folder
self.dataset_folder = [self.dataset_folder]
self.dataset_weights = [1]
elif isinstance(self.dataset_path, List): # Case 2: > 1 Dataset file
self.dataset_weights = None # Set to None so we consume all the samples randomly
elif isinstance(self.dataset_path, dict): # Case 3: dict with > 1 dataset_path and weights
tmp_dataset_path = self.dataset_path.copy()
self.dataset_path = list(tmp_dataset_path.keys())
self.dataset_weights = list(tmp_dataset_path.values())


@dataclass
class DataArgs:
"""Arguments related to the data and data files processing"""

dataset: Union[PretrainDatasetsArgs, NanosetDatasetsArgs]
dataset: Optional[Union[PretrainDatasetsArgs, NanosetDatasetsArgs]]
seed: Optional[int]
num_loading_workers: Optional[int] = 1

Expand Down Expand Up @@ -145,6 +140,7 @@ class CheckpointsArgs:
checkpoints_path: Path
checkpoint_interval: int
save_initial_state: Optional[bool] = False
save_final_state: Optional[bool] = False
resume_checkpoint_path: Optional[Path] = None
checkpoints_path_is_shared_file_system: Optional[bool] = False

Expand Down
3 changes: 3 additions & 0 deletions src/nanotron/config/models_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ class LlamaConfig:
rms_norm_eps: float = 1e-6
rope_scaling: Optional[dict] = None
rope_theta: float = 10000.0
rope_interleaved: bool = (
False # The default value has been True, but for loading Llama3 checkpoints you have to set it to False
)
tie_word_embeddings: bool = False
use_cache: bool = True
vocab_size: int = 32000
Expand Down
2 changes: 2 additions & 0 deletions src/nanotron/config/parallelism_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ class ParallelismArgs:
tp_linear_async_communication: Optional[bool] = None
recompute_layer: bool = False

tp_recompute_allgather: bool = True

expert_parallel_size: int = 1

def __post_init__(self):
Expand Down
80 changes: 80 additions & 0 deletions src/nanotron/data/collator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import dataclasses
from typing import Dict, List, Union

import numpy as np
import torch
from nanotron import distributed as dist
from nanotron.parallel.context import ParallelContext
from nanotron.parallel.pipeline_parallel.tensor_pointer import TensorPointer


@dataclasses.dataclass
class NanosetDataCollatorForCLM:
"""
Data collator used for causal language modeling with Nanosets dataset.

- input_pp_rank: Discards last input id token
- output_pp_rank: Discards first label id token
- other pp ranks: Don't have data. Instead, we use `TensorPointer` to point to the rank having the data.
"""

sequence_length: int
input_pp_rank: int
output_pp_rank: int
parallel_context: ParallelContext

def __call__(self, examples: List[Dict[str, List[np.ndarray]]]) -> Dict[str, Union[torch.Tensor, TensorPointer]]:
# Process the case when current rank doesn't require data. We return `TensorPointer` that points to ranks having the data.
current_pp_rank = dist.get_rank(self.parallel_context.pp_pg)
if current_pp_rank not in [
self.input_pp_rank,
self.output_pp_rank,
]:
assert all(len(example) == 0 for example in examples)
return {
"input_ids": TensorPointer(group_rank=self.input_pp_rank),
"input_mask": TensorPointer(group_rank=self.input_pp_rank),
"label_ids": TensorPointer(group_rank=self.output_pp_rank),
"label_mask": TensorPointer(group_rank=self.output_pp_rank),
}

# Make sure we load only what's necessary, ie we only load a `input_ids` column.
assert all(list(example.keys()) == ["input_ids"] for example in examples)

# TODO @nouamanetazi: Is it better to have examples as np.array or torch.Tensor?
input_ids = torch.vstack([examples[i]["input_ids"] for i in range(len(examples))]) # (b, s)
batch_size, expanded_input_length = input_ids.shape

result: Dict[str, Union[torch.LongTensor, TensorPointer]] = {}

result["input_ids"] = TensorPointer(group_rank=self.input_pp_rank)
result["input_mask"] = TensorPointer(group_rank=self.input_pp_rank)
result["label_ids"] = TensorPointer(group_rank=self.output_pp_rank)
result["label_mask"] = TensorPointer(group_rank=self.output_pp_rank)

assert (
expanded_input_length == self.sequence_length + 1
), f"Samples should be of length {self.sequence_length + 1} (seq_len+1), but got {expanded_input_length}"

# Process inputs: last token is the label
if current_pp_rank == self.input_pp_rank:
result["input_ids"] = input_ids[:, :-1]
result["input_mask"] = torch.ones((batch_size, self.sequence_length), dtype=torch.bool)

# Process labels: shift them to the left
if current_pp_rank == self.output_pp_rank:
result["label_ids"] = input_ids[:, 1:]
result["label_mask"] = torch.ones((batch_size, self.sequence_length), dtype=torch.bool)

if isinstance(result["input_ids"], torch.Tensor) and result["input_ids"].shape[-1] != self.sequence_length:
raise ValueError(
f"`labels` are incorrectly preprocessed. `labels` length is {result['input_ids'].shape[-1]}, but should be"
f" {self.sequence_length}."
)
if isinstance(result["label_ids"], torch.Tensor) and result["label_ids"].shape[-1] != self.sequence_length:
raise ValueError(
f"`labels` are incorrectly preprocessed. `labels` length is {result['label_ids'].shape[-1]}, but should be"
f" {self.sequence_length}."
)

return result
4 changes: 2 additions & 2 deletions src/nanotron/data/dataloader_builder.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import nanotron.distributed as dist
from nanotron import logging
from nanotron.data.collator import NanosetDataCollatorForCLM
from nanotron.dataloader import (
DataCollatorForCLM,
EmptyInfiniteDataset,
get_dataloader_worker_init,
get_sampler,
Expand Down Expand Up @@ -32,7 +32,7 @@ def build_nanoset_dataloader(
# No need to spawn a lot of workers, we can just use main
dataloader_num_workers = 0

data_collator = DataCollatorForCLM(
data_collator = NanosetDataCollatorForCLM(
sequence_length=sequence_length,
input_pp_rank=input_pp_rank,
output_pp_rank=output_pp_rank,
Expand Down
Loading
Loading