refactor: 重构 schemas 结构,统一枚举定义
This commit is contained in:
@@ -1,283 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from v1.agent.dependencies import RedisEventStream, TaskiqQueueClient
|
||||
|
||||
|
||||
class _FakeRedis:
|
||||
def __init__(self) -> None:
|
||||
self.store: dict[str, str] = {}
|
||||
self.delete_calls: list[str] = []
|
||||
|
||||
async def set(
|
||||
self,
|
||||
key: str,
|
||||
value: str,
|
||||
*,
|
||||
nx: bool = False,
|
||||
ex: int | None = None,
|
||||
) -> bool:
|
||||
del ex
|
||||
if nx and key in self.store:
|
||||
return False
|
||||
self.store[key] = value
|
||||
return True
|
||||
|
||||
async def get(self, key: str) -> str | None:
|
||||
return self.store.get(key)
|
||||
|
||||
async def delete(self, key: str) -> int:
|
||||
self.delete_calls.append(key)
|
||||
existed = 1 if key in self.store else 0
|
||||
self.store.pop(key, None)
|
||||
return existed
|
||||
|
||||
|
||||
class _FakeAsyncResult:
|
||||
def __init__(self, task_id: str) -> None:
|
||||
self.task_id = task_id
|
||||
|
||||
|
||||
class _FakeRedisStreamClient:
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_enqueue_returns_task_id(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
from v1.agent import dependencies as deps
|
||||
|
||||
fake_redis = _FakeRedis()
|
||||
resolved_client = {"value": False}
|
||||
|
||||
async def _fake_kiq(payload: dict[str, object]) -> _FakeAsyncResult:
|
||||
assert payload["command"] == "run"
|
||||
return _FakeAsyncResult("task-123")
|
||||
|
||||
async def _fake_get_or_init_client() -> _FakeRedis:
|
||||
resolved_client["value"] = True
|
||||
return fake_redis
|
||||
|
||||
monkeypatch.setattr(deps, "get_or_init_redis_client", _fake_get_or_init_client)
|
||||
monkeypatch.setattr(deps.run_command_task, "kiq", _fake_kiq)
|
||||
|
||||
client = TaskiqQueueClient()
|
||||
task_id = await client.enqueue(command={"command": "run"}, dedup_key=None)
|
||||
|
||||
assert resolved_client["value"] is True
|
||||
assert task_id == "task-123"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_enqueue_resume_dedup_returns_existing_task_id(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
from v1.agent import dependencies as deps
|
||||
|
||||
fake_redis = _FakeRedis()
|
||||
resolved_client = {"value": False}
|
||||
|
||||
async def _fake_kiq(payload: dict[str, object]) -> _FakeAsyncResult:
|
||||
del payload
|
||||
return _FakeAsyncResult("new-task-id")
|
||||
|
||||
async def _fake_get_or_init_client() -> _FakeRedis:
|
||||
resolved_client["value"] = True
|
||||
return fake_redis
|
||||
|
||||
monkeypatch.setattr(deps, "get_or_init_redis_client", _fake_get_or_init_client)
|
||||
monkeypatch.setattr(deps.run_command_task, "kiq", _fake_kiq)
|
||||
|
||||
dedup_key = "resume:session-1:call-1"
|
||||
fake_redis.store[f"agent:dedup:{dedup_key}"] = "existing-task-id"
|
||||
|
||||
client = TaskiqQueueClient()
|
||||
task_id = await client.enqueue(
|
||||
command={
|
||||
"command": "resume",
|
||||
"session_id": "session-1",
|
||||
"tool_call_id": "call-1",
|
||||
},
|
||||
dedup_key=dedup_key,
|
||||
)
|
||||
|
||||
assert resolved_client["value"] is True
|
||||
assert task_id == "existing-task-id"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_enqueue_inflight_dedup_waits_and_reuses_existing_task_id(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
from v1.agent import dependencies as deps
|
||||
|
||||
fake_redis = _FakeRedis()
|
||||
dedup_key = "resume:session-1:call-1"
|
||||
redis_key = f"agent:dedup:{dedup_key}"
|
||||
fake_redis.store[redis_key] = deps.DEDUP_INFLIGHT_MARKER
|
||||
attempts = {"count": 0}
|
||||
|
||||
async def _fake_get_or_init_client() -> _FakeRedis:
|
||||
return fake_redis
|
||||
|
||||
async def _fake_get(key: str) -> str | None:
|
||||
attempts["count"] += 1
|
||||
if attempts["count"] > 1:
|
||||
fake_redis.store[key] = "existing-task-id"
|
||||
return fake_redis.store.get(key)
|
||||
|
||||
async def _fake_sleep(_: float) -> None:
|
||||
return None
|
||||
|
||||
async def _fake_kiq(payload: dict[str, object]) -> _FakeAsyncResult:
|
||||
del payload
|
||||
raise AssertionError("should not enqueue when dedup task id appears")
|
||||
|
||||
monkeypatch.setattr(deps, "get_or_init_redis_client", _fake_get_or_init_client)
|
||||
monkeypatch.setattr(fake_redis, "get", _fake_get)
|
||||
monkeypatch.setattr(deps.asyncio, "sleep", _fake_sleep)
|
||||
monkeypatch.setattr(deps.run_command_task, "kiq", _fake_kiq)
|
||||
|
||||
client = TaskiqQueueClient()
|
||||
task_id = await client.enqueue(
|
||||
command={
|
||||
"command": "resume",
|
||||
"session_id": "session-1",
|
||||
"tool_call_id": "call-1",
|
||||
},
|
||||
dedup_key=dedup_key,
|
||||
)
|
||||
|
||||
assert task_id == "existing-task-id"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_enqueue_failure_cleans_dedup_lock(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
from v1.agent import dependencies as deps
|
||||
|
||||
fake_redis = _FakeRedis()
|
||||
dedup_key = "resume:session-1:call-1"
|
||||
redis_key = f"agent:dedup:{dedup_key}"
|
||||
|
||||
async def _fake_get_or_init_client() -> _FakeRedis:
|
||||
return fake_redis
|
||||
|
||||
async def _fake_kiq(payload: dict[str, object]) -> _FakeAsyncResult:
|
||||
del payload
|
||||
raise RuntimeError("enqueue failed")
|
||||
|
||||
monkeypatch.setattr(deps, "get_or_init_redis_client", _fake_get_or_init_client)
|
||||
monkeypatch.setattr(deps.run_command_task, "kiq", _fake_kiq)
|
||||
|
||||
client = TaskiqQueueClient()
|
||||
with pytest.raises(RuntimeError, match="enqueue failed"):
|
||||
await client.enqueue(
|
||||
command={
|
||||
"command": "resume",
|
||||
"session_id": "session-1",
|
||||
"tool_call_id": "call-1",
|
||||
},
|
||||
dedup_key=dedup_key,
|
||||
)
|
||||
|
||||
assert redis_key in fake_redis.delete_calls
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_enqueue_uses_critical_queue_when_requested(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
from v1.agent import dependencies as deps
|
||||
|
||||
fake_redis = _FakeRedis()
|
||||
|
||||
async def _fake_get_or_init_client() -> _FakeRedis:
|
||||
return fake_redis
|
||||
|
||||
async def _fake_default_kiq(_: dict[str, object]) -> _FakeAsyncResult:
|
||||
raise AssertionError("default queue should not be selected")
|
||||
|
||||
async def _fake_critical_kiq(_: dict[str, object]) -> _FakeAsyncResult:
|
||||
return _FakeAsyncResult("critical-task-id")
|
||||
|
||||
monkeypatch.setattr(deps, "get_or_init_redis_client", _fake_get_or_init_client)
|
||||
monkeypatch.setattr(deps.run_command_task, "kiq", _fake_default_kiq)
|
||||
monkeypatch.setattr(deps.run_command_task_critical, "kiq", _fake_critical_kiq)
|
||||
|
||||
client = TaskiqQueueClient()
|
||||
task_id = await client.enqueue(
|
||||
command={"command": "run", "queue": "critical"},
|
||||
dedup_key=None,
|
||||
)
|
||||
|
||||
assert task_id == "critical-task-id"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_enqueue_uses_bulk_queue_when_requested(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
from v1.agent import dependencies as deps
|
||||
|
||||
fake_redis = _FakeRedis()
|
||||
|
||||
async def _fake_get_or_init_client() -> _FakeRedis:
|
||||
return fake_redis
|
||||
|
||||
async def _fake_default_kiq(_: dict[str, object]) -> _FakeAsyncResult:
|
||||
raise AssertionError("default queue should not be selected")
|
||||
|
||||
async def _fake_bulk_kiq(_: dict[str, object]) -> _FakeAsyncResult:
|
||||
return _FakeAsyncResult("bulk-task-id")
|
||||
|
||||
monkeypatch.setattr(deps, "get_or_init_redis_client", _fake_get_or_init_client)
|
||||
monkeypatch.setattr(deps.run_command_task, "kiq", _fake_default_kiq)
|
||||
monkeypatch.setattr(deps.run_command_task_bulk, "kiq", _fake_bulk_kiq)
|
||||
|
||||
client = TaskiqQueueClient()
|
||||
task_id = await client.enqueue(
|
||||
command={"command": "run", "queue": "bulk"},
|
||||
dedup_key=None,
|
||||
)
|
||||
|
||||
assert task_id == "bulk-task-id"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_event_stream_caps_block_ms_below_socket_timeout(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
from v1.agent import dependencies as deps
|
||||
|
||||
async def _fake_get_or_init_client() -> _FakeRedisStreamClient:
|
||||
return _FakeRedisStreamClient()
|
||||
|
||||
monkeypatch.setattr(deps, "get_or_init_redis_client", _fake_get_or_init_client)
|
||||
monkeypatch.setattr(deps.config.agent_runtime, "redis_stream_block_ms", 5000)
|
||||
monkeypatch.setattr(deps.config.redis, "socket_timeout", 1.0)
|
||||
|
||||
stream = RedisEventStream()
|
||||
bus = await stream._get_bus()
|
||||
|
||||
assert bus._block_ms == 900
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_event_stream_uses_configured_block_ms_when_safe(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
from v1.agent import dependencies as deps
|
||||
|
||||
async def _fake_get_or_init_client() -> _FakeRedisStreamClient:
|
||||
return _FakeRedisStreamClient()
|
||||
|
||||
monkeypatch.setattr(deps, "get_or_init_redis_client", _fake_get_or_init_client)
|
||||
monkeypatch.setattr(deps.config.agent_runtime, "redis_stream_block_ms", 200)
|
||||
monkeypatch.setattr(deps.config.redis, "socket_timeout", 2.0)
|
||||
|
||||
stream = RedisEventStream()
|
||||
bus = await stream._get_bus()
|
||||
|
||||
assert bus._block_ms == 200
|
||||
@@ -12,7 +12,7 @@ import pytest
|
||||
import v1.agent.service as agent_service_module
|
||||
from core.auth.models import CurrentUser
|
||||
from core.config.settings import config
|
||||
from schemas.messages.chat_message import AgentChatMessageMetadata
|
||||
from schemas.domain.chat_message import AgentChatMessageMetadata
|
||||
from v1.agent.service import AgentService
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user