diff --git a/CHANGELOG.md b/CHANGELOG.md index d45073d..fb0137b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,10 @@ *Andrea Sponziello* ### **Copyrigth**: *Tiledesk SRL* +## [2024-09-04] +### 0.2.16 +- fix: max_tokens=1024 if citations=True + ## [2024-09-04] ### 0.2.15 - fix: citations without quote diff --git a/pyproject.toml b/pyproject.toml index 9141dc1..52ad5c4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "tilellm" -version = "0.2.15" +version = "0.2.16" description = "tiledesk for RAG" authors = ["Gianluca Lorenzo "] repository = "https://github.com/Tiledesk/tiledesk-llm" diff --git a/tilellm/controller/controller.py b/tilellm/controller/controller.py index 37b4e02..5b17fff 100644 --- a/tilellm/controller/controller.py +++ b/tilellm/controller/controller.py @@ -379,6 +379,7 @@ def get_session_history(session_id: str) -> BaseChatMessageHistory: citations = result['answer'].citations result['answer'], success = verify_answer(result['answer'].answer) + else: conversational_rag_chain = RunnableWithMessageHistory( rag_chain, @@ -396,6 +397,7 @@ def get_session_history(session_id: str) -> BaseChatMessageHistory: result['answer'], success = verify_answer(result['answer']) citations = None + docs = result["context"] # from pprint import pprint # pprint(docs) @@ -417,8 +419,11 @@ def get_session_history(session_id: str) -> BaseChatMessageHistory: ids = list(set(ids)) sources = list(set(sources)) - # source = " ".join(sources) - source = " ".join(set([cit.source_name for cit in citations])) + if question_answer.citations: + source = " ".join(set([cit.source_name for cit in citations])) + else: + source = " ".join(sources) + metadata_id = ids[0] logger.info(f"input: {result['input']}") diff --git a/tilellm/models/item_model.py b/tilellm/models/item_model.py index e8473d1..60a98f5 100644 --- a/tilellm/models/item_model.py +++ b/tilellm/models/item_model.py @@ -1,8 +1,10 @@ -from pydantic import BaseModel, Field, field_validator, ValidationError, model_validator, RootModel +from pydantic import BaseModel, Field, field_validator, ValidationError, model_validator, RootModel, root_validator from typing import Dict, Optional, List, Union, Any import datetime + + class ParametersScrapeType4(BaseModel): unwanted_tags: Optional[List[str]] = Field(default_factory=list) tags_to_extract: Optional[List[str]] = Field(default_factory=list) @@ -108,6 +110,13 @@ def top_k_range(cls, v): raise ValueError("top_k must be a positive integer.") return v + @model_validator(mode='after') + def check_citations_max_tokens(cls, values): + """Sets max_tokens to at least 1024 if citations=True.""" + if values.citations and values.max_tokens < 1024: + values.max_tokens = 1024 + return values + class AWSAuthentication(BaseModel): aws_access_key_id: str @@ -180,7 +189,7 @@ class Citation(BaseModel): ) source_name: str = Field( ..., - description="The Article Source (URL if available) of a SPECIFIC source which justifies the answer.", + description="The Article Source as URL (if available) of a SPECIFIC source which justifies the answer.", ) #quote: str = Field( # ...,