refactor: 重构聊天模块支持 SSE 断线重连及用户上下文隔离
This commit is contained in:
@@ -154,6 +154,65 @@ class _TerminalStreamAgentService(_FakeAgentService):
|
||||
return []
|
||||
|
||||
|
||||
class _MixedRunStreamAgentService(_FakeAgentService):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.stream_calls = 0
|
||||
|
||||
async def stream_events(
|
||||
self,
|
||||
*,
|
||||
thread_id: str,
|
||||
last_event_id: str | None,
|
||||
current_user: CurrentUser,
|
||||
) -> list[dict[str, object]]:
|
||||
del thread_id, last_event_id, current_user
|
||||
self.stream_calls += 1
|
||||
if self.stream_calls == 1:
|
||||
return [
|
||||
{
|
||||
"id": "11-0",
|
||||
"event": {
|
||||
"type": "RUN_FINISHED",
|
||||
"threadId": "00000000-0000-0000-0000-000000000001",
|
||||
"runId": "run-old",
|
||||
},
|
||||
},
|
||||
{
|
||||
"id": "12-0",
|
||||
"event": {
|
||||
"type": "RUN_STARTED",
|
||||
"threadId": "00000000-0000-0000-0000-000000000001",
|
||||
"runId": "run-1",
|
||||
},
|
||||
},
|
||||
]
|
||||
if self.stream_calls == 2:
|
||||
return [
|
||||
{
|
||||
"id": "13-0",
|
||||
"event": {
|
||||
"type": "STEP_STARTED",
|
||||
"threadId": "00000000-0000-0000-0000-000000000001",
|
||||
"runId": "run-1",
|
||||
"stepName": "router",
|
||||
},
|
||||
}
|
||||
]
|
||||
if self.stream_calls == 3:
|
||||
return [
|
||||
{
|
||||
"id": "14-0",
|
||||
"event": {
|
||||
"type": "RUN_FINISHED",
|
||||
"threadId": "00000000-0000-0000-0000-000000000001",
|
||||
"runId": "run-1",
|
||||
},
|
||||
}
|
||||
]
|
||||
return []
|
||||
|
||||
|
||||
def test_run_requires_auth_and_returns_202_task_id() -> None:
|
||||
app.dependency_overrides[get_agent_service] = lambda: _FakeAgentService()
|
||||
client = TestClient(app)
|
||||
@@ -168,7 +227,7 @@ def test_run_requires_auth_and_returns_202_task_id() -> None:
|
||||
"messages": [{"id": "u1", "role": "user", "content": "hello"}],
|
||||
"tools": [],
|
||||
"context": [],
|
||||
"forwardedProps": {"agent_type": "worker"},
|
||||
"forwardedProps": {"runtime_mode": "chat"},
|
||||
},
|
||||
)
|
||||
assert unauthorized.status_code == 401
|
||||
@@ -185,7 +244,7 @@ def test_run_requires_auth_and_returns_202_task_id() -> None:
|
||||
"messages": [{"id": "u1", "role": "user", "content": "hello"}],
|
||||
"tools": [],
|
||||
"context": [],
|
||||
"forwardedProps": {"agent_type": "worker"},
|
||||
"forwardedProps": {"runtime_mode": "chat"},
|
||||
},
|
||||
)
|
||||
assert authorized.status_code == 202
|
||||
@@ -219,7 +278,7 @@ def test_stream_reads_from_last_event_id() -> None:
|
||||
|
||||
try:
|
||||
response = client.get(
|
||||
"/api/v1/agent/runs/00000000-0000-0000-0000-000000000001/events?idle_limit=1",
|
||||
"/api/v1/agent/runs/00000000-0000-0000-0000-000000000001/events?runId=run-1&idle_limit=1",
|
||||
headers={"Last-Event-ID": "1-0"},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
@@ -255,7 +314,7 @@ def test_stream_handles_stream_backend_errors_without_connection_crash() -> None
|
||||
|
||||
try:
|
||||
response = client.get(
|
||||
"/api/v1/agent/runs/00000000-0000-0000-0000-000000000001/events?idle_limit=1"
|
||||
"/api/v1/agent/runs/00000000-0000-0000-0000-000000000001/events?runId=run-1&idle_limit=1"
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert response.headers["content-type"].startswith("text/event-stream")
|
||||
@@ -288,7 +347,7 @@ def test_stream_stops_after_terminal_run_event() -> None:
|
||||
|
||||
try:
|
||||
response = client.get(
|
||||
"/api/v1/agent/runs/00000000-0000-0000-0000-000000000001/events?idle_limit=3"
|
||||
"/api/v1/agent/runs/00000000-0000-0000-0000-000000000001/events?runId=run-1&idle_limit=3"
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert response.headers["content-type"].startswith("text/event-stream")
|
||||
@@ -309,7 +368,7 @@ def test_stream_rejects_invalid_last_event_id() -> None:
|
||||
|
||||
try:
|
||||
response = client.get(
|
||||
"/api/v1/agent/runs/00000000-0000-0000-0000-000000000001/events",
|
||||
"/api/v1/agent/runs/00000000-0000-0000-0000-000000000001/events?runId=run-1",
|
||||
headers={"Last-Event-ID": "bad-id"},
|
||||
)
|
||||
assert response.status_code == 422
|
||||
@@ -320,6 +379,68 @@ def test_stream_rejects_invalid_last_event_id() -> None:
|
||||
app.dependency_overrides = {}
|
||||
|
||||
|
||||
def test_stream_filters_non_target_run_and_waits_target_terminal() -> None:
|
||||
service = _MixedRunStreamAgentService()
|
||||
app.dependency_overrides[get_agent_service] = lambda: service
|
||||
app.dependency_overrides[get_current_user] = lambda: CurrentUser(
|
||||
id=uuid4(), phone="+8613812345678"
|
||||
)
|
||||
client = TestClient(app)
|
||||
original_acquire = agent_router._acquire_sse_slot
|
||||
original_release = agent_router._release_sse_slot
|
||||
|
||||
async def _allow_slot(*, user_id: str) -> bool:
|
||||
del user_id
|
||||
return True
|
||||
|
||||
async def _noop_release(*, user_id: str) -> None:
|
||||
del user_id
|
||||
return None
|
||||
|
||||
agent_router._acquire_sse_slot = _allow_slot # type: ignore[assignment]
|
||||
agent_router._release_sse_slot = _noop_release # type: ignore[assignment]
|
||||
|
||||
try:
|
||||
response = client.get(
|
||||
"/api/v1/agent/runs/00000000-0000-0000-0000-000000000001/events?runId=run-1&idle_limit=3"
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert response.headers["content-type"].startswith("text/event-stream")
|
||||
assert '"runId":"run-old"' not in response.text
|
||||
assert '"runId":"run-1"' in response.text
|
||||
assert "event: RUN_STARTED" in response.text
|
||||
assert "event: STEP_STARTED" in response.text
|
||||
assert "event: RUN_FINISHED" in response.text
|
||||
assert service.stream_calls == 3
|
||||
finally:
|
||||
agent_router._acquire_sse_slot = original_acquire # type: ignore[assignment]
|
||||
agent_router._release_sse_slot = original_release # type: ignore[assignment]
|
||||
app.dependency_overrides = {}
|
||||
|
||||
|
||||
def test_stream_rejects_invalid_or_missing_run_id() -> None:
|
||||
app.dependency_overrides[get_agent_service] = lambda: _FakeAgentService()
|
||||
app.dependency_overrides[get_current_user] = lambda: CurrentUser(
|
||||
id=uuid4(), phone="+8613812345678"
|
||||
)
|
||||
client = TestClient(app)
|
||||
|
||||
try:
|
||||
invalid = client.get(
|
||||
"/api/v1/agent/runs/00000000-0000-0000-0000-000000000001/events?runId=bad%20id"
|
||||
)
|
||||
assert invalid.status_code == 422
|
||||
assert invalid.json()["code"] == "AGENT_INVALID_RUN_ID"
|
||||
|
||||
missing = client.get(
|
||||
"/api/v1/agent/runs/00000000-0000-0000-0000-000000000001/events"
|
||||
)
|
||||
assert missing.status_code == 422
|
||||
assert missing.json()["code"] == "AGENT_INVALID_RUN_ID"
|
||||
finally:
|
||||
app.dependency_overrides = {}
|
||||
|
||||
|
||||
def test_stream_rejects_when_sse_connection_limit_exceeded() -> None:
|
||||
app.dependency_overrides[get_agent_service] = lambda: _FakeAgentService()
|
||||
app.dependency_overrides[get_current_user] = lambda: CurrentUser(
|
||||
@@ -336,7 +457,7 @@ def test_stream_rejects_when_sse_connection_limit_exceeded() -> None:
|
||||
|
||||
try:
|
||||
response = client.get(
|
||||
"/api/v1/agent/runs/00000000-0000-0000-0000-000000000001/events"
|
||||
"/api/v1/agent/runs/00000000-0000-0000-0000-000000000001/events?runId=run-1"
|
||||
)
|
||||
assert response.status_code == 429
|
||||
payload = response.json()
|
||||
@@ -587,7 +708,7 @@ def test_asr_transcribe_returns_sync_transcript(monkeypatch) -> None:
|
||||
return "这是测试转写结果"
|
||||
|
||||
monkeypatch.setattr(
|
||||
"v1.agent.service.asr_service.transcribe_file",
|
||||
"v1.agent.router.asr_service.transcribe_file",
|
||||
mock_transcribe_file,
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user