From f459b4d310b8f4bbb485945221382efe170ec4fc Mon Sep 17 00:00:00 2001 From: sunshineplan Date: Tue, 12 Mar 2024 17:55:18 +0800 Subject: [PATCH] Create prompt package (#10) --- chatgpt/chatgpt.go | 5 -- go.mod | 1 + go.sum | 2 + prompt/example.go | 20 ++++++ prompt/example_test.go | 19 ++++++ prompt/prompt.go | 140 +++++++++++++++++++++++++++++++++++++++++ prompt/prompt_test.go | 65 +++++++++++++++++++ 7 files changed, 247 insertions(+), 5 deletions(-) create mode 100644 prompt/example.go create mode 100644 prompt/example_test.go create mode 100644 prompt/prompt.go create mode 100644 prompt/prompt_test.go diff --git a/chatgpt/chatgpt.go b/chatgpt/chatgpt.go index 2fe588d..e8e8c21 100644 --- a/chatgpt/chatgpt.go +++ b/chatgpt/chatgpt.go @@ -143,11 +143,6 @@ func (stream *ChatStream) Next() (ai.ChatResponse, error) { return &ChatResponse[openai.ChatCompletionStreamResponse]{resp}, nil } -func (stream *ChatStream) Close() error { - stream.ChatCompletionStream.Close() - return nil -} - func (ai *ChatGPT) chatStream( ctx context.Context, history []openai.ChatCompletionMessage, diff --git a/go.mod b/go.mod index e45f5ba..6b80426 100644 --- a/go.mod +++ b/go.mod @@ -5,6 +5,7 @@ go 1.22 require ( github.com/google/generative-ai-go v0.8.0 github.com/sashabaranov/go-openai v1.20.3 + github.com/sunshineplan/utils v0.1.63 golang.org/x/time v0.5.0 google.golang.org/api v0.169.0 ) diff --git a/go.sum b/go.sum index 1ed617c..f62f9f6 100644 --- a/go.sum +++ b/go.sum @@ -74,6 +74,8 @@ github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +github.com/sunshineplan/utils v0.1.63 h1:QNcigCt9SDeDXpKPZ7w4JTcZ7t/3UepHQFFQuvi+IRo= +github.com/sunshineplan/utils v0.1.63/go.mod h1:7zhDUGgKo2FMFzs7j6IL7B/lh3BRuE7rb7R7IgGOAfc= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= go.opencensus.io v0.24.0 h1:y73uSU6J157QMP2kn2r30vwW1A2W2WFwSCGnAVxeaD0= go.opencensus.io v0.24.0/go.mod h1:vNK8G9p7aAivkbmorf4v+7Hgx+Zs0yY+0fOtgBfjQKo= diff --git a/prompt/example.go b/prompt/example.go new file mode 100644 index 0000000..2bea119 --- /dev/null +++ b/prompt/example.go @@ -0,0 +1,20 @@ +package prompt + +import "fmt" + +type Example struct { + Input []string + Output string + Prefix string +} + +func (ex Example) String() string { + switch len(ex.Input) { + case 0: + return "" + case 1: + return fmt.Sprintf("Input: %s\nOutput: %s", ex.Input[0], ex.Output) + default: + return fmt.Sprintf("Input:%s\nOutput: %s", printBatch(ex.Input, ex.Prefix, 0), ex.Output) + } +} diff --git a/prompt/example_test.go b/prompt/example_test.go new file mode 100644 index 0000000..11c8251 --- /dev/null +++ b/prompt/example_test.go @@ -0,0 +1,19 @@ +package prompt + +import "testing" + +func TestExample(t *testing.T) { + for i, tc := range []struct { + ex Example + output string + }{ + {Example{nil, "result", "%d|"}, ""}, + {Example{[]string{"abc"}, "result", "%d|"}, "Input: abc\nOutput: result"}, + {Example{[]string{"abc", "def", "ghi"}, "result", ""}, "Input:\"\"\"\nabc\ndef\nghi\n\"\"\"\nOutput: result"}, + {Example{[]string{"abc", "def", "ghi"}, "result", "%d|"}, "Input:\"\"\"\n1|abc\n2|def\n3|ghi\n\"\"\"\nOutput: result"}, + } { + if output := tc.ex.String(); output != tc.output { + t.Errorf("#%d: expected %q; got %q", i, tc.output, output) + } + } +} diff --git a/prompt/prompt.go b/prompt/prompt.go new file mode 100644 index 0000000..f2659ed --- /dev/null +++ b/prompt/prompt.go @@ -0,0 +1,140 @@ +package prompt + +import ( + "context" + "fmt" + "strings" + "text/template" + "time" + + "github.com/sunshineplan/ai" + "github.com/sunshineplan/utils/workers" +) + +const ( + defaultTimeout = time.Minute + defaultWorkers = 3 +) + +const defaultTemplate = `{{.Request}}{{with .Example}} +Example: +{{.}} +###{{end}}{{if .Input}} +Input:{{if gt (len .Input) 1}}{{printBatch .Input .Prefix .Start}}{{else}} {{index .Input 0}}{{end}}{{end}} +Output:` + +func printBatch(s []string, prefix string, start int) string { + var b strings.Builder + fmt.Fprintln(&b, `"""`) + for i, s := range s { + if prefix == "" { + fmt.Fprintln(&b, s) + } else { + fmt.Fprintf(&b, prefix+"%s\n", start+i+1, s) + } + } + fmt.Fprint(&b, `"""`) + return b.String() +} + +var defaultFuncMap = template.FuncMap{ + "printBatch": printBatch, +} + +type Prompt struct { + prompt string + t *template.Template + ex *Example + limit int + + d time.Duration + workers int +} + +func New(prompt string) *Prompt { + p := &Prompt{prompt: prompt, d: defaultTimeout, workers: defaultWorkers} + p.t = template.Must(template.New("prompt").Funcs(defaultFuncMap).Parse(defaultTemplate)) + return p +} + +func (prompt *Prompt) SetTemplate(t *template.Template) *Prompt { + prompt.t = t + return prompt +} + +func (prompt *Prompt) SetExample(ex Example) *Prompt { + prompt.ex = &ex + return prompt +} + +func (prompt *Prompt) SetLimit(limit int) *Prompt { + prompt.limit = limit + return prompt +} + +func (prompt *Prompt) SetAITimeout(d time.Duration) *Prompt { + prompt.d = d + return prompt +} + +func (prompt *Prompt) SetWorkers(n int) *Prompt { + prompt.workers = n + return prompt +} + +func (prompt *Prompt) execute(input []string, prefix string) (prompts []string, err error) { + length := len(input) + if length == 0 { + return + } + limit := prompt.limit + if limit == 0 { + limit = length + } + for n := 0; n < length; n = n + limit { + var s []string + if n+limit < length { + s = input[n : n+limit] + } else { + s = input[n:] + } + var b strings.Builder + if err = prompt.t.Execute(&b, struct { + Request string + Example *Example + Input []string + Prefix string + Start int + }{prompt.prompt, prompt.ex, s, prefix, n}); err != nil { + return nil, err + } + prompts = append(prompts, b.String()) + } + return +} + +type Result struct { + Index int + Prompt string + Result []string + Error error +} + +func (prompt *Prompt) Execute(ai ai.AI, input []string, prefix string) (<-chan Result, error) { + prompts, err := prompt.execute(input, prefix) + if err != nil { + return nil, err + } + c := make(chan Result, len(prompts)) + go workers.RunSlice(prompt.workers, prompts, func(i int, p string) { + ctx, cancel := context.WithTimeout(context.Background(), prompt.d) + defer cancel() + resp, err := ai.Chat(ctx, p) + if err != nil { + c <- Result{i, p, nil, err} + } else { + c <- Result{i, p, resp.Results(), nil} + } + }) + return c, nil +} diff --git a/prompt/prompt_test.go b/prompt/prompt_test.go new file mode 100644 index 0000000..5285b32 --- /dev/null +++ b/prompt/prompt_test.go @@ -0,0 +1,65 @@ +package prompt + +import ( + "reflect" + "testing" +) + +func TestPrompt(t *testing.T) { + for i, tc := range []struct { + prompt *Prompt + input []string + prefix string + prompts []string + }{ + { + New("no example single input"), + []string{"test"}, + "", + []string{"no example single input\nInput: test\nOutput:"}, + }, + { + New("has example single input").SetExample(Example{[]string{"abc", "def"}, "example", ""}), + []string{"test"}, + "", + []string{ + "has example single input\nExample:\nInput:\"\"\"\nabc\ndef\n\"\"\"\nOutput: example\n###\nInput: test\nOutput:", + }, + }, + { + New("has example with prefix single input").SetExample(Example{[]string{"abc", "def"}, "example", "%d|"}), + []string{"test"}, + "", + []string{ + "has example with prefix single input\nExample:\nInput:\"\"\"\n1|abc\n2|def\n\"\"\"\nOutput: example\n###\nInput: test\nOutput:", + }, + }, + { + New("no example multiple inputs"), + []string{"test1", "test2"}, + "", + []string{"no example multiple inputs\nInput:\"\"\"\ntest1\ntest2\n\"\"\"\nOutput:"}, + }, + { + New("no example multiple inputs with prefix"), + []string{"test1", "test2"}, + "%d|", + []string{"no example multiple inputs with prefix\nInput:\"\"\"\n1|test1\n2|test2\n\"\"\"\nOutput:"}, + }, + { + New("test limit").SetExample(Example{[]string{"abc", "def"}, "example", "%d|"}).SetLimit(2), + []string{"test1", "test2", "test3", "test4"}, + "%d|", + []string{ + "test limit\nExample:\nInput:\"\"\"\n1|abc\n2|def\n\"\"\"\nOutput: example\n###\nInput:\"\"\"\n1|test1\n2|test2\n\"\"\"\nOutput:", + "test limit\nExample:\nInput:\"\"\"\n1|abc\n2|def\n\"\"\"\nOutput: example\n###\nInput:\"\"\"\n3|test3\n4|test4\n\"\"\"\nOutput:", + }, + }, + } { + if prompts, err := tc.prompt.execute(tc.input, tc.prefix); err != nil { + t.Errorf("#%d: error: %s", i, err) + } else if !reflect.DeepEqual(prompts, tc.prompts) { + t.Errorf("#%d: expected %q; got %q", i, tc.prompts, prompts) + } + } +}