refactor: 重构 Agent 模块为 AgentScope,删除旧版 CrewAI/LiteLLM 实现

This commit is contained in:
qzl
2026-03-11 20:51:56 +08:00
parent 177ed616bf
commit 145e3dc615
149 changed files with 5120 additions and 11356 deletions
@@ -2,13 +2,14 @@ from core.agentscope.events.agui_codec import AgentScopeAgUiCodec, to_agui_wire_
from core.agentscope.events.pipeline import AgentScopeEventPipeline
from core.agentscope.events.redis_bus import RedisStreamBus
from core.agentscope.events.sse import to_sse_event
from core.agentscope.events.store import NullEventStore
from core.agentscope.events.store import NullEventStore, SqlAlchemyEventStore
__all__ = [
"AgentScopeAgUiCodec",
"AgentScopeEventPipeline",
"RedisStreamBus",
"NullEventStore",
"SqlAlchemyEventStore",
"to_agui_wire_event",
"to_sse_event",
]
@@ -0,0 +1,87 @@
from __future__ import annotations
from datetime import datetime, timezone
from decimal import Decimal
from uuid import UUID
from sqlalchemy import func, select
from sqlalchemy.ext.asyncio import AsyncSession
from models.agent_chat_message import AgentChatMessage, AgentChatMessageRole
from models.agent_chat_session import AgentChatSession, AgentChatSessionStatus
class MessageRepository:
def __init__(self, session: AsyncSession) -> None:
self._session = session
async def append_message(
self,
*,
session_id: UUID,
seq: int,
role: AgentChatMessageRole,
content: str,
model_code: str | None = None,
metadata: dict[str, object] | None = None,
input_tokens: int = 0,
output_tokens: int = 0,
cost: Decimal = Decimal("0"),
) -> AgentChatMessage:
message = AgentChatMessage(
session_id=session_id,
seq=seq,
role=role,
content=content,
model_code=model_code,
metadata_json=metadata,
input_tokens=input_tokens,
output_tokens=output_tokens,
cost=cost,
)
self._session.add(message)
await self._session.flush()
return message
class SessionRepository:
def __init__(self, session: AsyncSession) -> None:
self._session = session
async def get_session(self, *, session_id: UUID) -> AgentChatSession | None:
return await self._session.get(AgentChatSession, session_id)
async def lock_session_for_update(
self, *, session_id: UUID
) -> AgentChatSession | None:
stmt = (
select(AgentChatSession)
.where(AgentChatSession.id == session_id)
.with_for_update()
)
return (await self._session.execute(stmt)).scalar_one_or_none()
async def next_message_seq(self, *, session_id: UUID) -> int:
stmt = select(func.coalesce(func.max(AgentChatMessage.seq), 0)).where(
AgentChatMessage.session_id == session_id
)
current = (await self._session.execute(stmt)).scalar_one()
return int(current) + 1
async def update_runtime_state(
self,
*,
chat_session: AgentChatSession,
status: AgentChatSessionStatus,
state_snapshot: dict[str, object],
message_delta: int,
token_delta: int = 0,
cost_delta: Decimal = Decimal("0"),
) -> None:
chat_session.status = status
chat_session.state_snapshot = state_snapshot
chat_session.last_activity_at = datetime.now(timezone.utc)
chat_session.message_count += message_delta
chat_session.total_tokens += token_delta
chat_session.total_cost += cost_delta
await self._session.flush()
+204 -1
View File
@@ -1,6 +1,12 @@
from __future__ import annotations
from typing import Any, Protocol
from decimal import Decimal, InvalidOperation
from typing import Any, Callable, Protocol
from uuid import UUID
from core.agentscope.events.persistence import MessageRepository, SessionRepository
from models.agent_chat_message import AgentChatMessageRole
from models.agent_chat_session import AgentChatSessionStatus
class EventStore(Protocol):
@@ -10,3 +16,200 @@ class EventStore(Protocol):
class NullEventStore:
async def persist(self, event: dict[str, Any]) -> None:
del event
class SqlAlchemyEventStore:
_session_factory: Callable[[], Any]
def __init__(self, *, session_factory: Any) -> None:
self._session_factory = session_factory
self._message_buffers: dict[tuple[str, str], str] = {}
async def persist(self, event: dict[str, Any]) -> None:
event_type = str(event.get("type", "")).strip().upper()
thread_id = event.get("threadId")
if not isinstance(thread_id, str) or not thread_id:
return
try:
session_id = UUID(thread_id)
except ValueError:
return
session_key = str(session_id)
async with self._session_factory() as session:
session_repo = SessionRepository(session)
message_repo = MessageRepository(session)
chat_session = await session_repo.get_session(session_id=session_id)
if chat_session is None:
self._clear_session_buffers(session_key=session_key)
return
if event_type == "TEXT_MESSAGE_CONTENT":
self._buffer_text_delta(session_key=session_key, event=event)
return
if event_type == "RUN_STARTED":
await self._update_session_state(
session_repo=session_repo,
chat_session=chat_session,
status=AgentChatSessionStatus.RUNNING,
message_delta=0,
)
elif event_type == "RUN_ERROR":
await self._update_session_state(
session_repo=session_repo,
chat_session=chat_session,
status=AgentChatSessionStatus.FAILED,
message_delta=0,
)
self._clear_session_buffers(session_key=session_key)
elif event_type == "RUN_FINISHED":
await self._update_session_state(
session_repo=session_repo,
chat_session=chat_session,
status=AgentChatSessionStatus.COMPLETED,
message_delta=0,
)
self._clear_session_buffers(session_key=session_key)
elif event_type == "TEXT_MESSAGE_END":
await self._persist_assistant_message(
event=event,
session_id=session_id,
chat_session=chat_session,
session_repo=session_repo,
message_repo=message_repo,
)
await session.commit()
def _buffer_text_delta(self, *, session_key: str, event: dict[str, Any]) -> None:
message_id = event.get("messageId")
delta = event.get("delta")
if not isinstance(message_id, str) or not message_id:
return
if not isinstance(delta, str) or not delta:
return
key = (session_key, message_id)
current = self._message_buffers.get(key, "")
self._message_buffers[key] = f"{current}{delta}"
def _clear_session_buffers(self, *, session_key: str) -> None:
stale_keys = [k for k in self._message_buffers if k[0] == session_key]
for key in stale_keys:
self._message_buffers.pop(key, None)
async def _persist_assistant_message(
self,
*,
event: dict[str, Any],
session_id: UUID,
chat_session: Any,
session_repo: SessionRepository,
message_repo: MessageRepository,
) -> None:
message_id_raw = event.get("messageId")
message_id = message_id_raw if isinstance(message_id_raw, str) else ""
key = (str(session_id), message_id)
content = self._message_buffers.get(key, "")
if not content:
return
input_tokens = self._to_int(event.get("inputTokens"))
output_tokens = self._to_int(event.get("outputTokens"))
token_delta = input_tokens + output_tokens
cost = self._to_decimal(event.get("cost"))
latency_ms = self._to_int_or_none(event.get("latencyMs"))
run_id = event.get("runId")
model_code = event.get("model")
metadata: dict[str, object] = {"message_id": message_id}
if isinstance(run_id, str) and run_id:
metadata["run_id"] = run_id
if latency_ms is not None:
metadata["latency_ms"] = latency_ms
locked_session = await session_repo.lock_session_for_update(
session_id=session_id
)
if locked_session is None:
return
seq = int(getattr(locked_session, "message_count", 0) or 0) + 1
await message_repo.append_message(
session_id=session_id,
seq=seq,
role=AgentChatMessageRole.ASSISTANT,
content=content,
model_code=model_code if isinstance(model_code, str) else None,
metadata=metadata,
input_tokens=input_tokens,
output_tokens=output_tokens,
cost=cost,
)
current_status = getattr(chat_session, "status", AgentChatSessionStatus.RUNNING)
status = (
current_status
if isinstance(current_status, AgentChatSessionStatus)
else AgentChatSessionStatus.RUNNING
)
await self._update_session_state(
session_repo=session_repo,
chat_session=chat_session,
status=status,
message_delta=1,
token_delta=token_delta,
cost_delta=cost,
)
self._message_buffers.pop(key, None)
async def _update_session_state(
self,
*,
session_repo: SessionRepository,
chat_session: Any,
status: AgentChatSessionStatus,
message_delta: int,
token_delta: int = 0,
cost_delta: Decimal = Decimal("0"),
) -> None:
snapshot = (
chat_session.state_snapshot
if isinstance(chat_session.state_snapshot, dict)
else {}
)
await session_repo.update_runtime_state(
chat_session=chat_session,
status=status,
state_snapshot=snapshot,
message_delta=message_delta,
token_delta=token_delta,
cost_delta=cost_delta,
)
def _to_int(self, value: object) -> int:
if isinstance(value, bool):
return 0
if not isinstance(value, (int, float, str)):
return 0
try:
return max(int(value), 0)
except (TypeError, ValueError):
return 0
def _to_int_or_none(self, value: object) -> int | None:
if isinstance(value, bool):
return None
if not isinstance(value, (int, float, str)):
return None
try:
parsed = int(value)
except (TypeError, ValueError):
return None
return parsed if parsed >= 0 else None
def _to_decimal(self, value: object) -> Decimal:
try:
parsed = Decimal(str(value))
except (InvalidOperation, TypeError, ValueError):
return Decimal("0")
return parsed if parsed >= 0 else Decimal("0")