Skip to content

Commit

Permalink
Merge pull request #22 from atomiechen/dev
Browse files Browse the repository at this point in the history
Bump version to 0.7.1
  • Loading branch information
atomiechen committed May 7, 2024
2 parents 893e952 + 943dfeb commit b0f815b
Show file tree
Hide file tree
Showing 5 changed files with 72 additions and 23 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "HandyLLM"
version = "0.7.0"
version = "0.7.1"
authors = [
{ name="Atomie CHEN", email="atomic_cwh@163.com" },
]
Expand Down
1 change: 0 additions & 1 deletion src/handyllm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,3 @@
from .endpoint_manager import Endpoint, EndpointManager
from .prompt_converter import PromptConverter
from .utils import *
from .hprompt import *
22 changes: 22 additions & 0 deletions src/handyllm/_str_enum.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from enum import Enum, EnumMeta


class StrEnumMeta(EnumMeta):
def __contains__(cls, item):
if isinstance(item, str):
return item in iter(cls)
return super().__contains__(item)


class AutoStrEnum(Enum, metaclass=StrEnumMeta):
@staticmethod
def _generate_next_value_(name, start, count, last_values):
return name.lower() # use lower case as the value

def __eq__(self, other: object) -> bool:
if isinstance(other, Enum):
return self.value == other.value
elif isinstance(other, str):
return self.value == other.lower()
return False

63 changes: 44 additions & 19 deletions src/handyllm/hprompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
"load_var_map",
"RunConfig",
"RecordRequestMode",
"CredentialType",
]

import json
Expand All @@ -23,7 +24,7 @@
from pathlib import Path
from datetime import datetime
from typing import Optional, Union, TypeVar
from enum import Enum, auto
from enum import auto
from abc import abstractmethod, ABC
from dataclasses import dataclass, asdict, fields, replace

Expand All @@ -38,6 +39,7 @@
astream_chat_with_role, astream_completions,
stream_chat_with_role, stream_completions,
)
from ._str_enum import AutoStrEnum


PromptType = TypeVar('PromptType', bound='HandyPrompt')
Expand Down Expand Up @@ -131,14 +133,21 @@ def load_var_map(path: PathType) -> dict[str, str]:
substitute_map[key] = value.strip()
return substitute_map


class RecordRequestMode(Enum):
class RecordRequestMode(AutoStrEnum):
BLACKLIST = auto() # record all request arguments except specified ones
WHITELIST = auto() # record only specified request arguments
NONE = auto() # record no request arguments
ALL = auto() # record all request arguments


class CredentialType(AutoStrEnum):
# load environment variables from the credential file
ENV = auto()
# load the content of the file as request arguments
JSON = auto()
YAML = auto()


@dataclass
class RunConfig:
# record request arguments
Expand All @@ -160,11 +169,38 @@ class RunConfig:
# credential type: env, json, yaml
# if env, load environment variables from the credential file
# if json or yaml, load the content of the file as request arguments
credential_type: Optional[str] = None # default: guess from the file extension
credential_type: Optional[CredentialType] = None # default: guess from the file extension

# verbose output to stderr
verbose: Optional[bool] = None # default: False

def __setattr__(self, name: str, value: object):
if name == "record_request":
# validate record_request value
if isinstance(value, str):
if value not in RecordRequestMode:
raise ValueError(f"unsupported record_request value: {value}")
elif isinstance(value, RecordRequestMode):
value = value.value
elif value is None: # this field is optional
pass
else:
raise ValueError(f"unsupported record_request value: {value}")
elif name == "credential_type":
# validate credential_type value
if isinstance(value, str):
if value == 'yml':
value = CredentialType.YAML.value
elif value not in CredentialType:
raise ValueError(f"unsupported credential_type value: {value}")
elif isinstance(value, CredentialType):
value = value.value
elif value is None: # this field is optional
pass
else:
raise ValueError(f"unsupported credential_type value: {value}")
super().__setattr__(name, value)

def __len__(self):
return len([f for f in fields(self) if getattr(self, f.name) is not None])

Expand All @@ -174,10 +210,6 @@ def from_dict(cls, obj: dict, base_path: Optional[PathType] = None):
for field in fields(cls):
if field.name in obj:
input_kwargs[field.name] = obj[field.name]
# convert string to Enum
record_str = input_kwargs.get("record_request")
if record_str is not None:
input_kwargs["record_request"] = RecordRequestMode[record_str.upper()]
# add base_path to path fields and convert to resolved path
if base_path:
for path_field in ("output_path", "output_evaled_prompt_path", "var_map_path", "credential_path"):
Expand Down Expand Up @@ -212,10 +244,6 @@ def to_dict(self, retain_fd=False, base_path: Optional[PathType] = None) -> dict
# keep file descriptors
obj["output_fd"] = self.output_fd
obj["output_evaled_prompt_fd"] = self.output_evaled_prompt_fd
# convert Enum to string
record_enum = obj.get("record_request")
if record_enum is not None:
obj["record_request"] = obj["record_request"].name
# convert path to relative path
if base_path:
for path_field in ("output_path", "output_evaled_prompt_path", "var_map_path", "credential_path"):
Expand Down Expand Up @@ -352,11 +380,8 @@ def eval_run_config(
p = Path(run_config.credential_path)
if p.suffix:
run_config.credential_type = p.suffix[1:]
if run_config.credential_type == "yml":
run_config.credential_type = "yaml"
else:
run_config.credential_type = 'env'
run_config.credential_type = run_config.credential_type.lower()
run_config.credential_type = CredentialType.ENV
return run_config

@abstractmethod
Expand Down Expand Up @@ -441,11 +466,11 @@ def _prepare_run(self: PromptType, run_config: RunConfig, kwargs: dict):

# load the credential file
if run_config.credential_path:
if run_config.credential_type == "env":
if run_config.credential_type == CredentialType.ENV:
load_dotenv(run_config.credential_path, override=True)
elif run_config.credential_type in ("json", "yaml"):
elif run_config.credential_type in (CredentialType.JSON, CredentialType.YAML):
with open(run_config.credential_path, 'r', encoding='utf-8') as fin:
if run_config.credential_type == "json":
if run_config.credential_type == CredentialType.JSON:
credential_dict = json.load(fin)
else:
credential_dict = yaml.safe_load(fin)
Expand Down
7 changes: 5 additions & 2 deletions tests/test_hprompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,10 @@
prompt += result_prompt
# chain another hprompt
prompt += hprompt.load_from(cur_dir / './assets/magic.hprompt')
# run again
result2 = prompt.run()
# create a new run config
run_config = hprompt.RunConfig()
run_config.record_request = hprompt.RecordRequestMode.NONE # record no request args
# run again, with run config
result2 = prompt.run(run_config=run_config)


0 comments on commit b0f815b

Please sign in to comment.