-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
be5270d
commit f459b4d
Showing
7 changed files
with
247 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
package prompt | ||
|
||
import "fmt" | ||
|
||
type Example struct { | ||
Input []string | ||
Output string | ||
Prefix string | ||
} | ||
|
||
func (ex Example) String() string { | ||
switch len(ex.Input) { | ||
case 0: | ||
return "" | ||
case 1: | ||
return fmt.Sprintf("Input: %s\nOutput: %s", ex.Input[0], ex.Output) | ||
default: | ||
return fmt.Sprintf("Input:%s\nOutput: %s", printBatch(ex.Input, ex.Prefix, 0), ex.Output) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
package prompt | ||
|
||
import "testing" | ||
|
||
func TestExample(t *testing.T) { | ||
for i, tc := range []struct { | ||
ex Example | ||
output string | ||
}{ | ||
{Example{nil, "result", "%d|"}, ""}, | ||
{Example{[]string{"abc"}, "result", "%d|"}, "Input: abc\nOutput: result"}, | ||
{Example{[]string{"abc", "def", "ghi"}, "result", ""}, "Input:\"\"\"\nabc\ndef\nghi\n\"\"\"\nOutput: result"}, | ||
{Example{[]string{"abc", "def", "ghi"}, "result", "%d|"}, "Input:\"\"\"\n1|abc\n2|def\n3|ghi\n\"\"\"\nOutput: result"}, | ||
} { | ||
if output := tc.ex.String(); output != tc.output { | ||
t.Errorf("#%d: expected %q; got %q", i, tc.output, output) | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,140 @@ | ||
package prompt | ||
|
||
import ( | ||
"context" | ||
"fmt" | ||
"strings" | ||
"text/template" | ||
"time" | ||
|
||
"github.com/sunshineplan/ai" | ||
"github.com/sunshineplan/utils/workers" | ||
) | ||
|
||
const ( | ||
defaultTimeout = time.Minute | ||
defaultWorkers = 3 | ||
) | ||
|
||
const defaultTemplate = `{{.Request}}{{with .Example}} | ||
Example: | ||
{{.}} | ||
###{{end}}{{if .Input}} | ||
Input:{{if gt (len .Input) 1}}{{printBatch .Input .Prefix .Start}}{{else}} {{index .Input 0}}{{end}}{{end}} | ||
Output:` | ||
|
||
func printBatch(s []string, prefix string, start int) string { | ||
var b strings.Builder | ||
fmt.Fprintln(&b, `"""`) | ||
for i, s := range s { | ||
if prefix == "" { | ||
fmt.Fprintln(&b, s) | ||
} else { | ||
fmt.Fprintf(&b, prefix+"%s\n", start+i+1, s) | ||
} | ||
} | ||
fmt.Fprint(&b, `"""`) | ||
return b.String() | ||
} | ||
|
||
var defaultFuncMap = template.FuncMap{ | ||
"printBatch": printBatch, | ||
} | ||
|
||
type Prompt struct { | ||
prompt string | ||
t *template.Template | ||
ex *Example | ||
limit int | ||
|
||
d time.Duration | ||
workers int | ||
} | ||
|
||
func New(prompt string) *Prompt { | ||
p := &Prompt{prompt: prompt, d: defaultTimeout, workers: defaultWorkers} | ||
p.t = template.Must(template.New("prompt").Funcs(defaultFuncMap).Parse(defaultTemplate)) | ||
return p | ||
} | ||
|
||
func (prompt *Prompt) SetTemplate(t *template.Template) *Prompt { | ||
prompt.t = t | ||
return prompt | ||
} | ||
|
||
func (prompt *Prompt) SetExample(ex Example) *Prompt { | ||
prompt.ex = &ex | ||
return prompt | ||
} | ||
|
||
func (prompt *Prompt) SetLimit(limit int) *Prompt { | ||
prompt.limit = limit | ||
return prompt | ||
} | ||
|
||
func (prompt *Prompt) SetAITimeout(d time.Duration) *Prompt { | ||
prompt.d = d | ||
return prompt | ||
} | ||
|
||
func (prompt *Prompt) SetWorkers(n int) *Prompt { | ||
prompt.workers = n | ||
return prompt | ||
} | ||
|
||
func (prompt *Prompt) execute(input []string, prefix string) (prompts []string, err error) { | ||
length := len(input) | ||
if length == 0 { | ||
return | ||
} | ||
limit := prompt.limit | ||
if limit == 0 { | ||
limit = length | ||
} | ||
for n := 0; n < length; n = n + limit { | ||
var s []string | ||
if n+limit < length { | ||
s = input[n : n+limit] | ||
} else { | ||
s = input[n:] | ||
} | ||
var b strings.Builder | ||
if err = prompt.t.Execute(&b, struct { | ||
Request string | ||
Example *Example | ||
Input []string | ||
Prefix string | ||
Start int | ||
}{prompt.prompt, prompt.ex, s, prefix, n}); err != nil { | ||
return nil, err | ||
} | ||
prompts = append(prompts, b.String()) | ||
} | ||
return | ||
} | ||
|
||
type Result struct { | ||
Index int | ||
Prompt string | ||
Result []string | ||
Error error | ||
} | ||
|
||
func (prompt *Prompt) Execute(ai ai.AI, input []string, prefix string) (<-chan Result, error) { | ||
prompts, err := prompt.execute(input, prefix) | ||
if err != nil { | ||
return nil, err | ||
} | ||
c := make(chan Result, len(prompts)) | ||
go workers.RunSlice(prompt.workers, prompts, func(i int, p string) { | ||
ctx, cancel := context.WithTimeout(context.Background(), prompt.d) | ||
defer cancel() | ||
resp, err := ai.Chat(ctx, p) | ||
if err != nil { | ||
c <- Result{i, p, nil, err} | ||
} else { | ||
c <- Result{i, p, resp.Results(), nil} | ||
} | ||
}) | ||
return c, nil | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
package prompt | ||
|
||
import ( | ||
"reflect" | ||
"testing" | ||
) | ||
|
||
func TestPrompt(t *testing.T) { | ||
for i, tc := range []struct { | ||
prompt *Prompt | ||
input []string | ||
prefix string | ||
prompts []string | ||
}{ | ||
{ | ||
New("no example single input"), | ||
[]string{"test"}, | ||
"", | ||
[]string{"no example single input\nInput: test\nOutput:"}, | ||
}, | ||
{ | ||
New("has example single input").SetExample(Example{[]string{"abc", "def"}, "example", ""}), | ||
[]string{"test"}, | ||
"", | ||
[]string{ | ||
"has example single input\nExample:\nInput:\"\"\"\nabc\ndef\n\"\"\"\nOutput: example\n###\nInput: test\nOutput:", | ||
}, | ||
}, | ||
{ | ||
New("has example with prefix single input").SetExample(Example{[]string{"abc", "def"}, "example", "%d|"}), | ||
[]string{"test"}, | ||
"", | ||
[]string{ | ||
"has example with prefix single input\nExample:\nInput:\"\"\"\n1|abc\n2|def\n\"\"\"\nOutput: example\n###\nInput: test\nOutput:", | ||
}, | ||
}, | ||
{ | ||
New("no example multiple inputs"), | ||
[]string{"test1", "test2"}, | ||
"", | ||
[]string{"no example multiple inputs\nInput:\"\"\"\ntest1\ntest2\n\"\"\"\nOutput:"}, | ||
}, | ||
{ | ||
New("no example multiple inputs with prefix"), | ||
[]string{"test1", "test2"}, | ||
"%d|", | ||
[]string{"no example multiple inputs with prefix\nInput:\"\"\"\n1|test1\n2|test2\n\"\"\"\nOutput:"}, | ||
}, | ||
{ | ||
New("test limit").SetExample(Example{[]string{"abc", "def"}, "example", "%d|"}).SetLimit(2), | ||
[]string{"test1", "test2", "test3", "test4"}, | ||
"%d|", | ||
[]string{ | ||
"test limit\nExample:\nInput:\"\"\"\n1|abc\n2|def\n\"\"\"\nOutput: example\n###\nInput:\"\"\"\n1|test1\n2|test2\n\"\"\"\nOutput:", | ||
"test limit\nExample:\nInput:\"\"\"\n1|abc\n2|def\n\"\"\"\nOutput: example\n###\nInput:\"\"\"\n3|test3\n4|test4\n\"\"\"\nOutput:", | ||
}, | ||
}, | ||
} { | ||
if prompts, err := tc.prompt.execute(tc.input, tc.prefix); err != nil { | ||
t.Errorf("#%d: error: %s", i, err) | ||
} else if !reflect.DeepEqual(prompts, tc.prompts) { | ||
t.Errorf("#%d: expected %q; got %q", i, tc.prompts, prompts) | ||
} | ||
} | ||
} |