fix(redis): 修复 Redis 流读取兼容性问题

- 支持 bytes 和 str 类型的 entry_id
- 支持 list 类型响应格式
- 优化 payload 解码处理
This commit is contained in:
qzl
2026-03-11 21:33:25 +08:00
parent e4f69a64bd
commit 18db6c50e7
17 changed files with 359 additions and 54 deletions
@@ -55,23 +55,29 @@ class RedisStreamBus:
return []
first = response[0]
if (
not isinstance(first, tuple)
or len(first) != 2
or not isinstance(first[1], list)
):
if not isinstance(first, (list, tuple)) or len(first) != 2:
return []
entries = cast(list[tuple[str, dict[str, Any]]], first[1])
entries_raw = first[1]
if not isinstance(entries_raw, list):
return []
entries = cast(list[tuple[Any, dict[str, Any]]], entries_raw)
rows: list[dict[str, Any]] = []
for entry in entries:
if (
not isinstance(entry, tuple)
or len(entry) != 2
or not isinstance(entry[0], str)
or not isinstance(entry[1], dict)
):
continue
entry_id_raw = entry[0]
if isinstance(entry_id_raw, bytes):
entry_id = entry_id_raw.decode("utf-8", errors="replace")
elif isinstance(entry_id_raw, str):
entry_id = entry_id_raw
else:
continue
payload_map = cast(dict[str, Any], entry[1])
event_payload = payload_map.get("event")
if isinstance(event_payload, bytes):
@@ -84,7 +90,7 @@ class RedisStreamBus:
continue
if not isinstance(decoded, dict):
continue
rows.append({"id": entry[0], "event": decoded})
rows.append({"id": entry_id, "event": decoded})
return rows
def _stream_name(self, session_id: str) -> str:
@@ -24,7 +24,8 @@ class RunCommand(_AliasModel):
state: dict[str, Any] | None = None
messages: list[dict[str, Any]] = Field(default_factory=list)
tools: list[dict[str, Any]] = Field(default_factory=list)
context: dict[str, Any] = Field(default_factory=dict)
context: dict[str, Any] | list[dict[str, Any]] = Field(default_factory=list)
parent_run_id: str | None = Field(default=None, alias="parentRunId")
forwarded_props: dict[str, Any] = Field(
default_factory=dict, alias="forwardedProps"
)
+9 -1
View File
@@ -29,6 +29,14 @@ DEDUP_LOCK_SECONDS = 300
DEDUP_INFLIGHT_MARKER = "__inflight__"
def _event_stream_block_ms() -> int:
configured = int(config.agent_runtime.redis_stream_block_ms)
socket_timeout = float(config.redis.socket_timeout)
socket_timeout_ms = max(int(socket_timeout * 1000), 1)
safe_max = max(socket_timeout_ms - 100, 1)
return max(1, min(configured, safe_max))
class TaskiqQueueClient:
def __init__(self) -> None:
self._redis: Redis | None = None
@@ -93,7 +101,7 @@ class RedisEventStream:
client=client,
stream_prefix=config.agent_runtime.redis_stream_prefix,
read_count=config.agent_runtime.redis_stream_read_count,
block_ms=config.agent_runtime.redis_stream_block_ms,
block_ms=_event_stream_block_ms(),
)
return self._bus
+18 -5
View File
@@ -21,6 +21,7 @@ from core.agentscope.schemas.agui_input import (
validate_run_request_messages_contract,
)
from core.auth.models import CurrentUser
from core.logging import get_logger
from services.base.redis import get_or_init_redis_client
from v1.agent.dependencies import get_agent_service
from v1.agent.schemas import AsrTranscribeResponse, TaskAcceptedResponse
@@ -28,6 +29,7 @@ from v1.agent.service import AgentService, asr_service
from v1.users.dependencies import get_current_user
router = APIRouter(prefix="/agent", tags=["agent"])
logger = get_logger("v1.agent.router")
_LAST_EVENT_ID_RE = re.compile(r"^\d+-\d+$")
_RUNS_PER_MINUTE = 30
_TRANSCRIBES_PER_MINUTE = 20
@@ -188,11 +190,21 @@ async def stream_events(
idle_polls = 0
try:
while not await request.is_disconnected() and idle_polls < idle_limit:
rows = await service.stream_events(
thread_id=thread_id,
last_event_id=cursor,
current_user=current_user,
)
try:
rows = await service.stream_events(
thread_id=thread_id,
last_event_id=cursor,
current_user=current_user,
)
except Exception as exc: # noqa: BLE001
logger.warning(
"SSE stream read failed",
thread_id=thread_id,
user_id=str(current_user.id),
reason=str(exc),
)
break
if not rows:
idle_polls += 1
yield ": keep-alive\n\n"
@@ -207,6 +219,7 @@ async def stream_events(
continue
cursor = row_id
yield to_sse_event(row_id, event)
finally:
await _release_sse_slot(user_id=str(current_user.id))
+17 -7
View File
@@ -203,15 +203,25 @@ class AgentService:
f"agent-inputs/{current_user.id}/{run_input.thread_id}/"
f"{run_input.run_id}/attachment-{index}-{checksum}.{suffix}"
)
stored_path = await self._attachment_storage.upload_bytes(
bucket=config.storage.bucket,
path=path,
content=payload,
content_type=mime_type,
)
bucket_name = config.storage.bucket
try:
stored_path = await self._attachment_storage.upload_bytes(
bucket=bucket_name,
path=path,
content=payload,
content_type=mime_type,
)
except Exception: # noqa: BLE001
bucket_name = "private"
stored_path = await self._attachment_storage.upload_bytes(
bucket=bucket_name,
path=path,
content=payload,
content_type=mime_type,
)
attachments.append(
{
"bucket": config.storage.bucket,
"bucket": bucket_name,
"path": stored_path,
"mimeType": mime_type,
}
+42 -2
View File
@@ -1,5 +1,6 @@
from __future__ import annotations
import asyncio
from typing import Annotated
from uuid import UUID
@@ -14,6 +15,7 @@ from core.auth.models import CurrentUser
from core.config.settings import config
from core.db import get_db
from core.logging import get_logger
from services.base.supabase import supabase_service
from v1.auth.gateway import SupabaseAuthGateway
from v1.users.repository import SQLAlchemyUserRepository
from v1.users.service import AuthLookupAdapter, UserService
@@ -51,7 +53,41 @@ def get_jwt_verifier() -> JwtVerifier:
return _jwt_verifier
def get_current_user(authorization: str | None = Header(default=None)) -> CurrentUser:
async def _verify_user_with_supabase(token: str) -> CurrentUser | None:
try:
client = supabase_service.get_client()
except Exception as exc: # noqa: BLE001
logger.warning("Supabase fallback unavailable", reason=str(exc))
return None
try:
response = await asyncio.to_thread(client.auth.get_user, token)
except Exception as exc: # noqa: BLE001
logger.warning("Supabase token fallback validation failed", reason=str(exc))
return None
user = getattr(response, "user", None)
if user is None:
return None
user_id = getattr(user, "id", None)
if not isinstance(user_id, str) or not user_id:
return None
try:
parsed_id = UUID(user_id)
except ValueError:
return None
email = getattr(user, "email", None)
role = getattr(user, "role", None)
return CurrentUser(
id=parsed_id,
email=email if isinstance(email, str) else None,
role=role if isinstance(role, str) else None,
)
async def get_current_user(
authorization: str | None = Header(default=None),
) -> CurrentUser:
if not authorization:
logger.warning("JWT validation failed: missing authorization header")
raise HTTPException(status_code=401, detail="Unauthorized")
@@ -71,7 +107,11 @@ def get_current_user(authorization: str | None = Header(default=None)) -> Curren
error_type=type(exc).__name__,
reason=str(exc),
)
raise HTTPException(status_code=401, detail="Unauthorized") from exc
fallback_user = await _verify_user_with_supabase(token)
if fallback_user is None:
raise HTTPException(status_code=401, detail="Unauthorized") from exc
logger.info("JWT fallback validation succeeded", user_id=str(fallback_user.id))
return fallback_user
subject = payload.get("sub")
if not isinstance(subject, str) or not subject: