This commit is contained in:
史悦
2025-08-13 19:03:20 +08:00
commit d62a2e9ed9
73 changed files with 7296 additions and 0 deletions

View File

@@ -0,0 +1,60 @@
package github_auth
import (
"crypto/sha256"
"encoding/hex"
"fmt"
"github.com/gofrs/uuid"
"os"
"sort"
"strings"
)
func sha256Sign(data string) string {
hash := sha256.Sum256([]byte(data))
return hex.EncodeToString(hash[:])
}
func GetAccessTokenT() string {
t, _ := uuid.NewV4()
return t.String()
}
func JsonMap2Token(data map[string]interface{}) string {
if len(data) == 0 {
return ""
}
keys := make([]string, 0, len(data))
for k := range data {
keys = append(keys, k)
}
sort.Strings(keys)
var sb strings.Builder
for i, key := range keys {
if i > 0 {
sb.WriteString(";")
}
sb.WriteString(key)
sb.WriteString("=")
sb.WriteString(fmt.Sprintf("%v", data[key]))
}
return sb.String()
}
func JsonMap2SignToken(data map[string]interface{}) string {
token := JsonMap2Token(data)
if token == "" {
return ""
}
sign := Token2Sign(token)
return token + ";8kp=1:" + sign
}
func Token2Sign(token string) string {
sign := sha256Sign(token + fmt.Sprintf(";salt=%s", os.Getenv("TOKEN_SALT")))
return sign
}

View File

@@ -0,0 +1,162 @@
package github_auth
import (
"encoding/json"
"fmt"
"github.com/gofrs/uuid"
"github.com/gomodule/redigo/redis"
"ripper/internal/cache"
"strings"
)
type ClientAuthInfo struct {
ClientId string `json:"client_id"`
DisplayUserName string `json:"display_user_name,omitempty"`
DeviceCode string `json:"device_code"`
UserCode string `json:"user_code"`
CardCode string `json:"card_code"`
}
type ClientOAuthInfo struct {
ClientId string `json:"client_id" form:"client_id"`
Code string `json:"code" form:"code"`
ClientSecret string `json:"client_secret" form:"client_secret"`
Scope string `json:"scope" form:"scope"`
}
// BindClientToCode 绑定客户端到代码
// clientId 客户端ID
// exp 过期时间
// return 用户代码, 设备代码, 错误
func BindClientToCode(clientId string, exp int) (string, string, error) {
genCode := func() string {
newUUID, _ := uuid.NewV4()
uuidStr := strings.Replace(newUUID.String(), "-", "", -1)
return uuidStr[:6]
}
formattedUUID := genCode()
rep := 0
redisKey := fmt.Sprintf("copilot.proxy.%s", formattedUUID)
repeat, _ := cache.Exist(redisKey)
for repeat {
if rep > 5 {
return "", "", fmt.Errorf("gen code error")
}
formattedUUID = genCode()
redisKey = fmt.Sprintf("copilot.proxy.%s", formattedUUID)
repeat, _ = cache.Exist(redisKey)
rep++
}
devId := GenDevicesCode(40)
authInfo := ClientAuthInfo{
ClientId: clientId,
DeviceCode: devId,
UserCode: formattedUUID,
}
authInfoData, _ := json.Marshal(authInfo)
err := cache.Set(redisKey, authInfoData, exp)
if err != nil {
return "", "", err
}
redisKey = fmt.Sprintf("copilot.proxy.map.%s", devId)
err = cache.Set(redisKey, formattedUUID, exp)
return formattedUUID, devId, err
}
// GetClientAuthInfoByDeviceCode 通过设备代码获取客户端授权信息
func GetClientAuthInfoByDeviceCode(deviceCode string) (*ClientAuthInfo, error) {
redisKey := fmt.Sprintf("copilot.proxy.map.%s", deviceCode)
userCode, err := cache.Get(redisKey)
if err != nil {
return nil, err
}
redisKey = fmt.Sprintf("copilot.proxy.%s", userCode)
authInfoData, err := redis.Bytes(cache.Get(redisKey))
if err != nil {
return nil, err
}
authInfo := &ClientAuthInfo{}
err = json.Unmarshal(authInfoData, &authInfo)
return authInfo, err
}
func GetOAuthCodeInfoByClientIdAndCode(clientId string, code string) (*ClientOAuthInfo, error) {
cacheKey := "oauth2_authorize_" + clientId
oauthCodeData, err := redis.Bytes(cache.Get(cacheKey))
if err != nil {
return nil, err
}
var oauthCode ClientOAuthInfo
err = json.Unmarshal(oauthCodeData, &oauthCode)
if err != nil {
return nil, err
}
if oauthCode.Code != code {
return nil, fmt.Errorf("invalid oauth code")
}
return &oauthCode, nil
}
func GetClientAuthInfo(code string) (ClientAuthInfo, error) {
redisKey := fmt.Sprintf("copilot.proxy.%s", code)
authInfoData, err := redis.Bytes(cache.Get(redisKey))
if err != nil {
return ClientAuthInfo{}, err
}
var authInfo ClientAuthInfo
err = json.Unmarshal(authInfoData, &authInfo)
return authInfo, err
}
// GenDevicesCode 生成设备代码
func GenDevicesCode(codeLen int) string {
var newUUID string
for len(newUUID) < 64 {
ud, _ := uuid.NewV4()
newUUID += strings.Replace(ud.String(), "-", "", -1)
}
return newUUID[:codeLen]
}
// UpdateClientAuthStatusByDeviceCode 更新客户端授权码通过设备代码
func UpdateClientAuthStatusByDeviceCode(deviceCode string, cardCode string, displayUserName string) error {
redisKey := fmt.Sprintf("copilot.proxy.map.%s", deviceCode)
uCode, err := cache.Get(redisKey)
if err != nil {
return err
}
redisKey = fmt.Sprintf("copilot.proxy.%s", uCode)
authInfoData, err := redis.Bytes(cache.Get(redisKey))
if err != nil {
return err
}
authInfo := &ClientAuthInfo{}
err = json.Unmarshal(authInfoData, &authInfo)
if err != nil {
return err
}
authInfo.CardCode = cardCode
if displayUserName != "" {
authInfo.DisplayUserName = displayUserName
}
authInfoData, _ = json.Marshal(authInfo)
err = cache.Set(redisKey, authInfoData, -1)
return err
}
func RemoveClientAuthInfoByDeviceCode(deviceCode string) error {
redisKey := fmt.Sprintf("copilot.proxy.map.%s", deviceCode)
uCode, err := cache.Get(redisKey)
if err != nil {
return err
}
redisKey = fmt.Sprintf("copilot.proxy.%s", uCode)
err = cache.Del(redisKey)
if err != nil {
return err
}
redisKey = fmt.Sprintf("copilot.proxy.map.%s", deviceCode)
err = cache.Del(redisKey)
return err
}

8
internal/cache/cacheable.go vendored Normal file
View File

@@ -0,0 +1,8 @@
package cache
type Cacheable interface {
Set(key string, value interface{}, ttl int) error
Get(key string) (interface{}, error)
Exist(key string) (bool, error)
Del(key string) error
}

100
internal/cache/memory.go vendored Normal file
View File

@@ -0,0 +1,100 @@
package cache
import (
"fmt"
"sync"
"time"
)
// MemoryMap 用于内存缓存
type MemoryMap struct {
cache map[string]interface{}
expirations map[string]int64
mu sync.Mutex
}
func NewMemoryMap() *MemoryMap {
m := &MemoryMap{}
m.init()
return m
}
// init 初始化 MemoryMap 的缓存
func (m *MemoryMap) init() {
m.cache = make(map[string]interface{})
m.expirations = make(map[string]int64)
}
func (m *MemoryMap) Get(key string) (interface{}, error) {
m.mu.Lock()
defer m.mu.Unlock()
expiration, exists := m.expirations[key]
currentTime := time.Now().UnixMilli()
if exists && currentTime > expiration {
// 键已过期,删除并返回 nil
fmt.Printf("Get: key=%s has expired, deleting...\n", key)
delete(m.cache, key)
delete(m.expirations, key)
return nil, nil
}
value, ok := m.cache[key]
if !ok {
return nil, nil
}
return value, nil
}
// Set 设置缓存中的值,并指定过期时间(秒)
func (m *MemoryMap) Set(key string, value interface{}, ttl int) error {
m.mu.Lock()
defer m.mu.Unlock()
m.cache[key] = value
if ttl == 0 {
// 默认半小时
ttl = 30 * 60
}
if ttl == -1 {
// -1 表示永久缓存,不设置过期时间
delete(m.expirations, key)
} else {
expiration := time.Now().UnixMilli() + int64(ttl*1000)
m.expirations[key] = expiration
}
return nil
}
func (m *MemoryMap) Exist(key string) (bool, error) {
m.mu.Lock()
defer m.mu.Unlock()
expiration, exists := m.expirations[key]
currentTime := time.Now().UnixMilli()
if exists && currentTime > expiration {
// 键已过期,删除并返回 nil
fmt.Printf("Get: key=%s has expired, deleting...\n", key)
delete(m.cache, key)
delete(m.expirations, key)
}
_, ok := m.cache[key]
return ok, nil
}
func (m *MemoryMap) Del(key string) error {
m.mu.Lock()
defer m.mu.Unlock()
if _, ok := m.cache[key]; !ok {
return nil
}
delete(m.cache, key)
delete(m.expirations, key)
return nil
}
// 编译时检查
var _ Cacheable = (*MemoryMap)(nil)

25
internal/cache/operation.go vendored Normal file
View File

@@ -0,0 +1,25 @@
package cache
var cache Cacheable
func init() {
cache = NewMemoryMap()
// 已废弃redis缓存实现
/*host := os.Getenv("REDIS_HOST")
port := os.Getenv("REDIS_PORT")
psw := os.Getenv("REDIS_PASSWORD")
cache = NewRedisInstance(host, port, psw)*/
}
func Set(key string, value interface{}, ttl int) error {
return cache.Set(key, value, ttl)
}
func Get(key string) (interface{}, error) {
return cache.Get(key)
}
func Exist(key string) (bool, error) {
return cache.Exist(key)
}
func Del(key string) error {
return cache.Del(key)
}

82
internal/cache/redis.go vendored Normal file
View File

@@ -0,0 +1,82 @@
package cache
import (
"fmt"
"github.com/gomodule/redigo/redis"
"strconv"
"time"
)
type Redis struct {
Host string
Port string
Psw string
Pool *redis.Pool
}
func NewRedisInstance(host string, port string, psw string) *Redis {
r := &Redis{Host: host, Port: port, Psw: psw}
r.init()
return r
}
func (r *Redis) Get(k string) (interface{}, error) {
return r.getConn().Do("get", k)
}
func (r *Redis) Set(k string, v interface{}, ttl int) error {
_, err := r.getConn().Do("set", k, v, "EX", ttl)
return err
}
func (r *Redis) Exist(k string) (bool, error) {
return redis.Bool(r.getConn().Do("EXISTS", k))
}
func (r *Redis) Del(k string) error {
_, err := r.getConn().Do("del", k)
return err
}
func (r *Redis) getConn() redis.Conn {
return r.Pool.Get()
}
func (r *Redis) init() {
r.Pool = &redis.Pool{
// Maximum number of connections allocated by the pool at a given time.
// When zero, there is no limit on the number of connections in the pool.
//最大活跃连接数0代表无限
MaxActive: 1000,
//最大闲置连接数
// Maximum number of idle connections in the pool.
MaxIdle: 50,
//闲置连接的超时时间
// Close connections after remaining idle for this duration. If the value
// is zero, then idle connections are not closed. Applications should set
// the timeout to a value less than the server's timeout.
IdleTimeout: time.Second * 100,
//定义拨号获得连接的函数
// Dial is an application supplied function for creating and configuring a
// connection.
//
// The connection returned from Dial must not be in a special state
// (subscribed to pubsub channel, transaction started, ...).
Dial: func() (redis.Conn, error) {
port, _ := strconv.Atoi(r.Port)
c, err := redis.Dial("tcp", fmt.Sprintf("%s:%d", r.Host, port))
if err != nil {
return nil, err
}
password := r.Psw
if password != "" {
if _, err := c.Do("AUTH", password); err != nil {
c.Close()
return nil, err
}
}
return c, err
},
}
}

View File

