feat(agent): complete closed-loop runtime and pricing fallback

This commit is contained in:
qzl
2026-03-05 15:34:37 +08:00
parent b02a322bf3
commit b486e78ff3
67 changed files with 3832 additions and 7 deletions
+1
View File
@@ -0,0 +1 @@
from __future__ import annotations
@@ -0,0 +1 @@
from __future__ import annotations
@@ -0,0 +1,68 @@
from __future__ import annotations
from uuid import UUID
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
from core.agent.application.session_state_persistence import SessionStatePersistence
from core.agent.infrastructure.persistence.message_repository import MessageRepository
from core.agent.infrastructure.persistence.session_repository import SessionRepository
from core.db import AsyncSessionLocal
from models.agent_chat_message import AgentChatMessageRole
from models.agent_chat_session import AgentChatSessionStatus
class ResumeService:
def __init__(
self,
*,
session_factory: async_sessionmaker[AsyncSession] = AsyncSessionLocal,
) -> None:
self._session_factory = session_factory
self._state_persistence = SessionStatePersistence()
async def resume(self, *, session_id: str, tool_call_id: str) -> dict[str, object]:
session_uuid = UUID(session_id)
async with self._session_factory() as db_session:
session_repository = SessionRepository(db_session)
message_repository = MessageRepository(db_session)
chat_session = await session_repository.lock_session_for_update(
session_id=session_uuid
)
if chat_session is None:
raise ValueError("session not found")
state_snapshot = chat_session.state_snapshot or {}
pending_tool_call = state_snapshot.get("pending_tool_call_id")
if pending_tool_call != tool_call_id:
raise ValueError("pending tool call does not match")
next_seq = await session_repository.next_message_seq(
session_id=session_uuid
)
await message_repository.append_message(
session_id=session_uuid,
seq=next_seq,
role=AgentChatMessageRole.TOOL,
content='{"status":"ok"}',
metadata={"type": "tool_result", "tool_call_id": tool_call_id},
)
await message_repository.append_message(
session_id=session_uuid,
seq=next_seq + 1,
role=AgentChatMessageRole.ASSISTANT,
content="Tool result received",
metadata={"type": "assistant_output"},
)
snapshot = self._state_persistence.build_completed_snapshot()
await session_repository.update_runtime_state(
chat_session=chat_session,
status=AgentChatSessionStatus.COMPLETED,
state_snapshot=snapshot,
message_delta=2,
)
await db_session.commit()
return {"session_id": session_id, "resumed": True, "state_snapshot": snapshot}
@@ -0,0 +1,134 @@
from __future__ import annotations
from decimal import Decimal
from uuid import UUID, uuid4
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
from core.agent.application.session_state_persistence import SessionStatePersistence
from core.agent.infrastructure.crewai.factory import create_runtime
from core.agent.infrastructure.persistence.message_repository import MessageRepository
from core.agent.infrastructure.persistence.session_repository import SessionRepository
from core.db import AsyncSessionLocal
from models.agent_chat_message import AgentChatMessageRole
from models.agent_chat_session import AgentChatSessionStatus
from models.llm import Llm
from models.llm_factory import LlmFactory
from models.system_agents import SystemAgents
def _to_int(value: object, default: int = 0) -> int:
if isinstance(value, int):
return value
if isinstance(value, str):
try:
return int(value)
except ValueError:
return default
return default
def _to_decimal(value: object) -> Decimal:
if isinstance(value, (int, float, str, Decimal)):
return Decimal(str(value))
return Decimal("0")
class RunService:
def __init__(
self,
*,
session_factory: async_sessionmaker[AsyncSession] = AsyncSessionLocal,
) -> None:
self._session_factory = session_factory
self._state_persistence = SessionStatePersistence()
async def run(self, *, session_id: str, user_input: str) -> dict[str, object]:
session_uuid = UUID(session_id)
pending_tool_call_id = f"tool-{uuid4()}"
async with self._session_factory() as db_session:
session_repository = SessionRepository(db_session)
message_repository = MessageRepository(db_session)
chat_session = await session_repository.lock_session_for_update(
session_id=session_uuid
)
if chat_session is None:
raise ValueError("session not found")
model_code, provider_name = await self._load_agent_model_selection(
db_session
)
runtime = create_runtime(model_code=model_code, provider_name=provider_name)
runtime_result = runtime.execute(user_input=user_input)
assistant_text = str(runtime_result.get("assistant_text", ""))
prompt_tokens = _to_int(runtime_result.get("prompt_tokens", 0))
completion_tokens = _to_int(runtime_result.get("completion_tokens", 0))
total_tokens = _to_int(runtime_result.get("total_tokens", 0))
cost = _to_decimal(runtime_result.get("cost", 0))
agui_events = runtime_result.get("agui_events", [])
next_seq = await session_repository.next_message_seq(
session_id=session_uuid
)
await message_repository.append_message(
session_id=session_uuid,
seq=next_seq,
role=AgentChatMessageRole.USER,
content=user_input,
model_code=model_code,
metadata={"type": "user_input"},
)
await message_repository.append_message(
session_id=session_uuid,
seq=next_seq + 1,
role=AgentChatMessageRole.ASSISTANT,
content=assistant_text or "Tool call pending approval",
model_code=model_code,
metadata={
"type": "tool_call",
"tool_call_id": pending_tool_call_id,
},
input_tokens=prompt_tokens,
output_tokens=completion_tokens,
cost=cost,
)
snapshot = self._state_persistence.build_running_snapshot(
pending_tool_call_id=pending_tool_call_id
)
await session_repository.update_runtime_state(
chat_session=chat_session,
status=AgentChatSessionStatus.RUNNING,
state_snapshot=snapshot,
message_delta=2,
token_delta=total_tokens,
cost_delta=cost,
)
await db_session.commit()
return {
"session_id": session_id,
"persisted": True,
"pending_tool_call_id": pending_tool_call_id,
"state_snapshot": snapshot,
"events": agui_events,
}
async def _load_agent_model_selection(
self, session: AsyncSession
) -> tuple[str, str]:
stmt = (
select(Llm.model_code, LlmFactory.name)
.join(SystemAgents, SystemAgents.llm_id == Llm.id)
.join(LlmFactory, LlmFactory.id == Llm.factory_id)
.where(SystemAgents.status == "active")
.order_by(SystemAgents.agent_type.asc())
.limit(1)
)
record = (await session.execute(stmt)).one_or_none()
if record is None:
raise ValueError("active system agent model is required")
return str(record[0]), str(record[1])
@@ -0,0 +1,60 @@
from __future__ import annotations
import hashlib
import json
from typing import Protocol
from core.agent.domain.tool_correlation import build_tool_result_metadata
from core.agent.domain.state_snapshot import AgentStateSnapshot
class SessionStatePersistence:
def build_running_snapshot(
self, *, pending_tool_call_id: str | None
) -> dict[str, object]:
return AgentStateSnapshot(
status="running",
pending_tool_call_id=pending_tool_call_id,
).model_dump()
def build_completed_snapshot(self) -> dict[str, object]:
return AgentStateSnapshot(status="completed").model_dump()
class ToolResultStorage(Protocol):
async def upload_json(
self,
*,
bucket: str,
path: str,
payload: dict[str, object],
) -> str: ...
async def persist_tool_result_payload(
*,
storage: ToolResultStorage,
run_id: str,
turn_id: str,
tool_call_id: str,
tool_name: str,
payload: dict[str, object],
bucket: str,
path: str,
) -> dict[str, object]:
encoded = json.dumps(payload, ensure_ascii=True, sort_keys=True).encode("utf-8")
sha256 = hashlib.sha256(encoded).hexdigest()
etag = await storage.upload_json(bucket=bucket, path=path, payload=payload)
metadata = build_tool_result_metadata(
run_id=run_id,
turn_id=turn_id,
tool_call_id=tool_call_id,
tool_name=tool_name,
storage_bucket=bucket,
storage_path=path,
payload_sha256=sha256,
payload_bytes=len(encoded),
payload_format="json",
)
metadata["storage_etag"] = etag
return metadata
@@ -0,0 +1 @@
from __future__ import annotations
@@ -0,0 +1,10 @@
from __future__ import annotations
from typing import Literal
from pydantic import BaseModel
class AgentStateSnapshot(BaseModel):
status: Literal["pending", "running", "completed", "failed"]
pending_tool_call_id: str | None = None
@@ -0,0 +1,40 @@
from __future__ import annotations
def reconstruct_tool_call_result_event(
*,
metadata: dict[str, object],
payload: dict[str, object],
) -> dict[str, object]:
return {
"type": "TOOL_CALL_RESULT",
"data": payload,
"tool_call_id": metadata.get("tool_call_id"),
"tool_name": metadata.get("tool_name"),
}
def build_tool_result_metadata(
*,
run_id: str,
turn_id: str,
tool_call_id: str,
tool_name: str,
storage_bucket: str,
storage_path: str,
payload_sha256: str,
payload_bytes: int,
payload_format: str,
) -> dict[str, object]:
return {
"type": "tool_result",
"run_id": run_id,
"turn_id": turn_id,
"tool_call_id": tool_call_id,
"tool_name": tool_name,
"storage_bucket": storage_bucket,
"storage_path": storage_path,
"payload_sha256": payload_sha256,
"payload_bytes": payload_bytes,
"payload_format": payload_format,
}
@@ -0,0 +1 @@
from __future__ import annotations
@@ -0,0 +1 @@
from __future__ import annotations
@@ -0,0 +1,90 @@
from __future__ import annotations
import re
from typing import Any
from ag_ui.core.events import EventType
_CAMEL_CASE_BOUNDARY_RE = re.compile(r"([a-z0-9])([A-Z])")
_NON_ALNUM_RE = re.compile(r"[^A-Za-z0-9]+")
_SENSITIVE_KEYS = {
"apikey",
"authorization",
"token",
"accesstoken",
"refreshtoken",
"secret",
"password",
}
_TYPE_ALIASES = {
"taskStarted": "STEP_STARTED",
"taskFinished": "STEP_FINISHED",
"llmChunk": "TEXT_MESSAGE_CONTENT",
"llmStarted": "TEXT_MESSAGE_START",
"llmFinished": "TEXT_MESSAGE_END",
"toolCalled": "TOOL_CALL_START",
"toolCompleted": "TOOL_CALL_RESULT",
"error": "RUN_ERROR",
}
def _is_sensitive_key(key: str) -> bool:
normalized = _NON_ALNUM_RE.sub("", key.lower())
if normalized in _SENSITIVE_KEYS:
return True
if "token" in normalized:
return True
if "api" in normalized and "key" in normalized:
return True
return False
def _to_upper_snake(value: str) -> str:
with_boundaries = _CAMEL_CASE_BOUNDARY_RE.sub(r"\1_\2", value)
cleaned = _NON_ALNUM_RE.sub("_", with_boundaries)
return cleaned.strip("_").upper()
def _to_event_type(value: str) -> EventType:
try:
return EventType(value)
except ValueError as exc:
raise ValueError(f"unsupported AG-UI event type: {value}") from exc
def _redact_sensitive(value: Any) -> Any:
if isinstance(value, dict):
return {
key: (
"***REDACTED***"
if _is_sensitive_key(str(key))
else _redact_sensitive(child)
)
for key, child in value.items()
}
if isinstance(value, list):
return [_redact_sensitive(item) for item in value]
return value
def to_agui_events(internal_events: list[dict[str, Any]]) -> list[dict[str, Any]]:
normalized_events: list[dict[str, Any]] = []
for event in internal_events:
raw_type_value = event.get("type")
if not isinstance(raw_type_value, str) or not raw_type_value.strip():
raise ValueError("event.type must be a non-empty string")
raw_type = raw_type_value.strip()
normalized_event = {
key: value for key, value in event.items() if key not in {"type", "data"}
}
normalized_type = _TYPE_ALIASES.get(raw_type, _to_upper_snake(raw_type))
normalized_event["type"] = _to_event_type(normalized_type).value
data = event.get("data")
if not isinstance(data, dict):
raise ValueError("event.data must be an object")
normalized_event["data"] = _redact_sensitive(data)
normalized_events.append(normalized_event)
return normalized_events
@@ -0,0 +1,10 @@
from __future__ import annotations
import json
from typing import Any
def to_sse_event(stream_id: str, event: dict[str, Any]) -> str:
event_type = str(event.get("type", "MESSAGE"))
payload = json.dumps(event.get("data", {}), ensure_ascii=True)
return f"id: {stream_id}\nevent: {event_type}\ndata: {payload}\n\n"
@@ -0,0 +1 @@
from __future__ import annotations
@@ -0,0 +1,104 @@
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Protocol, cast
from core.config.settings import config
@dataclass(frozen=True)
class ResolvedAgentConfig:
model_code: str
provider_api_key: str = field(repr=False)
provider_name: str
stream: bool
class AgentRuntimeSettingsLike(Protocol):
default_model_code: str
streaming_enabled: bool
class LlmSettingsLike(Protocol):
provider_keys: dict[str, str]
class SettingsLike(Protocol):
agent_runtime: AgentRuntimeSettingsLike
llm: LlmSettingsLike
_PROVIDER_ALIASES = {
"ark": "volcengine",
"volcengine-ark": "volcengine",
"z-ai": "zai",
}
_SUPPORTED_PROVIDERS = {
"dashscope",
"minimax",
"moonshot",
"deepseek",
"volcengine",
"zai",
}
def _normalize_provider(provider: str) -> str:
normalized = provider.strip().lower()
canonical = _PROVIDER_ALIASES.get(normalized, normalized)
if canonical not in _SUPPORTED_PROVIDERS:
raise ValueError(f"unsupported provider '{provider}'")
return canonical
def _infer_provider_from_model(model_code: str) -> str:
lowered = model_code.strip().lower()
if lowered.startswith("qwen"):
return "dashscope"
if lowered.startswith("deepseek"):
return "deepseek"
if lowered.startswith("kimi") or lowered.startswith("moonshot"):
return "moonshot"
if lowered.startswith("abab") or lowered.startswith("minimax"):
return "minimax"
if lowered.startswith("doubao") or lowered.startswith("ark"):
return "volcengine"
if lowered.startswith("glm") or lowered.startswith("zai"):
return "zai"
raise ValueError("provider_name is required for unknown model_code")
class AgentConfigResolver:
def __init__(self, settings: SettingsLike | None = None) -> None:
self._settings: SettingsLike = cast(SettingsLike, settings or config)
def resolve(
self,
*,
model_code: str | None,
provider_name: str | None,
) -> ResolvedAgentConfig:
runtime_settings = self._settings.agent_runtime
resolved_model = (model_code or runtime_settings.default_model_code).strip()
if not resolved_model:
raise ValueError("llm_model_code is required")
provider = _normalize_provider(
provider_name or _infer_provider_from_model(resolved_model)
)
key_map = {
_normalize_provider(key): value
for key, value in self._settings.llm.provider_keys.items()
if value.strip()
}
resolved_key = key_map.get(provider, "").strip()
if not resolved_key:
raise ValueError(f"provider api key is required for provider '{provider}'")
return ResolvedAgentConfig(
model_code=resolved_model,
provider_api_key=resolved_key,
provider_name=provider,
stream=runtime_settings.streaming_enabled,
)
@@ -0,0 +1 @@
from __future__ import annotations
@@ -0,0 +1,15 @@
from __future__ import annotations
from core.agent.infrastructure.config.resolver import AgentConfigResolver
from core.agent.infrastructure.crewai.runtime import CrewAIRuntime
def create_runtime(
*, model_code: str | None, provider_name: str | None
) -> CrewAIRuntime:
resolver = AgentConfigResolver()
return CrewAIRuntime(
resolver=resolver,
model_code=model_code,
provider_name=provider_name,
)
@@ -0,0 +1,101 @@
from __future__ import annotations
from typing import Any
from core.agent.infrastructure.agui.bridge import to_agui_events
from core.agent.infrastructure.config.resolver import (
AgentConfigResolver,
ResolvedAgentConfig,
)
from core.agent.infrastructure.litellm.client import run_completion
from core.agent.infrastructure.litellm.usage_tracker import extract_usage_and_cost
def _to_litellm_model(*, provider_name: str, model_code: str) -> str:
normalized_model = model_code.strip()
if "/" in normalized_model:
return normalized_model
return f"{provider_name.strip().lower()}/{normalized_model}"
def _extract_assistant_text(response: dict[str, Any]) -> str:
choices = response.get("choices")
if not isinstance(choices, list) or not choices:
return ""
first = choices[0]
if not isinstance(first, dict):
return ""
message = first.get("message")
if not isinstance(message, dict):
return ""
content = message.get("content")
if isinstance(content, str):
return content
if isinstance(content, list):
text_parts = []
for item in content:
if isinstance(item, dict) and isinstance(item.get("text"), str):
text_parts.append(item["text"])
return "".join(text_parts)
return ""
class CrewAIRuntime:
def __init__(
self,
*,
resolver: AgentConfigResolver,
model_code: str | None,
provider_name: str | None,
) -> None:
self._config: ResolvedAgentConfig = resolver.resolve(
model_code=model_code,
provider_name=provider_name,
)
def map_events(self, internal_events: list[dict[str, Any]]) -> list[dict[str, Any]]:
return to_agui_events(internal_events)
def execute(self, *, user_input: str) -> dict[str, object]:
litellm_model = _to_litellm_model(
provider_name=self._config.provider_name,
model_code=self._config.model_code,
)
response = run_completion(
model=litellm_model,
api_key=self._config.provider_api_key,
messages=[{"role": "user", "content": user_input}],
)
if not isinstance(response, dict):
raise ValueError("llm response must be a dict")
usage_cost = extract_usage_and_cost(response)
assistant_text = _extract_assistant_text(response)
internal_events = [
{
"type": "llmStarted",
"data": {"model": self._config.model_code},
},
{
"type": "llmChunk",
"data": {"text": assistant_text},
},
{
"type": "llmFinished",
"data": {
"prompt_tokens": usage_cost.prompt_tokens,
"completion_tokens": usage_cost.completion_tokens,
"total_tokens": usage_cost.total_tokens,
"cost": usage_cost.cost,
"provider": self._config.provider_name,
},
},
]
return {
"assistant_text": assistant_text,
"prompt_tokens": usage_cost.prompt_tokens,
"completion_tokens": usage_cost.completion_tokens,
"total_tokens": usage_cost.total_tokens,
"cost": usage_cost.cost,
"agui_events": self.map_events(internal_events),
}
@@ -0,0 +1 @@
from __future__ import annotations
@@ -0,0 +1,66 @@
from __future__ import annotations
import json
import inspect
from typing import Any, Protocol
from uuid import UUID
class RedisStreamClient(Protocol):
def xadd(self, *args: Any, **kwargs: Any) -> Any: ...
def xread(self, *args: Any, **kwargs: Any) -> Any: ...
class RedisStreamEventStore:
def __init__(
self,
*,
client: RedisStreamClient,
stream_prefix: str,
read_count: int = 100,
block_ms: int = 5000,
) -> None:
self._client = client
self._stream_prefix = stream_prefix
self._read_count = read_count
self._block_ms = block_ms
def append_event_sync(self, *, session_id: UUID, event: dict[str, Any]) -> str:
stream = self._stream_name(session_id)
payload = json.dumps(event, ensure_ascii=True, separators=(",", ":"))
return str(self._client.xadd(stream, {"event": payload}))
async def read_events(
self,
*,
session_id: UUID,
last_event_id: str | None,
) -> list[dict[str, Any]]:
stream = self._stream_name(session_id)
start_id = "$" if last_event_id is None else last_event_id
raw_response = self._client.xread(
{stream: start_id},
count=self._read_count,
block=self._block_ms,
)
response = (
await raw_response if inspect.isawaitable(raw_response) else raw_response
)
if not response:
return []
_, entries = response[0]
result: list[dict[str, Any]] = []
for stream_id, payload in entries:
result.append(
{
"id": stream_id,
"event": json.loads(payload["event"]),
}
)
return result
def _stream_name(self, session_id: UUID) -> str:
return f"{self._stream_prefix}:{session_id}"
@@ -0,0 +1 @@
from __future__ import annotations
@@ -0,0 +1,18 @@
from __future__ import annotations
from typing import Any
from litellm import completion
def run_completion(*, model: str, api_key: str, messages: list[dict[str, Any]]) -> Any:
response = completion(
model=model,
api_key=api_key,
messages=messages,
stream=False,
)
model_dump = getattr(response, "model_dump", None)
if callable(model_dump):
return model_dump()
return response
@@ -0,0 +1,72 @@
from __future__ import annotations
from dataclasses import dataclass
@dataclass(frozen=True)
class TieredModelPricing:
max_prompt_tokens: int
input_cost_per_token: float
output_cost_per_token: float
cache_create_cost_per_token: float
cache_hit_cost_per_token: float
QWEN35_FLASH_TIERED_PRICING: tuple[TieredModelPricing, ...] = (
TieredModelPricing(
max_prompt_tokens=128_000,
input_cost_per_token=0.0002 / 1000,
output_cost_per_token=0.002 / 1000,
cache_create_cost_per_token=0.00025 / 1000,
cache_hit_cost_per_token=0.00002 / 1000,
),
TieredModelPricing(
max_prompt_tokens=256_000,
input_cost_per_token=0.0008 / 1000,
output_cost_per_token=0.008 / 1000,
cache_create_cost_per_token=0.001 / 1000,
cache_hit_cost_per_token=0.00008 / 1000,
),
TieredModelPricing(
max_prompt_tokens=1_000_000,
input_cost_per_token=0.0012 / 1000,
output_cost_per_token=0.012 / 1000,
cache_create_cost_per_token=0.0015 / 1000,
cache_hit_cost_per_token=0.00012 / 1000,
),
)
_MODEL_TIERED_PRICING: dict[str, tuple[TieredModelPricing, ...]] = {
"dashscope/qwen3.5-flash": QWEN35_FLASH_TIERED_PRICING,
}
def get_tiered_pricing(
*, model_name: str, prompt_tokens: int
) -> TieredModelPricing | None:
tiers = _MODEL_TIERED_PRICING.get(model_name.strip().lower())
if tiers is None:
return None
for tier in tiers:
if prompt_tokens <= tier.max_prompt_tokens:
return tier
return tiers[-1]
def calculate_tiered_model_cost(
*,
model_name: str,
prompt_tokens: int,
completion_tokens: int,
) -> float | None:
tier = get_tiered_pricing(model_name=model_name, prompt_tokens=prompt_tokens)
if tier is None:
return None
return (
prompt_tokens * tier.input_cost_per_token
+ completion_tokens * tier.output_cost_per_token
)
@@ -0,0 +1,55 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import Any
from litellm import completion_cost
from core.agent.infrastructure.litellm.pricing import calculate_tiered_model_cost
@dataclass(frozen=True)
class UsageCost:
prompt_tokens: int
completion_tokens: int
total_tokens: int
cost: float
cost_source: str = "litellm"
def extract_usage_and_cost(response: dict[str, Any]) -> UsageCost:
usage = response.get("usage")
if not isinstance(usage, dict):
raise ValueError("missing usage in response")
prompt_tokens = int(usage.get("prompt_tokens", 0))
completion_tokens = int(usage.get("completion_tokens", 0))
total_tokens = int(usage.get("total_tokens", prompt_tokens + completion_tokens))
model_name = str(response.get("model", "")).strip().lower()
try:
cost = completion_cost(completion_response=response)
if cost is None:
raise ValueError("unable to calculate litellm completion cost")
return UsageCost(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=total_tokens,
cost=float(cost),
)
except Exception as exc:
local_cost = calculate_tiered_model_cost(
model_name=model_name,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
)
if local_cost is None:
raise ValueError("unable to calculate litellm completion cost") from exc
return UsageCost(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=total_tokens,
cost=float(local_cost),
cost_source="custom_pricing",
)
@@ -0,0 +1 @@
from __future__ import annotations
@@ -0,0 +1,41 @@
from __future__ import annotations
from decimal import Decimal
from uuid import UUID
from sqlalchemy.ext.asyncio import AsyncSession
from models.agent_chat_message import AgentChatMessage, AgentChatMessageRole
class MessageRepository:
def __init__(self, session: AsyncSession) -> None:
self._session = session
async def append_message(
self,
*,
session_id: UUID,
seq: int,
role: AgentChatMessageRole,
content: str,
model_code: str | None = None,
metadata: dict[str, object] | None = None,
input_tokens: int = 0,
output_tokens: int = 0,
cost: Decimal = Decimal("0"),
) -> AgentChatMessage:
message = AgentChatMessage(
session_id=session_id,
seq=seq,
role=role,
content=content,
model_code=model_code,
metadata_json=metadata,
input_tokens=input_tokens,
output_tokens=output_tokens,
cost=cost,
)
self._session.add(message)
await self._session.flush()
return message
@@ -0,0 +1,77 @@
from __future__ import annotations
from datetime import datetime, timezone
from decimal import Decimal
from uuid import UUID
from sqlalchemy import func, select, update
from sqlalchemy.ext.asyncio import AsyncSession
from models.agent_chat_message import AgentChatMessage
from models.agent_chat_session import AgentChatSession, AgentChatSessionStatus
class SessionRepository:
def __init__(self, session: AsyncSession) -> None:
self._session = session
async def get_session(self, *, session_id: UUID) -> AgentChatSession | None:
return await self._session.get(AgentChatSession, session_id)
async def lock_session_for_update(
self, *, session_id: UUID
) -> AgentChatSession | None:
stmt = (
select(AgentChatSession)
.where(AgentChatSession.id == session_id)
.with_for_update()
)
return (await self._session.execute(stmt)).scalar_one_or_none()
async def next_message_seq(self, *, session_id: UUID) -> int:
stmt = select(func.coalesce(func.max(AgentChatMessage.seq), 0)).where(
AgentChatMessage.session_id == session_id
)
current = (await self._session.execute(stmt)).scalar_one()
return int(current) + 1
async def update_runtime_state(
self,
*,
chat_session: AgentChatSession,
status: AgentChatSessionStatus,
state_snapshot: dict[str, object],
message_delta: int,
token_delta: int = 0,
cost_delta: Decimal = Decimal("0"),
) -> None:
chat_session.status = status
chat_session.state_snapshot = state_snapshot
chat_session.last_activity_at = datetime.now(timezone.utc)
chat_session.message_count += message_delta
chat_session.total_tokens += token_delta
chat_session.total_cost += cost_delta
await self._session.flush()
async def soft_delete_session_with_messages(self, *, session_id: UUID) -> int:
existing = await self.get_session(session_id=session_id)
if existing is None or existing.deleted_at is not None:
return 0
deleted_at = datetime.now(timezone.utc)
session_stmt = (
update(AgentChatSession)
.where(AgentChatSession.id == session_id)
.where(AgentChatSession.deleted_at.is_(None))
.values(deleted_at=deleted_at)
)
message_stmt = (
update(AgentChatMessage)
.where(AgentChatMessage.session_id == session_id)
.where(AgentChatMessage.deleted_at.is_(None))
.values(deleted_at=deleted_at)
)
await self._session.execute(session_stmt)
await self._session.execute(message_stmt)
await self._session.flush()
return 1
@@ -0,0 +1 @@
from __future__ import annotations
@@ -0,0 +1,159 @@
from __future__ import annotations
import asyncio
import threading
from typing import Any, Callable, Protocol, cast
from uuid import UUID
import redis
from core.agent.application.resume_service import ResumeService
from core.agent.application.run_service import RunService
from core.agent.infrastructure.events.redis_stream import RedisStreamEventStore
from core.celery.app import celery_app
from core.config.settings import config
from core.logging import get_logger
logger = get_logger("core.agent.infrastructure.queue.tasks")
_background_loop: asyncio.AbstractEventLoop | None = None
_background_thread: threading.Thread | None = None
_background_ready = threading.Event()
class PublishEvent(Protocol):
def __call__(self, event_type: str, payload: dict[str, object]) -> None: ...
class RunServiceLike(Protocol):
async def run(self, *, session_id: str, user_input: str) -> dict[str, object]: ...
class ResumeServiceLike(Protocol):
async def resume(
self, *, session_id: str, tool_call_id: str
) -> dict[str, object]: ...
def _run_async(task: Callable[[], Any]) -> Any:
loop = _ensure_background_loop()
future = asyncio.run_coroutine_threadsafe(task(), loop)
return future.result()
def _ensure_background_loop() -> asyncio.AbstractEventLoop:
global _background_loop, _background_thread
if _background_loop is not None:
return _background_loop
def _loop_worker() -> None:
global _background_loop
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
_background_loop = loop
_background_ready.set()
loop.run_forever()
_background_thread = threading.Thread(target=_loop_worker, daemon=True)
_background_thread.start()
_background_ready.wait(timeout=5)
if _background_loop is None:
raise RuntimeError("failed to initialize background event loop")
return _background_loop
def _build_redis_publisher() -> PublishEvent:
settings = cast(Any, config)
client = redis.from_url(settings.redis.url, decode_responses=True)
event_store = RedisStreamEventStore(
client=client,
stream_prefix=settings.agent_runtime.redis_stream_prefix,
read_count=settings.agent_runtime.redis_stream_read_count,
block_ms=settings.agent_runtime.redis_stream_block_ms,
)
def _publish(event_type: str, payload: dict[str, object]) -> None:
session_id = str(payload.get("session_id", "")).strip()
if not session_id:
raise ValueError("session_id is required in event payload")
event_store.append_event_sync(
session_id=UUID(session_id),
event={"type": event_type, "data": payload},
)
return _publish
def run_agent_task(
command: dict[str, Any],
*,
publish_event: PublishEvent | None = None,
run_service: RunServiceLike | None = None,
resume_service: ResumeServiceLike | None = None,
) -> dict[str, object]:
publisher = publish_event or _build_redis_publisher()
service_run = run_service or RunService()
service_resume = resume_service or ResumeService()
command_type = str(command.get("command", "run"))
session_id = str(command.get("session_id", ""))
if command_type not in {"run", "resume"}:
raise ValueError("invalid command type")
if not session_id:
raise ValueError("session_id is required")
UUID(session_id)
start_event = "RUN_RESUMED" if command_type == "resume" else "RUN_STARTED"
publisher(start_event, {"session_id": session_id})
try:
if command_type == "resume":
tool_call_id = str(command.get("tool_call_id", ""))
if not tool_call_id:
raise ValueError("tool_call_id is required")
result = _run_async(
lambda: service_resume.resume(
session_id=session_id,
tool_call_id=tool_call_id,
)
)
else:
user_input = str(command.get("user_input", ""))
if not user_input:
raise ValueError("user_input is required")
result = _run_async(
lambda: service_run.run(
session_id=session_id,
user_input=user_input,
)
)
publisher("RUNTIME_EVENT", {"session_id": session_id, "result": result})
extra_events = result.get("events") if isinstance(result, dict) else None
if isinstance(extra_events, list):
for event in extra_events:
if not isinstance(event, dict):
continue
event_type = event.get("type")
event_data = event.get("data")
if not isinstance(event_type, str) or not isinstance(event_data, dict):
continue
payload = {"session_id": session_id, **event_data}
publisher(event_type, payload)
publisher("RUN_FINISHED", {"session_id": session_id})
return result
except Exception: # noqa: BLE001
error_id = "agent_runtime_failed"
logger.exception(
"Agent task failed",
session_id=session_id,
error_id=error_id,
)
publisher("RUN_ERROR", {"session_id": session_id, "error_id": error_id})
raise
@celery_app.task(name="tasks.agent.run_command")
def run_command_task(command: dict[str, Any]) -> dict[str, object]:
return run_agent_task(command)
+1
View File
@@ -15,6 +15,7 @@ def create_celery_app() -> Celery:
"social_app",
broker=config.celery_broker_url,
backend=config.celery_result_backend,
include=["core.agent.infrastructure.queue.tasks"],
)
app.conf.update(
+14
View File
@@ -140,6 +140,18 @@ class StorageSettings(BaseModel):
retention_days: int = Field(default=30, ge=1, le=3650)
class AgentRuntimeSettings(BaseModel):
redis_stream_prefix: str = "agent:events"
redis_stream_read_count: int = Field(default=100, ge=1, le=1000)
redis_stream_block_ms: int = Field(default=5000, ge=1, le=60000)
default_model_code: str = ""
streaming_enabled: bool = True
class LlmSettings(BaseModel):
provider_keys: dict[str, str] = Field(default_factory=dict)
class DatabaseSettings(BaseModel):
host: str = "localhost"
port: int = 5432
@@ -172,6 +184,8 @@ class Settings(BaseSettings):
redis: RedisSettings = RedisSettings()
supabase: SupabaseSettings = SupabaseSettings()
storage: StorageSettings = StorageSettings()
llm: LlmSettings = LlmSettings()
agent_runtime: AgentRuntimeSettings = AgentRuntimeSettings()
celery: CelerySettings = CelerySettings()
database: DatabaseSettings = DatabaseSettings()
@@ -15,11 +15,11 @@ factories:
request_url: https://api.deepseek.com/v1
avatar: https://registry.npmmirror.com/@lobehub/icons-static-png/latest/files/light/deepseek-color.png
- name: volcengine-ark
- name: volcengine
request_url: https://ark.cn-beijing.volces.com/api/v3
avatar: https://registry.npmmirror.com/@lobehub/icons-static-png/latest/files/light/doubao-color.png
- name: z-ai
- name: zai
request_url: https://api.z.ai/api/paas/v4
avatar: https://registry.npmmirror.com/@lobehub/icons-static-png/latest/files/light/zai.png
+4 -1
View File
@@ -45,7 +45,10 @@ class AgentChatMessage(TimestampMixin, SoftDeleteMixin, Base):
seq: Mapped[int] = mapped_column(Integer, nullable=False)
role: Mapped[AgentChatMessageRole] = mapped_column(
SqlEnum(
AgentChatMessageRole, name="agent_chat_message_role", native_enum=False
AgentChatMessageRole,
name="agent_chat_message_role",
native_enum=False,
values_callable=lambda enum_cls: [item.value for item in enum_cls],
),
nullable=False,
)
+6 -2
View File
@@ -7,6 +7,7 @@ from enum import Enum
from sqlalchemy import (
DateTime,
JSON,
Enum as SqlEnum,
Integer,
Numeric,
@@ -56,7 +57,10 @@ class AgentChatSession(TimestampMixin, SoftDeleteMixin, Base):
title: Mapped[str | None] = mapped_column(String(255), nullable=True)
status: Mapped[AgentChatSessionStatus] = mapped_column(
SqlEnum(
AgentChatSessionStatus, name="agent_chat_session_status", native_enum=False
AgentChatSessionStatus,
name="agent_chat_session_status",
native_enum=False,
values_callable=lambda enum_cls: [item.value for item in enum_cls],
),
nullable=False,
default=AgentChatSessionStatus.PENDING,
@@ -76,6 +80,6 @@ class AgentChatSession(TimestampMixin, SoftDeleteMixin, Base):
Numeric(12, 6), nullable=False, server_default=text("0")
)
state_snapshot: Mapped[dict | None] = mapped_column(
JSONB,
JSON().with_variant(JSONB, "postgresql"),
nullable=True,
)
+2 -2
View File
@@ -2,7 +2,7 @@ from __future__ import annotations
import uuid
from sqlalchemy import ForeignKey, String
from sqlalchemy import JSON, ForeignKey, String
from sqlalchemy.dialects.postgresql import JSONB, UUID
from sqlalchemy.orm import Mapped, mapped_column
@@ -26,7 +26,7 @@ class SystemAgents(TimestampMixin, Base):
nullable=False,
)
config: Mapped[dict] = mapped_column(
JSONB,
JSON().with_variant(JSONB, "postgresql"),
nullable=False,
server_default="{}",
)
+1
View File
@@ -0,0 +1 @@
from __future__ import annotations
+75
View File
@@ -0,0 +1,75 @@
from __future__ import annotations
from typing import Any, cast
from uuid import UUID
from fastapi import Depends
import redis.asyncio as redis
from sqlalchemy.ext.asyncio import AsyncSession
from core.agent.infrastructure.events.redis_stream import RedisStreamEventStore
from core.agent.infrastructure.queue.tasks import run_command_task
from core.config.settings import config
from core.db import get_db
from v1.agent.repository import AgentRepository
from v1.agent.service import AgentService
class CeleryQueueClient:
def __init__(self) -> None:
settings = cast(Any, config)
self._redis = redis.from_url(settings.redis.url, decode_responses=True)
async def enqueue(
self, *, command: dict[str, object], dedup_key: str | None
) -> str:
redis_key = None
if dedup_key:
redis_key = f"agent:dedup:{dedup_key}"
locked = await self._redis.set(redis_key, "__inflight__", nx=True, ex=300)
if not locked:
existing = await self._redis.get(redis_key)
if existing and existing != "__inflight__":
return existing
payload = dict(command)
if dedup_key:
payload["dedup_key"] = dedup_key
delay = getattr(run_command_task, "delay")
result = delay(payload)
task_id = str(result.id)
if redis_key is not None:
await self._redis.set(redis_key, task_id, ex=300)
return task_id
class RedisEventStream:
def __init__(self) -> None:
settings = cast(Any, config)
client = redis.from_url(settings.redis.url, decode_responses=True)
self._store = RedisStreamEventStore(
client=client,
stream_prefix=settings.agent_runtime.redis_stream_prefix,
read_count=settings.agent_runtime.redis_stream_read_count,
block_ms=settings.agent_runtime.redis_stream_block_ms,
)
async def read(
self,
*,
session_id: str,
last_event_id: str | None,
) -> list[dict[str, Any]]:
rows = await self._store.read_events(
session_id=UUID(session_id),
last_event_id=last_event_id,
)
return [{**row, "cursor": last_event_id} for row in rows]
def get_agent_service(session: AsyncSession = Depends(get_db)) -> AgentService:
return AgentService(
repository=AgentRepository(session),
queue=CeleryQueueClient(),
stream=RedisEventStream(),
)
+56
View File
@@ -0,0 +1,56 @@
from __future__ import annotations
from uuid import UUID
from fastapi import HTTPException
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from models.agent_chat_session import AgentChatSession
class AgentRepository:
def __init__(self, session: AsyncSession) -> None:
self._session = session
async def get_session_owner(self, *, session_id: str) -> str:
try:
session_uuid = UUID(session_id)
except ValueError as exc:
raise HTTPException(status_code=422, detail="Invalid session_id") from exc
stmt = select(AgentChatSession.user_id).where(
AgentChatSession.id == session_uuid
)
owner_id = (await self._session.execute(stmt)).scalar_one_or_none()
if owner_id is None:
raise HTTPException(status_code=404, detail="Session not found")
return str(owner_id)
async def create_session_for_user(self, *, user_id: str) -> str:
try:
user_uuid = UUID(user_id)
except ValueError as exc:
raise HTTPException(status_code=422, detail="Invalid user_id") from exc
session = AgentChatSession(user_id=user_uuid)
self._session.add(session)
await self._session.flush()
await self._session.refresh(session)
return str(session.id)
async def commit(self) -> None:
await self._session.commit()
async def rollback(self) -> None:
await self._session.rollback()
async def delete_session(self, *, session_id: str) -> None:
try:
session_uuid = UUID(session_id)
except ValueError as exc:
raise HTTPException(status_code=422, detail="Invalid session_id") from exc
session = await self._session.get(AgentChatSession, session_uuid)
if session is not None:
await self._session.delete(session)
await self._session.flush()
+104
View File
@@ -0,0 +1,104 @@
from __future__ import annotations
from collections.abc import AsyncIterator
import asyncio
from typing import Annotated
from fastapi import APIRouter, Depends, Header, Query, Request, status
from fastapi.responses import StreamingResponse
from core.agent.infrastructure.agui.stream import to_sse_event
from core.auth.models import CurrentUser
from v1.agent.dependencies import get_agent_service
from v1.agent.schemas import ResumeRequest, RunRequest, TaskAcceptedResponse
from v1.agent.service import AgentService
from v1.users.dependencies import get_current_user
router = APIRouter(prefix="/agent", tags=["agent"])
@router.post(
"/runs", response_model=TaskAcceptedResponse, status_code=status.HTTP_202_ACCEPTED
)
async def enqueue_run(
request: RunRequest,
service: Annotated[AgentService, Depends(get_agent_service)],
current_user: Annotated[CurrentUser, Depends(get_current_user)],
) -> TaskAcceptedResponse:
task = await service.enqueue_run(
session_id=request.session_id,
prompt=request.prompt,
current_user=current_user,
)
return TaskAcceptedResponse(
task_id=task.task_id,
session_id=task.session_id,
created=task.created,
)
@router.post(
"/runs/{session_id}/resume",
response_model=TaskAcceptedResponse,
status_code=status.HTTP_202_ACCEPTED,
)
async def enqueue_resume(
session_id: str,
request: ResumeRequest,
service: Annotated[AgentService, Depends(get_agent_service)],
current_user: Annotated[CurrentUser, Depends(get_current_user)],
) -> TaskAcceptedResponse:
task = await service.enqueue_resume(
session_id=session_id,
tool_call_id=request.tool_call_id,
current_user=current_user,
)
return TaskAcceptedResponse(
task_id=task.task_id,
session_id=task.session_id,
created=task.created,
)
@router.get("/runs/{session_id}/events")
async def stream_events(
request: Request,
session_id: str,
service: Annotated[AgentService, Depends(get_agent_service)],
current_user: Annotated[CurrentUser, Depends(get_current_user)],
last_event_id: str | None = Header(default=None, alias="Last-Event-ID"),
idle_limit: int = Query(default=300, ge=1, le=3600),
) -> StreamingResponse:
async def _event_iter() -> AsyncIterator[str]:
cursor = last_event_id
idle_polls = 0
while not await request.is_disconnected() and idle_polls < idle_limit:
rows = await service.stream_events(
session_id=session_id,
last_event_id=cursor,
current_user=current_user,
)
if not rows:
idle_polls += 1
yield ": keep-alive\n\n"
await asyncio.sleep(0.2)
continue
idle_polls = 0
for row in rows:
row_id = str(row.get("id", ""))
event = row.get("event")
if not row_id or not isinstance(event, dict):
continue
cursor = row_id
yield to_sse_event(row_id, event)
return StreamingResponse(
_event_iter(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
},
)
+18
View File
@@ -0,0 +1,18 @@
from __future__ import annotations
from pydantic import BaseModel, Field
class RunRequest(BaseModel):
session_id: str | None = Field(default=None, min_length=1, max_length=100)
prompt: str = Field(min_length=1, max_length=5000)
class ResumeRequest(BaseModel):
tool_call_id: str = Field(min_length=1, max_length=200)
class TaskAcceptedResponse(BaseModel):
task_id: str
session_id: str
created: bool
+132
View File
@@ -0,0 +1,132 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import Protocol
from fastapi import HTTPException
from core.auth.models import CurrentUser
@dataclass(frozen=True)
class TaskAccepted:
task_id: str
session_id: str
created: bool
class AgentRepositoryLike(Protocol):
async def get_session_owner(self, *, session_id: str) -> str: ...
async def create_session_for_user(self, *, user_id: str) -> str: ...
async def commit(self) -> None: ...
async def rollback(self) -> None: ...
class QueueClientLike(Protocol):
async def enqueue(
self, *, command: dict[str, object], dedup_key: str | None
) -> str: ...
class EventStreamLike(Protocol):
async def read(
self,
*,
session_id: str,
last_event_id: str | None,
) -> list[dict[str, object]]: ...
def ensure_session_owner(*, owner_id: str, current_user: CurrentUser) -> None:
if owner_id != str(current_user.id):
raise HTTPException(status_code=403, detail="Forbidden")
class AgentService:
def __init__(
self,
*,
repository: AgentRepositoryLike,
queue: QueueClientLike,
stream: EventStreamLike,
) -> None:
self._repository = repository
self._queue = queue
self._stream = stream
async def enqueue_run(
self,
*,
session_id: str | None,
prompt: str,
current_user: CurrentUser,
) -> TaskAccepted:
created = False
target_session_id = session_id
if target_session_id is None:
target_session_id = await self._repository.create_session_for_user(
user_id=str(current_user.id)
)
created = True
else:
owner = await self._repository.get_session_owner(
session_id=target_session_id
)
ensure_session_owner(owner_id=owner, current_user=current_user)
if created:
await self._repository.commit()
try:
task_id = await self._queue.enqueue(
command={
"command": "run",
"session_id": target_session_id,
"user_input": prompt,
},
dedup_key=None,
)
except Exception: # noqa: BLE001
raise
return TaskAccepted(
task_id=task_id, session_id=target_session_id, created=created
)
async def enqueue_resume(
self,
*,
session_id: str,
tool_call_id: str,
current_user: CurrentUser,
) -> TaskAccepted:
owner = await self._repository.get_session_owner(session_id=session_id)
ensure_session_owner(owner_id=owner, current_user=current_user)
dedup_key = f"resume:{session_id}:{tool_call_id}"
task_id = await self._queue.enqueue(
command={
"command": "resume",
"session_id": session_id,
"tool_call_id": tool_call_id,
},
dedup_key=dedup_key,
)
return TaskAccepted(task_id=task_id, session_id=session_id, created=False)
async def stream_events(
self,
*,
session_id: str,
last_event_id: str | None,
current_user: CurrentUser,
) -> list[dict[str, object]]:
owner = await self._repository.get_session_owner(session_id=session_id)
ensure_session_owner(owner_id=owner, current_user=current_user)
return await self._stream.read(
session_id=session_id,
last_event_id=last_event_id,
)
+2
View File
@@ -3,6 +3,7 @@ from __future__ import annotations
from fastapi import APIRouter
from core.http.models import HealthResponse
from v1.agent.router import router as agent_router
from v1.auth.router import router as auth_router
from v1.friendships.router import router as friendships_router
from v1.inbox_messages.router import router as inbox_messages_router
@@ -13,6 +14,7 @@ from v1.users.router import router as users_router
router = APIRouter(prefix="/api/v1")
router.include_router(auth_router)
router.include_router(agent_router)
router.include_router(friendships_router)
router.include_router(infra_router)
router.include_router(users_router)