feat: 初始提交

This commit is contained in:
anonymous
2025-10-21 09:38:26 +08:00
committed by t59688
parent 2965b8e28f
commit c9fc816fab
175 changed files with 23968 additions and 87 deletions

0
backend/app/__init__.py Normal file
View File

View File

View File

@@ -0,0 +1,12 @@
from fastapi import APIRouter
from . import admin, auth, llm_config, novels, updates, writer
api_router = APIRouter()
api_router.include_router(auth.router)
api_router.include_router(novels.router)
api_router.include_router(writer.router)
api_router.include_router(admin.router)
api_router.include_router(updates.router)
api_router.include_router(llm_config.router)

View File

@@ -0,0 +1,340 @@
import logging
from typing import List, Optional
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy import func, select
from sqlalchemy.ext.asyncio import AsyncSession
from ...core.dependencies import get_current_admin
from ...db.session import get_session
from ...models import NovelProject, UsageMetric, User
from ...schemas.admin import (
AdminNovelSummary,
DailyRequestLimit,
Statistics,
UpdateLogCreate,
UpdateLogRead,
UpdateLogUpdate,
)
from ...schemas.config import SystemConfigCreate, SystemConfigRead, SystemConfigUpdate
from ...schemas.prompt import PromptCreate, PromptRead, PromptUpdate
from ...schemas.novel import (
Chapter as ChapterSchema,
NovelProject as NovelProjectSchema,
NovelSectionResponse,
NovelSectionType,
)
from ...schemas.user import PasswordChangeRequest, User as UserSchema
from ...services.auth_service import AuthService
from ...services.admin_setting_service import AdminSettingService
from ...services.config_service import ConfigService
from ...services.novel_service import NovelService
from ...services.prompt_service import PromptService
from ...services.update_log_service import UpdateLogService
from ...services.user_service import UserService
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/admin", tags=["Admin"])
def get_prompt_service(session: AsyncSession = Depends(get_session)) -> PromptService:
return PromptService(session)
def get_update_log_service(session: AsyncSession = Depends(get_session)) -> UpdateLogService:
return UpdateLogService(session)
def get_admin_setting_service(session: AsyncSession = Depends(get_session)) -> AdminSettingService:
return AdminSettingService(session)
def get_config_service(session: AsyncSession = Depends(get_session)) -> ConfigService:
return ConfigService(session)
def get_novel_service(session: AsyncSession = Depends(get_session)) -> NovelService:
return NovelService(session)
def get_user_service(session: AsyncSession = Depends(get_session)) -> UserService:
return UserService(session)
def get_auth_service(session: AsyncSession = Depends(get_session)) -> AuthService:
return AuthService(session)
@router.get("/stats", response_model=Statistics)
async def read_statistics(
session: AsyncSession = Depends(get_session),
_: None = Depends(get_current_admin),
) -> Statistics:
novel_count = await session.scalar(select(func.count(NovelProject.id))) or 0
user_count = await session.scalar(select(func.count(User.id))) or 0
usage = await session.get(UsageMetric, "api_request_count")
api_request_count = usage.value if usage else 0
logger.info("管理员获取统计数据:小说=%s,用户=%s,请求=%s", novel_count, user_count, api_request_count)
return Statistics(novel_count=novel_count, user_count=user_count, api_request_count=api_request_count)
@router.get("/users", response_model=List[UserSchema])
async def list_users(
service: UserService = Depends(get_user_service),
_: None = Depends(get_current_admin),
) -> List[UserSchema]:
users = await service.list_users()
logger.info("管理员请求用户列表,共 %s", len(users))
return [UserSchema.model_validate(user) for user in users]
@router.get("/novel-projects", response_model=List[AdminNovelSummary])
async def list_novel_projects(
service: NovelService = Depends(get_novel_service),
_: None = Depends(get_current_admin),
) -> List[AdminNovelSummary]:
projects = await service.list_projects_for_admin()
logger.info("管理员查看项目列表,共 %s", len(projects))
return projects
@router.get("/novel-projects/{project_id}", response_model=NovelProjectSchema)
async def get_novel_project(
project_id: str,
service: NovelService = Depends(get_novel_service),
_: None = Depends(get_current_admin),
) -> NovelProjectSchema:
logger.info("管理员查看项目详情:%s", project_id)
return await service.get_project_schema_for_admin(project_id)
@router.get("/novel-projects/{project_id}/sections/{section}", response_model=NovelSectionResponse)
async def get_novel_project_section(
project_id: str,
section: NovelSectionType,
service: NovelService = Depends(get_novel_service),
_: None = Depends(get_current_admin),
) -> NovelSectionResponse:
logger.info("管理员查看项目 %s%s 区段", project_id, section)
return await service.get_section_data_for_admin(project_id, section)
@router.get("/novel-projects/{project_id}/chapters/{chapter_number}", response_model=ChapterSchema)
async def get_novel_project_chapter(
project_id: str,
chapter_number: int,
service: NovelService = Depends(get_novel_service),
_: None = Depends(get_current_admin),
) -> ChapterSchema:
logger.info("管理员查看项目 %s%s 章详情", project_id, chapter_number)
return await service.get_chapter_schema_for_admin(project_id, chapter_number)
@router.get("/prompts", response_model=List[PromptRead])
async def list_prompts(
service: PromptService = Depends(get_prompt_service),
_: None = Depends(get_current_admin),
) -> List[PromptRead]:
prompts = await service.list_prompts()
logger.info("管理员请求提示词列表,共 %s", len(prompts))
return prompts
@router.post("/prompts", response_model=PromptRead, status_code=status.HTTP_201_CREATED)
async def create_prompt(
payload: PromptCreate,
service: PromptService = Depends(get_prompt_service),
_: None = Depends(get_current_admin),
) -> PromptRead:
prompt = await service.create_prompt(payload)
logger.info("管理员创建提示词:%s", prompt.id)
return prompt
@router.get("/prompts/{prompt_id}", response_model=PromptRead)
async def get_prompt(
prompt_id: int,
service: PromptService = Depends(get_prompt_service),
_: None = Depends(get_current_admin),
) -> PromptRead:
prompt = await service.get_prompt_by_id(prompt_id)
if not prompt:
logger.warning("提示词 %s 不存在", prompt_id)
raise HTTPException(status_code=404, detail="提示词不存在")
logger.info("管理员获取提示词:%s", prompt_id)
return prompt
@router.patch("/prompts/{prompt_id}", response_model=PromptRead)
async def update_prompt(
prompt_id: int,
payload: PromptUpdate,
service: PromptService = Depends(get_prompt_service),
_: None = Depends(get_current_admin),
) -> PromptRead:
result = await service.update_prompt(prompt_id, payload)
if not result:
logger.warning("提示词 %s 不存在,无法更新", prompt_id)
raise HTTPException(status_code=404, detail="提示词不存在")
logger.info("管理员更新提示词:%s", prompt_id)
return result
@router.delete("/prompts/{prompt_id}", status_code=status.HTTP_204_NO_CONTENT)
async def delete_prompt(
prompt_id: int,
service: PromptService = Depends(get_prompt_service),
_: None = Depends(get_current_admin),
) -> None:
deleted = await service.delete_prompt(prompt_id)
if not deleted:
logger.warning("提示词 %s 不存在,无法删除", prompt_id)
raise HTTPException(status_code=404, detail="提示词不存在")
logger.info("管理员删除提示词:%s", prompt_id)
@router.get("/update-logs", response_model=List[UpdateLogRead])
async def list_update_logs(
service: UpdateLogService = Depends(get_update_log_service),
_: None = Depends(get_current_admin),
) -> List[UpdateLogRead]:
logs = await service.list_logs()
logger.info("管理员查看更新日志列表,共 %s", len(logs))
return [UpdateLogRead.model_validate(log) for log in logs]
@router.post("/update-logs", response_model=UpdateLogRead, status_code=status.HTTP_201_CREATED)
async def create_update_log(
payload: UpdateLogCreate,
service: UpdateLogService = Depends(get_update_log_service),
current_admin=Depends(get_current_admin),
) -> UpdateLogRead:
log = await service.create_log(
payload.content,
creator=current_admin.username,
is_pinned=payload.is_pinned or False,
)
logger.info("管理员 %s 创建更新日志:%s", current_admin.username, log.id)
return UpdateLogRead.model_validate(log)
@router.delete("/update-logs/{log_id}", status_code=status.HTTP_204_NO_CONTENT)
async def delete_update_log(
log_id: int,
service: UpdateLogService = Depends(get_update_log_service),
_: None = Depends(get_current_admin),
) -> None:
await service.delete_log(log_id)
logger.info("管理员删除更新日志:%s", log_id)
@router.patch("/update-logs/{log_id}", response_model=UpdateLogRead)
async def update_update_log(
log_id: int,
payload: UpdateLogUpdate,
service: UpdateLogService = Depends(get_update_log_service),
_: None = Depends(get_current_admin),
) -> UpdateLogRead:
log = await service.update_log(
log_id,
content=payload.content,
is_pinned=payload.is_pinned,
)
logger.info("管理员更新日志 %s", log_id)
return UpdateLogRead.model_validate(log)
@router.get("/settings/daily-request-limit", response_model=DailyRequestLimit)
async def get_daily_limit(
service: AdminSettingService = Depends(get_admin_setting_service),
_: None = Depends(get_current_admin),
) -> DailyRequestLimit:
value = await service.get("daily_request_limit", "100")
logger.info("管理员查询每日请求上限:%s", value)
return DailyRequestLimit(limit=int(value or 100))
@router.put("/settings/daily-request-limit", response_model=DailyRequestLimit)
async def update_daily_limit(
payload: DailyRequestLimit,
service: AdminSettingService = Depends(get_admin_setting_service),
_: None = Depends(get_current_admin),
) -> DailyRequestLimit:
await service.set("daily_request_limit", str(payload.limit))
logger.info("管理员设置每日请求上限为 %s", payload.limit)
return payload
@router.get("/system-configs", response_model=List[SystemConfigRead])
async def list_system_configs(
service: ConfigService = Depends(get_config_service),
_: None = Depends(get_current_admin),
) -> List[SystemConfigRead]:
configs = await service.list_configs()
logger.info("管理员获取系统配置,共 %s", len(configs))
return configs
@router.get("/system-configs/{key}", response_model=SystemConfigRead)
async def get_system_config(
key: str,
service: ConfigService = Depends(get_config_service),
_: None = Depends(get_current_admin),
) -> SystemConfigRead:
config = await service.get_config(key)
if not config:
logger.warning("系统配置 %s 不存在", key)
raise HTTPException(status_code=404, detail="配置项不存在")
logger.info("管理员查询系统配置:%s", key)
return config
@router.put("/system-configs/{key}", response_model=SystemConfigRead)
async def upsert_system_config(
key: str,
payload: SystemConfigCreate,
service: ConfigService = Depends(get_config_service),
_: None = Depends(get_current_admin),
) -> SystemConfigRead:
logger.info("管理员写入系统配置:%s", key)
return await service.upsert_config(
SystemConfigCreate(key=key, value=payload.value, description=payload.description)
)
@router.patch("/system-configs/{key}", response_model=SystemConfigRead)
async def patch_system_config(
key: str,
payload: SystemConfigUpdate,
service: ConfigService = Depends(get_config_service),
_: None = Depends(get_current_admin),
) -> SystemConfigRead:
config = await service.patch_config(key, payload)
if not config:
logger.warning("系统配置 %s 不存在,无法更新", key)
raise HTTPException(status_code=404, detail="配置项不存在")
logger.info("管理员部分更新系统配置:%s", key)
return config
@router.delete("/system-configs/{key}", status_code=status.HTTP_204_NO_CONTENT)
async def delete_system_config(
key: str,
service: ConfigService = Depends(get_config_service),
_: None = Depends(get_current_admin),
) -> None:
deleted = await service.remove_config(key)
if not deleted:
logger.warning("系统配置 %s 不存在,无法删除", key)
raise HTTPException(status_code=404, detail="配置项不存在")
logger.info("管理员删除系统配置:%s", key)
@router.post("/password", status_code=status.HTTP_204_NO_CONTENT)
async def change_password(
payload: PasswordChangeRequest,
current_admin=Depends(get_current_admin),
service: AuthService = Depends(get_auth_service),
) -> None:
await service.change_password(current_admin.username, payload.old_password, payload.new_password)
logger.info("管理员 %s 修改密码", current_admin.username)

View File

@@ -0,0 +1,106 @@
import logging
from datetime import timedelta
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, status
from fastapi.responses import HTMLResponse, RedirectResponse
from fastapi.security import OAuth2PasswordRequestForm
from sqlalchemy.ext.asyncio import AsyncSession
from ...core.config import settings
from ...core.dependencies import get_current_user
from ...db.session import get_session
from ...schemas.user import AuthOptions, Token, User, UserInDB, UserRegistration
from ...services.auth_service import AuthService
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/auth", tags=["Authentication"])
def get_auth_service(session: AsyncSession = Depends(get_session)) -> AuthService:
return AuthService(session)
@router.post("/send-code", status_code=204)
async def send_verification_code(email: str, service: AuthService = Depends(get_auth_service)):
await service.send_verification_code(email)
logger.info("%s 发送验证码", email)
@router.get("/options", response_model=AuthOptions)
async def read_auth_options(service: AuthService = Depends(get_auth_service)):
"""读取认证功能开关,供前端动态渲染。"""
options = await service.get_auth_options()
return options
@router.post("/users", response_model=User, status_code=status.HTTP_201_CREATED)
async def register_user(payload: UserRegistration, service: AuthService = Depends(get_auth_service)):
user = await service.register_user(payload)
logger.info("注册新用户:%s", user.username)
return User.model_validate(user)
@router.post("/token", response_model=Token)
async def login(form_data: OAuth2PasswordRequestForm = Depends(), service: AuthService = Depends(get_auth_service)):
user = await service.authenticate_user(form_data.username, form_data.password)
must_change_password = service.requires_password_reset(user)
token = await service.create_access_token(user, must_change_password=must_change_password)
logger.info("用户 %s 登录成功,需改密=%s", form_data.username, must_change_password)
return token
@router.get("/users/me", response_model=User)
async def read_current_user(current_user: UserInDB = Depends(get_current_user)):
logger.debug("读取当前用户:%s", current_user.username)
return current_user
@router.get("/linuxdo/login")
async def login_with_linuxdo(service: AuthService = Depends(get_auth_service)):
if not await service.is_linuxdo_login_enabled():
logger.warning("Linux.do 登录未启用")
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="未启用 Linux.do 登录")
client_id = await service.get_config_value("linuxdo.client_id")
redirect_uri = await service.get_config_value("linuxdo.redirect_uri")
auth_url = await service.get_config_value("linuxdo.auth_url")
if not all([client_id, redirect_uri, auth_url]):
logger.error("Linux.do OAuth 参数未配置完整")
raise HTTPException(status_code=500, detail="未配置 Linux.do OAuth 参数")
params = {
"client_id": client_id,
"redirect_uri": redirect_uri,
"response_type": "code",
"scope": "user",
}
query = "&".join(f"{k}={v}" for k, v in params.items())
logger.info("跳转 Linux.do 授权client_id=%s", client_id)
return RedirectResponse(url=f"{auth_url}?{query}")
@router.get("/linuxdo/register", response_class=HTMLResponse)
async def register_with_linuxdo(code: str, service: AuthService = Depends(get_auth_service)):
token = await service.handle_linuxdo_callback(code)
logger.info("Linux.do 授权回调成功")
token_json = token.model_dump_json()
html_content = f"""<!DOCTYPE html>
<html lang=\"zh-CN\">
<head><meta charset=\"UTF-8\"><title>正在跳转</title></head>
<body>
<p>正在跳转,请稍候...</p>
<script>
(function() {{
const token = JSON.parse('{token_json}');
try {{
window.localStorage.setItem('token', token.access_token);
}} catch (err) {{
console.error('无法写入本地存储', err);
}}
window.location.replace('/');
}})();
</script>
</body>
</html>"""
return HTMLResponse(content=html_content)

View File

@@ -0,0 +1,54 @@
import logging
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy.ext.asyncio import AsyncSession
from ...core.dependencies import get_current_user
from ...db.session import get_session
from ...schemas.llm_config import LLMConfigCreate, LLMConfigRead
from ...schemas.user import UserInDB
from ...services.llm_config_service import LLMConfigService
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/llm-config", tags=["LLM Configuration"])
def get_llm_config_service(session: AsyncSession = Depends(get_session)) -> LLMConfigService:
return LLMConfigService(session)
@router.get("", response_model=LLMConfigRead)
async def read_llm_config(
service: LLMConfigService = Depends(get_llm_config_service),
current_user: UserInDB = Depends(get_current_user),
) -> LLMConfigRead:
config = await service.get_config(current_user.id)
if not config:
logger.warning("用户 %s 尚未设置 LLM 配置", current_user.id)
raise HTTPException(status_code=404, detail="尚未设置自定义配置")
logger.info("用户 %s 获取 LLM 配置", current_user.id)
return config
@router.put("", response_model=LLMConfigRead)
async def upsert_llm_config(
payload: LLMConfigCreate,
service: LLMConfigService = Depends(get_llm_config_service),
current_user: UserInDB = Depends(get_current_user),
) -> LLMConfigRead:
logger.info("用户 %s 更新 LLM 配置", current_user.id)
return await service.upsert_config(current_user.id, payload)
@router.delete("", status_code=status.HTTP_204_NO_CONTENT)
async def delete_llm_config(
service: LLMConfigService = Depends(get_llm_config_service),
current_user: UserInDB = Depends(get_current_user),
) -> None:
deleted = await service.delete_config(current_user.id)
if not deleted:
logger.warning("用户 %s 删除 LLM 配置失败,未找到记录", current_user.id)
raise HTTPException(status_code=404, detail="未找到配置")
logger.info("用户 %s 删除 LLM 配置", current_user.id)

View File

@@ -0,0 +1,301 @@
import json
import logging
from typing import Dict, List
from fastapi import APIRouter, Body, Depends, HTTPException, status
from sqlalchemy.ext.asyncio import AsyncSession
from ...core.dependencies import get_current_user
from ...db.session import get_session
from ...schemas.novel import (
Blueprint,
BlueprintGenerationResponse,
BlueprintPatch,
Chapter as ChapterSchema,
ConverseRequest,
ConverseResponse,
NovelProject as NovelProjectSchema,
NovelProjectSummary,
NovelSectionResponse,
NovelSectionType,
)
from ...schemas.user import UserInDB
from ...services.llm_service import LLMService
from ...services.novel_service import NovelService
from ...services.prompt_service import PromptService
from ...utils.json_utils import remove_think_tags, unwrap_markdown_json
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/novels", tags=["Novels"])
JSON_RESPONSE_INSTRUCTION = """
IMPORTANT: 你的回复必须是合法的 JSON 对象,并严格包含以下字段:
{
"ai_message": "string",
"ui_control": {
"type": "single_choice | text_input | info_display",
"options": [
{"id": "option_1", "label": "string"}
],
"placeholder": "string"
},
"conversation_state": {},
"is_complete": false
}
不要输出额外的文本或解释。
"""
def _ensure_prompt(prompt: str | None, name: str) -> str:
if not prompt:
raise HTTPException(status_code=500, detail=f"未配置名为 {name} 的提示词,请联系管理员")
return prompt
@router.post("", response_model=NovelProjectSchema, status_code=status.HTTP_201_CREATED)
async def create_novel(
title: str = Body(...),
initial_prompt: str = Body(...),
session: AsyncSession = Depends(get_session),
current_user: UserInDB = Depends(get_current_user),
) -> NovelProjectSchema:
"""为当前用户创建一个新的小说项目。"""
novel_service = NovelService(session)
project = await novel_service.create_project(current_user.id, title, initial_prompt)
logger.info("用户 %s 创建项目 %s", current_user.id, project.id)
return await novel_service.get_project_schema(project.id, current_user.id)
@router.get("", response_model=List[NovelProjectSummary])
async def list_novels(
session: AsyncSession = Depends(get_session),
current_user: UserInDB = Depends(get_current_user),
) -> List[NovelProjectSummary]:
"""列出用户的全部小说项目摘要信息。"""
novel_service = NovelService(session)
projects = await novel_service.list_projects_for_user(current_user.id)
logger.info("用户 %s 获取项目列表,共 %s", current_user.id, len(projects))
return projects
@router.get("/{project_id}", response_model=NovelProjectSchema)
async def get_novel(
project_id: str,
session: AsyncSession = Depends(get_session),
current_user: UserInDB = Depends(get_current_user),
) -> NovelProjectSchema:
novel_service = NovelService(session)
logger.info("用户 %s 查询项目 %s", current_user.id, project_id)
return await novel_service.get_project_schema(project_id, current_user.id)
@router.get("/{project_id}/sections/{section}", response_model=NovelSectionResponse)
async def get_novel_section(
project_id: str,
section: NovelSectionType,
session: AsyncSession = Depends(get_session),
current_user: UserInDB = Depends(get_current_user),
) -> NovelSectionResponse:
novel_service = NovelService(session)
logger.info("用户 %s 获取项目 %s%s 区段", current_user.id, project_id, section)
return await novel_service.get_section_data(project_id, current_user.id, section)
@router.get("/{project_id}/chapters/{chapter_number}", response_model=ChapterSchema)
async def get_chapter(
project_id: str,
chapter_number: int,
session: AsyncSession = Depends(get_session),
current_user: UserInDB = Depends(get_current_user),
) -> ChapterSchema:
novel_service = NovelService(session)
logger.info("用户 %s 获取项目 %s%s", current_user.id, project_id, chapter_number)
return await novel_service.get_chapter_schema(project_id, current_user.id, chapter_number)
@router.delete("", status_code=status.HTTP_200_OK)
async def delete_novels(
project_ids: List[str] = Body(...),
session: AsyncSession = Depends(get_session),
current_user: UserInDB = Depends(get_current_user),
) -> Dict[str, str]:
novel_service = NovelService(session)
await novel_service.delete_projects(project_ids, current_user.id)
logger.info("用户 %s 删除项目 %s", current_user.id, project_ids)
return {"status": "success", "message": f"成功删除 {len(project_ids)} 个项目"}
@router.post("/{project_id}/concept/converse", response_model=ConverseResponse)
async def converse_with_concept(
project_id: str,
request: ConverseRequest,
session: AsyncSession = Depends(get_session),
current_user: UserInDB = Depends(get_current_user),
) -> ConverseResponse:
"""与概念设计师LLM进行对话引导蓝图筹备。"""
novel_service = NovelService(session)
prompt_service = PromptService(session)
llm_service = LLMService(session)
project = await novel_service.ensure_project_owner(project_id, current_user.id)
history_records = await novel_service.list_conversations(project_id)
logger.info(
"项目 %s 概念对话请求,用户 %s,历史记录 %s",
project_id,
current_user.id,
len(history_records),
)
conversation_history = [
{"role": record.role, "content": record.content}
for record in history_records
]
user_content = json.dumps(request.user_input, ensure_ascii=False)
conversation_history.append({"role": "user", "content": user_content})
system_prompt = _ensure_prompt(await prompt_service.get_prompt("concept"), "concept")
system_prompt = f"{system_prompt}\n{JSON_RESPONSE_INSTRUCTION}"
llm_response = await llm_service.get_llm_response(
system_prompt=system_prompt,
conversation_history=conversation_history,
temperature=0.8,
user_id=current_user.id,
timeout=240.0,
)
llm_response = remove_think_tags(llm_response)
try:
normalized = unwrap_markdown_json(llm_response)
parsed = json.loads(normalized)
except json.JSONDecodeError as exc:
logger.exception(
"Failed to parse concept converse response: project_id=%s user_id=%s normalized=%s",
project_id,
current_user.id,
normalized,
)
raise HTTPException(status_code=500, detail="AI 返回内容不是有效的 JSON") from exc
await novel_service.append_conversation(project_id, "user", user_content)
await novel_service.append_conversation(project_id, "assistant", normalized)
logger.info("项目 %s 概念对话完成is_complete=%s", project_id, parsed.get("is_complete"))
if parsed.get("is_complete"):
parsed["ready_for_blueprint"] = True
parsed.setdefault("conversation_state", parsed.get("conversation_state", {}))
return ConverseResponse(**parsed)
@router.post("/{project_id}/blueprint/generate", response_model=BlueprintGenerationResponse)
async def generate_blueprint(
project_id: str,
session: AsyncSession = Depends(get_session),
current_user: UserInDB = Depends(get_current_user),
) -> BlueprintGenerationResponse:
"""根据完整对话生成可执行的小说蓝图。"""
novel_service = NovelService(session)
prompt_service = PromptService(session)
llm_service = LLMService(session)
project = await novel_service.ensure_project_owner(project_id, current_user.id)
logger.info("项目 %s 开始生成蓝图", project_id)
history_records = await novel_service.list_conversations(project_id)
if not history_records:
raise HTTPException(status_code=400, detail="缺少对话历史,无法生成蓝图")
formatted_history: List[Dict[str, str]] = []
for record in history_records:
role = record.role
content = record.content
if not role or not content:
continue
try:
normalized = unwrap_markdown_json(content)
data = json.loads(normalized)
if role == "user":
user_value = data.get("value", data)
if isinstance(user_value, str):
formatted_history.append({"role": "user", "content": user_value})
elif role == "assistant":
ai_message = data.get("ai_message") if isinstance(data, dict) else None
if ai_message:
formatted_history.append({"role": "assistant", "content": ai_message})
except (json.JSONDecodeError, AttributeError):
continue
if not formatted_history:
raise HTTPException(status_code=400, detail="无法从历史对话中提取内容")
system_prompt = _ensure_prompt(await prompt_service.get_prompt("screenwriting"), "screenwriting")
blueprint_raw = await llm_service.get_llm_response(
system_prompt=system_prompt,
conversation_history=formatted_history,
temperature=0.3,
user_id=current_user.id,
timeout=480.0,
)
blueprint_raw = remove_think_tags(blueprint_raw)
blueprint_normalized = unwrap_markdown_json(blueprint_raw)
try:
blueprint_data = json.loads(blueprint_normalized)
except json.JSONDecodeError as exc:
raise HTTPException(status_code=500, detail="蓝图生成失败,请稍后重试") from exc
blueprint = Blueprint(**blueprint_data)
await novel_service.replace_blueprint(project_id, blueprint)
if blueprint.title:
project.title = blueprint.title
project.status = "blueprint_ready"
await session.commit()
logger.info("项目 %s 更新标题为 %s,并标记为 blueprint_ready", project_id, blueprint.title)
ai_message = (
"太棒了!我已经根据我们的对话整理出完整的小说蓝图。请确认是否进入写作阶段,或提出修改意见。"
)
return BlueprintGenerationResponse(blueprint=blueprint, ai_message=ai_message)
@router.post("/{project_id}/blueprint/save", response_model=NovelProjectSchema)
async def save_blueprint(
project_id: str,
blueprint_data: Blueprint | None = Body(None),
session: AsyncSession = Depends(get_session),
current_user: UserInDB = Depends(get_current_user),
) -> NovelProjectSchema:
"""保存蓝图信息,可用于手动覆盖自动生成结果。"""
novel_service = NovelService(session)
project = await novel_service.ensure_project_owner(project_id, current_user.id)
if blueprint_data:
await novel_service.replace_blueprint(project_id, blueprint_data)
if blueprint_data.title:
project.title = blueprint_data.title
await session.commit()
logger.info("项目 %s 手动保存蓝图", project_id)
else:
raise HTTPException(status_code=400, detail="缺少蓝图数据")
return await novel_service.get_project_schema(project_id, current_user.id)
@router.patch("/{project_id}/blueprint", response_model=NovelProjectSchema)
async def patch_blueprint(
project_id: str,
payload: BlueprintPatch,
session: AsyncSession = Depends(get_session),
current_user: UserInDB = Depends(get_current_user),
) -> NovelProjectSchema:
"""局部更新蓝图字段,对世界观或角色做微调。"""
novel_service = NovelService(session)
project = await novel_service.ensure_project_owner(project_id, current_user.id)
update_data = payload.model_dump(exclude_unset=True)
await novel_service.patch_blueprint(project_id, update_data)
logger.info("项目 %s 局部更新蓝图字段:%s", project_id, list(update_data.keys()))
return await novel_service.get_project_schema(project_id, current_user.id)

