-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
84 lines (61 loc) · 2.17 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
from typing import List, Optional
import os
import subprocess
import logging
logger = logging.getLogger(__name__)
def get_hash_from_bucket(
bucket_uri: str, s3_sync_args: Optional[List[str]] = None
) -> str:
s3_sync_args = s3_sync_args or []
subprocess.run(
["awsv2", "s3", "cp", "--quiet"]
+ s3_sync_args
+ [os.path.join(bucket_uri, "refs", "main"), "."]
)
with open(os.path.join(".", "main"), "r") as f:
f_hash = f.read().strip()
return f_hash
def get_checkpoint_and_refs_dir(
model_id: str,
bucket_uri: str,
s3_sync_args: Optional[List[str]] = None,
mkdir: bool = False,
) -> str:
from transformers.utils.hub import TRANSFORMERS_CACHE
f_hash = get_hash_from_bucket(bucket_uri, s3_sync_args)
path = os.path.join(TRANSFORMERS_CACHE, f"models--{model_id.replace('/', '--')}")
refs_dir = os.path.join(path, "refs")
checkpoint_dir = os.path.join(path, "snapshots", f_hash)
if mkdir:
os.makedirs(refs_dir, exist_ok=True)
os.makedirs(checkpoint_dir, exist_ok=True)
return checkpoint_dir, refs_dir
def get_download_path(model_id: str):
from transformers.utils.hub import TRANSFORMERS_CACHE
path = os.path.join(TRANSFORMERS_CACHE, f"models--{model_id.replace('/', '--')}")
return path
def download_model(
model_id: str,
bucket_uri: str,
s3_sync_args: Optional[List[str]] = None,
tokenizer_only: bool = False,
) -> None:
"""
Download a model from an S3 bucket and save it in TRANSFORMERS_CACHE for
seamless interoperability with Hugging Face's Transformers library.
The downloaded model may have a 'hash' file containing the commit hash corresponding
to the commit on Hugging Face Hub.
"""
s3_sync_args = s3_sync_args or []
path = get_download_path(model_id)
cmd = (
["awsv2", "s3", "sync"]
+ s3_sync_args
+ (["--exclude", "*", "--include", "*token*"] if tokenizer_only else [])
+ [bucket_uri, path]
)
print(f"RUN({cmd})")
subprocess.run(cmd)
print("done")
def get_mirror_link(model_id: str) -> str:
return f"s3://llama-2-weights/models--{model_id.replace('/', '--')}"