feat: 使用 viper 绑定配置,open ai api 添加认证配置
This commit is contained in:
@@ -2,41 +2,46 @@ package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"gopkg.in/yaml.v3"
|
||||
"github.com/spf13/viper"
|
||||
)
|
||||
|
||||
// Config 包含应用程序的所有配置
|
||||
type Config struct {
|
||||
Server ServerConfig `yaml:"server"`
|
||||
TTS TTSConfig `yaml:"tts"`
|
||||
Server ServerConfig `mapstructure:"server"`
|
||||
TTS TTSConfig `mapstructure:"tts"`
|
||||
OpenAI OpenAIConfig `mapstructure:"openai"`
|
||||
}
|
||||
|
||||
// OpenAIConfig 包含OpenAI API配置
|
||||
type OpenAIConfig struct {
|
||||
ApiKey string `mapstructure:"api_key"`
|
||||
}
|
||||
|
||||
// ServerConfig 包含HTTP服务器配置
|
||||
type ServerConfig struct {
|
||||
Port int `yaml:"port"`
|
||||
ReadTimeout int `yaml:"read_timeout"` // 单位:秒
|
||||
WriteTimeout int `yaml:"write_timeout"` // 单位:秒
|
||||
BasePath string `yaml:"base_path"`
|
||||
Port int `mapstructure:"port"`
|
||||
ReadTimeout int `mapstructure:"read_timeout"`
|
||||
WriteTimeout int `mapstructure:"write_timeout"`
|
||||
BasePath string `mapstructure:"base_path"`
|
||||
}
|
||||
|
||||
// TTSConfig 包含Microsoft TTS API配置
|
||||
type TTSConfig struct {
|
||||
APIKey string `yaml:"api_key"`
|
||||
Region string `yaml:"region"`
|
||||
DefaultVoice string `yaml:"default_voice"`
|
||||
DefaultRate string `yaml:"default_rate"`
|
||||
DefaultPitch string `yaml:"default_pitch"`
|
||||
DefaultFormat string `yaml:"default_format"`
|
||||
MaxTextLength int `yaml:"max_text_length"`
|
||||
RequestTimeout int `yaml:"request_timeout"` // 单位:秒
|
||||
MaxConcurrent int `yaml:"max_concurrent"`
|
||||
SegmentThreshold int `yaml:"segment_threshold"`
|
||||
MinSentenceLength int `yaml:"min_sentence_length"`
|
||||
MaxSentenceLength int `yaml:"max_sentence_length"`
|
||||
VoiceMapping map[string]string `yaml:"voice_mapping"` // OpenAI声音到Azure声音的映射
|
||||
Region string `mapstructure:"region"`
|
||||
DefaultVoice string `mapstructure:"default_voice"`
|
||||
DefaultRate string `mapstructure:"default_rate"`
|
||||
DefaultPitch string `mapstructure:"default_pitch"`
|
||||
DefaultFormat string `mapstructure:"default_format"`
|
||||
MaxTextLength int `mapstructure:"max_text_length"`
|
||||
RequestTimeout int `mapstructure:"request_timeout"`
|
||||
MaxConcurrent int `mapstructure:"max_concurrent"`
|
||||
SegmentThreshold int `mapstructure:"segment_threshold"`
|
||||
MinSentenceLength int `mapstructure:"min_sentence_length"`
|
||||
MaxSentenceLength int `mapstructure:"max_sentence_length"`
|
||||
VoiceMapping map[string]string `mapstructure:"voice_mapping"`
|
||||
}
|
||||
|
||||
var (
|
||||
@@ -48,20 +53,28 @@ var (
|
||||
func Load(configPath string) (*Config, error) {
|
||||
var err error
|
||||
once.Do(func() {
|
||||
// 设置默认配置
|
||||
setDefaults()
|
||||
v := viper.New()
|
||||
|
||||
// 配置 Viper
|
||||
v.SetConfigName("config")
|
||||
v.SetConfigType("yaml")
|
||||
v.SetEnvKeyReplacer(strings.NewReplacer(".", "_"))
|
||||
v.AutomaticEnv() // 自动绑定环境变量
|
||||
|
||||
// 从配置文件加载
|
||||
if configPath != "" {
|
||||
err = loadFromFile(configPath)
|
||||
if err != nil {
|
||||
v.SetConfigFile(configPath)
|
||||
if err = v.ReadInConfig(); err != nil {
|
||||
err = fmt.Errorf("加载配置文件失败: %w", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// 从环境变量覆盖
|
||||
overrideFromEnv()
|
||||
// 将配置绑定到结构体
|
||||
if err = v.Unmarshal(&config); err != nil {
|
||||
err = fmt.Errorf("解析配置失败: %w", err)
|
||||
return
|
||||
}
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
@@ -71,58 +84,6 @@ func Load(configPath string) (*Config, error) {
|
||||
return &config, nil
|
||||
}
|
||||
|
||||
// 设置默认配置值
|
||||
func setDefaults() {
|
||||
config = Config{
|
||||
Server: ServerConfig{
|
||||
Port: 8080,
|
||||
ReadTimeout: 30,
|
||||
WriteTimeout: 30,
|
||||
BasePath: "",
|
||||
},
|
||||
TTS: TTSConfig{
|
||||
DefaultVoice: "zh-CN-XiaoxiaoNeural",
|
||||
DefaultRate: "0%",
|
||||
DefaultPitch: "0%",
|
||||
DefaultFormat: "audio-24khz-48kbitrate-mono-mp3",
|
||||
MaxTextLength: 5000,
|
||||
RequestTimeout: 30,
|
||||
MaxConcurrent: 10,
|
||||
SegmentThreshold: 500,
|
||||
MinSentenceLength: 200,
|
||||
MaxSentenceLength: 300,
|
||||
VoiceMapping: make(map[string]string),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// 从配置文件加载配置
|
||||
func loadFromFile(path string) error {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return yaml.Unmarshal(data, &config)
|
||||
}
|
||||
|
||||
// 从环境变量中覆盖配置
|
||||
func overrideFromEnv() {
|
||||
if port := os.Getenv("TTS_SERVER_PORT"); port != "" {
|
||||
fmt.Sscanf(port, "%d", &config.Server.Port)
|
||||
}
|
||||
|
||||
if apiKey := os.Getenv("TTS_API_KEY"); apiKey != "" {
|
||||
config.TTS.APIKey = apiKey
|
||||
}
|
||||
|
||||
if region := os.Getenv("TTS_API_REGION"); region != "" {
|
||||
config.TTS.Region = region
|
||||
}
|
||||
|
||||
// 可以添加更多环境变量覆盖
|
||||
}
|
||||
|
||||
// Get 返回已加载的配置
|
||||
func Get() *Config {
|
||||
return &config
|
||||
|
||||
Reference in New Issue
Block a user