142 lines
5.1 KiB
Python
142 lines
5.1 KiB
Python
from __future__ import annotations
|
|
|
|
from datetime import datetime, timezone
|
|
from typing import TYPE_CHECKING, Protocol
|
|
from uuid import UUID
|
|
|
|
from sqlalchemy import select, update
|
|
from sqlalchemy.exc import SQLAlchemyError
|
|
|
|
from core.db.base_repository import BaseRepository
|
|
from core.logging import get_logger
|
|
from models.schedule_items import ScheduleItem
|
|
from models.schedule_subscriptions import ScheduleSubscription
|
|
|
|
if TYPE_CHECKING:
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
logger = get_logger("v1.schedule_items.repository")
|
|
|
|
|
|
class ScheduleItemRepository(Protocol):
|
|
async def get_by_id(self, entity_id: UUID) -> ScheduleItem | None: ...
|
|
async def get_by_item_id(
|
|
self, item_id: UUID, owner_id: UUID
|
|
) -> ScheduleItem | None: ...
|
|
async def create(self, data: dict) -> ScheduleItem: ...
|
|
async def update_by_item_id(
|
|
self, item_id: UUID, owner_id: UUID, data: dict
|
|
) -> ScheduleItem | None: ...
|
|
async def delete_by_item_id(
|
|
self, item_id: UUID, owner_id: UUID
|
|
) -> ScheduleItem | None: ...
|
|
async def list_by_date_range(
|
|
self, owner_id: UUID, start_at: datetime, end_at: datetime
|
|
) -> list[ScheduleItem]: ...
|
|
async def create_subscription(self, data: dict) -> ScheduleSubscription: ...
|
|
|
|
|
|
class SQLAlchemyScheduleItemRepository(BaseRepository[ScheduleItem]):
|
|
def __init__(self, session: AsyncSession) -> None:
|
|
super().__init__(session, ScheduleItem)
|
|
|
|
async def get_by_item_id(
|
|
self, item_id: UUID, owner_id: UUID
|
|
) -> ScheduleItem | None:
|
|
try:
|
|
stmt = (
|
|
select(ScheduleItem)
|
|
.where(ScheduleItem.id == item_id)
|
|
.where(ScheduleItem.owner_id == owner_id)
|
|
.where(ScheduleItem.deleted_at.is_(None))
|
|
)
|
|
result = await self._session.execute(stmt)
|
|
return result.scalar_one_or_none()
|
|
except SQLAlchemyError:
|
|
logger.exception(
|
|
"Schedule item lookup failed",
|
|
item_id=str(item_id),
|
|
owner_id=str(owner_id),
|
|
)
|
|
raise
|
|
|
|
async def create(self, data: dict) -> ScheduleItem:
|
|
try:
|
|
item = ScheduleItem(**data)
|
|
self._session.add(item)
|
|
await self._session.flush()
|
|
return item
|
|
except SQLAlchemyError:
|
|
logger.exception("Schedule item creation failed")
|
|
raise
|
|
|
|
async def update_by_item_id(
|
|
self, item_id: UUID, owner_id: UUID, data: dict
|
|
) -> ScheduleItem | None:
|
|
if not data:
|
|
return await self.get_by_item_id(item_id, owner_id)
|
|
try:
|
|
existing = await self.get_by_item_id(item_id, owner_id)
|
|
if existing is None:
|
|
return None
|
|
stmt = (
|
|
update(ScheduleItem)
|
|
.where(ScheduleItem.id == item_id)
|
|
.where(ScheduleItem.owner_id == owner_id)
|
|
.where(ScheduleItem.deleted_at.is_(None))
|
|
.values(**data)
|
|
.returning(ScheduleItem)
|
|
)
|
|
result = await self._session.execute(stmt)
|
|
await self._session.flush()
|
|
return result.scalar_one_or_none()
|
|
except SQLAlchemyError:
|
|
logger.exception("Schedule item update failed", item_id=str(item_id))
|
|
raise
|
|
|
|
async def delete_by_item_id(
|
|
self, item_id: UUID, owner_id: UUID
|
|
) -> ScheduleItem | None:
|
|
try:
|
|
stmt = (
|
|
update(ScheduleItem)
|
|
.where(ScheduleItem.id == item_id)
|
|
.where(ScheduleItem.owner_id == owner_id)
|
|
.where(ScheduleItem.deleted_at.is_(None))
|
|
.values(deleted_at=datetime.now(timezone.utc))
|
|
.returning(ScheduleItem)
|
|
)
|
|
result = await self._session.execute(stmt)
|
|
await self._session.flush()
|
|
return result.scalar_one_or_none()
|
|
except SQLAlchemyError:
|
|
logger.exception("Schedule item delete failed", item_id=str(item_id))
|
|
raise
|
|
|
|
async def list_by_date_range(
|
|
self, owner_id: UUID, start_at: datetime, end_at: datetime
|
|
) -> list[ScheduleItem]:
|
|
try:
|
|
stmt = (
|
|
select(ScheduleItem)
|
|
.where(ScheduleItem.owner_id == owner_id)
|
|
.where(ScheduleItem.deleted_at.is_(None))
|
|
.where(ScheduleItem.start_at >= start_at)
|
|
.where(ScheduleItem.start_at <= end_at)
|
|
.order_by(ScheduleItem.start_at.asc())
|
|
)
|
|
result = await self._session.execute(stmt)
|
|
return list(result.scalars().all())
|
|
except SQLAlchemyError:
|
|
logger.exception("Schedule item list failed", owner_id=str(owner_id))
|
|
raise
|
|
|
|
async def create_subscription(self, data: dict) -> ScheduleSubscription:
|
|
sub = ScheduleSubscription(**data)
|
|
self._session.add(sub)
|
|
await self._session.flush()
|
|
return sub
|
|
|
|
async def get_by_id(self, entity_id: UUID) -> ScheduleItem | None:
|
|
return await super().get_by_id(entity_id)
|