Initial commit

This commit is contained in:
史悦
2025-08-13 10:09:22 +08:00
commit 452fa4c2f4
3 changed files with 713 additions and 0 deletions

34
Dockerfile Normal file
View File

@@ -0,0 +1,34 @@
# Stage 1: Build the Go binary
FROM golang:1.22-alpine AS builder
# Set the Current Working Directory inside the container
WORKDIR /app
# Copy go mod and sum files
# COPY go.mod go.sum ./
# RUN go mod download
# Copy the source code into the container
COPY main.go .
# Build the Go app
# CGO_ENABLED=0 is needed for a static build
# GOOS=linux is to specify the target OS
# -a installs all packages to be rebuilt
# -installsuffix cgo is used with CGO_ENABLED=0
# -o main specifies the output file name
RUN CGO_ENABLED=0 GOOS=linux go build -a -installsuffix cgo -o main .
# Stage 2: Create the final, minimal image
FROM alpine:latest
WORKDIR /root/
# Copy the Pre-built binary file from the previous stage
COPY --from=builder /app/main .
# Expose port 8080 to the outside world
EXPOSE 8080
# Command to run the executable
CMD ["./main"]

6
compose.yaml Normal file
View File

@@ -0,0 +1,6 @@
services:
gemini-proxy:
build: .
ports:
- "9090:8080"
restart: always

673
main.go Normal file
View File

