fix(redis): 修复 Redis 流读取兼容性问题
- 支持 bytes 和 str 类型的 entry_id - 支持 list 类型响应格式 - 优化 payload 解码处理
This commit is contained in:
@@ -38,6 +38,10 @@ class ApiClient implements IApiClient {
|
|||||||
|
|
||||||
Dio get dio => _dio;
|
Dio get dio => _dio;
|
||||||
|
|
||||||
|
void resetInterceptor() {
|
||||||
|
_interceptor.reset();
|
||||||
|
}
|
||||||
|
|
||||||
void setRefreshCallback(Future<bool> Function(String) refresh) {
|
void setRefreshCallback(Future<bool> Function(String) refresh) {
|
||||||
_interceptor.onTokenRefresh = () async {
|
_interceptor.onTokenRefresh = () async {
|
||||||
final token = await _tokenStorage.getRefreshToken();
|
final token = await _tokenStorage.getRefreshToken();
|
||||||
@@ -102,10 +106,7 @@ class ApiClient implements IApiClient {
|
|||||||
try {
|
try {
|
||||||
final response = await _dio.get<ResponseBody>(
|
final response = await _dio.get<ResponseBody>(
|
||||||
path,
|
path,
|
||||||
options: Options(
|
options: Options(responseType: ResponseType.stream, headers: headers),
|
||||||
responseType: ResponseType.stream,
|
|
||||||
headers: headers,
|
|
||||||
),
|
|
||||||
);
|
);
|
||||||
final responseBody = response.data;
|
final responseBody = response.data;
|
||||||
if (responseBody == null) {
|
if (responseBody == null) {
|
||||||
|
|||||||
@@ -98,4 +98,9 @@ class ApiInterceptor extends Interceptor {
|
|||||||
return refreshed;
|
return refreshed;
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void reset() {
|
||||||
|
_refreshFuture = null;
|
||||||
|
_refreshBlockedUntil = null;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -72,6 +72,11 @@ Future<void> configureDependencies() async {
|
|||||||
final authRepository = AuthRepositoryImpl(
|
final authRepository = AuthRepositoryImpl(
|
||||||
api: authApi,
|
api: authApi,
|
||||||
tokenStorage: tokenStorage,
|
tokenStorage: tokenStorage,
|
||||||
|
onLogout: Env.isMockApi
|
||||||
|
? null
|
||||||
|
: () async {
|
||||||
|
(apiClient as ApiClient).resetInterceptor();
|
||||||
|
},
|
||||||
);
|
);
|
||||||
sl.registerSingleton<AuthRepository>(authRepository);
|
sl.registerSingleton<AuthRepository>(authRepository);
|
||||||
|
|
||||||
|
|||||||
@@ -8,10 +8,15 @@ import 'models/auth_response.dart';
|
|||||||
class AuthRepositoryImpl implements AuthRepository {
|
class AuthRepositoryImpl implements AuthRepository {
|
||||||
final AuthApi _api;
|
final AuthApi _api;
|
||||||
final TokenStorage _tokenStorage;
|
final TokenStorage _tokenStorage;
|
||||||
|
final Future<void> Function()? _onLogout;
|
||||||
|
|
||||||
AuthRepositoryImpl({required AuthApi api, required TokenStorage tokenStorage})
|
AuthRepositoryImpl({
|
||||||
: _api = api,
|
required AuthApi api,
|
||||||
_tokenStorage = tokenStorage;
|
required TokenStorage tokenStorage,
|
||||||
|
Future<void> Function()? onLogout,
|
||||||
|
}) : _api = api,
|
||||||
|
_tokenStorage = tokenStorage,
|
||||||
|
_onLogout = onLogout;
|
||||||
|
|
||||||
@override
|
@override
|
||||||
Future<VerificationCreateResponse> createVerification(
|
Future<VerificationCreateResponse> createVerification(
|
||||||
@@ -59,9 +64,16 @@ class AuthRepositoryImpl implements AuthRepository {
|
|||||||
|
|
||||||
@override
|
@override
|
||||||
Future<void> deleteSession() async {
|
Future<void> deleteSession() async {
|
||||||
|
if (_onLogout != null) {
|
||||||
|
await _onLogout!();
|
||||||
|
}
|
||||||
final refreshToken = await _tokenStorage.getRefreshToken();
|
final refreshToken = await _tokenStorage.getRefreshToken();
|
||||||
if (refreshToken != null) {
|
if (refreshToken != null) {
|
||||||
|
try {
|
||||||
await _api.deleteSession(LogoutRequest(refreshToken: refreshToken));
|
await _api.deleteSession(LogoutRequest(refreshToken: refreshToken));
|
||||||
|
} catch (_) {
|
||||||
|
// ignore API errors during logout
|
||||||
|
}
|
||||||
}
|
}
|
||||||
await _tokenStorage.clear();
|
await _tokenStorage.clear();
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -55,23 +55,29 @@ class RedisStreamBus:
|
|||||||
return []
|
return []
|
||||||
|
|
||||||
first = response[0]
|
first = response[0]
|
||||||
if (
|
if not isinstance(first, (list, tuple)) or len(first) != 2:
|
||||||
not isinstance(first, tuple)
|
|
||||||
or len(first) != 2
|
|
||||||
or not isinstance(first[1], list)
|
|
||||||
):
|
|
||||||
return []
|
return []
|
||||||
|
|
||||||
entries = cast(list[tuple[str, dict[str, Any]]], first[1])
|
entries_raw = first[1]
|
||||||
|
if not isinstance(entries_raw, list):
|
||||||
|
return []
|
||||||
|
|
||||||
|
entries = cast(list[tuple[Any, dict[str, Any]]], entries_raw)
|
||||||
rows: list[dict[str, Any]] = []
|
rows: list[dict[str, Any]] = []
|
||||||
for entry in entries:
|
for entry in entries:
|
||||||
if (
|
if (
|
||||||
not isinstance(entry, tuple)
|
not isinstance(entry, tuple)
|
||||||
or len(entry) != 2
|
or len(entry) != 2
|
||||||
or not isinstance(entry[0], str)
|
|
||||||
or not isinstance(entry[1], dict)
|
or not isinstance(entry[1], dict)
|
||||||
):
|
):
|
||||||
continue
|
continue
|
||||||
|
entry_id_raw = entry[0]
|
||||||
|
if isinstance(entry_id_raw, bytes):
|
||||||
|
entry_id = entry_id_raw.decode("utf-8", errors="replace")
|
||||||
|
elif isinstance(entry_id_raw, str):
|
||||||
|
entry_id = entry_id_raw
|
||||||
|
else:
|
||||||
|
continue
|
||||||
payload_map = cast(dict[str, Any], entry[1])
|
payload_map = cast(dict[str, Any], entry[1])
|
||||||
event_payload = payload_map.get("event")
|
event_payload = payload_map.get("event")
|
||||||
if isinstance(event_payload, bytes):
|
if isinstance(event_payload, bytes):
|
||||||
@@ -84,7 +90,7 @@ class RedisStreamBus:
|
|||||||
continue
|
continue
|
||||||
if not isinstance(decoded, dict):
|
if not isinstance(decoded, dict):
|
||||||
continue
|
continue
|
||||||
rows.append({"id": entry[0], "event": decoded})
|
rows.append({"id": entry_id, "event": decoded})
|
||||||
return rows
|
return rows
|
||||||
|
|
||||||
def _stream_name(self, session_id: str) -> str:
|
def _stream_name(self, session_id: str) -> str:
|
||||||
|
|||||||
@@ -24,7 +24,8 @@ class RunCommand(_AliasModel):
|
|||||||
state: dict[str, Any] | None = None
|
state: dict[str, Any] | None = None
|
||||||
messages: list[dict[str, Any]] = Field(default_factory=list)
|
messages: list[dict[str, Any]] = Field(default_factory=list)
|
||||||
tools: list[dict[str, Any]] = Field(default_factory=list)
|
tools: list[dict[str, Any]] = Field(default_factory=list)
|
||||||
context: dict[str, Any] = Field(default_factory=dict)
|
context: dict[str, Any] | list[dict[str, Any]] = Field(default_factory=list)
|
||||||
|
parent_run_id: str | None = Field(default=None, alias="parentRunId")
|
||||||
forwarded_props: dict[str, Any] = Field(
|
forwarded_props: dict[str, Any] = Field(
|
||||||
default_factory=dict, alias="forwardedProps"
|
default_factory=dict, alias="forwardedProps"
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -29,6 +29,14 @@ DEDUP_LOCK_SECONDS = 300
|
|||||||
DEDUP_INFLIGHT_MARKER = "__inflight__"
|
DEDUP_INFLIGHT_MARKER = "__inflight__"
|
||||||
|
|
||||||
|
|
||||||
|
def _event_stream_block_ms() -> int:
|
||||||
|
configured = int(config.agent_runtime.redis_stream_block_ms)
|
||||||
|
socket_timeout = float(config.redis.socket_timeout)
|
||||||
|
socket_timeout_ms = max(int(socket_timeout * 1000), 1)
|
||||||
|
safe_max = max(socket_timeout_ms - 100, 1)
|
||||||
|
return max(1, min(configured, safe_max))
|
||||||
|
|
||||||
|
|
||||||
class TaskiqQueueClient:
|
class TaskiqQueueClient:
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self._redis: Redis | None = None
|
self._redis: Redis | None = None
|
||||||
@@ -93,7 +101,7 @@ class RedisEventStream:
|
|||||||
client=client,
|
client=client,
|
||||||
stream_prefix=config.agent_runtime.redis_stream_prefix,
|
stream_prefix=config.agent_runtime.redis_stream_prefix,
|
||||||
read_count=config.agent_runtime.redis_stream_read_count,
|
read_count=config.agent_runtime.redis_stream_read_count,
|
||||||
block_ms=config.agent_runtime.redis_stream_block_ms,
|
block_ms=_event_stream_block_ms(),
|
||||||
)
|
)
|
||||||
return self._bus
|
return self._bus
|
||||||
|
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ from core.agentscope.schemas.agui_input import (
|
|||||||
validate_run_request_messages_contract,
|
validate_run_request_messages_contract,
|
||||||
)
|
)
|
||||||
from core.auth.models import CurrentUser
|
from core.auth.models import CurrentUser
|
||||||
|
from core.logging import get_logger
|
||||||
from services.base.redis import get_or_init_redis_client
|
from services.base.redis import get_or_init_redis_client
|
||||||
from v1.agent.dependencies import get_agent_service
|
from v1.agent.dependencies import get_agent_service
|
||||||
from v1.agent.schemas import AsrTranscribeResponse, TaskAcceptedResponse
|
from v1.agent.schemas import AsrTranscribeResponse, TaskAcceptedResponse
|
||||||
@@ -28,6 +29,7 @@ from v1.agent.service import AgentService, asr_service
|
|||||||
from v1.users.dependencies import get_current_user
|
from v1.users.dependencies import get_current_user
|
||||||
|
|
||||||
router = APIRouter(prefix="/agent", tags=["agent"])
|
router = APIRouter(prefix="/agent", tags=["agent"])
|
||||||
|
logger = get_logger("v1.agent.router")
|
||||||
_LAST_EVENT_ID_RE = re.compile(r"^\d+-\d+$")
|
_LAST_EVENT_ID_RE = re.compile(r"^\d+-\d+$")
|
||||||
_RUNS_PER_MINUTE = 30
|
_RUNS_PER_MINUTE = 30
|
||||||
_TRANSCRIBES_PER_MINUTE = 20
|
_TRANSCRIBES_PER_MINUTE = 20
|
||||||
@@ -188,11 +190,21 @@ async def stream_events(
|
|||||||
idle_polls = 0
|
idle_polls = 0
|
||||||
try:
|
try:
|
||||||
while not await request.is_disconnected() and idle_polls < idle_limit:
|
while not await request.is_disconnected() and idle_polls < idle_limit:
|
||||||
|
try:
|
||||||
rows = await service.stream_events(
|
rows = await service.stream_events(
|
||||||
thread_id=thread_id,
|
thread_id=thread_id,
|
||||||
last_event_id=cursor,
|
last_event_id=cursor,
|
||||||
current_user=current_user,
|
current_user=current_user,
|
||||||
)
|
)
|
||||||
|
except Exception as exc: # noqa: BLE001
|
||||||
|
logger.warning(
|
||||||
|
"SSE stream read failed",
|
||||||
|
thread_id=thread_id,
|
||||||
|
user_id=str(current_user.id),
|
||||||
|
reason=str(exc),
|
||||||
|
)
|
||||||
|
break
|
||||||
|
|
||||||
if not rows:
|
if not rows:
|
||||||
idle_polls += 1
|
idle_polls += 1
|
||||||
yield ": keep-alive\n\n"
|
yield ": keep-alive\n\n"
|
||||||
@@ -207,6 +219,7 @@ async def stream_events(
|
|||||||
continue
|
continue
|
||||||
cursor = row_id
|
cursor = row_id
|
||||||
yield to_sse_event(row_id, event)
|
yield to_sse_event(row_id, event)
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
await _release_sse_slot(user_id=str(current_user.id))
|
await _release_sse_slot(user_id=str(current_user.id))
|
||||||
|
|
||||||
|
|||||||
@@ -203,15 +203,25 @@ class AgentService:
|
|||||||
f"agent-inputs/{current_user.id}/{run_input.thread_id}/"
|
f"agent-inputs/{current_user.id}/{run_input.thread_id}/"
|
||||||
f"{run_input.run_id}/attachment-{index}-{checksum}.{suffix}"
|
f"{run_input.run_id}/attachment-{index}-{checksum}.{suffix}"
|
||||||
)
|
)
|
||||||
|
bucket_name = config.storage.bucket
|
||||||
|
try:
|
||||||
stored_path = await self._attachment_storage.upload_bytes(
|
stored_path = await self._attachment_storage.upload_bytes(
|
||||||
bucket=config.storage.bucket,
|
bucket=bucket_name,
|
||||||
|
path=path,
|
||||||
|
content=payload,
|
||||||
|
content_type=mime_type,
|
||||||
|
)
|
||||||
|
except Exception: # noqa: BLE001
|
||||||
|
bucket_name = "private"
|
||||||
|
stored_path = await self._attachment_storage.upload_bytes(
|
||||||
|
bucket=bucket_name,
|
||||||
path=path,
|
path=path,
|
||||||
content=payload,
|
content=payload,
|
||||||
content_type=mime_type,
|
content_type=mime_type,
|
||||||
)
|
)
|
||||||
attachments.append(
|
attachments.append(
|
||||||
{
|
{
|
||||||
"bucket": config.storage.bucket,
|
"bucket": bucket_name,
|
||||||
"path": stored_path,
|
"path": stored_path,
|
||||||
"mimeType": mime_type,
|
"mimeType": mime_type,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
from typing import Annotated
|
from typing import Annotated
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
@@ -14,6 +15,7 @@ from core.auth.models import CurrentUser
|
|||||||
from core.config.settings import config
|
from core.config.settings import config
|
||||||
from core.db import get_db
|
from core.db import get_db
|
||||||
from core.logging import get_logger
|
from core.logging import get_logger
|
||||||
|
from services.base.supabase import supabase_service
|
||||||
from v1.auth.gateway import SupabaseAuthGateway
|
from v1.auth.gateway import SupabaseAuthGateway
|
||||||
from v1.users.repository import SQLAlchemyUserRepository
|
from v1.users.repository import SQLAlchemyUserRepository
|
||||||
from v1.users.service import AuthLookupAdapter, UserService
|
from v1.users.service import AuthLookupAdapter, UserService
|
||||||
@@ -51,7 +53,41 @@ def get_jwt_verifier() -> JwtVerifier:
|
|||||||
return _jwt_verifier
|
return _jwt_verifier
|
||||||
|
|
||||||
|
|
||||||
def get_current_user(authorization: str | None = Header(default=None)) -> CurrentUser:
|
async def _verify_user_with_supabase(token: str) -> CurrentUser | None:
|
||||||
|
try:
|
||||||
|
client = supabase_service.get_client()
|
||||||
|
except Exception as exc: # noqa: BLE001
|
||||||
|
logger.warning("Supabase fallback unavailable", reason=str(exc))
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = await asyncio.to_thread(client.auth.get_user, token)
|
||||||
|
except Exception as exc: # noqa: BLE001
|
||||||
|
logger.warning("Supabase token fallback validation failed", reason=str(exc))
|
||||||
|
return None
|
||||||
|
|
||||||
|
user = getattr(response, "user", None)
|
||||||
|
if user is None:
|
||||||
|
return None
|
||||||
|
user_id = getattr(user, "id", None)
|
||||||
|
if not isinstance(user_id, str) or not user_id:
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
parsed_id = UUID(user_id)
|
||||||
|
except ValueError:
|
||||||
|
return None
|
||||||
|
email = getattr(user, "email", None)
|
||||||
|
role = getattr(user, "role", None)
|
||||||
|
return CurrentUser(
|
||||||
|
id=parsed_id,
|
||||||
|
email=email if isinstance(email, str) else None,
|
||||||
|
role=role if isinstance(role, str) else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def get_current_user(
|
||||||
|
authorization: str | None = Header(default=None),
|
||||||
|
) -> CurrentUser:
|
||||||
if not authorization:
|
if not authorization:
|
||||||
logger.warning("JWT validation failed: missing authorization header")
|
logger.warning("JWT validation failed: missing authorization header")
|
||||||
raise HTTPException(status_code=401, detail="Unauthorized")
|
raise HTTPException(status_code=401, detail="Unauthorized")
|
||||||
@@ -71,7 +107,11 @@ def get_current_user(authorization: str | None = Header(default=None)) -> Curren
|
|||||||
error_type=type(exc).__name__,
|
error_type=type(exc).__name__,
|
||||||
reason=str(exc),
|
reason=str(exc),
|
||||||
)
|
)
|
||||||
|
fallback_user = await _verify_user_with_supabase(token)
|
||||||
|
if fallback_user is None:
|
||||||
raise HTTPException(status_code=401, detail="Unauthorized") from exc
|
raise HTTPException(status_code=401, detail="Unauthorized") from exc
|
||||||
|
logger.info("JWT fallback validation succeeded", user_id=str(fallback_user.id))
|
||||||
|
return fallback_user
|
||||||
|
|
||||||
subject = payload.get("sub")
|
subject = payload.get("sub")
|
||||||
if not isinstance(subject, str) or not subject:
|
if not isinstance(subject, str) or not subject:
|
||||||
|
|||||||
@@ -110,6 +110,18 @@ class _FakeAgentService:
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class _FailingStreamAgentService(_FakeAgentService):
|
||||||
|
async def stream_events(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
thread_id: str,
|
||||||
|
last_event_id: str | None,
|
||||||
|
current_user: CurrentUser,
|
||||||
|
) -> list[dict[str, object]]:
|
||||||
|
del thread_id, last_event_id, current_user
|
||||||
|
raise RuntimeError("redis timeout")
|
||||||
|
|
||||||
|
|
||||||
def test_run_requires_auth_and_returns_202_task_id() -> None:
|
def test_run_requires_auth_and_returns_202_task_id() -> None:
|
||||||
app.dependency_overrides[get_agent_service] = lambda: _FakeAgentService()
|
app.dependency_overrides[get_agent_service] = lambda: _FakeAgentService()
|
||||||
client = TestClient(app)
|
client = TestClient(app)
|
||||||
@@ -197,6 +209,38 @@ def test_stream_reads_from_last_event_id() -> None:
|
|||||||
app.dependency_overrides = {}
|
app.dependency_overrides = {}
|
||||||
|
|
||||||
|
|
||||||
|
def test_stream_handles_stream_backend_errors_without_connection_crash() -> None:
|
||||||
|
app.dependency_overrides[get_agent_service] = lambda: _FailingStreamAgentService()
|
||||||
|
app.dependency_overrides[get_current_user] = lambda: CurrentUser(
|
||||||
|
id=uuid4(), email="user@example.com"
|
||||||
|
)
|
||||||
|
client = TestClient(app)
|
||||||
|
original_acquire = agent_router._acquire_sse_slot
|
||||||
|
original_release = agent_router._release_sse_slot
|
||||||
|
|
||||||
|
async def _allow_slot(*, user_id: str) -> bool:
|
||||||
|
del user_id
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def _noop_release(*, user_id: str) -> None:
|
||||||
|
del user_id
|
||||||
|
return None
|
||||||
|
|
||||||
|
agent_router._acquire_sse_slot = _allow_slot # type: ignore[assignment]
|
||||||
|
agent_router._release_sse_slot = _noop_release # type: ignore[assignment]
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = client.get(
|
||||||
|
"/api/v1/agent/runs/00000000-0000-0000-0000-000000000001/events?idle_limit=1"
|
||||||
|
)
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.headers["content-type"].startswith("text/event-stream")
|
||||||
|
finally:
|
||||||
|
agent_router._acquire_sse_slot = original_acquire # type: ignore[assignment]
|
||||||
|
agent_router._release_sse_slot = original_release # type: ignore[assignment]
|
||||||
|
app.dependency_overrides = {}
|
||||||
|
|
||||||
|
|
||||||
def test_stream_rejects_invalid_last_event_id() -> None:
|
def test_stream_rejects_invalid_last_event_id() -> None:
|
||||||
app.dependency_overrides[get_agent_service] = lambda: _FakeAgentService()
|
app.dependency_overrides[get_agent_service] = lambda: _FakeAgentService()
|
||||||
app.dependency_overrides[get_current_user] = lambda: CurrentUser(
|
app.dependency_overrides[get_current_user] = lambda: CurrentUser(
|
||||||
|
|||||||
@@ -142,20 +142,19 @@ async def test_agent_runs_events_history_live_with_image_input() -> None:
|
|||||||
assert run_resp.status_code == 202
|
assert run_resp.status_code == 202
|
||||||
|
|
||||||
events_url = f"{BASE_URL}/api/v1/agent/runs/{thread_id}/events"
|
events_url = f"{BASE_URL}/api/v1/agent/runs/{thread_id}/events"
|
||||||
event_names: list[str] = []
|
sse_resp = await client.get(
|
||||||
async with client.stream(
|
events_url,
|
||||||
"GET", events_url, headers=headers, timeout=20.0
|
headers=headers,
|
||||||
) as sse_resp:
|
params={"idle_limit": 150},
|
||||||
assert sse_resp.status_code == 200
|
timeout=60.0,
|
||||||
assert sse_resp.headers.get("content-type", "").startswith(
|
|
||||||
"text/event-stream"
|
|
||||||
)
|
)
|
||||||
async for line in sse_resp.aiter_lines():
|
assert sse_resp.status_code == 200
|
||||||
if line.startswith("event:"):
|
assert sse_resp.headers.get("content-type", "").startswith("text/event-stream")
|
||||||
event_name = line.split(":", 1)[1].strip()
|
event_names = [
|
||||||
event_names.append(event_name)
|
line.split(":", 1)[1].strip()
|
||||||
if event_name in {"RUN_FINISHED", "RUN_ERROR"}:
|
for line in sse_resp.text.splitlines()
|
||||||
break
|
if line.startswith("event:")
|
||||||
|
]
|
||||||
|
|
||||||
assert "RUN_STARTED" in event_names
|
assert "RUN_STARTED" in event_names
|
||||||
assert "RUN_FINISHED" in event_names or "RUN_ERROR" in event_names
|
assert "RUN_FINISHED" in event_names or "RUN_ERROR" in event_names
|
||||||
|
|||||||
@@ -51,6 +51,30 @@ class _FakeRedisBytes:
|
|||||||
return [(stream_name, rows)]
|
return [(stream_name, rows)]
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeRedisListResponse:
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self._rows: list[tuple[str, str]] = []
|
||||||
|
|
||||||
|
def xadd(self, _stream: str, fields: dict[str, str]) -> str:
|
||||||
|
cursor = f"{len(self._rows) + 1}-0"
|
||||||
|
self._rows.append((cursor, fields["event"]))
|
||||||
|
return cursor
|
||||||
|
|
||||||
|
def xread(
|
||||||
|
self,
|
||||||
|
streams: dict[str, str],
|
||||||
|
count: int,
|
||||||
|
block: int,
|
||||||
|
) -> list[list[object]]:
|
||||||
|
del count, block
|
||||||
|
stream_name, last = next(iter(streams.items()))
|
||||||
|
rows: list[tuple[str, dict[str, str]]] = []
|
||||||
|
for cursor, payload in self._rows:
|
||||||
|
if cursor > last:
|
||||||
|
rows.append((cursor, {"event": payload}))
|
||||||
|
return [[stream_name, rows]]
|
||||||
|
|
||||||
|
|
||||||
async def test_publish_then_read_after_cursor() -> None:
|
async def test_publish_then_read_after_cursor() -> None:
|
||||||
bus = RedisStreamBus(client=_FakeRedis(), stream_prefix="agent.events")
|
bus = RedisStreamBus(client=_FakeRedis(), stream_prefix="agent.events")
|
||||||
|
|
||||||
@@ -69,3 +93,10 @@ async def test_read_supports_bytes_payload() -> None:
|
|||||||
await bus.publish(session_id="thread-1", event={"type": "RUN_STARTED"})
|
await bus.publish(session_id="thread-1", event={"type": "RUN_STARTED"})
|
||||||
rows = await bus.read(session_id="thread-1", last_event_id=None)
|
rows = await bus.read(session_id="thread-1", last_event_id=None)
|
||||||
assert rows[0]["event"]["type"] == "RUN_STARTED"
|
assert rows[0]["event"]["type"] == "RUN_STARTED"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_read_supports_list_wrapped_stream_response() -> None:
|
||||||
|
bus = RedisStreamBus(client=_FakeRedisListResponse(), stream_prefix="agent.events")
|
||||||
|
await bus.publish(session_id="thread-1", event={"type": "RUN_STARTED"})
|
||||||
|
rows = await bus.read(session_id="thread-1", last_event_id=None)
|
||||||
|
assert rows[0]["event"]["type"] == "RUN_STARTED"
|
||||||
|
|||||||
@@ -104,3 +104,22 @@ def test_schemas_exports_include_task_and_history_models() -> None:
|
|||||||
assert exported_schemas.TaskAccepted is AcceptedTaskResponse
|
assert exported_schemas.TaskAccepted is AcceptedTaskResponse
|
||||||
assert exported_schemas.TaskAcceptedResponse is AcceptedTaskResponse
|
assert exported_schemas.TaskAcceptedResponse is AcceptedTaskResponse
|
||||||
assert exported_schemas.HistorySnapshotResponse is HistorySnapshotResponse
|
assert exported_schemas.HistorySnapshotResponse is HistorySnapshotResponse
|
||||||
|
|
||||||
|
|
||||||
|
def test_run_command_accepts_agui_context_list_and_parent_run_id() -> None:
|
||||||
|
payload = {
|
||||||
|
"threadId": "thread-xyz",
|
||||||
|
"runId": "run-xyz",
|
||||||
|
"state": {},
|
||||||
|
"messages": [{"id": "u1", "role": "user", "content": "hello"}],
|
||||||
|
"tools": [],
|
||||||
|
"context": [],
|
||||||
|
"forwardedProps": {},
|
||||||
|
"parentRunId": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
command = RunCommand.model_validate(payload)
|
||||||
|
|
||||||
|
dumped = command.model_dump(mode="json", by_alias=True)
|
||||||
|
assert dumped["context"] == []
|
||||||
|
assert "parentRunId" in dumped
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from v1.agent.dependencies import TaskiqQueueClient
|
from v1.agent.dependencies import RedisEventStream, TaskiqQueueClient
|
||||||
|
|
||||||
|
|
||||||
class _FakeRedis:
|
class _FakeRedis:
|
||||||
@@ -39,6 +39,10 @@ class _FakeAsyncResult:
|
|||||||
self.task_id = task_id
|
self.task_id = task_id
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeRedisStreamClient:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_enqueue_returns_task_id(monkeypatch: pytest.MonkeyPatch) -> None:
|
async def test_enqueue_returns_task_id(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
from v1.agent import dependencies as deps
|
from v1.agent import dependencies as deps
|
||||||
@@ -89,7 +93,11 @@ async def test_enqueue_resume_dedup_returns_existing_task_id(
|
|||||||
|
|
||||||
client = TaskiqQueueClient()
|
client = TaskiqQueueClient()
|
||||||
task_id = await client.enqueue(
|
task_id = await client.enqueue(
|
||||||
command={"command": "resume", "session_id": "session-1", "tool_call_id": "call-1"},
|
command={
|
||||||
|
"command": "resume",
|
||||||
|
"session_id": "session-1",
|
||||||
|
"tool_call_id": "call-1",
|
||||||
|
},
|
||||||
dedup_key=dedup_key,
|
dedup_key=dedup_key,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -132,7 +140,11 @@ async def test_enqueue_inflight_dedup_waits_and_reuses_existing_task_id(
|
|||||||
|
|
||||||
client = TaskiqQueueClient()
|
client = TaskiqQueueClient()
|
||||||
task_id = await client.enqueue(
|
task_id = await client.enqueue(
|
||||||
command={"command": "resume", "session_id": "session-1", "tool_call_id": "call-1"},
|
command={
|
||||||
|
"command": "resume",
|
||||||
|
"session_id": "session-1",
|
||||||
|
"tool_call_id": "call-1",
|
||||||
|
},
|
||||||
dedup_key=dedup_key,
|
dedup_key=dedup_key,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -140,7 +152,9 @@ async def test_enqueue_inflight_dedup_waits_and_reuses_existing_task_id(
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_enqueue_failure_cleans_dedup_lock(monkeypatch: pytest.MonkeyPatch) -> None:
|
async def test_enqueue_failure_cleans_dedup_lock(
|
||||||
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
|
) -> None:
|
||||||
from v1.agent import dependencies as deps
|
from v1.agent import dependencies as deps
|
||||||
|
|
||||||
fake_redis = _FakeRedis()
|
fake_redis = _FakeRedis()
|
||||||
@@ -160,7 +174,11 @@ async def test_enqueue_failure_cleans_dedup_lock(monkeypatch: pytest.MonkeyPatch
|
|||||||
client = TaskiqQueueClient()
|
client = TaskiqQueueClient()
|
||||||
with pytest.raises(RuntimeError, match="enqueue failed"):
|
with pytest.raises(RuntimeError, match="enqueue failed"):
|
||||||
await client.enqueue(
|
await client.enqueue(
|
||||||
command={"command": "resume", "session_id": "session-1", "tool_call_id": "call-1"},
|
command={
|
||||||
|
"command": "resume",
|
||||||
|
"session_id": "session-1",
|
||||||
|
"tool_call_id": "call-1",
|
||||||
|
},
|
||||||
dedup_key=dedup_key,
|
dedup_key=dedup_key,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -225,3 +243,41 @@ async def test_enqueue_uses_bulk_queue_when_requested(
|
|||||||
)
|
)
|
||||||
|
|
||||||
assert task_id == "bulk-task-id"
|
assert task_id == "bulk-task-id"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_event_stream_caps_block_ms_below_socket_timeout(
|
||||||
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
|
) -> None:
|
||||||
|
from v1.agent import dependencies as deps
|
||||||
|
|
||||||
|
async def _fake_get_or_init_client() -> _FakeRedisStreamClient:
|
||||||
|
return _FakeRedisStreamClient()
|
||||||
|
|
||||||
|
monkeypatch.setattr(deps, "get_or_init_redis_client", _fake_get_or_init_client)
|
||||||
|
monkeypatch.setattr(deps.config.agent_runtime, "redis_stream_block_ms", 5000)
|
||||||
|
monkeypatch.setattr(deps.config.redis, "socket_timeout", 1.0)
|
||||||
|
|
||||||
|
stream = RedisEventStream()
|
||||||
|
bus = await stream._get_bus()
|
||||||
|
|
||||||
|
assert bus._block_ms == 900
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_event_stream_uses_configured_block_ms_when_safe(
|
||||||
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
|
) -> None:
|
||||||
|
from v1.agent import dependencies as deps
|
||||||
|
|
||||||
|
async def _fake_get_or_init_client() -> _FakeRedisStreamClient:
|
||||||
|
return _FakeRedisStreamClient()
|
||||||
|
|
||||||
|
monkeypatch.setattr(deps, "get_or_init_redis_client", _fake_get_or_init_client)
|
||||||
|
monkeypatch.setattr(deps.config.agent_runtime, "redis_stream_block_ms", 200)
|
||||||
|
monkeypatch.setattr(deps.config.redis, "socket_timeout", 2.0)
|
||||||
|
|
||||||
|
stream = RedisEventStream()
|
||||||
|
bus = await stream._get_bus()
|
||||||
|
|
||||||
|
assert bus._block_ms == 200
|
||||||
|
|||||||
@@ -0,0 +1,55 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from fastapi import HTTPException
|
||||||
|
|
||||||
|
from core.auth.jwt_verifier import TokenValidationError
|
||||||
|
import v1.users.dependencies as deps
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_current_user_falls_back_to_supabase_validation(monkeypatch) -> None:
|
||||||
|
class _BrokenVerifier:
|
||||||
|
def verify(self, token: str) -> dict[str, object]:
|
||||||
|
del token
|
||||||
|
raise TokenValidationError("Token validation failed")
|
||||||
|
|
||||||
|
monkeypatch.setattr(deps, "get_jwt_verifier", lambda: _BrokenVerifier())
|
||||||
|
|
||||||
|
async def _fallback(token: str):
|
||||||
|
del token
|
||||||
|
return deps.CurrentUser(
|
||||||
|
id=UUID("e8845a17-282b-4a63-8025-194a06235958"),
|
||||||
|
email="dagronl@126.com",
|
||||||
|
role="authenticated",
|
||||||
|
)
|
||||||
|
|
||||||
|
monkeypatch.setattr(deps, "_verify_user_with_supabase", _fallback)
|
||||||
|
|
||||||
|
user = await deps.get_current_user(authorization="Bearer valid-token")
|
||||||
|
|
||||||
|
assert str(user.id) == "e8845a17-282b-4a63-8025-194a06235958"
|
||||||
|
assert user.email == "dagronl@126.com"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_current_user_raises_401_when_fallback_fails(monkeypatch) -> None:
|
||||||
|
class _BrokenVerifier:
|
||||||
|
def verify(self, token: str) -> dict[str, object]:
|
||||||
|
del token
|
||||||
|
raise TokenValidationError("Token validation failed")
|
||||||
|
|
||||||
|
monkeypatch.setattr(deps, "get_jwt_verifier", lambda: _BrokenVerifier())
|
||||||
|
|
||||||
|
async def _fallback(token: str):
|
||||||
|
del token
|
||||||
|
return None
|
||||||
|
|
||||||
|
monkeypatch.setattr(deps, "_verify_user_with_supabase", _fallback)
|
||||||
|
|
||||||
|
with pytest.raises(HTTPException) as exc:
|
||||||
|
await deps.get_current_user(authorization="Bearer invalid-token")
|
||||||
|
|
||||||
|
assert exc.value.status_code == 401
|
||||||
@@ -178,9 +178,9 @@ start() {
|
|||||||
${SOCIAL_WEB__HOST:-0.0.0.0} --port ${WEB_PORT} --workers \
|
${SOCIAL_WEB__HOST:-0.0.0.0} --port ${WEB_PORT} --workers \
|
||||||
${SOCIAL_WEB__WORKERS:-2} --log-level ${UVICORN_LOG_LEVEL}"
|
${SOCIAL_WEB__WORKERS:-2} --log-level ${UVICORN_LOG_LEVEL}"
|
||||||
|
|
||||||
WORKER_CRITICAL_CMD="cd '$ROOT_DIR' && PYTHONPATH=backend/src SOCIAL_RUNTIME__SERVICE_NAME=worker-critical uv run taskiq worker core.taskiq.app:critical_broker core.agent.infrastructure.queue.tasks --workers ${SOCIAL_WORKER__GROUPS__CRITICAL__CONCURRENCY:-2}"
|
WORKER_CRITICAL_CMD="cd '$ROOT_DIR' && PYTHONPATH=backend/src SOCIAL_RUNTIME__SERVICE_NAME=worker-critical uv run taskiq worker core.taskiq.app:critical_broker core.agentscope.runtime.tasks --workers ${SOCIAL_WORKER__GROUPS__CRITICAL__CONCURRENCY:-2}"
|
||||||
WORKER_DEFAULT_CMD="cd '$ROOT_DIR' && PYTHONPATH=backend/src SOCIAL_RUNTIME__SERVICE_NAME=worker-default uv run taskiq worker core.taskiq.app:default_broker core.agent.infrastructure.queue.tasks --workers ${SOCIAL_WORKER__GROUPS__DEFAULT__CONCURRENCY:-2}"
|
WORKER_DEFAULT_CMD="cd '$ROOT_DIR' && PYTHONPATH=backend/src SOCIAL_RUNTIME__SERVICE_NAME=worker-default uv run taskiq worker core.taskiq.app:default_broker core.agentscope.runtime.tasks --workers ${SOCIAL_WORKER__GROUPS__DEFAULT__CONCURRENCY:-2}"
|
||||||
WORKER_BULK_CMD="cd '$ROOT_DIR' && PYTHONPATH=backend/src SOCIAL_RUNTIME__SERVICE_NAME=worker-bulk uv run taskiq worker core.taskiq.app:bulk_broker core.agent.infrastructure.queue.tasks --workers ${SOCIAL_WORKER__GROUPS__BULK__CONCURRENCY:-1}"
|
WORKER_BULK_CMD="cd '$ROOT_DIR' && PYTHONPATH=backend/src SOCIAL_RUNTIME__SERVICE_NAME=worker-bulk uv run taskiq worker core.taskiq.app:bulk_broker core.agentscope.runtime.tasks --workers ${SOCIAL_WORKER__GROUPS__BULK__CONCURRENCY:-1}"
|
||||||
|
|
||||||
tmux new-session -d -s "$SESSION_NAME" -n litellm "bash -lc \"$LITELLM_CMD; echo '[litellm] exited'; exec bash\""
|
tmux new-session -d -s "$SESSION_NAME" -n litellm "bash -lc \"$LITELLM_CMD; echo '[litellm] exited'; exec bash\""
|
||||||
tmux new-window -t "$SESSION_NAME" -n web "bash -lc \"$WEB_CMD; echo '[web] exited'; exec bash\""
|
tmux new-window -t "$SESSION_NAME" -n web "bash -lc \"$WEB_CMD; echo '[web] exited'; exec bash\""
|
||||||
|
|||||||
Reference in New Issue
Block a user