From 87019d0ada6488fbb706a49100b8a3ac0fb22672 Mon Sep 17 00:00:00 2001 From: ssbuild <462304@qq.cn> Date: Mon, 20 Nov 2023 14:30:35 +0800 Subject: [PATCH] support seed Signed-off-by: ssbuild <462304@qq.cn> --- README.MD | 1 + serving/model_handler/baichuan2_13b/infer.py | 13 +++++---- serving/model_handler/baichuan2_7b/infer.py | 13 +++++---- serving/model_handler/baichuan_13b/infer.py | 17 ++++++------ serving/model_handler/baichuan_7b/infer.py | 17 ++++++------ serving/model_handler/base/__init__.py | 2 +- serving/model_handler/base/data_process.py | 28 +++++++++++++++++--- serving/model_handler/bluelm/infer.py | 17 ++++++------ serving/model_handler/chatglm/infer.py | 17 ++++++------ serving/model_handler/chatglm2/infer.py | 17 ++++++------ serving/model_handler/chatglm3/infer.py | 17 ++++++------ serving/model_handler/internlm/infer.py | 17 ++++++------ serving/model_handler/llama/infer.py | 17 ++++++------ serving/model_handler/llm/infer.py | 13 +++++---- serving/model_handler/moss/infer.py | 17 ++++++------ serving/model_handler/qwen/infer.py | 17 ++++++------ serving/model_handler/rwkv/infer.py | 17 ++++++------ serving/model_handler/skywork/infer.py | 17 ++++++------ serving/model_handler/t5/infer.py | 17 ++++++------ serving/model_handler/xverse/infer.py | 17 ++++++------ serving/model_handler/yi/infer.py | 17 ++++++------ serving/openai_api/custom.py | 18 ++++++++++++- serving/openai_api/openai_api_protocol.py | 2 +- serving/serve/api_react.py | 7 ++++- tests/test_openai.py | 1 + tests/test_openai_chat.py | 4 ++- tests/web_demo.py | 2 +- 27 files changed, 193 insertions(+), 166 deletions(-) diff --git a/README.MD b/README.MD index 729baa3..4613874 100644 --- a/README.MD +++ b/README.MD @@ -10,6 +10,7 @@ ## update information ```text + 11-20 support seed for generator sample 11-06 fix pydantic 2 and support api_keys in config 11-04 support yi aigc-zoo>=0.2.7.post2 , 支持 pydantic >= 2 11-01 support bluelm aigc-zoo>=0.2.7.post1 diff --git a/serving/model_handler/baichuan2_13b/infer.py b/serving/model_handler/baichuan2_13b/infer.py index 428879f..d093b90 100644 --- a/serving/model_handler/baichuan2_13b/infer.py +++ b/serving/model_handler/baichuan2_13b/infer.py @@ -13,7 +13,7 @@ from transformers import HfArgumentParser, BitsAndBytesConfig, GenerationConfig from aigc_zoo.model_zoo.baichuan.baichuan2_13b.llm_model import MyTransformer,BaichuanConfig,BaichuanTokenizer,\ MyBaichuanForCausalLM,PetlArguments,PetlModel -from serving.model_handler.base import EngineAPI_Base,CompletionResult, CompletionResult,LoraModelState, load_lora_config, GenerateProcess,WorkMode +from serving.model_handler.base import EngineAPI_Base,CompletionResult, CompletionResult,LoraModelState, load_lora_config, GenArgs,WorkMode from serving.prompt import * @@ -123,12 +123,11 @@ def get_default_gen_args(self): return default_kwargs def chat_stream(self,messages: List[Dict], **kwargs): - args_process = GenerateProcess(self,is_stream=True) - args_process.preprocess(kwargs) + args_process = GenArgs(kwargs, self, is_stream=True) chunk = args_process.chunk default_kwargs=self.get_default_gen_args() default_kwargs.update(kwargs) - args_process.postprocess(default_kwargs) + args_process.build_args(default_kwargs) stopping_criteria = default_kwargs.pop('stopping_criteria', None) generation_config = GenerationConfig(**default_kwargs) for response in self.get_model().chat(tokenizer=self.tokenizer, @@ -156,11 +155,10 @@ def chat_stream(self,messages: List[Dict], **kwargs): def chat(self,messages: List[Dict], **kwargs): - args_process = GenerateProcess(self) - args_process.preprocess(kwargs) + args_process = GenArgs(kwargs, self) default_kwargs = self.get_default_gen_args() default_kwargs.update(kwargs) - args_process.postprocess(default_kwargs) + args_process.build_args(default_kwargs) stopping_criteria = default_kwargs.pop('stopping_criteria', None) generation_config = GenerationConfig(**default_kwargs) response = self.get_model().chat(tokenizer=self.tokenizer, @@ -177,6 +175,7 @@ def chat(self,messages: List[Dict], **kwargs): def embedding(self, query, **kwargs): + args_process = GenArgs(kwargs, self) model = self.get_model() inputs = self.tokenizer(query, return_tensors="pt") inputs = inputs.to(model.device) diff --git a/serving/model_handler/baichuan2_7b/infer.py b/serving/model_handler/baichuan2_7b/infer.py index 1fb735b..5a512e7 100644 --- a/serving/model_handler/baichuan2_7b/infer.py +++ b/serving/model_handler/baichuan2_7b/infer.py @@ -14,7 +14,7 @@ from aigc_zoo.model_zoo.baichuan.baichuan2_7b.llm_model import MyTransformer,BaichuanConfig,BaichuanTokenizer,\ MyBaichuanForCausalLM,PetlArguments,PetlModel from serving.model_handler.base import EngineAPI_Base,CompletionResult,LoraModelState, \ - load_lora_config,GenerateProcess,WorkMode + load_lora_config,GenArgs,WorkMode from serving.prompt import * class NN_DataHelper(DataHelper):pass @@ -125,12 +125,11 @@ def get_default_gen_args(self): return default_kwargs def chat_stream(self,messages: List[Dict], **kwargs): - args_process = GenerateProcess(self,is_stream=True) - args_process.preprocess(kwargs) + args_process = GenArgs(kwargs, self, is_stream=True) chunk = args_process.chunk default_kwargs=self.get_default_gen_args() default_kwargs.update(kwargs) - args_process.postprocess(default_kwargs) + args_process.build_args(default_kwargs) stopping_criteria = default_kwargs.pop('stopping_criteria', None) generation_config = GenerationConfig(**default_kwargs) for response in self.get_model().chat(tokenizer=self.tokenizer, @@ -158,11 +157,10 @@ def chat_stream(self,messages: List[Dict], **kwargs): def chat(self,messages: List[Dict], **kwargs): - args_process = GenerateProcess(self) - args_process.preprocess(kwargs) + args_process = GenArgs(kwargs, self) default_kwargs = self.get_default_gen_args() default_kwargs.update(kwargs) - args_process.postprocess(default_kwargs) + args_process.build_args(default_kwargs) stopping_criteria = default_kwargs.pop('stopping_criteria', None) generation_config = GenerationConfig(**default_kwargs) response = self.get_model().chat(tokenizer=self.tokenizer, @@ -177,6 +175,7 @@ def chat(self,messages: List[Dict], **kwargs): def embedding(self, query, **kwargs): + args_process = GenArgs(kwargs, self) model = self.get_model() inputs = self.tokenizer(query, return_tensors="pt") inputs = inputs.to(model.device) diff --git a/serving/model_handler/baichuan_13b/infer.py b/serving/model_handler/baichuan_13b/infer.py index f7783ba..3c4d611 100644 --- a/serving/model_handler/baichuan_13b/infer.py +++ b/serving/model_handler/baichuan_13b/infer.py @@ -13,7 +13,7 @@ from transformers import HfArgumentParser, BitsAndBytesConfig, GenerationConfig from aigc_zoo.model_zoo.baichuan.baichuan_13b.llm_model import MyTransformer,BaichuanConfig,BaichuanTokenizer,\ MyBaichuanForCausalLM,PetlArguments,PetlModel -from serving.model_handler.base import EngineAPI_Base,CompletionResult,LoraModelState, load_lora_config, GenerateProcess,WorkMode +from serving.model_handler.base import EngineAPI_Base,CompletionResult,LoraModelState, load_lora_config, GenArgs,WorkMode from serving.prompt import * @@ -123,12 +123,11 @@ def get_default_gen_args(self): return default_kwargs def chat_stream(self,messages: List[Dict], **kwargs): - args_process = GenerateProcess(self,is_stream=True) - args_process.preprocess(kwargs) + args_process = GenArgs(kwargs, self, is_stream=True) chunk = args_process.chunk default_kwargs = self.get_default_gen_args() default_kwargs.update(kwargs) - args_process.postprocess(default_kwargs) + args_process.build_args(default_kwargs) stopping_criteria = default_kwargs.pop('stopping_criteria', None) generation_config = GenerationConfig(**default_kwargs) for response in self.get_model().chat(tokenizer=self.tokenizer, @@ -155,11 +154,10 @@ def chat_stream(self,messages: List[Dict], **kwargs): def chat(self,messages: List[Dict], **kwargs): - args_process = GenerateProcess(self) - args_process.preprocess(kwargs) + args_process = GenArgs(kwargs, self) default_kwargs = self.get_default_gen_args() default_kwargs.update(kwargs) - args_process.postprocess(default_kwargs) + args_process.build_args(default_kwargs) stopping_criteria = default_kwargs.pop('stopping_criteria', None) generation_config = GenerationConfig(**default_kwargs) response = self.get_model().chat(tokenizer=self.tokenizer, @@ -174,10 +172,10 @@ def chat(self,messages: List[Dict], **kwargs): }) def generate(self, messages: List[Dict], **kwargs): - args_process = GenerateProcess(self) + args_process = GenArgs(kwargs, self) default_kwargs = self.get_default_gen_args() default_kwargs.update(kwargs) - args_process.postprocess(default_kwargs) + args_process.build_args(default_kwargs) query = args_process.get_chat_info(messages, chat_format="generate") response = self.get_model().generate(query=query, **kwargs) return CompletionResult(result={ @@ -187,6 +185,7 @@ def generate(self, messages: List[Dict], **kwargs): def embedding(self, query, **kwargs): + args_process = GenArgs(kwargs, self) model = self.get_model() inputs = self.tokenizer(query, return_tensors="pt") inputs = inputs.to(model.device) diff --git a/serving/model_handler/baichuan_7b/infer.py b/serving/model_handler/baichuan_7b/infer.py index a60a66e..34e91f5 100644 --- a/serving/model_handler/baichuan_7b/infer.py +++ b/serving/model_handler/baichuan_7b/infer.py @@ -14,7 +14,7 @@ from transformers import HfArgumentParser, BitsAndBytesConfig, GenerationConfig from aigc_zoo.model_zoo.baichuan.baichuan_7b.llm_model import MyTransformer,BaiChuanConfig,BaiChuanTokenizer,PetlArguments,PetlModel from aigc_zoo.generator_utils.generator_llm import Generate -from serving.model_handler.base import EngineAPI_Base,CompletionResult, LoraModelState, load_lora_config,GenerateProcess,WorkMode +from serving.model_handler.base import EngineAPI_Base,CompletionResult, LoraModelState, load_lora_config,GenArgs,WorkMode from serving.prompt import * class NN_DataHelper(DataHelper):pass @@ -134,12 +134,11 @@ def get_default_gen_args(self): return default_kwargs def chat_stream(self,messages: List[Dict], **kwargs): - args_process = GenerateProcess(self,is_stream=True) - args_process.preprocess(kwargs) + args_process = GenArgs(kwargs, self, is_stream=True) chunk = args_process.chunk default_kwargs= self.get_default_gen_args() default_kwargs.update(kwargs) - generation_config = GenerationConfig(**args_process.postprocess(default_kwargs)) + generation_config = GenerationConfig(**args_process.build_args(default_kwargs)) query, history = args_process.get_chat_info(messages) prompt = get_chat_default(self.tokenizer, query, history) @@ -177,13 +176,12 @@ def stream_generator(): def chat(self,messages: List[Dict], **kwargs): - args_process = GenerateProcess(self) - args_process.preprocess(kwargs) + args_process = GenArgs(kwargs, self) default_kwargs = self.get_default_gen_args() default_kwargs.update(kwargs) query, history = args_process.get_chat_info(messages) prompt = get_chat_default(self.tokenizer, query, history) - response = self.gen_core.generate(query=prompt, **args_process.postprocess(default_kwargs)) + response = self.gen_core.generate(query=prompt, **args_process.build_args(default_kwargs)) response = args_process.postprocess_response(response, **kwargs) # history = history + [(query, response)] return CompletionResult(result={ @@ -192,10 +190,10 @@ def chat(self,messages: List[Dict], **kwargs): }) def generate(self,messages: List[Dict],**kwargs): - args_process = GenerateProcess(self) + args_process = GenArgs(kwargs, self) default_kwargs = self.get_default_gen_args() default_kwargs.update(kwargs) - args_process.postprocess(default_kwargs) + args_process.build_args(default_kwargs) query = args_process.get_chat_info(messages,chat_format="generate") response = self.gen_core.generate(query=query, **kwargs) return CompletionResult(result={ @@ -204,6 +202,7 @@ def generate(self,messages: List[Dict],**kwargs): }) def embedding(self, query, **kwargs): + args_process = GenArgs(kwargs, self) model = self.get_model() inputs = self.tokenizer(query, return_tensors="pt") inputs = inputs.to(model.device) diff --git a/serving/model_handler/base/__init__.py b/serving/model_handler/base/__init__.py index 5d3b2a6..207d0aa 100644 --- a/serving/model_handler/base/__init__.py +++ b/serving/model_handler/base/__init__.py @@ -4,6 +4,6 @@ from .infer import EngineAPI_Base,flat_input,CompletionResult from .data_define import ChunkData, LoraModelState, WorkMode -from .data_process import GenerateProcess +from .data_process import GenArgs from .loaders import load_lora_config from .utils import is_quantization_bnb diff --git a/serving/model_handler/base/data_process.py b/serving/model_handler/base/data_process.py index e51f7ef..7bdbb2d 100644 --- a/serving/model_handler/base/data_process.py +++ b/serving/model_handler/base/data_process.py @@ -156,22 +156,42 @@ def _calc_stopped_samples(self, input_ids: torch.LongTensor) -> bool: del ids return False -class GenerateProcess: - def __init__(self,this_obj,is_stream=False): +class GenArgs: + def __init__(self,args_dict:Dict, this_obj,is_stream=False): + if args_dict is None: + args_dict = {} + self.tokenizer: Optional[PreTrainedTokenizer] = this_obj.tokenizer self.config: Optional[PretrainedConfig] = this_obj.config self.is_stream = is_stream self.chunk: Optional[ChunkData] = None self.this_obj = this_obj - def preprocess(self, args_dict: dict): + # support seed + self.multinomial_fn = torch.multinomial + self.__preprocess(args_dict) + def __del__(self): + # restore + if torch.multinomial != self.multinomial_fn: + torch.multinomial = self.multinomial_fn + + def __preprocess(self, args_dict): if self.is_stream: nchar = args_dict.pop('nchar',1) gtype = args_dict.pop('gtype',"total") self.chunk = ChunkData(nchar=nchar, stop=args_dict.get('stop', None), mode=gtype) + + seed = args_dict.pop('seed',None) + + #进程隔离,互不影响 + if isinstance(seed,int): + device = self.this_obj.get_model().device + torch.multinomial = lambda *args, **kwargs: self.multinomial_fn(*args, + generator=torch.Generator(device=device).manual_seed(seed), + **kwargs) return args_dict - def postprocess(self, args_dict): + def build_args(self, args_dict): stop = args_dict.pop('stop',None) if stop is None: return args_dict diff --git a/serving/model_handler/bluelm/infer.py b/serving/model_handler/bluelm/infer.py index c790a67..54d7d3c 100644 --- a/serving/model_handler/bluelm/infer.py +++ b/serving/model_handler/bluelm/infer.py @@ -18,7 +18,7 @@ from aigc_zoo.model_zoo.bluelm.llm_model import MyBlueLMForCausalLM,BlueLMTokenizer,BlueLMConfig from aigc_zoo.model_zoo.llm.llm_model import MyTransformer,PetlArguments,PetlModel,AutoConfig from aigc_zoo.generator_utils.generator_llm import Generate -from serving.model_handler.base import EngineAPI_Base,CompletionResult, LoraModelState, load_lora_config, GenerateProcess,WorkMode,ChunkData +from serving.model_handler.base import EngineAPI_Base,CompletionResult, LoraModelState, load_lora_config, GenArgs,WorkMode,ChunkData from serving.prompt import * class NN_DataHelper(DataHelper):pass @@ -147,11 +147,10 @@ def get_default_gen_args(self): return default_kwargs def chat_stream(self,messages: List[Dict], **kwargs): - args_process = GenerateProcess(self,is_stream=True) - args_process.preprocess(kwargs) + args_process = GenArgs(kwargs, self, is_stream=True) default_kwargs = self.get_default_gen_args() default_kwargs.update(kwargs) - args_process.postprocess(default_kwargs) + args_process.build_args(default_kwargs) prefix,query, history = args_process.get_chat_info_with_system(messages) prompt = get_chat_bluelm(self.tokenizer, query, history=history, prefix=prefix) skip_word_list = default_kwargs.get('eos_token_id', None) or [self.tokenizer.eos_token_id] @@ -164,11 +163,10 @@ def chat_stream(self,messages: List[Dict], **kwargs): def chat(self,messages: List[Dict], **kwargs): - args_process = GenerateProcess(self) - args_process.preprocess(kwargs) + args_process = GenArgs(kwargs, self) default_kwargs = self.get_default_gen_args() default_kwargs.update(kwargs) - args_process.postprocess(default_kwargs) + args_process.build_args(default_kwargs) prefix,query, history = args_process.get_chat_info_with_system(messages) prompt = get_chat_bluelm(self.tokenizer, query, history=history, prefix=prefix) response = self.gen_core.generate(query=prompt, **default_kwargs) @@ -180,10 +178,10 @@ def chat(self,messages: List[Dict], **kwargs): def generate(self,messages: List[Dict],**kwargs): - args_process = GenerateProcess(self) + args_process = GenArgs(kwargs, self) default_kwargs = self.get_default_gen_args() default_kwargs.update(kwargs) - args_process.postprocess(default_kwargs) + args_process.build_args(default_kwargs) query = args_process.get_chat_info(messages,chat_format="generate") response = self.gen_core.generate(query=query, **default_kwargs) return CompletionResult(result={ @@ -192,6 +190,7 @@ def generate(self,messages: List[Dict],**kwargs): }) def embedding(self, query, **kwargs): + args_process = GenArgs(kwargs, self) model = self.get_model() inputs = self.tokenizer(query, return_tensors="pt") inputs = inputs.to(model.device) diff --git a/serving/model_handler/chatglm/infer.py b/serving/model_handler/chatglm/infer.py index 4333591..b8ec92e 100644 --- a/serving/model_handler/chatglm/infer.py +++ b/serving/model_handler/chatglm/infer.py @@ -13,7 +13,7 @@ from transformers import HfArgumentParser from aigc_zoo.model_zoo.chatglm.llm_model import MyTransformer, ChatGLMTokenizer, PetlArguments, setup_model_profile, \ ChatGLMConfig,PetlModel -from serving.model_handler.base import EngineAPI_Base,CompletionResult,LoraModelState, load_lora_config, GenerateProcess,WorkMode +from serving.model_handler.base import EngineAPI_Base,CompletionResult,LoraModelState, load_lora_config, GenArgs,WorkMode from serving.prompt import * class NN_DataHelper(DataHelper):pass @@ -148,12 +148,11 @@ def get_default_gen_args(self): return default_kwargs def chat_stream(self, messages: List[Dict], **kwargs): - args_process = GenerateProcess(self,is_stream=True) - args_process.preprocess(kwargs) + args_process = GenArgs(kwargs, self, is_stream=True) chunk = args_process.chunk default_kwargs=self.get_default_gen_args() default_kwargs.update(kwargs) - args_process.postprocess(default_kwargs) + args_process.build_args(default_kwargs) query,history = args_process.get_chat_info(messages) for response, history in self.model.stream_chat(self.tokenizer, query=query,history=history, **kwargs): chunk.step(response) @@ -175,11 +174,10 @@ def chat_stream(self, messages: List[Dict], **kwargs): }, complete=False) def chat(self,messages: List[Dict], **kwargs): - args_process = GenerateProcess(self) - args_process.preprocess(kwargs) + args_process = GenArgs(kwargs, self) default_kwargs=self.get_default_gen_args() default_kwargs.update(kwargs) - args_process.postprocess(default_kwargs) + args_process.build_args(default_kwargs) query, history = args_process.get_chat_info(messages) response, history = self.model.chat(self.tokenizer, query=query,history=history, **default_kwargs) response = args_process.postprocess_response(response, **kwargs) @@ -189,10 +187,10 @@ def chat(self,messages: List[Dict], **kwargs): }) def generate(self,messages: List[Dict],**kwargs): - args_process = GenerateProcess(self) + args_process = GenArgs(kwargs, self) default_kwargs=self.get_default_gen_args() default_kwargs.update(kwargs) - args_process.postprocess(default_kwargs) + args_process.build_args(default_kwargs) query = args_process.get_chat_info(messages,chat_format="generate") output,_ = self.model.chat(self.tokenizer, query=query,**default_kwargs) output_scores = default_kwargs.get('output_scores', False) @@ -205,6 +203,7 @@ def generate(self,messages: List[Dict],**kwargs): }) def embedding(self, query, **kwargs): + args_process = GenArgs(kwargs, self) model = self.get_model() inputs = self.tokenizer(query, return_tensors="pt") inputs = inputs.to(model.device) diff --git a/serving/model_handler/chatglm2/infer.py b/serving/model_handler/chatglm2/infer.py index f4644f4..4867a81 100644 --- a/serving/model_handler/chatglm2/infer.py +++ b/serving/model_handler/chatglm2/infer.py @@ -13,7 +13,7 @@ from transformers import HfArgumentParser from aigc_zoo.model_zoo.chatglm2.llm_model import MyTransformer, ChatGLMTokenizer, PetlArguments, \ setup_model_profile, ChatGLMConfig -from serving.model_handler.base import EngineAPI_Base,CompletionResult,LoraModelState, load_lora_config, GenerateProcess,WorkMode +from serving.model_handler.base import EngineAPI_Base,CompletionResult,LoraModelState, load_lora_config, GenArgs,WorkMode from serving.prompt import * @@ -152,12 +152,11 @@ def get_default_gen_args(self): return default_kwargs def chat_stream(self, messages: List[Dict], **kwargs): - args_process = GenerateProcess(self,is_stream=True) - args_process.preprocess(kwargs) + args_process = GenArgs(kwargs, self, is_stream=True) chunk = args_process.chunk default_kwargs = self.get_default_gen_args() default_kwargs.update(kwargs) - args_process.postprocess(default_kwargs) + args_process.build_args(default_kwargs) query, history = args_process.get_chat_info(messages) for response, history in self.get_model().stream_chat(self.tokenizer, query=query,history=history, **default_kwargs): chunk.step(response) @@ -178,11 +177,10 @@ def chat_stream(self, messages: List[Dict], **kwargs): }, complete=False) def chat(self,messages: List[Dict], **kwargs): - args_process = GenerateProcess(self) - args_process.preprocess(kwargs) + args_process = GenArgs(kwargs, self) default_kwargs = self.get_default_gen_args() default_kwargs.update(kwargs) - args_process.postprocess(default_kwargs) + args_process.build_args(default_kwargs) query, history = args_process.get_chat_info(messages) response, history = self.model.chat(self.tokenizer, query=query,history=history, **default_kwargs) response = args_process.postprocess_response(response, **kwargs) @@ -192,10 +190,10 @@ def chat(self,messages: List[Dict], **kwargs): }) def generate(self,messages: List[Dict],**kwargs): - args_process = GenerateProcess(self) + args_process = GenArgs(kwargs, self) default_kwargs = self.get_default_gen_args() default_kwargs.update(kwargs) - args_process.postprocess(default_kwargs) + args_process.build_args(default_kwargs) query = args_process.get_chat_info(messages,chat_format="generate") output,_ = self.model.chat(self.tokenizer, query=query,**default_kwargs) output_scores = default_kwargs.get('output_scores', False) @@ -208,6 +206,7 @@ def generate(self,messages: List[Dict],**kwargs): }) def embedding(self, query, **kwargs): + args_process = GenArgs(kwargs, self) model = self.get_model() inputs = self.tokenizer(query, return_tensors="pt") inputs = inputs.to(model.device) diff --git a/serving/model_handler/chatglm3/infer.py b/serving/model_handler/chatglm3/infer.py index befe5c8..fed7fab 100644 --- a/serving/model_handler/chatglm3/infer.py +++ b/serving/model_handler/chatglm3/infer.py @@ -14,7 +14,7 @@ from transformers import HfArgumentParser from aigc_zoo.model_zoo.chatglm3.llm_model import MyTransformer, ChatGLMTokenizer, PetlArguments, \ setup_model_profile, ChatGLMConfig -from serving.model_handler.base import EngineAPI_Base,CompletionResult,LoraModelState, load_lora_config, GenerateProcess,WorkMode +from serving.model_handler.base import EngineAPI_Base,CompletionResult,LoraModelState, load_lora_config, GenArgs,WorkMode from serving.prompt import * @@ -169,12 +169,11 @@ def get_default_gen_args(self): return default_kwargs def chat_stream(self, messages: List[Dict], **kwargs): - args_process = GenerateProcess(self,is_stream=True) - args_process.preprocess(kwargs) + args_process = GenArgs(kwargs, self, is_stream=True) chunk = args_process.chunk default_kwargs = self.get_default_gen_args() default_kwargs.update(kwargs) - args_process.postprocess(default_kwargs) + args_process.build_args(default_kwargs) role,query,history = _preprocess_messages_for_chatglm3(messages) for response, history in self.get_model().stream_chat(self.tokenizer, query=query,role=role,history=history,with_postprocess=False,**default_kwargs): chunk.step(response) @@ -195,11 +194,10 @@ def chat_stream(self, messages: List[Dict], **kwargs): }, complete=False) def chat(self,messages: List[Dict], **kwargs): - args_process = GenerateProcess(self) - args_process.preprocess(kwargs) + args_process = GenArgs(kwargs, self) default_kwargs = self.get_default_gen_args() default_kwargs.update(kwargs) - args_process.postprocess(default_kwargs) + args_process.build_args(default_kwargs) role, query, history = _preprocess_messages_for_chatglm3(messages) response, history = self.model.chat(self.tokenizer, query=query,role=role,history=messages,with_postprocess=False,**default_kwargs) if isinstance(response,str): @@ -213,10 +211,10 @@ def chat(self,messages: List[Dict], **kwargs): }) def generate(self,messages: List[Dict],**kwargs): - args_process = GenerateProcess(self) + args_process = GenArgs(kwargs, self) default_kwargs = self.get_default_gen_args() default_kwargs.update(kwargs) - args_process.postprocess(default_kwargs) + args_process.build_args(default_kwargs) query = messages[0]["content"] output,_ = self.model.chat(self.tokenizer,query=query,with_postprocess=False, **default_kwargs) output_scores = default_kwargs.get('output_scores', False) @@ -229,6 +227,7 @@ def generate(self,messages: List[Dict],**kwargs): }) def embedding(self, query, **kwargs): + args_process = GenArgs(kwargs, self) model = self.get_model() inputs = self.tokenizer(query, return_tensors="pt") inputs = inputs.to(model.device) diff --git a/serving/model_handler/internlm/infer.py b/serving/model_handler/internlm/infer.py index 95caefa..35fd49a 100644 --- a/serving/model_handler/internlm/infer.py +++ b/serving/model_handler/internlm/infer.py @@ -15,7 +15,7 @@ from aigc_zoo.model_zoo.internlm.llm_model import MyTransformer,InternLMConfig,InternLMTokenizer,\ InternLMForCausalLM,PetlArguments,PetlModel from serving.model_handler.base import EngineAPI_Base, CompletionResult,LoraModelState, load_lora_config, \ - GenerateProcess, WorkMode + GenArgs, WorkMode from serving.prompt import * class NN_DataHelper(DataHelper):pass @@ -134,11 +134,10 @@ def get_default_gen_args(self): return default_kwargs def chat_stream(self, messages: List[Dict], **kwargs): - args_process = GenerateProcess(self,is_stream=True) - args_process.preprocess(kwargs) + args_process = GenArgs(kwargs, self, is_stream=True) default_kwargs = self.get_default_gen_args() default_kwargs.update(kwargs) - args_process.postprocess(default_kwargs) + args_process.build_args(default_kwargs) query, history = args_process.get_chat_info(messages) skip_word_list = [self.tokenizer.eos_token_id,2, 103028] streamer = args_process.get_streamer(skip_word_list) @@ -148,11 +147,10 @@ def chat_stream(self, messages: List[Dict], **kwargs): def chat(self,messages: List[Dict], **kwargs): - args_process = GenerateProcess(self) - args_process.preprocess(kwargs) + args_process = GenArgs(kwargs, self) default_kwargs = self.get_default_gen_args() default_kwargs.update(kwargs) - args_process.postprocess(default_kwargs) + args_process.build_args(default_kwargs) query, history = args_process.get_chat_info(messages) response, history = self.model.chat(self.tokenizer, query=query,history=history, **default_kwargs) response = args_process.postprocess_response(response, **kwargs) @@ -162,10 +160,10 @@ def chat(self,messages: List[Dict], **kwargs): }) def generate(self,messages: List[Dict],**kwargs): - args_process = GenerateProcess(self) + args_process = GenArgs(kwargs, self) default_kwargs = self.get_default_gen_args() default_kwargs.update(kwargs) - args_process.postprocess(default_kwargs) + args_process.build_args(default_kwargs) query = args_process.get_chat_info(messages,chat_format="generate") output,_ = self.model.chat(self.tokenizer, query=query **default_kwargs) output_scores = default_kwargs.get('output_scores', False) @@ -178,6 +176,7 @@ def generate(self,messages: List[Dict],**kwargs): }) def embedding(self, query, **kwargs): + args_process = GenArgs(kwargs, self) model = self.get_model() inputs = self.tokenizer(query, return_tensors="pt") inputs = inputs.to(model.device) diff --git a/serving/model_handler/llama/infer.py b/serving/model_handler/llama/infer.py index a34da1e..a847b06 100644 --- a/serving/model_handler/llama/infer.py +++ b/serving/model_handler/llama/infer.py @@ -17,7 +17,7 @@ from deep_training.nlp.models.rellama.modeling_llama import LlamaForCausalLM from aigc_zoo.model_zoo.llm.llm_model import MyTransformer,PetlArguments,PetlModel,AutoConfig from aigc_zoo.generator_utils.generator_llm import Generate -from serving.model_handler.base import EngineAPI_Base,CompletionResult, LoraModelState, load_lora_config, GenerateProcess,WorkMode,ChunkData +from serving.model_handler.base import EngineAPI_Base,CompletionResult, LoraModelState, load_lora_config, GenArgs,WorkMode,ChunkData from serving.prompt import * class NN_DataHelper(DataHelper):pass @@ -152,11 +152,10 @@ def get_default_gen_args(self): return default_kwargs def chat_stream(self,messages: List[Dict], **kwargs): - args_process = GenerateProcess(self,is_stream=True) - args_process.preprocess(kwargs) + args_process = GenArgs(kwargs, self, is_stream=True) default_kwargs = self.get_default_gen_args() default_kwargs.update(kwargs) - args_process.postprocess(default_kwargs) + args_process.build_args(default_kwargs) prefix,query, history = args_process.get_chat_info_with_system(messages) if self.is_openbuddy: prompt = get_chat_openbuddy(self.tokenizer, query, history=history, prefix=prefix) @@ -176,11 +175,10 @@ def chat_stream(self,messages: List[Dict], **kwargs): def chat(self,messages: List[Dict], **kwargs): - args_process = GenerateProcess(self) - args_process.preprocess(kwargs) + args_process = GenArgs(kwargs, self) default_kwargs = self.get_default_gen_args() default_kwargs.update(kwargs) - args_process.postprocess(default_kwargs) + args_process.build_args(default_kwargs) prefix,query, history = args_process.get_chat_info_with_system(messages) if self.is_openbuddy: prompt = get_chat_openbuddy(self.tokenizer, query, history=history, prefix=prefix) @@ -199,10 +197,10 @@ def chat(self,messages: List[Dict], **kwargs): def generate(self,messages: List[Dict],**kwargs): - args_process = GenerateProcess(self) + args_process = GenArgs(kwargs, self) default_kwargs = self.get_default_gen_args() default_kwargs.update(kwargs) - args_process.postprocess(default_kwargs) + args_process.build_args(default_kwargs) query = args_process.get_chat_info(messages,chat_format="generate") response = self.gen_core.generate(query=query, **default_kwargs) return CompletionResult(result={ @@ -211,6 +209,7 @@ def generate(self,messages: List[Dict],**kwargs): }) def embedding(self, query, **kwargs): + args_process = GenArgs(kwargs, self) model = self.get_model() inputs = self.tokenizer(query, return_tensors="pt") inputs = inputs.to(model.device) diff --git a/serving/model_handler/llm/infer.py b/serving/model_handler/llm/infer.py index 9e8a86f..f0de39a 100644 --- a/serving/model_handler/llm/infer.py +++ b/serving/model_handler/llm/infer.py @@ -14,7 +14,7 @@ from aigc_zoo.model_zoo.llm.llm_model import MyTransformer,PetlArguments,PetlModel,AutoConfig from aigc_zoo.generator_utils.generator_llm import Generate from serving.model_handler.base import EngineAPI_Base, CompletionResult,LoraModelState, load_lora_config, \ - GenerateProcess, WorkMode + GenArgs, WorkMode from serving.prompt import * @@ -116,11 +116,10 @@ def get_default_gen_args(self): return default_kwargs def chat_stream(self,messages: List[Dict], **kwargs): - args_process = GenerateProcess(self,is_stream=True) - args_process.preprocess(kwargs) + args_process = GenArgs(kwargs, self, is_stream=True) default_kwargs = self.get_default_gen_args() default_kwargs.update(kwargs) - args_process.postprocess(default_kwargs) + args_process.build_args(default_kwargs) prefix,query, history = args_process.get_chat_info_with_system(messages) prompt = get_chat_default(self.tokenizer, query, history=history, prefix=prefix) skip_word_list = default_kwargs.get('eos_token_id', None) or [self.tokenizer.eos_token_id] @@ -131,11 +130,10 @@ def chat_stream(self,messages: List[Dict], **kwargs): def chat(self,messages: List[Dict], **kwargs): - args_process = GenerateProcess(self) - args_process.preprocess(kwargs) + args_process = GenArgs(kwargs, self) default_kwargs = self.get_default_gen_args() default_kwargs.update(kwargs) - args_process.postprocess(default_kwargs) + args_process.build_args(default_kwargs) prefix,query, history = args_process.get_chat_info_with_system(messages) prompt = get_chat_default(self.tokenizer, query, history=history, prefix=prefix) response = self.gen_core.generate(query=prompt, **default_kwargs) @@ -149,6 +147,7 @@ def chat(self,messages: List[Dict], **kwargs): def embedding(self, query, **kwargs): + args_process = GenArgs(kwargs, self) model = self.get_model() inputs = self.tokenizer(query, return_tensors="pt") inputs = inputs.to(model.device) diff --git a/serving/model_handler/moss/infer.py b/serving/model_handler/moss/infer.py index 2d5d80f..6bf3842 100644 --- a/serving/model_handler/moss/infer.py +++ b/serving/model_handler/moss/infer.py @@ -14,7 +14,7 @@ from aigc_zoo.model_zoo.moss.llm_model import MyTransformer,MossConfig,MossTokenizer,PetlArguments,PetlModel from aigc_zoo.generator_utils.generator_moss import Generate from serving.model_handler.base import EngineAPI_Base, CompletionResult, CompletionResult, LoraModelState, \ - load_lora_config, GenerateProcess, WorkMode + load_lora_config, GenArgs, WorkMode from serving.prompt import * @@ -126,11 +126,10 @@ def get_default_gen_args(self): return default_kwargs def chat_stream(self,messages: List[Dict], **kwargs): - args_process = GenerateProcess(self,is_stream=True) - args_process.preprocess(kwargs) + args_process = GenArgs(kwargs, self, is_stream=True) default_kwargs = self.get_default_gen_args() default_kwargs.update(kwargs) - args_process.postprocess(default_kwargs) + args_process.build_args(default_kwargs) query, history = args_process.get_chat_info(messages) skip_word_list = default_kwargs.get('eos_token_id', None) or [self.tokenizer.eos_token_id] streamer = args_process.get_streamer(skip_word_list) @@ -140,11 +139,10 @@ def chat_stream(self,messages: List[Dict], **kwargs): def chat(self,messages: List[Dict], **kwargs): - args_process = GenerateProcess(self) - args_process.preprocess(kwargs) + args_process = GenArgs(kwargs, self) default_kwargs = self.get_default_gen_args() default_kwargs.update(kwargs) - args_process.postprocess(default_kwargs) + args_process.build_args(default_kwargs) query, history = args_process.get_chat_info(messages) response,history = self.gen_core.chat(query=query, history=history, **default_kwargs) response = args_process.postprocess_response(response, **kwargs) @@ -155,10 +153,10 @@ def chat(self,messages: List[Dict], **kwargs): def generate(self,messages: List[Dict],**kwargs): - args_process = GenerateProcess(self) + args_process = GenArgs(kwargs, self) default_kwargs = self.get_default_gen_args() default_kwargs.update(kwargs) - args_process.postprocess(default_kwargs) + args_process.build_args(default_kwargs) query = args_process.get_chat_info(messages,chat_format="generate") response = self.model.generate(query=query, **kwargs) return CompletionResult(result={ @@ -167,6 +165,7 @@ def generate(self,messages: List[Dict],**kwargs): }) def embedding(self, query, **kwargs): + args_process = GenArgs(kwargs, self) model = self.get_model() inputs = self.tokenizer(query, return_tensors="pt") inputs = inputs.to(model.device) diff --git a/serving/model_handler/qwen/infer.py b/serving/model_handler/qwen/infer.py index db84170..e805e84 100644 --- a/serving/model_handler/qwen/infer.py +++ b/serving/model_handler/qwen/infer.py @@ -12,7 +12,7 @@ from transformers import HfArgumentParser, BitsAndBytesConfig from aigc_zoo.model_zoo.qwen.llm_model import MyTransformer, QWenTokenizer, PetlArguments, \ setup_model_profile, QWenConfig -from serving.model_handler.base import EngineAPI_Base,CompletionResult,LoraModelState, load_lora_config, GenerateProcess,WorkMode +from serving.model_handler.base import EngineAPI_Base,CompletionResult,LoraModelState, load_lora_config, GenArgs,WorkMode from serving.prompt import * @@ -139,11 +139,10 @@ def get_default_gen_args(self): return default_kwargs def chat_stream(self, messages: List[Dict], **kwargs): - args_process = GenerateProcess(self,is_stream=True) - args_process.preprocess(kwargs) + args_process = GenArgs(kwargs, self, is_stream=True) default_kwargs = self.get_default_gen_args() default_kwargs.update(kwargs) - args_process.postprocess(default_kwargs) + args_process.build_args(default_kwargs) query, history = args_process.get_chat_info(messages) skip_word_list = [self.tokenizer.im_end_id, self.tokenizer.im_start_id, self.tokenizer.eos_token_id or 151643] skip_word_list += default_kwargs.get('stop_words_ids', []) @@ -155,11 +154,10 @@ def chat_stream(self, messages: List[Dict], **kwargs): def chat(self,messages: List[Dict], **kwargs): - args_process = GenerateProcess(self) - args_process.preprocess(kwargs) + args_process = GenArgs(kwargs, self) default_kwargs = self.get_default_gen_args() default_kwargs.update(kwargs) - args_process.postprocess(default_kwargs) + args_process.build_args(default_kwargs) query, history = args_process.get_chat_info(messages) response, history = self.model.chat(self.tokenizer, query=query,history=history, **default_kwargs) response = args_process.postprocess_response(response, **kwargs) @@ -169,10 +167,10 @@ def chat(self,messages: List[Dict], **kwargs): }) def generate(self,messages: List[Dict],**kwargs): - args_process = GenerateProcess(self) + args_process = GenArgs(kwargs, self) default_kwargs = self.get_default_gen_args() default_kwargs.update(kwargs) - args_process.postprocess(default_kwargs) + args_process.build_args(default_kwargs) query = args_process.get_chat_info(messages,chat_format="generate") output = self.model.chat(self.tokenizer, query=query,**default_kwargs) output_scores = default_kwargs.get('output_scores', False) @@ -185,6 +183,7 @@ def generate(self,messages: List[Dict],**kwargs): }) def embedding(self, query, **kwargs): + args_process = GenArgs(kwargs, self) from deep_training.nlp.models.qwen.modeling_qwen import QWenLMHeadModel model: QWenLMHeadModel = self.get_model() inputs = self.tokenizer(query, return_tensors="pt") diff --git a/serving/model_handler/rwkv/infer.py b/serving/model_handler/rwkv/infer.py index 4b459ce..c7dd671 100644 --- a/serving/model_handler/rwkv/infer.py +++ b/serving/model_handler/rwkv/infer.py @@ -14,7 +14,7 @@ from aigc_zoo.model_zoo.rwkv4.llm_model import MyTransformer, RwkvConfig, \ set_model_profile,PetlArguments,PetlModel from aigc_zoo.utils.rwkv4_generate import Generate -from serving.model_handler.base import EngineAPI_Base,CompletionResult,LoraModelState, load_lora_config, GenerateProcess,WorkMode +from serving.model_handler.base import EngineAPI_Base,CompletionResult,LoraModelState, load_lora_config, GenArgs,WorkMode from serving.prompt import * @@ -122,11 +122,10 @@ def get_default_gen_args(self): def chat_stream(self,messages: List[Dict], **kwargs): - args_process = GenerateProcess(self,is_stream=True) - args_process.preprocess(kwargs) + args_process = GenArgs(kwargs, self, is_stream=True) default_kwargs = self.get_default_gen_args() default_kwargs.update(kwargs) - args_process.postprocess(default_kwargs) + args_process.build_args(default_kwargs) query, history = args_process.get_chat_info(messages) prompt = get_chat_default(self.tokenizer, query, history) skip_word_list = default_kwargs.get('eos_token_id', None) or [self.tokenizer.eos_token_id] @@ -136,11 +135,10 @@ def chat_stream(self,messages: List[Dict], **kwargs): return None def chat(self,messages: List[Dict], **kwargs): - args_process = GenerateProcess(self) - args_process.preprocess(kwargs) + args_process = GenArgs(kwargs, self) default_kwargs = self.get_default_gen_args() default_kwargs.update(kwargs) - args_process.postprocess(default_kwargs) + args_process.build_args(default_kwargs) query, history = args_process.get_chat_info(messages) prompt = get_chat_default(self.tokenizer, query, history) response = Generate.generate(self.get_model(), @@ -154,10 +152,10 @@ def chat(self,messages: List[Dict], **kwargs): def generate(self,messages: List[Dict],**kwargs): - args_process = GenerateProcess(self) + args_process = GenArgs(kwargs, self) default_kwargs = self.get_default_gen_args() default_kwargs.update(kwargs) - args_process.postprocess(default_kwargs) + args_process.build_args(default_kwargs) query = args_process.get_chat_info(messages,chat_format="generate") response = Generate.generate(self.get_model(), tokenizer=self.tokenizer, @@ -168,6 +166,7 @@ def generate(self,messages: List[Dict],**kwargs): }) def embedding(self, query, **kwargs): + args_process = GenArgs(kwargs, self) model = self.get_model() inputs = self.tokenizer(query, return_tensors="pt") inputs = inputs.to(model.device) diff --git a/serving/model_handler/skywork/infer.py b/serving/model_handler/skywork/infer.py index b791254..cd62be7 100644 --- a/serving/model_handler/skywork/infer.py +++ b/serving/model_handler/skywork/infer.py @@ -17,7 +17,7 @@ from aigc_zoo.model_zoo.skywork.llm_model import MySkyworkForCausalLM,SkyworkConfig,SkyworkTokenizer from aigc_zoo.model_zoo.llm.llm_model import MyTransformer,PetlArguments,PetlModel,AutoConfig from aigc_zoo.generator_utils.generator_llm import Generate -from serving.model_handler.base import EngineAPI_Base,CompletionResult, LoraModelState, load_lora_config, GenerateProcess,WorkMode,ChunkData +from serving.model_handler.base import EngineAPI_Base,CompletionResult, LoraModelState, load_lora_config, GenArgs,WorkMode,ChunkData from serving.prompt import * class NN_DataHelper(DataHelper):pass @@ -153,11 +153,10 @@ def get_default_gen_args(self): return default_kwargs def chat_stream(self,messages: List[Dict], **kwargs): - args_process = GenerateProcess(self,is_stream=True) - args_process.preprocess(kwargs) + args_process = GenArgs(kwargs, self, is_stream=True) default_kwargs = self.get_default_gen_args() default_kwargs.update(kwargs) - args_process.postprocess(default_kwargs) + args_process.build_args(default_kwargs) prefix,query, history = args_process.get_chat_info_with_system(messages) prompt = get_chat_skywork(self.tokenizer, query, history=history, prefix=prefix) skip_word_list = default_kwargs.get('eos_token_id', None) or [self.tokenizer.eos_token_id] @@ -170,11 +169,10 @@ def chat_stream(self,messages: List[Dict], **kwargs): def chat(self,messages: List[Dict], **kwargs): - args_process = GenerateProcess(self) - args_process.preprocess(kwargs) + args_process = GenArgs(kwargs, self) default_kwargs = self.get_default_gen_args() default_kwargs.update(kwargs) - args_process.postprocess(default_kwargs) + args_process.build_args(default_kwargs) prefix,query, history = args_process.get_chat_info_with_system(messages) prompt = get_chat_skywork(self.tokenizer, query, history=history, prefix=prefix) response = self.gen_core.generate(query=prompt, **default_kwargs) @@ -186,10 +184,10 @@ def chat(self,messages: List[Dict], **kwargs): def generate(self,messages: List[Dict],**kwargs): - args_process = GenerateProcess(self) + args_process = GenArgs(kwargs, self) default_kwargs = self.get_default_gen_args() default_kwargs.update(kwargs) - args_process.postprocess(default_kwargs) + args_process.build_args(default_kwargs) query = args_process.get_chat_info(messages,chat_format="generate") response = self.gen_core.generate(query=query, **default_kwargs) return CompletionResult(result={ @@ -198,6 +196,7 @@ def generate(self,messages: List[Dict],**kwargs): }) def embedding(self, query, **kwargs): + args_process = GenArgs(kwargs, self) model = self.get_model() inputs = self.tokenizer(query, return_tensors="pt") inputs = inputs.to(model.device) diff --git a/serving/model_handler/t5/infer.py b/serving/model_handler/t5/infer.py index 432531d..7924de9 100644 --- a/serving/model_handler/t5/infer.py +++ b/serving/model_handler/t5/infer.py @@ -14,7 +14,7 @@ from aigc_zoo.model_zoo.t5.llm_model import MyTransformer,PetlArguments,PetlModel,AutoConfig from aigc_zoo.generator_utils.generator_llm import Generate from serving.model_handler.base import EngineAPI_Base, CompletionResult,LoraModelState, load_lora_config, \ - GenerateProcess, WorkMode + GenArgs, WorkMode from serving.prompt import * @@ -143,11 +143,10 @@ def get_default_gen_args(self): return default_kwargs def chat_stream(self,messages: List[Dict], **kwargs): - args_process = GenerateProcess(self,is_stream=True) - args_process.preprocess(kwargs) + args_process = GenArgs(kwargs, self, is_stream=True) default_kwargs = self.get_default_gen_args() default_kwargs.update(kwargs) - args_process.postprocess(default_kwargs) + args_process.build_args(default_kwargs) query, history = args_process.get_chat_info(messages) prompt = get_chat_chatyaun(self.tokenizer, query, history) skip_word_list = default_kwargs.get('eos_token_id', None) or [self.tokenizer.eos_token_id] @@ -158,11 +157,10 @@ def chat_stream(self,messages: List[Dict], **kwargs): def chat(self,messages: List[Dict], **kwargs): - args_process = GenerateProcess(self) - args_process.preprocess(kwargs) + args_process = GenArgs(kwargs, self) default_kwargs = self.get_default_gen_args() default_kwargs.update(kwargs) - args_process.postprocess(default_kwargs) + args_process.build_args(default_kwargs) query, history = args_process.get_chat_info(messages) prompt = get_chat_chatyaun(self.tokenizer, query, history) response = self.gen_core.generate(prompt, **default_kwargs) @@ -175,10 +173,10 @@ def chat(self,messages: List[Dict], **kwargs): }) def generate(self, messages: List[Dict], **kwargs): - args_process = GenerateProcess(self) + args_process = GenArgs(kwargs, self) default_kwargs = self.get_default_gen_args() default_kwargs.update(kwargs) - args_process.postprocess(default_kwargs) + args_process.build_args(default_kwargs) query = args_process.get_chat_info(messages, chat_format="generate") response = Generate.generate(self.get_model(), tokenizer=self.tokenizer, @@ -191,6 +189,7 @@ def generate(self, messages: List[Dict], **kwargs): def embedding(self, query, **kwargs): + args_process = GenArgs(kwargs, self) model = self.get_model() inputs = self.tokenizer(query, return_tensors="pt") inputs = inputs.to(model.device) diff --git a/serving/model_handler/xverse/infer.py b/serving/model_handler/xverse/infer.py index ff1d349..9c2436a 100644 --- a/serving/model_handler/xverse/infer.py +++ b/serving/model_handler/xverse/infer.py @@ -13,7 +13,7 @@ from deep_training.nlp.layers.rope_scale.patch import RotaryNtkScaledArguments from transformers import HfArgumentParser, GenerationConfig from aigc_zoo.utils.xverse_generate import Generate -from serving.model_handler.base import EngineAPI_Base,CompletionResult,LoraModelState, load_lora_config, GenerateProcess,WorkMode +from serving.model_handler.base import EngineAPI_Base,CompletionResult,LoraModelState, load_lora_config, GenArgs,WorkMode from transformers import AutoModelForCausalLM from deep_training.utils.hf import register_transformer_model, register_transformer_config # noqa # from deep_training.nlp.models.xverse.modeling_xverse import XverseForCausalLM, XverseConfig @@ -142,11 +142,10 @@ def get_default_gen_args(self): return default_kwargs def chat_stream(self,messages: List[Dict], **kwargs): - args_process = GenerateProcess(self,is_stream=True) - args_process.preprocess(kwargs) + args_process = GenArgs(kwargs, self, is_stream=True) default_kwargs = self.get_default_gen_args() default_kwargs.update(kwargs) - args_process.postprocess(default_kwargs) + args_process.build_args(default_kwargs) stopping_criteria = default_kwargs.pop('stopping_criteria',None) generation_config = GenerationConfig(**default_kwargs) @@ -164,11 +163,10 @@ def chat_stream(self,messages: List[Dict], **kwargs): def chat(self,messages: List[Dict], **kwargs): - args_process = GenerateProcess(self) - args_process.preprocess(kwargs) + args_process = GenArgs(kwargs, self) default_kwargs = self.get_default_gen_args() default_kwargs.update(kwargs) - args_process.postprocess(default_kwargs) + args_process.build_args(default_kwargs) stopping_criteria = default_kwargs.pop('stopping_criteria', None) generation_config = GenerationConfig(**default_kwargs) response = self.get_model().chat(tokenizer=self.tokenizer, @@ -184,14 +182,14 @@ def chat(self,messages: List[Dict], **kwargs): def generate(self,messages: List[Dict],**kwargs): - args_process = GenerateProcess(self) + args_process = GenArgs(kwargs, self) default_kwargs = dict( eos_token_id=self.config.eos_token_id, pad_token_id=self.config.eos_token_id, do_sample=True, top_p=0.7, temperature=0.95, ) default_kwargs.update(kwargs) - args_process.postprocess(default_kwargs) + args_process.build_args(default_kwargs) query = args_process.get_chat_info(messages, chat_format="generate") response = Generate.generate(self.get_model(), tokenizer=self.tokenizer, @@ -202,6 +200,7 @@ def generate(self,messages: List[Dict],**kwargs): }) def embedding(self, query, **kwargs): + args_process = GenArgs(kwargs, self) model = self.get_model() inputs = self.tokenizer(query, return_tensors="pt") inputs = inputs.to(model.device) diff --git a/serving/model_handler/yi/infer.py b/serving/model_handler/yi/infer.py index 4240c73..bbae135 100644 --- a/serving/model_handler/yi/infer.py +++ b/serving/model_handler/yi/infer.py @@ -17,7 +17,7 @@ from aigc_zoo.model_zoo.yi.llm_model import MyYiForCausalLM,YiConfig,YiTokenizer from aigc_zoo.model_zoo.llm.llm_model import MyTransformer,PetlArguments,PetlModel,AutoConfig from aigc_zoo.generator_utils.generator_llm import Generate -from serving.model_handler.base import EngineAPI_Base,CompletionResult, LoraModelState, load_lora_config, GenerateProcess,WorkMode,ChunkData +from serving.model_handler.base import EngineAPI_Base,CompletionResult, LoraModelState, load_lora_config, GenArgs,WorkMode,ChunkData from serving.prompt import * class NN_DataHelper(DataHelper):pass @@ -153,11 +153,10 @@ def get_default_gen_args(self): return default_kwargs def chat_stream(self,messages: List[Dict], **kwargs): - args_process = GenerateProcess(self,is_stream=True) - args_process.preprocess(kwargs) + args_process = GenArgs(kwargs, self, is_stream=True) default_kwargs = self.get_default_gen_args() default_kwargs.update(kwargs) - args_process.postprocess(default_kwargs) + args_process.build_args(default_kwargs) prefix,query, history = args_process.get_chat_info_with_system(messages) # 模板 prompt = get_chat_yi(self.tokenizer, query, history=history, prefix=prefix) @@ -171,11 +170,10 @@ def chat_stream(self,messages: List[Dict], **kwargs): def chat(self,messages: List[Dict], **kwargs): - args_process = GenerateProcess(self) - args_process.preprocess(kwargs) + args_process = GenArgs(kwargs, self) default_kwargs = self.get_default_gen_args() default_kwargs.update(kwargs) - args_process.postprocess(default_kwargs) + args_process.build_args(default_kwargs) prefix,query, history = args_process.get_chat_info_with_system(messages) # 模板 prompt = get_chat_yi(self.tokenizer, query, history=history, prefix=prefix) @@ -188,10 +186,10 @@ def chat(self,messages: List[Dict], **kwargs): def generate(self,messages: List[Dict],**kwargs): - args_process = GenerateProcess(self) + args_process = GenArgs(kwargs, self) default_kwargs = self.get_default_gen_args() default_kwargs.update(kwargs) - args_process.postprocess(default_kwargs) + args_process.build_args(default_kwargs) query = args_process.get_chat_info(messages,chat_format="generate") response = self.gen_core.generate(query=query, **default_kwargs) return CompletionResult(result={ @@ -200,6 +198,7 @@ def generate(self,messages: List[Dict],**kwargs): }) def embedding(self, query, **kwargs): + args_process = GenArgs(kwargs, self) model = self.get_model() inputs = self.tokenizer(query, return_tensors="pt") inputs = inputs.to(model.device) diff --git a/serving/openai_api/custom.py b/serving/openai_api/custom.py index a98099c..3bef583 100644 --- a/serving/openai_api/custom.py +++ b/serving/openai_api/custom.py @@ -10,6 +10,11 @@ ] +class Tools(BaseModel): + type: Optional[str] = "function" + function: Optional[Dict] = None + + class CustomChatParams(BaseModel): model: str temperature: Optional[float] = 0.7 @@ -43,8 +48,18 @@ class CustomChatParams(BaseModel): forced_eos_token_id: Optional[int] = None guidance_scale: Optional[float] = None low_memory: Optional[bool] = None + + # Deprecated functions: Optional[List[Dict[str, Any]]] = None - function_call: Union[str, Dict[str, str]] = "auto" + # Deprecated + function_call: Optional[Union[str, Dict[str, str]]] = "auto" + + seed: Optional[int] = None + + # 取代 functions + tools: Optional[List[Tools]] = None + # 取代 function_call + tool_choice: Optional[Union[str, Dict[str, str]]] = "auto" # class Config: # underscore_attrs_are_private = True @@ -77,6 +92,7 @@ def _update_params(self, r): "forced_eos_token_id": self.forced_eos_token_id, "guidance_scale": self.guidance_scale, "low_memory": self.low_memory, + "seed": self.seed, } if self.frequency_penalty is not None and self.frequency_penalty > 0: params["repetition_penalty"] = self.frequency_penalty diff --git a/serving/openai_api/openai_api_protocol.py b/serving/openai_api/openai_api_protocol.py index 74e542e..8134ee5 100644 --- a/serving/openai_api/openai_api_protocol.py +++ b/serving/openai_api/openai_api_protocol.py @@ -83,7 +83,7 @@ class ChatCompletionRequest(CustomChatParams): messages: List[ChatMessage] def build_messages(self): - messages = [message.dict() for message in self.messages] + messages = [message.model_dump() for message in self.messages] assert self.messages[-1].role in [Role.USER,Role.OBSERVATION] return [messages] diff --git a/serving/serve/api_react.py b/serving/serve/api_react.py index b361f13..7bbfc26 100644 --- a/serving/serve/api_react.py +++ b/serving/serve/api_react.py @@ -19,6 +19,12 @@ def build_react_functions(request: Union[CompletionRequest,ChatCompletionRequest function_call = request.function_call functions = request.functions + tools = request.tools + tool_choice = request.tool_choice + if tools is not None and tool_choice is not None: + functions = [_.function for _ in tools if _.function] + function_call = tool_choice + messages = request.messages use_function = False if isinstance(messages, list) and isinstance(messages[0], dict): @@ -32,7 +38,6 @@ def build_react_functions(request: Union[CompletionRequest,ChatCompletionRequest request.messages, functions = get_react_prompt_for_qwen(messages, functions, function_call) elif model_type in ["chatglm","chatglm3"]: request.messages, functions = get_react_prompt_for_chatglm3(messages, functions, function_call) - return functions diff --git a/tests/test_openai.py b/tests/test_openai.py index 48dd4d7..50a85ca 100644 --- a/tests/test_openai.py +++ b/tests/test_openai.py @@ -25,6 +25,7 @@ "n": 1, # 返回 n 个choices "max_tokens": 512, # "stop": ["Observation:","Observation:\n"] + "seed": None, } diff --git a/tests/test_openai_chat.py b/tests/test_openai_chat.py index 67a4038..be49626 100644 --- a/tests/test_openai_chat.py +++ b/tests/test_openai_chat.py @@ -24,7 +24,9 @@ "nchar": 1,# stream 字符 "n": 1, # 返回 n 个choices "max_tokens": 512, - "stop": ["Observation:"] + "stop": ["Observation:"], + "seed": None, + # "seed": "46", } diff --git a/tests/web_demo.py b/tests/web_demo.py index e2235f0..4b7c019 100644 --- a/tests/web_demo.py +++ b/tests/web_demo.py @@ -38,7 +38,7 @@ def postprocess(self, y): return y -gr.Chatbot.postprocess = postprocess +gr.Chatbot.build_args = postprocess def parse_text(text): """copy from https://github.com/GaiZhenbiao/ChuanhuChatGPT/"""