feat: 优化 Agent 运行时与聊天设置体验
This commit is contained in:
@@ -74,12 +74,34 @@ def _router_role_rules() -> list[str]:
|
||||
def _worker_role_rules() -> list[str]:
|
||||
return [
|
||||
"Worker only: execute routed objective without changing router intent.",
|
||||
"Treat router output as objective/constraints contract, not as a fully-materialized tool-args payload.",
|
||||
"Infer deterministic required tool arguments from contract fields, tool schema, and runtime context.",
|
||||
"Ask minimal clarification only when required arguments cannot be inferred safely.",
|
||||
"Ground every claim in available evidence and tool results; never fabricate execution state.",
|
||||
"Keep status/result_type/answer/key_points/suggested_actions/error internally consistent.",
|
||||
"On partial/failed execution, return concise actionable error context.",
|
||||
]
|
||||
|
||||
|
||||
def build_worker_contract_prompt(*, router_output: RouterAgentOutput) -> str:
|
||||
contract_json = json.dumps(
|
||||
router_output.model_dump(mode="json", exclude_none=True),
|
||||
ensure_ascii=False,
|
||||
separators=(",", ":"),
|
||||
)
|
||||
return "\n".join(
|
||||
[
|
||||
"[Worker Contract]",
|
||||
"- Keep routed objective unchanged.",
|
||||
"- Use normalized_task_input as objective text.",
|
||||
"- Use multimodal_summary/key_entities/constraints as execution evidence.",
|
||||
"- Infer deterministic missing required tool args from evidence + tool schema.",
|
||||
"- Ask clarification only when safe inference is impossible.",
|
||||
"[RouterAgentOutput]",
|
||||
contract_json,
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def build_agent_prompt(*, agent_type: AgentType) -> str:
|
||||
lines = [
|
||||
"[Agent Identity]",
|
||||
|
||||
@@ -1,13 +1,12 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
from agentscope.agent import ReActAgent
|
||||
from agentscope.message import Msg
|
||||
from pydantic import BaseModel, ValidationError
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.agentscope.runtime.utils import extract_text_content, parse_json_dict
|
||||
from core.agentscope.utils import finalize_json_response
|
||||
|
||||
|
||||
class JsonReActAgent(ReActAgent):
|
||||
@@ -47,77 +46,14 @@ class JsonReActAgent(ReActAgent):
|
||||
*,
|
||||
output_model: type[BaseModel],
|
||||
) -> dict[str, Any]:
|
||||
schema_json = json.dumps(
|
||||
output_model.model_json_schema(),
|
||||
ensure_ascii=True,
|
||||
separators=(",", ":"),
|
||||
)
|
||||
last_error = ""
|
||||
|
||||
for attempt in range(1, self._finalize_retries + 2):
|
||||
prompt = await self.formatter.format(
|
||||
msgs=[
|
||||
Msg("system", self.sys_prompt, "system"),
|
||||
*await self.memory.get_memory(),
|
||||
Msg(
|
||||
"user",
|
||||
self._build_finalize_instruction(
|
||||
schema_json=schema_json,
|
||||
validation_error=last_error,
|
||||
attempt=attempt,
|
||||
),
|
||||
"user",
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
original_stream = self.model.stream
|
||||
self.model.stream = False
|
||||
try:
|
||||
response = await self.model(
|
||||
prompt,
|
||||
tool_choice="none",
|
||||
response_format={"type": "json_object"},
|
||||
)
|
||||
finally:
|
||||
self.model.stream = original_stream
|
||||
|
||||
raw_text = extract_text_content(getattr(response, "content", []))
|
||||
payload = parse_json_dict(raw_text)
|
||||
if payload is None:
|
||||
last_error = "Model output is not a valid JSON object."
|
||||
continue
|
||||
|
||||
try:
|
||||
validated = output_model.model_validate(payload)
|
||||
return validated.model_dump(mode="json", exclude_none=True)
|
||||
except ValidationError as exc:
|
||||
last_error = str(exc)
|
||||
|
||||
raise RuntimeError(
|
||||
f"failed to finalize structured output for {output_model.__name__}: {last_error}"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _build_finalize_instruction(
|
||||
*,
|
||||
schema_json: str,
|
||||
validation_error: str,
|
||||
attempt: int,
|
||||
) -> str:
|
||||
error_part = (
|
||||
""
|
||||
if not validation_error
|
||||
else (
|
||||
"\n\n[Validation Error From Previous Attempt]\n"
|
||||
f"{validation_error}\n"
|
||||
"Fix all missing/invalid fields and regenerate."
|
||||
)
|
||||
)
|
||||
return (
|
||||
"Return JSON only. Do not output markdown, prose, or code fences. "
|
||||
"Follow this JSON Schema exactly and include all required fields. "
|
||||
"Do not call tools.\n\n"
|
||||
f"[Schema]\n{schema_json}\n\n"
|
||||
f"[Attempt]\n{attempt}{error_part}"
|
||||
_, payload = await finalize_json_response(
|
||||
model=self.model,
|
||||
formatter=self.formatter,
|
||||
base_messages=[
|
||||
Msg("system", self.sys_prompt, "system"),
|
||||
*await self.memory.get_memory(),
|
||||
],
|
||||
output_model=output_model,
|
||||
retries=self._finalize_retries,
|
||||
)
|
||||
return payload
|
||||
|
||||
@@ -0,0 +1,125 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Any
|
||||
|
||||
from agentscope.model import OpenAIChatModel
|
||||
|
||||
from core.logging import get_logger
|
||||
|
||||
logger = get_logger("core.agentscope.runtime.runner")
|
||||
|
||||
|
||||
class TrackingChatModel:
|
||||
def __init__(self, inner: OpenAIChatModel) -> None:
|
||||
self._inner = inner
|
||||
self._total_input_tokens = 0
|
||||
self._total_output_tokens = 0
|
||||
self._total_latency_ms = 0
|
||||
self._cached_prompt_tokens = 0
|
||||
|
||||
@property
|
||||
def stream(self) -> bool:
|
||||
return self._inner.stream
|
||||
|
||||
@stream.setter
|
||||
def stream(self, value: bool) -> None:
|
||||
self._inner.stream = value
|
||||
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
return getattr(self._inner, name)
|
||||
|
||||
async def __call__(self, *args: Any, **kwargs: Any) -> Any:
|
||||
self._log_model_call(kwargs)
|
||||
response = await self._inner(*args, **kwargs)
|
||||
if isinstance(response, AsyncGenerator):
|
||||
return self._track_stream(response)
|
||||
self._record_usage(getattr(response, "usage", None))
|
||||
return response
|
||||
|
||||
def usage_summary(self) -> dict[str, int]:
|
||||
return {
|
||||
"input_tokens": self._total_input_tokens,
|
||||
"output_tokens": self._total_output_tokens,
|
||||
"latency_ms": self._total_latency_ms,
|
||||
"cached_prompt_tokens": self._cached_prompt_tokens,
|
||||
}
|
||||
|
||||
def _log_model_call(self, kwargs: dict[str, Any]) -> None:
|
||||
tools = kwargs.get("tools")
|
||||
tool_names, generate_response_schema = self._extract_tool_debug_info(tools)
|
||||
logger.info(
|
||||
"model_call_debug",
|
||||
tool_choice=kwargs.get("tool_choice"),
|
||||
tool_count=len(tool_names),
|
||||
tool_names=tool_names,
|
||||
generate_response_schema=generate_response_schema,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _extract_tool_debug_info(
|
||||
tools: Any,
|
||||
) -> tuple[list[str], dict[str, Any] | None]:
|
||||
tool_names: list[str] = []
|
||||
generate_response_schema: dict[str, Any] | None = None
|
||||
if not isinstance(tools, list):
|
||||
return tool_names, generate_response_schema
|
||||
|
||||
for tool in tools:
|
||||
if not isinstance(tool, dict):
|
||||
continue
|
||||
function = tool.get("function")
|
||||
if not isinstance(function, dict):
|
||||
continue
|
||||
name = function.get("name")
|
||||
if not isinstance(name, str):
|
||||
continue
|
||||
tool_names.append(name)
|
||||
if name != "generate_response":
|
||||
continue
|
||||
parameters = function.get("parameters")
|
||||
if not isinstance(parameters, dict):
|
||||
continue
|
||||
props = parameters.get("properties", {})
|
||||
generate_response_schema = {
|
||||
"required": parameters.get("required"),
|
||||
"properties": list(props.keys()) if isinstance(props, dict) else [],
|
||||
}
|
||||
return tool_names, generate_response_schema
|
||||
|
||||
async def _track_stream(
|
||||
self, response: AsyncGenerator[Any, None]
|
||||
) -> AsyncGenerator[Any, None]:
|
||||
latest_usage = None
|
||||
async for chunk in response:
|
||||
usage = getattr(chunk, "usage", None)
|
||||
if usage is not None:
|
||||
latest_usage = usage
|
||||
yield chunk
|
||||
self._record_usage(latest_usage)
|
||||
|
||||
def _record_usage(self, usage: Any) -> None:
|
||||
if usage is None:
|
||||
return
|
||||
self._total_input_tokens += max(int(getattr(usage, "input_tokens", 0) or 0), 0)
|
||||
self._total_output_tokens += max(
|
||||
int(getattr(usage, "output_tokens", 0) or 0), 0
|
||||
)
|
||||
self._total_latency_ms += max(
|
||||
int(round(float(getattr(usage, "time", 0) or 0) * 1000)), 0
|
||||
)
|
||||
metadata = getattr(usage, "metadata", None)
|
||||
if metadata is None:
|
||||
return
|
||||
self._cached_prompt_tokens += max(self._extract_cached_tokens(metadata), 0)
|
||||
|
||||
@staticmethod
|
||||
def _extract_cached_tokens(metadata: Any) -> int:
|
||||
if isinstance(metadata, dict):
|
||||
prompt_details = metadata.get("prompt_tokens_details")
|
||||
if isinstance(prompt_details, dict):
|
||||
return int(prompt_details.get("cached_tokens", 0) or 0)
|
||||
return 0
|
||||
|
||||
prompt_details = getattr(metadata, "prompt_tokens_details", None)
|
||||
return int(getattr(prompt_details, "cached_tokens", 0) or 0)
|
||||
@@ -0,0 +1,96 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from decimal import Decimal
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from core.agentscope.events.persistence import MessageRepository, SessionRepository
|
||||
from models.agent_chat_message import AgentChatMessageRole
|
||||
from models.agent_chat_session import AgentChatSessionStatus
|
||||
from schemas.agent.runtime_models import RouterAgentOutput
|
||||
from schemas.agent.system_agent import AgentType
|
||||
from schemas.messages.chat_message import AgentChatMessage, AgentChatMessageMetadata
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
|
||||
def _to_int(value: object) -> int:
|
||||
if value is None:
|
||||
return 0
|
||||
if isinstance(value, bool):
|
||||
return int(value)
|
||||
if isinstance(value, int):
|
||||
return value
|
||||
if isinstance(value, Decimal):
|
||||
return int(value)
|
||||
if isinstance(value, float):
|
||||
return int(value)
|
||||
if isinstance(value, str):
|
||||
text = value.strip()
|
||||
if not text:
|
||||
return 0
|
||||
try:
|
||||
return int(text)
|
||||
except ValueError:
|
||||
return int(float(text))
|
||||
return 0
|
||||
|
||||
|
||||
async def persist_router_message(
|
||||
*,
|
||||
session: AsyncSession,
|
||||
thread_id: str,
|
||||
run_id: str,
|
||||
model_code: str,
|
||||
router_output: RouterAgentOutput,
|
||||
response_metadata: dict[str, object],
|
||||
) -> None:
|
||||
session_id = UUID(thread_id)
|
||||
message_repo = MessageRepository(session)
|
||||
session_repo = SessionRepository(session)
|
||||
locked_session = await session_repo.lock_session_for_update(session_id=session_id)
|
||||
if locked_session is None:
|
||||
raise RuntimeError("chat session not found for router persistence")
|
||||
|
||||
seq = _to_int(getattr(locked_session, "message_count", 0)) + 1
|
||||
metadata = AgentChatMessageMetadata(
|
||||
run_id=run_id,
|
||||
agent_type=AgentType.ROUTER,
|
||||
router_agent_output=router_output,
|
||||
)
|
||||
message_payload = AgentChatMessage(
|
||||
id=uuid4(),
|
||||
seq=seq,
|
||||
role=AgentChatMessageRole.ASSISTANT.value,
|
||||
content="",
|
||||
model_code=model_code,
|
||||
tool_name=None,
|
||||
input_tokens=_to_int(response_metadata.get("inputTokens", 0)),
|
||||
output_tokens=_to_int(response_metadata.get("outputTokens", 0)),
|
||||
cost=Decimal(str(response_metadata.get("cost", 0) or 0)),
|
||||
latency_ms=_to_int(response_metadata.get("latencyMs", 0)),
|
||||
metadata=metadata,
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
await message_repo.append_message(
|
||||
session_id=session_id,
|
||||
seq=message_payload.seq,
|
||||
role=AgentChatMessageRole.ASSISTANT,
|
||||
content=message_payload.content,
|
||||
model_code=message_payload.model_code,
|
||||
tool_name=message_payload.tool_name,
|
||||
metadata=metadata.model_dump(mode="json", exclude_none=True),
|
||||
input_tokens=message_payload.input_tokens,
|
||||
output_tokens=message_payload.output_tokens,
|
||||
cost=message_payload.cost,
|
||||
latency_ms=message_payload.latency_ms,
|
||||
)
|
||||
await session_repo.update_runtime_state(
|
||||
chat_session=locked_session,
|
||||
status=AgentChatSessionStatus.RUNNING,
|
||||
state_snapshot=locked_session.state_snapshot or {},
|
||||
message_delta=1,
|
||||
token_delta=message_payload.input_tokens + message_payload.output_tokens,
|
||||
cost_delta=message_payload.cost,
|
||||
)
|
||||
await session.flush()
|
||||
@@ -1,30 +1,28 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from collections.abc import AsyncGenerator
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timezone
|
||||
from decimal import Decimal
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from uuid import UUID, uuid4
|
||||
from uuid import UUID
|
||||
|
||||
from ag_ui.core.types import RunAgentInput
|
||||
from agentscope.formatter import OpenAIChatFormatter
|
||||
from agentscope.memory import InMemoryMemory
|
||||
from agentscope.message import Msg
|
||||
from agentscope.model import OpenAIChatModel
|
||||
from core.agentscope.events.persistence import MessageRepository, SessionRepository
|
||||
from core.agentscope.runtime.json_react_agent import JsonReActAgent
|
||||
from core.agentscope.prompts.agent_prompt import build_worker_contract_prompt
|
||||
from core.agentscope.prompts.system_prompt import build_system_prompt
|
||||
from core.agentscope.tools.toolkit import build_stage_toolkit, build_toolkit
|
||||
from core.agentscope.runtime.utils import (
|
||||
normalize_tool_name,
|
||||
parse_tool_agent_output,
|
||||
from core.agentscope.runtime.json_react_agent import JsonReActAgent
|
||||
from core.agentscope.runtime.model_tracking import TrackingChatModel
|
||||
from core.agentscope.runtime.router_persistence import persist_router_message
|
||||
from core.agentscope.runtime.stage_emitter import PipelineStageEmitter
|
||||
from core.agentscope.tools.toolkit import build_stage_toolkit
|
||||
from core.agentscope.utils import (
|
||||
finalize_json_response,
|
||||
patch_agentscope_json_repair_compat,
|
||||
)
|
||||
from core.db.session import AsyncSessionLocal
|
||||
from core.logging import get_logger
|
||||
from models.agent_chat_message import AgentChatMessageRole
|
||||
from models.agent_chat_session import AgentChatSessionStatus
|
||||
from models.llm import Llm
|
||||
from models.system_agents import SystemAgents
|
||||
from schemas.agent.runtime_models import (
|
||||
@@ -33,7 +31,6 @@ from schemas.agent.runtime_models import (
|
||||
resolve_worker_output_model,
|
||||
)
|
||||
from schemas.agent.system_agent import AgentType, SystemAgentLLMConfig
|
||||
from schemas.messages.chat_message import AgentChatMessage, AgentChatMessageMetadata
|
||||
from schemas.user import UserContext
|
||||
from services.litellm.service import LiteLLMService
|
||||
from sqlalchemy import select
|
||||
@@ -59,246 +56,9 @@ class StageExecutionResult:
|
||||
response_metadata: dict[str, Any]
|
||||
|
||||
|
||||
class _TrackingChatModel:
|
||||
def __init__(self, inner: OpenAIChatModel) -> None:
|
||||
self._inner = inner
|
||||
self._total_input_tokens = 0
|
||||
self._total_output_tokens = 0
|
||||
self._total_latency_ms = 0
|
||||
self._cached_prompt_tokens = 0
|
||||
|
||||
@property
|
||||
def stream(self) -> bool:
|
||||
return self._inner.stream
|
||||
|
||||
@stream.setter
|
||||
def stream(self, value: bool) -> None:
|
||||
self._inner.stream = value
|
||||
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
return getattr(self._inner, name)
|
||||
|
||||
async def __call__(self, *args: Any, **kwargs: Any) -> Any:
|
||||
tools = kwargs.get("tools")
|
||||
tool_names: list[str] = []
|
||||
generate_response_schema: dict[str, Any] | None = None
|
||||
if isinstance(tools, list):
|
||||
for tool in tools:
|
||||
if not isinstance(tool, dict):
|
||||
continue
|
||||
function = tool.get("function")
|
||||
if isinstance(function, dict):
|
||||
name = function.get("name")
|
||||
if isinstance(name, str):
|
||||
tool_names.append(name)
|
||||
if name == "generate_response":
|
||||
parameters = function.get("parameters")
|
||||
if isinstance(parameters, dict):
|
||||
generate_response_schema = {
|
||||
"required": parameters.get("required"),
|
||||
"properties": list(
|
||||
(
|
||||
parameters.get("properties", {})
|
||||
if isinstance(
|
||||
parameters.get("properties", {}), dict
|
||||
)
|
||||
else {}
|
||||
).keys()
|
||||
),
|
||||
}
|
||||
logger.info(
|
||||
"model_call_debug",
|
||||
tool_choice=kwargs.get("tool_choice"),
|
||||
tool_count=len(tool_names),
|
||||
tool_names=tool_names,
|
||||
generate_response_schema=generate_response_schema,
|
||||
)
|
||||
response = await self._inner(*args, **kwargs)
|
||||
if isinstance(response, AsyncGenerator):
|
||||
return self._track_stream(response)
|
||||
self._record_usage(getattr(response, "usage", None))
|
||||
return response
|
||||
|
||||
async def _track_stream(
|
||||
self, response: AsyncGenerator[Any, None]
|
||||
) -> AsyncGenerator[Any, None]:
|
||||
latest_usage = None
|
||||
async for chunk in response:
|
||||
usage = getattr(chunk, "usage", None)
|
||||
if usage is not None:
|
||||
latest_usage = usage
|
||||
yield chunk
|
||||
self._record_usage(latest_usage)
|
||||
|
||||
def _record_usage(self, usage: Any) -> None:
|
||||
if usage is None:
|
||||
return
|
||||
self._total_input_tokens += max(int(getattr(usage, "input_tokens", 0) or 0), 0)
|
||||
self._total_output_tokens += max(
|
||||
int(getattr(usage, "output_tokens", 0) or 0), 0
|
||||
)
|
||||
self._total_latency_ms += max(
|
||||
int(round(float(getattr(usage, "time", 0) or 0) * 1000)), 0
|
||||
)
|
||||
metadata = getattr(usage, "metadata", None)
|
||||
if metadata is not None:
|
||||
cached_tokens = 0
|
||||
if isinstance(metadata, dict):
|
||||
prompt_details = metadata.get("prompt_tokens_details")
|
||||
if isinstance(prompt_details, dict):
|
||||
cached_tokens = int(prompt_details.get("cached_tokens", 0) or 0)
|
||||
else:
|
||||
prompt_details = getattr(metadata, "prompt_tokens_details", None)
|
||||
cached_tokens = int(getattr(prompt_details, "cached_tokens", 0) or 0)
|
||||
self._cached_prompt_tokens += max(cached_tokens, 0)
|
||||
|
||||
def usage_summary(self) -> dict[str, int]:
|
||||
return {
|
||||
"input_tokens": self._total_input_tokens,
|
||||
"output_tokens": self._total_output_tokens,
|
||||
"latency_ms": self._total_latency_ms,
|
||||
"cached_prompt_tokens": self._cached_prompt_tokens,
|
||||
}
|
||||
|
||||
|
||||
class _PipelineStageEmitter:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
pipeline: PipelineLike,
|
||||
session_id: str,
|
||||
run_id: str,
|
||||
stage: str,
|
||||
emit_text_events: bool,
|
||||
emit_tool_events: bool,
|
||||
) -> None:
|
||||
self._pipeline = pipeline
|
||||
self._session_id = session_id
|
||||
self._run_id = run_id
|
||||
self._stage = stage
|
||||
self._emit_text_events = emit_text_events
|
||||
self._emit_tool_events = emit_tool_events
|
||||
self._text_by_message_id: dict[str, str] = {}
|
||||
self._emitted_tool_calls: set[str] = set()
|
||||
self._emitted_tool_results: set[str] = set()
|
||||
self.latest_text_message_id: str | None = None
|
||||
self.latest_text: str = ""
|
||||
|
||||
async def handle_print(self, *, msg: Msg, last: bool) -> None:
|
||||
del last
|
||||
if self._emit_tool_events:
|
||||
await self._emit_tool_events_from_msg(msg)
|
||||
if self._emit_text_events:
|
||||
await self._emit_text_events_from_msg(msg)
|
||||
|
||||
async def _emit_text_events_from_msg(self, msg: Msg) -> None:
|
||||
text = msg.get_text_content(separator="") or ""
|
||||
if not text:
|
||||
return
|
||||
message_id = str(msg.id)
|
||||
self._text_by_message_id[message_id] = text
|
||||
self.latest_text_message_id = message_id
|
||||
self.latest_text = text
|
||||
|
||||
async def _emit_tool_events_from_msg(self, msg: Msg) -> None:
|
||||
for block in msg.get_content_blocks("tool_use"):
|
||||
tool_call_id = str(block.get("id", "")).strip()
|
||||
tool_name = str(block.get("name", "")).strip()
|
||||
if (
|
||||
not tool_call_id
|
||||
or not tool_name
|
||||
or tool_call_id in self._emitted_tool_calls
|
||||
):
|
||||
continue
|
||||
payload = {
|
||||
"messageId": str(msg.id),
|
||||
"toolCallId": tool_call_id,
|
||||
"toolCallName": tool_name,
|
||||
"stage": self._stage,
|
||||
}
|
||||
await self._emit("TOOL_CALL_START", payload)
|
||||
await self._emit(
|
||||
"TOOL_CALL_ARGS",
|
||||
{
|
||||
**payload,
|
||||
"args": block.get("input", {}),
|
||||
},
|
||||
)
|
||||
await self._emit("TOOL_CALL_END", payload)
|
||||
self._emitted_tool_calls.add(tool_call_id)
|
||||
|
||||
for block in msg.get_content_blocks("tool_result"):
|
||||
tool_call_id = str(block.get("id", "")).strip()
|
||||
if not tool_call_id or tool_call_id in self._emitted_tool_results:
|
||||
continue
|
||||
tool_output = parse_tool_agent_output(block.get("output"))
|
||||
if tool_output is None:
|
||||
continue
|
||||
|
||||
tool_output_dict = tool_output.model_dump(mode="json", exclude_none=True)
|
||||
|
||||
result_data = {
|
||||
"messageId": str(msg.id),
|
||||
"role": "tool",
|
||||
"stage": self._stage,
|
||||
"tool_name": tool_output.tool_name,
|
||||
"tool_call_id": tool_output.tool_call_id,
|
||||
"tool_call_args": tool_output.tool_call_args,
|
||||
"status": tool_output.status.value,
|
||||
"result_summary": tool_output.result_summary,
|
||||
}
|
||||
ui_hints = tool_output_dict.get("ui_hints")
|
||||
if ui_hints is not None:
|
||||
result_data["ui_hints"] = ui_hints
|
||||
if tool_output.error:
|
||||
result_data["error"] = tool_output.error.model_dump(mode="json")
|
||||
|
||||
await self._emit("TOOL_CALL_RESULT", result_data)
|
||||
self._emitted_tool_results.add(tool_call_id)
|
||||
|
||||
async def emit_final_text_end(
|
||||
self,
|
||||
*,
|
||||
worker_output: dict[str, Any],
|
||||
response_metadata: dict[str, Any],
|
||||
) -> None:
|
||||
message_id = (
|
||||
self.latest_text_message_id or f"worker-{self._run_id}-{uuid4().hex[:8]}"
|
||||
)
|
||||
|
||||
output_data = {
|
||||
"messageId": message_id,
|
||||
"role": "assistant",
|
||||
"stage": self._stage,
|
||||
"status": worker_output.get("status"),
|
||||
"answer": worker_output.get("answer", ""),
|
||||
"key_points": worker_output.get("key_points", []),
|
||||
"result_type": worker_output.get("result_type"),
|
||||
"suggested_actions": worker_output.get("suggested_actions", []),
|
||||
"error": worker_output.get("error"),
|
||||
}
|
||||
ui_hints = worker_output.get("ui_hints")
|
||||
if ui_hints is not None:
|
||||
output_data["ui_hints"] = ui_hints
|
||||
|
||||
output_data.update(response_metadata)
|
||||
|
||||
await self._emit("TEXT_MESSAGE_END", output_data)
|
||||
|
||||
async def _emit(self, event_type: str, payload: dict[str, Any]) -> None:
|
||||
await self._pipeline.emit(
|
||||
session_id=self._session_id,
|
||||
event={
|
||||
"type": event_type,
|
||||
"threadId": self._session_id,
|
||||
"runId": self._run_id,
|
||||
**payload,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
class AgentScopeRunner:
|
||||
def __init__(self, *, litellm_service: LiteLLMService | None = None) -> None:
|
||||
patch_agentscope_json_repair_compat()
|
||||
self._litellm_service = litellm_service or LiteLLMService()
|
||||
|
||||
async def execute(
|
||||
@@ -310,76 +70,30 @@ class AgentScopeRunner:
|
||||
run_input: RunAgentInput,
|
||||
) -> dict[str, Any]:
|
||||
owner_id = UUID(user_context.id)
|
||||
enabled_tool_names = self._extract_tool_names(run_input)
|
||||
|
||||
async with AsyncSessionLocal() as session:
|
||||
router_toolkit, worker_toolkit = self._build_toolkits(
|
||||
session=session,
|
||||
owner_id=owner_id,
|
||||
enabled_tool_names=enabled_tool_names,
|
||||
worker_toolkit = self._build_worker_toolkit(
|
||||
session=session, owner_id=owner_id
|
||||
)
|
||||
router_config, worker_config = await self._load_stage_configs(
|
||||
session=session
|
||||
)
|
||||
|
||||
router_config = await self._load_system_agent_config(
|
||||
router_output = await self._execute_router_step(
|
||||
session=session,
|
||||
agent_type=AgentType.ROUTER,
|
||||
)
|
||||
worker_config = await self._load_system_agent_config(
|
||||
session=session,
|
||||
agent_type=AgentType.WORKER,
|
||||
)
|
||||
|
||||
await self._emit_step_event(
|
||||
pipeline=pipeline,
|
||||
run_input=run_input,
|
||||
step_name="router",
|
||||
event_type="STEP_STARTED",
|
||||
)
|
||||
router_result = await self._run_router_stage(
|
||||
user_context=user_context,
|
||||
context_messages=context_messages,
|
||||
toolkit=router_toolkit,
|
||||
run_input=run_input,
|
||||
stage_config=router_config,
|
||||
)
|
||||
router_output = RouterAgentOutput.model_validate(router_result.payload)
|
||||
await self._persist_router_message(
|
||||
session=session,
|
||||
thread_id=run_input.thread_id,
|
||||
run_id=run_input.run_id,
|
||||
model_code=router_config.model_code,
|
||||
router_output=router_output,
|
||||
response_metadata=router_result.response_metadata,
|
||||
)
|
||||
await session.commit()
|
||||
await self._emit_step_event(
|
||||
worker_output = await self._execute_worker_step(
|
||||
pipeline=pipeline,
|
||||
run_input=run_input,
|
||||
step_name="router",
|
||||
event_type="STEP_FINISHED",
|
||||
)
|
||||
|
||||
worker_output_model = resolve_worker_output_model(router_output.ui.ui_mode)
|
||||
await self._emit_step_event(
|
||||
pipeline=pipeline,
|
||||
run_input=run_input,
|
||||
step_name="worker",
|
||||
event_type="STEP_STARTED",
|
||||
)
|
||||
worker_result = await self._run_worker_stage(
|
||||
user_context=user_context,
|
||||
router_output=router_output,
|
||||
toolkit=worker_toolkit,
|
||||
run_input=run_input,
|
||||
stage_config=worker_config,
|
||||
worker_output_model=worker_output_model,
|
||||
pipeline=pipeline,
|
||||
)
|
||||
worker_output = worker_output_model.model_validate(worker_result.payload)
|
||||
await self._emit_step_event(
|
||||
pipeline=pipeline,
|
||||
run_input=run_input,
|
||||
step_name="worker",
|
||||
event_type="STEP_FINISHED",
|
||||
)
|
||||
|
||||
return {
|
||||
@@ -387,40 +101,107 @@ class AgentScopeRunner:
|
||||
"worker": worker_output.model_dump(mode="json", exclude_none=True),
|
||||
}
|
||||
|
||||
def _build_toolkits(
|
||||
def _build_worker_toolkit(
|
||||
self,
|
||||
*,
|
||||
session: AsyncSession,
|
||||
owner_id: UUID,
|
||||
enabled_tool_names: set[str] | None,
|
||||
) -> tuple[Any, Any]:
|
||||
return (
|
||||
build_toolkit(
|
||||
session=session,
|
||||
owner_id=owner_id,
|
||||
enabled_tool_names=set(),
|
||||
),
|
||||
build_stage_toolkit(
|
||||
agent_type=AgentType.WORKER,
|
||||
session=session,
|
||||
owner_id=owner_id,
|
||||
enabled_tool_names=enabled_tool_names,
|
||||
),
|
||||
) -> Any:
|
||||
return build_stage_toolkit(
|
||||
agent_type=AgentType.WORKER,
|
||||
session=session,
|
||||
owner_id=owner_id,
|
||||
)
|
||||
|
||||
def _extract_tool_names(self, run_input: RunAgentInput) -> set[str] | None:
|
||||
raw_tools = getattr(run_input, "tools", None)
|
||||
if not isinstance(raw_tools, list):
|
||||
return None
|
||||
selected: set[str] = set()
|
||||
for item in raw_tools:
|
||||
if isinstance(item, dict):
|
||||
name = item.get("name")
|
||||
else:
|
||||
name = getattr(item, "name", None)
|
||||
if isinstance(name, str) and name.strip():
|
||||
selected.add(normalize_tool_name(name))
|
||||
return selected
|
||||
async def _load_stage_configs(
|
||||
self,
|
||||
*,
|
||||
session: AsyncSession,
|
||||
) -> tuple[SystemAgentRuntimeConfig, SystemAgentRuntimeConfig]:
|
||||
router_config = await self._load_system_agent_config(
|
||||
session=session,
|
||||
agent_type=AgentType.ROUTER,
|
||||
)
|
||||
worker_config = await self._load_system_agent_config(
|
||||
session=session,
|
||||
agent_type=AgentType.WORKER,
|
||||
)
|
||||
return router_config, worker_config
|
||||
|
||||
async def _execute_router_step(
|
||||
self,
|
||||
*,
|
||||
session: AsyncSession,
|
||||
pipeline: PipelineLike,
|
||||
run_input: RunAgentInput,
|
||||
user_context: UserContext,
|
||||
context_messages: list[Msg],
|
||||
stage_config: SystemAgentRuntimeConfig,
|
||||
) -> RouterAgentOutput:
|
||||
await self._emit_step_event(
|
||||
pipeline=pipeline,
|
||||
run_input=run_input,
|
||||
step_name="router",
|
||||
event_type="STEP_STARTED",
|
||||
)
|
||||
router_result = await self._run_router_stage(
|
||||
user_context=user_context,
|
||||
context_messages=context_messages,
|
||||
run_input=run_input,
|
||||
stage_config=stage_config,
|
||||
)
|
||||
router_output = RouterAgentOutput.model_validate(router_result.payload)
|
||||
await persist_router_message(
|
||||
session=session,
|
||||
thread_id=run_input.thread_id,
|
||||
run_id=run_input.run_id,
|
||||
model_code=stage_config.model_code,
|
||||
router_output=router_output,
|
||||
response_metadata=router_result.response_metadata,
|
||||
)
|
||||
await session.commit()
|
||||
await self._emit_step_event(
|
||||
pipeline=pipeline,
|
||||
run_input=run_input,
|
||||
step_name="router",
|
||||
event_type="STEP_FINISHED",
|
||||
)
|
||||
return router_output
|
||||
|
||||
async def _execute_worker_step(
|
||||
self,
|
||||
*,
|
||||
pipeline: PipelineLike,
|
||||
run_input: RunAgentInput,
|
||||
user_context: UserContext,
|
||||
router_output: RouterAgentOutput,
|
||||
toolkit: Any,
|
||||
stage_config: SystemAgentRuntimeConfig,
|
||||
) -> WorkerAgentOutputLite:
|
||||
worker_output_model = resolve_worker_output_model(router_output.ui.ui_mode)
|
||||
await self._emit_step_event(
|
||||
pipeline=pipeline,
|
||||
run_input=run_input,
|
||||
step_name="worker",
|
||||
event_type="STEP_STARTED",
|
||||
)
|
||||
worker_result = await self._run_worker_stage(
|
||||
user_context=user_context,
|
||||
router_output=router_output,
|
||||
toolkit=toolkit,
|
||||
run_input=run_input,
|
||||
stage_config=stage_config,
|
||||
worker_output_model=worker_output_model,
|
||||
pipeline=pipeline,
|
||||
)
|
||||
worker_output = worker_output_model.model_validate(worker_result.payload)
|
||||
await self._emit_step_event(
|
||||
pipeline=pipeline,
|
||||
run_input=run_input,
|
||||
step_name="worker",
|
||||
event_type="STEP_FINISHED",
|
||||
)
|
||||
return worker_output
|
||||
|
||||
async def _load_system_agent_config(
|
||||
self,
|
||||
@@ -451,7 +232,6 @@ class AgentScopeRunner:
|
||||
*,
|
||||
user_context: UserContext,
|
||||
context_messages: list[Msg],
|
||||
toolkit: Any,
|
||||
run_input: RunAgentInput,
|
||||
stage_config: SystemAgentRuntimeConfig,
|
||||
) -> StageExecutionResult:
|
||||
@@ -462,28 +242,26 @@ class AgentScopeRunner:
|
||||
now_utc=datetime.now(timezone.utc),
|
||||
tools=None,
|
||||
)
|
||||
agent = self._build_agent(
|
||||
agent_name="router",
|
||||
system_prompt=system_prompt,
|
||||
toolkit=toolkit,
|
||||
response, payload = await finalize_json_response(
|
||||
model=tracking_model,
|
||||
)
|
||||
response_msg = await agent.reply_json(
|
||||
context_messages,
|
||||
formatter=OpenAIChatFormatter(),
|
||||
base_messages=[Msg("system", system_prompt, "system"), *context_messages],
|
||||
output_model=RouterAgentOutput,
|
||||
retries=0,
|
||||
)
|
||||
response_msg = Msg(
|
||||
name="router",
|
||||
role="assistant",
|
||||
content=list(getattr(response, "content", [])),
|
||||
metadata=payload,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"router_reply_received",
|
||||
run_id=run_input.run_id,
|
||||
thread_id=run_input.thread_id,
|
||||
message_id=str(response_msg.id),
|
||||
)
|
||||
payload = RouterAgentOutput.model_validate(
|
||||
response_msg.metadata or {}
|
||||
).model_dump(
|
||||
mode="json",
|
||||
exclude_none=True,
|
||||
)
|
||||
return StageExecutionResult(
|
||||
message=response_msg,
|
||||
payload=payload,
|
||||
@@ -504,11 +282,9 @@ class AgentScopeRunner:
|
||||
worker_output_model: type[WorkerAgentOutputLite],
|
||||
pipeline: PipelineLike,
|
||||
) -> StageExecutionResult:
|
||||
worker_input = self._build_worker_input_messages(
|
||||
router_output=router_output,
|
||||
)
|
||||
worker_input = self._build_worker_input_messages(router_output=router_output)
|
||||
tracking_model = self._build_model(stage_config=stage_config)
|
||||
emitter = _PipelineStageEmitter(
|
||||
emitter = PipelineStageEmitter(
|
||||
pipeline=pipeline,
|
||||
session_id=run_input.thread_id,
|
||||
run_id=run_input.run_id,
|
||||
@@ -522,15 +298,14 @@ class AgentScopeRunner:
|
||||
agent_type=AgentType.WORKER,
|
||||
user_context=user_context,
|
||||
now_utc=datetime.now(timezone.utc),
|
||||
tools=run_input.tools,
|
||||
tools=None,
|
||||
),
|
||||
toolkit=toolkit,
|
||||
model=tracking_model,
|
||||
emitter=emitter,
|
||||
)
|
||||
response_msg = await agent.reply_json(
|
||||
worker_input,
|
||||
output_model=worker_output_model,
|
||||
worker_input, output_model=worker_output_model
|
||||
)
|
||||
worker_payload = worker_output_model.model_validate(response_msg.metadata or {})
|
||||
response_metadata = self._litellm_service.build_usage_metadata(
|
||||
@@ -552,24 +327,17 @@ class AgentScopeRunner:
|
||||
*,
|
||||
router_output: RouterAgentOutput,
|
||||
) -> list[Msg]:
|
||||
routing_contract = json.dumps(
|
||||
router_output.model_dump(mode="json", exclude_none=True),
|
||||
ensure_ascii=False,
|
||||
separators=(",", ":"),
|
||||
)
|
||||
routing_msg = Msg(
|
||||
name="router",
|
||||
role="user",
|
||||
content=(
|
||||
"Use the following routing contract as the execution source of truth. "
|
||||
f"Do not change the routed objective:\n{routing_contract}"
|
||||
),
|
||||
)
|
||||
return [routing_msg]
|
||||
return [
|
||||
Msg(
|
||||
name="router",
|
||||
role="user",
|
||||
content=build_worker_contract_prompt(router_output=router_output),
|
||||
)
|
||||
]
|
||||
|
||||
def _build_model(
|
||||
self, *, stage_config: SystemAgentRuntimeConfig
|
||||
) -> _TrackingChatModel:
|
||||
) -> TrackingChatModel:
|
||||
generate_kwargs: dict[str, Any] = {
|
||||
"temperature": stage_config.llm_config.temperature,
|
||||
"max_tokens": stage_config.llm_config.max_tokens,
|
||||
@@ -585,7 +353,7 @@ class AgentScopeRunner:
|
||||
client_kwargs={"base_url": self._litellm_service.proxy_base_url},
|
||||
generate_kwargs=generate_kwargs,
|
||||
)
|
||||
return _TrackingChatModel(model)
|
||||
return TrackingChatModel(model)
|
||||
|
||||
def _build_agent(
|
||||
self,
|
||||
@@ -593,8 +361,8 @@ class AgentScopeRunner:
|
||||
agent_name: str,
|
||||
system_prompt: str,
|
||||
toolkit: Any,
|
||||
model: _TrackingChatModel,
|
||||
emitter: _PipelineStageEmitter | None = None,
|
||||
model: TrackingChatModel,
|
||||
emitter: PipelineStageEmitter | None = None,
|
||||
) -> JsonReActAgent:
|
||||
return JsonReActAgent(
|
||||
name=agent_name,
|
||||
@@ -624,66 +392,5 @@ class AgentScopeRunner:
|
||||
},
|
||||
)
|
||||
|
||||
async def _persist_router_message(
|
||||
self,
|
||||
*,
|
||||
session: AsyncSession,
|
||||
thread_id: str,
|
||||
run_id: str,
|
||||
model_code: str,
|
||||
router_output: RouterAgentOutput,
|
||||
response_metadata: dict[str, Any],
|
||||
) -> None:
|
||||
session_id = UUID(thread_id)
|
||||
message_repo = MessageRepository(session)
|
||||
session_repo = SessionRepository(session)
|
||||
locked_session = await session_repo.lock_session_for_update(
|
||||
session_id=session_id
|
||||
)
|
||||
if locked_session is None:
|
||||
raise RuntimeError("chat session not found for router persistence")
|
||||
seq = int(getattr(locked_session, "message_count", 0) or 0) + 1
|
||||
metadata = AgentChatMessageMetadata(
|
||||
run_id=run_id,
|
||||
agent_type=AgentType.ROUTER,
|
||||
router_agent_output=router_output,
|
||||
)
|
||||
message_payload = AgentChatMessage(
|
||||
id=uuid4(),
|
||||
seq=seq,
|
||||
role=AgentChatMessageRole.ASSISTANT.value,
|
||||
content="",
|
||||
model_code=model_code,
|
||||
tool_name=None,
|
||||
input_tokens=int(response_metadata.get("inputTokens", 0) or 0),
|
||||
output_tokens=int(response_metadata.get("outputTokens", 0) or 0),
|
||||
cost=Decimal(str(response_metadata.get("cost", 0) or 0)),
|
||||
latency_ms=int(response_metadata.get("latencyMs", 0) or 0),
|
||||
metadata=metadata,
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
)
|
||||
await message_repo.append_message(
|
||||
session_id=session_id,
|
||||
seq=message_payload.seq,
|
||||
role=AgentChatMessageRole.ASSISTANT,
|
||||
content=message_payload.content,
|
||||
model_code=message_payload.model_code,
|
||||
tool_name=message_payload.tool_name,
|
||||
metadata=metadata.model_dump(mode="json", exclude_none=True),
|
||||
input_tokens=message_payload.input_tokens,
|
||||
output_tokens=message_payload.output_tokens,
|
||||
cost=message_payload.cost,
|
||||
latency_ms=message_payload.latency_ms,
|
||||
)
|
||||
await session_repo.update_runtime_state(
|
||||
chat_session=locked_session,
|
||||
status=AgentChatSessionStatus.RUNNING,
|
||||
state_snapshot=locked_session.state_snapshot or {},
|
||||
message_delta=1,
|
||||
token_delta=message_payload.input_tokens + message_payload.output_tokens,
|
||||
cost_delta=message_payload.cost,
|
||||
)
|
||||
await session.flush()
|
||||
|
||||
|
||||
AgentScopeReActRunner = AgentScopeRunner
|
||||
|
||||
@@ -0,0 +1,137 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Protocol
|
||||
from uuid import uuid4
|
||||
|
||||
from agentscope.message import Msg
|
||||
|
||||
from core.agentscope.utils import parse_tool_agent_output
|
||||
|
||||
|
||||
class PipelineLike(Protocol):
|
||||
async def emit(self, *, session_id: str, event: dict[str, Any]) -> str: ...
|
||||
|
||||
|
||||
class PipelineStageEmitter:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
pipeline: PipelineLike,
|
||||
session_id: str,
|
||||
run_id: str,
|
||||
stage: str,
|
||||
emit_text_events: bool,
|
||||
emit_tool_events: bool,
|
||||
) -> None:
|
||||
self._pipeline = pipeline
|
||||
self._session_id = session_id
|
||||
self._run_id = run_id
|
||||
self._stage = stage
|
||||
self._emit_text_events = emit_text_events
|
||||
self._emit_tool_events = emit_tool_events
|
||||
self._emitted_tool_calls: set[str] = set()
|
||||
self._emitted_tool_results: set[str] = set()
|
||||
self.latest_text_message_id: str | None = None
|
||||
self.latest_text: str = ""
|
||||
|
||||
async def handle_print(self, *, msg: Msg, last: bool) -> None:
|
||||
del last
|
||||
if self._emit_tool_events:
|
||||
await self._emit_tool_events_from_msg(msg)
|
||||
if self._emit_text_events:
|
||||
await self._emit_text_events_from_msg(msg)
|
||||
|
||||
async def emit_final_text_end(
|
||||
self,
|
||||
*,
|
||||
worker_output: dict[str, Any],
|
||||
response_metadata: dict[str, Any],
|
||||
) -> None:
|
||||
message_id = (
|
||||
self.latest_text_message_id or f"worker-{self._run_id}-{uuid4().hex[:8]}"
|
||||
)
|
||||
payload = {
|
||||
"messageId": message_id,
|
||||
"role": "assistant",
|
||||
"stage": self._stage,
|
||||
"status": worker_output.get("status"),
|
||||
"answer": worker_output.get("answer", ""),
|
||||
"key_points": worker_output.get("key_points", []),
|
||||
"result_type": worker_output.get("result_type"),
|
||||
"suggested_actions": worker_output.get("suggested_actions", []),
|
||||
"error": worker_output.get("error"),
|
||||
**response_metadata,
|
||||
}
|
||||
ui_hints = worker_output.get("ui_hints")
|
||||
if ui_hints is not None:
|
||||
payload["ui_hints"] = ui_hints
|
||||
await self._emit("TEXT_MESSAGE_END", payload)
|
||||
|
||||
async def _emit_text_events_from_msg(self, msg: Msg) -> None:
|
||||
text = msg.get_text_content(separator="") or ""
|
||||
if not text:
|
||||
return
|
||||
self.latest_text_message_id = str(msg.id)
|
||||
self.latest_text = text
|
||||
|
||||
async def _emit_tool_events_from_msg(self, msg: Msg) -> None:
|
||||
for block in msg.get_content_blocks("tool_use"):
|
||||
tool_call_id = str(block.get("id", "")).strip()
|
||||
tool_name = str(block.get("name", "")).strip()
|
||||
if (
|
||||
not tool_call_id
|
||||
or not tool_name
|
||||
or tool_call_id in self._emitted_tool_calls
|
||||
):
|
||||
continue
|
||||
base_payload = {
|
||||
"messageId": str(msg.id),
|
||||
"toolCallId": tool_call_id,
|
||||
"toolCallName": tool_name,
|
||||
"stage": self._stage,
|
||||
}
|
||||
await self._emit("TOOL_CALL_START", base_payload)
|
||||
await self._emit(
|
||||
"TOOL_CALL_ARGS", {**base_payload, "args": block.get("input", {})}
|
||||
)
|
||||
await self._emit("TOOL_CALL_END", base_payload)
|
||||
self._emitted_tool_calls.add(tool_call_id)
|
||||
|
||||
for block in msg.get_content_blocks("tool_result"):
|
||||
tool_call_id = str(block.get("id", "")).strip()
|
||||
if not tool_call_id or tool_call_id in self._emitted_tool_results:
|
||||
continue
|
||||
tool_output = parse_tool_agent_output(block.get("output"))
|
||||
if tool_output is None:
|
||||
continue
|
||||
payload = {
|
||||
"messageId": str(msg.id),
|
||||
"role": "tool",
|
||||
"stage": self._stage,
|
||||
"tool_name": tool_output.tool_name,
|
||||
"tool_call_id": tool_output.tool_call_id,
|
||||
"tool_call_args": tool_output.tool_call_args,
|
||||
"status": tool_output.status.value,
|
||||
"result_summary": tool_output.result_summary,
|
||||
}
|
||||
ui_hints = tool_output.model_dump(mode="json", exclude_none=True).get(
|
||||
"ui_hints"
|
||||
)
|
||||
if ui_hints is not None:
|
||||
payload["ui_hints"] = ui_hints
|
||||
if tool_output.error:
|
||||
payload["error"] = tool_output.error.model_dump(mode="json")
|
||||
|
||||
await self._emit("TOOL_CALL_RESULT", payload)
|
||||
self._emitted_tool_results.add(tool_call_id)
|
||||
|
||||
async def _emit(self, event_type: str, payload: dict[str, Any]) -> None:
|
||||
await self._pipeline.emit(
|
||||
session_id=self._session_id,
|
||||
event={
|
||||
"type": event_type,
|
||||
"threadId": self._session_id,
|
||||
"runId": self._run_id,
|
||||
**payload,
|
||||
},
|
||||
)
|
||||
@@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
from typing import Any
|
||||
from typing import Any, cast
|
||||
from uuid import UUID
|
||||
|
||||
from agentscope.message import Msg
|
||||
@@ -21,10 +21,12 @@ from core.taskiq.app import bulk_broker, critical_broker, default_broker
|
||||
from schemas.user import UserContext
|
||||
from services.base.redis import get_or_init_redis_client
|
||||
from services.base.supabase import supabase_service
|
||||
from schemas.messages.chat_message import extract_user_message_attachments
|
||||
from v1.agent.dependencies import get_agent_service
|
||||
from v1.users.dependencies import get_user_service
|
||||
|
||||
logger = get_logger("core.agentscope.runtime.tasks")
|
||||
_MAX_CONTEXT_ATTACHMENTS = 3
|
||||
|
||||
|
||||
def _load_runtime() -> type[Any]:
|
||||
@@ -63,38 +65,43 @@ async def _build_recent_context_messages(
|
||||
metadata = msg.get("metadata")
|
||||
|
||||
if role == "user" and metadata:
|
||||
attachments = metadata.get("user_message_attachments")
|
||||
if attachments:
|
||||
bucket = attachments.get("bucket")
|
||||
path = attachments.get("path")
|
||||
mime_type = attachments.get("mime_type")
|
||||
if bucket and path:
|
||||
try:
|
||||
image_bytes = await supabase_service.download_bytes(
|
||||
bucket=bucket,
|
||||
path=path,
|
||||
)
|
||||
b64_data = base64.b64encode(image_bytes).decode("utf-8")
|
||||
converted.append(
|
||||
Msg(
|
||||
name="user",
|
||||
role="user",
|
||||
content=[
|
||||
{"type": "text", "text": content},
|
||||
{
|
||||
"type": "image",
|
||||
"source": {
|
||||
"type": "base64",
|
||||
"media_type": mime_type or "image/png",
|
||||
"data": b64_data,
|
||||
},
|
||||
},
|
||||
],
|
||||
)
|
||||
)
|
||||
continue
|
||||
except Exception:
|
||||
pass
|
||||
image_blocks: list[dict[str, Any]] = []
|
||||
attachments = extract_user_message_attachments(metadata)[
|
||||
:_MAX_CONTEXT_ATTACHMENTS
|
||||
]
|
||||
for attachment in attachments:
|
||||
try:
|
||||
image_bytes = await supabase_service.download_bytes(
|
||||
bucket=attachment.bucket,
|
||||
path=attachment.path,
|
||||
)
|
||||
except Exception:
|
||||
continue
|
||||
b64_data = base64.b64encode(image_bytes).decode("utf-8")
|
||||
image_blocks.append(
|
||||
{
|
||||
"type": "image",
|
||||
"source": {
|
||||
"type": "base64",
|
||||
"media_type": attachment.mime_type or "image/png",
|
||||
"data": b64_data,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
if image_blocks:
|
||||
multimodal_content: list[dict[str, Any]] = []
|
||||
if isinstance(content, str) and content:
|
||||
multimodal_content.append({"type": "text", "text": content})
|
||||
multimodal_content.extend(image_blocks)
|
||||
converted.append(
|
||||
Msg(
|
||||
name="user",
|
||||
role="user",
|
||||
content=cast(Any, multimodal_content),
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
if role == "tool":
|
||||
role = "assistant"
|
||||
|
||||
@@ -0,0 +1,23 @@
|
||||
from core.agentscope.utils.compat import (
|
||||
patch_agentscope_json_repair_compat,
|
||||
safe_json_loads_with_repair,
|
||||
)
|
||||
from core.agentscope.utils.json_finalize import (
|
||||
build_json_finalize_instruction,
|
||||
finalize_json_response,
|
||||
)
|
||||
from core.agentscope.utils.parsing import (
|
||||
extract_text_content,
|
||||
parse_json_dict,
|
||||
parse_tool_agent_output,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"build_json_finalize_instruction",
|
||||
"extract_text_content",
|
||||
"finalize_json_response",
|
||||
"parse_json_dict",
|
||||
"parse_tool_agent_output",
|
||||
"patch_agentscope_json_repair_compat",
|
||||
"safe_json_loads_with_repair",
|
||||
]
|
||||
@@ -0,0 +1,48 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
from core.logging import get_logger
|
||||
|
||||
logger = get_logger("core.agentscope.utils.compat")
|
||||
_AGENTSCOPE_JSON_REPAIR_PATCHED = False
|
||||
|
||||
|
||||
def safe_json_loads_with_repair(json_str: str) -> dict[str, Any]:
|
||||
try:
|
||||
from json_repair import repair_json
|
||||
|
||||
repair_json_any: Any = repair_json
|
||||
try:
|
||||
repaired = repair_json_any(json_str, **{"stream_stable": True})
|
||||
except TypeError:
|
||||
repaired = repair_json_any(json_str)
|
||||
|
||||
if isinstance(repaired, dict):
|
||||
return repaired
|
||||
if isinstance(repaired, str):
|
||||
loaded = json.loads(repaired)
|
||||
return loaded if isinstance(loaded, dict) else {}
|
||||
return {}
|
||||
except Exception: # noqa: BLE001
|
||||
preview = json_str[:100] + "..." if len(json_str) > 100 else json_str
|
||||
logger.warning("failed_to_parse_tool_arguments", preview=preview)
|
||||
return {}
|
||||
|
||||
|
||||
def patch_agentscope_json_repair_compat() -> None:
|
||||
global _AGENTSCOPE_JSON_REPAIR_PATCHED
|
||||
if _AGENTSCOPE_JSON_REPAIR_PATCHED:
|
||||
return
|
||||
|
||||
try:
|
||||
from agentscope._utils import _common as common_mod
|
||||
from agentscope.model import _openai_model as openai_model_mod
|
||||
except Exception: # noqa: BLE001
|
||||
return
|
||||
|
||||
common_mod._json_loads_with_repair = safe_json_loads_with_repair
|
||||
openai_model_mod._json_loads_with_repair = safe_json_loads_with_repair
|
||||
_AGENTSCOPE_JSON_REPAIR_PATCHED = True
|
||||
logger.info("patched_agentscope_json_repair_compat")
|
||||
@@ -0,0 +1,96 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from collections.abc import Awaitable
|
||||
from typing import Any, Protocol
|
||||
|
||||
from agentscope.message import Msg
|
||||
from pydantic import BaseModel, ValidationError
|
||||
|
||||
from core.agentscope.utils.parsing import extract_text_content, parse_json_dict
|
||||
|
||||
|
||||
class FormatterProtocol(Protocol):
|
||||
def format(self, *args: Any, **kwargs: Any) -> Awaitable[Any]: ...
|
||||
|
||||
|
||||
def build_json_finalize_instruction(
|
||||
*,
|
||||
schema_json: str,
|
||||
attempt: int,
|
||||
validation_error: str = "",
|
||||
) -> str:
|
||||
error_part = (
|
||||
""
|
||||
if not validation_error
|
||||
else (
|
||||
"\n\n[Validation Error From Previous Attempt]\n"
|
||||
f"{validation_error}\n"
|
||||
"Fix all missing/invalid fields and regenerate."
|
||||
)
|
||||
)
|
||||
return (
|
||||
"Return JSON only. Do not output markdown, prose, or code fences. "
|
||||
"Follow this JSON Schema exactly and include all required fields. "
|
||||
"Do not call tools.\n\n"
|
||||
f"[Schema]\n{schema_json}\n\n"
|
||||
f"[Attempt]\n{attempt}{error_part}"
|
||||
)
|
||||
|
||||
|
||||
async def finalize_json_response(
|
||||
*,
|
||||
model: Any,
|
||||
formatter: FormatterProtocol,
|
||||
base_messages: list[Msg],
|
||||
output_model: type[BaseModel],
|
||||
retries: int,
|
||||
) -> tuple[Any, dict[str, Any]]:
|
||||
schema_json = json.dumps(
|
||||
output_model.model_json_schema(),
|
||||
ensure_ascii=True,
|
||||
separators=(",", ":"),
|
||||
)
|
||||
last_error = ""
|
||||
|
||||
for attempt in range(1, retries + 2):
|
||||
prompt = await formatter.format(
|
||||
msgs=[
|
||||
*base_messages,
|
||||
Msg(
|
||||
"user",
|
||||
build_json_finalize_instruction(
|
||||
schema_json=schema_json,
|
||||
attempt=attempt,
|
||||
validation_error=last_error,
|
||||
),
|
||||
"user",
|
||||
),
|
||||
]
|
||||
)
|
||||
original_stream = model.stream
|
||||
model.stream = False
|
||||
try:
|
||||
response = await model(
|
||||
prompt,
|
||||
tool_choice="none",
|
||||
response_format={"type": "json_object"},
|
||||
)
|
||||
finally:
|
||||
model.stream = original_stream
|
||||
|
||||
raw_text = extract_text_content(getattr(response, "content", []))
|
||||
payload = parse_json_dict(raw_text)
|
||||
if payload is None:
|
||||
last_error = "Model output is not a valid JSON object."
|
||||
continue
|
||||
|
||||
try:
|
||||
validated = output_model.model_validate(payload)
|
||||
return response, validated.model_dump(mode="json", exclude_none=True)
|
||||
except ValidationError as exc:
|
||||
last_error = str(exc)
|
||||
|
||||
raise RuntimeError(
|
||||
f"failed to finalize structured output for {output_model.__name__}: {last_error}"
|
||||
)
|
||||
-14
@@ -4,23 +4,9 @@ import json
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
from core.agentscope.runtime.ui_compiler import compile as compile_ui_hints
|
||||
from schemas.agent.runtime_models import ToolAgentOutput
|
||||
|
||||
|
||||
def compile_ui_hints_safe(ui_hints: Any) -> dict[str, Any] | None:
|
||||
if not ui_hints:
|
||||
return None
|
||||
try:
|
||||
return compile_ui_hints(ui_hints)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def normalize_tool_name(value: str) -> str:
|
||||
return value.strip().replace(".", "_").replace("-", "_")
|
||||
|
||||
|
||||
def parse_tool_agent_output(output: Any) -> ToolAgentOutput | None:
|
||||
blocks = output if isinstance(output, Sequence) else []
|
||||
for block in blocks:
|
||||
@@ -2,7 +2,7 @@ from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from decimal import Decimal
|
||||
from typing import ClassVar
|
||||
from typing import Any, ClassVar
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
@@ -11,7 +11,7 @@ from schemas.agent.runtime_models import RouterAgentOutput, WorkerAgentOutputRic
|
||||
from ..agent import AgentType, ToolAgentOutput
|
||||
|
||||
|
||||
class UserMessageAttachments(BaseModel):
|
||||
class UserMessageAttachment(BaseModel):
|
||||
model_config: ClassVar[ConfigDict] = ConfigDict(extra="allow")
|
||||
|
||||
bucket: str
|
||||
@@ -23,7 +23,7 @@ class AgentChatMessageMetadata(BaseModel):
|
||||
model_config: ClassVar[ConfigDict] = ConfigDict(extra="allow")
|
||||
run_id: str
|
||||
agent_type: AgentType | None = None
|
||||
user_message_attachments: UserMessageAttachments | None = None
|
||||
user_message_attachments: list[UserMessageAttachment] | None = None
|
||||
router_agent_output: RouterAgentOutput | None = None
|
||||
tool_agent_output: ToolAgentOutput | None = None
|
||||
worker_agent_output: WorkerAgentOutputRich | None = None
|
||||
@@ -46,3 +46,32 @@ class AgentChatMessage(BaseModel):
|
||||
latency_ms: int | None = Field(default=None, ge=0)
|
||||
metadata: AgentChatMessageMetadata | dict[str, object] | None = None
|
||||
timestamp: datetime
|
||||
|
||||
|
||||
def extract_user_message_attachments(
|
||||
metadata: AgentChatMessageMetadata | dict[str, object] | None,
|
||||
) -> list[UserMessageAttachment]:
|
||||
if metadata is None:
|
||||
return []
|
||||
|
||||
if isinstance(metadata, AgentChatMessageMetadata):
|
||||
raw_value: Any = metadata.user_message_attachments
|
||||
else:
|
||||
raw_value = metadata.get("user_message_attachments")
|
||||
|
||||
if raw_value is None:
|
||||
return []
|
||||
|
||||
raw_items: list[Any]
|
||||
if isinstance(raw_value, list):
|
||||
raw_items = raw_value
|
||||
else:
|
||||
raw_items = [raw_value]
|
||||
|
||||
attachments: list[UserMessageAttachment] = []
|
||||
for item in raw_items:
|
||||
try:
|
||||
attachments.append(UserMessageAttachment.model_validate(item))
|
||||
except Exception:
|
||||
continue
|
||||
return attachments
|
||||
|
||||
@@ -37,6 +37,13 @@ class AttachmentSignedUrlResponse(BaseModel):
|
||||
url: str
|
||||
|
||||
|
||||
class HistoryMessageAttachment(BaseModel):
|
||||
model_config = ConfigDict(populate_by_name=True, serialize_by_alias=True)
|
||||
|
||||
mime_type: str = Field(alias="mimeType")
|
||||
url: str
|
||||
|
||||
|
||||
class HistoryMessage(BaseModel):
|
||||
"""History message schema for /history endpoint response."""
|
||||
|
||||
@@ -46,9 +53,9 @@ class HistoryMessage(BaseModel):
|
||||
seq: int = Field(description="Message sequence number")
|
||||
role: str = Field(description="Message role: user | assistant | tool")
|
||||
content: str = Field(description="Message text content")
|
||||
url: str | None = Field(
|
||||
default=None,
|
||||
description="Temporary signed URL for user-attached images",
|
||||
attachments: list[HistoryMessageAttachment] = Field(
|
||||
default_factory=list,
|
||||
description="Temporary signed URLs for user-attached images",
|
||||
)
|
||||
ui_schema: UiSchemaRenderer | None = Field(
|
||||
default=None,
|
||||
|
||||
@@ -19,7 +19,8 @@ from core.config.settings import config
|
||||
from core.logging import get_logger
|
||||
from schemas.messages.chat_message import (
|
||||
AgentChatMessageMetadata,
|
||||
UserMessageAttachments,
|
||||
UserMessageAttachment,
|
||||
extract_user_message_attachments,
|
||||
)
|
||||
from v1.agent.schemas import HistorySnapshotResponse
|
||||
|
||||
@@ -27,6 +28,7 @@ logger = get_logger(__name__)
|
||||
_ALLOWED_ATTACHMENT_MIME_TYPES = {"image/png", "image/jpeg", "image/webp"}
|
||||
_MAX_ATTACHMENT_BYTES = 5 * 1024 * 1024
|
||||
_MAX_TOTAL_ATTACHMENT_BYTES = 12 * 1024 * 1024
|
||||
_MAX_ATTACHMENTS_PER_MESSAGE = 3
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
@@ -230,7 +232,7 @@ class AgentService:
|
||||
) -> tuple[str, AgentChatMessageMetadata | None]:
|
||||
text, content_blocks = extract_latest_user_payload(run_input)
|
||||
|
||||
user_attachments: UserMessageAttachments | None = None
|
||||
user_attachments: list[UserMessageAttachment] = []
|
||||
for block in content_blocks:
|
||||
if not isinstance(block, dict):
|
||||
continue
|
||||
@@ -257,12 +259,15 @@ class AgentService:
|
||||
thread_id=run_input.thread_id,
|
||||
current_user=current_user,
|
||||
)
|
||||
user_attachments = UserMessageAttachments(
|
||||
bucket=bucket,
|
||||
path=path,
|
||||
mime_type=mime_type,
|
||||
user_attachments.append(
|
||||
UserMessageAttachment(
|
||||
bucket=bucket,
|
||||
path=path,
|
||||
mime_type=mime_type,
|
||||
)
|
||||
)
|
||||
break
|
||||
if len(user_attachments) > _MAX_ATTACHMENTS_PER_MESSAGE:
|
||||
raise HTTPException(status_code=422, detail="Too many attachments")
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as exc: # noqa: BLE001
|
||||
@@ -270,7 +275,7 @@ class AgentService:
|
||||
raise HTTPException(status_code=422, detail="Invalid signed image url")
|
||||
|
||||
metadata: AgentChatMessageMetadata | None = None
|
||||
if user_attachments is not None:
|
||||
if user_attachments:
|
||||
metadata = AgentChatMessageMetadata(
|
||||
run_id=run_input.run_id,
|
||||
user_message_attachments=user_attachments,
|
||||
@@ -438,23 +443,38 @@ class AgentService:
|
||||
|
||||
messages: list[HistoryMessage] = []
|
||||
if day_payload:
|
||||
raw_messages = day_payload.get("messages") or []
|
||||
raw_messages_obj = day_payload.get("messages")
|
||||
raw_messages = (
|
||||
raw_messages_obj if isinstance(raw_messages_obj, list) else []
|
||||
)
|
||||
for msg_dict in raw_messages:
|
||||
msg = AgentChatMessage.model_validate(msg_dict)
|
||||
|
||||
signed_url: str | None = None
|
||||
if self._attachment_storage and msg.metadata:
|
||||
att = msg.metadata.user_message_attachments
|
||||
if att:
|
||||
signed_urls: dict[str, str] = {}
|
||||
attachments = extract_user_message_attachments(msg.metadata)
|
||||
if self._attachment_storage and attachments:
|
||||
expected_prefix = (
|
||||
f"agent-inputs/{current_user.id}/{thread_id}/uploads/"
|
||||
)
|
||||
for attachment in attachments:
|
||||
if not _is_safe_attachment_path(
|
||||
attachment.path,
|
||||
expected_prefix=expected_prefix,
|
||||
):
|
||||
continue
|
||||
signed_url = await self._attachment_storage.create_signed_url(
|
||||
bucket=att.bucket,
|
||||
path=att.path,
|
||||
bucket=attachment.bucket,
|
||||
path=attachment.path,
|
||||
expires_in_seconds=self._SIGNED_URL_EXPIRES_IN_SECONDS,
|
||||
)
|
||||
key = f"{attachment.bucket}/{attachment.path}"
|
||||
signed_urls[key] = signed_url
|
||||
|
||||
converted = convert_message_to_history(msg, None)
|
||||
if signed_url:
|
||||
converted["url"] = signed_url
|
||||
def _get_signed_url(payload: dict[str, str]) -> str:
|
||||
key = f"{payload['bucket']}/{payload['path']}"
|
||||
return signed_urls[key]
|
||||
|
||||
converted = convert_message_to_history(msg, _get_signed_url)
|
||||
messages.append(HistoryMessage.model_validate(converted))
|
||||
|
||||
return HistorySnapshotResponse(
|
||||
|
||||
@@ -11,7 +11,7 @@ from core.agentscope.runtime.ui_compiler import compile as compile_ui_hints
|
||||
from schemas.messages.chat_message import (
|
||||
AgentChatMessage,
|
||||
AgentChatMessageMetadata,
|
||||
UserMessageAttachments,
|
||||
extract_user_message_attachments,
|
||||
)
|
||||
|
||||
|
||||
@@ -23,7 +23,7 @@ def convert_message_to_history(
|
||||
将 AgentChatMessage 转换为 HistoryMessage 格式
|
||||
|
||||
转换规则:
|
||||
- role=user: 读取 metadata.user_message_attachments,将 bucket 转临时访问 url
|
||||
- role=user: 读取 metadata.user_message_attachments,转换为 attachments[]
|
||||
- role=tool: 读取 content 和 metadata.tool_agent_output.ui_hints,编译成 ui_schema
|
||||
- role=assistant: 读取 metadata.worker_agent_output.ui_hints,编译成 ui_schema
|
||||
"""
|
||||
@@ -31,11 +31,11 @@ def convert_message_to_history(
|
||||
content = message.content
|
||||
metadata = message.metadata
|
||||
|
||||
url: str | None = None
|
||||
attachments: list[dict[str, str]] = []
|
||||
ui_schema: dict[str, Any] | None = None
|
||||
|
||||
if role == "user":
|
||||
url = _convert_user_attachments(metadata, get_signed_url_fn)
|
||||
attachments = _convert_user_attachments(metadata, get_signed_url_fn)
|
||||
|
||||
elif role == "tool":
|
||||
ui_schema = _compile_tool_ui_hints(metadata)
|
||||
@@ -51,8 +51,8 @@ def convert_message_to_history(
|
||||
"timestamp": message.timestamp.isoformat(),
|
||||
}
|
||||
|
||||
if url:
|
||||
result["url"] = url
|
||||
if attachments:
|
||||
result["attachments"] = attachments
|
||||
|
||||
if ui_schema:
|
||||
result["ui_schema"] = ui_schema
|
||||
@@ -63,28 +63,33 @@ def convert_message_to_history(
|
||||
def _convert_user_attachments(
|
||||
metadata: AgentChatMessageMetadata | dict[str, Any] | None,
|
||||
get_signed_url_fn: Callable[[dict[str, str]], str] | None,
|
||||
) -> str | None:
|
||||
"""转换用户附件为临时访问 URL"""
|
||||
if not metadata:
|
||||
return None
|
||||
) -> list[dict[str, str]]:
|
||||
"""转换用户附件为临时访问 URL 列表"""
|
||||
if not metadata or not get_signed_url_fn:
|
||||
return []
|
||||
|
||||
if isinstance(metadata, AgentChatMessageMetadata):
|
||||
attachments = metadata.user_message_attachments
|
||||
resolved = extract_user_message_attachments(metadata)
|
||||
elif isinstance(metadata, dict):
|
||||
resolved = extract_user_message_attachments(metadata)
|
||||
else:
|
||||
attachments_data = metadata.get("user_message_attachments")
|
||||
if not attachments_data:
|
||||
return None
|
||||
attachments = UserMessageAttachments.model_validate(attachments_data)
|
||||
return []
|
||||
|
||||
if not attachments or not get_signed_url_fn:
|
||||
return None
|
||||
|
||||
try:
|
||||
return get_signed_url_fn(
|
||||
{"bucket": attachments.bucket, "path": attachments.path}
|
||||
signed_attachments: list[dict[str, str]] = []
|
||||
for attachment in resolved:
|
||||
try:
|
||||
signed_url = get_signed_url_fn(
|
||||
{"bucket": attachment.bucket, "path": attachment.path}
|
||||
)
|
||||
except Exception:
|
||||
continue
|
||||
signed_attachments.append(
|
||||
{
|
||||
"url": signed_url,
|
||||
"mimeType": attachment.mime_type,
|
||||
}
|
||||
)
|
||||
except Exception:
|
||||
return None
|
||||
return signed_attachments
|
||||
|
||||
|
||||
def _compile_tool_ui_hints(
|
||||
|
||||
@@ -25,13 +25,14 @@ class AppVersionInfo(BaseModel):
|
||||
router = APIRouter(prefix="/app", tags=["app"])
|
||||
|
||||
|
||||
def _parse_version(filename: str) -> tuple[str, int] | None:
|
||||
def _parse_version(filename: str) -> tuple[str, int, tuple[int, ...]] | None:
|
||||
pattern = r"app[-_]v?(\d+\.\d+\.\d+)\+(\d+)\.(?:apk|ipa)"
|
||||
match = re.search(pattern, filename, re.IGNORECASE)
|
||||
if match:
|
||||
version = match.group(1)
|
||||
build = int(match.group(2))
|
||||
return (version, build)
|
||||
version_tuple = tuple(int(x) for x in version.split("."))
|
||||
return (version, build, version_tuple)
|
||||
return None
|
||||
|
||||
|
||||
@@ -44,21 +45,32 @@ def _get_latest_release(
|
||||
if not base_path.exists():
|
||||
return None
|
||||
|
||||
ext = "ipa" if platform == "ios" else "apk"
|
||||
target_ext = "ipa" if platform == "ios" else "apk"
|
||||
candidates = []
|
||||
|
||||
MIN_APK_SIZE = 1024 * 1024 # 1MB
|
||||
MIN_IPA_SIZE = 1024 * 1024 # 1MB
|
||||
|
||||
for f in base_path.iterdir():
|
||||
if f.is_file() and f.suffix.lstrip(".").lower() == ext.lower():
|
||||
parsed = _parse_version(f.name)
|
||||
if parsed:
|
||||
version, build = parsed
|
||||
candidates.append((version, build, f.name))
|
||||
if not f.is_file():
|
||||
continue
|
||||
ext = f.suffix.lstrip(".").lower()
|
||||
if ext != target_ext:
|
||||
continue
|
||||
# 简单校验文件大小,排除伪装文件
|
||||
if f.stat().st_size < (MIN_APK_SIZE if ext == "apk" else MIN_IPA_SIZE):
|
||||
continue
|
||||
parsed = _parse_version(f.name)
|
||||
if parsed:
|
||||
version, build, version_tuple = parsed
|
||||
candidates.append((version_tuple, build, f.name))
|
||||
|
||||
if not candidates:
|
||||
return None
|
||||
|
||||
candidates.sort(key=lambda x: (x[0], x[1]), reverse=True)
|
||||
return candidates[0][0], candidates[0][1], candidates[0][2]
|
||||
result = candidates[0]
|
||||
return result[2].replace("+", "."), result[1], result[2]
|
||||
|
||||
|
||||
def _compare_versions(
|
||||
|
||||
Reference in New Issue
Block a user