feat: 重构项目以符合 Go 规范,添加 OpenAI 接口适配,优化长文本朗读功能(切割后合并)
This commit is contained in:
129
internal/config/config.go
Normal file
129
internal/config/config.go
Normal file
@@ -0,0 +1,129 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"sync"
|
||||
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
// Config 包含应用程序的所有配置
|
||||
type Config struct {
|
||||
Server ServerConfig `yaml:"server"`
|
||||
TTS TTSConfig `yaml:"tts"`
|
||||
}
|
||||
|
||||
// ServerConfig 包含HTTP服务器配置
|
||||
type ServerConfig struct {
|
||||
Port int `yaml:"port"`
|
||||
ReadTimeout int `yaml:"read_timeout"` // 单位:秒
|
||||
WriteTimeout int `yaml:"write_timeout"` // 单位:秒
|
||||
BasePath string `yaml:"base_path"`
|
||||
}
|
||||
|
||||
// TTSConfig 包含Microsoft TTS API配置
|
||||
type TTSConfig struct {
|
||||
APIKey string `yaml:"api_key"`
|
||||
Region string `yaml:"region"`
|
||||
DefaultVoice string `yaml:"default_voice"`
|
||||
DefaultRate string `yaml:"default_rate"`
|
||||
DefaultPitch string `yaml:"default_pitch"`
|
||||
DefaultFormat string `yaml:"default_format"`
|
||||
MaxTextLength int `yaml:"max_text_length"`
|
||||
RequestTimeout int `yaml:"request_timeout"` // 单位:秒
|
||||
MaxConcurrent int `yaml:"max_concurrent"`
|
||||
SegmentThreshold int `yaml:"segment_threshold"`
|
||||
MinSentenceLength int `yaml:"min_sentence_length"`
|
||||
MaxSentenceLength int `yaml:"max_sentence_length"`
|
||||
VoiceMapping map[string]string `yaml:"voice_mapping"` // OpenAI声音到Azure声音的映射
|
||||
}
|
||||
|
||||
var (
|
||||
config Config
|
||||
once sync.Once
|
||||
)
|
||||
|
||||
// Load 从指定路径加载配置文件
|
||||
func Load(configPath string) (*Config, error) {
|
||||
var err error
|
||||
once.Do(func() {
|
||||
// 设置默认配置
|
||||
setDefaults()
|
||||
|
||||
// 从配置文件加载
|
||||
if configPath != "" {
|
||||
err = loadFromFile(configPath)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("加载配置文件失败: %w", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// 从环境变量覆盖
|
||||
overrideFromEnv()
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &config, nil
|
||||
}
|
||||
|
||||
// 设置默认配置值
|
||||
func setDefaults() {
|
||||
config = Config{
|
||||
Server: ServerConfig{
|
||||
Port: 8080,
|
||||
ReadTimeout: 30,
|
||||
WriteTimeout: 30,
|
||||
BasePath: "",
|
||||
},
|
||||
TTS: TTSConfig{
|
||||
DefaultVoice: "zh-CN-XiaoxiaoNeural",
|
||||
DefaultRate: "0%",
|
||||
DefaultPitch: "0%",
|
||||
DefaultFormat: "audio-24khz-48kbitrate-mono-mp3",
|
||||
MaxTextLength: 5000,
|
||||
RequestTimeout: 30,
|
||||
MaxConcurrent: 10,
|
||||
SegmentThreshold: 500,
|
||||
MinSentenceLength: 200,
|
||||
MaxSentenceLength: 300,
|
||||
VoiceMapping: make(map[string]string),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// 从配置文件加载配置
|
||||
func loadFromFile(path string) error {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return yaml.Unmarshal(data, &config)
|
||||
}
|
||||
|
||||
// 从环境变量中覆盖配置
|
||||
func overrideFromEnv() {
|
||||
if port := os.Getenv("TTS_SERVER_PORT"); port != "" {
|
||||
fmt.Sscanf(port, "%d", &config.Server.Port)
|
||||
}
|
||||
|
||||
if apiKey := os.Getenv("TTS_API_KEY"); apiKey != "" {
|
||||
config.TTS.APIKey = apiKey
|
||||
}
|
||||
|
||||
if region := os.Getenv("TTS_API_REGION"); region != "" {
|
||||
config.TTS.Region = region
|
||||
}
|
||||
|
||||
// 可以添加更多环境变量覆盖
|
||||
}
|
||||
|
||||
// Get 返回已加载的配置
|
||||
func Get() *Config {
|
||||
return &config
|
||||
}
|
||||
76
internal/http/handlers/pages.go
Normal file
76
internal/http/handlers/pages.go
Normal file
@@ -0,0 +1,76 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"html/template"
|
||||
"net/http"
|
||||
"path/filepath"
|
||||
|
||||
"tts/internal/config"
|
||||
)
|
||||
|
||||
// PagesHandler 处理页面请求
|
||||
type PagesHandler struct {
|
||||
templates *template.Template
|
||||
config *config.Config
|
||||
}
|
||||
|
||||
// NewPagesHandler 创建一个新的页面处理器
|
||||
func NewPagesHandler(templatesDir string, cfg *config.Config) (*PagesHandler, error) {
|
||||
// 解析所有模板文件
|
||||
templates, err := template.ParseGlob(filepath.Join(templatesDir, "*.html"))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &PagesHandler{
|
||||
templates: templates,
|
||||
config: cfg,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// HandleIndex 处理首页请求
|
||||
func (h *PagesHandler) HandleIndex(w http.ResponseWriter, r *http.Request) {
|
||||
// 如果不是根路径,返回404
|
||||
if r.URL.Path != "/" && r.URL.Path != "/index.html" {
|
||||
http.NotFound(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
// 准备模板数据
|
||||
data := map[string]interface{}{
|
||||
"BasePath": h.config.Server.BasePath,
|
||||
"DefaultVoice": h.config.TTS.DefaultVoice,
|
||||
"DefaultRate": h.config.TTS.DefaultRate,
|
||||
"DefaultPitch": h.config.TTS.DefaultPitch,
|
||||
}
|
||||
|
||||
// 设置内容类型
|
||||
w.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||
|
||||
// 渲染模板
|
||||
if err := h.templates.ExecuteTemplate(w, "index.html", data); err != nil {
|
||||
http.Error(w, "模板渲染失败: "+err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// HandleAPIDoc 处理API文档请求
|
||||
func (h *PagesHandler) HandleAPIDoc(w http.ResponseWriter, r *http.Request) {
|
||||
// 准备模板数据
|
||||
data := map[string]interface{}{
|
||||
"BasePath": h.config.Server.BasePath,
|
||||
"DefaultVoice": h.config.TTS.DefaultVoice,
|
||||
"DefaultRate": h.config.TTS.DefaultRate,
|
||||
"DefaultPitch": h.config.TTS.DefaultPitch,
|
||||
"DefaultFormat": h.config.TTS.DefaultFormat,
|
||||
}
|
||||
|
||||
// 设置内容类型
|
||||
w.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||
|
||||
// 渲染模板
|
||||
if err := h.templates.ExecuteTemplate(w, "api-doc.html", data); err != nil {
|
||||
http.Error(w, "模板渲染失败: "+err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
}
|
||||
553
internal/http/handlers/tts.go
Normal file
553
internal/http/handlers/tts.go
Normal file
@@ -0,0 +1,553 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
"tts/internal/config"
|
||||
"tts/internal/models"
|
||||
"tts/internal/tts"
|
||||
"unicode/utf8"
|
||||
)
|
||||
|
||||
// TTSHandler 处理TTS请求
|
||||
type TTSHandler struct {
|
||||
ttsService tts.Service
|
||||
config *config.Config
|
||||
}
|
||||
|
||||
// NewTTSHandler 创建一个新的TTS处理器
|
||||
func NewTTSHandler(service tts.Service, cfg *config.Config) *TTSHandler {
|
||||
return &TTSHandler{
|
||||
ttsService: service,
|
||||
config: cfg,
|
||||
}
|
||||
}
|
||||
|
||||
// HandleOpenAITTS 处理OpenAI兼容的TTS请求
|
||||
func (h *TTSHandler) HandleOpenAITTS(w http.ResponseWriter, r *http.Request) {
|
||||
// 记录请求开始时间
|
||||
startTime := time.Now()
|
||||
|
||||
// 只支持POST请求
|
||||
if r.Method != http.MethodPost {
|
||||
http.Error(w, "仅支持POST请求", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
// 解析请求
|
||||
var openaiReq struct {
|
||||
Model string `json:"model"`
|
||||
Input string `json:"input"`
|
||||
Voice string `json:"voice"`
|
||||
Speed float64 `json:"speed"`
|
||||
}
|
||||
|
||||
if err := json.NewDecoder(r.Body).Decode(&openaiReq); err != nil {
|
||||
http.Error(w, "无效的JSON请求: "+err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// 记录解析时间
|
||||
parseTime := time.Since(startTime)
|
||||
|
||||
// 检查必需字段
|
||||
if openaiReq.Input == "" {
|
||||
http.Error(w, "input字段不能为空", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// 映射OpenAI声音到Microsoft声音
|
||||
msVoice := h.config.TTS.DefaultVoice
|
||||
if openaiReq.Voice != "" {
|
||||
// 检查是否有配置映射
|
||||
if mappedVoice, exists := h.config.TTS.VoiceMapping[openaiReq.Voice]; exists {
|
||||
msVoice = mappedVoice
|
||||
}
|
||||
}
|
||||
|
||||
// 转换速度参数到微软格式
|
||||
msRate := h.config.TTS.DefaultRate
|
||||
if openaiReq.Speed != 0 {
|
||||
// OpenAI速度转换为微软速度格式
|
||||
// OpenAI: 0.5(慢速), 1.0(正常), 2.0(快速)
|
||||
// 微软: "-50%"(慢), "+0%"(中), "+100%"(快)
|
||||
speedPercentage := (openaiReq.Speed - 1.0) * 100
|
||||
if speedPercentage >= 0 {
|
||||
msRate = fmt.Sprintf("+%.0f", speedPercentage)
|
||||
} else {
|
||||
msRate = fmt.Sprintf("%.0f", speedPercentage)
|
||||
}
|
||||
}
|
||||
|
||||
// 创建内部TTS请求
|
||||
req := models.TTSRequest{
|
||||
Text: openaiReq.Input,
|
||||
Voice: msVoice,
|
||||
Rate: msRate,
|
||||
Pitch: h.config.TTS.DefaultPitch,
|
||||
}
|
||||
|
||||
log.Printf("OpenAI TTS请求: model=%s, voice=%s → %s, speed=%.2f → %s, 文本长度=%d",
|
||||
openaiReq.Model, openaiReq.Voice, msVoice, openaiReq.Speed, msRate, len(req.Text))
|
||||
|
||||
// 检查文本长度
|
||||
if len(req.Text) > h.config.TTS.MaxTextLength {
|
||||
http.Error(w, "文本长度超过限制", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// 检查是否需要分段处理
|
||||
segmentThreshold := h.config.TTS.SegmentThreshold
|
||||
if len(req.Text) > segmentThreshold && len(req.Text) <= h.config.TTS.MaxTextLength {
|
||||
log.Printf("文本长度 %d 超过阈值 %d,使用分段处理", len(req.Text), segmentThreshold)
|
||||
// 使用分段处理
|
||||
h.handleSegmentedTTS(w, r, req)
|
||||
return
|
||||
}
|
||||
|
||||
// 非流式模式处理
|
||||
synthStart := time.Now()
|
||||
resp, err := h.ttsService.SynthesizeSpeech(r.Context(), req)
|
||||
synthTime := time.Since(synthStart)
|
||||
log.Printf("TTS合成耗时: %v, 文本长度: %d", synthTime, len(req.Text))
|
||||
|
||||
if err != nil {
|
||||
http.Error(w, "语音合成失败: "+err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// 设置响应
|
||||
w.Header().Set("Content-Type", "audio/mpeg")
|
||||
writeStart := time.Now()
|
||||
w.Write(resp.AudioContent)
|
||||
writeTime := time.Since(writeStart)
|
||||
|
||||
// 记录总耗时
|
||||
totalTime := time.Since(startTime)
|
||||
log.Printf("OpenAI TTS请求总耗时: %v (解析: %v, 合成: %v, 写入: %v), 音频大小: %s",
|
||||
totalTime, parseTime, synthTime, writeTime, formatFileSize(len(resp.AudioContent)))
|
||||
}
|
||||
|
||||
// HandleTTS 处理TTS请求
|
||||
func (h *TTSHandler) HandleTTS(w http.ResponseWriter, r *http.Request) {
|
||||
// 记录请求开始时间
|
||||
startTime := time.Now()
|
||||
|
||||
// 解析请求参数
|
||||
var req models.TTSRequest
|
||||
|
||||
switch r.Method {
|
||||
case http.MethodGet:
|
||||
// 从URL参数获取
|
||||
q := r.URL.Query()
|
||||
req = models.TTSRequest{
|
||||
Text: q.Get("t"),
|
||||
Voice: q.Get("v"),
|
||||
Rate: q.Get("r"),
|
||||
Pitch: q.Get("p"),
|
||||
}
|
||||
case http.MethodPost:
|
||||
// 从POST JSON体获取
|
||||
if r.Header.Get("Content-Type") == "application/json" {
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
log.Printf("JSON解析错误: %v", err)
|
||||
http.Error(w, "无效的JSON请求", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
} else {
|
||||
// 表单数据
|
||||
if err := r.ParseForm(); err != nil {
|
||||
log.Printf("表单解析错误: %v", err)
|
||||
http.Error(w, "无法解析表单数据", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
req = models.TTSRequest{
|
||||
Text: r.FormValue("text"),
|
||||
Voice: r.FormValue("voice"),
|
||||
Rate: r.FormValue("rate"),
|
||||
Pitch: r.FormValue("pitch"),
|
||||
}
|
||||
}
|
||||
default:
|
||||
log.Printf("不支持的HTTP方法: %s", r.Method)
|
||||
http.Error(w, "仅支持GET和POST请求", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
// 记录参数解析耗时
|
||||
parseTime := time.Since(startTime)
|
||||
log.Printf("请求参数解析耗时: %v", parseTime)
|
||||
|
||||
// 验证必要参数
|
||||
if req.Text == "" {
|
||||
log.Print("错误: 未提供文本参数")
|
||||
http.Error(w, "必须提供文本参数", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// 使用默认值填充空白参数
|
||||
if req.Voice == "" {
|
||||
req.Voice = h.config.TTS.DefaultVoice
|
||||
}
|
||||
if req.Rate == "" {
|
||||
req.Rate = h.config.TTS.DefaultRate
|
||||
}
|
||||
if req.Pitch == "" {
|
||||
req.Pitch = h.config.TTS.DefaultPitch
|
||||
}
|
||||
|
||||
// 检查文本长度
|
||||
if len(req.Text) > h.config.TTS.MaxTextLength {
|
||||
http.Error(w, "文本长度超过限制", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// 检查是否需要分段处理
|
||||
segmentThreshold := h.config.TTS.SegmentThreshold
|
||||
if len(req.Text) > segmentThreshold && len(req.Text) <= h.config.TTS.MaxTextLength {
|
||||
log.Printf("文本长度 %d 超过阈值 %d,使用分段处理", len(req.Text), segmentThreshold)
|
||||
// 如果文本长度超过阈值但小于最大限制,使用分段处理
|
||||
h.handleSegmentedTTS(w, r, req)
|
||||
return
|
||||
}
|
||||
|
||||
// 非流式模式处理(保持原有逻辑)
|
||||
synthStart := time.Now()
|
||||
resp, err := h.ttsService.SynthesizeSpeech(r.Context(), req)
|
||||
synthTime := time.Since(synthStart)
|
||||
log.Printf("TTS合成耗时: %v, 文本长度: %d", synthTime, len(req.Text))
|
||||
|
||||
if err != nil {
|
||||
http.Error(w, "语音合成失败: "+err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// 设置响应
|
||||
w.Header().Set("Content-Type", "audio/mpeg")
|
||||
writeStart := time.Now()
|
||||
w.Write(resp.AudioContent)
|
||||
writeTime := time.Since(writeStart)
|
||||
|
||||
// 记录总耗时
|
||||
totalTime := time.Since(startTime)
|
||||
log.Printf("TTS请求总耗时: %v (解析: %v, 合成: %v, 写入: %v), 音频大小: %s",
|
||||
totalTime, parseTime, synthTime, writeTime, formatFileSize(len(resp.AudioContent)))
|
||||
}
|
||||
|
||||
// handleSegmentedTTS 处理长文本的分段TTS请求
|
||||
func (h *TTSHandler) handleSegmentedTTS(w http.ResponseWriter, r *http.Request, req models.TTSRequest) {
|
||||
segmentStart := time.Now() // 分段处理开始时间
|
||||
text := req.Text
|
||||
|
||||
// 开始计时:分割文本
|
||||
splitStart := time.Now()
|
||||
// 按句子分段处理
|
||||
sentences := splitTextBySentences(text)
|
||||
segmentCount := len(sentences)
|
||||
splitTime := time.Since(splitStart)
|
||||
|
||||
log.Printf("分割文本耗时: %v, 文本总长度: %d, 分段数: %d, 平均句子长度: %.2f",
|
||||
splitTime, len(text), segmentCount, float64(len(text))/float64(segmentCount))
|
||||
|
||||
// 创建用于存储每段音频的切片
|
||||
results := make([][]byte, segmentCount)
|
||||
errChan := make(chan error, segmentCount)
|
||||
var wg sync.WaitGroup
|
||||
|
||||
// 限制并发数量,避免创建过多goroutine
|
||||
maxConcurrent := h.config.TTS.MaxConcurrent
|
||||
semaphore := make(chan struct{}, maxConcurrent)
|
||||
|
||||
// 用于记录每个分段处理的时间
|
||||
segmentTimes := make([]time.Duration, segmentCount)
|
||||
|
||||
// 合成阶段开始时间
|
||||
synthesisStart := time.Now()
|
||||
|
||||
// 并发处理每一个句子
|
||||
for i := 0; i < segmentCount; i++ {
|
||||
wg.Add(1)
|
||||
semaphore <- struct{}{} // 获取信号量
|
||||
go func(index int) {
|
||||
defer wg.Done()
|
||||
defer func() { <-semaphore }() // 释放信号量
|
||||
|
||||
// 创建该句的请求
|
||||
segReq := models.TTSRequest{
|
||||
Text: sentences[index],
|
||||
Voice: req.Voice,
|
||||
Rate: req.Rate,
|
||||
Pitch: req.Pitch,
|
||||
}
|
||||
|
||||
log.Printf("开始处理句子 #%d: 长度=%d, 内容='%s'",
|
||||
index+1,
|
||||
utf8.RuneCountInString(sentences[index]),
|
||||
truncateForLog(sentences[index], 20))
|
||||
|
||||
// 记录该段合成开始时间
|
||||
segStart := time.Now()
|
||||
|
||||
// 合成该段音频
|
||||
resp, err := h.ttsService.SynthesizeSpeech(r.Context(), segReq)
|
||||
|
||||
// 记录该段合成耗时
|
||||
segTime := time.Since(segStart)
|
||||
segmentTimes[index] = segTime
|
||||
|
||||
if err != nil {
|
||||
log.Printf("句子 #%d 合成失败,耗时: %v, 错误: %v", index+1, segTime, err)
|
||||
select {
|
||||
case errChan <- fmt.Errorf("句子 %d 合成失败: %w", index+1, err):
|
||||
default:
|
||||
// 已经有错误了,忽略
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
log.Printf("句子 #%d 合成成功:长度=%d, 耗时=%v, 音频大小=%s",
|
||||
index+1, utf8.RuneCountInString(sentences[index]), segTime, formatFileSize(len(resp.AudioContent)))
|
||||
|
||||
// 存储该段结果
|
||||
results[index] = resp.AudioContent
|
||||
}(i)
|
||||
}
|
||||
|
||||
// 等待所有goroutine完成
|
||||
wg.Wait()
|
||||
close(errChan)
|
||||
|
||||
// 记录所有分段合成总耗时
|
||||
synthesisTime := time.Since(synthesisStart)
|
||||
log.Printf("所有分段合成总耗时: %v, 平均每段耗时: %v",
|
||||
synthesisTime, synthesisTime/time.Duration(segmentCount))
|
||||
|
||||
// 检查是否有错误发生
|
||||
if err := <-errChan; err != nil {
|
||||
http.Error(w, "语音合成失败: "+err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// 记录写入开始时间
|
||||
writeStart := time.Now()
|
||||
|
||||
var audioData []byte
|
||||
var err error
|
||||
|
||||
audioData, err = audioMerge(results)
|
||||
|
||||
if err != nil {
|
||||
log.Printf("合并音频失败: %v", err)
|
||||
http.Error(w, "音频合并失败: "+err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// 设置响应内容类型
|
||||
w.Header().Set("Content-Type", "audio/mpeg")
|
||||
|
||||
// 写入合并后的音频数据
|
||||
totalSize := len(audioData)
|
||||
if _, writeErr := w.Write(audioData); writeErr != nil {
|
||||
log.Printf("写入响应失败: %v", writeErr)
|
||||
}
|
||||
|
||||
// 记录写入耗时
|
||||
writeTime := time.Since(writeStart)
|
||||
|
||||
// 记录总耗时
|
||||
totalTime := time.Since(segmentStart)
|
||||
log.Printf("分段TTS请求总耗时: %v (分割: %v, 合成: %v, 写入: %v), 总音频大小: %s",
|
||||
totalTime, splitTime, synthesisTime, writeTime, formatFileSize(totalSize))
|
||||
}
|
||||
|
||||
// splitTextBySentences 将文本按句子分割
|
||||
func splitTextBySentences(text string) []string {
|
||||
// 定义句子结束的标点符号
|
||||
sentenceEnders := []string{"。", "!", "?", "…", ".", "!", "?", "…", "\n"}
|
||||
|
||||
// 如果文本过短,直接作为一个句子返回
|
||||
if utf8.RuneCountInString(text) < 100 {
|
||||
return []string{text}
|
||||
}
|
||||
|
||||
var sentences []string
|
||||
var currentSentence strings.Builder
|
||||
maxSentenceLength := config.Get().TTS.MaxSentenceLength // 设置单个句子的最大长度,避免过长句子
|
||||
runeCount := 0 // 当前句子的实际字符数量
|
||||
|
||||
for _, char := range text {
|
||||
currentSentence.WriteRune(char)
|
||||
runeCount++
|
||||
|
||||
// 检查是否到达句子结束标点
|
||||
lastChar := string(char)
|
||||
isSentenceEnder := false
|
||||
for _, ender := range sentenceEnders {
|
||||
if lastChar == ender {
|
||||
isSentenceEnder = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// 判断是否结束一个句子 - 使用字符数量而非字节长度
|
||||
if isSentenceEnder || runeCount >= maxSentenceLength {
|
||||
// 添加当前句子到结果中
|
||||
sentence := currentSentence.String()
|
||||
if len(sentence) > 0 {
|
||||
sentences = append(sentences, sentence)
|
||||
}
|
||||
currentSentence.Reset() // 重置构建器
|
||||
runeCount = 0 // 重置字符计数器
|
||||
}
|
||||
}
|
||||
|
||||
// 处理可能的最后一个句子
|
||||
if currentSentence.Len() > 0 {
|
||||
lastSentence := currentSentence.String()
|
||||
sentences = append(sentences, lastSentence)
|
||||
}
|
||||
|
||||
// 合并过短的句子
|
||||
minSentenceLength := config.Get().TTS.MinSentenceLength // 设置最小句子长度阈值
|
||||
|
||||
if len(sentences) > 1 {
|
||||
mergedSentences := []string{}
|
||||
var currentMerged strings.Builder
|
||||
currentMergedLength := 0
|
||||
|
||||
for i, sentence := range sentences {
|
||||
sentenceLength := utf8.RuneCountInString(sentence)
|
||||
|
||||
// 如果当前句子太短,且不是最后一个,考虑合并
|
||||
if sentenceLength < minSentenceLength && i < len(sentences)-1 {
|
||||
// 检查合并后是否会超过最大长度
|
||||
if currentMergedLength+sentenceLength > maxSentenceLength {
|
||||
// 合并后会超长,先保存当前内容
|
||||
if currentMerged.Len() > 0 {
|
||||
mergedSentences = append(mergedSentences, currentMerged.String())
|
||||
currentMerged.Reset()
|
||||
currentMergedLength = 0
|
||||
}
|
||||
}
|
||||
|
||||
// 当前句子过短,添加到合并缓冲区
|
||||
currentMerged.WriteString(sentence)
|
||||
currentMergedLength += sentenceLength
|
||||
} else {
|
||||
// 句子足够长或是最后一句
|
||||
if currentMerged.Len() > 0 {
|
||||
// 检查合并后是否会超过最大长度
|
||||
if currentMergedLength+sentenceLength <= maxSentenceLength {
|
||||
// 有待合并的内容,将当前句子也合并进去
|
||||
currentMerged.WriteString(sentence)
|
||||
mergedSentence := currentMerged.String()
|
||||
mergedSentences = append(mergedSentences, mergedSentence)
|
||||
} else {
|
||||
// 合并后会超长,分别添加
|
||||
mergedSentence := currentMerged.String()
|
||||
mergedSentences = append(mergedSentences, mergedSentence)
|
||||
mergedSentences = append(mergedSentences, sentence)
|
||||
}
|
||||
currentMerged.Reset()
|
||||
currentMergedLength = 0
|
||||
} else {
|
||||
// 没有待合并内容,直接添加当前句子
|
||||
mergedSentences = append(mergedSentences, sentence)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 处理可能剩余的合并内容
|
||||
if currentMerged.Len() > 0 {
|
||||
mergedSentence := currentMerged.String()
|
||||
mergedSentences = append(mergedSentences, mergedSentence)
|
||||
log.Printf("添加最后剩余的合并句子,长度=%d", utf8.RuneCountInString(mergedSentence))
|
||||
}
|
||||
|
||||
return mergedSentences
|
||||
}
|
||||
|
||||
return sentences
|
||||
}
|
||||
|
||||
// truncateForLog 截断文本用于日志显示,同时显示开头和结尾
|
||||
func truncateForLog(text string, maxLength int) string {
|
||||
// 先去除换行符
|
||||
text = strings.ReplaceAll(text, "\n", " ")
|
||||
text = strings.ReplaceAll(text, "\r", " ")
|
||||
|
||||
runes := []rune(text)
|
||||
if len(runes) <= maxLength {
|
||||
return text
|
||||
}
|
||||
// 计算开头和结尾各显示多少字符
|
||||
halfLength := maxLength / 2
|
||||
return string(runes[:halfLength]) + "..." + string(runes[len(runes)-halfLength:])
|
||||
}
|
||||
|
||||
// audioMerge 音频合并
|
||||
func audioMerge(audioSegments [][]byte) ([]byte, error) {
|
||||
if len(audioSegments) == 0 {
|
||||
return nil, fmt.Errorf("没有音频片段可合并")
|
||||
}
|
||||
|
||||
// 使用 ffmpeg 合并音频
|
||||
tempDir, err := os.MkdirTemp("", "audio_merge_")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer os.RemoveAll(tempDir)
|
||||
|
||||
listFile := filepath.Join(tempDir, "concat.txt")
|
||||
lf, err := os.Create(listFile)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for i, seg := range audioSegments {
|
||||
segFile := filepath.Join(tempDir, fmt.Sprintf("seg_%d.mp3", i))
|
||||
if err := os.WriteFile(segFile, seg, 0644); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if _, err := lf.WriteString(fmt.Sprintf("file '%s'\n", segFile)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
lf.Close()
|
||||
|
||||
outputFile := filepath.Join(tempDir, "output.mp3")
|
||||
|
||||
cmd := exec.Command("ffmpeg", "-y", "-f", "concat", "-safe", "0", "-i", listFile, "-c", "copy", outputFile)
|
||||
if err := cmd.Run(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
mergedData, err := os.ReadFile(outputFile)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
log.Printf("使用ffmpeg合并完成,总大小: %s", formatFileSize(len(mergedData)))
|
||||
return mergedData, nil
|
||||
}
|
||||
|
||||
// formatFileSize 格式化文件大小
|
||||
func formatFileSize(size int) string {
|
||||
switch {
|
||||
case size < 1024:
|
||||
return fmt.Sprintf("%d B", size)
|
||||
case size < 1024*1024:
|
||||
return fmt.Sprintf("%.2f KB", float64(size)/1024.0)
|
||||
case size < 1024*1024*1024:
|
||||
return fmt.Sprintf("%.2f MB", float64(size)/(1024.0*1024.0))
|
||||
default:
|
||||
return fmt.Sprintf("%.2f GB", float64(size)/(1024.0*1024.0*1024.0))
|
||||
}
|
||||
}
|
||||
41
internal/http/handlers/voices.go
Normal file
41
internal/http/handlers/voices.go
Normal file
@@ -0,0 +1,41 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"tts/internal/tts"
|
||||
)
|
||||
|
||||
// VoicesHandler 处理语音列表请求
|
||||
type VoicesHandler struct {
|
||||
ttsService tts.Service
|
||||
}
|
||||
|
||||
// NewVoicesHandler 创建一个新的语音列表处理器
|
||||
func NewVoicesHandler(service tts.Service) *VoicesHandler {
|
||||
return &VoicesHandler{
|
||||
ttsService: service,
|
||||
}
|
||||
}
|
||||
|
||||
// HandleVoices 处理语音列表请求
|
||||
func (h *VoicesHandler) HandleVoices(w http.ResponseWriter, r *http.Request) {
|
||||
// 从查询参数中获取语言筛选
|
||||
locale := r.URL.Query().Get("locale")
|
||||
|
||||
// 获取语音列表
|
||||
voices, err := h.ttsService.ListVoices(r.Context(), locale)
|
||||
if err != nil {
|
||||
http.Error(w, "获取语音列表失败: "+err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// 设置内容类型
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
// 编码为JSON并返回
|
||||
if err := json.NewEncoder(w).Encode(voices); err != nil {
|
||||
http.Error(w, "JSON编码失败", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
}
|
||||
22
internal/http/middleware/cors.go
Normal file
22
internal/http/middleware/cors.go
Normal file
@@ -0,0 +1,22 @@
|
||||
package middleware
|
||||
|
||||
import "net/http"
|
||||
|
||||
// CORS 处理跨域资源共享
|
||||
func CORS(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// 设置CORS响应头
|
||||
w.Header().Set("Access-Control-Allow-Origin", "*")
|
||||
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, OPTIONS")
|
||||
w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization")
|
||||
|
||||
// 如果是预检请求,直接返回200
|
||||
if r.Method == http.MethodOptions {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
return
|
||||
}
|
||||
|
||||
// 继续下一个处理器
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
46
internal/http/middleware/logger.go
Normal file
46
internal/http/middleware/logger.go
Normal file
@@ -0,0 +1,46 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"log"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Logger 是一个HTTP中间件,记录请求的详细信息
|
||||
func Logger(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
start := time.Now()
|
||||
|
||||
// 包装ResponseWriter以捕获状态码
|
||||
wrapper := &responseWriterWrapper{
|
||||
ResponseWriter: w,
|
||||
statusCode: http.StatusOK,
|
||||
}
|
||||
|
||||
// 调用下一个处理器
|
||||
next.ServeHTTP(wrapper, r)
|
||||
|
||||
// 记录请求信息
|
||||
duration := time.Since(start)
|
||||
log.Printf(
|
||||
"[%s] %s %s %d %s",
|
||||
r.Method,
|
||||
r.RequestURI,
|
||||
r.RemoteAddr,
|
||||
wrapper.statusCode,
|
||||
duration,
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
// responseWriterWrapper 包装http.ResponseWriter以捕获状态码
|
||||
type responseWriterWrapper struct {
|
||||
http.ResponseWriter
|
||||
statusCode int
|
||||
}
|
||||
|
||||
// WriteHeader 捕获状态码
|
||||
func (w *responseWriterWrapper) WriteHeader(statusCode int) {
|
||||
w.statusCode = statusCode
|
||||
w.ResponseWriter.WriteHeader(statusCode)
|
||||
}
|
||||
83
internal/http/server/app.go
Normal file
83
internal/http/server/app.go
Normal file
@@ -0,0 +1,83 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
"time"
|
||||
"tts/internal/config"
|
||||
)
|
||||
|
||||
// App 表示整个TTS应用程序
|
||||
type App struct {
|
||||
server *Server
|
||||
cfg *config.Config
|
||||
}
|
||||
|
||||
// NewApp 创建一个新的应用程序实例
|
||||
func NewApp(configPath string) (*App, error) {
|
||||
// 加载配置
|
||||
cfg, err := config.Load(configPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("加载配置失败: %w", err)
|
||||
}
|
||||
|
||||
// 初始化服务
|
||||
ttsService, err := InitializeServices(cfg)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("初始化服务失败: %w", err)
|
||||
}
|
||||
|
||||
// 设置路由
|
||||
handler, err := SetupRoutes(cfg, ttsService)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("设置路由失败: %w", err)
|
||||
}
|
||||
|
||||
// 创建HTTP服务器
|
||||
server := New(cfg, handler)
|
||||
|
||||
return &App{
|
||||
server: server,
|
||||
cfg: cfg,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Start 启动应用程序
|
||||
func (a *App) Start() error {
|
||||
// 创建一个错误通道
|
||||
errChan := make(chan error, 1)
|
||||
|
||||
// 创建一个退出信号通道
|
||||
quit := make(chan os.Signal, 1)
|
||||
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
|
||||
|
||||
// 在一个goroutine中启动服务器
|
||||
go func() {
|
||||
log.Printf("启动TTS服务,监听端口 %d...\n", a.cfg.Server.Port)
|
||||
errChan <- a.server.Start()
|
||||
}()
|
||||
|
||||
// 等待退出信号或错误
|
||||
select {
|
||||
case err := <-errChan:
|
||||
return err
|
||||
case <-quit:
|
||||
log.Println("接收到退出信号,正在优雅关闭...")
|
||||
|
||||
// 创建一个超时上下文用于优雅关闭
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// 尝试优雅关闭服务器
|
||||
if err := a.server.Shutdown(ctx); err != nil {
|
||||
return fmt.Errorf("服务器关闭出错: %w", err)
|
||||
}
|
||||
|
||||
log.Println("服务器已优雅关闭")
|
||||
return nil
|
||||
}
|
||||
}
|
||||
65
internal/http/server/routes.go
Normal file
65
internal/http/server/routes.go
Normal file
@@ -0,0 +1,65 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"tts/internal/config"
|
||||
"tts/internal/http/handlers"
|
||||
"tts/internal/http/middleware"
|
||||
"tts/internal/tts"
|
||||
"tts/internal/tts/microsoft"
|
||||
)
|
||||
|
||||
// SetupRoutes 配置所有API路由
|
||||
func SetupRoutes(cfg *config.Config, ttsService tts.Service) (http.Handler, error) {
|
||||
// 创建一个新的路由多路复用器
|
||||
mux := http.NewServeMux()
|
||||
|
||||
// 创建处理器
|
||||
ttsHandler := handlers.NewTTSHandler(ttsService, cfg)
|
||||
voicesHandler := handlers.NewVoicesHandler(ttsService)
|
||||
|
||||
// 创建页面处理器
|
||||
pagesHandler, err := handlers.NewPagesHandler("./web/templates", cfg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 设置主页路由
|
||||
mux.HandleFunc("/", pagesHandler.HandleIndex)
|
||||
|
||||
// 设置API文档路由
|
||||
mux.HandleFunc("/api-doc", pagesHandler.HandleAPIDoc)
|
||||
|
||||
// 设置TTS API路由
|
||||
mux.HandleFunc("/tts", ttsHandler.HandleTTS)
|
||||
|
||||
// 设置语音列表API路由
|
||||
mux.HandleFunc("/voices", voicesHandler.HandleVoices)
|
||||
|
||||
mux.HandleFunc("/v1/audio/speech", ttsHandler.HandleOpenAITTS)
|
||||
mux.HandleFunc("/audio/speech", ttsHandler.HandleOpenAITTS)
|
||||
|
||||
// 设置静态文件服务
|
||||
fs := http.FileServer(http.Dir("./web/static"))
|
||||
mux.Handle("/static/", http.StripPrefix("/static/", fs))
|
||||
|
||||
// 应用基础路径前缀
|
||||
var handler http.Handler = mux
|
||||
if cfg.Server.BasePath != "" {
|
||||
handler = http.StripPrefix(cfg.Server.BasePath, mux)
|
||||
}
|
||||
|
||||
// 应用中间件
|
||||
handler = middleware.Logger(handler) // 日志中间件
|
||||
handler = middleware.CORS(handler) // CORS中间件
|
||||
|
||||
return handler, nil
|
||||
}
|
||||
|
||||
// InitializeServices 初始化所有服务
|
||||
func InitializeServices(cfg *config.Config) (tts.Service, error) {
|
||||
// 创建Microsoft TTS客户端
|
||||
ttsClient := microsoft.NewClient(cfg)
|
||||
|
||||
return ttsClient, nil
|
||||
}
|
||||
45
internal/http/server/server.go
Normal file
45
internal/http/server/server.go
Normal file
@@ -0,0 +1,45 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"tts/internal/config"
|
||||
)
|
||||
|
||||
// Server 封装HTTP服务器
|
||||
type Server struct {
|
||||
server *http.Server
|
||||
basePath string
|
||||
}
|
||||
|
||||
// New 创建新的HTTP服务器
|
||||
func New(cfg *config.Config, handler http.Handler) *Server {
|
||||
// 创建HTTP服务器
|
||||
httpServer := &http.Server{
|
||||
Addr: fmt.Sprintf(":%d", cfg.Server.Port),
|
||||
Handler: handler,
|
||||
ReadTimeout: time.Duration(cfg.Server.ReadTimeout) * time.Second,
|
||||
WriteTimeout: time.Duration(cfg.Server.WriteTimeout) * time.Second,
|
||||
IdleTimeout: 120 * time.Second,
|
||||
}
|
||||
|
||||
return &Server{
|
||||
server: httpServer,
|
||||
basePath: cfg.Server.BasePath,
|
||||
}
|
||||
}
|
||||
|
||||
// Start 启动HTTP服务器
|
||||
func (s *Server) Start() error {
|
||||
fmt.Printf("服务启动在 %s\n", s.server.Addr)
|
||||
return s.server.ListenAndServe()
|
||||
}
|
||||
|
||||
// Shutdown 优雅关闭服务器
|
||||
func (s *Server) Shutdown(ctx context.Context) error {
|
||||
fmt.Println("正在关闭HTTP服务器...")
|
||||
return s.server.Shutdown(ctx)
|
||||
}
|
||||
16
internal/models/tts.go
Normal file
16
internal/models/tts.go
Normal file
@@ -0,0 +1,16 @@
|
||||
package models
|
||||
|
||||
// TTSRequest 表示一个语音合成请求
|
||||
type TTSRequest struct {
|
||||
Text string `json:"text"` // 要转换的文本
|
||||
Voice string `json:"voice"` // 语音ID
|
||||
Rate string `json:"rate"` // 语速 (-100% 到 +100%)
|
||||
Pitch string `json:"pitch"` // 语调 (-100% 到 +100%)
|
||||
}
|
||||
|
||||
// TTSResponse 表示一个语音合成响应
|
||||
type TTSResponse struct {
|
||||
AudioContent []byte `json:"audio_content"` // 音频数据
|
||||
ContentType string `json:"content_type"` // MIME类型
|
||||
CacheHit bool `json:"cache_hit"` // 是否命中缓存
|
||||
}
|
||||
14
internal/models/voice.go
Normal file
14
internal/models/voice.go
Normal file
@@ -0,0 +1,14 @@
|
||||
package models
|
||||
|
||||
// Voice 表示一个语音合成声音
|
||||
type Voice struct {
|
||||
Name string `json:"name"` // 语音唯一标识符
|
||||
DisplayName string `json:"display_name"` // 语音显示名称
|
||||
LocalName string `json:"local_name"` // 本地化名称
|
||||
ShortName string `json:"short_name"` // 简称,例如 zh-CN-XiaoxiaoNeural
|
||||
Gender string `json:"gender"` // 性别: Female, Male
|
||||
Locale string `json:"locale"` // 语言区域, 如 zh-CN
|
||||
LocaleName string `json:"locale_name"` // 语言区域显示名称,如 中文(中国)
|
||||
StyleList []string `json:"style_list,omitempty"` // 支持的说话风格列表
|
||||
SampleRateHertz string `json:"sample_rate_hertz"` // 采样率
|
||||
}
|
||||
290
internal/tts/microsoft/client.go
Normal file
290
internal/tts/microsoft/client.go
Normal file
@@ -0,0 +1,290 @@
|
||||
package microsoft
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"html"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"tts/internal/config"
|
||||
"tts/internal/models"
|
||||
"tts/internal/utils"
|
||||
)
|
||||
|
||||
const (
|
||||
userAgent = "okhttp/4.5.0"
|
||||
voicesEndpoint = "https://%s.tts.speech.microsoft.com/cognitiveservices/voices/list"
|
||||
ttsEndpoint = "https://%s.tts.speech.microsoft.com/cognitiveservices/v1"
|
||||
ssmlTemplate = `<speak version='1.0' xmlns='http://www.w3.org/2001/10/synthesis' xmlns:mstts="http://www.w3.org/2001/mstts" xml:lang='%s'>
|
||||
<voice name='%s'>
|
||||
<mstts:express-as style="general" styledegree="1.0" role="default">
|
||||
<prosody rate='%s%%' pitch='%s%%' volume="medium">
|
||||
%s
|
||||
</prosody>
|
||||
</mstts:express-as>
|
||||
</voice>
|
||||
</speak>`
|
||||
)
|
||||
|
||||
// Client 是Microsoft TTS API的客户端实现
|
||||
type Client struct {
|
||||
defaultVoice string
|
||||
defaultRate string
|
||||
defaultPitch string
|
||||
defaultFormat string
|
||||
maxTextLength int
|
||||
httpClient *http.Client
|
||||
voicesCache []models.Voice
|
||||
voicesCacheMu sync.RWMutex
|
||||
voicesCacheExpiry time.Time
|
||||
|
||||
// 端点和认证信息
|
||||
endpoint map[string]interface{}
|
||||
endpointMu sync.RWMutex
|
||||
endpointExpiry time.Time
|
||||
}
|
||||
|
||||
func (c *Client) HandleOpenAITTS(w http.ResponseWriter, r *http.Request) {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
// NewClient 创建一个新的Microsoft TTS客户端
|
||||
func NewClient(cfg *config.Config) *Client {
|
||||
client := &Client{
|
||||
defaultVoice: cfg.TTS.DefaultVoice,
|
||||
defaultRate: cfg.TTS.DefaultRate,
|
||||
defaultPitch: cfg.TTS.DefaultPitch,
|
||||
defaultFormat: cfg.TTS.DefaultFormat,
|
||||
maxTextLength: cfg.TTS.MaxTextLength,
|
||||
httpClient: &http.Client{
|
||||
Timeout: time.Duration(cfg.TTS.RequestTimeout) * time.Second,
|
||||
},
|
||||
voicesCacheExpiry: time.Time{}, // 初始时缓存为空
|
||||
endpointExpiry: time.Time{}, // 初始时端点为空
|
||||
}
|
||||
|
||||
return client
|
||||
}
|
||||
|
||||
// getEndpoint 获取或刷新认证端点
|
||||
func (c *Client) getEndpoint(ctx context.Context) (map[string]interface{}, error) {
|
||||
c.endpointMu.RLock()
|
||||
if !c.endpointExpiry.IsZero() && time.Now().Before(c.endpointExpiry) && c.endpoint != nil {
|
||||
endpoint := c.endpoint
|
||||
c.endpointMu.RUnlock()
|
||||
return endpoint, nil
|
||||
}
|
||||
c.endpointMu.RUnlock()
|
||||
|
||||
// 获取新的端点信息
|
||||
endpoint, err := utils.GetEndpoint()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 更新缓存
|
||||
c.endpointMu.Lock()
|
||||
c.endpoint = endpoint
|
||||
c.endpointExpiry = time.Now().Add(45 * time.Minute) // 令牌有效期通常是1小时,提前刷新
|
||||
c.endpointMu.Unlock()
|
||||
|
||||
return endpoint, nil
|
||||
}
|
||||
|
||||
// ListVoices 获取可用的语音列表
|
||||
func (c *Client) ListVoices(ctx context.Context, locale string) ([]models.Voice, error) {
|
||||
// 检查缓存是否有效
|
||||
c.voicesCacheMu.RLock()
|
||||
if !c.voicesCacheExpiry.IsZero() && time.Now().Before(c.voicesCacheExpiry) && len(c.voicesCache) > 0 {
|
||||
voices := c.voicesCache
|
||||
c.voicesCacheMu.RUnlock()
|
||||
|
||||
// 如果指定了locale,则过滤结果
|
||||
if locale != "" {
|
||||
var filtered []models.Voice
|
||||
for _, voice := range voices {
|
||||
if strings.HasPrefix(voice.Locale, locale) {
|
||||
filtered = append(filtered, voice)
|
||||
}
|
||||
}
|
||||
return filtered, nil
|
||||
}
|
||||
return voices, nil
|
||||
}
|
||||
c.voicesCacheMu.RUnlock()
|
||||
|
||||
// 缓存无效,需要从API获取
|
||||
endpoint, err := c.getEndpoint(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
url := fmt.Sprintf(voicesEndpoint, endpoint["r"])
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 使用新的认证方式
|
||||
req.Header.Set("Authorization", endpoint["t"].(string))
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return nil, fmt.Errorf("API error: %s, status: %d", string(body), resp.StatusCode)
|
||||
}
|
||||
|
||||
var msVoices []MicrosoftVoice
|
||||
if err := json.NewDecoder(resp.Body).Decode(&msVoices); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 转换为通用模型
|
||||
voices := make([]models.Voice, len(msVoices))
|
||||
for i, v := range msVoices {
|
||||
voices[i] = models.Voice{
|
||||
Name: v.Name,
|
||||
DisplayName: v.DisplayName,
|
||||
LocalName: v.LocalName,
|
||||
ShortName: v.ShortName,
|
||||
Gender: v.Gender,
|
||||
Locale: v.Locale,
|
||||
LocaleName: v.LocaleName,
|
||||
StyleList: v.StyleList,
|
||||
SampleRateHertz: v.SampleRateHertz, // 直接使用字符串,无需转换
|
||||
}
|
||||
}
|
||||
|
||||
// 更新缓存
|
||||
c.voicesCacheMu.Lock()
|
||||
c.voicesCache = voices
|
||||
c.voicesCacheExpiry = time.Now().Add(1 * time.Hour) // 缓存1小时
|
||||
c.voicesCacheMu.Unlock()
|
||||
|
||||
// 如果指定了locale,则过滤结果
|
||||
if locale != "" {
|
||||
var filtered []models.Voice
|
||||
for _, voice := range voices {
|
||||
if strings.HasPrefix(voice.Locale, locale) {
|
||||
filtered = append(filtered, voice)
|
||||
}
|
||||
}
|
||||
return filtered, nil
|
||||
}
|
||||
|
||||
return voices, nil
|
||||
}
|
||||
|
||||
// SynthesizeSpeech 将文本转换为语音
|
||||
func (c *Client) SynthesizeSpeech(ctx context.Context, req models.TTSRequest) (*models.TTSResponse, error) {
|
||||
resp, err := c.createTTSRequest(ctx, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// 读取音频数据
|
||||
audio, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &models.TTSResponse{
|
||||
AudioContent: audio,
|
||||
ContentType: "audio/mpeg",
|
||||
CacheHit: false,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// createTTSRequest 创建并执行TTS请求,返回HTTP响应
|
||||
func (c *Client) createTTSRequest(ctx context.Context, req models.TTSRequest) (*http.Response, error) {
|
||||
// 参数验证
|
||||
if req.Text == "" {
|
||||
return nil, errors.New("文本不能为空")
|
||||
}
|
||||
|
||||
if len(req.Text) > c.maxTextLength {
|
||||
return nil, fmt.Errorf("文本长度超过限制 (%d > %d)", len(req.Text), c.maxTextLength)
|
||||
}
|
||||
|
||||
// 使用默认值填充空白参数
|
||||
voice := req.Voice
|
||||
if voice == "" {
|
||||
voice = c.defaultVoice
|
||||
}
|
||||
|
||||
rate := req.Rate
|
||||
if rate == "" {
|
||||
rate = c.defaultRate
|
||||
}
|
||||
|
||||
pitch := req.Pitch
|
||||
if pitch == "" {
|
||||
pitch = c.defaultPitch
|
||||
}
|
||||
|
||||
// 提取语言
|
||||
locale := "zh-CN" // 默认
|
||||
parts := strings.Split(voice, "-")
|
||||
if len(parts) >= 2 {
|
||||
locale = parts[0] + "-" + parts[1]
|
||||
}
|
||||
|
||||
// 对文本进行HTML转义,防止XML解析错误
|
||||
|
||||
escapedText := html.EscapeString(req.Text)
|
||||
|
||||
// 准备SSML内容
|
||||
ssml := fmt.Sprintf(ssmlTemplate, locale, voice, rate, pitch, escapedText)
|
||||
|
||||
// 获取端点信息
|
||||
endpoint, err := c.getEndpoint(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 准备请求
|
||||
url := fmt.Sprintf(ttsEndpoint, endpoint["r"])
|
||||
reqBody := bytes.NewBufferString(ssml)
|
||||
|
||||
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, reqBody)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
httpReq.Header.Set("Authorization", endpoint["t"].(string))
|
||||
httpReq.Header.Set("Content-Type", "application/ssml+xml")
|
||||
httpReq.Header.Set("X-Microsoft-OutputFormat", c.defaultFormat)
|
||||
httpReq.Header.Set("User-Agent", userAgent)
|
||||
|
||||
// 发送请求
|
||||
resp, err := c.httpClient.Do(httpReq)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
// 获取响应体以便调试
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
resp.Body.Close()
|
||||
log.Printf("TTS API错误: %s, 状态码: %d", string(body), resp.StatusCode)
|
||||
return nil, fmt.Errorf("TTS API错误: %s, 状态码: %d", string(body), resp.StatusCode)
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
45
internal/tts/microsoft/models.go
Normal file
45
internal/tts/microsoft/models.go
Normal file
@@ -0,0 +1,45 @@
|
||||
package microsoft
|
||||
|
||||
// MicrosoftVoice 表示Microsoft TTS服务中的一个语音
|
||||
type MicrosoftVoice struct {
|
||||
Name string `json:"Name"`
|
||||
DisplayName string `json:"DisplayName"`
|
||||
LocalName string `json:"LocalName"`
|
||||
ShortName string `json:"ShortName"`
|
||||
Gender string `json:"Gender"`
|
||||
Locale string `json:"Locale"`
|
||||
LocaleName string `json:"LocaleName"`
|
||||
StyleList []string `json:"StyleList,omitempty"`
|
||||
SampleRateHertz string `json:"SampleRateHertz"`
|
||||
VoiceType string `json:"VoiceType"`
|
||||
Status string `json:"Status"`
|
||||
}
|
||||
|
||||
// SSMLRequest 表示发送给Microsoft TTS服务的SSML请求
|
||||
type SSMLRequest struct {
|
||||
XMLHeader string
|
||||
Voice string
|
||||
Language string
|
||||
Rate string
|
||||
Pitch string
|
||||
Text string
|
||||
}
|
||||
|
||||
// FormatContentTypeMap 定义音频格式到MIME类型的映射
|
||||
var FormatContentTypeMap = map[string]string{
|
||||
"raw-16khz-16bit-mono-pcm": "audio/pcm",
|
||||
"raw-8khz-8bit-mono-mulaw": "audio/basic",
|
||||
"riff-8khz-8bit-mono-alaw": "audio/alaw",
|
||||
"riff-8khz-8bit-mono-mulaw": "audio/mulaw",
|
||||
"riff-16khz-16bit-mono-pcm": "audio/wav",
|
||||
"audio-16khz-128kbitrate-mono-mp3": "audio/mp3",
|
||||
"audio-16khz-64kbitrate-mono-mp3": "audio/mp3",
|
||||
"audio-16khz-32kbitrate-mono-mp3": "audio/mp3",
|
||||
"raw-24khz-16bit-mono-pcm": "audio/pcm",
|
||||
"riff-24khz-16bit-mono-pcm": "audio/wav",
|
||||
"audio-24khz-160kbitrate-mono-mp3": "audio/mp3",
|
||||
"audio-24khz-96kbitrate-mono-mp3": "audio/mp3",
|
||||
"audio-24khz-48kbitrate-mono-mp3": "audio/mp3",
|
||||
"ogg-24khz-16bit-mono-opus": "audio/ogg",
|
||||
"webm-24khz-16bit-mono-opus": "audio/webm",
|
||||
}
|
||||
15
internal/tts/service.go
Normal file
15
internal/tts/service.go
Normal file
@@ -0,0 +1,15 @@
|
||||
package tts
|
||||
|
||||
import (
|
||||
"context"
|
||||
"tts/internal/models"
|
||||
)
|
||||
|
||||
// Service 定义TTS服务接口
|
||||
type Service interface {
|
||||
// ListVoices 获取可用的语音列表
|
||||
ListVoices(ctx context.Context, locale string) ([]models.Voice, error)
|
||||
|
||||
// SynthesizeSpeech 将文本转换为语音
|
||||
SynthesizeSpeech(ctx context.Context, req models.TTSRequest) (*models.TTSResponse, error)
|
||||
}
|
||||
87
internal/utils/utils.go
Normal file
87
internal/utils/utils.go
Normal file
@@ -0,0 +1,87 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"crypto/hmac"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
var (
|
||||
log = logrus.New()
|
||||
client = &http.Client{}
|
||||
)
|
||||
|
||||
const (
|
||||
endpointURL = "https://dev.microsofttranslator.com/apps/endpoint?api-version=1.0"
|
||||
userAgent = "okhttp/4.5.0"
|
||||
clientVersion = "4.0.530a 5fe1dc6c"
|
||||
userId = "0f04d16a175c411e"
|
||||
homeGeographicRegion = "zh-Hans-CN"
|
||||
clientTraceId = "aab069b9-70a7-4844-a734-96cd78d94be9"
|
||||
voiceDecodeKey = "oik6PdDdMnOXemTbwvMn9de/h9lFnfBaCWbGMMZqqoSaQaqUOqjVGm5NqsmjcBI1x+sS9ugjB55HEJWRiFXYFw=="
|
||||
)
|
||||
|
||||
// GetEndpoint 获取语音合成服务的端点信息
|
||||
func GetEndpoint() (map[string]interface{}, error) {
|
||||
signature := Sign(endpointURL)
|
||||
headers := map[string]string{
|
||||
"Accept-Language": "zh-Hans",
|
||||
"X-ClientVersion": clientVersion,
|
||||
"X-UserId": userId,
|
||||
"X-HomeGeographicRegion": homeGeographicRegion,
|
||||
"X-ClientTraceId": clientTraceId,
|
||||
"X-MT-Signature": signature,
|
||||
"User-Agent": userAgent,
|
||||
"Content-Type": "application/json; charset=utf-8",
|
||||
"Content-Length": "0",
|
||||
"Accept-Encoding": "gzip",
|
||||
}
|
||||
req, err := http.NewRequest("POST", endpointURL, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for k, v := range headers {
|
||||
req.Header.Set(k, v)
|
||||
}
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
log.Error("failed to do request: ", err)
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
var result map[string]interface{}
|
||||
err = json.NewDecoder(resp.Body).Decode(&result)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// Sign 生成签名
|
||||
func Sign(urlStr string) string {
|
||||
u := strings.Split(urlStr, "://")[1]
|
||||
encodedUrl := url.QueryEscape(u)
|
||||
uuidStr := strings.ReplaceAll(uuid.New().String(), "-", "")
|
||||
formattedDate := strings.ToLower(time.Now().UTC().Format("Mon, 02 Jan 2006 15:04:05")) + "gmt"
|
||||
bytesToSign := fmt.Sprintf("MSTranslatorAndroidApp%s%s%s", encodedUrl, formattedDate, uuidStr)
|
||||
bytesToSign = strings.ToLower(bytesToSign)
|
||||
decode, _ := base64.StdEncoding.DecodeString(voiceDecodeKey)
|
||||
hash := hmac.New(sha256.New, decode)
|
||||
hash.Write([]byte(bytesToSign))
|
||||
secretKey := hash.Sum(nil)
|
||||
signBase64 := base64.StdEncoding.EncodeToString(secretKey)
|
||||
return fmt.Sprintf("MSTranslatorAndroidApp::%s::%s::%s", signBase64, formattedDate, uuidStr)
|
||||
}
|
||||
Reference in New Issue
Block a user