-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathagent_config.py
131 lines (109 loc) · 4.51 KB
/
agent_config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
from typing import List
import os
from string import Template
from council.chains import Chain
from council.evaluators import BasicEvaluator
from council.llm import OpenAILLM, LLMMessage
from council.skills import LLMSkill, PromptToMessages
from council.contexts import SkillContext
from council.prompt import PromptBuilder
from council.runners import Parallel
import constants
from config import Config
from skills import (
DocRetrievalSkill,
GoogleAggregatorSkill,
PandasSkill,
CustomGoogleNewsSkill,
CustomGoogleSearchSkill,
)
from retrieval import Retriever
from controller import Controller
from filter import LLMFilter
import dotenv
dotenv.load_dotenv()
class AgentConfig:
def __init__(self):
# Initializing document retrieval dependencies
self.config = Config(
encoding_name=constants.ENCODING_NAME,
embedding_model_name=constants.EMBEDDING_MODEL_NAME,
)
self.index = self.config.initialize()
self.index_retriever = self.index.as_retriever(
similarity_top_k=constants.NUM_RETRIEVED_DOCUMENTS
)
self.retriever = Retriever(self.config, self.index_retriever)
# Initializing agent config
self._llm_skill_model = OpenAILLM.from_env(
model=constants.DOC_AND_GOOGLE_RETRIEVAL_LLM
)
self._controller_model = OpenAILLM.from_env(model=constants.CONTROLLER_LLM)
self._init_skills()
self.chains = self._init_chains()
self.controller = Controller(
llm=self._controller_model, chains=self.chains, response_threshold=5
)
self.evaluator = BasicEvaluator()
self.filter = LLMFilter(llm=self._controller_model)
def load_config(self):
return {
"controller": self.controller,
"evaluator": self.evaluator,
"filter": self.filter,
}
def _init_skills(self):
# Document retrieval skills
self.doc_retrieval_skill = DocRetrievalSkill(self.retriever)
# Search skills
self.google_search_skill = CustomGoogleSearchSkill()
self.google_news_skill = CustomGoogleNewsSkill()
self.google_aggregator_skill = GoogleAggregatorSkill()
# Pandas skills
self.pandas_skill = PandasSkill(
api_token=os.getenv("OPENAI_API_KEY"), model=constants.PANDAS_LLM
)
# LLM Skill
self.llm_skill = LLMSkill(
llm=self._llm_skill_model,
system_prompt=Template(
"You are a financial analyst whose job is to answer user questions about $company with the provided context."
).substitute(company=constants.COMPANY_NAME),
context_messages=self._build_context_messages,
)
def _init_chains(self) -> List[Chain]:
self.doc_retrieval_chain = Chain(
name="doc_retrieval_chain",
description=f"Information from {constants.COMPANY_NAME} ({constants.COMPANY_TICKER}) 10-K from their 2022 fiscal year, a document that contain important updates for investors about company performance and operations",
runners=[self.doc_retrieval_skill, self.llm_skill],
)
self.search_chain = Chain(
name="search_chain",
description=f"Information about {constants.COMPANY_NAME} ({constants.COMPANY_TICKER}) using a Google search",
runners=[
Parallel(self.google_search_skill, self.google_news_skill),
self.google_aggregator_skill,
self.llm_skill,
],
)
self.pandas_chain = Chain(
name="pandas_chain",
description=f"{constants.COMPANY_NAME} ({constants.COMPANY_TICKER}) historical stock price and trading data information",
runners=[self.pandas_skill],
)
return [self.doc_retrieval_chain, self.search_chain, self.pandas_chain]
@staticmethod
def _build_context_messages(context: SkillContext) -> List[LLMMessage]:
"""Context messages function for LLMSkill"""
prompt = """Use the following pieces of context to answer the query.
If the answer is not provided in the context, do not make up an answer. Instead, respond that you do not know.
CONTEXT:
{{chain_history.last_message}}
END CONTEXT.
QUERY:
{{chat_history.user.last_message}}
END QUERY.
YOUR ANSWER:
"""
context_message_prompt = PromptToMessages(prompt_builder=PromptBuilder(prompt))
return context_message_prompt.to_user_message(context)