@@ -0,0 +1,153 @@
package auth
import (
_ "embed"
"fmt"
"github.com/gin-gonic/gin"
"net/http"
"os"
"ripper/internal/app/github_auth"
"ripper/internal/middleware"
"ripper/internal/response"
jwtpkg "ripper/pkg/jwt"
"time"
)
type postLoginDeviceCodeRequest struct {
ClientId string `json:"client_id" form:"client_id"`
}
type postLoginDeviceCodeResponse struct {
DeviceCode string `json:"device_code"` // 设备代码
UserCode string `json:"user_code"` // 用户代码
VerificationUrl string `json:"verification_uri"` // 验证地址
ExpiresIn int `json:"expires_in"` // 过期时间
Interval int `json:"interval"` // 间隔时间
}
type loginDeviceRequestInfo struct {
Code string `json:"code"`
Authorization string `json:"authorization"`
DisplayUserName string `json:"displayUserName,omitempty"`
Password string `json:"password"`
}
func postLoginDeviceCode(ctx *gin.Context) {
cli := postLoginDeviceCodeRequest{}
if err := ctx.ShouldBind(&cli); err != nil {
response.FailJson(ctx, response.FailStruct{
Code: -1,
Msg: "Invalid client id.",
}, false)
return
}
if cli.ClientId == "" {
response.FailJson(ctx, response.FailStruct{
Code: -1,
Msg: "Client id is required.",
}, false)
return
}
uid, devid, err := github_auth.BindClientToCode(cli.ClientId, 1800)
if err != nil {
response.FailJson(ctx, response.FailStruct{
Code: -1,
Msg: err.Error(),
}, false)
return
}
ctx.JSON(http.StatusOK, postLoginDeviceCodeResponse{
DeviceCode: devid,
UserCode: uid,
VerificationUrl: fmt.Sprintf("%s/login/device?user_code=%s", os.Getenv("DEFAULT_BASE_URL"), uid),
ExpiresIn: 1800,
Interval: 5,
})
}
func postLoginOauthAccessToken(ctx *gin.Context) {
v, exists := ctx.Get("client_auth_info")
if !exists {
ctx.JSON(http.StatusOK, gin.H{
"error": "authorization_pending",
"error_description": "The authorization request is still pending.",
"error_uri": "https://docs.github.com/developers/apps/authorizing-oauth-apps#error-codes-for-the-device-flow",
})
return
}
cliAuthInfo := v.(*github_auth.ClientAuthInfo)
t := time.Now()
t.Add(24 * 3 * time.Hour)
u, err := github_auth.GetClientAuthInfo(cliAuthInfo.UserCode)
if err != nil {
ctx.JSON(http.StatusOK, gin.H{
"error": "access_denied",
"error_description": "You must make a new request for a device code.",
"error_uri": "https://docs.github.com/developers/apps/authorizing-oauth-apps#error-codes-for-the-device-flow",
})
return
}
tk, _ := jwtpkg.CreateToken(&middleware.UserLoad{
UserDisplayName: cliAuthInfo.DisplayUserName,
CardCode: u.CardCode,
Client: cliAuthInfo.ClientId,
RegisteredClaims: jwtpkg.CreateStandardClaims(t.Unix(), "user"),
})
_ = github_auth.RemoveClientAuthInfoByDeviceCode(cliAuthInfo.ClientId)
ctx.JSON(http.StatusOK, gin.H{
"access_token": tk,
"scope": "",
"token_type": "bearer",
})
}
func postLoginDevice(ctx *gin.Context) {
var info loginDeviceRequestInfo
if err := response.BindStruct(ctx, &info); err != nil {
response.FailJson(ctx, response.FailStruct{
Code: 422,
Msg: "请求参数错误",
}, false)
return
}
// 验证密码
loginPassword := os.Getenv("LOGIN_PASSWORD")
if loginPassword != "" && info.Password != loginPassword {
response.FailJson(ctx, response.FailStruct{
Code: 422,
Msg: "访问密码错误",
}, false)
return
}
// 检查code是否存在
authInfo, err := github_auth.GetClientAuthInfo(info.Code)
if err != nil {
response.FailJson(ctx, response.FailStruct{
Code: 422,
Msg: "授权码填写错误",
}, false)
return
}
err = github_auth.UpdateClientAuthStatusByDeviceCode(authInfo.DeviceCode, info.Authorization, info.DisplayUserName)
if err != nil {
response.FailJson(ctx, response.FailStruct{
Code: 500,
Msg: "系统异常, 请稍后再试",
}, false)
return
}
response.SuccessJson(ctx, "ok")
}
func getLoginDevice(ctx *gin.Context) {
ctx.Header("Content-Type", "text/html; charset=utf-8")
ctx.HTML(http.StatusOK, "code.html", gin.H{})
}
func getHelpPage(ctx *gin.Context) {
ctx.Header("Content-Type", "text/html; charset=utf-8")
ctx.HTML(http.StatusOK, "help.html", gin.H{})
}

View File

