Skip to content

Commit

Permalink
Added URL attribution and improved message handling
Browse files Browse the repository at this point in the history
- Introduced a new UrlAttr struct to handle URL attributions.
- Enhanced the Handler function with additional checks for message metadata.
- Implemented a map to store URL attributions, reducing redundant API calls.
- Added getURLAttribution function to fetch URL attributions from the API.
- Refactored variable names for clarity and consistency.
  • Loading branch information
leokwsw committed Apr 5, 2024
1 parent 58989d7 commit c3e4720
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 6 deletions.
5 changes: 5 additions & 0 deletions api/chatgpt/typings.go
Original file line number Diff line number Diff line change
Expand Up @@ -208,3 +208,8 @@ type FileInfo struct {
DownloadURL string `json:"download_url"`
Status string `json:"status"`
}

type UrlAttr struct {
Url string `json:"url"`
Attribution string `json:"attribution"`
}
56 changes: 50 additions & 6 deletions api/imitate/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"github.com/leokwsw/go-chatgpt-api/api"
"github.com/leokwsw/go-chatgpt-api/api/chatgpt"
"io"
"net/url"
"os"
"regexp"
"strconv"
Expand Down Expand Up @@ -440,6 +441,9 @@ func Handler(c *gin.Context, resp *http.Response, token string, uuid string, str
if !(originalResponse.Message.Author.Role == "assistant" || (originalResponse.Message.Author.Role == "tool" && originalResponse.Message.Content.ContentType != "text")) || originalResponse.Message.Content.Parts == nil {
continue
}
if originalResponse.Message.Metadata.MessageType == "" {
continue
}
if originalResponse.Message.Metadata.MessageType != "next" && originalResponse.Message.Metadata.MessageType != "continue" || !strings.HasSuffix(originalResponse.Message.Content.ContentType, "text") {
continue
}
Expand All @@ -459,9 +463,17 @@ func Handler(c *gin.Context, resp *http.Response, token string, uuid string, str
}
}
offset := 0
for i, citation := range originalResponse.Message.Metadata.Citations {
for _, citation := range originalResponse.Message.Metadata.Citations {
rl := len(r)
originalResponse.Message.Content.Parts[0] = string(r[:citation.StartIx+offset]) + "[^" + strconv.Itoa(i+1) + "^](" + citation.Metadata.URL + " \"" + citation.Metadata.Title + "\")" + string(r[citation.EndIx+offset:])
attr := urlAttrMap[citation.Metadata.URL]
if attr == "" {
u, _ := url.Parse(citation.Metadata.URL)
baseURL := u.Scheme + "://" + u.Host + "/"
attr = getURLAttribution(token, api.PUID, baseURL)
if attr != "" {
urlAttrMap[citation.Metadata.URL] = attr
}
}
r = []rune(originalResponse.Message.Content.Parts[0].(string))
offset += len(r) - rl
}
Expand All @@ -487,9 +499,9 @@ func Handler(c *gin.Context, resp *http.Response, token string, uuid string, str
if err != nil {
continue
}
url := apiUrl + strings.Split(dalleContent.AssetPointer, "//")[1] + "/download"
newUrl := apiUrl + strings.Split(dalleContent.AssetPointer, "//")[1] + "/download"
waitGroup.Add(1)
go GetImageSource(&waitGroup, url, dalleContent.Metadata.Dalle.Prompt, token, index, imgSource)
go GetImageSource(&waitGroup, newUrl, dalleContent.Metadata.Dalle.Prompt, token, index, imgSource)
}
waitGroup.Wait()
translatedResponse := NewChatCompletionChunk(strings.Join(imgSource, ""))
Expand Down Expand Up @@ -531,8 +543,8 @@ func Handler(c *gin.Context, resp *http.Response, token string, uuid string, str
}
if isEnd {
if stream {
final_line := StopChunk(finishReason)
c.Writer.WriteString("data: " + final_line.String() + "\n\n")
finalLine := StopChunk(finishReason)
c.Writer.WriteString("data: " + finalLine.String() + "\n\n")
}
break
}
Expand All @@ -546,3 +558,35 @@ func Handler(c *gin.Context, resp *http.Response, token string, uuid string, str
ParentID: originalResponse.Message.ID,
}
}

var urlAttrMap = make(map[string]string)

func getURLAttribution(token string, puid string, url string) string {
req, err := http.NewRequest(http.MethodPost, chatgpt.ApiPrefix+"/attributions", bytes.NewBuffer([]byte(`{"urls":["`+url+`"]}`)))
if err != nil {
return ""
}
if puid != "" {
req.Header.Set("Cookie", "_puid="+puid+";")
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("User-Agent", api.UserAgent)
req.Header.Set("Oai-Language", api.Language)
if token != "" {
req.Header.Set("Authorization", api.GetAccessToken(token))
}
if err != nil {
return ""
}
resp, err := api.Client.Do(req)
if err != nil {
return ""
}
defer resp.Body.Close()
var urlAttr chatgpt.UrlAttr
err = json.NewDecoder(resp.Body).Decode(&urlAttr)
if err != nil {
return ""
}
return urlAttr.Attribution
}

0 comments on commit c3e4720

Please sign in to comment.