Initial commit
This commit is contained in:
34
Dockerfile
Normal file
34
Dockerfile
Normal file
@@ -0,0 +1,34 @@
|
||||
# Stage 1: Build the Go binary
|
||||
FROM golang:1.22-alpine AS builder
|
||||
|
||||
# Set the Current Working Directory inside the container
|
||||
WORKDIR /app
|
||||
|
||||
# Copy go mod and sum files
|
||||
# COPY go.mod go.sum ./
|
||||
# RUN go mod download
|
||||
|
||||
# Copy the source code into the container
|
||||
COPY main.go .
|
||||
|
||||
# Build the Go app
|
||||
# CGO_ENABLED=0 is needed for a static build
|
||||
# GOOS=linux is to specify the target OS
|
||||
# -a installs all packages to be rebuilt
|
||||
# -installsuffix cgo is used with CGO_ENABLED=0
|
||||
# -o main specifies the output file name
|
||||
RUN CGO_ENABLED=0 GOOS=linux go build -a -installsuffix cgo -o main .
|
||||
|
||||
# Stage 2: Create the final, minimal image
|
||||
FROM alpine:latest
|
||||
|
||||
WORKDIR /root/
|
||||
|
||||
# Copy the Pre-built binary file from the previous stage
|
||||
COPY --from=builder /app/main .
|
||||
|
||||
# Expose port 8080 to the outside world
|
||||
EXPOSE 8080
|
||||
|
||||
# Command to run the executable
|
||||
CMD ["./main"]
|
||||
6
compose.yaml
Normal file
6
compose.yaml
Normal file
@@ -0,0 +1,6 @@
|
||||
services:
|
||||
gemini-proxy:
|
||||
build: .
|
||||
ports:
|
||||
- "9090:8080"
|
||||
restart: always
|
||||
673
main.go
Normal file
673
main.go
Normal file
@@ -0,0 +1,673 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Config holds configuration settings
|
||||
type Config struct {
|
||||
UpstreamURLBase string
|
||||
MaxConsecutiveRetries int
|
||||
DebugMode bool
|
||||
RetryDelayMs int
|
||||
LogTruncationLimit int
|
||||
Port string
|
||||
}
|
||||
|
||||
// DefaultConfig returns default configuration
|
||||
func DefaultConfig() *Config {
|
||||
return &Config{
|
||||
UpstreamURLBase: "https://api-proxy.me/gemini",
|
||||
MaxConsecutiveRetries: 10,
|
||||
DebugMode: true,
|
||||
RetryDelayMs: 750,
|
||||
LogTruncationLimit: 8000,
|
||||
Port: ":8080",
|
||||
}
|
||||
}
|
||||
|
||||
var nonRetryableStatuses = map[int]bool{
|
||||
400: true, 401: true, 403: true, 404: true, 429: true,
|
||||
}
|
||||
|
||||
// Logger wrapper for different log levels
|
||||
type Logger struct {
|
||||
debug bool
|
||||
}
|
||||
|
||||
func (l *Logger) Debug(args ...interface{}) {
|
||||
if l.debug {
|
||||
log.Printf("[DEBUG %s] %v", time.Now().Format(time.RFC3339), fmt.Sprint(args...))
|
||||
}
|
||||
}
|
||||
|
||||
func (l *Logger) Info(args ...interface{}) {
|
||||
log.Printf("[INFO %s] %v", time.Now().Format(time.RFC3339), fmt.Sprint(args...))
|
||||
}
|
||||
|
||||
func (l *Logger) Error(args ...interface{}) {
|
||||
log.Printf("[ERROR %s] %v", time.Now().Format(time.RFC3339), fmt.Sprint(args...))
|
||||
}
|
||||
|
||||
func truncate(s string, limit int) string {
|
||||
if len(s) <= limit {
|
||||
return s
|
||||
}
|
||||
return fmt.Sprintf("%s... [truncated %d chars]", s[:limit], len(s)-limit)
|
||||
}
|
||||
|
||||
// Error response structure
|
||||
type ErrorResponse struct {
|
||||
Error ErrorDetail `json:"error"`
|
||||
}
|
||||
|
||||
type ErrorDetail struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Status string `json:"status"`
|
||||
Details []interface{} `json:"details,omitempty"`
|
||||
}
|
||||
|
||||
// Request/Response structures for Gemini API
|
||||
type GeminiRequest struct {
|
||||
Contents []Content `json:"contents"`
|
||||
}
|
||||
|
||||
type Content struct {
|
||||
Role string `json:"role"`
|
||||
Parts []Part `json:"parts"`
|
||||
}
|
||||
|
||||
type Part struct {
|
||||
Text string `json:"text,omitempty"`
|
||||
Thought bool `json:"thought,omitempty"`
|
||||
FunctionCall interface{} `json:"functionCall,omitempty"`
|
||||
ToolCode interface{} `json:"toolCode,omitempty"`
|
||||
}
|
||||
|
||||
type GeminiResponse struct {
|
||||
Candidates []Candidate `json:"candidates,omitempty"`
|
||||
Error *ErrorDetail `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
type Candidate struct {
|
||||
Content *Content `json:"content,omitempty"`
|
||||
FinishReason string `json:"finishReason,omitempty"`
|
||||
}
|
||||
|
||||
// ProxyServer represents the main server
|
||||
type ProxyServer struct {
|
||||
config *Config
|
||||
logger *Logger
|
||||
client *http.Client
|
||||
}
|
||||
|
||||
func NewProxyServer(config *Config) *ProxyServer {
|
||||
return &ProxyServer{
|
||||
config: config,
|
||||
logger: &Logger{debug: config.DebugMode},
|
||||
client: &http.Client{Timeout: 30 * time.Second},
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ProxyServer) statusToGoogleStatus(code int) string {
|
||||
switch code {
|
||||
case 400:
|
||||
return "INVALID_ARGUMENT"
|
||||
case 401:
|
||||
return "UNAUTHENTICATED"
|
||||
case 403:
|
||||
return "PERMISSION_DENIED"
|
||||
case 404:
|
||||
return "NOT_FOUND"
|
||||
case 429:
|
||||
return "RESOURCE_EXHAUSTED"
|
||||
case 500:
|
||||
return "INTERNAL"
|
||||
case 503:
|
||||
return "UNAVAILABLE"
|
||||
case 504:
|
||||
return "DEADLINE_EXCEEDED"
|
||||
default:
|
||||
return "UNKNOWN"
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ProxyServer) setCORSHeaders(w http.ResponseWriter) {
|
||||
w.Header().Set("Access-Control-Allow-Origin", "*")
|
||||
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, OPTIONS")
|
||||
w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization, X-Goog-Api-Key")
|
||||
}
|
||||
|
||||
func (s *ProxyServer) handleOPTIONS(w http.ResponseWriter, r *http.Request) {
|
||||
s.setCORSHeaders(w)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}
|
||||
|
||||
func (s *ProxyServer) jsonError(w http.ResponseWriter, status int, message string, details interface{}) {
|
||||
s.setCORSHeaders(w)
|
||||
w.Header().Set("Content-Type", "application/json; charset=utf-8")
|
||||
w.WriteHeader(status)
|
||||
|
||||
errorResp := ErrorResponse{
|
||||
Error: ErrorDetail{
|
||||
Code: status,
|
||||
Message: message,
|
||||
Status: s.statusToGoogleStatus(status),
|
||||
},
|
||||
}
|
||||
|
||||
if details != nil {
|
||||
errorResp.Error.Details = []interface{}{details}
|
||||
}
|
||||
|
||||
json.NewEncoder(w).Encode(errorResp)
|
||||
}
|
||||
|
||||
func (s *ProxyServer) buildUpstreamHeaders(r *http.Request) http.Header {
|
||||
headers := make(http.Header)
|
||||
copyHeader := func(key string) {
|
||||
if val := r.Header.Get(key); val != "" {
|
||||
headers.Set(key, val)
|
||||
}
|
||||
}
|
||||
|
||||
copyHeader("Authorization")
|
||||
copyHeader("X-Goog-Api-Key")
|
||||
copyHeader("Content-Type")
|
||||
copyHeader("Accept")
|
||||
|
||||
return headers
|
||||
}
|
||||
|
||||
func (s *ProxyServer) standardizeError(resp *http.Response) (*ErrorResponse, error) {
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
s.logger.Error("Failed to read error response body:", err)
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
s.logger.Error("Upstream error body:", truncate(string(body), s.config.LogTruncationLimit))
|
||||
|
||||
var errorResp ErrorResponse
|
||||
if err := json.Unmarshal(body, &errorResp); err != nil {
|
||||
// Create standardized error if upstream doesn't provide one
|
||||
message := "Request failed"
|
||||
if resp.StatusCode == 429 {
|
||||
message = "Resource has been exhausted (e.g. check quota)."
|
||||
} else if resp.Status != "" {
|
||||
message = resp.Status
|
||||
}
|
||||
|
||||
errorResp = ErrorResponse{
|
||||
Error: ErrorDetail{
|
||||
Code: resp.StatusCode,
|
||||
Message: message,
|
||||
Status: s.statusToGoogleStatus(resp.StatusCode),
|
||||
},
|
||||
}
|
||||
|
||||
if len(body) > 0 {
|
||||
errorResp.Error.Details = []interface{}{
|
||||
map[string]interface{}{
|
||||
"@type": "proxy.upstream",
|
||||
"upstream_error": truncate(string(body), s.config.LogTruncationLimit),
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure status field is set
|
||||
if errorResp.Error.Status == "" {
|
||||
errorResp.Error.Status = s.statusToGoogleStatus(errorResp.Error.Code)
|
||||
}
|
||||
|
||||
return &errorResp, nil
|
||||
}
|
||||
|
||||
func (s *ProxyServer) buildRetryRequestBody(originalBody *GeminiRequest, accumulatedText string) *GeminiRequest {
|
||||
s.logger.Debug(fmt.Sprintf("Building retry request. Accumulated text length: %d", len(accumulatedText)))
|
||||
s.logger.Debug("Accumulated text preview:", truncate(accumulatedText, 500))
|
||||
|
||||
retryBody := &GeminiRequest{
|
||||
Contents: make([]Content, len(originalBody.Contents)),
|
||||
}
|
||||
copy(retryBody.Contents, originalBody.Contents)
|
||||
|
||||
// Find last user message index
|
||||
lastUserIndex := -1
|
||||
for i := len(retryBody.Contents) - 1; i >= 0; i-- {
|
||||
if retryBody.Contents[i].Role == "user" {
|
||||
lastUserIndex = i
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Insert model response and continuation prompt
|
||||
history := []Content{
|
||||
{
|
||||
Role: "model",
|
||||
Parts: []Part{{Text: accumulatedText}},
|
||||
},
|
||||
{
|
||||
Role: "user",
|
||||
Parts: []Part{{
|
||||
Text: "Continue exactly where you left off, providing the final answer without repeating the previous thinking steps.",
|
||||
}},
|
||||
},
|
||||
}
|
||||
|
||||
if lastUserIndex != -1 {
|
||||
// Insert after the last user message
|
||||
newContents := make([]Content, 0, len(retryBody.Contents)+len(history))
|
||||
newContents = append(newContents, retryBody.Contents[:lastUserIndex+1]...)
|
||||
newContents = append(newContents, history...)
|
||||
newContents = append(newContents, retryBody.Contents[lastUserIndex+1:]...)
|
||||
retryBody.Contents = newContents
|
||||
} else {
|
||||
retryBody.Contents = append(retryBody.Contents, history...)
|
||||
}
|
||||
|
||||
bodyJSON, _ := json.Marshal(retryBody)
|
||||
s.logger.Debug("Constructed retry request body:", truncate(string(bodyJSON), s.config.LogTruncationLimit))
|
||||
|
||||
return retryBody
|
||||
}
|
||||
|
||||
func (s *ProxyServer) processStreamAndRetryInternally(ctx context.Context, w http.ResponseWriter, initialResp *http.Response, originalRequestBody *GeminiRequest, upstreamURL string, originalHeaders http.Header) error {
|
||||
var accumulatedText string
|
||||
var consecutiveRetryCount int
|
||||
currentResp := initialResp
|
||||
sessionStartTime := time.Now()
|
||||
|
||||
s.logger.Info(fmt.Sprintf("Starting stream processing session. Max retries: %d", s.config.MaxConsecutiveRetries))
|
||||
|
||||
for {
|
||||
var interruptionReason string
|
||||
streamStartTime := time.Now()
|
||||
linesInThisStream := 0
|
||||
textInThisStream := ""
|
||||
reasoningStepDetected := false
|
||||
hasReceivedFinalAnswerContent := false
|
||||
|
||||
s.logger.Info(fmt.Sprintf("=== Starting stream attempt %d/%d ===", consecutiveRetryCount+1, s.config.MaxConsecutiveRetries+1))
|
||||
|
||||
scanner := bufio.NewScanner(currentResp.Body)
|
||||
finishReasonArrived := false
|
||||
|
||||
for scanner.Scan() {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
default:
|
||||
}
|
||||
|
||||
line := scanner.Text()
|
||||
linesInThisStream++
|
||||
|
||||
// Write line to client
|
||||
fmt.Fprintf(w, "%s\n\n", line)
|
||||
if flusher, ok := w.(http.Flusher); ok {
|
||||
flusher.Flush()
|
||||
}
|
||||
|
||||
s.logger.Debug(fmt.Sprintf("SSE Line %d: %s", linesInThisStream, truncate(line, 500)))
|
||||
|
||||
if !strings.HasPrefix(line, "data: ") {
|
||||
continue
|
||||
}
|
||||
|
||||
var payload GeminiResponse
|
||||
if err := json.Unmarshal([]byte(line[6:]), &payload); err != nil {
|
||||
s.logger.Debug("Ignoring non-JSON data line.")
|
||||
continue
|
||||
}
|
||||
|
||||
if len(payload.Candidates) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
candidate := payload.Candidates[0]
|
||||
|
||||
// Process content parts
|
||||
if candidate.Content != nil && candidate.Content.Parts != nil {
|
||||
for _, part := range candidate.Content.Parts {
|
||||
if part.Text != "" {
|
||||
accumulatedText += part.Text
|
||||
textInThisStream += part.Text
|
||||
|
||||
// Check if this is final answer content (not a thought)
|
||||
if !part.Thought {
|
||||
hasReceivedFinalAnswerContent = true
|
||||
s.logger.Debug("Received final answer content (non-thought part).")
|
||||
} else {
|
||||
s.logger.Debug("Received 'thought' content part.")
|
||||
}
|
||||
} else if part.FunctionCall != nil || part.ToolCode != nil {
|
||||
reasoningStepDetected = true
|
||||
partJSON, _ := json.Marshal(part)
|
||||
s.logger.Info("Reasoning step detected (tool/function call):", truncate(string(partJSON), s.config.LogTruncationLimit))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Process finish reason
|
||||
if candidate.FinishReason != "" {
|
||||
finishReasonArrived = true
|
||||
s.logger.Info("Finish reason received:", candidate.FinishReason)
|
||||
|
||||
switch candidate.FinishReason {
|
||||
case "STOP":
|
||||
if hasReceivedFinalAnswerContent {
|
||||
sessionDuration := time.Since(sessionStartTime)
|
||||
s.logger.Info(fmt.Sprintf("=== STREAM COMPLETED SUCCESSFULLY (Reason: STOP after receiving final answer) ==="))
|
||||
s.logger.Info(fmt.Sprintf(" - Total session duration: %v, Retries: %d", sessionDuration, consecutiveRetryCount))
|
||||
currentResp.Body.Close()
|
||||
return nil
|
||||
} else {
|
||||
s.logger.Error("Stream finished with STOP but no final answer content was received. This is a failure.")
|
||||
interruptionReason = "STOP_WITHOUT_ANSWER"
|
||||
}
|
||||
case "MAX_TOKENS", "TOOL_CODE", "SAFETY", "RECITATION":
|
||||
s.logger.Info("Stream terminated with reason:", candidate.FinishReason, ". Closing stream.")
|
||||
currentResp.Body.Close()
|
||||
return nil
|
||||
default:
|
||||
s.logger.Error("Abnormal/unknown finish reason:", candidate.FinishReason)
|
||||
interruptionReason = "FINISH_ABNORMAL"
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
currentResp.Body.Close()
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
s.logger.Error("Scanner error:", err)
|
||||
interruptionReason = "SCAN_ERROR"
|
||||
}
|
||||
|
||||
if !finishReasonArrived && interruptionReason == "" {
|
||||
s.logger.Error("Stream ended prematurely without a finish reason (DROP).")
|
||||
if reasoningStepDetected {
|
||||
interruptionReason = "DROP_DURING_REASONING"
|
||||
} else {
|
||||
interruptionReason = "DROP"
|
||||
}
|
||||
}
|
||||
|
||||
streamDuration := time.Since(streamStartTime)
|
||||
s.logger.Info(fmt.Sprintf("Stream attempt %d summary: Duration: %v, Lines: %d, Chars: %d, Total Chars: %d",
|
||||
consecutiveRetryCount+1, streamDuration, linesInThisStream, len(textInThisStream), len(accumulatedText)))
|
||||
|
||||
if interruptionReason == "" {
|
||||
s.logger.Info("Stream finished without interruption. Closing.")
|
||||
return nil
|
||||
}
|
||||
|
||||
s.logger.Error("=== STREAM INTERRUPTED (Reason:", interruptionReason, ") ===")
|
||||
|
||||
if consecutiveRetryCount >= s.config.MaxConsecutiveRetries {
|
||||
s.logger.Error("Retry limit exceeded. Sending final error to client.")
|
||||
errorPayload := map[string]interface{}{
|
||||
"error": map[string]interface{}{
|
||||
"code": 504,
|
||||
"status": "DEADLINE_EXCEEDED",
|
||||
"message": fmt.Sprintf("Proxy retry limit (%d) exceeded. Last interruption: %s.", s.config.MaxConsecutiveRetries, interruptionReason),
|
||||
},
|
||||
}
|
||||
errorJSON, _ := json.Marshal(errorPayload)
|
||||
fmt.Fprintf(w, "event: error\ndata: %s\n\n", string(errorJSON))
|
||||
if flusher, ok := w.(http.Flusher); ok {
|
||||
flusher.Flush()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
consecutiveRetryCount++
|
||||
s.logger.Info(fmt.Sprintf("Proceeding to retry attempt %d...", consecutiveRetryCount))
|
||||
|
||||
// Wait before retry
|
||||
if s.config.RetryDelayMs > 0 {
|
||||
s.logger.Debug(fmt.Sprintf("Waiting %dms before retrying...", s.config.RetryDelayMs))
|
||||
time.Sleep(time.Duration(s.config.RetryDelayMs) * time.Millisecond)
|
||||
}
|
||||
|
||||
// Build retry request
|
||||
retryBody := s.buildRetryRequestBody(originalRequestBody, accumulatedText)
|
||||
retryBodyJSON, _ := json.Marshal(retryBody)
|
||||
|
||||
s.logger.Debug("Making retry request to:", upstreamURL)
|
||||
retryReq, err := http.NewRequestWithContext(ctx, "POST", upstreamURL, bytes.NewReader(retryBodyJSON))
|
||||
if err != nil {
|
||||
s.logger.Error("Failed to create retry request:", err)
|
||||
continue
|
||||
}
|
||||
|
||||
retryReq.Header = s.buildUpstreamHeaders(&http.Request{Header: originalHeaders})
|
||||
retryReq.Header.Set("Content-Type", "application/json")
|
||||
|
||||
retryResp, err := s.client.Do(retryReq)
|
||||
if err != nil {
|
||||
s.logger.Error("Retry request failed:", err)
|
||||
continue
|
||||
}
|
||||
|
||||
s.logger.Info(fmt.Sprintf("Retry request completed. Status: %d %s", retryResp.StatusCode, retryResp.Status))
|
||||
|
||||
if nonRetryableStatuses[retryResp.StatusCode] {
|
||||
s.logger.Error(fmt.Sprintf("FATAL: Received non-retryable status %d during retry.", retryResp.StatusCode))
|
||||
errorResp, err := s.standardizeError(retryResp)
|
||||
if err == nil {
|
||||
errorJSON, _ := json.Marshal(errorResp)
|
||||
fmt.Fprintf(w, "event: error\ndata: %s\n\n", string(errorJSON))
|
||||
if flusher, ok := w.(http.Flusher); ok {
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
if retryResp.StatusCode != 200 {
|
||||
retryResp.Body.Close()
|
||||
s.logger.Error(fmt.Sprintf("Upstream server error on retry: %d", retryResp.StatusCode))
|
||||
continue
|
||||
}
|
||||
|
||||
s.logger.Info("✓ Retry successful. Got new stream.")
|
||||
currentResp = retryResp
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ProxyServer) handleStreamingPost(w http.ResponseWriter, r *http.Request) {
|
||||
upstreamURL := s.config.UpstreamURLBase + r.URL.Path
|
||||
if r.URL.RawQuery != "" {
|
||||
upstreamURL += "?" + r.URL.RawQuery
|
||||
}
|
||||
|
||||
s.logger.Info(fmt.Sprintf("=== NEW STREAMING REQUEST: %s %s ===", r.Method, r.URL.String()))
|
||||
|
||||
// Read and parse request body
|
||||
body, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
s.jsonError(w, http.StatusBadRequest, "Failed to read request body", err.Error())
|
||||
return
|
||||
}
|
||||
defer r.Body.Close()
|
||||
|
||||
s.logger.Info(fmt.Sprintf("Request body (raw, %d bytes): %s", len(body), truncate(string(body), s.config.LogTruncationLimit)))
|
||||
|
||||
var originalRequestBody GeminiRequest
|
||||
if err := json.Unmarshal(body, &originalRequestBody); err != nil {
|
||||
s.logger.Error("Failed to parse request body:", err)
|
||||
s.jsonError(w, http.StatusBadRequest, "Invalid JSON in request body", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if len(originalRequestBody.Contents) > 0 {
|
||||
s.logger.Info(fmt.Sprintf("Request contains %d messages:", len(originalRequestBody.Contents)))
|
||||
for i, m := range originalRequestBody.Contents {
|
||||
var partsText []string
|
||||
for _, p := range m.Parts {
|
||||
if p.Text != "" {
|
||||
partsText = append(partsText, p.Text)
|
||||
} else {
|
||||
partsText = append(partsText, "[non-text part]")
|
||||
}
|
||||
}
|
||||
s.logger.Info(fmt.Sprintf(" [%d] role=%s, text: %s", i, m.Role, truncate(strings.Join(partsText, "\n"), 1000)))
|
||||
}
|
||||
}
|
||||
|
||||
s.logger.Info("=== MAKING INITIAL REQUEST TO UPSTREAM ===")
|
||||
t0 := time.Now()
|
||||
|
||||
req, err := http.NewRequestWithContext(r.Context(), "POST", upstreamURL, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
s.jsonError(w, http.StatusInternalServerError, "Failed to create upstream request", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
req.Header = s.buildUpstreamHeaders(r)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
initialResponse, err := s.client.Do(req)
|
||||
if err != nil {
|
||||
s.logger.Error("Initial request failed:", err)
|
||||
s.jsonError(w, http.StatusBadGateway, "Failed to connect to upstream", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
s.logger.Info(fmt.Sprintf("Initial upstream response received in %v. Status: %d", time.Since(t0), initialResponse.StatusCode))
|
||||
|
||||
if initialResponse.StatusCode != 200 {
|
||||
s.logger.Error(fmt.Sprintf("Initial request failed with status %d.", initialResponse.StatusCode))
|
||||
errorResp, err := s.standardizeError(initialResponse)
|
||||
if err != nil {
|
||||
s.jsonError(w, http.StatusBadGateway, "Failed to process upstream error", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
s.setCORSHeaders(w)
|
||||
w.Header().Set("Content-Type", "application/json; charset=utf-8")
|
||||
w.WriteHeader(initialResponse.StatusCode)
|
||||
json.NewEncoder(w).Encode(errorResp)
|
||||
return
|
||||
}
|
||||
|
||||
s.logger.Info("✓ Initial request successful. Starting stream processing.")
|
||||
|
||||
// Set up streaming response headers
|
||||
s.setCORSHeaders(w)
|
||||
w.Header().Set("Content-Type", "text/event-stream; charset=utf-8")
|
||||
w.Header().Set("Cache-Control", "no-cache")
|
||||
w.Header().Set("Connection", "keep-alive")
|
||||
w.WriteHeader(200)
|
||||
|
||||
if flusher, ok := w.(http.Flusher); ok {
|
||||
flusher.Flush()
|
||||
}
|
||||
|
||||
err = s.processStreamAndRetryInternally(r.Context(), w, initialResponse, &originalRequestBody, upstreamURL, r.Header)
|
||||
if err != nil {
|
||||
s.logger.Error("!!! UNHANDLED CRITICAL EXCEPTION IN STREAM PROCESSOR !!!", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ProxyServer) handleNonStreaming(w http.ResponseWriter, r *http.Request) {
|
||||
upstreamURL := s.config.UpstreamURLBase + r.URL.Path
|
||||
if r.URL.RawQuery != "" {
|
||||
upstreamURL += "?" + r.URL.RawQuery
|
||||
}
|
||||
|
||||
s.logger.Info(fmt.Sprintf("=== NEW NON-STREAMING REQUEST: %s %s ===", r.Method, r.URL.String()))
|
||||
|
||||
req, err := http.NewRequestWithContext(r.Context(), r.Method, upstreamURL, r.Body)
|
||||
if err != nil {
|
||||
s.jsonError(w, http.StatusInternalServerError, "Failed to create upstream request", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
req.Header = s.buildUpstreamHeaders(r)
|
||||
|
||||
resp, err := s.client.Do(req)
|
||||
if err != nil {
|
||||
s.logger.Error("Upstream request failed:", err)
|
||||
s.jsonError(w, http.StatusBadGateway, "Failed to connect to upstream", err.Error())
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != 200 {
|
||||
errorResp, err := s.standardizeError(resp)
|
||||
if err != nil {
|
||||
s.jsonError(w, http.StatusBadGateway, "Failed to process upstream error", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
s.setCORSHeaders(w)
|
||||
w.Header().Set("Content-Type", "application/json; charset=utf-8")
|
||||
w.WriteHeader(resp.StatusCode)
|
||||
json.NewEncoder(w).Encode(errorResp)
|
||||
return
|
||||
}
|
||||
|
||||
// Copy headers from upstream response
|
||||
s.setCORSHeaders(w)
|
||||
for key, values := range resp.Header {
|
||||
if key != "Access-Control-Allow-Origin" {
|
||||
for _, value := range values {
|
||||
w.Header().Add(key, value)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
w.WriteHeader(resp.StatusCode)
|
||||
io.Copy(w, resp.Body)
|
||||
}
|
||||
|
||||
func (s *ProxyServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method == "OPTIONS" {
|
||||
s.handleOPTIONS(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
// Check if this is a streaming request
|
||||
isStream := r.URL.Query().Get("alt") == "sse"
|
||||
|
||||
if r.Method == "POST" && isStream {
|
||||
s.handleStreamingPost(w, r)
|
||||
} else {
|
||||
s.handleNonStreaming(w, r)
|
||||
}
|
||||
}
|
||||
|
||||
func main() {
|
||||
config := DefaultConfig()
|
||||
|
||||
// Override config from environment variables if needed
|
||||
// You can add environment variable parsing here
|
||||
|
||||
server := NewProxyServer(config)
|
||||
|
||||
s := &http.Server{
|
||||
Addr: config.Port,
|
||||
Handler: server,
|
||||
ReadTimeout: 30 * time.Second,
|
||||
WriteTimeout: 300 * time.Second, // Longer for streaming
|
||||
IdleTimeout: 120 * time.Second,
|
||||
}
|
||||
|
||||
log.Printf("Starting Gemini API Proxy server on %s", config.Port)
|
||||
log.Fatal(s.ListenAndServe())
|
||||
}
|
||||
Reference in New Issue
Block a user