forked from streamlit/example-app-langchain-rag
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathbasic_chain.py
60 lines (44 loc) · 1.79 KB
/
basic_chain.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import os
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_openai import ChatOpenAI
from langchain_community.llms import HuggingFaceHub
from langchain_community.chat_models.huggingface import ChatHuggingFace
from dotenv import load_dotenv
MISTRAL_ID = "mistralai/Mistral-7B-Instruct-v0.1"
ZEPHYR_ID = "HuggingFaceH4/zephyr-7b-beta"
def get_model(repo_id=ZEPHYR_ID, **kwargs):
if repo_id == "ChatGPT":
chat_model = ChatOpenAI(temperature=0, **kwargs)
else:
huggingfacehub_api_token = kwargs.get("HUGGINGFACEHUB_API_TOKEN", None)
if not huggingfacehub_api_token:
huggingfacehub_api_token = os.environ.get("HUGGINGFACEHUB_API_TOKEN", None)
os.environ["HF_TOKEN"] = huggingfacehub_api_token
llm = HuggingFaceHub(
repo_id=repo_id,
task="text-generation",
model_kwargs={
"max_new_tokens": 512,
"top_k": 30,
"temperature": 0.1,
"repetition_penalty": 1.03,
"huggingfacehub_api_token": huggingfacehub_api_token,
})
chat_model = ChatHuggingFace(llm=llm)
return chat_model
def basic_chain(model=None, prompt=None):
if not model:
model = get_model()
if not prompt:
prompt = ChatPromptTemplate.from_template("Tell me the most noteworthy books by the author {author}")
chain = prompt | model
return chain
def main():
load_dotenv()
prompt = ChatPromptTemplate.from_template("Tell me the most noteworthy books by the author {author}")
chain = basic_chain(prompt=prompt) | StrOutputParser()
results = chain.invoke({"author": "William Faulkner"})
print(results)
if __name__ == '__main__':
main()