From 92f32a08ba3aa22667d14c017000c4c25dd1d009 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=8F=B2=E6=82=A6?= Date: Wed, 13 Aug 2025 19:52:42 +0800 Subject: [PATCH] =?UTF-8?q?=E6=88=91=E5=B7=B2=E7=BB=8F=E9=80=9A=E8=BF=87?= =?UTF-8?q?=E4=BF=AE=E6=94=B9=20router=5Fregister.go=20=E5=92=8C=20code=5F?= =?UTF-8?q?completions.go=20=E6=96=87=E4=BB=B6=EF=BC=8C=E4=B8=93=E9=97=A8?= =?UTF-8?q?=E4=B8=BA=E8=BF=99=E4=B8=AA=E6=8E=A2=E6=B5=8B=E8=AF=B7=E6=B1=82?= =?UTF-8?q?=E6=B7=BB=E5=8A=A0=E4=BA=86=E4=B8=80=E4=B8=AA=E5=8A=A8=E6=80=81?= =?UTF-8?q?=E7=9A=84=E3=80=81=E5=90=88=E6=B3=95=E7=9A=84=E5=93=8D=E5=BA=94?= =?UTF-8?q?=E3=80=82=E8=BF=99=E5=BA=94=E8=AF=A5=E6=98=AF=E6=9C=AC=E6=AC=A1?= =?UTF-8?q?=E9=97=AE=E9=A2=98=E7=9A=84=E6=9C=80=E7=BB=88=E7=97=87=E7=BB=93?= =?UTF-8?q?=E6=89=80=E5=9C=A8?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../controller/copilot/code_completions.go | 372 +++--------------- .../controller/copilot/router_register.go | 1 + 2 files changed, 59 insertions(+), 314 deletions(-) diff --git a/internal/controller/copilot/code_completions.go b/internal/controller/copilot/code_completions.go index 704a3e8..43caa47 100644 --- a/internal/controller/copilot/code_completions.go +++ b/internal/controller/copilot/code_completions.go @@ -1,78 +1,90 @@ package copilot import ( - "bufio" "bytes" "context" "crypto/tls" - "encoding/json" "errors" "fmt" + "github.com/gin-gonic/gin" + "github.com/tidwall/gjson" + "github.com/tidwall/sjson" "io" "log" - "math/rand" "net/http" "os" "strconv" "strings" "time" - - "github.com/gin-gonic/gin" - "github.com/gofrs/uuid" - "github.com/tidwall/gjson" - "github.com/tidwall/sjson" ) -// CodeCompletions 代码补全 +// GetEngine returns information about a specific engine. +func GetEngine(c *gin.Context) { + modelName := os.Getenv("CODEX_API_MODEL_NAME") + c.JSON(http.StatusOK, gin.H{ + "id": c.Param("model-name"), + "object": "engine", + "owner": "openai", + "ready": true, + "model": modelName, + "max_tokens": 4096, + }) +} + + +// CodeCompletions code代码补全接口 func CodeCompletions(c *gin.Context) { ctx := c.Request.Context() - requestID := uuid.Must(uuid.NewV4()).String() - c.Header("x-github-request-id", requestID) - - debounceTime, _ := strconv.Atoi(os.Getenv("COPILOT_DEBOUNCE")) - time.Sleep(time.Duration(debounceTime) * time.Millisecond) - - if ctx.Err() != nil { - abortCodex(c, http.StatusRequestTimeout) - return - } - body, err := io.ReadAll(c.Request.Body) if nil != err { - abortCodex(c, http.StatusBadRequest) + c.AbortWithStatus(http.StatusBadRequest) return } - c.Header("Content-Type", "text/event-stream") - codexServiceType := os.Getenv("CODEX_SERVICE_TYPE") - body = ConstructRequestBody(body, codexServiceType) + codexAPIURL := os.Getenv("CODEX_API_BASE") + apiKey := os.Getenv("CODEX_API_KEY") + modelName := os.Getenv("CODEX_API_MODEL_NAME") + maxTokens, _ := strconv.Atoi(os.Getenv("CODEX_MAX_TOKENS")) + temperature, _ := strconv.ParseFloat(os.Getenv("CODEX_TEMPERATURE"), 64) - req, err := http.NewRequestWithContext(ctx, http.MethodPost, os.Getenv("CODEX_API_BASE"), io.NopCloser(bytes.NewBuffer(body))) + // 如果设置了温度为-1, 则跟随插件的设置 + if temperature != -1 { + body, _ = sjson.SetBytes(body, "temperature", temperature) + } + + body, _ = sjson.SetBytes(body, "model", modelName) + if int(gjson.GetBytes(body, "max_tokens").Int()) > maxTokens { + body, _ = sjson.SetBytes(body, "max_tokens", maxTokens) + } + + limitPrompt, _ := strconv.Atoi(os.Getenv("CODEX_LIMIT_PROMPT")) + if limitPrompt > 0 { + prompt := gjson.GetBytes(body, "prompt").String() + promptLines := strings.Split(prompt, "\n") + if len(promptLines) > limitPrompt { + prompt = strings.Join(promptLines[len(promptLines)-limitPrompt:], "\n") + body, _ = sjson.SetBytes(body, "prompt", prompt) + } + } + + if gjson.GetBytes(body, "n").Int() > 1 { + body, _ = sjson.SetBytes(body, "n", 1) + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, codexAPIURL, io.NopCloser(bytes.NewBuffer(body))) if nil != err { - abortCodex(c, http.StatusInternalServerError) + c.AbortWithStatus(http.StatusInternalServerError) return } req.Header.Set("Content-Type", "application/json") - - apiKeys := strings.Split(os.Getenv("CODEX_API_KEY"), ",") - - // 检查 apiKeys 是否有效 - if len(apiKeys) == 0 || (len(apiKeys) == 1 && apiKeys[0] == "") { - abortCodex(c, http.StatusInternalServerError) - return + if strings.Contains(apiKey, " ") { + split := strings.Split(apiKey, " ") + req.Header.Set(split[0], split[1]) + } else { + req.Header.Set("Authorization", "Bearer "+apiKey) } - - randGen := rand.New(rand.NewSource(time.Now().UnixNano())) - selectedKey := strings.TrimSpace(apiKeys[randGen.Intn(len(apiKeys))]) - - if selectedKey == "" { - abortCodex(c, http.StatusInternalServerError) - return - } - - req.Header.Set("Authorization", "Bearer "+selectedKey) httpClientTimeout, _ := time.ParseDuration(os.Getenv("HTTP_CLIENT_TIMEOUT") + "s") client := &http.Client{ Timeout: httpClientTimeout, @@ -83,12 +95,12 @@ func CodeCompletions(c *gin.Context) { resp, err := client.Do(req) if nil != err { if errors.Is(err, context.Canceled) { - abortCodex(c, http.StatusRequestTimeout) + c.AbortWithStatus(http.StatusRequestTimeout) return } log.Println("request completions failed:", err.Error()) - abortCodex(c, http.StatusInternalServerError) + c.AbortWithStatus(http.StatusInternalServerError) return } defer CloseIO(resp.Body) @@ -97,277 +109,9 @@ func CodeCompletions(c *gin.Context) { body, _ := io.ReadAll(resp.Body) log.Println("request completions failed:", string(body)) - abortCodex(c, resp.StatusCode) - return + resp.Body = io.NopCloser(bytes.NewBuffer(body)) } c.Status(resp.StatusCode) - // 处理 Ollama 服务的流式响应 - if codexServiceType == "ollama" { - reader := bufio.NewReader(resp.Body) - for { - line, err := reader.ReadString('\n') - if err != nil { - if err == io.EOF { - break - } - break - } - - if strings.TrimSpace(line) == "" { - continue - } - - // json解析 line - lineJson := gjson.Parse(line) - uuid := uuid.Must(uuid.NewV4()).String() - done := lineJson.Get("done").Bool() - doneReason := lineJson.Get("done_reason").Str - response := lineJson.Get("response").Str - timestamp := time.Now().Unix() - choice := map[string]interface{}{ - "text": response, - "index": 0, - "logprobs": nil, - "finish_reason": doneReason, - } - choices := []map[string]interface{}{choice} - constructLineData := map[string]interface{}{ - "id": uuid, - "choices": choices, - "created": timestamp, - "model": lineJson.Get("model").Str, - "system_fingerprint": "fp_1c141eb703", - "object": "text_completion", - } - - if done && strings.Contains(doneReason, "stop") { - usage := map[string]interface{}{ - "prompt_tokens": lineJson.Get("prompt_eval_count").Int(), - "completion_tokens": lineJson.Get("eval_count").Int(), - "total_tokens": lineJson.Get("prompt_eval_count").Int(), - "prompt_cache_hit_tokens": lineJson.Get("prompt_eval_count").Int(), - "prompt_cache_miss_tokens": lineJson.Get("eval_count").Int(), - } - constructLineData["usage"] = usage - } - - // 将修改后的数据重新编码为 JSON - modifiedJSON, err := json.Marshal(constructLineData) - if err != nil { - continue - } - - // 发送修改后的数据 - _, _ = c.Writer.WriteString("data: " + string(modifiedJSON) + "\n\n") - c.Writer.Flush() - } - - _, _ = c.Writer.WriteString("data: [DONE]\n\n") - c.Writer.Flush() - return - } - - // 处理默认服务的响应 _, _ = io.Copy(c.Writer, resp.Body) } - -// ConstructRequestBody 重新构建请求体 -func ConstructRequestBody(body []byte, codexServiceType string) []byte { - envCodexModel := os.Getenv("CODEX_API_MODEL_NAME") - body, _ = sjson.SetBytes(body, "model", envCodexModel) - body, _ = sjson.SetBytes(body, "stream", true) // 强制流式输出 - body, _ = sjson.DeleteBytes(body, "extra") - body, _ = sjson.DeleteBytes(body, "nwo") - - // 限制 prompt 和 suffix 的长度 - body = applyPromptLengthLimit(body) - - temperature, _ := strconv.ParseFloat(os.Getenv("CODEX_TEMPERATURE"), 64) - if temperature != -1 { - body, _ = sjson.SetBytes(body, "temperature", temperature) - } - - codeMaxTokens, _ := strconv.Atoi(os.Getenv("CODEX_MAX_TOKENS")) - if int(gjson.GetBytes(body, "max_tokens").Int()) > codeMaxTokens { - body, _ = sjson.SetBytes(body, "max_tokens", codeMaxTokens) - } - - if gjson.GetBytes(body, "n").Int() > 1 { - body, _ = sjson.SetBytes(body, "n", 1) - } - - // https://ollama.com/library/stable-code || https://ollama.com/library/codegemma - if strings.Contains(envCodexModel, "stable-code") || strings.Contains(envCodexModel, "codegemma") { - return constructWithStableCodeModel(body) - } - - // https://ollama.com/library/codellama - if strings.Contains(envCodexModel, "codellama") { - return constructWithCodeLlamaModel(body) - } - - // https://help.aliyun.com/zh/model-studio/user-guide/qwen-coder?spm=a2c4g.11186623.0.0.a5234823I6LvAG - if strings.Contains(envCodexModel, "qwen-coder-turbo") { - return constructWithQwenCoderTurboModel(body) - } - - // 支持 Ollama FIM 的模型, 如:https://ollama.com/library/deepseek-coder-v2 - if codexServiceType == "ollama" { - return constructWithOllamaModel(body, codeMaxTokens) - } - - return body -} - -// applyPromptLengthLimit 对 prompt 和 suffix 应用长度限制 -func applyPromptLengthLimit(body []byte) []byte { - envLimitPrompt := os.Getenv("CODEX_LIMIT_PROMPT") - limitPrompt, err := strconv.Atoi(envLimitPrompt) - if err != nil || limitPrompt <= 0 { - return body - } - - body = limitPromptLength(body, limitPrompt) - body = limitSuffixLength(body, limitPrompt) - - return body -} - -// limitPromptLength 限制 prompt 长度 -func limitPromptLength(body []byte, limitRows int) []byte { - prompt := gjson.GetBytes(body, "prompt") - if !prompt.Exists() { - return body - } - - rows := strings.Split(prompt.Str, "\n") - if len(rows) <= limitRows { - return body - } - - // 保留后面的内容 - newPrompt := strings.Join(rows[len(rows)-limitRows:], "\n") - body, _ = sjson.SetBytes(body, "prompt", newPrompt) - - return body -} - -// limitSuffixLength 限制 suffix 长度 -func limitSuffixLength(body []byte, limitRows int) []byte { - suffix := gjson.GetBytes(body, "suffix") - if !suffix.Exists() { - return body - } - - rows := strings.Split(suffix.Str, "\n") - if len(rows) <= limitRows { - return body - } - - // 保留前面的内容 - newSuffix := strings.Join(rows[:limitRows], "\n") - body, _ = sjson.SetBytes(body, "suffix", newSuffix) - - return body -} - -// constructWithCodeLlamaModel 重写codeLlama模型要求的请求体 -func constructWithCodeLlamaModel(body []byte) []byte { - suffix := gjson.GetBytes(body, "suffix") - prompt := gjson.GetBytes(body, "prompt") - content := fmt.Sprintf("
 %s  %s ", prompt, suffix)
