Skip to content

Commit

Permalink
Merge feature/pegasus-scraping-filter-enhancements
Browse files Browse the repository at this point in the history
  • Loading branch information
Sunwood-ai-labs committed Jun 9, 2024
2 parents 560914f + 5d972a3 commit b1cd378
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 4 deletions.
14 changes: 11 additions & 3 deletions pegasus/Pegasus.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@
logger = loguru.logger

class Pegasus:
def __init__(self, output_dir, exclude_selectors=None, include_domain=None, exclude_keywords=None, output_extension=".md", dust_size=1000, max_depth=None, system_message=None, classification_prompt=None, max_retries=3):
def __init__(self, output_dir, exclude_selectors=None, include_domain=None, exclude_keywords=None, output_extension=".md",
dust_size=1000, max_depth=None, system_message=None, classification_prompt=None, max_retries=3,
model='gemini/gemini-1.5-pro-latest', rate_limit_sleep=60, other_error_sleep=10):
self.output_dir = output_dir
self.exclude_selectors = exclude_selectors
self.include_domain = include_domain
Expand All @@ -28,6 +30,9 @@ def __init__(self, output_dir, exclude_selectors=None, include_domain=None, excl
self.system_message = system_message
self.classification_prompt = classification_prompt
self.max_retries = max_retries
self.model = model
self.rate_limit_sleep = rate_limit_sleep
self.other_error_sleep = other_error_sleep
tprint(" Pegasus ", font="rnd-xlarge")
logger.info("初期化パラメータ:")
logger.info(f" output_dir: {output_dir}")
Expand All @@ -40,6 +45,9 @@ def __init__(self, output_dir, exclude_selectors=None, include_domain=None, excl
logger.info(f" system_message: {system_message}")
logger.info(f" classification_prompt: {classification_prompt}")
logger.info(f" max_retries: {max_retries}")
logger.info(f" model: {model}")
logger.info(f" rate_limit_sleep: {rate_limit_sleep}")
logger.info(f" other_error_sleep: {other_error_sleep}")

def filter_site(self, markdown_content):
if(self.classification_prompt is None):
Expand Down Expand Up @@ -69,9 +77,9 @@ def filter_site(self, markdown_content):
logger.warning(f"フィルタリングでエラーが発生しました。リトライします。({retry_count}/{self.max_retries}\nError: {e}")

if "429" in str(e):
sleep_time = 60 # 60秒スリープ
sleep_time = self.rate_limit_sleep # レート制限エラー時のスリープ時間をself.rate_limit_sleepから取得
else:
sleep_time = 10 # その他のエラーの場合は10秒スリープ
sleep_time = self.other_error_sleep # その他のエラー時のスリープ時間をself.other_error_sleepから取得

for _ in tqdm(range(sleep_time), desc="Sleeping", unit="s"):
time.sleep(1)
Expand Down
5 changes: 4 additions & 1 deletion pegasus/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@ def main():
parser.add_argument('--system-message', default=None, help='LiteLLMのシステムメッセージ(サイトの分類に使用)')
parser.add_argument('--classification-prompt', default=None, help='LiteLLMのサイト分類プロンプト(TrueまたはFalseを返すようにしてください)')
parser.add_argument('--max-retries', type=int, default=3, help='フィルタリングのリトライ回数の上限(デフォルト:3)')

parser.add_argument('--model', default='gemini/gemini-1.5-pro-latest', help='LiteLLMのモデル名 (デフォルト: gemini/gemini-1.5-pro-latest)')
parser.add_argument('--rate-limit-sleep', type=int, default=60, help='レート制限エラー時のスリープ時間(秒) (デフォルト: 60)')
parser.add_argument('--other-error-sleep', type=int, default=10, help='その他のエラー時のスリープ時間(秒) (デフォルト: 10)')

args = parser.parse_args()

pegasus = Pegasus(
Expand Down

0 comments on commit b1cd378

Please sign in to comment.