From ed750bfc99b2472eb73eb22705ae781d19129b50 Mon Sep 17 00:00:00 2001 From: Atomie CHEN Date: Tue, 7 May 2024 16:34:27 +0800 Subject: [PATCH 1/9] feat(hprompt): make RunConfig.record_request optional string type; add validation --- src/handyllm/hprompt.py | 39 ++++++++++++++++++++++++++++++--------- 1 file changed, 30 insertions(+), 9 deletions(-) diff --git a/src/handyllm/hprompt.py b/src/handyllm/hprompt.py index 333b6ac..9fec101 100644 --- a/src/handyllm/hprompt.py +++ b/src/handyllm/hprompt.py @@ -132,7 +132,20 @@ def load_var_map(path: PathType) -> dict[str, str]: return substitute_map -class RecordRequestMode(Enum): +class AutoName(Enum): + @staticmethod + def _generate_next_value_(name, start, count, last_values): + return name.lower() # use lower case as the value + + def __eq__(self, other: object) -> bool: + if isinstance(other, Enum): + return self.value == other.value + elif isinstance(other, str): + return self.value == other.lower() + return False + + +class RecordRequestMode(AutoName): BLACKLIST = auto() # record all request arguments except specified ones WHITELIST = auto() # record only specified request arguments NONE = auto() # record no request arguments @@ -165,6 +178,22 @@ class RunConfig: # verbose output to stderr verbose: Optional[bool] = None # default: False + def __setattr__(self, name: str, value: object): + if name == "record_request": + # validate record_request value + if isinstance(value, str): + option = value.upper() + if option not in RecordRequestMode.__members__: + raise ValueError(f"unsupported record_request value: {value}") + value = RecordRequestMode[option].value + elif isinstance(value, RecordRequestMode): + value = value.value + elif value is None: # this field is optional + pass + else: + raise ValueError(f"unsupported record_request value: {value}") + super().__setattr__(name, value) + def __len__(self): return len([f for f in fields(self) if getattr(self, f.name) is not None]) @@ -174,10 +203,6 @@ def from_dict(cls, obj: dict, base_path: Optional[PathType] = None): for field in fields(cls): if field.name in obj: input_kwargs[field.name] = obj[field.name] - # convert string to Enum - record_str = input_kwargs.get("record_request") - if record_str is not None: - input_kwargs["record_request"] = RecordRequestMode[record_str.upper()] # add base_path to path fields and convert to resolved path if base_path: for path_field in ("output_path", "output_evaled_prompt_path", "var_map_path", "credential_path"): @@ -212,10 +237,6 @@ def to_dict(self, retain_fd=False, base_path: Optional[PathType] = None) -> dict # keep file descriptors obj["output_fd"] = self.output_fd obj["output_evaled_prompt_fd"] = self.output_evaled_prompt_fd - # convert Enum to string - record_enum = obj.get("record_request") - if record_enum is not None: - obj["record_request"] = obj["record_request"].name # convert path to relative path if base_path: for path_field in ("output_path", "output_evaled_prompt_path", "var_map_path", "credential_path"): From ef49e402699e15c0d82f7b76564a680dfe367407 Mon Sep 17 00:00:00 2001 From: Atomie CHEN Date: Tue, 7 May 2024 16:43:39 +0800 Subject: [PATCH 2/9] feat(hprompt): improve record_request validation --- src/handyllm/hprompt.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/handyllm/hprompt.py b/src/handyllm/hprompt.py index 9fec101..905ffd0 100644 --- a/src/handyllm/hprompt.py +++ b/src/handyllm/hprompt.py @@ -182,10 +182,8 @@ def __setattr__(self, name: str, value: object): if name == "record_request": # validate record_request value if isinstance(value, str): - option = value.upper() - if option not in RecordRequestMode.__members__: + if value not in iter(RecordRequestMode): raise ValueError(f"unsupported record_request value: {value}") - value = RecordRequestMode[option].value elif isinstance(value, RecordRequestMode): value = value.value elif value is None: # this field is optional From df015b9c2d389deadd488e5ad3455fdafc08f188 Mon Sep 17 00:00:00 2001 From: Atomie CHEN Date: Tue, 7 May 2024 17:06:47 +0800 Subject: [PATCH 3/9] feat(hprompt): improve RecordRequestMode class methods --- src/handyllm/hprompt.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/src/handyllm/hprompt.py b/src/handyllm/hprompt.py index 905ffd0..53b5766 100644 --- a/src/handyllm/hprompt.py +++ b/src/handyllm/hprompt.py @@ -23,7 +23,7 @@ from pathlib import Path from datetime import datetime from typing import Optional, Union, TypeVar -from enum import Enum, auto +from enum import Enum, EnumMeta, auto from abc import abstractmethod, ABC from dataclasses import dataclass, asdict, fields, replace @@ -132,7 +132,14 @@ def load_var_map(path: PathType) -> dict[str, str]: return substitute_map -class AutoName(Enum): +class StrEnumMeta(EnumMeta): + def __contains__(cls, item): + if isinstance(item, str): + return item in iter(cls) + return super().__contains__(item) + + +class AutoStrEnum(Enum, metaclass=StrEnumMeta): @staticmethod def _generate_next_value_(name, start, count, last_values): return name.lower() # use lower case as the value @@ -145,7 +152,7 @@ def __eq__(self, other: object) -> bool: return False -class RecordRequestMode(AutoName): +class RecordRequestMode(AutoStrEnum): BLACKLIST = auto() # record all request arguments except specified ones WHITELIST = auto() # record only specified request arguments NONE = auto() # record no request arguments @@ -182,7 +189,7 @@ def __setattr__(self, name: str, value: object): if name == "record_request": # validate record_request value if isinstance(value, str): - if value not in iter(RecordRequestMode): + if value not in RecordRequestMode: raise ValueError(f"unsupported record_request value: {value}") elif isinstance(value, RecordRequestMode): value = value.value From 184e8ad25ace2dc68d7d351f245668d988e753c1 Mon Sep 17 00:00:00 2001 From: Atomie CHEN Date: Tue, 7 May 2024 17:09:00 +0800 Subject: [PATCH 4/9] refactor(hprompt): move AutoStrEnum to separate module --- src/handyllm/_str_enum.py | 22 ++++++++++++++++++++++ src/handyllm/hprompt.py | 24 ++---------------------- 2 files changed, 24 insertions(+), 22 deletions(-) create mode 100644 src/handyllm/_str_enum.py diff --git a/src/handyllm/_str_enum.py b/src/handyllm/_str_enum.py new file mode 100644 index 0000000..b612066 --- /dev/null +++ b/src/handyllm/_str_enum.py @@ -0,0 +1,22 @@ +from enum import Enum, EnumMeta + + +class StrEnumMeta(EnumMeta): + def __contains__(cls, item): + if isinstance(item, str): + return item in iter(cls) + return super().__contains__(item) + + +class AutoStrEnum(Enum, metaclass=StrEnumMeta): + @staticmethod + def _generate_next_value_(name, start, count, last_values): + return name.lower() # use lower case as the value + + def __eq__(self, other: object) -> bool: + if isinstance(other, Enum): + return self.value == other.value + elif isinstance(other, str): + return self.value == other.lower() + return False + diff --git a/src/handyllm/hprompt.py b/src/handyllm/hprompt.py index 53b5766..5ec672f 100644 --- a/src/handyllm/hprompt.py +++ b/src/handyllm/hprompt.py @@ -23,7 +23,7 @@ from pathlib import Path from datetime import datetime from typing import Optional, Union, TypeVar -from enum import Enum, EnumMeta, auto +from enum import auto from abc import abstractmethod, ABC from dataclasses import dataclass, asdict, fields, replace @@ -38,6 +38,7 @@ astream_chat_with_role, astream_completions, stream_chat_with_role, stream_completions, ) +from ._str_enum import AutoStrEnum PromptType = TypeVar('PromptType', bound='HandyPrompt') @@ -131,27 +132,6 @@ def load_var_map(path: PathType) -> dict[str, str]: substitute_map[key] = value.strip() return substitute_map - -class StrEnumMeta(EnumMeta): - def __contains__(cls, item): - if isinstance(item, str): - return item in iter(cls) - return super().__contains__(item) - - -class AutoStrEnum(Enum, metaclass=StrEnumMeta): - @staticmethod - def _generate_next_value_(name, start, count, last_values): - return name.lower() # use lower case as the value - - def __eq__(self, other: object) -> bool: - if isinstance(other, Enum): - return self.value == other.value - elif isinstance(other, str): - return self.value == other.lower() - return False - - class RecordRequestMode(AutoStrEnum): BLACKLIST = auto() # record all request arguments except specified ones WHITELIST = auto() # record only specified request arguments From 5cdbe6ff6ed50a83bb9f8d53dc4ec35b1bf16d21 Mon Sep 17 00:00:00 2001 From: Atomie CHEN Date: Tue, 7 May 2024 17:23:23 +0800 Subject: [PATCH 5/9] feat(hprompt): make credential type enum --- src/handyllm/hprompt.py | 34 ++++++++++++++++++++++++++-------- 1 file changed, 26 insertions(+), 8 deletions(-) diff --git a/src/handyllm/hprompt.py b/src/handyllm/hprompt.py index 5ec672f..2fa2b38 100644 --- a/src/handyllm/hprompt.py +++ b/src/handyllm/hprompt.py @@ -139,6 +139,14 @@ class RecordRequestMode(AutoStrEnum): ALL = auto() # record all request arguments +class CredentialType(AutoStrEnum): + # load environment variables from the credential file + ENV = auto() + # load the content of the file as request arguments + JSON = auto() + YAML = auto() + + @dataclass class RunConfig: # record request arguments @@ -160,7 +168,7 @@ class RunConfig: # credential type: env, json, yaml # if env, load environment variables from the credential file # if json or yaml, load the content of the file as request arguments - credential_type: Optional[str] = None # default: guess from the file extension + credential_type: Optional[CredentialType] = None # default: guess from the file extension # verbose output to stderr verbose: Optional[bool] = None # default: False @@ -177,6 +185,19 @@ def __setattr__(self, name: str, value: object): pass else: raise ValueError(f"unsupported record_request value: {value}") + elif name == "credential_type": + # validate credential_type value + if isinstance(value, str): + if value == 'yml': + value = CredentialType.YAML.value + elif value not in CredentialType: + raise ValueError(f"unsupported credential_type value: {value}") + elif isinstance(value, CredentialType): + value = value.value + elif value is None: # this field is optional + pass + else: + raise ValueError(f"unsupported credential_type value: {value}") super().__setattr__(name, value) def __len__(self): @@ -358,11 +379,8 @@ def eval_run_config( p = Path(run_config.credential_path) if p.suffix: run_config.credential_type = p.suffix[1:] - if run_config.credential_type == "yml": - run_config.credential_type = "yaml" else: - run_config.credential_type = 'env' - run_config.credential_type = run_config.credential_type.lower() + run_config.credential_type = CredentialType.ENV return run_config @abstractmethod @@ -447,11 +465,11 @@ def _prepare_run(self: PromptType, run_config: RunConfig, kwargs: dict): # load the credential file if run_config.credential_path: - if run_config.credential_type == "env": + if run_config.credential_type == CredentialType.ENV: load_dotenv(run_config.credential_path, override=True) - elif run_config.credential_type in ("json", "yaml"): + elif run_config.credential_type in (CredentialType.JSON, CredentialType.YAML): with open(run_config.credential_path, 'r', encoding='utf-8') as fin: - if run_config.credential_type == "json": + if run_config.credential_type == CredentialType.JSON: credential_dict = json.load(fin) else: credential_dict = yaml.safe_load(fin) From 2fae6e102a5a85e5e9bc20e38fce653f3493eb02 Mon Sep 17 00:00:00 2001 From: Atomie CHEN Date: Tue, 7 May 2024 17:34:26 +0800 Subject: [PATCH 6/9] feat(test_hprompt): update --- tests/test_hprompt.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/tests/test_hprompt.py b/tests/test_hprompt.py index 29a45fd..ee81df7 100644 --- a/tests/test_hprompt.py +++ b/tests/test_hprompt.py @@ -22,7 +22,10 @@ prompt += result_prompt # chain another hprompt prompt += hprompt.load_from(cur_dir / './assets/magic.hprompt') -# run again -result2 = prompt.run() +# create a new run config +run_config = hprompt.RunConfig() +run_config.record_request = hprompt.RecordRequestMode.NONE # record no request args +# run again, with run config +result2 = prompt.run(run_config=run_config) From 1dd61b25b0b948174467575df5fde7b250a9a657 Mon Sep 17 00:00:00 2001 From: Atomie CHEN Date: Tue, 7 May 2024 20:07:18 +0800 Subject: [PATCH 7/9] feat: remove hprompt import from __init__ --- src/handyllm/__init__.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/handyllm/__init__.py b/src/handyllm/__init__.py index c95b646..b2ee3c9 100644 --- a/src/handyllm/__init__.py +++ b/src/handyllm/__init__.py @@ -4,4 +4,3 @@ from .endpoint_manager import Endpoint, EndpointManager from .prompt_converter import PromptConverter from .utils import * -from .hprompt import * From 438162e5b37810c45b21cb82c31d9d686130beef Mon Sep 17 00:00:00 2001 From: Atomie CHEN Date: Tue, 7 May 2024 20:08:55 +0800 Subject: [PATCH 8/9] feat(hprompt): add CredentialType to __all__ --- src/handyllm/hprompt.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/handyllm/hprompt.py b/src/handyllm/hprompt.py index 2fa2b38..f0ce1ce 100644 --- a/src/handyllm/hprompt.py +++ b/src/handyllm/hprompt.py @@ -12,6 +12,7 @@ "load_var_map", "RunConfig", "RecordRequestMode", + "CredentialType", ] import json From 943dfeb529b44155ce84b5578646ae569ccd0520 Mon Sep 17 00:00:00 2001 From: Atomie CHEN Date: Tue, 7 May 2024 20:10:01 +0800 Subject: [PATCH 9/9] Bump version to 0.7.1 --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 16426d5..97ab797 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "HandyLLM" -version = "0.7.0" +version = "0.7.1" authors = [ { name="Atomie CHEN", email="atomic_cwh@163.com" }, ]