From 5244ba92eade5f76b2e627aabe9fcb9530f890b0 Mon Sep 17 00:00:00 2001 From: sunshineplan Date: Mon, 1 Apr 2024 15:07:53 +0800 Subject: [PATCH] AI interface add Model method to return model name (#32) --- ai.go | 1 + chatgpt/chatgpt.go | 18 +++++++++++------- gemini/gemini.go | 18 +++++++++++------- 3 files changed, 23 insertions(+), 14 deletions(-) diff --git a/ai.go b/ai.go index c4e5ecc..b196898 100644 --- a/ai.go +++ b/ai.go @@ -9,6 +9,7 @@ var ErrAIClosed = errors.New("AI client is nil or already closed") type AI interface { LLMs() LLMs + Model(context.Context) (string, error) Limiter diff --git a/chatgpt/chatgpt.go b/chatgpt/chatgpt.go index 1d0cb30..95f1ab3 100644 --- a/chatgpt/chatgpt.go +++ b/chatgpt/chatgpt.go @@ -17,7 +17,7 @@ const defaultModel = openai.GPT3Dot5Turbo var _ ai.AI = new(ChatGPT) type ChatGPT struct { - c *openai.Client + *openai.Client model string maxTokens *int32 temperature *float32 @@ -62,13 +62,17 @@ func NewWithClient(client *openai.Client, model string) ai.AI { if model == "" { model = defaultModel } - return &ChatGPT{c: client, model: model} + return &ChatGPT{Client: client, model: model} } func (ChatGPT) LLMs() ai.LLMs { return ai.ChatGPT } +func (chatgpt *ChatGPT) Model(_ context.Context) (string, error) { + return chatgpt.model, nil +} + func (chatgpt *ChatGPT) SetLimit(limit rate.Limit) { chatgpt.limiter = ai.NewLimiter(limit) } @@ -151,14 +155,14 @@ func (chatgpt *ChatGPT) chat( history []openai.ChatCompletionMessage, messages ...string, ) (resp openai.ChatCompletionResponse, err error) { - if chatgpt.c == nil { + if chatgpt.Client == nil { err = ai.ErrAIClosed return } if err = chatgpt.wait(ctx); err != nil { return } - return chatgpt.c.CreateChatCompletion(ctx, chatgpt.createRequest(session, history, messages...)) + return chatgpt.CreateChatCompletion(ctx, chatgpt.createRequest(session, history, messages...)) } func (ai *ChatGPT) Chat(ctx context.Context, messages ...string) (ai.ChatResponse, error) { @@ -204,7 +208,7 @@ func (chatgpt *ChatGPT) chatStream( history []openai.ChatCompletionMessage, messages ...string, ) (*openai.ChatCompletionStream, error) { - if chatgpt.c == nil { + if chatgpt.Client == nil { return nil, ai.ErrAIClosed } if err := chatgpt.wait(ctx); err != nil { @@ -212,7 +216,7 @@ func (chatgpt *ChatGPT) chatStream( } req := chatgpt.createRequest(true, history, messages...) req.Stream = true - return chatgpt.c.CreateChatCompletionStream(ctx, req) + return chatgpt.CreateChatCompletionStream(ctx, req) } func (ai *ChatGPT) ChatStream(ctx context.Context, messages ...string) (ai.ChatStream, error) { @@ -270,6 +274,6 @@ func (ai *ChatGPT) ChatSession() ai.ChatSession { } func (ai *ChatGPT) Close() error { - ai.c = nil + ai.Client = nil return nil } diff --git a/gemini/gemini.go b/gemini/gemini.go index cb50980..7f59fc8 100644 --- a/gemini/gemini.go +++ b/gemini/gemini.go @@ -21,7 +21,7 @@ const defaultModel = "gemini-1.0-pro" var _ ai.AI = new(Gemini) type Gemini struct { - c *genai.Client + *genai.Client model *genai.GenerativeModel config genai.GenerationConfig @@ -64,13 +64,21 @@ func NewWithClient(client *genai.Client, model string) ai.AI { if model == "" { model = defaultModel } - return &Gemini{c: client, model: client.GenerativeModel(model)} + return &Gemini{Client: client, model: client.GenerativeModel(model)} } func (Gemini) LLMs() ai.LLMs { return ai.Gemini } +func (gemini *Gemini) Model(ctx context.Context) (string, error) { + info, err := gemini.model.Info(ctx) + if err != nil { + return "", err + } + return info.Name, nil +} + func (gemini *Gemini) SetLimit(limit rate.Limit) { gemini.limiter = ai.NewLimiter(limit) } @@ -83,7 +91,7 @@ func (ai *Gemini) wait(ctx context.Context) error { } func (ai *Gemini) SetModel(model string) { - ai.model = ai.c.GenerativeModel(model) + ai.model = ai.GenerativeModel(model) ai.model.GenerationConfig = ai.config } @@ -208,7 +216,3 @@ func (session *ChatSession) History() (history []ai.Message) { func (ai *Gemini) ChatSession() ai.ChatSession { return &ChatSession{ai, ai.model.StartChat()} } - -func (ai *Gemini) Close() error { - return ai.c.Close() -}