feat(agent): 实现 Agent Runtime LLM 配置与消息元数据结构化支持
This commit is contained in:
@@ -5,6 +5,10 @@ from uuid import UUID
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||
|
||||
from core.agent.application.session_state_persistence import SessionStatePersistence
|
||||
from core.agent.domain.message_metadata import (
|
||||
MessageMetadataAssistantOutput,
|
||||
MessageMetadataToolResult,
|
||||
)
|
||||
from core.agent.infrastructure.persistence.message_repository import MessageRepository
|
||||
from core.agent.infrastructure.persistence.session_repository import SessionRepository
|
||||
from core.db import AsyncSessionLocal
|
||||
@@ -46,14 +50,16 @@ class ResumeService:
|
||||
seq=next_seq,
|
||||
role=AgentChatMessageRole.TOOL,
|
||||
content='{"status":"ok"}',
|
||||
metadata={"type": "tool_result", "tool_call_id": tool_call_id},
|
||||
metadata=MessageMetadataToolResult(
|
||||
tool_call_id=tool_call_id,
|
||||
).model_dump(),
|
||||
)
|
||||
await message_repository.append_message(
|
||||
session_id=session_uuid,
|
||||
seq=next_seq + 1,
|
||||
role=AgentChatMessageRole.ASSISTANT,
|
||||
content="Tool result received",
|
||||
metadata={"type": "assistant_output"},
|
||||
metadata=MessageMetadataAssistantOutput().model_dump(),
|
||||
)
|
||||
|
||||
snapshot = self._state_persistence.build_completed_snapshot()
|
||||
|
||||
@@ -3,10 +3,16 @@ from __future__ import annotations
|
||||
from decimal import Decimal
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from pydantic import ValidationError
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||
|
||||
from core.agent.application.session_state_persistence import SessionStatePersistence
|
||||
from core.agent.domain.message_metadata import (
|
||||
MessageMetadataToolCall,
|
||||
MessageMetadataUserInput,
|
||||
)
|
||||
from core.agent.domain.system_agent_config import SystemAgentLLMConfig
|
||||
from core.agent.infrastructure.crewai.factory import create_runtime
|
||||
from core.agent.infrastructure.persistence.message_repository import MessageRepository
|
||||
from core.agent.infrastructure.persistence.session_repository import SessionRepository
|
||||
@@ -58,10 +64,16 @@ class RunService:
|
||||
if chat_session is None:
|
||||
raise ValueError("session not found")
|
||||
|
||||
model_code, provider_name = await self._load_agent_model_selection(
|
||||
db_session
|
||||
(
|
||||
model_code,
|
||||
provider_name,
|
||||
llm_config,
|
||||
) = await self._load_agent_model_selection(db_session)
|
||||
runtime = create_runtime(
|
||||
model_code=model_code,
|
||||
provider_name=provider_name,
|
||||
llm_config=llm_config,
|
||||
)
|
||||
runtime = create_runtime(model_code=model_code, provider_name=provider_name)
|
||||
runtime_result = runtime.execute(user_input=user_input)
|
||||
assistant_text = str(runtime_result.get("assistant_text", ""))
|
||||
prompt_tokens = _to_int(runtime_result.get("prompt_tokens", 0))
|
||||
@@ -79,7 +91,7 @@ class RunService:
|
||||
role=AgentChatMessageRole.USER,
|
||||
content=user_input,
|
||||
model_code=model_code,
|
||||
metadata={"type": "user_input"},
|
||||
metadata=MessageMetadataUserInput().model_dump(),
|
||||
)
|
||||
await message_repository.append_message(
|
||||
session_id=session_uuid,
|
||||
@@ -87,10 +99,9 @@ class RunService:
|
||||
role=AgentChatMessageRole.ASSISTANT,
|
||||
content=assistant_text or "Tool call pending approval",
|
||||
model_code=model_code,
|
||||
metadata={
|
||||
"type": "tool_call",
|
||||
"tool_call_id": pending_tool_call_id,
|
||||
},
|
||||
metadata=MessageMetadataToolCall(
|
||||
tool_call_id=pending_tool_call_id,
|
||||
).model_dump(),
|
||||
input_tokens=prompt_tokens,
|
||||
output_tokens=completion_tokens,
|
||||
cost=cost,
|
||||
@@ -119,9 +130,9 @@ class RunService:
|
||||
|
||||
async def _load_agent_model_selection(
|
||||
self, session: AsyncSession
|
||||
) -> tuple[str, str]:
|
||||
) -> tuple[str, str, SystemAgentLLMConfig]:
|
||||
stmt = (
|
||||
select(Llm.model_code, LlmFactory.name)
|
||||
select(Llm.model_code, LlmFactory.name, SystemAgents.config)
|
||||
.join(SystemAgents, SystemAgents.llm_id == Llm.id)
|
||||
.join(LlmFactory, LlmFactory.id == Llm.factory_id)
|
||||
.where(SystemAgents.status == "active")
|
||||
@@ -131,4 +142,11 @@ class RunService:
|
||||
record = (await session.execute(stmt)).one_or_none()
|
||||
if record is None:
|
||||
raise ValueError("active system agent model is required")
|
||||
return str(record[0]), str(record[1])
|
||||
|
||||
raw_config = record[2] if isinstance(record[2], dict) else {}
|
||||
try:
|
||||
llm_config = SystemAgentLLMConfig.model_validate(raw_config)
|
||||
except ValidationError as exc:
|
||||
raise ValueError("invalid system agent config") from exc
|
||||
|
||||
return str(record[0]), str(record[1]), llm_config
|
||||
|
||||
@@ -0,0 +1,39 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class MessageMetadataUserInput(BaseModel):
|
||||
type: Literal["user_input"] = "user_input"
|
||||
|
||||
|
||||
class MessageMetadataToolCall(BaseModel):
|
||||
type: Literal["tool_call"] = "tool_call"
|
||||
tool_call_id: str
|
||||
|
||||
|
||||
class MessageMetadataToolResult(BaseModel):
|
||||
type: Literal["tool_result"] = "tool_result"
|
||||
tool_call_id: str
|
||||
run_id: str | None = None
|
||||
turn_id: str | None = None
|
||||
tool_name: str | None = None
|
||||
storage_bucket: str | None = None
|
||||
storage_path: str | None = None
|
||||
payload_sha256: str | None = None
|
||||
payload_bytes: int | None = None
|
||||
payload_format: str | None = None
|
||||
|
||||
|
||||
class MessageMetadataAssistantOutput(BaseModel):
|
||||
type: Literal["assistant_output"] = "assistant_output"
|
||||
|
||||
|
||||
MessageMetadata = (
|
||||
MessageMetadataUserInput
|
||||
| MessageMetadataToolCall
|
||||
| MessageMetadataToolResult
|
||||
| MessageMetadataAssistantOutput
|
||||
)
|
||||
@@ -0,0 +1,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class SystemAgentLLMConfig(BaseModel):
|
||||
temperature: float | None = Field(default=None, ge=0.0, le=2.0)
|
||||
max_tokens: int | None = Field(default=None, ge=1)
|
||||
@@ -1,5 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from core.agent.domain.message_metadata import MessageMetadataToolResult
|
||||
|
||||
|
||||
def reconstruct_tool_call_result_event(
|
||||
*,
|
||||
@@ -26,15 +28,14 @@ def build_tool_result_metadata(
|
||||
payload_bytes: int,
|
||||
payload_format: str,
|
||||
) -> dict[str, object]:
|
||||
return {
|
||||
"type": "tool_result",
|
||||
"run_id": run_id,
|
||||
"turn_id": turn_id,
|
||||
"tool_call_id": tool_call_id,
|
||||
"tool_name": tool_name,
|
||||
"storage_bucket": storage_bucket,
|
||||
"storage_path": storage_path,
|
||||
"payload_sha256": payload_sha256,
|
||||
"payload_bytes": payload_bytes,
|
||||
"payload_format": payload_format,
|
||||
}
|
||||
return MessageMetadataToolResult(
|
||||
run_id=run_id,
|
||||
turn_id=turn_id,
|
||||
tool_call_id=tool_call_id,
|
||||
tool_name=tool_name,
|
||||
storage_bucket=storage_bucket,
|
||||
storage_path=storage_path,
|
||||
payload_sha256=payload_sha256,
|
||||
payload_bytes=payload_bytes,
|
||||
payload_format=payload_format,
|
||||
).model_dump()
|
||||
|
||||
@@ -1,15 +1,20 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from core.agent.domain.system_agent_config import SystemAgentLLMConfig
|
||||
from core.agent.infrastructure.config.resolver import AgentConfigResolver
|
||||
from core.agent.infrastructure.crewai.runtime import CrewAIRuntime
|
||||
|
||||
|
||||
def create_runtime(
|
||||
*, model_code: str | None, provider_name: str | None
|
||||
*,
|
||||
model_code: str | None,
|
||||
provider_name: str | None,
|
||||
llm_config: SystemAgentLLMConfig | None = None,
|
||||
) -> CrewAIRuntime:
|
||||
resolver = AgentConfigResolver()
|
||||
return CrewAIRuntime(
|
||||
resolver=resolver,
|
||||
model_code=model_code,
|
||||
provider_name=provider_name,
|
||||
llm_config=llm_config,
|
||||
)
|
||||
|
||||
@@ -2,6 +2,7 @@ from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from core.agent.domain.system_agent_config import SystemAgentLLMConfig
|
||||
from core.agent.infrastructure.agui.bridge import to_agui_events
|
||||
from core.agent.infrastructure.config.resolver import (
|
||||
AgentConfigResolver,
|
||||
@@ -47,11 +48,13 @@ class CrewAIRuntime:
|
||||
resolver: AgentConfigResolver,
|
||||
model_code: str | None,
|
||||
provider_name: str | None,
|
||||
llm_config: SystemAgentLLMConfig | None = None,
|
||||
) -> None:
|
||||
self._config: ResolvedAgentConfig = resolver.resolve(
|
||||
model_code=model_code,
|
||||
provider_name=provider_name,
|
||||
)
|
||||
self._llm_config = llm_config or SystemAgentLLMConfig()
|
||||
|
||||
def map_events(self, internal_events: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
return to_agui_events(internal_events)
|
||||
@@ -65,6 +68,8 @@ class CrewAIRuntime:
|
||||
model=litellm_model,
|
||||
api_key=self._config.provider_api_key,
|
||||
messages=[{"role": "user", "content": user_input}],
|
||||
temperature=self._llm_config.temperature,
|
||||
max_tokens=self._llm_config.max_tokens,
|
||||
)
|
||||
if not isinstance(response, dict):
|
||||
raise ValueError("llm response must be a dict")
|
||||
|
||||
@@ -5,13 +5,26 @@ from typing import Any
|
||||
from litellm import completion
|
||||
|
||||
|
||||
def run_completion(*, model: str, api_key: str, messages: list[dict[str, Any]]) -> Any:
|
||||
response = completion(
|
||||
model=model,
|
||||
api_key=api_key,
|
||||
messages=messages,
|
||||
stream=False,
|
||||
)
|
||||
def run_completion(
|
||||
*,
|
||||
model: str,
|
||||
api_key: str,
|
||||
messages: list[dict[str, Any]],
|
||||
temperature: float | None = None,
|
||||
max_tokens: int | None = None,
|
||||
) -> Any:
|
||||
kwargs: dict[str, Any] = {
|
||||
"model": model,
|
||||
"api_key": api_key,
|
||||
"messages": messages,
|
||||
"stream": False,
|
||||
}
|
||||
if temperature is not None:
|
||||
kwargs["temperature"] = temperature
|
||||
if max_tokens is not None:
|
||||
kwargs["max_tokens"] = max_tokens
|
||||
|
||||
response = completion(**kwargs)
|
||||
model_dump = getattr(response, "model_dump", None)
|
||||
if callable(model_dump):
|
||||
return model_dump()
|
||||
|
||||
@@ -9,6 +9,7 @@ from pydantic import BaseModel, ValidationError
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from core.agent.domain.system_agent_config import SystemAgentLLMConfig
|
||||
from core.db.session import AsyncSessionLocal
|
||||
from core.logging import get_logger
|
||||
from models.llm import Llm
|
||||
@@ -38,7 +39,7 @@ class SystemAgentsSeed(BaseModel):
|
||||
agent_type: str
|
||||
llm_model_code: str
|
||||
status: str
|
||||
config: dict[str, Any]
|
||||
config: SystemAgentLLMConfig | None = None
|
||||
|
||||
|
||||
class SystemAgentsYaml(BaseModel):
|
||||
@@ -184,7 +185,9 @@ async def initialize_system_agents() -> None:
|
||||
agent_type=agent["agent_type"],
|
||||
llm_id=llm.id,
|
||||
status=agent["status"],
|
||||
config=agent["config"],
|
||||
config=SystemAgentLLMConfig.model_validate(
|
||||
agent.get("config") or {}
|
||||
).model_dump(),
|
||||
)
|
||||
|
||||
logger.info("Initialized system agents")
|
||||
|
||||
@@ -4,15 +4,18 @@ agents:
|
||||
status: active
|
||||
config:
|
||||
temperature: 0.7
|
||||
max_tokens: null
|
||||
|
||||
- agent_type: TASK_EXECUTION
|
||||
llm_model_code: deepseek-v3.2
|
||||
status: active
|
||||
config:
|
||||
temperature: 0.7
|
||||
max_tokens: null
|
||||
|
||||
- agent_type: RESULT_REPORTING
|
||||
llm_model_code: deepseek-v3.2
|
||||
status: active
|
||||
config:
|
||||
temperature: 0.7
|
||||
max_tokens: null
|
||||
|
||||
Reference in New Issue
Block a user