from __future__ import annotations import json import time import uuid from typing import TypedDict import httpx import pytest from sqlalchemy import select from core.config.settings import config from core.db.session import AsyncSessionLocal from models.points_audit_ledger import PointsAuditLedger from models.points_ledger import PointsLedger from models.register_bonus_claims import RegisterBonusClaims from models.user_points import UserPoints class IdentityData(TypedDict): email: str code: str async def _create_email_session( client: httpx.AsyncClient, *, email: str, code: str, ) -> dict[str, object]: resp = await client.post( "/api/v1/auth/email-session", json={"email": email, "token": code}, ) resp.raise_for_status() return resp.json() async def _wait_terminal_event( client: httpx.AsyncClient, *, access_token: str, thread_id: str, run_id: str, timeout_s: int = 180, ) -> str: headers = {"Authorization": f"Bearer {access_token}"} params = {"runId": run_id, "idle_limit": 120} started = time.time() async with client.stream( "GET", f"/api/v1/agent/runs/{thread_id}/events", headers=headers, params=params, ) as resp: resp.raise_for_status() async for line in resp.aiter_lines(): if time.time() - started > timeout_s: raise TimeoutError("SSE timed out") if not line or not line.startswith("data: "): continue event = json.loads(line[6:]) event_type = event.get("type") if event_type in {"RUN_FINISHED", "RUN_ERROR"}: return str(event_type) raise RuntimeError("No terminal SSE event") def _build_run_payload(*, thread_id: str, run_id: str) -> dict[str, object]: now = int(time.time() * 1000) return { "threadId": thread_id, "runId": run_id, "state": {}, "messages": [ { "id": f"msg_{run_id}_user_0", "role": "user", "content": "今天适合做重要决策吗?", } ], "tools": [], "context": [], "forwardedProps": { "runtime_mode": "chat", "client_time": { "device_timezone": "Asia/Shanghai", "client_now_iso": "2026-04-10T12:00:00Z", "client_epoch_ms": now, }, "divinationPayload": { "divinationMethod": "自动起卦", "questionType": "运势", "question": "今天适合做重要决策吗?", "divinationTimeIso": "2026-04-10T12:00:00Z", "yaoLines": ["少阳", "少阴", "老阳", "少阳", "老阴", "少阴"], }, }, } @pytest.mark.asyncio async def test_register_run_delete_reregister_keeps_bonus_single_use( api_client: httpx.AsyncClient, test_identity: IdentityData, db_cleanup: list[str], ) -> None: email = str(test_identity["email"]).strip().lower() db_cleanup.append(email) bonus = int(config.points_policy.register_bonus_points) first = await _create_email_session( api_client, email=email, code=str(test_identity["code"]), ) user1 = first.get("user") assert isinstance(user1, dict) user1_id = str(user1["id"]) token1 = str(first["access_token"]) headers1 = {"Authorization": f"Bearer {token1}"} before_run = await api_client.get("/api/v1/points/balance", headers=headers1) before_run.raise_for_status() before_data = before_run.json() assert int(before_data["balance"]) == bonus thread_id = str(uuid.uuid4()) run_id = f"run_{int(time.time() * 1000)}" enqueue = await api_client.post( "/api/v1/agent/runs", headers=headers1, json=_build_run_payload(thread_id=thread_id, run_id=run_id), ) enqueue.raise_for_status() assert enqueue.status_code == 202 terminal = await _wait_terminal_event( api_client, access_token=token1, thread_id=thread_id, run_id=run_id, ) assert terminal in {"RUN_FINISHED", "RUN_ERROR"} after_run = await api_client.get("/api/v1/points/balance", headers=headers1) after_run.raise_for_status() after_data = after_run.json() assert int(after_data["balance"]) == max(bonus - int(after_data["runCost"]), 0) delete_resp = await api_client.delete("/api/v1/users/me", headers=headers1) assert delete_resp.status_code == 204 second = await _create_email_session( api_client, email=email, code=str(test_identity["code"]), ) user2 = second.get("user") assert isinstance(user2, dict) user2_id = str(user2["id"]) token2 = str(second["access_token"]) assert user1_id != user2_id headers2 = {"Authorization": f"Bearer {token2}"} reregister_balance = await api_client.get( "/api/v1/points/balance", headers=headers2 ) reregister_balance.raise_for_status() re_data = reregister_balance.json() assert int(re_data["balance"]) == 0 async with AsyncSessionLocal() as session: points2 = ( await session.execute( select(UserPoints).where(UserPoints.user_id == uuid.UUID(user2_id)) ) ).scalar_one() assert int(points2.lifetime_earned) == 0 run_ledger_rows = list( ( await session.execute( select(PointsLedger) .where(PointsLedger.user_id == uuid.UUID(user1_id)) .order_by(PointsLedger.created_at.desc()) ) ).scalars() ) assert run_ledger_rows == [] run_audit_rows = list( ( await session.execute( select(PointsAuditLedger) .where( PointsAuditLedger.user_id_snapshot == uuid.UUID(user1_id), PointsAuditLedger.run_id == run_id, ) .order_by(PointsAuditLedger.created_at.desc()) ) ).scalars() ) assert run_audit_rows assert run_audit_rows[0].run_id == run_id assert run_audit_rows[0].billed_to in {"user", "platform"} claim_rows = list( ( await session.execute( select(RegisterBonusClaims).where( RegisterBonusClaims.user_email_snapshot == email ) ) ).scalars() ) assert len(claim_rows) == 1