View File

@@ -0,0 +1,22 @@
from typing import List
from fastapi import APIRouter, Depends
from sqlalchemy.ext.asyncio import AsyncSession
from ...db.session import get_session
from ...schemas.admin import UpdateLogRead
from ...services.update_log_service import UpdateLogService
router = APIRouter(prefix="/api/updates", tags=["Updates"])
def get_update_log_service(session: AsyncSession = Depends(get_session)) -> UpdateLogService:
return UpdateLogService(session)
@router.get("/latest", response_model=List[UpdateLogRead])
async def read_latest_updates(
service: UpdateLogService = Depends(get_update_log_service),
) -> List[UpdateLogRead]:
logs = await service.list_logs(limit=5)
return [UpdateLogRead.model_validate(log) for log in logs]

View File

@@ -0,0 +1,613 @@
import json
import logging
import os
from typing import Dict, List, Optional
from fastapi import APIRouter, Body, Depends, HTTPException
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from ...core.config import settings
from ...core.dependencies import get_current_user
from ...db.session import get_session
from ...models.novel import Chapter, ChapterOutline
from ...schemas.novel import (
DeleteChapterRequest,
EditChapterRequest,
EvaluateChapterRequest,
GenerateChapterRequest,
GenerateOutlineRequest,
NovelProject as NovelProjectSchema,
SelectVersionRequest,
UpdateChapterOutlineRequest,
)
from ...schemas.user import UserInDB
from ...services.chapter_context_service import ChapterContextService
from ...services.chapter_ingest_service import ChapterIngestionService
from ...services.llm_service import LLMService
from ...services.novel_service import NovelService
from ...services.prompt_service import PromptService
from ...services.vector_store_service import VectorStoreService
from ...utils.json_utils import remove_think_tags, unwrap_markdown_json
from ...repositories.system_config_repository import SystemConfigRepository
router = APIRouter(prefix="/api/writer", tags=["Writer"])
logger = logging.getLogger(__name__)
async def _load_project_schema(service: NovelService, project_id: str, user_id: int) -> NovelProjectSchema:
return await service.get_project_schema(project_id, user_id)
def _extract_tail_excerpt(text: Optional[str], limit: int = 500) -> str:
"""截取章节结尾文本,默认保留 500 字。"""
if not text:
return ""
stripped = text.strip()
if len(stripped) <= limit:
return stripped
return stripped[-limit:]
@router.post("/novels/{project_id}/chapters/generate", response_model=NovelProjectSchema)
async def generate_chapter(
project_id: str,
request: GenerateChapterRequest,
session: AsyncSession = Depends(get_session),
current_user: UserInDB = Depends(get_current_user),
) -> NovelProjectSchema:
novel_service = NovelService(session)
prompt_service = PromptService(session)
llm_service = LLMService(session)
project = await novel_service.ensure_project_owner(project_id, current_user.id)
logger.info("用户 %s 开始为项目 %s 生成第 %s", current_user.id, project_id, request.chapter_number)
outline = await novel_service.get_outline(project_id, request.chapter_number)
if not outline:
logger.warning("项目 %s 未找到第 %s 章纲要,生成流程终止", project_id, request.chapter_number)
raise HTTPException(status_code=404, detail="蓝图中未找到对应章节纲要")
chapter = await novel_service.get_or_create_chapter(project_id, request.chapter_number)
chapter.real_summary = None
chapter.selected_version_id = None
chapter.status = "generating"
await session.commit()
outlines_map = {item.chapter_number: item for item in project.outlines}
# 收集所有可用的历史章节摘要,便于在 Prompt 中提供前情背景
completed_chapters = []
latest_prev_number = -1
previous_summary_text = ""
previous_tail_excerpt = ""
for existing in project.chapters:
if existing.chapter_number >= request.chapter_number:
continue
if existing.selected_version is None or not existing.selected_version.content:
continue
if not existing.real_summary:
summary = await llm_service.get_summary(
existing.selected_version.content,
temperature=0.15,
user_id=current_user.id,
timeout=180.0,
)
existing.real_summary = remove_think_tags(summary)
await session.commit()
completed_chapters.append(
{
"chapter_number": existing.chapter_number,
"title": outlines_map.get(existing.chapter_number).title if outlines_map.get(existing.chapter_number) else f"{existing.chapter_number}",
"summary": existing.real_summary,
}
)
if existing.chapter_number > latest_prev_number:
latest_prev_number = existing.chapter_number
previous_summary_text = existing.real_summary or ""
previous_tail_excerpt = _extract_tail_excerpt(existing.selected_version.content)
project_schema = await novel_service._serialize_project(project)
blueprint_dict = project_schema.blueprint.model_dump()
if "relationships" in blueprint_dict and blueprint_dict["relationships"]:
for relation in blueprint_dict["relationships"]:
if "character_from" in relation:
relation["from"] = relation.pop("character_from")
if "character_to" in relation:
relation["to"] = relation.pop("character_to")
# 蓝图中禁止携带章节级别的细节信息,避免重复传输大段场景或对话内容
banned_blueprint_keys = {
"chapter_outline",
"chapter_summaries",
"chapter_details",
"chapter_dialogues",
"chapter_events",
"conversation_history",
"character_timelines",
}
for key in banned_blueprint_keys:
if key in blueprint_dict:
blueprint_dict.pop(key, None)
writer_prompt = await prompt_service.get_prompt("writing")
if not writer_prompt:
raise HTTPException(status_code=500, detail="缺少写作提示词")
# 初始化向量检索服务,若未配置则自动降级为纯提示词生成
vector_store: Optional[VectorStoreService]
if not settings.vector_store_enabled:
vector_store = None
else:
try:
vector_store = VectorStoreService()
except RuntimeError as exc:
logger.warning("向量库初始化失败RAG 检索被禁用: %s", exc)
vector_store = None
context_service = ChapterContextService(llm_service=llm_service, vector_store=vector_store)
outline_title = outline.title or f"{outline.chapter_number}"
outline_summary = outline.summary or "暂无摘要"
query_parts = [outline_title, outline_summary]
if request.writing_notes:
query_parts.append(request.writing_notes)
rag_query = "\n".join(part for part in query_parts if part)
rag_context = await context_service.retrieve_for_generation(
project_id=project_id,
query_text=rag_query or outline.title or outline.summary or "",
user_id=current_user.id,
)
chunk_count = len(rag_context.chunks) if rag_context and rag_context.chunks else 0
summary_count = len(rag_context.summaries) if rag_context and rag_context.summaries else 0
logger.info(
"项目 %s%s 章检索到 %s 个剧情片段和 %s 条摘要",
project_id,
request.chapter_number,
chunk_count,
summary_count,
)
# print("rag_context:",rag_context)
# 将蓝图、前情、RAG 检索结果拼装成结构化段落,供模型理解
blueprint_text = json.dumps(blueprint_dict, ensure_ascii=False, indent=2)
completed_lines = [
f"- 第{item['chapter_number']}章 - {item['title']}:{item['summary']}"
for item in completed_chapters
]
previous_summary_text = previous_summary_text or "暂无可用摘要"
previous_tail_excerpt = previous_tail_excerpt or "暂无上一章结尾内容"
completed_section = "\n".join(completed_lines) if completed_lines else "暂无前情摘要"
rag_chunks_text = "\n\n".join(rag_context.chunk_texts()) if rag_context.chunks else "未检索到章节片段"
rag_summaries_text = "\n".join(rag_context.summary_lines()) if rag_context.summaries else "未检索到章节摘要"
writing_notes = request.writing_notes or "无额外写作指令"
prompt_sections = [
("[世界蓝图](JSON)", blueprint_text),
# ("[前情摘要]", completed_section),
("[上一章摘要]", previous_summary_text),
("[上一章结尾]", previous_tail_excerpt),
("[检索到的剧情上下文](Markdown)", rag_chunks_text),
("[检索到的章节摘要]", rag_summaries_text),
(
"[当前章节目标]",
f"标题:{outline_title}\n摘要:{outline_summary}\n写作要求:{writing_notes}",
),
]
prompt_input = "\n\n".join(f"{title}\n{content}" for title, content in prompt_sections if content)
logger.debug("章节写作提示词:%s\n%s", writer_prompt, prompt_input)
async def _generate_single_version(idx: int) -> Dict:
try:
response = await llm_service.get_llm_response(
system_prompt=writer_prompt,
conversation_history=[{"role": "user", "content": prompt_input}],
temperature=0.9,
user_id=current_user.id,
timeout=600.0,
)
cleaned = remove_think_tags(response)
normalized = unwrap_markdown_json(cleaned)
try:
return json.loads(normalized)
except json.JSONDecodeError:
return {"content": normalized}
except Exception as exc:
logger.exception(
"项目 %s 生成第 %s 章第 %s 个版本时发生异常: %s",
project_id,
request.chapter_number,
idx + 1,
exc,
)
return {"content": f"生成失败: {exc}"}
version_count = await _resolve_version_count(session)
logger.info(
"项目 %s%s 章计划生成 %s 个版本",
project_id,
request.chapter_number,
version_count,
)
raw_versions = []
for idx in range(version_count):
raw_versions.append(await _generate_single_version(idx))
contents: List[str] = []
metadata: List[Dict] = []
for variant in raw_versions:
if isinstance(variant, dict):
if "content" in variant and isinstance(variant["content"], str):
contents.append(variant["content"])
elif "chapter_content" in variant:
contents.append(str(variant["chapter_content"]))
else:
contents.append(json.dumps(variant, ensure_ascii=False))
metadata.append(variant)
else:
contents.append(str(variant))
metadata.append({"raw": variant})
await novel_service.replace_chapter_versions(chapter, contents, metadata)
logger.info(
"项目 %s%s 章生成完成,已写入 %s 个版本",
project_id,
request.chapter_number,
len(contents),
)
return await _load_project_schema(novel_service, project_id, current_user.id)
async def _resolve_version_count(session: AsyncSession) -> int:
repo = SystemConfigRepository(session)
record = await repo.get_by_key("writer.chapter_versions")
if record:
try:
value = int(record.value)
if value > 0:
return value
except (TypeError, ValueError):
pass
env_value = os.getenv("WRITER_CHAPTER_VERSION_COUNT")
if env_value:
try:
value = int(env_value)
if value > 0:
return value
except ValueError:
pass
return 3
@router.post("/novels/{project_id}/chapters/select", response_model=NovelProjectSchema)
async def select_chapter_version(
project_id: str,
request: SelectVersionRequest,
session: AsyncSession = Depends(get_session),
current_user: UserInDB = Depends(get_current_user),
) -> NovelProjectSchema:
novel_service = NovelService(session)
llm_service = LLMService(session)
project = await novel_service.ensure_project_owner(project_id, current_user.id)
chapter = next((ch for ch in project.chapters if ch.chapter_number == request.chapter_number), None)
if not chapter:
logger.warning("项目 %s 未找到第 %s 章,无法选择版本", project_id, request.chapter_number)
raise HTTPException(status_code=404, detail="章节不存在")
selected = await novel_service.select_chapter_version(chapter, request.version_index)
logger.info(
"用户 %s 选择了项目 %s%s 章的第 %s 个版本",
current_user.id,
project_id,
request.chapter_number,
request.version_index,
)
if selected and selected.content:
summary = await llm_service.get_summary(
selected.content,
temperature=0.15,
user_id=current_user.id,
timeout=180.0,
)
chapter.real_summary = remove_think_tags(summary)
await session.commit()
# 选定版本后同步向量库,确保后续章节可检索到最新内容
vector_store: Optional[VectorStoreService]
if not settings.vector_store_enabled:
vector_store = None
else:
try:
vector_store = VectorStoreService()
except RuntimeError as exc:
logger.warning("向量库初始化失败,跳过章节向量同步: %s", exc)
vector_store = None
if vector_store:
ingestion_service = ChapterIngestionService(llm_service=llm_service, vector_store=vector_store)
outline = next((item for item in project.outlines if item.chapter_number == chapter.chapter_number), None)
chapter_title = outline.title if outline and outline.title else f"{chapter.chapter_number}"
await ingestion_service.ingest_chapter(
project_id=project_id,
chapter_number=chapter.chapter_number,
title=chapter_title,
content=selected.content,
summary=chapter.real_summary,
user_id=current_user.id,
)
logger.info(
"项目 %s%s 章已同步至向量库",
project_id,
chapter.chapter_number,
)
return await _load_project_schema(novel_service, project_id, current_user.id)
@router.post("/novels/{project_id}/chapters/evaluate", response_model=NovelProjectSchema)
async def evaluate_chapter(
project_id: str,
request: EvaluateChapterRequest,
session: AsyncSession = Depends(get_session),
current_user: UserInDB = Depends(get_current_user),
) -> NovelProjectSchema:
novel_service = NovelService(session)
prompt_service = PromptService(session)
llm_service = LLMService(session)
project = await novel_service.ensure_project_owner(project_id, current_user.id)
chapter = next((ch for ch in project.chapters if ch.chapter_number == request.chapter_number), None)
if not chapter:
logger.warning("项目 %s 未找到第 %s 章,无法执行评估", project_id, request.chapter_number)
raise HTTPException(status_code=404, detail="章节不存在")
if not chapter.versions:
logger.warning("项目 %s%s 章无可评估版本", project_id, request.chapter_number)
raise HTTPException(status_code=400, detail="无可评估的章节版本")
evaluator_prompt = await prompt_service.get_prompt("evaluation")
if not evaluator_prompt:
logger.error("缺少评估提示词,项目 %s%s 章评估失败", project_id, request.chapter_number)
raise HTTPException(status_code=500, detail="缺少评估提示词")
project_schema = await novel_service._serialize_project(project)
blueprint_dict = project_schema.blueprint.model_dump()
versions_to_evaluate = [
{"version_id": idx + 1, "content": version.content}
for idx, version in enumerate(sorted(chapter.versions, key=lambda item: item.created_at))
]
# print("blueprint_dict:",blueprint_dict)
evaluator_payload = {
"novel_blueprint": blueprint_dict,
"content_to_evaluate": {
"chapter_number": chapter.chapter_number,
"versions": versions_to_evaluate,
},
}
evaluation_raw = await llm_service.get_llm_response(
system_prompt=evaluator_prompt,
conversation_history=[{"role": "user", "content": json.dumps(evaluator_payload, ensure_ascii=False)}],
temperature=0.3,
user_id=current_user.id,
timeout=360.0,
)
evaluation_clean = remove_think_tags(evaluation_raw)
await novel_service.add_chapter_evaluation(chapter, None, evaluation_clean)
logger.info("项目 %s%s 章评估完成", project_id, request.chapter_number)
return await _load_project_schema(novel_service, project_id, current_user.id)
@router.post("/novels/{project_id}/chapters/outline", response_model=NovelProjectSchema)
async def generate_chapter_outline(
project_id: str,
request: GenerateOutlineRequest,
session: AsyncSession = Depends(get_session),
current_user: UserInDB = Depends(get_current_user),
) -> NovelProjectSchema:
novel_service = NovelService(session)
prompt_service = PromptService(session)
llm_service = LLMService(session)
await novel_service.ensure_project_owner(project_id, current_user.id)
logger.info(
"用户 %s 请求生成项目 %s 的章节大纲,起始章节 %s,数量 %s",
current_user.id,
project_id,
request.start_chapter,
request.num_chapters,
)
outline_prompt = await prompt_service.get_prompt("outline")
if not outline_prompt:
logger.error("缺少大纲提示词,项目 %s 大纲生成失败", project_id)
raise HTTPException(status_code=500, detail="缺少大纲提示词")
project_schema = await novel_service.get_project_schema(project_id, current_user.id)
blueprint_dict = project_schema.blueprint.model_dump()
payload = {
"novel_blueprint": blueprint_dict,
"wait_to_generate": {
"start_chapter": request.start_chapter,
"num_chapters": request.num_chapters,
},
}
response = await llm_service.get_llm_response(
system_prompt=outline_prompt,
conversation_history=[{"role": "user", "content": json.dumps(payload, ensure_ascii=False)}],
temperature=0.7,
user_id=current_user.id,
timeout=360.0,
)
normalized = unwrap_markdown_json(remove_think_tags(response))
try:
data = json.loads(normalized)
except json.JSONDecodeError as exc:
raise HTTPException(status_code=500, detail="章节大纲生成失败") from exc
new_outlines = data.get("chapters", [])
for item in new_outlines:
stmt = (
select(ChapterOutline)
.where(
ChapterOutline.project_id == project_id,
ChapterOutline.chapter_number == item.get("chapter_number"),
)
)
result = await session.execute(stmt)
record = result.scalars().first()
if record:
record.title = item.get("title", record.title)
record.summary = item.get("summary", record.summary)
else:
session.add(
ChapterOutline(
project_id=project_id,
chapter_number=item.get("chapter_number"),
title=item.get("title", ""),
summary=item.get("summary"),
)
)
await session.commit()
logger.info("项目 %s 章节大纲生成完成", project_id)
return await novel_service.get_project_schema(project_id, current_user.id)
@router.post("/novels/{project_id}/chapters/update-outline", response_model=NovelProjectSchema)
async def update_chapter_outline(
project_id: str,
request: UpdateChapterOutlineRequest,
session: AsyncSession = Depends(get_session),
current_user: UserInDB = Depends(get_current_user),
) -> NovelProjectSchema:
novel_service = NovelService(session)
await novel_service.ensure_project_owner(project_id, current_user.id)
logger.info(
"用户 %s 更新项目 %s%s 章大纲",
current_user.id,
project_id,
request.chapter_number,
)
stmt = (
select(ChapterOutline)
.where(
ChapterOutline.project_id == project_id,
ChapterOutline.chapter_number == request.chapter_number,
)
)
result = await session.execute(stmt)
outline = result.scalars().first()
if not outline:
outline = ChapterOutline(
project_id=project_id,
chapter_number=request.chapter_number,
)
session.add(outline)
outline.title = request.title
outline.summary = request.summary
await session.commit()
logger.info("项目 %s%s 章大纲已更新", project_id, request.chapter_number)
return await novel_service.get_project_schema(project_id, current_user.id)
@router.post("/novels/{project_id}/chapters/delete", response_model=NovelProjectSchema)
async def delete_chapters(
project_id: str,
request: DeleteChapterRequest,
session: AsyncSession = Depends(get_session),
current_user: UserInDB = Depends(get_current_user),
) -> NovelProjectSchema:
if not request.chapter_numbers:
logger.warning("项目 %s 未提供要删除的章节号", project_id)
raise HTTPException(status_code=400, detail="请提供要删除的章节号")
novel_service = NovelService(session)
llm_service = LLMService(session)
await novel_service.ensure_project_owner(project_id, current_user.id)
logger.info(
"用户 %s 删除项目 %s 的章节 %s",
current_user.id,
project_id,
request.chapter_numbers,
)
await novel_service.delete_chapters(project_id, request.chapter_numbers)
# 删除章节时同步清理向量库,避免过时内容被检索
vector_store: Optional[VectorStoreService]
if not settings.vector_store_enabled:
vector_store = None
else:
try:
vector_store = VectorStoreService()
except RuntimeError as exc:
logger.warning("向量库初始化失败,跳过章节向量删除: %s", exc)
vector_store = None
if vector_store:
ingestion_service = ChapterIngestionService(llm_service=llm_service, vector_store=vector_store)
await ingestion_service.delete_chapters(project_id, request.chapter_numbers)
logger.info(
"项目 %s 已从向量库移除章节 %s",
project_id,
request.chapter_numbers,
)
return await novel_service.get_project_schema(project_id, current_user.id)
@router.post("/novels/{project_id}/chapters/edit", response_model=NovelProjectSchema)
async def edit_chapter(
project_id: str,
request: EditChapterRequest,
session: AsyncSession = Depends(get_session),
current_user: UserInDB = Depends(get_current_user),
) -> NovelProjectSchema:
novel_service = NovelService(session)
llm_service = LLMService(session)
project = await novel_service.ensure_project_owner(project_id, current_user.id)
chapter = next((ch for ch in project.chapters if ch.chapter_number == request.chapter_number), None)
if not chapter or chapter.selected_version is None:
logger.warning("项目 %s%s 章尚未生成或未选择版本,无法编辑", project_id, request.chapter_number)
raise HTTPException(status_code=404, detail="章节尚未生成或未选择版本")
chapter.selected_version.content = request.content
chapter.word_count = len(request.content)
logger.info("用户 %s 更新了项目 %s%s 章内容", current_user.id, project_id, request.chapter_number)
if request.content.strip():
summary = await llm_service.get_summary(
request.content,
temperature=0.15,
user_id=current_user.id,
timeout=180.0,
)
chapter.real_summary = remove_think_tags(summary)
await session.commit()
vector_store: Optional[VectorStoreService]
if not settings.vector_store_enabled:
vector_store = None
else:
try:
vector_store = VectorStoreService()
except RuntimeError as exc:
logger.warning("向量库初始化失败,跳过章节向量更新: %s", exc)
vector_store = None
if vector_store and chapter.selected_version and chapter.selected_version.content:
ingestion_service = ChapterIngestionService(llm_service=llm_service, vector_store=vector_store)
outline = next((item for item in project.outlines if item.chapter_number == chapter.chapter_number), None)
chapter_title = outline.title if outline and outline.title else f"{chapter.chapter_number}"
await ingestion_service.ingest_chapter(
project_id=project_id,
chapter_number=chapter.chapter_number,
title=chapter_title,
content=chapter.selected_version.content,
summary=chapter.real_summary,
user_id=current_user.id,
)
logger.info("项目 %s%s 章更新内容已同步至向量库", project_id, chapter.chapter_number)
return await novel_service.get_project_schema(project_id, current_user.id)

