diff --git a/pegasus/Pegasus.py b/pegasus/Pegasus.py index f62fa6b..8d2a816 100644 --- a/pegasus/Pegasus.py +++ b/pegasus/Pegasus.py @@ -15,7 +15,9 @@ logger = loguru.logger class Pegasus: - def __init__(self, output_dir, exclude_selectors=None, include_domain=None, exclude_keywords=None, output_extension=".md", dust_size=1000, max_depth=None, system_message=None, classification_prompt=None, max_retries=3): + def __init__(self, output_dir, exclude_selectors=None, include_domain=None, exclude_keywords=None, output_extension=".md", + dust_size=1000, max_depth=None, system_message=None, classification_prompt=None, max_retries=3, + model='gemini/gemini-1.5-pro-latest', rate_limit_sleep=60, other_error_sleep=10): self.output_dir = output_dir self.exclude_selectors = exclude_selectors self.include_domain = include_domain @@ -28,6 +30,9 @@ def __init__(self, output_dir, exclude_selectors=None, include_domain=None, excl self.system_message = system_message self.classification_prompt = classification_prompt self.max_retries = max_retries + self.model = model + self.rate_limit_sleep = rate_limit_sleep + self.other_error_sleep = other_error_sleep tprint(" Pegasus ", font="rnd-xlarge") logger.info("初期化パラメータ:") logger.info(f" output_dir: {output_dir}") @@ -40,6 +45,9 @@ def __init__(self, output_dir, exclude_selectors=None, include_domain=None, excl logger.info(f" system_message: {system_message}") logger.info(f" classification_prompt: {classification_prompt}") logger.info(f" max_retries: {max_retries}") + logger.info(f" model: {model}") + logger.info(f" rate_limit_sleep: {rate_limit_sleep}") + logger.info(f" other_error_sleep: {other_error_sleep}") def filter_site(self, markdown_content): if(self.classification_prompt is None): @@ -69,9 +77,9 @@ def filter_site(self, markdown_content): logger.warning(f"フィルタリングでエラーが発生しました。リトライします。({retry_count}/{self.max_retries})\nError: {e}") if "429" in str(e): - sleep_time = 60 # 60秒スリープ + sleep_time = self.rate_limit_sleep # レート制限エラー時のスリープ時間をself.rate_limit_sleepから取得 else: - sleep_time = 10 # その他のエラーの場合は10秒スリープ + sleep_time = self.other_error_sleep # その他のエラー時のスリープ時間をself.other_error_sleepから取得 for _ in tqdm(range(sleep_time), desc="Sleeping", unit="s"): time.sleep(1) diff --git a/pegasus/cli.py b/pegasus/cli.py index 1c41460..3967257 100644 --- a/pegasus/cli.py +++ b/pegasus/cli.py @@ -17,7 +17,10 @@ def main(): parser.add_argument('--system-message', default=None, help='LiteLLMのシステムメッセージ(サイトの分類に使用)') parser.add_argument('--classification-prompt', default=None, help='LiteLLMのサイト分類プロンプト(TrueまたはFalseを返すようにしてください)') parser.add_argument('--max-retries', type=int, default=3, help='フィルタリングのリトライ回数の上限(デフォルト:3)') - + parser.add_argument('--model', default='gemini/gemini-1.5-pro-latest', help='LiteLLMのモデル名 (デフォルト: gemini/gemini-1.5-pro-latest)') + parser.add_argument('--rate-limit-sleep', type=int, default=60, help='レート制限エラー時のスリープ時間(秒) (デフォルト: 60)') + parser.add_argument('--other-error-sleep', type=int, default=10, help='その他のエラー時のスリープ時間(秒) (デフォルト: 10)') + args = parser.parse_args() pegasus = Pegasus(