-
Notifications
You must be signed in to change notification settings - Fork 0
/
alibaba_strategy.go
143 lines (122 loc) · 3.66 KB
/
alibaba_strategy.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
package llmconnector
import (
"context"
"encoding/json"
"fmt"
"github.com/simp-lee/gohttpclient"
)
type AlibabaStrategy struct {
chatClient *gohttpclient.Client
embedClient *gohttpclient.Client
config *Config
}
func NewAlibabaStrategy(config Config) (*AlibabaStrategy, error) {
if config.APIKey == "" {
return nil, fmt.Errorf("Alibaba API key is required")
}
// Set default base URLs if not set
if config.ChatURL == "" {
config.ChatURL = "https://dashscope.aliyuncs.com/api/v1/services/chat/completions"
}
if config.EmbedURL == "" {
config.EmbedURL = "https://dashscope.aliyuncs.com/api/v1/services/embeddings/text-embedding"
}
// Use default common config if not set
if config.CommonConfig == (CommonConfig{}) {
config.CommonConfig = DefaultCommonConfig()
}
// Prepare the chat client
chatClient, err := createClient(config.CommonConfig, config.APIKey)
if err != nil {
return nil, fmt.Errorf("failed to create Alibaba chat client: %w", err)
}
// Prepare the embedding client
embedClient, err := createClient(config.CommonConfig, config.APIKey)
if err != nil {
return nil, fmt.Errorf("failed to create Alibaba embedding client: %w", err)
}
return &AlibabaStrategy{
chatClient: chatClient,
embedClient: embedClient,
config: &config,
}, nil
}
func (s *AlibabaStrategy) Chat(ctx context.Context, chatMessages []ChatMessage, options *ChatOptions) (ChatResponse, error) {
request := map[string]interface{}{
"model": options.Model,
"messages": chatMessages,
}
if options.Temperature != nil {
request["temperature"] = *options.Temperature
}
if options.MaxTokens != nil {
request["max_tokens"] = *options.MaxTokens
}
if options.TopP != nil {
request["top_p"] = *options.TopP
}
if options.Stop != nil {
request["stop"] = options.Stop
}
resp, err := s.chatClient.Post(ctx, s.config.ChatURL, request)
if err != nil {
return nil, fmt.Errorf("Alibaba chat request failed: %w", err)
}
var alibabaResp AlibabaChatResponse
if err := json.Unmarshal(resp, &alibabaResp); err != nil {
return nil, fmt.Errorf("failed to unmarshal Alibaba chat response: %w", err)
}
return &alibabaResp, nil
}
type AlibabaChatResponse struct {
Output struct {
Text string `json:"text"`
} `json:"output"`
}
func (r *AlibabaChatResponse) GetContent() string {
return r.Output.Text
}
func (s *AlibabaStrategy) Embed(ctx context.Context, texts []string, options *EmbedOptions) (EmbedResponse, error) {
request := map[string]interface{}{
"model": options.Model,
"input": map[string]interface{}{
"texts": texts,
},
}
if options.EmbeddingType != "" {
request["params"] = map[string]string{
"text_type": options.EmbeddingType,
}
}
resp, err := s.embedClient.Post(ctx, s.config.EmbedURL, request)
if err != nil {
return nil, fmt.Errorf("Alibaba embed request failed: %w", err)
}
var alibabaResp AlibabaEmbeddingResponse
if err := json.Unmarshal(resp, &alibabaResp); err != nil {
return nil, fmt.Errorf("failed to unmarshal Alibaba embed response: %w", err)
}
return &AlibabaEmbedResponseWrapper{alibabaResp}, nil
}
type AlibabaEmbeddingResponse struct {
Output struct {
Embeddings []struct {
TextIndex int `json:"text_index"`
Embedding []float32 `json:"embedding"`
} `json:"embeddings"`
} `json:"output"`
Usage struct {
TotalTokens int `json:"total_tokens"`
} `json:"usage"`
RequestID string `json:"request_id"`
}
type AlibabaEmbedResponseWrapper struct {
AlibabaEmbeddingResponse
}
func (r *AlibabaEmbedResponseWrapper) GetEmbeddings() [][]float32 {
embeddings := make([][]float32, len(r.Output.Embeddings))
for i, embedding := range r.Output.Embeddings {
embeddings[i] = embedding.Embedding
}
return embeddings
}