feat: 初始提交

This commit is contained in:
anonymous
2025-10-21 09:38:26 +08:00
committed by t59688
parent 2965b8e28f
commit c9fc816fab
175 changed files with 23968 additions and 87 deletions

View File

9
backend/app/db/base.py Normal file
View File

@@ -0,0 +1,9 @@
from sqlalchemy.orm import DeclarativeBase, declared_attr
class Base(DeclarativeBase):
"""SQLAlchemy 基类,自动根据类名生成表名。"""
@declared_attr.directive
def __tablename__(cls) -> str: # type: ignore[override]
return cls.__name__.lower()

122
backend/app/db/init_db.py Normal file
View File

@@ -0,0 +1,122 @@
import logging
from pathlib import Path
from sqlalchemy import select, text
from sqlalchemy.exc import IntegrityError
from sqlalchemy.engine import URL, make_url
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from ..core.config import settings
from ..core.security import hash_password
from ..models import Prompt, SystemConfig, User
from .base import Base
from .system_config_defaults import SYSTEM_CONFIG_DEFAULTS
from .session import AsyncSessionLocal, engine
logger = logging.getLogger(__name__)
async def init_db() -> None:
"""初始化数据库结构并确保默认管理员存在。"""
await _ensure_database_exists()
# ---- 第一步:创建所有表结构 ----
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
logger.info("数据库表结构已初始化")
# ---- 第二步:确保管理员账号至少存在一个 ----
async with AsyncSessionLocal() as session:
admin_exists = await session.execute(select(User).where(User.is_admin.is_(True)))
if not admin_exists.scalars().first():
logger.warning("未检测到管理员账号,正在创建默认管理员 ...")
admin_user = User(
username=settings.admin_default_username,
email=settings.admin_default_email,
hashed_password=hash_password(settings.admin_default_password),
is_admin=True,
)
session.add(admin_user)
try:
await session.commit()
logger.info("默认管理员创建完成:%s", settings.admin_default_username)
except IntegrityError:
await session.rollback()
logger.exception("默认管理员创建失败,可能是并发启动导致,请检查数据库状态")
# ---- 第三步:同步系统配置到数据库 ----
for entry in SYSTEM_CONFIG_DEFAULTS:
value = entry.value_getter(settings)
if value is None:
continue
existing = await session.get(SystemConfig, entry.key)
if existing:
if entry.description and existing.description != entry.description:
existing.description = entry.description
continue
session.add(
SystemConfig(
key=entry.key,
value=value,
description=entry.description,
)
)
await _ensure_default_prompts(session)
await session.commit()
async def _ensure_database_exists() -> None:
"""在首次连接前确认数据库存在,针对不同驱动做最小化准备工作。"""
url = make_url(settings.sqlalchemy_database_uri)
if url.get_backend_name() == "sqlite":
# SQLite 采用文件数据库,确保父目录存在即可,无需额外建库语句
db_path = Path(url.database or "").expanduser()
if not db_path.is_absolute():
project_root = Path(__file__).resolve().parents[2]
db_path = (project_root / db_path).resolve()
db_path.parent.mkdir(parents=True, exist_ok=True)
return
database = (url.database or "").strip("/")
if not database:
return
admin_url = URL.create(
drivername=url.drivername,
username=url.username,
password=url.password,
host=url.host,
port=url.port,
database=None,
query=url.query,
)
admin_engine = create_async_engine(
admin_url.render_as_string(hide_password=False),
isolation_level="AUTOCOMMIT",
)
async with admin_engine.begin() as conn:
await conn.execute(text(f"CREATE DATABASE IF NOT EXISTS `{database}`"))
await admin_engine.dispose()
async def _ensure_default_prompts(session: AsyncSession) -> None:
prompts_dir = Path(__file__).resolve().parents[2] / "prompts"
if not prompts_dir.is_dir():
return
result = await session.execute(select(Prompt.name))
existing_names = set(result.scalars().all())
for prompt_file in sorted(prompts_dir.glob("*.md")):
name = prompt_file.stem
if name in existing_names:
continue
content = prompt_file.read_text(encoding="utf-8")
session.add(Prompt(name=name, content=content))

30
backend/app/db/session.py Normal file
View File

