from __future__ import annotations import json import time import uuid from typing import TypedDict import httpx import pytest class IdentityData(TypedDict): email: str code: str async def _create_email_session( client: httpx.AsyncClient, *, email: str, code: str, ) -> dict[str, object]: resp = await client.post( "/api/v1/auth/email-session", json={"email": email, "token": code}, ) resp.raise_for_status() return resp.json() async def _wait_terminal_event( client: httpx.AsyncClient, *, access_token: str, thread_id: str, run_id: str, timeout_s: int = 180, ) -> str: headers = {"Authorization": f"Bearer {access_token}"} params = {"runId": run_id, "idle_limit": 120} started = time.time() async with client.stream( "GET", f"/api/v1/agent/runs/{thread_id}/events", headers=headers, params=params, ) as resp: resp.raise_for_status() async for line in resp.aiter_lines(): if time.time() - started > timeout_s: raise TimeoutError("SSE timed out") if not line or not line.startswith("data: "): continue event = json.loads(line[6:]) event_type = event.get("type") if event_type in {"RUN_FINISHED", "RUN_ERROR"}: return str(event_type) raise RuntimeError("No terminal SSE event") def _build_run_payload( *, thread_id: str, run_id: str, runtime_mode: str, question: str, ) -> dict[str, object]: now = int(time.time() * 1000) return { "threadId": thread_id, "runId": run_id, "state": {}, "messages": [ { "id": f"msg_{run_id}_user_0", "role": "user", "content": question, } ], "tools": [], "context": [], "forwardedProps": { "runtime_mode": runtime_mode, "client_time": { "device_timezone": "Asia/Shanghai", "client_now_iso": "2026-04-10T12:00:00Z", "client_epoch_ms": now, }, "divinationPayload": { "divinationMethod": "自动起卦", "questionType": "运势", "question": question, "divinationTimeIso": "2026-04-10T12:00:00Z", "yaoLines": ["少阳", "少阴", "老阳", "少阳", "老阴", "少阴"], }, }, } @pytest.mark.asyncio async def test_follow_up_run_succeeds_and_limit_uses_assistant_count( api_client: httpx.AsyncClient, test_identity: IdentityData, db_cleanup: list[str], ) -> None: email = str(test_identity["email"]).strip().lower() db_cleanup.append(email) login = await _create_email_session( api_client, email=email, code=str(test_identity["code"]), ) token = str(login["access_token"]) headers = {"Authorization": f"Bearer {token}"} thread_id = str(uuid.uuid4()) first_run_id = f"run_chat_{int(time.time() * 1000)}" first_enqueue = await api_client.post( "/api/v1/agent/runs", headers=headers, json=_build_run_payload( thread_id=thread_id, run_id=first_run_id, runtime_mode="chat", question="这周适合推进新项目吗?", ), ) first_enqueue.raise_for_status() assert first_enqueue.status_code == 202 first_terminal = await _wait_terminal_event( api_client, access_token=token, thread_id=thread_id, run_id=first_run_id, ) assert first_terminal == "RUN_FINISHED" second_run_id = f"run_follow_up_{int(time.time() * 1000)}" second_enqueue = await api_client.post( "/api/v1/agent/runs", headers=headers, json=_build_run_payload( thread_id=thread_id, run_id=second_run_id, runtime_mode="follow_up", question="那我第一步应该先做什么?", ), ) second_enqueue.raise_for_status() assert second_enqueue.status_code == 202 second_terminal = await _wait_terminal_event( api_client, access_token=token, thread_id=thread_id, run_id=second_run_id, ) assert second_terminal == "RUN_FINISHED" history_resp = await api_client.get( "/api/v1/agent/history", headers=headers, params={"threadId": thread_id}, ) history_resp.raise_for_status() history_payload = history_resp.json() messages = history_payload.get("messages") assert isinstance(messages, list) assistant_messages = [ message for message in messages if isinstance(message, dict) and message.get("role") == "assistant" ] assert len(assistant_messages) == 2 third_run_id = f"run_follow_up_blocked_{int(time.time() * 1000)}" third_enqueue = await api_client.post( "/api/v1/agent/runs", headers=headers, json=_build_run_payload( thread_id=thread_id, run_id=third_run_id, runtime_mode="follow_up", question="还有哪些风险要特别注意?", ), ) assert third_enqueue.status_code == 409 error_payload = third_enqueue.json() assert error_payload.get("code") == "AGENT_SESSION_RUN_LIMIT_EXCEEDED" params = error_payload.get("params") assert isinstance(params, dict) assert params.get("maxRuns") == 2