Skip to content

Commit

Permalink
Merge pull request #13 from atomiechen/dev
Browse files Browse the repository at this point in the history
Bump version to 0.5.3
  • Loading branch information
atomiechen committed Dec 19, 2023
2 parents 2d98a7e + a7bb221 commit 8772ef7
Show file tree
Hide file tree
Showing 7 changed files with 38 additions and 13 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ Example scripts are placed in [tests](./tests) folder.

### Endpoints

Each API request will connect to an endpoint along with some API configurations, which include: `api_key`, `organization`, `api_base`, `api_type`, `api_version` and `model_engine_map`.
Each API request will connect to an endpoint along with some API configurations, which include: `api_key`, `organization`, `api_base`, `api_type`, `api_version`, `model_engine_map`, `dest_url`.

An `Endpoint` object contains these information. An `EndpointManager` acts like a list and can be used to rotate the next endpoint. See [test_endpoint.py](./tests/test_endpoint.py).

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "HandyLLM"
version = "0.5.2"
version = "0.5.3"
authors = [
{ name="Atomie CHEN", email="atomic_cwh@163.com" },
]
Expand Down
3 changes: 3 additions & 0 deletions src/handyllm/api_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def api_request(
timeout=None,
files=None,
raw_response=False,
dest_url=None,
**kwargs):
if api_key is None:
raise Exception("OpenAI API key is not set")
Expand Down Expand Up @@ -50,6 +51,8 @@ def api_request(
headers['Authorization'] = 'Bearer ' + api_key
if organization is not None:
headers['OpenAI-Organization'] = organization
if dest_url is not None:
headers['Destination-URL'] = dest_url
if method == 'post':
if files is None:
headers['Content-Type'] = 'application/json'
Expand Down
4 changes: 4 additions & 0 deletions src/handyllm/endpoint_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ def __init__(
api_type=None,
api_version=None,
model_engine_map=None,
dest_url=None,
):
self.name = name if name else f"ep_{id(self)}"
self.api_key = api_key
Expand All @@ -20,6 +21,7 @@ def __init__(
self.api_type = api_type
self.api_version = api_version
self.model_engine_map = model_engine_map if model_engine_map else {}
self.dest_url = dest_url

def __str__(self) -> str:
# do not print api_key
Expand All @@ -31,6 +33,7 @@ def __str__(self) -> str:
f'api_type={repr(self.api_type)}' if self.api_type else None,
f'api_version={repr(self.api_version)}' if self.api_version else None,
f'model_engine_map={repr(self.model_engine_map)}' if self.model_engine_map else None,
f'dest_url={repr(self.dest_url)}' if self.dest_url else None,
]
# remove None in listed_attributes
listed_attributes = [item for item in listed_attributes if item]
Expand All @@ -44,6 +47,7 @@ def get_api_info(self):
self.api_type,
self.api_version,
self.model_engine_map,
self.dest_url,
)


Expand Down
28 changes: 17 additions & 11 deletions src/handyllm/openai_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,28 +107,28 @@ def api_request_endpoint(
request_url,
**kwargs
):
api_key, organization, api_base, api_type, api_version, engine = cls.consume_kwargs(kwargs)
api_key, organization, api_base, api_type, api_version, engine, dest_url = cls.consume_kwargs(kwargs)
url = utils.join_url(api_base, request_url)
return api_request(url, api_key, organization=organization, api_type=api_type, **kwargs)
return api_request(url, api_key, organization=organization, api_type=api_type, dest_url=dest_url, **kwargs)

@classmethod
def consume_kwargs(cls, kwargs):
api_key = organization = api_base = api_type = api_version = engine = model_engine_map = None
api_key = organization = api_base = api_type = api_version = engine = model_engine_map = dest_url = None

# read API info from endpoint_manager
endpoint_manager = kwargs.pop('endpoint_manager', None)
if endpoint_manager is not None:
if not isinstance(endpoint_manager, EndpointManager):
raise Exception("endpoint_manager must be an instance of EndpointManager")
# get_next_endpoint() will be called once for each request
api_key, organization, api_base, api_type, api_version, model_engine_map = endpoint_manager.get_next_endpoint().get_api_info()
api_key, organization, api_base, api_type, api_version, model_engine_map, dest_url = endpoint_manager.get_next_endpoint().get_api_info()

# read API info from endpoint (override API info from endpoint_manager)
endpoint = kwargs.pop('endpoint', None)
if endpoint is not None:
if not isinstance(endpoint, Endpoint):
raise Exception("endpoint must be an instance of Endpoint")
api_key, organization, api_base, api_type, api_version, model_engine_map = endpoint.get_api_info()
api_key, organization, api_base, api_type, api_version, model_engine_map, dest_url = endpoint.get_api_info()