@@ -0,0 +1,30 @@
from collections.abc import AsyncGenerator
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from sqlalchemy.pool import NullPool
from ..core.config import settings
# 根据不同数据库驱动调整连接池参数,确保在多数据库环境下表现稳定
engine_kwargs = {"echo": settings.debug}
if settings.is_sqlite_backend:
# SQLite 场景下禁用连接池并放宽线程检查,避免多协程读写冲突
engine_kwargs.update(
pool_pre_ping=False,
connect_args={"check_same_thread": False},
poolclass=NullPool,
)
else:
# MySQL 场景保持健康检查与连接复用,适用于生产环境的长连接需求
engine_kwargs.update(pool_pre_ping=True, pool_recycle=3600)
engine = create_async_engine(settings.sqlalchemy_database_uri, **engine_kwargs)
# 统一的 Session 工厂,禁用 expire_on_commit 方便返回模型对象
AsyncSessionLocal = async_sessionmaker(bind=engine, expire_on_commit=False)
async def get_session() -> AsyncGenerator[AsyncSession, None]:
"""FastAPI 依赖项:提供一个作用域内共享的数据库会话。"""
async with AsyncSessionLocal() as session:
yield session

View File

@@ -0,0 +1,110 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import Callable, Optional
from ..core.config import Settings
def _to_optional_str(value: Optional[object]) -> Optional[str]:
return str(value) if value is not None else None
def _bool_to_text(value: bool) -> str:
return "true" if value else "false"
@dataclass(frozen=True)
class SystemConfigDefault:
key: str
value_getter: Callable[[Settings], Optional[str]]
description: Optional[str] = None
SYSTEM_CONFIG_DEFAULTS: list[SystemConfigDefault] = [
SystemConfigDefault(
key="llm.api_key",
value_getter=lambda config: config.openai_api_key,
description="默认 LLM API Key用于后台调用大模型。",
),
SystemConfigDefault(
key="llm.base_url",
value_getter=lambda config: _to_optional_str(config.openai_base_url),
description="默认大模型 API Base URL。",
),
SystemConfigDefault(
key="llm.model",
value_getter=lambda config: config.openai_model_name,
description="默认 LLM 模型名称。",
),
SystemConfigDefault(
key="smtp.server",
value_getter=lambda config: config.smtp_server,
description="用于发送邮件验证码的 SMTP 服务器地址。",
),
SystemConfigDefault(
key="smtp.port",
value_getter=lambda config: _to_optional_str(config.smtp_port),
description="SMTP 服务端口。",
),
SystemConfigDefault(
key="smtp.username",
value_getter=lambda config: config.smtp_username,
description="SMTP 登录用户名。",
),
SystemConfigDefault(
key="smtp.password",
value_getter=lambda config: config.smtp_password,
description="SMTP 登录密码。",
),
SystemConfigDefault(
key="smtp.from",
value_getter=lambda config: config.email_from,
description="邮件显示的发件人名称或邮箱。",
),
SystemConfigDefault(
key="auth.allow_registration",
value_getter=lambda config: _bool_to_text(config.allow_registration),
description="是否允许用户自助注册。",
),
SystemConfigDefault(
key="auth.linuxdo_enabled",
value_getter=lambda config: _bool_to_text(config.enable_linuxdo_login),
description="是否启用 Linux.do OAuth 登录。",
),
SystemConfigDefault(
key="linuxdo.client_id",
value_getter=lambda config: config.linuxdo_client_id,
description="Linux.do OAuth Client ID。",
),
SystemConfigDefault(
key="linuxdo.client_secret",
value_getter=lambda config: config.linuxdo_client_secret,
description="Linux.do OAuth Client Secret。",
),
SystemConfigDefault(
key="linuxdo.redirect_uri",
value_getter=lambda config: _to_optional_str(config.linuxdo_redirect_uri),
description="Linux.do OAuth 回调地址。",
),
SystemConfigDefault(
key="linuxdo.auth_url",
value_getter=lambda config: _to_optional_str(config.linuxdo_auth_url),
description="Linux.do OAuth 授权地址。",
),
SystemConfigDefault(
key="linuxdo.token_url",
value_getter=lambda config: _to_optional_str(config.linuxdo_token_url),
description="Linux.do OAuth Token 获取地址。",
),
SystemConfigDefault(
key="linuxdo.user_info_url",
value_getter=lambda config: _to_optional_str(config.linuxdo_user_info_url),
description="Linux.do 用户信息接口地址。",
),
SystemConfigDefault(
key="writer.chapter_versions",
value_getter=lambda config: _to_optional_str(config.writer_chapter_versions),
description="每次生成章节的候选版本数量。",
),
]