feat: 初始提交

This commit is contained in:
anonymous
2025-10-21 09:38:26 +08:00
committed by t59688
parent 2965b8e28f
commit c9fc816fab
175 changed files with 23968 additions and 87 deletions

View File

View File

@@ -0,0 +1,27 @@
from typing import Optional
from sqlalchemy.ext.asyncio import AsyncSession
from ..models import AdminSetting
from ..repositories.admin_setting_repository import AdminSettingRepository
class AdminSettingService:
"""管理员配置项服务,提供简单的 KV 操作。"""
def __init__(self, session: AsyncSession):
self.session = session
self.repo = AdminSettingRepository(session)
async def get(self, key: str, default: Optional[str] = None) -> Optional[str]:
value = await self.repo.get_value(key)
return value if value is not None else default
async def set(self, key: str, value: str) -> None:
record = await self.repo.get(key=key)
if record:
await self.repo.update_fields(record, value=value)
else:
setting = AdminSetting(key=key, value=value)
await self.repo.add(setting)
await self.session.commit()

View File

@@ -0,0 +1,389 @@
import asyncio
import logging
import random
import secrets
import string
import time
from typing import Dict, Optional
import httpx
from email.header import Header
from email.mime.text import MIMEText
from email.utils import formataddr, parseaddr
from fastapi import HTTPException, status
import smtplib
from ..core.config import settings
from ..core.security import create_access_token, hash_password, verify_password
from ..models import User
from ..repositories.system_config_repository import SystemConfigRepository
from ..repositories.user_repository import UserRepository
from ..schemas.user import AuthOptions, Token, UserCreate, UserInDB, UserRegistration
_VERIFICATION_CACHE: Dict[str, tuple[str, float]] = {}
_LAST_SEND_TIME: Dict[str, float] = {}
class AuthService:
"""认证与授权逻辑封装登录、注册、OAuth 对接等操作。"""
def __init__(self, session):
self.session = session
self.user_repo = UserRepository(session)
self.system_config_repo = SystemConfigRepository(session)
self._verification_cache = _VERIFICATION_CACHE
self._last_send_time = _LAST_SEND_TIME
# ------------------------------------------------------------------
# 用户登录 / 注册
# ------------------------------------------------------------------
async def authenticate_user(self, username: str, password: str) -> User:
user = await self.user_repo.get_by_username(username)
if not user or not verify_password(password, user.hashed_password):
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="用户名或密码错误")
return user
async def create_access_token(
self,
user: User | UserInDB,
*,
must_change_password: Optional[bool] = None,
) -> Token:
payload = {"is_admin": user.is_admin}
token = create_access_token(user.username, extra_claims=payload)
should_change = self.requires_password_reset(user) if must_change_password is None else must_change_password
return Token(access_token=token, must_change_password=should_change)
async def register_user(self, payload: UserRegistration) -> User:
if not await self.is_registration_enabled():
raise HTTPException(status_code=403, detail="当前暂未开放注册")
if await self.user_repo.get_by_username(payload.username):
raise HTTPException(status_code=400, detail="用户名已存在")
if payload.email and await self.user_repo.get_by_email(payload.email):
raise HTTPException(status_code=400, detail="邮箱已被使用")
if not self.verify_code(payload.email, payload.verification_code):
raise HTTPException(status_code=400, detail="验证码错误或已过期")
hashed_password = hash_password(payload.password)
user = User(
username=payload.username,
email=payload.email,
hashed_password=hashed_password,
)
self.session.add(user)
await self.session.commit()
return user
# ------------------------------------------------------------------
# 邮箱验证码逻辑
# ------------------------------------------------------------------
async def send_verification_code(self, email: str) -> None:
if not await self.is_registration_enabled():
raise HTTPException(status_code=403, detail="当前暂未开放注册")
now = time.time()
if email in self._last_send_time and now - self._last_send_time[email] < 60:
raise HTTPException(status_code=429, detail="请稍后再试1分钟内不可重复发送")
code = "".join(random.choices(string.digits, k=6))
self._verification_cache[email] = (code, now + 300)
self._last_send_time[email] = now
smtp_config = await self._load_smtp_config()
if not smtp_config:
raise HTTPException(status_code=500, detail="未配置邮件服务,请联系管理员")
await self._send_email(email, code, smtp_config)
def verify_code(self, email: str | None, code: str) -> bool:
if not email:
return False
cached = self._verification_cache.get(email)
if not cached:
return False
expected, expire_at = cached
if time.time() > expire_at:
self._verification_cache.pop(email, None)
return False
if code != expected:
return False
self._verification_cache.pop(email, None)
return True
async def _load_smtp_config(self) -> Optional[Dict[str, str]]:
keys = [
"smtp.server",
"smtp.port",
"smtp.username",
"smtp.password",
"smtp.from",
]
configs = {}
for key in keys:
config = await self.system_config_repo.get_by_key(key)
if config:
configs[key] = config.value
required_keys = {"smtp.server", "smtp.port", "smtp.username", "smtp.password", "smtp.from"}
if not required_keys.issubset(configs.keys()):
return None
return configs
async def _send_email(self, to_email: str, code: str, smtp_config: Dict[str, str]) -> None:
logger = logging.getLogger(__name__)
server = smtp_config["smtp.server"]
port = int(smtp_config.get("smtp.port", "465"))
username = smtp_config["smtp.username"]
password = smtp_config["smtp.password"]
from_value = smtp_config.get("smtp.from") or username
display_name, from_addr = parseaddr(from_value)
if not display_name and "@" not in from_value and "<" not in from_value and from_value.strip():
display_name = from_value.strip()
if not from_addr or "@" not in from_addr:
if from_addr and "@" not in from_addr:
logger.warning(
"发件邮箱缺少 @,已回退为登录账号",
extra={"original": from_addr},
)
from_addr = username
try:
from_addr.encode("ascii")
except UnicodeEncodeError:
logger.warning(
"发件邮箱包含非 ASCII 字符,已回退为登录账号",
extra={"original": from_addr},
)
from_addr = username
if display_name:
formatted_from = formataddr((Header(display_name, "utf-8").encode(), from_addr))
else:
formatted_from = from_addr
try:
to_email.encode("ascii")
except UnicodeEncodeError as exc: # noqa: BLE001
raise HTTPException(status_code=400, detail="邮箱地址包含不支持的字符") from exc
html_content = f"""
<!DOCTYPE html>
<html lang="zh-CN">
<head>
<meta http-equiv=\"Content-Type\" content=\"text/html; charset=UTF-8\">
<meta name=\"viewport\" content=\"width=device-width, initial-scale=1.0\">
<title>您的验证码</title>
<style>
body, table, td, a {{ -webkit-text-size-adjust: 100%; -ms-text-size-adjust: 100%; }}
table, td {{ mso-table-lspace: 0pt; mso-table-rspace: 0pt; }}
img {{ -ms-interpolation-mode: bicubic; }}
body {{ margin: 0; padding: 0; }}
table {{ border-collapse: collapse !important; }}
</style>
</head>
<body style=\"margin: 0; padding: 0; width: 100% !important; background-color: #f3f4f6;\">
<table width=\"100%\" border=\"0\" cellpadding=\"0\" cellspacing=\"0\" bgcolor=\"#f3f4f6\">
<tr>
<td align=\"center\" valign=\"top\" style=\"padding: 20px;\">
<table border=\"0\" cellpadding=\"0\" cellspacing=\"0\" width=\"100%\" style=\"max-width: 512px; background-color: #ffffff; border-radius: 16px; overflow: hidden;\">
<tr>
<td align=\"center\" style=\"background-color: #2563eb; padding: 32px;\">
<h1 style=\"font-family: Arial, Helvetica, sans-serif; font-size: 30px; font-weight: bold; color: #ffffff; margin: 0;\">操作验证码</h1>
<p style=\"font-family: Arial, Helvetica, sans-serif; font-size: 16px; color: #dbeafe; margin: 8px 0 0;\">请使用下方验证码完成操作。</p>
</td>
</tr>
<tr>
<td align=\"center\" style=\"padding: 32px 48px;\">
<table border=\"0\" cellpadding=\"0\" cellspacing=\"0\" width=\"100%\">
<tr>
<td align=\"center\" style=\"background-color: #f3f4f6; border-radius: 12px; padding: 16px; margin: 24px 0;\">
<p style=\"font-family: 'Courier New', Courier, monospace; font-size: 48px; font-weight: bold; letter-spacing: 0.1em; color: #1d4ed8; margin: 0;\">
{code[:3]}{code[3:]}
</p>
</td>
</tr>
<tr>
<td align=\"center\" style=\"padding-top: 24px;\">
<p style=\"font-family: Arial, Helvetica, sans-serif; font-size: 16px; color: #6b7280; margin: 0;\">
此验证码将在 <span style=\"font-weight: bold; color: #374151;\">5分钟</span> 内有效。
</p>
</td>
</tr>
<tr>
<td align=\"center\" style=\"padding-top: 32px; border-top: 1px solid #e5e7eb; margin-top: 32px;\">
<p style=\"font-family: Arial, Helvetica, sans-serif; font-size: 14px; font-weight: bold; color: #ef4444; margin: 0;\">
为保障安全,请勿泄露此验证码。
</p>
</td>
</tr>
</table>
</td>
</tr>
<tr>
<td align=\"center\" style=\"background-color: #f9fafb; padding: 24px; border-top: 1px solid #e5e7eb;\">
<p style=\"font-family: Arial, Helvetica, sans-serif; font-size: 14px; color: #6b7280; margin: 0;\">
如非本人操作,请忽略此邮件。
</p>
<p style=\"font-family: Arial, Helvetica, sans-serif; font-size: 12px; color: #9ca3af; margin: 8px 0 0;\">
&copy; {time.strftime('%Y')} 拯救小说家. All rights reserved.
</p>
</td>
</tr>
</table>
</td>
</tr>
</table>
</body>
</html>
"""
message = MIMEText(html_content, "html", "utf-8")
message["Subject"] = Header("注册验证码", "utf-8").encode()
message["From"] = formatted_from
message["To"] = to_email
logger.info("准备发送验证码邮件", extra={"to": to_email, "server": server, "port": port})
def _send():
smtp: Optional[smtplib.SMTP] = None
try:
if port == 465:
smtp = smtplib.SMTP_SSL(server, port, timeout=10)
else:
smtp = smtplib.SMTP(server, port, timeout=10)
smtp.starttls()
if username and password:
smtp.login(username, password)
smtp.sendmail(from_addr, [to_email], message.as_string())
logger.info("验证码邮件发送成功", extra={"to": to_email})
except Exception as exc: # noqa: BLE001
logger.exception("验证码发送失败")
raise
finally:
if smtp is not None:
try:
smtp.quit()
except Exception: # noqa: BLE001
pass
try:
await asyncio.to_thread(_send)
except Exception as exc: # noqa: BLE001
raise HTTPException(status_code=500, detail="验证码发送失败,请检查邮件配置") from exc
# ------------------------------------------------------------------
# OAuth 对接示例(以 Linux.do 为例)
# ------------------------------------------------------------------
async def handle_linuxdo_callback(self, code: str) -> Token:
if not await self.is_linuxdo_login_enabled():
raise HTTPException(status_code=403, detail="未启用 Linux.do 登录")
client_id = await self._get_config_value("linuxdo.client_id")
client_secret = await self._get_config_value("linuxdo.client_secret")
redirect_uri = await self._get_config_value("linuxdo.redirect_uri")
token_url = await self._get_config_value("linuxdo.token_url")
user_info_url = await self._get_config_value("linuxdo.user_info_url")
if not all([client_id, client_secret, redirect_uri, token_url, user_info_url]):
raise HTTPException(status_code=500, detail="未正确配置 Linux.do OAuth 参数")
async with httpx.AsyncClient() as client:
token_response = await client.post(
token_url,
data={
"client_id": client_id,
"client_secret": client_secret,
"code": code,
"redirect_uri": redirect_uri,
"grant_type": "authorization_code",
},
)
token_response.raise_for_status()
access_token = token_response.json().get("access_token")
if not access_token:
raise HTTPException(status_code=400, detail="授权失败,未获取到访问令牌")
user_info_response = await client.get(
user_info_url,
headers={"Authorization": f"Bearer {access_token}"},
)
user_info_response.raise_for_status()
data = user_info_response.json()
external_id = f"linuxdo:{data['id']}"
user = await self.user_repo.get_by_external_id(external_id)
if user is None:
placeholder_password = secrets.token_urlsafe(16)
user = User(
username=data["username"],
email=data.get("email"),
external_id=external_id,
hashed_password=hash_password(placeholder_password),
)
self.session.add(user)
await self.session.commit()
return await self.create_access_token(user)
async def _get_config_value(self, key: str) -> Optional[str]:
config = await self.system_config_repo.get_by_key(key)
return config.value if config else None
async def get_config_value(self, key: str) -> Optional[str]:
"""对外暴露的配置读取接口,便于路由层复用。"""
return await self._get_config_value(key)
@staticmethod
def _parse_bool(value: Optional[str], fallback: bool) -> bool:
if value is None:
return fallback
normalized = value.strip().lower()
return normalized in {"1", "true", "yes", "on"}
async def is_registration_enabled(self) -> bool:
value = await self._get_config_value("auth.allow_registration")
return self._parse_bool(value, fallback=settings.allow_registration)
async def is_linuxdo_login_enabled(self) -> bool:
value = await self._get_config_value("auth.linuxdo_enabled")
return self._parse_bool(value, fallback=settings.enable_linuxdo_login)
async def get_auth_options(self) -> AuthOptions:
"""聚合与认证相关的动态开关配置,便于前端一次性拉取。"""
allow_registration = await self.is_registration_enabled()
enable_linuxdo_login = await self.is_linuxdo_login_enabled()
return AuthOptions(
allow_registration=allow_registration,
enable_linuxdo_login=enable_linuxdo_login,
)
def requires_password_reset(self, user: User | UserInDB) -> bool:
if not user.is_admin:
return False
if user.username != settings.admin_default_username:
return False
hashed_password = getattr(user, "hashed_password", None)
if not hashed_password:
return False
return verify_password(settings.admin_default_password, hashed_password)
async def change_password(self, username: str, old_password: str, new_password: str) -> None:
user = await self.user_repo.get_by_username(username)
if not user:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="用户不存在")
if not verify_password(old_password, user.hashed_password):
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="当前密码错误")
if verify_password(new_password, user.hashed_password):
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="新密码不能与当前密码相同")
if username == settings.admin_default_username and new_password == settings.admin_default_password:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="新密码不能为默认密码")
user.hashed_password = hash_password(new_password)
await self.session.commit()

