feat: 初始提交
This commit is contained in:
0
backend/app/repositories/__init__.py
Normal file
0
backend/app/repositories/__init__.py
Normal file
15
backend/app/repositories/admin_setting_repository.py
Normal file
15
backend/app/repositories/admin_setting_repository.py
Normal file
@@ -0,0 +1,15 @@
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from .base import BaseRepository
|
||||
from ..models import AdminSetting
|
||||
|
||||
|
||||
class AdminSettingRepository(BaseRepository[AdminSetting]):
|
||||
model = AdminSetting
|
||||
|
||||
async def get_value(self, key: str) -> Optional[str]:
|
||||
result = await self.session.execute(select(AdminSetting).where(AdminSetting.key == key))
|
||||
record = result.scalars().first()
|
||||
return record.value if record else None
|
||||
44
backend/app/repositories/base.py
Normal file
44
backend/app/repositories/base.py
Normal file
@@ -0,0 +1,44 @@
|
||||
from typing import Any, Generic, Iterable, Optional, TypeVar
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import InstrumentedAttribute
|
||||
|
||||
ModelType = TypeVar("ModelType")
|
||||
|
||||
|
||||
class BaseRepository(Generic[ModelType]):
|
||||
"""通用仓储基类,封装常见的增删改查操作。"""
|
||||
|
||||
model: type[ModelType]
|
||||
|
||||
def __init__(self, session: AsyncSession):
|
||||
self.session = session
|
||||
|
||||
async def get(self, **filters: Any) -> Optional[ModelType]:
|
||||
stmt = select(self.model).filter_by(**filters)
|
||||
result = await self.session.execute(stmt)
|
||||
return result.scalars().first()
|
||||
|
||||
async def list(self, *, filters: Optional[dict[str, Any]] = None) -> Iterable[ModelType]:
|
||||
stmt = select(self.model)
|
||||
if filters:
|
||||
stmt = stmt.filter_by(**filters)
|
||||
result = await self.session.execute(stmt)
|
||||
return result.scalars().all()
|
||||
|
||||
async def add(self, instance: ModelType) -> ModelType:
|
||||
self.session.add(instance)
|
||||
await self.session.flush()
|
||||
return instance
|
||||
|
||||
async def delete(self, instance: ModelType) -> None:
|
||||
await self.session.delete(instance)
|
||||
|
||||
async def update_fields(self, instance: ModelType, **values: Any) -> ModelType:
|
||||
for key, value in values.items():
|
||||
if value is None:
|
||||
continue
|
||||
setattr(instance, key, value)
|
||||
await self.session.flush()
|
||||
return instance
|
||||
14
backend/app/repositories/llm_config_repository.py
Normal file
14
backend/app/repositories/llm_config_repository.py
Normal file
@@ -0,0 +1,14 @@
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from .base import BaseRepository
|
||||
from ..models import LLMConfig
|
||||
|
||||
|
||||
class LLMConfigRepository(BaseRepository[LLMConfig]):
|
||||
model = LLMConfig
|
||||
|
||||
async def get_by_user(self, user_id: int) -> Optional[LLMConfig]:
|
||||
result = await self.session.execute(select(LLMConfig).where(LLMConfig.user_id == user_id))
|
||||
return result.scalars().first()
|
||||
55
backend/app/repositories/novel_repository.py
Normal file
55
backend/app/repositories/novel_repository.py
Normal file
@@ -0,0 +1,55 @@
|
||||
from typing import Iterable, Optional
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from .base import BaseRepository
|
||||
from ..models import Chapter, NovelProject
|
||||
|
||||
|
||||
class NovelRepository(BaseRepository[NovelProject]):
|
||||
model = NovelProject
|
||||
|
||||
async def get_by_id(self, project_id: str) -> Optional[NovelProject]:
|
||||
stmt = (
|
||||
select(NovelProject)
|
||||
.where(NovelProject.id == project_id)
|
||||
.options(
|
||||
selectinload(NovelProject.blueprint),
|
||||
selectinload(NovelProject.characters),
|
||||
selectinload(NovelProject.relationships_),
|
||||
selectinload(NovelProject.outlines),
|
||||
selectinload(NovelProject.conversations),
|
||||
selectinload(NovelProject.chapters).selectinload(Chapter.versions),
|
||||
selectinload(NovelProject.chapters).selectinload(Chapter.evaluations),
|
||||
selectinload(NovelProject.chapters).selectinload(Chapter.selected_version),
|
||||
)
|
||||
)
|
||||
result = await self.session.execute(stmt)
|
||||
return result.scalars().first()
|
||||
|
||||
async def list_by_user(self, user_id: int) -> Iterable[NovelProject]:
|
||||
result = await self.session.execute(
|
||||
select(NovelProject)
|
||||
.where(NovelProject.user_id == user_id)
|
||||
.order_by(NovelProject.updated_at.desc())
|
||||
.options(
|
||||
selectinload(NovelProject.blueprint),
|
||||
selectinload(NovelProject.outlines),
|
||||
selectinload(NovelProject.chapters).selectinload(Chapter.selected_version),
|
||||
)
|
||||
)
|
||||
return result.scalars().all()
|
||||
|
||||
async def list_all(self) -> Iterable[NovelProject]:
|
||||
result = await self.session.execute(
|
||||
select(NovelProject)
|
||||
.order_by(NovelProject.updated_at.desc())
|
||||
.options(
|
||||
selectinload(NovelProject.owner),
|
||||
selectinload(NovelProject.blueprint),
|
||||
selectinload(NovelProject.outlines),
|
||||
selectinload(NovelProject.chapters).selectinload(Chapter.selected_version),
|
||||
)
|
||||
)
|
||||
return result.scalars().all()
|
||||
19
backend/app/repositories/prompt_repository.py
Normal file
19
backend/app/repositories/prompt_repository.py
Normal file
@@ -0,0 +1,19 @@
|
||||
from typing import Iterable, Optional
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from .base import BaseRepository
|
||||
from ..models import Prompt
|
||||
|
||||
|
||||
class PromptRepository(BaseRepository[Prompt]):
|
||||
model = Prompt
|
||||
|
||||
async def get_by_name(self, name: str) -> Optional[Prompt]:
|
||||
result = await self.session.execute(select(Prompt).where(Prompt.name == name))
|
||||
return result.scalars().first()
|
||||
|
||||
async def list_all(self) -> Iterable[Prompt]:
|
||||
result = await self.session.execute(select(Prompt).order_by(Prompt.name))
|
||||
return result.scalars().all()
|
||||
18
backend/app/repositories/system_config_repository.py
Normal file
18
backend/app/repositories/system_config_repository.py
Normal file
@@ -0,0 +1,18 @@
|
||||
from typing import Iterable, Optional
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from .base import BaseRepository
|
||||
from ..models import SystemConfig
|
||||
|
||||
|
||||
class SystemConfigRepository(BaseRepository[SystemConfig]):
|
||||
model = SystemConfig
|
||||
|
||||
async def get_by_key(self, key: str) -> Optional[SystemConfig]:
|
||||
result = await self.session.execute(select(SystemConfig).where(SystemConfig.key == key))
|
||||
return result.scalars().first()
|
||||
|
||||
async def list_all(self) -> Iterable[SystemConfig]:
|
||||
result = await self.session.execute(select(SystemConfig).order_by(SystemConfig.key))
|
||||
return result.scalars().all()
|
||||
19
backend/app/repositories/update_log_repository.py
Normal file
19
backend/app/repositories/update_log_repository.py
Normal file
@@ -0,0 +1,19 @@
|
||||
from typing import Iterable
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from .base import BaseRepository
|
||||
from ..models import UpdateLog
|
||||
|
||||
|
||||
class UpdateLogRepository(BaseRepository[UpdateLog]):
|
||||
model = UpdateLog
|
||||
|
||||
async def list(self) -> Iterable[UpdateLog]:
|
||||
result = await self.session.execute(select(UpdateLog).order_by(UpdateLog.created_at.desc()))
|
||||
return result.scalars().all()
|
||||
|
||||
async def list_latest(self, limit: int = 5) -> Iterable[UpdateLog]:
|
||||
stmt = select(UpdateLog).order_by(UpdateLog.created_at.desc()).limit(limit)
|
||||
result = await self.session.execute(stmt)
|
||||
return result.scalars().all()
|
||||
19
backend/app/repositories/usage_metric_repository.py
Normal file
19
backend/app/repositories/usage_metric_repository.py
Normal file
@@ -0,0 +1,19 @@
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from .base import BaseRepository
|
||||
from ..models import UsageMetric
|
||||
|
||||
|
||||
class UsageMetricRepository(BaseRepository[UsageMetric]):
|
||||
model = UsageMetric
|
||||
|
||||
async def get_or_create(self, key: str) -> UsageMetric:
|
||||
result = await self.session.execute(select(UsageMetric).where(UsageMetric.key == key))
|
||||
instance = result.scalars().first()
|
||||
if instance is None:
|
||||
instance = UsageMetric(key=key, value=0)
|
||||
self.session.add(instance)
|
||||
await self.session.flush()
|
||||
return instance
|
||||
62
backend/app/repositories/user_repository.py
Normal file
62
backend/app/repositories/user_repository.py
Normal file
@@ -0,0 +1,62 @@
|
||||
from datetime import date
|
||||
from typing import Iterable, Optional
|
||||
|
||||
from sqlalchemy import func, select, update
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from .base import BaseRepository
|
||||
from ..models import User, UserDailyRequest
|
||||
|
||||
|
||||
class UserRepository(BaseRepository[User]):
|
||||
model = User
|
||||
|
||||
async def get_by_username(self, username: str) -> Optional[User]:
|
||||
stmt = select(User).where(User.username == username)
|
||||
result = await self.session.execute(stmt)
|
||||
return result.scalars().first()
|
||||
|
||||
async def get_by_email(self, email: str) -> Optional[User]:
|
||||
stmt = select(User).where(User.email == email)
|
||||
result = await self.session.execute(stmt)
|
||||
return result.scalars().first()
|
||||
|
||||
async def get_by_external_id(self, external_id: str) -> Optional[User]:
|
||||
stmt = select(User).where(User.external_id == external_id)
|
||||
result = await self.session.execute(stmt)
|
||||
return result.scalars().first()
|
||||
|
||||
async def list_all(self) -> Iterable[User]:
|
||||
result = await self.session.execute(select(User))
|
||||
return result.scalars().all()
|
||||
|
||||
async def increment_daily_request(self, user_id: int) -> None:
|
||||
today = date.today()
|
||||
stmt = select(UserDailyRequest).where(
|
||||
UserDailyRequest.user_id == user_id,
|
||||
UserDailyRequest.request_date == today,
|
||||
)
|
||||
result = await self.session.execute(stmt)
|
||||
record = result.scalars().first()
|
||||
|
||||
if record is None:
|
||||
record = UserDailyRequest(user_id=user_id, request_date=today, request_count=1)
|
||||
self.session.add(record)
|
||||
else:
|
||||
record.request_count += 1
|
||||
await self.session.flush()
|
||||
|
||||
async def get_daily_request(self, user_id: int) -> int:
|
||||
today = date.today()
|
||||
stmt = select(UserDailyRequest.request_count).where(
|
||||
UserDailyRequest.user_id == user_id,
|
||||
UserDailyRequest.request_date == today,
|
||||
)
|
||||
result = await self.session.execute(stmt)
|
||||
value = result.scalars().first()
|
||||
return value or 0
|
||||
|
||||
async def count_users(self) -> int:
|
||||
stmt = select(func.count(User.id))
|
||||
result = await self.session.execute(stmt)
|
||||
return result.scalar_one()
|
||||
Reference in New Issue
Block a user