refactor: 重构 Agent 模块为 AgentScope,删除旧版 CrewAI/LiteLLM 实现
This commit is contained in:
@@ -0,0 +1,284 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from decimal import Decimal
|
||||
from enum import Enum
|
||||
from types import SimpleNamespace
|
||||
from typing import Any, cast
|
||||
|
||||
import pytest
|
||||
|
||||
from core.agentscope.events import store as store_module
|
||||
|
||||
|
||||
class _SessionStatus(str, Enum):
|
||||
RUNNING = "running"
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
|
||||
|
||||
class _FakeSessionCtx:
|
||||
class _Session:
|
||||
async def commit(self) -> None:
|
||||
return None
|
||||
|
||||
async def __aenter__(self) -> object:
|
||||
return self._Session()
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb) -> None: # noqa: ANN001
|
||||
del exc_type, exc, tb
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_marks_session_running_on_run_started(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
captured: dict[str, object] = {}
|
||||
fake_chat_session = SimpleNamespace(state_snapshot=None)
|
||||
|
||||
class _FakeSessionRepository:
|
||||
def __init__(self, session: object) -> None:
|
||||
del session
|
||||
|
||||
async def get_session(self, *, session_id): # noqa: ANN001
|
||||
captured["session_id"] = session_id
|
||||
return fake_chat_session
|
||||
|
||||
async def update_runtime_state(self, **kwargs): # noqa: ANN003
|
||||
captured.update(kwargs)
|
||||
|
||||
monkeypatch.setattr(store_module, "SessionRepository", _FakeSessionRepository)
|
||||
monkeypatch.setattr(store_module, "AgentChatSessionStatus", _SessionStatus)
|
||||
store = store_module.SqlAlchemyEventStore(session_factory=lambda: _FakeSessionCtx())
|
||||
|
||||
await store.persist(
|
||||
{
|
||||
"type": "RUN_STARTED",
|
||||
"threadId": "00000000-0000-0000-0000-000000000001",
|
||||
"runId": "run-1",
|
||||
}
|
||||
)
|
||||
|
||||
assert captured["status"] == _SessionStatus.RUNNING
|
||||
assert captured["message_delta"] == 0
|
||||
assert captured["token_delta"] == 0
|
||||
assert captured["cost_delta"] == Decimal("0")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_persists_assistant_message_and_aggregates(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
captured: dict[str, object] = {}
|
||||
fake_chat_session = SimpleNamespace(state_snapshot={"k": "v"}, message_count=6)
|
||||
|
||||
class _FakeSessionRepository:
|
||||
def __init__(self, session: object) -> None:
|
||||
del session
|
||||
|
||||
async def get_session(self, *, session_id): # noqa: ANN001
|
||||
del session_id
|
||||
return fake_chat_session
|
||||
|
||||
async def lock_session_for_update(self, *, session_id): # noqa: ANN001
|
||||
del session_id
|
||||
return fake_chat_session
|
||||
|
||||
async def update_runtime_state(self, **kwargs): # noqa: ANN003
|
||||
captured.update(kwargs)
|
||||
|
||||
class _FakeMessageRepository:
|
||||
def __init__(self, session: object) -> None:
|
||||
del session
|
||||
|
||||
async def append_message(self, **kwargs): # noqa: ANN003
|
||||
captured["append_kwargs"] = kwargs
|
||||
|
||||
monkeypatch.setattr(store_module, "SessionRepository", _FakeSessionRepository)
|
||||
monkeypatch.setattr(store_module, "MessageRepository", _FakeMessageRepository)
|
||||
monkeypatch.setattr(store_module, "AgentChatSessionStatus", _SessionStatus)
|
||||
|
||||
store = store_module.SqlAlchemyEventStore(session_factory=lambda: _FakeSessionCtx())
|
||||
|
||||
await store.persist(
|
||||
{
|
||||
"type": "TEXT_MESSAGE_CONTENT",
|
||||
"threadId": "00000000-0000-0000-0000-000000000001",
|
||||
"runId": "run-1",
|
||||
"messageId": "assistant-run-1",
|
||||
"delta": "hello",
|
||||
}
|
||||
)
|
||||
await store.persist(
|
||||
{
|
||||
"type": "TEXT_MESSAGE_END",
|
||||
"threadId": "00000000-0000-0000-0000-000000000001",
|
||||
"runId": "run-1",
|
||||
"messageId": "assistant-run-1",
|
||||
"inputTokens": 3,
|
||||
"outputTokens": 5,
|
||||
"cost": "0.123",
|
||||
"latencyMs": 250,
|
||||
}
|
||||
)
|
||||
|
||||
append_kwargs = cast(dict[str, Any], captured["append_kwargs"])
|
||||
assert append_kwargs["seq"] == 7
|
||||
assert append_kwargs["content"] == "hello"
|
||||
assert append_kwargs["input_tokens"] == 3
|
||||
assert append_kwargs["output_tokens"] == 5
|
||||
assert append_kwargs["cost"] == Decimal("0.123")
|
||||
assert append_kwargs["metadata"]["latency_ms"] == 250
|
||||
assert captured["message_delta"] == 1
|
||||
assert captured["token_delta"] == 8
|
||||
assert captured["cost_delta"] == Decimal("0.123")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_uses_canonical_thread_id_for_buffer_keys(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
captured: dict[str, object] = {}
|
||||
fake_chat_session = SimpleNamespace(state_snapshot={}, message_count=1)
|
||||
|
||||
class _FakeSessionRepository:
|
||||
def __init__(self, session: object) -> None:
|
||||
del session
|
||||
|
||||
async def get_session(self, *, session_id): # noqa: ANN001
|
||||
del session_id
|
||||
return fake_chat_session
|
||||
|
||||
async def lock_session_for_update(self, *, session_id): # noqa: ANN001
|
||||
del session_id
|
||||
return fake_chat_session
|
||||
|
||||
async def update_runtime_state(self, **kwargs): # noqa: ANN003
|
||||
captured.update(kwargs)
|
||||
|
||||
class _FakeMessageRepository:
|
||||
def __init__(self, session: object) -> None:
|
||||
del session
|
||||
|
||||
async def append_message(self, **kwargs): # noqa: ANN003
|
||||
captured["append_kwargs"] = kwargs
|
||||
|
||||
monkeypatch.setattr(store_module, "SessionRepository", _FakeSessionRepository)
|
||||
monkeypatch.setattr(store_module, "MessageRepository", _FakeMessageRepository)
|
||||
monkeypatch.setattr(store_module, "AgentChatSessionStatus", _SessionStatus)
|
||||
|
||||
store = store_module.SqlAlchemyEventStore(session_factory=lambda: _FakeSessionCtx())
|
||||
compact_thread_id = "00000000000000000000000000000001"
|
||||
|
||||
await store.persist(
|
||||
{
|
||||
"type": "TEXT_MESSAGE_CONTENT",
|
||||
"threadId": compact_thread_id,
|
||||
"runId": "run-1",
|
||||
"messageId": "assistant-run-1",
|
||||
"delta": "hello",
|
||||
}
|
||||
)
|
||||
await store.persist(
|
||||
{
|
||||
"type": "TEXT_MESSAGE_END",
|
||||
"threadId": compact_thread_id,
|
||||
"runId": "run-1",
|
||||
"messageId": "assistant-run-1",
|
||||
}
|
||||
)
|
||||
|
||||
append_kwargs = cast(dict[str, Any], captured["append_kwargs"])
|
||||
assert append_kwargs["content"] == "hello"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_clears_buffer_on_run_finished(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
captured: dict[str, object] = {}
|
||||
fake_chat_session = SimpleNamespace(state_snapshot={}, message_count=0)
|
||||
|
||||
class _FakeSessionRepository:
|
||||
def __init__(self, session: object) -> None:
|
||||
del session
|
||||
|
||||
async def get_session(self, *, session_id): # noqa: ANN001
|
||||
del session_id
|
||||
return fake_chat_session
|
||||
|
||||
async def lock_session_for_update(self, *, session_id): # noqa: ANN001
|
||||
del session_id
|
||||
return fake_chat_session
|
||||
|
||||
async def update_runtime_state(self, **kwargs): # noqa: ANN003
|
||||
captured.update(kwargs)
|
||||
|
||||
class _FakeMessageRepository:
|
||||
def __init__(self, session: object) -> None:
|
||||
del session
|
||||
|
||||
async def append_message(self, **kwargs): # noqa: ANN003
|
||||
captured["append_kwargs"] = kwargs
|
||||
|
||||
monkeypatch.setattr(store_module, "SessionRepository", _FakeSessionRepository)
|
||||
monkeypatch.setattr(store_module, "MessageRepository", _FakeMessageRepository)
|
||||
monkeypatch.setattr(store_module, "AgentChatSessionStatus", _SessionStatus)
|
||||
|
||||
store = store_module.SqlAlchemyEventStore(session_factory=lambda: _FakeSessionCtx())
|
||||
thread_id = "00000000-0000-0000-0000-000000000001"
|
||||
|
||||
await store.persist(
|
||||
{
|
||||
"type": "TEXT_MESSAGE_CONTENT",
|
||||
"threadId": thread_id,
|
||||
"runId": "run-1",
|
||||
"messageId": "assistant-run-1",
|
||||
"delta": "stale",
|
||||
}
|
||||
)
|
||||
await store.persist(
|
||||
{
|
||||
"type": "RUN_FINISHED",
|
||||
"threadId": thread_id,
|
||||
"runId": "run-1",
|
||||
}
|
||||
)
|
||||
await store.persist(
|
||||
{
|
||||
"type": "TEXT_MESSAGE_END",
|
||||
"threadId": thread_id,
|
||||
"runId": "run-1",
|
||||
"messageId": "assistant-run-1",
|
||||
}
|
||||
)
|
||||
|
||||
assert "append_kwargs" not in captured
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_drops_buffer_when_session_missing(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
class _FakeSessionRepository:
|
||||
def __init__(self, session: object) -> None:
|
||||
del session
|
||||
|
||||
async def get_session(self, *, session_id): # noqa: ANN001
|
||||
del session_id
|
||||
return None
|
||||
|
||||
monkeypatch.setattr(store_module, "SessionRepository", _FakeSessionRepository)
|
||||
|
||||
store = store_module.SqlAlchemyEventStore(session_factory=lambda: _FakeSessionCtx())
|
||||
thread_id = "00000000-0000-0000-0000-000000000001"
|
||||
|
||||
await store.persist(
|
||||
{
|
||||
"type": "TEXT_MESSAGE_CONTENT",
|
||||
"threadId": thread_id,
|
||||
"messageId": "assistant-run-1",
|
||||
"delta": "orphan",
|
||||
}
|
||||
)
|
||||
|
||||
assert store._message_buffers == {}
|
||||
Reference in New Issue
Block a user