Skip to content

Commit

Permalink
add: history on /api/ask
Browse files Browse the repository at this point in the history
  • Loading branch information
glorenzo972 committed Jul 27, 2024
1 parent 3803d70 commit 9767f10
Show file tree
Hide file tree
Showing 5 changed files with 70 additions and 23 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,11 @@
*Andrea Sponziello*
### **Copyrigth**: *Tiledesk SRL*

## [2024-07-27]
### 0.2.8
- add: history on /api/ask


## [2024-07-26]
### 0.2.7
- add: scrape_type=3|4
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "tilellm"
version = "0.2.7"
version = "0.2.8"
description = "tiledesk for RAG"
authors = ["Gianluca Lorenzo <gianluca.lorenzo@gmail.com>"]
repository = "https://github.com/Tiledesk/tiledesk-llm"
Expand Down
62 changes: 48 additions & 14 deletions tilellm/controller/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,30 +186,64 @@ async def ask_with_memory1(question_answer, repo=None):
async def ask_to_llm(question, chat_model=None):
try:
logger.info(question)
if question.llm == "cohere":
quest = question.system_context+" "+question.question
messages = [
HumanMessage(content=quest)
]
else:
messages = [
SystemMessage(content=question.system_context),
HumanMessage(content=question.question)
chat_history_list = []

if question.chat_history_dict is not None:
for key, entry in question.chat_history_dict.items():
chat_history_list.append(HumanMessage(content=entry.question)) # ('human', entry.question))
chat_history_list.append(AIMessage(content=entry.answer))

# from pprint import pprint
# pprint(chat_history_list)

qa_prompt = ChatPromptTemplate.from_messages(
[
("system", question.system_context),
MessagesPlaceholder("chat_history_a"),
("human", "{input}"),
]
)

store = {}

def get_session_history(session_id: str) -> BaseChatMessageHistory:
if session_id not in store:
store[session_id] = ChatMessageHistory()
return store[session_id]

runnable = qa_prompt | chat_model

a = chat_model.invoke(messages)
return SimpleAnswer(content=a.content)
runnable_with_history = RunnableWithMessageHistory(
runnable,
get_session_history,
input_messages_key="input",

)

result = await runnable_with_history.ainvoke(
{"input": question.question, 'chat_history_a': chat_history_list},
config={"configurable": {"session_id": uuid.uuid4().hex}
},
)

if not question.chat_history_dict:
question.chat_history_dict = {}

num = len(question.chat_history_dict.keys())
question.chat_history_dict[str(num)] = {"question": question.question, "answer": result.content}

return SimpleAnswer(answer=result.content, chat_history_dict=question.chat_history_dict)

except Exception as e:
import traceback
traceback.print_exc()
question_answer_list = []

result_to_return = SimpleAnswer(
content=repr(e)
)
result_to_return = SimpleAnswer(answer=repr(e),
chat_history_dict={})
raise fastapi.exceptions.HTTPException(status_code=400, detail=result_to_return.model_dump())


@inject_repo
async def ask_with_memory(question_answer, repo=None) -> RetrievalResult:
try:
Expand Down
4 changes: 3 additions & 1 deletion tilellm/models/item_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ class QuestionToLLM(BaseModel):
max_tokens: int = Field(default=128)
debug: bool = Field(default_factory=lambda: False)
system_context: str = Field(default="You are a helpful AI bot. Always reply in the same language of the question.")
chat_history_dict: Optional[Dict[str, ChatEntry]] = None

@field_validator("temperature")
def temperature_range(cls, v):
Expand All @@ -140,7 +141,8 @@ def max_tokens_range(cls, v):


class SimpleAnswer(BaseModel):
content: str
answer: str = Field(default="No answer")
chat_history_dict: Optional[Dict[str, ChatEntry]]


class RetrievalResult(BaseModel):
Expand Down
20 changes: 13 additions & 7 deletions tilellm/shared/utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ async def wrapper(self, item, *args, **kwargs):
def inject_llm(func):
@wraps(func)
async def wrapper(question, *args, **kwargs):
print(question)
logger.debug(question)
if question.llm == "openai":
chat_model = ChatOpenAI(api_key=question.llm_key,
model=question.model,
Expand Down Expand Up @@ -131,16 +131,21 @@ async def wrapper(question, *args, **kwargs):
# region_name="eu-central-1"
# )

import boto3
#session = boto3.Session(
# aws_access_key_id=question.llm_key.aws_secret_access_key,
# import boto3

# client_br = boto3.client('bedrock-runtime',
# aws_access_key_id=question.llm_key.aws_secret_access_key,
# aws_secret_access_key=question.llm_key.aws_secret_access_key,
# region_name=question.llm_key.region_name
# )
# session = boto3.Session(aws_access_key_id=question.llm_key.aws_secret_access_key,
# aws_secret_access_key=question.llm_key.aws_secret_access_key,
# region_name=question.llm_key.region_name
# )


# client_ss = session.client("bedrock-runtime")

chat_model = ChatBedrockConverse(
# client=client_br,
model=question.model,
temperature=question.temperature,
max_tokens=question.max_tokens,
Expand All @@ -150,7 +155,8 @@ async def wrapper(question, *args, **kwargs):

) # model_kwargs={"temperature": 0.001},

#print(chat_model.session)
# print(chat_model.client._get_credentials().access_key)


else:
chat_model = ChatOpenAI(api_key=question.llm_key,
Expand Down

0 comments on commit 9767f10

Please sign in to comment.