@@ -0,0 +1,105 @@
package auth
import (
"bytes"
"encoding/json"
"net/http"
"os"
"time"
"github.com/gin-gonic/gin"
"ripper/internal/response"
)
const (
clientID = "Iv1.b507a08c87ecfe98"
deviceCodeURL = "https://github.com/login/device/code"
tokenURL = "https://github.com/login/oauth/access_token"
)
type githubLoginDeviceRequest struct {
DeviceCode string `form:"device_code" json:"device_code" binding:"required"`
}
// getDeviceCode returns the device code for GitHub login.
func getDeviceCode(c *gin.Context) {
body := map[string]string{
"client_id": clientID,
}
result, err := makeRequest(c, http.MethodPost, deviceCodeURL, body)
if err != nil {
c.JSON(http.StatusUnprocessableEntity, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, result)
}
// getGhuToken returns the GitHub user token.
func getGhuToken(c *gin.Context) {
var params githubLoginDeviceRequest
if err := c.ShouldBind(&params); err != nil {
response.FailJson(c, response.FailStruct{
Code: -1,
Msg: "Invalid request: " + err.Error(),
}, false)
return
}
body := map[string]string{
"client_id": clientID,
"device_code": params.DeviceCode,
"grant_type": "urn:ietf:params:oauth:grant-type:device_code",
}
result, err := makeRequest(c, http.MethodPost, tokenURL, body)
if err != nil {
c.JSON(http.StatusUnprocessableEntity, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, result)
}
// getGithubLoginDevice returns the login page for GitHub.
func getGithubLoginDevice(ctx *gin.Context) {
ctx.Header("Content-Type", "text/html; charset=utf-8")
ctx.HTML(http.StatusOK, "login.html", gin.H{})
}
// makeRequest makes a request to the given URL with the given method and body.
func makeRequest(c *gin.Context, method, url string, body map[string]string) (interface{}, error) {
jsonBody, err := json.Marshal(body)
if err != nil {
return nil, err
}
req, err := http.NewRequestWithContext(c, method, url, bytes.NewBuffer(jsonBody))
if err != nil {
return nil, err
}
req.Header.Set("accept", "application/json")
req.Header.Set("content-type", "application/json")
req.Header.Set("editor-plugin-version", "copilot-intellij/1.5.21.6667")
req.Header.Set("copilot-language-server-version", "1.228.0")
req.Header.Set("user-agent", "GithubCopilot/1.228.0")
req.Header.Set("editor-version", "JetBrains-IU/242.21829.142")
httpClientTimeout, _ := time.ParseDuration(os.Getenv("HTTP_CLIENT_TIMEOUT") + "s")
client := &http.Client{Timeout: httpClientTimeout}
resp, err := client.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
var result interface{}
err = json.NewDecoder(resp.Body).Decode(&result)
if err != nil {
return nil, err
}
return result, nil
}

View File

@@ -0,0 +1,100 @@
package auth
import (
"encoding/json"
"github.com/gin-gonic/gin"
"net/http"
"os"
"ripper/internal/app/github_auth"
"ripper/internal/cache"
"ripper/internal/middleware"
"ripper/internal/response"
jwtpkg "ripper/pkg/jwt"
"time"
)
type getLoginOauthAuthorizeRequest struct {
ClientId string `json:"client_id" form:"client_id"`
Prompt string `json:"prompt" form:"prompt"`
RedirectUri string `json:"redirect_uri" form:"redirect_uri"`
Scope string `json:"scope" form:"scope"`
State string `json:"state" form:"state"`
}
func getLoginOauthAuthorize(ctx *gin.Context) {
req := getLoginOauthAuthorizeRequest{}
err := ctx.BindQuery(&req)
if err != nil {
response.FailJson(ctx, response.FailStruct{
Code: -1,
Msg: "Invalid request.",
}, false)
return
}
vsCopilotClientId := os.Getenv("VS_COPILOT_CLIENT_ID")
if req.ClientId != vsCopilotClientId {
response.FailJson(ctx, response.FailStruct{
Code: -1,
Msg: "Invalid client id.",
}, false)
return
}
oauthCode := github_auth.GenDevicesCode(20)
cai := github_auth.ClientOAuthInfo{
ClientId: req.ClientId,
Code: oauthCode,
Scope: req.Scope,
}
cacheKey := "oauth2_authorize_" + req.ClientId
caiInfo, _ := json.Marshal(cai)
err = cache.Set(cacheKey, caiInfo, 300)
if err != nil {
response.FailJson(ctx, response.FailStruct{
Code: -1,
Msg: "Internal error.",
}, false)
return
}
// Redirect to the client's redirect_uri
browserSessionId := github_auth.GenDevicesCode(64)
ctx.Redirect(302, req.RedirectUri+"?browserSessionId="+browserSessionId+"&code="+oauthCode+"&state="+req.State)
}
func postLoginOauthAccessTokenForVs2022(ctx *gin.Context) {
v, exists := ctx.Get("client_auth_info")
if !exists {
response.FailJson(ctx, response.FailStruct{
Code: -1,
Msg: "Invalid client id.",
}, false)
return
}
cliAuthInfo := v.(*github_auth.ClientOAuthInfo)
t := time.Now()
t.Add(24 * 3 * time.Hour)
tk, _ := jwtpkg.CreateToken(&middleware.UserLoad{
CardCode: cliAuthInfo.Code,
Client: cliAuthInfo.ClientId,
RegisteredClaims: jwtpkg.CreateStandardClaims(t.Unix(), "user"),
})
ctx.JSON(http.StatusOK, gin.H{
"access_token": tk,
"scope": cliAuthInfo.Scope,
"token_type": "bearer",
})
}
func getSiteSha(ctx *gin.Context) {
ctx.Header("X-GitHub-Request-Id", "C0E1:6A1A:1A1F:2A1D:1A1F:1A1F:1A1F:1A1F")
ctx.JSON(http.StatusOK, gin.H{})
}
func getLoginConfig(ctx *gin.Context) {
loginPassword := os.Getenv("LOGIN_PASSWORD")
ctx.JSON(http.StatusOK, gin.H{
"is_login_password": loginPassword != "",
})
}

View File

@@ -0,0 +1,42 @@
package auth
import (
"github.com/gin-gonic/gin"
"ripper/internal/middleware"
"strings"
)
func GinApi(g *gin.RouterGroup) {
g.GET("/help", getHelpPage)
// 启动设备代码登录流程
g.POST("/login/device/code", postLoginDeviceCode)
g.POST("/login/device", postLoginDevice)
g.GET("/login/device", getLoginDevice)
g.POST("/login/oauth/access_token", func(ctx *gin.Context) {
if strings.Index(ctx.Request.UserAgent(), "VSTeamExplorer") != -1 {
middleware.AuthCodeFlowCheckAuth(ctx)
} else {
middleware.DeviceCodeCheckAuth(ctx)
}
}, func(ctx *gin.Context) {
if strings.Index(ctx.Request.UserAgent(), "VSTeamExplorer") != -1 {
postLoginOauthAccessTokenForVs2022(ctx)
} else {
postLoginOauthAccessToken(ctx)
}
})
// oauth2 登录
g.GET("/login/oauth/authorize", getLoginOauthAuthorize)
// enterprise 验证
g.GET("/site/sha", getSiteSha)
// 获取登录页面配置
g.GET("/login/config", getLoginConfig)
// GitHub模拟登录获取 ghu_token
g.GET("/github/login/device/code", getGithubLoginDevice)
g.POST("/github/login/device/code", getDeviceCode)
g.POST("/github/login/ghu-token", getGhuToken)
}

View File

@@ -0,0 +1,29 @@
package copilot
import (
"github.com/gin-gonic/gin"
"github.com/gofrs/uuid"
"net/http"
)
// GetAgents 获取代理列表
func GetAgents(c *gin.Context) {
requestID := uuid.Must(uuid.NewV4()).String()
c.Header("x-github-request-id", requestID)
c.JSON(http.StatusOK, gin.H{
"agents": []gin.H{
{
"id": "github/copilot-workspace",
"name": "@workspace",
"description": "Ask questions and get answers about your codebase.",
"version": "1.0.0",
"publisher": "github",
"model": "gpt-4o-mini-2024-07-18",
"capabilities": "workspace",
"default_model": "gpt-4o-mini-2024-07-18",
"capabilities_model": "gpt-4o-mini-2024-07-18",
},
},
})
}

View File

@@ -0,0 +1,188 @@
package copilot
import (
"bytes"
"context"
"crypto/tls"
"errors"
"fmt"
"github.com/gin-gonic/gin"
"github.com/gofrs/uuid"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
"io"
"log"
"net/http"
"net/url"
"os"
"strconv"
"strings"
"time"
)
// ChatCompletions chat对话接口
func ChatCompletions(c *gin.Context) {
ctx := c.Request.Context()
// 添加响应头, 解决vscode校验github所属问题
requestID := uuid.Must(uuid.NewV4()).String()
c.Header("x-github-request-id", requestID)
body, err := io.ReadAll(c.Request.Body)
if nil != err {
c.AbortWithStatus(http.StatusBadRequest)
return
}
apiModelName := gjson.GetBytes(body, "model").String()
// 默认设置的对话模型
envModelName := os.Getenv("CHAT_API_MODEL_NAME")
// 默认设置的对话请求地址
chatAPIURL := os.Getenv("CHAT_API_BASE")
// 默认设置的对话模型key
apiKey := os.Getenv("CHAT_API_KEY")
// 轻量模型直接走代码补全接口, 节约成本
if strings.Contains(apiModelName, os.Getenv("LIGHTWEIGHT_MODEL")) {
envModelName = os.Getenv("CODEX_API_MODEL_NAME")
codexAPIURL := os.Getenv("CODEX_API_BASE")
parsedURL, err := url.Parse(codexAPIURL)
if err != nil {
fmt.Println("URL解析错误:", err)
return
}
chatAPIURL = "https://" + parsedURL.Hostname() + "/v1/chat/completions"
apiKey = os.Getenv("CODEX_API_KEY")
}
c.Header("Content-Type", "text/event-stream")
body, _ = sjson.SetBytes(body, "model", envModelName)
body, _ = sjson.SetBytes(body, "stream", true) // 强制流式输出
if !gjson.GetBytes(body, "function_call").Exists() {
messages := gjson.GetBytes(body, "messages").Array()
for i, msg := range messages {
toolCalls := msg.Get("tool_calls").Array()
if len(toolCalls) == 0 {
body, _ = sjson.DeleteBytes(body, fmt.Sprintf("messages.%d.tool_calls", i))
}
}
lastIndex := len(messages) - 1
chatLocale := os.Getenv("CHAT_LOCALE")
if chatLocale != "" && !strings.Contains(messages[lastIndex].Get("content").String(), "Respond in the following locale") {
body, _ = sjson.SetBytes(body, "messages."+strconv.Itoa(lastIndex)+".content", messages[lastIndex].Get("content").String()+"Respond in the following locale: "+chatLocale+".")
}
}
body, _ = sjson.DeleteBytes(body, "intent")
body, _ = sjson.DeleteBytes(body, "intent_threshold")
body, _ = sjson.DeleteBytes(body, "intent_content")
body, _ = sjson.DeleteBytes(body, "logprobs") // #IBZYCA
// 是否支持使用工具, 避免模型不支持相关功能报错
chatUseTools, _ := strconv.ParseBool(os.Getenv("CHAT_USE_TOOLS"))
if !chatUseTools {
body, _ = sjson.DeleteBytes(body, "tools")
body, _ = sjson.DeleteBytes(body, "tool_call")
body, _ = sjson.DeleteBytes(body, "functions")
body, _ = sjson.DeleteBytes(body, "function_call")
body, _ = sjson.DeleteBytes(body, "tool_choice")
}
ChatMaxTokens, _ := strconv.Atoi(os.Getenv("CHAT_MAX_TOKENS"))
if int(gjson.GetBytes(body, "max_tokens").Int()) > ChatMaxTokens {
body, _ = sjson.SetBytes(body, "max_tokens", ChatMaxTokens)
}
if gjson.GetBytes(body, "n").Int() > 1 {
body, _ = sjson.SetBytes(body, "n", 1)
}
messages := gjson.GetBytes(body, "messages").Array()
userAgent := c.GetHeader("User-Agent")
// 拦截处理vscode对话首次预处理请求, 减少等待时间
firstRole := gjson.GetBytes(body, "messages.0.role").String()
firstContent := gjson.GetBytes(body, "messages.0.content").String()
if strings.Contains(firstRole, "system") && strings.Contains(firstContent, "You are a helpful AI programming assistant to a user") &&
!strings.Contains(firstContent, "If you cannot choose just one category, or if none of the categories seem like they would provide the user with a better result, you must always respond with") &&
!gjson.GetBytes(body, "tool_choice").Exists() {
_, _ = c.Writer.WriteString("data: [DONE]\n\n")
c.Writer.Flush()
return
}
// vs2022客户端的兼容处理
if strings.Contains(userAgent, "VSCopilotClient") {
lastMessage := messages[len(messages)-1]
messageRole := lastMessage.Get("role").String()
messageContent := lastMessage.Get("content").String()
if strings.Contains(firstRole, "system") && strings.Contains(firstContent, "You are an AI programming assistant") {
vs2022FirstChatTemplate(c)
return
}
if messageRole == "user" && messageContent == "Write a short one-sentence question that I can ask that naturally follows from the previous few questions and answers. It should not ask a question which is already answered in the conversation. It should be a question that you are capable of answering. Reply with only the text of the question and nothing else." {
_, _ = c.Writer.WriteString("data: [DONE]\n\n")
c.Writer.Flush()
return
}
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, chatAPIURL, io.NopCloser(bytes.NewBuffer(body)))
if nil != err {
c.AbortWithStatus(http.StatusInternalServerError)
return
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+apiKey)
httpClientTimeout, _ := time.ParseDuration(os.Getenv("HTTP_CLIENT_TIMEOUT") + "s")
client := &http.Client{
Timeout: httpClientTimeout,
Transport: &http.Transport{
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
},
}
resp, err := client.Do(req)
if nil != err {
if errors.Is(err, context.Canceled) {
c.AbortWithStatus(http.StatusRequestTimeout)
return
}
log.Println("request conversation failed:", err.Error())
c.AbortWithStatus(http.StatusInternalServerError)
return
}
defer CloseIO(resp.Body)
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
log.Println("request completions failed:", string(body))
resp.Body = io.NopCloser(bytes.NewBuffer(body))
}
c.Status(resp.StatusCode)
_, _ = io.Copy(c.Writer, resp.Body)
}
// vs2022FirstChatTemplate is a template for the first chat completion response
func vs2022FirstChatTemplate(c *gin.Context) {
fixedOutput := `data: {"id":"f6202f6f-9d13-4518-b34f-65e945b0a1a2","object":"chat.completion.chunk","model":"gpt-4o-mini-2024-07-18","created":1734752124,"choices":[{"index":0,"delta":{"role":"assistant","content":""},"finish_reason":null}]}
data: {"id":"b2ab39cb-9a84-4006-b470-93a5965c6d69","object":"chat.completion.chunk","model":"gpt-4o-mini-2024-07-18","created":1734752124,"choices":[{"index":0,"delta":{"role":"assistant","content":""},"finish_reason":null}]}
data: {"id":"df5f9ce7-b653-4ffb-8d92-e21856ce1ffc","object":"chat.completion.chunk","model":"gpt-4o-mini-2024-07-18","created":1734752124,"choices":[{"index":0,"delta":{"role":"assistant","content":"Explain"},"finish_reason":null}]}
data: {"id":"fb58d66e-bb16-43f2-8470-2de0c8662533","object":"chat.completion.chunk","model":"gpt-4o-mini-2024-07-18","created":1734752124,"choices":[{"index":0,"delta":{"role":"assistant","content":""},"finish_reason":null}]}
data: {"id":"22ea16e2-766f-4b10-84d0-68399abc9181","object":"chat.completion.chunk","model":"gpt-4o-mini-2024-07-18","created":1734752124,"choices":[{"index":0,"delta":{"role":"assistant","content":""},"finish_reason":"stop"}]}
data: [DONE]
`
_, _ = c.Writer.WriteString(fixedOutput)
c.Writer.Flush()
}

View File

@@ -0,0 +1,273 @@
package copilot
import (
"context"
"crypto/sha256"
"fmt"
"github.com/gofrs/uuid"
"net/http"
"os"
"strconv"
"strings"
"sync"
"github.com/gin-gonic/gin"
)
// 常量定义
const (
defaultDimensionSize = 1536 // 默认向量维度
markdownFilePrefix = "File: `%s`\n```shell\n"
markdownFileSuffix = "```"
)
// 获取向量维度大小
func getDimensionSize() int {
dimensionSize := defaultDimensionSize
if dimSizeStr := os.Getenv("EMBEDDING_DIMENSION_SIZE"); dimSizeStr != "" {
if dimSize, err := strconv.Atoi(dimSizeStr); err == nil {
dimensionSize = dimSize
}
}
return dimensionSize
}
// 计算块大小
func getChunkSize() int {
// 根据维度大小调整块大小这里设置为维度的1.5倍左右
return getDimensionSize() * 3 / 2
}
// ChunkRequest 表示分块请求
type ChunkRequest struct {
Content string `json:"content" binding:"required"`
Path string `json:"path" binding:"required"`
Embed bool `json:"embed"`
}
// Chunk 表示内容块
type Chunk struct {
Hash string `json:"hash"`
Text string `json:"text"`
Range Range `json:"range"`
Embedding Embedding `json:"embedding,omitempty"`
}
// Range 表示文本范围
type Range struct {
Start int `json:"start"`
End int `json:"end"`
}
// Embedding 表示向量嵌入
type Embedding struct {
Embedding []float32 `json:"embedding"`
}
// ChunkResponse 表示分块响应
type ChunkResponse struct {
Chunks []Chunk `json:"chunks"`
EmbeddingModel string `json:"embedding_model"`
}
// ChunkService 处理文本分块和嵌入的服务
type ChunkService struct {
embeddingClient *EmbeddingClient
modelName string
}
// NewChunkService 创建新的分块服务
func NewChunkService() (*ChunkService, error) {
client, err := NewEmbeddingClient(getDimensionSize())
if err != nil {
return nil, fmt.Errorf("failed to create embedding client: %w", err)
}
return &ChunkService{
embeddingClient: client,
modelName: os.Getenv("EMBEDDING_API_MODEL_NAME"),
}, nil
}
// HandleChunks 处理分块请求的HTTP处理器
func HandleChunks(c *gin.Context) {
var req ChunkRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
service, err := NewChunkService()
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("failed to initialize service: %v", err)})
return
}
chunks := service.SplitIntoChunks(req.Content, req.Path)
if req.Embed {
if err := service.GenerateEmbeddings(c.Request.Context(), chunks); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("failed to generate embeddings: %v", err)})
return
}
}
resp := ChunkResponse{
Chunks: chunks,
EmbeddingModel: service.modelName,
}
requestID := uuid.Must(uuid.NewV4()).String()
c.Header("x-github-request-id", requestID)
c.JSON(http.StatusOK, resp)
}
// SplitIntoChunks 将内容分割成块
func (s *ChunkService) SplitIntoChunks(content, path string) []Chunk {
var chunks []Chunk
lines := strings.Split(content, "\n")
chunkSize := getChunkSize()
// 预分配切片容量,减少内存重新分配
estimatedChunks := len(content)/chunkSize + 1
chunks = make([]Chunk, 0, estimatedChunks)
var sb strings.Builder
start := 0
for _, line := range lines {
lineWithNewline := line + "\n"
// 如果当前块加上新行会超过chunkSize并且当前块不为空
if sb.Len()+len(lineWithNewline) > chunkSize && sb.Len() > 0 {
// 创建新的chunk
chunkText := sb.String()
chunk := s.createChunk(chunkText, path, start, start+len(chunkText))
chunks = append(chunks, chunk)
start += len(chunkText)
sb.Reset()
sb.WriteString(lineWithNewline)
} else {
sb.WriteString(lineWithNewline)
}
}
// 添加最后一个chunk
if sb.Len() > 0 {
chunkText := sb.String()
chunk := s.createChunk(chunkText, path, start, start+len(chunkText))
chunks = append(chunks, chunk)
}
return chunks
}
// createChunk 创建一个新的内容块
func (s *ChunkService) createChunk(text, path string, start, end int) Chunk {
// 计算文本的SHA-256哈希
hash := sha256.Sum256([]byte(text))
return Chunk{
Hash: fmt.Sprintf("%x", hash),
Text: fmt.Sprintf(markdownFilePrefix+"%s"+markdownFileSuffix, path, text),
Range: Range{
Start: start,
End: end,
},
Embedding: Embedding{
Embedding: make([]float32, 0), // 初始化为空切片
},
}
}
// GenerateEmbeddings 为所有块生成嵌入向量
func (s *ChunkService) GenerateEmbeddings(ctx context.Context, chunks []Chunk) error {
if len(chunks) == 0 {
return nil
}
// 对于少量块,直接串行处理
if len(chunks) <= 5 {
return s.generateEmbeddingsSerial(ctx, chunks)
}
// 对于大量块,使用并行处理
return s.generateEmbeddingsParallel(ctx, chunks)
}
// generateEmbeddingsSerial 串行生成嵌入向量
func (s *ChunkService) generateEmbeddingsSerial(ctx context.Context, chunks []Chunk) error {
for i := range chunks {
text := s.extractPlainText(chunks[i].Text)
embedding, err := s.embeddingClient.GetEmbedding(ctx, text)
if err != nil {
return fmt.Errorf("failed to generate embedding for chunk %d: %w", i, err)
}
chunks[i].Embedding.Embedding = embedding
}
return nil
}
// generateEmbeddingsParallel 并行生成嵌入向量
func (s *ChunkService) generateEmbeddingsParallel(ctx context.Context, chunks []Chunk) error {
var wg sync.WaitGroup
errChan := make(chan error, len(chunks))
// 限制并发数量,避免过多的并发请求
semaphore := make(chan struct{}, 10)
for i := range chunks {
wg.Add(1)
go func(idx int) {
defer wg.Done()
// 获取信号量
semaphore <- struct{}{}
defer func() { <-semaphore }()
text := s.extractPlainText(chunks[idx].Text)
embedding, err := s.embeddingClient.GetEmbedding(ctx, text)
if err != nil {
errChan <- fmt.Errorf("failed to generate embedding for chunk %d: %w", idx, err)
return
}
chunks[idx].Embedding.Embedding = embedding
}(i)
}
// 等待所有goroutine完成
wg.Wait()
close(errChan)
// 检查是否有错误
select {
case err := <-errChan:
if err != nil {
return err
}
default:
// 没有错误
}
return nil
}
// extractPlainText 从markdown格式的文本中提取纯文本
func (s *ChunkService) extractPlainText(text string) string {
// 移除第一行 File: 标记
if idx := strings.Index(text, "\n"); idx != -1 {
text = text[idx+1:]
}
// 移除 ```shell 和结尾的 ```
text = strings.TrimPrefix(text, "```shell\n")
text = strings.TrimSuffix(text, "```")
return text
}