# read API info from kwargs, class variables, and environment variables
api_key = cls.get_api_key(kwargs.pop('api_key', api_key))
Expand All @@ -150,7 +150,8 @@ def consume_kwargs(cls, kwargs):
engine = model_engine_map.get(model, model)
else:
engine = model
return api_key, organization, api_base, api_type, api_version, engine
dest_url = kwargs.pop('dest_url', dest_url)
return api_key, organization, api_base, api_type, api_version, engine, dest_url

@staticmethod
def get_request_url(request_url, api_type, api_version, engine):
Expand All @@ -166,7 +167,7 @@ def get_request_url(request_url, api_type, api_version, engine):

@classmethod
def chat(cls, messages, logger=None, log_marks=[], **kwargs):
api_key, organization, api_base, api_type, api_version, engine = cls.consume_kwargs(kwargs)
api_key, organization, api_base, api_type, api_version, engine, dest_url = cls.consume_kwargs(kwargs)
request_url = cls.get_request_url('/chat/completions', api_type, api_version, engine)

if logger is not None:
Expand All @@ -192,6 +193,7 @@ def chat(cls, messages, logger=None, log_marks=[], **kwargs):
organization=organization,
api_base=api_base,
api_type=api_type,
dest_url=dest_url,
**kwargs
)

Expand Down Expand Up @@ -237,7 +239,7 @@ def wrapper(response):

@classmethod
def completions(cls, prompt, logger=None, log_marks=[], **kwargs):
api_key, organization, api_base, api_type, api_version, engine = cls.consume_kwargs(kwargs)
api_key, organization, api_base, api_type, api_version, engine, dest_url = cls.consume_kwargs(kwargs)
request_url = cls.get_request_url('/completions', api_type, api_version, engine)

if logger is not None:
Expand All @@ -263,6 +265,7 @@ def completions(cls, prompt, logger=None, log_marks=[], **kwargs):
organization=organization,
api_base=api_base,
api_type=api_type,
dest_url=dest_url,
**kwargs
)

Expand Down Expand Up @@ -308,7 +311,7 @@ def edits(cls, **kwargs):

@classmethod
def embeddings(cls, **kwargs):
api_key, organization, api_base, api_type, api_version, engine = cls.consume_kwargs(kwargs)
api_key, organization, api_base, api_type, api_version, engine, dest_url = cls.consume_kwargs(kwargs)
request_url = cls.get_request_url('/embeddings', api_type, api_version, engine)
return cls.api_request_endpoint(
request_url,
Expand All @@ -317,12 +320,13 @@ def embeddings(cls, **kwargs):
organization=organization,
api_base=api_base,
api_type=api_type,
dest_url=dest_url,
**kwargs
)

@classmethod
def models_list(cls, **kwargs):
api_key, organization, api_base, api_type, api_version, engine = cls.consume_kwargs(kwargs)
api_key, organization, api_base, api_type, api_version, engine, dest_url = cls.consume_kwargs(kwargs)
request_url = cls.get_request_url('/models', api_type, api_version, engine)
return cls.api_request_endpoint(
request_url,
Expand All @@ -331,6 +335,7 @@ def models_list(cls, **kwargs):
organization=organization,
api_base=api_base,
api_type=api_type,
dest_url=dest_url,
**kwargs
)

Expand All @@ -346,7 +351,7 @@ def moderations(cls, **kwargs):

@classmethod
def images_generations(cls, **kwargs):
api_key, organization, api_base, api_type, api_version, engine = cls.consume_kwargs(kwargs)
api_key, organization, api_base, api_type, api_version, engine, dest_url = cls.consume_kwargs(kwargs)
if api_type and api_type.lower() in _API_TYPES_AZURE:
request_url = f'/openai/images/generations:submit?api-version={api_version}'
raw_response = True
Expand All @@ -361,6 +366,7 @@ def images_generations(cls, **kwargs):
api_base=api_base,
api_type=api_type,
raw_response=raw_response,
dest_url=dest_url,
**kwargs
)
if api_type and api_type.lower() in _API_TYPES_AZURE:
Expand Down
8 changes: 8 additions & 0 deletions src/handyllm/prompt_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,3 +69,11 @@ def chat_replace_variables(self, chat, variable_map: dict, inplace=False):
new_chat.append(new_message)
return new_chat

def chat_append_msg(self, chat, content: str, role: str = 'user', inplace=False):
if inplace:
chat.append({"role": role, "content": content})
return chat
else:
new_chat = chat.copy()
new_chat.append({"role": role, "content": content})
return new_chat
4 changes: 4 additions & 0 deletions tests/test_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,7 @@
)
# print(json.dumps(new_chat, indent=2))
print(converter.chat2raw(new_chat))
print(converter.chat_append_msg(new_chat, '''{
"item1": "It is really a good day.",
"item2": "Indeed."
}''', role='assistant'))

0 comments on commit 8772ef7

Please sign in to comment.