View File

261
backend/app/core/config.py Normal file
View File

@@ -0,0 +1,261 @@
from functools import lru_cache
from pathlib import Path
from typing import Optional
from pydantic import AliasChoices, AnyUrl, Field, HttpUrl, validator
from pydantic_settings import BaseSettings, SettingsConfigDict
from sqlalchemy.engine import URL, make_url
class Settings(BaseSettings):
"""应用全局配置,所有可调参数集中于此,统一加载自环境变量。"""
# -------------------- 基础应用配置 --------------------
app_name: str = Field(default="AI Novel Generator API", description="FastAPI 文档标题")
environment: str = Field(default="development", description="当前环境标识")
debug: bool = Field(default=True, description="是否开启调试模式")
allow_registration: bool = Field(
default=True,
env="ALLOW_USER_REGISTRATION",
description="是否允许用户自助注册",
)
logging_level: str = Field(
default="INFO",
env="LOGGING_LEVEL",
description="应用日志级别",
)
enable_linuxdo_login: bool = Field(
default=False,
env="ENABLE_LINUXDO_LOGIN",
description="是否启用 Linux.do OAuth 登录",
)
# -------------------- 安全相关配置 --------------------
secret_key: str = Field(..., env="SECRET_KEY", description="JWT 加密密钥")
jwt_algorithm: str = Field(default="HS256", env="JWT_ALGORITHM", description="JWT 加密算法")
access_token_expire_minutes: int = Field(
default=60 * 24 * 7,
env="ACCESS_TOKEN_EXPIRE_MINUTES",
description="访问令牌过期时间,单位分钟"
)
# -------------------- 数据库配置 --------------------
database_url: Optional[str] = Field(
default=None,
env="DATABASE_URL",
description="完整的数据库连接串,填入后覆盖下方数据库配置"
)
db_provider: str = Field(
default="mysql",
env="DB_PROVIDER",
description="数据库类型,仅支持 mysql 或 sqlite"
)
mysql_host: str = Field(default="localhost", env="MYSQL_HOST", description="MySQL 主机名")
mysql_port: int = Field(default=3306, env="MYSQL_PORT", description="MySQL 端口")
mysql_user: str = Field(default="root", env="MYSQL_USER", description="MySQL 用户名")
mysql_password: str = Field(default="", env="MYSQL_PASSWORD", description="MySQL 密码")
mysql_database: str = Field(default="arboris", env="MYSQL_DATABASE", description="MySQL 数据库名称")
# -------------------- 管理员初始化配置 --------------------
admin_default_username: str = Field(default="admin", env="ADMIN_DEFAULT_USERNAME", description="默认管理员用户名")
admin_default_password: str = Field(default="ChangeMe123!", env="ADMIN_DEFAULT_PASSWORD", description="默认管理员密码")
admin_default_email: Optional[str] = Field(default=None, env="ADMIN_DEFAULT_EMAIL", description="默认管理员邮箱")
# -------------------- LLM 相关配置 --------------------
openai_api_key: Optional[str] = Field(default=None, env="OPENAI_API_KEY", description="默认的 LLM API Key")
openai_base_url: Optional[HttpUrl] = Field(
default=None,
env="OPENAI_API_BASE_URL",
validation_alias=AliasChoices("OPENAI_API_BASE_URL", "OPENAI_BASE_URL"),
description="LLM API Base URL",
)
openai_model_name: str = Field(default="gpt-4o-mini", env="OPENAI_MODEL_NAME", description="默认 LLM 模型名称")
writer_chapter_versions: int = Field(
default=2,
ge=1,
env="WRITER_CHAPTER_VERSION_COUNT",
validation_alias=AliasChoices("WRITER_CHAPTER_VERSION_COUNT", "WRITER_CHAPTER_VERSIONS"),
description="每次生成章节的候选版本数量",
)
embedding_provider: str = Field(
default="openai",
env="EMBEDDING_PROVIDER",
description="嵌入模型提供方,支持 openai 或 ollama",
)
embedding_base_url: Optional[AnyUrl] = Field(
default=None,
env="EMBEDDING_BASE_URL",
description="嵌入模型使用的 Base URL",
)
embedding_api_key: Optional[str] = Field(
default=None,
env="EMBEDDING_API_KEY",
description="嵌入模型专用 API Key",
)
embedding_model: str = Field(
default="text-embedding-3-large",
env="EMBEDDING_MODEL",
validation_alias=AliasChoices("EMBEDDING_MODEL", "VECTOR_EMBEDDING_MODEL"),
description="默认的嵌入模型名称",
)
embedding_model_vector_size: Optional[int] = Field(
default=None,
env="EMBEDDING_MODEL_VECTOR_SIZE",
description="嵌入向量维度,未配置时将自动检测",
)
ollama_embedding_base_url: Optional[AnyUrl] = Field(
default=None,
env="OLLAMA_EMBEDDING_BASE_URL",
description="Ollama 嵌入模型服务地址",
)
ollama_embedding_model: str = Field(
default="nomic-embed-text:latest",
env="OLLAMA_EMBEDDING_MODEL",
description="Ollama 嵌入模型名称",
)
vector_db_url: Optional[str] = Field(
default=None,
env="VECTOR_DB_URL",
description="libsql 向量库连接地址",
)
vector_db_auth_token: Optional[str] = Field(
default=None,
env="VECTOR_DB_AUTH_TOKEN",
description="libsql 访问令牌",
)
vector_top_k_chunks: int = Field(
default=5,
ge=0,
env="VECTOR_TOP_K_CHUNKS",
description="剧情 chunk 检索条数",
)
vector_top_k_summaries: int = Field(
default=3,
ge=0,
env="VECTOR_TOP_K_SUMMARIES",
description="章节摘要检索条数",
)
vector_chunk_size: int = Field(
default=480,
ge=128,
env="VECTOR_CHUNK_SIZE",
description="章节分块的目标字数",
)
vector_chunk_overlap: int = Field(
default=120,
ge=0,
env="VECTOR_CHUNK_OVERLAP",
description="章节分块重叠字数",
)
# -------------------- Linux.do OAuth 配置 --------------------
linuxdo_client_id: Optional[str] = Field(default=None, env="LINUXDO_CLIENT_ID", description="Linux.do OAuth Client ID")
linuxdo_client_secret: Optional[str] = Field(
default=None, env="LINUXDO_CLIENT_SECRET", description="Linux.do OAuth Client Secret"
)
linuxdo_redirect_uri: Optional[HttpUrl] = Field(
default=None, env="LINUXDO_REDIRECT_URI", description="Linux.do OAuth 回调地址"
)
linuxdo_auth_url: Optional[HttpUrl] = Field(
default=None, env="LINUXDO_AUTH_URL", description="Linux.do OAuth 授权地址"
)
linuxdo_token_url: Optional[HttpUrl] = Field(
default=None, env="LINUXDO_TOKEN_URL", description="Linux.do OAuth Token 获取地址"
)
linuxdo_user_info_url: Optional[HttpUrl] = Field(
default=None, env="LINUXDO_USER_INFO_URL", description="Linux.do 用户信息接口地址"
)
# -------------------- 邮件配置 --------------------
smtp_server: Optional[str] = Field(default=None, env="SMTP_SERVER", description="SMTP 服务地址")
smtp_port: int = Field(default=587, env="SMTP_PORT", description="SMTP 服务端口")
smtp_username: Optional[str] = Field(default=None, env="SMTP_USERNAME", description="SMTP 登录用户名")
smtp_password: Optional[str] = Field(default=None, env="SMTP_PASSWORD", description="SMTP 登录密码")
email_from: Optional[str] = Field(default=None, env="EMAIL_FROM", description="邮件发送方显示名或邮箱")
model_config = SettingsConfigDict(
env_file=("new-backend/.env", ".env", "backend/.env"),
env_file_encoding="utf-8",
extra="ignore"
)
@validator("database_url", pre=True, always=True)
def _normalize_database_url(cls, value: Optional[str]) -> Optional[str]:
"""当环境变量中提供 DATABASE_URL 时,原样返回,便于自定义。"""
return value.strip() if isinstance(value, str) and value.strip() else value
@validator("db_provider", pre=True)
def _normalize_db_provider(cls, value: Optional[str]) -> str:
"""统一数据库类型大小写,并限制为受支持的驱动。"""
candidate = (value or "mysql").strip().lower()
if candidate not in {"mysql", "sqlite"}:
raise ValueError("DB_PROVIDER 仅支持 mysql 或 sqlite")
return candidate
@validator("embedding_provider", pre=True)
def _normalize_embedding_provider(cls, value: Optional[str]) -> str:
"""限制嵌入模型提供方的取值范围。"""
candidate = (value or "openai").strip().lower()
if candidate not in {"openai", "ollama"}:
raise ValueError("EMBEDDING_PROVIDER 仅支持 openai 或 ollama")
return candidate
@validator("logging_level", pre=True)
def _normalize_logging_level(cls, value: Optional[str]) -> str:
"""规范日志级别配置。"""
candidate = (value or "INFO").strip().upper()
valid_levels = {"CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"}
if candidate not in valid_levels:
raise ValueError("LOGGING_LEVEL 仅支持 CRITICAL/ERROR/WARNING/INFO/DEBUG/NOTSET")
return candidate
@property
def sqlalchemy_database_uri(self) -> str:
"""生成 SQLAlchemy 兼容的异步连接串,数据库类型由 DB_PROVIDER 控制。"""
if self.database_url:
url = make_url(self.database_url)
database = (url.database or "").strip("/")
normalized = URL.create(
drivername=url.drivername,
username=url.username,
password=url.password,
host=url.host,
port=url.port,
database=database or None,
query=url.query,
)
return normalized.render_as_string(hide_password=False)
if self.db_provider == "sqlite":
# SQLite 固定使用 storage/arboris.db并转换为绝对路径以避免运行目录差异
project_root = Path(__file__).resolve().parents[2]
db_path = (project_root / "storage" / "arboris.db").resolve()
return f"sqlite+aiosqlite:///{db_path}"
# MySQL 分支:统一对密码进行 URL 编码,避免特殊字符破坏连接串
from urllib.parse import quote_plus
encoded_password = quote_plus(self.mysql_password)
database = (self.mysql_database or "").strip("/")
return (
f"mysql+asyncmy://{self.mysql_user}:{encoded_password}"
f"@{self.mysql_host}:{self.mysql_port}/{database}"
)
@property
def is_sqlite_backend(self) -> bool:
"""辅助属性:判断当前连接串是否指向 SQLite用于差异化初始化流程。"""
return make_url(self.sqlalchemy_database_uri).get_backend_name() == "sqlite"
@property
def vector_store_enabled(self) -> bool:
"""是否已经配置向量库,用于在业务逻辑中快速判断。"""
return bool(self.vector_db_url)
@lru_cache
def get_settings() -> Settings:
"""使用 LRU 缓存确保配置只初始化一次,减少 IO 与解析开销。"""
return Settings()
settings = get_settings()

View File

@@ -0,0 +1,33 @@
from fastapi import Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer
from sqlalchemy.ext.asyncio import AsyncSession
from ..core.security import decode_access_token
from ..db.session import get_session
from ..repositories.user_repository import UserRepository
from ..schemas.user import UserInDB
from ..services.auth_service import AuthService
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/auth/token")
async def get_current_user(
token: str = Depends(oauth2_scheme),
session: AsyncSession = Depends(get_session),
) -> UserInDB:
payload = decode_access_token(token)
username = payload["sub"]
repo = UserRepository(session)
user = await repo.get_by_username(username)
if not user:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="用户不存在或已被禁用")
service = AuthService(session)
schema = UserInDB.model_validate(user)
schema.must_change_password = service.requires_password_reset(user)
return schema
async def get_current_admin(current_user: UserInDB = Depends(get_current_user)) -> UserInDB:
if not current_user.is_admin:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="需要管理员权限")
return current_user

View File

@@ -0,0 +1,58 @@
from datetime import datetime, timedelta
from typing import Any, Dict, Optional
from fastapi import HTTPException, status
from jose import JWTError, jwt
from passlib.context import CryptContext
from .config import settings
# 统一的密码哈希上下文,后续如需切换算法只需在此维护
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
def hash_password(password: str) -> str:
"""对用户密码进行哈希处理,任何时候都不要存储明文密码。"""
return pwd_context.hash(password)
def verify_password(plain_password: str, hashed_password: str) -> bool:
"""验证明文密码是否匹配哈希值。"""
return pwd_context.verify(plain_password, hashed_password)
def create_access_token(
subject: str,
*,
expires_delta: Optional[timedelta] = None,
extra_claims: Optional[Dict[str, Any]] = None,
) -> str:
"""生成 JWT 访问令牌,默认过期时间读取自配置。"""
if expires_delta is None:
expires_delta = timedelta(minutes=settings.access_token_expire_minutes)
now = datetime.utcnow()
expire = now + expires_delta
to_encode: Dict[str, Any] = {"sub": subject, "iat": now, "exp": expire}
if extra_claims:
to_encode.update(extra_claims)
return jwt.encode(to_encode, settings.secret_key, algorithm=settings.jwt_algorithm)
def decode_access_token(token: str) -> Dict[str, Any]:
"""解析并校验 JWT失败时抛出 401 异常。"""
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="无效的凭证",
headers={"WWW-Authenticate": "Bearer"},
)
try:
payload = jwt.decode(token, settings.secret_key, algorithms=[settings.jwt_algorithm])
except JWTError as exc:
raise credentials_exception from exc
if "sub" not in payload:
raise credentials_exception
return payload

View File

9
backend/app/db/base.py Normal file
View File

@@ -0,0 +1,9 @@
from sqlalchemy.orm import DeclarativeBase, declared_attr
class Base(DeclarativeBase):
"""SQLAlchemy 基类,自动根据类名生成表名。"""
@declared_attr.directive
def __tablename__(cls) -> str: # type: ignore[override]
return cls.__name__.lower()

122
backend/app/db/init_db.py Normal file
View File

@@ -0,0 +1,122 @@
import logging
from pathlib import Path
from sqlalchemy import select, text
from sqlalchemy.exc import IntegrityError
from sqlalchemy.engine import URL, make_url
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from ..core.config import settings
from ..core.security import hash_password
from ..models import Prompt, SystemConfig, User
from .base import Base
from .system_config_defaults import SYSTEM_CONFIG_DEFAULTS
from .session import AsyncSessionLocal, engine
logger = logging.getLogger(__name__)
async def init_db() -> None:
"""初始化数据库结构并确保默认管理员存在。"""
await _ensure_database_exists()
# ---- 第一步:创建所有表结构 ----
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
logger.info("数据库表结构已初始化")
# ---- 第二步:确保管理员账号至少存在一个 ----
async with AsyncSessionLocal() as session:
admin_exists = await session.execute(select(User).where(User.is_admin.is_(True)))
if not admin_exists.scalars().first():
logger.warning("未检测到管理员账号,正在创建默认管理员 ...")
admin_user = User(
username=settings.admin_default_username,
email=settings.admin_default_email,
hashed_password=hash_password(settings.admin_default_password),
is_admin=True,
)
session.add(admin_user)
try:
await session.commit()
logger.info("默认管理员创建完成:%s", settings.admin_default_username)
except IntegrityError:
await session.rollback()
logger.exception("默认管理员创建失败,可能是并发启动导致,请检查数据库状态")
# ---- 第三步:同步系统配置到数据库 ----
for entry in SYSTEM_CONFIG_DEFAULTS:
value = entry.value_getter(settings)
if value is None:
continue
existing = await session.get(SystemConfig, entry.key)
if existing:
if entry.description and existing.description != entry.description:
existing.description = entry.description
continue
session.add(
SystemConfig(
key=entry.key,
value=value,
description=entry.description,
)
)
await _ensure_default_prompts(session)
await session.commit()
async def _ensure_database_exists() -> None:
"""在首次连接前确认数据库存在,针对不同驱动做最小化准备工作。"""
url = make_url(settings.sqlalchemy_database_uri)
if url.get_backend_name() == "sqlite":
# SQLite 采用文件数据库,确保父目录存在即可,无需额外建库语句
db_path = Path(url.database or "").expanduser()
if not db_path.is_absolute():
project_root = Path(__file__).resolve().parents[2]
db_path = (project_root / db_path).resolve()
db_path.parent.mkdir(parents=True, exist_ok=True)
return
database = (url.database or "").strip("/")
if not database:
return
admin_url = URL.create(
drivername=url.drivername,
username=url.username,
password=url.password,
host=url.host,
port=url.port,
database=None,
query=url.query,
)
admin_engine = create_async_engine(
admin_url.render_as_string(hide_password=False),
isolation_level="AUTOCOMMIT",
)
async with admin_engine.begin() as conn:
await conn.execute(text(f"CREATE DATABASE IF NOT EXISTS `{database}`"))
await admin_engine.dispose()
async def _ensure_default_prompts(session: AsyncSession) -> None:
prompts_dir = Path(__file__).resolve().parents[2] / "prompts"
if not prompts_dir.is_dir():
return
result = await session.execute(select(Prompt.name))
existing_names = set(result.scalars().all())
for prompt_file in sorted(prompts_dir.glob("*.md")):
name = prompt_file.stem
if name in existing_names:
continue
content = prompt_file.read_text(encoding="utf-8")
session.add(Prompt(name=name, content=content))

30
backend/app/db/session.py Normal file
View File

@@ -0,0 +1,30 @@
from collections.abc import AsyncGenerator
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from sqlalchemy.pool import NullPool
from ..core.config import settings
# 根据不同数据库驱动调整连接池参数,确保在多数据库环境下表现稳定
engine_kwargs = {"echo": settings.debug}
if settings.is_sqlite_backend:
# SQLite 场景下禁用连接池并放宽线程检查,避免多协程读写冲突
engine_kwargs.update(
pool_pre_ping=False,
connect_args={"check_same_thread": False},
poolclass=NullPool,
)
else:
# MySQL 场景保持健康检查与连接复用,适用于生产环境的长连接需求
engine_kwargs.update(pool_pre_ping=True, pool_recycle=3600)
engine = create_async_engine(settings.sqlalchemy_database_uri, **engine_kwargs)
# 统一的 Session 工厂,禁用 expire_on_commit 方便返回模型对象
AsyncSessionLocal = async_sessionmaker(bind=engine, expire_on_commit=False)
async def get_session() -> AsyncGenerator[AsyncSession, None]:
"""FastAPI 依赖项:提供一个作用域内共享的数据库会话。"""
async with AsyncSessionLocal() as session:
yield session

View File

@@ -0,0 +1,110 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import Callable, Optional
from ..core.config import Settings
def _to_optional_str(value: Optional[object]) -> Optional[str]:
return str(value) if value is not None else None
def _bool_to_text(value: bool) -> str:
return "true" if value else "false"
@dataclass(frozen=True)
class SystemConfigDefault:
key: str
value_getter: Callable[[Settings], Optional[str]]
description: Optional[str] = None
SYSTEM_CONFIG_DEFAULTS: list[SystemConfigDefault] = [
SystemConfigDefault(
key="llm.api_key",
value_getter=lambda config: config.openai_api_key,
description="默认 LLM API Key用于后台调用大模型。",
),
SystemConfigDefault(
key="llm.base_url",
value_getter=lambda config: _to_optional_str(config.openai_base_url),
description="默认大模型 API Base URL。",
),
SystemConfigDefault(
key="llm.model",
value_getter=lambda config: config.openai_model_name,
description="默认 LLM 模型名称。",
),
SystemConfigDefault(
key="smtp.server",
value_getter=lambda config: config.smtp_server,
description="用于发送邮件验证码的 SMTP 服务器地址。",
),
SystemConfigDefault(
key="smtp.port",
value_getter=lambda config: _to_optional_str(config.smtp_port),
description="SMTP 服务端口。",
),
SystemConfigDefault(
key="smtp.username",
value_getter=lambda config: config.smtp_username,
description="SMTP 登录用户名。",
),
SystemConfigDefault(
key="smtp.password",
value_getter=lambda config: config.smtp_password,
description="SMTP 登录密码。",
),
SystemConfigDefault(
key="smtp.from",
value_getter=lambda config: config.email_from,
description="邮件显示的发件人名称或邮箱。",
),
SystemConfigDefault(
key="auth.allow_registration",
value_getter=lambda config: _bool_to_text(config.allow_registration),
description="是否允许用户自助注册。",
),
SystemConfigDefault(
key="auth.linuxdo_enabled",
value_getter=lambda config: _bool_to_text(config.enable_linuxdo_login),
description="是否启用 Linux.do OAuth 登录。",
),
SystemConfigDefault(
key="linuxdo.client_id",
value_getter=lambda config: config.linuxdo_client_id,
description="Linux.do OAuth Client ID。",
),
SystemConfigDefault(
key="linuxdo.client_secret",
value_getter=lambda config: config.linuxdo_client_secret,
description="Linux.do OAuth Client Secret。",
),
SystemConfigDefault(
key="linuxdo.redirect_uri",
value_getter=lambda config: _to_optional_str(config.linuxdo_redirect_uri),
description="Linux.do OAuth 回调地址。",
),
SystemConfigDefault(
key="linuxdo.auth_url",
value_getter=lambda config: _to_optional_str(config.linuxdo_auth_url),
description="Linux.do OAuth 授权地址。",
),
SystemConfigDefault(
key="linuxdo.token_url",
value_getter=lambda config: _to_optional_str(config.linuxdo_token_url),
description="Linux.do OAuth Token 获取地址。",
),
SystemConfigDefault(
key="linuxdo.user_info_url",
value_getter=lambda config: _to_optional_str(config.linuxdo_user_info_url),
description="Linux.do 用户信息接口地址。",
),
SystemConfigDefault(
key="writer.chapter_versions",
value_getter=lambda config: _to_optional_str(config.writer_chapter_versions),
description="每次生成章节的候选版本数量。",
),
]

