673 lines
20 KiB
Go
673 lines
20 KiB
Go
package main
|
|
|
|
import (
|
|
"bufio"
|
|
"bytes"
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"log"
|
|
"net/http"
|
|
"strings"
|
|
"time"
|
|
)
|
|
|
|
// Config holds configuration settings
|
|
type Config struct {
|
|
UpstreamURLBase string
|
|
MaxConsecutiveRetries int
|
|
DebugMode bool
|
|
RetryDelayMs int
|
|
LogTruncationLimit int
|
|
Port string
|
|
}
|
|
|
|
// DefaultConfig returns default configuration
|
|
func DefaultConfig() *Config {
|
|
return &Config{
|
|
UpstreamURLBase: "https://api-proxy.me/gemini",
|
|
MaxConsecutiveRetries: 10,
|
|
DebugMode: true,
|
|
RetryDelayMs: 750,
|
|
LogTruncationLimit: 8000,
|
|
Port: ":8080",
|
|
}
|
|
}
|
|
|
|
var nonRetryableStatuses = map[int]bool{
|
|
400: true, 401: true, 403: true, 404: true, 429: true,
|
|
}
|
|
|
|
// Logger wrapper for different log levels
|
|
type Logger struct {
|
|
debug bool
|
|
}
|
|
|
|
func (l *Logger) Debug(args ...interface{}) {
|
|
if l.debug {
|
|
log.Printf("[DEBUG %s] %v", time.Now().Format(time.RFC3339), fmt.Sprint(args...))
|
|
}
|
|
}
|
|
|
|
func (l *Logger) Info(args ...interface{}) {
|
|
log.Printf("[INFO %s] %v", time.Now().Format(time.RFC3339), fmt.Sprint(args...))
|
|
}
|
|
|
|
func (l *Logger) Error(args ...interface{}) {
|
|
log.Printf("[ERROR %s] %v", time.Now().Format(time.RFC3339), fmt.Sprint(args...))
|
|
}
|
|
|
|
func truncate(s string, limit int) string {
|
|
if len(s) <= limit {
|
|
return s
|
|
}
|
|
return fmt.Sprintf("%s... [truncated %d chars]", s[:limit], len(s)-limit)
|
|
}
|
|
|
|
// Error response structure
|
|
type ErrorResponse struct {
|
|
Error ErrorDetail `json:"error"`
|
|
}
|
|
|
|
type ErrorDetail struct {
|
|
Code int `json:"code"`
|
|
Message string `json:"message"`
|
|
Status string `json:"status"`
|
|
Details []interface{} `json:"details,omitempty"`
|
|
}
|
|
|
|
// Request/Response structures for Gemini API
|
|
type GeminiRequest struct {
|
|
Contents []Content `json:"contents"`
|
|
}
|
|
|
|
type Content struct {
|
|
Role string `json:"role"`
|
|
Parts []Part `json:"parts"`
|
|
}
|
|
|
|
type Part struct {
|
|
Text string `json:"text,omitempty"`
|
|
Thought bool `json:"thought,omitempty"`
|
|
FunctionCall interface{} `json:"functionCall,omitempty"`
|
|
ToolCode interface{} `json:"toolCode,omitempty"`
|
|
}
|
|
|
|
type GeminiResponse struct {
|
|
Candidates []Candidate `json:"candidates,omitempty"`
|
|
Error *ErrorDetail `json:"error,omitempty"`
|
|
}
|
|
|
|
type Candidate struct {
|
|
Content *Content `json:"content,omitempty"`
|
|
FinishReason string `json:"finishReason,omitempty"`
|
|
}
|
|
|
|
// ProxyServer represents the main server
|
|
type ProxyServer struct {
|
|
config *Config
|
|
logger *Logger
|
|
client *http.Client
|
|
}
|
|
|
|
func NewProxyServer(config *Config) *ProxyServer {
|
|
return &ProxyServer{
|
|
config: config,
|
|
logger: &Logger{debug: config.DebugMode},
|
|
client: &http.Client{Timeout: 30 * time.Second},
|
|
}
|
|
}
|
|
|
|
func (s *ProxyServer) statusToGoogleStatus(code int) string {
|
|
switch code {
|
|
case 400:
|
|
return "INVALID_ARGUMENT"
|
|
case 401:
|
|
return "UNAUTHENTICATED"
|
|
case 403:
|
|
return "PERMISSION_DENIED"
|
|
case 404:
|
|
return "NOT_FOUND"
|
|
case 429:
|
|
return "RESOURCE_EXHAUSTED"
|
|
case 500:
|
|
return "INTERNAL"
|
|
case 503:
|
|
return "UNAVAILABLE"
|
|
case 504:
|
|
return "DEADLINE_EXCEEDED"
|
|
default:
|
|
return "UNKNOWN"
|
|
}
|
|
}
|
|
|
|
func (s *ProxyServer) setCORSHeaders(w http.ResponseWriter) {
|
|
w.Header().Set("Access-Control-Allow-Origin", "*")
|
|
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, OPTIONS")
|
|
w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization, X-Goog-Api-Key")
|
|
}
|
|
|
|
func (s *ProxyServer) handleOPTIONS(w http.ResponseWriter, r *http.Request) {
|
|
s.setCORSHeaders(w)
|
|
w.WriteHeader(http.StatusOK)
|
|
}
|
|
|
|
func (s *ProxyServer) jsonError(w http.ResponseWriter, status int, message string, details interface{}) {
|
|
s.setCORSHeaders(w)
|
|
w.Header().Set("Content-Type", "application/json; charset=utf-8")
|
|
w.WriteHeader(status)
|
|
|
|
errorResp := ErrorResponse{
|
|
Error: ErrorDetail{
|
|
Code: status,
|
|
Message: message,
|
|
Status: s.statusToGoogleStatus(status),
|
|
},
|
|
}
|
|
|
|
if details != nil {
|
|
errorResp.Error.Details = []interface{}{details}
|
|
}
|
|
|
|
json.NewEncoder(w).Encode(errorResp)
|
|
}
|
|
|
|
func (s *ProxyServer) buildUpstreamHeaders(r *http.Request) http.Header {
|
|
headers := make(http.Header)
|
|
copyHeader := func(key string) {
|
|
if val := r.Header.Get(key); val != "" {
|
|
headers.Set(key, val)
|
|
}
|
|
}
|
|
|
|
copyHeader("Authorization")
|
|
copyHeader("X-Goog-Api-Key")
|
|
copyHeader("Content-Type")
|
|
copyHeader("Accept")
|
|
|
|
return headers
|
|
}
|
|
|
|
func (s *ProxyServer) standardizeError(resp *http.Response) (*ErrorResponse, error) {
|
|
body, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
s.logger.Error("Failed to read error response body:", err)
|
|
return nil, err
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
s.logger.Error("Upstream error body:", truncate(string(body), s.config.LogTruncationLimit))
|
|
|
|
var errorResp ErrorResponse
|
|
if err := json.Unmarshal(body, &errorResp); err != nil {
|
|
// Create standardized error if upstream doesn't provide one
|
|
message := "Request failed"
|
|
if resp.StatusCode == 429 {
|
|
message = "Resource has been exhausted (e.g. check quota)."
|
|
} else if resp.Status != "" {
|
|
message = resp.Status
|
|
}
|
|
|
|
errorResp = ErrorResponse{
|
|
Error: ErrorDetail{
|
|
Code: resp.StatusCode,
|
|
Message: message,
|
|
Status: s.statusToGoogleStatus(resp.StatusCode),
|
|
},
|
|
}
|
|
|
|
if len(body) > 0 {
|
|
errorResp.Error.Details = []interface{}{
|
|
map[string]interface{}{
|
|
"@type": "proxy.upstream",
|
|
"upstream_error": truncate(string(body), s.config.LogTruncationLimit),
|
|
},
|
|
}
|
|
}
|
|
}
|
|
|
|
// Ensure status field is set
|
|
if errorResp.Error.Status == "" {
|
|
errorResp.Error.Status = s.statusToGoogleStatus(errorResp.Error.Code)
|
|
}
|
|
|
|
return &errorResp, nil
|
|
}
|
|
|
|
func (s *ProxyServer) buildRetryRequestBody(originalBody *GeminiRequest, accumulatedText string) *GeminiRequest {
|
|
s.logger.Debug(fmt.Sprintf("Building retry request. Accumulated text length: %d", len(accumulatedText)))
|
|
s.logger.Debug("Accumulated text preview:", truncate(accumulatedText, 500))
|
|
|
|
retryBody := &GeminiRequest{
|
|
Contents: make([]Content, len(originalBody.Contents)),
|
|
}
|
|
copy(retryBody.Contents, originalBody.Contents)
|
|
|
|
// Find last user message index
|
|
lastUserIndex := -1
|
|
for i := len(retryBody.Contents) - 1; i >= 0; i-- {
|
|
if retryBody.Contents[i].Role == "user" {
|
|
lastUserIndex = i
|
|
break
|
|
}
|
|
}
|
|
|
|
// Insert model response and continuation prompt
|
|
history := []Content{
|
|
{
|
|
Role: "model",
|
|
Parts: []Part{{Text: accumulatedText}},
|
|
},
|
|
{
|
|
Role: "user",
|
|
Parts: []Part{{
|
|
Text: "Continue exactly where you left off, providing the final answer without repeating the previous thinking steps.",
|
|
}},
|
|
},
|
|
}
|
|
|
|
if lastUserIndex != -1 {
|
|
// Insert after the last user message
|
|
newContents := make([]Content, 0, len(retryBody.Contents)+len(history))
|
|
newContents = append(newContents, retryBody.Contents[:lastUserIndex+1]...)
|
|
newContents = append(newContents, history...)
|
|
newContents = append(newContents, retryBody.Contents[lastUserIndex+1:]...)
|
|
retryBody.Contents = newContents
|
|
} else {
|
|
retryBody.Contents = append(retryBody.Contents, history...)
|
|
}
|
|
|
|
bodyJSON, _ := json.Marshal(retryBody)
|
|
s.logger.Debug("Constructed retry request body:", truncate(string(bodyJSON), s.config.LogTruncationLimit))
|
|
|
|
return retryBody
|
|
}
|
|
|
|
func (s *ProxyServer) processStreamAndRetryInternally(ctx context.Context, w http.ResponseWriter, initialResp *http.Response, originalRequestBody *GeminiRequest, upstreamURL string, originalHeaders http.Header) error {
|
|
var accumulatedText string
|
|
var consecutiveRetryCount int
|
|
currentResp := initialResp
|
|
sessionStartTime := time.Now()
|
|
|
|
s.logger.Info(fmt.Sprintf("Starting stream processing session. Max retries: %d", s.config.MaxConsecutiveRetries))
|
|
|
|
for {
|
|
var interruptionReason string
|
|
streamStartTime := time.Now()
|
|
linesInThisStream := 0
|
|
textInThisStream := ""
|
|
reasoningStepDetected := false
|
|
hasReceivedFinalAnswerContent := false
|
|
|
|
s.logger.Info(fmt.Sprintf("=== Starting stream attempt %d/%d ===", consecutiveRetryCount+1, s.config.MaxConsecutiveRetries+1))
|
|
|
|
scanner := bufio.NewScanner(currentResp.Body)
|
|
finishReasonArrived := false
|
|
|
|
for scanner.Scan() {
|
|
select {
|
|
case <-ctx.Done():
|
|
return ctx.Err()
|
|
default:
|
|
}
|
|
|
|
line := scanner.Text()
|
|
linesInThisStream++
|
|
|
|
// Write line to client
|
|
fmt.Fprintf(w, "%s\n\n", line)
|
|
if flusher, ok := w.(http.Flusher); ok {
|
|
flusher.Flush()
|
|
}
|
|
|
|
s.logger.Debug(fmt.Sprintf("SSE Line %d: %s", linesInThisStream, truncate(line, 500)))
|
|
|
|
if !strings.HasPrefix(line, "data: ") {
|
|
continue
|
|
}
|
|
|
|
var payload GeminiResponse
|
|
if err := json.Unmarshal([]byte(line[6:]), &payload); err != nil {
|
|
s.logger.Debug("Ignoring non-JSON data line.")
|
|
continue
|
|
}
|
|
|
|
if len(payload.Candidates) == 0 {
|
|
continue
|
|
}
|
|
|
|
candidate := payload.Candidates[0]
|
|
|
|
// Process content parts
|
|
if candidate.Content != nil && candidate.Content.Parts != nil {
|
|
for _, part := range candidate.Content.Parts {
|
|
if part.Text != "" {
|
|
accumulatedText += part.Text
|
|
textInThisStream += part.Text
|
|
|
|
// Check if this is final answer content (not a thought)
|
|
if !part.Thought {
|
|
hasReceivedFinalAnswerContent = true
|
|
s.logger.Debug("Received final answer content (non-thought part).")
|
|
} else {
|
|
s.logger.Debug("Received 'thought' content part.")
|
|
}
|
|
} else if part.FunctionCall != nil || part.ToolCode != nil {
|
|
reasoningStepDetected = true
|
|
partJSON, _ := json.Marshal(part)
|
|
s.logger.Info("Reasoning step detected (tool/function call):", truncate(string(partJSON), s.config.LogTruncationLimit))
|
|
}
|
|
}
|
|
}
|
|
|
|
// Process finish reason
|
|
if candidate.FinishReason != "" {
|
|
finishReasonArrived = true
|
|
s.logger.Info("Finish reason received:", candidate.FinishReason)
|
|
|
|
switch candidate.FinishReason {
|
|
case "STOP":
|
|
if hasReceivedFinalAnswerContent {
|
|
sessionDuration := time.Since(sessionStartTime)
|
|
s.logger.Info(fmt.Sprintf("=== STREAM COMPLETED SUCCESSFULLY (Reason: STOP after receiving final answer) ==="))
|
|
s.logger.Info(fmt.Sprintf(" - Total session duration: %v, Retries: %d", sessionDuration, consecutiveRetryCount))
|
|
currentResp.Body.Close()
|
|
return nil
|
|
} else {
|
|
s.logger.Error("Stream finished with STOP but no final answer content was received. This is a failure.")
|
|
interruptionReason = "STOP_WITHOUT_ANSWER"
|
|
}
|
|
case "MAX_TOKENS", "TOOL_CODE", "SAFETY", "RECITATION":
|
|
s.logger.Info("Stream terminated with reason:", candidate.FinishReason, ". Closing stream.")
|
|
currentResp.Body.Close()
|
|
return nil
|
|
default:
|
|
s.logger.Error("Abnormal/unknown finish reason:", candidate.FinishReason)
|
|
interruptionReason = "FINISH_ABNORMAL"
|
|
}
|
|
break
|
|
}
|
|
}
|
|
|
|
currentResp.Body.Close()
|
|
|
|
if err := scanner.Err(); err != nil {
|
|
s.logger.Error("Scanner error:", err)
|
|
interruptionReason = "SCAN_ERROR"
|
|
}
|
|
|
|
if !finishReasonArrived && interruptionReason == "" {
|
|
s.logger.Error("Stream ended prematurely without a finish reason (DROP).")
|
|
if reasoningStepDetected {
|
|
interruptionReason = "DROP_DURING_REASONING"
|
|
} else {
|
|
interruptionReason = "DROP"
|
|
}
|
|
}
|
|
|
|
streamDuration := time.Since(streamStartTime)
|
|
s.logger.Info(fmt.Sprintf("Stream attempt %d summary: Duration: %v, Lines: %d, Chars: %d, Total Chars: %d",
|
|
consecutiveRetryCount+1, streamDuration, linesInThisStream, len(textInThisStream), len(accumulatedText)))
|
|
|
|
if interruptionReason == "" {
|
|
s.logger.Info("Stream finished without interruption. Closing.")
|
|
return nil
|
|
}
|
|
|
|
s.logger.Error("=== STREAM INTERRUPTED (Reason:", interruptionReason, ") ===")
|
|
|
|
if consecutiveRetryCount >= s.config.MaxConsecutiveRetries {
|
|
s.logger.Error("Retry limit exceeded. Sending final error to client.")
|
|
errorPayload := map[string]interface{}{
|
|
"error": map[string]interface{}{
|
|
"code": 504,
|
|
"status": "DEADLINE_EXCEEDED",
|
|
"message": fmt.Sprintf("Proxy retry limit (%d) exceeded. Last interruption: %s.", s.config.MaxConsecutiveRetries, interruptionReason),
|
|
},
|
|
}
|
|
errorJSON, _ := json.Marshal(errorPayload)
|
|
fmt.Fprintf(w, "event: error\ndata: %s\n\n", string(errorJSON))
|
|
if flusher, ok := w.(http.Flusher); ok {
|
|
flusher.Flush()
|
|
}
|
|
return nil
|
|
}
|
|
|
|
consecutiveRetryCount++
|
|
s.logger.Info(fmt.Sprintf("Proceeding to retry attempt %d...", consecutiveRetryCount))
|
|
|
|
// Wait before retry
|
|
if s.config.RetryDelayMs > 0 {
|
|
s.logger.Debug(fmt.Sprintf("Waiting %dms before retrying...", s.config.RetryDelayMs))
|
|
time.Sleep(time.Duration(s.config.RetryDelayMs) * time.Millisecond)
|
|
}
|
|
|
|
// Build retry request
|
|
retryBody := s.buildRetryRequestBody(originalRequestBody, accumulatedText)
|
|
retryBodyJSON, _ := json.Marshal(retryBody)
|
|
|
|
s.logger.Debug("Making retry request to:", upstreamURL)
|
|
retryReq, err := http.NewRequestWithContext(ctx, "POST", upstreamURL, bytes.NewReader(retryBodyJSON))
|
|
if err != nil {
|
|
s.logger.Error("Failed to create retry request:", err)
|
|
continue
|
|
}
|
|
|
|
retryReq.Header = s.buildUpstreamHeaders(&http.Request{Header: originalHeaders})
|
|
retryReq.Header.Set("Content-Type", "application/json")
|
|
|
|
retryResp, err := s.client.Do(retryReq)
|
|
if err != nil {
|
|
s.logger.Error("Retry request failed:", err)
|
|
continue
|
|
}
|
|
|
|
s.logger.Info(fmt.Sprintf("Retry request completed. Status: %d %s", retryResp.StatusCode, retryResp.Status))
|
|
|
|
if nonRetryableStatuses[retryResp.StatusCode] {
|
|
s.logger.Error(fmt.Sprintf("FATAL: Received non-retryable status %d during retry.", retryResp.StatusCode))
|
|
errorResp, err := s.standardizeError(retryResp)
|
|
if err == nil {
|
|
errorJSON, _ := json.Marshal(errorResp)
|
|
fmt.Fprintf(w, "event: error\ndata: %s\n\n", string(errorJSON))
|
|
if flusher, ok := w.(http.Flusher); ok {
|
|
flusher.Flush()
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
if retryResp.StatusCode != 200 {
|
|
retryResp.Body.Close()
|
|
s.logger.Error(fmt.Sprintf("Upstream server error on retry: %d", retryResp.StatusCode))
|
|
continue
|
|
}
|
|
|
|
s.logger.Info("✓ Retry successful. Got new stream.")
|
|
currentResp = retryResp
|
|
}
|
|
}
|
|
|
|
func (s *ProxyServer) handleStreamingPost(w http.ResponseWriter, r *http.Request) {
|
|
upstreamURL := s.config.UpstreamURLBase + r.URL.Path
|
|
if r.URL.RawQuery != "" {
|
|
upstreamURL += "?" + r.URL.RawQuery
|
|
}
|
|
|
|
s.logger.Info(fmt.Sprintf("=== NEW STREAMING REQUEST: %s %s ===", r.Method, r.URL.String()))
|
|
|
|
// Read and parse request body
|
|
body, err := io.ReadAll(r.Body)
|
|
if err != nil {
|
|
s.jsonError(w, http.StatusBadRequest, "Failed to read request body", err.Error())
|
|
return
|
|
}
|
|
defer r.Body.Close()
|
|
|
|
s.logger.Info(fmt.Sprintf("Request body (raw, %d bytes): %s", len(body), truncate(string(body), s.config.LogTruncationLimit)))
|
|
|
|
var originalRequestBody GeminiRequest
|
|
if err := json.Unmarshal(body, &originalRequestBody); err != nil {
|
|
s.logger.Error("Failed to parse request body:", err)
|
|
s.jsonError(w, http.StatusBadRequest, "Invalid JSON in request body", err.Error())
|
|
return
|
|
}
|
|
|
|
if len(originalRequestBody.Contents) > 0 {
|
|
s.logger.Info(fmt.Sprintf("Request contains %d messages:", len(originalRequestBody.Contents)))
|
|
for i, m := range originalRequestBody.Contents {
|
|
var partsText []string
|
|
for _, p := range m.Parts {
|
|
if p.Text != "" {
|
|
partsText = append(partsText, p.Text)
|
|
} else {
|
|
partsText = append(partsText, "[non-text part]")
|
|
}
|
|
}
|
|
s.logger.Info(fmt.Sprintf(" [%d] role=%s, text: %s", i, m.Role, truncate(strings.Join(partsText, "\n"), 1000)))
|
|
}
|
|
}
|
|
|
|
s.logger.Info("=== MAKING INITIAL REQUEST TO UPSTREAM ===")
|
|
t0 := time.Now()
|
|
|
|
req, err := http.NewRequestWithContext(r.Context(), "POST", upstreamURL, bytes.NewReader(body))
|
|
if err != nil {
|
|
s.jsonError(w, http.StatusInternalServerError, "Failed to create upstream request", err.Error())
|
|
return
|
|
}
|
|
|
|
req.Header = s.buildUpstreamHeaders(r)
|
|
req.Header.Set("Content-Type", "application/json")
|
|
|
|
initialResponse, err := s.client.Do(req)
|
|
if err != nil {
|
|
s.logger.Error("Initial request failed:", err)
|
|
s.jsonError(w, http.StatusBadGateway, "Failed to connect to upstream", err.Error())
|
|
return
|
|
}
|
|
|
|
s.logger.Info(fmt.Sprintf("Initial upstream response received in %v. Status: %d", time.Since(t0), initialResponse.StatusCode))
|
|
|
|
if initialResponse.StatusCode != 200 {
|
|
s.logger.Error(fmt.Sprintf("Initial request failed with status %d.", initialResponse.StatusCode))
|
|
errorResp, err := s.standardizeError(initialResponse)
|
|
if err != nil {
|
|
s.jsonError(w, http.StatusBadGateway, "Failed to process upstream error", err.Error())
|
|
return
|
|
}
|
|
|
|
s.setCORSHeaders(w)
|
|
w.Header().Set("Content-Type", "application/json; charset=utf-8")
|
|
w.WriteHeader(initialResponse.StatusCode)
|
|
json.NewEncoder(w).Encode(errorResp)
|
|
return
|
|
}
|
|
|
|
s.logger.Info("✓ Initial request successful. Starting stream processing.")
|
|
|
|
// Set up streaming response headers
|
|
s.setCORSHeaders(w)
|
|
w.Header().Set("Content-Type", "text/event-stream; charset=utf-8")
|
|
w.Header().Set("Cache-Control", "no-cache")
|
|
w.Header().Set("Connection", "keep-alive")
|
|
w.WriteHeader(200)
|
|
|
|
if flusher, ok := w.(http.Flusher); ok {
|
|
flusher.Flush()
|
|
}
|
|
|
|
err = s.processStreamAndRetryInternally(r.Context(), w, initialResponse, &originalRequestBody, upstreamURL, r.Header)
|
|
if err != nil {
|
|
s.logger.Error("!!! UNHANDLED CRITICAL EXCEPTION IN STREAM PROCESSOR !!!", err)
|
|
}
|
|
}
|
|
|
|
func (s *ProxyServer) handleNonStreaming(w http.ResponseWriter, r *http.Request) {
|
|
upstreamURL := s.config.UpstreamURLBase + r.URL.Path
|
|
if r.URL.RawQuery != "" {
|
|
upstreamURL += "?" + r.URL.RawQuery
|
|
}
|
|
|
|
s.logger.Info(fmt.Sprintf("=== NEW NON-STREAMING REQUEST: %s %s ===", r.Method, r.URL.String()))
|
|
|
|
req, err := http.NewRequestWithContext(r.Context(), r.Method, upstreamURL, r.Body)
|
|
if err != nil {
|
|
s.jsonError(w, http.StatusInternalServerError, "Failed to create upstream request", err.Error())
|
|
return
|
|
}
|
|
|
|
req.Header = s.buildUpstreamHeaders(r)
|
|
|
|
resp, err := s.client.Do(req)
|
|
if err != nil {
|
|
s.logger.Error("Upstream request failed:", err)
|
|
s.jsonError(w, http.StatusBadGateway, "Failed to connect to upstream", err.Error())
|
|
return
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != 200 {
|
|
errorResp, err := s.standardizeError(resp)
|
|
if err != nil {
|
|
s.jsonError(w, http.StatusBadGateway, "Failed to process upstream error", err.Error())
|
|
return
|
|
}
|
|
|
|
s.setCORSHeaders(w)
|
|
w.Header().Set("Content-Type", "application/json; charset=utf-8")
|
|
w.WriteHeader(resp.StatusCode)
|
|
json.NewEncoder(w).Encode(errorResp)
|
|
return
|
|
}
|
|
|
|
// Copy headers from upstream response
|
|
s.setCORSHeaders(w)
|
|
for key, values := range resp.Header {
|
|
if key != "Access-Control-Allow-Origin" {
|
|
for _, value := range values {
|
|
w.Header().Add(key, value)
|
|
}
|
|
}
|
|
}
|
|
|
|
w.WriteHeader(resp.StatusCode)
|
|
io.Copy(w, resp.Body)
|
|
}
|
|
|
|
func (s *ProxyServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|
if r.Method == "OPTIONS" {
|
|
s.handleOPTIONS(w, r)
|
|
return
|
|
}
|
|
|
|
// Check if this is a streaming request
|
|
isStream := r.URL.Query().Get("alt") == "sse"
|
|
|
|
if r.Method == "POST" && isStream {
|
|
s.handleStreamingPost(w, r)
|
|
} else {
|
|
s.handleNonStreaming(w, r)
|
|
}
|
|
}
|
|
|
|
func main() {
|
|
config := DefaultConfig()
|
|
|
|
// Override config from environment variables if needed
|
|
// You can add environment variable parsing here
|
|
|
|
server := NewProxyServer(config)
|
|
|
|
s := &http.Server{
|
|
Addr: config.Port,
|
|
Handler: server,
|
|
ReadTimeout: 30 * time.Second,
|
|
WriteTimeout: 300 * time.Second, // Longer for streaming
|
|
IdleTimeout: 120 * time.Second,
|
|
}
|
|
|
|
log.Printf("Starting Gemini API Proxy server on %s", config.Port)
|
|
log.Fatal(s.ListenAndServe())
|
|
} |