feat(agent): add redis short-term user context cache and align tests
This commit is contained in:
@@ -155,6 +155,98 @@ async def test_run_then_resume_persists_messages_and_session_state(
|
||||
await cleanup_session.commit()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_service_embeds_profile_settings_in_runtime_system_prompt(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
captured: dict[str, object] = {}
|
||||
session_uuid = uuid.uuid4()
|
||||
agent_type = f"AAA_TEST_{uuid.uuid4().hex[:8]}"
|
||||
original_profile: Profile | None = None
|
||||
|
||||
def _fake_execute(self, *, user_input: str, system_prompt: str | None = None):
|
||||
captured["user_input"] = user_input
|
||||
captured["system_prompt"] = system_prompt
|
||||
return {
|
||||
"assistant_text": "Mocked answer",
|
||||
"prompt_tokens": 11,
|
||||
"completion_tokens": 7,
|
||||
"total_tokens": 18,
|
||||
"cost": 0.0025,
|
||||
"agui_events": [],
|
||||
}
|
||||
|
||||
monkeypatch.setattr(
|
||||
"core.agent.infrastructure.crewai.runtime.CrewAIRuntime.execute",
|
||||
_fake_execute,
|
||||
)
|
||||
|
||||
await engine.dispose()
|
||||
async with AsyncSessionLocal() as lookup_session:
|
||||
owner_row = await lookup_session.execute(select(Profile.id).limit(1))
|
||||
owner_id = owner_row.scalar_one_or_none()
|
||||
if owner_id is None:
|
||||
pytest.skip("No profile owner available in local database")
|
||||
original_profile = await lookup_session.get(Profile, owner_id)
|
||||
llm_row = await lookup_session.execute(
|
||||
select(Llm.id, LlmFactory.name)
|
||||
.join(LlmFactory, LlmFactory.id == Llm.factory_id)
|
||||
.where(LlmFactory.name.in_(("dashscope", "deepseek", "moonshot")))
|
||||
.limit(1)
|
||||
)
|
||||
llm_record = llm_row.one_or_none()
|
||||
if llm_record is None:
|
||||
pytest.skip("No supported llm provider available in local database")
|
||||
llm_id = llm_record[0]
|
||||
|
||||
try:
|
||||
async with AsyncSessionLocal() as seed_session:
|
||||
seed_session.add(
|
||||
SystemAgents(agent_type=agent_type, llm_id=llm_id, status="active")
|
||||
)
|
||||
profile = await seed_session.get(Profile, owner_id)
|
||||
assert profile is not None
|
||||
profile.username = "demo-user"
|
||||
profile.bio = "hello\nworld"
|
||||
profile.settings = {
|
||||
"preferences": {
|
||||
"interface_language": "zh-CN",
|
||||
"ai_language": "en-US",
|
||||
"timezone": "Asia/Shanghai",
|
||||
"country": "CN",
|
||||
}
|
||||
}
|
||||
seed_session.add(AgentChatSession(id=session_uuid, user_id=owner_id))
|
||||
await seed_session.commit()
|
||||
|
||||
result = await RunService().run(session_id=str(session_uuid), user_input="hello")
|
||||
|
||||
assert result["persisted"] is True
|
||||
assert captured["user_input"] == "hello"
|
||||
system_prompt = captured["system_prompt"]
|
||||
assert isinstance(system_prompt, str)
|
||||
assert "# USER_PROFILE (JSON)" in system_prompt
|
||||
assert '"ai_language":"en-US"' in system_prompt
|
||||
assert '"timezone":"Asia/Shanghai"' in system_prompt
|
||||
assert '"country":"CN"' in system_prompt
|
||||
finally:
|
||||
await engine.dispose()
|
||||
async with AsyncSessionLocal() as cleanup_session:
|
||||
if original_profile is not None:
|
||||
profile = await cleanup_session.get(Profile, owner_id)
|
||||
if profile is not None:
|
||||
profile.username = original_profile.username
|
||||
profile.bio = original_profile.bio
|
||||
profile.settings = original_profile.settings
|
||||
await cleanup_session.execute(
|
||||
delete(AgentChatSession).where(AgentChatSession.id == session_uuid)
|
||||
)
|
||||
await cleanup_session.execute(
|
||||
delete(SystemAgents).where(SystemAgents.agent_type == agent_type)
|
||||
)
|
||||
await cleanup_session.commit()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_soft_delete_session_cascades_to_messages() -> None:
|
||||
session_uuid = uuid.uuid4()
|
||||
|
||||
@@ -44,11 +44,6 @@ class FakeUserService:
|
||||
bio=update.bio if update.bio is not None else self._user.bio,
|
||||
)
|
||||
|
||||
async def get_by_username(self, username: str) -> UserResponse:
|
||||
if username != self._user.username:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
return self._user
|
||||
|
||||
async def search_users(self, request: UserSearchRequest) -> list[UserResponse]:
|
||||
if request.query:
|
||||
return self._search_results if self._search_results else [self._user]
|
||||
@@ -191,22 +186,3 @@ def test_search_users_empty_query_returns_422() -> None:
|
||||
assert response.status_code == 422
|
||||
finally:
|
||||
app.dependency_overrides = {}
|
||||
|
||||
|
||||
def test_get_user_by_username_returns_404() -> None:
|
||||
user = UserResponse(
|
||||
id="00000000-0000-0000-0000-000000000001",
|
||||
username="demo",
|
||||
avatar_url=None,
|
||||
bio=None,
|
||||
)
|
||||
app.dependency_overrides[get_user_service] = _override_user_service(
|
||||
FakeUserService(user)
|
||||
)
|
||||
|
||||
client = TestClient(app)
|
||||
try:
|
||||
response = client.get("/api/v1/users/demo")
|
||||
assert response.status_code == 404
|
||||
finally:
|
||||
app.dependency_overrides = {}
|
||||
|
||||
@@ -0,0 +1,89 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from uuid import UUID
|
||||
|
||||
import httpx
|
||||
import jwt
|
||||
import pytest
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.config import config
|
||||
from core.db.session import AsyncSessionLocal
|
||||
from models.agent_chat_message import AgentChatMessage
|
||||
from models.agent_chat_session import AgentChatSession
|
||||
from models.profile import Profile
|
||||
|
||||
BASE_URL = os.getenv("AGENT_LIVE_BASE_URL", "http://localhost:5775")
|
||||
|
||||
|
||||
async def _owner_id() -> UUID:
|
||||
async with AsyncSessionLocal() as session:
|
||||
owner_id = (await session.execute(select(Profile.id).limit(1))).scalar_one_or_none()
|
||||
if owner_id is None:
|
||||
pytest.skip("profile owner not found")
|
||||
return owner_id
|
||||
|
||||
|
||||
def _jwt_for(user_id: UUID) -> str:
|
||||
secret = config.supabase.jwt_secret
|
||||
if not secret:
|
||||
pytest.skip("JWT secret not configured")
|
||||
issuer = f"{config.supabase.public_url.rstrip('/')}/auth/v1"
|
||||
payload = {
|
||||
"sub": str(user_id),
|
||||
"role": "authenticated",
|
||||
"aud": "authenticated",
|
||||
"iss": issuer,
|
||||
"iat": datetime.now(timezone.utc),
|
||||
"exp": datetime.now(timezone.utc) + timedelta(minutes=30),
|
||||
}
|
||||
return jwt.encode(payload, secret, algorithm="HS256")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.live
|
||||
async def test_agent_sse_closed_loop_live() -> None:
|
||||
if os.getenv("AGENT_LIVE_INTEGRATION") != "1":
|
||||
pytest.skip("set AGENT_LIVE_INTEGRATION=1 to run live integration test")
|
||||
|
||||
owner_id = await _owner_id()
|
||||
token = _jwt_for(owner_id)
|
||||
headers = {"Authorization": f"Bearer {token}"}
|
||||
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
run_resp = await client.post(
|
||||
f"{BASE_URL}/api/v1/agent/runs",
|
||||
headers=headers,
|
||||
json={"prompt": "请用一句话介绍你自己"},
|
||||
)
|
||||
assert run_resp.status_code == 202
|
||||
|
||||
accepted = run_resp.json()
|
||||
session_id = str(accepted["session_id"])
|
||||
assert session_id
|
||||
|
||||
events_url = f"{BASE_URL}/api/v1/agent/runs/{session_id}/events"
|
||||
event_names: list[str] = []
|
||||
async with client.stream("GET", events_url, headers=headers, timeout=20.0) as sse_resp:
|
||||
assert sse_resp.status_code == 200
|
||||
assert sse_resp.headers.get("content-type", "").startswith("text/event-stream")
|
||||
async for line in sse_resp.aiter_lines():
|
||||
if line.startswith("event:"):
|
||||
event_names.append(line.split(":", 1)[1].strip())
|
||||
|
||||
assert "RUN_STARTED" in event_names
|
||||
assert "RUN_FINISHED" in event_names or "RUN_ERROR" in event_names
|
||||
|
||||
async with AsyncSessionLocal() as session:
|
||||
session_row = await session.get(AgentChatSession, UUID(session_id))
|
||||
assert session_row is not None
|
||||
assert session_row.message_count >= 1
|
||||
assert session_row.total_tokens >= 0
|
||||
assert session_row.total_cost >= 0
|
||||
|
||||
rows = await session.execute(
|
||||
select(AgentChatMessage).where(AgentChatMessage.session_id == UUID(session_id))
|
||||
)
|
||||
assert len(list(rows.scalars().all())) >= 1
|
||||
Reference in New Issue
Block a user