feat(agent): 增强多模态链路与工具调用能力
This commit is contained in:
@@ -0,0 +1,20 @@
|
|||||||
|
{
|
||||||
|
"$schema": "https://opencode.ai/config.json",
|
||||||
|
"mcp": {
|
||||||
|
"supabase": {
|
||||||
|
"type": "local",
|
||||||
|
"enabled": true,
|
||||||
|
"command": [
|
||||||
|
"npx",
|
||||||
|
"-y",
|
||||||
|
"@aliyun-rds/supabase-mcp-server",
|
||||||
|
"--supabase-url",
|
||||||
|
"http://47.112.66.83",
|
||||||
|
"--supabase-anon-key",
|
||||||
|
"eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJpc3MiOiJzdXBhYmFzZSIsInJvbGUiOiJhbm9uIiwiaWF0IjoxNzczMDI3NDE5LCJleHAiOjEzMjgzNjY3NDE5fQ.NVXDla5_nYPdcJk_81fc3k1UrnNTrNne_trMqt6Hg4g",
|
||||||
|
"--supabase-service-role-key",
|
||||||
|
"eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJpc3MiOiJzdXBhYmFzZSIsInJvbGUiOiJzZXJ2aWNlX3JvbGUiLCJpYXQiOjE3NzMwMjc0MTksImV4cCI6MTMyODM2Njc0MTl9.RzQBia-3QcjupsHnqaxgDWB7wnY9R7Ms9R8pMokyvLY"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -23,10 +23,12 @@ class MessageRepository:
|
|||||||
role: AgentChatMessageRole,
|
role: AgentChatMessageRole,
|
||||||
content: str,
|
content: str,
|
||||||
model_code: str | None = None,
|
model_code: str | None = None,
|
||||||
|
tool_name: str | None = None,
|
||||||
metadata: dict[str, object] | None = None,
|
metadata: dict[str, object] | None = None,
|
||||||
input_tokens: int = 0,
|
input_tokens: int = 0,
|
||||||
output_tokens: int = 0,
|
output_tokens: int = 0,
|
||||||
cost: Decimal = Decimal("0"),
|
cost: Decimal = Decimal("0"),
|
||||||
|
latency_ms: int | None = None,
|
||||||
) -> AgentChatMessage:
|
) -> AgentChatMessage:
|
||||||
message = AgentChatMessage(
|
message = AgentChatMessage(
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
@@ -34,10 +36,12 @@ class MessageRepository:
|
|||||||
role=role,
|
role=role,
|
||||||
content=content,
|
content=content,
|
||||||
model_code=model_code,
|
model_code=model_code,
|
||||||
|
tool_name=tool_name,
|
||||||
metadata_json=metadata,
|
metadata_json=metadata,
|
||||||
input_tokens=input_tokens,
|
input_tokens=input_tokens,
|
||||||
output_tokens=output_tokens,
|
output_tokens=output_tokens,
|
||||||
cost=cost,
|
cost=cost,
|
||||||
|
latency_ms=latency_ms,
|
||||||
)
|
)
|
||||||
self._session.add(message)
|
self._session.add(message)
|
||||||
await self._session.flush()
|
await self._session.flush()
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
from decimal import Decimal, InvalidOperation
|
from decimal import Decimal, InvalidOperation
|
||||||
from typing import Any, Callable, Protocol
|
from typing import Any, Callable, Protocol
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
@@ -24,6 +25,7 @@ class SqlAlchemyEventStore:
|
|||||||
def __init__(self, *, session_factory: Any) -> None:
|
def __init__(self, *, session_factory: Any) -> None:
|
||||||
self._session_factory = session_factory
|
self._session_factory = session_factory
|
||||||
self._message_buffers: dict[tuple[str, str], str] = {}
|
self._message_buffers: dict[tuple[str, str], str] = {}
|
||||||
|
self._message_contexts: dict[tuple[str, str], dict[str, object]] = {}
|
||||||
|
|
||||||
async def persist(self, event: dict[str, Any]) -> None:
|
async def persist(self, event: dict[str, Any]) -> None:
|
||||||
event_type = str(event.get("type", "")).strip().upper()
|
event_type = str(event.get("type", "")).strip().upper()
|
||||||
@@ -48,6 +50,10 @@ class SqlAlchemyEventStore:
|
|||||||
self._buffer_text_delta(session_key=session_key, event=event)
|
self._buffer_text_delta(session_key=session_key, event=event)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
if event_type == "TEXT_MESSAGE_START":
|
||||||
|
self._buffer_text_context(session_key=session_key, event=event)
|
||||||
|
return
|
||||||
|
|
||||||
if event_type == "RUN_STARTED":
|
if event_type == "RUN_STARTED":
|
||||||
await self._update_session_state(
|
await self._update_session_state(
|
||||||
session_repo=session_repo,
|
session_repo=session_repo,
|
||||||
@@ -72,7 +78,15 @@ class SqlAlchemyEventStore:
|
|||||||
)
|
)
|
||||||
self._clear_session_buffers(session_key=session_key)
|
self._clear_session_buffers(session_key=session_key)
|
||||||
elif event_type == "TEXT_MESSAGE_END":
|
elif event_type == "TEXT_MESSAGE_END":
|
||||||
await self._persist_assistant_message(
|
await self._persist_text_message(
|
||||||
|
event=event,
|
||||||
|
session_id=session_id,
|
||||||
|
chat_session=chat_session,
|
||||||
|
session_repo=session_repo,
|
||||||
|
message_repo=message_repo,
|
||||||
|
)
|
||||||
|
elif event_type == "TOOL_CALL_RESULT":
|
||||||
|
await self._persist_tool_call_result(
|
||||||
event=event,
|
event=event,
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
chat_session=chat_session,
|
chat_session=chat_session,
|
||||||
@@ -97,8 +111,28 @@ class SqlAlchemyEventStore:
|
|||||||
stale_keys = [k for k in self._message_buffers if k[0] == session_key]
|
stale_keys = [k for k in self._message_buffers if k[0] == session_key]
|
||||||
for key in stale_keys:
|
for key in stale_keys:
|
||||||
self._message_buffers.pop(key, None)
|
self._message_buffers.pop(key, None)
|
||||||
|
stale_context_keys = [k for k in self._message_contexts if k[0] == session_key]
|
||||||
|
for key in stale_context_keys:
|
||||||
|
self._message_contexts.pop(key, None)
|
||||||
|
|
||||||
async def _persist_assistant_message(
|
def _buffer_text_context(self, *, session_key: str, event: dict[str, Any]) -> None:
|
||||||
|
message_id = event.get("messageId")
|
||||||
|
if not isinstance(message_id, str) or not message_id:
|
||||||
|
return
|
||||||
|
key = (session_key, message_id)
|
||||||
|
role = event.get("role")
|
||||||
|
stage = event.get("stage")
|
||||||
|
tool_name = event.get("toolName")
|
||||||
|
context: dict[str, object] = {}
|
||||||
|
if isinstance(role, str) and role:
|
||||||
|
context["role"] = role
|
||||||
|
if isinstance(stage, str) and stage:
|
||||||
|
context["stage"] = stage
|
||||||
|
if isinstance(tool_name, str) and tool_name:
|
||||||
|
context["tool_name"] = tool_name
|
||||||
|
self._message_contexts[key] = context
|
||||||
|
|
||||||
|
async def _persist_text_message(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
event: dict[str, Any],
|
event: dict[str, Any],
|
||||||
@@ -114,6 +148,8 @@ class SqlAlchemyEventStore:
|
|||||||
if not content:
|
if not content:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
context = self._message_contexts.get(key, {})
|
||||||
|
|
||||||
input_tokens = self._to_int(event.get("inputTokens"))
|
input_tokens = self._to_int(event.get("inputTokens"))
|
||||||
output_tokens = self._to_int(event.get("outputTokens"))
|
output_tokens = self._to_int(event.get("outputTokens"))
|
||||||
token_delta = input_tokens + output_tokens
|
token_delta = input_tokens + output_tokens
|
||||||
@@ -127,6 +163,20 @@ class SqlAlchemyEventStore:
|
|||||||
metadata["run_id"] = run_id
|
metadata["run_id"] = run_id
|
||||||
if latency_ms is not None:
|
if latency_ms is not None:
|
||||||
metadata["latency_ms"] = latency_ms
|
metadata["latency_ms"] = latency_ms
|
||||||
|
stage = event.get("stage")
|
||||||
|
if not isinstance(stage, str):
|
||||||
|
stage = context.get("stage")
|
||||||
|
if isinstance(stage, str) and stage:
|
||||||
|
metadata["stage"] = stage
|
||||||
|
|
||||||
|
role_value = context.get("role")
|
||||||
|
if not isinstance(role_value, str):
|
||||||
|
role_value = "assistant"
|
||||||
|
role = self._resolve_role(role_value)
|
||||||
|
tool_name = context.get("tool_name")
|
||||||
|
tool_name_value = (
|
||||||
|
tool_name if isinstance(tool_name, str) and tool_name else None
|
||||||
|
)
|
||||||
|
|
||||||
locked_session = await session_repo.lock_session_for_update(
|
locked_session = await session_repo.lock_session_for_update(
|
||||||
session_id=session_id
|
session_id=session_id
|
||||||
@@ -137,13 +187,15 @@ class SqlAlchemyEventStore:
|
|||||||
await message_repo.append_message(
|
await message_repo.append_message(
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
seq=seq,
|
seq=seq,
|
||||||
role=AgentChatMessageRole.ASSISTANT,
|
role=role,
|
||||||
content=content,
|
content=content,
|
||||||
model_code=model_code if isinstance(model_code, str) else None,
|
model_code=model_code if isinstance(model_code, str) else None,
|
||||||
|
tool_name=tool_name_value,
|
||||||
metadata=metadata,
|
metadata=metadata,
|
||||||
input_tokens=input_tokens,
|
input_tokens=input_tokens,
|
||||||
output_tokens=output_tokens,
|
output_tokens=output_tokens,
|
||||||
cost=cost,
|
cost=cost,
|
||||||
|
latency_ms=latency_ms,
|
||||||
)
|
)
|
||||||
|
|
||||||
current_status = getattr(chat_session, "status", AgentChatSessionStatus.RUNNING)
|
current_status = getattr(chat_session, "status", AgentChatSessionStatus.RUNNING)
|
||||||
@@ -161,6 +213,74 @@ class SqlAlchemyEventStore:
|
|||||||
cost_delta=cost,
|
cost_delta=cost,
|
||||||
)
|
)
|
||||||
self._message_buffers.pop(key, None)
|
self._message_buffers.pop(key, None)
|
||||||
|
self._message_contexts.pop(key, None)
|
||||||
|
|
||||||
|
async def _persist_tool_call_result(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
event: dict[str, Any],
|
||||||
|
session_id: UUID,
|
||||||
|
chat_session: Any,
|
||||||
|
session_repo: SessionRepository,
|
||||||
|
message_repo: MessageRepository,
|
||||||
|
) -> None:
|
||||||
|
tool_name = event.get("toolName")
|
||||||
|
if not isinstance(tool_name, str) or not tool_name:
|
||||||
|
return
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"args": event.get("args"),
|
||||||
|
"result": event.get("result"),
|
||||||
|
"error": event.get("error"),
|
||||||
|
"call_id": event.get("callId"),
|
||||||
|
}
|
||||||
|
content = json.dumps(payload, ensure_ascii=False, separators=(",", ":"))
|
||||||
|
metadata: dict[str, object] = {"tool_name": tool_name}
|
||||||
|
run_id = event.get("runId")
|
||||||
|
if isinstance(run_id, str) and run_id:
|
||||||
|
metadata["run_id"] = run_id
|
||||||
|
stage = event.get("stage")
|
||||||
|
if isinstance(stage, str) and stage:
|
||||||
|
metadata["stage"] = stage
|
||||||
|
task_id = event.get("taskId")
|
||||||
|
if isinstance(task_id, str) and task_id:
|
||||||
|
metadata["task_id"] = task_id
|
||||||
|
|
||||||
|
locked_session = await session_repo.lock_session_for_update(
|
||||||
|
session_id=session_id
|
||||||
|
)
|
||||||
|
if locked_session is None:
|
||||||
|
return
|
||||||
|
seq = int(getattr(locked_session, "message_count", 0) or 0) + 1
|
||||||
|
await message_repo.append_message(
|
||||||
|
session_id=session_id,
|
||||||
|
seq=seq,
|
||||||
|
role=AgentChatMessageRole.TOOL,
|
||||||
|
content=content,
|
||||||
|
tool_name=tool_name,
|
||||||
|
metadata=metadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
current_status = getattr(chat_session, "status", AgentChatSessionStatus.RUNNING)
|
||||||
|
status = (
|
||||||
|
current_status
|
||||||
|
if isinstance(current_status, AgentChatSessionStatus)
|
||||||
|
else AgentChatSessionStatus.RUNNING
|
||||||
|
)
|
||||||
|
await self._update_session_state(
|
||||||
|
session_repo=session_repo,
|
||||||
|
chat_session=chat_session,
|
||||||
|
status=status,
|
||||||
|
message_delta=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _resolve_role(self, value: str) -> AgentChatMessageRole:
|
||||||
|
normalized = value.strip().lower()
|
||||||
|
if normalized == AgentChatMessageRole.SYSTEM.value:
|
||||||
|
return AgentChatMessageRole.SYSTEM
|
||||||
|
if normalized == AgentChatMessageRole.TOOL.value:
|
||||||
|
return AgentChatMessageRole.TOOL
|
||||||
|
return AgentChatMessageRole.ASSISTANT
|
||||||
|
|
||||||
async def _update_session_state(
|
async def _update_session_state(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -38,7 +38,38 @@ def _schema_json(model: type[Any]) -> str:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def build_intent_user_prompt(*, user_input: str | list[dict[str, Any]]) -> str:
|
def build_intent_user_prompt(
|
||||||
|
*, user_input: str | list[dict[str, Any]]
|
||||||
|
) -> str | list[dict[str, Any]]:
|
||||||
|
if isinstance(user_input, list):
|
||||||
|
instruction_text = "\n\n".join(
|
||||||
|
[
|
||||||
|
INTENT_TASK_INSTRUCTION,
|
||||||
|
"[Output Schema]",
|
||||||
|
_schema_json(IntentOutput),
|
||||||
|
"[User Input]",
|
||||||
|
"Use the following multimodal blocks as the latest user input.",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
blocks = [
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": instruction_text,
|
||||||
|
}
|
||||||
|
]
|
||||||
|
user_blocks = _latest_user_content_blocks(user_input)
|
||||||
|
if not user_blocks:
|
||||||
|
user_blocks = [
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": json.dumps(
|
||||||
|
user_input, ensure_ascii=True, separators=(",", ":")
|
||||||
|
),
|
||||||
|
}
|
||||||
|
]
|
||||||
|
blocks.extend(user_blocks)
|
||||||
|
return blocks
|
||||||
|
|
||||||
normalized_input = (
|
normalized_input = (
|
||||||
user_input
|
user_input
|
||||||
if isinstance(user_input, str)
|
if isinstance(user_input, str)
|
||||||
@@ -55,6 +86,101 @@ def build_intent_user_prompt(*, user_input: str | list[dict[str, Any]]) -> str:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _latest_user_content_blocks(
|
||||||
|
user_input: list[dict[str, Any]],
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
for message in reversed(user_input):
|
||||||
|
if not isinstance(message, dict):
|
||||||
|
continue
|
||||||
|
if message.get("role") != "user":
|
||||||
|
continue
|
||||||
|
content = message.get("content")
|
||||||
|
if isinstance(content, str):
|
||||||
|
text = content.strip()
|
||||||
|
return [{"type": "text", "text": text}] if text else []
|
||||||
|
if not isinstance(content, list):
|
||||||
|
return []
|
||||||
|
|
||||||
|
blocks: list[dict[str, Any]] = []
|
||||||
|
for item in content:
|
||||||
|
if not isinstance(item, dict):
|
||||||
|
continue
|
||||||
|
item_type = item.get("type")
|
||||||
|
if item_type == "text":
|
||||||
|
text = item.get("text")
|
||||||
|
if isinstance(text, str) and text.strip():
|
||||||
|
blocks.append({"type": "text", "text": text})
|
||||||
|
continue
|
||||||
|
|
||||||
|
if item_type == "binary":
|
||||||
|
source_block = _binary_source_block(item)
|
||||||
|
if source_block is not None:
|
||||||
|
blocks.append(source_block)
|
||||||
|
continue
|
||||||
|
|
||||||
|
if item_type == "image":
|
||||||
|
source_block = _image_source_block(item)
|
||||||
|
if source_block is not None:
|
||||||
|
blocks.append(source_block)
|
||||||
|
|
||||||
|
return blocks
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
def _binary_source_block(item: dict[str, Any]) -> dict[str, Any] | None:
|
||||||
|
mime_type = item.get("mimeType")
|
||||||
|
media_type = mime_type if isinstance(mime_type, str) and mime_type else "image/png"
|
||||||
|
if not media_type.startswith("image/"):
|
||||||
|
return None
|
||||||
|
|
||||||
|
source_url = item.get("url")
|
||||||
|
if isinstance(source_url, str) and source_url:
|
||||||
|
return {"type": "image", "source": {"type": "url", "url": source_url}}
|
||||||
|
|
||||||
|
source_data = item.get("data")
|
||||||
|
if isinstance(source_data, str) and source_data:
|
||||||
|
return {
|
||||||
|
"type": "image",
|
||||||
|
"source": {
|
||||||
|
"type": "base64",
|
||||||
|
"media_type": media_type,
|
||||||
|
"data": source_data,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _image_source_block(item: dict[str, Any]) -> dict[str, Any] | None:
|
||||||
|
source = item.get("source")
|
||||||
|
if not isinstance(source, dict):
|
||||||
|
return None
|
||||||
|
|
||||||
|
source_type = source.get("type")
|
||||||
|
if source_type == "url":
|
||||||
|
source_url = source.get("value") or source.get("url")
|
||||||
|
if isinstance(source_url, str) and source_url:
|
||||||
|
return {"type": "image", "source": {"type": "url", "url": source_url}}
|
||||||
|
|
||||||
|
if source_type in {"data", "base64"}:
|
||||||
|
source_data = source.get("value") or source.get("data")
|
||||||
|
if isinstance(source_data, str) and source_data:
|
||||||
|
mime_type = source.get("mimeType") or source.get("media_type")
|
||||||
|
media_type = (
|
||||||
|
mime_type if isinstance(mime_type, str) and mime_type else "image/png"
|
||||||
|
)
|
||||||
|
if not media_type.startswith("image/"):
|
||||||
|
return None
|
||||||
|
return {
|
||||||
|
"type": "image",
|
||||||
|
"source": {
|
||||||
|
"type": "base64",
|
||||||
|
"media_type": media_type,
|
||||||
|
"data": source_data,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def build_execution_user_prompt(
|
def build_execution_user_prompt(
|
||||||
*,
|
*,
|
||||||
task_id: str,
|
task_id: str,
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
from typing import Any, Protocol
|
from typing import Any, Protocol
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
@@ -153,6 +154,44 @@ class AgentRouteRuntime:
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
await self._emit_stage_text(
|
||||||
|
thread_id=command.thread_id,
|
||||||
|
run_id=command.run_id,
|
||||||
|
stage_name="intent",
|
||||||
|
message_id=f"intent-{command.run_id}",
|
||||||
|
text=_intent_text_payload(result.intent),
|
||||||
|
response_metadata=result.intent.response_metadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
if result.intent.route == "DIRECT_RESPONSE" and result.execution is None:
|
||||||
|
await self._pipeline.emit(
|
||||||
|
session_id=command.thread_id,
|
||||||
|
event={
|
||||||
|
"type": "run.finished",
|
||||||
|
"threadId": command.thread_id,
|
||||||
|
"runId": command.run_id,
|
||||||
|
"data": {},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
|
||||||
|
if result.execution is not None:
|
||||||
|
for index, task in enumerate(result.execution.task_results, start=1):
|
||||||
|
await self._emit_stage_text(
|
||||||
|
thread_id=command.thread_id,
|
||||||
|
run_id=command.run_id,
|
||||||
|
stage_name="execution",
|
||||||
|
message_id=f"execution-{command.run_id}-{index}",
|
||||||
|
text=task.execution_summary,
|
||||||
|
response_metadata=task.response_metadata,
|
||||||
|
)
|
||||||
|
await self._emit_tool_result_events(
|
||||||
|
thread_id=command.thread_id,
|
||||||
|
run_id=command.run_id,
|
||||||
|
task_id=task.task_id,
|
||||||
|
tool_calls=_task_tool_calls(task),
|
||||||
|
)
|
||||||
|
|
||||||
await self._pipeline.emit(
|
await self._pipeline.emit(
|
||||||
session_id=command.thread_id,
|
session_id=command.thread_id,
|
||||||
event={
|
event={
|
||||||
@@ -164,35 +203,18 @@ class AgentRouteRuntime:
|
|||||||
)
|
)
|
||||||
|
|
||||||
report_message_id = f"assistant-{command.run_id}"
|
report_message_id = f"assistant-{command.run_id}"
|
||||||
await self._pipeline.emit(
|
response_metadata = (
|
||||||
session_id=command.thread_id,
|
result.report.response_metadata
|
||||||
event={
|
if isinstance(result.report.response_metadata, dict)
|
||||||
"type": "text.start",
|
else {}
|
||||||
"threadId": command.thread_id,
|
|
||||||
"runId": command.run_id,
|
|
||||||
"data": {"messageId": report_message_id, "role": "assistant"},
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
await self._pipeline.emit(
|
await self._emit_stage_text(
|
||||||
session_id=command.thread_id,
|
thread_id=command.thread_id,
|
||||||
event={
|
run_id=command.run_id,
|
||||||
"type": "text.delta",
|
stage_name="report",
|
||||||
"threadId": command.thread_id,
|
message_id=report_message_id,
|
||||||
"runId": command.run_id,
|
text=result.report.assistant_text,
|
||||||
"data": {
|
response_metadata=response_metadata,
|
||||||
"messageId": report_message_id,
|
|
||||||
"delta": result.report.assistant_text,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
)
|
|
||||||
await self._pipeline.emit(
|
|
||||||
session_id=command.thread_id,
|
|
||||||
event={
|
|
||||||
"type": "text.end",
|
|
||||||
"threadId": command.thread_id,
|
|
||||||
"runId": command.run_id,
|
|
||||||
"data": {"messageId": report_message_id},
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
await self._pipeline.emit(
|
await self._pipeline.emit(
|
||||||
session_id=command.thread_id,
|
session_id=command.thread_id,
|
||||||
@@ -213,3 +235,178 @@ class AgentRouteRuntime:
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
async def _emit_stage_text(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
thread_id: str,
|
||||||
|
run_id: str,
|
||||||
|
stage_name: str,
|
||||||
|
message_id: str,
|
||||||
|
text: str,
|
||||||
|
response_metadata: dict[str, Any],
|
||||||
|
) -> None:
|
||||||
|
await self._pipeline.emit(
|
||||||
|
session_id=thread_id,
|
||||||
|
event={
|
||||||
|
"type": "text.start",
|
||||||
|
"threadId": thread_id,
|
||||||
|
"runId": run_id,
|
||||||
|
"data": {
|
||||||
|
"messageId": message_id,
|
||||||
|
"role": "assistant",
|
||||||
|
"stage": stage_name,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
await self._pipeline.emit(
|
||||||
|
session_id=thread_id,
|
||||||
|
event={
|
||||||
|
"type": "text.delta",
|
||||||
|
"threadId": thread_id,
|
||||||
|
"runId": run_id,
|
||||||
|
"data": {"messageId": message_id, "delta": text},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
await self._pipeline.emit(
|
||||||
|
session_id=thread_id,
|
||||||
|
event={
|
||||||
|
"type": "text.end",
|
||||||
|
"threadId": thread_id,
|
||||||
|
"runId": run_id,
|
||||||
|
"data": {
|
||||||
|
"messageId": message_id,
|
||||||
|
"stage": stage_name,
|
||||||
|
**_text_end_telemetry_payload(response_metadata),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _emit_tool_result_events(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
thread_id: str,
|
||||||
|
run_id: str,
|
||||||
|
task_id: str,
|
||||||
|
tool_calls: list[dict[str, Any]],
|
||||||
|
) -> None:
|
||||||
|
for index, tool_call in enumerate(tool_calls, start=1):
|
||||||
|
tool_name = tool_call.get("tool_name")
|
||||||
|
if not isinstance(tool_name, str) or not tool_name:
|
||||||
|
continue
|
||||||
|
await self._pipeline.emit(
|
||||||
|
session_id=thread_id,
|
||||||
|
event={
|
||||||
|
"type": "tool.result",
|
||||||
|
"threadId": thread_id,
|
||||||
|
"runId": run_id,
|
||||||
|
"data": {
|
||||||
|
"callId": f"{run_id}-{task_id}-{index}",
|
||||||
|
"stage": "execution",
|
||||||
|
"taskId": task_id,
|
||||||
|
"toolName": tool_name,
|
||||||
|
"args": tool_call.get("args", {}),
|
||||||
|
"result": tool_call.get("result"),
|
||||||
|
"error": tool_call.get("error"),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _text_end_telemetry_payload(metadata: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
payload: dict[str, Any] = {}
|
||||||
|
model = _first_non_empty_str(metadata, keys=("model", "model_code"))
|
||||||
|
if model is not None:
|
||||||
|
payload["model"] = model
|
||||||
|
|
||||||
|
input_tokens = _first_number(metadata, keys=("inputTokens", "input_tokens"))
|
||||||
|
if input_tokens is not None:
|
||||||
|
payload["inputTokens"] = input_tokens
|
||||||
|
|
||||||
|
output_tokens = _first_number(metadata, keys=("outputTokens", "output_tokens"))
|
||||||
|
if output_tokens is not None:
|
||||||
|
payload["outputTokens"] = output_tokens
|
||||||
|
|
||||||
|
latency_ms = _first_number(metadata, keys=("latencyMs", "latency_ms"))
|
||||||
|
if latency_ms is not None:
|
||||||
|
payload["latencyMs"] = latency_ms
|
||||||
|
|
||||||
|
cost = _first_number(metadata, keys=("cost", "total_cost"), allow_float=True)
|
||||||
|
if cost is not None:
|
||||||
|
payload["cost"] = cost
|
||||||
|
|
||||||
|
return payload
|
||||||
|
|
||||||
|
|
||||||
|
def _intent_text_payload(intent: Any) -> str:
|
||||||
|
direct_response = getattr(intent, "direct_response", None)
|
||||||
|
if isinstance(direct_response, str) and direct_response.strip():
|
||||||
|
return direct_response
|
||||||
|
return json.dumps(intent.model_dump(mode="json"), ensure_ascii=False)
|
||||||
|
|
||||||
|
|
||||||
|
def _task_tool_calls(task: Any) -> list[dict[str, Any]]:
|
||||||
|
normalized: list[dict[str, Any]] = []
|
||||||
|
|
||||||
|
tool_calls = getattr(task, "tool_calls", None)
|
||||||
|
if isinstance(tool_calls, list):
|
||||||
|
for item in tool_calls:
|
||||||
|
if hasattr(item, "model_dump"):
|
||||||
|
dumped = item.model_dump(mode="json")
|
||||||
|
if isinstance(dumped, dict):
|
||||||
|
normalized.append(dumped)
|
||||||
|
elif isinstance(item, dict):
|
||||||
|
normalized.append(item)
|
||||||
|
|
||||||
|
if normalized:
|
||||||
|
return normalized
|
||||||
|
|
||||||
|
execution_data = getattr(task, "execution_data", None)
|
||||||
|
if not isinstance(execution_data, dict):
|
||||||
|
return []
|
||||||
|
fallback_calls = execution_data.get("tool_calls")
|
||||||
|
if not isinstance(fallback_calls, list):
|
||||||
|
return []
|
||||||
|
|
||||||
|
for item in fallback_calls:
|
||||||
|
if isinstance(item, dict):
|
||||||
|
normalized.append(item)
|
||||||
|
return normalized
|
||||||
|
|
||||||
|
|
||||||
|
def _first_non_empty_str(
|
||||||
|
metadata: dict[str, Any], *, keys: tuple[str, ...]
|
||||||
|
) -> str | None:
|
||||||
|
for key in keys:
|
||||||
|
value = metadata.get(key)
|
||||||
|
if isinstance(value, str) and value.strip():
|
||||||
|
return value.strip()
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _first_number(
|
||||||
|
metadata: dict[str, Any],
|
||||||
|
*,
|
||||||
|
keys: tuple[str, ...],
|
||||||
|
allow_float: bool = False,
|
||||||
|
) -> int | float | None:
|
||||||
|
for key in keys:
|
||||||
|
value = metadata.get(key)
|
||||||
|
if isinstance(value, bool):
|
||||||
|
continue
|
||||||
|
if isinstance(value, int):
|
||||||
|
if value < 0:
|
||||||
|
continue
|
||||||
|
return value
|
||||||
|
if isinstance(value, float):
|
||||||
|
if value < 0:
|
||||||
|
continue
|
||||||
|
return value if allow_float else int(value)
|
||||||
|
if isinstance(value, str):
|
||||||
|
try:
|
||||||
|
parsed = float(value) if allow_float else int(value)
|
||||||
|
except ValueError:
|
||||||
|
continue
|
||||||
|
if parsed >= 0:
|
||||||
|
return parsed
|
||||||
|
return None
|
||||||
|
|||||||
@@ -110,6 +110,19 @@ class AgentScopeRuntimeOrchestrator:
|
|||||||
)
|
)
|
||||||
intent_output = IntentOutput.model_validate(intent_payload)
|
intent_output = IntentOutput.model_validate(intent_payload)
|
||||||
|
|
||||||
|
if intent_output.route == "DIRECT_RESPONSE":
|
||||||
|
assistant_text = (
|
||||||
|
intent_output.direct_response or intent_output.intent_summary
|
||||||
|
)
|
||||||
|
return RuntimeOutput(
|
||||||
|
intent=intent_output,
|
||||||
|
execution=None,
|
||||||
|
report=ReportOutput(
|
||||||
|
assistant_text=assistant_text,
|
||||||
|
response_metadata=dict(intent_output.response_metadata),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
execution_output: ExecutionBatchOutput | None = None
|
execution_output: ExecutionBatchOutput | None = None
|
||||||
if intent_output.route == "TASK_EXECUTION":
|
if intent_output.route == "TASK_EXECUTION":
|
||||||
execution_toolkit = build_stage_toolkit(
|
execution_toolkit = build_stage_toolkit(
|
||||||
|
|||||||
@@ -1,6 +1,8 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import json
|
import json
|
||||||
|
from time import perf_counter
|
||||||
from typing import Any, cast
|
from typing import Any, cast
|
||||||
|
|
||||||
from core.agentscope.runtime.config_loader import RuntimeStageConfig
|
from core.agentscope.runtime.config_loader import RuntimeStageConfig
|
||||||
@@ -14,7 +16,8 @@ def _to_litellm_model(*, provider_name: str, model_code: str) -> str:
|
|||||||
normalized_model = model_code.strip()
|
normalized_model = model_code.strip()
|
||||||
if "/" in normalized_model:
|
if "/" in normalized_model:
|
||||||
return normalized_model
|
return normalized_model
|
||||||
return f"{provider_name.strip().lower()}/{normalized_model}"
|
del provider_name
|
||||||
|
return normalized_model
|
||||||
|
|
||||||
|
|
||||||
def _parse_json_text(raw_text: str) -> dict[str, Any]:
|
def _parse_json_text(raw_text: str) -> dict[str, Any]:
|
||||||
@@ -30,6 +33,11 @@ def _parse_json_text(raw_text: str) -> dict[str, Any]:
|
|||||||
|
|
||||||
|
|
||||||
class AgentScopeReActRunner:
|
class AgentScopeReActRunner:
|
||||||
|
def _build_litellm_service(self) -> Any:
|
||||||
|
from services.litellm.service import LiteLLMService
|
||||||
|
|
||||||
|
return LiteLLMService()
|
||||||
|
|
||||||
def _build_model(self, *, stage_config: RuntimeStageConfig) -> Any:
|
def _build_model(self, *, stage_config: RuntimeStageConfig) -> Any:
|
||||||
from agentscope.model import OpenAIChatModel
|
from agentscope.model import OpenAIChatModel
|
||||||
from agentscope.types import JSONSerializableObject
|
from agentscope.types import JSONSerializableObject
|
||||||
@@ -61,9 +69,16 @@ class AgentScopeReActRunner:
|
|||||||
stage_config: RuntimeStageConfig,
|
stage_config: RuntimeStageConfig,
|
||||||
agent_name: str,
|
agent_name: str,
|
||||||
system_prompt: str,
|
system_prompt: str,
|
||||||
user_prompt: str,
|
user_prompt: str | list[dict[str, Any]],
|
||||||
toolkit: Any | None,
|
toolkit: Any | None,
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
|
if stage_config.stage == "report" and toolkit is None:
|
||||||
|
return await self._run_report_stage_direct(
|
||||||
|
stage_config=stage_config,
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
user_prompt=user_prompt,
|
||||||
|
)
|
||||||
|
|
||||||
from agentscope.agent import ReActAgent
|
from agentscope.agent import ReActAgent
|
||||||
from agentscope.formatter import OpenAIChatFormatter
|
from agentscope.formatter import OpenAIChatFormatter
|
||||||
from agentscope.memory import InMemoryMemory
|
from agentscope.memory import InMemoryMemory
|
||||||
@@ -79,9 +94,19 @@ class AgentScopeReActRunner:
|
|||||||
max_iters=6,
|
max_iters=6,
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
response = await agent(Msg(name="user", content=user_prompt, role="user"))
|
started_at = perf_counter()
|
||||||
|
response = await agent(
|
||||||
|
Msg(name="user", content=cast(Any, user_prompt), role="user")
|
||||||
|
)
|
||||||
|
latency_ms = int(round((perf_counter() - started_at) * 1000))
|
||||||
text_content = response.get_text_content() or "{}"
|
text_content = response.get_text_content() or "{}"
|
||||||
return _parse_json_text(text_content)
|
payload = _parse_json_text(text_content)
|
||||||
|
return _merge_stage_response_metadata(
|
||||||
|
payload=payload,
|
||||||
|
stage_config=stage_config,
|
||||||
|
response=response,
|
||||||
|
latency_ms=latency_ms,
|
||||||
|
)
|
||||||
except json.JSONDecodeError as exc:
|
except json.JSONDecodeError as exc:
|
||||||
logger.exception(
|
logger.exception(
|
||||||
"agentscope stage output is not valid json",
|
"agentscope stage output is not valid json",
|
||||||
@@ -96,3 +121,234 @@ class AgentScopeReActRunner:
|
|||||||
agent_name=agent_name,
|
agent_name=agent_name,
|
||||||
)
|
)
|
||||||
raise RuntimeError("agent execution failed") from exc
|
raise RuntimeError("agent execution failed") from exc
|
||||||
|
|
||||||
|
async def _run_report_stage_direct(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
stage_config: RuntimeStageConfig,
|
||||||
|
system_prompt: str,
|
||||||
|
user_prompt: str | list[dict[str, Any]],
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
try:
|
||||||
|
service = self._build_litellm_service()
|
||||||
|
started_at = perf_counter()
|
||||||
|
response_with_cost = await asyncio.to_thread(
|
||||||
|
service.run_completion_with_cost,
|
||||||
|
model=_to_litellm_model(
|
||||||
|
provider_name=stage_config.provider_name,
|
||||||
|
model_code=stage_config.model_code,
|
||||||
|
),
|
||||||
|
messages=_report_messages(
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
user_prompt=user_prompt,
|
||||||
|
),
|
||||||
|
temperature=stage_config.llm_config.temperature,
|
||||||
|
max_tokens=stage_config.llm_config.max_tokens,
|
||||||
|
timeout=stage_config.llm_config.timeout_seconds,
|
||||||
|
response_format={"type": "json_object"},
|
||||||
|
)
|
||||||
|
latency_ms = int(round((perf_counter() - started_at) * 1000))
|
||||||
|
|
||||||
|
text_content = _chat_response_text(response_with_cost.response)
|
||||||
|
payload = _parse_json_text(text_content)
|
||||||
|
return _merge_report_response_metadata(
|
||||||
|
payload=payload,
|
||||||
|
stage_config=stage_config,
|
||||||
|
response_with_cost=response_with_cost,
|
||||||
|
latency_ms=latency_ms,
|
||||||
|
)
|
||||||
|
except json.JSONDecodeError as exc:
|
||||||
|
logger.exception(
|
||||||
|
"agentscope stage output is not valid json",
|
||||||
|
stage=stage_config.stage,
|
||||||
|
agent_name="report-agent",
|
||||||
|
)
|
||||||
|
raise RuntimeError("agent output format invalid") from exc
|
||||||
|
except Exception as exc:
|
||||||
|
logger.exception(
|
||||||
|
"agentscope stage execution failed",
|
||||||
|
stage=stage_config.stage,
|
||||||
|
agent_name="report-agent",
|
||||||
|
)
|
||||||
|
raise RuntimeError("agent execution failed") from exc
|
||||||
|
|
||||||
|
|
||||||
|
def _chat_response_text(response: Any) -> str:
|
||||||
|
content = _read_value(response, "content")
|
||||||
|
if isinstance(content, str) and content.strip():
|
||||||
|
return content
|
||||||
|
|
||||||
|
if not isinstance(content, list):
|
||||||
|
return _fallback_choice_content(response)
|
||||||
|
|
||||||
|
text_parts: list[str] = []
|
||||||
|
for block in content:
|
||||||
|
if not isinstance(block, dict):
|
||||||
|
continue
|
||||||
|
if block.get("type") != "text":
|
||||||
|
continue
|
||||||
|
text = block.get("text")
|
||||||
|
if isinstance(text, str) and text:
|
||||||
|
text_parts.append(text)
|
||||||
|
if text_parts:
|
||||||
|
return "".join(text_parts)
|
||||||
|
return _fallback_choice_content(response)
|
||||||
|
|
||||||
|
|
||||||
|
def _fallback_choice_content(response: Any) -> str:
|
||||||
|
choices = _read_value(response, "choices")
|
||||||
|
if not isinstance(choices, list) or not choices:
|
||||||
|
return "{}"
|
||||||
|
|
||||||
|
first_choice = choices[0]
|
||||||
|
message = getattr(first_choice, "message", None)
|
||||||
|
if message is None and isinstance(first_choice, dict):
|
||||||
|
message = first_choice.get("message")
|
||||||
|
|
||||||
|
if isinstance(message, dict):
|
||||||
|
content = message.get("content")
|
||||||
|
return content if isinstance(content, str) and content else "{}"
|
||||||
|
|
||||||
|
content = _read_value(message, "content")
|
||||||
|
return content if isinstance(content, str) and content else "{}"
|
||||||
|
|
||||||
|
|
||||||
|
def _read_value(source: Any, key: str) -> Any:
|
||||||
|
if isinstance(source, dict):
|
||||||
|
return source.get(key)
|
||||||
|
return getattr(source, key, None)
|
||||||
|
|
||||||
|
|
||||||
|
def _report_messages(
|
||||||
|
*, system_prompt: str, user_prompt: str | list[dict[str, Any]]
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
return [
|
||||||
|
{"role": "system", "content": system_prompt},
|
||||||
|
{"role": "user", "content": user_prompt},
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def _merge_stage_response_metadata(
|
||||||
|
*,
|
||||||
|
payload: dict[str, Any],
|
||||||
|
stage_config: RuntimeStageConfig,
|
||||||
|
response: Any,
|
||||||
|
latency_ms: int,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
result = dict(payload)
|
||||||
|
existing = result.get("response_metadata")
|
||||||
|
metadata: dict[str, Any] = dict(existing) if isinstance(existing, dict) else {}
|
||||||
|
metadata.setdefault("model", stage_config.model_code)
|
||||||
|
|
||||||
|
usage = _read_value(response, "usage")
|
||||||
|
prompt_tokens = _to_non_negative_int(
|
||||||
|
_read_value(usage, "prompt_tokens") or _read_value(usage, "input_tokens")
|
||||||
|
)
|
||||||
|
completion_tokens = _to_non_negative_int(
|
||||||
|
_read_value(usage, "completion_tokens") or _read_value(usage, "output_tokens")
|
||||||
|
)
|
||||||
|
cost = _to_non_negative_float(
|
||||||
|
_read_value(usage, "cost")
|
||||||
|
or _read_value(_read_value(usage, "metadata"), "cost")
|
||||||
|
)
|
||||||
|
resolved_model = _read_value(response, "model")
|
||||||
|
if cost is None and prompt_tokens is not None and completion_tokens is not None:
|
||||||
|
estimated_cost = _estimate_cost_by_pricing(
|
||||||
|
model=resolved_model
|
||||||
|
if isinstance(resolved_model, str)
|
||||||
|
else stage_config.model_code,
|
||||||
|
prompt_tokens=prompt_tokens,
|
||||||
|
completion_tokens=completion_tokens,
|
||||||
|
)
|
||||||
|
if estimated_cost is not None:
|
||||||
|
cost = estimated_cost
|
||||||
|
|
||||||
|
if prompt_tokens is not None:
|
||||||
|
metadata["inputTokens"] = prompt_tokens
|
||||||
|
if completion_tokens is not None:
|
||||||
|
metadata["outputTokens"] = completion_tokens
|
||||||
|
if cost is not None:
|
||||||
|
metadata["cost"] = cost
|
||||||
|
if latency_ms >= 0:
|
||||||
|
metadata["latencyMs"] = latency_ms
|
||||||
|
|
||||||
|
result["response_metadata"] = metadata
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def _merge_report_response_metadata(
|
||||||
|
*,
|
||||||
|
payload: dict[str, Any],
|
||||||
|
stage_config: RuntimeStageConfig,
|
||||||
|
response_with_cost: Any,
|
||||||
|
latency_ms: int,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
result = dict(payload)
|
||||||
|
existing = result.get("response_metadata")
|
||||||
|
metadata: dict[str, Any] = dict(existing) if isinstance(existing, dict) else {}
|
||||||
|
usage = _read_value(response_with_cost, "usage")
|
||||||
|
response = _read_value(response_with_cost, "response")
|
||||||
|
|
||||||
|
resolved_model = _read_value(response, "model")
|
||||||
|
if isinstance(resolved_model, str) and resolved_model.strip():
|
||||||
|
metadata["model"] = resolved_model.strip()
|
||||||
|
else:
|
||||||
|
metadata.setdefault("model", stage_config.model_code)
|
||||||
|
|
||||||
|
input_tokens = _to_non_negative_int(_read_value(usage, "prompt_tokens"))
|
||||||
|
output_tokens = _to_non_negative_int(_read_value(usage, "completion_tokens"))
|
||||||
|
cost = _to_non_negative_float(_read_value(usage, "cost"))
|
||||||
|
if input_tokens is not None:
|
||||||
|
metadata["inputTokens"] = input_tokens
|
||||||
|
if output_tokens is not None:
|
||||||
|
metadata["outputTokens"] = output_tokens
|
||||||
|
if cost is not None:
|
||||||
|
metadata["cost"] = cost
|
||||||
|
if latency_ms >= 0:
|
||||||
|
metadata["latencyMs"] = latency_ms
|
||||||
|
|
||||||
|
result["response_metadata"] = metadata
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def _to_non_negative_int(value: Any) -> int | None:
|
||||||
|
if isinstance(value, bool):
|
||||||
|
return None
|
||||||
|
if not isinstance(value, (int, float, str)):
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
parsed = int(value)
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
return None
|
||||||
|
return parsed if parsed >= 0 else None
|
||||||
|
|
||||||
|
|
||||||
|
def _to_non_negative_float(value: Any) -> float | None:
|
||||||
|
if isinstance(value, bool):
|
||||||
|
return None
|
||||||
|
if not isinstance(value, (int, float, str)):
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
parsed = float(value)
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
return None
|
||||||
|
return parsed if parsed >= 0 else None
|
||||||
|
|
||||||
|
|
||||||
|
def _estimate_cost_by_pricing(
|
||||||
|
*, model: str, prompt_tokens: int, completion_tokens: int
|
||||||
|
) -> float | None:
|
||||||
|
normalized_model = model.strip()
|
||||||
|
if not normalized_model:
|
||||||
|
return None
|
||||||
|
from services.litellm.service import LiteLLMService
|
||||||
|
|
||||||
|
service = LiteLLMService()
|
||||||
|
try:
|
||||||
|
return service.calculate_cost(
|
||||||
|
model=normalized_model,
|
||||||
|
prompt_tokens=prompt_tokens,
|
||||||
|
completion_tokens=completion_tokens,
|
||||||
|
)
|
||||||
|
except ValueError:
|
||||||
|
return None
|
||||||
|
|||||||
@@ -1,8 +1,11 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from datetime import datetime, timedelta, timezone
|
||||||
from typing import Any
|
from typing import Any
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
|
from sqlalchemy import select
|
||||||
|
|
||||||
from core.agentscope.events import (
|
from core.agentscope.events import (
|
||||||
AgentScopeAgUiCodec,
|
AgentScopeAgUiCodec,
|
||||||
AgentScopeEventPipeline,
|
AgentScopeEventPipeline,
|
||||||
@@ -18,6 +21,7 @@ from core.config.settings import config
|
|||||||
from core.db.session import AsyncSessionLocal
|
from core.db.session import AsyncSessionLocal
|
||||||
from core.logging import get_logger
|
from core.logging import get_logger
|
||||||
from core.taskiq.app import bulk_broker, critical_broker, default_broker
|
from core.taskiq.app import bulk_broker, critical_broker, default_broker
|
||||||
|
from models.agent_chat_message import AgentChatMessage, AgentChatMessageRole
|
||||||
from services.base.redis import get_or_init_redis_client
|
from services.base.redis import get_or_init_redis_client
|
||||||
|
|
||||||
logger = get_logger("core.agentscope.runtime.tasks")
|
logger = get_logger("core.agentscope.runtime.tasks")
|
||||||
@@ -76,6 +80,56 @@ def _extract_user_token(
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
async def _build_recent_context_messages(
|
||||||
|
*,
|
||||||
|
session: Any,
|
||||||
|
thread_id: str,
|
||||||
|
current_run_id: str,
|
||||||
|
max_messages: int = 20,
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
try:
|
||||||
|
session_uuid = UUID(thread_id)
|
||||||
|
except ValueError:
|
||||||
|
return []
|
||||||
|
|
||||||
|
utc_now = datetime.now(timezone.utc)
|
||||||
|
start_of_today = utc_now.replace(hour=0, minute=0, second=0, microsecond=0)
|
||||||
|
start_of_yesterday = start_of_today - timedelta(days=1)
|
||||||
|
|
||||||
|
stmt = (
|
||||||
|
select(AgentChatMessage)
|
||||||
|
.where(AgentChatMessage.session_id == session_uuid)
|
||||||
|
.where(AgentChatMessage.deleted_at.is_(None))
|
||||||
|
.where(AgentChatMessage.created_at >= start_of_yesterday)
|
||||||
|
.order_by(AgentChatMessage.seq.asc())
|
||||||
|
)
|
||||||
|
rows = (await session.execute(stmt)).scalars().all()
|
||||||
|
|
||||||
|
normalized: list[dict[str, Any]] = []
|
||||||
|
for row in rows:
|
||||||
|
metadata = row.metadata_json if isinstance(row.metadata_json, dict) else {}
|
||||||
|
if metadata.get("run_id") == current_run_id:
|
||||||
|
continue
|
||||||
|
role = (
|
||||||
|
row.role.value
|
||||||
|
if isinstance(row.role, AgentChatMessageRole)
|
||||||
|
else str(row.role)
|
||||||
|
)
|
||||||
|
if role not in {"user", "assistant"}:
|
||||||
|
continue
|
||||||
|
normalized.append(
|
||||||
|
{
|
||||||
|
"id": str(row.id),
|
||||||
|
"role": role,
|
||||||
|
"content": row.content,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(normalized) <= max_messages:
|
||||||
|
return normalized
|
||||||
|
return normalized[-max_messages:]
|
||||||
|
|
||||||
|
|
||||||
async def run_agentscope_task(command: dict[str, Any]) -> dict[str, object]:
|
async def run_agentscope_task(command: dict[str, Any]) -> dict[str, object]:
|
||||||
command_type = str(command.get("command", "run")).strip().lower()
|
command_type = str(command.get("command", "run")).strip().lower()
|
||||||
raw_run_input = command.get("run_input")
|
raw_run_input = command.get("run_input")
|
||||||
@@ -117,6 +171,21 @@ async def run_agentscope_task(command: dict[str, Any]) -> dict[str, object]:
|
|||||||
)
|
)
|
||||||
|
|
||||||
async with AsyncSessionLocal() as session:
|
async with AsyncSessionLocal() as session:
|
||||||
|
if command_type == "run":
|
||||||
|
context_messages = await _build_recent_context_messages(
|
||||||
|
session=session,
|
||||||
|
thread_id=parsed_run_input.thread_id,
|
||||||
|
current_run_id=parsed_run_input.run_id,
|
||||||
|
)
|
||||||
|
parsed_run_input = parsed_run_input.model_copy(
|
||||||
|
update={
|
||||||
|
"messages": [
|
||||||
|
*context_messages,
|
||||||
|
*parsed_run_input.messages,
|
||||||
|
]
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
if command_type == "resume":
|
if command_type == "resume":
|
||||||
await runtime.resume(
|
await runtime.resume(
|
||||||
command=parsed_run_input,
|
command=parsed_run_input,
|
||||||
|
|||||||
@@ -5,12 +5,21 @@ from typing import Any, Literal
|
|||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
class ExecutionToolCall(BaseModel):
|
||||||
|
tool_name: str = Field(min_length=1)
|
||||||
|
args: dict[str, Any] = Field(default_factory=dict)
|
||||||
|
result: Any | None = None
|
||||||
|
error: str | None = None
|
||||||
|
|
||||||
|
|
||||||
class ExecutionTaskOutput(BaseModel):
|
class ExecutionTaskOutput(BaseModel):
|
||||||
task_id: str = Field(min_length=1)
|
task_id: str = Field(min_length=1)
|
||||||
status: Literal["SUCCESS", "PARTIAL", "FAILED"]
|
status: Literal["SUCCESS", "PARTIAL", "FAILED"]
|
||||||
execution_summary: str = Field(min_length=1)
|
execution_summary: str = Field(min_length=1)
|
||||||
execution_data: dict[str, Any] = Field(default_factory=dict)
|
execution_data: dict[str, Any] = Field(default_factory=dict)
|
||||||
user_feedback_needs: list[str] = Field(default_factory=list)
|
user_feedback_needs: list[str] = Field(default_factory=list)
|
||||||
|
response_metadata: dict[str, Any] = Field(default_factory=dict)
|
||||||
|
tool_calls: list[ExecutionToolCall] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
class ExecutionBatchOutput(BaseModel):
|
class ExecutionBatchOutput(BaseModel):
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
from typing import Literal
|
from typing import Literal
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, model_validator
|
from pydantic import BaseModel, Field, model_validator
|
||||||
@@ -17,6 +18,7 @@ class IntentOutput(BaseModel):
|
|||||||
direct_response: str | None = None
|
direct_response: str | None = None
|
||||||
tasks: list[IntentTask] = Field(default_factory=list)
|
tasks: list[IntentTask] = Field(default_factory=list)
|
||||||
complexity: Literal["simple", "complex"]
|
complexity: Literal["simple", "complex"]
|
||||||
|
response_metadata: dict[str, Any] = Field(default_factory=dict)
|
||||||
|
|
||||||
@model_validator(mode="after")
|
@model_validator(mode="after")
|
||||||
def validate_route(self) -> "IntentOutput":
|
def validate_route(self) -> "IntentOutput":
|
||||||
|
|||||||
@@ -1,3 +1,7 @@
|
|||||||
from core.agentscope.tools.custom.calendar import calendar_read, calendar_write
|
from core.agentscope.tools.custom.calendar import (
|
||||||
|
calendar_read,
|
||||||
|
calendar_write,
|
||||||
|
user_resolve,
|
||||||
|
)
|
||||||
|
|
||||||
__all__ = ["calendar_read", "calendar_write"]
|
__all__ = ["calendar_read", "calendar_write", "user_resolve"]
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ from core.auth.jwt_verifier import JwtVerifier, TokenValidationError
|
|||||||
from core.agentscope.tools.custom.calendar_backend_ops import (
|
from core.agentscope.tools.custom.calendar_backend_ops import (
|
||||||
_execute_list_calendar_events,
|
_execute_list_calendar_events,
|
||||||
_execute_mutate_calendar_event,
|
_execute_mutate_calendar_event,
|
||||||
|
_execute_resolve_user_identity,
|
||||||
)
|
)
|
||||||
from core.config.settings import config
|
from core.config.settings import config
|
||||||
from core.agentscope.tools.response import build_tool_response
|
from core.agentscope.tools.response import build_tool_response
|
||||||
@@ -150,6 +151,30 @@ async def calendar_write(
|
|||||||
bool,
|
bool,
|
||||||
Field(description="Whether to use the replace strategy for conflicts."),
|
Field(description="Whether to use the replace strategy for conflicts."),
|
||||||
] = False,
|
] = False,
|
||||||
|
invite_user_emails: Annotated[
|
||||||
|
list[str] | None,
|
||||||
|
Field(description="Optional invite targets by email."),
|
||||||
|
] = None,
|
||||||
|
invite_user_names: Annotated[
|
||||||
|
list[str] | None,
|
||||||
|
Field(description="Optional invite targets by username."),
|
||||||
|
] = None,
|
||||||
|
invite_user_ids: Annotated[
|
||||||
|
list[str] | None,
|
||||||
|
Field(description="Optional invite targets by user ID (UUID string)."),
|
||||||
|
] = None,
|
||||||
|
invite_permission_view: Annotated[
|
||||||
|
bool,
|
||||||
|
Field(description="Invite permission: view."),
|
||||||
|
] = True,
|
||||||
|
invite_permission_edit: Annotated[
|
||||||
|
bool,
|
||||||
|
Field(description="Invite permission: edit."),
|
||||||
|
] = False,
|
||||||
|
invite_permission_invite: Annotated[
|
||||||
|
bool,
|
||||||
|
Field(description="Invite permission: invite others."),
|
||||||
|
] = False,
|
||||||
session: Any = None,
|
session: Any = None,
|
||||||
owner_id: Any = None,
|
owner_id: Any = None,
|
||||||
user_token: str | None = None,
|
user_token: str | None = None,
|
||||||
@@ -240,6 +265,15 @@ async def calendar_write(
|
|||||||
tool_args["reminderMinutes"] = reminder_minutes
|
tool_args["reminderMinutes"] = reminder_minutes
|
||||||
if status is not None:
|
if status is not None:
|
||||||
tool_args["status"] = status
|
tool_args["status"] = status
|
||||||
|
if invite_user_emails is not None:
|
||||||
|
tool_args["inviteUserEmails"] = invite_user_emails
|
||||||
|
if invite_user_names is not None:
|
||||||
|
tool_args["inviteUserNames"] = invite_user_names
|
||||||
|
if invite_user_ids is not None:
|
||||||
|
tool_args["inviteUserIds"] = invite_user_ids
|
||||||
|
tool_args["invitePermissionView"] = invite_permission_view
|
||||||
|
tool_args["invitePermissionEdit"] = invite_permission_edit
|
||||||
|
tool_args["invitePermissionInvite"] = invite_permission_invite
|
||||||
|
|
||||||
result = await _execute_mutate_calendar_event(
|
result = await _execute_mutate_calendar_event(
|
||||||
session=cast(Any, session),
|
session=cast(Any, session),
|
||||||
@@ -247,3 +281,34 @@ async def calendar_write(
|
|||||||
tool_args=tool_args,
|
tool_args=tool_args,
|
||||||
)
|
)
|
||||||
return build_tool_response(result)
|
return build_tool_response(result)
|
||||||
|
|
||||||
|
|
||||||
|
async def user_resolve(
|
||||||
|
user_email: Annotated[
|
||||||
|
str | None,
|
||||||
|
Field(description="User email to resolve user ID."),
|
||||||
|
] = None,
|
||||||
|
user_name: Annotated[
|
||||||
|
str | None,
|
||||||
|
Field(description="Username to resolve user ID."),
|
||||||
|
] = None,
|
||||||
|
session: Any = None,
|
||||||
|
owner_id: Any = None,
|
||||||
|
user_token: str | None = None,
|
||||||
|
) -> Any:
|
||||||
|
if session is None or owner_id is None:
|
||||||
|
raise ValueError("user.resolve missing runtime preset arguments")
|
||||||
|
if not isinstance(user_token, str) or not user_token.strip():
|
||||||
|
return build_tool_response(_unauthorized_response())
|
||||||
|
if not _verify_user_token(user_token=user_token, owner_id=cast(UUID, owner_id)):
|
||||||
|
return build_tool_response(_unauthorized_response())
|
||||||
|
|
||||||
|
result = await _execute_resolve_user_identity(
|
||||||
|
session=cast(Any, session),
|
||||||
|
owner_id=cast(UUID, owner_id),
|
||||||
|
tool_args={
|
||||||
|
"userEmail": user_email,
|
||||||
|
"userName": user_name,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
return build_tool_response(result)
|
||||||
|
|||||||
@@ -4,13 +4,20 @@ import re
|
|||||||
from datetime import datetime, timedelta, timezone
|
from datetime import datetime, timedelta, timezone
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
|
from fastapi import HTTPException
|
||||||
|
from sqlalchemy import select
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from core.auth.models import CurrentUser
|
from core.auth.models import CurrentUser
|
||||||
|
from services.base.supabase import supabase_service
|
||||||
|
from models.profile import Profile
|
||||||
|
from v1.auth.gateway import SupabaseAuthGateway
|
||||||
|
from v1.inbox_messages.repository import SQLAlchemyInboxMessageRepository
|
||||||
from v1.schedule_items.repository import SQLAlchemyScheduleItemRepository
|
from v1.schedule_items.repository import SQLAlchemyScheduleItemRepository
|
||||||
from v1.schedule_items.schemas import (
|
from v1.schedule_items.schemas import (
|
||||||
ScheduleItemCreateRequest,
|
ScheduleItemCreateRequest,
|
||||||
ScheduleItemMetadata,
|
ScheduleItemMetadata,
|
||||||
|
ScheduleItemShareRequest,
|
||||||
ScheduleItemStatus,
|
ScheduleItemStatus,
|
||||||
ScheduleItemUpdateRequest,
|
ScheduleItemUpdateRequest,
|
||||||
)
|
)
|
||||||
@@ -72,9 +79,196 @@ def _service(session: AsyncSession, owner_id: UUID) -> ScheduleItemService:
|
|||||||
repository=SQLAlchemyScheduleItemRepository(session),
|
repository=SQLAlchemyScheduleItemRepository(session),
|
||||||
session=session,
|
session=session,
|
||||||
current_user=CurrentUser(id=owner_id),
|
current_user=CurrentUser(id=owner_id),
|
||||||
|
inbox_repository=SQLAlchemyInboxMessageRepository(session),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_string_list(value: object, *, field_name: str) -> list[str]:
|
||||||
|
if value is None:
|
||||||
|
return []
|
||||||
|
if not isinstance(value, list):
|
||||||
|
raise ValueError(f"{field_name} must be a list of strings")
|
||||||
|
parsed: list[str] = []
|
||||||
|
for item in value:
|
||||||
|
if not isinstance(item, str) or not item.strip():
|
||||||
|
raise ValueError(f"{field_name} must be a list of non-empty strings")
|
||||||
|
parsed.append(item.strip())
|
||||||
|
return parsed
|
||||||
|
|
||||||
|
|
||||||
|
def _list_auth_users() -> list[object]:
|
||||||
|
admin_client = supabase_service.get_admin_client()
|
||||||
|
users: list[object] = []
|
||||||
|
page = 1
|
||||||
|
while page <= 100:
|
||||||
|
response = admin_client.auth.admin.list_users(page=page, per_page=100)
|
||||||
|
batch = (
|
||||||
|
list(response)
|
||||||
|
if isinstance(response, list)
|
||||||
|
else list(getattr(response, "users", []))
|
||||||
|
)
|
||||||
|
users.extend(batch)
|
||||||
|
if len(batch) < 100:
|
||||||
|
break
|
||||||
|
page += 1
|
||||||
|
return users
|
||||||
|
|
||||||
|
|
||||||
|
async def _get_profile_username(*, session: AsyncSession, user_id: UUID) -> str | None:
|
||||||
|
stmt = select(Profile.username).where(Profile.id == user_id)
|
||||||
|
return (await session.execute(stmt)).scalar_one_or_none()
|
||||||
|
|
||||||
|
|
||||||
|
async def _get_profile_by_username(
|
||||||
|
*, session: AsyncSession, username: str
|
||||||
|
) -> Profile | None:
|
||||||
|
stmt = (
|
||||||
|
select(Profile)
|
||||||
|
.where(Profile.username == username)
|
||||||
|
.where(Profile.deleted_at.is_(None))
|
||||||
|
)
|
||||||
|
return (await session.execute(stmt)).scalar_one_or_none()
|
||||||
|
|
||||||
|
|
||||||
|
def _find_auth_email_by_user_id(*, users: list[object], user_id: UUID) -> str | None:
|
||||||
|
target = str(user_id)
|
||||||
|
for user in users:
|
||||||
|
if str(getattr(user, "id", "")) == target:
|
||||||
|
email = getattr(user, "email", None)
|
||||||
|
if isinstance(email, str) and email.strip():
|
||||||
|
return email.strip()
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
async def _resolve_identity(
|
||||||
|
*,
|
||||||
|
session: AsyncSession,
|
||||||
|
user_email: str | None,
|
||||||
|
user_name: str | None,
|
||||||
|
) -> dict[str, object]:
|
||||||
|
email = user_email.strip().lower() if isinstance(user_email, str) else ""
|
||||||
|
name = user_name.strip() if isinstance(user_name, str) else ""
|
||||||
|
if bool(email) == bool(name):
|
||||||
|
raise ValueError("provide exactly one of user_email or user_name")
|
||||||
|
|
||||||
|
if email:
|
||||||
|
auth_gateway = SupabaseAuthGateway()
|
||||||
|
user = await auth_gateway.get_user_by_email(email)
|
||||||
|
user_id = UUID(user.id)
|
||||||
|
username = await _get_profile_username(session=session, user_id=user_id)
|
||||||
|
return {
|
||||||
|
"userId": str(user_id),
|
||||||
|
"email": user.email,
|
||||||
|
"username": username,
|
||||||
|
"matchedBy": "email",
|
||||||
|
}
|
||||||
|
|
||||||
|
profile = await _get_profile_by_username(session=session, username=name)
|
||||||
|
if profile is None:
|
||||||
|
raise HTTPException(status_code=404, detail="User not found")
|
||||||
|
users = _list_auth_users()
|
||||||
|
email_value = _find_auth_email_by_user_id(users=users, user_id=profile.id)
|
||||||
|
return {
|
||||||
|
"userId": str(profile.id),
|
||||||
|
"email": email_value,
|
||||||
|
"username": profile.username,
|
||||||
|
"matchedBy": "username",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _invite_permission(tool_args: dict[str, object]) -> dict[str, bool]:
|
||||||
|
return {
|
||||||
|
"permission_view": bool(tool_args.get("invitePermissionView", True)),
|
||||||
|
"permission_edit": bool(tool_args.get("invitePermissionEdit", False)),
|
||||||
|
"permission_invite": bool(tool_args.get("invitePermissionInvite", False)),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async def _share_event_with_invitees(
|
||||||
|
*,
|
||||||
|
session: AsyncSession,
|
||||||
|
owner_id: UUID,
|
||||||
|
event_id: UUID,
|
||||||
|
tool_args: dict[str, object],
|
||||||
|
) -> dict[str, object] | None:
|
||||||
|
email_targets = _parse_string_list(
|
||||||
|
tool_args.get("inviteUserEmails"),
|
||||||
|
field_name="inviteUserEmails",
|
||||||
|
)
|
||||||
|
name_targets = _parse_string_list(
|
||||||
|
tool_args.get("inviteUserNames"),
|
||||||
|
field_name="inviteUserNames",
|
||||||
|
)
|
||||||
|
id_targets = _parse_string_list(
|
||||||
|
tool_args.get("inviteUserIds"),
|
||||||
|
field_name="inviteUserIds",
|
||||||
|
)
|
||||||
|
if not email_targets and not name_targets and not id_targets:
|
||||||
|
return None
|
||||||
|
|
||||||
|
users = _list_auth_users() if id_targets else []
|
||||||
|
emails = {item.lower() for item in email_targets}
|
||||||
|
for user_id_raw in id_targets:
|
||||||
|
try:
|
||||||
|
user_id = UUID(user_id_raw)
|
||||||
|
except ValueError as exc:
|
||||||
|
raise ValueError("inviteUserIds must contain valid UUID strings") from exc
|
||||||
|
resolved_email = _find_auth_email_by_user_id(users=users, user_id=user_id)
|
||||||
|
if resolved_email is None:
|
||||||
|
raise HTTPException(status_code=404, detail="Invite user email not found")
|
||||||
|
emails.add(resolved_email.lower())
|
||||||
|
for username in name_targets:
|
||||||
|
resolved = await _resolve_identity(
|
||||||
|
session=session,
|
||||||
|
user_email=None,
|
||||||
|
user_name=username,
|
||||||
|
)
|
||||||
|
resolved_email = resolved.get("email")
|
||||||
|
if not isinstance(resolved_email, str) or not resolved_email:
|
||||||
|
raise HTTPException(status_code=404, detail="Invite user email not found")
|
||||||
|
emails.add(resolved_email.lower())
|
||||||
|
|
||||||
|
service = _service(session, owner_id)
|
||||||
|
permission = _invite_permission(tool_args)
|
||||||
|
invited: list[str] = []
|
||||||
|
for email in sorted(emails):
|
||||||
|
request = ScheduleItemShareRequest(email=email, **permission)
|
||||||
|
await service.share(event_id, request)
|
||||||
|
invited.append(email)
|
||||||
|
return {
|
||||||
|
"count": len(invited),
|
||||||
|
"emails": invited,
|
||||||
|
"permission": permission,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async def _execute_resolve_user_identity(
|
||||||
|
*,
|
||||||
|
session: AsyncSession,
|
||||||
|
owner_id: UUID,
|
||||||
|
tool_args: dict[str, object],
|
||||||
|
) -> dict[str, object]:
|
||||||
|
del owner_id
|
||||||
|
user_email_raw = tool_args.get("userEmail")
|
||||||
|
user_name_raw = tool_args.get("userName")
|
||||||
|
user_email = user_email_raw if isinstance(user_email_raw, str) else None
|
||||||
|
user_name = user_name_raw if isinstance(user_name_raw, str) else None
|
||||||
|
resolved = await _resolve_identity(
|
||||||
|
session=session,
|
||||||
|
user_email=user_email,
|
||||||
|
user_name=user_name,
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
"type": "user_lookup.v1",
|
||||||
|
"version": "v1",
|
||||||
|
"data": {
|
||||||
|
"ok": True,
|
||||||
|
**resolved,
|
||||||
|
},
|
||||||
|
"actions": [],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def _resolve_metadata(tool_args: dict[str, object]) -> ScheduleItemMetadata:
|
def _resolve_metadata(tool_args: dict[str, object]) -> ScheduleItemMetadata:
|
||||||
location = tool_args.get("location")
|
location = tool_args.get("location")
|
||||||
location_value = location.strip() if isinstance(location, str) else None
|
location_value = location.strip() if isinstance(location, str) else None
|
||||||
@@ -185,6 +379,12 @@ async def _execute_create(
|
|||||||
)
|
)
|
||||||
event_data = _event_payload(created)
|
event_data = _event_payload(created)
|
||||||
event_id = str(event_data["id"])
|
event_id = str(event_data["id"])
|
||||||
|
invite_result = await _share_event_with_invitees(
|
||||||
|
session=service._session,
|
||||||
|
owner_id=service.require_user_id(),
|
||||||
|
event_id=UUID(event_id),
|
||||||
|
tool_args=tool_args,
|
||||||
|
)
|
||||||
return {
|
return {
|
||||||
"type": "calendar_card.v1",
|
"type": "calendar_card.v1",
|
||||||
"version": "v1",
|
"version": "v1",
|
||||||
@@ -193,12 +393,13 @@ async def _execute_create(
|
|||||||
"sourceType": "agent_generated",
|
"sourceType": "agent_generated",
|
||||||
"ok": True,
|
"ok": True,
|
||||||
"message": "日程已创建",
|
"message": "日程已创建",
|
||||||
|
"inviteResult": invite_result,
|
||||||
},
|
},
|
||||||
"actions": [
|
"actions": [
|
||||||
{
|
{
|
||||||
"type": "link",
|
"type": "link",
|
||||||
"label": "查看详情",
|
"label": "查看详情",
|
||||||
"target": f"/calendar/events/{event_id}",
|
"target": f"/schedule-items/{event_id}",
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
@@ -274,6 +475,12 @@ async def _execute_update(
|
|||||||
ScheduleItemUpdateRequest.model_validate(update_data),
|
ScheduleItemUpdateRequest.model_validate(update_data),
|
||||||
)
|
)
|
||||||
event_data = _event_payload(updated)
|
event_data = _event_payload(updated)
|
||||||
|
invite_result = await _share_event_with_invitees(
|
||||||
|
session=service._session,
|
||||||
|
owner_id=service.require_user_id(),
|
||||||
|
event_id=UUID(str(event_data["id"])),
|
||||||
|
tool_args=tool_args,
|
||||||
|
)
|
||||||
return {
|
return {
|
||||||
"type": "calendar_card.v1",
|
"type": "calendar_card.v1",
|
||||||
"version": "v1",
|
"version": "v1",
|
||||||
@@ -282,12 +489,13 @@ async def _execute_update(
|
|||||||
"sourceType": "agent_generated",
|
"sourceType": "agent_generated",
|
||||||
"ok": True,
|
"ok": True,
|
||||||
"message": "日程已更新",
|
"message": "日程已更新",
|
||||||
|
"inviteResult": invite_result,
|
||||||
},
|
},
|
||||||
"actions": [
|
"actions": [
|
||||||
{
|
{
|
||||||
"type": "link",
|
"type": "link",
|
||||||
"label": "查看详情",
|
"label": "查看详情",
|
||||||
"target": f"/calendar/events/{event_data['id']}",
|
"target": f"/schedule-items/{event_data['id']}",
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,8 +4,9 @@ from dataclasses import dataclass
|
|||||||
|
|
||||||
|
|
||||||
TOOL_APPROVAL_REQUIRED: dict[str, bool] = {
|
TOOL_APPROVAL_REQUIRED: dict[str, bool] = {
|
||||||
"calendar.read": False,
|
"calendar_read": False,
|
||||||
"calendar.write": False,
|
"calendar_write": False,
|
||||||
|
"user_resolve": False,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -6,7 +6,11 @@ from uuid import UUID
|
|||||||
|
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from core.agentscope.tools.custom.calendar import calendar_read, calendar_write
|
from core.agentscope.tools.custom.calendar import (
|
||||||
|
calendar_read,
|
||||||
|
calendar_write,
|
||||||
|
user_resolve,
|
||||||
|
)
|
||||||
from core.agentscope.tools.hitl_middleware import register_tool_middlewares
|
from core.agentscope.tools.hitl_middleware import register_tool_middlewares
|
||||||
from core.agentscope.tools.tool_meta import TOOL_META
|
from core.agentscope.tools.tool_meta import TOOL_META
|
||||||
|
|
||||||
@@ -25,10 +29,12 @@ class ToolGroup:
|
|||||||
|
|
||||||
|
|
||||||
TOOL_GROUPS: dict[str, ToolGroup] = {
|
TOOL_GROUPS: dict[str, ToolGroup] = {
|
||||||
"intent": ToolGroup(stage="intent", tool_names=frozenset({"calendar.read"})),
|
"intent": ToolGroup(
|
||||||
|
stage="intent", tool_names=frozenset({"calendar_read", "user_resolve"})
|
||||||
|
),
|
||||||
"execution": ToolGroup(
|
"execution": ToolGroup(
|
||||||
stage="execution",
|
stage="execution",
|
||||||
tool_names=frozenset({"calendar.read", "calendar.write"}),
|
tool_names=frozenset({"calendar_read", "calendar_write", "user_resolve"}),
|
||||||
),
|
),
|
||||||
"report": ToolGroup(stage="report", tool_names=frozenset()),
|
"report": ToolGroup(stage="report", tool_names=frozenset()),
|
||||||
}
|
}
|
||||||
@@ -49,7 +55,7 @@ def _load_custom_tool_bindings(
|
|||||||
) -> list[CustomToolBinding]:
|
) -> list[CustomToolBinding]:
|
||||||
return [
|
return [
|
||||||
CustomToolBinding(
|
CustomToolBinding(
|
||||||
name="calendar.read",
|
name="calendar_read",
|
||||||
func=calendar_read,
|
func=calendar_read,
|
||||||
preset_kwargs={
|
preset_kwargs={
|
||||||
"session": session,
|
"session": session,
|
||||||
@@ -58,7 +64,7 @@ def _load_custom_tool_bindings(
|
|||||||
},
|
},
|
||||||
),
|
),
|
||||||
CustomToolBinding(
|
CustomToolBinding(
|
||||||
name="calendar.write",
|
name="calendar_write",
|
||||||
func=calendar_write,
|
func=calendar_write,
|
||||||
preset_kwargs={
|
preset_kwargs={
|
||||||
"session": session,
|
"session": session,
|
||||||
@@ -66,6 +72,15 @@ def _load_custom_tool_bindings(
|
|||||||
"user_token": user_token or "",
|
"user_token": user_token or "",
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
|
CustomToolBinding(
|
||||||
|
name="user_resolve",
|
||||||
|
func=user_resolve,
|
||||||
|
preset_kwargs={
|
||||||
|
"session": session,
|
||||||
|
"owner_id": owner_id,
|
||||||
|
"user_token": user_token or "",
|
||||||
|
},
|
||||||
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -126,21 +126,26 @@ class LiteLLMService:
|
|||||||
temperature: float | None = None,
|
temperature: float | None = None,
|
||||||
max_tokens: int | None = None,
|
max_tokens: int | None = None,
|
||||||
timeout: float | None = None,
|
timeout: float | None = None,
|
||||||
|
response_format: dict[str, Any] | None = None,
|
||||||
completion_fn: Callable[..., dict[str, Any]] | None = None,
|
completion_fn: Callable[..., dict[str, Any]] | None = None,
|
||||||
) -> LiteLLMResponseWithCost:
|
) -> LiteLLMResponseWithCost:
|
||||||
caller = completion_fn or completion
|
caller = completion_fn or completion
|
||||||
request_model = model if model.startswith("openai/") else f"openai/{model}"
|
request_model = model if model.startswith("openai/") else f"openai/{model}"
|
||||||
|
|
||||||
response_any = caller(
|
request_kwargs: dict[str, Any] = {
|
||||||
model=request_model,
|
"model": request_model,
|
||||||
api_key=self.proxy_api_key,
|
"api_key": self.proxy_api_key,
|
||||||
api_base=self.proxy_base_url,
|
"api_base": self.proxy_base_url,
|
||||||
messages=messages,
|
"messages": messages,
|
||||||
temperature=temperature,
|
"temperature": temperature,
|
||||||
max_tokens=max_tokens,
|
"max_tokens": max_tokens,
|
||||||
timeout=timeout,
|
"timeout": timeout,
|
||||||
stream=False,
|
"stream": False,
|
||||||
)
|
}
|
||||||
|
if response_format is not None:
|
||||||
|
request_kwargs["response_format"] = response_format
|
||||||
|
|
||||||
|
response_any = caller(**request_kwargs)
|
||||||
response = self._normalize_response(response_any)
|
response = self._normalize_response(response_any)
|
||||||
|
|
||||||
usage_raw = response.get("usage")
|
usage_raw = response.get("usage")
|
||||||
|
|||||||
@@ -107,6 +107,10 @@ class AgentRepository:
|
|||||||
raise HTTPException(status_code=404, detail="Session not found")
|
raise HTTPException(status_code=404, detail="Session not found")
|
||||||
|
|
||||||
next_seq = int(session_row.message_count or 0) + 1
|
next_seq = int(session_row.message_count or 0) + 1
|
||||||
|
if not _has_title(session_row.title):
|
||||||
|
session_title = _derive_session_title(content_text)
|
||||||
|
if session_title is not None:
|
||||||
|
session_row.title = session_title
|
||||||
payload_metadata = dict(metadata or {})
|
payload_metadata = dict(metadata or {})
|
||||||
payload_metadata["run_id"] = run_id
|
payload_metadata["run_id"] = run_id
|
||||||
message = AgentChatMessage(
|
message = AgentChatMessage(
|
||||||
@@ -264,3 +268,14 @@ class AgentRepository:
|
|||||||
if rendered:
|
if rendered:
|
||||||
payload["attachments"] = rendered
|
payload["attachments"] = rendered
|
||||||
return payload
|
return payload
|
||||||
|
|
||||||
|
|
||||||
|
def _has_title(title: object) -> bool:
|
||||||
|
return isinstance(title, str) and bool(title.strip())
|
||||||
|
|
||||||
|
|
||||||
|
def _derive_session_title(content_text: str) -> str | None:
|
||||||
|
normalized = " ".join(content_text.split())
|
||||||
|
if not normalized:
|
||||||
|
return None
|
||||||
|
return normalized[:80]
|
||||||
|
|||||||
@@ -203,6 +203,11 @@ async def stream_events(
|
|||||||
user_id=str(current_user.id),
|
user_id=str(current_user.id),
|
||||||
reason=str(exc),
|
reason=str(exc),
|
||||||
)
|
)
|
||||||
|
if "Timeout reading from" in str(exc):
|
||||||
|
idle_polls += 1
|
||||||
|
yield ": keep-alive\n\n"
|
||||||
|
await asyncio.sleep(0.2)
|
||||||
|
continue
|
||||||
break
|
break
|
||||||
|
|
||||||
if not rows:
|
if not rows:
|
||||||
|
|||||||
@@ -212,12 +212,19 @@ class AgentService:
|
|||||||
content_type=mime_type,
|
content_type=mime_type,
|
||||||
)
|
)
|
||||||
except Exception: # noqa: BLE001
|
except Exception: # noqa: BLE001
|
||||||
bucket_name = "private"
|
logger.exception(
|
||||||
stored_path = await self._attachment_storage.upload_bytes(
|
"Attachment upload failed",
|
||||||
bucket=bucket_name,
|
extra={
|
||||||
path=path,
|
"bucket": bucket_name,
|
||||||
content=payload,
|
"path": path,
|
||||||
content_type=mime_type,
|
"mime_type": mime_type,
|
||||||
|
"thread_id": run_input.thread_id,
|
||||||
|
"run_id": run_input.run_id,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=502,
|
||||||
|
detail="Failed to upload attachment",
|
||||||
)
|
)
|
||||||
attachments.append(
|
attachments.append(
|
||||||
{
|
{
|
||||||
|
|||||||
@@ -1,6 +1,8 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import base64
|
||||||
import os
|
import os
|
||||||
|
from pathlib import Path
|
||||||
from uuid import UUID, uuid4
|
from uuid import UUID, uuid4
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
@@ -12,6 +14,9 @@ from models.agent_chat_message import AgentChatMessage
|
|||||||
from models.agent_chat_session import AgentChatSession
|
from models.agent_chat_session import AgentChatSession
|
||||||
|
|
||||||
BASE_URL = os.getenv("AGENT_LIVE_BASE_URL", "http://localhost:5775")
|
BASE_URL = os.getenv("AGENT_LIVE_BASE_URL", "http://localhost:5775")
|
||||||
|
FIXTURE_IMAGE_PATH = (
|
||||||
|
Path(__file__).resolve().parents[3] / "fixtures" / "images" / "calendar_text_cn.png"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def _live_access_token(client: httpx.AsyncClient) -> str:
|
async def _live_access_token(client: httpx.AsyncClient) -> str:
|
||||||
@@ -108,6 +113,8 @@ async def test_agent_runs_events_history_live_with_image_input() -> None:
|
|||||||
if os.getenv("AGENT_LIVE_INTEGRATION") != "1":
|
if os.getenv("AGENT_LIVE_INTEGRATION") != "1":
|
||||||
pytest.skip("set AGENT_LIVE_INTEGRATION=1 to run live integration test")
|
pytest.skip("set AGENT_LIVE_INTEGRATION=1 to run live integration test")
|
||||||
|
|
||||||
|
image_data = base64.b64encode(FIXTURE_IMAGE_PATH.read_bytes()).decode("ascii")
|
||||||
|
|
||||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||||
token = await _live_access_token(client)
|
token = await _live_access_token(client)
|
||||||
headers = {"Authorization": f"Bearer {token}"}
|
headers = {"Authorization": f"Bearer {token}"}
|
||||||
@@ -128,7 +135,7 @@ async def test_agent_runs_events_history_live_with_image_input() -> None:
|
|||||||
{"type": "text", "text": "请描述图片里的内容"},
|
{"type": "text", "text": "请描述图片里的内容"},
|
||||||
{
|
{
|
||||||
"type": "binary",
|
"type": "binary",
|
||||||
"data": "aGVsbG8=",
|
"data": image_data,
|
||||||
"mimeType": "image/png",
|
"mimeType": "image/png",
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
@@ -142,19 +149,20 @@ async def test_agent_runs_events_history_live_with_image_input() -> None:
|
|||||||
assert run_resp.status_code == 202
|
assert run_resp.status_code == 202
|
||||||
|
|
||||||
events_url = f"{BASE_URL}/api/v1/agent/runs/{thread_id}/events"
|
events_url = f"{BASE_URL}/api/v1/agent/runs/{thread_id}/events"
|
||||||
sse_resp = await client.get(
|
event_names: list[str] = []
|
||||||
events_url,
|
async with client.stream(
|
||||||
headers=headers,
|
"GET", events_url, headers=headers, timeout=90.0
|
||||||
params={"idle_limit": 150},
|
) as sse_resp:
|
||||||
timeout=60.0,
|
assert sse_resp.status_code == 200
|
||||||
)
|
assert sse_resp.headers.get("content-type", "").startswith(
|
||||||
assert sse_resp.status_code == 200
|
"text/event-stream"
|
||||||
assert sse_resp.headers.get("content-type", "").startswith("text/event-stream")
|
)
|
||||||
event_names = [
|
async for line in sse_resp.aiter_lines():
|
||||||
line.split(":", 1)[1].strip()
|
if line.startswith("event:"):
|
||||||
for line in sse_resp.text.splitlines()
|
event_name = line.split(":", 1)[1].strip()
|
||||||
if line.startswith("event:")
|
event_names.append(event_name)
|
||||||
]
|
if event_name in {"RUN_FINISHED", "RUN_ERROR"}:
|
||||||
|
break
|
||||||
|
|
||||||
assert "RUN_STARTED" in event_names
|
assert "RUN_STARTED" in event_names
|
||||||
assert "RUN_FINISHED" in event_names or "RUN_ERROR" in event_names
|
assert "RUN_FINISHED" in event_names or "RUN_ERROR" in event_names
|
||||||
@@ -194,7 +202,14 @@ async def test_agent_runs_events_history_live_with_image_input() -> None:
|
|||||||
)
|
)
|
||||||
all_messages = list(rows.scalars().all())
|
all_messages = list(rows.scalars().all())
|
||||||
assert all_messages
|
assert all_messages
|
||||||
user_rows = [row for row in all_messages if str(row.role) == "user"]
|
user_rows = [
|
||||||
|
row
|
||||||
|
for row in all_messages
|
||||||
|
if (
|
||||||
|
getattr(row.role, "value", row.role) == "user"
|
||||||
|
or str(getattr(row.role, "value", row.role)) == "user"
|
||||||
|
)
|
||||||
|
]
|
||||||
assert user_rows
|
assert user_rows
|
||||||
metadata = user_rows[0].metadata_json or {}
|
metadata = user_rows[0].metadata_json or {}
|
||||||
attachments = metadata.get("attachments")
|
attachments = metadata.get("attachments")
|
||||||
|
|||||||
@@ -99,6 +99,16 @@ async def test_store_persists_assistant_message_and_aggregates(
|
|||||||
|
|
||||||
store = store_module.SqlAlchemyEventStore(session_factory=lambda: _FakeSessionCtx())
|
store = store_module.SqlAlchemyEventStore(session_factory=lambda: _FakeSessionCtx())
|
||||||
|
|
||||||
|
await store.persist(
|
||||||
|
{
|
||||||
|
"type": "TEXT_MESSAGE_START",
|
||||||
|
"threadId": "00000000-0000-0000-0000-000000000001",
|
||||||
|
"runId": "run-1",
|
||||||
|
"messageId": "assistant-run-1",
|
||||||
|
"role": "assistant",
|
||||||
|
"stage": "report",
|
||||||
|
}
|
||||||
|
)
|
||||||
await store.persist(
|
await store.persist(
|
||||||
{
|
{
|
||||||
"type": "TEXT_MESSAGE_CONTENT",
|
"type": "TEXT_MESSAGE_CONTENT",
|
||||||
@@ -128,6 +138,8 @@ async def test_store_persists_assistant_message_and_aggregates(
|
|||||||
assert append_kwargs["output_tokens"] == 5
|
assert append_kwargs["output_tokens"] == 5
|
||||||
assert append_kwargs["cost"] == Decimal("0.123")
|
assert append_kwargs["cost"] == Decimal("0.123")
|
||||||
assert append_kwargs["metadata"]["latency_ms"] == 250
|
assert append_kwargs["metadata"]["latency_ms"] == 250
|
||||||
|
assert append_kwargs["metadata"]["stage"] == "report"
|
||||||
|
assert append_kwargs["latency_ms"] == 250
|
||||||
assert captured["message_delta"] == 1
|
assert captured["message_delta"] == 1
|
||||||
assert captured["token_delta"] == 8
|
assert captured["token_delta"] == 8
|
||||||
assert captured["cost_delta"] == Decimal("0.123")
|
assert captured["cost_delta"] == Decimal("0.123")
|
||||||
@@ -255,6 +267,60 @@ async def test_store_clears_buffer_on_run_finished(
|
|||||||
assert "append_kwargs" not in captured
|
assert "append_kwargs" not in captured
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_store_persists_tool_call_result_as_tool_message(
|
||||||
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
|
) -> None:
|
||||||
|
captured: dict[str, object] = {}
|
||||||
|
fake_chat_session = SimpleNamespace(state_snapshot={}, message_count=2)
|
||||||
|
|
||||||
|
class _FakeSessionRepository:
|
||||||
|
def __init__(self, session: object) -> None:
|
||||||
|
del session
|
||||||
|
|
||||||
|
async def get_session(self, *, session_id): # noqa: ANN001
|
||||||
|
del session_id
|
||||||
|
return fake_chat_session
|
||||||
|
|
||||||
|
async def lock_session_for_update(self, *, session_id): # noqa: ANN001
|
||||||
|
del session_id
|
||||||
|
return fake_chat_session
|
||||||
|
|
||||||
|
async def update_runtime_state(self, **kwargs): # noqa: ANN003
|
||||||
|
captured.update(kwargs)
|
||||||
|
|
||||||
|
class _FakeMessageRepository:
|
||||||
|
def __init__(self, session: object) -> None:
|
||||||
|
del session
|
||||||
|
|
||||||
|
async def append_message(self, **kwargs): # noqa: ANN003
|
||||||
|
captured["append_kwargs"] = kwargs
|
||||||
|
|
||||||
|
monkeypatch.setattr(store_module, "SessionRepository", _FakeSessionRepository)
|
||||||
|
monkeypatch.setattr(store_module, "MessageRepository", _FakeMessageRepository)
|
||||||
|
monkeypatch.setattr(store_module, "AgentChatSessionStatus", _SessionStatus)
|
||||||
|
|
||||||
|
store = store_module.SqlAlchemyEventStore(session_factory=lambda: _FakeSessionCtx())
|
||||||
|
await store.persist(
|
||||||
|
{
|
||||||
|
"type": "TOOL_CALL_RESULT",
|
||||||
|
"threadId": "00000000-0000-0000-0000-000000000001",
|
||||||
|
"runId": "run-1",
|
||||||
|
"toolName": "calendar_write",
|
||||||
|
"taskId": "t1",
|
||||||
|
"stage": "execution",
|
||||||
|
"args": {"title": "A"},
|
||||||
|
"result": {"event_id": "evt-1"},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
append_kwargs = cast(dict[str, Any], captured["append_kwargs"])
|
||||||
|
assert getattr(append_kwargs["role"], "value", None) == "tool"
|
||||||
|
assert append_kwargs["tool_name"] == "calendar_write"
|
||||||
|
assert append_kwargs["metadata"]["task_id"] == "t1"
|
||||||
|
assert captured["message_delta"] == 1
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_store_drops_buffer_when_session_missing(
|
async def test_store_drops_buffer_when_session_missing(
|
||||||
monkeypatch: pytest.MonkeyPatch,
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
|
|||||||
@@ -13,8 +13,9 @@ from core.agentscope.schemas.user_context import (
|
|||||||
from core.agentscope.runtime.agent_route_runtime import AgentRouteRuntime
|
from core.agentscope.runtime.agent_route_runtime import AgentRouteRuntime
|
||||||
from core.agentscope.schemas import ReportOutput, RuntimeOutput
|
from core.agentscope.schemas import ReportOutput, RuntimeOutput
|
||||||
from core.agentscope.schemas.agent_runtime import RunCommand
|
from core.agentscope.schemas.agent_runtime import RunCommand
|
||||||
from core.agentscope.schemas.execution import ExecutionBatchOutput
|
from core.agentscope.schemas.execution import ExecutionBatchOutput, ExecutionTaskOutput
|
||||||
from core.agentscope.schemas.intent import IntentOutput
|
from core.agentscope.schemas.execution import ExecutionToolCall
|
||||||
|
from core.agentscope.schemas.intent import IntentOutput, IntentTask
|
||||||
|
|
||||||
|
|
||||||
def _user_context() -> UserAgentContext:
|
def _user_context() -> UserAgentContext:
|
||||||
@@ -50,20 +51,43 @@ async def test_runtime_emits_started_text_and_finished_events() -> None:
|
|||||||
async def run(self, **_: object) -> RuntimeOutput:
|
async def run(self, **_: object) -> RuntimeOutput:
|
||||||
return RuntimeOutput(
|
return RuntimeOutput(
|
||||||
intent=IntentOutput(
|
intent=IntentOutput(
|
||||||
route="DIRECT_RESPONSE",
|
route="TASK_EXECUTION",
|
||||||
intent_summary="summary",
|
intent_summary="summary",
|
||||||
direct_response="done",
|
direct_response=None,
|
||||||
tasks=[],
|
tasks=[IntentTask(task_id="t1", title="exec", objective="do")],
|
||||||
complexity="simple",
|
complexity="complex",
|
||||||
|
response_metadata={"latencyMs": 120},
|
||||||
),
|
),
|
||||||
execution=ExecutionBatchOutput(
|
execution=ExecutionBatchOutput(
|
||||||
task_results=[],
|
task_results=[
|
||||||
|
ExecutionTaskOutput(
|
||||||
|
task_id="t1",
|
||||||
|
status="SUCCESS",
|
||||||
|
execution_summary="execution-ok",
|
||||||
|
execution_data={},
|
||||||
|
user_feedback_needs=[],
|
||||||
|
response_metadata={"latencyMs": 300},
|
||||||
|
tool_calls=[
|
||||||
|
ExecutionToolCall(
|
||||||
|
tool_name="calendar_write",
|
||||||
|
args={"title": "A"},
|
||||||
|
result={"event_id": "evt-1"},
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
],
|
||||||
overall_status="SUCCESS",
|
overall_status="SUCCESS",
|
||||||
aggregate_summary="ok",
|
aggregate_summary="ok",
|
||||||
),
|
),
|
||||||
report=ReportOutput(
|
report=ReportOutput(
|
||||||
assistant_text="hello world",
|
assistant_text="hello world",
|
||||||
response_metadata={},
|
response_metadata={
|
||||||
|
"model": "qwen3.5-flash",
|
||||||
|
"inputTokens": 10,
|
||||||
|
"outputTokens": 5,
|
||||||
|
"cost": 0.123,
|
||||||
|
"latencyMs": 250,
|
||||||
|
},
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -86,6 +110,13 @@ async def test_runtime_emits_started_text_and_finished_events() -> None:
|
|||||||
"step.finish",
|
"step.finish",
|
||||||
"step.start",
|
"step.start",
|
||||||
"step.finish",
|
"step.finish",
|
||||||
|
"text.start",
|
||||||
|
"text.delta",
|
||||||
|
"text.end",
|
||||||
|
"text.start",
|
||||||
|
"text.delta",
|
||||||
|
"text.end",
|
||||||
|
"tool.result",
|
||||||
"step.start",
|
"step.start",
|
||||||
"text.start",
|
"text.start",
|
||||||
"text.delta",
|
"text.delta",
|
||||||
@@ -97,11 +128,19 @@ async def test_runtime_emits_started_text_and_finished_events() -> None:
|
|||||||
assert calls[2]["data"]["stepName"] == "intent"
|
assert calls[2]["data"]["stepName"] == "intent"
|
||||||
assert calls[3]["data"]["stepName"] == "execution"
|
assert calls[3]["data"]["stepName"] == "execution"
|
||||||
assert calls[4]["data"]["stepName"] == "execution"
|
assert calls[4]["data"]["stepName"] == "execution"
|
||||||
assert calls[5]["data"]["stepName"] == "report"
|
assert calls[5]["data"]["stage"] == "intent"
|
||||||
assert calls[7]["data"]["delta"] == "hello world"
|
assert calls[8]["data"]["stage"] == "execution"
|
||||||
assert calls[6]["data"]["messageId"] == calls[7]["data"]["messageId"]
|
assert calls[11]["data"]["toolName"] == "calendar_write"
|
||||||
assert calls[7]["data"]["messageId"] == calls[8]["data"]["messageId"]
|
assert calls[12]["data"]["stepName"] == "report"
|
||||||
assert calls[9]["data"]["stepName"] == "report"
|
assert calls[14]["data"]["delta"] == "hello world"
|
||||||
|
assert calls[13]["data"]["messageId"] == calls[14]["data"]["messageId"]
|
||||||
|
assert calls[14]["data"]["messageId"] == calls[15]["data"]["messageId"]
|
||||||
|
assert calls[15]["data"]["model"] == "qwen3.5-flash"
|
||||||
|
assert calls[15]["data"]["inputTokens"] == 10
|
||||||
|
assert calls[15]["data"]["outputTokens"] == 5
|
||||||
|
assert calls[15]["data"]["cost"] == 0.123
|
||||||
|
assert calls[15]["data"]["latencyMs"] == 250
|
||||||
|
assert calls[16]["data"]["stepName"] == "report"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@@ -140,3 +179,129 @@ async def test_runtime_emits_run_error_when_orchestrator_fails() -> None:
|
|||||||
]
|
]
|
||||||
assert calls[1]["data"]["stepName"] == "intent"
|
assert calls[1]["data"]["stepName"] == "intent"
|
||||||
assert calls[2]["data"]["message"] == "runtime execution failed"
|
assert calls[2]["data"]["message"] == "runtime execution failed"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_runtime_passes_binary_payload_to_orchestrator() -> None:
|
||||||
|
captured_user_input: object | None = None
|
||||||
|
|
||||||
|
class _FakePipeline:
|
||||||
|
async def emit(self, *, session_id: str, event: dict[str, object]) -> str:
|
||||||
|
assert session_id == "thread-1"
|
||||||
|
return str(event.get("type", ""))
|
||||||
|
|
||||||
|
class _CaptureOrchestrator:
|
||||||
|
async def run(self, **kwargs: object) -> RuntimeOutput:
|
||||||
|
nonlocal captured_user_input
|
||||||
|
captured_user_input = kwargs.get("user_input")
|
||||||
|
return RuntimeOutput(
|
||||||
|
intent=IntentOutput(
|
||||||
|
route="DIRECT_RESPONSE",
|
||||||
|
intent_summary="summary",
|
||||||
|
direct_response="done",
|
||||||
|
tasks=[],
|
||||||
|
complexity="simple",
|
||||||
|
),
|
||||||
|
execution=None,
|
||||||
|
report=ReportOutput(
|
||||||
|
assistant_text="ok",
|
||||||
|
response_metadata={},
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
runtime = AgentRouteRuntime(
|
||||||
|
orchestrator=_CaptureOrchestrator(),
|
||||||
|
pipeline=_FakePipeline(),
|
||||||
|
)
|
||||||
|
command = RunCommand.model_validate(
|
||||||
|
{
|
||||||
|
"threadId": "thread-1",
|
||||||
|
"runId": "run-1",
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"id": "u1",
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": "hello"},
|
||||||
|
{
|
||||||
|
"type": "binary",
|
||||||
|
"mimeType": "image/png",
|
||||||
|
"data": "aGVsbG8=",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
await runtime.run(
|
||||||
|
command=command,
|
||||||
|
owner_id=uuid4(),
|
||||||
|
user_token="token",
|
||||||
|
user_context=_user_context(),
|
||||||
|
session=cast(AsyncSession, object()),
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(captured_user_input, list)
|
||||||
|
first = captured_user_input[0]
|
||||||
|
assert isinstance(first, dict)
|
||||||
|
content = first.get("content")
|
||||||
|
assert isinstance(content, list)
|
||||||
|
binary = content[1]
|
||||||
|
assert isinstance(binary, dict)
|
||||||
|
assert binary.get("data") == "aGVsbG8="
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_runtime_direct_response_finishes_without_report_stage() -> None:
|
||||||
|
calls: list[dict[str, Any]] = []
|
||||||
|
|
||||||
|
class _FakePipeline:
|
||||||
|
async def emit(self, *, session_id: str, event: dict[str, object]) -> str:
|
||||||
|
assert session_id == "thread-1"
|
||||||
|
calls.append(event)
|
||||||
|
return f"{len(calls)}-0"
|
||||||
|
|
||||||
|
class _DirectOrchestrator:
|
||||||
|
async def run(self, **_: object) -> RuntimeOutput:
|
||||||
|
return RuntimeOutput(
|
||||||
|
intent=IntentOutput(
|
||||||
|
route="DIRECT_RESPONSE",
|
||||||
|
intent_summary="summary",
|
||||||
|
direct_response="direct-answer",
|
||||||
|
tasks=[],
|
||||||
|
complexity="simple",
|
||||||
|
response_metadata={"latencyMs": 88},
|
||||||
|
),
|
||||||
|
execution=None,
|
||||||
|
report=ReportOutput(
|
||||||
|
assistant_text="direct-answer",
|
||||||
|
response_metadata={"latencyMs": 88},
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
runtime = AgentRouteRuntime(
|
||||||
|
orchestrator=_DirectOrchestrator(),
|
||||||
|
pipeline=_FakePipeline(),
|
||||||
|
)
|
||||||
|
command = RunCommand(threadId="thread-1", runId="run-1", messages=[])
|
||||||
|
|
||||||
|
await runtime.run(
|
||||||
|
command=command,
|
||||||
|
owner_id=uuid4(),
|
||||||
|
user_token="token",
|
||||||
|
user_context=_user_context(),
|
||||||
|
session=cast(AsyncSession, object()),
|
||||||
|
)
|
||||||
|
|
||||||
|
assert [item["type"] for item in calls] == [
|
||||||
|
"run.started",
|
||||||
|
"step.start",
|
||||||
|
"step.finish",
|
||||||
|
"text.start",
|
||||||
|
"text.delta",
|
||||||
|
"text.end",
|
||||||
|
"run.finished",
|
||||||
|
]
|
||||||
|
assert calls[3]["data"]["stage"] == "intent"
|
||||||
|
assert calls[4]["data"]["delta"] == "direct-answer"
|
||||||
|
|||||||
@@ -68,6 +68,7 @@ class _FakeRunner:
|
|||||||
"direct_response": "你好",
|
"direct_response": "你好",
|
||||||
"tasks": [],
|
"tasks": [],
|
||||||
"complexity": "simple",
|
"complexity": "simple",
|
||||||
|
"response_metadata": {"model": "qwen3.5-flash", "latencyMs": 100},
|
||||||
}
|
}
|
||||||
self.report_calls += 1
|
self.report_calls += 1
|
||||||
return {
|
return {
|
||||||
@@ -131,7 +132,7 @@ async def test_runtime_direct_response_skips_execution(
|
|||||||
{
|
{
|
||||||
"type": "function",
|
"type": "function",
|
||||||
"function": {
|
"function": {
|
||||||
"name": "calendar.read",
|
"name": "calendar_read",
|
||||||
"description": "read",
|
"description": "read",
|
||||||
"parameters": {"type": "object", "properties": {}},
|
"parameters": {"type": "object", "properties": {}},
|
||||||
},
|
},
|
||||||
@@ -162,8 +163,10 @@ async def test_runtime_direct_response_skips_execution(
|
|||||||
|
|
||||||
assert result.intent.route == "DIRECT_RESPONSE"
|
assert result.intent.route == "DIRECT_RESPONSE"
|
||||||
assert result.execution is None
|
assert result.execution is None
|
||||||
assert result.report.assistant_text == "已完成"
|
assert result.report.assistant_text == "你好"
|
||||||
|
assert result.report.response_metadata["model"] == "qwen3.5-flash"
|
||||||
assert fake_runner.execution_calls == 0
|
assert fake_runner.execution_calls == 0
|
||||||
|
assert fake_runner.report_calls == 0
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@@ -183,7 +186,7 @@ async def test_runtime_complex_route_runs_execution(
|
|||||||
{
|
{
|
||||||
"type": "function",
|
"type": "function",
|
||||||
"function": {
|
"function": {
|
||||||
"name": "calendar.read",
|
"name": "calendar_read",
|
||||||
"description": "read",
|
"description": "read",
|
||||||
"parameters": {"type": "object", "properties": {}},
|
"parameters": {"type": "object", "properties": {}},
|
||||||
},
|
},
|
||||||
@@ -191,7 +194,7 @@ async def test_runtime_complex_route_runs_execution(
|
|||||||
{
|
{
|
||||||
"type": "function",
|
"type": "function",
|
||||||
"function": {
|
"function": {
|
||||||
"name": "calendar.write",
|
"name": "calendar_write",
|
||||||
"description": "write",
|
"description": "write",
|
||||||
"parameters": {"type": "object", "properties": {}},
|
"parameters": {"type": "object", "properties": {}},
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -9,6 +9,8 @@ from core.agentscope.schemas.system_agent_config import SystemAgentLLMConfig
|
|||||||
from core.agentscope.runtime.config_loader import RuntimeStageConfig
|
from core.agentscope.runtime.config_loader import RuntimeStageConfig
|
||||||
from core.agentscope.runtime.react_runner import (
|
from core.agentscope.runtime.react_runner import (
|
||||||
AgentScopeReActRunner,
|
AgentScopeReActRunner,
|
||||||
|
_chat_response_text,
|
||||||
|
_merge_stage_response_metadata,
|
||||||
_parse_json_text,
|
_parse_json_text,
|
||||||
_to_litellm_model,
|
_to_litellm_model,
|
||||||
)
|
)
|
||||||
@@ -32,10 +34,10 @@ def test_to_litellm_model_keeps_prefixed_model() -> None:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_to_litellm_model_builds_prefixed_model() -> None:
|
def test_to_litellm_model_uses_plain_model_name_when_unprefixed() -> None:
|
||||||
assert (
|
assert (
|
||||||
_to_litellm_model(provider_name="dashscope", model_code="qwen3.5-flash")
|
_to_litellm_model(provider_name="dashscope", model_code="qwen3.5-flash")
|
||||||
== "dashscope/qwen3.5-flash"
|
== "qwen3.5-flash"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -49,6 +51,24 @@ def test_parse_json_text_rejects_non_json() -> None:
|
|||||||
_parse_json_text("not-json")
|
_parse_json_text("not-json")
|
||||||
|
|
||||||
|
|
||||||
|
def test_chat_response_text_falls_back_to_choice_message_content() -> None:
|
||||||
|
response = SimpleNamespace(
|
||||||
|
content=None,
|
||||||
|
choices=[
|
||||||
|
{
|
||||||
|
"message": {
|
||||||
|
"content": '{"assistant_text":"fallback","response_metadata":{}}'
|
||||||
|
}
|
||||||
|
}
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert (
|
||||||
|
_chat_response_text(response)
|
||||||
|
== '{"assistant_text":"fallback","response_metadata":{}}'
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_run_json_stage_wraps_json_decode_error(
|
async def test_run_json_stage_wraps_json_decode_error(
|
||||||
monkeypatch: pytest.MonkeyPatch,
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
@@ -113,3 +133,88 @@ async def test_run_json_stage_wraps_runtime_error(
|
|||||||
user_prompt="user",
|
user_prompt="user",
|
||||||
toolkit=None,
|
toolkit=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_run_json_stage_report_merges_usage_metadata(
|
||||||
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
|
) -> None:
|
||||||
|
class _FakeLiteLLMService:
|
||||||
|
def run_completion_with_cost(self, **kwargs: object) -> object:
|
||||||
|
del kwargs
|
||||||
|
return SimpleNamespace(
|
||||||
|
response={
|
||||||
|
"model": "dashscope/qwen3.5-flash",
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"message": {
|
||||||
|
"content": '{"assistant_text":"ok","response_metadata":{}}'
|
||||||
|
}
|
||||||
|
}
|
||||||
|
],
|
||||||
|
},
|
||||||
|
usage=SimpleNamespace(
|
||||||
|
prompt_tokens=9,
|
||||||
|
completion_tokens=4,
|
||||||
|
cost=0.006,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
runner = AgentScopeReActRunner()
|
||||||
|
monkeypatch.setattr(
|
||||||
|
runner,
|
||||||
|
"_build_litellm_service",
|
||||||
|
lambda: _FakeLiteLLMService(),
|
||||||
|
)
|
||||||
|
|
||||||
|
report_stage = RuntimeStageConfig(
|
||||||
|
stage="report",
|
||||||
|
model_code="qwen3.5-flash",
|
||||||
|
provider_name="dashscope",
|
||||||
|
llm_config=SystemAgentLLMConfig(
|
||||||
|
temperature=0.1,
|
||||||
|
max_tokens=128,
|
||||||
|
timeout_seconds=30,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
payload = await runner.run_json_stage(
|
||||||
|
stage_config=report_stage,
|
||||||
|
agent_name="report-agent",
|
||||||
|
system_prompt="sys",
|
||||||
|
user_prompt="user",
|
||||||
|
toolkit=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
metadata = payload["response_metadata"]
|
||||||
|
assert metadata["model"] == "dashscope/qwen3.5-flash"
|
||||||
|
assert metadata["inputTokens"] == 9
|
||||||
|
assert metadata["outputTokens"] == 4
|
||||||
|
assert metadata["cost"] == 0.006
|
||||||
|
assert isinstance(metadata["latencyMs"], int)
|
||||||
|
assert metadata["latencyMs"] >= 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_merge_stage_response_metadata_estimates_cost_from_pricing(
|
||||||
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
|
) -> None:
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"core.agentscope.runtime.react_runner._estimate_cost_by_pricing",
|
||||||
|
lambda **kwargs: 0.0025,
|
||||||
|
)
|
||||||
|
payload = _merge_stage_response_metadata(
|
||||||
|
payload={"route": "DIRECT_RESPONSE", "response_metadata": {}},
|
||||||
|
stage_config=_stage_config(),
|
||||||
|
response=SimpleNamespace(
|
||||||
|
usage=SimpleNamespace(
|
||||||
|
prompt_tokens=12,
|
||||||
|
completion_tokens=8,
|
||||||
|
),
|
||||||
|
model="qwen3.5-flash",
|
||||||
|
),
|
||||||
|
latency_ms=50,
|
||||||
|
)
|
||||||
|
|
||||||
|
metadata = payload["response_metadata"]
|
||||||
|
assert metadata["inputTokens"] == 12
|
||||||
|
assert metadata["outputTokens"] == 8
|
||||||
|
assert metadata["cost"] == 0.0025
|
||||||
|
|||||||
@@ -71,6 +71,63 @@ async def test_run_agentscope_task_calls_runtime_run(
|
|||||||
assert called["resume"] == 0
|
assert called["resume"] == 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_run_agentscope_task_includes_recent_context_messages(
|
||||||
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
|
) -> None:
|
||||||
|
captured_messages: list[dict[str, Any]] = []
|
||||||
|
|
||||||
|
class _FakeRuntime:
|
||||||
|
def __init__(self, **kwargs: object) -> None:
|
||||||
|
del kwargs
|
||||||
|
|
||||||
|
async def run(self, **kwargs: object) -> object:
|
||||||
|
command = kwargs.get("command")
|
||||||
|
if command is not None:
|
||||||
|
raw_messages = getattr(command, "messages", [])
|
||||||
|
if isinstance(raw_messages, list):
|
||||||
|
captured_messages.extend(raw_messages)
|
||||||
|
return object()
|
||||||
|
|
||||||
|
async def resume(self, **kwargs: object) -> object:
|
||||||
|
del kwargs
|
||||||
|
return object()
|
||||||
|
|
||||||
|
async def _fake_get_redis_client() -> object:
|
||||||
|
return object()
|
||||||
|
|
||||||
|
async def _fake_context(**kwargs: object) -> list[dict[str, Any]]:
|
||||||
|
del kwargs
|
||||||
|
return [{"id": "ctx-1", "role": "assistant", "content": "历史上下文"}]
|
||||||
|
|
||||||
|
monkeypatch.setattr(tasks_module, "AgentRouteRuntime", _FakeRuntime)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
tasks_module,
|
||||||
|
"get_or_init_redis_client",
|
||||||
|
_fake_get_redis_client,
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(tasks_module, "AsyncSessionLocal", lambda: _FakeSessionCtx())
|
||||||
|
monkeypatch.setattr(
|
||||||
|
tasks_module,
|
||||||
|
"_build_recent_context_messages",
|
||||||
|
_fake_context,
|
||||||
|
)
|
||||||
|
|
||||||
|
run_input = _run_input_payload()
|
||||||
|
run_input["messages"] = [{"id": "u1", "role": "user", "content": "现在几点"}]
|
||||||
|
await tasks_module.run_agentscope_task(
|
||||||
|
{
|
||||||
|
"command": "run",
|
||||||
|
"owner_id": str(uuid4()),
|
||||||
|
"run_input": run_input,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(captured_messages) == 2
|
||||||
|
assert captured_messages[0]["id"] == "ctx-1"
|
||||||
|
assert captured_messages[1]["id"] == "u1"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_run_agentscope_task_calls_runtime_resume(
|
async def test_run_agentscope_task_calls_runtime_resume(
|
||||||
monkeypatch: pytest.MonkeyPatch,
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
|
|||||||
@@ -178,3 +178,89 @@ async def test_calendar_write_rejects_invalid_reminder_minutes(
|
|||||||
|
|
||||||
assert result["data"]["ok"] is False
|
assert result["data"]["ok"] is False
|
||||||
assert result["data"]["code"] == "INVALID_ARGUMENT"
|
assert result["data"]["code"] == "INVALID_ARGUMENT"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_calendar_write_maps_invite_arguments(
|
||||||
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
|
) -> None:
|
||||||
|
captured: dict[str, object] = {}
|
||||||
|
|
||||||
|
async def _fake_execute(**kwargs: Any) -> dict[str, object]:
|
||||||
|
captured.update(cast(dict[str, object], kwargs["tool_args"]))
|
||||||
|
return {"type": "calendar_card.v1", "version": "v1", "data": {"ok": True}}
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
calendar_module,
|
||||||
|
"_execute_mutate_calendar_event",
|
||||||
|
_fake_execute,
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(calendar_module, "_verify_user_token", lambda **_: True)
|
||||||
|
monkeypatch.setattr(calendar_module, "build_tool_response", lambda payload: payload)
|
||||||
|
|
||||||
|
await calendar_module.calendar_write(
|
||||||
|
session=cast(AsyncSession, SimpleNamespace()),
|
||||||
|
owner_id=uuid4(),
|
||||||
|
user_token="token-abc",
|
||||||
|
operation="create",
|
||||||
|
invite_user_emails=["a@example.com"],
|
||||||
|
invite_user_names=["alice"],
|
||||||
|
invite_user_ids=[str(uuid4())],
|
||||||
|
invite_permission_view=True,
|
||||||
|
invite_permission_edit=True,
|
||||||
|
invite_permission_invite=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert captured["inviteUserEmails"] == ["a@example.com"]
|
||||||
|
assert captured["inviteUserNames"] == ["alice"]
|
||||||
|
assert isinstance(captured["inviteUserIds"], list)
|
||||||
|
assert captured["invitePermissionView"] is True
|
||||||
|
assert captured["invitePermissionEdit"] is True
|
||||||
|
assert captured["invitePermissionInvite"] is True
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_user_resolve_maps_identity_arguments(
|
||||||
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
|
) -> None:
|
||||||
|
captured: dict[str, object] = {}
|
||||||
|
|
||||||
|
async def _fake_execute(**kwargs: Any) -> dict[str, object]:
|
||||||
|
captured.update(cast(dict[str, object], kwargs["tool_args"]))
|
||||||
|
return {"type": "user_lookup.v1", "version": "v1", "data": {"ok": True}}
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
calendar_module,
|
||||||
|
"_execute_resolve_user_identity",
|
||||||
|
_fake_execute,
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(calendar_module, "_verify_user_token", lambda **_: True)
|
||||||
|
monkeypatch.setattr(calendar_module, "build_tool_response", lambda payload: payload)
|
||||||
|
|
||||||
|
result = await calendar_module.user_resolve(
|
||||||
|
session=cast(AsyncSession, SimpleNamespace()),
|
||||||
|
owner_id=uuid4(),
|
||||||
|
user_token="token-abc",
|
||||||
|
user_email="a@example.com",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result["type"] == "user_lookup.v1"
|
||||||
|
assert captured == {"userEmail": "a@example.com", "userName": None}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_user_resolve_requires_valid_user_token(
|
||||||
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
|
) -> None:
|
||||||
|
monkeypatch.setattr(calendar_module, "_verify_user_token", lambda **_: False)
|
||||||
|
monkeypatch.setattr(calendar_module, "build_tool_response", lambda payload: payload)
|
||||||
|
|
||||||
|
result = await calendar_module.user_resolve(
|
||||||
|
session=cast(AsyncSession, SimpleNamespace()),
|
||||||
|
owner_id=uuid4(),
|
||||||
|
user_token="bad-token",
|
||||||
|
user_name="alice",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result["data"]["ok"] is False
|
||||||
|
assert result["data"]["code"] == "UNAUTHORIZED"
|
||||||
|
|||||||
@@ -0,0 +1,56 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from core.agentscope.prompts.runtime_prompt import build_intent_user_prompt
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_intent_user_prompt_keeps_multimodal_blocks() -> None:
|
||||||
|
prompt = build_intent_user_prompt(
|
||||||
|
user_input=[
|
||||||
|
{
|
||||||
|
"id": "u1",
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": "请识别图片内容"},
|
||||||
|
{
|
||||||
|
"type": "binary",
|
||||||
|
"mimeType": "image/png",
|
||||||
|
"data": "aGVsbG8=",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(prompt, list)
|
||||||
|
assert prompt
|
||||||
|
assert prompt[0]["type"] == "text"
|
||||||
|
assert "[Output Schema]" in prompt[0]["text"]
|
||||||
|
image_blocks = [item for item in prompt if item.get("type") == "image"]
|
||||||
|
assert len(image_blocks) == 1
|
||||||
|
source = image_blocks[0]["source"]
|
||||||
|
assert source["type"] == "base64"
|
||||||
|
assert source["media_type"] == "image/png"
|
||||||
|
assert source["data"] == "aGVsbG8="
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_intent_user_prompt_filters_non_image_binary_block() -> None:
|
||||||
|
prompt = build_intent_user_prompt(
|
||||||
|
user_input=[
|
||||||
|
{
|
||||||
|
"id": "u1",
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": "请处理这个输入"},
|
||||||
|
{
|
||||||
|
"type": "binary",
|
||||||
|
"mimeType": "application/pdf",
|
||||||
|
"data": "aGVsbG8=",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(prompt, list)
|
||||||
|
image_blocks = [item for item in prompt if item.get("type") == "image"]
|
||||||
|
assert image_blocks == []
|
||||||
@@ -20,11 +20,12 @@ async def test_build_toolkit_registers_calendar_tools() -> None:
|
|||||||
)
|
)
|
||||||
schemas = toolkit.get_json_schemas()
|
schemas = toolkit.get_json_schemas()
|
||||||
names = {item["function"]["name"] for item in schemas}
|
names = {item["function"]["name"] for item in schemas}
|
||||||
assert "calendar.read" in names
|
assert "calendar_read" in names
|
||||||
assert "calendar.write" in names
|
assert "calendar_write" in names
|
||||||
|
assert "user_resolve" in names
|
||||||
|
|
||||||
write_schema = next(
|
write_schema = next(
|
||||||
item for item in schemas if item["function"]["name"] == "calendar.write"
|
item for item in schemas if item["function"]["name"] == "calendar_write"
|
||||||
)
|
)
|
||||||
params = write_schema["function"]["parameters"]["properties"]
|
params = write_schema["function"]["parameters"]["properties"]
|
||||||
assert "user_token" not in params
|
assert "user_token" not in params
|
||||||
|
|||||||
@@ -33,11 +33,11 @@ def test_calculate_cost_uses_second_qwen_tier() -> None:
|
|||||||
|
|
||||||
def test_run_completion_extracts_usage_and_cost() -> None:
|
def test_run_completion_extracts_usage_and_cost() -> None:
|
||||||
service = LiteLLMService()
|
service = LiteLLMService()
|
||||||
|
captured: dict[str, object] = {}
|
||||||
|
|
||||||
result = service.run_completion_with_cost(
|
def _fake_completion(**kwargs: object) -> dict[str, object]:
|
||||||
model="dashscope/qwen3.5-flash",
|
captured.update(kwargs)
|
||||||
messages=[{"role": "user", "content": "hello"}],
|
return {
|
||||||
completion_fn=lambda **_: {
|
|
||||||
"model": "dashscope/qwen3.5-flash",
|
"model": "dashscope/qwen3.5-flash",
|
||||||
"usage": {
|
"usage": {
|
||||||
"prompt_tokens": 2000,
|
"prompt_tokens": 2000,
|
||||||
@@ -46,10 +46,17 @@ def test_run_completion_extracts_usage_and_cost() -> None:
|
|||||||
"prompt_tokens_details": {"cached_tokens": 500},
|
"prompt_tokens_details": {"cached_tokens": 500},
|
||||||
},
|
},
|
||||||
"choices": [{"message": {"content": "ok"}}],
|
"choices": [{"message": {"content": "ok"}}],
|
||||||
},
|
}
|
||||||
|
|
||||||
|
result = service.run_completion_with_cost(
|
||||||
|
model="dashscope/qwen3.5-flash",
|
||||||
|
messages=[{"role": "user", "content": "hello"}],
|
||||||
|
response_format={"type": "json_object"},
|
||||||
|
completion_fn=_fake_completion,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert result.usage.prompt_tokens == 2000
|
assert result.usage.prompt_tokens == 2000
|
||||||
assert result.usage.completion_tokens == 100
|
assert result.usage.completion_tokens == 100
|
||||||
assert result.usage.total_tokens == 2100
|
assert result.usage.total_tokens == 2100
|
||||||
assert result.usage.cost == pytest.approx(0.00051)
|
assert result.usage.cost == pytest.approx(0.00051)
|
||||||
|
assert captured["response_format"] == {"type": "json_object"}
|
||||||
|
|||||||
@@ -10,6 +10,31 @@ from models.agent_chat_message import AgentChatMessageRole
|
|||||||
from v1.agent.repository import AgentRepository
|
from v1.agent.repository import AgentRepository
|
||||||
|
|
||||||
|
|
||||||
|
class _ExecuteResult:
|
||||||
|
def __init__(self, value: object) -> None:
|
||||||
|
self._value = value
|
||||||
|
|
||||||
|
def scalar_one_or_none(self) -> object:
|
||||||
|
return self._value
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeSession:
|
||||||
|
def __init__(self, session_row: object) -> None:
|
||||||
|
self.session_row = session_row
|
||||||
|
self.added: list[object] = []
|
||||||
|
self.flushed = False
|
||||||
|
|
||||||
|
async def execute(self, stmt): # noqa: ANN001
|
||||||
|
del stmt
|
||||||
|
return _ExecuteResult(self.session_row)
|
||||||
|
|
||||||
|
def add(self, obj: object) -> None:
|
||||||
|
self.added.append(obj)
|
||||||
|
|
||||||
|
async def flush(self) -> None:
|
||||||
|
self.flushed = True
|
||||||
|
|
||||||
|
|
||||||
class _FakeToolResultStorage:
|
class _FakeToolResultStorage:
|
||||||
def __init__(self, payload: dict[str, object] | None) -> None:
|
def __init__(self, payload: dict[str, object] | None) -> None:
|
||||||
self._payload = payload
|
self._payload = payload
|
||||||
@@ -104,3 +129,48 @@ async def test_user_message_snapshot_includes_renderable_attachments() -> None:
|
|||||||
"mimeType": "image/png",
|
"mimeType": "image/png",
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_persist_user_message_sets_session_title_when_empty() -> None:
|
||||||
|
session_id = str(uuid4())
|
||||||
|
session_row = SimpleNamespace(
|
||||||
|
message_count=0,
|
||||||
|
title=None,
|
||||||
|
last_activity_at=datetime.now(timezone.utc),
|
||||||
|
)
|
||||||
|
fake_session = _FakeSession(session_row)
|
||||||
|
repository = AgentRepository(session=fake_session) # type: ignore[arg-type]
|
||||||
|
|
||||||
|
await repository.persist_user_message(
|
||||||
|
session_id=session_id,
|
||||||
|
run_id="run-1",
|
||||||
|
content_text=" 请帮我安排明天下午开会 ",
|
||||||
|
metadata=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert session_row.title == "请帮我安排明天下午开会"
|
||||||
|
assert session_row.message_count == 1
|
||||||
|
assert fake_session.flushed is True
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_persist_user_message_keeps_existing_session_title() -> None:
|
||||||
|
session_id = str(uuid4())
|
||||||
|
session_row = SimpleNamespace(
|
||||||
|
message_count=1,
|
||||||
|
title="已有标题",
|
||||||
|
last_activity_at=datetime.now(timezone.utc),
|
||||||
|
)
|
||||||
|
fake_session = _FakeSession(session_row)
|
||||||
|
repository = AgentRepository(session=fake_session) # type: ignore[arg-type]
|
||||||
|
|
||||||
|
await repository.persist_user_message(
|
||||||
|
session_id=session_id,
|
||||||
|
run_id="run-2",
|
||||||
|
content_text="新的消息内容",
|
||||||
|
metadata=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert session_row.title == "已有标题"
|
||||||
|
assert session_row.message_count == 2
|
||||||
|
|||||||
@@ -175,3 +175,53 @@ async def test_enqueue_resume_accepts_valid_tool_contract(
|
|||||||
assert result.task_id == "task-resume-1"
|
assert result.task_id == "task-resume-1"
|
||||||
assert result.thread_id == "00000000-0000-0000-0000-000000000001"
|
assert result.thread_id == "00000000-0000-0000-0000-000000000001"
|
||||||
assert result.run_id == "run-resume-1"
|
assert result.run_id == "run-resume-1"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_stream_events_retries_on_redis_timeout(
|
||||||
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
|
) -> None:
|
||||||
|
async def _acquire(*, user_id: str) -> bool:
|
||||||
|
del user_id
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def _release(*, user_id: str) -> None:
|
||||||
|
del user_id
|
||||||
|
|
||||||
|
monkeypatch.setattr(agent_router, "_acquire_sse_slot", _acquire)
|
||||||
|
monkeypatch.setattr(agent_router, "_release_sse_slot", _release)
|
||||||
|
|
||||||
|
class _Request:
|
||||||
|
async def is_disconnected(self) -> bool:
|
||||||
|
return False
|
||||||
|
|
||||||
|
class _Service:
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.calls = 0
|
||||||
|
|
||||||
|
async def stream_events(self, **kwargs): # noqa: ANN003
|
||||||
|
del kwargs
|
||||||
|
self.calls += 1
|
||||||
|
if self.calls == 1:
|
||||||
|
raise RuntimeError("Timeout reading from localhost:6379")
|
||||||
|
if self.calls == 2:
|
||||||
|
return [{"id": "1-0", "event": {"type": "RUN_FINISHED"}}]
|
||||||
|
return []
|
||||||
|
|
||||||
|
response = await agent_router.stream_events(
|
||||||
|
request=cast(Any, _Request()),
|
||||||
|
thread_id="00000000-0000-0000-0000-000000000001",
|
||||||
|
service=cast(Any, _Service()),
|
||||||
|
current_user=CurrentUser(id=uuid4(), email="user@example.com"),
|
||||||
|
last_event_id=None,
|
||||||
|
idle_limit=2,
|
||||||
|
)
|
||||||
|
|
||||||
|
chunks: list[str] = []
|
||||||
|
async for chunk in response.body_iterator:
|
||||||
|
chunks.append(str(chunk))
|
||||||
|
if any("RUN_FINISHED" in item for item in chunks):
|
||||||
|
break
|
||||||
|
|
||||||
|
merged = "".join(chunks)
|
||||||
|
assert "event: RUN_FINISHED" in merged
|
||||||
|
|||||||
@@ -124,6 +124,19 @@ class _FakeAttachmentStorage:
|
|||||||
return path
|
return path
|
||||||
|
|
||||||
|
|
||||||
|
class _AlwaysFailAttachmentStorage:
|
||||||
|
async def upload_bytes(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
bucket: str,
|
||||||
|
path: str,
|
||||||
|
content: bytes,
|
||||||
|
content_type: str,
|
||||||
|
) -> str:
|
||||||
|
del bucket, path, content, content_type
|
||||||
|
raise RuntimeError("upload failed")
|
||||||
|
|
||||||
|
|
||||||
def _user() -> CurrentUser:
|
def _user() -> CurrentUser:
|
||||||
return CurrentUser(
|
return CurrentUser(
|
||||||
id=UUID("00000000-0000-0000-0000-000000000001"),
|
id=UUID("00000000-0000-0000-0000-000000000001"),
|
||||||
@@ -317,6 +330,54 @@ async def test_enqueue_run_uploads_user_image_to_supabase_and_injects_metadata(
|
|||||||
assert isinstance(attachments[0]["path"], str)
|
assert isinstance(attachments[0]["path"], str)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_enqueue_run_raises_when_attachment_upload_fails_without_fallback(
|
||||||
|
monkeypatch,
|
||||||
|
) -> None:
|
||||||
|
monkeypatch.setattr(
|
||||||
|
agent_service_module.config.storage, "bucket", "agent-test-bucket"
|
||||||
|
)
|
||||||
|
repository = _FakeRepository()
|
||||||
|
service = AgentService(
|
||||||
|
repository=repository,
|
||||||
|
queue=_FakeQueue(),
|
||||||
|
stream=_FakeStream(),
|
||||||
|
attachment_storage=_AlwaysFailAttachmentStorage(),
|
||||||
|
)
|
||||||
|
run_input = RunAgentInput.model_validate(
|
||||||
|
{
|
||||||
|
"threadId": "00000000-0000-0000-0000-000000000001",
|
||||||
|
"runId": "run-with-image-fail",
|
||||||
|
"state": {},
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"id": "u1",
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": "帮我看下这张图"},
|
||||||
|
{
|
||||||
|
"type": "binary",
|
||||||
|
"data": "aGVsbG8=",
|
||||||
|
"mimeType": "image/png",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"tools": [],
|
||||||
|
"context": [],
|
||||||
|
"forwardedProps": {},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
await service.enqueue_run(run_input=run_input, current_user=_user())
|
||||||
|
raise AssertionError("expected HTTPException")
|
||||||
|
except HTTPException as exc:
|
||||||
|
assert exc.status_code == 502
|
||||||
|
assert exc.detail == "Failed to upload attachment"
|
||||||
|
|
||||||
|
assert repository.persisted_user_messages == []
|
||||||
|
|
||||||
|
|
||||||
async def test_get_history_snapshot_wraps_history_day_as_state_snapshot_event() -> None:
|
async def test_get_history_snapshot_wraps_history_day_as_state_snapshot_event() -> None:
|
||||||
service = AgentService(
|
service = AgentService(
|
||||||
repository=_FakeRepository(),
|
repository=_FakeRepository(),
|
||||||
|
|||||||
@@ -1,141 +0,0 @@
|
|||||||
# Agent Multimodal Smoke Implementation Plan
|
|
||||||
|
|
||||||
> **For Claude:** REQUIRED SUB-SKILL: Use superpowers:executing-plans to implement this plan task-by-task.
|
|
||||||
|
|
||||||
**Goal:** 完成 agent 三条主链路(runs/events/history)真实冒烟,并支持 RunAgentInput 图片信息在发送链路落 Supabase Storage、在 messages.metadata 持久化、在 history 返回中可渲染。
|
|
||||||
|
|
||||||
**Architecture:** 在 `v1/agent` 服务层新增“用户消息持久化 + 图片附件上传”步骤:`enqueue_run` 时解析用户消息 content block,图片上传到 `config.storage.bucket`,将路径写入 `messages.metadata`。运行时继续通过 AgentScope pipeline 输出 AG-UI 事件,SSE 从 Redis stream 订阅,历史查询从 `messages` 回放并附带附件信息。
|
|
||||||
|
|
||||||
**Tech Stack:** FastAPI, SQLAlchemy AsyncSession, Supabase Storage Admin Client, Redis SSE stream, AG-UI, pytest/httpx。
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
### Task 1: 用户消息图片附件上传与落库
|
|
||||||
|
|
||||||
**Files:**
|
|
||||||
- Create: `backend/src/v1/agent/attachment_storage.py`
|
|
||||||
- Modify: `backend/src/v1/agent/service.py`
|
|
||||||
- Modify: `backend/src/v1/agent/repository.py`
|
|
||||||
- Test: `backend/tests/unit/v1/agent/test_service.py`
|
|
||||||
|
|
||||||
**Step 1: 写失败测试(RED)**
|
|
||||||
|
|
||||||
```python
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_enqueue_run_persists_user_message_with_uploaded_image_metadata() -> None:
|
|
||||||
...
|
|
||||||
```
|
|
||||||
|
|
||||||
**Step 2: 运行单测验证失败**
|
|
||||||
|
|
||||||
Run: `uv run pytest tests/unit/v1/agent/test_service.py::test_enqueue_run_persists_user_message_with_uploaded_image_metadata -q`
|
|
||||||
Expected: FAIL(缺少附件上传/metadata 持久化行为)
|
|
||||||
|
|
||||||
**Step 3: 最小实现(GREEN)**
|
|
||||||
|
|
||||||
```python
|
|
||||||
class AgentAttachmentStorage:
|
|
||||||
async def upload_bytes(...):
|
|
||||||
...
|
|
||||||
|
|
||||||
class AgentService:
|
|
||||||
async def enqueue_run(...):
|
|
||||||
# 解析 user content blocks
|
|
||||||
# 上传图片到 storage
|
|
||||||
# repository 持久化 user message(metadata 包含 bucket/path)
|
|
||||||
...
|
|
||||||
```
|
|
||||||
|
|
||||||
**Step 4: 运行单测验证通过**
|
|
||||||
|
|
||||||
Run: `uv run pytest tests/unit/v1/agent/test_service.py::test_enqueue_run_persists_user_message_with_uploaded_image_metadata -q`
|
|
||||||
Expected: PASS
|
|
||||||
|
|
||||||
### Task 2: history 渲染附件路径
|
|
||||||
|
|
||||||
**Files:**
|
|
||||||
- Modify: `backend/src/v1/agent/repository.py`
|
|
||||||
- Test: `backend/tests/unit/v1/agent/test_repository.py`
|
|
||||||
|
|
||||||
**Step 1: 写失败测试(RED)**
|
|
||||||
|
|
||||||
```python
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_history_includes_user_message_attachments_from_metadata() -> None:
|
|
||||||
...
|
|
||||||
```
|
|
||||||
|
|
||||||
**Step 2: 运行测试验证失败**
|
|
||||||
|
|
||||||
Run: `uv run pytest tests/unit/v1/agent/test_repository.py::test_history_includes_user_message_attachments_from_metadata -q`
|
|
||||||
Expected: FAIL(history 尚未渲染 attachments)
|
|
||||||
|
|
||||||
**Step 3: 最小实现(GREEN)**
|
|
||||||
|
|
||||||
```python
|
|
||||||
if role == "user" and isinstance(metadata.get("attachments"), list):
|
|
||||||
payload["attachments"] = metadata["attachments"]
|
|
||||||
```
|
|
||||||
|
|
||||||
**Step 4: 运行测试验证通过**
|
|
||||||
|
|
||||||
Run: `uv run pytest tests/unit/v1/agent/test_repository.py::test_history_includes_user_message_attachments_from_metadata -q`
|
|
||||||
Expected: PASS
|
|
||||||
|
|
||||||
### Task 3: 真实冒烟 runs + SSE + history(含图片输入)
|
|
||||||
|
|
||||||
**Files:**
|
|
||||||
- Modify: `backend/tests/integration/v1/agent/test_sse_flow_live.py`
|
|
||||||
|
|
||||||
**Step 1: 写失败测试(RED)**
|
|
||||||
|
|
||||||
```python
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
@pytest.mark.live
|
|
||||||
async def test_agent_runs_events_history_live_with_image_input() -> None:
|
|
||||||
...
|
|
||||||
```
|
|
||||||
|
|
||||||
**Step 2: 运行 live 测试验证失败(实现前或环境不完整)**
|
|
||||||
|
|
||||||
Run: `AGENT_LIVE_INTEGRATION=1 AGENT_LIVE_EMAIL=... AGENT_LIVE_PASSWORD=... uv run pytest tests/integration/v1/agent/test_sse_flow_live.py::test_agent_runs_events_history_live_with_image_input -q -s`
|
|
||||||
Expected: FAIL(缺 metadata/path 或 history 不含附件)
|
|
||||||
|
|
||||||
**Step 3: 最小实现(GREEN)**
|
|
||||||
|
|
||||||
```python
|
|
||||||
# live 测试流程:
|
|
||||||
# 1) 登录拿 token
|
|
||||||
# 2) POST /runs 发送 text + image(data)
|
|
||||||
# 3) SSE 订阅直到 RUN_FINISHED/RUN_ERROR
|
|
||||||
# 4) GET /runs/{thread_id}/history
|
|
||||||
# 5) SQL 校验 sessions/messages 字段与 metadata.attachments
|
|
||||||
```
|
|
||||||
|
|
||||||
**Step 4: 运行 live 测试验证通过**
|
|
||||||
|
|
||||||
Run: `AGENT_LIVE_INTEGRATION=1 AGENT_LIVE_EMAIL=... AGENT_LIVE_PASSWORD=... uv run pytest tests/integration/v1/agent/test_sse_flow_live.py::test_agent_runs_events_history_live_with_image_input -q -s`
|
|
||||||
Expected: PASS
|
|
||||||
|
|
||||||
### Task 4: 全量收口验证与安全门禁
|
|
||||||
|
|
||||||
**Files:**
|
|
||||||
- Modify (if needed): `backend/src/v1/agent/*`, `backend/tests/*`
|
|
||||||
|
|
||||||
**Step 1: 回归测试**
|
|
||||||
|
|
||||||
Run: `uv run pytest tests/unit/v1/agent tests/unit/core/agentscope tests/integration/v1/agent -q`
|
|
||||||
Expected: PASS
|
|
||||||
|
|
||||||
**Step 2: 静态检查**
|
|
||||||
|
|
||||||
Run: `uv run ruff check src/v1/agent src/core/agentscope tests/unit/v1/agent tests/integration/v1/agent`
|
|
||||||
Expected: PASS
|
|
||||||
|
|
||||||
Run: `uv run basedpyright src/v1/agent src/core/agentscope tests/unit/v1/agent tests/integration/v1/agent`
|
|
||||||
Expected: 0 errors
|
|
||||||
|
|
||||||
**Step 3: 评审门禁**
|
|
||||||
|
|
||||||
Run agents: `security-reviewer`, `refactor-cleaner`, `code-reviewer`
|
|
||||||
Expected: 无未解决 CRITICAL/HIGH
|
|
||||||
@@ -0,0 +1,69 @@
|
|||||||
|
# Agent Multimodal Smoke Runbook
|
||||||
|
|
||||||
|
**Goal:** 固化 agent 三条主链路(runs/events/history)的真实冒烟标准与输入基线。
|
||||||
|
|
||||||
|
## 1. 覆盖范围
|
||||||
|
|
||||||
|
1. `POST /api/v1/agent/runs` - 接收多模态消息(文本+图片)
|
||||||
|
2. `GET /api/v1/agent/runs/{thread_id}/events` - SSE 事件流,事件名符合 AG-UI 标准(`RUN_STARTED`、`STEP_STARTED`、`TOOL_CALL_*`、`RUN_FINISHED`/`RUN_ERROR`)
|
||||||
|
3. `GET /api/v1/agent/runs/{thread_id}/history` - 返回 `STATE_SNAPSHOT`,含 `attachments` metadata
|
||||||
|
4. `sessions/messages` 落库完整:message_count、tokens、cost、latency、title、metadata
|
||||||
|
5. tool result 存储:大 payload 写 storage,metadata 记录 `storage_bucket`/`storage_path`
|
||||||
|
6. storage bucket 来源:必须来自环境变量 `SOCIAL_STORAGE__BUCKET`
|
||||||
|
|
||||||
|
## 2. 固定测试输入
|
||||||
|
|
||||||
|
- 图片夹具:`backend/tests/fixtures/images/calendar_text_cn.png`
|
||||||
|
- 多模态消息:
|
||||||
|
- 文本:`"识别图片中的日历内容并调用 calendar.write 创建日程"`
|
||||||
|
- 图片:`{"type":"binary","data":"<base64>","mimeType":"image/png"}`
|
||||||
|
|
||||||
|
## 3. 账号与凭据
|
||||||
|
|
||||||
|
- 冒烟账号:`dagronl@126.com` / `123456`
|
||||||
|
- 通过环境变量注入:`AGENT_LIVE_EMAIL`、`AGENT_LIVE_PASSWORD`
|
||||||
|
|
||||||
|
## 4. 执行命令
|
||||||
|
|
||||||
|
```bash
|
||||||
|
AGENT_LIVE_INTEGRATION=1 \
|
||||||
|
AGENT_LIVE_EMAIL="dagronl@126.com" \
|
||||||
|
AGENT_LIVE_PASSWORD="123456" \
|
||||||
|
uv run pytest tests/integration/v1/agent/test_sse_flow_live.py::test_agent_runs_events_history_live_with_image_input -q -s
|
||||||
|
```
|
||||||
|
|
||||||
|
## 5. 结果记录模板
|
||||||
|
|
||||||
|
- `thread_id` / `run_id`
|
||||||
|
- `runs` 状态码与响应
|
||||||
|
- `events` 事件序列
|
||||||
|
- `history` 是否含 `attachments[].bucket/path/mimeType`
|
||||||
|
- `sessions` 字段:message_count / total_tokens / total_cost / status / title
|
||||||
|
- `messages` 字段:role / content / metadata / tokens / cost / latency
|
||||||
|
- `tool_result` 是否写 storage
|
||||||
|
|
||||||
|
## 6. 安全注意
|
||||||
|
|
||||||
|
- 禁止将密码/token 写入 git 跟踪文件
|
||||||
|
|
||||||
|
## 7. 已修复问题清单
|
||||||
|
|
||||||
|
| 问题 | 修复内容 |
|
||||||
|
|------|----------|
|
||||||
|
| bucket 写入失败回退 | 改为直接报错,禁止回退到硬编码 bucket |
|
||||||
|
| user.resolve 工具 | 新增按 email/name 解析 user_id |
|
||||||
|
| calendar.write 邀请参数 | 增加 invite 参数透传 |
|
||||||
|
| inbox_repository 缺失 | 修复 calendar runtime 依赖 |
|
||||||
|
| runtime 模型名拼接 | 修复无效 model name |
|
||||||
|
| 多模态透传 | runtime 透传 binary.data,不过滤为 `<omitted>` |
|
||||||
|
| sessions.title 生成 | 首条用户消息持久化时自动生成 |
|
||||||
|
| assistant latency 入库 | `messages.latency_ms` 列写入 |
|
||||||
|
| intent/execution 阶段消息落库 | 新增 `text.*` 和 `tool.result` 事件 |
|
||||||
|
| DIRECT_RESPONSE 早返回 | intent 判定后直接返回,不进入 report 阶段 |
|
||||||
|
|
||||||
|
## 8. 待修复问题(用户新增)
|
||||||
|
|
||||||
|
1. **意图/执行阶段 tokens/cost 入库** - 目前仅 report 阶段入库
|
||||||
|
2. **连续会话记忆测试** - 验证 session 是否从数据库读取历史上下文
|
||||||
|
3. **工具调用测试** - calendar 读/写/删/分享 + 用户查找 + 时间感知
|
||||||
|
4. **session 失败排查** - 找出最新失败原因并修复
|
||||||
@@ -1,583 +0,0 @@
|
|||||||
# 日历邀请弹窗优化 Implementation Plan
|
|
||||||
|
|
||||||
> **For Claude:** REQUIRED SUB-SKILL: Use superpowers:executing-plans to implement this plan task-by-task.
|
|
||||||
|
|
||||||
**Goal:** 优化日历邀请消息弹窗,显示完整信息(发送者名称 + 日历标题),使用公共弹窗组件替代所有旧弹窗代码
|
|
||||||
|
|
||||||
**Architecture:**
|
|
||||||
- 后端新增用户信息查询接口
|
|
||||||
- 前端创建公共弹窗组件 MessageActionSheet
|
|
||||||
- 删除所有旧的弹窗代码(好友请求、日历邀请),统一使用公共组件
|
|
||||||
|
|
||||||
**Tech Stack:** Flutter (Dart), FastAPI (Python)
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
### Task 1: 后端添加用户信息查询接口
|
|
||||||
|
|
||||||
**Files:**
|
|
||||||
- Modify: `backend/src/v1/users/router.py`
|
|
||||||
- Modify: `backend/src/v1/users/service.py`
|
|
||||||
- Modify: `backend/src/v1/users/repository.py`
|
|
||||||
|
|
||||||
**Step 1: 添加 repository 方法**
|
|
||||||
|
|
||||||
修改 `backend/src/v1/users/repository.py`,在 `UserRepository` 和 `SQLAlchemyUserRepository` 中已有 `get_by_user_id` 方法,确认存在。
|
|
||||||
|
|
||||||
**Step 2: 添加 service 方法**
|
|
||||||
|
|
||||||
修改 `backend/src/v1/users/service.py`,添加:
|
|
||||||
|
|
||||||
```python
|
|
||||||
async def get_user_by_id(self, user_id: UUID) -> UserBasicInfo:
|
|
||||||
from v1.friendships.schemas import UserBasicInfo
|
|
||||||
|
|
||||||
profile = await self._repository.get_by_user_id(user_id)
|
|
||||||
if not profile:
|
|
||||||
raise HTTPException(status_code=404, detail="User not found")
|
|
||||||
return UserBasicInfo(
|
|
||||||
id=str(profile.user_id),
|
|
||||||
username=profile.username,
|
|
||||||
avatar_url=profile.avatar_url,
|
|
||||||
)
|
|
||||||
```
|
|
||||||
|
|
||||||
**Step 3: 添加 router 接口**
|
|
||||||
|
|
||||||
修改 `backend/src/v1/users/router.py`,添加:
|
|
||||||
|
|
||||||
```python
|
|
||||||
@router.get("/{user_id}", response_model=UserBasicInfo)
|
|
||||||
async def get_user(
|
|
||||||
user_id: UUID,
|
|
||||||
service: Annotated[UserService, Depends(get_user_service)],
|
|
||||||
):
|
|
||||||
return await service.get_user_by_id(user_id)
|
|
||||||
```
|
|
||||||
|
|
||||||
**Step 4: 运行 lint 和 typecheck**
|
|
||||||
|
|
||||||
```bash
|
|
||||||
cd backend && uv run ruff check src/v1/users/ && uv run basedpyright src/v1/users/
|
|
||||||
```
|
|
||||||
|
|
||||||
**Step 5: 提交**
|
|
||||||
|
|
||||||
```bash
|
|
||||||
git add backend/src/v1/users/ && git commit -m "feat(users): add get user by id endpoint"
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
### Task 2: 前端添加用户 API 接口
|
|
||||||
|
|
||||||
**Files:**
|
|
||||||
- Modify: `apps/lib/features/users/data/users_api.dart`
|
|
||||||
- Modify: `apps/lib/core/di/injection.dart`
|
|
||||||
|
|
||||||
**Step 1: 添加 UserBasicInfo 类和 getById 方法**
|
|
||||||
|
|
||||||
修改 `apps/lib/features/users/data/users_api.dart`:
|
|
||||||
|
|
||||||
```dart
|
|
||||||
class UserBasicInfo {
|
|
||||||
final String id;
|
|
||||||
final String username;
|
|
||||||
final String? avatarUrl;
|
|
||||||
|
|
||||||
UserBasicInfo({
|
|
||||||
required this.id,
|
|
||||||
required this.username,
|
|
||||||
this.avatarUrl,
|
|
||||||
});
|
|
||||||
|
|
||||||
factory UserBasicInfo.fromJson(Map<String, dynamic> json) {
|
|
||||||
return UserBasicInfo(
|
|
||||||
id: json['id'] as String,
|
|
||||||
username: json['username'] as String,
|
|
||||||
avatarUrl: json['avatar_url'] as String?,
|
|
||||||
);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
class UsersApi {
|
|
||||||
final IApiClient _client;
|
|
||||||
static const _prefix = '/api/v1/users';
|
|
||||||
|
|
||||||
UsersApi(this._client);
|
|
||||||
|
|
||||||
// ... existing methods
|
|
||||||
|
|
||||||
Future<UserBasicInfo> getById(String userId) async {
|
|
||||||
final response = await _client.get('$_prefix/$userId');
|
|
||||||
return UserBasicInfo.fromJson(response.data);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
**Step 2: 注册到 DI**
|
|
||||||
|
|
||||||
修改 `apps/lib/core/di/injection.dart`,添加:
|
|
||||||
|
|
||||||
```dart
|
|
||||||
sl.registerLazySingleton(() => UsersApi(sl<IApiClient>()));
|
|
||||||
```
|
|
||||||
|
|
||||||
**Step 3: 运行 flutter analyze**
|
|
||||||
|
|
||||||
```bash
|
|
||||||
cd apps && flutter analyze lib/features/users/
|
|
||||||
```
|
|
||||||
|
|
||||||
**Step 4: 提交**
|
|
||||||
|
|
||||||
```bash
|
|
||||||
git add apps/lib/features/users/ apps/lib/core/di/injection.dart && git commit -m "feat(users): add getById API and UserBasicInfo"
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
### Task 3: 创建公共弹窗组件 MessageActionSheet
|
|
||||||
|
|
||||||
**Files:**
|
|
||||||
- Create: `apps/lib/features/messages/ui/widgets/message_action_sheet.dart`
|
|
||||||
|
|
||||||
**Step 1: 创建弹窗组件**
|
|
||||||
|
|
||||||
创建 `apps/lib/features/messages/ui/widgets/message_action_sheet.dart`:
|
|
||||||
|
|
||||||
```dart
|
|
||||||
import 'package:flutter/material.dart';
|
|
||||||
import '../../../../core/theme/design_tokens.dart';
|
|
||||||
import '../../../../shared/widgets/app_button.dart';
|
|
||||||
|
|
||||||
class MessageActionSheet extends StatelessWidget {
|
|
||||||
final String title;
|
|
||||||
final String? description;
|
|
||||||
final String? statusText;
|
|
||||||
final bool isReadOnly;
|
|
||||||
final VoidCallback? onAccept;
|
|
||||||
final VoidCallback? onDecline;
|
|
||||||
final IconData? icon;
|
|
||||||
final Color? iconColor;
|
|
||||||
|
|
||||||
const MessageActionSheet({
|
|
||||||
super.key,
|
|
||||||
required this.title,
|
|
||||||
this.description,
|
|
||||||
this.statusText,
|
|
||||||
this.isReadOnly = false,
|
|
||||||
this.onAccept,
|
|
||||||
this.onDecline,
|
|
||||||
this.icon,
|
|
||||||
this.iconColor,
|
|
||||||
});
|
|
||||||
|
|
||||||
@override
|
|
||||||
Widget build(BuildContext context) {
|
|
||||||
return Container(
|
|
||||||
width: double.infinity,
|
|
||||||
padding: const EdgeInsets.fromLTRB(24, 20, 24, 0),
|
|
||||||
decoration: const BoxDecoration(
|
|
||||||
color: AppColors.white,
|
|
||||||
borderRadius: BorderRadius.vertical(top: Radius.circular(24)),
|
|
||||||
),
|
|
||||||
child: Column(
|
|
||||||
mainAxisSize: MainAxisSize.min,
|
|
||||||
children: [
|
|
||||||
Container(
|
|
||||||
width: 40,
|
|
||||||
height: 4,
|
|
||||||
decoration: BoxDecoration(
|
|
||||||
color: AppColors.slate300,
|
|
||||||
borderRadius: BorderRadius.circular(2),
|
|
||||||
),
|
|
||||||
),
|
|
||||||
const SizedBox(height: 20),
|
|
||||||
if (icon != null) ...[
|
|
||||||
Container(
|
|
||||||
width: 72,
|
|
||||||
height: 72,
|
|
||||||
decoration: BoxDecoration(
|
|
||||||
color: (iconColor ?? AppColors.blue500).withValues(alpha: 0.1),
|
|
||||||
shape: BoxShape.circle,
|
|
||||||
),
|
|
||||||
child: Icon(icon, size: 32, color: iconColor ?? AppColors.blue500),
|
|
||||||
),
|
|
||||||
const SizedBox(height: 16),
|
|
||||||
],
|
|
||||||
Text(
|
|
||||||
title,
|
|
||||||
style: const TextStyle(
|
|
||||||
fontSize: 20,
|
|
||||||
fontWeight: FontWeight.w600,
|
|
||||||
color: AppColors.slate900,
|
|
||||||
),
|
|
||||||
textAlign: TextAlign.center,
|
|
||||||
),
|
|
||||||
if (description != null && description!.isNotEmpty) ...[
|
|
||||||
const SizedBox(height: 8),
|
|
||||||
Text(
|
|
||||||
description!,
|
|
||||||
style: const TextStyle(fontSize: 14, color: AppColors.slate500),
|
|
||||||
textAlign: TextAlign.center,
|
|
||||||
),
|
|
||||||
],
|
|
||||||
if (statusText != null) ...[
|
|
||||||
const SizedBox(height: 16),
|
|
||||||
Container(
|
|
||||||
padding: const EdgeInsets.symmetric(horizontal: 12, vertical: 6),
|
|
||||||
decoration: BoxDecoration(
|
|
||||||
color: AppColors.slate100,
|
|
||||||
borderRadius: BorderRadius.circular(16),
|
|
||||||
),
|
|
||||||
child: Text(
|
|
||||||
statusText!,
|
|
||||||
style: const TextStyle(fontSize: 14, color: AppColors.slate600),
|
|
||||||
),
|
|
||||||
),
|
|
||||||
],
|
|
||||||
const SizedBox(height: 24),
|
|
||||||
if (!isReadOnly) ...[
|
|
||||||
Row(
|
|
||||||
children: [
|
|
||||||
Expanded(
|
|
||||||
child: AppButton(
|
|
||||||
text: '拒绝',
|
|
||||||
isOutlined: true,
|
|
||||||
onPressed: () {
|
|
||||||
Navigator.pop(context);
|
|
||||||
onDecline?.call();
|
|
||||||
},
|
|
||||||
),
|
|
||||||
),
|
|
||||||
const SizedBox(width: AppSpacing.md),
|
|
||||||
Expanded(
|
|
||||||
child: AppButton(
|
|
||||||
text: '接受',
|
|
||||||
onPressed: () {
|
|
||||||
Navigator.pop(context);
|
|
||||||
onAccept?.call();
|
|
||||||
},
|
|
||||||
),
|
|
||||||
),
|
|
||||||
],
|
|
||||||
),
|
|
||||||
],
|
|
||||||
SizedBox(height: MediaQuery.of(context).padding.bottom + 12),
|
|
||||||
],
|
|
||||||
),
|
|
||||||
);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
**Step 2: 运行 flutter analyze**
|
|
||||||
|
|
||||||
```bash
|
|
||||||
cd apps && flutter analyze lib/features/messages/ui/widgets/message_action_sheet.dart
|
|
||||||
```
|
|
||||||
|
|
||||||
**Step 3: 提交**
|
|
||||||
|
|
||||||
```bash
|
|
||||||
git add apps/lib/features/messages/ui/widgets/message_action_sheet.dart && git commit -m "feat(messages): add MessageActionSheet component"
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
### Task 4: 重构消息列表页面,使用公共组件并删除旧代码
|
|
||||||
|
|
||||||
**Files:**
|
|
||||||
- Modify: `apps/lib/features/messages/ui/screens/message_invite_list_screen.dart`
|
|
||||||
|
|
||||||
**Step 1: 添加依赖和字段**
|
|
||||||
|
|
||||||
在文件顶部添加:
|
|
||||||
|
|
||||||
```dart
|
|
||||||
import '../../../users/data/users_api.dart';
|
|
||||||
import '../widgets/message_action_sheet.dart';
|
|
||||||
```
|
|
||||||
|
|
||||||
在 `_MessageInviteListScreenState` 中添加:
|
|
||||||
|
|
||||||
```dart
|
|
||||||
late final UsersApi _usersApi;
|
|
||||||
```
|
|
||||||
|
|
||||||
在 `initState` 中添加:
|
|
||||||
|
|
||||||
```dart
|
|
||||||
_usersApi = sl<UsersApi>();
|
|
||||||
```
|
|
||||||
|
|
||||||
**Step 2: 添加获取日历邀请信息方法**
|
|
||||||
|
|
||||||
```dart
|
|
||||||
Future<(String calendarTitle, String senderName)?> _getCalendarInviteInfo(
|
|
||||||
InboxMessageResponse message,
|
|
||||||
) async {
|
|
||||||
if (message.scheduleItemId == null || message.senderId == null) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
try {
|
|
||||||
final calendar = await _calendarApi.getById(message.scheduleItemId!);
|
|
||||||
final sender = await _usersApi.getById(message.senderId!);
|
|
||||||
return (calendar.title, sender.username);
|
|
||||||
} catch (e) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
**Step 3: 替换日历邀请弹窗方法**
|
|
||||||
|
|
||||||
删除旧的 `_showCalendarInviteSheet` 方法,替换为:
|
|
||||||
|
|
||||||
```dart
|
|
||||||
Future<void> _showCalendarInviteSheet(InboxMessageResponse message) async {
|
|
||||||
final itemId = message.scheduleItemId;
|
|
||||||
if (itemId == null) return;
|
|
||||||
|
|
||||||
final info = await _getCalendarInviteInfo(message);
|
|
||||||
final title = info != null
|
|
||||||
? '${info.$2} 邀请你加入日历'
|
|
||||||
: '日历邀请';
|
|
||||||
final description = info?.$1;
|
|
||||||
|
|
||||||
if (!mounted) return;
|
|
||||||
|
|
||||||
showModalBottomSheet<void>(
|
|
||||||
context: context,
|
|
||||||
backgroundColor: Colors.transparent,
|
|
||||||
builder: (ctx) => MessageActionSheet(
|
|
||||||
title: title,
|
|
||||||
description: description,
|
|
||||||
icon: Icons.calendar_today,
|
|
||||||
iconColor: AppColors.blue500,
|
|
||||||
onAccept: () async {
|
|
||||||
try {
|
|
||||||
await _calendarApi.acceptSubscription(itemId);
|
|
||||||
await _inboxApi.markAsRead(message.id);
|
|
||||||
if (mounted) {
|
|
||||||
Toast.show(context, '已接受', type: ToastType.success);
|
|
||||||
_loadMessages();
|
|
||||||
}
|
|
||||||
} catch (e) {
|
|
||||||
if (mounted) {
|
|
||||||
Toast.show(context, '操作失败', type: ToastType.error);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
onDecline: () async {
|
|
||||||
try {
|
|
||||||
await _calendarApi.rejectSubscription(itemId);
|
|
||||||
await _inboxApi.markAsRead(message.id);
|
|
||||||
if (mounted) {
|
|
||||||
Toast.show(context, '已拒绝', type: ToastType.success);
|
|
||||||
_loadMessages();
|
|
||||||
}
|
|
||||||
} catch (e) {
|
|
||||||
if (mounted) {
|
|
||||||
Toast.show(context, '操作失败', type: ToastType.error);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
),
|
|
||||||
);
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
**Step 4: 添加已读日历邀请弹窗方法**
|
|
||||||
|
|
||||||
```dart
|
|
||||||
Future<void> _showCalendarInviteReadOnlySheet(InboxMessageResponse message) async {
|
|
||||||
final itemId = message.scheduleItemId;
|
|
||||||
if (itemId == null) return;
|
|
||||||
|
|
||||||
final info = await _getCalendarInviteInfo(message);
|
|
||||||
final title = info != null
|
|
||||||
? '${info.$2} 邀请你加入日历'
|
|
||||||
: '日历邀请';
|
|
||||||
final description = info?.$1;
|
|
||||||
|
|
||||||
final statusText = message.status.value == 'accepted' ? '已接受' : '已拒绝';
|
|
||||||
|
|
||||||
if (!mounted) return;
|
|
||||||
|
|
||||||
showModalBottomSheet<void>(
|
|
||||||
context: context,
|
|
||||||
backgroundColor: Colors.transparent,
|
|
||||||
builder: (ctx) => MessageActionSheet(
|
|
||||||
title: title,
|
|
||||||
description: description,
|
|
||||||
statusText: statusText,
|
|
||||||
isReadOnly: true,
|
|
||||||
icon: Icons.calendar_today,
|
|
||||||
iconColor: AppColors.blue500,
|
|
||||||
),
|
|
||||||
);
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
**Step 5: 替换好友请求弹窗方法**
|
|
||||||
|
|
||||||
删除旧的 `_showFriendRequestReadOnlySheet` 和 `_showFriendRequestActionSheet` 方法,替换为:
|
|
||||||
|
|
||||||
```dart
|
|
||||||
void _showFriendRequestSheet(MessageWithFriend item, {bool isReadOnly = false}) {
|
|
||||||
final message = item.message;
|
|
||||||
final friendRequest = item.friendRequest;
|
|
||||||
if (friendRequest == null) return;
|
|
||||||
|
|
||||||
final title = '${friendRequest.sender.username} 请求添加您为好友';
|
|
||||||
final description = message.content;
|
|
||||||
final statusText = isReadOnly
|
|
||||||
? (friendRequest.status == 'accepted'
|
|
||||||
? '已接受'
|
|
||||||
: friendRequest.status == 'rejected'
|
|
||||||
? '已拒绝'
|
|
||||||
: '已处理')
|
|
||||||
: null;
|
|
||||||
|
|
||||||
showModalBottomSheet<void>(
|
|
||||||
context: context,
|
|
||||||
backgroundColor: Colors.transparent,
|
|
||||||
isScrollControlled: true,
|
|
||||||
builder: (ctx) => MessageActionSheet(
|
|
||||||
title: title,
|
|
||||||
description: description,
|
|
||||||
statusText: statusText,
|
|
||||||
isReadOnly: isReadOnly,
|
|
||||||
icon: Icons.person_add_outlined,
|
|
||||||
iconColor: AppColors.emerald500,
|
|
||||||
onAccept: isReadOnly
|
|
||||||
? null
|
|
||||||
: () async {
|
|
||||||
await _processFriendRequest(item, accept: true);
|
|
||||||
},
|
|
||||||
onDecline: isReadOnly
|
|
||||||
? null
|
|
||||||
: () async {
|
|
||||||
await _processFriendRequest(item, accept: false);
|
|
||||||
},
|
|
||||||
),
|
|
||||||
);
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
**Step 6: 修改 _handleMessageTap 方法**
|
|
||||||
|
|
||||||
修改为调用新的统一方法:
|
|
||||||
|
|
||||||
```dart
|
|
||||||
case InboxMessageType.calendar:
|
|
||||||
final content = _parseCalendarContent(message.content);
|
|
||||||
if (content == null) return;
|
|
||||||
|
|
||||||
final type = content['type'] as String?;
|
|
||||||
if (type == 'invite') {
|
|
||||||
if (message.status.value == 'pending') {
|
|
||||||
await _showCalendarInviteSheet(message);
|
|
||||||
} else {
|
|
||||||
await _showCalendarInviteReadOnlySheet(message);
|
|
||||||
if (message.scheduleItemId != null && context.mounted) {
|
|
||||||
context.push('/calendar/events/${message.scheduleItemId}');
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else if (type == 'update') {
|
|
||||||
if (message.scheduleItemId != null) {
|
|
||||||
context.push('/calendar/events/${message.scheduleItemId}');
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return;
|
|
||||||
case InboxMessageType.friendRequest:
|
|
||||||
if (item.friendRequest == null) {
|
|
||||||
Toast.show(context, '发送者信息加载失败,请下拉重试', type: ToastType.error);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
_showFriendRequestSheet(item, isReadOnly: message.isRead);
|
|
||||||
return;
|
|
||||||
```
|
|
||||||
|
|
||||||
**Step 7: 删除旧的 _FriendRequestSheet 类**
|
|
||||||
|
|
||||||
删除文件末尾的整个 `_FriendRequestSheet` 类(约605-749行)。
|
|
||||||
|
|
||||||
**Step 8: 运行 flutter analyze**
|
|
||||||
|
|
||||||
```bash
|
|
||||||
cd apps && flutter analyze lib/features/messages/ui/screens/message_invite_list_screen.dart
|
|
||||||
```
|
|
||||||
|
|
||||||
**Step 9: 提交**
|
|
||||||
|
|
||||||
```bash
|
|
||||||
git add apps/lib/features/messages/ && git commit -m "refactor(messages): use MessageActionSheet for all message types"
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
### Task 5: 删除日历消息卡片中的旧弹窗代码
|
|
||||||
|
|
||||||
**Files:**
|
|
||||||
- Modify: `apps/lib/features/messages/ui/widgets/calendar_message_card.dart`
|
|
||||||
|
|
||||||
**Step 1: 修改 CalendarInviteCard**
|
|
||||||
|
|
||||||
CalendarInviteCard 是用于列表展示的卡片,不需要显示弹窗。检查是否有不必要的硬编码,如果有则清理。
|
|
||||||
|
|
||||||
**Step 2: 运行 flutter analyze**
|
|
||||||
|
|
||||||
```bash
|
|
||||||
cd apps && flutter analyze lib/features/messages/ui/widgets/calendar_message_card.dart
|
|
||||||
```
|
|
||||||
|
|
||||||
**Step 3: 提交**
|
|
||||||
|
|
||||||
```bash
|
|
||||||
git add apps/lib/features/calendar_message_card.dart && git commit/messages/ui/widgets -f "chore(messages): clean up calendar message card"
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
### Task 6: 验证和测试
|
|
||||||
|
|
||||||
**Step 1: 运行完整测试**
|
|
||||||
|
|
||||||
```bash
|
|
||||||
cd apps && flutter test test/features/messages/
|
|
||||||
cd backend && uv run pytest tests/unit/v1/users/ -v
|
|
||||||
```
|
|
||||||
|
|
||||||
**Step 2: 手动测试场景**
|
|
||||||
|
|
||||||
1. 用户 A 发送日历邀请给用户 B
|
|
||||||
2. 用户 B 打开未读消息,点击日历邀请
|
|
||||||
3. 弹窗显示:"XXX 邀请你加入 [日历标题]"
|
|
||||||
4. 点击接受/拒绝
|
|
||||||
5. 用户 B 打开已读消息,点击日历邀请
|
|
||||||
6. 弹窗显示状态标签
|
|
||||||
7. 好友请求未读/已读都使用相同弹窗组件
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Summary
|
|
||||||
|
|
||||||
| Task | Description |
|
|
||||||
|------|-------------|
|
|
||||||
| 1 | 后端添加用户信息查询接口 `/api/v1/users/{user_id}` |
|
|
||||||
| 2 | 前端添加 UsersApi.getById 方法 |
|
|
||||||
| 3 | 创建公共弹窗组件 MessageActionSheet |
|
|
||||||
| 4 | 重构消息列表页面,删除旧弹窗代码,统一使用 MessageActionSheet |
|
|
||||||
| 5 | 清理日历消息卡片旧代码 |
|
|
||||||
| 6 | 验证测试 |
|
|
||||||
|
|
||||||
**Plan complete and saved to `docs/plans/2026-03-11-calendar-invite-sheet.md`. Two execution options:**
|
|
||||||
|
|
||||||
1. **Subagent-Driven (this session)** - I dispatch fresh subagent per task, review between tasks, fast iteration
|
|
||||||
|
|
||||||
2. **Parallel Session (separate)** - Open new session with executing-plans, batch execution with checkpoints
|
|
||||||
|
|
||||||
Which approach?
|
|
||||||
Reference in New Issue
Block a user