Files
Microsoft-tts/internal/http/handlers/tts.go

471 lines
13 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package handlers
import (
"fmt"
"log"
"net/http"
"os"
"os/exec"
"path/filepath"
"strings"
"sync"
"time"
"tts/internal/config"
"tts/internal/models"
"tts/internal/tts"
"tts/internal/utils"
"unicode/utf8"
"github.com/gin-gonic/gin"
)
// truncateForLog 截断文本用于日志显示,同时显示开头和结尾
func truncateForLog(text string, maxLength int) string {
// 先去除换行符
text = strings.ReplaceAll(text, "\n", " ")
text = strings.ReplaceAll(text, "\r", " ")
runes := []rune(text)
if len(runes) <= maxLength {
return text
}
// 计算开头和结尾各显示多少字符
halfLength := maxLength / 2
return string(runes[:halfLength]) + "..." + string(runes[len(runes)-halfLength:])
}
// audioMerge 音频合并
func audioMerge(audioSegments [][]byte) ([]byte, error) {
if len(audioSegments) == 0 {
return nil, fmt.Errorf("没有音频片段可合并")
}
// 使用 ffmpeg 合并音频
tempDir, err := os.MkdirTemp("", "audio_merge_")
if err != nil {
return nil, err
}
defer os.RemoveAll(tempDir)
listFile := filepath.Join(tempDir, "concat.txt")
lf, err := os.Create(listFile)
if err != nil {
return nil, err
}
for i, seg := range audioSegments {
segFile := filepath.Join(tempDir, fmt.Sprintf("seg_%d.mp3", i))
if err := os.WriteFile(segFile, seg, 0644); err != nil {
return nil, err
}
if _, err := lf.WriteString(fmt.Sprintf("file '%s'\n", segFile)); err != nil {
return nil, err
}
}
lf.Close()
outputFile := filepath.Join(tempDir, "output.mp3")
cmd := exec.Command("ffmpeg", "-y", "-f", "concat", "-safe", "0", "-i", listFile, "-c", "copy", outputFile)
if err := cmd.Run(); err != nil {
return nil, err
}
mergedData, err := os.ReadFile(outputFile)
if err != nil {
return nil, err
}
log.Printf("使用ffmpeg合并完成总大小: %s", formatFileSize(len(mergedData)))
return mergedData, nil
}
// formatFileSize 格式化文件大小
func formatFileSize(size int) string {
switch {
case size < 1024:
return fmt.Sprintf("%d B", size)
case size < 1024*1024:
return fmt.Sprintf("%.2f KB", float64(size)/1024.0)
case size < 1024*1024*1024:
return fmt.Sprintf("%.2f MB", float64(size)/(1024.0*1024.0))
default:
return fmt.Sprintf("%.2f GB", float64(size)/(1024.0*1024.0*1024.0))
}
}
// TTSHandler 处理TTS请求
type TTSHandler struct {
ttsService tts.Service
config *config.Config
}
// NewTTSHandler 创建一个新的TTS处理器
func NewTTSHandler(service tts.Service, cfg *config.Config) *TTSHandler {
return &TTSHandler{
ttsService: service,
config: cfg,
}
}
// processTTSRequest 处理TTS请求的核心逻辑
func (h *TTSHandler) processTTSRequest(c *gin.Context, req models.TTSRequest, startTime time.Time, parseTime time.Duration, requestType string) {
// 验证必要参数
if req.Text == "" {
log.Print("错误: 未提供文本参数")
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "必须提供文本参数"})
return
}
// 使用默认值填充空白参数
h.fillDefaultValues(&req)
// 检查文本长度
reqTextLength := utf8.RuneCountInString(req.Text)
if reqTextLength > h.config.TTS.MaxTextLength {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "文本长度超过限制"})
return
}
// 检查是否需要分段处理
segmentThreshold := h.config.TTS.SegmentThreshold
if reqTextLength > segmentThreshold && reqTextLength <= h.config.TTS.MaxTextLength {
log.Printf("文本长度 %d 超过阈值 %d使用分段处理", reqTextLength, segmentThreshold)
h.handleSegmentedTTS(c, req)
return
}
synthStart := time.Now()
resp, err := h.ttsService.SynthesizeSpeech(c.Request.Context(), req)
synthTime := time.Since(synthStart)
log.Printf("TTS合成耗时: %v, 文本长度: %d", synthTime, reqTextLength)
if err != nil {
log.Printf("TTS合成失败: %v", err)
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "语音合成失败: " + err.Error()})
return
}
// 设置响应
c.Header("Content-Type", "audio/mpeg")
writeStart := time.Now()
if _, err := c.Writer.Write(resp.AudioContent); err != nil {
log.Printf("写入响应失败: %v", err)
return
}
writeTime := time.Since(writeStart)
// 记录总耗时
totalTime := time.Since(startTime)
log.Printf("%s请求总耗时: %v (解析: %v, 合成: %v, 写入: %v), 音频大小: %s",
requestType, totalTime, parseTime, synthTime, writeTime, formatFileSize(len(resp.AudioContent)))
}
// fillDefaultValues 填充默认值
func (h *TTSHandler) fillDefaultValues(req *models.TTSRequest) {
if req.Voice == "" {
req.Voice = h.config.TTS.DefaultVoice
}
if req.Rate == "" {
req.Rate = h.config.TTS.DefaultRate
}
if req.Pitch == "" {
req.Pitch = h.config.TTS.DefaultPitch
}
}
// HandleTTS 处理TTS请求
func (h *TTSHandler) HandleTTS(c *gin.Context) {
switch c.Request.Method {
case http.MethodGet:
h.HandleTTSGet(c)
case http.MethodPost:
h.HandleTTSPost(c)
default:
c.AbortWithStatusJSON(http.StatusMethodNotAllowed, gin.H{"error": "仅支持GET和POST请求"})
}
}
// HandleTTSGet 处理GET方式的TTS请求
func (h *TTSHandler) HandleTTSGet(c *gin.Context) {
startTime := time.Now()
// 从URL参数获取
req := models.TTSRequest{
Text: c.Query("t"),
Voice: c.Query("v"),
Rate: c.Query("r"),
Pitch: c.Query("p"),
Style: c.Query("s"),
}
parseTime := time.Since(startTime)
h.processTTSRequest(c, req, startTime, parseTime, "TTS GET")
}
// HandleTTSPost 处理POST方式的TTS请求
func (h *TTSHandler) HandleTTSPost(c *gin.Context) {
startTime := time.Now()
// 从POST JSON体或表单数据获取
var req models.TTSRequest
var err error
if c.ContentType() == "application/json" {
err = c.ShouldBindJSON(&req)
if err != nil {
log.Printf("JSON解析错误: %v", err)
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "无效的JSON请求"})
return
}
} else {
err = c.ShouldBind(&req)
if err != nil {
log.Printf("表单解析错误: %v", err)
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "无法解析表单数据"})
return
}
}
parseTime := time.Since(startTime)
h.processTTSRequest(c, req, startTime, parseTime, "TTS POST")
}
// HandleOpenAITTS 处理OpenAI兼容的TTS请求
func (h *TTSHandler) HandleOpenAITTS(c *gin.Context) {
startTime := time.Now()
// 只支持POST请求
if c.Request.Method != http.MethodPost {
c.AbortWithStatusJSON(http.StatusMethodNotAllowed, gin.H{"error": "仅支持POST请求"})
return
}
// 解析请求
var openaiReq models.OpenAIRequest
if err := c.ShouldBindJSON(&openaiReq); err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "无效的JSON请求: " + err.Error()})
return
}
parseTime := time.Since(startTime)
// 检查必需字段
if openaiReq.Input == "" {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "input字段不能为空"})
return
}
// 创建内部TTS请求
req := h.convertOpenAIRequest(openaiReq)
log.Printf("OpenAI TTS请求: model=%s, voice=%s → %s, speed=%.2f → %s, 文本长度=%d",
openaiReq.Model, openaiReq.Voice, req.Voice, openaiReq.Speed, req.Rate, utf8.RuneCountInString(req.Text))
h.processTTSRequest(c, req, startTime, parseTime, "OpenAI TTS")
}
// convertOpenAIRequest 将OpenAI请求转换为内部请求格式
func (h *TTSHandler) convertOpenAIRequest(openaiReq models.OpenAIRequest) models.TTSRequest {
// 映射OpenAI声音到Microsoft声音
msVoice := openaiReq.Voice
if openaiReq.Voice != "" && h.config.TTS.VoiceMapping[openaiReq.Voice] != "" {
msVoice = h.config.TTS.VoiceMapping[openaiReq.Voice]
}
// 转换速度参数到微软格式
msRate := h.config.TTS.DefaultRate
if openaiReq.Speed != 0 {
speedPercentage := (openaiReq.Speed - 1.0) * 100
if speedPercentage >= 0 {
msRate = fmt.Sprintf("+%.0f", speedPercentage)
} else {
msRate = fmt.Sprintf("%.0f", speedPercentage)
}
}
return models.TTSRequest{
Text: openaiReq.Input,
Voice: msVoice,
Rate: msRate,
Pitch: h.config.TTS.DefaultPitch,
Style: openaiReq.Model,
}
}
// Add this struct to store synthesis results
type sentenceSynthesisResult struct {
index int
length int
audioSize int
content string
duration time.Duration
}
// Modify the handleSegmentedTTS function to collect and display results in a table
func (h *TTSHandler) handleSegmentedTTS(c *gin.Context, req models.TTSRequest) {
segmentStart := time.Now()
text := req.Text
// 开始计时:分割文本
splitStart := time.Now()
sentences := splitTextBySentences(text)
segmentCount := len(sentences)
splitTime := time.Since(splitStart)
log.Printf("分割文本耗时: %v, 文本总长度: %d, 分段数: %d, 平均句子长度: %.2f",
splitTime, utf8.RuneCountInString(text), segmentCount, float64(utf8.RuneCountInString(text))/float64(segmentCount))
// 创建用于存储每段音频的切片
results := make([][]byte, segmentCount)
// 创建用于收集合成结果信息的切片
synthResults := make([]sentenceSynthesisResult, segmentCount)
errChan := make(chan error, 1)
var wg sync.WaitGroup
var synthMutex sync.Mutex
// 限制并发数量
maxConcurrent := h.config.TTS.MaxConcurrent
semaphore := make(chan struct{}, maxConcurrent)
// 合成阶段开始时间
synthesisStart := time.Now()
// 并发处理每一个句子
for i := 0; i < segmentCount; i++ {
wg.Add(1)
go func(index int) {
defer wg.Done()
select {
case semaphore <- struct{}{}: // 获取信号量
defer func() { <-semaphore }() // 释放信号量
case <-c.Request.Context().Done():
select {
case errChan <- c.Request.Context().Err():
default:
}
return
}
// 创建该句的请求
segReq := models.TTSRequest{
Text: sentences[index],
Voice: req.Voice,
Rate: req.Rate,
Pitch: req.Pitch,
Style: req.Style,
}
startTime := time.Now()
// 合成该段音频
resp, err := h.ttsService.SynthesizeSpeech(c.Request.Context(), segReq)
synthDuration := time.Since(startTime)
if err != nil {
select {
case errChan <- fmt.Errorf("句子 %d 合成失败: %w", index+1, err):
default:
}
return
}
// 收集合成结果信息,而不是立即打印
result := sentenceSynthesisResult{
index: index,
length: utf8.RuneCountInString(sentences[index]),
audioSize: len(resp.AudioContent),
content: truncateForLog(sentences[index], 20),
duration: synthDuration,
}
synthMutex.Lock()
synthResults[index] = result
results[index] = resp.AudioContent
synthMutex.Unlock()
}(i)
}
// 等待所有goroutine完成或出错
done := make(chan struct{})
go func() {
wg.Wait()
close(done)
}()
select {
case <-done:
// 所有goroutine正常完成
case err := <-errChan:
// 发生错误
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
case <-c.Request.Context().Done():
// 请求被取消
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "请求被取消"})
return
}
// 打印表格格式的合成结果
log.Println("句子合成结果表:")
log.Println("-------------------------------------------------------------")
log.Println("序号 | 长度 | 音频大小 | 耗时 | 内容")
log.Println("-------------------------------------------------------------")
for i := 0; i < segmentCount; i++ {
result := synthResults[i]
log.Printf("#%-3d | %4d | %12s | %10v | %s",
i+1,
result.length,
formatFileSize(result.audioSize),
result.duration.Round(time.Millisecond),
result.content)
}
log.Println("-------------------------------------------------------------")
// 记录合成总耗时
synthesisTime := time.Since(synthesisStart)
log.Printf("所有分段合成总耗时: %v, 平均每段耗时: %v",
synthesisTime, synthesisTime/time.Duration(segmentCount))
// 合并音频
writeStart := time.Now()
audioData, err := audioMerge(results)
if err != nil {
log.Printf("合并音频失败: %v", err)
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "音频合并失败: " + err.Error()})
return
}
// 设置响应内容类型并写入数据
c.Header("Content-Type", "audio/mpeg")
if _, err := c.Writer.Write(audioData); err != nil {
log.Printf("写入响应失败: %v", err)
return
}
// 记录写入耗时和总耗时
writeTime := time.Since(writeStart)
totalTime := time.Since(segmentStart)
log.Printf("分段TTS请求总耗时: %v (分割: %v, 合成: %v, 写入: %v), 总音频大小: %s",
totalTime, splitTime, synthesisTime, writeTime, formatFileSize(len(audioData)))
}
// splitTextBySentences 将文本按句子分割
func splitTextBySentences(text string) []string {
// 如果文本过短,直接作为一个句子返回
if utf8.RuneCountInString(text) < 100 {
return []string{text}
}
cfg := config.Get().TTS
maxLen := cfg.MaxSentenceLength
minLen := cfg.MinSentenceLength
// 第一次分割:按标点和长度限制分割
sentences := utils.SplitAndFilterEmptyLines(text)
// 第二次处理:合并过短的句子
shortSentences := utils.MergeStringsWithLimit(sentences, minLen, maxLen)
log.Printf("分割后的句子数: %d → %d", len(sentences), len(shortSentences))
return shortSentences
}