from __future__ import annotations from datetime import datetime, timezone from typing import TYPE_CHECKING, Protocol from uuid import UUID from sqlalchemy import select, update from sqlalchemy.exc import SQLAlchemyError from core.db.base_repository import BaseRepository from core.logging import get_logger from models.schedule_items import ScheduleItem from models.schedule_subscriptions import ScheduleSubscription if TYPE_CHECKING: from sqlalchemy.ext.asyncio import AsyncSession logger = get_logger("v1.schedule_items.repository") class ScheduleItemRepository(Protocol): async def get_by_id(self, entity_id: UUID) -> ScheduleItem | None: ... async def get_by_item_id( self, item_id: UUID, owner_id: UUID ) -> ScheduleItem | None: ... async def create(self, data: dict) -> ScheduleItem: ... async def update_by_item_id( self, item_id: UUID, owner_id: UUID, data: dict ) -> ScheduleItem | None: ... async def delete_by_item_id( self, item_id: UUID, owner_id: UUID ) -> ScheduleItem | None: ... async def list_by_date_range( self, owner_id: UUID, start_at: datetime, end_at: datetime ) -> list[ScheduleItem]: ... async def create_subscription(self, data: dict) -> ScheduleSubscription: ... class SQLAlchemyScheduleItemRepository(BaseRepository[ScheduleItem]): def __init__(self, session: AsyncSession) -> None: super().__init__(session, ScheduleItem) async def get_by_item_id( self, item_id: UUID, owner_id: UUID ) -> ScheduleItem | None: try: stmt = ( select(ScheduleItem) .where(ScheduleItem.id == item_id) .where(ScheduleItem.owner_id == owner_id) .where(ScheduleItem.deleted_at.is_(None)) ) result = await self._session.execute(stmt) return result.scalar_one_or_none() except SQLAlchemyError: logger.exception( "Schedule item lookup failed", item_id=str(item_id), owner_id=str(owner_id), ) raise async def create(self, data: dict) -> ScheduleItem: try: item = ScheduleItem(**data) self._session.add(item) await self._session.flush() return item except SQLAlchemyError: logger.exception("Schedule item creation failed") raise async def update_by_item_id( self, item_id: UUID, owner_id: UUID, data: dict ) -> ScheduleItem | None: if not data: return await self.get_by_item_id(item_id, owner_id) try: existing = await self.get_by_item_id(item_id, owner_id) if existing is None: return None stmt = ( update(ScheduleItem) .where(ScheduleItem.id == item_id) .where(ScheduleItem.owner_id == owner_id) .where(ScheduleItem.deleted_at.is_(None)) .values(**data) .returning(ScheduleItem) ) result = await self._session.execute(stmt) await self._session.flush() return result.scalar_one_or_none() except SQLAlchemyError: logger.exception("Schedule item update failed", item_id=str(item_id)) raise async def delete_by_item_id( self, item_id: UUID, owner_id: UUID ) -> ScheduleItem | None: try: stmt = ( update(ScheduleItem) .where(ScheduleItem.id == item_id) .where(ScheduleItem.owner_id == owner_id) .where(ScheduleItem.deleted_at.is_(None)) .values(deleted_at=datetime.now(timezone.utc)) .returning(ScheduleItem) ) result = await self._session.execute(stmt) await self._session.flush() return result.scalar_one_or_none() except SQLAlchemyError: logger.exception("Schedule item delete failed", item_id=str(item_id)) raise async def list_by_date_range( self, owner_id: UUID, start_at: datetime, end_at: datetime ) -> list[ScheduleItem]: try: stmt = ( select(ScheduleItem) .where(ScheduleItem.owner_id == owner_id) .where(ScheduleItem.deleted_at.is_(None)) .where(ScheduleItem.start_at >= start_at) .where(ScheduleItem.start_at <= end_at) .order_by(ScheduleItem.start_at.asc()) ) result = await self._session.execute(stmt) return list(result.scalars().all()) except SQLAlchemyError: logger.exception("Schedule item list failed", owner_id=str(owner_id)) raise async def create_subscription(self, data: dict) -> ScheduleSubscription: sub = ScheduleSubscription(**data) self._session.add(sub) await self._session.flush() return sub async def get_by_id(self, entity_id: UUID) -> ScheduleItem | None: return await super().get_by_id(entity_id)