diff --git a/pr_agent/algo/token_handler.py b/pr_agent/algo/token_handler.py index d7eff9d7c..b26fe133e 100644 --- a/pr_agent/algo/token_handler.py +++ b/pr_agent/algo/token_handler.py @@ -49,12 +49,15 @@ def _get_system_user_tokens(self, pr, encoder, vars: dict, system, user): Returns: The sum of the number of tokens in the system and user strings. """ - environment = Environment(undefined=StrictUndefined) - system_prompt = environment.from_string(system).render(vars) - user_prompt = environment.from_string(user).render(vars) - system_prompt_tokens = len(encoder.encode(system_prompt)) - user_prompt_tokens = len(encoder.encode(user_prompt)) - return system_prompt_tokens + user_prompt_tokens + try: + environment = Environment(undefined=StrictUndefined) + system_prompt = environment.from_string(system).render(vars) + user_prompt = environment.from_string(user).render(vars) + system_prompt_tokens = len(encoder.encode(system_prompt)) + user_prompt_tokens = len(encoder.encode(user_prompt)) + return system_prompt_tokens + user_prompt_tokens + except: + return -1 def count_tokens(self, patch: str) -> int: """ @@ -66,4 +69,4 @@ def count_tokens(self, patch: str) -> int: Returns: The number of tokens in the patch string. """ - return len(self.encoder.encode(patch, disallowed_special=())) \ No newline at end of file + return len(self.encoder.encode(patch, disallowed_special=()))