diff --git a/ai.go b/ai.go index 65398eb..05e54cc 100644 --- a/ai.go +++ b/ai.go @@ -14,6 +14,8 @@ type AI interface { Chatbot ChatSession() Chatbot + + Close() error } type Model interface { @@ -30,6 +32,7 @@ type Chatbot interface { type ChatStream interface { Next() (ChatResponse, error) + Close() error } type ChatResponse interface { diff --git a/chatgpt/chatgpt.go b/chatgpt/chatgpt.go index c69d4e2..c16558b 100644 --- a/chatgpt/chatgpt.go +++ b/chatgpt/chatgpt.go @@ -144,6 +144,11 @@ func (stream *ChatStream) Next() (ai.ChatResponse, error) { return &ChatResponse[openai.ChatCompletionStreamResponse]{resp}, nil } +func (stream *ChatStream) Close() error { + stream.ChatCompletionStream.Close() + return nil +} + func (ai *ChatGPT) chatStream( ctx context.Context, history []openai.ChatCompletionMessage, @@ -193,3 +198,7 @@ func (ai *ChatGPT) ChatSession() ai.Chatbot { ai.count = nil return &ChatSession{ai: ai} } + +func (ai *ChatGPT) Close() error { + return nil +} diff --git a/gemini/gemini.go b/gemini/gemini.go index 325f197..6f9f649 100644 --- a/gemini/gemini.go +++ b/gemini/gemini.go @@ -114,6 +114,10 @@ func (stream *ChatStream) Next() (ai.ChatResponse, error) { return &ChatResponse{resp}, nil } +func (stream *ChatStream) Close() error { + return nil +} + func (ai *Gemini) ChatStream(ctx context.Context, parts ...string) (ai.ChatStream, error) { if err := ai.wait(ctx); err != nil { return nil, err