feat: 初始提交
This commit is contained in:
0
backend/app/db/__init__.py
Normal file
0
backend/app/db/__init__.py
Normal file
9
backend/app/db/base.py
Normal file
9
backend/app/db/base.py
Normal file
@@ -0,0 +1,9 @@
|
||||
from sqlalchemy.orm import DeclarativeBase, declared_attr
|
||||
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
"""SQLAlchemy 基类,自动根据类名生成表名。"""
|
||||
|
||||
@declared_attr.directive
|
||||
def __tablename__(cls) -> str: # type: ignore[override]
|
||||
return cls.__name__.lower()
|
||||
122
backend/app/db/init_db.py
Normal file
122
backend/app/db/init_db.py
Normal file
@@ -0,0 +1,122 @@
|
||||
import logging
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from sqlalchemy import select, text
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.engine import URL, make_url
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
|
||||
|
||||
from ..core.config import settings
|
||||
from ..core.security import hash_password
|
||||
from ..models import Prompt, SystemConfig, User
|
||||
from .base import Base
|
||||
from .system_config_defaults import SYSTEM_CONFIG_DEFAULTS
|
||||
from .session import AsyncSessionLocal, engine
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def init_db() -> None:
|
||||
"""初始化数据库结构并确保默认管理员存在。"""
|
||||
|
||||
await _ensure_database_exists()
|
||||
|
||||
# ---- 第一步:创建所有表结构 ----
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
logger.info("数据库表结构已初始化")
|
||||
|
||||
# ---- 第二步:确保管理员账号至少存在一个 ----
|
||||
async with AsyncSessionLocal() as session:
|
||||
admin_exists = await session.execute(select(User).where(User.is_admin.is_(True)))
|
||||
if not admin_exists.scalars().first():
|
||||
logger.warning("未检测到管理员账号,正在创建默认管理员 ...")
|
||||
admin_user = User(
|
||||
username=settings.admin_default_username,
|
||||
email=settings.admin_default_email,
|
||||
hashed_password=hash_password(settings.admin_default_password),
|
||||
is_admin=True,
|
||||
)
|
||||
|
||||
session.add(admin_user)
|
||||
try:
|
||||
await session.commit()
|
||||
logger.info("默认管理员创建完成:%s", settings.admin_default_username)
|
||||
except IntegrityError:
|
||||
await session.rollback()
|
||||
logger.exception("默认管理员创建失败,可能是并发启动导致,请检查数据库状态")
|
||||
|
||||
# ---- 第三步:同步系统配置到数据库 ----
|
||||
for entry in SYSTEM_CONFIG_DEFAULTS:
|
||||
value = entry.value_getter(settings)
|
||||
if value is None:
|
||||
continue
|
||||
existing = await session.get(SystemConfig, entry.key)
|
||||
if existing:
|
||||
if entry.description and existing.description != entry.description:
|
||||
existing.description = entry.description
|
||||
continue
|
||||
session.add(
|
||||
SystemConfig(
|
||||
key=entry.key,
|
||||
value=value,
|
||||
description=entry.description,
|
||||
)
|
||||
)
|
||||
|
||||
await _ensure_default_prompts(session)
|
||||
|
||||
await session.commit()
|
||||
|
||||
|
||||
async def _ensure_database_exists() -> None:
|
||||
"""在首次连接前确认数据库存在,针对不同驱动做最小化准备工作。"""
|
||||
url = make_url(settings.sqlalchemy_database_uri)
|
||||
|
||||
if url.get_backend_name() == "sqlite":
|
||||
# SQLite 采用文件数据库,确保父目录存在即可,无需额外建库语句
|
||||
db_path = Path(url.database or "").expanduser()
|
||||
if not db_path.is_absolute():
|
||||
project_root = Path(__file__).resolve().parents[2]
|
||||
db_path = (project_root / db_path).resolve()
|
||||
db_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
return
|
||||
|
||||
database = (url.database or "").strip("/")
|
||||
if not database:
|
||||
return
|
||||
|
||||
admin_url = URL.create(
|
||||
drivername=url.drivername,
|
||||
username=url.username,
|
||||
password=url.password,
|
||||
host=url.host,
|
||||
port=url.port,
|
||||
database=None,
|
||||
query=url.query,
|
||||
)
|
||||
|
||||
admin_engine = create_async_engine(
|
||||
admin_url.render_as_string(hide_password=False),
|
||||
isolation_level="AUTOCOMMIT",
|
||||
)
|
||||
async with admin_engine.begin() as conn:
|
||||
await conn.execute(text(f"CREATE DATABASE IF NOT EXISTS `{database}`"))
|
||||
await admin_engine.dispose()
|
||||
|
||||
|
||||
async def _ensure_default_prompts(session: AsyncSession) -> None:
|
||||
prompts_dir = Path(__file__).resolve().parents[2] / "prompts"
|
||||
if not prompts_dir.is_dir():
|
||||
return
|
||||
|
||||
result = await session.execute(select(Prompt.name))
|
||||
existing_names = set(result.scalars().all())
|
||||
|
||||
for prompt_file in sorted(prompts_dir.glob("*.md")):
|
||||
name = prompt_file.stem
|
||||
if name in existing_names:
|
||||
continue
|
||||
content = prompt_file.read_text(encoding="utf-8")
|
||||
session.add(Prompt(name=name, content=content))
|
||||
30
backend/app/db/session.py
Normal file
30
backend/app/db/session.py
Normal file
@@ -0,0 +1,30 @@
|
||||
from collections.abc import AsyncGenerator
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
from sqlalchemy.pool import NullPool
|
||||
|
||||
from ..core.config import settings
|
||||
|
||||
# 根据不同数据库驱动调整连接池参数,确保在多数据库环境下表现稳定
|
||||
engine_kwargs = {"echo": settings.debug}
|
||||
if settings.is_sqlite_backend:
|
||||
# SQLite 场景下禁用连接池并放宽线程检查,避免多协程读写冲突
|
||||
engine_kwargs.update(
|
||||
pool_pre_ping=False,
|
||||
connect_args={"check_same_thread": False},
|
||||
poolclass=NullPool,
|
||||
)
|
||||
else:
|
||||
# MySQL 场景保持健康检查与连接复用,适用于生产环境的长连接需求
|
||||
engine_kwargs.update(pool_pre_ping=True, pool_recycle=3600)
|
||||
|
||||
engine = create_async_engine(settings.sqlalchemy_database_uri, **engine_kwargs)
|
||||
|
||||
# 统一的 Session 工厂,禁用 expire_on_commit 方便返回模型对象
|
||||
AsyncSessionLocal = async_sessionmaker(bind=engine, expire_on_commit=False)
|
||||
|
||||
|
||||
async def get_session() -> AsyncGenerator[AsyncSession, None]:
|
||||
"""FastAPI 依赖项:提供一个作用域内共享的数据库会话。"""
|
||||
async with AsyncSessionLocal() as session:
|
||||
yield session
|
||||
110
backend/app/db/system_config_defaults.py
Normal file
110
backend/app/db/system_config_defaults.py
Normal file
@@ -0,0 +1,110 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, Optional
|
||||
|
||||
from ..core.config import Settings
|
||||
|
||||
|
||||
def _to_optional_str(value: Optional[object]) -> Optional[str]:
|
||||
return str(value) if value is not None else None
|
||||
|
||||
|
||||
def _bool_to_text(value: bool) -> str:
|
||||
return "true" if value else "false"
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SystemConfigDefault:
|
||||
key: str
|
||||
value_getter: Callable[[Settings], Optional[str]]
|
||||
description: Optional[str] = None
|
||||
|
||||
|
||||
SYSTEM_CONFIG_DEFAULTS: list[SystemConfigDefault] = [
|
||||
SystemConfigDefault(
|
||||
key="llm.api_key",
|
||||
value_getter=lambda config: config.openai_api_key,
|
||||
description="默认 LLM API Key,用于后台调用大模型。",
|
||||
),
|
||||
SystemConfigDefault(
|
||||
key="llm.base_url",
|
||||
value_getter=lambda config: _to_optional_str(config.openai_base_url),
|
||||
description="默认大模型 API Base URL。",
|
||||
),
|
||||
SystemConfigDefault(
|
||||
key="llm.model",
|
||||
value_getter=lambda config: config.openai_model_name,
|
||||
description="默认 LLM 模型名称。",
|
||||
),
|
||||
SystemConfigDefault(
|
||||
key="smtp.server",
|
||||
value_getter=lambda config: config.smtp_server,
|
||||
description="用于发送邮件验证码的 SMTP 服务器地址。",
|
||||
),
|
||||
SystemConfigDefault(
|
||||
key="smtp.port",
|
||||
value_getter=lambda config: _to_optional_str(config.smtp_port),
|
||||
description="SMTP 服务端口。",
|
||||
),
|
||||
SystemConfigDefault(
|
||||
key="smtp.username",
|
||||
value_getter=lambda config: config.smtp_username,
|
||||
description="SMTP 登录用户名。",
|
||||
),
|
||||
SystemConfigDefault(
|
||||
key="smtp.password",
|
||||
value_getter=lambda config: config.smtp_password,
|
||||
description="SMTP 登录密码。",
|
||||
),
|
||||
SystemConfigDefault(
|
||||
key="smtp.from",
|
||||
value_getter=lambda config: config.email_from,
|
||||
description="邮件显示的发件人名称或邮箱。",
|
||||
),
|
||||
SystemConfigDefault(
|
||||
key="auth.allow_registration",
|
||||
value_getter=lambda config: _bool_to_text(config.allow_registration),
|
||||
description="是否允许用户自助注册。",
|
||||
),
|
||||
SystemConfigDefault(
|
||||
key="auth.linuxdo_enabled",
|
||||
value_getter=lambda config: _bool_to_text(config.enable_linuxdo_login),
|
||||
description="是否启用 Linux.do OAuth 登录。",
|
||||
),
|
||||
SystemConfigDefault(
|
||||
key="linuxdo.client_id",
|
||||
value_getter=lambda config: config.linuxdo_client_id,
|
||||
description="Linux.do OAuth Client ID。",
|
||||
),
|
||||
SystemConfigDefault(
|
||||
key="linuxdo.client_secret",
|
||||
value_getter=lambda config: config.linuxdo_client_secret,
|
||||
description="Linux.do OAuth Client Secret。",
|
||||
),
|
||||
SystemConfigDefault(
|
||||
key="linuxdo.redirect_uri",
|
||||
value_getter=lambda config: _to_optional_str(config.linuxdo_redirect_uri),
|
||||
description="Linux.do OAuth 回调地址。",
|
||||
),
|
||||
SystemConfigDefault(
|
||||
key="linuxdo.auth_url",
|
||||
value_getter=lambda config: _to_optional_str(config.linuxdo_auth_url),
|
||||
description="Linux.do OAuth 授权地址。",
|
||||
),
|
||||
SystemConfigDefault(
|
||||
key="linuxdo.token_url",
|
||||
value_getter=lambda config: _to_optional_str(config.linuxdo_token_url),
|
||||
description="Linux.do OAuth Token 获取地址。",
|
||||
),
|
||||
SystemConfigDefault(
|
||||
key="linuxdo.user_info_url",
|
||||
value_getter=lambda config: _to_optional_str(config.linuxdo_user_info_url),
|
||||
description="Linux.do 用户信息接口地址。",
|
||||
),
|
||||
SystemConfigDefault(
|
||||
key="writer.chapter_versions",
|
||||
value_getter=lambda config: _to_optional_str(config.writer_chapter_versions),
|
||||
description="每次生成章节的候选版本数量。",
|
||||
),
|
||||
]
|
||||
Reference in New Issue
Block a user