from functools import lru_cache from pathlib import Path from typing import Optional from pydantic import AliasChoices, AnyUrl, Field, HttpUrl, validator from pydantic_settings import BaseSettings, SettingsConfigDict from sqlalchemy.engine import URL, make_url class Settings(BaseSettings): """应用全局配置,所有可调参数集中于此,统一加载自环境变量。""" # -------------------- 基础应用配置 -------------------- app_name: str = Field(default="AI Novel Generator API", description="FastAPI 文档标题") environment: str = Field(default="development", description="当前环境标识") debug: bool = Field(default=True, description="是否开启调试模式") allow_registration: bool = Field( default=True, env="ALLOW_USER_REGISTRATION", description="是否允许用户自助注册", ) logging_level: str = Field( default="INFO", env="LOGGING_LEVEL", description="应用日志级别", ) enable_linuxdo_login: bool = Field( default=False, env="ENABLE_LINUXDO_LOGIN", description="是否启用 Linux.do OAuth 登录", ) # -------------------- 安全相关配置 -------------------- secret_key: str = Field(..., env="SECRET_KEY", description="JWT 加密密钥") jwt_algorithm: str = Field(default="HS256", env="JWT_ALGORITHM", description="JWT 加密算法") access_token_expire_minutes: int = Field( default=60 * 24 * 7, env="ACCESS_TOKEN_EXPIRE_MINUTES", description="访问令牌过期时间,单位分钟" ) # -------------------- 数据库配置 -------------------- database_url: Optional[str] = Field( default=None, env="DATABASE_URL", description="完整的数据库连接串,填入后覆盖下方数据库配置" ) db_provider: str = Field( default="mysql", env="DB_PROVIDER", description="数据库类型,仅支持 mysql 或 sqlite" ) mysql_host: str = Field(default="localhost", env="MYSQL_HOST", description="MySQL 主机名") mysql_port: int = Field(default=3306, env="MYSQL_PORT", description="MySQL 端口") mysql_user: str = Field(default="root", env="MYSQL_USER", description="MySQL 用户名") mysql_password: str = Field(default="", env="MYSQL_PASSWORD", description="MySQL 密码") mysql_database: str = Field(default="arboris", env="MYSQL_DATABASE", description="MySQL 数据库名称") # -------------------- 管理员初始化配置 -------------------- admin_default_username: str = Field(default="admin", env="ADMIN_DEFAULT_USERNAME", description="默认管理员用户名") admin_default_password: str = Field(default="ChangeMe123!", env="ADMIN_DEFAULT_PASSWORD", description="默认管理员密码") admin_default_email: Optional[str] = Field(default=None, env="ADMIN_DEFAULT_EMAIL", description="默认管理员邮箱") # -------------------- LLM 相关配置 -------------------- openai_api_key: Optional[str] = Field(default=None, env="OPENAI_API_KEY", description="默认的 LLM API Key") openai_base_url: Optional[HttpUrl] = Field( default=None, env="OPENAI_API_BASE_URL", validation_alias=AliasChoices("OPENAI_API_BASE_URL", "OPENAI_BASE_URL"), description="LLM API Base URL", ) openai_model_name: str = Field(default="gpt-4o-mini", env="OPENAI_MODEL_NAME", description="默认 LLM 模型名称") writer_chapter_versions: int = Field( default=2, ge=1, env="WRITER_CHAPTER_VERSION_COUNT", validation_alias=AliasChoices("WRITER_CHAPTER_VERSION_COUNT", "WRITER_CHAPTER_VERSIONS"), description="每次生成章节的候选版本数量", ) embedding_provider: str = Field( default="openai", env="EMBEDDING_PROVIDER", description="嵌入模型提供方,支持 openai 或 ollama", ) embedding_base_url: Optional[AnyUrl] = Field( default=None, env="EMBEDDING_BASE_URL", description="嵌入模型使用的 Base URL", ) embedding_api_key: Optional[str] = Field( default=None, env="EMBEDDING_API_KEY", description="嵌入模型专用 API Key", ) embedding_model: str = Field( default="text-embedding-3-large", env="EMBEDDING_MODEL", validation_alias=AliasChoices("EMBEDDING_MODEL", "VECTOR_EMBEDDING_MODEL"), description="默认的嵌入模型名称", ) embedding_model_vector_size: Optional[int] = Field( default=None, env="EMBEDDING_MODEL_VECTOR_SIZE", description="嵌入向量维度,未配置时将自动检测", ) ollama_embedding_base_url: Optional[AnyUrl] = Field( default=None, env="OLLAMA_EMBEDDING_BASE_URL", description="Ollama 嵌入模型服务地址", ) ollama_embedding_model: str = Field( default="nomic-embed-text:latest", env="OLLAMA_EMBEDDING_MODEL", description="Ollama 嵌入模型名称", ) vector_db_url: Optional[str] = Field( default=None, env="VECTOR_DB_URL", description="libsql 向量库连接地址", ) vector_db_auth_token: Optional[str] = Field( default=None, env="VECTOR_DB_AUTH_TOKEN", description="libsql 访问令牌", ) vector_top_k_chunks: int = Field( default=5, ge=0, env="VECTOR_TOP_K_CHUNKS", description="剧情 chunk 检索条数", ) vector_top_k_summaries: int = Field( default=3, ge=0, env="VECTOR_TOP_K_SUMMARIES", description="章节摘要检索条数", ) vector_chunk_size: int = Field( default=480, ge=128, env="VECTOR_CHUNK_SIZE", description="章节分块的目标字数", ) vector_chunk_overlap: int = Field( default=120, ge=0, env="VECTOR_CHUNK_OVERLAP", description="章节分块重叠字数", ) # -------------------- Linux.do OAuth 配置 -------------------- linuxdo_client_id: Optional[str] = Field(default=None, env="LINUXDO_CLIENT_ID", description="Linux.do OAuth Client ID") linuxdo_client_secret: Optional[str] = Field( default=None, env="LINUXDO_CLIENT_SECRET", description="Linux.do OAuth Client Secret" ) linuxdo_redirect_uri: Optional[HttpUrl] = Field( default=None, env="LINUXDO_REDIRECT_URI", description="Linux.do OAuth 回调地址" ) linuxdo_auth_url: Optional[HttpUrl] = Field( default=None, env="LINUXDO_AUTH_URL", description="Linux.do OAuth 授权地址" ) linuxdo_token_url: Optional[HttpUrl] = Field( default=None, env="LINUXDO_TOKEN_URL", description="Linux.do OAuth Token 获取地址" ) linuxdo_user_info_url: Optional[HttpUrl] = Field( default=None, env="LINUXDO_USER_INFO_URL", description="Linux.do 用户信息接口地址" ) # -------------------- 邮件配置 -------------------- smtp_server: Optional[str] = Field(default=None, env="SMTP_SERVER", description="SMTP 服务地址") smtp_port: int = Field(default=587, env="SMTP_PORT", description="SMTP 服务端口") smtp_username: Optional[str] = Field(default=None, env="SMTP_USERNAME", description="SMTP 登录用户名") smtp_password: Optional[str] = Field(default=None, env="SMTP_PASSWORD", description="SMTP 登录密码") email_from: Optional[str] = Field(default=None, env="EMAIL_FROM", description="邮件发送方显示名或邮箱") model_config = SettingsConfigDict( env_file=("new-backend/.env", ".env", "backend/.env"), env_file_encoding="utf-8", extra="ignore" ) @validator("database_url", pre=True, always=True) def _normalize_database_url(cls, value: Optional[str]) -> Optional[str]: """当环境变量中提供 DATABASE_URL 时,原样返回,便于自定义。""" return value.strip() if isinstance(value, str) and value.strip() else value @validator("db_provider", pre=True) def _normalize_db_provider(cls, value: Optional[str]) -> str: """统一数据库类型大小写,并限制为受支持的驱动。""" candidate = (value or "mysql").strip().lower() if candidate not in {"mysql", "sqlite"}: raise ValueError("DB_PROVIDER 仅支持 mysql 或 sqlite") return candidate @validator("embedding_provider", pre=True) def _normalize_embedding_provider(cls, value: Optional[str]) -> str: """限制嵌入模型提供方的取值范围。""" candidate = (value or "openai").strip().lower() if candidate not in {"openai", "ollama"}: raise ValueError("EMBEDDING_PROVIDER 仅支持 openai 或 ollama") return candidate @validator("logging_level", pre=True) def _normalize_logging_level(cls, value: Optional[str]) -> str: """规范日志级别配置。""" candidate = (value or "INFO").strip().upper() valid_levels = {"CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"} if candidate not in valid_levels: raise ValueError("LOGGING_LEVEL 仅支持 CRITICAL/ERROR/WARNING/INFO/DEBUG/NOTSET") return candidate @property def sqlalchemy_database_uri(self) -> str: """生成 SQLAlchemy 兼容的异步连接串,数据库类型由 DB_PROVIDER 控制。""" if self.database_url: url = make_url(self.database_url) database = (url.database or "").strip("/") normalized = URL.create( drivername=url.drivername, username=url.username, password=url.password, host=url.host, port=url.port, database=database or None, query=url.query, ) return normalized.render_as_string(hide_password=False) if self.db_provider == "sqlite": # SQLite 固定使用 storage/arboris.db,并转换为绝对路径以避免运行目录差异 project_root = Path(__file__).resolve().parents[2] db_path = (project_root / "storage" / "arboris.db").resolve() return f"sqlite+aiosqlite:///{db_path}" # MySQL 分支:统一对密码进行 URL 编码,避免特殊字符破坏连接串 from urllib.parse import quote_plus encoded_password = quote_plus(self.mysql_password) database = (self.mysql_database or "").strip("/") return ( f"mysql+asyncmy://{self.mysql_user}:{encoded_password}" f"@{self.mysql_host}:{self.mysql_port}/{database}" ) @property def is_sqlite_backend(self) -> bool: """辅助属性:判断当前连接串是否指向 SQLite,用于差异化初始化流程。""" return make_url(self.sqlalchemy_database_uri).get_backend_name() == "sqlite" @property def vector_store_enabled(self) -> bool: """是否已经配置向量库,用于在业务逻辑中快速判断。""" return bool(self.vector_db_url) @lru_cache def get_settings() -> Settings: """使用 LRU 缓存确保配置只初始化一次,减少 IO 与解析开销。""" return Settings() settings = get_settings()