Skip to content

Commit

Permalink
fix: citations
Browse files Browse the repository at this point in the history
  • Loading branch information
glorenzo972 committed Sep 4, 2024
1 parent d6eb2e9 commit 47db8f6
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 5 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@
*Andrea Sponziello*
### **Copyrigth**: *Tiledesk SRL*

## [2024-09-04]
### 0.2.16
- fix: max_tokens=1024 if citations=True

## [2024-09-04]
### 0.2.15
- fix: citations without quote
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "tilellm"
version = "0.2.15"
version = "0.2.16"
description = "tiledesk for RAG"
authors = ["Gianluca Lorenzo <gianluca.lorenzo@gmail.com>"]
repository = "https://github.com/Tiledesk/tiledesk-llm"
Expand Down
9 changes: 7 additions & 2 deletions tilellm/controller/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,7 @@ def get_session_history(session_id: str) -> BaseChatMessageHistory:
citations = result['answer'].citations
result['answer'], success = verify_answer(result['answer'].answer)


else:
conversational_rag_chain = RunnableWithMessageHistory(
rag_chain,
Expand All @@ -396,6 +397,7 @@ def get_session_history(session_id: str) -> BaseChatMessageHistory:
result['answer'], success = verify_answer(result['answer'])
citations = None


docs = result["context"]
# from pprint import pprint
# pprint(docs)
Expand All @@ -417,8 +419,11 @@ def get_session_history(session_id: str) -> BaseChatMessageHistory:
ids = list(set(ids))
sources = list(set(sources))

# source = " ".join(sources)
source = " ".join(set([cit.source_name for cit in citations]))
if question_answer.citations:
source = " ".join(set([cit.source_name for cit in citations]))
else:
source = " ".join(sources)

metadata_id = ids[0]

logger.info(f"input: {result['input']}")
Expand Down
13 changes: 11 additions & 2 deletions tilellm/models/item_model.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from pydantic import BaseModel, Field, field_validator, ValidationError, model_validator, RootModel
from pydantic import BaseModel, Field, field_validator, ValidationError, model_validator, RootModel, root_validator
from typing import Dict, Optional, List, Union, Any
import datetime




class ParametersScrapeType4(BaseModel):
unwanted_tags: Optional[List[str]] = Field(default_factory=list)
tags_to_extract: Optional[List[str]] = Field(default_factory=list)
Expand Down Expand Up @@ -108,6 +110,13 @@ def top_k_range(cls, v):
raise ValueError("top_k must be a positive integer.")
return v

@model_validator(mode='after')
def check_citations_max_tokens(cls, values):
"""Sets max_tokens to at least 1024 if citations=True."""
if values.citations and values.max_tokens < 1024:
values.max_tokens = 1024
return values


class AWSAuthentication(BaseModel):
aws_access_key_id: str
Expand Down Expand Up @@ -180,7 +189,7 @@ class Citation(BaseModel):
)
source_name: str = Field(
...,
description="The Article Source (URL if available) of a SPECIFIC source which justifies the answer.",
description="The Article Source as URL (if available) of a SPECIFIC source which justifies the answer.",
)
#quote: str = Field(
# ...,
Expand Down

0 comments on commit 47db8f6

Please sign in to comment.