-
Notifications
You must be signed in to change notification settings - Fork 3
/
embd.go
33 lines (30 loc) · 1007 Bytes
/
embd.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
package gemini
import (
"github.com/Limit-LAB/go-gemini/models"
)
func (c *Client) EmbedContent(model models.EmbeddingModel, parts []models.Part) (models.EmbeddingValue, error) {
url := c.url(string(model), "embedContent")
req := models.EmbeddingContentRequest{
Model: "models/" + string(model),
Content: models.Content{
Parts: parts,
},
}
rst, err := unjson[models.EmbeddingContentResponse](c.post(url, req))
return rst.Embedding, err
}
func (c *Client) BatchEmbedContent(model models.EmbeddingModel, parts [][]models.Part) ([]models.EmbeddingValue, error) {
url := c.url(string(model), "batchEmbedContents")
req := models.BatchEmbeddingContentsRequest{}
modelStr := "models/" + string(model)
for _, content := range parts {
req.Requests = append(req.Requests, models.EmbeddingContentRequest{
Model: modelStr,
Content: models.Content{
Parts: content,
},
})
}
rst, err := unjson[models.BatchEmbeddingContentsResponse](c.post(url, req))
return rst.Embeddings, err
}