-
Notifications
You must be signed in to change notification settings - Fork 10
/
langgraph.ts
131 lines (112 loc) · 3.75 KB
/
langgraph.ts
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import { AIMessage, BaseMessage, HumanMessage } from "@langchain/core/messages";
import { tool } from "@langchain/core/tools";
import { MemorySaver, StateGraph, StateGraphArgs } from "@langchain/langgraph";
import { ToolNode } from "@langchain/langgraph/prebuilt";
import { ChatOpenAI } from "@langchain/openai";
import { green, yellow } from "cli-color";
import "dotenv/config";
import { z } from "zod";
import { LiteralClient } from "@literalai/client";
// Create a new Literal Client and a Langchain compatible callback
const literalClient = new LiteralClient();
const cb = literalClient.instrumentation.langchain.literalCallback();
// Define the state interface
interface AgentState {
messages: BaseMessage[];
}
// Define the graph state
const graphState: StateGraphArgs<AgentState>["channels"] = {
messages: {
reducer: (x: BaseMessage[], y: BaseMessage[]) => x.concat(y),
},
};
// Define the tools for the agent to use
const weatherTool = tool(
async ({ query }) => {
if (
query.toLowerCase().includes("sf") ||
query.toLowerCase().includes("san francisco")
) {
return "It's 60 degrees and foggy.";
}
return "It's 90 degrees and sunny.";
},
{
name: "weather",
description: "Call to get the current weather for a location.",
schema: z.object({
query: z.string().describe("The query to use in your search."),
}),
}
);
const tools = [weatherTool];
const toolNode = new ToolNode<AgentState>(tools);
const model = new ChatOpenAI({
model: "gpt-4o-mini",
temperature: 0,
}).bindTools(tools);
// Define the function that determines whether to continue or not
function shouldContinue(state: AgentState) {
const messages = state.messages;
const lastMessage = messages[messages.length - 1] as AIMessage;
// If the LLM makes a tool call, then we route to the "tools" node
if (lastMessage.tool_calls?.length) {
return "tools";
}
// Otherwise, we stop (reply to the user)
return "__end__";
}
// Define the function that calls the model
async function callModel(state: AgentState) {
const messages = state.messages;
const response = await model.invoke(messages);
// We return a list, because this will get added to the existing list
return { messages: [response] };
}
// Define a new graph
const workflow = new StateGraph<AgentState>({ channels: graphState })
.addNode("agent", callModel)
.addNode("tools", toolNode)
.addEdge("__start__", "agent")
.addConditionalEdges("agent", shouldContinue)
.addEdge("tools", "agent");
// Initialize memory to persist state between graph runs
const checkpointer = new MemorySaver();
// Finally, we compile it!
// This compiles it into a LangChain Runnable.
const app = workflow.compile({ checkpointer });
async function main() {
console.log(green("> what is an LLM"));
const response = await model.invoke([new HumanMessage("what is an LLM")], {
callbacks: [cb],
});
console.log(yellow(response.content));
literalClient.thread({ name: "Weather Wrap" }).wrap(async () => {
console.log(green("> what is the weather in sf"));
// Use the Runnable
const finalState = await app.invoke(
{ messages: [new HumanMessage("what is the weather in sf")] },
{
configurable: { thread_id: "Weather Thread" },
runName: "weather",
callbacks: [cb],
}
);
console.log(
yellow(finalState.messages[finalState.messages.length - 1].content)
);
console.log(green("> what about ny"));
const nextState = await app.invoke(
{ messages: [new HumanMessage("what about ny")] },
{
configurable: { thread_id: "Weather Thread" },
runName: "weather",
callbacks: [cb],
}
);
console.log(
yellow(nextState.messages[nextState.messages.length - 1].content)
);
});
}
main();