314 lines
9.5 KiB
Python
314 lines
9.5 KiB
Python
from __future__ import annotations
|
|
|
|
from datetime import date
|
|
from typing import cast
|
|
from urllib.parse import quote
|
|
from uuid import UUID
|
|
|
|
from ag_ui.core import RunAgentInput
|
|
from fastapi import HTTPException
|
|
import pytest
|
|
|
|
import v1.agent.service as agent_service_module
|
|
from core.auth.models import CurrentUser
|
|
from core.config.settings import config
|
|
from schemas.messages.chat_message import AgentChatMessageMetadata
|
|
from v1.agent.service import AgentService
|
|
|
|
|
|
class _FakeRepository:
|
|
def __init__(self) -> None:
|
|
self.committed = False
|
|
self.persisted_user_messages: list[dict[str, object]] = []
|
|
|
|
async def get_session_owner(self, *, session_id: str) -> str:
|
|
if session_id == "00000000-0000-0000-0000-000000000001":
|
|
return "00000000-0000-0000-0000-000000000001"
|
|
raise HTTPException(status_code=404, detail="Session not found")
|
|
|
|
async def create_session_for_user(
|
|
self, *, user_id: str, session_id: str | None = None
|
|
) -> str:
|
|
del user_id
|
|
return session_id or "00000000-0000-0000-0000-000000000999"
|
|
|
|
async def commit(self) -> None:
|
|
self.committed = True
|
|
|
|
async def rollback(self) -> None:
|
|
return None
|
|
|
|
async def get_history_day(
|
|
self, *, session_id: str, before: date | None
|
|
) -> dict[str, object] | None:
|
|
del session_id, before
|
|
return None
|
|
|
|
async def get_latest_session_id_for_user(self, *, user_id: str) -> str | None:
|
|
del user_id
|
|
return None
|
|
|
|
async def persist_user_message(
|
|
self,
|
|
*,
|
|
session_id: str,
|
|
content: str,
|
|
metadata: AgentChatMessageMetadata | None,
|
|
) -> None:
|
|
self.persisted_user_messages.append(
|
|
{
|
|
"session_id": session_id,
|
|
"content": content,
|
|
"metadata": metadata,
|
|
}
|
|
)
|
|
|
|
|
|
class _FakeQueue:
|
|
def __init__(self) -> None:
|
|
self.commands: list[dict[str, object]] = []
|
|
|
|
async def enqueue(
|
|
self, *, command: dict[str, object], dedup_key: str | None
|
|
) -> str:
|
|
del dedup_key
|
|
self.commands.append(command)
|
|
return "task-1"
|
|
|
|
|
|
class _FakeStream:
|
|
async def read(
|
|
self, *, session_id: str, last_event_id: str | None
|
|
) -> list[dict[str, object]]:
|
|
del session_id, last_event_id
|
|
return []
|
|
|
|
|
|
class _FakeAttachmentStorage:
|
|
async def upload_bytes(
|
|
self,
|
|
*,
|
|
bucket: str,
|
|
path: str,
|
|
content: bytes,
|
|
content_type: str,
|
|
) -> str:
|
|
del bucket, content, content_type
|
|
return path
|
|
|
|
async def download_bytes(self, *, bucket: str, path: str) -> bytes:
|
|
del bucket, path
|
|
return b""
|
|
|
|
async def create_signed_url(
|
|
self,
|
|
*,
|
|
bucket: str,
|
|
path: str,
|
|
expires_in_seconds: int,
|
|
) -> str:
|
|
del expires_in_seconds
|
|
return f"https://signed.example/{bucket}/{path}"
|
|
|
|
def parse_signed_url(self, url: str) -> tuple[str, str]:
|
|
parsed = url.split("/storage/v1/object/sign/")
|
|
if len(parsed) != 2:
|
|
raise RuntimeError("invalid")
|
|
bucket, path = parsed[1].split("/", 1)
|
|
path = path.split("?", 1)[0]
|
|
return bucket, path
|
|
|
|
|
|
def _user() -> CurrentUser:
|
|
return CurrentUser(
|
|
id=UUID("00000000-0000-0000-0000-000000000001"),
|
|
email="user@example.com",
|
|
)
|
|
|
|
|
|
def _build_run_input(*, urls: list[str]) -> RunAgentInput:
|
|
content: list[dict[str, str]] = [{"type": "text", "text": "hello"}]
|
|
for url in urls:
|
|
content.append({"type": "binary", "mimeType": "image/png", "url": url})
|
|
return RunAgentInput.model_validate(
|
|
{
|
|
"threadId": "00000000-0000-0000-0000-000000000001",
|
|
"runId": "run-1",
|
|
"state": {},
|
|
"messages": [
|
|
{
|
|
"id": "u1",
|
|
"role": "user",
|
|
"content": content,
|
|
}
|
|
],
|
|
"tools": [],
|
|
"context": [],
|
|
"forwardedProps": {},
|
|
}
|
|
)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_enqueue_run_rejects_non_project_host_signed_url(monkeypatch) -> None:
|
|
monkeypatch.setattr(
|
|
agent_service_module.config.storage, "bucket", "agent-test-bucket"
|
|
)
|
|
service = AgentService(
|
|
repository=_FakeRepository(),
|
|
queue=_FakeQueue(),
|
|
stream=_FakeStream(),
|
|
attachment_storage=_FakeAttachmentStorage(),
|
|
)
|
|
run_input = _build_run_input(
|
|
urls=[
|
|
"https://evil.example.com/storage/v1/object/sign/agent-test-bucket/a.png?token=1"
|
|
]
|
|
)
|
|
|
|
with pytest.raises(HTTPException) as exc_info:
|
|
await service.enqueue_run(run_input=run_input, current_user=_user())
|
|
|
|
assert exc_info.value.status_code == 422
|
|
assert exc_info.value.detail == "INVALID_BINARY_URL_HOST"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_enqueue_run_persists_attachment_and_queue_without_user_token(
|
|
monkeypatch,
|
|
) -> None:
|
|
monkeypatch.setattr(
|
|
agent_service_module.config.storage, "bucket", "agent-test-bucket"
|
|
)
|
|
repository = _FakeRepository()
|
|
queue = _FakeQueue()
|
|
service = AgentService(
|
|
repository=repository,
|
|
queue=queue,
|
|
stream=_FakeStream(),
|
|
attachment_storage=_FakeAttachmentStorage(),
|
|
)
|
|
base_url = str(config.supabase.url).rstrip("/")
|
|
safe_path = quote(
|
|
"agent-inputs/00000000-0000-0000-0000-000000000001/"
|
|
"00000000-0000-0000-0000-000000000001/uploads/a.png"
|
|
)
|
|
safe_path_two = quote(
|
|
"agent-inputs/00000000-0000-0000-0000-000000000001/"
|
|
"00000000-0000-0000-0000-000000000001/uploads/b.png"
|
|
)
|
|
run_input = _build_run_input(
|
|
urls=[
|
|
f"{base_url}/storage/v1/object/sign/agent-test-bucket/{safe_path}?token=1",
|
|
f"{base_url}/storage/v1/object/sign/agent-test-bucket/{safe_path_two}?token=1",
|
|
]
|
|
)
|
|
|
|
accepted = await service.enqueue_run(run_input=run_input, current_user=_user())
|
|
|
|
assert accepted.task_id == "task-1"
|
|
persisted = repository.persisted_user_messages[0]
|
|
metadata = cast(AgentChatMessageMetadata | None, persisted["metadata"])
|
|
assert metadata is not None
|
|
attachments = metadata.user_message_attachments
|
|
assert attachments is not None
|
|
assert len(attachments) == 2
|
|
assert attachments[0].bucket == "agent-test-bucket"
|
|
command = queue.commands[0]
|
|
assert "user_token" not in command
|
|
run_input = command["run_input"]
|
|
assert isinstance(run_input, dict)
|
|
assert run_input["threadId"] == "00000000-0000-0000-0000-000000000001"
|
|
assert run_input["runId"] == "run-1"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_create_attachment_signed_url_returns_url(monkeypatch) -> None:
|
|
monkeypatch.setattr(
|
|
agent_service_module.config.storage, "bucket", "agent-test-bucket"
|
|
)
|
|
service = AgentService(
|
|
repository=_FakeRepository(),
|
|
queue=_FakeQueue(),
|
|
stream=_FakeStream(),
|
|
attachment_storage=_FakeAttachmentStorage(),
|
|
)
|
|
|
|
payload = await service.create_attachment_signed_url(
|
|
bucket="agent-test-bucket",
|
|
path="agent-inputs/00000000-0000-0000-0000-000000000001/thread-x/uploads/a.png",
|
|
current_user=_user(),
|
|
)
|
|
|
|
assert payload["bucket"] == "agent-test-bucket"
|
|
assert payload["path"].endswith("/a.png")
|
|
assert payload["url"].startswith("https://signed.example/")
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_create_attachment_signed_url_rejects_out_of_scope_path(
|
|
monkeypatch,
|
|
) -> None:
|
|
monkeypatch.setattr(
|
|
agent_service_module.config.storage, "bucket", "agent-test-bucket"
|
|
)
|
|
service = AgentService(
|
|
repository=_FakeRepository(),
|
|
queue=_FakeQueue(),
|
|
stream=_FakeStream(),
|
|
attachment_storage=_FakeAttachmentStorage(),
|
|
)
|
|
|
|
with pytest.raises(HTTPException) as exc_info:
|
|
await service.create_attachment_signed_url(
|
|
bucket="agent-test-bucket",
|
|
path="agent-inputs/other-user/thread-x/uploads/a.png",
|
|
current_user=_user(),
|
|
)
|
|
|
|
assert exc_info.value.status_code == 422
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_enqueue_run_rejects_too_many_attachments(monkeypatch) -> None:
|
|
monkeypatch.setattr(
|
|
agent_service_module.config.storage, "bucket", "agent-test-bucket"
|
|
)
|
|
service = AgentService(
|
|
repository=_FakeRepository(),
|
|
queue=_FakeQueue(),
|
|
stream=_FakeStream(),
|
|
attachment_storage=_FakeAttachmentStorage(),
|
|
)
|
|
base_url = str(config.supabase.url).rstrip("/")
|
|
safe_paths = [
|
|
quote(
|
|
"agent-inputs/00000000-0000-0000-0000-000000000001/"
|
|
"00000000-0000-0000-0000-000000000001/uploads/a.png"
|
|
),
|
|
quote(
|
|
"agent-inputs/00000000-0000-0000-0000-000000000001/"
|
|
"00000000-0000-0000-0000-000000000001/uploads/b.png"
|
|
),
|
|
quote(
|
|
"agent-inputs/00000000-0000-0000-0000-000000000001/"
|
|
"00000000-0000-0000-0000-000000000001/uploads/c.png"
|
|
),
|
|
quote(
|
|
"agent-inputs/00000000-0000-0000-0000-000000000001/"
|
|
"00000000-0000-0000-0000-000000000001/uploads/d.png"
|
|
),
|
|
]
|
|
run_input = _build_run_input(
|
|
urls=[
|
|
f"{base_url}/storage/v1/object/sign/agent-test-bucket/{safe_path}?token=1"
|
|
for safe_path in safe_paths
|
|
]
|
|
)
|
|
|
|
with pytest.raises(HTTPException) as exc_info:
|
|
await service.enqueue_run(run_input=run_input, current_user=_user())
|
|
|
|
assert exc_info.value.status_code == 422
|
|
assert exc_info.value.detail == "Too many attachments"
|