Skip to content

Commit

Permalink
Merge pull request #498 from innogames/dalle
Browse files Browse the repository at this point in the history
WIP: DALL-E  integration
  • Loading branch information
brainexe authored Nov 20, 2023
2 parents cfad085 + f32f5d2 commit d093bbc
Show file tree
Hide file tree
Showing 9 changed files with 285 additions and 25 deletions.
73 changes: 70 additions & 3 deletions command/openai/api.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
package openai

import "github.com/pkg/errors"
import (
"bytes"
"net/http"

"github.com/pkg/errors"
)

const (
apiHost = "https://api.openai.com"
apiCompletionURL = "/v1/chat/completions"
apiHost = "https://api.openai.com"
apiCompletionURL = "/v1/chat/completions"
apiDalleGenerateImageURL = "/v1/images/generations"
)

const (
Expand All @@ -13,6 +19,18 @@ const (
roleAssistant = "assistant"
)

func doRequest(cfg Config, apiEndpoint string, data []byte) (*http.Response, error) {
req, err := http.NewRequest("POST", cfg.APIHost+apiEndpoint, bytes.NewBuffer(data))
if err != nil {
return nil, err
}

req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+cfg.APIKey)

return client.Do(req)
}

// https://platform.openai.com/docs/api-reference/chat
type ChatRequest struct {
Model string `json:"model"`
Expand Down Expand Up @@ -72,3 +90,52 @@ type ChatChoice struct {
FinishReason string `json:"finish_reason"`
Delta ChatMessage `json:"delta"`
}

/*
{
"model": "dall-e-3",
"prompt": "a white siamese cat",
"n": 1,
"size": "1024x1024"
}
*/
type DalleRequest struct {
Model string `json:"model"`
Prompt string `json:"prompt"`
N int `json:"n"`
Size string `json:"size"`
}

/*
{
"created": 1700233554,
"data": [
{
"url": "https://XXXX"
}
]
}
or:
{
"error": {
"code": "invalid_size",
"message": "The size is not supported by this model.",
"param": null,
"type": "invalid_request_error"
}
}
*/
type DalleResponse struct {
Data []DalleResponseImage `json:"data"`
Error struct {
Code string `json:"code"`
Message string `json:"message"`
} `json:"error"`
}

type DalleResponseImage struct {
URL string `json:"url"`
RevisedPrompt string `json:"revised_prompt"`
}
28 changes: 9 additions & 19 deletions command/openai/client.go → command/openai/chatgpt.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package openai

import (
"bufio"
"bytes"
"encoding/json"
"io"
"net/http"
Expand All @@ -15,30 +14,21 @@ import (
var client http.Client

func CallChatGPT(cfg Config, inputMessages []ChatMessage, stream bool) (<-chan string, error) {
jsonData, _ := json.Marshal(ChatRequest{
Model: cfg.Model,
Temperature: cfg.Temperature,
Seed: cfg.Seed,
MaxTokens: cfg.MaxTokens,
Stream: stream,
Messages: inputMessages,
})

req, err := http.NewRequest("POST", cfg.APIHost+apiCompletionURL, bytes.NewBuffer(jsonData))
if err != nil {
return nil, err
}

req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+cfg.APIKey)

messageUpdates := make(chan string, 2)

// return a chan of all message updates here and listen here in the background in the event stream
go func() {
defer close(messageUpdates)

resp, err := client.Do(req)
jsonData, _ := json.Marshal(ChatRequest{
Model: cfg.Model,
Temperature: cfg.Temperature,
Seed: cfg.Seed,
MaxTokens: cfg.MaxTokens,
Stream: stream,
Messages: inputMessages,
})
resp, err := doRequest(cfg, apiCompletionURL, jsonData)
if err != nil {
messageUpdates <- err.Error()
return
Expand Down
2 changes: 2 additions & 0 deletions command/openai/command.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ func (c *chatGPTCommand) GetMatcher() matcher.Matcher {
matchers := []matcher.Matcher{
matcher.NewPrefixMatcher("openai", c.newConversation),
matcher.NewPrefixMatcher("chatgpt", c.newConversation),
matcher.NewPrefixMatcher("dalle", c.dalleGenerateImage),
matcher.NewPrefixMatcher("generate image", c.dalleGenerateImage),
matcher.WildcardMatcher(c.reply),
}

Expand Down
10 changes: 10 additions & 0 deletions command/openai/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,11 @@ type Config struct {

// log all input+output text to the logger. This could include personal information, therefore disabled by default!
LogTexts bool `mapstructure:"log_texts"`

// Dall-E image generation
DalleModel string `mapstructure:"dalle_model"`
DalleImageSize string `mapstructure:"dalle_image_size"`
DalleNumberOfImages int `mapstructure:"dalle_number_of_images"`
}

// IsEnabled checks if token is set
Expand All @@ -40,6 +45,11 @@ var defaultConfig = Config{
UpdateInterval: time.Second,
HistorySize: 15,
InitialSystemMessage: "You are a helpful Slack bot. By default, keep your answer short and truthful",

// default dall-e config
DalleModel: "dall-e-3",
DalleImageSize: "1024x1024",
DalleNumberOfImages: 1,
}

func loadConfig(config *config.Config) Config {
Expand Down
69 changes: 69 additions & 0 deletions command/openai/dalle.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
package openai

import (
"encoding/json"
"fmt"
"time"

"github.com/innogames/slack-bot/v2/bot/matcher"
"github.com/innogames/slack-bot/v2/bot/msg"
"github.com/innogames/slack-bot/v2/bot/util"
log "github.com/sirupsen/logrus"
)

// bot function to generate images with Dall-E
func (c *chatGPTCommand) dalleGenerateImage(match matcher.Result, message msg.Message) {
prompt := match.GetString(util.FullMatch)

go func() {
c.AddReaction(":coffee:", message)
defer c.RemoveReaction(":coffee:", message)

images, err := generateImage(c.cfg, prompt)
if err != nil {
c.ReplyError(message, err)
return
}

text := ""
for _, image := range images {
text += fmt.Sprintf(
" - %s: <%s|open image>\n",
image.RevisedPrompt,
image.URL,
)
}
c.SendMessage(message, text)
}()
}

func generateImage(cfg Config, prompt string) ([]DalleResponseImage, error) {
jsonData, _ := json.Marshal(DalleRequest{
Model: cfg.DalleModel,
Size: cfg.DalleImageSize,
N: cfg.DalleNumberOfImages,
Prompt: prompt,
})

start := time.Now()
resp, err := doRequest(cfg, apiDalleGenerateImageURL, jsonData)
if err != nil {
return nil, err
}
defer resp.Body.Close()

var response DalleResponse
err = json.NewDecoder(resp.Body).Decode(&response)
if err != nil {
return nil, err
}

if response.Error.Message != "" {
return nil, fmt.Errorf(response.Error.Message)
}

log.WithField("model", cfg.DalleModel).
Infof("Dall-E image generation took %s", time.Since(start))

return response.Data, nil
}
106 changes: 106 additions & 0 deletions command/openai/dalle_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
package openai

import (
"net/http"
"testing"
"time"

"github.com/innogames/slack-bot/v2/bot"
"github.com/innogames/slack-bot/v2/bot/config"
"github.com/innogames/slack-bot/v2/bot/msg"
"github.com/innogames/slack-bot/v2/bot/storage"
"github.com/innogames/slack-bot/v2/mocks"
"github.com/stretchr/testify/assert"
)

func TestDalle(t *testing.T) {
// init memory based storage
storage.InitStorage("")

slackClient := &mocks.SlackClient{}
base := bot.BaseCommand{SlackClient: slackClient}

t.Run("test http error", func(t *testing.T) {
ts := startTestServer(
t,
apiDalleGenerateImageURL,
[]testRequest{
{
`{"model":"dall-e-3","prompt":"a nice cat","n":1,"size":"1024x1024"}`,
`{
"error": {
"code": "invalid_api_key",
"message": "Incorrect API key provided: sk-1234**************************************567.",
"type": "invalid_request_error"
}
}`,
http.StatusUnauthorized,
},
},
)
openaiCfg := defaultConfig
openaiCfg.APIHost = ts.URL
openaiCfg.APIKey = "0815pass"

cfg := &config.Config{}
cfg.Set("openai", openaiCfg)

defer ts.Close()

commands := GetCommands(base, cfg)

message := msg.Message{}
message.Text = "dalle a nice cat"

mocks.AssertReaction(slackClient, ":coffee:", message)
mocks.AssertRemoveReaction(slackClient, ":coffee:", message)
mocks.AssertError(slackClient, message, "Incorrect API key provided: sk-1234**************************************567.")

actual := commands.Run(message)
time.Sleep(time.Millisecond * 100)
assert.True(t, actual)
})

t.Run("test generate image", func(t *testing.T) {
ts := startTestServer(
t,
apiDalleGenerateImageURL,
[]testRequest{
{
`{"model":"dall-e-3","prompt":"a nice cat","n":1,"size":"1024x1024"}`,
` {
"created": 1700233554,
"data": [
{
"url": "https://example.com/image123",
"revised_prompt": "revised prompt 1234"
}
]
}`,
http.StatusUnauthorized,
},
},
)
openaiCfg := defaultConfig
openaiCfg.APIHost = ts.URL
openaiCfg.APIKey = "0815pass"

cfg := &config.Config{}
cfg.Set("openai", openaiCfg)

defer ts.Close()

commands := GetCommands(base, cfg)

message := msg.Message{}
message.Text = "dalle a nice cat"

mocks.AssertReaction(slackClient, ":coffee:", message)
mocks.AssertRemoveReaction(slackClient, ":coffee:", message)
mocks.AssertSlackMessage(slackClient, message, " - revised prompt 1234: <https://example.com/image123|open image>\n")

actual := commands.Run(message)
time.Sleep(time.Millisecond * 100)
assert.True(t, actual)
})
}
Loading

0 comments on commit d093bbc

Please sign in to comment.