@@ -0,0 +1,673 @@
package main
import (
"bufio"
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"log"
"net/http"
"strings"
"time"
)
// Config holds configuration settings
type Config struct {
UpstreamURLBase string
MaxConsecutiveRetries int
DebugMode bool
RetryDelayMs int
LogTruncationLimit int
Port string
}
// DefaultConfig returns default configuration
func DefaultConfig() *Config {
return &Config{
UpstreamURLBase: "https://api-proxy.me/gemini",
MaxConsecutiveRetries: 10,
DebugMode: true,
RetryDelayMs: 750,
LogTruncationLimit: 8000,
Port: ":8080",
}
}
var nonRetryableStatuses = map[int]bool{
400: true, 401: true, 403: true, 404: true, 429: true,
}
// Logger wrapper for different log levels
type Logger struct {
debug bool
}
func (l *Logger) Debug(args ...interface{}) {
if l.debug {
log.Printf("[DEBUG %s] %v", time.Now().Format(time.RFC3339), fmt.Sprint(args...))
}
}
func (l *Logger) Info(args ...interface{}) {
log.Printf("[INFO %s] %v", time.Now().Format(time.RFC3339), fmt.Sprint(args...))
}
func (l *Logger) Error(args ...interface{}) {
log.Printf("[ERROR %s] %v", time.Now().Format(time.RFC3339), fmt.Sprint(args...))
}
func truncate(s string, limit int) string {
if len(s) <= limit {
return s
}
return fmt.Sprintf("%s... [truncated %d chars]", s[:limit], len(s)-limit)
}
// Error response structure
type ErrorResponse struct {
Error ErrorDetail `json:"error"`
}
type ErrorDetail struct {
Code int `json:"code"`
Message string `json:"message"`
Status string `json:"status"`
Details []interface{} `json:"details,omitempty"`
}
// Request/Response structures for Gemini API
type GeminiRequest struct {
Contents []Content `json:"contents"`
}
type Content struct {
Role string `json:"role"`
Parts []Part `json:"parts"`
}
type Part struct {
Text string `json:"text,omitempty"`
Thought bool `json:"thought,omitempty"`
FunctionCall interface{} `json:"functionCall,omitempty"`
ToolCode interface{} `json:"toolCode,omitempty"`
}
type GeminiResponse struct {
Candidates []Candidate `json:"candidates,omitempty"`
Error *ErrorDetail `json:"error,omitempty"`
}
type Candidate struct {
Content *Content `json:"content,omitempty"`
FinishReason string `json:"finishReason,omitempty"`
}
// ProxyServer represents the main server
type ProxyServer struct {
config *Config
logger *Logger
client *http.Client
}
func NewProxyServer(config *Config) *ProxyServer {
return &ProxyServer{
config: config,
logger: &Logger{debug: config.DebugMode},
client: &http.Client{Timeout: 30 * time.Second},
}
}
func (s *ProxyServer) statusToGoogleStatus(code int) string {
switch code {
case 400:
return "INVALID_ARGUMENT"
case 401:
return "UNAUTHENTICATED"
case 403:
return "PERMISSION_DENIED"
case 404:
return "NOT_FOUND"
case 429:
return "RESOURCE_EXHAUSTED"
case 500:
return "INTERNAL"
case 503:
return "UNAVAILABLE"
case 504:
return "DEADLINE_EXCEEDED"
default:
return "UNKNOWN"
}
}
func (s *ProxyServer) setCORSHeaders(w http.ResponseWriter) {
w.Header().Set("Access-Control-Allow-Origin", "*")
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, OPTIONS")
w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization, X-Goog-Api-Key")
}
func (s *ProxyServer) handleOPTIONS(w http.ResponseWriter, r *http.Request) {
s.setCORSHeaders(w)
w.WriteHeader(http.StatusOK)
}
func (s *ProxyServer) jsonError(w http.ResponseWriter, status int, message string, details interface{}) {
s.setCORSHeaders(w)
w.Header().Set("Content-Type", "application/json; charset=utf-8")
w.WriteHeader(status)
errorResp := ErrorResponse{
Error: ErrorDetail{
Code: status,
Message: message,
Status: s.statusToGoogleStatus(status),
},
}
if details != nil {
errorResp.Error.Details = []interface{}{details}
}
json.NewEncoder(w).Encode(errorResp)
}
func (s *ProxyServer) buildUpstreamHeaders(r *http.Request) http.Header {
headers := make(http.Header)
copyHeader := func(key string) {
if val := r.Header.Get(key); val != "" {
headers.Set(key, val)
}
}
copyHeader("Authorization")
copyHeader("X-Goog-Api-Key")
copyHeader("Content-Type")
copyHeader("Accept")
return headers
}
func (s *ProxyServer) standardizeError(resp *http.Response) (*ErrorResponse, error) {
body, err := io.ReadAll(resp.Body)
if err != nil {
s.logger.Error("Failed to read error response body:", err)
return nil, err
}
defer resp.Body.Close()
s.logger.Error("Upstream error body:", truncate(string(body), s.config.LogTruncationLimit))
var errorResp ErrorResponse
if err := json.Unmarshal(body, &errorResp); err != nil {
// Create standardized error if upstream doesn't provide one
message := "Request failed"
if resp.StatusCode == 429 {
message = "Resource has been exhausted (e.g. check quota)."
} else if resp.Status != "" {
message = resp.Status
}
errorResp = ErrorResponse{
Error: ErrorDetail{
Code: resp.StatusCode,
Message: message,
Status: s.statusToGoogleStatus(resp.StatusCode),
},
}
if len(body) > 0 {
errorResp.Error.Details = []interface{}{
map[string]interface{}{
"@type": "proxy.upstream",
"upstream_error": truncate(string(body), s.config.LogTruncationLimit),
},
}
}
}
// Ensure status field is set
if errorResp.Error.Status == "" {
errorResp.Error.Status = s.statusToGoogleStatus(errorResp.Error.Code)
}
return &errorResp, nil
}
func (s *ProxyServer) buildRetryRequestBody(originalBody *GeminiRequest, accumulatedText string) *GeminiRequest {
s.logger.Debug(fmt.Sprintf("Building retry request. Accumulated text length: %d", len(accumulatedText)))
s.logger.Debug("Accumulated text preview:", truncate(accumulatedText, 500))
retryBody := &GeminiRequest{
Contents: make([]Content, len(originalBody.Contents)),
}
copy(retryBody.Contents, originalBody.Contents)
// Find last user message index
lastUserIndex := -1
for i := len(retryBody.Contents) - 1; i >= 0; i-- {
if retryBody.Contents[i].Role == "user" {
lastUserIndex = i
break
}
}
// Insert model response and continuation prompt
history := []Content{
{
Role: "model",
Parts: []Part{{Text: accumulatedText}},
},
{
Role: "user",
Parts: []Part{{
Text: "Continue exactly where you left off, providing the final answer without repeating the previous thinking steps.",
}},
},
}
if lastUserIndex != -1 {
// Insert after the last user message
newContents := make([]Content, 0, len(retryBody.Contents)+len(history))
newContents = append(newContents, retryBody.Contents[:lastUserIndex+1]...)
newContents = append(newContents, history...)
newContents = append(newContents, retryBody.Contents[lastUserIndex+1:]...)
retryBody.Contents = newContents
} else {
retryBody.Contents = append(retryBody.Contents, history...)
}
bodyJSON, _ := json.Marshal(retryBody)
s.logger.Debug("Constructed retry request body:", truncate(string(bodyJSON), s.config.LogTruncationLimit))
return retryBody
}
func (s *ProxyServer) processStreamAndRetryInternally(ctx context.Context, w http.ResponseWriter, initialResp *http.Response, originalRequestBody *GeminiRequest, upstreamURL string, originalHeaders http.Header) error {
var accumulatedText string
var consecutiveRetryCount int
currentResp := initialResp
sessionStartTime := time.Now()
s.logger.Info(fmt.Sprintf("Starting stream processing session. Max retries: %d", s.config.MaxConsecutiveRetries))
for {
var interruptionReason string
streamStartTime := time.Now()
linesInThisStream := 0
textInThisStream := ""
reasoningStepDetected := false
hasReceivedFinalAnswerContent := false
s.logger.Info(fmt.Sprintf("=== Starting stream attempt %d/%d ===", consecutiveRetryCount+1, s.config.MaxConsecutiveRetries+1))
scanner := bufio.NewScanner(currentResp.Body)
finishReasonArrived := false
for scanner.Scan() {
select {
case <-ctx.Done():
return ctx.Err()
default:
}
line := scanner.Text()
linesInThisStream++
// Write line to client
fmt.Fprintf(w, "%s\n\n", line)
if flusher, ok := w.(http.Flusher); ok {
flusher.Flush()
}
s.logger.Debug(fmt.Sprintf("SSE Line %d: %s", linesInThisStream, truncate(line, 500)))
if !strings.HasPrefix(line, "data: ") {
continue
}
var payload GeminiResponse
if err := json.Unmarshal([]byte(line[6:]), &payload); err != nil {
s.logger.Debug("Ignoring non-JSON data line.")
continue
}
if len(payload.Candidates) == 0 {
continue
}
candidate := payload.Candidates[0]
// Process content parts
if candidate.Content != nil && candidate.Content.Parts != nil {
for _, part := range candidate.Content.Parts {
if part.Text != "" {
accumulatedText += part.Text
textInThisStream += part.Text
// Check if this is final answer content (not a thought)
if !part.Thought {
hasReceivedFinalAnswerContent = true
s.logger.Debug("Received final answer content (non-thought part).")
} else {
s.logger.Debug("Received 'thought' content part.")
}
} else if part.FunctionCall != nil || part.ToolCode != nil {
reasoningStepDetected = true
partJSON, _ := json.Marshal(part)
s.logger.Info("Reasoning step detected (tool/function call):", truncate(string(partJSON), s.config.LogTruncationLimit))
}
}
}
// Process finish reason
if candidate.FinishReason != "" {
finishReasonArrived = true
s.logger.Info("Finish reason received:", candidate.FinishReason)
switch candidate.FinishReason {
case "STOP":
if hasReceivedFinalAnswerContent {
sessionDuration := time.Since(sessionStartTime)
s.logger.Info(fmt.Sprintf("=== STREAM COMPLETED SUCCESSFULLY (Reason: STOP after receiving final answer) ==="))
s.logger.Info(fmt.Sprintf(" - Total session duration: %v, Retries: %d", sessionDuration, consecutiveRetryCount))
currentResp.Body.Close()
return nil
} else {
s.logger.Error("Stream finished with STOP but no final answer content was received. This is a failure.")
interruptionReason = "STOP_WITHOUT_ANSWER"
}
case "MAX_TOKENS", "TOOL_CODE", "SAFETY", "RECITATION":
s.logger.Info("Stream terminated with reason:", candidate.FinishReason, ". Closing stream.")
currentResp.Body.Close()
return nil
default:
s.logger.Error("Abnormal/unknown finish reason:", candidate.FinishReason)
interruptionReason = "FINISH_ABNORMAL"
}
break
}
}
currentResp.Body.Close()
if err := scanner.Err(); err != nil {
s.logger.Error("Scanner error:", err)
interruptionReason = "SCAN_ERROR"
}
if !finishReasonArrived && interruptionReason == "" {
s.logger.Error("Stream ended prematurely without a finish reason (DROP).")
if reasoningStepDetected {
interruptionReason = "DROP_DURING_REASONING"
} else {
interruptionReason = "DROP"
}
}
streamDuration := time.Since(streamStartTime)
s.logger.Info(fmt.Sprintf("Stream attempt %d summary: Duration: %v, Lines: %d, Chars: %d, Total Chars: %d",
consecutiveRetryCount+1, streamDuration, linesInThisStream, len(textInThisStream), len(accumulatedText)))
if interruptionReason == "" {
s.logger.Info("Stream finished without interruption. Closing.")
return nil
}
s.logger.Error("=== STREAM INTERRUPTED (Reason:", interruptionReason, ") ===")
if consecutiveRetryCount >= s.config.MaxConsecutiveRetries {
s.logger.Error("Retry limit exceeded. Sending final error to client.")
errorPayload := map[string]interface{}{
"error": map[string]interface{}{
"code": 504,
"status": "DEADLINE_EXCEEDED",
"message": fmt.Sprintf("Proxy retry limit (%d) exceeded. Last interruption: %s.", s.config.MaxConsecutiveRetries, interruptionReason),
},
}
errorJSON, _ := json.Marshal(errorPayload)
fmt.Fprintf(w, "event: error\ndata: %s\n\n", string(errorJSON))
if flusher, ok := w.(http.Flusher); ok {
flusher.Flush()
}
return nil
}
consecutiveRetryCount++
s.logger.Info(fmt.Sprintf("Proceeding to retry attempt %d...", consecutiveRetryCount))
// Wait before retry
if s.config.RetryDelayMs > 0 {
s.logger.Debug(fmt.Sprintf("Waiting %dms before retrying...", s.config.RetryDelayMs))
time.Sleep(time.Duration(s.config.RetryDelayMs) * time.Millisecond)
}
// Build retry request
retryBody := s.buildRetryRequestBody(originalRequestBody, accumulatedText)
retryBodyJSON, _ := json.Marshal(retryBody)
s.logger.Debug("Making retry request to:", upstreamURL)
retryReq, err := http.NewRequestWithContext(ctx, "POST", upstreamURL, bytes.NewReader(retryBodyJSON))
if err != nil {
s.logger.Error("Failed to create retry request:", err)
continue
}
retryReq.Header = s.buildUpstreamHeaders(&http.Request{Header: originalHeaders})
retryReq.Header.Set("Content-Type", "application/json")
retryResp, err := s.client.Do(retryReq)
if err != nil {
s.logger.Error("Retry request failed:", err)
continue
}
s.logger.Info(fmt.Sprintf("Retry request completed. Status: %d %s", retryResp.StatusCode, retryResp.Status))
if nonRetryableStatuses[retryResp.StatusCode] {
s.logger.Error(fmt.Sprintf("FATAL: Received non-retryable status %d during retry.", retryResp.StatusCode))
errorResp, err := s.standardizeError(retryResp)
if err == nil {
errorJSON, _ := json.Marshal(errorResp)
fmt.Fprintf(w, "event: error\ndata: %s\n\n", string(errorJSON))
if flusher, ok := w.(http.Flusher); ok {
flusher.Flush()
}
}
return nil
}
if retryResp.StatusCode != 200 {
retryResp.Body.Close()
s.logger.Error(fmt.Sprintf("Upstream server error on retry: %d", retryResp.StatusCode))
continue
}
s.logger.Info("✓ Retry successful. Got new stream.")
currentResp = retryResp
}
}
func (s *ProxyServer) handleStreamingPost(w http.ResponseWriter, r *http.Request) {
upstreamURL := s.config.UpstreamURLBase + r.URL.Path
if r.URL.RawQuery != "" {
upstreamURL += "?" + r.URL.RawQuery
}
s.logger.Info(fmt.Sprintf("=== NEW STREAMING REQUEST: %s %s ===", r.Method, r.URL.String()))
// Read and parse request body
body, err := io.ReadAll(r.Body)
if err != nil {
s.jsonError(w, http.StatusBadRequest, "Failed to read request body", err.Error())
return
}
defer r.Body.Close()
s.logger.Info(fmt.Sprintf("Request body (raw, %d bytes): %s", len(body), truncate(string(body), s.config.LogTruncationLimit)))
var originalRequestBody GeminiRequest
if err := json.Unmarshal(body, &originalRequestBody); err != nil {
s.logger.Error("Failed to parse request body:", err)
s.jsonError(w, http.StatusBadRequest, "Invalid JSON in request body", err.Error())
return
}
if len(originalRequestBody.Contents) > 0 {
s.logger.Info(fmt.Sprintf("Request contains %d messages:", len(originalRequestBody.Contents)))
for i, m := range originalRequestBody.Contents {
var partsText []string
for _, p := range m.Parts {
if p.Text != "" {
partsText = append(partsText, p.Text)
} else {
partsText = append(partsText, "[non-text part]")
}
}
s.logger.Info(fmt.Sprintf(" [%d] role=%s, text: %s", i, m.Role, truncate(strings.Join(partsText, "\n"), 1000)))
}
}
s.logger.Info("=== MAKING INITIAL REQUEST TO UPSTREAM ===")
t0 := time.Now()
req, err := http.NewRequestWithContext(r.Context(), "POST", upstreamURL, bytes.NewReader(body))
if err != nil {
s.jsonError(w, http.StatusInternalServerError, "Failed to create upstream request", err.Error())
return
}
req.Header = s.buildUpstreamHeaders(r)
req.Header.Set("Content-Type", "application/json")
initialResponse, err := s.client.Do(req)
if err != nil {
s.logger.Error("Initial request failed:", err)
s.jsonError(w, http.StatusBadGateway, "Failed to connect to upstream", err.Error())
return
}
s.logger.Info(fmt.Sprintf("Initial upstream response received in %v. Status: %d", time.Since(t0), initialResponse.StatusCode))
if initialResponse.StatusCode != 200 {
s.logger.Error(fmt.Sprintf("Initial request failed with status %d.", initialResponse.StatusCode))
errorResp, err := s.standardizeError(initialResponse)
if err != nil {
s.jsonError(w, http.StatusBadGateway, "Failed to process upstream error", err.Error())
return
}
s.setCORSHeaders(w)
w.Header().Set("Content-Type", "application/json; charset=utf-8")
w.WriteHeader(initialResponse.StatusCode)
json.NewEncoder(w).Encode(errorResp)
return
}
s.logger.Info("✓ Initial request successful. Starting stream processing.")
// Set up streaming response headers
s.setCORSHeaders(w)
w.Header().Set("Content-Type", "text/event-stream; charset=utf-8")
w.Header().Set("Cache-Control", "no-cache")
w.Header().Set("Connection", "keep-alive")
w.WriteHeader(200)
if flusher, ok := w.(http.Flusher); ok {
flusher.Flush()
}
err = s.processStreamAndRetryInternally(r.Context(), w, initialResponse, &originalRequestBody, upstreamURL, r.Header)
if err != nil {
s.logger.Error("!!! UNHANDLED CRITICAL EXCEPTION IN STREAM PROCESSOR !!!", err)
}
}
func (s *ProxyServer) handleNonStreaming(w http.ResponseWriter, r *http.Request) {
upstreamURL := s.config.UpstreamURLBase + r.URL.Path
if r.URL.RawQuery != "" {
upstreamURL += "?" + r.URL.RawQuery
}
s.logger.Info(fmt.Sprintf("=== NEW NON-STREAMING REQUEST: %s %s ===", r.Method, r.URL.String()))
req, err := http.NewRequestWithContext(r.Context(), r.Method, upstreamURL, r.Body)
if err != nil {
s.jsonError(w, http.StatusInternalServerError, "Failed to create upstream request", err.Error())
return
}
req.Header = s.buildUpstreamHeaders(r)
resp, err := s.client.Do(req)
if err != nil {
s.logger.Error("Upstream request failed:", err)
s.jsonError(w, http.StatusBadGateway, "Failed to connect to upstream", err.Error())
return
}
defer resp.Body.Close()
if resp.StatusCode != 200 {
errorResp, err := s.standardizeError(resp)
if err != nil {
s.jsonError(w, http.StatusBadGateway, "Failed to process upstream error", err.Error())
return
}
s.setCORSHeaders(w)
w.Header().Set("Content-Type", "application/json; charset=utf-8")
w.WriteHeader(resp.StatusCode)
json.NewEncoder(w).Encode(errorResp)
return
}
// Copy headers from upstream response
s.setCORSHeaders(w)
for key, values := range resp.Header {
if key != "Access-Control-Allow-Origin" {
for _, value := range values {
w.Header().Add(key, value)
}
}
}
w.WriteHeader(resp.StatusCode)
io.Copy(w, resp.Body)
}
func (s *ProxyServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if r.Method == "OPTIONS" {
s.handleOPTIONS(w, r)
return
}
// Check if this is a streaming request
isStream := r.URL.Query().Get("alt") == "sse"
if r.Method == "POST" && isStream {
s.handleStreamingPost(w, r)
} else {
s.handleNonStreaming(w, r)
}
}
func main() {
config := DefaultConfig()
// Override config from environment variables if needed
// You can add environment variable parsing here
server := NewProxyServer(config)
s := &http.Server{
Addr: config.Port,
Handler: server,
ReadTimeout: 30 * time.Second,
WriteTimeout: 300 * time.Second, // Longer for streaming
IdleTimeout: 120 * time.Second,
}
log.Printf("Starting Gemini API Proxy server on %s", config.Port)
log.Fatal(s.ListenAndServe())
}