feat(agent): 增强多模态链路与工具调用能力
This commit is contained in:
@@ -0,0 +1,20 @@
|
||||
{
|
||||
"$schema": "https://opencode.ai/config.json",
|
||||
"mcp": {
|
||||
"supabase": {
|
||||
"type": "local",
|
||||
"enabled": true,
|
||||
"command": [
|
||||
"npx",
|
||||
"-y",
|
||||
"@aliyun-rds/supabase-mcp-server",
|
||||
"--supabase-url",
|
||||
"http://47.112.66.83",
|
||||
"--supabase-anon-key",
|
||||
"eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJpc3MiOiJzdXBhYmFzZSIsInJvbGUiOiJhbm9uIiwiaWF0IjoxNzczMDI3NDE5LCJleHAiOjEzMjgzNjY3NDE5fQ.NVXDla5_nYPdcJk_81fc3k1UrnNTrNne_trMqt6Hg4g",
|
||||
"--supabase-service-role-key",
|
||||
"eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJpc3MiOiJzdXBhYmFzZSIsInJvbGUiOiJzZXJ2aWNlX3JvbGUiLCJpYXQiOjE3NzMwMjc0MTksImV4cCI6MTMyODM2Njc0MTl9.RzQBia-3QcjupsHnqaxgDWB7wnY9R7Ms9R8pMokyvLY"
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -23,10 +23,12 @@ class MessageRepository:
|
||||
role: AgentChatMessageRole,
|
||||
content: str,
|
||||
model_code: str | None = None,
|
||||
tool_name: str | None = None,
|
||||
metadata: dict[str, object] | None = None,
|
||||
input_tokens: int = 0,
|
||||
output_tokens: int = 0,
|
||||
cost: Decimal = Decimal("0"),
|
||||
latency_ms: int | None = None,
|
||||
) -> AgentChatMessage:
|
||||
message = AgentChatMessage(
|
||||
session_id=session_id,
|
||||
@@ -34,10 +36,12 @@ class MessageRepository:
|
||||
role=role,
|
||||
content=content,
|
||||
model_code=model_code,
|
||||
tool_name=tool_name,
|
||||
metadata_json=metadata,
|
||||
input_tokens=input_tokens,
|
||||
output_tokens=output_tokens,
|
||||
cost=cost,
|
||||
latency_ms=latency_ms,
|
||||
)
|
||||
self._session.add(message)
|
||||
await self._session.flush()
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from decimal import Decimal, InvalidOperation
|
||||
from typing import Any, Callable, Protocol
|
||||
from uuid import UUID
|
||||
@@ -24,6 +25,7 @@ class SqlAlchemyEventStore:
|
||||
def __init__(self, *, session_factory: Any) -> None:
|
||||
self._session_factory = session_factory
|
||||
self._message_buffers: dict[tuple[str, str], str] = {}
|
||||
self._message_contexts: dict[tuple[str, str], dict[str, object]] = {}
|
||||
|
||||
async def persist(self, event: dict[str, Any]) -> None:
|
||||
event_type = str(event.get("type", "")).strip().upper()
|
||||
@@ -48,6 +50,10 @@ class SqlAlchemyEventStore:
|
||||
self._buffer_text_delta(session_key=session_key, event=event)
|
||||
return
|
||||
|
||||
if event_type == "TEXT_MESSAGE_START":
|
||||
self._buffer_text_context(session_key=session_key, event=event)
|
||||
return
|
||||
|
||||
if event_type == "RUN_STARTED":
|
||||
await self._update_session_state(
|
||||
session_repo=session_repo,
|
||||
@@ -72,7 +78,15 @@ class SqlAlchemyEventStore:
|
||||
)
|
||||
self._clear_session_buffers(session_key=session_key)
|
||||
elif event_type == "TEXT_MESSAGE_END":
|
||||
await self._persist_assistant_message(
|
||||
await self._persist_text_message(
|
||||
event=event,
|
||||
session_id=session_id,
|
||||
chat_session=chat_session,
|
||||
session_repo=session_repo,
|
||||
message_repo=message_repo,
|
||||
)
|
||||
elif event_type == "TOOL_CALL_RESULT":
|
||||
await self._persist_tool_call_result(
|
||||
event=event,
|
||||
session_id=session_id,
|
||||
chat_session=chat_session,
|
||||
@@ -97,8 +111,28 @@ class SqlAlchemyEventStore:
|
||||
stale_keys = [k for k in self._message_buffers if k[0] == session_key]
|
||||
for key in stale_keys:
|
||||
self._message_buffers.pop(key, None)
|
||||
stale_context_keys = [k for k in self._message_contexts if k[0] == session_key]
|
||||
for key in stale_context_keys:
|
||||
self._message_contexts.pop(key, None)
|
||||
|
||||
async def _persist_assistant_message(
|
||||
def _buffer_text_context(self, *, session_key: str, event: dict[str, Any]) -> None:
|
||||
message_id = event.get("messageId")
|
||||
if not isinstance(message_id, str) or not message_id:
|
||||
return
|
||||
key = (session_key, message_id)
|
||||
role = event.get("role")
|
||||
stage = event.get("stage")
|
||||
tool_name = event.get("toolName")
|
||||
context: dict[str, object] = {}
|
||||
if isinstance(role, str) and role:
|
||||
context["role"] = role
|
||||
if isinstance(stage, str) and stage:
|
||||
context["stage"] = stage
|
||||
if isinstance(tool_name, str) and tool_name:
|
||||
context["tool_name"] = tool_name
|
||||
self._message_contexts[key] = context
|
||||
|
||||
async def _persist_text_message(
|
||||
self,
|
||||
*,
|
||||
event: dict[str, Any],
|
||||
@@ -114,6 +148,8 @@ class SqlAlchemyEventStore:
|
||||
if not content:
|
||||
return
|
||||
|
||||
context = self._message_contexts.get(key, {})
|
||||
|
||||
input_tokens = self._to_int(event.get("inputTokens"))
|
||||
output_tokens = self._to_int(event.get("outputTokens"))
|
||||
token_delta = input_tokens + output_tokens
|
||||
@@ -127,6 +163,20 @@ class SqlAlchemyEventStore:
|
||||
metadata["run_id"] = run_id
|
||||
if latency_ms is not None:
|
||||
metadata["latency_ms"] = latency_ms
|
||||
stage = event.get("stage")
|
||||
if not isinstance(stage, str):
|
||||
stage = context.get("stage")
|
||||
if isinstance(stage, str) and stage:
|
||||
metadata["stage"] = stage
|
||||
|
||||
role_value = context.get("role")
|
||||
if not isinstance(role_value, str):
|
||||
role_value = "assistant"
|
||||
role = self._resolve_role(role_value)
|
||||
tool_name = context.get("tool_name")
|
||||
tool_name_value = (
|
||||
tool_name if isinstance(tool_name, str) and tool_name else None
|
||||
)
|
||||
|
||||
locked_session = await session_repo.lock_session_for_update(
|
||||
session_id=session_id
|
||||
@@ -137,13 +187,15 @@ class SqlAlchemyEventStore:
|
||||
await message_repo.append_message(
|
||||
session_id=session_id,
|
||||
seq=seq,
|
||||
role=AgentChatMessageRole.ASSISTANT,
|
||||
role=role,
|
||||
content=content,
|
||||
model_code=model_code if isinstance(model_code, str) else None,
|
||||
tool_name=tool_name_value,
|
||||
metadata=metadata,
|
||||
input_tokens=input_tokens,
|
||||
output_tokens=output_tokens,
|
||||
cost=cost,
|
||||
latency_ms=latency_ms,
|
||||
)
|
||||
|
||||
current_status = getattr(chat_session, "status", AgentChatSessionStatus.RUNNING)
|
||||
@@ -161,6 +213,74 @@ class SqlAlchemyEventStore:
|
||||
cost_delta=cost,
|
||||
)
|
||||
self._message_buffers.pop(key, None)
|
||||
self._message_contexts.pop(key, None)
|
||||
|
||||
async def _persist_tool_call_result(
|
||||
self,
|
||||
*,
|
||||
event: dict[str, Any],
|
||||
session_id: UUID,
|
||||
chat_session: Any,
|
||||
session_repo: SessionRepository,
|
||||
message_repo: MessageRepository,
|
||||
) -> None:
|
||||
tool_name = event.get("toolName")
|
||||
if not isinstance(tool_name, str) or not tool_name:
|
||||
return
|
||||
|
||||
payload = {
|
||||
"args": event.get("args"),
|
||||
"result": event.get("result"),
|
||||
"error": event.get("error"),
|
||||
"call_id": event.get("callId"),
|
||||
}
|
||||
content = json.dumps(payload, ensure_ascii=False, separators=(",", ":"))
|
||||
metadata: dict[str, object] = {"tool_name": tool_name}
|
||||
run_id = event.get("runId")
|
||||
if isinstance(run_id, str) and run_id:
|
||||
metadata["run_id"] = run_id
|
||||
stage = event.get("stage")
|
||||
if isinstance(stage, str) and stage:
|
||||
metadata["stage"] = stage
|
||||
task_id = event.get("taskId")
|
||||
if isinstance(task_id, str) and task_id:
|
||||
metadata["task_id"] = task_id
|
||||
|
||||
locked_session = await session_repo.lock_session_for_update(
|
||||
session_id=session_id
|
||||
)
|
||||
if locked_session is None:
|
||||
return
|
||||
seq = int(getattr(locked_session, "message_count", 0) or 0) + 1
|
||||
await message_repo.append_message(
|
||||
session_id=session_id,
|
||||
seq=seq,
|
||||
role=AgentChatMessageRole.TOOL,
|
||||
content=content,
|
||||
tool_name=tool_name,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
current_status = getattr(chat_session, "status", AgentChatSessionStatus.RUNNING)
|
||||
status = (
|
||||
current_status
|
||||
if isinstance(current_status, AgentChatSessionStatus)
|
||||
else AgentChatSessionStatus.RUNNING
|
||||
)
|
||||
await self._update_session_state(
|
||||
session_repo=session_repo,
|
||||
chat_session=chat_session,
|
||||
status=status,
|
||||
message_delta=1,
|
||||
)
|
||||
|
||||
def _resolve_role(self, value: str) -> AgentChatMessageRole:
|
||||
normalized = value.strip().lower()
|
||||
if normalized == AgentChatMessageRole.SYSTEM.value:
|
||||
return AgentChatMessageRole.SYSTEM
|
||||
if normalized == AgentChatMessageRole.TOOL.value:
|
||||
return AgentChatMessageRole.TOOL
|
||||
return AgentChatMessageRole.ASSISTANT
|
||||
|
||||
async def _update_session_state(
|
||||
self,
|
||||
|
||||
@@ -38,7 +38,38 @@ def _schema_json(model: type[Any]) -> str:
|
||||
)
|
||||
|
||||
|
||||
def build_intent_user_prompt(*, user_input: str | list[dict[str, Any]]) -> str:
|
||||
def build_intent_user_prompt(
|
||||
*, user_input: str | list[dict[str, Any]]
|
||||
) -> str | list[dict[str, Any]]:
|
||||
if isinstance(user_input, list):
|
||||
instruction_text = "\n\n".join(
|
||||
[
|
||||
INTENT_TASK_INSTRUCTION,
|
||||
"[Output Schema]",
|
||||
_schema_json(IntentOutput),
|
||||
"[User Input]",
|
||||
"Use the following multimodal blocks as the latest user input.",
|
||||
]
|
||||
)
|
||||
blocks = [
|
||||
{
|
||||
"type": "text",
|
||||
"text": instruction_text,
|
||||
}
|
||||
]
|
||||
user_blocks = _latest_user_content_blocks(user_input)
|
||||
if not user_blocks:
|
||||
user_blocks = [
|
||||
{
|
||||
"type": "text",
|
||||
"text": json.dumps(
|
||||
user_input, ensure_ascii=True, separators=(",", ":")
|
||||
),
|
||||
}
|
||||
]
|
||||
blocks.extend(user_blocks)
|
||||
return blocks
|
||||
|
||||
normalized_input = (
|
||||
user_input
|
||||
if isinstance(user_input, str)
|
||||
@@ -55,6 +86,101 @@ def build_intent_user_prompt(*, user_input: str | list[dict[str, Any]]) -> str:
|
||||
)
|
||||
|
||||
|
||||
def _latest_user_content_blocks(
|
||||
user_input: list[dict[str, Any]],
|
||||
) -> list[dict[str, Any]]:
|
||||
for message in reversed(user_input):
|
||||
if not isinstance(message, dict):
|
||||
continue
|
||||
if message.get("role") != "user":
|
||||
continue
|
||||
content = message.get("content")
|
||||
if isinstance(content, str):
|
||||
text = content.strip()
|
||||
return [{"type": "text", "text": text}] if text else []
|
||||
if not isinstance(content, list):
|
||||
return []
|
||||
|
||||
blocks: list[dict[str, Any]] = []
|
||||
for item in content:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
item_type = item.get("type")
|
||||
if item_type == "text":
|
||||
text = item.get("text")
|
||||
if isinstance(text, str) and text.strip():
|
||||
blocks.append({"type": "text", "text": text})
|
||||
continue
|
||||
|
||||
if item_type == "binary":
|
||||
source_block = _binary_source_block(item)
|
||||
if source_block is not None:
|
||||
blocks.append(source_block)
|
||||
continue
|
||||
|
||||
if item_type == "image":
|
||||
source_block = _image_source_block(item)
|
||||
if source_block is not None:
|
||||
blocks.append(source_block)
|
||||
|
||||
return blocks
|
||||
return []
|
||||
|
||||
|
||||
def _binary_source_block(item: dict[str, Any]) -> dict[str, Any] | None:
|
||||
mime_type = item.get("mimeType")
|
||||
media_type = mime_type if isinstance(mime_type, str) and mime_type else "image/png"
|
||||
if not media_type.startswith("image/"):
|
||||
return None
|
||||
|
||||
source_url = item.get("url")
|
||||
if isinstance(source_url, str) and source_url:
|
||||
return {"type": "image", "source": {"type": "url", "url": source_url}}
|
||||
|
||||
source_data = item.get("data")
|
||||
if isinstance(source_data, str) and source_data:
|
||||
return {
|
||||
"type": "image",
|
||||
"source": {
|
||||
"type": "base64",
|
||||
"media_type": media_type,
|
||||
"data": source_data,
|
||||
},
|
||||
}
|
||||
return None
|
||||
|
||||
|
||||
def _image_source_block(item: dict[str, Any]) -> dict[str, Any] | None:
|
||||
source = item.get("source")
|
||||
if not isinstance(source, dict):
|
||||
return None
|
||||
|
||||
source_type = source.get("type")
|
||||
if source_type == "url":
|
||||
source_url = source.get("value") or source.get("url")
|
||||
if isinstance(source_url, str) and source_url:
|
||||
return {"type": "image", "source": {"type": "url", "url": source_url}}
|
||||
|
||||
if source_type in {"data", "base64"}:
|
||||
source_data = source.get("value") or source.get("data")
|
||||
if isinstance(source_data, str) and source_data:
|
||||
mime_type = source.get("mimeType") or source.get("media_type")
|
||||
media_type = (
|
||||
mime_type if isinstance(mime_type, str) and mime_type else "image/png"
|
||||
)
|
||||
if not media_type.startswith("image/"):
|
||||
return None
|
||||
return {
|
||||
"type": "image",
|
||||
"source": {
|
||||
"type": "base64",
|
||||
"media_type": media_type,
|
||||
"data": source_data,
|
||||
},
|
||||
}
|
||||
return None
|
||||
|
||||
|
||||
def build_execution_user_prompt(
|
||||
*,
|
||||
task_id: str,
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Any, Protocol
|
||||
from uuid import UUID
|
||||
|
||||
@@ -153,6 +154,44 @@ class AgentRouteRuntime:
|
||||
},
|
||||
)
|
||||
|
||||
await self._emit_stage_text(
|
||||
thread_id=command.thread_id,
|
||||
run_id=command.run_id,
|
||||
stage_name="intent",
|
||||
message_id=f"intent-{command.run_id}",
|
||||
text=_intent_text_payload(result.intent),
|
||||
response_metadata=result.intent.response_metadata,
|
||||
)
|
||||
|
||||
if result.intent.route == "DIRECT_RESPONSE" and result.execution is None:
|
||||
await self._pipeline.emit(
|
||||
session_id=command.thread_id,
|
||||
event={
|
||||
"type": "run.finished",
|
||||
"threadId": command.thread_id,
|
||||
"runId": command.run_id,
|
||||
"data": {},
|
||||
},
|
||||
)
|
||||
return result
|
||||
|
||||
if result.execution is not None:
|
||||
for index, task in enumerate(result.execution.task_results, start=1):
|
||||
await self._emit_stage_text(
|
||||
thread_id=command.thread_id,
|
||||
run_id=command.run_id,
|
||||
stage_name="execution",
|
||||
message_id=f"execution-{command.run_id}-{index}",
|
||||
text=task.execution_summary,
|
||||
response_metadata=task.response_metadata,
|
||||
)
|
||||
await self._emit_tool_result_events(
|
||||
thread_id=command.thread_id,
|
||||
run_id=command.run_id,
|
||||
task_id=task.task_id,
|
||||
tool_calls=_task_tool_calls(task),
|
||||
)
|
||||
|
||||
await self._pipeline.emit(
|
||||
session_id=command.thread_id,
|
||||
event={
|
||||
@@ -164,35 +203,18 @@ class AgentRouteRuntime:
|
||||
)
|
||||
|
||||
report_message_id = f"assistant-{command.run_id}"
|
||||
await self._pipeline.emit(
|
||||
session_id=command.thread_id,
|
||||
event={
|
||||
"type": "text.start",
|
||||
"threadId": command.thread_id,
|
||||
"runId": command.run_id,
|
||||
"data": {"messageId": report_message_id, "role": "assistant"},
|
||||
},
|
||||
response_metadata = (
|
||||
result.report.response_metadata
|
||||
if isinstance(result.report.response_metadata, dict)
|
||||
else {}
|
||||
)
|
||||
await self._pipeline.emit(
|
||||
session_id=command.thread_id,
|
||||
event={
|
||||
"type": "text.delta",
|
||||
"threadId": command.thread_id,
|
||||
"runId": command.run_id,
|
||||
"data": {
|
||||
"messageId": report_message_id,
|
||||
"delta": result.report.assistant_text,
|
||||
},
|
||||
},
|
||||
)
|
||||
await self._pipeline.emit(
|
||||
session_id=command.thread_id,
|
||||
event={
|
||||
"type": "text.end",
|
||||
"threadId": command.thread_id,
|
||||
"runId": command.run_id,
|
||||
"data": {"messageId": report_message_id},
|
||||
},
|
||||
await self._emit_stage_text(
|
||||
thread_id=command.thread_id,
|
||||
run_id=command.run_id,
|
||||
stage_name="report",
|
||||
message_id=report_message_id,
|
||||
text=result.report.assistant_text,
|
||||
response_metadata=response_metadata,
|
||||
)
|
||||
await self._pipeline.emit(
|
||||
session_id=command.thread_id,
|
||||
@@ -213,3 +235,178 @@ class AgentRouteRuntime:
|
||||
},
|
||||
)
|
||||
return result
|
||||
|
||||
async def _emit_stage_text(
|
||||
self,
|
||||
*,
|
||||
thread_id: str,
|
||||
run_id: str,
|
||||
stage_name: str,
|
||||
message_id: str,
|
||||
text: str,
|
||||
response_metadata: dict[str, Any],
|
||||
) -> None:
|
||||
await self._pipeline.emit(
|
||||
session_id=thread_id,
|
||||
event={
|
||||
"type": "text.start",
|
||||
"threadId": thread_id,
|
||||
"runId": run_id,
|
||||
"data": {
|
||||
"messageId": message_id,
|
||||
"role": "assistant",
|
||||
"stage": stage_name,
|
||||
},
|
||||
},
|
||||
)
|
||||
await self._pipeline.emit(
|
||||
session_id=thread_id,
|
||||
event={
|
||||
"type": "text.delta",
|
||||
"threadId": thread_id,
|
||||
"runId": run_id,
|
||||
"data": {"messageId": message_id, "delta": text},
|
||||
},
|
||||
)
|
||||
await self._pipeline.emit(
|
||||
session_id=thread_id,
|
||||
event={
|
||||
"type": "text.end",
|
||||
"threadId": thread_id,
|
||||
"runId": run_id,
|
||||
"data": {
|
||||
"messageId": message_id,
|
||||
"stage": stage_name,
|
||||
**_text_end_telemetry_payload(response_metadata),
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
async def _emit_tool_result_events(
|
||||
self,
|
||||
*,
|
||||
thread_id: str,
|
||||
run_id: str,
|
||||
task_id: str,
|
||||
tool_calls: list[dict[str, Any]],
|
||||
) -> None:
|
||||
for index, tool_call in enumerate(tool_calls, start=1):
|
||||
tool_name = tool_call.get("tool_name")
|
||||
if not isinstance(tool_name, str) or not tool_name:
|
||||
continue
|
||||
await self._pipeline.emit(
|
||||
session_id=thread_id,
|
||||
event={
|
||||
"type": "tool.result",
|
||||
"threadId": thread_id,
|
||||
"runId": run_id,
|
||||
"data": {
|
||||
"callId": f"{run_id}-{task_id}-{index}",
|
||||
"stage": "execution",
|
||||
"taskId": task_id,
|
||||
"toolName": tool_name,
|
||||
"args": tool_call.get("args", {}),
|
||||
"result": tool_call.get("result"),
|
||||
"error": tool_call.get("error"),
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def _text_end_telemetry_payload(metadata: dict[str, Any]) -> dict[str, Any]:
|
||||
payload: dict[str, Any] = {}
|
||||
model = _first_non_empty_str(metadata, keys=("model", "model_code"))
|
||||
if model is not None:
|
||||
payload["model"] = model
|
||||
|
||||
input_tokens = _first_number(metadata, keys=("inputTokens", "input_tokens"))
|
||||
if input_tokens is not None:
|
||||
payload["inputTokens"] = input_tokens
|
||||
|
||||
output_tokens = _first_number(metadata, keys=("outputTokens", "output_tokens"))
|
||||
if output_tokens is not None:
|
||||
payload["outputTokens"] = output_tokens
|
||||
|
||||
latency_ms = _first_number(metadata, keys=("latencyMs", "latency_ms"))
|
||||
if latency_ms is not None:
|
||||
payload["latencyMs"] = latency_ms
|
||||
|
||||
cost = _first_number(metadata, keys=("cost", "total_cost"), allow_float=True)
|
||||
if cost is not None:
|
||||
payload["cost"] = cost
|
||||
|
||||
return payload
|
||||
|
||||
|
||||
def _intent_text_payload(intent: Any) -> str:
|
||||
direct_response = getattr(intent, "direct_response", None)
|
||||
if isinstance(direct_response, str) and direct_response.strip():
|
||||
return direct_response
|
||||
return json.dumps(intent.model_dump(mode="json"), ensure_ascii=False)
|
||||
|
||||
|
||||
def _task_tool_calls(task: Any) -> list[dict[str, Any]]:
|
||||
normalized: list[dict[str, Any]] = []
|
||||
|
||||
tool_calls = getattr(task, "tool_calls", None)
|
||||
if isinstance(tool_calls, list):
|
||||
for item in tool_calls:
|
||||
if hasattr(item, "model_dump"):
|
||||
dumped = item.model_dump(mode="json")
|
||||
if isinstance(dumped, dict):
|
||||
normalized.append(dumped)
|
||||
elif isinstance(item, dict):
|
||||
normalized.append(item)
|
||||
|
||||
if normalized:
|
||||
return normalized
|
||||
|
||||
execution_data = getattr(task, "execution_data", None)
|
||||
if not isinstance(execution_data, dict):
|
||||
return []
|
||||
fallback_calls = execution_data.get("tool_calls")
|
||||
if not isinstance(fallback_calls, list):
|
||||
return []
|
||||
|
||||
for item in fallback_calls:
|
||||
if isinstance(item, dict):
|
||||
normalized.append(item)
|
||||
return normalized
|
||||
|
||||
|
||||
def _first_non_empty_str(
|
||||
metadata: dict[str, Any], *, keys: tuple[str, ...]
|
||||
) -> str | None:
|
||||
for key in keys:
|
||||
value = metadata.get(key)
|
||||
if isinstance(value, str) and value.strip():
|
||||
return value.strip()
|
||||
return None
|
||||
|
||||
|
||||
def _first_number(
|
||||
metadata: dict[str, Any],
|
||||
*,
|
||||
keys: tuple[str, ...],
|
||||
allow_float: bool = False,
|
||||
) -> int | float | None:
|
||||
for key in keys:
|
||||
value = metadata.get(key)
|
||||
if isinstance(value, bool):
|
||||
continue
|
||||
if isinstance(value, int):
|
||||
if value < 0:
|
||||
continue
|
||||
return value
|
||||
if isinstance(value, float):
|
||||
if value < 0:
|
||||
continue
|
||||
return value if allow_float else int(value)
|
||||
if isinstance(value, str):
|
||||
try:
|
||||
parsed = float(value) if allow_float else int(value)
|
||||
except ValueError:
|
||||
continue
|
||||
if parsed >= 0:
|
||||
return parsed
|
||||
return None
|
||||
|
||||
@@ -110,6 +110,19 @@ class AgentScopeRuntimeOrchestrator:
|
||||
)
|
||||
intent_output = IntentOutput.model_validate(intent_payload)
|
||||
|
||||
if intent_output.route == "DIRECT_RESPONSE":
|
||||
assistant_text = (
|
||||
intent_output.direct_response or intent_output.intent_summary
|
||||
)
|
||||
return RuntimeOutput(
|
||||
intent=intent_output,
|
||||
execution=None,
|
||||
report=ReportOutput(
|
||||
assistant_text=assistant_text,
|
||||
response_metadata=dict(intent_output.response_metadata),
|
||||
),
|
||||
)
|
||||
|
||||
execution_output: ExecutionBatchOutput | None = None
|
||||
if intent_output.route == "TASK_EXECUTION":
|
||||
execution_toolkit = build_stage_toolkit(
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from time import perf_counter
|
||||
from typing import Any, cast
|
||||
|
||||
from core.agentscope.runtime.config_loader import RuntimeStageConfig
|
||||
@@ -14,7 +16,8 @@ def _to_litellm_model(*, provider_name: str, model_code: str) -> str:
|
||||
normalized_model = model_code.strip()
|
||||
if "/" in normalized_model:
|
||||
return normalized_model
|
||||
return f"{provider_name.strip().lower()}/{normalized_model}"
|
||||
del provider_name
|
||||
return normalized_model
|
||||
|
||||
|
||||
def _parse_json_text(raw_text: str) -> dict[str, Any]:
|
||||
@@ -30,6 +33,11 @@ def _parse_json_text(raw_text: str) -> dict[str, Any]:
|
||||
|
||||
|
||||
class AgentScopeReActRunner:
|
||||
def _build_litellm_service(self) -> Any:
|
||||
from services.litellm.service import LiteLLMService
|
||||
|
||||
return LiteLLMService()
|
||||
|
||||
def _build_model(self, *, stage_config: RuntimeStageConfig) -> Any:
|
||||
from agentscope.model import OpenAIChatModel
|
||||
from agentscope.types import JSONSerializableObject
|
||||
@@ -61,9 +69,16 @@ class AgentScopeReActRunner:
|
||||
stage_config: RuntimeStageConfig,
|
||||
agent_name: str,
|
||||
system_prompt: str,
|
||||
user_prompt: str,
|
||||
user_prompt: str | list[dict[str, Any]],
|
||||
toolkit: Any | None,
|
||||
) -> dict[str, Any]:
|
||||
if stage_config.stage == "report" and toolkit is None:
|
||||
return await self._run_report_stage_direct(
|
||||
stage_config=stage_config,
|
||||
system_prompt=system_prompt,
|
||||
user_prompt=user_prompt,
|
||||
)
|
||||
|
||||
from agentscope.agent import ReActAgent
|
||||
from agentscope.formatter import OpenAIChatFormatter
|
||||
from agentscope.memory import InMemoryMemory
|
||||
@@ -79,9 +94,19 @@ class AgentScopeReActRunner:
|
||||
max_iters=6,
|
||||
)
|
||||
try:
|
||||
response = await agent(Msg(name="user", content=user_prompt, role="user"))
|
||||
started_at = perf_counter()
|
||||
response = await agent(
|
||||
Msg(name="user", content=cast(Any, user_prompt), role="user")
|
||||
)
|
||||
latency_ms = int(round((perf_counter() - started_at) * 1000))
|
||||
text_content = response.get_text_content() or "{}"
|
||||
return _parse_json_text(text_content)
|
||||
payload = _parse_json_text(text_content)
|
||||
return _merge_stage_response_metadata(
|
||||
payload=payload,
|
||||
stage_config=stage_config,
|
||||
response=response,
|
||||
latency_ms=latency_ms,
|
||||
)
|
||||
except json.JSONDecodeError as exc:
|
||||
logger.exception(
|
||||
"agentscope stage output is not valid json",
|
||||
@@ -96,3 +121,234 @@ class AgentScopeReActRunner:
|
||||
agent_name=agent_name,
|
||||
)
|
||||
raise RuntimeError("agent execution failed") from exc
|
||||
|
||||
async def _run_report_stage_direct(
|
||||
self,
|
||||
*,
|
||||
stage_config: RuntimeStageConfig,
|
||||
system_prompt: str,
|
||||
user_prompt: str | list[dict[str, Any]],
|
||||
) -> dict[str, Any]:
|
||||
try:
|
||||
service = self._build_litellm_service()
|
||||
started_at = perf_counter()
|
||||
response_with_cost = await asyncio.to_thread(
|
||||
service.run_completion_with_cost,
|
||||
model=_to_litellm_model(
|
||||
provider_name=stage_config.provider_name,
|
||||
model_code=stage_config.model_code,
|
||||
),
|
||||
messages=_report_messages(
|
||||
system_prompt=system_prompt,
|
||||
user_prompt=user_prompt,
|
||||
),
|
||||
temperature=stage_config.llm_config.temperature,
|
||||
max_tokens=stage_config.llm_config.max_tokens,
|
||||
timeout=stage_config.llm_config.timeout_seconds,
|
||||
response_format={"type": "json_object"},
|
||||
)
|
||||
latency_ms = int(round((perf_counter() - started_at) * 1000))
|
||||
|
||||
text_content = _chat_response_text(response_with_cost.response)
|
||||
payload = _parse_json_text(text_content)
|
||||
return _merge_report_response_metadata(
|
||||
payload=payload,
|
||||
stage_config=stage_config,
|
||||
response_with_cost=response_with_cost,
|
||||
latency_ms=latency_ms,
|
||||
)
|
||||
except json.JSONDecodeError as exc:
|
||||
logger.exception(
|
||||
"agentscope stage output is not valid json",
|
||||
stage=stage_config.stage,
|
||||
agent_name="report-agent",
|
||||
)
|
||||
raise RuntimeError("agent output format invalid") from exc
|
||||
except Exception as exc:
|
||||
logger.exception(
|
||||
"agentscope stage execution failed",
|
||||
stage=stage_config.stage,
|
||||
agent_name="report-agent",
|
||||
)
|
||||
raise RuntimeError("agent execution failed") from exc
|
||||
|
||||
|
||||
def _chat_response_text(response: Any) -> str:
|
||||
content = _read_value(response, "content")
|
||||
if isinstance(content, str) and content.strip():
|
||||
return content
|
||||
|
||||
if not isinstance(content, list):
|
||||
return _fallback_choice_content(response)
|
||||
|
||||
text_parts: list[str] = []
|
||||
for block in content:
|
||||
if not isinstance(block, dict):
|
||||
continue
|
||||
if block.get("type") != "text":
|
||||
continue
|
||||
text = block.get("text")
|
||||
if isinstance(text, str) and text:
|
||||
text_parts.append(text)
|
||||
if text_parts:
|
||||
return "".join(text_parts)
|
||||
return _fallback_choice_content(response)
|
||||
|
||||
|
||||
def _fallback_choice_content(response: Any) -> str:
|
||||
choices = _read_value(response, "choices")
|
||||
if not isinstance(choices, list) or not choices:
|
||||
return "{}"
|
||||
|
||||
first_choice = choices[0]
|
||||
message = getattr(first_choice, "message", None)
|
||||
if message is None and isinstance(first_choice, dict):
|
||||
message = first_choice.get("message")
|
||||
|
||||
if isinstance(message, dict):
|
||||
content = message.get("content")
|
||||
return content if isinstance(content, str) and content else "{}"
|
||||
|
||||
content = _read_value(message, "content")
|
||||
return content if isinstance(content, str) and content else "{}"
|
||||
|
||||
|
||||
def _read_value(source: Any, key: str) -> Any:
|
||||
if isinstance(source, dict):
|
||||
return source.get(key)
|
||||
return getattr(source, key, None)
|
||||
|
||||
|
||||
def _report_messages(
|
||||
*, system_prompt: str, user_prompt: str | list[dict[str, Any]]
|
||||
) -> list[dict[str, Any]]:
|
||||
return [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_prompt},
|
||||
]
|
||||
|
||||
|
||||
def _merge_stage_response_metadata(
|
||||
*,
|
||||
payload: dict[str, Any],
|
||||
stage_config: RuntimeStageConfig,
|
||||
response: Any,
|
||||
latency_ms: int,
|
||||
) -> dict[str, Any]:
|
||||
result = dict(payload)
|
||||
existing = result.get("response_metadata")
|
||||
metadata: dict[str, Any] = dict(existing) if isinstance(existing, dict) else {}
|
||||
metadata.setdefault("model", stage_config.model_code)
|
||||
|
||||
usage = _read_value(response, "usage")
|
||||
prompt_tokens = _to_non_negative_int(
|
||||
_read_value(usage, "prompt_tokens") or _read_value(usage, "input_tokens")
|
||||
)
|
||||
completion_tokens = _to_non_negative_int(
|
||||
_read_value(usage, "completion_tokens") or _read_value(usage, "output_tokens")
|
||||
)
|
||||
cost = _to_non_negative_float(
|
||||
_read_value(usage, "cost")
|
||||
or _read_value(_read_value(usage, "metadata"), "cost")
|
||||
)
|
||||
resolved_model = _read_value(response, "model")
|
||||
if cost is None and prompt_tokens is not None and completion_tokens is not None:
|
||||
estimated_cost = _estimate_cost_by_pricing(
|
||||
model=resolved_model
|
||||
if isinstance(resolved_model, str)
|
||||
else stage_config.model_code,
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
)
|
||||
if estimated_cost is not None:
|
||||
cost = estimated_cost
|
||||
|
||||
if prompt_tokens is not None:
|
||||
metadata["inputTokens"] = prompt_tokens
|
||||
if completion_tokens is not None:
|
||||
metadata["outputTokens"] = completion_tokens
|
||||
if cost is not None:
|
||||
metadata["cost"] = cost
|
||||
if latency_ms >= 0:
|
||||
metadata["latencyMs"] = latency_ms
|
||||
|
||||
result["response_metadata"] = metadata
|
||||
return result
|
||||
|
||||
|
||||
def _merge_report_response_metadata(
|
||||
*,
|
||||
payload: dict[str, Any],
|
||||
stage_config: RuntimeStageConfig,
|
||||
response_with_cost: Any,
|
||||
latency_ms: int,
|
||||
) -> dict[str, Any]:
|
||||
result = dict(payload)
|
||||
existing = result.get("response_metadata")
|
||||
metadata: dict[str, Any] = dict(existing) if isinstance(existing, dict) else {}
|
||||
usage = _read_value(response_with_cost, "usage")
|
||||
response = _read_value(response_with_cost, "response")
|
||||
|
||||
resolved_model = _read_value(response, "model")
|
||||
if isinstance(resolved_model, str) and resolved_model.strip():
|
||||
metadata["model"] = resolved_model.strip()
|
||||
else:
|
||||
metadata.setdefault("model", stage_config.model_code)
|
||||
|
||||
input_tokens = _to_non_negative_int(_read_value(usage, "prompt_tokens"))
|
||||
output_tokens = _to_non_negative_int(_read_value(usage, "completion_tokens"))
|
||||
cost = _to_non_negative_float(_read_value(usage, "cost"))
|
||||
if input_tokens is not None:
|
||||
metadata["inputTokens"] = input_tokens
|
||||
if output_tokens is not None:
|
||||
metadata["outputTokens"] = output_tokens
|
||||
if cost is not None:
|
||||
metadata["cost"] = cost
|
||||
if latency_ms >= 0:
|
||||
metadata["latencyMs"] = latency_ms
|
||||
|
||||
result["response_metadata"] = metadata
|
||||
return result
|
||||
|
||||
|
||||
def _to_non_negative_int(value: Any) -> int | None:
|
||||
if isinstance(value, bool):
|
||||
return None
|
||||
if not isinstance(value, (int, float, str)):
|
||||
return None
|
||||
try:
|
||||
parsed = int(value)
|
||||
except (TypeError, ValueError):
|
||||
return None
|
||||
return parsed if parsed >= 0 else None
|
||||
|
||||
|
||||
def _to_non_negative_float(value: Any) -> float | None:
|
||||
if isinstance(value, bool):
|
||||
return None
|
||||
if not isinstance(value, (int, float, str)):
|
||||
return None
|
||||
try:
|
||||
parsed = float(value)
|
||||
except (TypeError, ValueError):
|
||||
return None
|
||||
return parsed if parsed >= 0 else None
|
||||
|
||||
|
||||
def _estimate_cost_by_pricing(
|
||||
*, model: str, prompt_tokens: int, completion_tokens: int
|
||||
) -> float | None:
|
||||
normalized_model = model.strip()
|
||||
if not normalized_model:
|
||||
return None
|
||||
from services.litellm.service import LiteLLMService
|
||||
|
||||
service = LiteLLMService()
|
||||
try:
|
||||
return service.calculate_cost(
|
||||
model=normalized_model,
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
)
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
@@ -1,8 +1,11 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.agentscope.events import (
|
||||
AgentScopeAgUiCodec,
|
||||
AgentScopeEventPipeline,
|
||||
@@ -18,6 +21,7 @@ from core.config.settings import config
|
||||
from core.db.session import AsyncSessionLocal
|
||||
from core.logging import get_logger
|
||||
from core.taskiq.app import bulk_broker, critical_broker, default_broker
|
||||
from models.agent_chat_message import AgentChatMessage, AgentChatMessageRole
|
||||
from services.base.redis import get_or_init_redis_client
|
||||
|
||||
logger = get_logger("core.agentscope.runtime.tasks")
|
||||
@@ -76,6 +80,56 @@ def _extract_user_token(
|
||||
return None
|
||||
|
||||
|
||||
async def _build_recent_context_messages(
|
||||
*,
|
||||
session: Any,
|
||||
thread_id: str,
|
||||
current_run_id: str,
|
||||
max_messages: int = 20,
|
||||
) -> list[dict[str, Any]]:
|
||||
try:
|
||||
session_uuid = UUID(thread_id)
|
||||
except ValueError:
|
||||
return []
|
||||
|
||||
utc_now = datetime.now(timezone.utc)
|
||||
start_of_today = utc_now.replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
start_of_yesterday = start_of_today - timedelta(days=1)
|
||||
|
||||
stmt = (
|
||||
select(AgentChatMessage)
|
||||
.where(AgentChatMessage.session_id == session_uuid)
|
||||
.where(AgentChatMessage.deleted_at.is_(None))
|
||||
.where(AgentChatMessage.created_at >= start_of_yesterday)
|
||||
.order_by(AgentChatMessage.seq.asc())
|
||||
)
|
||||
rows = (await session.execute(stmt)).scalars().all()
|
||||
|
||||
normalized: list[dict[str, Any]] = []
|
||||
for row in rows:
|
||||
metadata = row.metadata_json if isinstance(row.metadata_json, dict) else {}
|
||||
if metadata.get("run_id") == current_run_id:
|
||||
continue
|
||||
role = (
|
||||
row.role.value
|
||||
if isinstance(row.role, AgentChatMessageRole)
|
||||
else str(row.role)
|
||||
)
|
||||
if role not in {"user", "assistant"}:
|
||||
continue
|
||||
normalized.append(
|
||||
{
|
||||
"id": str(row.id),
|
||||
"role": role,
|
||||
"content": row.content,
|
||||
}
|
||||
)
|
||||
|
||||
if len(normalized) <= max_messages:
|
||||
return normalized
|
||||
return normalized[-max_messages:]
|
||||
|
||||
|
||||
async def run_agentscope_task(command: dict[str, Any]) -> dict[str, object]:
|
||||
command_type = str(command.get("command", "run")).strip().lower()
|
||||
raw_run_input = command.get("run_input")
|
||||
@@ -117,6 +171,21 @@ async def run_agentscope_task(command: dict[str, Any]) -> dict[str, object]:
|
||||
)
|
||||
|
||||
async with AsyncSessionLocal() as session:
|
||||
if command_type == "run":
|
||||
context_messages = await _build_recent_context_messages(
|
||||
session=session,
|
||||
thread_id=parsed_run_input.thread_id,
|
||||
current_run_id=parsed_run_input.run_id,
|
||||
)
|
||||
parsed_run_input = parsed_run_input.model_copy(
|
||||
update={
|
||||
"messages": [
|
||||
*context_messages,
|
||||
*parsed_run_input.messages,
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
if command_type == "resume":
|
||||
await runtime.resume(
|
||||
command=parsed_run_input,
|
||||
|
||||
@@ -5,12 +5,21 @@ from typing import Any, Literal
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class ExecutionToolCall(BaseModel):
|
||||
tool_name: str = Field(min_length=1)
|
||||
args: dict[str, Any] = Field(default_factory=dict)
|
||||
result: Any | None = None
|
||||
error: str | None = None
|
||||
|
||||
|
||||
class ExecutionTaskOutput(BaseModel):
|
||||
task_id: str = Field(min_length=1)
|
||||
status: Literal["SUCCESS", "PARTIAL", "FAILED"]
|
||||
execution_summary: str = Field(min_length=1)
|
||||
execution_data: dict[str, Any] = Field(default_factory=dict)
|
||||
user_feedback_needs: list[str] = Field(default_factory=list)
|
||||
response_metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
tool_calls: list[ExecutionToolCall] = Field(default_factory=list)
|
||||
|
||||
|
||||
class ExecutionBatchOutput(BaseModel):
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
@@ -17,6 +18,7 @@ class IntentOutput(BaseModel):
|
||||
direct_response: str | None = None
|
||||
tasks: list[IntentTask] = Field(default_factory=list)
|
||||
complexity: Literal["simple", "complex"]
|
||||
response_metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_route(self) -> "IntentOutput":
|
||||
|
||||
@@ -1,3 +1,7 @@
|
||||
from core.agentscope.tools.custom.calendar import calendar_read, calendar_write
|
||||
from core.agentscope.tools.custom.calendar import (
|
||||
calendar_read,
|
||||
calendar_write,
|
||||
user_resolve,
|
||||
)
|
||||
|
||||
__all__ = ["calendar_read", "calendar_write"]
|
||||
__all__ = ["calendar_read", "calendar_write", "user_resolve"]
|
||||
|
||||
@@ -7,6 +7,7 @@ from core.auth.jwt_verifier import JwtVerifier, TokenValidationError
|
||||
from core.agentscope.tools.custom.calendar_backend_ops import (
|
||||
_execute_list_calendar_events,
|
||||
_execute_mutate_calendar_event,
|
||||
_execute_resolve_user_identity,
|
||||
)
|
||||
from core.config.settings import config
|
||||
from core.agentscope.tools.response import build_tool_response
|
||||
@@ -150,6 +151,30 @@ async def calendar_write(
|
||||
bool,
|
||||
Field(description="Whether to use the replace strategy for conflicts."),
|
||||
] = False,
|
||||
invite_user_emails: Annotated[
|
||||
list[str] | None,
|
||||
Field(description="Optional invite targets by email."),
|
||||
] = None,
|
||||
invite_user_names: Annotated[
|
||||
list[str] | None,
|
||||
Field(description="Optional invite targets by username."),
|
||||
] = None,
|
||||
invite_user_ids: Annotated[
|
||||
list[str] | None,
|
||||
Field(description="Optional invite targets by user ID (UUID string)."),
|
||||
] = None,
|
||||
invite_permission_view: Annotated[
|
||||
bool,
|
||||
Field(description="Invite permission: view."),
|
||||
] = True,
|
||||
invite_permission_edit: Annotated[
|
||||
bool,
|
||||
Field(description="Invite permission: edit."),
|
||||
] = False,
|
||||
invite_permission_invite: Annotated[
|
||||
bool,
|
||||
Field(description="Invite permission: invite others."),
|
||||
] = False,
|
||||
session: Any = None,
|
||||
owner_id: Any = None,
|
||||
user_token: str | None = None,
|
||||
@@ -240,6 +265,15 @@ async def calendar_write(
|
||||
tool_args["reminderMinutes"] = reminder_minutes
|
||||
if status is not None:
|
||||
tool_args["status"] = status
|
||||
if invite_user_emails is not None:
|
||||
tool_args["inviteUserEmails"] = invite_user_emails
|
||||
if invite_user_names is not None:
|
||||
tool_args["inviteUserNames"] = invite_user_names
|
||||
if invite_user_ids is not None:
|
||||
tool_args["inviteUserIds"] = invite_user_ids
|
||||
tool_args["invitePermissionView"] = invite_permission_view
|
||||
tool_args["invitePermissionEdit"] = invite_permission_edit
|
||||
tool_args["invitePermissionInvite"] = invite_permission_invite
|
||||
|
||||
result = await _execute_mutate_calendar_event(
|
||||
session=cast(Any, session),
|
||||
@@ -247,3 +281,34 @@ async def calendar_write(
|
||||
tool_args=tool_args,
|
||||
)
|
||||
return build_tool_response(result)
|
||||
|
||||
|
||||
async def user_resolve(
|
||||
user_email: Annotated[
|
||||
str | None,
|
||||
Field(description="User email to resolve user ID."),
|
||||
] = None,
|
||||
user_name: Annotated[
|
||||
str | None,
|
||||
Field(description="Username to resolve user ID."),
|
||||
] = None,
|
||||
session: Any = None,
|
||||
owner_id: Any = None,
|
||||
user_token: str | None = None,
|
||||
) -> Any:
|
||||
if session is None or owner_id is None:
|
||||
raise ValueError("user.resolve missing runtime preset arguments")
|
||||
if not isinstance(user_token, str) or not user_token.strip():
|
||||
return build_tool_response(_unauthorized_response())
|
||||
if not _verify_user_token(user_token=user_token, owner_id=cast(UUID, owner_id)):
|
||||
return build_tool_response(_unauthorized_response())
|
||||
|
||||
result = await _execute_resolve_user_identity(
|
||||
session=cast(Any, session),
|
||||
owner_id=cast(UUID, owner_id),
|
||||
tool_args={
|
||||
"userEmail": user_email,
|
||||
"userName": user_name,
|
||||
},
|
||||
)
|
||||
return build_tool_response(result)
|
||||
|
||||
@@ -4,13 +4,20 @@ import re
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from core.auth.models import CurrentUser
|
||||
from services.base.supabase import supabase_service
|
||||
from models.profile import Profile
|
||||
from v1.auth.gateway import SupabaseAuthGateway
|
||||
from v1.inbox_messages.repository import SQLAlchemyInboxMessageRepository
|
||||
from v1.schedule_items.repository import SQLAlchemyScheduleItemRepository
|
||||
from v1.schedule_items.schemas import (
|
||||
ScheduleItemCreateRequest,
|
||||
ScheduleItemMetadata,
|
||||
ScheduleItemShareRequest,
|
||||
ScheduleItemStatus,
|
||||
ScheduleItemUpdateRequest,
|
||||
)
|
||||
@@ -72,9 +79,196 @@ def _service(session: AsyncSession, owner_id: UUID) -> ScheduleItemService:
|
||||
repository=SQLAlchemyScheduleItemRepository(session),
|
||||
session=session,
|
||||
current_user=CurrentUser(id=owner_id),
|
||||
inbox_repository=SQLAlchemyInboxMessageRepository(session),
|
||||
)
|
||||
|
||||
|
||||
def _parse_string_list(value: object, *, field_name: str) -> list[str]:
|
||||
if value is None:
|
||||
return []
|
||||
if not isinstance(value, list):
|
||||
raise ValueError(f"{field_name} must be a list of strings")
|
||||
parsed: list[str] = []
|
||||
for item in value:
|
||||
if not isinstance(item, str) or not item.strip():
|
||||
raise ValueError(f"{field_name} must be a list of non-empty strings")
|
||||
parsed.append(item.strip())
|
||||
return parsed
|
||||
|
||||
|
||||
def _list_auth_users() -> list[object]:
|
||||
admin_client = supabase_service.get_admin_client()
|
||||
users: list[object] = []
|
||||
page = 1
|
||||
while page <= 100:
|
||||
response = admin_client.auth.admin.list_users(page=page, per_page=100)
|
||||
batch = (
|
||||
list(response)
|
||||
if isinstance(response, list)
|
||||
else list(getattr(response, "users", []))
|
||||
)
|
||||
users.extend(batch)
|
||||
if len(batch) < 100:
|
||||
break
|
||||
page += 1
|
||||
return users
|
||||
|
||||
|
||||
async def _get_profile_username(*, session: AsyncSession, user_id: UUID) -> str | None:
|
||||
stmt = select(Profile.username).where(Profile.id == user_id)
|
||||
return (await session.execute(stmt)).scalar_one_or_none()
|
||||
|
||||
|
||||
async def _get_profile_by_username(
|
||||
*, session: AsyncSession, username: str
|
||||
) -> Profile | None:
|
||||
stmt = (
|
||||
select(Profile)
|
||||
.where(Profile.username == username)
|
||||
.where(Profile.deleted_at.is_(None))
|
||||
)
|
||||
return (await session.execute(stmt)).scalar_one_or_none()
|
||||
|
||||
|
||||
def _find_auth_email_by_user_id(*, users: list[object], user_id: UUID) -> str | None:
|
||||
target = str(user_id)
|
||||
for user in users:
|
||||
if str(getattr(user, "id", "")) == target:
|
||||
email = getattr(user, "email", None)
|
||||
if isinstance(email, str) and email.strip():
|
||||
return email.strip()
|
||||
return None
|
||||
|
||||
|
||||
async def _resolve_identity(
|
||||
*,
|
||||
session: AsyncSession,
|
||||
user_email: str | None,
|
||||
user_name: str | None,
|
||||
) -> dict[str, object]:
|
||||
email = user_email.strip().lower() if isinstance(user_email, str) else ""
|
||||
name = user_name.strip() if isinstance(user_name, str) else ""
|
||||
if bool(email) == bool(name):
|
||||
raise ValueError("provide exactly one of user_email or user_name")
|
||||
|
||||
if email:
|
||||
auth_gateway = SupabaseAuthGateway()
|
||||
user = await auth_gateway.get_user_by_email(email)
|
||||
user_id = UUID(user.id)
|
||||
username = await _get_profile_username(session=session, user_id=user_id)
|
||||
return {
|
||||
"userId": str(user_id),
|
||||
"email": user.email,
|
||||
"username": username,
|
||||
"matchedBy": "email",
|
||||
}
|
||||
|
||||
profile = await _get_profile_by_username(session=session, username=name)
|
||||
if profile is None:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
users = _list_auth_users()
|
||||
email_value = _find_auth_email_by_user_id(users=users, user_id=profile.id)
|
||||
return {
|
||||
"userId": str(profile.id),
|
||||
"email": email_value,
|
||||
"username": profile.username,
|
||||
"matchedBy": "username",
|
||||
}
|
||||
|
||||
|
||||
def _invite_permission(tool_args: dict[str, object]) -> dict[str, bool]:
|
||||
return {
|
||||
"permission_view": bool(tool_args.get("invitePermissionView", True)),
|
||||
"permission_edit": bool(tool_args.get("invitePermissionEdit", False)),
|
||||
"permission_invite": bool(tool_args.get("invitePermissionInvite", False)),
|
||||
}
|
||||
|
||||
|
||||
async def _share_event_with_invitees(
|
||||
*,
|
||||
session: AsyncSession,
|
||||
owner_id: UUID,
|
||||
event_id: UUID,
|
||||
tool_args: dict[str, object],
|
||||
) -> dict[str, object] | None:
|
||||
email_targets = _parse_string_list(
|
||||
tool_args.get("inviteUserEmails"),
|
||||
field_name="inviteUserEmails",
|
||||
)
|
||||
name_targets = _parse_string_list(
|
||||
tool_args.get("inviteUserNames"),
|
||||
field_name="inviteUserNames",
|
||||
)
|
||||
id_targets = _parse_string_list(
|
||||
tool_args.get("inviteUserIds"),
|
||||
field_name="inviteUserIds",
|
||||
)
|
||||
if not email_targets and not name_targets and not id_targets:
|
||||
return None
|
||||
|
||||
users = _list_auth_users() if id_targets else []
|
||||
emails = {item.lower() for item in email_targets}
|
||||
for user_id_raw in id_targets:
|
||||
try:
|
||||
user_id = UUID(user_id_raw)
|
||||
except ValueError as exc:
|
||||
raise ValueError("inviteUserIds must contain valid UUID strings") from exc
|
||||
resolved_email = _find_auth_email_by_user_id(users=users, user_id=user_id)
|
||||
if resolved_email is None:
|
||||
raise HTTPException(status_code=404, detail="Invite user email not found")
|
||||
emails.add(resolved_email.lower())
|
||||
for username in name_targets:
|
||||
resolved = await _resolve_identity(
|
||||
session=session,
|
||||
user_email=None,
|
||||
user_name=username,
|
||||
)
|
||||
resolved_email = resolved.get("email")
|
||||
if not isinstance(resolved_email, str) or not resolved_email:
|
||||
raise HTTPException(status_code=404, detail="Invite user email not found")
|
||||
emails.add(resolved_email.lower())
|
||||
|
||||
service = _service(session, owner_id)
|
||||
permission = _invite_permission(tool_args)
|
||||
invited: list[str] = []
|
||||
for email in sorted(emails):
|
||||
request = ScheduleItemShareRequest(email=email, **permission)
|
||||
await service.share(event_id, request)
|
||||
invited.append(email)
|
||||
return {
|
||||
"count": len(invited),
|
||||
"emails": invited,
|
||||
"permission": permission,
|
||||
}
|
||||
|
||||
|
||||
async def _execute_resolve_user_identity(
|
||||
*,
|
||||
session: AsyncSession,
|
||||
owner_id: UUID,
|
||||
tool_args: dict[str, object],
|
||||
) -> dict[str, object]:
|
||||
del owner_id
|
||||
user_email_raw = tool_args.get("userEmail")
|
||||
user_name_raw = tool_args.get("userName")
|
||||
user_email = user_email_raw if isinstance(user_email_raw, str) else None
|
||||
user_name = user_name_raw if isinstance(user_name_raw, str) else None
|
||||
resolved = await _resolve_identity(
|
||||
session=session,
|
||||
user_email=user_email,
|
||||
user_name=user_name,
|
||||
)
|
||||
return {
|
||||
"type": "user_lookup.v1",
|
||||
"version": "v1",
|
||||
"data": {
|
||||
"ok": True,
|
||||
**resolved,
|
||||
},
|
||||
"actions": [],
|
||||
}
|
||||
|
||||
|
||||
def _resolve_metadata(tool_args: dict[str, object]) -> ScheduleItemMetadata:
|
||||
location = tool_args.get("location")
|
||||
location_value = location.strip() if isinstance(location, str) else None
|
||||
@@ -185,6 +379,12 @@ async def _execute_create(
|
||||
)
|
||||
event_data = _event_payload(created)
|
||||
event_id = str(event_data["id"])
|
||||
invite_result = await _share_event_with_invitees(
|
||||
session=service._session,
|
||||
owner_id=service.require_user_id(),
|
||||
event_id=UUID(event_id),
|
||||
tool_args=tool_args,
|
||||
)
|
||||
return {
|
||||
"type": "calendar_card.v1",
|
||||
"version": "v1",
|
||||
@@ -193,12 +393,13 @@ async def _execute_create(
|
||||
"sourceType": "agent_generated",
|
||||
"ok": True,
|
||||
"message": "日程已创建",
|
||||
"inviteResult": invite_result,
|
||||
},
|
||||
"actions": [
|
||||
{
|
||||
"type": "link",
|
||||
"label": "查看详情",
|
||||
"target": f"/calendar/events/{event_id}",
|
||||
"target": f"/schedule-items/{event_id}",
|
||||
}
|
||||
],
|
||||
}
|
||||
@@ -274,6 +475,12 @@ async def _execute_update(
|
||||
ScheduleItemUpdateRequest.model_validate(update_data),
|
||||
)
|
||||
event_data = _event_payload(updated)
|
||||
invite_result = await _share_event_with_invitees(
|
||||
session=service._session,
|
||||
owner_id=service.require_user_id(),
|
||||
event_id=UUID(str(event_data["id"])),
|
||||
tool_args=tool_args,
|
||||
)
|
||||
return {
|
||||
"type": "calendar_card.v1",
|
||||
"version": "v1",
|
||||
@@ -282,12 +489,13 @@ async def _execute_update(
|
||||
"sourceType": "agent_generated",
|
||||
"ok": True,
|
||||
"message": "日程已更新",
|
||||
"inviteResult": invite_result,
|
||||
},
|
||||
"actions": [
|
||||
{
|
||||
"type": "link",
|
||||
"label": "查看详情",
|
||||
"target": f"/calendar/events/{event_data['id']}",
|
||||
"target": f"/schedule-items/{event_data['id']}",
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
@@ -4,8 +4,9 @@ from dataclasses import dataclass
|
||||
|
||||
|
||||
TOOL_APPROVAL_REQUIRED: dict[str, bool] = {
|
||||
"calendar.read": False,
|
||||
"calendar.write": False,
|
||||
"calendar_read": False,
|
||||
"calendar_write": False,
|
||||
"user_resolve": False,
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -6,7 +6,11 @@ from uuid import UUID
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from core.agentscope.tools.custom.calendar import calendar_read, calendar_write
|
||||
from core.agentscope.tools.custom.calendar import (
|
||||
calendar_read,
|
||||
calendar_write,
|
||||
user_resolve,
|
||||
)
|
||||
from core.agentscope.tools.hitl_middleware import register_tool_middlewares
|
||||
from core.agentscope.tools.tool_meta import TOOL_META
|
||||
|
||||
@@ -25,10 +29,12 @@ class ToolGroup:
|
||||
|
||||
|
||||
TOOL_GROUPS: dict[str, ToolGroup] = {
|
||||
"intent": ToolGroup(stage="intent", tool_names=frozenset({"calendar.read"})),
|
||||
"intent": ToolGroup(
|
||||
stage="intent", tool_names=frozenset({"calendar_read", "user_resolve"})
|
||||
),
|
||||
"execution": ToolGroup(
|
||||
stage="execution",
|
||||
tool_names=frozenset({"calendar.read", "calendar.write"}),
|
||||
tool_names=frozenset({"calendar_read", "calendar_write", "user_resolve"}),
|
||||
),
|
||||
"report": ToolGroup(stage="report", tool_names=frozenset()),
|
||||
}
|
||||
@@ -49,7 +55,7 @@ def _load_custom_tool_bindings(
|
||||
) -> list[CustomToolBinding]:
|
||||
return [
|
||||
CustomToolBinding(
|
||||
name="calendar.read",
|
||||
name="calendar_read",
|
||||
func=calendar_read,
|
||||
preset_kwargs={
|
||||
"session": session,
|
||||
@@ -58,7 +64,7 @@ def _load_custom_tool_bindings(
|
||||
},
|
||||
),
|
||||
CustomToolBinding(
|
||||
name="calendar.write",
|
||||
name="calendar_write",
|
||||
func=calendar_write,
|
||||
preset_kwargs={
|
||||
"session": session,
|
||||
@@ -66,6 +72,15 @@ def _load_custom_tool_bindings(
|
||||
"user_token": user_token or "",
|
||||
},
|
||||
),
|
||||
CustomToolBinding(
|
||||
name="user_resolve",
|
||||
func=user_resolve,
|
||||
preset_kwargs={
|
||||
"session": session,
|
||||
"owner_id": owner_id,
|
||||
"user_token": user_token or "",
|
||||
},
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -126,21 +126,26 @@ class LiteLLMService:
|
||||
temperature: float | None = None,
|
||||
max_tokens: int | None = None,
|
||||
timeout: float | None = None,
|
||||
response_format: dict[str, Any] | None = None,
|
||||
completion_fn: Callable[..., dict[str, Any]] | None = None,
|
||||
) -> LiteLLMResponseWithCost:
|
||||
caller = completion_fn or completion
|
||||
request_model = model if model.startswith("openai/") else f"openai/{model}"
|
||||
|
||||
response_any = caller(
|
||||
model=request_model,
|
||||
api_key=self.proxy_api_key,
|
||||
api_base=self.proxy_base_url,
|
||||
messages=messages,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
timeout=timeout,
|
||||
stream=False,
|
||||
)
|
||||
request_kwargs: dict[str, Any] = {
|
||||
"model": request_model,
|
||||
"api_key": self.proxy_api_key,
|
||||
"api_base": self.proxy_base_url,
|
||||
"messages": messages,
|
||||
"temperature": temperature,
|
||||
"max_tokens": max_tokens,
|
||||
"timeout": timeout,
|
||||
"stream": False,
|
||||
}
|
||||
if response_format is not None:
|
||||
request_kwargs["response_format"] = response_format
|
||||
|
||||
response_any = caller(**request_kwargs)
|
||||
response = self._normalize_response(response_any)
|
||||
|
||||
usage_raw = response.get("usage")
|
||||
|
||||
@@ -107,6 +107,10 @@ class AgentRepository:
|
||||
raise HTTPException(status_code=404, detail="Session not found")
|
||||
|
||||
next_seq = int(session_row.message_count or 0) + 1
|
||||
if not _has_title(session_row.title):
|
||||
session_title = _derive_session_title(content_text)
|
||||
if session_title is not None:
|
||||
session_row.title = session_title
|
||||
payload_metadata = dict(metadata or {})
|
||||
payload_metadata["run_id"] = run_id
|
||||
message = AgentChatMessage(
|
||||
@@ -264,3 +268,14 @@ class AgentRepository:
|
||||
if rendered:
|
||||
payload["attachments"] = rendered
|
||||
return payload
|
||||
|
||||
|
||||
def _has_title(title: object) -> bool:
|
||||
return isinstance(title, str) and bool(title.strip())
|
||||
|
||||
|
||||
def _derive_session_title(content_text: str) -> str | None:
|
||||
normalized = " ".join(content_text.split())
|
||||
if not normalized:
|
||||
return None
|
||||
return normalized[:80]
|
||||
|
||||
@@ -203,6 +203,11 @@ async def stream_events(
|
||||
user_id=str(current_user.id),
|
||||
reason=str(exc),
|
||||
)
|
||||
if "Timeout reading from" in str(exc):
|
||||
idle_polls += 1
|
||||
yield ": keep-alive\n\n"
|
||||
await asyncio.sleep(0.2)
|
||||
continue
|
||||
break
|
||||
|
||||
if not rows:
|
||||
|
||||
@@ -212,12 +212,19 @@ class AgentService:
|
||||
content_type=mime_type,
|
||||
)
|
||||
except Exception: # noqa: BLE001
|
||||
bucket_name = "private"
|
||||
stored_path = await self._attachment_storage.upload_bytes(
|
||||
bucket=bucket_name,
|
||||
path=path,
|
||||
content=payload,
|
||||
content_type=mime_type,
|
||||
logger.exception(
|
||||
"Attachment upload failed",
|
||||
extra={
|
||||
"bucket": bucket_name,
|
||||
"path": path,
|
||||
"mime_type": mime_type,
|
||||
"thread_id": run_input.thread_id,
|
||||
"run_id": run_input.run_id,
|
||||
},
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail="Failed to upload attachment",
|
||||
)
|
||||
attachments.append(
|
||||
{
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import os
|
||||
from pathlib import Path
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
import httpx
|
||||
@@ -12,6 +14,9 @@ from models.agent_chat_message import AgentChatMessage
|
||||
from models.agent_chat_session import AgentChatSession
|
||||
|
||||
BASE_URL = os.getenv("AGENT_LIVE_BASE_URL", "http://localhost:5775")
|
||||
FIXTURE_IMAGE_PATH = (
|
||||
Path(__file__).resolve().parents[3] / "fixtures" / "images" / "calendar_text_cn.png"
|
||||
)
|
||||
|
||||
|
||||
async def _live_access_token(client: httpx.AsyncClient) -> str:
|
||||
@@ -108,6 +113,8 @@ async def test_agent_runs_events_history_live_with_image_input() -> None:
|
||||
if os.getenv("AGENT_LIVE_INTEGRATION") != "1":
|
||||
pytest.skip("set AGENT_LIVE_INTEGRATION=1 to run live integration test")
|
||||
|
||||
image_data = base64.b64encode(FIXTURE_IMAGE_PATH.read_bytes()).decode("ascii")
|
||||
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
token = await _live_access_token(client)
|
||||
headers = {"Authorization": f"Bearer {token}"}
|
||||
@@ -128,7 +135,7 @@ async def test_agent_runs_events_history_live_with_image_input() -> None:
|
||||
{"type": "text", "text": "请描述图片里的内容"},
|
||||
{
|
||||
"type": "binary",
|
||||
"data": "aGVsbG8=",
|
||||
"data": image_data,
|
||||
"mimeType": "image/png",
|
||||
},
|
||||
],
|
||||
@@ -142,19 +149,20 @@ async def test_agent_runs_events_history_live_with_image_input() -> None:
|
||||
assert run_resp.status_code == 202
|
||||
|
||||
events_url = f"{BASE_URL}/api/v1/agent/runs/{thread_id}/events"
|
||||
sse_resp = await client.get(
|
||||
events_url,
|
||||
headers=headers,
|
||||
params={"idle_limit": 150},
|
||||
timeout=60.0,
|
||||
)
|
||||
assert sse_resp.status_code == 200
|
||||
assert sse_resp.headers.get("content-type", "").startswith("text/event-stream")
|
||||
event_names = [
|
||||
line.split(":", 1)[1].strip()
|
||||
for line in sse_resp.text.splitlines()
|
||||
if line.startswith("event:")
|
||||
]
|
||||
event_names: list[str] = []
|
||||
async with client.stream(
|
||||
"GET", events_url, headers=headers, timeout=90.0
|
||||
) as sse_resp:
|
||||
assert sse_resp.status_code == 200
|
||||
assert sse_resp.headers.get("content-type", "").startswith(
|
||||
"text/event-stream"
|
||||
)
|
||||
async for line in sse_resp.aiter_lines():
|
||||
if line.startswith("event:"):
|
||||
event_name = line.split(":", 1)[1].strip()
|
||||
event_names.append(event_name)
|
||||
if event_name in {"RUN_FINISHED", "RUN_ERROR"}:
|
||||
break
|
||||
|
||||
assert "RUN_STARTED" in event_names
|
||||
assert "RUN_FINISHED" in event_names or "RUN_ERROR" in event_names
|
||||
@@ -194,7 +202,14 @@ async def test_agent_runs_events_history_live_with_image_input() -> None:
|
||||
)
|
||||
all_messages = list(rows.scalars().all())
|
||||
assert all_messages
|
||||
user_rows = [row for row in all_messages if str(row.role) == "user"]
|
||||
user_rows = [
|
||||
row
|
||||
for row in all_messages
|
||||
if (
|
||||
getattr(row.role, "value", row.role) == "user"
|
||||
or str(getattr(row.role, "value", row.role)) == "user"
|
||||
)
|
||||
]
|
||||
assert user_rows
|
||||
metadata = user_rows[0].metadata_json or {}
|
||||
attachments = metadata.get("attachments")
|
||||
|
||||
@@ -99,6 +99,16 @@ async def test_store_persists_assistant_message_and_aggregates(
|
||||
|
||||
store = store_module.SqlAlchemyEventStore(session_factory=lambda: _FakeSessionCtx())
|
||||
|
||||
await store.persist(
|
||||
{
|
||||
"type": "TEXT_MESSAGE_START",
|
||||
"threadId": "00000000-0000-0000-0000-000000000001",
|
||||
"runId": "run-1",
|
||||
"messageId": "assistant-run-1",
|
||||
"role": "assistant",
|
||||
"stage": "report",
|
||||
}
|
||||
)
|
||||
await store.persist(
|
||||
{
|
||||
"type": "TEXT_MESSAGE_CONTENT",
|
||||
@@ -128,6 +138,8 @@ async def test_store_persists_assistant_message_and_aggregates(
|
||||
assert append_kwargs["output_tokens"] == 5
|
||||
assert append_kwargs["cost"] == Decimal("0.123")
|
||||
assert append_kwargs["metadata"]["latency_ms"] == 250
|
||||
assert append_kwargs["metadata"]["stage"] == "report"
|
||||
assert append_kwargs["latency_ms"] == 250
|
||||
assert captured["message_delta"] == 1
|
||||
assert captured["token_delta"] == 8
|
||||
assert captured["cost_delta"] == Decimal("0.123")
|
||||
@@ -255,6 +267,60 @@ async def test_store_clears_buffer_on_run_finished(
|
||||
assert "append_kwargs" not in captured
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_persists_tool_call_result_as_tool_message(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
captured: dict[str, object] = {}
|
||||
fake_chat_session = SimpleNamespace(state_snapshot={}, message_count=2)
|
||||
|
||||
class _FakeSessionRepository:
|
||||
def __init__(self, session: object) -> None:
|
||||
del session
|
||||
|
||||
async def get_session(self, *, session_id): # noqa: ANN001
|
||||
del session_id
|
||||
return fake_chat_session
|
||||
|
||||
async def lock_session_for_update(self, *, session_id): # noqa: ANN001
|
||||
del session_id
|
||||
return fake_chat_session
|
||||
|
||||
async def update_runtime_state(self, **kwargs): # noqa: ANN003
|
||||
captured.update(kwargs)
|
||||
|
||||
class _FakeMessageRepository:
|
||||
def __init__(self, session: object) -> None:
|
||||
del session
|
||||
|
||||
async def append_message(self, **kwargs): # noqa: ANN003
|
||||
captured["append_kwargs"] = kwargs
|
||||
|
||||
monkeypatch.setattr(store_module, "SessionRepository", _FakeSessionRepository)
|
||||
monkeypatch.setattr(store_module, "MessageRepository", _FakeMessageRepository)
|
||||
monkeypatch.setattr(store_module, "AgentChatSessionStatus", _SessionStatus)
|
||||
|
||||
store = store_module.SqlAlchemyEventStore(session_factory=lambda: _FakeSessionCtx())
|
||||
await store.persist(
|
||||
{
|
||||
"type": "TOOL_CALL_RESULT",
|
||||
"threadId": "00000000-0000-0000-0000-000000000001",
|
||||
"runId": "run-1",
|
||||
"toolName": "calendar_write",
|
||||
"taskId": "t1",
|
||||
"stage": "execution",
|
||||
"args": {"title": "A"},
|
||||
"result": {"event_id": "evt-1"},
|
||||
}
|
||||
)
|
||||
|
||||
append_kwargs = cast(dict[str, Any], captured["append_kwargs"])
|
||||
assert getattr(append_kwargs["role"], "value", None) == "tool"
|
||||
assert append_kwargs["tool_name"] == "calendar_write"
|
||||
assert append_kwargs["metadata"]["task_id"] == "t1"
|
||||
assert captured["message_delta"] == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_drops_buffer_when_session_missing(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
|
||||
@@ -13,8 +13,9 @@ from core.agentscope.schemas.user_context import (
|
||||
from core.agentscope.runtime.agent_route_runtime import AgentRouteRuntime
|
||||
from core.agentscope.schemas import ReportOutput, RuntimeOutput
|
||||
from core.agentscope.schemas.agent_runtime import RunCommand
|
||||
from core.agentscope.schemas.execution import ExecutionBatchOutput
|
||||
from core.agentscope.schemas.intent import IntentOutput
|
||||
from core.agentscope.schemas.execution import ExecutionBatchOutput, ExecutionTaskOutput
|
||||
from core.agentscope.schemas.execution import ExecutionToolCall
|
||||
from core.agentscope.schemas.intent import IntentOutput, IntentTask
|
||||
|
||||
|
||||
def _user_context() -> UserAgentContext:
|
||||
@@ -50,20 +51,43 @@ async def test_runtime_emits_started_text_and_finished_events() -> None:
|
||||
async def run(self, **_: object) -> RuntimeOutput:
|
||||
return RuntimeOutput(
|
||||
intent=IntentOutput(
|
||||
route="DIRECT_RESPONSE",
|
||||
route="TASK_EXECUTION",
|
||||
intent_summary="summary",
|
||||
direct_response="done",
|
||||
tasks=[],
|
||||
complexity="simple",
|
||||
direct_response=None,
|
||||
tasks=[IntentTask(task_id="t1", title="exec", objective="do")],
|
||||
complexity="complex",
|
||||
response_metadata={"latencyMs": 120},
|
||||
),
|
||||
execution=ExecutionBatchOutput(
|
||||
task_results=[],
|
||||
task_results=[
|
||||
ExecutionTaskOutput(
|
||||
task_id="t1",
|
||||
status="SUCCESS",
|
||||
execution_summary="execution-ok",
|
||||
execution_data={},
|
||||
user_feedback_needs=[],
|
||||
response_metadata={"latencyMs": 300},
|
||||
tool_calls=[
|
||||
ExecutionToolCall(
|
||||
tool_name="calendar_write",
|
||||
args={"title": "A"},
|
||||
result={"event_id": "evt-1"},
|
||||
)
|
||||
],
|
||||
)
|
||||
],
|
||||
overall_status="SUCCESS",
|
||||
aggregate_summary="ok",
|
||||
),
|
||||
report=ReportOutput(
|
||||
assistant_text="hello world",
|
||||
response_metadata={},
|
||||
response_metadata={
|
||||
"model": "qwen3.5-flash",
|
||||
"inputTokens": 10,
|
||||
"outputTokens": 5,
|
||||
"cost": 0.123,
|
||||
"latencyMs": 250,
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
@@ -86,6 +110,13 @@ async def test_runtime_emits_started_text_and_finished_events() -> None:
|
||||
"step.finish",
|
||||
"step.start",
|
||||
"step.finish",
|
||||
"text.start",
|
||||
"text.delta",
|
||||
"text.end",
|
||||
"text.start",
|
||||
"text.delta",
|
||||
"text.end",
|
||||
"tool.result",
|
||||
"step.start",
|
||||
"text.start",
|
||||
"text.delta",
|
||||
@@ -97,11 +128,19 @@ async def test_runtime_emits_started_text_and_finished_events() -> None:
|
||||
assert calls[2]["data"]["stepName"] == "intent"
|
||||
assert calls[3]["data"]["stepName"] == "execution"
|
||||
assert calls[4]["data"]["stepName"] == "execution"
|
||||
assert calls[5]["data"]["stepName"] == "report"
|
||||
assert calls[7]["data"]["delta"] == "hello world"
|
||||
assert calls[6]["data"]["messageId"] == calls[7]["data"]["messageId"]
|
||||
assert calls[7]["data"]["messageId"] == calls[8]["data"]["messageId"]
|
||||
assert calls[9]["data"]["stepName"] == "report"
|
||||
assert calls[5]["data"]["stage"] == "intent"
|
||||
assert calls[8]["data"]["stage"] == "execution"
|
||||
assert calls[11]["data"]["toolName"] == "calendar_write"
|
||||
assert calls[12]["data"]["stepName"] == "report"
|
||||
assert calls[14]["data"]["delta"] == "hello world"
|
||||
assert calls[13]["data"]["messageId"] == calls[14]["data"]["messageId"]
|
||||
assert calls[14]["data"]["messageId"] == calls[15]["data"]["messageId"]
|
||||
assert calls[15]["data"]["model"] == "qwen3.5-flash"
|
||||
assert calls[15]["data"]["inputTokens"] == 10
|
||||
assert calls[15]["data"]["outputTokens"] == 5
|
||||
assert calls[15]["data"]["cost"] == 0.123
|
||||
assert calls[15]["data"]["latencyMs"] == 250
|
||||
assert calls[16]["data"]["stepName"] == "report"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -140,3 +179,129 @@ async def test_runtime_emits_run_error_when_orchestrator_fails() -> None:
|
||||
]
|
||||
assert calls[1]["data"]["stepName"] == "intent"
|
||||
assert calls[2]["data"]["message"] == "runtime execution failed"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runtime_passes_binary_payload_to_orchestrator() -> None:
|
||||
captured_user_input: object | None = None
|
||||
|
||||
class _FakePipeline:
|
||||
async def emit(self, *, session_id: str, event: dict[str, object]) -> str:
|
||||
assert session_id == "thread-1"
|
||||
return str(event.get("type", ""))
|
||||
|
||||
class _CaptureOrchestrator:
|
||||
async def run(self, **kwargs: object) -> RuntimeOutput:
|
||||
nonlocal captured_user_input
|
||||
captured_user_input = kwargs.get("user_input")
|
||||
return RuntimeOutput(
|
||||
intent=IntentOutput(
|
||||
route="DIRECT_RESPONSE",
|
||||
intent_summary="summary",
|
||||
direct_response="done",
|
||||
tasks=[],
|
||||
complexity="simple",
|
||||
),
|
||||
execution=None,
|
||||
report=ReportOutput(
|
||||
assistant_text="ok",
|
||||
response_metadata={},
|
||||
),
|
||||
)
|
||||
|
||||
runtime = AgentRouteRuntime(
|
||||
orchestrator=_CaptureOrchestrator(),
|
||||
pipeline=_FakePipeline(),
|
||||
)
|
||||
command = RunCommand.model_validate(
|
||||
{
|
||||
"threadId": "thread-1",
|
||||
"runId": "run-1",
|
||||
"messages": [
|
||||
{
|
||||
"id": "u1",
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "hello"},
|
||||
{
|
||||
"type": "binary",
|
||||
"mimeType": "image/png",
|
||||
"data": "aGVsbG8=",
|
||||
},
|
||||
],
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
await runtime.run(
|
||||
command=command,
|
||||
owner_id=uuid4(),
|
||||
user_token="token",
|
||||
user_context=_user_context(),
|
||||
session=cast(AsyncSession, object()),
|
||||
)
|
||||
|
||||
assert isinstance(captured_user_input, list)
|
||||
first = captured_user_input[0]
|
||||
assert isinstance(first, dict)
|
||||
content = first.get("content")
|
||||
assert isinstance(content, list)
|
||||
binary = content[1]
|
||||
assert isinstance(binary, dict)
|
||||
assert binary.get("data") == "aGVsbG8="
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runtime_direct_response_finishes_without_report_stage() -> None:
|
||||
calls: list[dict[str, Any]] = []
|
||||
|
||||
class _FakePipeline:
|
||||
async def emit(self, *, session_id: str, event: dict[str, object]) -> str:
|
||||
assert session_id == "thread-1"
|
||||
calls.append(event)
|
||||
return f"{len(calls)}-0"
|
||||
|
||||
class _DirectOrchestrator:
|
||||
async def run(self, **_: object) -> RuntimeOutput:
|
||||
return RuntimeOutput(
|
||||
intent=IntentOutput(
|
||||
route="DIRECT_RESPONSE",
|
||||
intent_summary="summary",
|
||||
direct_response="direct-answer",
|
||||
tasks=[],
|
||||
complexity="simple",
|
||||
response_metadata={"latencyMs": 88},
|
||||
),
|
||||
execution=None,
|
||||
report=ReportOutput(
|
||||
assistant_text="direct-answer",
|
||||
response_metadata={"latencyMs": 88},
|
||||
),
|
||||
)
|
||||
|
||||
runtime = AgentRouteRuntime(
|
||||
orchestrator=_DirectOrchestrator(),
|
||||
pipeline=_FakePipeline(),
|
||||
)
|
||||
command = RunCommand(threadId="thread-1", runId="run-1", messages=[])
|
||||
|
||||
await runtime.run(
|
||||
command=command,
|
||||
owner_id=uuid4(),
|
||||
user_token="token",
|
||||
user_context=_user_context(),
|
||||
session=cast(AsyncSession, object()),
|
||||
)
|
||||
|
||||
assert [item["type"] for item in calls] == [
|
||||
"run.started",
|
||||
"step.start",
|
||||
"step.finish",
|
||||
"text.start",
|
||||
"text.delta",
|
||||
"text.end",
|
||||
"run.finished",
|
||||
]
|
||||
assert calls[3]["data"]["stage"] == "intent"
|
||||
assert calls[4]["data"]["delta"] == "direct-answer"
|
||||
|
||||
@@ -68,6 +68,7 @@ class _FakeRunner:
|
||||
"direct_response": "你好",
|
||||
"tasks": [],
|
||||
"complexity": "simple",
|
||||
"response_metadata": {"model": "qwen3.5-flash", "latencyMs": 100},
|
||||
}
|
||||
self.report_calls += 1
|
||||
return {
|
||||
@@ -131,7 +132,7 @@ async def test_runtime_direct_response_skips_execution(
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "calendar.read",
|
||||
"name": "calendar_read",
|
||||
"description": "read",
|
||||
"parameters": {"type": "object", "properties": {}},
|
||||
},
|
||||
@@ -162,8 +163,10 @@ async def test_runtime_direct_response_skips_execution(
|
||||
|
||||
assert result.intent.route == "DIRECT_RESPONSE"
|
||||
assert result.execution is None
|
||||
assert result.report.assistant_text == "已完成"
|
||||
assert result.report.assistant_text == "你好"
|
||||
assert result.report.response_metadata["model"] == "qwen3.5-flash"
|
||||
assert fake_runner.execution_calls == 0
|
||||
assert fake_runner.report_calls == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -183,7 +186,7 @@ async def test_runtime_complex_route_runs_execution(
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "calendar.read",
|
||||
"name": "calendar_read",
|
||||
"description": "read",
|
||||
"parameters": {"type": "object", "properties": {}},
|
||||
},
|
||||
@@ -191,7 +194,7 @@ async def test_runtime_complex_route_runs_execution(
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "calendar.write",
|
||||
"name": "calendar_write",
|
||||
"description": "write",
|
||||
"parameters": {"type": "object", "properties": {}},
|
||||
},
|
||||
|
||||
@@ -9,6 +9,8 @@ from core.agentscope.schemas.system_agent_config import SystemAgentLLMConfig
|
||||
from core.agentscope.runtime.config_loader import RuntimeStageConfig
|
||||
from core.agentscope.runtime.react_runner import (
|
||||
AgentScopeReActRunner,
|
||||
_chat_response_text,
|
||||
_merge_stage_response_metadata,
|
||||
_parse_json_text,
|
||||
_to_litellm_model,
|
||||
)
|
||||
@@ -32,10 +34,10 @@ def test_to_litellm_model_keeps_prefixed_model() -> None:
|
||||
)
|
||||
|
||||
|
||||
def test_to_litellm_model_builds_prefixed_model() -> None:
|
||||
def test_to_litellm_model_uses_plain_model_name_when_unprefixed() -> None:
|
||||
assert (
|
||||
_to_litellm_model(provider_name="dashscope", model_code="qwen3.5-flash")
|
||||
== "dashscope/qwen3.5-flash"
|
||||
== "qwen3.5-flash"
|
||||
)
|
||||
|
||||
|
||||
@@ -49,6 +51,24 @@ def test_parse_json_text_rejects_non_json() -> None:
|
||||
_parse_json_text("not-json")
|
||||
|
||||
|
||||
def test_chat_response_text_falls_back_to_choice_message_content() -> None:
|
||||
response = SimpleNamespace(
|
||||
content=None,
|
||||
choices=[
|
||||
{
|
||||
"message": {
|
||||
"content": '{"assistant_text":"fallback","response_metadata":{}}'
|
||||
}
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
assert (
|
||||
_chat_response_text(response)
|
||||
== '{"assistant_text":"fallback","response_metadata":{}}'
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_json_stage_wraps_json_decode_error(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
@@ -113,3 +133,88 @@ async def test_run_json_stage_wraps_runtime_error(
|
||||
user_prompt="user",
|
||||
toolkit=None,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_json_stage_report_merges_usage_metadata(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
class _FakeLiteLLMService:
|
||||
def run_completion_with_cost(self, **kwargs: object) -> object:
|
||||
del kwargs
|
||||
return SimpleNamespace(
|
||||
response={
|
||||
"model": "dashscope/qwen3.5-flash",
|
||||
"choices": [
|
||||
{
|
||||
"message": {
|
||||
"content": '{"assistant_text":"ok","response_metadata":{}}'
|
||||
}
|
||||
}
|
||||
],
|
||||
},
|
||||
usage=SimpleNamespace(
|
||||
prompt_tokens=9,
|
||||
completion_tokens=4,
|
||||
cost=0.006,
|
||||
),
|
||||
)
|
||||
|
||||
runner = AgentScopeReActRunner()
|
||||
monkeypatch.setattr(
|
||||
runner,
|
||||
"_build_litellm_service",
|
||||
lambda: _FakeLiteLLMService(),
|
||||
)
|
||||
|
||||
report_stage = RuntimeStageConfig(
|
||||
stage="report",
|
||||
model_code="qwen3.5-flash",
|
||||
provider_name="dashscope",
|
||||
llm_config=SystemAgentLLMConfig(
|
||||
temperature=0.1,
|
||||
max_tokens=128,
|
||||
timeout_seconds=30,
|
||||
),
|
||||
)
|
||||
payload = await runner.run_json_stage(
|
||||
stage_config=report_stage,
|
||||
agent_name="report-agent",
|
||||
system_prompt="sys",
|
||||
user_prompt="user",
|
||||
toolkit=None,
|
||||
)
|
||||
|
||||
metadata = payload["response_metadata"]
|
||||
assert metadata["model"] == "dashscope/qwen3.5-flash"
|
||||
assert metadata["inputTokens"] == 9
|
||||
assert metadata["outputTokens"] == 4
|
||||
assert metadata["cost"] == 0.006
|
||||
assert isinstance(metadata["latencyMs"], int)
|
||||
assert metadata["latencyMs"] >= 0
|
||||
|
||||
|
||||
def test_merge_stage_response_metadata_estimates_cost_from_pricing(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
monkeypatch.setattr(
|
||||
"core.agentscope.runtime.react_runner._estimate_cost_by_pricing",
|
||||
lambda **kwargs: 0.0025,
|
||||
)
|
||||
payload = _merge_stage_response_metadata(
|
||||
payload={"route": "DIRECT_RESPONSE", "response_metadata": {}},
|
||||
stage_config=_stage_config(),
|
||||
response=SimpleNamespace(
|
||||
usage=SimpleNamespace(
|
||||
prompt_tokens=12,
|
||||
completion_tokens=8,
|
||||
),
|
||||
model="qwen3.5-flash",
|
||||
),
|
||||
latency_ms=50,
|
||||
)
|
||||
|
||||
metadata = payload["response_metadata"]
|
||||
assert metadata["inputTokens"] == 12
|
||||
assert metadata["outputTokens"] == 8
|
||||
assert metadata["cost"] == 0.0025
|
||||
|
||||
@@ -71,6 +71,63 @@ async def test_run_agentscope_task_calls_runtime_run(
|
||||
assert called["resume"] == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_agentscope_task_includes_recent_context_messages(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
captured_messages: list[dict[str, Any]] = []
|
||||
|
||||
class _FakeRuntime:
|
||||
def __init__(self, **kwargs: object) -> None:
|
||||
del kwargs
|
||||
|
||||
async def run(self, **kwargs: object) -> object:
|
||||
command = kwargs.get("command")
|
||||
if command is not None:
|
||||
raw_messages = getattr(command, "messages", [])
|
||||
if isinstance(raw_messages, list):
|
||||
captured_messages.extend(raw_messages)
|
||||
return object()
|
||||
|
||||
async def resume(self, **kwargs: object) -> object:
|
||||
del kwargs
|
||||
return object()
|
||||
|
||||
async def _fake_get_redis_client() -> object:
|
||||
return object()
|
||||
|
||||
async def _fake_context(**kwargs: object) -> list[dict[str, Any]]:
|
||||
del kwargs
|
||||
return [{"id": "ctx-1", "role": "assistant", "content": "历史上下文"}]
|
||||
|
||||
monkeypatch.setattr(tasks_module, "AgentRouteRuntime", _FakeRuntime)
|
||||
monkeypatch.setattr(
|
||||
tasks_module,
|
||||
"get_or_init_redis_client",
|
||||
_fake_get_redis_client,
|
||||
)
|
||||
monkeypatch.setattr(tasks_module, "AsyncSessionLocal", lambda: _FakeSessionCtx())
|
||||
monkeypatch.setattr(
|
||||
tasks_module,
|
||||
"_build_recent_context_messages",
|
||||
_fake_context,
|
||||
)
|
||||
|
||||
run_input = _run_input_payload()
|
||||
run_input["messages"] = [{"id": "u1", "role": "user", "content": "现在几点"}]
|
||||
await tasks_module.run_agentscope_task(
|
||||
{
|
||||
"command": "run",
|
||||
"owner_id": str(uuid4()),
|
||||
"run_input": run_input,
|
||||
}
|
||||
)
|
||||
|
||||
assert len(captured_messages) == 2
|
||||
assert captured_messages[0]["id"] == "ctx-1"
|
||||
assert captured_messages[1]["id"] == "u1"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_agentscope_task_calls_runtime_resume(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
|
||||
@@ -178,3 +178,89 @@ async def test_calendar_write_rejects_invalid_reminder_minutes(
|
||||
|
||||
assert result["data"]["ok"] is False
|
||||
assert result["data"]["code"] == "INVALID_ARGUMENT"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_calendar_write_maps_invite_arguments(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
captured: dict[str, object] = {}
|
||||
|
||||
async def _fake_execute(**kwargs: Any) -> dict[str, object]:
|
||||
captured.update(cast(dict[str, object], kwargs["tool_args"]))
|
||||
return {"type": "calendar_card.v1", "version": "v1", "data": {"ok": True}}
|
||||
|
||||
monkeypatch.setattr(
|
||||
calendar_module,
|
||||
"_execute_mutate_calendar_event",
|
||||
_fake_execute,
|
||||
)
|
||||
monkeypatch.setattr(calendar_module, "_verify_user_token", lambda **_: True)
|
||||
monkeypatch.setattr(calendar_module, "build_tool_response", lambda payload: payload)
|
||||
|
||||
await calendar_module.calendar_write(
|
||||
session=cast(AsyncSession, SimpleNamespace()),
|
||||
owner_id=uuid4(),
|
||||
user_token="token-abc",
|
||||
operation="create",
|
||||
invite_user_emails=["a@example.com"],
|
||||
invite_user_names=["alice"],
|
||||
invite_user_ids=[str(uuid4())],
|
||||
invite_permission_view=True,
|
||||
invite_permission_edit=True,
|
||||
invite_permission_invite=True,
|
||||
)
|
||||
|
||||
assert captured["inviteUserEmails"] == ["a@example.com"]
|
||||
assert captured["inviteUserNames"] == ["alice"]
|
||||
assert isinstance(captured["inviteUserIds"], list)
|
||||
assert captured["invitePermissionView"] is True
|
||||
assert captured["invitePermissionEdit"] is True
|
||||
assert captured["invitePermissionInvite"] is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_user_resolve_maps_identity_arguments(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
captured: dict[str, object] = {}
|
||||
|
||||
async def _fake_execute(**kwargs: Any) -> dict[str, object]:
|
||||
captured.update(cast(dict[str, object], kwargs["tool_args"]))
|
||||
return {"type": "user_lookup.v1", "version": "v1", "data": {"ok": True}}
|
||||
|
||||
monkeypatch.setattr(
|
||||
calendar_module,
|
||||
"_execute_resolve_user_identity",
|
||||
_fake_execute,
|
||||
)
|
||||
monkeypatch.setattr(calendar_module, "_verify_user_token", lambda **_: True)
|
||||
monkeypatch.setattr(calendar_module, "build_tool_response", lambda payload: payload)
|
||||
|
||||
result = await calendar_module.user_resolve(
|
||||
session=cast(AsyncSession, SimpleNamespace()),
|
||||
owner_id=uuid4(),
|
||||
user_token="token-abc",
|
||||
user_email="a@example.com",
|
||||
)
|
||||
|
||||
assert result["type"] == "user_lookup.v1"
|
||||
assert captured == {"userEmail": "a@example.com", "userName": None}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_user_resolve_requires_valid_user_token(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
monkeypatch.setattr(calendar_module, "_verify_user_token", lambda **_: False)
|
||||
monkeypatch.setattr(calendar_module, "build_tool_response", lambda payload: payload)
|
||||
|
||||
result = await calendar_module.user_resolve(
|
||||
session=cast(AsyncSession, SimpleNamespace()),
|
||||
owner_id=uuid4(),
|
||||
user_token="bad-token",
|
||||
user_name="alice",
|
||||
)
|
||||
|
||||
assert result["data"]["ok"] is False
|
||||
assert result["data"]["code"] == "UNAUTHORIZED"
|
||||
|
||||
@@ -0,0 +1,56 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from core.agentscope.prompts.runtime_prompt import build_intent_user_prompt
|
||||
|
||||
|
||||
def test_build_intent_user_prompt_keeps_multimodal_blocks() -> None:
|
||||
prompt = build_intent_user_prompt(
|
||||
user_input=[
|
||||
{
|
||||
"id": "u1",
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "请识别图片内容"},
|
||||
{
|
||||
"type": "binary",
|
||||
"mimeType": "image/png",
|
||||
"data": "aGVsbG8=",
|
||||
},
|
||||
],
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
assert isinstance(prompt, list)
|
||||
assert prompt
|
||||
assert prompt[0]["type"] == "text"
|
||||
assert "[Output Schema]" in prompt[0]["text"]
|
||||
image_blocks = [item for item in prompt if item.get("type") == "image"]
|
||||
assert len(image_blocks) == 1
|
||||
source = image_blocks[0]["source"]
|
||||
assert source["type"] == "base64"
|
||||
assert source["media_type"] == "image/png"
|
||||
assert source["data"] == "aGVsbG8="
|
||||
|
||||
|
||||
def test_build_intent_user_prompt_filters_non_image_binary_block() -> None:
|
||||
prompt = build_intent_user_prompt(
|
||||
user_input=[
|
||||
{
|
||||
"id": "u1",
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "请处理这个输入"},
|
||||
{
|
||||
"type": "binary",
|
||||
"mimeType": "application/pdf",
|
||||
"data": "aGVsbG8=",
|
||||
},
|
||||
],
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
assert isinstance(prompt, list)
|
||||
image_blocks = [item for item in prompt if item.get("type") == "image"]
|
||||
assert image_blocks == []
|
||||
@@ -20,11 +20,12 @@ async def test_build_toolkit_registers_calendar_tools() -> None:
|
||||
)
|
||||
schemas = toolkit.get_json_schemas()
|
||||
names = {item["function"]["name"] for item in schemas}
|
||||
assert "calendar.read" in names
|
||||
assert "calendar.write" in names
|
||||
assert "calendar_read" in names
|
||||
assert "calendar_write" in names
|
||||
assert "user_resolve" in names
|
||||
|
||||
write_schema = next(
|
||||
item for item in schemas if item["function"]["name"] == "calendar.write"
|
||||
item for item in schemas if item["function"]["name"] == "calendar_write"
|
||||
)
|
||||
params = write_schema["function"]["parameters"]["properties"]
|
||||
assert "user_token" not in params
|
||||
|
||||
@@ -33,11 +33,11 @@ def test_calculate_cost_uses_second_qwen_tier() -> None:
|
||||
|
||||
def test_run_completion_extracts_usage_and_cost() -> None:
|
||||
service = LiteLLMService()
|
||||
captured: dict[str, object] = {}
|
||||
|
||||
result = service.run_completion_with_cost(
|
||||
model="dashscope/qwen3.5-flash",
|
||||
messages=[{"role": "user", "content": "hello"}],
|
||||
completion_fn=lambda **_: {
|
||||
def _fake_completion(**kwargs: object) -> dict[str, object]:
|
||||
captured.update(kwargs)
|
||||
return {
|
||||
"model": "dashscope/qwen3.5-flash",
|
||||
"usage": {
|
||||
"prompt_tokens": 2000,
|
||||
@@ -46,10 +46,17 @@ def test_run_completion_extracts_usage_and_cost() -> None:
|
||||
"prompt_tokens_details": {"cached_tokens": 500},
|
||||
},
|
||||
"choices": [{"message": {"content": "ok"}}],
|
||||
},
|
||||
}
|
||||
|
||||
result = service.run_completion_with_cost(
|
||||
model="dashscope/qwen3.5-flash",
|
||||
messages=[{"role": "user", "content": "hello"}],
|
||||
response_format={"type": "json_object"},
|
||||
completion_fn=_fake_completion,
|
||||
)
|
||||
|
||||
assert result.usage.prompt_tokens == 2000
|
||||
assert result.usage.completion_tokens == 100
|
||||
assert result.usage.total_tokens == 2100
|
||||
assert result.usage.cost == pytest.approx(0.00051)
|
||||
assert captured["response_format"] == {"type": "json_object"}
|
||||
|
||||
@@ -10,6 +10,31 @@ from models.agent_chat_message import AgentChatMessageRole
|
||||
from v1.agent.repository import AgentRepository
|
||||
|
||||
|
||||
class _ExecuteResult:
|
||||
def __init__(self, value: object) -> None:
|
||||
self._value = value
|
||||
|
||||
def scalar_one_or_none(self) -> object:
|
||||
return self._value
|
||||
|
||||
|
||||
class _FakeSession:
|
||||
def __init__(self, session_row: object) -> None:
|
||||
self.session_row = session_row
|
||||
self.added: list[object] = []
|
||||
self.flushed = False
|
||||
|
||||
async def execute(self, stmt): # noqa: ANN001
|
||||
del stmt
|
||||
return _ExecuteResult(self.session_row)
|
||||
|
||||
def add(self, obj: object) -> None:
|
||||
self.added.append(obj)
|
||||
|
||||
async def flush(self) -> None:
|
||||
self.flushed = True
|
||||
|
||||
|
||||
class _FakeToolResultStorage:
|
||||
def __init__(self, payload: dict[str, object] | None) -> None:
|
||||
self._payload = payload
|
||||
@@ -104,3 +129,48 @@ async def test_user_message_snapshot_includes_renderable_attachments() -> None:
|
||||
"mimeType": "image/png",
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_persist_user_message_sets_session_title_when_empty() -> None:
|
||||
session_id = str(uuid4())
|
||||
session_row = SimpleNamespace(
|
||||
message_count=0,
|
||||
title=None,
|
||||
last_activity_at=datetime.now(timezone.utc),
|
||||
)
|
||||
fake_session = _FakeSession(session_row)
|
||||
repository = AgentRepository(session=fake_session) # type: ignore[arg-type]
|
||||
|
||||
await repository.persist_user_message(
|
||||
session_id=session_id,
|
||||
run_id="run-1",
|
||||
content_text=" 请帮我安排明天下午开会 ",
|
||||
metadata=None,
|
||||
)
|
||||
|
||||
assert session_row.title == "请帮我安排明天下午开会"
|
||||
assert session_row.message_count == 1
|
||||
assert fake_session.flushed is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_persist_user_message_keeps_existing_session_title() -> None:
|
||||
session_id = str(uuid4())
|
||||
session_row = SimpleNamespace(
|
||||
message_count=1,
|
||||
title="已有标题",
|
||||
last_activity_at=datetime.now(timezone.utc),
|
||||
)
|
||||
fake_session = _FakeSession(session_row)
|
||||
repository = AgentRepository(session=fake_session) # type: ignore[arg-type]
|
||||
|
||||
await repository.persist_user_message(
|
||||
session_id=session_id,
|
||||
run_id="run-2",
|
||||
content_text="新的消息内容",
|
||||
metadata=None,
|
||||
)
|
||||
|
||||
assert session_row.title == "已有标题"
|
||||
assert session_row.message_count == 2
|
||||
|
||||
@@ -175,3 +175,53 @@ async def test_enqueue_resume_accepts_valid_tool_contract(
|
||||
assert result.task_id == "task-resume-1"
|
||||
assert result.thread_id == "00000000-0000-0000-0000-000000000001"
|
||||
assert result.run_id == "run-resume-1"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_events_retries_on_redis_timeout(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
async def _acquire(*, user_id: str) -> bool:
|
||||
del user_id
|
||||
return True
|
||||
|
||||
async def _release(*, user_id: str) -> None:
|
||||
del user_id
|
||||
|
||||
monkeypatch.setattr(agent_router, "_acquire_sse_slot", _acquire)
|
||||
monkeypatch.setattr(agent_router, "_release_sse_slot", _release)
|
||||
|
||||
class _Request:
|
||||
async def is_disconnected(self) -> bool:
|
||||
return False
|
||||
|
||||
class _Service:
|
||||
def __init__(self) -> None:
|
||||
self.calls = 0
|
||||
|
||||
async def stream_events(self, **kwargs): # noqa: ANN003
|
||||
del kwargs
|
||||
self.calls += 1
|
||||
if self.calls == 1:
|
||||
raise RuntimeError("Timeout reading from localhost:6379")
|
||||
if self.calls == 2:
|
||||
return [{"id": "1-0", "event": {"type": "RUN_FINISHED"}}]
|
||||
return []
|
||||
|
||||
response = await agent_router.stream_events(
|
||||
request=cast(Any, _Request()),
|
||||
thread_id="00000000-0000-0000-0000-000000000001",
|
||||
service=cast(Any, _Service()),
|
||||
current_user=CurrentUser(id=uuid4(), email="user@example.com"),
|
||||
last_event_id=None,
|
||||
idle_limit=2,
|
||||
)
|
||||
|
||||
chunks: list[str] = []
|
||||
async for chunk in response.body_iterator:
|
||||
chunks.append(str(chunk))
|
||||
if any("RUN_FINISHED" in item for item in chunks):
|
||||
break
|
||||
|
||||
merged = "".join(chunks)
|
||||
assert "event: RUN_FINISHED" in merged
|
||||
|
||||
@@ -124,6 +124,19 @@ class _FakeAttachmentStorage:
|
||||
return path
|
||||
|
||||
|
||||
class _AlwaysFailAttachmentStorage:
|
||||
async def upload_bytes(
|
||||
self,
|
||||
*,
|
||||
bucket: str,
|
||||
path: str,
|
||||
content: bytes,
|
||||
content_type: str,
|
||||
) -> str:
|
||||
del bucket, path, content, content_type
|
||||
raise RuntimeError("upload failed")
|
||||
|
||||
|
||||
def _user() -> CurrentUser:
|
||||
return CurrentUser(
|
||||
id=UUID("00000000-0000-0000-0000-000000000001"),
|
||||
@@ -317,6 +330,54 @@ async def test_enqueue_run_uploads_user_image_to_supabase_and_injects_metadata(
|
||||
assert isinstance(attachments[0]["path"], str)
|
||||
|
||||
|
||||
async def test_enqueue_run_raises_when_attachment_upload_fails_without_fallback(
|
||||
monkeypatch,
|
||||
) -> None:
|
||||
monkeypatch.setattr(
|
||||
agent_service_module.config.storage, "bucket", "agent-test-bucket"
|
||||
)
|
||||
repository = _FakeRepository()
|
||||
service = AgentService(
|
||||
repository=repository,
|
||||
queue=_FakeQueue(),
|
||||
stream=_FakeStream(),
|
||||
attachment_storage=_AlwaysFailAttachmentStorage(),
|
||||
)
|
||||
run_input = RunAgentInput.model_validate(
|
||||
{
|
||||
"threadId": "00000000-0000-0000-0000-000000000001",
|
||||
"runId": "run-with-image-fail",
|
||||
"state": {},
|
||||
"messages": [
|
||||
{
|
||||
"id": "u1",
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "帮我看下这张图"},
|
||||
{
|
||||
"type": "binary",
|
||||
"data": "aGVsbG8=",
|
||||
"mimeType": "image/png",
|
||||
},
|
||||
],
|
||||
}
|
||||
],
|
||||
"tools": [],
|
||||
"context": [],
|
||||
"forwardedProps": {},
|
||||
}
|
||||
)
|
||||
|
||||
try:
|
||||
await service.enqueue_run(run_input=run_input, current_user=_user())
|
||||
raise AssertionError("expected HTTPException")
|
||||
except HTTPException as exc:
|
||||
assert exc.status_code == 502
|
||||
assert exc.detail == "Failed to upload attachment"
|
||||
|
||||
assert repository.persisted_user_messages == []
|
||||
|
||||
|
||||
async def test_get_history_snapshot_wraps_history_day_as_state_snapshot_event() -> None:
|
||||
service = AgentService(
|
||||
repository=_FakeRepository(),
|
||||
|
||||
@@ -1,141 +0,0 @@
|
||||
# Agent Multimodal Smoke Implementation Plan
|
||||
|
||||
> **For Claude:** REQUIRED SUB-SKILL: Use superpowers:executing-plans to implement this plan task-by-task.
|
||||
|
||||
**Goal:** 完成 agent 三条主链路(runs/events/history)真实冒烟,并支持 RunAgentInput 图片信息在发送链路落 Supabase Storage、在 messages.metadata 持久化、在 history 返回中可渲染。
|
||||
|
||||
**Architecture:** 在 `v1/agent` 服务层新增“用户消息持久化 + 图片附件上传”步骤:`enqueue_run` 时解析用户消息 content block,图片上传到 `config.storage.bucket`,将路径写入 `messages.metadata`。运行时继续通过 AgentScope pipeline 输出 AG-UI 事件,SSE 从 Redis stream 订阅,历史查询从 `messages` 回放并附带附件信息。
|
||||
|
||||
**Tech Stack:** FastAPI, SQLAlchemy AsyncSession, Supabase Storage Admin Client, Redis SSE stream, AG-UI, pytest/httpx。
|
||||
|
||||
---
|
||||
|
||||
### Task 1: 用户消息图片附件上传与落库
|
||||
|
||||
**Files:**
|
||||
- Create: `backend/src/v1/agent/attachment_storage.py`
|
||||
- Modify: `backend/src/v1/agent/service.py`
|
||||
- Modify: `backend/src/v1/agent/repository.py`
|
||||
- Test: `backend/tests/unit/v1/agent/test_service.py`
|
||||
|
||||
**Step 1: 写失败测试(RED)**
|
||||
|
||||
```python
|
||||
@pytest.mark.asyncio
|
||||
async def test_enqueue_run_persists_user_message_with_uploaded_image_metadata() -> None:
|
||||
...
|
||||
```
|
||||
|
||||
**Step 2: 运行单测验证失败**
|
||||
|
||||
Run: `uv run pytest tests/unit/v1/agent/test_service.py::test_enqueue_run_persists_user_message_with_uploaded_image_metadata -q`
|
||||
Expected: FAIL(缺少附件上传/metadata 持久化行为)
|
||||
|
||||
**Step 3: 最小实现(GREEN)**
|
||||
|
||||
```python
|
||||
class AgentAttachmentStorage:
|
||||
async def upload_bytes(...):
|
||||
...
|
||||
|
||||
class AgentService:
|
||||
async def enqueue_run(...):
|
||||
# 解析 user content blocks
|
||||
# 上传图片到 storage
|
||||
# repository 持久化 user message(metadata 包含 bucket/path)
|
||||
...
|
||||
```
|
||||
|
||||
**Step 4: 运行单测验证通过**
|
||||
|
||||
Run: `uv run pytest tests/unit/v1/agent/test_service.py::test_enqueue_run_persists_user_message_with_uploaded_image_metadata -q`
|
||||
Expected: PASS
|
||||
|
||||
### Task 2: history 渲染附件路径
|
||||
|
||||
**Files:**
|
||||
- Modify: `backend/src/v1/agent/repository.py`
|
||||
- Test: `backend/tests/unit/v1/agent/test_repository.py`
|
||||
|
||||
**Step 1: 写失败测试(RED)**
|
||||
|
||||
```python
|
||||
@pytest.mark.asyncio
|
||||
async def test_history_includes_user_message_attachments_from_metadata() -> None:
|
||||
...
|
||||
```
|
||||
|
||||
**Step 2: 运行测试验证失败**
|
||||
|
||||
Run: `uv run pytest tests/unit/v1/agent/test_repository.py::test_history_includes_user_message_attachments_from_metadata -q`
|
||||
Expected: FAIL(history 尚未渲染 attachments)
|
||||
|
||||
**Step 3: 最小实现(GREEN)**
|
||||
|
||||
```python
|
||||
if role == "user" and isinstance(metadata.get("attachments"), list):
|
||||
payload["attachments"] = metadata["attachments"]
|
||||
```
|
||||
|
||||
**Step 4: 运行测试验证通过**
|
||||
|
||||
Run: `uv run pytest tests/unit/v1/agent/test_repository.py::test_history_includes_user_message_attachments_from_metadata -q`
|
||||
Expected: PASS
|
||||
|
||||
### Task 3: 真实冒烟 runs + SSE + history(含图片输入)
|
||||
|
||||
**Files:**
|
||||
- Modify: `backend/tests/integration/v1/agent/test_sse_flow_live.py`
|
||||
|
||||
**Step 1: 写失败测试(RED)**
|
||||
|
||||
```python
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.live
|
||||
async def test_agent_runs_events_history_live_with_image_input() -> None:
|
||||
...
|
||||
```
|
||||
|
||||
**Step 2: 运行 live 测试验证失败(实现前或环境不完整)**
|
||||
|
||||
Run: `AGENT_LIVE_INTEGRATION=1 AGENT_LIVE_EMAIL=... AGENT_LIVE_PASSWORD=... uv run pytest tests/integration/v1/agent/test_sse_flow_live.py::test_agent_runs_events_history_live_with_image_input -q -s`
|
||||
Expected: FAIL(缺 metadata/path 或 history 不含附件)
|
||||
|
||||
**Step 3: 最小实现(GREEN)**
|
||||
|
||||
```python
|
||||
# live 测试流程:
|
||||
# 1) 登录拿 token
|
||||
# 2) POST /runs 发送 text + image(data)
|
||||
# 3) SSE 订阅直到 RUN_FINISHED/RUN_ERROR
|
||||
# 4) GET /runs/{thread_id}/history
|
||||
# 5) SQL 校验 sessions/messages 字段与 metadata.attachments
|
||||
```
|
||||
|
||||
**Step 4: 运行 live 测试验证通过**
|
||||
|
||||
Run: `AGENT_LIVE_INTEGRATION=1 AGENT_LIVE_EMAIL=... AGENT_LIVE_PASSWORD=... uv run pytest tests/integration/v1/agent/test_sse_flow_live.py::test_agent_runs_events_history_live_with_image_input -q -s`
|
||||
Expected: PASS
|
||||
|
||||
### Task 4: 全量收口验证与安全门禁
|
||||
|
||||
**Files:**
|
||||
- Modify (if needed): `backend/src/v1/agent/*`, `backend/tests/*`
|
||||
|
||||
**Step 1: 回归测试**
|
||||
|
||||
Run: `uv run pytest tests/unit/v1/agent tests/unit/core/agentscope tests/integration/v1/agent -q`
|
||||
Expected: PASS
|
||||
|
||||
**Step 2: 静态检查**
|
||||
|
||||
Run: `uv run ruff check src/v1/agent src/core/agentscope tests/unit/v1/agent tests/integration/v1/agent`
|
||||
Expected: PASS
|
||||
|
||||
Run: `uv run basedpyright src/v1/agent src/core/agentscope tests/unit/v1/agent tests/integration/v1/agent`
|
||||
Expected: 0 errors
|
||||
|
||||
**Step 3: 评审门禁**
|
||||
|
||||
Run agents: `security-reviewer`, `refactor-cleaner`, `code-reviewer`
|
||||
Expected: 无未解决 CRITICAL/HIGH
|
||||
@@ -0,0 +1,69 @@
|
||||
# Agent Multimodal Smoke Runbook
|
||||
|
||||
**Goal:** 固化 agent 三条主链路(runs/events/history)的真实冒烟标准与输入基线。
|
||||
|
||||
## 1. 覆盖范围
|
||||
|
||||
1. `POST /api/v1/agent/runs` - 接收多模态消息(文本+图片)
|
||||
2. `GET /api/v1/agent/runs/{thread_id}/events` - SSE 事件流,事件名符合 AG-UI 标准(`RUN_STARTED`、`STEP_STARTED`、`TOOL_CALL_*`、`RUN_FINISHED`/`RUN_ERROR`)
|
||||
3. `GET /api/v1/agent/runs/{thread_id}/history` - 返回 `STATE_SNAPSHOT`,含 `attachments` metadata
|
||||
4. `sessions/messages` 落库完整:message_count、tokens、cost、latency、title、metadata
|
||||
5. tool result 存储:大 payload 写 storage,metadata 记录 `storage_bucket`/`storage_path`
|
||||
6. storage bucket 来源:必须来自环境变量 `SOCIAL_STORAGE__BUCKET`
|
||||
|
||||
## 2. 固定测试输入
|
||||
|
||||
- 图片夹具:`backend/tests/fixtures/images/calendar_text_cn.png`
|
||||
- 多模态消息:
|
||||
- 文本:`"识别图片中的日历内容并调用 calendar.write 创建日程"`
|
||||
- 图片:`{"type":"binary","data":"<base64>","mimeType":"image/png"}`
|
||||
|
||||
## 3. 账号与凭据
|
||||
|
||||
- 冒烟账号:`dagronl@126.com` / `123456`
|
||||
- 通过环境变量注入:`AGENT_LIVE_EMAIL`、`AGENT_LIVE_PASSWORD`
|
||||
|
||||
## 4. 执行命令
|
||||
|
||||
```bash
|
||||
AGENT_LIVE_INTEGRATION=1 \
|
||||
AGENT_LIVE_EMAIL="dagronl@126.com" \
|
||||
AGENT_LIVE_PASSWORD="123456" \
|
||||
uv run pytest tests/integration/v1/agent/test_sse_flow_live.py::test_agent_runs_events_history_live_with_image_input -q -s
|
||||
```
|
||||
|
||||
## 5. 结果记录模板
|
||||
|
||||
- `thread_id` / `run_id`
|
||||
- `runs` 状态码与响应
|
||||
- `events` 事件序列
|
||||
- `history` 是否含 `attachments[].bucket/path/mimeType`
|
||||
- `sessions` 字段:message_count / total_tokens / total_cost / status / title
|
||||
- `messages` 字段:role / content / metadata / tokens / cost / latency
|
||||
- `tool_result` 是否写 storage
|
||||
|
||||
## 6. 安全注意
|
||||
|
||||
- 禁止将密码/token 写入 git 跟踪文件
|
||||
|
||||
## 7. 已修复问题清单
|
||||
|
||||
| 问题 | 修复内容 |
|
||||
|------|----------|
|
||||
| bucket 写入失败回退 | 改为直接报错,禁止回退到硬编码 bucket |
|
||||
| user.resolve 工具 | 新增按 email/name 解析 user_id |
|
||||
| calendar.write 邀请参数 | 增加 invite 参数透传 |
|
||||
| inbox_repository 缺失 | 修复 calendar runtime 依赖 |
|
||||
| runtime 模型名拼接 | 修复无效 model name |
|
||||
| 多模态透传 | runtime 透传 binary.data,不过滤为 `<omitted>` |
|
||||
| sessions.title 生成 | 首条用户消息持久化时自动生成 |
|
||||
| assistant latency 入库 | `messages.latency_ms` 列写入 |
|
||||
| intent/execution 阶段消息落库 | 新增 `text.*` 和 `tool.result` 事件 |
|
||||
| DIRECT_RESPONSE 早返回 | intent 判定后直接返回,不进入 report 阶段 |
|
||||
|
||||
## 8. 待修复问题(用户新增)
|
||||
|
||||
1. **意图/执行阶段 tokens/cost 入库** - 目前仅 report 阶段入库
|
||||
2. **连续会话记忆测试** - 验证 session 是否从数据库读取历史上下文
|
||||
3. **工具调用测试** - calendar 读/写/删/分享 + 用户查找 + 时间感知
|
||||
4. **session 失败排查** - 找出最新失败原因并修复
|
||||
@@ -1,583 +0,0 @@
|
||||
# 日历邀请弹窗优化 Implementation Plan
|
||||
|
||||
> **For Claude:** REQUIRED SUB-SKILL: Use superpowers:executing-plans to implement this plan task-by-task.
|
||||
|
||||
**Goal:** 优化日历邀请消息弹窗,显示完整信息(发送者名称 + 日历标题),使用公共弹窗组件替代所有旧弹窗代码
|
||||
|
||||
**Architecture:**
|
||||
- 后端新增用户信息查询接口
|
||||
- 前端创建公共弹窗组件 MessageActionSheet
|
||||
- 删除所有旧的弹窗代码(好友请求、日历邀请),统一使用公共组件
|
||||
|
||||
**Tech Stack:** Flutter (Dart), FastAPI (Python)
|
||||
|
||||
---
|
||||
|
||||
### Task 1: 后端添加用户信息查询接口
|
||||
|
||||
**Files:**
|
||||
- Modify: `backend/src/v1/users/router.py`
|
||||
- Modify: `backend/src/v1/users/service.py`
|
||||
- Modify: `backend/src/v1/users/repository.py`
|
||||
|
||||
**Step 1: 添加 repository 方法**
|
||||
|
||||
修改 `backend/src/v1/users/repository.py`,在 `UserRepository` 和 `SQLAlchemyUserRepository` 中已有 `get_by_user_id` 方法,确认存在。
|
||||
|
||||
**Step 2: 添加 service 方法**
|
||||
|
||||
修改 `backend/src/v1/users/service.py`,添加:
|
||||
|
||||
```python
|
||||
async def get_user_by_id(self, user_id: UUID) -> UserBasicInfo:
|
||||
from v1.friendships.schemas import UserBasicInfo
|
||||
|
||||
profile = await self._repository.get_by_user_id(user_id)
|
||||
if not profile:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
return UserBasicInfo(
|
||||
id=str(profile.user_id),
|
||||
username=profile.username,
|
||||
avatar_url=profile.avatar_url,
|
||||
)
|
||||
```
|
||||
|
||||
**Step 3: 添加 router 接口**
|
||||
|
||||
修改 `backend/src/v1/users/router.py`,添加:
|
||||
|
||||
```python
|
||||
@router.get("/{user_id}", response_model=UserBasicInfo)
|
||||
async def get_user(
|
||||
user_id: UUID,
|
||||
service: Annotated[UserService, Depends(get_user_service)],
|
||||
):
|
||||
return await service.get_user_by_id(user_id)
|
||||
```
|
||||
|
||||
**Step 4: 运行 lint 和 typecheck**
|
||||
|
||||
```bash
|
||||
cd backend && uv run ruff check src/v1/users/ && uv run basedpyright src/v1/users/
|
||||
```
|
||||
|
||||
**Step 5: 提交**
|
||||
|
||||
```bash
|
||||
git add backend/src/v1/users/ && git commit -m "feat(users): add get user by id endpoint"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Task 2: 前端添加用户 API 接口
|
||||
|
||||
**Files:**
|
||||
- Modify: `apps/lib/features/users/data/users_api.dart`
|
||||
- Modify: `apps/lib/core/di/injection.dart`
|
||||
|
||||
**Step 1: 添加 UserBasicInfo 类和 getById 方法**
|
||||
|
||||
修改 `apps/lib/features/users/data/users_api.dart`:
|
||||
|
||||
```dart
|
||||
class UserBasicInfo {
|
||||
final String id;
|
||||
final String username;
|
||||
final String? avatarUrl;
|
||||
|
||||
UserBasicInfo({
|
||||
required this.id,
|
||||
required this.username,
|
||||
this.avatarUrl,
|
||||
});
|
||||
|
||||
factory UserBasicInfo.fromJson(Map<String, dynamic> json) {
|
||||
return UserBasicInfo(
|
||||
id: json['id'] as String,
|
||||
username: json['username'] as String,
|
||||
avatarUrl: json['avatar_url'] as String?,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
class UsersApi {
|
||||
final IApiClient _client;
|
||||
static const _prefix = '/api/v1/users';
|
||||
|
||||
UsersApi(this._client);
|
||||
|
||||
// ... existing methods
|
||||
|
||||
Future<UserBasicInfo> getById(String userId) async {
|
||||
final response = await _client.get('$_prefix/$userId');
|
||||
return UserBasicInfo.fromJson(response.data);
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Step 2: 注册到 DI**
|
||||
|
||||
修改 `apps/lib/core/di/injection.dart`,添加:
|
||||
|
||||
```dart
|
||||
sl.registerLazySingleton(() => UsersApi(sl<IApiClient>()));
|
||||
```
|
||||
|
||||
**Step 3: 运行 flutter analyze**
|
||||
|
||||
```bash
|
||||
cd apps && flutter analyze lib/features/users/
|
||||
```
|
||||
|
||||
**Step 4: 提交**
|
||||
|
||||
```bash
|
||||
git add apps/lib/features/users/ apps/lib/core/di/injection.dart && git commit -m "feat(users): add getById API and UserBasicInfo"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Task 3: 创建公共弹窗组件 MessageActionSheet
|
||||
|
||||
**Files:**
|
||||
- Create: `apps/lib/features/messages/ui/widgets/message_action_sheet.dart`
|
||||
|
||||
**Step 1: 创建弹窗组件**
|
||||
|
||||
创建 `apps/lib/features/messages/ui/widgets/message_action_sheet.dart`:
|
||||
|
||||
```dart
|
||||
import 'package:flutter/material.dart';
|
||||
import '../../../../core/theme/design_tokens.dart';
|
||||
import '../../../../shared/widgets/app_button.dart';
|
||||
|
||||
class MessageActionSheet extends StatelessWidget {
|
||||
final String title;
|
||||
final String? description;
|
||||
final String? statusText;
|
||||
final bool isReadOnly;
|
||||
final VoidCallback? onAccept;
|
||||
final VoidCallback? onDecline;
|
||||
final IconData? icon;
|
||||
final Color? iconColor;
|
||||
|
||||
const MessageActionSheet({
|
||||
super.key,
|
||||
required this.title,
|
||||
this.description,
|
||||
this.statusText,
|
||||
this.isReadOnly = false,
|
||||
this.onAccept,
|
||||
this.onDecline,
|
||||
this.icon,
|
||||
this.iconColor,
|
||||
});
|
||||
|
||||
@override
|
||||
Widget build(BuildContext context) {
|
||||
return Container(
|
||||
width: double.infinity,
|
||||
padding: const EdgeInsets.fromLTRB(24, 20, 24, 0),
|
||||
decoration: const BoxDecoration(
|
||||
color: AppColors.white,
|
||||
borderRadius: BorderRadius.vertical(top: Radius.circular(24)),
|
||||
),
|
||||
child: Column(
|
||||
mainAxisSize: MainAxisSize.min,
|
||||
children: [
|
||||
Container(
|
||||
width: 40,
|
||||
height: 4,
|
||||
decoration: BoxDecoration(
|
||||
color: AppColors.slate300,
|
||||
borderRadius: BorderRadius.circular(2),
|
||||
),
|
||||
),
|
||||
const SizedBox(height: 20),
|
||||
if (icon != null) ...[
|
||||
Container(
|
||||
width: 72,
|
||||
height: 72,
|
||||
decoration: BoxDecoration(
|
||||
color: (iconColor ?? AppColors.blue500).withValues(alpha: 0.1),
|
||||
shape: BoxShape.circle,
|
||||
),
|
||||
child: Icon(icon, size: 32, color: iconColor ?? AppColors.blue500),
|
||||
),
|
||||
const SizedBox(height: 16),
|
||||
],
|
||||
Text(
|
||||
title,
|
||||
style: const TextStyle(
|
||||
fontSize: 20,
|
||||
fontWeight: FontWeight.w600,
|
||||
color: AppColors.slate900,
|
||||
),
|
||||
textAlign: TextAlign.center,
|
||||
),
|
||||
if (description != null && description!.isNotEmpty) ...[
|
||||
const SizedBox(height: 8),
|
||||
Text(
|
||||
description!,
|
||||
style: const TextStyle(fontSize: 14, color: AppColors.slate500),
|
||||
textAlign: TextAlign.center,
|
||||
),
|
||||
],
|
||||
if (statusText != null) ...[
|
||||
const SizedBox(height: 16),
|
||||
Container(
|
||||
padding: const EdgeInsets.symmetric(horizontal: 12, vertical: 6),
|
||||
decoration: BoxDecoration(
|
||||
color: AppColors.slate100,
|
||||
borderRadius: BorderRadius.circular(16),
|
||||
),
|
||||
child: Text(
|
||||
statusText!,
|
||||
style: const TextStyle(fontSize: 14, color: AppColors.slate600),
|
||||
),
|
||||
),
|
||||
],
|
||||
const SizedBox(height: 24),
|
||||
if (!isReadOnly) ...[
|
||||
Row(
|
||||
children: [
|
||||
Expanded(
|
||||
child: AppButton(
|
||||
text: '拒绝',
|
||||
isOutlined: true,
|
||||
onPressed: () {
|
||||
Navigator.pop(context);
|
||||
onDecline?.call();
|
||||
},
|
||||
),
|
||||
),
|
||||
const SizedBox(width: AppSpacing.md),
|
||||
Expanded(
|
||||
child: AppButton(
|
||||
text: '接受',
|
||||
onPressed: () {
|
||||
Navigator.pop(context);
|
||||
onAccept?.call();
|
||||
},
|
||||
),
|
||||
),
|
||||
],
|
||||
),
|
||||
],
|
||||
SizedBox(height: MediaQuery.of(context).padding.bottom + 12),
|
||||
],
|
||||
),
|
||||
);
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Step 2: 运行 flutter analyze**
|
||||
|
||||
```bash
|
||||
cd apps && flutter analyze lib/features/messages/ui/widgets/message_action_sheet.dart
|
||||
```
|
||||
|
||||
**Step 3: 提交**
|
||||
|
||||
```bash
|
||||
git add apps/lib/features/messages/ui/widgets/message_action_sheet.dart && git commit -m "feat(messages): add MessageActionSheet component"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Task 4: 重构消息列表页面,使用公共组件并删除旧代码
|
||||
|
||||
**Files:**
|
||||
- Modify: `apps/lib/features/messages/ui/screens/message_invite_list_screen.dart`
|
||||
|
||||
**Step 1: 添加依赖和字段**
|
||||
|
||||
在文件顶部添加:
|
||||
|
||||
```dart
|
||||
import '../../../users/data/users_api.dart';
|
||||
import '../widgets/message_action_sheet.dart';
|
||||
```
|
||||
|
||||
在 `_MessageInviteListScreenState` 中添加:
|
||||
|
||||
```dart
|
||||
late final UsersApi _usersApi;
|
||||
```
|
||||
|
||||
在 `initState` 中添加:
|
||||
|
||||
```dart
|
||||
_usersApi = sl<UsersApi>();
|
||||
```
|
||||
|
||||
**Step 2: 添加获取日历邀请信息方法**
|
||||
|
||||
```dart
|
||||
Future<(String calendarTitle, String senderName)?> _getCalendarInviteInfo(
|
||||
InboxMessageResponse message,
|
||||
) async {
|
||||
if (message.scheduleItemId == null || message.senderId == null) {
|
||||
return null;
|
||||
}
|
||||
try {
|
||||
final calendar = await _calendarApi.getById(message.scheduleItemId!);
|
||||
final sender = await _usersApi.getById(message.senderId!);
|
||||
return (calendar.title, sender.username);
|
||||
} catch (e) {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Step 3: 替换日历邀请弹窗方法**
|
||||
|
||||
删除旧的 `_showCalendarInviteSheet` 方法,替换为:
|
||||
|
||||
```dart
|
||||
Future<void> _showCalendarInviteSheet(InboxMessageResponse message) async {
|
||||
final itemId = message.scheduleItemId;
|
||||
if (itemId == null) return;
|
||||
|
||||
final info = await _getCalendarInviteInfo(message);
|
||||
final title = info != null
|
||||
? '${info.$2} 邀请你加入日历'
|
||||
: '日历邀请';
|
||||
final description = info?.$1;
|
||||
|
||||
if (!mounted) return;
|
||||
|
||||
showModalBottomSheet<void>(
|
||||
context: context,
|
||||
backgroundColor: Colors.transparent,
|
||||
builder: (ctx) => MessageActionSheet(
|
||||
title: title,
|
||||
description: description,
|
||||
icon: Icons.calendar_today,
|
||||
iconColor: AppColors.blue500,
|
||||
onAccept: () async {
|
||||
try {
|
||||
await _calendarApi.acceptSubscription(itemId);
|
||||
await _inboxApi.markAsRead(message.id);
|
||||
if (mounted) {
|
||||
Toast.show(context, '已接受', type: ToastType.success);
|
||||
_loadMessages();
|
||||
}
|
||||
} catch (e) {
|
||||
if (mounted) {
|
||||
Toast.show(context, '操作失败', type: ToastType.error);
|
||||
}
|
||||
}
|
||||
},
|
||||
onDecline: () async {
|
||||
try {
|
||||
await _calendarApi.rejectSubscription(itemId);
|
||||
await _inboxApi.markAsRead(message.id);
|
||||
if (mounted) {
|
||||
Toast.show(context, '已拒绝', type: ToastType.success);
|
||||
_loadMessages();
|
||||
}
|
||||
} catch (e) {
|
||||
if (mounted) {
|
||||
Toast.show(context, '操作失败', type: ToastType.error);
|
||||
}
|
||||
}
|
||||
},
|
||||
),
|
||||
);
|
||||
}
|
||||
```
|
||||
|
||||
**Step 4: 添加已读日历邀请弹窗方法**
|
||||
|
||||
```dart
|
||||
Future<void> _showCalendarInviteReadOnlySheet(InboxMessageResponse message) async {
|
||||
final itemId = message.scheduleItemId;
|
||||
if (itemId == null) return;
|
||||
|
||||
final info = await _getCalendarInviteInfo(message);
|
||||
final title = info != null
|
||||
? '${info.$2} 邀请你加入日历'
|
||||
: '日历邀请';
|
||||
final description = info?.$1;
|
||||
|
||||
final statusText = message.status.value == 'accepted' ? '已接受' : '已拒绝';
|
||||
|
||||
if (!mounted) return;
|
||||
|
||||
showModalBottomSheet<void>(
|
||||
context: context,
|
||||
backgroundColor: Colors.transparent,
|
||||
builder: (ctx) => MessageActionSheet(
|
||||
title: title,
|
||||
description: description,
|
||||
statusText: statusText,
|
||||
isReadOnly: true,
|
||||
icon: Icons.calendar_today,
|
||||
iconColor: AppColors.blue500,
|
||||
),
|
||||
);
|
||||
}
|
||||
```
|
||||
|
||||
**Step 5: 替换好友请求弹窗方法**
|
||||
|
||||
删除旧的 `_showFriendRequestReadOnlySheet` 和 `_showFriendRequestActionSheet` 方法,替换为:
|
||||
|
||||
```dart
|
||||
void _showFriendRequestSheet(MessageWithFriend item, {bool isReadOnly = false}) {
|
||||
final message = item.message;
|
||||
final friendRequest = item.friendRequest;
|
||||
if (friendRequest == null) return;
|
||||
|
||||
final title = '${friendRequest.sender.username} 请求添加您为好友';
|
||||
final description = message.content;
|
||||
final statusText = isReadOnly
|
||||
? (friendRequest.status == 'accepted'
|
||||
? '已接受'
|
||||
: friendRequest.status == 'rejected'
|
||||
? '已拒绝'
|
||||
: '已处理')
|
||||
: null;
|
||||
|
||||
showModalBottomSheet<void>(
|
||||
context: context,
|
||||
backgroundColor: Colors.transparent,
|
||||
isScrollControlled: true,
|
||||
builder: (ctx) => MessageActionSheet(
|
||||
title: title,
|
||||
description: description,
|
||||
statusText: statusText,
|
||||
isReadOnly: isReadOnly,
|
||||
icon: Icons.person_add_outlined,
|
||||
iconColor: AppColors.emerald500,
|
||||
onAccept: isReadOnly
|
||||
? null
|
||||
: () async {
|
||||
await _processFriendRequest(item, accept: true);
|
||||
},
|
||||
onDecline: isReadOnly
|
||||
? null
|
||||
: () async {
|
||||
await _processFriendRequest(item, accept: false);
|
||||
},
|
||||
),
|
||||
);
|
||||
}
|
||||
```
|
||||
|
||||
**Step 6: 修改 _handleMessageTap 方法**
|
||||
|
||||
修改为调用新的统一方法:
|
||||
|
||||
```dart
|
||||
case InboxMessageType.calendar:
|
||||
final content = _parseCalendarContent(message.content);
|
||||
if (content == null) return;
|
||||
|
||||
final type = content['type'] as String?;
|
||||
if (type == 'invite') {
|
||||
if (message.status.value == 'pending') {
|
||||
await _showCalendarInviteSheet(message);
|
||||
} else {
|
||||
await _showCalendarInviteReadOnlySheet(message);
|
||||
if (message.scheduleItemId != null && context.mounted) {
|
||||
context.push('/calendar/events/${message.scheduleItemId}');
|
||||
}
|
||||
}
|
||||
} else if (type == 'update') {
|
||||
if (message.scheduleItemId != null) {
|
||||
context.push('/calendar/events/${message.scheduleItemId}');
|
||||
}
|
||||
}
|
||||
return;
|
||||
case InboxMessageType.friendRequest:
|
||||
if (item.friendRequest == null) {
|
||||
Toast.show(context, '发送者信息加载失败,请下拉重试', type: ToastType.error);
|
||||
return;
|
||||
}
|
||||
_showFriendRequestSheet(item, isReadOnly: message.isRead);
|
||||
return;
|
||||
```
|
||||
|
||||
**Step 7: 删除旧的 _FriendRequestSheet 类**
|
||||
|
||||
删除文件末尾的整个 `_FriendRequestSheet` 类(约605-749行)。
|
||||
|
||||
**Step 8: 运行 flutter analyze**
|
||||
|
||||
```bash
|
||||
cd apps && flutter analyze lib/features/messages/ui/screens/message_invite_list_screen.dart
|
||||
```
|
||||
|
||||
**Step 9: 提交**
|
||||
|
||||
```bash
|
||||
git add apps/lib/features/messages/ && git commit -m "refactor(messages): use MessageActionSheet for all message types"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Task 5: 删除日历消息卡片中的旧弹窗代码
|
||||
|
||||
**Files:**
|
||||
- Modify: `apps/lib/features/messages/ui/widgets/calendar_message_card.dart`
|
||||
|
||||
**Step 1: 修改 CalendarInviteCard**
|
||||
|
||||
CalendarInviteCard 是用于列表展示的卡片,不需要显示弹窗。检查是否有不必要的硬编码,如果有则清理。
|
||||
|
||||
**Step 2: 运行 flutter analyze**
|
||||
|
||||
```bash
|
||||
cd apps && flutter analyze lib/features/messages/ui/widgets/calendar_message_card.dart
|
||||
```
|
||||
|
||||
**Step 3: 提交**
|
||||
|
||||
```bash
|
||||
git add apps/lib/features/calendar_message_card.dart && git commit/messages/ui/widgets -f "chore(messages): clean up calendar message card"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Task 6: 验证和测试
|
||||
|
||||
**Step 1: 运行完整测试**
|
||||
|
||||
```bash
|
||||
cd apps && flutter test test/features/messages/
|
||||
cd backend && uv run pytest tests/unit/v1/users/ -v
|
||||
```
|
||||
|
||||
**Step 2: 手动测试场景**
|
||||
|
||||
1. 用户 A 发送日历邀请给用户 B
|
||||
2. 用户 B 打开未读消息,点击日历邀请
|
||||
3. 弹窗显示:"XXX 邀请你加入 [日历标题]"
|
||||
4. 点击接受/拒绝
|
||||
5. 用户 B 打开已读消息,点击日历邀请
|
||||
6. 弹窗显示状态标签
|
||||
7. 好友请求未读/已读都使用相同弹窗组件
|
||||
|
||||
---
|
||||
|
||||
## Summary
|
||||
|
||||
| Task | Description |
|
||||
|------|-------------|
|
||||
| 1 | 后端添加用户信息查询接口 `/api/v1/users/{user_id}` |
|
||||
| 2 | 前端添加 UsersApi.getById 方法 |
|
||||
| 3 | 创建公共弹窗组件 MessageActionSheet |
|
||||
| 4 | 重构消息列表页面,删除旧弹窗代码,统一使用 MessageActionSheet |
|
||||
| 5 | 清理日历消息卡片旧代码 |
|
||||
| 6 | 验证测试 |
|
||||
|
||||
**Plan complete and saved to `docs/plans/2026-03-11-calendar-invite-sheet.md`. Two execution options:**
|
||||
|
||||
1. **Subagent-Driven (this session)** - I dispatch fresh subagent per task, review between tasks, fast iteration
|
||||
|
||||
2. **Parallel Session (separate)** - Open new session with executing-plans, batch execution with checkpoints
|
||||
|
||||
Which approach?
|
||||
Reference in New Issue
Block a user