Skip to content

Commit

Permalink
Create prompt package (#10)
Browse files Browse the repository at this point in the history
  • Loading branch information
sunshineplan authored Mar 12, 2024
1 parent be5270d commit f459b4d
Show file tree
Hide file tree
Showing 7 changed files with 247 additions and 5 deletions.
5 changes: 0 additions & 5 deletions chatgpt/chatgpt.go
Original file line number Diff line number Diff line change
Expand Up @@ -143,11 +143,6 @@ func (stream *ChatStream) Next() (ai.ChatResponse, error) {
return &ChatResponse[openai.ChatCompletionStreamResponse]{resp}, nil
}

func (stream *ChatStream) Close() error {
stream.ChatCompletionStream.Close()
return nil
}

func (ai *ChatGPT) chatStream(
ctx context.Context,
history []openai.ChatCompletionMessage,
Expand Down
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ go 1.22
require (
github.com/google/generative-ai-go v0.8.0
github.com/sashabaranov/go-openai v1.20.3
github.com/sunshineplan/utils v0.1.63
golang.org/x/time v0.5.0
google.golang.org/api v0.169.0
)
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
github.com/sunshineplan/utils v0.1.63 h1:QNcigCt9SDeDXpKPZ7w4JTcZ7t/3UepHQFFQuvi+IRo=
github.com/sunshineplan/utils v0.1.63/go.mod h1:7zhDUGgKo2FMFzs7j6IL7B/lh3BRuE7rb7R7IgGOAfc=
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
go.opencensus.io v0.24.0 h1:y73uSU6J157QMP2kn2r30vwW1A2W2WFwSCGnAVxeaD0=
go.opencensus.io v0.24.0/go.mod h1:vNK8G9p7aAivkbmorf4v+7Hgx+Zs0yY+0fOtgBfjQKo=
Expand Down
20 changes: 20 additions & 0 deletions prompt/example.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
package prompt

import "fmt"

type Example struct {
Input []string
Output string
Prefix string
}

func (ex Example) String() string {
switch len(ex.Input) {
case 0:
return ""
case 1:
return fmt.Sprintf("Input: %s\nOutput: %s", ex.Input[0], ex.Output)
default:
return fmt.Sprintf("Input:%s\nOutput: %s", printBatch(ex.Input, ex.Prefix, 0), ex.Output)
}
}
19 changes: 19 additions & 0 deletions prompt/example_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package prompt

import "testing"

func TestExample(t *testing.T) {
for i, tc := range []struct {
ex Example
output string
}{
{Example{nil, "result", "%d|"}, ""},
{Example{[]string{"abc"}, "result", "%d|"}, "Input: abc\nOutput: result"},
{Example{[]string{"abc", "def", "ghi"}, "result", ""}, "Input:\"\"\"\nabc\ndef\nghi\n\"\"\"\nOutput: result"},
{Example{[]string{"abc", "def", "ghi"}, "result", "%d|"}, "Input:\"\"\"\n1|abc\n2|def\n3|ghi\n\"\"\"\nOutput: result"},
} {
if output := tc.ex.String(); output != tc.output {
t.Errorf("#%d: expected %q; got %q", i, tc.output, output)
}
}
}
140 changes: 140 additions & 0 deletions prompt/prompt.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
package prompt

import (
"context"
"fmt"
"strings"
"text/template"
"time"

"github.com/sunshineplan/ai"
"github.com/sunshineplan/utils/workers"
)

const (
defaultTimeout = time.Minute
defaultWorkers = 3
)

const defaultTemplate = `{{.Request}}{{with .Example}}
Example:
{{.}}
###{{end}}{{if .Input}}
Input:{{if gt (len .Input) 1}}{{printBatch .Input .Prefix .Start}}{{else}} {{index .Input 0}}{{end}}{{end}}
Output:`

func printBatch(s []string, prefix string, start int) string {
var b strings.Builder
fmt.Fprintln(&b, `"""`)
for i, s := range s {
if prefix == "" {
fmt.Fprintln(&b, s)
} else {
fmt.Fprintf(&b, prefix+"%s\n", start+i+1, s)
}
}
fmt.Fprint(&b, `"""`)
return b.String()
}

var defaultFuncMap = template.FuncMap{
"printBatch": printBatch,
}

type Prompt struct {
prompt string
t *template.Template
ex *Example
limit int

d time.Duration
workers int
}

func New(prompt string) *Prompt {
p := &Prompt{prompt: prompt, d: defaultTimeout, workers: defaultWorkers}
p.t = template.Must(template.New("prompt").Funcs(defaultFuncMap).Parse(defaultTemplate))
return p
}

func (prompt *Prompt) SetTemplate(t *template.Template) *Prompt {
prompt.t = t
return prompt
}

func (prompt *Prompt) SetExample(ex Example) *Prompt {
prompt.ex = &ex
return prompt
}

func (prompt *Prompt) SetLimit(limit int) *Prompt {
prompt.limit = limit
return prompt
}

func (prompt *Prompt) SetAITimeout(d time.Duration) *Prompt {
prompt.d = d
return prompt
}

func (prompt *Prompt) SetWorkers(n int) *Prompt {
prompt.workers = n
return prompt
}

func (prompt *Prompt) execute(input []string, prefix string) (prompts []string, err error) {
length := len(input)
if length == 0 {
return
}
limit := prompt.limit
if limit == 0 {
limit = length
}
for n := 0; n < length; n = n + limit {
var s []string
if n+limit < length {
s = input[n : n+limit]
} else {
s = input[n:]
}
var b strings.Builder
if err = prompt.t.Execute(&b, struct {
Request string
Example *Example
Input []string
Prefix string
Start int
}{prompt.prompt, prompt.ex, s, prefix, n}); err != nil {
return nil, err
}
prompts = append(prompts, b.String())
}
return
}

type Result struct {
Index int
Prompt string
Result []string
Error error
}

func (prompt *Prompt) Execute(ai ai.AI, input []string, prefix string) (<-chan Result, error) {
prompts, err := prompt.execute(input, prefix)
if err != nil {
return nil, err
}
c := make(chan Result, len(prompts))
go workers.RunSlice(prompt.workers, prompts, func(i int, p string) {
ctx, cancel := context.WithTimeout(context.Background(), prompt.d)
defer cancel()
resp, err := ai.Chat(ctx, p)
if err != nil {
c <- Result{i, p, nil, err}
} else {
c <- Result{i, p, resp.Results(), nil}
}
})
return c, nil
}
65 changes: 65 additions & 0 deletions prompt/prompt_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
package prompt

import (
"reflect"
"testing"
)

func TestPrompt(t *testing.T) {
for i, tc := range []struct {
prompt *Prompt
input []string
prefix string
prompts []string
}{
{
New("no example single input"),
[]string{"test"},
"",
[]string{"no example single input\nInput: test\nOutput:"},
},
{
New("has example single input").SetExample(Example{[]string{"abc", "def"}, "example", ""}),
[]string{"test"},
"",
[]string{
"has example single input\nExample:\nInput:\"\"\"\nabc\ndef\n\"\"\"\nOutput: example\n###\nInput: test\nOutput:",
},
},
{
New("has example with prefix single input").SetExample(Example{[]string{"abc", "def"}, "example", "%d|"}),
[]string{"test"},
"",
[]string{
"has example with prefix single input\nExample:\nInput:\"\"\"\n1|abc\n2|def\n\"\"\"\nOutput: example\n###\nInput: test\nOutput:",
},
},
{
New("no example multiple inputs"),
[]string{"test1", "test2"},
"",
[]string{"no example multiple inputs\nInput:\"\"\"\ntest1\ntest2\n\"\"\"\nOutput:"},
},
{
New("no example multiple inputs with prefix"),
[]string{"test1", "test2"},
"%d|",
[]string{"no example multiple inputs with prefix\nInput:\"\"\"\n1|test1\n2|test2\n\"\"\"\nOutput:"},
},
{
New("test limit").SetExample(Example{[]string{"abc", "def"}, "example", "%d|"}).SetLimit(2),
[]string{"test1", "test2", "test3", "test4"},
"%d|",
[]string{
"test limit\nExample:\nInput:\"\"\"\n1|abc\n2|def\n\"\"\"\nOutput: example\n###\nInput:\"\"\"\n1|test1\n2|test2\n\"\"\"\nOutput:",
"test limit\nExample:\nInput:\"\"\"\n1|abc\n2|def\n\"\"\"\nOutput: example\n###\nInput:\"\"\"\n3|test3\n4|test4\n\"\"\"\nOutput:",
},
},
} {
if prompts, err := tc.prompt.execute(tc.input, tc.prefix); err != nil {
t.Errorf("#%d: error: %s", i, err)
} else if !reflect.DeepEqual(prompts, tc.prompts) {
t.Errorf("#%d: expected %q; got %q", i, tc.prompts, prompts)
}
}
}

0 comments on commit f459b4d

Please sign in to comment.