Files
copilot-app/internal/controller/copilot/embedding_client.go
2025-08-13 19:03:20 +08:00

168 lines
4.2 KiB
Go

package copilot
import (
"bytes"
"context"
"crypto/tls"
"encoding/json"
"fmt"
"io"
"net/http"
"os"
"sync"
"time"
)
// 常量定义
const (
defaultTimeout = 30 * time.Second
contentTypeJSON = "application/json"
)
// EmbeddingRequest 表示向嵌入API发送的请求
type EmbeddingRequest struct {
Model string `json:"model"`
Input []string `json:"input"`
Dimensions int `json:"dimensions"`
}
// EmbeddingResponse 表示从嵌入API接收的响应
type EmbeddingResponse struct {
Data []EmbeddingData `json:"data"`
Model string `json:"model"`
Object string `json:"object"`
Usage Usage `json:"usage"`
}
// EmbeddingData 表示单个嵌入数据
type EmbeddingData struct {
Embedding []float32 `json:"embedding"`
Index int `json:"index"`
Object string `json:"object"`
}
// Usage 表示API使用情况
type Usage struct {
PromptTokens int `json:"prompt_tokens"`
TotalTokens int `json:"total_tokens"`
}
// 移除未使用的类型
// Parameters 和 EmbeddingsRequest, EmbeddingsResponse 已被移除
// EmbeddingClient 封装了与嵌入API交互的功能
type EmbeddingClient struct {
apiURL string
apiKey string
model string
dimensions int
httpClient *http.Client
clientMutex sync.RWMutex
}
// NewEmbeddingClient 创建一个新的嵌入客户端
func NewEmbeddingClient(dimensions int) (*EmbeddingClient, error) {
apiURL := os.Getenv("EMBEDDING_API_BASE")
apiKey := os.Getenv("EMBEDDING_API_KEY")
if apiURL == "" || apiKey == "" {
return nil, fmt.Errorf("EMBEDDING_API_BASE or EMBEDDING_API_KEY environment variable not set")
}
if os.Getenv("EMBEDDING_API_MODEL_NAME") == "" {
return nil, fmt.Errorf("EMBEDDING_API_MODEL_NAME environment variable not set")
}
// 解析超时时间,如果未设置或解析失败则使用默认值
timeout := defaultTimeout
if timeoutStr := os.Getenv("HTTP_CLIENT_TIMEOUT"); timeoutStr != "" {
if parsedTimeout, err := time.ParseDuration(timeoutStr + "s"); err == nil {
timeout = parsedTimeout
}
}
client := &http.Client{
Timeout: timeout,
Transport: &http.Transport{
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
},
}
return &EmbeddingClient{
apiURL: apiURL,
apiKey: apiKey,
model: os.Getenv("EMBEDDING_API_MODEL_NAME"),
dimensions: dimensions,
httpClient: client,
}, nil
}
// SetModel 设置嵌入模型
func (c *EmbeddingClient) SetModel(model string) {
c.clientMutex.Lock()
defer c.clientMutex.Unlock()
c.model = model
}
// GetEmbedding 获取单个文本的嵌入
func (c *EmbeddingClient) GetEmbedding(ctx context.Context, text string) ([]float32, error) {
resp, err := c.GetEmbeddings(ctx, []string{text})
if err != nil {
return nil, err
}
if len(resp.Data) == 0 {
return nil, fmt.Errorf("no embeddings returned")
}
return resp.Data[0].Embedding, nil
}
// GetEmbeddings 批量获取多个文本的嵌入
func (c *EmbeddingClient) GetEmbeddings(ctx context.Context, texts []string) (*EmbeddingResponse, error) {
c.clientMutex.RLock()
dimensions := c.dimensions
c.clientMutex.RUnlock()
reqBody := EmbeddingRequest{
Model: os.Getenv("EMBEDDING_API_MODEL_NAME"),
Input: texts,
Dimensions: dimensions,
}
jsonData, err := json.Marshal(reqBody)
if err != nil {
return nil, fmt.Errorf("failed to marshal request: %v", err)
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.apiURL, bytes.NewBuffer(jsonData))
if err != nil {
return nil, fmt.Errorf("failed to create request: %v", err)
}
req.Header.Set("Content-Type", contentTypeJSON)
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.apiKey))
resp, err := c.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to make request: %v", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read response body: %v", err)
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(body))
}
var embeddingResp EmbeddingResponse
if err := json.Unmarshal(body, &embeddingResp); err != nil {
return nil, fmt.Errorf("failed to unmarshal response: %v", err)
}
return &embeddingResp, nil
}