feat(agent): complete closed-loop runtime and pricing fallback
This commit is contained in:
@@ -0,0 +1 @@
|
||||
from __future__ import annotations
|
||||
@@ -0,0 +1 @@
|
||||
from __future__ import annotations
|
||||
@@ -0,0 +1,68 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||
|
||||
from core.agent.application.session_state_persistence import SessionStatePersistence
|
||||
from core.agent.infrastructure.persistence.message_repository import MessageRepository
|
||||
from core.agent.infrastructure.persistence.session_repository import SessionRepository
|
||||
from core.db import AsyncSessionLocal
|
||||
from models.agent_chat_message import AgentChatMessageRole
|
||||
from models.agent_chat_session import AgentChatSessionStatus
|
||||
|
||||
|
||||
class ResumeService:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
session_factory: async_sessionmaker[AsyncSession] = AsyncSessionLocal,
|
||||
) -> None:
|
||||
self._session_factory = session_factory
|
||||
self._state_persistence = SessionStatePersistence()
|
||||
|
||||
async def resume(self, *, session_id: str, tool_call_id: str) -> dict[str, object]:
|
||||
session_uuid = UUID(session_id)
|
||||
|
||||
async with self._session_factory() as db_session:
|
||||
session_repository = SessionRepository(db_session)
|
||||
message_repository = MessageRepository(db_session)
|
||||
chat_session = await session_repository.lock_session_for_update(
|
||||
session_id=session_uuid
|
||||
)
|
||||
if chat_session is None:
|
||||
raise ValueError("session not found")
|
||||
|
||||
state_snapshot = chat_session.state_snapshot or {}
|
||||
pending_tool_call = state_snapshot.get("pending_tool_call_id")
|
||||
if pending_tool_call != tool_call_id:
|
||||
raise ValueError("pending tool call does not match")
|
||||
|
||||
next_seq = await session_repository.next_message_seq(
|
||||
session_id=session_uuid
|
||||
)
|
||||
await message_repository.append_message(
|
||||
session_id=session_uuid,
|
||||
seq=next_seq,
|
||||
role=AgentChatMessageRole.TOOL,
|
||||
content='{"status":"ok"}',
|
||||
metadata={"type": "tool_result", "tool_call_id": tool_call_id},
|
||||
)
|
||||
await message_repository.append_message(
|
||||
session_id=session_uuid,
|
||||
seq=next_seq + 1,
|
||||
role=AgentChatMessageRole.ASSISTANT,
|
||||
content="Tool result received",
|
||||
metadata={"type": "assistant_output"},
|
||||
)
|
||||
|
||||
snapshot = self._state_persistence.build_completed_snapshot()
|
||||
await session_repository.update_runtime_state(
|
||||
chat_session=chat_session,
|
||||
status=AgentChatSessionStatus.COMPLETED,
|
||||
state_snapshot=snapshot,
|
||||
message_delta=2,
|
||||
)
|
||||
await db_session.commit()
|
||||
|
||||
return {"session_id": session_id, "resumed": True, "state_snapshot": snapshot}
|
||||
@@ -0,0 +1,134 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from decimal import Decimal
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||
|
||||
from core.agent.application.session_state_persistence import SessionStatePersistence
|
||||
from core.agent.infrastructure.crewai.factory import create_runtime
|
||||
from core.agent.infrastructure.persistence.message_repository import MessageRepository
|
||||
from core.agent.infrastructure.persistence.session_repository import SessionRepository
|
||||
from core.db import AsyncSessionLocal
|
||||
from models.agent_chat_message import AgentChatMessageRole
|
||||
from models.agent_chat_session import AgentChatSessionStatus
|
||||
from models.llm import Llm
|
||||
from models.llm_factory import LlmFactory
|
||||
from models.system_agents import SystemAgents
|
||||
|
||||
|
||||
def _to_int(value: object, default: int = 0) -> int:
|
||||
if isinstance(value, int):
|
||||
return value
|
||||
if isinstance(value, str):
|
||||
try:
|
||||
return int(value)
|
||||
except ValueError:
|
||||
return default
|
||||
return default
|
||||
|
||||
|
||||
def _to_decimal(value: object) -> Decimal:
|
||||
if isinstance(value, (int, float, str, Decimal)):
|
||||
return Decimal(str(value))
|
||||
return Decimal("0")
|
||||
|
||||
|
||||
class RunService:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
session_factory: async_sessionmaker[AsyncSession] = AsyncSessionLocal,
|
||||
) -> None:
|
||||
self._session_factory = session_factory
|
||||
self._state_persistence = SessionStatePersistence()
|
||||
|
||||
async def run(self, *, session_id: str, user_input: str) -> dict[str, object]:
|
||||
session_uuid = UUID(session_id)
|
||||
pending_tool_call_id = f"tool-{uuid4()}"
|
||||
|
||||
async with self._session_factory() as db_session:
|
||||
session_repository = SessionRepository(db_session)
|
||||
message_repository = MessageRepository(db_session)
|
||||
|
||||
chat_session = await session_repository.lock_session_for_update(
|
||||
session_id=session_uuid
|
||||
)
|
||||
if chat_session is None:
|
||||
raise ValueError("session not found")
|
||||
|
||||
model_code, provider_name = await self._load_agent_model_selection(
|
||||
db_session
|
||||
)
|
||||
runtime = create_runtime(model_code=model_code, provider_name=provider_name)
|
||||
runtime_result = runtime.execute(user_input=user_input)
|
||||
assistant_text = str(runtime_result.get("assistant_text", ""))
|
||||
prompt_tokens = _to_int(runtime_result.get("prompt_tokens", 0))
|
||||
completion_tokens = _to_int(runtime_result.get("completion_tokens", 0))
|
||||
total_tokens = _to_int(runtime_result.get("total_tokens", 0))
|
||||
cost = _to_decimal(runtime_result.get("cost", 0))
|
||||
agui_events = runtime_result.get("agui_events", [])
|
||||
|
||||
next_seq = await session_repository.next_message_seq(
|
||||
session_id=session_uuid
|
||||
)
|
||||
await message_repository.append_message(
|
||||
session_id=session_uuid,
|
||||
seq=next_seq,
|
||||
role=AgentChatMessageRole.USER,
|
||||
content=user_input,
|
||||
model_code=model_code,
|
||||
metadata={"type": "user_input"},
|
||||
)
|
||||
await message_repository.append_message(
|
||||
session_id=session_uuid,
|
||||
seq=next_seq + 1,
|
||||
role=AgentChatMessageRole.ASSISTANT,
|
||||
content=assistant_text or "Tool call pending approval",
|
||||
model_code=model_code,
|
||||
metadata={
|
||||
"type": "tool_call",
|
||||
"tool_call_id": pending_tool_call_id,
|
||||
},
|
||||
input_tokens=prompt_tokens,
|
||||
output_tokens=completion_tokens,
|
||||
cost=cost,
|
||||
)
|
||||
|
||||
snapshot = self._state_persistence.build_running_snapshot(
|
||||
pending_tool_call_id=pending_tool_call_id
|
||||
)
|
||||
await session_repository.update_runtime_state(
|
||||
chat_session=chat_session,
|
||||
status=AgentChatSessionStatus.RUNNING,
|
||||
state_snapshot=snapshot,
|
||||
message_delta=2,
|
||||
token_delta=total_tokens,
|
||||
cost_delta=cost,
|
||||
)
|
||||
await db_session.commit()
|
||||
|
||||
return {
|
||||
"session_id": session_id,
|
||||
"persisted": True,
|
||||
"pending_tool_call_id": pending_tool_call_id,
|
||||
"state_snapshot": snapshot,
|
||||
"events": agui_events,
|
||||
}
|
||||
|
||||
async def _load_agent_model_selection(
|
||||
self, session: AsyncSession
|
||||
) -> tuple[str, str]:
|
||||
stmt = (
|
||||
select(Llm.model_code, LlmFactory.name)
|
||||
.join(SystemAgents, SystemAgents.llm_id == Llm.id)
|
||||
.join(LlmFactory, LlmFactory.id == Llm.factory_id)
|
||||
.where(SystemAgents.status == "active")
|
||||
.order_by(SystemAgents.agent_type.asc())
|
||||
.limit(1)
|
||||
)
|
||||
record = (await session.execute(stmt)).one_or_none()
|
||||
if record is None:
|
||||
raise ValueError("active system agent model is required")
|
||||
return str(record[0]), str(record[1])
|
||||
@@ -0,0 +1,60 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
from typing import Protocol
|
||||
|
||||
from core.agent.domain.tool_correlation import build_tool_result_metadata
|
||||
from core.agent.domain.state_snapshot import AgentStateSnapshot
|
||||
|
||||
|
||||
class SessionStatePersistence:
|
||||
def build_running_snapshot(
|
||||
self, *, pending_tool_call_id: str | None
|
||||
) -> dict[str, object]:
|
||||
return AgentStateSnapshot(
|
||||
status="running",
|
||||
pending_tool_call_id=pending_tool_call_id,
|
||||
).model_dump()
|
||||
|
||||
def build_completed_snapshot(self) -> dict[str, object]:
|
||||
return AgentStateSnapshot(status="completed").model_dump()
|
||||
|
||||
|
||||
class ToolResultStorage(Protocol):
|
||||
async def upload_json(
|
||||
self,
|
||||
*,
|
||||
bucket: str,
|
||||
path: str,
|
||||
payload: dict[str, object],
|
||||
) -> str: ...
|
||||
|
||||
|
||||
async def persist_tool_result_payload(
|
||||
*,
|
||||
storage: ToolResultStorage,
|
||||
run_id: str,
|
||||
turn_id: str,
|
||||
tool_call_id: str,
|
||||
tool_name: str,
|
||||
payload: dict[str, object],
|
||||
bucket: str,
|
||||
path: str,
|
||||
) -> dict[str, object]:
|
||||
encoded = json.dumps(payload, ensure_ascii=True, sort_keys=True).encode("utf-8")
|
||||
sha256 = hashlib.sha256(encoded).hexdigest()
|
||||
etag = await storage.upload_json(bucket=bucket, path=path, payload=payload)
|
||||
metadata = build_tool_result_metadata(
|
||||
run_id=run_id,
|
||||
turn_id=turn_id,
|
||||
tool_call_id=tool_call_id,
|
||||
tool_name=tool_name,
|
||||
storage_bucket=bucket,
|
||||
storage_path=path,
|
||||
payload_sha256=sha256,
|
||||
payload_bytes=len(encoded),
|
||||
payload_format="json",
|
||||
)
|
||||
metadata["storage_etag"] = etag
|
||||
return metadata
|
||||
@@ -0,0 +1 @@
|
||||
from __future__ import annotations
|
||||
@@ -0,0 +1,10 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class AgentStateSnapshot(BaseModel):
|
||||
status: Literal["pending", "running", "completed", "failed"]
|
||||
pending_tool_call_id: str | None = None
|
||||
@@ -0,0 +1,40 @@
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
def reconstruct_tool_call_result_event(
|
||||
*,
|
||||
metadata: dict[str, object],
|
||||
payload: dict[str, object],
|
||||
) -> dict[str, object]:
|
||||
return {
|
||||
"type": "TOOL_CALL_RESULT",
|
||||
"data": payload,
|
||||
"tool_call_id": metadata.get("tool_call_id"),
|
||||
"tool_name": metadata.get("tool_name"),
|
||||
}
|
||||
|
||||
|
||||
def build_tool_result_metadata(
|
||||
*,
|
||||
run_id: str,
|
||||
turn_id: str,
|
||||
tool_call_id: str,
|
||||
tool_name: str,
|
||||
storage_bucket: str,
|
||||
storage_path: str,
|
||||
payload_sha256: str,
|
||||
payload_bytes: int,
|
||||
payload_format: str,
|
||||
) -> dict[str, object]:
|
||||
return {
|
||||
"type": "tool_result",
|
||||
"run_id": run_id,
|
||||
"turn_id": turn_id,
|
||||
"tool_call_id": tool_call_id,
|
||||
"tool_name": tool_name,
|
||||
"storage_bucket": storage_bucket,
|
||||
"storage_path": storage_path,
|
||||
"payload_sha256": payload_sha256,
|
||||
"payload_bytes": payload_bytes,
|
||||
"payload_format": payload_format,
|
||||
}
|
||||
@@ -0,0 +1 @@
|
||||
from __future__ import annotations
|
||||
@@ -0,0 +1 @@
|
||||
from __future__ import annotations
|
||||
@@ -0,0 +1,90 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from ag_ui.core.events import EventType
|
||||
|
||||
|
||||
_CAMEL_CASE_BOUNDARY_RE = re.compile(r"([a-z0-9])([A-Z])")
|
||||
_NON_ALNUM_RE = re.compile(r"[^A-Za-z0-9]+")
|
||||
_SENSITIVE_KEYS = {
|
||||
"apikey",
|
||||
"authorization",
|
||||
"token",
|
||||
"accesstoken",
|
||||
"refreshtoken",
|
||||
"secret",
|
||||
"password",
|
||||
}
|
||||
_TYPE_ALIASES = {
|
||||
"taskStarted": "STEP_STARTED",
|
||||
"taskFinished": "STEP_FINISHED",
|
||||
"llmChunk": "TEXT_MESSAGE_CONTENT",
|
||||
"llmStarted": "TEXT_MESSAGE_START",
|
||||
"llmFinished": "TEXT_MESSAGE_END",
|
||||
"toolCalled": "TOOL_CALL_START",
|
||||
"toolCompleted": "TOOL_CALL_RESULT",
|
||||
"error": "RUN_ERROR",
|
||||
}
|
||||
|
||||
|
||||
def _is_sensitive_key(key: str) -> bool:
|
||||
normalized = _NON_ALNUM_RE.sub("", key.lower())
|
||||
if normalized in _SENSITIVE_KEYS:
|
||||
return True
|
||||
if "token" in normalized:
|
||||
return True
|
||||
if "api" in normalized and "key" in normalized:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _to_upper_snake(value: str) -> str:
|
||||
with_boundaries = _CAMEL_CASE_BOUNDARY_RE.sub(r"\1_\2", value)
|
||||
cleaned = _NON_ALNUM_RE.sub("_", with_boundaries)
|
||||
return cleaned.strip("_").upper()
|
||||
|
||||
|
||||
def _to_event_type(value: str) -> EventType:
|
||||
try:
|
||||
return EventType(value)
|
||||
except ValueError as exc:
|
||||
raise ValueError(f"unsupported AG-UI event type: {value}") from exc
|
||||
|
||||
|
||||
def _redact_sensitive(value: Any) -> Any:
|
||||
if isinstance(value, dict):
|
||||
return {
|
||||
key: (
|
||||
"***REDACTED***"
|
||||
if _is_sensitive_key(str(key))
|
||||
else _redact_sensitive(child)
|
||||
)
|
||||
for key, child in value.items()
|
||||
}
|
||||
if isinstance(value, list):
|
||||
return [_redact_sensitive(item) for item in value]
|
||||
return value
|
||||
|
||||
|
||||
def to_agui_events(internal_events: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
normalized_events: list[dict[str, Any]] = []
|
||||
|
||||
for event in internal_events:
|
||||
raw_type_value = event.get("type")
|
||||
if not isinstance(raw_type_value, str) or not raw_type_value.strip():
|
||||
raise ValueError("event.type must be a non-empty string")
|
||||
raw_type = raw_type_value.strip()
|
||||
normalized_event = {
|
||||
key: value for key, value in event.items() if key not in {"type", "data"}
|
||||
}
|
||||
normalized_type = _TYPE_ALIASES.get(raw_type, _to_upper_snake(raw_type))
|
||||
normalized_event["type"] = _to_event_type(normalized_type).value
|
||||
data = event.get("data")
|
||||
if not isinstance(data, dict):
|
||||
raise ValueError("event.data must be an object")
|
||||
normalized_event["data"] = _redact_sensitive(data)
|
||||
normalized_events.append(normalized_event)
|
||||
|
||||
return normalized_events
|
||||
@@ -0,0 +1,10 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
|
||||
def to_sse_event(stream_id: str, event: dict[str, Any]) -> str:
|
||||
event_type = str(event.get("type", "MESSAGE"))
|
||||
payload = json.dumps(event.get("data", {}), ensure_ascii=True)
|
||||
return f"id: {stream_id}\nevent: {event_type}\ndata: {payload}\n\n"
|
||||
@@ -0,0 +1 @@
|
||||
from __future__ import annotations
|
||||
@@ -0,0 +1,104 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Protocol, cast
|
||||
|
||||
from core.config.settings import config
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ResolvedAgentConfig:
|
||||
model_code: str
|
||||
provider_api_key: str = field(repr=False)
|
||||
provider_name: str
|
||||
stream: bool
|
||||
|
||||
|
||||
class AgentRuntimeSettingsLike(Protocol):
|
||||
default_model_code: str
|
||||
streaming_enabled: bool
|
||||
|
||||
|
||||
class LlmSettingsLike(Protocol):
|
||||
provider_keys: dict[str, str]
|
||||
|
||||
|
||||
class SettingsLike(Protocol):
|
||||
agent_runtime: AgentRuntimeSettingsLike
|
||||
llm: LlmSettingsLike
|
||||
|
||||
|
||||
_PROVIDER_ALIASES = {
|
||||
"ark": "volcengine",
|
||||
"volcengine-ark": "volcengine",
|
||||
"z-ai": "zai",
|
||||
}
|
||||
_SUPPORTED_PROVIDERS = {
|
||||
"dashscope",
|
||||
"minimax",
|
||||
"moonshot",
|
||||
"deepseek",
|
||||
"volcengine",
|
||||
"zai",
|
||||
}
|
||||
|
||||
|
||||
def _normalize_provider(provider: str) -> str:
|
||||
normalized = provider.strip().lower()
|
||||
canonical = _PROVIDER_ALIASES.get(normalized, normalized)
|
||||
if canonical not in _SUPPORTED_PROVIDERS:
|
||||
raise ValueError(f"unsupported provider '{provider}'")
|
||||
return canonical
|
||||
|
||||
|
||||
def _infer_provider_from_model(model_code: str) -> str:
|
||||
lowered = model_code.strip().lower()
|
||||
if lowered.startswith("qwen"):
|
||||
return "dashscope"
|
||||
if lowered.startswith("deepseek"):
|
||||
return "deepseek"
|
||||
if lowered.startswith("kimi") or lowered.startswith("moonshot"):
|
||||
return "moonshot"
|
||||
if lowered.startswith("abab") or lowered.startswith("minimax"):
|
||||
return "minimax"
|
||||
if lowered.startswith("doubao") or lowered.startswith("ark"):
|
||||
return "volcengine"
|
||||
if lowered.startswith("glm") or lowered.startswith("zai"):
|
||||
return "zai"
|
||||
raise ValueError("provider_name is required for unknown model_code")
|
||||
|
||||
|
||||
class AgentConfigResolver:
|
||||
def __init__(self, settings: SettingsLike | None = None) -> None:
|
||||
self._settings: SettingsLike = cast(SettingsLike, settings or config)
|
||||
|
||||
def resolve(
|
||||
self,
|
||||
*,
|
||||
model_code: str | None,
|
||||
provider_name: str | None,
|
||||
) -> ResolvedAgentConfig:
|
||||
runtime_settings = self._settings.agent_runtime
|
||||
resolved_model = (model_code or runtime_settings.default_model_code).strip()
|
||||
|
||||
if not resolved_model:
|
||||
raise ValueError("llm_model_code is required")
|
||||
|
||||
provider = _normalize_provider(
|
||||
provider_name or _infer_provider_from_model(resolved_model)
|
||||
)
|
||||
key_map = {
|
||||
_normalize_provider(key): value
|
||||
for key, value in self._settings.llm.provider_keys.items()
|
||||
if value.strip()
|
||||
}
|
||||
resolved_key = key_map.get(provider, "").strip()
|
||||
if not resolved_key:
|
||||
raise ValueError(f"provider api key is required for provider '{provider}'")
|
||||
|
||||
return ResolvedAgentConfig(
|
||||
model_code=resolved_model,
|
||||
provider_api_key=resolved_key,
|
||||
provider_name=provider,
|
||||
stream=runtime_settings.streaming_enabled,
|
||||
)
|
||||
@@ -0,0 +1 @@
|
||||
from __future__ import annotations
|
||||
@@ -0,0 +1,15 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from core.agent.infrastructure.config.resolver import AgentConfigResolver
|
||||
from core.agent.infrastructure.crewai.runtime import CrewAIRuntime
|
||||
|
||||
|
||||
def create_runtime(
|
||||
*, model_code: str | None, provider_name: str | None
|
||||
) -> CrewAIRuntime:
|
||||
resolver = AgentConfigResolver()
|
||||
return CrewAIRuntime(
|
||||
resolver=resolver,
|
||||
model_code=model_code,
|
||||
provider_name=provider_name,
|
||||
)
|
||||
@@ -0,0 +1,101 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from core.agent.infrastructure.agui.bridge import to_agui_events
|
||||
from core.agent.infrastructure.config.resolver import (
|
||||
AgentConfigResolver,
|
||||
ResolvedAgentConfig,
|
||||
)
|
||||
from core.agent.infrastructure.litellm.client import run_completion
|
||||
from core.agent.infrastructure.litellm.usage_tracker import extract_usage_and_cost
|
||||
|
||||
|
||||
def _to_litellm_model(*, provider_name: str, model_code: str) -> str:
|
||||
normalized_model = model_code.strip()
|
||||
if "/" in normalized_model:
|
||||
return normalized_model
|
||||
return f"{provider_name.strip().lower()}/{normalized_model}"
|
||||
|
||||
|
||||
def _extract_assistant_text(response: dict[str, Any]) -> str:
|
||||
choices = response.get("choices")
|
||||
if not isinstance(choices, list) or not choices:
|
||||
return ""
|
||||
first = choices[0]
|
||||
if not isinstance(first, dict):
|
||||
return ""
|
||||
message = first.get("message")
|
||||
if not isinstance(message, dict):
|
||||
return ""
|
||||
content = message.get("content")
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
if isinstance(content, list):
|
||||
text_parts = []
|
||||
for item in content:
|
||||
if isinstance(item, dict) and isinstance(item.get("text"), str):
|
||||
text_parts.append(item["text"])
|
||||
return "".join(text_parts)
|
||||
return ""
|
||||
|
||||
|
||||
class CrewAIRuntime:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
resolver: AgentConfigResolver,
|
||||
model_code: str | None,
|
||||
provider_name: str | None,
|
||||
) -> None:
|
||||
self._config: ResolvedAgentConfig = resolver.resolve(
|
||||
model_code=model_code,
|
||||
provider_name=provider_name,
|
||||
)
|
||||
|
||||
def map_events(self, internal_events: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
return to_agui_events(internal_events)
|
||||
|
||||
def execute(self, *, user_input: str) -> dict[str, object]:
|
||||
litellm_model = _to_litellm_model(
|
||||
provider_name=self._config.provider_name,
|
||||
model_code=self._config.model_code,
|
||||
)
|
||||
response = run_completion(
|
||||
model=litellm_model,
|
||||
api_key=self._config.provider_api_key,
|
||||
messages=[{"role": "user", "content": user_input}],
|
||||
)
|
||||
if not isinstance(response, dict):
|
||||
raise ValueError("llm response must be a dict")
|
||||
|
||||
usage_cost = extract_usage_and_cost(response)
|
||||
assistant_text = _extract_assistant_text(response)
|
||||
internal_events = [
|
||||
{
|
||||
"type": "llmStarted",
|
||||
"data": {"model": self._config.model_code},
|
||||
},
|
||||
{
|
||||
"type": "llmChunk",
|
||||
"data": {"text": assistant_text},
|
||||
},
|
||||
{
|
||||
"type": "llmFinished",
|
||||
"data": {
|
||||
"prompt_tokens": usage_cost.prompt_tokens,
|
||||
"completion_tokens": usage_cost.completion_tokens,
|
||||
"total_tokens": usage_cost.total_tokens,
|
||||
"cost": usage_cost.cost,
|
||||
"provider": self._config.provider_name,
|
||||
},
|
||||
},
|
||||
]
|
||||
return {
|
||||
"assistant_text": assistant_text,
|
||||
"prompt_tokens": usage_cost.prompt_tokens,
|
||||
"completion_tokens": usage_cost.completion_tokens,
|
||||
"total_tokens": usage_cost.total_tokens,
|
||||
"cost": usage_cost.cost,
|
||||
"agui_events": self.map_events(internal_events),
|
||||
}
|
||||
@@ -0,0 +1 @@
|
||||
from __future__ import annotations
|
||||
@@ -0,0 +1,66 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import inspect
|
||||
from typing import Any, Protocol
|
||||
from uuid import UUID
|
||||
|
||||
|
||||
class RedisStreamClient(Protocol):
|
||||
def xadd(self, *args: Any, **kwargs: Any) -> Any: ...
|
||||
|
||||
def xread(self, *args: Any, **kwargs: Any) -> Any: ...
|
||||
|
||||
|
||||
class RedisStreamEventStore:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
client: RedisStreamClient,
|
||||
stream_prefix: str,
|
||||
read_count: int = 100,
|
||||
block_ms: int = 5000,
|
||||
) -> None:
|
||||
self._client = client
|
||||
self._stream_prefix = stream_prefix
|
||||
self._read_count = read_count
|
||||
self._block_ms = block_ms
|
||||
|
||||
def append_event_sync(self, *, session_id: UUID, event: dict[str, Any]) -> str:
|
||||
stream = self._stream_name(session_id)
|
||||
payload = json.dumps(event, ensure_ascii=True, separators=(",", ":"))
|
||||
return str(self._client.xadd(stream, {"event": payload}))
|
||||
|
||||
async def read_events(
|
||||
self,
|
||||
*,
|
||||
session_id: UUID,
|
||||
last_event_id: str | None,
|
||||
) -> list[dict[str, Any]]:
|
||||
stream = self._stream_name(session_id)
|
||||
start_id = "$" if last_event_id is None else last_event_id
|
||||
raw_response = self._client.xread(
|
||||
{stream: start_id},
|
||||
count=self._read_count,
|
||||
block=self._block_ms,
|
||||
)
|
||||
response = (
|
||||
await raw_response if inspect.isawaitable(raw_response) else raw_response
|
||||
)
|
||||
|
||||
if not response:
|
||||
return []
|
||||
|
||||
_, entries = response[0]
|
||||
result: list[dict[str, Any]] = []
|
||||
for stream_id, payload in entries:
|
||||
result.append(
|
||||
{
|
||||
"id": stream_id,
|
||||
"event": json.loads(payload["event"]),
|
||||
}
|
||||
)
|
||||
return result
|
||||
|
||||
def _stream_name(self, session_id: UUID) -> str:
|
||||
return f"{self._stream_prefix}:{session_id}"
|
||||
@@ -0,0 +1 @@
|
||||
from __future__ import annotations
|
||||
@@ -0,0 +1,18 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from litellm import completion
|
||||
|
||||
|
||||
def run_completion(*, model: str, api_key: str, messages: list[dict[str, Any]]) -> Any:
|
||||
response = completion(
|
||||
model=model,
|
||||
api_key=api_key,
|
||||
messages=messages,
|
||||
stream=False,
|
||||
)
|
||||
model_dump = getattr(response, "model_dump", None)
|
||||
if callable(model_dump):
|
||||
return model_dump()
|
||||
return response
|
||||
@@ -0,0 +1,72 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TieredModelPricing:
|
||||
max_prompt_tokens: int
|
||||
input_cost_per_token: float
|
||||
output_cost_per_token: float
|
||||
cache_create_cost_per_token: float
|
||||
cache_hit_cost_per_token: float
|
||||
|
||||
|
||||
QWEN35_FLASH_TIERED_PRICING: tuple[TieredModelPricing, ...] = (
|
||||
TieredModelPricing(
|
||||
max_prompt_tokens=128_000,
|
||||
input_cost_per_token=0.0002 / 1000,
|
||||
output_cost_per_token=0.002 / 1000,
|
||||
cache_create_cost_per_token=0.00025 / 1000,
|
||||
cache_hit_cost_per_token=0.00002 / 1000,
|
||||
),
|
||||
TieredModelPricing(
|
||||
max_prompt_tokens=256_000,
|
||||
input_cost_per_token=0.0008 / 1000,
|
||||
output_cost_per_token=0.008 / 1000,
|
||||
cache_create_cost_per_token=0.001 / 1000,
|
||||
cache_hit_cost_per_token=0.00008 / 1000,
|
||||
),
|
||||
TieredModelPricing(
|
||||
max_prompt_tokens=1_000_000,
|
||||
input_cost_per_token=0.0012 / 1000,
|
||||
output_cost_per_token=0.012 / 1000,
|
||||
cache_create_cost_per_token=0.0015 / 1000,
|
||||
cache_hit_cost_per_token=0.00012 / 1000,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
_MODEL_TIERED_PRICING: dict[str, tuple[TieredModelPricing, ...]] = {
|
||||
"dashscope/qwen3.5-flash": QWEN35_FLASH_TIERED_PRICING,
|
||||
}
|
||||
|
||||
|
||||
def get_tiered_pricing(
|
||||
*, model_name: str, prompt_tokens: int
|
||||
) -> TieredModelPricing | None:
|
||||
tiers = _MODEL_TIERED_PRICING.get(model_name.strip().lower())
|
||||
if tiers is None:
|
||||
return None
|
||||
|
||||
for tier in tiers:
|
||||
if prompt_tokens <= tier.max_prompt_tokens:
|
||||
return tier
|
||||
|
||||
return tiers[-1]
|
||||
|
||||
|
||||
def calculate_tiered_model_cost(
|
||||
*,
|
||||
model_name: str,
|
||||
prompt_tokens: int,
|
||||
completion_tokens: int,
|
||||
) -> float | None:
|
||||
tier = get_tiered_pricing(model_name=model_name, prompt_tokens=prompt_tokens)
|
||||
if tier is None:
|
||||
return None
|
||||
|
||||
return (
|
||||
prompt_tokens * tier.input_cost_per_token
|
||||
+ completion_tokens * tier.output_cost_per_token
|
||||
)
|
||||
@@ -0,0 +1,55 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from litellm import completion_cost
|
||||
|
||||
from core.agent.infrastructure.litellm.pricing import calculate_tiered_model_cost
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class UsageCost:
|
||||
prompt_tokens: int
|
||||
completion_tokens: int
|
||||
total_tokens: int
|
||||
cost: float
|
||||
cost_source: str = "litellm"
|
||||
|
||||
|
||||
def extract_usage_and_cost(response: dict[str, Any]) -> UsageCost:
|
||||
usage = response.get("usage")
|
||||
if not isinstance(usage, dict):
|
||||
raise ValueError("missing usage in response")
|
||||
|
||||
prompt_tokens = int(usage.get("prompt_tokens", 0))
|
||||
completion_tokens = int(usage.get("completion_tokens", 0))
|
||||
total_tokens = int(usage.get("total_tokens", prompt_tokens + completion_tokens))
|
||||
model_name = str(response.get("model", "")).strip().lower()
|
||||
|
||||
try:
|
||||
cost = completion_cost(completion_response=response)
|
||||
if cost is None:
|
||||
raise ValueError("unable to calculate litellm completion cost")
|
||||
return UsageCost(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=total_tokens,
|
||||
cost=float(cost),
|
||||
)
|
||||
except Exception as exc:
|
||||
local_cost = calculate_tiered_model_cost(
|
||||
model_name=model_name,
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
)
|
||||
if local_cost is None:
|
||||
raise ValueError("unable to calculate litellm completion cost") from exc
|
||||
|
||||
return UsageCost(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=total_tokens,
|
||||
cost=float(local_cost),
|
||||
cost_source="custom_pricing",
|
||||
)
|
||||
@@ -0,0 +1 @@
|
||||
from __future__ import annotations
|
||||
@@ -0,0 +1,41 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from decimal import Decimal
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from models.agent_chat_message import AgentChatMessage, AgentChatMessageRole
|
||||
|
||||
|
||||
class MessageRepository:
|
||||
def __init__(self, session: AsyncSession) -> None:
|
||||
self._session = session
|
||||
|
||||
async def append_message(
|
||||
self,
|
||||
*,
|
||||
session_id: UUID,
|
||||
seq: int,
|
||||
role: AgentChatMessageRole,
|
||||
content: str,
|
||||
model_code: str | None = None,
|
||||
metadata: dict[str, object] | None = None,
|
||||
input_tokens: int = 0,
|
||||
output_tokens: int = 0,
|
||||
cost: Decimal = Decimal("0"),
|
||||
) -> AgentChatMessage:
|
||||
message = AgentChatMessage(
|
||||
session_id=session_id,
|
||||
seq=seq,
|
||||
role=role,
|
||||
content=content,
|
||||
model_code=model_code,
|
||||
metadata_json=metadata,
|
||||
input_tokens=input_tokens,
|
||||
output_tokens=output_tokens,
|
||||
cost=cost,
|
||||
)
|
||||
self._session.add(message)
|
||||
await self._session.flush()
|
||||
return message
|
||||
@@ -0,0 +1,77 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from decimal import Decimal
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import func, select, update
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from models.agent_chat_message import AgentChatMessage
|
||||
from models.agent_chat_session import AgentChatSession, AgentChatSessionStatus
|
||||
|
||||
|
||||
class SessionRepository:
|
||||
def __init__(self, session: AsyncSession) -> None:
|
||||
self._session = session
|
||||
|
||||
async def get_session(self, *, session_id: UUID) -> AgentChatSession | None:
|
||||
return await self._session.get(AgentChatSession, session_id)
|
||||
|
||||
async def lock_session_for_update(
|
||||
self, *, session_id: UUID
|
||||
) -> AgentChatSession | None:
|
||||
stmt = (
|
||||
select(AgentChatSession)
|
||||
.where(AgentChatSession.id == session_id)
|
||||
.with_for_update()
|
||||
)
|
||||
return (await self._session.execute(stmt)).scalar_one_or_none()
|
||||
|
||||
async def next_message_seq(self, *, session_id: UUID) -> int:
|
||||
stmt = select(func.coalesce(func.max(AgentChatMessage.seq), 0)).where(
|
||||
AgentChatMessage.session_id == session_id
|
||||
)
|
||||
current = (await self._session.execute(stmt)).scalar_one()
|
||||
return int(current) + 1
|
||||
|
||||
async def update_runtime_state(
|
||||
self,
|
||||
*,
|
||||
chat_session: AgentChatSession,
|
||||
status: AgentChatSessionStatus,
|
||||
state_snapshot: dict[str, object],
|
||||
message_delta: int,
|
||||
token_delta: int = 0,
|
||||
cost_delta: Decimal = Decimal("0"),
|
||||
) -> None:
|
||||
chat_session.status = status
|
||||
chat_session.state_snapshot = state_snapshot
|
||||
chat_session.last_activity_at = datetime.now(timezone.utc)
|
||||
chat_session.message_count += message_delta
|
||||
chat_session.total_tokens += token_delta
|
||||
chat_session.total_cost += cost_delta
|
||||
await self._session.flush()
|
||||
|
||||
async def soft_delete_session_with_messages(self, *, session_id: UUID) -> int:
|
||||
existing = await self.get_session(session_id=session_id)
|
||||
if existing is None or existing.deleted_at is not None:
|
||||
return 0
|
||||
|
||||
deleted_at = datetime.now(timezone.utc)
|
||||
session_stmt = (
|
||||
update(AgentChatSession)
|
||||
.where(AgentChatSession.id == session_id)
|
||||
.where(AgentChatSession.deleted_at.is_(None))
|
||||
.values(deleted_at=deleted_at)
|
||||
)
|
||||
message_stmt = (
|
||||
update(AgentChatMessage)
|
||||
.where(AgentChatMessage.session_id == session_id)
|
||||
.where(AgentChatMessage.deleted_at.is_(None))
|
||||
.values(deleted_at=deleted_at)
|
||||
)
|
||||
await self._session.execute(session_stmt)
|
||||
await self._session.execute(message_stmt)
|
||||
await self._session.flush()
|
||||
return 1
|
||||
@@ -0,0 +1 @@
|
||||
from __future__ import annotations
|
||||
@@ -0,0 +1,159 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import threading
|
||||
from typing import Any, Callable, Protocol, cast
|
||||
from uuid import UUID
|
||||
|
||||
import redis
|
||||
|
||||
from core.agent.application.resume_service import ResumeService
|
||||
from core.agent.application.run_service import RunService
|
||||
from core.agent.infrastructure.events.redis_stream import RedisStreamEventStore
|
||||
from core.celery.app import celery_app
|
||||
from core.config.settings import config
|
||||
from core.logging import get_logger
|
||||
|
||||
logger = get_logger("core.agent.infrastructure.queue.tasks")
|
||||
|
||||
_background_loop: asyncio.AbstractEventLoop | None = None
|
||||
_background_thread: threading.Thread | None = None
|
||||
_background_ready = threading.Event()
|
||||
|
||||
|
||||
class PublishEvent(Protocol):
|
||||
def __call__(self, event_type: str, payload: dict[str, object]) -> None: ...
|
||||
|
||||
|
||||
class RunServiceLike(Protocol):
|
||||
async def run(self, *, session_id: str, user_input: str) -> dict[str, object]: ...
|
||||
|
||||
|
||||
class ResumeServiceLike(Protocol):
|
||||
async def resume(
|
||||
self, *, session_id: str, tool_call_id: str
|
||||
) -> dict[str, object]: ...
|
||||
|
||||
|
||||
def _run_async(task: Callable[[], Any]) -> Any:
|
||||
loop = _ensure_background_loop()
|
||||
future = asyncio.run_coroutine_threadsafe(task(), loop)
|
||||
return future.result()
|
||||
|
||||
|
||||
def _ensure_background_loop() -> asyncio.AbstractEventLoop:
|
||||
global _background_loop, _background_thread
|
||||
if _background_loop is not None:
|
||||
return _background_loop
|
||||
|
||||
def _loop_worker() -> None:
|
||||
global _background_loop
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
_background_loop = loop
|
||||
_background_ready.set()
|
||||
loop.run_forever()
|
||||
|
||||
_background_thread = threading.Thread(target=_loop_worker, daemon=True)
|
||||
_background_thread.start()
|
||||
_background_ready.wait(timeout=5)
|
||||
if _background_loop is None:
|
||||
raise RuntimeError("failed to initialize background event loop")
|
||||
return _background_loop
|
||||
|
||||
|
||||
def _build_redis_publisher() -> PublishEvent:
|
||||
settings = cast(Any, config)
|
||||
client = redis.from_url(settings.redis.url, decode_responses=True)
|
||||
event_store = RedisStreamEventStore(
|
||||
client=client,
|
||||
stream_prefix=settings.agent_runtime.redis_stream_prefix,
|
||||
read_count=settings.agent_runtime.redis_stream_read_count,
|
||||
block_ms=settings.agent_runtime.redis_stream_block_ms,
|
||||
)
|
||||
|
||||
def _publish(event_type: str, payload: dict[str, object]) -> None:
|
||||
session_id = str(payload.get("session_id", "")).strip()
|
||||
if not session_id:
|
||||
raise ValueError("session_id is required in event payload")
|
||||
event_store.append_event_sync(
|
||||
session_id=UUID(session_id),
|
||||
event={"type": event_type, "data": payload},
|
||||
)
|
||||
|
||||
return _publish
|
||||
|
||||
|
||||
def run_agent_task(
|
||||
command: dict[str, Any],
|
||||
*,
|
||||
publish_event: PublishEvent | None = None,
|
||||
run_service: RunServiceLike | None = None,
|
||||
resume_service: ResumeServiceLike | None = None,
|
||||
) -> dict[str, object]:
|
||||
publisher = publish_event or _build_redis_publisher()
|
||||
service_run = run_service or RunService()
|
||||
service_resume = resume_service or ResumeService()
|
||||
|
||||
command_type = str(command.get("command", "run"))
|
||||
session_id = str(command.get("session_id", ""))
|
||||
|
||||
if command_type not in {"run", "resume"}:
|
||||
raise ValueError("invalid command type")
|
||||
if not session_id:
|
||||
raise ValueError("session_id is required")
|
||||
UUID(session_id)
|
||||
|
||||
start_event = "RUN_RESUMED" if command_type == "resume" else "RUN_STARTED"
|
||||
publisher(start_event, {"session_id": session_id})
|
||||
|
||||
try:
|
||||
if command_type == "resume":
|
||||
tool_call_id = str(command.get("tool_call_id", ""))
|
||||
if not tool_call_id:
|
||||
raise ValueError("tool_call_id is required")
|
||||
result = _run_async(
|
||||
lambda: service_resume.resume(
|
||||
session_id=session_id,
|
||||
tool_call_id=tool_call_id,
|
||||
)
|
||||
)
|
||||
else:
|
||||
user_input = str(command.get("user_input", ""))
|
||||
if not user_input:
|
||||
raise ValueError("user_input is required")
|
||||
result = _run_async(
|
||||
lambda: service_run.run(
|
||||
session_id=session_id,
|
||||
user_input=user_input,
|
||||
)
|
||||
)
|
||||
|
||||
publisher("RUNTIME_EVENT", {"session_id": session_id, "result": result})
|
||||
extra_events = result.get("events") if isinstance(result, dict) else None
|
||||
if isinstance(extra_events, list):
|
||||
for event in extra_events:
|
||||
if not isinstance(event, dict):
|
||||
continue
|
||||
event_type = event.get("type")
|
||||
event_data = event.get("data")
|
||||
if not isinstance(event_type, str) or not isinstance(event_data, dict):
|
||||
continue
|
||||
payload = {"session_id": session_id, **event_data}
|
||||
publisher(event_type, payload)
|
||||
publisher("RUN_FINISHED", {"session_id": session_id})
|
||||
return result
|
||||
except Exception: # noqa: BLE001
|
||||
error_id = "agent_runtime_failed"
|
||||
logger.exception(
|
||||
"Agent task failed",
|
||||
session_id=session_id,
|
||||
error_id=error_id,
|
||||
)
|
||||
publisher("RUN_ERROR", {"session_id": session_id, "error_id": error_id})
|
||||
raise
|
||||
|
||||
|
||||
@celery_app.task(name="tasks.agent.run_command")
|
||||
def run_command_task(command: dict[str, Any]) -> dict[str, object]:
|
||||
return run_agent_task(command)
|
||||
@@ -15,6 +15,7 @@ def create_celery_app() -> Celery:
|
||||
"social_app",
|
||||
broker=config.celery_broker_url,
|
||||
backend=config.celery_result_backend,
|
||||
include=["core.agent.infrastructure.queue.tasks"],
|
||||
)
|
||||
|
||||
app.conf.update(
|
||||
|
||||
@@ -140,6 +140,18 @@ class StorageSettings(BaseModel):
|
||||
retention_days: int = Field(default=30, ge=1, le=3650)
|
||||
|
||||
|
||||
class AgentRuntimeSettings(BaseModel):
|
||||
redis_stream_prefix: str = "agent:events"
|
||||
redis_stream_read_count: int = Field(default=100, ge=1, le=1000)
|
||||
redis_stream_block_ms: int = Field(default=5000, ge=1, le=60000)
|
||||
default_model_code: str = ""
|
||||
streaming_enabled: bool = True
|
||||
|
||||
|
||||
class LlmSettings(BaseModel):
|
||||
provider_keys: dict[str, str] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class DatabaseSettings(BaseModel):
|
||||
host: str = "localhost"
|
||||
port: int = 5432
|
||||
@@ -172,6 +184,8 @@ class Settings(BaseSettings):
|
||||
redis: RedisSettings = RedisSettings()
|
||||
supabase: SupabaseSettings = SupabaseSettings()
|
||||
storage: StorageSettings = StorageSettings()
|
||||
llm: LlmSettings = LlmSettings()
|
||||
agent_runtime: AgentRuntimeSettings = AgentRuntimeSettings()
|
||||
celery: CelerySettings = CelerySettings()
|
||||
database: DatabaseSettings = DatabaseSettings()
|
||||
|
||||
|
||||
@@ -15,11 +15,11 @@ factories:
|
||||
request_url: https://api.deepseek.com/v1
|
||||
avatar: https://registry.npmmirror.com/@lobehub/icons-static-png/latest/files/light/deepseek-color.png
|
||||
|
||||
- name: volcengine-ark
|
||||
- name: volcengine
|
||||
request_url: https://ark.cn-beijing.volces.com/api/v3
|
||||
avatar: https://registry.npmmirror.com/@lobehub/icons-static-png/latest/files/light/doubao-color.png
|
||||
|
||||
- name: z-ai
|
||||
- name: zai
|
||||
request_url: https://api.z.ai/api/paas/v4
|
||||
avatar: https://registry.npmmirror.com/@lobehub/icons-static-png/latest/files/light/zai.png
|
||||
|
||||
|
||||
Reference in New Issue
Block a user