Skip to content

Commit

Permalink
New config and Set up proxy separately (#27)
Browse files Browse the repository at this point in the history
* New AI client by config

* New config and Set up proxy separately

* Fix
  • Loading branch information
sunshineplan authored Mar 29, 2024
1 parent 49dbbcd commit 299d1a1
Show file tree
Hide file tree
Showing 8 changed files with 178 additions and 80 deletions.
6 changes: 5 additions & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,11 @@ jobs:
- name: Test Code
env:
CHATGPT_API_KEY: ${{ secrets.CHATGPT_API_KEY }}
CHATGPT_BASE_URL: ${{ secrets.CHATGPT_BASE_URL }}
CHATGPT_PROXY: ${{ secrets.CHATGPT_PROXY }}
CHATGPT_ENDPOINT: ${{ secrets.CHATGPT_ENDPOINT }}
CHATGPT_MODEL: ${{ secrets.CHATGPT_MODEL }}
GEMINI_API_KEY: ${{ secrets.GEMINI_API_KEY }}
GEMINI_PROXY: ${{ secrets.GEMINI_PROXY }}
GEMINI_ENDPOINT: ${{ secrets.GEMINI_ENDPOINT }}
GEMINI_MODEL: ${{ secrets.GEMINI_MODEL }}
run: go test -v
11 changes: 0 additions & 11 deletions ai.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@ package ai
import (
"context"
"errors"
"net/http"
"net/url"
)

var ErrAIClosed = errors.New("AI client is nil or already closed")
Expand Down Expand Up @@ -53,12 +51,3 @@ type ChatStream interface {
type ChatResponse interface {
Results() []string
}

func SetProxy(proxy string) error {
u, err := url.Parse(proxy)
if err != nil {
return err
}
http.DefaultTransport.(*http.Transport).Proxy = http.ProxyURL(u)
return nil
}
27 changes: 19 additions & 8 deletions ai_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,6 @@ import (
"github.com/sunshineplan/ai/gemini"
)

func init() {
if proxy := os.Getenv("AI_PROXY"); proxy != "" {
ai.SetProxy(proxy)
}
}

func testChat(ai ai.AI, prompt string) error {
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
defer cancel()
Expand Down Expand Up @@ -97,11 +91,18 @@ func TestGemini(t *testing.T) {
if apiKey == "" {
return
}
gemini, err := gemini.NewWithEndpoint(apiKey, os.Getenv("GEMINI_ENDPOINT"))
gemini, err := gemini.New(
ai.WithAPIKey(apiKey),
ai.WithEndpoint(os.Getenv("GEMINI_ENDPOINT")),
ai.WithProxy(os.Getenv("GEMINI_PROXY")),
)
if err != nil {
t.Fatal(err)
}
defer gemini.Close()
if model := os.Getenv("GEMINI_MODEL"); model != "" {
gemini.SetModel(model)
}
if err := testChat(gemini, "Who are you?"); err != nil {
t.Error(err)
}
Expand All @@ -118,8 +119,18 @@ func TestChatGPT(t *testing.T) {
if apiKey == "" {
return
}
chatgpt := chatgpt.NewWithBaseURL(apiKey, os.Getenv("CHATGPT_BASE_URL"))
chatgpt, err := chatgpt.New(
ai.WithAPIKey(apiKey),
ai.WithEndpoint(os.Getenv("CHATGPT_ENDPOINT")),
ai.WithProxy(os.Getenv("CHATGPT_PROXY")),
)
if err != nil {
t.Fatal(err)
}
defer chatgpt.Close()
if model := os.Getenv("CHATGPT_MODEL"); model != "" {
chatgpt.SetModel(model)
}
if err := testChat(chatgpt, "Who are you?"); err != nil {
t.Error(err)
}
Expand Down
43 changes: 32 additions & 11 deletions chatgpt/chatgpt.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ package chatgpt
import (
"context"
"io"
"net/http"
"net/url"

"github.com/sunshineplan/ai"

Expand All @@ -25,23 +27,42 @@ type ChatGPT struct {
limiter *rate.Limiter
}

func New(authToken string) ai.AI {
return NewWithBaseURL(authToken, "")
}

func NewWithBaseURL(authToken, baseURL string) ai.AI {
cfg := openai.DefaultConfig(authToken)
if baseURL != "" {
cfg.BaseURL = baseURL
func New(opts ...ai.ClientOption) (ai.AI, error) {
cfg := new(ai.ClientConfig)
for _, i := range opts {
i.Apply(cfg)
}
config := openai.DefaultConfig(cfg.APIKey)
if cfg.Endpoint != "" {
config.BaseURL = cfg.Endpoint
}
if cfg.Proxy != "" {
u, err := url.Parse(cfg.Proxy)
if err != nil {
return nil, err
}
if t, ok := http.DefaultTransport.(*http.Transport); ok {
t = t.Clone()
t.Proxy = http.ProxyURL(u)
config.HTTPClient = &http.Client{Transport: t}
}
}
return NewWithClient(openai.NewClientWithConfig(cfg))
c := NewWithClient(openai.NewClientWithConfig(config), cfg.Model)
if cfg.Limit != nil {
c.SetLimit(*cfg.Limit)
}
ai.ApplyModelConfig(c, cfg.ModelConfig)
return c, nil
}

func NewWithClient(client *openai.Client) ai.AI {
func NewWithClient(client *openai.Client, model string) ai.AI {
if client == nil {
panic("cannot create AI from nil client")
}
return &ChatGPT{c: client, model: defaultModel}
if model == "" {
model = defaultModel
}
return &ChatGPT{c: client, model: model}
}

func (ChatGPT) LLMs() ai.LLMs {
Expand Down
46 changes: 16 additions & 30 deletions client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,40 +8,26 @@ import (
"github.com/sunshineplan/ai/gemini"
)

func New(cfg ai.Config) (client ai.AI, err error) {
switch cfg.LLMs {
case "":
func New(cfg ai.ClientConfig) (client ai.AI, err error) {
if cfg.LLMs == "" {
return nil, errors.New("empty AI")
case ai.ChatGPT:
client = chatgpt.NewWithBaseURL(cfg.APIKey, cfg.BaseURL)
case ai.Gemini:
client, err = gemini.NewWithEndpoint(cfg.APIKey, cfg.BaseURL)
if err != nil {
return
}
default:
return nil, errors.New("unknown LLMs: " + string(cfg.LLMs))
}
if cfg.Model != "" {
client.SetModel(cfg.Model)
}
if cfg.Proxy != "" {
ai.SetProxy(cfg.Proxy)
}
if cfg.Count != nil {
client.SetCount(*cfg.Count)
}
if cfg.MaxTokens != nil {
client.SetMaxTokens(*cfg.MaxTokens)
}
if cfg.Temperature != nil {
client.SetTemperature(*cfg.Temperature)
}
if cfg.TopP != nil {
client.SetTopP(*cfg.TopP)
opts := []ai.ClientOption{
ai.WithAPIKey(cfg.APIKey),
ai.WithEndpoint(cfg.Endpoint),
ai.WithProxy(cfg.Proxy),
ai.WithModelConfig(cfg.ModelConfig),
}
if cfg.Limit != nil {
client.SetLimit(*cfg.Limit)
opts = append(opts, ai.WithLimit(*cfg.Limit))
}
switch cfg.LLMs {
case ai.ChatGPT:
client, err = chatgpt.New(opts...)
case ai.Gemini:
client, err = gemini.New(opts...)
default:
err = errors.New("unknown LLMs: " + string(cfg.LLMs))
}
return
}
63 changes: 56 additions & 7 deletions config.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,67 @@ package ai

import "golang.org/x/time/rate"

type Config struct {
LLMs LLMs
APIKey string
type ClientConfig struct {
LLMs LLMs

BaseURL string
Model string
Proxy string
APIKey string
Endpoint string
Proxy string

Limit *rate.Limit

Model string
ModelConfig ModelConfig
}

type ModelConfig struct {
Count *int32
MaxTokens *int32
Temperature *float32
TopP *float32
}

Limit *rate.Limit
func ApplyModelConfig(ai AI, cfg ModelConfig) {
if cfg.Count != nil {
ai.SetCount(*cfg.Count)
}
if cfg.MaxTokens != nil {
ai.SetMaxTokens(*cfg.MaxTokens)
}
if cfg.Temperature != nil {
ai.SetTemperature(*cfg.Temperature)
}
if cfg.TopP != nil {
ai.SetTopP(*cfg.TopP)
}
}

type ClientOption interface {
Apply(*ClientConfig)
}

func WithAPIKey(apiKey string) ClientOption { return withAPIKey(apiKey) }
func WithEndpoint(endpoint string) ClientOption { return withEndpoint(endpoint) }
func WithProxy(proxy string) ClientOption { return withProxy(proxy) }
func WithLimit(limit rate.Limit) ClientOption { return withLimit(limit) }
func WithModelConfig(config ModelConfig) ClientOption { return withModel(config) }

type withAPIKey string

func (w withAPIKey) Apply(cfg *ClientConfig) { cfg.APIKey = string(w) }

type withEndpoint string

func (w withEndpoint) Apply(cfg *ClientConfig) { cfg.Endpoint = string(w) }

type withProxy string

func (w withProxy) Apply(cfg *ClientConfig) { cfg.Proxy = string(w) }

type withLimit rate.Limit

func (w withLimit) Apply(cfg *ClientConfig) { cfg.Limit = (*rate.Limit)(&w) }

type withModel ModelConfig

func (w withModel) Apply(cfg *ClientConfig) { cfg.ModelConfig = ModelConfig(w) }
17 changes: 17 additions & 0 deletions gemini/apikey.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
package gemini

import "net/http"

var _ http.RoundTripper = new(apikey)

type apikey struct {
key string
rt http.RoundTripper
}

func (t *apikey) RoundTrip(req *http.Request) (*http.Response, error) {
args := req.URL.Query()
args.Set("key", t.key)
req.URL.RawQuery = args.Encode()
return t.rt.RoundTrip(req)
}
45 changes: 33 additions & 12 deletions gemini/gemini.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ import (
"context"
"errors"
"io"
"net/http"
"net/url"
"strings"

"github.com/sunshineplan/ai"
Expand All @@ -26,24 +28,43 @@ type Gemini struct {
limiter *rate.Limiter
}

func New(apiKey string) (ai.AI, error) {
return NewWithEndpoint(apiKey, "")
}

func NewWithEndpoint(apiKey, endpoint string) (ai.AI, error) {
opts := []option.ClientOption{option.WithAPIKey(apiKey)}
if endpoint != "" {
opts = append(opts, option.WithEndpoint(endpoint))
func New(opts ...ai.ClientOption) (ai.AI, error) {
cfg := new(ai.ClientConfig)
for _, i := range opts {
i.Apply(cfg)
}
o := []option.ClientOption{option.WithAPIKey(cfg.APIKey)}
if cfg.Proxy != "" {
u, err := url.Parse(cfg.Proxy)
if err != nil {
return nil, err
}
if t, ok := http.DefaultTransport.(*http.Transport); ok {
t = t.Clone()
t.Proxy = http.ProxyURL(u)
o = append(o, option.WithHTTPClient(&http.Client{Transport: &apikey{cfg.APIKey, t}}))
}
}
if cfg.Endpoint != "" {
o = append(o, option.WithEndpoint(cfg.Endpoint))
}
client, err := genai.NewClient(context.Background(), opts...)
client, err := genai.NewClient(context.Background(), o...)
if err != nil {
return nil, err
}
return NewWithClient(client), nil
c := NewWithClient(client, cfg.Model)
if cfg.Limit != nil {
c.SetLimit(*cfg.Limit)
}
ai.ApplyModelConfig(c, cfg.ModelConfig)
return c, nil
}

func NewWithClient(client *genai.Client) ai.AI {
return &Gemini{c: client, model: client.GenerativeModel(defaultModel)}
func NewWithClient(client *genai.Client, model string) ai.AI {
if model == "" {
model = defaultModel
}
return &Gemini{c: client, model: client.GenerativeModel(model)}
}

func (Gemini) LLMs() ai.LLMs {
Expand Down

0 comments on commit 299d1a1

Please sign in to comment.