-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Bump version to 0.5.0: Merge pull request #10 from atomiechen/dev
Bump version to 0.5.0
- Loading branch information
Showing
8 changed files
with
237 additions
and
150 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,3 @@ | ||
from .openai_api import OpenAIAPI | ||
from .endpoint_manager import EndpointManager | ||
from .endpoint_manager import Endpoint, EndpointManager | ||
from .prompt_converter import PromptConverter |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,81 +1,84 @@ | ||
from threading import Lock | ||
from . import OpenAIAPI | ||
from collections.abc import MutableSequence | ||
|
||
|
||
class Endpoint: | ||
def __init__( | ||
self, | ||
name=None, | ||
api_key=None, | ||
organization=None, | ||
api_base=None, | ||
api_type=None, | ||
api_version=None, | ||
): | ||
self.name = name if name else f"ep_{id(self)}" | ||
self.api_key = api_key | ||
self.organization = organization | ||
self.api_base = api_base | ||
self.api_type = api_type | ||
self.api_version = api_version | ||
|
||
def __str__(self) -> str: | ||
# do not print api_key | ||
listed_attributes = [ | ||
f'name={repr(self.name)}' if self.name else None, | ||
f'api_key=*' if self.api_key else None, | ||
f'organization={repr(self.organization)}' if self.organization else None, | ||
f'api_base={repr(self.api_base)}' if self.api_base else None, | ||
f'api_type={repr(self.api_type)}' if self.api_type else None, | ||
f'api_version={repr(self.api_version)}' if self.api_version else None, | ||
] | ||
# remove None in listed_attributes | ||
listed_attributes = [item for item in listed_attributes if item] | ||
return f"Endpoint({', '.join(listed_attributes)})" | ||
|
||
def get_api_info(self): | ||
return ( | ||
self.api_key, | ||
self.organization, | ||
self.api_base, | ||
self.api_type, | ||
self.api_version | ||
) | ||
|
||
|
||
class EndpointManager(MutableSequence): | ||
|
||
class EndpointManager: | ||
|
||
def __init__(self): | ||
self._lock = Lock() | ||
self._last_idx_endpoint = 0 | ||
self._endpoints = [] | ||
|
||
def clear(self): | ||
self._last_idx_endpoint = 0 | ||
self._endpoints.clear() | ||
|
||
self._base_urls = [] | ||
self._last_idx_url = 0 | ||
|
||
self._keys = [] | ||
self._last_idx_key = 0 | ||
|
||
self._organizations = [] | ||
self._last_idx_organization = 0 | ||
def __len__(self) -> int: | ||
return len(self._endpoints) | ||
|
||
def add_base_url(self, base_url: str): | ||
if isinstance(base_url, str) and base_url.strip() != '': | ||
self._base_urls.append(base_url) | ||
def __getitem__(self, idx: int) -> Endpoint: | ||
return self._endpoints[idx] | ||
|
||
def add_key(self, key: str): | ||
if isinstance(key, str) and key.strip() != '': | ||
self._keys.append(key) | ||
|
||
def add_organization(self, organization: str): | ||
if isinstance(organization, str) and organization.strip() != '': | ||
self._organizations.append(organization) | ||
|
||
def set_base_urls(self, base_urls): | ||
self._base_urls = [url for url in base_urls if isinstance(url, str) and url.strip() != ''] | ||
|
||
def set_keys(self, keys): | ||
self._keys = [key for key in keys if isinstance(key, str) and key.strip() != ''] | ||
|
||
def set_organizations(self, organizations): | ||
self._organizations = [organization for organization in organizations if isinstance(organization, str) and organization.strip() != ''] | ||
def __setitem__(self, idx: int, endpoint: Endpoint): | ||
self._endpoints[idx] = endpoint | ||
|
||
def get_base_url(self): | ||
if len(self._base_urls) == 0: | ||
return OpenAIAPI.get_api_base() | ||
else: | ||
base_url = self._base_urls[self._last_idx_url] | ||
if self._last_idx_url == len(self._base_urls) - 1: | ||
self._last_idx_url = 0 | ||
else: | ||
self._last_idx_url += 1 | ||
return base_url | ||
def __delitem__(self, idx: int): | ||
del self._endpoints[idx] | ||
|
||
def get_key(self): | ||
if len(self._keys) == 0: | ||
return OpenAIAPI.get_api_key() | ||
else: | ||
key = self._keys[self._last_idx_key] | ||
if self._last_idx_key == len(self._keys) - 1: | ||
self._last_idx_key = 0 | ||
else: | ||
self._last_idx_key += 1 | ||
return key | ||
|
||
def get_organization(self): | ||
if len(self._organizations) == 0: | ||
return OpenAIAPI.get_organization() | ||
else: | ||
organization = self._organizations[self._last_idx_organization] | ||
if self._last_idx_organization == len(self._keys) - 1: | ||
self._last_idx_organization = 0 | ||
else: | ||
self._last_idx_organization += 1 | ||
return organization | ||
|
||
def get_endpoint(self): | ||
def insert(self, idx: int, endpoint: Endpoint): | ||
self._endpoints.insert(idx, endpoint) | ||
|
||
def add_endpoint_by_info(self, **kwargs): | ||
endpoint = Endpoint(**kwargs) | ||
self.append(endpoint) | ||
|
||
def get_next_endpoint(self) -> Endpoint: | ||
with self._lock: | ||
# compose full url | ||
base_url = self.get_base_url() | ||
# get API key | ||
api_key = self.get_key() | ||
# get organization | ||
organization = self.get_organization() | ||
return base_url, api_key, organization | ||
endpoint = self._endpoints[self._last_idx_endpoint] | ||
if self._last_idx_endpoint == len(self._endpoints) - 1: | ||
self._last_idx_endpoint = 0 | ||
else: | ||
self._last_idx_endpoint += 1 | ||
return endpoint | ||
|
Oops, something went wrong.