Files
arboris-novel/backend/app/services/prompt_service.py
2025-10-21 09:51:27 +08:00

97 lines
3.4 KiB
Python

import asyncio
from typing import Dict, Optional
from sqlalchemy.ext.asyncio import AsyncSession
from ..models import Prompt
from ..repositories.prompt_repository import PromptRepository
from ..schemas.prompt import PromptCreate, PromptRead, PromptUpdate
_CACHE: Dict[str, PromptRead] = {}
_LOCK = asyncio.Lock()
_LOADED = False
class PromptService:
"""提示词服务,提供缓存加速与 CRUD 能力。"""
def __init__(self, session: AsyncSession):
self.session = session
self.repo = PromptRepository(session)
async def preload(self) -> None:
global _CACHE, _LOADED
prompts = await self.repo.list_all()
async with _LOCK:
_CACHE = {item.name: PromptRead.model_validate(item) for item in prompts}
_LOADED = True
async def get_prompt(self, name: str) -> Optional[str]:
global _LOADED
async with _LOCK:
if not _LOADED:
prompts = await self.repo.list_all()
_CACHE.update({item.name: PromptRead.model_validate(item) for item in prompts})
_LOADED = True
cached = _CACHE.get(name)
if cached:
return cached.content
prompt = await self.repo.get_by_name(name)
if not prompt:
return None
prompt_read = PromptRead.model_validate(prompt)
async with _LOCK:
_CACHE[name] = prompt_read
return prompt_read.content
async def list_prompts(self) -> list[PromptRead]:
prompts = await self.repo.list_all()
return [PromptRead.model_validate(item) for item in prompts]
async def get_prompt_by_id(self, prompt_id: int) -> Optional[PromptRead]:
instance = await self.repo.get(id=prompt_id)
if not instance:
return None
return PromptRead.model_validate(instance)
async def create_prompt(self, payload: PromptCreate) -> PromptRead:
data = payload.model_dump()
tags = data.get("tags")
if tags is not None:
data["tags"] = ",".join(tags)
prompt = Prompt(**data)
await self.repo.add(prompt)
await self.session.commit()
prompt_read = PromptRead.model_validate(prompt)
async with _LOCK:
_CACHE[prompt_read.name] = prompt_read
global _LOADED
_LOADED = True
return prompt_read
async def update_prompt(self, prompt_id: int, payload: PromptUpdate) -> Optional[PromptRead]:
instance = await self.repo.get(id=prompt_id)
if not instance:
return None
update_data = payload.model_dump(exclude_unset=True)
if "tags" in update_data and update_data["tags"] is not None:
update_data["tags"] = ",".join(update_data["tags"])
await self.repo.update_fields(instance, **update_data)
await self.session.commit()
prompt_read = PromptRead.model_validate(instance)
async with _LOCK:
_CACHE[prompt_read.name] = prompt_read
return prompt_read
async def delete_prompt(self, prompt_id: int) -> bool:
instance = await self.repo.get(id=prompt_id)
if not instance:
return False
await self.repo.delete(instance)
await self.session.commit()
async with _LOCK:
_CACHE.pop(instance.name, None)
return True