168 lines
4.2 KiB
Go
168 lines
4.2 KiB
Go
package copilot
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"crypto/tls"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"os"
|
|
"sync"
|
|
"time"
|
|
)
|
|
|
|
// 常量定义
|
|
const (
|
|
defaultTimeout = 30 * time.Second
|
|
contentTypeJSON = "application/json"
|
|
)
|
|
|
|
// EmbeddingRequest 表示向嵌入API发送的请求
|
|
type EmbeddingRequest struct {
|
|
Model string `json:"model"`
|
|
Input []string `json:"input"`
|
|
Dimensions int `json:"dimensions"`
|
|
}
|
|
|
|
// EmbeddingResponse 表示从嵌入API接收的响应
|
|
type EmbeddingResponse struct {
|
|
Data []EmbeddingData `json:"data"`
|
|
Model string `json:"model"`
|
|
Object string `json:"object"`
|
|
Usage Usage `json:"usage"`
|
|
}
|
|
|
|
// EmbeddingData 表示单个嵌入数据
|
|
type EmbeddingData struct {
|
|
Embedding []float32 `json:"embedding"`
|
|
Index int `json:"index"`
|
|
Object string `json:"object"`
|
|
}
|
|
|
|
// Usage 表示API使用情况
|
|
type Usage struct {
|
|
PromptTokens int `json:"prompt_tokens"`
|
|
TotalTokens int `json:"total_tokens"`
|
|
}
|
|
|
|
// 移除未使用的类型
|
|
// Parameters 和 EmbeddingsRequest, EmbeddingsResponse 已被移除
|
|
|
|
// EmbeddingClient 封装了与嵌入API交互的功能
|
|
type EmbeddingClient struct {
|
|
apiURL string
|
|
apiKey string
|
|
model string
|
|
dimensions int
|
|
httpClient *http.Client
|
|
clientMutex sync.RWMutex
|
|
}
|
|
|
|
// NewEmbeddingClient 创建一个新的嵌入客户端
|
|
func NewEmbeddingClient(dimensions int) (*EmbeddingClient, error) {
|
|
apiURL := os.Getenv("EMBEDDING_API_BASE")
|
|
apiKey := os.Getenv("EMBEDDING_API_KEY")
|
|
|
|
if apiURL == "" || apiKey == "" {
|
|
return nil, fmt.Errorf("EMBEDDING_API_BASE or EMBEDDING_API_KEY environment variable not set")
|
|
}
|
|
|
|
if os.Getenv("EMBEDDING_API_MODEL_NAME") == "" {
|
|
return nil, fmt.Errorf("EMBEDDING_API_MODEL_NAME environment variable not set")
|
|
}
|
|
|
|
// 解析超时时间,如果未设置或解析失败则使用默认值
|
|
timeout := defaultTimeout
|
|
if timeoutStr := os.Getenv("HTTP_CLIENT_TIMEOUT"); timeoutStr != "" {
|
|
if parsedTimeout, err := time.ParseDuration(timeoutStr + "s"); err == nil {
|
|
timeout = parsedTimeout
|
|
}
|
|
}
|
|
|
|
client := &http.Client{
|
|
Timeout: timeout,
|
|
Transport: &http.Transport{
|
|
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
|
|
},
|
|
}
|
|
|
|
return &EmbeddingClient{
|
|
apiURL: apiURL,
|
|
apiKey: apiKey,
|
|
model: os.Getenv("EMBEDDING_API_MODEL_NAME"),
|
|
dimensions: dimensions,
|
|
httpClient: client,
|
|
}, nil
|
|
}
|
|
|
|
// SetModel 设置嵌入模型
|
|
func (c *EmbeddingClient) SetModel(model string) {
|
|
c.clientMutex.Lock()
|
|
defer c.clientMutex.Unlock()
|
|
c.model = model
|
|
}
|
|
|
|
// GetEmbedding 获取单个文本的嵌入
|
|
func (c *EmbeddingClient) GetEmbedding(ctx context.Context, text string) ([]float32, error) {
|
|
resp, err := c.GetEmbeddings(ctx, []string{text})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if len(resp.Data) == 0 {
|
|
return nil, fmt.Errorf("no embeddings returned")
|
|
}
|
|
|
|
return resp.Data[0].Embedding, nil
|
|
}
|
|
|
|
// GetEmbeddings 批量获取多个文本的嵌入
|
|
func (c *EmbeddingClient) GetEmbeddings(ctx context.Context, texts []string) (*EmbeddingResponse, error) {
|
|
c.clientMutex.RLock()
|
|
dimensions := c.dimensions
|
|
c.clientMutex.RUnlock()
|
|
|
|
reqBody := EmbeddingRequest{
|
|
Model: os.Getenv("EMBEDDING_API_MODEL_NAME"),
|
|
Input: texts,
|
|
Dimensions: dimensions,
|
|
}
|
|
|
|
jsonData, err := json.Marshal(reqBody)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to marshal request: %v", err)
|
|
}
|
|
|
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.apiURL, bytes.NewBuffer(jsonData))
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to create request: %v", err)
|
|
}
|
|
|
|
req.Header.Set("Content-Type", contentTypeJSON)
|
|
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.apiKey))
|
|
|
|
resp, err := c.httpClient.Do(req)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to make request: %v", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
body, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to read response body: %v", err)
|
|
}
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
return nil, fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(body))
|
|
}
|
|
|
|
var embeddingResp EmbeddingResponse
|
|
if err := json.Unmarshal(body, &embeddingResp); err != nil {
|
|
return nil, fmt.Errorf("failed to unmarshal response: %v", err)
|
|
}
|
|
|
|
return &embeddingResp, nil
|
|
}
|