View File

@@ -0,0 +1,109 @@
from __future__ import annotations
"""
章节上下文组装服务:负责调用向量库检索上下文,并对结果做基础格式化。
所有关键步骤均包含中文注释,方便团队理解 RAG 流程。
"""
import logging
from dataclasses import dataclass
from typing import List, Optional
from ..core.config import settings
from ..services.llm_service import LLMService
from .vector_store_service import RetrievedChunk, RetrievedSummary, VectorStoreService
logger = logging.getLogger(__name__)
@dataclass
class ChapterRAGContext:
"""封装检索得到的上下文结果。"""
query: str
chunks: List[RetrievedChunk]
summaries: List[RetrievedSummary]
def chunk_texts(self) -> List[str]:
"""将检索到的 chunk 转换成带序号的 Markdown 段落。"""
lines = []
for idx, chunk in enumerate(self.chunks, start=1):
title = chunk.chapter_title or f"{chunk.chapter_number}"
lines.append(
f"### Chunk {idx}(来源:{title})\n{chunk.content.strip()}"
)
return lines
def summary_lines(self) -> List[str]:
"""整理章节摘要,方便直接插入 Prompt。"""
lines = []
for summary in self.summaries:
lines.append(
f"- 第{summary.chapter_number}章 - {summary.title}:{summary.summary.strip()}"
)
return lines
class ChapterContextService:
"""章节上下文服务,整合查询、格式化与容错逻辑。"""
def __init__(
self,
*,
llm_service: LLMService,
vector_store: Optional[VectorStoreService] = None,
) -> None:
self._llm_service = llm_service
self._vector_store = vector_store
async def retrieve_for_generation(
self,
*,
project_id: str,
query_text: str,
user_id: int,
top_k_chunks: Optional[int] = None,
top_k_summaries: Optional[int] = None,
) -> ChapterRAGContext:
"""根据章节摘要构造检索向量,并返回 RAG 上下文。"""
query = self._normalize(query_text)
if not settings.vector_store_enabled or not self._vector_store:
logger.debug("向量库未启用或初始化失败,跳过检索: project=%s", project_id)
return ChapterRAGContext(query=query, chunks=[], summaries=[])
embedding_model = None if settings.embedding_provider == "ollama" else settings.embedding_model
embedding = await self._llm_service.get_embedding(query, user_id=user_id, model=embedding_model)
if not embedding:
logger.warning("检索查询向量生成失败: project=%s chapter_query=%s", project_id, query)
return ChapterRAGContext(query=query, chunks=[], summaries=[])
chunks = await self._vector_store.query_chunks(
project_id=project_id,
embedding=embedding,
top_k=top_k_chunks,
)
summaries = await self._vector_store.query_summaries(
project_id=project_id,
embedding=embedding,
top_k=top_k_summaries,
)
logger.info(
"章节上下文检索完成: project=%s chunks=%d summaries=%d query_preview=%s",
project_id,
len(chunks),
len(summaries),
query[:80],
)
return ChapterRAGContext(query=query, chunks=chunks, summaries=summaries)
@staticmethod
def _normalize(text: str) -> str:
"""统一压缩空白字符,避免影响检索效果。"""
return " ".join(text.split())
__all__ = [
"ChapterContextService",
"ChapterRAGContext",
]

View File

