feat: 添加 Agent 步骤事件与图片附件功能

- 新增 stepStarted/stepFinished 事件类型支持
- 前端实现图片附件上传和预览功能
- 后端增强工具结果存储和事件处理
- 完善相关单元测试和集成测试
This commit is contained in:
zl-q
2026-03-12 09:29:57 +08:00
parent 87215f9d41
commit 7b8865e256
45 changed files with 3869 additions and 308 deletions
@@ -37,6 +37,21 @@ def to_agui_wire_event(event: dict[str, Any]) -> dict[str, Any]:
data = event.get("data")
if isinstance(data, dict):
if event_type == "tool.result":
for key in (
"messageId",
"toolCallId",
"callId",
"toolName",
"stage",
"taskId",
"ui",
"content",
):
value = data.get(key)
if value is not None:
payload[key] = value
return payload
reserved = {"type", "threadId", "runId"}
data_map = cast(dict[str, Any], data)
payload.update({k: v for k, v in data_map.items() if k not in reserved})
+171 -15
View File
@@ -1,11 +1,14 @@
from __future__ import annotations
import json
import re
from decimal import Decimal, InvalidOperation
from typing import Any, Callable, Protocol
from uuid import UUID
from uuid import UUID, uuid4
from core.agentscope.events.tool_result_summary import build_tool_content_summary
from core.agentscope.events.persistence import MessageRepository, SessionRepository
from core.logging import get_logger
from models.agent_chat_message import AgentChatMessageRole
from models.agent_chat_session import AgentChatSessionStatus
@@ -14,6 +17,16 @@ class EventStore(Protocol):
async def persist(self, event: dict[str, Any]) -> None: ...
class ToolResultStorageLike(Protocol):
async def upload_json(
self,
*,
bucket: str,
path: str,
payload: dict[str, object],
) -> str: ...
class NullEventStore:
async def persist(self, event: dict[str, Any]) -> None:
del event
@@ -21,9 +34,20 @@ class NullEventStore:
class SqlAlchemyEventStore:
_session_factory: Callable[[], Any]
_tool_result_storage: ToolResultStorageLike | None
_tool_result_bucket: str | None
_logger = get_logger("core.agentscope.events.store")
def __init__(self, *, session_factory: Any) -> None:
def __init__(
self,
*,
session_factory: Any,
tool_result_storage: ToolResultStorageLike | None = None,
tool_result_bucket: str | None = None,
) -> None:
self._session_factory = session_factory
self._tool_result_storage = tool_result_storage
self._tool_result_bucket = tool_result_bucket
self._message_buffers: dict[tuple[str, str], str] = {}
self._message_contexts: dict[tuple[str, str], dict[str, object]] = {}
@@ -228,23 +252,89 @@ class SqlAlchemyEventStore:
if not isinstance(tool_name, str) or not tool_name:
return
payload = {
"args": event.get("args"),
"result": event.get("result"),
"error": event.get("error"),
"call_id": event.get("callId"),
}
content = json.dumps(payload, ensure_ascii=False, separators=(",", ":"))
metadata: dict[str, object] = {"tool_name": tool_name}
run_id = event.get("runId")
if isinstance(run_id, str) and run_id:
metadata["run_id"] = run_id
run_id_value = run_id if isinstance(run_id, str) and run_id else ""
task_id = event.get("taskId")
task_id_value = task_id if isinstance(task_id, str) and task_id else "task"
call_id_value = event.get("callId")
if not isinstance(call_id_value, str) or not call_id_value:
call_id_value = (
f"{run_id_value}-{task_id_value}-{uuid4().hex[:8]}"
if run_id_value
else f"{task_id_value}-{uuid4().hex[:8]}"
)
summary = build_tool_content_summary(
tool_name=tool_name,
args=event.get("args") if isinstance(event.get("args"), dict) else None,
result=event.get("result"),
error=event.get("error"),
)
raw_result_value = event.get("result")
raw_result: dict[str, object] = (
raw_result_value if isinstance(raw_result_value, dict) else {}
)
ui_candidate = raw_result.get("ui")
ui_schema = ui_candidate if isinstance(ui_candidate, dict) else None
result_type = raw_result.get("type")
result_data = raw_result.get("data")
if (
ui_schema is None
and isinstance(result_type, str)
and isinstance(result_data, dict)
):
ui_schema = raw_result
payload: dict[str, object] = {
"toolName": tool_name,
"ui_schema": ui_schema,
"result": _sanitize_result(raw_result),
"error": _sanitize_error(event.get("error")),
"callId": call_id_value,
"runId": run_id_value,
"taskId": task_id_value,
"content": summary,
}
metadata: dict[str, object] = {
"tool_name": tool_name,
"tool_call_id": call_id_value,
"summary_version": "v1",
}
if run_id_value:
metadata["run_id"] = run_id_value
stage = event.get("stage")
if isinstance(stage, str) and stage:
metadata["stage"] = stage
task_id = event.get("taskId")
if isinstance(task_id, str) and task_id:
metadata["task_id"] = task_id
if task_id_value:
metadata["task_id"] = task_id_value
if self._tool_result_storage is not None and self._tool_result_bucket:
safe_run = _sanitize_path_component(run_id_value or "run")
safe_call = _sanitize_path_component(call_id_value)
storage_path = f"tool-results/{session_id}/{safe_run}/{safe_call}.json"
try:
await self._tool_result_storage.upload_json(
bucket=self._tool_result_bucket,
path=storage_path,
payload=payload,
)
metadata["storage_bucket"] = self._tool_result_bucket
metadata["storage_path"] = storage_path
except Exception: # noqa: BLE001
metadata["storage_upload_failed"] = True
self._logger.warning(
"tool result storage upload failed",
session_id=str(session_id),
run_id=run_id_value,
call_id=call_id_value,
storage_path=storage_path,
)
content = summary or json.dumps(
payload, ensure_ascii=False, separators=(",", ":")
)
locked_session = await session_repo.lock_session_for_update(
session_id=session_id
@@ -333,3 +423,69 @@ class SqlAlchemyEventStore:
except (InvalidOperation, TypeError, ValueError):
return Decimal("0")
return parsed if parsed >= 0 else Decimal("0")
def _sanitize_path_component(value: str) -> str:
compact = re.sub(r"[^A-Za-z0-9._-]", "-", value.strip())
compact = compact.strip(".-")
return compact or "id"
def _sanitize_error(value: object) -> str | None:
if isinstance(value, str) and value.strip():
return " ".join(value.split())[:300]
if isinstance(value, dict):
for key in ("message", "error", "detail"):
item = value.get(key)
if isinstance(item, str) and item.strip():
return " ".join(item.split())[:300]
return None
def _sanitize_result(value: object) -> dict[str, object]:
if not isinstance(value, dict):
return {}
def _is_sensitive_key(key: str) -> bool:
normalized = key.strip().lower().replace("-", "_")
if not normalized:
return False
exact = {
"password",
"token",
"secret",
"api_key",
"apikey",
"credential",
"authorization",
"auth",
}
if normalized in exact:
return True
patterns = (
"password",
"token",
"secret",
"auth",
"credential",
"api_key",
"apikey",
"authorization",
)
return any(pattern in normalized for pattern in patterns)
def _sanitize_value(item: object) -> object:
if isinstance(item, dict):
return _sanitize_result(item)
if isinstance(item, list):
return [_sanitize_value(entry) for entry in item]
return item
sanitized: dict[str, object] = {}
for key, item in value.items():
key_text = str(key)
if _is_sensitive_key(key_text):
sanitized[str(key)] = "[REDACTED]"
continue
sanitized[str(key)] = _sanitize_value(item)
return sanitized
@@ -0,0 +1,124 @@
from __future__ import annotations
from typing import Any
def build_tool_content_summary(
*,
tool_name: str,
args: dict[str, Any] | None,
result: Any,
error: Any,
) -> str:
error_message = _extract_error_message(error)
if error_message is not None:
return _truncate(f"{tool_name} 执行失败:{error_message}")
normalized_args = args if isinstance(args, dict) else {}
normalized_result = result if isinstance(result, dict) else {}
business_failure = _extract_business_failure_message(normalized_result)
if business_failure is not None:
return _truncate(f"{tool_name} 执行失败:{business_failure}")
if tool_name == "calendar_write":
title = _pick_first_str(normalized_result, ("title",)) or _pick_first_str(
normalized_args, ("title",)
)
start_at = _pick_first_str(normalized_result, ("startAt", "start_at"))
if title and start_at:
return _truncate(f"已创建日程:{title}{start_at}")
if title:
return _truncate(f"已创建日程:{title}")
if tool_name == "calendar_read":
total = _extract_total(normalized_result)
query = _pick_first_str(normalized_args, ("query",)) or "全部"
if total is not None:
return _truncate(f"查询到 {total} 条日程({query}")
if tool_name == "calendar_delete":
target = _pick_first_str(normalized_result, ("title", "eventId", "event_id"))
if target:
return _truncate(f"已删除日程:{target}")
if tool_name == "calendar_share":
target = _pick_first_str(normalized_result, ("target", "user", "userName"))
if target:
return _truncate(f"已分享日程给 {target}")
if tool_name == "user_resolve":
target = _pick_first_str(normalized_result, ("name", "userName", "userId"))
if target:
return _truncate(f"已匹配用户:{target}")
result_content = _pick_first_str(normalized_result, ("content", "message"))
if result_content:
return _truncate(result_content)
return _truncate(f"{tool_name} 执行完成")
def _extract_error_message(error: Any) -> str | None:
if isinstance(error, str) and error.strip():
return error.strip()
if isinstance(error, dict):
for key in ("message", "error", "detail"):
value = error.get(key)
if isinstance(value, str) and value.strip():
return value.strip()
return None
def _pick_first_str(payload: dict[str, Any], keys: tuple[str, ...]) -> str | None:
for key in keys:
value = payload.get(key)
if isinstance(value, str):
normalized = " ".join(value.split())
if normalized:
return normalized
return None
def _extract_total(result: dict[str, Any]) -> int | None:
candidates: list[Any] = [result.get("total")]
data = result.get("data")
if isinstance(data, dict):
candidates.append(data.get("total"))
events = data.get("events")
if isinstance(events, list):
candidates.append(len(events))
for value in candidates:
if isinstance(value, bool):
continue
if isinstance(value, int) and value >= 0:
return value
if isinstance(value, str) and value.isdigit():
return int(value)
return None
def _extract_business_failure_message(result: dict[str, Any]) -> str | None:
top_ok = result.get("ok")
if top_ok is False:
top_message = _pick_first_str(result, ("message", "error", "detail"))
if top_message:
return top_message
data = result.get("data")
if isinstance(data, dict) and data.get("ok") is False:
data_message = _pick_first_str(data, ("message", "error", "detail"))
if data_message:
return data_message
code = _pick_first_str(data, ("code",))
if code:
return code
return None
def _truncate(text: str, limit: int = 80) -> str:
normalized = " ".join(text.split())
if len(normalized) <= limit:
return normalized
return normalized[: limit - 3] + "..."
@@ -42,11 +42,19 @@ def build_intent_user_prompt(
*, user_input: str | list[dict[str, Any]]
) -> str | list[dict[str, Any]]:
if isinstance(user_input, list):
context_messages = _conversation_context_messages(user_input)
context_hint = (
json.dumps(context_messages, ensure_ascii=True, separators=(",", ":"))
if context_messages
else "[]"
)
instruction_text = "\n\n".join(
[
INTENT_TASK_INSTRUCTION,
"[Output Schema]",
_schema_json(IntentOutput),
"[Conversation Context]",
context_hint,
"[User Input]",
"Use the following multimodal blocks as the latest user input.",
]
@@ -127,6 +135,56 @@ def _latest_user_content_blocks(
return []
def _conversation_context_messages(
user_input: list[dict[str, Any]],
) -> list[dict[str, str]]:
latest_user_index = -1
for index in range(len(user_input) - 1, -1, -1):
item = user_input[index]
if isinstance(item, dict) and item.get("role") == "user":
latest_user_index = index
break
if latest_user_index <= 0:
return []
context_items: list[dict[str, str]] = []
for item in user_input[:latest_user_index]:
if not isinstance(item, dict):
continue
role = item.get("role")
if role not in {"user", "assistant"}:
continue
content = item.get("content")
text = _content_to_text(content)
if text:
context_items.append({"role": str(role), "content": text})
if len(context_items) <= 12:
return context_items
return context_items[-12:]
def _content_to_text(content: Any) -> str:
if isinstance(content, str):
return " ".join(content.split())
if not isinstance(content, list):
return ""
parts: list[str] = []
for block in content:
if not isinstance(block, dict):
continue
block_type = block.get("type")
if block_type == "text":
text = block.get("text")
if isinstance(text, str) and text.strip():
parts.append(" ".join(text.split()))
elif block_type in {"binary", "image"}:
parts.append("[image]")
return " ".join(parts).strip()
def _binary_source_block(item: dict[str, Any]) -> dict[str, Any] | None:
mime_type = item.get("mimeType")
media_type = mime_type if isinstance(mime_type, str) and mime_type else "image/png"
@@ -1,9 +1,11 @@
from __future__ import annotations
import json
import re
from typing import Any, Protocol
from uuid import UUID
import structlog
from sqlalchemy.ext.asyncio import AsyncSession
from core.logging import get_logger
@@ -31,7 +33,9 @@ class PipelineLike(Protocol):
class AgentRouteRuntime:
_orchestrator: OrchestratorLike
_pipeline: PipelineLike
_logger = get_logger("core.agentscope.runtime.agent_route_runtime")
_logger: structlog.stdlib.BoundLogger = get_logger(
"core.agentscope.runtime.agent_route_runtime"
)
def __init__(
self, *, orchestrator: OrchestratorLike, pipeline: PipelineLike
@@ -144,15 +148,6 @@ class AgentRouteRuntime:
"data": {"stepName": "execution"},
},
)
await self._pipeline.emit(
session_id=command.thread_id,
event={
"type": "step.finish",
"threadId": command.thread_id,
"runId": command.run_id,
"data": {"stepName": "execution"},
},
)
await self._emit_stage_text(
thread_id=command.thread_id,
@@ -191,6 +186,15 @@ class AgentRouteRuntime:
task_id=task.task_id,
tool_calls=_task_tool_calls(task),
)
await self._pipeline.emit(
session_id=command.thread_id,
event={
"type": "step.finish",
"threadId": command.thread_id,
"runId": command.run_id,
"data": {"stepName": "execution"},
},
)
await self._pipeline.emit(
session_id=command.thread_id,
@@ -294,6 +298,13 @@ class AgentRouteRuntime:
tool_name = tool_call.get("tool_name")
if not isinstance(tool_name, str) or not tool_name:
continue
call_id = f"{run_id}-{task_id}-{index}"
result_payload = _build_tool_result_event_payload(
tool_name=tool_name,
call_id=call_id,
raw_result=tool_call.get("result"),
raw_error=tool_call.get("error"),
)
await self._pipeline.emit(
session_id=thread_id,
event={
@@ -301,18 +312,175 @@ class AgentRouteRuntime:
"threadId": thread_id,
"runId": run_id,
"data": {
"callId": f"{run_id}-{task_id}-{index}",
"messageId": result_payload["messageId"],
"toolCallId": call_id,
"callId": call_id,
"stage": "execution",
"taskId": task_id,
"toolName": tool_name,
"args": tool_call.get("args", {}),
"result": tool_call.get("result"),
"error": tool_call.get("error"),
"args": _sanitize_result(tool_call.get("args", {})),
"result": result_payload["result"],
"error": result_payload["error"],
"ui": result_payload["ui"],
"content": result_payload["content"],
},
},
)
def _build_tool_result_event_payload(
*,
tool_name: str,
call_id: str,
raw_result: Any,
raw_error: Any,
) -> dict[str, Any]:
result = _sanitize_result(_normalize_tool_result(raw_result))
error = _sanitize_error(raw_error)
ui: dict[str, Any] | None = None
direct_ui = result.get("ui")
if isinstance(direct_ui, dict):
ui = direct_ui
elif isinstance(result.get("type"), str) and isinstance(result.get("data"), dict):
ui = result
text_content = _extract_result_text_content(result)
if text_content is None and isinstance(error, str):
text_content = error
if text_content is None:
text_content = f"{tool_name} 执行完成"
return {
"messageId": f"tool-result-{call_id}",
"result": result,
"error": error,
"ui": ui,
"content": text_content,
}
def _normalize_tool_result(raw_result: Any) -> dict[str, Any]:
if isinstance(raw_result, dict):
content = raw_result.get("content")
if isinstance(content, str):
parsed = _try_parse_json_object(content)
if parsed is not None:
return parsed
return raw_result
if isinstance(raw_result, str):
parsed = _try_parse_json_object(raw_result)
if parsed is not None:
return parsed
text = raw_result.strip()
if text:
return {"content": text}
if raw_result is not None:
return {"value": raw_result}
return {}
def _try_parse_json_object(value: str) -> dict[str, Any] | None:
raw = value.strip()
if not raw:
return None
try:
parsed = json.loads(raw)
except ValueError:
return None
if not isinstance(parsed, dict):
return None
return parsed
def _extract_result_text_content(result: dict[str, Any]) -> str | None:
content = result.get("content")
if isinstance(content, str) and content.strip():
return content
data = result.get("data")
if isinstance(data, dict):
message = data.get("message")
if isinstance(message, str) and message.strip():
return message
return None
def _sanitize_error(value: Any) -> str | None:
if isinstance(value, str) and value.strip():
text = " ".join(value.split())
return _redact_sensitive_text(text)[:300]
if isinstance(value, dict):
for key in ("message", "error", "detail"):
item = value.get(key)
if isinstance(item, str) and item.strip():
text = " ".join(item.split())
return _redact_sensitive_text(text)[:300]
return None
def _sanitize_result(value: Any) -> dict[str, Any]:
if not isinstance(value, dict):
return {}
def _is_sensitive_key(key: str) -> bool:
normalized = key.strip().lower().replace("-", "_")
if not normalized:
return False
exact = {
"password",
"token",
"secret",
"api_key",
"apikey",
"credential",
"authorization",
"auth",
}
if normalized in exact:
return True
patterns = (
"password",
"token",
"secret",
"credential",
"api_key",
"apikey",
"authorization",
)
return any(pattern in normalized for pattern in patterns)
def _sanitize_value(item: Any) -> Any:
if isinstance(item, dict):
return _sanitize_result(item)
if isinstance(item, list):
return [_sanitize_value(entry) for entry in item]
if isinstance(item, str):
return _redact_sensitive_text(item)
return item
sanitized: dict[str, Any] = {}
for key, item in value.items():
key_text = str(key)
if _is_sensitive_key(key_text):
sanitized[key_text] = "[REDACTED]"
continue
sanitized[key_text] = _sanitize_value(item)
return sanitized
def _redact_sensitive_text(value: str) -> str:
redacted = value
key_value_patterns = (
r"(?i)(authorization)\s*[:=]\s*bearer\s+[^\s,;]+",
r"(?i)(password|token|secret|api[_-]?key|authorization|credential)\s*[:=]\s*[^\s,;]+",
r"(?i)(password|token|secret|api[_-]?key|authorization|credential)\s+[^\s,;]+",
)
for pattern in key_value_patterns:
redacted = re.sub(pattern, r"\1=[REDACTED]", redacted)
redacted = re.sub(r"(?i)bearer\s+[^\s,;]+", "Bearer [REDACTED]", redacted)
return redacted
def _text_end_telemetry_payload(metadata: dict[str, Any]) -> dict[str, Any]:
payload: dict[str, Any] = {}
model = _first_non_empty_str(metadata, keys=("model", "model_code"))
@@ -52,6 +52,24 @@ def _tools_payload_from_schema(
return payload
def _merge_tool_schemas(
*schema_sets: list[dict[str, object]],
) -> list[dict[str, object]]:
merged: list[dict[str, object]] = []
seen_names: set[str] = set()
for schemas in schema_sets:
for schema in schemas:
function = schema.get("function")
if not isinstance(function, dict):
continue
name = function.get("name")
if not isinstance(name, str) or not name or name in seen_names:
continue
seen_names.add(name)
merged.append(schema)
return merged
class AgentScopeRuntimeOrchestrator:
_runner: Any
_config_loader: Callable[[AsyncSession], Awaitable[dict[str, RuntimeStageConfig]]]
@@ -96,10 +114,20 @@ class AgentScopeRuntimeOrchestrator:
enable_hitl=False,
)
intent_tools_schema = intent_toolkit.get_json_schemas()
execution_toolkit = build_stage_toolkit(
stage="execution",
session=session,
owner_id=owner_id,
user_token=user_token,
enable_hitl=True,
)
execution_tools_schema = execution_toolkit.get_json_schemas()
intent_prompt = build_system_prompt(
stage="intent",
user_context=user_context,
tools=_tools_payload_from_schema(intent_tools_schema),
tools=_tools_payload_from_schema(
_merge_tool_schemas(intent_tools_schema, execution_tools_schema)
),
)
intent_payload = await self._runner.run_json_stage(
stage_config=stage_config["intent"],
@@ -125,14 +153,6 @@ class AgentScopeRuntimeOrchestrator:
execution_output: ExecutionBatchOutput | None = None
if intent_output.route == "TASK_EXECUTION":
execution_toolkit = build_stage_toolkit(
stage="execution",
session=session,
owner_id=owner_id,
user_token=user_token,
enable_hitl=True,
)
execution_tools_schema = execution_toolkit.get_json_schemas()
execution_prompt = build_system_prompt(
stage="execution",
user_context=user_context,
@@ -2,6 +2,7 @@ from __future__ import annotations
import asyncio
import json
import math
from time import perf_counter
from typing import Any, cast
@@ -106,6 +107,9 @@ class AgentScopeReActRunner:
stage_config=stage_config,
response=response,
latency_ms=latency_ms,
system_prompt=system_prompt,
user_prompt=user_prompt,
assistant_text=text_content,
)
except json.JSONDecodeError as exc:
logger.exception(
@@ -234,6 +238,9 @@ def _merge_stage_response_metadata(
stage_config: RuntimeStageConfig,
response: Any,
latency_ms: int,
system_prompt: str,
user_prompt: str | list[dict[str, Any]],
assistant_text: str,
) -> dict[str, Any]:
result = dict(payload)
existing = result.get("response_metadata")
@@ -247,6 +254,15 @@ def _merge_stage_response_metadata(
completion_tokens = _to_non_negative_int(
_read_value(usage, "completion_tokens") or _read_value(usage, "output_tokens")
)
if prompt_tokens is None:
prompt_tokens = _estimate_token_count(
{
"system": system_prompt,
"user": user_prompt,
}
)
if completion_tokens is None:
completion_tokens = _estimate_token_count(assistant_text)
cost = _to_non_negative_float(
_read_value(usage, "cost")
or _read_value(_read_value(usage, "metadata"), "cost")
@@ -352,3 +368,16 @@ def _estimate_cost_by_pricing(
)
except ValueError:
return None
def _estimate_token_count(value: object) -> int:
try:
serialized = (
value if isinstance(value, str) else json.dumps(value, ensure_ascii=False)
)
except (TypeError, ValueError):
serialized = str(value)
normalized = serialized.strip()
if not normalized:
return 0
return max(1, math.ceil(len(normalized) / 4))
+7 -8
View File
@@ -21,6 +21,7 @@ from core.config.settings import config
from core.db.session import AsyncSessionLocal
from core.logging import get_logger
from core.taskiq.app import bulk_broker, critical_broker, default_broker
from core.agentscope.tools.tool_result_storage import create_tool_result_storage
from models.agent_chat_message import AgentChatMessage, AgentChatMessageRole
from services.base.redis import get_or_init_redis_client
@@ -67,16 +68,10 @@ def _build_user_context(*, owner_id: UUID, run_input: RunCommand) -> UserAgentCo
def _extract_user_token(
*, command: dict[str, Any], run_input: RunCommand
) -> str | None:
del run_input
raw_token = command.get("user_token")
if isinstance(raw_token, str) and raw_token.strip():
return raw_token.strip()
forwarded = (
run_input.forwarded_props if isinstance(run_input.forwarded_props, dict) else {}
)
for key in ("accessToken", "userToken", "token"):
value = forwarded.get(key)
if isinstance(value, str) and value.strip():
return value.strip()
return None
@@ -162,7 +157,11 @@ async def run_agentscope_task(command: dict[str, Any]) -> dict[str, object]:
)
pipeline = AgentScopeEventPipeline(
codec=AgentScopeAgUiCodec(),
store=SqlAlchemyEventStore(session_factory=AsyncSessionLocal),
store=SqlAlchemyEventStore(
session_factory=AsyncSessionLocal,
tool_result_storage=create_tool_result_storage(),
tool_result_bucket=config.storage.bucket,
),
bus=bus,
)
runtime = route_runtime_type(
@@ -67,6 +67,7 @@ def validate_run_request_messages_contract(run_input: RunAgentInput) -> None:
message = run_input.messages[0]
if getattr(message, "role", None) != "user":
raise ValueError("RunAgentInput.messages[0].role must be user")
_validate_user_content_blocks(getattr(message, "content", None))
extract_latest_user_payload(run_input)
@@ -106,84 +107,76 @@ def extract_latest_user_payload(
text_parts.append(text)
blocks.append({"type": "text", "text": text})
continue
if item_type not in {"image", "binary"}:
if item_type != "binary":
continue
source_type: str | None = None
source_value: str | None = None
source_mime: str | None = None
if item_type == "binary":
source_mime = (
item.get("mimeType")
if isinstance(item, dict)
else getattr(item, "mime_type", None)
)
source_url = (
item.get("url")
if isinstance(item, dict)
else getattr(item, "url", None)
)
source_data = (
item.get("data")
if isinstance(item, dict)
else getattr(item, "data", None)
)
if isinstance(source_url, str) and source_url:
source_type = "url"
source_value = source_url
elif isinstance(source_data, str) and source_data:
source_type = "data"
source_value = source_data
else:
source = getattr(item, "source", None)
source_type = (
source.get("type")
if isinstance(source, dict)
else getattr(source, "type", None)
)
source_value = (
source.get("value")
if isinstance(source, dict)
else getattr(source, "value", None)
)
source_mime = (
source.get("mimeType")
if isinstance(source, dict)
else getattr(source, "mimeType", None)
)
if (
source_type == "url"
and isinstance(source_value, str)
and source_value
):
source_url = (
item.get("url")
if isinstance(item, dict)
else getattr(item, "url", None)
)
if isinstance(source_url, str) and source_url:
blocks.append(
{"type": "image_url", "image_url": {"url": source_value}}
)
elif (
source_type == "data"
and isinstance(source_value, str)
and source_value
):
mime_type = (
source_mime
if isinstance(source_mime, str) and source_mime
else "image/png"
)
blocks.append(
{
"type": "image_url",
"image_url": {
"url": f"data:{mime_type};base64,{source_value}"
},
}
{"type": "image_url", "image_url": {"url": source_url}}
)
combined = "".join(text_parts).strip()
if combined:
if combined or blocks:
return combined, blocks
raise ValueError(
"RunAgentInput.messages requires at least one non-empty user message"
)
def _validate_user_content_blocks(content: Any) -> None:
if isinstance(content, str):
if content.strip():
return
raise ValueError(
"RunAgentInput.messages requires at least one non-empty user message"
)
if not isinstance(content, list):
raise ValueError("RunAgentInput.messages[0].content must be string or list")
has_text = False
has_binary = False
for item in content:
item_type = getattr(item, "type", None)
if item_type == "text":
text = getattr(item, "text", None)
if isinstance(text, str) and text.strip():
has_text = True
continue
if item_type == "binary":
mime_type = (
item.get("mimeType")
if isinstance(item, dict)
else getattr(item, "mime_type", None)
)
url = (
item.get("url")
if isinstance(item, dict)
else getattr(item, "url", None)
)
data = (
item.get("data")
if isinstance(item, dict)
else getattr(item, "data", None)
)
if not isinstance(mime_type, str) or not mime_type.startswith("image/"):
raise ValueError("binary content requires image mimeType")
if not isinstance(url, str) or not url:
raise ValueError("binary content requires url")
if isinstance(data, str) and data:
raise ValueError("binary content data is not allowed")
has_binary = True
continue
raise ValueError("unsupported content block type")
if not has_text and not has_binary:
raise ValueError(
"RunAgentInput.messages requires at least one non-empty user message"
)
def extract_latest_tool_result(
run_input: RunAgentInput,
) -> tuple[str, dict[str, object]]:
+3 -2
View File
@@ -3,10 +3,11 @@ from __future__ import annotations
import logging
from logging.config import dictConfig
from pathlib import Path
from typing import cast
import structlog
from core.config.settings import PROJECT_ROOT, RuntimeSettings, Settings
from core.config.settings import PROJECT_ROOT, RuntimeSettings, Settings, config
from core.logging.formatters import (
build_plain_formatter,
build_processor_formatter,
@@ -77,7 +78,7 @@ def build_logging_config(runtime: RuntimeSettings) -> dict[str, object]:
def configure_logging(settings: Settings | None = None) -> None:
active_settings = settings or Settings()
active_settings = settings if settings is not None else cast(Settings, config)
runtime = active_settings.runtime
try:
+25 -10
View File
@@ -19,19 +19,14 @@ class SupabaseService(BaseServiceProvider):
async def initialize(self, **_: Any) -> bool:
try:
self._client = create_client(
self._settings.url,
self._settings.anon_key,
)
self._admin_client = create_client(
self._settings.url,
self._settings.service_role_key,
)
self._init_clients()
self._set_initialized(True)
self.logger.info("Supabase service initialized")
return True
except Exception as exc: # noqa: BLE001
self.logger.warning("Supabase service initialization failed", error=str(exc))
self.logger.warning(
"Supabase service initialization failed", error=str(exc)
)
self._client = None
self._admin_client = None
self._set_initialized(False)
@@ -51,7 +46,9 @@ class SupabaseService(BaseServiceProvider):
return {"status": "unhealthy", "details": {"error": "not initialized"}}
try:
await asyncio.to_thread(client.auth.get_session)
await asyncio.to_thread(admin_client.auth.admin.list_users, page=1, per_page=1)
await asyncio.to_thread(
admin_client.auth.admin.list_users, page=1, per_page=1
)
return {
"status": "healthy",
"details": {
@@ -70,17 +67,35 @@ class SupabaseService(BaseServiceProvider):
return self._require_admin_client()
def _require_client(self) -> Any:
if self._client is None or self._admin_client is None:
self._init_clients()
self._set_initialized(True)
self.logger.info("Supabase service lazily initialized")
client = self._client
if client is None:
raise RuntimeError("Supabase client is not initialized")
return client
def _require_admin_client(self) -> Any:
if self._client is None or self._admin_client is None:
self._init_clients()
self._set_initialized(True)
self.logger.info("Supabase service lazily initialized")
admin_client = self._admin_client
if admin_client is None:
raise RuntimeError("Supabase admin client is not initialized")
return admin_client
def _init_clients(self) -> None:
self._client = create_client(
self._settings.url,
self._settings.anon_key,
)
self._admin_client = create_client(
self._settings.url,
self._settings.service_role_key,
)
supabase_service: SupabaseService = register_service_instance(
"supabase", SupabaseService()
+91 -1
View File
@@ -3,11 +3,20 @@ from __future__ import annotations
import asyncio
from typing import Any
from storage3.exceptions import StorageApiError
from core.config.settings import config
from services.base.supabase import supabase_service
class AgentAttachmentStorage:
def _validate_bucket(self, *, bucket: str) -> None:
expected = config.storage.bucket
if bucket != expected:
raise RuntimeError("Invalid attachment bucket")
def _bucket_client(self, *, bucket: str) -> Any:
self._validate_bucket(bucket=bucket)
client = supabase_service.get_admin_client()
storage = getattr(client, "storage", None)
if storage is None:
@@ -39,9 +48,82 @@ class AgentAttachmentStorage:
},
)
await asyncio.to_thread(_upload)
try:
await asyncio.to_thread(_upload)
except Exception as exc: # noqa: BLE001
if not _is_bucket_not_found_error(exc):
raise
await self._ensure_bucket_exists(bucket=bucket)
await asyncio.to_thread(_upload)
return path
async def _ensure_bucket_exists(self, *, bucket: str) -> None:
def _ensure() -> None:
client = supabase_service.get_admin_client()
storage = getattr(client, "storage", None)
if storage is None:
raise RuntimeError("Supabase storage client unavailable")
get_bucket = getattr(storage, "get_bucket", None)
if callable(get_bucket):
try:
get_bucket(bucket)
return
except Exception: # noqa: BLE001
pass
create_bucket = getattr(storage, "create_bucket", None)
if not callable(create_bucket):
raise RuntimeError("Supabase storage create_bucket is unavailable")
try:
create_bucket(bucket, options={"public": False})
except Exception as exc: # noqa: BLE001
message = str(exc).lower()
if "already exists" in message or "duplicate" in message:
return
raise
await asyncio.to_thread(_ensure)
async def download_bytes(self, *, bucket: str, path: str) -> bytes:
def _download() -> object:
bucket_client = self._bucket_client(bucket=bucket)
download = getattr(bucket_client, "download", None)
if not callable(download):
raise RuntimeError("Supabase storage download is unavailable")
return download(path)
raw = await asyncio.to_thread(_download)
if isinstance(raw, bytes):
return raw
if isinstance(raw, bytearray):
return bytes(raw)
if isinstance(raw, memoryview):
return raw.tobytes()
raise RuntimeError("Invalid attachment payload")
async def create_signed_url(
self,
*,
bucket: str,
path: str,
expires_in_seconds: int,
) -> str:
def _create_signed_url() -> object:
bucket_client = self._bucket_client(bucket=bucket)
signer = getattr(bucket_client, "create_signed_url", None)
if not callable(signer):
raise RuntimeError("Supabase storage signed url is unavailable")
return signer(path, expires_in_seconds)
raw = await asyncio.to_thread(_create_signed_url)
if isinstance(raw, str):
return raw
if isinstance(raw, dict):
signed_url = raw.get("signedURL") or raw.get("signedUrl") or raw.get("url")
if isinstance(signed_url, str) and signed_url:
return signed_url
raise RuntimeError("Invalid signed url payload")
def create_attachment_storage() -> AgentAttachmentStorage | None:
try:
@@ -49,3 +131,11 @@ def create_attachment_storage() -> AgentAttachmentStorage | None:
except Exception:
return None
return AgentAttachmentStorage()
def _is_bucket_not_found_error(exc: Exception) -> bool:
if isinstance(exc, StorageApiError):
message = str(exc).lower()
return "bucket" in message and "not found" in message
message = str(exc).lower()
return "bucket" in message and "not found" in message
+138 -16
View File
@@ -9,6 +9,7 @@ from fastapi import HTTPException
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from core.config.settings import config
from models.agent_chat_message import AgentChatMessage, AgentChatMessageRole
from models.agent_chat_session import AgentChatSession
@@ -200,6 +201,61 @@ class AgentRepository:
return None
return str(latest_id)
async def get_message_attachment_reference(
self,
*,
session_id: str,
message_id: str,
attachment_index: int,
) -> dict[str, str] | None:
try:
session_uuid = UUID(session_id)
message_uuid = UUID(message_id)
except ValueError as exc:
raise HTTPException(
status_code=422, detail="Invalid message/session id"
) from exc
stmt = (
select(AgentChatMessage)
.where(AgentChatMessage.id == message_uuid)
.where(AgentChatMessage.session_id == session_uuid)
.where(AgentChatMessage.deleted_at.is_(None))
)
message = (await self._session.execute(stmt)).scalar_one_or_none()
if message is None:
return None
metadata = (
message.metadata_json if isinstance(message.metadata_json, dict) else {}
)
attachments_raw = metadata.get("attachments")
if not isinstance(attachments_raw, list):
return None
if attachment_index < 0 or attachment_index >= len(attachments_raw):
return None
attachment = attachments_raw[attachment_index]
if not isinstance(attachment, dict):
return None
bucket = attachment.get("bucket")
path = attachment.get("path")
mime_type = attachment.get("mimeType")
if (
not isinstance(bucket, str)
or not bucket
or not isinstance(path, str)
or not path
or not isinstance(mime_type, str)
or not mime_type
):
return None
return {
"bucket": bucket,
"path": path,
"mimeType": mime_type,
}
async def _to_snapshot_message(
self, message: AgentChatMessage
) -> dict[str, object]:
@@ -233,30 +289,65 @@ class AgentRepository:
storage_bucket = metadata.get("storage_bucket")
storage_path = metadata.get("storage_path")
if isinstance(storage_bucket, str) and isinstance(storage_path, str):
try:
hydrated_content = await self._tool_result_storage.read_json(
bucket=storage_bucket,
path=storage_path,
expected_bucket = config.storage.bucket
message_session_id = getattr(message, "session_id", None)
expected_prefix = (
f"tool-results/{message_session_id}/"
if message_session_id is not None
else None
)
tool_call_id = metadata.get("tool_call_id")
is_legacy_path = isinstance(
tool_call_id, str
) and storage_path.endswith(f"/{tool_call_id}.json")
if (
storage_bucket == expected_bucket
and _is_safe_storage_path(storage_path)
and (
(
expected_prefix is not None
and storage_path.startswith(expected_prefix)
)
or (
storage_path.startswith("tool-results/")
and is_legacy_path
)
)
except Exception:
hydrated_content = None
):
try:
hydrated_content = (
await self._tool_result_storage.read_json(
bucket=storage_bucket,
path=storage_path,
)
)
except Exception:
hydrated_content = None
resolved_content = hydrated_content or parsed_content
payload["content"] = message.content
if resolved_content is not None:
result = resolved_content.get("result")
if isinstance(result, dict):
result_content = result.get("content")
if isinstance(result_content, str):
payload["content"] = result_content
ui = resolved_content.get("ui")
if not isinstance(ui, dict):
ui = resolved_content.get("ui_schema")
if isinstance(ui, dict):
payload["ui"] = ui
display_content = resolved_content.get("content")
if isinstance(display_content, str):
if not isinstance(display_content, str):
nested_result = resolved_content.get("result")
if isinstance(nested_result, dict):
nested_content = nested_result.get("content")
if isinstance(nested_content, str):
display_content = nested_content
if (
isinstance(display_content, str)
and display_content.strip()
and (
not payload["content"]
or _looks_like_offloaded_placeholder(str(payload["content"]))
)
):
payload["content"] = display_content
if "content" not in payload:
payload["content"] = message.content
else:
payload["content"] = message.content
metadata = message.metadata_json or {}
@@ -264,7 +355,22 @@ class AgentRepository:
metadata.get("attachments") if isinstance(metadata, dict) else None
)
if isinstance(attachments, list):
rendered = [item for item in attachments if isinstance(item, dict)]
rendered: list[dict[str, object]] = []
for index, item in enumerate(attachments):
if not isinstance(item, dict):
continue
mime_type = item.get("mimeType")
if not isinstance(mime_type, str) or not mime_type:
continue
rendered.append(
{
"mimeType": mime_type,
"previewPath": (
f"/api/v1/agent/runs/{message.session_id}/attachments/"
f"{message.id}/{index}"
),
}
)
if rendered:
payload["attachments"] = rendered
return payload
@@ -279,3 +385,19 @@ def _derive_session_title(content_text: str) -> str | None:
if not normalized:
return None
return normalized[:80]
def _is_safe_storage_path(path: str) -> bool:
normalized = path.strip()
if not normalized:
return False
if normalized.startswith("/"):
return False
if ".." in normalized:
return False
return True
def _looks_like_offloaded_placeholder(content: str) -> bool:
normalized = content.strip().lower()
return normalized in {'{"offloaded":true}', '{"offloaded": true}'}
+121 -2
View File
@@ -10,7 +10,17 @@ import time
from typing import Annotated, Union
from ag_ui.core import RunAgentInput
from fastapi import APIRouter, Depends, Header, Query, Request, status, UploadFile
from fastapi import (
APIRouter,
Depends,
File,
Form,
Header,
Query,
Request,
status,
UploadFile,
)
from fastapi import HTTPException
from fastapi.responses import JSONResponse, StreamingResponse
@@ -20,11 +30,18 @@ from core.agentscope.schemas.agui_input import (
parse_run_input,
validate_run_request_messages_contract,
)
from core.auth.jwt_verifier import JwtVerifier, TokenValidationError
from core.auth.models import CurrentUser
from core.config.settings import config
from core.logging import get_logger
from services.base.redis import get_or_init_redis_client
from v1.agent.dependencies import get_agent_service
from v1.agent.schemas import AsrTranscribeResponse, TaskAcceptedResponse
from v1.agent.schemas import (
AsrTranscribeResponse,
AttachmentReference,
AttachmentUploadResponse,
TaskAcceptedResponse,
)
from v1.agent.service import AgentService, asr_service
from v1.users.dependencies import get_current_user
@@ -38,6 +55,7 @@ _SSE_SLOT_TTL_SECONDS = 15 * 60
_MAX_TRANSCRIBE_AUDIO_BYTES = 10 * 1024 * 1024
_TRANSCRIBE_READ_CHUNK_BYTES = 1024 * 1024
_MULTIPART_OVERHEAD_BYTES = 64 * 1024
_MAX_ATTACHMENT_UPLOAD_BYTES = 5 * 1024 * 1024
_WAV_HEADER_MIN_BYTES = 12
_ALLOWED_AUDIO_CONTENT_TYPES = {
"audio/wav",
@@ -46,6 +64,42 @@ _ALLOWED_AUDIO_CONTENT_TYPES = {
}
def _verified_access_token_for_user(
*,
authorization: str | None,
current_user: CurrentUser,
) -> str | None:
if not isinstance(authorization, str):
return None
normalized = authorization.strip()
if not normalized:
return None
if not normalized.lower().startswith("bearer "):
raise HTTPException(status_code=401, detail="Unauthorized")
token = normalized[7:].strip()
if not token:
raise HTTPException(status_code=401, detail="Unauthorized")
jwt_secret = config.supabase.jwt_secret
if jwt_secret is None:
raise HTTPException(status_code=503, detail="Auth verifier unavailable")
verifier = JwtVerifier(
issuer=str(config.supabase.jwt_issuer),
jwt_secret=jwt_secret.get_secret_value(),
jwt_algorithm=config.supabase.jwt_algorithm,
)
try:
payload = verifier.verify(token)
except TokenValidationError as exc:
raise HTTPException(status_code=401, detail="Unauthorized") from exc
subject = payload.get("sub")
if not isinstance(subject, str) or subject != str(current_user.id):
raise HTTPException(status_code=403, detail="Forbidden")
return token
def _looks_like_wav_header(header: bytes) -> bool:
if len(header) < _WAV_HEADER_MIN_BYTES:
return False
@@ -111,6 +165,7 @@ async def enqueue_run(
request: RunAgentInput,
service: Annotated[AgentService, Depends(get_agent_service)],
current_user: Annotated[CurrentUser, Depends(get_current_user)],
authorization: str | None = Header(default=None, alias="Authorization"),
) -> TaskAcceptedResponse:
try:
normalized = parse_run_input(request.model_dump(mode="json", by_alias=True))
@@ -120,10 +175,15 @@ async def enqueue_run(
allowed = await _allow_run_request(user_id=str(current_user.id))
if not allowed:
raise HTTPException(status_code=429, detail="Too many run requests")
user_token = _verified_access_token_for_user(
authorization=authorization,
current_user=current_user,
)
task = await service.enqueue_run(
run_input=request,
current_user=current_user,
user_token=user_token,
)
return TaskAcceptedResponse(
taskId=task.task_id,
@@ -143,6 +203,7 @@ async def enqueue_resume(
request: RunAgentInput,
service: Annotated[AgentService, Depends(get_agent_service)],
current_user: Annotated[CurrentUser, Depends(get_current_user)],
authorization: str | None = Header(default=None, alias="Authorization"),
) -> TaskAcceptedResponse:
if request.thread_id != thread_id:
raise HTTPException(status_code=422, detail="thread_id path/body mismatch")
@@ -154,10 +215,15 @@ async def enqueue_resume(
allowed = await _allow_run_request(user_id=str(current_user.id))
if not allowed:
raise HTTPException(status_code=429, detail="Too many run requests")
user_token = _verified_access_token_for_user(
authorization=authorization,
current_user=current_user,
)
task = await service.enqueue_resume(
thread_id=thread_id,
run_input=request,
current_user=current_user,
user_token=user_token,
)
return TaskAcceptedResponse(
taskId=task.task_id,
@@ -253,6 +319,31 @@ async def get_history_snapshot(
)
@router.get("/runs/{thread_id}/attachments/{message_id}/{attachment_index}")
async def get_attachment_preview(
thread_id: str,
message_id: str,
attachment_index: int,
service: Annotated[AgentService, Depends(get_agent_service)],
current_user: Annotated[CurrentUser, Depends(get_current_user)],
) -> StreamingResponse:
if attachment_index < 0:
raise HTTPException(status_code=422, detail="Invalid attachment index")
payload, mime_type = await service.get_attachment_preview(
thread_id=thread_id,
message_id=message_id,
attachment_index=attachment_index,
current_user=current_user,
)
return StreamingResponse(
iter([payload]),
media_type=mime_type,
headers={
"Cache-Control": "private, max-age=300",
},
)
@router.get("/history")
async def get_user_history_snapshot(
service: Annotated[AgentService, Depends(get_agent_service)],
@@ -267,6 +358,34 @@ async def get_user_history_snapshot(
)
@router.post(
"/attachments",
response_model=AttachmentUploadResponse,
status_code=status.HTTP_200_OK,
)
async def upload_attachment(
service: Annotated[AgentService, Depends(get_agent_service)],
current_user: Annotated[CurrentUser, Depends(get_current_user)],
thread_id: str = Form(alias="threadId"),
file: UploadFile = File(),
) -> AttachmentUploadResponse:
payload = await file.read()
if not payload:
raise HTTPException(status_code=422, detail="Empty attachment")
if len(payload) > _MAX_ATTACHMENT_UPLOAD_BYTES:
raise HTTPException(status_code=413, detail="Attachment too large")
attachment = await service.upload_attachment(
thread_id=thread_id,
filename=file.filename,
content_type=file.content_type,
payload=payload,
current_user=current_user,
)
return AttachmentUploadResponse(
attachment=AttachmentReference.model_validate(attachment),
)
@router.post(
"/transcribe",
response_model=AsrTranscribeResponse,
+13
View File
@@ -14,3 +14,16 @@ class TaskAcceptedResponse(BaseModel):
class AsrTranscribeResponse(BaseModel):
transcript: str = Field(description="Transcribed text from audio")
class AttachmentReference(BaseModel):
model_config = ConfigDict(populate_by_name=True, serialize_by_alias=True)
bucket: str
path: str
mime_type: str = Field(alias="mimeType")
url: str
class AttachmentUploadResponse(BaseModel):
attachment: AttachmentReference
+297 -60
View File
@@ -1,7 +1,6 @@
from __future__ import annotations
import asyncio
import base64
from dataclasses import dataclass
from datetime import date
import hashlib
@@ -19,17 +18,22 @@ from core.config.settings import config
from core.logging import get_logger
logger = get_logger(__name__)
_ALLOWED_ATTACHMENT_MIME_TYPES = {"image/png", "image/jpeg", "image/webp"}
_MAX_ATTACHMENT_BYTES = 5 * 1024 * 1024
_MAX_TOTAL_ATTACHMENT_BYTES = 12 * 1024 * 1024
def _extract_user_token_from_run_input(run_input: RunAgentInput) -> str | None:
forwarded = run_input.forwarded_props
if not isinstance(forwarded, dict):
def _normalize_bearer_token(value: str | None) -> str | None:
if not isinstance(value, str):
return None
for key in ("accessToken", "userToken", "token"):
value = forwarded.get(key)
if isinstance(value, str) and value.strip():
return value.strip()
return None
normalized = value.strip()
if not normalized:
return None
lower = normalized.lower()
if lower.startswith("bearer "):
token = normalized[7:].strip()
return token or None
return normalized
@dataclass(frozen=True)
@@ -66,6 +70,14 @@ class AgentRepositoryLike(Protocol):
metadata: dict[str, object] | None,
) -> None: ...
async def get_message_attachment_reference(
self,
*,
session_id: str,
message_id: str,
attachment_index: int,
) -> dict[str, str] | None: ...
class QueueClientLike(Protocol):
async def enqueue(
@@ -92,6 +104,16 @@ class AttachmentStorageLike(Protocol):
content_type: str,
) -> str: ...
async def download_bytes(self, *, bucket: str, path: str) -> bytes: ...
async def create_signed_url(
self,
*,
bucket: str,
path: str,
expires_in_seconds: int,
) -> str: ...
def ensure_session_owner(*, owner_id: str, current_user: CurrentUser) -> None:
if owner_id != str(current_user.id):
@@ -104,6 +126,8 @@ class AgentService:
_stream: EventStreamLike
_attachment_storage: AttachmentStorageLike | None
_SIGNED_URL_EXPIRES_IN_SECONDS = 3600
def __init__(
self,
*,
@@ -122,6 +146,7 @@ class AgentService:
*,
run_input: RunAgentInput,
current_user: CurrentUser,
user_token: str | None = None,
) -> TaskAccepted:
created = False
thread_id = run_input.thread_id
@@ -161,7 +186,7 @@ class AgentService:
command={
"command": "run",
"owner_id": str(current_user.id),
"user_token": _extract_user_token_from_run_input(run_input),
"user_token": _normalize_bearer_token(user_token),
"run_input": run_input.model_dump(mode="json", by_alias=True),
},
dedup_key=None,
@@ -179,57 +204,115 @@ class AgentService:
run_input: RunAgentInput,
current_user: CurrentUser,
) -> tuple[str, dict[str, object] | None]:
text, content_blocks = extract_latest_user_payload(run_input)
text, _ = extract_latest_user_payload(run_input)
content_blocks = _extract_latest_user_content_blocks(run_input)
attachments: list[dict[str, object]] = []
if self._attachment_storage is not None:
for index, block in enumerate(content_blocks):
if not isinstance(block, dict):
continue
if block.get("type") != "image_url":
continue
image_value = block.get("image_url")
if not isinstance(image_value, dict):
continue
url = image_value.get("url")
if not isinstance(url, str) or not url.startswith("data:"):
continue
decoded = _decode_data_url(url)
if decoded is None:
continue
mime_type, payload = decoded
suffix = _mime_to_suffix(mime_type)
checksum = hashlib.sha1(payload).hexdigest()[:16]
path = (
f"agent-inputs/{current_user.id}/{run_input.thread_id}/"
f"{run_input.run_id}/attachment-{index}-{checksum}.{suffix}"
binary_blocks = [
block
for block in content_blocks
if isinstance(block, dict) and block.get("type") == "binary"
]
if binary_blocks:
if self._attachment_storage is None:
raise HTTPException(
status_code=503,
detail="Attachment storage unavailable",
)
bucket_name = config.storage.bucket
forwarded_props = (
run_input.forwarded_props
if isinstance(run_input.forwarded_props, dict)
else {}
)
raw_attachments = forwarded_props.get("attachments")
if not isinstance(raw_attachments, list):
raise HTTPException(
status_code=422, detail="Invalid attachments payload"
)
if len(raw_attachments) != len(binary_blocks):
raise HTTPException(
status_code=422, detail="Invalid attachments payload"
)
total_attachment_bytes = 0
expected_prefix = f"agent-inputs/{current_user.id}/{run_input.thread_id}/"
for index, raw_attachment in enumerate(raw_attachments):
if not isinstance(raw_attachment, dict):
raise HTTPException(
status_code=422,
detail="Invalid attachment reference",
)
bucket = raw_attachment.get("bucket")
path = raw_attachment.get("path")
mime_type = raw_attachment.get("mimeType")
if (
not isinstance(bucket, str)
or not isinstance(path, str)
or not isinstance(mime_type, str)
):
raise HTTPException(
status_code=422,
detail="Invalid attachment reference",
)
if bucket != config.storage.bucket:
raise HTTPException(status_code=403, detail="Forbidden")
if not _is_safe_attachment_path(path, expected_prefix=expected_prefix):
raise HTTPException(status_code=403, detail="Forbidden")
if mime_type.lower() not in _ALLOWED_ATTACHMENT_MIME_TYPES:
raise HTTPException(
status_code=422,
detail="Unsupported attachment type",
)
binary_block = binary_blocks[index]
binary_mime = binary_block.get("mimeType")
binary_url = binary_block.get("url")
if (
not isinstance(binary_mime, str)
or binary_mime != mime_type
or not isinstance(binary_url, str)
or not binary_url
):
raise HTTPException(
status_code=422,
detail="Invalid attachments payload",
)
try:
stored_path = await self._attachment_storage.upload_bytes(
bucket=bucket_name,
payload = await self._attachment_storage.download_bytes(
bucket=bucket,
path=path,
content=payload,
content_type=mime_type,
)
except Exception: # noqa: BLE001
logger.exception(
"Attachment upload failed",
"Attachment validation download failed",
extra={
"bucket": bucket_name,
"bucket": bucket,
"path": path,
"mime_type": mime_type,
"thread_id": run_input.thread_id,
"run_id": run_input.run_id,
},
)
raise HTTPException(
status_code=502,
detail="Failed to upload attachment",
detail="Failed to fetch attachment",
)
payload_size = len(payload)
if payload_size > _MAX_ATTACHMENT_BYTES:
raise HTTPException(
status_code=413,
detail="Attachment too large",
)
total_attachment_bytes += payload_size
if total_attachment_bytes > _MAX_TOTAL_ATTACHMENT_BYTES:
raise HTTPException(
status_code=413,
detail="Attachments too large",
)
attachments.append(
{
"bucket": bucket_name,
"path": stored_path,
"bucket": bucket,
"path": path,
"mimeType": mime_type,
}
)
@@ -238,12 +321,94 @@ class AgentService:
metadata["attachments"] = attachments
return text, metadata or None
async def upload_attachment(
self,
*,
thread_id: str,
filename: str | None,
content_type: str | None,
payload: bytes,
current_user: CurrentUser,
) -> dict[str, str]:
try:
owner = await self._repository.get_session_owner(session_id=thread_id)
except HTTPException as exc:
if exc.status_code != 404:
raise
try:
await self._repository.create_session_for_user(
user_id=str(current_user.id),
session_id=thread_id,
)
await self._repository.commit()
except IntegrityError:
await self._repository.rollback()
owner = await self._repository.get_session_owner(session_id=thread_id)
ensure_session_owner(owner_id=owner, current_user=current_user)
else:
ensure_session_owner(owner_id=owner, current_user=current_user)
if self._attachment_storage is None:
raise HTTPException(
status_code=503, detail="Attachment storage unavailable"
)
if not isinstance(content_type, str):
raise HTTPException(status_code=422, detail="Unsupported attachment type")
mime_type = content_type.lower()
if mime_type not in _ALLOWED_ATTACHMENT_MIME_TYPES:
raise HTTPException(status_code=422, detail="Unsupported attachment type")
if not payload:
raise HTTPException(status_code=422, detail="Empty attachment")
if len(payload) > _MAX_ATTACHMENT_BYTES:
raise HTTPException(status_code=413, detail="Attachment too large")
suffix = _mime_to_suffix(mime_type)
checksum = hashlib.sha1(payload).hexdigest()[:16]
filename_seed = filename if isinstance(filename, str) and filename else "upload"
filename_hash = hashlib.sha1(filename_seed.encode("utf-8")).hexdigest()[:8]
path = (
f"agent-inputs/{current_user.id}/{thread_id}/uploads/"
f"{filename_hash}-{checksum}.{suffix}"
)
bucket_name = config.storage.bucket
try:
stored_path = await self._attachment_storage.upload_bytes(
bucket=bucket_name,
path=path,
content=payload,
content_type=mime_type,
)
signed_url = await self._attachment_storage.create_signed_url(
bucket=bucket_name,
path=stored_path,
expires_in_seconds=self._SIGNED_URL_EXPIRES_IN_SECONDS,
)
except Exception: # noqa: BLE001
logger.exception(
"Attachment upload failed",
extra={
"bucket": bucket_name,
"path": path,
"mime_type": mime_type,
"thread_id": thread_id,
},
)
raise HTTPException(status_code=502, detail="Failed to upload attachment")
return {
"bucket": bucket_name,
"path": stored_path,
"mimeType": mime_type,
"url": signed_url,
}
async def enqueue_resume(
self,
*,
thread_id: str,
run_input: RunAgentInput,
current_user: CurrentUser,
user_token: str | None = None,
) -> TaskAccepted:
owner = await self._repository.get_session_owner(session_id=thread_id)
ensure_session_owner(owner_id=owner, current_user=current_user)
@@ -253,7 +418,7 @@ class AgentService:
command={
"command": "resume",
"owner_id": str(current_user.id),
"user_token": _extract_user_token_from_run_input(run_input),
"user_token": _normalize_bearer_token(user_token),
"run_input": run_input.model_dump(mode="json", by_alias=True),
},
dedup_key=dedup_key,
@@ -336,6 +501,63 @@ class AgentService:
current_user=current_user,
)
async def get_attachment_preview(
self,
*,
thread_id: str,
message_id: str,
attachment_index: int,
current_user: CurrentUser,
) -> tuple[bytes, str]:
owner = await self._repository.get_session_owner(session_id=thread_id)
ensure_session_owner(owner_id=owner, current_user=current_user)
if self._attachment_storage is None:
raise HTTPException(
status_code=503, detail="Attachment storage unavailable"
)
ref = await self._repository.get_message_attachment_reference(
session_id=thread_id,
message_id=message_id,
attachment_index=attachment_index,
)
if ref is None:
raise HTTPException(status_code=404, detail="Attachment not found")
bucket = ref.get("bucket")
path = ref.get("path")
mime_type = ref.get("mimeType")
if (
not isinstance(bucket, str)
or not isinstance(path, str)
or not isinstance(mime_type, str)
):
raise HTTPException(status_code=404, detail="Attachment not found")
if bucket != config.storage.bucket:
raise HTTPException(status_code=403, detail="Forbidden")
expected_prefix = f"agent-inputs/{current_user.id}/{thread_id}/"
if not _is_safe_attachment_path(path, expected_prefix=expected_prefix):
raise HTTPException(status_code=403, detail="Forbidden")
try:
payload = await self._attachment_storage.download_bytes(
bucket=bucket,
path=path,
)
except Exception: # noqa: BLE001
logger.exception(
"Attachment download failed",
extra={
"thread_id": thread_id,
"message_id": message_id,
"attachment_index": attachment_index,
"bucket": bucket,
},
)
raise HTTPException(status_code=502, detail="Failed to fetch attachment")
return payload, mime_type
class AsrService:
def __init__(self) -> None:
@@ -445,22 +667,26 @@ class AsrService:
asr_service = AsrService()
def _decode_data_url(data_url: str) -> tuple[str, bytes] | None:
if not data_url.startswith("data:"):
return None
header, sep, payload = data_url.partition(",")
if not sep:
return None
mime_type = "image/png"
if ";" in header:
maybe_mime = header[5:].split(";", 1)[0].strip()
if maybe_mime:
mime_type = maybe_mime
try:
decoded = base64.b64decode(payload, validate=True)
except ValueError:
return None
return mime_type, decoded
def _extract_latest_user_content_blocks(
run_input: RunAgentInput,
) -> list[dict[str, Any]]:
if not run_input.messages:
return []
latest = run_input.messages[-1]
content = getattr(latest, "content", None)
if not isinstance(content, list):
return []
blocks: list[dict[str, Any]] = []
for item in content:
if isinstance(item, dict):
blocks.append(item)
continue
model_dump = getattr(item, "model_dump", None)
if callable(model_dump):
dumped = model_dump(mode="json", by_alias=True, exclude_none=True)
if isinstance(dumped, dict):
blocks.append(dumped)
return blocks
def _mime_to_suffix(mime_type: str) -> str:
@@ -470,3 +696,14 @@ def _mime_to_suffix(mime_type: str) -> str:
"image/webp": "webp",
}
return mapping.get(mime_type.lower(), "bin")
def _is_safe_attachment_path(path: str, *, expected_prefix: str) -> bool:
normalized = path.strip()
if not normalized:
return False
if normalized.startswith("/"):
return False
if ".." in normalized:
return False
return normalized.startswith(expected_prefix)
@@ -1,16 +1,13 @@
from __future__ import annotations
from typing import Callable
from uuid import UUID
import pytest
from fastapi import HTTPException
from fastapi.testclient import TestClient
from app import app
from core.auth.models import CurrentUser
from v1.auth.dependencies import get_auth_service
from v1.users.dependencies import get_current_user
from v1.auth.rate_limit import reset_rate_limit_state
from v1.auth.schemas import (
AuthUser,
@@ -18,8 +18,14 @@ class _FakeAgentService:
def __init__(self) -> None:
self._stream_called = False
async def enqueue_run(self, *, run_input: RunAgentInput, current_user: CurrentUser):
del current_user
async def enqueue_run(
self,
*,
run_input: RunAgentInput,
current_user: CurrentUser,
user_token: str | None = None,
):
del current_user, user_token
return SimpleNamespace(
task_id="task-run-1",
thread_id=run_input.thread_id,
@@ -33,8 +39,9 @@ class _FakeAgentService:
thread_id: str,
run_input: RunAgentInput,
current_user: CurrentUser,
user_token: str | None = None,
):
del thread_id, current_user
del thread_id, current_user, user_token
return SimpleNamespace(
task_id="task-resume-1",
thread_id=run_input.thread_id,
@@ -109,6 +116,23 @@ class _FakeAgentService:
},
}
async def upload_attachment(
self,
*,
thread_id: str,
filename: str | None,
content_type: str | None,
payload: bytes,
current_user: CurrentUser,
) -> dict[str, str]:
del filename, content_type, payload, current_user
return {
"bucket": "bucket-test",
"path": f"agent-inputs/user/{thread_id}/upload.png",
"mimeType": "image/png",
"url": "https://signed.example/upload.png",
}
class _FailingStreamAgentService(_FakeAgentService):
async def stream_events(
@@ -393,6 +417,31 @@ def test_resume_accepts_tool_message_without_user_message() -> None:
app.dependency_overrides = {}
def test_upload_attachment_returns_reference() -> None:
app.dependency_overrides[get_agent_service] = lambda: _FakeAgentService()
app.dependency_overrides[get_current_user] = lambda: CurrentUser(
id=uuid4(), email="user@example.com"
)
client = TestClient(app)
file_payload = BytesIO(b"png")
file_payload.name = "demo.png"
try:
response = client.post(
"/api/v1/agent/attachments",
data={"threadId": "00000000-0000-0000-0000-000000000001"},
files={"file": ("demo.png", file_payload, "image/png")},
)
assert response.status_code == 200
body = response.json()
attachment = body["attachment"]
assert attachment["mimeType"] == "image/png"
assert "00000000-0000-0000-0000-000000000001" in attachment["path"]
finally:
app.dependency_overrides = {}
def test_asr_transcribe_returns_sync_transcript(monkeypatch) -> None:
app.dependency_overrides[get_current_user] = lambda: CurrentUser(
id=uuid4(), email="user@example.com"
@@ -40,3 +40,34 @@ def test_reserved_keys_in_data_cannot_override_wire_fields() -> None:
assert result["threadId"] == "thread-1"
assert result["runId"] == "run-1"
assert result["message"] == "ok"
def test_tool_result_wire_event_filters_sensitive_fields() -> None:
internal = {
"type": "tool.result",
"threadId": "thread-1",
"runId": "run-1",
"data": {
"messageId": "tool-result-1",
"toolCallId": "call-1",
"callId": "call-1",
"toolName": "calendar_write",
"content": "summary",
"ui": {"type": "calendar_operation.v1", "data": {"ok": True}},
"args": {"token": "secret"},
"result": {"raw": "secret"},
"error": "stack trace",
},
}
result = to_agui_wire_event(internal)
assert result["type"] == "TOOL_CALL_RESULT"
assert result["messageId"] == "tool-result-1"
assert result["toolCallId"] == "call-1"
assert result["toolName"] == "calendar_write"
assert result["content"] == "summary"
assert isinstance(result.get("ui"), dict)
assert "args" not in result
assert "result" not in result
assert "error" not in result
@@ -28,6 +28,27 @@ class _FakeSessionCtx:
del exc_type, exc, tb
class _FakeToolResultStorage:
def __init__(self) -> None:
self.upload_calls: list[dict[str, object]] = []
async def upload_json(
self,
*,
bucket: str,
path: str,
payload: dict[str, object],
) -> str:
self.upload_calls.append(
{
"bucket": bucket,
"path": path,
"payload": payload,
}
)
return path
@pytest.mark.asyncio
async def test_store_marks_session_running_on_run_started(
monkeypatch: pytest.MonkeyPatch,
@@ -300,7 +321,12 @@ async def test_store_persists_tool_call_result_as_tool_message(
monkeypatch.setattr(store_module, "MessageRepository", _FakeMessageRepository)
monkeypatch.setattr(store_module, "AgentChatSessionStatus", _SessionStatus)
store = store_module.SqlAlchemyEventStore(session_factory=lambda: _FakeSessionCtx())
fake_storage = _FakeToolResultStorage()
store = store_module.SqlAlchemyEventStore(
session_factory=lambda: _FakeSessionCtx(),
tool_result_storage=fake_storage,
tool_result_bucket="agent-tool-results",
)
await store.persist(
{
"type": "TOOL_CALL_RESULT",
@@ -310,7 +336,7 @@ async def test_store_persists_tool_call_result_as_tool_message(
"taskId": "t1",
"stage": "execution",
"args": {"title": "A"},
"result": {"event_id": "evt-1"},
"result": {"event_id": "evt-1", "token": "secret"},
}
)
@@ -318,9 +344,94 @@ async def test_store_persists_tool_call_result_as_tool_message(
assert getattr(append_kwargs["role"], "value", None) == "tool"
assert append_kwargs["tool_name"] == "calendar_write"
assert append_kwargs["metadata"]["task_id"] == "t1"
tool_call_id = append_kwargs["metadata"]["tool_call_id"]
assert isinstance(tool_call_id, str)
assert tool_call_id.startswith("run-1-t1-")
assert append_kwargs["metadata"]["storage_bucket"] == "agent-tool-results"
assert isinstance(append_kwargs["metadata"]["storage_path"], str)
assert append_kwargs["content"].startswith("已创建日程")
assert len(fake_storage.upload_calls) == 1
uploaded = fake_storage.upload_calls[0]
assert uploaded["bucket"] == "agent-tool-results"
payload = cast(dict[str, Any], uploaded["payload"])
assert payload["toolName"] == "calendar_write"
assert "args" not in payload
assert isinstance(payload.get("result"), dict)
assert payload["result"]["token"] == "[REDACTED]"
assert captured["message_delta"] == 1
@pytest.mark.asyncio
async def test_store_sanitizes_nested_sensitive_fields_in_result_payload(
monkeypatch: pytest.MonkeyPatch,
) -> None:
captured: dict[str, object] = {}
fake_chat_session = SimpleNamespace(state_snapshot={}, message_count=0)
class _FakeSessionRepository:
def __init__(self, session: object) -> None:
del session
async def get_session(self, *, session_id): # noqa: ANN001
del session_id
return fake_chat_session
async def lock_session_for_update(self, *, session_id): # noqa: ANN001
del session_id
return fake_chat_session
async def update_runtime_state(self, **kwargs): # noqa: ANN003
captured.update(kwargs)
class _FakeMessageRepository:
def __init__(self, session: object) -> None:
del session
async def append_message(self, **kwargs): # noqa: ANN003
captured["append_kwargs"] = kwargs
monkeypatch.setattr(store_module, "SessionRepository", _FakeSessionRepository)
monkeypatch.setattr(store_module, "MessageRepository", _FakeMessageRepository)
monkeypatch.setattr(store_module, "AgentChatSessionStatus", _SessionStatus)
fake_storage = _FakeToolResultStorage()
store = store_module.SqlAlchemyEventStore(
session_factory=lambda: _FakeSessionCtx(),
tool_result_storage=fake_storage,
tool_result_bucket="agent-tool-results",
)
await store.persist(
{
"type": "TOOL_CALL_RESULT",
"threadId": "00000000-0000-0000-0000-000000000001",
"runId": "run-1",
"toolName": "calendar_write",
"result": {
"data": {
"ok": True,
"accessToken": "secret-a",
"nested": {
"refresh_token": "secret-b",
},
"items": [
{"authorizationHeader": "secret-c"},
],
}
},
}
)
payload = cast(dict[str, Any], fake_storage.upload_calls[0]["payload"])
stored_result = cast(dict[str, Any], payload["result"])
data = cast(dict[str, Any], stored_result["data"])
assert data["accessToken"] == "[REDACTED]"
nested = cast(dict[str, Any], data["nested"])
assert nested["refresh_token"] == "[REDACTED]"
items = cast(list[Any], data["items"])
assert isinstance(items[0], dict)
assert items[0]["authorizationHeader"] == "[REDACTED]"
@pytest.mark.asyncio
async def test_store_drops_buffer_when_session_missing(
monkeypatch: pytest.MonkeyPatch,
@@ -0,0 +1,73 @@
from __future__ import annotations
from core.agentscope.events.tool_result_summary import build_tool_content_summary
def test_summary_prioritizes_error() -> None:
text = build_tool_content_summary(
tool_name="calendar_write",
args={"title": "A"},
result={"message": "ignored"},
error={"message": "denied"},
)
assert text == "calendar_write 执行失败:denied"
def test_summary_for_calendar_write() -> None:
text = build_tool_content_summary(
tool_name="calendar_write",
args={"title": "项目评审"},
result={"startAt": "明天 10:00"},
error=None,
)
assert text == "已创建日程:项目评审(明天 10:00)"
def test_summary_for_calendar_read() -> None:
text = build_tool_content_summary(
tool_name="calendar_read",
args={"query": "今天"},
result={"data": {"total": 3}},
error=None,
)
assert text == "查询到 3 条日程(今天)"
def test_summary_falls_back_to_result_content() -> None:
text = build_tool_content_summary(
tool_name="unknown_tool",
args=None,
result={"content": "这是非常长的说明" * 20},
error=None,
)
assert text.startswith("这是非常长的说明")
assert len(text) <= 80
def test_summary_default_done() -> None:
text = build_tool_content_summary(
tool_name="unknown_tool",
args=None,
result=None,
error=None,
)
assert text == "unknown_tool 执行完成"
def test_summary_marks_business_failure_when_ok_false() -> None:
text = build_tool_content_summary(
tool_name="calendar_write",
args={"title": "上学"},
result={
"type": "calendar_operation.v1",
"data": {
"ok": False,
"code": "UNAUTHORIZED",
"message": "calendar.write requires validated user token",
},
},
error=None,
)
assert (
text == "calendar_write 执行失败:calendar.write requires validated user token"
)
@@ -109,7 +109,6 @@ async def test_runtime_emits_started_text_and_finished_events() -> None:
"step.start",
"step.finish",
"step.start",
"step.finish",
"text.start",
"text.delta",
"text.end",
@@ -117,6 +116,7 @@ async def test_runtime_emits_started_text_and_finished_events() -> None:
"text.delta",
"text.end",
"tool.result",
"step.finish",
"step.start",
"text.start",
"text.delta",
@@ -127,10 +127,14 @@ async def test_runtime_emits_started_text_and_finished_events() -> None:
assert calls[1]["data"]["stepName"] == "intent"
assert calls[2]["data"]["stepName"] == "intent"
assert calls[3]["data"]["stepName"] == "execution"
assert calls[4]["data"]["stepName"] == "execution"
assert calls[5]["data"]["stage"] == "intent"
assert calls[8]["data"]["stage"] == "execution"
assert calls[11]["data"]["toolName"] == "calendar_write"
assert calls[4]["data"]["stage"] == "intent"
assert calls[7]["data"]["stage"] == "execution"
assert calls[10]["data"]["toolName"] == "calendar_write"
assert calls[10]["data"]["toolCallId"] == "run-1-t1-1"
assert calls[10]["data"]["messageId"] == "tool-result-run-1-t1-1"
tool_content = calls[10]["data"]["content"]
assert tool_content == "calendar_write 执行完成"
assert calls[11]["data"]["stepName"] == "execution"
assert calls[12]["data"]["stepName"] == "report"
assert calls[14]["data"]["delta"] == "hello world"
assert calls[13]["data"]["messageId"] == calls[14]["data"]["messageId"]
@@ -305,3 +309,300 @@ async def test_runtime_direct_response_finishes_without_report_stage() -> None:
]
assert calls[3]["data"]["stage"] == "intent"
assert calls[4]["data"]["delta"] == "direct-answer"
@pytest.mark.asyncio
async def test_runtime_tool_result_parses_json_string_ui_payload() -> None:
calls: list[dict[str, Any]] = []
class _FakePipeline:
async def emit(self, *, session_id: str, event: dict[str, object]) -> str:
assert session_id == "thread-1"
calls.append(event)
return f"{len(calls)}-0"
class _FakeOrchestrator:
async def run(self, **_: object) -> RuntimeOutput:
return RuntimeOutput(
intent=IntentOutput(
route="TASK_EXECUTION",
intent_summary="summary",
direct_response=None,
tasks=[IntentTask(task_id="t1", title="exec", objective="do")],
complexity="complex",
response_metadata={},
),
execution=ExecutionBatchOutput(
task_results=[
ExecutionTaskOutput(
task_id="t1",
status="SUCCESS",
execution_summary="execution-ok",
execution_data={},
user_feedback_needs=[],
response_metadata={},
tool_calls=[
ExecutionToolCall(
tool_name="calendar_write",
args={"title": "A"},
result='{"type":"calendar_card.v1","version":"v1","data":{"ok":true,"title":"A"},"actions":[]}',
)
],
)
],
overall_status="SUCCESS",
aggregate_summary="ok",
),
report=ReportOutput(
assistant_text="hello world",
response_metadata={},
),
)
runtime = AgentRouteRuntime(
orchestrator=_FakeOrchestrator(), pipeline=_FakePipeline()
)
command = RunCommand(threadId="thread-1", runId="run-1", messages=[])
await runtime.run(
command=command,
owner_id=uuid4(),
user_token="token",
user_context=_user_context(),
session=cast(AsyncSession, object()),
)
tool_events = [item for item in calls if item.get("type") == "tool.result"]
assert len(tool_events) == 1
data = tool_events[0]["data"]
assert isinstance(data, dict)
assert isinstance(data.get("ui"), dict)
assert data["ui"]["type"] == "calendar_card.v1"
@pytest.mark.asyncio
async def test_runtime_tool_result_keeps_plain_text_content() -> None:
calls: list[dict[str, Any]] = []
class _FakePipeline:
async def emit(self, *, session_id: str, event: dict[str, object]) -> str:
assert session_id == "thread-1"
calls.append(event)
return f"{len(calls)}-0"
class _FakeOrchestrator:
async def run(self, **_: object) -> RuntimeOutput:
return RuntimeOutput(
intent=IntentOutput(
route="TASK_EXECUTION",
intent_summary="summary",
direct_response=None,
tasks=[IntentTask(task_id="t1", title="exec", objective="do")],
complexity="complex",
response_metadata={},
),
execution=ExecutionBatchOutput(
task_results=[
ExecutionTaskOutput(
task_id="t1",
status="SUCCESS",
execution_summary="execution-ok",
execution_data={},
user_feedback_needs=[],
response_metadata={},
tool_calls=[
ExecutionToolCall(
tool_name="calendar_write",
args={"title": "A"},
result="created successfully",
)
],
)
],
overall_status="SUCCESS",
aggregate_summary="ok",
),
report=ReportOutput(
assistant_text="hello world",
response_metadata={},
),
)
runtime = AgentRouteRuntime(
orchestrator=_FakeOrchestrator(), pipeline=_FakePipeline()
)
command = RunCommand(threadId="thread-1", runId="run-1", messages=[])
await runtime.run(
command=command,
owner_id=uuid4(),
user_token="token",
user_context=_user_context(),
session=cast(AsyncSession, object()),
)
tool_events = [item for item in calls if item.get("type") == "tool.result"]
assert len(tool_events) == 1
data = tool_events[0]["data"]
assert isinstance(data, dict)
assert data["content"] == "created successfully"
@pytest.mark.asyncio
async def test_runtime_tool_result_sanitizes_sensitive_payload() -> None:
calls: list[dict[str, Any]] = []
class _FakePipeline:
async def emit(self, *, session_id: str, event: dict[str, object]) -> str:
assert session_id == "thread-1"
calls.append(event)
return f"{len(calls)}-0"
class _FakeOrchestrator:
async def run(self, **_: object) -> RuntimeOutput:
return RuntimeOutput(
intent=IntentOutput(
route="TASK_EXECUTION",
intent_summary="summary",
direct_response=None,
tasks=[IntentTask(task_id="t1", title="exec", objective="do")],
complexity="complex",
response_metadata={},
),
execution=ExecutionBatchOutput(
task_results=[
ExecutionTaskOutput(
task_id="t1",
status="SUCCESS",
execution_summary="execution-ok",
execution_data={},
user_feedback_needs=[],
response_metadata={},
tool_calls=[
ExecutionToolCall(
tool_name="calendar_write",
args={
"title": "A",
"accessToken": "arg-secret",
"author": "alice",
},
result={
"ok": True,
"accessToken": "secret-token",
"message": "Authorization: Bearer inline-token",
"nested": [
{
"authorizationHeader": "Bearer abc",
}
],
},
error="failed authorization=Bearer abc123 detail",
)
],
)
],
overall_status="SUCCESS",
aggregate_summary="ok",
),
report=ReportOutput(
assistant_text="hello world",
response_metadata={},
),
)
runtime = AgentRouteRuntime(
orchestrator=_FakeOrchestrator(), pipeline=_FakePipeline()
)
command = RunCommand(threadId="thread-1", runId="run-1", messages=[])
await runtime.run(
command=command,
owner_id=uuid4(),
user_token="token",
user_context=_user_context(),
session=cast(AsyncSession, object()),
)
tool_events = [item for item in calls if item.get("type") == "tool.result"]
assert len(tool_events) == 1
data = tool_events[0]["data"]
assert isinstance(data, dict)
assert isinstance(data["result"], dict)
assert data["result"]["accessToken"] == "[REDACTED]"
assert data["result"]["message"] == "Authorization=[REDACTED]"
nested = data["result"]["nested"]
assert isinstance(nested, list)
assert nested[0]["authorizationHeader"] == "[REDACTED]"
assert isinstance(data["args"], dict)
assert data["args"]["accessToken"] == "[REDACTED]"
assert data["args"]["author"] == "alice"
assert data["error"] == "failed authorization=[REDACTED] detail"
@pytest.mark.asyncio
async def test_runtime_tool_result_keeps_non_object_result() -> None:
calls: list[dict[str, Any]] = []
class _FakePipeline:
async def emit(self, *, session_id: str, event: dict[str, object]) -> str:
assert session_id == "thread-1"
calls.append(event)
return f"{len(calls)}-0"
class _FakeOrchestrator:
async def run(self, **_: object) -> RuntimeOutput:
return RuntimeOutput(
intent=IntentOutput(
route="TASK_EXECUTION",
intent_summary="summary",
direct_response=None,
tasks=[IntentTask(task_id="t1", title="exec", objective="do")],
complexity="complex",
response_metadata={},
),
execution=ExecutionBatchOutput(
task_results=[
ExecutionTaskOutput(
task_id="t1",
status="SUCCESS",
execution_summary="execution-ok",
execution_data={},
user_feedback_needs=[],
response_metadata={},
tool_calls=[
ExecutionToolCall(
tool_name="calendar_write",
args={"title": "A"},
result=["evt-1", "evt-2"],
)
],
)
],
overall_status="SUCCESS",
aggregate_summary="ok",
),
report=ReportOutput(
assistant_text="hello world",
response_metadata={},
),
)
runtime = AgentRouteRuntime(
orchestrator=_FakeOrchestrator(), pipeline=_FakePipeline()
)
command = RunCommand(threadId="thread-1", runId="run-1", messages=[])
await runtime.run(
command=command,
owner_id=uuid4(),
user_token="token",
user_context=_user_context(),
session=cast(AsyncSession, object()),
)
tool_events = [item for item in calls if item.get("type") == "tool.result"]
assert len(tool_events) == 1
data = tool_events[0]["data"]
assert isinstance(data, dict)
assert isinstance(data["result"], dict)
assert data["result"]["value"] == ["evt-1", "evt-2"]
@@ -212,6 +212,9 @@ def test_merge_stage_response_metadata_estimates_cost_from_pricing(
model="qwen3.5-flash",
),
latency_ms=50,
system_prompt="system",
user_prompt="user",
assistant_text='{"route":"DIRECT_RESPONSE"}',
)
metadata = payload["response_metadata"]
@@ -50,6 +50,10 @@ async def test_run_agentscope_task_calls_runtime_run(
async def _fake_get_redis_client() -> object:
return object()
async def _empty_context(**kwargs: object) -> list[dict[str, Any]]:
del kwargs
return []
monkeypatch.setattr(tasks_module, "AgentRouteRuntime", _FakeRuntime)
monkeypatch.setattr(
tasks_module,
@@ -60,7 +64,7 @@ async def test_run_agentscope_task_calls_runtime_run(
monkeypatch.setattr(
tasks_module,
"_build_recent_context_messages",
lambda **_: [],
_empty_context,
)
result = await tasks_module.run_agentscope_task(
@@ -101,6 +105,10 @@ async def test_run_agentscope_task_includes_recent_context_messages(
async def _fake_get_redis_client() -> object:
return object()
async def _empty_context(**kwargs: object) -> list[dict[str, Any]]:
del kwargs
return []
async def _fake_context(**kwargs: object) -> list[dict[str, Any]]:
del kwargs
return [{"id": "ctx-1", "role": "assistant", "content": "历史上下文"}]
@@ -115,7 +123,7 @@ async def test_run_agentscope_task_includes_recent_context_messages(
monkeypatch.setattr(
tasks_module,
"_build_recent_context_messages",
lambda **_: [],
_empty_context,
)
monkeypatch.setattr(
tasks_module,
@@ -94,3 +94,46 @@ def test_validate_run_request_messages_contract_requires_single_user_message() -
match="RunAgentInput.messages must contain exactly one user message",
):
validate_run_request_messages_contract(run_input)
def test_validate_run_request_messages_contract_accepts_binary_url_blocks() -> None:
payload = _base_payload()
payload["messages"] = [
{
"id": "u1",
"role": "user",
"content": [
{"type": "text", "text": "请分析"},
{
"type": "binary",
"mimeType": "image/png",
"url": "https://signed.example/a.png",
},
],
}
]
run_input = parse_run_input(payload)
validate_run_request_messages_contract(run_input)
def test_validate_run_request_messages_contract_rejects_binary_data_block() -> None:
payload = _base_payload()
payload["messages"] = [
{
"id": "u1",
"role": "user",
"content": [
{"type": "text", "text": "请分析"},
{
"type": "binary",
"mimeType": "image/png",
"data": "aGVsbG8=",
},
],
}
]
run_input = parse_run_input(payload)
with pytest.raises(ValueError, match="binary content requires url"):
validate_run_request_messages_contract(run_input)
@@ -54,3 +54,20 @@ def test_build_intent_user_prompt_filters_non_image_binary_block() -> None:
assert isinstance(prompt, list)
image_blocks = [item for item in prompt if item.get("type") == "image"]
assert image_blocks == []
def test_build_intent_user_prompt_includes_previous_context_messages() -> None:
prompt = build_intent_user_prompt(
user_input=[
{"id": "u1", "role": "user", "content": "我的口令是蓝鲸42"},
{"id": "a1", "role": "assistant", "content": "已记住"},
{"id": "u2", "role": "user", "content": "请重复口令"},
]
)
assert isinstance(prompt, list)
assert prompt
instruction = prompt[0].get("text", "")
assert isinstance(instruction, str)
assert "[Conversation Context]" in instruction
assert "\\u84dd\\u9cb842" in instruction
@@ -67,10 +67,8 @@ async def test_close_clears_clients(monkeypatch: pytest.MonkeyPatch) -> None:
assert await service.initialize() is True
assert await service.close() is True
assert service.is_initialized is False
with pytest.raises(RuntimeError):
service.get_client()
with pytest.raises(RuntimeError):
service.get_admin_client()
assert service.get_client() is not None
assert service.get_admin_client() is not None
@pytest.mark.asyncio
@@ -117,7 +115,47 @@ def test_get_client_raises_before_init() -> None:
settings=SupabaseSettings(public_url="https://test.supabase.co")
)
assert service.get_client() is not None
assert service.get_admin_client() is not None
def test_get_client_raises_when_lazy_initialization_fails(
monkeypatch: pytest.MonkeyPatch,
) -> None:
service = SupabaseService(
settings=SupabaseSettings(public_url="https://test.supabase.co")
)
def _fake_create_client(_: str, __: str) -> object:
raise RuntimeError("boom")
monkeypatch.setattr("services.base.supabase.create_client", _fake_create_client)
with pytest.raises(RuntimeError):
service.get_client()
with pytest.raises(RuntimeError):
service.get_admin_client()
def test_get_admin_client_lazily_initializes_clients(
monkeypatch: pytest.MonkeyPatch,
) -> None:
service = SupabaseService(
settings=SupabaseSettings(public_url="https://test.supabase.co")
)
anon_client = MagicMock(name="anon")
admin_client = MagicMock(name="admin")
create_calls: list[tuple[str, str]] = []
def _fake_create_client(url: str, key: str) -> object:
create_calls.append((url, key))
return anon_client if len(create_calls) == 1 else admin_client
monkeypatch.setattr("services.base.supabase.create_client", _fake_create_client)
resolved_admin = service.get_admin_client()
assert resolved_admin is admin_client
assert service.get_client() is anon_client
assert service.is_initialized is True
assert len(create_calls) == 2
@@ -0,0 +1,85 @@
from __future__ import annotations
from types import SimpleNamespace
import pytest
import v1.agent.attachment_storage as attachment_storage_module
class _FakeBucket:
def __init__(self) -> None:
self.upload_calls: list[tuple[str, bytes, dict[str, str]]] = []
self.download_calls: list[str] = []
def upload(self, path: str, content: bytes, options: dict[str, str]) -> object:
self.upload_calls.append((path, content, options))
return {"path": path}
def download(self, path: str) -> object:
self.download_calls.append(path)
return b"ok"
class _FakeStorage:
def __init__(self, bucket: _FakeBucket) -> None:
self._bucket = bucket
def from_(self, bucket: str) -> object:
del bucket
return self._bucket
@pytest.mark.asyncio
async def test_attachment_storage_rejects_unexpected_bucket(
monkeypatch: pytest.MonkeyPatch,
) -> None:
storage = attachment_storage_module.AgentAttachmentStorage()
monkeypatch.setattr(
attachment_storage_module.config.storage,
"bucket",
"allowed-bucket",
)
with pytest.raises(RuntimeError, match="Invalid attachment bucket"):
await storage.upload_bytes(
bucket="other-bucket",
path="agent-inputs/u/t/r/file.png",
content=b"data",
content_type="image/png",
)
@pytest.mark.asyncio
async def test_attachment_storage_accepts_configured_bucket(
monkeypatch: pytest.MonkeyPatch,
) -> None:
storage = attachment_storage_module.AgentAttachmentStorage()
fake_bucket = _FakeBucket()
fake_client = SimpleNamespace(storage=_FakeStorage(fake_bucket))
monkeypatch.setattr(
attachment_storage_module.config.storage,
"bucket",
"allowed-bucket",
)
monkeypatch.setattr(
attachment_storage_module.supabase_service,
"get_admin_client",
lambda: fake_client,
)
path = await storage.upload_bytes(
bucket="allowed-bucket",
path="agent-inputs/u/t/r/file.png",
content=b"data",
content_type="image/png",
)
payload = await storage.download_bytes(
bucket="allowed-bucket",
path=path,
)
assert path == "agent-inputs/u/t/r/file.png"
assert payload == b"ok"
assert len(fake_bucket.upload_calls) == 1
assert fake_bucket.download_calls == ["agent-inputs/u/t/r/file.png"]
+182 -9
View File
@@ -6,6 +6,7 @@ from uuid import uuid4
import pytest
from core.config.settings import config
from models.agent_chat_message import AgentChatMessageRole
from v1.agent.repository import AgentRepository
@@ -62,7 +63,7 @@ async def test_tool_message_hydrates_content_from_object_storage() -> None:
content='{"offloaded":true}',
metadata_json={
"tool_call_id": "call-1",
"storage_bucket": "private",
"storage_bucket": config.storage.bucket,
"storage_path": "tool-results/run-1/call-1.json",
},
)
@@ -73,6 +74,43 @@ async def test_tool_message_hydrates_content_from_object_storage() -> None:
assert payload["content"] == "已跳转"
@pytest.mark.asyncio
async def test_tool_message_hydrates_ui_from_ui_schema_field() -> None:
repository = AgentRepository(
session=SimpleNamespace(), # type: ignore[arg-type]
tool_result_storage=_FakeToolResultStorage(
{
"toolName": "calendar_write",
"ui_schema": {
"type": "calendar_operation.v1",
"version": "v1",
"data": {"ok": True, "operation": "create"},
"actions": [],
},
}
),
)
message = SimpleNamespace(
id=uuid4(),
role=AgentChatMessageRole.TOOL,
created_at=datetime.now(timezone.utc),
content="已创建日程:项目评审(明天 10:00)",
metadata_json={
"tool_call_id": "call-3",
"storage_bucket": config.storage.bucket,
"storage_path": "tool-results/run-1/call-3.json",
},
)
payload = await repository._to_snapshot_message(message) # type: ignore[arg-type]
assert payload["toolCallId"] == "call-3"
assert payload["content"] == "已创建日程:项目评审(明天 10:00)"
ui = payload.get("ui")
assert isinstance(ui, dict)
assert ui["type"] == "calendar_operation.v1"
@pytest.mark.asyncio
async def test_tool_message_keeps_inline_content_when_storage_payload_missing() -> None:
repository = AgentRepository(
@@ -86,7 +124,7 @@ async def test_tool_message_keeps_inline_content_when_storage_payload_missing()
content="inline-tool-content",
metadata_json={
"tool_call_id": "call-2",
"storage_bucket": "private",
"storage_bucket": config.storage.bucket,
"storage_path": "tool-results/run-1/call-2.json",
},
)
@@ -97,6 +135,111 @@ async def test_tool_message_keeps_inline_content_when_storage_payload_missing()
assert payload["content"] == "inline-tool-content"
@pytest.mark.asyncio
async def test_tool_message_skips_storage_when_path_not_matching_session() -> None:
repository = AgentRepository(
session=SimpleNamespace(), # type: ignore[arg-type]
tool_result_storage=_FakeToolResultStorage(
{
"ui_schema": {
"type": "calendar_operation.v1",
"version": "v1",
"data": {"ok": True},
"actions": [],
}
}
),
)
message = SimpleNamespace(
id=uuid4(),
session_id=uuid4(),
role=AgentChatMessageRole.TOOL,
created_at=datetime.now(timezone.utc),
content="summary",
metadata_json={
"tool_call_id": "call-x",
"storage_bucket": config.storage.bucket,
"storage_path": "tool-results/foreign-session/call-y.json",
},
)
payload = await repository._to_snapshot_message(message) # type: ignore[arg-type]
assert payload["content"] == "summary"
assert "ui" not in payload
@pytest.mark.asyncio
async def test_tool_message_rejects_path_traversal() -> None:
repository = AgentRepository(
session=SimpleNamespace(), # type: ignore[arg-type]
tool_result_storage=_FakeToolResultStorage(
{
"ui_schema": {
"type": "calendar_operation.v1",
"version": "v1",
"data": {"ok": True},
"actions": [],
}
}
),
)
message = SimpleNamespace(
id=uuid4(),
session_id=uuid4(),
role=AgentChatMessageRole.TOOL,
created_at=datetime.now(timezone.utc),
content="summary",
metadata_json={
"tool_call_id": "call-z",
"storage_bucket": config.storage.bucket,
"storage_path": "tool-results/ok/../../evil/call-z.json",
},
)
payload = await repository._to_snapshot_message(message) # type: ignore[arg-type]
assert payload["content"] == "summary"
assert "ui" not in payload
@pytest.mark.asyncio
async def test_tool_message_supports_legacy_storage_path() -> None:
repository = AgentRepository(
session=SimpleNamespace(), # type: ignore[arg-type]
tool_result_storage=_FakeToolResultStorage(
{
"ui_schema": {
"type": "calendar_operation.v1",
"version": "v1",
"data": {"ok": True},
"actions": [],
},
"content": "legacy content",
}
),
)
message = SimpleNamespace(
id=uuid4(),
session_id=uuid4(),
role=AgentChatMessageRole.TOOL,
created_at=datetime.now(timezone.utc),
content='{"offloaded":true}',
metadata_json={
"tool_call_id": "call-legacy",
"storage_bucket": config.storage.bucket,
"storage_path": "tool-results/old-run/call-legacy.json",
},
)
payload = await repository._to_snapshot_message(message) # type: ignore[arg-type]
assert payload["content"] == "legacy content"
ui = payload.get("ui")
assert isinstance(ui, dict)
assert ui["type"] == "calendar_operation.v1"
@pytest.mark.asyncio
async def test_user_message_snapshot_includes_renderable_attachments() -> None:
repository = AgentRepository(
@@ -104,6 +247,7 @@ async def test_user_message_snapshot_includes_renderable_attachments() -> None:
)
message = SimpleNamespace(
id=uuid4(),
session_id=uuid4(),
role=AgentChatMessageRole.USER,
created_at=datetime.now(timezone.utc),
content="请分析这张图",
@@ -122,13 +266,13 @@ async def test_user_message_snapshot_includes_renderable_attachments() -> None:
assert payload["role"] == "user"
assert payload["content"] == "请分析这张图"
assert payload["attachments"] == [
{
"bucket": "agent-chat-attachments",
"path": "agent-inputs/u1/t1/r1/m1/att-1.png",
"mimeType": "image/png",
}
]
attachments = payload.get("attachments")
assert isinstance(attachments, list)
assert len(attachments) == 1
first = attachments[0]
assert isinstance(first, dict)
assert first["mimeType"] == "image/png"
assert isinstance(first.get("previewPath"), str)
@pytest.mark.asyncio
@@ -174,3 +318,32 @@ async def test_persist_user_message_keeps_existing_session_title() -> None:
assert session_row.title == "已有标题"
assert session_row.message_count == 2
@pytest.mark.asyncio
async def test_get_message_attachment_reference_returns_item() -> None:
session_id = str(uuid4())
message_id = str(uuid4())
message = SimpleNamespace(
metadata_json={
"attachments": [
{
"bucket": "bucket-test",
"path": "agent-inputs/u/t/r/a.png",
"mimeType": "image/png",
}
]
}
)
fake_session = _FakeSession(message)
repository = AgentRepository(session=fake_session) # type: ignore[arg-type]
ref = await repository.get_message_attachment_reference(
session_id=session_id,
message_id=message_id,
attachment_index=0,
)
assert ref is not None
assert ref["bucket"] == "bucket-test"
assert ref["mimeType"] == "image/png"
@@ -225,3 +225,44 @@ async def test_stream_events_retries_on_redis_timeout(
merged = "".join(chunks)
assert "event: RUN_FINISHED" in merged
@pytest.mark.asyncio
async def test_get_attachment_preview_rejects_negative_index() -> None:
class _Service:
async def get_attachment_preview(self, **kwargs): # noqa: ANN003
del kwargs
raise AssertionError("get_attachment_preview should not be called")
with pytest.raises(HTTPException) as exc_info:
await agent_router.get_attachment_preview(
thread_id="00000000-0000-0000-0000-000000000001",
message_id="00000000-0000-0000-0000-000000000010",
attachment_index=-1,
service=cast(Any, _Service()),
current_user=CurrentUser(id=uuid4(), email="user@example.com"),
)
assert exc_info.value.status_code == 422
@pytest.mark.asyncio
async def test_get_attachment_preview_returns_streaming_response() -> None:
class _Service:
async def get_attachment_preview(self, **kwargs): # noqa: ANN003
del kwargs
return b"png-bytes", "image/png"
response = await agent_router.get_attachment_preview(
thread_id="00000000-0000-0000-0000-000000000001",
message_id="00000000-0000-0000-0000-000000000010",
attachment_index=0,
service=cast(Any, _Service()),
current_user=CurrentUser(id=uuid4(), email="user@example.com"),
)
chunks: list[bytes] = []
async for chunk in response.body_iterator:
chunks.append(cast(bytes, chunk))
assert response.media_type == "image/png"
assert b"".join(chunks) == b"png-bytes"
+341 -13
View File
@@ -6,8 +6,10 @@ from uuid import UUID
from ag_ui.core import RunAgentInput
from fastapi import HTTPException
import pytest
from sqlalchemy.exc import IntegrityError
from core.auth.models import CurrentUser
from core.config.settings import config
import v1.agent.service as agent_service_module
from v1.agent.service import AgentService, AsrService
@@ -74,12 +76,32 @@ class _FakeRepository:
}
)
async def get_message_attachment_reference(
self,
*,
session_id: str,
message_id: str,
attachment_index: int,
) -> dict[str, str] | None:
del session_id, message_id
if attachment_index != 0:
return None
return {
"bucket": config.storage.bucket,
"path": "agent-inputs/00000000-0000-0000-0000-000000000001/00000000-0000-0000-0000-000000000001/run-1/attachment-0-a.png",
"mimeType": "image/png",
}
class _FakeQueue:
def __init__(self) -> None:
self.commands: list[dict[str, object]] = []
async def enqueue(
self, *, command: dict[str, object], dedup_key: str | None
) -> str:
del command, dedup_key
self.commands.append(command)
del dedup_key
return "task-1"
@@ -123,6 +145,33 @@ class _FakeAttachmentStorage:
)
return path
async def download_bytes(self, *, bucket: str, path: str) -> bytes:
self.calls.append(
{
"bucket": bucket,
"path": path,
"download": True,
}
)
return b"png-bytes"
async def create_signed_url(
self,
*,
bucket: str,
path: str,
expires_in_seconds: int,
) -> str:
self.calls.append(
{
"bucket": bucket,
"path": path,
"signed": True,
"expires_in_seconds": expires_in_seconds,
}
)
return f"https://signed.example/{path}?exp={expires_in_seconds}"
class _AlwaysFailAttachmentStorage:
async def upload_bytes(
@@ -136,6 +185,20 @@ class _AlwaysFailAttachmentStorage:
del bucket, path, content, content_type
raise RuntimeError("upload failed")
async def download_bytes(self, *, bucket: str, path: str) -> bytes:
del bucket, path
raise RuntimeError("download failed")
async def create_signed_url(
self,
*,
bucket: str,
path: str,
expires_in_seconds: int,
) -> str:
del bucket, path, expires_in_seconds
raise RuntimeError("sign failed")
def _user() -> CurrentUser:
return CurrentUser(
@@ -186,9 +249,10 @@ async def test_resume_idempotency_uses_redis_lock_and_task_key() -> None:
async def test_enqueue_run_creates_missing_thread_session() -> None:
repository = _FakeRepository()
queue = _FakeQueue()
service = AgentService(
repository=repository,
queue=_FakeQueue(),
queue=queue,
stream=_FakeStream(),
)
run_input = _build_run_input(
@@ -206,6 +270,30 @@ async def test_enqueue_run_creates_missing_thread_session() -> None:
assert accepted.created is True
assert repository.created_with_session_id == "00000000-0000-0000-0000-000000000999"
assert repository.committed is True
assert queue.commands[0]["user_token"] is None
async def test_enqueue_run_uses_explicit_user_token() -> None:
repository = _FakeRepository()
queue = _FakeQueue()
service = AgentService(
repository=repository,
queue=queue,
stream=_FakeStream(),
)
run_input = _build_run_input(
thread_id="00000000-0000-0000-0000-000000000001",
run_id="run-1",
)
await service.enqueue_run(
run_input=run_input,
current_user=_user(),
user_token="Bearer access-token-1",
)
assert queue.commands
assert queue.commands[0]["user_token"] == "access-token-1"
async def test_enqueue_run_keeps_created_session_when_enqueue_fails() -> None:
@@ -270,7 +358,7 @@ async def test_enqueue_run_handles_session_create_race() -> None:
assert repository.rolled_back is True
async def test_enqueue_run_uploads_user_image_to_supabase_and_injects_metadata(
async def test_enqueue_run_uses_forwarded_attachments_and_injects_metadata(
monkeypatch,
) -> None:
monkeypatch.setattr(
@@ -297,15 +385,23 @@ async def test_enqueue_run_uploads_user_image_to_supabase_and_injects_metadata(
{"type": "text", "text": "帮我看下这张图"},
{
"type": "binary",
"data": "aGVsbG8=",
"mimeType": "image/png",
"url": "https://signed.example/upload.png",
},
],
}
],
"tools": [],
"context": [],
"forwardedProps": {},
"forwardedProps": {
"attachments": [
{
"bucket": "agent-test-bucket",
"path": "agent-inputs/00000000-0000-0000-0000-000000000001/00000000-0000-0000-0000-000000000001/upload.png",
"mimeType": "image/png",
}
]
},
}
)
@@ -313,10 +409,9 @@ async def test_enqueue_run_uploads_user_image_to_supabase_and_injects_metadata(
assert accepted.task_id == "task-1"
assert len(attachment_storage.calls) == 1
upload = attachment_storage.calls[0]
assert upload["bucket"] == "agent-test-bucket"
assert upload["content"] == b"hello"
assert upload["content_type"] == "image/png"
download = attachment_storage.calls[0]
assert download["bucket"] == "agent-test-bucket"
assert download["download"] is True
assert repository.persisted_user_messages
persisted = repository.persisted_user_messages[0]
assert persisted["session_id"] == "00000000-0000-0000-0000-000000000001"
@@ -330,7 +425,7 @@ async def test_enqueue_run_uploads_user_image_to_supabase_and_injects_metadata(
assert isinstance(attachments[0]["path"], str)
async def test_enqueue_run_raises_when_attachment_upload_fails_without_fallback(
async def test_enqueue_run_raises_when_attachment_download_fails_without_fallback(
monkeypatch,
) -> None:
monkeypatch.setattr(
@@ -356,15 +451,23 @@ async def test_enqueue_run_raises_when_attachment_upload_fails_without_fallback(
{"type": "text", "text": "帮我看下这张图"},
{
"type": "binary",
"data": "aGVsbG8=",
"mimeType": "image/png",
"url": "https://signed.example/upload.png",
},
],
}
],
"tools": [],
"context": [],
"forwardedProps": {},
"forwardedProps": {
"attachments": [
{
"bucket": "agent-test-bucket",
"path": "agent-inputs/00000000-0000-0000-0000-000000000001/00000000-0000-0000-0000-000000000001/upload.png",
"mimeType": "image/png",
}
]
},
}
)
@@ -373,11 +476,183 @@ async def test_enqueue_run_raises_when_attachment_upload_fails_without_fallback(
raise AssertionError("expected HTTPException")
except HTTPException as exc:
assert exc.status_code == 502
assert exc.detail == "Failed to upload attachment"
assert exc.detail == "Failed to fetch attachment"
assert repository.persisted_user_messages == []
async def test_enqueue_run_rejects_unsupported_attachment_type(
monkeypatch,
) -> None:
monkeypatch.setattr(
agent_service_module.config.storage, "bucket", "agent-test-bucket"
)
repository = _FakeRepository()
attachment_storage = _FakeAttachmentStorage()
service = AgentService(
repository=repository,
queue=_FakeQueue(),
stream=_FakeStream(),
attachment_storage=attachment_storage,
)
run_input = RunAgentInput.model_validate(
{
"threadId": "00000000-0000-0000-0000-000000000001",
"runId": "run-with-bad-image",
"state": {},
"messages": [
{
"id": "u1",
"role": "user",
"content": [
{"type": "text", "text": "请看附件"},
{
"type": "binary",
"mimeType": "image/gif",
"url": "https://signed.example/upload.gif",
},
],
}
],
"tools": [],
"context": [],
"forwardedProps": {
"attachments": [
{
"bucket": "agent-test-bucket",
"path": "agent-inputs/00000000-0000-0000-0000-000000000001/00000000-0000-0000-0000-000000000001/upload.gif",
"mimeType": "image/gif",
}
]
},
}
)
with pytest.raises(HTTPException) as exc_info:
await service.enqueue_run(run_input=run_input, current_user=_user())
assert exc_info.value.status_code == 422
assert exc_info.value.detail == "Unsupported attachment type"
assert attachment_storage.calls == []
async def test_enqueue_run_rejects_attachment_too_large(
monkeypatch,
) -> None:
monkeypatch.setattr(agent_service_module, "_MAX_ATTACHMENT_BYTES", 4)
monkeypatch.setattr(
agent_service_module.config.storage, "bucket", "agent-test-bucket"
)
repository = _FakeRepository()
attachment_storage = _FakeAttachmentStorage()
service = AgentService(
repository=repository,
queue=_FakeQueue(),
stream=_FakeStream(),
attachment_storage=attachment_storage,
)
run_input = RunAgentInput.model_validate(
{
"threadId": "00000000-0000-0000-0000-000000000001",
"runId": "run-with-big-image",
"state": {},
"messages": [
{
"id": "u1",
"role": "user",
"content": [
{"type": "text", "text": "请看附件"},
{
"type": "binary",
"mimeType": "image/png",
"url": "https://signed.example/upload.png",
},
],
}
],
"tools": [],
"context": [],
"forwardedProps": {
"attachments": [
{
"bucket": "agent-test-bucket",
"path": "agent-inputs/00000000-0000-0000-0000-000000000001/00000000-0000-0000-0000-000000000001/upload.png",
"mimeType": "image/png",
}
]
},
}
)
with pytest.raises(HTTPException) as exc_info:
await service.enqueue_run(run_input=run_input, current_user=_user())
assert exc_info.value.status_code == 413
assert exc_info.value.detail == "Attachment too large"
assert len(attachment_storage.calls) == 1
assert attachment_storage.calls[0]["download"] is True
async def test_enqueue_run_accepts_binary_url_and_persists_metadata() -> None:
repository = _FakeRepository()
queue = _FakeQueue()
attachment_storage = _FakeAttachmentStorage()
service = AgentService(
repository=repository,
queue=queue,
stream=_FakeStream(),
attachment_storage=attachment_storage,
)
run_input = RunAgentInput.model_validate(
{
"threadId": "00000000-0000-0000-0000-000000000001",
"runId": "run-with-binary-url",
"state": {},
"messages": [
{
"id": "u1",
"role": "user",
"content": [
{"type": "text", "text": "请分析"},
{
"type": "binary",
"mimeType": "image/png",
"url": "https://signed.example/upload-1.png",
},
],
}
],
"tools": [],
"context": [],
"forwardedProps": {
"attachments": [
{
"bucket": config.storage.bucket,
"path": "agent-inputs/00000000-0000-0000-0000-000000000001/00000000-0000-0000-0000-000000000001/upload-1.png",
"mimeType": "image/png",
}
]
},
}
)
accepted = await service.enqueue_run(run_input=run_input, current_user=_user())
assert accepted.task_id == "task-1"
persisted = repository.persisted_user_messages[-1]
metadata = persisted["metadata"]
assert isinstance(metadata, dict)
attachments = metadata.get("attachments")
assert isinstance(attachments, list)
assert attachments[0]["path"].endswith("upload-1.png")
queue_input = queue.commands[-1]["run_input"]
assert isinstance(queue_input, dict)
content = queue_input["messages"][0]["content"]
assert isinstance(content, list)
assert content[1]["type"] == "binary"
assert content[1]["url"] == "https://signed.example/upload-1.png"
async def test_get_history_snapshot_wraps_history_day_as_state_snapshot_event() -> None:
service = AgentService(
repository=_FakeRepository(),
@@ -415,6 +690,59 @@ async def test_get_user_history_snapshot_uses_latest_thread_when_absent() -> Non
assert event["threadId"] == "00000000-0000-0000-0000-000000000001"
async def test_get_attachment_preview_returns_payload_and_mime() -> None:
service = AgentService(
repository=_FakeRepository(),
queue=_FakeQueue(),
stream=_FakeStream(),
attachment_storage=_FakeAttachmentStorage(),
)
payload, mime_type = await service.get_attachment_preview(
thread_id="00000000-0000-0000-0000-000000000001",
message_id="00000000-0000-0000-0000-000000000010",
attachment_index=0,
current_user=_user(),
)
assert payload == b"png-bytes"
assert mime_type == "image/png"
async def test_get_attachment_preview_rejects_invalid_path() -> None:
class _BadPathRepository(_FakeRepository):
async def get_message_attachment_reference(
self,
*,
session_id: str,
message_id: str,
attachment_index: int,
) -> dict[str, str] | None:
del session_id, message_id, attachment_index
return {
"bucket": "bucket-test",
"path": "agent-inputs/other-user/other-thread/run-1/a.png",
"mimeType": "image/png",
}
service = AgentService(
repository=_BadPathRepository(),
queue=_FakeQueue(),
stream=_FakeStream(),
attachment_storage=_FakeAttachmentStorage(),
)
with pytest.raises(HTTPException) as exc_info:
await service.get_attachment_preview(
thread_id="00000000-0000-0000-0000-000000000001",
message_id="00000000-0000-0000-0000-000000000010",
attachment_index=0,
current_user=_user(),
)
assert exc_info.value.status_code == 403
async def test_asr_service_parses_dict_output_sentence(monkeypatch) -> None:
result = SimpleNamespace(
status_code=200,