feat(agent): 增强多模态链路与工具调用能力

This commit is contained in:
zl-q
2026-03-12 00:18:45 +08:00
parent 18db6c50e7
commit 21ba8e4a44
35 changed files with 2057 additions and 829 deletions
@@ -23,10 +23,12 @@ class MessageRepository:
role: AgentChatMessageRole,
content: str,
model_code: str | None = None,
tool_name: str | None = None,
metadata: dict[str, object] | None = None,
input_tokens: int = 0,
output_tokens: int = 0,
cost: Decimal = Decimal("0"),
latency_ms: int | None = None,
) -> AgentChatMessage:
message = AgentChatMessage(
session_id=session_id,
@@ -34,10 +36,12 @@ class MessageRepository:
role=role,
content=content,
model_code=model_code,
tool_name=tool_name,
metadata_json=metadata,
input_tokens=input_tokens,
output_tokens=output_tokens,
cost=cost,
latency_ms=latency_ms,
)
self._session.add(message)
await self._session.flush()
+123 -3
View File
@@ -1,5 +1,6 @@
from __future__ import annotations
import json
from decimal import Decimal, InvalidOperation
from typing import Any, Callable, Protocol
from uuid import UUID
@@ -24,6 +25,7 @@ class SqlAlchemyEventStore:
def __init__(self, *, session_factory: Any) -> None:
self._session_factory = session_factory
self._message_buffers: dict[tuple[str, str], str] = {}
self._message_contexts: dict[tuple[str, str], dict[str, object]] = {}
async def persist(self, event: dict[str, Any]) -> None:
event_type = str(event.get("type", "")).strip().upper()
@@ -48,6 +50,10 @@ class SqlAlchemyEventStore:
self._buffer_text_delta(session_key=session_key, event=event)
return
if event_type == "TEXT_MESSAGE_START":
self._buffer_text_context(session_key=session_key, event=event)
return
if event_type == "RUN_STARTED":
await self._update_session_state(
session_repo=session_repo,
@@ -72,7 +78,15 @@ class SqlAlchemyEventStore:
)
self._clear_session_buffers(session_key=session_key)
elif event_type == "TEXT_MESSAGE_END":
await self._persist_assistant_message(
await self._persist_text_message(
event=event,
session_id=session_id,
chat_session=chat_session,
session_repo=session_repo,
message_repo=message_repo,
)
elif event_type == "TOOL_CALL_RESULT":
await self._persist_tool_call_result(
event=event,
session_id=session_id,
chat_session=chat_session,
@@ -97,8 +111,28 @@ class SqlAlchemyEventStore:
stale_keys = [k for k in self._message_buffers if k[0] == session_key]
for key in stale_keys:
self._message_buffers.pop(key, None)
stale_context_keys = [k for k in self._message_contexts if k[0] == session_key]
for key in stale_context_keys:
self._message_contexts.pop(key, None)
async def _persist_assistant_message(
def _buffer_text_context(self, *, session_key: str, event: dict[str, Any]) -> None:
message_id = event.get("messageId")
if not isinstance(message_id, str) or not message_id:
return
key = (session_key, message_id)
role = event.get("role")
stage = event.get("stage")
tool_name = event.get("toolName")
context: dict[str, object] = {}
if isinstance(role, str) and role:
context["role"] = role
if isinstance(stage, str) and stage:
context["stage"] = stage
if isinstance(tool_name, str) and tool_name:
context["tool_name"] = tool_name
self._message_contexts[key] = context
async def _persist_text_message(
self,
*,
event: dict[str, Any],
@@ -114,6 +148,8 @@ class SqlAlchemyEventStore:
if not content:
return
context = self._message_contexts.get(key, {})
input_tokens = self._to_int(event.get("inputTokens"))
output_tokens = self._to_int(event.get("outputTokens"))
token_delta = input_tokens + output_tokens
@@ -127,6 +163,20 @@ class SqlAlchemyEventStore:
metadata["run_id"] = run_id
if latency_ms is not None:
metadata["latency_ms"] = latency_ms
stage = event.get("stage")
if not isinstance(stage, str):
stage = context.get("stage")
if isinstance(stage, str) and stage:
metadata["stage"] = stage
role_value = context.get("role")
if not isinstance(role_value, str):
role_value = "assistant"
role = self._resolve_role(role_value)
tool_name = context.get("tool_name")
tool_name_value = (
tool_name if isinstance(tool_name, str) and tool_name else None
)
locked_session = await session_repo.lock_session_for_update(
session_id=session_id
@@ -137,13 +187,15 @@ class SqlAlchemyEventStore:
await message_repo.append_message(
session_id=session_id,
seq=seq,
role=AgentChatMessageRole.ASSISTANT,
role=role,
content=content,
model_code=model_code if isinstance(model_code, str) else None,
tool_name=tool_name_value,
metadata=metadata,
input_tokens=input_tokens,
output_tokens=output_tokens,
cost=cost,
latency_ms=latency_ms,
)
current_status = getattr(chat_session, "status", AgentChatSessionStatus.RUNNING)
@@ -161,6 +213,74 @@ class SqlAlchemyEventStore:
cost_delta=cost,
)
self._message_buffers.pop(key, None)
self._message_contexts.pop(key, None)
async def _persist_tool_call_result(
self,
*,
event: dict[str, Any],
session_id: UUID,
chat_session: Any,
session_repo: SessionRepository,
message_repo: MessageRepository,
) -> None:
tool_name = event.get("toolName")
if not isinstance(tool_name, str) or not tool_name:
return
payload = {
"args": event.get("args"),
"result": event.get("result"),
"error": event.get("error"),
"call_id": event.get("callId"),
}
content = json.dumps(payload, ensure_ascii=False, separators=(",", ":"))
metadata: dict[str, object] = {"tool_name": tool_name}
run_id = event.get("runId")
if isinstance(run_id, str) and run_id:
metadata["run_id"] = run_id
stage = event.get("stage")
if isinstance(stage, str) and stage:
metadata["stage"] = stage
task_id = event.get("taskId")
if isinstance(task_id, str) and task_id:
metadata["task_id"] = task_id
locked_session = await session_repo.lock_session_for_update(
session_id=session_id
)
if locked_session is None:
return
seq = int(getattr(locked_session, "message_count", 0) or 0) + 1
await message_repo.append_message(
session_id=session_id,
seq=seq,
role=AgentChatMessageRole.TOOL,
content=content,
tool_name=tool_name,
metadata=metadata,
)
current_status = getattr(chat_session, "status", AgentChatSessionStatus.RUNNING)
status = (
current_status
if isinstance(current_status, AgentChatSessionStatus)
else AgentChatSessionStatus.RUNNING
)
await self._update_session_state(
session_repo=session_repo,
chat_session=chat_session,
status=status,
message_delta=1,
)
def _resolve_role(self, value: str) -> AgentChatMessageRole:
normalized = value.strip().lower()
if normalized == AgentChatMessageRole.SYSTEM.value:
return AgentChatMessageRole.SYSTEM
if normalized == AgentChatMessageRole.TOOL.value:
return AgentChatMessageRole.TOOL
return AgentChatMessageRole.ASSISTANT
async def _update_session_state(
self,
@@ -38,7 +38,38 @@ def _schema_json(model: type[Any]) -> str:
)
def build_intent_user_prompt(*, user_input: str | list[dict[str, Any]]) -> str:
def build_intent_user_prompt(
*, user_input: str | list[dict[str, Any]]
) -> str | list[dict[str, Any]]:
if isinstance(user_input, list):
instruction_text = "\n\n".join(
[
INTENT_TASK_INSTRUCTION,
"[Output Schema]",
_schema_json(IntentOutput),
"[User Input]",
"Use the following multimodal blocks as the latest user input.",
]
)
blocks = [
{
"type": "text",
"text": instruction_text,
}
]
user_blocks = _latest_user_content_blocks(user_input)
if not user_blocks:
user_blocks = [
{
"type": "text",
"text": json.dumps(
user_input, ensure_ascii=True, separators=(",", ":")
),
}
]
blocks.extend(user_blocks)
return blocks
normalized_input = (
user_input
if isinstance(user_input, str)
@@ -55,6 +86,101 @@ def build_intent_user_prompt(*, user_input: str | list[dict[str, Any]]) -> str:
)
def _latest_user_content_blocks(
user_input: list[dict[str, Any]],
) -> list[dict[str, Any]]:
for message in reversed(user_input):
if not isinstance(message, dict):
continue
if message.get("role") != "user":
continue
content = message.get("content")
if isinstance(content, str):
text = content.strip()
return [{"type": "text", "text": text}] if text else []
if not isinstance(content, list):
return []
blocks: list[dict[str, Any]] = []
for item in content:
if not isinstance(item, dict):
continue
item_type = item.get("type")
if item_type == "text":
text = item.get("text")
if isinstance(text, str) and text.strip():
blocks.append({"type": "text", "text": text})
continue
if item_type == "binary":
source_block = _binary_source_block(item)
if source_block is not None:
blocks.append(source_block)
continue
if item_type == "image":
source_block = _image_source_block(item)
if source_block is not None:
blocks.append(source_block)
return blocks
return []
def _binary_source_block(item: dict[str, Any]) -> dict[str, Any] | None:
mime_type = item.get("mimeType")
media_type = mime_type if isinstance(mime_type, str) and mime_type else "image/png"
if not media_type.startswith("image/"):
return None
source_url = item.get("url")
if isinstance(source_url, str) and source_url:
return {"type": "image", "source": {"type": "url", "url": source_url}}
source_data = item.get("data")
if isinstance(source_data, str) and source_data:
return {
"type": "image",
"source": {
"type": "base64",
"media_type": media_type,
"data": source_data,
},
}
return None
def _image_source_block(item: dict[str, Any]) -> dict[str, Any] | None:
source = item.get("source")
if not isinstance(source, dict):
return None
source_type = source.get("type")
if source_type == "url":
source_url = source.get("value") or source.get("url")
if isinstance(source_url, str) and source_url:
return {"type": "image", "source": {"type": "url", "url": source_url}}
if source_type in {"data", "base64"}:
source_data = source.get("value") or source.get("data")
if isinstance(source_data, str) and source_data:
mime_type = source.get("mimeType") or source.get("media_type")
media_type = (
mime_type if isinstance(mime_type, str) and mime_type else "image/png"
)
if not media_type.startswith("image/"):
return None
return {
"type": "image",
"source": {
"type": "base64",
"media_type": media_type,
"data": source_data,
},
}
return None
def build_execution_user_prompt(
*,
task_id: str,
@@ -1,5 +1,6 @@
from __future__ import annotations
import json
from typing import Any, Protocol
from uuid import UUID
@@ -153,6 +154,44 @@ class AgentRouteRuntime:
},
)
await self._emit_stage_text(
thread_id=command.thread_id,
run_id=command.run_id,
stage_name="intent",
message_id=f"intent-{command.run_id}",
text=_intent_text_payload(result.intent),
response_metadata=result.intent.response_metadata,
)
if result.intent.route == "DIRECT_RESPONSE" and result.execution is None:
await self._pipeline.emit(
session_id=command.thread_id,
event={
"type": "run.finished",
"threadId": command.thread_id,
"runId": command.run_id,
"data": {},
},
)
return result
if result.execution is not None:
for index, task in enumerate(result.execution.task_results, start=1):
await self._emit_stage_text(
thread_id=command.thread_id,
run_id=command.run_id,
stage_name="execution",
message_id=f"execution-{command.run_id}-{index}",
text=task.execution_summary,
response_metadata=task.response_metadata,
)
await self._emit_tool_result_events(
thread_id=command.thread_id,
run_id=command.run_id,
task_id=task.task_id,
tool_calls=_task_tool_calls(task),
)
await self._pipeline.emit(
session_id=command.thread_id,
event={
@@ -164,35 +203,18 @@ class AgentRouteRuntime:
)
report_message_id = f"assistant-{command.run_id}"
await self._pipeline.emit(
session_id=command.thread_id,
event={
"type": "text.start",
"threadId": command.thread_id,
"runId": command.run_id,
"data": {"messageId": report_message_id, "role": "assistant"},
},
response_metadata = (
result.report.response_metadata
if isinstance(result.report.response_metadata, dict)
else {}
)
await self._pipeline.emit(
session_id=command.thread_id,
event={
"type": "text.delta",
"threadId": command.thread_id,
"runId": command.run_id,
"data": {
"messageId": report_message_id,
"delta": result.report.assistant_text,
},
},
)
await self._pipeline.emit(
session_id=command.thread_id,
event={
"type": "text.end",
"threadId": command.thread_id,
"runId": command.run_id,
"data": {"messageId": report_message_id},
},
await self._emit_stage_text(
thread_id=command.thread_id,
run_id=command.run_id,
stage_name="report",
message_id=report_message_id,
text=result.report.assistant_text,
response_metadata=response_metadata,
)
await self._pipeline.emit(
session_id=command.thread_id,
@@ -213,3 +235,178 @@ class AgentRouteRuntime:
},
)
return result
async def _emit_stage_text(
self,
*,
thread_id: str,
run_id: str,
stage_name: str,
message_id: str,
text: str,
response_metadata: dict[str, Any],
) -> None:
await self._pipeline.emit(
session_id=thread_id,
event={
"type": "text.start",
"threadId": thread_id,
"runId": run_id,
"data": {
"messageId": message_id,
"role": "assistant",
"stage": stage_name,
},
},
)
await self._pipeline.emit(
session_id=thread_id,
event={
"type": "text.delta",
"threadId": thread_id,
"runId": run_id,
"data": {"messageId": message_id, "delta": text},
},
)
await self._pipeline.emit(
session_id=thread_id,
event={
"type": "text.end",
"threadId": thread_id,
"runId": run_id,
"data": {
"messageId": message_id,
"stage": stage_name,
**_text_end_telemetry_payload(response_metadata),
},
},
)
async def _emit_tool_result_events(
self,
*,
thread_id: str,
run_id: str,
task_id: str,
tool_calls: list[dict[str, Any]],
) -> None:
for index, tool_call in enumerate(tool_calls, start=1):
tool_name = tool_call.get("tool_name")
if not isinstance(tool_name, str) or not tool_name:
continue
await self._pipeline.emit(
session_id=thread_id,
event={
"type": "tool.result",
"threadId": thread_id,
"runId": run_id,
"data": {
"callId": f"{run_id}-{task_id}-{index}",
"stage": "execution",
"taskId": task_id,
"toolName": tool_name,
"args": tool_call.get("args", {}),
"result": tool_call.get("result"),
"error": tool_call.get("error"),
},
},
)
def _text_end_telemetry_payload(metadata: dict[str, Any]) -> dict[str, Any]:
payload: dict[str, Any] = {}
model = _first_non_empty_str(metadata, keys=("model", "model_code"))
if model is not None:
payload["model"] = model
input_tokens = _first_number(metadata, keys=("inputTokens", "input_tokens"))
if input_tokens is not None:
payload["inputTokens"] = input_tokens
output_tokens = _first_number(metadata, keys=("outputTokens", "output_tokens"))
if output_tokens is not None:
payload["outputTokens"] = output_tokens
latency_ms = _first_number(metadata, keys=("latencyMs", "latency_ms"))
if latency_ms is not None:
payload["latencyMs"] = latency_ms
cost = _first_number(metadata, keys=("cost", "total_cost"), allow_float=True)
if cost is not None:
payload["cost"] = cost
return payload
def _intent_text_payload(intent: Any) -> str:
direct_response = getattr(intent, "direct_response", None)
if isinstance(direct_response, str) and direct_response.strip():
return direct_response
return json.dumps(intent.model_dump(mode="json"), ensure_ascii=False)
def _task_tool_calls(task: Any) -> list[dict[str, Any]]:
normalized: list[dict[str, Any]] = []
tool_calls = getattr(task, "tool_calls", None)
if isinstance(tool_calls, list):
for item in tool_calls:
if hasattr(item, "model_dump"):
dumped = item.model_dump(mode="json")
if isinstance(dumped, dict):
normalized.append(dumped)
elif isinstance(item, dict):
normalized.append(item)
if normalized:
return normalized
execution_data = getattr(task, "execution_data", None)
if not isinstance(execution_data, dict):
return []
fallback_calls = execution_data.get("tool_calls")
if not isinstance(fallback_calls, list):
return []
for item in fallback_calls:
if isinstance(item, dict):
normalized.append(item)
return normalized
def _first_non_empty_str(
metadata: dict[str, Any], *, keys: tuple[str, ...]
) -> str | None:
for key in keys:
value = metadata.get(key)
if isinstance(value, str) and value.strip():
return value.strip()
return None
def _first_number(
metadata: dict[str, Any],
*,
keys: tuple[str, ...],
allow_float: bool = False,
) -> int | float | None:
for key in keys:
value = metadata.get(key)
if isinstance(value, bool):
continue
if isinstance(value, int):
if value < 0:
continue
return value
if isinstance(value, float):
if value < 0:
continue
return value if allow_float else int(value)
if isinstance(value, str):
try:
parsed = float(value) if allow_float else int(value)
except ValueError:
continue
if parsed >= 0:
return parsed
return None
@@ -110,6 +110,19 @@ class AgentScopeRuntimeOrchestrator:
)
intent_output = IntentOutput.model_validate(intent_payload)
if intent_output.route == "DIRECT_RESPONSE":
assistant_text = (
intent_output.direct_response or intent_output.intent_summary
)
return RuntimeOutput(
intent=intent_output,
execution=None,
report=ReportOutput(
assistant_text=assistant_text,
response_metadata=dict(intent_output.response_metadata),
),
)
execution_output: ExecutionBatchOutput | None = None
if intent_output.route == "TASK_EXECUTION":
execution_toolkit = build_stage_toolkit(
@@ -1,6 +1,8 @@
from __future__ import annotations
import asyncio
import json
from time import perf_counter
from typing import Any, cast
from core.agentscope.runtime.config_loader import RuntimeStageConfig
@@ -14,7 +16,8 @@ def _to_litellm_model(*, provider_name: str, model_code: str) -> str:
normalized_model = model_code.strip()
if "/" in normalized_model:
return normalized_model
return f"{provider_name.strip().lower()}/{normalized_model}"
del provider_name
return normalized_model
def _parse_json_text(raw_text: str) -> dict[str, Any]:
@@ -30,6 +33,11 @@ def _parse_json_text(raw_text: str) -> dict[str, Any]:
class AgentScopeReActRunner:
def _build_litellm_service(self) -> Any:
from services.litellm.service import LiteLLMService
return LiteLLMService()
def _build_model(self, *, stage_config: RuntimeStageConfig) -> Any:
from agentscope.model import OpenAIChatModel
from agentscope.types import JSONSerializableObject
@@ -61,9 +69,16 @@ class AgentScopeReActRunner:
stage_config: RuntimeStageConfig,
agent_name: str,
system_prompt: str,
user_prompt: str,
user_prompt: str | list[dict[str, Any]],
toolkit: Any | None,
) -> dict[str, Any]:
if stage_config.stage == "report" and toolkit is None:
return await self._run_report_stage_direct(
stage_config=stage_config,
system_prompt=system_prompt,
user_prompt=user_prompt,
)
from agentscope.agent import ReActAgent
from agentscope.formatter import OpenAIChatFormatter
from agentscope.memory import InMemoryMemory
@@ -79,9 +94,19 @@ class AgentScopeReActRunner:
max_iters=6,
)
try:
response = await agent(Msg(name="user", content=user_prompt, role="user"))
started_at = perf_counter()
response = await agent(
Msg(name="user", content=cast(Any, user_prompt), role="user")
)
latency_ms = int(round((perf_counter() - started_at) * 1000))
text_content = response.get_text_content() or "{}"
return _parse_json_text(text_content)
payload = _parse_json_text(text_content)
return _merge_stage_response_metadata(
payload=payload,
stage_config=stage_config,
response=response,
latency_ms=latency_ms,
)
except json.JSONDecodeError as exc:
logger.exception(
"agentscope stage output is not valid json",
@@ -96,3 +121,234 @@ class AgentScopeReActRunner:
agent_name=agent_name,
)
raise RuntimeError("agent execution failed") from exc
async def _run_report_stage_direct(
self,
*,
stage_config: RuntimeStageConfig,
system_prompt: str,
user_prompt: str | list[dict[str, Any]],
) -> dict[str, Any]:
try:
service = self._build_litellm_service()
started_at = perf_counter()
response_with_cost = await asyncio.to_thread(
service.run_completion_with_cost,
model=_to_litellm_model(
provider_name=stage_config.provider_name,
model_code=stage_config.model_code,
),
messages=_report_messages(
system_prompt=system_prompt,
user_prompt=user_prompt,
),
temperature=stage_config.llm_config.temperature,
max_tokens=stage_config.llm_config.max_tokens,
timeout=stage_config.llm_config.timeout_seconds,
response_format={"type": "json_object"},
)
latency_ms = int(round((perf_counter() - started_at) * 1000))
text_content = _chat_response_text(response_with_cost.response)
payload = _parse_json_text(text_content)
return _merge_report_response_metadata(
payload=payload,
stage_config=stage_config,
response_with_cost=response_with_cost,
latency_ms=latency_ms,
)
except json.JSONDecodeError as exc:
logger.exception(
"agentscope stage output is not valid json",
stage=stage_config.stage,
agent_name="report-agent",
)
raise RuntimeError("agent output format invalid") from exc
except Exception as exc:
logger.exception(
"agentscope stage execution failed",
stage=stage_config.stage,
agent_name="report-agent",
)
raise RuntimeError("agent execution failed") from exc
def _chat_response_text(response: Any) -> str:
content = _read_value(response, "content")
if isinstance(content, str) and content.strip():
return content
if not isinstance(content, list):
return _fallback_choice_content(response)
text_parts: list[str] = []
for block in content:
if not isinstance(block, dict):
continue
if block.get("type") != "text":
continue
text = block.get("text")
if isinstance(text, str) and text:
text_parts.append(text)
if text_parts:
return "".join(text_parts)
return _fallback_choice_content(response)
def _fallback_choice_content(response: Any) -> str:
choices = _read_value(response, "choices")
if not isinstance(choices, list) or not choices:
return "{}"
first_choice = choices[0]
message = getattr(first_choice, "message", None)
if message is None and isinstance(first_choice, dict):
message = first_choice.get("message")
if isinstance(message, dict):
content = message.get("content")
return content if isinstance(content, str) and content else "{}"
content = _read_value(message, "content")
return content if isinstance(content, str) and content else "{}"
def _read_value(source: Any, key: str) -> Any:
if isinstance(source, dict):
return source.get(key)
return getattr(source, key, None)
def _report_messages(
*, system_prompt: str, user_prompt: str | list[dict[str, Any]]
) -> list[dict[str, Any]]:
return [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
]
def _merge_stage_response_metadata(
*,
payload: dict[str, Any],
stage_config: RuntimeStageConfig,
response: Any,
latency_ms: int,
) -> dict[str, Any]:
result = dict(payload)
existing = result.get("response_metadata")
metadata: dict[str, Any] = dict(existing) if isinstance(existing, dict) else {}
metadata.setdefault("model", stage_config.model_code)
usage = _read_value(response, "usage")
prompt_tokens = _to_non_negative_int(
_read_value(usage, "prompt_tokens") or _read_value(usage, "input_tokens")
)
completion_tokens = _to_non_negative_int(
_read_value(usage, "completion_tokens") or _read_value(usage, "output_tokens")
)
cost = _to_non_negative_float(
_read_value(usage, "cost")
or _read_value(_read_value(usage, "metadata"), "cost")
)
resolved_model = _read_value(response, "model")
if cost is None and prompt_tokens is not None and completion_tokens is not None:
estimated_cost = _estimate_cost_by_pricing(
model=resolved_model
if isinstance(resolved_model, str)
else stage_config.model_code,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
)
if estimated_cost is not None:
cost = estimated_cost
if prompt_tokens is not None:
metadata["inputTokens"] = prompt_tokens
if completion_tokens is not None:
metadata["outputTokens"] = completion_tokens
if cost is not None:
metadata["cost"] = cost
if latency_ms >= 0:
metadata["latencyMs"] = latency_ms
result["response_metadata"] = metadata
return result
def _merge_report_response_metadata(
*,
payload: dict[str, Any],
stage_config: RuntimeStageConfig,
response_with_cost: Any,
latency_ms: int,
) -> dict[str, Any]:
result = dict(payload)
existing = result.get("response_metadata")
metadata: dict[str, Any] = dict(existing) if isinstance(existing, dict) else {}
usage = _read_value(response_with_cost, "usage")
response = _read_value(response_with_cost, "response")
resolved_model = _read_value(response, "model")
if isinstance(resolved_model, str) and resolved_model.strip():
metadata["model"] = resolved_model.strip()
else:
metadata.setdefault("model", stage_config.model_code)
input_tokens = _to_non_negative_int(_read_value(usage, "prompt_tokens"))
output_tokens = _to_non_negative_int(_read_value(usage, "completion_tokens"))
cost = _to_non_negative_float(_read_value(usage, "cost"))
if input_tokens is not None:
metadata["inputTokens"] = input_tokens
if output_tokens is not None:
metadata["outputTokens"] = output_tokens
if cost is not None:
metadata["cost"] = cost
if latency_ms >= 0:
metadata["latencyMs"] = latency_ms
result["response_metadata"] = metadata
return result
def _to_non_negative_int(value: Any) -> int | None:
if isinstance(value, bool):
return None
if not isinstance(value, (int, float, str)):
return None
try:
parsed = int(value)
except (TypeError, ValueError):
return None
return parsed if parsed >= 0 else None
def _to_non_negative_float(value: Any) -> float | None:
if isinstance(value, bool):
return None
if not isinstance(value, (int, float, str)):
return None
try:
parsed = float(value)
except (TypeError, ValueError):
return None
return parsed if parsed >= 0 else None
def _estimate_cost_by_pricing(
*, model: str, prompt_tokens: int, completion_tokens: int
) -> float | None:
normalized_model = model.strip()
if not normalized_model:
return None
from services.litellm.service import LiteLLMService
service = LiteLLMService()
try:
return service.calculate_cost(
model=normalized_model,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
)
except ValueError:
return None
@@ -1,8 +1,11 @@
from __future__ import annotations
from datetime import datetime, timedelta, timezone
from typing import Any
from uuid import UUID
from sqlalchemy import select
from core.agentscope.events import (
AgentScopeAgUiCodec,
AgentScopeEventPipeline,
@@ -18,6 +21,7 @@ from core.config.settings import config
from core.db.session import AsyncSessionLocal
from core.logging import get_logger
from core.taskiq.app import bulk_broker, critical_broker, default_broker
from models.agent_chat_message import AgentChatMessage, AgentChatMessageRole
from services.base.redis import get_or_init_redis_client
logger = get_logger("core.agentscope.runtime.tasks")
@@ -76,6 +80,56 @@ def _extract_user_token(
return None
async def _build_recent_context_messages(
*,
session: Any,
thread_id: str,
current_run_id: str,
max_messages: int = 20,
) -> list[dict[str, Any]]:
try:
session_uuid = UUID(thread_id)
except ValueError:
return []
utc_now = datetime.now(timezone.utc)
start_of_today = utc_now.replace(hour=0, minute=0, second=0, microsecond=0)
start_of_yesterday = start_of_today - timedelta(days=1)
stmt = (
select(AgentChatMessage)
.where(AgentChatMessage.session_id == session_uuid)
.where(AgentChatMessage.deleted_at.is_(None))
.where(AgentChatMessage.created_at >= start_of_yesterday)
.order_by(AgentChatMessage.seq.asc())
)
rows = (await session.execute(stmt)).scalars().all()
normalized: list[dict[str, Any]] = []
for row in rows:
metadata = row.metadata_json if isinstance(row.metadata_json, dict) else {}
if metadata.get("run_id") == current_run_id:
continue
role = (
row.role.value
if isinstance(row.role, AgentChatMessageRole)
else str(row.role)
)
if role not in {"user", "assistant"}:
continue
normalized.append(
{
"id": str(row.id),
"role": role,
"content": row.content,
}
)
if len(normalized) <= max_messages:
return normalized
return normalized[-max_messages:]
async def run_agentscope_task(command: dict[str, Any]) -> dict[str, object]:
command_type = str(command.get("command", "run")).strip().lower()
raw_run_input = command.get("run_input")
@@ -117,6 +171,21 @@ async def run_agentscope_task(command: dict[str, Any]) -> dict[str, object]:
)
async with AsyncSessionLocal() as session:
if command_type == "run":
context_messages = await _build_recent_context_messages(
session=session,
thread_id=parsed_run_input.thread_id,
current_run_id=parsed_run_input.run_id,
)
parsed_run_input = parsed_run_input.model_copy(
update={
"messages": [
*context_messages,
*parsed_run_input.messages,
]
}
)
if command_type == "resume":
await runtime.resume(
command=parsed_run_input,
@@ -5,12 +5,21 @@ from typing import Any, Literal
from pydantic import BaseModel, Field
class ExecutionToolCall(BaseModel):
tool_name: str = Field(min_length=1)
args: dict[str, Any] = Field(default_factory=dict)
result: Any | None = None
error: str | None = None
class ExecutionTaskOutput(BaseModel):
task_id: str = Field(min_length=1)
status: Literal["SUCCESS", "PARTIAL", "FAILED"]
execution_summary: str = Field(min_length=1)
execution_data: dict[str, Any] = Field(default_factory=dict)
user_feedback_needs: list[str] = Field(default_factory=list)
response_metadata: dict[str, Any] = Field(default_factory=dict)
tool_calls: list[ExecutionToolCall] = Field(default_factory=list)
class ExecutionBatchOutput(BaseModel):
@@ -1,5 +1,6 @@
from __future__ import annotations
from typing import Any
from typing import Literal
from pydantic import BaseModel, Field, model_validator
@@ -17,6 +18,7 @@ class IntentOutput(BaseModel):
direct_response: str | None = None
tasks: list[IntentTask] = Field(default_factory=list)
complexity: Literal["simple", "complex"]
response_metadata: dict[str, Any] = Field(default_factory=dict)
@model_validator(mode="after")
def validate_route(self) -> "IntentOutput":
@@ -1,3 +1,7 @@
from core.agentscope.tools.custom.calendar import calendar_read, calendar_write
from core.agentscope.tools.custom.calendar import (
calendar_read,
calendar_write,
user_resolve,
)
__all__ = ["calendar_read", "calendar_write"]
__all__ = ["calendar_read", "calendar_write", "user_resolve"]
@@ -7,6 +7,7 @@ from core.auth.jwt_verifier import JwtVerifier, TokenValidationError
from core.agentscope.tools.custom.calendar_backend_ops import (
_execute_list_calendar_events,
_execute_mutate_calendar_event,
_execute_resolve_user_identity,
)
from core.config.settings import config
from core.agentscope.tools.response import build_tool_response
@@ -150,6 +151,30 @@ async def calendar_write(
bool,
Field(description="Whether to use the replace strategy for conflicts."),
] = False,
invite_user_emails: Annotated[
list[str] | None,
Field(description="Optional invite targets by email."),
] = None,
invite_user_names: Annotated[
list[str] | None,
Field(description="Optional invite targets by username."),
] = None,
invite_user_ids: Annotated[
list[str] | None,
Field(description="Optional invite targets by user ID (UUID string)."),
] = None,
invite_permission_view: Annotated[
bool,
Field(description="Invite permission: view."),
] = True,
invite_permission_edit: Annotated[
bool,
Field(description="Invite permission: edit."),
] = False,
invite_permission_invite: Annotated[
bool,
Field(description="Invite permission: invite others."),
] = False,
session: Any = None,
owner_id: Any = None,
user_token: str | None = None,
@@ -240,6 +265,15 @@ async def calendar_write(
tool_args["reminderMinutes"] = reminder_minutes
if status is not None:
tool_args["status"] = status
if invite_user_emails is not None:
tool_args["inviteUserEmails"] = invite_user_emails
if invite_user_names is not None:
tool_args["inviteUserNames"] = invite_user_names
if invite_user_ids is not None:
tool_args["inviteUserIds"] = invite_user_ids
tool_args["invitePermissionView"] = invite_permission_view
tool_args["invitePermissionEdit"] = invite_permission_edit
tool_args["invitePermissionInvite"] = invite_permission_invite
result = await _execute_mutate_calendar_event(
session=cast(Any, session),
@@ -247,3 +281,34 @@ async def calendar_write(
tool_args=tool_args,
)
return build_tool_response(result)
async def user_resolve(
user_email: Annotated[
str | None,
Field(description="User email to resolve user ID."),
] = None,
user_name: Annotated[
str | None,
Field(description="Username to resolve user ID."),
] = None,
session: Any = None,
owner_id: Any = None,
user_token: str | None = None,
) -> Any:
if session is None or owner_id is None:
raise ValueError("user.resolve missing runtime preset arguments")
if not isinstance(user_token, str) or not user_token.strip():
return build_tool_response(_unauthorized_response())
if not _verify_user_token(user_token=user_token, owner_id=cast(UUID, owner_id)):
return build_tool_response(_unauthorized_response())
result = await _execute_resolve_user_identity(
session=cast(Any, session),
owner_id=cast(UUID, owner_id),
tool_args={
"userEmail": user_email,
"userName": user_name,
},
)
return build_tool_response(result)
@@ -4,13 +4,20 @@ import re
from datetime import datetime, timedelta, timezone
from uuid import UUID
from fastapi import HTTPException
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from core.auth.models import CurrentUser
from services.base.supabase import supabase_service
from models.profile import Profile
from v1.auth.gateway import SupabaseAuthGateway
from v1.inbox_messages.repository import SQLAlchemyInboxMessageRepository
from v1.schedule_items.repository import SQLAlchemyScheduleItemRepository
from v1.schedule_items.schemas import (
ScheduleItemCreateRequest,
ScheduleItemMetadata,
ScheduleItemShareRequest,
ScheduleItemStatus,
ScheduleItemUpdateRequest,
)
@@ -72,9 +79,196 @@ def _service(session: AsyncSession, owner_id: UUID) -> ScheduleItemService:
repository=SQLAlchemyScheduleItemRepository(session),
session=session,
current_user=CurrentUser(id=owner_id),
inbox_repository=SQLAlchemyInboxMessageRepository(session),
)
def _parse_string_list(value: object, *, field_name: str) -> list[str]:
if value is None:
return []
if not isinstance(value, list):
raise ValueError(f"{field_name} must be a list of strings")
parsed: list[str] = []
for item in value:
if not isinstance(item, str) or not item.strip():
raise ValueError(f"{field_name} must be a list of non-empty strings")
parsed.append(item.strip())
return parsed
def _list_auth_users() -> list[object]:
admin_client = supabase_service.get_admin_client()
users: list[object] = []
page = 1
while page <= 100:
response = admin_client.auth.admin.list_users(page=page, per_page=100)
batch = (
list(response)
if isinstance(response, list)
else list(getattr(response, "users", []))
)
users.extend(batch)
if len(batch) < 100:
break
page += 1
return users
async def _get_profile_username(*, session: AsyncSession, user_id: UUID) -> str | None:
stmt = select(Profile.username).where(Profile.id == user_id)
return (await session.execute(stmt)).scalar_one_or_none()
async def _get_profile_by_username(
*, session: AsyncSession, username: str
) -> Profile | None:
stmt = (
select(Profile)
.where(Profile.username == username)
.where(Profile.deleted_at.is_(None))
)
return (await session.execute(stmt)).scalar_one_or_none()
def _find_auth_email_by_user_id(*, users: list[object], user_id: UUID) -> str | None:
target = str(user_id)
for user in users:
if str(getattr(user, "id", "")) == target:
email = getattr(user, "email", None)
if isinstance(email, str) and email.strip():
return email.strip()
return None
async def _resolve_identity(
*,
session: AsyncSession,
user_email: str | None,
user_name: str | None,
) -> dict[str, object]:
email = user_email.strip().lower() if isinstance(user_email, str) else ""
name = user_name.strip() if isinstance(user_name, str) else ""
if bool(email) == bool(name):
raise ValueError("provide exactly one of user_email or user_name")
if email:
auth_gateway = SupabaseAuthGateway()
user = await auth_gateway.get_user_by_email(email)
user_id = UUID(user.id)
username = await _get_profile_username(session=session, user_id=user_id)
return {
"userId": str(user_id),
"email": user.email,
"username": username,
"matchedBy": "email",
}
profile = await _get_profile_by_username(session=session, username=name)
if profile is None:
raise HTTPException(status_code=404, detail="User not found")
users = _list_auth_users()
email_value = _find_auth_email_by_user_id(users=users, user_id=profile.id)
return {
"userId": str(profile.id),
"email": email_value,
"username": profile.username,
"matchedBy": "username",
}
def _invite_permission(tool_args: dict[str, object]) -> dict[str, bool]:
return {
"permission_view": bool(tool_args.get("invitePermissionView", True)),
"permission_edit": bool(tool_args.get("invitePermissionEdit", False)),
"permission_invite": bool(tool_args.get("invitePermissionInvite", False)),
}
async def _share_event_with_invitees(
*,
session: AsyncSession,
owner_id: UUID,
event_id: UUID,
tool_args: dict[str, object],
) -> dict[str, object] | None:
email_targets = _parse_string_list(
tool_args.get("inviteUserEmails"),
field_name="inviteUserEmails",
)
name_targets = _parse_string_list(
tool_args.get("inviteUserNames"),
field_name="inviteUserNames",
)
id_targets = _parse_string_list(
tool_args.get("inviteUserIds"),
field_name="inviteUserIds",
)
if not email_targets and not name_targets and not id_targets:
return None
users = _list_auth_users() if id_targets else []
emails = {item.lower() for item in email_targets}
for user_id_raw in id_targets:
try:
user_id = UUID(user_id_raw)
except ValueError as exc:
raise ValueError("inviteUserIds must contain valid UUID strings") from exc
resolved_email = _find_auth_email_by_user_id(users=users, user_id=user_id)
if resolved_email is None:
raise HTTPException(status_code=404, detail="Invite user email not found")
emails.add(resolved_email.lower())
for username in name_targets:
resolved = await _resolve_identity(
session=session,
user_email=None,
user_name=username,
)
resolved_email = resolved.get("email")
if not isinstance(resolved_email, str) or not resolved_email:
raise HTTPException(status_code=404, detail="Invite user email not found")
emails.add(resolved_email.lower())
service = _service(session, owner_id)
permission = _invite_permission(tool_args)
invited: list[str] = []
for email in sorted(emails):
request = ScheduleItemShareRequest(email=email, **permission)
await service.share(event_id, request)
invited.append(email)
return {
"count": len(invited),
"emails": invited,
"permission": permission,
}
async def _execute_resolve_user_identity(
*,
session: AsyncSession,
owner_id: UUID,
tool_args: dict[str, object],
) -> dict[str, object]:
del owner_id
user_email_raw = tool_args.get("userEmail")
user_name_raw = tool_args.get("userName")
user_email = user_email_raw if isinstance(user_email_raw, str) else None
user_name = user_name_raw if isinstance(user_name_raw, str) else None
resolved = await _resolve_identity(
session=session,
user_email=user_email,
user_name=user_name,
)
return {
"type": "user_lookup.v1",
"version": "v1",
"data": {
"ok": True,
**resolved,
},
"actions": [],
}
def _resolve_metadata(tool_args: dict[str, object]) -> ScheduleItemMetadata:
location = tool_args.get("location")
location_value = location.strip() if isinstance(location, str) else None
@@ -185,6 +379,12 @@ async def _execute_create(
)
event_data = _event_payload(created)
event_id = str(event_data["id"])
invite_result = await _share_event_with_invitees(
session=service._session,
owner_id=service.require_user_id(),
event_id=UUID(event_id),
tool_args=tool_args,
)
return {
"type": "calendar_card.v1",
"version": "v1",
@@ -193,12 +393,13 @@ async def _execute_create(
"sourceType": "agent_generated",
"ok": True,
"message": "日程已创建",
"inviteResult": invite_result,
},
"actions": [
{
"type": "link",
"label": "查看详情",
"target": f"/calendar/events/{event_id}",
"target": f"/schedule-items/{event_id}",
}
],
}
@@ -274,6 +475,12 @@ async def _execute_update(
ScheduleItemUpdateRequest.model_validate(update_data),
)
event_data = _event_payload(updated)
invite_result = await _share_event_with_invitees(
session=service._session,
owner_id=service.require_user_id(),
event_id=UUID(str(event_data["id"])),
tool_args=tool_args,
)
return {
"type": "calendar_card.v1",
"version": "v1",
@@ -282,12 +489,13 @@ async def _execute_update(
"sourceType": "agent_generated",
"ok": True,
"message": "日程已更新",
"inviteResult": invite_result,
},
"actions": [
{
"type": "link",
"label": "查看详情",
"target": f"/calendar/events/{event_data['id']}",
"target": f"/schedule-items/{event_data['id']}",
}
],
}
@@ -4,8 +4,9 @@ from dataclasses import dataclass
TOOL_APPROVAL_REQUIRED: dict[str, bool] = {
"calendar.read": False,
"calendar.write": False,
"calendar_read": False,
"calendar_write": False,
"user_resolve": False,
}
+20 -5
View File
@@ -6,7 +6,11 @@ from uuid import UUID
from sqlalchemy.ext.asyncio import AsyncSession
from core.agentscope.tools.custom.calendar import calendar_read, calendar_write
from core.agentscope.tools.custom.calendar import (
calendar_read,
calendar_write,
user_resolve,
)
from core.agentscope.tools.hitl_middleware import register_tool_middlewares
from core.agentscope.tools.tool_meta import TOOL_META
@@ -25,10 +29,12 @@ class ToolGroup:
TOOL_GROUPS: dict[str, ToolGroup] = {
"intent": ToolGroup(stage="intent", tool_names=frozenset({"calendar.read"})),
"intent": ToolGroup(
stage="intent", tool_names=frozenset({"calendar_read", "user_resolve"})
),
"execution": ToolGroup(
stage="execution",
tool_names=frozenset({"calendar.read", "calendar.write"}),
tool_names=frozenset({"calendar_read", "calendar_write", "user_resolve"}),
),
"report": ToolGroup(stage="report", tool_names=frozenset()),
}
@@ -49,7 +55,7 @@ def _load_custom_tool_bindings(
) -> list[CustomToolBinding]:
return [
CustomToolBinding(
name="calendar.read",
name="calendar_read",
func=calendar_read,
preset_kwargs={
"session": session,
@@ -58,7 +64,7 @@ def _load_custom_tool_bindings(
},
),
CustomToolBinding(
name="calendar.write",
name="calendar_write",
func=calendar_write,
preset_kwargs={
"session": session,
@@ -66,6 +72,15 @@ def _load_custom_tool_bindings(
"user_token": user_token or "",
},
),
CustomToolBinding(
name="user_resolve",
func=user_resolve,
preset_kwargs={
"session": session,
"owner_id": owner_id,
"user_token": user_token or "",
},
),
]
+15 -10
View File
@@ -126,21 +126,26 @@ class LiteLLMService:
temperature: float | None = None,
max_tokens: int | None = None,
timeout: float | None = None,
response_format: dict[str, Any] | None = None,
completion_fn: Callable[..., dict[str, Any]] | None = None,
) -> LiteLLMResponseWithCost:
caller = completion_fn or completion
request_model = model if model.startswith("openai/") else f"openai/{model}"
response_any = caller(
model=request_model,
api_key=self.proxy_api_key,
api_base=self.proxy_base_url,
messages=messages,
temperature=temperature,
max_tokens=max_tokens,
timeout=timeout,
stream=False,
)
request_kwargs: dict[str, Any] = {
"model": request_model,
"api_key": self.proxy_api_key,
"api_base": self.proxy_base_url,
"messages": messages,
"temperature": temperature,
"max_tokens": max_tokens,
"timeout": timeout,
"stream": False,
}
if response_format is not None:
request_kwargs["response_format"] = response_format
response_any = caller(**request_kwargs)
response = self._normalize_response(response_any)
usage_raw = response.get("usage")
+15
View File
@@ -107,6 +107,10 @@ class AgentRepository:
raise HTTPException(status_code=404, detail="Session not found")
next_seq = int(session_row.message_count or 0) + 1
if not _has_title(session_row.title):
session_title = _derive_session_title(content_text)
if session_title is not None:
session_row.title = session_title
payload_metadata = dict(metadata or {})
payload_metadata["run_id"] = run_id
message = AgentChatMessage(
@@ -264,3 +268,14 @@ class AgentRepository:
if rendered:
payload["attachments"] = rendered
return payload
def _has_title(title: object) -> bool:
return isinstance(title, str) and bool(title.strip())
def _derive_session_title(content_text: str) -> str | None:
normalized = " ".join(content_text.split())
if not normalized:
return None
return normalized[:80]
+5
View File
@@ -203,6 +203,11 @@ async def stream_events(
user_id=str(current_user.id),
reason=str(exc),
)
if "Timeout reading from" in str(exc):
idle_polls += 1
yield ": keep-alive\n\n"
await asyncio.sleep(0.2)
continue
break
if not rows:
+13 -6
View File
@@ -212,12 +212,19 @@ class AgentService:
content_type=mime_type,
)
except Exception: # noqa: BLE001
bucket_name = "private"
stored_path = await self._attachment_storage.upload_bytes(
bucket=bucket_name,
path=path,
content=payload,
content_type=mime_type,
logger.exception(
"Attachment upload failed",
extra={
"bucket": bucket_name,
"path": path,
"mime_type": mime_type,
"thread_id": run_input.thread_id,
"run_id": run_input.run_id,
},
)
raise HTTPException(
status_code=502,
detail="Failed to upload attachment",
)
attachments.append(
{