105
backend/app/main.py Normal file
View File

@@ -0,0 +1,105 @@
"""FastAPI 应用入口,负责装配路由、依赖与生命周期管理。"""
import logging
from logging.config import dictConfig
from contextlib import asynccontextmanager
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from .core.config import settings
from .db.init_db import init_db
from .services.prompt_service import PromptService
from .db.session import AsyncSessionLocal
from .api.routers import api_router
dictConfig(
{
"version": 1,
"disable_existing_loggers": False,
"formatters": {
"default": {
"format": "%(asctime)s [%(levelname)s] %(name)s - %(message)s",
}
},
"handlers": {
"console": {
"class": "logging.StreamHandler",
"formatter": "default",
}
},
"loggers": {
"backend": {
"level": settings.logging_level,
"handlers": ["console"],
"propagate": False,
},
"app": {
"level": settings.logging_level,
"handlers": ["console"],
"propagate": False,
},
"backend.app": {
"level": settings.logging_level,
"handlers": ["console"],
"propagate": False,
},
"backend.api": {
"level": settings.logging_level,
"handlers": ["console"],
"propagate": False,
},
"backend.services": {
"level": settings.logging_level,
"handlers": ["console"],
"propagate": False,
},
},
"root": {
"level": "WARNING",
"handlers": ["console"],
},
}
)
@asynccontextmanager
async def lifespan(app: FastAPI):
# 应用启动时初始化数据库,并预热提示词缓存
await init_db()
async with AsyncSessionLocal() as session:
prompt_service = PromptService(session)
await prompt_service.preload()
yield
app = FastAPI(
title=settings.app_name,
debug=settings.debug,
version="1.0.0",
lifespan=lifespan,
)
# CORS 配置,生产环境建议改为具体域名
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
app.include_router(api_router)
# 健康检查接口(用于 Docker 健康检查和监控)
@app.get("/health", tags=["Health"])
@app.get("/api/health", tags=["Health"])
async def health_check():
"""健康检查接口,返回应用状态。"""
return {
"status": "healthy",
"app": settings.app_name,
"version": "1.0.0",
}

View File

@@ -0,0 +1,41 @@
"""集中导出 ORM 模型,确保 SQLAlchemy 元数据在初始化时被正确加载。"""
from .admin_setting import AdminSetting
from .llm_config import LLMConfig
from .novel import (
BlueprintCharacter,
BlueprintRelationship,
Chapter,
ChapterEvaluation,
ChapterOutline,
ChapterVersion,
NovelBlueprint,
NovelConversation,
NovelProject,
)
from .prompt import Prompt
from .update_log import UpdateLog
from .usage_metric import UsageMetric
from .user import User
from .user_daily_request import UserDailyRequest
from .system_config import SystemConfig
__all__ = [
"AdminSetting",
"LLMConfig",
"NovelConversation",
"NovelBlueprint",
"BlueprintCharacter",
"BlueprintRelationship",
"ChapterOutline",
"Chapter",
"ChapterVersion",
"ChapterEvaluation",
"NovelProject",
"Prompt",
"UpdateLog",
"UsageMetric",
"User",
"UserDailyRequest",
"SystemConfig",
]

View File

@@ -0,0 +1,13 @@
from sqlalchemy import String, Text
from sqlalchemy.orm import Mapped, mapped_column
from ..db.base import Base
class AdminSetting(Base):
"""后台配置项,采用简单的 KV 结构。"""
__tablename__ = "admin_settings"
key: Mapped[str] = mapped_column(String(64), primary_key=True)
value: Mapped[str] = mapped_column(Text, nullable=False)

View File

@@ -0,0 +1,17 @@
from sqlalchemy import ForeignKey, Text
from sqlalchemy.orm import Mapped, mapped_column, relationship
from ..db.base import Base
class LLMConfig(Base):
"""用户自定义的 LLM 接入配置。"""
__tablename__ = "llm_configs"
user_id: Mapped[int] = mapped_column(ForeignKey("users.id", ondelete="CASCADE"), primary_key=True)
llm_provider_url: Mapped[str | None] = mapped_column(Text())
llm_provider_api_key: Mapped[str | None] = mapped_column(Text())
llm_provider_model: Mapped[str | None] = mapped_column(Text())
user: Mapped["User"] = relationship("User", back_populates="llm_config")

225
backend/app/models/novel.py Normal file
View File

@@ -0,0 +1,225 @@
from __future__ import annotations
from datetime import datetime
from typing import Optional
from sqlalchemy import JSON, BigInteger, DateTime, Float, ForeignKey, Integer, String, Text, func
from sqlalchemy.dialects.mysql import LONGTEXT
from sqlalchemy.orm import Mapped, mapped_column, relationship
from ..db.base import Base
# 自定义列类型:兼容跨数据库环境
BIGINT_PK_TYPE = BigInteger().with_variant(Integer, "sqlite")
LONG_TEXT_TYPE = Text().with_variant(LONGTEXT, "mysql")
class _MetadataAccessor:
"""Descriptor 用于将 `metadata` 访问重定向到 `metadata_`,且保持 Base.metadata 可用。"""
def __get__(self, instance, owner):
if instance is None:
return Base.metadata
return instance.metadata_
def __set__(self, instance, value):
instance.metadata_ = value
class NovelProject(Base):
"""小说项目主表,仅存放轻量级元数据。"""
__tablename__ = "novel_projects"
id: Mapped[str] = mapped_column(String(36), primary_key=True)
user_id: Mapped[int] = mapped_column(ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True)
title: Mapped[str] = mapped_column(String(255), nullable=False)
initial_prompt: Mapped[Optional[str]] = mapped_column(Text)
status: Mapped[str] = mapped_column(String(32), default="draft")
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), server_default=func.now())
updated_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now())
owner: Mapped["User"] = relationship("User", back_populates="novel_projects")
blueprint: Mapped[Optional["NovelBlueprint"]] = relationship(
back_populates="project", cascade="all, delete-orphan", uselist=False
)
conversations: Mapped[list["NovelConversation"]] = relationship(
back_populates="project", cascade="all, delete-orphan", order_by="NovelConversation.seq"
)
characters: Mapped[list["BlueprintCharacter"]] = relationship(
back_populates="project", cascade="all, delete-orphan", order_by="BlueprintCharacter.position"
)
relationships_: Mapped[list["BlueprintRelationship"]] = relationship(
back_populates="project", cascade="all, delete-orphan", order_by="BlueprintRelationship.position"
)
outlines: Mapped[list["ChapterOutline"]] = relationship(
back_populates="project", cascade="all, delete-orphan", order_by="ChapterOutline.chapter_number"
)
chapters: Mapped[list["Chapter"]] = relationship(
back_populates="project", cascade="all, delete-orphan", order_by="Chapter.chapter_number"
)
class NovelConversation(Base):
"""对话记录表,存储概念阶段的连续对话。"""
__tablename__ = "novel_conversations"
id: Mapped[int] = mapped_column(BIGINT_PK_TYPE, primary_key=True, autoincrement=True)
project_id: Mapped[str] = mapped_column(ForeignKey("novel_projects.id", ondelete="CASCADE"), nullable=False)
seq: Mapped[int] = mapped_column(Integer, nullable=False)
role: Mapped[str] = mapped_column(String(32), nullable=False)
content: Mapped[str] = mapped_column(LONG_TEXT_TYPE, nullable=False)
metadata_: Mapped[Optional[dict]] = mapped_column("metadata", JSON)
metadata = _MetadataAccessor()
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), server_default=func.now())
project: Mapped[NovelProject] = relationship(back_populates="conversations")
class NovelBlueprint(Base):
"""蓝图主体信息(标题、风格等)。"""
__tablename__ = "novel_blueprints"
project_id: Mapped[str] = mapped_column(
ForeignKey("novel_projects.id", ondelete="CASCADE"), primary_key=True
)
title: Mapped[Optional[str]] = mapped_column(String(255))
target_audience: Mapped[Optional[str]] = mapped_column(String(255))
genre: Mapped[Optional[str]] = mapped_column(String(128))
style: Mapped[Optional[str]] = mapped_column(String(128))
tone: Mapped[Optional[str]] = mapped_column(String(128))
one_sentence_summary: Mapped[Optional[str]] = mapped_column(Text)
full_synopsis: Mapped[Optional[str]] = mapped_column(LONG_TEXT_TYPE)
world_setting: Mapped[Optional[dict]] = mapped_column(JSON, default=dict)
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), server_default=func.now())
updated_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now())
project: Mapped[NovelProject] = relationship(back_populates="blueprint")
class BlueprintCharacter(Base):
"""蓝图角色信息。"""
__tablename__ = "blueprint_characters"
id: Mapped[int] = mapped_column(BIGINT_PK_TYPE, primary_key=True, autoincrement=True)
project_id: Mapped[str] = mapped_column(ForeignKey("novel_projects.id", ondelete="CASCADE"), nullable=False)
name: Mapped[str] = mapped_column(String(255), nullable=False)
identity: Mapped[Optional[str]] = mapped_column(String(255))
personality: Mapped[Optional[str]] = mapped_column(Text)
goals: Mapped[Optional[str]] = mapped_column(Text)
abilities: Mapped[Optional[str]] = mapped_column(Text)
relationship_to_protagonist: Mapped[Optional[str]] = mapped_column(Text)
extra: Mapped[Optional[dict]] = mapped_column(JSON)
position: Mapped[int] = mapped_column(Integer, default=0)
project: Mapped[NovelProject] = relationship(back_populates="characters")
class BlueprintRelationship(Base):
"""角色之间的关系。"""
__tablename__ = "blueprint_relationships"
id: Mapped[int] = mapped_column(BIGINT_PK_TYPE, primary_key=True, autoincrement=True)
project_id: Mapped[str] = mapped_column(ForeignKey("novel_projects.id", ondelete="CASCADE"), nullable=False)
character_from: Mapped[str] = mapped_column(String(255), nullable=False)
character_to: Mapped[str] = mapped_column(String(255), nullable=False)
description: Mapped[Optional[str]] = mapped_column(Text)
position: Mapped[int] = mapped_column(Integer, default=0)
project: Mapped[NovelProject] = relationship(back_populates="relationships_")
class ChapterOutline(Base):
"""章节纲要。"""
__tablename__ = "chapter_outlines"
id: Mapped[int] = mapped_column(BIGINT_PK_TYPE, primary_key=True, autoincrement=True)
project_id: Mapped[str] = mapped_column(ForeignKey("novel_projects.id", ondelete="CASCADE"), nullable=False)
chapter_number: Mapped[int] = mapped_column(Integer, nullable=False)
title: Mapped[str] = mapped_column(String(255), nullable=False)
summary: Mapped[Optional[str]] = mapped_column(Text)
project: Mapped[NovelProject] = relationship(back_populates="outlines")
class Chapter(Base):
"""章节正文状态,指向选中的版本。"""
__tablename__ = "chapters"
id: Mapped[int] = mapped_column(BIGINT_PK_TYPE, primary_key=True, autoincrement=True)
project_id: Mapped[str] = mapped_column(ForeignKey("novel_projects.id", ondelete="CASCADE"), nullable=False)
chapter_number: Mapped[int] = mapped_column(Integer, nullable=False)
real_summary: Mapped[Optional[str]] = mapped_column(Text)
status: Mapped[str] = mapped_column(String(32), default="not_generated")
word_count: Mapped[int] = mapped_column(Integer, default=0)
selected_version_id: Mapped[Optional[int]] = mapped_column(
ForeignKey("chapter_versions.id", ondelete="SET NULL"), nullable=True
)
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), server_default=func.now())
updated_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now())
project: Mapped[NovelProject] = relationship(back_populates="chapters")
versions: Mapped[list["ChapterVersion"]] = relationship(
"ChapterVersion",
back_populates="chapter",
cascade="all, delete-orphan",
order_by="ChapterVersion.created_at",
primaryjoin="Chapter.id == ChapterVersion.chapter_id",
foreign_keys="[ChapterVersion.chapter_id]",
)
selected_version: Mapped[Optional["ChapterVersion"]] = relationship(
"ChapterVersion",
foreign_keys=[selected_version_id],
primaryjoin="Chapter.selected_version_id == ChapterVersion.id",
post_update=True,
)
evaluations: Mapped[list["ChapterEvaluation"]] = relationship(
back_populates="chapter", cascade="all, delete-orphan", order_by="ChapterEvaluation.created_at"
)
class ChapterVersion(Base):
"""章节生成的不同版本文本。"""
__tablename__ = "chapter_versions"
id: Mapped[int] = mapped_column(BIGINT_PK_TYPE, primary_key=True, autoincrement=True)
chapter_id: Mapped[int] = mapped_column(ForeignKey("chapters.id", ondelete="CASCADE"), nullable=False)
version_label: Mapped[Optional[str]] = mapped_column(String(64))
provider: Mapped[Optional[str]] = mapped_column(String(64))
content: Mapped[str] = mapped_column(LONG_TEXT_TYPE, nullable=False)
metadata_: Mapped[Optional[dict]] = mapped_column("metadata", JSON)
metadata = _MetadataAccessor()
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), server_default=func.now())
chapter: Mapped[Chapter] = relationship(
"Chapter",
back_populates="versions",
foreign_keys=[chapter_id],
)
evaluations: Mapped[list["ChapterEvaluation"]] = relationship(
back_populates="version", cascade="all, delete-orphan"
)
class ChapterEvaluation(Base):
"""章节评估记录。"""
__tablename__ = "chapter_evaluations"
id: Mapped[int] = mapped_column(BIGINT_PK_TYPE, primary_key=True, autoincrement=True)
chapter_id: Mapped[int] = mapped_column(ForeignKey("chapters.id", ondelete="CASCADE"), nullable=False)
version_id: Mapped[Optional[int]] = mapped_column(ForeignKey("chapter_versions.id", ondelete="CASCADE"))
decision: Mapped[Optional[str]] = mapped_column(String(32))
feedback: Mapped[Optional[str]] = mapped_column(Text)
score: Mapped[Optional[float]] = mapped_column(Float)
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), server_default=func.now())
chapter: Mapped[Chapter] = relationship(back_populates="evaluations")
version: Mapped[Optional[ChapterVersion]] = relationship(back_populates="evaluations")

View File

@@ -0,0 +1,25 @@
from datetime import datetime
from typing import Optional
from sqlalchemy import DateTime, String, Text, func
from sqlalchemy.orm import Mapped, mapped_column
from ..db.base import Base
class Prompt(Base):
"""提示词表,支持后台 CRUD 操作。"""
__tablename__ = "prompts"
id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
name: Mapped[str] = mapped_column(String(100), unique=True, nullable=False, index=True)
title: Mapped[Optional[str]] = mapped_column(String(255))
content: Mapped[str] = mapped_column(Text, nullable=False)
tags: Mapped[Optional[str]] = mapped_column(String(255))
updated_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
server_default=func.now(),
onupdate=func.now(),
)
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), server_default=func.now())

View File

@@ -0,0 +1,14 @@
from sqlalchemy import String, Text
from sqlalchemy.orm import Mapped, mapped_column
from ..db.base import Base
class SystemConfig(Base):
"""系统级配置项,例如默认 LLM API Key、模型名称等。"""
__tablename__ = "system_configs"
key: Mapped[str] = mapped_column(String(100), primary_key=True)
value: Mapped[str] = mapped_column(Text, nullable=False)
description: Mapped[str | None] = mapped_column(String(255))

View File

@@ -0,0 +1,18 @@
from datetime import datetime
from sqlalchemy import Boolean, DateTime, String, Text, func
from sqlalchemy.orm import Mapped, mapped_column
from ..db.base import Base
class UpdateLog(Base):
"""更新日志表,供公告与后台管理使用。"""
__tablename__ = "update_logs"
id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
content: Mapped[str] = mapped_column(Text, nullable=False)
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), server_default=func.now())
created_by: Mapped[str | None] = mapped_column(String(64))
is_pinned: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False)

View File

@@ -0,0 +1,13 @@
from sqlalchemy import Integer, String
from sqlalchemy.orm import Mapped, mapped_column
from ..db.base import Base
class UsageMetric(Base):
"""通用计数器表,目前用于记录 API 请求次数等统计数据。"""
__tablename__ = "usage_metrics"
key: Mapped[str] = mapped_column(String(64), primary_key=True)
value: Mapped[int] = mapped_column(Integer, nullable=False, default=0)

View File

@@ -0,0 +1,31 @@
from datetime import datetime
from typing import Optional
from sqlalchemy import Boolean, DateTime, ForeignKey, Integer, String, func
from sqlalchemy.orm import Mapped, mapped_column, relationship
from ..db.base import Base
class User(Base):
"""用户主表,记录账号及权限信息。"""
__tablename__ = "users"
id: Mapped[int] = mapped_column(Integer, primary_key=True, index=True)
username: Mapped[str] = mapped_column(String(64), unique=True, nullable=False, index=True)
email: Mapped[Optional[str]] = mapped_column(String(128), unique=True)
hashed_password: Mapped[str] = mapped_column(String(255), nullable=False)
external_id: Mapped[Optional[str]] = mapped_column(String(255), unique=True)
is_admin: Mapped[bool] = mapped_column(Boolean, default=False)
is_active: Mapped[bool] = mapped_column(Boolean, default=True)
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), server_default=func.now())
updated_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
server_default=func.now(),
onupdate=func.now(),
)
# 关系映射
novel_projects: Mapped[list["NovelProject"]] = relationship("NovelProject", back_populates="owner")
llm_config: Mapped[Optional["LLMConfig"]] = relationship("LLMConfig", back_populates="user", uselist=False)

View File

@@ -0,0 +1,18 @@
from datetime import date
from sqlalchemy import Date, ForeignKey, Integer, UniqueConstraint
from sqlalchemy.orm import Mapped, mapped_column
from ..db.base import Base
class UserDailyRequest(Base):
"""记录每位用户每日使用次数的限流表。"""
__tablename__ = "user_daily_requests"
__table_args__ = (UniqueConstraint("user_id", "request_date", name="uq_user_daily"),)
id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
user_id: Mapped[int] = mapped_column(ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True)
request_date: Mapped[date] = mapped_column(Date, nullable=False)
request_count: Mapped[int] = mapped_column(Integer, default=0)

View File

View File

@@ -0,0 +1,15 @@
from typing import Optional
from sqlalchemy import select
from .base import BaseRepository
from ..models import AdminSetting
class AdminSettingRepository(BaseRepository[AdminSetting]):
model = AdminSetting
async def get_value(self, key: str) -> Optional[str]:
result = await self.session.execute(select(AdminSetting).where(AdminSetting.key == key))
record = result.scalars().first()
return record.value if record else None

View File

@@ -0,0 +1,44 @@
from typing import Any, Generic, Iterable, Optional, TypeVar
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import InstrumentedAttribute
ModelType = TypeVar("ModelType")
class BaseRepository(Generic[ModelType]):
"""通用仓储基类,封装常见的增删改查操作。"""
model: type[ModelType]
def __init__(self, session: AsyncSession):
self.session = session
async def get(self, **filters: Any) -> Optional[ModelType]:
stmt = select(self.model).filter_by(**filters)
result = await self.session.execute(stmt)
return result.scalars().first()
async def list(self, *, filters: Optional[dict[str, Any]] = None) -> Iterable[ModelType]:
stmt = select(self.model)
if filters:
stmt = stmt.filter_by(**filters)
result = await self.session.execute(stmt)
return result.scalars().all()
async def add(self, instance: ModelType) -> ModelType:
self.session.add(instance)
await self.session.flush()
return instance
async def delete(self, instance: ModelType) -> None:
await self.session.delete(instance)
async def update_fields(self, instance: ModelType, **values: Any) -> ModelType:
for key, value in values.items():
if value is None:
continue
setattr(instance, key, value)
await self.session.flush()
return instance

View File

@@ -0,0 +1,14 @@
from typing import Optional
from sqlalchemy import select
from .base import BaseRepository
from ..models import LLMConfig
class LLMConfigRepository(BaseRepository[LLMConfig]):
model = LLMConfig
async def get_by_user(self, user_id: int) -> Optional[LLMConfig]:
result = await self.session.execute(select(LLMConfig).where(LLMConfig.user_id == user_id))
return result.scalars().first()

View File

@@ -0,0 +1,55 @@
from typing import Iterable, Optional
from sqlalchemy import select
from sqlalchemy.orm import selectinload
from .base import BaseRepository
from ..models import Chapter, NovelProject
class NovelRepository(BaseRepository[NovelProject]):
model = NovelProject
async def get_by_id(self, project_id: str) -> Optional[NovelProject]:
stmt = (
select(NovelProject)
.where(NovelProject.id == project_id)
.options(
selectinload(NovelProject.blueprint),
selectinload(NovelProject.characters),
selectinload(NovelProject.relationships_),
selectinload(NovelProject.outlines),
selectinload(NovelProject.conversations),
selectinload(NovelProject.chapters).selectinload(Chapter.versions),
selectinload(NovelProject.chapters).selectinload(Chapter.evaluations),
selectinload(NovelProject.chapters).selectinload(Chapter.selected_version),
)
)
result = await self.session.execute(stmt)
return result.scalars().first()
async def list_by_user(self, user_id: int) -> Iterable[NovelProject]:
result = await self.session.execute(
select(NovelProject)
.where(NovelProject.user_id == user_id)
.order_by(NovelProject.updated_at.desc())
.options(
selectinload(NovelProject.blueprint),
selectinload(NovelProject.outlines),
selectinload(NovelProject.chapters).selectinload(Chapter.selected_version),
)
)
return result.scalars().all()
async def list_all(self) -> Iterable[NovelProject]:
result = await self.session.execute(
select(NovelProject)
.order_by(NovelProject.updated_at.desc())
.options(
selectinload(NovelProject.owner),
selectinload(NovelProject.blueprint),
selectinload(NovelProject.outlines),
selectinload(NovelProject.chapters).selectinload(Chapter.selected_version),
)
)
return result.scalars().all()

View File

@@ -0,0 +1,19 @@
from typing import Iterable, Optional
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from .base import BaseRepository
from ..models import Prompt
class PromptRepository(BaseRepository[Prompt]):
model = Prompt
async def get_by_name(self, name: str) -> Optional[Prompt]:
result = await self.session.execute(select(Prompt).where(Prompt.name == name))
return result.scalars().first()
async def list_all(self) -> Iterable[Prompt]:
result = await self.session.execute(select(Prompt).order_by(Prompt.name))
return result.scalars().all()

View File

@@ -0,0 +1,18 @@
from typing import Iterable, Optional
from sqlalchemy import select
from .base import BaseRepository
from ..models import SystemConfig
class SystemConfigRepository(BaseRepository[SystemConfig]):
model = SystemConfig
async def get_by_key(self, key: str) -> Optional[SystemConfig]:
result = await self.session.execute(select(SystemConfig).where(SystemConfig.key == key))
return result.scalars().first()
async def list_all(self) -> Iterable[SystemConfig]:
result = await self.session.execute(select(SystemConfig).order_by(SystemConfig.key))
return result.scalars().all()

View File

@@ -0,0 +1,19 @@
from typing import Iterable
from sqlalchemy import select
from .base import BaseRepository
from ..models import UpdateLog
class UpdateLogRepository(BaseRepository[UpdateLog]):
model = UpdateLog
async def list(self) -> Iterable[UpdateLog]:
result = await self.session.execute(select(UpdateLog).order_by(UpdateLog.created_at.desc()))
return result.scalars().all()
async def list_latest(self, limit: int = 5) -> Iterable[UpdateLog]:
stmt = select(UpdateLog).order_by(UpdateLog.created_at.desc()).limit(limit)
result = await self.session.execute(stmt)
return result.scalars().all()

