Skip to content

Commit

Permalink
Merge pull request #22 from otiai10/develop
Browse files Browse the repository at this point in the history
Develop
  • Loading branch information
otiai10 authored Jun 26, 2023
2 parents 4d20a5b + f6307b5 commit c9efe56
Show file tree
Hide file tree
Showing 9 changed files with 267 additions and 39 deletions.
13 changes: 13 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,19 @@ request := openaigo.ChatRequest{
}
```

if you want **shorthand**, use [`functioncall`](https://pkg.go.dev/github.com/otiai10/openaigo@v1.4.0/functioncall).

```go
import fc "github.com/otiai10/openaigo/functioncall"

request.Functions = fc.Funcs{
"get_weather": {GetWeather, "Get weather of the location", fc.Params{
{"location", "string", "location of the weather", true},
{"date", "string", "ISO 8601 date string", true},
}},
}
```

See [test app](https://github.com/otiai10/openaigo/blob/main/testapp/main.go) as a working example.

# Need `stream`?
Expand Down
31 changes: 23 additions & 8 deletions chat.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package openaigo

import "encoding/json"

// ChatCompletionRequestBody:
// https://platform.openai.com/docs/guides/chat/chat-completions-beta
// https://platform.openai.com/docs/api-reference/chat
Expand Down Expand Up @@ -80,12 +82,22 @@ type ChatCompletionRequestBody struct {
User string `json:"user,omitempty"`

// Functions: A list of functions which GPT is allowed to request to call.
Functions []Function `json:"functions,omitempty"`
// Functions []Function `json:"functions,omitempty"`
Functions json.Marshaler `json:"functions,omitempty"`

// FunctionCall: You ain't need it. Default is "auto".
FunctionCall string `json:"function_call,omitempty"`
}

type Functions []Function

func (funcs Functions) MarshalJSON() ([]byte, error) {
if len(funcs) == 0 {
return []byte("[]"), nil
}
return json.Marshal([]Function(funcs))
}

type Function struct {
Name string `json:"name,omitempty"`
Description string `json:"description,omitempty"`
Expand Down Expand Up @@ -129,17 +141,20 @@ type Message struct {
}

type FunctionCall struct {
Name string `json:"name,omitempty"`
NameRaw string `json:"name,omitempty"`
ArgumentsRaw string `json:"arguments,omitempty"`
// Arguments map[string]any `json:"arguments,omitempty"`
}

// func Arg[T any](fc FunctionCall, name string) (res T) {
// if fc.Arguments == nil || fc.Arguments[name] == nil {
// return
// }
// return fc.Arguments[name].(T)
// }
func (fc *FunctionCall) Name() string {
return fc.NameRaw
}

func (fc *FunctionCall) Args() map[string]any {
var args map[string]any
json.Unmarshal([]byte(fc.ArgumentsRaw), &args)
return args
}

type ChatCompletionResponse struct {
ID string `json:"id"`
Expand Down
2 changes: 1 addition & 1 deletion chat_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ func TestClient_ChatCompletion_FunctionCall(t *testing.T) {
Role: "user", Content: "Hello, I'm John.",
},
},
Functions: []Function{
Functions: Functions{
{
Name: "test_method",
Parameters: Parameters{
Expand Down
59 changes: 59 additions & 0 deletions functioncall/all_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
package functioncall

import (
"encoding/json"
"testing"

. "github.com/otiai10/mint"
)

func TestFunctions(t *testing.T) {
funcs := Funcs{}
Expect(t, funcs).TypeOf("functioncall.Funcs")
}

func TestFunctions_MarshalJSON(t *testing.T) {
repeat := func(word string, count int) (r string) {
for i := 0; i < count; i++ {
r += word
}
return r
}
funcs := Funcs{
"repeat": Func{repeat, "Repeat given string N times", Params{
{"word", "string", "String to be repeated", true},
{"count", "number", "How many times to repeat", true},
}},
}
b, err := funcs.MarshalJSON()
Expect(t, err).ToBe(nil)

v := []map[string]any{}
err = json.Unmarshal(b, &v)
Expect(t, err).ToBe(nil)

Expect(t, v).Query("0.name").ToBe("repeat")
Expect(t, v).Query("0.description").ToBe("Repeat given string N times")
Expect(t, v).Query("0.parameters.type").ToBe("object")
Expect(t, v).Query("0.parameters.properties.word.type").ToBe("string")
Expect(t, v).Query("0.parameters.required.1").ToBe("count")
}

func TestAs(t *testing.T) {
repeat := func(word string, count int) (r string) {
for i := 0; i < count; i++ {
r += word
}
return r
}
funcs := Funcs{
"repeat": Func{repeat, "Repeat given string N times", Params{
{"word", "string", "String to be repeated", true},
{"count", "number", "How many times to repeat", true},
}},
}
a := As[[]map[string]any](funcs)
Expect(t, a).TypeOf("[]map[string]interface {}")
Expect(t, a).Query("0.name").ToBe("repeat")
Expect(t, a).Query("0.parameters.type").ToBe("object")
}
66 changes: 66 additions & 0 deletions functioncall/functioncall.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
package functioncall

import (
"encoding/json"
)

type Funcs map[string]Func

type Func struct {
Value any `json:"-"`
Description string `json:"description,omitempty"`
Parameters Params `json:"parameters,omitempty"`
}

type Params []Param

type Param struct {
Name string `json:"-"`
Type string `json:"type,omitempty"`
Description string `json:"description,omitempty"`
Required bool `json:"-"`
// Enum []any `json:"enum,omitempty"`
}

func (funcs Funcs) MarshalJSON() ([]byte, error) {
// Convert map to slice
sl := []map[string]any{}
for key, fun := range funcs {
f := map[string]any{
"name": key,
"description": fun.Description,
"parameters": fun.Parameters,
}
sl = append(sl, f)
}
return json.Marshal(sl)
}

func (params Params) MarshalJSON() ([]byte, error) {
required := []string{}
props := map[string]Param{}
for _, p := range params {
if p.Required {
required = append(required, p.Name)
}
props[p.Name] = p
}
schema := map[string]any{
"type": "object",
"properties": props,
"required": required,
}
return json.Marshal(schema)
}

func As[T any](funcs Funcs) (dest T) {
b, err := funcs.MarshalJSON()
if err != nil {
panic(err)
}
err = json.Unmarshal(b, &dest)
if err != nil {
panic(err)
}
return dest
}
52 changes: 52 additions & 0 deletions functioncall/invoke.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
package functioncall

import (
"encoding/json"
"fmt"
"reflect"
)

type Invocation interface {
Name() string
Args() map[string]any
}

func (funcs Funcs) Call(invocation Invocation) string {
b, err := json.Marshal(funcs.Invoke(invocation))
if err != nil {
return err.Error()
}
return string(b)
}

func (funcs Funcs) Invoke(invocation Invocation) any {
f, ok := funcs[invocation.Name()]
if !ok {
return fmt.Sprintf("function not found: %s", invocation.Name())
}
v := reflect.ValueOf(f.Value)
if !v.IsValid() || v.IsZero() {
return fmt.Sprintf("function is invalid: %s", invocation.Name())
}
if v.Kind() != reflect.Func {
return fmt.Sprintf("function is not a function: %s", invocation.Name())
}
if v.Type().NumIn() != len(invocation.Args()) {
return fmt.Sprintf("function argument length mismatch: %s", invocation.Name())
}
// Call the function with given arguments by using `reflect` package
args := invocation.Args()
params := []reflect.Value{}
for i, p := range f.Parameters {
if arg, ok := args[p.Name]; ok {
params = append(params, reflect.ValueOf(arg))
} else {
params = append(params, reflect.Zero(v.Type().In(i)))
}
}
rets := []any{}
for _, r := range v.Call(params) {
rets = append(rets, r.Interface())
}
return rets
}
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@ module github.com/otiai10/openaigo

go 1.18

require github.com/otiai10/mint v1.4.1
require github.com/otiai10/mint v1.6.1
5 changes: 2 additions & 3 deletions go.sum
Original file line number Diff line number Diff line change
@@ -1,3 +1,2 @@
github.com/otiai10/curr v1.0.0 h1:TJIWdbX0B+kpNagQrjgq8bCMrbhiuX73M2XwgtDMoOI=
github.com/otiai10/mint v1.4.1 h1:HOVBfKP1oXIc0wWo9hZ8JLdZtyCPWqjvmFDuVZ0yv2Y=
github.com/otiai10/mint v1.4.1/go.mod h1:gifjb2MYOoULtKLqUAEILUG/9KONW6f7YsJ6vQLTlFI=
github.com/otiai10/mint v1.6.1 h1:kgbTJmOpp/0ce7hk3H8jiSuR0MXmpwWRfqUdKww17qg=
github.com/otiai10/mint v1.6.1/go.mod h1:MJm72SBthJjz8qhefc4z1PYEieWmy8Bku7CjcAqyUSM=
76 changes: 50 additions & 26 deletions testapp/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"time"

"github.com/otiai10/openaigo"
fc "github.com/otiai10/openaigo/functioncall"
)

type Scenario struct {
Expand All @@ -20,6 +21,15 @@ const (
SKIP = "\033[0;33m====> SKIP\033[0m\n\n"
)

func GetWeather(location string, date float64) (string, error) {
return "sunny", nil
}

func GetDate() int {
now := time.Now()
return now.Year()*10000 + int(now.Month())*100 + now.Day()
}

var (
OPENAI_API_KEY string

Expand Down Expand Up @@ -141,37 +151,51 @@ var (
Name: "function_call",
Run: func() (any, error) {
conversation := []openaigo.Message{
{Role: "user", Content: "What's the weather in Tokyo today?"},
{Role: "user", Content: "Should I bring an umbrella tomorrow? I'm living around Tokyo."},
}
funcs := fc.Funcs{
"GetDate": fc.Func{GetDate, "A function to get date today", fc.Params{}},
"GetWeather": fc.Func{GetWeather, "A function to get weather information", fc.Params{
{"location", "string", "location of the wather", true},
{"date", "integer", "date MMDD as number", true},
}},
}
client := openaigo.NewClient(OPENAI_API_KEY)
request := openaigo.ChatRequest{
Model: openaigo.GPT3_5Turbo_0613,
Messages: conversation,
Functions: []openaigo.Function{
{
Name: "get_weather",
Description: "A function to get weather information",
Parameters: openaigo.Parameters{
Type: "object",
Properties: map[string]map[string]any{
"location": {"type": "string"},
"date": {"type": "string", "description": "ISO 8601 date string"},
},
Required: []string{"location"},
},
},
},
Model: openaigo.GPT3_5Turbo_0613,
Messages: conversation,
Functions: funcs,
}
res_1, err := client.Chat(nil, request)
if err != nil {
return nil, err
}
conversation = append(conversation, res_1.Choices[0].Message)
if res_1.Choices[0].Message.FunctionCall != nil {
fmt.Printf("%+v\n", res_1.Choices[0].Message.FunctionCall)
conversation = append(conversation, openaigo.Message{
Role: "function",
Name: res_1.Choices[0].Message.FunctionCall.Name(),
Content: funcs.Call(res_1.Choices[0].Message.FunctionCall),
})
}
request.Messages = conversation
res_2, err := client.Chat(nil, request)
if err != nil {
return nil, err
}
conversation = append(conversation, res_2.Choices[0].Message)
if res_2.Choices[0].Message.FunctionCall != nil {
fmt.Printf("%+v\n", res_2.Choices[0].Message.FunctionCall)
conversation = append(conversation, openaigo.Message{
Role: "function",
Name: res_2.Choices[0].Message.FunctionCall.Name(),
Content: funcs.Call(res_2.Choices[0].Message.FunctionCall),
})
}
res0, err := client.Chat(nil, request)
conversation = append(conversation, res0.Choices[0].Message)
conversation = append(conversation, openaigo.Message{
Role: "function",
Name: "get_weather",
Content: "20%:thunderstorm,70%:sandstorm,10%:snowy",
})
request.Messages = conversation
res, err := client.Chat(nil, request)
return res, err
res_3, err := client.Chat(nil, request)
return res_3, err
},
},
}
Expand Down

0 comments on commit c9efe56

Please sign in to comment.