refactor: 简化 AgentScope 运行时模块与事件处理
- 移除冗余的 user_token 参数传递 - 重构 tool.result 事件使用 ToolAgentOutput 模型 - 重构 text.end 事件使用 WorkerAgentOutput 模型 - 简化 store 模块的 tool result 处理逻辑 - 更新 router/service 适配新事件结构 - 清理废弃的测试文件与设计文档 - 新增 AgentRuns 多模态存储设计文档
This commit is contained in:
@@ -38,16 +38,13 @@ def to_agui_wire_event(event: dict[str, Any]) -> dict[str, Any]:
|
||||
data = event.get("data")
|
||||
if isinstance(data, dict):
|
||||
if event_type == "tool.result":
|
||||
for key in (
|
||||
"messageId",
|
||||
"toolCallId",
|
||||
"callId",
|
||||
"toolName",
|
||||
"stage",
|
||||
"taskId",
|
||||
"ui",
|
||||
"content",
|
||||
):
|
||||
for key in ("messageId", "toolCallId", "toolAgentOutput"):
|
||||
value = data.get(key)
|
||||
if value is not None:
|
||||
payload[key] = value
|
||||
return payload
|
||||
if event_type == "text.end":
|
||||
for key in ("messageId", "workerAgentOutput"):
|
||||
value = data.get(key)
|
||||
if value is not None:
|
||||
payload[key] = value
|
||||
|
||||
@@ -1,16 +1,19 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import re
|
||||
from decimal import Decimal, InvalidOperation
|
||||
from typing import Any, Callable, Protocol
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from core.agentscope.events.tool_result_summary import build_tool_content_summary
|
||||
from core.agentscope.events.persistence import MessageRepository, SessionRepository
|
||||
from core.logging import get_logger
|
||||
from models.agent_chat_message import AgentChatMessageRole
|
||||
from models.agent_chat_session import AgentChatSessionStatus
|
||||
from schemas.agent.runtime_models import (
|
||||
ToolAgentOutput,
|
||||
WorkerAgentOutputLite,
|
||||
WorkerAgentOutputRich,
|
||||
)
|
||||
|
||||
|
||||
class EventStore(Protocol):
|
||||
@@ -193,6 +196,19 @@ class SqlAlchemyEventStore:
|
||||
if isinstance(stage, str) and stage:
|
||||
metadata["stage"] = stage
|
||||
|
||||
worker_payload = event.get("workerAgentOutput")
|
||||
if isinstance(worker_payload, dict):
|
||||
try:
|
||||
if "ui_hints" in worker_payload:
|
||||
worker_output = WorkerAgentOutputRich.model_validate(worker_payload)
|
||||
else:
|
||||
worker_output = WorkerAgentOutputLite.model_validate(worker_payload)
|
||||
except Exception:
|
||||
worker_output = None
|
||||
else:
|
||||
content = worker_output.answer
|
||||
metadata["worker_agent_output"] = worker_output.model_dump(mode="json")
|
||||
|
||||
role_value = context.get("role")
|
||||
if not isinstance(role_value, str):
|
||||
role_value = "assistant"
|
||||
@@ -252,6 +268,14 @@ class SqlAlchemyEventStore:
|
||||
if not isinstance(tool_name, str) or not tool_name:
|
||||
return
|
||||
|
||||
raw_output = event.get("toolAgentOutput")
|
||||
if not isinstance(raw_output, dict):
|
||||
return
|
||||
try:
|
||||
tool_output = ToolAgentOutput.model_validate(raw_output)
|
||||
except Exception:
|
||||
return
|
||||
|
||||
run_id = event.get("runId")
|
||||
run_id_value = run_id if isinstance(run_id, str) and run_id else ""
|
||||
task_id = event.get("taskId")
|
||||
@@ -264,43 +288,18 @@ class SqlAlchemyEventStore:
|
||||
else f"{task_id_value}-{uuid4().hex[:8]}"
|
||||
)
|
||||
|
||||
summary = build_tool_content_summary(
|
||||
tool_name=tool_name,
|
||||
args=event.get("args") if isinstance(event.get("args"), dict) else None,
|
||||
result=event.get("result"),
|
||||
error=event.get("error"),
|
||||
)
|
||||
|
||||
raw_result_value = event.get("result")
|
||||
raw_result: dict[str, object] = (
|
||||
raw_result_value if isinstance(raw_result_value, dict) else {}
|
||||
)
|
||||
ui_candidate = raw_result.get("ui")
|
||||
ui_schema = ui_candidate if isinstance(ui_candidate, dict) else None
|
||||
result_type = raw_result.get("type")
|
||||
result_data = raw_result.get("data")
|
||||
if (
|
||||
ui_schema is None
|
||||
and isinstance(result_type, str)
|
||||
and isinstance(result_data, dict)
|
||||
):
|
||||
ui_schema = raw_result
|
||||
|
||||
payload: dict[str, object] = {
|
||||
"toolName": tool_name,
|
||||
"ui_schema": ui_schema,
|
||||
"result": _sanitize_result(raw_result),
|
||||
"error": _sanitize_error(event.get("error")),
|
||||
"toolAgentOutput": tool_output.model_dump(mode="json"),
|
||||
"callId": call_id_value,
|
||||
"runId": run_id_value,
|
||||
"taskId": task_id_value,
|
||||
"content": summary,
|
||||
"content": tool_output.result_summary,
|
||||
}
|
||||
|
||||
metadata: dict[str, object] = {
|
||||
"tool_name": tool_name,
|
||||
"tool_call_id": call_id_value,
|
||||
"summary_version": "v1",
|
||||
"tool_agent_output": tool_output.model_dump(mode="json"),
|
||||
}
|
||||
if run_id_value:
|
||||
metadata["run_id"] = run_id_value
|
||||
@@ -332,9 +331,7 @@ class SqlAlchemyEventStore:
|
||||
storage_path=storage_path,
|
||||
)
|
||||
|
||||
content = summary or json.dumps(
|
||||
payload, ensure_ascii=False, separators=(",", ":")
|
||||
)
|
||||
content = tool_output.result_summary
|
||||
|
||||
locked_session = await session_repo.lock_session_for_update(
|
||||
session_id=session_id
|
||||
@@ -429,63 +426,3 @@ def _sanitize_path_component(value: str) -> str:
|
||||
compact = re.sub(r"[^A-Za-z0-9._-]", "-", value.strip())
|
||||
compact = compact.strip(".-")
|
||||
return compact or "id"
|
||||
|
||||
|
||||
def _sanitize_error(value: object) -> str | None:
|
||||
if isinstance(value, str) and value.strip():
|
||||
return " ".join(value.split())[:300]
|
||||
if isinstance(value, dict):
|
||||
for key in ("message", "error", "detail"):
|
||||
item = value.get(key)
|
||||
if isinstance(item, str) and item.strip():
|
||||
return " ".join(item.split())[:300]
|
||||
return None
|
||||
|
||||
|
||||
def _sanitize_result(value: object) -> dict[str, object]:
|
||||
if not isinstance(value, dict):
|
||||
return {}
|
||||
|
||||
def _is_sensitive_key(key: str) -> bool:
|
||||
normalized = key.strip().lower().replace("-", "_")
|
||||
if not normalized:
|
||||
return False
|
||||
exact = {
|
||||
"password",
|
||||
"token",
|
||||
"secret",
|
||||
"api_key",
|
||||
"apikey",
|
||||
"credential",
|
||||
"authorization",
|
||||
"auth",
|
||||
}
|
||||
if normalized in exact:
|
||||
return True
|
||||
patterns = (
|
||||
"password",
|
||||
"token",
|
||||
"secret",
|
||||
"auth",
|
||||
"credential",
|
||||
"api_key",
|
||||
"apikey",
|
||||
"authorization",
|
||||
)
|
||||
return any(pattern in normalized for pattern in patterns)
|
||||
|
||||
def _sanitize_value(item: object) -> object:
|
||||
if isinstance(item, dict):
|
||||
return _sanitize_result(item)
|
||||
if isinstance(item, list):
|
||||
return [_sanitize_value(entry) for entry in item]
|
||||
return item
|
||||
|
||||
sanitized: dict[str, object] = {}
|
||||
for key, item in value.items():
|
||||
key_text = str(key)
|
||||
if _is_sensitive_key(key_text):
|
||||
sanitized[str(key)] = "[REDACTED]"
|
||||
continue
|
||||
sanitized[str(key)] = _sanitize_value(item)
|
||||
return sanitized
|
||||
|
||||
@@ -78,21 +78,19 @@ def build_intent_user_prompt(
|
||||
*, user_input: str | list[dict[str, Any]]
|
||||
) -> str | list[dict[str, Any]]:
|
||||
if isinstance(user_input, list):
|
||||
instruction_block = {
|
||||
"type": "text",
|
||||
"text": "\n\n".join(
|
||||
[
|
||||
ROUTER_STAGE_INSTRUCTION,
|
||||
"[Output Schema]",
|
||||
_schema_json(RouterAgentOutput),
|
||||
]
|
||||
),
|
||||
}
|
||||
return [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "\n\n".join(
|
||||
[
|
||||
ROUTER_STAGE_INSTRUCTION,
|
||||
"[Output Schema]",
|
||||
_schema_json(RouterAgentOutput),
|
||||
"[User Input]",
|
||||
json.dumps(
|
||||
user_input, ensure_ascii=True, separators=(",", ":")
|
||||
),
|
||||
]
|
||||
),
|
||||
}
|
||||
instruction_block,
|
||||
*user_input,
|
||||
]
|
||||
return "\n\n".join(
|
||||
[
|
||||
|
||||
@@ -50,11 +50,9 @@ class AgentScopeRuntimeOrchestrator:
|
||||
*,
|
||||
command: RunAgentInput,
|
||||
owner_id: UUID,
|
||||
user_token: str,
|
||||
user_context: UserContext,
|
||||
session: AsyncSession,
|
||||
) -> dict[str, Any]:
|
||||
del user_token
|
||||
return await self._execute(
|
||||
command=command,
|
||||
owner_id=owner_id,
|
||||
@@ -68,11 +66,9 @@ class AgentScopeRuntimeOrchestrator:
|
||||
*,
|
||||
command: RunAgentInput,
|
||||
owner_id: UUID,
|
||||
user_token: str,
|
||||
user_context: UserContext,
|
||||
session: AsyncSession,
|
||||
) -> dict[str, Any]:
|
||||
del user_token
|
||||
return await self._execute(
|
||||
command=command,
|
||||
owner_id=owner_id,
|
||||
@@ -116,7 +112,7 @@ class AgentScopeRuntimeOrchestrator:
|
||||
user_input = _to_resume_user_input(command)
|
||||
else:
|
||||
_, content_blocks = extract_latest_user_payload(command)
|
||||
user_input = _to_user_input_payload(content_blocks)
|
||||
user_input = _to_model_user_input(content_blocks)
|
||||
router_toolkit = build_stage_toolkit(
|
||||
stage="intent",
|
||||
session=session,
|
||||
@@ -159,16 +155,38 @@ class AgentScopeRuntimeOrchestrator:
|
||||
|
||||
worker_payload = result.get("worker") if isinstance(result, dict) else None
|
||||
worker = worker_payload if isinstance(worker_payload, dict) else {}
|
||||
response_metadata = worker.get("response_metadata")
|
||||
metadata = response_metadata if isinstance(response_metadata, dict) else {}
|
||||
assistant_text = _resolve_worker_answer(worker)
|
||||
tool_outputs_raw = worker.get("tool_outputs")
|
||||
if isinstance(tool_outputs_raw, list):
|
||||
for idx, item in enumerate(tool_outputs_raw, start=1):
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
tool_name = item.get("tool_name")
|
||||
tool_call_id = item.get("tool_call_id")
|
||||
if not isinstance(tool_name, str) or not tool_name:
|
||||
continue
|
||||
if not isinstance(tool_call_id, str) or not tool_call_id:
|
||||
tool_call_id = f"{run_id}-tool-{idx}"
|
||||
await self._pipeline.emit(
|
||||
session_id=thread_id,
|
||||
event={
|
||||
"type": "tool.result",
|
||||
"threadId": thread_id,
|
||||
"runId": run_id,
|
||||
"data": {
|
||||
"messageId": f"tool-{tool_call_id}",
|
||||
"toolCallId": tool_call_id,
|
||||
"toolAgentOutput": item,
|
||||
},
|
||||
},
|
||||
)
|
||||
await self._emit_stage_text(
|
||||
thread_id=thread_id,
|
||||
run_id=run_id,
|
||||
stage_name="worker",
|
||||
message_id=f"assistant-{run_id}",
|
||||
text=assistant_text,
|
||||
response_metadata=metadata,
|
||||
worker_agent_output=worker,
|
||||
)
|
||||
|
||||
await self._pipeline.emit(
|
||||
@@ -215,7 +233,7 @@ class AgentScopeRuntimeOrchestrator:
|
||||
stage_name: str,
|
||||
message_id: str,
|
||||
text: str,
|
||||
response_metadata: dict[str, Any],
|
||||
worker_agent_output: dict[str, Any],
|
||||
) -> None:
|
||||
await self._pipeline.emit(
|
||||
session_id=thread_id,
|
||||
@@ -250,8 +268,7 @@ class AgentScopeRuntimeOrchestrator:
|
||||
"runId": run_id,
|
||||
"data": {
|
||||
"messageId": message_id,
|
||||
"stage": stage_name,
|
||||
**_text_end_telemetry_payload(response_metadata),
|
||||
"workerAgentOutput": worker_agent_output,
|
||||
},
|
||||
},
|
||||
)
|
||||
@@ -271,6 +288,28 @@ def _to_user_input_payload(
|
||||
return content_blocks
|
||||
|
||||
|
||||
def _to_model_user_input(
|
||||
content_blocks: list[dict[str, Any]],
|
||||
) -> str | list[dict[str, Any]]:
|
||||
normalized: list[dict[str, Any]] = []
|
||||
for block in content_blocks:
|
||||
if not isinstance(block, dict):
|
||||
continue
|
||||
block_type = block.get("type")
|
||||
if block_type == "text":
|
||||
text = block.get("text")
|
||||
if isinstance(text, str) and text.strip():
|
||||
normalized.append({"type": "text", "text": text})
|
||||
continue
|
||||
if block_type != "binary":
|
||||
continue
|
||||
url = block.get("url")
|
||||
if isinstance(url, str) and url:
|
||||
normalized.append({"type": "image_url", "image_url": {"url": url}})
|
||||
|
||||
return _to_user_input_payload(normalized)
|
||||
|
||||
|
||||
def _to_resume_user_input(command: RunAgentInput) -> list[dict[str, Any]]:
|
||||
normalized: list[dict[str, Any]] = []
|
||||
for message in command.messages:
|
||||
@@ -296,66 +335,3 @@ def _resolve_worker_answer(worker: dict[str, Any]) -> str:
|
||||
return message
|
||||
|
||||
return "抱歉,这次没有产出可用结果,请重试。"
|
||||
|
||||
|
||||
def _text_end_telemetry_payload(metadata: dict[str, Any]) -> dict[str, Any]:
|
||||
payload: dict[str, Any] = {}
|
||||
model = _first_non_empty_str(metadata, keys=("model", "model_code"))
|
||||
if model is not None:
|
||||
payload["model"] = model
|
||||
|
||||
input_tokens = _first_number(metadata, keys=("inputTokens", "input_tokens"))
|
||||
if input_tokens is not None:
|
||||
payload["inputTokens"] = input_tokens
|
||||
|
||||
output_tokens = _first_number(metadata, keys=("outputTokens", "output_tokens"))
|
||||
if output_tokens is not None:
|
||||
payload["outputTokens"] = output_tokens
|
||||
|
||||
latency_ms = _first_number(metadata, keys=("latencyMs", "latency_ms"))
|
||||
if latency_ms is not None:
|
||||
payload["latencyMs"] = latency_ms
|
||||
|
||||
cost = _first_number(metadata, keys=("cost", "total_cost"), allow_float=True)
|
||||
if cost is not None:
|
||||
payload["cost"] = cost
|
||||
|
||||
return payload
|
||||
|
||||
|
||||
def _first_non_empty_str(
|
||||
metadata: dict[str, Any], *, keys: tuple[str, ...]
|
||||
) -> str | None:
|
||||
for key in keys:
|
||||
value = metadata.get(key)
|
||||
if isinstance(value, str) and value.strip():
|
||||
return value.strip()
|
||||
return None
|
||||
|
||||
|
||||
def _first_number(
|
||||
metadata: dict[str, Any],
|
||||
*,
|
||||
keys: tuple[str, ...],
|
||||
allow_float: bool = False,
|
||||
) -> int | float | None:
|
||||
for key in keys:
|
||||
value = metadata.get(key)
|
||||
if isinstance(value, bool):
|
||||
continue
|
||||
if isinstance(value, int):
|
||||
if value < 0:
|
||||
continue
|
||||
return value
|
||||
if isinstance(value, float):
|
||||
if value < 0:
|
||||
continue
|
||||
return value if allow_float else int(value)
|
||||
if isinstance(value, str):
|
||||
try:
|
||||
parsed = float(value) if allow_float else int(value)
|
||||
except ValueError:
|
||||
continue
|
||||
if parsed >= 0:
|
||||
return parsed
|
||||
return None
|
||||
|
||||
@@ -68,16 +68,6 @@ def _build_user_context(*, owner_id: UUID, run_input: RunAgentInput) -> UserCont
|
||||
)
|
||||
|
||||
|
||||
def _extract_user_token(
|
||||
*, command: dict[str, Any], run_input: RunAgentInput
|
||||
) -> str | None:
|
||||
del run_input
|
||||
raw_token = command.get("user_token")
|
||||
if isinstance(raw_token, str) and raw_token.strip():
|
||||
return raw_token.strip()
|
||||
return None
|
||||
|
||||
|
||||
async def _build_recent_context_messages(
|
||||
*,
|
||||
session: Any,
|
||||
@@ -147,7 +137,6 @@ async def run_agentscope_task(command: dict[str, Any]) -> dict[str, object]:
|
||||
if command_type == "resume":
|
||||
extract_latest_tool_result(parsed_run_input)
|
||||
user_context = _build_user_context(owner_id=owner_id, run_input=parsed_run_input)
|
||||
user_token = _extract_user_token(command=command, run_input=parsed_run_input) or ""
|
||||
|
||||
redis_client = await get_or_init_redis_client()
|
||||
bus = RedisStreamBus(
|
||||
@@ -189,7 +178,6 @@ async def run_agentscope_task(command: dict[str, Any]) -> dict[str, object]:
|
||||
await runtime.resume(
|
||||
command=parsed_run_input,
|
||||
owner_id=owner_id,
|
||||
user_token=user_token,
|
||||
user_context=user_context,
|
||||
session=session,
|
||||
)
|
||||
@@ -197,7 +185,6 @@ async def run_agentscope_task(command: dict[str, Any]) -> dict[str, object]:
|
||||
await runtime.run(
|
||||
command=parsed_run_input,
|
||||
owner_id=owner_id,
|
||||
user_token=user_token,
|
||||
user_context=user_context,
|
||||
session=session,
|
||||
)
|
||||
|
||||
@@ -167,16 +167,18 @@ class UiCompiler:
|
||||
timestamp: str | None = None,
|
||||
meta: dict[str, Any] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
hints = output.ui_hints or self._build_default_worker_hints(output)
|
||||
if output.error is not None and not self._contains_error_block(hints.blocks):
|
||||
output_ui_hints = getattr(output, "ui_hints", None)
|
||||
hints = output_ui_hints or self._build_default_worker_hints(output)
|
||||
output_error = getattr(output, "error", None)
|
||||
if output_error is not None and not self._contains_error_block(hints.blocks):
|
||||
hints = self._append_error_block(
|
||||
hints,
|
||||
UiHintErrorBlock(
|
||||
kind="error",
|
||||
errorCode=output.error.code,
|
||||
message=output.error.message,
|
||||
retryable=output.error.retryable,
|
||||
details=self._stringify_details(output.error.details),
|
||||
errorCode=output_error.code,
|
||||
message=output_error.message,
|
||||
retryable=output_error.retryable,
|
||||
details=self._stringify_details(output_error.details),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@@ -0,0 +1,57 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Protocol
|
||||
|
||||
from services.base.supabase import supabase_service
|
||||
|
||||
|
||||
class ToolResultStorage(Protocol):
|
||||
async def upload_json(
|
||||
self,
|
||||
*,
|
||||
bucket: str,
|
||||
path: str,
|
||||
payload: dict[str, object],
|
||||
) -> str: ...
|
||||
|
||||
async def read_json(
|
||||
self,
|
||||
*,
|
||||
bucket: str,
|
||||
path: str,
|
||||
) -> dict[str, object] | None: ...
|
||||
|
||||
|
||||
class SupabaseToolResultStorage:
|
||||
async def upload_json(
|
||||
self,
|
||||
*,
|
||||
bucket: str,
|
||||
path: str,
|
||||
payload: dict[str, object],
|
||||
) -> str:
|
||||
serialized = json.dumps(payload, ensure_ascii=True, separators=(",", ":"))
|
||||
await supabase_service.upload_bytes(
|
||||
bucket=bucket,
|
||||
path=path,
|
||||
content=serialized.encode("utf-8"),
|
||||
content_type="application/json",
|
||||
)
|
||||
return path
|
||||
|
||||
async def read_json(
|
||||
self,
|
||||
*,
|
||||
bucket: str,
|
||||
path: str,
|
||||
) -> dict[str, object] | None:
|
||||
raw = await supabase_service.download_bytes(bucket=bucket, path=path)
|
||||
decoded = json.loads(raw.decode("utf-8"))
|
||||
if isinstance(decoded, dict):
|
||||
return decoded
|
||||
return None
|
||||
|
||||
|
||||
def create_tool_result_storage() -> ToolResultStorage:
|
||||
return SupabaseToolResultStorage()
|
||||
@@ -1,11 +1,3 @@
|
||||
from core.agentscope.schemas.agui_input import (
|
||||
extract_latest_tool_result,
|
||||
extract_latest_user_content,
|
||||
extract_latest_user_payload,
|
||||
extract_latest_user_text,
|
||||
parse_run_input,
|
||||
validate_run_request_messages_contract,
|
||||
)
|
||||
from schemas.agent.runtime_models import (
|
||||
ResultType,
|
||||
RouterAgentOutput,
|
||||
@@ -14,17 +6,17 @@ from schemas.agent.runtime_models import (
|
||||
ToolAgentOutput,
|
||||
ToolStatus,
|
||||
UiMode,
|
||||
WorkerAgentOutput,
|
||||
WorkerAgentOutputLite,
|
||||
WorkerAgentOutputRich,
|
||||
WorkerAgentOutput,
|
||||
resolve_worker_output_model,
|
||||
)
|
||||
from schemas.agent.system_agent import AgentType, SystemAgentLLMConfig
|
||||
from schemas.agent.ui_hints import (
|
||||
UiHintAction,
|
||||
UiHintBlock,
|
||||
UiHintsPayload,
|
||||
)
|
||||
from schemas.agent.system_agent import AgentType, SystemAgentLLMConfig
|
||||
|
||||
__all__ = [
|
||||
"AgentType",
|
||||
@@ -43,10 +35,4 @@ __all__ = [
|
||||
"WorkerAgentOutputRich",
|
||||
"WorkerAgentOutput",
|
||||
"resolve_worker_output_model",
|
||||
"extract_latest_tool_result",
|
||||
"extract_latest_user_content",
|
||||
"extract_latest_user_payload",
|
||||
"extract_latest_user_text",
|
||||
"parse_run_input",
|
||||
"validate_run_request_messages_contract",
|
||||
]
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
from core.agentscope.schemas.agui_input import * # noqa: F403
|
||||
@@ -374,7 +374,7 @@ class WorkerAgentOutputRich(WorkerAgentOutputLite):
|
||||
)
|
||||
|
||||
|
||||
WorkerAgentOutput = WorkerAgentOutputRich
|
||||
WorkerAgentOutput = WorkerAgentOutputLite | WorkerAgentOutputRich
|
||||
|
||||
|
||||
def resolve_worker_output_model(ui_mode: UiMode) -> type[WorkerAgentOutputLite]:
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
from schemas.messages.chat_message import AgentChatMessageMetadata
|
||||
from schemas.messages.chat_message import AgentChatMessage, AgentChatMessageMetadata
|
||||
|
||||
__all__ = ["AgentChatMessageMetadata"]
|
||||
__all__ = ["AgentChatMessage", "AgentChatMessageMetadata"]
|
||||
|
||||
@@ -1,8 +1,11 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from decimal import Decimal
|
||||
from typing import ClassVar
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from ..agent import AgentType, ToolAgentOutput, WorkerAgentOutput
|
||||
|
||||
@@ -22,3 +25,16 @@ class AgentChatMessageMetadata(BaseModel):
|
||||
user_message_attachments: UserMessageAttachments | None = None
|
||||
tool_agent_output: ToolAgentOutput | None = None
|
||||
worker_agent_output: WorkerAgentOutput | None = None
|
||||
|
||||
|
||||
class AgentChatMessage(BaseModel):
|
||||
"""Canonical schema aligned with `messages` table columns."""
|
||||
|
||||
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
|
||||
|
||||
id: UUID
|
||||
seq: int
|
||||
role: str
|
||||
content: str
|
||||
metadata: AgentChatMessageMetadata | dict[str, object] | None = None
|
||||
timestamp: datetime
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import date, datetime, time, timedelta, timezone
|
||||
import json
|
||||
from typing import Protocol
|
||||
from uuid import UUID
|
||||
|
||||
@@ -9,10 +8,9 @@ from fastapi import HTTPException
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from core.config.settings import config
|
||||
from models.agent_chat_message import AgentChatMessage, AgentChatMessageRole
|
||||
from models.agent_chat_session import AgentChatSession
|
||||
from services.base.supabase import supabase_service
|
||||
from schemas.messages.chat_message import AgentChatMessage as AgentChatMessageSchema
|
||||
|
||||
|
||||
class ToolResultPayloadStorage(Protocol):
|
||||
@@ -210,132 +208,17 @@ class AgentRepository:
|
||||
if isinstance(message.role, AgentChatMessageRole)
|
||||
else str(message.role)
|
||||
)
|
||||
payload: dict[str, object] = {
|
||||
"id": str(message.id),
|
||||
"role": role,
|
||||
"timestamp": message.created_at.astimezone(timezone.utc).isoformat(),
|
||||
}
|
||||
|
||||
if role == AgentChatMessageRole.TOOL.value:
|
||||
metadata = message.metadata_json or {}
|
||||
tool_call_id = metadata.get("tool_call_id")
|
||||
if isinstance(tool_call_id, str) and tool_call_id:
|
||||
payload["toolCallId"] = tool_call_id
|
||||
|
||||
parsed_content: dict[str, object] | None = None
|
||||
try:
|
||||
decoded = json.loads(message.content)
|
||||
if isinstance(decoded, dict):
|
||||
parsed_content = decoded
|
||||
except (TypeError, ValueError):
|
||||
parsed_content = None
|
||||
|
||||
hydrated_content: dict[str, object] | None = None
|
||||
if self._tool_result_storage is not None:
|
||||
storage_bucket = metadata.get("storage_bucket")
|
||||
storage_path = metadata.get("storage_path")
|
||||
if isinstance(storage_bucket, str) and isinstance(storage_path, str):
|
||||
expected_bucket = config.storage.bucket
|
||||
message_session_id = getattr(message, "session_id", None)
|
||||
expected_prefix = (
|
||||
f"tool-results/{message_session_id}/"
|
||||
if message_session_id is not None
|
||||
else None
|
||||
)
|
||||
tool_call_id = metadata.get("tool_call_id")
|
||||
is_legacy_path = isinstance(
|
||||
tool_call_id, str
|
||||
) and storage_path.endswith(f"/{tool_call_id}.json")
|
||||
if (
|
||||
storage_bucket == expected_bucket
|
||||
and _is_safe_storage_path(storage_path)
|
||||
and (
|
||||
(
|
||||
expected_prefix is not None
|
||||
and storage_path.startswith(expected_prefix)
|
||||
)
|
||||
or (
|
||||
storage_path.startswith("tool-results/")
|
||||
and is_legacy_path
|
||||
)
|
||||
)
|
||||
):
|
||||
try:
|
||||
hydrated_content = (
|
||||
await self._tool_result_storage.read_json(
|
||||
bucket=storage_bucket,
|
||||
path=storage_path,
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
hydrated_content = None
|
||||
|
||||
resolved_content = hydrated_content or parsed_content
|
||||
payload["content"] = message.content
|
||||
if resolved_content is not None:
|
||||
ui = resolved_content.get("ui")
|
||||
if not isinstance(ui, dict):
|
||||
ui = resolved_content.get("ui_schema")
|
||||
if isinstance(ui, dict):
|
||||
payload["ui"] = ui
|
||||
display_content = resolved_content.get("content")
|
||||
if not isinstance(display_content, str):
|
||||
nested_result = resolved_content.get("result")
|
||||
if isinstance(nested_result, dict):
|
||||
nested_content = nested_result.get("content")
|
||||
if isinstance(nested_content, str):
|
||||
display_content = nested_content
|
||||
if (
|
||||
isinstance(display_content, str)
|
||||
and display_content.strip()
|
||||
and (
|
||||
not payload["content"]
|
||||
or _looks_like_offloaded_placeholder(str(payload["content"]))
|
||||
)
|
||||
):
|
||||
payload["content"] = display_content
|
||||
else:
|
||||
payload["content"] = message.content
|
||||
|
||||
if role == AgentChatMessageRole.USER.value:
|
||||
metadata = message.metadata_json or {}
|
||||
user_attachments = metadata.get("user_message_attachments")
|
||||
if isinstance(user_attachments, dict):
|
||||
bucket = user_attachments.get("bucket")
|
||||
path = user_attachments.get("path")
|
||||
mime_type = user_attachments.get("mime_type")
|
||||
if (
|
||||
isinstance(bucket, str)
|
||||
and isinstance(path, str)
|
||||
and isinstance(mime_type, str)
|
||||
):
|
||||
try:
|
||||
signed_url = await supabase_service.create_signed_url(
|
||||
bucket=bucket,
|
||||
path=path,
|
||||
expires_in_seconds=3600,
|
||||
)
|
||||
attachment_block = {
|
||||
"type": "binary",
|
||||
"mimeType": mime_type,
|
||||
"url": signed_url,
|
||||
}
|
||||
existing_content = message.content
|
||||
if (
|
||||
isinstance(existing_content, str)
|
||||
and existing_content.strip()
|
||||
):
|
||||
content_blocks = [
|
||||
{"type": "text", "text": existing_content}
|
||||
]
|
||||
content_blocks.append(attachment_block)
|
||||
payload["content"] = content_blocks
|
||||
else:
|
||||
payload["content"] = [attachment_block]
|
||||
except Exception: # noqa: BLE001
|
||||
pass
|
||||
|
||||
return payload
|
||||
payload_model = AgentChatMessageSchema.model_validate(
|
||||
{
|
||||
"id": str(message.id),
|
||||
"seq": int(message.seq),
|
||||
"role": role,
|
||||
"content": message.content,
|
||||
"metadata": message.metadata_json,
|
||||
"timestamp": message.created_at.astimezone(timezone.utc).isoformat(),
|
||||
}
|
||||
)
|
||||
return payload_model.model_dump(mode="json", exclude_none=True)
|
||||
|
||||
|
||||
def _has_title(title: object) -> bool:
|
||||
@@ -347,19 +230,3 @@ def _derive_session_title(content_text: str) -> str | None:
|
||||
if not normalized:
|
||||
return None
|
||||
return normalized[:80]
|
||||
|
||||
|
||||
def _is_safe_storage_path(path: str) -> bool:
|
||||
normalized = path.strip()
|
||||
if not normalized:
|
||||
return False
|
||||
if normalized.startswith("/"):
|
||||
return False
|
||||
if ".." in normalized:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def _looks_like_offloaded_placeholder(content: str) -> bool:
|
||||
normalized = content.strip().lower()
|
||||
return normalized in {'{"offloaded":true}', '{"offloaded": true}'}
|
||||
|
||||
@@ -11,9 +11,7 @@ from typing import Annotated, Union
|
||||
|
||||
from ag_ui.core import RunAgentInput
|
||||
from core.agentscope.events import to_sse_event
|
||||
from core.auth.jwt_verifier import JwtVerifier, TokenValidationError
|
||||
from core.auth.models import CurrentUser
|
||||
from core.config.settings import config
|
||||
from core.logging import get_logger
|
||||
from fastapi import (
|
||||
APIRouter,
|
||||
@@ -38,6 +36,7 @@ from v1.agent.dependencies import get_agent_service
|
||||
from v1.agent.schemas import (
|
||||
AsrTranscribeResponse,
|
||||
AttachmentReference,
|
||||
AttachmentSignedUrlResponse,
|
||||
AttachmentUploadResponse,
|
||||
TaskAcceptedResponse,
|
||||
)
|
||||
@@ -63,42 +62,6 @@ _ALLOWED_AUDIO_CONTENT_TYPES = {
|
||||
}
|
||||
|
||||
|
||||
def _verified_access_token_for_user(
|
||||
*,
|
||||
authorization: str | None,
|
||||
current_user: CurrentUser,
|
||||
) -> str:
|
||||
if not isinstance(authorization, str):
|
||||
raise HTTPException(status_code=401, detail="Unauthorized")
|
||||
normalized = authorization.strip()
|
||||
if not normalized:
|
||||
raise HTTPException(status_code=401, detail="Unauthorized")
|
||||
if not normalized.lower().startswith("bearer "):
|
||||
raise HTTPException(status_code=401, detail="Unauthorized")
|
||||
token = normalized[7:].strip()
|
||||
if not token:
|
||||
raise HTTPException(status_code=401, detail="Unauthorized")
|
||||
|
||||
jwt_secret = config.supabase.jwt_secret
|
||||
if jwt_secret is None:
|
||||
raise HTTPException(status_code=503, detail="Auth verifier unavailable")
|
||||
|
||||
verifier = JwtVerifier(
|
||||
issuer=str(config.supabase.jwt_issuer),
|
||||
jwt_secret=jwt_secret.get_secret_value(),
|
||||
jwt_algorithm=config.supabase.jwt_algorithm,
|
||||
)
|
||||
try:
|
||||
payload = verifier.verify(token)
|
||||
except TokenValidationError as exc:
|
||||
raise HTTPException(status_code=401, detail="Unauthorized") from exc
|
||||
|
||||
subject = payload.get("sub")
|
||||
if not isinstance(subject, str) or subject != str(current_user.id):
|
||||
raise HTTPException(status_code=403, detail="Forbidden")
|
||||
return token
|
||||
|
||||
|
||||
def _looks_like_wav_header(header: bytes) -> bool:
|
||||
if len(header) < _WAV_HEADER_MIN_BYTES:
|
||||
return False
|
||||
@@ -164,7 +127,6 @@ async def enqueue_run(
|
||||
request: RunAgentInput,
|
||||
service: Annotated[AgentService, Depends(get_agent_service)],
|
||||
current_user: Annotated[CurrentUser, Depends(get_current_user)],
|
||||
authorization: str | None = Header(default=None, alias="Authorization"),
|
||||
) -> TaskAcceptedResponse:
|
||||
try:
|
||||
normalized = parse_run_input(request.model_dump(mode="json", by_alias=True))
|
||||
@@ -174,15 +136,10 @@ async def enqueue_run(
|
||||
allowed = await _allow_run_request(user_id=str(current_user.id))
|
||||
if not allowed:
|
||||
raise HTTPException(status_code=429, detail="Too many run requests")
|
||||
user_token = _verified_access_token_for_user(
|
||||
authorization=authorization,
|
||||
current_user=current_user,
|
||||
)
|
||||
|
||||
task = await service.enqueue_run(
|
||||
run_input=request,
|
||||
current_user=current_user,
|
||||
user_token=user_token,
|
||||
)
|
||||
return TaskAcceptedResponse(
|
||||
taskId=task.task_id,
|
||||
@@ -202,7 +159,6 @@ async def enqueue_resume(
|
||||
request: RunAgentInput,
|
||||
service: Annotated[AgentService, Depends(get_agent_service)],
|
||||
current_user: Annotated[CurrentUser, Depends(get_current_user)],
|
||||
authorization: str | None = Header(default=None, alias="Authorization"),
|
||||
) -> TaskAcceptedResponse:
|
||||
if request.thread_id != thread_id:
|
||||
raise HTTPException(status_code=422, detail="thread_id path/body mismatch")
|
||||
@@ -214,15 +170,10 @@ async def enqueue_resume(
|
||||
allowed = await _allow_run_request(user_id=str(current_user.id))
|
||||
if not allowed:
|
||||
raise HTTPException(status_code=429, detail="Too many run requests")
|
||||
user_token = _verified_access_token_for_user(
|
||||
authorization=authorization,
|
||||
current_user=current_user,
|
||||
)
|
||||
task = await service.enqueue_resume(
|
||||
thread_id=thread_id,
|
||||
run_input=request,
|
||||
current_user=current_user,
|
||||
user_token=user_token,
|
||||
)
|
||||
return TaskAcceptedResponse(
|
||||
taskId=task.task_id,
|
||||
@@ -304,20 +255,6 @@ async def stream_events(
|
||||
)
|
||||
|
||||
|
||||
@router.get("/runs/{thread_id}/history")
|
||||
async def get_history_snapshot(
|
||||
thread_id: str,
|
||||
service: Annotated[AgentService, Depends(get_agent_service)],
|
||||
current_user: Annotated[CurrentUser, Depends(get_current_user)],
|
||||
before: date | None = Query(default=None),
|
||||
) -> dict[str, object]:
|
||||
return await service.get_history_snapshot(
|
||||
thread_id=thread_id,
|
||||
before=before,
|
||||
current_user=current_user,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/history")
|
||||
async def get_user_history_snapshot(
|
||||
service: Annotated[AgentService, Depends(get_agent_service)],
|
||||
@@ -360,6 +297,25 @@ async def upload_attachment(
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/attachments/signed-url",
|
||||
response_model=AttachmentSignedUrlResponse,
|
||||
status_code=status.HTTP_200_OK,
|
||||
)
|
||||
async def create_attachment_signed_url(
|
||||
service: Annotated[AgentService, Depends(get_agent_service)],
|
||||
current_user: Annotated[CurrentUser, Depends(get_current_user)],
|
||||
bucket: str = Query(min_length=1, max_length=100),
|
||||
path: str = Query(min_length=1, max_length=500),
|
||||
) -> AttachmentSignedUrlResponse:
|
||||
signed = await service.create_attachment_signed_url(
|
||||
bucket=bucket,
|
||||
path=path,
|
||||
current_user=current_user,
|
||||
)
|
||||
return AttachmentSignedUrlResponse(**signed)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/transcribe",
|
||||
response_model=AsrTranscribeResponse,
|
||||
|
||||
@@ -27,3 +27,9 @@ class AttachmentReference(BaseModel):
|
||||
|
||||
class AttachmentUploadResponse(BaseModel):
|
||||
attachment: AttachmentReference
|
||||
|
||||
|
||||
class AttachmentSignedUrlResponse(BaseModel):
|
||||
bucket: str
|
||||
path: str
|
||||
url: str
|
||||
|
||||
@@ -5,6 +5,7 @@ from dataclasses import dataclass
|
||||
from datetime import date
|
||||
import hashlib
|
||||
from typing import Any, Protocol
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import dashscope
|
||||
from ag_ui.core import RunAgentInput, StateSnapshotEvent
|
||||
@@ -23,19 +24,6 @@ _MAX_ATTACHMENT_BYTES = 5 * 1024 * 1024
|
||||
_MAX_TOTAL_ATTACHMENT_BYTES = 12 * 1024 * 1024
|
||||
|
||||
|
||||
def _normalize_bearer_token(value: str | None) -> str | None:
|
||||
if not isinstance(value, str):
|
||||
return None
|
||||
normalized = value.strip()
|
||||
if not normalized:
|
||||
return None
|
||||
lower = normalized.lower()
|
||||
if lower.startswith("bearer "):
|
||||
token = normalized[7:].strip()
|
||||
return token or None
|
||||
return normalized
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TaskAccepted:
|
||||
task_id: str
|
||||
@@ -70,14 +58,6 @@ class AgentRepositoryLike(Protocol):
|
||||
metadata: dict[str, object] | None,
|
||||
) -> None: ...
|
||||
|
||||
async def get_message_attachment_reference(
|
||||
self,
|
||||
*,
|
||||
session_id: str,
|
||||
message_id: str,
|
||||
attachment_index: int,
|
||||
) -> dict[str, str] | None: ...
|
||||
|
||||
|
||||
class QueueClientLike(Protocol):
|
||||
async def enqueue(
|
||||
@@ -148,7 +128,6 @@ class AgentService:
|
||||
*,
|
||||
run_input: RunAgentInput,
|
||||
current_user: CurrentUser,
|
||||
user_token: str | None = None,
|
||||
) -> TaskAccepted:
|
||||
created = False
|
||||
thread_id = run_input.thread_id
|
||||
@@ -188,7 +167,6 @@ class AgentService:
|
||||
command={
|
||||
"command": "run",
|
||||
"owner_id": str(current_user.id),
|
||||
"user_token": _normalize_bearer_token(user_token),
|
||||
"run_input": run_input.model_dump(mode="json", by_alias=True),
|
||||
},
|
||||
dedup_key=None,
|
||||
@@ -226,19 +204,28 @@ class AgentService:
|
||||
mime_type = "application/octet-stream"
|
||||
|
||||
if self._attachment_storage is None:
|
||||
continue
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail="Attachment storage unavailable",
|
||||
)
|
||||
|
||||
try:
|
||||
bucket, path = self._attachment_storage.parse_signed_url(url)
|
||||
bucket, path = self._validate_binary_signed_url(
|
||||
url=url,
|
||||
thread_id=run_input.thread_id,
|
||||
current_user=current_user,
|
||||
)
|
||||
user_attachments = UserMessageAttachments(
|
||||
bucket=bucket,
|
||||
path=path,
|
||||
mime_type=mime_type,
|
||||
)
|
||||
break
|
||||
except Exception: # noqa: BLE001
|
||||
logger.warning("Failed to parse signed URL", url=url)
|
||||
continue
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.warning("Failed to parse signed URL", url=url, error=str(exc))
|
||||
raise HTTPException(status_code=422, detail="Invalid signed image url")
|
||||
|
||||
metadata: dict[str, object] | None = None
|
||||
if user_attachments is not None:
|
||||
@@ -329,13 +316,57 @@ class AgentService:
|
||||
"url": signed_url,
|
||||
}
|
||||
|
||||
async def create_attachment_signed_url(
|
||||
self,
|
||||
*,
|
||||
bucket: str,
|
||||
path: str,
|
||||
current_user: CurrentUser,
|
||||
) -> dict[str, str]:
|
||||
if self._attachment_storage is None:
|
||||
raise HTTPException(
|
||||
status_code=503, detail="Attachment storage unavailable"
|
||||
)
|
||||
normalized_bucket = bucket.strip()
|
||||
if normalized_bucket != config.storage.bucket:
|
||||
raise HTTPException(status_code=422, detail="Invalid attachment bucket")
|
||||
|
||||
normalized_path = path.strip()
|
||||
expected_prefix = f"agent-inputs/{current_user.id}/"
|
||||
if not _is_safe_attachment_path(
|
||||
normalized_path, expected_prefix=expected_prefix
|
||||
):
|
||||
raise HTTPException(status_code=422, detail="Invalid attachment path scope")
|
||||
|
||||
try:
|
||||
signed_url = await self._attachment_storage.create_signed_url(
|
||||
bucket=normalized_bucket,
|
||||
path=normalized_path,
|
||||
expires_in_seconds=self._SIGNED_URL_EXPIRES_IN_SECONDS,
|
||||
)
|
||||
except Exception: # noqa: BLE001
|
||||
logger.exception(
|
||||
"Attachment signed URL generation failed",
|
||||
extra={
|
||||
"bucket": normalized_bucket,
|
||||
"path": normalized_path,
|
||||
"user_id": str(current_user.id),
|
||||
},
|
||||
)
|
||||
raise HTTPException(status_code=502, detail="Failed to generate signed URL")
|
||||
|
||||
return {
|
||||
"bucket": normalized_bucket,
|
||||
"path": normalized_path,
|
||||
"url": signed_url,
|
||||
}
|
||||
|
||||
async def enqueue_resume(
|
||||
self,
|
||||
*,
|
||||
thread_id: str,
|
||||
run_input: RunAgentInput,
|
||||
current_user: CurrentUser,
|
||||
user_token: str | None = None,
|
||||
) -> TaskAccepted:
|
||||
owner = await self._repository.get_session_owner(session_id=thread_id)
|
||||
ensure_session_owner(owner_id=owner, current_user=current_user)
|
||||
@@ -345,7 +376,6 @@ class AgentService:
|
||||
command={
|
||||
"command": "resume",
|
||||
"owner_id": str(current_user.id),
|
||||
"user_token": _normalize_bearer_token(user_token),
|
||||
"run_input": run_input.model_dump(mode="json", by_alias=True),
|
||||
},
|
||||
dedup_key=dedup_key,
|
||||
@@ -428,6 +458,37 @@ class AgentService:
|
||||
current_user=current_user,
|
||||
)
|
||||
|
||||
def _validate_binary_signed_url(
|
||||
self,
|
||||
*,
|
||||
url: str,
|
||||
thread_id: str,
|
||||
current_user: CurrentUser,
|
||||
) -> tuple[str, str]:
|
||||
if self._attachment_storage is None:
|
||||
raise HTTPException(
|
||||
status_code=503, detail="Attachment storage unavailable"
|
||||
)
|
||||
parsed = urlparse(url)
|
||||
expected_host = urlparse(config.supabase.url).netloc
|
||||
if parsed.netloc != expected_host:
|
||||
raise HTTPException(status_code=422, detail="INVALID_BINARY_URL_HOST")
|
||||
|
||||
try:
|
||||
bucket, path = self._attachment_storage.parse_signed_url(url)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
raise HTTPException(
|
||||
status_code=422, detail="Invalid signed image url"
|
||||
) from exc
|
||||
|
||||
if bucket != config.storage.bucket:
|
||||
raise HTTPException(status_code=422, detail="INVALID_BINARY_URL_BUCKET")
|
||||
|
||||
expected_prefix = f"agent-inputs/{current_user.id}/{thread_id}/uploads/"
|
||||
if not _is_safe_attachment_path(path, expected_prefix=expected_prefix):
|
||||
raise HTTPException(status_code=422, detail="INVALID_BINARY_URL_PATH_SCOPE")
|
||||
return bucket, path
|
||||
|
||||
|
||||
class AsrService:
|
||||
def __init__(self) -> None:
|
||||
|
||||
Reference in New Issue
Block a user