Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Config and proxy #27

Merged
merged 4 commits into from
Mar 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,11 @@ jobs:
- name: Test Code
env:
CHATGPT_API_KEY: ${{ secrets.CHATGPT_API_KEY }}
CHATGPT_BASE_URL: ${{ secrets.CHATGPT_BASE_URL }}
CHATGPT_PROXY: ${{ secrets.CHATGPT_PROXY }}
CHATGPT_ENDPOINT: ${{ secrets.CHATGPT_ENDPOINT }}
CHATGPT_MODEL: ${{ secrets.CHATGPT_MODEL }}
GEMINI_API_KEY: ${{ secrets.GEMINI_API_KEY }}
GEMINI_PROXY: ${{ secrets.GEMINI_PROXY }}
GEMINI_ENDPOINT: ${{ secrets.GEMINI_ENDPOINT }}
GEMINI_MODEL: ${{ secrets.GEMINI_MODEL }}
run: go test -v
11 changes: 0 additions & 11 deletions ai.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@ package ai
import (
"context"
"errors"
"net/http"
"net/url"
)

var ErrAIClosed = errors.New("AI client is nil or already closed")
Expand Down Expand Up @@ -53,12 +51,3 @@ type ChatStream interface {
type ChatResponse interface {
Results() []string
}

func SetProxy(proxy string) error {
u, err := url.Parse(proxy)
if err != nil {
return err
}
http.DefaultTransport.(*http.Transport).Proxy = http.ProxyURL(u)
return nil
}
27 changes: 19 additions & 8 deletions ai_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,6 @@ import (
"github.com/sunshineplan/ai/gemini"
)

func init() {
if proxy := os.Getenv("AI_PROXY"); proxy != "" {
ai.SetProxy(proxy)
}
}

func testChat(ai ai.AI, prompt string) error {
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
defer cancel()
Expand Down Expand Up @@ -97,11 +91,18 @@ func TestGemini(t *testing.T) {
if apiKey == "" {
return
}
gemini, err := gemini.NewWithEndpoint(apiKey, os.Getenv("GEMINI_ENDPOINT"))
gemini, err := gemini.New(
ai.WithAPIKey(apiKey),
ai.WithEndpoint(os.Getenv("GEMINI_ENDPOINT")),
ai.WithProxy(os.Getenv("GEMINI_PROXY")),
)
if err != nil {
t.Fatal(err)
}
defer gemini.Close()
if model := os.Getenv("GEMINI_MODEL"); model != "" {
gemini.SetModel(model)
}
if err := testChat(gemini, "Who are you?"); err != nil {
t.Error(err)
}
Expand All @@ -118,8 +119,18 @@ func TestChatGPT(t *testing.T) {
if apiKey == "" {
return
}
chatgpt := chatgpt.NewWithBaseURL(apiKey, os.Getenv("CHATGPT_BASE_URL"))
chatgpt, err := chatgpt.New(
ai.WithAPIKey(apiKey),
ai.WithEndpoint(os.Getenv("CHATGPT_ENDPOINT")),
ai.WithProxy(os.Getenv("CHATGPT_PROXY")),
)
if err != nil {
t.Fatal(err)
}
defer chatgpt.Close()
if model := os.Getenv("CHATGPT_MODEL"); model != "" {
chatgpt.SetModel(model)
}
if err := testChat(chatgpt, "Who are you?"); err != nil {
t.Error(err)
}
Expand Down
43 changes: 32 additions & 11 deletions chatgpt/chatgpt.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ package chatgpt
import (
"context"
"io"
"net/http"
"net/url"

"github.com/sunshineplan/ai"

Expand All @@ -25,23 +27,42 @@ type ChatGPT struct {
limiter *rate.Limiter
}

func New(authToken string) ai.AI {
return NewWithBaseURL(authToken, "")
}

func NewWithBaseURL(authToken, baseURL string) ai.AI {
cfg := openai.DefaultConfig(authToken)
if baseURL != "" {
cfg.BaseURL = baseURL
func New(opts ...ai.ClientOption) (ai.AI, error) {
cfg := new(ai.ClientConfig)
for _, i := range opts {
i.Apply(cfg)
}
config := openai.DefaultConfig(cfg.APIKey)
if cfg.Endpoint != "" {
config.BaseURL = cfg.Endpoint
}
if cfg.Proxy != "" {
u, err := url.Parse(cfg.Proxy)
if err != nil {
return nil, err
}
if t, ok := http.DefaultTransport.(*http.Transport); ok {
t = t.Clone()
t.Proxy = http.ProxyURL(u)
config.HTTPClient = &http.Client{Transport: t}
}
}
return NewWithClient(openai.NewClientWithConfig(cfg))
c := NewWithClient(openai.NewClientWithConfig(config), cfg.Model)
if cfg.Limit != nil {
c.SetLimit(*cfg.Limit)
}
ai.ApplyModelConfig(c, cfg.ModelConfig)
return c, nil
}

func NewWithClient(client *openai.Client) ai.AI {
func NewWithClient(client *openai.Client, model string) ai.AI {
if client == nil {
panic("cannot create AI from nil client")
}
return &ChatGPT{c: client, model: defaultModel}
if model == "" {
model = defaultModel
}
return &ChatGPT{c: client, model: model}
}

func (ChatGPT) LLMs() ai.LLMs {
Expand Down
46 changes: 16 additions & 30 deletions client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,40 +8,26 @@ import (
"github.com/sunshineplan/ai/gemini"
)

func New(cfg ai.Config) (client ai.AI, err error) {
switch cfg.LLMs {
case "":
func New(cfg ai.ClientConfig) (client ai.AI, err error) {
if cfg.LLMs == "" {
return nil, errors.New("empty AI")
case ai.ChatGPT:
client = chatgpt.NewWithBaseURL(cfg.APIKey, cfg.BaseURL)
case ai.Gemini:
client, err = gemini.NewWithEndpoint(cfg.APIKey, cfg.BaseURL)
if err != nil {
return
}
default:
return nil, errors.New("unknown LLMs: " + string(cfg.LLMs))
}
if cfg.Model != "" {
client.SetModel(cfg.Model)
}
if cfg.Proxy != "" {
ai.SetProxy(cfg.Proxy)
}
if cfg.Count != nil {
client.SetCount(*cfg.Count)
}
if cfg.MaxTokens != nil {
client.SetMaxTokens(*cfg.MaxTokens)
}
if cfg.Temperature != nil {
client.SetTemperature(*cfg.Temperature)
}
if cfg.TopP != nil {
client.SetTopP(*cfg.TopP)
opts := []ai.ClientOption{
ai.WithAPIKey(cfg.APIKey),
ai.WithEndpoint(cfg.Endpoint),
ai.WithProxy(cfg.Proxy),
ai.WithModelConfig(cfg.ModelConfig),
}
if cfg.Limit != nil {
client.SetLimit(*cfg.Limit)
opts = append(opts, ai.WithLimit(*cfg.Limit))
}
switch cfg.LLMs {
case ai.ChatGPT:
client, err = chatgpt.New(opts...)
case ai.Gemini:
client, err = gemini.New(opts...)
default:
err = errors.New("unknown LLMs: " + string(cfg.LLMs))
}
return
}
63 changes: 56 additions & 7 deletions config.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,67 @@ package ai

import "golang.org/x/time/rate"

type Config struct {
LLMs LLMs
APIKey string
type ClientConfig struct {
LLMs LLMs

BaseURL string
Model string
Proxy string
APIKey string
Endpoint string
Proxy string

Limit *rate.Limit

Model string
ModelConfig ModelConfig
}

type ModelConfig struct {
Count *int32
MaxTokens *int32
Temperature *float32
TopP *float32
}

Limit *rate.Limit
func ApplyModelConfig(ai AI, cfg ModelConfig) {
if cfg.Count != nil {
ai.SetCount(*cfg.Count)
}
if cfg.MaxTokens != nil {
ai.SetMaxTokens(*cfg.MaxTokens)
}
if cfg.Temperature != nil {
ai.SetTemperature(*cfg.Temperature)
}
if cfg.TopP != nil {
ai.SetTopP(*cfg.TopP)
}
}

type ClientOption interface {
Apply(*ClientConfig)
}

func WithAPIKey(apiKey string) ClientOption { return withAPIKey(apiKey) }
func WithEndpoint(endpoint string) ClientOption { return withEndpoint(endpoint) }
func WithProxy(proxy string) ClientOption { return withProxy(proxy) }
func WithLimit(limit rate.Limit) ClientOption { return withLimit(limit) }
func WithModelConfig(config ModelConfig) ClientOption { return withModel(config) }

type withAPIKey string

func (w withAPIKey) Apply(cfg *ClientConfig) { cfg.APIKey = string(w) }

type withEndpoint string

func (w withEndpoint) Apply(cfg *ClientConfig) { cfg.Endpoint = string(w) }

type withProxy string

func (w withProxy) Apply(cfg *ClientConfig) { cfg.Proxy = string(w) }

type withLimit rate.Limit

func (w withLimit) Apply(cfg *ClientConfig) { cfg.Limit = (*rate.Limit)(&w) }

type withModel ModelConfig

func (w withModel) Apply(cfg *ClientConfig) { cfg.ModelConfig = ModelConfig(w) }
17 changes: 17 additions & 0 deletions gemini/apikey.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
package gemini

import "net/http"

var _ http.RoundTripper = new(apikey)

type apikey struct {
key string
rt http.RoundTripper
}

func (t *apikey) RoundTrip(req *http.Request) (*http.Response, error) {
args := req.URL.Query()
args.Set("key", t.key)
req.URL.RawQuery = args.Encode()
return t.rt.RoundTrip(req)
}
45 changes: 33 additions & 12 deletions gemini/gemini.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ import (
"context"
"errors"
"io"
"net/http"
"net/url"
"strings"

"github.com/sunshineplan/ai"
Expand All @@ -26,24 +28,43 @@ type Gemini struct {
limiter *rate.Limiter
}

func New(apiKey string) (ai.AI, error) {
return NewWithEndpoint(apiKey, "")
}

func NewWithEndpoint(apiKey, endpoint string) (ai.AI, error) {
opts := []option.ClientOption{option.WithAPIKey(apiKey)}
if endpoint != "" {
opts = append(opts, option.WithEndpoint(endpoint))
func New(opts ...ai.ClientOption) (ai.AI, error) {
cfg := new(ai.ClientConfig)
for _, i := range opts {
i.Apply(cfg)
}
o := []option.ClientOption{option.WithAPIKey(cfg.APIKey)}
if cfg.Proxy != "" {
u, err := url.Parse(cfg.Proxy)
if err != nil {
return nil, err
}
if t, ok := http.DefaultTransport.(*http.Transport); ok {
t = t.Clone()
t.Proxy = http.ProxyURL(u)
o = append(o, option.WithHTTPClient(&http.Client{Transport: &apikey{cfg.APIKey, t}}))
}
}
if cfg.Endpoint != "" {
o = append(o, option.WithEndpoint(cfg.Endpoint))
}
client, err := genai.NewClient(context.Background(), opts...)
client, err := genai.NewClient(context.Background(), o...)
if err != nil {
return nil, err
}
return NewWithClient(client), nil
c := NewWithClient(client, cfg.Model)
if cfg.Limit != nil {
c.SetLimit(*cfg.Limit)
}
ai.ApplyModelConfig(c, cfg.ModelConfig)
return c, nil
}

func NewWithClient(client *genai.Client) ai.AI {
return &Gemini{c: client, model: client.GenerativeModel(defaultModel)}
func NewWithClient(client *genai.Client, model string) ai.AI {
if model == "" {
model = defaultModel
}
return &Gemini{c: client, model: client.GenerativeModel(model)}
}

func (Gemini) LLMs() ai.LLMs {
Expand Down