Skip to content

Commit

Permalink
Apply linter to existing files (#39)
Browse files Browse the repository at this point in the history
  • Loading branch information
wconstab authored Feb 2, 2024
1 parent b99af33 commit bfa5db1
Show file tree
Hide file tree
Showing 8 changed files with 201 additions and 169 deletions.
6 changes: 5 additions & 1 deletion torchtrain/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.

from torchtrain.datasets.alpaca import build_alpaca_data_loader
from torchtrain.datasets.tokenizer import create_tokenizer
from torchtrain.datasets.pad_batch_sequence import pad_batch_to_longest_seq
from torchtrain.datasets.tokenizer import create_tokenizer

__all__ = ["build_alpaca_data_loader", "create_tokenizer", "pad_batch_to_longest_seq"]

dataloader_fn = {
"alpaca": build_alpaca_data_loader,
Expand Down
27 changes: 8 additions & 19 deletions torchtrain/datasets/alpaca.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,15 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.

from typing import List, Tuple
from typing import List

import torch

from datasets import load_dataset
from torch.utils.data import IterableDataset, DataLoader, DistributedSampler
from torch.utils.data import DataLoader, IterableDataset

from torchtrain.datasets.tokenizer import TokenizerIf

from datasets import load_dataset


class AlpacaDataset(IterableDataset):
"""PyTorch Representation of the Alpaca Dataset from Hugging Face.
Expand All @@ -37,11 +34,7 @@ class AlpacaDataset(IterableDataset):
Batch size: 8
"""

def __init__(self,
tokenizer: TokenizerIf,
seq_len: int = 2048,
**kwargs
) -> None:
def __init__(self, tokenizer: TokenizerIf, seq_len: int = 2048, **kwargs) -> None:
self._data = load_dataset("tatsu-lab/alpaca", split="train")
self._tokenizer = tokenizer
self.data_iterator = iter(self._data)
Expand All @@ -52,7 +45,7 @@ def __len__(self):
return len(self._data)

def __iter__(self):
max_buffer_token_len = (1 + self.seq_len)
max_buffer_token_len = 1 + self.seq_len
all_tokens: List[int] = []

for sample in self.data_iterator:
Expand All @@ -71,11 +64,7 @@ def __iter__(self):


def build_alpaca_data_loader(
tokenizer: TokenizerIf,
batch_size: int,
seq_len: int,
world_size,
rank
tokenizer: TokenizerIf, batch_size: int, seq_len: int, world_size, rank
):
alpaca_ds = AlpacaDataset(tokenizer=tokenizer, seq_len=seq_len)
# TOOD: sampler can't work with iterable dataset, figure out a way
Expand Down
35 changes: 27 additions & 8 deletions torchtrain/datasets/download_tokenizer.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import os
from typing import Optional

Expand All @@ -11,20 +12,38 @@

def hf_download(repo_id: Optional[str] = None, hf_token: Optional[str] = None) -> None:
from huggingface_hub import hf_hub_download

os.makedirs(f"checkpoints/{repo_id}", exist_ok=True)
try:
hf_hub_download(repo_id, "tokenizer.model", local_dir=f"torchtrain/datasets/tokenizer/", local_dir_use_symlinks=False, token=hf_token)
hf_hub_download(
repo_id,
"tokenizer.model",
local_dir="torchtrain/datasets/tokenizer/",
local_dir_use_symlinks=False,
token=hf_token,
)
except HTTPError as e:
if e.response.status_code == 401:
print("You need to pass a valid `--hf_token=...` to download private checkpoints.")
print(
"You need to pass a valid `--hf_token=...` to download private checkpoints."
)
else:
raise e

if __name__ == '__main__':

if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description='Download tokenizer from HuggingFace.')
parser.add_argument('--repo_id', type=str, default="meta-llama/llama-2-70b", help='Repository ID to download from.')
parser.add_argument('--hf_token', type=str, default=None, help='HuggingFace API token.')

parser = argparse.ArgumentParser(description="Download tokenizer from HuggingFace.")
parser.add_argument(
"--repo_id",
type=str,
default="meta-llama/llama-2-70b",
help="Repository ID to download from.",
)
parser.add_argument(
"--hf_token", type=str, default=None, help="HuggingFace API token."
)

args = parser.parse_args()
hf_download(args.repo_id, args.hf_token)
14 changes: 13 additions & 1 deletion torchtrain/models/llama/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,20 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.

from torchtrain.models.llama.model import ModelArgs, Transformer

__all__ = ["Transformer"]

llama_configs = {
"debugmodel": ModelArgs(dim=256, n_layers=1, n_heads=16),
"7B": ModelArgs(dim=4096, n_layers=32, n_heads=32),
"13B": ModelArgs(dim=5120, n_layers=40, n_heads=40),
"70B": ModelArgs(dim=8192, n_layers=80, n_heads=64, n_kv_heads=8, ffn_dim_multiplier=1.3, multiple_of=4096),
"70B": ModelArgs(
dim=8192,
n_layers=80,
n_heads=64,
n_kv_heads=8,
ffn_dim_multiplier=1.3,
multiple_of=4096,
),
}
Loading

0 comments on commit bfa5db1

Please sign in to comment.