refactor: 移除 crewai agent 架构相关代码并更新 LLM 配置
This commit is contained in:
@@ -1,97 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from decimal import Decimal
|
||||
from types import MethodType
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from core.auth.models import CurrentUser
|
||||
from models.agent_chat_message import AgentChatMessage, AgentChatMessageRole
|
||||
from models.agent_chat_session import AgentChatSession, AgentChatSessionStatus
|
||||
from v1.agent.schemas import AgentChatRunRequest
|
||||
from v1.agent.service import AgentChatService
|
||||
|
||||
|
||||
class _FakeAsyncSession:
|
||||
def __init__(self) -> None:
|
||||
self.added: list[object] = []
|
||||
self.committed = False
|
||||
self.rolled_back = False
|
||||
|
||||
def add(self, obj: object) -> None:
|
||||
self.added.append(obj)
|
||||
|
||||
async def flush(self) -> None:
|
||||
return None
|
||||
|
||||
async def commit(self) -> None:
|
||||
self.committed = True
|
||||
|
||||
async def rollback(self) -> None:
|
||||
self.rolled_back = True
|
||||
|
||||
async def refresh(self, obj: object) -> None:
|
||||
if isinstance(obj, AgentChatSession) and obj.id is None:
|
||||
obj.id = uuid4()
|
||||
if isinstance(obj, AgentChatMessage) and obj.id is None:
|
||||
obj.id = uuid4()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_persists_messages_and_emits_ordered_events() -> None:
|
||||
fake_db = _FakeAsyncSession()
|
||||
service = AgentChatService(
|
||||
session=fake_db, # type: ignore[arg-type]
|
||||
current_user=CurrentUser(id=UUID("00000000-0000-0000-0000-000000000001")),
|
||||
)
|
||||
|
||||
async def _resolve_session(
|
||||
self: AgentChatService,
|
||||
*,
|
||||
session_id: object | None,
|
||||
user_id: UUID,
|
||||
first_message: str,
|
||||
now: datetime,
|
||||
) -> AgentChatSession:
|
||||
assert session_id is None
|
||||
assert first_message == "hello"
|
||||
return AgentChatSession(
|
||||
id=UUID("00000000-0000-0000-0000-000000000111"),
|
||||
user_id=user_id,
|
||||
title="hello",
|
||||
status=AgentChatSessionStatus.RUNNING,
|
||||
last_activity_at=now,
|
||||
message_count=0,
|
||||
total_tokens=0,
|
||||
total_cost=Decimal("0"),
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
deleted_at=None,
|
||||
)
|
||||
|
||||
async def _next_seq_base(self: AgentChatService, session_id: object) -> int:
|
||||
assert session_id == UUID("00000000-0000-0000-0000-000000000111")
|
||||
return 2
|
||||
|
||||
service._resolve_session = MethodType(_resolve_session, service) # type: ignore[method-assign]
|
||||
service._next_seq_base = MethodType(_next_seq_base, service) # type: ignore[method-assign]
|
||||
|
||||
response = await service.run(AgentChatRunRequest(message="hello"))
|
||||
|
||||
assert fake_db.committed is True
|
||||
inserted_messages = [
|
||||
item for item in fake_db.added if isinstance(item, AgentChatMessage)
|
||||
]
|
||||
assert len(inserted_messages) == 2
|
||||
assert [msg.seq for msg in inserted_messages] == [3, 4]
|
||||
assert [msg.role for msg in inserted_messages] == [
|
||||
AgentChatMessageRole.USER,
|
||||
AgentChatMessageRole.ASSISTANT,
|
||||
]
|
||||
assert [event.type for event in response.events] == [
|
||||
"run.started",
|
||||
"message.delta",
|
||||
"run.completed",
|
||||
]
|
||||
@@ -1,78 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Callable
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from app import app
|
||||
from v1.agent.dependencies import get_agent_service
|
||||
from v1.agent.schemas import (
|
||||
AgentChatEvent,
|
||||
AgentChatRunRequest,
|
||||
AgentChatRunResponse,
|
||||
)
|
||||
from v1.agent.service import AgentChatService
|
||||
|
||||
|
||||
class FakeAgentChatService:
|
||||
async def run(self, payload: AgentChatRunRequest) -> AgentChatRunResponse:
|
||||
return AgentChatRunResponse(
|
||||
session_id=UUID("00000000-0000-0000-0000-000000000001"),
|
||||
output=payload.message,
|
||||
events=[
|
||||
AgentChatEvent(
|
||||
type="run.started", run_id="00000000-0000-0000-0000-000000000001"
|
||||
),
|
||||
AgentChatEvent(
|
||||
type="message.delta", message_id="m1", delta=payload.message
|
||||
),
|
||||
AgentChatEvent(
|
||||
type="run.completed",
|
||||
run_id="00000000-0000-0000-0000-000000000001",
|
||||
output=payload.message,
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def _override_agent_chat_service(
|
||||
service: FakeAgentChatService,
|
||||
) -> Callable[[], AgentChatService]:
|
||||
def _get_service() -> AgentChatService:
|
||||
return service # type: ignore[return-value]
|
||||
|
||||
return _get_service
|
||||
|
||||
|
||||
def test_run_route_returns_response() -> None:
|
||||
app.dependency_overrides[get_agent_service] = _override_agent_chat_service(
|
||||
FakeAgentChatService()
|
||||
)
|
||||
|
||||
client = TestClient(app)
|
||||
try:
|
||||
response = client.post("/api/v1/agent-chat", json={"message": "hello"})
|
||||
assert response.status_code == 200
|
||||
body = response.json()
|
||||
assert body["output"] == "hello"
|
||||
assert [event["type"] for event in body["events"]] == [
|
||||
"run.started",
|
||||
"message.delta",
|
||||
"run.completed",
|
||||
]
|
||||
finally:
|
||||
app.dependency_overrides = {}
|
||||
|
||||
|
||||
def test_run_route_validates_payload() -> None:
|
||||
app.dependency_overrides[get_agent_service] = _override_agent_chat_service(
|
||||
FakeAgentChatService()
|
||||
)
|
||||
|
||||
client = TestClient(app)
|
||||
try:
|
||||
response = client.post("/api/v1/agent-chat", json={"message": ""})
|
||||
assert response.status_code == 422
|
||||
finally:
|
||||
app.dependency_overrides = {}
|
||||
@@ -1,20 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from decimal import Decimal
|
||||
|
||||
from v1.agent.service import aggregate_session_cost
|
||||
|
||||
|
||||
def test_aggregate_session_cost_sums_non_negative_values() -> None:
|
||||
total = aggregate_session_cost([Decimal("0.010000"), Decimal("0.002500")])
|
||||
assert total == Decimal("0.012500")
|
||||
|
||||
|
||||
def test_aggregate_session_cost_rejects_negative_value() -> None:
|
||||
try:
|
||||
aggregate_session_cost([Decimal("-0.010000")])
|
||||
raised = False
|
||||
except ValueError:
|
||||
raised = True
|
||||
|
||||
assert raised is True
|
||||
@@ -1,42 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from decimal import Decimal
|
||||
from uuid import UUID
|
||||
|
||||
from models.agent_chat_session import AgentChatSession, AgentChatSessionStatus
|
||||
from v1.agent.service import select_recent_session
|
||||
|
||||
|
||||
def test_select_recent_session_uses_last_activity_desc() -> None:
|
||||
sessions = [
|
||||
AgentChatSession(
|
||||
id=UUID("00000000-0000-0000-0000-000000000001"),
|
||||
user_id=UUID("00000000-0000-0000-0000-0000000000a1"),
|
||||
title="older",
|
||||
status=AgentChatSessionStatus.COMPLETED,
|
||||
last_activity_at=datetime(2026, 2, 25, 9, 0, tzinfo=timezone.utc),
|
||||
message_count=1,
|
||||
total_tokens=1,
|
||||
total_cost=Decimal("0"),
|
||||
),
|
||||
AgentChatSession(
|
||||
id=UUID("00000000-0000-0000-0000-000000000002"),
|
||||
user_id=UUID("00000000-0000-0000-0000-0000000000a1"),
|
||||
title="newer",
|
||||
status=AgentChatSessionStatus.RUNNING,
|
||||
last_activity_at=datetime(2026, 2, 25, 10, 0, tzinfo=timezone.utc),
|
||||
message_count=2,
|
||||
total_tokens=2,
|
||||
total_cost=Decimal("0"),
|
||||
),
|
||||
]
|
||||
|
||||
selected = select_recent_session(sessions)
|
||||
|
||||
assert selected is not None
|
||||
assert selected.id == UUID("00000000-0000-0000-0000-000000000002")
|
||||
|
||||
|
||||
def test_select_recent_session_returns_none_for_empty_collection() -> None:
|
||||
assert select_recent_session([]) is None
|
||||
@@ -1,82 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from app import app
|
||||
from core.auth.models import CurrentUser
|
||||
from v1.agent.dependencies import get_agent_service
|
||||
from v1.agent.schemas import RunAgentInput
|
||||
from v1.users.dependencies import get_current_user
|
||||
|
||||
|
||||
class FakeAgentService:
|
||||
async def prepare_resume(self, run_id: str, input_data: RunAgentInput):
|
||||
return None
|
||||
|
||||
async def stream_run(self, input_data: RunAgentInput):
|
||||
yield 'data: {"type": "RUN_STARTED", "runId": "r1"}\n\n'
|
||||
yield 'data: {"type": "TEXT_MESSAGE_START", "messageId": "m1"}\n\n'
|
||||
yield 'data: {"type": "TEXT_MESSAGE_CONTENT", "delta": "Hello"}\n\n'
|
||||
yield 'data: {"type": "TEXT_MESSAGE_END", "messageId": "m1"}\n\n'
|
||||
yield 'data: {"type": "RUN_FINISHED", "runId": "r1"}\n\n'
|
||||
|
||||
async def stream_resume(self, run_id: str, input_data: RunAgentInput):
|
||||
yield 'data: {"type": "RUN_STARTED", "runId": "r1"}\n\n'
|
||||
yield 'data: {"type": "TEXT_MESSAGE_START", "messageId": "m2"}\n\n'
|
||||
yield 'data: {"type": "TEXT_MESSAGE_CONTENT", "delta": "Resumed"}\n\n'
|
||||
yield 'data: {"type": "TEXT_MESSAGE_END", "messageId": "m2"}\n\n'
|
||||
yield 'data: {"type": "RUN_FINISHED", "runId": "r1"}\n\n'
|
||||
|
||||
|
||||
def _get_test_user() -> CurrentUser:
|
||||
return CurrentUser(id=UUID("00000000-0000-0000-0000-000000000001"))
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client() -> TestClient:
|
||||
app.dependency_overrides[get_current_user] = _get_test_user
|
||||
app.dependency_overrides[get_agent_service] = lambda: FakeAgentService()
|
||||
yield TestClient(app)
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
class TestChatRoutes:
|
||||
def test_run_route_streams_sse_events(self, client: TestClient):
|
||||
payload = {
|
||||
"threadId": "t1",
|
||||
"runId": "r1",
|
||||
"state": {},
|
||||
"messages": [],
|
||||
"tools": [],
|
||||
"context": [],
|
||||
"forwardedProps": {},
|
||||
}
|
||||
response = client.post("/api/v1/agent/runs", json=payload)
|
||||
assert response.status_code == 200
|
||||
assert response.headers["content-type"] == "text/event-stream; charset=utf-8"
|
||||
|
||||
events = response.text.split("\n\n")
|
||||
assert 'data: {"type": "RUN_STARTED"' in events[0]
|
||||
assert 'data: {"type": "TEXT_MESSAGE_START"' in events[1]
|
||||
|
||||
def test_resume_route_streams_sse_events(self, client: TestClient):
|
||||
payload = {
|
||||
"threadId": "t1",
|
||||
"runId": "r1",
|
||||
"state": {},
|
||||
"messages": [],
|
||||
"tools": [],
|
||||
"context": [],
|
||||
"forwardedProps": {},
|
||||
"resume": {"interruptId": "int-1", "payload": {"decision": "approved"}},
|
||||
}
|
||||
response = client.post("/api/v1/agent/runs/r1/resume", json=payload)
|
||||
assert response.status_code == 200
|
||||
assert response.headers["content-type"] == "text/event-stream; charset=utf-8"
|
||||
|
||||
events = response.text.split("\n\n")
|
||||
assert 'data: {"type": "RUN_STARTED"' in events[0]
|
||||
assert 'data: {"type": "TEXT_MESSAGE_CONTENT", "delta": "Resumed"' in events[2]
|
||||
@@ -1,144 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from uuid import UUID
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from app import app
|
||||
from core.auth.models import CurrentUser
|
||||
from v1.agent.dependencies import get_agent_service
|
||||
from v1.agent.schemas import RunAgentInput
|
||||
from v1.users.dependencies import get_current_user
|
||||
|
||||
|
||||
class FakeAgentServiceWithInterrupt:
|
||||
async def prepare_resume(self, run_id: str, input_data: RunAgentInput):
|
||||
return None
|
||||
|
||||
async def stream_run(self, input_data: RunAgentInput):
|
||||
yield 'data: {"type": "RUN_STARTED", "runId": "' + input_data.runId + '"}\n\n'
|
||||
yield 'data: {"type": "TEXT_MESSAGE_START", "messageId": "m1"}\n\n'
|
||||
yield 'data: {"type": "TEXT_MESSAGE_CONTENT", "delta": "Let me navigate"}\n\n'
|
||||
yield 'data: {"type": "TOOL_CALL", "toolName": "ui.navigate_to", "args": {"path": "/home"}}\n\n'
|
||||
yield (
|
||||
'data: {"type": "RUN_FINISHED", "runId": "'
|
||||
+ input_data.runId
|
||||
+ '", "outcome": "interrupt", "interrupt": {"id": "int-1", "reason": "frontend_tool", "payload": {"toolName": "ui.navigate_to", "args": {"path": "/home"}}}}\n\n'
|
||||
)
|
||||
|
||||
async def stream_resume(self, run_id: str, input_data: RunAgentInput):
|
||||
if input_data.resume and input_data.resume.get("interruptId") == "int-1":
|
||||
payload = input_data.resume.get("payload", {})
|
||||
yield 'data: {"type": "RUN_STARTED", "runId": "' + run_id + '"}\n\n'
|
||||
yield (
|
||||
'data: {"type": "TOOL_RESULT", "toolName": "ui.navigate_to", "result": '
|
||||
+ json.dumps(payload.get("result", {}))
|
||||
+ "}\n\n"
|
||||
)
|
||||
yield 'data: {"type": "TEXT_MESSAGE_START", "messageId": "m2"}\n\n'
|
||||
yield 'data: {"type": "TEXT_MESSAGE_CONTENT", "delta": "Navigation completed"}\n\n'
|
||||
yield 'data: {"type": "RUN_FINISHED", "runId": "' + run_id + '"}\n\n'
|
||||
else:
|
||||
yield (
|
||||
'data: {"type": "RUN_FINISHED", "runId": "'
|
||||
+ run_id
|
||||
+ '", "outcome": "error"}\n\n'
|
||||
)
|
||||
|
||||
|
||||
def _get_test_user() -> CurrentUser:
|
||||
return CurrentUser(id=UUID("00000000-0000-0000-0000-000000000001"))
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client() -> TestClient:
|
||||
app.dependency_overrides[get_current_user] = _get_test_user
|
||||
app.dependency_overrides[get_agent_service] = (
|
||||
lambda: FakeAgentServiceWithInterrupt()
|
||||
)
|
||||
yield TestClient(app)
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
class TestInterruptResumeFlow:
|
||||
def test_frontend_tool_interrupt_then_resume_with_result(self, client: TestClient):
|
||||
payload = {
|
||||
"threadId": "t1",
|
||||
"runId": "r1",
|
||||
"state": {},
|
||||
"messages": [{"role": "user", "content": "Navigate to home"}],
|
||||
"tools": [{"name": "ui.navigate_to", "execution_target": "frontend"}],
|
||||
"context": [],
|
||||
"forwardedProps": {},
|
||||
}
|
||||
response = client.post("/api/v1/agent/runs", json=payload)
|
||||
assert response.status_code == 200
|
||||
|
||||
events = response.text.split("\n\n")
|
||||
interrupt_event = [e for e in events if '"outcome": "interrupt"' in e][0]
|
||||
assert '"id": "int-1"' in interrupt_event
|
||||
assert '"reason": "frontend_tool"' in interrupt_event
|
||||
|
||||
resume_payload = {
|
||||
"threadId": "t1",
|
||||
"runId": "r1",
|
||||
"state": {},
|
||||
"messages": [],
|
||||
"tools": [],
|
||||
"context": [],
|
||||
"forwardedProps": {},
|
||||
"resume": {
|
||||
"interruptId": "int-1",
|
||||
"payload": {"result": {"success": True}},
|
||||
},
|
||||
}
|
||||
resume_response = client.post(
|
||||
"/api/v1/agent/runs/r1/resume", json=resume_payload
|
||||
)
|
||||
assert resume_response.status_code == 200
|
||||
|
||||
resume_events = resume_response.text.split("\n\n")
|
||||
tool_result_event = [e for e in resume_events if '"type": "TOOL_RESULT"' in e][
|
||||
0
|
||||
]
|
||||
assert '"toolName": "ui.navigate_to"' in tool_result_event
|
||||
assert '"success": true' in tool_result_event.lower()
|
||||
|
||||
def test_backend_tool_approval_rejected(self, client: TestClient):
|
||||
payload = {
|
||||
"threadId": "t2",
|
||||
"runId": "r2",
|
||||
"state": {},
|
||||
"messages": [{"role": "user", "content": "Transfer funds"}],
|
||||
"tools": [
|
||||
{
|
||||
"name": "srv.transfer_funds",
|
||||
"execution_target": "backend",
|
||||
"requires_approval": True,
|
||||
}
|
||||
],
|
||||
"context": [],
|
||||
"forwardedProps": {},
|
||||
}
|
||||
response = client.post("/api/v1/agent/runs", json=payload)
|
||||
assert response.status_code == 200
|
||||
|
||||
resume_payload = {
|
||||
"threadId": "t2",
|
||||
"runId": "r2",
|
||||
"state": {},
|
||||
"messages": [],
|
||||
"tools": [],
|
||||
"context": [],
|
||||
"forwardedProps": {},
|
||||
"resume": {
|
||||
"interruptId": "int-1",
|
||||
"payload": {"decision": "rejected", "reason": "User denied"},
|
||||
},
|
||||
}
|
||||
resume_response = client.post(
|
||||
"/api/v1/agent/runs/r2/resume", json=resume_payload
|
||||
)
|
||||
assert resume_response.status_code == 200
|
||||
Reference in New Issue
Block a user