feat: 初始提交
This commit is contained in:
96
backend/app/services/prompt_service.py
Normal file
96
backend/app/services/prompt_service.py
Normal file
@@ -0,0 +1,96 @@
|
||||
import asyncio
|
||||
from typing import Dict, Optional
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from ..models import Prompt
|
||||
from ..repositories.prompt_repository import PromptRepository
|
||||
from ..schemas.prompt import PromptCreate, PromptRead, PromptUpdate
|
||||
|
||||
_CACHE: Dict[str, PromptRead] = {}
|
||||
_LOCK = asyncio.Lock()
|
||||
_LOADED = False
|
||||
|
||||
|
||||
class PromptService:
|
||||
"""提示词服务,提供缓存加速与 CRUD 能力。"""
|
||||
|
||||
def __init__(self, session: AsyncSession):
|
||||
self.session = session
|
||||
self.repo = PromptRepository(session)
|
||||
|
||||
async def preload(self) -> None:
|
||||
global _CACHE, _LOADED
|
||||
prompts = await self.repo.list_all()
|
||||
async with _LOCK:
|
||||
_CACHE = {item.name: PromptRead.model_validate(item) for item in prompts}
|
||||
_LOADED = True
|
||||
|
||||
async def get_prompt(self, name: str) -> Optional[str]:
|
||||
global _LOADED
|
||||
async with _LOCK:
|
||||
if not _LOADED:
|
||||
prompts = await self.repo.list_all()
|
||||
_CACHE.update({item.name: PromptRead.model_validate(item) for item in prompts})
|
||||
_LOADED = True
|
||||
cached = _CACHE.get(name)
|
||||
if cached:
|
||||
return cached.content
|
||||
|
||||
prompt = await self.repo.get_by_name(name)
|
||||
if not prompt:
|
||||
return None
|
||||
|
||||
prompt_read = PromptRead.model_validate(prompt)
|
||||
async with _LOCK:
|
||||
_CACHE[name] = prompt_read
|
||||
return prompt_read.content
|
||||
|
||||
async def list_prompts(self) -> list[PromptRead]:
|
||||
prompts = await self.repo.list_all()
|
||||
return [PromptRead.model_validate(item) for item in prompts]
|
||||
|
||||
async def get_prompt_by_id(self, prompt_id: int) -> Optional[PromptRead]:
|
||||
instance = await self.repo.get(id=prompt_id)
|
||||
if not instance:
|
||||
return None
|
||||
return PromptRead.model_validate(instance)
|
||||
|
||||
async def create_prompt(self, payload: PromptCreate) -> PromptRead:
|
||||
data = payload.model_dump()
|
||||
tags = data.get("tags")
|
||||
if tags is not None:
|
||||
data["tags"] = ",".join(tags)
|
||||
prompt = Prompt(**data)
|
||||
await self.repo.add(prompt)
|
||||
await self.session.commit()
|
||||
prompt_read = PromptRead.model_validate(prompt)
|
||||
async with _LOCK:
|
||||
_CACHE[prompt_read.name] = prompt_read
|
||||
global _LOADED
|
||||
_LOADED = True
|
||||
return prompt_read
|
||||
|
||||
async def update_prompt(self, prompt_id: int, payload: PromptUpdate) -> Optional[PromptRead]:
|
||||
instance = await self.repo.get(id=prompt_id)
|
||||
if not instance:
|
||||
return None
|
||||
update_data = payload.model_dump(exclude_unset=True)
|
||||
if "tags" in update_data and update_data["tags"] is not None:
|
||||
update_data["tags"] = ",".join(update_data["tags"])
|
||||
await self.repo.update_fields(instance, **update_data)
|
||||
await self.session.commit()
|
||||
prompt_read = PromptRead.model_validate(instance)
|
||||
async with _LOCK:
|
||||
_CACHE[prompt_read.name] = prompt_read
|
||||
return prompt_read
|
||||
|
||||
async def delete_prompt(self, prompt_id: int) -> bool:
|
||||
instance = await self.repo.get(id=prompt_id)
|
||||
if not instance:
|
||||
return False
|
||||
await self.repo.delete(instance)
|
||||
await self.session.commit()
|
||||
async with _LOCK:
|
||||
_CACHE.pop(instance.name, None)
|
||||
return True
|
||||
Reference in New Issue
Block a user