from __future__ import annotations from datetime import date from types import SimpleNamespace from uuid import UUID from ag_ui.core import RunAgentInput from fastapi import HTTPException import pytest from sqlalchemy.exc import IntegrityError from core.auth.models import CurrentUser from core.config.settings import config import v1.agent.service as agent_service_module from v1.agent.service import AgentService, AsrService class _FakeRepository: def __init__(self) -> None: self.committed = False self.rolled_back = False self.deleted_session_id: str | None = None self.created_with_session_id: str | None = None self.persisted_user_messages: list[dict[str, object]] = [] async def get_session_owner(self, *, session_id: str) -> str: if session_id == "00000000-0000-0000-0000-000000000001": return "00000000-0000-0000-0000-000000000001" raise HTTPException(status_code=404, detail="Session not found") async def create_session_for_user( self, *, user_id: str, session_id: str | None = None ) -> str: del user_id self.created_with_session_id = session_id return session_id or "00000000-0000-0000-0000-000000000999" async def commit(self) -> None: self.committed = True async def rollback(self) -> None: self.rolled_back = True async def delete_session(self, *, session_id: str) -> None: self.deleted_session_id = session_id async def get_history_day( self, *, session_id: str, before: date | None ) -> dict[str, object] | None: del session_id if before is not None and before <= date(2026, 3, 6): return None return { "day": "2026-03-06", "hasMore": False, "messages": [{"id": "m1", "role": "assistant", "content": "hello"}], } async def get_latest_session_id_for_user(self, *, user_id: str) -> str | None: del user_id return "00000000-0000-0000-0000-000000000001" async def persist_user_message( self, *, session_id: str, run_id: str, content_text: str, metadata: dict[str, object] | None, ) -> None: self.persisted_user_messages.append( { "session_id": session_id, "run_id": run_id, "content_text": content_text, "metadata": metadata, } ) async def get_message_attachment_reference( self, *, session_id: str, message_id: str, attachment_index: int, ) -> dict[str, str] | None: del session_id, message_id if attachment_index != 0: return None return { "bucket": config.storage.bucket, "path": "agent-inputs/00000000-0000-0000-0000-000000000001/00000000-0000-0000-0000-000000000001/run-1/attachment-0-a.png", "mimeType": "image/png", } class _FakeQueue: def __init__(self) -> None: self.commands: list[dict[str, object]] = [] async def enqueue( self, *, command: dict[str, object], dedup_key: str | None ) -> str: self.commands.append(command) del dedup_key return "task-1" class _FailingQueue: async def enqueue( self, *, command: dict[str, object], dedup_key: str | None ) -> str: del command, dedup_key raise RuntimeError("enqueue failed") class _FakeStream: async def read( self, *, session_id: str, last_event_id: str | None ) -> list[dict[str, object]]: del session_id return [ {"id": "2-0", "event": {"type": "RUN_STARTED"}, "cursor": last_event_id} ] class _FakeAttachmentStorage: def __init__(self) -> None: self.calls: list[dict[str, object]] = [] async def upload_bytes( self, *, bucket: str, path: str, content: bytes, content_type: str, ) -> str: self.calls.append( { "bucket": bucket, "path": path, "content": content, "content_type": content_type, } ) return path async def download_bytes(self, *, bucket: str, path: str) -> bytes: self.calls.append( { "bucket": bucket, "path": path, "download": True, } ) return b"png-bytes" async def create_signed_url( self, *, bucket: str, path: str, expires_in_seconds: int, ) -> str: self.calls.append( { "bucket": bucket, "path": path, "signed": True, "expires_in_seconds": expires_in_seconds, } ) return f"https://signed.example/{path}?exp={expires_in_seconds}" class _AlwaysFailAttachmentStorage: async def upload_bytes( self, *, bucket: str, path: str, content: bytes, content_type: str, ) -> str: del bucket, path, content, content_type raise RuntimeError("upload failed") async def download_bytes(self, *, bucket: str, path: str) -> bytes: del bucket, path raise RuntimeError("download failed") async def create_signed_url( self, *, bucket: str, path: str, expires_in_seconds: int, ) -> str: del bucket, path, expires_in_seconds raise RuntimeError("sign failed") def _user() -> CurrentUser: return CurrentUser( id=UUID("00000000-0000-0000-0000-000000000001"), email="user@example.com", ) def _build_run_input(*, thread_id: str, run_id: str) -> RunAgentInput: return RunAgentInput.model_validate( { "threadId": thread_id, "runId": run_id, "state": {}, "messages": [{"id": "u1", "role": "user", "content": "hello"}], "tools": [], "context": [], "forwardedProps": {}, } ) async def test_resume_idempotency_uses_redis_lock_and_task_key() -> None: service = AgentService( repository=_FakeRepository(), queue=_FakeQueue(), stream=_FakeStream(), ) user = _user() run_input = _build_run_input( thread_id="00000000-0000-0000-0000-000000000001", run_id="run-1", ) first = await service.enqueue_resume( thread_id="00000000-0000-0000-0000-000000000001", run_input=run_input, current_user=user, ) second = await service.enqueue_resume( thread_id="00000000-0000-0000-0000-000000000001", run_input=run_input, current_user=user, ) assert first.task_id == second.task_id async def test_enqueue_run_creates_missing_thread_session() -> None: repository = _FakeRepository() queue = _FakeQueue() service = AgentService( repository=repository, queue=queue, stream=_FakeStream(), ) run_input = _build_run_input( thread_id="00000000-0000-0000-0000-000000000999", run_id="run-1", ) accepted = await service.enqueue_run( run_input=run_input, current_user=_user(), ) assert accepted.thread_id == "00000000-0000-0000-0000-000000000999" assert accepted.run_id == "run-1" assert accepted.created is True assert repository.created_with_session_id == "00000000-0000-0000-0000-000000000999" assert repository.committed is True assert queue.commands[0]["user_token"] is None async def test_enqueue_run_uses_explicit_user_token() -> None: repository = _FakeRepository() queue = _FakeQueue() service = AgentService( repository=repository, queue=queue, stream=_FakeStream(), ) run_input = _build_run_input( thread_id="00000000-0000-0000-0000-000000000001", run_id="run-1", ) await service.enqueue_run( run_input=run_input, current_user=_user(), user_token="Bearer access-token-1", ) assert queue.commands assert queue.commands[0]["user_token"] == "access-token-1" async def test_enqueue_run_keeps_created_session_when_enqueue_fails() -> None: repository = _FakeRepository() service = AgentService( repository=repository, queue=_FailingQueue(), stream=_FakeStream(), ) run_input = _build_run_input( thread_id="00000000-0000-0000-0000-000000000999", run_id="run-1", ) try: await service.enqueue_run( run_input=run_input, current_user=_user(), ) raise AssertionError("expected RuntimeError") except RuntimeError as exc: assert str(exc) == "enqueue failed" assert repository.deleted_session_id is None async def test_enqueue_run_handles_session_create_race() -> None: class _RaceRepository(_FakeRepository): def __init__(self) -> None: super().__init__() self.create_calls = 0 async def get_session_owner(self, *, session_id: str) -> str: if self.create_calls == 0: raise HTTPException(status_code=404, detail="Session not found") return "00000000-0000-0000-0000-000000000001" async def create_session_for_user( self, *, user_id: str, session_id: str | None = None ) -> str: del user_id, session_id self.create_calls += 1 raise IntegrityError("insert", {}, Exception("duplicate key")) repository = _RaceRepository() service = AgentService( repository=repository, queue=_FakeQueue(), stream=_FakeStream(), ) run_input = _build_run_input( thread_id="00000000-0000-0000-0000-000000000999", run_id="run-1", ) accepted = await service.enqueue_run( run_input=run_input, current_user=_user(), ) assert accepted.created is False assert repository.rolled_back is True async def test_enqueue_run_uses_forwarded_attachments_and_injects_metadata( monkeypatch, ) -> None: monkeypatch.setattr( agent_service_module.config.storage, "bucket", "agent-test-bucket" ) repository = _FakeRepository() attachment_storage = _FakeAttachmentStorage() service = AgentService( repository=repository, queue=_FakeQueue(), stream=_FakeStream(), attachment_storage=attachment_storage, ) run_input = RunAgentInput.model_validate( { "threadId": "00000000-0000-0000-0000-000000000001", "runId": "run-with-image", "state": {}, "messages": [ { "id": "u1", "role": "user", "content": [ {"type": "text", "text": "帮我看下这张图"}, { "type": "binary", "mimeType": "image/png", "url": "https://signed.example/upload.png", }, ], } ], "tools": [], "context": [], "forwardedProps": { "attachments": [ { "bucket": "agent-test-bucket", "path": "agent-inputs/00000000-0000-0000-0000-000000000001/00000000-0000-0000-0000-000000000001/upload.png", "mimeType": "image/png", } ] }, } ) accepted = await service.enqueue_run(run_input=run_input, current_user=_user()) assert accepted.task_id == "task-1" assert len(attachment_storage.calls) == 1 download = attachment_storage.calls[0] assert download["bucket"] == "agent-test-bucket" assert download["download"] is True assert repository.persisted_user_messages persisted = repository.persisted_user_messages[0] assert persisted["session_id"] == "00000000-0000-0000-0000-000000000001" assert persisted["run_id"] == "run-with-image" metadata = persisted["metadata"] assert isinstance(metadata, dict) attachments = metadata.get("attachments") assert isinstance(attachments, list) assert attachments and isinstance(attachments[0], dict) assert attachments[0]["bucket"] == "agent-test-bucket" assert isinstance(attachments[0]["path"], str) async def test_enqueue_run_raises_when_attachment_download_fails_without_fallback( monkeypatch, ) -> None: monkeypatch.setattr( agent_service_module.config.storage, "bucket", "agent-test-bucket" ) repository = _FakeRepository() service = AgentService( repository=repository, queue=_FakeQueue(), stream=_FakeStream(), attachment_storage=_AlwaysFailAttachmentStorage(), ) run_input = RunAgentInput.model_validate( { "threadId": "00000000-0000-0000-0000-000000000001", "runId": "run-with-image-fail", "state": {}, "messages": [ { "id": "u1", "role": "user", "content": [ {"type": "text", "text": "帮我看下这张图"}, { "type": "binary", "mimeType": "image/png", "url": "https://signed.example/upload.png", }, ], } ], "tools": [], "context": [], "forwardedProps": { "attachments": [ { "bucket": "agent-test-bucket", "path": "agent-inputs/00000000-0000-0000-0000-000000000001/00000000-0000-0000-0000-000000000001/upload.png", "mimeType": "image/png", } ] }, } ) try: await service.enqueue_run(run_input=run_input, current_user=_user()) raise AssertionError("expected HTTPException") except HTTPException as exc: assert exc.status_code == 502 assert exc.detail == "Failed to fetch attachment" assert repository.persisted_user_messages == [] async def test_enqueue_run_rejects_unsupported_attachment_type( monkeypatch, ) -> None: monkeypatch.setattr( agent_service_module.config.storage, "bucket", "agent-test-bucket" ) repository = _FakeRepository() attachment_storage = _FakeAttachmentStorage() service = AgentService( repository=repository, queue=_FakeQueue(), stream=_FakeStream(), attachment_storage=attachment_storage, ) run_input = RunAgentInput.model_validate( { "threadId": "00000000-0000-0000-0000-000000000001", "runId": "run-with-bad-image", "state": {}, "messages": [ { "id": "u1", "role": "user", "content": [ {"type": "text", "text": "请看附件"}, { "type": "binary", "mimeType": "image/gif", "url": "https://signed.example/upload.gif", }, ], } ], "tools": [], "context": [], "forwardedProps": { "attachments": [ { "bucket": "agent-test-bucket", "path": "agent-inputs/00000000-0000-0000-0000-000000000001/00000000-0000-0000-0000-000000000001/upload.gif", "mimeType": "image/gif", } ] }, } ) with pytest.raises(HTTPException) as exc_info: await service.enqueue_run(run_input=run_input, current_user=_user()) assert exc_info.value.status_code == 422 assert exc_info.value.detail == "Unsupported attachment type" assert attachment_storage.calls == [] async def test_enqueue_run_rejects_attachment_too_large( monkeypatch, ) -> None: monkeypatch.setattr(agent_service_module, "_MAX_ATTACHMENT_BYTES", 4) monkeypatch.setattr( agent_service_module.config.storage, "bucket", "agent-test-bucket" ) repository = _FakeRepository() attachment_storage = _FakeAttachmentStorage() service = AgentService( repository=repository, queue=_FakeQueue(), stream=_FakeStream(), attachment_storage=attachment_storage, ) run_input = RunAgentInput.model_validate( { "threadId": "00000000-0000-0000-0000-000000000001", "runId": "run-with-big-image", "state": {}, "messages": [ { "id": "u1", "role": "user", "content": [ {"type": "text", "text": "请看附件"}, { "type": "binary", "mimeType": "image/png", "url": "https://signed.example/upload.png", }, ], } ], "tools": [], "context": [], "forwardedProps": { "attachments": [ { "bucket": "agent-test-bucket", "path": "agent-inputs/00000000-0000-0000-0000-000000000001/00000000-0000-0000-0000-000000000001/upload.png", "mimeType": "image/png", } ] }, } ) with pytest.raises(HTTPException) as exc_info: await service.enqueue_run(run_input=run_input, current_user=_user()) assert exc_info.value.status_code == 413 assert exc_info.value.detail == "Attachment too large" assert len(attachment_storage.calls) == 1 assert attachment_storage.calls[0]["download"] is True async def test_enqueue_run_accepts_binary_url_and_persists_metadata() -> None: repository = _FakeRepository() queue = _FakeQueue() attachment_storage = _FakeAttachmentStorage() service = AgentService( repository=repository, queue=queue, stream=_FakeStream(), attachment_storage=attachment_storage, ) run_input = RunAgentInput.model_validate( { "threadId": "00000000-0000-0000-0000-000000000001", "runId": "run-with-binary-url", "state": {}, "messages": [ { "id": "u1", "role": "user", "content": [ {"type": "text", "text": "请分析"}, { "type": "binary", "mimeType": "image/png", "url": "https://signed.example/upload-1.png", }, ], } ], "tools": [], "context": [], "forwardedProps": { "attachments": [ { "bucket": config.storage.bucket, "path": "agent-inputs/00000000-0000-0000-0000-000000000001/00000000-0000-0000-0000-000000000001/upload-1.png", "mimeType": "image/png", } ] }, } ) accepted = await service.enqueue_run(run_input=run_input, current_user=_user()) assert accepted.task_id == "task-1" persisted = repository.persisted_user_messages[-1] metadata = persisted["metadata"] assert isinstance(metadata, dict) attachments = metadata.get("attachments") assert isinstance(attachments, list) assert attachments[0]["path"].endswith("upload-1.png") queue_input = queue.commands[-1]["run_input"] assert isinstance(queue_input, dict) content = queue_input["messages"][0]["content"] assert isinstance(content, list) assert content[1]["type"] == "binary" assert content[1]["url"] == "https://signed.example/upload-1.png" async def test_get_history_snapshot_wraps_history_day_as_state_snapshot_event() -> None: service = AgentService( repository=_FakeRepository(), queue=_FakeQueue(), stream=_FakeStream(), ) event = await service.get_history_snapshot( thread_id="00000000-0000-0000-0000-000000000001", before=date(2026, 3, 7), current_user=_user(), ) assert event["type"] == "STATE_SNAPSHOT" assert event["threadId"] == "00000000-0000-0000-0000-000000000001" snapshot = event["snapshot"] assert isinstance(snapshot, dict) assert snapshot["scope"] == "history_day" assert snapshot["day"] == "2026-03-06" assert snapshot["messages"][0]["id"] == "m1" async def test_get_user_history_snapshot_uses_latest_thread_when_absent() -> None: service = AgentService( repository=_FakeRepository(), queue=_FakeQueue(), stream=_FakeStream(), ) event = await service.get_user_history_snapshot( current_user=_user(), thread_id=None, before=None, ) assert event["type"] == "STATE_SNAPSHOT" assert event["threadId"] == "00000000-0000-0000-0000-000000000001" async def test_get_attachment_preview_returns_payload_and_mime() -> None: service = AgentService( repository=_FakeRepository(), queue=_FakeQueue(), stream=_FakeStream(), attachment_storage=_FakeAttachmentStorage(), ) payload, mime_type = await service.get_attachment_preview( thread_id="00000000-0000-0000-0000-000000000001", message_id="00000000-0000-0000-0000-000000000010", attachment_index=0, current_user=_user(), ) assert payload == b"png-bytes" assert mime_type == "image/png" async def test_get_attachment_preview_rejects_invalid_path() -> None: class _BadPathRepository(_FakeRepository): async def get_message_attachment_reference( self, *, session_id: str, message_id: str, attachment_index: int, ) -> dict[str, str] | None: del session_id, message_id, attachment_index return { "bucket": "bucket-test", "path": "agent-inputs/other-user/other-thread/run-1/a.png", "mimeType": "image/png", } service = AgentService( repository=_BadPathRepository(), queue=_FakeQueue(), stream=_FakeStream(), attachment_storage=_FakeAttachmentStorage(), ) with pytest.raises(HTTPException) as exc_info: await service.get_attachment_preview( thread_id="00000000-0000-0000-0000-000000000001", message_id="00000000-0000-0000-0000-000000000010", attachment_index=0, current_user=_user(), ) assert exc_info.value.status_code == 403 async def test_asr_service_parses_dict_output_sentence(monkeypatch) -> None: result = SimpleNamespace( status_code=200, message="ok", output={"sentence": {"text": "你好,世界"}}, request_id="req-test", ) class _FakeRecognition: def __init__(self, **kwargs) -> None: del kwargs def call(self, *, file: str): del file return result monkeypatch.setattr(agent_service_module, "Recognition", _FakeRecognition) monkeypatch.setattr(AsrService, "_get_api_key", lambda self: "test-key") service = AsrService() transcript = await service.transcribe_file("/tmp/test.wav", "test.wav") assert transcript == "你好,世界" async def test_asr_service_parses_sentence_when_result_is_dict(monkeypatch) -> None: result = { "status_code": 200, "message": "ok", "output": {"sentence": {"text": "字典结果"}}, "request_id": "req-dict", } class _FakeRecognition: def __init__(self, **kwargs) -> None: del kwargs def call(self, *, file: str): del file return result monkeypatch.setattr(agent_service_module, "Recognition", _FakeRecognition) monkeypatch.setattr(AsrService, "_get_api_key", lambda self: "test-key") service = AsrService() transcript = await service.transcribe_file("/tmp/test.wav", "test.wav") assert transcript == "字典结果" async def test_asr_service_returns_empty_when_sentence_missing(monkeypatch) -> None: result = { "status_code": 200, "message": "ok", "output": {}, } class _FakeRecognition: def __init__(self, **kwargs) -> None: del kwargs def call(self, *, file: str): del file return result monkeypatch.setattr(agent_service_module, "Recognition", _FakeRecognition) monkeypatch.setattr(AsrService, "_get_api_key", lambda self: "test-key") service = AsrService() transcript = await service.transcribe_file("/tmp/test.wav", "test.wav") assert transcript == ""