feat(agent): 增强多模态链路与工具调用能力

This commit is contained in:
zl-q
2026-03-12 00:18:45 +08:00
parent 18db6c50e7
commit 21ba8e4a44
35 changed files with 2057 additions and 829 deletions
+20
View File
@@ -0,0 +1,20 @@
{
"$schema": "https://opencode.ai/config.json",
"mcp": {
"supabase": {
"type": "local",
"enabled": true,
"command": [
"npx",
"-y",
"@aliyun-rds/supabase-mcp-server",
"--supabase-url",
"http://47.112.66.83",
"--supabase-anon-key",
"eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJpc3MiOiJzdXBhYmFzZSIsInJvbGUiOiJhbm9uIiwiaWF0IjoxNzczMDI3NDE5LCJleHAiOjEzMjgzNjY3NDE5fQ.NVXDla5_nYPdcJk_81fc3k1UrnNTrNne_trMqt6Hg4g",
"--supabase-service-role-key",
"eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJpc3MiOiJzdXBhYmFzZSIsInJvbGUiOiJzZXJ2aWNlX3JvbGUiLCJpYXQiOjE3NzMwMjc0MTksImV4cCI6MTMyODM2Njc0MTl9.RzQBia-3QcjupsHnqaxgDWB7wnY9R7Ms9R8pMokyvLY"
]
}
}
}
@@ -23,10 +23,12 @@ class MessageRepository:
role: AgentChatMessageRole,
content: str,
model_code: str | None = None,
tool_name: str | None = None,
metadata: dict[str, object] | None = None,
input_tokens: int = 0,
output_tokens: int = 0,
cost: Decimal = Decimal("0"),
latency_ms: int | None = None,
) -> AgentChatMessage:
message = AgentChatMessage(
session_id=session_id,
@@ -34,10 +36,12 @@ class MessageRepository:
role=role,
content=content,
model_code=model_code,
tool_name=tool_name,
metadata_json=metadata,
input_tokens=input_tokens,
output_tokens=output_tokens,
cost=cost,
latency_ms=latency_ms,
)
self._session.add(message)
await self._session.flush()
+123 -3
View File
@@ -1,5 +1,6 @@
from __future__ import annotations
import json
from decimal import Decimal, InvalidOperation
from typing import Any, Callable, Protocol
from uuid import UUID
@@ -24,6 +25,7 @@ class SqlAlchemyEventStore:
def __init__(self, *, session_factory: Any) -> None:
self._session_factory = session_factory
self._message_buffers: dict[tuple[str, str], str] = {}
self._message_contexts: dict[tuple[str, str], dict[str, object]] = {}
async def persist(self, event: dict[str, Any]) -> None:
event_type = str(event.get("type", "")).strip().upper()
@@ -48,6 +50,10 @@ class SqlAlchemyEventStore:
self._buffer_text_delta(session_key=session_key, event=event)
return
if event_type == "TEXT_MESSAGE_START":
self._buffer_text_context(session_key=session_key, event=event)
return
if event_type == "RUN_STARTED":
await self._update_session_state(
session_repo=session_repo,
@@ -72,7 +78,15 @@ class SqlAlchemyEventStore:
)
self._clear_session_buffers(session_key=session_key)
elif event_type == "TEXT_MESSAGE_END":
await self._persist_assistant_message(
await self._persist_text_message(
event=event,
session_id=session_id,
chat_session=chat_session,
session_repo=session_repo,
message_repo=message_repo,
)
elif event_type == "TOOL_CALL_RESULT":
await self._persist_tool_call_result(
event=event,
session_id=session_id,
chat_session=chat_session,
@@ -97,8 +111,28 @@ class SqlAlchemyEventStore:
stale_keys = [k for k in self._message_buffers if k[0] == session_key]
for key in stale_keys:
self._message_buffers.pop(key, None)
stale_context_keys = [k for k in self._message_contexts if k[0] == session_key]
for key in stale_context_keys:
self._message_contexts.pop(key, None)
async def _persist_assistant_message(
def _buffer_text_context(self, *, session_key: str, event: dict[str, Any]) -> None:
message_id = event.get("messageId")
if not isinstance(message_id, str) or not message_id:
return
key = (session_key, message_id)
role = event.get("role")
stage = event.get("stage")
tool_name = event.get("toolName")
context: dict[str, object] = {}
if isinstance(role, str) and role:
context["role"] = role
if isinstance(stage, str) and stage:
context["stage"] = stage
if isinstance(tool_name, str) and tool_name:
context["tool_name"] = tool_name
self._message_contexts[key] = context
async def _persist_text_message(
self,
*,
event: dict[str, Any],
@@ -114,6 +148,8 @@ class SqlAlchemyEventStore:
if not content:
return
context = self._message_contexts.get(key, {})
input_tokens = self._to_int(event.get("inputTokens"))
output_tokens = self._to_int(event.get("outputTokens"))
token_delta = input_tokens + output_tokens
@@ -127,6 +163,20 @@ class SqlAlchemyEventStore:
metadata["run_id"] = run_id
if latency_ms is not None:
metadata["latency_ms"] = latency_ms
stage = event.get("stage")
if not isinstance(stage, str):
stage = context.get("stage")
if isinstance(stage, str) and stage:
metadata["stage"] = stage
role_value = context.get("role")
if not isinstance(role_value, str):
role_value = "assistant"
role = self._resolve_role(role_value)
tool_name = context.get("tool_name")
tool_name_value = (
tool_name if isinstance(tool_name, str) and tool_name else None
)
locked_session = await session_repo.lock_session_for_update(
session_id=session_id
@@ -137,13 +187,15 @@ class SqlAlchemyEventStore:
await message_repo.append_message(
session_id=session_id,
seq=seq,
role=AgentChatMessageRole.ASSISTANT,
role=role,
content=content,
model_code=model_code if isinstance(model_code, str) else None,
tool_name=tool_name_value,
metadata=metadata,
input_tokens=input_tokens,
output_tokens=output_tokens,
cost=cost,
latency_ms=latency_ms,
)
current_status = getattr(chat_session, "status", AgentChatSessionStatus.RUNNING)
@@ -161,6 +213,74 @@ class SqlAlchemyEventStore:
cost_delta=cost,
)
self._message_buffers.pop(key, None)
self._message_contexts.pop(key, None)
async def _persist_tool_call_result(
self,
*,
event: dict[str, Any],
session_id: UUID,
chat_session: Any,
session_repo: SessionRepository,
message_repo: MessageRepository,
) -> None:
tool_name = event.get("toolName")
if not isinstance(tool_name, str) or not tool_name:
return
payload = {
"args": event.get("args"),
"result": event.get("result"),
"error": event.get("error"),
"call_id": event.get("callId"),
}
content = json.dumps(payload, ensure_ascii=False, separators=(",", ":"))
metadata: dict[str, object] = {"tool_name": tool_name}
run_id = event.get("runId")
if isinstance(run_id, str) and run_id:
metadata["run_id"] = run_id
stage = event.get("stage")
if isinstance(stage, str) and stage:
metadata["stage"] = stage
task_id = event.get("taskId")
if isinstance(task_id, str) and task_id:
metadata["task_id"] = task_id
locked_session = await session_repo.lock_session_for_update(
session_id=session_id
)
if locked_session is None:
return
seq = int(getattr(locked_session, "message_count", 0) or 0) + 1
await message_repo.append_message(
session_id=session_id,
seq=seq,
role=AgentChatMessageRole.TOOL,
content=content,
tool_name=tool_name,
metadata=metadata,
)
current_status = getattr(chat_session, "status", AgentChatSessionStatus.RUNNING)
status = (
current_status
if isinstance(current_status, AgentChatSessionStatus)
else AgentChatSessionStatus.RUNNING
)
await self._update_session_state(
session_repo=session_repo,
chat_session=chat_session,
status=status,
message_delta=1,
)
def _resolve_role(self, value: str) -> AgentChatMessageRole:
normalized = value.strip().lower()
if normalized == AgentChatMessageRole.SYSTEM.value:
return AgentChatMessageRole.SYSTEM
if normalized == AgentChatMessageRole.TOOL.value:
return AgentChatMessageRole.TOOL
return AgentChatMessageRole.ASSISTANT
async def _update_session_state(
self,
@@ -38,7 +38,38 @@ def _schema_json(model: type[Any]) -> str:
)
def build_intent_user_prompt(*, user_input: str | list[dict[str, Any]]) -> str:
def build_intent_user_prompt(
*, user_input: str | list[dict[str, Any]]
) -> str | list[dict[str, Any]]:
if isinstance(user_input, list):
instruction_text = "\n\n".join(
[
INTENT_TASK_INSTRUCTION,
"[Output Schema]",
_schema_json(IntentOutput),
"[User Input]",
"Use the following multimodal blocks as the latest user input.",
]
)
blocks = [
{
"type": "text",
"text": instruction_text,
}
]
user_blocks = _latest_user_content_blocks(user_input)
if not user_blocks:
user_blocks = [
{
"type": "text",
"text": json.dumps(
user_input, ensure_ascii=True, separators=(",", ":")
),
}
]
blocks.extend(user_blocks)
return blocks
normalized_input = (
user_input
if isinstance(user_input, str)
@@ -55,6 +86,101 @@ def build_intent_user_prompt(*, user_input: str | list[dict[str, Any]]) -> str:
)
def _latest_user_content_blocks(
user_input: list[dict[str, Any]],
) -> list[dict[str, Any]]:
for message in reversed(user_input):
if not isinstance(message, dict):
continue
if message.get("role") != "user":
continue
content = message.get("content")
if isinstance(content, str):
text = content.strip()
return [{"type": "text", "text": text}] if text else []
if not isinstance(content, list):
return []
blocks: list[dict[str, Any]] = []
for item in content:
if not isinstance(item, dict):
continue
item_type = item.get("type")
if item_type == "text":
text = item.get("text")
if isinstance(text, str) and text.strip():
blocks.append({"type": "text", "text": text})
continue
if item_type == "binary":
source_block = _binary_source_block(item)
if source_block is not None:
blocks.append(source_block)
continue
if item_type == "image":
source_block = _image_source_block(item)
if source_block is not None:
blocks.append(source_block)
return blocks
return []
def _binary_source_block(item: dict[str, Any]) -> dict[str, Any] | None:
mime_type = item.get("mimeType")
media_type = mime_type if isinstance(mime_type, str) and mime_type else "image/png"
if not media_type.startswith("image/"):
return None
source_url = item.get("url")
if isinstance(source_url, str) and source_url:
return {"type": "image", "source": {"type": "url", "url": source_url}}
source_data = item.get("data")
if isinstance(source_data, str) and source_data:
return {
"type": "image",
"source": {
"type": "base64",
"media_type": media_type,
"data": source_data,
},
}
return None
def _image_source_block(item: dict[str, Any]) -> dict[str, Any] | None:
source = item.get("source")
if not isinstance(source, dict):
return None
source_type = source.get("type")
if source_type == "url":
source_url = source.get("value") or source.get("url")
if isinstance(source_url, str) and source_url:
return {"type": "image", "source": {"type": "url", "url": source_url}}
if source_type in {"data", "base64"}:
source_data = source.get("value") or source.get("data")
if isinstance(source_data, str) and source_data:
mime_type = source.get("mimeType") or source.get("media_type")
media_type = (
mime_type if isinstance(mime_type, str) and mime_type else "image/png"
)
if not media_type.startswith("image/"):
return None
return {
"type": "image",
"source": {
"type": "base64",
"media_type": media_type,
"data": source_data,
},
}
return None
def build_execution_user_prompt(
*,
task_id: str,
@@ -1,5 +1,6 @@
from __future__ import annotations
import json
from typing import Any, Protocol
from uuid import UUID
@@ -153,6 +154,44 @@ class AgentRouteRuntime:
},
)
await self._emit_stage_text(
thread_id=command.thread_id,
run_id=command.run_id,
stage_name="intent",
message_id=f"intent-{command.run_id}",
text=_intent_text_payload(result.intent),
response_metadata=result.intent.response_metadata,
)
if result.intent.route == "DIRECT_RESPONSE" and result.execution is None:
await self._pipeline.emit(
session_id=command.thread_id,
event={
"type": "run.finished",
"threadId": command.thread_id,
"runId": command.run_id,
"data": {},
},
)
return result
if result.execution is not None:
for index, task in enumerate(result.execution.task_results, start=1):
await self._emit_stage_text(
thread_id=command.thread_id,
run_id=command.run_id,
stage_name="execution",
message_id=f"execution-{command.run_id}-{index}",
text=task.execution_summary,
response_metadata=task.response_metadata,
)
await self._emit_tool_result_events(
thread_id=command.thread_id,
run_id=command.run_id,
task_id=task.task_id,
tool_calls=_task_tool_calls(task),
)
await self._pipeline.emit(
session_id=command.thread_id,
event={
@@ -164,35 +203,18 @@ class AgentRouteRuntime:
)
report_message_id = f"assistant-{command.run_id}"
await self._pipeline.emit(
session_id=command.thread_id,
event={
"type": "text.start",
"threadId": command.thread_id,
"runId": command.run_id,
"data": {"messageId": report_message_id, "role": "assistant"},
},
response_metadata = (
result.report.response_metadata
if isinstance(result.report.response_metadata, dict)
else {}
)
await self._pipeline.emit(
session_id=command.thread_id,
event={
"type": "text.delta",
"threadId": command.thread_id,
"runId": command.run_id,
"data": {
"messageId": report_message_id,
"delta": result.report.assistant_text,
},
},
)
await self._pipeline.emit(
session_id=command.thread_id,
event={
"type": "text.end",
"threadId": command.thread_id,
"runId": command.run_id,
"data": {"messageId": report_message_id},
},
await self._emit_stage_text(
thread_id=command.thread_id,
run_id=command.run_id,
stage_name="report",
message_id=report_message_id,
text=result.report.assistant_text,
response_metadata=response_metadata,
)
await self._pipeline.emit(
session_id=command.thread_id,
@@ -213,3 +235,178 @@ class AgentRouteRuntime:
},
)
return result
async def _emit_stage_text(
self,
*,
thread_id: str,
run_id: str,
stage_name: str,
message_id: str,
text: str,
response_metadata: dict[str, Any],
) -> None:
await self._pipeline.emit(
session_id=thread_id,
event={
"type": "text.start",
"threadId": thread_id,
"runId": run_id,
"data": {
"messageId": message_id,
"role": "assistant",
"stage": stage_name,
},
},
)
await self._pipeline.emit(
session_id=thread_id,
event={
"type": "text.delta",
"threadId": thread_id,
"runId": run_id,
"data": {"messageId": message_id, "delta": text},
},
)
await self._pipeline.emit(
session_id=thread_id,
event={
"type": "text.end",
"threadId": thread_id,
"runId": run_id,
"data": {
"messageId": message_id,
"stage": stage_name,
**_text_end_telemetry_payload(response_metadata),
},
},
)
async def _emit_tool_result_events(
self,
*,
thread_id: str,
run_id: str,
task_id: str,
tool_calls: list[dict[str, Any]],
) -> None:
for index, tool_call in enumerate(tool_calls, start=1):
tool_name = tool_call.get("tool_name")
if not isinstance(tool_name, str) or not tool_name:
continue
await self._pipeline.emit(
session_id=thread_id,
event={
"type": "tool.result",
"threadId": thread_id,
"runId": run_id,
"data": {
"callId": f"{run_id}-{task_id}-{index}",
"stage": "execution",
"taskId": task_id,
"toolName": tool_name,
"args": tool_call.get("args", {}),
"result": tool_call.get("result"),
"error": tool_call.get("error"),
},
},
)
def _text_end_telemetry_payload(metadata: dict[str, Any]) -> dict[str, Any]:
payload: dict[str, Any] = {}
model = _first_non_empty_str(metadata, keys=("model", "model_code"))
if model is not None:
payload["model"] = model
input_tokens = _first_number(metadata, keys=("inputTokens", "input_tokens"))
if input_tokens is not None:
payload["inputTokens"] = input_tokens
output_tokens = _first_number(metadata, keys=("outputTokens", "output_tokens"))
if output_tokens is not None:
payload["outputTokens"] = output_tokens
latency_ms = _first_number(metadata, keys=("latencyMs", "latency_ms"))
if latency_ms is not None:
payload["latencyMs"] = latency_ms
cost = _first_number(metadata, keys=("cost", "total_cost"), allow_float=True)
if cost is not None:
payload["cost"] = cost
return payload
def _intent_text_payload(intent: Any) -> str:
direct_response = getattr(intent, "direct_response", None)
if isinstance(direct_response, str) and direct_response.strip():
return direct_response
return json.dumps(intent.model_dump(mode="json"), ensure_ascii=False)
def _task_tool_calls(task: Any) -> list[dict[str, Any]]:
normalized: list[dict[str, Any]] = []
tool_calls = getattr(task, "tool_calls", None)
if isinstance(tool_calls, list):
for item in tool_calls:
if hasattr(item, "model_dump"):
dumped = item.model_dump(mode="json")
if isinstance(dumped, dict):
normalized.append(dumped)
elif isinstance(item, dict):
normalized.append(item)
if normalized:
return normalized
execution_data = getattr(task, "execution_data", None)
if not isinstance(execution_data, dict):
return []
fallback_calls = execution_data.get("tool_calls")
if not isinstance(fallback_calls, list):
return []
for item in fallback_calls:
if isinstance(item, dict):
normalized.append(item)
return normalized
def _first_non_empty_str(
metadata: dict[str, Any], *, keys: tuple[str, ...]
) -> str | None:
for key in keys:
value = metadata.get(key)
if isinstance(value, str) and value.strip():
return value.strip()
return None
def _first_number(
metadata: dict[str, Any],
*,
keys: tuple[str, ...],
allow_float: bool = False,
) -> int | float | None:
for key in keys:
value = metadata.get(key)
if isinstance(value, bool):
continue
if isinstance(value, int):
if value < 0:
continue
return value
if isinstance(value, float):
if value < 0:
continue
return value if allow_float else int(value)
if isinstance(value, str):
try:
parsed = float(value) if allow_float else int(value)
except ValueError:
continue
if parsed >= 0:
return parsed
return None
@@ -110,6 +110,19 @@ class AgentScopeRuntimeOrchestrator:
)
intent_output = IntentOutput.model_validate(intent_payload)
if intent_output.route == "DIRECT_RESPONSE":
assistant_text = (
intent_output.direct_response or intent_output.intent_summary
)
return RuntimeOutput(
intent=intent_output,
execution=None,
report=ReportOutput(
assistant_text=assistant_text,
response_metadata=dict(intent_output.response_metadata),
),
)
execution_output: ExecutionBatchOutput | None = None
if intent_output.route == "TASK_EXECUTION":
execution_toolkit = build_stage_toolkit(
@@ -1,6 +1,8 @@
from __future__ import annotations
import asyncio
import json
from time import perf_counter
from typing import Any, cast
from core.agentscope.runtime.config_loader import RuntimeStageConfig
@@ -14,7 +16,8 @@ def _to_litellm_model(*, provider_name: str, model_code: str) -> str:
normalized_model = model_code.strip()
if "/" in normalized_model:
return normalized_model
return f"{provider_name.strip().lower()}/{normalized_model}"
del provider_name
return normalized_model
def _parse_json_text(raw_text: str) -> dict[str, Any]:
@@ -30,6 +33,11 @@ def _parse_json_text(raw_text: str) -> dict[str, Any]:
class AgentScopeReActRunner:
def _build_litellm_service(self) -> Any:
from services.litellm.service import LiteLLMService
return LiteLLMService()
def _build_model(self, *, stage_config: RuntimeStageConfig) -> Any:
from agentscope.model import OpenAIChatModel
from agentscope.types import JSONSerializableObject
@@ -61,9 +69,16 @@ class AgentScopeReActRunner:
stage_config: RuntimeStageConfig,
agent_name: str,
system_prompt: str,
user_prompt: str,
user_prompt: str | list[dict[str, Any]],
toolkit: Any | None,
) -> dict[str, Any]:
if stage_config.stage == "report" and toolkit is None:
return await self._run_report_stage_direct(
stage_config=stage_config,
system_prompt=system_prompt,
user_prompt=user_prompt,
)
from agentscope.agent import ReActAgent
from agentscope.formatter import OpenAIChatFormatter
from agentscope.memory import InMemoryMemory
@@ -79,9 +94,19 @@ class AgentScopeReActRunner:
max_iters=6,
)
try:
response = await agent(Msg(name="user", content=user_prompt, role="user"))
started_at = perf_counter()
response = await agent(
Msg(name="user", content=cast(Any, user_prompt), role="user")
)
latency_ms = int(round((perf_counter() - started_at) * 1000))
text_content = response.get_text_content() or "{}"
return _parse_json_text(text_content)
payload = _parse_json_text(text_content)
return _merge_stage_response_metadata(
payload=payload,
stage_config=stage_config,
response=response,
latency_ms=latency_ms,
)
except json.JSONDecodeError as exc:
logger.exception(
"agentscope stage output is not valid json",
@@ -96,3 +121,234 @@ class AgentScopeReActRunner:
agent_name=agent_name,
)
raise RuntimeError("agent execution failed") from exc
async def _run_report_stage_direct(
self,
*,
stage_config: RuntimeStageConfig,
system_prompt: str,
user_prompt: str | list[dict[str, Any]],
) -> dict[str, Any]:
try:
service = self._build_litellm_service()
started_at = perf_counter()
response_with_cost = await asyncio.to_thread(
service.run_completion_with_cost,
model=_to_litellm_model(
provider_name=stage_config.provider_name,
model_code=stage_config.model_code,
),
messages=_report_messages(
system_prompt=system_prompt,
user_prompt=user_prompt,
),
temperature=stage_config.llm_config.temperature,
max_tokens=stage_config.llm_config.max_tokens,
timeout=stage_config.llm_config.timeout_seconds,
response_format={"type": "json_object"},
)
latency_ms = int(round((perf_counter() - started_at) * 1000))
text_content = _chat_response_text(response_with_cost.response)
payload = _parse_json_text(text_content)
return _merge_report_response_metadata(
payload=payload,
stage_config=stage_config,
response_with_cost=response_with_cost,
latency_ms=latency_ms,
)
except json.JSONDecodeError as exc:
logger.exception(
"agentscope stage output is not valid json",
stage=stage_config.stage,
agent_name="report-agent",
)
raise RuntimeError("agent output format invalid") from exc
except Exception as exc:
logger.exception(
"agentscope stage execution failed",
stage=stage_config.stage,
agent_name="report-agent",
)
raise RuntimeError("agent execution failed") from exc
def _chat_response_text(response: Any) -> str:
content = _read_value(response, "content")
if isinstance(content, str) and content.strip():
return content
if not isinstance(content, list):
return _fallback_choice_content(response)
text_parts: list[str] = []
for block in content:
if not isinstance(block, dict):
continue
if block.get("type") != "text":
continue
text = block.get("text")
if isinstance(text, str) and text:
text_parts.append(text)
if text_parts:
return "".join(text_parts)
return _fallback_choice_content(response)
def _fallback_choice_content(response: Any) -> str:
choices = _read_value(response, "choices")
if not isinstance(choices, list) or not choices:
return "{}"
first_choice = choices[0]
message = getattr(first_choice, "message", None)
if message is None and isinstance(first_choice, dict):
message = first_choice.get("message")
if isinstance(message, dict):
content = message.get("content")
return content if isinstance(content, str) and content else "{}"
content = _read_value(message, "content")
return content if isinstance(content, str) and content else "{}"
def _read_value(source: Any, key: str) -> Any:
if isinstance(source, dict):
return source.get(key)
return getattr(source, key, None)
def _report_messages(
*, system_prompt: str, user_prompt: str | list[dict[str, Any]]
) -> list[dict[str, Any]]:
return [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
]
def _merge_stage_response_metadata(
*,
payload: dict[str, Any],
stage_config: RuntimeStageConfig,
response: Any,
latency_ms: int,
) -> dict[str, Any]:
result = dict(payload)
existing = result.get("response_metadata")
metadata: dict[str, Any] = dict(existing) if isinstance(existing, dict) else {}
metadata.setdefault("model", stage_config.model_code)
usage = _read_value(response, "usage")
prompt_tokens = _to_non_negative_int(
_read_value(usage, "prompt_tokens") or _read_value(usage, "input_tokens")
)
completion_tokens = _to_non_negative_int(
_read_value(usage, "completion_tokens") or _read_value(usage, "output_tokens")
)
cost = _to_non_negative_float(
_read_value(usage, "cost")
or _read_value(_read_value(usage, "metadata"), "cost")
)
resolved_model = _read_value(response, "model")
if cost is None and prompt_tokens is not None and completion_tokens is not None:
estimated_cost = _estimate_cost_by_pricing(
model=resolved_model
if isinstance(resolved_model, str)
else stage_config.model_code,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
)
if estimated_cost is not None:
cost = estimated_cost
if prompt_tokens is not None:
metadata["inputTokens"] = prompt_tokens
if completion_tokens is not None:
metadata["outputTokens"] = completion_tokens
if cost is not None:
metadata["cost"] = cost
if latency_ms >= 0:
metadata["latencyMs"] = latency_ms
result["response_metadata"] = metadata
return result
def _merge_report_response_metadata(
*,
payload: dict[str, Any],
stage_config: RuntimeStageConfig,
response_with_cost: Any,
latency_ms: int,
) -> dict[str, Any]:
result = dict(payload)
existing = result.get("response_metadata")
metadata: dict[str, Any] = dict(existing) if isinstance(existing, dict) else {}
usage = _read_value(response_with_cost, "usage")
response = _read_value(response_with_cost, "response")
resolved_model = _read_value(response, "model")
if isinstance(resolved_model, str) and resolved_model.strip():
metadata["model"] = resolved_model.strip()
else:
metadata.setdefault("model", stage_config.model_code)
input_tokens = _to_non_negative_int(_read_value(usage, "prompt_tokens"))
output_tokens = _to_non_negative_int(_read_value(usage, "completion_tokens"))
cost = _to_non_negative_float(_read_value(usage, "cost"))
if input_tokens is not None:
metadata["inputTokens"] = input_tokens
if output_tokens is not None:
metadata["outputTokens"] = output_tokens
if cost is not None:
metadata["cost"] = cost
if latency_ms >= 0:
metadata["latencyMs"] = latency_ms
result["response_metadata"] = metadata
return result
def _to_non_negative_int(value: Any) -> int | None:
if isinstance(value, bool):
return None
if not isinstance(value, (int, float, str)):
return None
try:
parsed = int(value)
except (TypeError, ValueError):
return None
return parsed if parsed >= 0 else None
def _to_non_negative_float(value: Any) -> float | None:
if isinstance(value, bool):
return None
if not isinstance(value, (int, float, str)):
return None
try:
parsed = float(value)
except (TypeError, ValueError):
return None
return parsed if parsed >= 0 else None
def _estimate_cost_by_pricing(
*, model: str, prompt_tokens: int, completion_tokens: int
) -> float | None:
normalized_model = model.strip()
if not normalized_model:
return None
from services.litellm.service import LiteLLMService
service = LiteLLMService()
try:
return service.calculate_cost(
model=normalized_model,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
)
except ValueError:
return None
@@ -1,8 +1,11 @@
from __future__ import annotations
from datetime import datetime, timedelta, timezone
from typing import Any
from uuid import UUID
from sqlalchemy import select
from core.agentscope.events import (
AgentScopeAgUiCodec,
AgentScopeEventPipeline,
@@ -18,6 +21,7 @@ from core.config.settings import config
from core.db.session import AsyncSessionLocal
from core.logging import get_logger
from core.taskiq.app import bulk_broker, critical_broker, default_broker
from models.agent_chat_message import AgentChatMessage, AgentChatMessageRole
from services.base.redis import get_or_init_redis_client
logger = get_logger("core.agentscope.runtime.tasks")
@@ -76,6 +80,56 @@ def _extract_user_token(
return None
async def _build_recent_context_messages(
*,
session: Any,
thread_id: str,
current_run_id: str,
max_messages: int = 20,
) -> list[dict[str, Any]]:
try:
session_uuid = UUID(thread_id)
except ValueError:
return []
utc_now = datetime.now(timezone.utc)
start_of_today = utc_now.replace(hour=0, minute=0, second=0, microsecond=0)
start_of_yesterday = start_of_today - timedelta(days=1)
stmt = (
select(AgentChatMessage)
.where(AgentChatMessage.session_id == session_uuid)
.where(AgentChatMessage.deleted_at.is_(None))
.where(AgentChatMessage.created_at >= start_of_yesterday)
.order_by(AgentChatMessage.seq.asc())
)
rows = (await session.execute(stmt)).scalars().all()
normalized: list[dict[str, Any]] = []
for row in rows:
metadata = row.metadata_json if isinstance(row.metadata_json, dict) else {}
if metadata.get("run_id") == current_run_id:
continue
role = (
row.role.value
if isinstance(row.role, AgentChatMessageRole)
else str(row.role)
)
if role not in {"user", "assistant"}:
continue
normalized.append(
{
"id": str(row.id),
"role": role,
"content": row.content,
}
)
if len(normalized) <= max_messages:
return normalized
return normalized[-max_messages:]
async def run_agentscope_task(command: dict[str, Any]) -> dict[str, object]:
command_type = str(command.get("command", "run")).strip().lower()
raw_run_input = command.get("run_input")
@@ -117,6 +171,21 @@ async def run_agentscope_task(command: dict[str, Any]) -> dict[str, object]:
)
async with AsyncSessionLocal() as session:
if command_type == "run":
context_messages = await _build_recent_context_messages(
session=session,
thread_id=parsed_run_input.thread_id,
current_run_id=parsed_run_input.run_id,
)
parsed_run_input = parsed_run_input.model_copy(
update={
"messages": [
*context_messages,
*parsed_run_input.messages,
]
}
)
if command_type == "resume":
await runtime.resume(
command=parsed_run_input,
@@ -5,12 +5,21 @@ from typing import Any, Literal
from pydantic import BaseModel, Field
class ExecutionToolCall(BaseModel):
tool_name: str = Field(min_length=1)
args: dict[str, Any] = Field(default_factory=dict)
result: Any | None = None
error: str | None = None
class ExecutionTaskOutput(BaseModel):
task_id: str = Field(min_length=1)
status: Literal["SUCCESS", "PARTIAL", "FAILED"]
execution_summary: str = Field(min_length=1)
execution_data: dict[str, Any] = Field(default_factory=dict)
user_feedback_needs: list[str] = Field(default_factory=list)
response_metadata: dict[str, Any] = Field(default_factory=dict)
tool_calls: list[ExecutionToolCall] = Field(default_factory=list)
class ExecutionBatchOutput(BaseModel):
@@ -1,5 +1,6 @@
from __future__ import annotations
from typing import Any
from typing import Literal
from pydantic import BaseModel, Field, model_validator
@@ -17,6 +18,7 @@ class IntentOutput(BaseModel):
direct_response: str | None = None
tasks: list[IntentTask] = Field(default_factory=list)
complexity: Literal["simple", "complex"]
response_metadata: dict[str, Any] = Field(default_factory=dict)
@model_validator(mode="after")
def validate_route(self) -> "IntentOutput":
@@ -1,3 +1,7 @@
from core.agentscope.tools.custom.calendar import calendar_read, calendar_write
from core.agentscope.tools.custom.calendar import (
calendar_read,
calendar_write,
user_resolve,
)
__all__ = ["calendar_read", "calendar_write"]
__all__ = ["calendar_read", "calendar_write", "user_resolve"]
@@ -7,6 +7,7 @@ from core.auth.jwt_verifier import JwtVerifier, TokenValidationError
from core.agentscope.tools.custom.calendar_backend_ops import (
_execute_list_calendar_events,
_execute_mutate_calendar_event,
_execute_resolve_user_identity,
)
from core.config.settings import config
from core.agentscope.tools.response import build_tool_response
@@ -150,6 +151,30 @@ async def calendar_write(
bool,
Field(description="Whether to use the replace strategy for conflicts."),
] = False,
invite_user_emails: Annotated[
list[str] | None,
Field(description="Optional invite targets by email."),
] = None,
invite_user_names: Annotated[
list[str] | None,
Field(description="Optional invite targets by username."),
] = None,
invite_user_ids: Annotated[
list[str] | None,
Field(description="Optional invite targets by user ID (UUID string)."),
] = None,
invite_permission_view: Annotated[
bool,
Field(description="Invite permission: view."),
] = True,
invite_permission_edit: Annotated[
bool,
Field(description="Invite permission: edit."),
] = False,
invite_permission_invite: Annotated[
bool,
Field(description="Invite permission: invite others."),
] = False,
session: Any = None,
owner_id: Any = None,
user_token: str | None = None,
@@ -240,6 +265,15 @@ async def calendar_write(
tool_args["reminderMinutes"] = reminder_minutes
if status is not None:
tool_args["status"] = status
if invite_user_emails is not None:
tool_args["inviteUserEmails"] = invite_user_emails
if invite_user_names is not None:
tool_args["inviteUserNames"] = invite_user_names
if invite_user_ids is not None:
tool_args["inviteUserIds"] = invite_user_ids
tool_args["invitePermissionView"] = invite_permission_view
tool_args["invitePermissionEdit"] = invite_permission_edit
tool_args["invitePermissionInvite"] = invite_permission_invite
result = await _execute_mutate_calendar_event(
session=cast(Any, session),
@@ -247,3 +281,34 @@ async def calendar_write(
tool_args=tool_args,
)
return build_tool_response(result)
async def user_resolve(
user_email: Annotated[
str | None,
Field(description="User email to resolve user ID."),
] = None,
user_name: Annotated[
str | None,
Field(description="Username to resolve user ID."),
] = None,
session: Any = None,
owner_id: Any = None,
user_token: str | None = None,
) -> Any:
if session is None or owner_id is None:
raise ValueError("user.resolve missing runtime preset arguments")
if not isinstance(user_token, str) or not user_token.strip():
return build_tool_response(_unauthorized_response())
if not _verify_user_token(user_token=user_token, owner_id=cast(UUID, owner_id)):
return build_tool_response(_unauthorized_response())
result = await _execute_resolve_user_identity(
session=cast(Any, session),
owner_id=cast(UUID, owner_id),
tool_args={
"userEmail": user_email,
"userName": user_name,
},
)
return build_tool_response(result)
@@ -4,13 +4,20 @@ import re
from datetime import datetime, timedelta, timezone
from uuid import UUID
from fastapi import HTTPException
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from core.auth.models import CurrentUser
from services.base.supabase import supabase_service
from models.profile import Profile
from v1.auth.gateway import SupabaseAuthGateway
from v1.inbox_messages.repository import SQLAlchemyInboxMessageRepository
from v1.schedule_items.repository import SQLAlchemyScheduleItemRepository
from v1.schedule_items.schemas import (
ScheduleItemCreateRequest,
ScheduleItemMetadata,
ScheduleItemShareRequest,
ScheduleItemStatus,
ScheduleItemUpdateRequest,
)
@@ -72,9 +79,196 @@ def _service(session: AsyncSession, owner_id: UUID) -> ScheduleItemService:
repository=SQLAlchemyScheduleItemRepository(session),
session=session,
current_user=CurrentUser(id=owner_id),
inbox_repository=SQLAlchemyInboxMessageRepository(session),
)
def _parse_string_list(value: object, *, field_name: str) -> list[str]:
if value is None:
return []
if not isinstance(value, list):
raise ValueError(f"{field_name} must be a list of strings")
parsed: list[str] = []
for item in value:
if not isinstance(item, str) or not item.strip():
raise ValueError(f"{field_name} must be a list of non-empty strings")
parsed.append(item.strip())
return parsed
def _list_auth_users() -> list[object]:
admin_client = supabase_service.get_admin_client()
users: list[object] = []
page = 1
while page <= 100:
response = admin_client.auth.admin.list_users(page=page, per_page=100)
batch = (
list(response)
if isinstance(response, list)
else list(getattr(response, "users", []))
)
users.extend(batch)
if len(batch) < 100:
break
page += 1
return users
async def _get_profile_username(*, session: AsyncSession, user_id: UUID) -> str | None:
stmt = select(Profile.username).where(Profile.id == user_id)
return (await session.execute(stmt)).scalar_one_or_none()
async def _get_profile_by_username(
*, session: AsyncSession, username: str
) -> Profile | None:
stmt = (
select(Profile)
.where(Profile.username == username)
.where(Profile.deleted_at.is_(None))
)
return (await session.execute(stmt)).scalar_one_or_none()
def _find_auth_email_by_user_id(*, users: list[object], user_id: UUID) -> str | None:
target = str(user_id)
for user in users:
if str(getattr(user, "id", "")) == target:
email = getattr(user, "email", None)
if isinstance(email, str) and email.strip():
return email.strip()
return None
async def _resolve_identity(
*,
session: AsyncSession,
user_email: str | None,
user_name: str | None,
) -> dict[str, object]:
email = user_email.strip().lower() if isinstance(user_email, str) else ""
name = user_name.strip() if isinstance(user_name, str) else ""
if bool(email) == bool(name):
raise ValueError("provide exactly one of user_email or user_name")
if email:
auth_gateway = SupabaseAuthGateway()
user = await auth_gateway.get_user_by_email(email)
user_id = UUID(user.id)
username = await _get_profile_username(session=session, user_id=user_id)
return {
"userId": str(user_id),
"email": user.email,
"username": username,
"matchedBy": "email",
}
profile = await _get_profile_by_username(session=session, username=name)
if profile is None:
raise HTTPException(status_code=404, detail="User not found")
users = _list_auth_users()
email_value = _find_auth_email_by_user_id(users=users, user_id=profile.id)
return {
"userId": str(profile.id),
"email": email_value,
"username": profile.username,
"matchedBy": "username",
}
def _invite_permission(tool_args: dict[str, object]) -> dict[str, bool]:
return {
"permission_view": bool(tool_args.get("invitePermissionView", True)),
"permission_edit": bool(tool_args.get("invitePermissionEdit", False)),
"permission_invite": bool(tool_args.get("invitePermissionInvite", False)),
}
async def _share_event_with_invitees(
*,
session: AsyncSession,
owner_id: UUID,
event_id: UUID,
tool_args: dict[str, object],
) -> dict[str, object] | None:
email_targets = _parse_string_list(
tool_args.get("inviteUserEmails"),
field_name="inviteUserEmails",
)
name_targets = _parse_string_list(
tool_args.get("inviteUserNames"),
field_name="inviteUserNames",
)
id_targets = _parse_string_list(
tool_args.get("inviteUserIds"),
field_name="inviteUserIds",
)
if not email_targets and not name_targets and not id_targets:
return None
users = _list_auth_users() if id_targets else []
emails = {item.lower() for item in email_targets}
for user_id_raw in id_targets:
try:
user_id = UUID(user_id_raw)
except ValueError as exc:
raise ValueError("inviteUserIds must contain valid UUID strings") from exc
resolved_email = _find_auth_email_by_user_id(users=users, user_id=user_id)
if resolved_email is None:
raise HTTPException(status_code=404, detail="Invite user email not found")
emails.add(resolved_email.lower())
for username in name_targets:
resolved = await _resolve_identity(
session=session,
user_email=None,
user_name=username,
)
resolved_email = resolved.get("email")
if not isinstance(resolved_email, str) or not resolved_email:
raise HTTPException(status_code=404, detail="Invite user email not found")
emails.add(resolved_email.lower())
service = _service(session, owner_id)
permission = _invite_permission(tool_args)
invited: list[str] = []
for email in sorted(emails):
request = ScheduleItemShareRequest(email=email, **permission)
await service.share(event_id, request)
invited.append(email)
return {
"count": len(invited),
"emails": invited,
"permission": permission,
}
async def _execute_resolve_user_identity(
*,
session: AsyncSession,
owner_id: UUID,
tool_args: dict[str, object],
) -> dict[str, object]:
del owner_id
user_email_raw = tool_args.get("userEmail")
user_name_raw = tool_args.get("userName")
user_email = user_email_raw if isinstance(user_email_raw, str) else None
user_name = user_name_raw if isinstance(user_name_raw, str) else None
resolved = await _resolve_identity(
session=session,
user_email=user_email,
user_name=user_name,
)
return {
"type": "user_lookup.v1",
"version": "v1",
"data": {
"ok": True,
**resolved,
},
"actions": [],
}
def _resolve_metadata(tool_args: dict[str, object]) -> ScheduleItemMetadata:
location = tool_args.get("location")
location_value = location.strip() if isinstance(location, str) else None
@@ -185,6 +379,12 @@ async def _execute_create(
)
event_data = _event_payload(created)
event_id = str(event_data["id"])
invite_result = await _share_event_with_invitees(
session=service._session,
owner_id=service.require_user_id(),
event_id=UUID(event_id),
tool_args=tool_args,
)
return {
"type": "calendar_card.v1",
"version": "v1",
@@ -193,12 +393,13 @@ async def _execute_create(
"sourceType": "agent_generated",
"ok": True,
"message": "日程已创建",
"inviteResult": invite_result,
},
"actions": [
{
"type": "link",
"label": "查看详情",
"target": f"/calendar/events/{event_id}",
"target": f"/schedule-items/{event_id}",
}
],
}
@@ -274,6 +475,12 @@ async def _execute_update(
ScheduleItemUpdateRequest.model_validate(update_data),
)
event_data = _event_payload(updated)
invite_result = await _share_event_with_invitees(
session=service._session,
owner_id=service.require_user_id(),
event_id=UUID(str(event_data["id"])),
tool_args=tool_args,
)
return {
"type": "calendar_card.v1",
"version": "v1",
@@ -282,12 +489,13 @@ async def _execute_update(
"sourceType": "agent_generated",
"ok": True,
"message": "日程已更新",
"inviteResult": invite_result,
},
"actions": [
{
"type": "link",
"label": "查看详情",
"target": f"/calendar/events/{event_data['id']}",
"target": f"/schedule-items/{event_data['id']}",
}
],
}
@@ -4,8 +4,9 @@ from dataclasses import dataclass
TOOL_APPROVAL_REQUIRED: dict[str, bool] = {
"calendar.read": False,
"calendar.write": False,
"calendar_read": False,
"calendar_write": False,
"user_resolve": False,
}
+20 -5
View File
@@ -6,7 +6,11 @@ from uuid import UUID
from sqlalchemy.ext.asyncio import AsyncSession
from core.agentscope.tools.custom.calendar import calendar_read, calendar_write
from core.agentscope.tools.custom.calendar import (
calendar_read,
calendar_write,
user_resolve,
)
from core.agentscope.tools.hitl_middleware import register_tool_middlewares
from core.agentscope.tools.tool_meta import TOOL_META
@@ -25,10 +29,12 @@ class ToolGroup:
TOOL_GROUPS: dict[str, ToolGroup] = {
"intent": ToolGroup(stage="intent", tool_names=frozenset({"calendar.read"})),
"intent": ToolGroup(
stage="intent", tool_names=frozenset({"calendar_read", "user_resolve"})
),
"execution": ToolGroup(
stage="execution",
tool_names=frozenset({"calendar.read", "calendar.write"}),
tool_names=frozenset({"calendar_read", "calendar_write", "user_resolve"}),
),
"report": ToolGroup(stage="report", tool_names=frozenset()),
}
@@ -49,7 +55,7 @@ def _load_custom_tool_bindings(
) -> list[CustomToolBinding]:
return [
CustomToolBinding(
name="calendar.read",
name="calendar_read",
func=calendar_read,
preset_kwargs={
"session": session,
@@ -58,7 +64,7 @@ def _load_custom_tool_bindings(
},
),
CustomToolBinding(
name="calendar.write",
name="calendar_write",
func=calendar_write,
preset_kwargs={
"session": session,
@@ -66,6 +72,15 @@ def _load_custom_tool_bindings(
"user_token": user_token or "",
},
),
CustomToolBinding(
name="user_resolve",
func=user_resolve,
preset_kwargs={
"session": session,
"owner_id": owner_id,
"user_token": user_token or "",
},
),
]
+15 -10
View File
@@ -126,21 +126,26 @@ class LiteLLMService:
temperature: float | None = None,
max_tokens: int | None = None,
timeout: float | None = None,
response_format: dict[str, Any] | None = None,
completion_fn: Callable[..., dict[str, Any]] | None = None,
) -> LiteLLMResponseWithCost:
caller = completion_fn or completion
request_model = model if model.startswith("openai/") else f"openai/{model}"
response_any = caller(
model=request_model,
api_key=self.proxy_api_key,
api_base=self.proxy_base_url,
messages=messages,
temperature=temperature,
max_tokens=max_tokens,
timeout=timeout,
stream=False,
)
request_kwargs: dict[str, Any] = {
"model": request_model,
"api_key": self.proxy_api_key,
"api_base": self.proxy_base_url,
"messages": messages,
"temperature": temperature,
"max_tokens": max_tokens,
"timeout": timeout,
"stream": False,
}
if response_format is not None:
request_kwargs["response_format"] = response_format
response_any = caller(**request_kwargs)
response = self._normalize_response(response_any)
usage_raw = response.get("usage")
+15
View File
@@ -107,6 +107,10 @@ class AgentRepository:
raise HTTPException(status_code=404, detail="Session not found")
next_seq = int(session_row.message_count or 0) + 1
if not _has_title(session_row.title):
session_title = _derive_session_title(content_text)
if session_title is not None:
session_row.title = session_title
payload_metadata = dict(metadata or {})
payload_metadata["run_id"] = run_id
message = AgentChatMessage(
@@ -264,3 +268,14 @@ class AgentRepository:
if rendered:
payload["attachments"] = rendered
return payload
def _has_title(title: object) -> bool:
return isinstance(title, str) and bool(title.strip())
def _derive_session_title(content_text: str) -> str | None:
normalized = " ".join(content_text.split())
if not normalized:
return None
return normalized[:80]
+5
View File
@@ -203,6 +203,11 @@ async def stream_events(
user_id=str(current_user.id),
reason=str(exc),
)
if "Timeout reading from" in str(exc):
idle_polls += 1
yield ": keep-alive\n\n"
await asyncio.sleep(0.2)
continue
break
if not rows:
+13 -6
View File
@@ -212,12 +212,19 @@ class AgentService:
content_type=mime_type,
)
except Exception: # noqa: BLE001
bucket_name = "private"
stored_path = await self._attachment_storage.upload_bytes(
bucket=bucket_name,
path=path,
content=payload,
content_type=mime_type,
logger.exception(
"Attachment upload failed",
extra={
"bucket": bucket_name,
"path": path,
"mime_type": mime_type,
"thread_id": run_input.thread_id,
"run_id": run_input.run_id,
},
)
raise HTTPException(
status_code=502,
detail="Failed to upload attachment",
)
attachments.append(
{
@@ -1,6 +1,8 @@
from __future__ import annotations
import base64
import os
from pathlib import Path
from uuid import UUID, uuid4
import httpx
@@ -12,6 +14,9 @@ from models.agent_chat_message import AgentChatMessage
from models.agent_chat_session import AgentChatSession
BASE_URL = os.getenv("AGENT_LIVE_BASE_URL", "http://localhost:5775")
FIXTURE_IMAGE_PATH = (
Path(__file__).resolve().parents[3] / "fixtures" / "images" / "calendar_text_cn.png"
)
async def _live_access_token(client: httpx.AsyncClient) -> str:
@@ -108,6 +113,8 @@ async def test_agent_runs_events_history_live_with_image_input() -> None:
if os.getenv("AGENT_LIVE_INTEGRATION") != "1":
pytest.skip("set AGENT_LIVE_INTEGRATION=1 to run live integration test")
image_data = base64.b64encode(FIXTURE_IMAGE_PATH.read_bytes()).decode("ascii")
async with httpx.AsyncClient(timeout=30.0) as client:
token = await _live_access_token(client)
headers = {"Authorization": f"Bearer {token}"}
@@ -128,7 +135,7 @@ async def test_agent_runs_events_history_live_with_image_input() -> None:
{"type": "text", "text": "请描述图片里的内容"},
{
"type": "binary",
"data": "aGVsbG8=",
"data": image_data,
"mimeType": "image/png",
},
],
@@ -142,19 +149,20 @@ async def test_agent_runs_events_history_live_with_image_input() -> None:
assert run_resp.status_code == 202
events_url = f"{BASE_URL}/api/v1/agent/runs/{thread_id}/events"
sse_resp = await client.get(
events_url,
headers=headers,
params={"idle_limit": 150},
timeout=60.0,
)
assert sse_resp.status_code == 200
assert sse_resp.headers.get("content-type", "").startswith("text/event-stream")
event_names = [
line.split(":", 1)[1].strip()
for line in sse_resp.text.splitlines()
if line.startswith("event:")
]
event_names: list[str] = []
async with client.stream(
"GET", events_url, headers=headers, timeout=90.0
) as sse_resp:
assert sse_resp.status_code == 200
assert sse_resp.headers.get("content-type", "").startswith(
"text/event-stream"
)
async for line in sse_resp.aiter_lines():
if line.startswith("event:"):
event_name = line.split(":", 1)[1].strip()
event_names.append(event_name)
if event_name in {"RUN_FINISHED", "RUN_ERROR"}:
break
assert "RUN_STARTED" in event_names
assert "RUN_FINISHED" in event_names or "RUN_ERROR" in event_names
@@ -194,7 +202,14 @@ async def test_agent_runs_events_history_live_with_image_input() -> None:
)
all_messages = list(rows.scalars().all())
assert all_messages
user_rows = [row for row in all_messages if str(row.role) == "user"]
user_rows = [
row
for row in all_messages
if (
getattr(row.role, "value", row.role) == "user"
or str(getattr(row.role, "value", row.role)) == "user"
)
]
assert user_rows
metadata = user_rows[0].metadata_json or {}
attachments = metadata.get("attachments")
@@ -99,6 +99,16 @@ async def test_store_persists_assistant_message_and_aggregates(
store = store_module.SqlAlchemyEventStore(session_factory=lambda: _FakeSessionCtx())
await store.persist(
{
"type": "TEXT_MESSAGE_START",
"threadId": "00000000-0000-0000-0000-000000000001",
"runId": "run-1",
"messageId": "assistant-run-1",
"role": "assistant",
"stage": "report",
}
)
await store.persist(
{
"type": "TEXT_MESSAGE_CONTENT",
@@ -128,6 +138,8 @@ async def test_store_persists_assistant_message_and_aggregates(
assert append_kwargs["output_tokens"] == 5
assert append_kwargs["cost"] == Decimal("0.123")
assert append_kwargs["metadata"]["latency_ms"] == 250
assert append_kwargs["metadata"]["stage"] == "report"
assert append_kwargs["latency_ms"] == 250
assert captured["message_delta"] == 1
assert captured["token_delta"] == 8
assert captured["cost_delta"] == Decimal("0.123")
@@ -255,6 +267,60 @@ async def test_store_clears_buffer_on_run_finished(
assert "append_kwargs" not in captured
@pytest.mark.asyncio
async def test_store_persists_tool_call_result_as_tool_message(
monkeypatch: pytest.MonkeyPatch,
) -> None:
captured: dict[str, object] = {}
fake_chat_session = SimpleNamespace(state_snapshot={}, message_count=2)
class _FakeSessionRepository:
def __init__(self, session: object) -> None:
del session
async def get_session(self, *, session_id): # noqa: ANN001
del session_id
return fake_chat_session
async def lock_session_for_update(self, *, session_id): # noqa: ANN001
del session_id
return fake_chat_session
async def update_runtime_state(self, **kwargs): # noqa: ANN003
captured.update(kwargs)
class _FakeMessageRepository:
def __init__(self, session: object) -> None:
del session
async def append_message(self, **kwargs): # noqa: ANN003
captured["append_kwargs"] = kwargs
monkeypatch.setattr(store_module, "SessionRepository", _FakeSessionRepository)
monkeypatch.setattr(store_module, "MessageRepository", _FakeMessageRepository)
monkeypatch.setattr(store_module, "AgentChatSessionStatus", _SessionStatus)
store = store_module.SqlAlchemyEventStore(session_factory=lambda: _FakeSessionCtx())
await store.persist(
{
"type": "TOOL_CALL_RESULT",
"threadId": "00000000-0000-0000-0000-000000000001",
"runId": "run-1",
"toolName": "calendar_write",
"taskId": "t1",
"stage": "execution",
"args": {"title": "A"},
"result": {"event_id": "evt-1"},
}
)
append_kwargs = cast(dict[str, Any], captured["append_kwargs"])
assert getattr(append_kwargs["role"], "value", None) == "tool"
assert append_kwargs["tool_name"] == "calendar_write"
assert append_kwargs["metadata"]["task_id"] == "t1"
assert captured["message_delta"] == 1
@pytest.mark.asyncio
async def test_store_drops_buffer_when_session_missing(
monkeypatch: pytest.MonkeyPatch,
@@ -13,8 +13,9 @@ from core.agentscope.schemas.user_context import (
from core.agentscope.runtime.agent_route_runtime import AgentRouteRuntime
from core.agentscope.schemas import ReportOutput, RuntimeOutput
from core.agentscope.schemas.agent_runtime import RunCommand
from core.agentscope.schemas.execution import ExecutionBatchOutput
from core.agentscope.schemas.intent import IntentOutput
from core.agentscope.schemas.execution import ExecutionBatchOutput, ExecutionTaskOutput
from core.agentscope.schemas.execution import ExecutionToolCall
from core.agentscope.schemas.intent import IntentOutput, IntentTask
def _user_context() -> UserAgentContext:
@@ -50,20 +51,43 @@ async def test_runtime_emits_started_text_and_finished_events() -> None:
async def run(self, **_: object) -> RuntimeOutput:
return RuntimeOutput(
intent=IntentOutput(
route="DIRECT_RESPONSE",
route="TASK_EXECUTION",
intent_summary="summary",
direct_response="done",
tasks=[],
complexity="simple",
direct_response=None,
tasks=[IntentTask(task_id="t1", title="exec", objective="do")],
complexity="complex",
response_metadata={"latencyMs": 120},
),
execution=ExecutionBatchOutput(
task_results=[],
task_results=[
ExecutionTaskOutput(
task_id="t1",
status="SUCCESS",
execution_summary="execution-ok",
execution_data={},
user_feedback_needs=[],
response_metadata={"latencyMs": 300},
tool_calls=[
ExecutionToolCall(
tool_name="calendar_write",
args={"title": "A"},
result={"event_id": "evt-1"},
)
],
)
],
overall_status="SUCCESS",
aggregate_summary="ok",
),
report=ReportOutput(
assistant_text="hello world",
response_metadata={},
response_metadata={
"model": "qwen3.5-flash",
"inputTokens": 10,
"outputTokens": 5,
"cost": 0.123,
"latencyMs": 250,
},
),
)
@@ -86,6 +110,13 @@ async def test_runtime_emits_started_text_and_finished_events() -> None:
"step.finish",
"step.start",
"step.finish",
"text.start",
"text.delta",
"text.end",
"text.start",
"text.delta",
"text.end",
"tool.result",
"step.start",
"text.start",
"text.delta",
@@ -97,11 +128,19 @@ async def test_runtime_emits_started_text_and_finished_events() -> None:
assert calls[2]["data"]["stepName"] == "intent"
assert calls[3]["data"]["stepName"] == "execution"
assert calls[4]["data"]["stepName"] == "execution"
assert calls[5]["data"]["stepName"] == "report"
assert calls[7]["data"]["delta"] == "hello world"
assert calls[6]["data"]["messageId"] == calls[7]["data"]["messageId"]
assert calls[7]["data"]["messageId"] == calls[8]["data"]["messageId"]
assert calls[9]["data"]["stepName"] == "report"
assert calls[5]["data"]["stage"] == "intent"
assert calls[8]["data"]["stage"] == "execution"
assert calls[11]["data"]["toolName"] == "calendar_write"
assert calls[12]["data"]["stepName"] == "report"
assert calls[14]["data"]["delta"] == "hello world"
assert calls[13]["data"]["messageId"] == calls[14]["data"]["messageId"]
assert calls[14]["data"]["messageId"] == calls[15]["data"]["messageId"]
assert calls[15]["data"]["model"] == "qwen3.5-flash"
assert calls[15]["data"]["inputTokens"] == 10
assert calls[15]["data"]["outputTokens"] == 5
assert calls[15]["data"]["cost"] == 0.123
assert calls[15]["data"]["latencyMs"] == 250
assert calls[16]["data"]["stepName"] == "report"
@pytest.mark.asyncio
@@ -140,3 +179,129 @@ async def test_runtime_emits_run_error_when_orchestrator_fails() -> None:
]
assert calls[1]["data"]["stepName"] == "intent"
assert calls[2]["data"]["message"] == "runtime execution failed"
@pytest.mark.asyncio
async def test_runtime_passes_binary_payload_to_orchestrator() -> None:
captured_user_input: object | None = None
class _FakePipeline:
async def emit(self, *, session_id: str, event: dict[str, object]) -> str:
assert session_id == "thread-1"
return str(event.get("type", ""))
class _CaptureOrchestrator:
async def run(self, **kwargs: object) -> RuntimeOutput:
nonlocal captured_user_input
captured_user_input = kwargs.get("user_input")
return RuntimeOutput(
intent=IntentOutput(
route="DIRECT_RESPONSE",
intent_summary="summary",
direct_response="done",
tasks=[],
complexity="simple",
),
execution=None,
report=ReportOutput(
assistant_text="ok",
response_metadata={},
),
)
runtime = AgentRouteRuntime(
orchestrator=_CaptureOrchestrator(),
pipeline=_FakePipeline(),
)
command = RunCommand.model_validate(
{
"threadId": "thread-1",
"runId": "run-1",
"messages": [
{
"id": "u1",
"role": "user",
"content": [
{"type": "text", "text": "hello"},
{
"type": "binary",
"mimeType": "image/png",
"data": "aGVsbG8=",
},
],
}
],
}
)
await runtime.run(
command=command,
owner_id=uuid4(),
user_token="token",
user_context=_user_context(),
session=cast(AsyncSession, object()),
)
assert isinstance(captured_user_input, list)
first = captured_user_input[0]
assert isinstance(first, dict)
content = first.get("content")
assert isinstance(content, list)
binary = content[1]
assert isinstance(binary, dict)
assert binary.get("data") == "aGVsbG8="
@pytest.mark.asyncio
async def test_runtime_direct_response_finishes_without_report_stage() -> None:
calls: list[dict[str, Any]] = []
class _FakePipeline:
async def emit(self, *, session_id: str, event: dict[str, object]) -> str:
assert session_id == "thread-1"
calls.append(event)
return f"{len(calls)}-0"
class _DirectOrchestrator:
async def run(self, **_: object) -> RuntimeOutput:
return RuntimeOutput(
intent=IntentOutput(
route="DIRECT_RESPONSE",
intent_summary="summary",
direct_response="direct-answer",
tasks=[],
complexity="simple",
response_metadata={"latencyMs": 88},
),
execution=None,
report=ReportOutput(
assistant_text="direct-answer",
response_metadata={"latencyMs": 88},
),
)
runtime = AgentRouteRuntime(
orchestrator=_DirectOrchestrator(),
pipeline=_FakePipeline(),
)
command = RunCommand(threadId="thread-1", runId="run-1", messages=[])
await runtime.run(
command=command,
owner_id=uuid4(),
user_token="token",
user_context=_user_context(),
session=cast(AsyncSession, object()),
)
assert [item["type"] for item in calls] == [
"run.started",
"step.start",
"step.finish",
"text.start",
"text.delta",
"text.end",
"run.finished",
]
assert calls[3]["data"]["stage"] == "intent"
assert calls[4]["data"]["delta"] == "direct-answer"
@@ -68,6 +68,7 @@ class _FakeRunner:
"direct_response": "你好",
"tasks": [],
"complexity": "simple",
"response_metadata": {"model": "qwen3.5-flash", "latencyMs": 100},
}
self.report_calls += 1
return {
@@ -131,7 +132,7 @@ async def test_runtime_direct_response_skips_execution(
{
"type": "function",
"function": {
"name": "calendar.read",
"name": "calendar_read",
"description": "read",
"parameters": {"type": "object", "properties": {}},
},
@@ -162,8 +163,10 @@ async def test_runtime_direct_response_skips_execution(
assert result.intent.route == "DIRECT_RESPONSE"
assert result.execution is None
assert result.report.assistant_text == "已完成"
assert result.report.assistant_text == "你好"
assert result.report.response_metadata["model"] == "qwen3.5-flash"
assert fake_runner.execution_calls == 0
assert fake_runner.report_calls == 0
@pytest.mark.asyncio
@@ -183,7 +186,7 @@ async def test_runtime_complex_route_runs_execution(
{
"type": "function",
"function": {
"name": "calendar.read",
"name": "calendar_read",
"description": "read",
"parameters": {"type": "object", "properties": {}},
},
@@ -191,7 +194,7 @@ async def test_runtime_complex_route_runs_execution(
{
"type": "function",
"function": {
"name": "calendar.write",
"name": "calendar_write",
"description": "write",
"parameters": {"type": "object", "properties": {}},
},
@@ -9,6 +9,8 @@ from core.agentscope.schemas.system_agent_config import SystemAgentLLMConfig
from core.agentscope.runtime.config_loader import RuntimeStageConfig
from core.agentscope.runtime.react_runner import (
AgentScopeReActRunner,
_chat_response_text,
_merge_stage_response_metadata,
_parse_json_text,
_to_litellm_model,
)
@@ -32,10 +34,10 @@ def test_to_litellm_model_keeps_prefixed_model() -> None:
)
def test_to_litellm_model_builds_prefixed_model() -> None:
def test_to_litellm_model_uses_plain_model_name_when_unprefixed() -> None:
assert (
_to_litellm_model(provider_name="dashscope", model_code="qwen3.5-flash")
== "dashscope/qwen3.5-flash"
== "qwen3.5-flash"
)
@@ -49,6 +51,24 @@ def test_parse_json_text_rejects_non_json() -> None:
_parse_json_text("not-json")
def test_chat_response_text_falls_back_to_choice_message_content() -> None:
response = SimpleNamespace(
content=None,
choices=[
{
"message": {
"content": '{"assistant_text":"fallback","response_metadata":{}}'
}
}
],
)
assert (
_chat_response_text(response)
== '{"assistant_text":"fallback","response_metadata":{}}'
)
@pytest.mark.asyncio
async def test_run_json_stage_wraps_json_decode_error(
monkeypatch: pytest.MonkeyPatch,
@@ -113,3 +133,88 @@ async def test_run_json_stage_wraps_runtime_error(
user_prompt="user",
toolkit=None,
)
@pytest.mark.asyncio
async def test_run_json_stage_report_merges_usage_metadata(
monkeypatch: pytest.MonkeyPatch,
) -> None:
class _FakeLiteLLMService:
def run_completion_with_cost(self, **kwargs: object) -> object:
del kwargs
return SimpleNamespace(
response={
"model": "dashscope/qwen3.5-flash",
"choices": [
{
"message": {
"content": '{"assistant_text":"ok","response_metadata":{}}'
}
}
],
},
usage=SimpleNamespace(
prompt_tokens=9,
completion_tokens=4,
cost=0.006,
),
)
runner = AgentScopeReActRunner()
monkeypatch.setattr(
runner,
"_build_litellm_service",
lambda: _FakeLiteLLMService(),
)
report_stage = RuntimeStageConfig(
stage="report",
model_code="qwen3.5-flash",
provider_name="dashscope",
llm_config=SystemAgentLLMConfig(
temperature=0.1,
max_tokens=128,
timeout_seconds=30,
),
)
payload = await runner.run_json_stage(
stage_config=report_stage,
agent_name="report-agent",
system_prompt="sys",
user_prompt="user",
toolkit=None,
)
metadata = payload["response_metadata"]
assert metadata["model"] == "dashscope/qwen3.5-flash"
assert metadata["inputTokens"] == 9
assert metadata["outputTokens"] == 4
assert metadata["cost"] == 0.006
assert isinstance(metadata["latencyMs"], int)
assert metadata["latencyMs"] >= 0
def test_merge_stage_response_metadata_estimates_cost_from_pricing(
monkeypatch: pytest.MonkeyPatch,
) -> None:
monkeypatch.setattr(
"core.agentscope.runtime.react_runner._estimate_cost_by_pricing",
lambda **kwargs: 0.0025,
)
payload = _merge_stage_response_metadata(
payload={"route": "DIRECT_RESPONSE", "response_metadata": {}},
stage_config=_stage_config(),
response=SimpleNamespace(
usage=SimpleNamespace(
prompt_tokens=12,
completion_tokens=8,
),
model="qwen3.5-flash",
),
latency_ms=50,
)
metadata = payload["response_metadata"]
assert metadata["inputTokens"] == 12
assert metadata["outputTokens"] == 8
assert metadata["cost"] == 0.0025
@@ -71,6 +71,63 @@ async def test_run_agentscope_task_calls_runtime_run(
assert called["resume"] == 0
@pytest.mark.asyncio
async def test_run_agentscope_task_includes_recent_context_messages(
monkeypatch: pytest.MonkeyPatch,
) -> None:
captured_messages: list[dict[str, Any]] = []
class _FakeRuntime:
def __init__(self, **kwargs: object) -> None:
del kwargs
async def run(self, **kwargs: object) -> object:
command = kwargs.get("command")
if command is not None:
raw_messages = getattr(command, "messages", [])
if isinstance(raw_messages, list):
captured_messages.extend(raw_messages)
return object()
async def resume(self, **kwargs: object) -> object:
del kwargs
return object()
async def _fake_get_redis_client() -> object:
return object()
async def _fake_context(**kwargs: object) -> list[dict[str, Any]]:
del kwargs
return [{"id": "ctx-1", "role": "assistant", "content": "历史上下文"}]
monkeypatch.setattr(tasks_module, "AgentRouteRuntime", _FakeRuntime)
monkeypatch.setattr(
tasks_module,
"get_or_init_redis_client",
_fake_get_redis_client,
)
monkeypatch.setattr(tasks_module, "AsyncSessionLocal", lambda: _FakeSessionCtx())
monkeypatch.setattr(
tasks_module,
"_build_recent_context_messages",
_fake_context,
)
run_input = _run_input_payload()
run_input["messages"] = [{"id": "u1", "role": "user", "content": "现在几点"}]
await tasks_module.run_agentscope_task(
{
"command": "run",
"owner_id": str(uuid4()),
"run_input": run_input,
}
)
assert len(captured_messages) == 2
assert captured_messages[0]["id"] == "ctx-1"
assert captured_messages[1]["id"] == "u1"
@pytest.mark.asyncio
async def test_run_agentscope_task_calls_runtime_resume(
monkeypatch: pytest.MonkeyPatch,
@@ -178,3 +178,89 @@ async def test_calendar_write_rejects_invalid_reminder_minutes(
assert result["data"]["ok"] is False
assert result["data"]["code"] == "INVALID_ARGUMENT"
@pytest.mark.asyncio
async def test_calendar_write_maps_invite_arguments(
monkeypatch: pytest.MonkeyPatch,
) -> None:
captured: dict[str, object] = {}
async def _fake_execute(**kwargs: Any) -> dict[str, object]:
captured.update(cast(dict[str, object], kwargs["tool_args"]))
return {"type": "calendar_card.v1", "version": "v1", "data": {"ok": True}}
monkeypatch.setattr(
calendar_module,
"_execute_mutate_calendar_event",
_fake_execute,
)
monkeypatch.setattr(calendar_module, "_verify_user_token", lambda **_: True)
monkeypatch.setattr(calendar_module, "build_tool_response", lambda payload: payload)
await calendar_module.calendar_write(
session=cast(AsyncSession, SimpleNamespace()),
owner_id=uuid4(),
user_token="token-abc",
operation="create",
invite_user_emails=["a@example.com"],
invite_user_names=["alice"],
invite_user_ids=[str(uuid4())],
invite_permission_view=True,
invite_permission_edit=True,
invite_permission_invite=True,
)
assert captured["inviteUserEmails"] == ["a@example.com"]
assert captured["inviteUserNames"] == ["alice"]
assert isinstance(captured["inviteUserIds"], list)
assert captured["invitePermissionView"] is True
assert captured["invitePermissionEdit"] is True
assert captured["invitePermissionInvite"] is True
@pytest.mark.asyncio
async def test_user_resolve_maps_identity_arguments(
monkeypatch: pytest.MonkeyPatch,
) -> None:
captured: dict[str, object] = {}
async def _fake_execute(**kwargs: Any) -> dict[str, object]:
captured.update(cast(dict[str, object], kwargs["tool_args"]))
return {"type": "user_lookup.v1", "version": "v1", "data": {"ok": True}}
monkeypatch.setattr(
calendar_module,
"_execute_resolve_user_identity",
_fake_execute,
)
monkeypatch.setattr(calendar_module, "_verify_user_token", lambda **_: True)
monkeypatch.setattr(calendar_module, "build_tool_response", lambda payload: payload)
result = await calendar_module.user_resolve(
session=cast(AsyncSession, SimpleNamespace()),
owner_id=uuid4(),
user_token="token-abc",
user_email="a@example.com",
)
assert result["type"] == "user_lookup.v1"
assert captured == {"userEmail": "a@example.com", "userName": None}
@pytest.mark.asyncio
async def test_user_resolve_requires_valid_user_token(
monkeypatch: pytest.MonkeyPatch,
) -> None:
monkeypatch.setattr(calendar_module, "_verify_user_token", lambda **_: False)
monkeypatch.setattr(calendar_module, "build_tool_response", lambda payload: payload)
result = await calendar_module.user_resolve(
session=cast(AsyncSession, SimpleNamespace()),
owner_id=uuid4(),
user_token="bad-token",
user_name="alice",
)
assert result["data"]["ok"] is False
assert result["data"]["code"] == "UNAUTHORIZED"
@@ -0,0 +1,56 @@
from __future__ import annotations
from core.agentscope.prompts.runtime_prompt import build_intent_user_prompt
def test_build_intent_user_prompt_keeps_multimodal_blocks() -> None:
prompt = build_intent_user_prompt(
user_input=[
{
"id": "u1",
"role": "user",
"content": [
{"type": "text", "text": "请识别图片内容"},
{
"type": "binary",
"mimeType": "image/png",
"data": "aGVsbG8=",
},
],
}
]
)
assert isinstance(prompt, list)
assert prompt
assert prompt[0]["type"] == "text"
assert "[Output Schema]" in prompt[0]["text"]
image_blocks = [item for item in prompt if item.get("type") == "image"]
assert len(image_blocks) == 1
source = image_blocks[0]["source"]
assert source["type"] == "base64"
assert source["media_type"] == "image/png"
assert source["data"] == "aGVsbG8="
def test_build_intent_user_prompt_filters_non_image_binary_block() -> None:
prompt = build_intent_user_prompt(
user_input=[
{
"id": "u1",
"role": "user",
"content": [
{"type": "text", "text": "请处理这个输入"},
{
"type": "binary",
"mimeType": "application/pdf",
"data": "aGVsbG8=",
},
],
}
]
)
assert isinstance(prompt, list)
image_blocks = [item for item in prompt if item.get("type") == "image"]
assert image_blocks == []
@@ -20,11 +20,12 @@ async def test_build_toolkit_registers_calendar_tools() -> None:
)
schemas = toolkit.get_json_schemas()
names = {item["function"]["name"] for item in schemas}
assert "calendar.read" in names
assert "calendar.write" in names
assert "calendar_read" in names
assert "calendar_write" in names
assert "user_resolve" in names
write_schema = next(
item for item in schemas if item["function"]["name"] == "calendar.write"
item for item in schemas if item["function"]["name"] == "calendar_write"
)
params = write_schema["function"]["parameters"]["properties"]
assert "user_token" not in params
@@ -33,11 +33,11 @@ def test_calculate_cost_uses_second_qwen_tier() -> None:
def test_run_completion_extracts_usage_and_cost() -> None:
service = LiteLLMService()
captured: dict[str, object] = {}
result = service.run_completion_with_cost(
model="dashscope/qwen3.5-flash",
messages=[{"role": "user", "content": "hello"}],
completion_fn=lambda **_: {
def _fake_completion(**kwargs: object) -> dict[str, object]:
captured.update(kwargs)
return {
"model": "dashscope/qwen3.5-flash",
"usage": {
"prompt_tokens": 2000,
@@ -46,10 +46,17 @@ def test_run_completion_extracts_usage_and_cost() -> None:
"prompt_tokens_details": {"cached_tokens": 500},
},
"choices": [{"message": {"content": "ok"}}],
},
}
result = service.run_completion_with_cost(
model="dashscope/qwen3.5-flash",
messages=[{"role": "user", "content": "hello"}],
response_format={"type": "json_object"},
completion_fn=_fake_completion,
)
assert result.usage.prompt_tokens == 2000
assert result.usage.completion_tokens == 100
assert result.usage.total_tokens == 2100
assert result.usage.cost == pytest.approx(0.00051)
assert captured["response_format"] == {"type": "json_object"}
@@ -10,6 +10,31 @@ from models.agent_chat_message import AgentChatMessageRole
from v1.agent.repository import AgentRepository
class _ExecuteResult:
def __init__(self, value: object) -> None:
self._value = value
def scalar_one_or_none(self) -> object:
return self._value
class _FakeSession:
def __init__(self, session_row: object) -> None:
self.session_row = session_row
self.added: list[object] = []
self.flushed = False
async def execute(self, stmt): # noqa: ANN001
del stmt
return _ExecuteResult(self.session_row)
def add(self, obj: object) -> None:
self.added.append(obj)
async def flush(self) -> None:
self.flushed = True
class _FakeToolResultStorage:
def __init__(self, payload: dict[str, object] | None) -> None:
self._payload = payload
@@ -104,3 +129,48 @@ async def test_user_message_snapshot_includes_renderable_attachments() -> None:
"mimeType": "image/png",
}
]
@pytest.mark.asyncio
async def test_persist_user_message_sets_session_title_when_empty() -> None:
session_id = str(uuid4())
session_row = SimpleNamespace(
message_count=0,
title=None,
last_activity_at=datetime.now(timezone.utc),
)
fake_session = _FakeSession(session_row)
repository = AgentRepository(session=fake_session) # type: ignore[arg-type]
await repository.persist_user_message(
session_id=session_id,
run_id="run-1",
content_text=" 请帮我安排明天下午开会 ",
metadata=None,
)
assert session_row.title == "请帮我安排明天下午开会"
assert session_row.message_count == 1
assert fake_session.flushed is True
@pytest.mark.asyncio
async def test_persist_user_message_keeps_existing_session_title() -> None:
session_id = str(uuid4())
session_row = SimpleNamespace(
message_count=1,
title="已有标题",
last_activity_at=datetime.now(timezone.utc),
)
fake_session = _FakeSession(session_row)
repository = AgentRepository(session=fake_session) # type: ignore[arg-type]
await repository.persist_user_message(
session_id=session_id,
run_id="run-2",
content_text="新的消息内容",
metadata=None,
)
assert session_row.title == "已有标题"
assert session_row.message_count == 2
@@ -175,3 +175,53 @@ async def test_enqueue_resume_accepts_valid_tool_contract(
assert result.task_id == "task-resume-1"
assert result.thread_id == "00000000-0000-0000-0000-000000000001"
assert result.run_id == "run-resume-1"
@pytest.mark.asyncio
async def test_stream_events_retries_on_redis_timeout(
monkeypatch: pytest.MonkeyPatch,
) -> None:
async def _acquire(*, user_id: str) -> bool:
del user_id
return True
async def _release(*, user_id: str) -> None:
del user_id
monkeypatch.setattr(agent_router, "_acquire_sse_slot", _acquire)
monkeypatch.setattr(agent_router, "_release_sse_slot", _release)
class _Request:
async def is_disconnected(self) -> bool:
return False
class _Service:
def __init__(self) -> None:
self.calls = 0
async def stream_events(self, **kwargs): # noqa: ANN003
del kwargs
self.calls += 1
if self.calls == 1:
raise RuntimeError("Timeout reading from localhost:6379")
if self.calls == 2:
return [{"id": "1-0", "event": {"type": "RUN_FINISHED"}}]
return []
response = await agent_router.stream_events(
request=cast(Any, _Request()),
thread_id="00000000-0000-0000-0000-000000000001",
service=cast(Any, _Service()),
current_user=CurrentUser(id=uuid4(), email="user@example.com"),
last_event_id=None,
idle_limit=2,
)
chunks: list[str] = []
async for chunk in response.body_iterator:
chunks.append(str(chunk))
if any("RUN_FINISHED" in item for item in chunks):
break
merged = "".join(chunks)
assert "event: RUN_FINISHED" in merged
@@ -124,6 +124,19 @@ class _FakeAttachmentStorage:
return path
class _AlwaysFailAttachmentStorage:
async def upload_bytes(
self,
*,
bucket: str,
path: str,
content: bytes,
content_type: str,
) -> str:
del bucket, path, content, content_type
raise RuntimeError("upload failed")
def _user() -> CurrentUser:
return CurrentUser(
id=UUID("00000000-0000-0000-0000-000000000001"),
@@ -317,6 +330,54 @@ async def test_enqueue_run_uploads_user_image_to_supabase_and_injects_metadata(
assert isinstance(attachments[0]["path"], str)
async def test_enqueue_run_raises_when_attachment_upload_fails_without_fallback(
monkeypatch,
) -> None:
monkeypatch.setattr(
agent_service_module.config.storage, "bucket", "agent-test-bucket"
)
repository = _FakeRepository()
service = AgentService(
repository=repository,
queue=_FakeQueue(),
stream=_FakeStream(),
attachment_storage=_AlwaysFailAttachmentStorage(),
)
run_input = RunAgentInput.model_validate(
{
"threadId": "00000000-0000-0000-0000-000000000001",
"runId": "run-with-image-fail",
"state": {},
"messages": [
{
"id": "u1",
"role": "user",
"content": [
{"type": "text", "text": "帮我看下这张图"},
{
"type": "binary",
"data": "aGVsbG8=",
"mimeType": "image/png",
},
],
}
],
"tools": [],
"context": [],
"forwardedProps": {},
}
)
try:
await service.enqueue_run(run_input=run_input, current_user=_user())
raise AssertionError("expected HTTPException")
except HTTPException as exc:
assert exc.status_code == 502
assert exc.detail == "Failed to upload attachment"
assert repository.persisted_user_messages == []
async def test_get_history_snapshot_wraps_history_day_as_state_snapshot_event() -> None:
service = AgentService(
repository=_FakeRepository(),
@@ -1,141 +0,0 @@
# Agent Multimodal Smoke Implementation Plan
> **For Claude:** REQUIRED SUB-SKILL: Use superpowers:executing-plans to implement this plan task-by-task.
**Goal:** 完成 agent 三条主链路(runs/events/history)真实冒烟,并支持 RunAgentInput 图片信息在发送链路落 Supabase Storage、在 messages.metadata 持久化、在 history 返回中可渲染。
**Architecture:**`v1/agent` 服务层新增“用户消息持久化 + 图片附件上传”步骤:`enqueue_run` 时解析用户消息 content block,图片上传到 `config.storage.bucket`,将路径写入 `messages.metadata`。运行时继续通过 AgentScope pipeline 输出 AG-UI 事件,SSE 从 Redis stream 订阅,历史查询从 `messages` 回放并附带附件信息。
**Tech Stack:** FastAPI, SQLAlchemy AsyncSession, Supabase Storage Admin Client, Redis SSE stream, AG-UI, pytest/httpx。
---
### Task 1: 用户消息图片附件上传与落库
**Files:**
- Create: `backend/src/v1/agent/attachment_storage.py`
- Modify: `backend/src/v1/agent/service.py`
- Modify: `backend/src/v1/agent/repository.py`
- Test: `backend/tests/unit/v1/agent/test_service.py`
**Step 1: 写失败测试(RED**
```python
@pytest.mark.asyncio
async def test_enqueue_run_persists_user_message_with_uploaded_image_metadata() -> None:
...
```
**Step 2: 运行单测验证失败**
Run: `uv run pytest tests/unit/v1/agent/test_service.py::test_enqueue_run_persists_user_message_with_uploaded_image_metadata -q`
Expected: FAIL(缺少附件上传/metadata 持久化行为)
**Step 3: 最小实现(GREEN**
```python
class AgentAttachmentStorage:
async def upload_bytes(...):
...
class AgentService:
async def enqueue_run(...):
# 解析 user content blocks
# 上传图片到 storage
# repository 持久化 user message(metadata 包含 bucket/path)
...
```
**Step 4: 运行单测验证通过**
Run: `uv run pytest tests/unit/v1/agent/test_service.py::test_enqueue_run_persists_user_message_with_uploaded_image_metadata -q`
Expected: PASS
### Task 2: history 渲染附件路径
**Files:**
- Modify: `backend/src/v1/agent/repository.py`
- Test: `backend/tests/unit/v1/agent/test_repository.py`
**Step 1: 写失败测试(RED**
```python
@pytest.mark.asyncio
async def test_history_includes_user_message_attachments_from_metadata() -> None:
...
```
**Step 2: 运行测试验证失败**
Run: `uv run pytest tests/unit/v1/agent/test_repository.py::test_history_includes_user_message_attachments_from_metadata -q`
Expected: FAILhistory 尚未渲染 attachments
**Step 3: 最小实现(GREEN**
```python
if role == "user" and isinstance(metadata.get("attachments"), list):
payload["attachments"] = metadata["attachments"]
```
**Step 4: 运行测试验证通过**
Run: `uv run pytest tests/unit/v1/agent/test_repository.py::test_history_includes_user_message_attachments_from_metadata -q`
Expected: PASS
### Task 3: 真实冒烟 runs + SSE + history(含图片输入)
**Files:**
- Modify: `backend/tests/integration/v1/agent/test_sse_flow_live.py`
**Step 1: 写失败测试(RED**
```python
@pytest.mark.asyncio
@pytest.mark.live
async def test_agent_runs_events_history_live_with_image_input() -> None:
...
```
**Step 2: 运行 live 测试验证失败(实现前或环境不完整)**
Run: `AGENT_LIVE_INTEGRATION=1 AGENT_LIVE_EMAIL=... AGENT_LIVE_PASSWORD=... uv run pytest tests/integration/v1/agent/test_sse_flow_live.py::test_agent_runs_events_history_live_with_image_input -q -s`
Expected: FAIL(缺 metadata/path 或 history 不含附件)
**Step 3: 最小实现(GREEN**
```python
# live 测试流程:
# 1) 登录拿 token
# 2) POST /runs 发送 text + image(data)
# 3) SSE 订阅直到 RUN_FINISHED/RUN_ERROR
# 4) GET /runs/{thread_id}/history
# 5) SQL 校验 sessions/messages 字段与 metadata.attachments
```
**Step 4: 运行 live 测试验证通过**
Run: `AGENT_LIVE_INTEGRATION=1 AGENT_LIVE_EMAIL=... AGENT_LIVE_PASSWORD=... uv run pytest tests/integration/v1/agent/test_sse_flow_live.py::test_agent_runs_events_history_live_with_image_input -q -s`
Expected: PASS
### Task 4: 全量收口验证与安全门禁
**Files:**
- Modify (if needed): `backend/src/v1/agent/*`, `backend/tests/*`
**Step 1: 回归测试**
Run: `uv run pytest tests/unit/v1/agent tests/unit/core/agentscope tests/integration/v1/agent -q`
Expected: PASS
**Step 2: 静态检查**
Run: `uv run ruff check src/v1/agent src/core/agentscope tests/unit/v1/agent tests/integration/v1/agent`
Expected: PASS
Run: `uv run basedpyright src/v1/agent src/core/agentscope tests/unit/v1/agent tests/integration/v1/agent`
Expected: 0 errors
**Step 3: 评审门禁**
Run agents: `security-reviewer`, `refactor-cleaner`, `code-reviewer`
Expected: 无未解决 CRITICAL/HIGH
@@ -0,0 +1,69 @@
# Agent Multimodal Smoke Runbook
**Goal:** 固化 agent 三条主链路(runs/events/history)的真实冒烟标准与输入基线。
## 1. 覆盖范围
1. `POST /api/v1/agent/runs` - 接收多模态消息(文本+图片)
2. `GET /api/v1/agent/runs/{thread_id}/events` - SSE 事件流,事件名符合 AG-UI 标准(`RUN_STARTED``STEP_STARTED``TOOL_CALL_*``RUN_FINISHED`/`RUN_ERROR`
3. `GET /api/v1/agent/runs/{thread_id}/history` - 返回 `STATE_SNAPSHOT`,含 `attachments` metadata
4. `sessions/messages` 落库完整:message_count、tokens、cost、latency、title、metadata
5. tool result 存储:大 payload 写 storagemetadata 记录 `storage_bucket`/`storage_path`
6. storage bucket 来源:必须来自环境变量 `SOCIAL_STORAGE__BUCKET`
## 2. 固定测试输入
- 图片夹具:`backend/tests/fixtures/images/calendar_text_cn.png`
- 多模态消息:
- 文本:`"识别图片中的日历内容并调用 calendar.write 创建日程"`
- 图片:`{"type":"binary","data":"<base64>","mimeType":"image/png"}`
## 3. 账号与凭据
- 冒烟账号:`dagronl@126.com` / `123456`
- 通过环境变量注入:`AGENT_LIVE_EMAIL``AGENT_LIVE_PASSWORD`
## 4. 执行命令
```bash
AGENT_LIVE_INTEGRATION=1 \
AGENT_LIVE_EMAIL="dagronl@126.com" \
AGENT_LIVE_PASSWORD="123456" \
uv run pytest tests/integration/v1/agent/test_sse_flow_live.py::test_agent_runs_events_history_live_with_image_input -q -s
```
## 5. 结果记录模板
- `thread_id` / `run_id`
- `runs` 状态码与响应
- `events` 事件序列
- `history` 是否含 `attachments[].bucket/path/mimeType`
- `sessions` 字段:message_count / total_tokens / total_cost / status / title
- `messages` 字段:role / content / metadata / tokens / cost / latency
- `tool_result` 是否写 storage
## 6. 安全注意
- 禁止将密码/token 写入 git 跟踪文件
## 7. 已修复问题清单
| 问题 | 修复内容 |
|------|----------|
| bucket 写入失败回退 | 改为直接报错,禁止回退到硬编码 bucket |
| user.resolve 工具 | 新增按 email/name 解析 user_id |
| calendar.write 邀请参数 | 增加 invite 参数透传 |
| inbox_repository 缺失 | 修复 calendar runtime 依赖 |
| runtime 模型名拼接 | 修复无效 model name |
| 多模态透传 | runtime 透传 binary.data,不过滤为 `<omitted>` |
| sessions.title 生成 | 首条用户消息持久化时自动生成 |
| assistant latency 入库 | `messages.latency_ms` 列写入 |
| intent/execution 阶段消息落库 | 新增 `text.*``tool.result` 事件 |
| DIRECT_RESPONSE 早返回 | intent 判定后直接返回,不进入 report 阶段 |
## 8. 待修复问题(用户新增)
1. **意图/执行阶段 tokens/cost 入库** - 目前仅 report 阶段入库
2. **连续会话记忆测试** - 验证 session 是否从数据库读取历史上下文
3. **工具调用测试** - calendar 读/写/删/分享 + 用户查找 + 时间感知
4. **session 失败排查** - 找出最新失败原因并修复
@@ -1,583 +0,0 @@
# 日历邀请弹窗优化 Implementation Plan
> **For Claude:** REQUIRED SUB-SKILL: Use superpowers:executing-plans to implement this plan task-by-task.
**Goal:** 优化日历邀请消息弹窗,显示完整信息(发送者名称 + 日历标题),使用公共弹窗组件替代所有旧弹窗代码
**Architecture:**
- 后端新增用户信息查询接口
- 前端创建公共弹窗组件 MessageActionSheet
- 删除所有旧的弹窗代码(好友请求、日历邀请),统一使用公共组件
**Tech Stack:** Flutter (Dart), FastAPI (Python)
---
### Task 1: 后端添加用户信息查询接口
**Files:**
- Modify: `backend/src/v1/users/router.py`
- Modify: `backend/src/v1/users/service.py`
- Modify: `backend/src/v1/users/repository.py`
**Step 1: 添加 repository 方法**
修改 `backend/src/v1/users/repository.py`,在 `UserRepository``SQLAlchemyUserRepository` 中已有 `get_by_user_id` 方法,确认存在。
**Step 2: 添加 service 方法**
修改 `backend/src/v1/users/service.py`,添加:
```python
async def get_user_by_id(self, user_id: UUID) -> UserBasicInfo:
from v1.friendships.schemas import UserBasicInfo
profile = await self._repository.get_by_user_id(user_id)
if not profile:
raise HTTPException(status_code=404, detail="User not found")
return UserBasicInfo(
id=str(profile.user_id),
username=profile.username,
avatar_url=profile.avatar_url,
)
```
**Step 3: 添加 router 接口**
修改 `backend/src/v1/users/router.py`,添加:
```python
@router.get("/{user_id}", response_model=UserBasicInfo)
async def get_user(
user_id: UUID,
service: Annotated[UserService, Depends(get_user_service)],
):
return await service.get_user_by_id(user_id)
```
**Step 4: 运行 lint 和 typecheck**
```bash
cd backend && uv run ruff check src/v1/users/ && uv run basedpyright src/v1/users/
```
**Step 5: 提交**
```bash
git add backend/src/v1/users/ && git commit -m "feat(users): add get user by id endpoint"
```
---
### Task 2: 前端添加用户 API 接口
**Files:**
- Modify: `apps/lib/features/users/data/users_api.dart`
- Modify: `apps/lib/core/di/injection.dart`
**Step 1: 添加 UserBasicInfo 类和 getById 方法**
修改 `apps/lib/features/users/data/users_api.dart`
```dart
class UserBasicInfo {
final String id;
final String username;
final String? avatarUrl;
UserBasicInfo({
required this.id,
required this.username,
this.avatarUrl,
});
factory UserBasicInfo.fromJson(Map<String, dynamic> json) {
return UserBasicInfo(
id: json['id'] as String,
username: json['username'] as String,
avatarUrl: json['avatar_url'] as String?,
);
}
}
class UsersApi {
final IApiClient _client;
static const _prefix = '/api/v1/users';
UsersApi(this._client);
// ... existing methods
Future<UserBasicInfo> getById(String userId) async {
final response = await _client.get('$_prefix/$userId');
return UserBasicInfo.fromJson(response.data);
}
}
```
**Step 2: 注册到 DI**
修改 `apps/lib/core/di/injection.dart`,添加:
```dart
sl.registerLazySingleton(() => UsersApi(sl<IApiClient>()));
```
**Step 3: 运行 flutter analyze**
```bash
cd apps && flutter analyze lib/features/users/
```
**Step 4: 提交**
```bash
git add apps/lib/features/users/ apps/lib/core/di/injection.dart && git commit -m "feat(users): add getById API and UserBasicInfo"
```
---
### Task 3: 创建公共弹窗组件 MessageActionSheet
**Files:**
- Create: `apps/lib/features/messages/ui/widgets/message_action_sheet.dart`
**Step 1: 创建弹窗组件**
创建 `apps/lib/features/messages/ui/widgets/message_action_sheet.dart`
```dart
import 'package:flutter/material.dart';
import '../../../../core/theme/design_tokens.dart';
import '../../../../shared/widgets/app_button.dart';
class MessageActionSheet extends StatelessWidget {
final String title;
final String? description;
final String? statusText;
final bool isReadOnly;
final VoidCallback? onAccept;
final VoidCallback? onDecline;
final IconData? icon;
final Color? iconColor;
const MessageActionSheet({
super.key,
required this.title,
this.description,
this.statusText,
this.isReadOnly = false,
this.onAccept,
this.onDecline,
this.icon,
this.iconColor,
});
@override
Widget build(BuildContext context) {
return Container(
width: double.infinity,
padding: const EdgeInsets.fromLTRB(24, 20, 24, 0),
decoration: const BoxDecoration(
color: AppColors.white,
borderRadius: BorderRadius.vertical(top: Radius.circular(24)),
),
child: Column(
mainAxisSize: MainAxisSize.min,
children: [
Container(
width: 40,
height: 4,
decoration: BoxDecoration(
color: AppColors.slate300,
borderRadius: BorderRadius.circular(2),
),
),
const SizedBox(height: 20),
if (icon != null) ...[
Container(
width: 72,
height: 72,
decoration: BoxDecoration(
color: (iconColor ?? AppColors.blue500).withValues(alpha: 0.1),
shape: BoxShape.circle,
),
child: Icon(icon, size: 32, color: iconColor ?? AppColors.blue500),
),
const SizedBox(height: 16),
],
Text(
title,
style: const TextStyle(
fontSize: 20,
fontWeight: FontWeight.w600,
color: AppColors.slate900,
),
textAlign: TextAlign.center,
),
if (description != null && description!.isNotEmpty) ...[
const SizedBox(height: 8),
Text(
description!,
style: const TextStyle(fontSize: 14, color: AppColors.slate500),
textAlign: TextAlign.center,
),
],
if (statusText != null) ...[
const SizedBox(height: 16),
Container(
padding: const EdgeInsets.symmetric(horizontal: 12, vertical: 6),
decoration: BoxDecoration(
color: AppColors.slate100,
borderRadius: BorderRadius.circular(16),
),
child: Text(
statusText!,
style: const TextStyle(fontSize: 14, color: AppColors.slate600),
),
),
],
const SizedBox(height: 24),
if (!isReadOnly) ...[
Row(
children: [
Expanded(
child: AppButton(
text: '拒绝',
isOutlined: true,
onPressed: () {
Navigator.pop(context);
onDecline?.call();
},
),
),
const SizedBox(width: AppSpacing.md),
Expanded(
child: AppButton(
text: '接受',
onPressed: () {
Navigator.pop(context);
onAccept?.call();
},
),
),
],
),
],
SizedBox(height: MediaQuery.of(context).padding.bottom + 12),
],
),
);
}
}
```
**Step 2: 运行 flutter analyze**
```bash
cd apps && flutter analyze lib/features/messages/ui/widgets/message_action_sheet.dart
```
**Step 3: 提交**
```bash
git add apps/lib/features/messages/ui/widgets/message_action_sheet.dart && git commit -m "feat(messages): add MessageActionSheet component"
```
---
### Task 4: 重构消息列表页面,使用公共组件并删除旧代码
**Files:**
- Modify: `apps/lib/features/messages/ui/screens/message_invite_list_screen.dart`
**Step 1: 添加依赖和字段**
在文件顶部添加:
```dart
import '../../../users/data/users_api.dart';
import '../widgets/message_action_sheet.dart';
```
`_MessageInviteListScreenState` 中添加:
```dart
late final UsersApi _usersApi;
```
`initState` 中添加:
```dart
_usersApi = sl<UsersApi>();
```
**Step 2: 添加获取日历邀请信息方法**
```dart
Future<(String calendarTitle, String senderName)?> _getCalendarInviteInfo(
InboxMessageResponse message,
) async {
if (message.scheduleItemId == null || message.senderId == null) {
return null;
}
try {
final calendar = await _calendarApi.getById(message.scheduleItemId!);
final sender = await _usersApi.getById(message.senderId!);
return (calendar.title, sender.username);
} catch (e) {
return null;
}
}
```
**Step 3: 替换日历邀请弹窗方法**
删除旧的 `_showCalendarInviteSheet` 方法,替换为:
```dart
Future<void> _showCalendarInviteSheet(InboxMessageResponse message) async {
final itemId = message.scheduleItemId;
if (itemId == null) return;
final info = await _getCalendarInviteInfo(message);
final title = info != null
? '${info.$2} 邀请你加入日历'
: '日历邀请';
final description = info?.$1;
if (!mounted) return;
showModalBottomSheet<void>(
context: context,
backgroundColor: Colors.transparent,
builder: (ctx) => MessageActionSheet(
title: title,
description: description,
icon: Icons.calendar_today,
iconColor: AppColors.blue500,
onAccept: () async {
try {
await _calendarApi.acceptSubscription(itemId);
await _inboxApi.markAsRead(message.id);
if (mounted) {
Toast.show(context, '已接受', type: ToastType.success);
_loadMessages();
}
} catch (e) {
if (mounted) {
Toast.show(context, '操作失败', type: ToastType.error);
}
}
},
onDecline: () async {
try {
await _calendarApi.rejectSubscription(itemId);
await _inboxApi.markAsRead(message.id);
if (mounted) {
Toast.show(context, '已拒绝', type: ToastType.success);
_loadMessages();
}
} catch (e) {
if (mounted) {
Toast.show(context, '操作失败', type: ToastType.error);
}
}
},
),
);
}
```
**Step 4: 添加已读日历邀请弹窗方法**
```dart
Future<void> _showCalendarInviteReadOnlySheet(InboxMessageResponse message) async {
final itemId = message.scheduleItemId;
if (itemId == null) return;
final info = await _getCalendarInviteInfo(message);
final title = info != null
? '${info.$2} 邀请你加入日历'
: '日历邀请';
final description = info?.$1;
final statusText = message.status.value == 'accepted' ? '已接受' : '已拒绝';
if (!mounted) return;
showModalBottomSheet<void>(
context: context,
backgroundColor: Colors.transparent,
builder: (ctx) => MessageActionSheet(
title: title,
description: description,
statusText: statusText,
isReadOnly: true,
icon: Icons.calendar_today,
iconColor: AppColors.blue500,
),
);
}
```
**Step 5: 替换好友请求弹窗方法**
删除旧的 `_showFriendRequestReadOnlySheet``_showFriendRequestActionSheet` 方法,替换为:
```dart
void _showFriendRequestSheet(MessageWithFriend item, {bool isReadOnly = false}) {
final message = item.message;
final friendRequest = item.friendRequest;
if (friendRequest == null) return;
final title = '${friendRequest.sender.username} 请求添加您为好友';
final description = message.content;
final statusText = isReadOnly
? (friendRequest.status == 'accepted'
? '已接受'
: friendRequest.status == 'rejected'
? '已拒绝'
: '已处理')
: null;
showModalBottomSheet<void>(
context: context,
backgroundColor: Colors.transparent,
isScrollControlled: true,
builder: (ctx) => MessageActionSheet(
title: title,
description: description,
statusText: statusText,
isReadOnly: isReadOnly,
icon: Icons.person_add_outlined,
iconColor: AppColors.emerald500,
onAccept: isReadOnly
? null
: () async {
await _processFriendRequest(item, accept: true);
},
onDecline: isReadOnly
? null
: () async {
await _processFriendRequest(item, accept: false);
},
),
);
}
```
**Step 6: 修改 _handleMessageTap 方法**
修改为调用新的统一方法:
```dart
case InboxMessageType.calendar:
final content = _parseCalendarContent(message.content);
if (content == null) return;
final type = content['type'] as String?;
if (type == 'invite') {
if (message.status.value == 'pending') {
await _showCalendarInviteSheet(message);
} else {
await _showCalendarInviteReadOnlySheet(message);
if (message.scheduleItemId != null && context.mounted) {
context.push('/calendar/events/${message.scheduleItemId}');
}
}
} else if (type == 'update') {
if (message.scheduleItemId != null) {
context.push('/calendar/events/${message.scheduleItemId}');
}
}
return;
case InboxMessageType.friendRequest:
if (item.friendRequest == null) {
Toast.show(context, '发送者信息加载失败,请下拉重试', type: ToastType.error);
return;
}
_showFriendRequestSheet(item, isReadOnly: message.isRead);
return;
```
**Step 7: 删除旧的 _FriendRequestSheet 类**
删除文件末尾的整个 `_FriendRequestSheet` 类(约605-749行)。
**Step 8: 运行 flutter analyze**
```bash
cd apps && flutter analyze lib/features/messages/ui/screens/message_invite_list_screen.dart
```
**Step 9: 提交**
```bash
git add apps/lib/features/messages/ && git commit -m "refactor(messages): use MessageActionSheet for all message types"
```
---
### Task 5: 删除日历消息卡片中的旧弹窗代码
**Files:**
- Modify: `apps/lib/features/messages/ui/widgets/calendar_message_card.dart`
**Step 1: 修改 CalendarInviteCard**
CalendarInviteCard 是用于列表展示的卡片,不需要显示弹窗。检查是否有不必要的硬编码,如果有则清理。
**Step 2: 运行 flutter analyze**
```bash
cd apps && flutter analyze lib/features/messages/ui/widgets/calendar_message_card.dart
```
**Step 3: 提交**
```bash
git add apps/lib/features/calendar_message_card.dart && git commit/messages/ui/widgets -f "chore(messages): clean up calendar message card"
```
---
### Task 6: 验证和测试
**Step 1: 运行完整测试**
```bash
cd apps && flutter test test/features/messages/
cd backend && uv run pytest tests/unit/v1/users/ -v
```
**Step 2: 手动测试场景**
1. 用户 A 发送日历邀请给用户 B
2. 用户 B 打开未读消息,点击日历邀请
3. 弹窗显示:"XXX 邀请你加入 [日历标题]"
4. 点击接受/拒绝
5. 用户 B 打开已读消息,点击日历邀请
6. 弹窗显示状态标签
7. 好友请求未读/已读都使用相同弹窗组件
---
## Summary
| Task | Description |
|------|-------------|
| 1 | 后端添加用户信息查询接口 `/api/v1/users/{user_id}` |
| 2 | 前端添加 UsersApi.getById 方法 |
| 3 | 创建公共弹窗组件 MessageActionSheet |
| 4 | 重构消息列表页面,删除旧弹窗代码,统一使用 MessageActionSheet |
| 5 | 清理日历消息卡片旧代码 |
| 6 | 验证测试 |
**Plan complete and saved to `docs/plans/2026-03-11-calendar-invite-sheet.md`. Two execution options:**
1. **Subagent-Driven (this session)** - I dispatch fresh subagent per task, review between tasks, fast iteration
2. **Parallel Session (separate)** - Open new session with executing-plans, batch execution with checkpoints
Which approach?