fix(redis): 修复 Redis 流读取兼容性问题

- 支持 bytes 和 str 类型的 entry_id
- 支持 list 类型响应格式
- 优化 payload 解码处理
This commit is contained in:
qzl
2026-03-11 21:33:25 +08:00
parent e4f69a64bd
commit 18db6c50e7
17 changed files with 359 additions and 54 deletions
@@ -51,6 +51,30 @@ class _FakeRedisBytes:
return [(stream_name, rows)]
class _FakeRedisListResponse:
def __init__(self) -> None:
self._rows: list[tuple[str, str]] = []
def xadd(self, _stream: str, fields: dict[str, str]) -> str:
cursor = f"{len(self._rows) + 1}-0"
self._rows.append((cursor, fields["event"]))
return cursor
def xread(
self,
streams: dict[str, str],
count: int,
block: int,
) -> list[list[object]]:
del count, block
stream_name, last = next(iter(streams.items()))
rows: list[tuple[str, dict[str, str]]] = []
for cursor, payload in self._rows:
if cursor > last:
rows.append((cursor, {"event": payload}))
return [[stream_name, rows]]
async def test_publish_then_read_after_cursor() -> None:
bus = RedisStreamBus(client=_FakeRedis(), stream_prefix="agent.events")
@@ -69,3 +93,10 @@ async def test_read_supports_bytes_payload() -> None:
await bus.publish(session_id="thread-1", event={"type": "RUN_STARTED"})
rows = await bus.read(session_id="thread-1", last_event_id=None)
assert rows[0]["event"]["type"] == "RUN_STARTED"
async def test_read_supports_list_wrapped_stream_response() -> None:
bus = RedisStreamBus(client=_FakeRedisListResponse(), stream_prefix="agent.events")
await bus.publish(session_id="thread-1", event={"type": "RUN_STARTED"})
rows = await bus.read(session_id="thread-1", last_event_id=None)
assert rows[0]["event"]["type"] == "RUN_STARTED"
@@ -104,3 +104,22 @@ def test_schemas_exports_include_task_and_history_models() -> None:
assert exported_schemas.TaskAccepted is AcceptedTaskResponse
assert exported_schemas.TaskAcceptedResponse is AcceptedTaskResponse
assert exported_schemas.HistorySnapshotResponse is HistorySnapshotResponse
def test_run_command_accepts_agui_context_list_and_parent_run_id() -> None:
payload = {
"threadId": "thread-xyz",
"runId": "run-xyz",
"state": {},
"messages": [{"id": "u1", "role": "user", "content": "hello"}],
"tools": [],
"context": [],
"forwardedProps": {},
"parentRunId": None,
}
command = RunCommand.model_validate(payload)
dumped = command.model_dump(mode="json", by_alias=True)
assert dumped["context"] == []
assert "parentRunId" in dumped
@@ -2,7 +2,7 @@ from __future__ import annotations
import pytest
from v1.agent.dependencies import TaskiqQueueClient
from v1.agent.dependencies import RedisEventStream, TaskiqQueueClient
class _FakeRedis:
@@ -39,6 +39,10 @@ class _FakeAsyncResult:
self.task_id = task_id
class _FakeRedisStreamClient:
pass
@pytest.mark.asyncio
async def test_enqueue_returns_task_id(monkeypatch: pytest.MonkeyPatch) -> None:
from v1.agent import dependencies as deps
@@ -89,7 +93,11 @@ async def test_enqueue_resume_dedup_returns_existing_task_id(
client = TaskiqQueueClient()
task_id = await client.enqueue(
command={"command": "resume", "session_id": "session-1", "tool_call_id": "call-1"},
command={
"command": "resume",
"session_id": "session-1",
"tool_call_id": "call-1",
},
dedup_key=dedup_key,
)
@@ -132,7 +140,11 @@ async def test_enqueue_inflight_dedup_waits_and_reuses_existing_task_id(
client = TaskiqQueueClient()
task_id = await client.enqueue(
command={"command": "resume", "session_id": "session-1", "tool_call_id": "call-1"},
command={
"command": "resume",
"session_id": "session-1",
"tool_call_id": "call-1",
},
dedup_key=dedup_key,
)
@@ -140,7 +152,9 @@ async def test_enqueue_inflight_dedup_waits_and_reuses_existing_task_id(
@pytest.mark.asyncio
async def test_enqueue_failure_cleans_dedup_lock(monkeypatch: pytest.MonkeyPatch) -> None:
async def test_enqueue_failure_cleans_dedup_lock(
monkeypatch: pytest.MonkeyPatch,
) -> None:
from v1.agent import dependencies as deps
fake_redis = _FakeRedis()
@@ -160,7 +174,11 @@ async def test_enqueue_failure_cleans_dedup_lock(monkeypatch: pytest.MonkeyPatch
client = TaskiqQueueClient()
with pytest.raises(RuntimeError, match="enqueue failed"):
await client.enqueue(
command={"command": "resume", "session_id": "session-1", "tool_call_id": "call-1"},
command={
"command": "resume",
"session_id": "session-1",
"tool_call_id": "call-1",
},
dedup_key=dedup_key,
)
@@ -225,3 +243,41 @@ async def test_enqueue_uses_bulk_queue_when_requested(
)
assert task_id == "bulk-task-id"
@pytest.mark.asyncio
async def test_event_stream_caps_block_ms_below_socket_timeout(
monkeypatch: pytest.MonkeyPatch,
) -> None:
from v1.agent import dependencies as deps
async def _fake_get_or_init_client() -> _FakeRedisStreamClient:
return _FakeRedisStreamClient()
monkeypatch.setattr(deps, "get_or_init_redis_client", _fake_get_or_init_client)
monkeypatch.setattr(deps.config.agent_runtime, "redis_stream_block_ms", 5000)
monkeypatch.setattr(deps.config.redis, "socket_timeout", 1.0)
stream = RedisEventStream()
bus = await stream._get_bus()
assert bus._block_ms == 900
@pytest.mark.asyncio
async def test_event_stream_uses_configured_block_ms_when_safe(
monkeypatch: pytest.MonkeyPatch,
) -> None:
from v1.agent import dependencies as deps
async def _fake_get_or_init_client() -> _FakeRedisStreamClient:
return _FakeRedisStreamClient()
monkeypatch.setattr(deps, "get_or_init_redis_client", _fake_get_or_init_client)
monkeypatch.setattr(deps.config.agent_runtime, "redis_stream_block_ms", 200)
monkeypatch.setattr(deps.config.redis, "socket_timeout", 2.0)
stream = RedisEventStream()
bus = await stream._get_bus()
assert bus._block_ms == 200
@@ -0,0 +1,55 @@
from __future__ import annotations
from uuid import UUID
import pytest
from fastapi import HTTPException
from core.auth.jwt_verifier import TokenValidationError
import v1.users.dependencies as deps
@pytest.mark.asyncio
async def test_get_current_user_falls_back_to_supabase_validation(monkeypatch) -> None:
class _BrokenVerifier:
def verify(self, token: str) -> dict[str, object]:
del token
raise TokenValidationError("Token validation failed")
monkeypatch.setattr(deps, "get_jwt_verifier", lambda: _BrokenVerifier())
async def _fallback(token: str):
del token
return deps.CurrentUser(
id=UUID("e8845a17-282b-4a63-8025-194a06235958"),
email="dagronl@126.com",
role="authenticated",
)
monkeypatch.setattr(deps, "_verify_user_with_supabase", _fallback)
user = await deps.get_current_user(authorization="Bearer valid-token")
assert str(user.id) == "e8845a17-282b-4a63-8025-194a06235958"
assert user.email == "dagronl@126.com"
@pytest.mark.asyncio
async def test_get_current_user_raises_401_when_fallback_fails(monkeypatch) -> None:
class _BrokenVerifier:
def verify(self, token: str) -> dict[str, object]:
del token
raise TokenValidationError("Token validation failed")
monkeypatch.setattr(deps, "get_jwt_verifier", lambda: _BrokenVerifier())
async def _fallback(token: str):
del token
return None
monkeypatch.setattr(deps, "_verify_user_with_supabase", _fallback)
with pytest.raises(HTTPException) as exc:
await deps.get_current_user(authorization="Bearer invalid-token")
assert exc.value.status_code == 401