Skip to content

Commit

Permalink
Simplifying the indexing of action tokens (#477)
Browse files Browse the repository at this point in the history
  • Loading branch information
jamesbraza authored Sep 24, 2024
1 parent 350edd4 commit 6ea52b8
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 13 deletions.
13 changes: 2 additions & 11 deletions paperqa/agents/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from paperqa.docs import Docs
from paperqa.llms import EmbeddingModel, LiteLLMModel
from paperqa.settings import Settings
from paperqa.types import Answer, LLMResult
from paperqa.types import Answer
from paperqa.utils import get_year

from .models import QueryRequest
Expand Down Expand Up @@ -150,16 +150,7 @@ def export_frame(self) -> Frame:
async def step(
self, action: ToolRequestMessage
) -> tuple[list[Message], float, bool, bool]:

# add usage for action if it has usage
info = action.info
if info and "usage" in info and "model" in info:
r = LLMResult(
model=info["model"],
prompt_count=info["usage"][0],
completion_count=info["usage"][1],
)
self.state.answer.add_tokens(r)
self.state.answer.add_tokens(action) # Add usage for action if present

# If the action has empty tool_calls, the agent can later take that into account
msgs = cast(
Expand Down
13 changes: 11 additions & 2 deletions paperqa/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

import litellm # for cost
import tiktoken
from aviary.message import Message
from pybtex.database import BibliographyData, Entry, Person
from pybtex.database.input.bibtex import Parser
from pybtex.scanner import PybtexSyntaxError
Expand Down Expand Up @@ -198,8 +199,16 @@ def get_citation(self, name: str) -> str:
raise ValueError(f"Could not find docname {name} in contexts.") from exc
return doc.citation

def add_tokens(self, result: LLMResult) -> None:
"""Update the token counts for the given result."""
def add_tokens(self, result: LLMResult | Message) -> None:
"""Update the token counts for the given LLM result or message."""
if isinstance(result, Message):
if not result.info or any(x not in result.info for x in ("model", "usage")):
return
result = LLMResult(
model=result.info["model"],
prompt_count=result.info["usage"][0],
completion_count=result.info["usage"][1],
)
if result.model not in self.token_counts:
self.token_counts[result.model] = [
result.prompt_count,
Expand Down

0 comments on commit 6ea52b8

Please sign in to comment.