Skip to content

Commit

Permalink
checkpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
goliaro committed Jul 21, 2023
1 parent 1fa3c77 commit 51760fd
Show file tree
Hide file tree
Showing 8 changed files with 61 additions and 15 deletions.
3 changes: 2 additions & 1 deletion include/flexflow/flexflow_c.h
Original file line number Diff line number Diff line change
Expand Up @@ -202,9 +202,10 @@ flexflow_tensor_t
flexflow_tensor_t
flexflow_model_add_embedding(flexflow_model_t handle,
const flexflow_tensor_t input,
int num_entires,
int num_entries,
int out_dim,
enum AggrMode aggr,
DataType dtype,
flexflow_op_t shared_op,
flexflow_initializer_t kernel_initializer,
char const *name);
Expand Down
2 changes: 1 addition & 1 deletion include/flexflow/model.h
Original file line number Diff line number Diff line change
Expand Up @@ -467,7 +467,7 @@ class FFModel {
char const *name = NULL);
// Add an embedding layer
Tensor embedding(const Tensor input,
int num_entires,
int num_entries,
int outDim,
AggrMode aggr,
DataType dtype = DT_FLOAT,
Expand Down
7 changes: 5 additions & 2 deletions python/flexflow/core/flexflow_cffi.py
Original file line number Diff line number Diff line change
Expand Up @@ -1299,7 +1299,7 @@ def conv2d(self, input, out_channels,
return Tensor(handle, owner_op_type=OpType.CONV2D)

def embedding(self, input, num_embeddings, embedding_dim,
aggr, shared_op=None, kernel_initializer=None, name=None):
aggr, dtype=DataType.DT_FLOAT shared_op=None, kernel_initializer=None, name=None):
"""Layer that turns positive integers into dense vectors of fixed size
:param input: the input Tensor.
Expand All @@ -1313,6 +1313,9 @@ def embedding(self, input, num_embeddings, embedding_dim,
:param aggr: aggregation mode. Options are AGGR_MODE_NONE, AGGR_MODE_SUM and AGGR_MODE_AVG.
:type aggr: AggrMode
:param dtype: the tensor data type. Options are DT_BOOLEAN, DT_INT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE, DT_INT4, DT_INT8, DT_NONE
:type dtype: DataType
:param shared_op: the layer whose parameters are shared with. Default is None.
:type shared_op: Op
Expand All @@ -1336,7 +1339,7 @@ def embedding(self, input, num_embeddings, embedding_dim,
(type(kernel_initializer) is NormInitializer), \
f"Unknown initializer type: {kernel_initializer}"
handle = ffc.flexflow_model_add_embedding(
self.handle, input.handle, num_embeddings, embedding_dim, c_aggr,
self.handle, input.handle, num_embeddings, embedding_dim, c_aggr, dtype,
shared_op_handle, kernel_initializer.handle, c_name,
)
# NOTE: We must keep a reference to the initializer or else it will be
Expand Down
2 changes: 1 addition & 1 deletion python/flexflow/serve/models/falcon.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,5 @@
from flexflow.core import *

class FlexFlowFalcon:
def __init__(self):
def __init__(self, max_batch_size=1, max_seq_length=256, max_tokens_per_batch=64):
pass
44 changes: 42 additions & 2 deletions python/flexflow/serve/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,47 @@
# limitations under the License.

from flexflow.core import *
import sys, random

class FlexFlowLLAMA:
class LLAMAConfig(dict):
def __init__(self):
pass
self.n_layers = 32
self.vocab_size = 3200
self.n_heads = 32
self.dim = 4096
self.multiple_of = 256
self.norm_eps = 1e-6
self.total_requests = 2560
self.incremental_mode = True
self.hidden_dim = 11008
self.max_seq_len = 256
self.max_num_tokens = 64
self.max_beam_width = 1
self.max_beam_depth = 8

def __getattr__(self, name):
if name in self:
return self[name]
else:
raise AttributeError(f"'LLAMAConfig' object has no attribute '{name}'")

class FlexFlowLLAMA:
def __init__(self, max_batch_size=1, max_seq_length=256, max_tokens_per_batch=64, use_full_precision=False):
self.max_batch_size = max_batch_size
self.use_full_precision = use_full_precision
self.llama_config = LLAMAConfig()
self.llama_config.max_seq_length = max_seq_length
self.llama_config.max_num_tokens = max_tokens_per_batch

self.build_model()

def build_model(self):
ffconfig = FFConfig()
ffmodel = FFModel(ffconfig)

tokens_dims = [self.max_tokens_per_batch, 1]
input_tensor = ffmodel.create_tensor(tokens_dims, DataType.DT_INT32)

embed_init = UniformInitializer(random.randint(0, sys.maxsize), 0, 0)
token = ffmodel.embedding(input_tensor, self.llama_config.vocab_size, self.llama_config.dim, AggrMode.AGGR_MODE_NONE, DataType.DT_FLOAT if self.use_full_precision else DataType.DT_HALF, None, embed_init)

2 changes: 1 addition & 1 deletion python/flexflow/serve/models/opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,5 @@
from flexflow.core import *

class FlexFlowOPT:
def __init__(self):
def __init__(self, max_batch_size=1, max_seq_length=256, max_tokens_per_batch=64):
pass
2 changes: 1 addition & 1 deletion python/flexflow/serve/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ def __init__(self, model_name, data_type="half"):
self.model_type = self.__get_ff_model_type(model_name)
self.data_type = data_type
self.default_config = SamplingConfig()
self.model = self.model_type()

def __get_ff_model_type(self, model_name):
hf_config = AutoConfig.from_pretrained(model_name)
Expand Down Expand Up @@ -62,6 +61,7 @@ def compile(
self.tensor_parallel_degree = tensor_parallel_degree
self.pipeline_parallel_degree = pipeline_parallel_degree
self.ssms = ssms
self.model = self.model_type(max_batch_size, max_seq_length, max_tokens_per_batch)
assert False and "Not implemented yet"

def generate(self, prompt, sampling=None):
Expand Down
14 changes: 8 additions & 6 deletions src/c/flexflow_c.cc
Original file line number Diff line number Diff line change
Expand Up @@ -462,9 +462,10 @@ flexflow_tensor_t
flexflow_tensor_t
flexflow_model_add_embedding(flexflow_model_t handle_,
const flexflow_tensor_t input_,
int num_entires,
int num_entries,
int out_dim,
enum AggrMode aggr,
DataType dtype,
flexflow_op_t shared_op_,
flexflow_initializer_t kernel_initializer_,
char const *name) {
Expand All @@ -476,20 +477,21 @@ flexflow_tensor_t
// TODO: update the flexflow_c and Python API to support other data types
// Currently we assume it's float
Tensor tensor = handle->embedding(input,
num_entires,
num_entries,
out_dim,
aggr,
DT_FLOAT,
dtype,
shared_op,
kernel_initializer,
name);
DEBUG_PRINT("[Embedding] new Tensor %p, input %p, num_entires %d, out_dim "
"%d, aggr %d, shared_op %p, kernel_init %p, name %s",
DEBUG_PRINT("[Embedding] new Tensor %p, input %p, num_entries %d, out_dim "
"%d, aggr %d, dtype %d, shared_op %p, kernel_init %p, name %s",
tensor,
input,
num_entires,
num_entries,
out_dim,
aggr,
dtype,
shared_op,
kernel_initializer,
name);
Expand Down

0 comments on commit 51760fd

Please sign in to comment.