@@ -0,0 +1,262 @@
from __future__ import annotations
"""
章节向量入库服务:在章节确认后负责切分文本、生成嵌入并写入向量库。
全部注释使用中文,方便团队成员阅读理解。
"""
import logging
from typing import Dict, List, Optional, Sequence
from ..core.config import settings
from ..services.llm_service import LLMService
from ..services.vector_store_service import VectorStoreService
logger = logging.getLogger(__name__)
try: # noqa: SIM105 - 提示缺少可选依赖
from langchain_text_splitters import RecursiveCharacterTextSplitter
except ImportError: # pragma: no cover - 未安装时会走后备方案
RecursiveCharacterTextSplitter = None # type: ignore[assignment]
class ChapterIngestionService:
"""封装章节内容与摘要的向量化与入库流程。"""
def __init__(
self,
*,
llm_service: LLMService,
vector_store: Optional[VectorStoreService] = None,
) -> None:
self._llm_service = llm_service
self._vector_store = vector_store or VectorStoreService()
self._text_splitter = self._init_text_splitter()
async def ingest_chapter(
self,
*,
project_id: str,
chapter_number: int,
title: str,
content: str,
summary: Optional[str],
user_id: int,
) -> None:
"""将章节正文与摘要写入向量库,供后续 RAG 检索使用。"""
if not settings.vector_store_enabled:
logger.debug("向量库未启用,跳过章节向量写入: project=%s chapter=%s", project_id, chapter_number)
return
if not content.strip():
logger.debug("章节正文为空,跳过向量写入: project=%s chapter=%s", project_id, chapter_number)
return
chunks = self._split_into_chunks(content)
if not chunks:
logger.debug("章节正文切分后为空,跳过向量写入: project=%s chapter=%s", project_id, chapter_number)
return
logger.info(
"开始写入章节向量: project=%s chapter=%s chunks=%d",
project_id,
chapter_number,
len(chunks),
)
await self._vector_store.delete_by_chapters(project_id, [chapter_number])
chunk_records = []
for index, chunk_text in enumerate(chunks):
embedding = await self._llm_service.get_embedding(
chunk_text,
user_id=user_id,
)
if not embedding:
logger.warning(
"生成章节片段向量失败,已跳过: project=%s chapter=%s chunk=%s",
project_id,
chapter_number,
index,
)
continue
record_id = f"{project_id}:{chapter_number}:{index}"
chunk_records.append(
{
"id": record_id,
"project_id": project_id,
"chapter_number": chapter_number,
"chunk_index": index,
"chapter_title": title,
"content": chunk_text,
"embedding": embedding,
"metadata": {
"chunk_id": record_id,
"length": len(chunk_text),
},
}
)
if chunk_records:
await self._vector_store.upsert_chunks(records=chunk_records)
logger.info(
"章节正文向量写入完成: project=%s chapter=%s 成功片段=%d",
project_id,
chapter_number,
len(chunk_records),
)
if summary:
cleaned_summary = summary.strip()
if cleaned_summary:
summary_embedding = await self._llm_service.get_embedding(
cleaned_summary,
user_id=user_id,
)
if summary_embedding:
summary_id = f"{project_id}:{chapter_number}:summary"
await self._vector_store.upsert_summaries(
records=[
{
"id": summary_id,
"project_id": project_id,
"chapter_number": chapter_number,
"title": title,
"summary": cleaned_summary,
"embedding": summary_embedding,
}
]
)
logger.info(
"章节摘要向量写入完成: project=%s chapter=%s",
project_id,
chapter_number,
)
else:
logger.warning(
"生成章节摘要向量失败,已跳过: project=%s chapter=%s",
project_id,
chapter_number,
)
async def delete_chapters(self, project_id: str, chapter_numbers: Sequence[int]) -> None:
"""从向量库中删除指定章节的所有片段与摘要。"""
if not settings.vector_store_enabled or not chapter_numbers:
return
logger.info(
"准备删除章节向量: project=%s chapters=%s",
project_id,
list(chapter_numbers),
)
await self._vector_store.delete_by_chapters(project_id, list(chapter_numbers))
def _split_into_chunks(self, text: str) -> List[str]:
"""按照配置的 chunk 大小与重叠度切分章节正文。"""
normalized = text.strip()
if not normalized:
return []
if self._text_splitter:
parts = [segment.strip() for segment in self._text_splitter.split_text(normalized)]
filtered = [part for part in parts if part]
if filtered:
logger.debug(
"使用 LangChain 文本切分器完成分段: count=%d chunk_size=%d overlap=%d",
len(filtered),
settings.vector_chunk_size,
settings.vector_chunk_overlap,
)
return filtered
return self._legacy_split(normalized)
@staticmethod
def _find_split_offset(segment: str) -> Optional[int]:
"""在片段内部寻找更自然的分割点,优先换行,其次常见标点。"""
candidates: Dict[str, int] = {}
newline_pos = segment.rfind("\n\n")
if newline_pos == -1:
newline_pos = segment.rfind("\n")
if newline_pos > 0:
candidates["newline"] = newline_pos
punctuation_marks = ["", "", "", "!", "?", ".", ";", ""]
for mark in punctuation_marks:
idx = segment.rfind(mark)
if idx > 0:
candidates.setdefault("punctuation", idx + len(mark))
if not candidates:
return None
# 选择最接近末尾但又不过短的分割点
best_offset = max(candidates.values())
if best_offset < len(segment) * 0.4:
return None
return best_offset
def _init_text_splitter(self) -> Optional["RecursiveCharacterTextSplitter"]:
"""初始化 LangChain 文本切分器,可根据配置动态调整。"""
if RecursiveCharacterTextSplitter is None:
logger.warning("未安装 langchain-text-splitters章节切分将回退至内置策略。")
return None
chunk_size = settings.vector_chunk_size
overlap = min(settings.vector_chunk_overlap, chunk_size // 2)
separators = [
"\n\n",
"\n",
"", "", "",
"!", "?", "", ";",
"", ",",
" ",
]
splitter = RecursiveCharacterTextSplitter(
separators=separators,
chunk_size=chunk_size,
chunk_overlap=overlap,
keep_separator=False,
strip_whitespace=True,
)
logger.info(
"已初始化 LangChain 文本切分器: chunk_size=%d overlap=%d",
chunk_size,
overlap,
)
return splitter
def _legacy_split(self, text: str) -> List[str]:
"""内置切分策略,作为 LangChain 缺失时的后备方案。"""
chunk_size = settings.vector_chunk_size
overlap = min(settings.vector_chunk_overlap, chunk_size // 2)
chunks: List[str] = []
start = 0
total_length = len(text)
while start < total_length:
end = min(total_length, start + chunk_size)
segment = text[start:end]
split_offset = self._find_split_offset(segment)
if split_offset is not None and start + split_offset < total_length:
end = start + split_offset
segment = text[start:end]
chunk_text = segment.strip()
if chunk_text:
chunks.append(chunk_text)
if end >= total_length:
break
start = max(0, end - overlap)
logger.debug(
"使用内置策略完成章节切分: count=%d chunk_size=%d overlap=%d",
len(chunks),
chunk_size,
overlap,
)
return chunks
__all__ = ["ChapterIngestionService"]

View File

@@ -0,0 +1,49 @@
from typing import Iterable, Optional
from sqlalchemy.ext.asyncio import AsyncSession
from ..repositories.system_config_repository import SystemConfigRepository
from ..models import SystemConfig
from ..schemas.config import SystemConfigCreate, SystemConfigRead, SystemConfigUpdate
class ConfigService:
"""系统配置服务:提供 CRUD 接口,并负责转换 Pydantic 模型。"""
def __init__(self, session: AsyncSession):
self.session = session
self.repo = SystemConfigRepository(session)
async def list_configs(self) -> list[SystemConfigRead]:
configs = await self.repo.list_all()
return [SystemConfigRead.model_validate(cfg) for cfg in configs]
async def get_config(self, key: str) -> Optional[SystemConfigRead]:
config = await self.repo.get_by_key(key)
return SystemConfigRead.model_validate(config) if config else None
async def upsert_config(self, payload: SystemConfigCreate) -> SystemConfigRead:
instance = await self.repo.get_by_key(payload.key)
if instance:
await self.repo.update_fields(instance, value=payload.value, description=payload.description)
else:
instance = SystemConfig(**payload.model_dump())
await self.repo.add(instance)
await self.session.commit()
return SystemConfigRead.model_validate(instance)
async def patch_config(self, key: str, payload: SystemConfigUpdate) -> Optional[SystemConfigRead]:
instance = await self.repo.get_by_key(key)
if not instance:
return None
await self.repo.update_fields(instance, **payload.model_dump(exclude_unset=True))
await self.session.commit()
return SystemConfigRead.model_validate(instance)
async def remove_config(self, key: str) -> bool:
instance = await self.repo.get_by_key(key)
if not instance:
return False
await self.repo.delete(instance)
await self.session.commit()
return True

View File

@@ -0,0 +1,41 @@
from typing import Optional
from sqlalchemy.ext.asyncio import AsyncSession
from ..models import LLMConfig
from ..repositories.llm_config_repository import LLMConfigRepository
from ..schemas.llm_config import LLMConfigCreate, LLMConfigRead
class LLMConfigService:
"""用户自定义 LLM 配置服务。"""
def __init__(self, session: AsyncSession):
self.session = session
self.repo = LLMConfigRepository(session)
async def upsert_config(self, user_id: int, payload: LLMConfigCreate) -> LLMConfigRead:
instance = await self.repo.get_by_user(user_id)
data = payload.model_dump(exclude_unset=True)
if "llm_provider_url" in data and data["llm_provider_url"] is not None:
# HttpUrl 类型在 sqlite 中无法直接写入,需要提前转为字符串
data["llm_provider_url"] = str(data["llm_provider_url"])
if instance:
await self.repo.update_fields(instance, **data)
else:
instance = LLMConfig(user_id=user_id, **data)
await self.repo.add(instance)
await self.session.commit()
return LLMConfigRead.model_validate(instance)
async def get_config(self, user_id: int) -> Optional[LLMConfigRead]:
instance = await self.repo.get_by_user(user_id)
return LLMConfigRead.model_validate(instance) if instance else None
async def delete_config(self, user_id: int) -> bool:
instance = await self.repo.get_by_user(user_id)
if not instance:
return False
await self.repo.delete(instance)
await self.session.commit()
return True

View File

@@ -0,0 +1,306 @@
import logging
import os
from typing import Any, Dict, List, Optional
import httpx
from fastapi import HTTPException, status
from openai import APIConnectionError, APITimeoutError, AsyncOpenAI, InternalServerError
from ..core.config import settings
from ..repositories.llm_config_repository import LLMConfigRepository
from ..repositories.system_config_repository import SystemConfigRepository
from ..repositories.user_repository import UserRepository
from ..services.admin_setting_service import AdminSettingService
from ..services.prompt_service import PromptService
from ..services.usage_service import UsageService
from ..utils.llm_tool import ChatMessage, LLMClient
logger = logging.getLogger(__name__)
try: # pragma: no cover - 运行环境未安装时兼容
from ollama import AsyncClient as OllamaAsyncClient
except ImportError: # pragma: no cover - Ollama 为可选依赖
OllamaAsyncClient = None
class LLMService:
"""封装与大模型交互的所有逻辑,包括配额控制与配置选择。"""
def __init__(self, session):
self.session = session
self.llm_repo = LLMConfigRepository(session)
self.system_config_repo = SystemConfigRepository(session)
self.user_repo = UserRepository(session)
self.admin_setting_service = AdminSettingService(session)
self.usage_service = UsageService(session)
self._embedding_dimensions: Dict[str, int] = {}
async def get_llm_response(
self,
system_prompt: str,
conversation_history: List[Dict[str, str]],
*,
temperature: float = 0.7,
user_id: Optional[int] = None,
timeout: float = 300.0,
response_format: Optional[str] = "json_object",
) -> str:
messages = [{"role": "system", "content": system_prompt}, *conversation_history]
return await self._stream_and_collect(
messages,
temperature=temperature,
user_id=user_id,
timeout=timeout,
response_format=response_format,
)
async def get_summary(
self,
chapter_content: str,
*,
temperature: float = 0.2,
user_id: Optional[int] = None,
timeout: float = 180.0,
system_prompt: Optional[str] = None,
) -> str:
if not system_prompt:
prompt_service = PromptService(self.session)
system_prompt = await prompt_service.get_prompt("extraction")
if not system_prompt:
raise HTTPException(status_code=500, detail="未配置摘要提示词")
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": chapter_content},
]
return await self._stream_and_collect(messages, temperature=temperature, user_id=user_id, timeout=timeout)
async def _stream_and_collect(
self,
messages: List[Dict[str, str]],
*,
temperature: float,
user_id: Optional[int],
timeout: float,
response_format: Optional[str] = None,
) -> str:
config = await self._resolve_llm_config(user_id)
client = LLMClient(api_key=config["api_key"], base_url=config.get("base_url"))
chat_messages = [ChatMessage(role=msg["role"], content=msg["content"]) for msg in messages]
full_response = ""
finish_reason = None
logger.info(
"Streaming LLM response: model=%s user_id=%s messages=%d",
config.get("model"),
user_id,
len(messages),
)
try:
async for part in client.stream_chat(
messages=chat_messages,
model=config.get("model"),
temperature=temperature,
timeout=int(timeout),
response_format=response_format,
):
if part.get("content"):
full_response += part["content"]
if part.get("finish_reason"):
finish_reason = part["finish_reason"]
except InternalServerError as exc:
detail = "AI 服务内部错误,请稍后重试"
response = getattr(exc, "response", None)
if response is not None:
try:
payload = response.json()
error_data = payload.get("error", {}) if isinstance(payload, dict) else {}
detail = error_data.get("message_zh") or error_data.get("message") or detail
except Exception:
detail = str(exc) or detail
else:
detail = str(exc) or detail
logger.error(
"LLM stream internal error: model=%s user_id=%s detail=%s",
config.get("model"),
user_id,
detail,
exc_info=exc,
)
raise HTTPException(status_code=503, detail=detail)
except (httpx.RemoteProtocolError, httpx.ReadTimeout, APIConnectionError, APITimeoutError) as exc:
if isinstance(exc, httpx.RemoteProtocolError):
detail = "AI 服务连接被意外中断,请稍后重试"
elif isinstance(exc, (httpx.ReadTimeout, APITimeoutError)):
detail = "AI 服务响应超时,请稍后重试"
else:
detail = "无法连接到 AI 服务,请稍后重试"
logger.error(
"LLM stream failed: model=%s user_id=%s detail=%s",
config.get("model"),
user_id,
detail,
exc_info=exc,
)
raise HTTPException(status_code=503, detail=detail) from exc
logger.debug(
"LLM response collected: model=%s user_id=%s finish_reason=%s preview=%s",
config.get("model"),
user_id,
finish_reason,
full_response[:500],
)
if finish_reason == "length":
logger.warning(
"LLM response truncated: model=%s user_id=%s",
config.get("model"),
user_id,
)
raise HTTPException(status_code=500, detail="AI 响应被截断,请缩短输入或调整参数")
if not full_response:
logger.error(
"LLM returned empty response: model=%s user_id=%s",
config.get("model"),
user_id,
)
raise HTTPException(status_code=500, detail="AI 未返回有效内容")
await self.usage_service.increment("api_request_count")
logger.info(
"LLM response success: model=%s user_id=%s chars=%d",
config.get("model"),
user_id,
len(full_response),
)
return full_response
async def _resolve_llm_config(self, user_id: Optional[int]) -> Dict[str, Optional[str]]:
if user_id:
config = await self.llm_repo.get_by_user(user_id)
if config and config.llm_provider_api_key:
return {
"api_key": config.llm_provider_api_key,
"base_url": config.llm_provider_url,
"model": config.llm_provider_model,
}
# 检查每日使用次数限制
if user_id:
await self._enforce_daily_limit(user_id)
api_key = await self._get_config_value("llm.api_key")
base_url = await self._get_config_value("llm.base_url")
model = await self._get_config_value("llm.model")
if not api_key:
raise HTTPException(status_code=500, detail="未配置默认 LLM API Key")
return {"api_key": api_key, "base_url": base_url, "model": model}
async def get_embedding(
self,
text: str,
*,
user_id: Optional[int] = None,
model: Optional[str] = None,
) -> List[float]:
"""生成文本向量,用于章节 RAG 检索,支持 openai 与 ollama 双提供方。"""
provider = settings.embedding_provider
target_model = model or (
settings.ollama_embedding_model if provider == "ollama" else settings.embedding_model
)
if provider == "ollama":
if OllamaAsyncClient is None:
logger.error("未安装 ollama 依赖,无法调用本地嵌入模型。")
raise HTTPException(status_code=500, detail="缺少 Ollama 依赖,请先安装 ollama 包。")
base_url_any = settings.ollama_embedding_base_url or settings.embedding_base_url
base_url = str(base_url_any) if base_url_any else None
client = OllamaAsyncClient(host=base_url)
try:
response = await client.embeddings(model=target_model, prompt=text)
except Exception as exc: # pragma: no cover - 本地服务调用失败
logger.warning(
"Ollama 嵌入请求失败: model=%s error=%s",
target_model,
exc,
)
return []
embedding: Optional[List[float]]
if isinstance(response, dict):
embedding = response.get("embedding")
else:
embedding = getattr(response, "embedding", None)
if not embedding:
logger.warning("Ollama 返回空向量: model=%s", target_model)
return []
if not isinstance(embedding, list):
embedding = list(embedding)
else:
config = await self._resolve_llm_config(user_id)
api_key = settings.embedding_api_key or config["api_key"]
base_url_setting = settings.embedding_base_url or config.get("base_url")
base_url = str(base_url_setting) if base_url_setting else None
client = AsyncOpenAI(api_key=api_key, base_url=base_url)
try:
response = await client.embeddings.create(
input=text,
model=target_model,
)
except Exception as exc: # pragma: no cover - 网络或鉴权失败
logger.warning(
"OpenAI 嵌入请求失败: model=%s user_id=%s error=%s",
target_model,
user_id,
exc,
)
return []
if not response.data:
logger.warning("OpenAI 嵌入请求返回空数据: model=%s user_id=%s", target_model, user_id)
return []
embedding = response.data[0].embedding
if not isinstance(embedding, list):
embedding = list(embedding)
dimension = len(embedding)
if not dimension and settings.embedding_model_vector_size:
dimension = settings.embedding_model_vector_size
if dimension:
self._embedding_dimensions[target_model] = dimension
return embedding
def get_embedding_dimension(self, model: Optional[str] = None) -> Optional[int]:
"""获取嵌入向量维度,优先返回缓存结果,其次读取配置。"""
target_model = model or (
settings.ollama_embedding_model if settings.embedding_provider == "ollama" else settings.embedding_model
)
if target_model in self._embedding_dimensions:
return self._embedding_dimensions[target_model]
return settings.embedding_model_vector_size
async def _enforce_daily_limit(self, user_id: int) -> None:
limit_str = await self.admin_setting_service.get("daily_request_limit", "100")
limit = int(limit_str or 10)
used = await self.user_repo.get_daily_request(user_id)
if used >= limit:
raise HTTPException(
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
detail="今日请求次数已达上限,请明日再试或设置自定义 API Key。",
)
await self.user_repo.increment_daily_request(user_id)
await self.session.commit()
async def _get_config_value(self, key: str) -> Optional[str]:
record = await self.system_config_repo.get_by_key(key)
if record:
return record.value
# 兼容环境变量,首次迁移时无需立即写入数据库
env_key = key.upper().replace(".", "_")
return os.getenv(env_key)

View File

@@ -0,0 +1,700 @@
from __future__ import annotations
import json
import uuid
from datetime import datetime, timezone
from typing import Any, Dict, Iterable, List, Optional
_PREFERRED_CONTENT_KEYS: tuple[str, ...] = (
"content",
"chapter_content",
"chapter_text",
"full_content",
"text",
"body",
"story",
"chapter",
"real_summary",
"summary",
)
def _normalize_version_content(raw_content: Any, metadata: Any) -> str:
text = _coerce_text(metadata)
if not text:
text = _coerce_text(raw_content)
return text or ""
def _coerce_text(value: Any) -> Optional[str]:
if value is None:
return None
if isinstance(value, str):
return _clean_string(value)
if isinstance(value, (int, float)):
return str(value)
if isinstance(value, dict):
for key in _PREFERRED_CONTENT_KEYS:
if key in value and value[key]:
nested = _coerce_text(value[key])
if nested:
return nested
return _clean_string(json.dumps(value, ensure_ascii=False))
if isinstance(value, (list, tuple, set)):
parts = [text for text in (_coerce_text(item) for item in value) if text]
if parts:
return "\n".join(parts)
return None
return _clean_string(str(value))
def _clean_string(text: str) -> str:
stripped = text.strip()
if not stripped:
return stripped
if stripped.startswith("{") and stripped.endswith("}"):
try:
parsed = json.loads(stripped)
coerced = _coerce_text(parsed)
if coerced:
return coerced
except json.JSONDecodeError:
pass
if stripped.startswith('"') and stripped.endswith('"') and len(stripped) >= 2:
stripped = stripped[1:-1]
return (
stripped.replace("\\n", "\n")
.replace("\\t", "\t")
.replace('\\"', '"')
.replace("\\\\", "\\")
)
from fastapi import HTTPException, status
from sqlalchemy import delete, func, select, update
from sqlalchemy.ext.asyncio import AsyncSession
from ..models import (
BlueprintCharacter,
BlueprintRelationship,
Chapter,
ChapterEvaluation,
ChapterOutline,
ChapterVersion,
NovelBlueprint,
NovelConversation,
NovelProject,
)
from ..repositories.novel_repository import NovelRepository
from ..schemas.admin import AdminNovelSummary
from ..schemas.novel import (
Blueprint,
Chapter as ChapterSchema,
ChapterGenerationStatus,
ChapterOutline as ChapterOutlineSchema,
NovelProject as NovelProjectSchema,
NovelProjectSummary,
NovelSectionResponse,
NovelSectionType,
)
class NovelService:
"""小说项目服务,基于拆表后的结构提供聚合与业务操作。"""
def __init__(self, session: AsyncSession):
self.session = session
self.repo = NovelRepository(session)
# ------------------------------------------------------------------
# 项目与摘要
# ------------------------------------------------------------------
async def create_project(self, user_id: int, title: str, initial_prompt: str) -> NovelProject:
project = NovelProject(
id=str(uuid.uuid4()),
user_id=user_id,
title=title,
initial_prompt=initial_prompt,
)
blueprint = NovelBlueprint(project=project)
self.session.add_all([project, blueprint])
await self.session.commit()
await self.session.refresh(project)
return project
async def ensure_project_owner(self, project_id: str, user_id: int) -> NovelProject:
project = await self.repo.get_by_id(project_id)
if not project:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="项目不存在")
if project.user_id != user_id:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="无权访问该项目")
return project
async def get_project_schema(self, project_id: str, user_id: int) -> NovelProjectSchema:
project = await self.ensure_project_owner(project_id, user_id)
return await self._serialize_project(project)
async def get_section_data(
self,
project_id: str,
user_id: int,
section: NovelSectionType,
) -> NovelSectionResponse:
project = await self.ensure_project_owner(project_id, user_id)
return self._build_section_response(project, section)
async def get_chapter_schema(
self,
project_id: str,
user_id: int,
chapter_number: int,
) -> ChapterSchema:
project = await self.ensure_project_owner(project_id, user_id)
return self._build_chapter_schema(project, chapter_number)
async def list_projects_for_user(self, user_id: int) -> List[NovelProjectSummary]:
projects = await self.repo.list_by_user(user_id)
summaries: List[NovelProjectSummary] = []
for project in projects:
blueprint = project.blueprint
genre = blueprint.genre if blueprint and blueprint.genre else "未知"
outlines = project.outlines
chapters = project.chapters
total = len(outlines) or len(chapters)
completed = sum(1 for chapter in chapters if chapter.selected_version_id)
summaries.append(
NovelProjectSummary(
id=project.id,
title=project.title,
genre=genre,
last_edited=project.updated_at.isoformat() if project.updated_at else "未知",
completed_chapters=completed,
total_chapters=total,
)
)
return summaries
async def list_projects_for_admin(self) -> List[AdminNovelSummary]:
projects = await self.repo.list_all()
summaries: List[AdminNovelSummary] = []
for project in projects:
blueprint = project.blueprint
genre = blueprint.genre if blueprint and blueprint.genre else "未知"
outlines = project.outlines
chapters = project.chapters
total = len(outlines) or len(chapters)
completed = sum(1 for chapter in chapters if chapter.selected_version_id)
owner = project.owner
summaries.append(
AdminNovelSummary(
id=project.id,
title=project.title,
owner_id=owner.id if owner else 0,
owner_username=owner.username if owner else "未知",
genre=genre,
last_edited=project.updated_at.isoformat() if project.updated_at else "",
completed_chapters=completed,
total_chapters=total,
)
)
return summaries
async def delete_projects(self, project_ids: List[str], user_id: int) -> None:
for pid in project_ids:
project = await self.ensure_project_owner(pid, user_id)
await self.repo.delete(project)
await self.session.commit()
async def count_projects(self) -> int:
result = await self.session.execute(select(func.count(NovelProject.id)))
return result.scalar_one()
# ------------------------------------------------------------------
# 对话管理
# ------------------------------------------------------------------
async def list_conversations(self, project_id: str) -> List[NovelConversation]:
stmt = (
select(NovelConversation)
.where(NovelConversation.project_id == project_id)
.order_by(NovelConversation.seq.asc())
)
result = await self.session.execute(stmt)
return list(result.scalars())
async def append_conversation(self, project_id: str, role: str, content: str, metadata: Optional[Dict] = None) -> None:
result = await self.session.execute(
select(func.max(NovelConversation.seq)).where(NovelConversation.project_id == project_id)
)
current_max = result.scalar()
next_seq = (current_max or 0) + 1
convo = NovelConversation(
project_id=project_id,
seq=next_seq,
role=role,
content=content,
metadata=metadata,
)
self.session.add(convo)
await self.session.commit()
await self._touch_project(project_id)
# ------------------------------------------------------------------
# 蓝图管理
# ------------------------------------------------------------------
async def replace_blueprint(self, project_id: str, blueprint: Blueprint) -> None:
record = await self.session.get(NovelBlueprint, project_id)
if not record:
record = NovelBlueprint(project_id=project_id)
self.session.add(record)
record.title = blueprint.title
record.target_audience = blueprint.target_audience
record.genre = blueprint.genre
record.style = blueprint.style
record.tone = blueprint.tone
record.one_sentence_summary = blueprint.one_sentence_summary
record.full_synopsis = blueprint.full_synopsis
record.world_setting = blueprint.world_setting
await self.session.execute(delete(BlueprintCharacter).where(BlueprintCharacter.project_id == project_id))
for index, data in enumerate(blueprint.characters):
self.session.add(
BlueprintCharacter(
project_id=project_id,
name=data.get("name", ""),
identity=data.get("identity"),
personality=data.get("personality"),
goals=data.get("goals"),
abilities=data.get("abilities"),
relationship_to_protagonist=data.get("relationship_to_protagonist"),
extra={k: v for k, v in data.items() if k not in {
"name",
"identity",
"personality",
"goals",
"abilities",
"relationship_to_protagonist",
}},
position=index,
)
)
await self.session.execute(delete(BlueprintRelationship).where(BlueprintRelationship.project_id == project_id))
for index, relation in enumerate(blueprint.relationships):
self.session.add(
BlueprintRelationship(
project_id=project_id,
character_from=relation.character_from,
character_to=relation.character_to,
description=relation.description,
position=index,
)
)
await self.session.execute(delete(ChapterOutline).where(ChapterOutline.project_id == project_id))
for outline in blueprint.chapter_outline:
self.session.add(
ChapterOutline(
project_id=project_id,
chapter_number=outline.chapter_number,
title=outline.title,
summary=outline.summary,
)
)
await self.session.commit()
await self._touch_project(project_id)
async def patch_blueprint(self, project_id: str, patch: Dict) -> None:
blueprint = await self.session.get(NovelBlueprint, project_id)
if not blueprint:
blueprint = NovelBlueprint(project_id=project_id)
self.session.add(blueprint)
if "one_sentence_summary" in patch:
blueprint.one_sentence_summary = patch["one_sentence_summary"]
if "full_synopsis" in patch:
blueprint.full_synopsis = patch["full_synopsis"]
if "world_setting" in patch and patch["world_setting"] is not None:
existing = blueprint.world_setting or {}
existing.update(patch["world_setting"])
blueprint.world_setting = existing
if "characters" in patch and patch["characters"] is not None:
await self.session.execute(delete(BlueprintCharacter).where(BlueprintCharacter.project_id == project_id))
for index, data in enumerate(patch["characters"]):
self.session.add(
BlueprintCharacter(
project_id=project_id,
name=data.get("name", ""),
identity=data.get("identity"),
personality=data.get("personality"),
goals=data.get("goals"),
abilities=data.get("abilities"),
relationship_to_protagonist=data.get("relationship_to_protagonist"),
extra={k: v for k, v in data.items() if k not in {
"name",
"identity",
"personality",
"goals",
"abilities",
"relationship_to_protagonist",
}},
position=index,
)
)
if "relationships" in patch and patch["relationships"] is not None:
await self.session.execute(delete(BlueprintRelationship).where(BlueprintRelationship.project_id == project_id))
for index, relation in enumerate(patch["relationships"]):
self.session.add(
BlueprintRelationship(
project_id=project_id,
character_from=relation.get("character_from"),
character_to=relation.get("character_to"),
description=relation.get("description"),
position=index,
)
)
if "chapter_outline" in patch and patch["chapter_outline"] is not None:
await self.session.execute(delete(ChapterOutline).where(ChapterOutline.project_id == project_id))
for outline in patch["chapter_outline"]:
self.session.add(
ChapterOutline(
project_id=project_id,
chapter_number=outline.get("chapter_number"),
title=outline.get("title", ""),
summary=outline.get("summary"),
)
)
await self.session.commit()
await self._touch_project(project_id)
# ------------------------------------------------------------------
# 章节与版本
# ------------------------------------------------------------------
async def get_outline(self, project_id: str, chapter_number: int) -> Optional[ChapterOutline]:
stmt = (
select(ChapterOutline)
.where(
ChapterOutline.project_id == project_id,
ChapterOutline.chapter_number == chapter_number,
)
)
result = await self.session.execute(stmt)
return result.scalars().first()
async def get_or_create_chapter(self, project_id: str, chapter_number: int) -> Chapter:
stmt = (
select(Chapter)
.where(
Chapter.project_id == project_id,
Chapter.chapter_number == chapter_number,
)
)
result = await self.session.execute(stmt)
chapter = result.scalars().first()
if chapter:
return chapter
chapter = Chapter(project_id=project_id, chapter_number=chapter_number)
self.session.add(chapter)
await self.session.commit()
await self.session.refresh(chapter)
return chapter
async def replace_chapter_versions(self, chapter: Chapter, contents: List[str], metadata: Optional[List[Dict]] = None) -> List[ChapterVersion]:
await self.session.execute(delete(ChapterVersion).where(ChapterVersion.chapter_id == chapter.id))
versions: List[ChapterVersion] = []
for index, content in enumerate(contents):
extra = metadata[index] if metadata and index < len(metadata) else None
text_content = _normalize_version_content(content, extra)
version = ChapterVersion(
chapter_id=chapter.id,
content=text_content,
metadata=None,
version_label=f"v{index+1}",
)
self.session.add(version)
versions.append(version)
chapter.status = ChapterGenerationStatus.WAITING_FOR_CONFIRM.value
await self.session.commit()
await self.session.refresh(chapter)
await self._touch_project(chapter.project_id)
return versions
async def select_chapter_version(self, chapter: Chapter, version_index: int) -> ChapterVersion:
versions = sorted(chapter.versions, key=lambda item: item.created_at)
if not versions or version_index < 0 or version_index >= len(versions):
raise HTTPException(status_code=400, detail="版本索引无效")
selected = versions[version_index]
chapter.selected_version_id = selected.id
chapter.status = ChapterGenerationStatus.SUCCESSFUL.value
chapter.word_count = len(selected.content or "")
await self.session.commit()
await self.session.refresh(chapter)
await self._touch_project(chapter.project_id)
return selected
async def add_chapter_evaluation(self, chapter: Chapter, version: Optional[ChapterVersion], feedback: str, decision: Optional[str] = None) -> None:
evaluation = ChapterEvaluation(
chapter_id=chapter.id,
version_id=version.id if version else None,
feedback=feedback,
decision=decision,
)
self.session.add(evaluation)
chapter.status = ChapterGenerationStatus.WAITING_FOR_CONFIRM.value
await self.session.commit()
await self.session.refresh(chapter)
await self._touch_project(chapter.project_id)
async def delete_chapters(self, project_id: str, chapter_numbers: Iterable[int]) -> None:
await self.session.execute(
delete(Chapter).where(
Chapter.project_id == project_id,
Chapter.chapter_number.in_(list(chapter_numbers)),
)
)
await self.session.execute(
delete(ChapterOutline).where(
ChapterOutline.project_id == project_id,
ChapterOutline.chapter_number.in_(list(chapter_numbers)),
)
)
await self.session.commit()
await self._touch_project(project_id)
# ------------------------------------------------------------------
# 序列化辅助
# ------------------------------------------------------------------
async def get_project_schema_for_admin(self, project_id: str) -> NovelProjectSchema:
project = await self.repo.get_by_id(project_id)
if not project:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="项目不存在")
return await self._serialize_project(project)
async def get_section_data_for_admin(
self,
project_id: str,
section: NovelSectionType,
) -> NovelSectionResponse:
project = await self.repo.get_by_id(project_id)
if not project:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="项目不存在")
return self._build_section_response(project, section)
async def get_chapter_schema_for_admin(
self,
project_id: str,
chapter_number: int,
) -> ChapterSchema:
project = await self.repo.get_by_id(project_id)
if not project:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="项目不存在")
return self._build_chapter_schema(project, chapter_number)
async def _serialize_project(self, project: NovelProject) -> NovelProjectSchema:
conversations = [
{"role": convo.role, "content": convo.content}
for convo in sorted(project.conversations, key=lambda c: c.seq)
]
blueprint_schema = self._build_blueprint_schema(project)
outlines_map = {outline.chapter_number: outline for outline in project.outlines}
chapters_map = {chapter.chapter_number: chapter for chapter in project.chapters}
chapter_numbers = sorted(set(outlines_map.keys()) | set(chapters_map.keys()))
chapters_schema: List[ChapterSchema] = [
self._build_chapter_schema(
project,
number,
outlines_map=outlines_map,
chapters_map=chapters_map,
)
for number in chapter_numbers
]
return NovelProjectSchema(
id=project.id,
user_id=project.user_id,
title=project.title,
initial_prompt=project.initial_prompt or "",
conversation_history=conversations,
blueprint=blueprint_schema,
chapters=chapters_schema,
)
async def _touch_project(self, project_id: str) -> None:
await self.session.execute(
update(NovelProject)
.where(NovelProject.id == project_id)
.values(updated_at=datetime.now(timezone.utc))
)
await self.session.commit()
def _build_blueprint_schema(self, project: NovelProject) -> Blueprint:
blueprint_obj = project.blueprint
if blueprint_obj:
return Blueprint(
title=blueprint_obj.title or "",
target_audience=blueprint_obj.target_audience or "",
genre=blueprint_obj.genre or "",
style=blueprint_obj.style or "",
tone=blueprint_obj.tone or "",
one_sentence_summary=blueprint_obj.one_sentence_summary or "",
full_synopsis=blueprint_obj.full_synopsis or "",
world_setting=blueprint_obj.world_setting or {},
characters=[
{
"name": character.name,
"identity": character.identity,
"personality": character.personality,
"goals": character.goals,
"abilities": character.abilities,
"relationship_to_protagonist": character.relationship_to_protagonist,
**(character.extra or {}),
}
for character in sorted(project.characters, key=lambda c: c.position)
],
relationships=[
{
"character_from": relation.character_from,
"character_to": relation.character_to,
"description": relation.description or "",
"relationship_type": getattr(relation, "relationship_type", None),
}
for relation in sorted(project.relationships_, key=lambda r: r.position)
],
chapter_outline=[
ChapterOutlineSchema(
chapter_number=outline.chapter_number,
title=outline.title,
summary=outline.summary or "",
)
for outline in sorted(project.outlines, key=lambda o: o.chapter_number)
],
)
return Blueprint(
title="",
target_audience="",
genre="",
style="",
tone="",
one_sentence_summary="",
full_synopsis="",
world_setting={},
characters=[],
relationships=[],
chapter_outline=[],
)
def _build_section_response(
self,
project: NovelProject,
section: NovelSectionType,
) -> NovelSectionResponse:
blueprint = self._build_blueprint_schema(project)
if section == NovelSectionType.OVERVIEW:
data = {
"title": project.title,
"initial_prompt": project.initial_prompt or "",
"status": project.status,
"one_sentence_summary": blueprint.one_sentence_summary,
"target_audience": blueprint.target_audience,
"genre": blueprint.genre,
"style": blueprint.style,
"tone": blueprint.tone,
"full_synopsis": blueprint.full_synopsis,
"updated_at": project.updated_at.isoformat() if project.updated_at else None,
}
elif section == NovelSectionType.WORLD_SETTING:
data = {
"world_setting": blueprint.world_setting or {},
}
elif section == NovelSectionType.CHARACTERS:
data = {
"characters": blueprint.characters,
}
elif section == NovelSectionType.RELATIONSHIPS:
data = {
"relationships": blueprint.relationships,
}
elif section == NovelSectionType.CHAPTER_OUTLINE:
data = {
"chapter_outline": [outline.model_dump() for outline in blueprint.chapter_outline],
}
elif section == NovelSectionType.CHAPTERS:
outlines_map = {outline.chapter_number: outline for outline in project.outlines}
chapters_map = {chapter.chapter_number: chapter for chapter in project.chapters}
chapter_numbers = sorted(set(outlines_map.keys()) | set(chapters_map.keys()))
# 章节列表只返回元数据,不包含完整内容
chapters = [
self._build_chapter_schema(
project,
number,
outlines_map=outlines_map,
chapters_map=chapters_map,
include_content=False,
).model_dump()
for number in chapter_numbers
]
data = {
"chapters": chapters,
"total": len(chapters),
}
else:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="未知的章节类型")
return NovelSectionResponse(section=section, data=data)
def _build_chapter_schema(
self,
project: NovelProject,
chapter_number: int,
*,
outlines_map: Optional[Dict[int, ChapterOutline]] = None,
chapters_map: Optional[Dict[int, Chapter]] = None,
include_content: bool = True,
) -> ChapterSchema:
outlines = outlines_map or {outline.chapter_number: outline for outline in project.outlines}
chapters = chapters_map or {chapter.chapter_number: chapter for chapter in project.chapters}
outline = outlines.get(chapter_number)
chapter = chapters.get(chapter_number)
if not outline and not chapter:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="章节不存在")
title = outline.title if outline else f"{chapter_number}"
summary = outline.summary if outline else ""
real_summary = chapter.real_summary if chapter else None
content = None
versions: Optional[List[str]] = None
evaluation_text: Optional[str] = None
status_value = ChapterGenerationStatus.NOT_GENERATED.value
word_count = 0
if chapter:
status_value = chapter.status or ChapterGenerationStatus.NOT_GENERATED.value
word_count = chapter.word_count or 0
# 只有在 include_content=True 时才包含完整内容
if include_content:
if chapter.selected_version:
content = chapter.selected_version.content
if chapter.versions:
versions = [
v.content
for v in sorted(chapter.versions, key=lambda item: item.created_at)
]
if chapter.evaluations:
latest = sorted(chapter.evaluations, key=lambda item: item.created_at)[-1]
evaluation_text = latest.feedback or latest.decision
return ChapterSchema(
chapter_number=chapter_number,
title=title,
summary=summary,
real_summary=real_summary,
content=content,
versions=versions,
evaluation=evaluation_text,
generation_status=ChapterGenerationStatus(status_value),
word_count=word_count,
)