View File

@@ -0,0 +1,373 @@
package copilot
import (
"bufio"
"bytes"
"context"
"crypto/tls"
"encoding/json"
"errors"
"fmt"
"io"
"log"
"math/rand"
"net/http"
"os"
"strconv"
"strings"
"time"
"github.com/gin-gonic/gin"
"github.com/gofrs/uuid"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
// CodeCompletions 代码补全
func CodeCompletions(c *gin.Context) {
ctx := c.Request.Context()
requestID := uuid.Must(uuid.NewV4()).String()
c.Header("x-github-request-id", requestID)
debounceTime, _ := strconv.Atoi(os.Getenv("COPILOT_DEBOUNCE"))
time.Sleep(time.Duration(debounceTime) * time.Millisecond)
if ctx.Err() != nil {
abortCodex(c, http.StatusRequestTimeout)
return
}
body, err := io.ReadAll(c.Request.Body)
if nil != err {
abortCodex(c, http.StatusBadRequest)
return
}
c.Header("Content-Type", "text/event-stream")
codexServiceType := os.Getenv("CODEX_SERVICE_TYPE")
body = ConstructRequestBody(body, codexServiceType)
req, err := http.NewRequestWithContext(ctx, http.MethodPost, os.Getenv("CODEX_API_BASE"), io.NopCloser(bytes.NewBuffer(body)))
if nil != err {
abortCodex(c, http.StatusInternalServerError)
return
}
req.Header.Set("Content-Type", "application/json")
apiKeys := strings.Split(os.Getenv("CODEX_API_KEY"), ",")
// 检查 apiKeys 是否有效
if len(apiKeys) == 0 || (len(apiKeys) == 1 && apiKeys[0] == "") {
abortCodex(c, http.StatusInternalServerError)
return
}
randGen := rand.New(rand.NewSource(time.Now().UnixNano()))
selectedKey := strings.TrimSpace(apiKeys[randGen.Intn(len(apiKeys))])
if selectedKey == "" {
abortCodex(c, http.StatusInternalServerError)
return
}
req.Header.Set("Authorization", "Bearer "+selectedKey)
httpClientTimeout, _ := time.ParseDuration(os.Getenv("HTTP_CLIENT_TIMEOUT") + "s")
client := &http.Client{
Timeout: httpClientTimeout,
Transport: &http.Transport{
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
},
}
resp, err := client.Do(req)
if nil != err {
if errors.Is(err, context.Canceled) {
abortCodex(c, http.StatusRequestTimeout)
return
}
log.Println("request completions failed:", err.Error())
abortCodex(c, http.StatusInternalServerError)
return
}
defer CloseIO(resp.Body)
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
log.Println("request completions failed:", string(body))
abortCodex(c, resp.StatusCode)
return
}
c.Status(resp.StatusCode)
// 处理 Ollama 服务的流式响应
if codexServiceType == "ollama" {
reader := bufio.NewReader(resp.Body)
for {
line, err := reader.ReadString('\n')
if err != nil {
if err == io.EOF {
break
}
break
}
if strings.TrimSpace(line) == "" {
continue
}
// json解析 line
lineJson := gjson.Parse(line)
uuid := uuid.Must(uuid.NewV4()).String()
done := lineJson.Get("done").Bool()
doneReason := lineJson.Get("done_reason").Str
response := lineJson.Get("response").Str
timestamp := time.Now().Unix()
choice := map[string]interface{}{
"text": response,
"index": 0,
"logprobs": nil,
"finish_reason": doneReason,
}
choices := []map[string]interface{}{choice}
constructLineData := map[string]interface{}{
"id": uuid,
"choices": choices,
"created": timestamp,
"model": lineJson.Get("model").Str,
"system_fingerprint": "fp_1c141eb703",
"object": "text_completion",
}
if done && strings.Contains(doneReason, "stop") {
usage := map[string]interface{}{
"prompt_tokens": lineJson.Get("prompt_eval_count").Int(),
"completion_tokens": lineJson.Get("eval_count").Int(),
"total_tokens": lineJson.Get("prompt_eval_count").Int(),
"prompt_cache_hit_tokens": lineJson.Get("prompt_eval_count").Int(),
"prompt_cache_miss_tokens": lineJson.Get("eval_count").Int(),
}
constructLineData["usage"] = usage
}
// 将修改后的数据重新编码为 JSON
modifiedJSON, err := json.Marshal(constructLineData)
if err != nil {
continue
}
// 发送修改后的数据
_, _ = c.Writer.WriteString("data: " + string(modifiedJSON) + "\n\n")
c.Writer.Flush()
}
_, _ = c.Writer.WriteString("data: [DONE]\n\n")
c.Writer.Flush()
return
}
// 处理默认服务的响应
_, _ = io.Copy(c.Writer, resp.Body)
}
// ConstructRequestBody 重新构建请求体
func ConstructRequestBody(body []byte, codexServiceType string) []byte {
envCodexModel := os.Getenv("CODEX_API_MODEL_NAME")
body, _ = sjson.SetBytes(body, "model", envCodexModel)
body, _ = sjson.SetBytes(body, "stream", true) // 强制流式输出
body, _ = sjson.DeleteBytes(body, "extra")
body, _ = sjson.DeleteBytes(body, "nwo")
// 限制 prompt 和 suffix 的长度
body = applyPromptLengthLimit(body)
temperature, _ := strconv.ParseFloat(os.Getenv("CODEX_TEMPERATURE"), 64)
if temperature != -1 {
body, _ = sjson.SetBytes(body, "temperature", temperature)
}
codeMaxTokens, _ := strconv.Atoi(os.Getenv("CODEX_MAX_TOKENS"))
if int(gjson.GetBytes(body, "max_tokens").Int()) > codeMaxTokens {
body, _ = sjson.SetBytes(body, "max_tokens", codeMaxTokens)
}
if gjson.GetBytes(body, "n").Int() > 1 {
body, _ = sjson.SetBytes(body, "n", 1)
}
// https://ollama.com/library/stable-code || https://ollama.com/library/codegemma
if strings.Contains(envCodexModel, "stable-code") || strings.Contains(envCodexModel, "codegemma") {
return constructWithStableCodeModel(body)
}
// https://ollama.com/library/codellama
if strings.Contains(envCodexModel, "codellama") {
return constructWithCodeLlamaModel(body)
}
// https://help.aliyun.com/zh/model-studio/user-guide/qwen-coder?spm=a2c4g.11186623.0.0.a5234823I6LvAG
if strings.Contains(envCodexModel, "qwen-coder-turbo") {
return constructWithQwenCoderTurboModel(body)
}
// 支持 Ollama FIM 的模型, 如:https://ollama.com/library/deepseek-coder-v2
if codexServiceType == "ollama" {
return constructWithOllamaModel(body, codeMaxTokens)
}
return body
}
// applyPromptLengthLimit 对 prompt 和 suffix 应用长度限制
func applyPromptLengthLimit(body []byte) []byte {
envLimitPrompt := os.Getenv("CODEX_LIMIT_PROMPT")
limitPrompt, err := strconv.Atoi(envLimitPrompt)
if err != nil || limitPrompt <= 0 {
return body
}
body = limitPromptLength(body, limitPrompt)
body = limitSuffixLength(body, limitPrompt)
return body
}
// limitPromptLength 限制 prompt 长度
func limitPromptLength(body []byte, limitRows int) []byte {
prompt := gjson.GetBytes(body, "prompt")
if !prompt.Exists() {
return body
}
rows := strings.Split(prompt.Str, "\n")
if len(rows) <= limitRows {
return body
}
// 保留后面的内容
newPrompt := strings.Join(rows[len(rows)-limitRows:], "\n")
body, _ = sjson.SetBytes(body, "prompt", newPrompt)
return body
}
// limitSuffixLength 限制 suffix 长度
func limitSuffixLength(body []byte, limitRows int) []byte {
suffix := gjson.GetBytes(body, "suffix")
if !suffix.Exists() {
return body
}
rows := strings.Split(suffix.Str, "\n")
if len(rows) <= limitRows {
return body
}
// 保留前面的内容
newSuffix := strings.Join(rows[:limitRows], "\n")
body, _ = sjson.SetBytes(body, "suffix", newSuffix)
return body
}
// constructWithCodeLlamaModel 重写codeLlama模型要求的请求体
func constructWithCodeLlamaModel(body []byte) []byte {
suffix := gjson.GetBytes(body, "suffix")
prompt := gjson.GetBytes(body, "prompt")
content := fmt.Sprintf("<PRE> %s <SUF> %s <MID>", prompt, suffix)
return constructWithChatModel(body, content)
}
// constructWithStableCodeModel 重写StableCode模型要求的请求体
func constructWithStableCodeModel(body []byte) []byte {
suffix := gjson.GetBytes(body, "suffix")
prompt := gjson.GetBytes(body, "prompt")
content := fmt.Sprintf("<fim_prefix>%s<fim_suffix>%s<fim_middle>", prompt, suffix)
return constructWithChatModel(body, content)
}
// constructWithChatModel 重写Chat请求体
func constructWithChatModel(body []byte, content string) []byte {
// 创建新的 JSON 对象并添加到 body 中
messages := []map[string]string{
{
"role": "user",
"content": content,
},
}
body, _ = sjson.SetBytes(body, "messages", messages)
jsonStr := string(body)
jsonStr = strings.ReplaceAll(jsonStr, "\\u003c", "<")
jsonStr = strings.ReplaceAll(jsonStr, "\\u003e", ">")
return []byte(jsonStr)
}
// constructWithQwenCoderTurboModel 重写QwenCoderTurbo模型要求的请求体
func constructWithQwenCoderTurboModel(body []byte) []byte {
if gjson.GetBytes(body, "n").Int() > 1 {
body, _ = sjson.SetBytes(body, "n", 1)
}
suffix := gjson.GetBytes(body, "suffix")
prompt := gjson.GetBytes(body, "prompt")
codeLanguage := gjson.GetBytes(body, "extra.language")
messages := []map[string]interface{}{
{
"role": "system",
"content": "You are an expert in " + codeLanguage.Str + " programming, highly skilled at understanding and continuing to write code.",
},
{
"role": "user",
"content": "Combined with subsequent code snippets, help me complete the code:\n\n" +
"Code subsequent content:\n```" + codeLanguage.Str + "\n" + suffix.Str + "```\n\n" +
"Remember:\n" +
"- Do not generate content outside of the code.\n" +
"- Do not directly fill in all the code content, the maximum number of lines of code should not exceed 5 lines.\n" +
"- Answer must refer to the code suffix content, do not exceed the boundary, otherwise repeated code will occur.\n" +
"- If you don't know how to answer, just reply with an empty string.",
},
{
"role": "assistant",
"content": prompt.Str,
"partial": true,
},
}
body, _ = sjson.SetBytes(body, "messages", messages)
body, _ = sjson.DeleteBytes(body, "prompt")
return body
}
// constructWithOllamaModel 重写Ollama模型要求的请求体
func constructWithOllamaModel(body []byte, codeMaxTokens int) []byte {
body, _ = sjson.SetBytes(body, "options.temperature", 0)
// stop参数处理
stopArray := gjson.GetBytes(body, "stop").Array()
stopSlice := make([]interface{}, len(stopArray))
for i, v := range stopArray {
stopSlice[i] = v.String()
}
body, _ = sjson.SetBytes(body, "options.stop", stopSlice)
body, _ = sjson.SetBytes(body, "stream", true)
maxTokens := gjson.GetBytes(body, "max_tokens").Int()
if int(maxTokens) > codeMaxTokens {
body, _ = sjson.SetBytes(body, "options.num_predict", codeMaxTokens)
} else {
body, _ = sjson.SetBytes(body, "options.num_predict", maxTokens)
}
return body
}
// abortCodex 中断请求
func abortCodex(c *gin.Context, status int) {
c.Header("Content-Type", "text/event-stream")
c.String(status, "data: [DONE]\n\n")
c.Abort()
}

