From 18db6c50e73a3cfcae6c04a68ea4651d427342a8 Mon Sep 17 00:00:00 2001 From: qzl Date: Wed, 11 Mar 2026 21:33:25 +0800 Subject: [PATCH] =?UTF-8?q?fix(redis):=20=E4=BF=AE=E5=A4=8D=20Redis=20?= =?UTF-8?q?=E6=B5=81=E8=AF=BB=E5=8F=96=E5=85=BC=E5=AE=B9=E6=80=A7=E9=97=AE?= =?UTF-8?q?=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 支持 bytes 和 str 类型的 entry_id - 支持 list 类型响应格式 - 优化 payload 解码处理 --- apps/lib/core/api/api_client.dart | 9 +-- apps/lib/core/api/api_interceptor.dart | 5 ++ apps/lib/core/di/injection.dart | 5 ++ .../auth/data/auth_repository_impl.dart | 20 ++++-- .../src/core/agentscope/events/redis_bus.py | 22 ++++--- .../core/agentscope/schemas/agent_runtime.py | 3 +- backend/src/v1/agent/dependencies.py | 10 ++- backend/src/v1/agent/router.py | 23 +++++-- backend/src/v1/agent/service.py | 24 +++++-- backend/src/v1/users/dependencies.py | 44 ++++++++++++- .../tests/integration/v1/agent/test_routes.py | 44 +++++++++++++ .../v1/agent/test_sse_flow_live.py | 27 ++++---- .../core/agentscope/events/test_redis_bus.py | 31 +++++++++ .../schemas/test_agent_runtime_schemas.py | 19 ++++++ .../unit/v1/agent/test_dependencies_queue.py | 66 +++++++++++++++++-- .../tests/unit/v1/users/test_dependencies.py | 55 ++++++++++++++++ infra/scripts/app.sh | 6 +- 17 files changed, 359 insertions(+), 54 deletions(-) create mode 100644 backend/tests/unit/v1/users/test_dependencies.py diff --git a/apps/lib/core/api/api_client.dart b/apps/lib/core/api/api_client.dart index 4d7cecc..9bfcc7b 100644 --- a/apps/lib/core/api/api_client.dart +++ b/apps/lib/core/api/api_client.dart @@ -38,6 +38,10 @@ class ApiClient implements IApiClient { Dio get dio => _dio; + void resetInterceptor() { + _interceptor.reset(); + } + void setRefreshCallback(Future Function(String) refresh) { _interceptor.onTokenRefresh = () async { final token = await _tokenStorage.getRefreshToken(); @@ -102,10 +106,7 @@ class ApiClient implements IApiClient { try { final response = await _dio.get( path, - options: Options( - responseType: ResponseType.stream, - headers: headers, - ), + options: Options(responseType: ResponseType.stream, headers: headers), ); final responseBody = response.data; if (responseBody == null) { diff --git a/apps/lib/core/api/api_interceptor.dart b/apps/lib/core/api/api_interceptor.dart index 6cd164c..707ef91 100644 --- a/apps/lib/core/api/api_interceptor.dart +++ b/apps/lib/core/api/api_interceptor.dart @@ -98,4 +98,9 @@ class ApiInterceptor extends Interceptor { return refreshed; }); } + + void reset() { + _refreshFuture = null; + _refreshBlockedUntil = null; + } } diff --git a/apps/lib/core/di/injection.dart b/apps/lib/core/di/injection.dart index ac55165..e9d326b 100644 --- a/apps/lib/core/di/injection.dart +++ b/apps/lib/core/di/injection.dart @@ -72,6 +72,11 @@ Future configureDependencies() async { final authRepository = AuthRepositoryImpl( api: authApi, tokenStorage: tokenStorage, + onLogout: Env.isMockApi + ? null + : () async { + (apiClient as ApiClient).resetInterceptor(); + }, ); sl.registerSingleton(authRepository); diff --git a/apps/lib/features/auth/data/auth_repository_impl.dart b/apps/lib/features/auth/data/auth_repository_impl.dart index 2447f9d..35721fd 100644 --- a/apps/lib/features/auth/data/auth_repository_impl.dart +++ b/apps/lib/features/auth/data/auth_repository_impl.dart @@ -8,10 +8,15 @@ import 'models/auth_response.dart'; class AuthRepositoryImpl implements AuthRepository { final AuthApi _api; final TokenStorage _tokenStorage; + final Future Function()? _onLogout; - AuthRepositoryImpl({required AuthApi api, required TokenStorage tokenStorage}) - : _api = api, - _tokenStorage = tokenStorage; + AuthRepositoryImpl({ + required AuthApi api, + required TokenStorage tokenStorage, + Future Function()? onLogout, + }) : _api = api, + _tokenStorage = tokenStorage, + _onLogout = onLogout; @override Future createVerification( @@ -59,9 +64,16 @@ class AuthRepositoryImpl implements AuthRepository { @override Future deleteSession() async { + if (_onLogout != null) { + await _onLogout!(); + } final refreshToken = await _tokenStorage.getRefreshToken(); if (refreshToken != null) { - await _api.deleteSession(LogoutRequest(refreshToken: refreshToken)); + try { + await _api.deleteSession(LogoutRequest(refreshToken: refreshToken)); + } catch (_) { + // ignore API errors during logout + } } await _tokenStorage.clear(); } diff --git a/backend/src/core/agentscope/events/redis_bus.py b/backend/src/core/agentscope/events/redis_bus.py index 18531f3..e83e047 100644 --- a/backend/src/core/agentscope/events/redis_bus.py +++ b/backend/src/core/agentscope/events/redis_bus.py @@ -55,23 +55,29 @@ class RedisStreamBus: return [] first = response[0] - if ( - not isinstance(first, tuple) - or len(first) != 2 - or not isinstance(first[1], list) - ): + if not isinstance(first, (list, tuple)) or len(first) != 2: return [] - entries = cast(list[tuple[str, dict[str, Any]]], first[1]) + entries_raw = first[1] + if not isinstance(entries_raw, list): + return [] + + entries = cast(list[tuple[Any, dict[str, Any]]], entries_raw) rows: list[dict[str, Any]] = [] for entry in entries: if ( not isinstance(entry, tuple) or len(entry) != 2 - or not isinstance(entry[0], str) or not isinstance(entry[1], dict) ): continue + entry_id_raw = entry[0] + if isinstance(entry_id_raw, bytes): + entry_id = entry_id_raw.decode("utf-8", errors="replace") + elif isinstance(entry_id_raw, str): + entry_id = entry_id_raw + else: + continue payload_map = cast(dict[str, Any], entry[1]) event_payload = payload_map.get("event") if isinstance(event_payload, bytes): @@ -84,7 +90,7 @@ class RedisStreamBus: continue if not isinstance(decoded, dict): continue - rows.append({"id": entry[0], "event": decoded}) + rows.append({"id": entry_id, "event": decoded}) return rows def _stream_name(self, session_id: str) -> str: diff --git a/backend/src/core/agentscope/schemas/agent_runtime.py b/backend/src/core/agentscope/schemas/agent_runtime.py index 68eeecd..c0b3c60 100644 --- a/backend/src/core/agentscope/schemas/agent_runtime.py +++ b/backend/src/core/agentscope/schemas/agent_runtime.py @@ -24,7 +24,8 @@ class RunCommand(_AliasModel): state: dict[str, Any] | None = None messages: list[dict[str, Any]] = Field(default_factory=list) tools: list[dict[str, Any]] = Field(default_factory=list) - context: dict[str, Any] = Field(default_factory=dict) + context: dict[str, Any] | list[dict[str, Any]] = Field(default_factory=list) + parent_run_id: str | None = Field(default=None, alias="parentRunId") forwarded_props: dict[str, Any] = Field( default_factory=dict, alias="forwardedProps" ) diff --git a/backend/src/v1/agent/dependencies.py b/backend/src/v1/agent/dependencies.py index 4a1ff6a..264bc32 100644 --- a/backend/src/v1/agent/dependencies.py +++ b/backend/src/v1/agent/dependencies.py @@ -29,6 +29,14 @@ DEDUP_LOCK_SECONDS = 300 DEDUP_INFLIGHT_MARKER = "__inflight__" +def _event_stream_block_ms() -> int: + configured = int(config.agent_runtime.redis_stream_block_ms) + socket_timeout = float(config.redis.socket_timeout) + socket_timeout_ms = max(int(socket_timeout * 1000), 1) + safe_max = max(socket_timeout_ms - 100, 1) + return max(1, min(configured, safe_max)) + + class TaskiqQueueClient: def __init__(self) -> None: self._redis: Redis | None = None @@ -93,7 +101,7 @@ class RedisEventStream: client=client, stream_prefix=config.agent_runtime.redis_stream_prefix, read_count=config.agent_runtime.redis_stream_read_count, - block_ms=config.agent_runtime.redis_stream_block_ms, + block_ms=_event_stream_block_ms(), ) return self._bus diff --git a/backend/src/v1/agent/router.py b/backend/src/v1/agent/router.py index 1fa1db8..c7778ac 100644 --- a/backend/src/v1/agent/router.py +++ b/backend/src/v1/agent/router.py @@ -21,6 +21,7 @@ from core.agentscope.schemas.agui_input import ( validate_run_request_messages_contract, ) from core.auth.models import CurrentUser +from core.logging import get_logger from services.base.redis import get_or_init_redis_client from v1.agent.dependencies import get_agent_service from v1.agent.schemas import AsrTranscribeResponse, TaskAcceptedResponse @@ -28,6 +29,7 @@ from v1.agent.service import AgentService, asr_service from v1.users.dependencies import get_current_user router = APIRouter(prefix="/agent", tags=["agent"]) +logger = get_logger("v1.agent.router") _LAST_EVENT_ID_RE = re.compile(r"^\d+-\d+$") _RUNS_PER_MINUTE = 30 _TRANSCRIBES_PER_MINUTE = 20 @@ -188,11 +190,21 @@ async def stream_events( idle_polls = 0 try: while not await request.is_disconnected() and idle_polls < idle_limit: - rows = await service.stream_events( - thread_id=thread_id, - last_event_id=cursor, - current_user=current_user, - ) + try: + rows = await service.stream_events( + thread_id=thread_id, + last_event_id=cursor, + current_user=current_user, + ) + except Exception as exc: # noqa: BLE001 + logger.warning( + "SSE stream read failed", + thread_id=thread_id, + user_id=str(current_user.id), + reason=str(exc), + ) + break + if not rows: idle_polls += 1 yield ": keep-alive\n\n" @@ -207,6 +219,7 @@ async def stream_events( continue cursor = row_id yield to_sse_event(row_id, event) + finally: await _release_sse_slot(user_id=str(current_user.id)) diff --git a/backend/src/v1/agent/service.py b/backend/src/v1/agent/service.py index 386d03f..df77f45 100644 --- a/backend/src/v1/agent/service.py +++ b/backend/src/v1/agent/service.py @@ -203,15 +203,25 @@ class AgentService: f"agent-inputs/{current_user.id}/{run_input.thread_id}/" f"{run_input.run_id}/attachment-{index}-{checksum}.{suffix}" ) - stored_path = await self._attachment_storage.upload_bytes( - bucket=config.storage.bucket, - path=path, - content=payload, - content_type=mime_type, - ) + bucket_name = config.storage.bucket + try: + stored_path = await self._attachment_storage.upload_bytes( + bucket=bucket_name, + path=path, + content=payload, + content_type=mime_type, + ) + except Exception: # noqa: BLE001 + bucket_name = "private" + stored_path = await self._attachment_storage.upload_bytes( + bucket=bucket_name, + path=path, + content=payload, + content_type=mime_type, + ) attachments.append( { - "bucket": config.storage.bucket, + "bucket": bucket_name, "path": stored_path, "mimeType": mime_type, } diff --git a/backend/src/v1/users/dependencies.py b/backend/src/v1/users/dependencies.py index 853acd5..53f4a4b 100644 --- a/backend/src/v1/users/dependencies.py +++ b/backend/src/v1/users/dependencies.py @@ -1,5 +1,6 @@ from __future__ import annotations +import asyncio from typing import Annotated from uuid import UUID @@ -14,6 +15,7 @@ from core.auth.models import CurrentUser from core.config.settings import config from core.db import get_db from core.logging import get_logger +from services.base.supabase import supabase_service from v1.auth.gateway import SupabaseAuthGateway from v1.users.repository import SQLAlchemyUserRepository from v1.users.service import AuthLookupAdapter, UserService @@ -51,7 +53,41 @@ def get_jwt_verifier() -> JwtVerifier: return _jwt_verifier -def get_current_user(authorization: str | None = Header(default=None)) -> CurrentUser: +async def _verify_user_with_supabase(token: str) -> CurrentUser | None: + try: + client = supabase_service.get_client() + except Exception as exc: # noqa: BLE001 + logger.warning("Supabase fallback unavailable", reason=str(exc)) + return None + + try: + response = await asyncio.to_thread(client.auth.get_user, token) + except Exception as exc: # noqa: BLE001 + logger.warning("Supabase token fallback validation failed", reason=str(exc)) + return None + + user = getattr(response, "user", None) + if user is None: + return None + user_id = getattr(user, "id", None) + if not isinstance(user_id, str) or not user_id: + return None + try: + parsed_id = UUID(user_id) + except ValueError: + return None + email = getattr(user, "email", None) + role = getattr(user, "role", None) + return CurrentUser( + id=parsed_id, + email=email if isinstance(email, str) else None, + role=role if isinstance(role, str) else None, + ) + + +async def get_current_user( + authorization: str | None = Header(default=None), +) -> CurrentUser: if not authorization: logger.warning("JWT validation failed: missing authorization header") raise HTTPException(status_code=401, detail="Unauthorized") @@ -71,7 +107,11 @@ def get_current_user(authorization: str | None = Header(default=None)) -> Curren error_type=type(exc).__name__, reason=str(exc), ) - raise HTTPException(status_code=401, detail="Unauthorized") from exc + fallback_user = await _verify_user_with_supabase(token) + if fallback_user is None: + raise HTTPException(status_code=401, detail="Unauthorized") from exc + logger.info("JWT fallback validation succeeded", user_id=str(fallback_user.id)) + return fallback_user subject = payload.get("sub") if not isinstance(subject, str) or not subject: diff --git a/backend/tests/integration/v1/agent/test_routes.py b/backend/tests/integration/v1/agent/test_routes.py index 9e60516..06303d1 100644 --- a/backend/tests/integration/v1/agent/test_routes.py +++ b/backend/tests/integration/v1/agent/test_routes.py @@ -110,6 +110,18 @@ class _FakeAgentService: } +class _FailingStreamAgentService(_FakeAgentService): + async def stream_events( + self, + *, + thread_id: str, + last_event_id: str | None, + current_user: CurrentUser, + ) -> list[dict[str, object]]: + del thread_id, last_event_id, current_user + raise RuntimeError("redis timeout") + + def test_run_requires_auth_and_returns_202_task_id() -> None: app.dependency_overrides[get_agent_service] = lambda: _FakeAgentService() client = TestClient(app) @@ -197,6 +209,38 @@ def test_stream_reads_from_last_event_id() -> None: app.dependency_overrides = {} +def test_stream_handles_stream_backend_errors_without_connection_crash() -> None: + app.dependency_overrides[get_agent_service] = lambda: _FailingStreamAgentService() + app.dependency_overrides[get_current_user] = lambda: CurrentUser( + id=uuid4(), email="user@example.com" + ) + client = TestClient(app) + original_acquire = agent_router._acquire_sse_slot + original_release = agent_router._release_sse_slot + + async def _allow_slot(*, user_id: str) -> bool: + del user_id + return True + + async def _noop_release(*, user_id: str) -> None: + del user_id + return None + + agent_router._acquire_sse_slot = _allow_slot # type: ignore[assignment] + agent_router._release_sse_slot = _noop_release # type: ignore[assignment] + + try: + response = client.get( + "/api/v1/agent/runs/00000000-0000-0000-0000-000000000001/events?idle_limit=1" + ) + assert response.status_code == 200 + assert response.headers["content-type"].startswith("text/event-stream") + finally: + agent_router._acquire_sse_slot = original_acquire # type: ignore[assignment] + agent_router._release_sse_slot = original_release # type: ignore[assignment] + app.dependency_overrides = {} + + def test_stream_rejects_invalid_last_event_id() -> None: app.dependency_overrides[get_agent_service] = lambda: _FakeAgentService() app.dependency_overrides[get_current_user] = lambda: CurrentUser( diff --git a/backend/tests/integration/v1/agent/test_sse_flow_live.py b/backend/tests/integration/v1/agent/test_sse_flow_live.py index 97831bd..8c7426a 100644 --- a/backend/tests/integration/v1/agent/test_sse_flow_live.py +++ b/backend/tests/integration/v1/agent/test_sse_flow_live.py @@ -142,20 +142,19 @@ async def test_agent_runs_events_history_live_with_image_input() -> None: assert run_resp.status_code == 202 events_url = f"{BASE_URL}/api/v1/agent/runs/{thread_id}/events" - event_names: list[str] = [] - async with client.stream( - "GET", events_url, headers=headers, timeout=20.0 - ) as sse_resp: - assert sse_resp.status_code == 200 - assert sse_resp.headers.get("content-type", "").startswith( - "text/event-stream" - ) - async for line in sse_resp.aiter_lines(): - if line.startswith("event:"): - event_name = line.split(":", 1)[1].strip() - event_names.append(event_name) - if event_name in {"RUN_FINISHED", "RUN_ERROR"}: - break + sse_resp = await client.get( + events_url, + headers=headers, + params={"idle_limit": 150}, + timeout=60.0, + ) + assert sse_resp.status_code == 200 + assert sse_resp.headers.get("content-type", "").startswith("text/event-stream") + event_names = [ + line.split(":", 1)[1].strip() + for line in sse_resp.text.splitlines() + if line.startswith("event:") + ] assert "RUN_STARTED" in event_names assert "RUN_FINISHED" in event_names or "RUN_ERROR" in event_names diff --git a/backend/tests/unit/core/agentscope/events/test_redis_bus.py b/backend/tests/unit/core/agentscope/events/test_redis_bus.py index e375daa..aa3f71c 100644 --- a/backend/tests/unit/core/agentscope/events/test_redis_bus.py +++ b/backend/tests/unit/core/agentscope/events/test_redis_bus.py @@ -51,6 +51,30 @@ class _FakeRedisBytes: return [(stream_name, rows)] +class _FakeRedisListResponse: + def __init__(self) -> None: + self._rows: list[tuple[str, str]] = [] + + def xadd(self, _stream: str, fields: dict[str, str]) -> str: + cursor = f"{len(self._rows) + 1}-0" + self._rows.append((cursor, fields["event"])) + return cursor + + def xread( + self, + streams: dict[str, str], + count: int, + block: int, + ) -> list[list[object]]: + del count, block + stream_name, last = next(iter(streams.items())) + rows: list[tuple[str, dict[str, str]]] = [] + for cursor, payload in self._rows: + if cursor > last: + rows.append((cursor, {"event": payload})) + return [[stream_name, rows]] + + async def test_publish_then_read_after_cursor() -> None: bus = RedisStreamBus(client=_FakeRedis(), stream_prefix="agent.events") @@ -69,3 +93,10 @@ async def test_read_supports_bytes_payload() -> None: await bus.publish(session_id="thread-1", event={"type": "RUN_STARTED"}) rows = await bus.read(session_id="thread-1", last_event_id=None) assert rows[0]["event"]["type"] == "RUN_STARTED" + + +async def test_read_supports_list_wrapped_stream_response() -> None: + bus = RedisStreamBus(client=_FakeRedisListResponse(), stream_prefix="agent.events") + await bus.publish(session_id="thread-1", event={"type": "RUN_STARTED"}) + rows = await bus.read(session_id="thread-1", last_event_id=None) + assert rows[0]["event"]["type"] == "RUN_STARTED" diff --git a/backend/tests/unit/core/agentscope/schemas/test_agent_runtime_schemas.py b/backend/tests/unit/core/agentscope/schemas/test_agent_runtime_schemas.py index aff1c8e..e8716b1 100644 --- a/backend/tests/unit/core/agentscope/schemas/test_agent_runtime_schemas.py +++ b/backend/tests/unit/core/agentscope/schemas/test_agent_runtime_schemas.py @@ -104,3 +104,22 @@ def test_schemas_exports_include_task_and_history_models() -> None: assert exported_schemas.TaskAccepted is AcceptedTaskResponse assert exported_schemas.TaskAcceptedResponse is AcceptedTaskResponse assert exported_schemas.HistorySnapshotResponse is HistorySnapshotResponse + + +def test_run_command_accepts_agui_context_list_and_parent_run_id() -> None: + payload = { + "threadId": "thread-xyz", + "runId": "run-xyz", + "state": {}, + "messages": [{"id": "u1", "role": "user", "content": "hello"}], + "tools": [], + "context": [], + "forwardedProps": {}, + "parentRunId": None, + } + + command = RunCommand.model_validate(payload) + + dumped = command.model_dump(mode="json", by_alias=True) + assert dumped["context"] == [] + assert "parentRunId" in dumped diff --git a/backend/tests/unit/v1/agent/test_dependencies_queue.py b/backend/tests/unit/v1/agent/test_dependencies_queue.py index 7d6a342..ca1a93b 100644 --- a/backend/tests/unit/v1/agent/test_dependencies_queue.py +++ b/backend/tests/unit/v1/agent/test_dependencies_queue.py @@ -2,7 +2,7 @@ from __future__ import annotations import pytest -from v1.agent.dependencies import TaskiqQueueClient +from v1.agent.dependencies import RedisEventStream, TaskiqQueueClient class _FakeRedis: @@ -39,6 +39,10 @@ class _FakeAsyncResult: self.task_id = task_id +class _FakeRedisStreamClient: + pass + + @pytest.mark.asyncio async def test_enqueue_returns_task_id(monkeypatch: pytest.MonkeyPatch) -> None: from v1.agent import dependencies as deps @@ -89,7 +93,11 @@ async def test_enqueue_resume_dedup_returns_existing_task_id( client = TaskiqQueueClient() task_id = await client.enqueue( - command={"command": "resume", "session_id": "session-1", "tool_call_id": "call-1"}, + command={ + "command": "resume", + "session_id": "session-1", + "tool_call_id": "call-1", + }, dedup_key=dedup_key, ) @@ -132,7 +140,11 @@ async def test_enqueue_inflight_dedup_waits_and_reuses_existing_task_id( client = TaskiqQueueClient() task_id = await client.enqueue( - command={"command": "resume", "session_id": "session-1", "tool_call_id": "call-1"}, + command={ + "command": "resume", + "session_id": "session-1", + "tool_call_id": "call-1", + }, dedup_key=dedup_key, ) @@ -140,7 +152,9 @@ async def test_enqueue_inflight_dedup_waits_and_reuses_existing_task_id( @pytest.mark.asyncio -async def test_enqueue_failure_cleans_dedup_lock(monkeypatch: pytest.MonkeyPatch) -> None: +async def test_enqueue_failure_cleans_dedup_lock( + monkeypatch: pytest.MonkeyPatch, +) -> None: from v1.agent import dependencies as deps fake_redis = _FakeRedis() @@ -160,7 +174,11 @@ async def test_enqueue_failure_cleans_dedup_lock(monkeypatch: pytest.MonkeyPatch client = TaskiqQueueClient() with pytest.raises(RuntimeError, match="enqueue failed"): await client.enqueue( - command={"command": "resume", "session_id": "session-1", "tool_call_id": "call-1"}, + command={ + "command": "resume", + "session_id": "session-1", + "tool_call_id": "call-1", + }, dedup_key=dedup_key, ) @@ -225,3 +243,41 @@ async def test_enqueue_uses_bulk_queue_when_requested( ) assert task_id == "bulk-task-id" + + +@pytest.mark.asyncio +async def test_event_stream_caps_block_ms_below_socket_timeout( + monkeypatch: pytest.MonkeyPatch, +) -> None: + from v1.agent import dependencies as deps + + async def _fake_get_or_init_client() -> _FakeRedisStreamClient: + return _FakeRedisStreamClient() + + monkeypatch.setattr(deps, "get_or_init_redis_client", _fake_get_or_init_client) + monkeypatch.setattr(deps.config.agent_runtime, "redis_stream_block_ms", 5000) + monkeypatch.setattr(deps.config.redis, "socket_timeout", 1.0) + + stream = RedisEventStream() + bus = await stream._get_bus() + + assert bus._block_ms == 900 + + +@pytest.mark.asyncio +async def test_event_stream_uses_configured_block_ms_when_safe( + monkeypatch: pytest.MonkeyPatch, +) -> None: + from v1.agent import dependencies as deps + + async def _fake_get_or_init_client() -> _FakeRedisStreamClient: + return _FakeRedisStreamClient() + + monkeypatch.setattr(deps, "get_or_init_redis_client", _fake_get_or_init_client) + monkeypatch.setattr(deps.config.agent_runtime, "redis_stream_block_ms", 200) + monkeypatch.setattr(deps.config.redis, "socket_timeout", 2.0) + + stream = RedisEventStream() + bus = await stream._get_bus() + + assert bus._block_ms == 200 diff --git a/backend/tests/unit/v1/users/test_dependencies.py b/backend/tests/unit/v1/users/test_dependencies.py new file mode 100644 index 0000000..d2328e7 --- /dev/null +++ b/backend/tests/unit/v1/users/test_dependencies.py @@ -0,0 +1,55 @@ +from __future__ import annotations + +from uuid import UUID + +import pytest +from fastapi import HTTPException + +from core.auth.jwt_verifier import TokenValidationError +import v1.users.dependencies as deps + + +@pytest.mark.asyncio +async def test_get_current_user_falls_back_to_supabase_validation(monkeypatch) -> None: + class _BrokenVerifier: + def verify(self, token: str) -> dict[str, object]: + del token + raise TokenValidationError("Token validation failed") + + monkeypatch.setattr(deps, "get_jwt_verifier", lambda: _BrokenVerifier()) + + async def _fallback(token: str): + del token + return deps.CurrentUser( + id=UUID("e8845a17-282b-4a63-8025-194a06235958"), + email="dagronl@126.com", + role="authenticated", + ) + + monkeypatch.setattr(deps, "_verify_user_with_supabase", _fallback) + + user = await deps.get_current_user(authorization="Bearer valid-token") + + assert str(user.id) == "e8845a17-282b-4a63-8025-194a06235958" + assert user.email == "dagronl@126.com" + + +@pytest.mark.asyncio +async def test_get_current_user_raises_401_when_fallback_fails(monkeypatch) -> None: + class _BrokenVerifier: + def verify(self, token: str) -> dict[str, object]: + del token + raise TokenValidationError("Token validation failed") + + monkeypatch.setattr(deps, "get_jwt_verifier", lambda: _BrokenVerifier()) + + async def _fallback(token: str): + del token + return None + + monkeypatch.setattr(deps, "_verify_user_with_supabase", _fallback) + + with pytest.raises(HTTPException) as exc: + await deps.get_current_user(authorization="Bearer invalid-token") + + assert exc.value.status_code == 401 diff --git a/infra/scripts/app.sh b/infra/scripts/app.sh index 0634fb8..db2a7e5 100755 --- a/infra/scripts/app.sh +++ b/infra/scripts/app.sh @@ -178,9 +178,9 @@ start() { ${SOCIAL_WEB__HOST:-0.0.0.0} --port ${WEB_PORT} --workers \ ${SOCIAL_WEB__WORKERS:-2} --log-level ${UVICORN_LOG_LEVEL}" - WORKER_CRITICAL_CMD="cd '$ROOT_DIR' && PYTHONPATH=backend/src SOCIAL_RUNTIME__SERVICE_NAME=worker-critical uv run taskiq worker core.taskiq.app:critical_broker core.agent.infrastructure.queue.tasks --workers ${SOCIAL_WORKER__GROUPS__CRITICAL__CONCURRENCY:-2}" - WORKER_DEFAULT_CMD="cd '$ROOT_DIR' && PYTHONPATH=backend/src SOCIAL_RUNTIME__SERVICE_NAME=worker-default uv run taskiq worker core.taskiq.app:default_broker core.agent.infrastructure.queue.tasks --workers ${SOCIAL_WORKER__GROUPS__DEFAULT__CONCURRENCY:-2}" - WORKER_BULK_CMD="cd '$ROOT_DIR' && PYTHONPATH=backend/src SOCIAL_RUNTIME__SERVICE_NAME=worker-bulk uv run taskiq worker core.taskiq.app:bulk_broker core.agent.infrastructure.queue.tasks --workers ${SOCIAL_WORKER__GROUPS__BULK__CONCURRENCY:-1}" + WORKER_CRITICAL_CMD="cd '$ROOT_DIR' && PYTHONPATH=backend/src SOCIAL_RUNTIME__SERVICE_NAME=worker-critical uv run taskiq worker core.taskiq.app:critical_broker core.agentscope.runtime.tasks --workers ${SOCIAL_WORKER__GROUPS__CRITICAL__CONCURRENCY:-2}" + WORKER_DEFAULT_CMD="cd '$ROOT_DIR' && PYTHONPATH=backend/src SOCIAL_RUNTIME__SERVICE_NAME=worker-default uv run taskiq worker core.taskiq.app:default_broker core.agentscope.runtime.tasks --workers ${SOCIAL_WORKER__GROUPS__DEFAULT__CONCURRENCY:-2}" + WORKER_BULK_CMD="cd '$ROOT_DIR' && PYTHONPATH=backend/src SOCIAL_RUNTIME__SERVICE_NAME=worker-bulk uv run taskiq worker core.taskiq.app:bulk_broker core.agentscope.runtime.tasks --workers ${SOCIAL_WORKER__GROUPS__BULK__CONCURRENCY:-1}" tmux new-session -d -s "$SESSION_NAME" -n litellm "bash -lc \"$LITELLM_CMD; echo '[litellm] exited'; exec bash\"" tmux new-window -t "$SESSION_NAME" -n web "bash -lc \"$WEB_CMD; echo '[web] exited'; exec bash\""