refactor: 重构 Agent 模块为 AgentScope,删除旧版 CrewAI/LiteLLM 实现

This commit is contained in:
qzl
2026-03-11 20:51:56 +08:00
parent 177ed616bf
commit 145e3dc615
149 changed files with 5120 additions and 11356 deletions
@@ -1,703 +0,0 @@
from __future__ import annotations
import json
import uuid
from decimal import Decimal
import pytest
from ag_ui.core import RunAgentInput
from sqlalchemy import delete, select
from core.agent.application.resume_service import ResumeService
from core.agent.application.run_service import RunService
from core.agent.infrastructure.persistence.session_repository import SessionRepository
from core.agent.infrastructure.queue.tasks import run_agent_task
from core.agent.infrastructure.storage.tool_result_storage import (
create_tool_result_storage,
)
from services.base.supabase import supabase_service
from core.db import AsyncSessionLocal, engine
from models.agent_chat_message import AgentChatMessage, AgentChatMessageRole
from models.agent_chat_session import AgentChatSession, AgentChatSessionStatus
from models.llm import Llm
from models.llm_factory import LlmFactory
from models.profile import Profile
from models.system_agents import SystemAgents
@pytest.mark.asyncio
async def test_run_then_resume_persists_messages_and_session_state(
monkeypatch: pytest.MonkeyPatch,
) -> None:
call_count = {"n": 0}
def _fake_execute(
self,
*,
user_input: str,
system_prompt: str | None = None,
tools: list[dict[str, object]] | None = None,
) -> dict[str, object]:
del self, user_input, system_prompt, tools
call_count["n"] += 1
if call_count["n"] == 1:
return {
"assistant_text": "请确认是否跳转。",
"prompt_tokens": 11,
"completion_tokens": 7,
"total_tokens": 18,
"cost": 0.0025,
"pending_front_tool": {
"name": "front.navigate_to_route",
"args": {"target": "/calendar/dayweek", "replace": False},
"target": "frontend",
},
"agui_events": [],
}
return {
"assistant_text": "已继续执行并完成。",
"prompt_tokens": 3,
"completion_tokens": 2,
"total_tokens": 5,
"cost": 0.001,
"pending_front_tool": None,
"agui_events": [],
}
monkeypatch.setattr(
"core.agent.infrastructure.crewai.runtime.CrewAIRuntime.execute",
_fake_execute,
)
async with AsyncSessionLocal() as lookup_session:
existing_owner = await lookup_session.execute(
select(AgentChatSession.user_id).limit(1)
)
owner_id = existing_owner.scalar_one_or_none()
if owner_id is None:
pytest.skip("No existing session owner available in local database")
factory_id = uuid.uuid4()
session_uuid = uuid.uuid4()
agent_type = f"AAA_TEST_{uuid.uuid4().hex[:8]}"
async with AsyncSessionLocal() as seed_session:
llm_row = await seed_session.execute(select(Llm.id).limit(1))
llm_id = llm_row.scalar_one_or_none()
if llm_id is None:
seed_session.add(
LlmFactory(
id=factory_id,
name=f"dashscope-test-{uuid.uuid4().hex[:8]}",
request_url="https://dashscope.example",
)
)
llm_id = uuid.uuid4()
seed_session.add(
Llm(
id=llm_id,
factory_id=factory_id,
model_code=f"qwen3.5-flash-test-{uuid.uuid4().hex[:6]}",
)
)
seed_session.add(
SystemAgents(agent_type=agent_type, llm_id=llm_id, status="active")
)
seed_session.add(AgentChatSession(id=session_uuid, user_id=owner_id))
await seed_session.commit()
published: list[str] = []
queued_commands: list[dict[str, object]] = []
async def _publish(event: dict[str, object]) -> None:
event_type = event.get("type")
if isinstance(event_type, str):
published.append(event_type)
async def _enqueue(command: dict[str, object]) -> str:
queued_commands.append(command)
return "task-followup-1"
try:
run_input_payload = {
"threadId": str(session_uuid),
"runId": "run-it-1",
"state": {},
"messages": [
{"id": "u1", "role": "user", "content": "帮我打开日历"},
],
"tools": [
{
"name": "front.navigate_to_route",
"description": "navigate route",
"parameters": {"type": "object"},
}
],
"context": [],
"forwardedProps": {},
}
run_result = await run_agent_task(
{
"command": "run",
"run_input": run_input_payload,
},
publish_event=_publish,
enqueue_command=_enqueue,
run_service=RunService(),
resume_service=ResumeService(),
)
pending_tool_call_id = str(run_result["pending_tool_call_id"])
state_snapshot = run_result["state_snapshot"]
assert isinstance(state_snapshot, dict)
pending_tool_nonce = state_snapshot["pending_tool_nonce"]
assert isinstance(pending_tool_nonce, str)
await run_agent_task(
{
"command": "resume",
"run_input": {
"threadId": str(session_uuid),
"runId": "run-it-2",
"state": {},
"messages": [
{
"id": "tool-1",
"role": "tool",
"toolCallId": pending_tool_call_id,
"content": json.dumps(
{
"toolName": "front.navigate_to_route",
"toolArgs": {
"target": "/calendar/dayweek",
"replace": False,
"__nonce": pending_tool_nonce,
},
"nonce": pending_tool_nonce,
"result": {"ok": True},
},
ensure_ascii=True,
separators=(",", ":"),
),
}
],
"tools": [],
"context": [],
"forwardedProps": {},
},
},
publish_event=_publish,
enqueue_command=_enqueue,
run_service=RunService(),
resume_service=ResumeService(),
)
assert len(queued_commands) == 1
await run_agent_task(
queued_commands[0],
publish_event=_publish,
enqueue_command=_enqueue,
run_service=RunService(),
resume_service=ResumeService(),
)
await engine.dispose()
async with AsyncSessionLocal() as verify_session:
db_session = await verify_session.get(AgentChatSession, session_uuid)
assert db_session is not None
assert db_session.status == AgentChatSessionStatus.COMPLETED
assert db_session.message_count == 4
assert db_session.total_tokens == 23
assert db_session.total_cost == Decimal("0.003500")
assert db_session.state_snapshot == {
"status": "completed",
"pending_tool_call_id": None,
"pending_tool_name": None,
"pending_tool_args_sha256": None,
"pending_tool_nonce": None,
}
rows = await verify_session.execute(
select(AgentChatMessage)
.where(AgentChatMessage.session_id == session_uuid)
.order_by(AgentChatMessage.seq.asc())
)
messages = list(rows.scalars().all())
assert [item.role for item in messages] == [
AgentChatMessageRole.USER,
AgentChatMessageRole.ASSISTANT,
AgentChatMessageRole.TOOL,
AgentChatMessageRole.ASSISTANT,
]
assert messages[1].input_tokens == 11
assert messages[1].output_tokens == 7
assert messages[1].cost == Decimal("0.002500")
assert messages[3].content == "已继续执行并完成。"
assert "RUN_STARTED" in published
assert "RUN_FINISHED" in published
assert "TEXT_MESSAGE_CONTENT" in published
finally:
async with AsyncSessionLocal() as cleanup_session:
await cleanup_session.execute(
delete(AgentChatSession).where(AgentChatSession.id == session_uuid)
)
await cleanup_session.execute(
delete(SystemAgents).where(SystemAgents.agent_type == agent_type)
)
await cleanup_session.commit()
@pytest.mark.asyncio
async def test_resume_tool_result_offloads_to_supabase_storage_for_calendar_tool(
monkeypatch: pytest.MonkeyPatch,
) -> None:
call_count = {"n": 0}
def _fake_execute(
self,
*,
user_input: str,
system_prompt: str | None = None,
tools: list[dict[str, object]] | None = None,
) -> dict[str, object]:
del self, user_input, system_prompt, tools
call_count["n"] += 1
if call_count["n"] == 1:
return {
"assistant_text": "我来创建日历事件,请稍候确认。",
"prompt_tokens": 10,
"completion_tokens": 6,
"total_tokens": 16,
"cost": 0.002,
"pending_front_tool": {
"name": "front.create_calendar_event",
"args": {
"title": "测试日程",
"start": "2026-03-09T09:00:00+08:00",
"end": "2026-03-09T10:00:00+08:00",
},
"target": "frontend",
},
"agui_events": [],
}
return {
"assistant_text": "日历已创建。",
"prompt_tokens": 2,
"completion_tokens": 2,
"total_tokens": 4,
"cost": 0.001,
"pending_front_tool": None,
"agui_events": [],
}
monkeypatch.setattr(
"core.agent.infrastructure.crewai.runtime.CrewAIRuntime.execute",
_fake_execute,
)
factory_id = uuid.uuid4()
test_user_id: str | None = None
test_user_email = f"agent-it-{uuid.uuid4().hex[:8]}@example.com"
owner_id = uuid.uuid4()
initialized = await supabase_service.initialize()
if not initialized:
pytest.skip("Supabase service is unavailable")
admin_client = supabase_service.get_admin_client()
tool_result_storage = create_tool_result_storage()
assert tool_result_storage is not None
created_user = admin_client.auth.admin.create_user(
{
"email": test_user_email,
"password": "Passw0rd!123",
"email_confirm": True,
"user_metadata": {"source": "integration-test"},
}
)
test_user_id = str(created_user.user.id)
owner_id = uuid.UUID(test_user_id)
await engine.dispose()
async with AsyncSessionLocal() as lookup_session:
llm_row = await lookup_session.execute(select(Llm.id).limit(1))
llm_id = llm_row.scalar_one_or_none()
if llm_id is None:
async with AsyncSessionLocal() as seed_session:
factory_row = await seed_session.execute(
select(LlmFactory.id).where(LlmFactory.name == "dashscope").limit(1)
)
existing_factory_id = factory_row.scalar_one_or_none()
if existing_factory_id is None:
seed_session.add(
LlmFactory(
id=factory_id,
name="dashscope",
request_url="https://dashscope.example",
)
)
await seed_session.commit()
else:
factory_id = existing_factory_id
async with AsyncSessionLocal() as seed_session:
llm_id = uuid.uuid4()
seed_session.add(
Llm(
id=llm_id,
factory_id=factory_id,
model_code=f"qwen3.5-flash-test-{uuid.uuid4().hex[:6]}",
)
)
await seed_session.commit()
storage = admin_client.storage
try:
storage.get_bucket("private")
except Exception:
storage.create_bucket("private", "private", {"public": False})
session_uuid = uuid.uuid4()
agent_type = f"AAA_TEST_{uuid.uuid4().hex[:8]}"
uploaded_path: str | None = None
try:
probe_path = f"tool-results/probe/{uuid.uuid4().hex}.json"
try:
storage.from_("private").upload(probe_path, b"{}")
storage.from_("private").remove([probe_path])
except Exception:
pytest.skip(
"Supabase Storage upload API unavailable in current environment"
)
async with AsyncSessionLocal() as seed_session:
existing_profile = await seed_session.get(Profile, owner_id)
if existing_profile is None:
seed_session.add(
Profile(
id=owner_id,
username=f"it_{uuid.uuid4().hex[:8]}",
)
)
seed_session.add(
SystemAgents(agent_type=agent_type, llm_id=llm_id, status="active")
)
seed_session.add(AgentChatSession(id=session_uuid, user_id=owner_id))
await seed_session.commit()
run_result = await run_agent_task(
{
"command": "run",
"run_input": {
"threadId": str(session_uuid),
"runId": "run-storage-1",
"state": {},
"messages": [
{
"id": "u1",
"role": "user",
"content": "帮我创建明天9点到10点的日历",
}
],
"tools": [
{
"name": "front.create_calendar_event",
"description": "Create calendar event",
"parameters": {"type": "object"},
}
],
"context": [],
"forwardedProps": {},
},
},
run_service=RunService(),
resume_service=ResumeService(
tool_result_storage=tool_result_storage,
tool_result_bucket="private",
tool_result_prefix="tool-results",
),
)
pending_tool_call_id = str(run_result["pending_tool_call_id"])
snapshot = run_result["state_snapshot"]
assert isinstance(snapshot, dict)
pending_tool_nonce = snapshot.get("pending_tool_nonce")
assert isinstance(pending_tool_nonce, str)
await run_agent_task(
{
"command": "resume",
"run_input": {
"threadId": str(session_uuid),
"runId": "run-storage-2",
"state": {},
"messages": [
{
"id": "tool-1",
"role": "tool",
"toolCallId": pending_tool_call_id,
"content": json.dumps(
{
"toolName": "front.create_calendar_event",
"toolArgs": {
"title": "测试日程",
"start": "2026-03-09T09:00:00+08:00",
"end": "2026-03-09T10:00:00+08:00",
"__nonce": pending_tool_nonce,
},
"nonce": pending_tool_nonce,
"result": {
"ok": True,
"type": "calendar_card.v1",
"version": "v1",
"data": {
"id": "evt-test",
"title": "测试日程",
"description": "x" * 9000,
},
"actions": [
{
"type": "link",
"label": "查看详情",
"target": "/calendar/events/evt-test",
}
],
},
},
ensure_ascii=True,
separators=(",", ":"),
),
}
],
"tools": [],
"context": [],
"forwardedProps": {},
},
},
run_service=RunService(),
resume_service=ResumeService(
tool_result_storage=tool_result_storage,
tool_result_bucket="private",
tool_result_prefix="tool-results",
),
)
await engine.dispose()
async with AsyncSessionLocal() as verify_session:
rows = await verify_session.execute(
select(AgentChatMessage)
.where(AgentChatMessage.session_id == session_uuid)
.where(AgentChatMessage.role == AgentChatMessageRole.TOOL)
.order_by(AgentChatMessage.seq.desc())
)
tool_message = rows.scalars().first()
assert tool_message is not None
metadata = tool_message.metadata_json or {}
storage_bucket = metadata.get("storage_bucket")
storage_path = metadata.get("storage_path")
assert storage_bucket == "private"
assert isinstance(storage_path, str)
assert storage_path.startswith("tool-results/")
uploaded_path = storage_path
downloaded = storage.from_("private").download(uploaded_path)
if isinstance(downloaded, bytes):
downloaded_payload = json.loads(downloaded.decode("utf-8"))
else:
downloaded_payload = json.loads(str(downloaded))
assert downloaded_payload["toolName"] == "front.create_calendar_event"
result_payload = downloaded_payload["result"]
assert result_payload["type"] == "calendar_card.v1"
assert result_payload["data"]["id"] == "evt-test"
finally:
if uploaded_path:
try:
storage.from_("private").remove([uploaded_path])
except Exception:
pass
async with AsyncSessionLocal() as cleanup_session:
await cleanup_session.execute(
delete(AgentChatSession).where(AgentChatSession.id == session_uuid)
)
await cleanup_session.execute(
delete(SystemAgents).where(SystemAgents.agent_type == agent_type)
)
await cleanup_session.execute(delete(Profile).where(Profile.id == owner_id))
await cleanup_session.execute(
delete(Llm).where(Llm.factory_id == factory_id)
)
await cleanup_session.execute(
delete(LlmFactory).where(LlmFactory.id == factory_id)
)
await cleanup_session.commit()
if test_user_id is not None:
try:
admin_client.auth.admin.delete_user(test_user_id)
except Exception:
pass
await supabase_service.close()
@pytest.mark.asyncio
async def test_run_service_embeds_profile_settings_in_runtime_system_prompt(
monkeypatch: pytest.MonkeyPatch,
) -> None:
captured: dict[str, object] = {}
session_uuid = uuid.uuid4()
agent_type = f"AAA_TEST_{uuid.uuid4().hex[:8]}"
original_profile: Profile | None = None
def _fake_execute(self, *, user_input: str, system_prompt: str | None = None):
captured["user_input"] = user_input
captured["system_prompt"] = system_prompt
return {
"assistant_text": "Mocked answer",
"prompt_tokens": 11,
"completion_tokens": 7,
"total_tokens": 18,
"cost": 0.0025,
"agui_events": [],
}
monkeypatch.setattr(
"core.agent.infrastructure.crewai.runtime.CrewAIRuntime.execute",
_fake_execute,
)
await engine.dispose()
async with AsyncSessionLocal() as lookup_session:
owner_row = await lookup_session.execute(select(Profile.id).limit(1))
owner_id = owner_row.scalar_one_or_none()
if owner_id is None:
pytest.skip("No profile owner available in local database")
original_profile = await lookup_session.get(Profile, owner_id)
llm_row = await lookup_session.execute(
select(Llm.id, LlmFactory.name)
.join(LlmFactory, LlmFactory.id == Llm.factory_id)
.where(LlmFactory.name.in_(("dashscope", "deepseek", "moonshot")))
.limit(1)
)
llm_record = llm_row.one_or_none()
if llm_record is None:
pytest.skip("No supported llm provider available in local database")
llm_id = llm_record[0]
try:
async with AsyncSessionLocal() as seed_session:
seed_session.add(
SystemAgents(agent_type=agent_type, llm_id=llm_id, status="active")
)
profile = await seed_session.get(Profile, owner_id)
assert profile is not None
profile.username = "demo-user"
profile.bio = "hello\nworld"
profile.settings = {
"preferences": {
"interface_language": "zh-CN",
"ai_language": "en-US",
"timezone": "Asia/Shanghai",
"country": "CN",
}
}
seed_session.add(AgentChatSession(id=session_uuid, user_id=owner_id))
await seed_session.commit()
result = await RunService().run(
run_input=RunAgentInput.model_validate(
{
"threadId": str(session_uuid),
"runId": "run-it-1",
"state": {},
"messages": [
{"id": "u1", "role": "user", "content": "hello"},
],
"tools": [],
"context": [],
"forwardedProps": {},
}
)
)
assert result["persisted"] is True
assert captured["user_input"] == "hello"
system_prompt = captured["system_prompt"]
assert isinstance(system_prompt, str)
assert "# USER_PROFILE (JSON)" in system_prompt
assert '"ai_language":"en-US"' in system_prompt
assert '"timezone":"Asia/Shanghai"' in system_prompt
assert '"country":"CN"' in system_prompt
finally:
await engine.dispose()
async with AsyncSessionLocal() as cleanup_session:
if original_profile is not None:
profile = await cleanup_session.get(Profile, owner_id)
if profile is not None:
profile.username = original_profile.username
profile.bio = original_profile.bio
profile.settings = original_profile.settings
await cleanup_session.execute(
delete(AgentChatSession).where(AgentChatSession.id == session_uuid)
)
await cleanup_session.execute(
delete(SystemAgents).where(SystemAgents.agent_type == agent_type)
)
await cleanup_session.commit()
@pytest.mark.asyncio
async def test_soft_delete_session_cascades_to_messages() -> None:
session_uuid = uuid.uuid4()
await engine.dispose()
async with AsyncSessionLocal() as lookup_session:
owner = await lookup_session.execute(select(Profile.id).limit(1))
owner_id = owner.scalar_one_or_none()
if owner_id is None:
pytest.skip("No profile owner available in local database")
async with AsyncSessionLocal() as seed_session:
seed_session.add(AgentChatSession(id=session_uuid, user_id=owner_id))
await seed_session.flush()
seed_session.add(
AgentChatMessage(
session_id=session_uuid,
seq=1,
role=AgentChatMessageRole.USER,
content="hello",
)
)
await seed_session.commit()
try:
async with AsyncSessionLocal() as mutate_session:
repo = SessionRepository(mutate_session)
affected = await repo.soft_delete_session_with_messages(
session_id=session_uuid
)
await mutate_session.commit()
assert affected == 1
async with AsyncSessionLocal() as verify_session:
db_session = await verify_session.get(AgentChatSession, session_uuid)
assert db_session is not None
assert db_session.deleted_at is not None
rows = await verify_session.execute(
select(AgentChatMessage).where(
AgentChatMessage.session_id == session_uuid
)
)
messages = list(rows.scalars().all())
assert len(messages) == 1
assert messages[0].deleted_at is not None
finally:
async with AsyncSessionLocal() as cleanup_session:
await cleanup_session.execute(
delete(AgentChatMessage).where(
AgentChatMessage.session_id == session_uuid
)
)
await cleanup_session.execute(
delete(AgentChatSession).where(AgentChatSession.id == session_uuid)
)
await cleanup_session.commit()
@@ -1,78 +0,0 @@
from __future__ import annotations
from core.agent.application.session_state_persistence import persist_tool_result_payload
from core.agent.domain.tool_correlation import reconstruct_tool_call_result_event
from core.agent.infrastructure.queue.tasks import run_agent_task
class _FakeStorage:
def __init__(self) -> None:
self.writes: dict[str, dict[str, object]] = {}
async def upload_json(
self, *, bucket: str, path: str, payload: dict[str, object]
) -> str:
self.writes[f"{bucket}/{path}"] = payload
return "etag-1"
async def test_closed_loop_run_flow_frontend_to_sse() -> None:
thread_id = "00000000-0000-0000-0000-000000000001"
published: list[str] = []
class _FakeRunService:
async def run(self, *, run_input: object) -> dict[str, object]:
del run_input
return {"threadId": thread_id, "runId": "run-1"}
async def _publish(event: dict[str, object]) -> None:
event_type = event.get("type")
if isinstance(event_type, str):
published.append(event_type)
result = await run_agent_task(
{
"command": "run",
"run_input": {
"threadId": thread_id,
"runId": "run-1",
"state": {},
"messages": [{"id": "u1", "role": "user", "content": "hello"}],
"tools": [],
"context": [],
"forwardedProps": {},
},
},
publish_event=_publish,
run_service=_FakeRunService(),
)
assert result["threadId"] == thread_id
assert published[0] == "RUN_STARTED"
assert published[-1] == "RUN_FINISHED"
async def test_tool_result_full_payload_persist_and_reconstruct() -> None:
storage = _FakeStorage()
payload = {
"schema": "ui.v1",
"components": [{"type": "card", "title": "Weather"}],
}
metadata = await persist_tool_result_payload(
storage=storage,
run_id="run-1",
turn_id="turn-1",
tool_call_id="call-1",
tool_name="weather",
payload=payload,
bucket="private",
path="tool-results/run-1/call-1.json",
)
event = reconstruct_tool_call_result_event(metadata=metadata, payload=payload)
assert metadata["type"] == "tool_result"
assert metadata["storage_bucket"] == "private"
assert event["type"] == "TOOL_CALL_RESULT"
assert event["data"]["schema"] == "ui.v1"
@@ -7,8 +7,11 @@ from uuid import UUID, uuid4
import pytest
from core.agent.domain.system_agent_config import SystemAgentLLMConfig
from core.agent.domain.user_context import UserAgentContext, parse_profile_settings
from core.agentscope.schemas.system_agent_config import SystemAgentLLMConfig
from core.agentscope.schemas.user_context import (
UserAgentContext,
parse_profile_settings,
)
from core.agentscope.runtime.config_loader import RuntimeStageConfig
from core.agentscope.runtime.orchestrator import AgentScopeRuntimeOrchestrator
from core.db.session import AsyncSessionLocal
@@ -69,6 +69,7 @@ def _override_current_user(user_id: UUID) -> Callable[[], CurrentUser]:
def test_create_schedule_item_returns_201() -> None:
item = ScheduleItemResponse(
id=uuid4(),
owner_id=uuid4(),
title="Test Event",
start_at=datetime(2026, 2, 28, 16, 0, 0, tzinfo=timezone.utc),
timezone="UTC",
@@ -76,6 +77,8 @@ def test_create_schedule_item_returns_201() -> None:
source_type=ScheduleItemSourceType.MANUAL,
created_at=datetime(2026, 2, 27, 10, 0, 0, tzinfo=timezone.utc),
updated_at=datetime(2026, 2, 27, 10, 0, 0, tzinfo=timezone.utc),
permission=7,
is_owner=True,
)
app.dependency_overrides[get_schedule_item_service] = (
@@ -99,6 +102,7 @@ def test_create_schedule_item_returns_201() -> None:
def test_list_schedule_items_returns_200() -> None:
item = ScheduleItemResponse(
id=uuid4(),
owner_id=uuid4(),
title="Test Event",
start_at=datetime(2026, 2, 28, 16, 0, 0, tzinfo=timezone.utc),
timezone="UTC",
@@ -106,6 +110,8 @@ def test_list_schedule_items_returns_200() -> None:
source_type=ScheduleItemSourceType.MANUAL,
created_at=datetime(2026, 2, 27, 10, 0, 0, tzinfo=timezone.utc),
updated_at=datetime(2026, 2, 27, 10, 0, 0, tzinfo=timezone.utc),
permission=7,
is_owner=True,
)
app.dependency_overrides[get_schedule_item_service] = (
@@ -131,6 +137,7 @@ def test_get_schedule_item_returns_200() -> None:
item_id = uuid4()
item = ScheduleItemResponse(
id=item_id,
owner_id=uuid4(),
title="Test Event",
start_at=datetime(2026, 2, 28, 16, 0, 0, tzinfo=timezone.utc),
timezone="UTC",
@@ -138,6 +145,8 @@ def test_get_schedule_item_returns_200() -> None:
source_type=ScheduleItemSourceType.MANUAL,
created_at=datetime(2026, 2, 27, 10, 0, 0, tzinfo=timezone.utc),
updated_at=datetime(2026, 2, 27, 10, 0, 0, tzinfo=timezone.utc),
permission=7,
is_owner=True,
)
app.dependency_overrides[get_schedule_item_service] = (
@@ -156,6 +165,7 @@ def test_update_schedule_item_returns_200() -> None:
item_id = uuid4()
item = ScheduleItemResponse(
id=item_id,
owner_id=uuid4(),
title="Updated Event",
start_at=datetime(2026, 2, 28, 16, 0, 0, tzinfo=timezone.utc),
timezone="UTC",
@@ -163,6 +173,8 @@ def test_update_schedule_item_returns_200() -> None:
source_type=ScheduleItemSourceType.MANUAL,
created_at=datetime(2026, 2, 27, 10, 0, 0, tzinfo=timezone.utc),
updated_at=datetime(2026, 2, 27, 10, 0, 0, tzinfo=timezone.utc),
permission=7,
is_owner=True,
)
app.dependency_overrides[get_schedule_item_service] = (
@@ -184,6 +196,7 @@ def test_delete_schedule_item_returns_204() -> None:
item_id = uuid4()
item = ScheduleItemResponse(
id=item_id,
owner_id=uuid4(),
title="Test Event",
start_at=datetime(2026, 2, 28, 16, 0, 0, tzinfo=timezone.utc),
timezone="UTC",
@@ -191,6 +204,8 @@ def test_delete_schedule_item_returns_204() -> None:
source_type=ScheduleItemSourceType.MANUAL,
created_at=datetime(2026, 2, 27, 10, 0, 0, tzinfo=timezone.utc),
updated_at=datetime(2026, 2, 27, 10, 0, 0, tzinfo=timezone.utc),
permission=7,
is_owner=True,
)
app.dependency_overrides[get_schedule_item_service] = (
@@ -354,6 +354,12 @@ def test_asr_transcribe_returns_sync_transcript(monkeypatch) -> None:
id=uuid4(), email="user@example.com"
)
async def _allow_transcribe(*, user_id: str) -> bool:
del user_id
return True
monkeypatch.setattr(agent_router, "_allow_transcribe_request", _allow_transcribe)
async def mock_transcribe_file(file_path: str, filename: str) -> str:
assert file_path.endswith(".wav")
assert filename == "test.wav"
@@ -391,6 +397,12 @@ def test_asr_transcribe_rejects_oversized_audio(monkeypatch) -> None:
monkeypatch.setattr(agent_router, "_MAX_TRANSCRIBE_AUDIO_BYTES", 4)
async def _allow_transcribe(*, user_id: str) -> bool:
del user_id
return True
monkeypatch.setattr(agent_router, "_allow_transcribe_request", _allow_transcribe)
client = TestClient(app)
oversized = BytesIO(b"12345")
oversized.name = "test.wav"
@@ -407,11 +419,17 @@ def test_asr_transcribe_rejects_oversized_audio(monkeypatch) -> None:
app.dependency_overrides = {}
def test_asr_transcribe_rejects_non_wav_audio() -> None:
def test_asr_transcribe_rejects_non_wav_audio(monkeypatch) -> None:
app.dependency_overrides[get_current_user] = lambda: CurrentUser(
id=uuid4(), email="user@example.com"
)
async def _allow_transcribe(*, user_id: str) -> bool:
del user_id
return True
monkeypatch.setattr(agent_router, "_allow_transcribe_request", _allow_transcribe)
client = TestClient(app)
fake_mp3 = BytesIO(b"fake-mp3")
fake_mp3.name = "test.mp3"
@@ -428,11 +446,17 @@ def test_asr_transcribe_rejects_non_wav_audio() -> None:
app.dependency_overrides = {}
def test_asr_transcribe_rejects_invalid_wav_payload() -> None:
def test_asr_transcribe_rejects_invalid_wav_payload(monkeypatch) -> None:
app.dependency_overrides[get_current_user] = lambda: CurrentUser(
id=uuid4(), email="user@example.com"
)
async def _allow_transcribe(*, user_id: str) -> bool:
del user_id
return True
monkeypatch.setattr(agent_router, "_allow_transcribe_request", _allow_transcribe)
client = TestClient(app)
fake_payload = BytesIO(b"not-a-wav")
fake_payload.name = "test.wav"
@@ -447,3 +471,33 @@ def test_asr_transcribe_rejects_invalid_wav_payload() -> None:
assert response.json()["detail"] == "Unsupported audio format"
finally:
app.dependency_overrides = {}
def test_asr_transcribe_rejects_when_rate_limited_for_current_user(monkeypatch) -> None:
known_user = CurrentUser(id=uuid4(), email="user@example.com")
app.dependency_overrides[get_current_user] = lambda: known_user
captured_user_ids: list[str] = []
async def _deny_transcribe(*, user_id: str) -> bool:
captured_user_ids.append(user_id)
return False
monkeypatch.setattr(agent_router, "_allow_transcribe_request", _deny_transcribe)
client = TestClient(app)
wav_content = b"RIFF\x24\x80\x00\x00WAVEfmt "
wav_file = BytesIO(wav_content)
wav_file.name = "test.wav"
try:
response = client.post(
"/api/v1/agent/transcribe",
files={"audio": ("test.wav", wav_file, "audio/wav")},
)
assert response.status_code == 429
assert response.json()["detail"] == "Too many transcribe requests"
assert captured_user_ids == [str(known_user.id)]
finally:
app.dependency_overrides = {}