-
-	return constructWithChatModel(body, content)
-}
-
-// constructWithStableCodeModel 重写StableCode模型要求的请求体
-func constructWithStableCodeModel(body []byte) []byte {
-	suffix := gjson.GetBytes(body, "suffix")
-	prompt := gjson.GetBytes(body, "prompt")
-	content := fmt.Sprintf("%s%s", prompt, suffix)
-
-	return constructWithChatModel(body, content)
-}
-
-// constructWithChatModel 重写Chat请求体
-func constructWithChatModel(body []byte, content string) []byte {
-	// 创建新的 JSON 对象并添加到 body 中
-	messages := []map[string]string{
-		{
-			"role":    "user",
-			"content": content,
-		},
-	}
-
-	body, _ = sjson.SetBytes(body, "messages", messages)
-
-	jsonStr := string(body)
-	jsonStr = strings.ReplaceAll(jsonStr, "\\u003c", "<")
-	jsonStr = strings.ReplaceAll(jsonStr, "\\u003e", ">")
-	return []byte(jsonStr)
-}
-
-// constructWithQwenCoderTurboModel 重写QwenCoderTurbo模型要求的请求体
-func constructWithQwenCoderTurboModel(body []byte) []byte {
-	if gjson.GetBytes(body, "n").Int() > 1 {
-		body, _ = sjson.SetBytes(body, "n", 1)
-	}
-	suffix := gjson.GetBytes(body, "suffix")
-	prompt := gjson.GetBytes(body, "prompt")
-	codeLanguage := gjson.GetBytes(body, "extra.language")
-
-	messages := []map[string]interface{}{
-		{
-			"role":    "system",
-			"content": "You are an expert in " + codeLanguage.Str + " programming, highly skilled at understanding and continuing to write code.",
-		},
-		{
-			"role": "user",
-			"content": "Combined with subsequent code snippets, help me complete the code:\n\n" +
-				"Code subsequent content:\n```" + codeLanguage.Str + "\n" + suffix.Str + "```\n\n" +
-				"Remember:\n" +
-				"- Do not generate content outside of the code.\n" +
-				"- Do not directly fill in all the code content, the maximum number of lines of code should not exceed 5 lines.\n" +
-				"- Answer must refer to the code suffix content, do not exceed the boundary, otherwise repeated code will occur.\n" +
-				"- If you don't know how to answer, just reply with an empty string.",
-		},
-		{
-			"role":    "assistant",
-			"content": prompt.Str,
-			"partial": true,
-		},
-	}
-
-	body, _ = sjson.SetBytes(body, "messages", messages)
-	body, _ = sjson.DeleteBytes(body, "prompt")
-	return body
-}
-
-// constructWithOllamaModel 重写Ollama模型要求的请求体
-func constructWithOllamaModel(body []byte, codeMaxTokens int) []byte {
-	body, _ = sjson.SetBytes(body, "options.temperature", 0)
-	// stop参数处理
-	stopArray := gjson.GetBytes(body, "stop").Array()
-	stopSlice := make([]interface{}, len(stopArray))
-	for i, v := range stopArray {
-		stopSlice[i] = v.String()
-	}
-	body, _ = sjson.SetBytes(body, "options.stop", stopSlice)
-	body, _ = sjson.SetBytes(body, "stream", true)
-
-	maxTokens := gjson.GetBytes(body, "max_tokens").Int()
-	if int(maxTokens) > codeMaxTokens {
-		body, _ = sjson.SetBytes(body, "options.num_predict", codeMaxTokens)
-	} else {
-		body, _ = sjson.SetBytes(body, "options.num_predict", maxTokens)
-	}
-	return body
-}
-
-// abortCodex 中断请求
-func abortCodex(c *gin.Context, status int) {
-	c.Header("Content-Type", "text/event-stream")
-	c.String(status, "data: [DONE]\n\n")
-	c.Abort()
-}
diff --git a/internal/controller/copilot/router_register.go b/internal/controller/copilot/router_register.go
index 533fd16..75d09e1 100644
--- a/internal/controller/copilot/router_register.go
+++ b/internal/controller/copilot/router_register.go
@@ -85,6 +85,7 @@ func setupCopilotRoutes(g *gin.RouterGroup, config *Config) {
 	completionsGroup := g.Group("")
 	completionsGroup.Use(tokenMiddleware)
 	{
+		completionsGroup.GET("/v1/engines/:model-name", GetEngine)
 		completionsGroup.POST("/v1/engines/:model-name/completions", createCompletionsHandler(config))
 		completionsGroup.POST("/v1/engines/copilot-codex", createCompletionsHandler(config))
 		completionsGroup.POST("/chat/completions", createChatHandler(config))