Skip to content

Commit

Permalink
refactor of repository classes
Browse files Browse the repository at this point in the history
  • Loading branch information
glorenzo972 committed Jun 15, 2024
1 parent b98f302 commit e210874
Show file tree
Hide file tree
Showing 4 changed files with 255 additions and 8 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,11 @@
*Andrea Sponziello*
### **Copyrigth**: *Tiledesk SRL*

## [2024-06-15]
### 0.2.1
- update: langchain v. 0.1.16
- modified: prompt for q&A

## [2024-06-08]
### 0.2.0
- refactor: refactor repository in order to manage pod and serverless
Expand Down
10 changes: 5 additions & 5 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "tilellm"
version = "0.2.0"
version = "0.2.1"
description = "tiledesk for RAG"
authors = ["Gianluca Lorenzo <gianluca.lorenzo@gmail.com>"]
repository = "https://github.com/Tiledesk/tiledesk-llm"
Expand All @@ -18,14 +18,14 @@ jsonschema= "^4.20.0"
redis= "^5.0.0"
aioredis= "^2.0.0"
#redismutex = "^1.0.0"
langchain = "^0.1.9"
langchain = "^0.1.16"
jq = "^1.6.0"
openai = "^1.12.0"
langchain_openai = "^0.0.7"
langchain_openai = "0.0.x"
pinecone-client = "^3.1.0"
python-dotenv = "^1.0.1"
langchain_community = "^0.0.24"
tiktoken = "^0.6.0"
langchain_community = "0.0.x"
tiktoken = "0.6.x"
beautifulsoup4 ="^4.12.3"
#uvicorn = "^0.28"
unstructured= "^0.12.6"
Expand Down
235 changes: 232 additions & 3 deletions tilellm/controller/openai_controller.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import uuid

import fastapi
from langchain.chains import ConversationalRetrievalChain, LLMChain # Per la conversazione va usata questa classe
from langchain.chains import ConversationalRetrievalChain, LLMChain # Deprecata
from langchain_core.prompts import PromptTemplate, SystemMessagePromptTemplate
from langchain_openai import ChatOpenAI
# from tilellm.store.pinecone_repository import add_pc_item as pinecone_add_item
Expand All @@ -8,15 +10,24 @@
from langchain_community.callbacks.openai_info import OpenAICallbackHandler
from tilellm.models.item_model import RetrievalResult, ChatEntry
from tilellm.shared.utility import inject_repo
import tilellm.shared.const as const
# from tilellm.store.pinecone_repository_base import PineconeRepositoryBase

from langchain.chains import create_history_aware_retriever
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.runnables.history import RunnableWithMessageHistory
from langchain.chains import create_retrieval_chain
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain_community.chat_message_histories import ChatMessageHistory
from langchain_core.chat_history import BaseChatMessageHistory

import logging

logger = logging.getLogger(__name__)


@inject_repo
async def ask_with_memory(question_answer, repo=None):
async def ask_with_memory1(question_answer, repo=None):

try:
logger.info(question_answer)
Expand Down Expand Up @@ -59,6 +70,7 @@ async def ask_with_memory(question_answer, repo=None):
# pprint(len(mydocs))

if question_answer.system_context is not None and question_answer.system_context:
print("blocco if")
from langchain.chains import LLMChain

# prompt_template = "Tell me a {adjective} joke"
Expand All @@ -78,6 +90,7 @@ async def ask_with_memory(question_answer, repo=None):
llm=llm,
retriever=retriever,
return_source_documents=True,
verbose=True,
combine_docs_chain_kwargs={"prompt": sys_prompt}
)
# from pprint import pprint
Expand All @@ -90,15 +103,22 @@ async def ask_with_memory(question_answer, repo=None):
)

else:
print("blocco else")
#PromptTemplate.from_template()
crc = ConversationalRetrievalChain.from_llm(llm=llm,
retriever=retriever,
return_source_documents=True)
return_source_documents=True,
verbose=True)


