feat: 初始提交
This commit is contained in:
122
backend/app/db/init_db.py
Normal file
122
backend/app/db/init_db.py
Normal file
@@ -0,0 +1,122 @@
|
||||
import logging
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from sqlalchemy import select, text
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.engine import URL, make_url
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
|
||||
|
||||
from ..core.config import settings
|
||||
from ..core.security import hash_password
|
||||
from ..models import Prompt, SystemConfig, User
|
||||
from .base import Base
|
||||
from .system_config_defaults import SYSTEM_CONFIG_DEFAULTS
|
||||
from .session import AsyncSessionLocal, engine
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def init_db() -> None:
|
||||
"""初始化数据库结构并确保默认管理员存在。"""
|
||||
|
||||
await _ensure_database_exists()
|
||||
|
||||
# ---- 第一步:创建所有表结构 ----
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
logger.info("数据库表结构已初始化")
|
||||
|
||||
# ---- 第二步:确保管理员账号至少存在一个 ----
|
||||
async with AsyncSessionLocal() as session:
|
||||
admin_exists = await session.execute(select(User).where(User.is_admin.is_(True)))
|
||||
if not admin_exists.scalars().first():
|
||||
logger.warning("未检测到管理员账号,正在创建默认管理员 ...")
|
||||
admin_user = User(
|
||||
username=settings.admin_default_username,
|
||||
email=settings.admin_default_email,
|
||||
hashed_password=hash_password(settings.admin_default_password),
|
||||
is_admin=True,
|
||||
)
|
||||
|
||||
session.add(admin_user)
|
||||
try:
|
||||
await session.commit()
|
||||
logger.info("默认管理员创建完成:%s", settings.admin_default_username)
|
||||
except IntegrityError:
|
||||
await session.rollback()
|
||||
logger.exception("默认管理员创建失败,可能是并发启动导致,请检查数据库状态")
|
||||
|
||||
# ---- 第三步:同步系统配置到数据库 ----
|
||||
for entry in SYSTEM_CONFIG_DEFAULTS:
|
||||
value = entry.value_getter(settings)
|
||||
if value is None:
|
||||
continue
|
||||
existing = await session.get(SystemConfig, entry.key)
|
||||
if existing:
|
||||
if entry.description and existing.description != entry.description:
|
||||
existing.description = entry.description
|
||||
continue
|
||||
session.add(
|
||||
SystemConfig(
|
||||
key=entry.key,
|
||||
value=value,
|
||||
description=entry.description,
|
||||
)
|
||||
)
|
||||
|
||||
await _ensure_default_prompts(session)
|
||||
|
||||
await session.commit()
|
||||
|
||||
|
||||
async def _ensure_database_exists() -> None:
|
||||
"""在首次连接前确认数据库存在,针对不同驱动做最小化准备工作。"""
|
||||
url = make_url(settings.sqlalchemy_database_uri)
|
||||
|
||||
if url.get_backend_name() == "sqlite":
|
||||
# SQLite 采用文件数据库,确保父目录存在即可,无需额外建库语句
|
||||
db_path = Path(url.database or "").expanduser()
|
||||
if not db_path.is_absolute():
|
||||
project_root = Path(__file__).resolve().parents[2]
|
||||
db_path = (project_root / db_path).resolve()
|
||||
db_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
return
|
||||
|
||||
database = (url.database or "").strip("/")
|
||||
if not database:
|
||||
return
|
||||
|
||||
admin_url = URL.create(
|
||||
drivername=url.drivername,
|
||||
username=url.username,
|
||||
password=url.password,
|
||||
host=url.host,
|
||||
port=url.port,
|
||||
database=None,
|
||||
query=url.query,
|
||||
)
|
||||
|
||||
admin_engine = create_async_engine(
|
||||
admin_url.render_as_string(hide_password=False),
|
||||
isolation_level="AUTOCOMMIT",
|
||||
)
|
||||
async with admin_engine.begin() as conn:
|
||||
await conn.execute(text(f"CREATE DATABASE IF NOT EXISTS `{database}`"))
|
||||
await admin_engine.dispose()
|
||||
|
||||
|
||||
async def _ensure_default_prompts(session: AsyncSession) -> None:
|
||||
prompts_dir = Path(__file__).resolve().parents[2] / "prompts"
|
||||
if not prompts_dir.is_dir():
|
||||
return
|
||||
|
||||
result = await session.execute(select(Prompt.name))
|
||||
existing_names = set(result.scalars().all())
|
||||
|
||||
for prompt_file in sorted(prompts_dir.glob("*.md")):
|
||||
name = prompt_file.stem
|
||||
if name in existing_names:
|
||||
continue
|
||||
content = prompt_file.read_text(encoding="utf-8")
|
||||
session.add(Prompt(name=name, content=content))
|
||||
Reference in New Issue
Block a user