提交
This commit is contained in:
167
internal/controller/copilot/embedding_client.go
Normal file
167
internal/controller/copilot/embedding_client.go
Normal file
@@ -0,0 +1,167 @@
|
||||
package copilot
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// 常量定义
|
||||
const (
|
||||
defaultTimeout = 30 * time.Second
|
||||
contentTypeJSON = "application/json"
|
||||
)
|
||||
|
||||
// EmbeddingRequest 表示向嵌入API发送的请求
|
||||
type EmbeddingRequest struct {
|
||||
Model string `json:"model"`
|
||||
Input []string `json:"input"`
|
||||
Dimensions int `json:"dimensions"`
|
||||
}
|
||||
|
||||
// EmbeddingResponse 表示从嵌入API接收的响应
|
||||
type EmbeddingResponse struct {
|
||||
Data []EmbeddingData `json:"data"`
|
||||
Model string `json:"model"`
|
||||
Object string `json:"object"`
|
||||
Usage Usage `json:"usage"`
|
||||
}
|
||||
|
||||
// EmbeddingData 表示单个嵌入数据
|
||||
type EmbeddingData struct {
|
||||
Embedding []float32 `json:"embedding"`
|
||||
Index int `json:"index"`
|
||||
Object string `json:"object"`
|
||||
}
|
||||
|
||||
// Usage 表示API使用情况
|
||||
type Usage struct {
|
||||
PromptTokens int `json:"prompt_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
}
|
||||
|
||||
// 移除未使用的类型
|
||||
// Parameters 和 EmbeddingsRequest, EmbeddingsResponse 已被移除
|
||||
|
||||
// EmbeddingClient 封装了与嵌入API交互的功能
|
||||
type EmbeddingClient struct {
|
||||
apiURL string
|
||||
apiKey string
|
||||
model string
|
||||
dimensions int
|
||||
httpClient *http.Client
|
||||
clientMutex sync.RWMutex
|
||||
}
|
||||
|
||||
// NewEmbeddingClient 创建一个新的嵌入客户端
|
||||
func NewEmbeddingClient(dimensions int) (*EmbeddingClient, error) {
|
||||
apiURL := os.Getenv("EMBEDDING_API_BASE")
|
||||
apiKey := os.Getenv("EMBEDDING_API_KEY")
|
||||
|
||||
if apiURL == "" || apiKey == "" {
|
||||
return nil, fmt.Errorf("EMBEDDING_API_BASE or EMBEDDING_API_KEY environment variable not set")
|
||||
}
|
||||
|
||||
if os.Getenv("EMBEDDING_API_MODEL_NAME") == "" {
|
||||
return nil, fmt.Errorf("EMBEDDING_API_MODEL_NAME environment variable not set")
|
||||
}
|
||||
|
||||
// 解析超时时间,如果未设置或解析失败则使用默认值
|
||||
timeout := defaultTimeout
|
||||
if timeoutStr := os.Getenv("HTTP_CLIENT_TIMEOUT"); timeoutStr != "" {
|
||||
if parsedTimeout, err := time.ParseDuration(timeoutStr + "s"); err == nil {
|
||||
timeout = parsedTimeout
|
||||
}
|
||||
}
|
||||
|
||||
client := &http.Client{
|
||||
Timeout: timeout,
|
||||
Transport: &http.Transport{
|
||||
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
|
||||
},
|
||||
}
|
||||
|
||||
return &EmbeddingClient{
|
||||
apiURL: apiURL,
|
||||
apiKey: apiKey,
|
||||
model: os.Getenv("EMBEDDING_API_MODEL_NAME"),
|
||||
dimensions: dimensions,
|
||||
httpClient: client,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// SetModel 设置嵌入模型
|
||||
func (c *EmbeddingClient) SetModel(model string) {
|
||||
c.clientMutex.Lock()
|
||||
defer c.clientMutex.Unlock()
|
||||
c.model = model
|
||||
}
|
||||
|
||||
// GetEmbedding 获取单个文本的嵌入
|
||||
func (c *EmbeddingClient) GetEmbedding(ctx context.Context, text string) ([]float32, error) {
|
||||
resp, err := c.GetEmbeddings(ctx, []string{text})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(resp.Data) == 0 {
|
||||
return nil, fmt.Errorf("no embeddings returned")
|
||||
}
|
||||
|
||||
return resp.Data[0].Embedding, nil
|
||||
}
|
||||
|
||||
// GetEmbeddings 批量获取多个文本的嵌入
|
||||
func (c *EmbeddingClient) GetEmbeddings(ctx context.Context, texts []string) (*EmbeddingResponse, error) {
|
||||
c.clientMutex.RLock()
|
||||
dimensions := c.dimensions
|
||||
c.clientMutex.RUnlock()
|
||||
|
||||
reqBody := EmbeddingRequest{
|
||||
Model: os.Getenv("EMBEDDING_API_MODEL_NAME"),
|
||||
Input: texts,
|
||||
Dimensions: dimensions,
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal request: %v", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.apiURL, bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %v", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", contentTypeJSON)
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.apiKey))
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to make request: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read response body: %v", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
var embeddingResp EmbeddingResponse
|
||||
if err := json.Unmarshal(body, &embeddingResp); err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal response: %v", err)
|
||||
}
|
||||
|
||||
return &embeddingResp, nil
|
||||
}
|
||||
Reference in New Issue
Block a user