diff --git a/prompt/prompt.go b/prompt/prompt.go index d6eff03..aeca83e 100644 --- a/prompt/prompt.go +++ b/prompt/prompt.go @@ -84,7 +84,7 @@ func (prompt *Prompt) SetWorkers(n int) *Prompt { return prompt } -func (prompt *Prompt) execute(input []string, prefix string) (prompts []string, err error) { +func (prompt *Prompt) Prompts(input []string, prefix string) (prompts []string, err error) { length := len(input) if length == 0 { return @@ -123,7 +123,7 @@ type Result struct { } func (prompt *Prompt) Execute(ai ai.AI, input []string, prefix string) (<-chan Result, int, error) { - prompts, err := prompt.execute(input, prefix) + prompts, err := prompt.Prompts(input, prefix) if err != nil { return nil, 0, err } @@ -131,7 +131,13 @@ func (prompt *Prompt) Execute(ai ai.AI, input []string, prefix string) (<-chan R c := make(chan Result, n) go func() { workers.RunSlice(prompt.workers, prompts, func(i int, p string) { - ctx, cancel := context.WithTimeout(context.Background(), prompt.d) + var ctx context.Context + var cancel context.CancelFunc + if prompt.d != 0 { + ctx, cancel = context.WithTimeout(context.Background(), prompt.d) + } else { + ctx, cancel = context.WithCancel(context.Background()) + } defer cancel() resp, err := ai.Chat(ctx, p) if err != nil { diff --git a/prompt/prompt_test.go b/prompt/prompt_test.go index 2f53a5d..77ea36a 100644 --- a/prompt/prompt_test.go +++ b/prompt/prompt_test.go @@ -62,7 +62,7 @@ func TestPrompt(t *testing.T) { }, }, } { - if prompts, err := tc.prompt.execute(tc.input, tc.prefix); err != nil { + if prompts, err := tc.prompt.Prompts(tc.input, tc.prefix); err != nil { t.Errorf("#%d: error: %s", i, err) } else if !reflect.DeepEqual(prompts, tc.prompts) { t.Errorf("#%d: expected %q; got %q", i, tc.prompts, prompts)