refactor: 重构 Agent 模块为 AgentScope,删除旧版 CrewAI/LiteLLM 实现
This commit is contained in:
@@ -2,13 +2,14 @@ from core.agentscope.events.agui_codec import AgentScopeAgUiCodec, to_agui_wire_
|
||||
from core.agentscope.events.pipeline import AgentScopeEventPipeline
|
||||
from core.agentscope.events.redis_bus import RedisStreamBus
|
||||
from core.agentscope.events.sse import to_sse_event
|
||||
from core.agentscope.events.store import NullEventStore
|
||||
from core.agentscope.events.store import NullEventStore, SqlAlchemyEventStore
|
||||
|
||||
__all__ = [
|
||||
"AgentScopeAgUiCodec",
|
||||
"AgentScopeEventPipeline",
|
||||
"RedisStreamBus",
|
||||
"NullEventStore",
|
||||
"SqlAlchemyEventStore",
|
||||
"to_agui_wire_event",
|
||||
"to_sse_event",
|
||||
]
|
||||
|
||||
@@ -0,0 +1,87 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from decimal import Decimal
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from models.agent_chat_message import AgentChatMessage, AgentChatMessageRole
|
||||
from models.agent_chat_session import AgentChatSession, AgentChatSessionStatus
|
||||
|
||||
|
||||
class MessageRepository:
|
||||
def __init__(self, session: AsyncSession) -> None:
|
||||
self._session = session
|
||||
|
||||
async def append_message(
|
||||
self,
|
||||
*,
|
||||
session_id: UUID,
|
||||
seq: int,
|
||||
role: AgentChatMessageRole,
|
||||
content: str,
|
||||
model_code: str | None = None,
|
||||
metadata: dict[str, object] | None = None,
|
||||
input_tokens: int = 0,
|
||||
output_tokens: int = 0,
|
||||
cost: Decimal = Decimal("0"),
|
||||
) -> AgentChatMessage:
|
||||
message = AgentChatMessage(
|
||||
session_id=session_id,
|
||||
seq=seq,
|
||||
role=role,
|
||||
content=content,
|
||||
model_code=model_code,
|
||||
metadata_json=metadata,
|
||||
input_tokens=input_tokens,
|
||||
output_tokens=output_tokens,
|
||||
cost=cost,
|
||||
)
|
||||
self._session.add(message)
|
||||
await self._session.flush()
|
||||
return message
|
||||
|
||||
|
||||
class SessionRepository:
|
||||
def __init__(self, session: AsyncSession) -> None:
|
||||
self._session = session
|
||||
|
||||
async def get_session(self, *, session_id: UUID) -> AgentChatSession | None:
|
||||
return await self._session.get(AgentChatSession, session_id)
|
||||
|
||||
async def lock_session_for_update(
|
||||
self, *, session_id: UUID
|
||||
) -> AgentChatSession | None:
|
||||
stmt = (
|
||||
select(AgentChatSession)
|
||||
.where(AgentChatSession.id == session_id)
|
||||
.with_for_update()
|
||||
)
|
||||
return (await self._session.execute(stmt)).scalar_one_or_none()
|
||||
|
||||
async def next_message_seq(self, *, session_id: UUID) -> int:
|
||||
stmt = select(func.coalesce(func.max(AgentChatMessage.seq), 0)).where(
|
||||
AgentChatMessage.session_id == session_id
|
||||
)
|
||||
current = (await self._session.execute(stmt)).scalar_one()
|
||||
return int(current) + 1
|
||||
|
||||
async def update_runtime_state(
|
||||
self,
|
||||
*,
|
||||
chat_session: AgentChatSession,
|
||||
status: AgentChatSessionStatus,
|
||||
state_snapshot: dict[str, object],
|
||||
message_delta: int,
|
||||
token_delta: int = 0,
|
||||
cost_delta: Decimal = Decimal("0"),
|
||||
) -> None:
|
||||
chat_session.status = status
|
||||
chat_session.state_snapshot = state_snapshot
|
||||
chat_session.last_activity_at = datetime.now(timezone.utc)
|
||||
chat_session.message_count += message_delta
|
||||
chat_session.total_tokens += token_delta
|
||||
chat_session.total_cost += cost_delta
|
||||
await self._session.flush()
|
||||
@@ -1,6 +1,12 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Protocol
|
||||
from decimal import Decimal, InvalidOperation
|
||||
from typing import Any, Callable, Protocol
|
||||
from uuid import UUID
|
||||
|
||||
from core.agentscope.events.persistence import MessageRepository, SessionRepository
|
||||
from models.agent_chat_message import AgentChatMessageRole
|
||||
from models.agent_chat_session import AgentChatSessionStatus
|
||||
|
||||
|
||||
class EventStore(Protocol):
|
||||
@@ -10,3 +16,200 @@ class EventStore(Protocol):
|
||||
class NullEventStore:
|
||||
async def persist(self, event: dict[str, Any]) -> None:
|
||||
del event
|
||||
|
||||
|
||||
class SqlAlchemyEventStore:
|
||||
_session_factory: Callable[[], Any]
|
||||
|
||||
def __init__(self, *, session_factory: Any) -> None:
|
||||
self._session_factory = session_factory
|
||||
self._message_buffers: dict[tuple[str, str], str] = {}
|
||||
|
||||
async def persist(self, event: dict[str, Any]) -> None:
|
||||
event_type = str(event.get("type", "")).strip().upper()
|
||||
thread_id = event.get("threadId")
|
||||
if not isinstance(thread_id, str) or not thread_id:
|
||||
return
|
||||
try:
|
||||
session_id = UUID(thread_id)
|
||||
except ValueError:
|
||||
return
|
||||
session_key = str(session_id)
|
||||
|
||||
async with self._session_factory() as session:
|
||||
session_repo = SessionRepository(session)
|
||||
message_repo = MessageRepository(session)
|
||||
chat_session = await session_repo.get_session(session_id=session_id)
|
||||
if chat_session is None:
|
||||
self._clear_session_buffers(session_key=session_key)
|
||||
return
|
||||
|
||||
if event_type == "TEXT_MESSAGE_CONTENT":
|
||||
self._buffer_text_delta(session_key=session_key, event=event)
|
||||
return
|
||||
|
||||
if event_type == "RUN_STARTED":
|
||||
await self._update_session_state(
|
||||
session_repo=session_repo,
|
||||
chat_session=chat_session,
|
||||
status=AgentChatSessionStatus.RUNNING,
|
||||
message_delta=0,
|
||||
)
|
||||
elif event_type == "RUN_ERROR":
|
||||
await self._update_session_state(
|
||||
session_repo=session_repo,
|
||||
chat_session=chat_session,
|
||||
status=AgentChatSessionStatus.FAILED,
|
||||
message_delta=0,
|
||||
)
|
||||
self._clear_session_buffers(session_key=session_key)
|
||||
elif event_type == "RUN_FINISHED":
|
||||
await self._update_session_state(
|
||||
session_repo=session_repo,
|
||||
chat_session=chat_session,
|
||||
status=AgentChatSessionStatus.COMPLETED,
|
||||
message_delta=0,
|
||||
)
|
||||
self._clear_session_buffers(session_key=session_key)
|
||||
elif event_type == "TEXT_MESSAGE_END":
|
||||
await self._persist_assistant_message(
|
||||
event=event,
|
||||
session_id=session_id,
|
||||
chat_session=chat_session,
|
||||
session_repo=session_repo,
|
||||
message_repo=message_repo,
|
||||
)
|
||||
|
||||
await session.commit()
|
||||
|
||||
def _buffer_text_delta(self, *, session_key: str, event: dict[str, Any]) -> None:
|
||||
message_id = event.get("messageId")
|
||||
delta = event.get("delta")
|
||||
if not isinstance(message_id, str) or not message_id:
|
||||
return
|
||||
if not isinstance(delta, str) or not delta:
|
||||
return
|
||||
key = (session_key, message_id)
|
||||
current = self._message_buffers.get(key, "")
|
||||
self._message_buffers[key] = f"{current}{delta}"
|
||||
|
||||
def _clear_session_buffers(self, *, session_key: str) -> None:
|
||||
stale_keys = [k for k in self._message_buffers if k[0] == session_key]
|
||||
for key in stale_keys:
|
||||
self._message_buffers.pop(key, None)
|
||||
|
||||
async def _persist_assistant_message(
|
||||
self,
|
||||
*,
|
||||
event: dict[str, Any],
|
||||
session_id: UUID,
|
||||
chat_session: Any,
|
||||
session_repo: SessionRepository,
|
||||
message_repo: MessageRepository,
|
||||
) -> None:
|
||||
message_id_raw = event.get("messageId")
|
||||
message_id = message_id_raw if isinstance(message_id_raw, str) else ""
|
||||
key = (str(session_id), message_id)
|
||||
content = self._message_buffers.get(key, "")
|
||||
if not content:
|
||||
return
|
||||
|
||||
input_tokens = self._to_int(event.get("inputTokens"))
|
||||
output_tokens = self._to_int(event.get("outputTokens"))
|
||||
token_delta = input_tokens + output_tokens
|
||||
cost = self._to_decimal(event.get("cost"))
|
||||
latency_ms = self._to_int_or_none(event.get("latencyMs"))
|
||||
run_id = event.get("runId")
|
||||
model_code = event.get("model")
|
||||
|
||||
metadata: dict[str, object] = {"message_id": message_id}
|
||||
if isinstance(run_id, str) and run_id:
|
||||
metadata["run_id"] = run_id
|
||||
if latency_ms is not None:
|
||||
metadata["latency_ms"] = latency_ms
|
||||
|
||||
locked_session = await session_repo.lock_session_for_update(
|
||||
session_id=session_id
|
||||
)
|
||||
if locked_session is None:
|
||||
return
|
||||
seq = int(getattr(locked_session, "message_count", 0) or 0) + 1
|
||||
await message_repo.append_message(
|
||||
session_id=session_id,
|
||||
seq=seq,
|
||||
role=AgentChatMessageRole.ASSISTANT,
|
||||
content=content,
|
||||
model_code=model_code if isinstance(model_code, str) else None,
|
||||
metadata=metadata,
|
||||
input_tokens=input_tokens,
|
||||
output_tokens=output_tokens,
|
||||
cost=cost,
|
||||
)
|
||||
|
||||
current_status = getattr(chat_session, "status", AgentChatSessionStatus.RUNNING)
|
||||
status = (
|
||||
current_status
|
||||
if isinstance(current_status, AgentChatSessionStatus)
|
||||
else AgentChatSessionStatus.RUNNING
|
||||
)
|
||||
await self._update_session_state(
|
||||
session_repo=session_repo,
|
||||
chat_session=chat_session,
|
||||
status=status,
|
||||
message_delta=1,
|
||||
token_delta=token_delta,
|
||||
cost_delta=cost,
|
||||
)
|
||||
self._message_buffers.pop(key, None)
|
||||
|
||||
async def _update_session_state(
|
||||
self,
|
||||
*,
|
||||
session_repo: SessionRepository,
|
||||
chat_session: Any,
|
||||
status: AgentChatSessionStatus,
|
||||
message_delta: int,
|
||||
token_delta: int = 0,
|
||||
cost_delta: Decimal = Decimal("0"),
|
||||
) -> None:
|
||||
snapshot = (
|
||||
chat_session.state_snapshot
|
||||
if isinstance(chat_session.state_snapshot, dict)
|
||||
else {}
|
||||
)
|
||||
await session_repo.update_runtime_state(
|
||||
chat_session=chat_session,
|
||||
status=status,
|
||||
state_snapshot=snapshot,
|
||||
message_delta=message_delta,
|
||||
token_delta=token_delta,
|
||||
cost_delta=cost_delta,
|
||||
)
|
||||
|
||||
def _to_int(self, value: object) -> int:
|
||||
if isinstance(value, bool):
|
||||
return 0
|
||||
if not isinstance(value, (int, float, str)):
|
||||
return 0
|
||||
try:
|
||||
return max(int(value), 0)
|
||||
except (TypeError, ValueError):
|
||||
return 0
|
||||
|
||||
def _to_int_or_none(self, value: object) -> int | None:
|
||||
if isinstance(value, bool):
|
||||
return None
|
||||
if not isinstance(value, (int, float, str)):
|
||||
return None
|
||||
try:
|
||||
parsed = int(value)
|
||||
except (TypeError, ValueError):
|
||||
return None
|
||||
return parsed if parsed >= 0 else None
|
||||
|
||||
def _to_decimal(self, value: object) -> Decimal:
|
||||
try:
|
||||
parsed = Decimal(str(value))
|
||||
except (InvalidOperation, TypeError, ValueError):
|
||||
return Decimal("0")
|
||||
return parsed if parsed >= 0 else Decimal("0")
|
||||
|
||||
Reference in New Issue
Block a user