refactor: 重构 AgentScope ReAct Runner 与事件处理

- 重构 runtime/runner.py 实现 ReAct Agent 核心逻辑
- 更新事件编码器与存储机制
- 优化 prompt 系统与 tool 调用
- 调整 agent service 与 repository 配合
This commit is contained in:
qzl
2026-03-16 16:10:39 +08:00
parent ab073c88ed
commit 36b104fa37
22 changed files with 1288 additions and 319 deletions
@@ -10,11 +10,11 @@ from ag_ui.core import (
RunErrorEvent,
StepStartedEvent,
StepFinishedEvent,
TextMessageStartEvent,
TextMessageContentEvent,
TextMessageEndEvent,
ToolCallResultEvent,
)
from core.agentscope.runtime.ui_compiler import compile as compile_ui_hints
from schemas.agent.ui_hints import UiHintsPayload
if TYPE_CHECKING:
pass
@@ -25,8 +25,6 @@ _INTERNAL_TO_AGUI: dict[str, EventType] = {
"run.error": EventType.RUN_ERROR,
"step.start": EventType.STEP_STARTED,
"step.finish": EventType.STEP_FINISHED,
"text.start": EventType.TEXT_MESSAGE_START,
"text.delta": EventType.TEXT_MESSAGE_CONTENT,
"text.end": EventType.TEXT_MESSAGE_END,
"tool.start": EventType.TOOL_CALL_START,
"tool.args": EventType.TOOL_CALL_ARGS,
@@ -53,6 +51,34 @@ def _is_agui_event(event: dict[str, Any]) -> bool:
return False
def _sanitize_agui_event(event: dict[str, Any]) -> dict[str, Any]:
payload = dict(event)
event_type = str(payload.get("type", "")).strip().upper()
if event_type in {
EventType.TEXT_MESSAGE_END.value,
EventType.TOOL_CALL_RESULT.value,
}:
ui_hints = payload.get("ui_hints")
if ui_hints is not None:
try:
ui_hints_payload = UiHintsPayload.model_validate(ui_hints)
ui_schema = compile_ui_hints(ui_hints_payload)
payload["ui_schema"] = ui_schema
except Exception:
pass
payload.pop("ui_hints", None)
if event_type == EventType.TEXT_MESSAGE_END.value:
for key in (
"inputTokens",
"outputTokens",
"cost",
"latencyMs",
"model",
):
payload.pop(key, None)
return payload
def _build_run_started(event: dict[str, Any]) -> RunStartedEvent:
return RunStartedEvent(
thread_id=event.get("threadId", ""),
@@ -77,31 +103,21 @@ def _build_run_error(event: dict[str, Any]) -> RunErrorEvent:
def _build_step_started(event: dict[str, Any]) -> StepStartedEvent:
data = event.get("data", {})
step_name = event.get("stepName", "")
if (not isinstance(step_name, str) or not step_name) and isinstance(data, dict):
step_name = data.get("stepName", "")
return StepStartedEvent(
step_name=data.get("stepName", ""),
step_name=step_name if isinstance(step_name, str) else "",
)
def _build_step_finished(event: dict[str, Any]) -> StepFinishedEvent:
data = event.get("data", {})
step_name = event.get("stepName", "")
if (not isinstance(step_name, str) or not step_name) and isinstance(data, dict):
step_name = data.get("stepName", "")
return StepFinishedEvent(
step_name=data.get("stepName", ""),
)
def _build_text_start(event: dict[str, Any]) -> TextMessageStartEvent:
data = event.get("data", {})
return TextMessageStartEvent(
message_id=data.get("messageId", ""),
role=data.get("role", "assistant"),
)
def _build_text_delta(event: dict[str, Any]) -> TextMessageContentEvent:
data = event.get("data", {})
return TextMessageContentEvent(
message_id=data.get("messageId", ""),
delta=data.get("delta", ""),
step_name=step_name if isinstance(step_name, str) else "",
)
@@ -128,8 +144,6 @@ _BUILDER_MAP: dict[str, Any] = {
"run.error": _build_run_error,
"step.start": _build_step_started,
"step.finish": _build_step_finished,
"text.start": _build_text_start,
"text.delta": _build_text_delta,
"text.end": _build_text_end,
"tool.result": _build_tool_result,
}
@@ -140,7 +154,7 @@ def to_agui_wire_event(event: dict[str, Any] | BaseEvent) -> dict[str, Any]:
return event.model_dump(by_alias=True, exclude_none=True)
if _is_agui_event(event):
return event
return _sanitize_agui_event(event)
internal_type = str(event.get("type", "")).strip()
@@ -156,24 +170,29 @@ def to_agui_wire_event(event: dict[str, Any] | BaseEvent) -> dict[str, Any]:
text_end_payload["threadId"] = thread_id
if isinstance(run_id, str) and run_id:
text_end_payload["runId"] = run_id
for key in ("messageId", "workerAgentOutput"):
value = data.get(key)
if value is not None:
text_end_payload[key] = value
reserved = {
"type",
"threadId",
"runId",
"inputTokens",
"outputTokens",
"cost",
"latencyMs",
"model",
}
text_end_payload.update({k: v for k, v in data.items() if k not in reserved})
return text_end_payload
if internal_type == "tool.result" and isinstance(data, dict):
tool_result_payload = {
tool_result_payload: dict[str, Any] = {
"type": _convert_to_agui_type(internal_type).value,
}
if isinstance(thread_id, str) and thread_id:
tool_result_payload["threadId"] = thread_id
if isinstance(run_id, str) and run_id:
tool_result_payload["runId"] = run_id
for key in ("messageId", "toolCallId", "toolAgentOutput"):
value = data.get(key)
if value is not None:
tool_result_payload[key] = value
reserved = {"type", "threadId", "runId"}
tool_result_payload.update({k: v for k, v in data.items() if k not in reserved})
return tool_result_payload
builder = _BUILDER_MAP.get(internal_type)
+68 -174
View File
@@ -1,35 +1,22 @@
from __future__ import annotations
import re
from decimal import Decimal, InvalidOperation
from typing import Any, Callable, Protocol
from uuid import UUID, uuid4
from uuid import UUID
from core.agentscope.events.persistence import MessageRepository, SessionRepository
from core.logging import get_logger
from models.agent_chat_message import AgentChatMessageRole
from models.agent_chat_session import AgentChatSessionStatus
from schemas.agent.runtime_models import (
ToolAgentOutput,
WorkerAgentOutputLite,
WorkerAgentOutputRich,
)
from schemas.agent.runtime_models import ToolAgentOutput, WorkerAgentOutputRich
from schemas.agent.system_agent import AgentType
from schemas.messages.chat_message import AgentChatMessageMetadata
class EventStore(Protocol):
async def persist(self, event: dict[str, Any]) -> None: ...
class ToolResultStorageLike(Protocol):
async def upload_json(
self,
*,
bucket: str,
path: str,
payload: dict[str, object],
) -> str: ...
class NullEventStore:
async def persist(self, event: dict[str, Any]) -> None:
del event
@@ -37,22 +24,14 @@ class NullEventStore:
class SqlAlchemyEventStore:
_session_factory: Callable[[], Any]
_tool_result_storage: ToolResultStorageLike | None
_tool_result_bucket: str | None
_logger = get_logger("core.agentscope.events.store")
def __init__(
self,
*,
session_factory: Any,
tool_result_storage: ToolResultStorageLike | None = None,
tool_result_bucket: str | None = None,
) -> None:
self._session_factory = session_factory
self._tool_result_storage = tool_result_storage
self._tool_result_bucket = tool_result_bucket
self._message_buffers: dict[tuple[str, str], str] = {}
self._message_contexts: dict[tuple[str, str], dict[str, object]] = {}
async def persist(self, event: dict[str, Any]) -> None:
event_type = str(event.get("type", "")).strip().upper().replace(".", "_")
@@ -63,22 +42,11 @@ class SqlAlchemyEventStore:
session_id = UUID(thread_id)
except ValueError:
return
session_key = str(session_id)
async with self._session_factory() as session:
session_repo = SessionRepository(session)
message_repo = MessageRepository(session)
chat_session = await session_repo.get_session(session_id=session_id)
if chat_session is None:
self._clear_session_buffers(session_key=session_key)
return
if event_type == "TEXT_MESSAGE_CONTENT":
self._buffer_text_delta(session_key=session_key, event=event)
return
if event_type == "TEXT_MESSAGE_START":
self._buffer_text_context(session_key=session_key, event=event)
return
if event_type == "RUN_STARTED":
@@ -95,7 +63,6 @@ class SqlAlchemyEventStore:
status=AgentChatSessionStatus.FAILED,
message_delta=0,
)
self._clear_session_buffers(session_key=session_key)
elif event_type == "RUN_FINISHED":
await self._update_session_state(
session_repo=session_repo,
@@ -103,7 +70,6 @@ class SqlAlchemyEventStore:
status=AgentChatSessionStatus.COMPLETED,
message_delta=0,
)
self._clear_session_buffers(session_key=session_key)
elif event_type == "TEXT_MESSAGE_END":
await self._persist_text_message(
event=event,
@@ -123,42 +89,6 @@ class SqlAlchemyEventStore:
await session.commit()
def _buffer_text_delta(self, *, session_key: str, event: dict[str, Any]) -> None:
message_id = self._event_value(event, "messageId")
delta = self._event_value(event, "delta")
if not isinstance(message_id, str) or not message_id:
return
if not isinstance(delta, str) or not delta:
return
key = (session_key, message_id)
current = self._message_buffers.get(key, "")
self._message_buffers[key] = f"{current}{delta}"
def _clear_session_buffers(self, *, session_key: str) -> None:
stale_keys = [k for k in self._message_buffers if k[0] == session_key]
for key in stale_keys:
self._message_buffers.pop(key, None)
stale_context_keys = [k for k in self._message_contexts if k[0] == session_key]
for key in stale_context_keys:
self._message_contexts.pop(key, None)
def _buffer_text_context(self, *, session_key: str, event: dict[str, Any]) -> None:
message_id = self._event_value(event, "messageId")
if not isinstance(message_id, str) or not message_id:
return
key = (session_key, message_id)
role = self._event_value(event, "role")
stage = self._event_value(event, "stage")
tool_name = self._event_value(event, "toolName")
context: dict[str, object] = {}
if isinstance(role, str) and role:
context["role"] = role
if isinstance(stage, str) and stage:
context["stage"] = stage
if isinstance(tool_name, str) and tool_name:
context["tool_name"] = tool_name
self._message_contexts[key] = context
async def _persist_text_message(
self,
*,
@@ -170,13 +100,11 @@ class SqlAlchemyEventStore:
) -> None:
message_id_raw = self._event_value(event, "messageId")
message_id = message_id_raw if isinstance(message_id_raw, str) else ""
key = (str(session_id), message_id)
content = self._message_buffers.get(key, "")
content_value = self._event_value(event, "answer")
content = content_value if isinstance(content_value, str) else ""
if not content:
return
context = self._message_contexts.get(key, {})
input_tokens = self._to_int(self._event_value(event, "inputTokens"))
output_tokens = self._to_int(self._event_value(event, "outputTokens"))
token_delta = input_tokens + output_tokens
@@ -185,35 +113,48 @@ class SqlAlchemyEventStore:
run_id = self._event_value(event, "runId")
model_code = self._event_value(event, "model")
metadata: dict[str, object] = {"message_id": message_id}
if isinstance(run_id, str) and run_id:
metadata["run_id"] = run_id
if latency_ms is not None:
metadata["latency_ms"] = latency_ms
stage = self._event_value(event, "stage")
if not isinstance(stage, str):
stage = context.get("stage")
if isinstance(stage, str) and stage:
metadata["stage"] = stage
run_id_value = run_id if isinstance(run_id, str) and run_id else None
if run_id_value is None:
return
worker_payload = self._event_value(event, "workerAgentOutput")
if isinstance(worker_payload, dict):
try:
if "ui_hints" in worker_payload:
worker_output = WorkerAgentOutputRich.model_validate(worker_payload)
else:
worker_output = WorkerAgentOutputLite.model_validate(worker_payload)
except Exception:
worker_output = None
else:
content = worker_output.answer
metadata["worker_agent_output"] = worker_output.model_dump(mode="json")
worker_output_fields = (
"status",
"answer",
"key_points",
"result_type",
"suggested_actions",
"error",
"ui_hints",
)
worker_output_payload: dict[str, object] = {}
for field in worker_output_fields:
value = self._event_value(event, field)
if value is not None:
worker_output_payload[field] = value
role_value = context.get("role")
if not worker_output_payload:
return
try:
worker_output = WorkerAgentOutputRich.model_validate(worker_output_payload)
metadata_model = AgentChatMessageMetadata(
run_id=run_id_value,
agent_type=AgentType.WORKER,
worker_agent_output=worker_output,
)
except Exception:
self._logger.warning(
"invalid worker metadata payload",
run_id=run_id_value,
message_id=message_id,
)
return
role_value = self._event_value(event, "role")
if not isinstance(role_value, str):
role_value = "assistant"
role = self._resolve_role(role_value)
tool_name = context.get("tool_name")
tool_name = self._event_value(event, "tool_name")
tool_name_value = (
tool_name if isinstance(tool_name, str) and tool_name else None
)
@@ -231,7 +172,7 @@ class SqlAlchemyEventStore:
content=content,
model_code=model_code if isinstance(model_code, str) else None,
tool_name=tool_name_value,
metadata=metadata,
metadata=metadata_model.model_dump(mode="json", exclude_none=True),
input_tokens=input_tokens,
output_tokens=output_tokens,
cost=cost,
@@ -252,8 +193,6 @@ class SqlAlchemyEventStore:
token_delta=token_delta,
cost_delta=cost,
)
self._message_buffers.pop(key, None)
self._message_contexts.pop(key, None)
async def _persist_tool_call_result(
self,
@@ -264,72 +203,33 @@ class SqlAlchemyEventStore:
session_repo: SessionRepository,
message_repo: MessageRepository,
) -> None:
tool_name = self._event_value(event, "toolName")
if not isinstance(tool_name, str) or not tool_name:
run_id = self._event_value(event, "runId")
run_id_value = run_id if isinstance(run_id, str) and run_id else None
if run_id_value is None:
return
raw_output = self._event_value(event, "toolAgentOutput")
if not isinstance(raw_output, dict):
return
raw_output: dict[str, object] = {
"tool_name": self._event_value(event, "tool_name"),
"tool_call_id": self._event_value(event, "tool_call_id"),
"tool_call_args": self._event_value(event, "tool_call_args"),
"status": self._event_value(event, "status"),
"result_summary": self._event_value(event, "result_summary"),
"error": self._event_value(event, "error"),
"ui_hints": self._event_value(event, "ui_hints"),
}
try:
tool_output = ToolAgentOutput.model_validate(raw_output)
except Exception:
return
run_id = self._event_value(event, "runId")
run_id_value = run_id if isinstance(run_id, str) and run_id else ""
task_id = self._event_value(event, "taskId")
task_id_value = task_id if isinstance(task_id, str) and task_id else "task"
call_id_value = self._event_value(event, "callId")
if not isinstance(call_id_value, str) or not call_id_value:
call_id_value = (
f"{run_id_value}-{task_id_value}-{uuid4().hex[:8]}"
if run_id_value
else f"{task_id_value}-{uuid4().hex[:8]}"
metadata_model = AgentChatMessageMetadata(
run_id=run_id_value,
tool_agent_output=tool_output,
)
payload: dict[str, object] = {
"toolAgentOutput": tool_output.model_dump(mode="json"),
"callId": call_id_value,
"runId": run_id_value,
"taskId": task_id_value,
"content": tool_output.result_summary,
}
metadata: dict[str, object] = {
"tool_name": tool_name,
"tool_call_id": call_id_value,
"tool_agent_output": tool_output.model_dump(mode="json"),
}
if run_id_value:
metadata["run_id"] = run_id_value
stage = self._event_value(event, "stage")
if isinstance(stage, str) and stage:
metadata["stage"] = stage
if task_id_value:
metadata["task_id"] = task_id_value
if self._tool_result_storage is not None and self._tool_result_bucket:
safe_run = _sanitize_path_component(run_id_value or "run")
safe_call = _sanitize_path_component(call_id_value)
storage_path = f"tool-results/{session_id}/{safe_run}/{safe_call}.json"
try:
await self._tool_result_storage.upload_json(
bucket=self._tool_result_bucket,
path=storage_path,
payload=payload,
)
metadata["storage_bucket"] = self._tool_result_bucket
metadata["storage_path"] = storage_path
except Exception: # noqa: BLE001
metadata["storage_upload_failed"] = True
self._logger.warning(
"tool result storage upload failed",
session_id=str(session_id),
run_id=run_id_value,
call_id=call_id_value,
storage_path=storage_path,
)
except Exception:
self._logger.warning(
"invalid tool metadata payload",
run_id=run_id_value,
)
return
content = tool_output.result_summary
@@ -344,8 +244,8 @@ class SqlAlchemyEventStore:
seq=seq,
role=AgentChatMessageRole.TOOL,
content=content,
tool_name=tool_name,
metadata=metadata,
tool_name=tool_output.tool_name,
metadata=metadata_model.model_dump(mode="json", exclude_none=True),
)
current_status = getattr(chat_session, "status", AgentChatSessionStatus.RUNNING)
@@ -433,9 +333,3 @@ class SqlAlchemyEventStore:
if isinstance(data, dict):
return data.get(key, default)
return default
def _sanitize_path_component(value: str) -> str:
compact = re.sub(r"[^A-Za-z0-9._-]", "-", value.strip())
compact = compact.strip(".-")
return compact or "id"