View File

@@ -0,0 +1,167 @@
package copilot
import (
"bytes"
"context"
"crypto/tls"
"encoding/json"
"fmt"
"io"
"net/http"
"os"
"sync"
"time"
)
// 常量定义
const (
defaultTimeout = 30 * time.Second
contentTypeJSON = "application/json"
)
// EmbeddingRequest 表示向嵌入API发送的请求
type EmbeddingRequest struct {
Model string `json:"model"`
Input []string `json:"input"`
Dimensions int `json:"dimensions"`
}
// EmbeddingResponse 表示从嵌入API接收的响应
type EmbeddingResponse struct {
Data []EmbeddingData `json:"data"`
Model string `json:"model"`
Object string `json:"object"`
Usage Usage `json:"usage"`
}
// EmbeddingData 表示单个嵌入数据
type EmbeddingData struct {
Embedding []float32 `json:"embedding"`
Index int `json:"index"`
Object string `json:"object"`
}
// Usage 表示API使用情况
type Usage struct {
PromptTokens int `json:"prompt_tokens"`
TotalTokens int `json:"total_tokens"`
}
// 移除未使用的类型
// Parameters 和 EmbeddingsRequest, EmbeddingsResponse 已被移除
// EmbeddingClient 封装了与嵌入API交互的功能
type EmbeddingClient struct {
apiURL string
apiKey string
model string
dimensions int
httpClient *http.Client
clientMutex sync.RWMutex
}
// NewEmbeddingClient 创建一个新的嵌入客户端
func NewEmbeddingClient(dimensions int) (*EmbeddingClient, error) {
apiURL := os.Getenv("EMBEDDING_API_BASE")
apiKey := os.Getenv("EMBEDDING_API_KEY")
if apiURL == "" || apiKey == "" {
return nil, fmt.Errorf("EMBEDDING_API_BASE or EMBEDDING_API_KEY environment variable not set")
}
if os.Getenv("EMBEDDING_API_MODEL_NAME") == "" {
return nil, fmt.Errorf("EMBEDDING_API_MODEL_NAME environment variable not set")
}
// 解析超时时间,如果未设置或解析失败则使用默认值
timeout := defaultTimeout
if timeoutStr := os.Getenv("HTTP_CLIENT_TIMEOUT"); timeoutStr != "" {
if parsedTimeout, err := time.ParseDuration(timeoutStr + "s"); err == nil {
timeout = parsedTimeout
}
}
client := &http.Client{
Timeout: timeout,
Transport: &http.Transport{
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
},
}
return &EmbeddingClient{
apiURL: apiURL,
apiKey: apiKey,
model: os.Getenv("EMBEDDING_API_MODEL_NAME"),
dimensions: dimensions,
httpClient: client,
}, nil
}
// SetModel 设置嵌入模型
func (c *EmbeddingClient) SetModel(model string) {
c.clientMutex.Lock()
defer c.clientMutex.Unlock()
c.model = model
}
// GetEmbedding 获取单个文本的嵌入
func (c *EmbeddingClient) GetEmbedding(ctx context.Context, text string) ([]float32, error) {
resp, err := c.GetEmbeddings(ctx, []string{text})
if err != nil {
return nil, err
}
if len(resp.Data) == 0 {
return nil, fmt.Errorf("no embeddings returned")
}
return resp.Data[0].Embedding, nil
}
// GetEmbeddings 批量获取多个文本的嵌入
func (c *EmbeddingClient) GetEmbeddings(ctx context.Context, texts []string) (*EmbeddingResponse, error) {
c.clientMutex.RLock()
dimensions := c.dimensions
c.clientMutex.RUnlock()
reqBody := EmbeddingRequest{
Model: os.Getenv("EMBEDDING_API_MODEL_NAME"),
Input: texts,
Dimensions: dimensions,
}
jsonData, err := json.Marshal(reqBody)
if err != nil {
return nil, fmt.Errorf("failed to marshal request: %v", err)
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.apiURL, bytes.NewBuffer(jsonData))
if err != nil {
return nil, fmt.Errorf("failed to create request: %v", err)
}
req.Header.Set("Content-Type", contentTypeJSON)
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.apiKey))
resp, err := c.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to make request: %v", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read response body: %v", err)
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(body))
}
var embeddingResp EmbeddingResponse
if err := json.Unmarshal(body, &embeddingResp); err != nil {
return nil, fmt.Errorf("failed to unmarshal response: %v", err)
}
return &embeddingResp, nil
}

View File

@@ -0,0 +1,58 @@
package copilot
import (
"github.com/gofrs/uuid"
"net/http"
"os"
"strconv"
"github.com/gin-gonic/gin"
)
// EmbeddingsAPIRequest 表示嵌入API的请求结构
type EmbeddingsAPIRequest struct {
Input []string `json:"input" binding:"required"`
Model string `json:"model,omitempty"`
Dimensions int `json:"dimensions,omitempty"`
}
// HandleEmbeddings 处理嵌入请求的HTTP处理器
func HandleEmbeddings(c *gin.Context) {
requestID := uuid.Must(uuid.NewV4()).String()
c.Header("x-github-request-id", requestID)
var req EmbeddingsAPIRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
// 从环境变量获取维度大小默认为1536
dimensionSize := 1536
if dimSizeStr := os.Getenv("EMBEDDING_DIMENSION_SIZE"); dimSizeStr != "" {
if dimSize, err := strconv.Atoi(dimSizeStr); err == nil {
dimensionSize = dimSize
}
}
// 创建嵌入客户端
client, err := NewEmbeddingClient(dimensionSize)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
// 如果请求中指定了模型,则使用请求中的模型
if req.Model != "" {
client.SetModel(req.Model)
}
// 获取嵌入,使用请求上下文以支持取消操作
resp, err := client.GetEmbeddings(c.Request.Context(), req.Input)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, resp)
}

View File

@@ -0,0 +1,147 @@
package copilot
import (
"encoding/json"
"github.com/gin-gonic/gin"
"github.com/gofrs/uuid"
"log"
"math/rand"
"net/http"
"os"
"ripper/internal/app/github_auth"
"ripper/internal/cache"
"strconv"
"strings"
"time"
)
// GetDisguiseCopilotInternalV2Token 返回伪装的token
func GetDisguiseCopilotInternalV2Token(ctx *gin.Context) {
requestID := uuid.Must(uuid.NewV4()).String()
ctx.Header("x-github-request-id", requestID)
trackingId, _ := uuid.NewV4()
now := time.Now().Unix()
dcAt, _ := strconv.Atoi(os.Getenv("DISGUISE_COPILOT_TOKEN_EXPIRES_AT"))
expiresAt := now + int64(dcAt)
sku := "copilot_for_business_seat"
copilotToken := github_auth.JsonMap2SignToken(map[string]interface{}{
"tid": trackingId,
"exp": expiresAt,
"sku": sku,
"st": "dotcom",
"chat": 1,
"u": "github",
})
endpoints := make(map[string]interface{})
endpoints["api"] = os.Getenv("API_BASE_URL")
endpoints["origin-tracker"] = "https://origin-tracker.individual.githubcopilot.com"
endpoints["proxy"] = os.Getenv("PROXY_BASE_URL")
endpoints["telemetry"] = os.Getenv("TELEMETRY_BASE_URL")
gout := gin.H{
"annotations_enabled": true,
"chat_enabled": true,
"chat_jetbrains_enabled": true,
"code_quote_enabled": true,
"code_review_enabled": false,
"codesearch": true,
"copilot_ide_agent_chat_gpt4_small_prompt": false,
"copilotignore_enabled": false,
"endpoints": endpoints,
"expires_at": expiresAt,
"individual": true,
"nes_enabled": false,
"prompt_8k": true,
"public_suggestions": "disabled",
"refresh_in": 1500,
"sku": sku,
"snippy_load_test_enabled": false,
"telemetry": "disabled",
"token": copilotToken,
"tracking_id": trackingId,
"intellij_editor_fetcher": false,
"vsc_electron_fetcher": false,
"vs_editor_fetcher": false,
"vsc_panel_v2": false,
"xcode": true,
"xcode_chat": true,
"limited_user_quotas": nil,
"limited_user_reset_date": nil,
"vsc_electron_fetcher_v2": false,
}
ctx.JSON(http.StatusOK, gout)
}
// GetCopilotInternalV2Token 获取github copilot官方token
func GetCopilotInternalV2Token(c *gin.Context) {
ghuTokens := strings.Split(os.Getenv("COPILOT_GHU_TOKEN"), ",")
if len(ghuTokens) == 0 {
return
}
rand.Seed(time.Now().UnixNano())
ghu := ghuTokens[rand.Intn(len(ghuTokens))]
if ghu == "" {
log.Println("ghu token is empty")
c.JSON(http.StatusUnprocessableEntity, gin.H{
"message": "ghu token is empty",
})
return
}
cacheKey := "copilot_internal_v2_token"
token, err := cache.Get(cacheKey)
if err != nil {
log.Println(err.Error())
c.JSON(http.StatusUnauthorized, gin.H{"error": err.Error()})
cache.Del(cacheKey)
return
}
if token != nil {
c.JSON(http.StatusOK, token)
return
}
url := "https://api.github.com/copilot_internal/v2/token"
req, err := http.NewRequestWithContext(c, "GET", url, nil)
if err != nil {
log.Println(err.Error())
c.JSON(http.StatusUnauthorized, gin.H{"error": err.Error()})
return
}
req.Header.Set("authorization", "token "+ghu)
req.Header.Set("editor-plugin-version", "copilot-intellij/1.5.21.6667")
req.Header.Set("editor-version", "JetBrains-IU/242.21829.142")
req.Header.Set("user-agent", "GithubCopilot/1.228.0")
client := &http.Client{Timeout: 30 * time.Second}
resp, err := client.Do(req)
if err != nil {
log.Println(err.Error())
c.JSON(resp.StatusCode, gin.H{"error": err.Error()})
return
}
if resp.StatusCode != 200 {
errorMsg := "获取 Token 失败, 当前 ghu_token 账户可能并未订阅 github copilot 服务!" + ghu
c.JSON(resp.StatusCode, gin.H{"error": errorMsg})
log.Println(errorMsg)
return
}
defer resp.Body.Close()
var result interface{}
err = json.NewDecoder(resp.Body).Decode(&result)
if err != nil {
log.Println(err.Error())
c.JSON(http.StatusUnprocessableEntity, gin.H{"error": err.Error()})
return
}
cache.Set(cacheKey, result, 1500)
c.JSON(resp.StatusCode, result)
}

View File

