commit 452fa4c2f4877e82e97d192f9649fcb19090f02b Author: 史悦 Date: Wed Aug 13 10:09:22 2025 +0800 Initial commit diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..a845364 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,34 @@ +# Stage 1: Build the Go binary +FROM golang:1.22-alpine AS builder + +# Set the Current Working Directory inside the container +WORKDIR /app + +# Copy go mod and sum files +# COPY go.mod go.sum ./ +# RUN go mod download + +# Copy the source code into the container +COPY main.go . + +# Build the Go app +# CGO_ENABLED=0 is needed for a static build +# GOOS=linux is to specify the target OS +# -a installs all packages to be rebuilt +# -installsuffix cgo is used with CGO_ENABLED=0 +# -o main specifies the output file name +RUN CGO_ENABLED=0 GOOS=linux go build -a -installsuffix cgo -o main . + +# Stage 2: Create the final, minimal image +FROM alpine:latest + +WORKDIR /root/ + +# Copy the Pre-built binary file from the previous stage +COPY --from=builder /app/main . + +# Expose port 8080 to the outside world +EXPOSE 8080 + +# Command to run the executable +CMD ["./main"] \ No newline at end of file diff --git a/compose.yaml b/compose.yaml new file mode 100644 index 0000000..7c7ce41 --- /dev/null +++ b/compose.yaml @@ -0,0 +1,6 @@ +services: + gemini-proxy: + build: . + ports: + - "9090:8080" + restart: always \ No newline at end of file diff --git a/main.go b/main.go new file mode 100644 index 0000000..c914caa --- /dev/null +++ b/main.go @@ -0,0 +1,673 @@ +package main + +import ( + "bufio" + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "log" + "net/http" + "strings" + "time" +) + +// Config holds configuration settings +type Config struct { + UpstreamURLBase string + MaxConsecutiveRetries int + DebugMode bool + RetryDelayMs int + LogTruncationLimit int + Port string +} + +// DefaultConfig returns default configuration +func DefaultConfig() *Config { + return &Config{ + UpstreamURLBase: "https://api-proxy.me/gemini", + MaxConsecutiveRetries: 10, + DebugMode: true, + RetryDelayMs: 750, + LogTruncationLimit: 8000, + Port: ":8080", + } +} + +var nonRetryableStatuses = map[int]bool{ + 400: true, 401: true, 403: true, 404: true, 429: true, +} + +// Logger wrapper for different log levels +type Logger struct { + debug bool +} + +func (l *Logger) Debug(args ...interface{}) { + if l.debug { + log.Printf("[DEBUG %s] %v", time.Now().Format(time.RFC3339), fmt.Sprint(args...)) + } +} + +func (l *Logger) Info(args ...interface{}) { + log.Printf("[INFO %s] %v", time.Now().Format(time.RFC3339), fmt.Sprint(args...)) +} + +func (l *Logger) Error(args ...interface{}) { + log.Printf("[ERROR %s] %v", time.Now().Format(time.RFC3339), fmt.Sprint(args...)) +} + +func truncate(s string, limit int) string { + if len(s) <= limit { + return s + } + return fmt.Sprintf("%s... [truncated %d chars]", s[:limit], len(s)-limit) +} + +// Error response structure +type ErrorResponse struct { + Error ErrorDetail `json:"error"` +} + +type ErrorDetail struct { + Code int `json:"code"` + Message string `json:"message"` + Status string `json:"status"` + Details []interface{} `json:"details,omitempty"` +} + +// Request/Response structures for Gemini API +type GeminiRequest struct { + Contents []Content `json:"contents"` +} + +type Content struct { + Role string `json:"role"` + Parts []Part `json:"parts"` +} + +type Part struct { + Text string `json:"text,omitempty"` + Thought bool `json:"thought,omitempty"` + FunctionCall interface{} `json:"functionCall,omitempty"` + ToolCode interface{} `json:"toolCode,omitempty"` +} + +type GeminiResponse struct { + Candidates []Candidate `json:"candidates,omitempty"` + Error *ErrorDetail `json:"error,omitempty"` +} + +type Candidate struct { + Content *Content `json:"content,omitempty"` + FinishReason string `json:"finishReason,omitempty"` +} + +// ProxyServer represents the main server +type ProxyServer struct { + config *Config + logger *Logger + client *http.Client +} + +func NewProxyServer(config *Config) *ProxyServer { + return &ProxyServer{ + config: config, + logger: &Logger{debug: config.DebugMode}, + client: &http.Client{Timeout: 30 * time.Second}, + } +} + +func (s *ProxyServer) statusToGoogleStatus(code int) string { + switch code { + case 400: + return "INVALID_ARGUMENT" + case 401: + return "UNAUTHENTICATED" + case 403: + return "PERMISSION_DENIED" + case 404: + return "NOT_FOUND" + case 429: + return "RESOURCE_EXHAUSTED" + case 500: + return "INTERNAL" + case 503: + return "UNAVAILABLE" + case 504: + return "DEADLINE_EXCEEDED" + default: + return "UNKNOWN" + } +} + +func (s *ProxyServer) setCORSHeaders(w http.ResponseWriter) { + w.Header().Set("Access-Control-Allow-Origin", "*") + w.Header().Set("Access-Control-Allow-Methods", "GET, POST, OPTIONS") + w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization, X-Goog-Api-Key") +} + +func (s *ProxyServer) handleOPTIONS(w http.ResponseWriter, r *http.Request) { + s.setCORSHeaders(w) + w.WriteHeader(http.StatusOK) +} + +func (s *ProxyServer) jsonError(w http.ResponseWriter, status int, message string, details interface{}) { + s.setCORSHeaders(w) + w.Header().Set("Content-Type", "application/json; charset=utf-8") + w.WriteHeader(status) + + errorResp := ErrorResponse{ + Error: ErrorDetail{ + Code: status, + Message: message, + Status: s.statusToGoogleStatus(status), + }, + } + + if details != nil { + errorResp.Error.Details = []interface{}{details} + } + + json.NewEncoder(w).Encode(errorResp) +} + +func (s *ProxyServer) buildUpstreamHeaders(r *http.Request) http.Header { + headers := make(http.Header) + copyHeader := func(key string) { + if val := r.Header.Get(key); val != "" { + headers.Set(key, val) + } + } + + copyHeader("Authorization") + copyHeader("X-Goog-Api-Key") + copyHeader("Content-Type") + copyHeader("Accept") + + return headers +} + +func (s *ProxyServer) standardizeError(resp *http.Response) (*ErrorResponse, error) { + body, err := io.ReadAll(resp.Body) + if err != nil { + s.logger.Error("Failed to read error response body:", err) + return nil, err + } + defer resp.Body.Close() + + s.logger.Error("Upstream error body:", truncate(string(body), s.config.LogTruncationLimit)) + + var errorResp ErrorResponse + if err := json.Unmarshal(body, &errorResp); err != nil { + // Create standardized error if upstream doesn't provide one + message := "Request failed" + if resp.StatusCode == 429 { + message = "Resource has been exhausted (e.g. check quota)." + } else if resp.Status != "" { + message = resp.Status + } + + errorResp = ErrorResponse{ + Error: ErrorDetail{ + Code: resp.StatusCode, + Message: message, + Status: s.statusToGoogleStatus(resp.StatusCode), + }, + } + + if len(body) > 0 { + errorResp.Error.Details = []interface{}{ + map[string]interface{}{ + "@type": "proxy.upstream", + "upstream_error": truncate(string(body), s.config.LogTruncationLimit), + }, + } + } + } + + // Ensure status field is set + if errorResp.Error.Status == "" { + errorResp.Error.Status = s.statusToGoogleStatus(errorResp.Error.Code) + } + + return &errorResp, nil +} + +func (s *ProxyServer) buildRetryRequestBody(originalBody *GeminiRequest, accumulatedText string) *GeminiRequest { + s.logger.Debug(fmt.Sprintf("Building retry request. Accumulated text length: %d", len(accumulatedText))) + s.logger.Debug("Accumulated text preview:", truncate(accumulatedText, 500)) + + retryBody := &GeminiRequest{ + Contents: make([]Content, len(originalBody.Contents)), + } + copy(retryBody.Contents, originalBody.Contents) + + // Find last user message index + lastUserIndex := -1 + for i := len(retryBody.Contents) - 1; i >= 0; i-- { + if retryBody.Contents[i].Role == "user" { + lastUserIndex = i + break + } + } + + // Insert model response and continuation prompt + history := []Content{ + { + Role: "model", + Parts: []Part{{Text: accumulatedText}}, + }, + { + Role: "user", + Parts: []Part{{ + Text: "Continue exactly where you left off, providing the final answer without repeating the previous thinking steps.", + }}, + }, + } + + if lastUserIndex != -1 { + // Insert after the last user message + newContents := make([]Content, 0, len(retryBody.Contents)+len(history)) + newContents = append(newContents, retryBody.Contents[:lastUserIndex+1]...) + newContents = append(newContents, history...) + newContents = append(newContents, retryBody.Contents[lastUserIndex+1:]...) + retryBody.Contents = newContents + } else { + retryBody.Contents = append(retryBody.Contents, history...) + } + + bodyJSON, _ := json.Marshal(retryBody) + s.logger.Debug("Constructed retry request body:", truncate(string(bodyJSON), s.config.LogTruncationLimit)) + + return retryBody +} + +func (s *ProxyServer) processStreamAndRetryInternally(ctx context.Context, w http.ResponseWriter, initialResp *http.Response, originalRequestBody *GeminiRequest, upstreamURL string, originalHeaders http.Header) error { + var accumulatedText string + var consecutiveRetryCount int + currentResp := initialResp + sessionStartTime := time.Now() + + s.logger.Info(fmt.Sprintf("Starting stream processing session. Max retries: %d", s.config.MaxConsecutiveRetries)) + + for { + var interruptionReason string + streamStartTime := time.Now() + linesInThisStream := 0 + textInThisStream := "" + reasoningStepDetected := false + hasReceivedFinalAnswerContent := false + + s.logger.Info(fmt.Sprintf("=== Starting stream attempt %d/%d ===", consecutiveRetryCount+1, s.config.MaxConsecutiveRetries+1)) + + scanner := bufio.NewScanner(currentResp.Body) + finishReasonArrived := false + + for scanner.Scan() { + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + + line := scanner.Text() + linesInThisStream++ + + // Write line to client + fmt.Fprintf(w, "%s\n\n", line) + if flusher, ok := w.(http.Flusher); ok { + flusher.Flush() + } + + s.logger.Debug(fmt.Sprintf("SSE Line %d: %s", linesInThisStream, truncate(line, 500))) + + if !strings.HasPrefix(line, "data: ") { + continue + } + + var payload GeminiResponse + if err := json.Unmarshal([]byte(line[6:]), &payload); err != nil { + s.logger.Debug("Ignoring non-JSON data line.") + continue + } + + if len(payload.Candidates) == 0 { + continue + } + + candidate := payload.Candidates[0] + + // Process content parts + if candidate.Content != nil && candidate.Content.Parts != nil { + for _, part := range candidate.Content.Parts { + if part.Text != "" { + accumulatedText += part.Text + textInThisStream += part.Text + + // Check if this is final answer content (not a thought) + if !part.Thought { + hasReceivedFinalAnswerContent = true + s.logger.Debug("Received final answer content (non-thought part).") + } else { + s.logger.Debug("Received 'thought' content part.") + } + } else if part.FunctionCall != nil || part.ToolCode != nil { + reasoningStepDetected = true + partJSON, _ := json.Marshal(part) + s.logger.Info("Reasoning step detected (tool/function call):", truncate(string(partJSON), s.config.LogTruncationLimit)) + } + } + } + + // Process finish reason + if candidate.FinishReason != "" { + finishReasonArrived = true + s.logger.Info("Finish reason received:", candidate.FinishReason) + + switch candidate.FinishReason { + case "STOP": + if hasReceivedFinalAnswerContent { + sessionDuration := time.Since(sessionStartTime) + s.logger.Info(fmt.Sprintf("=== STREAM COMPLETED SUCCESSFULLY (Reason: STOP after receiving final answer) ===")) + s.logger.Info(fmt.Sprintf(" - Total session duration: %v, Retries: %d", sessionDuration, consecutiveRetryCount)) + currentResp.Body.Close() + return nil + } else { + s.logger.Error("Stream finished with STOP but no final answer content was received. This is a failure.") + interruptionReason = "STOP_WITHOUT_ANSWER" + } + case "MAX_TOKENS", "TOOL_CODE", "SAFETY", "RECITATION": + s.logger.Info("Stream terminated with reason:", candidate.FinishReason, ". Closing stream.") + currentResp.Body.Close() + return nil + default: + s.logger.Error("Abnormal/unknown finish reason:", candidate.FinishReason) + interruptionReason = "FINISH_ABNORMAL" + } + break + } + } + + currentResp.Body.Close() + + if err := scanner.Err(); err != nil { + s.logger.Error("Scanner error:", err) + interruptionReason = "SCAN_ERROR" + } + + if !finishReasonArrived && interruptionReason == "" { + s.logger.Error("Stream ended prematurely without a finish reason (DROP).") + if reasoningStepDetected { + interruptionReason = "DROP_DURING_REASONING" + } else { + interruptionReason = "DROP" + } + } + + streamDuration := time.Since(streamStartTime) + s.logger.Info(fmt.Sprintf("Stream attempt %d summary: Duration: %v, Lines: %d, Chars: %d, Total Chars: %d", + consecutiveRetryCount+1, streamDuration, linesInThisStream, len(textInThisStream), len(accumulatedText))) + + if interruptionReason == "" { + s.logger.Info("Stream finished without interruption. Closing.") + return nil + } + + s.logger.Error("=== STREAM INTERRUPTED (Reason:", interruptionReason, ") ===") + + if consecutiveRetryCount >= s.config.MaxConsecutiveRetries { + s.logger.Error("Retry limit exceeded. Sending final error to client.") + errorPayload := map[string]interface{}{ + "error": map[string]interface{}{ + "code": 504, + "status": "DEADLINE_EXCEEDED", + "message": fmt.Sprintf("Proxy retry limit (%d) exceeded. Last interruption: %s.", s.config.MaxConsecutiveRetries, interruptionReason), + }, + } + errorJSON, _ := json.Marshal(errorPayload) + fmt.Fprintf(w, "event: error\ndata: %s\n\n", string(errorJSON)) + if flusher, ok := w.(http.Flusher); ok { + flusher.Flush() + } + return nil + } + + consecutiveRetryCount++ + s.logger.Info(fmt.Sprintf("Proceeding to retry attempt %d...", consecutiveRetryCount)) + + // Wait before retry + if s.config.RetryDelayMs > 0 { + s.logger.Debug(fmt.Sprintf("Waiting %dms before retrying...", s.config.RetryDelayMs)) + time.Sleep(time.Duration(s.config.RetryDelayMs) * time.Millisecond) + } + + // Build retry request + retryBody := s.buildRetryRequestBody(originalRequestBody, accumulatedText) + retryBodyJSON, _ := json.Marshal(retryBody) + + s.logger.Debug("Making retry request to:", upstreamURL) + retryReq, err := http.NewRequestWithContext(ctx, "POST", upstreamURL, bytes.NewReader(retryBodyJSON)) + if err != nil { + s.logger.Error("Failed to create retry request:", err) + continue + } + + retryReq.Header = s.buildUpstreamHeaders(&http.Request{Header: originalHeaders}) + retryReq.Header.Set("Content-Type", "application/json") + + retryResp, err := s.client.Do(retryReq) + if err != nil { + s.logger.Error("Retry request failed:", err) + continue + } + + s.logger.Info(fmt.Sprintf("Retry request completed. Status: %d %s", retryResp.StatusCode, retryResp.Status)) + + if nonRetryableStatuses[retryResp.StatusCode] { + s.logger.Error(fmt.Sprintf("FATAL: Received non-retryable status %d during retry.", retryResp.StatusCode)) + errorResp, err := s.standardizeError(retryResp) + if err == nil { + errorJSON, _ := json.Marshal(errorResp) + fmt.Fprintf(w, "event: error\ndata: %s\n\n", string(errorJSON)) + if flusher, ok := w.(http.Flusher); ok { + flusher.Flush() + } + } + return nil + } + + if retryResp.StatusCode != 200 { + retryResp.Body.Close() + s.logger.Error(fmt.Sprintf("Upstream server error on retry: %d", retryResp.StatusCode)) + continue + } + + s.logger.Info("✓ Retry successful. Got new stream.") + currentResp = retryResp + } +} + +func (s *ProxyServer) handleStreamingPost(w http.ResponseWriter, r *http.Request) { + upstreamURL := s.config.UpstreamURLBase + r.URL.Path + if r.URL.RawQuery != "" { + upstreamURL += "?" + r.URL.RawQuery + } + + s.logger.Info(fmt.Sprintf("=== NEW STREAMING REQUEST: %s %s ===", r.Method, r.URL.String())) + + // Read and parse request body + body, err := io.ReadAll(r.Body) + if err != nil { + s.jsonError(w, http.StatusBadRequest, "Failed to read request body", err.Error()) + return + } + defer r.Body.Close() + + s.logger.Info(fmt.Sprintf("Request body (raw, %d bytes): %s", len(body), truncate(string(body), s.config.LogTruncationLimit))) + + var originalRequestBody GeminiRequest + if err := json.Unmarshal(body, &originalRequestBody); err != nil { + s.logger.Error("Failed to parse request body:", err) + s.jsonError(w, http.StatusBadRequest, "Invalid JSON in request body", err.Error()) + return + } + + if len(originalRequestBody.Contents) > 0 { + s.logger.Info(fmt.Sprintf("Request contains %d messages:", len(originalRequestBody.Contents))) + for i, m := range originalRequestBody.Contents { + var partsText []string + for _, p := range m.Parts { + if p.Text != "" { + partsText = append(partsText, p.Text) + } else { + partsText = append(partsText, "[non-text part]") + } + } + s.logger.Info(fmt.Sprintf(" [%d] role=%s, text: %s", i, m.Role, truncate(strings.Join(partsText, "\n"), 1000))) + } + } + + s.logger.Info("=== MAKING INITIAL REQUEST TO UPSTREAM ===") + t0 := time.Now() + + req, err := http.NewRequestWithContext(r.Context(), "POST", upstreamURL, bytes.NewReader(body)) + if err != nil { + s.jsonError(w, http.StatusInternalServerError, "Failed to create upstream request", err.Error()) + return + } + + req.Header = s.buildUpstreamHeaders(r) + req.Header.Set("Content-Type", "application/json") + + initialResponse, err := s.client.Do(req) + if err != nil { + s.logger.Error("Initial request failed:", err) + s.jsonError(w, http.StatusBadGateway, "Failed to connect to upstream", err.Error()) + return + } + + s.logger.Info(fmt.Sprintf("Initial upstream response received in %v. Status: %d", time.Since(t0), initialResponse.StatusCode)) + + if initialResponse.StatusCode != 200 { + s.logger.Error(fmt.Sprintf("Initial request failed with status %d.", initialResponse.StatusCode)) + errorResp, err := s.standardizeError(initialResponse) + if err != nil { + s.jsonError(w, http.StatusBadGateway, "Failed to process upstream error", err.Error()) + return + } + + s.setCORSHeaders(w) + w.Header().Set("Content-Type", "application/json; charset=utf-8") + w.WriteHeader(initialResponse.StatusCode) + json.NewEncoder(w).Encode(errorResp) + return + } + + s.logger.Info("✓ Initial request successful. Starting stream processing.") + + // Set up streaming response headers + s.setCORSHeaders(w) + w.Header().Set("Content-Type", "text/event-stream; charset=utf-8") + w.Header().Set("Cache-Control", "no-cache") + w.Header().Set("Connection", "keep-alive") + w.WriteHeader(200) + + if flusher, ok := w.(http.Flusher); ok { + flusher.Flush() + } + + err = s.processStreamAndRetryInternally(r.Context(), w, initialResponse, &originalRequestBody, upstreamURL, r.Header) + if err != nil { + s.logger.Error("!!! UNHANDLED CRITICAL EXCEPTION IN STREAM PROCESSOR !!!", err) + } +} + +func (s *ProxyServer) handleNonStreaming(w http.ResponseWriter, r *http.Request) { + upstreamURL := s.config.UpstreamURLBase + r.URL.Path + if r.URL.RawQuery != "" { + upstreamURL += "?" + r.URL.RawQuery + } + + s.logger.Info(fmt.Sprintf("=== NEW NON-STREAMING REQUEST: %s %s ===", r.Method, r.URL.String())) + + req, err := http.NewRequestWithContext(r.Context(), r.Method, upstreamURL, r.Body) + if err != nil { + s.jsonError(w, http.StatusInternalServerError, "Failed to create upstream request", err.Error()) + return + } + + req.Header = s.buildUpstreamHeaders(r) + + resp, err := s.client.Do(req) + if err != nil { + s.logger.Error("Upstream request failed:", err) + s.jsonError(w, http.StatusBadGateway, "Failed to connect to upstream", err.Error()) + return + } + defer resp.Body.Close() + + if resp.StatusCode != 200 { + errorResp, err := s.standardizeError(resp) + if err != nil { + s.jsonError(w, http.StatusBadGateway, "Failed to process upstream error", err.Error()) + return + } + + s.setCORSHeaders(w) + w.Header().Set("Content-Type", "application/json; charset=utf-8") + w.WriteHeader(resp.StatusCode) + json.NewEncoder(w).Encode(errorResp) + return + } + + // Copy headers from upstream response + s.setCORSHeaders(w) + for key, values := range resp.Header { + if key != "Access-Control-Allow-Origin" { + for _, value := range values { + w.Header().Add(key, value) + } + } + } + + w.WriteHeader(resp.StatusCode) + io.Copy(w, resp.Body) +} + +func (s *ProxyServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { + if r.Method == "OPTIONS" { + s.handleOPTIONS(w, r) + return + } + + // Check if this is a streaming request + isStream := r.URL.Query().Get("alt") == "sse" + + if r.Method == "POST" && isStream { + s.handleStreamingPost(w, r) + } else { + s.handleNonStreaming(w, r) + } +} + +func main() { + config := DefaultConfig() + + // Override config from environment variables if needed + // You can add environment variable parsing here + + server := NewProxyServer(config) + + s := &http.Server{ + Addr: config.Port, + Handler: server, + ReadTimeout: 30 * time.Second, + WriteTimeout: 300 * time.Second, // Longer for streaming + IdleTimeout: 120 * time.Second, + } + + log.Printf("Starting Gemini API Proxy server on %s", config.Port) + log.Fatal(s.ListenAndServe()) +} \ No newline at end of file