View File

@@ -0,0 +1,96 @@
import asyncio
from typing import Dict, Optional
from sqlalchemy.ext.asyncio import AsyncSession
from ..models import Prompt
from ..repositories.prompt_repository import PromptRepository
from ..schemas.prompt import PromptCreate, PromptRead, PromptUpdate
_CACHE: Dict[str, PromptRead] = {}
_LOCK = asyncio.Lock()
_LOADED = False
class PromptService:
"""提示词服务,提供缓存加速与 CRUD 能力。"""
def __init__(self, session: AsyncSession):
self.session = session
self.repo = PromptRepository(session)
async def preload(self) -> None:
global _CACHE, _LOADED
prompts = await self.repo.list_all()
async with _LOCK:
_CACHE = {item.name: PromptRead.model_validate(item) for item in prompts}
_LOADED = True
async def get_prompt(self, name: str) -> Optional[str]:
global _LOADED
async with _LOCK:
if not _LOADED:
prompts = await self.repo.list_all()
_CACHE.update({item.name: PromptRead.model_validate(item) for item in prompts})
_LOADED = True
cached = _CACHE.get(name)
if cached:
return cached.content
prompt = await self.repo.get_by_name(name)
if not prompt:
return None
prompt_read = PromptRead.model_validate(prompt)
async with _LOCK:
_CACHE[name] = prompt_read
return prompt_read.content
async def list_prompts(self) -> list[PromptRead]:
prompts = await self.repo.list_all()
return [PromptRead.model_validate(item) for item in prompts]
async def get_prompt_by_id(self, prompt_id: int) -> Optional[PromptRead]:
instance = await self.repo.get(id=prompt_id)
if not instance:
return None
return PromptRead.model_validate(instance)
async def create_prompt(self, payload: PromptCreate) -> PromptRead:
data = payload.model_dump()
tags = data.get("tags")
if tags is not None:
data["tags"] = ",".join(tags)
prompt = Prompt(**data)
await self.repo.add(prompt)
await self.session.commit()
prompt_read = PromptRead.model_validate(prompt)
async with _LOCK:
_CACHE[prompt_read.name] = prompt_read
global _LOADED
_LOADED = True
return prompt_read
async def update_prompt(self, prompt_id: int, payload: PromptUpdate) -> Optional[PromptRead]:
instance = await self.repo.get(id=prompt_id)
if not instance:
return None
update_data = payload.model_dump(exclude_unset=True)
if "tags" in update_data and update_data["tags"] is not None:
update_data["tags"] = ",".join(update_data["tags"])
await self.repo.update_fields(instance, **update_data)
await self.session.commit()
prompt_read = PromptRead.model_validate(instance)
async with _LOCK:
_CACHE[prompt_read.name] = prompt_read
return prompt_read
async def delete_prompt(self, prompt_id: int) -> bool:
instance = await self.repo.get(id=prompt_id)
if not instance:
return False
await self.repo.delete(instance)
await self.session.commit()
async with _LOCK:
_CACHE.pop(instance.name, None)
return True

