forked from foundation-model-stack/fms-hf-tuning
-
Notifications
You must be signed in to change notification settings - Fork 0
/
configs.py
170 lines (157 loc) · 5.95 KB
/
configs.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
# Copyright The FMS HF Tuning Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Standard
from dataclasses import dataclass, field
from typing import List, Optional, Union
# Third Party
import torch
import transformers
# Local
from tuning.trackers.tracker_factory import FILE_LOGGING_TRACKER
DEFAULT_CONTEXT_LENGTH = 4096
DEFAULT_OPTIMIZER = "adamw_torch"
IGNORE_INDEX = -100
DEFAULT_PAD_TOKEN = "<PAD>"
DEFAULT_EOS_TOKEN = "</s>"
DEFAULT_BOS_TOKEN = "<s>"
DEFAULT_UNK_TOKEN = "<unk>"
@dataclass
class ModelArguments:
model_name_or_path: Optional[str] = field(default="facebook/opt-125m")
use_flash_attn: bool = field(
default=True,
metadata={"help": "Use Flash attention v2 from transformers, default is True"},
)
torch_dtype: Optional[Union[torch.dtype, str]] = torch.bfloat16
embedding_size_multiple_of: Optional[int] = field(
default=1,
metadata={
"help": "Resize model embedding layer to the nearest multiple of \
the given number after tokenizer modifications. \
NOTE: This involves extending \
the embedding layer without any corresponding real tokens."
},
)
tokenizer_name_or_path: Optional[str] = field(
default=None,
metadata={
"help": "Path to custom tokenizer. \
If not provided it defaults to model_name_or_path \
and special tokens will be added as needed for specific tokenizer classes. \
For prompt tuning, if tokenizer_name_or_path provided, special tokens are not added, \
otherwise, it defaults to model_name_or_path with special tokens for specific \
tokenizer classes."
},
)
@dataclass
class DataArguments:
training_data_path: str = field(
default=None,
metadata={"help": "Path to the training data in JSON/JSONL format."},
)
response_template: str = field(
default=None,
metadata={"help": "Response template, separator to train on completions only"},
)
dataset_text_field: str = field(
default=None,
metadata={
"help": "Training dataset text field containing single sequence. \
Either the dataset_text_field \
or data_formatter_template need to be supplied."
},
)
validation_data_path: str = field(
default=None,
metadata={"help": "Path to the validation data in JSON/JSONL format."},
)
data_formatter_template: str = field(
default=None,
metadata={
"help": "formatter template to format a single sequence \
from each instance in JSONL files. \
Keys of JSON can be referred to as {{key}} in template. \
Either the dataset_text_field \
or data_formatter_template needs to be supplied."
},
)
@dataclass
class TrainingArguments(transformers.TrainingArguments):
# pylint: disable=too-many-instance-attributes
cache_dir: Optional[str] = field(default=None)
# optim: str = field(default=DEFAULT_OPTIMIZER)
max_seq_length: int = field(
default=DEFAULT_CONTEXT_LENGTH,
metadata={
"help": "Maximum sequence length. Sequences will be right padded \
(and possibly truncated)."
},
)
packing: bool = field(
default=False,
metadata={"help": "Packing to be enabled in SFT Trainer, default is False"},
)
save_strategy: str = field(
default="epoch",
metadata={
"help": "The checkpoint save strategy to adopt during training. \
Possible values are 'no'(no save is done during training), \
'epoch' (save is done at the end of each epoch), \
'steps' (save is done every `save_steps`)"
},
)
save_model_dir: str = field(
default=None,
metadata={
"help": "Directory where tuned model will be saved to \
using SFTTrainer.save_model()."
},
)
logging_strategy: str = field(
default="epoch",
metadata={
"help": "The logging strategy to adopt during training. \
Possible values are 'no'(no logging is done during training), \
'epoch' (logging is done at the end of each epoch), \
'steps' (logging is done every `logging_steps`)"
},
)
trackers: Optional[List[str.lower]] = field(
default_factory=lambda: [FILE_LOGGING_TRACKER],
metadata={
"help": "Experiment trackers to use.\n"
+ "Available trackers are - file_logger(default), aim, none\n"
+ "Requires additional configs, see tuning.configs/tracker_configs.py"
},
)
log_level: str = field(
default="passive",
metadata={
"help": "The log level to adopt during training. \
By default, 'passive' level is set which keeps the \
current log level for the Transformers library (which will be 'warning` by default) \
Other possible values are 'debug', 'info', 'warning', 'error' and 'critical'"
},
)
@dataclass
class TrainerControllerArguments:
trainer_controller_config_file: str = field(
default=None,
metadata={
"help": (
"Trainer controller configuration file (e.g trainercontroller_config.yaml) \
in YAML format."
)
},
)