View File

@@ -0,0 +1,19 @@
from typing import Optional
from sqlalchemy import select
from .base import BaseRepository
from ..models import UsageMetric
class UsageMetricRepository(BaseRepository[UsageMetric]):
model = UsageMetric
async def get_or_create(self, key: str) -> UsageMetric:
result = await self.session.execute(select(UsageMetric).where(UsageMetric.key == key))
instance = result.scalars().first()
if instance is None:
instance = UsageMetric(key=key, value=0)
self.session.add(instance)
await self.session.flush()
return instance

View File

@@ -0,0 +1,62 @@
from datetime import date
from typing import Iterable, Optional
from sqlalchemy import func, select, update
from sqlalchemy.ext.asyncio import AsyncSession
from .base import BaseRepository
from ..models import User, UserDailyRequest
class UserRepository(BaseRepository[User]):
model = User
async def get_by_username(self, username: str) -> Optional[User]:
stmt = select(User).where(User.username == username)
result = await self.session.execute(stmt)
return result.scalars().first()
async def get_by_email(self, email: str) -> Optional[User]:
stmt = select(User).where(User.email == email)
result = await self.session.execute(stmt)
return result.scalars().first()
async def get_by_external_id(self, external_id: str) -> Optional[User]:
stmt = select(User).where(User.external_id == external_id)
result = await self.session.execute(stmt)
return result.scalars().first()
async def list_all(self) -> Iterable[User]:
result = await self.session.execute(select(User))
return result.scalars().all()
async def increment_daily_request(self, user_id: int) -> None:
today = date.today()
stmt = select(UserDailyRequest).where(
UserDailyRequest.user_id == user_id,
UserDailyRequest.request_date == today,
)
result = await self.session.execute(stmt)
record = result.scalars().first()
if record is None:
record = UserDailyRequest(user_id=user_id, request_date=today, request_count=1)
self.session.add(record)
else:
record.request_count += 1
await self.session.flush()
async def get_daily_request(self, user_id: int) -> int:
today = date.today()
stmt = select(UserDailyRequest.request_count).where(
UserDailyRequest.user_id == user_id,
UserDailyRequest.request_date == today,
)
result = await self.session.execute(stmt)
value = result.scalars().first()
return value or 0
async def count_users(self) -> int:
stmt = select(func.count(User.id))
result = await self.session.execute(stmt)
return result.scalar_one()

View File

View File

@@ -0,0 +1,49 @@
from datetime import datetime
from typing import Optional
from pydantic import BaseModel, Field
class Statistics(BaseModel):
novel_count: int
user_count: int
api_request_count: int
class DailyRequestLimit(BaseModel):
limit: int = Field(..., ge=0, description="匿名用户每日可用次数")
class UpdateLogRead(BaseModel):
id: int
content: str
created_at: datetime
created_by: Optional[str] = None
is_pinned: bool
class Config:
from_attributes = True
class UpdateLogBase(BaseModel):
content: Optional[str] = None
is_pinned: Optional[bool] = None
class UpdateLogCreate(UpdateLogBase):
content: str
class UpdateLogUpdate(UpdateLogBase):
pass
class AdminNovelSummary(BaseModel):
id: str
title: str
owner_id: int
owner_username: str
genre: str
last_edited: str
completed_chapters: int
total_chapters: int

View File

@@ -0,0 +1,23 @@
from typing import Optional
from pydantic import BaseModel, Field
class SystemConfigBase(BaseModel):
key: str = Field(..., description="配置键,需全局唯一")
value: str = Field(..., description="配置值,统一存储为字符串")
description: Optional[str] = Field(default=None, description="配置用途说明")
class SystemConfigCreate(SystemConfigBase):
pass
class SystemConfigUpdate(BaseModel):
value: Optional[str] = Field(default=None)
description: Optional[str] = Field(default=None)
class SystemConfigRead(SystemConfigBase):
class Config:
from_attributes = True

View File

@@ -0,0 +1,20 @@
from typing import Optional
from pydantic import BaseModel, HttpUrl, Field
class LLMConfigBase(BaseModel):
llm_provider_url: Optional[HttpUrl] = Field(default=None, description="自定义 LLM 服务地址")
llm_provider_api_key: Optional[str] = Field(default=None, description="自定义 LLM API Key")
llm_provider_model: Optional[str] = Field(default=None, description="自定义模型名称")
class LLMConfigCreate(LLMConfigBase):
pass
class LLMConfigRead(LLMConfigBase):
user_id: int
class Config:
from_attributes = True

View File

@@ -0,0 +1,170 @@
from enum import Enum
from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Field
class ChoiceOption(BaseModel):
"""前端选择项描述,用于动态 UI 控件。"""
id: str
label: str
class UIControl(BaseModel):
"""描述前端应渲染的组件类型与配置。"""
type: str = Field(..., description="控件类型,如 single_choice/text_input")
options: Optional[List[ChoiceOption]] = Field(default=None, description="可选项列表")
placeholder: Optional[str] = Field(default=None, description="输入提示文案")
class ConverseResponse(BaseModel):
"""概念对话接口的统一返回体。"""
ai_message: str
ui_control: UIControl
conversation_state: Dict[str, Any]
is_complete: bool = False
ready_for_blueprint: Optional[bool] = None
class ConverseRequest(BaseModel):
"""概念对话接口的请求体。"""
user_input: Dict[str, Any]
conversation_state: Dict[str, Any]
class ChapterGenerationStatus(str, Enum):
NOT_GENERATED = "not_generated"
GENERATING = "generating"
EVALUATING = "evaluating"
SELECTING = "selecting"
FAILED = "failed"
EVALUATION_FAILED = "evaluation_failed"
WAITING_FOR_CONFIRM = "waiting_for_confirm"
SUCCESSFUL = "successful"
class ChapterOutline(BaseModel):
chapter_number: int
title: str
summary: str
class Chapter(ChapterOutline):
real_summary: Optional[str] = None
content: Optional[str] = None
versions: Optional[List[str]] = None
evaluation: Optional[str] = None
generation_status: ChapterGenerationStatus = ChapterGenerationStatus.NOT_GENERATED
class Relationship(BaseModel):
character_from: str
character_to: str
description: str
class Blueprint(BaseModel):
title: str
target_audience: str = ""
genre: str = ""
style: str = ""
tone: str = ""
one_sentence_summary: str = ""
full_synopsis: str = ""
world_setting: Dict[str, Any] = {}
characters: List[Dict[str, Any]] = []
relationships: List[Relationship] = []
chapter_outline: List[ChapterOutline] = []
class NovelProject(BaseModel):
id: str
user_id: int
title: str
initial_prompt: str
conversation_history: List[Dict[str, Any]] = []
blueprint: Optional[Blueprint] = None
chapters: List[Chapter] = []
class Config:
from_attributes = True
class NovelProjectSummary(BaseModel):
id: str
title: str
genre: str
last_edited: str
completed_chapters: int
total_chapters: int
class BlueprintGenerationResponse(BaseModel):
blueprint: Blueprint
ai_message: str
class ChapterGenerationResponse(BaseModel):
ai_message: str
chapter_versions: List[Dict[str, Any]]
class NovelSectionType(str, Enum):
OVERVIEW = "overview"
WORLD_SETTING = "world_setting"
CHARACTERS = "characters"
RELATIONSHIPS = "relationships"
CHAPTER_OUTLINE = "chapter_outline"
CHAPTERS = "chapters"
class NovelSectionResponse(BaseModel):
section: NovelSectionType
data: Dict[str, Any]
class GenerateChapterRequest(BaseModel):
chapter_number: int
writing_notes: Optional[str] = Field(default=None, description="章节额外写作指令")
class SelectVersionRequest(BaseModel):
chapter_number: int
version_index: int
class EvaluateChapterRequest(BaseModel):
chapter_number: int
class UpdateChapterOutlineRequest(BaseModel):
chapter_number: int
title: str
summary: str
class DeleteChapterRequest(BaseModel):
chapter_numbers: List[int]
class GenerateOutlineRequest(BaseModel):
start_chapter: int
num_chapters: int
class BlueprintPatch(BaseModel):
one_sentence_summary: Optional[str] = None
full_synopsis: Optional[str] = None
world_setting: Optional[Dict[str, Any]] = None
characters: Optional[List[Dict[str, Any]]] = None
relationships: Optional[List[Relationship]] = None
chapter_outline: Optional[List[ChapterOutline]] = None
class EditChapterRequest(BaseModel):
chapter_number: int
content: str

View File

@@ -0,0 +1,56 @@
from typing import Any, List, Optional
from pydantic import BaseModel, Field
class PromptBase(BaseModel):
"""Prompt 基础模型。"""
name: str = Field(..., description="唯一标识,用于代码引用")
title: Optional[str] = Field(default=None, description="可读标题")
content: str = Field(..., description="提示词具体内容")
tags: Optional[List[str]] = Field(default=None, description="标签集合")
class PromptCreate(PromptBase):
"""创建 Prompt 时使用的模型。"""
pass
class PromptUpdate(BaseModel):
"""更新 Prompt 时使用的模型。"""
title: Optional[str] = Field(default=None)
content: Optional[str] = Field(default=None)
tags: Optional[List[str]] = Field(default=None)
class PromptRead(PromptBase):
"""对外暴露的 Prompt 数据结构。"""
id: int
class Config:
from_attributes = True
@classmethod
def model_validate(cls, obj: Any, *args: Any, **kwargs: Any) -> "PromptRead": # type: ignore[override]
"""在转换 ORM 模型时,将字符串标签拆分为列表。"""
if hasattr(obj, "id") and hasattr(obj, "name"):
raw_tags = getattr(obj, "tags", None)
if isinstance(raw_tags, str):
processed = [tag for tag in raw_tags.split(",") if tag]
elif isinstance(raw_tags, list):
processed = raw_tags
else:
processed = None
data = {
"id": getattr(obj, "id"),
"name": getattr(obj, "name"),
"title": getattr(obj, "title", None),
"content": getattr(obj, "content", None),
"tags": processed,
}
return super().model_validate(data, *args, **kwargs)
return super().model_validate(obj, *args, **kwargs)

View File

@@ -0,0 +1,74 @@
from pydantic import BaseModel, EmailStr, Field
from typing import Optional
class UserBase(BaseModel):
"""用户基础数据结构,供多处复用。"""
username: str = Field(..., description="用户名")
email: Optional[EmailStr] = Field(default=None, description="邮箱,可选")
class UserCreate(UserBase):
"""注册时使用的模型。"""
password: str = Field(..., min_length=6, description="明文密码")
class UserUpdate(BaseModel):
"""用户信息修改模型。"""
email: Optional[EmailStr] = Field(default=None, description="邮箱")
password: Optional[str] = Field(default=None, min_length=6, description="新密码")
class User(UserBase):
"""对外暴露的用户信息。"""
id: int = Field(..., description="用户主键")
is_admin: bool = Field(default=False, description="是否为管理员")
must_change_password: bool = Field(default=False, description="是否需要强制修改密码")
class Config:
from_attributes = True
class UserInDB(User):
"""数据库内部使用的模型,包含哈希后的密码。"""
hashed_password: str
class Token(BaseModel):
"""登录成功后返回的访问令牌。"""
access_token: str
token_type: str = "bearer"
must_change_password: bool = Field(default=False, description="是否需要强制修改密码")
class TokenPayload(BaseModel):
"""JWT 负载信息。"""
sub: str
is_admin: bool = False
class UserRegistration(UserCreate):
"""注册接口需要的字段,包含邮箱验证码。"""
verification_code: str = Field(..., min_length=4, max_length=10, description="邮箱验证码")
class PasswordChangeRequest(BaseModel):
"""管理员修改密码请求模型。"""
old_password: str = Field(..., min_length=6, description="当前密码")
new_password: str = Field(..., min_length=8, description="新密码")
class AuthOptions(BaseModel):
"""认证相关开关信息,供前端动态控制功能。"""
allow_registration: bool = Field(..., description="是否允许开放用户注册")
enable_linuxdo_login: bool = Field(..., description="是否启用 Linux.do 登录")

View File

View File

@@ -0,0 +1,27 @@
from typing import Optional
from sqlalchemy.ext.asyncio import AsyncSession
from ..models import AdminSetting
from ..repositories.admin_setting_repository import AdminSettingRepository
class AdminSettingService:
"""管理员配置项服务,提供简单的 KV 操作。"""
def __init__(self, session: AsyncSession):
self.session = session
self.repo = AdminSettingRepository(session)
async def get(self, key: str, default: Optional[str] = None) -> Optional[str]:
value = await self.repo.get_value(key)
return value if value is not None else default
async def set(self, key: str, value: str) -> None:
record = await self.repo.get(key=key)
if record:
await self.repo.update_fields(record, value=value)
else:
setting = AdminSetting(key=key, value=value)
await self.repo.add(setting)
await self.session.commit()

View File

@@ -0,0 +1,389 @@
import asyncio
import logging
import random
import secrets
import string
import time
from typing import Dict, Optional
import httpx
from email.header import Header
from email.mime.text import MIMEText
from email.utils import formataddr, parseaddr
from fastapi import HTTPException, status
import smtplib
from ..core.config import settings
from ..core.security import create_access_token, hash_password, verify_password
from ..models import User
from ..repositories.system_config_repository import SystemConfigRepository
from ..repositories.user_repository import UserRepository
from ..schemas.user import AuthOptions, Token, UserCreate, UserInDB, UserRegistration
_VERIFICATION_CACHE: Dict[str, tuple[str, float]] = {}
_LAST_SEND_TIME: Dict[str, float] = {}
class AuthService:
"""认证与授权逻辑封装登录、注册、OAuth 对接等操作。"""
def __init__(self, session):
self.session = session
self.user_repo = UserRepository(session)
self.system_config_repo = SystemConfigRepository(session)
self._verification_cache = _VERIFICATION_CACHE
self._last_send_time = _LAST_SEND_TIME
# ------------------------------------------------------------------
# 用户登录 / 注册
# ------------------------------------------------------------------
async def authenticate_user(self, username: str, password: str) -> User:
user = await self.user_repo.get_by_username(username)
if not user or not verify_password(password, user.hashed_password):
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="用户名或密码错误")
return user
async def create_access_token(
self,
user: User | UserInDB,
*,
must_change_password: Optional[bool] = None,
) -> Token:
payload = {"is_admin": user.is_admin}
token = create_access_token(user.username, extra_claims=payload)
should_change = self.requires_password_reset(user) if must_change_password is None else must_change_password
return Token(access_token=token, must_change_password=should_change)
async def register_user(self, payload: UserRegistration) -> User:
if not await self.is_registration_enabled():
raise HTTPException(status_code=403, detail="当前暂未开放注册")
if await self.user_repo.get_by_username(payload.username):
raise HTTPException(status_code=400, detail="用户名已存在")
if payload.email and await self.user_repo.get_by_email(payload.email):
raise HTTPException(status_code=400, detail="邮箱已被使用")
if not self.verify_code(payload.email, payload.verification_code):
raise HTTPException(status_code=400, detail="验证码错误或已过期")
hashed_password = hash_password(payload.password)
user = User(
username=payload.username,
email=payload.email,
hashed_password=hashed_password,
)
self.session.add(user)
await self.session.commit()
return user
# ------------------------------------------------------------------
# 邮箱验证码逻辑
# ------------------------------------------------------------------
async def send_verification_code(self, email: str) -> None:
if not await self.is_registration_enabled():
raise HTTPException(status_code=403, detail="当前暂未开放注册")
now = time.time()
if email in self._last_send_time and now - self._last_send_time[email] < 60:
raise HTTPException(status_code=429, detail="请稍后再试1分钟内不可重复发送")
code = "".join(random.choices(string.digits, k=6))
self._verification_cache[email] = (code, now + 300)
self._last_send_time[email] = now
smtp_config = await self._load_smtp_config()
if not smtp_config:
raise HTTPException(status_code=500, detail="未配置邮件服务,请联系管理员")
await self._send_email(email, code, smtp_config)
def verify_code(self, email: str | None, code: str) -> bool:
if not email:
return False
cached = self._verification_cache.get(email)
if not cached:
return False
expected, expire_at = cached
if time.time() > expire_at:
self._verification_cache.pop(email, None)
return False
if code != expected:
return False
self._verification_cache.pop(email, None)
return True
async def _load_smtp_config(self) -> Optional[Dict[str, str]]:
keys = [
"smtp.server",
"smtp.port",
"smtp.username",
"smtp.password",
"smtp.from",
]
configs = {}
for key in keys:
config = await self.system_config_repo.get_by_key(key)
if config:
configs[key] = config.value
required_keys = {"smtp.server", "smtp.port", "smtp.username", "smtp.password", "smtp.from"}
if not required_keys.issubset(configs.keys()):
return None
return configs
async def _send_email(self, to_email: str, code: str, smtp_config: Dict[str, str]) -> None:
logger = logging.getLogger(__name__)
server = smtp_config["smtp.server"]
port = int(smtp_config.get("smtp.port", "465"))
username = smtp_config["smtp.username"]
password = smtp_config["smtp.password"]
from_value = smtp_config.get("smtp.from") or username
display_name, from_addr = parseaddr(from_value)
if not display_name and "@" not in from_value and "<" not in from_value and from_value.strip():
display_name = from_value.strip()
if not from_addr or "@" not in from_addr:
if from_addr and "@" not in from_addr:
logger.warning(
"发件邮箱缺少 @,已回退为登录账号",
extra={"original": from_addr},
)
from_addr = username
try:
from_addr.encode("ascii")
except UnicodeEncodeError:
logger.warning(
"发件邮箱包含非 ASCII 字符,已回退为登录账号",
extra={"original": from_addr},
)
from_addr = username
if display_name:
formatted_from = formataddr((Header(display_name, "utf-8").encode(), from_addr))
else:
formatted_from = from_addr
try:
to_email.encode("ascii")
except UnicodeEncodeError as exc: # noqa: BLE001
raise HTTPException(status_code=400, detail="邮箱地址包含不支持的字符") from exc
html_content = f"""
<!DOCTYPE html>
<html lang="zh-CN">
<head>
<meta http-equiv=\"Content-Type\" content=\"text/html; charset=UTF-8\">
<meta name=\"viewport\" content=\"width=device-width, initial-scale=1.0\">
<title>您的验证码</title>
<style>
body, table, td, a {{ -webkit-text-size-adjust: 100%; -ms-text-size-adjust: 100%; }}
table, td {{ mso-table-lspace: 0pt; mso-table-rspace: 0pt; }}
img {{ -ms-interpolation-mode: bicubic; }}
body {{ margin: 0; padding: 0; }}
table {{ border-collapse: collapse !important; }}
</style>
</head>
<body style=\"margin: 0; padding: 0; width: 100% !important; background-color: #f3f4f6;\">
<table width=\"100%\" border=\"0\" cellpadding=\"0\" cellspacing=\"0\" bgcolor=\"#f3f4f6\">
<tr>
<td align=\"center\" valign=\"top\" style=\"padding: 20px;\">
<table border=\"0\" cellpadding=\"0\" cellspacing=\"0\" width=\"100%\" style=\"max-width: 512px; background-color: #ffffff; border-radius: 16px; overflow: hidden;\">
<tr>
<td align=\"center\" style=\"background-color: #2563eb; padding: 32px;\">
<h1 style=\"font-family: Arial, Helvetica, sans-serif; font-size: 30px; font-weight: bold; color: #ffffff; margin: 0;\">操作验证码</h1>
<p style=\"font-family: Arial, Helvetica, sans-serif; font-size: 16px; color: #dbeafe; margin: 8px 0 0;\">请使用下方验证码完成操作。</p>
</td>
</tr>
<tr>
<td align=\"center\" style=\"padding: 32px 48px;\">
<table border=\"0\" cellpadding=\"0\" cellspacing=\"0\" width=\"100%\">
<tr>
<td align=\"center\" style=\"background-color: #f3f4f6; border-radius: 12px; padding: 16px; margin: 24px 0;\">
<p style=\"font-family: 'Courier New', Courier, monospace; font-size: 48px; font-weight: bold; letter-spacing: 0.1em; color: #1d4ed8; margin: 0;\">
{code[:3]}{code[3:]}
</p>
</td>
</tr>
<tr>
<td align=\"center\" style=\"padding-top: 24px;\">
<p style=\"font-family: Arial, Helvetica, sans-serif; font-size: 16px; color: #6b7280; margin: 0;\">
此验证码将在 <span style=\"font-weight: bold; color: #374151;\">5分钟</span> 内有效。
</p>
</td>
</tr>
<tr>
<td align=\"center\" style=\"padding-top: 32px; border-top: 1px solid #e5e7eb; margin-top: 32px;\">
<p style=\"font-family: Arial, Helvetica, sans-serif; font-size: 14px; font-weight: bold; color: #ef4444; margin: 0;\">
为保障安全,请勿泄露此验证码。
</p>
</td>
</tr>
</table>
</td>
</tr>
<tr>
<td align=\"center\" style=\"background-color: #f9fafb; padding: 24px; border-top: 1px solid #e5e7eb;\">
<p style=\"font-family: Arial, Helvetica, sans-serif; font-size: 14px; color: #6b7280; margin: 0;\">
如非本人操作,请忽略此邮件。
</p>
<p style=\"font-family: Arial, Helvetica, sans-serif; font-size: 12px; color: #9ca3af; margin: 8px 0 0;\">
&copy; {time.strftime('%Y')} 拯救小说家. All rights reserved.
</p>
</td>
</tr>
</table>
</td>
</tr>
</table>
</body>
</html>
"""
message = MIMEText(html_content, "html", "utf-8")
message["Subject"] = Header("注册验证码", "utf-8").encode()
message["From"] = formatted_from
message["To"] = to_email
logger.info("准备发送验证码邮件", extra={"to": to_email, "server": server, "port": port})
def _send():
smtp: Optional[smtplib.SMTP] = None
try:
if port == 465:
smtp = smtplib.SMTP_SSL(server, port, timeout=10)
else:
smtp = smtplib.SMTP(server, port, timeout=10)
smtp.starttls()
if username and password:
smtp.login(username, password)
smtp.sendmail(from_addr, [to_email], message.as_string())
logger.info("验证码邮件发送成功", extra={"to": to_email})
except Exception as exc: # noqa: BLE001
logger.exception("验证码发送失败")
raise
finally:
if smtp is not None:
try:
smtp.quit()
except Exception: # noqa: BLE001
pass
try:
await asyncio.to_thread(_send)
except Exception as exc: # noqa: BLE001
raise HTTPException(status_code=500, detail="验证码发送失败,请检查邮件配置") from exc
# ------------------------------------------------------------------
# OAuth 对接示例(以 Linux.do 为例)
# ------------------------------------------------------------------
async def handle_linuxdo_callback(self, code: str) -> Token:
if not await self.is_linuxdo_login_enabled():
raise HTTPException(status_code=403, detail="未启用 Linux.do 登录")
client_id = await self._get_config_value("linuxdo.client_id")
client_secret = await self._get_config_value("linuxdo.client_secret")
redirect_uri = await self._get_config_value("linuxdo.redirect_uri")
token_url = await self._get_config_value("linuxdo.token_url")
user_info_url = await self._get_config_value("linuxdo.user_info_url")
if not all([client_id, client_secret, redirect_uri, token_url, user_info_url]):
raise HTTPException(status_code=500, detail="未正确配置 Linux.do OAuth 参数")
async with httpx.AsyncClient() as client:
token_response = await client.post(
token_url,
data={
"client_id": client_id,
"client_secret": client_secret,
"code": code,
"redirect_uri": redirect_uri,
"grant_type": "authorization_code",
},
)
token_response.raise_for_status()
access_token = token_response.json().get("access_token")
if not access_token:
raise HTTPException(status_code=400, detail="授权失败,未获取到访问令牌")
user_info_response = await client.get(
user_info_url,
headers={"Authorization": f"Bearer {access_token}"},
)
user_info_response.raise_for_status()
data = user_info_response.json()
external_id = f"linuxdo:{data['id']}"
user = await self.user_repo.get_by_external_id(external_id)
if user is None:
placeholder_password = secrets.token_urlsafe(16)
user = User(
username=data["username"],
email=data.get("email"),
external_id=external_id,
hashed_password=hash_password(placeholder_password),
)
self.session.add(user)
await self.session.commit()
return await self.create_access_token(user)
async def _get_config_value(self, key: str) -> Optional[str]:
config = await self.system_config_repo.get_by_key(key)
return config.value if config else None
async def get_config_value(self, key: str) -> Optional[str]:
"""对外暴露的配置读取接口,便于路由层复用。"""
return await self._get_config_value(key)
@staticmethod
def _parse_bool(value: Optional[str], fallback: bool) -> bool:
if value is None:
return fallback
normalized = value.strip().lower()
return normalized in {"1", "true", "yes", "on"}
async def is_registration_enabled(self) -> bool:
value = await self._get_config_value("auth.allow_registration")
return self._parse_bool(value, fallback=settings.allow_registration)
async def is_linuxdo_login_enabled(self) -> bool:
value = await self._get_config_value("auth.linuxdo_enabled")
return self._parse_bool(value, fallback=settings.enable_linuxdo_login)
async def get_auth_options(self) -> AuthOptions:
"""聚合与认证相关的动态开关配置,便于前端一次性拉取。"""
allow_registration = await self.is_registration_enabled()
enable_linuxdo_login = await self.is_linuxdo_login_enabled()
return AuthOptions(
allow_registration=allow_registration,
enable_linuxdo_login=enable_linuxdo_login,
)
def requires_password_reset(self, user: User | UserInDB) -> bool:
if not user.is_admin:
return False
if user.username != settings.admin_default_username:
return False
hashed_password = getattr(user, "hashed_password", None)
if not hashed_password:
return False
return verify_password(settings.admin_default_password, hashed_password)
async def change_password(self, username: str, old_password: str, new_password: str) -> None:
user = await self.user_repo.get_by_username(username)
if not user:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="用户不存在")
if not verify_password(old_password, user.hashed_password):
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="当前密码错误")
if verify_password(new_password, user.hashed_password):
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="新密码不能与当前密码相同")
if username == settings.admin_default_username and new_password == settings.admin_default_password:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="新密码不能为默认密码")
user.hashed_password = hash_password(new_password)
await self.session.commit()

