fix(redis): 修复 Redis 流读取兼容性问题
- 支持 bytes 和 str 类型的 entry_id - 支持 list 类型响应格式 - 优化 payload 解码处理
This commit is contained in:
@@ -110,6 +110,18 @@ class _FakeAgentService:
|
||||
}
|
||||
|
||||
|
||||
class _FailingStreamAgentService(_FakeAgentService):
|
||||
async def stream_events(
|
||||
self,
|
||||
*,
|
||||
thread_id: str,
|
||||
last_event_id: str | None,
|
||||
current_user: CurrentUser,
|
||||
) -> list[dict[str, object]]:
|
||||
del thread_id, last_event_id, current_user
|
||||
raise RuntimeError("redis timeout")
|
||||
|
||||
|
||||
def test_run_requires_auth_and_returns_202_task_id() -> None:
|
||||
app.dependency_overrides[get_agent_service] = lambda: _FakeAgentService()
|
||||
client = TestClient(app)
|
||||
@@ -197,6 +209,38 @@ def test_stream_reads_from_last_event_id() -> None:
|
||||
app.dependency_overrides = {}
|
||||
|
||||
|
||||
def test_stream_handles_stream_backend_errors_without_connection_crash() -> None:
|
||||
app.dependency_overrides[get_agent_service] = lambda: _FailingStreamAgentService()
|
||||
app.dependency_overrides[get_current_user] = lambda: CurrentUser(
|
||||
id=uuid4(), email="user@example.com"
|
||||
)
|
||||
client = TestClient(app)
|
||||
original_acquire = agent_router._acquire_sse_slot
|
||||
original_release = agent_router._release_sse_slot
|
||||
|
||||
async def _allow_slot(*, user_id: str) -> bool:
|
||||
del user_id
|
||||
return True
|
||||
|
||||
async def _noop_release(*, user_id: str) -> None:
|
||||
del user_id
|
||||
return None
|
||||
|
||||
agent_router._acquire_sse_slot = _allow_slot # type: ignore[assignment]
|
||||
agent_router._release_sse_slot = _noop_release # type: ignore[assignment]
|
||||
|
||||
try:
|
||||
response = client.get(
|
||||
"/api/v1/agent/runs/00000000-0000-0000-0000-000000000001/events?idle_limit=1"
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert response.headers["content-type"].startswith("text/event-stream")
|
||||
finally:
|
||||
agent_router._acquire_sse_slot = original_acquire # type: ignore[assignment]
|
||||
agent_router._release_sse_slot = original_release # type: ignore[assignment]
|
||||
app.dependency_overrides = {}
|
||||
|
||||
|
||||
def test_stream_rejects_invalid_last_event_id() -> None:
|
||||
app.dependency_overrides[get_agent_service] = lambda: _FakeAgentService()
|
||||
app.dependency_overrides[get_current_user] = lambda: CurrentUser(
|
||||
|
||||
@@ -142,20 +142,19 @@ async def test_agent_runs_events_history_live_with_image_input() -> None:
|
||||
assert run_resp.status_code == 202
|
||||
|
||||
events_url = f"{BASE_URL}/api/v1/agent/runs/{thread_id}/events"
|
||||
event_names: list[str] = []
|
||||
async with client.stream(
|
||||
"GET", events_url, headers=headers, timeout=20.0
|
||||
) as sse_resp:
|
||||
assert sse_resp.status_code == 200
|
||||
assert sse_resp.headers.get("content-type", "").startswith(
|
||||
"text/event-stream"
|
||||
)
|
||||
async for line in sse_resp.aiter_lines():
|
||||
if line.startswith("event:"):
|
||||
event_name = line.split(":", 1)[1].strip()
|
||||
event_names.append(event_name)
|
||||
if event_name in {"RUN_FINISHED", "RUN_ERROR"}:
|
||||
break
|
||||
sse_resp = await client.get(
|
||||
events_url,
|
||||
headers=headers,
|
||||
params={"idle_limit": 150},
|
||||
timeout=60.0,
|
||||
)
|
||||
assert sse_resp.status_code == 200
|
||||
assert sse_resp.headers.get("content-type", "").startswith("text/event-stream")
|
||||
event_names = [
|
||||
line.split(":", 1)[1].strip()
|
||||
for line in sse_resp.text.splitlines()
|
||||
if line.startswith("event:")
|
||||
]
|
||||
|
||||
assert "RUN_STARTED" in event_names
|
||||
assert "RUN_FINISHED" in event_names or "RUN_ERROR" in event_names
|
||||
|
||||
@@ -51,6 +51,30 @@ class _FakeRedisBytes:
|
||||
return [(stream_name, rows)]
|
||||
|
||||
|
||||
class _FakeRedisListResponse:
|
||||
def __init__(self) -> None:
|
||||
self._rows: list[tuple[str, str]] = []
|
||||
|
||||
def xadd(self, _stream: str, fields: dict[str, str]) -> str:
|
||||
cursor = f"{len(self._rows) + 1}-0"
|
||||
self._rows.append((cursor, fields["event"]))
|
||||
return cursor
|
||||
|
||||
def xread(
|
||||
self,
|
||||
streams: dict[str, str],
|
||||
count: int,
|
||||
block: int,
|
||||
) -> list[list[object]]:
|
||||
del count, block
|
||||
stream_name, last = next(iter(streams.items()))
|
||||
rows: list[tuple[str, dict[str, str]]] = []
|
||||
for cursor, payload in self._rows:
|
||||
if cursor > last:
|
||||
rows.append((cursor, {"event": payload}))
|
||||
return [[stream_name, rows]]
|
||||
|
||||
|
||||
async def test_publish_then_read_after_cursor() -> None:
|
||||
bus = RedisStreamBus(client=_FakeRedis(), stream_prefix="agent.events")
|
||||
|
||||
@@ -69,3 +93,10 @@ async def test_read_supports_bytes_payload() -> None:
|
||||
await bus.publish(session_id="thread-1", event={"type": "RUN_STARTED"})
|
||||
rows = await bus.read(session_id="thread-1", last_event_id=None)
|
||||
assert rows[0]["event"]["type"] == "RUN_STARTED"
|
||||
|
||||
|
||||
async def test_read_supports_list_wrapped_stream_response() -> None:
|
||||
bus = RedisStreamBus(client=_FakeRedisListResponse(), stream_prefix="agent.events")
|
||||
await bus.publish(session_id="thread-1", event={"type": "RUN_STARTED"})
|
||||
rows = await bus.read(session_id="thread-1", last_event_id=None)
|
||||
assert rows[0]["event"]["type"] == "RUN_STARTED"
|
||||
|
||||
@@ -104,3 +104,22 @@ def test_schemas_exports_include_task_and_history_models() -> None:
|
||||
assert exported_schemas.TaskAccepted is AcceptedTaskResponse
|
||||
assert exported_schemas.TaskAcceptedResponse is AcceptedTaskResponse
|
||||
assert exported_schemas.HistorySnapshotResponse is HistorySnapshotResponse
|
||||
|
||||
|
||||
def test_run_command_accepts_agui_context_list_and_parent_run_id() -> None:
|
||||
payload = {
|
||||
"threadId": "thread-xyz",
|
||||
"runId": "run-xyz",
|
||||
"state": {},
|
||||
"messages": [{"id": "u1", "role": "user", "content": "hello"}],
|
||||
"tools": [],
|
||||
"context": [],
|
||||
"forwardedProps": {},
|
||||
"parentRunId": None,
|
||||
}
|
||||
|
||||
command = RunCommand.model_validate(payload)
|
||||
|
||||
dumped = command.model_dump(mode="json", by_alias=True)
|
||||
assert dumped["context"] == []
|
||||
assert "parentRunId" in dumped
|
||||
|
||||
@@ -2,7 +2,7 @@ from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from v1.agent.dependencies import TaskiqQueueClient
|
||||
from v1.agent.dependencies import RedisEventStream, TaskiqQueueClient
|
||||
|
||||
|
||||
class _FakeRedis:
|
||||
@@ -39,6 +39,10 @@ class _FakeAsyncResult:
|
||||
self.task_id = task_id
|
||||
|
||||
|
||||
class _FakeRedisStreamClient:
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_enqueue_returns_task_id(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
from v1.agent import dependencies as deps
|
||||
@@ -89,7 +93,11 @@ async def test_enqueue_resume_dedup_returns_existing_task_id(
|
||||
|
||||
client = TaskiqQueueClient()
|
||||
task_id = await client.enqueue(
|
||||
command={"command": "resume", "session_id": "session-1", "tool_call_id": "call-1"},
|
||||
command={
|
||||
"command": "resume",
|
||||
"session_id": "session-1",
|
||||
"tool_call_id": "call-1",
|
||||
},
|
||||
dedup_key=dedup_key,
|
||||
)
|
||||
|
||||
@@ -132,7 +140,11 @@ async def test_enqueue_inflight_dedup_waits_and_reuses_existing_task_id(
|
||||
|
||||
client = TaskiqQueueClient()
|
||||
task_id = await client.enqueue(
|
||||
command={"command": "resume", "session_id": "session-1", "tool_call_id": "call-1"},
|
||||
command={
|
||||
"command": "resume",
|
||||
"session_id": "session-1",
|
||||
"tool_call_id": "call-1",
|
||||
},
|
||||
dedup_key=dedup_key,
|
||||
)
|
||||
|
||||
@@ -140,7 +152,9 @@ async def test_enqueue_inflight_dedup_waits_and_reuses_existing_task_id(
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_enqueue_failure_cleans_dedup_lock(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
async def test_enqueue_failure_cleans_dedup_lock(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
from v1.agent import dependencies as deps
|
||||
|
||||
fake_redis = _FakeRedis()
|
||||
@@ -160,7 +174,11 @@ async def test_enqueue_failure_cleans_dedup_lock(monkeypatch: pytest.MonkeyPatch
|
||||
client = TaskiqQueueClient()
|
||||
with pytest.raises(RuntimeError, match="enqueue failed"):
|
||||
await client.enqueue(
|
||||
command={"command": "resume", "session_id": "session-1", "tool_call_id": "call-1"},
|
||||
command={
|
||||
"command": "resume",
|
||||
"session_id": "session-1",
|
||||
"tool_call_id": "call-1",
|
||||
},
|
||||
dedup_key=dedup_key,
|
||||
)
|
||||
|
||||
@@ -225,3 +243,41 @@ async def test_enqueue_uses_bulk_queue_when_requested(
|
||||
)
|
||||
|
||||
assert task_id == "bulk-task-id"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_event_stream_caps_block_ms_below_socket_timeout(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
from v1.agent import dependencies as deps
|
||||
|
||||
async def _fake_get_or_init_client() -> _FakeRedisStreamClient:
|
||||
return _FakeRedisStreamClient()
|
||||
|
||||
monkeypatch.setattr(deps, "get_or_init_redis_client", _fake_get_or_init_client)
|
||||
monkeypatch.setattr(deps.config.agent_runtime, "redis_stream_block_ms", 5000)
|
||||
monkeypatch.setattr(deps.config.redis, "socket_timeout", 1.0)
|
||||
|
||||
stream = RedisEventStream()
|
||||
bus = await stream._get_bus()
|
||||
|
||||
assert bus._block_ms == 900
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_event_stream_uses_configured_block_ms_when_safe(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
from v1.agent import dependencies as deps
|
||||
|
||||
async def _fake_get_or_init_client() -> _FakeRedisStreamClient:
|
||||
return _FakeRedisStreamClient()
|
||||
|
||||
monkeypatch.setattr(deps, "get_or_init_redis_client", _fake_get_or_init_client)
|
||||
monkeypatch.setattr(deps.config.agent_runtime, "redis_stream_block_ms", 200)
|
||||
monkeypatch.setattr(deps.config.redis, "socket_timeout", 2.0)
|
||||
|
||||
stream = RedisEventStream()
|
||||
bus = await stream._get_bus()
|
||||
|
||||
assert bus._block_ms == 200
|
||||
|
||||
@@ -0,0 +1,55 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
import pytest
|
||||
from fastapi import HTTPException
|
||||
|
||||
from core.auth.jwt_verifier import TokenValidationError
|
||||
import v1.users.dependencies as deps
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_current_user_falls_back_to_supabase_validation(monkeypatch) -> None:
|
||||
class _BrokenVerifier:
|
||||
def verify(self, token: str) -> dict[str, object]:
|
||||
del token
|
||||
raise TokenValidationError("Token validation failed")
|
||||
|
||||
monkeypatch.setattr(deps, "get_jwt_verifier", lambda: _BrokenVerifier())
|
||||
|
||||
async def _fallback(token: str):
|
||||
del token
|
||||
return deps.CurrentUser(
|
||||
id=UUID("e8845a17-282b-4a63-8025-194a06235958"),
|
||||
email="dagronl@126.com",
|
||||
role="authenticated",
|
||||
)
|
||||
|
||||
monkeypatch.setattr(deps, "_verify_user_with_supabase", _fallback)
|
||||
|
||||
user = await deps.get_current_user(authorization="Bearer valid-token")
|
||||
|
||||
assert str(user.id) == "e8845a17-282b-4a63-8025-194a06235958"
|
||||
assert user.email == "dagronl@126.com"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_current_user_raises_401_when_fallback_fails(monkeypatch) -> None:
|
||||
class _BrokenVerifier:
|
||||
def verify(self, token: str) -> dict[str, object]:
|
||||
del token
|
||||
raise TokenValidationError("Token validation failed")
|
||||
|
||||
monkeypatch.setattr(deps, "get_jwt_verifier", lambda: _BrokenVerifier())
|
||||
|
||||
async def _fallback(token: str):
|
||||
del token
|
||||
return None
|
||||
|
||||
monkeypatch.setattr(deps, "_verify_user_with_supabase", _fallback)
|
||||
|
||||
with pytest.raises(HTTPException) as exc:
|
||||
await deps.get_current_user(authorization="Bearer invalid-token")
|
||||
|
||||
assert exc.value.status_code == 401
|
||||
Reference in New Issue
Block a user