from __future__ import annotations from decimal import Decimal from uuid import UUID from sqlalchemy.dialects.postgresql import insert from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from models.agent_chat_message import AgentChatMessage from models.points_audit_ledger import PointsAuditLedger from models.points_ledger import PointsLedger from models.profile import Profile from models.register_bonus_claims import RegisterBonusClaims from models.user_points import UserPoints from schemas.domain.points import ( AppendAuditLedgerCommand, ApplyPointsChangeCommand, PointsChargeSnapshot, ) from schemas.enums import AgentChatMessageRole class PointsRepository: def __init__(self, session: AsyncSession) -> None: self._session = session async def get_or_create_user_points_for_update( self, *, user_id: UUID ) -> UserPoints: insert_stmt = ( insert(UserPoints) .values(user_id=user_id) .on_conflict_do_nothing(index_elements=[UserPoints.user_id]) ) await self._session.execute(insert_stmt) stmt = select(UserPoints).where(UserPoints.user_id == user_id).with_for_update() return (await self._session.execute(stmt)).scalar_one() async def has_ledger_event(self, *, user_id: UUID, event_id: str) -> bool: stmt = select(PointsLedger.id).where( PointsLedger.user_id == user_id, PointsLedger.event_id == event_id, ) row = (await self._session.execute(stmt)).scalar_one_or_none() return row is not None async def append_ledger( self, *, command: ApplyPointsChangeCommand, balance_after: int, ) -> None: entry = PointsLedger( user_id=command.user_id, direction=command.direction, amount=command.amount, balance_after=balance_after, change_type=command.change_type.value, biz_type=command.biz_type.value if command.biz_type is not None else None, biz_id=command.biz_id, event_id=command.event_id, operator_id=command.operator_id, metadata_json=command.metadata.model_dump(mode="json", exclude_none=True), ) self._session.add(entry) await self._session.flush() async def has_audit_event(self, *, event_id: str) -> bool: stmt = select(PointsAuditLedger.id).where( PointsAuditLedger.event_id == event_id ) row = (await self._session.execute(stmt)).scalar_one_or_none() return row is not None async def append_audit_ledger(self, *, command: AppendAuditLedgerCommand) -> None: entry = PointsAuditLedger( event_id=command.event_id, user_id_snapshot=command.user_id_snapshot, user_email_snapshot=command.user_email_snapshot, change_type=command.change_type.value, biz_type=command.biz_type.value if command.biz_type is not None else None, biz_id=command.biz_id, direction=command.direction, amount=command.amount, balance_after=command.balance_after, billed_to=command.billed_to, run_id=command.run_id, request_id=command.request_id, input_tokens=command.input_tokens, output_tokens=command.output_tokens, cost=command.cost, metadata_json=command.metadata.model_dump(mode="json", exclude_none=True), ) self._session.add(entry) await self._session.flush() async def get_run_usage_snapshot( self, *, session_id: UUID, run_id: str, ) -> PointsChargeSnapshot | None: stmt = ( select(AgentChatMessage) .where( AgentChatMessage.session_id == session_id, AgentChatMessage.role == AgentChatMessageRole.ASSISTANT, AgentChatMessage.deleted_at.is_(None), ) .order_by(AgentChatMessage.seq.desc()) .limit(20) ) messages = list((await self._session.execute(stmt)).scalars().all()) message = None for candidate in messages: metadata = candidate.metadata_json or {} if metadata.get("run_id") == run_id: message = candidate break if message is None: return None cost_value = message.cost if message.cost is not None else Decimal("0") return PointsChargeSnapshot( message_id=message.id, message_seq=max(int(message.seq), 1), model_code=(message.model_code or "agent_run").strip() or "agent_run", input_tokens=max(int(message.input_tokens), 0), output_tokens=max(int(message.output_tokens), 0), cost=Decimal(str(cost_value)), ) async def get_user_points(self, *, user_id: UUID) -> UserPoints: insert_stmt = ( insert(UserPoints) .values(user_id=user_id) .on_conflict_do_nothing(index_elements=[UserPoints.user_id]) ) await self._session.execute(insert_stmt) stmt = select(UserPoints).where(UserPoints.user_id == user_id) return (await self._session.execute(stmt)).scalar_one() async def claim_register_bonus( self, *, email_hash: str, user_email_snapshot: str, first_user_id_snapshot: UUID, grant_event_id: str, ) -> bool: stmt = ( insert(RegisterBonusClaims) .values( email_hash=email_hash, user_email_snapshot=user_email_snapshot, first_user_id_snapshot=first_user_id_snapshot, grant_event_id=grant_event_id, ) .on_conflict_do_nothing(index_elements=[RegisterBonusClaims.email_hash]) .returning(RegisterBonusClaims.id) ) inserted_id = (await self._session.execute(stmt)).scalar_one_or_none() return inserted_id is not None async def get_register_bonus_claim( self, *, email_hash: str, ) -> RegisterBonusClaims | None: stmt = ( select(RegisterBonusClaims) .where(RegisterBonusClaims.email_hash == email_hash) .limit(1) ) return (await self._session.execute(stmt)).scalar_one_or_none() async def update_register_bonus_balance_snapshot( self, *, email_hash: str, balance_snapshot: int, ) -> bool: claim = await self.get_register_bonus_claim(email_hash=email_hash) if claim is None: return False claim.balance_snapshot = int(balance_snapshot) await self._session.flush() return True async def has_purchased_starter_pack( self, *, email_hash: str, ) -> bool: claim = await self.get_register_bonus_claim(email_hash=email_hash) if claim is None: return False return bool(claim.has_purchased_starter_pack) async def get_profile_settings( self, *, user_id: UUID, ) -> dict[str, object] | None: stmt = select(Profile.settings).where(Profile.id == user_id).limit(1) row = (await self._session.execute(stmt)).scalar_one_or_none() return row