View File

@@ -0,0 +1,109 @@
from __future__ import annotations
"""
章节上下文组装服务:负责调用向量库检索上下文,并对结果做基础格式化。
所有关键步骤均包含中文注释,方便团队理解 RAG 流程。
"""
import logging
from dataclasses import dataclass
from typing import List, Optional
from ..core.config import settings
from ..services.llm_service import LLMService
from .vector_store_service import RetrievedChunk, RetrievedSummary, VectorStoreService
logger = logging.getLogger(__name__)
@dataclass
class ChapterRAGContext:
"""封装检索得到的上下文结果。"""
query: str
chunks: List[RetrievedChunk]
summaries: List[RetrievedSummary]
def chunk_texts(self) -> List[str]:
"""将检索到的 chunk 转换成带序号的 Markdown 段落。"""
lines = []
for idx, chunk in enumerate(self.chunks, start=1):
title = chunk.chapter_title or f"{chunk.chapter_number}"
lines.append(
f"### Chunk {idx}(来源:{title})\n{chunk.content.strip()}"
)
return lines
def summary_lines(self) -> List[str]:
"""整理章节摘要,方便直接插入 Prompt。"""
lines = []
for summary in self.summaries:
lines.append(
f"- 第{summary.chapter_number}章 - {summary.title}:{summary.summary.strip()}"
)
return lines
class ChapterContextService:
"""章节上下文服务,整合查询、格式化与容错逻辑。"""
def __init__(
self,
*,
llm_service: LLMService,
vector_store: Optional[VectorStoreService] = None,
) -> None:
self._llm_service = llm_service
self._vector_store = vector_store
async def retrieve_for_generation(
self,
*,
project_id: str,
query_text: str,
user_id: int,
top_k_chunks: Optional[int] = None,
top_k_summaries: Optional[int] = None,
) -> ChapterRAGContext:
"""根据章节摘要构造检索向量,并返回 RAG 上下文。"""
query = self._normalize(query_text)
if not settings.vector_store_enabled or not self._vector_store:
logger.debug("向量库未启用或初始化失败,跳过检索: project=%s", project_id)
return ChapterRAGContext(query=query, chunks=[], summaries=[])
embedding_model = None if settings.embedding_provider == "ollama" else settings.embedding_model
embedding = await self._llm_service.get_embedding(query, user_id=user_id, model=embedding_model)
if not embedding:
logger.warning("检索查询向量生成失败: project=%s chapter_query=%s", project_id, query)
return ChapterRAGContext(query=query, chunks=[], summaries=[])
chunks = await self._vector_store.query_chunks(
project_id=project_id,
embedding=embedding,
top_k=top_k_chunks,
)
summaries = await self._vector_store.query_summaries(
project_id=project_id,
embedding=embedding,
top_k=top_k_summaries,
)
logger.info(
"章节上下文检索完成: project=%s chunks=%d summaries=%d query_preview=%s",
project_id,
len(chunks),
len(summaries),
query[:80],
)
return ChapterRAGContext(query=query, chunks=chunks, summaries=summaries)
@staticmethod
def _normalize(text: str) -> str:
"""统一压缩空白字符,避免影响检索效果。"""
return " ".join(text.split())
__all__ = [
"ChapterContextService",
"ChapterRAGContext",
]

View File