View File

@@ -0,0 +1,60 @@
from typing import List, Optional
from fastapi import HTTPException, status
from sqlalchemy import update
from sqlalchemy.ext.asyncio import AsyncSession
from ..models import UpdateLog
from ..repositories.update_log_repository import UpdateLogRepository
class UpdateLogService:
"""更新日志服务,提供增删改查能力,并保证置顶唯一。"""
def __init__(self, session: AsyncSession):
self.session = session
self.repo = UpdateLogRepository(session)
async def list_logs(self, limit: Optional[int] = None) -> List[UpdateLog]:
if limit is None:
return list(await self.repo.list())
return list(await self.repo.list_latest(limit))
async def create_log(self, content: str, creator: str | None = None, *, is_pinned: bool = False) -> UpdateLog:
if is_pinned:
await self._clear_pinned()
log = UpdateLog(content=content, created_by=creator, is_pinned=is_pinned)
await self.repo.add(log)
await self.session.commit()
await self.session.refresh(log)
return log
async def update_log(self, log_id: int, *, content: Optional[str] = None, is_pinned: Optional[bool] = None) -> UpdateLog:
log = await self.repo.get(id=log_id)
if not log:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="更新记录不存在")
updates = {}
if content is not None:
updates["content"] = content
if is_pinned is not None:
if is_pinned:
await self._clear_pinned()
updates["is_pinned"] = is_pinned
if updates:
await self.repo.update_fields(log, **updates)
await self.session.commit()
await self.session.refresh(log)
return log
async def delete_log(self, log_id: int) -> None:
log = await self.repo.get(id=log_id)
if not log:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="更新记录不存在")
await self.repo.delete(log)
await self.session.commit()
async def _clear_pinned(self) -> None:
await self.session.execute(update(UpdateLog).values(is_pinned=False))

