feat(agent): complete closed-loop runtime and pricing fallback

This commit is contained in:
qzl
2026-03-05 15:34:37 +08:00
parent b02a322bf3
commit b486e78ff3
67 changed files with 3832 additions and 7 deletions
@@ -0,0 +1,24 @@
"""agent runtime baseline marker
Revision ID: 202603040001
Revises: 435419f8121c
Create Date: 2026-03-04 23:59:00
"""
from __future__ import annotations
from typing import Sequence, Union
revision: str = "202603040001"
down_revision: Union[str, Sequence[str], None] = "435419f8121c"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
return None
def downgrade() -> None:
return None
@@ -0,0 +1,66 @@
"""agent runtime closed loop contract
Revision ID: 202603050001
Revises: 435419f8121c
Create Date: 2026-03-05 12:00:00
"""
from __future__ import annotations
from typing import Sequence, Union
from alembic import op
revision: str = "202603050001"
down_revision: Union[str, Sequence[str], None] = "202603040001"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
op.execute("ALTER TABLE sessions ADD COLUMN IF NOT EXISTS state_snapshot JSONB")
with op.get_context().autocommit_block():
op.execute(
"CREATE INDEX CONCURRENTLY IF NOT EXISTS ix_messages_session_seq ON messages (session_id, seq)"
)
op.execute(
"""
ALTER TABLE messages
ADD CONSTRAINT chk_messages_tool_result_metadata
CHECK (
role != 'tool'
OR (metadata IS NOT NULL AND metadata->>'type' = 'tool_result')
)
NOT VALID
"""
)
op.execute(
"ALTER TABLE messages VALIDATE CONSTRAINT chk_messages_tool_result_metadata"
)
op.execute(
"""
ALTER TABLE messages
ADD CONSTRAINT chk_messages_assistant_metadata
CHECK (
role != 'assistant'
OR (metadata IS NOT NULL AND metadata->>'type' IN ('tool_call', 'assistant_output'))
)
NOT VALID
"""
)
op.execute(
"ALTER TABLE messages VALIDATE CONSTRAINT chk_messages_assistant_metadata"
)
def downgrade() -> None:
op.execute(
"ALTER TABLE messages DROP CONSTRAINT IF EXISTS chk_messages_assistant_metadata"
)
op.execute(
"ALTER TABLE messages DROP CONSTRAINT IF EXISTS chk_messages_tool_result_metadata"
)
op.execute("DROP INDEX IF EXISTS ix_messages_session_seq")
op.execute("ALTER TABLE sessions DROP COLUMN IF EXISTS state_snapshot")
+1
View File
@@ -0,0 +1 @@
from __future__ import annotations
@@ -0,0 +1 @@
from __future__ import annotations
@@ -0,0 +1,68 @@
from __future__ import annotations
from uuid import UUID
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
from core.agent.application.session_state_persistence import SessionStatePersistence
from core.agent.infrastructure.persistence.message_repository import MessageRepository
from core.agent.infrastructure.persistence.session_repository import SessionRepository
from core.db import AsyncSessionLocal
from models.agent_chat_message import AgentChatMessageRole
from models.agent_chat_session import AgentChatSessionStatus
class ResumeService:
def __init__(
self,
*,
session_factory: async_sessionmaker[AsyncSession] = AsyncSessionLocal,
) -> None:
self._session_factory = session_factory
self._state_persistence = SessionStatePersistence()
async def resume(self, *, session_id: str, tool_call_id: str) -> dict[str, object]:
session_uuid = UUID(session_id)
async with self._session_factory() as db_session:
session_repository = SessionRepository(db_session)
message_repository = MessageRepository(db_session)
chat_session = await session_repository.lock_session_for_update(
session_id=session_uuid
)
if chat_session is None:
raise ValueError("session not found")
state_snapshot = chat_session.state_snapshot or {}
pending_tool_call = state_snapshot.get("pending_tool_call_id")
if pending_tool_call != tool_call_id:
raise ValueError("pending tool call does not match")
next_seq = await session_repository.next_message_seq(
session_id=session_uuid
)
await message_repository.append_message(
session_id=session_uuid,
seq=next_seq,
role=AgentChatMessageRole.TOOL,
content='{"status":"ok"}',
metadata={"type": "tool_result", "tool_call_id": tool_call_id},
)
await message_repository.append_message(
session_id=session_uuid,
seq=next_seq + 1,
role=AgentChatMessageRole.ASSISTANT,
content="Tool result received",
metadata={"type": "assistant_output"},
)
snapshot = self._state_persistence.build_completed_snapshot()
await session_repository.update_runtime_state(
chat_session=chat_session,
status=AgentChatSessionStatus.COMPLETED,
state_snapshot=snapshot,
message_delta=2,
)
await db_session.commit()
return {"session_id": session_id, "resumed": True, "state_snapshot": snapshot}
@@ -0,0 +1,134 @@
from __future__ import annotations
from decimal import Decimal
from uuid import UUID, uuid4
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
from core.agent.application.session_state_persistence import SessionStatePersistence
from core.agent.infrastructure.crewai.factory import create_runtime
from core.agent.infrastructure.persistence.message_repository import MessageRepository
from core.agent.infrastructure.persistence.session_repository import SessionRepository
from core.db import AsyncSessionLocal
from models.agent_chat_message import AgentChatMessageRole
from models.agent_chat_session import AgentChatSessionStatus
from models.llm import Llm
from models.llm_factory import LlmFactory
from models.system_agents import SystemAgents
def _to_int(value: object, default: int = 0) -> int:
if isinstance(value, int):
return value
if isinstance(value, str):
try:
return int(value)
except ValueError:
return default
return default
def _to_decimal(value: object) -> Decimal:
if isinstance(value, (int, float, str, Decimal)):
return Decimal(str(value))
return Decimal("0")
class RunService:
def __init__(
self,
*,
session_factory: async_sessionmaker[AsyncSession] = AsyncSessionLocal,
) -> None:
self._session_factory = session_factory
self._state_persistence = SessionStatePersistence()
async def run(self, *, session_id: str, user_input: str) -> dict[str, object]:
session_uuid = UUID(session_id)
pending_tool_call_id = f"tool-{uuid4()}"
async with self._session_factory() as db_session:
session_repository = SessionRepository(db_session)
message_repository = MessageRepository(db_session)
chat_session = await session_repository.lock_session_for_update(
session_id=session_uuid
)
if chat_session is None:
raise ValueError("session not found")
model_code, provider_name = await self._load_agent_model_selection(
db_session
)
runtime = create_runtime(model_code=model_code, provider_name=provider_name)
runtime_result = runtime.execute(user_input=user_input)
assistant_text = str(runtime_result.get("assistant_text", ""))
prompt_tokens = _to_int(runtime_result.get("prompt_tokens", 0))
completion_tokens = _to_int(runtime_result.get("completion_tokens", 0))
total_tokens = _to_int(runtime_result.get("total_tokens", 0))
cost = _to_decimal(runtime_result.get("cost", 0))
agui_events = runtime_result.get("agui_events", [])
next_seq = await session_repository.next_message_seq(
session_id=session_uuid
)
await message_repository.append_message(
session_id=session_uuid,
seq=next_seq,
role=AgentChatMessageRole.USER,
content=user_input,
model_code=model_code,
metadata={"type": "user_input"},
)
await message_repository.append_message(
session_id=session_uuid,
seq=next_seq + 1,
role=AgentChatMessageRole.ASSISTANT,
content=assistant_text or "Tool call pending approval",
model_code=model_code,
metadata={
"type": "tool_call",
"tool_call_id": pending_tool_call_id,
},
input_tokens=prompt_tokens,
output_tokens=completion_tokens,
cost=cost,
)
snapshot = self._state_persistence.build_running_snapshot(
pending_tool_call_id=pending_tool_call_id
)
await session_repository.update_runtime_state(
chat_session=chat_session,
status=AgentChatSessionStatus.RUNNING,
state_snapshot=snapshot,
message_delta=2,
token_delta=total_tokens,
cost_delta=cost,
)
await db_session.commit()
return {
"session_id": session_id,
"persisted": True,
"pending_tool_call_id": pending_tool_call_id,
"state_snapshot": snapshot,
"events": agui_events,
}
async def _load_agent_model_selection(
self, session: AsyncSession
) -> tuple[str, str]:
stmt = (
select(Llm.model_code, LlmFactory.name)
.join(SystemAgents, SystemAgents.llm_id == Llm.id)
.join(LlmFactory, LlmFactory.id == Llm.factory_id)
.where(SystemAgents.status == "active")
.order_by(SystemAgents.agent_type.asc())
.limit(1)
)
record = (await session.execute(stmt)).one_or_none()
if record is None:
raise ValueError("active system agent model is required")
return str(record[0]), str(record[1])
@@ -0,0 +1,60 @@
from __future__ import annotations
import hashlib
import json
from typing import Protocol
from core.agent.domain.tool_correlation import build_tool_result_metadata
from core.agent.domain.state_snapshot import AgentStateSnapshot
class SessionStatePersistence:
def build_running_snapshot(
self, *, pending_tool_call_id: str | None
) -> dict[str, object]:
return AgentStateSnapshot(
status="running",
pending_tool_call_id=pending_tool_call_id,
).model_dump()
def build_completed_snapshot(self) -> dict[str, object]:
return AgentStateSnapshot(status="completed").model_dump()
class ToolResultStorage(Protocol):
async def upload_json(
self,
*,
bucket: str,
path: str,
payload: dict[str, object],
) -> str: ...
async def persist_tool_result_payload(
*,
storage: ToolResultStorage,
run_id: str,
turn_id: str,
tool_call_id: str,
tool_name: str,
payload: dict[str, object],
bucket: str,
path: str,
) -> dict[str, object]:
encoded = json.dumps(payload, ensure_ascii=True, sort_keys=True).encode("utf-8")
sha256 = hashlib.sha256(encoded).hexdigest()
etag = await storage.upload_json(bucket=bucket, path=path, payload=payload)
metadata = build_tool_result_metadata(
run_id=run_id,
turn_id=turn_id,
tool_call_id=tool_call_id,
tool_name=tool_name,
storage_bucket=bucket,
storage_path=path,
payload_sha256=sha256,
payload_bytes=len(encoded),
payload_format="json",
)
metadata["storage_etag"] = etag
return metadata
@@ -0,0 +1 @@
from __future__ import annotations
@@ -0,0 +1,10 @@
from __future__ import annotations
from typing import Literal
from pydantic import BaseModel
class AgentStateSnapshot(BaseModel):
status: Literal["pending", "running", "completed", "failed"]
pending_tool_call_id: str | None = None
@@ -0,0 +1,40 @@
from __future__ import annotations
def reconstruct_tool_call_result_event(
*,
metadata: dict[str, object],
payload: dict[str, object],
) -> dict[str, object]:
return {
"type": "TOOL_CALL_RESULT",
"data": payload,
"tool_call_id": metadata.get("tool_call_id"),
"tool_name": metadata.get("tool_name"),
}
def build_tool_result_metadata(
*,
run_id: str,
turn_id: str,
tool_call_id: str,
tool_name: str,
storage_bucket: str,
storage_path: str,
payload_sha256: str,
payload_bytes: int,
payload_format: str,
) -> dict[str, object]:
return {
"type": "tool_result",
"run_id": run_id,
"turn_id": turn_id,
"tool_call_id": tool_call_id,
"tool_name": tool_name,
"storage_bucket": storage_bucket,
"storage_path": storage_path,
"payload_sha256": payload_sha256,
"payload_bytes": payload_bytes,
"payload_format": payload_format,
}
@@ -0,0 +1 @@
from __future__ import annotations
@@ -0,0 +1 @@
from __future__ import annotations
@@ -0,0 +1,90 @@
from __future__ import annotations
import re
from typing import Any
from ag_ui.core.events import EventType
_CAMEL_CASE_BOUNDARY_RE = re.compile(r"([a-z0-9])([A-Z])")
_NON_ALNUM_RE = re.compile(r"[^A-Za-z0-9]+")
_SENSITIVE_KEYS = {
"apikey",
"authorization",
"token",
"accesstoken",
"refreshtoken",
"secret",
"password",
}
_TYPE_ALIASES = {
"taskStarted": "STEP_STARTED",
"taskFinished": "STEP_FINISHED",
"llmChunk": "TEXT_MESSAGE_CONTENT",
"llmStarted": "TEXT_MESSAGE_START",
"llmFinished": "TEXT_MESSAGE_END",
"toolCalled": "TOOL_CALL_START",
"toolCompleted": "TOOL_CALL_RESULT",
"error": "RUN_ERROR",
}
def _is_sensitive_key(key: str) -> bool:
normalized = _NON_ALNUM_RE.sub("", key.lower())
if normalized in _SENSITIVE_KEYS:
return True
if "token" in normalized:
return True
if "api" in normalized and "key" in normalized:
return True
return False
def _to_upper_snake(value: str) -> str:
with_boundaries = _CAMEL_CASE_BOUNDARY_RE.sub(r"\1_\2", value)
cleaned = _NON_ALNUM_RE.sub("_", with_boundaries)
return cleaned.strip("_").upper()
def _to_event_type(value: str) -> EventType:
try:
return EventType(value)
except ValueError as exc:
raise ValueError(f"unsupported AG-UI event type: {value}") from exc
def _redact_sensitive(value: Any) -> Any:
if isinstance(value, dict):
return {
key: (
"***REDACTED***"
if _is_sensitive_key(str(key))
else _redact_sensitive(child)
)
for key, child in value.items()
}
if isinstance(value, list):
return [_redact_sensitive(item) for item in value]
return value
def to_agui_events(internal_events: list[dict[str, Any]]) -> list[dict[str, Any]]:
normalized_events: list[dict[str, Any]] = []
for event in internal_events:
raw_type_value = event.get("type")
if not isinstance(raw_type_value, str) or not raw_type_value.strip():
raise ValueError("event.type must be a non-empty string")
raw_type = raw_type_value.strip()
normalized_event = {
key: value for key, value in event.items() if key not in {"type", "data"}
}
normalized_type = _TYPE_ALIASES.get(raw_type, _to_upper_snake(raw_type))
normalized_event["type"] = _to_event_type(normalized_type).value
data = event.get("data")
if not isinstance(data, dict):
raise ValueError("event.data must be an object")
normalized_event["data"] = _redact_sensitive(data)
normalized_events.append(normalized_event)
return normalized_events
@@ -0,0 +1,10 @@
from __future__ import annotations
import json
from typing import Any
def to_sse_event(stream_id: str, event: dict[str, Any]) -> str:
event_type = str(event.get("type", "MESSAGE"))
payload = json.dumps(event.get("data", {}), ensure_ascii=True)
return f"id: {stream_id}\nevent: {event_type}\ndata: {payload}\n\n"
@@ -0,0 +1 @@
from __future__ import annotations
@@ -0,0 +1,104 @@
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Protocol, cast
from core.config.settings import config
@dataclass(frozen=True)
class ResolvedAgentConfig:
model_code: str
provider_api_key: str = field(repr=False)
provider_name: str
stream: bool
class AgentRuntimeSettingsLike(Protocol):
default_model_code: str
streaming_enabled: bool
class LlmSettingsLike(Protocol):
provider_keys: dict[str, str]
class SettingsLike(Protocol):
agent_runtime: AgentRuntimeSettingsLike
llm: LlmSettingsLike
_PROVIDER_ALIASES = {
"ark": "volcengine",
"volcengine-ark": "volcengine",
"z-ai": "zai",
}
_SUPPORTED_PROVIDERS = {
"dashscope",
"minimax",
"moonshot",
"deepseek",
"volcengine",
"zai",
}
def _normalize_provider(provider: str) -> str:
normalized = provider.strip().lower()
canonical = _PROVIDER_ALIASES.get(normalized, normalized)
if canonical not in _SUPPORTED_PROVIDERS:
raise ValueError(f"unsupported provider '{provider}'")
return canonical
def _infer_provider_from_model(model_code: str) -> str:
lowered = model_code.strip().lower()
if lowered.startswith("qwen"):
return "dashscope"
if lowered.startswith("deepseek"):
return "deepseek"
if lowered.startswith("kimi") or lowered.startswith("moonshot"):
return "moonshot"
if lowered.startswith("abab") or lowered.startswith("minimax"):
return "minimax"
if lowered.startswith("doubao") or lowered.startswith("ark"):
return "volcengine"
if lowered.startswith("glm") or lowered.startswith("zai"):
return "zai"
raise ValueError("provider_name is required for unknown model_code")
class AgentConfigResolver:
def __init__(self, settings: SettingsLike | None = None) -> None:
self._settings: SettingsLike = cast(SettingsLike, settings or config)
def resolve(
self,
*,
model_code: str | None,
provider_name: str | None,
) -> ResolvedAgentConfig:
runtime_settings = self._settings.agent_runtime
resolved_model = (model_code or runtime_settings.default_model_code).strip()
if not resolved_model:
raise ValueError("llm_model_code is required")
provider = _normalize_provider(
provider_name or _infer_provider_from_model(resolved_model)
)
key_map = {
_normalize_provider(key): value
for key, value in self._settings.llm.provider_keys.items()
if value.strip()
}
resolved_key = key_map.get(provider, "").strip()
if not resolved_key:
raise ValueError(f"provider api key is required for provider '{provider}'")
return ResolvedAgentConfig(
model_code=resolved_model,
provider_api_key=resolved_key,
provider_name=provider,
stream=runtime_settings.streaming_enabled,
)
@@ -0,0 +1 @@
from __future__ import annotations
@@ -0,0 +1,15 @@
from __future__ import annotations
from core.agent.infrastructure.config.resolver import AgentConfigResolver
from core.agent.infrastructure.crewai.runtime import CrewAIRuntime
def create_runtime(
*, model_code: str | None, provider_name: str | None
) -> CrewAIRuntime:
resolver = AgentConfigResolver()
return CrewAIRuntime(
resolver=resolver,
model_code=model_code,
provider_name=provider_name,
)
@@ -0,0 +1,101 @@
from __future__ import annotations
from typing import Any
from core.agent.infrastructure.agui.bridge import to_agui_events
from core.agent.infrastructure.config.resolver import (
AgentConfigResolver,
ResolvedAgentConfig,
)
from core.agent.infrastructure.litellm.client import run_completion
from core.agent.infrastructure.litellm.usage_tracker import extract_usage_and_cost
def _to_litellm_model(*, provider_name: str, model_code: str) -> str:
normalized_model = model_code.strip()
if "/" in normalized_model:
return normalized_model
return f"{provider_name.strip().lower()}/{normalized_model}"
def _extract_assistant_text(response: dict[str, Any]) -> str:
choices = response.get("choices")
if not isinstance(choices, list) or not choices:
return ""
first = choices[0]
if not isinstance(first, dict):
return ""
message = first.get("message")
if not isinstance(message, dict):
return ""
content = message.get("content")
if isinstance(content, str):
return content
if isinstance(content, list):
text_parts = []
for item in content:
if isinstance(item, dict) and isinstance(item.get("text"), str):
text_parts.append(item["text"])
return "".join(text_parts)
return ""
class CrewAIRuntime:
def __init__(
self,
*,
resolver: AgentConfigResolver,
model_code: str | None,
provider_name: str | None,
) -> None:
self._config: ResolvedAgentConfig = resolver.resolve(
model_code=model_code,
provider_name=provider_name,
)
def map_events(self, internal_events: list[dict[str, Any]]) -> list[dict[str, Any]]:
return to_agui_events(internal_events)
def execute(self, *, user_input: str) -> dict[str, object]:
litellm_model = _to_litellm_model(
provider_name=self._config.provider_name,
model_code=self._config.model_code,
)
response = run_completion(
model=litellm_model,
api_key=self._config.provider_api_key,
messages=[{"role": "user", "content": user_input}],
)
if not isinstance(response, dict):
raise ValueError("llm response must be a dict")
usage_cost = extract_usage_and_cost(response)
assistant_text = _extract_assistant_text(response)
internal_events = [
{
"type": "llmStarted",
"data": {"model": self._config.model_code},
},
{
"type": "llmChunk",
"data": {"text": assistant_text},
},
{
"type": "llmFinished",
"data": {
"prompt_tokens": usage_cost.prompt_tokens,
"completion_tokens": usage_cost.completion_tokens,
"total_tokens": usage_cost.total_tokens,
"cost": usage_cost.cost,
"provider": self._config.provider_name,
},
},
]
return {
"assistant_text": assistant_text,
"prompt_tokens": usage_cost.prompt_tokens,
"completion_tokens": usage_cost.completion_tokens,
"total_tokens": usage_cost.total_tokens,
"cost": usage_cost.cost,
"agui_events": self.map_events(internal_events),
}
@@ -0,0 +1 @@
from __future__ import annotations
@@ -0,0 +1,66 @@
from __future__ import annotations
import json
import inspect
from typing import Any, Protocol
from uuid import UUID
class RedisStreamClient(Protocol):
def xadd(self, *args: Any, **kwargs: Any) -> Any: ...
def xread(self, *args: Any, **kwargs: Any) -> Any: ...
class RedisStreamEventStore:
def __init__(
self,
*,
client: RedisStreamClient,
stream_prefix: str,
read_count: int = 100,
block_ms: int = 5000,
) -> None:
self._client = client
self._stream_prefix = stream_prefix
self._read_count = read_count
self._block_ms = block_ms
def append_event_sync(self, *, session_id: UUID, event: dict[str, Any]) -> str:
stream = self._stream_name(session_id)
payload = json.dumps(event, ensure_ascii=True, separators=(",", ":"))
return str(self._client.xadd(stream, {"event": payload}))
async def read_events(
self,
*,
session_id: UUID,
last_event_id: str | None,
) -> list[dict[str, Any]]:
stream = self._stream_name(session_id)
start_id = "$" if last_event_id is None else last_event_id
raw_response = self._client.xread(
{stream: start_id},
count=self._read_count,
block=self._block_ms,
)
response = (
await raw_response if inspect.isawaitable(raw_response) else raw_response
)
if not response:
return []
_, entries = response[0]
result: list[dict[str, Any]] = []
for stream_id, payload in entries:
result.append(
{
"id": stream_id,
"event": json.loads(payload["event"]),
}
)
return result
def _stream_name(self, session_id: UUID) -> str:
return f"{self._stream_prefix}:{session_id}"
@@ -0,0 +1 @@
from __future__ import annotations
@@ -0,0 +1,18 @@
from __future__ import annotations
from typing import Any
from litellm import completion
def run_completion(*, model: str, api_key: str, messages: list[dict[str, Any]]) -> Any:
response = completion(
model=model,
api_key=api_key,
messages=messages,
stream=False,
)
model_dump = getattr(response, "model_dump", None)
if callable(model_dump):
return model_dump()
return response
@@ -0,0 +1,72 @@
from __future__ import annotations
from dataclasses import dataclass
@dataclass(frozen=True)
class TieredModelPricing:
max_prompt_tokens: int
input_cost_per_token: float
output_cost_per_token: float
cache_create_cost_per_token: float
cache_hit_cost_per_token: float
QWEN35_FLASH_TIERED_PRICING: tuple[TieredModelPricing, ...] = (
TieredModelPricing(
max_prompt_tokens=128_000,
input_cost_per_token=0.0002 / 1000,
output_cost_per_token=0.002 / 1000,
cache_create_cost_per_token=0.00025 / 1000,
cache_hit_cost_per_token=0.00002 / 1000,
),
TieredModelPricing(
max_prompt_tokens=256_000,
input_cost_per_token=0.0008 / 1000,
output_cost_per_token=0.008 / 1000,
cache_create_cost_per_token=0.001 / 1000,
cache_hit_cost_per_token=0.00008 / 1000,
),
TieredModelPricing(
max_prompt_tokens=1_000_000,
input_cost_per_token=0.0012 / 1000,
output_cost_per_token=0.012 / 1000,
cache_create_cost_per_token=0.0015 / 1000,
cache_hit_cost_per_token=0.00012 / 1000,
),
)
_MODEL_TIERED_PRICING: dict[str, tuple[TieredModelPricing, ...]] = {
"dashscope/qwen3.5-flash": QWEN35_FLASH_TIERED_PRICING,
}
def get_tiered_pricing(
*, model_name: str, prompt_tokens: int
) -> TieredModelPricing | None:
tiers = _MODEL_TIERED_PRICING.get(model_name.strip().lower())
if tiers is None:
return None
for tier in tiers:
if prompt_tokens <= tier.max_prompt_tokens:
return tier
return tiers[-1]
def calculate_tiered_model_cost(
*,
model_name: str,
prompt_tokens: int,
completion_tokens: int,
) -> float | None:
tier = get_tiered_pricing(model_name=model_name, prompt_tokens=prompt_tokens)
if tier is None:
return None
return (
prompt_tokens * tier.input_cost_per_token
+ completion_tokens * tier.output_cost_per_token
)
@@ -0,0 +1,55 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import Any
from litellm import completion_cost
from core.agent.infrastructure.litellm.pricing import calculate_tiered_model_cost
@dataclass(frozen=True)
class UsageCost:
prompt_tokens: int
completion_tokens: int
total_tokens: int
cost: float
cost_source: str = "litellm"
def extract_usage_and_cost(response: dict[str, Any]) -> UsageCost:
usage = response.get("usage")
if not isinstance(usage, dict):
raise ValueError("missing usage in response")
prompt_tokens = int(usage.get("prompt_tokens", 0))
completion_tokens = int(usage.get("completion_tokens", 0))
total_tokens = int(usage.get("total_tokens", prompt_tokens + completion_tokens))
model_name = str(response.get("model", "")).strip().lower()
try:
cost = completion_cost(completion_response=response)
if cost is None:
raise ValueError("unable to calculate litellm completion cost")
return UsageCost(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=total_tokens,
cost=float(cost),
)
except Exception as exc:
local_cost = calculate_tiered_model_cost(
model_name=model_name,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
)
if local_cost is None:
raise ValueError("unable to calculate litellm completion cost") from exc
return UsageCost(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=total_tokens,
cost=float(local_cost),
cost_source="custom_pricing",
)
@@ -0,0 +1 @@
from __future__ import annotations
@@ -0,0 +1,41 @@
from __future__ import annotations
from decimal import Decimal
from uuid import UUID
from sqlalchemy.ext.asyncio import AsyncSession
from models.agent_chat_message import AgentChatMessage, AgentChatMessageRole
class MessageRepository:
def __init__(self, session: AsyncSession) -> None:
self._session = session
async def append_message(
self,
*,
session_id: UUID,
seq: int,
role: AgentChatMessageRole,
content: str,
model_code: str | None = None,
metadata: dict[str, object] | None = None,
input_tokens: int = 0,
output_tokens: int = 0,
cost: Decimal = Decimal("0"),
) -> AgentChatMessage:
message = AgentChatMessage(
session_id=session_id,
seq=seq,
role=role,
content=content,
model_code=model_code,
metadata_json=metadata,
input_tokens=input_tokens,
output_tokens=output_tokens,
cost=cost,
)
self._session.add(message)
await self._session.flush()
return message
@@ -0,0 +1,77 @@
from __future__ import annotations
from datetime import datetime, timezone
from decimal import Decimal
from uuid import UUID
from sqlalchemy import func, select, update
from sqlalchemy.ext.asyncio import AsyncSession
from models.agent_chat_message import AgentChatMessage
from models.agent_chat_session import AgentChatSession, AgentChatSessionStatus
class SessionRepository:
def __init__(self, session: AsyncSession) -> None:
self._session = session
async def get_session(self, *, session_id: UUID) -> AgentChatSession | None:
return await self._session.get(AgentChatSession, session_id)
async def lock_session_for_update(
self, *, session_id: UUID
) -> AgentChatSession | None:
stmt = (
select(AgentChatSession)
.where(AgentChatSession.id == session_id)
.with_for_update()
)
return (await self._session.execute(stmt)).scalar_one_or_none()
async def next_message_seq(self, *, session_id: UUID) -> int:
stmt = select(func.coalesce(func.max(AgentChatMessage.seq), 0)).where(
AgentChatMessage.session_id == session_id
)
current = (await self._session.execute(stmt)).scalar_one()
return int(current) + 1
async def update_runtime_state(
self,
*,
chat_session: AgentChatSession,
status: AgentChatSessionStatus,
state_snapshot: dict[str, object],
message_delta: int,
token_delta: int = 0,
cost_delta: Decimal = Decimal("0"),
) -> None:
chat_session.status = status
chat_session.state_snapshot = state_snapshot
chat_session.last_activity_at = datetime.now(timezone.utc)
chat_session.message_count += message_delta
chat_session.total_tokens += token_delta
chat_session.total_cost += cost_delta
await self._session.flush()
async def soft_delete_session_with_messages(self, *, session_id: UUID) -> int:
existing = await self.get_session(session_id=session_id)
if existing is None or existing.deleted_at is not None:
return 0
deleted_at = datetime.now(timezone.utc)
session_stmt = (
update(AgentChatSession)
.where(AgentChatSession.id == session_id)
.where(AgentChatSession.deleted_at.is_(None))
.values(deleted_at=deleted_at)
)
message_stmt = (
update(AgentChatMessage)
.where(AgentChatMessage.session_id == session_id)
.where(AgentChatMessage.deleted_at.is_(None))
.values(deleted_at=deleted_at)
)
await self._session.execute(session_stmt)
await self._session.execute(message_stmt)
await self._session.flush()
return 1
@@ -0,0 +1 @@
from __future__ import annotations
@@ -0,0 +1,159 @@
from __future__ import annotations
import asyncio
import threading
from typing import Any, Callable, Protocol, cast
from uuid import UUID
import redis
from core.agent.application.resume_service import ResumeService
from core.agent.application.run_service import RunService
from core.agent.infrastructure.events.redis_stream import RedisStreamEventStore
from core.celery.app import celery_app
from core.config.settings import config
from core.logging import get_logger
logger = get_logger("core.agent.infrastructure.queue.tasks")
_background_loop: asyncio.AbstractEventLoop | None = None
_background_thread: threading.Thread | None = None
_background_ready = threading.Event()
class PublishEvent(Protocol):
def __call__(self, event_type: str, payload: dict[str, object]) -> None: ...
class RunServiceLike(Protocol):
async def run(self, *, session_id: str, user_input: str) -> dict[str, object]: ...
class ResumeServiceLike(Protocol):
async def resume(
self, *, session_id: str, tool_call_id: str
) -> dict[str, object]: ...
def _run_async(task: Callable[[], Any]) -> Any:
loop = _ensure_background_loop()
future = asyncio.run_coroutine_threadsafe(task(), loop)
return future.result()
def _ensure_background_loop() -> asyncio.AbstractEventLoop:
global _background_loop, _background_thread
if _background_loop is not None:
return _background_loop
def _loop_worker() -> None:
global _background_loop
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
_background_loop = loop
_background_ready.set()
loop.run_forever()
_background_thread = threading.Thread(target=_loop_worker, daemon=True)
_background_thread.start()
_background_ready.wait(timeout=5)
if _background_loop is None:
raise RuntimeError("failed to initialize background event loop")
return _background_loop
def _build_redis_publisher() -> PublishEvent:
settings = cast(Any, config)
client = redis.from_url(settings.redis.url, decode_responses=True)
event_store = RedisStreamEventStore(
client=client,
stream_prefix=settings.agent_runtime.redis_stream_prefix,
read_count=settings.agent_runtime.redis_stream_read_count,
block_ms=settings.agent_runtime.redis_stream_block_ms,
)
def _publish(event_type: str, payload: dict[str, object]) -> None:
session_id = str(payload.get("session_id", "")).strip()
if not session_id:
raise ValueError("session_id is required in event payload")
event_store.append_event_sync(
session_id=UUID(session_id),
event={"type": event_type, "data": payload},
)
return _publish
def run_agent_task(
command: dict[str, Any],
*,
publish_event: PublishEvent | None = None,
run_service: RunServiceLike | None = None,
resume_service: ResumeServiceLike | None = None,
) -> dict[str, object]:
publisher = publish_event or _build_redis_publisher()
service_run = run_service or RunService()
service_resume = resume_service or ResumeService()
command_type = str(command.get("command", "run"))
session_id = str(command.get("session_id", ""))
if command_type not in {"run", "resume"}:
raise ValueError("invalid command type")
if not session_id:
raise ValueError("session_id is required")
UUID(session_id)
start_event = "RUN_RESUMED" if command_type == "resume" else "RUN_STARTED"
publisher(start_event, {"session_id": session_id})
try:
if command_type == "resume":
tool_call_id = str(command.get("tool_call_id", ""))
if not tool_call_id:
raise ValueError("tool_call_id is required")
result = _run_async(
lambda: service_resume.resume(
session_id=session_id,
tool_call_id=tool_call_id,
)
)
else:
user_input = str(command.get("user_input", ""))
if not user_input:
raise ValueError("user_input is required")
result = _run_async(
lambda: service_run.run(
session_id=session_id,
user_input=user_input,
)
)
publisher("RUNTIME_EVENT", {"session_id": session_id, "result": result})
extra_events = result.get("events") if isinstance(result, dict) else None
if isinstance(extra_events, list):
for event in extra_events:
if not isinstance(event, dict):
continue
event_type = event.get("type")
event_data = event.get("data")
if not isinstance(event_type, str) or not isinstance(event_data, dict):
continue
payload = {"session_id": session_id, **event_data}
publisher(event_type, payload)
publisher("RUN_FINISHED", {"session_id": session_id})
return result
except Exception: # noqa: BLE001
error_id = "agent_runtime_failed"
logger.exception(
"Agent task failed",
session_id=session_id,
error_id=error_id,
)
publisher("RUN_ERROR", {"session_id": session_id, "error_id": error_id})
raise
@celery_app.task(name="tasks.agent.run_command")
def run_command_task(command: dict[str, Any]) -> dict[str, object]:
return run_agent_task(command)
+1
View File
@@ -15,6 +15,7 @@ def create_celery_app() -> Celery:
"social_app",
broker=config.celery_broker_url,
backend=config.celery_result_backend,
include=["core.agent.infrastructure.queue.tasks"],
)
app.conf.update(
+14
View File
@@ -140,6 +140,18 @@ class StorageSettings(BaseModel):
retention_days: int = Field(default=30, ge=1, le=3650)
class AgentRuntimeSettings(BaseModel):
redis_stream_prefix: str = "agent:events"
redis_stream_read_count: int = Field(default=100, ge=1, le=1000)
redis_stream_block_ms: int = Field(default=5000, ge=1, le=60000)
default_model_code: str = ""
streaming_enabled: bool = True
class LlmSettings(BaseModel):
provider_keys: dict[str, str] = Field(default_factory=dict)
class DatabaseSettings(BaseModel):
host: str = "localhost"
port: int = 5432
@@ -172,6 +184,8 @@ class Settings(BaseSettings):
redis: RedisSettings = RedisSettings()
supabase: SupabaseSettings = SupabaseSettings()
storage: StorageSettings = StorageSettings()
llm: LlmSettings = LlmSettings()
agent_runtime: AgentRuntimeSettings = AgentRuntimeSettings()
celery: CelerySettings = CelerySettings()
database: DatabaseSettings = DatabaseSettings()
@@ -15,11 +15,11 @@ factories:
request_url: https://api.deepseek.com/v1
avatar: https://registry.npmmirror.com/@lobehub/icons-static-png/latest/files/light/deepseek-color.png
- name: volcengine-ark
- name: volcengine
request_url: https://ark.cn-beijing.volces.com/api/v3
avatar: https://registry.npmmirror.com/@lobehub/icons-static-png/latest/files/light/doubao-color.png
- name: z-ai
- name: zai
request_url: https://api.z.ai/api/paas/v4
avatar: https://registry.npmmirror.com/@lobehub/icons-static-png/latest/files/light/zai.png
+4 -1
View File
@@ -45,7 +45,10 @@ class AgentChatMessage(TimestampMixin, SoftDeleteMixin, Base):
seq: Mapped[int] = mapped_column(Integer, nullable=False)
role: Mapped[AgentChatMessageRole] = mapped_column(
SqlEnum(
AgentChatMessageRole, name="agent_chat_message_role", native_enum=False
AgentChatMessageRole,
name="agent_chat_message_role",
native_enum=False,
values_callable=lambda enum_cls: [item.value for item in enum_cls],
),
nullable=False,
)
+6 -2
View File
@@ -7,6 +7,7 @@ from enum import Enum
from sqlalchemy import (
DateTime,
JSON,
Enum as SqlEnum,
Integer,
Numeric,
@@ -56,7 +57,10 @@ class AgentChatSession(TimestampMixin, SoftDeleteMixin, Base):
title: Mapped[str | None] = mapped_column(String(255), nullable=True)
status: Mapped[AgentChatSessionStatus] = mapped_column(
SqlEnum(
AgentChatSessionStatus, name="agent_chat_session_status", native_enum=False
AgentChatSessionStatus,
name="agent_chat_session_status",
native_enum=False,
values_callable=lambda enum_cls: [item.value for item in enum_cls],
),
nullable=False,
default=AgentChatSessionStatus.PENDING,
@@ -76,6 +80,6 @@ class AgentChatSession(TimestampMixin, SoftDeleteMixin, Base):
Numeric(12, 6), nullable=False, server_default=text("0")
)
state_snapshot: Mapped[dict | None] = mapped_column(
JSONB,
JSON().with_variant(JSONB, "postgresql"),
nullable=True,
)
+2 -2
View File
@@ -2,7 +2,7 @@ from __future__ import annotations
import uuid
from sqlalchemy import ForeignKey, String
from sqlalchemy import JSON, ForeignKey, String
from sqlalchemy.dialects.postgresql import JSONB, UUID
from sqlalchemy.orm import Mapped, mapped_column
@@ -26,7 +26,7 @@ class SystemAgents(TimestampMixin, Base):
nullable=False,
)
config: Mapped[dict] = mapped_column(
JSONB,
JSON().with_variant(JSONB, "postgresql"),
nullable=False,
server_default="{}",
)
+1
View File
@@ -0,0 +1 @@
from __future__ import annotations
+75
View File
@@ -0,0 +1,75 @@
from __future__ import annotations
from typing import Any, cast
from uuid import UUID
from fastapi import Depends
import redis.asyncio as redis
from sqlalchemy.ext.asyncio import AsyncSession
from core.agent.infrastructure.events.redis_stream import RedisStreamEventStore
from core.agent.infrastructure.queue.tasks import run_command_task
from core.config.settings import config
from core.db import get_db
from v1.agent.repository import AgentRepository
from v1.agent.service import AgentService
class CeleryQueueClient:
def __init__(self) -> None:
settings = cast(Any, config)
self._redis = redis.from_url(settings.redis.url, decode_responses=True)
async def enqueue(
self, *, command: dict[str, object], dedup_key: str | None
) -> str:
redis_key = None
if dedup_key:
redis_key = f"agent:dedup:{dedup_key}"
locked = await self._redis.set(redis_key, "__inflight__", nx=True, ex=300)
if not locked:
existing = await self._redis.get(redis_key)
if existing and existing != "__inflight__":
return existing
payload = dict(command)
if dedup_key:
payload["dedup_key"] = dedup_key
delay = getattr(run_command_task, "delay")
result = delay(payload)
task_id = str(result.id)
if redis_key is not None:
await self._redis.set(redis_key, task_id, ex=300)
return task_id
class RedisEventStream:
def __init__(self) -> None:
settings = cast(Any, config)
client = redis.from_url(settings.redis.url, decode_responses=True)
self._store = RedisStreamEventStore(
client=client,
stream_prefix=settings.agent_runtime.redis_stream_prefix,
read_count=settings.agent_runtime.redis_stream_read_count,
block_ms=settings.agent_runtime.redis_stream_block_ms,
)
async def read(
self,
*,
session_id: str,
last_event_id: str | None,
) -> list[dict[str, Any]]:
rows = await self._store.read_events(
session_id=UUID(session_id),
last_event_id=last_event_id,
)
return [{**row, "cursor": last_event_id} for row in rows]
def get_agent_service(session: AsyncSession = Depends(get_db)) -> AgentService:
return AgentService(
repository=AgentRepository(session),
queue=CeleryQueueClient(),
stream=RedisEventStream(),
)
+56
View File
@@ -0,0 +1,56 @@
from __future__ import annotations
from uuid import UUID
from fastapi import HTTPException
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from models.agent_chat_session import AgentChatSession
class AgentRepository:
def __init__(self, session: AsyncSession) -> None:
self._session = session
async def get_session_owner(self, *, session_id: str) -> str:
try:
session_uuid = UUID(session_id)
except ValueError as exc:
raise HTTPException(status_code=422, detail="Invalid session_id") from exc
stmt = select(AgentChatSession.user_id).where(
AgentChatSession.id == session_uuid
)
owner_id = (await self._session.execute(stmt)).scalar_one_or_none()
if owner_id is None:
raise HTTPException(status_code=404, detail="Session not found")
return str(owner_id)
async def create_session_for_user(self, *, user_id: str) -> str:
try:
user_uuid = UUID(user_id)
except ValueError as exc:
raise HTTPException(status_code=422, detail="Invalid user_id") from exc
session = AgentChatSession(user_id=user_uuid)
self._session.add(session)
await self._session.flush()
await self._session.refresh(session)
return str(session.id)
async def commit(self) -> None:
await self._session.commit()
async def rollback(self) -> None:
await self._session.rollback()
async def delete_session(self, *, session_id: str) -> None:
try:
session_uuid = UUID(session_id)
except ValueError as exc:
raise HTTPException(status_code=422, detail="Invalid session_id") from exc
session = await self._session.get(AgentChatSession, session_uuid)
if session is not None:
await self._session.delete(session)
await self._session.flush()
+104
View File
@@ -0,0 +1,104 @@
from __future__ import annotations
from collections.abc import AsyncIterator
import asyncio
from typing import Annotated
from fastapi import APIRouter, Depends, Header, Query, Request, status
from fastapi.responses import StreamingResponse
from core.agent.infrastructure.agui.stream import to_sse_event
from core.auth.models import CurrentUser
from v1.agent.dependencies import get_agent_service
from v1.agent.schemas import ResumeRequest, RunRequest, TaskAcceptedResponse
from v1.agent.service import AgentService
from v1.users.dependencies import get_current_user
router = APIRouter(prefix="/agent", tags=["agent"])
@router.post(
"/runs", response_model=TaskAcceptedResponse, status_code=status.HTTP_202_ACCEPTED
)
async def enqueue_run(
request: RunRequest,
service: Annotated[AgentService, Depends(get_agent_service)],
current_user: Annotated[CurrentUser, Depends(get_current_user)],
) -> TaskAcceptedResponse:
task = await service.enqueue_run(
session_id=request.session_id,
prompt=request.prompt,
current_user=current_user,
)
return TaskAcceptedResponse(
task_id=task.task_id,
session_id=task.session_id,
created=task.created,
)
@router.post(
"/runs/{session_id}/resume",
response_model=TaskAcceptedResponse,
status_code=status.HTTP_202_ACCEPTED,
)
async def enqueue_resume(
session_id: str,
request: ResumeRequest,
service: Annotated[AgentService, Depends(get_agent_service)],
current_user: Annotated[CurrentUser, Depends(get_current_user)],
) -> TaskAcceptedResponse:
task = await service.enqueue_resume(
session_id=session_id,
tool_call_id=request.tool_call_id,
current_user=current_user,
)
return TaskAcceptedResponse(
task_id=task.task_id,
session_id=task.session_id,
created=task.created,
)
@router.get("/runs/{session_id}/events")
async def stream_events(
request: Request,
session_id: str,
service: Annotated[AgentService, Depends(get_agent_service)],
current_user: Annotated[CurrentUser, Depends(get_current_user)],
last_event_id: str | None = Header(default=None, alias="Last-Event-ID"),
idle_limit: int = Query(default=300, ge=1, le=3600),
) -> StreamingResponse:
async def _event_iter() -> AsyncIterator[str]:
cursor = last_event_id
idle_polls = 0
while not await request.is_disconnected() and idle_polls < idle_limit:
rows = await service.stream_events(
session_id=session_id,
last_event_id=cursor,
current_user=current_user,
)
if not rows:
idle_polls += 1
yield ": keep-alive\n\n"
await asyncio.sleep(0.2)
continue
idle_polls = 0
for row in rows:
row_id = str(row.get("id", ""))
event = row.get("event")
if not row_id or not isinstance(event, dict):
continue
cursor = row_id
yield to_sse_event(row_id, event)
return StreamingResponse(
_event_iter(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
},
)
+18
View File
@@ -0,0 +1,18 @@
from __future__ import annotations
from pydantic import BaseModel, Field
class RunRequest(BaseModel):
session_id: str | None = Field(default=None, min_length=1, max_length=100)
prompt: str = Field(min_length=1, max_length=5000)
class ResumeRequest(BaseModel):
tool_call_id: str = Field(min_length=1, max_length=200)
class TaskAcceptedResponse(BaseModel):
task_id: str
session_id: str
created: bool
+132
View File
@@ -0,0 +1,132 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import Protocol
from fastapi import HTTPException
from core.auth.models import CurrentUser
@dataclass(frozen=True)
class TaskAccepted:
task_id: str
session_id: str
created: bool
class AgentRepositoryLike(Protocol):
async def get_session_owner(self, *, session_id: str) -> str: ...
async def create_session_for_user(self, *, user_id: str) -> str: ...
async def commit(self) -> None: ...
async def rollback(self) -> None: ...
class QueueClientLike(Protocol):
async def enqueue(
self, *, command: dict[str, object], dedup_key: str | None
) -> str: ...
class EventStreamLike(Protocol):
async def read(
self,
*,
session_id: str,
last_event_id: str | None,
) -> list[dict[str, object]]: ...
def ensure_session_owner(*, owner_id: str, current_user: CurrentUser) -> None:
if owner_id != str(current_user.id):
raise HTTPException(status_code=403, detail="Forbidden")
class AgentService:
def __init__(
self,
*,
repository: AgentRepositoryLike,
queue: QueueClientLike,
stream: EventStreamLike,
) -> None:
self._repository = repository
self._queue = queue
self._stream = stream
async def enqueue_run(
self,
*,
session_id: str | None,
prompt: str,
current_user: CurrentUser,
) -> TaskAccepted:
created = False
target_session_id = session_id
if target_session_id is None:
target_session_id = await self._repository.create_session_for_user(
user_id=str(current_user.id)
)
created = True
else:
owner = await self._repository.get_session_owner(
session_id=target_session_id
)
ensure_session_owner(owner_id=owner, current_user=current_user)
if created:
await self._repository.commit()
try:
task_id = await self._queue.enqueue(
command={
"command": "run",
"session_id": target_session_id,
"user_input": prompt,
},
dedup_key=None,
)
except Exception: # noqa: BLE001
raise
return TaskAccepted(
task_id=task_id, session_id=target_session_id, created=created
)
async def enqueue_resume(
self,
*,
session_id: str,
tool_call_id: str,
current_user: CurrentUser,
) -> TaskAccepted:
owner = await self._repository.get_session_owner(session_id=session_id)
ensure_session_owner(owner_id=owner, current_user=current_user)
dedup_key = f"resume:{session_id}:{tool_call_id}"
task_id = await self._queue.enqueue(
command={
"command": "resume",
"session_id": session_id,
"tool_call_id": tool_call_id,
},
dedup_key=dedup_key,
)
return TaskAccepted(task_id=task_id, session_id=session_id, created=False)
async def stream_events(
self,
*,
session_id: str,
last_event_id: str | None,
current_user: CurrentUser,
) -> list[dict[str, object]]:
owner = await self._repository.get_session_owner(session_id=session_id)
ensure_session_owner(owner_id=owner, current_user=current_user)
return await self._stream.read(
session_id=session_id,
last_event_id=last_event_id,
)
+2
View File
@@ -3,6 +3,7 @@ from __future__ import annotations
from fastapi import APIRouter
from core.http.models import HealthResponse
from v1.agent.router import router as agent_router
from v1.auth.router import router as auth_router
from v1.friendships.router import router as friendships_router
from v1.inbox_messages.router import router as inbox_messages_router
@@ -13,6 +14,7 @@ from v1.users.router import router as users_router
router = APIRouter(prefix="/api/v1")
router.include_router(auth_router)
router.include_router(agent_router)
router.include_router(friendships_router)
router.include_router(infra_router)
router.include_router(users_router)
@@ -0,0 +1,97 @@
from __future__ import annotations
import os
from datetime import datetime, timedelta, timezone
from uuid import UUID
import httpx
import jwt
import pytest
from sqlalchemy import select
from core.config import config
from core.db.session import AsyncSessionLocal
from models.agent_chat_message import AgentChatMessage
from models.agent_chat_session import AgentChatSession
from models.profile import Profile
BASE_URL = os.getenv("AGENT_LIVE_BASE_URL", "http://localhost:5775")
async def _owner_id() -> UUID:
async with AsyncSessionLocal() as session:
owner_id = (
await session.execute(select(Profile.id).limit(1))
).scalar_one_or_none()
if owner_id is None:
raise RuntimeError("profile owner not found")
return owner_id
def _jwt_for(user_id: UUID) -> str:
secret = config.supabase.jwt_secret
if not secret:
raise RuntimeError("JWT secret not configured")
issuer = f"{config.supabase.public_url.rstrip('/')}/auth/v1"
payload = {
"sub": str(user_id),
"role": "authenticated",
"aud": "authenticated",
"iss": issuer,
"iat": datetime.now(timezone.utc),
"exp": datetime.now(timezone.utc) + timedelta(minutes=30),
}
return jwt.encode(payload, secret, algorithm="HS256")
@pytest.mark.asyncio
@pytest.mark.live
async def test_agent_closed_loop_live() -> None:
if os.getenv("AGENT_LIVE_E2E") != "1":
pytest.skip("set AGENT_LIVE_E2E=1 to run live closed-loop test")
owner_id = await _owner_id()
token = _jwt_for(owner_id)
headers = {"Authorization": f"Bearer {token}"}
async with httpx.AsyncClient(timeout=30.0) as client:
run_resp = await client.post(
f"{BASE_URL}/api/v1/agent/runs",
headers=headers,
json={"prompt": "请用一句话介绍你自己"},
)
assert run_resp.status_code == 202
accepted = run_resp.json()
session_id = str(accepted["session_id"])
assert session_id
events_url = f"{BASE_URL}/api/v1/agent/runs/{session_id}/events"
event_names: list[str] = []
async with client.stream(
"GET", events_url, headers=headers, timeout=20.0
) as sse_resp:
assert sse_resp.status_code == 200
assert sse_resp.headers.get("content-type", "").startswith(
"text/event-stream"
)
async for line in sse_resp.aiter_lines():
if line.startswith("event:"):
event_names.append(line.split(":", 1)[1].strip())
assert "RUN_STARTED" in event_names
assert "RUN_FINISHED" in event_names or "RUN_ERROR" in event_names
async with AsyncSessionLocal() as session:
session_row = await session.get(AgentChatSession, UUID(session_id))
assert session_row is not None
assert session_row.message_count >= 1
assert session_row.total_tokens >= 0
assert session_row.total_cost >= 0
rows = await session.execute(
select(AgentChatMessage).where(
AgentChatMessage.session_id == UUID(session_id)
)
)
assert len(list(rows.scalars().all())) >= 1
@@ -0,0 +1,213 @@
from __future__ import annotations
import uuid
from decimal import Decimal
import pytest
from sqlalchemy import delete, select
from core.agent.application.resume_service import ResumeService
from core.agent.application.run_service import RunService
from core.agent.infrastructure.persistence.session_repository import SessionRepository
from core.agent.infrastructure.queue.tasks import run_agent_task
from core.db import AsyncSessionLocal, engine
from models.agent_chat_message import AgentChatMessage, AgentChatMessageRole
from models.agent_chat_session import AgentChatSession, AgentChatSessionStatus
from models.llm import Llm
from models.llm_factory import LlmFactory
from models.profile import Profile
from models.system_agents import SystemAgents
@pytest.mark.asyncio
async def test_run_then_resume_persists_messages_and_session_state(
monkeypatch: pytest.MonkeyPatch,
) -> None:
def _fake_execute(self, *, user_input: str) -> dict[str, object]:
del user_input
return {
"assistant_text": "Mocked answer",
"prompt_tokens": 11,
"completion_tokens": 7,
"total_tokens": 18,
"cost": 0.0025,
"agui_events": [
{"type": "TEXT_MESSAGE_START", "data": {"session_id": "__TBD__"}},
{
"type": "TEXT_MESSAGE_CONTENT",
"data": {"session_id": "__TBD__", "text": "Mocked answer"},
},
{"type": "TEXT_MESSAGE_END", "data": {"session_id": "__TBD__"}},
],
}
monkeypatch.setattr(
"core.agent.infrastructure.crewai.runtime.CrewAIRuntime.execute",
_fake_execute,
)
async with AsyncSessionLocal() as lookup_session:
existing_owner = await lookup_session.execute(
select(AgentChatSession.user_id).limit(1)
)
owner_id = existing_owner.scalar_one_or_none()
if owner_id is None:
pytest.skip("No existing session owner available in local database")
factory_id = uuid.uuid4()
session_uuid = uuid.uuid4()
agent_type = f"AAA_TEST_{uuid.uuid4().hex[:8]}"
async with AsyncSessionLocal() as seed_session:
llm_row = await seed_session.execute(select(Llm.id).limit(1))
llm_id = llm_row.scalar_one_or_none()
if llm_id is None:
seed_session.add(
LlmFactory(
id=factory_id,
name=f"dashscope-test-{uuid.uuid4().hex[:8]}",
request_url="https://dashscope.example",
)
)
llm_id = uuid.uuid4()
seed_session.add(
Llm(
id=llm_id,
factory_id=factory_id,
model_code=f"qwen3.5-flash-test-{uuid.uuid4().hex[:6]}",
)
)
seed_session.add(
SystemAgents(agent_type=agent_type, llm_id=llm_id, status="active")
)
seed_session.add(AgentChatSession(id=session_uuid, user_id=owner_id))
await seed_session.commit()
published: list[str] = []
def _publish(event_type: str, payload: dict[str, object]) -> None:
del payload
published.append(event_type)
try:
run_result = run_agent_task(
{
"command": "run",
"session_id": str(session_uuid),
"user_input": "hello",
},
publish_event=_publish,
run_service=RunService(),
resume_service=ResumeService(),
)
pending_tool_call_id = str(run_result["pending_tool_call_id"])
run_agent_task(
{
"command": "resume",
"session_id": str(session_uuid),
"tool_call_id": pending_tool_call_id,
},
publish_event=_publish,
run_service=RunService(),
resume_service=ResumeService(),
)
await engine.dispose()
async with AsyncSessionLocal() as verify_session:
db_session = await verify_session.get(AgentChatSession, session_uuid)
assert db_session is not None
assert db_session.status == AgentChatSessionStatus.COMPLETED
assert db_session.message_count == 4
assert db_session.total_tokens == 18
assert db_session.total_cost == Decimal("0.002500")
assert db_session.state_snapshot == {
"status": "completed",
"pending_tool_call_id": None,
}
rows = await verify_session.execute(
select(AgentChatMessage)
.where(AgentChatMessage.session_id == session_uuid)
.order_by(AgentChatMessage.seq.asc())
)
messages = list(rows.scalars().all())
assert [item.role for item in messages] == [
AgentChatMessageRole.USER,
AgentChatMessageRole.ASSISTANT,
AgentChatMessageRole.TOOL,
AgentChatMessageRole.ASSISTANT,
]
assert messages[1].input_tokens == 11
assert messages[1].output_tokens == 7
assert messages[1].cost == Decimal("0.002500")
assert "RUN_STARTED" in published
assert "RUN_RESUMED" in published
assert "TEXT_MESSAGE_CONTENT" in published
finally:
async with AsyncSessionLocal() as cleanup_session:
await cleanup_session.execute(
delete(AgentChatSession).where(AgentChatSession.id == session_uuid)
)
await cleanup_session.execute(
delete(SystemAgents).where(SystemAgents.agent_type == agent_type)
)
await cleanup_session.commit()
@pytest.mark.asyncio
async def test_soft_delete_session_cascades_to_messages() -> None:
session_uuid = uuid.uuid4()
await engine.dispose()
async with AsyncSessionLocal() as lookup_session:
owner = await lookup_session.execute(select(Profile.id).limit(1))
owner_id = owner.scalar_one_or_none()
if owner_id is None:
pytest.skip("No profile owner available in local database")
async with AsyncSessionLocal() as seed_session:
seed_session.add(AgentChatSession(id=session_uuid, user_id=owner_id))
await seed_session.flush()
seed_session.add(
AgentChatMessage(
session_id=session_uuid,
seq=1,
role=AgentChatMessageRole.USER,
content="hello",
)
)
await seed_session.commit()
try:
async with AsyncSessionLocal() as mutate_session:
repo = SessionRepository(mutate_session)
affected = await repo.soft_delete_session_with_messages(
session_id=session_uuid
)
await mutate_session.commit()
assert affected == 1
async with AsyncSessionLocal() as verify_session:
db_session = await verify_session.get(AgentChatSession, session_uuid)
assert db_session is not None
assert db_session.deleted_at is not None
rows = await verify_session.execute(
select(AgentChatMessage).where(
AgentChatMessage.session_id == session_uuid
)
)
messages = list(rows.scalars().all())
assert len(messages) == 1
assert messages[0].deleted_at is not None
finally:
async with AsyncSessionLocal() as cleanup_session:
await cleanup_session.execute(
delete(AgentChatMessage).where(
AgentChatMessage.session_id == session_uuid
)
)
await cleanup_session.execute(
delete(AgentChatSession).where(AgentChatSession.id == session_uuid)
)
await cleanup_session.commit()
@@ -0,0 +1,69 @@
from __future__ import annotations
from core.agent.application.session_state_persistence import persist_tool_result_payload
from core.agent.domain.tool_correlation import reconstruct_tool_call_result_event
from core.agent.infrastructure.queue.tasks import run_agent_task
class _FakeStorage:
def __init__(self) -> None:
self.writes: dict[str, dict[str, object]] = {}
async def upload_json(
self, *, bucket: str, path: str, payload: dict[str, object]
) -> str:
self.writes[f"{bucket}/{path}"] = payload
return "etag-1"
def test_closed_loop_run_flow_frontend_to_sse() -> None:
session_id = "00000000-0000-0000-0000-000000000001"
published: list[str] = []
class _FakeRunService:
async def run(self, *, session_id: str, user_input: str) -> dict[str, object]:
return {"session_id": session_id, "user_input": user_input}
def _publish(event_type: str, payload: dict[str, object]) -> None:
del payload
published.append(event_type)
result = run_agent_task(
{
"command": "run",
"session_id": session_id,
"user_input": "hello",
},
publish_event=_publish,
run_service=_FakeRunService(),
)
assert result["session_id"] == session_id
assert published[0] == "RUN_STARTED"
assert published[-1] == "RUN_FINISHED"
async def test_tool_result_full_payload_persist_and_reconstruct() -> None:
storage = _FakeStorage()
payload = {
"schema": "ui.v1",
"components": [{"type": "card", "title": "Weather"}],
}
metadata = await persist_tool_result_payload(
storage=storage,
run_id="run-1",
turn_id="turn-1",
tool_call_id="call-1",
tool_name="weather",
payload=payload,
bucket="private",
path="tool-results/run-1/call-1.json",
)
event = reconstruct_tool_call_result_event(metadata=metadata, payload=payload)
assert metadata["type"] == "tool_result"
assert metadata["storage_bucket"] == "private"
assert event["type"] == "TOOL_CALL_RESULT"
assert event["data"]["schema"] == "ui.v1"
@@ -0,0 +1,107 @@
from __future__ import annotations
from types import SimpleNamespace
from uuid import uuid4
from fastapi.testclient import TestClient
from app import app
from core.auth.models import CurrentUser
from v1.agent.dependencies import get_agent_service
from v1.users.dependencies import get_current_user
class _FakeAgentService:
def __init__(self) -> None:
self._stream_called = False
async def enqueue_run(
self, *, session_id: str | None, prompt: str, current_user: CurrentUser
):
del prompt, current_user
resolved_session = session_id or "auto-created-session"
return SimpleNamespace(
task_id="task-run-1",
session_id=resolved_session,
created=session_id is None,
)
async def enqueue_resume(
self,
*,
session_id: str,
tool_call_id: str,
current_user: CurrentUser,
):
del tool_call_id, current_user
return SimpleNamespace(
task_id="task-resume-1", session_id=session_id, created=False
)
async def stream_events(
self,
*,
session_id: str,
last_event_id: str | None,
current_user: CurrentUser,
) -> list[dict[str, object]]:
del session_id, current_user
if self._stream_called:
return []
self._stream_called = True
return [
{"id": "2-0", "event": {"type": "RUN_STARTED"}, "cursor": last_event_id}
]
def test_run_requires_auth_and_returns_202_task_id() -> None:
app.dependency_overrides[get_agent_service] = lambda: _FakeAgentService()
client = TestClient(app)
try:
unauthorized = client.post(
"/api/v1/agent/runs",
json={"session_id": "session-1", "prompt": "hello"},
)
assert unauthorized.status_code == 401
app.dependency_overrides[get_current_user] = lambda: CurrentUser(
id=uuid4(), email="user@example.com"
)
authorized = client.post(
"/api/v1/agent/runs",
json={"session_id": "session-1", "prompt": "hello"},
)
assert authorized.status_code == 202
assert authorized.json()["task_id"] == "task-run-1"
assert authorized.json()["created"] is False
first_chat = client.post(
"/api/v1/agent/runs",
json={"prompt": "hello"},
)
assert first_chat.status_code == 202
assert first_chat.json()["session_id"] == "auto-created-session"
assert first_chat.json()["created"] is True
finally:
app.dependency_overrides = {}
def test_stream_reads_from_last_event_id() -> None:
app.dependency_overrides[get_agent_service] = lambda: _FakeAgentService()
app.dependency_overrides[get_current_user] = lambda: CurrentUser(
id=uuid4(), email="user@example.com"
)
client = TestClient(app)
try:
response = client.get(
"/api/v1/agent/runs/session-1/events?idle_limit=1",
headers={"Last-Event-ID": "1-0"},
)
assert response.status_code == 200
assert response.headers["content-type"].startswith("text/event-stream")
assert "id: 2-0" in response.text
assert "event: RUN_STARTED" in response.text
finally:
app.dependency_overrides = {}
@@ -0,0 +1,138 @@
from __future__ import annotations
import pytest
from core.agent.infrastructure.agui.bridge import to_agui_events
from core.agent.infrastructure.agui.stream import to_sse_event
def test_bridge_normalizes_event_type_to_upper_snake() -> None:
events = [{"type": "runStarted", "data": {"ok": True}}]
out = to_agui_events(events)
assert out[0]["type"] == "RUN_STARTED"
def test_bridge_supports_core_agui_event_taxonomy() -> None:
events = [
{"type": "runStarted", "data": {}},
{"type": "runFinished", "data": {}},
{"type": "stepStarted", "data": {}},
{"type": "stepFinished", "data": {}},
{"type": "textMessageStart", "data": {}},
{"type": "textMessageContent", "data": {}},
{"type": "textMessageEnd", "data": {}},
{"type": "toolCallStart", "data": {}},
{"type": "toolCallArgs", "data": {}},
{"type": "toolCallEnd", "data": {}},
{"type": "toolCallResult", "data": {}},
{"type": "stateSnapshot", "data": {}},
{"type": "stateDelta", "data": {}},
{"type": "reasoningMessageStart", "data": {}},
{"type": "reasoningMessageContent", "data": {}},
{"type": "reasoningMessageEnd", "data": {}},
]
out = to_agui_events(events)
assert [event["type"] for event in out] == [
"RUN_STARTED",
"RUN_FINISHED",
"STEP_STARTED",
"STEP_FINISHED",
"TEXT_MESSAGE_START",
"TEXT_MESSAGE_CONTENT",
"TEXT_MESSAGE_END",
"TOOL_CALL_START",
"TOOL_CALL_ARGS",
"TOOL_CALL_END",
"TOOL_CALL_RESULT",
"STATE_SNAPSHOT",
"STATE_DELTA",
"REASONING_MESSAGE_START",
"REASONING_MESSAGE_CONTENT",
"REASONING_MESSAGE_END",
]
def test_bridge_preserves_common_agui_fields() -> None:
events = [
{
"type": "toolCallResult",
"id": "evt-1",
"run_id": "run-1",
"timestamp": "2026-03-05T12:00:00Z",
"parent_message_id": "msg-1",
"data": {"ok": True},
}
]
out = to_agui_events(events)
assert out[0]["type"] == "TOOL_CALL_RESULT"
assert out[0]["id"] == "evt-1"
assert out[0]["run_id"] == "run-1"
assert out[0]["timestamp"] == "2026-03-05T12:00:00Z"
assert out[0]["parent_message_id"] == "msg-1"
def test_bridge_rejects_empty_event_type() -> None:
with pytest.raises(ValueError):
to_agui_events([{"type": "", "data": {}}])
def test_bridge_rejects_non_object_data() -> None:
with pytest.raises(ValueError):
to_agui_events([{"type": "runStarted", "data": "not-object"}])
def test_bridge_redacts_sensitive_fields_in_data() -> None:
out = to_agui_events(
[
{
"type": "toolCallArgs",
"data": {
"api_key": "k-1",
"payload": {"authorization": "Bearer x"},
"safe": "ok",
},
}
]
)
assert out[0]["data"]["api_key"] == "***REDACTED***"
assert out[0]["data"]["payload"]["authorization"] == "***REDACTED***"
assert out[0]["data"]["safe"] == "ok"
def test_bridge_redacts_sensitive_key_variants() -> None:
out = to_agui_events(
[
{
"type": "toolCallArgs",
"data": {
"x-api-key": "k-2",
"auth_token": "t-1",
"openaiApiKey": "k-3",
},
}
]
)
assert out[0]["data"]["x-api-key"] == "***REDACTED***"
assert out[0]["data"]["auth_token"] == "***REDACTED***"
assert out[0]["data"]["openaiApiKey"] == "***REDACTED***"
def test_bridge_rejects_unknown_event_type() -> None:
with pytest.raises(ValueError):
to_agui_events([{"type": "NOT_A_REAL_EVENT", "data": {}}])
def test_sse_format_includes_id_event_data() -> None:
payload = to_sse_event(
stream_id="1-0", event={"type": "RUN_STARTED", "data": {"a": 1}}
)
assert payload.startswith("id: 1-0\nevent: RUN_STARTED\ndata: {")
@@ -0,0 +1,96 @@
from __future__ import annotations
import pytest
from types import SimpleNamespace
from pytest import MonkeyPatch
from core.agent.infrastructure.config.resolver import AgentConfigResolver
from core.config.settings import Settings
def test_runtime_raises_if_model_or_api_key_missing() -> None:
resolver = AgentConfigResolver(
settings=SimpleNamespace(
agent_runtime=SimpleNamespace(
default_model_code="", streaming_enabled=True
),
llm=SimpleNamespace(provider_keys={}),
)
)
with pytest.raises(ValueError):
resolver.resolve(model_code="", provider_name="dashscope")
with pytest.raises(ValueError):
resolver.resolve(model_code="gpt-4o-mini", provider_name="dashscope")
def test_runtime_reads_provider_api_key_from_settings() -> None:
resolver = AgentConfigResolver(
settings=SimpleNamespace(
agent_runtime=SimpleNamespace(
default_model_code="gpt-4o-mini",
streaming_enabled=True,
),
llm=SimpleNamespace(provider_keys={"dashscope": "env-like-api-key"}),
)
)
resolved = resolver.resolve(model_code="", provider_name="dashscope")
assert resolved.model_code == "gpt-4o-mini"
assert resolved.provider_api_key == "env-like-api-key"
def test_runtime_reads_provider_api_key_from_env(monkeypatch: MonkeyPatch) -> None:
monkeypatch.setenv("SOCIAL_LLM__PROVIDER_KEYS__DASHSCOPE", "env-key")
resolver = AgentConfigResolver(settings=Settings())
resolved = resolver.resolve(model_code="gpt-4o-mini", provider_name="dashscope")
assert resolved.provider_api_key == "env-key"
def test_runtime_supports_provider_alias_to_env_key() -> None:
resolver = AgentConfigResolver(
settings=SimpleNamespace(
agent_runtime=SimpleNamespace(
default_model_code="deepseek-v3.2",
streaming_enabled=True,
),
llm=SimpleNamespace(provider_keys={"ark": "ark-key"}),
)
)
resolved = resolver.resolve(model_code="", provider_name="volcengine-ark")
assert resolved.provider_api_key == "ark-key"
def test_runtime_rejects_unsupported_provider() -> None:
resolver = AgentConfigResolver(
settings=SimpleNamespace(
agent_runtime=SimpleNamespace(
default_model_code="qwen3.5-flash", streaming_enabled=True
),
llm=SimpleNamespace(provider_keys={"dashscope": "dash-key"}),
)
)
with pytest.raises(ValueError):
resolver.resolve(model_code="", provider_name="unknown-provider")
def test_runtime_config_repr_does_not_expose_api_key() -> None:
resolver = AgentConfigResolver(
settings=SimpleNamespace(
agent_runtime=SimpleNamespace(
default_model_code="qwen3.5-flash", streaming_enabled=True
),
llm=SimpleNamespace(provider_keys={"dashscope": "very-secret-key"}),
)
)
resolved = resolver.resolve(model_code="", provider_name="dashscope")
assert "very-secret-key" not in repr(resolved)
@@ -0,0 +1,97 @@
from __future__ import annotations
from types import SimpleNamespace
from core.agent.infrastructure.config.resolver import AgentConfigResolver
from core.agent.infrastructure.crewai.runtime import CrewAIRuntime
def test_runtime_emits_text_tool_reasoning_events() -> None:
runtime = CrewAIRuntime(
resolver=AgentConfigResolver(
settings=SimpleNamespace(
agent_runtime=SimpleNamespace(
default_model_code="",
streaming_enabled=True,
),
llm=SimpleNamespace(provider_keys={"dashscope": "env-api-key"}),
)
),
model_code="gpt-4o-mini",
provider_name="dashscope",
)
events = runtime.map_events(
[
{"type": "textMessageContent", "data": {"text": "hello"}},
{"type": "toolCallStart", "data": {"tool_name": "weather"}},
{"type": "toolCallResult", "data": {"ok": True}},
{"type": "reasoningMessageContent", "data": {"text": "thinking"}},
{"type": "runFinished", "data": {"status": "completed"}},
]
)
assert [event["type"] for event in events] == [
"TEXT_MESSAGE_CONTENT",
"TOOL_CALL_START",
"TOOL_CALL_RESULT",
"REASONING_MESSAGE_CONTENT",
"RUN_FINISHED",
]
def test_runtime_execute_uses_provider_prefixed_litellm_model(
monkeypatch,
) -> None:
captured: dict[str, object] = {}
def _fake_completion(
*, model: str, api_key: str, messages: list[dict[str, object]]
):
captured["model"] = model
captured["api_key"] = api_key
captured["messages"] = messages
return {
"choices": [
{
"message": {
"content": "hello",
}
}
],
"usage": {},
}
monkeypatch.setattr(
"core.agent.infrastructure.crewai.runtime.run_completion",
_fake_completion,
)
monkeypatch.setattr(
"core.agent.infrastructure.crewai.runtime.extract_usage_and_cost",
lambda _response: SimpleNamespace(
prompt_tokens=1,
completion_tokens=2,
total_tokens=3,
cost=0.001,
),
)
runtime = CrewAIRuntime(
resolver=AgentConfigResolver(
settings=SimpleNamespace(
agent_runtime=SimpleNamespace(
default_model_code="",
streaming_enabled=True,
),
llm=SimpleNamespace(provider_keys={"dashscope": "env-api-key"}),
)
),
model_code="qwen3.5-flash",
provider_name="dashscope",
)
result = runtime.execute(user_input="hi")
assert captured["model"] == "dashscope/qwen3.5-flash"
assert captured["api_key"] == "env-api-key"
assert result["assistant_text"] == "hello"
@@ -0,0 +1,61 @@
from __future__ import annotations
import pytest
from core.agent.infrastructure.litellm.usage_tracker import extract_usage_and_cost
def test_usage_tracker_extracts_tokens_and_cost(
monkeypatch: pytest.MonkeyPatch,
) -> None:
monkeypatch.setattr(
"core.agent.infrastructure.litellm.usage_tracker.completion_cost",
lambda completion_response: 0.123,
)
response = {
"usage": {"prompt_tokens": 11, "completion_tokens": 7, "total_tokens": 18},
}
usage = extract_usage_and_cost(response)
assert usage.prompt_tokens == 11
assert usage.completion_tokens == 7
assert usage.total_tokens == 18
assert usage.cost == 0.123
@pytest.mark.parametrize(
("prompt_tokens", "completion_tokens", "expected_cost"),
[
(128000, 1000, 0.0276),
(200000, 1000, 0.168),
(300000, 1000, 0.372),
],
)
def test_usage_tracker_falls_back_to_local_qwen35_pricing_when_model_unmapped(
monkeypatch: pytest.MonkeyPatch,
prompt_tokens: int,
completion_tokens: int,
expected_cost: float,
) -> None:
def _raise_unmapped(*, completion_response): # type: ignore[no-untyped-def]
del completion_response
raise Exception("This model isn't mapped yet")
monkeypatch.setattr(
"core.agent.infrastructure.litellm.usage_tracker.completion_cost",
_raise_unmapped,
)
response = {
"model": "dashscope/qwen3.5-flash",
"usage": {
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": prompt_tokens + completion_tokens,
},
}
usage = extract_usage_and_cost(response)
assert usage.cost == pytest.approx(expected_cost)
assert usage.cost_source == "custom_pricing"
@@ -0,0 +1,67 @@
from __future__ import annotations
import pytest
from core.agent.infrastructure.queue.tasks import run_agent_task
class _FakeRunService:
async def run(self, *, session_id: str, user_input: str) -> dict[str, object]:
return {"session_id": session_id, "user_input": user_input}
class _FakeResumeService:
async def resume(self, *, session_id: str, tool_call_id: str) -> dict[str, object]:
return {"session_id": session_id, "tool_call_id": tool_call_id}
def test_run_agent_task_emits_started_runtime_and_finished_events() -> None:
session_id = "00000000-0000-0000-0000-000000000001"
events: list[str] = []
def _publish(event_type: str, payload: dict[str, object]) -> None:
del payload
events.append(event_type)
result = run_agent_task(
{
"command": "run",
"session_id": session_id,
"user_input": "hello",
},
publish_event=_publish,
run_service=_FakeRunService(),
resume_service=_FakeResumeService(),
)
assert result["session_id"] == session_id
assert events == ["RUN_STARTED", "RUNTIME_EVENT", "RUN_FINISHED"]
def test_run_agent_task_emits_error_event_on_exception() -> None:
session_id = "00000000-0000-0000-0000-000000000001"
class _BrokenRunService(_FakeRunService):
async def run(self, *, session_id: str, user_input: str) -> dict[str, object]:
del session_id, user_input
raise RuntimeError("boom")
events: list[str] = []
def _publish(event_type: str, payload: dict[str, object]) -> None:
del payload
events.append(event_type)
with pytest.raises(RuntimeError):
run_agent_task(
{
"command": "run",
"session_id": session_id,
"user_input": "hello",
},
publish_event=_publish,
run_service=_BrokenRunService(),
resume_service=_FakeResumeService(),
)
assert events == ["RUN_STARTED", "RUN_ERROR"]
@@ -0,0 +1,57 @@
from __future__ import annotations
from uuid import uuid4
import pytest
from core.agent.infrastructure.events.redis_stream import RedisStreamEventStore
class _FakeRedisClient:
def __init__(self) -> None:
self.calls: list[tuple[str, dict[str, str]]] = []
def xadd(self, stream: str, fields: dict[str, str]) -> str:
self.calls.append((stream, fields))
return "1-0"
async def xread(
self,
streams: dict[str, str],
count: int,
block: int,
) -> list[tuple[str, list[tuple[str, dict[str, str]]]]]:
del count, block
key, start_id = next(iter(streams.items()))
if start_id == "$":
return [(key, [("11-0", {"event": '{"type":"RUN_STARTED"}'})])]
return [(key, [("12-0", {"event": '{"type":"RUN_FINISHED"}'})])]
def test_append_event_writes_json_payload() -> None:
client = _FakeRedisClient()
session_id = uuid4()
store = RedisStreamEventStore(client=client, stream_prefix="agent:events")
stream_id = store.append_event_sync(
session_id=session_id, event={"type": "RUN_STARTED"}
)
assert stream_id == "1-0"
assert len(client.calls) == 1
stream, fields = client.calls[0]
assert stream == f"agent:events:{session_id}"
assert fields["event"] == '{"type":"RUN_STARTED"}'
@pytest.mark.asyncio
async def test_read_events_respects_last_event_id() -> None:
client = _FakeRedisClient()
session_id = uuid4()
store = RedisStreamEventStore(client=client, stream_prefix="agent:events")
from_start = await store.read_events(session_id=session_id, last_event_id=None)
from_last = await store.read_events(session_id=session_id, last_event_id="11-0")
assert from_start[0]["id"] == "11-0"
assert from_last[0]["id"] == "12-0"
@@ -0,0 +1,22 @@
from __future__ import annotations
import pytest
from core.agent.application.resume_service import ResumeService
from core.agent.application.run_service import RunService
@pytest.mark.asyncio
async def test_run_service_rejects_invalid_session_id() -> None:
run_service = RunService()
with pytest.raises(ValueError):
await run_service.run(session_id="session-1", user_input="hello")
@pytest.mark.asyncio
async def test_resume_service_requires_pending_tool_call() -> None:
resume_service = ResumeService()
with pytest.raises(ValueError):
await resume_service.resume(session_id="session-1", tool_call_id="call-1")
@@ -0,0 +1,12 @@
from __future__ import annotations
from core.agent.domain.state_snapshot import AgentStateSnapshot
def test_state_snapshot_serialization_round_trip() -> None:
snapshot = AgentStateSnapshot(status="running", pending_tool_call_id="call-1")
payload = snapshot.model_dump()
assert payload["status"] == "running"
assert payload["pending_tool_call_id"] == "call-1"
@@ -0,0 +1,20 @@
from __future__ import annotations
from core.agent.domain.tool_correlation import build_tool_result_metadata
def test_tool_correlation_builds_tool_result_metadata() -> None:
metadata = build_tool_result_metadata(
run_id="run-1",
turn_id="turn-1",
tool_call_id="call-1",
tool_name="weather",
storage_bucket="private",
storage_path="tool-results/run-1/call-1.json",
payload_sha256="sha256",
payload_bytes=128,
payload_format="json",
)
assert metadata["type"] == "tool_result"
assert metadata["tool_call_id"] == "call-1"
@@ -0,0 +1,29 @@
from __future__ import annotations
from pathlib import Path
def test_session_has_state_snapshot_and_status_contract() -> None:
model_path = (
Path(__file__).resolve().parents[3] / "src" / "models" / "agent_chat_session.py"
)
content = model_path.read_text(encoding="utf-8")
assert "state_snapshot" in content
assert "AgentChatSessionStatus" in content
def test_message_has_token_cost_and_metadata_contract() -> None:
model_path = (
Path(__file__).resolve().parents[3] / "src" / "models" / "agent_chat_message.py"
)
content = model_path.read_text(encoding="utf-8")
assert "input_tokens" in content
assert "output_tokens" in content
assert "cost" in content
assert '"metadata"' in content
versions_dir = Path(__file__).resolve().parents[3] / "alembic" / "versions"
migration_file = versions_dir / "20260305_agent_runtime_closed_loop_contract.py"
assert migration_file.exists()
@@ -0,0 +1,20 @@
from __future__ import annotations
from pytest import MonkeyPatch
from core.config.settings import Settings
def test_social_prefixed_llm_provider_keys_populates_settings(
monkeypatch: MonkeyPatch,
) -> None:
monkeypatch.setenv("SOCIAL_LLM__PROVIDER_KEYS__DASHSCOPE", "dash-key")
monkeypatch.setenv("SOCIAL_LLM__PROVIDER_KEYS__DEEPSEEK", "deep-key")
monkeypatch.setenv("SOCIAL_LLM__PROVIDER_KEYS__ARK", "ark-key")
settings = Settings()
keys = {key.lower(): value for key, value in settings.llm.provider_keys.items()}
assert keys["dashscope"] == "dash-key"
assert keys["deepseek"] == "deep-key"
assert keys["ark"] == "ark-key"
@@ -0,0 +1,16 @@
from __future__ import annotations
from uuid import uuid4
import pytest
from fastapi import HTTPException
from core.auth.models import CurrentUser
from v1.agent.service import ensure_session_owner
def test_owner_guard_denies_non_owner() -> None:
user = CurrentUser(id=uuid4(), email="self@example.com")
with pytest.raises(HTTPException):
ensure_session_owner(owner_id="other-user", current_user=user)
+125
View File
@@ -0,0 +1,125 @@
from __future__ import annotations
from uuid import UUID
from core.auth.models import CurrentUser
from v1.agent.service import AgentService
class _FakeRepository:
def __init__(self) -> None:
self.committed = False
self.rolled_back = False
self.deleted_session_id: str | None = None
async def get_session_owner(self, *, session_id: str) -> str:
del session_id
return "00000000-0000-0000-0000-000000000001"
async def create_session_for_user(self, *, user_id: str) -> str:
del user_id
return "00000000-0000-0000-0000-000000000999"
async def commit(self) -> None:
self.committed = True
async def rollback(self) -> None:
self.rolled_back = True
async def delete_session(self, *, session_id: str) -> None:
self.deleted_session_id = session_id
class _FakeQueue:
async def enqueue(
self, *, command: dict[str, object], dedup_key: str | None
) -> str:
del command, dedup_key
return "task-1"
class _FailingQueue:
async def enqueue(
self, *, command: dict[str, object], dedup_key: str | None
) -> str:
del command, dedup_key
raise RuntimeError("enqueue failed")
class _FakeStream:
async def read(
self, *, session_id: str, last_event_id: str | None
) -> list[dict[str, object]]:
del session_id
return [
{"id": "2-0", "event": {"type": "RUN_STARTED"}, "cursor": last_event_id}
]
def _user() -> CurrentUser:
return CurrentUser(
id=UUID("00000000-0000-0000-0000-000000000001"),
email="user@example.com",
)
async def test_resume_idempotency_uses_redis_lock_and_task_key() -> None:
service = AgentService(
repository=_FakeRepository(),
queue=_FakeQueue(),
stream=_FakeStream(),
)
user = _user()
first = await service.enqueue_resume(
session_id="session-1",
tool_call_id="call-1",
current_user=user,
)
second = await service.enqueue_resume(
session_id="session-1",
tool_call_id="call-1",
current_user=user,
)
assert first.task_id == second.task_id
async def test_enqueue_run_without_session_creates_new_session() -> None:
repository = _FakeRepository()
service = AgentService(
repository=repository,
queue=_FakeQueue(),
stream=_FakeStream(),
)
accepted = await service.enqueue_run(
session_id=None,
prompt="hello",
current_user=_user(),
)
assert accepted.session_id == "00000000-0000-0000-0000-000000000999"
assert accepted.created is True
assert repository.committed is True
async def test_enqueue_run_keeps_created_session_when_enqueue_fails() -> None:
repository = _FakeRepository()
service = AgentService(
repository=repository,
queue=_FailingQueue(),
stream=_FakeStream(),
)
try:
await service.enqueue_run(
session_id=None,
prompt="hello",
current_user=_user(),
)
raise AssertionError("expected RuntimeError")
except RuntimeError as exc:
assert str(exc) == "enqueue failed"
assert repository.deleted_session_id is None