@@ -0,0 +1,262 @@
from __future__ import annotations
"""
章节向量入库服务:在章节确认后负责切分文本、生成嵌入并写入向量库。
全部注释使用中文,方便团队成员阅读理解。
"""
import logging
from typing import Dict, List, Optional, Sequence
from ..core.config import settings
from ..services.llm_service import LLMService
from ..services.vector_store_service import VectorStoreService
logger = logging.getLogger(__name__)
try: # noqa: SIM105 - 提示缺少可选依赖
from langchain_text_splitters import RecursiveCharacterTextSplitter
except ImportError: # pragma: no cover - 未安装时会走后备方案
RecursiveCharacterTextSplitter = None # type: ignore[assignment]
class ChapterIngestionService:
"""封装章节内容与摘要的向量化与入库流程。"""
def __init__(
self,
*,
llm_service: LLMService,
vector_store: Optional[VectorStoreService] = None,
) -> None:
self._llm_service = llm_service
self._vector_store = vector_store or VectorStoreService()
self._text_splitter = self._init_text_splitter()
async def ingest_chapter(
self,
*,
project_id: str,
chapter_number: int,
title: str,
content: str,
summary: Optional[str],
user_id: int,
) -> None:
"""将章节正文与摘要写入向量库,供后续 RAG 检索使用。"""
if not settings.vector_store_enabled:
logger.debug("向量库未启用,跳过章节向量写入: project=%s chapter=%s", project_id, chapter_number)
return
if not content.strip():
logger.debug("章节正文为空,跳过向量写入: project=%s chapter=%s", project_id, chapter_number)
return
chunks = self._split_into_chunks(content)
if not chunks:
logger.debug("章节正文切分后为空,跳过向量写入: project=%s chapter=%s", project_id, chapter_number)
return
logger.info(
"开始写入章节向量: project=%s chapter=%s chunks=%d",
project_id,
chapter_number,
len(chunks),
)
await self._vector_store.delete_by_chapters(project_id, [chapter_number])
chunk_records = []
for index, chunk_text in enumerate(chunks):
embedding = await self._llm_service.get_embedding(
chunk_text,
user_id=user_id,
)
if not embedding:
logger.warning(
"生成章节片段向量失败,已跳过: project=%s chapter=%s chunk=%s",
project_id,
chapter_number,
index,
)
continue
record_id = f"{project_id}:{chapter_number}:{index}"
chunk_records.append(
{
"id": record_id,
"project_id": project_id,
"chapter_number": chapter_number,
"chunk_index": index,
"chapter_title": title,
"content": chunk_text,
"embedding": embedding,
"metadata": {
"chunk_id": record_id,
"length": len(chunk_text),
},
}
)
if chunk_records:
await self._vector_store.upsert_chunks(records=chunk_records)
logger.info(
"章节正文向量写入完成: project=%s chapter=%s 成功片段=%d",
project_id,
chapter_number,
len(chunk_records),
)
if summary:
cleaned_summary = summary.strip()
if cleaned_summary:
summary_embedding = await self._llm_service.get_embedding(
cleaned_summary,
user_id=user_id,
)
if summary_embedding:
summary_id = f"{project_id}:{chapter_number}:summary"
await self._vector_store.upsert_summaries(
records=[
{
"id": summary_id,
"project_id": project_id,
"chapter_number": chapter_number,
"title": title,
"summary": cleaned_summary,
"embedding": summary_embedding,
}
]
)
logger.info(
"章节摘要向量写入完成: project=%s chapter=%s",
project_id,
chapter_number,
)
else:
logger.warning(
"生成章节摘要向量失败,已跳过: project=%s chapter=%s",
project_id,
chapter_number,
)
async def delete_chapters(self, project_id: str, chapter_numbers: Sequence[int]) -> None:
"""从向量库中删除指定章节的所有片段与摘要。"""
if not settings.vector_store_enabled or not chapter_numbers:
return
logger.info(
"准备删除章节向量: project=%s chapters=%s",
project_id,
list(chapter_numbers),
)
await self._vector_store.delete_by_chapters(project_id, list(chapter_numbers))
def _split_into_chunks(self, text: str) -> List[str]:
"""按照配置的 chunk 大小与重叠度切分章节正文。"""
normalized = text.strip()
if not normalized:
return []
if self._text_splitter:
parts = [segment.strip() for segment in self._text_splitter.split_text(normalized)]
filtered = [part for part in parts if part]
if filtered:
logger.debug(
"使用 LangChain 文本切分器完成分段: count=%d chunk_size=%d overlap=%d",
len(filtered),
settings.vector_chunk_size,
settings.vector_chunk_overlap,
)
return filtered
return self._legacy_split(normalized)
@staticmethod
def _find_split_offset(segment: str) -> Optional[int]:
"""在片段内部寻找更自然的分割点,优先换行,其次常见标点。"""
candidates: Dict[str, int] = {}
newline_pos = segment.rfind("\n\n")
if newline_pos == -1:
newline_pos = segment.rfind("\n")
if newline_pos > 0:
candidates["newline"] = newline_pos
punctuation_marks = ["", "", "", "!", "?", ".", ";", ""]
for mark in punctuation_marks:
idx = segment.rfind(mark)
if idx > 0:
candidates.setdefault("punctuation", idx + len(mark))
if not candidates:
return None
# 选择最接近末尾但又不过短的分割点
best_offset = max(candidates.values())
if best_offset < len(segment) * 0.4:
return None
return best_offset
def _init_text_splitter(self) -> Optional["RecursiveCharacterTextSplitter"]:
"""初始化 LangChain 文本切分器,可根据配置动态调整。"""
if RecursiveCharacterTextSplitter is None:
logger.warning("未安装 langchain-text-splitters章节切分将回退至内置策略。")
return None
chunk_size = settings.vector_chunk_size
overlap = min(settings.vector_chunk_overlap, chunk_size // 2)
separators = [
"\n\n",
"\n",
"", "", "",
"!", "?", "", ";",
"", ",",
" ",
]
splitter = RecursiveCharacterTextSplitter(
separators=separators,
chunk_size=chunk_size,
chunk_overlap=overlap,
keep_separator=False,
strip_whitespace=True,
)
logger.info(
"已初始化 LangChain 文本切分器: chunk_size=%d overlap=%d",
chunk_size,
overlap,
)
return splitter
def _legacy_split(self, text: str) -> List[str]:
"""内置切分策略,作为 LangChain 缺失时的后备方案。"""
chunk_size = settings.vector_chunk_size
overlap = min(settings.vector_chunk_overlap, chunk_size // 2)
chunks: List[str] = []
start = 0
total_length = len(text)
while start < total_length:
end = min(total_length, start + chunk_size)
segment = text[start:end]
split_offset = self._find_split_offset(segment)
if split_offset is not None and start + split_offset < total_length:
end = start + split_offset
segment = text[start:end]
chunk_text = segment.strip()
if chunk_text:
chunks.append(chunk_text)
if end >= total_length:
break
start = max(0, end - overlap)
logger.debug(
"使用内置策略完成章节切分: count=%d chunk_size=%d overlap=%d",
len(chunks),
chunk_size,
overlap,
)
return chunks
__all__ = ["ChapterIngestionService"]

View File

@@ -0,0 +1,49 @@
from typing import Iterable, Optional
from sqlalchemy.ext.asyncio import AsyncSession
from ..repositories.system_config_repository import SystemConfigRepository
from ..models import SystemConfig
from ..schemas.config import SystemConfigCreate, SystemConfigRead, SystemConfigUpdate
class ConfigService:
"""系统配置服务:提供 CRUD 接口,并负责转换 Pydantic 模型。"""
def __init__(self, session: AsyncSession):
self.session = session
self.repo = SystemConfigRepository(session)
async def list_configs(self) -> list[SystemConfigRead]:
configs = await self.repo.list_all()
return [SystemConfigRead.model_validate(cfg) for cfg in configs]
async def get_config(self, key: str) -> Optional[SystemConfigRead]:
config = await self.repo.get_by_key(key)
return SystemConfigRead.model_validate(config) if config else None
async def upsert_config(self, payload: SystemConfigCreate) -> SystemConfigRead:
instance = await self.repo.get_by_key(payload.key)
if instance:
await self.repo.update_fields(instance, value=payload.value, description=payload.description)
else:
instance = SystemConfig(**payload.model_dump())
await self.repo.add(instance)
await self.session.commit()
return SystemConfigRead.model_validate(instance)
async def patch_config(self, key: str, payload: SystemConfigUpdate) -> Optional[SystemConfigRead]:
instance = await self.repo.get_by_key(key)
if not instance:
return None
await self.repo.update_fields(instance, **payload.model_dump(exclude_unset=True))
await self.session.commit()
return SystemConfigRead.model_validate(instance)
async def remove_config(self, key: str) -> bool:
instance = await self.repo.get_by_key(key)
if not instance:
return False
await self.repo.delete(instance)
await self.session.commit()
return True

View File

@@ -0,0 +1,41 @@
from typing import Optional
from sqlalchemy.ext.asyncio import AsyncSession
from ..models import LLMConfig
from ..repositories.llm_config_repository import LLMConfigRepository
from ..schemas.llm_config import LLMConfigCreate, LLMConfigRead
class LLMConfigService:
"""用户自定义 LLM 配置服务。"""
def __init__(self, session: AsyncSession):
self.session = session
self.repo = LLMConfigRepository(session)
async def upsert_config(self, user_id: int, payload: LLMConfigCreate) -> LLMConfigRead:
instance = await self.repo.get_by_user(user_id)
data = payload.model_dump(exclude_unset=True)
if "llm_provider_url" in data and data["llm_provider_url"] is not None:
# HttpUrl 类型在 sqlite 中无法直接写入,需要提前转为字符串
data["llm_provider_url"] = str(data["llm_provider_url"])
if instance:
await self.repo.update_fields(instance, **data)
else:
instance = LLMConfig(user_id=user_id, **data)
await self.repo.add(instance)
await self.session.commit()
return LLMConfigRead.model_validate(instance)
async def get_config(self, user_id: int) -> Optional[LLMConfigRead]:
instance = await self.repo.get_by_user(user_id)
return LLMConfigRead.model_validate(instance) if instance else None
async def delete_config(self, user_id: int) -> bool:
instance = await self.repo.get_by_user(user_id)
if not instance:
return False
await self.repo.delete(instance)
await self.session.commit()
return True

View File

@@ -0,0 +1,306 @@
import logging
import os
from typing import Any, Dict, List, Optional
import httpx
from fastapi import HTTPException, status
from openai import APIConnectionError, APITimeoutError, AsyncOpenAI, InternalServerError
from ..core.config import settings
from ..repositories.llm_config_repository import LLMConfigRepository
from ..repositories.system_config_repository import SystemConfigRepository
from ..repositories.user_repository import UserRepository
from ..services.admin_setting_service import AdminSettingService
from ..services.prompt_service import PromptService
from ..services.usage_service import UsageService
from ..utils.llm_tool import ChatMessage, LLMClient
logger = logging.getLogger(__name__)
try: # pragma: no cover - 运行环境未安装时兼容
from ollama import AsyncClient as OllamaAsyncClient
except ImportError: # pragma: no cover - Ollama 为可选依赖
OllamaAsyncClient = None
class LLMService:
"""封装与大模型交互的所有逻辑,包括配额控制与配置选择。"""
def __init__(self, session):
self.session = session
self.llm_repo = LLMConfigRepository(session)
self.system_config_repo = SystemConfigRepository(session)
self.user_repo = UserRepository(session)
self.admin_setting_service = AdminSettingService(session)
self.usage_service = UsageService(session)
self._embedding_dimensions: Dict[str, int] = {}
async def get_llm_response(
self,
system_prompt: str,
conversation_history: List[Dict[str, str]],
*,
temperature: float = 0.7,
user_id: Optional[int] = None,
timeout: float = 300.0,
response_format: Optional[str] = "json_object",
) -> str:
messages = [{"role": "system", "content": system_prompt}, *conversation_history]
return await self._stream_and_collect(
messages,
temperature=temperature,
user_id=user_id,
timeout=timeout,
response_format=response_format,
)
async def get_summary(
self,
chapter_content: str,
*,
temperature: float = 0.2,
user_id: Optional[int] = None,
timeout: float = 180.0,
system_prompt: Optional[str] = None,
) -> str:
if not system_prompt:
prompt_service = PromptService(self.session)
system_prompt = await prompt_service.get_prompt("extraction")
if not system_prompt:
raise HTTPException(status_code=500, detail="未配置摘要提示词")
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": chapter_content},
]
return await self._stream_and_collect(messages, temperature=temperature, user_id=user_id, timeout=timeout)
async def _stream_and_collect(
self,
messages: List[Dict[str, str]],
*,
temperature: float,
user_id: Optional[int],
timeout: float,
response_format: Optional[str] = None,
) -> str:
config = await self._resolve_llm_config(user_id)
client = LLMClient(api_key=config["api_key"], base_url=config.get("base_url"))
chat_messages = [ChatMessage(role=msg["role"], content=msg["content"]) for msg in messages]
full_response = ""
finish_reason = None
logger.info(
"Streaming LLM response: model=%s user_id=%s messages=%d",
config.get("model"),
user_id,
len(messages),
)
try:
async for part in client.stream_chat(
messages=chat_messages,
model=config.get("model"),
temperature=temperature,
timeout=int(timeout),
response_format=response_format,
):
if part.get("content"):
full_response += part["content"]
if part.get("finish_reason"):
finish_reason = part["finish_reason"]
except InternalServerError as exc:
detail = "AI 服务内部错误,请稍后重试"
response = getattr(exc, "response", None)
if response is not None:
try:
payload = response.json()
error_data = payload.get("error", {}) if isinstance(payload, dict) else {}
detail = error_data.get("message_zh") or error_data.get("message") or detail
except Exception:
detail = str(exc) or detail
else:
detail = str(exc) or detail
logger.error(
"LLM stream internal error: model=%s user_id=%s detail=%s",
config.get("model"),
user_id,
detail,
exc_info=exc,
)
raise HTTPException(status_code=503, detail=detail)
except (httpx.RemoteProtocolError, httpx.ReadTimeout, APIConnectionError, APITimeoutError) as exc:
if isinstance(exc, httpx.RemoteProtocolError):
detail = "AI 服务连接被意外中断,请稍后重试"
elif isinstance(exc, (httpx.ReadTimeout, APITimeoutError)):
detail = "AI 服务响应超时,请稍后重试"
else:
detail = "无法连接到 AI 服务,请稍后重试"
logger.error(
"LLM stream failed: model=%s user_id=%s detail=%s",
config.get("model"),
user_id,
detail,
exc_info=exc,
)
raise HTTPException(status_code=503, detail=detail) from exc
logger.debug(
"LLM response collected: model=%s user_id=%s finish_reason=%s preview=%s",
config.get("model"),
user_id,
finish_reason,
full_response[:500],
)
if finish_reason == "length":
logger.warning(
"LLM response truncated: model=%s user_id=%s",
config.get("model"),
user_id,
)
raise HTTPException(status_code=500, detail="AI 响应被截断,请缩短输入或调整参数")
if not full_response:
logger.error(
"LLM returned empty response: model=%s user_id=%s",
config.get("model"),
user_id,
)
raise HTTPException(status_code=500, detail="AI 未返回有效内容")
await self.usage_service.increment("api_request_count")
logger.info(
"LLM response success: model=%s user_id=%s chars=%d",
config.get("model"),
user_id,
len(full_response),
)
return full_response
async def _resolve_llm_config(self, user_id: Optional[int]) -> Dict[str, Optional[str]]:
if user_id:
config = await self.llm_repo.get_by_user(user_id)
if config and config.llm_provider_api_key:
return {
"api_key": config.llm_provider_api_key,
"base_url": config.llm_provider_url,
"model": config.llm_provider_model,
}
# 检查每日使用次数限制
if user_id:
await self._enforce_daily_limit(user_id)
api_key = await self._get_config_value("llm.api_key")
base_url = await self._get_config_value("llm.base_url")
model = await self._get_config_value("llm.model")
if not api_key:
raise HTTPException(status_code=500, detail="未配置默认 LLM API Key")
return {"api_key": api_key, "base_url": base_url, "model": model}
async def get_embedding(
self,
text: str,
*,
user_id: Optional[int] = None,
model: Optional[str] = None,
) -> List[float]:
"""生成文本向量,用于章节 RAG 检索,支持 openai 与 ollama 双提供方。"""
provider = settings.embedding_provider
target_model = model or (
settings.ollama_embedding_model if provider == "ollama" else settings.embedding_model
)
if provider == "ollama":
if OllamaAsyncClient is None:
logger.error("未安装 ollama 依赖,无法调用本地嵌入模型。")
raise HTTPException(status_code=500, detail="缺少 Ollama 依赖,请先安装 ollama 包。")
base_url_any = settings.ollama_embedding_base_url or settings.embedding_base_url
base_url = str(base_url_any) if base_url_any else None
client = OllamaAsyncClient(host=base_url)
try:
response = await client.embeddings(model=target_model, prompt=text)
except Exception as exc: # pragma: no cover - 本地服务调用失败
logger.warning(
"Ollama 嵌入请求失败: model=%s error=%s",
target_model,
exc,
)
return []
embedding: Optional[List[float]]
if isinstance(response, dict):
embedding = response.get("embedding")
else:
embedding = getattr(response, "embedding", None)
if not embedding:
logger.warning("Ollama 返回空向量: model=%s", target_model)
return []
if not isinstance(embedding, list):
embedding = list(embedding)
else:
config = await self._resolve_llm_config(user_id)
api_key = settings.embedding_api_key or config["api_key"]
base_url_setting = settings.embedding_base_url or config.get("base_url")
base_url = str(base_url_setting) if base_url_setting else None
client = AsyncOpenAI(api_key=api_key, base_url=base_url)
try:
response = await client.embeddings.create(
input=text,
model=target_model,
)
except Exception as exc: # pragma: no cover - 网络或鉴权失败
logger.warning(
"OpenAI 嵌入请求失败: model=%s user_id=%s error=%s",
target_model,
user_id,
exc,
)
return []
if not response.data:
logger.warning("OpenAI 嵌入请求返回空数据: model=%s user_id=%s", target_model, user_id)
return []
embedding = response.data[0].embedding
if not isinstance(embedding, list):
embedding = list(embedding)
dimension = len(embedding)
if not dimension and settings.embedding_model_vector_size:
dimension = settings.embedding_model_vector_size
if dimension:
self._embedding_dimensions[target_model] = dimension
return embedding
def get_embedding_dimension(self, model: Optional[str] = None) -> Optional[int]:
"""获取嵌入向量维度,优先返回缓存结果,其次读取配置。"""
target_model = model or (
settings.ollama_embedding_model if settings.embedding_provider == "ollama" else settings.embedding_model
)
if target_model in self._embedding_dimensions:
return self._embedding_dimensions[target_model]
return settings.embedding_model_vector_size
async def _enforce_daily_limit(self, user_id: int) -> None:
limit_str = await self.admin_setting_service.get("daily_request_limit", "100")
limit = int(limit_str or 10)
used = await self.user_repo.get_daily_request(user_id)
if used >= limit:
raise HTTPException(
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
detail="今日请求次数已达上限,请明日再试或设置自定义 API Key。",
)
await self.user_repo.increment_daily_request(user_id)
await self.session.commit()
async def _get_config_value(self, key: str) -> Optional[str]:
record = await self.system_config_repo.get_by_key(key)
if record:
return record.value
# 兼容环境变量,首次迁移时无需立即写入数据库
env_key = key.upper().replace(".", "_")
return os.getenv(env_key)

View File

@@ -0,0 +1,700 @@
from __future__ import annotations
import json
import uuid
from datetime import datetime, timezone
from typing import Any, Dict, Iterable, List, Optional
_PREFERRED_CONTENT_KEYS: tuple[str, ...] = (
"content",
"chapter_content",
"chapter_text",
"full_content",
"text",
"body",
"story",
"chapter",
"real_summary",
"summary",
)
def _normalize_version_content(raw_content: Any, metadata: Any) -> str:
text = _coerce_text(metadata)
if not text:
text = _coerce_text(raw_content)
return text or ""
def _coerce_text(value: Any) -> Optional[str]:
if value is None:
return None
if isinstance(value, str):
return _clean_string(value)
if isinstance(value, (int, float)):
return str(value)
if isinstance(value, dict):
for key in _PREFERRED_CONTENT_KEYS:
if key in value and value[key]:
nested = _coerce_text(value[key])
if nested:
return nested
return _clean_string(json.dumps(value, ensure_ascii=False))
if isinstance(value, (list, tuple, set)):
parts = [text for text in (_coerce_text(item) for item in value) if text]
if parts:
return "\n".join(parts)
return None
return _clean_string(str(value))
def _clean_string(text: str) -> str:
stripped = text.strip()
if not stripped:
return stripped
if stripped.startswith("{") and stripped.endswith("}"):
try:
parsed = json.loads(stripped)
coerced = _coerce_text(parsed)
if coerced:
return coerced
except json.JSONDecodeError:
pass
if stripped.startswith('"') and stripped.endswith('"') and len(stripped) >= 2:
stripped = stripped[1:-1]
return (
stripped.replace("\\n", "\n")
.replace("\\t", "\t")
.replace('\\"', '"')
.replace("\\\\", "\\")
)
from fastapi import HTTPException, status
from sqlalchemy import delete, func, select, update
from sqlalchemy.ext.asyncio import AsyncSession
from ..models import (
BlueprintCharacter,
BlueprintRelationship,
Chapter,
ChapterEvaluation,
ChapterOutline,
ChapterVersion,
NovelBlueprint,
NovelConversation,
NovelProject,
)
from ..repositories.novel_repository import NovelRepository
from ..schemas.admin import AdminNovelSummary
from ..schemas.novel import (
Blueprint,
Chapter as ChapterSchema,
ChapterGenerationStatus,
ChapterOutline as ChapterOutlineSchema,
NovelProject as NovelProjectSchema,
NovelProjectSummary,
NovelSectionResponse,
NovelSectionType,
)
class NovelService:
"""小说项目服务,基于拆表后的结构提供聚合与业务操作。"""
def __init__(self, session: AsyncSession):
self.session = session
self.repo = NovelRepository(session)
# ------------------------------------------------------------------
# 项目与摘要
# ------------------------------------------------------------------
async def create_project(self, user_id: int, title: str, initial_prompt: str) -> NovelProject:
project = NovelProject(
id=str(uuid.uuid4()),
user_id=user_id,
title=title,
initial_prompt=initial_prompt,
)
blueprint = NovelBlueprint(project=project)
self.session.add_all([project, blueprint])
await self.session.commit()
await self.session.refresh(project)
return project
async def ensure_project_owner(self, project_id: str, user_id: int) -> NovelProject:
project = await self.repo.get_by_id(project_id)
if not project:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="项目不存在")
if project.user_id != user_id:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="无权访问该项目")
return project
async def get_project_schema(self, project_id: str, user_id: int) -> NovelProjectSchema:
project = await self.ensure_project_owner(project_id, user_id)
return await self._serialize_project(project)
async def get_section_data(
self,
project_id: str,
user_id: int,
section: NovelSectionType,
) -> NovelSectionResponse:
project = await self.ensure_project_owner(project_id, user_id)
return self._build_section_response(project, section)
async def get_chapter_schema(
self,
project_id: str,
user_id: int,
chapter_number: int,
) -> ChapterSchema:
project = await self.ensure_project_owner(project_id, user_id)
return self._build_chapter_schema(project, chapter_number)
async def list_projects_for_user(self, user_id: int) -> List[NovelProjectSummary]:
projects = await self.repo.list_by_user(user_id)
summaries: List[NovelProjectSummary] = []
for project in projects:
blueprint = project.blueprint
genre = blueprint.genre if blueprint and blueprint.genre else "未知"
outlines = project.outlines
chapters = project.chapters
total = len(outlines) or len(chapters)
completed = sum(1 for chapter in chapters if chapter.selected_version_id)
summaries.append(
NovelProjectSummary(
id=project.id,
title=project.title,
genre=genre,
last_edited=project.updated_at.isoformat() if project.updated_at else "未知",
completed_chapters=completed,
total_chapters=total,
)
)
return summaries
async def list_projects_for_admin(self) -> List[AdminNovelSummary]:
projects = await self.repo.list_all()
summaries: List[AdminNovelSummary] = []
for project in projects:
blueprint = project.blueprint
genre = blueprint.genre if blueprint and blueprint.genre else "未知"
outlines = project.outlines
chapters = project.chapters
total = len(outlines) or len(chapters)
completed = sum(1 for chapter in chapters if chapter.selected_version_id)
owner = project.owner
summaries.append(
AdminNovelSummary(
id=project.id,
title=project.title,
owner_id=owner.id if owner else 0,
owner_username=owner.username if owner else "未知",
genre=genre,
last_edited=project.updated_at.isoformat() if project.updated_at else "",
completed_chapters=completed,
total_chapters=total,
)
)
return summaries
async def delete_projects(self, project_ids: List[str], user_id: int) -> None:
for pid in project_ids:
project = await self.ensure_project_owner(pid, user_id)
await self.repo.delete(project)
await self.session.commit()
async def count_projects(self) -> int:
result = await self.session.execute(select(func.count(NovelProject.id)))
return result.scalar_one()
# ------------------------------------------------------------------
# 对话管理
# ------------------------------------------------------------------
async def list_conversations(self, project_id: str) -> List[NovelConversation]:
stmt = (
select(NovelConversation)
.where(NovelConversation.project_id == project_id)
.order_by(NovelConversation.seq.asc())
)
result = await self.session.execute(stmt)
return list(result.scalars())
async def append_conversation(self, project_id: str, role: str, content: str, metadata: Optional[Dict] = None) -> None:
result = await self.session.execute(
select(func.max(NovelConversation.seq)).where(NovelConversation.project_id == project_id)
)
current_max = result.scalar()
next_seq = (current_max or 0) + 1
convo = NovelConversation(
project_id=project_id,
seq=next_seq,
role=role,
content=content,
metadata=metadata,
)
self.session.add(convo)
await self.session.commit()
await self._touch_project(project_id)
# ------------------------------------------------------------------
# 蓝图管理
# ------------------------------------------------------------------
async def replace_blueprint(self, project_id: str, blueprint: Blueprint) -> None:
record = await self.session.get(NovelBlueprint, project_id)
if not record:
record = NovelBlueprint(project_id=project_id)
self.session.add(record)
record.title = blueprint.title
record.target_audience = blueprint.target_audience
record.genre = blueprint.genre
record.style = blueprint.style
record.tone = blueprint.tone
record.one_sentence_summary = blueprint.one_sentence_summary
record.full_synopsis = blueprint.full_synopsis
record.world_setting = blueprint.world_setting
await self.session.execute(delete(BlueprintCharacter).where(BlueprintCharacter.project_id == project_id))
for index, data in enumerate(blueprint.characters):
self.session.add(
BlueprintCharacter(
project_id=project_id,
name=data.get("name", ""),
identity=data.get("identity"),
personality=data.get("personality"),
goals=data.get("goals"),
abilities=data.get("abilities"),
relationship_to_protagonist=data.get("relationship_to_protagonist"),
extra={k: v for k, v in data.items() if k not in {
"name",
"identity",
"personality",
"goals",
"abilities",
"relationship_to_protagonist",
}},
position=index,
)
)
await self.session.execute(delete(BlueprintRelationship).where(BlueprintRelationship.project_id == project_id))
for index, relation in enumerate(blueprint.relationships):
self.session.add(
BlueprintRelationship(
project_id=project_id,
character_from=relation.character_from,
character_to=relation.character_to,
description=relation.description,
position=index,
)
)
await self.session.execute(delete(ChapterOutline).where(ChapterOutline.project_id == project_id))
for outline in blueprint.chapter_outline:
self.session.add(
ChapterOutline(
project_id=project_id,
chapter_number=outline.chapter_number,
title=outline.title,
summary=outline.summary,
)
)
await self.session.commit()
await self._touch_project(project_id)
async def patch_blueprint(self, project_id: str, patch: Dict) -> None:
blueprint = await self.session.get(NovelBlueprint, project_id)
if not blueprint:
blueprint = NovelBlueprint(project_id=project_id)
self.session.add(blueprint)
if "one_sentence_summary" in patch:
blueprint.one_sentence_summary = patch["one_sentence_summary"]
if "full_synopsis" in patch:
blueprint.full_synopsis = patch["full_synopsis"]
if "world_setting" in patch and patch["world_setting"] is not None:
existing = blueprint.world_setting or {}
existing.update(patch["world_setting"])
blueprint.world_setting = existing
if "characters" in patch and patch["characters"] is not None:
await self.session.execute(delete(BlueprintCharacter).where(BlueprintCharacter.project_id == project_id))
for index, data in enumerate(patch["characters"]):
self.session.add(
BlueprintCharacter(
project_id=project_id,
name=data.get("name", ""),
identity=data.get("identity"),
personality=data.get("personality"),
goals=data.get("goals"),
abilities=data.get("abilities"),
relationship_to_protagonist=data.get("relationship_to_protagonist"),
extra={k: v for k, v in data.items() if k not in {
"name",
"identity",
"personality",
"goals",
"abilities",
"relationship_to_protagonist",
}},
position=index,
)
)
if "relationships" in patch and patch["relationships"] is not None:
await self.session.execute(delete(BlueprintRelationship).where(BlueprintRelationship.project_id == project_id))
for index, relation in enumerate(patch["relationships"]):
self.session.add(
BlueprintRelationship(
project_id=project_id,
character_from=relation.get("character_from"),
character_to=relation.get("character_to"),
description=relation.get("description"),
position=index,
)
)
if "chapter_outline" in patch and patch["chapter_outline"] is not None:
await self.session.execute(delete(ChapterOutline).where(ChapterOutline.project_id == project_id))
for outline in patch["chapter_outline"]:
self.session.add(
ChapterOutline(
project_id=project_id,
chapter_number=outline.get("chapter_number"),
title=outline.get("title", ""),
summary=outline.get("summary"),
)
)
await self.session.commit()
await self._touch_project(project_id)
# ------------------------------------------------------------------
# 章节与版本
# ------------------------------------------------------------------
async def get_outline(self, project_id: str, chapter_number: int) -> Optional[ChapterOutline]:
stmt = (
select(ChapterOutline)
.where(
ChapterOutline.project_id == project_id,
ChapterOutline.chapter_number == chapter_number,
)
)
result = await self.session.execute(stmt)
return result.scalars().first()
async def get_or_create_chapter(self, project_id: str, chapter_number: int) -> Chapter:
stmt = (
select(Chapter)
.where(
Chapter.project_id == project_id,
Chapter.chapter_number == chapter_number,
)
)
result = await self.session.execute(stmt)
chapter = result.scalars().first()
if chapter:
return chapter
chapter = Chapter(project_id=project_id, chapter_number=chapter_number)
self.session.add(chapter)
await self.session.commit()
await self.session.refresh(chapter)
return chapter
async def replace_chapter_versions(self, chapter: Chapter, contents: List[str], metadata: Optional[List[Dict]] = None) -> List[ChapterVersion]:
await self.session.execute(delete(ChapterVersion).where(ChapterVersion.chapter_id == chapter.id))
versions: List[ChapterVersion] = []
for index, content in enumerate(contents):
extra = metadata[index] if metadata and index < len(metadata) else None
text_content = _normalize_version_content(content, extra)
version = ChapterVersion(
chapter_id=chapter.id,
content=text_content,
metadata=None,
version_label=f"v{index+1}",
)
self.session.add(version)
versions.append(version)
chapter.status = ChapterGenerationStatus.WAITING_FOR_CONFIRM.value
await self.session.commit()
await self.session.refresh(chapter)
await self._touch_project(chapter.project_id)
return versions
async def select_chapter_version(self, chapter: Chapter, version_index: int) -> ChapterVersion:
versions = sorted(chapter.versions, key=lambda item: item.created_at)
if not versions or version_index < 0 or version_index >= len(versions):
raise HTTPException(status_code=400, detail="版本索引无效")
selected = versions[version_index]
chapter.selected_version_id = selected.id
chapter.status = ChapterGenerationStatus.SUCCESSFUL.value
chapter.word_count = len(selected.content or "")
await self.session.commit()
await self.session.refresh(chapter)
await self._touch_project(chapter.project_id)
return selected
async def add_chapter_evaluation(self, chapter: Chapter, version: Optional[ChapterVersion], feedback: str, decision: Optional[str] = None) -> None:
evaluation = ChapterEvaluation(
chapter_id=chapter.id,
version_id=version.id if version else None,
feedback=feedback,
decision=decision,
)
self.session.add(evaluation)
chapter.status = ChapterGenerationStatus.WAITING_FOR_CONFIRM.value
await self.session.commit()
await self.session.refresh(chapter)
await self._touch_project(chapter.project_id)
async def delete_chapters(self, project_id: str, chapter_numbers: Iterable[int]) -> None:
await self.session.execute(
delete(Chapter).where(
Chapter.project_id == project_id,
Chapter.chapter_number.in_(list(chapter_numbers)),
)
)
await self.session.execute(
delete(ChapterOutline).where(
ChapterOutline.project_id == project_id,
ChapterOutline.chapter_number.in_(list(chapter_numbers)),
)
)
await self.session.commit()
await self._touch_project(project_id)
# ------------------------------------------------------------------
# 序列化辅助
# ------------------------------------------------------------------
async def get_project_schema_for_admin(self, project_id: str) -> NovelProjectSchema:
project = await self.repo.get_by_id(project_id)
if not project:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="项目不存在")
return await self._serialize_project(project)
async def get_section_data_for_admin(
self,
project_id: str,
section: NovelSectionType,
) -> NovelSectionResponse:
project = await self.repo.get_by_id(project_id)
if not project:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="项目不存在")
return self._build_section_response(project, section)
async def get_chapter_schema_for_admin(
self,
project_id: str,
chapter_number: int,
) -> ChapterSchema:
project = await self.repo.get_by_id(project_id)
if not project:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="项目不存在")
return self._build_chapter_schema(project, chapter_number)
async def _serialize_project(self, project: NovelProject) -> NovelProjectSchema:
conversations = [
{"role": convo.role, "content": convo.content}
for convo in sorted(project.conversations, key=lambda c: c.seq)
]
blueprint_schema = self._build_blueprint_schema(project)
outlines_map = {outline.chapter_number: outline for outline in project.outlines}
chapters_map = {chapter.chapter_number: chapter for chapter in project.chapters}
chapter_numbers = sorted(set(outlines_map.keys()) | set(chapters_map.keys()))
chapters_schema: List[ChapterSchema] = [
self._build_chapter_schema(
project,
number,
outlines_map=outlines_map,
chapters_map=chapters_map,
)
for number in chapter_numbers
]
return NovelProjectSchema(
id=project.id,
user_id=project.user_id,
title=project.title,
initial_prompt=project.initial_prompt or "",
conversation_history=conversations,
blueprint=blueprint_schema,
chapters=chapters_schema,
)
async def _touch_project(self, project_id: str) -> None:
await self.session.execute(
update(NovelProject)
.where(NovelProject.id == project_id)
.values(updated_at=datetime.now(timezone.utc))
)
await self.session.commit()
def _build_blueprint_schema(self, project: NovelProject) -> Blueprint:
blueprint_obj = project.blueprint
if blueprint_obj:
return Blueprint(
title=blueprint_obj.title or "",
target_audience=blueprint_obj.target_audience or "",
genre=blueprint_obj.genre or "",
style=blueprint_obj.style or "",
tone=blueprint_obj.tone or "",
one_sentence_summary=blueprint_obj.one_sentence_summary or "",
full_synopsis=blueprint_obj.full_synopsis or "",
world_setting=blueprint_obj.world_setting or {},
characters=[
{
"name": character.name,
"identity": character.identity,
"personality": character.personality,
"goals": character.goals,
"abilities": character.abilities,
"relationship_to_protagonist": character.relationship_to_protagonist,
**(character.extra or {}),
}
for character in sorted(project.characters, key=lambda c: c.position)
],
relationships=[
{
"character_from": relation.character_from,
"character_to": relation.character_to,
"description": relation.description or "",
"relationship_type": getattr(relation, "relationship_type", None),
}
for relation in sorted(project.relationships_, key=lambda r: r.position)
],
chapter_outline=[
ChapterOutlineSchema(
chapter_number=outline.chapter_number,
title=outline.title,
summary=outline.summary or "",
)
for outline in sorted(project.outlines, key=lambda o: o.chapter_number)
],
)
return Blueprint(
title="",
target_audience="",
genre="",
style="",
tone="",
one_sentence_summary="",
full_synopsis="",
world_setting={},
characters=[],
relationships=[],
chapter_outline=[],
)
def _build_section_response(
self,
project: NovelProject,
section: NovelSectionType,
) -> NovelSectionResponse:
blueprint = self._build_blueprint_schema(project)
if section == NovelSectionType.OVERVIEW:
data = {
"title": project.title,
"initial_prompt": project.initial_prompt or "",
"status": project.status,
"one_sentence_summary": blueprint.one_sentence_summary,
"target_audience": blueprint.target_audience,
"genre": blueprint.genre,
"style": blueprint.style,
"tone": blueprint.tone,
"full_synopsis": blueprint.full_synopsis,
"updated_at": project.updated_at.isoformat() if project.updated_at else None,
}
elif section == NovelSectionType.WORLD_SETTING:
data = {
"world_setting": blueprint.world_setting or {},
}
elif section == NovelSectionType.CHARACTERS:
data = {
"characters": blueprint.characters,
}
elif section == NovelSectionType.RELATIONSHIPS:
data = {
"relationships": blueprint.relationships,
}
elif section == NovelSectionType.CHAPTER_OUTLINE:
data = {
"chapter_outline": [outline.model_dump() for outline in blueprint.chapter_outline],
}
elif section == NovelSectionType.CHAPTERS:
outlines_map = {outline.chapter_number: outline for outline in project.outlines}
chapters_map = {chapter.chapter_number: chapter for chapter in project.chapters}
chapter_numbers = sorted(set(outlines_map.keys()) | set(chapters_map.keys()))
# 章节列表只返回元数据,不包含完整内容
chapters = [
self._build_chapter_schema(
project,
number,
outlines_map=outlines_map,
chapters_map=chapters_map,
include_content=False,
).model_dump()
for number in chapter_numbers
]
data = {
"chapters": chapters,
"total": len(chapters),
}
else:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="未知的章节类型")
return NovelSectionResponse(section=section, data=data)
def _build_chapter_schema(
self,
project: NovelProject,
chapter_number: int,
*,
outlines_map: Optional[Dict[int, ChapterOutline]] = None,
chapters_map: Optional[Dict[int, Chapter]] = None,
include_content: bool = True,
) -> ChapterSchema:
outlines = outlines_map or {outline.chapter_number: outline for outline in project.outlines}
chapters = chapters_map or {chapter.chapter_number: chapter for chapter in project.chapters}
outline = outlines.get(chapter_number)
chapter = chapters.get(chapter_number)
if not outline and not chapter:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="章节不存在")
title = outline.title if outline else f"{chapter_number}"
summary = outline.summary if outline else ""
real_summary = chapter.real_summary if chapter else None
content = None
versions: Optional[List[str]] = None
evaluation_text: Optional[str] = None
status_value = ChapterGenerationStatus.NOT_GENERATED.value
word_count = 0
if chapter:
status_value = chapter.status or ChapterGenerationStatus.NOT_GENERATED.value
word_count = chapter.word_count or 0
# 只有在 include_content=True 时才包含完整内容
if include_content:
if chapter.selected_version:
content = chapter.selected_version.content
if chapter.versions:
versions = [
v.content
for v in sorted(chapter.versions, key=lambda item: item.created_at)
]
if chapter.evaluations:
latest = sorted(chapter.evaluations, key=lambda item: item.created_at)[-1]
evaluation_text = latest.feedback or latest.decision
return ChapterSchema(
chapter_number=chapter_number,
title=title,
summary=summary,
real_summary=real_summary,
content=content,
versions=versions,
evaluation=evaluation_text,
generation_status=ChapterGenerationStatus(status_value),
word_count=word_count,
)

View File

@@ -0,0 +1,96 @@
import asyncio
from typing import Dict, Optional
from sqlalchemy.ext.asyncio import AsyncSession
from ..models import Prompt
from ..repositories.prompt_repository import PromptRepository
from ..schemas.prompt import PromptCreate, PromptRead, PromptUpdate
_CACHE: Dict[str, PromptRead] = {}
_LOCK = asyncio.Lock()
_LOADED = False
class PromptService:
"""提示词服务,提供缓存加速与 CRUD 能力。"""
def __init__(self, session: AsyncSession):
self.session = session
self.repo = PromptRepository(session)
async def preload(self) -> None:
global _CACHE, _LOADED
prompts = await self.repo.list_all()
async with _LOCK:
_CACHE = {item.name: PromptRead.model_validate(item) for item in prompts}
_LOADED = True
async def get_prompt(self, name: str) -> Optional[str]:
global _LOADED
async with _LOCK:
if not _LOADED:
prompts = await self.repo.list_all()
_CACHE.update({item.name: PromptRead.model_validate(item) for item in prompts})
_LOADED = True
cached = _CACHE.get(name)
if cached:
return cached.content
prompt = await self.repo.get_by_name(name)
if not prompt:
return None
prompt_read = PromptRead.model_validate(prompt)
async with _LOCK:
_CACHE[name] = prompt_read
return prompt_read.content
async def list_prompts(self) -> list[PromptRead]:
prompts = await self.repo.list_all()
return [PromptRead.model_validate(item) for item in prompts]
async def get_prompt_by_id(self, prompt_id: int) -> Optional[PromptRead]:
instance = await self.repo.get(id=prompt_id)
if not instance:
return None
return PromptRead.model_validate(instance)
async def create_prompt(self, payload: PromptCreate) -> PromptRead:
data = payload.model_dump()
tags = data.get("tags")
if tags is not None:
data["tags"] = ",".join(tags)
prompt = Prompt(**data)
await self.repo.add(prompt)
await self.session.commit()
prompt_read = PromptRead.model_validate(prompt)
async with _LOCK:
_CACHE[prompt_read.name] = prompt_read
global _LOADED
_LOADED = True
return prompt_read
async def update_prompt(self, prompt_id: int, payload: PromptUpdate) -> Optional[PromptRead]:
instance = await self.repo.get(id=prompt_id)
if not instance:
return None
update_data = payload.model_dump(exclude_unset=True)
if "tags" in update_data and update_data["tags"] is not None:
update_data["tags"] = ",".join(update_data["tags"])
await self.repo.update_fields(instance, **update_data)
await self.session.commit()
prompt_read = PromptRead.model_validate(instance)
async with _LOCK:
_CACHE[prompt_read.name] = prompt_read
return prompt_read
async def delete_prompt(self, prompt_id: int) -> bool:
instance = await self.repo.get(id=prompt_id)
if not instance:
return False
await self.repo.delete(instance)
await self.session.commit()
async with _LOCK:
_CACHE.pop(instance.name, None)
return True

