-
Notifications
You must be signed in to change notification settings - Fork 5
/
query.py
61 lines (51 loc) · 1.89 KB
/
query.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import os
import pymongo
from dotenv import load_dotenv
from langchain_community.llms import OpenAI
from langchain.chains import RetrievalQA
from langchain_nomic import NomicEmbeddings
from langchain.vectorstores import MongoDBAtlasVectorSearch
def chatbot(ticker, query):
'''
Performs vector search on MongoDB Atlas vector store to retrieve relevant embeddings
while filtering results based on ticker specified.
Uses OpenAI's gpt-3.5-turbo to generate a response given the retrieved embeddings.
Args:
ticker (str): The stock ticker (e.g. AAPL) specified for filtering.
query (str): The question inputted by the user.
Returns:
retriever_output (str): The chatbot's answer.
'''
CLUSTER_NAME = os.getenv("CLUSTER_NAME")
DB_NAME = os.getenv("DB_NAME")
COLLECTION_NAME = os.getenv("COLLECTION_NAME")
client = pymongo.MongoClient(CLUSTER_NAME)
database = client[DB_NAME]
collection = database[COLLECTION_NAME]
ATLAS_VECTOR_SEARCH_INDEX_NAME = "vector_index"
embeddings = NomicEmbeddings(
nomic_api_key=os.getenv("NOMIC_API_KEY"),
model='nomic-embed-text-v1.5',
)
# Define the filter based on the metadata field and value
vector_search = MongoDBAtlasVectorSearch(
embedding=embeddings,
collection=collection,
index_name=ATLAS_VECTOR_SEARCH_INDEX_NAME,
)
filter_dict = {"metadata.ticker": ticker}
llm = OpenAI(
openai_api_key=os.getenv("OPENAI_API_KEY"),
temperature=0
)
retriever = vector_search.as_retriever(filter=filter_dict)
qa = RetrievalQA.from_chain_type(
llm=llm,
chain_type="stuff",
retriever=retriever,
)
retriever_output = qa.run(query)
return retriever_output
if __name__ == "__main__":
load_dotenv()
print(chatbot("TSLA", "Is the market for energy storage products competitive?"))