feat: 重构项目以符合 Go 规范,添加 OpenAI 接口适配,优化长文本朗读功能(切割后合并)

This commit is contained in:
王锦强
2025-03-09 13:02:28 +08:00
parent 539f6d9ef5
commit 8f2fd68ebe
31 changed files with 2487 additions and 647 deletions

View File

@@ -0,0 +1,290 @@
package microsoft
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"html"
"io"
"log"
"net/http"
"strings"
"sync"
"time"
"tts/internal/config"
"tts/internal/models"
"tts/internal/utils"
)
const (
userAgent = "okhttp/4.5.0"
voicesEndpoint = "https://%s.tts.speech.microsoft.com/cognitiveservices/voices/list"
ttsEndpoint = "https://%s.tts.speech.microsoft.com/cognitiveservices/v1"
ssmlTemplate = `<speak version='1.0' xmlns='http://www.w3.org/2001/10/synthesis' xmlns:mstts="http://www.w3.org/2001/mstts" xml:lang='%s'>
<voice name='%s'>
<mstts:express-as style="general" styledegree="1.0" role="default">
<prosody rate='%s%%' pitch='%s%%' volume="medium">
%s
</prosody>
</mstts:express-as>
</voice>
</speak>`
)
// Client 是Microsoft TTS API的客户端实现
type Client struct {
defaultVoice string
defaultRate string
defaultPitch string
defaultFormat string
maxTextLength int
httpClient *http.Client
voicesCache []models.Voice
voicesCacheMu sync.RWMutex
voicesCacheExpiry time.Time
// 端点和认证信息
endpoint map[string]interface{}
endpointMu sync.RWMutex
endpointExpiry time.Time
}
func (c *Client) HandleOpenAITTS(w http.ResponseWriter, r *http.Request) {
//TODO implement me
panic("implement me")
}
// NewClient 创建一个新的Microsoft TTS客户端
func NewClient(cfg *config.Config) *Client {
client := &Client{
defaultVoice: cfg.TTS.DefaultVoice,
defaultRate: cfg.TTS.DefaultRate,
defaultPitch: cfg.TTS.DefaultPitch,
defaultFormat: cfg.TTS.DefaultFormat,
maxTextLength: cfg.TTS.MaxTextLength,
httpClient: &http.Client{
Timeout: time.Duration(cfg.TTS.RequestTimeout) * time.Second,
},
voicesCacheExpiry: time.Time{}, // 初始时缓存为空
endpointExpiry: time.Time{}, // 初始时端点为空
}
return client
}
// getEndpoint 获取或刷新认证端点
func (c *Client) getEndpoint(ctx context.Context) (map[string]interface{}, error) {
c.endpointMu.RLock()
if !c.endpointExpiry.IsZero() && time.Now().Before(c.endpointExpiry) && c.endpoint != nil {
endpoint := c.endpoint
c.endpointMu.RUnlock()
return endpoint, nil
}
c.endpointMu.RUnlock()
// 获取新的端点信息
endpoint, err := utils.GetEndpoint()
if err != nil {
return nil, err
}
// 更新缓存
c.endpointMu.Lock()
c.endpoint = endpoint
c.endpointExpiry = time.Now().Add(45 * time.Minute) // 令牌有效期通常是1小时提前刷新
c.endpointMu.Unlock()
return endpoint, nil
}
// ListVoices 获取可用的语音列表
func (c *Client) ListVoices(ctx context.Context, locale string) ([]models.Voice, error) {
// 检查缓存是否有效
c.voicesCacheMu.RLock()
if !c.voicesCacheExpiry.IsZero() && time.Now().Before(c.voicesCacheExpiry) && len(c.voicesCache) > 0 {
voices := c.voicesCache
c.voicesCacheMu.RUnlock()
// 如果指定了locale则过滤结果
if locale != "" {
var filtered []models.Voice
for _, voice := range voices {
if strings.HasPrefix(voice.Locale, locale) {
filtered = append(filtered, voice)
}
}
return filtered, nil
}
return voices, nil
}
c.voicesCacheMu.RUnlock()
// 缓存无效需要从API获取
endpoint, err := c.getEndpoint(ctx)
if err != nil {
return nil, err
}
url := fmt.Sprintf(voicesEndpoint, endpoint["r"])
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return nil, err
}
// 使用新的认证方式
req.Header.Set("Authorization", endpoint["t"].(string))
resp, err := c.httpClient.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
return nil, fmt.Errorf("API error: %s, status: %d", string(body), resp.StatusCode)
}
var msVoices []MicrosoftVoice
if err := json.NewDecoder(resp.Body).Decode(&msVoices); err != nil {
return nil, err
}
// 转换为通用模型
voices := make([]models.Voice, len(msVoices))
for i, v := range msVoices {
voices[i] = models.Voice{
Name: v.Name,
DisplayName: v.DisplayName,
LocalName: v.LocalName,
ShortName: v.ShortName,
Gender: v.Gender,
Locale: v.Locale,
LocaleName: v.LocaleName,
StyleList: v.StyleList,
SampleRateHertz: v.SampleRateHertz, // 直接使用字符串,无需转换
}
}
// 更新缓存
c.voicesCacheMu.Lock()
c.voicesCache = voices
c.voicesCacheExpiry = time.Now().Add(1 * time.Hour) // 缓存1小时
c.voicesCacheMu.Unlock()
// 如果指定了locale则过滤结果
if locale != "" {
var filtered []models.Voice
for _, voice := range voices {
if strings.HasPrefix(voice.Locale, locale) {
filtered = append(filtered, voice)
}
}
return filtered, nil
}
return voices, nil
}
// SynthesizeSpeech 将文本转换为语音
func (c *Client) SynthesizeSpeech(ctx context.Context, req models.TTSRequest) (*models.TTSResponse, error) {
resp, err := c.createTTSRequest(ctx, req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
// 读取音频数据
audio, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
return &models.TTSResponse{
AudioContent: audio,
ContentType: "audio/mpeg",
CacheHit: false,
}, nil
}
// createTTSRequest 创建并执行TTS请求返回HTTP响应
func (c *Client) createTTSRequest(ctx context.Context, req models.TTSRequest) (*http.Response, error) {
// 参数验证
if req.Text == "" {
return nil, errors.New("文本不能为空")
}
if len(req.Text) > c.maxTextLength {
return nil, fmt.Errorf("文本长度超过限制 (%d > %d)", len(req.Text), c.maxTextLength)
}
// 使用默认值填充空白参数
voice := req.Voice
if voice == "" {
voice = c.defaultVoice
}
rate := req.Rate
if rate == "" {
rate = c.defaultRate
}
pitch := req.Pitch
if pitch == "" {
pitch = c.defaultPitch
}
// 提取语言
locale := "zh-CN" // 默认
parts := strings.Split(voice, "-")
if len(parts) >= 2 {
locale = parts[0] + "-" + parts[1]
}
// 对文本进行HTML转义防止XML解析错误
escapedText := html.EscapeString(req.Text)
// 准备SSML内容
ssml := fmt.Sprintf(ssmlTemplate, locale, voice, rate, pitch, escapedText)
// 获取端点信息
endpoint, err := c.getEndpoint(ctx)
if err != nil {
return nil, err
}
// 准备请求
url := fmt.Sprintf(ttsEndpoint, endpoint["r"])
reqBody := bytes.NewBufferString(ssml)
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, reqBody)
if err != nil {
return nil, err
}
httpReq.Header.Set("Authorization", endpoint["t"].(string))
httpReq.Header.Set("Content-Type", "application/ssml+xml")
httpReq.Header.Set("X-Microsoft-OutputFormat", c.defaultFormat)
httpReq.Header.Set("User-Agent", userAgent)
// 发送请求
resp, err := c.httpClient.Do(httpReq)
if err != nil {
return nil, err
}
if resp.StatusCode != http.StatusOK {
// 获取响应体以便调试
body, _ := io.ReadAll(resp.Body)
resp.Body.Close()
log.Printf("TTS API错误: %s, 状态码: %d", string(body), resp.StatusCode)
return nil, fmt.Errorf("TTS API错误: %s, 状态码: %d", string(body), resp.StatusCode)
}
return resp, nil
}

