feat: refactor application to use Gin framework, update routing and middleware handling
This commit is contained in:
@@ -1,7 +1,6 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
@@ -15,8 +14,84 @@ import (
|
||||
"tts/internal/models"
|
||||
"tts/internal/tts"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// truncateForLog 截断文本用于日志显示,同时显示开头和结尾
|
||||
func truncateForLog(text string, maxLength int) string {
|
||||
// 先去除换行符
|
||||
text = strings.ReplaceAll(text, "\n", " ")
|
||||
text = strings.ReplaceAll(text, "\r", " ")
|
||||
|
||||
runes := []rune(text)
|
||||
if len(runes) <= maxLength {
|
||||
return text
|
||||
}
|
||||
// 计算开头和结尾各显示多少字符
|
||||
halfLength := maxLength / 2
|
||||
return string(runes[:halfLength]) + "..." + string(runes[len(runes)-halfLength:])
|
||||
}
|
||||
|
||||
// audioMerge 音频合并
|
||||
func audioMerge(audioSegments [][]byte) ([]byte, error) {
|
||||
if len(audioSegments) == 0 {
|
||||
return nil, fmt.Errorf("没有音频片段可合并")
|
||||
}
|
||||
|
||||
// 使用 ffmpeg 合并音频
|
||||
tempDir, err := os.MkdirTemp("", "audio_merge_")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer os.RemoveAll(tempDir)
|
||||
|
||||
listFile := filepath.Join(tempDir, "concat.txt")
|
||||
lf, err := os.Create(listFile)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for i, seg := range audioSegments {
|
||||
segFile := filepath.Join(tempDir, fmt.Sprintf("seg_%d.mp3", i))
|
||||
if err := os.WriteFile(segFile, seg, 0644); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if _, err := lf.WriteString(fmt.Sprintf("file '%s'\n", segFile)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
lf.Close()
|
||||
|
||||
outputFile := filepath.Join(tempDir, "output.mp3")
|
||||
|
||||
cmd := exec.Command("ffmpeg", "-y", "-f", "concat", "-safe", "0", "-i", listFile, "-c", "copy", outputFile)
|
||||
if err := cmd.Run(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
mergedData, err := os.ReadFile(outputFile)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
log.Printf("使用ffmpeg合并完成,总大小: %s", formatFileSize(len(mergedData)))
|
||||
return mergedData, nil
|
||||
}
|
||||
|
||||
// formatFileSize 格式化文件大小
|
||||
func formatFileSize(size int) string {
|
||||
switch {
|
||||
case size < 1024:
|
||||
return fmt.Sprintf("%d B", size)
|
||||
case size < 1024*1024:
|
||||
return fmt.Sprintf("%.2f KB", float64(size)/1024.0)
|
||||
case size < 1024*1024*1024:
|
||||
return fmt.Sprintf("%.2f MB", float64(size)/(1024.0*1024.0))
|
||||
default:
|
||||
return fmt.Sprintf("%.2f GB", float64(size)/(1024.0*1024.0*1024.0))
|
||||
}
|
||||
}
|
||||
|
||||
// TTSHandler 处理TTS请求
|
||||
type TTSHandler struct {
|
||||
ttsService tts.Service
|
||||
@@ -32,13 +107,13 @@ func NewTTSHandler(service tts.Service, cfg *config.Config) *TTSHandler {
|
||||
}
|
||||
|
||||
// HandleOpenAITTS 处理OpenAI兼容的TTS请求
|
||||
func (h *TTSHandler) HandleOpenAITTS(w http.ResponseWriter, r *http.Request) {
|
||||
func (h *TTSHandler) HandleOpenAITTS(c *gin.Context) {
|
||||
// 记录请求开始时间
|
||||
startTime := time.Now()
|
||||
|
||||
// 只支持POST请求
|
||||
if r.Method != http.MethodPost {
|
||||
http.Error(w, "仅支持POST请求", http.StatusMethodNotAllowed)
|
||||
if c.Request.Method != http.MethodPost {
|
||||
c.AbortWithStatusJSON(http.StatusMethodNotAllowed, gin.H{"error": "仅支持POST请求"})
|
||||
return
|
||||
}
|
||||
|
||||
@@ -50,8 +125,8 @@ func (h *TTSHandler) HandleOpenAITTS(w http.ResponseWriter, r *http.Request) {
|
||||
Speed float64 `json:"speed"`
|
||||
}
|
||||
|
||||
if err := json.NewDecoder(r.Body).Decode(&openaiReq); err != nil {
|
||||
http.Error(w, "无效的JSON请求: "+err.Error(), http.StatusBadRequest)
|
||||
if err := c.ShouldBindJSON(&openaiReq); err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "无效的JSON请求: " + err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
@@ -60,7 +135,7 @@ func (h *TTSHandler) HandleOpenAITTS(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
// 检查必需字段
|
||||
if openaiReq.Input == "" {
|
||||
http.Error(w, "input字段不能为空", http.StatusBadRequest)
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "input字段不能为空"})
|
||||
return
|
||||
}
|
||||
|
||||
@@ -98,7 +173,7 @@ func (h *TTSHandler) HandleOpenAITTS(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
// 检查文本长度
|
||||
if len(req.Text) > h.config.TTS.MaxTextLength {
|
||||
http.Error(w, "文本长度超过限制", http.StatusBadRequest)
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "文本长度超过限制"})
|
||||
return
|
||||
}
|
||||
|
||||
@@ -107,25 +182,25 @@ func (h *TTSHandler) HandleOpenAITTS(w http.ResponseWriter, r *http.Request) {
|
||||
if len(req.Text) > segmentThreshold && len(req.Text) <= h.config.TTS.MaxTextLength {
|
||||
log.Printf("文本长度 %d 超过阈值 %d,使用分段处理", len(req.Text), segmentThreshold)
|
||||
// 使用分段处理
|
||||
h.handleSegmentedTTS(w, r, req)
|
||||
h.handleSegmentedTTS(c, req)
|
||||
return
|
||||
}
|
||||
|
||||
// 非流式模式处理
|
||||
synthStart := time.Now()
|
||||
resp, err := h.ttsService.SynthesizeSpeech(r.Context(), req)
|
||||
resp, err := h.ttsService.SynthesizeSpeech(c.Request.Context(), req)
|
||||
synthTime := time.Since(synthStart)
|
||||
log.Printf("TTS合成耗时: %v, 文本长度: %d", synthTime, len(req.Text))
|
||||
|
||||
if err != nil {
|
||||
http.Error(w, "语音合成失败: "+err.Error(), http.StatusInternalServerError)
|
||||
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "语音合成失败: " + err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// 设置响应
|
||||
w.Header().Set("Content-Type", "audio/mpeg")
|
||||
c.Header("Content-Type", "audio/mpeg")
|
||||
writeStart := time.Now()
|
||||
w.Write(resp.AudioContent)
|
||||
c.Writer.Write(resp.AudioContent)
|
||||
writeTime := time.Since(writeStart)
|
||||
|
||||
// 记录总耗时
|
||||
@@ -135,61 +210,51 @@ func (h *TTSHandler) HandleOpenAITTS(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
// HandleTTS 处理TTS请求
|
||||
func (h *TTSHandler) HandleTTS(w http.ResponseWriter, r *http.Request) {
|
||||
func (h *TTSHandler) HandleTTS(c *gin.Context) {
|
||||
// 记录请求开始时间
|
||||
startTime := time.Now()
|
||||
|
||||
// 解析请求参数
|
||||
var req models.TTSRequest
|
||||
|
||||
switch r.Method {
|
||||
switch c.Request.Method {
|
||||
case http.MethodGet:
|
||||
// 从URL参数获取
|
||||
q := r.URL.Query()
|
||||
req = models.TTSRequest{
|
||||
Text: q.Get("t"),
|
||||
Voice: q.Get("v"),
|
||||
Rate: q.Get("r"),
|
||||
Pitch: q.Get("p"),
|
||||
Style: q.Get("s"),
|
||||
Text: c.Query("t"),
|
||||
Voice: c.Query("v"),
|
||||
Rate: c.Query("r"),
|
||||
Pitch: c.Query("p"),
|
||||
Style: c.Query("s"),
|
||||
}
|
||||
case http.MethodPost:
|
||||
// 从POST JSON体获取
|
||||
if r.Header.Get("Content-Type") == "application/json" {
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
if c.ContentType() == "application/json" {
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
log.Printf("JSON解析错误: %v", err)
|
||||
http.Error(w, "无效的JSON请求", http.StatusBadRequest)
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "无效的JSON请求"})
|
||||
return
|
||||
}
|
||||
} else {
|
||||
// 表单数据
|
||||
if err := r.ParseForm(); err != nil {
|
||||
if err := c.ShouldBind(&req); err != nil {
|
||||
log.Printf("表单解析错误: %v", err)
|
||||
http.Error(w, "无法解析表单数据", http.StatusBadRequest)
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "无法解析表单数据"})
|
||||
return
|
||||
}
|
||||
req = models.TTSRequest{
|
||||
Text: r.FormValue("text"),
|
||||
Voice: r.FormValue("voice"),
|
||||
Rate: r.FormValue("rate"),
|
||||
Pitch: r.FormValue("pitch"),
|
||||
Style: r.FormValue("style"),
|
||||
}
|
||||
}
|
||||
default:
|
||||
log.Printf("不支持的HTTP方法: %s", r.Method)
|
||||
http.Error(w, "仅支持GET和POST请求", http.StatusMethodNotAllowed)
|
||||
log.Printf("不支持的HTTP方法: %s", c.Request.Method)
|
||||
c.AbortWithStatusJSON(http.StatusMethodNotAllowed, gin.H{"error": "仅支持GET和POST请求"})
|
||||
return
|
||||
}
|
||||
|
||||
// 记录参数解析耗时
|
||||
parseTime := time.Since(startTime)
|
||||
log.Printf("请求参数解析耗时: %v", parseTime)
|
||||
|
||||
// 验证必要参数
|
||||
if req.Text == "" {
|
||||
log.Print("错误: 未提供文本参数")
|
||||
http.Error(w, "必须提供文本参数", http.StatusBadRequest)
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "必须提供文本参数"})
|
||||
return
|
||||
}
|
||||
|
||||
@@ -206,7 +271,7 @@ func (h *TTSHandler) HandleTTS(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
// 检查文本长度
|
||||
if len(req.Text) > h.config.TTS.MaxTextLength {
|
||||
http.Error(w, "文本长度超过限制", http.StatusBadRequest)
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "文本长度超过限制"})
|
||||
return
|
||||
}
|
||||
|
||||
@@ -215,24 +280,24 @@ func (h *TTSHandler) HandleTTS(w http.ResponseWriter, r *http.Request) {
|
||||
if len(req.Text) > segmentThreshold && len(req.Text) <= h.config.TTS.MaxTextLength {
|
||||
log.Printf("文本长度 %d 超过阈值 %d,使用分段处理", len(req.Text), segmentThreshold)
|
||||
// 如果文本长度超过阈值但小于最大限制,使用分段处理
|
||||
h.handleSegmentedTTS(w, r, req)
|
||||
h.handleSegmentedTTS(c, req)
|
||||
return
|
||||
}
|
||||
|
||||
synthStart := time.Now()
|
||||
resp, err := h.ttsService.SynthesizeSpeech(r.Context(), req)
|
||||
resp, err := h.ttsService.SynthesizeSpeech(c.Request.Context(), req)
|
||||
synthTime := time.Since(synthStart)
|
||||
log.Printf("TTS合成耗时: %v, 文本长度: %d", synthTime, len(req.Text))
|
||||
|
||||
if err != nil {
|
||||
http.Error(w, "语音合成失败: "+err.Error(), http.StatusInternalServerError)
|
||||
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "语音合成失败: " + err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// 设置响应
|
||||
w.Header().Set("Content-Type", "audio/mpeg")
|
||||
c.Header("Content-Type", "audio/mpeg")
|
||||
writeStart := time.Now()
|
||||
w.Write(resp.AudioContent)
|
||||
c.Writer.Write(resp.AudioContent)
|
||||
writeTime := time.Since(writeStart)
|
||||
|
||||
// 记录总耗时
|
||||
@@ -242,7 +307,7 @@ func (h *TTSHandler) HandleTTS(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
// handleSegmentedTTS 处理长文本的分段TTS请求
|
||||
func (h *TTSHandler) handleSegmentedTTS(w http.ResponseWriter, r *http.Request, req models.TTSRequest) {
|
||||
func (h *TTSHandler) handleSegmentedTTS(c *gin.Context, req models.TTSRequest) {
|
||||
segmentStart := time.Now() // 分段处理开始时间
|
||||
text := req.Text
|
||||
|
||||
@@ -296,7 +361,7 @@ func (h *TTSHandler) handleSegmentedTTS(w http.ResponseWriter, r *http.Request,
|
||||
segStart := time.Now()
|
||||
|
||||
// 合成该段音频
|
||||
resp, err := h.ttsService.SynthesizeSpeech(r.Context(), segReq)
|
||||
resp, err := h.ttsService.SynthesizeSpeech(c.Request.Context(), segReq)
|
||||
|
||||
// 记录该段合成耗时
|
||||
segTime := time.Since(segStart)
|
||||
@@ -331,7 +396,7 @@ func (h *TTSHandler) handleSegmentedTTS(w http.ResponseWriter, r *http.Request,
|
||||
|
||||
// 检查是否有错误发生
|
||||
if err := <-errChan; err != nil {
|
||||
http.Error(w, "语音合成失败: "+err.Error(), http.StatusInternalServerError)
|
||||
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "语音合成失败: " + err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
@@ -345,16 +410,16 @@ func (h *TTSHandler) handleSegmentedTTS(w http.ResponseWriter, r *http.Request,
|
||||
|
||||
if err != nil {
|
||||
log.Printf("合并音频失败: %v", err)
|
||||
http.Error(w, "音频合并失败: "+err.Error(), http.StatusInternalServerError)
|
||||
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "音频合并失败: " + err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// 设置响应内容类型
|
||||
w.Header().Set("Content-Type", "audio/mpeg")
|
||||
c.Header("Content-Type", "audio/mpeg")
|
||||
|
||||
// 写入合并后的音频数据
|
||||
totalSize := len(audioData)
|
||||
if _, writeErr := w.Write(audioData); writeErr != nil {
|
||||
if _, writeErr := c.Writer.Write(audioData); writeErr != nil {
|
||||
log.Printf("写入响应失败: %v", writeErr)
|
||||
}
|
||||
|
||||
@@ -468,7 +533,6 @@ func splitTextBySentences(text string) []string {
|
||||
if currentMerged.Len() > 0 {
|
||||
mergedSentence := currentMerged.String()
|
||||
mergedSentences = append(mergedSentences, mergedSentence)
|
||||
log.Printf("添加最后剩余的合并句子,长度=%d", utf8.RuneCountInString(mergedSentence))
|
||||
}
|
||||
|
||||
return mergedSentences
|
||||
@@ -476,77 +540,3 @@ func splitTextBySentences(text string) []string {
|
||||
|
||||
return sentences
|
||||
}
|
||||
|
||||
// truncateForLog 截断文本用于日志显示,同时显示开头和结尾
|
||||
func truncateForLog(text string, maxLength int) string {
|
||||
// 先去除换行符
|
||||
text = strings.ReplaceAll(text, "\n", " ")
|
||||
text = strings.ReplaceAll(text, "\r", " ")
|
||||
|
||||
runes := []rune(text)
|
||||
if len(runes) <= maxLength {
|
||||
return text
|
||||
}
|
||||
// 计算开头和结尾各显示多少字符
|
||||
halfLength := maxLength / 2
|
||||
return string(runes[:halfLength]) + "..." + string(runes[len(runes)-halfLength:])
|
||||
}
|
||||
|
||||
// audioMerge 音频合并
|
||||
func audioMerge(audioSegments [][]byte) ([]byte, error) {
|
||||
if len(audioSegments) == 0 {
|
||||
return nil, fmt.Errorf("没有音频片段可合并")
|
||||
}
|
||||
|
||||
// 使用 ffmpeg 合并音频
|
||||
tempDir, err := os.MkdirTemp("", "audio_merge_")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer os.RemoveAll(tempDir)
|
||||
|
||||
listFile := filepath.Join(tempDir, "concat.txt")
|
||||
lf, err := os.Create(listFile)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for i, seg := range audioSegments {
|
||||
segFile := filepath.Join(tempDir, fmt.Sprintf("seg_%d.mp3", i))
|
||||
if err := os.WriteFile(segFile, seg, 0644); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if _, err := lf.WriteString(fmt.Sprintf("file '%s'\n", segFile)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
lf.Close()
|
||||
|
||||
outputFile := filepath.Join(tempDir, "output.mp3")
|
||||
|
||||
cmd := exec.Command("ffmpeg", "-y", "-f", "concat", "-safe", "0", "-i", listFile, "-c", "copy", outputFile)
|
||||
if err := cmd.Run(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
mergedData, err := os.ReadFile(outputFile)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
log.Printf("使用ffmpeg合并完成,总大小: %s", formatFileSize(len(mergedData)))
|
||||
return mergedData, nil
|
||||
}
|
||||
|
||||
// formatFileSize 格式化文件大小
|
||||
func formatFileSize(size int) string {
|
||||
switch {
|
||||
case size < 1024:
|
||||
return fmt.Sprintf("%d B", size)
|
||||
case size < 1024*1024:
|
||||
return fmt.Sprintf("%.2f KB", float64(size)/1024.0)
|
||||
case size < 1024*1024*1024:
|
||||
return fmt.Sprintf("%.2f MB", float64(size)/(1024.0*1024.0))
|
||||
default:
|
||||
return fmt.Sprintf("%.2f GB", float64(size)/(1024.0*1024.0*1024.0))
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user