View File

@@ -0,0 +1,60 @@
from typing import List, Optional
from fastapi import HTTPException, status
from sqlalchemy import update
from sqlalchemy.ext.asyncio import AsyncSession
from ..models import UpdateLog
from ..repositories.update_log_repository import UpdateLogRepository
class UpdateLogService:
"""更新日志服务,提供增删改查能力,并保证置顶唯一。"""
def __init__(self, session: AsyncSession):
self.session = session
self.repo = UpdateLogRepository(session)
async def list_logs(self, limit: Optional[int] = None) -> List[UpdateLog]:
if limit is None:
return list(await self.repo.list())
return list(await self.repo.list_latest(limit))
async def create_log(self, content: str, creator: str | None = None, *, is_pinned: bool = False) -> UpdateLog:
if is_pinned:
await self._clear_pinned()
log = UpdateLog(content=content, created_by=creator, is_pinned=is_pinned)
await self.repo.add(log)
await self.session.commit()
await self.session.refresh(log)
return log
async def update_log(self, log_id: int, *, content: Optional[str] = None, is_pinned: Optional[bool] = None) -> UpdateLog:
log = await self.repo.get(id=log_id)
if not log:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="更新记录不存在")
updates = {}
if content is not None:
updates["content"] = content
if is_pinned is not None:
if is_pinned:
await self._clear_pinned()
updates["is_pinned"] = is_pinned
if updates:
await self.repo.update_fields(log, **updates)
await self.session.commit()
await self.session.refresh(log)
return log
async def delete_log(self, log_id: int) -> None:
log = await self.repo.get(id=log_id)
if not log:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="更新记录不存在")
await self.repo.delete(log)
await self.session.commit()
async def _clear_pinned(self) -> None:
await self.session.execute(update(UpdateLog).values(is_pinned=False))

View File

@@ -0,0 +1,21 @@
from sqlalchemy.ext.asyncio import AsyncSession
from ..repositories.usage_metric_repository import UsageMetricRepository
class UsageService:
"""通用计数服务,目前用于统计 API 请求次数等。"""
def __init__(self, session: AsyncSession):
self.session = session
self.repo = UsageMetricRepository(session)
async def increment(self, key: str) -> None:
counter = await self.repo.get_or_create(key)
counter.value += 1
await self.session.commit()
async def get_value(self, key: str) -> int:
counter = await self.repo.get_or_create(key)
await self.session.commit()
return counter.value

View File

@@ -0,0 +1,62 @@
from typing import Optional
from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession
from ..core.security import hash_password
from ..models import User
from ..repositories.user_repository import UserRepository
from ..schemas.user import UserCreate, UserInDB
class UserService:
"""用户领域服务,负责注册、查询与配额统计。"""
def __init__(self, session: AsyncSession):
self.session = session
self.repo = UserRepository(session)
async def create_user(self, payload: UserCreate, *, external_id: str | None = None) -> UserInDB:
hashed_password = hash_password(payload.password)
user = User(
username=payload.username,
email=payload.email,
hashed_password=hashed_password,
external_id=external_id,
)
self.session.add(user)
try:
await self.session.commit()
except IntegrityError as exc:
await self.session.rollback()
raise ValueError("用户名或邮箱已存在") from exc
return UserInDB.model_validate(user)
async def get_by_username(self, username: str) -> Optional[UserInDB]:
user = await self.repo.get_by_username(username)
return UserInDB.model_validate(user) if user else None
async def get_by_email(self, email: str) -> Optional[UserInDB]:
user = await self.repo.get_by_email(email)
return UserInDB.model_validate(user) if user else None
async def get_by_external_id(self, external_id: str) -> Optional[UserInDB]:
user = await self.repo.get_by_external_id(external_id)
return UserInDB.model_validate(user) if user else None
async def get_user(self, user_id: int) -> Optional[UserInDB]:
user = await self.repo.get(id=user_id)
return UserInDB.model_validate(user) if user else None
async def list_users(self) -> list[UserInDB]:
users = await self.repo.list_all()
return [UserInDB.model_validate(item) for item in users]
async def increment_daily_request(self, user_id: int) -> None:
await self.repo.increment_daily_request(user_id)
await self.session.commit()
async def get_daily_request(self, user_id: int) -> int:
return await self.repo.get_daily_request(user_id)

View File

@@ -0,0 +1,544 @@
from __future__ import annotations
"""
基于 libsql 的向量检索服务,封装章节内容的存储与查询。
本文件中的注释均使用中文,便于团队成员快速理解 RAG 相关逻辑。
"""
import json
import logging
import math
from array import array
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, Iterable, List, Optional, Sequence
from ..core.config import settings
try: # noqa: SIM105 - 明确区分依赖缺失的情况
import libsql_client
except ImportError: # pragma: no cover - 在未安装依赖时提供友好提示
libsql_client = None # type: ignore[assignment]
logger = logging.getLogger(__name__)
@dataclass
class RetrievedChunk:
"""向量检索得到的剧情片段。"""
content: str
chapter_number: int
chapter_title: Optional[str]
score: float
metadata: Dict[str, Any]
@dataclass
class RetrievedSummary:
"""向量检索得到的章节摘要。"""
chapter_number: int
title: str
summary: str
score: float
class VectorStoreService:
"""libsql 向量库操作工具,确保不同小说项目的数据隔离。"""
def __init__(self) -> None:
if not settings.vector_store_enabled:
logger.warning("未开启向量库配置RAG 检索将被跳过。")
self._client = None
self._schema_ready = True
return
if libsql_client is None: # pragma: no cover - 运行环境缺少依赖
raise RuntimeError("缺少 libsql-client 依赖,请先在环境中安装。")
url = settings.vector_db_url
if url and url.startswith("file:"):
path_part = url.split("file:", 1)[1]
resolved = Path(path_part).expanduser().resolve()
resolved.parent.mkdir(parents=True, exist_ok=True)
url = f"file:{resolved}"
logger.info("向量库使用本地文件: %s", resolved)
try:
logger.info("初始化 libsql 客户端: url=%s", url)
self._client = libsql_client.create_client(
url=url,
auth_token=settings.vector_db_auth_token,
)
except Exception as exc: # pragma: no cover - 连接异常仅打印日志
logger.error("初始化 libsql 客户端失败: %s", exc)
self._client = None
self._schema_ready = True
else:
self._schema_ready = False
logger.info("libsql 客户端初始化成功,等待建表。")
async def ensure_schema(self) -> None:
"""初始化向量表结构,保证系统首次运行即可使用。"""
if not self._client or self._schema_ready:
return
statements = [
"""
CREATE TABLE IF NOT EXISTS rag_chunks (
id TEXT PRIMARY KEY,
project_id TEXT NOT NULL,
chapter_number INTEGER NOT NULL,
chunk_index INTEGER NOT NULL,
chapter_title TEXT,
content TEXT NOT NULL,
embedding BLOB NOT NULL,
metadata TEXT,
created_at INTEGER DEFAULT (unixepoch())
)
""",
"""
CREATE INDEX IF NOT EXISTS idx_rag_chunks_project
ON rag_chunks(project_id, chapter_number)
""",
"""
CREATE TABLE IF NOT EXISTS rag_summaries (
id TEXT PRIMARY KEY,
project_id TEXT NOT NULL,
chapter_number INTEGER NOT NULL,
title TEXT NOT NULL,
summary TEXT NOT NULL,
embedding BLOB NOT NULL,
created_at INTEGER DEFAULT (unixepoch())
)
""",
"""
CREATE INDEX IF NOT EXISTS idx_rag_summaries_project
ON rag_summaries(project_id, chapter_number)
""",
]
try:
for sql in statements:
await self._client.execute(sql) # type: ignore[union-attr]
logger.info("已确保向量库表结构存在。")
except Exception as exc: # pragma: no cover - 初始化失败时记录日志
logger.error("创建向量库表结构失败: %s", exc)
else:
self._schema_ready = True
async def query_chunks(
self,
*,
project_id: str,
embedding: Sequence[float],
top_k: Optional[int] = None,
) -> List[RetrievedChunk]:
"""根据查询向量检索剧情片段,结果已按相似度排序。"""
if not self._client or not embedding:
return []
await self.ensure_schema()
top_k = top_k or settings.vector_top_k_chunks
if top_k <= 0:
return []
blob = self._to_f32_blob(embedding)
sql = """
SELECT
content,
chapter_number,
chapter_title,
COALESCE(metadata, '{}') AS metadata,
vector_distance_cosine(embedding, :query) AS distance
FROM rag_chunks
WHERE project_id = :project_id
ORDER BY distance ASC
LIMIT :limit
"""
try:
result = await self._client.execute( # type: ignore[union-attr]
sql,
{
"project_id": project_id,
"query": blob,
"limit": top_k,
},
)
except Exception as exc: # pragma: no cover - 查询异常时仅记录
if "no such function: vector_distance_cosine" in str(exc).lower():
logger.warning("向量库缺少 vector_distance_cosine 函数,回退至应用层相似度计算。")
return await self._query_chunks_with_python_similarity(
project_id=project_id,
embedding=embedding,
top_k=top_k,
)
logger.warning("向量检索剧情片段失败: %s", exc)
return []
items: List[RetrievedChunk] = []
for row in self._iter_rows(result):
items.append(
RetrievedChunk(
content=row.get("content", ""),
chapter_number=row.get("chapter_number", 0),
chapter_title=row.get("chapter_title"),
score=row.get("distance", 0.0),
metadata=self._parse_metadata(row.get("metadata")),
)
)
return items
async def query_summaries(
self,
*,
project_id: str,
embedding: Sequence[float],
top_k: Optional[int] = None,
) -> List[RetrievedSummary]:
"""根据查询向量检索章节摘要列表。"""
if not self._client or not embedding:
return []
await self.ensure_schema()
top_k = top_k or settings.vector_top_k_summaries
if top_k <= 0:
return []
blob = self._to_f32_blob(embedding)
sql = """
SELECT
chapter_number,
title,
summary,
vector_distance_cosine(embedding, :query) AS distance
FROM rag_summaries
WHERE project_id = :project_id
ORDER BY distance ASC
LIMIT :limit
"""
try:
result = await self._client.execute( # type: ignore[union-attr]
sql,
{
"project_id": project_id,
"query": blob,
"limit": top_k,
},
)
except Exception as exc: # pragma: no cover - 查询异常时仅记录
if "no such function: vector_distance_cosine" in str(exc).lower():
logger.warning("向量库缺少 vector_distance_cosine 函数,回退至应用层相似度计算。")
return await self._query_summaries_with_python_similarity(
project_id=project_id,
embedding=embedding,
top_k=top_k,
)
logger.warning("向量检索章节摘要失败: %s", exc)
return []
items: List[RetrievedSummary] = []
for row in self._iter_rows(result):
items.append(
RetrievedSummary(
chapter_number=row.get("chapter_number", 0),
title=row.get("title", ""),
summary=row.get("summary", ""),
score=row.get("distance", 0.0),
)
)
return items
async def upsert_chunks(
self,
*,
records: Iterable[Dict[str, Any]],
) -> None:
"""批量写入章节片段,供后续检索使用。"""
if not self._client:
return
await self.ensure_schema()
sql = """
INSERT INTO rag_chunks (
id,
project_id,
chapter_number,
chunk_index,
chapter_title,
content,
embedding,
metadata
) VALUES (
:id,
:project_id,
:chapter_number,
:chunk_index,
:chapter_title,
:content,
:embedding,
:metadata
)
ON CONFLICT(id) DO UPDATE SET
content=excluded.content,
embedding=excluded.embedding,
metadata=excluded.metadata,
chapter_title=excluded.chapter_title
"""
payload = []
for item in records:
embedding = item.get("embedding", [])
payload.append(
{
**item,
"embedding": self._to_f32_blob(embedding),
"metadata": json.dumps(item.get("metadata") or {}, ensure_ascii=False),
}
)
if not payload:
return
for item in payload:
try:
await self._client.execute(sql, item) # type: ignore[union-attr]
except Exception as exc: # pragma: no cover - 单条写入失败时记录日志
logger.error("写入 rag_chunks 失败: %s", exc)
else:
logger.debug(
"已写入章节片段: project=%s chapter=%s chunk=%s",
item.get("project_id"),
item.get("chapter_number"),
item.get("chunk_index"),
)
async def upsert_summaries(
self,
*,
records: Iterable[Dict[str, Any]],
) -> None:
"""同步章节摘要向量,供摘要层检索使用。"""
if not self._client:
return
await self.ensure_schema()
sql = """
INSERT INTO rag_summaries (
id,
project_id,
chapter_number,
title,
summary,
embedding
) VALUES (
:id,
:project_id,
:chapter_number,
:title,
:summary,
:embedding
)
ON CONFLICT(id) DO UPDATE SET
summary=excluded.summary,
embedding=excluded.embedding,
title=excluded.title
"""
payload = []
for item in records:
embedding = item.get("embedding", [])
payload.append(
{
**item,
"embedding": self._to_f32_blob(embedding),
}
)
if not payload:
return
for item in payload:
try:
await self._client.execute(sql, item) # type: ignore[union-attr]
except Exception as exc: # pragma: no cover - 单条写入失败时记录日志
logger.error("写入 rag_summaries 失败: %s", exc)
else:
logger.debug(
"已写入章节摘要: project=%s chapter=%s",
item.get("project_id"),
item.get("chapter_number"),
)
async def delete_by_chapters(self, project_id: str, chapter_numbers: Sequence[int]) -> None:
"""根据章节编号批量删除对应的上下文数据。"""
if not self._client or not chapter_numbers:
return
await self.ensure_schema()
placeholders = ",".join(":chapter_" + str(idx) for idx in range(len(chapter_numbers)))
params = {
"project_id": project_id,
**{f"chapter_{idx}": number for idx, number in enumerate(chapter_numbers)},
}
chunk_sql = f"""
DELETE FROM rag_chunks
WHERE project_id = :project_id
AND chapter_number IN ({placeholders})
"""
summary_sql = f"""
DELETE FROM rag_summaries
WHERE project_id = :project_id
AND chapter_number IN ({placeholders})
"""
try:
await self._client.execute(chunk_sql, params) # type: ignore[union-attr]
await self._client.execute(summary_sql, params) # type: ignore[union-attr]
logger.info(
"已删除章节向量: project=%s chapters=%s",
project_id,
list(chapter_numbers),
)
except Exception as exc: # pragma: no cover - 删除失败时记录日志
logger.error("删除章节向量失败: project=%s chapters=%s error=%s", project_id, chapter_numbers, exc)
@staticmethod
def _to_f32_blob(embedding: Sequence[float]) -> bytes:
"""将向量浮点列表编码为 libsql 可识别的 float32 二进制。"""
return array("f", embedding).tobytes()
@staticmethod
def _from_f32_blob(blob: Any) -> List[float]:
"""将数据库中的 BLOB 解码为浮点列表。"""
if not blob:
return []
if isinstance(blob, memoryview):
blob = blob.tobytes()
data = array("f")
data.frombytes(bytes(blob))
return list(data)
@staticmethod
def _cosine_distance(vec_a: Sequence[float], vec_b: Sequence[float]) -> float:
"""计算余弦距离1 - similarity避免除零。"""
if not vec_a or not vec_b:
return 1.0
dot = sum(a * b for a, b in zip(vec_a, vec_b))
norm_a = math.sqrt(sum(a * a for a in vec_a))
norm_b = math.sqrt(sum(b * b for b in vec_b))
if norm_a == 0 or norm_b == 0:
return 1.0
similarity = dot / (norm_a * norm_b)
return 1.0 - similarity
async def _query_chunks_with_python_similarity(
self,
*,
project_id: str,
embedding: Sequence[float],
top_k: int,
) -> List[RetrievedChunk]:
sql = """
SELECT
content,
chapter_number,
chapter_title,
COALESCE(metadata, '{}') AS metadata,
embedding
FROM rag_chunks
WHERE project_id = :project_id
"""
result = await self._client.execute(sql, {"project_id": project_id}) # type: ignore[union-attr]
scored: List[RetrievedChunk] = []
for row in self._iter_rows(result):
stored_embedding = self._from_f32_blob(row.get("embedding"))
distance = self._cosine_distance(embedding, stored_embedding)
scored.append(
RetrievedChunk(
content=row.get("content", ""),
chapter_number=row.get("chapter_number", 0),
chapter_title=row.get("chapter_title"),
score=distance,
metadata=self._parse_metadata(row.get("metadata")),
)
)
scored.sort(key=lambda item: item.score)
return scored[:top_k]
async def _query_summaries_with_python_similarity(
self,
*,
project_id: str,
embedding: Sequence[float],
top_k: int,
) -> List[RetrievedSummary]:
sql = """
SELECT
chapter_number,
title,
summary,
embedding
FROM rag_summaries
WHERE project_id = :project_id
"""
result = await self._client.execute(sql, {"project_id": project_id}) # type: ignore[union-attr]
scored: List[RetrievedSummary] = []
for row in self._iter_rows(result):
stored_embedding = self._from_f32_blob(row.get("embedding"))
distance = self._cosine_distance(embedding, stored_embedding)
scored.append(
RetrievedSummary(
chapter_number=row.get("chapter_number", 0),
title=row.get("title", ""),
summary=row.get("summary", ""),
score=distance,
)
)
scored.sort(key=lambda item: item.score)
return scored[:top_k]
@staticmethod
def _parse_metadata(raw: Any) -> Dict[str, Any]:
"""解析存储的 JSON 文本,确保输出为 dict。"""
if not raw:
return {}
if isinstance(raw, dict):
return raw
if isinstance(raw, (bytes, bytearray)):
raw = raw.decode("utf-8")
if isinstance(raw, str):
try:
parsed = json.loads(raw)
return parsed if isinstance(parsed, dict) else {}
except json.JSONDecodeError:
return {}
return {}
@staticmethod
def _iter_rows(result: Any) -> Iterable[Dict[str, Any]]:
"""统一处理 libsql 返回的行数据,确保以 dict 形式迭代。"""
rows = getattr(result, "rows", None)
if rows is None:
rows = result
if not rows:
return []
normalized: List[Dict[str, Any]] = []
for row in rows:
if isinstance(row, dict):
normalized.append(row)
elif hasattr(row, "_asdict"):
normalized.append(row._asdict()) # type: ignore[attr-defined]
else:
try:
normalized.append(dict(row))
except Exception: # pragma: no cover - 无法转换时跳过
continue
return normalized
__all__ = [
"VectorStoreService",
"RetrievedChunk",
"RetrievedSummary",
]

View File

View File

@@ -0,0 +1,81 @@
import re
def remove_think_tags(raw_text: str) -> str:
"""移除 <think></think> 标签,避免污染结果。"""
if not raw_text:
return raw_text
return re.sub(r"<think>.*?</think>", "", raw_text, flags=re.DOTALL).strip()
def unwrap_markdown_json(raw_text: str) -> str:
"""从 Markdown 或普通文本中提取 JSON 字符串。"""
if not raw_text:
return raw_text
trimmed = raw_text.strip()
fence_match = re.search(r"```(?:json|JSON)?\s*(.*?)\s*```", trimmed, re.DOTALL)
if fence_match:
candidate = fence_match.group(1).strip()
if candidate:
return candidate
json_start_candidates = [idx for idx in (trimmed.find("{"), trimmed.find("[")) if idx != -1]
if json_start_candidates:
start_idx = min(json_start_candidates)
closing_brace = trimmed.rfind("}")
closing_bracket = trimmed.rfind("]")
end_idx = max(closing_brace, closing_bracket)
if end_idx != -1 and end_idx > start_idx:
candidate = trimmed[start_idx : end_idx + 1].strip()
if candidate:
return candidate
return trimmed
def sanitize_json_like_text(raw_text: str) -> str:
"""对可能含有未转义换行/引号的 JSON 文本进行清洗。"""
if not raw_text:
return raw_text
result = []
in_string = False
escape_next = False
length = len(raw_text)
i = 0
while i < length:
ch = raw_text[i]
if in_string:
if escape_next:
result.append(ch)
escape_next = False
elif ch == "\\":
result.append(ch)
escape_next = True
elif ch == '"':
j = i + 1
while j < length and raw_text[j] in " \t\r\n":
j += 1
if j >= length or raw_text[j] in "}]" or raw_text[j] == ",":
in_string = False
result.append(ch)
else:
result.extend(["\\", '"'])
elif ch == "\n":
result.extend(["\\", "n"])
elif ch == "\r":
result.extend(["\\", "r"])
elif ch == "\t":
result.extend(["\\", "t"])
else:
result.append(ch)
else:
if ch == '"':
in_string = True
result.append(ch)
i += 1
return "".join(result)

View File

@@ -0,0 +1,65 @@
# -*- coding: utf-8 -*-
"""OpenAI 兼容型 LLM 工具封装,保持与旧项目一致的接口体验。"""
import os
from dataclasses import asdict, dataclass
from typing import AsyncGenerator, Dict, List, Optional
from openai import AsyncOpenAI
@dataclass
class ChatMessage:
role: str
content: str
def to_dict(self) -> Dict[str, str]:
return asdict(self)
class LLMClient:
"""异步流式调用封装,兼容 OpenAI SDK。"""
def __init__(self, api_key: Optional[str] = None, base_url: Optional[str] = None):
key = api_key or os.environ.get("OPENAI_API_KEY")
if not key:
raise ValueError("缺少 OPENAI_API_KEY 配置,请在数据库或环境变量中补全。")
self._client = AsyncOpenAI(api_key=key, base_url=base_url or os.environ.get("OPENAI_API_BASE"))
async def stream_chat(
self,
messages: List[ChatMessage],
model: Optional[str] = None,
response_format: Optional[str] = None,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
max_tokens: Optional[int] = None,
timeout: int = 120,
**kwargs,
) -> AsyncGenerator[Dict[str, str], None]:
payload = {
"model": model or os.environ.get("MODEL", "gpt-3.5-turbo"),
"messages": [msg.to_dict() for msg in messages],
"stream": True,
"timeout": timeout,
**kwargs,
}
if response_format:
payload["response_format"] = {"type": response_format}
if temperature is not None:
payload["temperature"] = temperature
if top_p is not None:
payload["top_p"] = top_p
if max_tokens is not None:
payload["max_tokens"] = max_tokens
stream = await self._client.chat.completions.create(**payload)
async for chunk in stream:
if not chunk.choices:
continue
choice = chunk.choices[0]
yield {
"content": choice.delta.content,
"finish_reason": choice.finish_reason,
}