@@ -0,0 +1,373 @@
package copilot
import (
"bytes"
"context"
"crypto/tls"
"encoding/json"
"errors"
"fmt"
"github.com/gofrs/uuid"
"io"
"io/ioutil"
"log"
"math/rand"
"net/http"
"os"
"ripper/internal/cache"
"strconv"
"strings"
"time"
"github.com/gin-gonic/gin"
)
// CodexCompletions 全代理GitHub的代码补全接口
func CodexCompletions(c *gin.Context) {
ctx := c.Request.Context()
requestID := uuid.Must(uuid.NewV4()).String()
c.Header("x-github-request-id", requestID)
urlModelName := c.Param("model-name")
debounceTime, _ := strconv.Atoi(os.Getenv("COPILOT_DEBOUNCE"))
time.Sleep(time.Duration(debounceTime) * time.Millisecond)
if ctx.Err() != nil {
abortCodex(c, http.StatusRequestTimeout)
return
}
body, err := io.ReadAll(c.Request.Body)
if nil != err {
abortCodex(c, http.StatusBadRequest)
return
}
copilotAccountType := os.Getenv("COPILOT_ACCOUNT_TYPE")
url := "https://proxy." + copilotAccountType + ".githubcopilot.com/v1/engines/" + urlModelName + "/completions"
req, err := http.NewRequestWithContext(c, "POST", url, bytes.NewBuffer(body))
if nil != err {
abortCodex(c, http.StatusInternalServerError)
return
}
// 合并请求头
if err := mergeHeaders(c.Request.Header, req); err != nil {
log.Println(err)
abortCodex(c, http.StatusInternalServerError)
return
}
httpClientTimeout, _ := time.ParseDuration(os.Getenv("HTTP_CLIENT_TIMEOUT") + "s")
client := &http.Client{
Timeout: httpClientTimeout,
Transport: &http.Transport{
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
},
}
resp, err := client.Do(req)
if nil != err {
if errors.Is(err, context.Canceled) {
abortCodex(c, http.StatusRequestTimeout)
return
}
log.Println("request completions failed:", err.Error())
abortCodex(c, http.StatusInternalServerError)
return
}
defer CloseIO(resp.Body)
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
log.Println("请求GitHub官方补全接口失败:", string(body))
abortCodex(c, resp.StatusCode)
return
}
c.Status(resp.StatusCode)
c.Header("Content-Type", resp.Header.Get("Content-Type"))
_, _ = io.Copy(c.Writer, resp.Body)
}
// ChatsCompletions 全代理GitHub的聊天补全接口
func ChatsCompletions(c *gin.Context) {
ctx := c.Request.Context()
if ctx.Err() != nil {
abortCodex(c, http.StatusRequestTimeout)
return
}
body, err := io.ReadAll(c.Request.Body)
if nil != err {
abortCodex(c, http.StatusBadRequest)
return
}
copilotAccountType := os.Getenv("COPILOT_ACCOUNT_TYPE")
url := "https://api." + copilotAccountType + ".githubcopilot.com/chat/completions"
req, err := http.NewRequestWithContext(c, "POST", url, bytes.NewBuffer(body))
if nil != err {
abortCodex(c, http.StatusInternalServerError)
return
}
// 合并请求头
if err := mergeHeaders(c.Request.Header, req); err != nil {
log.Println(err)
abortCodex(c, http.StatusInternalServerError)
return
}
httpClientTimeout, _ := time.ParseDuration(os.Getenv("HTTP_CLIENT_TIMEOUT") + "s")
client := &http.Client{
Timeout: httpClientTimeout,
Transport: &http.Transport{
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
},
}
resp, err := client.Do(req)
if nil != err {
if errors.Is(err, context.Canceled) {
abortCodex(c, http.StatusRequestTimeout)
return
}
log.Println("request completions failed:", err.Error())
abortCodex(c, http.StatusInternalServerError)
return
}
defer CloseIO(resp.Body)
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
log.Println("请求GitHub官方对话接口失败:", string(body))
abortCodex(c, resp.StatusCode)
return
}
c.Status(resp.StatusCode)
c.Header("Content-Type", resp.Header.Get("Content-Type"))
_, _ = io.Copy(c.Writer, resp.Body)
}
// ChatEditCompletions 聊天编辑补全接口
func ChatEditCompletions(c *gin.Context) {
ctx := c.Request.Context()
if ctx.Err() != nil {
abortCodex(c, http.StatusRequestTimeout)
return
}
body, err := io.ReadAll(c.Request.Body)
if nil != err {
abortCodex(c, http.StatusBadRequest)
return
}
copilotAccountType := os.Getenv("COPILOT_ACCOUNT_TYPE")
url := "https://proxy." + copilotAccountType + ".githubcopilot.com/v1/engines/copilot-centralus-h100/speculation"
req, err := http.NewRequestWithContext(c, "POST", url, bytes.NewBuffer(body))
if nil != err {
abortCodex(c, http.StatusInternalServerError)
return
}
// 合并请求头
if err := mergeHeaders(c.Request.Header, req); err != nil {
log.Println(err)
abortCodex(c, http.StatusInternalServerError)
return
}
httpClientTimeout, _ := time.ParseDuration(os.Getenv("HTTP_CLIENT_TIMEOUT") + "s")
client := &http.Client{
Timeout: httpClientTimeout,
Transport: &http.Transport{
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
},
}
resp, err := client.Do(req)
if nil != err {
if errors.Is(err, context.Canceled) {
abortCodex(c, http.StatusRequestTimeout)
return
}
log.Println("request failed:", err.Error())
abortCodex(c, http.StatusInternalServerError)
return
}
defer CloseIO(resp.Body)
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
log.Println("请求 Chat 编辑接口失败:", string(body))
abortCodex(c, resp.StatusCode)
return
}
c.Status(resp.StatusCode)
c.Header("Content-Type", resp.Header.Get("Content-Type"))
requestID := uuid.Must(uuid.NewV4()).String()
c.Header("x-github-request-id", requestID)
_, _ = io.Copy(c.Writer, resp.Body)
}
// getAuthToken 获取GitHub Copilot的临时Token
func getAuthToken() (string, error) {
ghuTokens := strings.Split(os.Getenv("COPILOT_GHU_TOKEN"), ",")
if len(ghuTokens) == 0 {
return "", fmt.Errorf("COPILOT_GHU_TOKEN environment variable is empty or malformed")
}
rand.Seed(time.Now().UnixNano())
ghu := ghuTokens[rand.Intn(len(ghuTokens))]
cacheKey := "github:copilot_internal_v2_token:" + ghu
token, err := cache.Get(cacheKey)
if err != nil {
cache.Del(cacheKey)
return "", err
}
if token != nil {
return token.(string), nil
}
url := "https://api.github.com/copilot_internal/v2/token"
httpClientTimeout, _ := time.ParseDuration(os.Getenv("HTTP_CLIENT_TIMEOUT") + "s")
client := &http.Client{
Timeout: httpClientTimeout,
Transport: &http.Transport{
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
},
}
req, err := http.NewRequest("GET", url, nil)
if err != nil {
return "", err
}
req.Header.Set("authorization", "token "+ghu)
req.Header.Set("host", "api.github.com")
req.Header.Set("accept", "*/*")
req.Header.Set("editor-plugin-version", "copilot-intellij/1.5.21.6667")
req.Header.Set("copilot-language-server-version", "1.228.0")
req.Header.Set("user-agent", "GithubCopilot/1.228.0")
req.Header.Set("editor-version", "JetBrains-IU/242.21829.142")
res, err := client.Do(req)
if err != nil {
return "", err
}
defer res.Body.Close()
if res.StatusCode != http.StatusOK {
return "", fmt.Errorf("获取 Token 失败" + ghu)
}
body, err := ioutil.ReadAll(res.Body)
if err != nil {
return "", err
}
// 解析json
var result map[string]interface{}
err = json.Unmarshal(body, &result)
if err != nil {
return "", err
}
newToken := result["token"].(string)
err = cache.Set(cacheKey, newToken, 1500)
if err != nil {
return "", err
}
return newToken, nil
}
// mergeHeaders 合并请求头,固定请求头会覆盖原有请求头
func mergeHeaders(originalHeader http.Header, req *http.Request) error {
// 复制原始请求头
for key, values := range originalHeader {
for _, value := range values {
req.Header.Add(key, value)
}
}
// 获取token
token, err := getAuthToken()
if err != nil {
return fmt.Errorf("获取GitHub Copilot的临时Token失败: %w", err)
}
// 固定请求头
fixedHeaders := map[string]string{
"authorization": "Bearer " + token,
"editor-plugin-version": "copilot-intellij/1.5.21.6667",
"copilot-language-server-version": "1.228.0",
"user-agent": "GithubCopilot/1.228.0",
"editor-version": "JetBrains-IU/242.21829.142",
}
// 设置固定请求头(覆盖原有的)
for key, value := range fixedHeaders {
req.Header.Set(key, value)
}
return nil
}
// GetCopilotModels 获取GitHub Copilot的模型列表
func GetCopilotModels(c *gin.Context) {
copilotAccountType := os.Getenv("COPILOT_ACCOUNT_TYPE")
url := "https://api." + copilotAccountType + ".githubcopilot.com/models"
req, err := http.NewRequestWithContext(c, "GET", url, nil)
if nil != err {
abortCodex(c, http.StatusInternalServerError)
return
}
// 合并请求头
if err := mergeHeaders(c.Request.Header, req); err != nil {
log.Println(err)
abortCodex(c, http.StatusInternalServerError)
return
}
httpClientTimeout, _ := time.ParseDuration(os.Getenv("HTTP_CLIENT_TIMEOUT") + "s")
client := &http.Client{
Timeout: httpClientTimeout,
Transport: &http.Transport{
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
},
}
resp, err := client.Do(req)
if nil != err {
if errors.Is(err, context.Canceled) {
abortCodex(c, http.StatusRequestTimeout)
return
}
log.Println("获取模型列表失败:", err.Error())
abortCodex(c, http.StatusInternalServerError)
return
}
defer CloseIO(resp.Body)
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
log.Println("请求GitHub Copilot模型列表失败:", string(body))
abortCodex(c, resp.StatusCode)
return
}
// 转发原始响应
c.Status(resp.StatusCode)
c.Header("Content-Type", resp.Header.Get("Content-Type"))
requestID := uuid.Must(uuid.NewV4()).String()
c.Header("x-github-request-id", requestID)
_, _ = io.Copy(c.Writer, resp.Body)
}

View File

@@ -0,0 +1,24 @@
package copilot
import (
"github.com/gin-gonic/gin"
"github.com/gofrs/uuid"
"net/http"
)
// GetMembership 获取团队成员信息
func GetMembership(c *gin.Context) {
requestID := uuid.Must(uuid.NewV4()).String()
c.Header("x-github-request-id", requestID)
teamID := c.Param("teamID")
username := c.Param("username")
c.JSON(http.StatusOK, gin.H{
"message": "Not Found",
"documentation_url": "https://docs.github.com/rest/teams/members#get-team-membership-for-a-user-legacy",
"status": "404",
"teamID": teamID,
"username": username,
})
}

View File

@@ -0,0 +1,49 @@
package copilot
import (
"github.com/gin-gonic/gin"
"net/http"
)
func V3meta(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{})
}
func Cliv3(c *gin.Context) {
c.Header("X-OAuth-Scopes", "gist, read:org, repo, user, workflow, write:public_key")
c.JSON(http.StatusOK, gin.H{
"current_user_url": "https://api.github.com/user",
"current_user_authorizations_html_url": "https://github.com/settings/connections/applications{/client_id}",
"authorizations_url": "https://api.github.com/authorizations",
"code_search_url": "https://api.github.com/search/code?q={query}{&page,per_page,sort,order}",
"commit_search_url": "https://api.github.com/search/commits?q={query}{&page,per_page,sort,order}",
"emails_url": "https://api.github.com/user/emails",
"emojis_url": "https://api.github.com/emojis",
"events_url": "https://api.github.com/events",
"feeds_url": "https://api.github.com/feeds",
"followers_url": "https://api.github.com/user/followers",
"following_url": "https://api.github.com/user/following{/target}",
"gists_url": "https://api.github.com/gists{/gist_id}",
"hub_url": "https://api.github.com/hub",
"issue_search_url": "https://api.github.com/search/issues?q={query}{&page,per_page,sort,order}",
"issues_url": "https://api.github.com/issues",
"keys_url": "https://api.github.com/user/keys",
"label_search_url": "https://api.github.com/search/labels?q={query}&repository_id={repository_id}{&page,per_page}",
"notifications_url": "https://api.github.com/notifications",
"organization_url": "https://api.github.com/orgs/{org}",
"organization_repositories_url": "https://api.github.com/orgs/{org}/repos{?type,page,per_page,sort}",
"organization_teams_url": "https://api.github.com/orgs/{org}/teams",
"public_gists_url": "https://api.github.com/gists/public",
"rate_limit_url": "https://api.github.com/rate_limit",
"repository_url": "https://api.github.com/repos/{owner}/{repo}",
"repository_search_url": "https://api.github.com/search/repositories?q={query}{&page,per_page,sort,order}",
"current_user_repositories_url": "https://api.github.com/user/repos{?type,page,per_page,sort}",
"starred_url": "https://api.github.com/user/starred{/owner}{/repo}",
"starred_gists_url": "https://api.github.com/gists/starred",
"topic_search_url": "https://api.github.com/search/topics?q={query}{&page,per_page}",
"user_url": "https://api.github.com/users/{user}",
"user_organizations_url": "https://api.github.com/user/orgs",
"user_repositories_url": "https://api.github.com/users/{user}/repos{?type,page,per_page,sort}",
"user_search_url": "https://api.github.com/search/users?q={query}{&page,per_page,sort,order}",
})
}

View File

