Files
copilot-app/internal/controller/copilot/chunks.go
2025-08-13 19:03:20 +08:00

274 lines
6.5 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package copilot
import (
"context"
"crypto/sha256"
"fmt"
"github.com/gofrs/uuid"
"net/http"
"os"
"strconv"
"strings"
"sync"
"github.com/gin-gonic/gin"
)
// 常量定义
const (
defaultDimensionSize = 1536 // 默认向量维度
markdownFilePrefix = "File: `%s`\n```shell\n"
markdownFileSuffix = "```"
)
// 获取向量维度大小
func getDimensionSize() int {
dimensionSize := defaultDimensionSize
if dimSizeStr := os.Getenv("EMBEDDING_DIMENSION_SIZE"); dimSizeStr != "" {
if dimSize, err := strconv.Atoi(dimSizeStr); err == nil {
dimensionSize = dimSize
}
}
return dimensionSize
}
// 计算块大小
func getChunkSize() int {
// 根据维度大小调整块大小这里设置为维度的1.5倍左右
return getDimensionSize() * 3 / 2
}
// ChunkRequest 表示分块请求
type ChunkRequest struct {
Content string `json:"content" binding:"required"`
Path string `json:"path" binding:"required"`
Embed bool `json:"embed"`
}
// Chunk 表示内容块
type Chunk struct {
Hash string `json:"hash"`
Text string `json:"text"`
Range Range `json:"range"`
Embedding Embedding `json:"embedding,omitempty"`
}
// Range 表示文本范围
type Range struct {
Start int `json:"start"`
End int `json:"end"`
}
// Embedding 表示向量嵌入
type Embedding struct {
Embedding []float32 `json:"embedding"`
}
// ChunkResponse 表示分块响应
type ChunkResponse struct {
Chunks []Chunk `json:"chunks"`
EmbeddingModel string `json:"embedding_model"`
}
// ChunkService 处理文本分块和嵌入的服务
type ChunkService struct {
embeddingClient *EmbeddingClient
modelName string
}
// NewChunkService 创建新的分块服务
func NewChunkService() (*ChunkService, error) {
client, err := NewEmbeddingClient(getDimensionSize())
if err != nil {
return nil, fmt.Errorf("failed to create embedding client: %w", err)
}
return &ChunkService{
embeddingClient: client,
modelName: os.Getenv("EMBEDDING_API_MODEL_NAME"),
}, nil
}
// HandleChunks 处理分块请求的HTTP处理器
func HandleChunks(c *gin.Context) {
var req ChunkRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
service, err := NewChunkService()
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("failed to initialize service: %v", err)})
return
}
chunks := service.SplitIntoChunks(req.Content, req.Path)
if req.Embed {
if err := service.GenerateEmbeddings(c.Request.Context(), chunks); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("failed to generate embeddings: %v", err)})
return
}
}
resp := ChunkResponse{
Chunks: chunks,
EmbeddingModel: service.modelName,
}
requestID := uuid.Must(uuid.NewV4()).String()
c.Header("x-github-request-id", requestID)
c.JSON(http.StatusOK, resp)
}
// SplitIntoChunks 将内容分割成块
func (s *ChunkService) SplitIntoChunks(content, path string) []Chunk {
var chunks []Chunk
lines := strings.Split(content, "\n")
chunkSize := getChunkSize()
// 预分配切片容量,减少内存重新分配
estimatedChunks := len(content)/chunkSize + 1
chunks = make([]Chunk, 0, estimatedChunks)
var sb strings.Builder
start := 0
for _, line := range lines {
lineWithNewline := line + "\n"
// 如果当前块加上新行会超过chunkSize并且当前块不为空
if sb.Len()+len(lineWithNewline) > chunkSize && sb.Len() > 0 {
// 创建新的chunk
chunkText := sb.String()
chunk := s.createChunk(chunkText, path, start, start+len(chunkText))
chunks = append(chunks, chunk)
start += len(chunkText)
sb.Reset()
sb.WriteString(lineWithNewline)
} else {
sb.WriteString(lineWithNewline)
}
}
// 添加最后一个chunk
if sb.Len() > 0 {
chunkText := sb.String()
chunk := s.createChunk(chunkText, path, start, start+len(chunkText))
chunks = append(chunks, chunk)
}
return chunks
}
// createChunk 创建一个新的内容块
func (s *ChunkService) createChunk(text, path string, start, end int) Chunk {
// 计算文本的SHA-256哈希
hash := sha256.Sum256([]byte(text))
return Chunk{
Hash: fmt.Sprintf("%x", hash),
Text: fmt.Sprintf(markdownFilePrefix+"%s"+markdownFileSuffix, path, text),
Range: Range{
Start: start,
End: end,
},
Embedding: Embedding{
Embedding: make([]float32, 0), // 初始化为空切片
},
}
}
// GenerateEmbeddings 为所有块生成嵌入向量
func (s *ChunkService) GenerateEmbeddings(ctx context.Context, chunks []Chunk) error {
if len(chunks) == 0 {
return nil
}
// 对于少量块,直接串行处理
if len(chunks) <= 5 {
return s.generateEmbeddingsSerial(ctx, chunks)
}
// 对于大量块,使用并行处理
return s.generateEmbeddingsParallel(ctx, chunks)
}
// generateEmbeddingsSerial 串行生成嵌入向量
func (s *ChunkService) generateEmbeddingsSerial(ctx context.Context, chunks []Chunk) error {
for i := range chunks {
text := s.extractPlainText(chunks[i].Text)
embedding, err := s.embeddingClient.GetEmbedding(ctx, text)
if err != nil {
return fmt.Errorf("failed to generate embedding for chunk %d: %w", i, err)
}
chunks[i].Embedding.Embedding = embedding
}
return nil
}
// generateEmbeddingsParallel 并行生成嵌入向量
func (s *ChunkService) generateEmbeddingsParallel(ctx context.Context, chunks []Chunk) error {
var wg sync.WaitGroup
errChan := make(chan error, len(chunks))
// 限制并发数量,避免过多的并发请求
semaphore := make(chan struct{}, 10)
for i := range chunks {
wg.Add(1)
go func(idx int) {
defer wg.Done()
// 获取信号量
semaphore <- struct{}{}
defer func() { <-semaphore }()
text := s.extractPlainText(chunks[idx].Text)
embedding, err := s.embeddingClient.GetEmbedding(ctx, text)
if err != nil {
errChan <- fmt.Errorf("failed to generate embedding for chunk %d: %w", idx, err)
return
}
chunks[idx].Embedding.Embedding = embedding
}(i)
}
// 等待所有goroutine完成
wg.Wait()
close(errChan)
// 检查是否有错误
select {
case err := <-errChan:
if err != nil {
return err
}
default:
// 没有错误
}
return nil
}
// extractPlainText 从markdown格式的文本中提取纯文本
func (s *ChunkService) extractPlainText(text string) string {
// 移除第一行 File: 标记
if idx := strings.Index(text, "\n"); idx != -1 {
text = text[idx+1:]
}
// 移除 ```shell 和结尾的 ```
text = strings.TrimPrefix(text, "```shell\n")
text = strings.TrimSuffix(text, "```")
return text
}