forked from google/gemma.cpp
-
Notifications
You must be signed in to change notification settings - Fork 0
/
gemma.h
207 lines (177 loc) · 6.92 KB
/
gemma.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_H_
#define THIRD_PARTY_GEMMA_CPP_GEMMA_H_
#include <algorithm>
#include <cctype>
#include <functional>
#include <memory>
#include <random>
#include <string>
#include <vector>
// copybara:import_next_line:gemma_cpp
#include "compression/compress.h" // SfpStream/NuqStream
// copybara:import_next_line:gemma_cpp
#include "configs.h" // kSeqLen
// copybara:import_next_line:gemma_cpp
#include "util/args.h" // ArgsBase
#include "hwy/aligned_allocator.h"
#include "hwy/base.h" // hwy::bfloat16_t
#include "hwy/contrib/thread_pool/thread_pool.h"
// copybara:import_next_line:sentencepiece
#include "src/sentencepiece_processor.h"
namespace gcpp {
// Allowable types for GEMMA_WEIGHT_T (can be specified at compilation time):
// float, hwy::bfloat16_t, SfpStream, NuqStream
#ifndef GEMMA_WEIGHT_T
#define GEMMA_WEIGHT_T SfpStream
#endif // !GEMMA_WEIGHT_T
using WeightT = GEMMA_WEIGHT_T;
using EmbedderInputT = hwy::bfloat16_t;
constexpr size_t kPrefillBatchSize = 16;
constexpr bool kSystemPrompt = false;
struct KVCache {
hwy::AlignedFreeUniquePtr<float[]>
key_cache; // batch_size * kSeqLen * kLayers * kKVHeads * kQKVDim
hwy::AlignedFreeUniquePtr<float[]>
value_cache; // batch_size * kSeqLen * kLayers * kKVHeads * kQKVDim
};
// Model variants: see configs.h for details.
enum class Model { GEMMA_2B, GEMMA_7B };
enum class ModelTraining { GEMMA_IT, GEMMA_PT };
struct LoaderArgs : public ArgsBase<LoaderArgs> {
LoaderArgs(int argc, char* argv[]) { InitAndParse(argc, argv); }
static std::string ToLower(const std::string& text) {
std::string result = text;
std::transform(begin(result), end(result), begin(result),
[](unsigned char c) { return std::tolower(c); });
return result;
}
gcpp::Model ModelType() const {
const std::string model_type_lc = ToLower(model_type);
if (model_type_lc == "2b-pt" || model_type_lc == "2b-it") {
return gcpp::Model::GEMMA_2B;
} else {
return gcpp::Model::GEMMA_7B;
}
}
gcpp::ModelTraining ModelTraining() const {
const std::string model_type_lc = ToLower(model_type);
if (model_type_lc == "7b-pt" || model_type_lc == "2b-pt") {
return gcpp::ModelTraining::GEMMA_PT;
} else {
return gcpp::ModelTraining::GEMMA_IT;
}
}
// Returns error string or nullptr if OK.
const char* Validate() const {
const std::string model_type_lc = ToLower(model_type);
if (model_type_lc != "2b-pt" && model_type_lc != "7b-pt" &&
model_type_lc != "2b-it" && model_type_lc != "7b-it") {
return "Model type must be 2b-pt, 7b-pt, 2b-it, or "
"7b-it.";
}
if (tokenizer.path.empty()) {
return "Missing --tokenizer flag, a file for the tokenizer is required.";
}
if (model_type.empty()) {
return "Missing --model flag, need to specify either 2b-pt, 7b-pt, "
"2b-it, or 7b-it.";
}
if (cache.path.empty()) {
return "Missing --compressed_weights flag, a file for the compressed "
"model.";
}
return nullptr;
}
Path tokenizer;
Path model; // uncompressed weights OR
Path cache; // compressed weights
std::string model_type;
template <class Visitor>
void ForEach(const Visitor& visitor) {
visitor(tokenizer, "tokenizer", Path(),
"Path name of tokenizer model file. (required)");
visitor(
cache, "compressed_weights", Path(),
"Path name of compressed weights file, regenerated from `--weights` "
"file if "
"the compressed weights file does not exist. (required)");
visitor(model_type, "model", std::string(),
"Model type - can be 2b-it (2B parameters, instruction-tuned), "
"2b-pt (2B parameters, pretrained), 7b-it (7B parameters, "
"instruction-tuned), or 7b-pt (7B parameters, pretrained). "
"(required)");
visitor(model, "weights", Path(),
"Path name of model weights (.sbs) file. Only required if "
"compressed_weights file is not present and needs to be "
"regenerated. Otherwise, not needed");
}
};
struct GemmaInterface;
struct Gemma {
Gemma(const LoaderArgs& args, hwy::ThreadPool& pool);
~Gemma(); // must be defined after GemmaInterface's dtor is defined.
const sentencepiece::SentencePieceProcessor& Tokenizer() const;
std::unique_ptr<GemmaInterface> impl_;
gcpp::ModelTraining model_training;
};
// StreamFunc is called with (token, probability). For prompt tokens,
// probability is 0.0f.
using StreamFunc = std::function<bool(int, float)>;
using AcceptFunc = std::function<bool(int)>;
struct InferenceArgs : public ArgsBase<InferenceArgs> {
InferenceArgs(int argc, char* argv[]) { InitAndParse(argc, argv); }
size_t max_tokens;
size_t max_generated_tokens;
float temperature;
bool deterministic;
bool multiturn;
// Returns error string or nullptr if OK.
const char* Validate() const {
if (max_tokens > gcpp::kSeqLen) {
return "max_tokens is larger than the maximum sequence length (see "
"configs.h).";
}
if (max_generated_tokens > max_tokens) {
return "Maximum number of generated tokens is larger than the maximum "
"total tokens.";
}
return nullptr;
}
template <class Visitor>
void ForEach(const Visitor& visitor) {
visitor(max_tokens, "max_tokens", size_t{3072},
"Maximum number of tokens in prompt + generation.");
visitor(max_generated_tokens, "max_generated_tokens", size_t{2048},
"Maximum number of tokens to generate.");
visitor(temperature, "temperature", 1.0f, "Temperature for top-K", 2);
visitor(deterministic, "deterministic", false,
"Make top-k sampling deterministic", 2);
visitor(multiturn, "multiturn", true,
"Multiturn mode (if 0, this clears the KV cache after every "
"interaction without quitting)",
2);
}
};
void GenerateGemma(Gemma& gemma, const InferenceArgs& args,
const std::vector<int>& prompt, size_t start_pos,
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool,
const StreamFunc& stream_token,
const AcceptFunc& accept_token, std::mt19937& g,
int verbosity);
constexpr int EOS_ID = 1;
} // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_H_