@@ -0,0 +1,158 @@
package copilot
import (
"fmt"
"github.com/gin-gonic/gin"
"log"
"os"
"ripper/internal/middleware"
"strconv"
)
type Config struct {
ClientType string
CopilotProxyAll bool
}
// loadConfig loads the configuration from environment variables.
func loadConfig() (*Config, error) {
proxyAll, err := strconv.ParseBool(os.Getenv("COPILOT_PROXY_ALL"))
if err != nil {
return nil, fmt.Errorf("invalid boolean value for COPILOT_PROXY_ALL: %v", err)
}
return &Config{
ClientType: os.Getenv("COPILOT_CLIENT_TYPE"),
CopilotProxyAll: proxyAll,
}, nil
}
// GinApi 注册路由
func GinApi(g *gin.RouterGroup) {
config, err := loadConfig()
if err != nil {
log.Fatal(err)
}
// 基础路由
setupBasicRoutes(g, config)
// 用户相关路由
setupUserRoutes(g)
// Copilot相关路由
setupCopilotRoutes(g, config)
// API v3相关路由
setupV3Routes(g)
}
// setupBasicRoutes 设置基础路由
func setupBasicRoutes(g *gin.RouterGroup, config *Config) {
g.GET("/models", createModelsHandler(config))
g.GET("/_ping", GetPing)
g.POST("/telemetry", PostTelemetry)
g.GET("/agents", GetAgents)
g.GET("/copilot_internal/user", GetCopilotInternalUser)
}
// setupUserRoutes 设置用户相关路由
func setupUserRoutes(g *gin.RouterGroup) {
authMiddleware := middleware.AccessTokenCheckAuth()
userGroup := g.Group("")
userGroup.Use(authMiddleware)
{
userGroup.GET("/user", GetLoginUser)
userGroup.GET("/user/orgs", GetUserOrgs)
userGroup.GET("/api/v3/user", GetLoginUser)
userGroup.GET("/api/v3/user/orgs", GetUserOrgs)
userGroup.GET("/teams/:teamID/memberships/:username", GetMembership)
userGroup.POST("/chunks", HandleChunks)
}
}
// setupCopilotRoutes 设置Copilot相关路由
func setupCopilotRoutes(g *gin.RouterGroup, config *Config) {
tokenMiddleware := middleware.TokenCheckAuth()
// Copilot token endpoint
g.GET("/copilot_internal/v2/token",
middleware.AccessTokenCheckAuth(),
createTokenHandler(config))
// Completions endpoints
completionsGroup := g.Group("")
completionsGroup.Use(tokenMiddleware)
{
completionsGroup.POST("/v1/engines/:model-name/completions", createCompletionsHandler(config))
completionsGroup.POST("/v1/engines/copilot-codex", createCompletionsHandler(config))
completionsGroup.POST("/chat/completions", createChatHandler(config))
completionsGroup.POST("/agents/chat", createChatHandler(config))
completionsGroup.POST("/v1/chat/completions", createChatHandler(config))
completionsGroup.POST("/v1/engines/copilot-centralus-h100/speculation", createChatEditCompletionsHandler(config))
completionsGroup.POST("/embeddings", HandleEmbeddings)
}
}
// setupV3Routes 设置API v3相关路由
func setupV3Routes(g *gin.RouterGroup) {
g.GET("/api/v3/meta", V3meta)
g.GET("/api/v3/", Cliv3)
g.GET("/", Cliv3)
}
// 处理函数生成器
func createTokenHandler(config *Config) gin.HandlerFunc {
return func(c *gin.Context) {
if config.ClientType == "github" && !config.CopilotProxyAll {
GetCopilotInternalV2Token(c)
} else {
GetDisguiseCopilotInternalV2Token(c)
}
}
}
// createCompletionsHandler 生成代码补全处理函数
func createCompletionsHandler(config *Config) gin.HandlerFunc {
return func(c *gin.Context) {
if config.ClientType == "github" && config.CopilotProxyAll {
CodexCompletions(c)
} else {
CodeCompletions(c)
}
}
}
// createChatHandler 生成聊天补全处理函数
func createChatHandler(config *Config) gin.HandlerFunc {
return func(c *gin.Context) {
if config.ClientType == "github" && config.CopilotProxyAll {
ChatsCompletions(c)
} else {
ChatCompletions(c)
}
}
}
// createChatEditCompletionsHandler 生成聊天编辑补全处理函数
func createChatEditCompletionsHandler(config *Config) gin.HandlerFunc {
return func(c *gin.Context) {
if config.ClientType == "github" && config.CopilotProxyAll {
ChatEditCompletions(c)
} else {
CodeCompletions(c)
}
}
}
// createModelsHandler 生成模型处理函数
func createModelsHandler(config *Config) gin.HandlerFunc {
return func(c *gin.Context) {
if config.ClientType == "github" && config.CopilotProxyAll {
GetCopilotModels(c)
} else {
GetModels(c)
}
}
}

View File

@@ -0,0 +1,20 @@
package copilot
import (
"github.com/gin-gonic/gin"
"github.com/gofrs/uuid"
"net/http"
)
// PostTelemetry 接收并处理来自GitHub Copilot的遥测数据
func PostTelemetry(c *gin.Context) {
requestID := uuid.Must(uuid.NewV4()).String()
c.Header("x-github-request-id", requestID)
c.JSON(http.StatusOK, gin.H{
"itemsReceived": 0,
"itemsAccepted": 0,
"appId": nil,
"errors": []string{},
})
}

View File

@@ -0,0 +1,109 @@
package copilot
import (
"crypto/md5"
"encoding/hex"
"fmt"
"github.com/gin-gonic/gin"
"github.com/gofrs/uuid"
"math/rand"
"net/http"
"ripper/internal/middleware"
jwtpkg "ripper/pkg/jwt"
"time"
)
// GetLoginUser 获取登录用户信息
func GetLoginUser(ctx *gin.Context) {
userDisplayName := "github"
token, _ := jwtpkg.GetJwtProto(ctx, &middleware.UserLoad{})
if token != nil && token.UserDisplayName != "" {
userDisplayName = token.UserDisplayName
}
ctx.Header("X-OAuth-Scopes", "gist, read:org, repo, user, workflow, write:public_key")
requestID := uuid.Must(uuid.NewV4()).String()
ctx.Header("x-github-request-id", requestID)
ctx.JSON(http.StatusOK, gin.H{
"login": userDisplayName,
"id": 9919,
"node_id": "DEyOk9yZ2FuaXphdGlvbjk5MTk=",
"avatar_url": "https://avatars.githubusercontent.com/u/9919?v=4",
"gravatar_id": "",
"url": "https://api.github.com/users/github",
"html_url": "https://github.com/github",
"followers_url": "https://api.github.com/users/github/followers",
"following_url": "https://api.github.com/users/github/following{/other_user}",
"gists_url": "https://api.github.com/users/github/gists{/gist_id}",
"starred_url": "https://api.github.com/users/github/starred{/owner}{/repo}",
"subscriptions_url": "https://api.github.com/users/github/subscriptions",
"organizations_url": "https://api.github.com/users/github/orgs",
"repos_url": "https://api.github.com/users/github/repos",
"events_url": "https://api.github.com/users/github/events{/privacy}",
"received_events_url": "https://api.github.com/users/github/received_events",
"type": "User",
"site_admin": false,
"name": "GitHub",
"company": nil,
"blog": "",
"location": "San Francisco, CA",
"email": nil,
"hireable": nil,
"bio": nil,
"twitter_username": nil,
"public_repos": 498,
"public_gists": 0,
"followers": 42848,
"following": 0,
"created_at": "2008-05-11T04:37:31Z",
"updated_at": "2022-11-29T19:44:55Z",
})
}
func GetUserOrgs(ctx *gin.Context) {
ctx.Header("X-OAuth-Scopes", "gist, read:org, repo, user, workflow, write:public_key")
ctx.JSON(http.StatusOK, []interface{}{})
}
// generateTrackingID 生成模拟的 analytics_tracking_id
func generateTrackingID() string {
// 生成一个随机字符串并计算其 MD5
randomStr := fmt.Sprintf("%d%d", time.Now().UnixNano(), rand.Int())
hash := md5.Sum([]byte(randomStr))
return hex.EncodeToString(hash[:])
}
// generateAssignedDate 生成模拟的 assigned_date
func generateAssignedDate() string {
// 生成最近30天内的随机时间
now := time.Now()
daysAgo := rand.Intn(30)
randomTime := now.AddDate(0, 0, -daysAgo)
// 随机增加小时和分钟
randomHour := rand.Intn(24)
randomMinute := rand.Intn(60)
randomTime = randomTime.Add(time.Duration(randomHour) * time.Hour)
randomTime = randomTime.Add(time.Duration(randomMinute) * time.Minute)
// 返回格式化的时间字符串
return randomTime.Format(time.RFC3339)
}
// GetCopilotInternalUser 获取 Copilot 内部用户信息
func GetCopilotInternalUser(ctx *gin.Context) {
requestID := uuid.Must(uuid.NewV4()).String()
ctx.Header("x-github-request-id", requestID)
ctx.JSON(http.StatusOK, gin.H{
"access_type_sku": "free_educational",
"copilot_plan": "individual",
"analytics_tracking_id": generateTrackingID(),
"assigned_date": generateAssignedDate(),
"can_signup_for_limited": false,
"chat_enabled": true,
"organization_login_list": []interface{}{},
"organization_list": []interface{}{},
})
}

View File

@@ -0,0 +1,78 @@
package copilot
import (
_ "embed"
"encoding/json"
"github.com/gofrs/uuid"
"io"
"log"
"net/http"
"os"
"path/filepath"
"time"
"github.com/gin-gonic/gin"
)
type Pong struct {
Now int `json:"now"`
Status string `json:"status"`
Ns1 string `json:"ns1"`
}
// GetPing 模拟ping接口
func GetPing(ctx *gin.Context) {
requestID := uuid.Must(uuid.NewV4()).String()
ctx.Header("x-github-request-id", requestID)
ctx.JSON(http.StatusOK, Pong{
Now: time.Now().Second(),
Status: "ok",
Ns1: "200 OK",
})
}
// ModelsResponse 模型列表响应结构
type ModelsResponse struct {
Data []interface{} `json:"data"`
Object string `json:"object"`
}
// GetModels 获取模型列表
func GetModels(ctx *gin.Context) {
// 从根目录下读取models.json文件
jsonFile, err := os.Open(filepath.Join("models.json"))
if err != nil {
log.Printf("无法打开models.json文件: %v", err)
ctx.JSON(http.StatusInternalServerError, gin.H{"error": "无法读取模型列表数据"})
return
}
defer CloseIO(jsonFile)
// 解析JSON数据
jsonData, err := io.ReadAll(jsonFile)
if err != nil {
log.Printf("读取models.json内容失败: %v", err)
ctx.JSON(http.StatusInternalServerError, gin.H{"error": "无法读取模型列表数据"})
return
}
var modelsResponse ModelsResponse
if err := json.Unmarshal(jsonData, &modelsResponse); err != nil {
log.Printf("解析models.json失败: %v", err)
ctx.JSON(http.StatusInternalServerError, gin.H{"error": "无法解析模型列表数据"})
return
}
// 返回模型列表数据
requestID := uuid.Must(uuid.NewV4()).String()
ctx.Header("x-github-request-id", requestID)
ctx.JSON(http.StatusOK, modelsResponse)
}
func CloseIO(c io.Closer) {
err := c.Close()
if nil != err {
log.Println(err)
}
}

206
internal/middleware/auth.go Normal file
View File

