From aaeaa94a02f78bbd7e4e518f7345e7e049ef28ab Mon Sep 17 00:00:00 2001 From: glorenzo972 Date: Sat, 31 Aug 2024 16:53:36 +0200 Subject: [PATCH] add: citations --- CHANGELOG.md | 4 + pyproject.toml | 4 +- tilellm/__main__.py | 22 ++- tilellm/agents/shopify_agent.py | 130 ++++++++++++++++ tilellm/controller/controller.py | 216 +++++++++++++++++--------- tilellm/models/item_model.py | 56 ++++++- tilellm/shared/const.py | 119 +++++++++++++- tilellm/tools/document_tool_simple.py | 2 +- tilellm/tools/shopify_tool.py | 74 +++++++++ 9 files changed, 549 insertions(+), 78 deletions(-) create mode 100644 tilellm/agents/shopify_agent.py create mode 100644 tilellm/tools/shopify_tool.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 07945eb..1680bc9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,10 @@ *Andrea Sponziello* ### **Copyrigth**: *Tiledesk SRL* +## [2024-08-31] +### 0.2.12 +- add: citations + ## [2024-07-31] ### 0.2.11 - fix: log diff --git a/pyproject.toml b/pyproject.toml index 89b01d2..682fea7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "tilellm" -version = "0.2.11" +version = "0.2.12" description = "tiledesk for RAG" authors = ["Gianluca Lorenzo "] repository = "https://github.com/Tiledesk/tiledesk-llm" @@ -41,6 +41,8 @@ docx2txt= "0.8" wikipedia= "1.4.0" html2text= "2024.2.26" psutil= "6.0.0" +httpx= "0.27.0" +gql= "3.5.0" [tool.poetry.dependencies.uvicorn] diff --git a/tilellm/__main__.py b/tilellm/__main__.py index f1a1260..1bbfc8d 100644 --- a/tilellm/__main__.py +++ b/tilellm/__main__.py @@ -22,7 +22,8 @@ ScrapeStatusReq, ScrapeStatusResponse, PineconeIndexingResult, RetrievalResult, PineconeNamespaceResult, - PineconeDescNamespaceResult, PineconeItems, QuestionToLLM, SimpleAnswer) + PineconeDescNamespaceResult, PineconeItems, QuestionToLLM, SimpleAnswer, + QuestionToAgent) from tilellm.store.redis_repository import redis_xgroup_create from tilellm.controller.controller import (ask_with_memory, @@ -36,7 +37,7 @@ get_desc_namespace, get_list_namespace, get_sources_namespace, - ask_to_llm) + ask_to_llm, ask_to_agent) import logging @@ -350,6 +351,23 @@ async def post_ask_with_memory_main(question_answer: QuestionAnswer): return JSONResponse(content=result.model_dump()) +@app.post("/api/agent", response_model=SimpleAnswer) +async def post_ask_to_agent_main(question_to_agent: QuestionToAgent): + """ + Query and Aswer with chat history + :param question_to_agent: + :return: SimpleAnswer + """ + print(question_to_agent) + logger.debug(question_to_agent) + + result = await ask_to_agent(question_to_agent) + + logger.debug(result) + return JSONResponse(content=result.model_dump()) + + + @app.post("/api/ask", response_model=SimpleAnswer) async def post_ask_to_llm_main(question: QuestionToLLM): """ diff --git a/tilellm/agents/shopify_agent.py b/tilellm/agents/shopify_agent.py new file mode 100644 index 0000000..c89cce4 --- /dev/null +++ b/tilellm/agents/shopify_agent.py @@ -0,0 +1,130 @@ +from langchain_core.prompts import PromptTemplate + + + +from langchain.agents import create_react_agent, AgentExecutor +from langchain_core.tools import Tool +from langchain_openai import ChatOpenAI + +from tilellm.tools.shopify_tool import get_graphql_answer +from tilellm.shared.const import react_prompt_template +from functools import partial + + +def lookup(question_to_agent, chat_model, chat_history:str): + #You are an API agent. + template1 = """ + Given the question {question} I want you to find the answer. the first step is to create a GraphQL query + for Shopify Admin client that answer the question, then pass the query to tool in GraphQL format. + (USE ONLY GraphQL format. No comment is needed! Use the parameter in the same language as the question). + Use this schema {schema} for GraphQL and not exceed 10 items. + Examples of GraphQL query: + - query {{ products(first: 100) {{ edges {{ node {{ id title price }} }} }} }} + - query {{ products(first: 10) {{ edges {{ node {{ id title }} }} }} }} + - query {{ products(first: 10, query: "price:<50") {{ edges {{ node {{ title variants(first: 1) {{ edges {{ node {{ price }} }} }} }} }} }} }} + + + In Your Final answer, use the same language of the question, the response should interpret and summarize the key information from the query result + in a clear and concise manner. If there isn't product that answer the question, simply say that there isn't products. + """ + template=""" + Follow these instructions exactly to answer the question: {question} + + 1. Create a GraphQL query for Shopify Admin client that answers the question. + - Use ONLY GraphQL format, no comments. + - Use parameters in the same language as the question. + - Use this schema: {schema} + - Limit results to a maximum of 10 items. + + 2. Query format: + query {{ + // Your code here + }} + + 3. Examples of valid queries: + - query {{ products(first: 10) {{ edges {{ node {{ id title price }} }} }} }} + - query {{ products(first: 10, query: "price:<50") {{ edges {{ node {{ title variants(first: 1) {{ edges {{ node {{ price }} }} }} }} }} }} }} + + 4. Present ONLY the GraphQL query, nothing else. + + 5. After receiving the query results, provide the final answer: + - Use the same language as the original question. + - Interpret and summarize key information from the query results. + - Be clear and concise. + - If there are no products that answer the question, state this explicitly. + + Remember: Follow these instructions to the letter. Do not add explanations or comments that are not requested. + """ + question = question_to_agent.question + + for tool in question_to_agent.tools: + if 'shopify' in tool: + shopify_tool = tool['shopify'] + + # Safely access the root dictionary + api_key = shopify_tool.root.get('api_key') + url = shopify_tool.root.get('url') + break # Exit the loop once 'shopify' is found + + get_graphql_answer_with_key = partial(get_graphql_answer, url=url, api_key=api_key) + tools_for_agent_shopify = [ + Tool( + name="Retrieve content from Shopify given GraphQL query", + func=get_graphql_answer_with_key, + description="useful when you need get the useful information from shopify" + + ), + ] + + schema = """ + products(first: 10, query: "") { + edges { + node { + title + variants(first: 1, last: 10) { + edges { + node { + price + availableForSale + barcode + displayName + id + } + } + } + bodyHtml + descriptionHtml + id + productType + tags + totalInventory + } + } + } + """ + prompt_template = PromptTemplate( + input_variables=["question", "schema"], template=template + ) + + #react_prompt = hub.pull("hwchase17/react") + react_prompt = PromptTemplate.from_template(react_prompt_template) + + agent = create_react_agent( + llm=chat_model, tools=tools_for_agent_shopify, prompt=react_prompt + ) + + agent_executor = AgentExecutor( + agent=agent, + tools=tools_for_agent_shopify, + verbose=True, + max_iterations=4, + early_stopping_method="force", + handle_parsing_errors=True + ) + + result = agent_executor.invoke( + input={"input": prompt_template.format_prompt(question=question, schema=schema), + "chat_history": chat_history} + ) + + return result diff --git a/tilellm/controller/controller.py b/tilellm/controller/controller.py index 03c583d..5084e22 100644 --- a/tilellm/controller/controller.py +++ b/tilellm/controller/controller.py @@ -1,18 +1,23 @@ import uuid +from typing import List import fastapi from langchain.chains import ConversationalRetrievalChain, LLMChain # Deprecata from langchain.retrievers import ContextualCompressionRetriever from langchain.retrievers.document_compressors import DocumentCompressorPipeline from langchain_community.document_transformers import EmbeddingsRedundantFilter +from langchain_core.documents import Document from langchain_core.prompts import PromptTemplate, SystemMessagePromptTemplate +from langchain_core.runnables import RunnablePassthrough from langchain_openai import ChatOpenAI # from tilellm.store.pinecone_repository import add_pc_item as pinecone_add_item # from tilellm.store.pinecone_repository import create_pc_index, get_embeddings_dimension from langchain_openai import OpenAIEmbeddings from langchain_community.callbacks.openai_info import OpenAICallbackHandler +from pydantic.v1 import BaseModel, Field + from tilellm.models.item_model import RetrievalResult, ChatEntry, PineconeIndexingResult, PineconeNamespaceResult, \ - PineconeDescNamespaceResult, PineconeItems, SimpleAnswer + PineconeDescNamespaceResult, PineconeItems, SimpleAnswer, QuotedAnswer from tilellm.shared.utility import inject_repo, inject_llm import tilellm.shared.const as const # from tilellm.store.pinecone_repository_base import PineconeRepositoryBase @@ -25,6 +30,10 @@ from langchain_community.chat_message_histories import ChatMessageHistory from langchain_core.chat_history import BaseChatMessageHistory +from langchain_community.agent_toolkits.load_tools import load_tools +from langchain.agents import AgentType, initialize_agent +from tilellm.agents.shopify_agent import lookup as shopify_lookup_agent + from langchain.schema import( AIMessage, HumanMessage, @@ -295,44 +304,60 @@ async def ask_with_memory(question_answer, repo=None) -> RetrievalResult: base_compressor=pipeline_compressor, base_retriever=vs_retriever ) + # Contextualize question + contextualize_q_system_prompt = const.contextualize_q_system_prompt + contextualize_q_prompt = ChatPromptTemplate.from_messages( + [ + ("system", contextualize_q_system_prompt), + MessagesPlaceholder("chat_history"), + ("human", "{input}"), + ] + ) + history_aware_retriever = create_history_aware_retriever( + llm, retriever, contextualize_q_prompt + ) + if question_answer.system_context is not None and question_answer.system_context: + # Answer question - prompt from user + qa_system_prompt = question_answer.system_context + else: + # Answer question - prompt default + qa_system_prompt = const.qa_system_prompt - # Contextualize question - contextualize_q_system_prompt = const.contextualize_q_system_prompt - contextualize_q_prompt = ChatPromptTemplate.from_messages( - [ - ("system", contextualize_q_system_prompt), - MessagesPlaceholder("chat_history"), - ("human", "{input}"), - ] - ) - history_aware_retriever = create_history_aware_retriever( - llm, retriever, contextualize_q_prompt - ) + qa_prompt = ChatPromptTemplate.from_messages( + [ + ("system", qa_system_prompt), + MessagesPlaceholder("chat_history"), + ("human", "{input}"), + ] + ) - # Answer question - qa_system_prompt = question_answer.system_context - qa_prompt = ChatPromptTemplate.from_messages( - [ - ("system", qa_system_prompt), - MessagesPlaceholder("chat_history"), - ("human", "{input}"), - ] - ) + question_answer_chain = create_stuff_documents_chain(llm, qa_prompt) + + rag_chain = create_retrieval_chain(history_aware_retriever, question_answer_chain) + + store = {} - question_answer_chain = create_stuff_documents_chain(llm, qa_prompt) + def get_session_history(session_id: str) -> BaseChatMessageHistory: + if session_id not in store: + store[session_id] = load_session_history(question_answer.chat_history_dict) + return store[session_id] - rag_chain = create_retrieval_chain(history_aware_retriever, question_answer_chain) - store = {} + if question_answer.citations: + rag_chain_from_docs = ( + RunnablePassthrough.assign(context=(lambda x: format_docs_with_id(x["context"]))) + | qa_prompt + | llm.with_structured_output(QuotedAnswer) + ) - def get_session_history(session_id: str) -> BaseChatMessageHistory: - if session_id not in store: - store[session_id] = load_session_history(question_answer.chat_history_dict) - return store[session_id] + retrieve_docs = (lambda x: x["input"]) | retriever + chain_w_citations = RunnablePassthrough.assign(context=retrieve_docs).assign( + answer=rag_chain_from_docs + ) conversational_rag_chain = RunnableWithMessageHistory( - rag_chain, + chain_w_citations, get_session_history, input_messages_key="input", history_messages_key="chat_history", @@ -340,47 +365,17 @@ def get_session_history(session_id: str) -> BaseChatMessageHistory: ) result = conversational_rag_chain.invoke( - {"input": question_answer.question, }, #'chat_history': chat_history_list}, + {"input": question_answer.question, }, # 'chat_history': chat_history_list}, config={"configurable": {"session_id": uuid.uuid4().hex} - }, # constructs a key "abc123" in `store`. + } # constructs a key "abc123" in `store`. ) - else: - # Contextualize question - contextualize_q_system_prompt = const.contextualize_q_system_prompt - contextualize_q_prompt = ChatPromptTemplate.from_messages( - [ - ("system", contextualize_q_system_prompt), - MessagesPlaceholder("chat_history"), - ("human", "{input}"), - ] - ) - # logger.info(contextualize_q_prompt) - history_aware_retriever = create_history_aware_retriever( - llm, retriever, contextualize_q_prompt - ) - - # Answer question - qa_system_prompt = const.qa_system_prompt - qa_prompt = ChatPromptTemplate.from_messages( - [ - ("system", qa_system_prompt), - MessagesPlaceholder("chat_history"), - ("human", "{input}"), - ] - ) - - question_answer_chain = create_stuff_documents_chain(llm, qa_prompt) - - rag_chain = create_retrieval_chain(history_aware_retriever, question_answer_chain) - - store = {} - - def get_session_history(session_id: str) -> BaseChatMessageHistory: - if session_id not in store: - store[session_id] = load_session_history(question_answer.chat_history_dict) - return store[session_id] + # from pprint import pprint + # pprint(result["answer"]) + citations = result['answer'].citations + result['answer'], success = verify_answer(result['answer'].answer) + else: conversational_rag_chain = RunnableWithMessageHistory( rag_chain, get_session_history, @@ -390,10 +385,12 @@ def get_session_history(session_id: str) -> BaseChatMessageHistory: ) result = conversational_rag_chain.invoke( - {"input": question_answer.question}, #'chat_history': chat_history_list}, + {"input": question_answer.question, }, # 'chat_history': chat_history_list}, config={"configurable": {"session_id": uuid.uuid4().hex} - }, # constructs a key "abc123" in `store`. + } # constructs a key "abc123" in `store`. ) + result['answer'], success = verify_answer(result['answer']) + citations = None docs = result["context"] # from pprint import pprint @@ -422,7 +419,7 @@ def get_session_history(session_id: str) -> BaseChatMessageHistory: logger.info(f"chat_history: {result['chat_history']}") logger.info(f"answer: {result['answer']}") - result['answer'], success = verify_answer(result['answer']) + question_answer_list.append((result['input'], result['answer'])) @@ -439,6 +436,7 @@ def get_session_history(session_id: str) -> BaseChatMessageHistory: ids=ids, source=source, id=metadata_id, + citations = citations, prompt_token_size=prompt_token_size, content_chunks=content_chunks, success=success, @@ -464,6 +462,75 @@ def get_session_history(session_id: str) -> BaseChatMessageHistory: ) raise fastapi.exceptions.HTTPException(status_code=400, detail=result_to_return.model_dump()) +@inject_llm +async def ask_to_agent(question_to_agent, chat_model=None): + try: + logger.info(question_to_agent) + #chat_history_list = [] + #tools = load_tools( + # question_to_agent.tools, + # endpoint="https://swapi-graphql.netlify.app/.netlify/functions/index", + # api_key="ciccio" + #) + #agent = initialize_agent( + # tools, chat_model, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True + #) + #result = agent.invoke(question_to_agent.question) + result_history = "" + if question_to_agent.chat_history_dict is not None: + #for key, entry in question_to_agent.chat_history_dict.items(): + # chat_history_list.append(HumanMessage(content=entry.question)) # ('human', entry.question)) + # chat_history_list.append(AIMessage(content=entry.answer)) + # "chat_history": "Human: My name is Bob\\nAI: Hello Bob!", + + for i in range(len(question_to_agent.chat_history_dict)): + entry = question_to_agent.chat_history_dict[str(i)] + result_history += f"Human: {entry.question}\n" + result_history += f"AI: {entry.answer}\n" + + result_history = result_history.strip() # Remove trailing newline + print(result_history) + + + #qa_prompt = ChatPromptTemplate.from_messages( + # [ + # ("system", question_to_agent.system_context), + # MessagesPlaceholder("tools_result"), + # MessagesPlaceholder("chat_history", n_messages=question_to_agent.n_messages), + # ("human", "{input}"), + # ] + #) + + #store = {} + + #def get_session_history(session_id: str) -> BaseChatMessageHistory: + # if session_id not in store: + # store[session_id] = load_session_history(question_to_agent.chat_history_dict) # ChatMessageHistory() + # return store[session_id] + + result_shopify = shopify_lookup_agent(question_to_agent=question_to_agent, chat_model=chat_model, chat_history=result_history) + print(f"RESULT: {result_shopify.get('output')} type: {type(result_shopify.get('output'))}") + + + if not question_to_agent.chat_history_dict: + question_to_agent.chat_history_dict = {} + + num = len(question_to_agent.chat_history_dict.keys()) + question_to_agent.chat_history_dict[str(num)] = dict({"question": question_to_agent.question, "answer": result_shopify.get("output")}) + + answer_to_agent = SimpleAnswer(answer=result_shopify.get("output"), chat_history_dict=question_to_agent.chat_history_dict) + print(answer_to_agent) + return answer_to_agent + + except Exception as e: + import traceback + traceback.print_exc() + question_answer_list = [] + + result_to_return = SimpleAnswer(answer=repr(e), + chat_history_dict={}) + raise fastapi.exceptions.HTTPException(status_code=400, detail=result_to_return.model_dump()) + @inject_repo async def ask_with_sequence(question_answer, repo=None) -> RetrievalResult: @@ -760,3 +827,12 @@ def load_session_history(history) -> BaseChatMessageHistory: chat_history.add_message(HumanMessage(content=entry.question)) # ('human', entry.question)) chat_history.add_message(AIMessage(content=entry.answer)) return chat_history + + +def format_docs_with_id(docs: List[Document]) -> str: + formatted = [ + f"Source ID: {i}\nArticle Source: {doc.metadata['source']}\nArticle Snippet: {doc.page_content}" + for i, doc in enumerate(docs) + ] + return "\n\n" + "\n\n".join(formatted) + diff --git a/tilellm/models/item_model.py b/tilellm/models/item_model.py index 8deb78a..6ac4563 100644 --- a/tilellm/models/item_model.py +++ b/tilellm/models/item_model.py @@ -1,5 +1,5 @@ -from pydantic import BaseModel, Field, field_validator, ValidationError, model_validator -from typing import Dict, Optional, List, Union +from pydantic import BaseModel, Field, field_validator, ValidationError, model_validator, RootModel +from typing import Dict, Optional, List, Union, Any import datetime @@ -89,6 +89,7 @@ class QuestionAnswer(BaseModel): embedding: str = Field(default_factory=lambda: "text-embedding-ada-002") similarity_threshold: float = Field(default_factory=lambda: 1.0) debug: bool = Field(default_factory=lambda: False) + citations: bool = Field(default_factory=lambda: False) system_context: Optional[str] = None search_type: str = Field(default_factory=lambda: "similarity") chat_history_dict: Optional[Dict[str, ChatEntry]] = None @@ -148,6 +149,56 @@ def max_tokens_range(cls, v): return v +class ToolOptions(RootModel[Dict[str, Any]]): + #__root__: Dict[str, Any] = Field(default_factory=dict) + pass + + +class QuestionToAgent(BaseModel): + question: str + llm_key: Union[str, AWSAuthentication] + llm: str + model: str = Field(default="gpt-3.5-turbo") + tools: Optional[List[Dict[str, ToolOptions]]] = Field(default_factory=dict) + system_context: str = Field(default="You are a helpful AI bot. Always reply in the same language of the question.") + temperature: float = Field(default=0.0) + max_tokens: int = Field(default=128) + chat_history_dict: Optional[Dict[str, ChatEntry]] = None + n_messages: int = Field(default_factory=lambda: None) + + @field_validator("n_messages") + def n_messages_range(cls, v): + """Ensures n_messages is within greater than 0""" + if not v > 0: + raise ValueError("n_messages must be greater than 0") + return v + +class Citation(BaseModel): + source_id: int = Field( + ..., + description="The integer ID of a SPECIFIC source which justifies the answer.", + ) + source_name: str = Field( + ..., + description="The Article Source of a SPECIFIC source which justifies the answer.", + ) + quote: str = Field( + ..., + description="The VERBATIM quote from the specified source that justifies the answer.", + ) + +class QuotedAnswer(BaseModel): + """Answer the user question based only on the given sources, and cite the sources used.""" + + answer: str = Field( + ..., + description="The answer to the user question, which is based only on the given sources.", + ) + citations: List[Citation] = Field( + ..., description="Citations from the given sources that justify the answer." + ) + + class SimpleAnswer(BaseModel): answer: str = Field(default="No answer") chat_history_dict: Optional[Dict[str, ChatEntry]] @@ -161,6 +212,7 @@ class RetrievalResult(BaseModel): ids: Optional[List[str]] | None = None source: str | None = None sources: Optional[List[str]] | None = None + citations: Optional[List[Citation]] | None = None content_chunks: Optional[List[str]] | None = None prompt_token_size: int = Field(default=0) error_message: Optional[str] | None = None diff --git a/tilellm/shared/const.py b/tilellm/shared/const.py index 25db834..4c7dca6 100644 --- a/tilellm/shared/const.py +++ b/tilellm/shared/const.py @@ -19,8 +19,7 @@ The first step is to extrac relevant information to the question from retrieved context. If you don't know the answer, just say that you don't know. \ Respond with "No relevant information were found " if no relevant information were found. - - + #### @@ -28,6 +27,122 @@ #### """ +qa_system_prompt1 = """You are an AI assistant specialized in question-answering tasks. \ + Your goal is to provide accurate and helpful answers based solely on the given context. \ + Follow these instructions carefully: + + 1. You will be provided with a context delimited by #### tags. \ + This context contains the only information you should use to answer the question. \ + Do not use any external knowledge or information not present in the given context. + + 2. Here is the context you must use: + + #### + {CONTEXT} + #### + + + 3. Now, here is the question you need to answer: + + {QUESTION} + + + 4. To answer the question, follow these steps: + a. Carefully read through the context and extract all information relevant to the question. \ + If you find relevant information, proceed to step b. \ + If you don't find any relevant information, skip to step c. + + b. Using only the relevant information you extracted, formulate a clear and concise answer \ + to the question. Make sure your answer is directly based on the context provided and does not\ + include any external knowledge or assumptions. + + c. If you couldn't find any relevant information in the context to answer the question, \ + respond with exactly this phrase: "No relevant information were found " + + 5. Present your answer in the following format: + + [Your answer goes here. If you found relevant information, provide your answer based on the \ + context. If no relevant information was found, write the phrase specified in step 4c.] + + + Remember, if you're unsure or don't have enough information to answer the question accurately, \ + it's better to admit that you don't know rather than making guesses or using information not \ + provided in the context. +""" +qa_system_prompt_citations=""" +You are an AI assistant tasked with answering questions based on a given document while providing citations for the information used. Follow these instructions carefully: + +1. You will be provided with a document to analyze. The document is enclosed in tags: + + +{{DOCUMENT}} + + +2. You will then be given a question to answer based on the information in the document. The question is enclosed in tags: + + +{{QUESTION}} + + +3. Carefully read through the document and identify the most relevant parts that contain information to answer the question. + +4. When you find relevant information, make a mental note of its location in the document. You will use this to provide citations later. + +5. Formulate your answer based on the relevant information you've found. Your answer should be comprehensive and accurate. + +6. When writing your answer, you should cite the relevant parts of the document. To do this, use square brackets with a number inside, like this: [1], [2], etc. Place these citations immediately after the information you're referencing. + +7. After your main answer, provide a "References" section. In this section, list out the actual text from the document that you cited, preceded by the corresponding number in square brackets. + +8. Format your entire response as follows: + + +[Your comprehensive answer here, with citations in square brackets] + +References: +[1] [Exact quote from the document] +[2] [Exact quote from the document] +... + + +9. If the question cannot be answered using the information in the document, state this clearly in your answer and explain why. + +10. Do not include any information that is not present in the given document. + +Remember, your goal is to provide an accurate, well-cited answer based solely on the information in the given document. + +""" + +react_prompt_template = """ +Answer the following questions as best you can. You have access to the following tools: + +{tools} + +Use the following format: + +Question: the input question you must answer + +Thought: you should always think about what to do + +Action: the action to take, should be one of [{tool_names}] + +Action Input: the input to the action + +Observation: the result of the action + +... (this Thought/Action/Action Input/Observation can repeat N times) + +Thought: I now know the final answer + +Final Answer: the final answer to the original input question + +Begin! + +Question: {input} + +Thought:{agent_scratchpad} +""" + def populate_constant(): global PINECONE_API_KEY, PINECONE_INDEX, PINECONE_TEXT_KEY, VOYAGEAI_API_KEY diff --git a/tilellm/tools/document_tool_simple.py b/tilellm/tools/document_tool_simple.py index 8a3488b..32f8dcd 100644 --- a/tilellm/tools/document_tool_simple.py +++ b/tilellm/tools/document_tool_simple.py @@ -59,7 +59,7 @@ async def get_content_by_url(url: str, scrape_type: int, **kwargs) -> list[Docu bs_transformer = BeautifulSoupTransformer() docs_transformed = bs_transformer.transform_documents(docs, tags_to_extract=params_type_4.tags_to_extract, - unwanted_tags =params_type_4.unwanted_tags, + unwanted_tags=params_type_4.unwanted_tags, unwanted_classnames=params_type_4.unwanted_classnames, remove_lines=params_type_4.remove_lines, remove_comments=params_type_4.remove_comments diff --git a/tilellm/tools/shopify_tool.py b/tilellm/tools/shopify_tool.py new file mode 100644 index 0000000..4251c7b --- /dev/null +++ b/tilellm/tools/shopify_tool.py @@ -0,0 +1,74 @@ +import httpx +import json +import re + +def get_graphql_answer(input, url, api_key): + print(f"==========> {input}") + #print("==========>" + url) + #print("==========>" + api_key) + + headers = { + "content-type": "application/json", + "X-Shopify-Access-Token": api_key} + + print(f"========= {input}") + gql_query = clean_graphql_query(input) + #gql_query = input + #inputparam = json.loads(input) + #response = httpx.post(url, json=inputparam["query"], headers=headers) + + # Print the response + #print(response.json()) + + from gql import gql, Client + from gql.transport.httpx import HTTPXTransport #aiohttp import AIOHTTPTransport + + # Select your transport with a defined url endpoint + #transport = AIOHTTPTransport(url=url, headers=headers) + transport = HTTPXTransport(url=url, headers=headers) + + # Create a GraphQL client using the defined transport + client = Client(transport=transport, fetch_schema_from_transport=True) + + # Provide a GraphQL query + print(f"========= {gql_query}") + query = gql(gql_query) + #print(f"========= query gql {query}") + # Execute the query on the transport + + result = client.execute(query) + print(f"risultato query {result}") + return result + #print(f"endpoint: {url}, api_key: {api_key}") + + + + +import re + +def clean_graphql_query(query): + # Remove leading and trailing whitespace + query = query.strip() + + # Handle the case where query is in the format query='QUERY' + match = re.match(r'^(?:const\s+)?(?:var\s+)?(?:let\s+)?(?:query\s*=\s*[\'"]*)(.*?)([\'"]*)\s*$', query, re.DOTALL) + if match: + query = match.group(1) + + # Case 1 and 2: Remove enclosing backticks and language specifiers + query = re.sub(r'^```(?:graphql|query)?\s*', '', query) + query = re.sub(r'\s*```$', '', query) + + # Remove any remaining backticks + query = query.replace('`', '') + + # Unescape any escaped quotes + query = query.replace('\\"', '"').replace("\\'", "'") + + # Ensure the query is properly closed + open_braces = query.count('{') + close_braces = query.count('}') + if open_braces > close_braces: + query += '}' * (open_braces - close_braces) + + return query.strip()