View File

@@ -0,0 +1,45 @@
package microsoft
// MicrosoftVoice 表示Microsoft TTS服务中的一个语音
type MicrosoftVoice struct {
Name string `json:"Name"`
DisplayName string `json:"DisplayName"`
LocalName string `json:"LocalName"`
ShortName string `json:"ShortName"`
Gender string `json:"Gender"`
Locale string `json:"Locale"`
LocaleName string `json:"LocaleName"`
StyleList []string `json:"StyleList,omitempty"`
SampleRateHertz string `json:"SampleRateHertz"`
VoiceType string `json:"VoiceType"`
Status string `json:"Status"`
}
// SSMLRequest 表示发送给Microsoft TTS服务的SSML请求
type SSMLRequest struct {
XMLHeader string
Voice string
Language string
Rate string
Pitch string
Text string
}
// FormatContentTypeMap 定义音频格式到MIME类型的映射
var FormatContentTypeMap = map[string]string{
"raw-16khz-16bit-mono-pcm": "audio/pcm",
"raw-8khz-8bit-mono-mulaw": "audio/basic",
"riff-8khz-8bit-mono-alaw": "audio/alaw",
"riff-8khz-8bit-mono-mulaw": "audio/mulaw",
"riff-16khz-16bit-mono-pcm": "audio/wav",
"audio-16khz-128kbitrate-mono-mp3": "audio/mp3",
"audio-16khz-64kbitrate-mono-mp3": "audio/mp3",
"audio-16khz-32kbitrate-mono-mp3": "audio/mp3",
"raw-24khz-16bit-mono-pcm": "audio/pcm",
"riff-24khz-16bit-mono-pcm": "audio/wav",
"audio-24khz-160kbitrate-mono-mp3": "audio/mp3",
"audio-24khz-96kbitrate-mono-mp3": "audio/mp3",
"audio-24khz-48kbitrate-mono-mp3": "audio/mp3",
"ogg-24khz-16bit-mono-opus": "audio/ogg",
"webm-24khz-16bit-mono-opus": "audio/webm",
}