# 'Use the following pieces of context to answer the user\'s question. If you don\'t know the answer, just say that you don\'t know, don\'t try to make up an answer.',
result = crc.invoke({'question': question_answer.question,
'chat_history': question_answer_list}
)

docs = result["source_documents"]
from pprint import pprint
pprint(result)

ids = []
sources = []
Expand Down Expand Up @@ -153,6 +173,206 @@ async def ask_with_memory(question_answer, repo=None):
raise fastapi.exceptions.HTTPException(status_code=400, detail=result_to_return.model_dump())


@inject_repo
async def ask_with_memory(question_answer, repo=None):
try:
logger.info(question_answer)
# question = str
# namespace: str
# gptkey: str
# model: str =Field(default="gpt-3.5-turbo")
# temperature: float = Field(default=0.0)
# top_k: int = Field(default=5)
# max_tokens: int = Field(default=128)
# system_context: Optional[str]
# chat_history_dict : Dict[str, ChatEntry]

question_answer_list = []
if question_answer.chat_history_dict is not None:
for key, entry in question_answer.chat_history_dict.items():
question_answer_list.append((entry.question, entry.answer))

logger.info(question_answer_list)
openai_callback_handler = OpenAICallbackHandler()

llm = ChatOpenAI(model_name=question_answer.model,
temperature=question_answer.temperature,
openai_api_key=question_answer.gptkey,
max_tokens=question_answer.max_tokens,
callbacks=[openai_callback_handler])

emb_dimension = repo.get_embeddings_dimension(question_answer.embedding)
oai_embeddings = OpenAIEmbeddings(api_key=question_answer.gptkey, model=question_answer.embedding)

vector_store = await repo.create_pc_index(oai_embeddings, emb_dimension)

retriever = vector_store.as_retriever(search_type='similarity',
search_kwargs={'k': question_answer.top_k,
'namespace': question_answer.namespace}
)

if question_answer.system_context is not None and question_answer.system_context:

# Contextualize question
contextualize_q_system_prompt = const.contextualize_q_system_prompt
contextualize_q_prompt = ChatPromptTemplate.from_messages(
[
("system", contextualize_q_system_prompt),
MessagesPlaceholder("chat_history"),
("human", "{input}"),
]
)
history_aware_retriever = create_history_aware_retriever(
llm, retriever, contextualize_q_prompt
)

# Answer question
qa_system_prompt = question_answer.system_context
qa_prompt = ChatPromptTemplate.from_messages(
[
("system", qa_system_prompt),
MessagesPlaceholder("chat_history"),
("human", "{input}"),
]
)

question_answer_chain = create_stuff_documents_chain(llm, qa_prompt)

rag_chain = create_retrieval_chain(history_aware_retriever, question_answer_chain)

store = {}

def get_session_history(session_id: str) -> BaseChatMessageHistory:
if session_id not in store:
store[session_id] = ChatMessageHistory()
return store[session_id]

conversational_rag_chain = RunnableWithMessageHistory(
rag_chain,
get_session_history,
input_messages_key="input",
history_messages_key="chat_history",
output_messages_key="answer",
)

result = conversational_rag_chain.invoke(
{"input": question_answer.question, 'chat_history': question_answer_list},
config={"configurable": {"session_id": uuid.uuid4().hex}
}, # constructs a key "abc123" in `store`.
)

else:
# Contextualize question
contextualize_q_system_prompt = const.contextualize_q_system_prompt
contextualize_q_prompt = ChatPromptTemplate.from_messages(
[
("system", contextualize_q_system_prompt),
MessagesPlaceholder("chat_history"),
("human", "{input}"),
]
)
history_aware_retriever = create_history_aware_retriever(
llm, retriever, contextualize_q_prompt
)

# Answer question
qa_system_prompt = const.qa_system_prompt
qa_prompt = ChatPromptTemplate.from_messages(
[
("system", qa_system_prompt),
MessagesPlaceholder("chat_history"),
("human", "{input}"),
]
)