@@ -0,0 +1,206 @@
package middleware
import (
"fmt"
"github.com/gin-gonic/gin"
"net/http"
"os"
"ripper/internal/app/github_auth"
"ripper/internal/response"
jwtpkg "ripper/pkg/jwt"
"strconv"
"strings"
"time"
)
type OAuthCheck struct {
ClientId string `json:"client_id" form:"client_id"`
DeviceCode string `json:"device_code" form:"device_code"`
GrantType string `json:"grant_type" form:"grant_type"`
}
func DeviceCodeCheckAuth(ctx *gin.Context) {
checkInfo := &OAuthCheck{}
if err := ctx.ShouldBind(&checkInfo); err != nil {
response.FailJson(ctx, response.FailStruct{
Code: -1,
Msg: "Invalid client id.",
}, false)
ctx.Abort()
return
}
info, _ := github_auth.GetClientAuthInfoByDeviceCode(checkInfo.DeviceCode)
if info.CardCode == "" {
ctx.JSON(http.StatusOK, gin.H{
"error": "authorization_pending",
"error_description": "The authorization request is still pending.",
"error_uri": "https://docs.github.com/developers/apps/authorizing-oauth-apps#error-codes-for-the-device-flow",
})
ctx.Abort()
return
}
ctx.Set("client_auth_info", info)
ctx.Next()
}
func AuthCodeFlowCheckAuth(ctx *gin.Context) {
checkInfoClient := &github_auth.ClientOAuthInfo{}
err := ctx.Bind(&checkInfoClient)
if err != nil {
response.FailJson(ctx, response.FailStruct{
Code: -1,
Msg: "Invalid client id.",
}, false)
ctx.Abort()
return
}
oauthCodeInfo, err := github_auth.GetOAuthCodeInfoByClientIdAndCode(checkInfoClient.ClientId, checkInfoClient.Code)
if err != nil {
response.FailJson(ctx, response.FailStruct{
Code: -1,
Msg: "Invalid client id.",
}, false)
ctx.Abort()
return
}
ctx.Set("client_auth_info", oauthCodeInfo)
ctx.Next()
}
func AccessTokenCheckAuth() gin.HandlerFunc {
return func(c *gin.Context) {
token := c.Request.Header.Get("Authorization")
if token == "" {
response.FailJsonAndStatusCode(c, http.StatusForbidden, response.NoAccess, false)
c.Abort()
return
}
last := strings.Index(token, " ")
if len(token) < last || last == -1 {
response.FailJsonAndStatusCode(c, http.StatusForbidden, response.TokenWrongful, false)
c.Abort()
return
}
token = token[last+1:]
chk, jwter, err := jwtpkg.CheckToken(token, &UserLoad{}, "user")
if err != nil {
errmsg := response.TokenWrongful
errmsg.Msg = "令牌验证错误"
response.FailJsonAndStatusCode(c, http.StatusForbidden, errmsg, true, err.Error())
c.Abort()
return
}
if !chk {
response.FailJsonAndStatusCode(c, http.StatusForbidden, response.NoAccess, true, "破损令牌")
c.Abort()
return
}
chs := true
issuerStr := ""
issuerStr, err = jwter.GetIssuer()
if err != nil {
chs = false
c.Abort()
return
}
if "user" != issuerStr && issuerStr != "" {
chs = false
c.Abort()
return
}
if !chs {
errmsg := response.TokenWrongful
errmsg.Msg = "签名错误"
response.FailJsonAndStatusCode(c, http.StatusForbidden, errmsg, true, err.Error())
c.Abort()
return
}
c.Set("token", jwter)
c.Set("tokenStr", token)
c.Set("token.issuer", issuerStr)
c.Next()
}
}
func TokenCheckAuth() gin.HandlerFunc {
return func(c *gin.Context) {
clientType := os.Getenv("COPILOT_CLIENT_TYPE")
copilotProxyAll, err := strconv.ParseBool(os.Getenv("COPILOT_PROXY_ALL"))
if clientType == "github" && !copilotProxyAll {
c.Next()
return
}
token := c.Request.Header.Get("Authorization")
if token == "" {
response.FailJsonAndStatusCode(c, http.StatusUnauthorized, response.TokenWrongful, false)
c.Abort()
return
}
last := strings.Index(token, " ")
if len(token) < last || last == -1 {
response.FailJsonAndStatusCode(c, http.StatusUnauthorized, response.TokenWrongful, false)
c.Abort()
return
}
token = token[last+1:]
parsedToken := parseAuthorizationToken(token)
// 校验exp是否过期
expired, err := isExpired(parsedToken["exp"])
if err != nil {
response.FailJsonAndStatusCode(c, http.StatusUnauthorized, response.TokenWrongful, false)
c.Abort()
return
} else {
if expired {
response.FailJsonAndStatusCode(c, http.StatusUnauthorized, response.TokenOverdue, false)
c.Abort()
return
}
}
rawToken := github_auth.JsonMap2Token(map[string]interface{}{
"tid": parsedToken["tid"],
"exp": parsedToken["exp"],
"sku": parsedToken["sku"],
"st": parsedToken["st"],
"chat": parsedToken["chat"],
"u": parsedToken["u"],
})
sign := "1:" + github_auth.Token2Sign(rawToken)
if sign != parsedToken["8kp"] {
response.FailJsonAndStatusCode(c, http.StatusUnauthorized, response.TokenWrongful, false)
c.Abort()
return
}
c.Next()
}
}
func parseAuthorizationToken(token string) map[string]string {
result := make(map[string]string)
pairs := strings.Split(token, ";")
for _, pair := range pairs {
kv := strings.SplitN(pair, "=", 2)
if len(kv) == 2 {
key := kv[0]
value := kv[1]
if key == "tid" || key == "exp" || key == "sku" || key == "st" || key == "8kp" || key == "chat" || key == "u" {
result[key] = value
}
}
}
return result
}
func isExpired(expStr string) (bool, error) {
exp, err := strconv.ParseInt(expStr, 10, 64)
if err != nil {
return false, fmt.Errorf("invalid exp timestamp: %v", err)
}
now := time.Now().Unix()
return now > exp, nil
}

View File

@@ -0,0 +1,30 @@
package middleware
import (
"github.com/gin-gonic/gin"
)
//Cors 跨域中间件
func Cors() gin.HandlerFunc {
return func(c *gin.Context) {
origin := c.Request.Header.Get("origin") //请求头部
if len(origin) == 0 {
origin = c.Request.Header.Get("Origin")
}
//接收客户端发送的origin (重要!)
c.Writer.Header().Set("Access-Control-Allow-Origin", origin)
//允许客户端传递校验信息比如 cookie (重要)
c.Writer.Header().Set("Access-Control-Allow-Credentials", "true")
c.Writer.Header().Set("Access-Control-Allow-Headers", "Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization, accept, origin, Cache-Control, X-Requested-With")
//服务器支持的所有跨域请求的方法
c.Writer.Header().Set("Access-Control-Allow-Methods", "OPTIONS, GET, POST, PUT, DELETE, UPDATE")
c.Writer.Header().Set("Content-Type", "application/json; charset=utf-8")
// 设置预验请求有效期为 86400 秒
c.Writer.Header().Set("Access-Control-Max-Age", "86400")
if c.Request.Method == "OPTIONS" {
c.AbortWithStatus(200)
return
}
c.Next()
}
}

102
internal/middleware/jwt.go Normal file
View File

@@ -0,0 +1,102 @@
package middleware
import (
"github.com/gin-gonic/gin"
"ripper/internal/response"
jwtpkg "ripper/pkg/jwt"
)
// JWTCheck 检查是否登陆
// 检查完毕会将jwt结构体写入到Context
// 适用于同时用于公开与鉴权的路由
func JWTCheck(c *gin.Context, model jwtpkg.LoadModel, issure ...string) (bool, error) {
token := c.Request.Header.Get("Authorization")
if token == "" {
return false, nil
}
if len(token) < 8 {
return false, nil
}
token = token[7:]
chk, jwter, err := jwtpkg.CheckToken(token, model, "")
if err != nil {
return false, err
}
chs := true
for _, v := range issure {
jwt, err := jwter.GetIssuer()
if err != nil {
chs = false
break
}
if v != jwt {
chs = false
break
}
}
if !chs {
return false, nil
}
if !chk {
return false, nil
}
c.Set("token", jwter)
c.Next()
return true, nil
}
// JWTAuth 为JWT中间件客户端下需要在header带上Authorization: Bearer <token>
// issure 为可选验证签名,支持多参选择
func JWTAuth(model jwtpkg.LoadModel, issure ...string) gin.HandlerFunc {
return func(c *gin.Context) {
token := c.Request.Header.Get("Authorization")
if token == "" {
response.FailJson(c, response.NoAccess, false)
c.Abort()
return
}
if len(token) < 8 {
response.FailJson(c, response.TokenWrongful, false)
c.Abort()
return
}
token = token[7:]
chk, jwter, err := jwtpkg.CheckToken(token, model, "")
if err != nil {
errmsg := response.TokenWrongful
errmsg.Msg = "令牌验证错误"
response.FailJson(c, errmsg, true, err.Error())
c.Abort()
return
}
if !chk {
response.FailJson(c, response.NoAccess, true, "破损令牌")
c.Abort()
return
}
chs := true
issuerStr := ""
for _, v := range issure {
issuerStr, err = jwter.GetIssuer()
if err != nil {
chs = false
break
}
if v != issuerStr {
chs = false
break
}
}
if !chs {
errmsg := response.TokenWrongful
errmsg.Msg = "签名错误"
response.FailJson(c, errmsg, true, err.Error())
c.Abort()
return
}
c.Set("token", jwter)
c.Set("tokenStr", token)
c.Set("token.issuer", issuerStr)
c.Next()
}
}

View File

@@ -0,0 +1,23 @@
package middleware
import (
"github.com/golang-jwt/jwt/v5"
jwtPkg "ripper/pkg/jwt"
)
type AdminLoad struct {
Username string `json:"username"`
}
type UserLoad struct {
UserDisplayName string `json:"userDisplayName,omitempty"`
CardCode string `json:"token"`
Client string `json:"client"`
jwt.RegisteredClaims
}
func NewUserLoad(ID uint, ExpiresAt int64, Issuer string) *UserLoad {
return &UserLoad{
RegisteredClaims: jwtPkg.CreateStandardClaims(ExpiresAt, Issuer),
}
}

View File

@@ -0,0 +1 @@
package middleware

View File

@@ -0,0 +1,72 @@
package response
/*
code约定:
code代表错误无错误始终为0
200 OK - [GET]:服务器成功返回用户请求的数据;
201 CREATED - [POST/PUT/PATCH]:用户新建或修改数据成功;
202 Accepted - [*]:表示一个请求已经进入后台排队(异步任务);
204 NO CONTENT - [DELETE]:用户删除数据成功;
400 INVALID REQUEST - [POST/PUT/PATCH]:用户发出的请求有错误,服务器没有进行新建或修改数据的操作;
401 Unauthorized - [*]:表示用户没有权限(令牌、用户名、密码错误);
403 Forbidden - [*] 表示用户得到授权与401错误相对但是访问是被禁止的
404 NOT FOUND - [*]:用户发出的请求针对的是不存在的记录,服务器没有进行操作;
406 Not Acceptable - [GET]:用户请求的格式不可得;
410 Gone -[GET]:用户请求的资源被永久删除,且不会再得到的;
422 Unprocesable entity - [POST/PUT/PATCH] 当创建一个对象时,发生一个验证错误;
500 INTERNAL SERVER ERROR - [*]:服务器发生错误,用户将无法判断发出的请求是否成功。
*/
var (
NoAccess = FailStruct{
Code: 401,
Msg: "无权访问",
}
TokenWrongful = FailStruct{
Code: 401,
Msg: "Token不合法",
}
TokenOverdue = FailStruct{
Code: 401,
Msg: "Token过期",
}
NoIntactParameters = FailStruct{
Code: -10001,
Msg: "参数提交不完整,请重试",
}
UserError = FailStruct{
Code: 10001,
Msg: "帐号或密码错误",
}
SignError = FailStruct{
Code: 10002,
Msg: "",
}
CaptchaError = FailStruct{
Code: 10003,
Msg: "生成验证码错误",
}
CaptchaVefError = FailStruct{
Code: 10004,
Msg: "验证码错误",
}
WechatLoginError = FailStruct{
Code: 10005,
Msg: "微信登陆错误",
}
)
type FailStruct struct {
Code int
Msg string
}
type Message struct {
ErrCode int `json:"error"`
Data interface{} `json:"data"`
Msg string `json:"message"`
}
type Token struct {
Token string `json:"token"`
}

View File

@@ -0,0 +1,62 @@
package response
import (
"errors"
"github.com/gin-gonic/gin"
"net/http"
)
/*
HTTP状态码约定
服务器访问正常始终200,错误交给code
*/
func BindStruct(c *gin.Context, bind interface{}) error {
if err := c.ShouldBindJSON(bind); err != nil {
FailJson(c, NoIntactParameters, false, "结构体绑定错误")
return errors.New("BindError")
}
return nil
}
func SuccessJson(c *gin.Context, msg string, data ...interface{}) {
var tmps interface{}
if len(data) > 0 {
tmps = data[0]
}
c.JSON(http.StatusOK, Message{
ErrCode: 0,
Data: tmps,
Msg: msg,
})
}
func FailJson(c *gin.Context, load FailStruct, WriteLog bool, logMsh ...string) {
if WriteLog {
var werrmsg string
for _, v := range logMsh {
werrmsg += v + "\n"
}
}
c.JSON(http.StatusOK, Message{
ErrCode: load.Code,
Msg: load.Msg,
})
}
func FailJsonAndStatusCode(c *gin.Context, code int, load FailStruct, WriteLog bool, logMsh ...string) {
if WriteLog {
var werrmsg string
for _, v := range logMsh {
werrmsg += v + "\n"
}
}
c.JSON(code, Message{
ErrCode: load.Code,
Msg: load.Msg,
})
}
func SuccessByte(c *gin.Context, data []byte) {
c.Writer.Write(data)
}

View File

@@ -0,0 +1,25 @@
package router
import (
"github.com/gin-gonic/gin"
"html/template"
authApi "ripper/internal/controller/auth"
"ripper/internal/controller/copilot"
"ripper/internal/middleware"
"ripper/static"
)
func NewHTTPRouter(r *gin.Engine) {
rootRouter := r.Group("/")
tmpl := template.Must(template.New("").ParseFS(static.Public, "public/*.html"))
r.SetHTMLTemplate(tmpl)
apiRouter := r.Group("/api")
rootRouter.Use(middleware.Cors())
apiRouter.Use(middleware.Cors())
authApi.GinApi(rootRouter)
copilot.GinApi(rootRouter)
}