feat: 初始提交
This commit is contained in:
0
backend/app/services/__init__.py
Normal file
0
backend/app/services/__init__.py
Normal file
27
backend/app/services/admin_setting_service.py
Normal file
27
backend/app/services/admin_setting_service.py
Normal file
@@ -0,0 +1,27 @@
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from ..models import AdminSetting
|
||||
from ..repositories.admin_setting_repository import AdminSettingRepository
|
||||
|
||||
|
||||
class AdminSettingService:
|
||||
"""管理员配置项服务,提供简单的 KV 操作。"""
|
||||
|
||||
def __init__(self, session: AsyncSession):
|
||||
self.session = session
|
||||
self.repo = AdminSettingRepository(session)
|
||||
|
||||
async def get(self, key: str, default: Optional[str] = None) -> Optional[str]:
|
||||
value = await self.repo.get_value(key)
|
||||
return value if value is not None else default
|
||||
|
||||
async def set(self, key: str, value: str) -> None:
|
||||
record = await self.repo.get(key=key)
|
||||
if record:
|
||||
await self.repo.update_fields(record, value=value)
|
||||
else:
|
||||
setting = AdminSetting(key=key, value=value)
|
||||
await self.repo.add(setting)
|
||||
await self.session.commit()
|
||||
389
backend/app/services/auth_service.py
Normal file
389
backend/app/services/auth_service.py
Normal file
@@ -0,0 +1,389 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import random
|
||||
import secrets
|
||||
import string
|
||||
import time
|
||||
from typing import Dict, Optional
|
||||
|
||||
import httpx
|
||||
from email.header import Header
|
||||
from email.mime.text import MIMEText
|
||||
from email.utils import formataddr, parseaddr
|
||||
from fastapi import HTTPException, status
|
||||
|
||||
import smtplib
|
||||
|
||||
from ..core.config import settings
|
||||
from ..core.security import create_access_token, hash_password, verify_password
|
||||
from ..models import User
|
||||
from ..repositories.system_config_repository import SystemConfigRepository
|
||||
from ..repositories.user_repository import UserRepository
|
||||
from ..schemas.user import AuthOptions, Token, UserCreate, UserInDB, UserRegistration
|
||||
|
||||
|
||||
_VERIFICATION_CACHE: Dict[str, tuple[str, float]] = {}
|
||||
_LAST_SEND_TIME: Dict[str, float] = {}
|
||||
|
||||
|
||||
class AuthService:
|
||||
"""认证与授权逻辑,封装登录、注册、OAuth 对接等操作。"""
|
||||
|
||||
def __init__(self, session):
|
||||
self.session = session
|
||||
self.user_repo = UserRepository(session)
|
||||
self.system_config_repo = SystemConfigRepository(session)
|
||||
self._verification_cache = _VERIFICATION_CACHE
|
||||
self._last_send_time = _LAST_SEND_TIME
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 用户登录 / 注册
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def authenticate_user(self, username: str, password: str) -> User:
|
||||
user = await self.user_repo.get_by_username(username)
|
||||
if not user or not verify_password(password, user.hashed_password):
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="用户名或密码错误")
|
||||
return user
|
||||
|
||||
async def create_access_token(
|
||||
self,
|
||||
user: User | UserInDB,
|
||||
*,
|
||||
must_change_password: Optional[bool] = None,
|
||||
) -> Token:
|
||||
payload = {"is_admin": user.is_admin}
|
||||
token = create_access_token(user.username, extra_claims=payload)
|
||||
should_change = self.requires_password_reset(user) if must_change_password is None else must_change_password
|
||||
return Token(access_token=token, must_change_password=should_change)
|
||||
|
||||
async def register_user(self, payload: UserRegistration) -> User:
|
||||
if not await self.is_registration_enabled():
|
||||
raise HTTPException(status_code=403, detail="当前暂未开放注册")
|
||||
if await self.user_repo.get_by_username(payload.username):
|
||||
raise HTTPException(status_code=400, detail="用户名已存在")
|
||||
if payload.email and await self.user_repo.get_by_email(payload.email):
|
||||
raise HTTPException(status_code=400, detail="邮箱已被使用")
|
||||
|
||||
if not self.verify_code(payload.email, payload.verification_code):
|
||||
raise HTTPException(status_code=400, detail="验证码错误或已过期")
|
||||
|
||||
hashed_password = hash_password(payload.password)
|
||||
user = User(
|
||||
username=payload.username,
|
||||
email=payload.email,
|
||||
hashed_password=hashed_password,
|
||||
)
|
||||
self.session.add(user)
|
||||
await self.session.commit()
|
||||
return user
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 邮箱验证码逻辑
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def send_verification_code(self, email: str) -> None:
|
||||
if not await self.is_registration_enabled():
|
||||
raise HTTPException(status_code=403, detail="当前暂未开放注册")
|
||||
now = time.time()
|
||||
if email in self._last_send_time and now - self._last_send_time[email] < 60:
|
||||
raise HTTPException(status_code=429, detail="请稍后再试,1分钟内不可重复发送")
|
||||
|
||||
code = "".join(random.choices(string.digits, k=6))
|
||||
self._verification_cache[email] = (code, now + 300)
|
||||
self._last_send_time[email] = now
|
||||
|
||||
smtp_config = await self._load_smtp_config()
|
||||
if not smtp_config:
|
||||
raise HTTPException(status_code=500, detail="未配置邮件服务,请联系管理员")
|
||||
|
||||
await self._send_email(email, code, smtp_config)
|
||||
|
||||
def verify_code(self, email: str | None, code: str) -> bool:
|
||||
if not email:
|
||||
return False
|
||||
cached = self._verification_cache.get(email)
|
||||
if not cached:
|
||||
return False
|
||||
expected, expire_at = cached
|
||||
if time.time() > expire_at:
|
||||
self._verification_cache.pop(email, None)
|
||||
return False
|
||||
if code != expected:
|
||||
return False
|
||||
self._verification_cache.pop(email, None)
|
||||
return True
|
||||
|
||||
async def _load_smtp_config(self) -> Optional[Dict[str, str]]:
|
||||
keys = [
|
||||
"smtp.server",
|
||||
"smtp.port",
|
||||
"smtp.username",
|
||||
"smtp.password",
|
||||
"smtp.from",
|
||||
]
|
||||
configs = {}
|
||||
for key in keys:
|
||||
config = await self.system_config_repo.get_by_key(key)
|
||||
if config:
|
||||
configs[key] = config.value
|
||||
|
||||
required_keys = {"smtp.server", "smtp.port", "smtp.username", "smtp.password", "smtp.from"}
|
||||
if not required_keys.issubset(configs.keys()):
|
||||
return None
|
||||
|
||||
return configs
|
||||
|
||||
async def _send_email(self, to_email: str, code: str, smtp_config: Dict[str, str]) -> None:
|
||||
logger = logging.getLogger(__name__)
|
||||
server = smtp_config["smtp.server"]
|
||||
port = int(smtp_config.get("smtp.port", "465"))
|
||||
username = smtp_config["smtp.username"]
|
||||
password = smtp_config["smtp.password"]
|
||||
from_value = smtp_config.get("smtp.from") or username
|
||||
display_name, from_addr = parseaddr(from_value)
|
||||
if not display_name and "@" not in from_value and "<" not in from_value and from_value.strip():
|
||||
display_name = from_value.strip()
|
||||
if not from_addr or "@" not in from_addr:
|
||||
if from_addr and "@" not in from_addr:
|
||||
logger.warning(
|
||||
"发件邮箱缺少 @,已回退为登录账号",
|
||||
extra={"original": from_addr},
|
||||
)
|
||||
from_addr = username
|
||||
try:
|
||||
from_addr.encode("ascii")
|
||||
except UnicodeEncodeError:
|
||||
logger.warning(
|
||||
"发件邮箱包含非 ASCII 字符,已回退为登录账号",
|
||||
extra={"original": from_addr},
|
||||
)
|
||||
from_addr = username
|
||||
if display_name:
|
||||
formatted_from = formataddr((Header(display_name, "utf-8").encode(), from_addr))
|
||||
else:
|
||||
formatted_from = from_addr
|
||||
|
||||
try:
|
||||
to_email.encode("ascii")
|
||||
except UnicodeEncodeError as exc: # noqa: BLE001
|
||||
raise HTTPException(status_code=400, detail="邮箱地址包含不支持的字符") from exc
|
||||
|
||||
html_content = f"""
|
||||
<!DOCTYPE html>
|
||||
<html lang="zh-CN">
|
||||
<head>
|
||||
<meta http-equiv=\"Content-Type\" content=\"text/html; charset=UTF-8\">
|
||||
<meta name=\"viewport\" content=\"width=device-width, initial-scale=1.0\">
|
||||
<title>您的验证码</title>
|
||||
<style>
|
||||
body, table, td, a {{ -webkit-text-size-adjust: 100%; -ms-text-size-adjust: 100%; }}
|
||||
table, td {{ mso-table-lspace: 0pt; mso-table-rspace: 0pt; }}
|
||||
img {{ -ms-interpolation-mode: bicubic; }}
|
||||
body {{ margin: 0; padding: 0; }}
|
||||
table {{ border-collapse: collapse !important; }}
|
||||
</style>
|
||||
</head>
|
||||
<body style=\"margin: 0; padding: 0; width: 100% !important; background-color: #f3f4f6;\">
|
||||
<table width=\"100%\" border=\"0\" cellpadding=\"0\" cellspacing=\"0\" bgcolor=\"#f3f4f6\">
|
||||
<tr>
|
||||
<td align=\"center\" valign=\"top\" style=\"padding: 20px;\">
|
||||
<table border=\"0\" cellpadding=\"0\" cellspacing=\"0\" width=\"100%\" style=\"max-width: 512px; background-color: #ffffff; border-radius: 16px; overflow: hidden;\">
|
||||
<tr>
|
||||
<td align=\"center\" style=\"background-color: #2563eb; padding: 32px;\">
|
||||
<h1 style=\"font-family: Arial, Helvetica, sans-serif; font-size: 30px; font-weight: bold; color: #ffffff; margin: 0;\">操作验证码</h1>
|
||||
<p style=\"font-family: Arial, Helvetica, sans-serif; font-size: 16px; color: #dbeafe; margin: 8px 0 0;\">请使用下方验证码完成操作。</p>
|
||||
</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align=\"center\" style=\"padding: 32px 48px;\">
|
||||
<table border=\"0\" cellpadding=\"0\" cellspacing=\"0\" width=\"100%\">
|
||||
<tr>
|
||||
<td align=\"center\" style=\"background-color: #f3f4f6; border-radius: 12px; padding: 16px; margin: 24px 0;\">
|
||||
<p style=\"font-family: 'Courier New', Courier, monospace; font-size: 48px; font-weight: bold; letter-spacing: 0.1em; color: #1d4ed8; margin: 0;\">
|
||||
{code[:3]}{code[3:]}
|
||||
</p>
|
||||
</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align=\"center\" style=\"padding-top: 24px;\">
|
||||
<p style=\"font-family: Arial, Helvetica, sans-serif; font-size: 16px; color: #6b7280; margin: 0;\">
|
||||
此验证码将在 <span style=\"font-weight: bold; color: #374151;\">5分钟</span> 内有效。
|
||||
</p>
|
||||
</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align=\"center\" style=\"padding-top: 32px; border-top: 1px solid #e5e7eb; margin-top: 32px;\">
|
||||
<p style=\"font-family: Arial, Helvetica, sans-serif; font-size: 14px; font-weight: bold; color: #ef4444; margin: 0;\">
|
||||
为保障安全,请勿泄露此验证码。
|
||||
</p>
|
||||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align=\"center\" style=\"background-color: #f9fafb; padding: 24px; border-top: 1px solid #e5e7eb;\">
|
||||
<p style=\"font-family: Arial, Helvetica, sans-serif; font-size: 14px; color: #6b7280; margin: 0;\">
|
||||
如非本人操作,请忽略此邮件。
|
||||
</p>
|
||||
<p style=\"font-family: Arial, Helvetica, sans-serif; font-size: 12px; color: #9ca3af; margin: 8px 0 0;\">
|
||||
© {time.strftime('%Y')} 拯救小说家. All rights reserved.
|
||||
</p>
|
||||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
|
||||
message = MIMEText(html_content, "html", "utf-8")
|
||||
message["Subject"] = Header("注册验证码", "utf-8").encode()
|
||||
message["From"] = formatted_from
|
||||
message["To"] = to_email
|
||||
|
||||
logger.info("准备发送验证码邮件", extra={"to": to_email, "server": server, "port": port})
|
||||
|
||||
def _send():
|
||||
smtp: Optional[smtplib.SMTP] = None
|
||||
try:
|
||||
if port == 465:
|
||||
smtp = smtplib.SMTP_SSL(server, port, timeout=10)
|
||||
else:
|
||||
smtp = smtplib.SMTP(server, port, timeout=10)
|
||||
smtp.starttls()
|
||||
if username and password:
|
||||
smtp.login(username, password)
|
||||
smtp.sendmail(from_addr, [to_email], message.as_string())
|
||||
logger.info("验证码邮件发送成功", extra={"to": to_email})
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.exception("验证码发送失败")
|
||||
raise
|
||||
finally:
|
||||
if smtp is not None:
|
||||
try:
|
||||
smtp.quit()
|
||||
except Exception: # noqa: BLE001
|
||||
pass
|
||||
|
||||
try:
|
||||
await asyncio.to_thread(_send)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
raise HTTPException(status_code=500, detail="验证码发送失败,请检查邮件配置") from exc
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# OAuth 对接示例(以 Linux.do 为例)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def handle_linuxdo_callback(self, code: str) -> Token:
|
||||
if not await self.is_linuxdo_login_enabled():
|
||||
raise HTTPException(status_code=403, detail="未启用 Linux.do 登录")
|
||||
client_id = await self._get_config_value("linuxdo.client_id")
|
||||
client_secret = await self._get_config_value("linuxdo.client_secret")
|
||||
redirect_uri = await self._get_config_value("linuxdo.redirect_uri")
|
||||
token_url = await self._get_config_value("linuxdo.token_url")
|
||||
user_info_url = await self._get_config_value("linuxdo.user_info_url")
|
||||
|
||||
if not all([client_id, client_secret, redirect_uri, token_url, user_info_url]):
|
||||
raise HTTPException(status_code=500, detail="未正确配置 Linux.do OAuth 参数")
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
token_response = await client.post(
|
||||
token_url,
|
||||
data={
|
||||
"client_id": client_id,
|
||||
"client_secret": client_secret,
|
||||
"code": code,
|
||||
"redirect_uri": redirect_uri,
|
||||
"grant_type": "authorization_code",
|
||||
},
|
||||
)
|
||||
token_response.raise_for_status()
|
||||
access_token = token_response.json().get("access_token")
|
||||
if not access_token:
|
||||
raise HTTPException(status_code=400, detail="授权失败,未获取到访问令牌")
|
||||
|
||||
user_info_response = await client.get(
|
||||
user_info_url,
|
||||
headers={"Authorization": f"Bearer {access_token}"},
|
||||
)
|
||||
user_info_response.raise_for_status()
|
||||
data = user_info_response.json()
|
||||
|
||||
external_id = f"linuxdo:{data['id']}"
|
||||
user = await self.user_repo.get_by_external_id(external_id)
|
||||
if user is None:
|
||||
placeholder_password = secrets.token_urlsafe(16)
|
||||
user = User(
|
||||
username=data["username"],
|
||||
email=data.get("email"),
|
||||
external_id=external_id,
|
||||
hashed_password=hash_password(placeholder_password),
|
||||
)
|
||||
self.session.add(user)
|
||||
await self.session.commit()
|
||||
|
||||
return await self.create_access_token(user)
|
||||
|
||||
async def _get_config_value(self, key: str) -> Optional[str]:
|
||||
config = await self.system_config_repo.get_by_key(key)
|
||||
return config.value if config else None
|
||||
|
||||
async def get_config_value(self, key: str) -> Optional[str]:
|
||||
"""对外暴露的配置读取接口,便于路由层复用。"""
|
||||
return await self._get_config_value(key)
|
||||
|
||||
@staticmethod
|
||||
def _parse_bool(value: Optional[str], fallback: bool) -> bool:
|
||||
if value is None:
|
||||
return fallback
|
||||
normalized = value.strip().lower()
|
||||
return normalized in {"1", "true", "yes", "on"}
|
||||
|
||||
async def is_registration_enabled(self) -> bool:
|
||||
value = await self._get_config_value("auth.allow_registration")
|
||||
return self._parse_bool(value, fallback=settings.allow_registration)
|
||||
|
||||
async def is_linuxdo_login_enabled(self) -> bool:
|
||||
value = await self._get_config_value("auth.linuxdo_enabled")
|
||||
return self._parse_bool(value, fallback=settings.enable_linuxdo_login)
|
||||
|
||||
async def get_auth_options(self) -> AuthOptions:
|
||||
"""聚合与认证相关的动态开关配置,便于前端一次性拉取。"""
|
||||
|
||||
allow_registration = await self.is_registration_enabled()
|
||||
enable_linuxdo_login = await self.is_linuxdo_login_enabled()
|
||||
return AuthOptions(
|
||||
allow_registration=allow_registration,
|
||||
enable_linuxdo_login=enable_linuxdo_login,
|
||||
)
|
||||
|
||||
def requires_password_reset(self, user: User | UserInDB) -> bool:
|
||||
if not user.is_admin:
|
||||
return False
|
||||
if user.username != settings.admin_default_username:
|
||||
return False
|
||||
hashed_password = getattr(user, "hashed_password", None)
|
||||
if not hashed_password:
|
||||
return False
|
||||
return verify_password(settings.admin_default_password, hashed_password)
|
||||
|
||||
async def change_password(self, username: str, old_password: str, new_password: str) -> None:
|
||||
user = await self.user_repo.get_by_username(username)
|
||||
if not user:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="用户不存在")
|
||||
|
||||
if not verify_password(old_password, user.hashed_password):
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="当前密码错误")
|
||||
|
||||
if verify_password(new_password, user.hashed_password):
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="新密码不能与当前密码相同")
|
||||
|
||||
if username == settings.admin_default_username and new_password == settings.admin_default_password:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="新密码不能为默认密码")
|
||||
|
||||
user.hashed_password = hash_password(new_password)
|
||||
await self.session.commit()
|
||||
109
backend/app/services/chapter_context_service.py
Normal file
109
backend/app/services/chapter_context_service.py
Normal file
@@ -0,0 +1,109 @@
|
||||
from __future__ import annotations
|
||||
|
||||
"""
|
||||
章节上下文组装服务:负责调用向量库检索上下文,并对结果做基础格式化。
|
||||
|
||||
所有关键步骤均包含中文注释,方便团队理解 RAG 流程。
|
||||
"""
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional
|
||||
|
||||
from ..core.config import settings
|
||||
from ..services.llm_service import LLMService
|
||||
from .vector_store_service import RetrievedChunk, RetrievedSummary, VectorStoreService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChapterRAGContext:
|
||||
"""封装检索得到的上下文结果。"""
|
||||
|
||||
query: str
|
||||
chunks: List[RetrievedChunk]
|
||||
summaries: List[RetrievedSummary]
|
||||
|
||||
def chunk_texts(self) -> List[str]:
|
||||
"""将检索到的 chunk 转换成带序号的 Markdown 段落。"""
|
||||
lines = []
|
||||
for idx, chunk in enumerate(self.chunks, start=1):
|
||||
title = chunk.chapter_title or f"第{chunk.chapter_number}章"
|
||||
lines.append(
|
||||
f"### Chunk {idx}(来源:{title})\n{chunk.content.strip()}"
|
||||
)
|
||||
return lines
|
||||
|
||||
def summary_lines(self) -> List[str]:
|
||||
"""整理章节摘要,方便直接插入 Prompt。"""
|
||||
lines = []
|
||||
for summary in self.summaries:
|
||||
lines.append(
|
||||
f"- 第{summary.chapter_number}章 - {summary.title}:{summary.summary.strip()}"
|
||||
)
|
||||
return lines
|
||||
|
||||
|
||||
class ChapterContextService:
|
||||
"""章节上下文服务,整合查询、格式化与容错逻辑。"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
llm_service: LLMService,
|
||||
vector_store: Optional[VectorStoreService] = None,
|
||||
) -> None:
|
||||
self._llm_service = llm_service
|
||||
self._vector_store = vector_store
|
||||
|
||||
async def retrieve_for_generation(
|
||||
self,
|
||||
*,
|
||||
project_id: str,
|
||||
query_text: str,
|
||||
user_id: int,
|
||||
top_k_chunks: Optional[int] = None,
|
||||
top_k_summaries: Optional[int] = None,
|
||||
) -> ChapterRAGContext:
|
||||
"""根据章节摘要构造检索向量,并返回 RAG 上下文。"""
|
||||
query = self._normalize(query_text)
|
||||
if not settings.vector_store_enabled or not self._vector_store:
|
||||
logger.debug("向量库未启用或初始化失败,跳过检索: project=%s", project_id)
|
||||
return ChapterRAGContext(query=query, chunks=[], summaries=[])
|
||||
|
||||
embedding_model = None if settings.embedding_provider == "ollama" else settings.embedding_model
|
||||
embedding = await self._llm_service.get_embedding(query, user_id=user_id, model=embedding_model)
|
||||
if not embedding:
|
||||
logger.warning("检索查询向量生成失败: project=%s chapter_query=%s", project_id, query)
|
||||
return ChapterRAGContext(query=query, chunks=[], summaries=[])
|
||||
|
||||
chunks = await self._vector_store.query_chunks(
|
||||
project_id=project_id,
|
||||
embedding=embedding,
|
||||
top_k=top_k_chunks,
|
||||
)
|
||||
summaries = await self._vector_store.query_summaries(
|
||||
project_id=project_id,
|
||||
embedding=embedding,
|
||||
top_k=top_k_summaries,
|
||||
)
|
||||
logger.info(
|
||||
"章节上下文检索完成: project=%s chunks=%d summaries=%d query_preview=%s",
|
||||
project_id,
|
||||
len(chunks),
|
||||
len(summaries),
|
||||
query[:80],
|
||||
)
|
||||
return ChapterRAGContext(query=query, chunks=chunks, summaries=summaries)
|
||||
|
||||
@staticmethod
|
||||
def _normalize(text: str) -> str:
|
||||
"""统一压缩空白字符,避免影响检索效果。"""
|
||||
return " ".join(text.split())
|
||||
|
||||
|
||||
__all__ = [
|
||||
"ChapterContextService",
|
||||
"ChapterRAGContext",
|
||||
]
|
||||
262
backend/app/services/chapter_ingest_service.py
Normal file
262
backend/app/services/chapter_ingest_service.py
Normal file
@@ -0,0 +1,262 @@
|
||||
from __future__ import annotations
|
||||
|
||||
"""
|
||||
章节向量入库服务:在章节确认后负责切分文本、生成嵌入并写入向量库。
|
||||
|
||||
全部注释使用中文,方便团队成员阅读理解。
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Dict, List, Optional, Sequence
|
||||
|
||||
from ..core.config import settings
|
||||
from ..services.llm_service import LLMService
|
||||
from ..services.vector_store_service import VectorStoreService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
try: # noqa: SIM105 - 提示缺少可选依赖
|
||||
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
||||
except ImportError: # pragma: no cover - 未安装时会走后备方案
|
||||
RecursiveCharacterTextSplitter = None # type: ignore[assignment]
|
||||
|
||||
|
||||
class ChapterIngestionService:
|
||||
"""封装章节内容与摘要的向量化与入库流程。"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
llm_service: LLMService,
|
||||
vector_store: Optional[VectorStoreService] = None,
|
||||
) -> None:
|
||||
self._llm_service = llm_service
|
||||
self._vector_store = vector_store or VectorStoreService()
|
||||
self._text_splitter = self._init_text_splitter()
|
||||
|
||||
async def ingest_chapter(
|
||||
self,
|
||||
*,
|
||||
project_id: str,
|
||||
chapter_number: int,
|
||||
title: str,
|
||||
content: str,
|
||||
summary: Optional[str],
|
||||
user_id: int,
|
||||
) -> None:
|
||||
"""将章节正文与摘要写入向量库,供后续 RAG 检索使用。"""
|
||||
if not settings.vector_store_enabled:
|
||||
logger.debug("向量库未启用,跳过章节向量写入: project=%s chapter=%s", project_id, chapter_number)
|
||||
return
|
||||
if not content.strip():
|
||||
logger.debug("章节正文为空,跳过向量写入: project=%s chapter=%s", project_id, chapter_number)
|
||||
return
|
||||
|
||||
chunks = self._split_into_chunks(content)
|
||||
if not chunks:
|
||||
logger.debug("章节正文切分后为空,跳过向量写入: project=%s chapter=%s", project_id, chapter_number)
|
||||
return
|
||||
|
||||
logger.info(
|
||||
"开始写入章节向量: project=%s chapter=%s chunks=%d",
|
||||
project_id,
|
||||
chapter_number,
|
||||
len(chunks),
|
||||
)
|
||||
await self._vector_store.delete_by_chapters(project_id, [chapter_number])
|
||||
|
||||
chunk_records = []
|
||||
for index, chunk_text in enumerate(chunks):
|
||||
embedding = await self._llm_service.get_embedding(
|
||||
chunk_text,
|
||||
user_id=user_id,
|
||||
)
|
||||
if not embedding:
|
||||
logger.warning(
|
||||
"生成章节片段向量失败,已跳过: project=%s chapter=%s chunk=%s",
|
||||
project_id,
|
||||
chapter_number,
|
||||
index,
|
||||
)
|
||||
continue
|
||||
record_id = f"{project_id}:{chapter_number}:{index}"
|
||||
chunk_records.append(
|
||||
{
|
||||
"id": record_id,
|
||||
"project_id": project_id,
|
||||
"chapter_number": chapter_number,
|
||||
"chunk_index": index,
|
||||
"chapter_title": title,
|
||||
"content": chunk_text,
|
||||
"embedding": embedding,
|
||||
"metadata": {
|
||||
"chunk_id": record_id,
|
||||
"length": len(chunk_text),
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
if chunk_records:
|
||||
await self._vector_store.upsert_chunks(records=chunk_records)
|
||||
logger.info(
|
||||
"章节正文向量写入完成: project=%s chapter=%s 成功片段=%d",
|
||||
project_id,
|
||||
chapter_number,
|
||||
len(chunk_records),
|
||||
)
|
||||
|
||||
if summary:
|
||||
cleaned_summary = summary.strip()
|
||||
if cleaned_summary:
|
||||
summary_embedding = await self._llm_service.get_embedding(
|
||||
cleaned_summary,
|
||||
user_id=user_id,
|
||||
)
|
||||
if summary_embedding:
|
||||
summary_id = f"{project_id}:{chapter_number}:summary"
|
||||
await self._vector_store.upsert_summaries(
|
||||
records=[
|
||||
{
|
||||
"id": summary_id,
|
||||
"project_id": project_id,
|
||||
"chapter_number": chapter_number,
|
||||
"title": title,
|
||||
"summary": cleaned_summary,
|
||||
"embedding": summary_embedding,
|
||||
}
|
||||
]
|
||||
)
|
||||
logger.info(
|
||||
"章节摘要向量写入完成: project=%s chapter=%s",
|
||||
project_id,
|
||||
chapter_number,
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"生成章节摘要向量失败,已跳过: project=%s chapter=%s",
|
||||
project_id,
|
||||
chapter_number,
|
||||
)
|
||||
|
||||
async def delete_chapters(self, project_id: str, chapter_numbers: Sequence[int]) -> None:
|
||||
"""从向量库中删除指定章节的所有片段与摘要。"""
|
||||
if not settings.vector_store_enabled or not chapter_numbers:
|
||||
return
|
||||
logger.info(
|
||||
"准备删除章节向量: project=%s chapters=%s",
|
||||
project_id,
|
||||
list(chapter_numbers),
|
||||
)
|
||||
await self._vector_store.delete_by_chapters(project_id, list(chapter_numbers))
|
||||
|
||||
def _split_into_chunks(self, text: str) -> List[str]:
|
||||
"""按照配置的 chunk 大小与重叠度切分章节正文。"""
|
||||
normalized = text.strip()
|
||||
if not normalized:
|
||||
return []
|
||||
|
||||
if self._text_splitter:
|
||||
parts = [segment.strip() for segment in self._text_splitter.split_text(normalized)]
|
||||
filtered = [part for part in parts if part]
|
||||
if filtered:
|
||||
logger.debug(
|
||||
"使用 LangChain 文本切分器完成分段: count=%d chunk_size=%d overlap=%d",
|
||||
len(filtered),
|
||||
settings.vector_chunk_size,
|
||||
settings.vector_chunk_overlap,
|
||||
)
|
||||
return filtered
|
||||
|
||||
return self._legacy_split(normalized)
|
||||
|
||||
@staticmethod
|
||||
def _find_split_offset(segment: str) -> Optional[int]:
|
||||
"""在片段内部寻找更自然的分割点,优先换行,其次常见标点。"""
|
||||
candidates: Dict[str, int] = {}
|
||||
newline_pos = segment.rfind("\n\n")
|
||||
if newline_pos == -1:
|
||||
newline_pos = segment.rfind("\n")
|
||||
if newline_pos > 0:
|
||||
candidates["newline"] = newline_pos
|
||||
|
||||
punctuation_marks = ["。", "!", "?", "!", "?", ".", ";", ";"]
|
||||
for mark in punctuation_marks:
|
||||
idx = segment.rfind(mark)
|
||||
if idx > 0:
|
||||
candidates.setdefault("punctuation", idx + len(mark))
|
||||
|
||||
if not candidates:
|
||||
return None
|
||||
|
||||
# 选择最接近末尾但又不过短的分割点
|
||||
best_offset = max(candidates.values())
|
||||
if best_offset < len(segment) * 0.4:
|
||||
return None
|
||||
return best_offset
|
||||
|
||||
def _init_text_splitter(self) -> Optional["RecursiveCharacterTextSplitter"]:
|
||||
"""初始化 LangChain 文本切分器,可根据配置动态调整。"""
|
||||
if RecursiveCharacterTextSplitter is None:
|
||||
logger.warning("未安装 langchain-text-splitters,章节切分将回退至内置策略。")
|
||||
return None
|
||||
|
||||
chunk_size = settings.vector_chunk_size
|
||||
overlap = min(settings.vector_chunk_overlap, chunk_size // 2)
|
||||
separators = [
|
||||
"\n\n",
|
||||
"\n",
|
||||
"。", "!", "?",
|
||||
"!", "?", ";", ";",
|
||||
",", ",",
|
||||
" ",
|
||||
]
|
||||
splitter = RecursiveCharacterTextSplitter(
|
||||
separators=separators,
|
||||
chunk_size=chunk_size,
|
||||
chunk_overlap=overlap,
|
||||
keep_separator=False,
|
||||
strip_whitespace=True,
|
||||
)
|
||||
logger.info(
|
||||
"已初始化 LangChain 文本切分器: chunk_size=%d overlap=%d",
|
||||
chunk_size,
|
||||
overlap,
|
||||
)
|
||||
return splitter
|
||||
|
||||
def _legacy_split(self, text: str) -> List[str]:
|
||||
"""内置切分策略,作为 LangChain 缺失时的后备方案。"""
|
||||
chunk_size = settings.vector_chunk_size
|
||||
overlap = min(settings.vector_chunk_overlap, chunk_size // 2)
|
||||
|
||||
chunks: List[str] = []
|
||||
start = 0
|
||||
total_length = len(text)
|
||||
|
||||
while start < total_length:
|
||||
end = min(total_length, start + chunk_size)
|
||||
segment = text[start:end]
|
||||
|
||||
split_offset = self._find_split_offset(segment)
|
||||
if split_offset is not None and start + split_offset < total_length:
|
||||
end = start + split_offset
|
||||
segment = text[start:end]
|
||||
|
||||
chunk_text = segment.strip()
|
||||
if chunk_text:
|
||||
chunks.append(chunk_text)
|
||||
|
||||
if end >= total_length:
|
||||
break
|
||||
start = max(0, end - overlap)
|
||||
|
||||
logger.debug(
|
||||
"使用内置策略完成章节切分: count=%d chunk_size=%d overlap=%d",
|
||||
len(chunks),
|
||||
chunk_size,
|
||||
overlap,
|
||||
)
|
||||
return chunks
|
||||
|
||||
|
||||
__all__ = ["ChapterIngestionService"]
|
||||
49
backend/app/services/config_service.py
Normal file
49
backend/app/services/config_service.py
Normal file
@@ -0,0 +1,49 @@
|
||||
from typing import Iterable, Optional
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from ..repositories.system_config_repository import SystemConfigRepository
|
||||
from ..models import SystemConfig
|
||||
from ..schemas.config import SystemConfigCreate, SystemConfigRead, SystemConfigUpdate
|
||||
|
||||
|
||||
class ConfigService:
|
||||
"""系统配置服务:提供 CRUD 接口,并负责转换 Pydantic 模型。"""
|
||||
|
||||
def __init__(self, session: AsyncSession):
|
||||
self.session = session
|
||||
self.repo = SystemConfigRepository(session)
|
||||
|
||||
async def list_configs(self) -> list[SystemConfigRead]:
|
||||
configs = await self.repo.list_all()
|
||||
return [SystemConfigRead.model_validate(cfg) for cfg in configs]
|
||||
|
||||
async def get_config(self, key: str) -> Optional[SystemConfigRead]:
|
||||
config = await self.repo.get_by_key(key)
|
||||
return SystemConfigRead.model_validate(config) if config else None
|
||||
|
||||
async def upsert_config(self, payload: SystemConfigCreate) -> SystemConfigRead:
|
||||
instance = await self.repo.get_by_key(payload.key)
|
||||
if instance:
|
||||
await self.repo.update_fields(instance, value=payload.value, description=payload.description)
|
||||
else:
|
||||
instance = SystemConfig(**payload.model_dump())
|
||||
await self.repo.add(instance)
|
||||
await self.session.commit()
|
||||
return SystemConfigRead.model_validate(instance)
|
||||
|
||||
async def patch_config(self, key: str, payload: SystemConfigUpdate) -> Optional[SystemConfigRead]:
|
||||
instance = await self.repo.get_by_key(key)
|
||||
if not instance:
|
||||
return None
|
||||
await self.repo.update_fields(instance, **payload.model_dump(exclude_unset=True))
|
||||
await self.session.commit()
|
||||
return SystemConfigRead.model_validate(instance)
|
||||
|
||||
async def remove_config(self, key: str) -> bool:
|
||||
instance = await self.repo.get_by_key(key)
|
||||
if not instance:
|
||||
return False
|
||||
await self.repo.delete(instance)
|
||||
await self.session.commit()
|
||||
return True
|
||||
41
backend/app/services/llm_config_service.py
Normal file
41
backend/app/services/llm_config_service.py
Normal file
@@ -0,0 +1,41 @@
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from ..models import LLMConfig
|
||||
from ..repositories.llm_config_repository import LLMConfigRepository
|
||||
from ..schemas.llm_config import LLMConfigCreate, LLMConfigRead
|
||||
|
||||
|
||||
class LLMConfigService:
|
||||
"""用户自定义 LLM 配置服务。"""
|
||||
|
||||
def __init__(self, session: AsyncSession):
|
||||
self.session = session
|
||||
self.repo = LLMConfigRepository(session)
|
||||
|
||||
async def upsert_config(self, user_id: int, payload: LLMConfigCreate) -> LLMConfigRead:
|
||||
instance = await self.repo.get_by_user(user_id)
|
||||
data = payload.model_dump(exclude_unset=True)
|
||||
if "llm_provider_url" in data and data["llm_provider_url"] is not None:
|
||||
# HttpUrl 类型在 sqlite 中无法直接写入,需要提前转为字符串
|
||||
data["llm_provider_url"] = str(data["llm_provider_url"])
|
||||
if instance:
|
||||
await self.repo.update_fields(instance, **data)
|
||||
else:
|
||||
instance = LLMConfig(user_id=user_id, **data)
|
||||
await self.repo.add(instance)
|
||||
await self.session.commit()
|
||||
return LLMConfigRead.model_validate(instance)
|
||||
|
||||
async def get_config(self, user_id: int) -> Optional[LLMConfigRead]:
|
||||
instance = await self.repo.get_by_user(user_id)
|
||||
return LLMConfigRead.model_validate(instance) if instance else None
|
||||
|
||||
async def delete_config(self, user_id: int) -> bool:
|
||||
instance = await self.repo.get_by_user(user_id)
|
||||
if not instance:
|
||||
return False
|
||||
await self.repo.delete(instance)
|
||||
await self.session.commit()
|
||||
return True
|
||||
306
backend/app/services/llm_service.py
Normal file
306
backend/app/services/llm_service.py
Normal file
@@ -0,0 +1,306 @@
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import httpx
|
||||
from fastapi import HTTPException, status
|
||||
from openai import APIConnectionError, APITimeoutError, AsyncOpenAI, InternalServerError
|
||||
|
||||
from ..core.config import settings
|
||||
from ..repositories.llm_config_repository import LLMConfigRepository
|
||||
from ..repositories.system_config_repository import SystemConfigRepository
|
||||
from ..repositories.user_repository import UserRepository
|
||||
from ..services.admin_setting_service import AdminSettingService
|
||||
from ..services.prompt_service import PromptService
|
||||
from ..services.usage_service import UsageService
|
||||
from ..utils.llm_tool import ChatMessage, LLMClient
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
try: # pragma: no cover - 运行环境未安装时兼容
|
||||
from ollama import AsyncClient as OllamaAsyncClient
|
||||
except ImportError: # pragma: no cover - Ollama 为可选依赖
|
||||
OllamaAsyncClient = None
|
||||
|
||||
|
||||
class LLMService:
|
||||
"""封装与大模型交互的所有逻辑,包括配额控制与配置选择。"""
|
||||
|
||||
def __init__(self, session):
|
||||
self.session = session
|
||||
self.llm_repo = LLMConfigRepository(session)
|
||||
self.system_config_repo = SystemConfigRepository(session)
|
||||
self.user_repo = UserRepository(session)
|
||||
self.admin_setting_service = AdminSettingService(session)
|
||||
self.usage_service = UsageService(session)
|
||||
self._embedding_dimensions: Dict[str, int] = {}
|
||||
|
||||
async def get_llm_response(
|
||||
self,
|
||||
system_prompt: str,
|
||||
conversation_history: List[Dict[str, str]],
|
||||
*,
|
||||
temperature: float = 0.7,
|
||||
user_id: Optional[int] = None,
|
||||
timeout: float = 300.0,
|
||||
response_format: Optional[str] = "json_object",
|
||||
) -> str:
|
||||
messages = [{"role": "system", "content": system_prompt}, *conversation_history]
|
||||
return await self._stream_and_collect(
|
||||
messages,
|
||||
temperature=temperature,
|
||||
user_id=user_id,
|
||||
timeout=timeout,
|
||||
response_format=response_format,
|
||||
)
|
||||
|
||||
async def get_summary(
|
||||
self,
|
||||
chapter_content: str,
|
||||
*,
|
||||
temperature: float = 0.2,
|
||||
user_id: Optional[int] = None,
|
||||
timeout: float = 180.0,
|
||||
system_prompt: Optional[str] = None,
|
||||
) -> str:
|
||||
if not system_prompt:
|
||||
prompt_service = PromptService(self.session)
|
||||
system_prompt = await prompt_service.get_prompt("extraction")
|
||||
if not system_prompt:
|
||||
raise HTTPException(status_code=500, detail="未配置摘要提示词")
|
||||
messages = [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": chapter_content},
|
||||
]
|
||||
return await self._stream_and_collect(messages, temperature=temperature, user_id=user_id, timeout=timeout)
|
||||
|
||||
async def _stream_and_collect(
|
||||
self,
|
||||
messages: List[Dict[str, str]],
|
||||
*,
|
||||
temperature: float,
|
||||
user_id: Optional[int],
|
||||
timeout: float,
|
||||
response_format: Optional[str] = None,
|
||||
) -> str:
|
||||
config = await self._resolve_llm_config(user_id)
|
||||
client = LLMClient(api_key=config["api_key"], base_url=config.get("base_url"))
|
||||
|
||||
chat_messages = [ChatMessage(role=msg["role"], content=msg["content"]) for msg in messages]
|
||||
|
||||
full_response = ""
|
||||
finish_reason = None
|
||||
|
||||
logger.info(
|
||||
"Streaming LLM response: model=%s user_id=%s messages=%d",
|
||||
config.get("model"),
|
||||
user_id,
|
||||
len(messages),
|
||||
)
|
||||
|
||||
try:
|
||||
async for part in client.stream_chat(
|
||||
messages=chat_messages,
|
||||
model=config.get("model"),
|
||||
temperature=temperature,
|
||||
timeout=int(timeout),
|
||||
response_format=response_format,
|
||||
):
|
||||
if part.get("content"):
|
||||
full_response += part["content"]
|
||||
if part.get("finish_reason"):
|
||||
finish_reason = part["finish_reason"]
|
||||
except InternalServerError as exc:
|
||||
detail = "AI 服务内部错误,请稍后重试"
|
||||
response = getattr(exc, "response", None)
|
||||
if response is not None:
|
||||
try:
|
||||
payload = response.json()
|
||||
error_data = payload.get("error", {}) if isinstance(payload, dict) else {}
|
||||
detail = error_data.get("message_zh") or error_data.get("message") or detail
|
||||
except Exception:
|
||||
detail = str(exc) or detail
|
||||
else:
|
||||
detail = str(exc) or detail
|
||||
logger.error(
|
||||
"LLM stream internal error: model=%s user_id=%s detail=%s",
|
||||
config.get("model"),
|
||||
user_id,
|
||||
detail,
|
||||
exc_info=exc,
|
||||
)
|
||||
raise HTTPException(status_code=503, detail=detail)
|
||||
except (httpx.RemoteProtocolError, httpx.ReadTimeout, APIConnectionError, APITimeoutError) as exc:
|
||||
if isinstance(exc, httpx.RemoteProtocolError):
|
||||
detail = "AI 服务连接被意外中断,请稍后重试"
|
||||
elif isinstance(exc, (httpx.ReadTimeout, APITimeoutError)):
|
||||
detail = "AI 服务响应超时,请稍后重试"
|
||||
else:
|
||||
detail = "无法连接到 AI 服务,请稍后重试"
|
||||
logger.error(
|
||||
"LLM stream failed: model=%s user_id=%s detail=%s",
|
||||
config.get("model"),
|
||||
user_id,
|
||||
detail,
|
||||
exc_info=exc,
|
||||
)
|
||||
raise HTTPException(status_code=503, detail=detail) from exc
|
||||
|
||||
logger.debug(
|
||||
"LLM response collected: model=%s user_id=%s finish_reason=%s preview=%s",
|
||||
config.get("model"),
|
||||
user_id,
|
||||
finish_reason,
|
||||
full_response[:500],
|
||||
)
|
||||
|
||||
if finish_reason == "length":
|
||||
logger.warning(
|
||||
"LLM response truncated: model=%s user_id=%s",
|
||||
config.get("model"),
|
||||
user_id,
|
||||
)
|
||||
raise HTTPException(status_code=500, detail="AI 响应被截断,请缩短输入或调整参数")
|
||||
|
||||
if not full_response:
|
||||
logger.error(
|
||||
"LLM returned empty response: model=%s user_id=%s",
|
||||
config.get("model"),
|
||||
user_id,
|
||||
)
|
||||
raise HTTPException(status_code=500, detail="AI 未返回有效内容")
|
||||
|
||||
await self.usage_service.increment("api_request_count")
|
||||
logger.info(
|
||||
"LLM response success: model=%s user_id=%s chars=%d",
|
||||
config.get("model"),
|
||||
user_id,
|
||||
len(full_response),
|
||||
)
|
||||
return full_response
|
||||
|
||||
async def _resolve_llm_config(self, user_id: Optional[int]) -> Dict[str, Optional[str]]:
|
||||
if user_id:
|
||||
config = await self.llm_repo.get_by_user(user_id)
|
||||
if config and config.llm_provider_api_key:
|
||||
return {
|
||||
"api_key": config.llm_provider_api_key,
|
||||
"base_url": config.llm_provider_url,
|
||||
"model": config.llm_provider_model,
|
||||
}
|
||||
|
||||
# 检查每日使用次数限制
|
||||
if user_id:
|
||||
await self._enforce_daily_limit(user_id)
|
||||
|
||||
api_key = await self._get_config_value("llm.api_key")
|
||||
base_url = await self._get_config_value("llm.base_url")
|
||||
model = await self._get_config_value("llm.model")
|
||||
|
||||
if not api_key:
|
||||
raise HTTPException(status_code=500, detail="未配置默认 LLM API Key")
|
||||
|
||||
return {"api_key": api_key, "base_url": base_url, "model": model}
|
||||
|
||||
async def get_embedding(
|
||||
self,
|
||||
text: str,
|
||||
*,
|
||||
user_id: Optional[int] = None,
|
||||
model: Optional[str] = None,
|
||||
) -> List[float]:
|
||||
"""生成文本向量,用于章节 RAG 检索,支持 openai 与 ollama 双提供方。"""
|
||||
provider = settings.embedding_provider
|
||||
target_model = model or (
|
||||
settings.ollama_embedding_model if provider == "ollama" else settings.embedding_model
|
||||
)
|
||||
|
||||
if provider == "ollama":
|
||||
if OllamaAsyncClient is None:
|
||||
logger.error("未安装 ollama 依赖,无法调用本地嵌入模型。")
|
||||
raise HTTPException(status_code=500, detail="缺少 Ollama 依赖,请先安装 ollama 包。")
|
||||
|
||||
base_url_any = settings.ollama_embedding_base_url or settings.embedding_base_url
|
||||
base_url = str(base_url_any) if base_url_any else None
|
||||
client = OllamaAsyncClient(host=base_url)
|
||||
try:
|
||||
response = await client.embeddings(model=target_model, prompt=text)
|
||||
except Exception as exc: # pragma: no cover - 本地服务调用失败
|
||||
logger.warning(
|
||||
"Ollama 嵌入请求失败: model=%s error=%s",
|
||||
target_model,
|
||||
exc,
|
||||
)
|
||||
return []
|
||||
embedding: Optional[List[float]]
|
||||
if isinstance(response, dict):
|
||||
embedding = response.get("embedding")
|
||||
else:
|
||||
embedding = getattr(response, "embedding", None)
|
||||
if not embedding:
|
||||
logger.warning("Ollama 返回空向量: model=%s", target_model)
|
||||
return []
|
||||
if not isinstance(embedding, list):
|
||||
embedding = list(embedding)
|
||||
else:
|
||||
config = await self._resolve_llm_config(user_id)
|
||||
api_key = settings.embedding_api_key or config["api_key"]
|
||||
base_url_setting = settings.embedding_base_url or config.get("base_url")
|
||||
base_url = str(base_url_setting) if base_url_setting else None
|
||||
client = AsyncOpenAI(api_key=api_key, base_url=base_url)
|
||||
try:
|
||||
response = await client.embeddings.create(
|
||||
input=text,
|
||||
model=target_model,
|
||||
)
|
||||
except Exception as exc: # pragma: no cover - 网络或鉴权失败
|
||||
logger.warning(
|
||||
"OpenAI 嵌入请求失败: model=%s user_id=%s error=%s",
|
||||
target_model,
|
||||
user_id,
|
||||
exc,
|
||||
)
|
||||
return []
|
||||
if not response.data:
|
||||
logger.warning("OpenAI 嵌入请求返回空数据: model=%s user_id=%s", target_model, user_id)
|
||||
return []
|
||||
embedding = response.data[0].embedding
|
||||
|
||||
if not isinstance(embedding, list):
|
||||
embedding = list(embedding)
|
||||
|
||||
dimension = len(embedding)
|
||||
if not dimension and settings.embedding_model_vector_size:
|
||||
dimension = settings.embedding_model_vector_size
|
||||
if dimension:
|
||||
self._embedding_dimensions[target_model] = dimension
|
||||
return embedding
|
||||
|
||||
def get_embedding_dimension(self, model: Optional[str] = None) -> Optional[int]:
|
||||
"""获取嵌入向量维度,优先返回缓存结果,其次读取配置。"""
|
||||
target_model = model or (
|
||||
settings.ollama_embedding_model if settings.embedding_provider == "ollama" else settings.embedding_model
|
||||
)
|
||||
if target_model in self._embedding_dimensions:
|
||||
return self._embedding_dimensions[target_model]
|
||||
return settings.embedding_model_vector_size
|
||||
|
||||
async def _enforce_daily_limit(self, user_id: int) -> None:
|
||||
limit_str = await self.admin_setting_service.get("daily_request_limit", "100")
|
||||
limit = int(limit_str or 10)
|
||||
used = await self.user_repo.get_daily_request(user_id)
|
||||
if used >= limit:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
|
||||
detail="今日请求次数已达上限,请明日再试或设置自定义 API Key。",
|
||||
)
|
||||
await self.user_repo.increment_daily_request(user_id)
|
||||
await self.session.commit()
|
||||
|
||||
async def _get_config_value(self, key: str) -> Optional[str]:
|
||||
record = await self.system_config_repo.get_by_key(key)
|
||||
if record:
|
||||
return record.value
|
||||
# 兼容环境变量,首次迁移时无需立即写入数据库
|
||||
env_key = key.upper().replace(".", "_")
|
||||
return os.getenv(env_key)
|
||||
700
backend/app/services/novel_service.py
Normal file
700
backend/app/services/novel_service.py
Normal file
@@ -0,0 +1,700 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Dict, Iterable, List, Optional
|
||||
|
||||
_PREFERRED_CONTENT_KEYS: tuple[str, ...] = (
|
||||
"content",
|
||||
"chapter_content",
|
||||
"chapter_text",
|
||||
"full_content",
|
||||
"text",
|
||||
"body",
|
||||
"story",
|
||||
"chapter",
|
||||
"real_summary",
|
||||
"summary",
|
||||
)
|
||||
|
||||
|
||||
def _normalize_version_content(raw_content: Any, metadata: Any) -> str:
|
||||
text = _coerce_text(metadata)
|
||||
if not text:
|
||||
text = _coerce_text(raw_content)
|
||||
return text or ""
|
||||
|
||||
|
||||
def _coerce_text(value: Any) -> Optional[str]:
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, str):
|
||||
return _clean_string(value)
|
||||
if isinstance(value, (int, float)):
|
||||
return str(value)
|
||||
if isinstance(value, dict):
|
||||
for key in _PREFERRED_CONTENT_KEYS:
|
||||
if key in value and value[key]:
|
||||
nested = _coerce_text(value[key])
|
||||
if nested:
|
||||
return nested
|
||||
return _clean_string(json.dumps(value, ensure_ascii=False))
|
||||
if isinstance(value, (list, tuple, set)):
|
||||
parts = [text for text in (_coerce_text(item) for item in value) if text]
|
||||
if parts:
|
||||
return "\n".join(parts)
|
||||
return None
|
||||
return _clean_string(str(value))
|
||||
|
||||
|
||||
def _clean_string(text: str) -> str:
|
||||
stripped = text.strip()
|
||||
if not stripped:
|
||||
return stripped
|
||||
if stripped.startswith("{") and stripped.endswith("}"):
|
||||
try:
|
||||
parsed = json.loads(stripped)
|
||||
coerced = _coerce_text(parsed)
|
||||
if coerced:
|
||||
return coerced
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
if stripped.startswith('"') and stripped.endswith('"') and len(stripped) >= 2:
|
||||
stripped = stripped[1:-1]
|
||||
return (
|
||||
stripped.replace("\\n", "\n")
|
||||
.replace("\\t", "\t")
|
||||
.replace('\\"', '"')
|
||||
.replace("\\\\", "\\")
|
||||
)
|
||||
|
||||
from fastapi import HTTPException, status
|
||||
from sqlalchemy import delete, func, select, update
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from ..models import (
|
||||
BlueprintCharacter,
|
||||
BlueprintRelationship,
|
||||
Chapter,
|
||||
ChapterEvaluation,
|
||||
ChapterOutline,
|
||||
ChapterVersion,
|
||||
NovelBlueprint,
|
||||
NovelConversation,
|
||||
NovelProject,
|
||||
)
|
||||
from ..repositories.novel_repository import NovelRepository
|
||||
from ..schemas.admin import AdminNovelSummary
|
||||
from ..schemas.novel import (
|
||||
Blueprint,
|
||||
Chapter as ChapterSchema,
|
||||
ChapterGenerationStatus,
|
||||
ChapterOutline as ChapterOutlineSchema,
|
||||
NovelProject as NovelProjectSchema,
|
||||
NovelProjectSummary,
|
||||
NovelSectionResponse,
|
||||
NovelSectionType,
|
||||
)
|
||||
|
||||
|
||||
class NovelService:
|
||||
"""小说项目服务,基于拆表后的结构提供聚合与业务操作。"""
|
||||
|
||||
def __init__(self, session: AsyncSession):
|
||||
self.session = session
|
||||
self.repo = NovelRepository(session)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 项目与摘要
|
||||
# ------------------------------------------------------------------
|
||||
async def create_project(self, user_id: int, title: str, initial_prompt: str) -> NovelProject:
|
||||
project = NovelProject(
|
||||
id=str(uuid.uuid4()),
|
||||
user_id=user_id,
|
||||
title=title,
|
||||
initial_prompt=initial_prompt,
|
||||
)
|
||||
blueprint = NovelBlueprint(project=project)
|
||||
self.session.add_all([project, blueprint])
|
||||
await self.session.commit()
|
||||
await self.session.refresh(project)
|
||||
return project
|
||||
|
||||
async def ensure_project_owner(self, project_id: str, user_id: int) -> NovelProject:
|
||||
project = await self.repo.get_by_id(project_id)
|
||||
if not project:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="项目不存在")
|
||||
if project.user_id != user_id:
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="无权访问该项目")
|
||||
return project
|
||||
|
||||
async def get_project_schema(self, project_id: str, user_id: int) -> NovelProjectSchema:
|
||||
project = await self.ensure_project_owner(project_id, user_id)
|
||||
return await self._serialize_project(project)
|
||||
|
||||
async def get_section_data(
|
||||
self,
|
||||
project_id: str,
|
||||
user_id: int,
|
||||
section: NovelSectionType,
|
||||
) -> NovelSectionResponse:
|
||||
project = await self.ensure_project_owner(project_id, user_id)
|
||||
return self._build_section_response(project, section)
|
||||
|
||||
async def get_chapter_schema(
|
||||
self,
|
||||
project_id: str,
|
||||
user_id: int,
|
||||
chapter_number: int,
|
||||
) -> ChapterSchema:
|
||||
project = await self.ensure_project_owner(project_id, user_id)
|
||||
return self._build_chapter_schema(project, chapter_number)
|
||||
|
||||
async def list_projects_for_user(self, user_id: int) -> List[NovelProjectSummary]:
|
||||
projects = await self.repo.list_by_user(user_id)
|
||||
summaries: List[NovelProjectSummary] = []
|
||||
for project in projects:
|
||||
blueprint = project.blueprint
|
||||
genre = blueprint.genre if blueprint and blueprint.genre else "未知"
|
||||
outlines = project.outlines
|
||||
chapters = project.chapters
|
||||
total = len(outlines) or len(chapters)
|
||||
completed = sum(1 for chapter in chapters if chapter.selected_version_id)
|
||||
summaries.append(
|
||||
NovelProjectSummary(
|
||||
id=project.id,
|
||||
title=project.title,
|
||||
genre=genre,
|
||||
last_edited=project.updated_at.isoformat() if project.updated_at else "未知",
|
||||
completed_chapters=completed,
|
||||
total_chapters=total,
|
||||
)
|
||||
)
|
||||
return summaries
|
||||
|
||||
async def list_projects_for_admin(self) -> List[AdminNovelSummary]:
|
||||
projects = await self.repo.list_all()
|
||||
summaries: List[AdminNovelSummary] = []
|
||||
for project in projects:
|
||||
blueprint = project.blueprint
|
||||
genre = blueprint.genre if blueprint and blueprint.genre else "未知"
|
||||
outlines = project.outlines
|
||||
chapters = project.chapters
|
||||
total = len(outlines) or len(chapters)
|
||||
completed = sum(1 for chapter in chapters if chapter.selected_version_id)
|
||||
owner = project.owner
|
||||
summaries.append(
|
||||
AdminNovelSummary(
|
||||
id=project.id,
|
||||
title=project.title,
|
||||
owner_id=owner.id if owner else 0,
|
||||
owner_username=owner.username if owner else "未知",
|
||||
genre=genre,
|
||||
last_edited=project.updated_at.isoformat() if project.updated_at else "",
|
||||
completed_chapters=completed,
|
||||
total_chapters=total,
|
||||
)
|
||||
)
|
||||
return summaries
|
||||
|
||||
async def delete_projects(self, project_ids: List[str], user_id: int) -> None:
|
||||
for pid in project_ids:
|
||||
project = await self.ensure_project_owner(pid, user_id)
|
||||
await self.repo.delete(project)
|
||||
await self.session.commit()
|
||||
|
||||
async def count_projects(self) -> int:
|
||||
result = await self.session.execute(select(func.count(NovelProject.id)))
|
||||
return result.scalar_one()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 对话管理
|
||||
# ------------------------------------------------------------------
|
||||
async def list_conversations(self, project_id: str) -> List[NovelConversation]:
|
||||
stmt = (
|
||||
select(NovelConversation)
|
||||
.where(NovelConversation.project_id == project_id)
|
||||
.order_by(NovelConversation.seq.asc())
|
||||
)
|
||||
result = await self.session.execute(stmt)
|
||||
return list(result.scalars())
|
||||
|
||||
async def append_conversation(self, project_id: str, role: str, content: str, metadata: Optional[Dict] = None) -> None:
|
||||
result = await self.session.execute(
|
||||
select(func.max(NovelConversation.seq)).where(NovelConversation.project_id == project_id)
|
||||
)
|
||||
current_max = result.scalar()
|
||||
next_seq = (current_max or 0) + 1
|
||||
convo = NovelConversation(
|
||||
project_id=project_id,
|
||||
seq=next_seq,
|
||||
role=role,
|
||||
content=content,
|
||||
metadata=metadata,
|
||||
)
|
||||
self.session.add(convo)
|
||||
await self.session.commit()
|
||||
await self._touch_project(project_id)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 蓝图管理
|
||||
# ------------------------------------------------------------------
|
||||
async def replace_blueprint(self, project_id: str, blueprint: Blueprint) -> None:
|
||||
record = await self.session.get(NovelBlueprint, project_id)
|
||||
if not record:
|
||||
record = NovelBlueprint(project_id=project_id)
|
||||
self.session.add(record)
|
||||
record.title = blueprint.title
|
||||
record.target_audience = blueprint.target_audience
|
||||
record.genre = blueprint.genre
|
||||
record.style = blueprint.style
|
||||
record.tone = blueprint.tone
|
||||
record.one_sentence_summary = blueprint.one_sentence_summary
|
||||
record.full_synopsis = blueprint.full_synopsis
|
||||
record.world_setting = blueprint.world_setting
|
||||
|
||||
await self.session.execute(delete(BlueprintCharacter).where(BlueprintCharacter.project_id == project_id))
|
||||
for index, data in enumerate(blueprint.characters):
|
||||
self.session.add(
|
||||
BlueprintCharacter(
|
||||
project_id=project_id,
|
||||
name=data.get("name", ""),
|
||||
identity=data.get("identity"),
|
||||
personality=data.get("personality"),
|
||||
goals=data.get("goals"),
|
||||
abilities=data.get("abilities"),
|
||||
relationship_to_protagonist=data.get("relationship_to_protagonist"),
|
||||
extra={k: v for k, v in data.items() if k not in {
|
||||
"name",
|
||||
"identity",
|
||||
"personality",
|
||||
"goals",
|
||||
"abilities",
|
||||
"relationship_to_protagonist",
|
||||
}},
|
||||
position=index,
|
||||
)
|
||||
)
|
||||
|
||||
await self.session.execute(delete(BlueprintRelationship).where(BlueprintRelationship.project_id == project_id))
|
||||
for index, relation in enumerate(blueprint.relationships):
|
||||
self.session.add(
|
||||
BlueprintRelationship(
|
||||
project_id=project_id,
|
||||
character_from=relation.character_from,
|
||||
character_to=relation.character_to,
|
||||
description=relation.description,
|
||||
position=index,
|
||||
)
|
||||
)
|
||||
|
||||
await self.session.execute(delete(ChapterOutline).where(ChapterOutline.project_id == project_id))
|
||||
for outline in blueprint.chapter_outline:
|
||||
self.session.add(
|
||||
ChapterOutline(
|
||||
project_id=project_id,
|
||||
chapter_number=outline.chapter_number,
|
||||
title=outline.title,
|
||||
summary=outline.summary,
|
||||
)
|
||||
)
|
||||
|
||||
await self.session.commit()
|
||||
await self._touch_project(project_id)
|
||||
|
||||
async def patch_blueprint(self, project_id: str, patch: Dict) -> None:
|
||||
blueprint = await self.session.get(NovelBlueprint, project_id)
|
||||
if not blueprint:
|
||||
blueprint = NovelBlueprint(project_id=project_id)
|
||||
self.session.add(blueprint)
|
||||
|
||||
if "one_sentence_summary" in patch:
|
||||
blueprint.one_sentence_summary = patch["one_sentence_summary"]
|
||||
if "full_synopsis" in patch:
|
||||
blueprint.full_synopsis = patch["full_synopsis"]
|
||||
if "world_setting" in patch and patch["world_setting"] is not None:
|
||||
existing = blueprint.world_setting or {}
|
||||
existing.update(patch["world_setting"])
|
||||
blueprint.world_setting = existing
|
||||
if "characters" in patch and patch["characters"] is not None:
|
||||
await self.session.execute(delete(BlueprintCharacter).where(BlueprintCharacter.project_id == project_id))
|
||||
for index, data in enumerate(patch["characters"]):
|
||||
self.session.add(
|
||||
BlueprintCharacter(
|
||||
project_id=project_id,
|
||||
name=data.get("name", ""),
|
||||
identity=data.get("identity"),
|
||||
personality=data.get("personality"),
|
||||
goals=data.get("goals"),
|
||||
abilities=data.get("abilities"),
|
||||
relationship_to_protagonist=data.get("relationship_to_protagonist"),
|
||||
extra={k: v for k, v in data.items() if k not in {
|
||||
"name",
|
||||
"identity",
|
||||
"personality",
|
||||
"goals",
|
||||
"abilities",
|
||||
"relationship_to_protagonist",
|
||||
}},
|
||||
position=index,
|
||||
)
|
||||
)
|
||||
if "relationships" in patch and patch["relationships"] is not None:
|
||||
await self.session.execute(delete(BlueprintRelationship).where(BlueprintRelationship.project_id == project_id))
|
||||
for index, relation in enumerate(patch["relationships"]):
|
||||
self.session.add(
|
||||
BlueprintRelationship(
|
||||
project_id=project_id,
|
||||
character_from=relation.get("character_from"),
|
||||
character_to=relation.get("character_to"),
|
||||
description=relation.get("description"),
|
||||
position=index,
|
||||
)
|
||||
)
|
||||
if "chapter_outline" in patch and patch["chapter_outline"] is not None:
|
||||
await self.session.execute(delete(ChapterOutline).where(ChapterOutline.project_id == project_id))
|
||||
for outline in patch["chapter_outline"]:
|
||||
self.session.add(
|
||||
ChapterOutline(
|
||||
project_id=project_id,
|
||||
chapter_number=outline.get("chapter_number"),
|
||||
title=outline.get("title", ""),
|
||||
summary=outline.get("summary"),
|
||||
)
|
||||
)
|
||||
await self.session.commit()
|
||||
await self._touch_project(project_id)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 章节与版本
|
||||
# ------------------------------------------------------------------
|
||||
async def get_outline(self, project_id: str, chapter_number: int) -> Optional[ChapterOutline]:
|
||||
stmt = (
|
||||
select(ChapterOutline)
|
||||
.where(
|
||||
ChapterOutline.project_id == project_id,
|
||||
ChapterOutline.chapter_number == chapter_number,
|
||||
)
|
||||
)
|
||||
result = await self.session.execute(stmt)
|
||||
return result.scalars().first()
|
||||
|
||||
async def get_or_create_chapter(self, project_id: str, chapter_number: int) -> Chapter:
|
||||
stmt = (
|
||||
select(Chapter)
|
||||
.where(
|
||||
Chapter.project_id == project_id,
|
||||
Chapter.chapter_number == chapter_number,
|
||||
)
|
||||
)
|
||||
result = await self.session.execute(stmt)
|
||||
chapter = result.scalars().first()
|
||||
if chapter:
|
||||
return chapter
|
||||
chapter = Chapter(project_id=project_id, chapter_number=chapter_number)
|
||||
self.session.add(chapter)
|
||||
await self.session.commit()
|
||||
await self.session.refresh(chapter)
|
||||
return chapter
|
||||
|
||||
async def replace_chapter_versions(self, chapter: Chapter, contents: List[str], metadata: Optional[List[Dict]] = None) -> List[ChapterVersion]:
|
||||
await self.session.execute(delete(ChapterVersion).where(ChapterVersion.chapter_id == chapter.id))
|
||||
versions: List[ChapterVersion] = []
|
||||
for index, content in enumerate(contents):
|
||||
extra = metadata[index] if metadata and index < len(metadata) else None
|
||||
text_content = _normalize_version_content(content, extra)
|
||||
version = ChapterVersion(
|
||||
chapter_id=chapter.id,
|
||||
content=text_content,
|
||||
metadata=None,
|
||||
version_label=f"v{index+1}",
|
||||
)
|
||||
self.session.add(version)
|
||||
versions.append(version)
|
||||
chapter.status = ChapterGenerationStatus.WAITING_FOR_CONFIRM.value
|
||||
await self.session.commit()
|
||||
await self.session.refresh(chapter)
|
||||
await self._touch_project(chapter.project_id)
|
||||
return versions
|
||||
|
||||
async def select_chapter_version(self, chapter: Chapter, version_index: int) -> ChapterVersion:
|
||||
versions = sorted(chapter.versions, key=lambda item: item.created_at)
|
||||
if not versions or version_index < 0 or version_index >= len(versions):
|
||||
raise HTTPException(status_code=400, detail="版本索引无效")
|
||||
selected = versions[version_index]
|
||||
chapter.selected_version_id = selected.id
|
||||
chapter.status = ChapterGenerationStatus.SUCCESSFUL.value
|
||||
chapter.word_count = len(selected.content or "")
|
||||
await self.session.commit()
|
||||
await self.session.refresh(chapter)
|
||||
await self._touch_project(chapter.project_id)
|
||||
return selected
|
||||
|
||||
async def add_chapter_evaluation(self, chapter: Chapter, version: Optional[ChapterVersion], feedback: str, decision: Optional[str] = None) -> None:
|
||||
evaluation = ChapterEvaluation(
|
||||
chapter_id=chapter.id,
|
||||
version_id=version.id if version else None,
|
||||
feedback=feedback,
|
||||
decision=decision,
|
||||
)
|
||||
self.session.add(evaluation)
|
||||
chapter.status = ChapterGenerationStatus.WAITING_FOR_CONFIRM.value
|
||||
await self.session.commit()
|
||||
await self.session.refresh(chapter)
|
||||
await self._touch_project(chapter.project_id)
|
||||
|
||||
async def delete_chapters(self, project_id: str, chapter_numbers: Iterable[int]) -> None:
|
||||
await self.session.execute(
|
||||
delete(Chapter).where(
|
||||
Chapter.project_id == project_id,
|
||||
Chapter.chapter_number.in_(list(chapter_numbers)),
|
||||
)
|
||||
)
|
||||
await self.session.execute(
|
||||
delete(ChapterOutline).where(
|
||||
ChapterOutline.project_id == project_id,
|
||||
ChapterOutline.chapter_number.in_(list(chapter_numbers)),
|
||||
)
|
||||
)
|
||||
await self.session.commit()
|
||||
await self._touch_project(project_id)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 序列化辅助
|
||||
# ------------------------------------------------------------------
|
||||
async def get_project_schema_for_admin(self, project_id: str) -> NovelProjectSchema:
|
||||
project = await self.repo.get_by_id(project_id)
|
||||
if not project:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="项目不存在")
|
||||
return await self._serialize_project(project)
|
||||
|
||||
async def get_section_data_for_admin(
|
||||
self,
|
||||
project_id: str,
|
||||
section: NovelSectionType,
|
||||
) -> NovelSectionResponse:
|
||||
project = await self.repo.get_by_id(project_id)
|
||||
if not project:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="项目不存在")
|
||||
return self._build_section_response(project, section)
|
||||
|
||||
async def get_chapter_schema_for_admin(
|
||||
self,
|
||||
project_id: str,
|
||||
chapter_number: int,
|
||||
) -> ChapterSchema:
|
||||
project = await self.repo.get_by_id(project_id)
|
||||
if not project:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="项目不存在")
|
||||
return self._build_chapter_schema(project, chapter_number)
|
||||
|
||||
async def _serialize_project(self, project: NovelProject) -> NovelProjectSchema:
|
||||
conversations = [
|
||||
{"role": convo.role, "content": convo.content}
|
||||
for convo in sorted(project.conversations, key=lambda c: c.seq)
|
||||
]
|
||||
|
||||
blueprint_schema = self._build_blueprint_schema(project)
|
||||
|
||||
outlines_map = {outline.chapter_number: outline for outline in project.outlines}
|
||||
chapters_map = {chapter.chapter_number: chapter for chapter in project.chapters}
|
||||
chapter_numbers = sorted(set(outlines_map.keys()) | set(chapters_map.keys()))
|
||||
chapters_schema: List[ChapterSchema] = [
|
||||
self._build_chapter_schema(
|
||||
project,
|
||||
number,
|
||||
outlines_map=outlines_map,
|
||||
chapters_map=chapters_map,
|
||||
)
|
||||
for number in chapter_numbers
|
||||
]
|
||||
|
||||
return NovelProjectSchema(
|
||||
id=project.id,
|
||||
user_id=project.user_id,
|
||||
title=project.title,
|
||||
initial_prompt=project.initial_prompt or "",
|
||||
conversation_history=conversations,
|
||||
blueprint=blueprint_schema,
|
||||
chapters=chapters_schema,
|
||||
)
|
||||
|
||||
async def _touch_project(self, project_id: str) -> None:
|
||||
await self.session.execute(
|
||||
update(NovelProject)
|
||||
.where(NovelProject.id == project_id)
|
||||
.values(updated_at=datetime.now(timezone.utc))
|
||||
)
|
||||
await self.session.commit()
|
||||
|
||||
def _build_blueprint_schema(self, project: NovelProject) -> Blueprint:
|
||||
blueprint_obj = project.blueprint
|
||||
if blueprint_obj:
|
||||
return Blueprint(
|
||||
title=blueprint_obj.title or "",
|
||||
target_audience=blueprint_obj.target_audience or "",
|
||||
genre=blueprint_obj.genre or "",
|
||||
style=blueprint_obj.style or "",
|
||||
tone=blueprint_obj.tone or "",
|
||||
one_sentence_summary=blueprint_obj.one_sentence_summary or "",
|
||||
full_synopsis=blueprint_obj.full_synopsis or "",
|
||||
world_setting=blueprint_obj.world_setting or {},
|
||||
characters=[
|
||||
{
|
||||
"name": character.name,
|
||||
"identity": character.identity,
|
||||
"personality": character.personality,
|
||||
"goals": character.goals,
|
||||
"abilities": character.abilities,
|
||||
"relationship_to_protagonist": character.relationship_to_protagonist,
|
||||
**(character.extra or {}),
|
||||
}
|
||||
for character in sorted(project.characters, key=lambda c: c.position)
|
||||
],
|
||||
relationships=[
|
||||
{
|
||||
"character_from": relation.character_from,
|
||||
"character_to": relation.character_to,
|
||||
"description": relation.description or "",
|
||||
"relationship_type": getattr(relation, "relationship_type", None),
|
||||
}
|
||||
for relation in sorted(project.relationships_, key=lambda r: r.position)
|
||||
],
|
||||
chapter_outline=[
|
||||
ChapterOutlineSchema(
|
||||
chapter_number=outline.chapter_number,
|
||||
title=outline.title,
|
||||
summary=outline.summary or "",
|
||||
)
|
||||
for outline in sorted(project.outlines, key=lambda o: o.chapter_number)
|
||||
],
|
||||
)
|
||||
return Blueprint(
|
||||
title="",
|
||||
target_audience="",
|
||||
genre="",
|
||||
style="",
|
||||
tone="",
|
||||
one_sentence_summary="",
|
||||
full_synopsis="",
|
||||
world_setting={},
|
||||
characters=[],
|
||||
relationships=[],
|
||||
chapter_outline=[],
|
||||
)
|
||||
|
||||
def _build_section_response(
|
||||
self,
|
||||
project: NovelProject,
|
||||
section: NovelSectionType,
|
||||
) -> NovelSectionResponse:
|
||||
blueprint = self._build_blueprint_schema(project)
|
||||
|
||||
if section == NovelSectionType.OVERVIEW:
|
||||
data = {
|
||||
"title": project.title,
|
||||
"initial_prompt": project.initial_prompt or "",
|
||||
"status": project.status,
|
||||
"one_sentence_summary": blueprint.one_sentence_summary,
|
||||
"target_audience": blueprint.target_audience,
|
||||
"genre": blueprint.genre,
|
||||
"style": blueprint.style,
|
||||
"tone": blueprint.tone,
|
||||
"full_synopsis": blueprint.full_synopsis,
|
||||
"updated_at": project.updated_at.isoformat() if project.updated_at else None,
|
||||
}
|
||||
elif section == NovelSectionType.WORLD_SETTING:
|
||||
data = {
|
||||
"world_setting": blueprint.world_setting or {},
|
||||
}
|
||||
elif section == NovelSectionType.CHARACTERS:
|
||||
data = {
|
||||
"characters": blueprint.characters,
|
||||
}
|
||||
elif section == NovelSectionType.RELATIONSHIPS:
|
||||
data = {
|
||||
"relationships": blueprint.relationships,
|
||||
}
|
||||
elif section == NovelSectionType.CHAPTER_OUTLINE:
|
||||
data = {
|
||||
"chapter_outline": [outline.model_dump() for outline in blueprint.chapter_outline],
|
||||
}
|
||||
elif section == NovelSectionType.CHAPTERS:
|
||||
outlines_map = {outline.chapter_number: outline for outline in project.outlines}
|
||||
chapters_map = {chapter.chapter_number: chapter for chapter in project.chapters}
|
||||
chapter_numbers = sorted(set(outlines_map.keys()) | set(chapters_map.keys()))
|
||||
# 章节列表只返回元数据,不包含完整内容
|
||||
chapters = [
|
||||
self._build_chapter_schema(
|
||||
project,
|
||||
number,
|
||||
outlines_map=outlines_map,
|
||||
chapters_map=chapters_map,
|
||||
include_content=False,
|
||||
).model_dump()
|
||||
for number in chapter_numbers
|
||||
]
|
||||
data = {
|
||||
"chapters": chapters,
|
||||
"total": len(chapters),
|
||||
}
|
||||
else:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="未知的章节类型")
|
||||
|
||||
return NovelSectionResponse(section=section, data=data)
|
||||
|
||||
def _build_chapter_schema(
|
||||
self,
|
||||
project: NovelProject,
|
||||
chapter_number: int,
|
||||
*,
|
||||
outlines_map: Optional[Dict[int, ChapterOutline]] = None,
|
||||
chapters_map: Optional[Dict[int, Chapter]] = None,
|
||||
include_content: bool = True,
|
||||
) -> ChapterSchema:
|
||||
outlines = outlines_map or {outline.chapter_number: outline for outline in project.outlines}
|
||||
chapters = chapters_map or {chapter.chapter_number: chapter for chapter in project.chapters}
|
||||
outline = outlines.get(chapter_number)
|
||||
chapter = chapters.get(chapter_number)
|
||||
|
||||
if not outline and not chapter:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="章节不存在")
|
||||
|
||||
title = outline.title if outline else f"第{chapter_number}章"
|
||||
summary = outline.summary if outline else ""
|
||||
real_summary = chapter.real_summary if chapter else None
|
||||
content = None
|
||||
versions: Optional[List[str]] = None
|
||||
evaluation_text: Optional[str] = None
|
||||
status_value = ChapterGenerationStatus.NOT_GENERATED.value
|
||||
word_count = 0
|
||||
|
||||
if chapter:
|
||||
status_value = chapter.status or ChapterGenerationStatus.NOT_GENERATED.value
|
||||
word_count = chapter.word_count or 0
|
||||
|
||||
# 只有在 include_content=True 时才包含完整内容
|
||||
if include_content:
|
||||
if chapter.selected_version:
|
||||
content = chapter.selected_version.content
|
||||
if chapter.versions:
|
||||
versions = [
|
||||
v.content
|
||||
for v in sorted(chapter.versions, key=lambda item: item.created_at)
|
||||
]
|
||||
if chapter.evaluations:
|
||||
latest = sorted(chapter.evaluations, key=lambda item: item.created_at)[-1]
|
||||
evaluation_text = latest.feedback or latest.decision
|
||||
|
||||
return ChapterSchema(
|
||||
chapter_number=chapter_number,
|
||||
title=title,
|
||||
summary=summary,
|
||||
real_summary=real_summary,
|
||||
content=content,
|
||||
versions=versions,
|
||||
evaluation=evaluation_text,
|
||||
generation_status=ChapterGenerationStatus(status_value),
|
||||
word_count=word_count,
|
||||
)
|
||||
96
backend/app/services/prompt_service.py
Normal file
96
backend/app/services/prompt_service.py
Normal file
@@ -0,0 +1,96 @@
|
||||
import asyncio
|
||||
from typing import Dict, Optional
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from ..models import Prompt
|
||||
from ..repositories.prompt_repository import PromptRepository
|
||||
from ..schemas.prompt import PromptCreate, PromptRead, PromptUpdate
|
||||
|
||||
_CACHE: Dict[str, PromptRead] = {}
|
||||
_LOCK = asyncio.Lock()
|
||||
_LOADED = False
|
||||
|
||||
|
||||
class PromptService:
|
||||
"""提示词服务,提供缓存加速与 CRUD 能力。"""
|
||||
|
||||
def __init__(self, session: AsyncSession):
|
||||
self.session = session
|
||||
self.repo = PromptRepository(session)
|
||||
|
||||
async def preload(self) -> None:
|
||||
global _CACHE, _LOADED
|
||||
prompts = await self.repo.list_all()
|
||||
async with _LOCK:
|
||||
_CACHE = {item.name: PromptRead.model_validate(item) for item in prompts}
|
||||
_LOADED = True
|
||||
|
||||
async def get_prompt(self, name: str) -> Optional[str]:
|
||||
global _LOADED
|
||||
async with _LOCK:
|
||||
if not _LOADED:
|
||||
prompts = await self.repo.list_all()
|
||||
_CACHE.update({item.name: PromptRead.model_validate(item) for item in prompts})
|
||||
_LOADED = True
|
||||
cached = _CACHE.get(name)
|
||||
if cached:
|
||||
return cached.content
|
||||
|
||||
prompt = await self.repo.get_by_name(name)
|
||||
if not prompt:
|
||||
return None
|
||||
|
||||
prompt_read = PromptRead.model_validate(prompt)
|
||||
async with _LOCK:
|
||||
_CACHE[name] = prompt_read
|
||||
return prompt_read.content
|
||||
|
||||
async def list_prompts(self) -> list[PromptRead]:
|
||||
prompts = await self.repo.list_all()
|
||||
return [PromptRead.model_validate(item) for item in prompts]
|
||||
|
||||
async def get_prompt_by_id(self, prompt_id: int) -> Optional[PromptRead]:
|
||||
instance = await self.repo.get(id=prompt_id)
|
||||
if not instance:
|
||||
return None
|
||||
return PromptRead.model_validate(instance)
|
||||
|
||||
async def create_prompt(self, payload: PromptCreate) -> PromptRead:
|
||||
data = payload.model_dump()
|
||||
tags = data.get("tags")
|
||||
if tags is not None:
|
||||
data["tags"] = ",".join(tags)
|
||||
prompt = Prompt(**data)
|
||||
await self.repo.add(prompt)
|
||||
await self.session.commit()
|
||||
prompt_read = PromptRead.model_validate(prompt)
|
||||
async with _LOCK:
|
||||
_CACHE[prompt_read.name] = prompt_read
|
||||
global _LOADED
|
||||
_LOADED = True
|
||||
return prompt_read
|
||||
|
||||
async def update_prompt(self, prompt_id: int, payload: PromptUpdate) -> Optional[PromptRead]:
|
||||
instance = await self.repo.get(id=prompt_id)
|
||||
if not instance:
|
||||
return None
|
||||
update_data = payload.model_dump(exclude_unset=True)
|
||||
if "tags" in update_data and update_data["tags"] is not None:
|
||||
update_data["tags"] = ",".join(update_data["tags"])
|
||||
await self.repo.update_fields(instance, **update_data)
|
||||
await self.session.commit()
|
||||
prompt_read = PromptRead.model_validate(instance)
|
||||
async with _LOCK:
|
||||
_CACHE[prompt_read.name] = prompt_read
|
||||
return prompt_read
|
||||
|
||||
async def delete_prompt(self, prompt_id: int) -> bool:
|
||||
instance = await self.repo.get(id=prompt_id)
|
||||
if not instance:
|
||||
return False
|
||||
await self.repo.delete(instance)
|
||||
await self.session.commit()
|
||||
async with _LOCK:
|
||||
_CACHE.pop(instance.name, None)
|
||||
return True
|
||||
60
backend/app/services/update_log_service.py
Normal file
60
backend/app/services/update_log_service.py
Normal file
@@ -0,0 +1,60 @@
|
||||
from typing import List, Optional
|
||||
|
||||
from fastapi import HTTPException, status
|
||||
from sqlalchemy import update
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from ..models import UpdateLog
|
||||
from ..repositories.update_log_repository import UpdateLogRepository
|
||||
|
||||
|
||||
class UpdateLogService:
|
||||
"""更新日志服务,提供增删改查能力,并保证置顶唯一。"""
|
||||
|
||||
def __init__(self, session: AsyncSession):
|
||||
self.session = session
|
||||
self.repo = UpdateLogRepository(session)
|
||||
|
||||
async def list_logs(self, limit: Optional[int] = None) -> List[UpdateLog]:
|
||||
if limit is None:
|
||||
return list(await self.repo.list())
|
||||
return list(await self.repo.list_latest(limit))
|
||||
|
||||
async def create_log(self, content: str, creator: str | None = None, *, is_pinned: bool = False) -> UpdateLog:
|
||||
if is_pinned:
|
||||
await self._clear_pinned()
|
||||
log = UpdateLog(content=content, created_by=creator, is_pinned=is_pinned)
|
||||
await self.repo.add(log)
|
||||
await self.session.commit()
|
||||
await self.session.refresh(log)
|
||||
return log
|
||||
|
||||
async def update_log(self, log_id: int, *, content: Optional[str] = None, is_pinned: Optional[bool] = None) -> UpdateLog:
|
||||
log = await self.repo.get(id=log_id)
|
||||
if not log:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="更新记录不存在")
|
||||
|
||||
updates = {}
|
||||
if content is not None:
|
||||
updates["content"] = content
|
||||
if is_pinned is not None:
|
||||
if is_pinned:
|
||||
await self._clear_pinned()
|
||||
updates["is_pinned"] = is_pinned
|
||||
|
||||
if updates:
|
||||
await self.repo.update_fields(log, **updates)
|
||||
await self.session.commit()
|
||||
await self.session.refresh(log)
|
||||
|
||||
return log
|
||||
|
||||
async def delete_log(self, log_id: int) -> None:
|
||||
log = await self.repo.get(id=log_id)
|
||||
if not log:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="更新记录不存在")
|
||||
await self.repo.delete(log)
|
||||
await self.session.commit()
|
||||
|
||||
async def _clear_pinned(self) -> None:
|
||||
await self.session.execute(update(UpdateLog).values(is_pinned=False))
|
||||
21
backend/app/services/usage_service.py
Normal file
21
backend/app/services/usage_service.py
Normal file
@@ -0,0 +1,21 @@
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from ..repositories.usage_metric_repository import UsageMetricRepository
|
||||
|
||||
|
||||
class UsageService:
|
||||
"""通用计数服务,目前用于统计 API 请求次数等。"""
|
||||
|
||||
def __init__(self, session: AsyncSession):
|
||||
self.session = session
|
||||
self.repo = UsageMetricRepository(session)
|
||||
|
||||
async def increment(self, key: str) -> None:
|
||||
counter = await self.repo.get_or_create(key)
|
||||
counter.value += 1
|
||||
await self.session.commit()
|
||||
|
||||
async def get_value(self, key: str) -> int:
|
||||
counter = await self.repo.get_or_create(key)
|
||||
await self.session.commit()
|
||||
return counter.value
|
||||
62
backend/app/services/user_service.py
Normal file
62
backend/app/services/user_service.py
Normal file
@@ -0,0 +1,62 @@
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from ..core.security import hash_password
|
||||
from ..models import User
|
||||
from ..repositories.user_repository import UserRepository
|
||||
from ..schemas.user import UserCreate, UserInDB
|
||||
|
||||
|
||||
class UserService:
|
||||
"""用户领域服务,负责注册、查询与配额统计。"""
|
||||
|
||||
def __init__(self, session: AsyncSession):
|
||||
self.session = session
|
||||
self.repo = UserRepository(session)
|
||||
|
||||
async def create_user(self, payload: UserCreate, *, external_id: str | None = None) -> UserInDB:
|
||||
hashed_password = hash_password(payload.password)
|
||||
user = User(
|
||||
username=payload.username,
|
||||
email=payload.email,
|
||||
hashed_password=hashed_password,
|
||||
external_id=external_id,
|
||||
)
|
||||
|
||||
self.session.add(user)
|
||||
try:
|
||||
await self.session.commit()
|
||||
except IntegrityError as exc:
|
||||
await self.session.rollback()
|
||||
raise ValueError("用户名或邮箱已存在") from exc
|
||||
|
||||
return UserInDB.model_validate(user)
|
||||
|
||||
async def get_by_username(self, username: str) -> Optional[UserInDB]:
|
||||
user = await self.repo.get_by_username(username)
|
||||
return UserInDB.model_validate(user) if user else None
|
||||
|
||||
async def get_by_email(self, email: str) -> Optional[UserInDB]:
|
||||
user = await self.repo.get_by_email(email)
|
||||
return UserInDB.model_validate(user) if user else None
|
||||
|
||||
async def get_by_external_id(self, external_id: str) -> Optional[UserInDB]:
|
||||
user = await self.repo.get_by_external_id(external_id)
|
||||
return UserInDB.model_validate(user) if user else None
|
||||
|
||||
async def get_user(self, user_id: int) -> Optional[UserInDB]:
|
||||
user = await self.repo.get(id=user_id)
|
||||
return UserInDB.model_validate(user) if user else None
|
||||
|
||||
async def list_users(self) -> list[UserInDB]:
|
||||
users = await self.repo.list_all()
|
||||
return [UserInDB.model_validate(item) for item in users]
|
||||
|
||||
async def increment_daily_request(self, user_id: int) -> None:
|
||||
await self.repo.increment_daily_request(user_id)
|
||||
await self.session.commit()
|
||||
|
||||
async def get_daily_request(self, user_id: int) -> int:
|
||||
return await self.repo.get_daily_request(user_id)
|
||||
544
backend/app/services/vector_store_service.py
Normal file
544
backend/app/services/vector_store_service.py
Normal file
@@ -0,0 +1,544 @@
|
||||
from __future__ import annotations
|
||||
|
||||
"""
|
||||
基于 libsql 的向量检索服务,封装章节内容的存储与查询。
|
||||
|
||||
本文件中的注释均使用中文,便于团队成员快速理解 RAG 相关逻辑。
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
from array import array
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Iterable, List, Optional, Sequence
|
||||
|
||||
from ..core.config import settings
|
||||
|
||||
try: # noqa: SIM105 - 明确区分依赖缺失的情况
|
||||
import libsql_client
|
||||
except ImportError: # pragma: no cover - 在未安装依赖时提供友好提示
|
||||
libsql_client = None # type: ignore[assignment]
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RetrievedChunk:
|
||||
"""向量检索得到的剧情片段。"""
|
||||
|
||||
content: str
|
||||
chapter_number: int
|
||||
chapter_title: Optional[str]
|
||||
score: float
|
||||
metadata: Dict[str, Any]
|
||||
|
||||
|
||||
@dataclass
|
||||
class RetrievedSummary:
|
||||
"""向量检索得到的章节摘要。"""
|
||||
|
||||
chapter_number: int
|
||||
title: str
|
||||
summary: str
|
||||
score: float
|
||||
|
||||
|
||||
class VectorStoreService:
|
||||
"""libsql 向量库操作工具,确保不同小说项目的数据隔离。"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
if not settings.vector_store_enabled:
|
||||
logger.warning("未开启向量库配置,RAG 检索将被跳过。")
|
||||
self._client = None
|
||||
self._schema_ready = True
|
||||
return
|
||||
|
||||
if libsql_client is None: # pragma: no cover - 运行环境缺少依赖
|
||||
raise RuntimeError("缺少 libsql-client 依赖,请先在环境中安装。")
|
||||
|
||||
url = settings.vector_db_url
|
||||
if url and url.startswith("file:"):
|
||||
path_part = url.split("file:", 1)[1]
|
||||
resolved = Path(path_part).expanduser().resolve()
|
||||
resolved.parent.mkdir(parents=True, exist_ok=True)
|
||||
url = f"file:{resolved}"
|
||||
logger.info("向量库使用本地文件: %s", resolved)
|
||||
|
||||
try:
|
||||
logger.info("初始化 libsql 客户端: url=%s", url)
|
||||
self._client = libsql_client.create_client(
|
||||
url=url,
|
||||
auth_token=settings.vector_db_auth_token,
|
||||
)
|
||||
except Exception as exc: # pragma: no cover - 连接异常仅打印日志
|
||||
logger.error("初始化 libsql 客户端失败: %s", exc)
|
||||
self._client = None
|
||||
self._schema_ready = True
|
||||
else:
|
||||
self._schema_ready = False
|
||||
logger.info("libsql 客户端初始化成功,等待建表。")
|
||||
|
||||
async def ensure_schema(self) -> None:
|
||||
"""初始化向量表结构,保证系统首次运行即可使用。"""
|
||||
if not self._client or self._schema_ready:
|
||||
return
|
||||
|
||||
statements = [
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS rag_chunks (
|
||||
id TEXT PRIMARY KEY,
|
||||
project_id TEXT NOT NULL,
|
||||
chapter_number INTEGER NOT NULL,
|
||||
chunk_index INTEGER NOT NULL,
|
||||
chapter_title TEXT,
|
||||
content TEXT NOT NULL,
|
||||
embedding BLOB NOT NULL,
|
||||
metadata TEXT,
|
||||
created_at INTEGER DEFAULT (unixepoch())
|
||||
)
|
||||
""",
|
||||
"""
|
||||
CREATE INDEX IF NOT EXISTS idx_rag_chunks_project
|
||||
ON rag_chunks(project_id, chapter_number)
|
||||
""",
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS rag_summaries (
|
||||
id TEXT PRIMARY KEY,
|
||||
project_id TEXT NOT NULL,
|
||||
chapter_number INTEGER NOT NULL,
|
||||
title TEXT NOT NULL,
|
||||
summary TEXT NOT NULL,
|
||||
embedding BLOB NOT NULL,
|
||||
created_at INTEGER DEFAULT (unixepoch())
|
||||
)
|
||||
""",
|
||||
"""
|
||||
CREATE INDEX IF NOT EXISTS idx_rag_summaries_project
|
||||
ON rag_summaries(project_id, chapter_number)
|
||||
""",
|
||||
]
|
||||
|
||||
try:
|
||||
for sql in statements:
|
||||
await self._client.execute(sql) # type: ignore[union-attr]
|
||||
logger.info("已确保向量库表结构存在。")
|
||||
except Exception as exc: # pragma: no cover - 初始化失败时记录日志
|
||||
logger.error("创建向量库表结构失败: %s", exc)
|
||||
else:
|
||||
self._schema_ready = True
|
||||
|
||||
async def query_chunks(
|
||||
self,
|
||||
*,
|
||||
project_id: str,
|
||||
embedding: Sequence[float],
|
||||
top_k: Optional[int] = None,
|
||||
) -> List[RetrievedChunk]:
|
||||
"""根据查询向量检索剧情片段,结果已按相似度排序。"""
|
||||
if not self._client or not embedding:
|
||||
return []
|
||||
|
||||
await self.ensure_schema()
|
||||
top_k = top_k or settings.vector_top_k_chunks
|
||||
if top_k <= 0:
|
||||
return []
|
||||
|
||||
blob = self._to_f32_blob(embedding)
|
||||
sql = """
|
||||
SELECT
|
||||
content,
|
||||
chapter_number,
|
||||
chapter_title,
|
||||
COALESCE(metadata, '{}') AS metadata,
|
||||
vector_distance_cosine(embedding, :query) AS distance
|
||||
FROM rag_chunks
|
||||
WHERE project_id = :project_id
|
||||
ORDER BY distance ASC
|
||||
LIMIT :limit
|
||||
"""
|
||||
try:
|
||||
result = await self._client.execute( # type: ignore[union-attr]
|
||||
sql,
|
||||
{
|
||||
"project_id": project_id,
|
||||
"query": blob,
|
||||
"limit": top_k,
|
||||
},
|
||||
)
|
||||
except Exception as exc: # pragma: no cover - 查询异常时仅记录
|
||||
if "no such function: vector_distance_cosine" in str(exc).lower():
|
||||
logger.warning("向量库缺少 vector_distance_cosine 函数,回退至应用层相似度计算。")
|
||||
return await self._query_chunks_with_python_similarity(
|
||||
project_id=project_id,
|
||||
embedding=embedding,
|
||||
top_k=top_k,
|
||||
)
|
||||
logger.warning("向量检索剧情片段失败: %s", exc)
|
||||
return []
|
||||
|
||||
items: List[RetrievedChunk] = []
|
||||
for row in self._iter_rows(result):
|
||||
items.append(
|
||||
RetrievedChunk(
|
||||
content=row.get("content", ""),
|
||||
chapter_number=row.get("chapter_number", 0),
|
||||
chapter_title=row.get("chapter_title"),
|
||||
score=row.get("distance", 0.0),
|
||||
metadata=self._parse_metadata(row.get("metadata")),
|
||||
)
|
||||
)
|
||||
return items
|
||||
|
||||
async def query_summaries(
|
||||
self,
|
||||
*,
|
||||
project_id: str,
|
||||
embedding: Sequence[float],
|
||||
top_k: Optional[int] = None,
|
||||
) -> List[RetrievedSummary]:
|
||||
"""根据查询向量检索章节摘要列表。"""
|
||||
if not self._client or not embedding:
|
||||
return []
|
||||
|
||||
await self.ensure_schema()
|
||||
top_k = top_k or settings.vector_top_k_summaries
|
||||
if top_k <= 0:
|
||||
return []
|
||||
|
||||
blob = self._to_f32_blob(embedding)
|
||||
sql = """
|
||||
SELECT
|
||||
chapter_number,
|
||||
title,
|
||||
summary,
|
||||
vector_distance_cosine(embedding, :query) AS distance
|
||||
FROM rag_summaries
|
||||
WHERE project_id = :project_id
|
||||
ORDER BY distance ASC
|
||||
LIMIT :limit
|
||||
"""
|
||||
try:
|
||||
result = await self._client.execute( # type: ignore[union-attr]
|
||||
sql,
|
||||
{
|
||||
"project_id": project_id,
|
||||
"query": blob,
|
||||
"limit": top_k,
|
||||
},
|
||||
)
|
||||
except Exception as exc: # pragma: no cover - 查询异常时仅记录
|
||||
if "no such function: vector_distance_cosine" in str(exc).lower():
|
||||
logger.warning("向量库缺少 vector_distance_cosine 函数,回退至应用层相似度计算。")
|
||||
return await self._query_summaries_with_python_similarity(
|
||||
project_id=project_id,
|
||||
embedding=embedding,
|
||||
top_k=top_k,
|
||||
)
|
||||
logger.warning("向量检索章节摘要失败: %s", exc)
|
||||
return []
|
||||
|
||||
items: List[RetrievedSummary] = []
|
||||
for row in self._iter_rows(result):
|
||||
items.append(
|
||||
RetrievedSummary(
|
||||
chapter_number=row.get("chapter_number", 0),
|
||||
title=row.get("title", ""),
|
||||
summary=row.get("summary", ""),
|
||||
score=row.get("distance", 0.0),
|
||||
)
|
||||
)
|
||||
return items
|
||||
|
||||
async def upsert_chunks(
|
||||
self,
|
||||
*,
|
||||
records: Iterable[Dict[str, Any]],
|
||||
) -> None:
|
||||
"""批量写入章节片段,供后续检索使用。"""
|
||||
if not self._client:
|
||||
return
|
||||
|
||||
await self.ensure_schema()
|
||||
sql = """
|
||||
INSERT INTO rag_chunks (
|
||||
id,
|
||||
project_id,
|
||||
chapter_number,
|
||||
chunk_index,
|
||||
chapter_title,
|
||||
content,
|
||||
embedding,
|
||||
metadata
|
||||
) VALUES (
|
||||
:id,
|
||||
:project_id,
|
||||
:chapter_number,
|
||||
:chunk_index,
|
||||
:chapter_title,
|
||||
:content,
|
||||
:embedding,
|
||||
:metadata
|
||||
)
|
||||
ON CONFLICT(id) DO UPDATE SET
|
||||
content=excluded.content,
|
||||
embedding=excluded.embedding,
|
||||
metadata=excluded.metadata,
|
||||
chapter_title=excluded.chapter_title
|
||||
"""
|
||||
payload = []
|
||||
for item in records:
|
||||
embedding = item.get("embedding", [])
|
||||
payload.append(
|
||||
{
|
||||
**item,
|
||||
"embedding": self._to_f32_blob(embedding),
|
||||
"metadata": json.dumps(item.get("metadata") or {}, ensure_ascii=False),
|
||||
}
|
||||
)
|
||||
|
||||
if not payload:
|
||||
return
|
||||
|
||||
for item in payload:
|
||||
try:
|
||||
await self._client.execute(sql, item) # type: ignore[union-attr]
|
||||
except Exception as exc: # pragma: no cover - 单条写入失败时记录日志
|
||||
logger.error("写入 rag_chunks 失败: %s", exc)
|
||||
else:
|
||||
logger.debug(
|
||||
"已写入章节片段: project=%s chapter=%s chunk=%s",
|
||||
item.get("project_id"),
|
||||
item.get("chapter_number"),
|
||||
item.get("chunk_index"),
|
||||
)
|
||||
|
||||
async def upsert_summaries(
|
||||
self,
|
||||
*,
|
||||
records: Iterable[Dict[str, Any]],
|
||||
) -> None:
|
||||
"""同步章节摘要向量,供摘要层检索使用。"""
|
||||
if not self._client:
|
||||
return
|
||||
|
||||
await self.ensure_schema()
|
||||
sql = """
|
||||
INSERT INTO rag_summaries (
|
||||
id,
|
||||
project_id,
|
||||
chapter_number,
|
||||
title,
|
||||
summary,
|
||||
embedding
|
||||
) VALUES (
|
||||
:id,
|
||||
:project_id,
|
||||
:chapter_number,
|
||||
:title,
|
||||
:summary,
|
||||
:embedding
|
||||
)
|
||||
ON CONFLICT(id) DO UPDATE SET
|
||||
summary=excluded.summary,
|
||||
embedding=excluded.embedding,
|
||||
title=excluded.title
|
||||
"""
|
||||
|
||||
payload = []
|
||||
for item in records:
|
||||
embedding = item.get("embedding", [])
|
||||
payload.append(
|
||||
{
|
||||
**item,
|
||||
"embedding": self._to_f32_blob(embedding),
|
||||
}
|
||||
)
|
||||
|
||||
if not payload:
|
||||
return
|
||||
|
||||
for item in payload:
|
||||
try:
|
||||
await self._client.execute(sql, item) # type: ignore[union-attr]
|
||||
except Exception as exc: # pragma: no cover - 单条写入失败时记录日志
|
||||
logger.error("写入 rag_summaries 失败: %s", exc)
|
||||
else:
|
||||
logger.debug(
|
||||
"已写入章节摘要: project=%s chapter=%s",
|
||||
item.get("project_id"),
|
||||
item.get("chapter_number"),
|
||||
)
|
||||
|
||||
async def delete_by_chapters(self, project_id: str, chapter_numbers: Sequence[int]) -> None:
|
||||
"""根据章节编号批量删除对应的上下文数据。"""
|
||||
if not self._client or not chapter_numbers:
|
||||
return
|
||||
|
||||
await self.ensure_schema()
|
||||
placeholders = ",".join(":chapter_" + str(idx) for idx in range(len(chapter_numbers)))
|
||||
params = {
|
||||
"project_id": project_id,
|
||||
**{f"chapter_{idx}": number for idx, number in enumerate(chapter_numbers)},
|
||||
}
|
||||
chunk_sql = f"""
|
||||
DELETE FROM rag_chunks
|
||||
WHERE project_id = :project_id
|
||||
AND chapter_number IN ({placeholders})
|
||||
"""
|
||||
summary_sql = f"""
|
||||
DELETE FROM rag_summaries
|
||||
WHERE project_id = :project_id
|
||||
AND chapter_number IN ({placeholders})
|
||||
"""
|
||||
try:
|
||||
await self._client.execute(chunk_sql, params) # type: ignore[union-attr]
|
||||
await self._client.execute(summary_sql, params) # type: ignore[union-attr]
|
||||
logger.info(
|
||||
"已删除章节向量: project=%s chapters=%s",
|
||||
project_id,
|
||||
list(chapter_numbers),
|
||||
)
|
||||
except Exception as exc: # pragma: no cover - 删除失败时记录日志
|
||||
logger.error("删除章节向量失败: project=%s chapters=%s error=%s", project_id, chapter_numbers, exc)
|
||||
|
||||
@staticmethod
|
||||
def _to_f32_blob(embedding: Sequence[float]) -> bytes:
|
||||
"""将向量浮点列表编码为 libsql 可识别的 float32 二进制。"""
|
||||
return array("f", embedding).tobytes()
|
||||
|
||||
@staticmethod
|
||||
def _from_f32_blob(blob: Any) -> List[float]:
|
||||
"""将数据库中的 BLOB 解码为浮点列表。"""
|
||||
if not blob:
|
||||
return []
|
||||
if isinstance(blob, memoryview):
|
||||
blob = blob.tobytes()
|
||||
data = array("f")
|
||||
data.frombytes(bytes(blob))
|
||||
return list(data)
|
||||
|
||||
@staticmethod
|
||||
def _cosine_distance(vec_a: Sequence[float], vec_b: Sequence[float]) -> float:
|
||||
"""计算余弦距离(1 - similarity),避免除零。"""
|
||||
if not vec_a or not vec_b:
|
||||
return 1.0
|
||||
dot = sum(a * b for a, b in zip(vec_a, vec_b))
|
||||
norm_a = math.sqrt(sum(a * a for a in vec_a))
|
||||
norm_b = math.sqrt(sum(b * b for b in vec_b))
|
||||
if norm_a == 0 or norm_b == 0:
|
||||
return 1.0
|
||||
similarity = dot / (norm_a * norm_b)
|
||||
return 1.0 - similarity
|
||||
|
||||
async def _query_chunks_with_python_similarity(
|
||||
self,
|
||||
*,
|
||||
project_id: str,
|
||||
embedding: Sequence[float],
|
||||
top_k: int,
|
||||
) -> List[RetrievedChunk]:
|
||||
sql = """
|
||||
SELECT
|
||||
content,
|
||||
chapter_number,
|
||||
chapter_title,
|
||||
COALESCE(metadata, '{}') AS metadata,
|
||||
embedding
|
||||
FROM rag_chunks
|
||||
WHERE project_id = :project_id
|
||||
"""
|
||||
result = await self._client.execute(sql, {"project_id": project_id}) # type: ignore[union-attr]
|
||||
scored: List[RetrievedChunk] = []
|
||||
for row in self._iter_rows(result):
|
||||
stored_embedding = self._from_f32_blob(row.get("embedding"))
|
||||
distance = self._cosine_distance(embedding, stored_embedding)
|
||||
scored.append(
|
||||
RetrievedChunk(
|
||||
content=row.get("content", ""),
|
||||
chapter_number=row.get("chapter_number", 0),
|
||||
chapter_title=row.get("chapter_title"),
|
||||
score=distance,
|
||||
metadata=self._parse_metadata(row.get("metadata")),
|
||||
)
|
||||
)
|
||||
scored.sort(key=lambda item: item.score)
|
||||
return scored[:top_k]
|
||||
|
||||
async def _query_summaries_with_python_similarity(
|
||||
self,
|
||||
*,
|
||||
project_id: str,
|
||||
embedding: Sequence[float],
|
||||
top_k: int,
|
||||
) -> List[RetrievedSummary]:
|
||||
sql = """
|
||||
SELECT
|
||||
chapter_number,
|
||||
title,
|
||||
summary,
|
||||
embedding
|
||||
FROM rag_summaries
|
||||
WHERE project_id = :project_id
|
||||
"""
|
||||
result = await self._client.execute(sql, {"project_id": project_id}) # type: ignore[union-attr]
|
||||
scored: List[RetrievedSummary] = []
|
||||
for row in self._iter_rows(result):
|
||||
stored_embedding = self._from_f32_blob(row.get("embedding"))
|
||||
distance = self._cosine_distance(embedding, stored_embedding)
|
||||
scored.append(
|
||||
RetrievedSummary(
|
||||
chapter_number=row.get("chapter_number", 0),
|
||||
title=row.get("title", ""),
|
||||
summary=row.get("summary", ""),
|
||||
score=distance,
|
||||
)
|
||||
)
|
||||
scored.sort(key=lambda item: item.score)
|
||||
return scored[:top_k]
|
||||
|
||||
@staticmethod
|
||||
def _parse_metadata(raw: Any) -> Dict[str, Any]:
|
||||
"""解析存储的 JSON 文本,确保输出为 dict。"""
|
||||
if not raw:
|
||||
return {}
|
||||
if isinstance(raw, dict):
|
||||
return raw
|
||||
if isinstance(raw, (bytes, bytearray)):
|
||||
raw = raw.decode("utf-8")
|
||||
if isinstance(raw, str):
|
||||
try:
|
||||
parsed = json.loads(raw)
|
||||
return parsed if isinstance(parsed, dict) else {}
|
||||
except json.JSONDecodeError:
|
||||
return {}
|
||||
return {}
|
||||
|
||||
@staticmethod
|
||||
def _iter_rows(result: Any) -> Iterable[Dict[str, Any]]:
|
||||
"""统一处理 libsql 返回的行数据,确保以 dict 形式迭代。"""
|
||||
rows = getattr(result, "rows", None)
|
||||
if rows is None:
|
||||
rows = result
|
||||
if not rows:
|
||||
return []
|
||||
normalized: List[Dict[str, Any]] = []
|
||||
for row in rows:
|
||||
if isinstance(row, dict):
|
||||
normalized.append(row)
|
||||
elif hasattr(row, "_asdict"):
|
||||
normalized.append(row._asdict()) # type: ignore[attr-defined]
|
||||
else:
|
||||
try:
|
||||
normalized.append(dict(row))
|
||||
except Exception: # pragma: no cover - 无法转换时跳过
|
||||
continue
|
||||
return normalized
|
||||
|
||||
|
||||
__all__ = [
|
||||
"VectorStoreService",
|
||||
"RetrievedChunk",
|
||||
"RetrievedSummary",
|
||||
]
|
||||
Reference in New Issue
Block a user