307 lines
12 KiB
Python
307 lines
12 KiB
Python
import logging
|
|
import os
|
|
from typing import Any, Dict, List, Optional
|
|
|
|
import httpx
|
|
from fastapi import HTTPException, status
|
|
from openai import APIConnectionError, APITimeoutError, AsyncOpenAI, InternalServerError
|
|
|
|
from ..core.config import settings
|
|
from ..repositories.llm_config_repository import LLMConfigRepository
|
|
from ..repositories.system_config_repository import SystemConfigRepository
|
|
from ..repositories.user_repository import UserRepository
|
|
from ..services.admin_setting_service import AdminSettingService
|
|
from ..services.prompt_service import PromptService
|
|
from ..services.usage_service import UsageService
|
|
from ..utils.llm_tool import ChatMessage, LLMClient
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
try: # pragma: no cover - 运行环境未安装时兼容
|
|
from ollama import AsyncClient as OllamaAsyncClient
|
|
except ImportError: # pragma: no cover - Ollama 为可选依赖
|
|
OllamaAsyncClient = None
|
|
|
|
|
|
class LLMService:
|
|
"""封装与大模型交互的所有逻辑,包括配额控制与配置选择。"""
|
|
|
|
def __init__(self, session):
|
|
self.session = session
|
|
self.llm_repo = LLMConfigRepository(session)
|
|
self.system_config_repo = SystemConfigRepository(session)
|
|
self.user_repo = UserRepository(session)
|
|
self.admin_setting_service = AdminSettingService(session)
|
|
self.usage_service = UsageService(session)
|
|
self._embedding_dimensions: Dict[str, int] = {}
|
|
|
|
async def get_llm_response(
|
|
self,
|
|
system_prompt: str,
|
|
conversation_history: List[Dict[str, str]],
|
|
*,
|
|
temperature: float = 0.7,
|
|
user_id: Optional[int] = None,
|
|
timeout: float = 300.0,
|
|
response_format: Optional[str] = "json_object",
|
|
) -> str:
|
|
messages = [{"role": "system", "content": system_prompt}, *conversation_history]
|
|
return await self._stream_and_collect(
|
|
messages,
|
|
temperature=temperature,
|
|
user_id=user_id,
|
|
timeout=timeout,
|
|
response_format=response_format,
|
|
)
|
|
|
|
async def get_summary(
|
|
self,
|
|
chapter_content: str,
|
|
*,
|
|
temperature: float = 0.2,
|
|
user_id: Optional[int] = None,
|
|
timeout: float = 180.0,
|
|
system_prompt: Optional[str] = None,
|
|
) -> str:
|
|
if not system_prompt:
|
|
prompt_service = PromptService(self.session)
|
|
system_prompt = await prompt_service.get_prompt("extraction")
|
|
if not system_prompt:
|
|
raise HTTPException(status_code=500, detail="未配置摘要提示词")
|
|
messages = [
|
|
{"role": "system", "content": system_prompt},
|
|
{"role": "user", "content": chapter_content},
|
|
]
|
|
return await self._stream_and_collect(messages, temperature=temperature, user_id=user_id, timeout=timeout)
|
|
|
|
async def _stream_and_collect(
|
|
self,
|
|
messages: List[Dict[str, str]],
|
|
*,
|
|
temperature: float,
|
|
user_id: Optional[int],
|
|
timeout: float,
|
|
response_format: Optional[str] = None,
|
|
) -> str:
|
|
config = await self._resolve_llm_config(user_id)
|
|
client = LLMClient(api_key=config["api_key"], base_url=config.get("base_url"))
|
|
|
|
chat_messages = [ChatMessage(role=msg["role"], content=msg["content"]) for msg in messages]
|
|
|
|
full_response = ""
|
|
finish_reason = None
|
|
|
|
logger.info(
|
|
"Streaming LLM response: model=%s user_id=%s messages=%d",
|
|
config.get("model"),
|
|
user_id,
|
|
len(messages),
|
|
)
|
|
|
|
try:
|
|
async for part in client.stream_chat(
|
|
messages=chat_messages,
|
|
model=config.get("model"),
|
|
temperature=temperature,
|
|
timeout=int(timeout),
|
|
response_format=response_format,
|
|
):
|
|
if part.get("content"):
|
|
full_response += part["content"]
|
|
if part.get("finish_reason"):
|
|
finish_reason = part["finish_reason"]
|
|
except InternalServerError as exc:
|
|
detail = "AI 服务内部错误,请稍后重试"
|
|
response = getattr(exc, "response", None)
|
|
if response is not None:
|
|
try:
|
|
payload = response.json()
|
|
error_data = payload.get("error", {}) if isinstance(payload, dict) else {}
|
|
detail = error_data.get("message_zh") or error_data.get("message") or detail
|
|
except Exception:
|
|
detail = str(exc) or detail
|
|
else:
|
|
detail = str(exc) or detail
|
|
logger.error(
|
|
"LLM stream internal error: model=%s user_id=%s detail=%s",
|
|
config.get("model"),
|
|
user_id,
|
|
detail,
|
|
exc_info=exc,
|
|
)
|
|
raise HTTPException(status_code=503, detail=detail)
|
|
except (httpx.RemoteProtocolError, httpx.ReadTimeout, APIConnectionError, APITimeoutError) as exc:
|
|
if isinstance(exc, httpx.RemoteProtocolError):
|
|
detail = "AI 服务连接被意外中断,请稍后重试"
|
|
elif isinstance(exc, (httpx.ReadTimeout, APITimeoutError)):
|
|
detail = "AI 服务响应超时,请稍后重试"
|
|
else:
|
|
detail = "无法连接到 AI 服务,请稍后重试"
|
|
logger.error(
|
|
"LLM stream failed: model=%s user_id=%s detail=%s",
|
|
config.get("model"),
|
|
user_id,
|
|
detail,
|
|
exc_info=exc,
|
|
)
|
|
raise HTTPException(status_code=503, detail=detail) from exc
|
|
|
|
logger.debug(
|
|
"LLM response collected: model=%s user_id=%s finish_reason=%s preview=%s",
|
|
config.get("model"),
|
|
user_id,
|
|
finish_reason,
|
|
full_response[:500],
|
|
)
|
|
|
|
if finish_reason == "length":
|
|
logger.warning(
|
|
"LLM response truncated: model=%s user_id=%s",
|
|
config.get("model"),
|
|
user_id,
|
|
)
|
|
raise HTTPException(status_code=500, detail="AI 响应被截断,请缩短输入或调整参数")
|
|
|
|
if not full_response:
|
|
logger.error(
|
|
"LLM returned empty response: model=%s user_id=%s",
|
|
config.get("model"),
|
|
user_id,
|
|
)
|
|
raise HTTPException(status_code=500, detail="AI 未返回有效内容")
|
|
|
|
await self.usage_service.increment("api_request_count")
|
|
logger.info(
|
|
"LLM response success: model=%s user_id=%s chars=%d",
|
|
config.get("model"),
|
|
user_id,
|
|
len(full_response),
|
|
)
|
|
return full_response
|
|
|
|
async def _resolve_llm_config(self, user_id: Optional[int]) -> Dict[str, Optional[str]]:
|
|
if user_id:
|
|
config = await self.llm_repo.get_by_user(user_id)
|
|
if config and config.llm_provider_api_key:
|
|
return {
|
|
"api_key": config.llm_provider_api_key,
|
|
"base_url": config.llm_provider_url,
|
|
"model": config.llm_provider_model,
|
|
}
|
|
|
|
# 检查每日使用次数限制
|
|
if user_id:
|
|
await self._enforce_daily_limit(user_id)
|
|
|
|
api_key = await self._get_config_value("llm.api_key")
|
|
base_url = await self._get_config_value("llm.base_url")
|
|
model = await self._get_config_value("llm.model")
|
|
|
|
if not api_key:
|
|
raise HTTPException(status_code=500, detail="未配置默认 LLM API Key")
|
|
|
|
return {"api_key": api_key, "base_url": base_url, "model": model}
|
|
|
|
async def get_embedding(
|
|
self,
|
|
text: str,
|
|
*,
|
|
user_id: Optional[int] = None,
|
|
model: Optional[str] = None,
|
|
) -> List[float]:
|
|
"""生成文本向量,用于章节 RAG 检索,支持 openai 与 ollama 双提供方。"""
|
|
provider = settings.embedding_provider
|
|
target_model = model or (
|
|
settings.ollama_embedding_model if provider == "ollama" else settings.embedding_model
|
|
)
|
|
|
|
if provider == "ollama":
|
|
if OllamaAsyncClient is None:
|
|
logger.error("未安装 ollama 依赖,无法调用本地嵌入模型。")
|
|
raise HTTPException(status_code=500, detail="缺少 Ollama 依赖,请先安装 ollama 包。")
|
|
|
|
base_url_any = settings.ollama_embedding_base_url or settings.embedding_base_url
|
|
base_url = str(base_url_any) if base_url_any else None
|
|
client = OllamaAsyncClient(host=base_url)
|
|
try:
|
|
response = await client.embeddings(model=target_model, prompt=text)
|
|
except Exception as exc: # pragma: no cover - 本地服务调用失败
|
|
logger.warning(
|
|
"Ollama 嵌入请求失败: model=%s error=%s",
|
|
target_model,
|
|
exc,
|
|
)
|
|
return []
|
|
embedding: Optional[List[float]]
|
|
if isinstance(response, dict):
|
|
embedding = response.get("embedding")
|
|
else:
|
|
embedding = getattr(response, "embedding", None)
|
|
if not embedding:
|
|
logger.warning("Ollama 返回空向量: model=%s", target_model)
|
|
return []
|
|
if not isinstance(embedding, list):
|
|
embedding = list(embedding)
|
|
else:
|
|
config = await self._resolve_llm_config(user_id)
|
|
api_key = settings.embedding_api_key or config["api_key"]
|
|
base_url_setting = settings.embedding_base_url or config.get("base_url")
|
|
base_url = str(base_url_setting) if base_url_setting else None
|
|
client = AsyncOpenAI(api_key=api_key, base_url=base_url)
|
|
try:
|
|
response = await client.embeddings.create(
|
|
input=text,
|
|
model=target_model,
|
|
)
|
|
except Exception as exc: # pragma: no cover - 网络或鉴权失败
|
|
logger.warning(
|
|
"OpenAI 嵌入请求失败: model=%s user_id=%s error=%s",
|
|
target_model,
|
|
user_id,
|
|
exc,
|
|
)
|
|
return []
|
|
if not response.data:
|
|
logger.warning("OpenAI 嵌入请求返回空数据: model=%s user_id=%s", target_model, user_id)
|
|
return []
|
|
embedding = response.data[0].embedding
|
|
|
|
if not isinstance(embedding, list):
|
|
embedding = list(embedding)
|
|
|
|
dimension = len(embedding)
|
|
if not dimension and settings.embedding_model_vector_size:
|
|
dimension = settings.embedding_model_vector_size
|
|
if dimension:
|
|
self._embedding_dimensions[target_model] = dimension
|
|
return embedding
|
|
|
|
def get_embedding_dimension(self, model: Optional[str] = None) -> Optional[int]:
|
|
"""获取嵌入向量维度,优先返回缓存结果,其次读取配置。"""
|
|
target_model = model or (
|
|
settings.ollama_embedding_model if settings.embedding_provider == "ollama" else settings.embedding_model
|
|
)
|
|
if target_model in self._embedding_dimensions:
|
|
return self._embedding_dimensions[target_model]
|
|
return settings.embedding_model_vector_size
|
|
|
|
async def _enforce_daily_limit(self, user_id: int) -> None:
|
|
limit_str = await self.admin_setting_service.get("daily_request_limit", "100")
|
|
limit = int(limit_str or 10)
|
|
used = await self.user_repo.get_daily_request(user_id)
|
|
if used >= limit:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
|
|
detail="今日请求次数已达上限,请明日再试或设置自定义 API Key。",
|
|
)
|
|
await self.user_repo.increment_daily_request(user_id)
|
|
await self.session.commit()
|
|
|
|
async def _get_config_value(self, key: str) -> Optional[str]:
|
|
record = await self.system_config_repo.get_by_key(key)
|
|
if record:
|
|
return record.value
|
|
# 兼容环境变量,首次迁移时无需立即写入数据库
|
|
env_key = key.upper().replace(".", "_")
|
|
return os.getenv(env_key)
|