refactor: 重构聊天模块支持 SSE 断线重连及用户上下文隔离

This commit is contained in:
zl-q
2026-03-30 09:06:10 +08:00
parent 1aac62f39e
commit 4285b4ec80
28 changed files with 1624 additions and 658 deletions
@@ -77,7 +77,7 @@ def test_create_schedule_item_returns_201() -> None:
source_type=ScheduleItemSourceType.MANUAL,
created_at=datetime(2026, 2, 27, 10, 0, 0, tzinfo=timezone.utc),
updated_at=datetime(2026, 2, 27, 10, 0, 0, tzinfo=timezone.utc),
permission=7,
permission=15,
is_owner=True,
)
@@ -110,7 +110,7 @@ def test_list_schedule_items_returns_200() -> None:
source_type=ScheduleItemSourceType.MANUAL,
created_at=datetime(2026, 2, 27, 10, 0, 0, tzinfo=timezone.utc),
updated_at=datetime(2026, 2, 27, 10, 0, 0, tzinfo=timezone.utc),
permission=7,
permission=15,
is_owner=True,
)
@@ -145,7 +145,7 @@ def test_get_schedule_item_returns_200() -> None:
source_type=ScheduleItemSourceType.MANUAL,
created_at=datetime(2026, 2, 27, 10, 0, 0, tzinfo=timezone.utc),
updated_at=datetime(2026, 2, 27, 10, 0, 0, tzinfo=timezone.utc),
permission=7,
permission=15,
is_owner=True,
)
@@ -173,7 +173,7 @@ def test_update_schedule_item_returns_200() -> None:
source_type=ScheduleItemSourceType.MANUAL,
created_at=datetime(2026, 2, 27, 10, 0, 0, tzinfo=timezone.utc),
updated_at=datetime(2026, 2, 27, 10, 0, 0, tzinfo=timezone.utc),
permission=7,
permission=15,
is_owner=True,
)
@@ -204,7 +204,7 @@ def test_delete_schedule_item_returns_204() -> None:
source_type=ScheduleItemSourceType.MANUAL,
created_at=datetime(2026, 2, 27, 10, 0, 0, tzinfo=timezone.utc),
updated_at=datetime(2026, 2, 27, 10, 0, 0, tzinfo=timezone.utc),
permission=7,
permission=15,
is_owner=True,
)
@@ -154,6 +154,65 @@ class _TerminalStreamAgentService(_FakeAgentService):
return []
class _MixedRunStreamAgentService(_FakeAgentService):
def __init__(self) -> None:
super().__init__()
self.stream_calls = 0
async def stream_events(
self,
*,
thread_id: str,
last_event_id: str | None,
current_user: CurrentUser,
) -> list[dict[str, object]]:
del thread_id, last_event_id, current_user
self.stream_calls += 1
if self.stream_calls == 1:
return [
{
"id": "11-0",
"event": {
"type": "RUN_FINISHED",
"threadId": "00000000-0000-0000-0000-000000000001",
"runId": "run-old",
},
},
{
"id": "12-0",
"event": {
"type": "RUN_STARTED",
"threadId": "00000000-0000-0000-0000-000000000001",
"runId": "run-1",
},
},
]
if self.stream_calls == 2:
return [
{
"id": "13-0",
"event": {
"type": "STEP_STARTED",
"threadId": "00000000-0000-0000-0000-000000000001",
"runId": "run-1",
"stepName": "router",
},
}
]
if self.stream_calls == 3:
return [
{
"id": "14-0",
"event": {
"type": "RUN_FINISHED",
"threadId": "00000000-0000-0000-0000-000000000001",
"runId": "run-1",
},
}
]
return []
def test_run_requires_auth_and_returns_202_task_id() -> None:
app.dependency_overrides[get_agent_service] = lambda: _FakeAgentService()
client = TestClient(app)
@@ -168,7 +227,7 @@ def test_run_requires_auth_and_returns_202_task_id() -> None:
"messages": [{"id": "u1", "role": "user", "content": "hello"}],
"tools": [],
"context": [],
"forwardedProps": {"agent_type": "worker"},
"forwardedProps": {"runtime_mode": "chat"},
},
)
assert unauthorized.status_code == 401
@@ -185,7 +244,7 @@ def test_run_requires_auth_and_returns_202_task_id() -> None:
"messages": [{"id": "u1", "role": "user", "content": "hello"}],
"tools": [],
"context": [],
"forwardedProps": {"agent_type": "worker"},
"forwardedProps": {"runtime_mode": "chat"},
},
)
assert authorized.status_code == 202
@@ -219,7 +278,7 @@ def test_stream_reads_from_last_event_id() -> None:
try:
response = client.get(
"/api/v1/agent/runs/00000000-0000-0000-0000-000000000001/events?idle_limit=1",
"/api/v1/agent/runs/00000000-0000-0000-0000-000000000001/events?runId=run-1&idle_limit=1",
headers={"Last-Event-ID": "1-0"},
)
assert response.status_code == 200
@@ -255,7 +314,7 @@ def test_stream_handles_stream_backend_errors_without_connection_crash() -> None
try:
response = client.get(
"/api/v1/agent/runs/00000000-0000-0000-0000-000000000001/events?idle_limit=1"
"/api/v1/agent/runs/00000000-0000-0000-0000-000000000001/events?runId=run-1&idle_limit=1"
)
assert response.status_code == 200
assert response.headers["content-type"].startswith("text/event-stream")
@@ -288,7 +347,7 @@ def test_stream_stops_after_terminal_run_event() -> None:
try:
response = client.get(
"/api/v1/agent/runs/00000000-0000-0000-0000-000000000001/events?idle_limit=3"
"/api/v1/agent/runs/00000000-0000-0000-0000-000000000001/events?runId=run-1&idle_limit=3"
)
assert response.status_code == 200
assert response.headers["content-type"].startswith("text/event-stream")
@@ -309,7 +368,7 @@ def test_stream_rejects_invalid_last_event_id() -> None:
try:
response = client.get(
"/api/v1/agent/runs/00000000-0000-0000-0000-000000000001/events",
"/api/v1/agent/runs/00000000-0000-0000-0000-000000000001/events?runId=run-1",
headers={"Last-Event-ID": "bad-id"},
)
assert response.status_code == 422
@@ -320,6 +379,68 @@ def test_stream_rejects_invalid_last_event_id() -> None:
app.dependency_overrides = {}
def test_stream_filters_non_target_run_and_waits_target_terminal() -> None:
service = _MixedRunStreamAgentService()
app.dependency_overrides[get_agent_service] = lambda: service
app.dependency_overrides[get_current_user] = lambda: CurrentUser(
id=uuid4(), phone="+8613812345678"
)
client = TestClient(app)
original_acquire = agent_router._acquire_sse_slot
original_release = agent_router._release_sse_slot
async def _allow_slot(*, user_id: str) -> bool:
del user_id
return True
async def _noop_release(*, user_id: str) -> None:
del user_id
return None
agent_router._acquire_sse_slot = _allow_slot # type: ignore[assignment]
agent_router._release_sse_slot = _noop_release # type: ignore[assignment]
try:
response = client.get(
"/api/v1/agent/runs/00000000-0000-0000-0000-000000000001/events?runId=run-1&idle_limit=3"
)
assert response.status_code == 200
assert response.headers["content-type"].startswith("text/event-stream")
assert '"runId":"run-old"' not in response.text
assert '"runId":"run-1"' in response.text
assert "event: RUN_STARTED" in response.text
assert "event: STEP_STARTED" in response.text
assert "event: RUN_FINISHED" in response.text
assert service.stream_calls == 3
finally:
agent_router._acquire_sse_slot = original_acquire # type: ignore[assignment]
agent_router._release_sse_slot = original_release # type: ignore[assignment]
app.dependency_overrides = {}
def test_stream_rejects_invalid_or_missing_run_id() -> None:
app.dependency_overrides[get_agent_service] = lambda: _FakeAgentService()
app.dependency_overrides[get_current_user] = lambda: CurrentUser(
id=uuid4(), phone="+8613812345678"
)
client = TestClient(app)
try:
invalid = client.get(
"/api/v1/agent/runs/00000000-0000-0000-0000-000000000001/events?runId=bad%20id"
)
assert invalid.status_code == 422
assert invalid.json()["code"] == "AGENT_INVALID_RUN_ID"
missing = client.get(
"/api/v1/agent/runs/00000000-0000-0000-0000-000000000001/events"
)
assert missing.status_code == 422
assert missing.json()["code"] == "AGENT_INVALID_RUN_ID"
finally:
app.dependency_overrides = {}
def test_stream_rejects_when_sse_connection_limit_exceeded() -> None:
app.dependency_overrides[get_agent_service] = lambda: _FakeAgentService()
app.dependency_overrides[get_current_user] = lambda: CurrentUser(
@@ -336,7 +457,7 @@ def test_stream_rejects_when_sse_connection_limit_exceeded() -> None:
try:
response = client.get(
"/api/v1/agent/runs/00000000-0000-0000-0000-000000000001/events"
"/api/v1/agent/runs/00000000-0000-0000-0000-000000000001/events?runId=run-1"
)
assert response.status_code == 429
payload = response.json()
@@ -587,7 +708,7 @@ def test_asr_transcribe_returns_sync_transcript(monkeypatch) -> None:
return "这是测试转写结果"
monkeypatch.setattr(
"v1.agent.service.asr_service.transcribe_file",
"v1.agent.router.asr_service.transcribe_file",
mock_transcribe_file,
)