question_answer_chain = create_stuff_documents_chain(llm, qa_prompt)

rag_chain = create_retrieval_chain(history_aware_retriever, question_answer_chain)

store = {}

def get_session_history(session_id: str) -> BaseChatMessageHistory:
if session_id not in store:
store[session_id] = ChatMessageHistory()
return store[session_id]

conversational_rag_chain = RunnableWithMessageHistory(
rag_chain,
get_session_history,
input_messages_key="input",
history_messages_key="chat_history",
output_messages_key="answer",
)

result = conversational_rag_chain.invoke(
{"input": question_answer.question, 'chat_history': question_answer_list},
config={"configurable": {"session_id": uuid.uuid4().hex}
}, # constructs a key "abc123" in `store`.
)

# print(store)
# print(question_answer_list)

docs = result["context"]
from pprint import pprint
pprint(docs)

ids = []
sources = []
for doc in docs:
ids.append(doc.metadata['id'])
sources.append(doc.metadata['source'])

ids = list(set(ids))
sources = list(set(sources))
source = " ".join(sources)
metadata_id = ids[0]

logger.info(result)
print(result['answer'])
result['answer'], success = verify_answer(result['answer'])

question_answer_list.append((result['input'], result['answer']))

chat_entries = [ChatEntry(question=q, answer=a) for q, a in question_answer_list]
chat_history_dict = {str(i): entry for i, entry in enumerate(chat_entries)}



# success = bool(openai_callback_handler.successful_requests)
prompt_token_size = openai_callback_handler.total_tokens

result_to_return = RetrievalResult(
answer=result['answer'],
namespace=question_answer.namespace,
sources=sources,
ids=ids,
source=source,
id=metadata_id,
prompt_token_size=prompt_token_size,
success=success,
error_message=None,
chat_history_dict=chat_history_dict
)

return result_to_return.dict()
except Exception as e:
import traceback
traceback.print_exc()
question_answer_list = []
if question_answer.chat_history_dict is not None:
for key, entry in question_answer.chat_history_dict.items():
question_answer_list.append((entry.question, entry.answer))
chat_entries = [ChatEntry(question=q, answer=a) for q, a in question_answer_list]
chat_history_dict = {str(i): entry for i, entry in enumerate(chat_entries)}

result_to_return = RetrievalResult(
namespace=question_answer.namespace,
error_message=repr(e),
chat_history_dict=chat_history_dict
)
raise fastapi.exceptions.HTTPException(status_code=400, detail=result_to_return.model_dump())

@inject_repo
async def ask_with_sequence(question_answer, repo=None):
try:
Expand Down Expand Up @@ -407,3 +627,12 @@ def get_idproduct_chain(llm) -> LLMChain:
)

return LLMChain(llm=llm, prompt=summary_prompt_template)


def verify_answer(s):
if s.endswith("<NOANS>"):
s = s[:-7] # Rimuove <NOANS> dalla fine della stringa
success = False
else:
success = True
return s, success
13 changes: 13 additions & 0 deletions tilellm/shared/const.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,19 @@
PINECONE_INDEX = None
PINECONE_TEXT_KEY = None

contextualize_q_system_prompt = """Given a chat history and the latest user question \
which might reference context in the chat history, formulate a standalone question \
which can be understood without the chat history. Do NOT answer the question, \
just reformulate it if needed and otherwise return it as is."""

qa_system_prompt = """You are an helpful assistant for question-answering tasks. \
Use ONLY the following pieces of retrieved context to answer the question. \
If you don't know the answer, just say that you don't know. \
If none of the retrieved context answer the question, add this word to the end <NOANS> \
{context}"""


def populate_constant():
global PINECONE_API_KEY, PINECONE_INDEX, PINECONE_TEXT_KEY
Expand Down

0 comments on commit e210874

Please sign in to comment.