81 lines
2.7 KiB
Python
81 lines
2.7 KiB
Python
from __future__ import annotations
|
|
|
|
from datetime import datetime, timedelta, timezone
|
|
from uuid import UUID, uuid4
|
|
|
|
import pytest
|
|
|
|
from core.auth.models import CurrentUser
|
|
from models.agent_chat_session import (
|
|
AgentChatSession,
|
|
AgentChatSessionStatus,
|
|
SessionType,
|
|
)
|
|
from v1.agent.service import AgentChatService
|
|
from v1.agent.tool_registry import validate_tool_spec
|
|
|
|
|
|
class TestAgentSecurityRules:
|
|
def test_tool_name_must_be_allowlisted(self):
|
|
validate_tool_spec({"name": "ui.navigate_to", "execution_target": "frontend"})
|
|
validate_tool_spec({"name": "srv.search_docs", "execution_target": "backend"})
|
|
|
|
def test_tool_name_rejected_if_not_in_namespace(self):
|
|
try:
|
|
validate_tool_spec(
|
|
{"name": "malicious.tool", "execution_target": "frontend"}
|
|
)
|
|
except ValueError:
|
|
pass
|
|
else:
|
|
raise AssertionError("Should have raised ValueError for unknown namespace")
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_frontend_result_fails_when_interrupt_mismatch(self):
|
|
session = AgentChatSession(
|
|
id=uuid4(),
|
|
user_id=UUID("00000000-0000-0000-0000-000000000001"),
|
|
session_type=SessionType.CHAT,
|
|
status=AgentChatSessionStatus.RUNNING,
|
|
)
|
|
|
|
class FakeAsyncSession:
|
|
def __init__(self, session_obj: AgentChatSession) -> None:
|
|
self._session_obj = session_obj
|
|
|
|
async def execute(self, stmt: object):
|
|
class _Result:
|
|
def __init__(self, session_obj: AgentChatSession | None) -> None:
|
|
self._session_obj = session_obj
|
|
|
|
def scalar_one_or_none(self) -> AgentChatSession | None:
|
|
return self._session_obj
|
|
|
|
return _Result(self._session_obj)
|
|
|
|
async def scalar(self, stmt: object) -> AgentChatSession | None:
|
|
return self._session_obj
|
|
|
|
service = AgentChatService(
|
|
session=FakeAsyncSession(session), # type: ignore[arg-type]
|
|
current_user=CurrentUser(id=UUID("00000000-0000-0000-0000-000000000001")),
|
|
)
|
|
|
|
await service.set_pending_tool_call(
|
|
session_id=session.id,
|
|
interrupt_id="int-1",
|
|
tool_name="srv.transfer_funds",
|
|
tool_args={"to": "u2", "amount": 100},
|
|
expires_at=datetime.now(timezone.utc) + timedelta(minutes=5),
|
|
thread_id="t1",
|
|
run_id="r1",
|
|
)
|
|
|
|
result = await service.apply_resume_decision(
|
|
session_id=session.id,
|
|
interrupt_id="int-other",
|
|
decision={"decision": "approved"},
|
|
)
|
|
|
|
assert result.applied is False
|