127 lines
4.1 KiB
Python
127 lines
4.1 KiB
Python
from __future__ import annotations
|
|
|
|
from datetime import datetime, timezone
|
|
from uuid import UUID, uuid4
|
|
|
|
import pytest
|
|
from fastapi import HTTPException
|
|
|
|
from core.auth.models import CurrentUser
|
|
from models.agent_chat_session import (
|
|
AgentChatSession,
|
|
AgentChatSessionStatus,
|
|
SessionType,
|
|
)
|
|
from v1.agent.schemas import RunAgentInput
|
|
from v1.agent.service import AgentChatService
|
|
|
|
|
|
class FakeAsyncSession:
|
|
def __init__(self, sessions: list[AgentChatSession]) -> None:
|
|
self._sessions = {session.id: session for session in sessions}
|
|
self.commit_called = False
|
|
|
|
async def execute(self, stmt: object):
|
|
class _Result:
|
|
def __init__(self, session_obj: AgentChatSession | None) -> None:
|
|
self._session_obj = session_obj
|
|
|
|
def scalar_one_or_none(self) -> AgentChatSession | None:
|
|
return self._session_obj
|
|
|
|
for session in self._sessions.values():
|
|
return _Result(session)
|
|
return _Result(None)
|
|
|
|
async def scalar(self, stmt: object) -> AgentChatSession | None:
|
|
for session in self._sessions.values():
|
|
return session
|
|
return None
|
|
|
|
async def commit(self) -> None:
|
|
self.commit_called = True
|
|
|
|
|
|
def _build_input(run_id: str) -> RunAgentInput:
|
|
return RunAgentInput.model_validate(
|
|
{
|
|
"threadId": "t1",
|
|
"runId": run_id,
|
|
"state": {},
|
|
"messages": [],
|
|
"tools": [],
|
|
"context": [],
|
|
"forwardedProps": {},
|
|
"resume": {"interruptId": "int-1", "payload": {"decision": "approved"}},
|
|
}
|
|
)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_stream_resume_rejects_non_owner_session() -> None:
|
|
session = AgentChatSession(
|
|
id=uuid4(),
|
|
user_id=uuid4(),
|
|
session_type=SessionType.CHAT,
|
|
status=AgentChatSessionStatus.RUNNING,
|
|
state_snapshot={
|
|
"version": 2,
|
|
"pending_tool_call": {
|
|
"interrupt_id": "int-1",
|
|
"tool_name": "srv.transfer_funds",
|
|
"tool_args": {"to": "u2", "amount": 100},
|
|
"status": "PENDING_APPROVAL",
|
|
"expires_at": datetime.now(timezone.utc).isoformat(),
|
|
"decision": None,
|
|
"result": None,
|
|
"updated_at": datetime.now(timezone.utc).isoformat(),
|
|
},
|
|
"run_context": {"thread_id": "t1", "run_id": str(uuid4())},
|
|
},
|
|
)
|
|
service = AgentChatService(
|
|
session=FakeAsyncSession([session]), # type: ignore[arg-type]
|
|
current_user=CurrentUser(id=UUID("00000000-0000-0000-0000-000000000001")),
|
|
)
|
|
|
|
with pytest.raises(HTTPException) as exc_info:
|
|
await service.prepare_resume(str(session.id), _build_input(str(session.id)))
|
|
|
|
assert exc_info.value.status_code == 404
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_prepare_resume_commits_expired_state_before_410() -> None:
|
|
owner_id = UUID("00000000-0000-0000-0000-000000000001")
|
|
session = AgentChatSession(
|
|
id=uuid4(),
|
|
user_id=owner_id,
|
|
session_type=SessionType.CHAT,
|
|
status=AgentChatSessionStatus.RUNNING,
|
|
state_snapshot={
|
|
"version": 2,
|
|
"pending_tool_call": {
|
|
"interrupt_id": "int-1",
|
|
"tool_name": "srv.transfer_funds",
|
|
"tool_args": {"to": "u2", "amount": 100},
|
|
"status": "PENDING_APPROVAL",
|
|
"expires_at": "2000-01-01T00:00:00+00:00",
|
|
"decision": None,
|
|
"result": None,
|
|
"updated_at": datetime.now(timezone.utc).isoformat(),
|
|
},
|
|
"run_context": {"thread_id": "t1", "run_id": str(uuid4())},
|
|
},
|
|
)
|
|
fake_db = FakeAsyncSession([session])
|
|
service = AgentChatService(
|
|
session=fake_db, # type: ignore[arg-type]
|
|
current_user=CurrentUser(id=owner_id),
|
|
)
|
|
|
|
with pytest.raises(HTTPException) as exc_info:
|
|
await service.prepare_resume(str(session.id), _build_input(str(session.id)))
|
|
|
|
assert exc_info.value.status_code == 410
|
|
assert fake_db.commit_called is True
|