Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve close #12

Merged
merged 1 commit into from
Mar 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions ai.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,13 @@ package ai

import (
"context"
"errors"
"net/http"
"net/url"
)

var ErrAIClosed = errors.New("AI client is nil or already closed")

type AI interface {
Limiter

Expand Down
63 changes: 44 additions & 19 deletions chatgpt/chatgpt.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ const defaultModel = openai.GPT3Dot5Turbo
var _ ai.AI = new(ChatGPT)

type ChatGPT struct {
*openai.Client
c *openai.Client
model string
maxTokens *int32
temperature *float32
Expand All @@ -30,7 +30,10 @@ func New(authToken string) ai.AI {
}

func NewWithClient(client *openai.Client) ai.AI {
return &ChatGPT{Client: client, model: defaultModel}
if client == nil {
panic("cannot create AI from nil client")
}
return &ChatGPT{c: client, model: defaultModel}
}

func (chatgpt *ChatGPT) SetLimit(limit rate.Limit) {
Expand Down Expand Up @@ -74,12 +77,23 @@ func (resp *ChatResponse[Response]) Results() (res []string) {
return
}

func (ai *ChatGPT) createRequest(history []openai.ChatCompletionMessage, messages ...string) (req openai.ChatCompletionRequest) {
func (resp *ChatResponse[Response]) String() string {
if res := resp.Results(); len(res) > 0 {
return res[0]
}
return ""
}

func (ai *ChatGPT) createRequest(
stream bool,
history []openai.ChatCompletionMessage,
messages ...string,
) (req openai.ChatCompletionRequest) {
req.Model = ai.model
if ai.maxTokens != nil {
req.MaxTokens = int(*ai.maxTokens)
}
if ai.count != nil {
if !stream && ai.count != nil {
req.N = int(*ai.count)
}
if ai.temperature != nil {
Expand All @@ -98,15 +112,19 @@ func (ai *ChatGPT) createRequest(history []openai.ChatCompletionMessage, message
return
}

func (ai *ChatGPT) chat(
func (chatgpt *ChatGPT) chat(
ctx context.Context,
history []openai.ChatCompletionMessage,
messages ...string,
) (resp openai.ChatCompletionResponse, err error) {
if err = ai.wait(ctx); err != nil {
if chatgpt.c == nil {
err = ai.ErrAIClosed
return
}
return ai.CreateChatCompletion(ctx, ai.createRequest(history, messages...))
if err = chatgpt.wait(ctx); err != nil {
return
}
return chatgpt.c.CreateChatCompletion(ctx, chatgpt.createRequest(false, history, messages...))
}

func (ai *ChatGPT) Chat(ctx context.Context, messages ...string) (ai.ChatResponse, error) {
Expand All @@ -120,40 +138,47 @@ func (ai *ChatGPT) Chat(ctx context.Context, messages ...string) (ai.ChatRespons
var _ ai.ChatStream = new(ChatStream)

type ChatStream struct {
*openai.ChatCompletionStream
cs *ChatSession
content string
sr *openai.ChatCompletionStream
cs *ChatSession
merged string
}

func (stream *ChatStream) Next() (ai.ChatResponse, error) {
resp, err := stream.Recv()
resp, err := stream.sr.Recv()
if err != nil {
if err == io.EOF {
if stream.cs != nil {
stream.cs.History = append(stream.cs.History, openai.ChatCompletionMessage{
Role: openai.ChatMessageRoleAssistant, Content: stream.content})
Role: openai.ChatMessageRoleAssistant, Content: stream.merged})
}
}
stream.content = ""
stream.merged = ""
return nil, err
}
if stream.cs != nil {
stream.content += resp.Choices[0].Delta.Content
stream.merged += resp.Choices[0].Delta.Content
}
return &ChatResponse[openai.ChatCompletionStreamResponse]{resp}, nil
}

func (ai *ChatGPT) chatStream(
func (stream *ChatStream) Close() error {
return stream.sr.Close()
}

func (chatgpt *ChatGPT) chatStream(
ctx context.Context,
history []openai.ChatCompletionMessage,
messages ...string,
) (*openai.ChatCompletionStream, error) {
if err := ai.wait(ctx); err != nil {
if chatgpt.c == nil {
return nil, ai.ErrAIClosed
}
if err := chatgpt.wait(ctx); err != nil {
return nil, err
}
req := ai.createRequest(history, messages...)
req := chatgpt.createRequest(true, history, messages...)
req.Stream = true
return ai.CreateChatCompletionStream(ctx, req)
return chatgpt.c.CreateChatCompletionStream(ctx, req)
}

func (ai *ChatGPT) ChatStream(ctx context.Context, messages ...string) (ai.ChatStream, error) {
Expand Down Expand Up @@ -189,10 +214,10 @@ func (session *ChatSession) ChatStream(ctx context.Context, messages ...string)
}

func (ai *ChatGPT) ChatSession() ai.Chatbot {
ai.count = nil
return &ChatSession{ai: ai}
}

func (ai *ChatGPT) Close() error {
ai.c = nil
return nil
}
32 changes: 24 additions & 8 deletions gemini/gemini.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package gemini

import (
"context"
"errors"
"io"
"strings"

Expand All @@ -18,7 +19,7 @@ const defaultModel = "gemini-1.0-pro"
var _ ai.AI = new(Gemini)

type Gemini struct {
*genai.Client
c *genai.Client
model *genai.GenerativeModel
config genai.GenerationConfig

Expand All @@ -34,7 +35,7 @@ func New(apiKey string) (ai.AI, error) {
}

func NewWithClient(client *genai.Client) ai.AI {
return &Gemini{Client: client, model: client.GenerativeModel(defaultModel)}
return &Gemini{c: client, model: client.GenerativeModel(defaultModel)}
}

func (gemini *Gemini) SetLimit(limit rate.Limit) {
Expand All @@ -49,7 +50,7 @@ func (ai *Gemini) wait(ctx context.Context) error {
}

func (ai *Gemini) SetModel(model string) {
ai.model = ai.GenerativeModel(model)
ai.model = ai.c.GenerativeModel(model)
ai.model.GenerationConfig = ai.config
}

Expand Down Expand Up @@ -79,6 +80,13 @@ func (resp *ChatResponse) Results() (res []string) {
return
}

func (resp *ChatResponse) String() string {
if res := resp.Results(); len(res) > 0 {
return res[0]
}
return ""
}

func texts2parts(texts []string) (parts []genai.Part) {
for _, i := range texts {
parts = append(parts, genai.Text(i))
Expand All @@ -100,11 +108,14 @@ func (ai *Gemini) Chat(ctx context.Context, parts ...string) (ai.ChatResponse, e
var _ ai.ChatStream = new(ChatStream)

type ChatStream struct {
*genai.GenerateContentResponseIterator
iter *genai.GenerateContentResponseIterator
}

func (stream *ChatStream) Next() (ai.ChatResponse, error) {
resp, err := stream.GenerateContentResponseIterator.Next()
if stream.iter == nil {
return nil, errors.New("stream iterator is nil or already closed")
}
resp, err := stream.iter.Next()
if err != nil {
if err == iterator.Done {
return nil, io.EOF
Expand All @@ -115,6 +126,7 @@ func (stream *ChatStream) Next() (ai.ChatResponse, error) {
}

func (stream *ChatStream) Close() error {
stream.iter = nil
return nil
}

Expand All @@ -129,14 +141,14 @@ var _ ai.Chatbot = new(ChatSession)

type ChatSession struct {
ai *Gemini
*genai.ChatSession
cs *genai.ChatSession
}

func (session *ChatSession) Chat(ctx context.Context, parts ...string) (ai.ChatResponse, error) {
if err := session.ai.wait(ctx); err != nil {
return nil, err
}
resp, err := session.SendMessage(ctx, texts2parts(parts)...)
resp, err := session.cs.SendMessage(ctx, texts2parts(parts)...)
if err != nil {
return nil, err
}
Expand All @@ -147,9 +159,13 @@ func (session *ChatSession) ChatStream(ctx context.Context, parts ...string) (ai
if err := session.ai.wait(ctx); err != nil {
return nil, err
}
return &ChatStream{session.SendMessageStream(ctx, texts2parts(parts)...)}, nil
return &ChatStream{session.cs.SendMessageStream(ctx, texts2parts(parts)...)}, nil
}

func (ai *Gemini) ChatSession() ai.Chatbot {
return &ChatSession{ai, ai.model.StartChat()}
}

func (ai *Gemini) Close() error {
return ai.c.Close()
}
3 changes: 3 additions & 0 deletions gemini/gemini_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ func TestGemini(t *testing.T) {
if err != nil {
t.Fatal(err)
}
defer gemini.Close()
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
defer cancel()
fmt.Println("Who are you?")
Expand All @@ -39,6 +40,7 @@ func TestGemini(t *testing.T) {
if err != nil {
t.Fatal(err)
}
defer stream.Close()
for {
resp, err := stream.Next()
if err != nil {
Expand Down Expand Up @@ -66,6 +68,7 @@ func TestGemini(t *testing.T) {
if err != nil {
t.Fatal(err)
}
defer stream.Close()
for {
resp, err := stream.Next()
if err != nil {
Expand Down