提交
This commit is contained in:
29
internal/controller/copilot/agents.go
Normal file
29
internal/controller/copilot/agents.go
Normal file
@@ -0,0 +1,29 @@
|
||||
package copilot
|
||||
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/gofrs/uuid"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// GetAgents 获取代理列表
|
||||
func GetAgents(c *gin.Context) {
|
||||
requestID := uuid.Must(uuid.NewV4()).String()
|
||||
c.Header("x-github-request-id", requestID)
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"agents": []gin.H{
|
||||
{
|
||||
"id": "github/copilot-workspace",
|
||||
"name": "@workspace",
|
||||
"description": "Ask questions and get answers about your codebase.",
|
||||
"version": "1.0.0",
|
||||
"publisher": "github",
|
||||
"model": "gpt-4o-mini-2024-07-18",
|
||||
"capabilities": "workspace",
|
||||
"default_model": "gpt-4o-mini-2024-07-18",
|
||||
"capabilities_model": "gpt-4o-mini-2024-07-18",
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
188
internal/controller/copilot/chat_completions.go
Normal file
188
internal/controller/copilot/chat_completions.go
Normal file
@@ -0,0 +1,188 @@
|
||||
package copilot
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/gofrs/uuid"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ChatCompletions chat对话接口
|
||||
func ChatCompletions(c *gin.Context) {
|
||||
ctx := c.Request.Context()
|
||||
|
||||
// 添加响应头, 解决vscode校验github所属问题
|
||||
requestID := uuid.Must(uuid.NewV4()).String()
|
||||
c.Header("x-github-request-id", requestID)
|
||||
|
||||
body, err := io.ReadAll(c.Request.Body)
|
||||
if nil != err {
|
||||
c.AbortWithStatus(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
apiModelName := gjson.GetBytes(body, "model").String()
|
||||
// 默认设置的对话模型
|
||||
envModelName := os.Getenv("CHAT_API_MODEL_NAME")
|
||||
// 默认设置的对话请求地址
|
||||
chatAPIURL := os.Getenv("CHAT_API_BASE")
|
||||
// 默认设置的对话模型key
|
||||
apiKey := os.Getenv("CHAT_API_KEY")
|
||||
|
||||
// 轻量模型直接走代码补全接口, 节约成本
|
||||
if strings.Contains(apiModelName, os.Getenv("LIGHTWEIGHT_MODEL")) {
|
||||
envModelName = os.Getenv("CODEX_API_MODEL_NAME")
|
||||
codexAPIURL := os.Getenv("CODEX_API_BASE")
|
||||
parsedURL, err := url.Parse(codexAPIURL)
|
||||
if err != nil {
|
||||
fmt.Println("URL解析错误:", err)
|
||||
return
|
||||
}
|
||||
chatAPIURL = "https://" + parsedURL.Hostname() + "/v1/chat/completions"
|
||||
apiKey = os.Getenv("CODEX_API_KEY")
|
||||
}
|
||||
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
|
||||
body, _ = sjson.SetBytes(body, "model", envModelName)
|
||||
body, _ = sjson.SetBytes(body, "stream", true) // 强制流式输出
|
||||
|
||||
if !gjson.GetBytes(body, "function_call").Exists() {
|
||||
messages := gjson.GetBytes(body, "messages").Array()
|
||||
for i, msg := range messages {
|
||||
toolCalls := msg.Get("tool_calls").Array()
|
||||
if len(toolCalls) == 0 {
|
||||
body, _ = sjson.DeleteBytes(body, fmt.Sprintf("messages.%d.tool_calls", i))
|
||||
}
|
||||
}
|
||||
lastIndex := len(messages) - 1
|
||||
chatLocale := os.Getenv("CHAT_LOCALE")
|
||||
if chatLocale != "" && !strings.Contains(messages[lastIndex].Get("content").String(), "Respond in the following locale") {
|
||||
body, _ = sjson.SetBytes(body, "messages."+strconv.Itoa(lastIndex)+".content", messages[lastIndex].Get("content").String()+"Respond in the following locale: "+chatLocale+".")
|
||||
}
|
||||
}
|
||||
|
||||
body, _ = sjson.DeleteBytes(body, "intent")
|
||||
body, _ = sjson.DeleteBytes(body, "intent_threshold")
|
||||
body, _ = sjson.DeleteBytes(body, "intent_content")
|
||||
body, _ = sjson.DeleteBytes(body, "logprobs") // #IBZYCA
|
||||
|
||||
// 是否支持使用工具, 避免模型不支持相关功能报错
|
||||
chatUseTools, _ := strconv.ParseBool(os.Getenv("CHAT_USE_TOOLS"))
|
||||
if !chatUseTools {
|
||||
body, _ = sjson.DeleteBytes(body, "tools")
|
||||
body, _ = sjson.DeleteBytes(body, "tool_call")
|
||||
body, _ = sjson.DeleteBytes(body, "functions")
|
||||
body, _ = sjson.DeleteBytes(body, "function_call")
|
||||
body, _ = sjson.DeleteBytes(body, "tool_choice")
|
||||
}
|
||||
|
||||
ChatMaxTokens, _ := strconv.Atoi(os.Getenv("CHAT_MAX_TOKENS"))
|
||||
if int(gjson.GetBytes(body, "max_tokens").Int()) > ChatMaxTokens {
|
||||
body, _ = sjson.SetBytes(body, "max_tokens", ChatMaxTokens)
|
||||
}
|
||||
|
||||
if gjson.GetBytes(body, "n").Int() > 1 {
|
||||
body, _ = sjson.SetBytes(body, "n", 1)
|
||||
}
|
||||
|
||||
messages := gjson.GetBytes(body, "messages").Array()
|
||||
userAgent := c.GetHeader("User-Agent")
|
||||
|
||||
// 拦截处理vscode对话首次预处理请求, 减少等待时间
|
||||
firstRole := gjson.GetBytes(body, "messages.0.role").String()
|
||||
firstContent := gjson.GetBytes(body, "messages.0.content").String()
|
||||
if strings.Contains(firstRole, "system") && strings.Contains(firstContent, "You are a helpful AI programming assistant to a user") &&
|
||||
!strings.Contains(firstContent, "If you cannot choose just one category, or if none of the categories seem like they would provide the user with a better result, you must always respond with") &&
|
||||
!gjson.GetBytes(body, "tool_choice").Exists() {
|
||||
_, _ = c.Writer.WriteString("data: [DONE]\n\n")
|
||||
c.Writer.Flush()
|
||||
return
|
||||
}
|
||||
|
||||
// vs2022客户端的兼容处理
|
||||
if strings.Contains(userAgent, "VSCopilotClient") {
|
||||
lastMessage := messages[len(messages)-1]
|
||||
messageRole := lastMessage.Get("role").String()
|
||||
messageContent := lastMessage.Get("content").String()
|
||||
if strings.Contains(firstRole, "system") && strings.Contains(firstContent, "You are an AI programming assistant") {
|
||||
vs2022FirstChatTemplate(c)
|
||||
return
|
||||
}
|
||||
if messageRole == "user" && messageContent == "Write a short one-sentence question that I can ask that naturally follows from the previous few questions and answers. It should not ask a question which is already answered in the conversation. It should be a question that you are capable of answering. Reply with only the text of the question and nothing else." {
|
||||
_, _ = c.Writer.WriteString("data: [DONE]\n\n")
|
||||
c.Writer.Flush()
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, chatAPIURL, io.NopCloser(bytes.NewBuffer(body)))
|
||||
if nil != err {
|
||||
c.AbortWithStatus(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+apiKey)
|
||||
httpClientTimeout, _ := time.ParseDuration(os.Getenv("HTTP_CLIENT_TIMEOUT") + "s")
|
||||
client := &http.Client{
|
||||
Timeout: httpClientTimeout,
|
||||
Transport: &http.Transport{
|
||||
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
|
||||
},
|
||||
}
|
||||
resp, err := client.Do(req)
|
||||
if nil != err {
|
||||
if errors.Is(err, context.Canceled) {
|
||||
c.AbortWithStatus(http.StatusRequestTimeout)
|
||||
return
|
||||
}
|
||||
|
||||
log.Println("request conversation failed:", err.Error())
|
||||
c.AbortWithStatus(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
defer CloseIO(resp.Body)
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
log.Println("request completions failed:", string(body))
|
||||
|
||||
resp.Body = io.NopCloser(bytes.NewBuffer(body))
|
||||
}
|
||||
|
||||
c.Status(resp.StatusCode)
|
||||
_, _ = io.Copy(c.Writer, resp.Body)
|
||||
}
|
||||
|
||||
// vs2022FirstChatTemplate is a template for the first chat completion response
|
||||
func vs2022FirstChatTemplate(c *gin.Context) {
|
||||
fixedOutput := `data: {"id":"f6202f6f-9d13-4518-b34f-65e945b0a1a2","object":"chat.completion.chunk","model":"gpt-4o-mini-2024-07-18","created":1734752124,"choices":[{"index":0,"delta":{"role":"assistant","content":""},"finish_reason":null}]}
|
||||
|
||||
data: {"id":"b2ab39cb-9a84-4006-b470-93a5965c6d69","object":"chat.completion.chunk","model":"gpt-4o-mini-2024-07-18","created":1734752124,"choices":[{"index":0,"delta":{"role":"assistant","content":""},"finish_reason":null}]}
|
||||
|
||||
data: {"id":"df5f9ce7-b653-4ffb-8d92-e21856ce1ffc","object":"chat.completion.chunk","model":"gpt-4o-mini-2024-07-18","created":1734752124,"choices":[{"index":0,"delta":{"role":"assistant","content":"Explain"},"finish_reason":null}]}
|
||||
|
||||
data: {"id":"fb58d66e-bb16-43f2-8470-2de0c8662533","object":"chat.completion.chunk","model":"gpt-4o-mini-2024-07-18","created":1734752124,"choices":[{"index":0,"delta":{"role":"assistant","content":""},"finish_reason":null}]}
|
||||
|
||||
data: {"id":"22ea16e2-766f-4b10-84d0-68399abc9181","object":"chat.completion.chunk","model":"gpt-4o-mini-2024-07-18","created":1734752124,"choices":[{"index":0,"delta":{"role":"assistant","content":""},"finish_reason":"stop"}]}
|
||||
|
||||
data: [DONE]
|
||||
|
||||
`
|
||||
_, _ = c.Writer.WriteString(fixedOutput)
|
||||
c.Writer.Flush()
|
||||
}
|
||||
273
internal/controller/copilot/chunks.go
Normal file
273
internal/controller/copilot/chunks.go
Normal file
@@ -0,0 +1,273 @@
|
||||
package copilot
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"fmt"
|
||||
"github.com/gofrs/uuid"
|
||||
"net/http"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// 常量定义
|
||||
const (
|
||||
defaultDimensionSize = 1536 // 默认向量维度
|
||||
markdownFilePrefix = "File: `%s`\n```shell\n"
|
||||
markdownFileSuffix = "```"
|
||||
)
|
||||
|
||||
// 获取向量维度大小
|
||||
func getDimensionSize() int {
|
||||
dimensionSize := defaultDimensionSize
|
||||
if dimSizeStr := os.Getenv("EMBEDDING_DIMENSION_SIZE"); dimSizeStr != "" {
|
||||
if dimSize, err := strconv.Atoi(dimSizeStr); err == nil {
|
||||
dimensionSize = dimSize
|
||||
}
|
||||
}
|
||||
return dimensionSize
|
||||
}
|
||||
|
||||
// 计算块大小
|
||||
func getChunkSize() int {
|
||||
// 根据维度大小调整块大小,这里设置为维度的1.5倍左右
|
||||
return getDimensionSize() * 3 / 2
|
||||
}
|
||||
|
||||
// ChunkRequest 表示分块请求
|
||||
type ChunkRequest struct {
|
||||
Content string `json:"content" binding:"required"`
|
||||
Path string `json:"path" binding:"required"`
|
||||
Embed bool `json:"embed"`
|
||||
}
|
||||
|
||||
// Chunk 表示内容块
|
||||
type Chunk struct {
|
||||
Hash string `json:"hash"`
|
||||
Text string `json:"text"`
|
||||
Range Range `json:"range"`
|
||||
Embedding Embedding `json:"embedding,omitempty"`
|
||||
}
|
||||
|
||||
// Range 表示文本范围
|
||||
type Range struct {
|
||||
Start int `json:"start"`
|
||||
End int `json:"end"`
|
||||
}
|
||||
|
||||
// Embedding 表示向量嵌入
|
||||
type Embedding struct {
|
||||
Embedding []float32 `json:"embedding"`
|
||||
}
|
||||
|
||||
// ChunkResponse 表示分块响应
|
||||
type ChunkResponse struct {
|
||||
Chunks []Chunk `json:"chunks"`
|
||||
EmbeddingModel string `json:"embedding_model"`
|
||||
}
|
||||
|
||||
// ChunkService 处理文本分块和嵌入的服务
|
||||
type ChunkService struct {
|
||||
embeddingClient *EmbeddingClient
|
||||
modelName string
|
||||
}
|
||||
|
||||
// NewChunkService 创建新的分块服务
|
||||
func NewChunkService() (*ChunkService, error) {
|
||||
client, err := NewEmbeddingClient(getDimensionSize())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create embedding client: %w", err)
|
||||
}
|
||||
|
||||
return &ChunkService{
|
||||
embeddingClient: client,
|
||||
modelName: os.Getenv("EMBEDDING_API_MODEL_NAME"),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// HandleChunks 处理分块请求的HTTP处理器
|
||||
func HandleChunks(c *gin.Context) {
|
||||
var req ChunkRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
service, err := NewChunkService()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("failed to initialize service: %v", err)})
|
||||
return
|
||||
}
|
||||
|
||||
chunks := service.SplitIntoChunks(req.Content, req.Path)
|
||||
|
||||
if req.Embed {
|
||||
if err := service.GenerateEmbeddings(c.Request.Context(), chunks); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("failed to generate embeddings: %v", err)})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
resp := ChunkResponse{
|
||||
Chunks: chunks,
|
||||
EmbeddingModel: service.modelName,
|
||||
}
|
||||
|
||||
requestID := uuid.Must(uuid.NewV4()).String()
|
||||
c.Header("x-github-request-id", requestID)
|
||||
c.JSON(http.StatusOK, resp)
|
||||
}
|
||||
|
||||
// SplitIntoChunks 将内容分割成块
|
||||
func (s *ChunkService) SplitIntoChunks(content, path string) []Chunk {
|
||||
var chunks []Chunk
|
||||
lines := strings.Split(content, "\n")
|
||||
chunkSize := getChunkSize()
|
||||
|
||||
// 预分配切片容量,减少内存重新分配
|
||||
estimatedChunks := len(content)/chunkSize + 1
|
||||
chunks = make([]Chunk, 0, estimatedChunks)
|
||||
|
||||
var sb strings.Builder
|
||||
start := 0
|
||||
|
||||
for _, line := range lines {
|
||||
lineWithNewline := line + "\n"
|
||||
|
||||
// 如果当前块加上新行会超过chunkSize,并且当前块不为空
|
||||
if sb.Len()+len(lineWithNewline) > chunkSize && sb.Len() > 0 {
|
||||
// 创建新的chunk
|
||||
chunkText := sb.String()
|
||||
chunk := s.createChunk(chunkText, path, start, start+len(chunkText))
|
||||
chunks = append(chunks, chunk)
|
||||
|
||||
start += len(chunkText)
|
||||
sb.Reset()
|
||||
sb.WriteString(lineWithNewline)
|
||||
} else {
|
||||
sb.WriteString(lineWithNewline)
|
||||
}
|
||||
}
|
||||
|
||||
// 添加最后一个chunk
|
||||
if sb.Len() > 0 {
|
||||
chunkText := sb.String()
|
||||
chunk := s.createChunk(chunkText, path, start, start+len(chunkText))
|
||||
chunks = append(chunks, chunk)
|
||||
}
|
||||
|
||||
return chunks
|
||||
}
|
||||
|
||||
// createChunk 创建一个新的内容块
|
||||
func (s *ChunkService) createChunk(text, path string, start, end int) Chunk {
|
||||
// 计算文本的SHA-256哈希
|
||||
hash := sha256.Sum256([]byte(text))
|
||||
|
||||
return Chunk{
|
||||
Hash: fmt.Sprintf("%x", hash),
|
||||
Text: fmt.Sprintf(markdownFilePrefix+"%s"+markdownFileSuffix, path, text),
|
||||
Range: Range{
|
||||
Start: start,
|
||||
End: end,
|
||||
},
|
||||
Embedding: Embedding{
|
||||
Embedding: make([]float32, 0), // 初始化为空切片
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// GenerateEmbeddings 为所有块生成嵌入向量
|
||||
func (s *ChunkService) GenerateEmbeddings(ctx context.Context, chunks []Chunk) error {
|
||||
if len(chunks) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 对于少量块,直接串行处理
|
||||
if len(chunks) <= 5 {
|
||||
return s.generateEmbeddingsSerial(ctx, chunks)
|
||||
}
|
||||
|
||||
// 对于大量块,使用并行处理
|
||||
return s.generateEmbeddingsParallel(ctx, chunks)
|
||||
}
|
||||
|
||||
// generateEmbeddingsSerial 串行生成嵌入向量
|
||||
func (s *ChunkService) generateEmbeddingsSerial(ctx context.Context, chunks []Chunk) error {
|
||||
for i := range chunks {
|
||||
text := s.extractPlainText(chunks[i].Text)
|
||||
|
||||
embedding, err := s.embeddingClient.GetEmbedding(ctx, text)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to generate embedding for chunk %d: %w", i, err)
|
||||
}
|
||||
|
||||
chunks[i].Embedding.Embedding = embedding
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// generateEmbeddingsParallel 并行生成嵌入向量
|
||||
func (s *ChunkService) generateEmbeddingsParallel(ctx context.Context, chunks []Chunk) error {
|
||||
var wg sync.WaitGroup
|
||||
errChan := make(chan error, len(chunks))
|
||||
|
||||
// 限制并发数量,避免过多的并发请求
|
||||
semaphore := make(chan struct{}, 10)
|
||||
|
||||
for i := range chunks {
|
||||
wg.Add(1)
|
||||
go func(idx int) {
|
||||
defer wg.Done()
|
||||
|
||||
// 获取信号量
|
||||
semaphore <- struct{}{}
|
||||
defer func() { <-semaphore }()
|
||||
|
||||
text := s.extractPlainText(chunks[idx].Text)
|
||||
|
||||
embedding, err := s.embeddingClient.GetEmbedding(ctx, text)
|
||||
if err != nil {
|
||||
errChan <- fmt.Errorf("failed to generate embedding for chunk %d: %w", idx, err)
|
||||
return
|
||||
}
|
||||
|
||||
chunks[idx].Embedding.Embedding = embedding
|
||||
}(i)
|
||||
}
|
||||
|
||||
// 等待所有goroutine完成
|
||||
wg.Wait()
|
||||
close(errChan)
|
||||
|
||||
// 检查是否有错误
|
||||
select {
|
||||
case err := <-errChan:
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
default:
|
||||
// 没有错误
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// extractPlainText 从markdown格式的文本中提取纯文本
|
||||
func (s *ChunkService) extractPlainText(text string) string {
|
||||
// 移除第一行 File: 标记
|
||||
if idx := strings.Index(text, "\n"); idx != -1 {
|
||||
text = text[idx+1:]
|
||||
}
|
||||
|
||||
// 移除 ```shell 和结尾的 ```
|
||||
text = strings.TrimPrefix(text, "```shell\n")
|
||||
text = strings.TrimSuffix(text, "```")
|
||||
|
||||
return text
|
||||
}
|
||||
373
internal/controller/copilot/code_completions.go
Normal file
373
internal/controller/copilot/code_completions.go
Normal file
@@ -0,0 +1,373 @@
|
||||
package copilot
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"math/rand"
|
||||
"net/http"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/gofrs/uuid"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
|
||||
// CodeCompletions 代码补全
|
||||
func CodeCompletions(c *gin.Context) {
|
||||
ctx := c.Request.Context()
|
||||
|
||||
requestID := uuid.Must(uuid.NewV4()).String()
|
||||
c.Header("x-github-request-id", requestID)
|
||||
|
||||
debounceTime, _ := strconv.Atoi(os.Getenv("COPILOT_DEBOUNCE"))
|
||||
time.Sleep(time.Duration(debounceTime) * time.Millisecond)
|
||||
|
||||
if ctx.Err() != nil {
|
||||
abortCodex(c, http.StatusRequestTimeout)
|
||||
return
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(c.Request.Body)
|
||||
if nil != err {
|
||||
abortCodex(c, http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
codexServiceType := os.Getenv("CODEX_SERVICE_TYPE")
|
||||
body = ConstructRequestBody(body, codexServiceType)
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, os.Getenv("CODEX_API_BASE"), io.NopCloser(bytes.NewBuffer(body)))
|
||||
if nil != err {
|
||||
abortCodex(c, http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
apiKeys := strings.Split(os.Getenv("CODEX_API_KEY"), ",")
|
||||
|
||||
// 检查 apiKeys 是否有效
|
||||
if len(apiKeys) == 0 || (len(apiKeys) == 1 && apiKeys[0] == "") {
|
||||
abortCodex(c, http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
randGen := rand.New(rand.NewSource(time.Now().UnixNano()))
|
||||
selectedKey := strings.TrimSpace(apiKeys[randGen.Intn(len(apiKeys))])
|
||||
|
||||
if selectedKey == "" {
|
||||
abortCodex(c, http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
req.Header.Set("Authorization", "Bearer "+selectedKey)
|
||||
httpClientTimeout, _ := time.ParseDuration(os.Getenv("HTTP_CLIENT_TIMEOUT") + "s")
|
||||
client := &http.Client{
|
||||
Timeout: httpClientTimeout,
|
||||
Transport: &http.Transport{
|
||||
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
|
||||
},
|
||||
}
|
||||
resp, err := client.Do(req)
|
||||
if nil != err {
|
||||
if errors.Is(err, context.Canceled) {
|
||||
abortCodex(c, http.StatusRequestTimeout)
|
||||
return
|
||||
}
|
||||
|
||||
log.Println("request completions failed:", err.Error())
|
||||
abortCodex(c, http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
defer CloseIO(resp.Body)
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
log.Println("request completions failed:", string(body))
|
||||
|
||||
abortCodex(c, resp.StatusCode)
|
||||
return
|
||||
}
|
||||
|
||||
c.Status(resp.StatusCode)
|
||||
// 处理 Ollama 服务的流式响应
|
||||
if codexServiceType == "ollama" {
|
||||
reader := bufio.NewReader(resp.Body)
|
||||
for {
|
||||
line, err := reader.ReadString('\n')
|
||||
if err != nil {
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
if strings.TrimSpace(line) == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
// json解析 line
|
||||
lineJson := gjson.Parse(line)
|
||||
uuid := uuid.Must(uuid.NewV4()).String()
|
||||
done := lineJson.Get("done").Bool()
|
||||
doneReason := lineJson.Get("done_reason").Str
|
||||
response := lineJson.Get("response").Str
|
||||
timestamp := time.Now().Unix()
|
||||
choice := map[string]interface{}{
|
||||
"text": response,
|
||||
"index": 0,
|
||||
"logprobs": nil,
|
||||
"finish_reason": doneReason,
|
||||
}
|
||||
choices := []map[string]interface{}{choice}
|
||||
constructLineData := map[string]interface{}{
|
||||
"id": uuid,
|
||||
"choices": choices,
|
||||
"created": timestamp,
|
||||
"model": lineJson.Get("model").Str,
|
||||
"system_fingerprint": "fp_1c141eb703",
|
||||
"object": "text_completion",
|
||||
}
|
||||
|
||||
if done && strings.Contains(doneReason, "stop") {
|
||||
usage := map[string]interface{}{
|
||||
"prompt_tokens": lineJson.Get("prompt_eval_count").Int(),
|
||||
"completion_tokens": lineJson.Get("eval_count").Int(),
|
||||
"total_tokens": lineJson.Get("prompt_eval_count").Int(),
|
||||
"prompt_cache_hit_tokens": lineJson.Get("prompt_eval_count").Int(),
|
||||
"prompt_cache_miss_tokens": lineJson.Get("eval_count").Int(),
|
||||
}
|
||||
constructLineData["usage"] = usage
|
||||
}
|
||||
|
||||
// 将修改后的数据重新编码为 JSON
|
||||
modifiedJSON, err := json.Marshal(constructLineData)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// 发送修改后的数据
|
||||
_, _ = c.Writer.WriteString("data: " + string(modifiedJSON) + "\n\n")
|
||||
c.Writer.Flush()
|
||||
}
|
||||
|
||||
_, _ = c.Writer.WriteString("data: [DONE]\n\n")
|
||||
c.Writer.Flush()
|
||||
return
|
||||
}
|
||||
|
||||
// 处理默认服务的响应
|
||||
_, _ = io.Copy(c.Writer, resp.Body)
|
||||
}
|
||||
|
||||
// ConstructRequestBody 重新构建请求体
|
||||
func ConstructRequestBody(body []byte, codexServiceType string) []byte {
|
||||
envCodexModel := os.Getenv("CODEX_API_MODEL_NAME")
|
||||
body, _ = sjson.SetBytes(body, "model", envCodexModel)
|
||||
body, _ = sjson.SetBytes(body, "stream", true) // 强制流式输出
|
||||
body, _ = sjson.DeleteBytes(body, "extra")
|
||||
body, _ = sjson.DeleteBytes(body, "nwo")
|
||||
|
||||
// 限制 prompt 和 suffix 的长度
|
||||
body = applyPromptLengthLimit(body)
|
||||
|
||||
temperature, _ := strconv.ParseFloat(os.Getenv("CODEX_TEMPERATURE"), 64)
|
||||
if temperature != -1 {
|
||||
body, _ = sjson.SetBytes(body, "temperature", temperature)
|
||||
}
|
||||
|
||||
codeMaxTokens, _ := strconv.Atoi(os.Getenv("CODEX_MAX_TOKENS"))
|
||||
if int(gjson.GetBytes(body, "max_tokens").Int()) > codeMaxTokens {
|
||||
body, _ = sjson.SetBytes(body, "max_tokens", codeMaxTokens)
|
||||
}
|
||||
|
||||
if gjson.GetBytes(body, "n").Int() > 1 {
|
||||
body, _ = sjson.SetBytes(body, "n", 1)
|
||||
}
|
||||
|
||||
// https://ollama.com/library/stable-code || https://ollama.com/library/codegemma
|
||||
if strings.Contains(envCodexModel, "stable-code") || strings.Contains(envCodexModel, "codegemma") {
|
||||
return constructWithStableCodeModel(body)
|
||||
}
|
||||
|
||||
// https://ollama.com/library/codellama
|
||||
if strings.Contains(envCodexModel, "codellama") {
|
||||
return constructWithCodeLlamaModel(body)
|
||||
}
|
||||
|
||||
// https://help.aliyun.com/zh/model-studio/user-guide/qwen-coder?spm=a2c4g.11186623.0.0.a5234823I6LvAG
|
||||
if strings.Contains(envCodexModel, "qwen-coder-turbo") {
|
||||
return constructWithQwenCoderTurboModel(body)
|
||||
}
|
||||
|
||||
// 支持 Ollama FIM 的模型, 如:https://ollama.com/library/deepseek-coder-v2
|
||||
if codexServiceType == "ollama" {
|
||||
return constructWithOllamaModel(body, codeMaxTokens)
|
||||
}
|
||||
|
||||
return body
|
||||
}
|
||||
|
||||
// applyPromptLengthLimit 对 prompt 和 suffix 应用长度限制
|
||||
func applyPromptLengthLimit(body []byte) []byte {
|
||||
envLimitPrompt := os.Getenv("CODEX_LIMIT_PROMPT")
|
||||
limitPrompt, err := strconv.Atoi(envLimitPrompt)
|
||||
if err != nil || limitPrompt <= 0 {
|
||||
return body
|
||||
}
|
||||
|
||||
body = limitPromptLength(body, limitPrompt)
|
||||
body = limitSuffixLength(body, limitPrompt)
|
||||
|
||||
return body
|
||||
}
|
||||
|
||||
// limitPromptLength 限制 prompt 长度
|
||||
func limitPromptLength(body []byte, limitRows int) []byte {
|
||||
prompt := gjson.GetBytes(body, "prompt")
|
||||
if !prompt.Exists() {
|
||||
return body
|
||||
}
|
||||
|
||||
rows := strings.Split(prompt.Str, "\n")
|
||||
if len(rows) <= limitRows {
|
||||
return body
|
||||
}
|
||||
|
||||
// 保留后面的内容
|
||||
newPrompt := strings.Join(rows[len(rows)-limitRows:], "\n")
|
||||
body, _ = sjson.SetBytes(body, "prompt", newPrompt)
|
||||
|
||||
return body
|
||||
}
|
||||
|
||||
// limitSuffixLength 限制 suffix 长度
|
||||
func limitSuffixLength(body []byte, limitRows int) []byte {
|
||||
suffix := gjson.GetBytes(body, "suffix")
|
||||
if !suffix.Exists() {
|
||||
return body
|
||||
}
|
||||
|
||||
rows := strings.Split(suffix.Str, "\n")
|
||||
if len(rows) <= limitRows {
|
||||
return body
|
||||
}
|
||||
|
||||
// 保留前面的内容
|
||||
newSuffix := strings.Join(rows[:limitRows], "\n")
|
||||
body, _ = sjson.SetBytes(body, "suffix", newSuffix)
|
||||
|
||||
return body
|
||||
}
|
||||
|
||||
// constructWithCodeLlamaModel 重写codeLlama模型要求的请求体
|
||||
func constructWithCodeLlamaModel(body []byte) []byte {
|
||||
suffix := gjson.GetBytes(body, "suffix")
|
||||
prompt := gjson.GetBytes(body, "prompt")
|
||||
content := fmt.Sprintf("<PRE> %s <SUF> %s <MID>", prompt, suffix)
|
||||
|
||||
return constructWithChatModel(body, content)
|
||||
}
|
||||
|
||||
// constructWithStableCodeModel 重写StableCode模型要求的请求体
|
||||
func constructWithStableCodeModel(body []byte) []byte {
|
||||
suffix := gjson.GetBytes(body, "suffix")
|
||||
prompt := gjson.GetBytes(body, "prompt")
|
||||
content := fmt.Sprintf("<fim_prefix>%s<fim_suffix>%s<fim_middle>", prompt, suffix)
|
||||
|
||||
return constructWithChatModel(body, content)
|
||||
}
|
||||
|
||||
// constructWithChatModel 重写Chat请求体
|
||||
func constructWithChatModel(body []byte, content string) []byte {
|
||||
// 创建新的 JSON 对象并添加到 body 中
|
||||
messages := []map[string]string{
|
||||
{
|
||||
"role": "user",
|
||||
"content": content,
|
||||
},
|
||||
}
|
||||
|
||||
body, _ = sjson.SetBytes(body, "messages", messages)
|
||||
|
||||
jsonStr := string(body)
|
||||
jsonStr = strings.ReplaceAll(jsonStr, "\\u003c", "<")
|
||||
jsonStr = strings.ReplaceAll(jsonStr, "\\u003e", ">")
|
||||
return []byte(jsonStr)
|
||||
}
|
||||
|
||||
// constructWithQwenCoderTurboModel 重写QwenCoderTurbo模型要求的请求体
|
||||
func constructWithQwenCoderTurboModel(body []byte) []byte {
|
||||
if gjson.GetBytes(body, "n").Int() > 1 {
|
||||
body, _ = sjson.SetBytes(body, "n", 1)
|
||||
}
|
||||
suffix := gjson.GetBytes(body, "suffix")
|
||||
prompt := gjson.GetBytes(body, "prompt")
|
||||
codeLanguage := gjson.GetBytes(body, "extra.language")
|
||||
|
||||
messages := []map[string]interface{}{
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are an expert in " + codeLanguage.Str + " programming, highly skilled at understanding and continuing to write code.",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Combined with subsequent code snippets, help me complete the code:\n\n" +
|
||||
"Code subsequent content:\n```" + codeLanguage.Str + "\n" + suffix.Str + "```\n\n" +
|
||||
"Remember:\n" +
|
||||
"- Do not generate content outside of the code.\n" +
|
||||
"- Do not directly fill in all the code content, the maximum number of lines of code should not exceed 5 lines.\n" +
|
||||
"- Answer must refer to the code suffix content, do not exceed the boundary, otherwise repeated code will occur.\n" +
|
||||
"- If you don't know how to answer, just reply with an empty string.",
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": prompt.Str,
|
||||
"partial": true,
|
||||
},
|
||||
}
|
||||
|
||||
body, _ = sjson.SetBytes(body, "messages", messages)
|
||||
body, _ = sjson.DeleteBytes(body, "prompt")
|
||||
return body
|
||||
}
|
||||
|
||||
// constructWithOllamaModel 重写Ollama模型要求的请求体
|
||||
func constructWithOllamaModel(body []byte, codeMaxTokens int) []byte {
|
||||
body, _ = sjson.SetBytes(body, "options.temperature", 0)
|
||||
// stop参数处理
|
||||
stopArray := gjson.GetBytes(body, "stop").Array()
|
||||
stopSlice := make([]interface{}, len(stopArray))
|
||||
for i, v := range stopArray {
|
||||
stopSlice[i] = v.String()
|
||||
}
|
||||
body, _ = sjson.SetBytes(body, "options.stop", stopSlice)
|
||||
body, _ = sjson.SetBytes(body, "stream", true)
|
||||
|
||||
maxTokens := gjson.GetBytes(body, "max_tokens").Int()
|
||||
if int(maxTokens) > codeMaxTokens {
|
||||
body, _ = sjson.SetBytes(body, "options.num_predict", codeMaxTokens)
|
||||
} else {
|
||||
body, _ = sjson.SetBytes(body, "options.num_predict", maxTokens)
|
||||
}
|
||||
return body
|
||||
}
|
||||
|
||||
// abortCodex 中断请求
|
||||
func abortCodex(c *gin.Context, status int) {
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.String(status, "data: [DONE]\n\n")
|
||||
c.Abort()
|
||||
}
|
||||
167
internal/controller/copilot/embedding_client.go
Normal file
167
internal/controller/copilot/embedding_client.go
Normal file
@@ -0,0 +1,167 @@
|
||||
package copilot
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// 常量定义
|
||||
const (
|
||||
defaultTimeout = 30 * time.Second
|
||||
contentTypeJSON = "application/json"
|
||||
)
|
||||
|
||||
// EmbeddingRequest 表示向嵌入API发送的请求
|
||||
type EmbeddingRequest struct {
|
||||
Model string `json:"model"`
|
||||
Input []string `json:"input"`
|
||||
Dimensions int `json:"dimensions"`
|
||||
}
|
||||
|
||||
// EmbeddingResponse 表示从嵌入API接收的响应
|
||||
type EmbeddingResponse struct {
|
||||
Data []EmbeddingData `json:"data"`
|
||||
Model string `json:"model"`
|
||||
Object string `json:"object"`
|
||||
Usage Usage `json:"usage"`
|
||||
}
|
||||
|
||||
// EmbeddingData 表示单个嵌入数据
|
||||
type EmbeddingData struct {
|
||||
Embedding []float32 `json:"embedding"`
|
||||
Index int `json:"index"`
|
||||
Object string `json:"object"`
|
||||
}
|
||||
|
||||
// Usage 表示API使用情况
|
||||
type Usage struct {
|
||||
PromptTokens int `json:"prompt_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
}
|
||||
|
||||
// 移除未使用的类型
|
||||
// Parameters 和 EmbeddingsRequest, EmbeddingsResponse 已被移除
|
||||
|
||||
// EmbeddingClient 封装了与嵌入API交互的功能
|
||||
type EmbeddingClient struct {
|
||||
apiURL string
|
||||
apiKey string
|
||||
model string
|
||||
dimensions int
|
||||
httpClient *http.Client
|
||||
clientMutex sync.RWMutex
|
||||
}
|
||||
|
||||
// NewEmbeddingClient 创建一个新的嵌入客户端
|
||||
func NewEmbeddingClient(dimensions int) (*EmbeddingClient, error) {
|
||||
apiURL := os.Getenv("EMBEDDING_API_BASE")
|
||||
apiKey := os.Getenv("EMBEDDING_API_KEY")
|
||||
|
||||
if apiURL == "" || apiKey == "" {
|
||||
return nil, fmt.Errorf("EMBEDDING_API_BASE or EMBEDDING_API_KEY environment variable not set")
|
||||
}
|
||||
|
||||
if os.Getenv("EMBEDDING_API_MODEL_NAME") == "" {
|
||||
return nil, fmt.Errorf("EMBEDDING_API_MODEL_NAME environment variable not set")
|
||||
}
|
||||
|
||||
// 解析超时时间,如果未设置或解析失败则使用默认值
|
||||
timeout := defaultTimeout
|
||||
if timeoutStr := os.Getenv("HTTP_CLIENT_TIMEOUT"); timeoutStr != "" {
|
||||
if parsedTimeout, err := time.ParseDuration(timeoutStr + "s"); err == nil {
|
||||
timeout = parsedTimeout
|
||||
}
|
||||
}
|
||||
|
||||
client := &http.Client{
|
||||
Timeout: timeout,
|
||||
Transport: &http.Transport{
|
||||
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
|
||||
},
|
||||
}
|
||||
|
||||
return &EmbeddingClient{
|
||||
apiURL: apiURL,
|
||||
apiKey: apiKey,
|
||||
model: os.Getenv("EMBEDDING_API_MODEL_NAME"),
|
||||
dimensions: dimensions,
|
||||
httpClient: client,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// SetModel 设置嵌入模型
|
||||
func (c *EmbeddingClient) SetModel(model string) {
|
||||
c.clientMutex.Lock()
|
||||
defer c.clientMutex.Unlock()
|
||||
c.model = model
|
||||
}
|
||||
|
||||
// GetEmbedding 获取单个文本的嵌入
|
||||
func (c *EmbeddingClient) GetEmbedding(ctx context.Context, text string) ([]float32, error) {
|
||||
resp, err := c.GetEmbeddings(ctx, []string{text})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(resp.Data) == 0 {
|
||||
return nil, fmt.Errorf("no embeddings returned")
|
||||
}
|
||||
|
||||
return resp.Data[0].Embedding, nil
|
||||
}
|
||||
|
||||
// GetEmbeddings 批量获取多个文本的嵌入
|
||||
func (c *EmbeddingClient) GetEmbeddings(ctx context.Context, texts []string) (*EmbeddingResponse, error) {
|
||||
c.clientMutex.RLock()
|
||||
dimensions := c.dimensions
|
||||
c.clientMutex.RUnlock()
|
||||
|
||||
reqBody := EmbeddingRequest{
|
||||
Model: os.Getenv("EMBEDDING_API_MODEL_NAME"),
|
||||
Input: texts,
|
||||
Dimensions: dimensions,
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal request: %v", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.apiURL, bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %v", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", contentTypeJSON)
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.apiKey))
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to make request: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read response body: %v", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
var embeddingResp EmbeddingResponse
|
||||
if err := json.Unmarshal(body, &embeddingResp); err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal response: %v", err)
|
||||
}
|
||||
|
||||
return &embeddingResp, nil
|
||||
}
|
||||
58
internal/controller/copilot/embeddings.go
Normal file
58
internal/controller/copilot/embeddings.go
Normal file
@@ -0,0 +1,58 @@
|
||||
package copilot
|
||||
|
||||
import (
|
||||
"github.com/gofrs/uuid"
|
||||
"net/http"
|
||||
"os"
|
||||
"strconv"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// EmbeddingsAPIRequest 表示嵌入API的请求结构
|
||||
type EmbeddingsAPIRequest struct {
|
||||
Input []string `json:"input" binding:"required"`
|
||||
Model string `json:"model,omitempty"`
|
||||
Dimensions int `json:"dimensions,omitempty"`
|
||||
}
|
||||
|
||||
// HandleEmbeddings 处理嵌入请求的HTTP处理器
|
||||
func HandleEmbeddings(c *gin.Context) {
|
||||
requestID := uuid.Must(uuid.NewV4()).String()
|
||||
c.Header("x-github-request-id", requestID)
|
||||
|
||||
var req EmbeddingsAPIRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// 从环境变量获取维度大小,默认为1536
|
||||
dimensionSize := 1536
|
||||
if dimSizeStr := os.Getenv("EMBEDDING_DIMENSION_SIZE"); dimSizeStr != "" {
|
||||
if dimSize, err := strconv.Atoi(dimSizeStr); err == nil {
|
||||
dimensionSize = dimSize
|
||||
}
|
||||
}
|
||||
|
||||
// 创建嵌入客户端
|
||||
client, err := NewEmbeddingClient(dimensionSize)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// 如果请求中指定了模型,则使用请求中的模型
|
||||
if req.Model != "" {
|
||||
client.SetModel(req.Model)
|
||||
}
|
||||
|
||||
// 获取嵌入,使用请求上下文以支持取消操作
|
||||
resp, err := client.GetEmbeddings(c.Request.Context(), req.Input)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, resp)
|
||||
}
|
||||
147
internal/controller/copilot/get_copilot_internal_v2_token.go
Normal file
147
internal/controller/copilot/get_copilot_internal_v2_token.go
Normal file
@@ -0,0 +1,147 @@
|
||||
package copilot
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/gofrs/uuid"
|
||||
"log"
|
||||
"math/rand"
|
||||
"net/http"
|
||||
"os"
|
||||
"ripper/internal/app/github_auth"
|
||||
"ripper/internal/cache"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// GetDisguiseCopilotInternalV2Token 返回伪装的token
|
||||
func GetDisguiseCopilotInternalV2Token(ctx *gin.Context) {
|
||||
requestID := uuid.Must(uuid.NewV4()).String()
|
||||
ctx.Header("x-github-request-id", requestID)
|
||||
|
||||
trackingId, _ := uuid.NewV4()
|
||||
now := time.Now().Unix()
|
||||
dcAt, _ := strconv.Atoi(os.Getenv("DISGUISE_COPILOT_TOKEN_EXPIRES_AT"))
|
||||
expiresAt := now + int64(dcAt)
|
||||
sku := "copilot_for_business_seat"
|
||||
|
||||
copilotToken := github_auth.JsonMap2SignToken(map[string]interface{}{
|
||||
"tid": trackingId,
|
||||
"exp": expiresAt,
|
||||
"sku": sku,
|
||||
"st": "dotcom",
|
||||
"chat": 1,
|
||||
"u": "github",
|
||||
})
|
||||
|
||||
endpoints := make(map[string]interface{})
|
||||
endpoints["api"] = os.Getenv("API_BASE_URL")
|
||||
endpoints["origin-tracker"] = "https://origin-tracker.individual.githubcopilot.com"
|
||||
endpoints["proxy"] = os.Getenv("PROXY_BASE_URL")
|
||||
endpoints["telemetry"] = os.Getenv("TELEMETRY_BASE_URL")
|
||||
|
||||
gout := gin.H{
|
||||
"annotations_enabled": true,
|
||||
"chat_enabled": true,
|
||||
"chat_jetbrains_enabled": true,
|
||||
"code_quote_enabled": true,
|
||||
"code_review_enabled": false,
|
||||
"codesearch": true,
|
||||
"copilot_ide_agent_chat_gpt4_small_prompt": false,
|
||||
"copilotignore_enabled": false,
|
||||
"endpoints": endpoints,
|
||||
"expires_at": expiresAt,
|
||||
"individual": true,
|
||||
"nes_enabled": false,
|
||||
"prompt_8k": true,
|
||||
"public_suggestions": "disabled",
|
||||
"refresh_in": 1500,
|
||||
"sku": sku,
|
||||
"snippy_load_test_enabled": false,
|
||||
"telemetry": "disabled",
|
||||
"token": copilotToken,
|
||||
"tracking_id": trackingId,
|
||||
"intellij_editor_fetcher": false,
|
||||
"vsc_electron_fetcher": false,
|
||||
"vs_editor_fetcher": false,
|
||||
"vsc_panel_v2": false,
|
||||
"xcode": true,
|
||||
"xcode_chat": true,
|
||||
"limited_user_quotas": nil,
|
||||
"limited_user_reset_date": nil,
|
||||
"vsc_electron_fetcher_v2": false,
|
||||
}
|
||||
ctx.JSON(http.StatusOK, gout)
|
||||
}
|
||||
|
||||
// GetCopilotInternalV2Token 获取github copilot官方token
|
||||
func GetCopilotInternalV2Token(c *gin.Context) {
|
||||
ghuTokens := strings.Split(os.Getenv("COPILOT_GHU_TOKEN"), ",")
|
||||
if len(ghuTokens) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
rand.Seed(time.Now().UnixNano())
|
||||
ghu := ghuTokens[rand.Intn(len(ghuTokens))]
|
||||
if ghu == "" {
|
||||
log.Println("ghu token is empty")
|
||||
c.JSON(http.StatusUnprocessableEntity, gin.H{
|
||||
"message": "ghu token is empty",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
cacheKey := "copilot_internal_v2_token"
|
||||
token, err := cache.Get(cacheKey)
|
||||
if err != nil {
|
||||
log.Println(err.Error())
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": err.Error()})
|
||||
cache.Del(cacheKey)
|
||||
return
|
||||
}
|
||||
|
||||
if token != nil {
|
||||
c.JSON(http.StatusOK, token)
|
||||
return
|
||||
}
|
||||
|
||||
url := "https://api.github.com/copilot_internal/v2/token"
|
||||
req, err := http.NewRequestWithContext(c, "GET", url, nil)
|
||||
if err != nil {
|
||||
log.Println(err.Error())
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
req.Header.Set("authorization", "token "+ghu)
|
||||
req.Header.Set("editor-plugin-version", "copilot-intellij/1.5.21.6667")
|
||||
req.Header.Set("editor-version", "JetBrains-IU/242.21829.142")
|
||||
req.Header.Set("user-agent", "GithubCopilot/1.228.0")
|
||||
|
||||
client := &http.Client{Timeout: 30 * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
log.Println(err.Error())
|
||||
c.JSON(resp.StatusCode, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if resp.StatusCode != 200 {
|
||||
errorMsg := "获取 Token 失败, 当前 ghu_token 账户可能并未订阅 github copilot 服务!" + ghu
|
||||
c.JSON(resp.StatusCode, gin.H{"error": errorMsg})
|
||||
log.Println(errorMsg)
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
var result interface{}
|
||||
err = json.NewDecoder(resp.Body).Decode(&result)
|
||||
if err != nil {
|
||||
log.Println(err.Error())
|
||||
c.JSON(http.StatusUnprocessableEntity, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
cache.Set(cacheKey, result, 1500)
|
||||
c.JSON(resp.StatusCode, result)
|
||||
}
|
||||
373
internal/controller/copilot/github_completions.go
Normal file
373
internal/controller/copilot/github_completions.go
Normal file
@@ -0,0 +1,373 @@
|
||||
package copilot
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/gofrs/uuid"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"math/rand"
|
||||
"net/http"
|
||||
"os"
|
||||
"ripper/internal/cache"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// CodexCompletions 全代理GitHub的代码补全接口
|
||||
func CodexCompletions(c *gin.Context) {
|
||||
ctx := c.Request.Context()
|
||||
|
||||
requestID := uuid.Must(uuid.NewV4()).String()
|
||||
c.Header("x-github-request-id", requestID)
|
||||
|
||||
urlModelName := c.Param("model-name")
|
||||
debounceTime, _ := strconv.Atoi(os.Getenv("COPILOT_DEBOUNCE"))
|
||||
time.Sleep(time.Duration(debounceTime) * time.Millisecond)
|
||||
|
||||
if ctx.Err() != nil {
|
||||
abortCodex(c, http.StatusRequestTimeout)
|
||||
return
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(c.Request.Body)
|
||||
if nil != err {
|
||||
abortCodex(c, http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
copilotAccountType := os.Getenv("COPILOT_ACCOUNT_TYPE")
|
||||
url := "https://proxy." + copilotAccountType + ".githubcopilot.com/v1/engines/" + urlModelName + "/completions"
|
||||
req, err := http.NewRequestWithContext(c, "POST", url, bytes.NewBuffer(body))
|
||||
if nil != err {
|
||||
abortCodex(c, http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// 合并请求头
|
||||
if err := mergeHeaders(c.Request.Header, req); err != nil {
|
||||
log.Println(err)
|
||||
abortCodex(c, http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
httpClientTimeout, _ := time.ParseDuration(os.Getenv("HTTP_CLIENT_TIMEOUT") + "s")
|
||||
client := &http.Client{
|
||||
Timeout: httpClientTimeout,
|
||||
Transport: &http.Transport{
|
||||
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
|
||||
},
|
||||
}
|
||||
resp, err := client.Do(req)
|
||||
if nil != err {
|
||||
if errors.Is(err, context.Canceled) {
|
||||
abortCodex(c, http.StatusRequestTimeout)
|
||||
return
|
||||
}
|
||||
|
||||
log.Println("request completions failed:", err.Error())
|
||||
abortCodex(c, http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
defer CloseIO(resp.Body)
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
log.Println("请求GitHub官方补全接口失败:", string(body))
|
||||
|
||||
abortCodex(c, resp.StatusCode)
|
||||
return
|
||||
}
|
||||
|
||||
c.Status(resp.StatusCode)
|
||||
c.Header("Content-Type", resp.Header.Get("Content-Type"))
|
||||
_, _ = io.Copy(c.Writer, resp.Body)
|
||||
}
|
||||
|
||||
// ChatsCompletions 全代理GitHub的聊天补全接口
|
||||
func ChatsCompletions(c *gin.Context) {
|
||||
ctx := c.Request.Context()
|
||||
if ctx.Err() != nil {
|
||||
abortCodex(c, http.StatusRequestTimeout)
|
||||
return
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(c.Request.Body)
|
||||
if nil != err {
|
||||
abortCodex(c, http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
copilotAccountType := os.Getenv("COPILOT_ACCOUNT_TYPE")
|
||||
url := "https://api." + copilotAccountType + ".githubcopilot.com/chat/completions"
|
||||
req, err := http.NewRequestWithContext(c, "POST", url, bytes.NewBuffer(body))
|
||||
if nil != err {
|
||||
abortCodex(c, http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// 合并请求头
|
||||
if err := mergeHeaders(c.Request.Header, req); err != nil {
|
||||
log.Println(err)
|
||||
abortCodex(c, http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
httpClientTimeout, _ := time.ParseDuration(os.Getenv("HTTP_CLIENT_TIMEOUT") + "s")
|
||||
client := &http.Client{
|
||||
Timeout: httpClientTimeout,
|
||||
Transport: &http.Transport{
|
||||
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
|
||||
},
|
||||
}
|
||||
resp, err := client.Do(req)
|
||||
if nil != err {
|
||||
if errors.Is(err, context.Canceled) {
|
||||
abortCodex(c, http.StatusRequestTimeout)
|
||||
return
|
||||
}
|
||||
|
||||
log.Println("request completions failed:", err.Error())
|
||||
abortCodex(c, http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
defer CloseIO(resp.Body)
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
log.Println("请求GitHub官方对话接口失败:", string(body))
|
||||
|
||||
abortCodex(c, resp.StatusCode)
|
||||
return
|
||||
}
|
||||
|
||||
c.Status(resp.StatusCode)
|
||||
c.Header("Content-Type", resp.Header.Get("Content-Type"))
|
||||
_, _ = io.Copy(c.Writer, resp.Body)
|
||||
}
|
||||
|
||||
// ChatEditCompletions 聊天编辑补全接口
|
||||
func ChatEditCompletions(c *gin.Context) {
|
||||
ctx := c.Request.Context()
|
||||
if ctx.Err() != nil {
|
||||
abortCodex(c, http.StatusRequestTimeout)
|
||||
return
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(c.Request.Body)
|
||||
if nil != err {
|
||||
abortCodex(c, http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
copilotAccountType := os.Getenv("COPILOT_ACCOUNT_TYPE")
|
||||
url := "https://proxy." + copilotAccountType + ".githubcopilot.com/v1/engines/copilot-centralus-h100/speculation"
|
||||
req, err := http.NewRequestWithContext(c, "POST", url, bytes.NewBuffer(body))
|
||||
if nil != err {
|
||||
abortCodex(c, http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// 合并请求头
|
||||
if err := mergeHeaders(c.Request.Header, req); err != nil {
|
||||
log.Println(err)
|
||||
abortCodex(c, http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
httpClientTimeout, _ := time.ParseDuration(os.Getenv("HTTP_CLIENT_TIMEOUT") + "s")
|
||||
client := &http.Client{
|
||||
Timeout: httpClientTimeout,
|
||||
Transport: &http.Transport{
|
||||
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
|
||||
},
|
||||
}
|
||||
resp, err := client.Do(req)
|
||||
if nil != err {
|
||||
if errors.Is(err, context.Canceled) {
|
||||
abortCodex(c, http.StatusRequestTimeout)
|
||||
return
|
||||
}
|
||||
|
||||
log.Println("request failed:", err.Error())
|
||||
abortCodex(c, http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
defer CloseIO(resp.Body)
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
log.Println("请求 Chat 编辑接口失败:", string(body))
|
||||
|
||||
abortCodex(c, resp.StatusCode)
|
||||
return
|
||||
}
|
||||
|
||||
c.Status(resp.StatusCode)
|
||||
c.Header("Content-Type", resp.Header.Get("Content-Type"))
|
||||
requestID := uuid.Must(uuid.NewV4()).String()
|
||||
c.Header("x-github-request-id", requestID)
|
||||
_, _ = io.Copy(c.Writer, resp.Body)
|
||||
}
|
||||
|
||||
// getAuthToken 获取GitHub Copilot的临时Token
|
||||
func getAuthToken() (string, error) {
|
||||
ghuTokens := strings.Split(os.Getenv("COPILOT_GHU_TOKEN"), ",")
|
||||
if len(ghuTokens) == 0 {
|
||||
return "", fmt.Errorf("COPILOT_GHU_TOKEN environment variable is empty or malformed")
|
||||
}
|
||||
|
||||
rand.Seed(time.Now().UnixNano())
|
||||
ghu := ghuTokens[rand.Intn(len(ghuTokens))]
|
||||
cacheKey := "github:copilot_internal_v2_token:" + ghu
|
||||
token, err := cache.Get(cacheKey)
|
||||
if err != nil {
|
||||
cache.Del(cacheKey)
|
||||
return "", err
|
||||
}
|
||||
if token != nil {
|
||||
return token.(string), nil
|
||||
}
|
||||
|
||||
url := "https://api.github.com/copilot_internal/v2/token"
|
||||
httpClientTimeout, _ := time.ParseDuration(os.Getenv("HTTP_CLIENT_TIMEOUT") + "s")
|
||||
client := &http.Client{
|
||||
Timeout: httpClientTimeout,
|
||||
Transport: &http.Transport{
|
||||
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
|
||||
},
|
||||
}
|
||||
req, err := http.NewRequest("GET", url, nil)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
req.Header.Set("authorization", "token "+ghu)
|
||||
req.Header.Set("host", "api.github.com")
|
||||
req.Header.Set("accept", "*/*")
|
||||
req.Header.Set("editor-plugin-version", "copilot-intellij/1.5.21.6667")
|
||||
req.Header.Set("copilot-language-server-version", "1.228.0")
|
||||
req.Header.Set("user-agent", "GithubCopilot/1.228.0")
|
||||
req.Header.Set("editor-version", "JetBrains-IU/242.21829.142")
|
||||
|
||||
res, err := client.Do(req)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer res.Body.Close()
|
||||
|
||||
if res.StatusCode != http.StatusOK {
|
||||
return "", fmt.Errorf("获取 Token 失败" + ghu)
|
||||
}
|
||||
|
||||
body, err := ioutil.ReadAll(res.Body)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
// 解析json
|
||||
var result map[string]interface{}
|
||||
err = json.Unmarshal(body, &result)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
newToken := result["token"].(string)
|
||||
err = cache.Set(cacheKey, newToken, 1500)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return newToken, nil
|
||||
}
|
||||
|
||||
// mergeHeaders 合并请求头,固定请求头会覆盖原有请求头
|
||||
func mergeHeaders(originalHeader http.Header, req *http.Request) error {
|
||||
// 复制原始请求头
|
||||
for key, values := range originalHeader {
|
||||
for _, value := range values {
|
||||
req.Header.Add(key, value)
|
||||
}
|
||||
}
|
||||
|
||||
// 获取token
|
||||
token, err := getAuthToken()
|
||||
if err != nil {
|
||||
return fmt.Errorf("获取GitHub Copilot的临时Token失败: %w", err)
|
||||
}
|
||||
|
||||
// 固定请求头
|
||||
fixedHeaders := map[string]string{
|
||||
"authorization": "Bearer " + token,
|
||||
"editor-plugin-version": "copilot-intellij/1.5.21.6667",
|
||||
"copilot-language-server-version": "1.228.0",
|
||||
"user-agent": "GithubCopilot/1.228.0",
|
||||
"editor-version": "JetBrains-IU/242.21829.142",
|
||||
}
|
||||
|
||||
// 设置固定请求头(覆盖原有的)
|
||||
for key, value := range fixedHeaders {
|
||||
req.Header.Set(key, value)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetCopilotModels 获取GitHub Copilot的模型列表
|
||||
func GetCopilotModels(c *gin.Context) {
|
||||
copilotAccountType := os.Getenv("COPILOT_ACCOUNT_TYPE")
|
||||
url := "https://api." + copilotAccountType + ".githubcopilot.com/models"
|
||||
req, err := http.NewRequestWithContext(c, "GET", url, nil)
|
||||
if nil != err {
|
||||
abortCodex(c, http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// 合并请求头
|
||||
if err := mergeHeaders(c.Request.Header, req); err != nil {
|
||||
log.Println(err)
|
||||
abortCodex(c, http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
httpClientTimeout, _ := time.ParseDuration(os.Getenv("HTTP_CLIENT_TIMEOUT") + "s")
|
||||
client := &http.Client{
|
||||
Timeout: httpClientTimeout,
|
||||
Transport: &http.Transport{
|
||||
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
|
||||
},
|
||||
}
|
||||
resp, err := client.Do(req)
|
||||
if nil != err {
|
||||
if errors.Is(err, context.Canceled) {
|
||||
abortCodex(c, http.StatusRequestTimeout)
|
||||
return
|
||||
}
|
||||
|
||||
log.Println("获取模型列表失败:", err.Error())
|
||||
abortCodex(c, http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
defer CloseIO(resp.Body)
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
log.Println("请求GitHub Copilot模型列表失败:", string(body))
|
||||
|
||||
abortCodex(c, resp.StatusCode)
|
||||
return
|
||||
}
|
||||
|
||||
// 转发原始响应
|
||||
c.Status(resp.StatusCode)
|
||||
c.Header("Content-Type", resp.Header.Get("Content-Type"))
|
||||
requestID := uuid.Must(uuid.NewV4()).String()
|
||||
c.Header("x-github-request-id", requestID)
|
||||
_, _ = io.Copy(c.Writer, resp.Body)
|
||||
}
|
||||
24
internal/controller/copilot/membership.go
Normal file
24
internal/controller/copilot/membership.go
Normal file
@@ -0,0 +1,24 @@
|
||||
package copilot
|
||||
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/gofrs/uuid"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// GetMembership 获取团队成员信息
|
||||
func GetMembership(c *gin.Context) {
|
||||
requestID := uuid.Must(uuid.NewV4()).String()
|
||||
c.Header("x-github-request-id", requestID)
|
||||
|
||||
teamID := c.Param("teamID")
|
||||
username := c.Param("username")
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"message": "Not Found",
|
||||
"documentation_url": "https://docs.github.com/rest/teams/members#get-team-membership-for-a-user-legacy",
|
||||
"status": "404",
|
||||
"teamID": teamID,
|
||||
"username": username,
|
||||
})
|
||||
}
|
||||
49
internal/controller/copilot/meta.go
Normal file
49
internal/controller/copilot/meta.go
Normal file
@@ -0,0 +1,49 @@
|
||||
package copilot
|
||||
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
func V3meta(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{})
|
||||
}
|
||||
|
||||
func Cliv3(c *gin.Context) {
|
||||
c.Header("X-OAuth-Scopes", "gist, read:org, repo, user, workflow, write:public_key")
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"current_user_url": "https://api.github.com/user",
|
||||
"current_user_authorizations_html_url": "https://github.com/settings/connections/applications{/client_id}",
|
||||
"authorizations_url": "https://api.github.com/authorizations",
|
||||
"code_search_url": "https://api.github.com/search/code?q={query}{&page,per_page,sort,order}",
|
||||
"commit_search_url": "https://api.github.com/search/commits?q={query}{&page,per_page,sort,order}",
|
||||
"emails_url": "https://api.github.com/user/emails",
|
||||
"emojis_url": "https://api.github.com/emojis",
|
||||
"events_url": "https://api.github.com/events",
|
||||
"feeds_url": "https://api.github.com/feeds",
|
||||
"followers_url": "https://api.github.com/user/followers",
|
||||
"following_url": "https://api.github.com/user/following{/target}",
|
||||
"gists_url": "https://api.github.com/gists{/gist_id}",
|
||||
"hub_url": "https://api.github.com/hub",
|
||||
"issue_search_url": "https://api.github.com/search/issues?q={query}{&page,per_page,sort,order}",
|
||||
"issues_url": "https://api.github.com/issues",
|
||||
"keys_url": "https://api.github.com/user/keys",
|
||||
"label_search_url": "https://api.github.com/search/labels?q={query}&repository_id={repository_id}{&page,per_page}",
|
||||
"notifications_url": "https://api.github.com/notifications",
|
||||
"organization_url": "https://api.github.com/orgs/{org}",
|
||||
"organization_repositories_url": "https://api.github.com/orgs/{org}/repos{?type,page,per_page,sort}",
|
||||
"organization_teams_url": "https://api.github.com/orgs/{org}/teams",
|
||||
"public_gists_url": "https://api.github.com/gists/public",
|
||||
"rate_limit_url": "https://api.github.com/rate_limit",
|
||||
"repository_url": "https://api.github.com/repos/{owner}/{repo}",
|
||||
"repository_search_url": "https://api.github.com/search/repositories?q={query}{&page,per_page,sort,order}",
|
||||
"current_user_repositories_url": "https://api.github.com/user/repos{?type,page,per_page,sort}",
|
||||
"starred_url": "https://api.github.com/user/starred{/owner}{/repo}",
|
||||
"starred_gists_url": "https://api.github.com/gists/starred",
|
||||
"topic_search_url": "https://api.github.com/search/topics?q={query}{&page,per_page}",
|
||||
"user_url": "https://api.github.com/users/{user}",
|
||||
"user_organizations_url": "https://api.github.com/user/orgs",
|
||||
"user_repositories_url": "https://api.github.com/users/{user}/repos{?type,page,per_page,sort}",
|
||||
"user_search_url": "https://api.github.com/search/users?q={query}{&page,per_page,sort,order}",
|
||||
})
|
||||
}
|
||||
158
internal/controller/copilot/router_register.go
Normal file
158
internal/controller/copilot/router_register.go
Normal file
@@ -0,0 +1,158 @@
|
||||
package copilot
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"log"
|
||||
"os"
|
||||
"ripper/internal/middleware"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
ClientType string
|
||||
CopilotProxyAll bool
|
||||
}
|
||||
|
||||
// loadConfig loads the configuration from environment variables.
|
||||
func loadConfig() (*Config, error) {
|
||||
proxyAll, err := strconv.ParseBool(os.Getenv("COPILOT_PROXY_ALL"))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid boolean value for COPILOT_PROXY_ALL: %v", err)
|
||||
}
|
||||
|
||||
return &Config{
|
||||
ClientType: os.Getenv("COPILOT_CLIENT_TYPE"),
|
||||
CopilotProxyAll: proxyAll,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GinApi 注册路由
|
||||
func GinApi(g *gin.RouterGroup) {
|
||||
config, err := loadConfig()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
// 基础路由
|
||||
setupBasicRoutes(g, config)
|
||||
|
||||
// 用户相关路由
|
||||
setupUserRoutes(g)
|
||||
|
||||
// Copilot相关路由
|
||||
setupCopilotRoutes(g, config)
|
||||
|
||||
// API v3相关路由
|
||||
setupV3Routes(g)
|
||||
}
|
||||
|
||||
// setupBasicRoutes 设置基础路由
|
||||
func setupBasicRoutes(g *gin.RouterGroup, config *Config) {
|
||||
g.GET("/models", createModelsHandler(config))
|
||||
g.GET("/_ping", GetPing)
|
||||
g.POST("/telemetry", PostTelemetry)
|
||||
g.GET("/agents", GetAgents)
|
||||
g.GET("/copilot_internal/user", GetCopilotInternalUser)
|
||||
}
|
||||
|
||||
// setupUserRoutes 设置用户相关路由
|
||||
func setupUserRoutes(g *gin.RouterGroup) {
|
||||
authMiddleware := middleware.AccessTokenCheckAuth()
|
||||
|
||||
userGroup := g.Group("")
|
||||
userGroup.Use(authMiddleware)
|
||||
{
|
||||
userGroup.GET("/user", GetLoginUser)
|
||||
userGroup.GET("/user/orgs", GetUserOrgs)
|
||||
userGroup.GET("/api/v3/user", GetLoginUser)
|
||||
userGroup.GET("/api/v3/user/orgs", GetUserOrgs)
|
||||
userGroup.GET("/teams/:teamID/memberships/:username", GetMembership)
|
||||
userGroup.POST("/chunks", HandleChunks)
|
||||
}
|
||||
}
|
||||
|
||||
// setupCopilotRoutes 设置Copilot相关路由
|
||||
func setupCopilotRoutes(g *gin.RouterGroup, config *Config) {
|
||||
tokenMiddleware := middleware.TokenCheckAuth()
|
||||
|
||||
// Copilot token endpoint
|
||||
g.GET("/copilot_internal/v2/token",
|
||||
middleware.AccessTokenCheckAuth(),
|
||||
createTokenHandler(config))
|
||||
|
||||
// Completions endpoints
|
||||
completionsGroup := g.Group("")
|
||||
completionsGroup.Use(tokenMiddleware)
|
||||
{
|
||||
completionsGroup.POST("/v1/engines/:model-name/completions", createCompletionsHandler(config))
|
||||
completionsGroup.POST("/v1/engines/copilot-codex", createCompletionsHandler(config))
|
||||
completionsGroup.POST("/chat/completions", createChatHandler(config))
|
||||
completionsGroup.POST("/agents/chat", createChatHandler(config))
|
||||
completionsGroup.POST("/v1/chat/completions", createChatHandler(config))
|
||||
completionsGroup.POST("/v1/engines/copilot-centralus-h100/speculation", createChatEditCompletionsHandler(config))
|
||||
completionsGroup.POST("/embeddings", HandleEmbeddings)
|
||||
}
|
||||
}
|
||||
|
||||
// setupV3Routes 设置API v3相关路由
|
||||
func setupV3Routes(g *gin.RouterGroup) {
|
||||
g.GET("/api/v3/meta", V3meta)
|
||||
g.GET("/api/v3/", Cliv3)
|
||||
g.GET("/", Cliv3)
|
||||
}
|
||||
|
||||
// 处理函数生成器
|
||||
func createTokenHandler(config *Config) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
if config.ClientType == "github" && !config.CopilotProxyAll {
|
||||
GetCopilotInternalV2Token(c)
|
||||
} else {
|
||||
GetDisguiseCopilotInternalV2Token(c)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// createCompletionsHandler 生成代码补全处理函数
|
||||
func createCompletionsHandler(config *Config) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
if config.ClientType == "github" && config.CopilotProxyAll {
|
||||
CodexCompletions(c)
|
||||
} else {
|
||||
CodeCompletions(c)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// createChatHandler 生成聊天补全处理函数
|
||||
func createChatHandler(config *Config) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
if config.ClientType == "github" && config.CopilotProxyAll {
|
||||
ChatsCompletions(c)
|
||||
} else {
|
||||
ChatCompletions(c)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// createChatEditCompletionsHandler 生成聊天编辑补全处理函数
|
||||
func createChatEditCompletionsHandler(config *Config) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
if config.ClientType == "github" && config.CopilotProxyAll {
|
||||
ChatEditCompletions(c)
|
||||
} else {
|
||||
CodeCompletions(c)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// createModelsHandler 生成模型处理函数
|
||||
func createModelsHandler(config *Config) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
if config.ClientType == "github" && config.CopilotProxyAll {
|
||||
GetCopilotModels(c)
|
||||
} else {
|
||||
GetModels(c)
|
||||
}
|
||||
}
|
||||
}
|
||||
20
internal/controller/copilot/telemetry.go
Normal file
20
internal/controller/copilot/telemetry.go
Normal file
@@ -0,0 +1,20 @@
|
||||
package copilot
|
||||
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/gofrs/uuid"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// PostTelemetry 接收并处理来自GitHub Copilot的遥测数据
|
||||
func PostTelemetry(c *gin.Context) {
|
||||
requestID := uuid.Must(uuid.NewV4()).String()
|
||||
c.Header("x-github-request-id", requestID)
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"itemsReceived": 0,
|
||||
"itemsAccepted": 0,
|
||||
"appId": nil,
|
||||
"errors": []string{},
|
||||
})
|
||||
}
|
||||
109
internal/controller/copilot/user.go
Normal file
109
internal/controller/copilot/user.go
Normal file
@@ -0,0 +1,109 @@
|
||||
package copilot
|
||||
|
||||
import (
|
||||
"crypto/md5"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/gofrs/uuid"
|
||||
"math/rand"
|
||||
"net/http"
|
||||
"ripper/internal/middleware"
|
||||
jwtpkg "ripper/pkg/jwt"
|
||||
"time"
|
||||
)
|
||||
|
||||
// GetLoginUser 获取登录用户信息
|
||||
func GetLoginUser(ctx *gin.Context) {
|
||||
userDisplayName := "github"
|
||||
token, _ := jwtpkg.GetJwtProto(ctx, &middleware.UserLoad{})
|
||||
if token != nil && token.UserDisplayName != "" {
|
||||
userDisplayName = token.UserDisplayName
|
||||
}
|
||||
|
||||
ctx.Header("X-OAuth-Scopes", "gist, read:org, repo, user, workflow, write:public_key")
|
||||
requestID := uuid.Must(uuid.NewV4()).String()
|
||||
ctx.Header("x-github-request-id", requestID)
|
||||
ctx.JSON(http.StatusOK, gin.H{
|
||||
"login": userDisplayName,
|
||||
"id": 9919,
|
||||
"node_id": "DEyOk9yZ2FuaXphdGlvbjk5MTk=",
|
||||
"avatar_url": "https://avatars.githubusercontent.com/u/9919?v=4",
|
||||
"gravatar_id": "",
|
||||
"url": "https://api.github.com/users/github",
|
||||
"html_url": "https://github.com/github",
|
||||
"followers_url": "https://api.github.com/users/github/followers",
|
||||
"following_url": "https://api.github.com/users/github/following{/other_user}",
|
||||
"gists_url": "https://api.github.com/users/github/gists{/gist_id}",
|
||||
"starred_url": "https://api.github.com/users/github/starred{/owner}{/repo}",
|
||||
"subscriptions_url": "https://api.github.com/users/github/subscriptions",
|
||||
"organizations_url": "https://api.github.com/users/github/orgs",
|
||||
"repos_url": "https://api.github.com/users/github/repos",
|
||||
"events_url": "https://api.github.com/users/github/events{/privacy}",
|
||||
"received_events_url": "https://api.github.com/users/github/received_events",
|
||||
"type": "User",
|
||||
"site_admin": false,
|
||||
"name": "GitHub",
|
||||
"company": nil,
|
||||
"blog": "",
|
||||
"location": "San Francisco, CA",
|
||||
"email": nil,
|
||||
"hireable": nil,
|
||||
"bio": nil,
|
||||
"twitter_username": nil,
|
||||
"public_repos": 498,
|
||||
"public_gists": 0,
|
||||
"followers": 42848,
|
||||
"following": 0,
|
||||
"created_at": "2008-05-11T04:37:31Z",
|
||||
"updated_at": "2022-11-29T19:44:55Z",
|
||||
})
|
||||
|
||||
}
|
||||
|
||||
func GetUserOrgs(ctx *gin.Context) {
|
||||
ctx.Header("X-OAuth-Scopes", "gist, read:org, repo, user, workflow, write:public_key")
|
||||
ctx.JSON(http.StatusOK, []interface{}{})
|
||||
}
|
||||
|
||||
// generateTrackingID 生成模拟的 analytics_tracking_id
|
||||
func generateTrackingID() string {
|
||||
// 生成一个随机字符串并计算其 MD5
|
||||
randomStr := fmt.Sprintf("%d%d", time.Now().UnixNano(), rand.Int())
|
||||
hash := md5.Sum([]byte(randomStr))
|
||||
return hex.EncodeToString(hash[:])
|
||||
}
|
||||
|
||||
// generateAssignedDate 生成模拟的 assigned_date
|
||||
func generateAssignedDate() string {
|
||||
// 生成最近30天内的随机时间
|
||||
now := time.Now()
|
||||
daysAgo := rand.Intn(30)
|
||||
randomTime := now.AddDate(0, 0, -daysAgo)
|
||||
|
||||
// 随机增加小时和分钟
|
||||
randomHour := rand.Intn(24)
|
||||
randomMinute := rand.Intn(60)
|
||||
randomTime = randomTime.Add(time.Duration(randomHour) * time.Hour)
|
||||
randomTime = randomTime.Add(time.Duration(randomMinute) * time.Minute)
|
||||
|
||||
// 返回格式化的时间字符串
|
||||
return randomTime.Format(time.RFC3339)
|
||||
}
|
||||
|
||||
// GetCopilotInternalUser 获取 Copilot 内部用户信息
|
||||
func GetCopilotInternalUser(ctx *gin.Context) {
|
||||
requestID := uuid.Must(uuid.NewV4()).String()
|
||||
ctx.Header("x-github-request-id", requestID)
|
||||
|
||||
ctx.JSON(http.StatusOK, gin.H{
|
||||
"access_type_sku": "free_educational",
|
||||
"copilot_plan": "individual",
|
||||
"analytics_tracking_id": generateTrackingID(),
|
||||
"assigned_date": generateAssignedDate(),
|
||||
"can_signup_for_limited": false,
|
||||
"chat_enabled": true,
|
||||
"organization_login_list": []interface{}{},
|
||||
"organization_list": []interface{}{},
|
||||
})
|
||||
}
|
||||
78
internal/controller/copilot/utils.go
Normal file
78
internal/controller/copilot/utils.go
Normal file
@@ -0,0 +1,78 @@
|
||||
package copilot
|
||||
|
||||
import (
|
||||
_ "embed"
|
||||
"encoding/json"
|
||||
"github.com/gofrs/uuid"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type Pong struct {
|
||||
Now int `json:"now"`
|
||||
Status string `json:"status"`
|
||||
Ns1 string `json:"ns1"`
|
||||
}
|
||||
|
||||
// GetPing 模拟ping接口
|
||||
func GetPing(ctx *gin.Context) {
|
||||
requestID := uuid.Must(uuid.NewV4()).String()
|
||||
ctx.Header("x-github-request-id", requestID)
|
||||
|
||||
ctx.JSON(http.StatusOK, Pong{
|
||||
Now: time.Now().Second(),
|
||||
Status: "ok",
|
||||
Ns1: "200 OK",
|
||||
})
|
||||
}
|
||||
|
||||
// ModelsResponse 模型列表响应结构
|
||||
type ModelsResponse struct {
|
||||
Data []interface{} `json:"data"`
|
||||
Object string `json:"object"`
|
||||
}
|
||||
|
||||
// GetModels 获取模型列表
|
||||
func GetModels(ctx *gin.Context) {
|
||||
// 从根目录下读取models.json文件
|
||||
jsonFile, err := os.Open(filepath.Join("models.json"))
|
||||
if err != nil {
|
||||
log.Printf("无法打开models.json文件: %v", err)
|
||||
ctx.JSON(http.StatusInternalServerError, gin.H{"error": "无法读取模型列表数据"})
|
||||
return
|
||||
}
|
||||
defer CloseIO(jsonFile)
|
||||
|
||||
// 解析JSON数据
|
||||
jsonData, err := io.ReadAll(jsonFile)
|
||||
if err != nil {
|
||||
log.Printf("读取models.json内容失败: %v", err)
|
||||
ctx.JSON(http.StatusInternalServerError, gin.H{"error": "无法读取模型列表数据"})
|
||||
return
|
||||
}
|
||||
|
||||
var modelsResponse ModelsResponse
|
||||
if err := json.Unmarshal(jsonData, &modelsResponse); err != nil {
|
||||
log.Printf("解析models.json失败: %v", err)
|
||||
ctx.JSON(http.StatusInternalServerError, gin.H{"error": "无法解析模型列表数据"})
|
||||
return
|
||||
}
|
||||
|
||||
// 返回模型列表数据
|
||||
requestID := uuid.Must(uuid.NewV4()).String()
|
||||
ctx.Header("x-github-request-id", requestID)
|
||||
ctx.JSON(http.StatusOK, modelsResponse)
|
||||
}
|
||||
|
||||
func CloseIO(c io.Closer) {
|
||||
err := c.Close()
|
||||
if nil != err {
|
||||
log.Println(err)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user