refactor: 简化 AgentScope 运行时模块与事件处理
- 移除冗余的 user_token 参数传递 - 重构 tool.result 事件使用 ToolAgentOutput 模型 - 重构 text.end 事件使用 WorkerAgentOutput 模型 - 简化 store 模块的 tool result 处理逻辑 - 更新 router/service 适配新事件结构 - 清理废弃的测试文件与设计文档 - 新增 AgentRuns 多模态存储设计文档
This commit is contained in:
@@ -6,7 +6,6 @@ from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from core.config.settings import config
|
||||
from models.agent_chat_message import AgentChatMessageRole
|
||||
from v1.agent.repository import AgentRepository
|
||||
|
||||
@@ -36,243 +35,27 @@ class _FakeSession:
|
||||
self.flushed = True
|
||||
|
||||
|
||||
class _FakeToolResultStorage:
|
||||
def __init__(self, payload: dict[str, object] | None) -> None:
|
||||
self._payload = payload
|
||||
|
||||
async def read_json(self, *, bucket: str, path: str) -> dict[str, object] | None:
|
||||
del bucket, path
|
||||
return self._payload
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_message_hydrates_content_from_object_storage() -> None:
|
||||
repository = AgentRepository(
|
||||
session=SimpleNamespace(), # type: ignore[arg-type]
|
||||
tool_result_storage=_FakeToolResultStorage(
|
||||
{
|
||||
"toolName": "front.navigate_to_route",
|
||||
"result": {"ok": True, "applied": True, "content": "已跳转"},
|
||||
}
|
||||
),
|
||||
)
|
||||
async def test_snapshot_message_returns_raw_db_columns() -> None:
|
||||
repository = AgentRepository(session=SimpleNamespace()) # type: ignore[arg-type]
|
||||
now = datetime.now(timezone.utc)
|
||||
message = SimpleNamespace(
|
||||
id=uuid4(),
|
||||
session_id=uuid4(),
|
||||
seq=7,
|
||||
role=AgentChatMessageRole.TOOL,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
content='{"offloaded":true}',
|
||||
metadata_json={
|
||||
"tool_call_id": "call-1",
|
||||
"storage_bucket": config.storage.bucket,
|
||||
"storage_path": "tool-results/run-1/call-1.json",
|
||||
},
|
||||
metadata_json={"tool_call_id": "call-1"},
|
||||
created_at=now,
|
||||
)
|
||||
|
||||
payload = await repository._to_snapshot_message(message) # type: ignore[arg-type]
|
||||
|
||||
assert payload["toolCallId"] == "call-1"
|
||||
assert payload["content"] == "已跳转"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_message_hydrates_ui_from_ui_schema_field() -> None:
|
||||
repository = AgentRepository(
|
||||
session=SimpleNamespace(), # type: ignore[arg-type]
|
||||
tool_result_storage=_FakeToolResultStorage(
|
||||
{
|
||||
"toolName": "calendar_write",
|
||||
"ui_schema": {
|
||||
"type": "calendar_operation.v1",
|
||||
"version": "v1",
|
||||
"data": {"ok": True, "operation": "create"},
|
||||
"actions": [],
|
||||
},
|
||||
}
|
||||
),
|
||||
)
|
||||
message = SimpleNamespace(
|
||||
id=uuid4(),
|
||||
role=AgentChatMessageRole.TOOL,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
content="已创建日程:项目评审(明天 10:00)",
|
||||
metadata_json={
|
||||
"tool_call_id": "call-3",
|
||||
"storage_bucket": config.storage.bucket,
|
||||
"storage_path": "tool-results/run-1/call-3.json",
|
||||
},
|
||||
)
|
||||
|
||||
payload = await repository._to_snapshot_message(message) # type: ignore[arg-type]
|
||||
|
||||
assert payload["toolCallId"] == "call-3"
|
||||
assert payload["content"] == "已创建日程:项目评审(明天 10:00)"
|
||||
ui = payload.get("ui")
|
||||
assert isinstance(ui, dict)
|
||||
assert ui["type"] == "calendar_operation.v1"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_message_keeps_inline_content_when_storage_payload_missing() -> None:
|
||||
repository = AgentRepository(
|
||||
session=SimpleNamespace(), # type: ignore[arg-type]
|
||||
tool_result_storage=_FakeToolResultStorage(None),
|
||||
)
|
||||
message = SimpleNamespace(
|
||||
id=uuid4(),
|
||||
role=AgentChatMessageRole.TOOL,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
content="inline-tool-content",
|
||||
metadata_json={
|
||||
"tool_call_id": "call-2",
|
||||
"storage_bucket": config.storage.bucket,
|
||||
"storage_path": "tool-results/run-1/call-2.json",
|
||||
},
|
||||
)
|
||||
|
||||
payload = await repository._to_snapshot_message(message) # type: ignore[arg-type]
|
||||
|
||||
assert payload["toolCallId"] == "call-2"
|
||||
assert payload["content"] == "inline-tool-content"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_message_skips_storage_when_path_not_matching_session() -> None:
|
||||
repository = AgentRepository(
|
||||
session=SimpleNamespace(), # type: ignore[arg-type]
|
||||
tool_result_storage=_FakeToolResultStorage(
|
||||
{
|
||||
"ui_schema": {
|
||||
"type": "calendar_operation.v1",
|
||||
"version": "v1",
|
||||
"data": {"ok": True},
|
||||
"actions": [],
|
||||
}
|
||||
}
|
||||
),
|
||||
)
|
||||
message = SimpleNamespace(
|
||||
id=uuid4(),
|
||||
session_id=uuid4(),
|
||||
role=AgentChatMessageRole.TOOL,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
content="summary",
|
||||
metadata_json={
|
||||
"tool_call_id": "call-x",
|
||||
"storage_bucket": config.storage.bucket,
|
||||
"storage_path": "tool-results/foreign-session/call-y.json",
|
||||
},
|
||||
)
|
||||
|
||||
payload = await repository._to_snapshot_message(message) # type: ignore[arg-type]
|
||||
|
||||
assert payload["content"] == "summary"
|
||||
assert "ui" not in payload
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_message_rejects_path_traversal() -> None:
|
||||
repository = AgentRepository(
|
||||
session=SimpleNamespace(), # type: ignore[arg-type]
|
||||
tool_result_storage=_FakeToolResultStorage(
|
||||
{
|
||||
"ui_schema": {
|
||||
"type": "calendar_operation.v1",
|
||||
"version": "v1",
|
||||
"data": {"ok": True},
|
||||
"actions": [],
|
||||
}
|
||||
}
|
||||
),
|
||||
)
|
||||
message = SimpleNamespace(
|
||||
id=uuid4(),
|
||||
session_id=uuid4(),
|
||||
role=AgentChatMessageRole.TOOL,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
content="summary",
|
||||
metadata_json={
|
||||
"tool_call_id": "call-z",
|
||||
"storage_bucket": config.storage.bucket,
|
||||
"storage_path": "tool-results/ok/../../evil/call-z.json",
|
||||
},
|
||||
)
|
||||
|
||||
payload = await repository._to_snapshot_message(message) # type: ignore[arg-type]
|
||||
|
||||
assert payload["content"] == "summary"
|
||||
assert "ui" not in payload
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_message_supports_legacy_storage_path() -> None:
|
||||
repository = AgentRepository(
|
||||
session=SimpleNamespace(), # type: ignore[arg-type]
|
||||
tool_result_storage=_FakeToolResultStorage(
|
||||
{
|
||||
"ui_schema": {
|
||||
"type": "calendar_operation.v1",
|
||||
"version": "v1",
|
||||
"data": {"ok": True},
|
||||
"actions": [],
|
||||
},
|
||||
"content": "legacy content",
|
||||
}
|
||||
),
|
||||
)
|
||||
message = SimpleNamespace(
|
||||
id=uuid4(),
|
||||
session_id=uuid4(),
|
||||
role=AgentChatMessageRole.TOOL,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
content='{"offloaded":true}',
|
||||
metadata_json={
|
||||
"tool_call_id": "call-legacy",
|
||||
"storage_bucket": config.storage.bucket,
|
||||
"storage_path": "tool-results/old-run/call-legacy.json",
|
||||
},
|
||||
)
|
||||
|
||||
payload = await repository._to_snapshot_message(message) # type: ignore[arg-type]
|
||||
|
||||
assert payload["content"] == "legacy content"
|
||||
ui = payload.get("ui")
|
||||
assert isinstance(ui, dict)
|
||||
assert ui["type"] == "calendar_operation.v1"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_user_message_snapshot_includes_renderable_attachments() -> None:
|
||||
repository = AgentRepository(
|
||||
session=SimpleNamespace(), # type: ignore[arg-type]
|
||||
)
|
||||
message = SimpleNamespace(
|
||||
id=uuid4(),
|
||||
session_id=uuid4(),
|
||||
role=AgentChatMessageRole.USER,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
content="请分析这张图",
|
||||
metadata_json={
|
||||
"attachments": [
|
||||
{
|
||||
"bucket": "agent-chat-attachments",
|
||||
"path": "agent-inputs/u1/t1/r1/m1/att-1.png",
|
||||
"mimeType": "image/png",
|
||||
}
|
||||
]
|
||||
},
|
||||
)
|
||||
|
||||
payload = await repository._to_snapshot_message(message) # type: ignore[arg-type]
|
||||
|
||||
assert payload["role"] == "user"
|
||||
assert payload["content"] == "请分析这张图"
|
||||
attachments = payload.get("attachments")
|
||||
assert isinstance(attachments, list)
|
||||
assert len(attachments) == 1
|
||||
first = attachments[0]
|
||||
assert isinstance(first, dict)
|
||||
assert first["mimeType"] == "image/png"
|
||||
assert isinstance(first.get("previewPath"), str)
|
||||
assert payload["seq"] == 7
|
||||
assert payload["role"] == "tool"
|
||||
assert payload["content"] == '{"offloaded":true}'
|
||||
assert payload["metadata"] == {"tool_call_id": "call-1"}
|
||||
assert "timestamp" in payload
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -318,32 +101,3 @@ async def test_persist_user_message_keeps_existing_session_title() -> None:
|
||||
|
||||
assert session_row.title == "已有标题"
|
||||
assert session_row.message_count == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_message_attachment_reference_returns_item() -> None:
|
||||
session_id = str(uuid4())
|
||||
message_id = str(uuid4())
|
||||
message = SimpleNamespace(
|
||||
metadata_json={
|
||||
"attachments": [
|
||||
{
|
||||
"bucket": "bucket-test",
|
||||
"path": "agent-inputs/u/t/r/a.png",
|
||||
"mimeType": "image/png",
|
||||
}
|
||||
]
|
||||
}
|
||||
)
|
||||
fake_session = _FakeSession(message)
|
||||
repository = AgentRepository(session=fake_session) # type: ignore[arg-type]
|
||||
|
||||
ref = await repository.get_message_attachment_reference(
|
||||
session_id=session_id,
|
||||
message_id=message_id,
|
||||
attachment_index=0,
|
||||
)
|
||||
|
||||
assert ref is not None
|
||||
assert ref["bucket"] == "bucket-test"
|
||||
assert ref["mimeType"] == "image/png"
|
||||
|
||||
@@ -12,48 +12,6 @@ from core.auth.models import CurrentUser
|
||||
from v1.agent import router as agent_router
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_allow_run_request_fails_closed_when_redis_unavailable(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
async def _raise_redis_error():
|
||||
raise RuntimeError("redis unavailable")
|
||||
|
||||
monkeypatch.setattr(agent_router, "get_or_init_redis_client", _raise_redis_error)
|
||||
|
||||
allowed = await agent_router._allow_run_request(user_id="user-1")
|
||||
|
||||
assert allowed is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_acquire_sse_slot_fails_closed_when_redis_unavailable(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
async def _raise_redis_error():
|
||||
raise RuntimeError("redis unavailable")
|
||||
|
||||
monkeypatch.setattr(agent_router, "get_or_init_redis_client", _raise_redis_error)
|
||||
|
||||
allowed = await agent_router._acquire_sse_slot(user_id="user-1")
|
||||
|
||||
assert allowed is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_allow_transcribe_request_fails_closed_when_redis_unavailable(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
async def _raise_redis_error():
|
||||
raise RuntimeError("redis unavailable")
|
||||
|
||||
monkeypatch.setattr(agent_router, "get_or_init_redis_client", _raise_redis_error)
|
||||
|
||||
allowed = await agent_router._allow_transcribe_request(user_id="user-1")
|
||||
|
||||
assert allowed is False
|
||||
|
||||
|
||||
def _resume_input_with_tool_message() -> RunAgentInput:
|
||||
return RunAgentInput.model_validate(
|
||||
{
|
||||
@@ -82,13 +40,7 @@ async def test_enqueue_resume_rejects_without_tool_contract() -> None:
|
||||
"threadId": "00000000-0000-0000-0000-000000000001",
|
||||
"runId": "run-resume-invalid",
|
||||
"state": {},
|
||||
"messages": [
|
||||
{
|
||||
"id": "u1",
|
||||
"role": "user",
|
||||
"content": "continue",
|
||||
}
|
||||
],
|
||||
"messages": [{"id": "u1", "role": "user", "content": "continue"}],
|
||||
"tools": [],
|
||||
"context": [],
|
||||
"forwardedProps": {},
|
||||
@@ -109,10 +61,6 @@ async def test_enqueue_resume_rejects_without_tool_contract() -> None:
|
||||
)
|
||||
|
||||
assert exc_info.value.status_code == 422
|
||||
assert (
|
||||
exc_info.value.detail
|
||||
== "RunAgentInput.messages requires a tool message with toolCallId for resume"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -141,7 +89,6 @@ async def test_enqueue_resume_rejects_when_rate_limited(
|
||||
)
|
||||
|
||||
assert exc_info.value.status_code == 429
|
||||
assert exc_info.value.detail == "Too many run requests"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -173,96 +120,4 @@ async def test_enqueue_resume_accepts_valid_tool_contract(
|
||||
)
|
||||
|
||||
assert result.task_id == "task-resume-1"
|
||||
assert result.thread_id == "00000000-0000-0000-0000-000000000001"
|
||||
assert result.run_id == "run-resume-1"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_events_retries_on_redis_timeout(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
async def _acquire(*, user_id: str) -> bool:
|
||||
del user_id
|
||||
return True
|
||||
|
||||
async def _release(*, user_id: str) -> None:
|
||||
del user_id
|
||||
|
||||
monkeypatch.setattr(agent_router, "_acquire_sse_slot", _acquire)
|
||||
monkeypatch.setattr(agent_router, "_release_sse_slot", _release)
|
||||
|
||||
class _Request:
|
||||
async def is_disconnected(self) -> bool:
|
||||
return False
|
||||
|
||||
class _Service:
|
||||
def __init__(self) -> None:
|
||||
self.calls = 0
|
||||
|
||||
async def stream_events(self, **kwargs): # noqa: ANN003
|
||||
del kwargs
|
||||
self.calls += 1
|
||||
if self.calls == 1:
|
||||
raise RuntimeError("Timeout reading from localhost:6379")
|
||||
if self.calls == 2:
|
||||
return [{"id": "1-0", "event": {"type": "RUN_FINISHED"}}]
|
||||
return []
|
||||
|
||||
response = await agent_router.stream_events(
|
||||
request=cast(Any, _Request()),
|
||||
thread_id="00000000-0000-0000-0000-000000000001",
|
||||
service=cast(Any, _Service()),
|
||||
current_user=CurrentUser(id=uuid4(), email="user@example.com"),
|
||||
last_event_id=None,
|
||||
idle_limit=2,
|
||||
)
|
||||
|
||||
chunks: list[str] = []
|
||||
async for chunk in response.body_iterator:
|
||||
chunks.append(str(chunk))
|
||||
if any("RUN_FINISHED" in item for item in chunks):
|
||||
break
|
||||
|
||||
merged = "".join(chunks)
|
||||
assert "event: RUN_FINISHED" in merged
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_attachment_preview_rejects_negative_index() -> None:
|
||||
class _Service:
|
||||
async def get_attachment_preview(self, **kwargs): # noqa: ANN003
|
||||
del kwargs
|
||||
raise AssertionError("get_attachment_preview should not be called")
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await agent_router.get_attachment_preview(
|
||||
thread_id="00000000-0000-0000-0000-000000000001",
|
||||
message_id="00000000-0000-0000-0000-000000000010",
|
||||
attachment_index=-1,
|
||||
service=cast(Any, _Service()),
|
||||
current_user=CurrentUser(id=uuid4(), email="user@example.com"),
|
||||
)
|
||||
|
||||
assert exc_info.value.status_code == 422
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_attachment_preview_returns_streaming_response() -> None:
|
||||
class _Service:
|
||||
async def get_attachment_preview(self, **kwargs): # noqa: ANN003
|
||||
del kwargs
|
||||
return b"png-bytes", "image/png"
|
||||
|
||||
response = await agent_router.get_attachment_preview(
|
||||
thread_id="00000000-0000-0000-0000-000000000001",
|
||||
message_id="00000000-0000-0000-0000-000000000010",
|
||||
attachment_index=0,
|
||||
service=cast(Any, _Service()),
|
||||
current_user=CurrentUser(id=uuid4(), email="user@example.com"),
|
||||
)
|
||||
chunks: list[bytes] = []
|
||||
async for chunk in response.body_iterator:
|
||||
chunks.append(cast(bytes, chunk))
|
||||
|
||||
assert response.media_type == "image/png"
|
||||
assert b"".join(chunks) == b"png-bytes"
|
||||
|
||||
@@ -1,25 +1,22 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import date
|
||||
from types import SimpleNamespace
|
||||
from urllib.parse import quote
|
||||
from uuid import UUID
|
||||
|
||||
from ag_ui.core import RunAgentInput
|
||||
from fastapi import HTTPException
|
||||
import pytest
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
|
||||
import v1.agent.service as agent_service_module
|
||||
from core.auth.models import CurrentUser
|
||||
from core.config.settings import config
|
||||
import v1.agent.service as agent_service_module
|
||||
from v1.agent.service import AgentService, AsrService
|
||||
from v1.agent.service import AgentService
|
||||
|
||||
|
||||
class _FakeRepository:
|
||||
def __init__(self) -> None:
|
||||
self.committed = False
|
||||
self.rolled_back = False
|
||||
self.deleted_session_id: str | None = None
|
||||
self.created_with_session_id: str | None = None
|
||||
self.persisted_user_messages: list[dict[str, object]] = []
|
||||
|
||||
async def get_session_owner(self, *, session_id: str) -> str:
|
||||
@@ -31,33 +28,23 @@ class _FakeRepository:
|
||||
self, *, user_id: str, session_id: str | None = None
|
||||
) -> str:
|
||||
del user_id
|
||||
self.created_with_session_id = session_id
|
||||
return session_id or "00000000-0000-0000-0000-000000000999"
|
||||
|
||||
async def commit(self) -> None:
|
||||
self.committed = True
|
||||
|
||||
async def rollback(self) -> None:
|
||||
self.rolled_back = True
|
||||
|
||||
async def delete_session(self, *, session_id: str) -> None:
|
||||
self.deleted_session_id = session_id
|
||||
return None
|
||||
|
||||
async def get_history_day(
|
||||
self, *, session_id: str, before: date | None
|
||||
) -> dict[str, object] | None:
|
||||
del session_id
|
||||
if before is not None and before <= date(2026, 3, 6):
|
||||
return None
|
||||
return {
|
||||
"day": "2026-03-06",
|
||||
"hasMore": False,
|
||||
"messages": [{"id": "m1", "role": "assistant", "content": "hello"}],
|
||||
}
|
||||
del session_id, before
|
||||
return None
|
||||
|
||||
async def get_latest_session_id_for_user(self, *, user_id: str) -> str | None:
|
||||
del user_id
|
||||
return "00000000-0000-0000-0000-000000000001"
|
||||
return None
|
||||
|
||||
async def persist_user_message(
|
||||
self,
|
||||
@@ -76,22 +63,6 @@ class _FakeRepository:
|
||||
}
|
||||
)
|
||||
|
||||
async def get_message_attachment_reference(
|
||||
self,
|
||||
*,
|
||||
session_id: str,
|
||||
message_id: str,
|
||||
attachment_index: int,
|
||||
) -> dict[str, str] | None:
|
||||
del session_id, message_id
|
||||
if attachment_index != 0:
|
||||
return None
|
||||
return {
|
||||
"bucket": config.storage.bucket,
|
||||
"path": "agent-inputs/00000000-0000-0000-0000-000000000001/00000000-0000-0000-0000-000000000001/run-1/attachment-0-a.png",
|
||||
"mimeType": "image/png",
|
||||
}
|
||||
|
||||
|
||||
class _FakeQueue:
|
||||
def __init__(self) -> None:
|
||||
@@ -100,33 +71,20 @@ class _FakeQueue:
|
||||
async def enqueue(
|
||||
self, *, command: dict[str, object], dedup_key: str | None
|
||||
) -> str:
|
||||
self.commands.append(command)
|
||||
del dedup_key
|
||||
self.commands.append(command)
|
||||
return "task-1"
|
||||
|
||||
|
||||
class _FailingQueue:
|
||||
async def enqueue(
|
||||
self, *, command: dict[str, object], dedup_key: str | None
|
||||
) -> str:
|
||||
del command, dedup_key
|
||||
raise RuntimeError("enqueue failed")
|
||||
|
||||
|
||||
class _FakeStream:
|
||||
async def read(
|
||||
self, *, session_id: str, last_event_id: str | None
|
||||
) -> list[dict[str, object]]:
|
||||
del session_id
|
||||
return [
|
||||
{"id": "2-0", "event": {"type": "RUN_STARTED"}, "cursor": last_event_id}
|
||||
]
|
||||
del session_id, last_event_id
|
||||
return []
|
||||
|
||||
|
||||
class _FakeAttachmentStorage:
|
||||
def __init__(self) -> None:
|
||||
self.calls: list[dict[str, object]] = []
|
||||
|
||||
async def upload_bytes(
|
||||
self,
|
||||
*,
|
||||
@@ -135,65 +93,12 @@ class _FakeAttachmentStorage:
|
||||
content: bytes,
|
||||
content_type: str,
|
||||
) -> str:
|
||||
self.calls.append(
|
||||
{
|
||||
"bucket": bucket,
|
||||
"path": path,
|
||||
"content": content,
|
||||
"content_type": content_type,
|
||||
}
|
||||
)
|
||||
del bucket, content, content_type
|
||||
return path
|
||||
|
||||
async def download_bytes(self, *, bucket: str, path: str) -> bytes:
|
||||
self.calls.append(
|
||||
{
|
||||
"bucket": bucket,
|
||||
"path": path,
|
||||
"download": True,
|
||||
}
|
||||
)
|
||||
return b"png-bytes"
|
||||
|
||||
async def create_signed_url(
|
||||
self,
|
||||
*,
|
||||
bucket: str,
|
||||
path: str,
|
||||
expires_in_seconds: int,
|
||||
) -> str:
|
||||
self.calls.append(
|
||||
{
|
||||
"bucket": bucket,
|
||||
"path": path,
|
||||
"signed": True,
|
||||
"expires_in_seconds": expires_in_seconds,
|
||||
}
|
||||
)
|
||||
return f"https://signed.example/{path}?exp={expires_in_seconds}"
|
||||
|
||||
def parse_signed_url(self, url: str) -> tuple[str, str]:
|
||||
if url.startswith("https://signed.example/"):
|
||||
path = url.replace("https://signed.example/", "").split("?")[0]
|
||||
return "agent-test-bucket", path
|
||||
raise RuntimeError("Invalid signed URL")
|
||||
|
||||
|
||||
class _AlwaysFailAttachmentStorage:
|
||||
async def upload_bytes(
|
||||
self,
|
||||
*,
|
||||
bucket: str,
|
||||
path: str,
|
||||
content: bytes,
|
||||
content_type: str,
|
||||
) -> str:
|
||||
del bucket, path, content, content_type
|
||||
raise RuntimeError("upload failed")
|
||||
|
||||
async def download_bytes(self, *, bucket: str, path: str) -> bytes:
|
||||
del bucket, path
|
||||
raise RuntimeError("download failed")
|
||||
return b""
|
||||
|
||||
async def create_signed_url(
|
||||
self,
|
||||
@@ -202,12 +107,16 @@ class _AlwaysFailAttachmentStorage:
|
||||
path: str,
|
||||
expires_in_seconds: int,
|
||||
) -> str:
|
||||
del bucket, path, expires_in_seconds
|
||||
raise RuntimeError("sign failed")
|
||||
del expires_in_seconds
|
||||
return f"https://signed.example/{bucket}/{path}"
|
||||
|
||||
def parse_signed_url(self, url: str) -> tuple[str, str]:
|
||||
del url
|
||||
raise RuntimeError("parse failed")
|
||||
parsed = url.split("/storage/v1/object/sign/")
|
||||
if len(parsed) != 2:
|
||||
raise RuntimeError("invalid")
|
||||
bucket, path = parsed[1].split("/", 1)
|
||||
path = path.split("?", 1)[0]
|
||||
return bucket, path
|
||||
|
||||
|
||||
def _user() -> CurrentUser:
|
||||
@@ -217,13 +126,22 @@ def _user() -> CurrentUser:
|
||||
)
|
||||
|
||||
|
||||
def _build_run_input(*, thread_id: str, run_id: str) -> RunAgentInput:
|
||||
def _build_run_input(*, url: str) -> RunAgentInput:
|
||||
return RunAgentInput.model_validate(
|
||||
{
|
||||
"threadId": thread_id,
|
||||
"runId": run_id,
|
||||
"threadId": "00000000-0000-0000-0000-000000000001",
|
||||
"runId": "run-1",
|
||||
"state": {},
|
||||
"messages": [{"id": "u1", "role": "user", "content": "hello"}],
|
||||
"messages": [
|
||||
{
|
||||
"id": "u1",
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "hello"},
|
||||
{"type": "binary", "mimeType": "image/png", "url": url},
|
||||
],
|
||||
}
|
||||
],
|
||||
"tools": [],
|
||||
"context": [],
|
||||
"forwardedProps": {},
|
||||
@@ -231,454 +149,69 @@ def _build_run_input(*, thread_id: str, run_id: str) -> RunAgentInput:
|
||||
)
|
||||
|
||||
|
||||
async def test_resume_idempotency_uses_redis_lock_and_task_key() -> None:
|
||||
@pytest.mark.asyncio
|
||||
async def test_enqueue_run_rejects_non_project_host_signed_url(monkeypatch) -> None:
|
||||
monkeypatch.setattr(
|
||||
agent_service_module.config.storage, "bucket", "agent-test-bucket"
|
||||
)
|
||||
service = AgentService(
|
||||
repository=_FakeRepository(),
|
||||
queue=_FakeQueue(),
|
||||
stream=_FakeStream(),
|
||||
)
|
||||
user = _user()
|
||||
run_input = _build_run_input(
|
||||
thread_id="00000000-0000-0000-0000-000000000001",
|
||||
run_id="run-1",
|
||||
)
|
||||
|
||||
first = await service.enqueue_resume(
|
||||
thread_id="00000000-0000-0000-0000-000000000001",
|
||||
run_input=run_input,
|
||||
current_user=user,
|
||||
)
|
||||
second = await service.enqueue_resume(
|
||||
thread_id="00000000-0000-0000-0000-000000000001",
|
||||
run_input=run_input,
|
||||
current_user=user,
|
||||
)
|
||||
|
||||
assert first.task_id == second.task_id
|
||||
|
||||
|
||||
async def test_enqueue_run_creates_missing_thread_session() -> None:
|
||||
repository = _FakeRepository()
|
||||
queue = _FakeQueue()
|
||||
service = AgentService(
|
||||
repository=repository,
|
||||
queue=queue,
|
||||
stream=_FakeStream(),
|
||||
attachment_storage=_FakeAttachmentStorage(),
|
||||
)
|
||||
run_input = _build_run_input(
|
||||
thread_id="00000000-0000-0000-0000-000000000999",
|
||||
run_id="run-1",
|
||||
)
|
||||
|
||||
accepted = await service.enqueue_run(
|
||||
run_input=run_input,
|
||||
current_user=_user(),
|
||||
)
|
||||
|
||||
assert accepted.thread_id == "00000000-0000-0000-0000-000000000999"
|
||||
assert accepted.run_id == "run-1"
|
||||
assert accepted.created is True
|
||||
assert repository.created_with_session_id == "00000000-0000-0000-0000-000000000999"
|
||||
assert repository.committed is True
|
||||
assert queue.commands[0]["user_token"] is None
|
||||
|
||||
|
||||
async def test_enqueue_run_uses_explicit_user_token() -> None:
|
||||
repository = _FakeRepository()
|
||||
queue = _FakeQueue()
|
||||
service = AgentService(
|
||||
repository=repository,
|
||||
queue=queue,
|
||||
stream=_FakeStream(),
|
||||
)
|
||||
run_input = _build_run_input(
|
||||
thread_id="00000000-0000-0000-0000-000000000001",
|
||||
run_id="run-1",
|
||||
)
|
||||
|
||||
await service.enqueue_run(
|
||||
run_input=run_input,
|
||||
current_user=_user(),
|
||||
user_token="Bearer access-token-1",
|
||||
)
|
||||
|
||||
assert queue.commands
|
||||
assert queue.commands[0]["user_token"] == "access-token-1"
|
||||
|
||||
|
||||
async def test_enqueue_run_keeps_created_session_when_enqueue_fails() -> None:
|
||||
repository = _FakeRepository()
|
||||
service = AgentService(
|
||||
repository=repository,
|
||||
queue=_FailingQueue(),
|
||||
stream=_FakeStream(),
|
||||
)
|
||||
run_input = _build_run_input(
|
||||
thread_id="00000000-0000-0000-0000-000000000999",
|
||||
run_id="run-1",
|
||||
)
|
||||
|
||||
try:
|
||||
await service.enqueue_run(
|
||||
run_input=run_input,
|
||||
current_user=_user(),
|
||||
)
|
||||
raise AssertionError("expected RuntimeError")
|
||||
except RuntimeError as exc:
|
||||
assert str(exc) == "enqueue failed"
|
||||
|
||||
assert repository.deleted_session_id is None
|
||||
|
||||
|
||||
async def test_enqueue_run_handles_session_create_race() -> None:
|
||||
class _RaceRepository(_FakeRepository):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.create_calls = 0
|
||||
|
||||
async def get_session_owner(self, *, session_id: str) -> str:
|
||||
if self.create_calls == 0:
|
||||
raise HTTPException(status_code=404, detail="Session not found")
|
||||
return "00000000-0000-0000-0000-000000000001"
|
||||
|
||||
async def create_session_for_user(
|
||||
self, *, user_id: str, session_id: str | None = None
|
||||
) -> str:
|
||||
del user_id, session_id
|
||||
self.create_calls += 1
|
||||
raise IntegrityError("insert", {}, Exception("duplicate key"))
|
||||
|
||||
repository = _RaceRepository()
|
||||
service = AgentService(
|
||||
repository=repository,
|
||||
queue=_FakeQueue(),
|
||||
stream=_FakeStream(),
|
||||
)
|
||||
run_input = _build_run_input(
|
||||
thread_id="00000000-0000-0000-0000-000000000999",
|
||||
run_id="run-1",
|
||||
)
|
||||
|
||||
accepted = await service.enqueue_run(
|
||||
run_input=run_input,
|
||||
current_user=_user(),
|
||||
)
|
||||
|
||||
assert accepted.created is False
|
||||
assert repository.rolled_back is True
|
||||
|
||||
|
||||
async def test_enqueue_run_parses_signed_url_and_injects_metadata(
|
||||
monkeypatch,
|
||||
) -> None:
|
||||
monkeypatch.setattr(
|
||||
agent_service_module.config.storage, "bucket", "agent-test-bucket"
|
||||
)
|
||||
repository = _FakeRepository()
|
||||
attachment_storage = _FakeAttachmentStorage()
|
||||
service = AgentService(
|
||||
repository=repository,
|
||||
queue=_FakeQueue(),
|
||||
stream=_FakeStream(),
|
||||
attachment_storage=attachment_storage,
|
||||
)
|
||||
run_input = RunAgentInput.model_validate(
|
||||
{
|
||||
"threadId": "00000000-0000-0000-0000-000000000001",
|
||||
"runId": "run-with-image",
|
||||
"state": {},
|
||||
"messages": [
|
||||
{
|
||||
"id": "u1",
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "帮我看下这张图"},
|
||||
{
|
||||
"type": "binary",
|
||||
"mimeType": "image/png",
|
||||
"url": "https://signed.example/agent-inputs/u/t/r/file.png",
|
||||
},
|
||||
],
|
||||
}
|
||||
],
|
||||
"tools": [],
|
||||
"context": [],
|
||||
}
|
||||
)
|
||||
|
||||
accepted = await service.enqueue_run(run_input=run_input, current_user=_user())
|
||||
|
||||
assert accepted.task_id == "task-1"
|
||||
assert repository.persisted_user_messages
|
||||
persisted = repository.persisted_user_messages[0]
|
||||
assert persisted["session_id"] == "00000000-0000-0000-0000-000000000001"
|
||||
assert persisted["run_id"] == "run-with-image"
|
||||
metadata = persisted["metadata"]
|
||||
assert isinstance(metadata, dict)
|
||||
attachments = metadata.get("user_message_attachments")
|
||||
assert isinstance(attachments, dict)
|
||||
assert attachments["bucket"] == "agent-test-bucket"
|
||||
assert attachments["path"] == "agent-inputs/u/t/r/file.png"
|
||||
assert attachments["mime_type"] == "image/png"
|
||||
|
||||
|
||||
async def test_enqueue_run_with_invalid_signed_url_still_succeeds(
|
||||
monkeypatch,
|
||||
) -> None:
|
||||
monkeypatch.setattr(
|
||||
agent_service_module.config.storage, "bucket", "agent-test-bucket"
|
||||
)
|
||||
repository = _FakeRepository()
|
||||
attachment_storage = _FakeAttachmentStorage()
|
||||
service = AgentService(
|
||||
repository=repository,
|
||||
queue=_FakeQueue(),
|
||||
stream=_FakeStream(),
|
||||
attachment_storage=attachment_storage,
|
||||
)
|
||||
run_input = RunAgentInput.model_validate(
|
||||
{
|
||||
"threadId": "00000000-0000-0000-0000-000000000001",
|
||||
"runId": "run-with-invalid-url",
|
||||
"state": {},
|
||||
"messages": [
|
||||
{
|
||||
"id": "u1",
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "帮我看下这张图"},
|
||||
{
|
||||
"type": "binary",
|
||||
"mimeType": "image/png",
|
||||
"url": "invalid-url-format",
|
||||
},
|
||||
],
|
||||
}
|
||||
],
|
||||
"tools": [],
|
||||
"context": [],
|
||||
}
|
||||
)
|
||||
|
||||
accepted = await service.enqueue_run(run_input=run_input, current_user=_user())
|
||||
|
||||
assert accepted.task_id == "task-1"
|
||||
assert repository.persisted_user_messages
|
||||
persisted = repository.persisted_user_messages[0]
|
||||
metadata = persisted["metadata"]
|
||||
assert metadata is None
|
||||
|
||||
|
||||
async def test_enqueue_run_rejects_unsupported_attachment_type(
|
||||
monkeypatch,
|
||||
) -> None:
|
||||
monkeypatch.setattr(
|
||||
agent_service_module.config.storage, "bucket", "agent-test-bucket"
|
||||
)
|
||||
repository = _FakeRepository()
|
||||
attachment_storage = _FakeAttachmentStorage()
|
||||
service = AgentService(
|
||||
repository=repository,
|
||||
queue=_FakeQueue(),
|
||||
stream=_FakeStream(),
|
||||
attachment_storage=attachment_storage,
|
||||
)
|
||||
run_input = RunAgentInput.model_validate(
|
||||
{
|
||||
"threadId": "00000000-0000-0000-0000-000000000001",
|
||||
"runId": "run-with-bad-image",
|
||||
"state": {},
|
||||
"messages": [
|
||||
{
|
||||
"id": "u1",
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "请看附件"},
|
||||
{
|
||||
"type": "binary",
|
||||
"mimeType": "image/gif",
|
||||
"url": "https://signed.example/upload.gif",
|
||||
},
|
||||
],
|
||||
}
|
||||
],
|
||||
"tools": [],
|
||||
"context": [],
|
||||
"forwardedProps": {
|
||||
"attachments": [
|
||||
{
|
||||
"bucket": "agent-test-bucket",
|
||||
"path": "agent-inputs/00000000-0000-0000-0000-000000000001/00000000-0000-0000-0000-000000000001/upload.gif",
|
||||
"mimeType": "image/gif",
|
||||
}
|
||||
]
|
||||
},
|
||||
}
|
||||
url="https://evil.example.com/storage/v1/object/sign/agent-test-bucket/a.png?token=1"
|
||||
)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await service.enqueue_run(run_input=run_input, current_user=_user())
|
||||
|
||||
assert exc_info.value.status_code == 422
|
||||
assert exc_info.value.detail == "Unsupported attachment type"
|
||||
assert attachment_storage.calls == []
|
||||
assert exc_info.value.detail == "INVALID_BINARY_URL_HOST"
|
||||
|
||||
|
||||
async def test_enqueue_run_rejects_attachment_too_large(
|
||||
@pytest.mark.asyncio
|
||||
async def test_enqueue_run_persists_attachment_and_queue_without_user_token(
|
||||
monkeypatch,
|
||||
) -> None:
|
||||
monkeypatch.setattr(agent_service_module, "_MAX_ATTACHMENT_BYTES", 4)
|
||||
monkeypatch.setattr(
|
||||
agent_service_module.config.storage, "bucket", "agent-test-bucket"
|
||||
)
|
||||
repository = _FakeRepository()
|
||||
attachment_storage = _FakeAttachmentStorage()
|
||||
service = AgentService(
|
||||
repository=repository,
|
||||
queue=_FakeQueue(),
|
||||
stream=_FakeStream(),
|
||||
attachment_storage=attachment_storage,
|
||||
)
|
||||
run_input = RunAgentInput.model_validate(
|
||||
{
|
||||
"threadId": "00000000-0000-0000-0000-000000000001",
|
||||
"runId": "run-with-big-image",
|
||||
"state": {},
|
||||
"messages": [
|
||||
{
|
||||
"id": "u1",
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "请看附件"},
|
||||
{
|
||||
"type": "binary",
|
||||
"mimeType": "image/png",
|
||||
"url": "https://signed.example/upload.png",
|
||||
},
|
||||
],
|
||||
}
|
||||
],
|
||||
"tools": [],
|
||||
"context": [],
|
||||
"forwardedProps": {
|
||||
"attachments": [
|
||||
{
|
||||
"bucket": "agent-test-bucket",
|
||||
"path": "agent-inputs/00000000-0000-0000-0000-000000000001/00000000-0000-0000-0000-000000000001/upload.png",
|
||||
"mimeType": "image/png",
|
||||
}
|
||||
]
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await service.enqueue_run(run_input=run_input, current_user=_user())
|
||||
|
||||
assert exc_info.value.status_code == 413
|
||||
assert exc_info.value.detail == "Attachment too large"
|
||||
assert len(attachment_storage.calls) == 1
|
||||
assert attachment_storage.calls[0]["download"] is True
|
||||
|
||||
|
||||
async def test_enqueue_run_accepts_binary_url_and_persists_metadata() -> None:
|
||||
repository = _FakeRepository()
|
||||
queue = _FakeQueue()
|
||||
attachment_storage = _FakeAttachmentStorage()
|
||||
service = AgentService(
|
||||
repository=repository,
|
||||
queue=queue,
|
||||
stream=_FakeStream(),
|
||||
attachment_storage=attachment_storage,
|
||||
attachment_storage=_FakeAttachmentStorage(),
|
||||
)
|
||||
run_input = RunAgentInput.model_validate(
|
||||
{
|
||||
"threadId": "00000000-0000-0000-0000-000000000001",
|
||||
"runId": "run-with-binary-url",
|
||||
"state": {},
|
||||
"messages": [
|
||||
{
|
||||
"id": "u1",
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "请分析"},
|
||||
{
|
||||
"type": "binary",
|
||||
"mimeType": "image/png",
|
||||
"url": "https://signed.example/upload-1.png",
|
||||
},
|
||||
],
|
||||
}
|
||||
],
|
||||
"tools": [],
|
||||
"context": [],
|
||||
"forwardedProps": {
|
||||
"attachments": [
|
||||
{
|
||||
"bucket": config.storage.bucket,
|
||||
"path": "agent-inputs/00000000-0000-0000-0000-000000000001/00000000-0000-0000-0000-000000000001/upload-1.png",
|
||||
"mimeType": "image/png",
|
||||
}
|
||||
]
|
||||
},
|
||||
}
|
||||
base_url = str(config.supabase.url).rstrip("/")
|
||||
safe_path = quote(
|
||||
"agent-inputs/00000000-0000-0000-0000-000000000001/"
|
||||
"00000000-0000-0000-0000-000000000001/uploads/a.png"
|
||||
)
|
||||
run_input = _build_run_input(
|
||||
url=f"{base_url}/storage/v1/object/sign/agent-test-bucket/{safe_path}?token=1"
|
||||
)
|
||||
|
||||
accepted = await service.enqueue_run(run_input=run_input, current_user=_user())
|
||||
|
||||
assert accepted.task_id == "task-1"
|
||||
persisted = repository.persisted_user_messages[-1]
|
||||
persisted = repository.persisted_user_messages[0]
|
||||
metadata = persisted["metadata"]
|
||||
assert isinstance(metadata, dict)
|
||||
attachments = metadata.get("attachments")
|
||||
assert isinstance(attachments, list)
|
||||
assert attachments[0]["path"].endswith("upload-1.png")
|
||||
queue_input = queue.commands[-1]["run_input"]
|
||||
assert isinstance(queue_input, dict)
|
||||
content = queue_input["messages"][0]["content"]
|
||||
assert isinstance(content, list)
|
||||
assert content[1]["type"] == "binary"
|
||||
assert content[1]["url"] == "https://signed.example/upload-1.png"
|
||||
attachment = metadata["user_message_attachments"]
|
||||
assert attachment["bucket"] == "agent-test-bucket"
|
||||
command = queue.commands[0]
|
||||
assert "user_token" not in command
|
||||
|
||||
|
||||
async def test_get_history_snapshot_wraps_history_day_as_state_snapshot_event() -> None:
|
||||
service = AgentService(
|
||||
repository=_FakeRepository(),
|
||||
queue=_FakeQueue(),
|
||||
stream=_FakeStream(),
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_attachment_signed_url_returns_url(monkeypatch) -> None:
|
||||
monkeypatch.setattr(
|
||||
agent_service_module.config.storage, "bucket", "agent-test-bucket"
|
||||
)
|
||||
|
||||
event = await service.get_history_snapshot(
|
||||
thread_id="00000000-0000-0000-0000-000000000001",
|
||||
before=date(2026, 3, 7),
|
||||
current_user=_user(),
|
||||
)
|
||||
|
||||
assert event["type"] == "STATE_SNAPSHOT"
|
||||
assert event["threadId"] == "00000000-0000-0000-0000-000000000001"
|
||||
snapshot = event["snapshot"]
|
||||
assert isinstance(snapshot, dict)
|
||||
assert snapshot["scope"] == "history_day"
|
||||
assert snapshot["day"] == "2026-03-06"
|
||||
assert snapshot["messages"][0]["id"] == "m1"
|
||||
|
||||
|
||||
async def test_get_user_history_snapshot_uses_latest_thread_when_absent() -> None:
|
||||
service = AgentService(
|
||||
repository=_FakeRepository(),
|
||||
queue=_FakeQueue(),
|
||||
stream=_FakeStream(),
|
||||
)
|
||||
event = await service.get_user_history_snapshot(
|
||||
current_user=_user(),
|
||||
thread_id=None,
|
||||
before=None,
|
||||
)
|
||||
assert event["type"] == "STATE_SNAPSHOT"
|
||||
assert event["threadId"] == "00000000-0000-0000-0000-000000000001"
|
||||
|
||||
|
||||
async def test_get_attachment_preview_returns_payload_and_mime() -> None:
|
||||
service = AgentService(
|
||||
repository=_FakeRepository(),
|
||||
queue=_FakeQueue(),
|
||||
@@ -686,120 +219,36 @@ async def test_get_attachment_preview_returns_payload_and_mime() -> None:
|
||||
attachment_storage=_FakeAttachmentStorage(),
|
||||
)
|
||||
|
||||
payload, mime_type = await service.get_attachment_preview(
|
||||
thread_id="00000000-0000-0000-0000-000000000001",
|
||||
message_id="00000000-0000-0000-0000-000000000010",
|
||||
attachment_index=0,
|
||||
payload = await service.create_attachment_signed_url(
|
||||
bucket="agent-test-bucket",
|
||||
path="agent-inputs/00000000-0000-0000-0000-000000000001/thread-x/uploads/a.png",
|
||||
current_user=_user(),
|
||||
)
|
||||
|
||||
assert payload == b"png-bytes"
|
||||
assert mime_type == "image/png"
|
||||
assert payload["bucket"] == "agent-test-bucket"
|
||||
assert payload["path"].endswith("/a.png")
|
||||
assert payload["url"].startswith("https://signed.example/")
|
||||
|
||||
|
||||
async def test_get_attachment_preview_rejects_invalid_path() -> None:
|
||||
class _BadPathRepository(_FakeRepository):
|
||||
async def get_message_attachment_reference(
|
||||
self,
|
||||
*,
|
||||
session_id: str,
|
||||
message_id: str,
|
||||
attachment_index: int,
|
||||
) -> dict[str, str] | None:
|
||||
del session_id, message_id, attachment_index
|
||||
return {
|
||||
"bucket": "bucket-test",
|
||||
"path": "agent-inputs/other-user/other-thread/run-1/a.png",
|
||||
"mimeType": "image/png",
|
||||
}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_attachment_signed_url_rejects_out_of_scope_path(
|
||||
monkeypatch,
|
||||
) -> None:
|
||||
monkeypatch.setattr(
|
||||
agent_service_module.config.storage, "bucket", "agent-test-bucket"
|
||||
)
|
||||
service = AgentService(
|
||||
repository=_BadPathRepository(),
|
||||
repository=_FakeRepository(),
|
||||
queue=_FakeQueue(),
|
||||
stream=_FakeStream(),
|
||||
attachment_storage=_FakeAttachmentStorage(),
|
||||
)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await service.get_attachment_preview(
|
||||
thread_id="00000000-0000-0000-0000-000000000001",
|
||||
message_id="00000000-0000-0000-0000-000000000010",
|
||||
attachment_index=0,
|
||||
await service.create_attachment_signed_url(
|
||||
bucket="agent-test-bucket",
|
||||
path="agent-inputs/other-user/thread-x/uploads/a.png",
|
||||
current_user=_user(),
|
||||
)
|
||||
|
||||
assert exc_info.value.status_code == 403
|
||||
|
||||
|
||||
async def test_asr_service_parses_dict_output_sentence(monkeypatch) -> None:
|
||||
result = SimpleNamespace(
|
||||
status_code=200,
|
||||
message="ok",
|
||||
output={"sentence": {"text": "你好,世界"}},
|
||||
request_id="req-test",
|
||||
)
|
||||
|
||||
class _FakeRecognition:
|
||||
def __init__(self, **kwargs) -> None:
|
||||
del kwargs
|
||||
|
||||
def call(self, *, file: str):
|
||||
del file
|
||||
return result
|
||||
|
||||
monkeypatch.setattr(agent_service_module, "Recognition", _FakeRecognition)
|
||||
monkeypatch.setattr(AsrService, "_get_api_key", lambda self: "test-key")
|
||||
service = AsrService()
|
||||
|
||||
transcript = await service.transcribe_file("/tmp/test.wav", "test.wav")
|
||||
|
||||
assert transcript == "你好,世界"
|
||||
|
||||
|
||||
async def test_asr_service_parses_sentence_when_result_is_dict(monkeypatch) -> None:
|
||||
result = {
|
||||
"status_code": 200,
|
||||
"message": "ok",
|
||||
"output": {"sentence": {"text": "字典结果"}},
|
||||
"request_id": "req-dict",
|
||||
}
|
||||
|
||||
class _FakeRecognition:
|
||||
def __init__(self, **kwargs) -> None:
|
||||
del kwargs
|
||||
|
||||
def call(self, *, file: str):
|
||||
del file
|
||||
return result
|
||||
|
||||
monkeypatch.setattr(agent_service_module, "Recognition", _FakeRecognition)
|
||||
monkeypatch.setattr(AsrService, "_get_api_key", lambda self: "test-key")
|
||||
service = AsrService()
|
||||
|
||||
transcript = await service.transcribe_file("/tmp/test.wav", "test.wav")
|
||||
|
||||
assert transcript == "字典结果"
|
||||
|
||||
|
||||
async def test_asr_service_returns_empty_when_sentence_missing(monkeypatch) -> None:
|
||||
result = {
|
||||
"status_code": 200,
|
||||
"message": "ok",
|
||||
"output": {},
|
||||
}
|
||||
|
||||
class _FakeRecognition:
|
||||
def __init__(self, **kwargs) -> None:
|
||||
del kwargs
|
||||
|
||||
def call(self, *, file: str):
|
||||
del file
|
||||
return result
|
||||
|
||||
monkeypatch.setattr(agent_service_module, "Recognition", _FakeRecognition)
|
||||
monkeypatch.setattr(AsrService, "_get_api_key", lambda self: "test-key")
|
||||
service = AsrService()
|
||||
|
||||
transcript = await service.transcribe_file("/tmp/test.wav", "test.wav")
|
||||
|
||||
assert transcript == ""
|
||||
assert exc_info.value.status_code == 422
|
||||
|
||||
Reference in New Issue
Block a user