feat: 实现 AgentScope ReAct Runner 两阶段执行并重构事件处理

This commit is contained in:
zl-q
2026-03-16 09:01:01 +08:00
parent 072c09d99d
commit dcceb48d84
51 changed files with 5015 additions and 5663 deletions
@@ -38,9 +38,10 @@ _INTERNAL_TO_AGUI: dict[str, EventType] = {
def _convert_to_agui_type(internal_type: str) -> EventType:
return _INTERNAL_TO_AGUI.get(
internal_type, EventType(internal_type.upper().replace(".", "_"))
)
mapped = _INTERNAL_TO_AGUI.get(internal_type)
if mapped is not None:
return mapped
return EventType(internal_type.upper().replace(".", "_"))
def _is_agui_event(event: dict[str, Any]) -> bool:
@@ -142,32 +143,64 @@ def to_agui_wire_event(event: dict[str, Any] | BaseEvent) -> dict[str, Any]:
return event
internal_type = str(event.get("type", "")).strip()
thread_id = event.get("threadId")
run_id = event.get("runId")
data = event.get("data")
if internal_type == "text.end" and isinstance(data, dict):
text_end_payload: dict[str, Any] = {
"type": _convert_to_agui_type(internal_type).value,
}
if isinstance(thread_id, str) and thread_id:
text_end_payload["threadId"] = thread_id
if isinstance(run_id, str) and run_id:
text_end_payload["runId"] = run_id
for key in ("messageId", "workerAgentOutput"):
value = data.get(key)
if value is not None:
text_end_payload[key] = value
return text_end_payload
if internal_type == "tool.result" and isinstance(data, dict):
tool_result_payload = {
"type": _convert_to_agui_type(internal_type).value,
}
if isinstance(thread_id, str) and thread_id:
tool_result_payload["threadId"] = thread_id
if isinstance(run_id, str) and run_id:
tool_result_payload["runId"] = run_id
for key in ("messageId", "toolCallId", "toolAgentOutput"):
value = data.get(key)
if value is not None:
tool_result_payload[key] = value
return tool_result_payload
builder = _BUILDER_MAP.get(internal_type)
if builder:
agui_event = builder(event)
return agui_event.model_dump(by_alias=True, exclude_none=True)
payload = agui_event.model_dump(by_alias=True, exclude_none=True)
if isinstance(thread_id, str) and thread_id:
payload["threadId"] = thread_id
if isinstance(run_id, str) and run_id:
payload["runId"] = run_id
if isinstance(data, dict):
reserved = {"type", "threadId", "runId"}
payload.update({k: v for k, v in data.items() if k not in reserved})
return payload
wire_type = _convert_to_agui_type(internal_type)
payload: dict[str, Any] = {
"type": wire_type.value,
}
thread_id = event.get("threadId")
run_id = event.get("runId")
if isinstance(thread_id, str) and thread_id:
payload["threadId"] = thread_id
if isinstance(run_id, str) and run_id:
payload["runId"] = run_id
data = event.get("data")
if isinstance(data, dict):
if internal_type == "text.end":
for key in ("messageId", "workerAgentOutput"):
value = data.get(key)
if value is not None:
payload[key] = value
return payload
reserved = {"type", "threadId", "runId"}
data_map = cast(dict[str, Any], data)
payload.update({k: v for k, v in data_map.items() if k not in reserved})
@@ -50,5 +50,5 @@ class AgentScopeEventPipeline:
) -> str:
event_dict = to_dict(event)
wire_event = self._codec.to_wire(event_dict)
await self._store.persist(wire_event)
await self._store.persist(event_dict)
return await self._bus.publish(session_id=session_id, event=wire_event)
+36 -23
View File
@@ -55,8 +55,8 @@ class SqlAlchemyEventStore:
self._message_contexts: dict[tuple[str, str], dict[str, object]] = {}
async def persist(self, event: dict[str, Any]) -> None:
event_type = str(event.get("type", "")).strip().upper()
thread_id = event.get("threadId")
event_type = str(event.get("type", "")).strip().upper().replace(".", "_")
thread_id = self._event_value(event, "threadId")
if not isinstance(thread_id, str) or not thread_id:
return
try:
@@ -124,8 +124,8 @@ class SqlAlchemyEventStore:
await session.commit()
def _buffer_text_delta(self, *, session_key: str, event: dict[str, Any]) -> None:
message_id = event.get("messageId")
delta = event.get("delta")
message_id = self._event_value(event, "messageId")
delta = self._event_value(event, "delta")
if not isinstance(message_id, str) or not message_id:
return
if not isinstance(delta, str) or not delta:
@@ -143,13 +143,13 @@ class SqlAlchemyEventStore:
self._message_contexts.pop(key, None)
def _buffer_text_context(self, *, session_key: str, event: dict[str, Any]) -> None:
message_id = event.get("messageId")
message_id = self._event_value(event, "messageId")
if not isinstance(message_id, str) or not message_id:
return
key = (session_key, message_id)
role = event.get("role")
stage = event.get("stage")
tool_name = event.get("toolName")
role = self._event_value(event, "role")
stage = self._event_value(event, "stage")
tool_name = self._event_value(event, "toolName")
context: dict[str, object] = {}
if isinstance(role, str) and role:
context["role"] = role
@@ -168,7 +168,7 @@ class SqlAlchemyEventStore:
session_repo: SessionRepository,
message_repo: MessageRepository,
) -> None:
message_id_raw = event.get("messageId")
message_id_raw = self._event_value(event, "messageId")
message_id = message_id_raw if isinstance(message_id_raw, str) else ""
key = (str(session_id), message_id)
content = self._message_buffers.get(key, "")
@@ -177,26 +177,26 @@ class SqlAlchemyEventStore:
context = self._message_contexts.get(key, {})
input_tokens = self._to_int(event.get("inputTokens"))
output_tokens = self._to_int(event.get("outputTokens"))
input_tokens = self._to_int(self._event_value(event, "inputTokens"))
output_tokens = self._to_int(self._event_value(event, "outputTokens"))
token_delta = input_tokens + output_tokens
cost = self._to_decimal(event.get("cost"))
latency_ms = self._to_int_or_none(event.get("latencyMs"))
run_id = event.get("runId")
model_code = event.get("model")
cost = self._to_decimal(self._event_value(event, "cost"))
latency_ms = self._to_int_or_none(self._event_value(event, "latencyMs"))
run_id = self._event_value(event, "runId")
model_code = self._event_value(event, "model")
metadata: dict[str, object] = {"message_id": message_id}
if isinstance(run_id, str) and run_id:
metadata["run_id"] = run_id
if latency_ms is not None:
metadata["latency_ms"] = latency_ms
stage = event.get("stage")
stage = self._event_value(event, "stage")
if not isinstance(stage, str):
stage = context.get("stage")
if isinstance(stage, str) and stage:
metadata["stage"] = stage
worker_payload = event.get("workerAgentOutput")
worker_payload = self._event_value(event, "workerAgentOutput")
if isinstance(worker_payload, dict):
try:
if "ui_hints" in worker_payload:
@@ -264,11 +264,11 @@ class SqlAlchemyEventStore:
session_repo: SessionRepository,
message_repo: MessageRepository,
) -> None:
tool_name = event.get("toolName")
tool_name = self._event_value(event, "toolName")
if not isinstance(tool_name, str) or not tool_name:
return
raw_output = event.get("toolAgentOutput")
raw_output = self._event_value(event, "toolAgentOutput")
if not isinstance(raw_output, dict):
return
try:
@@ -276,11 +276,11 @@ class SqlAlchemyEventStore:
except Exception:
return
run_id = event.get("runId")
run_id = self._event_value(event, "runId")
run_id_value = run_id if isinstance(run_id, str) and run_id else ""
task_id = event.get("taskId")
task_id = self._event_value(event, "taskId")
task_id_value = task_id if isinstance(task_id, str) and task_id else "task"
call_id_value = event.get("callId")
call_id_value = self._event_value(event, "callId")
if not isinstance(call_id_value, str) or not call_id_value:
call_id_value = (
f"{run_id_value}-{task_id_value}-{uuid4().hex[:8]}"
@@ -303,7 +303,7 @@ class SqlAlchemyEventStore:
}
if run_id_value:
metadata["run_id"] = run_id_value
stage = event.get("stage")
stage = self._event_value(event, "stage")
if isinstance(stage, str) and stage:
metadata["stage"] = stage
if task_id_value:
@@ -421,6 +421,19 @@ class SqlAlchemyEventStore:
return Decimal("0")
return parsed if parsed >= 0 else Decimal("0")
def _event_value(
self,
event: dict[str, Any],
key: str,
default: object | None = None,
) -> object | None:
if key in event:
return event.get(key)
data = event.get("data")
if isinstance(data, dict):
return data.get(key, default)
return default
def _sanitize_path_component(value: str) -> str:
compact = re.sub(r"[^A-Za-z0-9._-]", "-", value.strip())