feat: 初始提交
This commit is contained in:
0
backend/app/__init__.py
Normal file
0
backend/app/__init__.py
Normal file
0
backend/app/api/__init__.py
Normal file
0
backend/app/api/__init__.py
Normal file
12
backend/app/api/routers/__init__.py
Normal file
12
backend/app/api/routers/__init__.py
Normal file
@@ -0,0 +1,12 @@
|
||||
from fastapi import APIRouter
|
||||
|
||||
from . import admin, auth, llm_config, novels, updates, writer
|
||||
|
||||
api_router = APIRouter()
|
||||
|
||||
api_router.include_router(auth.router)
|
||||
api_router.include_router(novels.router)
|
||||
api_router.include_router(writer.router)
|
||||
api_router.include_router(admin.router)
|
||||
api_router.include_router(updates.router)
|
||||
api_router.include_router(llm_config.router)
|
||||
340
backend/app/api/routers/admin.py
Normal file
340
backend/app/api/routers/admin.py
Normal file
@@ -0,0 +1,340 @@
|
||||
import logging
|
||||
from typing import List, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from ...core.dependencies import get_current_admin
|
||||
from ...db.session import get_session
|
||||
from ...models import NovelProject, UsageMetric, User
|
||||
from ...schemas.admin import (
|
||||
AdminNovelSummary,
|
||||
DailyRequestLimit,
|
||||
Statistics,
|
||||
UpdateLogCreate,
|
||||
UpdateLogRead,
|
||||
UpdateLogUpdate,
|
||||
)
|
||||
from ...schemas.config import SystemConfigCreate, SystemConfigRead, SystemConfigUpdate
|
||||
from ...schemas.prompt import PromptCreate, PromptRead, PromptUpdate
|
||||
from ...schemas.novel import (
|
||||
Chapter as ChapterSchema,
|
||||
NovelProject as NovelProjectSchema,
|
||||
NovelSectionResponse,
|
||||
NovelSectionType,
|
||||
)
|
||||
from ...schemas.user import PasswordChangeRequest, User as UserSchema
|
||||
from ...services.auth_service import AuthService
|
||||
from ...services.admin_setting_service import AdminSettingService
|
||||
from ...services.config_service import ConfigService
|
||||
from ...services.novel_service import NovelService
|
||||
from ...services.prompt_service import PromptService
|
||||
from ...services.update_log_service import UpdateLogService
|
||||
from ...services.user_service import UserService
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/admin", tags=["Admin"])
|
||||
|
||||
|
||||
def get_prompt_service(session: AsyncSession = Depends(get_session)) -> PromptService:
|
||||
return PromptService(session)
|
||||
|
||||
|
||||
def get_update_log_service(session: AsyncSession = Depends(get_session)) -> UpdateLogService:
|
||||
return UpdateLogService(session)
|
||||
|
||||
|
||||
def get_admin_setting_service(session: AsyncSession = Depends(get_session)) -> AdminSettingService:
|
||||
return AdminSettingService(session)
|
||||
|
||||
|
||||
def get_config_service(session: AsyncSession = Depends(get_session)) -> ConfigService:
|
||||
return ConfigService(session)
|
||||
|
||||
|
||||
def get_novel_service(session: AsyncSession = Depends(get_session)) -> NovelService:
|
||||
return NovelService(session)
|
||||
|
||||
|
||||
def get_user_service(session: AsyncSession = Depends(get_session)) -> UserService:
|
||||
return UserService(session)
|
||||
|
||||
|
||||
def get_auth_service(session: AsyncSession = Depends(get_session)) -> AuthService:
|
||||
return AuthService(session)
|
||||
|
||||
|
||||
@router.get("/stats", response_model=Statistics)
|
||||
async def read_statistics(
|
||||
session: AsyncSession = Depends(get_session),
|
||||
_: None = Depends(get_current_admin),
|
||||
) -> Statistics:
|
||||
novel_count = await session.scalar(select(func.count(NovelProject.id))) or 0
|
||||
user_count = await session.scalar(select(func.count(User.id))) or 0
|
||||
usage = await session.get(UsageMetric, "api_request_count")
|
||||
api_request_count = usage.value if usage else 0
|
||||
logger.info("管理员获取统计数据:小说=%s,用户=%s,请求=%s", novel_count, user_count, api_request_count)
|
||||
return Statistics(novel_count=novel_count, user_count=user_count, api_request_count=api_request_count)
|
||||
|
||||
|
||||
@router.get("/users", response_model=List[UserSchema])
|
||||
async def list_users(
|
||||
service: UserService = Depends(get_user_service),
|
||||
_: None = Depends(get_current_admin),
|
||||
) -> List[UserSchema]:
|
||||
users = await service.list_users()
|
||||
logger.info("管理员请求用户列表,共 %s 条", len(users))
|
||||
return [UserSchema.model_validate(user) for user in users]
|
||||
|
||||
|
||||
@router.get("/novel-projects", response_model=List[AdminNovelSummary])
|
||||
async def list_novel_projects(
|
||||
service: NovelService = Depends(get_novel_service),
|
||||
_: None = Depends(get_current_admin),
|
||||
) -> List[AdminNovelSummary]:
|
||||
projects = await service.list_projects_for_admin()
|
||||
logger.info("管理员查看项目列表,共 %s 个", len(projects))
|
||||
return projects
|
||||
|
||||
|
||||
@router.get("/novel-projects/{project_id}", response_model=NovelProjectSchema)
|
||||
async def get_novel_project(
|
||||
project_id: str,
|
||||
service: NovelService = Depends(get_novel_service),
|
||||
_: None = Depends(get_current_admin),
|
||||
) -> NovelProjectSchema:
|
||||
logger.info("管理员查看项目详情:%s", project_id)
|
||||
return await service.get_project_schema_for_admin(project_id)
|
||||
|
||||
|
||||
@router.get("/novel-projects/{project_id}/sections/{section}", response_model=NovelSectionResponse)
|
||||
async def get_novel_project_section(
|
||||
project_id: str,
|
||||
section: NovelSectionType,
|
||||
service: NovelService = Depends(get_novel_service),
|
||||
_: None = Depends(get_current_admin),
|
||||
) -> NovelSectionResponse:
|
||||
logger.info("管理员查看项目 %s 的 %s 区段", project_id, section)
|
||||
return await service.get_section_data_for_admin(project_id, section)
|
||||
|
||||
|
||||
@router.get("/novel-projects/{project_id}/chapters/{chapter_number}", response_model=ChapterSchema)
|
||||
async def get_novel_project_chapter(
|
||||
project_id: str,
|
||||
chapter_number: int,
|
||||
service: NovelService = Depends(get_novel_service),
|
||||
_: None = Depends(get_current_admin),
|
||||
) -> ChapterSchema:
|
||||
logger.info("管理员查看项目 %s 第 %s 章详情", project_id, chapter_number)
|
||||
return await service.get_chapter_schema_for_admin(project_id, chapter_number)
|
||||
|
||||
|
||||
@router.get("/prompts", response_model=List[PromptRead])
|
||||
async def list_prompts(
|
||||
service: PromptService = Depends(get_prompt_service),
|
||||
_: None = Depends(get_current_admin),
|
||||
) -> List[PromptRead]:
|
||||
prompts = await service.list_prompts()
|
||||
logger.info("管理员请求提示词列表,共 %s 条", len(prompts))
|
||||
return prompts
|
||||
|
||||
|
||||
@router.post("/prompts", response_model=PromptRead, status_code=status.HTTP_201_CREATED)
|
||||
async def create_prompt(
|
||||
payload: PromptCreate,
|
||||
service: PromptService = Depends(get_prompt_service),
|
||||
_: None = Depends(get_current_admin),
|
||||
) -> PromptRead:
|
||||
prompt = await service.create_prompt(payload)
|
||||
logger.info("管理员创建提示词:%s", prompt.id)
|
||||
return prompt
|
||||
|
||||
|
||||
@router.get("/prompts/{prompt_id}", response_model=PromptRead)
|
||||
async def get_prompt(
|
||||
prompt_id: int,
|
||||
service: PromptService = Depends(get_prompt_service),
|
||||
_: None = Depends(get_current_admin),
|
||||
) -> PromptRead:
|
||||
prompt = await service.get_prompt_by_id(prompt_id)
|
||||
if not prompt:
|
||||
logger.warning("提示词 %s 不存在", prompt_id)
|
||||
raise HTTPException(status_code=404, detail="提示词不存在")
|
||||
logger.info("管理员获取提示词:%s", prompt_id)
|
||||
return prompt
|
||||
|
||||
|
||||
@router.patch("/prompts/{prompt_id}", response_model=PromptRead)
|
||||
async def update_prompt(
|
||||
prompt_id: int,
|
||||
payload: PromptUpdate,
|
||||
service: PromptService = Depends(get_prompt_service),
|
||||
_: None = Depends(get_current_admin),
|
||||
) -> PromptRead:
|
||||
result = await service.update_prompt(prompt_id, payload)
|
||||
if not result:
|
||||
logger.warning("提示词 %s 不存在,无法更新", prompt_id)
|
||||
raise HTTPException(status_code=404, detail="提示词不存在")
|
||||
logger.info("管理员更新提示词:%s", prompt_id)
|
||||
return result
|
||||
|
||||
|
||||
@router.delete("/prompts/{prompt_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||
async def delete_prompt(
|
||||
prompt_id: int,
|
||||
service: PromptService = Depends(get_prompt_service),
|
||||
_: None = Depends(get_current_admin),
|
||||
) -> None:
|
||||
deleted = await service.delete_prompt(prompt_id)
|
||||
if not deleted:
|
||||
logger.warning("提示词 %s 不存在,无法删除", prompt_id)
|
||||
raise HTTPException(status_code=404, detail="提示词不存在")
|
||||
logger.info("管理员删除提示词:%s", prompt_id)
|
||||
|
||||
|
||||
@router.get("/update-logs", response_model=List[UpdateLogRead])
|
||||
async def list_update_logs(
|
||||
service: UpdateLogService = Depends(get_update_log_service),
|
||||
_: None = Depends(get_current_admin),
|
||||
) -> List[UpdateLogRead]:
|
||||
logs = await service.list_logs()
|
||||
logger.info("管理员查看更新日志列表,共 %s 条", len(logs))
|
||||
return [UpdateLogRead.model_validate(log) for log in logs]
|
||||
|
||||
|
||||
@router.post("/update-logs", response_model=UpdateLogRead, status_code=status.HTTP_201_CREATED)
|
||||
async def create_update_log(
|
||||
payload: UpdateLogCreate,
|
||||
service: UpdateLogService = Depends(get_update_log_service),
|
||||
current_admin=Depends(get_current_admin),
|
||||
) -> UpdateLogRead:
|
||||
log = await service.create_log(
|
||||
payload.content,
|
||||
creator=current_admin.username,
|
||||
is_pinned=payload.is_pinned or False,
|
||||
)
|
||||
logger.info("管理员 %s 创建更新日志:%s", current_admin.username, log.id)
|
||||
return UpdateLogRead.model_validate(log)
|
||||
|
||||
|
||||
@router.delete("/update-logs/{log_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||
async def delete_update_log(
|
||||
log_id: int,
|
||||
service: UpdateLogService = Depends(get_update_log_service),
|
||||
_: None = Depends(get_current_admin),
|
||||
) -> None:
|
||||
await service.delete_log(log_id)
|
||||
logger.info("管理员删除更新日志:%s", log_id)
|
||||
|
||||
|
||||
@router.patch("/update-logs/{log_id}", response_model=UpdateLogRead)
|
||||
async def update_update_log(
|
||||
log_id: int,
|
||||
payload: UpdateLogUpdate,
|
||||
service: UpdateLogService = Depends(get_update_log_service),
|
||||
_: None = Depends(get_current_admin),
|
||||
) -> UpdateLogRead:
|
||||
log = await service.update_log(
|
||||
log_id,
|
||||
content=payload.content,
|
||||
is_pinned=payload.is_pinned,
|
||||
)
|
||||
logger.info("管理员更新日志 %s", log_id)
|
||||
return UpdateLogRead.model_validate(log)
|
||||
|
||||
|
||||
@router.get("/settings/daily-request-limit", response_model=DailyRequestLimit)
|
||||
async def get_daily_limit(
|
||||
service: AdminSettingService = Depends(get_admin_setting_service),
|
||||
_: None = Depends(get_current_admin),
|
||||
) -> DailyRequestLimit:
|
||||
value = await service.get("daily_request_limit", "100")
|
||||
logger.info("管理员查询每日请求上限:%s", value)
|
||||
return DailyRequestLimit(limit=int(value or 100))
|
||||
|
||||
|
||||
@router.put("/settings/daily-request-limit", response_model=DailyRequestLimit)
|
||||
async def update_daily_limit(
|
||||
payload: DailyRequestLimit,
|
||||
service: AdminSettingService = Depends(get_admin_setting_service),
|
||||
_: None = Depends(get_current_admin),
|
||||
) -> DailyRequestLimit:
|
||||
await service.set("daily_request_limit", str(payload.limit))
|
||||
logger.info("管理员设置每日请求上限为 %s", payload.limit)
|
||||
return payload
|
||||
|
||||
|
||||
@router.get("/system-configs", response_model=List[SystemConfigRead])
|
||||
async def list_system_configs(
|
||||
service: ConfigService = Depends(get_config_service),
|
||||
_: None = Depends(get_current_admin),
|
||||
) -> List[SystemConfigRead]:
|
||||
configs = await service.list_configs()
|
||||
logger.info("管理员获取系统配置,共 %s 条", len(configs))
|
||||
return configs
|
||||
|
||||
|
||||
@router.get("/system-configs/{key}", response_model=SystemConfigRead)
|
||||
async def get_system_config(
|
||||
key: str,
|
||||
service: ConfigService = Depends(get_config_service),
|
||||
_: None = Depends(get_current_admin),
|
||||
) -> SystemConfigRead:
|
||||
config = await service.get_config(key)
|
||||
if not config:
|
||||
logger.warning("系统配置 %s 不存在", key)
|
||||
raise HTTPException(status_code=404, detail="配置项不存在")
|
||||
logger.info("管理员查询系统配置:%s", key)
|
||||
return config
|
||||
|
||||
|
||||
@router.put("/system-configs/{key}", response_model=SystemConfigRead)
|
||||
async def upsert_system_config(
|
||||
key: str,
|
||||
payload: SystemConfigCreate,
|
||||
service: ConfigService = Depends(get_config_service),
|
||||
_: None = Depends(get_current_admin),
|
||||
) -> SystemConfigRead:
|
||||
logger.info("管理员写入系统配置:%s", key)
|
||||
return await service.upsert_config(
|
||||
SystemConfigCreate(key=key, value=payload.value, description=payload.description)
|
||||
)
|
||||
|
||||
|
||||
@router.patch("/system-configs/{key}", response_model=SystemConfigRead)
|
||||
async def patch_system_config(
|
||||
key: str,
|
||||
payload: SystemConfigUpdate,
|
||||
service: ConfigService = Depends(get_config_service),
|
||||
_: None = Depends(get_current_admin),
|
||||
) -> SystemConfigRead:
|
||||
config = await service.patch_config(key, payload)
|
||||
if not config:
|
||||
logger.warning("系统配置 %s 不存在,无法更新", key)
|
||||
raise HTTPException(status_code=404, detail="配置项不存在")
|
||||
logger.info("管理员部分更新系统配置:%s", key)
|
||||
return config
|
||||
|
||||
|
||||
@router.delete("/system-configs/{key}", status_code=status.HTTP_204_NO_CONTENT)
|
||||
async def delete_system_config(
|
||||
key: str,
|
||||
service: ConfigService = Depends(get_config_service),
|
||||
_: None = Depends(get_current_admin),
|
||||
) -> None:
|
||||
deleted = await service.remove_config(key)
|
||||
if not deleted:
|
||||
logger.warning("系统配置 %s 不存在,无法删除", key)
|
||||
raise HTTPException(status_code=404, detail="配置项不存在")
|
||||
logger.info("管理员删除系统配置:%s", key)
|
||||
|
||||
|
||||
@router.post("/password", status_code=status.HTTP_204_NO_CONTENT)
|
||||
async def change_password(
|
||||
payload: PasswordChangeRequest,
|
||||
current_admin=Depends(get_current_admin),
|
||||
service: AuthService = Depends(get_auth_service),
|
||||
) -> None:
|
||||
await service.change_password(current_admin.username, payload.old_password, payload.new_password)
|
||||
logger.info("管理员 %s 修改密码", current_admin.username)
|
||||
106
backend/app/api/routers/auth.py
Normal file
106
backend/app/api/routers/auth.py
Normal file
@@ -0,0 +1,106 @@
|
||||
import logging
|
||||
from datetime import timedelta
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from fastapi.responses import HTMLResponse, RedirectResponse
|
||||
from fastapi.security import OAuth2PasswordRequestForm
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from ...core.config import settings
|
||||
from ...core.dependencies import get_current_user
|
||||
from ...db.session import get_session
|
||||
from ...schemas.user import AuthOptions, Token, User, UserInDB, UserRegistration
|
||||
from ...services.auth_service import AuthService
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/auth", tags=["Authentication"])
|
||||
|
||||
|
||||
def get_auth_service(session: AsyncSession = Depends(get_session)) -> AuthService:
|
||||
return AuthService(session)
|
||||
|
||||
|
||||
@router.post("/send-code", status_code=204)
|
||||
async def send_verification_code(email: str, service: AuthService = Depends(get_auth_service)):
|
||||
await service.send_verification_code(email)
|
||||
logger.info("向 %s 发送验证码", email)
|
||||
|
||||
|
||||
@router.get("/options", response_model=AuthOptions)
|
||||
async def read_auth_options(service: AuthService = Depends(get_auth_service)):
|
||||
"""读取认证功能开关,供前端动态渲染。"""
|
||||
options = await service.get_auth_options()
|
||||
return options
|
||||
|
||||
|
||||
@router.post("/users", response_model=User, status_code=status.HTTP_201_CREATED)
|
||||
async def register_user(payload: UserRegistration, service: AuthService = Depends(get_auth_service)):
|
||||
user = await service.register_user(payload)
|
||||
logger.info("注册新用户:%s", user.username)
|
||||
return User.model_validate(user)
|
||||
|
||||
|
||||
@router.post("/token", response_model=Token)
|
||||
async def login(form_data: OAuth2PasswordRequestForm = Depends(), service: AuthService = Depends(get_auth_service)):
|
||||
user = await service.authenticate_user(form_data.username, form_data.password)
|
||||
must_change_password = service.requires_password_reset(user)
|
||||
token = await service.create_access_token(user, must_change_password=must_change_password)
|
||||
logger.info("用户 %s 登录成功,需改密=%s", form_data.username, must_change_password)
|
||||
return token
|
||||
|
||||
|
||||
@router.get("/users/me", response_model=User)
|
||||
async def read_current_user(current_user: UserInDB = Depends(get_current_user)):
|
||||
logger.debug("读取当前用户:%s", current_user.username)
|
||||
return current_user
|
||||
|
||||
|
||||
@router.get("/linuxdo/login")
|
||||
async def login_with_linuxdo(service: AuthService = Depends(get_auth_service)):
|
||||
if not await service.is_linuxdo_login_enabled():
|
||||
logger.warning("Linux.do 登录未启用")
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="未启用 Linux.do 登录")
|
||||
client_id = await service.get_config_value("linuxdo.client_id")
|
||||
redirect_uri = await service.get_config_value("linuxdo.redirect_uri")
|
||||
auth_url = await service.get_config_value("linuxdo.auth_url")
|
||||
if not all([client_id, redirect_uri, auth_url]):
|
||||
logger.error("Linux.do OAuth 参数未配置完整")
|
||||
raise HTTPException(status_code=500, detail="未配置 Linux.do OAuth 参数")
|
||||
params = {
|
||||
"client_id": client_id,
|
||||
"redirect_uri": redirect_uri,
|
||||
"response_type": "code",
|
||||
"scope": "user",
|
||||
}
|
||||
query = "&".join(f"{k}={v}" for k, v in params.items())
|
||||
logger.info("跳转 Linux.do 授权,client_id=%s", client_id)
|
||||
return RedirectResponse(url=f"{auth_url}?{query}")
|
||||
|
||||
|
||||
@router.get("/linuxdo/register", response_class=HTMLResponse)
|
||||
async def register_with_linuxdo(code: str, service: AuthService = Depends(get_auth_service)):
|
||||
token = await service.handle_linuxdo_callback(code)
|
||||
logger.info("Linux.do 授权回调成功")
|
||||
token_json = token.model_dump_json()
|
||||
html_content = f"""<!DOCTYPE html>
|
||||
<html lang=\"zh-CN\">
|
||||
<head><meta charset=\"UTF-8\"><title>正在跳转</title></head>
|
||||
<body>
|
||||
<p>正在跳转,请稍候...</p>
|
||||
<script>
|
||||
(function() {{
|
||||
const token = JSON.parse('{token_json}');
|
||||
try {{
|
||||
window.localStorage.setItem('token', token.access_token);
|
||||
}} catch (err) {{
|
||||
console.error('无法写入本地存储', err);
|
||||
}}
|
||||
window.location.replace('/');
|
||||
}})();
|
||||
</script>
|
||||
</body>
|
||||
</html>"""
|
||||
return HTMLResponse(content=html_content)
|
||||
54
backend/app/api/routers/llm_config.py
Normal file
54
backend/app/api/routers/llm_config.py
Normal file
@@ -0,0 +1,54 @@
|
||||
import logging
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from ...core.dependencies import get_current_user
|
||||
from ...db.session import get_session
|
||||
from ...schemas.llm_config import LLMConfigCreate, LLMConfigRead
|
||||
from ...schemas.user import UserInDB
|
||||
from ...services.llm_config_service import LLMConfigService
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/llm-config", tags=["LLM Configuration"])
|
||||
|
||||
|
||||
def get_llm_config_service(session: AsyncSession = Depends(get_session)) -> LLMConfigService:
|
||||
return LLMConfigService(session)
|
||||
|
||||
|
||||
@router.get("", response_model=LLMConfigRead)
|
||||
async def read_llm_config(
|
||||
service: LLMConfigService = Depends(get_llm_config_service),
|
||||
current_user: UserInDB = Depends(get_current_user),
|
||||
) -> LLMConfigRead:
|
||||
config = await service.get_config(current_user.id)
|
||||
if not config:
|
||||
logger.warning("用户 %s 尚未设置 LLM 配置", current_user.id)
|
||||
raise HTTPException(status_code=404, detail="尚未设置自定义配置")
|
||||
logger.info("用户 %s 获取 LLM 配置", current_user.id)
|
||||
return config
|
||||
|
||||
|
||||
@router.put("", response_model=LLMConfigRead)
|
||||
async def upsert_llm_config(
|
||||
payload: LLMConfigCreate,
|
||||
service: LLMConfigService = Depends(get_llm_config_service),
|
||||
current_user: UserInDB = Depends(get_current_user),
|
||||
) -> LLMConfigRead:
|
||||
logger.info("用户 %s 更新 LLM 配置", current_user.id)
|
||||
return await service.upsert_config(current_user.id, payload)
|
||||
|
||||
|
||||
@router.delete("", status_code=status.HTTP_204_NO_CONTENT)
|
||||
async def delete_llm_config(
|
||||
service: LLMConfigService = Depends(get_llm_config_service),
|
||||
current_user: UserInDB = Depends(get_current_user),
|
||||
) -> None:
|
||||
deleted = await service.delete_config(current_user.id)
|
||||
if not deleted:
|
||||
logger.warning("用户 %s 删除 LLM 配置失败,未找到记录", current_user.id)
|
||||
raise HTTPException(status_code=404, detail="未找到配置")
|
||||
logger.info("用户 %s 删除 LLM 配置", current_user.id)
|
||||
301
backend/app/api/routers/novels.py
Normal file
301
backend/app/api/routers/novels.py
Normal file
@@ -0,0 +1,301 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import Dict, List
|
||||
|
||||
from fastapi import APIRouter, Body, Depends, HTTPException, status
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from ...core.dependencies import get_current_user
|
||||
from ...db.session import get_session
|
||||
from ...schemas.novel import (
|
||||
Blueprint,
|
||||
BlueprintGenerationResponse,
|
||||
BlueprintPatch,
|
||||
Chapter as ChapterSchema,
|
||||
ConverseRequest,
|
||||
ConverseResponse,
|
||||
NovelProject as NovelProjectSchema,
|
||||
NovelProjectSummary,
|
||||
NovelSectionResponse,
|
||||
NovelSectionType,
|
||||
)
|
||||
from ...schemas.user import UserInDB
|
||||
from ...services.llm_service import LLMService
|
||||
from ...services.novel_service import NovelService
|
||||
from ...services.prompt_service import PromptService
|
||||
from ...utils.json_utils import remove_think_tags, unwrap_markdown_json
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/novels", tags=["Novels"])
|
||||
|
||||
JSON_RESPONSE_INSTRUCTION = """
|
||||
IMPORTANT: 你的回复必须是合法的 JSON 对象,并严格包含以下字段:
|
||||
{
|
||||
"ai_message": "string",
|
||||
"ui_control": {
|
||||
"type": "single_choice | text_input | info_display",
|
||||
"options": [
|
||||
{"id": "option_1", "label": "string"}
|
||||
],
|
||||
"placeholder": "string"
|
||||
},
|
||||
"conversation_state": {},
|
||||
"is_complete": false
|
||||
}
|
||||
不要输出额外的文本或解释。
|
||||
"""
|
||||
|
||||
|
||||
def _ensure_prompt(prompt: str | None, name: str) -> str:
|
||||
if not prompt:
|
||||
raise HTTPException(status_code=500, detail=f"未配置名为 {name} 的提示词,请联系管理员")
|
||||
return prompt
|
||||
|
||||
|
||||
@router.post("", response_model=NovelProjectSchema, status_code=status.HTTP_201_CREATED)
|
||||
async def create_novel(
|
||||
title: str = Body(...),
|
||||
initial_prompt: str = Body(...),
|
||||
session: AsyncSession = Depends(get_session),
|
||||
current_user: UserInDB = Depends(get_current_user),
|
||||
) -> NovelProjectSchema:
|
||||
"""为当前用户创建一个新的小说项目。"""
|
||||
novel_service = NovelService(session)
|
||||
project = await novel_service.create_project(current_user.id, title, initial_prompt)
|
||||
logger.info("用户 %s 创建项目 %s", current_user.id, project.id)
|
||||
return await novel_service.get_project_schema(project.id, current_user.id)
|
||||
|
||||
|
||||
@router.get("", response_model=List[NovelProjectSummary])
|
||||
async def list_novels(
|
||||
session: AsyncSession = Depends(get_session),
|
||||
current_user: UserInDB = Depends(get_current_user),
|
||||
) -> List[NovelProjectSummary]:
|
||||
"""列出用户的全部小说项目摘要信息。"""
|
||||
novel_service = NovelService(session)
|
||||
projects = await novel_service.list_projects_for_user(current_user.id)
|
||||
logger.info("用户 %s 获取项目列表,共 %s 个", current_user.id, len(projects))
|
||||
return projects
|
||||
|
||||
|
||||
@router.get("/{project_id}", response_model=NovelProjectSchema)
|
||||
async def get_novel(
|
||||
project_id: str,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
current_user: UserInDB = Depends(get_current_user),
|
||||
) -> NovelProjectSchema:
|
||||
novel_service = NovelService(session)
|
||||
logger.info("用户 %s 查询项目 %s", current_user.id, project_id)
|
||||
return await novel_service.get_project_schema(project_id, current_user.id)
|
||||
|
||||
|
||||
@router.get("/{project_id}/sections/{section}", response_model=NovelSectionResponse)
|
||||
async def get_novel_section(
|
||||
project_id: str,
|
||||
section: NovelSectionType,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
current_user: UserInDB = Depends(get_current_user),
|
||||
) -> NovelSectionResponse:
|
||||
novel_service = NovelService(session)
|
||||
logger.info("用户 %s 获取项目 %s 的 %s 区段", current_user.id, project_id, section)
|
||||
return await novel_service.get_section_data(project_id, current_user.id, section)
|
||||
|
||||
|
||||
@router.get("/{project_id}/chapters/{chapter_number}", response_model=ChapterSchema)
|
||||
async def get_chapter(
|
||||
project_id: str,
|
||||
chapter_number: int,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
current_user: UserInDB = Depends(get_current_user),
|
||||
) -> ChapterSchema:
|
||||
novel_service = NovelService(session)
|
||||
logger.info("用户 %s 获取项目 %s 第 %s 章", current_user.id, project_id, chapter_number)
|
||||
return await novel_service.get_chapter_schema(project_id, current_user.id, chapter_number)
|
||||
|
||||
|
||||
@router.delete("", status_code=status.HTTP_200_OK)
|
||||
async def delete_novels(
|
||||
project_ids: List[str] = Body(...),
|
||||
session: AsyncSession = Depends(get_session),
|
||||
current_user: UserInDB = Depends(get_current_user),
|
||||
) -> Dict[str, str]:
|
||||
novel_service = NovelService(session)
|
||||
await novel_service.delete_projects(project_ids, current_user.id)
|
||||
logger.info("用户 %s 删除项目 %s", current_user.id, project_ids)
|
||||
return {"status": "success", "message": f"成功删除 {len(project_ids)} 个项目"}
|
||||
|
||||
|
||||
@router.post("/{project_id}/concept/converse", response_model=ConverseResponse)
|
||||
async def converse_with_concept(
|
||||
project_id: str,
|
||||
request: ConverseRequest,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
current_user: UserInDB = Depends(get_current_user),
|
||||
) -> ConverseResponse:
|
||||
"""与概念设计师(LLM)进行对话,引导蓝图筹备。"""
|
||||
novel_service = NovelService(session)
|
||||
prompt_service = PromptService(session)
|
||||
llm_service = LLMService(session)
|
||||
|
||||
project = await novel_service.ensure_project_owner(project_id, current_user.id)
|
||||
|
||||
history_records = await novel_service.list_conversations(project_id)
|
||||
logger.info(
|
||||
"项目 %s 概念对话请求,用户 %s,历史记录 %s 条",
|
||||
project_id,
|
||||
current_user.id,
|
||||
len(history_records),
|
||||
)
|
||||
conversation_history = [
|
||||
{"role": record.role, "content": record.content}
|
||||
for record in history_records
|
||||
]
|
||||
user_content = json.dumps(request.user_input, ensure_ascii=False)
|
||||
conversation_history.append({"role": "user", "content": user_content})
|
||||
|
||||
system_prompt = _ensure_prompt(await prompt_service.get_prompt("concept"), "concept")
|
||||
system_prompt = f"{system_prompt}\n{JSON_RESPONSE_INSTRUCTION}"
|
||||
|
||||
llm_response = await llm_service.get_llm_response(
|
||||
system_prompt=system_prompt,
|
||||
conversation_history=conversation_history,
|
||||
temperature=0.8,
|
||||
user_id=current_user.id,
|
||||
timeout=240.0,
|
||||
)
|
||||
llm_response = remove_think_tags(llm_response)
|
||||
|
||||
try:
|
||||
normalized = unwrap_markdown_json(llm_response)
|
||||
parsed = json.loads(normalized)
|
||||
except json.JSONDecodeError as exc:
|
||||
logger.exception(
|
||||
"Failed to parse concept converse response: project_id=%s user_id=%s normalized=%s",
|
||||
project_id,
|
||||
current_user.id,
|
||||
normalized,
|
||||
)
|
||||
raise HTTPException(status_code=500, detail="AI 返回内容不是有效的 JSON") from exc
|
||||
|
||||
await novel_service.append_conversation(project_id, "user", user_content)
|
||||
await novel_service.append_conversation(project_id, "assistant", normalized)
|
||||
|
||||
logger.info("项目 %s 概念对话完成,is_complete=%s", project_id, parsed.get("is_complete"))
|
||||
|
||||
if parsed.get("is_complete"):
|
||||
parsed["ready_for_blueprint"] = True
|
||||
|
||||
parsed.setdefault("conversation_state", parsed.get("conversation_state", {}))
|
||||
return ConverseResponse(**parsed)
|
||||
|
||||
|
||||
@router.post("/{project_id}/blueprint/generate", response_model=BlueprintGenerationResponse)
|
||||
async def generate_blueprint(
|
||||
project_id: str,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
current_user: UserInDB = Depends(get_current_user),
|
||||
) -> BlueprintGenerationResponse:
|
||||
"""根据完整对话生成可执行的小说蓝图。"""
|
||||
novel_service = NovelService(session)
|
||||
prompt_service = PromptService(session)
|
||||
llm_service = LLMService(session)
|
||||
|
||||
project = await novel_service.ensure_project_owner(project_id, current_user.id)
|
||||
logger.info("项目 %s 开始生成蓝图", project_id)
|
||||
|
||||
history_records = await novel_service.list_conversations(project_id)
|
||||
if not history_records:
|
||||
raise HTTPException(status_code=400, detail="缺少对话历史,无法生成蓝图")
|
||||
|
||||
formatted_history: List[Dict[str, str]] = []
|
||||
for record in history_records:
|
||||
role = record.role
|
||||
content = record.content
|
||||
if not role or not content:
|
||||
continue
|
||||
try:
|
||||
normalized = unwrap_markdown_json(content)
|
||||
data = json.loads(normalized)
|
||||
if role == "user":
|
||||
user_value = data.get("value", data)
|
||||
if isinstance(user_value, str):
|
||||
formatted_history.append({"role": "user", "content": user_value})
|
||||
elif role == "assistant":
|
||||
ai_message = data.get("ai_message") if isinstance(data, dict) else None
|
||||
if ai_message:
|
||||
formatted_history.append({"role": "assistant", "content": ai_message})
|
||||
except (json.JSONDecodeError, AttributeError):
|
||||
continue
|
||||
|
||||
if not formatted_history:
|
||||
raise HTTPException(status_code=400, detail="无法从历史对话中提取内容")
|
||||
|
||||
system_prompt = _ensure_prompt(await prompt_service.get_prompt("screenwriting"), "screenwriting")
|
||||
blueprint_raw = await llm_service.get_llm_response(
|
||||
system_prompt=system_prompt,
|
||||
conversation_history=formatted_history,
|
||||
temperature=0.3,
|
||||
user_id=current_user.id,
|
||||
timeout=480.0,
|
||||
)
|
||||
blueprint_raw = remove_think_tags(blueprint_raw)
|
||||
|
||||
blueprint_normalized = unwrap_markdown_json(blueprint_raw)
|
||||
try:
|
||||
blueprint_data = json.loads(blueprint_normalized)
|
||||
except json.JSONDecodeError as exc:
|
||||
raise HTTPException(status_code=500, detail="蓝图生成失败,请稍后重试") from exc
|
||||
|
||||
blueprint = Blueprint(**blueprint_data)
|
||||
await novel_service.replace_blueprint(project_id, blueprint)
|
||||
if blueprint.title:
|
||||
project.title = blueprint.title
|
||||
project.status = "blueprint_ready"
|
||||
await session.commit()
|
||||
logger.info("项目 %s 更新标题为 %s,并标记为 blueprint_ready", project_id, blueprint.title)
|
||||
|
||||
ai_message = (
|
||||
"太棒了!我已经根据我们的对话整理出完整的小说蓝图。请确认是否进入写作阶段,或提出修改意见。"
|
||||
)
|
||||
return BlueprintGenerationResponse(blueprint=blueprint, ai_message=ai_message)
|
||||
|
||||
|
||||
@router.post("/{project_id}/blueprint/save", response_model=NovelProjectSchema)
|
||||
async def save_blueprint(
|
||||
project_id: str,
|
||||
blueprint_data: Blueprint | None = Body(None),
|
||||
session: AsyncSession = Depends(get_session),
|
||||
current_user: UserInDB = Depends(get_current_user),
|
||||
) -> NovelProjectSchema:
|
||||
"""保存蓝图信息,可用于手动覆盖自动生成结果。"""
|
||||
novel_service = NovelService(session)
|
||||
project = await novel_service.ensure_project_owner(project_id, current_user.id)
|
||||
|
||||
if blueprint_data:
|
||||
await novel_service.replace_blueprint(project_id, blueprint_data)
|
||||
if blueprint_data.title:
|
||||
project.title = blueprint_data.title
|
||||
await session.commit()
|
||||
logger.info("项目 %s 手动保存蓝图", project_id)
|
||||
else:
|
||||
raise HTTPException(status_code=400, detail="缺少蓝图数据")
|
||||
|
||||
return await novel_service.get_project_schema(project_id, current_user.id)
|
||||
|
||||
|
||||
@router.patch("/{project_id}/blueprint", response_model=NovelProjectSchema)
|
||||
async def patch_blueprint(
|
||||
project_id: str,
|
||||
payload: BlueprintPatch,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
current_user: UserInDB = Depends(get_current_user),
|
||||
) -> NovelProjectSchema:
|
||||
"""局部更新蓝图字段,对世界观或角色做微调。"""
|
||||
novel_service = NovelService(session)
|
||||
project = await novel_service.ensure_project_owner(project_id, current_user.id)
|
||||
|
||||
update_data = payload.model_dump(exclude_unset=True)
|
||||
await novel_service.patch_blueprint(project_id, update_data)
|
||||
logger.info("项目 %s 局部更新蓝图字段:%s", project_id, list(update_data.keys()))
|
||||
return await novel_service.get_project_schema(project_id, current_user.id)
|
||||
22
backend/app/api/routers/updates.py
Normal file
22
backend/app/api/routers/updates.py
Normal file
@@ -0,0 +1,22 @@
|
||||
from typing import List
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from ...db.session import get_session
|
||||
from ...schemas.admin import UpdateLogRead
|
||||
from ...services.update_log_service import UpdateLogService
|
||||
|
||||
router = APIRouter(prefix="/api/updates", tags=["Updates"])
|
||||
|
||||
|
||||
def get_update_log_service(session: AsyncSession = Depends(get_session)) -> UpdateLogService:
|
||||
return UpdateLogService(session)
|
||||
|
||||
|
||||
@router.get("/latest", response_model=List[UpdateLogRead])
|
||||
async def read_latest_updates(
|
||||
service: UpdateLogService = Depends(get_update_log_service),
|
||||
) -> List[UpdateLogRead]:
|
||||
logs = await service.list_logs(limit=5)
|
||||
return [UpdateLogRead.model_validate(log) for log in logs]
|
||||
613
backend/app/api/routers/writer.py
Normal file
613
backend/app/api/routers/writer.py
Normal file
@@ -0,0 +1,613 @@
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from fastapi import APIRouter, Body, Depends, HTTPException
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from ...core.config import settings
|
||||
from ...core.dependencies import get_current_user
|
||||
from ...db.session import get_session
|
||||
from ...models.novel import Chapter, ChapterOutline
|
||||
from ...schemas.novel import (
|
||||
DeleteChapterRequest,
|
||||
EditChapterRequest,
|
||||
EvaluateChapterRequest,
|
||||
GenerateChapterRequest,
|
||||
GenerateOutlineRequest,
|
||||
NovelProject as NovelProjectSchema,
|
||||
SelectVersionRequest,
|
||||
UpdateChapterOutlineRequest,
|
||||
)
|
||||
from ...schemas.user import UserInDB
|
||||
from ...services.chapter_context_service import ChapterContextService
|
||||
from ...services.chapter_ingest_service import ChapterIngestionService
|
||||
from ...services.llm_service import LLMService
|
||||
from ...services.novel_service import NovelService
|
||||
from ...services.prompt_service import PromptService
|
||||
from ...services.vector_store_service import VectorStoreService
|
||||
from ...utils.json_utils import remove_think_tags, unwrap_markdown_json
|
||||
from ...repositories.system_config_repository import SystemConfigRepository
|
||||
|
||||
router = APIRouter(prefix="/api/writer", tags=["Writer"])
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def _load_project_schema(service: NovelService, project_id: str, user_id: int) -> NovelProjectSchema:
|
||||
return await service.get_project_schema(project_id, user_id)
|
||||
|
||||
|
||||
def _extract_tail_excerpt(text: Optional[str], limit: int = 500) -> str:
|
||||
"""截取章节结尾文本,默认保留 500 字。"""
|
||||
if not text:
|
||||
return ""
|
||||
stripped = text.strip()
|
||||
if len(stripped) <= limit:
|
||||
return stripped
|
||||
return stripped[-limit:]
|
||||
|
||||
|
||||
@router.post("/novels/{project_id}/chapters/generate", response_model=NovelProjectSchema)
|
||||
async def generate_chapter(
|
||||
project_id: str,
|
||||
request: GenerateChapterRequest,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
current_user: UserInDB = Depends(get_current_user),
|
||||
) -> NovelProjectSchema:
|
||||
novel_service = NovelService(session)
|
||||
prompt_service = PromptService(session)
|
||||
llm_service = LLMService(session)
|
||||
|
||||
project = await novel_service.ensure_project_owner(project_id, current_user.id)
|
||||
logger.info("用户 %s 开始为项目 %s 生成第 %s 章", current_user.id, project_id, request.chapter_number)
|
||||
outline = await novel_service.get_outline(project_id, request.chapter_number)
|
||||
if not outline:
|
||||
logger.warning("项目 %s 未找到第 %s 章纲要,生成流程终止", project_id, request.chapter_number)
|
||||
raise HTTPException(status_code=404, detail="蓝图中未找到对应章节纲要")
|
||||
|
||||
chapter = await novel_service.get_or_create_chapter(project_id, request.chapter_number)
|
||||
chapter.real_summary = None
|
||||
chapter.selected_version_id = None
|
||||
chapter.status = "generating"
|
||||
await session.commit()
|
||||
|
||||
outlines_map = {item.chapter_number: item for item in project.outlines}
|
||||
# 收集所有可用的历史章节摘要,便于在 Prompt 中提供前情背景
|
||||
completed_chapters = []
|
||||
latest_prev_number = -1
|
||||
previous_summary_text = ""
|
||||
previous_tail_excerpt = ""
|
||||
for existing in project.chapters:
|
||||
if existing.chapter_number >= request.chapter_number:
|
||||
continue
|
||||
if existing.selected_version is None or not existing.selected_version.content:
|
||||
continue
|
||||
if not existing.real_summary:
|
||||
summary = await llm_service.get_summary(
|
||||
existing.selected_version.content,
|
||||
temperature=0.15,
|
||||
user_id=current_user.id,
|
||||
timeout=180.0,
|
||||
)
|
||||
existing.real_summary = remove_think_tags(summary)
|
||||
await session.commit()
|
||||
completed_chapters.append(
|
||||
{
|
||||
"chapter_number": existing.chapter_number,
|
||||
"title": outlines_map.get(existing.chapter_number).title if outlines_map.get(existing.chapter_number) else f"第{existing.chapter_number}章",
|
||||
"summary": existing.real_summary,
|
||||
}
|
||||
)
|
||||
if existing.chapter_number > latest_prev_number:
|
||||
latest_prev_number = existing.chapter_number
|
||||
previous_summary_text = existing.real_summary or ""
|
||||
previous_tail_excerpt = _extract_tail_excerpt(existing.selected_version.content)
|
||||
|
||||
project_schema = await novel_service._serialize_project(project)
|
||||
blueprint_dict = project_schema.blueprint.model_dump()
|
||||
|
||||
if "relationships" in blueprint_dict and blueprint_dict["relationships"]:
|
||||
for relation in blueprint_dict["relationships"]:
|
||||
if "character_from" in relation:
|
||||
relation["from"] = relation.pop("character_from")
|
||||
if "character_to" in relation:
|
||||
relation["to"] = relation.pop("character_to")
|
||||
|
||||
# 蓝图中禁止携带章节级别的细节信息,避免重复传输大段场景或对话内容
|
||||
banned_blueprint_keys = {
|
||||
"chapter_outline",
|
||||
"chapter_summaries",
|
||||
"chapter_details",
|
||||
"chapter_dialogues",
|
||||
"chapter_events",
|
||||
"conversation_history",
|
||||
"character_timelines",
|
||||
}
|
||||
for key in banned_blueprint_keys:
|
||||
if key in blueprint_dict:
|
||||
blueprint_dict.pop(key, None)
|
||||
|
||||
writer_prompt = await prompt_service.get_prompt("writing")
|
||||
if not writer_prompt:
|
||||
raise HTTPException(status_code=500, detail="缺少写作提示词")
|
||||
|
||||
# 初始化向量检索服务,若未配置则自动降级为纯提示词生成
|
||||
vector_store: Optional[VectorStoreService]
|
||||
if not settings.vector_store_enabled:
|
||||
vector_store = None
|
||||
else:
|
||||
try:
|
||||
vector_store = VectorStoreService()
|
||||
except RuntimeError as exc:
|
||||
logger.warning("向量库初始化失败,RAG 检索被禁用: %s", exc)
|
||||
vector_store = None
|
||||
context_service = ChapterContextService(llm_service=llm_service, vector_store=vector_store)
|
||||
|
||||
outline_title = outline.title or f"第{outline.chapter_number}章"
|
||||
outline_summary = outline.summary or "暂无摘要"
|
||||
query_parts = [outline_title, outline_summary]
|
||||
if request.writing_notes:
|
||||
query_parts.append(request.writing_notes)
|
||||
rag_query = "\n".join(part for part in query_parts if part)
|
||||
rag_context = await context_service.retrieve_for_generation(
|
||||
project_id=project_id,
|
||||
query_text=rag_query or outline.title or outline.summary or "",
|
||||
user_id=current_user.id,
|
||||
)
|
||||
chunk_count = len(rag_context.chunks) if rag_context and rag_context.chunks else 0
|
||||
summary_count = len(rag_context.summaries) if rag_context and rag_context.summaries else 0
|
||||
logger.info(
|
||||
"项目 %s 第 %s 章检索到 %s 个剧情片段和 %s 条摘要",
|
||||
project_id,
|
||||
request.chapter_number,
|
||||
chunk_count,
|
||||
summary_count,
|
||||
)
|
||||
# print("rag_context:",rag_context)
|
||||
# 将蓝图、前情、RAG 检索结果拼装成结构化段落,供模型理解
|
||||
blueprint_text = json.dumps(blueprint_dict, ensure_ascii=False, indent=2)
|
||||
completed_lines = [
|
||||
f"- 第{item['chapter_number']}章 - {item['title']}:{item['summary']}"
|
||||
for item in completed_chapters
|
||||
]
|
||||
previous_summary_text = previous_summary_text or "暂无可用摘要"
|
||||
previous_tail_excerpt = previous_tail_excerpt or "暂无上一章结尾内容"
|
||||
completed_section = "\n".join(completed_lines) if completed_lines else "暂无前情摘要"
|
||||
rag_chunks_text = "\n\n".join(rag_context.chunk_texts()) if rag_context.chunks else "未检索到章节片段"
|
||||
rag_summaries_text = "\n".join(rag_context.summary_lines()) if rag_context.summaries else "未检索到章节摘要"
|
||||
writing_notes = request.writing_notes or "无额外写作指令"
|
||||
|
||||
prompt_sections = [
|
||||
("[世界蓝图](JSON)", blueprint_text),
|
||||
# ("[前情摘要]", completed_section),
|
||||
("[上一章摘要]", previous_summary_text),
|
||||
("[上一章结尾]", previous_tail_excerpt),
|
||||
("[检索到的剧情上下文](Markdown)", rag_chunks_text),
|
||||
("[检索到的章节摘要]", rag_summaries_text),
|
||||
(
|
||||
"[当前章节目标]",
|
||||
f"标题:{outline_title}\n摘要:{outline_summary}\n写作要求:{writing_notes}",
|
||||
),
|
||||
]
|
||||
prompt_input = "\n\n".join(f"{title}\n{content}" for title, content in prompt_sections if content)
|
||||
logger.debug("章节写作提示词:%s\n%s", writer_prompt, prompt_input)
|
||||
async def _generate_single_version(idx: int) -> Dict:
|
||||
try:
|
||||
response = await llm_service.get_llm_response(
|
||||
system_prompt=writer_prompt,
|
||||
conversation_history=[{"role": "user", "content": prompt_input}],
|
||||
temperature=0.9,
|
||||
user_id=current_user.id,
|
||||
timeout=600.0,
|
||||
)
|
||||
cleaned = remove_think_tags(response)
|
||||
normalized = unwrap_markdown_json(cleaned)
|
||||
try:
|
||||
return json.loads(normalized)
|
||||
except json.JSONDecodeError:
|
||||
return {"content": normalized}
|
||||
except Exception as exc:
|
||||
logger.exception(
|
||||
"项目 %s 生成第 %s 章第 %s 个版本时发生异常: %s",
|
||||
project_id,
|
||||
request.chapter_number,
|
||||
idx + 1,
|
||||
exc,
|
||||
)
|
||||
return {"content": f"生成失败: {exc}"}
|
||||
|
||||
version_count = await _resolve_version_count(session)
|
||||
logger.info(
|
||||
"项目 %s 第 %s 章计划生成 %s 个版本",
|
||||
project_id,
|
||||
request.chapter_number,
|
||||
version_count,
|
||||
)
|
||||
raw_versions = []
|
||||
for idx in range(version_count):
|
||||
raw_versions.append(await _generate_single_version(idx))
|
||||
contents: List[str] = []
|
||||
metadata: List[Dict] = []
|
||||
for variant in raw_versions:
|
||||
if isinstance(variant, dict):
|
||||
if "content" in variant and isinstance(variant["content"], str):
|
||||
contents.append(variant["content"])
|
||||
elif "chapter_content" in variant:
|
||||
contents.append(str(variant["chapter_content"]))
|
||||
else:
|
||||
contents.append(json.dumps(variant, ensure_ascii=False))
|
||||
metadata.append(variant)
|
||||
else:
|
||||
contents.append(str(variant))
|
||||
metadata.append({"raw": variant})
|
||||
|
||||
await novel_service.replace_chapter_versions(chapter, contents, metadata)
|
||||
logger.info(
|
||||
"项目 %s 第 %s 章生成完成,已写入 %s 个版本",
|
||||
project_id,
|
||||
request.chapter_number,
|
||||
len(contents),
|
||||
)
|
||||
return await _load_project_schema(novel_service, project_id, current_user.id)
|
||||
|
||||
|
||||
async def _resolve_version_count(session: AsyncSession) -> int:
|
||||
repo = SystemConfigRepository(session)
|
||||
record = await repo.get_by_key("writer.chapter_versions")
|
||||
if record:
|
||||
try:
|
||||
value = int(record.value)
|
||||
if value > 0:
|
||||
return value
|
||||
except (TypeError, ValueError):
|
||||
pass
|
||||
env_value = os.getenv("WRITER_CHAPTER_VERSION_COUNT")
|
||||
if env_value:
|
||||
try:
|
||||
value = int(env_value)
|
||||
if value > 0:
|
||||
return value
|
||||
except ValueError:
|
||||
pass
|
||||
return 3
|
||||
|
||||
|
||||
@router.post("/novels/{project_id}/chapters/select", response_model=NovelProjectSchema)
|
||||
async def select_chapter_version(
|
||||
project_id: str,
|
||||
request: SelectVersionRequest,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
current_user: UserInDB = Depends(get_current_user),
|
||||
) -> NovelProjectSchema:
|
||||
novel_service = NovelService(session)
|
||||
llm_service = LLMService(session)
|
||||
|
||||
project = await novel_service.ensure_project_owner(project_id, current_user.id)
|
||||
chapter = next((ch for ch in project.chapters if ch.chapter_number == request.chapter_number), None)
|
||||
if not chapter:
|
||||
logger.warning("项目 %s 未找到第 %s 章,无法选择版本", project_id, request.chapter_number)
|
||||
raise HTTPException(status_code=404, detail="章节不存在")
|
||||
|
||||
selected = await novel_service.select_chapter_version(chapter, request.version_index)
|
||||
logger.info(
|
||||
"用户 %s 选择了项目 %s 第 %s 章的第 %s 个版本",
|
||||
current_user.id,
|
||||
project_id,
|
||||
request.chapter_number,
|
||||
request.version_index,
|
||||
)
|
||||
if selected and selected.content:
|
||||
summary = await llm_service.get_summary(
|
||||
selected.content,
|
||||
temperature=0.15,
|
||||
user_id=current_user.id,
|
||||
timeout=180.0,
|
||||
)
|
||||
chapter.real_summary = remove_think_tags(summary)
|
||||
await session.commit()
|
||||
|
||||
# 选定版本后同步向量库,确保后续章节可检索到最新内容
|
||||
vector_store: Optional[VectorStoreService]
|
||||
if not settings.vector_store_enabled:
|
||||
vector_store = None
|
||||
else:
|
||||
try:
|
||||
vector_store = VectorStoreService()
|
||||
except RuntimeError as exc:
|
||||
logger.warning("向量库初始化失败,跳过章节向量同步: %s", exc)
|
||||
vector_store = None
|
||||
|
||||
if vector_store:
|
||||
ingestion_service = ChapterIngestionService(llm_service=llm_service, vector_store=vector_store)
|
||||
outline = next((item for item in project.outlines if item.chapter_number == chapter.chapter_number), None)
|
||||
chapter_title = outline.title if outline and outline.title else f"第{chapter.chapter_number}章"
|
||||
await ingestion_service.ingest_chapter(
|
||||
project_id=project_id,
|
||||
chapter_number=chapter.chapter_number,
|
||||
title=chapter_title,
|
||||
content=selected.content,
|
||||
summary=chapter.real_summary,
|
||||
user_id=current_user.id,
|
||||
)
|
||||
logger.info(
|
||||
"项目 %s 第 %s 章已同步至向量库",
|
||||
project_id,
|
||||
chapter.chapter_number,
|
||||
)
|
||||
|
||||
return await _load_project_schema(novel_service, project_id, current_user.id)
|
||||
|
||||
|
||||
@router.post("/novels/{project_id}/chapters/evaluate", response_model=NovelProjectSchema)
|
||||
async def evaluate_chapter(
|
||||
project_id: str,
|
||||
request: EvaluateChapterRequest,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
current_user: UserInDB = Depends(get_current_user),
|
||||
) -> NovelProjectSchema:
|
||||
novel_service = NovelService(session)
|
||||
prompt_service = PromptService(session)
|
||||
llm_service = LLMService(session)
|
||||
|
||||
project = await novel_service.ensure_project_owner(project_id, current_user.id)
|
||||
chapter = next((ch for ch in project.chapters if ch.chapter_number == request.chapter_number), None)
|
||||
if not chapter:
|
||||
logger.warning("项目 %s 未找到第 %s 章,无法执行评估", project_id, request.chapter_number)
|
||||
raise HTTPException(status_code=404, detail="章节不存在")
|
||||
if not chapter.versions:
|
||||
logger.warning("项目 %s 第 %s 章无可评估版本", project_id, request.chapter_number)
|
||||
raise HTTPException(status_code=400, detail="无可评估的章节版本")
|
||||
|
||||
evaluator_prompt = await prompt_service.get_prompt("evaluation")
|
||||
if not evaluator_prompt:
|
||||
logger.error("缺少评估提示词,项目 %s 第 %s 章评估失败", project_id, request.chapter_number)
|
||||
raise HTTPException(status_code=500, detail="缺少评估提示词")
|
||||
|
||||
project_schema = await novel_service._serialize_project(project)
|
||||
blueprint_dict = project_schema.blueprint.model_dump()
|
||||
|
||||
versions_to_evaluate = [
|
||||
{"version_id": idx + 1, "content": version.content}
|
||||
for idx, version in enumerate(sorted(chapter.versions, key=lambda item: item.created_at))
|
||||
]
|
||||
# print("blueprint_dict:",blueprint_dict)
|
||||
evaluator_payload = {
|
||||
"novel_blueprint": blueprint_dict,
|
||||
"content_to_evaluate": {
|
||||
"chapter_number": chapter.chapter_number,
|
||||
"versions": versions_to_evaluate,
|
||||
},
|
||||
}
|
||||
|
||||
evaluation_raw = await llm_service.get_llm_response(
|
||||
system_prompt=evaluator_prompt,
|
||||
conversation_history=[{"role": "user", "content": json.dumps(evaluator_payload, ensure_ascii=False)}],
|
||||
temperature=0.3,
|
||||
user_id=current_user.id,
|
||||
timeout=360.0,
|
||||
)
|
||||
evaluation_clean = remove_think_tags(evaluation_raw)
|
||||
await novel_service.add_chapter_evaluation(chapter, None, evaluation_clean)
|
||||
logger.info("项目 %s 第 %s 章评估完成", project_id, request.chapter_number)
|
||||
|
||||
return await _load_project_schema(novel_service, project_id, current_user.id)
|
||||
|
||||
|
||||
@router.post("/novels/{project_id}/chapters/outline", response_model=NovelProjectSchema)
|
||||
async def generate_chapter_outline(
|
||||
project_id: str,
|
||||
request: GenerateOutlineRequest,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
current_user: UserInDB = Depends(get_current_user),
|
||||
) -> NovelProjectSchema:
|
||||
novel_service = NovelService(session)
|
||||
prompt_service = PromptService(session)
|
||||
llm_service = LLMService(session)
|
||||
|
||||
await novel_service.ensure_project_owner(project_id, current_user.id)
|
||||
logger.info(
|
||||
"用户 %s 请求生成项目 %s 的章节大纲,起始章节 %s,数量 %s",
|
||||
current_user.id,
|
||||
project_id,
|
||||
request.start_chapter,
|
||||
request.num_chapters,
|
||||
)
|
||||
outline_prompt = await prompt_service.get_prompt("outline")
|
||||
if not outline_prompt:
|
||||
logger.error("缺少大纲提示词,项目 %s 大纲生成失败", project_id)
|
||||
raise HTTPException(status_code=500, detail="缺少大纲提示词")
|
||||
|
||||
project_schema = await novel_service.get_project_schema(project_id, current_user.id)
|
||||
blueprint_dict = project_schema.blueprint.model_dump()
|
||||
|
||||
payload = {
|
||||
"novel_blueprint": blueprint_dict,
|
||||
"wait_to_generate": {
|
||||
"start_chapter": request.start_chapter,
|
||||
"num_chapters": request.num_chapters,
|
||||
},
|
||||
}
|
||||
|
||||
response = await llm_service.get_llm_response(
|
||||
system_prompt=outline_prompt,
|
||||
conversation_history=[{"role": "user", "content": json.dumps(payload, ensure_ascii=False)}],
|
||||
temperature=0.7,
|
||||
user_id=current_user.id,
|
||||
timeout=360.0,
|
||||
)
|
||||
normalized = unwrap_markdown_json(remove_think_tags(response))
|
||||
try:
|
||||
data = json.loads(normalized)
|
||||
except json.JSONDecodeError as exc:
|
||||
raise HTTPException(status_code=500, detail="章节大纲生成失败") from exc
|
||||
|
||||
new_outlines = data.get("chapters", [])
|
||||
for item in new_outlines:
|
||||
stmt = (
|
||||
select(ChapterOutline)
|
||||
.where(
|
||||
ChapterOutline.project_id == project_id,
|
||||
ChapterOutline.chapter_number == item.get("chapter_number"),
|
||||
)
|
||||
)
|
||||
result = await session.execute(stmt)
|
||||
record = result.scalars().first()
|
||||
if record:
|
||||
record.title = item.get("title", record.title)
|
||||
record.summary = item.get("summary", record.summary)
|
||||
else:
|
||||
session.add(
|
||||
ChapterOutline(
|
||||
project_id=project_id,
|
||||
chapter_number=item.get("chapter_number"),
|
||||
title=item.get("title", ""),
|
||||
summary=item.get("summary"),
|
||||
)
|
||||
)
|
||||
await session.commit()
|
||||
logger.info("项目 %s 章节大纲生成完成", project_id)
|
||||
|
||||
return await novel_service.get_project_schema(project_id, current_user.id)
|
||||
|
||||
|
||||
@router.post("/novels/{project_id}/chapters/update-outline", response_model=NovelProjectSchema)
|
||||
async def update_chapter_outline(
|
||||
project_id: str,
|
||||
request: UpdateChapterOutlineRequest,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
current_user: UserInDB = Depends(get_current_user),
|
||||
) -> NovelProjectSchema:
|
||||
novel_service = NovelService(session)
|
||||
await novel_service.ensure_project_owner(project_id, current_user.id)
|
||||
logger.info(
|
||||
"用户 %s 更新项目 %s 第 %s 章大纲",
|
||||
current_user.id,
|
||||
project_id,
|
||||
request.chapter_number,
|
||||
)
|
||||
|
||||
stmt = (
|
||||
select(ChapterOutline)
|
||||
.where(
|
||||
ChapterOutline.project_id == project_id,
|
||||
ChapterOutline.chapter_number == request.chapter_number,
|
||||
)
|
||||
)
|
||||
result = await session.execute(stmt)
|
||||
outline = result.scalars().first()
|
||||
if not outline:
|
||||
outline = ChapterOutline(
|
||||
project_id=project_id,
|
||||
chapter_number=request.chapter_number,
|
||||
)
|
||||
session.add(outline)
|
||||
|
||||
outline.title = request.title
|
||||
outline.summary = request.summary
|
||||
await session.commit()
|
||||
logger.info("项目 %s 第 %s 章大纲已更新", project_id, request.chapter_number)
|
||||
|
||||
return await novel_service.get_project_schema(project_id, current_user.id)
|
||||
|
||||
|
||||
@router.post("/novels/{project_id}/chapters/delete", response_model=NovelProjectSchema)
|
||||
async def delete_chapters(
|
||||
project_id: str,
|
||||
request: DeleteChapterRequest,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
current_user: UserInDB = Depends(get_current_user),
|
||||
) -> NovelProjectSchema:
|
||||
if not request.chapter_numbers:
|
||||
logger.warning("项目 %s 未提供要删除的章节号", project_id)
|
||||
raise HTTPException(status_code=400, detail="请提供要删除的章节号")
|
||||
novel_service = NovelService(session)
|
||||
llm_service = LLMService(session)
|
||||
await novel_service.ensure_project_owner(project_id, current_user.id)
|
||||
logger.info(
|
||||
"用户 %s 删除项目 %s 的章节 %s",
|
||||
current_user.id,
|
||||
project_id,
|
||||
request.chapter_numbers,
|
||||
)
|
||||
await novel_service.delete_chapters(project_id, request.chapter_numbers)
|
||||
|
||||
# 删除章节时同步清理向量库,避免过时内容被检索
|
||||
vector_store: Optional[VectorStoreService]
|
||||
if not settings.vector_store_enabled:
|
||||
vector_store = None
|
||||
else:
|
||||
try:
|
||||
vector_store = VectorStoreService()
|
||||
except RuntimeError as exc:
|
||||
logger.warning("向量库初始化失败,跳过章节向量删除: %s", exc)
|
||||
vector_store = None
|
||||
|
||||
if vector_store:
|
||||
ingestion_service = ChapterIngestionService(llm_service=llm_service, vector_store=vector_store)
|
||||
await ingestion_service.delete_chapters(project_id, request.chapter_numbers)
|
||||
logger.info(
|
||||
"项目 %s 已从向量库移除章节 %s",
|
||||
project_id,
|
||||
request.chapter_numbers,
|
||||
)
|
||||
|
||||
return await novel_service.get_project_schema(project_id, current_user.id)
|
||||
|
||||
|
||||
@router.post("/novels/{project_id}/chapters/edit", response_model=NovelProjectSchema)
|
||||
async def edit_chapter(
|
||||
project_id: str,
|
||||
request: EditChapterRequest,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
current_user: UserInDB = Depends(get_current_user),
|
||||
) -> NovelProjectSchema:
|
||||
novel_service = NovelService(session)
|
||||
llm_service = LLMService(session)
|
||||
|
||||
project = await novel_service.ensure_project_owner(project_id, current_user.id)
|
||||
chapter = next((ch for ch in project.chapters if ch.chapter_number == request.chapter_number), None)
|
||||
if not chapter or chapter.selected_version is None:
|
||||
logger.warning("项目 %s 第 %s 章尚未生成或未选择版本,无法编辑", project_id, request.chapter_number)
|
||||
raise HTTPException(status_code=404, detail="章节尚未生成或未选择版本")
|
||||
|
||||
chapter.selected_version.content = request.content
|
||||
chapter.word_count = len(request.content)
|
||||
logger.info("用户 %s 更新了项目 %s 第 %s 章内容", current_user.id, project_id, request.chapter_number)
|
||||
|
||||
if request.content.strip():
|
||||
summary = await llm_service.get_summary(
|
||||
request.content,
|
||||
temperature=0.15,
|
||||
user_id=current_user.id,
|
||||
timeout=180.0,
|
||||
)
|
||||
chapter.real_summary = remove_think_tags(summary)
|
||||
await session.commit()
|
||||
|
||||
vector_store: Optional[VectorStoreService]
|
||||
if not settings.vector_store_enabled:
|
||||
vector_store = None
|
||||
else:
|
||||
try:
|
||||
vector_store = VectorStoreService()
|
||||
except RuntimeError as exc:
|
||||
logger.warning("向量库初始化失败,跳过章节向量更新: %s", exc)
|
||||
vector_store = None
|
||||
|
||||
if vector_store and chapter.selected_version and chapter.selected_version.content:
|
||||
ingestion_service = ChapterIngestionService(llm_service=llm_service, vector_store=vector_store)
|
||||
outline = next((item for item in project.outlines if item.chapter_number == chapter.chapter_number), None)
|
||||
chapter_title = outline.title if outline and outline.title else f"第{chapter.chapter_number}章"
|
||||
await ingestion_service.ingest_chapter(
|
||||
project_id=project_id,
|
||||
chapter_number=chapter.chapter_number,
|
||||
title=chapter_title,
|
||||
content=chapter.selected_version.content,
|
||||
summary=chapter.real_summary,
|
||||
user_id=current_user.id,
|
||||
)
|
||||
logger.info("项目 %s 第 %s 章更新内容已同步至向量库", project_id, chapter.chapter_number)
|
||||
|
||||
return await novel_service.get_project_schema(project_id, current_user.id)
|
||||
0
backend/app/core/__init__.py
Normal file
0
backend/app/core/__init__.py
Normal file
261
backend/app/core/config.py
Normal file
261
backend/app/core/config.py
Normal file
@@ -0,0 +1,261 @@
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import AliasChoices, AnyUrl, Field, HttpUrl, validator
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
from sqlalchemy.engine import URL, make_url
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
"""应用全局配置,所有可调参数集中于此,统一加载自环境变量。"""
|
||||
|
||||
# -------------------- 基础应用配置 --------------------
|
||||
app_name: str = Field(default="AI Novel Generator API", description="FastAPI 文档标题")
|
||||
environment: str = Field(default="development", description="当前环境标识")
|
||||
debug: bool = Field(default=True, description="是否开启调试模式")
|
||||
allow_registration: bool = Field(
|
||||
default=True,
|
||||
env="ALLOW_USER_REGISTRATION",
|
||||
description="是否允许用户自助注册",
|
||||
)
|
||||
logging_level: str = Field(
|
||||
default="INFO",
|
||||
env="LOGGING_LEVEL",
|
||||
description="应用日志级别",
|
||||
)
|
||||
enable_linuxdo_login: bool = Field(
|
||||
default=False,
|
||||
env="ENABLE_LINUXDO_LOGIN",
|
||||
description="是否启用 Linux.do OAuth 登录",
|
||||
)
|
||||
|
||||
# -------------------- 安全相关配置 --------------------
|
||||
secret_key: str = Field(..., env="SECRET_KEY", description="JWT 加密密钥")
|
||||
jwt_algorithm: str = Field(default="HS256", env="JWT_ALGORITHM", description="JWT 加密算法")
|
||||
access_token_expire_minutes: int = Field(
|
||||
default=60 * 24 * 7,
|
||||
env="ACCESS_TOKEN_EXPIRE_MINUTES",
|
||||
description="访问令牌过期时间,单位分钟"
|
||||
)
|
||||
|
||||
# -------------------- 数据库配置 --------------------
|
||||
database_url: Optional[str] = Field(
|
||||
default=None,
|
||||
env="DATABASE_URL",
|
||||
description="完整的数据库连接串,填入后覆盖下方数据库配置"
|
||||
)
|
||||
db_provider: str = Field(
|
||||
default="mysql",
|
||||
env="DB_PROVIDER",
|
||||
description="数据库类型,仅支持 mysql 或 sqlite"
|
||||
)
|
||||
mysql_host: str = Field(default="localhost", env="MYSQL_HOST", description="MySQL 主机名")
|
||||
mysql_port: int = Field(default=3306, env="MYSQL_PORT", description="MySQL 端口")
|
||||
mysql_user: str = Field(default="root", env="MYSQL_USER", description="MySQL 用户名")
|
||||
mysql_password: str = Field(default="", env="MYSQL_PASSWORD", description="MySQL 密码")
|
||||
mysql_database: str = Field(default="arboris", env="MYSQL_DATABASE", description="MySQL 数据库名称")
|
||||
|
||||
# -------------------- 管理员初始化配置 --------------------
|
||||
admin_default_username: str = Field(default="admin", env="ADMIN_DEFAULT_USERNAME", description="默认管理员用户名")
|
||||
admin_default_password: str = Field(default="ChangeMe123!", env="ADMIN_DEFAULT_PASSWORD", description="默认管理员密码")
|
||||
admin_default_email: Optional[str] = Field(default=None, env="ADMIN_DEFAULT_EMAIL", description="默认管理员邮箱")
|
||||
|
||||
# -------------------- LLM 相关配置 --------------------
|
||||
openai_api_key: Optional[str] = Field(default=None, env="OPENAI_API_KEY", description="默认的 LLM API Key")
|
||||
openai_base_url: Optional[HttpUrl] = Field(
|
||||
default=None,
|
||||
env="OPENAI_API_BASE_URL",
|
||||
validation_alias=AliasChoices("OPENAI_API_BASE_URL", "OPENAI_BASE_URL"),
|
||||
description="LLM API Base URL",
|
||||
)
|
||||
openai_model_name: str = Field(default="gpt-4o-mini", env="OPENAI_MODEL_NAME", description="默认 LLM 模型名称")
|
||||
writer_chapter_versions: int = Field(
|
||||
default=2,
|
||||
ge=1,
|
||||
env="WRITER_CHAPTER_VERSION_COUNT",
|
||||
validation_alias=AliasChoices("WRITER_CHAPTER_VERSION_COUNT", "WRITER_CHAPTER_VERSIONS"),
|
||||
description="每次生成章节的候选版本数量",
|
||||
)
|
||||
embedding_provider: str = Field(
|
||||
default="openai",
|
||||
env="EMBEDDING_PROVIDER",
|
||||
description="嵌入模型提供方,支持 openai 或 ollama",
|
||||
)
|
||||
embedding_base_url: Optional[AnyUrl] = Field(
|
||||
default=None,
|
||||
env="EMBEDDING_BASE_URL",
|
||||
description="嵌入模型使用的 Base URL",
|
||||
)
|
||||
embedding_api_key: Optional[str] = Field(
|
||||
default=None,
|
||||
env="EMBEDDING_API_KEY",
|
||||
description="嵌入模型专用 API Key",
|
||||
)
|
||||
embedding_model: str = Field(
|
||||
default="text-embedding-3-large",
|
||||
env="EMBEDDING_MODEL",
|
||||
validation_alias=AliasChoices("EMBEDDING_MODEL", "VECTOR_EMBEDDING_MODEL"),
|
||||
description="默认的嵌入模型名称",
|
||||
)
|
||||
embedding_model_vector_size: Optional[int] = Field(
|
||||
default=None,
|
||||
env="EMBEDDING_MODEL_VECTOR_SIZE",
|
||||
description="嵌入向量维度,未配置时将自动检测",
|
||||
)
|
||||
ollama_embedding_base_url: Optional[AnyUrl] = Field(
|
||||
default=None,
|
||||
env="OLLAMA_EMBEDDING_BASE_URL",
|
||||
description="Ollama 嵌入模型服务地址",
|
||||
)
|
||||
ollama_embedding_model: str = Field(
|
||||
default="nomic-embed-text:latest",
|
||||
env="OLLAMA_EMBEDDING_MODEL",
|
||||
description="Ollama 嵌入模型名称",
|
||||
)
|
||||
vector_db_url: Optional[str] = Field(
|
||||
default=None,
|
||||
env="VECTOR_DB_URL",
|
||||
description="libsql 向量库连接地址",
|
||||
)
|
||||
vector_db_auth_token: Optional[str] = Field(
|
||||
default=None,
|
||||
env="VECTOR_DB_AUTH_TOKEN",
|
||||
description="libsql 访问令牌",
|
||||
)
|
||||
vector_top_k_chunks: int = Field(
|
||||
default=5,
|
||||
ge=0,
|
||||
env="VECTOR_TOP_K_CHUNKS",
|
||||
description="剧情 chunk 检索条数",
|
||||
)
|
||||
vector_top_k_summaries: int = Field(
|
||||
default=3,
|
||||
ge=0,
|
||||
env="VECTOR_TOP_K_SUMMARIES",
|
||||
description="章节摘要检索条数",
|
||||
)
|
||||
vector_chunk_size: int = Field(
|
||||
default=480,
|
||||
ge=128,
|
||||
env="VECTOR_CHUNK_SIZE",
|
||||
description="章节分块的目标字数",
|
||||
)
|
||||
vector_chunk_overlap: int = Field(
|
||||
default=120,
|
||||
ge=0,
|
||||
env="VECTOR_CHUNK_OVERLAP",
|
||||
description="章节分块重叠字数",
|
||||
)
|
||||
|
||||
# -------------------- Linux.do OAuth 配置 --------------------
|
||||
linuxdo_client_id: Optional[str] = Field(default=None, env="LINUXDO_CLIENT_ID", description="Linux.do OAuth Client ID")
|
||||
linuxdo_client_secret: Optional[str] = Field(
|
||||
default=None, env="LINUXDO_CLIENT_SECRET", description="Linux.do OAuth Client Secret"
|
||||
)
|
||||
linuxdo_redirect_uri: Optional[HttpUrl] = Field(
|
||||
default=None, env="LINUXDO_REDIRECT_URI", description="Linux.do OAuth 回调地址"
|
||||
)
|
||||
linuxdo_auth_url: Optional[HttpUrl] = Field(
|
||||
default=None, env="LINUXDO_AUTH_URL", description="Linux.do OAuth 授权地址"
|
||||
)
|
||||
linuxdo_token_url: Optional[HttpUrl] = Field(
|
||||
default=None, env="LINUXDO_TOKEN_URL", description="Linux.do OAuth Token 获取地址"
|
||||
)
|
||||
linuxdo_user_info_url: Optional[HttpUrl] = Field(
|
||||
default=None, env="LINUXDO_USER_INFO_URL", description="Linux.do 用户信息接口地址"
|
||||
)
|
||||
|
||||
# -------------------- 邮件配置 --------------------
|
||||
smtp_server: Optional[str] = Field(default=None, env="SMTP_SERVER", description="SMTP 服务地址")
|
||||
smtp_port: int = Field(default=587, env="SMTP_PORT", description="SMTP 服务端口")
|
||||
smtp_username: Optional[str] = Field(default=None, env="SMTP_USERNAME", description="SMTP 登录用户名")
|
||||
smtp_password: Optional[str] = Field(default=None, env="SMTP_PASSWORD", description="SMTP 登录密码")
|
||||
email_from: Optional[str] = Field(default=None, env="EMAIL_FROM", description="邮件发送方显示名或邮箱")
|
||||
|
||||
model_config = SettingsConfigDict(
|
||||
env_file=("new-backend/.env", ".env", "backend/.env"),
|
||||
env_file_encoding="utf-8",
|
||||
extra="ignore"
|
||||
)
|
||||
|
||||
@validator("database_url", pre=True, always=True)
|
||||
def _normalize_database_url(cls, value: Optional[str]) -> Optional[str]:
|
||||
"""当环境变量中提供 DATABASE_URL 时,原样返回,便于自定义。"""
|
||||
return value.strip() if isinstance(value, str) and value.strip() else value
|
||||
|
||||
@validator("db_provider", pre=True)
|
||||
def _normalize_db_provider(cls, value: Optional[str]) -> str:
|
||||
"""统一数据库类型大小写,并限制为受支持的驱动。"""
|
||||
candidate = (value or "mysql").strip().lower()
|
||||
if candidate not in {"mysql", "sqlite"}:
|
||||
raise ValueError("DB_PROVIDER 仅支持 mysql 或 sqlite")
|
||||
return candidate
|
||||
@validator("embedding_provider", pre=True)
|
||||
def _normalize_embedding_provider(cls, value: Optional[str]) -> str:
|
||||
"""限制嵌入模型提供方的取值范围。"""
|
||||
candidate = (value or "openai").strip().lower()
|
||||
if candidate not in {"openai", "ollama"}:
|
||||
raise ValueError("EMBEDDING_PROVIDER 仅支持 openai 或 ollama")
|
||||
return candidate
|
||||
|
||||
@validator("logging_level", pre=True)
|
||||
def _normalize_logging_level(cls, value: Optional[str]) -> str:
|
||||
"""规范日志级别配置。"""
|
||||
candidate = (value or "INFO").strip().upper()
|
||||
valid_levels = {"CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"}
|
||||
if candidate not in valid_levels:
|
||||
raise ValueError("LOGGING_LEVEL 仅支持 CRITICAL/ERROR/WARNING/INFO/DEBUG/NOTSET")
|
||||
return candidate
|
||||
|
||||
@property
|
||||
def sqlalchemy_database_uri(self) -> str:
|
||||
"""生成 SQLAlchemy 兼容的异步连接串,数据库类型由 DB_PROVIDER 控制。"""
|
||||
if self.database_url:
|
||||
url = make_url(self.database_url)
|
||||
database = (url.database or "").strip("/")
|
||||
normalized = URL.create(
|
||||
drivername=url.drivername,
|
||||
username=url.username,
|
||||
password=url.password,
|
||||
host=url.host,
|
||||
port=url.port,
|
||||
database=database or None,
|
||||
query=url.query,
|
||||
)
|
||||
return normalized.render_as_string(hide_password=False)
|
||||
|
||||
if self.db_provider == "sqlite":
|
||||
# SQLite 固定使用 storage/arboris.db,并转换为绝对路径以避免运行目录差异
|
||||
project_root = Path(__file__).resolve().parents[2]
|
||||
db_path = (project_root / "storage" / "arboris.db").resolve()
|
||||
return f"sqlite+aiosqlite:///{db_path}"
|
||||
|
||||
# MySQL 分支:统一对密码进行 URL 编码,避免特殊字符破坏连接串
|
||||
from urllib.parse import quote_plus
|
||||
|
||||
encoded_password = quote_plus(self.mysql_password)
|
||||
database = (self.mysql_database or "").strip("/")
|
||||
return (
|
||||
f"mysql+asyncmy://{self.mysql_user}:{encoded_password}"
|
||||
f"@{self.mysql_host}:{self.mysql_port}/{database}"
|
||||
)
|
||||
|
||||
@property
|
||||
def is_sqlite_backend(self) -> bool:
|
||||
"""辅助属性:判断当前连接串是否指向 SQLite,用于差异化初始化流程。"""
|
||||
return make_url(self.sqlalchemy_database_uri).get_backend_name() == "sqlite"
|
||||
|
||||
@property
|
||||
def vector_store_enabled(self) -> bool:
|
||||
"""是否已经配置向量库,用于在业务逻辑中快速判断。"""
|
||||
return bool(self.vector_db_url)
|
||||
|
||||
|
||||
@lru_cache
|
||||
def get_settings() -> Settings:
|
||||
"""使用 LRU 缓存确保配置只初始化一次,减少 IO 与解析开销。"""
|
||||
return Settings()
|
||||
|
||||
|
||||
settings = get_settings()
|
||||
33
backend/app/core/dependencies.py
Normal file
33
backend/app/core/dependencies.py
Normal file
@@ -0,0 +1,33 @@
|
||||
from fastapi import Depends, HTTPException, status
|
||||
from fastapi.security import OAuth2PasswordBearer
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from ..core.security import decode_access_token
|
||||
from ..db.session import get_session
|
||||
from ..repositories.user_repository import UserRepository
|
||||
from ..schemas.user import UserInDB
|
||||
from ..services.auth_service import AuthService
|
||||
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/auth/token")
|
||||
|
||||
|
||||
async def get_current_user(
|
||||
token: str = Depends(oauth2_scheme),
|
||||
session: AsyncSession = Depends(get_session),
|
||||
) -> UserInDB:
|
||||
payload = decode_access_token(token)
|
||||
username = payload["sub"]
|
||||
repo = UserRepository(session)
|
||||
user = await repo.get_by_username(username)
|
||||
if not user:
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="用户不存在或已被禁用")
|
||||
service = AuthService(session)
|
||||
schema = UserInDB.model_validate(user)
|
||||
schema.must_change_password = service.requires_password_reset(user)
|
||||
return schema
|
||||
|
||||
|
||||
async def get_current_admin(current_user: UserInDB = Depends(get_current_user)) -> UserInDB:
|
||||
if not current_user.is_admin:
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="需要管理员权限")
|
||||
return current_user
|
||||
58
backend/app/core/security.py
Normal file
58
backend/app/core/security.py
Normal file
@@ -0,0 +1,58 @@
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from fastapi import HTTPException, status
|
||||
from jose import JWTError, jwt
|
||||
from passlib.context import CryptContext
|
||||
|
||||
from .config import settings
|
||||
|
||||
# 统一的密码哈希上下文,后续如需切换算法只需在此维护
|
||||
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||||
|
||||
|
||||
def hash_password(password: str) -> str:
|
||||
"""对用户密码进行哈希处理,任何时候都不要存储明文密码。"""
|
||||
return pwd_context.hash(password)
|
||||
|
||||
|
||||
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
||||
"""验证明文密码是否匹配哈希值。"""
|
||||
return pwd_context.verify(plain_password, hashed_password)
|
||||
|
||||
|
||||
def create_access_token(
|
||||
subject: str,
|
||||
*,
|
||||
expires_delta: Optional[timedelta] = None,
|
||||
extra_claims: Optional[Dict[str, Any]] = None,
|
||||
) -> str:
|
||||
"""生成 JWT 访问令牌,默认过期时间读取自配置。"""
|
||||
if expires_delta is None:
|
||||
expires_delta = timedelta(minutes=settings.access_token_expire_minutes)
|
||||
|
||||
now = datetime.utcnow()
|
||||
expire = now + expires_delta
|
||||
|
||||
to_encode: Dict[str, Any] = {"sub": subject, "iat": now, "exp": expire}
|
||||
if extra_claims:
|
||||
to_encode.update(extra_claims)
|
||||
|
||||
return jwt.encode(to_encode, settings.secret_key, algorithm=settings.jwt_algorithm)
|
||||
|
||||
|
||||
def decode_access_token(token: str) -> Dict[str, Any]:
|
||||
"""解析并校验 JWT,失败时抛出 401 异常。"""
|
||||
credentials_exception = HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="无效的凭证",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
try:
|
||||
payload = jwt.decode(token, settings.secret_key, algorithms=[settings.jwt_algorithm])
|
||||
except JWTError as exc:
|
||||
raise credentials_exception from exc
|
||||
|
||||
if "sub" not in payload:
|
||||
raise credentials_exception
|
||||
return payload
|
||||
0
backend/app/db/__init__.py
Normal file
0
backend/app/db/__init__.py
Normal file
9
backend/app/db/base.py
Normal file
9
backend/app/db/base.py
Normal file
@@ -0,0 +1,9 @@
|
||||
from sqlalchemy.orm import DeclarativeBase, declared_attr
|
||||
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
"""SQLAlchemy 基类,自动根据类名生成表名。"""
|
||||
|
||||
@declared_attr.directive
|
||||
def __tablename__(cls) -> str: # type: ignore[override]
|
||||
return cls.__name__.lower()
|
||||
122
backend/app/db/init_db.py
Normal file
122
backend/app/db/init_db.py
Normal file
@@ -0,0 +1,122 @@
|
||||
import logging
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from sqlalchemy import select, text
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.engine import URL, make_url
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
|
||||
|
||||
from ..core.config import settings
|
||||
from ..core.security import hash_password
|
||||
from ..models import Prompt, SystemConfig, User
|
||||
from .base import Base
|
||||
from .system_config_defaults import SYSTEM_CONFIG_DEFAULTS
|
||||
from .session import AsyncSessionLocal, engine
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def init_db() -> None:
|
||||
"""初始化数据库结构并确保默认管理员存在。"""
|
||||
|
||||
await _ensure_database_exists()
|
||||
|
||||
# ---- 第一步:创建所有表结构 ----
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
logger.info("数据库表结构已初始化")
|
||||
|
||||
# ---- 第二步:确保管理员账号至少存在一个 ----
|
||||
async with AsyncSessionLocal() as session:
|
||||
admin_exists = await session.execute(select(User).where(User.is_admin.is_(True)))
|
||||
if not admin_exists.scalars().first():
|
||||
logger.warning("未检测到管理员账号,正在创建默认管理员 ...")
|
||||
admin_user = User(
|
||||
username=settings.admin_default_username,
|
||||
email=settings.admin_default_email,
|
||||
hashed_password=hash_password(settings.admin_default_password),
|
||||
is_admin=True,
|
||||
)
|
||||
|
||||
session.add(admin_user)
|
||||
try:
|
||||
await session.commit()
|
||||
logger.info("默认管理员创建完成:%s", settings.admin_default_username)
|
||||
except IntegrityError:
|
||||
await session.rollback()
|
||||
logger.exception("默认管理员创建失败,可能是并发启动导致,请检查数据库状态")
|
||||
|
||||
# ---- 第三步:同步系统配置到数据库 ----
|
||||
for entry in SYSTEM_CONFIG_DEFAULTS:
|
||||
value = entry.value_getter(settings)
|
||||
if value is None:
|
||||
continue
|
||||
existing = await session.get(SystemConfig, entry.key)
|
||||
if existing:
|
||||
if entry.description and existing.description != entry.description:
|
||||
existing.description = entry.description
|
||||
continue
|
||||
session.add(
|
||||
SystemConfig(
|
||||
key=entry.key,
|
||||
value=value,
|
||||
description=entry.description,
|
||||
)
|
||||
)
|
||||
|
||||
await _ensure_default_prompts(session)
|
||||
|
||||
await session.commit()
|
||||
|
||||
|
||||
async def _ensure_database_exists() -> None:
|
||||
"""在首次连接前确认数据库存在,针对不同驱动做最小化准备工作。"""
|
||||
url = make_url(settings.sqlalchemy_database_uri)
|
||||
|
||||
if url.get_backend_name() == "sqlite":
|
||||
# SQLite 采用文件数据库,确保父目录存在即可,无需额外建库语句
|
||||
db_path = Path(url.database or "").expanduser()
|
||||
if not db_path.is_absolute():
|
||||
project_root = Path(__file__).resolve().parents[2]
|
||||
db_path = (project_root / db_path).resolve()
|
||||
db_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
return
|
||||
|
||||
database = (url.database or "").strip("/")
|
||||
if not database:
|
||||
return
|
||||
|
||||
admin_url = URL.create(
|
||||
drivername=url.drivername,
|
||||
username=url.username,
|
||||
password=url.password,
|
||||
host=url.host,
|
||||
port=url.port,
|
||||
database=None,
|
||||
query=url.query,
|
||||
)
|
||||
|
||||
admin_engine = create_async_engine(
|
||||
admin_url.render_as_string(hide_password=False),
|
||||
isolation_level="AUTOCOMMIT",
|
||||
)
|
||||
async with admin_engine.begin() as conn:
|
||||
await conn.execute(text(f"CREATE DATABASE IF NOT EXISTS `{database}`"))
|
||||
await admin_engine.dispose()
|
||||
|
||||
|
||||
async def _ensure_default_prompts(session: AsyncSession) -> None:
|
||||
prompts_dir = Path(__file__).resolve().parents[2] / "prompts"
|
||||
if not prompts_dir.is_dir():
|
||||
return
|
||||
|
||||
result = await session.execute(select(Prompt.name))
|
||||
existing_names = set(result.scalars().all())
|
||||
|
||||
for prompt_file in sorted(prompts_dir.glob("*.md")):
|
||||
name = prompt_file.stem
|
||||
if name in existing_names:
|
||||
continue
|
||||
content = prompt_file.read_text(encoding="utf-8")
|
||||
session.add(Prompt(name=name, content=content))
|
||||
30
backend/app/db/session.py
Normal file
30
backend/app/db/session.py
Normal file
@@ -0,0 +1,30 @@
|
||||
from collections.abc import AsyncGenerator
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
from sqlalchemy.pool import NullPool
|
||||
|
||||
from ..core.config import settings
|
||||
|
||||
# 根据不同数据库驱动调整连接池参数,确保在多数据库环境下表现稳定
|
||||
engine_kwargs = {"echo": settings.debug}
|
||||
if settings.is_sqlite_backend:
|
||||
# SQLite 场景下禁用连接池并放宽线程检查,避免多协程读写冲突
|
||||
engine_kwargs.update(
|
||||
pool_pre_ping=False,
|
||||
connect_args={"check_same_thread": False},
|
||||
poolclass=NullPool,
|
||||
)
|
||||
else:
|
||||
# MySQL 场景保持健康检查与连接复用,适用于生产环境的长连接需求
|
||||
engine_kwargs.update(pool_pre_ping=True, pool_recycle=3600)
|
||||
|
||||
engine = create_async_engine(settings.sqlalchemy_database_uri, **engine_kwargs)
|
||||
|
||||
# 统一的 Session 工厂,禁用 expire_on_commit 方便返回模型对象
|
||||
AsyncSessionLocal = async_sessionmaker(bind=engine, expire_on_commit=False)
|
||||
|
||||
|
||||
async def get_session() -> AsyncGenerator[AsyncSession, None]:
|
||||
"""FastAPI 依赖项:提供一个作用域内共享的数据库会话。"""
|
||||
async with AsyncSessionLocal() as session:
|
||||
yield session
|
||||
110
backend/app/db/system_config_defaults.py
Normal file
110
backend/app/db/system_config_defaults.py
Normal file
@@ -0,0 +1,110 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, Optional
|
||||
|
||||
from ..core.config import Settings
|
||||
|
||||
|
||||
def _to_optional_str(value: Optional[object]) -> Optional[str]:
|
||||
return str(value) if value is not None else None
|
||||
|
||||
|
||||
def _bool_to_text(value: bool) -> str:
|
||||
return "true" if value else "false"
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SystemConfigDefault:
|
||||
key: str
|
||||
value_getter: Callable[[Settings], Optional[str]]
|
||||
description: Optional[str] = None
|
||||
|
||||
|
||||
SYSTEM_CONFIG_DEFAULTS: list[SystemConfigDefault] = [
|
||||
SystemConfigDefault(
|
||||
key="llm.api_key",
|
||||
value_getter=lambda config: config.openai_api_key,
|
||||
description="默认 LLM API Key,用于后台调用大模型。",
|
||||
),
|
||||
SystemConfigDefault(
|
||||
key="llm.base_url",
|
||||
value_getter=lambda config: _to_optional_str(config.openai_base_url),
|
||||
description="默认大模型 API Base URL。",
|
||||
),
|
||||
SystemConfigDefault(
|
||||
key="llm.model",
|
||||
value_getter=lambda config: config.openai_model_name,
|
||||
description="默认 LLM 模型名称。",
|
||||
),
|
||||
SystemConfigDefault(
|
||||
key="smtp.server",
|
||||
value_getter=lambda config: config.smtp_server,
|
||||
description="用于发送邮件验证码的 SMTP 服务器地址。",
|
||||
),
|
||||
SystemConfigDefault(
|
||||
key="smtp.port",
|
||||
value_getter=lambda config: _to_optional_str(config.smtp_port),
|
||||
description="SMTP 服务端口。",
|
||||
),
|
||||
SystemConfigDefault(
|
||||
key="smtp.username",
|
||||
value_getter=lambda config: config.smtp_username,
|
||||
description="SMTP 登录用户名。",
|
||||
),
|
||||
SystemConfigDefault(
|
||||
key="smtp.password",
|
||||
value_getter=lambda config: config.smtp_password,
|
||||
description="SMTP 登录密码。",
|
||||
),
|
||||
SystemConfigDefault(
|
||||
key="smtp.from",
|
||||
value_getter=lambda config: config.email_from,
|
||||
description="邮件显示的发件人名称或邮箱。",
|
||||
),
|
||||
SystemConfigDefault(
|
||||
key="auth.allow_registration",
|
||||
value_getter=lambda config: _bool_to_text(config.allow_registration),
|
||||
description="是否允许用户自助注册。",
|
||||
),
|
||||
SystemConfigDefault(
|
||||
key="auth.linuxdo_enabled",
|
||||
value_getter=lambda config: _bool_to_text(config.enable_linuxdo_login),
|
||||
description="是否启用 Linux.do OAuth 登录。",
|
||||
),
|
||||
SystemConfigDefault(
|
||||
key="linuxdo.client_id",
|
||||
value_getter=lambda config: config.linuxdo_client_id,
|
||||
description="Linux.do OAuth Client ID。",
|
||||
),
|
||||
SystemConfigDefault(
|
||||
key="linuxdo.client_secret",
|
||||
value_getter=lambda config: config.linuxdo_client_secret,
|
||||
description="Linux.do OAuth Client Secret。",
|
||||
),
|
||||
SystemConfigDefault(
|
||||
key="linuxdo.redirect_uri",
|
||||
value_getter=lambda config: _to_optional_str(config.linuxdo_redirect_uri),
|
||||
description="Linux.do OAuth 回调地址。",
|
||||
),
|
||||
SystemConfigDefault(
|
||||
key="linuxdo.auth_url",
|
||||
value_getter=lambda config: _to_optional_str(config.linuxdo_auth_url),
|
||||
description="Linux.do OAuth 授权地址。",
|
||||
),
|
||||
SystemConfigDefault(
|
||||
key="linuxdo.token_url",
|
||||
value_getter=lambda config: _to_optional_str(config.linuxdo_token_url),
|
||||
description="Linux.do OAuth Token 获取地址。",
|
||||
),
|
||||
SystemConfigDefault(
|
||||
key="linuxdo.user_info_url",
|
||||
value_getter=lambda config: _to_optional_str(config.linuxdo_user_info_url),
|
||||
description="Linux.do 用户信息接口地址。",
|
||||
),
|
||||
SystemConfigDefault(
|
||||
key="writer.chapter_versions",
|
||||
value_getter=lambda config: _to_optional_str(config.writer_chapter_versions),
|
||||
description="每次生成章节的候选版本数量。",
|
||||
),
|
||||
]
|
||||
105
backend/app/main.py
Normal file
105
backend/app/main.py
Normal file
@@ -0,0 +1,105 @@
|
||||
"""FastAPI 应用入口,负责装配路由、依赖与生命周期管理。"""
|
||||
|
||||
import logging
|
||||
from logging.config import dictConfig
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
from .core.config import settings
|
||||
from .db.init_db import init_db
|
||||
from .services.prompt_service import PromptService
|
||||
from .db.session import AsyncSessionLocal
|
||||
from .api.routers import api_router
|
||||
|
||||
|
||||
dictConfig(
|
||||
{
|
||||
"version": 1,
|
||||
"disable_existing_loggers": False,
|
||||
"formatters": {
|
||||
"default": {
|
||||
"format": "%(asctime)s [%(levelname)s] %(name)s - %(message)s",
|
||||
}
|
||||
},
|
||||
"handlers": {
|
||||
"console": {
|
||||
"class": "logging.StreamHandler",
|
||||
"formatter": "default",
|
||||
}
|
||||
},
|
||||
"loggers": {
|
||||
"backend": {
|
||||
"level": settings.logging_level,
|
||||
"handlers": ["console"],
|
||||
"propagate": False,
|
||||
},
|
||||
"app": {
|
||||
"level": settings.logging_level,
|
||||
"handlers": ["console"],
|
||||
"propagate": False,
|
||||
},
|
||||
"backend.app": {
|
||||
"level": settings.logging_level,
|
||||
"handlers": ["console"],
|
||||
"propagate": False,
|
||||
},
|
||||
"backend.api": {
|
||||
"level": settings.logging_level,
|
||||
"handlers": ["console"],
|
||||
"propagate": False,
|
||||
},
|
||||
"backend.services": {
|
||||
"level": settings.logging_level,
|
||||
"handlers": ["console"],
|
||||
"propagate": False,
|
||||
},
|
||||
},
|
||||
"root": {
|
||||
"level": "WARNING",
|
||||
"handlers": ["console"],
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
# 应用启动时初始化数据库,并预热提示词缓存
|
||||
await init_db()
|
||||
async with AsyncSessionLocal() as session:
|
||||
prompt_service = PromptService(session)
|
||||
await prompt_service.preload()
|
||||
yield
|
||||
|
||||
|
||||
app = FastAPI(
|
||||
title=settings.app_name,
|
||||
debug=settings.debug,
|
||||
version="1.0.0",
|
||||
lifespan=lifespan,
|
||||
)
|
||||
|
||||
# CORS 配置,生产环境建议改为具体域名
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
app.include_router(api_router)
|
||||
|
||||
|
||||
# 健康检查接口(用于 Docker 健康检查和监控)
|
||||
@app.get("/health", tags=["Health"])
|
||||
@app.get("/api/health", tags=["Health"])
|
||||
async def health_check():
|
||||
"""健康检查接口,返回应用状态。"""
|
||||
return {
|
||||
"status": "healthy",
|
||||
"app": settings.app_name,
|
||||
"version": "1.0.0",
|
||||
}
|
||||
41
backend/app/models/__init__.py
Normal file
41
backend/app/models/__init__.py
Normal file
@@ -0,0 +1,41 @@
|
||||
"""集中导出 ORM 模型,确保 SQLAlchemy 元数据在初始化时被正确加载。"""
|
||||
|
||||
from .admin_setting import AdminSetting
|
||||
from .llm_config import LLMConfig
|
||||
from .novel import (
|
||||
BlueprintCharacter,
|
||||
BlueprintRelationship,
|
||||
Chapter,
|
||||
ChapterEvaluation,
|
||||
ChapterOutline,
|
||||
ChapterVersion,
|
||||
NovelBlueprint,
|
||||
NovelConversation,
|
||||
NovelProject,
|
||||
)
|
||||
from .prompt import Prompt
|
||||
from .update_log import UpdateLog
|
||||
from .usage_metric import UsageMetric
|
||||
from .user import User
|
||||
from .user_daily_request import UserDailyRequest
|
||||
from .system_config import SystemConfig
|
||||
|
||||
__all__ = [
|
||||
"AdminSetting",
|
||||
"LLMConfig",
|
||||
"NovelConversation",
|
||||
"NovelBlueprint",
|
||||
"BlueprintCharacter",
|
||||
"BlueprintRelationship",
|
||||
"ChapterOutline",
|
||||
"Chapter",
|
||||
"ChapterVersion",
|
||||
"ChapterEvaluation",
|
||||
"NovelProject",
|
||||
"Prompt",
|
||||
"UpdateLog",
|
||||
"UsageMetric",
|
||||
"User",
|
||||
"UserDailyRequest",
|
||||
"SystemConfig",
|
||||
]
|
||||
13
backend/app/models/admin_setting.py
Normal file
13
backend/app/models/admin_setting.py
Normal file
@@ -0,0 +1,13 @@
|
||||
from sqlalchemy import String, Text
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from ..db.base import Base
|
||||
|
||||
|
||||
class AdminSetting(Base):
|
||||
"""后台配置项,采用简单的 KV 结构。"""
|
||||
|
||||
__tablename__ = "admin_settings"
|
||||
|
||||
key: Mapped[str] = mapped_column(String(64), primary_key=True)
|
||||
value: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
17
backend/app/models/llm_config.py
Normal file
17
backend/app/models/llm_config.py
Normal file
@@ -0,0 +1,17 @@
|
||||
from sqlalchemy import ForeignKey, Text
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from ..db.base import Base
|
||||
|
||||
|
||||
class LLMConfig(Base):
|
||||
"""用户自定义的 LLM 接入配置。"""
|
||||
|
||||
__tablename__ = "llm_configs"
|
||||
|
||||
user_id: Mapped[int] = mapped_column(ForeignKey("users.id", ondelete="CASCADE"), primary_key=True)
|
||||
llm_provider_url: Mapped[str | None] = mapped_column(Text())
|
||||
llm_provider_api_key: Mapped[str | None] = mapped_column(Text())
|
||||
llm_provider_model: Mapped[str | None] = mapped_column(Text())
|
||||
|
||||
user: Mapped["User"] = relationship("User", back_populates="llm_config")
|
||||
225
backend/app/models/novel.py
Normal file
225
backend/app/models/novel.py
Normal file
@@ -0,0 +1,225 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import JSON, BigInteger, DateTime, Float, ForeignKey, Integer, String, Text, func
|
||||
from sqlalchemy.dialects.mysql import LONGTEXT
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from ..db.base import Base
|
||||
|
||||
# 自定义列类型:兼容跨数据库环境
|
||||
BIGINT_PK_TYPE = BigInteger().with_variant(Integer, "sqlite")
|
||||
LONG_TEXT_TYPE = Text().with_variant(LONGTEXT, "mysql")
|
||||
|
||||
|
||||
class _MetadataAccessor:
|
||||
"""Descriptor 用于将 `metadata` 访问重定向到 `metadata_`,且保持 Base.metadata 可用。"""
|
||||
|
||||
def __get__(self, instance, owner):
|
||||
if instance is None:
|
||||
return Base.metadata
|
||||
return instance.metadata_
|
||||
|
||||
def __set__(self, instance, value):
|
||||
instance.metadata_ = value
|
||||
|
||||
|
||||
class NovelProject(Base):
|
||||
"""小说项目主表,仅存放轻量级元数据。"""
|
||||
|
||||
__tablename__ = "novel_projects"
|
||||
|
||||
id: Mapped[str] = mapped_column(String(36), primary_key=True)
|
||||
user_id: Mapped[int] = mapped_column(ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True)
|
||||
title: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
initial_prompt: Mapped[Optional[str]] = mapped_column(Text)
|
||||
status: Mapped[str] = mapped_column(String(32), default="draft")
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), server_default=func.now())
|
||||
updated_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now())
|
||||
|
||||
owner: Mapped["User"] = relationship("User", back_populates="novel_projects")
|
||||
blueprint: Mapped[Optional["NovelBlueprint"]] = relationship(
|
||||
back_populates="project", cascade="all, delete-orphan", uselist=False
|
||||
)
|
||||
conversations: Mapped[list["NovelConversation"]] = relationship(
|
||||
back_populates="project", cascade="all, delete-orphan", order_by="NovelConversation.seq"
|
||||
)
|
||||
characters: Mapped[list["BlueprintCharacter"]] = relationship(
|
||||
back_populates="project", cascade="all, delete-orphan", order_by="BlueprintCharacter.position"
|
||||
)
|
||||
relationships_: Mapped[list["BlueprintRelationship"]] = relationship(
|
||||
back_populates="project", cascade="all, delete-orphan", order_by="BlueprintRelationship.position"
|
||||
)
|
||||
outlines: Mapped[list["ChapterOutline"]] = relationship(
|
||||
back_populates="project", cascade="all, delete-orphan", order_by="ChapterOutline.chapter_number"
|
||||
)
|
||||
chapters: Mapped[list["Chapter"]] = relationship(
|
||||
back_populates="project", cascade="all, delete-orphan", order_by="Chapter.chapter_number"
|
||||
)
|
||||
|
||||
|
||||
class NovelConversation(Base):
|
||||
"""对话记录表,存储概念阶段的连续对话。"""
|
||||
|
||||
__tablename__ = "novel_conversations"
|
||||
|
||||
id: Mapped[int] = mapped_column(BIGINT_PK_TYPE, primary_key=True, autoincrement=True)
|
||||
project_id: Mapped[str] = mapped_column(ForeignKey("novel_projects.id", ondelete="CASCADE"), nullable=False)
|
||||
seq: Mapped[int] = mapped_column(Integer, nullable=False)
|
||||
role: Mapped[str] = mapped_column(String(32), nullable=False)
|
||||
content: Mapped[str] = mapped_column(LONG_TEXT_TYPE, nullable=False)
|
||||
metadata_: Mapped[Optional[dict]] = mapped_column("metadata", JSON)
|
||||
metadata = _MetadataAccessor()
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), server_default=func.now())
|
||||
|
||||
project: Mapped[NovelProject] = relationship(back_populates="conversations")
|
||||
|
||||
|
||||
class NovelBlueprint(Base):
|
||||
"""蓝图主体信息(标题、风格等)。"""
|
||||
|
||||
__tablename__ = "novel_blueprints"
|
||||
|
||||
project_id: Mapped[str] = mapped_column(
|
||||
ForeignKey("novel_projects.id", ondelete="CASCADE"), primary_key=True
|
||||
)
|
||||
title: Mapped[Optional[str]] = mapped_column(String(255))
|
||||
target_audience: Mapped[Optional[str]] = mapped_column(String(255))
|
||||
genre: Mapped[Optional[str]] = mapped_column(String(128))
|
||||
style: Mapped[Optional[str]] = mapped_column(String(128))
|
||||
tone: Mapped[Optional[str]] = mapped_column(String(128))
|
||||
one_sentence_summary: Mapped[Optional[str]] = mapped_column(Text)
|
||||
full_synopsis: Mapped[Optional[str]] = mapped_column(LONG_TEXT_TYPE)
|
||||
world_setting: Mapped[Optional[dict]] = mapped_column(JSON, default=dict)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), server_default=func.now())
|
||||
updated_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now())
|
||||
|
||||
project: Mapped[NovelProject] = relationship(back_populates="blueprint")
|
||||
|
||||
|
||||
class BlueprintCharacter(Base):
|
||||
"""蓝图角色信息。"""
|
||||
|
||||
__tablename__ = "blueprint_characters"
|
||||
|
||||
id: Mapped[int] = mapped_column(BIGINT_PK_TYPE, primary_key=True, autoincrement=True)
|
||||
project_id: Mapped[str] = mapped_column(ForeignKey("novel_projects.id", ondelete="CASCADE"), nullable=False)
|
||||
name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
identity: Mapped[Optional[str]] = mapped_column(String(255))
|
||||
personality: Mapped[Optional[str]] = mapped_column(Text)
|
||||
goals: Mapped[Optional[str]] = mapped_column(Text)
|
||||
abilities: Mapped[Optional[str]] = mapped_column(Text)
|
||||
relationship_to_protagonist: Mapped[Optional[str]] = mapped_column(Text)
|
||||
extra: Mapped[Optional[dict]] = mapped_column(JSON)
|
||||
position: Mapped[int] = mapped_column(Integer, default=0)
|
||||
|
||||
project: Mapped[NovelProject] = relationship(back_populates="characters")
|
||||
|
||||
|
||||
class BlueprintRelationship(Base):
|
||||
"""角色之间的关系。"""
|
||||
|
||||
__tablename__ = "blueprint_relationships"
|
||||
|
||||
id: Mapped[int] = mapped_column(BIGINT_PK_TYPE, primary_key=True, autoincrement=True)
|
||||
project_id: Mapped[str] = mapped_column(ForeignKey("novel_projects.id", ondelete="CASCADE"), nullable=False)
|
||||
character_from: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
character_to: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
description: Mapped[Optional[str]] = mapped_column(Text)
|
||||
position: Mapped[int] = mapped_column(Integer, default=0)
|
||||
|
||||
project: Mapped[NovelProject] = relationship(back_populates="relationships_")
|
||||
|
||||
|
||||
class ChapterOutline(Base):
|
||||
"""章节纲要。"""
|
||||
|
||||
__tablename__ = "chapter_outlines"
|
||||
|
||||
id: Mapped[int] = mapped_column(BIGINT_PK_TYPE, primary_key=True, autoincrement=True)
|
||||
project_id: Mapped[str] = mapped_column(ForeignKey("novel_projects.id", ondelete="CASCADE"), nullable=False)
|
||||
chapter_number: Mapped[int] = mapped_column(Integer, nullable=False)
|
||||
title: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
summary: Mapped[Optional[str]] = mapped_column(Text)
|
||||
|
||||
project: Mapped[NovelProject] = relationship(back_populates="outlines")
|
||||
|
||||
|
||||
class Chapter(Base):
|
||||
"""章节正文状态,指向选中的版本。"""
|
||||
|
||||
__tablename__ = "chapters"
|
||||
|
||||
id: Mapped[int] = mapped_column(BIGINT_PK_TYPE, primary_key=True, autoincrement=True)
|
||||
project_id: Mapped[str] = mapped_column(ForeignKey("novel_projects.id", ondelete="CASCADE"), nullable=False)
|
||||
chapter_number: Mapped[int] = mapped_column(Integer, nullable=False)
|
||||
real_summary: Mapped[Optional[str]] = mapped_column(Text)
|
||||
status: Mapped[str] = mapped_column(String(32), default="not_generated")
|
||||
word_count: Mapped[int] = mapped_column(Integer, default=0)
|
||||
selected_version_id: Mapped[Optional[int]] = mapped_column(
|
||||
ForeignKey("chapter_versions.id", ondelete="SET NULL"), nullable=True
|
||||
)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), server_default=func.now())
|
||||
updated_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now())
|
||||
|
||||
project: Mapped[NovelProject] = relationship(back_populates="chapters")
|
||||
versions: Mapped[list["ChapterVersion"]] = relationship(
|
||||
"ChapterVersion",
|
||||
back_populates="chapter",
|
||||
cascade="all, delete-orphan",
|
||||
order_by="ChapterVersion.created_at",
|
||||
primaryjoin="Chapter.id == ChapterVersion.chapter_id",
|
||||
foreign_keys="[ChapterVersion.chapter_id]",
|
||||
)
|
||||
selected_version: Mapped[Optional["ChapterVersion"]] = relationship(
|
||||
"ChapterVersion",
|
||||
foreign_keys=[selected_version_id],
|
||||
primaryjoin="Chapter.selected_version_id == ChapterVersion.id",
|
||||
post_update=True,
|
||||
)
|
||||
evaluations: Mapped[list["ChapterEvaluation"]] = relationship(
|
||||
back_populates="chapter", cascade="all, delete-orphan", order_by="ChapterEvaluation.created_at"
|
||||
)
|
||||
|
||||
|
||||
class ChapterVersion(Base):
|
||||
"""章节生成的不同版本文本。"""
|
||||
|
||||
__tablename__ = "chapter_versions"
|
||||
|
||||
id: Mapped[int] = mapped_column(BIGINT_PK_TYPE, primary_key=True, autoincrement=True)
|
||||
chapter_id: Mapped[int] = mapped_column(ForeignKey("chapters.id", ondelete="CASCADE"), nullable=False)
|
||||
version_label: Mapped[Optional[str]] = mapped_column(String(64))
|
||||
provider: Mapped[Optional[str]] = mapped_column(String(64))
|
||||
content: Mapped[str] = mapped_column(LONG_TEXT_TYPE, nullable=False)
|
||||
metadata_: Mapped[Optional[dict]] = mapped_column("metadata", JSON)
|
||||
metadata = _MetadataAccessor()
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), server_default=func.now())
|
||||
|
||||
chapter: Mapped[Chapter] = relationship(
|
||||
"Chapter",
|
||||
back_populates="versions",
|
||||
foreign_keys=[chapter_id],
|
||||
)
|
||||
evaluations: Mapped[list["ChapterEvaluation"]] = relationship(
|
||||
back_populates="version", cascade="all, delete-orphan"
|
||||
)
|
||||
|
||||
|
||||
class ChapterEvaluation(Base):
|
||||
"""章节评估记录。"""
|
||||
|
||||
__tablename__ = "chapter_evaluations"
|
||||
|
||||
id: Mapped[int] = mapped_column(BIGINT_PK_TYPE, primary_key=True, autoincrement=True)
|
||||
chapter_id: Mapped[int] = mapped_column(ForeignKey("chapters.id", ondelete="CASCADE"), nullable=False)
|
||||
version_id: Mapped[Optional[int]] = mapped_column(ForeignKey("chapter_versions.id", ondelete="CASCADE"))
|
||||
decision: Mapped[Optional[str]] = mapped_column(String(32))
|
||||
feedback: Mapped[Optional[str]] = mapped_column(Text)
|
||||
score: Mapped[Optional[float]] = mapped_column(Float)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), server_default=func.now())
|
||||
|
||||
chapter: Mapped[Chapter] = relationship(back_populates="evaluations")
|
||||
version: Mapped[Optional[ChapterVersion]] = relationship(back_populates="evaluations")
|
||||
25
backend/app/models/prompt.py
Normal file
25
backend/app/models/prompt.py
Normal file
@@ -0,0 +1,25 @@
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import DateTime, String, Text, func
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from ..db.base import Base
|
||||
|
||||
|
||||
class Prompt(Base):
|
||||
"""提示词表,支持后台 CRUD 操作。"""
|
||||
|
||||
__tablename__ = "prompts"
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
|
||||
name: Mapped[str] = mapped_column(String(100), unique=True, nullable=False, index=True)
|
||||
title: Mapped[Optional[str]] = mapped_column(String(255))
|
||||
content: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
tags: Mapped[Optional[str]] = mapped_column(String(255))
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
server_default=func.now(),
|
||||
onupdate=func.now(),
|
||||
)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), server_default=func.now())
|
||||
14
backend/app/models/system_config.py
Normal file
14
backend/app/models/system_config.py
Normal file
@@ -0,0 +1,14 @@
|
||||
from sqlalchemy import String, Text
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from ..db.base import Base
|
||||
|
||||
|
||||
class SystemConfig(Base):
|
||||
"""系统级配置项,例如默认 LLM API Key、模型名称等。"""
|
||||
|
||||
__tablename__ = "system_configs"
|
||||
|
||||
key: Mapped[str] = mapped_column(String(100), primary_key=True)
|
||||
value: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
description: Mapped[str | None] = mapped_column(String(255))
|
||||
18
backend/app/models/update_log.py
Normal file
18
backend/app/models/update_log.py
Normal file
@@ -0,0 +1,18 @@
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import Boolean, DateTime, String, Text, func
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from ..db.base import Base
|
||||
|
||||
|
||||
class UpdateLog(Base):
|
||||
"""更新日志表,供公告与后台管理使用。"""
|
||||
|
||||
__tablename__ = "update_logs"
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
|
||||
content: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), server_default=func.now())
|
||||
created_by: Mapped[str | None] = mapped_column(String(64))
|
||||
is_pinned: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False)
|
||||
13
backend/app/models/usage_metric.py
Normal file
13
backend/app/models/usage_metric.py
Normal file
@@ -0,0 +1,13 @@
|
||||
from sqlalchemy import Integer, String
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from ..db.base import Base
|
||||
|
||||
|
||||
class UsageMetric(Base):
|
||||
"""通用计数器表,目前用于记录 API 请求次数等统计数据。"""
|
||||
|
||||
__tablename__ = "usage_metrics"
|
||||
|
||||
key: Mapped[str] = mapped_column(String(64), primary_key=True)
|
||||
value: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
||||
31
backend/app/models/user.py
Normal file
31
backend/app/models/user.py
Normal file
@@ -0,0 +1,31 @@
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import Boolean, DateTime, ForeignKey, Integer, String, func
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from ..db.base import Base
|
||||
|
||||
|
||||
class User(Base):
|
||||
"""用户主表,记录账号及权限信息。"""
|
||||
|
||||
__tablename__ = "users"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, index=True)
|
||||
username: Mapped[str] = mapped_column(String(64), unique=True, nullable=False, index=True)
|
||||
email: Mapped[Optional[str]] = mapped_column(String(128), unique=True)
|
||||
hashed_password: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
external_id: Mapped[Optional[str]] = mapped_column(String(255), unique=True)
|
||||
is_admin: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
is_active: Mapped[bool] = mapped_column(Boolean, default=True)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), server_default=func.now())
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
server_default=func.now(),
|
||||
onupdate=func.now(),
|
||||
)
|
||||
|
||||
# 关系映射
|
||||
novel_projects: Mapped[list["NovelProject"]] = relationship("NovelProject", back_populates="owner")
|
||||
llm_config: Mapped[Optional["LLMConfig"]] = relationship("LLMConfig", back_populates="user", uselist=False)
|
||||
18
backend/app/models/user_daily_request.py
Normal file
18
backend/app/models/user_daily_request.py
Normal file
@@ -0,0 +1,18 @@
|
||||
from datetime import date
|
||||
|
||||
from sqlalchemy import Date, ForeignKey, Integer, UniqueConstraint
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from ..db.base import Base
|
||||
|
||||
|
||||
class UserDailyRequest(Base):
|
||||
"""记录每位用户每日使用次数的限流表。"""
|
||||
|
||||
__tablename__ = "user_daily_requests"
|
||||
__table_args__ = (UniqueConstraint("user_id", "request_date", name="uq_user_daily"),)
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
|
||||
user_id: Mapped[int] = mapped_column(ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True)
|
||||
request_date: Mapped[date] = mapped_column(Date, nullable=False)
|
||||
request_count: Mapped[int] = mapped_column(Integer, default=0)
|
||||
0
backend/app/repositories/__init__.py
Normal file
0
backend/app/repositories/__init__.py
Normal file
15
backend/app/repositories/admin_setting_repository.py
Normal file
15
backend/app/repositories/admin_setting_repository.py
Normal file
@@ -0,0 +1,15 @@
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from .base import BaseRepository
|
||||
from ..models import AdminSetting
|
||||
|
||||
|
||||
class AdminSettingRepository(BaseRepository[AdminSetting]):
|
||||
model = AdminSetting
|
||||
|
||||
async def get_value(self, key: str) -> Optional[str]:
|
||||
result = await self.session.execute(select(AdminSetting).where(AdminSetting.key == key))
|
||||
record = result.scalars().first()
|
||||
return record.value if record else None
|
||||
44
backend/app/repositories/base.py
Normal file
44
backend/app/repositories/base.py
Normal file
@@ -0,0 +1,44 @@
|
||||
from typing import Any, Generic, Iterable, Optional, TypeVar
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import InstrumentedAttribute
|
||||
|
||||
ModelType = TypeVar("ModelType")
|
||||
|
||||
|
||||
class BaseRepository(Generic[ModelType]):
|
||||
"""通用仓储基类,封装常见的增删改查操作。"""
|
||||
|
||||
model: type[ModelType]
|
||||
|
||||
def __init__(self, session: AsyncSession):
|
||||
self.session = session
|
||||
|
||||
async def get(self, **filters: Any) -> Optional[ModelType]:
|
||||
stmt = select(self.model).filter_by(**filters)
|
||||
result = await self.session.execute(stmt)
|
||||
return result.scalars().first()
|
||||
|
||||
async def list(self, *, filters: Optional[dict[str, Any]] = None) -> Iterable[ModelType]:
|
||||
stmt = select(self.model)
|
||||
if filters:
|
||||
stmt = stmt.filter_by(**filters)
|
||||
result = await self.session.execute(stmt)
|
||||
return result.scalars().all()
|
||||
|
||||
async def add(self, instance: ModelType) -> ModelType:
|
||||
self.session.add(instance)
|
||||
await self.session.flush()
|
||||
return instance
|
||||
|
||||
async def delete(self, instance: ModelType) -> None:
|
||||
await self.session.delete(instance)
|
||||
|
||||
async def update_fields(self, instance: ModelType, **values: Any) -> ModelType:
|
||||
for key, value in values.items():
|
||||
if value is None:
|
||||
continue
|
||||
setattr(instance, key, value)
|
||||
await self.session.flush()
|
||||
return instance
|
||||
14
backend/app/repositories/llm_config_repository.py
Normal file
14
backend/app/repositories/llm_config_repository.py
Normal file
@@ -0,0 +1,14 @@
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from .base import BaseRepository
|
||||
from ..models import LLMConfig
|
||||
|
||||
|
||||
class LLMConfigRepository(BaseRepository[LLMConfig]):
|
||||
model = LLMConfig
|
||||
|
||||
async def get_by_user(self, user_id: int) -> Optional[LLMConfig]:
|
||||
result = await self.session.execute(select(LLMConfig).where(LLMConfig.user_id == user_id))
|
||||
return result.scalars().first()
|
||||
55
backend/app/repositories/novel_repository.py
Normal file
55
backend/app/repositories/novel_repository.py
Normal file
@@ -0,0 +1,55 @@
|
||||
from typing import Iterable, Optional
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from .base import BaseRepository
|
||||
from ..models import Chapter, NovelProject
|
||||
|
||||
|
||||
class NovelRepository(BaseRepository[NovelProject]):
|
||||
model = NovelProject
|
||||
|
||||
async def get_by_id(self, project_id: str) -> Optional[NovelProject]:
|
||||
stmt = (
|
||||
select(NovelProject)
|
||||
.where(NovelProject.id == project_id)
|
||||
.options(
|
||||
selectinload(NovelProject.blueprint),
|
||||
selectinload(NovelProject.characters),
|
||||
selectinload(NovelProject.relationships_),
|
||||
selectinload(NovelProject.outlines),
|
||||
selectinload(NovelProject.conversations),
|
||||
selectinload(NovelProject.chapters).selectinload(Chapter.versions),
|
||||
selectinload(NovelProject.chapters).selectinload(Chapter.evaluations),
|
||||
selectinload(NovelProject.chapters).selectinload(Chapter.selected_version),
|
||||
)
|
||||
)
|
||||
result = await self.session.execute(stmt)
|
||||
return result.scalars().first()
|
||||
|
||||
async def list_by_user(self, user_id: int) -> Iterable[NovelProject]:
|
||||
result = await self.session.execute(
|
||||
select(NovelProject)
|
||||
.where(NovelProject.user_id == user_id)
|
||||
.order_by(NovelProject.updated_at.desc())
|
||||
.options(
|
||||
selectinload(NovelProject.blueprint),
|
||||
selectinload(NovelProject.outlines),
|
||||
selectinload(NovelProject.chapters).selectinload(Chapter.selected_version),
|
||||
)
|
||||
)
|
||||
return result.scalars().all()
|
||||
|
||||
async def list_all(self) -> Iterable[NovelProject]:
|
||||
result = await self.session.execute(
|
||||
select(NovelProject)
|
||||
.order_by(NovelProject.updated_at.desc())
|
||||
.options(
|
||||
selectinload(NovelProject.owner),
|
||||
selectinload(NovelProject.blueprint),
|
||||
selectinload(NovelProject.outlines),
|
||||
selectinload(NovelProject.chapters).selectinload(Chapter.selected_version),
|
||||
)
|
||||
)
|
||||
return result.scalars().all()
|
||||
19
backend/app/repositories/prompt_repository.py
Normal file
19
backend/app/repositories/prompt_repository.py
Normal file
@@ -0,0 +1,19 @@
|
||||
from typing import Iterable, Optional
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from .base import BaseRepository
|
||||
from ..models import Prompt
|
||||
|
||||
|
||||
class PromptRepository(BaseRepository[Prompt]):
|
||||
model = Prompt
|
||||
|
||||
async def get_by_name(self, name: str) -> Optional[Prompt]:
|
||||
result = await self.session.execute(select(Prompt).where(Prompt.name == name))
|
||||
return result.scalars().first()
|
||||
|
||||
async def list_all(self) -> Iterable[Prompt]:
|
||||
result = await self.session.execute(select(Prompt).order_by(Prompt.name))
|
||||
return result.scalars().all()
|
||||
18
backend/app/repositories/system_config_repository.py
Normal file
18
backend/app/repositories/system_config_repository.py
Normal file
@@ -0,0 +1,18 @@
|
||||
from typing import Iterable, Optional
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from .base import BaseRepository
|
||||
from ..models import SystemConfig
|
||||
|
||||
|
||||
class SystemConfigRepository(BaseRepository[SystemConfig]):
|
||||
model = SystemConfig
|
||||
|
||||
async def get_by_key(self, key: str) -> Optional[SystemConfig]:
|
||||
result = await self.session.execute(select(SystemConfig).where(SystemConfig.key == key))
|
||||
return result.scalars().first()
|
||||
|
||||
async def list_all(self) -> Iterable[SystemConfig]:
|
||||
result = await self.session.execute(select(SystemConfig).order_by(SystemConfig.key))
|
||||
return result.scalars().all()
|
||||
19
backend/app/repositories/update_log_repository.py
Normal file
19
backend/app/repositories/update_log_repository.py
Normal file
@@ -0,0 +1,19 @@
|
||||
from typing import Iterable
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from .base import BaseRepository
|
||||
from ..models import UpdateLog
|
||||
|
||||
|
||||
class UpdateLogRepository(BaseRepository[UpdateLog]):
|
||||
model = UpdateLog
|
||||
|
||||
async def list(self) -> Iterable[UpdateLog]:
|
||||
result = await self.session.execute(select(UpdateLog).order_by(UpdateLog.created_at.desc()))
|
||||
return result.scalars().all()
|
||||
|
||||
async def list_latest(self, limit: int = 5) -> Iterable[UpdateLog]:
|
||||
stmt = select(UpdateLog).order_by(UpdateLog.created_at.desc()).limit(limit)
|
||||
result = await self.session.execute(stmt)
|
||||
return result.scalars().all()
|
||||
19
backend/app/repositories/usage_metric_repository.py
Normal file
19
backend/app/repositories/usage_metric_repository.py
Normal file
@@ -0,0 +1,19 @@
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from .base import BaseRepository
|
||||
from ..models import UsageMetric
|
||||
|
||||
|
||||
class UsageMetricRepository(BaseRepository[UsageMetric]):
|
||||
model = UsageMetric
|
||||
|
||||
async def get_or_create(self, key: str) -> UsageMetric:
|
||||
result = await self.session.execute(select(UsageMetric).where(UsageMetric.key == key))
|
||||
instance = result.scalars().first()
|
||||
if instance is None:
|
||||
instance = UsageMetric(key=key, value=0)
|
||||
self.session.add(instance)
|
||||
await self.session.flush()
|
||||
return instance
|
||||
62
backend/app/repositories/user_repository.py
Normal file
62
backend/app/repositories/user_repository.py
Normal file
@@ -0,0 +1,62 @@
|
||||
from datetime import date
|
||||
from typing import Iterable, Optional
|
||||
|
||||
from sqlalchemy import func, select, update
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from .base import BaseRepository
|
||||
from ..models import User, UserDailyRequest
|
||||
|
||||
|
||||
class UserRepository(BaseRepository[User]):
|
||||
model = User
|
||||
|
||||
async def get_by_username(self, username: str) -> Optional[User]:
|
||||
stmt = select(User).where(User.username == username)
|
||||
result = await self.session.execute(stmt)
|
||||
return result.scalars().first()
|
||||
|
||||
async def get_by_email(self, email: str) -> Optional[User]:
|
||||
stmt = select(User).where(User.email == email)
|
||||
result = await self.session.execute(stmt)
|
||||
return result.scalars().first()
|
||||
|
||||
async def get_by_external_id(self, external_id: str) -> Optional[User]:
|
||||
stmt = select(User).where(User.external_id == external_id)
|
||||
result = await self.session.execute(stmt)
|
||||
return result.scalars().first()
|
||||
|
||||
async def list_all(self) -> Iterable[User]:
|
||||
result = await self.session.execute(select(User))
|
||||
return result.scalars().all()
|
||||
|
||||
async def increment_daily_request(self, user_id: int) -> None:
|
||||
today = date.today()
|
||||
stmt = select(UserDailyRequest).where(
|
||||
UserDailyRequest.user_id == user_id,
|
||||
UserDailyRequest.request_date == today,
|
||||
)
|
||||
result = await self.session.execute(stmt)
|
||||
record = result.scalars().first()
|
||||
|
||||
if record is None:
|
||||
record = UserDailyRequest(user_id=user_id, request_date=today, request_count=1)
|
||||
self.session.add(record)
|
||||
else:
|
||||
record.request_count += 1
|
||||
await self.session.flush()
|
||||
|
||||
async def get_daily_request(self, user_id: int) -> int:
|
||||
today = date.today()
|
||||
stmt = select(UserDailyRequest.request_count).where(
|
||||
UserDailyRequest.user_id == user_id,
|
||||
UserDailyRequest.request_date == today,
|
||||
)
|
||||
result = await self.session.execute(stmt)
|
||||
value = result.scalars().first()
|
||||
return value or 0
|
||||
|
||||
async def count_users(self) -> int:
|
||||
stmt = select(func.count(User.id))
|
||||
result = await self.session.execute(stmt)
|
||||
return result.scalar_one()
|
||||
0
backend/app/schemas/__init__.py
Normal file
0
backend/app/schemas/__init__.py
Normal file
49
backend/app/schemas/admin.py
Normal file
49
backend/app/schemas/admin.py
Normal file
@@ -0,0 +1,49 @@
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class Statistics(BaseModel):
|
||||
novel_count: int
|
||||
user_count: int
|
||||
api_request_count: int
|
||||
|
||||
|
||||
class DailyRequestLimit(BaseModel):
|
||||
limit: int = Field(..., ge=0, description="匿名用户每日可用次数")
|
||||
|
||||
|
||||
class UpdateLogRead(BaseModel):
|
||||
id: int
|
||||
content: str
|
||||
created_at: datetime
|
||||
created_by: Optional[str] = None
|
||||
is_pinned: bool
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class UpdateLogBase(BaseModel):
|
||||
content: Optional[str] = None
|
||||
is_pinned: Optional[bool] = None
|
||||
|
||||
|
||||
class UpdateLogCreate(UpdateLogBase):
|
||||
content: str
|
||||
|
||||
|
||||
class UpdateLogUpdate(UpdateLogBase):
|
||||
pass
|
||||
|
||||
|
||||
class AdminNovelSummary(BaseModel):
|
||||
id: str
|
||||
title: str
|
||||
owner_id: int
|
||||
owner_username: str
|
||||
genre: str
|
||||
last_edited: str
|
||||
completed_chapters: int
|
||||
total_chapters: int
|
||||
23
backend/app/schemas/config.py
Normal file
23
backend/app/schemas/config.py
Normal file
@@ -0,0 +1,23 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class SystemConfigBase(BaseModel):
|
||||
key: str = Field(..., description="配置键,需全局唯一")
|
||||
value: str = Field(..., description="配置值,统一存储为字符串")
|
||||
description: Optional[str] = Field(default=None, description="配置用途说明")
|
||||
|
||||
|
||||
class SystemConfigCreate(SystemConfigBase):
|
||||
pass
|
||||
|
||||
|
||||
class SystemConfigUpdate(BaseModel):
|
||||
value: Optional[str] = Field(default=None)
|
||||
description: Optional[str] = Field(default=None)
|
||||
|
||||
|
||||
class SystemConfigRead(SystemConfigBase):
|
||||
class Config:
|
||||
from_attributes = True
|
||||
20
backend/app/schemas/llm_config.py
Normal file
20
backend/app/schemas/llm_config.py
Normal file
@@ -0,0 +1,20 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, HttpUrl, Field
|
||||
|
||||
|
||||
class LLMConfigBase(BaseModel):
|
||||
llm_provider_url: Optional[HttpUrl] = Field(default=None, description="自定义 LLM 服务地址")
|
||||
llm_provider_api_key: Optional[str] = Field(default=None, description="自定义 LLM API Key")
|
||||
llm_provider_model: Optional[str] = Field(default=None, description="自定义模型名称")
|
||||
|
||||
|
||||
class LLMConfigCreate(LLMConfigBase):
|
||||
pass
|
||||
|
||||
|
||||
class LLMConfigRead(LLMConfigBase):
|
||||
user_id: int
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
170
backend/app/schemas/novel.py
Normal file
170
backend/app/schemas/novel.py
Normal file
@@ -0,0 +1,170 @@
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class ChoiceOption(BaseModel):
|
||||
"""前端选择项描述,用于动态 UI 控件。"""
|
||||
|
||||
id: str
|
||||
label: str
|
||||
|
||||
|
||||
class UIControl(BaseModel):
|
||||
"""描述前端应渲染的组件类型与配置。"""
|
||||
|
||||
type: str = Field(..., description="控件类型,如 single_choice/text_input")
|
||||
options: Optional[List[ChoiceOption]] = Field(default=None, description="可选项列表")
|
||||
placeholder: Optional[str] = Field(default=None, description="输入提示文案")
|
||||
|
||||
|
||||
class ConverseResponse(BaseModel):
|
||||
"""概念对话接口的统一返回体。"""
|
||||
|
||||
ai_message: str
|
||||
ui_control: UIControl
|
||||
conversation_state: Dict[str, Any]
|
||||
is_complete: bool = False
|
||||
ready_for_blueprint: Optional[bool] = None
|
||||
|
||||
|
||||
class ConverseRequest(BaseModel):
|
||||
"""概念对话接口的请求体。"""
|
||||
|
||||
user_input: Dict[str, Any]
|
||||
conversation_state: Dict[str, Any]
|
||||
|
||||
|
||||
class ChapterGenerationStatus(str, Enum):
|
||||
NOT_GENERATED = "not_generated"
|
||||
GENERATING = "generating"
|
||||
EVALUATING = "evaluating"
|
||||
SELECTING = "selecting"
|
||||
FAILED = "failed"
|
||||
EVALUATION_FAILED = "evaluation_failed"
|
||||
WAITING_FOR_CONFIRM = "waiting_for_confirm"
|
||||
SUCCESSFUL = "successful"
|
||||
|
||||
|
||||
class ChapterOutline(BaseModel):
|
||||
chapter_number: int
|
||||
title: str
|
||||
summary: str
|
||||
|
||||
|
||||
class Chapter(ChapterOutline):
|
||||
real_summary: Optional[str] = None
|
||||
content: Optional[str] = None
|
||||
versions: Optional[List[str]] = None
|
||||
evaluation: Optional[str] = None
|
||||
generation_status: ChapterGenerationStatus = ChapterGenerationStatus.NOT_GENERATED
|
||||
|
||||
|
||||
class Relationship(BaseModel):
|
||||
character_from: str
|
||||
character_to: str
|
||||
description: str
|
||||
|
||||
|
||||
class Blueprint(BaseModel):
|
||||
title: str
|
||||
target_audience: str = ""
|
||||
genre: str = ""
|
||||
style: str = ""
|
||||
tone: str = ""
|
||||
one_sentence_summary: str = ""
|
||||
full_synopsis: str = ""
|
||||
world_setting: Dict[str, Any] = {}
|
||||
characters: List[Dict[str, Any]] = []
|
||||
relationships: List[Relationship] = []
|
||||
chapter_outline: List[ChapterOutline] = []
|
||||
|
||||
|
||||
class NovelProject(BaseModel):
|
||||
id: str
|
||||
user_id: int
|
||||
title: str
|
||||
initial_prompt: str
|
||||
conversation_history: List[Dict[str, Any]] = []
|
||||
blueprint: Optional[Blueprint] = None
|
||||
chapters: List[Chapter] = []
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class NovelProjectSummary(BaseModel):
|
||||
id: str
|
||||
title: str
|
||||
genre: str
|
||||
last_edited: str
|
||||
completed_chapters: int
|
||||
total_chapters: int
|
||||
|
||||
|
||||
class BlueprintGenerationResponse(BaseModel):
|
||||
blueprint: Blueprint
|
||||
ai_message: str
|
||||
|
||||
|
||||
class ChapterGenerationResponse(BaseModel):
|
||||
ai_message: str
|
||||
chapter_versions: List[Dict[str, Any]]
|
||||
|
||||
|
||||
class NovelSectionType(str, Enum):
|
||||
OVERVIEW = "overview"
|
||||
WORLD_SETTING = "world_setting"
|
||||
CHARACTERS = "characters"
|
||||
RELATIONSHIPS = "relationships"
|
||||
CHAPTER_OUTLINE = "chapter_outline"
|
||||
CHAPTERS = "chapters"
|
||||
|
||||
|
||||
class NovelSectionResponse(BaseModel):
|
||||
section: NovelSectionType
|
||||
data: Dict[str, Any]
|
||||
|
||||
|
||||
class GenerateChapterRequest(BaseModel):
|
||||
chapter_number: int
|
||||
writing_notes: Optional[str] = Field(default=None, description="章节额外写作指令")
|
||||
|
||||
|
||||
class SelectVersionRequest(BaseModel):
|
||||
chapter_number: int
|
||||
version_index: int
|
||||
|
||||
|
||||
class EvaluateChapterRequest(BaseModel):
|
||||
chapter_number: int
|
||||
|
||||
|
||||
class UpdateChapterOutlineRequest(BaseModel):
|
||||
chapter_number: int
|
||||
title: str
|
||||
summary: str
|
||||
|
||||
|
||||
class DeleteChapterRequest(BaseModel):
|
||||
chapter_numbers: List[int]
|
||||
|
||||
|
||||
class GenerateOutlineRequest(BaseModel):
|
||||
start_chapter: int
|
||||
num_chapters: int
|
||||
|
||||
|
||||
class BlueprintPatch(BaseModel):
|
||||
one_sentence_summary: Optional[str] = None
|
||||
full_synopsis: Optional[str] = None
|
||||
world_setting: Optional[Dict[str, Any]] = None
|
||||
characters: Optional[List[Dict[str, Any]]] = None
|
||||
relationships: Optional[List[Relationship]] = None
|
||||
chapter_outline: Optional[List[ChapterOutline]] = None
|
||||
|
||||
|
||||
class EditChapterRequest(BaseModel):
|
||||
chapter_number: int
|
||||
content: str
|
||||
56
backend/app/schemas/prompt.py
Normal file
56
backend/app/schemas/prompt.py
Normal file
@@ -0,0 +1,56 @@
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class PromptBase(BaseModel):
|
||||
"""Prompt 基础模型。"""
|
||||
|
||||
name: str = Field(..., description="唯一标识,用于代码引用")
|
||||
title: Optional[str] = Field(default=None, description="可读标题")
|
||||
content: str = Field(..., description="提示词具体内容")
|
||||
tags: Optional[List[str]] = Field(default=None, description="标签集合")
|
||||
|
||||
|
||||
class PromptCreate(PromptBase):
|
||||
"""创建 Prompt 时使用的模型。"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class PromptUpdate(BaseModel):
|
||||
"""更新 Prompt 时使用的模型。"""
|
||||
|
||||
title: Optional[str] = Field(default=None)
|
||||
content: Optional[str] = Field(default=None)
|
||||
tags: Optional[List[str]] = Field(default=None)
|
||||
|
||||
|
||||
class PromptRead(PromptBase):
|
||||
"""对外暴露的 Prompt 数据结构。"""
|
||||
|
||||
id: int
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
@classmethod
|
||||
def model_validate(cls, obj: Any, *args: Any, **kwargs: Any) -> "PromptRead": # type: ignore[override]
|
||||
"""在转换 ORM 模型时,将字符串标签拆分为列表。"""
|
||||
if hasattr(obj, "id") and hasattr(obj, "name"):
|
||||
raw_tags = getattr(obj, "tags", None)
|
||||
if isinstance(raw_tags, str):
|
||||
processed = [tag for tag in raw_tags.split(",") if tag]
|
||||
elif isinstance(raw_tags, list):
|
||||
processed = raw_tags
|
||||
else:
|
||||
processed = None
|
||||
data = {
|
||||
"id": getattr(obj, "id"),
|
||||
"name": getattr(obj, "name"),
|
||||
"title": getattr(obj, "title", None),
|
||||
"content": getattr(obj, "content", None),
|
||||
"tags": processed,
|
||||
}
|
||||
return super().model_validate(data, *args, **kwargs)
|
||||
return super().model_validate(obj, *args, **kwargs)
|
||||
74
backend/app/schemas/user.py
Normal file
74
backend/app/schemas/user.py
Normal file
@@ -0,0 +1,74 @@
|
||||
from pydantic import BaseModel, EmailStr, Field
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class UserBase(BaseModel):
|
||||
"""用户基础数据结构,供多处复用。"""
|
||||
|
||||
username: str = Field(..., description="用户名")
|
||||
email: Optional[EmailStr] = Field(default=None, description="邮箱,可选")
|
||||
|
||||
|
||||
class UserCreate(UserBase):
|
||||
"""注册时使用的模型。"""
|
||||
|
||||
password: str = Field(..., min_length=6, description="明文密码")
|
||||
|
||||
|
||||
class UserUpdate(BaseModel):
|
||||
"""用户信息修改模型。"""
|
||||
|
||||
email: Optional[EmailStr] = Field(default=None, description="邮箱")
|
||||
password: Optional[str] = Field(default=None, min_length=6, description="新密码")
|
||||
|
||||
|
||||
class User(UserBase):
|
||||
"""对外暴露的用户信息。"""
|
||||
|
||||
id: int = Field(..., description="用户主键")
|
||||
is_admin: bool = Field(default=False, description="是否为管理员")
|
||||
must_change_password: bool = Field(default=False, description="是否需要强制修改密码")
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class UserInDB(User):
|
||||
"""数据库内部使用的模型,包含哈希后的密码。"""
|
||||
|
||||
hashed_password: str
|
||||
|
||||
|
||||
class Token(BaseModel):
|
||||
"""登录成功后返回的访问令牌。"""
|
||||
|
||||
access_token: str
|
||||
token_type: str = "bearer"
|
||||
must_change_password: bool = Field(default=False, description="是否需要强制修改密码")
|
||||
|
||||
|
||||
class TokenPayload(BaseModel):
|
||||
"""JWT 负载信息。"""
|
||||
|
||||
sub: str
|
||||
is_admin: bool = False
|
||||
|
||||
|
||||
class UserRegistration(UserCreate):
|
||||
"""注册接口需要的字段,包含邮箱验证码。"""
|
||||
|
||||
verification_code: str = Field(..., min_length=4, max_length=10, description="邮箱验证码")
|
||||
|
||||
|
||||
class PasswordChangeRequest(BaseModel):
|
||||
"""管理员修改密码请求模型。"""
|
||||
|
||||
old_password: str = Field(..., min_length=6, description="当前密码")
|
||||
new_password: str = Field(..., min_length=8, description="新密码")
|
||||
|
||||
|
||||
class AuthOptions(BaseModel):
|
||||
"""认证相关开关信息,供前端动态控制功能。"""
|
||||
|
||||
allow_registration: bool = Field(..., description="是否允许开放用户注册")
|
||||
enable_linuxdo_login: bool = Field(..., description="是否启用 Linux.do 登录")
|
||||
0
backend/app/services/__init__.py
Normal file
0
backend/app/services/__init__.py
Normal file
27
backend/app/services/admin_setting_service.py
Normal file
27
backend/app/services/admin_setting_service.py
Normal file
@@ -0,0 +1,27 @@
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from ..models import AdminSetting
|
||||
from ..repositories.admin_setting_repository import AdminSettingRepository
|
||||
|
||||
|
||||
class AdminSettingService:
|
||||
"""管理员配置项服务,提供简单的 KV 操作。"""
|
||||
|
||||
def __init__(self, session: AsyncSession):
|
||||
self.session = session
|
||||
self.repo = AdminSettingRepository(session)
|
||||
|
||||
async def get(self, key: str, default: Optional[str] = None) -> Optional[str]:
|
||||
value = await self.repo.get_value(key)
|
||||
return value if value is not None else default
|
||||
|
||||
async def set(self, key: str, value: str) -> None:
|
||||
record = await self.repo.get(key=key)
|
||||
if record:
|
||||
await self.repo.update_fields(record, value=value)
|
||||
else:
|
||||
setting = AdminSetting(key=key, value=value)
|
||||
await self.repo.add(setting)
|
||||
await self.session.commit()
|
||||
389
backend/app/services/auth_service.py
Normal file
389
backend/app/services/auth_service.py
Normal file
@@ -0,0 +1,389 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import random
|
||||
import secrets
|
||||
import string
|
||||
import time
|
||||
from typing import Dict, Optional
|
||||
|
||||
import httpx
|
||||
from email.header import Header
|
||||
from email.mime.text import MIMEText
|
||||
from email.utils import formataddr, parseaddr
|
||||
from fastapi import HTTPException, status
|
||||
|
||||
import smtplib
|
||||
|
||||
from ..core.config import settings
|
||||
from ..core.security import create_access_token, hash_password, verify_password
|
||||
from ..models import User
|
||||
from ..repositories.system_config_repository import SystemConfigRepository
|
||||
from ..repositories.user_repository import UserRepository
|
||||
from ..schemas.user import AuthOptions, Token, UserCreate, UserInDB, UserRegistration
|
||||
|
||||
|
||||
_VERIFICATION_CACHE: Dict[str, tuple[str, float]] = {}
|
||||
_LAST_SEND_TIME: Dict[str, float] = {}
|
||||
|
||||
|
||||
class AuthService:
|
||||
"""认证与授权逻辑,封装登录、注册、OAuth 对接等操作。"""
|
||||
|
||||
def __init__(self, session):
|
||||
self.session = session
|
||||
self.user_repo = UserRepository(session)
|
||||
self.system_config_repo = SystemConfigRepository(session)
|
||||
self._verification_cache = _VERIFICATION_CACHE
|
||||
self._last_send_time = _LAST_SEND_TIME
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 用户登录 / 注册
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def authenticate_user(self, username: str, password: str) -> User:
|
||||
user = await self.user_repo.get_by_username(username)
|
||||
if not user or not verify_password(password, user.hashed_password):
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="用户名或密码错误")
|
||||
return user
|
||||
|
||||
async def create_access_token(
|
||||
self,
|
||||
user: User | UserInDB,
|
||||
*,
|
||||
must_change_password: Optional[bool] = None,
|
||||
) -> Token:
|
||||
payload = {"is_admin": user.is_admin}
|
||||
token = create_access_token(user.username, extra_claims=payload)
|
||||
should_change = self.requires_password_reset(user) if must_change_password is None else must_change_password
|
||||
return Token(access_token=token, must_change_password=should_change)
|
||||
|
||||
async def register_user(self, payload: UserRegistration) -> User:
|
||||
if not await self.is_registration_enabled():
|
||||
raise HTTPException(status_code=403, detail="当前暂未开放注册")
|
||||
if await self.user_repo.get_by_username(payload.username):
|
||||
raise HTTPException(status_code=400, detail="用户名已存在")
|
||||
if payload.email and await self.user_repo.get_by_email(payload.email):
|
||||
raise HTTPException(status_code=400, detail="邮箱已被使用")
|
||||
|
||||
if not self.verify_code(payload.email, payload.verification_code):
|
||||
raise HTTPException(status_code=400, detail="验证码错误或已过期")
|
||||
|
||||
hashed_password = hash_password(payload.password)
|
||||
user = User(
|
||||
username=payload.username,
|
||||
email=payload.email,
|
||||
hashed_password=hashed_password,
|
||||
)
|
||||
self.session.add(user)
|
||||
await self.session.commit()
|
||||
return user
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 邮箱验证码逻辑
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def send_verification_code(self, email: str) -> None:
|
||||
if not await self.is_registration_enabled():
|
||||
raise HTTPException(status_code=403, detail="当前暂未开放注册")
|
||||
now = time.time()
|
||||
if email in self._last_send_time and now - self._last_send_time[email] < 60:
|
||||
raise HTTPException(status_code=429, detail="请稍后再试,1分钟内不可重复发送")
|
||||
|
||||
code = "".join(random.choices(string.digits, k=6))
|
||||
self._verification_cache[email] = (code, now + 300)
|
||||
self._last_send_time[email] = now
|
||||
|
||||
smtp_config = await self._load_smtp_config()
|
||||
if not smtp_config:
|
||||
raise HTTPException(status_code=500, detail="未配置邮件服务,请联系管理员")
|
||||
|
||||
await self._send_email(email, code, smtp_config)
|
||||
|
||||
def verify_code(self, email: str | None, code: str) -> bool:
|
||||
if not email:
|
||||
return False
|
||||
cached = self._verification_cache.get(email)
|
||||
if not cached:
|
||||
return False
|
||||
expected, expire_at = cached
|
||||
if time.time() > expire_at:
|
||||
self._verification_cache.pop(email, None)
|
||||
return False
|
||||
if code != expected:
|
||||
return False
|
||||
self._verification_cache.pop(email, None)
|
||||
return True
|
||||
|
||||
async def _load_smtp_config(self) -> Optional[Dict[str, str]]:
|
||||
keys = [
|
||||
"smtp.server",
|
||||
"smtp.port",
|
||||
"smtp.username",
|
||||
"smtp.password",
|
||||
"smtp.from",
|
||||
]
|
||||
configs = {}
|
||||
for key in keys:
|
||||
config = await self.system_config_repo.get_by_key(key)
|
||||
if config:
|
||||
configs[key] = config.value
|
||||
|
||||
required_keys = {"smtp.server", "smtp.port", "smtp.username", "smtp.password", "smtp.from"}
|
||||
if not required_keys.issubset(configs.keys()):
|
||||
return None
|
||||
|
||||
return configs
|
||||
|
||||
async def _send_email(self, to_email: str, code: str, smtp_config: Dict[str, str]) -> None:
|
||||
logger = logging.getLogger(__name__)
|
||||
server = smtp_config["smtp.server"]
|
||||
port = int(smtp_config.get("smtp.port", "465"))
|
||||
username = smtp_config["smtp.username"]
|
||||
password = smtp_config["smtp.password"]
|
||||
from_value = smtp_config.get("smtp.from") or username
|
||||
display_name, from_addr = parseaddr(from_value)
|
||||
if not display_name and "@" not in from_value and "<" not in from_value and from_value.strip():
|
||||
display_name = from_value.strip()
|
||||
if not from_addr or "@" not in from_addr:
|
||||
if from_addr and "@" not in from_addr:
|
||||
logger.warning(
|
||||
"发件邮箱缺少 @,已回退为登录账号",
|
||||
extra={"original": from_addr},
|
||||
)
|
||||
from_addr = username
|
||||
try:
|
||||
from_addr.encode("ascii")
|
||||
except UnicodeEncodeError:
|
||||
logger.warning(
|
||||
"发件邮箱包含非 ASCII 字符,已回退为登录账号",
|
||||
extra={"original": from_addr},
|
||||
)
|
||||
from_addr = username
|
||||
if display_name:
|
||||
formatted_from = formataddr((Header(display_name, "utf-8").encode(), from_addr))
|
||||
else:
|
||||
formatted_from = from_addr
|
||||
|
||||
try:
|
||||
to_email.encode("ascii")
|
||||
except UnicodeEncodeError as exc: # noqa: BLE001
|
||||
raise HTTPException(status_code=400, detail="邮箱地址包含不支持的字符") from exc
|
||||
|
||||
html_content = f"""
|
||||
<!DOCTYPE html>
|
||||
<html lang="zh-CN">
|
||||
<head>
|
||||
<meta http-equiv=\"Content-Type\" content=\"text/html; charset=UTF-8\">
|
||||
<meta name=\"viewport\" content=\"width=device-width, initial-scale=1.0\">
|
||||
<title>您的验证码</title>
|
||||
<style>
|
||||
body, table, td, a {{ -webkit-text-size-adjust: 100%; -ms-text-size-adjust: 100%; }}
|
||||
table, td {{ mso-table-lspace: 0pt; mso-table-rspace: 0pt; }}
|
||||
img {{ -ms-interpolation-mode: bicubic; }}
|
||||
body {{ margin: 0; padding: 0; }}
|
||||
table {{ border-collapse: collapse !important; }}
|
||||
</style>
|
||||
</head>
|
||||
<body style=\"margin: 0; padding: 0; width: 100% !important; background-color: #f3f4f6;\">
|
||||
<table width=\"100%\" border=\"0\" cellpadding=\"0\" cellspacing=\"0\" bgcolor=\"#f3f4f6\">
|
||||
<tr>
|
||||
<td align=\"center\" valign=\"top\" style=\"padding: 20px;\">
|
||||
<table border=\"0\" cellpadding=\"0\" cellspacing=\"0\" width=\"100%\" style=\"max-width: 512px; background-color: #ffffff; border-radius: 16px; overflow: hidden;\">
|
||||
<tr>
|
||||
<td align=\"center\" style=\"background-color: #2563eb; padding: 32px;\">
|
||||
<h1 style=\"font-family: Arial, Helvetica, sans-serif; font-size: 30px; font-weight: bold; color: #ffffff; margin: 0;\">操作验证码</h1>
|
||||
<p style=\"font-family: Arial, Helvetica, sans-serif; font-size: 16px; color: #dbeafe; margin: 8px 0 0;\">请使用下方验证码完成操作。</p>
|
||||
</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align=\"center\" style=\"padding: 32px 48px;\">
|
||||
<table border=\"0\" cellpadding=\"0\" cellspacing=\"0\" width=\"100%\">
|
||||
<tr>
|
||||
<td align=\"center\" style=\"background-color: #f3f4f6; border-radius: 12px; padding: 16px; margin: 24px 0;\">
|
||||
<p style=\"font-family: 'Courier New', Courier, monospace; font-size: 48px; font-weight: bold; letter-spacing: 0.1em; color: #1d4ed8; margin: 0;\">
|
||||
{code[:3]}{code[3:]}
|
||||
</p>
|
||||
</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align=\"center\" style=\"padding-top: 24px;\">
|
||||
<p style=\"font-family: Arial, Helvetica, sans-serif; font-size: 16px; color: #6b7280; margin: 0;\">
|
||||
此验证码将在 <span style=\"font-weight: bold; color: #374151;\">5分钟</span> 内有效。
|
||||
</p>
|
||||
</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align=\"center\" style=\"padding-top: 32px; border-top: 1px solid #e5e7eb; margin-top: 32px;\">
|
||||
<p style=\"font-family: Arial, Helvetica, sans-serif; font-size: 14px; font-weight: bold; color: #ef4444; margin: 0;\">
|
||||
为保障安全,请勿泄露此验证码。
|
||||
</p>
|
||||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align=\"center\" style=\"background-color: #f9fafb; padding: 24px; border-top: 1px solid #e5e7eb;\">
|
||||
<p style=\"font-family: Arial, Helvetica, sans-serif; font-size: 14px; color: #6b7280; margin: 0;\">
|
||||
如非本人操作,请忽略此邮件。
|
||||
</p>
|
||||
<p style=\"font-family: Arial, Helvetica, sans-serif; font-size: 12px; color: #9ca3af; margin: 8px 0 0;\">
|
||||
© {time.strftime('%Y')} 拯救小说家. All rights reserved.
|
||||
</p>
|
||||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
|
||||
message = MIMEText(html_content, "html", "utf-8")
|
||||
message["Subject"] = Header("注册验证码", "utf-8").encode()
|
||||
message["From"] = formatted_from
|
||||
message["To"] = to_email
|
||||
|
||||
logger.info("准备发送验证码邮件", extra={"to": to_email, "server": server, "port": port})
|
||||
|
||||
def _send():
|
||||
smtp: Optional[smtplib.SMTP] = None
|
||||
try:
|
||||
if port == 465:
|
||||
smtp = smtplib.SMTP_SSL(server, port, timeout=10)
|
||||
else:
|
||||
smtp = smtplib.SMTP(server, port, timeout=10)
|
||||
smtp.starttls()
|
||||
if username and password:
|
||||
smtp.login(username, password)
|
||||
smtp.sendmail(from_addr, [to_email], message.as_string())
|
||||
logger.info("验证码邮件发送成功", extra={"to": to_email})
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.exception("验证码发送失败")
|
||||
raise
|
||||
finally:
|
||||
if smtp is not None:
|
||||
try:
|
||||
smtp.quit()
|
||||
except Exception: # noqa: BLE001
|
||||
pass
|
||||
|
||||
try:
|
||||
await asyncio.to_thread(_send)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
raise HTTPException(status_code=500, detail="验证码发送失败,请检查邮件配置") from exc
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# OAuth 对接示例(以 Linux.do 为例)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def handle_linuxdo_callback(self, code: str) -> Token:
|
||||
if not await self.is_linuxdo_login_enabled():
|
||||
raise HTTPException(status_code=403, detail="未启用 Linux.do 登录")
|
||||
client_id = await self._get_config_value("linuxdo.client_id")
|
||||
client_secret = await self._get_config_value("linuxdo.client_secret")
|
||||
redirect_uri = await self._get_config_value("linuxdo.redirect_uri")
|
||||
token_url = await self._get_config_value("linuxdo.token_url")
|
||||
user_info_url = await self._get_config_value("linuxdo.user_info_url")
|
||||
|
||||
if not all([client_id, client_secret, redirect_uri, token_url, user_info_url]):
|
||||
raise HTTPException(status_code=500, detail="未正确配置 Linux.do OAuth 参数")
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
token_response = await client.post(
|
||||
token_url,
|
||||
data={
|
||||
"client_id": client_id,
|
||||
"client_secret": client_secret,
|
||||
"code": code,
|
||||
"redirect_uri": redirect_uri,
|
||||
"grant_type": "authorization_code",
|
||||
},
|
||||
)
|
||||
token_response.raise_for_status()
|
||||
access_token = token_response.json().get("access_token")
|
||||
if not access_token:
|
||||
raise HTTPException(status_code=400, detail="授权失败,未获取到访问令牌")
|
||||
|
||||
user_info_response = await client.get(
|
||||
user_info_url,
|
||||
headers={"Authorization": f"Bearer {access_token}"},
|
||||
)
|
||||
user_info_response.raise_for_status()
|
||||
data = user_info_response.json()
|
||||
|
||||
external_id = f"linuxdo:{data['id']}"
|
||||
user = await self.user_repo.get_by_external_id(external_id)
|
||||
if user is None:
|
||||
placeholder_password = secrets.token_urlsafe(16)
|
||||
user = User(
|
||||
username=data["username"],
|
||||
email=data.get("email"),
|
||||
external_id=external_id,
|
||||
hashed_password=hash_password(placeholder_password),
|
||||
)
|
||||
self.session.add(user)
|
||||
await self.session.commit()
|
||||
|
||||
return await self.create_access_token(user)
|
||||
|
||||
async def _get_config_value(self, key: str) -> Optional[str]:
|
||||
config = await self.system_config_repo.get_by_key(key)
|
||||
return config.value if config else None
|
||||
|
||||
async def get_config_value(self, key: str) -> Optional[str]:
|
||||
"""对外暴露的配置读取接口,便于路由层复用。"""
|
||||
return await self._get_config_value(key)
|
||||
|
||||
@staticmethod
|
||||
def _parse_bool(value: Optional[str], fallback: bool) -> bool:
|
||||
if value is None:
|
||||
return fallback
|
||||
normalized = value.strip().lower()
|
||||
return normalized in {"1", "true", "yes", "on"}
|
||||
|
||||
async def is_registration_enabled(self) -> bool:
|
||||
value = await self._get_config_value("auth.allow_registration")
|
||||
return self._parse_bool(value, fallback=settings.allow_registration)
|
||||
|
||||
async def is_linuxdo_login_enabled(self) -> bool:
|
||||
value = await self._get_config_value("auth.linuxdo_enabled")
|
||||
return self._parse_bool(value, fallback=settings.enable_linuxdo_login)
|
||||
|
||||
async def get_auth_options(self) -> AuthOptions:
|
||||
"""聚合与认证相关的动态开关配置,便于前端一次性拉取。"""
|
||||
|
||||
allow_registration = await self.is_registration_enabled()
|
||||
enable_linuxdo_login = await self.is_linuxdo_login_enabled()
|
||||
return AuthOptions(
|
||||
allow_registration=allow_registration,
|
||||
enable_linuxdo_login=enable_linuxdo_login,
|
||||
)
|
||||
|
||||
def requires_password_reset(self, user: User | UserInDB) -> bool:
|
||||
if not user.is_admin:
|
||||
return False
|
||||
if user.username != settings.admin_default_username:
|
||||
return False
|
||||
hashed_password = getattr(user, "hashed_password", None)
|
||||
if not hashed_password:
|
||||
return False
|
||||
return verify_password(settings.admin_default_password, hashed_password)
|
||||
|
||||
async def change_password(self, username: str, old_password: str, new_password: str) -> None:
|
||||
user = await self.user_repo.get_by_username(username)
|
||||
if not user:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="用户不存在")
|
||||
|
||||
if not verify_password(old_password, user.hashed_password):
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="当前密码错误")
|
||||
|
||||
if verify_password(new_password, user.hashed_password):
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="新密码不能与当前密码相同")
|
||||
|
||||
if username == settings.admin_default_username and new_password == settings.admin_default_password:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="新密码不能为默认密码")
|
||||
|
||||
user.hashed_password = hash_password(new_password)
|
||||
await self.session.commit()
|
||||
109
backend/app/services/chapter_context_service.py
Normal file
109
backend/app/services/chapter_context_service.py
Normal file
@@ -0,0 +1,109 @@
|
||||
from __future__ import annotations
|
||||
|
||||
"""
|
||||
章节上下文组装服务:负责调用向量库检索上下文,并对结果做基础格式化。
|
||||
|
||||
所有关键步骤均包含中文注释,方便团队理解 RAG 流程。
|
||||
"""
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional
|
||||
|
||||
from ..core.config import settings
|
||||
from ..services.llm_service import LLMService
|
||||
from .vector_store_service import RetrievedChunk, RetrievedSummary, VectorStoreService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChapterRAGContext:
|
||||
"""封装检索得到的上下文结果。"""
|
||||
|
||||
query: str
|
||||
chunks: List[RetrievedChunk]
|
||||
summaries: List[RetrievedSummary]
|
||||
|
||||
def chunk_texts(self) -> List[str]:
|
||||
"""将检索到的 chunk 转换成带序号的 Markdown 段落。"""
|
||||
lines = []
|
||||
for idx, chunk in enumerate(self.chunks, start=1):
|
||||
title = chunk.chapter_title or f"第{chunk.chapter_number}章"
|
||||
lines.append(
|
||||
f"### Chunk {idx}(来源:{title})\n{chunk.content.strip()}"
|
||||
)
|
||||
return lines
|
||||
|
||||
def summary_lines(self) -> List[str]:
|
||||
"""整理章节摘要,方便直接插入 Prompt。"""
|
||||
lines = []
|
||||
for summary in self.summaries:
|
||||
lines.append(
|
||||
f"- 第{summary.chapter_number}章 - {summary.title}:{summary.summary.strip()}"
|
||||
)
|
||||
return lines
|
||||
|
||||
|
||||
class ChapterContextService:
|
||||
"""章节上下文服务,整合查询、格式化与容错逻辑。"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
llm_service: LLMService,
|
||||
vector_store: Optional[VectorStoreService] = None,
|
||||
) -> None:
|
||||
self._llm_service = llm_service
|
||||
self._vector_store = vector_store
|
||||
|
||||
async def retrieve_for_generation(
|
||||
self,
|
||||
*,
|
||||
project_id: str,
|
||||
query_text: str,
|
||||
user_id: int,
|
||||
top_k_chunks: Optional[int] = None,
|
||||
top_k_summaries: Optional[int] = None,
|
||||
) -> ChapterRAGContext:
|
||||
"""根据章节摘要构造检索向量,并返回 RAG 上下文。"""
|
||||
query = self._normalize(query_text)
|
||||
if not settings.vector_store_enabled or not self._vector_store:
|
||||
logger.debug("向量库未启用或初始化失败,跳过检索: project=%s", project_id)
|
||||
return ChapterRAGContext(query=query, chunks=[], summaries=[])
|
||||
|
||||
embedding_model = None if settings.embedding_provider == "ollama" else settings.embedding_model
|
||||
embedding = await self._llm_service.get_embedding(query, user_id=user_id, model=embedding_model)
|
||||
if not embedding:
|
||||
logger.warning("检索查询向量生成失败: project=%s chapter_query=%s", project_id, query)
|
||||
return ChapterRAGContext(query=query, chunks=[], summaries=[])
|
||||
|
||||
chunks = await self._vector_store.query_chunks(
|
||||
project_id=project_id,
|
||||
embedding=embedding,
|
||||
top_k=top_k_chunks,
|
||||
)
|
||||
summaries = await self._vector_store.query_summaries(
|
||||
project_id=project_id,
|
||||
embedding=embedding,
|
||||
top_k=top_k_summaries,
|
||||
)
|
||||
logger.info(
|
||||
"章节上下文检索完成: project=%s chunks=%d summaries=%d query_preview=%s",
|
||||
project_id,
|
||||
len(chunks),
|
||||
len(summaries),
|
||||
query[:80],
|
||||
)
|
||||
return ChapterRAGContext(query=query, chunks=chunks, summaries=summaries)
|
||||
|
||||
@staticmethod
|
||||
def _normalize(text: str) -> str:
|
||||
"""统一压缩空白字符,避免影响检索效果。"""
|
||||
return " ".join(text.split())
|
||||
|
||||
|
||||
__all__ = [
|
||||
"ChapterContextService",
|
||||
"ChapterRAGContext",
|
||||
]
|
||||
262
backend/app/services/chapter_ingest_service.py
Normal file
262
backend/app/services/chapter_ingest_service.py
Normal file
@@ -0,0 +1,262 @@
|
||||
from __future__ import annotations
|
||||
|
||||
"""
|
||||
章节向量入库服务:在章节确认后负责切分文本、生成嵌入并写入向量库。
|
||||
|
||||
全部注释使用中文,方便团队成员阅读理解。
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Dict, List, Optional, Sequence
|
||||
|
||||
from ..core.config import settings
|
||||
from ..services.llm_service import LLMService
|
||||
from ..services.vector_store_service import VectorStoreService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
try: # noqa: SIM105 - 提示缺少可选依赖
|
||||
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
||||
except ImportError: # pragma: no cover - 未安装时会走后备方案
|
||||
RecursiveCharacterTextSplitter = None # type: ignore[assignment]
|
||||
|
||||
|
||||
class ChapterIngestionService:
|
||||
"""封装章节内容与摘要的向量化与入库流程。"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
llm_service: LLMService,
|
||||
vector_store: Optional[VectorStoreService] = None,
|
||||
) -> None:
|
||||
self._llm_service = llm_service
|
||||
self._vector_store = vector_store or VectorStoreService()
|
||||
self._text_splitter = self._init_text_splitter()
|
||||
|
||||
async def ingest_chapter(
|
||||
self,
|
||||
*,
|
||||
project_id: str,
|
||||
chapter_number: int,
|
||||
title: str,
|
||||
content: str,
|
||||
summary: Optional[str],
|
||||
user_id: int,
|
||||
) -> None:
|
||||
"""将章节正文与摘要写入向量库,供后续 RAG 检索使用。"""
|
||||
if not settings.vector_store_enabled:
|
||||
logger.debug("向量库未启用,跳过章节向量写入: project=%s chapter=%s", project_id, chapter_number)
|
||||
return
|
||||
if not content.strip():
|
||||
logger.debug("章节正文为空,跳过向量写入: project=%s chapter=%s", project_id, chapter_number)
|
||||
return
|
||||
|
||||
chunks = self._split_into_chunks(content)
|
||||
if not chunks:
|
||||
logger.debug("章节正文切分后为空,跳过向量写入: project=%s chapter=%s", project_id, chapter_number)
|
||||
return
|
||||
|
||||
logger.info(
|
||||
"开始写入章节向量: project=%s chapter=%s chunks=%d",
|
||||
project_id,
|
||||
chapter_number,
|
||||
len(chunks),
|
||||
)
|
||||
await self._vector_store.delete_by_chapters(project_id, [chapter_number])
|
||||
|
||||
chunk_records = []
|
||||
for index, chunk_text in enumerate(chunks):
|
||||
embedding = await self._llm_service.get_embedding(
|
||||
chunk_text,
|
||||
user_id=user_id,
|
||||
)
|
||||
if not embedding:
|
||||
logger.warning(
|
||||
"生成章节片段向量失败,已跳过: project=%s chapter=%s chunk=%s",
|
||||
project_id,
|
||||
chapter_number,
|
||||
index,
|
||||
)
|
||||
continue
|
||||
record_id = f"{project_id}:{chapter_number}:{index}"
|
||||
chunk_records.append(
|
||||
{
|
||||
"id": record_id,
|
||||
"project_id": project_id,
|
||||
"chapter_number": chapter_number,
|
||||
"chunk_index": index,
|
||||
"chapter_title": title,
|
||||
"content": chunk_text,
|
||||
"embedding": embedding,
|
||||
"metadata": {
|
||||
"chunk_id": record_id,
|
||||
"length": len(chunk_text),
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
if chunk_records:
|
||||
await self._vector_store.upsert_chunks(records=chunk_records)
|
||||
logger.info(
|
||||
"章节正文向量写入完成: project=%s chapter=%s 成功片段=%d",
|
||||
project_id,
|
||||
chapter_number,
|
||||
len(chunk_records),
|
||||
)
|
||||
|
||||
if summary:
|
||||
cleaned_summary = summary.strip()
|
||||
if cleaned_summary:
|
||||
summary_embedding = await self._llm_service.get_embedding(
|
||||
cleaned_summary,
|
||||
user_id=user_id,
|
||||
)
|
||||
if summary_embedding:
|
||||
summary_id = f"{project_id}:{chapter_number}:summary"
|
||||
await self._vector_store.upsert_summaries(
|
||||
records=[
|
||||
{
|
||||
"id": summary_id,
|
||||
"project_id": project_id,
|
||||
"chapter_number": chapter_number,
|
||||
"title": title,
|
||||
"summary": cleaned_summary,
|
||||
"embedding": summary_embedding,
|
||||
}
|
||||
]
|
||||
)
|
||||
logger.info(
|
||||
"章节摘要向量写入完成: project=%s chapter=%s",
|
||||
project_id,
|
||||
chapter_number,
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"生成章节摘要向量失败,已跳过: project=%s chapter=%s",
|
||||
project_id,
|
||||
chapter_number,
|
||||
)
|
||||
|
||||
async def delete_chapters(self, project_id: str, chapter_numbers: Sequence[int]) -> None:
|
||||
"""从向量库中删除指定章节的所有片段与摘要。"""
|
||||
if not settings.vector_store_enabled or not chapter_numbers:
|
||||
return
|
||||
logger.info(
|
||||
"准备删除章节向量: project=%s chapters=%s",
|
||||
project_id,
|
||||
list(chapter_numbers),
|
||||
)
|
||||
await self._vector_store.delete_by_chapters(project_id, list(chapter_numbers))
|
||||
|
||||
def _split_into_chunks(self, text: str) -> List[str]:
|
||||
"""按照配置的 chunk 大小与重叠度切分章节正文。"""
|
||||
normalized = text.strip()
|
||||
if not normalized:
|
||||
return []
|
||||
|
||||
if self._text_splitter:
|
||||
parts = [segment.strip() for segment in self._text_splitter.split_text(normalized)]
|
||||
filtered = [part for part in parts if part]
|
||||
if filtered:
|
||||
logger.debug(
|
||||
"使用 LangChain 文本切分器完成分段: count=%d chunk_size=%d overlap=%d",
|
||||
len(filtered),
|
||||
settings.vector_chunk_size,
|
||||
settings.vector_chunk_overlap,
|
||||
)
|
||||
return filtered
|
||||
|
||||
return self._legacy_split(normalized)
|
||||
|
||||
@staticmethod
|
||||
def _find_split_offset(segment: str) -> Optional[int]:
|
||||
"""在片段内部寻找更自然的分割点,优先换行,其次常见标点。"""
|
||||
candidates: Dict[str, int] = {}
|
||||
newline_pos = segment.rfind("\n\n")
|
||||
if newline_pos == -1:
|
||||
newline_pos = segment.rfind("\n")
|
||||
if newline_pos > 0:
|
||||
candidates["newline"] = newline_pos
|
||||
|
||||
punctuation_marks = ["。", "!", "?", "!", "?", ".", ";", ";"]
|
||||
for mark in punctuation_marks:
|
||||
idx = segment.rfind(mark)
|
||||
if idx > 0:
|
||||
candidates.setdefault("punctuation", idx + len(mark))
|
||||
|
||||
if not candidates:
|
||||
return None
|
||||
|
||||
# 选择最接近末尾但又不过短的分割点
|
||||
best_offset = max(candidates.values())
|
||||
if best_offset < len(segment) * 0.4:
|
||||
return None
|
||||
return best_offset
|
||||
|
||||
def _init_text_splitter(self) -> Optional["RecursiveCharacterTextSplitter"]:
|
||||
"""初始化 LangChain 文本切分器,可根据配置动态调整。"""
|
||||
if RecursiveCharacterTextSplitter is None:
|
||||
logger.warning("未安装 langchain-text-splitters,章节切分将回退至内置策略。")
|
||||
return None
|
||||
|
||||
chunk_size = settings.vector_chunk_size
|
||||
overlap = min(settings.vector_chunk_overlap, chunk_size // 2)
|
||||
separators = [
|
||||
"\n\n",
|
||||
"\n",
|
||||
"。", "!", "?",
|
||||
"!", "?", ";", ";",
|
||||
",", ",",
|
||||
" ",
|
||||
]
|
||||
splitter = RecursiveCharacterTextSplitter(
|
||||
separators=separators,
|
||||
chunk_size=chunk_size,
|
||||
chunk_overlap=overlap,
|
||||
keep_separator=False,
|
||||
strip_whitespace=True,
|
||||
)
|
||||
logger.info(
|
||||
"已初始化 LangChain 文本切分器: chunk_size=%d overlap=%d",
|
||||
chunk_size,
|
||||
overlap,
|
||||
)
|
||||
return splitter
|
||||
|
||||
def _legacy_split(self, text: str) -> List[str]:
|
||||
"""内置切分策略,作为 LangChain 缺失时的后备方案。"""
|
||||
chunk_size = settings.vector_chunk_size
|
||||
overlap = min(settings.vector_chunk_overlap, chunk_size // 2)
|
||||
|
||||
chunks: List[str] = []
|
||||
start = 0
|
||||
total_length = len(text)
|
||||
|
||||
while start < total_length:
|
||||
end = min(total_length, start + chunk_size)
|
||||
segment = text[start:end]
|
||||
|
||||
split_offset = self._find_split_offset(segment)
|
||||
if split_offset is not None and start + split_offset < total_length:
|
||||
end = start + split_offset
|
||||
segment = text[start:end]
|
||||
|
||||
chunk_text = segment.strip()
|
||||
if chunk_text:
|
||||
chunks.append(chunk_text)
|
||||
|
||||
if end >= total_length:
|
||||
break
|
||||
start = max(0, end - overlap)
|
||||
|
||||
logger.debug(
|
||||
"使用内置策略完成章节切分: count=%d chunk_size=%d overlap=%d",
|
||||
len(chunks),
|
||||
chunk_size,
|
||||
overlap,
|
||||
)
|
||||
return chunks
|
||||
|
||||
|
||||
__all__ = ["ChapterIngestionService"]
|
||||
49
backend/app/services/config_service.py
Normal file
49
backend/app/services/config_service.py
Normal file
@@ -0,0 +1,49 @@
|
||||
from typing import Iterable, Optional
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from ..repositories.system_config_repository import SystemConfigRepository
|
||||
from ..models import SystemConfig
|
||||
from ..schemas.config import SystemConfigCreate, SystemConfigRead, SystemConfigUpdate
|
||||
|
||||
|
||||
class ConfigService:
|
||||
"""系统配置服务:提供 CRUD 接口,并负责转换 Pydantic 模型。"""
|
||||
|
||||
def __init__(self, session: AsyncSession):
|
||||
self.session = session
|
||||
self.repo = SystemConfigRepository(session)
|
||||
|
||||
async def list_configs(self) -> list[SystemConfigRead]:
|
||||
configs = await self.repo.list_all()
|
||||
return [SystemConfigRead.model_validate(cfg) for cfg in configs]
|
||||
|
||||
async def get_config(self, key: str) -> Optional[SystemConfigRead]:
|
||||
config = await self.repo.get_by_key(key)
|
||||
return SystemConfigRead.model_validate(config) if config else None
|
||||
|
||||
async def upsert_config(self, payload: SystemConfigCreate) -> SystemConfigRead:
|
||||
instance = await self.repo.get_by_key(payload.key)
|
||||
if instance:
|
||||
await self.repo.update_fields(instance, value=payload.value, description=payload.description)
|
||||
else:
|
||||
instance = SystemConfig(**payload.model_dump())
|
||||
await self.repo.add(instance)
|
||||
await self.session.commit()
|
||||
return SystemConfigRead.model_validate(instance)
|
||||
|
||||
async def patch_config(self, key: str, payload: SystemConfigUpdate) -> Optional[SystemConfigRead]:
|
||||
instance = await self.repo.get_by_key(key)
|
||||
if not instance:
|
||||
return None
|
||||
await self.repo.update_fields(instance, **payload.model_dump(exclude_unset=True))
|
||||
await self.session.commit()
|
||||
return SystemConfigRead.model_validate(instance)
|
||||
|
||||
async def remove_config(self, key: str) -> bool:
|
||||
instance = await self.repo.get_by_key(key)
|
||||
if not instance:
|
||||
return False
|
||||
await self.repo.delete(instance)
|
||||
await self.session.commit()
|
||||
return True
|
||||
41
backend/app/services/llm_config_service.py
Normal file
41
backend/app/services/llm_config_service.py
Normal file
@@ -0,0 +1,41 @@
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from ..models import LLMConfig
|
||||
from ..repositories.llm_config_repository import LLMConfigRepository
|
||||
from ..schemas.llm_config import LLMConfigCreate, LLMConfigRead
|
||||
|
||||
|
||||
class LLMConfigService:
|
||||
"""用户自定义 LLM 配置服务。"""
|
||||
|
||||
def __init__(self, session: AsyncSession):
|
||||
self.session = session
|
||||
self.repo = LLMConfigRepository(session)
|
||||
|
||||
async def upsert_config(self, user_id: int, payload: LLMConfigCreate) -> LLMConfigRead:
|
||||
instance = await self.repo.get_by_user(user_id)
|
||||
data = payload.model_dump(exclude_unset=True)
|
||||
if "llm_provider_url" in data and data["llm_provider_url"] is not None:
|
||||
# HttpUrl 类型在 sqlite 中无法直接写入,需要提前转为字符串
|
||||
data["llm_provider_url"] = str(data["llm_provider_url"])
|
||||
if instance:
|
||||
await self.repo.update_fields(instance, **data)
|
||||
else:
|
||||
instance = LLMConfig(user_id=user_id, **data)
|
||||
await self.repo.add(instance)
|
||||
await self.session.commit()
|
||||
return LLMConfigRead.model_validate(instance)
|
||||
|
||||
async def get_config(self, user_id: int) -> Optional[LLMConfigRead]:
|
||||
instance = await self.repo.get_by_user(user_id)
|
||||
return LLMConfigRead.model_validate(instance) if instance else None
|
||||
|
||||
async def delete_config(self, user_id: int) -> bool:
|
||||
instance = await self.repo.get_by_user(user_id)
|
||||
if not instance:
|
||||
return False
|
||||
await self.repo.delete(instance)
|
||||
await self.session.commit()
|
||||
return True
|
||||
306
backend/app/services/llm_service.py
Normal file
306
backend/app/services/llm_service.py
Normal file
@@ -0,0 +1,306 @@
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import httpx
|
||||
from fastapi import HTTPException, status
|
||||
from openai import APIConnectionError, APITimeoutError, AsyncOpenAI, InternalServerError
|
||||
|
||||
from ..core.config import settings
|
||||
from ..repositories.llm_config_repository import LLMConfigRepository
|
||||
from ..repositories.system_config_repository import SystemConfigRepository
|
||||
from ..repositories.user_repository import UserRepository
|
||||
from ..services.admin_setting_service import AdminSettingService
|
||||
from ..services.prompt_service import PromptService
|
||||
from ..services.usage_service import UsageService
|
||||
from ..utils.llm_tool import ChatMessage, LLMClient
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
try: # pragma: no cover - 运行环境未安装时兼容
|
||||
from ollama import AsyncClient as OllamaAsyncClient
|
||||
except ImportError: # pragma: no cover - Ollama 为可选依赖
|
||||
OllamaAsyncClient = None
|
||||
|
||||
|
||||
class LLMService:
|
||||
"""封装与大模型交互的所有逻辑,包括配额控制与配置选择。"""
|
||||
|
||||
def __init__(self, session):
|
||||
self.session = session
|
||||
self.llm_repo = LLMConfigRepository(session)
|
||||
self.system_config_repo = SystemConfigRepository(session)
|
||||
self.user_repo = UserRepository(session)
|
||||
self.admin_setting_service = AdminSettingService(session)
|
||||
self.usage_service = UsageService(session)
|
||||
self._embedding_dimensions: Dict[str, int] = {}
|
||||
|
||||
async def get_llm_response(
|
||||
self,
|
||||
system_prompt: str,
|
||||
conversation_history: List[Dict[str, str]],
|
||||
*,
|
||||
temperature: float = 0.7,
|
||||
user_id: Optional[int] = None,
|
||||
timeout: float = 300.0,
|
||||
response_format: Optional[str] = "json_object",
|
||||
) -> str:
|
||||
messages = [{"role": "system", "content": system_prompt}, *conversation_history]
|
||||
return await self._stream_and_collect(
|
||||
messages,
|
||||
temperature=temperature,
|
||||
user_id=user_id,
|
||||
timeout=timeout,
|
||||
response_format=response_format,
|
||||
)
|
||||
|
||||
async def get_summary(
|
||||
self,
|
||||
chapter_content: str,
|
||||
*,
|
||||
temperature: float = 0.2,
|
||||
user_id: Optional[int] = None,
|
||||
timeout: float = 180.0,
|
||||
system_prompt: Optional[str] = None,
|
||||
) -> str:
|
||||
if not system_prompt:
|
||||
prompt_service = PromptService(self.session)
|
||||
system_prompt = await prompt_service.get_prompt("extraction")
|
||||
if not system_prompt:
|
||||
raise HTTPException(status_code=500, detail="未配置摘要提示词")
|
||||
messages = [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": chapter_content},
|
||||
]
|
||||
return await self._stream_and_collect(messages, temperature=temperature, user_id=user_id, timeout=timeout)
|
||||
|
||||
async def _stream_and_collect(
|
||||
self,
|
||||
messages: List[Dict[str, str]],
|
||||
*,
|
||||
temperature: float,
|
||||
user_id: Optional[int],
|
||||
timeout: float,
|
||||
response_format: Optional[str] = None,
|
||||
) -> str:
|
||||
config = await self._resolve_llm_config(user_id)
|
||||
client = LLMClient(api_key=config["api_key"], base_url=config.get("base_url"))
|
||||
|
||||
chat_messages = [ChatMessage(role=msg["role"], content=msg["content"]) for msg in messages]
|
||||
|
||||
full_response = ""
|
||||
finish_reason = None
|
||||
|
||||
logger.info(
|
||||
"Streaming LLM response: model=%s user_id=%s messages=%d",
|
||||
config.get("model"),
|
||||
user_id,
|
||||
len(messages),
|
||||
)
|
||||
|
||||
try:
|
||||
async for part in client.stream_chat(
|
||||
messages=chat_messages,
|
||||
model=config.get("model"),
|
||||
temperature=temperature,
|
||||
timeout=int(timeout),
|
||||
response_format=response_format,
|
||||
):
|
||||
if part.get("content"):
|
||||
full_response += part["content"]
|
||||
if part.get("finish_reason"):
|
||||
finish_reason = part["finish_reason"]
|
||||
except InternalServerError as exc:
|
||||
detail = "AI 服务内部错误,请稍后重试"
|
||||
response = getattr(exc, "response", None)
|
||||
if response is not None:
|
||||
try:
|
||||
payload = response.json()
|
||||
error_data = payload.get("error", {}) if isinstance(payload, dict) else {}
|
||||
detail = error_data.get("message_zh") or error_data.get("message") or detail
|
||||
except Exception:
|
||||
detail = str(exc) or detail
|
||||
else:
|
||||
detail = str(exc) or detail
|
||||
logger.error(
|
||||
"LLM stream internal error: model=%s user_id=%s detail=%s",
|
||||
config.get("model"),
|
||||
user_id,
|
||||
detail,
|
||||
exc_info=exc,
|
||||
)
|
||||
raise HTTPException(status_code=503, detail=detail)
|
||||
except (httpx.RemoteProtocolError, httpx.ReadTimeout, APIConnectionError, APITimeoutError) as exc:
|
||||
if isinstance(exc, httpx.RemoteProtocolError):
|
||||
detail = "AI 服务连接被意外中断,请稍后重试"
|
||||
elif isinstance(exc, (httpx.ReadTimeout, APITimeoutError)):
|
||||
detail = "AI 服务响应超时,请稍后重试"
|
||||
else:
|
||||
detail = "无法连接到 AI 服务,请稍后重试"
|
||||
logger.error(
|
||||
"LLM stream failed: model=%s user_id=%s detail=%s",
|
||||
config.get("model"),
|
||||
user_id,
|
||||
detail,
|
||||
exc_info=exc,
|
||||
)
|
||||
raise HTTPException(status_code=503, detail=detail) from exc
|
||||
|
||||
logger.debug(
|
||||
"LLM response collected: model=%s user_id=%s finish_reason=%s preview=%s",
|
||||
config.get("model"),
|
||||
user_id,
|
||||
finish_reason,
|
||||
full_response[:500],
|
||||
)
|
||||
|
||||
if finish_reason == "length":
|
||||
logger.warning(
|
||||
"LLM response truncated: model=%s user_id=%s",
|
||||
config.get("model"),
|
||||
user_id,
|
||||
)
|
||||
raise HTTPException(status_code=500, detail="AI 响应被截断,请缩短输入或调整参数")
|
||||
|
||||
if not full_response:
|
||||
logger.error(
|
||||
"LLM returned empty response: model=%s user_id=%s",
|
||||
config.get("model"),
|
||||
user_id,
|
||||
)
|
||||
raise HTTPException(status_code=500, detail="AI 未返回有效内容")
|
||||
|
||||
await self.usage_service.increment("api_request_count")
|
||||
logger.info(
|
||||
"LLM response success: model=%s user_id=%s chars=%d",
|
||||
config.get("model"),
|
||||
user_id,
|
||||
len(full_response),
|
||||
)
|
||||
return full_response
|
||||
|
||||
async def _resolve_llm_config(self, user_id: Optional[int]) -> Dict[str, Optional[str]]:
|
||||
if user_id:
|
||||
config = await self.llm_repo.get_by_user(user_id)
|
||||
if config and config.llm_provider_api_key:
|
||||
return {
|
||||
"api_key": config.llm_provider_api_key,
|
||||
"base_url": config.llm_provider_url,
|
||||
"model": config.llm_provider_model,
|
||||
}
|
||||
|
||||
# 检查每日使用次数限制
|
||||
if user_id:
|
||||
await self._enforce_daily_limit(user_id)
|
||||
|
||||
api_key = await self._get_config_value("llm.api_key")
|
||||
base_url = await self._get_config_value("llm.base_url")
|
||||
model = await self._get_config_value("llm.model")
|
||||
|
||||
if not api_key:
|
||||
raise HTTPException(status_code=500, detail="未配置默认 LLM API Key")
|
||||
|
||||
return {"api_key": api_key, "base_url": base_url, "model": model}
|
||||
|
||||
async def get_embedding(
|
||||
self,
|
||||
text: str,
|
||||
*,
|
||||
user_id: Optional[int] = None,
|
||||
model: Optional[str] = None,
|
||||
) -> List[float]:
|
||||
"""生成文本向量,用于章节 RAG 检索,支持 openai 与 ollama 双提供方。"""
|
||||
provider = settings.embedding_provider
|
||||
target_model = model or (
|
||||
settings.ollama_embedding_model if provider == "ollama" else settings.embedding_model
|
||||
)
|
||||
|
||||
if provider == "ollama":
|
||||
if OllamaAsyncClient is None:
|
||||
logger.error("未安装 ollama 依赖,无法调用本地嵌入模型。")
|
||||
raise HTTPException(status_code=500, detail="缺少 Ollama 依赖,请先安装 ollama 包。")
|
||||
|
||||
base_url_any = settings.ollama_embedding_base_url or settings.embedding_base_url
|
||||
base_url = str(base_url_any) if base_url_any else None
|
||||
client = OllamaAsyncClient(host=base_url)
|
||||
try:
|
||||
response = await client.embeddings(model=target_model, prompt=text)
|
||||
except Exception as exc: # pragma: no cover - 本地服务调用失败
|
||||
logger.warning(
|
||||
"Ollama 嵌入请求失败: model=%s error=%s",
|
||||
target_model,
|
||||
exc,
|
||||
)
|
||||
return []
|
||||
embedding: Optional[List[float]]
|
||||
if isinstance(response, dict):
|
||||
embedding = response.get("embedding")
|
||||
else:
|
||||
embedding = getattr(response, "embedding", None)
|
||||
if not embedding:
|
||||
logger.warning("Ollama 返回空向量: model=%s", target_model)
|
||||
return []
|
||||
if not isinstance(embedding, list):
|
||||
embedding = list(embedding)
|
||||
else:
|
||||
config = await self._resolve_llm_config(user_id)
|
||||
api_key = settings.embedding_api_key or config["api_key"]
|
||||
base_url_setting = settings.embedding_base_url or config.get("base_url")
|
||||
base_url = str(base_url_setting) if base_url_setting else None
|
||||
client = AsyncOpenAI(api_key=api_key, base_url=base_url)
|
||||
try:
|
||||
response = await client.embeddings.create(
|
||||
input=text,
|
||||
model=target_model,
|
||||
)
|
||||
except Exception as exc: # pragma: no cover - 网络或鉴权失败
|
||||
logger.warning(
|
||||
"OpenAI 嵌入请求失败: model=%s user_id=%s error=%s",
|
||||
target_model,
|
||||
user_id,
|
||||
exc,
|
||||
)
|
||||
return []
|
||||
if not response.data:
|
||||
logger.warning("OpenAI 嵌入请求返回空数据: model=%s user_id=%s", target_model, user_id)
|
||||
return []
|
||||
embedding = response.data[0].embedding
|
||||
|
||||
if not isinstance(embedding, list):
|
||||
embedding = list(embedding)
|
||||
|
||||
dimension = len(embedding)
|
||||
if not dimension and settings.embedding_model_vector_size:
|
||||
dimension = settings.embedding_model_vector_size
|
||||
if dimension:
|
||||
self._embedding_dimensions[target_model] = dimension
|
||||
return embedding
|
||||
|
||||
def get_embedding_dimension(self, model: Optional[str] = None) -> Optional[int]:
|
||||
"""获取嵌入向量维度,优先返回缓存结果,其次读取配置。"""
|
||||
target_model = model or (
|
||||
settings.ollama_embedding_model if settings.embedding_provider == "ollama" else settings.embedding_model
|
||||
)
|
||||
if target_model in self._embedding_dimensions:
|
||||
return self._embedding_dimensions[target_model]
|
||||
return settings.embedding_model_vector_size
|
||||
|
||||
async def _enforce_daily_limit(self, user_id: int) -> None:
|
||||
limit_str = await self.admin_setting_service.get("daily_request_limit", "100")
|
||||
limit = int(limit_str or 10)
|
||||
used = await self.user_repo.get_daily_request(user_id)
|
||||
if used >= limit:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
|
||||
detail="今日请求次数已达上限,请明日再试或设置自定义 API Key。",
|
||||
)
|
||||
await self.user_repo.increment_daily_request(user_id)
|
||||
await self.session.commit()
|
||||
|
||||
async def _get_config_value(self, key: str) -> Optional[str]:
|
||||
record = await self.system_config_repo.get_by_key(key)
|
||||
if record:
|
||||
return record.value
|
||||
# 兼容环境变量,首次迁移时无需立即写入数据库
|
||||
env_key = key.upper().replace(".", "_")
|
||||
return os.getenv(env_key)
|
||||
700
backend/app/services/novel_service.py
Normal file
700
backend/app/services/novel_service.py
Normal file
@@ -0,0 +1,700 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Dict, Iterable, List, Optional
|
||||
|
||||
_PREFERRED_CONTENT_KEYS: tuple[str, ...] = (
|
||||
"content",
|
||||
"chapter_content",
|
||||
"chapter_text",
|
||||
"full_content",
|
||||
"text",
|
||||
"body",
|
||||
"story",
|
||||
"chapter",
|
||||
"real_summary",
|
||||
"summary",
|
||||
)
|
||||
|
||||
|
||||
def _normalize_version_content(raw_content: Any, metadata: Any) -> str:
|
||||
text = _coerce_text(metadata)
|
||||
if not text:
|
||||
text = _coerce_text(raw_content)
|
||||
return text or ""
|
||||
|
||||
|
||||
def _coerce_text(value: Any) -> Optional[str]:
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, str):
|
||||
return _clean_string(value)
|
||||
if isinstance(value, (int, float)):
|
||||
return str(value)
|
||||
if isinstance(value, dict):
|
||||
for key in _PREFERRED_CONTENT_KEYS:
|
||||
if key in value and value[key]:
|
||||
nested = _coerce_text(value[key])
|
||||
if nested:
|
||||
return nested
|
||||
return _clean_string(json.dumps(value, ensure_ascii=False))
|
||||
if isinstance(value, (list, tuple, set)):
|
||||
parts = [text for text in (_coerce_text(item) for item in value) if text]
|
||||
if parts:
|
||||
return "\n".join(parts)
|
||||
return None
|
||||
return _clean_string(str(value))
|
||||
|
||||
|
||||
def _clean_string(text: str) -> str:
|
||||
stripped = text.strip()
|
||||
if not stripped:
|
||||
return stripped
|
||||
if stripped.startswith("{") and stripped.endswith("}"):
|
||||
try:
|
||||
parsed = json.loads(stripped)
|
||||
coerced = _coerce_text(parsed)
|
||||
if coerced:
|
||||
return coerced
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
if stripped.startswith('"') and stripped.endswith('"') and len(stripped) >= 2:
|
||||
stripped = stripped[1:-1]
|
||||
return (
|
||||
stripped.replace("\\n", "\n")
|
||||
.replace("\\t", "\t")
|
||||
.replace('\\"', '"')
|
||||
.replace("\\\\", "\\")
|
||||
)
|
||||
|
||||
from fastapi import HTTPException, status
|
||||
from sqlalchemy import delete, func, select, update
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from ..models import (
|
||||
BlueprintCharacter,
|
||||
BlueprintRelationship,
|
||||
Chapter,
|
||||
ChapterEvaluation,
|
||||
ChapterOutline,
|
||||
ChapterVersion,
|
||||
NovelBlueprint,
|
||||
NovelConversation,
|
||||
NovelProject,
|
||||
)
|
||||
from ..repositories.novel_repository import NovelRepository
|
||||
from ..schemas.admin import AdminNovelSummary
|
||||
from ..schemas.novel import (
|
||||
Blueprint,
|
||||
Chapter as ChapterSchema,
|
||||
ChapterGenerationStatus,
|
||||
ChapterOutline as ChapterOutlineSchema,
|
||||
NovelProject as NovelProjectSchema,
|
||||
NovelProjectSummary,
|
||||
NovelSectionResponse,
|
||||
NovelSectionType,
|
||||
)
|
||||
|
||||
|
||||
class NovelService:
|
||||
"""小说项目服务,基于拆表后的结构提供聚合与业务操作。"""
|
||||
|
||||
def __init__(self, session: AsyncSession):
|
||||
self.session = session
|
||||
self.repo = NovelRepository(session)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 项目与摘要
|
||||
# ------------------------------------------------------------------
|
||||
async def create_project(self, user_id: int, title: str, initial_prompt: str) -> NovelProject:
|
||||
project = NovelProject(
|
||||
id=str(uuid.uuid4()),
|
||||
user_id=user_id,
|
||||
title=title,
|
||||
initial_prompt=initial_prompt,
|
||||
)
|
||||
blueprint = NovelBlueprint(project=project)
|
||||
self.session.add_all([project, blueprint])
|
||||
await self.session.commit()
|
||||
await self.session.refresh(project)
|
||||
return project
|
||||
|
||||
async def ensure_project_owner(self, project_id: str, user_id: int) -> NovelProject:
|
||||
project = await self.repo.get_by_id(project_id)
|
||||
if not project:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="项目不存在")
|
||||
if project.user_id != user_id:
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="无权访问该项目")
|
||||
return project
|
||||
|
||||
async def get_project_schema(self, project_id: str, user_id: int) -> NovelProjectSchema:
|
||||
project = await self.ensure_project_owner(project_id, user_id)
|
||||
return await self._serialize_project(project)
|
||||
|
||||
async def get_section_data(
|
||||
self,
|
||||
project_id: str,
|
||||
user_id: int,
|
||||
section: NovelSectionType,
|
||||
) -> NovelSectionResponse:
|
||||
project = await self.ensure_project_owner(project_id, user_id)
|
||||
return self._build_section_response(project, section)
|
||||
|
||||
async def get_chapter_schema(
|
||||
self,
|
||||
project_id: str,
|
||||
user_id: int,
|
||||
chapter_number: int,
|
||||
) -> ChapterSchema:
|
||||
project = await self.ensure_project_owner(project_id, user_id)
|
||||
return self._build_chapter_schema(project, chapter_number)
|
||||
|
||||
async def list_projects_for_user(self, user_id: int) -> List[NovelProjectSummary]:
|
||||
projects = await self.repo.list_by_user(user_id)
|
||||
summaries: List[NovelProjectSummary] = []
|
||||
for project in projects:
|
||||
blueprint = project.blueprint
|
||||
genre = blueprint.genre if blueprint and blueprint.genre else "未知"
|
||||
outlines = project.outlines
|
||||
chapters = project.chapters
|
||||
total = len(outlines) or len(chapters)
|
||||
completed = sum(1 for chapter in chapters if chapter.selected_version_id)
|
||||
summaries.append(
|
||||
NovelProjectSummary(
|
||||
id=project.id,
|
||||
title=project.title,
|
||||
genre=genre,
|
||||
last_edited=project.updated_at.isoformat() if project.updated_at else "未知",
|
||||
completed_chapters=completed,
|
||||
total_chapters=total,
|
||||
)
|
||||
)
|
||||
return summaries
|
||||
|
||||
async def list_projects_for_admin(self) -> List[AdminNovelSummary]:
|
||||
projects = await self.repo.list_all()
|
||||
summaries: List[AdminNovelSummary] = []
|
||||
for project in projects:
|
||||
blueprint = project.blueprint
|
||||
genre = blueprint.genre if blueprint and blueprint.genre else "未知"
|
||||
outlines = project.outlines
|
||||
chapters = project.chapters
|
||||
total = len(outlines) or len(chapters)
|
||||
completed = sum(1 for chapter in chapters if chapter.selected_version_id)
|
||||
owner = project.owner
|
||||
summaries.append(
|
||||
AdminNovelSummary(
|
||||
id=project.id,
|
||||
title=project.title,
|
||||
owner_id=owner.id if owner else 0,
|
||||
owner_username=owner.username if owner else "未知",
|
||||
genre=genre,
|
||||
last_edited=project.updated_at.isoformat() if project.updated_at else "",
|
||||
completed_chapters=completed,
|
||||
total_chapters=total,
|
||||
)
|
||||
)
|
||||
return summaries
|
||||
|
||||
async def delete_projects(self, project_ids: List[str], user_id: int) -> None:
|
||||
for pid in project_ids:
|
||||
project = await self.ensure_project_owner(pid, user_id)
|
||||
await self.repo.delete(project)
|
||||
await self.session.commit()
|
||||
|
||||
async def count_projects(self) -> int:
|
||||
result = await self.session.execute(select(func.count(NovelProject.id)))
|
||||
return result.scalar_one()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 对话管理
|
||||
# ------------------------------------------------------------------
|
||||
async def list_conversations(self, project_id: str) -> List[NovelConversation]:
|
||||
stmt = (
|
||||
select(NovelConversation)
|
||||
.where(NovelConversation.project_id == project_id)
|
||||
.order_by(NovelConversation.seq.asc())
|
||||
)
|
||||
result = await self.session.execute(stmt)
|
||||
return list(result.scalars())
|
||||
|
||||
async def append_conversation(self, project_id: str, role: str, content: str, metadata: Optional[Dict] = None) -> None:
|
||||
result = await self.session.execute(
|
||||
select(func.max(NovelConversation.seq)).where(NovelConversation.project_id == project_id)
|
||||
)
|
||||
current_max = result.scalar()
|
||||
next_seq = (current_max or 0) + 1
|
||||
convo = NovelConversation(
|
||||
project_id=project_id,
|
||||
seq=next_seq,
|
||||
role=role,
|
||||
content=content,
|
||||
metadata=metadata,
|
||||
)
|
||||
self.session.add(convo)
|
||||
await self.session.commit()
|
||||
await self._touch_project(project_id)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 蓝图管理
|
||||
# ------------------------------------------------------------------
|
||||
async def replace_blueprint(self, project_id: str, blueprint: Blueprint) -> None:
|
||||
record = await self.session.get(NovelBlueprint, project_id)
|
||||
if not record:
|
||||
record = NovelBlueprint(project_id=project_id)
|
||||
self.session.add(record)
|
||||
record.title = blueprint.title
|
||||
record.target_audience = blueprint.target_audience
|
||||
record.genre = blueprint.genre
|
||||
record.style = blueprint.style
|
||||
record.tone = blueprint.tone
|
||||
record.one_sentence_summary = blueprint.one_sentence_summary
|
||||
record.full_synopsis = blueprint.full_synopsis
|
||||
record.world_setting = blueprint.world_setting
|
||||
|
||||
await self.session.execute(delete(BlueprintCharacter).where(BlueprintCharacter.project_id == project_id))
|
||||
for index, data in enumerate(blueprint.characters):
|
||||
self.session.add(
|
||||
BlueprintCharacter(
|
||||
project_id=project_id,
|
||||
name=data.get("name", ""),
|
||||
identity=data.get("identity"),
|
||||
personality=data.get("personality"),
|
||||
goals=data.get("goals"),
|
||||
abilities=data.get("abilities"),
|
||||
relationship_to_protagonist=data.get("relationship_to_protagonist"),
|
||||
extra={k: v for k, v in data.items() if k not in {
|
||||
"name",
|
||||
"identity",
|
||||
"personality",
|
||||
"goals",
|
||||
"abilities",
|
||||
"relationship_to_protagonist",
|
||||
}},
|
||||
position=index,
|
||||
)
|
||||
)
|
||||
|
||||
await self.session.execute(delete(BlueprintRelationship).where(BlueprintRelationship.project_id == project_id))
|
||||
for index, relation in enumerate(blueprint.relationships):
|
||||
self.session.add(
|
||||
BlueprintRelationship(
|
||||
project_id=project_id,
|
||||
character_from=relation.character_from,
|
||||
character_to=relation.character_to,
|
||||
description=relation.description,
|
||||
position=index,
|
||||
)
|
||||
)
|
||||
|
||||
await self.session.execute(delete(ChapterOutline).where(ChapterOutline.project_id == project_id))
|
||||
for outline in blueprint.chapter_outline:
|
||||
self.session.add(
|
||||
ChapterOutline(
|
||||
project_id=project_id,
|
||||
chapter_number=outline.chapter_number,
|
||||
title=outline.title,
|
||||
summary=outline.summary,
|
||||
)
|
||||
)
|
||||
|
||||
await self.session.commit()
|
||||
await self._touch_project(project_id)
|
||||
|
||||
async def patch_blueprint(self, project_id: str, patch: Dict) -> None:
|
||||
blueprint = await self.session.get(NovelBlueprint, project_id)
|
||||
if not blueprint:
|
||||
blueprint = NovelBlueprint(project_id=project_id)
|
||||
self.session.add(blueprint)
|
||||
|
||||
if "one_sentence_summary" in patch:
|
||||
blueprint.one_sentence_summary = patch["one_sentence_summary"]
|
||||
if "full_synopsis" in patch:
|
||||
blueprint.full_synopsis = patch["full_synopsis"]
|
||||
if "world_setting" in patch and patch["world_setting"] is not None:
|
||||
existing = blueprint.world_setting or {}
|
||||
existing.update(patch["world_setting"])
|
||||
blueprint.world_setting = existing
|
||||
if "characters" in patch and patch["characters"] is not None:
|
||||
await self.session.execute(delete(BlueprintCharacter).where(BlueprintCharacter.project_id == project_id))
|
||||
for index, data in enumerate(patch["characters"]):
|
||||
self.session.add(
|
||||
BlueprintCharacter(
|
||||
project_id=project_id,
|
||||
name=data.get("name", ""),
|
||||
identity=data.get("identity"),
|
||||
personality=data.get("personality"),
|
||||
goals=data.get("goals"),
|
||||
abilities=data.get("abilities"),
|
||||
relationship_to_protagonist=data.get("relationship_to_protagonist"),
|
||||
extra={k: v for k, v in data.items() if k not in {
|
||||
"name",
|
||||
"identity",
|
||||
"personality",
|
||||
"goals",
|
||||
"abilities",
|
||||
"relationship_to_protagonist",
|
||||
}},
|
||||
position=index,
|
||||
)
|
||||
)
|
||||
if "relationships" in patch and patch["relationships"] is not None:
|
||||
await self.session.execute(delete(BlueprintRelationship).where(BlueprintRelationship.project_id == project_id))
|
||||
for index, relation in enumerate(patch["relationships"]):
|
||||
self.session.add(
|
||||
BlueprintRelationship(
|
||||
project_id=project_id,
|
||||
character_from=relation.get("character_from"),
|
||||
character_to=relation.get("character_to"),
|
||||
description=relation.get("description"),
|
||||
position=index,
|
||||
)
|
||||
)
|
||||
if "chapter_outline" in patch and patch["chapter_outline"] is not None:
|
||||
await self.session.execute(delete(ChapterOutline).where(ChapterOutline.project_id == project_id))
|
||||
for outline in patch["chapter_outline"]:
|
||||
self.session.add(
|
||||
ChapterOutline(
|
||||
project_id=project_id,
|
||||
chapter_number=outline.get("chapter_number"),
|
||||
title=outline.get("title", ""),
|
||||
summary=outline.get("summary"),
|
||||
)
|
||||
)
|
||||
await self.session.commit()
|
||||
await self._touch_project(project_id)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 章节与版本
|
||||
# ------------------------------------------------------------------
|
||||
async def get_outline(self, project_id: str, chapter_number: int) -> Optional[ChapterOutline]:
|
||||
stmt = (
|
||||
select(ChapterOutline)
|
||||
.where(
|
||||
ChapterOutline.project_id == project_id,
|
||||
ChapterOutline.chapter_number == chapter_number,
|
||||
)
|
||||
)
|
||||
result = await self.session.execute(stmt)
|
||||
return result.scalars().first()
|
||||
|
||||
async def get_or_create_chapter(self, project_id: str, chapter_number: int) -> Chapter:
|
||||
stmt = (
|
||||
select(Chapter)
|
||||
.where(
|
||||
Chapter.project_id == project_id,
|
||||
Chapter.chapter_number == chapter_number,
|
||||
)
|
||||
)
|
||||
result = await self.session.execute(stmt)
|
||||
chapter = result.scalars().first()
|
||||
if chapter:
|
||||
return chapter
|
||||
chapter = Chapter(project_id=project_id, chapter_number=chapter_number)
|
||||
self.session.add(chapter)
|
||||
await self.session.commit()
|
||||
await self.session.refresh(chapter)
|
||||
return chapter
|
||||
|
||||
async def replace_chapter_versions(self, chapter: Chapter, contents: List[str], metadata: Optional[List[Dict]] = None) -> List[ChapterVersion]:
|
||||
await self.session.execute(delete(ChapterVersion).where(ChapterVersion.chapter_id == chapter.id))
|
||||
versions: List[ChapterVersion] = []
|
||||
for index, content in enumerate(contents):
|
||||
extra = metadata[index] if metadata and index < len(metadata) else None
|
||||
text_content = _normalize_version_content(content, extra)
|
||||
version = ChapterVersion(
|
||||
chapter_id=chapter.id,
|
||||
content=text_content,
|
||||
metadata=None,
|
||||
version_label=f"v{index+1}",
|
||||
)
|
||||
self.session.add(version)
|
||||
versions.append(version)
|
||||
chapter.status = ChapterGenerationStatus.WAITING_FOR_CONFIRM.value
|
||||
await self.session.commit()
|
||||
await self.session.refresh(chapter)
|
||||
await self._touch_project(chapter.project_id)
|
||||
return versions
|
||||
|
||||
async def select_chapter_version(self, chapter: Chapter, version_index: int) -> ChapterVersion:
|
||||
versions = sorted(chapter.versions, key=lambda item: item.created_at)
|
||||
if not versions or version_index < 0 or version_index >= len(versions):
|
||||
raise HTTPException(status_code=400, detail="版本索引无效")
|
||||
selected = versions[version_index]
|
||||
chapter.selected_version_id = selected.id
|
||||
chapter.status = ChapterGenerationStatus.SUCCESSFUL.value
|
||||
chapter.word_count = len(selected.content or "")
|
||||
await self.session.commit()
|
||||
await self.session.refresh(chapter)
|
||||
await self._touch_project(chapter.project_id)
|
||||
return selected
|
||||
|
||||
async def add_chapter_evaluation(self, chapter: Chapter, version: Optional[ChapterVersion], feedback: str, decision: Optional[str] = None) -> None:
|
||||
evaluation = ChapterEvaluation(
|
||||
chapter_id=chapter.id,
|
||||
version_id=version.id if version else None,
|
||||
feedback=feedback,
|
||||
decision=decision,
|
||||
)
|
||||
self.session.add(evaluation)
|
||||
chapter.status = ChapterGenerationStatus.WAITING_FOR_CONFIRM.value
|
||||
await self.session.commit()
|
||||
await self.session.refresh(chapter)
|
||||
await self._touch_project(chapter.project_id)
|
||||
|
||||
async def delete_chapters(self, project_id: str, chapter_numbers: Iterable[int]) -> None:
|
||||
await self.session.execute(
|
||||
delete(Chapter).where(
|
||||
Chapter.project_id == project_id,
|
||||
Chapter.chapter_number.in_(list(chapter_numbers)),
|
||||
)
|
||||
)
|
||||
await self.session.execute(
|
||||
delete(ChapterOutline).where(
|
||||
ChapterOutline.project_id == project_id,
|
||||
ChapterOutline.chapter_number.in_(list(chapter_numbers)),
|
||||
)
|
||||
)
|
||||
await self.session.commit()
|
||||
await self._touch_project(project_id)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 序列化辅助
|
||||
# ------------------------------------------------------------------
|
||||
async def get_project_schema_for_admin(self, project_id: str) -> NovelProjectSchema:
|
||||
project = await self.repo.get_by_id(project_id)
|
||||
if not project:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="项目不存在")
|
||||
return await self._serialize_project(project)
|
||||
|
||||
async def get_section_data_for_admin(
|
||||
self,
|
||||
project_id: str,
|
||||
section: NovelSectionType,
|
||||
) -> NovelSectionResponse:
|
||||
project = await self.repo.get_by_id(project_id)
|
||||
if not project:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="项目不存在")
|
||||
return self._build_section_response(project, section)
|
||||
|
||||
async def get_chapter_schema_for_admin(
|
||||
self,
|
||||
project_id: str,
|
||||
chapter_number: int,
|
||||
) -> ChapterSchema:
|
||||
project = await self.repo.get_by_id(project_id)
|
||||
if not project:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="项目不存在")
|
||||
return self._build_chapter_schema(project, chapter_number)
|
||||
|
||||
async def _serialize_project(self, project: NovelProject) -> NovelProjectSchema:
|
||||
conversations = [
|
||||
{"role": convo.role, "content": convo.content}
|
||||
for convo in sorted(project.conversations, key=lambda c: c.seq)
|
||||
]
|
||||
|
||||
blueprint_schema = self._build_blueprint_schema(project)
|
||||
|
||||
outlines_map = {outline.chapter_number: outline for outline in project.outlines}
|
||||
chapters_map = {chapter.chapter_number: chapter for chapter in project.chapters}
|
||||
chapter_numbers = sorted(set(outlines_map.keys()) | set(chapters_map.keys()))
|
||||
chapters_schema: List[ChapterSchema] = [
|
||||
self._build_chapter_schema(
|
||||
project,
|
||||
number,
|
||||
outlines_map=outlines_map,
|
||||
chapters_map=chapters_map,
|
||||
)
|
||||
for number in chapter_numbers
|
||||
]
|
||||
|
||||
return NovelProjectSchema(
|
||||
id=project.id,
|
||||
user_id=project.user_id,
|
||||
title=project.title,
|
||||
initial_prompt=project.initial_prompt or "",
|
||||
conversation_history=conversations,
|
||||
blueprint=blueprint_schema,
|
||||
chapters=chapters_schema,
|
||||
)
|
||||
|
||||
async def _touch_project(self, project_id: str) -> None:
|
||||
await self.session.execute(
|
||||
update(NovelProject)
|
||||
.where(NovelProject.id == project_id)
|
||||
.values(updated_at=datetime.now(timezone.utc))
|
||||
)
|
||||
await self.session.commit()
|
||||
|
||||
def _build_blueprint_schema(self, project: NovelProject) -> Blueprint:
|
||||
blueprint_obj = project.blueprint
|
||||
if blueprint_obj:
|
||||
return Blueprint(
|
||||
title=blueprint_obj.title or "",
|
||||
target_audience=blueprint_obj.target_audience or "",
|
||||
genre=blueprint_obj.genre or "",
|
||||
style=blueprint_obj.style or "",
|
||||
tone=blueprint_obj.tone or "",
|
||||
one_sentence_summary=blueprint_obj.one_sentence_summary or "",
|
||||
full_synopsis=blueprint_obj.full_synopsis or "",
|
||||
world_setting=blueprint_obj.world_setting or {},
|
||||
characters=[
|
||||
{
|
||||
"name": character.name,
|
||||
"identity": character.identity,
|
||||
"personality": character.personality,
|
||||
"goals": character.goals,
|
||||
"abilities": character.abilities,
|
||||
"relationship_to_protagonist": character.relationship_to_protagonist,
|
||||
**(character.extra or {}),
|
||||
}
|
||||
for character in sorted(project.characters, key=lambda c: c.position)
|
||||
],
|
||||
relationships=[
|
||||
{
|
||||
"character_from": relation.character_from,
|
||||
"character_to": relation.character_to,
|
||||
"description": relation.description or "",
|
||||
"relationship_type": getattr(relation, "relationship_type", None),
|
||||
}
|
||||
for relation in sorted(project.relationships_, key=lambda r: r.position)
|
||||
],
|
||||
chapter_outline=[
|
||||
ChapterOutlineSchema(
|
||||
chapter_number=outline.chapter_number,
|
||||
title=outline.title,
|
||||
summary=outline.summary or "",
|
||||
)
|
||||
for outline in sorted(project.outlines, key=lambda o: o.chapter_number)
|
||||
],
|
||||
)
|
||||
return Blueprint(
|
||||
title="",
|
||||
target_audience="",
|
||||
genre="",
|
||||
style="",
|
||||
tone="",
|
||||
one_sentence_summary="",
|
||||
full_synopsis="",
|
||||
world_setting={},
|
||||
characters=[],
|
||||
relationships=[],
|
||||
chapter_outline=[],
|
||||
)
|
||||
|
||||
def _build_section_response(
|
||||
self,
|
||||
project: NovelProject,
|
||||
section: NovelSectionType,
|
||||
) -> NovelSectionResponse:
|
||||
blueprint = self._build_blueprint_schema(project)
|
||||
|
||||
if section == NovelSectionType.OVERVIEW:
|
||||
data = {
|
||||
"title": project.title,
|
||||
"initial_prompt": project.initial_prompt or "",
|
||||
"status": project.status,
|
||||
"one_sentence_summary": blueprint.one_sentence_summary,
|
||||
"target_audience": blueprint.target_audience,
|
||||
"genre": blueprint.genre,
|
||||
"style": blueprint.style,
|
||||
"tone": blueprint.tone,
|
||||
"full_synopsis": blueprint.full_synopsis,
|
||||
"updated_at": project.updated_at.isoformat() if project.updated_at else None,
|
||||
}
|
||||
elif section == NovelSectionType.WORLD_SETTING:
|
||||
data = {
|
||||
"world_setting": blueprint.world_setting or {},
|
||||
}
|
||||
elif section == NovelSectionType.CHARACTERS:
|
||||
data = {
|
||||
"characters": blueprint.characters,
|
||||
}
|
||||
elif section == NovelSectionType.RELATIONSHIPS:
|
||||
data = {
|
||||
"relationships": blueprint.relationships,
|
||||
}
|
||||
elif section == NovelSectionType.CHAPTER_OUTLINE:
|
||||
data = {
|
||||
"chapter_outline": [outline.model_dump() for outline in blueprint.chapter_outline],
|
||||
}
|
||||
elif section == NovelSectionType.CHAPTERS:
|
||||
outlines_map = {outline.chapter_number: outline for outline in project.outlines}
|
||||
chapters_map = {chapter.chapter_number: chapter for chapter in project.chapters}
|
||||
chapter_numbers = sorted(set(outlines_map.keys()) | set(chapters_map.keys()))
|
||||
# 章节列表只返回元数据,不包含完整内容
|
||||
chapters = [
|
||||
self._build_chapter_schema(
|
||||
project,
|
||||
number,
|
||||
outlines_map=outlines_map,
|
||||
chapters_map=chapters_map,
|
||||
include_content=False,
|
||||
).model_dump()
|
||||
for number in chapter_numbers
|
||||
]
|
||||
data = {
|
||||
"chapters": chapters,
|
||||
"total": len(chapters),
|
||||
}
|
||||
else:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="未知的章节类型")
|
||||
|
||||
return NovelSectionResponse(section=section, data=data)
|
||||
|
||||
def _build_chapter_schema(
|
||||
self,
|
||||
project: NovelProject,
|
||||
chapter_number: int,
|
||||
*,
|
||||
outlines_map: Optional[Dict[int, ChapterOutline]] = None,
|
||||
chapters_map: Optional[Dict[int, Chapter]] = None,
|
||||
include_content: bool = True,
|
||||
) -> ChapterSchema:
|
||||
outlines = outlines_map or {outline.chapter_number: outline for outline in project.outlines}
|
||||
chapters = chapters_map or {chapter.chapter_number: chapter for chapter in project.chapters}
|
||||
outline = outlines.get(chapter_number)
|
||||
chapter = chapters.get(chapter_number)
|
||||
|
||||
if not outline and not chapter:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="章节不存在")
|
||||
|
||||
title = outline.title if outline else f"第{chapter_number}章"
|
||||
summary = outline.summary if outline else ""
|
||||
real_summary = chapter.real_summary if chapter else None
|
||||
content = None
|
||||
versions: Optional[List[str]] = None
|
||||
evaluation_text: Optional[str] = None
|
||||
status_value = ChapterGenerationStatus.NOT_GENERATED.value
|
||||
word_count = 0
|
||||
|
||||
if chapter:
|
||||
status_value = chapter.status or ChapterGenerationStatus.NOT_GENERATED.value
|
||||
word_count = chapter.word_count or 0
|
||||
|
||||
# 只有在 include_content=True 时才包含完整内容
|
||||
if include_content:
|
||||
if chapter.selected_version:
|
||||
content = chapter.selected_version.content
|
||||
if chapter.versions:
|
||||
versions = [
|
||||
v.content
|
||||
for v in sorted(chapter.versions, key=lambda item: item.created_at)
|
||||
]
|
||||
if chapter.evaluations:
|
||||
latest = sorted(chapter.evaluations, key=lambda item: item.created_at)[-1]
|
||||
evaluation_text = latest.feedback or latest.decision
|
||||
|
||||
return ChapterSchema(
|
||||
chapter_number=chapter_number,
|
||||
title=title,
|
||||
summary=summary,
|
||||
real_summary=real_summary,
|
||||
content=content,
|
||||
versions=versions,
|
||||
evaluation=evaluation_text,
|
||||
generation_status=ChapterGenerationStatus(status_value),
|
||||
word_count=word_count,
|
||||
)
|
||||
96
backend/app/services/prompt_service.py
Normal file
96
backend/app/services/prompt_service.py
Normal file
@@ -0,0 +1,96 @@
|
||||
import asyncio
|
||||
from typing import Dict, Optional
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from ..models import Prompt
|
||||
from ..repositories.prompt_repository import PromptRepository
|
||||
from ..schemas.prompt import PromptCreate, PromptRead, PromptUpdate
|
||||
|
||||
_CACHE: Dict[str, PromptRead] = {}
|
||||
_LOCK = asyncio.Lock()
|
||||
_LOADED = False
|
||||
|
||||
|
||||
class PromptService:
|
||||
"""提示词服务,提供缓存加速与 CRUD 能力。"""
|
||||
|
||||
def __init__(self, session: AsyncSession):
|
||||
self.session = session
|
||||
self.repo = PromptRepository(session)
|
||||
|
||||
async def preload(self) -> None:
|
||||
global _CACHE, _LOADED
|
||||
prompts = await self.repo.list_all()
|
||||
async with _LOCK:
|
||||
_CACHE = {item.name: PromptRead.model_validate(item) for item in prompts}
|
||||
_LOADED = True
|
||||
|
||||
async def get_prompt(self, name: str) -> Optional[str]:
|
||||
global _LOADED
|
||||
async with _LOCK:
|
||||
if not _LOADED:
|
||||
prompts = await self.repo.list_all()
|
||||
_CACHE.update({item.name: PromptRead.model_validate(item) for item in prompts})
|
||||
_LOADED = True
|
||||
cached = _CACHE.get(name)
|
||||
if cached:
|
||||
return cached.content
|
||||
|
||||
prompt = await self.repo.get_by_name(name)
|
||||
if not prompt:
|
||||
return None
|
||||
|
||||
prompt_read = PromptRead.model_validate(prompt)
|
||||
async with _LOCK:
|
||||
_CACHE[name] = prompt_read
|
||||
return prompt_read.content
|
||||
|
||||
async def list_prompts(self) -> list[PromptRead]:
|
||||
prompts = await self.repo.list_all()
|
||||
return [PromptRead.model_validate(item) for item in prompts]
|
||||
|
||||
async def get_prompt_by_id(self, prompt_id: int) -> Optional[PromptRead]:
|
||||
instance = await self.repo.get(id=prompt_id)
|
||||
if not instance:
|
||||
return None
|
||||
return PromptRead.model_validate(instance)
|
||||
|
||||
async def create_prompt(self, payload: PromptCreate) -> PromptRead:
|
||||
data = payload.model_dump()
|
||||
tags = data.get("tags")
|
||||
if tags is not None:
|
||||
data["tags"] = ",".join(tags)
|
||||
prompt = Prompt(**data)
|
||||
await self.repo.add(prompt)
|
||||
await self.session.commit()
|
||||
prompt_read = PromptRead.model_validate(prompt)
|
||||
async with _LOCK:
|
||||
_CACHE[prompt_read.name] = prompt_read
|
||||
global _LOADED
|
||||
_LOADED = True
|
||||
return prompt_read
|
||||
|
||||
async def update_prompt(self, prompt_id: int, payload: PromptUpdate) -> Optional[PromptRead]:
|
||||
instance = await self.repo.get(id=prompt_id)
|
||||
if not instance:
|
||||
return None
|
||||
update_data = payload.model_dump(exclude_unset=True)
|
||||
if "tags" in update_data and update_data["tags"] is not None:
|
||||
update_data["tags"] = ",".join(update_data["tags"])
|
||||
await self.repo.update_fields(instance, **update_data)
|
||||
await self.session.commit()
|
||||
prompt_read = PromptRead.model_validate(instance)
|
||||
async with _LOCK:
|
||||
_CACHE[prompt_read.name] = prompt_read
|
||||
return prompt_read
|
||||
|
||||
async def delete_prompt(self, prompt_id: int) -> bool:
|
||||
instance = await self.repo.get(id=prompt_id)
|
||||
if not instance:
|
||||
return False
|
||||
await self.repo.delete(instance)
|
||||
await self.session.commit()
|
||||
async with _LOCK:
|
||||
_CACHE.pop(instance.name, None)
|
||||
return True
|
||||
60
backend/app/services/update_log_service.py
Normal file
60
backend/app/services/update_log_service.py
Normal file
@@ -0,0 +1,60 @@
|
||||
from typing import List, Optional
|
||||
|
||||
from fastapi import HTTPException, status
|
||||
from sqlalchemy import update
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from ..models import UpdateLog
|
||||
from ..repositories.update_log_repository import UpdateLogRepository
|
||||
|
||||
|
||||
class UpdateLogService:
|
||||
"""更新日志服务,提供增删改查能力,并保证置顶唯一。"""
|
||||
|
||||
def __init__(self, session: AsyncSession):
|
||||
self.session = session
|
||||
self.repo = UpdateLogRepository(session)
|
||||
|
||||
async def list_logs(self, limit: Optional[int] = None) -> List[UpdateLog]:
|
||||
if limit is None:
|
||||
return list(await self.repo.list())
|
||||
return list(await self.repo.list_latest(limit))
|
||||
|
||||
async def create_log(self, content: str, creator: str | None = None, *, is_pinned: bool = False) -> UpdateLog:
|
||||
if is_pinned:
|
||||
await self._clear_pinned()
|
||||
log = UpdateLog(content=content, created_by=creator, is_pinned=is_pinned)
|
||||
await self.repo.add(log)
|
||||
await self.session.commit()
|
||||
await self.session.refresh(log)
|
||||
return log
|
||||
|
||||
async def update_log(self, log_id: int, *, content: Optional[str] = None, is_pinned: Optional[bool] = None) -> UpdateLog:
|
||||
log = await self.repo.get(id=log_id)
|
||||
if not log:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="更新记录不存在")
|
||||
|
||||
updates = {}
|
||||
if content is not None:
|
||||
updates["content"] = content
|
||||
if is_pinned is not None:
|
||||
if is_pinned:
|
||||
await self._clear_pinned()
|
||||
updates["is_pinned"] = is_pinned
|
||||
|
||||
if updates:
|
||||
await self.repo.update_fields(log, **updates)
|
||||
await self.session.commit()
|
||||
await self.session.refresh(log)
|
||||
|
||||
return log
|
||||
|
||||
async def delete_log(self, log_id: int) -> None:
|
||||
log = await self.repo.get(id=log_id)
|
||||
if not log:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="更新记录不存在")
|
||||
await self.repo.delete(log)
|
||||
await self.session.commit()
|
||||
|
||||
async def _clear_pinned(self) -> None:
|
||||
await self.session.execute(update(UpdateLog).values(is_pinned=False))
|
||||
21
backend/app/services/usage_service.py
Normal file
21
backend/app/services/usage_service.py
Normal file
@@ -0,0 +1,21 @@
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from ..repositories.usage_metric_repository import UsageMetricRepository
|
||||
|
||||
|
||||
class UsageService:
|
||||
"""通用计数服务,目前用于统计 API 请求次数等。"""
|
||||
|
||||
def __init__(self, session: AsyncSession):
|
||||
self.session = session
|
||||
self.repo = UsageMetricRepository(session)
|
||||
|
||||
async def increment(self, key: str) -> None:
|
||||
counter = await self.repo.get_or_create(key)
|
||||
counter.value += 1
|
||||
await self.session.commit()
|
||||
|
||||
async def get_value(self, key: str) -> int:
|
||||
counter = await self.repo.get_or_create(key)
|
||||
await self.session.commit()
|
||||
return counter.value
|
||||
62
backend/app/services/user_service.py
Normal file
62
backend/app/services/user_service.py
Normal file
@@ -0,0 +1,62 @@
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from ..core.security import hash_password
|
||||
from ..models import User
|
||||
from ..repositories.user_repository import UserRepository
|
||||
from ..schemas.user import UserCreate, UserInDB
|
||||
|
||||
|
||||
class UserService:
|
||||
"""用户领域服务,负责注册、查询与配额统计。"""
|
||||
|
||||
def __init__(self, session: AsyncSession):
|
||||
self.session = session
|
||||
self.repo = UserRepository(session)
|
||||
|
||||
async def create_user(self, payload: UserCreate, *, external_id: str | None = None) -> UserInDB:
|
||||
hashed_password = hash_password(payload.password)
|
||||
user = User(
|
||||
username=payload.username,
|
||||
email=payload.email,
|
||||
hashed_password=hashed_password,
|
||||
external_id=external_id,
|
||||
)
|
||||
|
||||
self.session.add(user)
|
||||
try:
|
||||
await self.session.commit()
|
||||
except IntegrityError as exc:
|
||||
await self.session.rollback()
|
||||
raise ValueError("用户名或邮箱已存在") from exc
|
||||
|
||||
return UserInDB.model_validate(user)
|
||||
|
||||
async def get_by_username(self, username: str) -> Optional[UserInDB]:
|
||||
user = await self.repo.get_by_username(username)
|
||||
return UserInDB.model_validate(user) if user else None
|
||||
|
||||
async def get_by_email(self, email: str) -> Optional[UserInDB]:
|
||||
user = await self.repo.get_by_email(email)
|
||||
return UserInDB.model_validate(user) if user else None
|
||||
|
||||
async def get_by_external_id(self, external_id: str) -> Optional[UserInDB]:
|
||||
user = await self.repo.get_by_external_id(external_id)
|
||||
return UserInDB.model_validate(user) if user else None
|
||||
|
||||
async def get_user(self, user_id: int) -> Optional[UserInDB]:
|
||||
user = await self.repo.get(id=user_id)
|
||||
return UserInDB.model_validate(user) if user else None
|
||||
|
||||
async def list_users(self) -> list[UserInDB]:
|
||||
users = await self.repo.list_all()
|
||||
return [UserInDB.model_validate(item) for item in users]
|
||||
|
||||
async def increment_daily_request(self, user_id: int) -> None:
|
||||
await self.repo.increment_daily_request(user_id)
|
||||
await self.session.commit()
|
||||
|
||||
async def get_daily_request(self, user_id: int) -> int:
|
||||
return await self.repo.get_daily_request(user_id)
|
||||
544
backend/app/services/vector_store_service.py
Normal file
544
backend/app/services/vector_store_service.py
Normal file
@@ -0,0 +1,544 @@
|
||||
from __future__ import annotations
|
||||
|
||||
"""
|
||||
基于 libsql 的向量检索服务,封装章节内容的存储与查询。
|
||||
|
||||
本文件中的注释均使用中文,便于团队成员快速理解 RAG 相关逻辑。
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
from array import array
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Iterable, List, Optional, Sequence
|
||||
|
||||
from ..core.config import settings
|
||||
|
||||
try: # noqa: SIM105 - 明确区分依赖缺失的情况
|
||||
import libsql_client
|
||||
except ImportError: # pragma: no cover - 在未安装依赖时提供友好提示
|
||||
libsql_client = None # type: ignore[assignment]
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RetrievedChunk:
|
||||
"""向量检索得到的剧情片段。"""
|
||||
|
||||
content: str
|
||||
chapter_number: int
|
||||
chapter_title: Optional[str]
|
||||
score: float
|
||||
metadata: Dict[str, Any]
|
||||
|
||||
|
||||
@dataclass
|
||||
class RetrievedSummary:
|
||||
"""向量检索得到的章节摘要。"""
|
||||
|
||||
chapter_number: int
|
||||
title: str
|
||||
summary: str
|
||||
score: float
|
||||
|
||||
|
||||
class VectorStoreService:
|
||||
"""libsql 向量库操作工具,确保不同小说项目的数据隔离。"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
if not settings.vector_store_enabled:
|
||||
logger.warning("未开启向量库配置,RAG 检索将被跳过。")
|
||||
self._client = None
|
||||
self._schema_ready = True
|
||||
return
|
||||
|
||||
if libsql_client is None: # pragma: no cover - 运行环境缺少依赖
|
||||
raise RuntimeError("缺少 libsql-client 依赖,请先在环境中安装。")
|
||||
|
||||
url = settings.vector_db_url
|
||||
if url and url.startswith("file:"):
|
||||
path_part = url.split("file:", 1)[1]
|
||||
resolved = Path(path_part).expanduser().resolve()
|
||||
resolved.parent.mkdir(parents=True, exist_ok=True)
|
||||
url = f"file:{resolved}"
|
||||
logger.info("向量库使用本地文件: %s", resolved)
|
||||
|
||||
try:
|
||||
logger.info("初始化 libsql 客户端: url=%s", url)
|
||||
self._client = libsql_client.create_client(
|
||||
url=url,
|
||||
auth_token=settings.vector_db_auth_token,
|
||||
)
|
||||
except Exception as exc: # pragma: no cover - 连接异常仅打印日志
|
||||
logger.error("初始化 libsql 客户端失败: %s", exc)
|
||||
self._client = None
|
||||
self._schema_ready = True
|
||||
else:
|
||||
self._schema_ready = False
|
||||
logger.info("libsql 客户端初始化成功,等待建表。")
|
||||
|
||||
async def ensure_schema(self) -> None:
|
||||
"""初始化向量表结构,保证系统首次运行即可使用。"""
|
||||
if not self._client or self._schema_ready:
|
||||
return
|
||||
|
||||
statements = [
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS rag_chunks (
|
||||
id TEXT PRIMARY KEY,
|
||||
project_id TEXT NOT NULL,
|
||||
chapter_number INTEGER NOT NULL,
|
||||
chunk_index INTEGER NOT NULL,
|
||||
chapter_title TEXT,
|
||||
content TEXT NOT NULL,
|
||||
embedding BLOB NOT NULL,
|
||||
metadata TEXT,
|
||||
created_at INTEGER DEFAULT (unixepoch())
|
||||
)
|
||||
""",
|
||||
"""
|
||||
CREATE INDEX IF NOT EXISTS idx_rag_chunks_project
|
||||
ON rag_chunks(project_id, chapter_number)
|
||||
""",
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS rag_summaries (
|
||||
id TEXT PRIMARY KEY,
|
||||
project_id TEXT NOT NULL,
|
||||
chapter_number INTEGER NOT NULL,
|
||||
title TEXT NOT NULL,
|
||||
summary TEXT NOT NULL,
|
||||
embedding BLOB NOT NULL,
|
||||
created_at INTEGER DEFAULT (unixepoch())
|
||||
)
|
||||
""",
|
||||
"""
|
||||
CREATE INDEX IF NOT EXISTS idx_rag_summaries_project
|
||||
ON rag_summaries(project_id, chapter_number)
|
||||
""",
|
||||
]
|
||||
|
||||
try:
|
||||
for sql in statements:
|
||||
await self._client.execute(sql) # type: ignore[union-attr]
|
||||
logger.info("已确保向量库表结构存在。")
|
||||
except Exception as exc: # pragma: no cover - 初始化失败时记录日志
|
||||
logger.error("创建向量库表结构失败: %s", exc)
|
||||
else:
|
||||
self._schema_ready = True
|
||||
|
||||
async def query_chunks(
|
||||
self,
|
||||
*,
|
||||
project_id: str,
|
||||
embedding: Sequence[float],
|
||||
top_k: Optional[int] = None,
|
||||
) -> List[RetrievedChunk]:
|
||||
"""根据查询向量检索剧情片段,结果已按相似度排序。"""
|
||||
if not self._client or not embedding:
|
||||
return []
|
||||
|
||||
await self.ensure_schema()
|
||||
top_k = top_k or settings.vector_top_k_chunks
|
||||
if top_k <= 0:
|
||||
return []
|
||||
|
||||
blob = self._to_f32_blob(embedding)
|
||||
sql = """
|
||||
SELECT
|
||||
content,
|
||||
chapter_number,
|
||||
chapter_title,
|
||||
COALESCE(metadata, '{}') AS metadata,
|
||||
vector_distance_cosine(embedding, :query) AS distance
|
||||
FROM rag_chunks
|
||||
WHERE project_id = :project_id
|
||||
ORDER BY distance ASC
|
||||
LIMIT :limit
|
||||
"""
|
||||
try:
|
||||
result = await self._client.execute( # type: ignore[union-attr]
|
||||
sql,
|
||||
{
|
||||
"project_id": project_id,
|
||||
"query": blob,
|
||||
"limit": top_k,
|
||||
},
|
||||
)
|
||||
except Exception as exc: # pragma: no cover - 查询异常时仅记录
|
||||
if "no such function: vector_distance_cosine" in str(exc).lower():
|
||||
logger.warning("向量库缺少 vector_distance_cosine 函数,回退至应用层相似度计算。")
|
||||
return await self._query_chunks_with_python_similarity(
|
||||
project_id=project_id,
|
||||
embedding=embedding,
|
||||
top_k=top_k,
|
||||
)
|
||||
logger.warning("向量检索剧情片段失败: %s", exc)
|
||||
return []
|
||||
|
||||
items: List[RetrievedChunk] = []
|
||||
for row in self._iter_rows(result):
|
||||
items.append(
|
||||
RetrievedChunk(
|
||||
content=row.get("content", ""),
|
||||
chapter_number=row.get("chapter_number", 0),
|
||||
chapter_title=row.get("chapter_title"),
|
||||
score=row.get("distance", 0.0),
|
||||
metadata=self._parse_metadata(row.get("metadata")),
|
||||
)
|
||||
)
|
||||
return items
|
||||
|
||||
async def query_summaries(
|
||||
self,
|
||||
*,
|
||||
project_id: str,
|
||||
embedding: Sequence[float],
|
||||
top_k: Optional[int] = None,
|
||||
) -> List[RetrievedSummary]:
|
||||
"""根据查询向量检索章节摘要列表。"""
|
||||
if not self._client or not embedding:
|
||||
return []
|
||||
|
||||
await self.ensure_schema()
|
||||
top_k = top_k or settings.vector_top_k_summaries
|
||||
if top_k <= 0:
|
||||
return []
|
||||
|
||||
blob = self._to_f32_blob(embedding)
|
||||
sql = """
|
||||
SELECT
|
||||
chapter_number,
|
||||
title,
|
||||
summary,
|
||||
vector_distance_cosine(embedding, :query) AS distance
|
||||
FROM rag_summaries
|
||||
WHERE project_id = :project_id
|
||||
ORDER BY distance ASC
|
||||
LIMIT :limit
|
||||
"""
|
||||
try:
|
||||
result = await self._client.execute( # type: ignore[union-attr]
|
||||
sql,
|
||||
{
|
||||
"project_id": project_id,
|
||||
"query": blob,
|
||||
"limit": top_k,
|
||||
},
|
||||
)
|
||||
except Exception as exc: # pragma: no cover - 查询异常时仅记录
|
||||
if "no such function: vector_distance_cosine" in str(exc).lower():
|
||||
logger.warning("向量库缺少 vector_distance_cosine 函数,回退至应用层相似度计算。")
|
||||
return await self._query_summaries_with_python_similarity(
|
||||
project_id=project_id,
|
||||
embedding=embedding,
|
||||
top_k=top_k,
|
||||
)
|
||||
logger.warning("向量检索章节摘要失败: %s", exc)
|
||||
return []
|
||||
|
||||
items: List[RetrievedSummary] = []
|
||||
for row in self._iter_rows(result):
|
||||
items.append(
|
||||
RetrievedSummary(
|
||||
chapter_number=row.get("chapter_number", 0),
|
||||
title=row.get("title", ""),
|
||||
summary=row.get("summary", ""),
|
||||
score=row.get("distance", 0.0),
|
||||
)
|
||||
)
|
||||
return items
|
||||
|
||||
async def upsert_chunks(
|
||||
self,
|
||||
*,
|
||||
records: Iterable[Dict[str, Any]],
|
||||
) -> None:
|
||||
"""批量写入章节片段,供后续检索使用。"""
|
||||
if not self._client:
|
||||
return
|
||||
|
||||
await self.ensure_schema()
|
||||
sql = """
|
||||
INSERT INTO rag_chunks (
|
||||
id,
|
||||
project_id,
|
||||
chapter_number,
|
||||
chunk_index,
|
||||
chapter_title,
|
||||
content,
|
||||
embedding,
|
||||
metadata
|
||||
) VALUES (
|
||||
:id,
|
||||
:project_id,
|
||||
:chapter_number,
|
||||
:chunk_index,
|
||||
:chapter_title,
|
||||
:content,
|
||||
:embedding,
|
||||
:metadata
|
||||
)
|
||||
ON CONFLICT(id) DO UPDATE SET
|
||||
content=excluded.content,
|
||||
embedding=excluded.embedding,
|
||||
metadata=excluded.metadata,
|
||||
chapter_title=excluded.chapter_title
|
||||
"""
|
||||
payload = []
|
||||
for item in records:
|
||||
embedding = item.get("embedding", [])
|
||||
payload.append(
|
||||
{
|
||||
**item,
|
||||
"embedding": self._to_f32_blob(embedding),
|
||||
"metadata": json.dumps(item.get("metadata") or {}, ensure_ascii=False),
|
||||
}
|
||||
)
|
||||
|
||||
if not payload:
|
||||
return
|
||||
|
||||
for item in payload:
|
||||
try:
|
||||
await self._client.execute(sql, item) # type: ignore[union-attr]
|
||||
except Exception as exc: # pragma: no cover - 单条写入失败时记录日志
|
||||
logger.error("写入 rag_chunks 失败: %s", exc)
|
||||
else:
|
||||
logger.debug(
|
||||
"已写入章节片段: project=%s chapter=%s chunk=%s",
|
||||
item.get("project_id"),
|
||||
item.get("chapter_number"),
|
||||
item.get("chunk_index"),
|
||||
)
|
||||
|
||||
async def upsert_summaries(
|
||||
self,
|
||||
*,
|
||||
records: Iterable[Dict[str, Any]],
|
||||
) -> None:
|
||||
"""同步章节摘要向量,供摘要层检索使用。"""
|
||||
if not self._client:
|
||||
return
|
||||
|
||||
await self.ensure_schema()
|
||||
sql = """
|
||||
INSERT INTO rag_summaries (
|
||||
id,
|
||||
project_id,
|
||||
chapter_number,
|
||||
title,
|
||||
summary,
|
||||
embedding
|
||||
) VALUES (
|
||||
:id,
|
||||
:project_id,
|
||||
:chapter_number,
|
||||
:title,
|
||||
:summary,
|
||||
:embedding
|
||||
)
|
||||
ON CONFLICT(id) DO UPDATE SET
|
||||
summary=excluded.summary,
|
||||
embedding=excluded.embedding,
|
||||
title=excluded.title
|
||||
"""
|
||||
|
||||
payload = []
|
||||
for item in records:
|
||||
embedding = item.get("embedding", [])
|
||||
payload.append(
|
||||
{
|
||||
**item,
|
||||
"embedding": self._to_f32_blob(embedding),
|
||||
}
|
||||
)
|
||||
|
||||
if not payload:
|
||||
return
|
||||
|
||||
for item in payload:
|
||||
try:
|
||||
await self._client.execute(sql, item) # type: ignore[union-attr]
|
||||
except Exception as exc: # pragma: no cover - 单条写入失败时记录日志
|
||||
logger.error("写入 rag_summaries 失败: %s", exc)
|
||||
else:
|
||||
logger.debug(
|
||||
"已写入章节摘要: project=%s chapter=%s",
|
||||
item.get("project_id"),
|
||||
item.get("chapter_number"),
|
||||
)
|
||||
|
||||
async def delete_by_chapters(self, project_id: str, chapter_numbers: Sequence[int]) -> None:
|
||||
"""根据章节编号批量删除对应的上下文数据。"""
|
||||
if not self._client or not chapter_numbers:
|
||||
return
|
||||
|
||||
await self.ensure_schema()
|
||||
placeholders = ",".join(":chapter_" + str(idx) for idx in range(len(chapter_numbers)))
|
||||
params = {
|
||||
"project_id": project_id,
|
||||
**{f"chapter_{idx}": number for idx, number in enumerate(chapter_numbers)},
|
||||
}
|
||||
chunk_sql = f"""
|
||||
DELETE FROM rag_chunks
|
||||
WHERE project_id = :project_id
|
||||
AND chapter_number IN ({placeholders})
|
||||
"""
|
||||
summary_sql = f"""
|
||||
DELETE FROM rag_summaries
|
||||
WHERE project_id = :project_id
|
||||
AND chapter_number IN ({placeholders})
|
||||
"""
|
||||
try:
|
||||
await self._client.execute(chunk_sql, params) # type: ignore[union-attr]
|
||||
await self._client.execute(summary_sql, params) # type: ignore[union-attr]
|
||||
logger.info(
|
||||
"已删除章节向量: project=%s chapters=%s",
|
||||
project_id,
|
||||
list(chapter_numbers),
|
||||
)
|
||||
except Exception as exc: # pragma: no cover - 删除失败时记录日志
|
||||
logger.error("删除章节向量失败: project=%s chapters=%s error=%s", project_id, chapter_numbers, exc)
|
||||
|
||||
@staticmethod
|
||||
def _to_f32_blob(embedding: Sequence[float]) -> bytes:
|
||||
"""将向量浮点列表编码为 libsql 可识别的 float32 二进制。"""
|
||||
return array("f", embedding).tobytes()
|
||||
|
||||
@staticmethod
|
||||
def _from_f32_blob(blob: Any) -> List[float]:
|
||||
"""将数据库中的 BLOB 解码为浮点列表。"""
|
||||
if not blob:
|
||||
return []
|
||||
if isinstance(blob, memoryview):
|
||||
blob = blob.tobytes()
|
||||
data = array("f")
|
||||
data.frombytes(bytes(blob))
|
||||
return list(data)
|
||||
|
||||
@staticmethod
|
||||
def _cosine_distance(vec_a: Sequence[float], vec_b: Sequence[float]) -> float:
|
||||
"""计算余弦距离(1 - similarity),避免除零。"""
|
||||
if not vec_a or not vec_b:
|
||||
return 1.0
|
||||
dot = sum(a * b for a, b in zip(vec_a, vec_b))
|
||||
norm_a = math.sqrt(sum(a * a for a in vec_a))
|
||||
norm_b = math.sqrt(sum(b * b for b in vec_b))
|
||||
if norm_a == 0 or norm_b == 0:
|
||||
return 1.0
|
||||
similarity = dot / (norm_a * norm_b)
|
||||
return 1.0 - similarity
|
||||
|
||||
async def _query_chunks_with_python_similarity(
|
||||
self,
|
||||
*,
|
||||
project_id: str,
|
||||
embedding: Sequence[float],
|
||||
top_k: int,
|
||||
) -> List[RetrievedChunk]:
|
||||
sql = """
|
||||
SELECT
|
||||
content,
|
||||
chapter_number,
|
||||
chapter_title,
|
||||
COALESCE(metadata, '{}') AS metadata,
|
||||
embedding
|
||||
FROM rag_chunks
|
||||
WHERE project_id = :project_id
|
||||
"""
|
||||
result = await self._client.execute(sql, {"project_id": project_id}) # type: ignore[union-attr]
|
||||
scored: List[RetrievedChunk] = []
|
||||
for row in self._iter_rows(result):
|
||||
stored_embedding = self._from_f32_blob(row.get("embedding"))
|
||||
distance = self._cosine_distance(embedding, stored_embedding)
|
||||
scored.append(
|
||||
RetrievedChunk(
|
||||
content=row.get("content", ""),
|
||||
chapter_number=row.get("chapter_number", 0),
|
||||
chapter_title=row.get("chapter_title"),
|
||||
score=distance,
|
||||
metadata=self._parse_metadata(row.get("metadata")),
|
||||
)
|
||||
)
|
||||
scored.sort(key=lambda item: item.score)
|
||||
return scored[:top_k]
|
||||
|
||||
async def _query_summaries_with_python_similarity(
|
||||
self,
|
||||
*,
|
||||
project_id: str,
|
||||
embedding: Sequence[float],
|
||||
top_k: int,
|
||||
) -> List[RetrievedSummary]:
|
||||
sql = """
|
||||
SELECT
|
||||
chapter_number,
|
||||
title,
|
||||
summary,
|
||||
embedding
|
||||
FROM rag_summaries
|
||||
WHERE project_id = :project_id
|
||||
"""
|
||||
result = await self._client.execute(sql, {"project_id": project_id}) # type: ignore[union-attr]
|
||||
scored: List[RetrievedSummary] = []
|
||||
for row in self._iter_rows(result):
|
||||
stored_embedding = self._from_f32_blob(row.get("embedding"))
|
||||
distance = self._cosine_distance(embedding, stored_embedding)
|
||||
scored.append(
|
||||
RetrievedSummary(
|
||||
chapter_number=row.get("chapter_number", 0),
|
||||
title=row.get("title", ""),
|
||||
summary=row.get("summary", ""),
|
||||
score=distance,
|
||||
)
|
||||
)
|
||||
scored.sort(key=lambda item: item.score)
|
||||
return scored[:top_k]
|
||||
|
||||
@staticmethod
|
||||
def _parse_metadata(raw: Any) -> Dict[str, Any]:
|
||||
"""解析存储的 JSON 文本,确保输出为 dict。"""
|
||||
if not raw:
|
||||
return {}
|
||||
if isinstance(raw, dict):
|
||||
return raw
|
||||
if isinstance(raw, (bytes, bytearray)):
|
||||
raw = raw.decode("utf-8")
|
||||
if isinstance(raw, str):
|
||||
try:
|
||||
parsed = json.loads(raw)
|
||||
return parsed if isinstance(parsed, dict) else {}
|
||||
except json.JSONDecodeError:
|
||||
return {}
|
||||
return {}
|
||||
|
||||
@staticmethod
|
||||
def _iter_rows(result: Any) -> Iterable[Dict[str, Any]]:
|
||||
"""统一处理 libsql 返回的行数据,确保以 dict 形式迭代。"""
|
||||
rows = getattr(result, "rows", None)
|
||||
if rows is None:
|
||||
rows = result
|
||||
if not rows:
|
||||
return []
|
||||
normalized: List[Dict[str, Any]] = []
|
||||
for row in rows:
|
||||
if isinstance(row, dict):
|
||||
normalized.append(row)
|
||||
elif hasattr(row, "_asdict"):
|
||||
normalized.append(row._asdict()) # type: ignore[attr-defined]
|
||||
else:
|
||||
try:
|
||||
normalized.append(dict(row))
|
||||
except Exception: # pragma: no cover - 无法转换时跳过
|
||||
continue
|
||||
return normalized
|
||||
|
||||
|
||||
__all__ = [
|
||||
"VectorStoreService",
|
||||
"RetrievedChunk",
|
||||
"RetrievedSummary",
|
||||
]
|
||||
0
backend/app/utils/__init__.py
Normal file
0
backend/app/utils/__init__.py
Normal file
81
backend/app/utils/json_utils.py
Normal file
81
backend/app/utils/json_utils.py
Normal file
@@ -0,0 +1,81 @@
|
||||
import re
|
||||
|
||||
|
||||
def remove_think_tags(raw_text: str) -> str:
|
||||
"""移除 <think></think> 标签,避免污染结果。"""
|
||||
if not raw_text:
|
||||
return raw_text
|
||||
return re.sub(r"<think>.*?</think>", "", raw_text, flags=re.DOTALL).strip()
|
||||
|
||||
|
||||
def unwrap_markdown_json(raw_text: str) -> str:
|
||||
"""从 Markdown 或普通文本中提取 JSON 字符串。"""
|
||||
if not raw_text:
|
||||
return raw_text
|
||||
|
||||
trimmed = raw_text.strip()
|
||||
|
||||
fence_match = re.search(r"```(?:json|JSON)?\s*(.*?)\s*```", trimmed, re.DOTALL)
|
||||
if fence_match:
|
||||
candidate = fence_match.group(1).strip()
|
||||
if candidate:
|
||||
return candidate
|
||||
|
||||
json_start_candidates = [idx for idx in (trimmed.find("{"), trimmed.find("[")) if idx != -1]
|
||||
if json_start_candidates:
|
||||
start_idx = min(json_start_candidates)
|
||||
closing_brace = trimmed.rfind("}")
|
||||
closing_bracket = trimmed.rfind("]")
|
||||
end_idx = max(closing_brace, closing_bracket)
|
||||
if end_idx != -1 and end_idx > start_idx:
|
||||
candidate = trimmed[start_idx : end_idx + 1].strip()
|
||||
if candidate:
|
||||
return candidate
|
||||
|
||||
return trimmed
|
||||
|
||||
|
||||
def sanitize_json_like_text(raw_text: str) -> str:
|
||||
"""对可能含有未转义换行/引号的 JSON 文本进行清洗。"""
|
||||
if not raw_text:
|
||||
return raw_text
|
||||
|
||||
result = []
|
||||
in_string = False
|
||||
escape_next = False
|
||||
length = len(raw_text)
|
||||
i = 0
|
||||
while i < length:
|
||||
ch = raw_text[i]
|
||||
if in_string:
|
||||
if escape_next:
|
||||
result.append(ch)
|
||||
escape_next = False
|
||||
elif ch == "\\":
|
||||
result.append(ch)
|
||||
escape_next = True
|
||||
elif ch == '"':
|
||||
j = i + 1
|
||||
while j < length and raw_text[j] in " \t\r\n":
|
||||
j += 1
|
||||
|
||||
if j >= length or raw_text[j] in "}]" or raw_text[j] == ",":
|
||||
in_string = False
|
||||
result.append(ch)
|
||||
else:
|
||||
result.extend(["\\", '"'])
|
||||
elif ch == "\n":
|
||||
result.extend(["\\", "n"])
|
||||
elif ch == "\r":
|
||||
result.extend(["\\", "r"])
|
||||
elif ch == "\t":
|
||||
result.extend(["\\", "t"])
|
||||
else:
|
||||
result.append(ch)
|
||||
else:
|
||||
if ch == '"':
|
||||
in_string = True
|
||||
result.append(ch)
|
||||
i += 1
|
||||
|
||||
return "".join(result)
|
||||
65
backend/app/utils/llm_tool.py
Normal file
65
backend/app/utils/llm_tool.py
Normal file
@@ -0,0 +1,65 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""OpenAI 兼容型 LLM 工具封装,保持与旧项目一致的接口体验。"""
|
||||
|
||||
import os
|
||||
from dataclasses import asdict, dataclass
|
||||
from typing import AsyncGenerator, Dict, List, Optional
|
||||
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChatMessage:
|
||||
role: str
|
||||
content: str
|
||||
|
||||
def to_dict(self) -> Dict[str, str]:
|
||||
return asdict(self)
|
||||
|
||||
|
||||
class LLMClient:
|
||||
"""异步流式调用封装,兼容 OpenAI SDK。"""
|
||||
|
||||
def __init__(self, api_key: Optional[str] = None, base_url: Optional[str] = None):
|
||||
key = api_key or os.environ.get("OPENAI_API_KEY")
|
||||
if not key:
|
||||
raise ValueError("缺少 OPENAI_API_KEY 配置,请在数据库或环境变量中补全。")
|
||||
|
||||
self._client = AsyncOpenAI(api_key=key, base_url=base_url or os.environ.get("OPENAI_API_BASE"))
|
||||
|
||||
async def stream_chat(
|
||||
self,
|
||||
messages: List[ChatMessage],
|
||||
model: Optional[str] = None,
|
||||
response_format: Optional[str] = None,
|
||||
temperature: Optional[float] = None,
|
||||
top_p: Optional[float] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
timeout: int = 120,
|
||||
**kwargs,
|
||||
) -> AsyncGenerator[Dict[str, str], None]:
|
||||
payload = {
|
||||
"model": model or os.environ.get("MODEL", "gpt-3.5-turbo"),
|
||||
"messages": [msg.to_dict() for msg in messages],
|
||||
"stream": True,
|
||||
"timeout": timeout,
|
||||
**kwargs,
|
||||
}
|
||||
if response_format:
|
||||
payload["response_format"] = {"type": response_format}
|
||||
if temperature is not None:
|
||||
payload["temperature"] = temperature
|
||||
if top_p is not None:
|
||||
payload["top_p"] = top_p
|
||||
if max_tokens is not None:
|
||||
payload["max_tokens"] = max_tokens
|
||||
|
||||
stream = await self._client.chat.completions.create(**payload)
|
||||
async for chunk in stream:
|
||||
if not chunk.choices:
|
||||
continue
|
||||
choice = chunk.choices[0]
|
||||
yield {
|
||||
"content": choice.delta.content,
|
||||
"finish_reason": choice.finish_reason,
|
||||
}
|
||||
Reference in New Issue
Block a user