View File

@@ -0,0 +1,21 @@
from sqlalchemy.ext.asyncio import AsyncSession
from ..repositories.usage_metric_repository import UsageMetricRepository
class UsageService:
"""通用计数服务,目前用于统计 API 请求次数等。"""
def __init__(self, session: AsyncSession):
self.session = session
self.repo = UsageMetricRepository(session)
async def increment(self, key: str) -> None:
counter = await self.repo.get_or_create(key)
counter.value += 1
await self.session.commit()
async def get_value(self, key: str) -> int:
counter = await self.repo.get_or_create(key)
await self.session.commit()
return counter.value

View File

@@ -0,0 +1,62 @@
from typing import Optional
from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession
from ..core.security import hash_password
from ..models import User
from ..repositories.user_repository import UserRepository
from ..schemas.user import UserCreate, UserInDB
class UserService:
"""用户领域服务,负责注册、查询与配额统计。"""
def __init__(self, session: AsyncSession):
self.session = session
self.repo = UserRepository(session)
async def create_user(self, payload: UserCreate, *, external_id: str | None = None) -> UserInDB:
hashed_password = hash_password(payload.password)
user = User(
username=payload.username,
email=payload.email,
hashed_password=hashed_password,
external_id=external_id,
)
self.session.add(user)
try:
await self.session.commit()
except IntegrityError as exc:
await self.session.rollback()
raise ValueError("用户名或邮箱已存在") from exc
return UserInDB.model_validate(user)
async def get_by_username(self, username: str) -> Optional[UserInDB]:
user = await self.repo.get_by_username(username)
return UserInDB.model_validate(user) if user else None
async def get_by_email(self, email: str) -> Optional[UserInDB]:
user = await self.repo.get_by_email(email)
return UserInDB.model_validate(user) if user else None
async def get_by_external_id(self, external_id: str) -> Optional[UserInDB]:
user = await self.repo.get_by_external_id(external_id)
return UserInDB.model_validate(user) if user else None
async def get_user(self, user_id: int) -> Optional[UserInDB]:
user = await self.repo.get(id=user_id)
return UserInDB.model_validate(user) if user else None
async def list_users(self) -> list[UserInDB]:
users = await self.repo.list_all()
return [UserInDB.model_validate(item) for item in users]
async def increment_daily_request(self, user_id: int) -> None:
await self.repo.increment_daily_request(user_id)
await self.session.commit()
async def get_daily_request(self, user_id: int) -> int:
return await self.repo.get_daily_request(user_id)

