Skip to content

Commit

Permalink
AI interface add Model method to return model name (#32)
Browse files Browse the repository at this point in the history
  • Loading branch information
sunshineplan authored Apr 1, 2024
1 parent 237104e commit 5244ba9
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 14 deletions.
1 change: 1 addition & 0 deletions ai.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ var ErrAIClosed = errors.New("AI client is nil or already closed")

type AI interface {
LLMs() LLMs
Model(context.Context) (string, error)

Limiter

Expand Down
18 changes: 11 additions & 7 deletions chatgpt/chatgpt.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ const defaultModel = openai.GPT3Dot5Turbo
var _ ai.AI = new(ChatGPT)

type ChatGPT struct {
c *openai.Client
*openai.Client
model string
maxTokens *int32
temperature *float32
Expand Down Expand Up @@ -62,13 +62,17 @@ func NewWithClient(client *openai.Client, model string) ai.AI {
if model == "" {
model = defaultModel
}
return &ChatGPT{c: client, model: model}
return &ChatGPT{Client: client, model: model}
}

func (ChatGPT) LLMs() ai.LLMs {
return ai.ChatGPT
}

func (chatgpt *ChatGPT) Model(_ context.Context) (string, error) {
return chatgpt.model, nil
}

func (chatgpt *ChatGPT) SetLimit(limit rate.Limit) {
chatgpt.limiter = ai.NewLimiter(limit)
}
Expand Down Expand Up @@ -151,14 +155,14 @@ func (chatgpt *ChatGPT) chat(
history []openai.ChatCompletionMessage,
messages ...string,
) (resp openai.ChatCompletionResponse, err error) {
if chatgpt.c == nil {
if chatgpt.Client == nil {
err = ai.ErrAIClosed
return
}
if err = chatgpt.wait(ctx); err != nil {
return
}
return chatgpt.c.CreateChatCompletion(ctx, chatgpt.createRequest(session, history, messages...))
return chatgpt.CreateChatCompletion(ctx, chatgpt.createRequest(session, history, messages...))
}

func (ai *ChatGPT) Chat(ctx context.Context, messages ...string) (ai.ChatResponse, error) {
Expand Down Expand Up @@ -204,15 +208,15 @@ func (chatgpt *ChatGPT) chatStream(
history []openai.ChatCompletionMessage,
messages ...string,
) (*openai.ChatCompletionStream, error) {
if chatgpt.c == nil {
if chatgpt.Client == nil {
return nil, ai.ErrAIClosed
}
if err := chatgpt.wait(ctx); err != nil {
return nil, err
}
req := chatgpt.createRequest(true, history, messages...)
req.Stream = true
return chatgpt.c.CreateChatCompletionStream(ctx, req)
return chatgpt.CreateChatCompletionStream(ctx, req)
}

func (ai *ChatGPT) ChatStream(ctx context.Context, messages ...string) (ai.ChatStream, error) {
Expand Down Expand Up @@ -270,6 +274,6 @@ func (ai *ChatGPT) ChatSession() ai.ChatSession {
}

func (ai *ChatGPT) Close() error {
ai.c = nil
ai.Client = nil
return nil
}
18 changes: 11 additions & 7 deletions gemini/gemini.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ const defaultModel = "gemini-1.0-pro"
var _ ai.AI = new(Gemini)

type Gemini struct {
c *genai.Client
*genai.Client
model *genai.GenerativeModel
config genai.GenerationConfig

Expand Down Expand Up @@ -64,13 +64,21 @@ func NewWithClient(client *genai.Client, model string) ai.AI {
if model == "" {
model = defaultModel
}
return &Gemini{c: client, model: client.GenerativeModel(model)}
return &Gemini{Client: client, model: client.GenerativeModel(model)}
}

func (Gemini) LLMs() ai.LLMs {
return ai.Gemini
}

func (gemini *Gemini) Model(ctx context.Context) (string, error) {
info, err := gemini.model.Info(ctx)
if err != nil {
return "", err
}
return info.Name, nil
}

func (gemini *Gemini) SetLimit(limit rate.Limit) {
gemini.limiter = ai.NewLimiter(limit)
}
Expand All @@ -83,7 +91,7 @@ func (ai *Gemini) wait(ctx context.Context) error {
}

func (ai *Gemini) SetModel(model string) {
ai.model = ai.c.GenerativeModel(model)
ai.model = ai.GenerativeModel(model)
ai.model.GenerationConfig = ai.config
}

Expand Down Expand Up @@ -208,7 +216,3 @@ func (session *ChatSession) History() (history []ai.Message) {
func (ai *Gemini) ChatSession() ai.ChatSession {
return &ChatSession{ai, ai.model.StartChat()}
}

func (ai *Gemini) Close() error {
return ai.c.Close()
}

0 comments on commit 5244ba9

Please sign in to comment.