feat: refactor application to use Gin framework, update routing and middleware handling

This commit is contained in:
王锦强
2025-03-15 12:39:17 +08:00
parent cab289dabf
commit 6fa5c1f467
13 changed files with 314 additions and 279 deletions

View File

@@ -2,9 +2,9 @@ package handlers
import (
"html/template"
"net/http"
"path/filepath"
"github.com/gin-gonic/gin"
"tts/internal/config"
)
@@ -29,13 +29,7 @@ func NewPagesHandler(templatesDir string, cfg *config.Config) (*PagesHandler, er
}
// HandleIndex 处理首页请求
func (h *PagesHandler) HandleIndex(w http.ResponseWriter, r *http.Request) {
// 如果不是根路径返回404
if r.URL.Path != "/" && r.URL.Path != "/index.html" {
http.NotFound(w, r)
return
}
func (h *PagesHandler) HandleIndex(c *gin.Context) {
// 准备模板数据
data := map[string]interface{}{
"BasePath": h.config.Server.BasePath,
@@ -45,17 +39,17 @@ func (h *PagesHandler) HandleIndex(w http.ResponseWriter, r *http.Request) {
}
// 设置内容类型
w.Header().Set("Content-Type", "text/html; charset=utf-8")
c.Header("Content-Type", "text/html; charset=utf-8")
// 渲染模板
if err := h.templates.ExecuteTemplate(w, "index.html", data); err != nil {
http.Error(w, "模板渲染失败: "+err.Error(), http.StatusInternalServerError)
if err := h.templates.ExecuteTemplate(c.Writer, "index.html", data); err != nil {
c.AbortWithStatusJSON(500, gin.H{"error": "模板渲染失败: " + err.Error()})
return
}
}
// HandleAPIDoc 处理API文档请求
func (h *PagesHandler) HandleAPIDoc(w http.ResponseWriter, r *http.Request) {
func (h *PagesHandler) HandleAPIDoc(c *gin.Context) {
// 准备模板数据
data := map[string]interface{}{
"BasePath": h.config.Server.BasePath,
@@ -66,11 +60,11 @@ func (h *PagesHandler) HandleAPIDoc(w http.ResponseWriter, r *http.Request) {
}
// 设置内容类型
w.Header().Set("Content-Type", "text/html; charset=utf-8")
c.Header("Content-Type", "text/html; charset=utf-8")
// 渲染模板
if err := h.templates.ExecuteTemplate(w, "api-doc.html", data); err != nil {
http.Error(w, "模板渲染失败: "+err.Error(), http.StatusInternalServerError)
if err := h.templates.ExecuteTemplate(c.Writer, "api-doc.html", data); err != nil {
c.AbortWithStatusJSON(500, gin.H{"error": "模板渲染失败: " + err.Error()})
return
}
}

View File

@@ -1,7 +1,6 @@
package handlers
import (
"encoding/json"
"fmt"
"log"
"net/http"
@@ -15,8 +14,84 @@ import (
"tts/internal/models"
"tts/internal/tts"
"unicode/utf8"
"github.com/gin-gonic/gin"
)
// truncateForLog 截断文本用于日志显示,同时显示开头和结尾
func truncateForLog(text string, maxLength int) string {
// 先去除换行符
text = strings.ReplaceAll(text, "\n", " ")
text = strings.ReplaceAll(text, "\r", " ")
runes := []rune(text)
if len(runes) <= maxLength {
return text
}
// 计算开头和结尾各显示多少字符
halfLength := maxLength / 2
return string(runes[:halfLength]) + "..." + string(runes[len(runes)-halfLength:])
}
// audioMerge 音频合并
func audioMerge(audioSegments [][]byte) ([]byte, error) {
if len(audioSegments) == 0 {
return nil, fmt.Errorf("没有音频片段可合并")
}
// 使用 ffmpeg 合并音频
tempDir, err := os.MkdirTemp("", "audio_merge_")
if err != nil {
return nil, err
}
defer os.RemoveAll(tempDir)
listFile := filepath.Join(tempDir, "concat.txt")
lf, err := os.Create(listFile)
if err != nil {
return nil, err
}
for i, seg := range audioSegments {
segFile := filepath.Join(tempDir, fmt.Sprintf("seg_%d.mp3", i))
if err := os.WriteFile(segFile, seg, 0644); err != nil {
return nil, err
}
if _, err := lf.WriteString(fmt.Sprintf("file '%s'\n", segFile)); err != nil {
return nil, err
}
}
lf.Close()
outputFile := filepath.Join(tempDir, "output.mp3")
cmd := exec.Command("ffmpeg", "-y", "-f", "concat", "-safe", "0", "-i", listFile, "-c", "copy", outputFile)
if err := cmd.Run(); err != nil {
return nil, err
}
mergedData, err := os.ReadFile(outputFile)
if err != nil {
return nil, err
}
log.Printf("使用ffmpeg合并完成总大小: %s", formatFileSize(len(mergedData)))
return mergedData, nil
}
// formatFileSize 格式化文件大小
func formatFileSize(size int) string {
switch {
case size < 1024:
return fmt.Sprintf("%d B", size)
case size < 1024*1024:
return fmt.Sprintf("%.2f KB", float64(size)/1024.0)
case size < 1024*1024*1024:
return fmt.Sprintf("%.2f MB", float64(size)/(1024.0*1024.0))
default:
return fmt.Sprintf("%.2f GB", float64(size)/(1024.0*1024.0*1024.0))
}
}
// TTSHandler 处理TTS请求
type TTSHandler struct {
ttsService tts.Service
@@ -32,13 +107,13 @@ func NewTTSHandler(service tts.Service, cfg *config.Config) *TTSHandler {
}
// HandleOpenAITTS 处理OpenAI兼容的TTS请求
func (h *TTSHandler) HandleOpenAITTS(w http.ResponseWriter, r *http.Request) {
func (h *TTSHandler) HandleOpenAITTS(c *gin.Context) {
// 记录请求开始时间
startTime := time.Now()
// 只支持POST请求
if r.Method != http.MethodPost {
http.Error(w, "仅支持POST请求", http.StatusMethodNotAllowed)
if c.Request.Method != http.MethodPost {
c.AbortWithStatusJSON(http.StatusMethodNotAllowed, gin.H{"error": "仅支持POST请求"})
return
}
@@ -50,8 +125,8 @@ func (h *TTSHandler) HandleOpenAITTS(w http.ResponseWriter, r *http.Request) {
Speed float64 `json:"speed"`
}
if err := json.NewDecoder(r.Body).Decode(&openaiReq); err != nil {
http.Error(w, "无效的JSON请求: "+err.Error(), http.StatusBadRequest)
if err := c.ShouldBindJSON(&openaiReq); err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "无效的JSON请求: " + err.Error()})
return
}
@@ -60,7 +135,7 @@ func (h *TTSHandler) HandleOpenAITTS(w http.ResponseWriter, r *http.Request) {
// 检查必需字段
if openaiReq.Input == "" {
http.Error(w, "input字段不能为空", http.StatusBadRequest)
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "input字段不能为空"})
return
}
@@ -98,7 +173,7 @@ func (h *TTSHandler) HandleOpenAITTS(w http.ResponseWriter, r *http.Request) {
// 检查文本长度
if len(req.Text) > h.config.TTS.MaxTextLength {
http.Error(w, "文本长度超过限制", http.StatusBadRequest)
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "文本长度超过限制"})
return
}
@@ -107,25 +182,25 @@ func (h *TTSHandler) HandleOpenAITTS(w http.ResponseWriter, r *http.Request) {
if len(req.Text) > segmentThreshold && len(req.Text) <= h.config.TTS.MaxTextLength {
log.Printf("文本长度 %d 超过阈值 %d使用分段处理", len(req.Text), segmentThreshold)
// 使用分段处理
h.handleSegmentedTTS(w, r, req)
h.handleSegmentedTTS(c, req)
return
}
// 非流式模式处理
synthStart := time.Now()
resp, err := h.ttsService.SynthesizeSpeech(r.Context(), req)
resp, err := h.ttsService.SynthesizeSpeech(c.Request.Context(), req)
synthTime := time.Since(synthStart)
log.Printf("TTS合成耗时: %v, 文本长度: %d", synthTime, len(req.Text))
if err != nil {
http.Error(w, "语音合成失败: "+err.Error(), http.StatusInternalServerError)
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "语音合成失败: " + err.Error()})
return
}
// 设置响应
w.Header().Set("Content-Type", "audio/mpeg")
c.Header("Content-Type", "audio/mpeg")
writeStart := time.Now()
w.Write(resp.AudioContent)
c.Writer.Write(resp.AudioContent)
writeTime := time.Since(writeStart)
// 记录总耗时
@@ -135,61 +210,51 @@ func (h *TTSHandler) HandleOpenAITTS(w http.ResponseWriter, r *http.Request) {
}
// HandleTTS 处理TTS请求
func (h *TTSHandler) HandleTTS(w http.ResponseWriter, r *http.Request) {
func (h *TTSHandler) HandleTTS(c *gin.Context) {
// 记录请求开始时间
startTime := time.Now()
// 解析请求参数
var req models.TTSRequest
switch r.Method {
switch c.Request.Method {
case http.MethodGet:
// 从URL参数获取
q := r.URL.Query()
req = models.TTSRequest{
Text: q.Get("t"),
Voice: q.Get("v"),
Rate: q.Get("r"),
Pitch: q.Get("p"),
Style: q.Get("s"),
Text: c.Query("t"),
Voice: c.Query("v"),
Rate: c.Query("r"),
Pitch: c.Query("p"),
Style: c.Query("s"),
}
case http.MethodPost:
// 从POST JSON体获取
if r.Header.Get("Content-Type") == "application/json" {
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
if c.ContentType() == "application/json" {
if err := c.ShouldBindJSON(&req); err != nil {
log.Printf("JSON解析错误: %v", err)
http.Error(w, "无效的JSON请求", http.StatusBadRequest)
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "无效的JSON请求"})
return
}
} else {
// 表单数据
if err := r.ParseForm(); err != nil {
if err := c.ShouldBind(&req); err != nil {
log.Printf("表单解析错误: %v", err)
http.Error(w, "无法解析表单数据", http.StatusBadRequest)
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "无法解析表单数据"})
return
}
req = models.TTSRequest{
Text: r.FormValue("text"),
Voice: r.FormValue("voice"),
Rate: r.FormValue("rate"),
Pitch: r.FormValue("pitch"),
Style: r.FormValue("style"),
}
}
default:
log.Printf("不支持的HTTP方法: %s", r.Method)
http.Error(w, "仅支持GET和POST请求", http.StatusMethodNotAllowed)
log.Printf("不支持的HTTP方法: %s", c.Request.Method)
c.AbortWithStatusJSON(http.StatusMethodNotAllowed, gin.H{"error": "仅支持GET和POST请求"})
return
}
// 记录参数解析耗时
parseTime := time.Since(startTime)
log.Printf("请求参数解析耗时: %v", parseTime)
// 验证必要参数
if req.Text == "" {
log.Print("错误: 未提供文本参数")
http.Error(w, "必须提供文本参数", http.StatusBadRequest)
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "必须提供文本参数"})
return
}
@@ -206,7 +271,7 @@ func (h *TTSHandler) HandleTTS(w http.ResponseWriter, r *http.Request) {
// 检查文本长度
if len(req.Text) > h.config.TTS.MaxTextLength {
http.Error(w, "文本长度超过限制", http.StatusBadRequest)
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "文本长度超过限制"})
return
}
@@ -215,24 +280,24 @@ func (h *TTSHandler) HandleTTS(w http.ResponseWriter, r *http.Request) {
if len(req.Text) > segmentThreshold && len(req.Text) <= h.config.TTS.MaxTextLength {
log.Printf("文本长度 %d 超过阈值 %d使用分段处理", len(req.Text), segmentThreshold)
// 如果文本长度超过阈值但小于最大限制,使用分段处理
h.handleSegmentedTTS(w, r, req)
h.handleSegmentedTTS(c, req)
return
}
synthStart := time.Now()
resp, err := h.ttsService.SynthesizeSpeech(r.Context(), req)
resp, err := h.ttsService.SynthesizeSpeech(c.Request.Context(), req)
synthTime := time.Since(synthStart)
log.Printf("TTS合成耗时: %v, 文本长度: %d", synthTime, len(req.Text))
if err != nil {
http.Error(w, "语音合成失败: "+err.Error(), http.StatusInternalServerError)
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "语音合成失败: " + err.Error()})
return
}
// 设置响应
w.Header().Set("Content-Type", "audio/mpeg")
c.Header("Content-Type", "audio/mpeg")
writeStart := time.Now()
w.Write(resp.AudioContent)
c.Writer.Write(resp.AudioContent)
writeTime := time.Since(writeStart)
// 记录总耗时
@@ -242,7 +307,7 @@ func (h *TTSHandler) HandleTTS(w http.ResponseWriter, r *http.Request) {
}
// handleSegmentedTTS 处理长文本的分段TTS请求
func (h *TTSHandler) handleSegmentedTTS(w http.ResponseWriter, r *http.Request, req models.TTSRequest) {
func (h *TTSHandler) handleSegmentedTTS(c *gin.Context, req models.TTSRequest) {
segmentStart := time.Now() // 分段处理开始时间
text := req.Text
@@ -296,7 +361,7 @@ func (h *TTSHandler) handleSegmentedTTS(w http.ResponseWriter, r *http.Request,
segStart := time.Now()
// 合成该段音频
resp, err := h.ttsService.SynthesizeSpeech(r.Context(), segReq)
resp, err := h.ttsService.SynthesizeSpeech(c.Request.Context(), segReq)
// 记录该段合成耗时
segTime := time.Since(segStart)
@@ -331,7 +396,7 @@ func (h *TTSHandler) handleSegmentedTTS(w http.ResponseWriter, r *http.Request,
// 检查是否有错误发生
if err := <-errChan; err != nil {
http.Error(w, "语音合成失败: "+err.Error(), http.StatusInternalServerError)
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "语音合成失败: " + err.Error()})
return
}
@@ -345,16 +410,16 @@ func (h *TTSHandler) handleSegmentedTTS(w http.ResponseWriter, r *http.Request,
if err != nil {
log.Printf("合并音频失败: %v", err)
http.Error(w, "音频合并失败: "+err.Error(), http.StatusInternalServerError)
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "音频合并失败: " + err.Error()})
return
}
// 设置响应内容类型
w.Header().Set("Content-Type", "audio/mpeg")
c.Header("Content-Type", "audio/mpeg")
// 写入合并后的音频数据
totalSize := len(audioData)
if _, writeErr := w.Write(audioData); writeErr != nil {
if _, writeErr := c.Writer.Write(audioData); writeErr != nil {
log.Printf("写入响应失败: %v", writeErr)
}
@@ -468,7 +533,6 @@ func splitTextBySentences(text string) []string {
if currentMerged.Len() > 0 {
mergedSentence := currentMerged.String()
mergedSentences = append(mergedSentences, mergedSentence)
log.Printf("添加最后剩余的合并句子,长度=%d", utf8.RuneCountInString(mergedSentence))
}
return mergedSentences
@@ -476,77 +540,3 @@ func splitTextBySentences(text string) []string {
return sentences
}
// truncateForLog 截断文本用于日志显示,同时显示开头和结尾
func truncateForLog(text string, maxLength int) string {
// 先去除换行符
text = strings.ReplaceAll(text, "\n", " ")
text = strings.ReplaceAll(text, "\r", " ")
runes := []rune(text)
if len(runes) <= maxLength {
return text
}
// 计算开头和结尾各显示多少字符
halfLength := maxLength / 2
return string(runes[:halfLength]) + "..." + string(runes[len(runes)-halfLength:])
}
// audioMerge 音频合并
func audioMerge(audioSegments [][]byte) ([]byte, error) {
if len(audioSegments) == 0 {
return nil, fmt.Errorf("没有音频片段可合并")
}
// 使用 ffmpeg 合并音频
tempDir, err := os.MkdirTemp("", "audio_merge_")
if err != nil {
return nil, err
}
defer os.RemoveAll(tempDir)
listFile := filepath.Join(tempDir, "concat.txt")
lf, err := os.Create(listFile)
if err != nil {
return nil, err
}
for i, seg := range audioSegments {
segFile := filepath.Join(tempDir, fmt.Sprintf("seg_%d.mp3", i))
if err := os.WriteFile(segFile, seg, 0644); err != nil {
return nil, err
}
if _, err := lf.WriteString(fmt.Sprintf("file '%s'\n", segFile)); err != nil {
return nil, err
}
}
lf.Close()
outputFile := filepath.Join(tempDir, "output.mp3")
cmd := exec.Command("ffmpeg", "-y", "-f", "concat", "-safe", "0", "-i", listFile, "-c", "copy", outputFile)
if err := cmd.Run(); err != nil {
return nil, err
}
mergedData, err := os.ReadFile(outputFile)
if err != nil {
return nil, err
}
log.Printf("使用ffmpeg合并完成总大小: %s", formatFileSize(len(mergedData)))
return mergedData, nil
}
// formatFileSize 格式化文件大小
func formatFileSize(size int) string {
switch {
case size < 1024:
return fmt.Sprintf("%d B", size)
case size < 1024*1024:
return fmt.Sprintf("%.2f KB", float64(size)/1024.0)
case size < 1024*1024*1024:
return fmt.Sprintf("%.2f MB", float64(size)/(1024.0*1024.0))
default:
return fmt.Sprintf("%.2f GB", float64(size)/(1024.0*1024.0*1024.0))
}
}

View File

@@ -1,9 +1,10 @@
package handlers
import (
"encoding/json"
"net/http"
"tts/internal/tts"
"github.com/gin-gonic/gin"
)
// VoicesHandler 处理语音列表请求
@@ -19,23 +20,17 @@ func NewVoicesHandler(service tts.Service) *VoicesHandler {
}
// HandleVoices 处理语音列表请求
func (h *VoicesHandler) HandleVoices(w http.ResponseWriter, r *http.Request) {
func (h *VoicesHandler) HandleVoices(c *gin.Context) {
// 从查询参数中获取语言筛选
locale := r.URL.Query().Get("locale")
locale := c.Query("locale")
// 获取语音列表
voices, err := h.ttsService.ListVoices(r.Context(), locale)
voices, err := h.ttsService.ListVoices(c.Request.Context(), locale)
if err != nil {
http.Error(w, "获取语音列表失败: "+err.Error(), http.StatusInternalServerError)
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "获取语音列表失败: " + err.Error()})
return
}
// 设置内容类型
w.Header().Set("Content-Type", "application/json")
// 编码为JSON并返回
if err := json.NewEncoder(w).Encode(voices); err != nil {
http.Error(w, "JSON编码失败", http.StatusInternalServerError)
return
}
// 返回JSON响应
c.JSON(http.StatusOK, voices)
}

View File

@@ -1,58 +1,58 @@
package middleware
import (
"net/http"
"strings"
"github.com/gin-gonic/gin"
)
// OpenAIAuth 中间件验证 OpenAI API 请求的令牌
func OpenAIAuth(apiToken string, next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
func OpenAIAuth(apiToken string) gin.HandlerFunc {
return func(c *gin.Context) {
// 如果没有配置令牌,跳过验证
if apiToken == "" {
next.ServeHTTP(w, r)
c.Next()
return
}
// 获取请求头中的 Authorization
authHeader := r.Header.Get("Authorization")
authHeader := c.GetHeader("Authorization")
if authHeader == "" {
http.Error(w, "未提供授权令牌", http.StatusUnauthorized)
c.AbortWithStatusJSON(401, gin.H{"error": "未提供授权令牌"})
return
}
// 验证格式是否为 "Bearer {token}"
parts := strings.SplitN(authHeader, " ", 2)
if len(parts) != 2 || parts[0] != "Bearer" {
http.Error(w, "授权格式无效", http.StatusUnauthorized)
c.AbortWithStatusJSON(401, gin.H{"error": "授权格式无效"})
return
}
// 验证令牌是否正确
if parts[1] != apiToken {
http.Error(w, "令牌无效", http.StatusUnauthorized)
c.AbortWithStatusJSON(401, gin.H{"error": "令牌无效"})
return
}
// 令牌验证通过,继续处理请求
next.ServeHTTP(w, r)
})
c.Next()
}
}
// TTSAuth 是用于验证 TTS API 接口的中间件
func TTSAuth(apiKey string, next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
func TTSAuth(apiKey string) gin.HandlerFunc {
return func(c *gin.Context) {
// 从查询参数中获取 api_key
queryKey := r.URL.Query().Get("api_key")
queryKey := c.Query("api_key")
// 如果 apiKey 配置为空字符串,表示不需要验证
if apiKey != "" && queryKey != apiKey {
w.WriteHeader(http.StatusUnauthorized)
w.Write([]byte("未授权访问: 无效的 API 密钥"))
c.AbortWithStatusJSON(401, gin.H{"error": "未授权访问: 无效的 API 密钥"})
return
}
// 验证通过,继续处理请求
next.ServeHTTP(w, r)
})
c.Next()
}
}

View File

@@ -1,22 +1,22 @@
package middleware
import "net/http"
import "github.com/gin-gonic/gin"
// CORS 处理跨域资源共享
func CORS(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
func CORS() gin.HandlerFunc {
return func(c *gin.Context) {
// 设置CORS响应头
w.Header().Set("Access-Control-Allow-Origin", "*")
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, OPTIONS")
w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization")
c.Header("Access-Control-Allow-Origin", "*")
c.Header("Access-Control-Allow-Methods", "GET, POST, OPTIONS")
c.Header("Access-Control-Allow-Headers", "Content-Type, Authorization")
// 如果是预检请求直接返回200
if r.Method == http.MethodOptions {
w.WriteHeader(http.StatusOK)
if c.Request.Method == "OPTIONS" {
c.AbortWithStatus(200)
return
}
// 继续下一个处理器
next.ServeHTTP(w, r)
})
c.Next()
}
}

View File

@@ -2,45 +2,27 @@ package middleware
import (
"log"
"net/http"
"time"
"github.com/gin-gonic/gin"
)
// Logger 是一个HTTP中间件记录请求的详细信息
func Logger(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
func Logger() gin.HandlerFunc {
return func(c *gin.Context) {
start := time.Now()
// 包装ResponseWriter以捕获状态码
wrapper := &responseWriterWrapper{
ResponseWriter: w,
statusCode: http.StatusOK,
}
// 调用下一个处理器
next.ServeHTTP(wrapper, r)
// 处理请求
c.Next()
// 记录请求信息
duration := time.Since(start)
log.Printf(
"[%s] %s %s %d %s",
r.Method,
r.RequestURI,
r.RemoteAddr,
wrapper.statusCode,
log.Printf("[%s] %s %s %d %s",
c.Request.Method,
c.Request.URL.Path,
c.ClientIP(),
c.Writer.Status(),
duration,
)
})
}
// responseWriterWrapper 包装http.ResponseWriter以捕获状态码
type responseWriterWrapper struct {
http.ResponseWriter
statusCode int
}
// WriteHeader 捕获状态码
func (w *responseWriterWrapper) WriteHeader(statusCode int) {
w.statusCode = statusCode
w.ResponseWriter.WriteHeader(statusCode)
}
}

View File

@@ -1,18 +1,19 @@
package routes
import (
"net/http"
"tts/internal/config"
"tts/internal/http/handlers"
"tts/internal/http/middleware"
"tts/internal/tts"
"tts/internal/tts/microsoft"
"github.com/gin-gonic/gin"
)
// SetupRoutes 配置所有API路由
func SetupRoutes(cfg *config.Config, ttsService tts.Service) (http.Handler, error) {
// 创建一个新的路由多路复用器
mux := http.NewServeMux()
func SetupRoutes(cfg *config.Config, ttsService tts.Service) (*gin.Engine, error) {
// 创建Gin路由
router := gin.New()
// 创建处理器
ttsHandler := handlers.NewTTSHandler(ttsService, cfg)
@@ -24,43 +25,41 @@ func SetupRoutes(cfg *config.Config, ttsService tts.Service) (http.Handler, erro
return nil, err
}
// 设置主页路由
mux.HandleFunc("/", pagesHandler.HandleIndex)
// 设置API文档路由
mux.HandleFunc("/api-doc", pagesHandler.HandleAPIDoc)
// 设置TTS API路由 - 添加认证中间件
ttsHandlerFunc := http.HandlerFunc(ttsHandler.HandleTTS)
authenticatedTTSHandler := middleware.TTSAuth(cfg.TTS.ApiKey, ttsHandlerFunc)
mux.Handle("/tts", authenticatedTTSHandler)
// 设置语音列表API路由
mux.HandleFunc("/voices", voicesHandler.HandleVoices)
// 创建OpenAI兼容接口的处理器添加验证中间件
openAIHandler := http.HandlerFunc(ttsHandler.HandleOpenAITTS)
authenticatedHandler := middleware.OpenAIAuth(cfg.OpenAI.ApiKey, openAIHandler)
// 应用OpenAI兼容的路由
mux.Handle("/v1/audio/speech", authenticatedHandler)
mux.Handle("/audio/speech", authenticatedHandler)
// 设置静态文件服务
fs := http.FileServer(http.Dir("./web/static"))
mux.Handle("/static/", http.StripPrefix("/static/", fs))
// 应用中间件
router.Use(middleware.Logger()) // 日志中间件
router.Use(middleware.CORS()) // CORS中间件
// 应用基础路径前缀
var handler http.Handler = mux
var baseRouter gin.IRoutes
if cfg.Server.BasePath != "" {
handler = http.StripPrefix(cfg.Server.BasePath, mux)
baseRouter = router.Group(cfg.Server.BasePath)
} else {
baseRouter = router
}
// 应用中间件
handler = middleware.Logger(handler) // 日志中间件
handler = middleware.CORS(handler) // CORS中间件
// 设置静态文件服务
baseRouter.Static("/static", "./web/static")
return handler, nil
// 设置主页路由
baseRouter.GET("/", pagesHandler.HandleIndex)
// 设置API文档路由
baseRouter.GET("/api-doc", pagesHandler.HandleAPIDoc)
// 设置TTS API路由 - 添加认证中间件
baseRouter.POST("/tts", middleware.TTSAuth(cfg.TTS.ApiKey), ttsHandler.HandleTTS)
baseRouter.GET("/tts", middleware.TTSAuth(cfg.TTS.ApiKey), ttsHandler.HandleTTS)
// 设置语音列表API路由
baseRouter.GET("/voices", voicesHandler.HandleVoices)
// 设置OpenAI兼容接口的处理器添加验证中间件
openAIHandler := middleware.OpenAIAuth(cfg.OpenAI.ApiKey)
baseRouter.POST("/v1/audio/speech", openAIHandler, ttsHandler.HandleOpenAITTS)
baseRouter.POST("/audio/speech", openAIHandler, ttsHandler.HandleOpenAITTS)
return router, nil
}
// InitializeServices 初始化所有服务

View File

@@ -32,14 +32,14 @@ func NewApp(configPath string) (*App, error) {
return nil, fmt.Errorf("初始化服务失败: %w", err)
}
// 设置路由
handler, err := routes.SetupRoutes(cfg, ttsService)
// 设置Gin路由
router, err := routes.SetupRoutes(cfg, ttsService)
if err != nil {
return nil, fmt.Errorf("设置路由失败: %w", err)
}
// 创建HTTP服务器
server := New(cfg, handler)
server := New(cfg, router)
return &App{
server: server,

View File

@@ -3,43 +3,36 @@ package server
import (
"context"
"fmt"
"net/http"
"time"
"github.com/gin-gonic/gin"
"tts/internal/config"
)
// Server 封装HTTP服务器
type Server struct {
server *http.Server
router *gin.Engine
basePath string
port int
}
// New 创建新的HTTP服务器
func New(cfg *config.Config, handler http.Handler) *Server {
// 创建HTTP服务器
httpServer := &http.Server{
Addr: fmt.Sprintf(":%d", cfg.Server.Port),
Handler: handler,
ReadTimeout: time.Duration(cfg.Server.ReadTimeout) * time.Second,
WriteTimeout: time.Duration(cfg.Server.WriteTimeout) * time.Second,
IdleTimeout: 120 * time.Second,
}
func New(cfg *config.Config, router *gin.Engine) *Server {
return &Server{
server: httpServer,
router: router,
basePath: cfg.Server.BasePath,
port: cfg.Server.Port,
}
}
// Start 启动HTTP服务器
func (s *Server) Start() error {
fmt.Printf("服务启动在 %s\n", s.server.Addr)
return s.server.ListenAndServe()
addr := fmt.Sprintf(":%d", s.port)
return s.router.Run(addr)
}
// Shutdown 优雅关闭服务器
func (s *Server) Shutdown(ctx context.Context) error {
fmt.Println("正在关闭HTTP服务器...")
return s.server.Shutdown(ctx)
// Gin 本身没有提供 Shutdown 方法,需要手动实现
// 这里可以添加自定义的关闭逻辑
return nil
}