diff --git a/command/openai/api.go b/command/openai/api.go index 61e3129a..f6f8bcda 100644 --- a/command/openai/api.go +++ b/command/openai/api.go @@ -21,7 +21,7 @@ const ( ) // we don't use our default clients.HttpClient as we need longer timeouts... -var client = http.Client{ +var httpClient = http.Client{ Timeout: 60 * time.Second, } @@ -34,7 +34,7 @@ func doRequest(cfg Config, apiEndpoint string, data []byte) (*http.Response, err req.Header.Set("Content-Type", "application/json") req.Header.Set("Authorization", "Bearer "+cfg.APIKey) - return client.Do(req) + return httpClient.Do(req) } // https://platform.openai.com/docs/api-reference/chat diff --git a/command/openai/command.go b/command/openai/command.go index 95f80295..7c32d4aa 100644 --- a/command/openai/command.go +++ b/command/openai/command.go @@ -85,7 +85,7 @@ func (c *openaiCommand) startConversation(message msg.Ref, text string) bool { var storageIdentifier string if message.GetThread() != "" { - // "openai" was triggerd within a existing thread. -> fetch the whole thread history as context + // "openai" was triggered within a existing thread. -> fetch the whole thread history as context threadMessages, err := c.SlackClient.GetThreadMessages(message) if err != nil { c.ReplyError(message, fmt.Errorf("can't load thread messages: %w", err)) @@ -236,9 +236,12 @@ func (c *openaiCommand) callAndStore(messages []ChatMessage, storageIdentifier s Role: roleAssistant, Content: responseText.String(), }) + + // truncate the history if needed to the last X messages if len(messages) > c.cfg.HistorySize { messages = messages[len(messages)-c.cfg.HistorySize:] } + err = storage.Write(storageKey, storageIdentifier, messages) if err != nil { log.Warnf("Error while storing openai history: %s", err) diff --git a/command/openai/dalle.go b/command/openai/dalle.go index 16916483..24eedecb 100644 --- a/command/openai/dalle.go +++ b/command/openai/dalle.go @@ -10,6 +10,7 @@ import ( "github.com/innogames/slack-bot/v2/bot/msg" "github.com/innogames/slack-bot/v2/bot/stats" "github.com/innogames/slack-bot/v2/bot/util" + "github.com/innogames/slack-bot/v2/client" log "github.com/sirupsen/logrus" "github.com/slack-go/slack" ) @@ -51,7 +52,7 @@ func (c *openaiCommand) sendImageInSlack(image DalleResponseImage, message msg.M if err != nil { return err } - resp, err := client.Do(req) + resp, err := httpClient.Do(req) if err != nil { return err } @@ -63,9 +64,20 @@ func (c *openaiCommand) sendImageInSlack(image DalleResponseImage, message msg.M Channels: []string{message.Channel}, ThreadTimestamp: message.Timestamp, Reader: resp.Body, - InitialComment: image.RevisedPrompt, + InitialComment: fmt.Sprintf("Dall-e prompt: %s", image.RevisedPrompt), }) + c.SlackClient.SendBlockMessage( + message, + []slack.Block{ + slack.NewActionBlock( + "", + client.GetInteractionButton("dalle", "Regenerate", fmt.Sprintf("dall-e %s", image.RevisedPrompt)), + ), + }, + slack.MsgOptionTS(message.Timestamp), + ) + return err } diff --git a/command/openai/dalle_test.go b/command/openai/dalle_test.go index 116b8d18..c0cbdb50 100644 --- a/command/openai/dalle_test.go +++ b/command/openai/dalle_test.go @@ -100,6 +100,12 @@ func TestDalle(t *testing.T) { mocks.AssertRemoveReaction(slackClient, ":coffee:", message) mocks.AssertReaction(slackClient, ":outbox_tray:", message) mocks.AssertRemoveReaction(slackClient, ":outbox_tray:", message) + mocks.AssertSlackBlocks( + t, + slackClient, + message, + `[{"type":"actions","elements":[{"type":"button","text":{"type":"plain_text","text":"Regenerate","emoji":true},"action_id":"dalle","value":"dall-e revised prompt 1234"}]}]`, + ) slackClient.On( "UploadFile", diff --git a/mocks/testing.go b/mocks/testing.go index 71f068c5..7ffb6659 100644 --- a/mocks/testing.go +++ b/mocks/testing.go @@ -148,7 +148,7 @@ func AssertSlackBlocks(t *testing.T, slackClient *SlackClient, message msg.Ref, } return expectedJSON == string(givenJSON) - })).Once().Return("") + }), mock.Anything).Once().Return("") } // AssertContainsSlackBlocks is a small test helper to check for certain slack.Block