Files
social-app/backend/tests/unit/v1/agent/test_service.py
T
2026-03-19 00:52:12 +08:00

385 lines
12 KiB
Python

from __future__ import annotations
from datetime import date
from typing import cast
from urllib.parse import quote
from uuid import UUID
from ag_ui.core import RunAgentInput
from fastapi import HTTPException
import pytest
import v1.agent.service as agent_service_module
from core.auth.models import CurrentUser
from core.config.settings import config
from schemas.messages.chat_message import AgentChatMessageMetadata
from v1.agent.service import AgentService
class _FakeRepository:
def __init__(self) -> None:
self.committed = False
self.persisted_user_messages: list[dict[str, object]] = []
async def get_session_owner(self, *, session_id: str) -> str:
if session_id == "00000000-0000-0000-0000-000000000001":
return "00000000-0000-0000-0000-000000000001"
raise HTTPException(status_code=404, detail="Session not found")
async def create_session_for_user(
self, *, user_id: str, session_id: str | None = None
) -> str:
del user_id
return session_id or "00000000-0000-0000-0000-000000000999"
async def commit(self) -> None:
self.committed = True
async def rollback(self) -> None:
return None
async def get_history_day(
self, *, session_id: str, before: date | None
) -> dict[str, object] | None:
del session_id, before
return None
async def get_latest_session_id_for_user(self, *, user_id: str) -> str | None:
del user_id
return None
async def persist_user_message(
self,
*,
session_id: str,
content: str,
metadata: AgentChatMessageMetadata | None,
) -> None:
self.persisted_user_messages.append(
{
"session_id": session_id,
"content": content,
"metadata": metadata,
}
)
class _FakeQueue:
def __init__(self) -> None:
self.commands: list[dict[str, object]] = []
async def enqueue(
self, *, command: dict[str, object], dedup_key: str | None
) -> str:
del dedup_key
self.commands.append(command)
return "task-1"
class _FakeStream:
async def read(
self, *, session_id: str, last_event_id: str | None
) -> list[dict[str, object]]:
del session_id, last_event_id
return []
class _FakeAttachmentStorage:
async def upload_bytes(
self,
*,
bucket: str,
path: str,
content: bytes,
content_type: str,
) -> str:
del bucket, content, content_type
return path
async def download_bytes(self, *, bucket: str, path: str) -> bytes:
del bucket, path
return b""
async def create_signed_url(
self,
*,
bucket: str,
path: str,
expires_in_seconds: int,
) -> str:
del expires_in_seconds
return f"https://signed.example/{bucket}/{path}"
def parse_signed_url(self, url: str) -> tuple[str, str]:
parsed = url.split("/storage/v1/object/sign/")
if len(parsed) != 2:
raise RuntimeError("invalid")
bucket, path = parsed[1].split("/", 1)
path = path.split("?", 1)[0]
return bucket, path
def _user() -> CurrentUser:
return CurrentUser(
id=UUID("00000000-0000-0000-0000-000000000001"),
email="user@example.com",
)
def _build_run_input(*, urls: list[str]) -> RunAgentInput:
content: list[dict[str, str]] = [{"type": "text", "text": "hello"}]
for url in urls:
content.append({"type": "binary", "mimeType": "image/png", "url": url})
return RunAgentInput.model_validate(
{
"threadId": "00000000-0000-0000-0000-000000000001",
"runId": "run-1",
"state": {},
"messages": [
{
"id": "u1",
"role": "user",
"content": content,
}
],
"tools": [],
"context": [],
"forwardedProps": {},
}
)
@pytest.mark.asyncio
async def test_enqueue_run_rejects_non_project_host_signed_url(monkeypatch) -> None:
monkeypatch.setattr(
agent_service_module.config.storage, "bucket", "agent-test-bucket"
)
service = AgentService(
repository=_FakeRepository(),
queue=_FakeQueue(),
stream=_FakeStream(),
attachment_storage=_FakeAttachmentStorage(),
)
run_input = _build_run_input(
urls=[
"https://evil.example.com/storage/v1/object/sign/agent-test-bucket/a.png?token=1"
]
)
with pytest.raises(HTTPException) as exc_info:
await service.enqueue_run(run_input=run_input, current_user=_user())
assert exc_info.value.status_code == 422
assert exc_info.value.detail == "INVALID_BINARY_URL_HOST"
@pytest.mark.asyncio
async def test_enqueue_run_persists_attachment_and_queue_without_user_token(
monkeypatch,
) -> None:
monkeypatch.setattr(
agent_service_module.config.storage, "bucket", "agent-test-bucket"
)
repository = _FakeRepository()
queue = _FakeQueue()
service = AgentService(
repository=repository,
queue=queue,
stream=_FakeStream(),
attachment_storage=_FakeAttachmentStorage(),
)
base_url = str(config.supabase.url).rstrip("/")
safe_path = quote(
"agent-inputs/00000000-0000-0000-0000-000000000001/"
"00000000-0000-0000-0000-000000000001/uploads/a.png"
)
safe_path_two = quote(
"agent-inputs/00000000-0000-0000-0000-000000000001/"
"00000000-0000-0000-0000-000000000001/uploads/b.png"
)
run_input = _build_run_input(
urls=[
f"{base_url}/storage/v1/object/sign/agent-test-bucket/{safe_path}?token=1",
f"{base_url}/storage/v1/object/sign/agent-test-bucket/{safe_path_two}?token=1",
]
)
accepted = await service.enqueue_run(run_input=run_input, current_user=_user())
assert accepted.task_id == "task-1"
persisted = repository.persisted_user_messages[0]
metadata = cast(AgentChatMessageMetadata | None, persisted["metadata"])
assert metadata is not None
attachments = metadata.user_message_attachments
assert attachments is not None
assert len(attachments) == 2
assert attachments[0].bucket == "agent-test-bucket"
command = queue.commands[0]
assert "user_token" not in command
run_input = command["run_input"]
assert isinstance(run_input, dict)
assert run_input["threadId"] == "00000000-0000-0000-0000-000000000001"
assert run_input["runId"] == "run-1"
@pytest.mark.asyncio
async def test_create_attachment_signed_url_returns_url(monkeypatch) -> None:
monkeypatch.setattr(
agent_service_module.config.storage, "bucket", "agent-test-bucket"
)
service = AgentService(
repository=_FakeRepository(),
queue=_FakeQueue(),
stream=_FakeStream(),
attachment_storage=_FakeAttachmentStorage(),
)
payload = await service.create_attachment_signed_url(
bucket="agent-test-bucket",
path="agent-inputs/00000000-0000-0000-0000-000000000001/thread-x/uploads/a.png",
current_user=_user(),
)
assert payload["bucket"] == "agent-test-bucket"
assert payload["path"].endswith("/a.png")
assert payload["url"].startswith("https://signed.example/")
@pytest.mark.asyncio
async def test_create_attachment_signed_url_rejects_out_of_scope_path(
monkeypatch,
) -> None:
monkeypatch.setattr(
agent_service_module.config.storage, "bucket", "agent-test-bucket"
)
service = AgentService(
repository=_FakeRepository(),
queue=_FakeQueue(),
stream=_FakeStream(),
attachment_storage=_FakeAttachmentStorage(),
)
with pytest.raises(HTTPException) as exc_info:
await service.create_attachment_signed_url(
bucket="agent-test-bucket",
path="agent-inputs/other-user/thread-x/uploads/a.png",
current_user=_user(),
)
assert exc_info.value.status_code == 422
@pytest.mark.asyncio
async def test_enqueue_run_rejects_too_many_attachments(monkeypatch) -> None:
monkeypatch.setattr(
agent_service_module.config.storage, "bucket", "agent-test-bucket"
)
service = AgentService(
repository=_FakeRepository(),
queue=_FakeQueue(),
stream=_FakeStream(),
attachment_storage=_FakeAttachmentStorage(),
)
base_url = str(config.supabase.url).rstrip("/")
safe_paths = [
quote(
"agent-inputs/00000000-0000-0000-0000-000000000001/"
"00000000-0000-0000-0000-000000000001/uploads/a.png"
),
quote(
"agent-inputs/00000000-0000-0000-0000-000000000001/"
"00000000-0000-0000-0000-000000000001/uploads/b.png"
),
quote(
"agent-inputs/00000000-0000-0000-0000-000000000001/"
"00000000-0000-0000-0000-000000000001/uploads/c.png"
),
quote(
"agent-inputs/00000000-0000-0000-0000-000000000001/"
"00000000-0000-0000-0000-000000000001/uploads/d.png"
),
]
run_input = _build_run_input(
urls=[
f"{base_url}/storage/v1/object/sign/agent-test-bucket/{safe_path}?token=1"
for safe_path in safe_paths
]
)
with pytest.raises(HTTPException) as exc_info:
await service.enqueue_run(run_input=run_input, current_user=_user())
assert exc_info.value.status_code == 422
assert exc_info.value.detail == "Too many attachments"
@pytest.mark.asyncio
async def test_get_history_snapshot_filters_out_tool_messages() -> None:
class _HistoryRepository(_FakeRepository):
async def get_history_day(
self, *, session_id: str, before: date | None
) -> dict[str, object] | None:
del session_id, before
return {
"day": "2026-03-17",
"hasMore": False,
"messages": [
{
"id": "00000000-0000-0000-0000-000000000111",
"seq": 1,
"role": "user",
"content": "帮我查一下今天日程",
"metadata": None,
"timestamp": "2026-03-17T09:00:00Z",
},
{
"id": "00000000-0000-0000-0000-000000000112",
"seq": 2,
"role": "tool",
"content": "已获取日程列表,共 3 条",
"metadata": {
"run_id": "run-1",
"tool_agent_output": {
"tool_name": "calendar_read",
"tool_call_id": "call-1",
"status": "success",
"result": "status=success total=3 returned=3",
},
},
"timestamp": "2026-03-17T09:00:01Z",
},
{
"id": "00000000-0000-0000-0000-000000000113",
"seq": 3,
"role": "assistant",
"content": "今天共有 3 条日程。",
"metadata": {
"run_id": "run-1",
"agent_output": {
"status": "success",
"answer": "今天共有 3 条日程。",
"key_points": [],
"result_type": "summary",
"suggested_actions": [],
},
},
"timestamp": "2026-03-17T09:00:02Z",
},
],
}
service = AgentService(
repository=_HistoryRepository(),
queue=_FakeQueue(),
stream=_FakeStream(),
attachment_storage=_FakeAttachmentStorage(),
)
snapshot = await service.get_history_snapshot(
thread_id="00000000-0000-0000-0000-000000000001",
before=None,
current_user=_user(),
)
assert [message.role for message in snapshot.messages] == ["user", "assistant"]