From 9100930d3a1fe981cf92d57a9e96083492c97c79 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8E=8B=E9=94=A6=E5=BC=BA?= <1061669148@qq.com> Date: Sun, 16 Mar 2025 20:24:04 +0800 Subject: [PATCH] feat: enhance TTS request handling by refactoring methods, adding OpenAI request support, and improving text segmentation --- go.mod | 3 +- go.sum | 13 +- internal/http/handlers/tts.go | 612 +++++++++++++++------------------- internal/models/tts.go | 8 + internal/utils/utils.go | 60 ++++ 5 files changed, 348 insertions(+), 348 deletions(-) diff --git a/go.mod b/go.mod index e774043..d915753 100644 --- a/go.mod +++ b/go.mod @@ -5,6 +5,7 @@ go 1.23.0 toolchain go1.24.0 require ( + github.com/gin-gonic/gin v1.10.0 github.com/google/uuid v1.6.0 github.com/sirupsen/logrus v1.9.3 github.com/spf13/viper v1.19.0 @@ -14,11 +15,9 @@ require ( github.com/bytedance/sonic v1.13.1 // indirect github.com/bytedance/sonic/loader v0.2.4 // indirect github.com/cloudwego/base64x v0.1.5 // indirect - github.com/cloudwego/iasm v0.2.0 // indirect github.com/fsnotify/fsnotify v1.7.0 // indirect github.com/gabriel-vasile/mimetype v1.4.8 // indirect github.com/gin-contrib/sse v1.0.0 // indirect - github.com/gin-gonic/gin v1.10.0 // indirect github.com/go-playground/locales v0.14.1 // indirect github.com/go-playground/universal-translator v0.18.1 // indirect github.com/go-playground/validator/v10 v10.25.0 // indirect diff --git a/go.sum b/go.sum index 91c2d96..31d3ee7 100644 --- a/go.sum +++ b/go.sum @@ -5,7 +5,6 @@ github.com/bytedance/sonic/loader v0.2.4 h1:ZWCw4stuXUsn1/+zQDqeE7JKP+QO47tz7QCN github.com/bytedance/sonic/loader v0.2.4/go.mod h1:N8A3vUdtUebEY2/VQC0MyhYeKUFosQU6FxH2JmUe6VI= github.com/cloudwego/base64x v0.1.5 h1:XPciSp1xaq2VCSt6lF0phncD4koWyULpl5bUxbfCyP4= github.com/cloudwego/base64x v0.1.5/go.mod h1:0zlkT4Wn5C6NdauXdJRhSKRlJvmclQ1hhJgA0rcu/8w= -github.com/cloudwego/iasm v0.2.0 h1:1KNIy1I1H9hNNFEEH3DVnI4UujN+1zjpuk6gwHLTssg= github.com/cloudwego/iasm v0.2.0/go.mod h1:8rXZaNYT2n95jn+zTI1sDr+IgcD2GVs0nlbbQPiEFhY= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= @@ -21,6 +20,8 @@ github.com/gin-contrib/sse v1.0.0 h1:y3bT1mUWUxDpW4JLQg/HnTqV4rozuW4tC9eFKTxYI9E github.com/gin-contrib/sse v1.0.0/go.mod h1:zNuFdwarAygJBht0NTKiSi3jRf6RbqeILZ9Sp6Slhe0= github.com/gin-gonic/gin v1.10.0 h1:nTuyha1TYqgedzytsKYqna+DfLos46nTv2ygFy86HFU= github.com/gin-gonic/gin v1.10.0/go.mod h1:4PMNQiOhvDRa013RKVbsiNwoyezlm2rm0uX/T7kzp5Y= +github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s= +github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4= github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA= github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY= github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY= @@ -59,8 +60,6 @@ github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M= github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= -github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6Wq+LM= -github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs= github.com/pelletier/go-toml/v2 v2.2.3 h1:YmeHyLY8mFWbdkNWwpr+qIL2bEqT0o95WSdkNHvL12M= github.com/pelletier/go-toml/v2 v2.2.3/go.mod h1:MfCQTFTvCcUyyvvwm1+G6H/jORL20Xlb6rzQu9GuUkc= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= @@ -94,8 +93,7 @@ github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/ github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= -github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= -github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8= github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU= @@ -117,12 +115,8 @@ golang.org/x/net v0.37.0 h1:1zLorHbz+LYj7MQlSf1+2tPIIgibq2eL5xkrGk6f+2c= golang.org/x/net v0.37.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8= golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.20.0 h1:Od9JTbYCk261bKm4M/mw7AklTlFYIa0bIp9BgSm1S8Y= -golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.31.0 h1:ioabZlmFYtWhL+TRYpcnNlLwhyxaM9kWTDEmfnprqik= golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= -golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= -golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= golang.org/x/text v0.23.0 h1:D71I7dUrlY+VX0gQShAThNGHFxZ13dGLBHQLVl1mJlY= golang.org/x/text v0.23.0/go.mod h1:/BLNzu4aZCJ1+kcD0DNRotWKage4q2rGVAg4o22unh4= google.golang.org/protobuf v1.36.5 h1:tPhr+woSbjfYvY6/GPufUoYizxw1cF/yFoxJ2fmpwlM= @@ -136,4 +130,3 @@ gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= nullprogram.com/x/optparse v1.0.0/go.mod h1:KdyPE+Igbe0jQUrVfMqDMeJQIJZEuyV7pjYmp6pbG50= -rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4= diff --git a/internal/http/handlers/tts.go b/internal/http/handlers/tts.go index 6a095a9..d7ed79c 100644 --- a/internal/http/handlers/tts.go +++ b/internal/http/handlers/tts.go @@ -13,6 +13,7 @@ import ( "tts/internal/config" "tts/internal/models" "tts/internal/tts" + "tts/internal/utils" "unicode/utf8" "github.com/gin-gonic/gin" @@ -106,151 +107,8 @@ func NewTTSHandler(service tts.Service, cfg *config.Config) *TTSHandler { } } -// HandleOpenAITTS 处理OpenAI兼容的TTS请求 -func (h *TTSHandler) HandleOpenAITTS(c *gin.Context) { - // 记录请求开始时间 - startTime := time.Now() - - // 只支持POST请求 - if c.Request.Method != http.MethodPost { - c.AbortWithStatusJSON(http.StatusMethodNotAllowed, gin.H{"error": "仅支持POST请求"}) - return - } - - // 解析请求 - var openaiReq struct { - Model string `json:"model"` - Input string `json:"input"` - Voice string `json:"voice"` - Speed float64 `json:"speed"` - } - - if err := c.ShouldBindJSON(&openaiReq); err != nil { - c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "无效的JSON请求: " + err.Error()}) - return - } - - // 记录解析时间 - parseTime := time.Since(startTime) - - // 检查必需字段 - if openaiReq.Input == "" { - c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "input字段不能为空"}) - return - } - - // 映射OpenAI声音到Microsoft声音 - msVoice := openaiReq.Voice - if openaiReq.Voice != "" && h.config.TTS.VoiceMapping[openaiReq.Voice] != "" { - msVoice = h.config.TTS.VoiceMapping[openaiReq.Voice] - } - - // 转换速度参数到微软格式 - msRate := h.config.TTS.DefaultRate - if openaiReq.Speed != 0 { - // OpenAI速度转换为微软速度格式 - // OpenAI: 0.5(慢速), 1.0(正常), 2.0(快速) - // 微软: "-50%"(慢), "+0%"(中), "+100%"(快) - speedPercentage := (openaiReq.Speed - 1.0) * 100 - if speedPercentage >= 0 { - msRate = fmt.Sprintf("+%.0f", speedPercentage) - } else { - msRate = fmt.Sprintf("%.0f", speedPercentage) - } - } - - // 创建内部TTS请求 - req := models.TTSRequest{ - Text: openaiReq.Input, - Voice: msVoice, - Rate: msRate, - Pitch: h.config.TTS.DefaultPitch, - Style: openaiReq.Model, - } - - log.Printf("OpenAI TTS请求: model=%s, voice=%s → %s, speed=%.2f → %s, 文本长度=%d", - openaiReq.Model, openaiReq.Voice, msVoice, openaiReq.Speed, msRate, len(req.Text)) - - // 检查文本长度 - if len(req.Text) > h.config.TTS.MaxTextLength { - c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "文本长度超过限制"}) - return - } - - // 检查是否需要分段处理 - segmentThreshold := h.config.TTS.SegmentThreshold - if len(req.Text) > segmentThreshold && len(req.Text) <= h.config.TTS.MaxTextLength { - log.Printf("文本长度 %d 超过阈值 %d,使用分段处理", len(req.Text), segmentThreshold) - // 使用分段处理 - h.handleSegmentedTTS(c, req) - return - } - - // 非流式模式处理 - synthStart := time.Now() - resp, err := h.ttsService.SynthesizeSpeech(c.Request.Context(), req) - synthTime := time.Since(synthStart) - log.Printf("TTS合成耗时: %v, 文本长度: %d", synthTime, len(req.Text)) - - if err != nil { - c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "语音合成失败: " + err.Error()}) - return - } - - // 设置响应 - c.Header("Content-Type", "audio/mpeg") - writeStart := time.Now() - c.Writer.Write(resp.AudioContent) - writeTime := time.Since(writeStart) - - // 记录总耗时 - totalTime := time.Since(startTime) - log.Printf("OpenAI TTS请求总耗时: %v (解析: %v, 合成: %v, 写入: %v), 音频大小: %s", - totalTime, parseTime, synthTime, writeTime, formatFileSize(len(resp.AudioContent))) -} - -// HandleTTS 处理TTS请求 -func (h *TTSHandler) HandleTTS(c *gin.Context) { - // 记录请求开始时间 - startTime := time.Now() - - // 解析请求参数 - var req models.TTSRequest - - switch c.Request.Method { - case http.MethodGet: - // 从URL参数获取 - req = models.TTSRequest{ - Text: c.Query("t"), - Voice: c.Query("v"), - Rate: c.Query("r"), - Pitch: c.Query("p"), - Style: c.Query("s"), - } - case http.MethodPost: - // 从POST JSON体获取 - if c.ContentType() == "application/json" { - if err := c.ShouldBindJSON(&req); err != nil { - log.Printf("JSON解析错误: %v", err) - c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "无效的JSON请求"}) - return - } - } else { - // 表单数据 - if err := c.ShouldBind(&req); err != nil { - log.Printf("表单解析错误: %v", err) - c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "无法解析表单数据"}) - return - } - } - default: - log.Printf("不支持的HTTP方法: %s", c.Request.Method) - c.AbortWithStatusJSON(http.StatusMethodNotAllowed, gin.H{"error": "仅支持GET和POST请求"}) - return - } - - parseTime := time.Since(startTime) - +// processTTSRequest 处理TTS请求的核心逻辑 +func (h *TTSHandler) processTTSRequest(c *gin.Context, req models.TTSRequest, startTime time.Time, parseTime time.Duration, requestType string) { // 验证必要参数 if req.Text == "" { log.Print("错误: 未提供文本参数") @@ -259,6 +117,51 @@ func (h *TTSHandler) HandleTTS(c *gin.Context) { } // 使用默认值填充空白参数 + h.fillDefaultValues(&req) + + // 检查文本长度 + reqTextLength := utf8.RuneCountInString(req.Text) + if reqTextLength > h.config.TTS.MaxTextLength { + c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "文本长度超过限制"}) + return + } + + // 检查是否需要分段处理 + segmentThreshold := h.config.TTS.SegmentThreshold + if reqTextLength > segmentThreshold && reqTextLength <= h.config.TTS.MaxTextLength { + log.Printf("文本长度 %d 超过阈值 %d,使用分段处理", reqTextLength, segmentThreshold) + h.handleSegmentedTTS(c, req) + return + } + + synthStart := time.Now() + resp, err := h.ttsService.SynthesizeSpeech(c.Request.Context(), req) + synthTime := time.Since(synthStart) + log.Printf("TTS合成耗时: %v, 文本长度: %d", synthTime, reqTextLength) + + if err != nil { + log.Printf("TTS合成失败: %v", err) + c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "语音合成失败: " + err.Error()}) + return + } + + // 设置响应 + c.Header("Content-Type", "audio/mpeg") + writeStart := time.Now() + if _, err := c.Writer.Write(resp.AudioContent); err != nil { + log.Printf("写入响应失败: %v", err) + return + } + writeTime := time.Since(writeStart) + + // 记录总耗时 + totalTime := time.Since(startTime) + log.Printf("%s请求总耗时: %v (解析: %v, 合成: %v, 写入: %v), 音频大小: %s", + requestType, totalTime, parseTime, synthTime, writeTime, formatFileSize(len(resp.AudioContent))) +} + +// fillDefaultValues 填充默认值 +func (h *TTSHandler) fillDefaultValues(req *models.TTSRequest) { if req.Voice == "" { req.Voice = h.config.TTS.DefaultVoice } @@ -268,81 +171,182 @@ func (h *TTSHandler) HandleTTS(c *gin.Context) { if req.Pitch == "" { req.Pitch = h.config.TTS.DefaultPitch } - - // 检查文本长度 - if len(req.Text) > h.config.TTS.MaxTextLength { - c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "文本长度超过限制"}) - return - } - - // 检查是否需要分段处理 - segmentThreshold := h.config.TTS.SegmentThreshold - if len(req.Text) > segmentThreshold && len(req.Text) <= h.config.TTS.MaxTextLength { - log.Printf("文本长度 %d 超过阈值 %d,使用分段处理", len(req.Text), segmentThreshold) - // 如果文本长度超过阈值但小于最大限制,使用分段处理 - h.handleSegmentedTTS(c, req) - return - } - - synthStart := time.Now() - resp, err := h.ttsService.SynthesizeSpeech(c.Request.Context(), req) - synthTime := time.Since(synthStart) - log.Printf("TTS合成耗时: %v, 文本长度: %d", synthTime, len(req.Text)) - - if err != nil { - c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "语音合成失败: " + err.Error()}) - return - } - - // 设置响应 - c.Header("Content-Type", "audio/mpeg") - writeStart := time.Now() - c.Writer.Write(resp.AudioContent) - writeTime := time.Since(writeStart) - - // 记录总耗时 - totalTime := time.Since(startTime) - log.Printf("TTS请求总耗时: %v (解析: %v, 合成: %v, 写入: %v), 音频大小: %s", - totalTime, parseTime, synthTime, writeTime, formatFileSize(len(resp.AudioContent))) } -// handleSegmentedTTS 处理长文本的分段TTS请求 +// HandleTTS 处理TTS请求 +func (h *TTSHandler) HandleTTS(c *gin.Context) { + switch c.Request.Method { + case http.MethodGet: + h.HandleTTSGet(c) + case http.MethodPost: + h.HandleTTSPost(c) + default: + c.AbortWithStatusJSON(http.StatusMethodNotAllowed, gin.H{"error": "仅支持GET和POST请求"}) + } +} + +// HandleTTSGet 处理GET方式的TTS请求 +func (h *TTSHandler) HandleTTSGet(c *gin.Context) { + startTime := time.Now() + + // 从URL参数获取 + req := models.TTSRequest{ + Text: c.Query("t"), + Voice: c.Query("v"), + Rate: c.Query("r"), + Pitch: c.Query("p"), + Style: c.Query("s"), + } + + parseTime := time.Since(startTime) + h.processTTSRequest(c, req, startTime, parseTime, "TTS GET") +} + +// HandleTTSPost 处理POST方式的TTS请求 +func (h *TTSHandler) HandleTTSPost(c *gin.Context) { + startTime := time.Now() + + // 从POST JSON体或表单数据获取 + var req models.TTSRequest + var err error + + if c.ContentType() == "application/json" { + err = c.ShouldBindJSON(&req) + if err != nil { + log.Printf("JSON解析错误: %v", err) + c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "无效的JSON请求"}) + return + } + } else { + err = c.ShouldBind(&req) + if err != nil { + log.Printf("表单解析错误: %v", err) + c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "无法解析表单数据"}) + return + } + } + + parseTime := time.Since(startTime) + h.processTTSRequest(c, req, startTime, parseTime, "TTS POST") +} + +// HandleOpenAITTS 处理OpenAI兼容的TTS请求 +func (h *TTSHandler) HandleOpenAITTS(c *gin.Context) { + startTime := time.Now() + + // 只支持POST请求 + if c.Request.Method != http.MethodPost { + c.AbortWithStatusJSON(http.StatusMethodNotAllowed, gin.H{"error": "仅支持POST请求"}) + return + } + + // 解析请求 + var openaiReq models.OpenAIRequest + if err := c.ShouldBindJSON(&openaiReq); err != nil { + c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "无效的JSON请求: " + err.Error()}) + return + } + + parseTime := time.Since(startTime) + + // 检查必需字段 + if openaiReq.Input == "" { + c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "input字段不能为空"}) + return + } + + // 创建内部TTS请求 + req := h.convertOpenAIRequest(openaiReq) + + log.Printf("OpenAI TTS请求: model=%s, voice=%s → %s, speed=%.2f → %s, 文本长度=%d", + openaiReq.Model, openaiReq.Voice, req.Voice, openaiReq.Speed, req.Rate, utf8.RuneCountInString(req.Text)) + + h.processTTSRequest(c, req, startTime, parseTime, "OpenAI TTS") +} + +// convertOpenAIRequest 将OpenAI请求转换为内部请求格式 +func (h *TTSHandler) convertOpenAIRequest(openaiReq models.OpenAIRequest) models.TTSRequest { + // 映射OpenAI声音到Microsoft声音 + msVoice := openaiReq.Voice + if openaiReq.Voice != "" && h.config.TTS.VoiceMapping[openaiReq.Voice] != "" { + msVoice = h.config.TTS.VoiceMapping[openaiReq.Voice] + } + + // 转换速度参数到微软格式 + msRate := h.config.TTS.DefaultRate + if openaiReq.Speed != 0 { + speedPercentage := (openaiReq.Speed - 1.0) * 100 + if speedPercentage >= 0 { + msRate = fmt.Sprintf("+%.0f", speedPercentage) + } else { + msRate = fmt.Sprintf("%.0f", speedPercentage) + } + } + + return models.TTSRequest{ + Text: openaiReq.Input, + Voice: msVoice, + Rate: msRate, + Pitch: h.config.TTS.DefaultPitch, + Style: openaiReq.Model, + } +} + +// Add this struct to store synthesis results +type sentenceSynthesisResult struct { + index int + length int + audioSize int + content string + duration time.Duration +} + +// Modify the handleSegmentedTTS function to collect and display results in a table func (h *TTSHandler) handleSegmentedTTS(c *gin.Context, req models.TTSRequest) { - segmentStart := time.Now() // 分段处理开始时间 + segmentStart := time.Now() text := req.Text // 开始计时:分割文本 splitStart := time.Now() - // 按句子分段处理 sentences := splitTextBySentences(text) segmentCount := len(sentences) splitTime := time.Since(splitStart) log.Printf("分割文本耗时: %v, 文本总长度: %d, 分段数: %d, 平均句子长度: %.2f", - splitTime, len(text), segmentCount, float64(len(text))/float64(segmentCount)) + splitTime, utf8.RuneCountInString(text), segmentCount, float64(utf8.RuneCountInString(text))/float64(segmentCount)) // 创建用于存储每段音频的切片 results := make([][]byte, segmentCount) - errChan := make(chan error, segmentCount) - var wg sync.WaitGroup + // 创建用于收集合成结果信息的切片 + synthResults := make([]sentenceSynthesisResult, segmentCount) - // 限制并发数量,避免创建过多goroutine + errChan := make(chan error, 1) + var wg sync.WaitGroup + var synthMutex sync.Mutex + + // 限制并发数量 maxConcurrent := h.config.TTS.MaxConcurrent semaphore := make(chan struct{}, maxConcurrent) - // 用于记录每个分段处理的时间 - segmentTimes := make([]time.Duration, segmentCount) - // 合成阶段开始时间 synthesisStart := time.Now() // 并发处理每一个句子 for i := 0; i < segmentCount; i++ { wg.Add(1) - semaphore <- struct{}{} // 获取信号量 go func(index int) { defer wg.Done() - defer func() { <-semaphore }() // 释放信号量 + + select { + case semaphore <- struct{}{}: // 获取信号量 + defer func() { <-semaphore }() // 释放信号量 + case <-c.Request.Context().Done(): + select { + case errChan <- c.Request.Context().Err(): + default: + } + return + } // 创建该句的请求 segReq := models.TTSRequest{ @@ -350,193 +354,129 @@ func (h *TTSHandler) handleSegmentedTTS(c *gin.Context, req models.TTSRequest) { Voice: req.Voice, Rate: req.Rate, Pitch: req.Pitch, + Style: req.Style, } - log.Printf("开始处理句子 #%d: 长度=%d, 内容='%s'", - index+1, - utf8.RuneCountInString(sentences[index]), - truncateForLog(sentences[index], 20)) - - // 记录该段合成开始时间 - segStart := time.Now() - + startTime := time.Now() // 合成该段音频 resp, err := h.ttsService.SynthesizeSpeech(c.Request.Context(), segReq) - - // 记录该段合成耗时 - segTime := time.Since(segStart) - segmentTimes[index] = segTime + synthDuration := time.Since(startTime) if err != nil { - log.Printf("句子 #%d 合成失败,耗时: %v, 错误: %v", index+1, segTime, err) select { case errChan <- fmt.Errorf("句子 %d 合成失败: %w", index+1, err): default: - // 已经有错误了,忽略 } return } - log.Printf("句子 #%d 合成成功:长度=%d, 耗时=%v, 音频大小=%s", - index+1, utf8.RuneCountInString(sentences[index]), segTime, formatFileSize(len(resp.AudioContent))) + // 收集合成结果信息,而不是立即打印 + result := sentenceSynthesisResult{ + index: index, + length: utf8.RuneCountInString(sentences[index]), + audioSize: len(resp.AudioContent), + content: truncateForLog(sentences[index], 20), + duration: synthDuration, + } - // 存储该段结果 + synthMutex.Lock() + synthResults[index] = result results[index] = resp.AudioContent + synthMutex.Unlock() }(i) } - // 等待所有goroutine完成 - wg.Wait() - close(errChan) + // 等待所有goroutine完成或出错 + done := make(chan struct{}) + go func() { + wg.Wait() + close(done) + }() - // 记录所有分段合成总耗时 + select { + case <-done: + // 所有goroutine正常完成 + case err := <-errChan: + // 发生错误 + c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + case <-c.Request.Context().Done(): + // 请求被取消 + c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "请求被取消"}) + return + } + + // 打印表格格式的合成结果 + log.Println("句子合成结果表:") + log.Println("-------------------------------------------------------------") + log.Println("序号 | 长度 | 音频大小 | 耗时 | 内容") + log.Println("-------------------------------------------------------------") + for i := 0; i < segmentCount; i++ { + result := synthResults[i] + log.Printf("#%-3d | %4d | %12s | %10v | %s", + i+1, + result.length, + formatFileSize(result.audioSize), + result.duration.Round(time.Millisecond), + result.content) + } + log.Println("-------------------------------------------------------------") + + // 记录合成总耗时 synthesisTime := time.Since(synthesisStart) log.Printf("所有分段合成总耗时: %v, 平均每段耗时: %v", synthesisTime, synthesisTime/time.Duration(segmentCount)) - // 检查是否有错误发生 - if err := <-errChan; err != nil { - c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "语音合成失败: " + err.Error()}) - return - } - - // 记录写入开始时间 + // 合并音频 writeStart := time.Now() - - var audioData []byte - var err error - - audioData, err = audioMerge(results) - + audioData, err := audioMerge(results) if err != nil { log.Printf("合并音频失败: %v", err) c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "音频合并失败: " + err.Error()}) return } - // 设置响应内容类型 + // 设置响应内容类型并写入数据 c.Header("Content-Type", "audio/mpeg") - - // 写入合并后的音频数据 - totalSize := len(audioData) - if _, writeErr := c.Writer.Write(audioData); writeErr != nil { - log.Printf("写入响应失败: %v", writeErr) + if _, err := c.Writer.Write(audioData); err != nil { + log.Printf("写入响应失败: %v", err) + return } - // 记录写入耗时 + // 记录写入耗时和总耗时 writeTime := time.Since(writeStart) - - // 记录总耗时 totalTime := time.Since(segmentStart) log.Printf("分段TTS请求总耗时: %v (分割: %v, 合成: %v, 写入: %v), 总音频大小: %s", - totalTime, splitTime, synthesisTime, writeTime, formatFileSize(totalSize)) + totalTime, splitTime, synthesisTime, writeTime, formatFileSize(len(audioData))) +} + +// sentenceEnders 定义句子结束的标点符号 +var sentenceEnders = map[rune]bool{ + '。': true, + '!': true, + '?': true, + '…': true, + '.': true, + '!': true, + '?': true, + '\n': true, } // splitTextBySentences 将文本按句子分割 func splitTextBySentences(text string) []string { - // 定义句子结束的标点符号 - sentenceEnders := []string{"。", "!", "?", "…", ".", "!", "?", "…", "\n"} - // 如果文本过短,直接作为一个句子返回 if utf8.RuneCountInString(text) < 100 { return []string{text} } - var sentences []string - var currentSentence strings.Builder - maxSentenceLength := config.Get().TTS.MaxSentenceLength // 设置单个句子的最大长度,避免过长句子 - runeCount := 0 // 当前句子的实际字符数量 + cfg := config.Get().TTS + maxLen := cfg.MaxSentenceLength + minLen := cfg.MinSentenceLength - for _, char := range text { - currentSentence.WriteRune(char) - runeCount++ - - // 检查是否到达句子结束标点 - lastChar := string(char) - isSentenceEnder := false - for _, ender := range sentenceEnders { - if lastChar == ender { - isSentenceEnder = true - break - } - } - - // 判断是否结束一个句子 - 使用字符数量而非字节长度 - if isSentenceEnder || runeCount >= maxSentenceLength { - // 添加当前句子到结果中 - sentence := currentSentence.String() - if len(sentence) > 0 { - sentences = append(sentences, sentence) - } - currentSentence.Reset() // 重置构建器 - runeCount = 0 // 重置字符计数器 - } - } - - // 处理可能的最后一个句子 - if currentSentence.Len() > 0 { - lastSentence := currentSentence.String() - sentences = append(sentences, lastSentence) - } - - // 合并过短的句子 - minSentenceLength := config.Get().TTS.MinSentenceLength // 设置最小句子长度阈值 - - if len(sentences) > 1 { - mergedSentences := []string{} - var currentMerged strings.Builder - currentMergedLength := 0 - - for i, sentence := range sentences { - sentenceLength := utf8.RuneCountInString(sentence) - - // 如果当前句子太短,且不是最后一个,考虑合并 - if sentenceLength < minSentenceLength && i < len(sentences)-1 { - // 检查合并后是否会超过最大长度 - if currentMergedLength+sentenceLength > maxSentenceLength { - // 合并后会超长,先保存当前内容 - if currentMerged.Len() > 0 { - mergedSentences = append(mergedSentences, currentMerged.String()) - currentMerged.Reset() - currentMergedLength = 0 - } - } - - // 当前句子过短,添加到合并缓冲区 - currentMerged.WriteString(sentence) - currentMergedLength += sentenceLength - } else { - // 句子足够长或是最后一句 - if currentMerged.Len() > 0 { - // 检查合并后是否会超过最大长度 - if currentMergedLength+sentenceLength <= maxSentenceLength { - // 有待合并的内容,将当前句子也合并进去 - currentMerged.WriteString(sentence) - mergedSentence := currentMerged.String() - mergedSentences = append(mergedSentences, mergedSentence) - } else { - // 合并后会超长,分别添加 - mergedSentence := currentMerged.String() - mergedSentences = append(mergedSentences, mergedSentence) - mergedSentences = append(mergedSentences, sentence) - } - currentMerged.Reset() - currentMergedLength = 0 - } else { - // 没有待合并内容,直接添加当前句子 - mergedSentences = append(mergedSentences, sentence) - } - } - } - - // 处理可能剩余的合并内容 - if currentMerged.Len() > 0 { - mergedSentence := currentMerged.String() - mergedSentences = append(mergedSentences, mergedSentence) - } - - return mergedSentences - } - - return sentences + // 第一次分割:按标点和长度限制分割 + sentences := utils.SplitAndFilterEmptyLines(text) + // 第二次处理:合并过短的句子 + shortSentences := utils.MergeStringsWithLimit(sentences, minLen, maxLen) + log.Printf("分割后的句子数: %d → %d", len(sentences), len(shortSentences)) + return shortSentences } diff --git a/internal/models/tts.go b/internal/models/tts.go index 622e897..7b42496 100644 --- a/internal/models/tts.go +++ b/internal/models/tts.go @@ -15,3 +15,11 @@ type TTSResponse struct { ContentType string `json:"content_type"` // MIME类型 CacheHit bool `json:"cache_hit"` // 是否命中缓存 } + +// OpenAIRequest OpenAI TTS请求结构体 +type OpenAIRequest struct { + Model string `json:"model"` + Input string `json:"input"` + Voice string `json:"voice"` + Speed float64 `json:"speed"` +} diff --git a/internal/utils/utils.go b/internal/utils/utils.go index 30b7e64..7979612 100644 --- a/internal/utils/utils.go +++ b/internal/utils/utils.go @@ -12,6 +12,7 @@ import ( "net/url" "strings" "time" + "unicode/utf8" "github.com/google/uuid" "github.com/sirupsen/logrus" @@ -97,3 +98,62 @@ func Sign(urlStr string) string { signBase64 := base64.StdEncoding.EncodeToString(secretKey) return fmt.Sprintf("MSTranslatorAndroidApp::%s::%s::%s", signBase64, formattedDate, uuidStr) } + +// SplitAndFilterEmptyLines 拆分文本并过滤掉空行 +func SplitAndFilterEmptyLines(text string) []string { + // 按换行符拆分 + lines := strings.Split(text, "\n") + var result []string + + for _, line := range lines { + trimmed := strings.TrimSpace(line) + if trimmed != "" { + result = append(result, trimmed) + } + } + return result +} + +// MergeStringsWithLimit 会将字符串切片依次累加,直到总长度 ≥ minLen。 +// 但如果再合并下一段后会超过 maxLen,则提前结束本段合并,放入结果。 +// 然后继续新的一段合并。 +func MergeStringsWithLimit(strs []string, minLen int, maxLen int) []string { + var result []string + + for i := 0; i < len(strs); { + // 如果已经没有更多段落,直接退出 + if i >= len(strs) { + break + } + + // 从当前段开始合并 + currentBuilder := strings.Builder{} + currentBuilder.WriteString(strs[i]) + i++ + + for i < len(strs) { + currentLen := utf8.RuneCountInString(currentBuilder.String()) + // 如果当前已达(或超过) minLen,先行结束本段合并 + if currentLen >= minLen { + break + } + + // 检查添加下一个段落后是否会超过 1.2 × minLen + nextLen := utf8.RuneCountInString(strs[i]) + if currentLen+nextLen > int(float64(minLen)*1.2) { + // 加上下一个会超标,则结束合并 + break + } + + // 如果未超标,则继续合并这个段 + currentBuilder.WriteString("\n") + currentBuilder.WriteString(strs[i]) + i++ + } + + // 本段合并结束,加入结果 + result = append(result, currentBuilder.String()) + } + + return result +}