refactor: 重构 Agent 模块为 AgentScope,删除旧版 CrewAI/LiteLLM 实现
This commit is contained in:
@@ -1,703 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import uuid
|
||||
from decimal import Decimal
|
||||
|
||||
import pytest
|
||||
from ag_ui.core import RunAgentInput
|
||||
from sqlalchemy import delete, select
|
||||
|
||||
from core.agent.application.resume_service import ResumeService
|
||||
from core.agent.application.run_service import RunService
|
||||
from core.agent.infrastructure.persistence.session_repository import SessionRepository
|
||||
from core.agent.infrastructure.queue.tasks import run_agent_task
|
||||
from core.agent.infrastructure.storage.tool_result_storage import (
|
||||
create_tool_result_storage,
|
||||
)
|
||||
from services.base.supabase import supabase_service
|
||||
from core.db import AsyncSessionLocal, engine
|
||||
from models.agent_chat_message import AgentChatMessage, AgentChatMessageRole
|
||||
from models.agent_chat_session import AgentChatSession, AgentChatSessionStatus
|
||||
from models.llm import Llm
|
||||
from models.llm_factory import LlmFactory
|
||||
from models.profile import Profile
|
||||
from models.system_agents import SystemAgents
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_then_resume_persists_messages_and_session_state(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
call_count = {"n": 0}
|
||||
|
||||
def _fake_execute(
|
||||
self,
|
||||
*,
|
||||
user_input: str,
|
||||
system_prompt: str | None = None,
|
||||
tools: list[dict[str, object]] | None = None,
|
||||
) -> dict[str, object]:
|
||||
del self, user_input, system_prompt, tools
|
||||
call_count["n"] += 1
|
||||
if call_count["n"] == 1:
|
||||
return {
|
||||
"assistant_text": "请确认是否跳转。",
|
||||
"prompt_tokens": 11,
|
||||
"completion_tokens": 7,
|
||||
"total_tokens": 18,
|
||||
"cost": 0.0025,
|
||||
"pending_front_tool": {
|
||||
"name": "front.navigate_to_route",
|
||||
"args": {"target": "/calendar/dayweek", "replace": False},
|
||||
"target": "frontend",
|
||||
},
|
||||
"agui_events": [],
|
||||
}
|
||||
return {
|
||||
"assistant_text": "已继续执行并完成。",
|
||||
"prompt_tokens": 3,
|
||||
"completion_tokens": 2,
|
||||
"total_tokens": 5,
|
||||
"cost": 0.001,
|
||||
"pending_front_tool": None,
|
||||
"agui_events": [],
|
||||
}
|
||||
|
||||
monkeypatch.setattr(
|
||||
"core.agent.infrastructure.crewai.runtime.CrewAIRuntime.execute",
|
||||
_fake_execute,
|
||||
)
|
||||
|
||||
async with AsyncSessionLocal() as lookup_session:
|
||||
existing_owner = await lookup_session.execute(
|
||||
select(AgentChatSession.user_id).limit(1)
|
||||
)
|
||||
owner_id = existing_owner.scalar_one_or_none()
|
||||
if owner_id is None:
|
||||
pytest.skip("No existing session owner available in local database")
|
||||
factory_id = uuid.uuid4()
|
||||
session_uuid = uuid.uuid4()
|
||||
agent_type = f"AAA_TEST_{uuid.uuid4().hex[:8]}"
|
||||
|
||||
async with AsyncSessionLocal() as seed_session:
|
||||
llm_row = await seed_session.execute(select(Llm.id).limit(1))
|
||||
llm_id = llm_row.scalar_one_or_none()
|
||||
if llm_id is None:
|
||||
seed_session.add(
|
||||
LlmFactory(
|
||||
id=factory_id,
|
||||
name=f"dashscope-test-{uuid.uuid4().hex[:8]}",
|
||||
request_url="https://dashscope.example",
|
||||
)
|
||||
)
|
||||
llm_id = uuid.uuid4()
|
||||
seed_session.add(
|
||||
Llm(
|
||||
id=llm_id,
|
||||
factory_id=factory_id,
|
||||
model_code=f"qwen3.5-flash-test-{uuid.uuid4().hex[:6]}",
|
||||
)
|
||||
)
|
||||
seed_session.add(
|
||||
SystemAgents(agent_type=agent_type, llm_id=llm_id, status="active")
|
||||
)
|
||||
seed_session.add(AgentChatSession(id=session_uuid, user_id=owner_id))
|
||||
await seed_session.commit()
|
||||
|
||||
published: list[str] = []
|
||||
queued_commands: list[dict[str, object]] = []
|
||||
|
||||
async def _publish(event: dict[str, object]) -> None:
|
||||
event_type = event.get("type")
|
||||
if isinstance(event_type, str):
|
||||
published.append(event_type)
|
||||
|
||||
async def _enqueue(command: dict[str, object]) -> str:
|
||||
queued_commands.append(command)
|
||||
return "task-followup-1"
|
||||
|
||||
try:
|
||||
run_input_payload = {
|
||||
"threadId": str(session_uuid),
|
||||
"runId": "run-it-1",
|
||||
"state": {},
|
||||
"messages": [
|
||||
{"id": "u1", "role": "user", "content": "帮我打开日历"},
|
||||
],
|
||||
"tools": [
|
||||
{
|
||||
"name": "front.navigate_to_route",
|
||||
"description": "navigate route",
|
||||
"parameters": {"type": "object"},
|
||||
}
|
||||
],
|
||||
"context": [],
|
||||
"forwardedProps": {},
|
||||
}
|
||||
run_result = await run_agent_task(
|
||||
{
|
||||
"command": "run",
|
||||
"run_input": run_input_payload,
|
||||
},
|
||||
publish_event=_publish,
|
||||
enqueue_command=_enqueue,
|
||||
run_service=RunService(),
|
||||
resume_service=ResumeService(),
|
||||
)
|
||||
pending_tool_call_id = str(run_result["pending_tool_call_id"])
|
||||
state_snapshot = run_result["state_snapshot"]
|
||||
assert isinstance(state_snapshot, dict)
|
||||
pending_tool_nonce = state_snapshot["pending_tool_nonce"]
|
||||
assert isinstance(pending_tool_nonce, str)
|
||||
|
||||
await run_agent_task(
|
||||
{
|
||||
"command": "resume",
|
||||
"run_input": {
|
||||
"threadId": str(session_uuid),
|
||||
"runId": "run-it-2",
|
||||
"state": {},
|
||||
"messages": [
|
||||
{
|
||||
"id": "tool-1",
|
||||
"role": "tool",
|
||||
"toolCallId": pending_tool_call_id,
|
||||
"content": json.dumps(
|
||||
{
|
||||
"toolName": "front.navigate_to_route",
|
||||
"toolArgs": {
|
||||
"target": "/calendar/dayweek",
|
||||
"replace": False,
|
||||
"__nonce": pending_tool_nonce,
|
||||
},
|
||||
"nonce": pending_tool_nonce,
|
||||
"result": {"ok": True},
|
||||
},
|
||||
ensure_ascii=True,
|
||||
separators=(",", ":"),
|
||||
),
|
||||
}
|
||||
],
|
||||
"tools": [],
|
||||
"context": [],
|
||||
"forwardedProps": {},
|
||||
},
|
||||
},
|
||||
publish_event=_publish,
|
||||
enqueue_command=_enqueue,
|
||||
run_service=RunService(),
|
||||
resume_service=ResumeService(),
|
||||
)
|
||||
|
||||
assert len(queued_commands) == 1
|
||||
await run_agent_task(
|
||||
queued_commands[0],
|
||||
publish_event=_publish,
|
||||
enqueue_command=_enqueue,
|
||||
run_service=RunService(),
|
||||
resume_service=ResumeService(),
|
||||
)
|
||||
|
||||
await engine.dispose()
|
||||
async with AsyncSessionLocal() as verify_session:
|
||||
db_session = await verify_session.get(AgentChatSession, session_uuid)
|
||||
assert db_session is not None
|
||||
assert db_session.status == AgentChatSessionStatus.COMPLETED
|
||||
assert db_session.message_count == 4
|
||||
assert db_session.total_tokens == 23
|
||||
assert db_session.total_cost == Decimal("0.003500")
|
||||
assert db_session.state_snapshot == {
|
||||
"status": "completed",
|
||||
"pending_tool_call_id": None,
|
||||
"pending_tool_name": None,
|
||||
"pending_tool_args_sha256": None,
|
||||
"pending_tool_nonce": None,
|
||||
}
|
||||
|
||||
rows = await verify_session.execute(
|
||||
select(AgentChatMessage)
|
||||
.where(AgentChatMessage.session_id == session_uuid)
|
||||
.order_by(AgentChatMessage.seq.asc())
|
||||
)
|
||||
messages = list(rows.scalars().all())
|
||||
assert [item.role for item in messages] == [
|
||||
AgentChatMessageRole.USER,
|
||||
AgentChatMessageRole.ASSISTANT,
|
||||
AgentChatMessageRole.TOOL,
|
||||
AgentChatMessageRole.ASSISTANT,
|
||||
]
|
||||
assert messages[1].input_tokens == 11
|
||||
assert messages[1].output_tokens == 7
|
||||
assert messages[1].cost == Decimal("0.002500")
|
||||
assert messages[3].content == "已继续执行并完成。"
|
||||
|
||||
assert "RUN_STARTED" in published
|
||||
assert "RUN_FINISHED" in published
|
||||
assert "TEXT_MESSAGE_CONTENT" in published
|
||||
finally:
|
||||
async with AsyncSessionLocal() as cleanup_session:
|
||||
await cleanup_session.execute(
|
||||
delete(AgentChatSession).where(AgentChatSession.id == session_uuid)
|
||||
)
|
||||
await cleanup_session.execute(
|
||||
delete(SystemAgents).where(SystemAgents.agent_type == agent_type)
|
||||
)
|
||||
await cleanup_session.commit()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resume_tool_result_offloads_to_supabase_storage_for_calendar_tool(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
call_count = {"n": 0}
|
||||
|
||||
def _fake_execute(
|
||||
self,
|
||||
*,
|
||||
user_input: str,
|
||||
system_prompt: str | None = None,
|
||||
tools: list[dict[str, object]] | None = None,
|
||||
) -> dict[str, object]:
|
||||
del self, user_input, system_prompt, tools
|
||||
call_count["n"] += 1
|
||||
if call_count["n"] == 1:
|
||||
return {
|
||||
"assistant_text": "我来创建日历事件,请稍候确认。",
|
||||
"prompt_tokens": 10,
|
||||
"completion_tokens": 6,
|
||||
"total_tokens": 16,
|
||||
"cost": 0.002,
|
||||
"pending_front_tool": {
|
||||
"name": "front.create_calendar_event",
|
||||
"args": {
|
||||
"title": "测试日程",
|
||||
"start": "2026-03-09T09:00:00+08:00",
|
||||
"end": "2026-03-09T10:00:00+08:00",
|
||||
},
|
||||
"target": "frontend",
|
||||
},
|
||||
"agui_events": [],
|
||||
}
|
||||
return {
|
||||
"assistant_text": "日历已创建。",
|
||||
"prompt_tokens": 2,
|
||||
"completion_tokens": 2,
|
||||
"total_tokens": 4,
|
||||
"cost": 0.001,
|
||||
"pending_front_tool": None,
|
||||
"agui_events": [],
|
||||
}
|
||||
|
||||
monkeypatch.setattr(
|
||||
"core.agent.infrastructure.crewai.runtime.CrewAIRuntime.execute",
|
||||
_fake_execute,
|
||||
)
|
||||
|
||||
factory_id = uuid.uuid4()
|
||||
test_user_id: str | None = None
|
||||
test_user_email = f"agent-it-{uuid.uuid4().hex[:8]}@example.com"
|
||||
owner_id = uuid.uuid4()
|
||||
|
||||
initialized = await supabase_service.initialize()
|
||||
if not initialized:
|
||||
pytest.skip("Supabase service is unavailable")
|
||||
|
||||
admin_client = supabase_service.get_admin_client()
|
||||
tool_result_storage = create_tool_result_storage()
|
||||
assert tool_result_storage is not None
|
||||
created_user = admin_client.auth.admin.create_user(
|
||||
{
|
||||
"email": test_user_email,
|
||||
"password": "Passw0rd!123",
|
||||
"email_confirm": True,
|
||||
"user_metadata": {"source": "integration-test"},
|
||||
}
|
||||
)
|
||||
test_user_id = str(created_user.user.id)
|
||||
owner_id = uuid.UUID(test_user_id)
|
||||
|
||||
await engine.dispose()
|
||||
async with AsyncSessionLocal() as lookup_session:
|
||||
llm_row = await lookup_session.execute(select(Llm.id).limit(1))
|
||||
llm_id = llm_row.scalar_one_or_none()
|
||||
|
||||
if llm_id is None:
|
||||
async with AsyncSessionLocal() as seed_session:
|
||||
factory_row = await seed_session.execute(
|
||||
select(LlmFactory.id).where(LlmFactory.name == "dashscope").limit(1)
|
||||
)
|
||||
existing_factory_id = factory_row.scalar_one_or_none()
|
||||
if existing_factory_id is None:
|
||||
seed_session.add(
|
||||
LlmFactory(
|
||||
id=factory_id,
|
||||
name="dashscope",
|
||||
request_url="https://dashscope.example",
|
||||
)
|
||||
)
|
||||
await seed_session.commit()
|
||||
else:
|
||||
factory_id = existing_factory_id
|
||||
|
||||
async with AsyncSessionLocal() as seed_session:
|
||||
llm_id = uuid.uuid4()
|
||||
seed_session.add(
|
||||
Llm(
|
||||
id=llm_id,
|
||||
factory_id=factory_id,
|
||||
model_code=f"qwen3.5-flash-test-{uuid.uuid4().hex[:6]}",
|
||||
)
|
||||
)
|
||||
await seed_session.commit()
|
||||
|
||||
storage = admin_client.storage
|
||||
try:
|
||||
storage.get_bucket("private")
|
||||
except Exception:
|
||||
storage.create_bucket("private", "private", {"public": False})
|
||||
|
||||
session_uuid = uuid.uuid4()
|
||||
agent_type = f"AAA_TEST_{uuid.uuid4().hex[:8]}"
|
||||
uploaded_path: str | None = None
|
||||
|
||||
try:
|
||||
probe_path = f"tool-results/probe/{uuid.uuid4().hex}.json"
|
||||
try:
|
||||
storage.from_("private").upload(probe_path, b"{}")
|
||||
storage.from_("private").remove([probe_path])
|
||||
except Exception:
|
||||
pytest.skip(
|
||||
"Supabase Storage upload API unavailable in current environment"
|
||||
)
|
||||
|
||||
async with AsyncSessionLocal() as seed_session:
|
||||
existing_profile = await seed_session.get(Profile, owner_id)
|
||||
if existing_profile is None:
|
||||
seed_session.add(
|
||||
Profile(
|
||||
id=owner_id,
|
||||
username=f"it_{uuid.uuid4().hex[:8]}",
|
||||
)
|
||||
)
|
||||
seed_session.add(
|
||||
SystemAgents(agent_type=agent_type, llm_id=llm_id, status="active")
|
||||
)
|
||||
seed_session.add(AgentChatSession(id=session_uuid, user_id=owner_id))
|
||||
await seed_session.commit()
|
||||
|
||||
run_result = await run_agent_task(
|
||||
{
|
||||
"command": "run",
|
||||
"run_input": {
|
||||
"threadId": str(session_uuid),
|
||||
"runId": "run-storage-1",
|
||||
"state": {},
|
||||
"messages": [
|
||||
{
|
||||
"id": "u1",
|
||||
"role": "user",
|
||||
"content": "帮我创建明天9点到10点的日历",
|
||||
}
|
||||
],
|
||||
"tools": [
|
||||
{
|
||||
"name": "front.create_calendar_event",
|
||||
"description": "Create calendar event",
|
||||
"parameters": {"type": "object"},
|
||||
}
|
||||
],
|
||||
"context": [],
|
||||
"forwardedProps": {},
|
||||
},
|
||||
},
|
||||
run_service=RunService(),
|
||||
resume_service=ResumeService(
|
||||
tool_result_storage=tool_result_storage,
|
||||
tool_result_bucket="private",
|
||||
tool_result_prefix="tool-results",
|
||||
),
|
||||
)
|
||||
pending_tool_call_id = str(run_result["pending_tool_call_id"])
|
||||
snapshot = run_result["state_snapshot"]
|
||||
assert isinstance(snapshot, dict)
|
||||
pending_tool_nonce = snapshot.get("pending_tool_nonce")
|
||||
assert isinstance(pending_tool_nonce, str)
|
||||
|
||||
await run_agent_task(
|
||||
{
|
||||
"command": "resume",
|
||||
"run_input": {
|
||||
"threadId": str(session_uuid),
|
||||
"runId": "run-storage-2",
|
||||
"state": {},
|
||||
"messages": [
|
||||
{
|
||||
"id": "tool-1",
|
||||
"role": "tool",
|
||||
"toolCallId": pending_tool_call_id,
|
||||
"content": json.dumps(
|
||||
{
|
||||
"toolName": "front.create_calendar_event",
|
||||
"toolArgs": {
|
||||
"title": "测试日程",
|
||||
"start": "2026-03-09T09:00:00+08:00",
|
||||
"end": "2026-03-09T10:00:00+08:00",
|
||||
"__nonce": pending_tool_nonce,
|
||||
},
|
||||
"nonce": pending_tool_nonce,
|
||||
"result": {
|
||||
"ok": True,
|
||||
"type": "calendar_card.v1",
|
||||
"version": "v1",
|
||||
"data": {
|
||||
"id": "evt-test",
|
||||
"title": "测试日程",
|
||||
"description": "x" * 9000,
|
||||
},
|
||||
"actions": [
|
||||
{
|
||||
"type": "link",
|
||||
"label": "查看详情",
|
||||
"target": "/calendar/events/evt-test",
|
||||
}
|
||||
],
|
||||
},
|
||||
},
|
||||
ensure_ascii=True,
|
||||
separators=(",", ":"),
|
||||
),
|
||||
}
|
||||
],
|
||||
"tools": [],
|
||||
"context": [],
|
||||
"forwardedProps": {},
|
||||
},
|
||||
},
|
||||
run_service=RunService(),
|
||||
resume_service=ResumeService(
|
||||
tool_result_storage=tool_result_storage,
|
||||
tool_result_bucket="private",
|
||||
tool_result_prefix="tool-results",
|
||||
),
|
||||
)
|
||||
|
||||
await engine.dispose()
|
||||
async with AsyncSessionLocal() as verify_session:
|
||||
rows = await verify_session.execute(
|
||||
select(AgentChatMessage)
|
||||
.where(AgentChatMessage.session_id == session_uuid)
|
||||
.where(AgentChatMessage.role == AgentChatMessageRole.TOOL)
|
||||
.order_by(AgentChatMessage.seq.desc())
|
||||
)
|
||||
tool_message = rows.scalars().first()
|
||||
assert tool_message is not None
|
||||
metadata = tool_message.metadata_json or {}
|
||||
storage_bucket = metadata.get("storage_bucket")
|
||||
storage_path = metadata.get("storage_path")
|
||||
assert storage_bucket == "private"
|
||||
assert isinstance(storage_path, str)
|
||||
assert storage_path.startswith("tool-results/")
|
||||
uploaded_path = storage_path
|
||||
|
||||
downloaded = storage.from_("private").download(uploaded_path)
|
||||
if isinstance(downloaded, bytes):
|
||||
downloaded_payload = json.loads(downloaded.decode("utf-8"))
|
||||
else:
|
||||
downloaded_payload = json.loads(str(downloaded))
|
||||
|
||||
assert downloaded_payload["toolName"] == "front.create_calendar_event"
|
||||
result_payload = downloaded_payload["result"]
|
||||
assert result_payload["type"] == "calendar_card.v1"
|
||||
assert result_payload["data"]["id"] == "evt-test"
|
||||
finally:
|
||||
if uploaded_path:
|
||||
try:
|
||||
storage.from_("private").remove([uploaded_path])
|
||||
except Exception:
|
||||
pass
|
||||
async with AsyncSessionLocal() as cleanup_session:
|
||||
await cleanup_session.execute(
|
||||
delete(AgentChatSession).where(AgentChatSession.id == session_uuid)
|
||||
)
|
||||
await cleanup_session.execute(
|
||||
delete(SystemAgents).where(SystemAgents.agent_type == agent_type)
|
||||
)
|
||||
await cleanup_session.execute(delete(Profile).where(Profile.id == owner_id))
|
||||
await cleanup_session.execute(
|
||||
delete(Llm).where(Llm.factory_id == factory_id)
|
||||
)
|
||||
await cleanup_session.execute(
|
||||
delete(LlmFactory).where(LlmFactory.id == factory_id)
|
||||
)
|
||||
await cleanup_session.commit()
|
||||
if test_user_id is not None:
|
||||
try:
|
||||
admin_client.auth.admin.delete_user(test_user_id)
|
||||
except Exception:
|
||||
pass
|
||||
await supabase_service.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_service_embeds_profile_settings_in_runtime_system_prompt(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
captured: dict[str, object] = {}
|
||||
session_uuid = uuid.uuid4()
|
||||
agent_type = f"AAA_TEST_{uuid.uuid4().hex[:8]}"
|
||||
original_profile: Profile | None = None
|
||||
|
||||
def _fake_execute(self, *, user_input: str, system_prompt: str | None = None):
|
||||
captured["user_input"] = user_input
|
||||
captured["system_prompt"] = system_prompt
|
||||
return {
|
||||
"assistant_text": "Mocked answer",
|
||||
"prompt_tokens": 11,
|
||||
"completion_tokens": 7,
|
||||
"total_tokens": 18,
|
||||
"cost": 0.0025,
|
||||
"agui_events": [],
|
||||
}
|
||||
|
||||
monkeypatch.setattr(
|
||||
"core.agent.infrastructure.crewai.runtime.CrewAIRuntime.execute",
|
||||
_fake_execute,
|
||||
)
|
||||
|
||||
await engine.dispose()
|
||||
async with AsyncSessionLocal() as lookup_session:
|
||||
owner_row = await lookup_session.execute(select(Profile.id).limit(1))
|
||||
owner_id = owner_row.scalar_one_or_none()
|
||||
if owner_id is None:
|
||||
pytest.skip("No profile owner available in local database")
|
||||
original_profile = await lookup_session.get(Profile, owner_id)
|
||||
llm_row = await lookup_session.execute(
|
||||
select(Llm.id, LlmFactory.name)
|
||||
.join(LlmFactory, LlmFactory.id == Llm.factory_id)
|
||||
.where(LlmFactory.name.in_(("dashscope", "deepseek", "moonshot")))
|
||||
.limit(1)
|
||||
)
|
||||
llm_record = llm_row.one_or_none()
|
||||
if llm_record is None:
|
||||
pytest.skip("No supported llm provider available in local database")
|
||||
llm_id = llm_record[0]
|
||||
|
||||
try:
|
||||
async with AsyncSessionLocal() as seed_session:
|
||||
seed_session.add(
|
||||
SystemAgents(agent_type=agent_type, llm_id=llm_id, status="active")
|
||||
)
|
||||
profile = await seed_session.get(Profile, owner_id)
|
||||
assert profile is not None
|
||||
profile.username = "demo-user"
|
||||
profile.bio = "hello\nworld"
|
||||
profile.settings = {
|
||||
"preferences": {
|
||||
"interface_language": "zh-CN",
|
||||
"ai_language": "en-US",
|
||||
"timezone": "Asia/Shanghai",
|
||||
"country": "CN",
|
||||
}
|
||||
}
|
||||
seed_session.add(AgentChatSession(id=session_uuid, user_id=owner_id))
|
||||
await seed_session.commit()
|
||||
|
||||
result = await RunService().run(
|
||||
run_input=RunAgentInput.model_validate(
|
||||
{
|
||||
"threadId": str(session_uuid),
|
||||
"runId": "run-it-1",
|
||||
"state": {},
|
||||
"messages": [
|
||||
{"id": "u1", "role": "user", "content": "hello"},
|
||||
],
|
||||
"tools": [],
|
||||
"context": [],
|
||||
"forwardedProps": {},
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
assert result["persisted"] is True
|
||||
assert captured["user_input"] == "hello"
|
||||
system_prompt = captured["system_prompt"]
|
||||
assert isinstance(system_prompt, str)
|
||||
assert "# USER_PROFILE (JSON)" in system_prompt
|
||||
assert '"ai_language":"en-US"' in system_prompt
|
||||
assert '"timezone":"Asia/Shanghai"' in system_prompt
|
||||
assert '"country":"CN"' in system_prompt
|
||||
finally:
|
||||
await engine.dispose()
|
||||
async with AsyncSessionLocal() as cleanup_session:
|
||||
if original_profile is not None:
|
||||
profile = await cleanup_session.get(Profile, owner_id)
|
||||
if profile is not None:
|
||||
profile.username = original_profile.username
|
||||
profile.bio = original_profile.bio
|
||||
profile.settings = original_profile.settings
|
||||
await cleanup_session.execute(
|
||||
delete(AgentChatSession).where(AgentChatSession.id == session_uuid)
|
||||
)
|
||||
await cleanup_session.execute(
|
||||
delete(SystemAgents).where(SystemAgents.agent_type == agent_type)
|
||||
)
|
||||
await cleanup_session.commit()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_soft_delete_session_cascades_to_messages() -> None:
|
||||
session_uuid = uuid.uuid4()
|
||||
await engine.dispose()
|
||||
|
||||
async with AsyncSessionLocal() as lookup_session:
|
||||
owner = await lookup_session.execute(select(Profile.id).limit(1))
|
||||
owner_id = owner.scalar_one_or_none()
|
||||
if owner_id is None:
|
||||
pytest.skip("No profile owner available in local database")
|
||||
|
||||
async with AsyncSessionLocal() as seed_session:
|
||||
seed_session.add(AgentChatSession(id=session_uuid, user_id=owner_id))
|
||||
await seed_session.flush()
|
||||
seed_session.add(
|
||||
AgentChatMessage(
|
||||
session_id=session_uuid,
|
||||
seq=1,
|
||||
role=AgentChatMessageRole.USER,
|
||||
content="hello",
|
||||
)
|
||||
)
|
||||
await seed_session.commit()
|
||||
|
||||
try:
|
||||
async with AsyncSessionLocal() as mutate_session:
|
||||
repo = SessionRepository(mutate_session)
|
||||
affected = await repo.soft_delete_session_with_messages(
|
||||
session_id=session_uuid
|
||||
)
|
||||
await mutate_session.commit()
|
||||
assert affected == 1
|
||||
|
||||
async with AsyncSessionLocal() as verify_session:
|
||||
db_session = await verify_session.get(AgentChatSession, session_uuid)
|
||||
assert db_session is not None
|
||||
assert db_session.deleted_at is not None
|
||||
rows = await verify_session.execute(
|
||||
select(AgentChatMessage).where(
|
||||
AgentChatMessage.session_id == session_uuid
|
||||
)
|
||||
)
|
||||
messages = list(rows.scalars().all())
|
||||
assert len(messages) == 1
|
||||
assert messages[0].deleted_at is not None
|
||||
finally:
|
||||
async with AsyncSessionLocal() as cleanup_session:
|
||||
await cleanup_session.execute(
|
||||
delete(AgentChatMessage).where(
|
||||
AgentChatMessage.session_id == session_uuid
|
||||
)
|
||||
)
|
||||
await cleanup_session.execute(
|
||||
delete(AgentChatSession).where(AgentChatSession.id == session_uuid)
|
||||
)
|
||||
await cleanup_session.commit()
|
||||
@@ -1,78 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from core.agent.application.session_state_persistence import persist_tool_result_payload
|
||||
from core.agent.domain.tool_correlation import reconstruct_tool_call_result_event
|
||||
from core.agent.infrastructure.queue.tasks import run_agent_task
|
||||
|
||||
|
||||
class _FakeStorage:
|
||||
def __init__(self) -> None:
|
||||
self.writes: dict[str, dict[str, object]] = {}
|
||||
|
||||
async def upload_json(
|
||||
self, *, bucket: str, path: str, payload: dict[str, object]
|
||||
) -> str:
|
||||
self.writes[f"{bucket}/{path}"] = payload
|
||||
return "etag-1"
|
||||
|
||||
|
||||
async def test_closed_loop_run_flow_frontend_to_sse() -> None:
|
||||
thread_id = "00000000-0000-0000-0000-000000000001"
|
||||
published: list[str] = []
|
||||
|
||||
class _FakeRunService:
|
||||
async def run(self, *, run_input: object) -> dict[str, object]:
|
||||
del run_input
|
||||
return {"threadId": thread_id, "runId": "run-1"}
|
||||
|
||||
async def _publish(event: dict[str, object]) -> None:
|
||||
event_type = event.get("type")
|
||||
if isinstance(event_type, str):
|
||||
published.append(event_type)
|
||||
|
||||
result = await run_agent_task(
|
||||
{
|
||||
"command": "run",
|
||||
"run_input": {
|
||||
"threadId": thread_id,
|
||||
"runId": "run-1",
|
||||
"state": {},
|
||||
"messages": [{"id": "u1", "role": "user", "content": "hello"}],
|
||||
"tools": [],
|
||||
"context": [],
|
||||
"forwardedProps": {},
|
||||
},
|
||||
},
|
||||
publish_event=_publish,
|
||||
run_service=_FakeRunService(),
|
||||
)
|
||||
|
||||
assert result["threadId"] == thread_id
|
||||
assert published[0] == "RUN_STARTED"
|
||||
assert published[-1] == "RUN_FINISHED"
|
||||
|
||||
|
||||
async def test_tool_result_full_payload_persist_and_reconstruct() -> None:
|
||||
storage = _FakeStorage()
|
||||
payload = {
|
||||
"schema": "ui.v1",
|
||||
"components": [{"type": "card", "title": "Weather"}],
|
||||
}
|
||||
|
||||
metadata = await persist_tool_result_payload(
|
||||
storage=storage,
|
||||
run_id="run-1",
|
||||
turn_id="turn-1",
|
||||
tool_call_id="call-1",
|
||||
tool_name="weather",
|
||||
payload=payload,
|
||||
bucket="private",
|
||||
path="tool-results/run-1/call-1.json",
|
||||
)
|
||||
|
||||
event = reconstruct_tool_call_result_event(metadata=metadata, payload=payload)
|
||||
|
||||
assert metadata["type"] == "tool_result"
|
||||
assert metadata["storage_bucket"] == "private"
|
||||
assert event["type"] == "TOOL_CALL_RESULT"
|
||||
assert event["data"]["schema"] == "ui.v1"
|
||||
@@ -7,8 +7,11 @@ from uuid import UUID, uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from core.agent.domain.system_agent_config import SystemAgentLLMConfig
|
||||
from core.agent.domain.user_context import UserAgentContext, parse_profile_settings
|
||||
from core.agentscope.schemas.system_agent_config import SystemAgentLLMConfig
|
||||
from core.agentscope.schemas.user_context import (
|
||||
UserAgentContext,
|
||||
parse_profile_settings,
|
||||
)
|
||||
from core.agentscope.runtime.config_loader import RuntimeStageConfig
|
||||
from core.agentscope.runtime.orchestrator import AgentScopeRuntimeOrchestrator
|
||||
from core.db.session import AsyncSessionLocal
|
||||
|
||||
@@ -69,6 +69,7 @@ def _override_current_user(user_id: UUID) -> Callable[[], CurrentUser]:
|
||||
def test_create_schedule_item_returns_201() -> None:
|
||||
item = ScheduleItemResponse(
|
||||
id=uuid4(),
|
||||
owner_id=uuid4(),
|
||||
title="Test Event",
|
||||
start_at=datetime(2026, 2, 28, 16, 0, 0, tzinfo=timezone.utc),
|
||||
timezone="UTC",
|
||||
@@ -76,6 +77,8 @@ def test_create_schedule_item_returns_201() -> None:
|
||||
source_type=ScheduleItemSourceType.MANUAL,
|
||||
created_at=datetime(2026, 2, 27, 10, 0, 0, tzinfo=timezone.utc),
|
||||
updated_at=datetime(2026, 2, 27, 10, 0, 0, tzinfo=timezone.utc),
|
||||
permission=7,
|
||||
is_owner=True,
|
||||
)
|
||||
|
||||
app.dependency_overrides[get_schedule_item_service] = (
|
||||
@@ -99,6 +102,7 @@ def test_create_schedule_item_returns_201() -> None:
|
||||
def test_list_schedule_items_returns_200() -> None:
|
||||
item = ScheduleItemResponse(
|
||||
id=uuid4(),
|
||||
owner_id=uuid4(),
|
||||
title="Test Event",
|
||||
start_at=datetime(2026, 2, 28, 16, 0, 0, tzinfo=timezone.utc),
|
||||
timezone="UTC",
|
||||
@@ -106,6 +110,8 @@ def test_list_schedule_items_returns_200() -> None:
|
||||
source_type=ScheduleItemSourceType.MANUAL,
|
||||
created_at=datetime(2026, 2, 27, 10, 0, 0, tzinfo=timezone.utc),
|
||||
updated_at=datetime(2026, 2, 27, 10, 0, 0, tzinfo=timezone.utc),
|
||||
permission=7,
|
||||
is_owner=True,
|
||||
)
|
||||
|
||||
app.dependency_overrides[get_schedule_item_service] = (
|
||||
@@ -131,6 +137,7 @@ def test_get_schedule_item_returns_200() -> None:
|
||||
item_id = uuid4()
|
||||
item = ScheduleItemResponse(
|
||||
id=item_id,
|
||||
owner_id=uuid4(),
|
||||
title="Test Event",
|
||||
start_at=datetime(2026, 2, 28, 16, 0, 0, tzinfo=timezone.utc),
|
||||
timezone="UTC",
|
||||
@@ -138,6 +145,8 @@ def test_get_schedule_item_returns_200() -> None:
|
||||
source_type=ScheduleItemSourceType.MANUAL,
|
||||
created_at=datetime(2026, 2, 27, 10, 0, 0, tzinfo=timezone.utc),
|
||||
updated_at=datetime(2026, 2, 27, 10, 0, 0, tzinfo=timezone.utc),
|
||||
permission=7,
|
||||
is_owner=True,
|
||||
)
|
||||
|
||||
app.dependency_overrides[get_schedule_item_service] = (
|
||||
@@ -156,6 +165,7 @@ def test_update_schedule_item_returns_200() -> None:
|
||||
item_id = uuid4()
|
||||
item = ScheduleItemResponse(
|
||||
id=item_id,
|
||||
owner_id=uuid4(),
|
||||
title="Updated Event",
|
||||
start_at=datetime(2026, 2, 28, 16, 0, 0, tzinfo=timezone.utc),
|
||||
timezone="UTC",
|
||||
@@ -163,6 +173,8 @@ def test_update_schedule_item_returns_200() -> None:
|
||||
source_type=ScheduleItemSourceType.MANUAL,
|
||||
created_at=datetime(2026, 2, 27, 10, 0, 0, tzinfo=timezone.utc),
|
||||
updated_at=datetime(2026, 2, 27, 10, 0, 0, tzinfo=timezone.utc),
|
||||
permission=7,
|
||||
is_owner=True,
|
||||
)
|
||||
|
||||
app.dependency_overrides[get_schedule_item_service] = (
|
||||
@@ -184,6 +196,7 @@ def test_delete_schedule_item_returns_204() -> None:
|
||||
item_id = uuid4()
|
||||
item = ScheduleItemResponse(
|
||||
id=item_id,
|
||||
owner_id=uuid4(),
|
||||
title="Test Event",
|
||||
start_at=datetime(2026, 2, 28, 16, 0, 0, tzinfo=timezone.utc),
|
||||
timezone="UTC",
|
||||
@@ -191,6 +204,8 @@ def test_delete_schedule_item_returns_204() -> None:
|
||||
source_type=ScheduleItemSourceType.MANUAL,
|
||||
created_at=datetime(2026, 2, 27, 10, 0, 0, tzinfo=timezone.utc),
|
||||
updated_at=datetime(2026, 2, 27, 10, 0, 0, tzinfo=timezone.utc),
|
||||
permission=7,
|
||||
is_owner=True,
|
||||
)
|
||||
|
||||
app.dependency_overrides[get_schedule_item_service] = (
|
||||
|
||||
@@ -354,6 +354,12 @@ def test_asr_transcribe_returns_sync_transcript(monkeypatch) -> None:
|
||||
id=uuid4(), email="user@example.com"
|
||||
)
|
||||
|
||||
async def _allow_transcribe(*, user_id: str) -> bool:
|
||||
del user_id
|
||||
return True
|
||||
|
||||
monkeypatch.setattr(agent_router, "_allow_transcribe_request", _allow_transcribe)
|
||||
|
||||
async def mock_transcribe_file(file_path: str, filename: str) -> str:
|
||||
assert file_path.endswith(".wav")
|
||||
assert filename == "test.wav"
|
||||
@@ -391,6 +397,12 @@ def test_asr_transcribe_rejects_oversized_audio(monkeypatch) -> None:
|
||||
|
||||
monkeypatch.setattr(agent_router, "_MAX_TRANSCRIBE_AUDIO_BYTES", 4)
|
||||
|
||||
async def _allow_transcribe(*, user_id: str) -> bool:
|
||||
del user_id
|
||||
return True
|
||||
|
||||
monkeypatch.setattr(agent_router, "_allow_transcribe_request", _allow_transcribe)
|
||||
|
||||
client = TestClient(app)
|
||||
oversized = BytesIO(b"12345")
|
||||
oversized.name = "test.wav"
|
||||
@@ -407,11 +419,17 @@ def test_asr_transcribe_rejects_oversized_audio(monkeypatch) -> None:
|
||||
app.dependency_overrides = {}
|
||||
|
||||
|
||||
def test_asr_transcribe_rejects_non_wav_audio() -> None:
|
||||
def test_asr_transcribe_rejects_non_wav_audio(monkeypatch) -> None:
|
||||
app.dependency_overrides[get_current_user] = lambda: CurrentUser(
|
||||
id=uuid4(), email="user@example.com"
|
||||
)
|
||||
|
||||
async def _allow_transcribe(*, user_id: str) -> bool:
|
||||
del user_id
|
||||
return True
|
||||
|
||||
monkeypatch.setattr(agent_router, "_allow_transcribe_request", _allow_transcribe)
|
||||
|
||||
client = TestClient(app)
|
||||
fake_mp3 = BytesIO(b"fake-mp3")
|
||||
fake_mp3.name = "test.mp3"
|
||||
@@ -428,11 +446,17 @@ def test_asr_transcribe_rejects_non_wav_audio() -> None:
|
||||
app.dependency_overrides = {}
|
||||
|
||||
|
||||
def test_asr_transcribe_rejects_invalid_wav_payload() -> None:
|
||||
def test_asr_transcribe_rejects_invalid_wav_payload(monkeypatch) -> None:
|
||||
app.dependency_overrides[get_current_user] = lambda: CurrentUser(
|
||||
id=uuid4(), email="user@example.com"
|
||||
)
|
||||
|
||||
async def _allow_transcribe(*, user_id: str) -> bool:
|
||||
del user_id
|
||||
return True
|
||||
|
||||
monkeypatch.setattr(agent_router, "_allow_transcribe_request", _allow_transcribe)
|
||||
|
||||
client = TestClient(app)
|
||||
fake_payload = BytesIO(b"not-a-wav")
|
||||
fake_payload.name = "test.wav"
|
||||
@@ -447,3 +471,33 @@ def test_asr_transcribe_rejects_invalid_wav_payload() -> None:
|
||||
assert response.json()["detail"] == "Unsupported audio format"
|
||||
finally:
|
||||
app.dependency_overrides = {}
|
||||
|
||||
|
||||
def test_asr_transcribe_rejects_when_rate_limited_for_current_user(monkeypatch) -> None:
|
||||
known_user = CurrentUser(id=uuid4(), email="user@example.com")
|
||||
app.dependency_overrides[get_current_user] = lambda: known_user
|
||||
|
||||
captured_user_ids: list[str] = []
|
||||
|
||||
async def _deny_transcribe(*, user_id: str) -> bool:
|
||||
captured_user_ids.append(user_id)
|
||||
return False
|
||||
|
||||
monkeypatch.setattr(agent_router, "_allow_transcribe_request", _deny_transcribe)
|
||||
|
||||
client = TestClient(app)
|
||||
wav_content = b"RIFF\x24\x80\x00\x00WAVEfmt "
|
||||
wav_file = BytesIO(wav_content)
|
||||
wav_file.name = "test.wav"
|
||||
|
||||
try:
|
||||
response = client.post(
|
||||
"/api/v1/agent/transcribe",
|
||||
files={"audio": ("test.wav", wav_file, "audio/wav")},
|
||||
)
|
||||
|
||||
assert response.status_code == 429
|
||||
assert response.json()["detail"] == "Too many transcribe requests"
|
||||
assert captured_user_ids == [str(known_user.id)]
|
||||
finally:
|
||||
app.dependency_overrides = {}
|
||||
|
||||
Reference in New Issue
Block a user