feat(agent): 增强多模态链路与工具调用能力
This commit is contained in:
@@ -23,10 +23,12 @@ class MessageRepository:
|
||||
role: AgentChatMessageRole,
|
||||
content: str,
|
||||
model_code: str | None = None,
|
||||
tool_name: str | None = None,
|
||||
metadata: dict[str, object] | None = None,
|
||||
input_tokens: int = 0,
|
||||
output_tokens: int = 0,
|
||||
cost: Decimal = Decimal("0"),
|
||||
latency_ms: int | None = None,
|
||||
) -> AgentChatMessage:
|
||||
message = AgentChatMessage(
|
||||
session_id=session_id,
|
||||
@@ -34,10 +36,12 @@ class MessageRepository:
|
||||
role=role,
|
||||
content=content,
|
||||
model_code=model_code,
|
||||
tool_name=tool_name,
|
||||
metadata_json=metadata,
|
||||
input_tokens=input_tokens,
|
||||
output_tokens=output_tokens,
|
||||
cost=cost,
|
||||
latency_ms=latency_ms,
|
||||
)
|
||||
self._session.add(message)
|
||||
await self._session.flush()
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from decimal import Decimal, InvalidOperation
|
||||
from typing import Any, Callable, Protocol
|
||||
from uuid import UUID
|
||||
@@ -24,6 +25,7 @@ class SqlAlchemyEventStore:
|
||||
def __init__(self, *, session_factory: Any) -> None:
|
||||
self._session_factory = session_factory
|
||||
self._message_buffers: dict[tuple[str, str], str] = {}
|
||||
self._message_contexts: dict[tuple[str, str], dict[str, object]] = {}
|
||||
|
||||
async def persist(self, event: dict[str, Any]) -> None:
|
||||
event_type = str(event.get("type", "")).strip().upper()
|
||||
@@ -48,6 +50,10 @@ class SqlAlchemyEventStore:
|
||||
self._buffer_text_delta(session_key=session_key, event=event)
|
||||
return
|
||||
|
||||
if event_type == "TEXT_MESSAGE_START":
|
||||
self._buffer_text_context(session_key=session_key, event=event)
|
||||
return
|
||||
|
||||
if event_type == "RUN_STARTED":
|
||||
await self._update_session_state(
|
||||
session_repo=session_repo,
|
||||
@@ -72,7 +78,15 @@ class SqlAlchemyEventStore:
|
||||
)
|
||||
self._clear_session_buffers(session_key=session_key)
|
||||
elif event_type == "TEXT_MESSAGE_END":
|
||||
await self._persist_assistant_message(
|
||||
await self._persist_text_message(
|
||||
event=event,
|
||||
session_id=session_id,
|
||||
chat_session=chat_session,
|
||||
session_repo=session_repo,
|
||||
message_repo=message_repo,
|
||||
)
|
||||
elif event_type == "TOOL_CALL_RESULT":
|
||||
await self._persist_tool_call_result(
|
||||
event=event,
|
||||
session_id=session_id,
|
||||
chat_session=chat_session,
|
||||
@@ -97,8 +111,28 @@ class SqlAlchemyEventStore:
|
||||
stale_keys = [k for k in self._message_buffers if k[0] == session_key]
|
||||
for key in stale_keys:
|
||||
self._message_buffers.pop(key, None)
|
||||
stale_context_keys = [k for k in self._message_contexts if k[0] == session_key]
|
||||
for key in stale_context_keys:
|
||||
self._message_contexts.pop(key, None)
|
||||
|
||||
async def _persist_assistant_message(
|
||||
def _buffer_text_context(self, *, session_key: str, event: dict[str, Any]) -> None:
|
||||
message_id = event.get("messageId")
|
||||
if not isinstance(message_id, str) or not message_id:
|
||||
return
|
||||
key = (session_key, message_id)
|
||||
role = event.get("role")
|
||||
stage = event.get("stage")
|
||||
tool_name = event.get("toolName")
|
||||
context: dict[str, object] = {}
|
||||
if isinstance(role, str) and role:
|
||||
context["role"] = role
|
||||
if isinstance(stage, str) and stage:
|
||||
context["stage"] = stage
|
||||
if isinstance(tool_name, str) and tool_name:
|
||||
context["tool_name"] = tool_name
|
||||
self._message_contexts[key] = context
|
||||
|
||||
async def _persist_text_message(
|
||||
self,
|
||||
*,
|
||||
event: dict[str, Any],
|
||||
@@ -114,6 +148,8 @@ class SqlAlchemyEventStore:
|
||||
if not content:
|
||||
return
|
||||
|
||||
context = self._message_contexts.get(key, {})
|
||||
|
||||
input_tokens = self._to_int(event.get("inputTokens"))
|
||||
output_tokens = self._to_int(event.get("outputTokens"))
|
||||
token_delta = input_tokens + output_tokens
|
||||
@@ -127,6 +163,20 @@ class SqlAlchemyEventStore:
|
||||
metadata["run_id"] = run_id
|
||||
if latency_ms is not None:
|
||||
metadata["latency_ms"] = latency_ms
|
||||
stage = event.get("stage")
|
||||
if not isinstance(stage, str):
|
||||
stage = context.get("stage")
|
||||
if isinstance(stage, str) and stage:
|
||||
metadata["stage"] = stage
|
||||
|
||||
role_value = context.get("role")
|
||||
if not isinstance(role_value, str):
|
||||
role_value = "assistant"
|
||||
role = self._resolve_role(role_value)
|
||||
tool_name = context.get("tool_name")
|
||||
tool_name_value = (
|
||||
tool_name if isinstance(tool_name, str) and tool_name else None
|
||||
)
|
||||
|
||||
locked_session = await session_repo.lock_session_for_update(
|
||||
session_id=session_id
|
||||
@@ -137,13 +187,15 @@ class SqlAlchemyEventStore:
|
||||
await message_repo.append_message(
|
||||
session_id=session_id,
|
||||
seq=seq,
|
||||
role=AgentChatMessageRole.ASSISTANT,
|
||||
role=role,
|
||||
content=content,
|
||||
model_code=model_code if isinstance(model_code, str) else None,
|
||||
tool_name=tool_name_value,
|
||||
metadata=metadata,
|
||||
input_tokens=input_tokens,
|
||||
output_tokens=output_tokens,
|
||||
cost=cost,
|
||||
latency_ms=latency_ms,
|
||||
)
|
||||
|
||||
current_status = getattr(chat_session, "status", AgentChatSessionStatus.RUNNING)
|
||||
@@ -161,6 +213,74 @@ class SqlAlchemyEventStore:
|
||||
cost_delta=cost,
|
||||
)
|
||||
self._message_buffers.pop(key, None)
|
||||
self._message_contexts.pop(key, None)
|
||||
|
||||
async def _persist_tool_call_result(
|
||||
self,
|
||||
*,
|
||||
event: dict[str, Any],
|
||||
session_id: UUID,
|
||||
chat_session: Any,
|
||||
session_repo: SessionRepository,
|
||||
message_repo: MessageRepository,
|
||||
) -> None:
|
||||
tool_name = event.get("toolName")
|
||||
if not isinstance(tool_name, str) or not tool_name:
|
||||
return
|
||||
|
||||
payload = {
|
||||
"args": event.get("args"),
|
||||
"result": event.get("result"),
|
||||
"error": event.get("error"),
|
||||
"call_id": event.get("callId"),
|
||||
}
|
||||
content = json.dumps(payload, ensure_ascii=False, separators=(",", ":"))
|
||||
metadata: dict[str, object] = {"tool_name": tool_name}
|
||||
run_id = event.get("runId")
|
||||
if isinstance(run_id, str) and run_id:
|
||||
metadata["run_id"] = run_id
|
||||
stage = event.get("stage")
|
||||
if isinstance(stage, str) and stage:
|
||||
metadata["stage"] = stage
|
||||
task_id = event.get("taskId")
|
||||
if isinstance(task_id, str) and task_id:
|
||||
metadata["task_id"] = task_id
|
||||
|
||||
locked_session = await session_repo.lock_session_for_update(
|
||||
session_id=session_id
|
||||
)
|
||||
if locked_session is None:
|
||||
return
|
||||
seq = int(getattr(locked_session, "message_count", 0) or 0) + 1
|
||||
await message_repo.append_message(
|
||||
session_id=session_id,
|
||||
seq=seq,
|
||||
role=AgentChatMessageRole.TOOL,
|
||||
content=content,
|
||||
tool_name=tool_name,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
current_status = getattr(chat_session, "status", AgentChatSessionStatus.RUNNING)
|
||||
status = (
|
||||
current_status
|
||||
if isinstance(current_status, AgentChatSessionStatus)
|
||||
else AgentChatSessionStatus.RUNNING
|
||||
)
|
||||
await self._update_session_state(
|
||||
session_repo=session_repo,
|
||||
chat_session=chat_session,
|
||||
status=status,
|
||||
message_delta=1,
|
||||
)
|
||||
|
||||
def _resolve_role(self, value: str) -> AgentChatMessageRole:
|
||||
normalized = value.strip().lower()
|
||||
if normalized == AgentChatMessageRole.SYSTEM.value:
|
||||
return AgentChatMessageRole.SYSTEM
|
||||
if normalized == AgentChatMessageRole.TOOL.value:
|
||||
return AgentChatMessageRole.TOOL
|
||||
return AgentChatMessageRole.ASSISTANT
|
||||
|
||||
async def _update_session_state(
|
||||
self,
|
||||
|
||||
@@ -38,7 +38,38 @@ def _schema_json(model: type[Any]) -> str:
|
||||
)
|
||||
|
||||
|
||||
def build_intent_user_prompt(*, user_input: str | list[dict[str, Any]]) -> str:
|
||||
def build_intent_user_prompt(
|
||||
*, user_input: str | list[dict[str, Any]]
|
||||
) -> str | list[dict[str, Any]]:
|
||||
if isinstance(user_input, list):
|
||||
instruction_text = "\n\n".join(
|
||||
[
|
||||
INTENT_TASK_INSTRUCTION,
|
||||
"[Output Schema]",
|
||||
_schema_json(IntentOutput),
|
||||
"[User Input]",
|
||||
"Use the following multimodal blocks as the latest user input.",
|
||||
]
|
||||
)
|
||||
blocks = [
|
||||
{
|
||||
"type": "text",
|
||||
"text": instruction_text,
|
||||
}
|
||||
]
|
||||
user_blocks = _latest_user_content_blocks(user_input)
|
||||
if not user_blocks:
|
||||
user_blocks = [
|
||||
{
|
||||
"type": "text",
|
||||
"text": json.dumps(
|
||||
user_input, ensure_ascii=True, separators=(",", ":")
|
||||
),
|
||||
}
|
||||
]
|
||||
blocks.extend(user_blocks)
|
||||
return blocks
|
||||
|
||||
normalized_input = (
|
||||
user_input
|
||||
if isinstance(user_input, str)
|
||||
@@ -55,6 +86,101 @@ def build_intent_user_prompt(*, user_input: str | list[dict[str, Any]]) -> str:
|
||||
)
|
||||
|
||||
|
||||
def _latest_user_content_blocks(
|
||||
user_input: list[dict[str, Any]],
|
||||
) -> list[dict[str, Any]]:
|
||||
for message in reversed(user_input):
|
||||
if not isinstance(message, dict):
|
||||
continue
|
||||
if message.get("role") != "user":
|
||||
continue
|
||||
content = message.get("content")
|
||||
if isinstance(content, str):
|
||||
text = content.strip()
|
||||
return [{"type": "text", "text": text}] if text else []
|
||||
if not isinstance(content, list):
|
||||
return []
|
||||
|
||||
blocks: list[dict[str, Any]] = []
|
||||
for item in content:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
item_type = item.get("type")
|
||||
if item_type == "text":
|
||||
text = item.get("text")
|
||||
if isinstance(text, str) and text.strip():
|
||||
blocks.append({"type": "text", "text": text})
|
||||
continue
|
||||
|
||||
if item_type == "binary":
|
||||
source_block = _binary_source_block(item)
|
||||
if source_block is not None:
|
||||
blocks.append(source_block)
|
||||
continue
|
||||
|
||||
if item_type == "image":
|
||||
source_block = _image_source_block(item)
|
||||
if source_block is not None:
|
||||
blocks.append(source_block)
|
||||
|
||||
return blocks
|
||||
return []
|
||||
|
||||
|
||||
def _binary_source_block(item: dict[str, Any]) -> dict[str, Any] | None:
|
||||
mime_type = item.get("mimeType")
|
||||
media_type = mime_type if isinstance(mime_type, str) and mime_type else "image/png"
|
||||
if not media_type.startswith("image/"):
|
||||
return None
|
||||
|
||||
source_url = item.get("url")
|
||||
if isinstance(source_url, str) and source_url:
|
||||
return {"type": "image", "source": {"type": "url", "url": source_url}}
|
||||
|
||||
source_data = item.get("data")
|
||||
if isinstance(source_data, str) and source_data:
|
||||
return {
|
||||
"type": "image",
|
||||
"source": {
|
||||
"type": "base64",
|
||||
"media_type": media_type,
|
||||
"data": source_data,
|
||||
},
|
||||
}
|
||||
return None
|
||||
|
||||
|
||||
def _image_source_block(item: dict[str, Any]) -> dict[str, Any] | None:
|
||||
source = item.get("source")
|
||||
if not isinstance(source, dict):
|
||||
return None
|
||||
|
||||
source_type = source.get("type")
|
||||
if source_type == "url":
|
||||
source_url = source.get("value") or source.get("url")
|
||||
if isinstance(source_url, str) and source_url:
|
||||
return {"type": "image", "source": {"type": "url", "url": source_url}}
|
||||
|
||||
if source_type in {"data", "base64"}:
|
||||
source_data = source.get("value") or source.get("data")
|
||||
if isinstance(source_data, str) and source_data:
|
||||
mime_type = source.get("mimeType") or source.get("media_type")
|
||||
media_type = (
|
||||
mime_type if isinstance(mime_type, str) and mime_type else "image/png"
|
||||
)
|
||||
if not media_type.startswith("image/"):
|
||||
return None
|
||||
return {
|
||||
"type": "image",
|
||||
"source": {
|
||||
"type": "base64",
|
||||
"media_type": media_type,
|
||||
"data": source_data,
|
||||
},
|
||||
}
|
||||
return None
|
||||
|
||||
|
||||
def build_execution_user_prompt(
|
||||
*,
|
||||
task_id: str,
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Any, Protocol
|
||||
from uuid import UUID
|
||||
|
||||
@@ -153,6 +154,44 @@ class AgentRouteRuntime:
|
||||
},
|
||||
)
|
||||
|
||||
await self._emit_stage_text(
|
||||
thread_id=command.thread_id,
|
||||
run_id=command.run_id,
|
||||
stage_name="intent",
|
||||
message_id=f"intent-{command.run_id}",
|
||||
text=_intent_text_payload(result.intent),
|
||||
response_metadata=result.intent.response_metadata,
|
||||
)
|
||||
|
||||
if result.intent.route == "DIRECT_RESPONSE" and result.execution is None:
|
||||
await self._pipeline.emit(
|
||||
session_id=command.thread_id,
|
||||
event={
|
||||
"type": "run.finished",
|
||||
"threadId": command.thread_id,
|
||||
"runId": command.run_id,
|
||||
"data": {},
|
||||
},
|
||||
)
|
||||
return result
|
||||
|
||||
if result.execution is not None:
|
||||
for index, task in enumerate(result.execution.task_results, start=1):
|
||||
await self._emit_stage_text(
|
||||
thread_id=command.thread_id,
|
||||
run_id=command.run_id,
|
||||
stage_name="execution",
|
||||
message_id=f"execution-{command.run_id}-{index}",
|
||||
text=task.execution_summary,
|
||||
response_metadata=task.response_metadata,
|
||||
)
|
||||
await self._emit_tool_result_events(
|
||||
thread_id=command.thread_id,
|
||||
run_id=command.run_id,
|
||||
task_id=task.task_id,
|
||||
tool_calls=_task_tool_calls(task),
|
||||
)
|
||||
|
||||
await self._pipeline.emit(
|
||||
session_id=command.thread_id,
|
||||
event={
|
||||
@@ -164,35 +203,18 @@ class AgentRouteRuntime:
|
||||
)
|
||||
|
||||
report_message_id = f"assistant-{command.run_id}"
|
||||
await self._pipeline.emit(
|
||||
session_id=command.thread_id,
|
||||
event={
|
||||
"type": "text.start",
|
||||
"threadId": command.thread_id,
|
||||
"runId": command.run_id,
|
||||
"data": {"messageId": report_message_id, "role": "assistant"},
|
||||
},
|
||||
response_metadata = (
|
||||
result.report.response_metadata
|
||||
if isinstance(result.report.response_metadata, dict)
|
||||
else {}
|
||||
)
|
||||
await self._pipeline.emit(
|
||||
session_id=command.thread_id,
|
||||
event={
|
||||
"type": "text.delta",
|
||||
"threadId": command.thread_id,
|
||||
"runId": command.run_id,
|
||||
"data": {
|
||||
"messageId": report_message_id,
|
||||
"delta": result.report.assistant_text,
|
||||
},
|
||||
},
|
||||
)
|
||||
await self._pipeline.emit(
|
||||
session_id=command.thread_id,
|
||||
event={
|
||||
"type": "text.end",
|
||||
"threadId": command.thread_id,
|
||||
"runId": command.run_id,
|
||||
"data": {"messageId": report_message_id},
|
||||
},
|
||||
await self._emit_stage_text(
|
||||
thread_id=command.thread_id,
|
||||
run_id=command.run_id,
|
||||
stage_name="report",
|
||||
message_id=report_message_id,
|
||||
text=result.report.assistant_text,
|
||||
response_metadata=response_metadata,
|
||||
)
|
||||
await self._pipeline.emit(
|
||||
session_id=command.thread_id,
|
||||
@@ -213,3 +235,178 @@ class AgentRouteRuntime:
|
||||
},
|
||||
)
|
||||
return result
|
||||
|
||||
async def _emit_stage_text(
|
||||
self,
|
||||
*,
|
||||
thread_id: str,
|
||||
run_id: str,
|
||||
stage_name: str,
|
||||
message_id: str,
|
||||
text: str,
|
||||
response_metadata: dict[str, Any],
|
||||
) -> None:
|
||||
await self._pipeline.emit(
|
||||
session_id=thread_id,
|
||||
event={
|
||||
"type": "text.start",
|
||||
"threadId": thread_id,
|
||||
"runId": run_id,
|
||||
"data": {
|
||||
"messageId": message_id,
|
||||
"role": "assistant",
|
||||
"stage": stage_name,
|
||||
},
|
||||
},
|
||||
)
|
||||
await self._pipeline.emit(
|
||||
session_id=thread_id,
|
||||
event={
|
||||
"type": "text.delta",
|
||||
"threadId": thread_id,
|
||||
"runId": run_id,
|
||||
"data": {"messageId": message_id, "delta": text},
|
||||
},
|
||||
)
|
||||
await self._pipeline.emit(
|
||||
session_id=thread_id,
|
||||
event={
|
||||
"type": "text.end",
|
||||
"threadId": thread_id,
|
||||
"runId": run_id,
|
||||
"data": {
|
||||
"messageId": message_id,
|
||||
"stage": stage_name,
|
||||
**_text_end_telemetry_payload(response_metadata),
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
async def _emit_tool_result_events(
|
||||
self,
|
||||
*,
|
||||
thread_id: str,
|
||||
run_id: str,
|
||||
task_id: str,
|
||||
tool_calls: list[dict[str, Any]],
|
||||
) -> None:
|
||||
for index, tool_call in enumerate(tool_calls, start=1):
|
||||
tool_name = tool_call.get("tool_name")
|
||||
if not isinstance(tool_name, str) or not tool_name:
|
||||
continue
|
||||
await self._pipeline.emit(
|
||||
session_id=thread_id,
|
||||
event={
|
||||
"type": "tool.result",
|
||||
"threadId": thread_id,
|
||||
"runId": run_id,
|
||||
"data": {
|
||||
"callId": f"{run_id}-{task_id}-{index}",
|
||||
"stage": "execution",
|
||||
"taskId": task_id,
|
||||
"toolName": tool_name,
|
||||
"args": tool_call.get("args", {}),
|
||||
"result": tool_call.get("result"),
|
||||
"error": tool_call.get("error"),
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def _text_end_telemetry_payload(metadata: dict[str, Any]) -> dict[str, Any]:
|
||||
payload: dict[str, Any] = {}
|
||||
model = _first_non_empty_str(metadata, keys=("model", "model_code"))
|
||||
if model is not None:
|
||||
payload["model"] = model
|
||||
|
||||
input_tokens = _first_number(metadata, keys=("inputTokens", "input_tokens"))
|
||||
if input_tokens is not None:
|
||||
payload["inputTokens"] = input_tokens
|
||||
|
||||
output_tokens = _first_number(metadata, keys=("outputTokens", "output_tokens"))
|
||||
if output_tokens is not None:
|
||||
payload["outputTokens"] = output_tokens
|
||||
|
||||
latency_ms = _first_number(metadata, keys=("latencyMs", "latency_ms"))
|
||||
if latency_ms is not None:
|
||||
payload["latencyMs"] = latency_ms
|
||||
|
||||
cost = _first_number(metadata, keys=("cost", "total_cost"), allow_float=True)
|
||||
if cost is not None:
|
||||
payload["cost"] = cost
|
||||
|
||||
return payload
|
||||
|
||||
|
||||
def _intent_text_payload(intent: Any) -> str:
|
||||
direct_response = getattr(intent, "direct_response", None)
|
||||
if isinstance(direct_response, str) and direct_response.strip():
|
||||
return direct_response
|
||||
return json.dumps(intent.model_dump(mode="json"), ensure_ascii=False)
|
||||
|
||||
|
||||
def _task_tool_calls(task: Any) -> list[dict[str, Any]]:
|
||||
normalized: list[dict[str, Any]] = []
|
||||
|
||||
tool_calls = getattr(task, "tool_calls", None)
|
||||
if isinstance(tool_calls, list):
|
||||
for item in tool_calls:
|
||||
if hasattr(item, "model_dump"):
|
||||
dumped = item.model_dump(mode="json")
|
||||
if isinstance(dumped, dict):
|
||||
normalized.append(dumped)
|
||||
elif isinstance(item, dict):
|
||||
normalized.append(item)
|
||||
|
||||
if normalized:
|
||||
return normalized
|
||||
|
||||
execution_data = getattr(task, "execution_data", None)
|
||||
if not isinstance(execution_data, dict):
|
||||
return []
|
||||
fallback_calls = execution_data.get("tool_calls")
|
||||
if not isinstance(fallback_calls, list):
|
||||
return []
|
||||
|
||||
for item in fallback_calls:
|
||||
if isinstance(item, dict):
|
||||
normalized.append(item)
|
||||
return normalized
|
||||
|
||||
|
||||
def _first_non_empty_str(
|
||||
metadata: dict[str, Any], *, keys: tuple[str, ...]
|
||||
) -> str | None:
|
||||
for key in keys:
|
||||
value = metadata.get(key)
|
||||
if isinstance(value, str) and value.strip():
|
||||
return value.strip()
|
||||
return None
|
||||
|
||||
|
||||
def _first_number(
|
||||
metadata: dict[str, Any],
|
||||
*,
|
||||
keys: tuple[str, ...],
|
||||
allow_float: bool = False,
|
||||
) -> int | float | None:
|
||||
for key in keys:
|
||||
value = metadata.get(key)
|
||||
if isinstance(value, bool):
|
||||
continue
|
||||
if isinstance(value, int):
|
||||
if value < 0:
|
||||
continue
|
||||
return value
|
||||
if isinstance(value, float):
|
||||
if value < 0:
|
||||
continue
|
||||
return value if allow_float else int(value)
|
||||
if isinstance(value, str):
|
||||
try:
|
||||
parsed = float(value) if allow_float else int(value)
|
||||
except ValueError:
|
||||
continue
|
||||
if parsed >= 0:
|
||||
return parsed
|
||||
return None
|
||||
|
||||
@@ -110,6 +110,19 @@ class AgentScopeRuntimeOrchestrator:
|
||||
)
|
||||
intent_output = IntentOutput.model_validate(intent_payload)
|
||||
|
||||
if intent_output.route == "DIRECT_RESPONSE":
|
||||
assistant_text = (
|
||||
intent_output.direct_response or intent_output.intent_summary
|
||||
)
|
||||
return RuntimeOutput(
|
||||
intent=intent_output,
|
||||
execution=None,
|
||||
report=ReportOutput(
|
||||
assistant_text=assistant_text,
|
||||
response_metadata=dict(intent_output.response_metadata),
|
||||
),
|
||||
)
|
||||
|
||||
execution_output: ExecutionBatchOutput | None = None
|
||||
if intent_output.route == "TASK_EXECUTION":
|
||||
execution_toolkit = build_stage_toolkit(
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from time import perf_counter
|
||||
from typing import Any, cast
|
||||
|
||||
from core.agentscope.runtime.config_loader import RuntimeStageConfig
|
||||
@@ -14,7 +16,8 @@ def _to_litellm_model(*, provider_name: str, model_code: str) -> str:
|
||||
normalized_model = model_code.strip()
|
||||
if "/" in normalized_model:
|
||||
return normalized_model
|
||||
return f"{provider_name.strip().lower()}/{normalized_model}"
|
||||
del provider_name
|
||||
return normalized_model
|
||||
|
||||
|
||||
def _parse_json_text(raw_text: str) -> dict[str, Any]:
|
||||
@@ -30,6 +33,11 @@ def _parse_json_text(raw_text: str) -> dict[str, Any]:
|
||||
|
||||
|
||||
class AgentScopeReActRunner:
|
||||
def _build_litellm_service(self) -> Any:
|
||||
from services.litellm.service import LiteLLMService
|
||||
|
||||
return LiteLLMService()
|
||||
|
||||
def _build_model(self, *, stage_config: RuntimeStageConfig) -> Any:
|
||||
from agentscope.model import OpenAIChatModel
|
||||
from agentscope.types import JSONSerializableObject
|
||||
@@ -61,9 +69,16 @@ class AgentScopeReActRunner:
|
||||
stage_config: RuntimeStageConfig,
|
||||
agent_name: str,
|
||||
system_prompt: str,
|
||||
user_prompt: str,
|
||||
user_prompt: str | list[dict[str, Any]],
|
||||
toolkit: Any | None,
|
||||
) -> dict[str, Any]:
|
||||
if stage_config.stage == "report" and toolkit is None:
|
||||
return await self._run_report_stage_direct(
|
||||
stage_config=stage_config,
|
||||
system_prompt=system_prompt,
|
||||
user_prompt=user_prompt,
|
||||
)
|
||||
|
||||
from agentscope.agent import ReActAgent
|
||||
from agentscope.formatter import OpenAIChatFormatter
|
||||
from agentscope.memory import InMemoryMemory
|
||||
@@ -79,9 +94,19 @@ class AgentScopeReActRunner:
|
||||
max_iters=6,
|
||||
)
|
||||
try:
|
||||
response = await agent(Msg(name="user", content=user_prompt, role="user"))
|
||||
started_at = perf_counter()
|
||||
response = await agent(
|
||||
Msg(name="user", content=cast(Any, user_prompt), role="user")
|
||||
)
|
||||
latency_ms = int(round((perf_counter() - started_at) * 1000))
|
||||
text_content = response.get_text_content() or "{}"
|
||||
return _parse_json_text(text_content)
|
||||
payload = _parse_json_text(text_content)
|
||||
return _merge_stage_response_metadata(
|
||||
payload=payload,
|
||||
stage_config=stage_config,
|
||||
response=response,
|
||||
latency_ms=latency_ms,
|
||||
)
|
||||
except json.JSONDecodeError as exc:
|
||||
logger.exception(
|
||||
"agentscope stage output is not valid json",
|
||||
@@ -96,3 +121,234 @@ class AgentScopeReActRunner:
|
||||
agent_name=agent_name,
|
||||
)
|
||||
raise RuntimeError("agent execution failed") from exc
|
||||
|
||||
async def _run_report_stage_direct(
|
||||
self,
|
||||
*,
|
||||
stage_config: RuntimeStageConfig,
|
||||
system_prompt: str,
|
||||
user_prompt: str | list[dict[str, Any]],
|
||||
) -> dict[str, Any]:
|
||||
try:
|
||||
service = self._build_litellm_service()
|
||||
started_at = perf_counter()
|
||||
response_with_cost = await asyncio.to_thread(
|
||||
service.run_completion_with_cost,
|
||||
model=_to_litellm_model(
|
||||
provider_name=stage_config.provider_name,
|
||||
model_code=stage_config.model_code,
|
||||
),
|
||||
messages=_report_messages(
|
||||
system_prompt=system_prompt,
|
||||
user_prompt=user_prompt,
|
||||
),
|
||||
temperature=stage_config.llm_config.temperature,
|
||||
max_tokens=stage_config.llm_config.max_tokens,
|
||||
timeout=stage_config.llm_config.timeout_seconds,
|
||||
response_format={"type": "json_object"},
|
||||
)
|
||||
latency_ms = int(round((perf_counter() - started_at) * 1000))
|
||||
|
||||
text_content = _chat_response_text(response_with_cost.response)
|
||||
payload = _parse_json_text(text_content)
|
||||
return _merge_report_response_metadata(
|
||||
payload=payload,
|
||||
stage_config=stage_config,
|
||||
response_with_cost=response_with_cost,
|
||||
latency_ms=latency_ms,
|
||||
)
|
||||
except json.JSONDecodeError as exc:
|
||||
logger.exception(
|
||||
"agentscope stage output is not valid json",
|
||||
stage=stage_config.stage,
|
||||
agent_name="report-agent",
|
||||
)
|
||||
raise RuntimeError("agent output format invalid") from exc
|
||||
except Exception as exc:
|
||||
logger.exception(
|
||||
"agentscope stage execution failed",
|
||||
stage=stage_config.stage,
|
||||
agent_name="report-agent",
|
||||
)
|
||||
raise RuntimeError("agent execution failed") from exc
|
||||
|
||||
|
||||
def _chat_response_text(response: Any) -> str:
|
||||
content = _read_value(response, "content")
|
||||
if isinstance(content, str) and content.strip():
|
||||
return content
|
||||
|
||||
if not isinstance(content, list):
|
||||
return _fallback_choice_content(response)
|
||||
|
||||
text_parts: list[str] = []
|
||||
for block in content:
|
||||
if not isinstance(block, dict):
|
||||
continue
|
||||
if block.get("type") != "text":
|
||||
continue
|
||||
text = block.get("text")
|
||||
if isinstance(text, str) and text:
|
||||
text_parts.append(text)
|
||||
if text_parts:
|
||||
return "".join(text_parts)
|
||||
return _fallback_choice_content(response)
|
||||
|
||||
|
||||
def _fallback_choice_content(response: Any) -> str:
|
||||
choices = _read_value(response, "choices")
|
||||
if not isinstance(choices, list) or not choices:
|
||||
return "{}"
|
||||
|
||||
first_choice = choices[0]
|
||||
message = getattr(first_choice, "message", None)
|
||||
if message is None and isinstance(first_choice, dict):
|
||||
message = first_choice.get("message")
|
||||
|
||||
if isinstance(message, dict):
|
||||
content = message.get("content")
|
||||
return content if isinstance(content, str) and content else "{}"
|
||||
|
||||
content = _read_value(message, "content")
|
||||
return content if isinstance(content, str) and content else "{}"
|
||||
|
||||
|
||||
def _read_value(source: Any, key: str) -> Any:
|
||||
if isinstance(source, dict):
|
||||
return source.get(key)
|
||||
return getattr(source, key, None)
|
||||
|
||||
|
||||
def _report_messages(
|
||||
*, system_prompt: str, user_prompt: str | list[dict[str, Any]]
|
||||
) -> list[dict[str, Any]]:
|
||||
return [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_prompt},
|
||||
]
|
||||
|
||||
|
||||
def _merge_stage_response_metadata(
|
||||
*,
|
||||
payload: dict[str, Any],
|
||||
stage_config: RuntimeStageConfig,
|
||||
response: Any,
|
||||
latency_ms: int,
|
||||
) -> dict[str, Any]:
|
||||
result = dict(payload)
|
||||
existing = result.get("response_metadata")
|
||||
metadata: dict[str, Any] = dict(existing) if isinstance(existing, dict) else {}
|
||||
metadata.setdefault("model", stage_config.model_code)
|
||||
|
||||
usage = _read_value(response, "usage")
|
||||
prompt_tokens = _to_non_negative_int(
|
||||
_read_value(usage, "prompt_tokens") or _read_value(usage, "input_tokens")
|
||||
)
|
||||
completion_tokens = _to_non_negative_int(
|
||||
_read_value(usage, "completion_tokens") or _read_value(usage, "output_tokens")
|
||||
)
|
||||
cost = _to_non_negative_float(
|
||||
_read_value(usage, "cost")
|
||||
or _read_value(_read_value(usage, "metadata"), "cost")
|
||||
)
|
||||
resolved_model = _read_value(response, "model")
|
||||
if cost is None and prompt_tokens is not None and completion_tokens is not None:
|
||||
estimated_cost = _estimate_cost_by_pricing(
|
||||
model=resolved_model
|
||||
if isinstance(resolved_model, str)
|
||||
else stage_config.model_code,
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
)
|
||||
if estimated_cost is not None:
|
||||
cost = estimated_cost
|
||||
|
||||
if prompt_tokens is not None:
|
||||
metadata["inputTokens"] = prompt_tokens
|
||||
if completion_tokens is not None:
|
||||
metadata["outputTokens"] = completion_tokens
|
||||
if cost is not None:
|
||||
metadata["cost"] = cost
|
||||
if latency_ms >= 0:
|
||||
metadata["latencyMs"] = latency_ms
|
||||
|
||||
result["response_metadata"] = metadata
|
||||
return result
|
||||
|
||||
|
||||
def _merge_report_response_metadata(
|
||||
*,
|
||||
payload: dict[str, Any],
|
||||
stage_config: RuntimeStageConfig,
|
||||
response_with_cost: Any,
|
||||
latency_ms: int,
|
||||
) -> dict[str, Any]:
|
||||
result = dict(payload)
|
||||
existing = result.get("response_metadata")
|
||||
metadata: dict[str, Any] = dict(existing) if isinstance(existing, dict) else {}
|
||||
usage = _read_value(response_with_cost, "usage")
|
||||
response = _read_value(response_with_cost, "response")
|
||||
|
||||
resolved_model = _read_value(response, "model")
|
||||
if isinstance(resolved_model, str) and resolved_model.strip():
|
||||
metadata["model"] = resolved_model.strip()
|
||||
else:
|
||||
metadata.setdefault("model", stage_config.model_code)
|
||||
|
||||
input_tokens = _to_non_negative_int(_read_value(usage, "prompt_tokens"))
|
||||
output_tokens = _to_non_negative_int(_read_value(usage, "completion_tokens"))
|
||||
cost = _to_non_negative_float(_read_value(usage, "cost"))
|
||||
if input_tokens is not None:
|
||||
metadata["inputTokens"] = input_tokens
|
||||
if output_tokens is not None:
|
||||
metadata["outputTokens"] = output_tokens
|
||||
if cost is not None:
|
||||
metadata["cost"] = cost
|
||||
if latency_ms >= 0:
|
||||
metadata["latencyMs"] = latency_ms
|
||||
|
||||
result["response_metadata"] = metadata
|
||||
return result
|
||||
|
||||
|
||||
def _to_non_negative_int(value: Any) -> int | None:
|
||||
if isinstance(value, bool):
|
||||
return None
|
||||
if not isinstance(value, (int, float, str)):
|
||||
return None
|
||||
try:
|
||||
parsed = int(value)
|
||||
except (TypeError, ValueError):
|
||||
return None
|
||||
return parsed if parsed >= 0 else None
|
||||
|
||||
|
||||
def _to_non_negative_float(value: Any) -> float | None:
|
||||
if isinstance(value, bool):
|
||||
return None
|
||||
if not isinstance(value, (int, float, str)):
|
||||
return None
|
||||
try:
|
||||
parsed = float(value)
|
||||
except (TypeError, ValueError):
|
||||
return None
|
||||
return parsed if parsed >= 0 else None
|
||||
|
||||
|
||||
def _estimate_cost_by_pricing(
|
||||
*, model: str, prompt_tokens: int, completion_tokens: int
|
||||
) -> float | None:
|
||||
normalized_model = model.strip()
|
||||
if not normalized_model:
|
||||
return None
|
||||
from services.litellm.service import LiteLLMService
|
||||
|
||||
service = LiteLLMService()
|
||||
try:
|
||||
return service.calculate_cost(
|
||||
model=normalized_model,
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
)
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
@@ -1,8 +1,11 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.agentscope.events import (
|
||||
AgentScopeAgUiCodec,
|
||||
AgentScopeEventPipeline,
|
||||
@@ -18,6 +21,7 @@ from core.config.settings import config
|
||||
from core.db.session import AsyncSessionLocal
|
||||
from core.logging import get_logger
|
||||
from core.taskiq.app import bulk_broker, critical_broker, default_broker
|
||||
from models.agent_chat_message import AgentChatMessage, AgentChatMessageRole
|
||||
from services.base.redis import get_or_init_redis_client
|
||||
|
||||
logger = get_logger("core.agentscope.runtime.tasks")
|
||||
@@ -76,6 +80,56 @@ def _extract_user_token(
|
||||
return None
|
||||
|
||||
|
||||
async def _build_recent_context_messages(
|
||||
*,
|
||||
session: Any,
|
||||
thread_id: str,
|
||||
current_run_id: str,
|
||||
max_messages: int = 20,
|
||||
) -> list[dict[str, Any]]:
|
||||
try:
|
||||
session_uuid = UUID(thread_id)
|
||||
except ValueError:
|
||||
return []
|
||||
|
||||
utc_now = datetime.now(timezone.utc)
|
||||
start_of_today = utc_now.replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
start_of_yesterday = start_of_today - timedelta(days=1)
|
||||
|
||||
stmt = (
|
||||
select(AgentChatMessage)
|
||||
.where(AgentChatMessage.session_id == session_uuid)
|
||||
.where(AgentChatMessage.deleted_at.is_(None))
|
||||
.where(AgentChatMessage.created_at >= start_of_yesterday)
|
||||
.order_by(AgentChatMessage.seq.asc())
|
||||
)
|
||||
rows = (await session.execute(stmt)).scalars().all()
|
||||
|
||||
normalized: list[dict[str, Any]] = []
|
||||
for row in rows:
|
||||
metadata = row.metadata_json if isinstance(row.metadata_json, dict) else {}
|
||||
if metadata.get("run_id") == current_run_id:
|
||||
continue
|
||||
role = (
|
||||
row.role.value
|
||||
if isinstance(row.role, AgentChatMessageRole)
|
||||
else str(row.role)
|
||||
)
|
||||
if role not in {"user", "assistant"}:
|
||||
continue
|
||||
normalized.append(
|
||||
{
|
||||
"id": str(row.id),
|
||||
"role": role,
|
||||
"content": row.content,
|
||||
}
|
||||
)
|
||||
|
||||
if len(normalized) <= max_messages:
|
||||
return normalized
|
||||
return normalized[-max_messages:]
|
||||
|
||||
|
||||
async def run_agentscope_task(command: dict[str, Any]) -> dict[str, object]:
|
||||
command_type = str(command.get("command", "run")).strip().lower()
|
||||
raw_run_input = command.get("run_input")
|
||||
@@ -117,6 +171,21 @@ async def run_agentscope_task(command: dict[str, Any]) -> dict[str, object]:
|
||||
)
|
||||
|
||||
async with AsyncSessionLocal() as session:
|
||||
if command_type == "run":
|
||||
context_messages = await _build_recent_context_messages(
|
||||
session=session,
|
||||
thread_id=parsed_run_input.thread_id,
|
||||
current_run_id=parsed_run_input.run_id,
|
||||
)
|
||||
parsed_run_input = parsed_run_input.model_copy(
|
||||
update={
|
||||
"messages": [
|
||||
*context_messages,
|
||||
*parsed_run_input.messages,
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
if command_type == "resume":
|
||||
await runtime.resume(
|
||||
command=parsed_run_input,
|
||||
|
||||
@@ -5,12 +5,21 @@ from typing import Any, Literal
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class ExecutionToolCall(BaseModel):
|
||||
tool_name: str = Field(min_length=1)
|
||||
args: dict[str, Any] = Field(default_factory=dict)
|
||||
result: Any | None = None
|
||||
error: str | None = None
|
||||
|
||||
|
||||
class ExecutionTaskOutput(BaseModel):
|
||||
task_id: str = Field(min_length=1)
|
||||
status: Literal["SUCCESS", "PARTIAL", "FAILED"]
|
||||
execution_summary: str = Field(min_length=1)
|
||||
execution_data: dict[str, Any] = Field(default_factory=dict)
|
||||
user_feedback_needs: list[str] = Field(default_factory=list)
|
||||
response_metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
tool_calls: list[ExecutionToolCall] = Field(default_factory=list)
|
||||
|
||||
|
||||
class ExecutionBatchOutput(BaseModel):
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
@@ -17,6 +18,7 @@ class IntentOutput(BaseModel):
|
||||
direct_response: str | None = None
|
||||
tasks: list[IntentTask] = Field(default_factory=list)
|
||||
complexity: Literal["simple", "complex"]
|
||||
response_metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_route(self) -> "IntentOutput":
|
||||
|
||||
@@ -1,3 +1,7 @@
|
||||
from core.agentscope.tools.custom.calendar import calendar_read, calendar_write
|
||||
from core.agentscope.tools.custom.calendar import (
|
||||
calendar_read,
|
||||
calendar_write,
|
||||
user_resolve,
|
||||
)
|
||||
|
||||
__all__ = ["calendar_read", "calendar_write"]
|
||||
__all__ = ["calendar_read", "calendar_write", "user_resolve"]
|
||||
|
||||
@@ -7,6 +7,7 @@ from core.auth.jwt_verifier import JwtVerifier, TokenValidationError
|
||||
from core.agentscope.tools.custom.calendar_backend_ops import (
|
||||
_execute_list_calendar_events,
|
||||
_execute_mutate_calendar_event,
|
||||
_execute_resolve_user_identity,
|
||||
)
|
||||
from core.config.settings import config
|
||||
from core.agentscope.tools.response import build_tool_response
|
||||
@@ -150,6 +151,30 @@ async def calendar_write(
|
||||
bool,
|
||||
Field(description="Whether to use the replace strategy for conflicts."),
|
||||
] = False,
|
||||
invite_user_emails: Annotated[
|
||||
list[str] | None,
|
||||
Field(description="Optional invite targets by email."),
|
||||
] = None,
|
||||
invite_user_names: Annotated[
|
||||
list[str] | None,
|
||||
Field(description="Optional invite targets by username."),
|
||||
] = None,
|
||||
invite_user_ids: Annotated[
|
||||
list[str] | None,
|
||||
Field(description="Optional invite targets by user ID (UUID string)."),
|
||||
] = None,
|
||||
invite_permission_view: Annotated[
|
||||
bool,
|
||||
Field(description="Invite permission: view."),
|
||||
] = True,
|
||||
invite_permission_edit: Annotated[
|
||||
bool,
|
||||
Field(description="Invite permission: edit."),
|
||||
] = False,
|
||||
invite_permission_invite: Annotated[
|
||||
bool,
|
||||
Field(description="Invite permission: invite others."),
|
||||
] = False,
|
||||
session: Any = None,
|
||||
owner_id: Any = None,
|
||||
user_token: str | None = None,
|
||||
@@ -240,6 +265,15 @@ async def calendar_write(
|
||||
tool_args["reminderMinutes"] = reminder_minutes
|
||||
if status is not None:
|
||||
tool_args["status"] = status
|
||||
if invite_user_emails is not None:
|
||||
tool_args["inviteUserEmails"] = invite_user_emails
|
||||
if invite_user_names is not None:
|
||||
tool_args["inviteUserNames"] = invite_user_names
|
||||
if invite_user_ids is not None:
|
||||
tool_args["inviteUserIds"] = invite_user_ids
|
||||
tool_args["invitePermissionView"] = invite_permission_view
|
||||
tool_args["invitePermissionEdit"] = invite_permission_edit
|
||||
tool_args["invitePermissionInvite"] = invite_permission_invite
|
||||
|
||||
result = await _execute_mutate_calendar_event(
|
||||
session=cast(Any, session),
|
||||
@@ -247,3 +281,34 @@ async def calendar_write(
|
||||
tool_args=tool_args,
|
||||
)
|
||||
return build_tool_response(result)
|
||||
|
||||
|
||||
async def user_resolve(
|
||||
user_email: Annotated[
|
||||
str | None,
|
||||
Field(description="User email to resolve user ID."),
|
||||
] = None,
|
||||
user_name: Annotated[
|
||||
str | None,
|
||||
Field(description="Username to resolve user ID."),
|
||||
] = None,
|
||||
session: Any = None,
|
||||
owner_id: Any = None,
|
||||
user_token: str | None = None,
|
||||
) -> Any:
|
||||
if session is None or owner_id is None:
|
||||
raise ValueError("user.resolve missing runtime preset arguments")
|
||||
if not isinstance(user_token, str) or not user_token.strip():
|
||||
return build_tool_response(_unauthorized_response())
|
||||
if not _verify_user_token(user_token=user_token, owner_id=cast(UUID, owner_id)):
|
||||
return build_tool_response(_unauthorized_response())
|
||||
|
||||
result = await _execute_resolve_user_identity(
|
||||
session=cast(Any, session),
|
||||
owner_id=cast(UUID, owner_id),
|
||||
tool_args={
|
||||
"userEmail": user_email,
|
||||
"userName": user_name,
|
||||
},
|
||||
)
|
||||
return build_tool_response(result)
|
||||
|
||||
@@ -4,13 +4,20 @@ import re
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from core.auth.models import CurrentUser
|
||||
from services.base.supabase import supabase_service
|
||||
from models.profile import Profile
|
||||
from v1.auth.gateway import SupabaseAuthGateway
|
||||
from v1.inbox_messages.repository import SQLAlchemyInboxMessageRepository
|
||||
from v1.schedule_items.repository import SQLAlchemyScheduleItemRepository
|
||||
from v1.schedule_items.schemas import (
|
||||
ScheduleItemCreateRequest,
|
||||
ScheduleItemMetadata,
|
||||
ScheduleItemShareRequest,
|
||||
ScheduleItemStatus,
|
||||
ScheduleItemUpdateRequest,
|
||||
)
|
||||
@@ -72,9 +79,196 @@ def _service(session: AsyncSession, owner_id: UUID) -> ScheduleItemService:
|
||||
repository=SQLAlchemyScheduleItemRepository(session),
|
||||
session=session,
|
||||
current_user=CurrentUser(id=owner_id),
|
||||
inbox_repository=SQLAlchemyInboxMessageRepository(session),
|
||||
)
|
||||
|
||||
|
||||
def _parse_string_list(value: object, *, field_name: str) -> list[str]:
|
||||
if value is None:
|
||||
return []
|
||||
if not isinstance(value, list):
|
||||
raise ValueError(f"{field_name} must be a list of strings")
|
||||
parsed: list[str] = []
|
||||
for item in value:
|
||||
if not isinstance(item, str) or not item.strip():
|
||||
raise ValueError(f"{field_name} must be a list of non-empty strings")
|
||||
parsed.append(item.strip())
|
||||
return parsed
|
||||
|
||||
|
||||
def _list_auth_users() -> list[object]:
|
||||
admin_client = supabase_service.get_admin_client()
|
||||
users: list[object] = []
|
||||
page = 1
|
||||
while page <= 100:
|
||||
response = admin_client.auth.admin.list_users(page=page, per_page=100)
|
||||
batch = (
|
||||
list(response)
|
||||
if isinstance(response, list)
|
||||
else list(getattr(response, "users", []))
|
||||
)
|
||||
users.extend(batch)
|
||||
if len(batch) < 100:
|
||||
break
|
||||
page += 1
|
||||
return users
|
||||
|
||||
|
||||
async def _get_profile_username(*, session: AsyncSession, user_id: UUID) -> str | None:
|
||||
stmt = select(Profile.username).where(Profile.id == user_id)
|
||||
return (await session.execute(stmt)).scalar_one_or_none()
|
||||
|
||||
|
||||
async def _get_profile_by_username(
|
||||
*, session: AsyncSession, username: str
|
||||
) -> Profile | None:
|
||||
stmt = (
|
||||
select(Profile)
|
||||
.where(Profile.username == username)
|
||||
.where(Profile.deleted_at.is_(None))
|
||||
)
|
||||
return (await session.execute(stmt)).scalar_one_or_none()
|
||||
|
||||
|
||||
def _find_auth_email_by_user_id(*, users: list[object], user_id: UUID) -> str | None:
|
||||
target = str(user_id)
|
||||
for user in users:
|
||||
if str(getattr(user, "id", "")) == target:
|
||||
email = getattr(user, "email", None)
|
||||
if isinstance(email, str) and email.strip():
|
||||
return email.strip()
|
||||
return None
|
||||
|
||||
|
||||
async def _resolve_identity(
|
||||
*,
|
||||
session: AsyncSession,
|
||||
user_email: str | None,
|
||||
user_name: str | None,
|
||||
) -> dict[str, object]:
|
||||
email = user_email.strip().lower() if isinstance(user_email, str) else ""
|
||||
name = user_name.strip() if isinstance(user_name, str) else ""
|
||||
if bool(email) == bool(name):
|
||||
raise ValueError("provide exactly one of user_email or user_name")
|
||||
|
||||
if email:
|
||||
auth_gateway = SupabaseAuthGateway()
|
||||
user = await auth_gateway.get_user_by_email(email)
|
||||
user_id = UUID(user.id)
|
||||
username = await _get_profile_username(session=session, user_id=user_id)
|
||||
return {
|
||||
"userId": str(user_id),
|
||||
"email": user.email,
|
||||
"username": username,
|
||||
"matchedBy": "email",
|
||||
}
|
||||
|
||||
profile = await _get_profile_by_username(session=session, username=name)
|
||||
if profile is None:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
users = _list_auth_users()
|
||||
email_value = _find_auth_email_by_user_id(users=users, user_id=profile.id)
|
||||
return {
|
||||
"userId": str(profile.id),
|
||||
"email": email_value,
|
||||
"username": profile.username,
|
||||
"matchedBy": "username",
|
||||
}
|
||||
|
||||
|
||||
def _invite_permission(tool_args: dict[str, object]) -> dict[str, bool]:
|
||||
return {
|
||||
"permission_view": bool(tool_args.get("invitePermissionView", True)),
|
||||
"permission_edit": bool(tool_args.get("invitePermissionEdit", False)),
|
||||
"permission_invite": bool(tool_args.get("invitePermissionInvite", False)),
|
||||
}
|
||||
|
||||
|
||||
async def _share_event_with_invitees(
|
||||
*,
|
||||
session: AsyncSession,
|
||||
owner_id: UUID,
|
||||
event_id: UUID,
|
||||
tool_args: dict[str, object],
|
||||
) -> dict[str, object] | None:
|
||||
email_targets = _parse_string_list(
|
||||
tool_args.get("inviteUserEmails"),
|
||||
field_name="inviteUserEmails",
|
||||
)
|
||||
name_targets = _parse_string_list(
|
||||
tool_args.get("inviteUserNames"),
|
||||
field_name="inviteUserNames",
|
||||
)
|
||||
id_targets = _parse_string_list(
|
||||
tool_args.get("inviteUserIds"),
|
||||
field_name="inviteUserIds",
|
||||
)
|
||||
if not email_targets and not name_targets and not id_targets:
|
||||
return None
|
||||
|
||||
users = _list_auth_users() if id_targets else []
|
||||
emails = {item.lower() for item in email_targets}
|
||||
for user_id_raw in id_targets:
|
||||
try:
|
||||
user_id = UUID(user_id_raw)
|
||||
except ValueError as exc:
|
||||
raise ValueError("inviteUserIds must contain valid UUID strings") from exc
|
||||
resolved_email = _find_auth_email_by_user_id(users=users, user_id=user_id)
|
||||
if resolved_email is None:
|
||||
raise HTTPException(status_code=404, detail="Invite user email not found")
|
||||
emails.add(resolved_email.lower())
|
||||
for username in name_targets:
|
||||
resolved = await _resolve_identity(
|
||||
session=session,
|
||||
user_email=None,
|
||||
user_name=username,
|
||||
)
|
||||
resolved_email = resolved.get("email")
|
||||
if not isinstance(resolved_email, str) or not resolved_email:
|
||||
raise HTTPException(status_code=404, detail="Invite user email not found")
|
||||
emails.add(resolved_email.lower())
|
||||
|
||||
service = _service(session, owner_id)
|
||||
permission = _invite_permission(tool_args)
|
||||
invited: list[str] = []
|
||||
for email in sorted(emails):
|
||||
request = ScheduleItemShareRequest(email=email, **permission)
|
||||
await service.share(event_id, request)
|
||||
invited.append(email)
|
||||
return {
|
||||
"count": len(invited),
|
||||
"emails": invited,
|
||||
"permission": permission,
|
||||
}
|
||||
|
||||
|
||||
async def _execute_resolve_user_identity(
|
||||
*,
|
||||
session: AsyncSession,
|
||||
owner_id: UUID,
|
||||
tool_args: dict[str, object],
|
||||
) -> dict[str, object]:
|
||||
del owner_id
|
||||
user_email_raw = tool_args.get("userEmail")
|
||||
user_name_raw = tool_args.get("userName")
|
||||
user_email = user_email_raw if isinstance(user_email_raw, str) else None
|
||||
user_name = user_name_raw if isinstance(user_name_raw, str) else None
|
||||
resolved = await _resolve_identity(
|
||||
session=session,
|
||||
user_email=user_email,
|
||||
user_name=user_name,
|
||||
)
|
||||
return {
|
||||
"type": "user_lookup.v1",
|
||||
"version": "v1",
|
||||
"data": {
|
||||
"ok": True,
|
||||
**resolved,
|
||||
},
|
||||
"actions": [],
|
||||
}
|
||||
|
||||
|
||||
def _resolve_metadata(tool_args: dict[str, object]) -> ScheduleItemMetadata:
|
||||
location = tool_args.get("location")
|
||||
location_value = location.strip() if isinstance(location, str) else None
|
||||
@@ -185,6 +379,12 @@ async def _execute_create(
|
||||
)
|
||||
event_data = _event_payload(created)
|
||||
event_id = str(event_data["id"])
|
||||
invite_result = await _share_event_with_invitees(
|
||||
session=service._session,
|
||||
owner_id=service.require_user_id(),
|
||||
event_id=UUID(event_id),
|
||||
tool_args=tool_args,
|
||||
)
|
||||
return {
|
||||
"type": "calendar_card.v1",
|
||||
"version": "v1",
|
||||
@@ -193,12 +393,13 @@ async def _execute_create(
|
||||
"sourceType": "agent_generated",
|
||||
"ok": True,
|
||||
"message": "日程已创建",
|
||||
"inviteResult": invite_result,
|
||||
},
|
||||
"actions": [
|
||||
{
|
||||
"type": "link",
|
||||
"label": "查看详情",
|
||||
"target": f"/calendar/events/{event_id}",
|
||||
"target": f"/schedule-items/{event_id}",
|
||||
}
|
||||
],
|
||||
}
|
||||
@@ -274,6 +475,12 @@ async def _execute_update(
|
||||
ScheduleItemUpdateRequest.model_validate(update_data),
|
||||
)
|
||||
event_data = _event_payload(updated)
|
||||
invite_result = await _share_event_with_invitees(
|
||||
session=service._session,
|
||||
owner_id=service.require_user_id(),
|
||||
event_id=UUID(str(event_data["id"])),
|
||||
tool_args=tool_args,
|
||||
)
|
||||
return {
|
||||
"type": "calendar_card.v1",
|
||||
"version": "v1",
|
||||
@@ -282,12 +489,13 @@ async def _execute_update(
|
||||
"sourceType": "agent_generated",
|
||||
"ok": True,
|
||||
"message": "日程已更新",
|
||||
"inviteResult": invite_result,
|
||||
},
|
||||
"actions": [
|
||||
{
|
||||
"type": "link",
|
||||
"label": "查看详情",
|
||||
"target": f"/calendar/events/{event_data['id']}",
|
||||
"target": f"/schedule-items/{event_data['id']}",
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
@@ -4,8 +4,9 @@ from dataclasses import dataclass
|
||||
|
||||
|
||||
TOOL_APPROVAL_REQUIRED: dict[str, bool] = {
|
||||
"calendar.read": False,
|
||||
"calendar.write": False,
|
||||
"calendar_read": False,
|
||||
"calendar_write": False,
|
||||
"user_resolve": False,
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -6,7 +6,11 @@ from uuid import UUID
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from core.agentscope.tools.custom.calendar import calendar_read, calendar_write
|
||||
from core.agentscope.tools.custom.calendar import (
|
||||
calendar_read,
|
||||
calendar_write,
|
||||
user_resolve,
|
||||
)
|
||||
from core.agentscope.tools.hitl_middleware import register_tool_middlewares
|
||||
from core.agentscope.tools.tool_meta import TOOL_META
|
||||
|
||||
@@ -25,10 +29,12 @@ class ToolGroup:
|
||||
|
||||
|
||||
TOOL_GROUPS: dict[str, ToolGroup] = {
|
||||
"intent": ToolGroup(stage="intent", tool_names=frozenset({"calendar.read"})),
|
||||
"intent": ToolGroup(
|
||||
stage="intent", tool_names=frozenset({"calendar_read", "user_resolve"})
|
||||
),
|
||||
"execution": ToolGroup(
|
||||
stage="execution",
|
||||
tool_names=frozenset({"calendar.read", "calendar.write"}),
|
||||
tool_names=frozenset({"calendar_read", "calendar_write", "user_resolve"}),
|
||||
),
|
||||
"report": ToolGroup(stage="report", tool_names=frozenset()),
|
||||
}
|
||||
@@ -49,7 +55,7 @@ def _load_custom_tool_bindings(
|
||||
) -> list[CustomToolBinding]:
|
||||
return [
|
||||
CustomToolBinding(
|
||||
name="calendar.read",
|
||||
name="calendar_read",
|
||||
func=calendar_read,
|
||||
preset_kwargs={
|
||||
"session": session,
|
||||
@@ -58,7 +64,7 @@ def _load_custom_tool_bindings(
|
||||
},
|
||||
),
|
||||
CustomToolBinding(
|
||||
name="calendar.write",
|
||||
name="calendar_write",
|
||||
func=calendar_write,
|
||||
preset_kwargs={
|
||||
"session": session,
|
||||
@@ -66,6 +72,15 @@ def _load_custom_tool_bindings(
|
||||
"user_token": user_token or "",
|
||||
},
|
||||
),
|
||||
CustomToolBinding(
|
||||
name="user_resolve",
|
||||
func=user_resolve,
|
||||
preset_kwargs={
|
||||
"session": session,
|
||||
"owner_id": owner_id,
|
||||
"user_token": user_token or "",
|
||||
},
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -126,21 +126,26 @@ class LiteLLMService:
|
||||
temperature: float | None = None,
|
||||
max_tokens: int | None = None,
|
||||
timeout: float | None = None,
|
||||
response_format: dict[str, Any] | None = None,
|
||||
completion_fn: Callable[..., dict[str, Any]] | None = None,
|
||||
) -> LiteLLMResponseWithCost:
|
||||
caller = completion_fn or completion
|
||||
request_model = model if model.startswith("openai/") else f"openai/{model}"
|
||||
|
||||
response_any = caller(
|
||||
model=request_model,
|
||||
api_key=self.proxy_api_key,
|
||||
api_base=self.proxy_base_url,
|
||||
messages=messages,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
timeout=timeout,
|
||||
stream=False,
|
||||
)
|
||||
request_kwargs: dict[str, Any] = {
|
||||
"model": request_model,
|
||||
"api_key": self.proxy_api_key,
|
||||
"api_base": self.proxy_base_url,
|
||||
"messages": messages,
|
||||
"temperature": temperature,
|
||||
"max_tokens": max_tokens,
|
||||
"timeout": timeout,
|
||||
"stream": False,
|
||||
}
|
||||
if response_format is not None:
|
||||
request_kwargs["response_format"] = response_format
|
||||
|
||||
response_any = caller(**request_kwargs)
|
||||
response = self._normalize_response(response_any)
|
||||
|
||||
usage_raw = response.get("usage")
|
||||
|
||||
@@ -107,6 +107,10 @@ class AgentRepository:
|
||||
raise HTTPException(status_code=404, detail="Session not found")
|
||||
|
||||
next_seq = int(session_row.message_count or 0) + 1
|
||||
if not _has_title(session_row.title):
|
||||
session_title = _derive_session_title(content_text)
|
||||
if session_title is not None:
|
||||
session_row.title = session_title
|
||||
payload_metadata = dict(metadata or {})
|
||||
payload_metadata["run_id"] = run_id
|
||||
message = AgentChatMessage(
|
||||
@@ -264,3 +268,14 @@ class AgentRepository:
|
||||
if rendered:
|
||||
payload["attachments"] = rendered
|
||||
return payload
|
||||
|
||||
|
||||
def _has_title(title: object) -> bool:
|
||||
return isinstance(title, str) and bool(title.strip())
|
||||
|
||||
|
||||
def _derive_session_title(content_text: str) -> str | None:
|
||||
normalized = " ".join(content_text.split())
|
||||
if not normalized:
|
||||
return None
|
||||
return normalized[:80]
|
||||
|
||||
@@ -203,6 +203,11 @@ async def stream_events(
|
||||
user_id=str(current_user.id),
|
||||
reason=str(exc),
|
||||
)
|
||||
if "Timeout reading from" in str(exc):
|
||||
idle_polls += 1
|
||||
yield ": keep-alive\n\n"
|
||||
await asyncio.sleep(0.2)
|
||||
continue
|
||||
break
|
||||
|
||||
if not rows:
|
||||
|
||||
@@ -212,12 +212,19 @@ class AgentService:
|
||||
content_type=mime_type,
|
||||
)
|
||||
except Exception: # noqa: BLE001
|
||||
bucket_name = "private"
|
||||
stored_path = await self._attachment_storage.upload_bytes(
|
||||
bucket=bucket_name,
|
||||
path=path,
|
||||
content=payload,
|
||||
content_type=mime_type,
|
||||
logger.exception(
|
||||
"Attachment upload failed",
|
||||
extra={
|
||||
"bucket": bucket_name,
|
||||
"path": path,
|
||||
"mime_type": mime_type,
|
||||
"thread_id": run_input.thread_id,
|
||||
"run_id": run_input.run_id,
|
||||
},
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail="Failed to upload attachment",
|
||||
)
|
||||
attachments.append(
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user