97 lines
3.4 KiB
Python
97 lines
3.4 KiB
Python
import asyncio
|
|
from typing import Dict, Optional
|
|
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from ..models import Prompt
|
|
from ..repositories.prompt_repository import PromptRepository
|
|
from ..schemas.prompt import PromptCreate, PromptRead, PromptUpdate
|
|
|
|
_CACHE: Dict[str, PromptRead] = {}
|
|
_LOCK = asyncio.Lock()
|
|
_LOADED = False
|
|
|
|
|
|
class PromptService:
|
|
"""提示词服务,提供缓存加速与 CRUD 能力。"""
|
|
|
|
def __init__(self, session: AsyncSession):
|
|
self.session = session
|
|
self.repo = PromptRepository(session)
|
|
|
|
async def preload(self) -> None:
|
|
global _CACHE, _LOADED
|
|
prompts = await self.repo.list_all()
|
|
async with _LOCK:
|
|
_CACHE = {item.name: PromptRead.model_validate(item) for item in prompts}
|
|
_LOADED = True
|
|
|
|
async def get_prompt(self, name: str) -> Optional[str]:
|
|
global _LOADED
|
|
async with _LOCK:
|
|
if not _LOADED:
|
|
prompts = await self.repo.list_all()
|
|
_CACHE.update({item.name: PromptRead.model_validate(item) for item in prompts})
|
|
_LOADED = True
|
|
cached = _CACHE.get(name)
|
|
if cached:
|
|
return cached.content
|
|
|
|
prompt = await self.repo.get_by_name(name)
|
|
if not prompt:
|
|
return None
|
|
|
|
prompt_read = PromptRead.model_validate(prompt)
|
|
async with _LOCK:
|
|
_CACHE[name] = prompt_read
|
|
return prompt_read.content
|
|
|
|
async def list_prompts(self) -> list[PromptRead]:
|
|
prompts = await self.repo.list_all()
|
|
return [PromptRead.model_validate(item) for item in prompts]
|
|
|
|
async def get_prompt_by_id(self, prompt_id: int) -> Optional[PromptRead]:
|
|
instance = await self.repo.get(id=prompt_id)
|
|
if not instance:
|
|
return None
|
|
return PromptRead.model_validate(instance)
|
|
|
|
async def create_prompt(self, payload: PromptCreate) -> PromptRead:
|
|
data = payload.model_dump()
|
|
tags = data.get("tags")
|
|
if tags is not None:
|
|
data["tags"] = ",".join(tags)
|
|
prompt = Prompt(**data)
|
|
await self.repo.add(prompt)
|
|
await self.session.commit()
|
|
prompt_read = PromptRead.model_validate(prompt)
|
|
async with _LOCK:
|
|
_CACHE[prompt_read.name] = prompt_read
|
|
global _LOADED
|
|
_LOADED = True
|
|
return prompt_read
|
|
|
|
async def update_prompt(self, prompt_id: int, payload: PromptUpdate) -> Optional[PromptRead]:
|
|
instance = await self.repo.get(id=prompt_id)
|
|
if not instance:
|
|
return None
|
|
update_data = payload.model_dump(exclude_unset=True)
|
|
if "tags" in update_data and update_data["tags"] is not None:
|
|
update_data["tags"] = ",".join(update_data["tags"])
|
|
await self.repo.update_fields(instance, **update_data)
|
|
await self.session.commit()
|
|
prompt_read = PromptRead.model_validate(instance)
|
|
async with _LOCK:
|
|
_CACHE[prompt_read.name] = prompt_read
|
|
return prompt_read
|
|
|
|
async def delete_prompt(self, prompt_id: int) -> bool:
|
|
instance = await self.repo.get(id=prompt_id)
|
|
if not instance:
|
|
return False
|
|
await self.repo.delete(instance)
|
|
await self.session.commit()
|
|
async with _LOCK:
|
|
_CACHE.pop(instance.name, None)
|
|
return True
|