View File

@@ -0,0 +1,544 @@
from __future__ import annotations
"""
基于 libsql 的向量检索服务,封装章节内容的存储与查询。
本文件中的注释均使用中文,便于团队成员快速理解 RAG 相关逻辑。
"""
import json
import logging
import math
from array import array
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, Iterable, List, Optional, Sequence
from ..core.config import settings
try: # noqa: SIM105 - 明确区分依赖缺失的情况
import libsql_client
except ImportError: # pragma: no cover - 在未安装依赖时提供友好提示
libsql_client = None # type: ignore[assignment]
logger = logging.getLogger(__name__)
@dataclass
class RetrievedChunk:
"""向量检索得到的剧情片段。"""
content: str
chapter_number: int
chapter_title: Optional[str]
score: float
metadata: Dict[str, Any]
@dataclass
class RetrievedSummary:
"""向量检索得到的章节摘要。"""
chapter_number: int
title: str
summary: str
score: float
class VectorStoreService:
"""libsql 向量库操作工具,确保不同小说项目的数据隔离。"""
def __init__(self) -> None:
if not settings.vector_store_enabled:
logger.warning("未开启向量库配置RAG 检索将被跳过。")
self._client = None
self._schema_ready = True
return
if libsql_client is None: # pragma: no cover - 运行环境缺少依赖
raise RuntimeError("缺少 libsql-client 依赖,请先在环境中安装。")
url = settings.vector_db_url
if url and url.startswith("file:"):
path_part = url.split("file:", 1)[1]
resolved = Path(path_part).expanduser().resolve()
resolved.parent.mkdir(parents=True, exist_ok=True)
url = f"file:{resolved}"
logger.info("向量库使用本地文件: %s", resolved)
try:
logger.info("初始化 libsql 客户端: url=%s", url)
self._client = libsql_client.create_client(
url=url,
auth_token=settings.vector_db_auth_token,
)
except Exception as exc: # pragma: no cover - 连接异常仅打印日志
logger.error("初始化 libsql 客户端失败: %s", exc)
self._client = None
self._schema_ready = True
else:
self._schema_ready = False
logger.info("libsql 客户端初始化成功,等待建表。")
async def ensure_schema(self) -> None:
"""初始化向量表结构,保证系统首次运行即可使用。"""
if not self._client or self._schema_ready:
return
statements = [
"""
CREATE TABLE IF NOT EXISTS rag_chunks (
id TEXT PRIMARY KEY,
project_id TEXT NOT NULL,
chapter_number INTEGER NOT NULL,
chunk_index INTEGER NOT NULL,
chapter_title TEXT,
content TEXT NOT NULL,
embedding BLOB NOT NULL,
metadata TEXT,
created_at INTEGER DEFAULT (unixepoch())
)
""",
"""
CREATE INDEX IF NOT EXISTS idx_rag_chunks_project
ON rag_chunks(project_id, chapter_number)
""",
"""
CREATE TABLE IF NOT EXISTS rag_summaries (
id TEXT PRIMARY KEY,
project_id TEXT NOT NULL,
chapter_number INTEGER NOT NULL,
title TEXT NOT NULL,
summary TEXT NOT NULL,
embedding BLOB NOT NULL,
created_at INTEGER DEFAULT (unixepoch())
)
""",
"""
CREATE INDEX IF NOT EXISTS idx_rag_summaries_project
ON rag_summaries(project_id, chapter_number)
""",
]
try:
for sql in statements:
await self._client.execute(sql) # type: ignore[union-attr]
logger.info("已确保向量库表结构存在。")
except Exception as exc: # pragma: no cover - 初始化失败时记录日志
logger.error("创建向量库表结构失败: %s", exc)
else:
self._schema_ready = True
async def query_chunks(
self,
*,
project_id: str,
embedding: Sequence[float],
top_k: Optional[int] = None,
) -> List[RetrievedChunk]:
"""根据查询向量检索剧情片段,结果已按相似度排序。"""
if not self._client or not embedding:
return []
await self.ensure_schema()
top_k = top_k or settings.vector_top_k_chunks
if top_k <= 0:
return []
blob = self._to_f32_blob(embedding)
sql = """
SELECT
content,
chapter_number,
chapter_title,
COALESCE(metadata, '{}') AS metadata,
vector_distance_cosine(embedding, :query) AS distance
FROM rag_chunks
WHERE project_id = :project_id
ORDER BY distance ASC
LIMIT :limit
"""
try:
result = await self._client.execute( # type: ignore[union-attr]
sql,
{
"project_id": project_id,
"query": blob,
"limit": top_k,
},
)
except Exception as exc: # pragma: no cover - 查询异常时仅记录
if "no such function: vector_distance_cosine" in str(exc).lower():
logger.warning("向量库缺少 vector_distance_cosine 函数,回退至应用层相似度计算。")
return await self._query_chunks_with_python_similarity(
project_id=project_id,
embedding=embedding,
top_k=top_k,
)
logger.warning("向量检索剧情片段失败: %s", exc)
return []
items: List[RetrievedChunk] = []
for row in self._iter_rows(result):
items.append(
RetrievedChunk(
content=row.get("content", ""),
chapter_number=row.get("chapter_number", 0),
chapter_title=row.get("chapter_title"),
score=row.get("distance", 0.0),
metadata=self._parse_metadata(row.get("metadata")),
)
)
return items
async def query_summaries(
self,
*,
project_id: str,
embedding: Sequence[float],
top_k: Optional[int] = None,
) -> List[RetrievedSummary]:
"""根据查询向量检索章节摘要列表。"""
if not self._client or not embedding:
return []
await self.ensure_schema()
top_k = top_k or settings.vector_top_k_summaries
if top_k <= 0:
return []
blob = self._to_f32_blob(embedding)
sql = """
SELECT
chapter_number,
title,
summary,
vector_distance_cosine(embedding, :query) AS distance
FROM rag_summaries
WHERE project_id = :project_id
ORDER BY distance ASC
LIMIT :limit
"""
try:
result = await self._client.execute( # type: ignore[union-attr]
sql,
{
"project_id": project_id,
"query": blob,
"limit": top_k,
},
)
except Exception as exc: # pragma: no cover - 查询异常时仅记录
if "no such function: vector_distance_cosine" in str(exc).lower():
logger.warning("向量库缺少 vector_distance_cosine 函数,回退至应用层相似度计算。")
return await self._query_summaries_with_python_similarity(
project_id=project_id,
embedding=embedding,
top_k=top_k,
)
logger.warning("向量检索章节摘要失败: %s", exc)
return []
items: List[RetrievedSummary] = []
for row in self._iter_rows(result):
items.append(
RetrievedSummary(
chapter_number=row.get("chapter_number", 0),
title=row.get("title", ""),
summary=row.get("summary", ""),
score=row.get("distance", 0.0),
)
)
return items
async def upsert_chunks(
self,
*,
records: Iterable[Dict[str, Any]],
) -> None:
"""批量写入章节片段,供后续检索使用。"""
if not self._client:
return
await self.ensure_schema()
sql = """
INSERT INTO rag_chunks (
id,
project_id,
chapter_number,
chunk_index,
chapter_title,
content,
embedding,
metadata
) VALUES (
:id,
:project_id,
:chapter_number,
:chunk_index,
:chapter_title,
:content,
:embedding,
:metadata
)
ON CONFLICT(id) DO UPDATE SET
content=excluded.content,
embedding=excluded.embedding,
metadata=excluded.metadata,
chapter_title=excluded.chapter_title
"""
payload = []
for item in records:
embedding = item.get("embedding", [])
payload.append(
{
**item,
"embedding": self._to_f32_blob(embedding),
"metadata": json.dumps(item.get("metadata") or {}, ensure_ascii=False),
}
)
if not payload:
return
for item in payload:
try:
await self._client.execute(sql, item) # type: ignore[union-attr]
except Exception as exc: # pragma: no cover - 单条写入失败时记录日志
logger.error("写入 rag_chunks 失败: %s", exc)
else:
logger.debug(
"已写入章节片段: project=%s chapter=%s chunk=%s",
item.get("project_id"),
item.get("chapter_number"),
item.get("chunk_index"),
)
async def upsert_summaries(
self,
*,
records: Iterable[Dict[str, Any]],
) -> None:
"""同步章节摘要向量,供摘要层检索使用。"""
if not self._client:
return
await self.ensure_schema()
sql = """
INSERT INTO rag_summaries (
id,
project_id,
chapter_number,
title,
summary,
embedding
) VALUES (
:id,
:project_id,
:chapter_number,
:title,
:summary,
:embedding
)
ON CONFLICT(id) DO UPDATE SET
summary=excluded.summary,
embedding=excluded.embedding,
title=excluded.title
"""
payload = []
for item in records:
embedding = item.get("embedding", [])
payload.append(
{
**item,
"embedding": self._to_f32_blob(embedding),
}
)
if not payload:
return
for item in payload:
try:
await self._client.execute(sql, item) # type: ignore[union-attr]
except Exception as exc: # pragma: no cover - 单条写入失败时记录日志
logger.error("写入 rag_summaries 失败: %s", exc)
else:
logger.debug(
"已写入章节摘要: project=%s chapter=%s",
item.get("project_id"),
item.get("chapter_number"),
)
async def delete_by_chapters(self, project_id: str, chapter_numbers: Sequence[int]) -> None:
"""根据章节编号批量删除对应的上下文数据。"""
if not self._client or not chapter_numbers:
return
await self.ensure_schema()
placeholders = ",".join(":chapter_" + str(idx) for idx in range(len(chapter_numbers)))
params = {
"project_id": project_id,
**{f"chapter_{idx}": number for idx, number in enumerate(chapter_numbers)},
}
chunk_sql = f"""
DELETE FROM rag_chunks
WHERE project_id = :project_id
AND chapter_number IN ({placeholders})
"""
summary_sql = f"""
DELETE FROM rag_summaries
WHERE project_id = :project_id
AND chapter_number IN ({placeholders})
"""
try:
await self._client.execute(chunk_sql, params) # type: ignore[union-attr]
await self._client.execute(summary_sql, params) # type: ignore[union-attr]
logger.info(
"已删除章节向量: project=%s chapters=%s",
project_id,
list(chapter_numbers),
)
except Exception as exc: # pragma: no cover - 删除失败时记录日志
logger.error("删除章节向量失败: project=%s chapters=%s error=%s", project_id, chapter_numbers, exc)
@staticmethod
def _to_f32_blob(embedding: Sequence[float]) -> bytes:
"""将向量浮点列表编码为 libsql 可识别的 float32 二进制。"""
return array("f", embedding).tobytes()
@staticmethod
def _from_f32_blob(blob: Any) -> List[float]:
"""将数据库中的 BLOB 解码为浮点列表。"""
if not blob:
return []
if isinstance(blob, memoryview):
blob = blob.tobytes()
data = array("f")
data.frombytes(bytes(blob))
return list(data)
@staticmethod
def _cosine_distance(vec_a: Sequence[float], vec_b: Sequence[float]) -> float:
"""计算余弦距离1 - similarity避免除零。"""
if not vec_a or not vec_b:
return 1.0
dot = sum(a * b for a, b in zip(vec_a, vec_b))
norm_a = math.sqrt(sum(a * a for a in vec_a))
norm_b = math.sqrt(sum(b * b for b in vec_b))
if norm_a == 0 or norm_b == 0:
return 1.0
similarity = dot / (norm_a * norm_b)
return 1.0 - similarity
async def _query_chunks_with_python_similarity(
self,
*,
project_id: str,
embedding: Sequence[float],
top_k: int,
) -> List[RetrievedChunk]:
sql = """
SELECT
content,
chapter_number,
chapter_title,
COALESCE(metadata, '{}') AS metadata,
embedding
FROM rag_chunks
WHERE project_id = :project_id
"""
result = await self._client.execute(sql, {"project_id": project_id}) # type: ignore[union-attr]
scored: List[RetrievedChunk] = []
for row in self._iter_rows(result):
stored_embedding = self._from_f32_blob(row.get("embedding"))
distance = self._cosine_distance(embedding, stored_embedding)
scored.append(
RetrievedChunk(
content=row.get("content", ""),
chapter_number=row.get("chapter_number", 0),
chapter_title=row.get("chapter_title"),
score=distance,
metadata=self._parse_metadata(row.get("metadata")),
)
)
scored.sort(key=lambda item: item.score)
return scored[:top_k]
async def _query_summaries_with_python_similarity(
self,
*,
project_id: str,
embedding: Sequence[float],
top_k: int,
) -> List[RetrievedSummary]:
sql = """
SELECT
chapter_number,
title,
summary,
embedding
FROM rag_summaries
WHERE project_id = :project_id
"""
result = await self._client.execute(sql, {"project_id": project_id}) # type: ignore[union-attr]
scored: List[RetrievedSummary] = []
for row in self._iter_rows(result):
stored_embedding = self._from_f32_blob(row.get("embedding"))
distance = self._cosine_distance(embedding, stored_embedding)
scored.append(
RetrievedSummary(
chapter_number=row.get("chapter_number", 0),
title=row.get("title", ""),
summary=row.get("summary", ""),
score=distance,
)
)
scored.sort(key=lambda item: item.score)
return scored[:top_k]
@staticmethod
def _parse_metadata(raw: Any) -> Dict[str, Any]:
"""解析存储的 JSON 文本,确保输出为 dict。"""
if not raw:
return {}
if isinstance(raw, dict):
return raw
if isinstance(raw, (bytes, bytearray)):
raw = raw.decode("utf-8")
if isinstance(raw, str):
try:
parsed = json.loads(raw)
return parsed if isinstance(parsed, dict) else {}
except json.JSONDecodeError:
return {}
return {}
@staticmethod
def _iter_rows(result: Any) -> Iterable[Dict[str, Any]]:
"""统一处理 libsql 返回的行数据,确保以 dict 形式迭代。"""
rows = getattr(result, "rows", None)
if rows is None:
rows = result
if not rows:
return []
normalized: List[Dict[str, Any]] = []
for row in rows:
if isinstance(row, dict):
normalized.append(row)
elif hasattr(row, "_asdict"):
normalized.append(row._asdict()) # type: ignore[attr-defined]
else:
try:
normalized.append(dict(row))
except Exception: # pragma: no cover - 无法转换时跳过
continue
return normalized
__all__ = [
"VectorStoreService",
"RetrievedChunk",
"RetrievedSummary",
]