diff --git a/README.md b/README.md index 84dc4c1..9c2f068 100644 --- a/README.md +++ b/README.md @@ -30,7 +30,7 @@ Example scripts are placed in [tests](./tests) folder. ### Endpoints -Each API request will connect to an endpoint along with some API configurations, which include: `api_key`, `organization`, `api_base`, `api_type`, `api_version` and `model_engine_map`. +Each API request will connect to an endpoint along with some API configurations, which include: `api_key`, `organization`, `api_base`, `api_type`, `api_version`, `model_engine_map`, `dest_url`. An `Endpoint` object contains these information. An `EndpointManager` acts like a list and can be used to rotate the next endpoint. See [test_endpoint.py](./tests/test_endpoint.py). diff --git a/pyproject.toml b/pyproject.toml index a231e0b..78ebde9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "HandyLLM" -version = "0.5.2" +version = "0.5.3" authors = [ { name="Atomie CHEN", email="atomic_cwh@163.com" }, ] diff --git a/src/handyllm/api_request.py b/src/handyllm/api_request.py index bb902e0..720258a 100644 --- a/src/handyllm/api_request.py +++ b/src/handyllm/api_request.py @@ -19,6 +19,7 @@ def api_request( timeout=None, files=None, raw_response=False, + dest_url=None, **kwargs): if api_key is None: raise Exception("OpenAI API key is not set") @@ -50,6 +51,8 @@ def api_request( headers['Authorization'] = 'Bearer ' + api_key if organization is not None: headers['OpenAI-Organization'] = organization + if dest_url is not None: + headers['Destination-URL'] = dest_url if method == 'post': if files is None: headers['Content-Type'] = 'application/json' diff --git a/src/handyllm/endpoint_manager.py b/src/handyllm/endpoint_manager.py index 3319425..accbb90 100644 --- a/src/handyllm/endpoint_manager.py +++ b/src/handyllm/endpoint_manager.py @@ -12,6 +12,7 @@ def __init__( api_type=None, api_version=None, model_engine_map=None, + dest_url=None, ): self.name = name if name else f"ep_{id(self)}" self.api_key = api_key @@ -20,6 +21,7 @@ def __init__( self.api_type = api_type self.api_version = api_version self.model_engine_map = model_engine_map if model_engine_map else {} + self.dest_url = dest_url def __str__(self) -> str: # do not print api_key @@ -31,6 +33,7 @@ def __str__(self) -> str: f'api_type={repr(self.api_type)}' if self.api_type else None, f'api_version={repr(self.api_version)}' if self.api_version else None, f'model_engine_map={repr(self.model_engine_map)}' if self.model_engine_map else None, + f'dest_url={repr(self.dest_url)}' if self.dest_url else None, ] # remove None in listed_attributes listed_attributes = [item for item in listed_attributes if item] @@ -44,6 +47,7 @@ def get_api_info(self): self.api_type, self.api_version, self.model_engine_map, + self.dest_url, ) diff --git a/src/handyllm/openai_api.py b/src/handyllm/openai_api.py index fa6243d..7401b0d 100644 --- a/src/handyllm/openai_api.py +++ b/src/handyllm/openai_api.py @@ -107,13 +107,13 @@ def api_request_endpoint( request_url, **kwargs ): - api_key, organization, api_base, api_type, api_version, engine = cls.consume_kwargs(kwargs) + api_key, organization, api_base, api_type, api_version, engine, dest_url = cls.consume_kwargs(kwargs) url = utils.join_url(api_base, request_url) - return api_request(url, api_key, organization=organization, api_type=api_type, **kwargs) + return api_request(url, api_key, organization=organization, api_type=api_type, dest_url=dest_url, **kwargs) @classmethod def consume_kwargs(cls, kwargs): - api_key = organization = api_base = api_type = api_version = engine = model_engine_map = None + api_key = organization = api_base = api_type = api_version = engine = model_engine_map = dest_url = None # read API info from endpoint_manager endpoint_manager = kwargs.pop('endpoint_manager', None) @@ -121,14 +121,14 @@ def consume_kwargs(cls, kwargs): if not isinstance(endpoint_manager, EndpointManager): raise Exception("endpoint_manager must be an instance of EndpointManager") # get_next_endpoint() will be called once for each request - api_key, organization, api_base, api_type, api_version, model_engine_map = endpoint_manager.get_next_endpoint().get_api_info() + api_key, organization, api_base, api_type, api_version, model_engine_map, dest_url = endpoint_manager.get_next_endpoint().get_api_info() # read API info from endpoint (override API info from endpoint_manager) endpoint = kwargs.pop('endpoint', None) if endpoint is not None: if not isinstance(endpoint, Endpoint): raise Exception("endpoint must be an instance of Endpoint") - api_key, organization, api_base, api_type, api_version, model_engine_map = endpoint.get_api_info() + api_key, organization, api_base, api_type, api_version, model_engine_map, dest_url = endpoint.get_api_info() # read API info from kwargs, class variables, and environment variables api_key = cls.get_api_key(kwargs.pop('api_key', api_key)) @@ -150,7 +150,8 @@ def consume_kwargs(cls, kwargs): engine = model_engine_map.get(model, model) else: engine = model - return api_key, organization, api_base, api_type, api_version, engine + dest_url = kwargs.pop('dest_url', dest_url) + return api_key, organization, api_base, api_type, api_version, engine, dest_url @staticmethod def get_request_url(request_url, api_type, api_version, engine): @@ -166,7 +167,7 @@ def get_request_url(request_url, api_type, api_version, engine): @classmethod def chat(cls, messages, logger=None, log_marks=[], **kwargs): - api_key, organization, api_base, api_type, api_version, engine = cls.consume_kwargs(kwargs) + api_key, organization, api_base, api_type, api_version, engine, dest_url = cls.consume_kwargs(kwargs) request_url = cls.get_request_url('/chat/completions', api_type, api_version, engine) if logger is not None: @@ -192,6 +193,7 @@ def chat(cls, messages, logger=None, log_marks=[], **kwargs): organization=organization, api_base=api_base, api_type=api_type, + dest_url=dest_url, **kwargs ) @@ -237,7 +239,7 @@ def wrapper(response): @classmethod def completions(cls, prompt, logger=None, log_marks=[], **kwargs): - api_key, organization, api_base, api_type, api_version, engine = cls.consume_kwargs(kwargs) + api_key, organization, api_base, api_type, api_version, engine, dest_url = cls.consume_kwargs(kwargs) request_url = cls.get_request_url('/completions', api_type, api_version, engine) if logger is not None: @@ -263,6 +265,7 @@ def completions(cls, prompt, logger=None, log_marks=[], **kwargs): organization=organization, api_base=api_base, api_type=api_type, + dest_url=dest_url, **kwargs ) @@ -308,7 +311,7 @@ def edits(cls, **kwargs): @classmethod def embeddings(cls, **kwargs): - api_key, organization, api_base, api_type, api_version, engine = cls.consume_kwargs(kwargs) + api_key, organization, api_base, api_type, api_version, engine, dest_url = cls.consume_kwargs(kwargs) request_url = cls.get_request_url('/embeddings', api_type, api_version, engine) return cls.api_request_endpoint( request_url, @@ -317,12 +320,13 @@ def embeddings(cls, **kwargs): organization=organization, api_base=api_base, api_type=api_type, + dest_url=dest_url, **kwargs ) @classmethod def models_list(cls, **kwargs): - api_key, organization, api_base, api_type, api_version, engine = cls.consume_kwargs(kwargs) + api_key, organization, api_base, api_type, api_version, engine, dest_url = cls.consume_kwargs(kwargs) request_url = cls.get_request_url('/models', api_type, api_version, engine) return cls.api_request_endpoint( request_url, @@ -331,6 +335,7 @@ def models_list(cls, **kwargs): organization=organization, api_base=api_base, api_type=api_type, + dest_url=dest_url, **kwargs ) @@ -346,7 +351,7 @@ def moderations(cls, **kwargs): @classmethod def images_generations(cls, **kwargs): - api_key, organization, api_base, api_type, api_version, engine = cls.consume_kwargs(kwargs) + api_key, organization, api_base, api_type, api_version, engine, dest_url = cls.consume_kwargs(kwargs) if api_type and api_type.lower() in _API_TYPES_AZURE: request_url = f'/openai/images/generations:submit?api-version={api_version}' raw_response = True @@ -361,6 +366,7 @@ def images_generations(cls, **kwargs): api_base=api_base, api_type=api_type, raw_response=raw_response, + dest_url=dest_url, **kwargs ) if api_type and api_type.lower() in _API_TYPES_AZURE: diff --git a/src/handyllm/prompt_converter.py b/src/handyllm/prompt_converter.py index 21e76fa..ca187e4 100644 --- a/src/handyllm/prompt_converter.py +++ b/src/handyllm/prompt_converter.py @@ -69,3 +69,11 @@ def chat_replace_variables(self, chat, variable_map: dict, inplace=False): new_chat.append(new_message) return new_chat + def chat_append_msg(self, chat, content: str, role: str = 'user', inplace=False): + if inplace: + chat.append({"role": role, "content": content}) + return chat + else: + new_chat = chat.copy() + new_chat.append({"role": role, "content": content}) + return new_chat diff --git a/tests/test_prompt.py b/tests/test_prompt.py index 9796013..981c9f1 100644 --- a/tests/test_prompt.py +++ b/tests/test_prompt.py @@ -20,3 +20,7 @@ ) # print(json.dumps(new_chat, indent=2)) print(converter.chat2raw(new_chat)) +print(converter.chat_append_msg(new_chat, '''{ + "item1": "It is really a good day.", + "item2": "Indeed." +}''', role='assistant'))